* right next to the pip installation instructions, mention they don't work for Windows;
* add a link to #5795 for an unofficial discussion of Windows native support
In the future JAX will be able to use a serialization format
based on a variant of MHLO. This is not yet ready, but in this PR
we are starting to get jax2tf ready for this. As a temporary
step, we had introduced a TF op called XlaCallModule which carries
a serialized MHLO module and which e can use to wrap the JAX native
MHLO as a TF op. We still reuse parts of jax2tf, in particular
the gradient machinery.
This functionality can be enabled locally with a
`experimental_native_lowering` flag for `jax2tf.convert`, or
globally with the flag `--jax2tf_default_experimental_native_lowering`.
values >= 88.7229.
When a JAX where() op is used to avoid a NaN or undefined value, reverse
differentiation can still return NaN even though the NaN input is not selected
by the conditional:
https://jax.readthedocs.io/en/latest/faq.html#gradients-contain-nan-where-using-where
This change uses jnp.maximum and jnp.minimum to compute CELU without producing an undefined value.
PiperOrigin-RevId: 461678140
* Split `lower_mesh_computation` into `lower_mesh_computation` and `lower_sharding_computation`. This is because `lower_mesh_computation` handles 3 paths; `spmd lowering path`, `non-spmd lowering path` and `xmap spmd lowering path`. I didn't want to add a 4th path to it for general shardings.
* `lower_sharding_computation` works in SPMD mode since its only used in pjit. Majority of the logic is the same. The only difference is that `mesh` does not exist in this function.
* `MeshComputation` is the point where `lower_mesh_computation` and `lower_sharding_computation` merge.
* `AUTO` and `UNSPECIFIED` cannot be used without mesh right now but I have a CL to fix this.
* Rest of the changes are to make all other functions play nicely with sharding instances.
PiperOrigin-RevId: 461260553
For very large trees of custom nodes this printing can be very verbose with a
lot or repetition. Our internal repository also encourages very deep package
names which exacerbates this issue.
Users encounter treedef printing when interacting with some staging APIs in JAX,
for example:
>>> params = { .. some params .. }
>>> f = jax.jit(..).lower(params).compile()
>>> f(params) # fine
>>> params['some_new_thing'] = something
>>> f(params)
TypeError: function compiled for {treedef}, called with {treedef}.
PiperOrigin-RevId: 461190971
--
e1f1e93e0c8b53e62a064b06b56c84a2bfedb911 by Roy Frostig <frostig@google.com>:
maintain an alias to `jax.tree_util.tree_map` in the top level `jax` module
PiperOrigin-RevId: 461146464
Added a section to README to explain the division errors
and to show a workaround. Changed the division errors
to include more detail as to what the error is,
and to include a link to the new section in the README
Issue: #11402
Due to a typo we were running no tests for convolutions with shape
polymorphism and enable_xla=False.
Added a few more tests from #11402 (Thanks @sdenton4).
The main issue was that in presence of shape polymorphism we cannot
just use `x.shape` for a TF value `x` because it will contain `None`
in the place of unknown dimensions. We must use instead the JAX
abstract values.
This does not fix all issues reported in #11402, there is still the
computation of padding or padding="SAME". Commented out the
corresponding test.