Skip to content

Commit 78596cd

Browse files
committed
Update translation: lectures/numpy_vs_numba_vs_jax.md
1 parent eb4aa1e commit 78596cd

1 file changed

Lines changed: 39 additions & 23 deletions

File tree

lectures/numpy_vs_numba_vs_jax.md

Lines changed: 39 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ translation:
2424
Sequential operations::Numba Version: Numba 版本
2525
Sequential operations::JAX Version: JAX 版本
2626
Sequential operations::Summary: 总结
27+
Overall recommendations: 总体建议
2728
---
2829

2930
(parallel)=
@@ -69,7 +70,10 @@ tags: [hide-output]
6970

7071
```{code-cell} ipython3
7172
import random
73+
from functools import partial
74+
7275
import numpy as np
76+
import numba
7377
import quantecon as qe
7478
import matplotlib.pyplot as plt
7579
import matplotlib as mpl # i18n
@@ -80,6 +84,7 @@ from mpl_toolkits.mplot3d.axes3d import Axes3D
8084
from matplotlib import cm
8185
import jax
8286
import jax.numpy as jnp
87+
from jax import lax
8388
```
8489

8590
## 向量化运算
@@ -117,7 +122,7 @@ ax.plot_surface(x,
117122
y,
118123
f(x, y),
119124
rstride=2, cstride=2,
120-
cmap=cm.jet,
125+
cmap=cm.viridis,
121126
alpha=0.7,
122127
linewidth=0.25)
123128
ax.set_zlim(-0.5, 1.0)
@@ -143,7 +148,6 @@ for x in grid:
143148
m = z
144149
```
145150

146-
147151
### NumPy 向量化
148152

149153
如果我们切换到 NumPy 风格的向量化,就可以使用更大的网格,并且代码执行速度相对较快。
@@ -168,14 +172,11 @@ print(f"NumPy result: {z_max_numpy:.6f}")
168172

169173
(并行化效率不高,因为二进制文件在看到数组 `x``y` 的大小之前就已经被编译了。)
170174

171-
172175
### 与 Numba 的比较
173176

174177
现在让我们看看能否使用简单循环的 Numba 获得更好的性能。
175178

176179
```{code-cell} ipython3
177-
import numba
178-
179180
@numba.jit
180181
def compute_max_numba(grid):
181182
m = -np.inf
@@ -189,9 +190,9 @@ def compute_max_numba(grid):
189190
grid = np.linspace(-3, 3, 3_000)
190191
191192
with qe.Timer(precision=8):
192-
z_max_numpy = compute_max_numba(grid)
193+
z_max_numba = compute_max_numba(grid)
193194
194-
print(f"Numba result: {z_max_numpy:.6f}")
195+
print(f"Numba result: {z_max_numba:.6f}")
195196
```
196197

197198
让我们再次运行以消除编译时间。
@@ -207,7 +208,6 @@ with qe.Timer(precision=8):
207208

208209
另一方面,Numba 例程使用的内存少得多,因为我们只处理一个一维网格。
209210

210-
211211
### 并行化的 Numba
212212

213213
现在让我们使用 `prange` 尝试 Numba 的并行化:
@@ -282,7 +282,6 @@ with qe.Timer(precision=8):
282282

283283
对于更强大的机器和更大的网格尺寸,即使在 CPU 上,并行化也能带来显著的速度提升。
284284

285-
286285
### 使用 JAX 的向量化代码
287286

288287
表面上,JAX 中的向量化代码与 NumPy 代码类似。
@@ -303,7 +302,7 @@ def f(x, y):
303302

304303
```{code-cell} ipython3
305304
grid = jnp.linspace(-3, 3, 3_000)
306-
x_mesh, y_mesh = np.meshgrid(grid, grid)
305+
x_mesh, y_mesh = jnp.meshgrid(grid, grid)
307306
308307
with qe.Timer(precision=8):
309308
z_max = jnp.max(f(x_mesh, y_mesh))
@@ -320,11 +319,10 @@ with qe.Timer(precision=8):
320319
z_max.block_until_ready()
321320
```
322321

323-
编译完成后,由于 GPU 加速,JAX 明显快于 NumPy。
322+
编译完成后,JAX 明显快于 NumPy,尤其是在 GPU 上
324323

325324
编译开销是一次性成本,当函数被反复调用时,这种开销是值得的。
326325

327-
328326
### JAX 加 vmap
329327

330328
NumPy 代码和 JAX 代码都存在一个问题:
@@ -386,7 +384,6 @@ with qe.Timer(precision=8):
386384

387385
当我们处理更大的问题时,将进一步探讨这些想法。
388386

389-
390387
### vmap 版本 2
391388

392389
我们可以使用 vmap 进一步提高内存效率。
@@ -421,7 +418,7 @@ def compute_max_vmap_v2(grid):
421418
with qe.Timer(precision=8):
422419
z_max = compute_max_vmap_v2(grid).block_until_ready()
423420
424-
print(f"JAX vmap v1 result: {z_max:.6f}")
421+
print(f"JAX vmap v2 result: {z_max:.6f}")
425422
```
426423

427424
让我们再次运行以消除编译时间:
@@ -433,7 +430,6 @@ with qe.Timer(precision=8):
433430

434431
如果您像我们一样在 GPU 上运行,应该能看到又一个不小的速度提升。
435432

436-
437433
### 总结
438434

439435
在我们看来,JAX 是向量化运算的赢家。
@@ -448,15 +444,13 @@ with qe.Timer(precision=8):
448444

449445
对于经济学、计量经济学和金融学中遇到的大多数情况,将高效并行化的工作交给 JAX 编译器,远比尝试手工编写这些例程要好得多。
450446

451-
452447
## 顺序运算
453448

454449
某些运算本质上是顺序的——因此难以或不可能向量化。
455450

456451
在这种情况下,NumPy 是一个较差的选择,我们只剩下 Numba 或 JAX 可以选择。
457452

458-
为了比较这两种选择,我们将重新回顾在{doc}`Numba 讲座 <numba>`中看到的迭代二次映射问题。
459-
453+
为了比较这两种选择,我们将重新回顾在 {doc}`Numba 讲座 <numba>` 中看到的迭代二次映射问题。
460454

461455
### Numba 版本
462456

@@ -501,9 +495,6 @@ Numba 的编译通常相当快,对于像这样的顺序运算,生成的代
501495
(我们将 `n` 设为静态,因为它影响数组大小,JAX 希望在编译代码中针对其值进行特化处理。)
502496

503497
```{code-cell} ipython3
504-
from jax import lax
505-
from functools import partial
506-
507498
cpu = jax.devices("cpu")[0]
508499
509500
@partial(jax.jit, static_argnums=(1,), device=cpu)
@@ -546,7 +537,6 @@ JAX 对于这种顺序运算也相当高效。
546537

547538
JAX 和 Numba 在编译后都能提供出色的性能,对于纯顺序运算,Numba 通常(但并非总是)提供略快的速度。
548539

549-
550540
### 总结
551541

552542
虽然 Numba 和 JAX 在顺序运算中都能提供出色的性能,但**在代码可读性和易用性方面存在显著差异**
@@ -559,4 +549,30 @@ Numba 版本简单直观,易于阅读:我们只需分配一个数组,然
559549

560550
此外,JAX 的不可变数组意味着我们无法简单地就地更新数组元素,这使得直接复制 Numba 使用的算法变得困难。
561551

562-
对于这类顺序运算,在代码清晰度、实现便利性以及高性能方面,Numba 是明显的赢家。
552+
对于这类顺序运算,在代码清晰度、实现便利性以及高性能方面,Numba 是明显的赢家。
553+
554+
## 总体建议
555+
556+
让我们退一步,总结一下各方案的权衡取舍。
557+
558+
对于**向量化操作**,JAX 是最强的选择。
559+
560+
得益于 JIT 编译和跨 CPU 与 GPU 的高效并行化,它在速度上与 NumPy 持平甚至超越 NumPy。
561+
562+
`vmap` 变换可以减少内存使用,并且通常比基于传统网格(meshgrid)的向量化方式产生更清晰的代码。
563+
564+
此外,JAX 函数支持自动微分,我们将在 {doc}`autodiff` 中进行探讨。
565+
566+
对于**顺序操作**,Numba 具有明显优势。
567+
568+
代码自然易读——只需一个带装饰器的 Python 循环——且性能出色。
569+
570+
JAX 可以通过 `lax.scan` 处理顺序问题,但对于纯顺序工作而言,其语法不够直观,性能提升也十分有限。
571+
572+
话虽如此,`lax.scan` 有一个重要优势:它支持对循环进行自动微分,而 Numba 无法做到这一点。
573+
574+
如果需要对顺序计算进行微分(例如,计算轨迹对模型参数的敏感性),尽管语法不够自然,JAX 仍是更好的选择。
575+
576+
在实践中,许多问题往往同时涉及两种模式。
577+
578+
一个实用的经验法则是:新项目默认使用 JAX,尤其是在硬件加速或可微分性可能有用的情况下;当需要一个快速且可读的紧凑顺序循环时,则选用 Numba。

0 commit comments

Comments
 (0)