1. At the lax level, before we bind the primitive, we need to insert pbroadcasts if some inputs are varying. This is equivalent to the rewrite rules that shard_map has.
2. In abstract_eval rules of primitives, we need to check if all inputs are varying across the same mesh axes and then add the `varying_manual_axes` to the output ShapedArray.
This in turn requires us to support `pbroadcast2` and `psum2` primitives in shard_map.py. These primitives don't need to insert any pbroadcasts (equivalent to `no_rewrite` in shard_map) but need to do checks and update the output aval in their abstract_eval rules.
* pbroadcast_p: Union the existing aval.varying_manual_axes + axes (passed to pbroadcast) to calculate the output vma. For checks we need to make sure that the intersection of `aval.varying_manual_axes` and `axes` is empty.
* psum2_p: Remove the named axes from aval.varying_manual_axes to calculate the output vma. For checks we need to make sure that the intersection of `aval.varying_manual_axes` and `axes` is NOT empty.
Majority of the primitives should use the standard_insert_pbroadcast and standard_vma_rule and I'll add those in the follow up CLs to other primitives
PiperOrigin-RevId: 739225392
This is so that we can catch this exception in backward_pass/vmap and add extra message to inform users that this is a potential JAX bug. They should file an issue on the repo.
Currently, we only raise `ShardingTypeError` in one place, but we can expand to all other places in follow up changes. This change sets the machinery up.
Previous error:
```
jax._src.core.ShardingTypeError: dynamic_update_slice update sharding must be equal to operand sharding, got update sharding float32[2@x]({Explicit: ('x',)}) for operand sharding float32[16]({}).
```
New error:
```
jax._src.core.ShardingTypeError: dynamic_update_slice update sharding must be equal to operand sharding, got update sharding float32[2@x]({Explicit: ('x',)}) for operand sharding float32[16]({}).
This is a potential JAX bug. Please file an issue at https://github.com/jax-ml/jax/issues
```
The new added message of `This is a potential JAX bug...` is important because this error is raised in the backward pass which is 100% a JAX bug given that forward pass did not error.
PiperOrigin-RevId: 739053305
Previously it was:
`ValueError: A single NamedSharding spec specification can map every mesh axis to at most one positional dimension, but PartitionSpec('x', 'x') has duplicate entries for x`
Now it is:
`TypeError: dot_general operation with inputs: i64[8@x,2], i64[2,8@x] produces an illegally sharded result: i64[8@x,8@x]`
PiperOrigin-RevId: 736657644
As part of my efforts to simplify the primitive implementations in lax.linalg, I've found that all of the primitives share some common logic when it comes to impls, abstract_evals, and batching. This change adds some helper functions and starts the process of abstracting the primitive definitions to simplify and reduce duplication. I will continue with the rest of the primitives in lax.linalg, but I didn't want to overload the first diff.
PiperOrigin-RevId: 729471970
For future reference, this can be done via
python -m mypy jax --warn-unused-ignores > /tmp/unused.txt
while IFS=: read file line rest; do
echo "$file:$line";
gsed -i "${line}s/ *\# type: ignore\(\[[^]]*\]\)*//" "$file"
done < /tmp/unused.txt
* get_aval is not context dependent
* canonicalization does not happen for avals on an empty mesh
* jax.jit does not set abstract mesh context anymore before tracing
* sharding checks have been relaxed for all modes (`Auto`, `Explicit` and `Manual`). This means that `f = lambda x, y: x * y; f(explicit_sharded_arr, np_array)` will be allowed without inserting any mesh_casts even in `Explicit` sharding mode
* Even if use_mesh is not used in explicit sharding mode, computation follows data works!
* Higher order primitives skip canonicalization (pjit_p, while_p, cond_p, for_loop_p, scan_p)
* Check in partial_eval which compares jaxpr_known.outvars == jaxpr.out_avals has been relaxed to not check shardings if any one of the aval has an empty mesh.
As mentioned in https://github.com/jax-ml/jax/issues/26474 we need to relax the typing and sharding rule checks because if we insert `mesh_cast`s, those lead to creation of unnecessary residuals (for literals, numpy arrays, basically anything that has an empty mesh) which is not good.
PiperOrigin-RevId: 726097292
* When current_mesh is Manual and aval mesh is Auto
* When current mesh is set and aval mesh is unset
* Final style primitives skip this canonicalization and they are free to add it in their own `bind` method.
* `mesh_cast` is skipped from this canonicalization to avoid recursion errors.
This is required to make sure that after we hit abstract_eval rule and check_jaxpr, everything is properly typed in JAX's type system.
`Auto` right now is a bit more permissive because we need to keep the current code at HEAD working but `Explicit` and `Manual` are very strict.
PiperOrigin-RevId: 722868091
* `mesh_cast`: AxisTypes between src and dst mesh **must** differ. There should be **no "visible" data movement**. The shape of the aval doesn't change.
* `reshard`: Mesh should be the **same** between src and dst (same axis_names, axis_sizes and axis_types). **Data movement is allowed**. The shape of the aval doesn't change.
We might make `reshard` == `device_put`, hence the API is in experimental. This decision can be taken at a later point in time. The reason not to just give `device_put` this power is because `device_put` does a lot of stuff right now (and is going to get even more powers in the near future like cross-host transfers) and it's semantics would be very confusing if we keep piling sharding-in-types stuff on it.
PiperOrigin-RevId: 717588253
Replace `with set_mesh(mesh):` with `with use_mesh(mesh):` context manager
Also expose `AxisTypes` and `use_mesh` into public API via `jax.sharding.AxisTypes` and `jax.sharding.use_mesh`.
PiperOrigin-RevId: 716446406
Also allow users to enter into `Auto`/`User` mode inside jit along all or some axes.
Add checks to make sure that avals inside a context match the surrounding context. This check happens inside `abstract_eval` rules but maybe we need a more central place for it which we can create later on.
PiperOrigin-RevId: 707128096
Cases where we error
* batch dimensions not having consistent sharding (ignore None)
* contracting dimensions not having consistent sharding (ignore None)
* lhs.mesh != rhs.mesh
* if batch dimension and tensor dimension sharding match -> Error
PiperOrigin-RevId: 684983567
Change in preparation for removing HLO ops from the XLA Python bindings.
In passing, also:
* improve how the documentation of FftType renders.
* remove some stale references to xla_client
* remove the standard_translate rule, which is unused.
PiperOrigin-RevId: 684892102
There will be more improvements and semantics clarification coming in the future as we integrate it more into JAX.
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
PiperOrigin-RevId: 668991384
Since array_abstraction_level is a class property, it is also present on instances. We can avoid forming map(type, avals) and instead simply take the type(...) of the result. It's also shorter this way.
PiperOrigin-RevId: 606629740
fixes#14397
For autodiff purposes (and eventually for evaluation implementation purposes)
we need to distinguish between pow :: inexact -> int -> inexact (which is
differentiable at (0., 0)) and pow :: inexact -> inexact -> inexact (which
isn't); see https://github.com/google/jax/issues/14397#issuecomment-1426386290.
Instead of making a new primitive, we made the old one polymorphic and switch
its behavior on the element type of its second argument.
There were also some other cases with special handling for algorithmic reasons
(e.g. doing binary exponentiation), so these autodiff cases had to be merged
with those algorithmic cases.
Co-authored-by: Roy Frostig <frostig@google.com>
--
b243ea79ae7c9e2c2aa85e264b8dca8fc4c61b7b by Jake VanderPlas <jakevdp@google.com>:
Rename opaque dtype to extended dtype.
This includes three deprecations:
- jax.core.is_opaque_dtype(dt) is deprecated in favor of jnp.issubdtype(dt, jax.dtypes.extended)
- jax.core.has_opaque_dtype(x) is deprecated in favor of jnp.issubdtype(x.dtype, jax.dtypes.extended)
- the allow_opaque_dtype argument to jax.core.canonicalize_dtype is now allow_extended_dtype
Because jax.core is explicitly excluded from the API deprecation policy, these changes will not be
subject to a standard 3-month deprecation period.
COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/16824 from jakevdp:extended-dtype b243ea79ae7c9e2c2aa85e264b8dca8fc4c61b7b
PiperOrigin-RevId: 550674205
... in preparation for paring down `jax.core`'s exported symbols.
Also includes a few import fixups along the way, and a TODO comment to avoid an
import cycle in `_src/dtypes.py`.
PiperOrigin-RevId: 496024782
Almost all XLA translation rules have MHLO equivalents at this point, and there are no code paths that use the XLA translation rules in preference to their MLIR equivalents.
PiperOrigin-RevId: 442547482
At the moment this change does nothing since standard_primitive already registers these same translation rules. The change is in preparation for removing the behavior of standard_primitive of registering an XLA translation rule.
PiperOrigin-RevId: 442222533
This is a strictly mechanical change that moves abstract value canonicalization out of the core.AbstractValue subclasses and into their callers. This makes it safe to manipulate non-canonical abstract values even inside an -x32 context.
The callers to which canonicalization was added were:
a) all callers of `ConcreteArray` inside the JAX Tree.
b) all callers of `ShapedArray` and `UnshapedArray` that were found to be passing non-canonical dtypes during a global presubmit. These were identified by adding an assertion that the dtype is in fact canonical and fixing all the resulting test failures.
PiperOrigin-RevId: 414704700
To solve a circular dependency problem where some functions in jax._src.lax.lax depend on slicing, I moved a number of utility functions, e.g., standard_primitive, into a new module `jax._src.lax.utils`. Only utilities that need to be present at module import time were moved.
PiperOrigin-RevId: 411921794