Skip to content

Commit 4b45330

Browse files
committed
Update translation: lectures/jax_intro.md
1 parent 1e8dde1 commit 4b45330

1 file changed

Lines changed: 149 additions & 71 deletions

File tree

lectures/jax_intro.md

Lines changed: 149 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,13 @@ translation:
1515
JAX as a NumPy Replacement: JAX 作为 NumPy 的替代品
1616
JAX as a NumPy Replacement::Similarities: 相似之处
1717
JAX as a NumPy Replacement::Differences: 差异
18+
JAX as a NumPy Replacement::Differences::Speed!: 速度!
19+
JAX as a NumPy Replacement::Differences::Speed!::With NumPy: 使用 NumPy
20+
JAX as a NumPy Replacement::Differences::Speed!::With JAX: 使用 JAX
21+
JAX as a NumPy Replacement::Differences::Size Experiment: 大小实验
1822
JAX as a NumPy Replacement::Differences::Precision: 精度
1923
JAX as a NumPy Replacement::Differences::Immutability: 不可变性
20-
JAX as a NumPy Replacement::Differences::A workaround: 变通方法
24+
JAX as a NumPy Replacement::Differences::A Workaround: 变通方法
2125
Functional Programming: 函数式编程
2226
Functional Programming::Pure functions: 纯函数
2327
Functional Programming::Examples: 示例
@@ -26,18 +30,12 @@ translation:
2630
Random numbers::Why explicit random state?: 为什么要显式随机状态?
2731
Random numbers::Why explicit random state?::NumPy's approach: NumPy 的方法
2832
Random numbers::Why explicit random state?::JAX's approach: JAX 的方法
29-
JIT compilation: JIT 编译
30-
JIT compilation::A simple example: 一个简单的示例
31-
JIT compilation::A simple example::With NumPy: 使用 NumPy
32-
JIT compilation::A simple example::With JAX: 使用 JAX
33-
JIT compilation::A simple example::Changing array sizes: 更改数组大小
34-
JIT compilation::Evaluating a more complicated function: 评估更复杂的函数
35-
JIT compilation::Evaluating a more complicated function::With NumPy: 使用 NumPy
36-
JIT compilation::Evaluating a more complicated function::With JAX: 使用 JAX
37-
JIT compilation::How JIT compilation works: JIT 编译的工作原理
38-
JIT compilation::Compiling the whole function: 编译整个函数
39-
JIT compilation::Compiling non-pure functions: 编译非纯函数
40-
JIT compilation::Summary: 总结
33+
JIT Compilation: JIT 编译
34+
JIT Compilation::With NumPy: 一个简单的示例
35+
JIT Compilation::With JAX: 评估更复杂的函数
36+
JIT Compilation::Compiling the Whole Function: JIT 编译的工作原理
37+
JIT Compilation::How JIT compilation works: 编译整个函数
38+
JIT Compilation::Compiling non-pure functions: 编译非纯函数
4139
Vectorization with `vmap`: 使用 `vmap` 进行向量化
4240
Vectorization with `vmap`::A simple example: 一个简单的示例
4341
Vectorization with `vmap`::Combining transformations: 组合变换
@@ -49,13 +47,16 @@ translation:
4947

5048
本讲座简要介绍 [Google JAX](https://github.com/jax-ml/jax)
5149

50+
```{include} _admonition/gpu.md
51+
```
52+
5253
JAX 是一个高性能科学计算库,提供以下功能:
5354

5455
* 类似 [NumPy](https://en.wikipedia.org/wiki/NumPy) 的接口,可以在 CPU 和 GPU 上自动并行化,
5556
* 一个即时编译器,用于加速大量数值运算,以及
5657
* [自动微分](https://en.wikipedia.org/wiki/Automatic_differentiation)
5758

58-
JAX 也在日益维护和提供[更多专业化的科学计算例程](https://docs.jax.dev/en/latest/jax.scipy.html),例如那些最初在 [SciPy](https://en.wikipedia.org/wiki/SciPy) 中找到的例程。
59+
JAX 也在日益维护和提供 [更多专业化的科学计算例程](https://docs.jax.dev/en/latest/jax.scipy.html),例如那些最初在 [SciPy](https://en.wikipedia.org/wiki/SciPy) 中找到的例程。
5960

6061
除了 Anaconda 中已有的内容外,本讲座还需要以下库:
6162

@@ -65,36 +66,27 @@ JAX 也在日益维护和提供[更多专业化的科学计算例程](https://do
6566
!pip install jax quantecon
6667
```
6768

68-
```{include} _admonition/gpu.md
69-
```
70-
71-
## JAX 作为 NumPy 的替代品
72-
73-
JAX 的一个吸引人之处在于,它的数组处理操作在尽可能的情况下遵循 NumPy API。
74-
75-
这意味着在许多情况下,我们可以将 JAX 作为 NumPy 的直接替代品使用。
76-
77-
让我们来看看 JAX 和 NumPy 之间的异同。
78-
79-
### 相似之处
80-
8169
我们将使用以下导入:
8270

8371
```{code-cell} ipython3
8472
import jax
8573
import jax.numpy as jnp
8674
import matplotlib.pyplot as plt
87-
import matplotlib.patches as mpatches
8875
import numpy as np
8976
import quantecon as qe
90-
import matplotlib as mpl # i18n
91-
import matplotlib.font_manager # i18n
92-
FONTPATH = "_fonts/SourceHanSerifSC-SemiBold.otf" # i18n
93-
mpl.font_manager.fontManager.addfont(FONTPATH) # i18n
94-
mpl.rcParams['font.family'] = ['Source Han Serif SC'] # i18n
9577
```
9678

97-
注意我们导入了 `jax.numpy as jnp`,它提供了类似 NumPy 的接口。
79+
## JAX 作为 NumPy 的替代品
80+
81+
让我们来看看 JAX 和 NumPy 之间的异同。
82+
83+
### 相似之处
84+
85+
上面我们导入了 `jax.numpy as jnp`,它提供了类似 NumPy 的数组操作接口。
86+
87+
JAX 的一个吸引人之处在于,这个接口在尽可能的情况下遵循 NumPy API。
88+
89+
因此,我们通常可以将 JAX 作为 NumPy 的直接替代品使用。
9890

9991
以下是使用 `jnp` 进行的一些标准数组操作:
10092

@@ -110,15 +102,11 @@ print(a)
110102
print(jnp.sum(a))
111103
```
112104

113-
```{code-cell} ipython3
114-
print(jnp.mean(a))
115-
```
116-
117105
```{code-cell} ipython3
118106
print(jnp.dot(a, a))
119107
```
120108

121-
然而,数组对象 `a` 并不是 NumPy 数组:
109+
但需要注意的是,数组对象 `a` 并不是 NumPy 数组:
122110

123111
```{code-cell} ipython3
124112
a
@@ -128,37 +116,131 @@ a
128116
type(a)
129117
```
130118

131-
即使是数组上的标量值映射也会返回 JAX 数组。
119+
即使是数组上的标量值映射也会返回 JAX 数组而非标量!
132120

133121
```{code-cell} ipython3
134122
jnp.sum(a)
135123
```
136124

137-
对高维数组的操作也与 NumPy 类似:
125+
### 差异
126+
127+
现在让我们来看看 JAX 和 NumPy 数组操作之间的一些差异。
128+
129+
(jax_speed)=
130+
#### 速度!
138131

139-
```{code-cell} ipython3
140-
A = jnp.ones((2, 2))
141-
B = jnp.identity(2)
142-
A @ B
132+
一个主要差异是 JAX 更快——有时快得多。
133+
134+
为了说明这一点,假设我们想在许多点处计算余弦函数。
135+
136+
```{code-cell}
137+
n = 50_000_000
138+
x = np.linspace(0, 10, n) # NumPy array
143139
```
144140

145-
JAX 的数组接口也提供了 `linalg` 子包:
141+
##### 使用 NumPy
146142

147-
```{code-cell} ipython3
148-
jnp.linalg.inv(B) # Inverse of identity is identity
143+
让我们用 NumPy 来试试。
144+
145+
```{code-cell}
146+
with qe.Timer():
147+
# First NumPy timing
148+
y = np.cos(x)
149149
```
150150

151-
```{code-cell} ipython3
152-
jnp.linalg.eigh(B) # Computes eigenvalues and eigenvectors
151+
再来一次。
152+
153+
```{code-cell}
154+
with qe.Timer():
155+
# Second NumPy timing
156+
y = np.cos(x)
153157
```
154158

155-
### 差异
159+
这里
156160

157-
现在让我们来看看 JAX 和 NumPy 数组操作之间的一些差异。
161+
* NumPy 使用预编译的二进制文件将余弦函数应用于浮点数数组
162+
* 该二进制文件在本地机器的 CPU 上运行
163+
164+
##### 使用 JAX
165+
166+
现在让我们用 JAX 来试试。
167+
168+
```{code-cell}
169+
x = jnp.linspace(0, 10, n)
170+
```
171+
172+
让我们对同样的过程计时。
173+
174+
```{code-cell}
175+
with qe.Timer():
176+
# First run
177+
y = jnp.cos(x)
178+
# Hold the interpreter until the array operation finishes
179+
y.block_until_ready()
180+
```
181+
182+
```{note}
183+
上面,`block_until_ready` 方法会阻塞解释器,直到计算结果返回。
184+
这对于计时执行是必要的,因为 JAX 使用异步调度,
185+
允许 Python 解释器在数值计算之前运行。
186+
```
187+
188+
现在让我们再次计时。
189+
190+
```{code-cell}
191+
with qe.Timer():
192+
# Second run
193+
y = jnp.cos(x)
194+
# Hold interpreter
195+
y.block_until_ready()
196+
```
197+
198+
在 GPU 上,这段代码的运行速度远快于其 NumPy 等价代码。
199+
200+
另外,通常情况下,由于 JIT 编译,第二次运行比第一次更快。
201+
202+
这是因为即使是像 `jnp.cos` 这样的内置函数也会被 JIT 编译——第一次运行包含了编译时间。
203+
204+
为什么 JAX 要对 `jnp.cos` 这样的内置函数进行 JIT 编译,而不是像 NumPy 那样直接提供预编译版本呢?
205+
206+
原因是 JIT 编译器希望针对所使用数组的*大小*(以及数据类型)进行专门优化。
207+
208+
大小对于生成优化代码很重要,因为高效的并行化需要将任务大小与可用硬件匹配。
209+
210+
#### 大小实验
211+
212+
我们可以通过改变输入大小并观察运行时间来验证 JAX 针对数组大小进行专门优化的说法。
213+
214+
```{code-cell}
215+
x = jnp.linspace(0, 10, n + 1)
216+
```
217+
218+
```{code-cell}
219+
with qe.Timer():
220+
# First run
221+
y = jnp.cos(x)
222+
# Hold interpreter
223+
y.block_until_ready()
224+
```
225+
226+
227+
```{code-cell}
228+
with qe.Timer():
229+
# Second run
230+
y = jnp.cos(x)
231+
# Hold interpreter
232+
y.block_until_ready()
233+
```
234+
235+
运行时间先增加后减少(这在 GPU 上会更明显)。
236+
237+
这与上面的讨论一致——改变数组大小后的第一次运行显示了编译开销。
238+
239+
下面将进一步讨论 JIT 编译。
158240

159241
#### 精度
160242

161-
NumPy 和 JAX 之间的一个差异是 JAX 默认使用 32 位浮点数。
243+
NumPy 和 JAX 之间的另一个差异是 JAX 默认使用 32 位浮点数。
162244

163245
这是因为 JAX 经常用于 GPU 计算,而大多数 GPU 计算使用 32 位浮点数。
164246

@@ -196,40 +278,34 @@ a[0] = 1
196278
a
197279
```
198280

199-
在 JAX 中,这会失败:
281+
在 JAX 中,这会失败 😱。
282+
200283

201284
```{code-cell} ipython3
202285
a = jnp.linspace(0, 1, 3)
203286
a
204287
```
205288

206289
```{code-cell} ipython3
207-
:tags: [raises-exception]
290+
try:
291+
a[0] = 1
292+
except Exception as e:
293+
print(e)
208294
209-
a[0] = 1
210295
```
211296

212-
与不可变性一致,JAX 不支持原地操作:
213-
214-
```{code-cell} ipython3
215-
a = np.array((2, 1))
216-
a.sort() # Unlike NumPy, does not mutate a
217-
a
218-
```
297+
JAX 的设计者选择将数组设为不可变的,因为
219298

220-
```{code-cell} ipython3
221-
a = jnp.array((2, 1))
222-
a_new = a.sort() # Instead, the sort method returns a new sorted array
223-
a, a_new
224-
```
299+
1. JAX 使用*函数式编程风格*,并且
300+
2. 函数式编程通常避免可变数据
225301

226-
JAX 的设计者选择将数组设为不可变的,因为 JAX 使用 [函数式编程](https://en.wikipedia.org/wiki/Functional_programming) 风格
302+
我们将在 {ref}`下面 <jax_func>` 讨论这些思想
227303

228-
这个设计选择有重要的含义,我们接下来将对此进行探讨!
229304

305+
(jax_at_workaround)=
230306
#### 变通方法
231307

232-
我们注意到 JAX 确实提供了一种使用 [`at` 方法](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html) 进行原地数组修改的版本
308+
JAX 确实通过 [`at` 方法](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html) 提供了原地数组修改的直接替代方案
233309

234310
```{code-cell} ipython3
235311
a = jnp.linspace(0, 1, 3)
@@ -251,6 +327,8 @@ a
251327

252328
(尽管它在 JIT 编译的函数中实际上可以很高效——但现在先把这个放在一边。)
253329

330+
331+
(jax_func)=
254332
## 函数式编程
255333

256334
来自 JAX 的文档:

0 commit comments

Comments
 (0)