Skip to content

Commit c40fbf6

Browse files
committed
misc
1 parent 2034bb9 commit c40fbf6

File tree

1 file changed

+15
-7
lines changed

1 file changed

+15
-7
lines changed

lectures/jax_intro.md

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ eigvals
121121

122122
Let's now look at some differences between JAX and NumPy array operations.
123123

124+
(jax_speed)=
124125
#### Speed!
125126

126127
Let's say we want to evaluate the cosine function at many points.
@@ -631,15 +632,20 @@ The explicitness of JAX brings significant benefits:
631632
The last point is expanded on in the next section.
632633

633634

634-
## JIT compilation
635+
## JIT Compilation
635636

636637
The JAX just-in-time (JIT) compiler accelerates execution by generating
637638
efficient machine code that varies with both task size and hardware.
638639

640+
We saw the power of JAX's JIT compiler combined with parallel hardware when we
641+
{ref}`above <jax_speed>`, when we applied `cos` to a large array.
642+
643+
Let's try the same thing with a more complex function.
644+
639645

640646
### Evaluating a more complicated function
641647

642-
Let's try the same thing with a more complex function.
648+
Consider the function
643649

644650
```{code-cell}
645651
def f(x):
@@ -695,14 +701,14 @@ with qe.Timer():
695701

696702
The outcome is similar to the `cos` example --- JAX is faster, especially on the second run after JIT compilation.
697703

698-
Moreover, with JAX, we have another trick up our sleeve --- we can JIT-compile
704+
However, with JAX, we have another trick up our sleeve --- we can JIT-compile
699705
the *entire* function, not just individual operations.
700706

701707

702708
### Compiling the whole function
703709

704-
The JAX just-in-time (JIT) compiler can accelerate execution within functions by fusing linear
705-
algebra operations into a single optimized kernel.
710+
The JAX just-in-time (JIT) compiler can accelerate execution within functions by fusing array
711+
operations into a single optimized kernel.
706712

707713
Let's try this with the function `f`:
708714

@@ -756,9 +762,11 @@ compiled code and run at full speed.
756762

757763
### Compiling non-pure functions
758764

759-
Now that we've seen how powerful JIT compilation can be, it's important to understand its relationship with pure functions.
765+
Now that we've seen how powerful JIT compilation can be, it's important to
766+
understand its relationship with pure functions.
760767

761-
While JAX will not usually throw errors when compiling impure functions, execution becomes unpredictable.
768+
While JAX will not usually throw errors when compiling impure functions,
769+
execution becomes unpredictable.
762770

763771
Here's an illustration of this fact, using global variables:
764772

0 commit comments

Comments
 (0)