6 Commits

Author SHA1 Message Date
Yash Katariya
0ffdeb3de2 Rename jax.sharding.OpShardingSharding to jax.sharding.GSPMDSharding. jax.sharding.OpShardingSharding will be removed in 3 months from Feb 17, 2023.
PiperOrigin-RevId: 510556189
2023-02-17 17:11:06 -08:00
Peter Hawkins
2b9ad0d93e Move contents of jax.experimental.global_device_array to jax._src.global_device_array.
Make jax.experimental.global_device_array a shim around jax._src.global_device_array.

Change in preparation for deprecating global device arrays.

PiperOrigin-RevId: 510261140
2023-02-16 15:37:10 -08:00
Peter Hawkins
54269c1145 Remove more exported names from jax.interpreters.xla.
None of these appear to have public users, and this module is not included in the deprecation policy.

Also:
* shorten a number of alias chains.
* move make_op_metadata() into its only caller in jax2tf
* delete the unused function dtype_to_primitive_type.
PiperOrigin-RevId: 510205315
2023-02-16 11:56:30 -08:00
Roy Frostig
cb8dcce2fe migrate more internal dependencies from jax.core to jax._src.core
PiperOrigin-RevId: 509736368
2023-02-14 23:01:11 -08:00
Sharad Vikram
442aa028c2 Fix xmap staging rule to handle positional semantics
PiperOrigin-RevId: 509356614
2023-02-13 16:05:17 -08:00
Peter Hawkins
4a523e3d74 Minimize exported names from jax.experimental.maps.
Move implementation of maps to jax._src.maps.

PiperOrigin-RevId: 509309092
2023-02-13 12:57:54 -08:00