Skip to content

Commit 5f70fd1

Browse files
committed
Update translation: lectures/numpy_vs_numba_vs_jax.md
1 parent 0afe707 commit 5f70fd1

1 file changed

Lines changed: 75 additions & 34 deletions

File tree

lectures/numpy_vs_numba_vs_jax.md

Lines changed: 75 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,15 @@ translation:
1717
Vectorized operations::Parallelized Numba: 并行化的 Numba
1818
Vectorized operations::Vectorized code with JAX: 使用 JAX 的向量化代码
1919
Vectorized operations::JAX plus vmap: JAX 加 vmap
20-
Vectorized operations::JAX plus vmap::Version 1: 版本 1
21-
Vectorized operations::vmap version 2: vmap 版本 2
22-
Vectorized operations::Summary: 总结
20+
Vectorized operations::Summary: vmap 版本 2
2321
Sequential operations: 顺序运算
2422
Sequential operations::Numba Version: Numba 版本
2523
Sequential operations::JAX Version: JAX 版本
2624
Sequential operations::Summary: 总结
2725
Overall recommendations: 总体建议
2826
---
2927

30-
(parallel)=
28+
(numpy_numba_jax)=
3129
```{raw} jupyter
3230
<div id="qe-notebook-header" align="right" style="text-align:right;">
3331
<a href="https://quantecon.org/" title="quantecon.org">
@@ -69,7 +67,6 @@ tags: [hide-output]
6967
我们将使用以下导入。
7068

7169
```{code-cell} ipython3
72-
import random
7370
from functools import partial
7471
7572
import numpy as np
@@ -472,14 +469,16 @@ def qm(x0, n, α=4.0):
472469
```{code-cell} ipython3
473470
n = 10_000_000
474471
475-
with qe.Timer(precision=8):
472+
with qe.Timer():
473+
# First run
476474
x = qm(0.1, n)
477475
```
478476

479477
让我们再次运行以消除编译时间:
480478

481479
```{code-cell} ipython3
482-
with qe.Timer(precision=8):
480+
with qe.Timer():
481+
# Second run
483482
x = qm(0.1, n)
484483
```
485484

@@ -491,15 +490,62 @@ Numba 的编译通常相当快,对于像这样的顺序运算,生成的代
491490

492491
### JAX 版本
493492

494-
现在让我们使用 `lax.scan` 创建一个 JAX 版本
493+
现在让我们使用 `at[t].set` 风格的语法创建一个 JAX 版本,正如 {ref}`JAX 讲座中讨论的 <jax_at_workaround>`,这为不可变数组提供了一种变通方法。
495494

496-
(我们将 `n` 设为静态,因为它影响数组大小,JAX 希望在编译代码中针对其值进行特化处理。)
495+
我们将使用 `lax.fori_loop`,它是一种可以被 XLA 编译的 for 循环版本。
497496

498497
```{code-cell} ipython3
499498
cpu = jax.devices("cpu")[0]
500499
501-
@partial(jax.jit, static_argnums=(1,), device=cpu)
502-
def qm_jax(x0, n, α=4.0):
500+
@partial(jax.jit, static_argnames=("n",), device=cpu)
501+
def qm_jax_fori(x0, n, α=4.0):
502+
503+
x = jnp.empty(n + 1).at[0].set(x0)
504+
505+
def update(t, x):
506+
return x.at[t + 1].set(α * x[t] * (1 - x[t]))
507+
508+
x = lax.fori_loop(0, n, update, x)
509+
return x
510+
511+
```
512+
513+
* 我们将 `n` 设为静态,因为它影响数组大小,JAX 希望在编译代码中针对其值进行特化处理。
514+
* 我们通过 `device=cpu` 将计算固定在 CPU 上,因为这种顺序工作负载由许多小操作组成,几乎没有机会利用 GPU 并行性。
515+
516+
虽然 `at[t].set` 看起来在每一步都创建了一个新数组,但在 JIT 编译的函数内部,编译器会检测到旧数组不再需要,并就地执行更新。
517+
518+
让我们使用相同的参数计时:
519+
520+
```{code-cell} ipython3
521+
with qe.Timer():
522+
# First run
523+
x_jax = qm_jax_fori(0.1, n)
524+
# Hold interpreter
525+
x_jax.block_until_ready()
526+
```
527+
528+
让我们再次运行以消除编译开销:
529+
530+
```{code-cell} ipython3
531+
with qe.Timer():
532+
# Second run
533+
x_jax = qm_jax_fori(0.1, n)
534+
# Hold interpreter
535+
x_jax.block_until_ready()
536+
```
537+
538+
JAX 对于这种顺序运算也相当高效。
539+
540+
541+
我们还有另一种实现循环的方式,使用 `lax.scan`
542+
543+
这种替代方案可以说更符合 JAX 的函数式方法——尽管语法难以记忆。
544+
545+
546+
```{code-cell} ipython3
547+
@partial(jax.jit, static_argnames=("n",), device=cpu)
548+
def qm_jax_scan(x0, n, α=4.0):
503549
def update(x, t):
504550
x_new = α * x * (1 - x)
505551
return x_new, x_new
@@ -510,33 +556,27 @@ def qm_jax(x0, n, α=4.0):
510556

511557
这段代码不易阅读,但本质上,`lax.scan` 反复调用 `update` 并将返回值 `x_new` 累积到一个数组中。
512558

513-
```{note}
514-
细心的读者会注意到,我们在 `jax.jit` 装饰器中指定了 `device=cpu`。
515-
516-
该计算由许多小的顺序运算组成,几乎没有机会让 GPU 利用并行性。
517-
518-
因此,GPU 上的内核启动开销往往占主导地位,使得 CPU 更适合这种工作负载。
519-
520-
好奇的读者可以尝试删除此选项,看看性能如何变化。
521-
```
522-
523559
让我们使用相同的参数计时:
524560

525561
```{code-cell} ipython3
526-
with qe.Timer(precision=8):
527-
x_jax = qm_jax(0.1, n).block_until_ready()
562+
with qe.Timer():
563+
# First run
564+
x_jax = qm_jax_scan(0.1, n)
565+
# Hold interpreter
566+
x_jax.block_until_ready()
528567
```
529568

530569
让我们再次运行以消除编译开销:
531570

532571
```{code-cell} ipython3
533-
with qe.Timer(precision=8):
534-
x_jax = qm_jax(0.1, n).block_until_ready()
572+
with qe.Timer():
573+
# Second run
574+
x_jax = qm_jax_scan(0.1, n)
575+
# Hold interpreter
576+
x_jax.block_until_ready()
535577
```
536578

537-
JAX 对于这种顺序运算也相当高效。
538-
539-
JAX 和 Numba 在编译后都能提供出色的性能,对于纯顺序运算,Numba 通常(但并非总是)提供略快的速度。
579+
JAX 和 Numba 在编译后都能提供出色的性能。
540580

541581
### 总结
542582

@@ -546,11 +586,11 @@ Numba 版本简单直观,易于阅读:我们只需分配一个数组,然
546586

547587
这正是大多数程序员思考该算法的方式。
548588

549-
另一方面,JAX 版本需要使用 `lax.scan`这明显不够直观
589+
另一方面,JAX 版本需要使用 `lax.fori_loop``lax.scan`两者都比标准 Python 循环更不直观
550590

551-
此外,JAX 的不可变数组意味着我们无法简单地就地更新数组元素,这使得直接复制 Numba 使用的算法变得困难
591+
虽然 JAX `at[t].set` 语法确实允许逐元素更新,但整体代码仍然比 Numba 等价版本更难阅读
552592

553-
对于这类顺序运算,在代码清晰度、实现便利性以及高性能方面,Numba 是明显的赢家。
593+
对于这类顺序运算,在代码清晰度和实现便利性方面,Numba 是明显的赢家。
554594

555595
## 总体建议
556596

@@ -568,11 +608,12 @@ Numba 版本简单直观,易于阅读:我们只需分配一个数组,然
568608

569609
代码自然易读——只需一个带装饰器的 Python 循环——且性能出色。
570610

571-
JAX 可以通过 `lax.scan` 处理顺序问题,但对于纯顺序工作而言,其语法不够直观,性能提升也十分有限。
572-
573-
话虽如此,`lax.scan` 有一个重要优势:它支持对循环进行自动微分,而 Numba 无法做到这一点。
611+
JAX 可以通过 `lax.fori_loop``lax.scan` 处理顺序问题,但语法不够直观。
574612

613+
```{note}
614+
`lax.fori_loop` 和 `lax.scan` 有一个重要优势:它们支持对循环进行自动微分,而 Numba 无法做到这一点。
575615
如果需要对顺序计算进行微分(例如,计算轨迹对模型参数的敏感性),尽管语法不够自然,JAX 仍是更好的选择。
616+
```
576617

577618
在实践中,许多问题往往同时涉及两种模式。
578619

0 commit comments

Comments
 (0)