Skip to content

Commit b5edbd6

Browse files
Merge remote-tracking branch 'origin/main' into translation-sync-2026-04-13T12-48-54-pr-530
# Conflicts: # .translate/state/jax_intro.md.yml # .translate/state/numpy_vs_numba_vs_jax.md.yml # lectures/jax_intro.md # lectures/numba.md # lectures/numpy_vs_numba_vs_jax.md Co-authored-by: HumphreyYang <39026988+HumphreyYang@users.noreply.github.com>
2 parents f6569d2 + 2006d74 commit b5edbd6

File tree

4 files changed

+43
-131
lines changed

4 files changed

+43
-131
lines changed

.translate/state/jax_intro.md.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
source-sha: 11e7d823f7f355f5025d40cab40bf801b3262e56
1+
source-sha: 95378b8382b4dbd1cd3e0ffe0e152811894c357f
22
synced-at: "2026-04-13"
33
model: claude-sonnet-4-6
44
mode: UPDATE

.translate/state/numpy_vs_numba_vs_jax.md.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
source-sha: 11e7d823f7f355f5025d40cab40bf801b3262e56
1+
source-sha: 95378b8382b4dbd1cd3e0ffe0e152811894c357f
22
synced-at: "2026-04-13"
33
model: claude-sonnet-4-6
44
mode: UPDATE

lectures/jax_intro.md

Lines changed: 22 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,11 @@ translation:
2121
JAX as a NumPy Replacement::Differences::Size Experiment: 大小实验
2222
JAX as a NumPy Replacement::Differences::Precision: 精度
2323
JAX as a NumPy Replacement::Differences::Immutability: 不可变性
24-
JAX as a NumPy Replacement::Differences::A Workaround: 变通方法
24+
JAX as a NumPy Replacement::Differences::A workaround: 变通方法
2525
Functional Programming: 函数式编程
2626
Functional Programming::Pure functions: 纯函数
2727
Functional Programming::Examples: 示例
28-
Functional Programming::Why Functional Programming?: 为什么要函数式编程
28+
Functional Programming::Why Functional Programming?: 为什么使用函数式编程
2929
Random numbers: 随机数
3030
Random numbers::Random number generation: 随机数生成
3131
Random numbers::Why explicit random state?: 为什么要显式随机状态?
@@ -37,10 +37,6 @@ translation:
3737
JIT Compilation::Compiling the Whole Function: 编译整个函数
3838
JIT Compilation::How JIT compilation works: JIT 编译的工作原理
3939
JIT Compilation::Compiling non-pure functions: 编译非纯函数
40-
Vectorization with vmap: 使用 vmap 进行向量化
41-
Vectorization with vmap::A simple example: 一个简单的示例
42-
Vectorization with vmap::Combining transformations: 组合变换
43-
Automatic differentiation: a preview: 自动微分:预览
4440
Exercises: 练习
4541
---
4642

@@ -77,17 +73,17 @@ import numpy as np
7773
import quantecon as qe
7874
```
7975

76+
注意我们导入了 `jax.numpy as jnp`,它提供了类似 NumPy 的接口。
77+
8078
## JAX 作为 NumPy 的替代品
8179

82-
让我们来看看 JAX NumPy 之间的异同
80+
JAX 的一个吸引人之处在于,它的数组处理操作在尽可能的情况下遵循 NumPy API
8381

84-
### 相似之处
82+
这意味着在许多情况下,我们可以将 JAX 作为 NumPy 的直接替代品使用。
8583

86-
上面我们导入了 `jax.numpy as jnp`,它提供了类似 NumPy 的数组操作接口。
87-
88-
JAX 的一个吸引人之处在于,这个接口在尽可能的情况下遵循 NumPy API。
84+
让我们来看看 JAX 和 NumPy 之间的异同。
8985

90-
因此,我们通常可以将 JAX 作为 NumPy 的直接替代品使用。
86+
### 相似之处
9187

9288
以下是使用 `jnp` 进行的一些标准数组操作:
9389

@@ -107,7 +103,7 @@ print(jnp.sum(a))
107103
print(jnp.dot(a, a))
108104
```
109105

110-
但需要注意的是,数组对象 `a` 并不是 NumPy 数组:
106+
然而,数组对象 `a` 并不是 NumPy 数组:
111107

112108
```{code-cell} ipython3
113109
a
@@ -117,7 +113,7 @@ a
117113
type(a)
118114
```
119115

120-
即使是数组上的标量值映射也会返回 JAX 数组而非标量
116+
即使是数组上的标量值映射也会返回 JAX 数组,而不是标量
121117

122118
```{code-cell} ipython3
123119
jnp.sum(a)
@@ -130,18 +126,16 @@ jnp.sum(a)
130126
(jax_speed)=
131127
#### 速度!
132128

133-
一个主要差异是 JAX 更快——有时快得多。
134-
135-
为了说明这一点,假设我们想在许多点处计算余弦函数。
129+
假设我们想在许多点上计算余弦函数。
136130

137131
```{code-cell}
138132
n = 50_000_000
139-
x = np.linspace(0, 10, n) # NumPy array
133+
x = np.linspace(0, 10, n)
140134
```
141135

142136
##### 使用 NumPy
143137

144-
让我们用 NumPy 来试试
138+
让我们先用 NumPy 试试:
145139

146140
```{code-cell}
147141
with qe.Timer():
@@ -159,7 +153,7 @@ with qe.Timer():
159153

160154
这里
161155

162-
* NumPy 使用预编译的二进制文件将余弦函数应用于浮点数数组
156+
* NumPy 使用预编译的二进制文件对浮点数数组应用余弦函数
163157
* 该二进制文件在本地机器的 CPU 上运行
164158

165159
##### 使用 JAX
@@ -181,8 +175,11 @@ with qe.Timer():
181175
```
182176

183177
```{note}
184-
上面的 `block_until_ready` 方法会阻塞解释器,直到计算结果返回。
185-
这对于计时是必要的,因为 JAX 使用异步调度,允许 Python 解释器在数值计算之前继续运行。
178+
这里,为了测量实际速度,我们使用 `block_until_ready` 方法来阻塞解释器,直到计算结果返回。
179+
180+
这是必要的,因为 JAX 使用异步调度,允许 Python 解释器在数值计算之前运行。
181+
182+
对于非计时代码,可以删除包含 `block_until_ready` 的那一行。
186183
```
187184

188185
再来计时一次。
@@ -277,8 +274,7 @@ a[0] = 1
277274
a
278275
```
279276

280-
在 JAX 中,这会失败 😱。
281-
277+
在 JAX 中,这会失败!
282278

283279
```{code-cell} ipython3
284280
a = jnp.linspace(0, 1, 3)
@@ -290,21 +286,13 @@ try:
290286
a[0] = 1
291287
except Exception as e:
292288
print(e)
293-
294289
```
295290

296-
JAX 的设计者选择将数组设为不可变的,因为
297-
298-
1. JAX 使用*函数式编程风格*,并且
299-
2. 函数式编程通常避免可变数据
300-
301-
我们将在 {ref}`下面 <jax_func>` 讨论这些思想。
291+
JAX 的设计者选择将数组设为不可变的,因为 JAX 使用函数式编程风格,我们将在下面讨论这一点。
302292

303-
304-
(jax_at_workaround)=
305293
#### 变通方法
306294

307-
JAX 确实通过 [`at` 方法](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html) 提供了原地数组修改的直接替代方案
295+
我们注意到 JAX 确实提供了一种替代原地数组修改的方式,使用 [`at` 方法](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html)
308296

309297
```{code-cell} ipython3
310298
a = jnp.linspace(0, 1, 3)
@@ -326,8 +314,6 @@ a
326314

327315
(尽管它在 JIT 编译的函数中实际上可以很高效——但现在先把这个放在一边。)
328316

329-
330-
(jax_func)=
331317
## 函数式编程
332318

333319
来自 JAX 的文档:
@@ -860,40 +846,6 @@ fast_batch_mm_diff(X)
860846

861847
`jit``vmap` 以及(我们接下来将看到的)`grad` 的这种组合方式是 JAX 设计的核心,使其在科学计算和机器学习领域尤为强大。
862848

863-
864-
## 自动微分:预览
865-
866-
JAX 可以使用自动微分来计算梯度。
867-
868-
这对于优化和求解非线性系统非常有用。
869-
870-
以下是一个简单的示例,涉及函数 $f(x) = x^2 / 2$:
871-
872-
```{code-cell} ipython3
873-
def f(x):
874-
return (x**2) / 2
875-
876-
f_prime = jax.grad(f)
877-
```
878-
879-
```{code-cell} ipython3
880-
f_prime(10.0)
881-
```
882-
883-
让我们绘制函数和导数,注意 $f'(x) = x$。
884-
885-
```{code-cell} ipython3
886-
fig, ax = plt.subplots()
887-
x_grid = jnp.linspace(-4, 4, 200)
888-
ax.plot(x_grid, f(x_grid), label="$f$")
889-
ax.plot(x_grid, [f_prime(x) for x in x_grid], label="$f'$")
890-
ax.legend(loc='upper center')
891-
plt.show()
892-
```
893-
894-
自动微分是一个有许多经济学和金融学应用的深层主题。我们在{doc}`自动微分讲座 <autodiff>`中提供了更深入的讨论。
895-
896-
897849
## 练习
898850

899851

lectures/numpy_vs_numba_vs_jax.md

Lines changed: 19 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ tags: [hide-output]
6767
我们将使用以下导入。
6868

6969
```{code-cell} ipython3
70+
import random
7071
from functools import partial
7172
7273
import numpy as np
@@ -466,62 +467,15 @@ Numba 的编译通常相当快,对于像这样的顺序运算,生成的代
466467

467468
### JAX 版本
468469

469-
现在让我们使用 `at[t].set` 风格的语法创建一个 JAX 版本,正如 {ref}`JAX 讲座中讨论的 <jax_at_workaround>`,这为不可变数组提供了一种变通方法。
470+
现在让我们使用 `lax.scan` 创建一个 JAX 版本
470471

471-
我们将使用 `lax.fori_loop`,它是一种可以被 XLA 编译的 for 循环版本。
472+
(我们将 `n` 设为静态,因为它影响数组大小,JAX 希望在编译代码中针对其值进行特化处理。)
472473

473474
```{code-cell} ipython3
474475
cpu = jax.devices("cpu")[0]
475476
476-
@partial(jax.jit, static_argnames=("n",), device=cpu)
477-
def qm_jax_fori(x0, n, α=4.0):
478-
479-
x = jnp.empty(n + 1).at[0].set(x0)
480-
481-
def update(t, x):
482-
return x.at[t + 1].set(α * x[t] * (1 - x[t]))
483-
484-
x = lax.fori_loop(0, n, update, x)
485-
return x
486-
487-
```
488-
489-
* 我们将 `n` 设为静态,因为它影响数组大小,JAX 希望在编译代码中针对其值进行特化处理。
490-
* 我们通过 `device=cpu` 将计算固定在 CPU 上,因为这种顺序工作负载由许多小操作组成,几乎没有机会利用 GPU 并行性。
491-
492-
虽然 `at[t].set` 看起来在每一步都创建了一个新数组,但在 JIT 编译的函数内部,编译器会检测到旧数组不再需要,并就地执行更新。
493-
494-
让我们使用相同的参数计时:
495-
496-
```{code-cell} ipython3
497-
with qe.Timer():
498-
# First run
499-
x_jax = qm_jax_fori(0.1, n)
500-
# Hold interpreter
501-
x_jax.block_until_ready()
502-
```
503-
504-
让我们再次运行以消除编译开销:
505-
506-
```{code-cell} ipython3
507-
with qe.Timer():
508-
# Second run
509-
x_jax = qm_jax_fori(0.1, n)
510-
# Hold interpreter
511-
x_jax.block_until_ready()
512-
```
513-
514-
JAX 对于这种顺序运算也相当高效。
515-
516-
517-
我们还有另一种实现循环的方式,使用 `lax.scan`
518-
519-
这种替代方案可以说更符合 JAX 的函数式方法——尽管语法难以记忆。
520-
521-
522-
```{code-cell} ipython3
523-
@partial(jax.jit, static_argnames=("n",), device=cpu)
524-
def qm_jax_scan(x0, n, α=4.0):
477+
@partial(jax.jit, static_argnames=('n',), device=cpu)
478+
def qm_jax(x0, n, α=4.0):
525479
def update(x, t):
526480
x_new = α * x * (1 - x)
527481
return x_new, x_new
@@ -532,12 +486,16 @@ def qm_jax_scan(x0, n, α=4.0):
532486

533487
这段代码不易阅读,但本质上,`lax.scan` 反复调用 `update` 并将返回值 `x_new` 累积到一个数组中。
534488

489+
```{note}
490+
我们在 `jax.jit` 装饰器中指定了 `device=cpu`,因为该计算由许多小的顺序运算组成,几乎没有机会让 GPU 利用并行性。因此,GPU 上的内核启动开销往往占主导地位,使得 CPU 更适合这种工作负载。
491+
```
492+
535493
让我们使用相同的参数计时:
536494

537495
```{code-cell} ipython3
538496
with qe.Timer():
539497
# First run
540-
x_jax = qm_jax_scan(0.1, n)
498+
x_jax = qm_jax(0.1, n)
541499
# Hold interpreter
542500
x_jax.block_until_ready()
543501
```
@@ -547,11 +505,13 @@ with qe.Timer():
547505
```{code-cell} ipython3
548506
with qe.Timer():
549507
# Second run
550-
x_jax = qm_jax_scan(0.1, n)
508+
x_jax = qm_jax(0.1, n)
551509
# Hold interpreter
552510
x_jax.block_until_ready()
553511
```
554512

513+
JAX 对于这种顺序运算也相当高效。
514+
555515
JAX 和 Numba 在编译后都能提供出色的性能。
556516

557517
### 总结
@@ -562,9 +522,9 @@ Numba 版本简单直观,易于阅读:我们只需分配一个数组,然
562522

563523
这正是大多数程序员思考该算法的方式。
564524

565-
另一方面,JAX 版本需要使用 `lax.fori_loop``lax.scan`两者都比标准 Python 循环更不直观
525+
另一方面,JAX 版本需要使用 `lax.scan`这明显不够直观
566526

567-
虽然 JAX `at[t].set` 语法确实允许逐元素更新,但整体代码仍然比 Numba 等价版本更难阅读
527+
此外,JAX 的不可变数组意味着我们无法简单地就地更新数组元素,这使得直接复制 Numba 使用的算法变得困难
568528

569529
对于这类顺序运算,在代码清晰度和实现便利性方面,Numba 是明显的赢家。
570530

@@ -582,13 +542,13 @@ Numba 版本简单直观,易于阅读:我们只需分配一个数组,然
582542

583543
对于**顺序操作**,Numba 具有明显优势。
584544

585-
代码自然易读——只需一个带装饰器的 Python 循环——且性能出色
545+
代码自然且可读——只需一个带有装饰器的 Python 循环——性能也非常出色
586546

587-
JAX 可以通过 `lax.fori_loop``lax.scan` 处理顺序问题,但语法不够直观。
547+
JAX 可以通过 `lax.scan` 处理顺序问题,但语法不够直观。
588548

589549
```{note}
590-
`lax.fori_loop` 和 `lax.scan` 有一个重要优势:它们支持对循环进行自动微分,而 Numba 无法做到这一点。
591-
如果需要对顺序计算进行微分(例如,计算轨迹对模型参数的敏感性),尽管语法不够自然,JAX 仍是更好的选择。
550+
`lax.scan` 的一个重要优势是它支持通过循环进行自动微分,而 Numba 无法做到这一点。
551+
如果您需要对顺序计算进行微分(例如,计算轨迹对模型参数的敏感性),尽管语法不够自然,JAX 仍是更好的选择。
592552
```
593553

594554
在实践中,许多问题涉及两种模式的混合。

0 commit comments

Comments
 (0)