@@ -15,9 +15,9 @@ translation:
1515 JAX as a NumPy Replacement : JAX 作为 NumPy 的替代品
1616 JAX as a NumPy Replacement::Similarities : 相似之处
1717 JAX as a NumPy Replacement::Differences : 差异
18- JAX as a NumPy Replacement::Differences::Precision : 精度
19- JAX as a NumPy Replacement::Differences::Immutability : 不可变性
20- JAX as a NumPy Replacement::Differences::A workaround : 变通方法
18+ JAX as a NumPy Replacement::Differences::Speed! : 精度
19+ JAX as a NumPy Replacement::Differences::Precision : 不可变性
20+ JAX as a NumPy Replacement::Differences::Immutability : 变通方法
2121 Functional Programming : 函数式编程
2222 Functional Programming::Pure functions : 纯函数
2323 Functional Programming::Examples : 示例
@@ -26,18 +26,12 @@ translation:
2626 Random numbers::Why explicit random state? : 为什么要显式随机状态?
2727 Random numbers::Why explicit random state?::NumPy's approach : NumPy 的方法
2828 Random numbers::Why explicit random state?::JAX's approach : JAX 的方法
29- JIT compilation : JIT 编译
30- JIT compilation::A simple example : 一个简单的示例
31- JIT compilation::A simple example::With NumPy : 使用 NumPy
32- JIT compilation::A simple example::With JAX : 使用 JAX
33- JIT compilation::A simple example::Changing array sizes : 更改数组大小
34- JIT compilation::Evaluating a more complicated function : 评估更复杂的函数
35- JIT compilation::Evaluating a more complicated function::With NumPy : 使用 NumPy
36- JIT compilation::Evaluating a more complicated function::With JAX : 使用 JAX
37- JIT compilation::Compiling the Whole Function : 编译整个函数
38- JIT compilation::Compiling non-pure functions : 编译非纯函数
39- JIT compilation::Summary : 总结
40- Gradients : 梯度
29+ JIT Compilation : JIT 编译
30+ JIT Compilation::With NumPy : 使用 NumPy
31+ JIT Compilation::With JAX : 使用 JAX
32+ JIT Compilation::Compiling the Whole Function : 编译整个函数
33+ JIT Compilation::How JIT compilation works : JIT 编译的工作原理
34+ JIT Compilation::Compiling non-pure functions : 编译非纯函数
4135 Exercises : 练习
4236---
4337
@@ -148,7 +142,6 @@ jnp.linalg.inv(B) # Inverse of identity is identity
148142jnp.linalg.eigh(B) # Computes eigenvalues and eigenvectors
149143```
150144
151-
152145### 差异
153146
154147现在让我们来看看 JAX 和 NumPy 数组操作之间的一些差异。
249242
250243(尽管它在 JIT 编译的函数中实际上可以很高效——但现在先把这个放在一边。)
251244
252-
253245## 函数式编程
254246
255247来自 JAX 的文档:
279271* 不会改变全局状态
280272* 不会修改传递给函数的数据(不可变数据)
281273
282-
283-
284274### 示例
285275
286276以下是一个* 非纯* 函数的示例:
@@ -316,7 +306,6 @@ def add_tax_pure(prices, tax_rate):
316306
317307现在我们理解了什么是纯函数,让我们探索 JAX 处理随机数的方法如何维护这种纯粹性。
318308
319-
320309## 随机数
321310
322311与 NumPy 或 Matlab 中的随机数相比,JAX 中的随机数有很大不同。
@@ -327,7 +316,6 @@ def add_tax_pure(prices, tax_rate):
327316
328317此外,对随机状态的完全控制对于并行编程至关重要,例如当我们想要沿多个线程运行独立实验时。
329318
330-
331319### 随机数生成
332320
333321在 JAX 中,随机数生成器的状态被显式控制。
@@ -405,7 +393,6 @@ key = jax.random.PRNGKey(seed)
405393matrices = gen_random_matrices(key)
406394```
407395
408-
409396### 为什么要显式随机状态?
410397
411398为什么 JAX 需要这种相对冗长的随机数生成方法?
@@ -433,7 +420,6 @@ print(np.random.randn()) # Updates state of random number generator
433420* 它是非确定性的:相同的输入(在这种情况下,没有输入)产生不同的输出
434421* 它有副作用:它修改了全局随机数生成器状态
435422
436-
437423#### JAX 的方法
438424
439425如上所示,JAX 采用了不同的方法,通过密钥使随机性显式化。
@@ -475,126 +461,21 @@ JAX 的显式性带来了显著的好处:
475461
476462最后一点将在下一节中进行扩展。
477463
478-
479464## JIT 编译
480465
481466JAX 的即时(JIT)编译器通过生成随任务大小和硬件变化的高效机器码来加速执行。
482467
483- ### 一个简单的示例
484-
485- 假设我们想在许多点上求余弦函数的值。
486-
487- ``` {code-cell}
488- n = 50_000_000
489- x = np.linspace(0, 10, n)
490- ```
491-
492- #### 使用 NumPy
493-
494- 让我们先用 NumPy 试试:
495-
496- ``` {code-cell}
497- with qe.Timer():
498- y = np.cos(x)
499- ```
500-
501- 再试一次。
502-
503- ``` {code-cell}
504- with qe.Timer():
505- y = np.cos(x)
506- ```
507-
508- 这里 NumPy 使用预先构建的二进制文件,该文件由精心编写的低级代码编译而成,用于对浮点数数组应用余弦函数。
509-
510- 这个二进制文件随 NumPy 一起发布。
511-
512- #### 使用 JAX
513-
514- 现在让我们用 JAX 试试。
515-
516- ``` {code-cell}
517- x = jnp.linspace(0, 10, n)
518- ```
519-
520- 让我们对相同的过程计时。
521-
522- ``` {code-cell}
523- with qe.Timer():
524- y = jnp.cos(x)
525- jax.block_until_ready(y);
526- ```
527-
528- ``` {note}
529- 这里,为了测量实际速度,我们使用 `block_until_ready` 方法让解释器等待,直到计算结果返回。
530-
531- 这是必要的,因为 JAX 使用异步调度,允许 Python 解释器领先于数值计算。
532-
533- 对于非计时代码,您可以删除包含 `block_until_ready` 的那一行。
534- ```
535-
536-
537- 再次计时。
538-
539-
540- ``` {code-cell}
541- with qe.Timer():
542- y = jnp.cos(x)
543- jax.block_until_ready(y);
544- ```
545-
546- 在 GPU 上,这段代码的运行速度远快于其 NumPy 等价代码。
547-
548- 此外,通常第二次运行比第一次快,这是由于 JIT 编译的原因。
549-
550- 这是因为即使是像 ` jnp.cos ` 这样的内置函数也经过了 JIT 编译——第一次运行包含了编译时间。
551-
552- 为什么 JAX 要对像 ` jnp.cos ` 这样的内置函数进行 JIT 编译,而不是像 NumPy 那样提供预编译版本?
553-
554- 原因是 JIT 编译器希望针对正在使用的数组的* 大小* (以及数据类型)进行专门优化。
555-
556- 大小对于生成优化代码很重要,因为高效的并行化需要将任务大小与可用硬件相匹配。
557-
558- 这就是为什么 JAX 要等到看到数组大小后再进行编译——这需要 JIT 编译方法,而不是提供预编译的二进制文件。
559-
560-
561- #### 更改数组大小
562-
563- 这里我们更改输入大小并观察运行时间。
564-
565- ``` {code-cell}
566- x = jnp.linspace(0, 10, n + 1)
567- ```
568-
569- ``` {code-cell}
570- with qe.Timer():
571- y = jnp.cos(x)
572- jax.block_until_ready(y);
573- ```
574-
575-
576- ``` {code-cell}
577- with qe.Timer():
578- y = jnp.cos(x)
579- jax.block_until_ready(y);
580- ```
581-
582- 通常,运行时间会先增加然后再减少(这在 GPU 上会更加明显)。
583-
584- 这是因为 JIT 编译器针对数组大小进行专门优化以利用并行化——因此当数组大小改变时会生成新的编译代码。
468+ 当我们在 {ref}` 上面 <jax_speed> ` 对一个大型数组应用 ` cos ` 时,我们看到了 JAX 的 JIT 编译器结合并行硬件的强大之处。
585469
586-
587- ### 评估更复杂的函数
588-
589- 让我们用一个更复杂的函数尝试同样的操作。
470+ 让我们用一个更复杂的函数尝试同样的操作:
590471
591472``` {code-cell}
592473def f(x):
593- y = np.cos(2 * x**2) + np.sqrt(np.abs(x)) + 2 * np.sin(x**4) - 0.1 * x**2
474+ y = np.cos(2 * x**2) + np.sqrt(np.abs(x)) + 2 * np.sin(x**4) - x**2
594475 return y
595476```
596477
597- #### 使用 NumPy
478+ ### 使用 NumPy
598479
599480我们先用 NumPy 试试:
600481
@@ -605,12 +486,11 @@ x = np.linspace(0, 10, n)
605486
606487``` {code-cell}
607488with qe.Timer():
489+ # Time NumPy code
608490 y = f(x)
609491```
610492
611-
612-
613- #### 使用 JAX
493+ ### 使用 JAX
614494
615495现在让我们用 JAX 再试一次。
616496
@@ -620,34 +500,36 @@ with qe.Timer():
620500def f(x):
621501 y = jnp.cos(2 * x**2) + jnp.sqrt(jnp.abs(x)) + 2 * jnp.sin(x**4) - x**2
622502 return y
623- ```
624503
625- 现在让我们计时。
626504
627- ``` {code-cell}
628505x = jnp.linspace(0, 10, n)
629506```
630507
508+ 现在让我们计时。
509+
631510``` {code-cell}
632511with qe.Timer():
512+ # First call
633513 y = f(x)
514+ # Hold interpreter
634515 jax.block_until_ready(y);
635516```
636517
637518``` {code-cell}
638519with qe.Timer():
520+ # Second call
639521 y = f(x)
522+ # Hold interpreter
640523 jax.block_until_ready(y);
641524```
642525
643526结果与 ` cos ` 示例类似——JAX 更快,尤其是在 JIT 编译后的第二次运行中。
644527
645- 此外,使用 JAX,我们还有另一个技巧:
646-
528+ 然而,使用 JAX,我们还有另一个技巧——我们可以对* 整个* 函数进行 JIT 编译,而不仅仅是单个操作。
647529
648530### 编译整个函数
649531
650- JAX 即时(JIT)编译器可以通过将线性代数运算融合到单个优化内核中来加速函数内部的执行 。
532+ JAX 即时(JIT)编译器可以通过将数组操作融合到单个优化内核中来加速函数内部的执行 。
651533
652534让我们用函数 ` f ` 来试试这个:
653535
@@ -657,21 +539,24 @@ f_jax = jax.jit(f)
657539
658540``` {code-cell}
659541with qe.Timer():
542+ # First run
660543 y = f_jax(x)
544+ # Hold interpreter
661545 jax.block_until_ready(y);
662546```
663547
664548``` {code-cell}
665549with qe.Timer():
550+ # Second run
666551 y = f_jax(x)
552+ # Hold interpreter
667553 jax.block_until_ready(y);
668554```
669555
670556运行时间再次改善——现在是因为我们融合了所有操作,使编译器能够更积极地进行优化。
671557
672558例如,编译器可以消除对硬件加速器的多次调用以及许多中间数组的创建。
673559
674-
675560顺便提一下,当针对 JIT 编译器的函数时,更常见的语法是:
676561
677562``` {code-cell} ipython3
@@ -680,6 +565,14 @@ def f(x):
680565 pass # put function body here
681566```
682567
568+ ### JIT 编译的工作原理
569+
570+ 当我们对一个函数应用 ` jax.jit ` 时,JAX 会对其进行* 追踪* :它不会立即执行操作,而是将操作序列记录为计算图,并将该图交给 [ XLA] ( https://openxla.org/xla ) 编译器。
571+
572+ XLA 随后将这些操作融合并优化为一个针对可用硬件(CPU、GPU 或 TPU)定制的单个编译内核。
573+
574+ 对 JIT 编译函数的第一次调用会产生编译开销,但后续具有相同输入形状和类型的调用将复用缓存的编译代码并以全速运行。
575+
683576### 编译非纯函数
684577
685578现在我们已经看到了 JIT 编译的强大之处,理解它与纯函数的关系非常重要。
@@ -728,65 +621,6 @@ f(x)
728621
729622这个故事的寓意:使用 JAX 时请编写纯函数!
730623
731-
732- ### 总结
733-
734- 现在我们可以理解为什么开发者和编译器都受益于纯函数。
735-
736- 我们喜欢纯函数,因为它们:
737-
738- * 有助于测试:每个函数可以独立运行
739- * 促进确定性行为,从而实现可复现性
740- * 防止由于修改共享状态而产生的错误
741-
742- 编译器喜欢纯函数和函数式编程,因为:
743-
744- * 数据依赖关系是显式的,有助于优化复杂计算
745- * 纯函数更容易进行微分(自动微分)
746- * 纯函数更容易并行化和优化(不依赖于共享可变状态)
747-
748-
749- ## 梯度
750-
751- JAX 可以使用自动微分来计算梯度。
752-
753- 这对于优化和求解非线性系统非常有用。
754-
755- 我们将在本讲座系列后面看到重要的应用。
756-
757- 现在,这里有一个非常简单的说明,涉及函数:
758-
759- ``` {code-cell} ipython3
760- def f(x):
761- return (x**2) / 2
762- ```
763-
764- 让我们求导数:
765-
766- ``` {code-cell} ipython3
767- f_prime = jax.grad(f)
768- ```
769-
770- ``` {code-cell} ipython3
771- f_prime(10.0)
772- ```
773-
774- 让我们绘制函数和导数,注意 $f'(x) = x$。
775-
776- ``` {code-cell} ipython3
777- import matplotlib.pyplot as plt
778-
779- fig, ax = plt.subplots()
780- x_grid = jnp.linspace(-4, 4, 200)
781- ax.plot(x_grid, f(x_grid), label="$f$")
782- ax.plot(x_grid, [f_prime(x) for x in x_grid], label="$f'$")
783- ax.legend(loc='upper center')
784- plt.show()
785- ```
786-
787- 我们将进一步探索 JAX 自动微分的内容推迟到 {doc}` jax:autodiff ` 。
788-
789-
790624## 练习
791625
792626
@@ -871,4 +705,4 @@ with qe.Timer():
871705```
872706
873707``` {solution-end}
874- ```
708+ ```
0 commit comments