@@ -17,17 +17,15 @@ translation:
1717 Vectorized operations::Parallelized Numba : 并行化的 Numba
1818 Vectorized operations::Vectorized code with JAX : 使用 JAX 的向量化代码
1919 Vectorized operations::JAX plus vmap : JAX 加 vmap
20- Vectorized operations::JAX plus vmap::Version 1 : 版本 1
21- Vectorized operations::vmap version 2 : vmap 版本 2
22- Vectorized operations::Summary : 总结
20+ Vectorized operations::Summary : vmap 版本 2
2321 Sequential operations : 顺序运算
2422 Sequential operations::Numba Version : Numba 版本
2523 Sequential operations::JAX Version : JAX 版本
2624 Sequential operations::Summary : 总结
2725 Overall recommendations : 总体建议
2826---
2927
30- (parallel )=
28+ (numpy_numba_jax )=
3129``` {raw} jupyter
3230<div id="qe-notebook-header" align="right" style="text-align:right;">
3331 <a href="https://quantecon.org/" title="quantecon.org">
@@ -69,7 +67,6 @@ tags: [hide-output]
6967我们将使用以下导入。
7068
7169``` {code-cell} ipython3
72- import random
7370from functools import partial
7471
7572import numpy as np
@@ -472,14 +469,16 @@ def qm(x0, n, α=4.0):
472469``` {code-cell} ipython3
473470n = 10_000_000
474471
475- with qe.Timer(precision=8):
472+ with qe.Timer():
473+ # First run
476474 x = qm(0.1, n)
477475```
478476
479477让我们再次运行以消除编译时间:
480478
481479``` {code-cell} ipython3
482- with qe.Timer(precision=8):
480+ with qe.Timer():
481+ # Second run
483482 x = qm(0.1, n)
484483```
485484
@@ -491,15 +490,62 @@ Numba 的编译通常相当快,对于像这样的顺序运算,生成的代
491490
492491### JAX 版本
493492
494- 现在让我们使用 ` lax.scan ` 创建一个 JAX 版本:
493+ 现在让我们使用 ` at[t].set ` 风格的语法创建一个 JAX 版本,正如 {ref} ` JAX 讲座中讨论的 <jax_at_workaround> ` ,这为不可变数组提供了一种变通方法。
495494
496- (我们将 ` n ` 设为静态,因为它影响数组大小,JAX 希望在编译代码中针对其值进行特化处理。)
495+ 我们将使用 ` lax.fori_loop ` ,它是一种可以被 XLA 编译的 for 循环版本。
497496
498497``` {code-cell} ipython3
499498cpu = jax.devices("cpu")[0]
500499
501- @partial(jax.jit, static_argnums=(1,), device=cpu)
502- def qm_jax(x0, n, α=4.0):
500+ @partial(jax.jit, static_argnames=("n",), device=cpu)
501+ def qm_jax_fori(x0, n, α=4.0):
502+
503+ x = jnp.empty(n + 1).at[0].set(x0)
504+
505+ def update(t, x):
506+ return x.at[t + 1].set(α * x[t] * (1 - x[t]))
507+
508+ x = lax.fori_loop(0, n, update, x)
509+ return x
510+
511+ ```
512+
513+ * 我们将 ` n ` 设为静态,因为它影响数组大小,JAX 希望在编译代码中针对其值进行特化处理。
514+ * 我们通过 ` device=cpu ` 将计算固定在 CPU 上,因为这种顺序工作负载由许多小操作组成,几乎没有机会利用 GPU 并行性。
515+
516+ 虽然 ` at[t].set ` 看起来在每一步都创建了一个新数组,但在 JIT 编译的函数内部,编译器会检测到旧数组不再需要,并就地执行更新。
517+
518+ 让我们使用相同的参数计时:
519+
520+ ``` {code-cell} ipython3
521+ with qe.Timer():
522+ # First run
523+ x_jax = qm_jax_fori(0.1, n)
524+ # Hold interpreter
525+ x_jax.block_until_ready()
526+ ```
527+
528+ 让我们再次运行以消除编译开销:
529+
530+ ``` {code-cell} ipython3
531+ with qe.Timer():
532+ # Second run
533+ x_jax = qm_jax_fori(0.1, n)
534+ # Hold interpreter
535+ x_jax.block_until_ready()
536+ ```
537+
538+ JAX 对于这种顺序运算也相当高效。
539+
540+
541+ 我们还有另一种实现循环的方式,使用 ` lax.scan ` 。
542+
543+ 这种替代方案可以说更符合 JAX 的函数式方法——尽管语法难以记忆。
544+
545+
546+ ``` {code-cell} ipython3
547+ @partial(jax.jit, static_argnames=("n",), device=cpu)
548+ def qm_jax_scan(x0, n, α=4.0):
503549 def update(x, t):
504550 x_new = α * x * (1 - x)
505551 return x_new, x_new
@@ -510,33 +556,27 @@ def qm_jax(x0, n, α=4.0):
510556
511557这段代码不易阅读,但本质上,` lax.scan ` 反复调用 ` update ` 并将返回值 ` x_new ` 累积到一个数组中。
512558
513- ``` {note}
514- 细心的读者会注意到,我们在 `jax.jit` 装饰器中指定了 `device=cpu`。
515-
516- 该计算由许多小的顺序运算组成,几乎没有机会让 GPU 利用并行性。
517-
518- 因此,GPU 上的内核启动开销往往占主导地位,使得 CPU 更适合这种工作负载。
519-
520- 好奇的读者可以尝试删除此选项,看看性能如何变化。
521- ```
522-
523559让我们使用相同的参数计时:
524560
525561``` {code-cell} ipython3
526- with qe.Timer(precision=8):
527- x_jax = qm_jax(0.1, n).block_until_ready()
562+ with qe.Timer():
563+ # First run
564+ x_jax = qm_jax_scan(0.1, n)
565+ # Hold interpreter
566+ x_jax.block_until_ready()
528567```
529568
530569让我们再次运行以消除编译开销:
531570
532571``` {code-cell} ipython3
533- with qe.Timer(precision=8):
534- x_jax = qm_jax(0.1, n).block_until_ready()
572+ with qe.Timer():
573+ # Second run
574+ x_jax = qm_jax_scan(0.1, n)
575+ # Hold interpreter
576+ x_jax.block_until_ready()
535577```
536578
537- JAX 对于这种顺序运算也相当高效。
538-
539- JAX 和 Numba 在编译后都能提供出色的性能,对于纯顺序运算,Numba 通常(但并非总是)提供略快的速度。
579+ JAX 和 Numba 在编译后都能提供出色的性能。
540580
541581### 总结
542582
@@ -546,11 +586,11 @@ Numba 版本简单直观,易于阅读:我们只需分配一个数组,然
546586
547587这正是大多数程序员思考该算法的方式。
548588
549- 另一方面,JAX 版本需要使用 ` lax.scan ` ,这明显不够直观 。
589+ 另一方面,JAX 版本需要使用 ` lax.fori_loop ` 或 ` lax. scan` ,两者都比标准 Python 循环更不直观 。
550590
551- 此外, JAX 的不可变数组意味着我们无法简单地就地更新数组元素,这使得直接复制 Numba 使用的算法变得困难 。
591+ 虽然 JAX 的 ` at[t].set ` 语法确实允许逐元素更新,但整体代码仍然比 Numba 等价版本更难阅读 。
552592
553- 对于这类顺序运算,在代码清晰度、实现便利性以及高性能方面 ,Numba 是明显的赢家。
593+ 对于这类顺序运算,在代码清晰度和实现便利性方面 ,Numba 是明显的赢家。
554594
555595## 总体建议
556596
@@ -568,11 +608,12 @@ Numba 版本简单直观,易于阅读:我们只需分配一个数组,然
568608
569609代码自然易读——只需一个带装饰器的 Python 循环——且性能出色。
570610
571- JAX 可以通过 ` lax.scan ` 处理顺序问题,但对于纯顺序工作而言,其语法不够直观,性能提升也十分有限。
572-
573- 话虽如此,` lax.scan ` 有一个重要优势:它支持对循环进行自动微分,而 Numba 无法做到这一点。
611+ JAX 可以通过 ` lax.fori_loop ` 或 ` lax.scan ` 处理顺序问题,但语法不够直观。
574612
613+ ``` {note}
614+ `lax.fori_loop` 和 `lax.scan` 有一个重要优势:它们支持对循环进行自动微分,而 Numba 无法做到这一点。
575615如果需要对顺序计算进行微分(例如,计算轨迹对模型参数的敏感性),尽管语法不够自然,JAX 仍是更好的选择。
616+ ```
576617
577618在实践中,许多问题往往同时涉及两种模式。
578619
0 commit comments