58 Commits

Author SHA1 Message Date
Bixia Zheng
c4ac0dd6bd Implement the extension to the custom_partitioning API.
Add a sharding rule string and trailing factor_sizes to def_partition, to
provide a sharding rule specification when Shardy is used. We use this
information to construct a SdyShardingRule and invoke SdyShardingRule.build
during MLIR lowering.

Extend custom_partitioner tests in  pjit_test.py for Shardy sharding rule.

PiperOrigin-RevId: 713399604
2025-01-08 13:34:47 -08:00
Michael Hudgins
d4d1518c3d Update references to the GitHub url in JAX codebase to reflect move from google/jax to jax-ml/jax
PiperOrigin-RevId: 676843138
2024-09-20 07:52:33 -07:00
Sergei Lebedev
740945a724 Moved the implementation of `custom_partitioning` into jax/_src
This is necessary to avoid a circular dependency

   jax -> fused_attention_stablehlo -> experimental -> jax

in google/jax#21371.

PiperOrigin-RevId: 650201550
2024-07-08 04:31:44 -07:00
Dan Foreman-Mackey
6d35b109fd Rename "Example" to "Examples" in docstrings.
This PR updates all docstrings that previously had a section heading
called "Example" and replaces that with "Examples" to be consistent.
2024-06-21 11:43:16 -04:00
Yash Katariya
1edd649de4 Deprecate XLACompatibleSharding in favor of jax.sharding.Sharding.
PiperOrigin-RevId: 640544939
2024-06-05 09:07:27 -07:00
jax authors
47420a3825 Merge pull request #20884 from superbobry:main
PiperOrigin-RevId: 636176070
2024-05-22 08:29:45 -07:00
Peter Hawkins
a9a567523f Fix mypy errors:
```
jax/_src/sharding_impls.py:570: error: Unused "type: ignore" comment  [unused-ignore]
jax/_src/sharding_impls.py:589: error: Unused "type: ignore" comment  [unused-ignore]
jax/_src/sharding_impls.py:903: error: Unused "type: ignore" comment  [unused-ignore]
```

Also add a # type: ignore to suppress an incorrect type stub already
fixed in jaxlib but not released yet.
2024-05-22 14:31:52 +00:00
jax authors
b5583742b5 Merge pull request #21273 from superbobry:mypy-ruff
PiperOrigin-RevId: 636146344
2024-05-22 06:35:38 -07:00
Sergei Lebedev
c3bc88d5e4 Bumped mypy to 1.10.0 and ruff to 0.4.4 2024-05-16 23:16:32 +01:00
Parker Schuh
e652d62b85 Cleanup second registration of custom_partitioning callbacks now that
the jaxlib version has been bumped.

PiperOrigin-RevId: 631852273
2024-05-08 10:45:39 -07:00
Yash Katariya
395d3cb79e Bump minimum jaxlib version to 0.4.27
xla_extension_version is 261 and mlir_api_version is 56

PiperOrigin-RevId: 631579739
2024-05-07 16:07:59 -07:00
Sergei Lebedev
602d4bd4d2 Propagate the available axes to the partitioning function
See #20864 for more context and the added test for a reproducer.
2024-04-23 14:24:28 +01:00
Jieying Luo
68c674d106 [PJRT C API] Add a PJRT extension to register custom partitioner.
- This extension has one C API which registers a custom partitioner with callbacks from the input.
- Update xla_client.register_custom_call_partitioner to take an optional PJRT_Api* input.
- Add xla_bridge.register_plugin_initialization_callbacks to register callbacks to be called with PJRT_Api* after plugins are discovered.

PiperOrigin-RevId: 620357554
2024-03-29 15:40:26 -07:00
Jake VanderPlas
e59a0506fe Deprecate jax.tree_map in favor of jax.tree.map 2024-02-22 11:35:39 -08:00
Matthew Johnson
4a8babb101 integrate attrs in jax.jit
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2024-01-27 17:44:43 -08:00
Parker Schuh
899765edd0 Return mlir modules instead of XlaComputation from custom_partitioning.
This will help with exporting this call to the c-api.

PiperOrigin-RevId: 599921028
2024-01-19 13:23:42 -08:00
Yash Katariya
4c9241ecda Cache ClosedJaxpr creation to minimize cache_misses. ClosedJaxpr should always be created under a cache.
PiperOrigin-RevId: 593023314
2023-12-21 22:15:52 -08:00
Jan Hrček
4da56dcdd7 Fix duplicate word occurrences 2023-12-19 06:15:30 +01:00
Parker Schuh
7ba8622719 For custom_partitioning, directly emit call when inside of a shard_map.
PiperOrigin-RevId: 592011427
2023-12-18 14:32:38 -08:00
Sergei Lebedev
f936613b06 Upgrade remaining sources to Python 3.9
This PR is a follow up to #18881.

The changes were generated by adding

    from __future__ import annotations

to the files which did not already have them and running

    pyupgrade --py39-plus --keep-percent-format {jax,tests,jaxlib,examples,benchmarks}/**/*.py
2023-12-13 10:29:45 +00:00
Yash Katariya
10f6a35f83 Add a registry for primitives that require device_assignment during lowering
PiperOrigin-RevId: 589272990
2023-12-08 16:31:41 -08:00
Yash Katariya
5fb8ceca73 Make lowering oblivious to real physical devices. Instead cache lowering on HloSharding only (which is based on logical device numbers)
Make an exception for callbacks and custom_partitioning because they need access to device_assignment during lowering.

PiperOrigin-RevId: 589244695
2023-12-08 14:36:09 -08:00
Chase Riley Roberts
abf2c5cc3e Copybara import of the project:
--
118c9e18c72757b4497b035a8628125e63feb435 by Thenerdstation <chaserileyroberts@gmail.com>:

Added bytes usage

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/18349 from chaserileyroberts:chase/forward_fix/custom_part_bytes 118c9e18c72757b4497b035a8628125e63feb435
PiperOrigin-RevId: 580000167
2023-11-06 17:01:02 -08:00
George Necula
edbe49fb2a Cleanup the handling of single- and multi-platform lowering in ModuleContext
Previously, we introduced support for multi-platform lowering, by
adding a new LoweringParameters object that can be used to specify
a cross-lowering platform or even multiple platforms. But we had
kept the ModuleContext.platform in place because some lowering rules
were still referencing it. Now we replace ModuleContext.platform with
ModuleContext.platforms, which removes the redundancy, simplifies
the code, and makes it clearer that the lowering rules should not
simply assume single-platform lowering.

PiperOrigin-RevId: 576575376
2023-10-25 10:40:41 -07:00
Parker Schuh
03575c4b33 Pad generated sharding specs with None up to ndims to simplify comparing dims
across different partitioned arguments.

PiperOrigin-RevId: 555712119
2023-08-10 17:02:31 -07:00
Parker Schuh
74bcd65bbd Make mesh available to custom_partitioning lowering rules.
PiperOrigin-RevId: 555319896
2023-08-09 17:08:57 -07:00
Yash Katariya
4d698c30b9 Return PositionalSharding instead of GSPMDSharding in custom_partitioning when mesh is not defined
PiperOrigin-RevId: 539719517
2023-06-12 11:52:28 -07:00
Yash Katariya
01fdd91a5f Use _to_xla_hlo_sharding everywhere in JAX. Remove _to_xla_op_sharding in favor of _to_xla_hlo_sharding since constructing a C++ class is faster than protos and will help with further changes coming to HloSharding.
PiperOrigin-RevId: 537969500
2023-06-05 13:41:31 -07:00
Parker Schuh
5c2070c204 custom_parititioning: in lower sharding, Sharding should be XLACompatibleSharding.
PiperOrigin-RevId: 537077304
2023-06-01 11:16:08 -07:00
Yash Katariya
ae9d1498e5 Bump minimum jaxlib version to 0.4.11. xla_extension_version is 158 and mlir_api_version is 49. It will subsume https://github.com/google/jax/pull/16161#issuecomment-1564977332
PiperOrigin-RevId: 537047525
2023-06-01 09:42:55 -07:00
Parker Schuh
016eae4141 Allow disabling the parsing of GSPMDSharding -> NamedSharding.
Because this is best effort, users writing code to handle GPSMDSharding
should be able to deal only with the GSPMDSharding type.

PiperOrigin-RevId: 534612265
2023-05-23 17:16:56 -07:00
Parker Schuh
56ca8af9bb Make custom_partitioning support multiple return values.
PiperOrigin-RevId: 533584581
2023-05-19 16:58:54 -07:00
Parker Schuh
08169291a4 Simplify custom_partitioning to use jax.ShapeDtypeStruct instead of passing separate
arguments for shape and sharding.

PiperOrigin-RevId: 533257532
2023-05-18 14:48:07 -07:00
Peter Hawkins
ba11b9dcba Remove tupling of custom call results.
MHLO-to-HLO conversion now knows how to introduce tuples to custom calls if needed, so we can remove explicit tupling from JAX.

PiperOrigin-RevId: 528485268
2023-05-01 09:02:14 -07:00
Parker Schuh
87c328864b Improve testing for custom_partitioning.
Add a test to demonstrate how to force XLA to choose
a different sharding.

Also it is possible to return the wrong
shape from a partition function. We should error in this case.

PiperOrigin-RevId: 525606690
2023-04-19 18:26:51 -07:00
Yash Katariya
393e5931d1 Move parse_flatten_op_sharding to sharding_impls.py to remove local import of pjit using that function from pxla.py
PiperOrigin-RevId: 523573375
2023-04-11 19:26:25 -07:00
Peter Hawkins
be1cf46a49 Split sharding_impls into its own Bazel target.
* Move dependencies of sharding_impls into sharding_impls to avoid creating cyclic dependencies.
* Fix a handful of new pytype errors.

PiperOrigin-RevId: 523146076
2023-04-10 10:15:58 -07:00
Peter Hawkins
6cc1bf54a1 Move jax.interpreters.partial_eval to jax._src.interpreters.partial_eval.
Also fix up some other internal imports of jax.interpreters.* to use jax._src.interpreters.

PiperOrigin-RevId: 519813664
2023-03-27 13:30:47 -07:00
Peter Hawkins
a0121d9b9b Improve pytype inference for Sharding type.
* Define use_cpp_class and use_cpp_method decorators as no-ops for type checking.
* Remove the use of abc.ABC when defining the Sharding type. This triggers a pytype bug: the easiest fix seems to be to skip the use of the ABC.
* Write use_cpp_class decorator differently on ArrayImpl to work around pytype bug.
* Fix a few new type errors.

PiperOrigin-RevId: 516631428
2023-03-14 14:20:17 -07:00
Peter Hawkins
1925aa1109 Split Sharding subclasses out of _src/sharding.py into _src/sharding_impls.py
By defining the Sharding base class in its own module, we can pull it out into a separate Bazel submodule, which will help pytype inference when defining Array.

PiperOrigin-RevId: 516223009
2023-03-13 08:50:18 -07:00
Matthew Johnson
b05975b964 add result info to mhlo, fixes #14780
incidentally fixes #14787
2023-03-06 21:21:26 -08:00
Yash Katariya
52a7701dda Replace usage of {in|out}_axis_resources with {in|out}_shardings
PiperOrigin-RevId: 513040164
2023-02-28 14:29:09 -08:00
Yash Katariya
0ffdeb3de2 Rename jax.sharding.OpShardingSharding to jax.sharding.GSPMDSharding. jax.sharding.OpShardingSharding will be removed in 3 months from Feb 17, 2023.
PiperOrigin-RevId: 510556189
2023-02-17 17:11:06 -08:00
Peter Hawkins
0af9fff5ca Replace uses of deprecated JAX sharding APIs with their new names in jax.sharding.
This change updates:
* {jax.experimental.maps.Mesh, jax.interpreters.pxla.Mesh} to jax.sharding.Mesh
* {jax.experimental.PartitionSpec, jax.experimental.pjit.PartitionSpec, jax.interpreters.pxla.PartitionSpec, jax.pxla.PartitionSpec} to jax.sharding.PartitionSpec
* jax.experimental.maps.NamedSharding to jax.sharding.NamedSharding.

PiperOrigin-RevId: 510027595
2023-02-15 21:03:03 -08:00
Peter Hawkins
00d45feee6 Deprecate uses of jax.experimental.pjit.NamedSharding and jax.experimental.pjit.PartitionSpec.
Use the aliases under jax.sharding instead.

PiperOrigin-RevId: 509837529
2023-02-15 08:14:26 -08:00
Roy Frostig
cb8dcce2fe migrate more internal dependencies from jax.core to jax._src.core
PiperOrigin-RevId: 509736368
2023-02-14 23:01:11 -08:00
Peter Hawkins
cc8d7fae32 Move jax.interpreters.mlir to jax._src.interpreters.mlir.
Replace jax.interpreters.mlir with a shim that re-exports names that are likely to be used externally.

PiperOrigin-RevId: 508187063
2023-02-08 14:39:01 -08:00
Parker Schuh
7526d0ea1f Add static_argnums to custom_partitioning.
Arguments specified by static_argnums cannot contain
any jax tracers because they will be passed into the XLA compiler
where the lowering information for these tracers is already lost.
2023-02-03 11:41:17 -08:00
Leopold Cambier
df89c77b06 Fix trailing whiteshape + failing doc test + removing first section title 2023-01-24 15:55:17 -08:00
Leopold Cambier
3c3a2eea50 Removing HLO dump from docstring, using assert(re.search(...)) 2023-01-23 14:38:34 -08:00