@@ -17,16 +17,15 @@ translation:
1717 Vectorized operations::Parallelized Numba : 并行化的 Numba
1818 Vectorized operations::Vectorized code with JAX : 使用 JAX 的向量化代码
1919 Vectorized operations::JAX plus vmap : JAX 加 vmap
20- Vectorized operations::JAX plus vmap::Version 1 : 版本 1
21- Vectorized operations::vmap version 2 : vmap 版本 2
2220 Vectorized operations::Summary : 总结
2321 Sequential operations : 顺序运算
2422 Sequential operations::Numba Version : Numba 版本
2523 Sequential operations::JAX Version : JAX 版本
2624 Sequential operations::Summary : 总结
25+ Overall recommendations : 总体建议
2726---
2827
29- (parallel )=
28+ (numpy_numba_jax )=
3029``` {raw} jupyter
3130<div id="qe-notebook-header" align="right" style="text-align:right;">
3231 <a href="https://quantecon.org/" title="quantecon.org">
@@ -144,14 +143,13 @@ for x in grid:
144143 m = z
145144```
146145
147-
148146### NumPy 向量化
149147
150148如果我们切换到 NumPy 风格的向量化,就可以使用更大的网格,并且代码执行速度相对较快。
151149
152150这里我们使用 ` np.meshgrid ` 来创建二维输入网格 ` x ` 和 ` y ` ,使得 ` f(x, y) ` 能生成乘积网格上的所有计算结果。
153151
154- (这一策略可以追溯到 Matlab 。)
152+ (这一策略可以追溯到 MATLAB 。)
155153
156154``` {code-cell} ipython3
157155grid = np.linspace(-3, 3, 3_000)
@@ -169,7 +167,6 @@ print(f"NumPy result: {z_max_numpy:.6f}")
169167
170168(并行化效率不高,因为二进制文件在看到数组 ` x ` 和 ` y ` 的大小之前就已经被编译了。)
171169
172-
173170### 与 Numba 的比较
174171
175172现在让我们看看能否使用简单循环的 Numba 获得更好的性能。
@@ -208,7 +205,6 @@ with qe.Timer(precision=8):
208205
209206另一方面,Numba 例程使用的内存少得多,因为我们只处理一个一维网格。
210207
211-
212208### 并行化的 Numba
213209
214210现在让我们使用 ` prange ` 尝试 Numba 的并行化:
@@ -231,24 +227,44 @@ def compute_max_numba_parallel(grid):
231227
232228```
233229
234- 这通常会返回不正确的结果 :
230+ 这将返回 ` -inf ` ——即 ` m ` 的初始值,仿佛它从未被更新过 :
235231
236232``` {code-cell} ipython3
237233z_max_parallel_incorrect = compute_max_numba_parallel(grid)
238234print(f"Numba result: {z_max_parallel_incorrect} 😱")
239235```
240236
241- 原因是变量 ` m ` 被多个线程共享,但没有得到正确控制。
237+ 要理解原因,请回忆 ` prange ` 会将外层循环拆分到各个线程中。
238+
239+ 每个线程都会得到自己的 ` m ` 私有副本,初始化为 ` -np.inf ` ,并在其负责的迭代块中正确地更新它。
242240
243- 当多个线程同时尝试读写 ` m ` 时,它们会相互干扰 。
241+ 但在循环结束时,Numba 需要将各线程的 ` m ` 副本合并为一个单一的值——即 ** 归约 ** 操作 。
244242
245- 线程读取了 ` m ` 的过时值,或者相互覆盖了更新——或者 ` m ` 始终保持其初始值而从未被更新 。
243+ 对于它能识别的模式,例如 ` m += z ` (求和)或 ` m = max(m, z) ` (求最大值),Numba 知道合并算子 。
246244
247- 这里有一个更仔细编写的版本。
245+ 但它无法将 ` if z > m: m = z ` 识别为最大值归约,因此各线程的结果永远不会被合并,` m ` 始终保持其初始值。
246+
247+ 最简单的修复方法是将条件判断替换为 Numba 能识别的 ` max ` :
248248
249249``` {code-cell} ipython3
250250@numba.jit(parallel=True)
251251def compute_max_numba_parallel(grid):
252+ n = len(grid)
253+ m = -np.inf
254+ for i in numba.prange(n):
255+ for j in range(n):
256+ x = grid[i]
257+ y = grid[j]
258+ z = np.cos(x**2 + y**2) / (1 + x**2 + y**2)
259+ m = max(m, z)
260+ return m
261+ ```
262+
263+ 另一种方法是使循环体在不同 ` i ` 之间完全独立,并自行处理归约:
264+
265+ ``` {code-cell} ipython3
266+ @numba.jit(parallel=True)
267+ def compute_max_numba_parallel_v2(grid):
252268 n = len(grid)
253269 row_maxes = np.empty(n)
254270 for i in numba.prange(n):
@@ -263,9 +279,7 @@ def compute_max_numba_parallel(grid):
263279 return np.max(row_maxes)
264280```
265281
266- 现在 ` for i in numba.prange(n) ` 所作用的代码块在不同的 ` i ` 之间是独立的。
267-
268- 每个线程写入数组 ` row_maxes ` 的不同元素,并行化是安全的。
282+ 在这里,每个线程写入 ` row_maxes ` 的不同元素,因此我们通过 ` np.max ` 自行处理归约。
269283
270284``` {code-cell} ipython3
271285z_max_parallel = compute_max_numba_parallel(grid)
@@ -283,7 +297,6 @@ with qe.Timer(precision=8):
283297
284298对于更强大的机器和更大的网格尺寸,即使在 CPU 上,并行化也能带来显著的速度提升。
285299
286-
287300### 使用 JAX 的向量化代码
288301
289302表面上,JAX 中的向量化代码与 NumPy 代码类似。
@@ -304,7 +317,7 @@ def f(x, y):
304317
305318``` {code-cell} ipython3
306319grid = jnp.linspace(-3, 3, 3_000)
307- x_mesh, y_mesh = np .meshgrid(grid, grid)
320+ x_mesh, y_mesh = jnp .meshgrid(grid, grid)
308321
309322with qe.Timer(precision=8):
310323 z_max = jnp.max(f(x_mesh, y_mesh))
@@ -321,14 +334,13 @@ with qe.Timer(precision=8):
321334 z_max.block_until_ready()
322335```
323336
324- 编译完成后,由于 GPU 加速, JAX 明显快于 NumPy。
337+ 编译完成后,JAX 明显快于 NumPy,在 GPU 上尤为如此 。
325338
326339编译开销是一次性成本,当函数被反复调用时,这种开销是值得的。
327340
328-
329341### JAX 加 vmap
330342
331- NumPy 代码和 JAX 代码都存在一个问题:
343+ NumPy 代码和上述 JAX 代码都存在一个问题:
332344
333345虽然扁平数组占用内存较少
334346
@@ -346,9 +358,9 @@ x_mesh.nbytes + y_mesh.nbytes
346358
347359幸运的是,JAX 提供了一种使用 [ jax.vmap] ( https://docs.jax.dev/en/latest/_autosummary/jax.vmap.html ) 的不同方法。
348360
349- #### 版本 1
361+ ` vmap ` 的思路是将向量化分阶段进行,将一个对单个值进行操作的函数转化为对数组进行操作的函数。
350362
351- 以下是我们应用 ` vmap ` 的一种方式 。
363+ 以下是我们将其应用于当前问题的方式 。
352364
353365``` {code-cell} ipython3
354366# 设置 f,使其在给定任意 y 时,对所有 x 计算 f(x, y)
@@ -375,32 +387,19 @@ with qe.Timer(precision=8):
375387 z_max.block_until_ready()
376388```
377389
378- 通过避免使用大型输入数组 ` x_mesh ` 和 ` y_mesh ` ,这个 ` vmap ` 版本使用的内存少得多。
379-
380- 在 CPU 上运行时,其运行时间与网格版本相似。
381-
382- 在 GPU 上运行时,通常速度要快得多。
383-
384- 实际上,使用 ` vmap ` 还有另一个优势:它允许我们将向量化分阶段进行。
385-
386- 这往往会产生比传统向量化代码更易于理解的代码。
390+ 通过避免使用大型输入数组 ` x_mesh ` 和 ` y_mesh ` ,这个 ` vmap ` 版本使用的内存少得多,运行时间也相近。
387391
388- 当我们处理更大的问题时,将进一步探讨这些想法 。
392+ 但我们仍然留有一些速度提升的空间未被利用 。
389393
394+ 上面的代码计算了完整的二维数组 ` f(x,y) ` ,然后再取最大值。
390395
391- ### vmap 版本 2
396+ 此外, ` jnp.max ` 调用位于 JIT 编译函数 ` f ` 之外,因此编译器无法将这些操作融合为单个内核。
392397
393- 我们可以使用 vmap 进一步提高内存效率。
394-
395- 在前一个版本中,虽然我们避免了大型输入数组,但在计算最大值之前仍然会创建大型输出数组 ` f(x,y) ` 。
396-
397- 让我们尝试一种略有不同的方法,将求最大值操作移到内部。
398-
399- 由于这一改变,我们永远不会计算二维数组 ` f(x,y) ` 。
398+ 我们可以通过将最大值操作移到内部并将所有内容包装在一个 ` @jax.jit ` 中来解决这两个问题:
400399
401400``` {code-cell} ipython3
402401@jax.jit
403- def compute_max_vmap_v2 (grid):
402+ def compute_max_vmap (grid):
404403 # 构建一个沿每行取最大值的函数
405404 f_vec_x_max = lambda y: jnp.max(f(grid, y))
406405 # 向量化该函数,以便我们可以同时对所有行调用
@@ -416,25 +415,26 @@ def compute_max_vmap_v2(grid):
416415
417416我们将此函数应用于所有行,然后取各行最大值中的最大值。
418417
418+ 由于将最大值操作移到内部,我们永远不会构建完整的二维数组 ` f(x,y) ` ,从而节省了更多内存。
419+
420+ 并且由于所有内容都在单个 ` @jax.jit ` 下,编译器可以将所有操作融合为一个优化的内核。
421+
419422让我们试试。
420423
421424``` {code-cell} ipython3
422425with qe.Timer(precision=8):
423- z_max = compute_max_vmap_v2 (grid).block_until_ready()
426+ z_max = compute_max_vmap (grid).block_until_ready()
424427
425- print(f"JAX vmap v1 result: {z_max:.6f}")
428+ print(f"JAX vmap result: {z_max:.6f}")
426429```
427430
428431让我们再次运行以消除编译时间:
429432
430433``` {code-cell} ipython3
431434with qe.Timer(precision=8):
432- z_max = compute_max_vmap_v2 (grid).block_until_ready()
435+ z_max = compute_max_vmap (grid).block_until_ready()
433436```
434437
435- 如果您像我们一样在 GPU 上运行,应该能看到又一个不小的速度提升。
436-
437-
438438### 总结
439439
440440在我们看来,JAX 是向量化运算的赢家。
@@ -449,15 +449,13 @@ with qe.Timer(precision=8):
449449
450450对于经济学、计量经济学和金融学中遇到的大多数情况,将高效并行化的工作交给 JAX 编译器,远比尝试手工编写这些例程要好得多。
451451
452-
453452## 顺序运算
454453
455454某些运算本质上是顺序的——因此难以或不可能向量化。
456455
457456在这种情况下,NumPy 是一个较差的选择,我们只剩下 Numba 或 JAX 可以选择。
458457
459- 为了比较这两种选择,我们将重新回顾在{doc}` Numba 讲座 <numba> ` 中看到的迭代二次映射问题。
460-
458+ 为了比较这两种选择,我们将重新回顾在 {doc}` Numba 讲座 <numba> ` 中看到的迭代二次映射问题。
461459
462460### Numba 版本
463461
@@ -545,8 +543,7 @@ with qe.Timer(precision=8):
545543
546544JAX 对于这种顺序运算也相当高效。
547545
548- JAX 和 Numba 在编译后都能提供出色的性能,对于纯顺序运算,Numba 通常(但并非总是)提供略快的速度。
549-
546+ JAX 和 Numba 在编译后都能提供出色的性能。
550547
551548### 总结
552549
@@ -560,4 +557,31 @@ Numba 版本简单直观,易于阅读:我们只需分配一个数组,然
560557
561558此外,JAX 的不可变数组意味着我们无法简单地就地更新数组元素,这使得直接复制 Numba 使用的算法变得困难。
562559
563- 对于这类顺序运算,在代码清晰度、实现便利性以及高性能方面,Numba 是明显的赢家。
560+ 对于这类顺序运算,在代码清晰度和实现便利性方面,Numba 是明显的赢家。
561+
562+ ## 总体建议
563+
564+ 让我们退一步,总结一下各方的权衡取舍。
565+
566+ 对于** 向量化操作** ,JAX 是最强的选择。
567+
568+ 得益于 JIT 编译和在 CPU 与 GPU 上的高效并行化,它在速度上与 NumPy 持平或超越 NumPy。
569+
570+ ` vmap ` 变换降低了内存使用量,并且通常比传统的基于网格的向量化产生更清晰的代码。
571+
572+ 此外,JAX 函数支持自动微分,我们将在 {doc}` autodiff ` 中进一步探讨。
573+
574+ 对于** 顺序操作** ,Numba 具有明显优势。
575+
576+ 代码自然且可读——只需一个带有装饰器的 Python 循环——性能也非常出色。
577+
578+ JAX 可以通过 ` lax.scan ` 处理顺序问题,但语法不够直观。
579+
580+ ``` {note}
581+ `lax.scan` 的一个重要优势是它支持通过循环进行自动微分,而 Numba 无法做到这一点。
582+ 如果您需要对顺序计算进行微分(例如,计算轨迹对模型参数的敏感性),尽管语法不够自然,JAX 仍是更好的选择。
583+ ```
584+
585+ 在实践中,许多问题涉及两种模式的混合。
586+
587+ 一个实用的经验法则是:对于新项目默认使用 JAX,尤其是当硬件加速或可微分性可能有用时,而当您有一个需要快速且可读的紧凑顺序循环时,则选用 Numba。
0 commit comments