It isn't free to pretty-print IR, so it's best to avoid it unless necessary. In addition, by returning an IR object, the user is now free to, say, print it with different options.
For example, one can now write things like:
```
In [1]: import numpy as np, jax, jax.numpy as jnp
In [2]: m = jax.jit(lambda x: x + jnp.array(np.arange(1000))).lower(7.).compiler_ir(dialect='mhlo')
In [3]: m.operation.print(large_elements_limit=10)
module @jit__lambda_.4 {
func public @main(%arg0: tensor<f32>) -> tensor<1000xf32> {
%0 = mhlo.constant opaque<"_", "0xDEADBEEF"> : tensor<1000xi32>
%1 = "mhlo.convert"(%0) : (tensor<1000xi32>) -> tensor<1000xf32>
%2 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<1000xf32>
%3 = mhlo.add %2, %1 : tensor<1000xf32>
return %3 : tensor<1000xf32>
}
}
```
Fixes https://github.com/google/jax/issues/9226
PiperOrigin-RevId: 422855649
--
4fcdadbfb3f4c484fd4432203cf13b88782b9311 by Matthew Johnson <mattjj@google.com>:
add jax.ensure_compile_time_eval to public api
aka jax.core.eval_context
COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/7987 from google:issue7535 4fcdadbfb3f4c484fd4432203cf13b88782b9311
PiperOrigin-RevId: 420928687
A callback under ad_checkpoint.checkpoint will be invoked
twice when taking the gradient: once during the forward pass
and once again during the backward pass when the residuals
for the forward pass are rematerialized.
While HLO dumping is redundant with XLA's XLA_FLAGS=--xla_dump_to=... feature, MHLO dumping is useful since XLA only ever sees and dumps the IR after it has been canonicalized and converted to HLO. Some debugging tasks require easy access to the MHLO as well.
PiperOrigin-RevId: 416435598
* dropping support for special AD handling for hcb.id_tap and id_print.
From now on, only the primals are tapped. The old behavior can be
obtained (for a limited time) by setting the JAX_HOST_CALLBACK_AD_TRANSFORMS
environment variale, or the --flax_host_callback_ad_transforms flag.
Additionally, added documentation for how to implement the old behavior
using JAX custom AD APIs.
This allows us to make some significant cleanup in the internals.
Previously, jax.jit was ignored by jax2tf. This can result in the
converted code being much slower than the JAX core, unless the
user adds an explicit `tf.function(jit_compile=True)`. With this
change that wrapper is added automatically for all code fragments
under jax.jit. Note that most jax.numpy functions are annotated
with jax.jit, so with this change they will all be compiled.
When doing this I ran into problems with tf.custom_gradient and
tf.function. As documented in the
[tf.custom_gradient](https://www.tensorflow.org/api_docs/python/tf/custom_gradient)
documentation, you get a LookupError when trying to build the gradient
of a tf.function, even if it has a tf.custom_gradient defined. The
recommended solution is to add a tf.stop_gradient. This is safe, since
jax2tf will always wrap the converted functions with a tf.custom_gradient.
Use dtypes.issubdtype to test for subtyping otherwise we mishandle bfloat16 dtypes.
Don't pass an empty list to concatenate() when converting a shape to a value.
Forbid empty lists as arguments to lax.concatenate().
The `jax.experimental.stax` and `jax.experimental.optimizers` modules are standalone examples libraries. By contrast, the remaining modules in `jax.experimental` are experimental features of the JAX core system. This change moves the two example libraries, and the README that describes them, to `jax.example_libraries` to reflect this distinction.
PiperOrigin-RevId: 404405186
* Don't wrap static arguments in hashable wrappers in pmap.
* Delete wrap_hashably().
* In argnums_partial, either enforce hashability or wrap values with an explicitly unhashable wrapper. The intent here is that either we should check for hashability early or we should make sure it's clear that it's not something we intended..
* Delete argnames_partial, which appears unused.
This PR changes `jax.numpy.array()` to avoid creating any on-device arrays during tracing. As a consequence, calls to `jnp.array()` in a traced context, such as `jax.jit` will always be staged into the trace.
This change may break code that depends on the current (undocumented and unintentional) behavior of `jnp.array()` to perform shape or index calculations that must be known statically (at trace time). The workaround for such cases is to use classic NumPy to perform shape/index calculations.
PiperOrigin-RevId: 398008511