For dynamic shapes experimentation and normal debugging, `python_pjit` still exists so that problem doesn't exist which makes us free to remove these 2 flags.
I am leaving pmap's flag alone for now.
PiperOrigin-RevId: 522602754
I forgot that the default setting is actually in jaxlib:
fbe9a80fdb/xla/python/xla_client.py (L135)
To be able to make this change as a jax-only release, I manually set
the env var on Cloud TPU if it isn't already set.
This is another attempt to land a rolled-back change https://github.com/google/jax/pull/14734 (cl/514070997).
See b/272154366 for more details.
The use case for call_tf with shape polymorphism is when we have a JAX program
that calls into TF function, and we want to serialize the JAX program with
some shapes unknown. Previously this use case did not work, except in the special
case when the output shape of the called TF function returns statically known
shapes.
The idea is that we allow the user of call_tf to specify the output shape.
This can be done even in presence of shape polymorphism, by writing the
output shape as an expression in terms of the input shapes. This is what
other JAX primitives do, e.g., concat, so we are simply enabling call_tf
to get the same behavior.
This change should be enough for old-style jax2tf, but will require more
work for native serialization.
We also removed some old code that was trying to workaround some limitations
in shape inference in TF. I think that those workarounds are ugly, and I am
prepared to give error messages rather than keep that code. So far no
tests fail.
PiperOrigin-RevId: 515137407
- jax.Array.broadcast: use lax.broadcast instead
- jax.Array.broadcast_in_dim: use lax.broadcast_in_dim instead
- jax.Array.split: use jnp.split instead
These are removed because they are not part of the np.ndarray API.
This changes the internals of JAX without affecting any public API.
Before, `jit` was a final style primitive. This means that the creation
of jaxpr was delayed as much as possible and transformations were stacked
on top of each other. With the `jit`-`pjit` implementation merge, `jit`
becomes an initial style primitive which means that we trace to jaxpr
as early as possible. For more information see [this section in autodidax](https://jax.readthedocs.io/en/latest/autodidax.html#on-the-fly-final-style-and-staged-initial-style-processing).
Moving to initial style should simplify JAX's internals and make
development of features like dynamic shapes, etc easier.
PiperOrigin-RevId: 508143501
Previously, division was only supported in certain situation, and this
led to errors, e.g., when using strides. Now we generalize the polynomials
to also include "floordiv(E, E)" and "mod(E, E)" as atoms, in addition
to dimension variables. A symbolic dimension is now a sum of products
of atoms. (We also changed the documentation to use symbolic dimension
instead of dimension polynomials).
Previously binary operations involving symbolic dimensions would
work only when the other operand is convertible to a symbolic dimension,
e.g., an integer. This resulted in errors when trying "x.shape[0] * 3.5"
and the recourse was to ask the user to add an explicit
"jnp.array(x.shape[0])".
Now we allow binary operations with any operand and the
"jnp.array" is added automatically if the other operand is not
an integer or a symbolic dimension. This means that instead
of an error they may be an error downstream if one tries to use
the result as a dimension. There is one known case where
JAX works with static shapes and with the previous behavior,
but will fail now. When you operate on `np.ndarray` and
symbolic dimension, previously this was kept as a `np.ndarray`
but not it is turned into a JAX array. The following
program will now fail if `x.shape[0]` is a symbolic dimension.:
`jnp.ones(np.arange(5) * x.shape[0])`
Instead you should write
`jnp.ones([i * x.shape[0] for i in range(5)])`
The CallTfEffect was added recently as an internal workaround for
DCE removing instances of call_tf. Here we add a parameter to
`call_tf` to be able to declare if the called computation is
effectful and should not be removed by DCE.