@@ -43,13 +43,16 @@ translation:
4343
4444本讲座简要介绍 [ Google JAX] ( https://github.com/jax-ml/jax ) 。
4545
46+ ``` {include} _admonition/gpu.md
47+ ```
48+
4649JAX 是一个高性能科学计算库,提供以下功能:
4750
4851* 类似 [ NumPy] ( https://en.wikipedia.org/wiki/NumPy ) 的接口,可以在 CPU 和 GPU 上自动并行化,
4952* 一个即时编译器,用于加速大量数值运算,以及
5053* [ 自动微分] ( https://en.wikipedia.org/wiki/Automatic_differentiation ) 。
5154
52- JAX 也在日益维护和提供[ 更多专业化的科学计算例程] ( https://docs.jax.dev/en/latest/jax.scipy.html ) ,例如那些最初在 [ SciPy] ( https://en.wikipedia.org/wiki/SciPy ) 中找到的例程。
55+ JAX 也在日益维护和提供 [ 更多专业化的科学计算例程] ( https://docs.jax.dev/en/latest/jax.scipy.html ) ,例如那些最初在 [ SciPy] ( https://en.wikipedia.org/wiki/SciPy ) 中找到的例程。
5356
5457除了 Anaconda 中已有的内容外,本讲座还需要以下库:
5558
@@ -59,36 +62,28 @@ JAX 也在日益维护和提供[更多专业化的科学计算例程](https://do
5962!pip install jax quantecon
6063```
6164
62- ``` {include} _admonition/gpu.md
63- ```
64-
65- ## JAX 作为 NumPy 的替代品
66-
67- JAX 的一个吸引人之处在于,它的数组处理操作在尽可能的情况下遵循 NumPy API。
68-
69- 这意味着在许多情况下,我们可以将 JAX 作为 NumPy 的直接替代品使用。
70-
71- 让我们来看看 JAX 和 NumPy 之间的异同。
72-
73- ### 相似之处
74-
7565我们将使用以下导入:
7666
7767``` {code-cell} ipython3
7868import jax
7969import jax.numpy as jnp
8070import matplotlib.pyplot as plt
81- import matplotlib as mpl
82- import matplotlib.font_manager
83- FONTPATH = "_fonts/SourceHanSerifSC-SemiBold.otf"
84- mpl.font_manager.fontManager.addfont(FONTPATH)
85- mpl.rcParams['font.family'] = ['Source Han Serif SC']
8671import numpy as np
8772import quantecon as qe
8873```
8974
9075注意我们导入了 ` jax.numpy as jnp ` ,它提供了类似 NumPy 的接口。
9176
77+ ## JAX 作为 NumPy 的替代品
78+
79+ JAX 的一个吸引人之处在于,它的数组处理操作在尽可能的情况下遵循 NumPy API。
80+
81+ 这意味着在许多情况下,我们可以将 JAX 作为 NumPy 的直接替代品使用。
82+
83+ 让我们来看看 JAX 和 NumPy 之间的异同。
84+
85+ ### 相似之处
86+
9287以下是使用 ` jnp ` 进行的一些标准数组操作:
9388
9489``` {code-cell} ipython3
@@ -103,10 +98,6 @@ print(a)
10398print(jnp.sum(a))
10499```
105100
106- ``` {code-cell} ipython3
107- print(jnp.mean(a))
108- ```
109-
110101``` {code-cell} ipython3
111102print(jnp.dot(a, a))
112103```
121112type(a)
122113```
123114
124- 即使是数组上的标量值映射也会返回 JAX 数组。
115+ 即使是数组上的标量值映射也会返回 JAX 数组,而不是标量!
125116
126117``` {code-cell} ipython3
127118jnp.sum(a)
128119```
129120
130- 对高维数组的操作也与 NumPy 类似:
121+ ### 差异
131122
132- ``` {code-cell} ipython3
133- A = jnp.ones((2, 2))
134- B = jnp.identity(2)
135- A @ B
123+ 现在让我们来看看 JAX 和 NumPy 数组操作之间的一些差异。
124+
125+ (jax_speed)=
126+ #### 速度!
127+
128+ 假设我们想在许多点上计算余弦函数。
129+
130+ ``` {code-cell}
131+ n = 50_000_000
132+ x = np.linspace(0, 10, n)
136133```
137134
138- JAX 的数组接口也提供了 ` linalg ` 子包:
135+ ##### 使用 NumPy
139136
140- ``` {code-cell} ipython3
141- jnp.linalg.inv(B) # Inverse of identity is identity
137+ 让我们先用 NumPy 试试:
138+
139+ ``` {code-cell}
140+ with qe.Timer():
141+ # First NumPy timing
142+ y = np.cos(x)
142143```
143144
144- ``` {code-cell} ipython3
145- eigvals, eigvecs = jnp.linalg.eigh(B) # Computes eigenvalues and eigenvectors
146- eigvals
145+ 再来一次。
146+
147+ ``` {code-cell}
148+ with qe.Timer():
149+ # Second NumPy timing
150+ y = np.cos(x)
147151```
148152
149- ### 差异
153+ 这里
150154
151- 现在让我们来看看 JAX 和 NumPy 数组操作之间的一些差异。
155+ * NumPy 使用预编译的二进制文件对浮点数数组应用余弦函数
156+ * 该二进制文件在本地机器的 CPU 上运行
157+
158+ ##### 使用 JAX
159+
160+ 现在让我们用 JAX 试试。
161+
162+ ``` {code-cell}
163+ x = jnp.linspace(0, 10, n)
164+ ```
165+
166+ 让我们对相同的过程计时。
167+
168+ ``` {code-cell}
169+ with qe.Timer():
170+ # First run
171+ y = jnp.cos(x)
172+ # Hold the interpreter until the array operation finishes
173+ jax.block_until_ready(y);
174+ ```
175+
176+ ``` {note}
177+ 这里,为了测量实际速度,我们使用 `block_until_ready` 方法来阻塞解释器,直到计算结果返回。
178+
179+ 这是必要的,因为 JAX 使用异步调度,允许 Python 解释器在数值计算之前运行。
180+
181+ 对于非计时代码,可以删除包含 `block_until_ready` 的那一行。
182+ ```
183+
184+ 再来计时一次。
185+
186+ ``` {code-cell}
187+ with qe.Timer():
188+ # Second run
189+ y = jnp.cos(x)
190+ # Hold interpreter
191+ jax.block_until_ready(y);
192+ ```
193+
194+ 在 GPU 上,此代码的运行速度远快于其 NumPy 等效代码。
195+
196+ 此外,通常第二次运行比第一次更快,这是由于 JIT 编译的缘故。
197+
198+ 这是因为即使是像 ` jnp.cos ` 这样的内置函数也是经过 JIT 编译的——第一次运行包含了编译时间。
199+
200+ 为什么 JAX 要对像 ` jnp.cos ` 这样的内置函数进行 JIT 编译,而不是像 NumPy 那样直接提供预编译版本?
201+
202+ 原因是 JIT 编译器希望针对所使用数组的* 大小* (以及数据类型)进行专门优化。
203+
204+ 大小对于生成优化代码很重要,因为高效的并行化需要将任务大小与可用硬件相匹配。
205+
206+ 我们可以通过更改输入大小并观察运行时间来验证 JAX 针对数组大小进行专门化的说法。
207+
208+ ``` {code-cell}
209+ x = jnp.linspace(0, 10, n + 1)
210+ ```
211+
212+ ``` {code-cell}
213+ with qe.Timer():
214+ # First run
215+ y = jnp.cos(x)
216+ # Hold interpreter
217+ jax.block_until_ready(y);
218+ ```
219+
220+ ``` {code-cell}
221+ with qe.Timer():
222+ # Second run
223+ y = jnp.cos(x)
224+ # Hold interpreter
225+ jax.block_until_ready(y);
226+ ```
227+
228+ 运行时间先增加后减少(这在 GPU 上会更明显)。
229+
230+ 这与上面的讨论一致——更改数组大小后的第一次运行显示了编译开销。
231+
232+ 关于 JIT 编译的进一步讨论见下文。
152233
153- (jax_speed)=
154234#### 速度!
155235
156236假设我们想在许多点上求余弦函数的值。
@@ -301,30 +381,13 @@ try:
301381 a[0] = 1
302382except Exception as e:
303383 print(e)
304-
305384```
306385
307- 与不可变性一致,JAX 不支持原地操作:
308-
309- ``` {code-cell} ipython3
310- a = np.array((2, 1))
311- a.sort() # Unlike NumPy, does not mutate a
312- a
313- ```
314-
315- ``` {code-cell} ipython3
316- a = jnp.array((2, 1))
317- a_new = a.sort() # Instead, the sort method returns a new sorted array
318- a, a_new
319- ```
320-
321- JAX 的设计者选择将数组设为不可变的,因为 JAX 使用 [ 函数式编程] ( https://en.wikipedia.org/wiki/Functional_programming ) 风格。
322-
323- 这个设计选择有重要的含义,我们接下来将对此进行探讨!
386+ JAX 的设计者选择将数组设为不可变的,因为 JAX 使用函数式编程风格,我们将在下面讨论这一点。
324387
325388#### 变通方法
326389
327- 我们注意到 JAX 确实提供了一种使用 [ ` at ` 方法] ( https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html ) 进行原地数组修改的版本 。
390+ 我们注意到 JAX 确实提供了一种替代原地数组修改的方式,使用 [ ` at ` 方法] ( https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html ) 。
328391
329392``` {code-cell} ipython3
330393a = jnp.linspace(0, 1, 3)
@@ -408,15 +471,29 @@ def add_tax_pure(prices, tax_rate):
408471
409472这个纯版本通过函数参数使所有依赖关系变得明确,并且不修改任何外部状态。
410473
411- ### 为什么使用函数式编程?
474+ ### 为什么要函数式编程?
475+
476+ 在 QuantEcon,我们热爱纯函数,因为它们:
477+
478+ * 有助于测试:每个函数可以独立运行
479+ * 促进确定性行为,从而提高可重复性
480+ * 防止由于修改共享状态而产生的错误
481+
482+ JAX 编译器热爱纯函数和函数式编程,因为:
483+
484+ * 数据依赖关系是显式的,有助于优化复杂计算
485+ * 纯函数更易于微分(自动微分)
486+ * 纯函数更易于并行化和优化(不依赖于共享的可变状态)
487+
488+ 另一种理解方式如下:
412489
413- JAX 将函数表示为计算图,然后对其进行编译或变换(例如,求导 )。
490+ JAX 将函数表示为计算图,然后对其进行编译或变换(例如,微分 )。
414491
415492这些计算图描述了给定的一组输入如何被转换为输出。
416493
417- 它们在构造上就是纯粹的 。
494+ JAX 的计算图在构造上是纯粹的 。
418495
419- JAX 使用函数式编程风格,使得用户构建的函数能够直接映射到 JAX 所支持的图论表示中。
496+ JAX 使用函数式编程风格,以便用户构建的函数能够直接映射到 JAX 所支持的图论表示中。
420497
421498## 随机数
422499
@@ -649,7 +726,7 @@ JAX 的显式性带来了显著的好处:
649726
650727JAX 的即时(JIT)编译器通过生成随任务大小和硬件变化的高效机器码来加速执行。
651728
652- 当我们在 {ref}` 上面 <jax_speed>` 对一个大型数组应用 ` cos ` 时,我们看到了 JAX 的 JIT 编译器结合并行硬件的强大之处。
729+ 我们在 {ref}` 上文 <jax_speed>` 中已经看到了 JAX 的 JIT 编译器结合并行硬件的强大之处,当时我们对一个大数组应用了 ` cos ` 函数 。
653730
654731让我们用一个更复杂的函数尝试同样的操作:
655732
@@ -709,12 +786,11 @@ with qe.Timer():
709786
710787结果与 ` cos ` 示例类似——JAX 更快,尤其是在 JIT 编译后的第二次运行中。
711788
712- 然而,使用 JAX,我们还有另一个技巧——我们可以对* 整个* 函数进行 JIT 编译,而不仅仅是单个操作。
713-
789+ 然而,使用 JAX,我们还有另一个技巧——我们可以对整个函数进行 JIT 编译,而不仅仅是单个操作。
714790
715791### 编译整个函数
716792
717- JAX 即时(JIT)编译器可以通过将数组操作融合到单个优化内核中来加速函数内部的执行 。
793+ JAX 即时(JIT)编译器可以通过将数组运算融合到单个优化内核中来加速函数内部的执行 。
718794
719795让我们用函数 ` f ` 来试试这个:
720796
@@ -742,7 +818,6 @@ with qe.Timer():
742818
743819例如,编译器可以消除对硬件加速器的多次调用以及许多中间数组的创建。
744820
745-
746821顺便提一下,当针对 JIT 编译器的函数时,更常见的语法是:
747822
748823``` {code-cell} ipython3
@@ -807,21 +882,64 @@ f(x)
807882
808883这个故事的寓意:使用 JAX 时请编写纯函数!
809884
810- ### 总结
885+ ## 使用 ` vmap ` 进行向量化
811886
812- 现在我们可以理解为什么开发者和编译器都受益于纯函数 。
887+ JAX 的另一个强大变换是 ` jax.vmap ` ,它能自动将一个针对单个输入编写的函数向量化,使其可以在批量数据上运行 。
813888
814- 我们喜欢纯函数,因为它们:
889+ 这避免了手动编写向量化代码或使用显式循环的需要。
815890
816- * 有助于测试:每个函数可以独立运行
817- * 促进确定性行为,从而实现可复现性
818- * 防止由于修改共享状态而产生的错误
891+ ### 一个简单的示例
819892
820- 编译器喜欢纯函数和函数式编程,因为:
893+ 假设我们有一个函数,用于计算一组数字的均值与中位数之差。
821894
822- * 数据依赖关系是显式的,有助于优化复杂计算
823- * 纯函数更容易进行微分(自动微分)
824- * 纯函数更容易并行化和优化(不依赖于共享可变状态)
895+ ``` {code-cell} ipython3
896+ def mm_diff(x):
897+ return jnp.mean(x) - jnp.median(x)
898+ ```
899+
900+ 我们可以将其应用于单个向量:
901+
902+ ``` {code-cell} ipython3
903+ x = jnp.array([1.0, 2.0, 5.0])
904+ mm_diff(x)
905+ ```
906+
907+ 现在假设我们有一个矩阵,想要对每一行计算这些统计量。
908+
909+ 不使用 ` vmap ` 时,我们需要显式循环:
910+
911+ ``` {code-cell} ipython3
912+ X = jnp.array([[1.0, 2.0, 5.0],
913+ [4.0, 5.0, 6.0],
914+ [1.0, 8.0, 9.0]])
915+
916+ for row in X:
917+ print(mm_diff(row))
918+ ```
919+
920+ 然而,Python 循环速度较慢,无法被 JAX 高效编译或并行化。
921+
922+ 使用 ` vmap ` 可以将计算保留在加速器上,并与其他 JAX 变换(如 ` jit ` 和 ` grad ` )组合使用:
923+
924+ ``` {code-cell} ipython3
925+ batch_mm_diff = jax.vmap(mm_diff)
926+ batch_mm_diff(X)
927+ ```
928+
929+ 函数 ` mm_diff ` 是针对单个数组编写的,而 ` vmap ` 自动将其提升为按行作用于矩阵的函数——无需循环,无需重新塑形。
930+
931+ ### 组合变换
932+
933+ JAX 的优势之一在于各变换可以自然地组合使用。
934+
935+ 例如,我们可以对向量化函数进行 JIT 编译:
936+
937+ ``` {code-cell} ipython3
938+ fast_batch_mm_diff = jax.jit(jax.vmap(mm_diff))
939+ fast_batch_mm_diff(X)
940+ ```
941+
942+ ` jit ` 、` vmap ` 以及(我们接下来将看到的)` grad ` 的这种组合方式是 JAX 设计的核心,使其在科学计算和机器学习领域尤为强大。
825943
826944## 练习
827945
0 commit comments