@@ -24,6 +24,7 @@ translation:
2424 Sequential operations::Numba Version : Numba 版本
2525 Sequential operations::JAX Version : JAX 版本
2626 Sequential operations::Summary : 总结
27+ Overall recommendations : 总体建议
2728---
2829
2930(parallel)=
@@ -69,7 +70,10 @@ tags: [hide-output]
6970
7071``` {code-cell} ipython3
7172import random
73+ from functools import partial
74+
7275import numpy as np
76+ import numba
7377import quantecon as qe
7478import matplotlib.pyplot as plt
7579import matplotlib as mpl # i18n
@@ -80,6 +84,7 @@ from mpl_toolkits.mplot3d.axes3d import Axes3D
8084from matplotlib import cm
8185import jax
8286import jax.numpy as jnp
87+ from jax import lax
8388```
8489
8590## 向量化运算
@@ -117,7 +122,7 @@ ax.plot_surface(x,
117122 y,
118123 f(x, y),
119124 rstride=2, cstride=2,
120- cmap=cm.jet ,
125+ cmap=cm.viridis ,
121126 alpha=0.7,
122127 linewidth=0.25)
123128ax.set_zlim(-0.5, 1.0)
@@ -143,7 +148,6 @@ for x in grid:
143148 m = z
144149```
145150
146-
147151### NumPy 向量化
148152
149153如果我们切换到 NumPy 风格的向量化,就可以使用更大的网格,并且代码执行速度相对较快。
@@ -168,14 +172,11 @@ print(f"NumPy result: {z_max_numpy:.6f}")
168172
169173(并行化效率不高,因为二进制文件在看到数组 ` x ` 和 ` y ` 的大小之前就已经被编译了。)
170174
171-
172175### 与 Numba 的比较
173176
174177现在让我们看看能否使用简单循环的 Numba 获得更好的性能。
175178
176179``` {code-cell} ipython3
177- import numba
178-
179180@numba.jit
180181def compute_max_numba(grid):
181182 m = -np.inf
@@ -189,9 +190,9 @@ def compute_max_numba(grid):
189190grid = np.linspace(-3, 3, 3_000)
190191
191192with qe.Timer(precision=8):
192- z_max_numpy = compute_max_numba(grid)
193+ z_max_numba = compute_max_numba(grid)
193194
194- print(f"Numba result: {z_max_numpy :.6f}")
195+ print(f"Numba result: {z_max_numba :.6f}")
195196```
196197
197198让我们再次运行以消除编译时间。
@@ -207,7 +208,6 @@ with qe.Timer(precision=8):
207208
208209另一方面,Numba 例程使用的内存少得多,因为我们只处理一个一维网格。
209210
210-
211211### 并行化的 Numba
212212
213213现在让我们使用 ` prange ` 尝试 Numba 的并行化:
@@ -282,7 +282,6 @@ with qe.Timer(precision=8):
282282
283283对于更强大的机器和更大的网格尺寸,即使在 CPU 上,并行化也能带来显著的速度提升。
284284
285-
286285### 使用 JAX 的向量化代码
287286
288287表面上,JAX 中的向量化代码与 NumPy 代码类似。
@@ -303,7 +302,7 @@ def f(x, y):
303302
304303``` {code-cell} ipython3
305304grid = jnp.linspace(-3, 3, 3_000)
306- x_mesh, y_mesh = np .meshgrid(grid, grid)
305+ x_mesh, y_mesh = jnp .meshgrid(grid, grid)
307306
308307with qe.Timer(precision=8):
309308 z_max = jnp.max(f(x_mesh, y_mesh))
@@ -320,11 +319,10 @@ with qe.Timer(precision=8):
320319 z_max.block_until_ready()
321320```
322321
323- 编译完成后,由于 GPU 加速, JAX 明显快于 NumPy。
322+ 编译完成后,JAX 明显快于 NumPy,尤其是在 GPU 上 。
324323
325324编译开销是一次性成本,当函数被反复调用时,这种开销是值得的。
326325
327-
328326### JAX 加 vmap
329327
330328NumPy 代码和 JAX 代码都存在一个问题:
@@ -386,7 +384,6 @@ with qe.Timer(precision=8):
386384
387385当我们处理更大的问题时,将进一步探讨这些想法。
388386
389-
390387### vmap 版本 2
391388
392389我们可以使用 vmap 进一步提高内存效率。
@@ -421,7 +418,7 @@ def compute_max_vmap_v2(grid):
421418with qe.Timer(precision=8):
422419 z_max = compute_max_vmap_v2(grid).block_until_ready()
423420
424- print(f"JAX vmap v1 result: {z_max:.6f}")
421+ print(f"JAX vmap v2 result: {z_max:.6f}")
425422```
426423
427424让我们再次运行以消除编译时间:
@@ -433,7 +430,6 @@ with qe.Timer(precision=8):
433430
434431如果您像我们一样在 GPU 上运行,应该能看到又一个不小的速度提升。
435432
436-
437433### 总结
438434
439435在我们看来,JAX 是向量化运算的赢家。
@@ -448,15 +444,13 @@ with qe.Timer(precision=8):
448444
449445对于经济学、计量经济学和金融学中遇到的大多数情况,将高效并行化的工作交给 JAX 编译器,远比尝试手工编写这些例程要好得多。
450446
451-
452447## 顺序运算
453448
454449某些运算本质上是顺序的——因此难以或不可能向量化。
455450
456451在这种情况下,NumPy 是一个较差的选择,我们只剩下 Numba 或 JAX 可以选择。
457452
458- 为了比较这两种选择,我们将重新回顾在{doc}` Numba 讲座 <numba> ` 中看到的迭代二次映射问题。
459-
453+ 为了比较这两种选择,我们将重新回顾在 {doc}` Numba 讲座 <numba> ` 中看到的迭代二次映射问题。
460454
461455### Numba 版本
462456
@@ -501,9 +495,6 @@ Numba 的编译通常相当快,对于像这样的顺序运算,生成的代
501495(我们将 ` n ` 设为静态,因为它影响数组大小,JAX 希望在编译代码中针对其值进行特化处理。)
502496
503497``` {code-cell} ipython3
504- from jax import lax
505- from functools import partial
506-
507498cpu = jax.devices("cpu")[0]
508499
509500@partial(jax.jit, static_argnums=(1,), device=cpu)
@@ -546,7 +537,6 @@ JAX 对于这种顺序运算也相当高效。
546537
547538JAX 和 Numba 在编译后都能提供出色的性能,对于纯顺序运算,Numba 通常(但并非总是)提供略快的速度。
548539
549-
550540### 总结
551541
552542虽然 Numba 和 JAX 在顺序运算中都能提供出色的性能,但** 在代码可读性和易用性方面存在显著差异** 。
@@ -559,4 +549,30 @@ Numba 版本简单直观,易于阅读:我们只需分配一个数组,然
559549
560550此外,JAX 的不可变数组意味着我们无法简单地就地更新数组元素,这使得直接复制 Numba 使用的算法变得困难。
561551
562- 对于这类顺序运算,在代码清晰度、实现便利性以及高性能方面,Numba 是明显的赢家。
552+ 对于这类顺序运算,在代码清晰度、实现便利性以及高性能方面,Numba 是明显的赢家。
553+
554+ ## 总体建议
555+
556+ 让我们退一步,总结一下各方案的权衡取舍。
557+
558+ 对于** 向量化操作** ,JAX 是最强的选择。
559+
560+ 得益于 JIT 编译和跨 CPU 与 GPU 的高效并行化,它在速度上与 NumPy 持平甚至超越 NumPy。
561+
562+ ` vmap ` 变换可以减少内存使用,并且通常比基于传统网格(meshgrid)的向量化方式产生更清晰的代码。
563+
564+ 此外,JAX 函数支持自动微分,我们将在 {doc}` autodiff ` 中进行探讨。
565+
566+ 对于** 顺序操作** ,Numba 具有明显优势。
567+
568+ 代码自然易读——只需一个带装饰器的 Python 循环——且性能出色。
569+
570+ JAX 可以通过 ` lax.scan ` 处理顺序问题,但对于纯顺序工作而言,其语法不够直观,性能提升也十分有限。
571+
572+ 话虽如此,` lax.scan ` 有一个重要优势:它支持对循环进行自动微分,而 Numba 无法做到这一点。
573+
574+ 如果需要对顺序计算进行微分(例如,计算轨迹对模型参数的敏感性),尽管语法不够自然,JAX 仍是更好的选择。
575+
576+ 在实践中,许多问题往往同时涉及两种模式。
577+
578+ 一个实用的经验法则是:新项目默认使用 JAX,尤其是在硬件加速或可微分性可能有用的情况下;当需要一个快速且可读的紧凑顺序循环时,则选用 Numba。
0 commit comments