49 Commits

Author SHA1 Message Date
Yash Katariya
5950e722e2 Make sure vma on ShapedArray exists by default to make development easier. The field is populated inside shard_map guarded on the varying_axes_in_types config though.
PiperOrigin-RevId: 741554623
2025-03-28 09:44:03 -07:00
Yash Katariya
f1a9241187 Add standard_insert_broadcasts to all traceables in lax.py and checks in abstract_eval rules of those primitives.
PiperOrigin-RevId: 740536031
2025-03-25 17:03:18 -07:00
Yash Katariya
3163fbaac4 Add varying manual axes rules to mul_p and convert_element_type_p. There are 2 things that need to be added:
1. At the lax level, before we bind the primitive, we need to insert pbroadcasts if some inputs are varying. This is equivalent to the rewrite rules that shard_map has.

2. In abstract_eval rules of primitives, we need to check if all inputs are varying across the same mesh axes and then add the `varying_manual_axes` to the output ShapedArray.

This in turn requires us to support `pbroadcast2` and `psum2` primitives in shard_map.py. These primitives don't need to insert any pbroadcasts (equivalent to `no_rewrite` in shard_map) but need to do checks and update the output aval in their abstract_eval rules.

* pbroadcast_p: Union the existing aval.varying_manual_axes + axes (passed to pbroadcast) to calculate the output vma. For checks we need to make sure that the intersection of `aval.varying_manual_axes` and `axes` is empty.

* psum2_p: Remove the named axes from aval.varying_manual_axes to calculate the output vma. For checks we need to make sure that the intersection of `aval.varying_manual_axes` and `axes` is NOT empty.

Majority of the primitives should use the standard_insert_pbroadcast and standard_vma_rule and I'll add those in the follow up CLs to other primitives

PiperOrigin-RevId: 739225392
2025-03-21 10:26:18 -07:00
Yash Katariya
40ce44d143 Add ShardingTypeError to all sharding rules in JAX
PiperOrigin-RevId: 739205830
2025-03-21 09:26:08 -07:00
Yash Katariya
c7d6b653ce [sharding_in_types] Add core.ShardingTypeError as a new Exception that are sharding-in-types specific errors should raise.
This is so that we can catch this exception in backward_pass/vmap and add extra message to inform users that this is a potential JAX bug. They should file an issue on the repo.

Currently, we only raise `ShardingTypeError` in one place, but we can expand to all other places in follow up changes. This change sets the machinery up.

Previous error:

```
jax._src.core.ShardingTypeError: dynamic_update_slice update sharding must be equal to operand sharding, got update sharding float32[2@x]({Explicit: ('x',)}) for operand sharding float32[16]({}).
```

New error:

```
jax._src.core.ShardingTypeError: dynamic_update_slice update sharding must be equal to operand sharding, got update sharding float32[2@x]({Explicit: ('x',)}) for operand sharding float32[16]({}).
This is a potential JAX bug. Please file an issue at https://github.com/jax-ml/jax/issues
```

The new added message of `This is a potential JAX bug...` is important because this error is raised in the backward pass which is 100% a JAX bug given that forward pass did not error.

PiperOrigin-RevId: 739053305
2025-03-20 22:19:08 -07:00
Yash Katariya
e615e2acb3 Raise a better error with more info when we see duplicate axis in a PartitionSpec resulting from a sharding rule.
Previously it was:

`ValueError: A single NamedSharding spec specification can map every mesh axis to at most one positional dimension, but PartitionSpec('x', 'x') has duplicate entries for x`

Now it is:

`TypeError: dot_general operation with inputs: i64[8@x,2], i64[2,8@x] produces an illegally sharded result: i64[8@x,8@x]`

PiperOrigin-RevId: 736657644
2025-03-13 15:24:10 -07:00
Dan Foreman-Mackey
a981e1c4b9 Start adding primitive registration helper functions to lax.linalg.
As part of my efforts to simplify the primitive implementations in lax.linalg, I've found that all of the primitives share some common logic when it comes to impls, abstract_evals, and batching. This change adds some helper functions and starts the process of abstracting the primitive definitions to simplify and reduce duplication. I will continue with the rest of the primitives in lax.linalg, but I didn't want to overload the first diff.

PiperOrigin-RevId: 729471970
2025-02-21 04:05:34 -08:00
Yash Katariya
8305803b76 [sharding_in_types] Initial support for partial-auto/explicit shard_map + sharding-in-types. If the axes in shmap(..., auto=...) is an explicit axes in the outer mesh context, then that axis is treated as Explicit instead of Auto.
PiperOrigin-RevId: 728920514
2025-02-19 20:04:54 -08:00
Yash Katariya
a3edfb43ef Now that sharding_in_types config flag is True, remove the config and all the conditionals
PiperOrigin-RevId: 728653433
2025-02-19 06:53:35 -08:00
Sergei Lebedev
a73456d54d Removed unused `# type: ignore` comments
For future reference, this can be done via

    python -m mypy jax --warn-unused-ignores > /tmp/unused.txt
    while IFS=: read file line rest; do
      echo "$file:$line";
      gsed -i "${line}s/ *\# type: ignore\(\[[^]]*\]\)*//" "$file"
    done < /tmp/unused.txt
2025-02-13 21:12:27 +00:00
Yash Katariya
229aa65a3e Split NamedSharding into a separate file called named_sharding.py so that we can import it in core.py and break the cyclic dependency.
PiperOrigin-RevId: 726566863
2025-02-13 11:22:54 -08:00
Yash Katariya
2062e986a6 Fix the error message to say out_sharding instead of sharding in lax.reshape sharding rule
PiperOrigin-RevId: 726484167
2025-02-13 07:49:54 -08:00
Yash Katariya
d58c3a4722 [sharding_in_types] Fix some properties that assumed axis_types always existed.
PiperOrigin-RevId: 726187278
2025-02-12 13:57:19 -08:00
Yash Katariya
2d01df760b [sharding_in_types] Make the typing checks and sharding rule checks a little bit less strict when the current or aval mesh is empty/unset. Also some more changes as listed below:
* get_aval is not context dependent

* canonicalization does not happen for avals on an empty mesh

* jax.jit does not set abstract mesh context anymore before tracing

* sharding checks have been relaxed for all modes (`Auto`, `Explicit` and `Manual`). This means that `f = lambda x, y: x * y; f(explicit_sharded_arr, np_array)` will be allowed without inserting any mesh_casts even in `Explicit` sharding mode

* Even if use_mesh is not used in explicit sharding mode, computation follows data works!

* Higher order primitives skip canonicalization (pjit_p, while_p, cond_p, for_loop_p, scan_p)

* Check in partial_eval which compares jaxpr_known.outvars == jaxpr.out_avals has been relaxed to not check shardings if any one of the aval has an empty mesh.

As mentioned in https://github.com/jax-ml/jax/issues/26474 we need to relax the typing and sharding rule checks because if we insert `mesh_cast`s, those lead to creation of unnecessary residuals (for literals, numpy arrays, basically anything that has an empty mesh) which is not good.

PiperOrigin-RevId: 726097292
2025-02-12 10:03:01 -08:00
Yash Katariya
bc1a706688 [sharding_in_types] Add a canonicalize_value step before dispatching bind so that we can insert mesh_casts under the following conditions:
* When current_mesh is Manual and aval mesh is Auto

* When current mesh is set and aval mesh is unset

* Final style primitives skip this canonicalization and they are free to add it in their own `bind` method.

* `mesh_cast` is skipped from this canonicalization to avoid recursion errors.

This is required to make sure that after we hit abstract_eval rule and check_jaxpr, everything is properly typed in JAX's type system.

`Auto` right now is a bit more permissive because we need to keep the current code at HEAD working but `Explicit` and `Manual` are very strict.

PiperOrigin-RevId: 722868091
2025-02-03 18:00:19 -08:00
Yash Katariya
9107ee4a22 Do automatic casting from auto -> manual when the context mesh is manual and avals are in auto mode. This happens when values are being closed over in a shard_map. The casting is happening at lax level but we can move this to a different place later on.
PiperOrigin-RevId: 721495804
2025-01-30 13:14:04 -08:00
Yash Katariya
8f248fe626 [sharding_in_types] Upstream changes from defaulting sharding_in_types config to True experiment. There aren't a lot of failures in TGP but we can atleast upstream these changes until we work on the failures.
PiperOrigin-RevId: 720639755
2025-01-28 11:04:42 -08:00
Yash Katariya
d28c3fa409 Replace Hidden/Visible/Collective AxisTypes names with Auto/Explicit/Manual.
PiperOrigin-RevId: 719561729
2025-01-24 23:21:13 -08:00
Yash Katariya
704b2e5fba [sharding_in_types] Make vmap work with shard_map + pallas
PiperOrigin-RevId: 718578207
2025-01-22 16:48:32 -08:00
Yash Katariya
799eb98cac Add reshard API in experimental. Currently for sharding_in_types we have 2 APIs: mesh_cast and reshard. Both work in sharding_in_types mode and affect the sharding of the aval. Following are the semantics of both:
* `mesh_cast`: AxisTypes between src and dst mesh **must** differ. There should be **no "visible" data movement**. The shape of the aval doesn't change.

* `reshard`: Mesh should be the **same** between src and dst (same axis_names, axis_sizes and axis_types). **Data movement is allowed**. The shape of the aval doesn't change.

We might make `reshard` == `device_put`, hence the API is in experimental. This decision can be taken at a later point in time. The reason not to just give `device_put` this power is because `device_put` does a lot of stuff right now (and is going to get even more powers in the near future like cross-host transfers) and it's semantics would be very confusing if we keep piling sharding-in-types stuff on it.

PiperOrigin-RevId: 717588253
2025-01-20 11:39:25 -08:00
Yash Katariya
c7f8d17f5a Expose hidden_axes via jax namespace as public API. Also mention it as a workaround for primitives we don't support yet.
PiperOrigin-RevId: 716839003
2025-01-17 16:48:58 -08:00
Yash Katariya
49224d6cdb Replace Auto/User/Collective AxisTypes names with Hidden/Visible/Collective.
Replace `with set_mesh(mesh):` with `with use_mesh(mesh):` context manager

Also expose `AxisTypes` and `use_mesh` into public API via `jax.sharding.AxisTypes` and `jax.sharding.use_mesh`.

PiperOrigin-RevId: 716446406
2025-01-16 17:55:54 -08:00
Yash Katariya
c72ed260fe [sharding_in_types] Handle ShapeDtypeStruct inputs with sharding_in_types by registering the sharding on the aval properly created by SDS in it's pytype_aval_mapping.
Also If we are running under full auto mode, don't error out if primitives don't have a sharding rule registered.

PiperOrigin-RevId: 715383866
2025-01-14 08:03:50 -08:00
Yash Katariya
af63e443ef [sharding_in_types] Check out_avals with mesh context too. This is because users can pass their own shardings to functions like einsum, reshape, broadcast`, etc
PiperOrigin-RevId: 707672801
2024-12-18 14:42:40 -08:00
Yash Katariya
473e2bf527 Put abstract_mesh on every eqn so that we can preserve it during eval_jaxpr and check_jaxpr roundtrip.
Also allow users to enter into `Auto`/`User` mode inside jit along all or some axes.

Add checks to make sure that avals inside a context match the surrounding context. This check happens inside `abstract_eval` rules but maybe we need a more central place for it which we can create later on.

PiperOrigin-RevId: 707128096
2024-12-17 09:17:21 -08:00
Yash Katariya
21f8885a9e [sharding_in_types] Make argmax and argmin work with sharding_in_types. This also requires adding reduce_p sharding rule
PiperOrigin-RevId: 699244204
2024-11-22 12:00:22 -08:00
Dougal Maclaurin
48f24b6acb Remove ConcreteArray from JAX. It's easy to do trace-time concretization without it.
PiperOrigin-RevId: 691929385
2024-10-31 14:06:54 -07:00
Yash Katariya
18bc354305 [sharding_in_types] Add dot_general sharding rule. We only handle the simple cases and rely on xla to insert the collectives.
Cases where we error

* batch dimensions not having consistent sharding (ignore None)
* contracting dimensions not having consistent sharding (ignore None)
* lhs.mesh != rhs.mesh
* if batch dimension and tensor dimension sharding match -> Error

PiperOrigin-RevId: 684983567
2024-10-11 16:05:13 -07:00
Peter Hawkins
46f0a3eee7 Clone RandomAlgorithm into lax.py, instead of using the version from XLA.
Change in preparation for removing HLO ops from the XLA Python bindings.

In passing, also:
* improve how the documentation of FftType renders.
* remove some stale references to xla_client
* remove the standard_translate rule, which is unused.

PiperOrigin-RevId: 684892102
2024-10-11 11:03:15 -07:00
Yash Katariya
bcfe95e98e Initial integration of sharding in types in JAX. Currently we just support nary ops in forward only sharding propagation. Currently this functionality is experimental and hidden behind jax_sharding_in_types config flag.
There will be more improvements and semantics clarification coming in the future as we integrate it more into JAX.

Co-authored-by: Dougal Maclaurin <dougalm@google.com>
PiperOrigin-RevId: 668991384
2024-08-29 10:50:04 -07:00
Matthew Johnson
3f9eb404e4 remove named_shapes (since xmap is now gone) 2024-07-25 00:54:50 +00:00
Peter Hawkins
b91a97d163 Avoid a couple more uses of operator.attrgetter(), since it's unnecessarily slow compared to writing a function directly.
PiperOrigin-RevId: 609749059
2024-02-23 09:20:56 -08:00
Peter Hawkins
bab8ae81a3 Micro optimization to standard_abstract_eval:
Since array_abstraction_level is a class property, it is also present on instances. We can avoid forming map(type, avals) and instead simply take the type(...) of the result. It's also shorter this way.

PiperOrigin-RevId: 606629740
2024-02-13 08:26:35 -08:00
jax authors
af42359433 Merge pull request #16419 from mattjj:pow-jvp
PiperOrigin-RevId: 559266945
2023-08-22 17:15:04 -07:00
Peter Hawkins
47651c6a59 Remove uses of XLA translation rules.
Remove translation_rule argument to standard_primitive.

PiperOrigin-RevId: 557220350
2023-08-15 12:53:36 -07:00
Matthew Johnson
69ad4df9a5 fix pow_p jvp rule at x=0. y=0
fixes #14397

For autodiff purposes (and eventually for evaluation implementation purposes)
we need to distinguish between pow :: inexact -> int -> inexact (which is
differentiable at (0., 0)) and pow :: inexact -> inexact -> inexact (which
isn't); see https://github.com/google/jax/issues/14397#issuecomment-1426386290.

Instead of making a new primitive, we made the old one polymorphic and switch
its behavior on the element type of its second argument.

There were also some other cases with special handling for algorithmic reasons
(e.g. doing binary exponentiation), so these autodiff cases had to be merged
with those algorithmic cases.

Co-authored-by: Roy Frostig <frostig@google.com>
2023-07-28 17:14:47 -07:00
Jake Vanderplas
b4132b4c50 Copybara import of the project:
--
b243ea79ae7c9e2c2aa85e264b8dca8fc4c61b7b by Jake VanderPlas <jakevdp@google.com>:

Rename opaque dtype to extended dtype.

This includes three deprecations:
 - jax.core.is_opaque_dtype(dt) is deprecated in favor of jnp.issubdtype(dt, jax.dtypes.extended)
 - jax.core.has_opaque_dtype(x) is deprecated in favor of jnp.issubdtype(x.dtype, jax.dtypes.extended)
 - the allow_opaque_dtype argument to jax.core.canonicalize_dtype is now allow_extended_dtype
Because jax.core is explicitly excluded from the API deprecation policy, these changes will not be
subject to a standard 3-month deprecation period.

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/16824 from jakevdp:extended-dtype b243ea79ae7c9e2c2aa85e264b8dca8fc4c61b7b
PiperOrigin-RevId: 550674205
2023-07-24 14:38:20 -07:00
Jake VanderPlas
45836824ce KeyArray: improve errors for unimplemented primitives 2023-04-24 16:54:25 -07:00
Peter Hawkins
31eeaed913 Split mlir.py and xla.py into separate Bazel targets.
PiperOrigin-RevId: 520737811
2023-03-30 14:06:16 -07:00
Peter Hawkins
6cc1bf54a1 Move jax.interpreters.partial_eval to jax._src.interpreters.partial_eval.
Also fix up some other internal imports of jax.interpreters.* to use jax._src.interpreters.

PiperOrigin-RevId: 519813664
2023-03-27 13:30:47 -07:00
Roy Frostig
d927a5dbf3 migrate internal dependencies from jax.core to jax._src.core
... in preparation for paring down `jax.core`'s exported symbols.

Also includes a few import fixups along the way, and a TODO comment to avoid an
import cycle in `_src/dtypes.py`.

PiperOrigin-RevId: 496024782
2022-12-16 21:00:14 -08:00
Jake VanderPlas
1ed18fa500 add allow_opaque_dtype to dtypes.canonicalize_dtype utility 2022-10-17 13:47:42 -07:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Peter Hawkins
a48752a578 [MHLO] Remove most XLA translation rules.
Almost all XLA translation rules have MHLO equivalents at this point, and there are no code paths that use the XLA translation rules in preference to their MLIR equivalents.

PiperOrigin-RevId: 442547482
2022-04-18 08:28:35 -07:00
Peter Hawkins
c4ba450867 [MHLO] Add explicit XLA translation rules for primitives that lack MHLO lowerings that rely on standard_primitive registering a translation rule.
At the moment this change does nothing since standard_primitive already registers these same translation rules. The change is in preparation for removing the behavior of standard_primitive of registering an XLA translation rule.

PiperOrigin-RevId: 442222533
2022-04-16 07:01:19 -07:00
Matthew Johnson
4354f355a8 prototyping dynamic shapes
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2022-04-11 22:10:47 -07:00
Matthew Johnson
4db899007b add staging logic for polymorphic shapes in jaxprs
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2022-01-05 14:11:12 -08:00
Peter Hawkins
06cd1fedee Move dtype canonicalization out of core.AbstractValue subclasses.
This is a strictly mechanical change that moves abstract value canonicalization out of the core.AbstractValue subclasses and into their callers. This makes it safe to manipulate non-canonical abstract values even inside an -x32 context.

The callers to which canonicalization was added were:
a) all callers of `ConcreteArray` inside the JAX Tree.
b) all callers of `ShapedArray` and `UnshapedArray` that were found to be passing non-canonical dtypes during a global presubmit. These were identified by adding an assertion that the dtype is in fact canonical and fixing all the resulting test failures.

PiperOrigin-RevId: 414704700
2021-12-07 06:13:07 -08:00
Peter Hawkins
83d8c6c238 Split slice/update_slice/gather/scatter out of jax._src.lax.lax into jax._src.lax.slicing.
To solve a circular dependency problem where some functions in jax._src.lax.lax depend on slicing, I moved a number of utility functions, e.g., standard_primitive, into a new module `jax._src.lax.utils`. Only utilities that need to be present at module import time were moved.

PiperOrigin-RevId: 411921794
2021-11-23 16:35:18 -08:00