@@ -121,6 +121,7 @@ eigvals
121121
122122Let's now look at some differences between JAX and NumPy array operations.
123123
124+ (jax_speed)=
124125#### Speed!
125126
126127Let's say we want to evaluate the cosine function at many points.
@@ -631,15 +632,20 @@ The explicitness of JAX brings significant benefits:
631632The last point is expanded on in the next section.
632633
633634
634- ## JIT compilation
635+ ## JIT Compilation
635636
636637The JAX just-in-time (JIT) compiler accelerates execution by generating
637638efficient machine code that varies with both task size and hardware.
638639
640+ We saw the power of JAX's JIT compiler combined with parallel hardware when we
641+ {ref}` above <jax_speed> ` , when we applied ` cos ` to a large array.
642+
643+ Let's try the same thing with a more complex function.
644+
639645
640646### Evaluating a more complicated function
641647
642- Let's try the same thing with a more complex function.
648+ Consider the function
643649
644650``` {code-cell}
645651def f(x):
@@ -695,14 +701,14 @@ with qe.Timer():
695701
696702The outcome is similar to the ` cos ` example --- JAX is faster, especially on the second run after JIT compilation.
697703
698- Moreover , with JAX, we have another trick up our sleeve --- we can JIT-compile
704+ However , with JAX, we have another trick up our sleeve --- we can JIT-compile
699705the * entire* function, not just individual operations.
700706
701707
702708### Compiling the whole function
703709
704- The JAX just-in-time (JIT) compiler can accelerate execution within functions by fusing linear
705- algebra operations into a single optimized kernel.
710+ The JAX just-in-time (JIT) compiler can accelerate execution within functions by fusing array
711+ operations into a single optimized kernel.
706712
707713Let's try this with the function ` f ` :
708714
@@ -756,9 +762,11 @@ compiled code and run at full speed.
756762
757763### Compiling non-pure functions
758764
759- Now that we've seen how powerful JIT compilation can be, it's important to understand its relationship with pure functions.
765+ Now that we've seen how powerful JIT compilation can be, it's important to
766+ understand its relationship with pure functions.
760767
761- While JAX will not usually throw errors when compiling impure functions, execution becomes unpredictable.
768+ While JAX will not usually throw errors when compiling impure functions,
769+ execution becomes unpredictable.
762770
763771Here's an illustration of this fact, using global variables:
764772
0 commit comments