@@ -15,9 +15,13 @@ translation:
1515 JAX as a NumPy Replacement : JAX 作为 NumPy 的替代品
1616 JAX as a NumPy Replacement::Similarities : 相似之处
1717 JAX as a NumPy Replacement::Differences : 差异
18+ JAX as a NumPy Replacement::Differences::Speed! : 速度!
19+ JAX as a NumPy Replacement::Differences::Speed!::With NumPy : 使用 NumPy
20+ JAX as a NumPy Replacement::Differences::Speed!::With JAX : 使用 JAX
21+ JAX as a NumPy Replacement::Differences::Size Experiment : 大小实验
1822 JAX as a NumPy Replacement::Differences::Precision : 精度
1923 JAX as a NumPy Replacement::Differences::Immutability : 不可变性
20- JAX as a NumPy Replacement::Differences::A workaround : 变通方法
24+ JAX as a NumPy Replacement::Differences::A Workaround : 变通方法
2125 Functional Programming : 函数式编程
2226 Functional Programming::Pure functions : 纯函数
2327 Functional Programming::Examples : 示例
@@ -26,18 +30,12 @@ translation:
2630 Random numbers::Why explicit random state? : 为什么要显式随机状态?
2731 Random numbers::Why explicit random state?::NumPy's approach : NumPy 的方法
2832 Random numbers::Why explicit random state?::JAX's approach : JAX 的方法
29- JIT compilation : JIT 编译
30- JIT compilation::A simple example : 一个简单的示例
31- JIT compilation::A simple example::With NumPy : 使用 NumPy
32- JIT compilation::A simple example::With JAX : 使用 JAX
33- JIT compilation::A simple example::Changing array sizes : 更改数组大小
34- JIT compilation::Evaluating a more complicated function : 评估更复杂的函数
35- JIT compilation::Evaluating a more complicated function::With NumPy : 使用 NumPy
36- JIT compilation::Evaluating a more complicated function::With JAX : 使用 JAX
37- JIT compilation::How JIT compilation works : JIT 编译的工作原理
38- JIT compilation::Compiling the whole function : 编译整个函数
39- JIT compilation::Compiling non-pure functions : 编译非纯函数
40- JIT compilation::Summary : 总结
33+ JIT Compilation : JIT 编译
34+ JIT Compilation::With NumPy : 一个简单的示例
35+ JIT Compilation::With JAX : 评估更复杂的函数
36+ JIT Compilation::Compiling the Whole Function : JIT 编译的工作原理
37+ JIT Compilation::How JIT compilation works : 编译整个函数
38+ JIT Compilation::Compiling non-pure functions : 编译非纯函数
4139 Vectorization with `vmap` : 使用 `vmap` 进行向量化
4240 Vectorization with `vmap`::A simple example : 一个简单的示例
4341 Vectorization with `vmap`::Combining transformations : 组合变换
@@ -49,13 +47,16 @@ translation:
4947
5048本讲座简要介绍 [ Google JAX] ( https://github.com/jax-ml/jax ) 。
5149
50+ ``` {include} _admonition/gpu.md
51+ ```
52+
5253JAX 是一个高性能科学计算库,提供以下功能:
5354
5455* 类似 [ NumPy] ( https://en.wikipedia.org/wiki/NumPy ) 的接口,可以在 CPU 和 GPU 上自动并行化,
5556* 一个即时编译器,用于加速大量数值运算,以及
5657* [ 自动微分] ( https://en.wikipedia.org/wiki/Automatic_differentiation ) 。
5758
58- JAX 也在日益维护和提供[ 更多专业化的科学计算例程] ( https://docs.jax.dev/en/latest/jax.scipy.html ) ,例如那些最初在 [ SciPy] ( https://en.wikipedia.org/wiki/SciPy ) 中找到的例程。
59+ JAX 也在日益维护和提供 [ 更多专业化的科学计算例程] ( https://docs.jax.dev/en/latest/jax.scipy.html ) ,例如那些最初在 [ SciPy] ( https://en.wikipedia.org/wiki/SciPy ) 中找到的例程。
5960
6061除了 Anaconda 中已有的内容外,本讲座还需要以下库:
6162
@@ -65,36 +66,27 @@ JAX 也在日益维护和提供[更多专业化的科学计算例程](https://do
6566!pip install jax quantecon
6667```
6768
68- ``` {include} _admonition/gpu.md
69- ```
70-
71- ## JAX 作为 NumPy 的替代品
72-
73- JAX 的一个吸引人之处在于,它的数组处理操作在尽可能的情况下遵循 NumPy API。
74-
75- 这意味着在许多情况下,我们可以将 JAX 作为 NumPy 的直接替代品使用。
76-
77- 让我们来看看 JAX 和 NumPy 之间的异同。
78-
79- ### 相似之处
80-
8169我们将使用以下导入:
8270
8371``` {code-cell} ipython3
8472import jax
8573import jax.numpy as jnp
8674import matplotlib.pyplot as plt
87- import matplotlib.patches as mpatches
8875import numpy as np
8976import quantecon as qe
90- import matplotlib as mpl # i18n
91- import matplotlib.font_manager # i18n
92- FONTPATH = "_fonts/SourceHanSerifSC-SemiBold.otf" # i18n
93- mpl.font_manager.fontManager.addfont(FONTPATH) # i18n
94- mpl.rcParams['font.family'] = ['Source Han Serif SC'] # i18n
9577```
9678
97- 注意我们导入了 ` jax.numpy as jnp ` ,它提供了类似 NumPy 的接口。
79+ ## JAX 作为 NumPy 的替代品
80+
81+ 让我们来看看 JAX 和 NumPy 之间的异同。
82+
83+ ### 相似之处
84+
85+ 上面我们导入了 ` jax.numpy as jnp ` ,它提供了类似 NumPy 的数组操作接口。
86+
87+ JAX 的一个吸引人之处在于,这个接口在尽可能的情况下遵循 NumPy API。
88+
89+ 因此,我们通常可以将 JAX 作为 NumPy 的直接替代品使用。
9890
9991以下是使用 ` jnp ` 进行的一些标准数组操作:
10092
@@ -110,15 +102,11 @@ print(a)
110102print(jnp.sum(a))
111103```
112104
113- ``` {code-cell} ipython3
114- print(jnp.mean(a))
115- ```
116-
117105``` {code-cell} ipython3
118106print(jnp.dot(a, a))
119107```
120108
121- 然而 ,数组对象 ` a ` 并不是 NumPy 数组:
109+ 但需要注意的是 ,数组对象 ` a ` 并不是 NumPy 数组:
122110
123111``` {code-cell} ipython3
124112a
128116type(a)
129117```
130118
131- 即使是数组上的标量值映射也会返回 JAX 数组。
119+ 即使是数组上的标量值映射也会返回 JAX 数组而非标量!
132120
133121``` {code-cell} ipython3
134122jnp.sum(a)
135123```
136124
137- 对高维数组的操作也与 NumPy 类似:
125+ ### 差异
126+
127+ 现在让我们来看看 JAX 和 NumPy 数组操作之间的一些差异。
128+
129+ (jax_speed)=
130+ #### 速度!
138131
139- ``` {code-cell} ipython3
140- A = jnp.ones((2, 2))
141- B = jnp.identity(2)
142- A @ B
132+ 一个主要差异是 JAX 更快——有时快得多。
133+
134+ 为了说明这一点,假设我们想在许多点处计算余弦函数。
135+
136+ ``` {code-cell}
137+ n = 50_000_000
138+ x = np.linspace(0, 10, n) # NumPy array
143139```
144140
145- JAX 的数组接口也提供了 ` linalg ` 子包:
141+ ##### 使用 NumPy
146142
147- ``` {code-cell} ipython3
148- jnp.linalg.inv(B) # Inverse of identity is identity
143+ 让我们用 NumPy 来试试。
144+
145+ ``` {code-cell}
146+ with qe.Timer():
147+ # First NumPy timing
148+ y = np.cos(x)
149149```
150150
151- ``` {code-cell} ipython3
152- jnp.linalg.eigh(B) # Computes eigenvalues and eigenvectors
151+ 再来一次。
152+
153+ ``` {code-cell}
154+ with qe.Timer():
155+ # Second NumPy timing
156+ y = np.cos(x)
153157```
154158
155- ### 差异
159+ 这里
156160
157- 现在让我们来看看 JAX 和 NumPy 数组操作之间的一些差异。
161+ * NumPy 使用预编译的二进制文件将余弦函数应用于浮点数数组
162+ * 该二进制文件在本地机器的 CPU 上运行
163+
164+ ##### 使用 JAX
165+
166+ 现在让我们用 JAX 来试试。
167+
168+ ``` {code-cell}
169+ x = jnp.linspace(0, 10, n)
170+ ```
171+
172+ 让我们对同样的过程计时。
173+
174+ ``` {code-cell}
175+ with qe.Timer():
176+ # First run
177+ y = jnp.cos(x)
178+ # Hold the interpreter until the array operation finishes
179+ y.block_until_ready()
180+ ```
181+
182+ ``` {note}
183+ 上面,`block_until_ready` 方法会阻塞解释器,直到计算结果返回。
184+ 这对于计时执行是必要的,因为 JAX 使用异步调度,
185+ 允许 Python 解释器在数值计算之前运行。
186+ ```
187+
188+ 现在让我们再次计时。
189+
190+ ``` {code-cell}
191+ with qe.Timer():
192+ # Second run
193+ y = jnp.cos(x)
194+ # Hold interpreter
195+ y.block_until_ready()
196+ ```
197+
198+ 在 GPU 上,这段代码的运行速度远快于其 NumPy 等价代码。
199+
200+ 另外,通常情况下,由于 JIT 编译,第二次运行比第一次更快。
201+
202+ 这是因为即使是像 ` jnp.cos ` 这样的内置函数也会被 JIT 编译——第一次运行包含了编译时间。
203+
204+ 为什么 JAX 要对 ` jnp.cos ` 这样的内置函数进行 JIT 编译,而不是像 NumPy 那样直接提供预编译版本呢?
205+
206+ 原因是 JIT 编译器希望针对所使用数组的* 大小* (以及数据类型)进行专门优化。
207+
208+ 大小对于生成优化代码很重要,因为高效的并行化需要将任务大小与可用硬件匹配。
209+
210+ #### 大小实验
211+
212+ 我们可以通过改变输入大小并观察运行时间来验证 JAX 针对数组大小进行专门优化的说法。
213+
214+ ``` {code-cell}
215+ x = jnp.linspace(0, 10, n + 1)
216+ ```
217+
218+ ``` {code-cell}
219+ with qe.Timer():
220+ # First run
221+ y = jnp.cos(x)
222+ # Hold interpreter
223+ y.block_until_ready()
224+ ```
225+
226+
227+ ``` {code-cell}
228+ with qe.Timer():
229+ # Second run
230+ y = jnp.cos(x)
231+ # Hold interpreter
232+ y.block_until_ready()
233+ ```
234+
235+ 运行时间先增加后减少(这在 GPU 上会更明显)。
236+
237+ 这与上面的讨论一致——改变数组大小后的第一次运行显示了编译开销。
238+
239+ 下面将进一步讨论 JIT 编译。
158240
159241#### 精度
160242
161- NumPy 和 JAX 之间的一个差异是 JAX 默认使用 32 位浮点数。
243+ NumPy 和 JAX 之间的另一个差异是 JAX 默认使用 32 位浮点数。
162244
163245这是因为 JAX 经常用于 GPU 计算,而大多数 GPU 计算使用 32 位浮点数。
164246
@@ -196,40 +278,34 @@ a[0] = 1
196278a
197279```
198280
199- 在 JAX 中,这会失败:
281+ 在 JAX 中,这会失败 😱。
282+
200283
201284``` {code-cell} ipython3
202285a = jnp.linspace(0, 1, 3)
203286a
204287```
205288
206289``` {code-cell} ipython3
207- :tags: [raises-exception]
290+ try:
291+ a[0] = 1
292+ except Exception as e:
293+ print(e)
208294
209- a[0] = 1
210295```
211296
212- 与不可变性一致,JAX 不支持原地操作:
213-
214- ``` {code-cell} ipython3
215- a = np.array((2, 1))
216- a.sort() # Unlike NumPy, does not mutate a
217- a
218- ```
297+ JAX 的设计者选择将数组设为不可变的,因为
219298
220- ``` {code-cell} ipython3
221- a = jnp.array((2, 1))
222- a_new = a.sort() # Instead, the sort method returns a new sorted array
223- a, a_new
224- ```
299+ 1 . JAX 使用* 函数式编程风格* ,并且
300+ 2 . 函数式编程通常避免可变数据
225301
226- JAX 的设计者选择将数组设为不可变的,因为 JAX 使用 [ 函数式编程 ] ( https://en.wikipedia.org/wiki/Functional_programming ) 风格 。
302+ 我们将在 {ref} ` 下面 <jax_func> ` 讨论这些思想 。
227303
228- 这个设计选择有重要的含义,我们接下来将对此进行探讨!
229304
305+ (jax_at_workaround)=
230306#### 变通方法
231307
232- 我们注意到 JAX 确实提供了一种使用 [ ` at ` 方法] ( https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html ) 进行原地数组修改的版本 。
308+ JAX 确实通过 [ ` at ` 方法] ( https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html ) 提供了原地数组修改的直接替代方案 。
233309
234310``` {code-cell} ipython3
235311a = jnp.linspace(0, 1, 3)
251327
252328(尽管它在 JIT 编译的函数中实际上可以很高效——但现在先把这个放在一边。)
253329
330+
331+ (jax_func)=
254332## 函数式编程
255333
256334来自 JAX 的文档:
0 commit comments