@@ -18,6 +18,7 @@ translation:
1818 JAX as a NumPy Replacement::Differences::Speed! : 速度!
1919 JAX as a NumPy Replacement::Differences::Speed!::With NumPy : 使用 NumPy
2020 JAX as a NumPy Replacement::Differences::Speed!::With JAX : 使用 JAX
21+ JAX as a NumPy Replacement::Differences::Size Experiment : 大小实验
2122 JAX as a NumPy Replacement::Differences::Precision : 精度
2223 JAX as a NumPy Replacement::Differences::Immutability : 不可变性
2324 JAX as a NumPy Replacement::Differences::A workaround : 变通方法
@@ -31,13 +32,11 @@ translation:
3132 Random numbers::Why explicit random state?::NumPy's approach : NumPy 的方法
3233 Random numbers::Why explicit random state?::JAX's approach : JAX 的方法
3334 JIT Compilation : JIT 编译
34- JIT Compilation::Evaluating a more complicated function : 评估更复杂的函数
35- JIT Compilation::Evaluating a more complicated function::With NumPy : 使用 NumPy
36- JIT Compilation::Evaluating a more complicated function::With JAX : 使用 JAX
37- JIT Compilation::Compiling the whole function : 编译整个函数
35+ JIT Compilation::With NumPy : 使用 NumPy
36+ JIT Compilation::With JAX : 使用 JAX
37+ JIT Compilation::Compiling the Whole Function : 编译整个函数
3838 JIT Compilation::How JIT compilation works : JIT 编译的工作原理
3939 JIT Compilation::Compiling non-pure functions : 编译非纯函数
40- JIT Compilation::Summary : 总结
4140 Exercises : 练习
4241---
4342
@@ -205,6 +204,8 @@ with qe.Timer():
205204
206205大小对于生成优化代码很重要,因为高效的并行化需要将任务大小与可用硬件相匹配。
207206
207+ #### 大小实验
208+
208209我们可以通过更改输入大小并观察运行时间来验证 JAX 针对数组大小进行专门化的说法。
209210
210211``` {code-cell}
@@ -233,105 +234,6 @@ with qe.Timer():
233234
234235关于 JIT 编译的进一步讨论见下文。
235236
236- (jax_speed)=
237- #### 速度!
238-
239- 假设我们想在许多点上求余弦函数的值。
240-
241- ``` {code-cell}
242- n = 50_000_000
243- x = np.linspace(0, 10, n)
244- ```
245-
246- ##### 使用 NumPy
247-
248- 让我们先用 NumPy 试试:
249-
250- ``` {code-cell}
251- with qe.Timer():
252- y = np.cos(x)
253- ```
254-
255- 再来一次。
256-
257- ``` {code-cell}
258- with qe.Timer():
259- y = np.cos(x)
260- ```
261-
262- 这里:
263-
264- * NumPy 使用预编译的二进制文件对浮点数组应用余弦函数
265- * 该二进制文件在本地机器的 CPU 上运行
266-
267- ##### 使用 JAX
268-
269- 现在让我们用 JAX 试试。
270-
271- ``` {code-cell}
272- x = jnp.linspace(0, 10, n)
273- ```
274-
275- 对相同的过程计时。
276-
277- ``` {code-cell}
278- with qe.Timer():
279- y = jnp.cos(x)
280- jax.block_until_ready(y);
281- ```
282-
283- ``` {note}
284- 这里,为了测量实际速度,我们使用 `block_until_ready` 方法让解释器等待,直到计算结果返回。
285-
286- 这是必要的,因为 JAX 使用异步调度,允许 Python 解释器在数值计算之前运行。
287-
288- 对于非计时代码,可以省略包含 `block_until_ready` 的那行。
289- ```
290-
291- 再计时一次。
292-
293- ``` {code-cell}
294- with qe.Timer():
295- y = jnp.cos(x)
296- jax.block_until_ready(y);
297- ```
298-
299- 在 GPU 上,这段代码的运行速度远快于等效的 NumPy 代码。
300-
301- 此外,通常第二次运行比第一次更快,这是由于 JIT 编译的原因。
302-
303- 这是因为即使是 ` jnp.cos ` 这样的内置函数也会被 JIT 编译——第一次运行包含了编译时间。
304-
305- 为什么 JAX 要对 ` jnp.cos ` 这样的内置函数进行 JIT 编译,而不是像 NumPy 那样直接提供预编译版本?
306-
307- 原因是 JIT 编译器希望针对所使用数组的* 大小* (以及数据类型)进行专门优化。
308-
309- 大小对于生成优化代码很重要,因为高效并行化需要将任务大小与可用硬件相匹配。
310-
311- 我们可以通过改变输入大小并观察运行时间来验证 JAX 针对数组大小进行专门优化的说法。
312-
313- ``` {code-cell}
314- x = jnp.linspace(0, 10, n + 1)
315- ```
316-
317- ``` {code-cell}
318- with qe.Timer():
319- y = jnp.cos(x)
320- jax.block_until_ready(y);
321- ```
322-
323- ``` {code-cell}
324- with qe.Timer():
325- y = jnp.cos(x)
326- jax.block_until_ready(y);
327- ```
328-
329- 运行时间先增加后再次下降(在 GPU 上这一现象会更明显)。
330-
331- 这与上面的讨论一致——改变数组大小后的第一次运行显示了编译开销。
332-
333- 关于 JIT 编译的进一步讨论见下文。
334-
335237#### 精度
336238
337239NumPy 和 JAX 之间的另一个差异是 JAX 默认使用 32 位浮点数。
@@ -731,19 +633,15 @@ JAX 的即时(JIT)编译器通过生成随任务大小和硬件变化的高
731633
732634我们在 {ref}` 上文 <jax_speed> ` 中已经看到了 JAX 的 JIT 编译器结合并行硬件的强大之处,当时我们对一个大数组应用了 ` cos ` 函数。
733635
734- 让我们用一个更复杂的函数尝试同样的操作。
735-
736- ### 评估更复杂的函数
737-
738- 考虑以下函数:
636+ 让我们用一个更复杂的函数尝试同样的操作:
739637
740638``` {code-cell}
741639def f(x):
742640 y = np.cos(2 * x**2) + np.sqrt(np.abs(x)) + 2 * np.sin(x**4) - x**2
743641 return y
744642```
745643
746- #### 使用 NumPy
644+ ### 使用 NumPy
747645
748646我们先用 NumPy 试试:
749647
@@ -758,7 +656,7 @@ with qe.Timer():
758656 y = f(x)
759657```
760658
761- #### 使用 JAX
659+ ### 使用 JAX
762660
763661现在让我们用 JAX 再试一次。
764662
@@ -793,7 +691,7 @@ with qe.Timer():
793691
794692结果与 ` cos ` 示例类似——JAX 更快,尤其是在 JIT 编译后的第二次运行中。
795693
796- 然而,使用 JAX,我们还有另一个技巧——我们可以对 * 整个 * 函数进行 JIT 编译,而不仅仅是单个操作。
694+ 然而,使用 JAX,我们还有另一个技巧——我们可以对整个函数进行 JIT 编译,而不仅仅是单个操作。
797695
798696### 编译整个函数
799697
0 commit comments