You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: lectures/jax_intro.md
+34-24Lines changed: 34 additions & 24 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -18,9 +18,10 @@ translation:
18
18
JAX as a NumPy Replacement::Differences::Speed!: سرعت!
19
19
JAX as a NumPy Replacement::Differences::Speed!::With NumPy: با NumPy
20
20
JAX as a NumPy Replacement::Differences::Speed!::With JAX: با JAX
21
+
JAX as a NumPy Replacement::Differences::Size Experiment: آزمایش اندازه
21
22
JAX as a NumPy Replacement::Differences::Precision: دقت
22
23
JAX as a NumPy Replacement::Differences::Immutability: تغییرناپذیری
23
-
JAX as a NumPy Replacement::Differences::A workaround: راهحل جایگزین
24
+
JAX as a NumPy Replacement::Differences::A Workaround: راهحل جایگزین
24
25
Functional Programming: برنامهنویسی تابعی
25
26
Functional Programming::Pure functions: توابع خالص
26
27
Functional Programming::Examples: مثالها
@@ -76,18 +77,18 @@ import numpy as np
76
77
import quantecon as qe
77
78
```
78
79
79
-
توجه کنید که `jax.numpy as jnp` را import میکنیم که یک رابط شبیه NumPy فراهم میکند.
80
-
81
80
## JAX به عنوان جایگزین NumPy
82
81
83
-
یکی از ویژگیهای جذاب JAX این است که، هر زمان که امکانپذیر باشد، عملیات پردازش آرایههای آن با API NumPy مطابقت دارد.
84
-
85
-
این بدان معناست که در بسیاری از موارد، میتوانیم از JAX به عنوان جایگزین مستقیم NumPy استفاده کنیم.
86
-
87
82
بیایید به شباهتها و تفاوتهای بین JAX و NumPy نگاه کنیم.
88
83
89
84
### شباهتها
90
85
86
+
در بالا `jax.numpy as jnp` را وارد کردیم که یک رابط شبیه به NumPy برای عملیات آرایه فراهم میکند.
87
+
88
+
یکی از ویژگیهای جذاب JAX این است که، هر زمان که امکانپذیر باشد، این رابط با API NumPy مطابقت دارد.
89
+
90
+
در نتیجه، اغلب میتوانیم از JAX به عنوان جایگزین مستقیم NumPy استفاده کنیم.
91
+
91
92
در اینجا برخی عملیات استاندارد آرایه با استفاده از `jnp` آمده است:
92
93
93
94
```{code-cell} ipython3
@@ -106,7 +107,7 @@ print(jnp.sum(a))
106
107
print(jnp.dot(a, a))
107
108
```
108
109
109
-
با این حال، شیء آرایه `a` یک آرایه NumPy نیست:
110
+
با این حال، باید به خاطر داشت که شیء آرایه `a` یک آرایه NumPy نیست:
110
111
111
112
```{code-cell} ipython3
112
113
a
@@ -129,11 +130,13 @@ jnp.sum(a)
129
130
(jax_speed)=
130
131
#### سرعت!
131
132
132
-
فرض کنیم میخواهیم تابع کسینوس را در نقاط بسیاری ارزیابی کنیم.
133
+
یکی از تفاوتهای عمده این است که JAX سریعتر است --- و گاهی بسیار سریعتر.
134
+
135
+
برای نشان دادن این موضوع، فرض کنیم میخواهیم تابع کسینوس را در نقاط بسیاری ارزیابی کنیم.
133
136
134
137
```{code-cell}
135
138
n = 50_000_000
136
-
x = np.linspace(0, 10, n)
139
+
x = np.linspace(0, 10, n) # NumPy array
137
140
```
138
141
139
142
##### با NumPy
@@ -174,27 +177,23 @@ with qe.Timer():
174
177
# First run
175
178
y = jnp.cos(x)
176
179
# Hold the interpreter until the array operation finishes
177
-
jax.block_until_ready(y);
180
+
y.block_until_ready()
178
181
```
179
182
180
183
```{note}
181
-
در اینجا، برای اندازهگیری سرعت واقعی، از متد `block_until_ready` استفاده میکنیم
182
-
تا مفسر را تا زمانی که نتایج محاسبات بازگردانده شوند نگه داریم.
183
-
184
-
این ضروری است زیرا JAX از ارسال ناهمزمان استفاده میکند که
184
+
در بالا، متد `block_until_ready` مفسر را تا زمانی که نتایج محاسبات بازگردانده شوند نگه میدارد.
185
+
این برای زمانبندی اجرا ضروری است زیرا JAX از ارسال ناهمزمان استفاده میکند که
185
186
به مفسر Python اجازه میدهد جلوتر از محاسبات عددی حرکت کند.
186
-
187
-
برای کدهایی که زمانبندی نمیشوند، میتوانید خط حاوی `block_until_ready` را حذف کنید.
188
187
```
189
188
190
-
و بیایید دوباره زمانبندی کنیم.
189
+
اکنون بیایید دوباره زمانبندی کنیم.
191
190
192
191
```{code-cell}
193
192
with qe.Timer():
194
193
# Second run
195
194
y = jnp.cos(x)
196
195
# Hold interpreter
197
-
jax.block_until_ready(y);
196
+
y.block_until_ready()
198
197
```
199
198
200
199
روی GPU، این کد بسیار سریعتر از معادل NumPy خود اجرا میشود.
@@ -209,6 +208,8 @@ with qe.Timer():
209
208
210
209
اندازه برای تولید کد بهینه اهمیت دارد زیرا موازیسازی کارآمد نیازمند تطابق اندازه کار با سختافزار موجود است.
211
210
211
+
#### آزمایش اندازه
212
+
212
213
میتوانیم ادعا که JAX بر اندازه آرایه تخصص پیدا میکند را با تغییر اندازه ورودی و مشاهده زمانهای اجرا تأیید کنیم.
213
214
214
215
```{code-cell}
@@ -220,15 +221,15 @@ with qe.Timer():
220
221
# First run
221
222
y = jnp.cos(x)
222
223
# Hold interpreter
223
-
jax.block_until_ready(y);
224
+
y.block_until_ready()
224
225
```
225
226
226
227
```{code-cell}
227
228
with qe.Timer():
228
229
# Second run
229
230
y = jnp.cos(x)
230
231
# Hold interpreter
231
-
jax.block_until_ready(y);
232
+
y.block_until_ready()
232
233
```
233
234
234
235
زمان اجرا افزایش مییابد و سپس دوباره کاهش مییابد (این روی GPU واضحتر خواهد بود).
@@ -277,7 +278,7 @@ a[0] = 1
277
278
a
278
279
```
279
280
280
-
در JAX این کار شکست میخورد!
281
+
در JAX این کار شکست میخورد 😱.
281
282
282
283
```{code-cell} ipython3
283
284
a = jnp.linspace(0, 1, 3)
@@ -292,11 +293,18 @@ except Exception as e:
292
293
293
294
```
294
295
295
-
طراحان JAX تصمیم گرفتند آرایهها را تغییرناپذیر کنند زیرا JAX از سبک برنامهنویسی تابعی استفاده میکند که در ادامه آن را بررسی میکنیم.
296
+
طراحان JAX تصمیم گرفتند آرایهها را تغییرناپذیر کنند زیرا
296
297
298
+
1. JAX از *سبک برنامهنویسی تابعی* استفاده میکند و
299
+
2. برنامهنویسی تابعی معمولاً از دادههای قابل تغییر اجتناب میکند
300
+
301
+
این ایدهها را {ref}`در ادامه <jax_func>` بررسی میکنیم.
302
+
303
+
304
+
(jax_at_workaround)=
297
305
#### راهحل جایگزین
298
306
299
-
توجه میکنیم که JAX یک جایگزین برای تغییر درجای آرایه با استفاده از[متد `at`](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html) فراهم میکند.
307
+
JAX یک جایگزین مستقیم برای تغییر درجای آرایه از طریق[متد `at`](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html) فراهم میکند.
300
308
301
309
```{code-cell} ipython3
302
310
a = jnp.linspace(0, 1, 3)
@@ -318,6 +326,8 @@ a
318
326
319
327
(اگرچه در واقع میتواند داخل توابع کامپایلشده JIT کارآمد باشد -- اما بیایید این را فعلاً کنار بگذاریم.)
Copy file name to clipboardExpand all lines: lectures/numpy_vs_numba_vs_jax.md
+60-21Lines changed: 60 additions & 21 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -36,23 +36,23 @@ translation:
36
36
37
37
# NumPy در مقابل Numba در مقابل JAX
38
38
39
-
در سخنرانیهای قبلی، سه کتابخانه اصلی برای محاسبات علمی و عددی را بحث کردیم:
39
+
در درسهای قبلی، سه کتابخانه اصلی برای محاسبات علمی و عددی را بحث کردیم:
40
40
41
41
*[NumPy](numpy)
42
42
*[Numba](numba)
43
43
*[JAX](jax_intro)
44
44
45
45
کدام یک را باید در هر موقعیت استفاده کنیم؟
46
46
47
-
این سخنرانی به آن سؤال پاسخ میدهد، حداقل تا حدی، با بحث در مورد برخی موارد استفاده.
47
+
این درس به آن سؤال پاسخ میدهد، حداقل تا حدی، با بحث در مورد برخی موارد استفاده.
48
48
49
49
قبل از شروع، توجه میکنیم که دو مورد اول یک جفت طبیعی هستند: NumPy و Numba به خوبی با هم کار میکنند.
50
50
51
51
JAX، از سوی دیگر، به تنهایی میایستد.
52
52
53
53
هنگام بررسی هر رویکرد، نه تنها کارایی و رد پای حافظه، بلکه وضوح و سهولت استفاده را نیز در نظر خواهیم گرفت.
54
54
55
-
علاوه بر آنچه در Anaconda موجود است، این سخنرانی به کتابخانههای زیر نیاز دارد:
55
+
علاوه بر آنچه در Anaconda موجود است، این درس به کتابخانههای زیر نیاز دارد:
56
56
57
57
```{code-cell} ipython3
58
58
---
@@ -67,7 +67,6 @@ tags: [hide-output]
67
67
ما از import های زیر استفاده خواهیم کرد.
68
68
69
69
```{code-cell} ipython3
70
-
import random
71
70
from functools import partial
72
71
73
72
import numpy as np
@@ -455,15 +454,60 @@ Numba این عملیات ترتیبی را به طور بسیار کارآمد
455
454
456
455
### نسخه JAX
457
456
458
-
حالا بیایید یک نسخه JAX با استفاده از `lax.scan` ایجاد کنیم:
457
+
حالا بیایید یک نسخه JAX با استفاده از سینتکس `at[t].set` ایجاد کنیم که، همانطور که {ref}`در درس JAX بحث شد <jax_at_workaround>`، راهحلی برای آرایههای تغییرناپذیر فراهم میکند.
459
458
460
-
(ما `n` را ایستا نگه میداریم زیرا بر اندازه آرایه تأثیر میگذارد و از این رو JAX میخواهد روی مقدار آن در کد کامپایل شده تخصصی شود.)
459
+
ما از `lax.fori_loop` استفاده میکنیم که نسخهای از حلقه for است که میتواند توسط XLA کامپایل شود.
* ما `n` را ایستا نگه میداریم زیرا بر اندازه آرایه تأثیر میگذارد و از این رو JAX میخواهد روی مقدار آن در کد کامپایل شده تخصصی شود.
478
+
* ما به CPU از طریق `device=cpu` متصل میمانیم زیرا این بار کاری ترتیبی از بسیاری عملیات کوچک تشکیل شده است که فرصت کمی برای موازیسازی GPU باقی میگذارد.
479
+
480
+
اگرچه `at[t].set` در هر مرحله ظاهراً یک آرایه جدید ایجاد میکند، در داخل یک تابع کامپایلشده با JIT، کامپایلر تشخیص میدهد که آرایه قدیمی دیگر مورد نیاز نیست و بهروزرسانی را در جا انجام میدهد.
481
+
482
+
بیایید آن را با همان پارامترها زمانبندی کنیم:
483
+
484
+
```{code-cell} ipython3
485
+
with qe.Timer():
486
+
# First run
487
+
x_jax = qm_jax_fori(0.1, n)
488
+
# Hold interpreter
489
+
x_jax.block_until_ready()
490
+
```
491
+
492
+
بیایید دوباره اجرا کنیم تا سربار کامپایل حذف شود:
493
+
494
+
```{code-cell} ipython3
495
+
with qe.Timer():
496
+
# Second run
497
+
x_jax = qm_jax_fori(0.1, n)
498
+
# Hold interpreter
499
+
x_jax.block_until_ready()
500
+
```
501
+
502
+
JAX نیز برای این عملیات ترتیبی کاملاً کارآمد است.
503
+
504
+
روش دیگری برای پیادهسازی حلقه وجود دارد که از `lax.scan` استفاده میکند.
505
+
506
+
این روش جایگزین، به طور قابل بحث، بیشتر با رویکرد تابعی JAX همسو است --- اگرچه سینتکس آن به خاطر سپردن دشواری دارد.
این کد خواندن آسانی ندارد اما، در اصل، `lax.scan` به طور مکرر `update` را فراخوانی میکند و بازگشتهای `x_new` را در یک آرایه جمع میکند.
476
520
477
-
```{note}
478
-
ما `device=cpu` را در decorator `jax.jit` مشخص میکنیم زیرا این محاسبه از بسیاری عملیات ترتیبی کوچک تشکیل شده است که فرصت کمی برای بهرهبرداری GPU از موازیسازی باقی میگذارد. در نتیجه، سربار راهاندازی kernel تمایل دارد روی GPU غالب شود و CPU را متناسبتر برای این بار کاری میکند.
479
-
```
480
-
481
521
بیایید آن را با همان پارامترها زمانبندی کنیم:
482
522
483
523
```{code-cell} ipython3
484
524
with qe.Timer():
485
525
# First run
486
-
x_jax = qm_jax(0.1, n)
526
+
x_jax = qm_jax_scan(0.1, n)
487
527
# Hold interpreter
488
528
x_jax.block_until_ready()
489
529
```
@@ -493,13 +533,11 @@ with qe.Timer():
493
533
```{code-cell} ipython3
494
534
with qe.Timer():
495
535
# Second run
496
-
x_jax = qm_jax(0.1, n)
536
+
x_jax = qm_jax_scan(0.1, n)
497
537
# Hold interpreter
498
538
x_jax.block_until_ready()
499
539
```
500
540
501
-
JAX نیز برای این عملیات ترتیبی کاملاً کارآمد است.
502
-
503
541
هم JAX و هم Numba عملکرد قوی پس از کامپایل ارائه میدهند.
504
542
505
543
### خلاصه
@@ -510,9 +548,9 @@ JAX نیز برای این عملیات ترتیبی کاملاً کارآمد
510
548
511
549
این دقیقاً نحوه تفکر اکثر برنامهنویسان در مورد الگوریتم است.
512
550
513
-
نسخه JAX، از سوی دیگر، نیاز به استفاده از `lax.scan`دارد که به طور قابل توجهی کمتر شهودی است.
551
+
نسخههای JAX، از سوی دیگر، نیاز به استفاده از `lax.fori_loop` یا `lax.scan`دارند که هر دو کمتر شهودی از یک حلقه استاندارد Python هستند.
514
552
515
-
علاوه بر این، آرایههای تغییرناپذیر JAX به این معنی است که نمیتوانیم به سادگی عناصر آرایه را در جا بهروزرسانی کنیم و تکرار مستقیم الگوریتم مورد استفاده توسط Numba را سخت میکند.
553
+
در حالی که سینتکس `at[t].set` در JAX بهروزرسانی عنصر به عنصر را ممکن میسازد، کد کلی همچنان سختتر از معادل Numba برای خواندن است.
516
554
517
555
برای این نوع عملیات ترتیبی، Numba برنده واضح از نظر وضوح کد و سهولت پیادهسازی است.
518
556
@@ -532,11 +570,12 @@ JAX نیز برای این عملیات ترتیبی کاملاً کارآمد
532
570
533
571
کد طبیعی و خوانا است --- صرفاً یک حلقه پایتون با یک decorator --- و کارایی آن عالی است.
534
572
535
-
JAX میتواند مسائل ترتیبی را از طریق `lax.scan` مدیریت کند، اما نحو آن کمتر شهودی است و برای کارهای کاملاً ترتیبی، بهرهوری اضافی ناچیز است.
536
-
537
-
با این حال، `lax.scan` یک مزیت مهم دارد: از مشتقگیری خودکار در طول حلقه پشتیبانی میکند، که Numba قادر به انجام آن نیست.
573
+
JAX میتواند مسائل ترتیبی را از طریق `lax.fori_loop` یا `lax.scan` مدیریت کند، اما نحو آن کمتر شهودی است.
538
574
575
+
```{note}
576
+
یک مزیت مهم `lax.fori_loop` و `lax.scan` این است که از مشتقگیری خودکار در طول حلقه پشتیبانی میکنند، که Numba قادر به انجام آن نیست.
539
577
اگر نیاز دارید از طریق یک محاسبه ترتیبی مشتق بگیرید (مثلاً محاسبه حساسیتهای یک مسیر نسبت به پارامترهای مدل)، JAX علیرغم نحو کمتر طبیعیاش، انتخاب بهتری است.
578
+
```
540
579
541
580
در عمل، بسیاری از مسائل ترکیبی از هر دو الگو هستند.
0 commit comments