Skip to content

Commit 0d03a4f

Browse files
authored
🌐 [translation-sync] Improve NumPy vs Numba vs JAX lecture (#96)
* Update translation: lectures/numpy_vs_numba_vs_jax.md * Update translation: .translate/state/numpy_vs_numba_vs_jax.md.yml
1 parent 6da33da commit 0d03a4f

File tree

2 files changed

+53
-46
lines changed

2 files changed

+53
-46
lines changed
Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
source-sha: 05ce95691fd97e48da39dd6d58fe032c03e8813d
2-
synced-at: "2026-04-08"
1+
source-sha: 94dd7d22385ec46d740db1fc2cddf05c29377594
2+
synced-at: "2026-04-12"
33
model: claude-sonnet-4-6
44
mode: UPDATE
55
section-count: 3
6-
tool-version: 0.13.1
6+
tool-version: 0.14.1

lectures/numpy_vs_numba_vs_jax.md

Lines changed: 50 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@ translation:
1717
Vectorized operations::Parallelized Numba: Numba موازی شده
1818
Vectorized operations::Vectorized code with JAX: کد برداری شده با JAX
1919
Vectorized operations::JAX plus vmap: JAX به علاوه vmap
20-
Vectorized operations::JAX plus vmap::Version 1: نسخه 1
21-
Vectorized operations::vmap version 2: نسخه 2 vmap
2220
Vectorized operations::Summary: خلاصه
2321
Sequential operations: عملیات ترتیبی
2422
Sequential operations::Numba Version: نسخه Numba
@@ -27,7 +25,7 @@ translation:
2725
Overall recommendations: توصیه‌های کلی
2826
---
2927

30-
(parallel)=
28+
(numpy_numba_jax)=
3129
```{raw} jupyter
3230
<div id="qe-notebook-header" align="right" style="text-align:right;">
3331
<a href="https://quantecon.org/" title="quantecon.org">
@@ -150,7 +148,7 @@ for x in grid:
150148

151149
در اینجا از `np.meshgrid` برای ایجاد شبکه‌های ورودی دوبعدی `x` و `y` استفاده می‌کنیم به گونه‌ای که `f(x, y)` تمام ارزیابی‌ها را روی شبکه حاصلضرب تولید می‌کند.
152150

153-
(این استراتژی به Matlab بازمی‌گردد.)
151+
(این استراتژی به MATLAB بازمی‌گردد.)
154152

155153
```{code-cell} ipython3
156154
grid = np.linspace(-3, 3, 3_000)
@@ -226,24 +224,44 @@ def compute_max_numba_parallel(grid):
226224
227225
```
228226

229-
معمولاً این نتیجه نادرستی برمی‌گرداند:
227+
این `-inf` برمی‌گرداند --- مقدار اولیه `m`، انگار که هرگز به‌روزرسانی نشده است:
230228

231229
```{code-cell} ipython3
232230
z_max_parallel_incorrect = compute_max_numba_parallel(grid)
233231
print(f"Numba result: {z_max_parallel_incorrect} 😱")
234232
```
235233

236-
دلیل این است که متغیر `m` بین نخ‌ها مشترک است و به درستی کنترل نمی‌شود.
234+
برای درک چرایی این موضوع، به یاد بیاورید که `prange` حلقه بیرونی را بین نخ‌ها تقسیم می‌کند.
237235

238-
وقتی چندین نخ سعی می‌کنند همزمان `m` را بخوانند و بنویسند، با یکدیگر تداخل می‌کنند.
236+
هر نخ یک نسخه خصوصی از `m` دارد که با مقدار `-np.inf` مقداردهی اولیه شده و آن را در بازه تکرارهای خود به درستی به‌روزرسانی می‌کند.
239237

240-
نخ‌ها مقادیر قدیمی `m` را می‌خوانند یا به‌روزرسانی‌های یکدیگر را بازنویسی می‌کنند --- یا `m` هرگز از مقدار اولیه خود به‌روزرسانی نمی‌شود.
238+
اما در پایان حلقه، Numba باید نسخه‌های هر نخ از `m` را در یک مقدار واحد ترکیب کند --- یک **تقلیل (reduction)**.
241239

242-
در اینجا یک نسخه با دقت بیشتری نوشته شده است.
240+
برای الگوهایی که تشخیص می‌دهد، مانند `m += z` (جمع) یا `m = max(m, z)` (max)، Numba عملگر ترکیب را می‌شناسد.
241+
242+
اما الگوی `if z > m: m = z` را به عنوان یک تقلیل max تشخیص نمی‌دهد، بنابراین نتایج هر نخ هرگز ترکیب نمی‌شوند و `m` مقدار اولیه خود را حفظ می‌کند.
243+
244+
ساده‌ترین راه‌حل جایگزینی شرط با `max` است که Numba آن را می‌شناسد:
243245

244246
```{code-cell} ipython3
245247
@numba.jit(parallel=True)
246248
def compute_max_numba_parallel(grid):
249+
n = len(grid)
250+
m = -np.inf
251+
for i in numba.prange(n):
252+
for j in range(n):
253+
x = grid[i]
254+
y = grid[j]
255+
z = np.cos(x**2 + y**2) / (1 + x**2 + y**2)
256+
m = max(m, z)
257+
return m
258+
```
259+
260+
یک روش جایگزین این است که بدنه حلقه را بین `i` ها کاملاً مستقل کنیم و تقلیل را خودمان انجام دهیم:
261+
262+
```{code-cell} ipython3
263+
@numba.jit(parallel=True)
264+
def compute_max_numba_parallel_v2(grid):
247265
n = len(grid)
248266
row_maxes = np.empty(n)
249267
for i in numba.prange(n):
@@ -258,9 +276,7 @@ def compute_max_numba_parallel(grid):
258276
return np.max(row_maxes)
259277
```
260278

261-
اکنون بلوک کدی که `for i in numba.prange(n)` روی آن عمل می‌کند بین `i` ها مستقل است.
262-
263-
هر نخ به یک عنصر جداگانه از آرایه `row_maxes` می‌نویسد و موازی‌سازی ایمن است.
279+
در اینجا هر نخ به یک عنصر جداگانه از `row_maxes` می‌نویسد، بنابراین تقلیل را خودمان از طریق `np.max` انجام می‌دهیم.
264280

265281
```{code-cell} ipython3
266282
z_max_parallel = compute_max_numba_parallel(grid)
@@ -321,7 +337,7 @@ with qe.Timer(precision=8):
321337

322338
### JAX به علاوه vmap
323339

324-
یک مشکل با کد NumPy و کد JAX وجود دارد:
340+
یک مشکل با کد NumPy و کد JAX فوق وجود دارد:
325341

326342
در حالی که آرایه‌های تخت حافظه کمی دارند
327343

@@ -339,9 +355,9 @@ x_mesh.nbytes + y_mesh.nbytes
339355

340356
خوشبختانه، JAX رویکرد متفاوتی را با استفاده از [jax.vmap](https://docs.jax.dev/en/latest/_autosummary/jax.vmap.html) می‌پذیرد.
341357

342-
#### نسخه 1
358+
ایده `vmap` این است که برداری‌سازی را به مراحل تقسیم کند و تابعی که روی مقادیر تکی عمل می‌کند را به تابعی تبدیل کند که روی آرایه‌ها عمل می‌کند.
343359

344-
در اینجا یک راه برای اعمال `vmap` آمده است.
360+
در اینجا نحوه اعمال آن به مسئله ما آمده است.
345361

346362
```{code-cell} ipython3
347363
# f را تنظیم کنید تا f(x, y) را در هر x برای هر y داده شده محاسبه کند
@@ -368,31 +384,19 @@ with qe.Timer(precision=8):
368384
z_max.block_until_ready()
369385
```
370386

371-
با اجتناب از آرایه‌های ورودی بزرگ `x_mesh` و `y_mesh`، این نسخه `vmap` از حافظه بسیار کمتری استفاده می‌کند.
372-
373-
وقتی روی CPU اجرا می‌شود، زمان اجرای آن شبیه به نسخه meshgrid است.
374-
375-
وقتی روی GPU اجرا می‌شود، معمولاً به طور قابل توجهی سریعتر است.
376-
377-
در واقع، استفاده از `vmap` مزیت دیگری دارد: به ما اجازه می‌دهد برداری‌سازی را به مراحل تقسیم کنیم.
378-
379-
این منجر به کدی می‌شود که اغلب راحت‌تر از کد برداری شده سنتی قابل درک است.
380-
381-
ما این ایده‌ها را بیشتر هنگام حل مسائل بزرگتر بررسی خواهیم کرد.
387+
با اجتناب از آرایه‌های ورودی بزرگ `x_mesh` و `y_mesh`، این نسخه `vmap` از حافظه بسیار کمتری با زمان اجرای مشابه استفاده می‌کند.
382388

383-
### نسخه 2 vmap
389+
اما هنوز برخی بهره‌های سرعت را از دست می‌دهیم.
384390

385-
می‌توانیم با استفاده از vmap همچنان کارآمدتر از نظر حافظه باشیم.
391+
کد فوق آرایه دوبعدی کامل `f(x,y)` را محاسبه می‌کند و سپس max را می‌گیرد.
386392

387-
در حالی که در نسخه قبلی از آرایه‌های ورودی بزرگ اجتناب می‌کنیم، هنوز آرایه خروجی بزرگ `f(x,y)` را قبل از محاسبه حداکثر ایجاد می‌کنیم.
393+
علاوه بر این، فراخوانی `jnp.max` خارج از تابع JIT-کامپایل شده `f` قرار دارد، بنابراین کامپایلر نمی‌تواند این عملیات را در یک kernel واحد ادغام کند.
388394

389-
بیایید یک رویکرد کمی متفاوت را امتحان کنیم که max را به داخل می‌برد.
390-
391-
به دلیل این تغییر، ما هرگز آرایه دوبعدی `f(x,y)` را محاسبه نمی‌کنیم.
395+
می‌توانیم هر دو مشکل را با انتقال max به داخل و پوشاندن همه چیز در یک `@jax.jit` واحد برطرف کنیم:
392396

393397
```{code-cell} ipython3
394398
@jax.jit
395-
def compute_max_vmap_v2(grid):
399+
def compute_max_vmap(grid):
396400
# یک تابع بسازید که حداکثر را در امتداد هر سطر بگیرد
397401
f_vec_x_max = lambda y: jnp.max(f(grid, y))
398402
# تابع را برداری کنید تا بتوانیم روی تمام سطرها همزمان فراخوانی کنیم
@@ -408,24 +412,26 @@ def compute_max_vmap_v2(grid):
408412

409413
ما این تابع را روی تمام سطرها اعمال می‌کنیم و سپس حداکثر max های سطر را می‌گیریم.
410414

415+
چون max را به داخل منتقل می‌کنیم، هرگز آرایه دوبعدی کامل `f(x,y)` را نمی‌سازیم و حافظه بیشتری صرفه‌جویی می‌شود.
416+
417+
و چون همه چیز زیر یک `@jax.jit` واحد قرار دارد، کامپایلر می‌تواند تمام عملیات را در یک kernel بهینه ادغام کند.
418+
411419
بیایید آن را امتحان کنیم.
412420

413421
```{code-cell} ipython3
414422
with qe.Timer(precision=8):
415-
z_max = compute_max_vmap_v2(grid).block_until_ready()
423+
z_max = compute_max_vmap(grid).block_until_ready()
416424
417-
print(f"JAX vmap v2 result: {z_max:.6f}")
425+
print(f"JAX vmap result: {z_max:.6f}")
418426
```
419427

420428
بیایید دوباره اجرا کنیم تا زمان کامپایل حذف شود:
421429

422430
```{code-cell} ipython3
423431
with qe.Timer(precision=8):
424-
z_max = compute_max_vmap_v2(grid).block_until_ready()
432+
z_max = compute_max_vmap(grid).block_until_ready()
425433
```
426434

427-
اگر این را روی GPU اجرا می‌کنید، همانطور که ما این کار را می‌کنیم، باید افزایش سرعت قابل توجه دیگری را ببینید.
428-
429435
### خلاصه
430436

431437
به نظر ما، JAX برنده برای عملیات برداری شده است.
@@ -531,7 +537,7 @@ with qe.Timer(precision=8):
531537

532538
JAX نیز برای این عملیات ترتیبی کاملاً کارآمد است.
533539

534-
هم JAX و هم Numba عملکرد قوی پس از کامپایل ارائه می‌دهند، با این که Numba معمولاً (اما نه همیشه) سرعت‌های کمی بهتری در عملیات کاملاً ترتیبی ارائه می‌دهد.
540+
هم JAX و هم Numba عملکرد قوی پس از کامپایل ارائه می‌دهند.
535541

536542
### خلاصه
537543

@@ -545,7 +551,7 @@ JAX نیز برای این عملیات ترتیبی کاملاً کارآمد
545551

546552
علاوه بر این، آرایه‌های تغییرناپذیر JAX به این معنی است که نمی‌توانیم به سادگی عناصر آرایه را در جا به‌روزرسانی کنیم و تکرار مستقیم الگوریتم مورد استفاده توسط Numba را سخت می‌کند.
547553

548-
برای این نوع عملیات ترتیبی، Numba برنده واضح از نظر وضوح کد و سهولت پیاده‌سازی، و همچنین عملکرد بالا است.
554+
برای این نوع عملیات ترتیبی، Numba برنده واضح از نظر وضوح کد و سهولت پیاده‌سازی است.
549555

550556
## توصیه‌های کلی
551557

@@ -563,11 +569,12 @@ JAX نیز برای این عملیات ترتیبی کاملاً کارآمد
563569

564570
کد طبیعی و خوانا است --- صرفاً یک حلقه پایتون با یک decorator --- و کارایی آن عالی است.
565571

566-
JAX می‌تواند مسائل ترتیبی را از طریق `lax.scan` مدیریت کند، اما نحو آن کمتر شهودی است و برای کارهای کاملاً ترتیبی، بهره‌وری اضافی ناچیز است.
567-
568-
با این حال، `lax.scan` یک مزیت مهم دارد: از مشتق‌گیری خودکار در طول حلقه پشتیبانی می‌کند، که Numba قادر به انجام آن نیست.
572+
JAX می‌تواند مسائل ترتیبی را از طریق `lax.scan` مدیریت کند، اما نحو آن کمتر شهودی است.
569573

574+
```{note}
575+
یک مزیت مهم `lax.scan` این است که از مشتق‌گیری خودکار در طول حلقه پشتیبانی می‌کند، که Numba قادر به انجام آن نیست.
570576
اگر نیاز دارید از طریق یک محاسبه ترتیبی مشتق بگیرید (مثلاً محاسبه حساسیت‌های یک مسیر نسبت به پارامترهای مدل)، JAX علی‌رغم نحو کمتر طبیعی‌اش، انتخاب بهتری است.
577+
```
571578

572579
در عمل، بسیاری از مسائل ترکیبی از هر دو الگو هستند.
573580

0 commit comments

Comments
 (0)