Add a sharding rule string and trailing factor_sizes to def_partition, to
provide a sharding rule specification when Shardy is used. We use this
information to construct a SdyShardingRule and invoke SdyShardingRule.build
during MLIR lowering.
Extend custom_partitioner tests in pjit_test.py for Shardy sharding rule.
PiperOrigin-RevId: 713399604
This is necessary to avoid a circular dependency
jax -> fused_attention_stablehlo -> experimental -> jax
in google/jax#21371.
PiperOrigin-RevId: 650201550
```
jax/_src/sharding_impls.py:570: error: Unused "type: ignore" comment [unused-ignore]
jax/_src/sharding_impls.py:589: error: Unused "type: ignore" comment [unused-ignore]
jax/_src/sharding_impls.py:903: error: Unused "type: ignore" comment [unused-ignore]
```
Also add a # type: ignore to suppress an incorrect type stub already
fixed in jaxlib but not released yet.
- This extension has one C API which registers a custom partitioner with callbacks from the input.
- Update xla_client.register_custom_call_partitioner to take an optional PJRT_Api* input.
- Add xla_bridge.register_plugin_initialization_callbacks to register callbacks to be called with PJRT_Api* after plugins are discovered.
PiperOrigin-RevId: 620357554
This PR is a follow up to #18881.
The changes were generated by adding
from __future__ import annotations
to the files which did not already have them and running
pyupgrade --py39-plus --keep-percent-format {jax,tests,jaxlib,examples,benchmarks}/**/*.py
Previously, we introduced support for multi-platform lowering, by
adding a new LoweringParameters object that can be used to specify
a cross-lowering platform or even multiple platforms. But we had
kept the ModuleContext.platform in place because some lowering rules
were still referencing it. Now we replace ModuleContext.platform with
ModuleContext.platforms, which removes the redundancy, simplifies
the code, and makes it clearer that the lowering rules should not
simply assume single-platform lowering.
PiperOrigin-RevId: 576575376
Because this is best effort, users writing code to handle GPSMDSharding
should be able to deal only with the GSPMDSharding type.
PiperOrigin-RevId: 534612265
MHLO-to-HLO conversion now knows how to introduce tuples to custom calls if needed, so we can remove explicit tupling from JAX.
PiperOrigin-RevId: 528485268
Add a test to demonstrate how to force XLA to choose
a different sharding.
Also it is possible to return the wrong
shape from a partition function. We should error in this case.
PiperOrigin-RevId: 525606690
* Move dependencies of sharding_impls into sharding_impls to avoid creating cyclic dependencies.
* Fix a handful of new pytype errors.
PiperOrigin-RevId: 523146076
* Define use_cpp_class and use_cpp_method decorators as no-ops for type checking.
* Remove the use of abc.ABC when defining the Sharding type. This triggers a pytype bug: the easiest fix seems to be to skip the use of the ABC.
* Write use_cpp_class decorator differently on ArrayImpl to work around pytype bug.
* Fix a few new type errors.
PiperOrigin-RevId: 516631428
By defining the Sharding base class in its own module, we can pull it out into a separate Bazel submodule, which will help pytype inference when defining Array.
PiperOrigin-RevId: 516223009
Arguments specified by static_argnums cannot contain
any jax tracers because they will be passed into the XLA compiler
where the lowering information for these tracers is already lost.