Skip to content

Commit b299b21

Browse files
Fix duplicate Speed! section in jax_intro.md, add Size Experiment heading
- Add #### 大小实验 heading to match English #### Size Experiment structure - Remove duplicate #### 速度! section (was duplicated from overlapping translation PRs) - Add Size Experiment entry to heading map in frontmatter Agent-Logs-Url: https://github.com/QuantEcon/lecture-python-programming.zh-cn/sessions/9216d7c8-4733-4e5e-a1c5-74a4d6018ce1 Co-authored-by: HumphreyYang <39026988+HumphreyYang@users.noreply.github.com>
1 parent eacc764 commit b299b21

File tree

1 file changed

+3
-98
lines changed

1 file changed

+3
-98
lines changed

lectures/jax_intro.md

Lines changed: 3 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ translation:
1818
JAX as a NumPy Replacement::Differences::Speed!: 速度!
1919
JAX as a NumPy Replacement::Differences::Speed!::With NumPy: 使用 NumPy
2020
JAX as a NumPy Replacement::Differences::Speed!::With JAX: 使用 JAX
21+
JAX as a NumPy Replacement::Differences::Size Experiment: 大小实验
2122
JAX as a NumPy Replacement::Differences::Precision: 精度
2223
JAX as a NumPy Replacement::Differences::Immutability: 不可变性
2324
JAX as a NumPy Replacement::Differences::A workaround: 变通方法
@@ -203,6 +204,8 @@ with qe.Timer():
203204

204205
大小对于生成优化代码很重要,因为高效的并行化需要将任务大小与可用硬件相匹配。
205206

207+
#### 大小实验
208+
206209
我们可以通过更改输入大小并观察运行时间来验证 JAX 针对数组大小进行专门化的说法。
207210

208211
```{code-cell}
@@ -231,104 +234,6 @@ with qe.Timer():
231234

232235
关于 JIT 编译的进一步讨论见下文。
233236

234-
#### 速度!
235-
236-
假设我们想在许多点上求余弦函数的值。
237-
238-
```{code-cell}
239-
n = 50_000_000
240-
x = np.linspace(0, 10, n)
241-
```
242-
243-
##### 使用 NumPy
244-
245-
让我们先用 NumPy 试试:
246-
247-
```{code-cell}
248-
with qe.Timer():
249-
y = np.cos(x)
250-
```
251-
252-
再来一次。
253-
254-
```{code-cell}
255-
with qe.Timer():
256-
y = np.cos(x)
257-
```
258-
259-
这里:
260-
261-
* NumPy 使用预编译的二进制文件对浮点数组应用余弦函数
262-
* 该二进制文件在本地机器的 CPU 上运行
263-
264-
##### 使用 JAX
265-
266-
现在让我们用 JAX 试试。
267-
268-
```{code-cell}
269-
x = jnp.linspace(0, 10, n)
270-
```
271-
272-
对相同的过程计时。
273-
274-
```{code-cell}
275-
with qe.Timer():
276-
y = jnp.cos(x)
277-
jax.block_until_ready(y);
278-
```
279-
280-
```{note}
281-
这里,为了测量实际速度,我们使用 `block_until_ready` 方法让解释器等待,直到计算结果返回。
282-
283-
这是必要的,因为 JAX 使用异步调度,允许 Python 解释器在数值计算之前运行。
284-
285-
对于非计时代码,可以省略包含 `block_until_ready` 的那行。
286-
```
287-
288-
再计时一次。
289-
290-
```{code-cell}
291-
with qe.Timer():
292-
y = jnp.cos(x)
293-
jax.block_until_ready(y);
294-
```
295-
296-
在 GPU 上,这段代码的运行速度远快于等效的 NumPy 代码。
297-
298-
此外,通常第二次运行比第一次更快,这是由于 JIT 编译的原因。
299-
300-
这是因为即使是 `jnp.cos` 这样的内置函数也会被 JIT 编译——第一次运行包含了编译时间。
301-
302-
为什么 JAX 要对 `jnp.cos` 这样的内置函数进行 JIT 编译,而不是像 NumPy 那样直接提供预编译版本?
303-
304-
原因是 JIT 编译器希望针对所使用数组的*大小*(以及数据类型)进行专门优化。
305-
306-
大小对于生成优化代码很重要,因为高效并行化需要将任务大小与可用硬件相匹配。
307-
308-
我们可以通过改变输入大小并观察运行时间来验证 JAX 针对数组大小进行专门优化的说法。
309-
310-
```{code-cell}
311-
x = jnp.linspace(0, 10, n + 1)
312-
```
313-
314-
```{code-cell}
315-
with qe.Timer():
316-
y = jnp.cos(x)
317-
jax.block_until_ready(y);
318-
```
319-
320-
```{code-cell}
321-
with qe.Timer():
322-
y = jnp.cos(x)
323-
jax.block_until_ready(y);
324-
```
325-
326-
运行时间先增加后再次下降(在 GPU 上这一现象会更明显)。
327-
328-
这与上面的讨论一致——改变数组大小后的第一次运行显示了编译开销。
329-
330-
关于 JIT 编译的进一步讨论见下文。
331-
332237
#### 精度
333238

334239
NumPy 和 JAX 之间的另一个差异是 JAX 默认使用 32 位浮点数。

0 commit comments

Comments
 (0)