13 Commits

Author SHA1 Message Date
Peter Hawkins
8fb1fd318d Replace jax._src.util.prod with math.prod.
math.prod() was added in Python 3.8, so we can assume it is always present.

PiperOrigin-RevId: 513011144
2023-02-28 12:41:00 -08:00
Peter Hawkins
f66f6ec98a [JAX] Move jax._src.lib.xla_bridge to jax._src.xla_bridge.
Limit jax._src.lib to shims around jaxlib and nothing else.

The goal of this change is to avoid a dependency cycle between the rest of jax and jax._src.lib in a Bazel build. This allows the types for jax._src.lib to be inferred by pytype in isolation without referring to the rest of JAX.

PiperOrigin-RevId: 512922397
2023-02-28 07:01:57 -08:00
Yash Katariya
0d834c0c00 Use the standard jtu.create_global_mesh instead of creating a mesh from scratch.
PiperOrigin-RevId: 511648529
2023-02-22 18:11:48 -08:00
Matthew Johnson
1ddb3f6a92 [shard-map] add annotations and notes to shard_map_test.py 2023-02-17 10:54:29 -08:00
Matthew Johnson
ab881cb720 [shard-map] add systematic tests for eager, jit, autodiff 2023-02-16 20:40:09 -08:00
Peter Hawkins
00d45feee6 Deprecate uses of jax.experimental.pjit.NamedSharding and jax.experimental.pjit.PartitionSpec.
Use the aliases under jax.sharding instead.

PiperOrigin-RevId: 509837529
2023-02-15 08:14:26 -08:00
Roy Frostig
cb8dcce2fe migrate more internal dependencies from jax.core to jax._src.core
PiperOrigin-RevId: 509736368
2023-02-14 23:01:11 -08:00
Matthew Johnson
9538bc3e73 generalize vmap spmd_axis_name to accept tuples of axis names
This brings the argument more in line with what can appear as positional
arguments to the PartitionSpec constructor.
2023-02-10 15:25:23 -08:00
Matthew Johnson
6fb3ace5d0 [shard-map] add vmap spmd_axis_name support, fix vmap rule bug 2023-02-08 23:54:28 -08:00
Matthew Johnson
1a03f34383 [shard-map] if check_rep=False, don't call rep rules in eager 2023-02-08 15:42:35 -08:00
Matthew Johnson
58d3f552d7 [shard-map] add remat support, very basic test 2023-02-08 11:15:38 -08:00
Matthew Johnson
6db3f48656 [shard_map] add rep rule for axis_index, trivial test 2023-02-06 16:59:22 -08:00
Matthew Johnson
ff1e9b3973 shard_map (shmap) prototype and JEP
Co-authored-by: Sharad Vikram <sharadmv@google.com>
Co-authored-by: Sholto Douglas <sholto@google.com>
2023-02-02 23:01:30 -08:00