Skip to content

Commit 22a145b

Browse files
authored
🌐 [translation-sync] Add lax.fori_loop example and improve jax_intro (#107)
* Update translation: lectures/jax_intro.md * Update translation: .translate/state/jax_intro.md.yml * Update translation: lectures/numpy_vs_numba_vs_jax.md * Update translation: .translate/state/numpy_vs_numba_vs_jax.md.yml
1 parent 0fe603e commit 22a145b

File tree

4 files changed

+96
-47
lines changed

4 files changed

+96
-47
lines changed

.translate/state/jax_intro.md.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
source-sha: 95378b8382b4dbd1cd3e0ffe0e152811894c357f
1+
source-sha: 11e7d823f7f355f5025d40cab40bf801b3262e56
22
synced-at: "2026-04-13"
33
model: claude-sonnet-4-6
44
mode: UPDATE

.translate/state/numpy_vs_numba_vs_jax.md.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
source-sha: 95378b8382b4dbd1cd3e0ffe0e152811894c357f
1+
source-sha: 11e7d823f7f355f5025d40cab40bf801b3262e56
22
synced-at: "2026-04-13"
33
model: claude-sonnet-4-6
44
mode: UPDATE

lectures/jax_intro.md

Lines changed: 34 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,10 @@ translation:
1818
JAX as a NumPy Replacement::Differences::Speed!: سرعت!
1919
JAX as a NumPy Replacement::Differences::Speed!::With NumPy: با NumPy
2020
JAX as a NumPy Replacement::Differences::Speed!::With JAX: با JAX
21+
JAX as a NumPy Replacement::Differences::Size Experiment: آزمایش اندازه
2122
JAX as a NumPy Replacement::Differences::Precision: دقت
2223
JAX as a NumPy Replacement::Differences::Immutability: تغییرناپذیری
23-
JAX as a NumPy Replacement::Differences::A workaround: راه‌حل جایگزین
24+
JAX as a NumPy Replacement::Differences::A Workaround: راه‌حل جایگزین
2425
Functional Programming: برنامه‌نویسی تابعی
2526
Functional Programming::Pure functions: توابع خالص
2627
Functional Programming::Examples: مثال‌ها
@@ -76,18 +77,18 @@ import numpy as np
7677
import quantecon as qe
7778
```
7879

79-
توجه کنید که `jax.numpy as jnp` را import می‌کنیم که یک رابط شبیه NumPy فراهم می‌کند.
80-
8180
## JAX به عنوان جایگزین NumPy
8281

83-
یکی از ویژگی‌های جذاب JAX این است که، هر زمان که امکان‌پذیر باشد، عملیات پردازش آرایه‌های آن با API NumPy مطابقت دارد.
84-
85-
این بدان معناست که در بسیاری از موارد، می‌توانیم از JAX به عنوان جایگزین مستقیم NumPy استفاده کنیم.
86-
8782
بیایید به شباهت‌ها و تفاوت‌های بین JAX و NumPy نگاه کنیم.
8883

8984
### شباهت‌ها
9085

86+
در بالا `jax.numpy as jnp` را وارد کردیم که یک رابط شبیه به NumPy برای عملیات آرایه فراهم می‌کند.
87+
88+
یکی از ویژگی‌های جذاب JAX این است که، هر زمان که امکان‌پذیر باشد، این رابط با API NumPy مطابقت دارد.
89+
90+
در نتیجه، اغلب می‌توانیم از JAX به عنوان جایگزین مستقیم NumPy استفاده کنیم.
91+
9192
در اینجا برخی عملیات استاندارد آرایه با استفاده از `jnp` آمده است:
9293

9394
```{code-cell} ipython3
@@ -106,7 +107,7 @@ print(jnp.sum(a))
106107
print(jnp.dot(a, a))
107108
```
108109

109-
با این حال، شیء آرایه `a` یک آرایه NumPy نیست:
110+
با این حال، باید به خاطر داشت که شیء آرایه `a` یک آرایه NumPy نیست:
110111

111112
```{code-cell} ipython3
112113
a
@@ -129,11 +130,13 @@ jnp.sum(a)
129130
(jax_speed)=
130131
#### سرعت!
131132

132-
فرض کنیم می‌خواهیم تابع کسینوس را در نقاط بسیاری ارزیابی کنیم.
133+
یکی از تفاوت‌های عمده این است که JAX سریع‌تر است --- و گاهی بسیار سریع‌تر.
134+
135+
برای نشان دادن این موضوع، فرض کنیم می‌خواهیم تابع کسینوس را در نقاط بسیاری ارزیابی کنیم.
133136

134137
```{code-cell}
135138
n = 50_000_000
136-
x = np.linspace(0, 10, n)
139+
x = np.linspace(0, 10, n) # NumPy array
137140
```
138141

139142
##### با NumPy
@@ -174,27 +177,23 @@ with qe.Timer():
174177
# First run
175178
y = jnp.cos(x)
176179
# Hold the interpreter until the array operation finishes
177-
jax.block_until_ready(y);
180+
y.block_until_ready()
178181
```
179182

180183
```{note}
181-
در اینجا، برای اندازه‌گیری سرعت واقعی، از متد `block_until_ready` استفاده می‌کنیم
182-
تا مفسر را تا زمانی که نتایج محاسبات بازگردانده شوند نگه داریم.
183-
184-
این ضروری است زیرا JAX از ارسال ناهمزمان استفاده می‌کند که
184+
در بالا، متد `block_until_ready` مفسر را تا زمانی که نتایج محاسبات بازگردانده شوند نگه می‌دارد.
185+
این برای زمان‌بندی اجرا ضروری است زیرا JAX از ارسال ناهمزمان استفاده می‌کند که
185186
به مفسر Python اجازه می‌دهد جلوتر از محاسبات عددی حرکت کند.
186-
187-
برای کدهایی که زمان‌بندی نمی‌شوند، می‌توانید خط حاوی `block_until_ready` را حذف کنید.
188187
```
189188

190-
و بیایید دوباره زمان‌بندی کنیم.
189+
اکنون بیایید دوباره زمان‌بندی کنیم.
191190

192191
```{code-cell}
193192
with qe.Timer():
194193
# Second run
195194
y = jnp.cos(x)
196195
# Hold interpreter
197-
jax.block_until_ready(y);
196+
y.block_until_ready()
198197
```
199198

200199
روی GPU، این کد بسیار سریع‌تر از معادل NumPy خود اجرا می‌شود.
@@ -209,6 +208,8 @@ with qe.Timer():
209208

210209
اندازه برای تولید کد بهینه اهمیت دارد زیرا موازی‌سازی کارآمد نیازمند تطابق اندازه کار با سخت‌افزار موجود است.
211210

211+
#### آزمایش اندازه
212+
212213
می‌توانیم ادعا که JAX بر اندازه آرایه تخصص پیدا می‌کند را با تغییر اندازه ورودی و مشاهده زمان‌های اجرا تأیید کنیم.
213214

214215
```{code-cell}
@@ -220,15 +221,15 @@ with qe.Timer():
220221
# First run
221222
y = jnp.cos(x)
222223
# Hold interpreter
223-
jax.block_until_ready(y);
224+
y.block_until_ready()
224225
```
225226

226227
```{code-cell}
227228
with qe.Timer():
228229
# Second run
229230
y = jnp.cos(x)
230231
# Hold interpreter
231-
jax.block_until_ready(y);
232+
y.block_until_ready()
232233
```
233234

234235
زمان اجرا افزایش می‌یابد و سپس دوباره کاهش می‌یابد (این روی GPU واضح‌تر خواهد بود).
@@ -277,7 +278,7 @@ a[0] = 1
277278
a
278279
```
279280

280-
در JAX این کار شکست می‌خورد!
281+
در JAX این کار شکست می‌خورد 😱.
281282

282283
```{code-cell} ipython3
283284
a = jnp.linspace(0, 1, 3)
@@ -292,11 +293,18 @@ except Exception as e:
292293
293294
```
294295

295-
طراحان JAX تصمیم گرفتند آرایه‌ها را تغییرناپذیر کنند زیرا JAX از سبک برنامه‌نویسی تابعی استفاده می‌کند که در ادامه آن را بررسی می‌کنیم.
296+
طراحان JAX تصمیم گرفتند آرایه‌ها را تغییرناپذیر کنند زیرا
296297

298+
1. JAX از *سبک برنامه‌نویسی تابعی* استفاده می‌کند و
299+
2. برنامه‌نویسی تابعی معمولاً از داده‌های قابل تغییر اجتناب می‌کند
300+
301+
این ایده‌ها را {ref}`در ادامه <jax_func>` بررسی می‌کنیم.
302+
303+
304+
(jax_at_workaround)=
297305
#### راه‌حل جایگزین
298306

299-
توجه می‌کنیم که JAX یک جایگزین برای تغییر درجای آرایه با استفاده از [متد `at`](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html) فراهم می‌کند.
307+
JAX یک جایگزین مستقیم برای تغییر درجای آرایه از طریق [متد `at`](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html) فراهم می‌کند.
300308

301309
```{code-cell} ipython3
302310
a = jnp.linspace(0, 1, 3)
@@ -318,6 +326,8 @@ a
318326

319327
(اگرچه در واقع می‌تواند داخل توابع کامپایل‌شده JIT کارآمد باشد -- اما بیایید این را فعلاً کنار بگذاریم.)
320328

329+
330+
(jax_func)=
321331
## برنامه‌نویسی تابعی
322332

323333
از مستندات JAX:

lectures/numpy_vs_numba_vs_jax.md

Lines changed: 60 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -36,23 +36,23 @@ translation:
3636

3737
# NumPy در مقابل Numba در مقابل JAX
3838

39-
در سخنرانی‌های قبلی، سه کتابخانه اصلی برای محاسبات علمی و عددی را بحث کردیم:
39+
در درس‌های قبلی، سه کتابخانه اصلی برای محاسبات علمی و عددی را بحث کردیم:
4040

4141
* [NumPy](numpy)
4242
* [Numba](numba)
4343
* [JAX](jax_intro)
4444

4545
کدام یک را باید در هر موقعیت استفاده کنیم؟
4646

47-
این سخنرانی به آن سؤال پاسخ می‌دهد، حداقل تا حدی، با بحث در مورد برخی موارد استفاده.
47+
این درس به آن سؤال پاسخ می‌دهد، حداقل تا حدی، با بحث در مورد برخی موارد استفاده.
4848

4949
قبل از شروع، توجه می‌کنیم که دو مورد اول یک جفت طبیعی هستند: NumPy و Numba به خوبی با هم کار می‌کنند.
5050

5151
JAX، از سوی دیگر، به تنهایی می‌ایستد.
5252

5353
هنگام بررسی هر رویکرد، نه تنها کارایی و رد پای حافظه، بلکه وضوح و سهولت استفاده را نیز در نظر خواهیم گرفت.
5454

55-
علاوه بر آنچه در Anaconda موجود است، این سخنرانی به کتابخانه‌های زیر نیاز دارد:
55+
علاوه بر آنچه در Anaconda موجود است، این درس به کتابخانه‌های زیر نیاز دارد:
5656

5757
```{code-cell} ipython3
5858
---
@@ -67,7 +67,6 @@ tags: [hide-output]
6767
ما از import های زیر استفاده خواهیم کرد.
6868

6969
```{code-cell} ipython3
70-
import random
7170
from functools import partial
7271
7372
import numpy as np
@@ -455,15 +454,60 @@ Numba این عملیات ترتیبی را به طور بسیار کارآمد
455454

456455
### نسخه JAX
457456

458-
حالا بیایید یک نسخه JAX با استفاده از `lax.scan` ایجاد کنیم:
457+
حالا بیایید یک نسخه JAX با استفاده از سینتکس `at[t].set` ایجاد کنیم که، همان‌طور که {ref}`در درس JAX بحث شد <jax_at_workaround>`، راه‌حلی برای آرایه‌های تغییرناپذیر فراهم می‌کند.
459458

460-
(ما `n` را ایستا نگه می‌داریم زیرا بر اندازه آرایه تأثیر می‌گذارد و از این رو JAX می‌خواهد روی مقدار آن در کد کامپایل شده تخصصی شود.)
459+
ما از `lax.fori_loop` استفاده می‌کنیم که نسخه‌ای از حلقه for است که می‌تواند توسط XLA کامپایل شود.
461460

462461
```{code-cell} ipython3
463462
cpu = jax.devices("cpu")[0]
464463
465-
@partial(jax.jit, static_argnames=('n',), device=cpu)
466-
def qm_jax(x0, n, α=4.0):
464+
@partial(jax.jit, static_argnames=("n",), device=cpu)
465+
def qm_jax_fori(x0, n, α=4.0):
466+
467+
x = jnp.empty(n + 1).at[0].set(x0)
468+
469+
def update(t, x):
470+
return x.at[t + 1].set(α * x[t] * (1 - x[t]))
471+
472+
x = lax.fori_loop(0, n, update, x)
473+
return x
474+
475+
```
476+
477+
* ما `n` را ایستا نگه می‌داریم زیرا بر اندازه آرایه تأثیر می‌گذارد و از این رو JAX می‌خواهد روی مقدار آن در کد کامپایل شده تخصصی شود.
478+
* ما به CPU از طریق `device=cpu` متصل می‌مانیم زیرا این بار کاری ترتیبی از بسیاری عملیات کوچک تشکیل شده است که فرصت کمی برای موازی‌سازی GPU باقی می‌گذارد.
479+
480+
اگرچه `at[t].set` در هر مرحله ظاهراً یک آرایه جدید ایجاد می‌کند، در داخل یک تابع کامپایل‌شده با JIT، کامپایلر تشخیص می‌دهد که آرایه قدیمی دیگر مورد نیاز نیست و به‌روزرسانی را در جا انجام می‌دهد.
481+
482+
بیایید آن را با همان پارامترها زمان‌بندی کنیم:
483+
484+
```{code-cell} ipython3
485+
with qe.Timer():
486+
# First run
487+
x_jax = qm_jax_fori(0.1, n)
488+
# Hold interpreter
489+
x_jax.block_until_ready()
490+
```
491+
492+
بیایید دوباره اجرا کنیم تا سربار کامپایل حذف شود:
493+
494+
```{code-cell} ipython3
495+
with qe.Timer():
496+
# Second run
497+
x_jax = qm_jax_fori(0.1, n)
498+
# Hold interpreter
499+
x_jax.block_until_ready()
500+
```
501+
502+
JAX نیز برای این عملیات ترتیبی کاملاً کارآمد است.
503+
504+
روش دیگری برای پیاده‌سازی حلقه وجود دارد که از `lax.scan` استفاده می‌کند.
505+
506+
این روش جایگزین، به طور قابل بحث، بیشتر با رویکرد تابعی JAX همسو است --- اگرچه سینتکس آن به خاطر سپردن دشواری دارد.
507+
508+
```{code-cell} ipython3
509+
@partial(jax.jit, static_argnames=("n",), device=cpu)
510+
def qm_jax_scan(x0, n, α=4.0):
467511
def update(x, t):
468512
x_new = α * x * (1 - x)
469513
return x_new, x_new
@@ -474,16 +518,12 @@ def qm_jax(x0, n, α=4.0):
474518

475519
این کد خواندن آسانی ندارد اما، در اصل، `lax.scan` به طور مکرر `update` را فراخوانی می‌کند و بازگشت‌های `x_new` را در یک آرایه جمع می‌کند.
476520

477-
```{note}
478-
ما `device=cpu` را در decorator `jax.jit` مشخص می‌کنیم زیرا این محاسبه از بسیاری عملیات ترتیبی کوچک تشکیل شده است که فرصت کمی برای بهره‌برداری GPU از موازی‌سازی باقی می‌گذارد. در نتیجه، سربار راه‌اندازی kernel تمایل دارد روی GPU غالب شود و CPU را متناسب‌تر برای این بار کاری می‌کند.
479-
```
480-
481521
بیایید آن را با همان پارامترها زمان‌بندی کنیم:
482522

483523
```{code-cell} ipython3
484524
with qe.Timer():
485525
# First run
486-
x_jax = qm_jax(0.1, n)
526+
x_jax = qm_jax_scan(0.1, n)
487527
# Hold interpreter
488528
x_jax.block_until_ready()
489529
```
@@ -493,13 +533,11 @@ with qe.Timer():
493533
```{code-cell} ipython3
494534
with qe.Timer():
495535
# Second run
496-
x_jax = qm_jax(0.1, n)
536+
x_jax = qm_jax_scan(0.1, n)
497537
# Hold interpreter
498538
x_jax.block_until_ready()
499539
```
500540

501-
JAX نیز برای این عملیات ترتیبی کاملاً کارآمد است.
502-
503541
هم JAX و هم Numba عملکرد قوی پس از کامپایل ارائه می‌دهند.
504542

505543
### خلاصه
@@ -510,9 +548,9 @@ JAX نیز برای این عملیات ترتیبی کاملاً کارآمد
510548

511549
این دقیقاً نحوه تفکر اکثر برنامه‌نویسان در مورد الگوریتم است.
512550

513-
نسخه JAX، از سوی دیگر، نیاز به استفاده از `lax.scan` دارد که به طور قابل توجهی کمتر شهودی است.
551+
نسخه‌های JAX، از سوی دیگر، نیاز به استفاده از `lax.fori_loop` یا `lax.scan` دارند که هر دو کمتر شهودی از یک حلقه استاندارد Python هستند.
514552

515-
علاوه بر این، آرایه‌های تغییرناپذیر JAX به این معنی است که نمی‌توانیم به سادگی عناصر آرایه را در جا به‌روزرسانی کنیم و تکرار مستقیم الگوریتم مورد استفاده توسط Numba را سخت می‌کند.
553+
در حالی که سینتکس `at[t].set` در JAX به‌روزرسانی عنصر به عنصر را ممکن می‌سازد، کد کلی همچنان سخت‌تر از معادل Numba برای خواندن است.
516554

517555
برای این نوع عملیات ترتیبی، Numba برنده واضح از نظر وضوح کد و سهولت پیاده‌سازی است.
518556

@@ -532,11 +570,12 @@ JAX نیز برای این عملیات ترتیبی کاملاً کارآمد
532570

533571
کد طبیعی و خوانا است --- صرفاً یک حلقه پایتون با یک decorator --- و کارایی آن عالی است.
534572

535-
JAX می‌تواند مسائل ترتیبی را از طریق `lax.scan` مدیریت کند، اما نحو آن کمتر شهودی است و برای کارهای کاملاً ترتیبی، بهره‌وری اضافی ناچیز است.
536-
537-
با این حال، `lax.scan` یک مزیت مهم دارد: از مشتق‌گیری خودکار در طول حلقه پشتیبانی می‌کند، که Numba قادر به انجام آن نیست.
573+
JAX می‌تواند مسائل ترتیبی را از طریق `lax.fori_loop` یا `lax.scan` مدیریت کند، اما نحو آن کمتر شهودی است.
538574

575+
```{note}
576+
یک مزیت مهم `lax.fori_loop` و `lax.scan` این است که از مشتق‌گیری خودکار در طول حلقه پشتیبانی می‌کنند، که Numba قادر به انجام آن نیست.
539577
اگر نیاز دارید از طریق یک محاسبه ترتیبی مشتق بگیرید (مثلاً محاسبه حساسیت‌های یک مسیر نسبت به پارامترهای مدل)، JAX علی‌رغم نحو کمتر طبیعی‌اش، انتخاب بهتری است.
578+
```
540579

541580
در عمل، بسیاری از مسائل ترکیبی از هر دو الگو هستند.
542581

0 commit comments

Comments
 (0)