JAX kept an older name around (.block_host_until_ready()) in parallel with the new name (.block_until_ready()) to avoid breaking users. Deprecate it so we only have one name.
PiperOrigin-RevId: 433228545
--
d42fffd849a4bac0c0c11a3346c93f07f8c64c44 by Jake VanderPlas <jakevdp@google.com>:
JaxTestCase: set numpy_rank_promotion='raise' by default
PiperOrigin-RevId: 427896974
It isn't free to pretty-print IR, so it's best to avoid it unless necessary. In addition, by returning an IR object, the user is now free to, say, print it with different options.
For example, one can now write things like:
```
In [1]: import numpy as np, jax, jax.numpy as jnp
In [2]: m = jax.jit(lambda x: x + jnp.array(np.arange(1000))).lower(7.).compiler_ir(dialect='mhlo')
In [3]: m.operation.print(large_elements_limit=10)
module @jit__lambda_.4 {
func public @main(%arg0: tensor<f32>) -> tensor<1000xf32> {
%0 = mhlo.constant opaque<"_", "0xDEADBEEF"> : tensor<1000xi32>
%1 = "mhlo.convert"(%0) : (tensor<1000xi32>) -> tensor<1000xf32>
%2 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<1000xf32>
%3 = mhlo.add %2, %1 : tensor<1000xf32>
return %3 : tensor<1000xf32>
}
}
```
Fixes https://github.com/google/jax/issues/9226
PiperOrigin-RevId: 422855649
--
4fcdadbfb3f4c484fd4432203cf13b88782b9311 by Matthew Johnson <mattjj@google.com>:
add jax.ensure_compile_time_eval to public api
aka jax.core.eval_context
COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/7987 from google:issue7535 4fcdadbfb3f4c484fd4432203cf13b88782b9311
PiperOrigin-RevId: 420928687
A callback under ad_checkpoint.checkpoint will be invoked
twice when taking the gradient: once during the forward pass
and once again during the backward pass when the residuals
for the forward pass are rematerialized.
While HLO dumping is redundant with XLA's XLA_FLAGS=--xla_dump_to=... feature, MHLO dumping is useful since XLA only ever sees and dumps the IR after it has been canonicalized and converted to HLO. Some debugging tasks require easy access to the MHLO as well.
PiperOrigin-RevId: 416435598
* dropping support for special AD handling for hcb.id_tap and id_print.
From now on, only the primals are tapped. The old behavior can be
obtained (for a limited time) by setting the JAX_HOST_CALLBACK_AD_TRANSFORMS
environment variale, or the --flax_host_callback_ad_transforms flag.
Additionally, added documentation for how to implement the old behavior
using JAX custom AD APIs.
This allows us to make some significant cleanup in the internals.