Skip to content

Commit 2006d74

Browse files
authored
Merge pull request #54 from QuantEcon/translation-sync-2026-04-13T00-09-37-pr-528
🌐 [translation-sync] Minor edits
2 parents 895d241 + b299b21 commit 2006d74

File tree

6 files changed

+83
-215
lines changed

6 files changed

+83
-215
lines changed

.translate/state/jax_intro.md.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
source-sha: 8d73de367a7f160dac777aa557f1c26069f84ea5
2-
synced-at: "2026-04-12"
1+
source-sha: 95378b8382b4dbd1cd3e0ffe0e152811894c357f
2+
synced-at: "2026-04-13"
33
model: claude-sonnet-4-6
44
mode: UPDATE
55
section-count: 7

.translate/state/numba.md.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
source-sha: be6eeaee8db0c8bfea65b89d57ca8aecf7f96dff
2-
synced-at: "2026-04-12"
1+
source-sha: 95378b8382b4dbd1cd3e0ffe0e152811894c357f
2+
synced-at: "2026-04-13"
33
model: claude-sonnet-4-6
44
mode: UPDATE
55
section-count: 5

.translate/state/numpy_vs_numba_vs_jax.md.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
source-sha: 94dd7d22385ec46d740db1fc2cddf05c29377594
2-
synced-at: "2026-04-12"
1+
source-sha: 95378b8382b4dbd1cd3e0ffe0e152811894c357f
2+
synced-at: "2026-04-13"
33
model: claude-sonnet-4-6
44
mode: UPDATE
55
section-count: 3

lectures/jax_intro.md

Lines changed: 10 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ translation:
1818
JAX as a NumPy Replacement::Differences::Speed!: 速度!
1919
JAX as a NumPy Replacement::Differences::Speed!::With NumPy: 使用 NumPy
2020
JAX as a NumPy Replacement::Differences::Speed!::With JAX: 使用 JAX
21+
JAX as a NumPy Replacement::Differences::Size Experiment: 大小实验
2122
JAX as a NumPy Replacement::Differences::Precision: 精度
2223
JAX as a NumPy Replacement::Differences::Immutability: 不可变性
2324
JAX as a NumPy Replacement::Differences::A workaround: 变通方法
@@ -31,13 +32,11 @@ translation:
3132
Random numbers::Why explicit random state?::NumPy's approach: NumPy 的方法
3233
Random numbers::Why explicit random state?::JAX's approach: JAX 的方法
3334
JIT Compilation: JIT 编译
34-
JIT Compilation::Evaluating a more complicated function: 评估更复杂的函数
35-
JIT Compilation::Evaluating a more complicated function::With NumPy: 使用 NumPy
36-
JIT Compilation::Evaluating a more complicated function::With JAX: 使用 JAX
37-
JIT Compilation::Compiling the whole function: 编译整个函数
35+
JIT Compilation::With NumPy: 使用 NumPy
36+
JIT Compilation::With JAX: 使用 JAX
37+
JIT Compilation::Compiling the Whole Function: 编译整个函数
3838
JIT Compilation::How JIT compilation works: JIT 编译的工作原理
3939
JIT Compilation::Compiling non-pure functions: 编译非纯函数
40-
JIT Compilation::Summary: 总结
4140
Exercises: 练习
4241
---
4342

@@ -205,6 +204,8 @@ with qe.Timer():
205204

206205
大小对于生成优化代码很重要,因为高效的并行化需要将任务大小与可用硬件相匹配。
207206

207+
#### 大小实验
208+
208209
我们可以通过更改输入大小并观察运行时间来验证 JAX 针对数组大小进行专门化的说法。
209210

210211
```{code-cell}
@@ -233,105 +234,6 @@ with qe.Timer():
233234

234235
关于 JIT 编译的进一步讨论见下文。
235236

236-
(jax_speed)=
237-
#### 速度!
238-
239-
假设我们想在许多点上求余弦函数的值。
240-
241-
```{code-cell}
242-
n = 50_000_000
243-
x = np.linspace(0, 10, n)
244-
```
245-
246-
##### 使用 NumPy
247-
248-
让我们先用 NumPy 试试:
249-
250-
```{code-cell}
251-
with qe.Timer():
252-
y = np.cos(x)
253-
```
254-
255-
再来一次。
256-
257-
```{code-cell}
258-
with qe.Timer():
259-
y = np.cos(x)
260-
```
261-
262-
这里:
263-
264-
* NumPy 使用预编译的二进制文件对浮点数组应用余弦函数
265-
* 该二进制文件在本地机器的 CPU 上运行
266-
267-
##### 使用 JAX
268-
269-
现在让我们用 JAX 试试。
270-
271-
```{code-cell}
272-
x = jnp.linspace(0, 10, n)
273-
```
274-
275-
对相同的过程计时。
276-
277-
```{code-cell}
278-
with qe.Timer():
279-
y = jnp.cos(x)
280-
jax.block_until_ready(y);
281-
```
282-
283-
```{note}
284-
这里,为了测量实际速度,我们使用 `block_until_ready` 方法让解释器等待,直到计算结果返回。
285-
286-
这是必要的,因为 JAX 使用异步调度,允许 Python 解释器在数值计算之前运行。
287-
288-
对于非计时代码,可以省略包含 `block_until_ready` 的那行。
289-
```
290-
291-
再计时一次。
292-
293-
```{code-cell}
294-
with qe.Timer():
295-
y = jnp.cos(x)
296-
jax.block_until_ready(y);
297-
```
298-
299-
在 GPU 上,这段代码的运行速度远快于等效的 NumPy 代码。
300-
301-
此外,通常第二次运行比第一次更快,这是由于 JIT 编译的原因。
302-
303-
这是因为即使是 `jnp.cos` 这样的内置函数也会被 JIT 编译——第一次运行包含了编译时间。
304-
305-
为什么 JAX 要对 `jnp.cos` 这样的内置函数进行 JIT 编译,而不是像 NumPy 那样直接提供预编译版本?
306-
307-
原因是 JIT 编译器希望针对所使用数组的*大小*(以及数据类型)进行专门优化。
308-
309-
大小对于生成优化代码很重要,因为高效并行化需要将任务大小与可用硬件相匹配。
310-
311-
我们可以通过改变输入大小并观察运行时间来验证 JAX 针对数组大小进行专门优化的说法。
312-
313-
```{code-cell}
314-
x = jnp.linspace(0, 10, n + 1)
315-
```
316-
317-
```{code-cell}
318-
with qe.Timer():
319-
y = jnp.cos(x)
320-
jax.block_until_ready(y);
321-
```
322-
323-
```{code-cell}
324-
with qe.Timer():
325-
y = jnp.cos(x)
326-
jax.block_until_ready(y);
327-
```
328-
329-
运行时间先增加后再次下降(在 GPU 上这一现象会更明显)。
330-
331-
这与上面的讨论一致——改变数组大小后的第一次运行显示了编译开销。
332-
333-
关于 JIT 编译的进一步讨论见下文。
334-
335237
#### 精度
336238

337239
NumPy 和 JAX 之间的另一个差异是 JAX 默认使用 32 位浮点数。
@@ -731,19 +633,15 @@ JAX 的即时(JIT)编译器通过生成随任务大小和硬件变化的高
731633

732634
我们在 {ref}`上文 <jax_speed>` 中已经看到了 JAX 的 JIT 编译器结合并行硬件的强大之处,当时我们对一个大数组应用了 `cos` 函数。
733635

734-
让我们用一个更复杂的函数尝试同样的操作。
735-
736-
### 评估更复杂的函数
737-
738-
考虑以下函数:
636+
让我们用一个更复杂的函数尝试同样的操作:
739637

740638
```{code-cell}
741639
def f(x):
742640
y = np.cos(2 * x**2) + np.sqrt(np.abs(x)) + 2 * np.sin(x**4) - x**2
743641
return y
744642
```
745643

746-
#### 使用 NumPy
644+
### 使用 NumPy
747645

748646
我们先用 NumPy 试试:
749647

@@ -758,7 +656,7 @@ with qe.Timer():
758656
y = f(x)
759657
```
760658

761-
#### 使用 JAX
659+
### 使用 JAX
762660

763661
现在让我们用 JAX 再试一次。
764662

@@ -793,7 +691,7 @@ with qe.Timer():
793691

794692
结果与 `cos` 示例类似——JAX 更快,尤其是在 JIT 编译后的第二次运行中。
795693

796-
然而,使用 JAX,我们还有另一个技巧——我们可以对*整个*函数进行 JIT 编译,而不仅仅是单个操作。
694+
然而,使用 JAX,我们还有另一个技巧——我们可以对整个函数进行 JIT 编译,而不仅仅是单个操作。
797695

798696
### 编译整个函数
799697

lectures/numba.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ n = 10_000_000
132132
133133
with qe.Timer() as timer1:
134134
# Time Python base version
135-
x = qm(0.1, int(n))
135+
x = qm(0.1, n)
136136
137137
```
138138

@@ -160,7 +160,7 @@ qm_numba = jit(qm)
160160
```{code-cell} ipython3
161161
with qe.Timer() as timer2:
162162
# Time jitted version
163-
x = qm_numba(0.1, int(n))
163+
x = qm_numba(0.1, n)
164164
```
165165

166166
这已经是非常大的速度提升。
@@ -172,7 +172,7 @@ with qe.Timer() as timer2:
172172
```{code-cell} ipython3
173173
with qe.Timer() as timer3:
174174
# Second run
175-
x = qm_numba(0.1, int(n))
175+
x = qm_numba(0.1, n)
176176
```
177177

178178
以下是速度提升

0 commit comments

Comments
 (0)