2022-09-22 12:26:48 -07:00
|
|
|
# Copyright 2021 The JAX Authors.
|
2021-02-05 16:50:38 -08:00
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
2023-07-21 14:20:39 -04:00
|
|
|
from collections import OrderedDict, namedtuple
|
2024-06-06 14:18:27 -07:00
|
|
|
import contextlib
|
2021-09-07 07:53:42 -07:00
|
|
|
import re
|
2024-03-26 13:28:03 -07:00
|
|
|
from functools import partial
|
2021-04-27 10:29:39 -07:00
|
|
|
import logging
|
2023-02-28 12:40:30 -08:00
|
|
|
import math
|
2023-12-07 15:56:56 +00:00
|
|
|
import textwrap
|
2021-07-01 11:59:13 -07:00
|
|
|
import threading
|
2021-09-29 11:11:01 -07:00
|
|
|
import unittest
|
2021-02-05 16:50:38 -08:00
|
|
|
|
|
|
|
from absl.testing import absltest
|
2021-04-15 06:12:18 -07:00
|
|
|
from absl.testing import parameterized
|
2021-02-05 16:50:38 -08:00
|
|
|
import numpy as np
|
|
|
|
|
2022-09-21 20:17:38 -07:00
|
|
|
import concurrent.futures
|
|
|
|
|
2021-02-05 16:50:38 -08:00
|
|
|
import jax
|
|
|
|
import jax.numpy as jnp
|
2023-02-14 23:00:40 -08:00
|
|
|
from jax._src import core
|
2023-10-11 08:45:30 -07:00
|
|
|
from jax._src import config
|
2021-09-24 07:02:08 -07:00
|
|
|
from jax._src import test_util as jtu
|
2022-05-05 17:20:00 -07:00
|
|
|
from jax import dtypes
|
2022-03-24 15:00:10 -07:00
|
|
|
from jax import stages
|
2021-04-27 10:29:39 -07:00
|
|
|
from jax import lax
|
2024-07-09 07:32:38 -07:00
|
|
|
from jax._src.lax import lax as lax_internal
|
2022-12-11 22:54:39 -08:00
|
|
|
from jax.lax import with_sharding_constraint
|
2023-04-05 14:09:46 -07:00
|
|
|
from jax._src import prng
|
2024-07-01 13:13:53 -07:00
|
|
|
from jax.sharding import PartitionSpec as P, Mesh
|
2022-11-05 20:15:39 -07:00
|
|
|
from jax.experimental import multihost_utils
|
2022-10-13 18:35:24 -07:00
|
|
|
from jax.experimental.custom_partitioning import custom_partitioning
|
2022-09-27 10:06:10 -07:00
|
|
|
from jax._src import array
|
2024-06-05 08:02:39 -07:00
|
|
|
from jax._src.sharding import Sharding, common_devices_indices_map
|
2023-04-06 08:31:47 -07:00
|
|
|
from jax._src import op_shardings
|
2023-04-10 10:15:08 -07:00
|
|
|
from jax._src import sharding_impls
|
|
|
|
from jax._src.sharding_impls import (
|
|
|
|
AUTO, UNSPECIFIED, NamedSharding, GSPMDSharding, PositionalSharding,
|
2023-04-11 19:25:56 -07:00
|
|
|
SingleDeviceSharding, parse_flatten_op_sharding)
|
2022-12-16 13:06:38 -08:00
|
|
|
import jax._src.pjit as pjit_lib
|
2024-07-24 10:23:29 -07:00
|
|
|
from jax._src.pjit import pjit
|
2023-07-14 14:27:29 -07:00
|
|
|
from jax._src import mesh as mesh_lib
|
2023-02-07 11:16:01 -08:00
|
|
|
from jax._src.interpreters import pxla
|
#sdy Initial set of changes to allow for lowering to the Shardy dialect.
The OpenXLA project is working on an open source, MLIR, named-axis based propagation (and in the future SP<D partitioning) system that will be dialect agnostic (would work for any dialect - MHLO, StableHLO, YourDialect). We plan on having frontends like JAX and PyTorch target this when using XLA and wanting SPMD propagation/partitioning. See www.github.com/openxla/shardy for more info.
Currently Shardy is implemented inside the XLA compiler, requiring us to round-trip between StableHLO and HLO with `mhlo.sharding`s. But we will eventually make Shardy the first pass in the XLA pipeline while it's still working on StableHLO. Partitioning (the system that adds the collectives like all-gathers/all-reduces) will still be the GSPMD Partitioner, but next year the Shardy partitioner will be developed, allowing for propagation and partitioning to be completely in MLIR and the first pass in the pipeline. So then we'd have:
1. Traced jaxpr
2. Jaxpr -> StableHLO
3. StableHLO with Shardy propagation
4. StableHLO with Shardy partitioning
5. StableHLO -> HLO
6. XLA optimizations
The following test:
```py
def test_sdy_lowering(self):
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
np_inp = np.arange(16).reshape(8, 2)
s = jax.sharding.NamedSharding(mesh, P('x', 'y'))
arr = jax.device_put(np_inp, s)
@partial(jax.jit, out_shardings=s)
def f(x):
return x * 2
print(f.lower(arr).as_text())
```
outputs:
```
module @jit_f attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} {
sdy.mesh @mesh = <"x"=4, "y"=2>
func.func public @main(%arg0: tensor<8x2xi64> {mhlo.layout_mode = "{1,0}", sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}) -> (tensor<8x2xi64> {jax.result_info = "", mhlo.layout_mode = "default", sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}) {
%c = stablehlo.constant dense<2> : tensor<i64>
%0 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i64>) -> tensor<8x2xi64>
%1 = stablehlo.multiply %arg0, %0 : tensor<8x2xi64>
return %1 : tensor<8x2xi64>
}
}
```
Shardy will be hidden behind the `jax_use_shardy_partitioner` flag initially before becoming enabled by default in the future.
PiperOrigin-RevId: 655127611
2024-07-23 05:31:15 -07:00
|
|
|
from jax._src.lib.mlir import dialects
|
2023-02-28 07:01:14 -08:00
|
|
|
from jax._src import xla_bridge
|
|
|
|
from jax._src.lib import xla_client as xc
|
Introduce `jax.sharding.AbstractMesh(shape_tuple: tuple[tuple[str, int], ...])` and allow `with_sharding_constraint` and `shard_map` to accept an abstract mesh as input (`with_sharding_constraint` is via `NamedSharding(abstract_mesh, pspec)`).
**Semantics**
Inside jit, we don't need to talk about concrete devices ever so the semantics stay the same as today i.e. we can lower a NamedSharding with abstract mesh with only mesh axis names and sizes and PartitionSpec. The only restriction is that the number of devices need to be consistent throughout the program when we are tracing.
During compilation, the order of devices throughout the program needs to be consistent (same as before this change).
Outside jit i.e. eager mode, if a `shard_map` or `with_sharding_constraint` contains AbstractMesh, then the input to those primitives should contain a concrete Mesh with the same shape and names as the abstract mesh.
**Why do this?**
There are cases, where you want the change the devices in the mesh but keep the mesh shape the same (axis names and axis sizes). But this leads to a device mismatch error if you have `with_sharding_constraint` or `shard_map` in your computation because they embed concrete devices in their signature.
So to fix the error, you need to change the mesh in `wsc` and `shmap` which will lead to a tracing cache miss (because function id is now different) and consequently a lowering to stableHLO cache miss. Explaining via an example:
```
mesh1 = Mesh(jax.devices()[:2], 'x')
mesh2 = Mesh(jax.devices()[2:4], 'x')
arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P()))
arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P()))
@jax.jit
def f(x):
y = with_sharding_constraint(x, NamedSharding(mesh1, P('x')))
return y * 2
f(arr_mesh1)
f(arr_mesh2) # DEVICE MISMATCH ERROR!
```
The same problem exists for `shard_map` since it takes a mesh with concrete devices in it's signature.
**Okay, so how do you fix this?**
As mentioned above, we need the above program to work and get tracing and lowering cache hits (**cache hits is the most important** part here)
The approach in this change, allows `with_sharding_constraint` to accept a `NamedSharding(abstract_mesh, pspec)` as input. This leads to no errors downstream and we get tracing and lowering cache hits since we don't encode the concrete devices anymore. Just the axis_names and axis_size of the mesh.
**The important part is that the concrete device information should only come from the arguments. Inside `jax.jit`, you should never reference concrete devices ever.**
```
mesh1 = Mesh(jax.devices()[:2], 'x')
mesh2 = Mesh(jax.devices()[2:4], 'x')
arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P()))
arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P()))
# Creating abstract mesh with mesh1 but since both meshes have the same shape (names
# and axis size), it should be ok.
abstract_mesh = jax.sharding.AbstractMesh(arr_mesh1.shape_tuple)
@jax.jit
def f(x):
y = with_sharding_constraint(x, NamedSharding(abstract_mesh, P('x')))
return y * 2
f(arr_mesh1)
f(arr_mesh2) # tracing and lowering cache hit
```
**One caveat is that this only works with `jax.NamedSharding` but that's fine because `NamedSharding` is the most used `Sharding` in JAX.**
**What about `shard_map`?**
shard_map's signature will be: `shmap(f, mesh: Mesh | AbstractMesh, in_specs: Specs, out_specs: Specs)`.
```
mesh1 = Mesh(jax.devices()[:2], 'x')
mesh2 = Mesh(jax.devices()[2:4], 'x')
arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P()))
arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P()))
# Creating abstract mesh with mesh1 but since both meshes have the same shape (names
# and axis size), it should be ok.
abstract_mesh = jax.sharding.AbstractMesh(arr_mesh1.shape_tuple)
@jax.jit
def f(x):
y = shard_map(lambda x: x, mesh=abstract_mesh, in_specs=P('x'), out_specs=P('x'))
return y * 2
f(arr_mesh1)
f(arr_mesh2) # tracing and lowering cache hit
```
This is a fully backwards change. So your current code will continue to work as is but you can opt-into this new behavior and get all the benefits!
PiperOrigin-RevId: 662670932
2024-08-13 15:17:30 -07:00
|
|
|
from jax._src.lib import xla_extension_version
|
2023-05-23 16:37:02 -07:00
|
|
|
from jax._src.lib import xla_extension
|
2024-01-05 14:16:32 -08:00
|
|
|
from jax._src.util import curry, unzip2
|
2021-02-05 16:50:38 -08:00
|
|
|
|
2023-10-12 13:15:22 +01:00
|
|
|
config.parse_flags_with_absl()
|
2021-02-05 16:50:38 -08:00
|
|
|
|
2024-06-06 14:18:27 -07:00
|
|
|
# Run all tests with 8 CPU devices.
|
|
|
|
_exit_stack = contextlib.ExitStack()
|
2022-11-17 05:33:54 +00:00
|
|
|
|
2021-04-27 02:19:18 -07:00
|
|
|
def setUpModule():
|
2024-06-06 14:18:27 -07:00
|
|
|
_exit_stack.enter_context(jtu.set_host_platform_device_count(8))
|
2021-04-27 02:19:18 -07:00
|
|
|
|
|
|
|
def tearDownModule():
|
2024-06-06 14:18:27 -07:00
|
|
|
_exit_stack.close()
|
2021-04-15 06:12:18 -07:00
|
|
|
|
2022-08-25 12:22:42 -07:00
|
|
|
def create_array(global_shape, global_mesh, mesh_axes, global_data=None,
|
|
|
|
dtype=np.float32):
|
2022-06-10 07:31:43 -07:00
|
|
|
if global_data is None:
|
|
|
|
global_data = np.arange(
|
2023-02-28 12:40:30 -08:00
|
|
|
math.prod(global_shape), dtype=dtype).reshape(global_shape)
|
2022-06-10 07:31:43 -07:00
|
|
|
|
2022-07-15 16:12:42 -07:00
|
|
|
if isinstance(mesh_axes, Sharding):
|
|
|
|
sharding = mesh_axes
|
|
|
|
else:
|
2022-11-14 14:43:26 -08:00
|
|
|
sharding = NamedSharding(global_mesh, mesh_axes)
|
2022-06-10 07:31:43 -07:00
|
|
|
|
|
|
|
return array.make_array_from_callback(
|
|
|
|
global_shape, sharding, lambda idx: global_data[idx]), global_data
|
|
|
|
|
|
|
|
|
2022-09-16 11:15:56 -07:00
|
|
|
def _check_instance(self, x):
|
2023-03-15 17:08:21 -07:00
|
|
|
self.assertIsInstance(x, array.ArrayImpl)
|
2022-09-16 11:15:56 -07:00
|
|
|
|
|
|
|
|
Add reverse-mode AD support for pjit
This is a somewhat big patch, because the transposition process turns out to be
quite difficult. The biggest issue appears when we do partial evaluation and we have
to add a whole bunch of intermediate values as outputs of the primal computation,
but we don't have any partition specs for them!
A simple workaround would be to mark all of them as replicated, but that would
likely tank performance which is why we didn't go with that option. Instead, we use
a newly added XLA option called `allow_spmd_sharding_propagation_to_output` to compile
a throwaway executable that lets us query output sharding that XLA considers convenient
for the computation.
However, there's one more difficulty: XLA's `OpSharding` is much less constrained
than our `PartitionSpec`s. In particular, while `PartitionSpec`s can only represent
"block permutations" of devices (with blocks deliniated by mesh axes), `OpSharding`
allows arbitrary assignment (permutation) of tensor chunks to devices. This means that
not every `OpSharding` has a corresponding `PartitionSpec`, but I did implement a
(somewhat involved) procedure that should recover one whenever it exists.
Unfortunately this makes our support for reverse-mode AD partial, because we might
be unable to handle `OpSharding` returned by XLA. But this will only happen if XLA
actually comes up with sharding specifications on its own. If it merely propagates
the sharding obtained from `PartitionSpec`s into the middle of the computation, then
we should be good. In any case, if we end up seeing failures in this path, we should
consider relaxing `PartitionSpec`s, but that would be a pretty large change, so I decided
to avoid it unless there's no other way.
PiperOrigin-RevId: 399680306
2021-09-29 07:19:28 -07:00
|
|
|
@curry
|
|
|
|
def check_1d_2d_mesh(f, set_mesh):
|
|
|
|
return parameterized.named_parameters(
|
|
|
|
{"testcase_name": "_" + name, "mesh": mesh, "resources": resources}
|
|
|
|
for name, mesh, resources in (
|
|
|
|
("2", (("x", 2),), "x"),
|
|
|
|
("2x1", (("x", 2), ("y", 1)), ("x", "y")),
|
|
|
|
("2x2", (("x", 2), ("y", 2)), ("x", "y")),
|
|
|
|
))(jtu.with_mesh_from_kwargs(f) if set_mesh else f)
|
|
|
|
|
|
|
|
|
2021-02-05 16:50:38 -08:00
|
|
|
# TODO(skye): make the buffer donation utils part of JaxTestCase
|
2023-01-12 22:42:06 +00:00
|
|
|
@jtu.pytest_mark_if_available('multiaccelerator')
|
2021-02-05 16:50:38 -08:00
|
|
|
class PJitTest(jtu.BufferDonationTestCase):
|
|
|
|
|
2021-08-26 22:36:58 -07:00
|
|
|
@jtu.with_mesh([('x', 1)])
|
|
|
|
def testDeviceBufferAval(self):
|
|
|
|
|
2023-02-28 14:28:32 -08:00
|
|
|
@partial(pjit, in_shardings=None, out_shardings=P('x'))
|
2021-08-26 22:36:58 -07:00
|
|
|
def f(x):
|
|
|
|
return x
|
|
|
|
|
|
|
|
shape = (2, 2)
|
2023-02-28 12:40:30 -08:00
|
|
|
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
2021-08-26 22:36:58 -07:00
|
|
|
actual = f(x)
|
|
|
|
expected = x
|
|
|
|
self.assertAllClose(actual, expected, check_dtypes=False)
|
2022-09-16 11:15:56 -07:00
|
|
|
_check_instance(self, actual)
|
2023-12-04 13:35:35 -08:00
|
|
|
self.assertLen(actual.addressable_shards, 1)
|
2021-08-26 22:36:58 -07:00
|
|
|
self.assertAllClose(
|
2023-12-04 13:35:35 -08:00
|
|
|
np.asarray(actual.addressable_shards[0].data), expected, check_dtypes=False)
|
|
|
|
# Repro for a bug on addressable_shards aval
|
|
|
|
_ = repr(actual.addressable_shards)
|
2021-08-26 22:36:58 -07:00
|
|
|
|
2021-06-01 14:32:59 +03:00
|
|
|
@jtu.with_mesh([('x', 2)])
|
2021-02-05 16:50:38 -08:00
|
|
|
def testBasic1D(self):
|
|
|
|
@partial(pjit,
|
2023-02-28 14:28:32 -08:00
|
|
|
in_shardings=(P('x'), P('x')),
|
|
|
|
out_shardings=None)
|
2021-02-05 16:50:38 -08:00
|
|
|
def f(x, y):
|
|
|
|
return x + y
|
|
|
|
|
|
|
|
shape = (8, 8)
|
2023-02-28 12:40:30 -08:00
|
|
|
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
2021-02-05 16:50:38 -08:00
|
|
|
actual = f(x, x + 1)
|
|
|
|
expected = x + (x + 1)
|
|
|
|
self.assertAllClose(actual, expected, check_dtypes=False)
|
2022-09-16 11:15:56 -07:00
|
|
|
_check_instance(self, actual)
|
2023-12-04 13:35:35 -08:00
|
|
|
self.assertLen(actual.addressable_shards, 2)
|
|
|
|
self.assertAllClose(np.asarray(actual.addressable_shards[0].data), expected,
|
2021-02-05 16:50:38 -08:00
|
|
|
check_dtypes=False)
|
|
|
|
|
2022-04-12 09:45:18 -07:00
|
|
|
@jtu.with_mesh([('x', 2)])
|
|
|
|
def testJitOfPjitDisallowed(self):
|
|
|
|
@partial(pjit,
|
2023-02-28 14:28:32 -08:00
|
|
|
in_shardings=(P('x'), P('x')),
|
|
|
|
out_shardings=None)
|
2022-04-12 09:45:18 -07:00
|
|
|
def f(x, y):
|
|
|
|
return x + y
|
|
|
|
|
|
|
|
shape = (8, 8)
|
2023-02-28 12:40:30 -08:00
|
|
|
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
2023-03-15 17:08:21 -07:00
|
|
|
out = jax.jit(f)(x, x + 1)
|
|
|
|
self.assertArraysEqual(out, x + x + 1)
|
2022-04-12 09:45:18 -07:00
|
|
|
|
2022-02-18 10:51:49 -08:00
|
|
|
@jtu.with_mesh([('x', 2)])
|
|
|
|
def testUnevenShardingConstraint(self):
|
|
|
|
@partial(pjit,
|
2023-02-28 14:28:32 -08:00
|
|
|
in_shardings=(P('x'), P('x')),
|
|
|
|
out_shardings=None)
|
2022-02-18 10:51:49 -08:00
|
|
|
def f(x, y):
|
|
|
|
x = x[:3]
|
|
|
|
y = y[:3]
|
|
|
|
x = with_sharding_constraint(x, P('x'))
|
|
|
|
y = with_sharding_constraint(y, P('x'))
|
|
|
|
out = x + y
|
|
|
|
return jnp.pad(out, [[0, 1]])
|
|
|
|
|
|
|
|
shape = (4,)
|
2023-02-28 12:40:30 -08:00
|
|
|
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
2022-02-18 10:51:49 -08:00
|
|
|
actual = f(x, x + 1)
|
|
|
|
expected = x + (x + 1)
|
|
|
|
self.assertAllClose(actual[:3], expected[:3], check_dtypes=False)
|
2022-09-16 11:15:56 -07:00
|
|
|
_check_instance(self, actual)
|
2023-12-04 13:35:35 -08:00
|
|
|
self.assertLen(actual.addressable_shards, 2)
|
|
|
|
self.assertAllClose(np.asarray(actual.addressable_shards[0].data)[:3],
|
|
|
|
expected[:3], check_dtypes=False)
|
2022-02-18 10:51:49 -08:00
|
|
|
|
2022-02-16 19:44:13 -08:00
|
|
|
def testBasic1DWithMeshContextManager(self):
|
|
|
|
@partial(pjit,
|
2023-02-28 14:28:32 -08:00
|
|
|
in_shardings=(P('x'), P('x')),
|
|
|
|
out_shardings=None)
|
2022-02-16 19:44:13 -08:00
|
|
|
def f(x, y):
|
|
|
|
return x + y
|
|
|
|
|
|
|
|
shape = (8, 8)
|
2023-02-28 12:40:30 -08:00
|
|
|
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
2024-09-03 16:22:23 -07:00
|
|
|
with jtu.create_mesh((2,), ('x')) as mesh:
|
2022-02-16 19:44:13 -08:00
|
|
|
actual = f(x, x + 1)
|
|
|
|
expected = x + (x + 1)
|
2024-09-03 16:22:23 -07:00
|
|
|
self.assertEqual(mesh, jtu.create_mesh((2,), ('x')))
|
2022-02-16 19:44:13 -08:00
|
|
|
self.assertAllClose(actual, expected, check_dtypes=False)
|
2022-09-16 11:15:56 -07:00
|
|
|
_check_instance(self, actual)
|
2023-12-04 13:35:35 -08:00
|
|
|
self.assertLen(actual.addressable_shards, 2)
|
|
|
|
self.assertAllClose(np.asarray(actual.addressable_shards[0].data), expected,
|
2022-02-16 19:44:13 -08:00
|
|
|
check_dtypes=False)
|
|
|
|
|
2021-06-01 14:32:59 +03:00
|
|
|
@jtu.with_mesh([('x', 2), ('y', 2)])
|
2021-02-05 16:50:38 -08:00
|
|
|
def testBasic2D(self):
|
|
|
|
@partial(pjit,
|
2023-02-28 14:28:32 -08:00
|
|
|
in_shardings=(P(None, 'x', 'y'), P('y')),
|
|
|
|
out_shardings=P('x'))
|
2021-02-05 16:50:38 -08:00
|
|
|
def f(x, y):
|
|
|
|
return x @ y
|
|
|
|
|
|
|
|
x_shape = (8, 6, 4)
|
|
|
|
y_shape = (4, 2)
|
2023-04-13 11:48:11 -07:00
|
|
|
x = jnp.arange(math.prod(x_shape)).reshape(x_shape)
|
|
|
|
y = jnp.arange(math.prod(y_shape)).reshape(y_shape)
|
2021-02-05 16:50:38 -08:00
|
|
|
actual = f(x, y)
|
|
|
|
expected = x @ y
|
|
|
|
self.assertAllClose(actual, expected, check_dtypes=False)
|
2022-09-16 11:15:56 -07:00
|
|
|
_check_instance(self, actual)
|
2023-12-04 13:35:35 -08:00
|
|
|
self.assertLen(actual.addressable_shards, 4)
|
2021-02-05 16:50:38 -08:00
|
|
|
|
|
|
|
split0, split1 = np.split(expected, 2)
|
2023-12-04 13:35:35 -08:00
|
|
|
self.assertAllClose(np.asarray(actual.addressable_shards[0].data), split0,
|
2021-02-05 16:50:38 -08:00
|
|
|
check_dtypes=False)
|
2023-12-04 13:35:35 -08:00
|
|
|
self.assertAllClose(np.asarray(actual.addressable_shards[1].data), split0,
|
2021-02-05 16:50:38 -08:00
|
|
|
check_dtypes=False)
|
2023-12-04 13:35:35 -08:00
|
|
|
self.assertAllClose(np.asarray(actual.addressable_shards[2].data), split1,
|
2021-02-05 16:50:38 -08:00
|
|
|
check_dtypes=False)
|
2023-12-04 13:35:35 -08:00
|
|
|
self.assertAllClose(np.asarray(actual.addressable_shards[3].data), split1,
|
2021-02-05 16:50:38 -08:00
|
|
|
check_dtypes=False)
|
|
|
|
|
2022-02-18 11:15:56 -08:00
|
|
|
def testDifferentNestedMesh(self):
|
2024-09-03 16:22:23 -07:00
|
|
|
with jtu.create_mesh((2, 1), ("x", "y")) as m1:
|
|
|
|
with jtu.create_mesh((2, 2), ("a", "b")) as m2:
|
2023-07-14 14:27:29 -07:00
|
|
|
self.assertEqual(mesh_lib.thread_resources.env.physical_mesh, m2)
|
|
|
|
self.assertEqual(mesh_lib.thread_resources.env.physical_mesh, m1)
|
|
|
|
self.assertEqual(mesh_lib.thread_resources.env.physical_mesh,
|
|
|
|
mesh_lib.EMPTY_ENV.physical_mesh)
|
2022-02-18 11:15:56 -08:00
|
|
|
|
|
|
|
def testSameNestedMesh(self):
|
2024-09-03 16:22:23 -07:00
|
|
|
mesh = jtu.create_mesh((2, 1), ("a", "b"))
|
2023-07-14 14:27:29 -07:00
|
|
|
thread_resources = mesh_lib.thread_resources
|
2022-02-18 11:15:56 -08:00
|
|
|
with mesh as m1:
|
|
|
|
with mesh as m2:
|
2023-03-10 10:07:37 -08:00
|
|
|
self.assertEqual(thread_resources.env.physical_mesh, m2)
|
|
|
|
self.assertEqual(thread_resources.env.physical_mesh, m1)
|
|
|
|
self.assertEqual(thread_resources.env.physical_mesh,
|
2023-07-14 14:27:29 -07:00
|
|
|
mesh_lib.EMPTY_ENV.physical_mesh)
|
2022-02-18 11:15:56 -08:00
|
|
|
|
|
|
|
def testMeshDecorator(self):
|
|
|
|
x = jnp.arange(8)
|
|
|
|
mesh_shape = (2, 2)
|
2023-02-28 12:40:30 -08:00
|
|
|
size = math.prod(mesh_shape)
|
2022-02-18 11:15:56 -08:00
|
|
|
if len(jax.devices()) < size:
|
|
|
|
raise unittest.SkipTest(f"Test requires {size} global devices.")
|
|
|
|
mesh_devices = np.array(jax.devices()[:size]).reshape(mesh_shape)
|
|
|
|
|
2023-02-03 14:28:07 -08:00
|
|
|
@jax.sharding.Mesh(mesh_devices, ('x', 'y'))
|
2022-02-18 11:15:56 -08:00
|
|
|
def dec():
|
2023-02-18 09:59:58 -08:00
|
|
|
return pjit(lambda x: x, in_shardings=P('x'), out_shardings=None)(x)
|
2022-02-18 11:15:56 -08:00
|
|
|
out = dec()
|
|
|
|
self.assertArraysEqual(out, x)
|
|
|
|
|
2023-11-22 05:47:17 -08:00
|
|
|
def testMeshHashRace(self):
|
2024-09-03 16:22:23 -07:00
|
|
|
mesh = jtu.create_mesh((2, 1), ('a', 'testMeshHashRace'))
|
2023-11-22 05:47:17 -08:00
|
|
|
self.assertFalse(hasattr(mesh, '_hash'))
|
|
|
|
with concurrent.futures.ThreadPoolExecutor(max_workers=5) as pool:
|
|
|
|
fs = []
|
|
|
|
for _ in range(5):
|
|
|
|
fs.append(pool.submit(lambda: hash(mesh)))
|
|
|
|
for f in concurrent.futures.as_completed(fs):
|
|
|
|
f.result()
|
|
|
|
self.assertTrue(hasattr(mesh, '_hash'))
|
|
|
|
|
2021-06-01 14:32:59 +03:00
|
|
|
@jtu.with_mesh([('x', 2), ('y', 2)])
|
2021-02-05 16:50:38 -08:00
|
|
|
def testTwoMeshAxisSharding(self):
|
|
|
|
@partial(pjit,
|
2023-02-28 14:28:32 -08:00
|
|
|
in_shardings=P(('x', 'y'),),
|
|
|
|
out_shardings=jax.sharding.PartitionSpec(('x', 'y'),))
|
2021-02-05 16:50:38 -08:00
|
|
|
def f(x, y):
|
|
|
|
return x @ y
|
|
|
|
|
|
|
|
shape = (8, 8)
|
2023-04-13 11:48:11 -07:00
|
|
|
x = jnp.arange(math.prod(shape)).reshape(shape)
|
2021-02-05 16:50:38 -08:00
|
|
|
actual = f(x, x + 1)
|
|
|
|
expected = x @ (x + 1)
|
|
|
|
self.assertAllClose(actual, expected, check_dtypes=False)
|
2022-09-16 11:15:56 -07:00
|
|
|
_check_instance(self, actual)
|
2023-12-04 13:35:35 -08:00
|
|
|
self.assertLen(actual.addressable_shards, 4)
|
2021-02-05 16:50:38 -08:00
|
|
|
|
|
|
|
splits = np.split(expected, 4)
|
2023-12-04 13:35:35 -08:00
|
|
|
self.assertAllClose(np.asarray(actual.addressable_shards[0].data), splits[0],
|
2021-02-05 16:50:38 -08:00
|
|
|
check_dtypes=False)
|
2023-12-04 13:35:35 -08:00
|
|
|
self.assertAllClose(np.asarray(actual.addressable_shards[1].data), splits[1],
|
2021-02-05 16:50:38 -08:00
|
|
|
check_dtypes=False)
|
2023-12-04 13:35:35 -08:00
|
|
|
self.assertAllClose(np.asarray(actual.addressable_shards[2].data), splits[2],
|
2021-02-05 16:50:38 -08:00
|
|
|
check_dtypes=False)
|
2023-12-04 13:35:35 -08:00
|
|
|
self.assertAllClose(np.asarray(actual.addressable_shards[3].data), splits[3],
|
2021-02-05 16:50:38 -08:00
|
|
|
check_dtypes=False)
|
|
|
|
|
2021-06-01 14:32:59 +03:00
|
|
|
@jtu.with_mesh([('x', 2)])
|
2023-09-13 16:35:02 +01:00
|
|
|
@jtu.run_on_devices('cpu', 'gpu', 'tpu')
|
2021-02-05 16:50:38 -08:00
|
|
|
def testBufferDonation(self):
|
2023-07-12 15:09:18 -07:00
|
|
|
@partial(pjit, in_shardings=P('x'), out_shardings=P('x'), donate_argnums=0)
|
2021-02-05 16:50:38 -08:00
|
|
|
def f(x, y):
|
|
|
|
return x + y
|
|
|
|
|
2023-02-18 09:59:58 -08:00
|
|
|
shard = pjit(lambda x: x, in_shardings=P('x'), out_shardings=P('x'))
|
2021-02-05 16:50:38 -08:00
|
|
|
x = shard(jnp.ones((2, 5)) * 4)
|
|
|
|
y = shard(jnp.ones((2, 5)) * 2)
|
|
|
|
expected = x + y
|
|
|
|
self.assertAllClose(f(x, y), expected)
|
|
|
|
self.assertNotDeleted(y)
|
|
|
|
self.assertDeleted(x)
|
|
|
|
|
2023-09-13 16:35:02 +01:00
|
|
|
@jtu.run_on_devices('cpu', 'gpu', 'tpu')
|
2023-07-12 15:09:18 -07:00
|
|
|
def testBufferDonationWithNames(self):
|
2024-09-03 16:22:23 -07:00
|
|
|
mesh = jtu.create_mesh((2,), ('x'))
|
2023-07-12 15:09:18 -07:00
|
|
|
s = NamedSharding(mesh, P('x'))
|
|
|
|
|
|
|
|
@partial(pjit, out_shardings=s, donate_argnames='inp2')
|
|
|
|
def f(inp1, inp2):
|
|
|
|
return inp1 + inp2
|
|
|
|
|
|
|
|
x = jax.device_put(np.ones((2, 5)) * 4, s)
|
|
|
|
y = jax.device_put(np.ones((2, 5)) * 2, s)
|
|
|
|
expected = x + y
|
|
|
|
self.assertAllClose(f(x, y), expected)
|
|
|
|
self.assertNotDeleted(x)
|
|
|
|
self.assertDeleted(y)
|
|
|
|
|
2023-09-13 16:35:02 +01:00
|
|
|
@jtu.run_on_devices('cpu', 'gpu', 'tpu')
|
2023-07-12 15:09:18 -07:00
|
|
|
def testBufferDonationWithKwargs(self):
|
2024-09-03 16:22:23 -07:00
|
|
|
mesh = jtu.create_mesh((2,), ('x'))
|
2023-07-12 15:09:18 -07:00
|
|
|
s = NamedSharding(mesh, P('x'))
|
|
|
|
|
|
|
|
@partial(pjit, out_shardings=s, donate_argnames=('inp2', 'inp3'))
|
|
|
|
def f(inp1, inp2, inp3):
|
|
|
|
return inp1 + inp2 + inp3, inp3
|
|
|
|
|
|
|
|
x = jax.device_put(np.ones((2, 5)) * 4, s)
|
|
|
|
y = jax.device_put(np.ones((2, 5)) * 2, s)
|
|
|
|
z = jax.device_put(np.ones((2, 5)), s)
|
|
|
|
|
|
|
|
expected = x + y + z
|
|
|
|
self.assertAllClose(f(x, inp2=y, inp3=z)[0], expected)
|
|
|
|
self.assertNotDeleted(x)
|
|
|
|
self.assertDeleted(y)
|
|
|
|
self.assertDeleted(z)
|
|
|
|
|
2023-09-13 16:35:02 +01:00
|
|
|
@jtu.run_on_devices('cpu', 'gpu', 'tpu')
|
2023-07-12 15:09:18 -07:00
|
|
|
def testBufferDonationWithPyTreeKwargs(self):
|
2024-09-03 16:22:23 -07:00
|
|
|
mesh = jtu.create_mesh((2,), ('x'))
|
2023-07-12 15:09:18 -07:00
|
|
|
s = NamedSharding(mesh, P('x'))
|
|
|
|
|
|
|
|
@partial(pjit, out_shardings=s, donate_argnames='inp2')
|
|
|
|
def f(inp1, inp2, inp3):
|
2024-02-22 11:35:39 -08:00
|
|
|
return jax.tree.map(lambda x, y, z: x + y + z, inp1, inp2, inp3)
|
2023-07-12 15:09:18 -07:00
|
|
|
|
|
|
|
x = np.ones((2, 5)) * 4
|
|
|
|
x_tree = jax.device_put({"a": {"b": x}, "c": x}, s)
|
|
|
|
|
|
|
|
y = np.ones((2, 5)) * 2
|
|
|
|
y_tree = jax.device_put({"a": {"b": y}, "c": y}, s)
|
|
|
|
|
|
|
|
z = np.ones((2, 5))
|
|
|
|
z_tree = jax.device_put({"a": {"b": z}, "c": z}, s)
|
|
|
|
|
|
|
|
expected = x + y + z
|
|
|
|
out = f(x_tree, inp2=y_tree, inp3=z_tree)
|
2024-02-22 11:35:39 -08:00
|
|
|
jax.tree.map(lambda o: self.assertAllClose(o, expected), out)
|
|
|
|
jax.tree.map(self.assertNotDeleted, x_tree)
|
|
|
|
jax.tree.map(self.assertDeleted, y_tree)
|
|
|
|
jax.tree.map(self.assertNotDeleted, z_tree)
|
2023-07-12 15:09:18 -07:00
|
|
|
|
2023-12-07 18:45:50 -08:00
|
|
|
@jtu.run_on_devices('tpu', 'cpu', 'gpu')
|
|
|
|
def testBufferDonationWithOutputShardingInference(self):
|
2024-09-03 16:22:23 -07:00
|
|
|
mesh = jtu.create_mesh((2,), 'x')
|
2023-12-07 18:45:50 -08:00
|
|
|
s = NamedSharding(mesh, P('x'))
|
|
|
|
rs = NamedSharding(mesh, P())
|
|
|
|
|
|
|
|
@partial(pjit, donate_argnames=('inp2', 'inp3'))
|
|
|
|
def f(inp1, inp2, inp3):
|
|
|
|
return (
|
|
|
|
jax.lax.with_sharding_constraint(inp1, rs),
|
|
|
|
inp1,
|
|
|
|
jax.lax.with_sharding_constraint(inp2, rs),
|
|
|
|
inp2,
|
|
|
|
jax.lax.with_sharding_constraint(inp3, rs),
|
|
|
|
inp3,
|
|
|
|
)
|
|
|
|
|
|
|
|
x = np.ones((2, 5)) * 4
|
|
|
|
x_tree = jax.device_put({'a': {'b': x}, 'c': x}, s)
|
|
|
|
|
|
|
|
y = np.ones((2, 7)) * 2
|
|
|
|
y_tree = jax.device_put({'a': {'b': y}, 'c': y}, s)
|
|
|
|
|
|
|
|
z = np.ones((2, 11))
|
|
|
|
z_tree = jax.device_put({'a': {'b': z}, 'c': z}, s)
|
|
|
|
|
|
|
|
out = f(x_tree, y_tree, z_tree)
|
2024-02-22 11:35:39 -08:00
|
|
|
jax.tree.map(self.assertNotDeleted, x_tree)
|
|
|
|
jax.tree.map(self.assertDeleted, y_tree)
|
|
|
|
jax.tree.map(self.assertDeleted, z_tree)
|
2023-12-07 18:45:50 -08:00
|
|
|
|
|
|
|
@jtu.run_on_devices('tpu')
|
|
|
|
def testBufferDonationWithOutputShardingInferenceAndTokens(self):
|
2024-08-23 06:50:14 -07:00
|
|
|
if config.use_shardy_partitioner.value:
|
|
|
|
self.skipTest('b/355263220: Shardy does not support callbacks yet.')
|
2024-09-03 16:22:23 -07:00
|
|
|
mesh = jtu.create_mesh((2,), 'x')
|
2023-12-07 18:45:50 -08:00
|
|
|
s = NamedSharding(mesh, P('x'))
|
|
|
|
|
|
|
|
def _callback(x):
|
jax.pure_callback and jax.experimental.io_callback now use jax.Arrays
The motivation for this change is two-fold
* JAX APIs should use jax.Arrays.
* Using jax.Arrays potentially allows keeping the data on device, instead
of always copying it to the host. Note that the version here still always
copies to the host.
If this change breaks you, you can recover the old behavior by changing
jax.pure_callback(
f,
result_shape_dtypes,
*args,
**kwargs,
)
to
jax.pure_callback(
lambda *args: f(*jax.tree.map(np.asarray, args)),
result_shape_dtypes,
*args,
**kwargs,
)
so that the callback function is called with NumPy arrays as before.
I will update the "External callbacks" tutorial in a follow up.
PiperOrigin-RevId: 622457378
2024-04-06 09:29:16 -07:00
|
|
|
self.assertIsInstance(x, jax.Array)
|
2023-12-07 18:45:50 -08:00
|
|
|
|
|
|
|
@partial(pjit, donate_argnames=('x'))
|
|
|
|
def f(x):
|
|
|
|
# Just to get tokens.
|
|
|
|
jax.experimental.io_callback(_callback, None, x, ordered=True)
|
|
|
|
jax.experimental.io_callback(_callback, None, x, ordered=True)
|
|
|
|
return x * x
|
|
|
|
|
|
|
|
x = np.ones((2, 5)) * 4
|
|
|
|
x = jax.device_put(x, s)
|
|
|
|
f(x)
|
|
|
|
jax.effects_barrier()
|
|
|
|
self.assertDeleted(x)
|
|
|
|
|
|
|
|
@jtu.run_on_devices('tpu', 'cpu', 'gpu')
|
|
|
|
def testBufferDonationNotDonated(self):
|
2024-09-03 16:22:23 -07:00
|
|
|
mesh = jtu.create_mesh((2,), 'x')
|
2023-12-07 18:45:50 -08:00
|
|
|
s = NamedSharding(mesh, P('x'))
|
|
|
|
|
|
|
|
@partial(pjit, donate_argnames=('x'))
|
|
|
|
def f(x):
|
|
|
|
return x @ x.T
|
|
|
|
|
|
|
|
x = jax.device_put(np.arange(16).reshape(8, 2), s)
|
|
|
|
f(x)
|
|
|
|
self.assertNotDeleted(x)
|
|
|
|
|
2023-02-28 20:21:43 -08:00
|
|
|
@jtu.with_mesh([('x', 2), ('y', 1)])
|
|
|
|
def testShardingConstraintStablehlo(self):
|
|
|
|
@partial(pjit, in_shardings=None, out_shardings=None)
|
|
|
|
def f(x):
|
|
|
|
y = x + 1
|
|
|
|
y = with_sharding_constraint(y, P('x', 'y'))
|
|
|
|
return y * 2
|
|
|
|
|
|
|
|
shape = (8, 8)
|
|
|
|
x = np.arange(math.prod(shape)).reshape(shape)
|
|
|
|
expected = (x + 1) * 2
|
|
|
|
actual = f(x)
|
|
|
|
self.assertAllClose(actual, expected, check_dtypes=False)
|
|
|
|
_check_instance(self, actual)
|
2023-12-04 13:35:35 -08:00
|
|
|
self.assertLen(actual.addressable_shards, 2)
|
|
|
|
self.assertAllClose(np.asarray(actual.addressable_shards[0].data), expected,
|
2023-02-28 20:21:43 -08:00
|
|
|
check_dtypes=False)
|
|
|
|
|
2023-03-13 17:09:06 -07:00
|
|
|
hlo = f.lower(np.ones(shape)).compiler_ir()
|
2024-08-23 06:50:14 -07:00
|
|
|
if config.use_shardy_partitioner.value:
|
|
|
|
# Annotation from with_sharding_constraint
|
|
|
|
self.assertIn('<@mesh, [{"x"}, {"y"}]>', str(hlo))
|
|
|
|
# Annotation from pjit
|
|
|
|
self.assertIn('sharding = #sdy.sharding<@mesh, [{}, {}]>}', str(hlo))
|
|
|
|
else:
|
|
|
|
# Annotation from with_sharding_constraint
|
|
|
|
self.assertIn('sharding = "{devices=[2,1]<=[2]}"', str(hlo))
|
|
|
|
# Annotation from pjit
|
|
|
|
self.assertIn('sharding = "{replicated}"', str(hlo))
|
2023-02-28 20:21:43 -08:00
|
|
|
|
2022-08-16 07:51:17 -07:00
|
|
|
def testShardingConstraintWithArray(self):
|
2024-09-03 16:22:23 -07:00
|
|
|
mesh = jtu.create_mesh((2, 1), ('x', 'y'))
|
2022-11-14 14:43:26 -08:00
|
|
|
s = NamedSharding(mesh, P(None))
|
2022-08-16 07:51:17 -07:00
|
|
|
|
2023-02-28 14:28:32 -08:00
|
|
|
@partial(pjit, in_shardings=s, out_shardings=s)
|
2022-08-16 07:51:17 -07:00
|
|
|
def f(x):
|
|
|
|
y = x + 1
|
2022-11-14 14:43:26 -08:00
|
|
|
y = with_sharding_constraint(y, NamedSharding(mesh, P('x', 'y')))
|
2022-08-16 07:51:17 -07:00
|
|
|
return y * 2
|
|
|
|
|
|
|
|
shape = (8, 8)
|
2023-02-28 12:40:30 -08:00
|
|
|
x = np.arange(math.prod(shape)).reshape(shape)
|
2022-08-16 07:51:17 -07:00
|
|
|
expected = (x + 1) * 2
|
|
|
|
actual = f(x)
|
|
|
|
self.assertAllClose(actual, expected, check_dtypes=False)
|
2022-09-26 16:17:26 -07:00
|
|
|
self.assertIsInstance(actual, array.ArrayImpl)
|
2022-08-16 07:51:17 -07:00
|
|
|
self.assertLen(actual.addressable_shards, 2)
|
2023-01-30 20:01:58 -08:00
|
|
|
self.assertAllClose(actual, expected, check_dtypes=False)
|
2022-08-16 07:51:17 -07:00
|
|
|
|
|
|
|
hlo = f.lower(np.ones(shape)).compiler_ir(dialect="hlo")
|
|
|
|
# Annotation from with_sharding_constraint
|
2023-11-17 09:37:45 -08:00
|
|
|
self.assertIn('sharding={devices=[2,1]<=[2]}', hlo.as_hlo_text())
|
2022-08-16 07:51:17 -07:00
|
|
|
# Annotation from pjit
|
|
|
|
self.assertIn("sharding={replicated}", hlo.as_hlo_text())
|
|
|
|
|
|
|
|
def testShardingConstraintWithArrayOpSharding(self):
|
2024-08-23 06:50:14 -07:00
|
|
|
if config.use_shardy_partitioner.value:
|
|
|
|
self.skipTest("Shardy doesn't support PositionalSharding")
|
2022-08-16 07:51:17 -07:00
|
|
|
shape = (8, 8)
|
2024-09-03 16:22:23 -07:00
|
|
|
mesh = jtu.create_mesh((2, 1), ('x', 'y'))
|
2022-11-14 14:43:26 -08:00
|
|
|
s = NamedSharding(mesh, P(None))
|
2024-03-26 13:28:03 -07:00
|
|
|
ops = pxla.to_gspmd_sharding(
|
2022-11-14 14:43:26 -08:00
|
|
|
NamedSharding(mesh, P('x', 'y')), len(shape))
|
2022-08-16 07:51:17 -07:00
|
|
|
|
2023-02-28 14:28:32 -08:00
|
|
|
@partial(pjit, in_shardings=s, out_shardings=s)
|
2022-08-16 07:51:17 -07:00
|
|
|
def f(x):
|
|
|
|
y = x + 1
|
|
|
|
y = with_sharding_constraint(y, ops)
|
|
|
|
return y * 2
|
|
|
|
|
2023-02-28 12:40:30 -08:00
|
|
|
x = np.arange(math.prod(shape)).reshape(shape)
|
2022-08-16 07:51:17 -07:00
|
|
|
expected = (x + 1) * 2
|
|
|
|
actual = f(x)
|
|
|
|
self.assertAllClose(actual, expected, check_dtypes=False)
|
2022-09-26 16:17:26 -07:00
|
|
|
self.assertIsInstance(actual, array.ArrayImpl)
|
2022-08-16 07:51:17 -07:00
|
|
|
self.assertLen(actual.addressable_shards, 2)
|
2023-01-30 20:01:58 -08:00
|
|
|
self.assertAllClose(actual, expected, check_dtypes=False)
|
2022-08-16 07:51:17 -07:00
|
|
|
|
|
|
|
hlo = f.lower(np.ones(shape)).compiler_ir(dialect="hlo")
|
|
|
|
# Annotation from with_sharding_constraint
|
2023-11-17 09:37:45 -08:00
|
|
|
self.assertIn('sharding={devices=[2,1]<=[2]}', hlo.as_hlo_text())
|
2022-08-16 07:51:17 -07:00
|
|
|
# Annotation from pjit
|
|
|
|
self.assertIn("sharding={replicated}", hlo.as_hlo_text())
|
|
|
|
|
|
|
|
def testShardingConstraintPyTreeWithArray(self):
|
2024-09-03 16:22:23 -07:00
|
|
|
mesh = jtu.create_mesh((2, 1), ('x', 'y'))
|
2022-08-16 07:51:17 -07:00
|
|
|
|
2023-12-13 13:14:38 -08:00
|
|
|
@jax.jit
|
2022-08-16 07:51:17 -07:00
|
|
|
def f(x):
|
2023-12-13 13:14:38 -08:00
|
|
|
return with_sharding_constraint(x, NamedSharding(mesh, P('x', 'y')))
|
2022-08-16 07:51:17 -07:00
|
|
|
|
|
|
|
shape = (8, 8)
|
2023-02-28 12:40:30 -08:00
|
|
|
v = np.arange(math.prod(shape)).reshape(shape)
|
2023-12-13 13:14:38 -08:00
|
|
|
x = [v, v * 2]
|
|
|
|
out = f(x)
|
2022-08-16 07:51:17 -07:00
|
|
|
|
2023-12-13 13:14:38 -08:00
|
|
|
self.assertArraysEqual(out[0], v)
|
|
|
|
self.assertArraysEqual(out[1], v * 2)
|
|
|
|
self.assertLen(out[0].addressable_shards, 2)
|
|
|
|
self.assertLen(out[1].addressable_shards, 2)
|
2022-08-16 07:51:17 -07:00
|
|
|
|
|
|
|
hlo = f.lower(x).compiler_ir(dialect="hlo")
|
|
|
|
# Annotations from with_sharding_constraint
|
2023-11-17 09:37:45 -08:00
|
|
|
self.assertIn('sharding={devices=[2,1]<=[2]}', hlo.as_hlo_text())
|
2023-12-13 13:14:38 -08:00
|
|
|
self.assertIn('sharding={devices=[2,1]<=[2]}', hlo.as_hlo_text())
|
2021-02-24 09:40:29 -08:00
|
|
|
|
2023-05-20 16:27:28 -07:00
|
|
|
def testShardingConstraintPyTreeWithUnconstrainedDimsWithJit(self):
|
2022-01-13 10:34:45 -08:00
|
|
|
|
2024-09-03 16:22:23 -07:00
|
|
|
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
2023-05-20 16:27:28 -07:00
|
|
|
@jax.jit
|
2022-01-13 10:34:45 -08:00
|
|
|
def f(x):
|
|
|
|
x = with_sharding_constraint(
|
2023-05-20 16:27:28 -07:00
|
|
|
x, [NamedSharding(mesh, P(P.UNCONSTRAINED, 'y', None)),
|
|
|
|
NamedSharding(mesh, P('x', P.UNCONSTRAINED, None))])
|
2022-01-13 10:34:45 -08:00
|
|
|
x = x.copy()
|
|
|
|
x[0]['a'] *= 2
|
|
|
|
return x
|
|
|
|
|
|
|
|
shape = (2, 8, 8)
|
2023-02-28 12:40:30 -08:00
|
|
|
v = np.arange(math.prod(shape)).reshape(shape)
|
2022-01-13 10:34:45 -08:00
|
|
|
x = [{'a': v, 'b': v * 2}, v * 3]
|
|
|
|
actual = f(x)
|
|
|
|
|
|
|
|
expected = x.copy()
|
|
|
|
expected[0]['a'] *= 2
|
|
|
|
self.assertAllClose(actual, expected, check_dtypes=False)
|
2023-12-04 13:35:35 -08:00
|
|
|
self.assertLen(actual[0]['a'].addressable_shards, 4)
|
2022-01-13 10:34:45 -08:00
|
|
|
|
2022-12-15 20:59:34 -08:00
|
|
|
mlir_str = str(f.lower(x).compiler_ir())
|
2024-08-23 06:50:14 -07:00
|
|
|
if config.use_shardy_partitioner.value:
|
|
|
|
self.assertIn('<@mesh, [{?}, {"y"}, {}]>', mlir_str)
|
|
|
|
self.assertIn('<@mesh, [{"x"}, {?}, {}]>', mlir_str)
|
|
|
|
else:
|
|
|
|
self.assertIn("unspecified_dims=[0]", mlir_str)
|
|
|
|
self.assertIn("unspecified_dims=[1]", mlir_str)
|
2022-01-13 10:34:45 -08:00
|
|
|
|
2022-08-05 12:17:41 -07:00
|
|
|
@jtu.with_mesh([('x', 2), ('y', 2)])
|
|
|
|
def testShardingConstraintPyTreeVmapWithUnconstrainedDims(self):
|
|
|
|
|
2023-02-28 14:28:32 -08:00
|
|
|
@partial(pjit, in_shardings=None, out_shardings=None)
|
2022-08-05 12:17:41 -07:00
|
|
|
def f(x):
|
|
|
|
x = jax.vmap(lambda x: with_sharding_constraint(
|
|
|
|
x, [P(P.UNCONSTRAINED, 'y'),
|
|
|
|
P('x', P.UNCONSTRAINED)]))(x)
|
|
|
|
x = x.copy()
|
|
|
|
x[0]['a'] *= 2
|
|
|
|
return x
|
|
|
|
|
|
|
|
shape = (2, 8, 8)
|
2023-02-28 12:40:30 -08:00
|
|
|
v = np.arange(math.prod(shape)).reshape(shape)
|
2022-08-05 12:17:41 -07:00
|
|
|
x = [{'a': v, 'b': v * 2}, v * 3]
|
|
|
|
|
2022-12-15 20:59:34 -08:00
|
|
|
mlir_str = str(f.lower(x).compiler_ir())
|
2024-08-23 06:50:14 -07:00
|
|
|
if config.use_shardy_partitioner.value:
|
|
|
|
self.assertIn('<@mesh, [{?}, {?}, {"y"}]>', mlir_str)
|
|
|
|
self.assertIn('<@mesh, [{?}, {"x"}, {?}]>', mlir_str)
|
|
|
|
else:
|
|
|
|
self.assertIn("unspecified_dims=[0,1]", mlir_str)
|
|
|
|
self.assertIn("unspecified_dims=[0,2]", mlir_str)
|
2022-08-05 12:17:41 -07:00
|
|
|
|
2021-04-20 03:48:07 -07:00
|
|
|
def testCaching(self):
|
|
|
|
def f(x):
|
|
|
|
assert should_be_tracing
|
|
|
|
return jnp.sin(x) * 2
|
|
|
|
|
|
|
|
x = np.arange(16).reshape(4, 4)
|
|
|
|
devices = np.array(list(jax.local_devices())[:4])
|
|
|
|
if devices.size < 4:
|
2021-09-29 11:11:01 -07:00
|
|
|
raise unittest.SkipTest("Test requires 4 devices")
|
2021-04-20 03:48:07 -07:00
|
|
|
devices = devices.reshape((2, 2))
|
2023-02-03 14:28:07 -08:00
|
|
|
with jax.sharding.Mesh(devices, ('x', 'y')):
|
2021-04-20 03:48:07 -07:00
|
|
|
should_be_tracing = True
|
2023-02-18 09:59:58 -08:00
|
|
|
pjit(f, in_shardings=P(('x', 'y')), out_shardings=None)(x)
|
2021-04-20 03:48:07 -07:00
|
|
|
should_be_tracing = False
|
2023-02-18 09:59:58 -08:00
|
|
|
pjit(f, in_shardings=P(('x', 'y')), out_shardings=None)(x)
|
2021-04-20 03:48:07 -07:00
|
|
|
# Re-create the mesh to make sure that has no influence on caching
|
2023-02-03 14:28:07 -08:00
|
|
|
with jax.sharding.Mesh(devices, ('x', 'y')):
|
2021-04-20 03:48:07 -07:00
|
|
|
should_be_tracing = False
|
2023-02-18 09:59:58 -08:00
|
|
|
pjit(f, in_shardings=P(('x', 'y')), out_shardings=None)(x)
|
2021-04-20 03:48:07 -07:00
|
|
|
|
2021-06-01 14:32:59 +03:00
|
|
|
@jtu.with_mesh([('x', 2), ('y', 1)])
|
2021-04-21 04:09:30 -07:00
|
|
|
def testNested(self):
|
|
|
|
# Add a constant captured by the nested pjit to make things more complicated
|
2022-06-17 16:25:36 -07:00
|
|
|
h = jnp.arange(4.)
|
2023-02-18 09:59:58 -08:00
|
|
|
f = pjit(
|
|
|
|
lambda x: x.sum() + h.sum(),
|
|
|
|
in_shardings=P('x', 'y'),
|
|
|
|
out_shardings=None,
|
|
|
|
)
|
|
|
|
g = pjit(
|
|
|
|
lambda x: f(jnp.sin(x)), in_shardings=P('x', None), out_shardings=None
|
|
|
|
)
|
2022-06-17 16:25:36 -07:00
|
|
|
x = jnp.arange(16.).reshape((4, 4))
|
2021-04-21 04:09:30 -07:00
|
|
|
y = g(x)
|
|
|
|
self.assertAllClose(y, jnp.sin(x).sum() + h.sum())
|
2022-09-16 11:15:56 -07:00
|
|
|
_check_instance(self, y)
|
2021-04-21 04:09:30 -07:00
|
|
|
|
Add reverse-mode AD support for pjit
This is a somewhat big patch, because the transposition process turns out to be
quite difficult. The biggest issue appears when we do partial evaluation and we have
to add a whole bunch of intermediate values as outputs of the primal computation,
but we don't have any partition specs for them!
A simple workaround would be to mark all of them as replicated, but that would
likely tank performance which is why we didn't go with that option. Instead, we use
a newly added XLA option called `allow_spmd_sharding_propagation_to_output` to compile
a throwaway executable that lets us query output sharding that XLA considers convenient
for the computation.
However, there's one more difficulty: XLA's `OpSharding` is much less constrained
than our `PartitionSpec`s. In particular, while `PartitionSpec`s can only represent
"block permutations" of devices (with blocks deliniated by mesh axes), `OpSharding`
allows arbitrary assignment (permutation) of tensor chunks to devices. This means that
not every `OpSharding` has a corresponding `PartitionSpec`, but I did implement a
(somewhat involved) procedure that should recover one whenever it exists.
Unfortunately this makes our support for reverse-mode AD partial, because we might
be unable to handle `OpSharding` returned by XLA. But this will only happen if XLA
actually comes up with sharding specifications on its own. If it merely propagates
the sharding obtained from `PartitionSpec`s into the middle of the computation, then
we should be good. In any case, if we end up seeing failures in this path, we should
consider relaxing `PartitionSpec`s, but that would be a pretty large change, so I decided
to avoid it unless there's no other way.
PiperOrigin-RevId: 399680306
2021-09-29 07:19:28 -07:00
|
|
|
@check_1d_2d_mesh(set_mesh=True)
|
|
|
|
def testAutodiff(self, mesh, resources):
|
|
|
|
if len(mesh) != 2: return
|
|
|
|
assert resources == ('x', 'y')
|
2021-04-21 11:04:52 -07:00
|
|
|
# Add a constant captured by the nested pjit to make things more complicated
|
2022-06-17 16:25:36 -07:00
|
|
|
h = jnp.arange(4.)
|
2023-02-18 09:59:58 -08:00
|
|
|
f = pjit(
|
|
|
|
lambda x: x.sum(1) * h.sum(),
|
|
|
|
in_shardings=P('x', 'y'),
|
|
|
|
out_shardings=P(('x', 'y')),
|
|
|
|
)
|
|
|
|
g = pjit(
|
|
|
|
lambda x: f(jnp.sin(x * 4 + 2)),
|
|
|
|
in_shardings=P('x', None),
|
|
|
|
out_shardings=P(('x', 'y')),
|
|
|
|
)
|
2022-06-17 16:25:36 -07:00
|
|
|
jtu.check_grads(g, (jnp.arange(16.).reshape((4, 4)) / 100,), order=2)
|
2021-04-21 11:04:52 -07:00
|
|
|
|
2023-02-03 17:58:35 -08:00
|
|
|
@jtu.with_mesh([('x', 2), ('y', 1)])
|
|
|
|
def testAutodiffCache(self):
|
2024-08-29 09:42:35 -07:00
|
|
|
f = pjit(
|
|
|
|
lambda x: jnp.sin(x).sum(), in_shardings=P('x'), out_shardings=None
|
|
|
|
)
|
2023-02-03 17:58:35 -08:00
|
|
|
x = jnp.arange(16, dtype=jnp.float32)
|
2024-08-20 16:18:21 -07:00
|
|
|
jax.grad(f)(x) # Warm up the cache.
|
2024-08-29 09:42:35 -07:00
|
|
|
before = pjit_lib._pjit_lower_cached.cache_info()
|
|
|
|
jax.grad(f)(x)
|
|
|
|
after = pjit_lib._pjit_lower_cached.cache_info()
|
|
|
|
|
|
|
|
# One hit for the forward pass, one hit for backward.
|
|
|
|
self.assertEqual(after.hits, before.hits + 2)
|
|
|
|
self.assertEqual(after.misses, before.misses)
|
2023-02-03 17:58:35 -08:00
|
|
|
|
2021-06-01 14:32:59 +03:00
|
|
|
@jtu.with_mesh([('x', 2), ('y', 1)])
|
2021-04-22 15:30:03 -07:00
|
|
|
def testEvalJaxpr(self):
|
2022-06-17 16:25:36 -07:00
|
|
|
x, y = jnp.arange(4.), jnp.arange(5.)
|
2023-02-18 09:59:58 -08:00
|
|
|
f = pjit(
|
|
|
|
lambda x, y: x.sum() + jnp.sin(y),
|
|
|
|
in_shardings=(P('x'), P('y')),
|
|
|
|
out_shardings=P('y'),
|
|
|
|
)
|
2021-04-22 15:30:03 -07:00
|
|
|
f_jaxpr = jax.make_jaxpr(f)(x, y)
|
2023-02-14 23:00:40 -08:00
|
|
|
f_eval = core.jaxpr_as_fun(f_jaxpr)
|
2021-04-22 15:30:03 -07:00
|
|
|
r, = f_eval(x, y)
|
|
|
|
self.assertAllClose(r, x.sum() + jnp.sin(y))
|
2021-04-26 06:41:44 -07:00
|
|
|
|
2021-06-01 14:32:59 +03:00
|
|
|
@jtu.with_mesh([('x', 2)])
|
2021-04-26 06:41:44 -07:00
|
|
|
def testNonArrayArg(self):
|
2023-02-18 09:59:58 -08:00
|
|
|
self.assertEqual(
|
|
|
|
pjit(lambda x: x + 2, in_shardings=None, out_shardings=None)(1), 3
|
|
|
|
)
|
2021-04-26 06:41:44 -07:00
|
|
|
|
2021-06-01 14:32:59 +03:00
|
|
|
@jtu.with_mesh([('x', 2)])
|
2021-05-05 06:07:16 -07:00
|
|
|
def testNonHashableAxisResources(self):
|
|
|
|
x = jnp.arange(4)
|
2023-02-18 09:59:58 -08:00
|
|
|
y = pjit(
|
|
|
|
lambda x: {'b': x['a'] + 2},
|
|
|
|
in_shardings=({'a': P('x')},),
|
|
|
|
out_shardings={'b': P('x')},
|
|
|
|
)({'a': x})
|
2021-05-05 06:07:16 -07:00
|
|
|
self.assertAllClose(y, {'b': x + 2})
|
|
|
|
|
2021-06-01 14:32:59 +03:00
|
|
|
@jtu.with_mesh([('x', 2)])
|
2021-04-26 06:41:44 -07:00
|
|
|
def testGradOfConstraint(self):
|
|
|
|
# Make sure that we can compute grads through sharding constraints
|
|
|
|
h = lambda x: jnp.sin(with_sharding_constraint(x, P('x'))).sum()
|
2023-02-18 09:59:58 -08:00
|
|
|
f = pjit(lambda x: jax.grad(h)(x), in_shardings=None, out_shardings=None)
|
2021-04-26 06:41:44 -07:00
|
|
|
x = jnp.arange(8, dtype=jnp.float32)
|
2023-01-12 17:23:55 -08:00
|
|
|
out = f(x)
|
|
|
|
self.assertAllClose(out, jnp.cos(x))
|
2023-03-15 17:08:21 -07:00
|
|
|
self.assertLen(out.devices(), 2)
|
2021-04-22 15:30:03 -07:00
|
|
|
|
2021-06-01 14:32:59 +03:00
|
|
|
@jtu.with_mesh([('x', 2)])
|
2021-05-05 06:43:47 -07:00
|
|
|
def testNoopPartitionSpecs(self):
|
|
|
|
noops = [P(), P(None), P(()), P((), None), P(None, None, ())]
|
|
|
|
x = jnp.arange(8).reshape((2, 2, 2))
|
|
|
|
for spec in noops:
|
2023-02-18 09:59:58 -08:00
|
|
|
y = pjit(lambda x: x * 2, in_shardings=spec, out_shardings=spec)(x)
|
2021-05-05 06:43:47 -07:00
|
|
|
self.assertAllClose(y, x * 2)
|
2021-02-05 16:50:38 -08:00
|
|
|
|
2021-06-01 14:32:59 +03:00
|
|
|
@jtu.with_mesh([('x', 2)])
|
2021-05-06 12:34:15 -07:00
|
|
|
def testVMap(self):
|
2023-02-18 09:59:58 -08:00
|
|
|
f = pjit(lambda x, y: (x + y, x), in_shardings=P('x'), out_shardings=P('x'))
|
2021-05-06 12:34:15 -07:00
|
|
|
x = jnp.arange(4)
|
|
|
|
y = jnp.arange(5*4).reshape((5, 4))
|
|
|
|
z, w = jax.vmap(f, in_axes=(None, 0), out_axes=(0, None))(x, y)
|
2022-01-31 08:44:11 -08:00
|
|
|
self.assertAllClose(z, x[jnp.newaxis] + y)
|
2021-05-06 12:34:15 -07:00
|
|
|
self.assertAllClose(w, x)
|
2023-04-09 15:41:32 -07:00
|
|
|
self.assertEqual(
|
2023-06-05 13:40:59 -07:00
|
|
|
z.sharding._to_xla_hlo_sharding(z.ndim).tile_assignment_dimensions(),
|
2023-04-09 15:41:32 -07:00
|
|
|
[1, 2])
|
|
|
|
self.assertEqual(
|
2023-06-05 13:40:59 -07:00
|
|
|
w.sharding._to_xla_hlo_sharding(w.ndim).tile_assignment_dimensions(), [2])
|
2021-05-06 12:34:15 -07:00
|
|
|
|
2021-07-14 06:24:48 -07:00
|
|
|
@jtu.with_mesh([('x', 2)])
|
|
|
|
def testVMapShardingConstraint(self):
|
2023-02-18 09:59:58 -08:00
|
|
|
f = pjit(
|
|
|
|
lambda x: with_sharding_constraint(x, P('x')),
|
|
|
|
in_shardings=P(),
|
|
|
|
out_shardings=P('x'),
|
|
|
|
)
|
2021-07-14 06:24:48 -07:00
|
|
|
x = jnp.arange(5*4).reshape((5, 4))
|
|
|
|
jaxpr = jax.make_jaxpr(jax.vmap(f))(x)
|
|
|
|
pjit_eqn, = jaxpr.eqns
|
|
|
|
constraint_eqn, = pjit_eqn.params['jaxpr'].eqns
|
2024-03-26 13:28:03 -07:00
|
|
|
op = constraint_eqn.params['sharding']._to_xla_hlo_sharding(x.ndim)
|
2023-06-05 13:40:59 -07:00
|
|
|
self.assertTrue(op.is_tiled())
|
|
|
|
self.assertListEqual(op.tile_assignment_dimensions(), [1, 2])
|
|
|
|
self.assertListEqual(op.tile_assignment_devices(), [0, 1])
|
2023-04-06 08:31:47 -07:00
|
|
|
self.assertFalse(op_shardings.is_op_sharding_replicated(op))
|
2021-07-14 06:24:48 -07:00
|
|
|
|
2022-08-08 18:33:26 -07:00
|
|
|
@jtu.with_mesh([('x', 2)])
|
|
|
|
def testVMapShardingConstraintWithSpmdAxis(self):
|
|
|
|
f = pjit(
|
|
|
|
jax.vmap(
|
|
|
|
lambda x: with_sharding_constraint(x, P(None)),
|
|
|
|
spmd_axis_name='x',
|
|
|
|
),
|
2023-02-18 09:59:58 -08:00
|
|
|
in_shardings=P('x'),
|
|
|
|
out_shardings=P('x'),
|
|
|
|
)
|
2022-08-08 18:33:26 -07:00
|
|
|
x = jnp.arange(16 * 4).reshape((16, 4))
|
|
|
|
jaxpr = jax.make_jaxpr(f)(x)
|
|
|
|
pjit_eqn, = jaxpr.eqns
|
|
|
|
constraint_eqn, = pjit_eqn.params['jaxpr'].eqns
|
2024-03-26 13:28:03 -07:00
|
|
|
op = constraint_eqn.params['sharding']._to_xla_hlo_sharding(x.ndim)
|
2023-06-05 13:40:59 -07:00
|
|
|
self.assertTrue(op.is_tiled())
|
|
|
|
self.assertListEqual(op.tile_assignment_dimensions(), [2, 1])
|
|
|
|
self.assertListEqual(op.tile_assignment_devices(), [0, 1])
|
2023-04-06 08:31:47 -07:00
|
|
|
self.assertFalse(op_shardings.is_op_sharding_replicated(op))
|
2022-08-08 18:33:26 -07:00
|
|
|
|
2021-09-27 05:31:48 -07:00
|
|
|
@jtu.with_mesh([('x', 2)])
|
2021-10-27 20:27:09 -07:00
|
|
|
def testLowerWithDuckTyping(self):
|
2021-09-27 05:31:48 -07:00
|
|
|
x = jax.ShapeDtypeStruct((2, 2), jnp.float32)
|
|
|
|
# Make sure this doesn't crash
|
2023-02-18 09:59:58 -08:00
|
|
|
pjit(lambda x: x + 4, in_shardings=P('x'), out_shardings=P('x')).lower(x)
|
2021-11-16 11:21:27 -08:00
|
|
|
|
|
|
|
@jtu.with_mesh([('x', 2)])
|
|
|
|
def testLowerDonateArgnumsAvailable(self):
|
|
|
|
x = jax.ShapeDtypeStruct((2, 2), jnp.float32)
|
|
|
|
def f(*args):
|
|
|
|
x, *_ = args
|
|
|
|
return x
|
|
|
|
f_low = pjit(f, donate_argnums=(0,),
|
2023-02-28 14:28:32 -08:00
|
|
|
in_shardings=P('x'), out_shardings=P('x')).lower(x)
|
2021-11-16 11:21:27 -08:00
|
|
|
f_com = f_low.compile()
|
|
|
|
f_low.donate_argnums == f_com.donate_argnums == (0,)
|
2021-09-27 05:31:48 -07:00
|
|
|
|
2023-07-12 15:09:18 -07:00
|
|
|
@jtu.with_mesh([('x', 2)])
|
|
|
|
def testLowerDonateArgnumsAvailableWithNames(self):
|
|
|
|
x = jax.ShapeDtypeStruct((2, 2), jnp.float32)
|
|
|
|
def f(inp1):
|
|
|
|
return inp1
|
|
|
|
|
|
|
|
f_low = pjit(f, in_shardings=P('x'), out_shardings=P('x'),
|
|
|
|
donate_argnames=('inp1',)).lower(x)
|
|
|
|
f_com = f_low.compile()
|
|
|
|
f_low.donate_argnums == f_com.donate_argnums == (0,)
|
|
|
|
|
2022-09-09 20:18:14 -07:00
|
|
|
@unittest.skip('Fails in OSS builds on GPU with jax at HEAD and latest '
|
|
|
|
'jaxlib on pypi.')
|
2021-04-27 10:29:39 -07:00
|
|
|
def testInfeed(self):
|
|
|
|
devices = np.array(jax.local_devices())
|
|
|
|
nr_devices = len(devices)
|
|
|
|
shape = (nr_devices * 3, nr_devices * 5)
|
|
|
|
|
|
|
|
def f_for_jit(x):
|
|
|
|
token = lax.create_token(x)
|
|
|
|
(y,), token = lax.infeed(
|
2023-02-14 23:00:40 -08:00
|
|
|
token, shape=(core.ShapedArray(x.shape, np.float32),))
|
2021-04-27 10:29:39 -07:00
|
|
|
(z,), token = lax.infeed(
|
2023-02-14 23:00:40 -08:00
|
|
|
token, shape=(core.ShapedArray(x.shape, np.float32),))
|
2021-04-27 10:29:39 -07:00
|
|
|
(w,), token = lax.infeed(
|
2023-02-14 23:00:40 -08:00
|
|
|
token, shape=(core.ShapedArray(x.shape, np.float32),))
|
2021-04-27 10:29:39 -07:00
|
|
|
|
|
|
|
return x + y + z + w
|
|
|
|
|
2023-04-13 11:48:11 -07:00
|
|
|
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
2021-04-27 10:29:39 -07:00
|
|
|
y = x * 2.
|
|
|
|
z = x * 3.
|
|
|
|
w = x * 4.
|
|
|
|
|
2021-05-07 14:03:00 -07:00
|
|
|
# Transfer data to infeed before executing the function. For GPUs, the
|
|
|
|
# execution of the compiled function is blocking, so transferring data
|
|
|
|
# to infeed before executing ensures that the execution does not deadlock
|
|
|
|
# waiting for the infeed data.
|
2022-12-16 09:02:28 -08:00
|
|
|
logging.info('Transferring to infeed for the jit call')
|
2021-04-27 10:29:39 -07:00
|
|
|
d = devices[0]
|
|
|
|
d.transfer_to_infeed((y,))
|
|
|
|
d.transfer_to_infeed((z,))
|
|
|
|
d.transfer_to_infeed((w,))
|
2021-05-07 14:03:00 -07:00
|
|
|
|
|
|
|
# JIT
|
|
|
|
logging.info('Making jit call')
|
|
|
|
res0 = jax.jit(f_for_jit)(x)
|
2021-04-27 10:29:39 -07:00
|
|
|
self.assertAllClose(res0, x + y + z + w, check_dtypes=True)
|
|
|
|
|
|
|
|
# PJIT
|
|
|
|
def f_for_pjit(x):
|
|
|
|
token = lax.create_token(x)
|
|
|
|
# A replicated infeed
|
|
|
|
(y,), token = lax.infeed(
|
|
|
|
token,
|
2023-02-14 23:00:40 -08:00
|
|
|
shape=(core.ShapedArray(x.shape, np.float32),),
|
2021-04-27 10:29:39 -07:00
|
|
|
partitions=(None,))
|
|
|
|
# An infeed sharded on first axis
|
|
|
|
(z,), token = lax.infeed(
|
|
|
|
token,
|
2023-02-14 23:00:40 -08:00
|
|
|
shape=(core.ShapedArray(x.shape, np.float32),),
|
2021-04-27 10:29:39 -07:00
|
|
|
partitions=(P(nr_devices, 1),))
|
|
|
|
# An infeed sharded on second axis
|
|
|
|
(w,), token = lax.infeed(
|
|
|
|
token,
|
2023-02-14 23:00:40 -08:00
|
|
|
shape=(core.ShapedArray(x.shape, np.float32),),
|
2021-04-27 10:29:39 -07:00
|
|
|
partitions=(P(1, nr_devices),))
|
|
|
|
return x + y + z + w
|
|
|
|
|
2022-12-16 09:02:28 -08:00
|
|
|
logging.info('Transferring to infeed for the pjit call')
|
2021-04-27 10:29:39 -07:00
|
|
|
for didx, d in enumerate(devices):
|
|
|
|
# Transfer the whole array to all devices for replicated.
|
|
|
|
d.transfer_to_infeed((y,))
|
|
|
|
# For sharded infeed, transfer only the needed slices to each device.
|
2022-05-12 19:13:00 +01:00
|
|
|
d.transfer_to_infeed(z[3 * didx:3 * didx + 3, :])
|
2021-04-27 10:29:39 -07:00
|
|
|
d.transfer_to_infeed((w[:, 5 * didx:5 * didx + 5],))
|
|
|
|
|
2023-02-03 14:28:07 -08:00
|
|
|
with jax.sharding.Mesh(devices, ['d']):
|
2021-05-07 14:03:00 -07:00
|
|
|
logging.info('Making pjit call')
|
2023-02-18 09:59:58 -08:00
|
|
|
res = pjit(f_for_pjit, in_shardings=(P('d'),), out_shardings=P('d'))(x)
|
2021-05-07 14:03:00 -07:00
|
|
|
|
2021-04-27 10:29:39 -07:00
|
|
|
self.assertAllClose(res0, res, check_dtypes=True)
|
|
|
|
|
2021-07-01 11:59:13 -07:00
|
|
|
def testOutfeed(self):
|
2023-03-23 18:38:26 -07:00
|
|
|
if xla_bridge.using_pjrt_c_api():
|
|
|
|
raise unittest.SkipTest('outfeed not implemented in PJRT C API')
|
2024-08-23 06:50:14 -07:00
|
|
|
if config.use_shardy_partitioner.value:
|
|
|
|
self.skipTest(
|
|
|
|
'b/355263220: outfeed lowering not supported by Shardy')
|
2023-03-23 18:38:26 -07:00
|
|
|
|
2021-07-01 11:59:13 -07:00
|
|
|
devices = np.array(jax.local_devices())
|
|
|
|
nr_devices = len(devices)
|
|
|
|
shape = (nr_devices * 3, nr_devices * 5)
|
|
|
|
|
|
|
|
def f(x):
|
|
|
|
token = lax.create_token(x)
|
|
|
|
token = lax.outfeed(token, x, partitions=(None,))
|
|
|
|
token = lax.outfeed(token, x, partitions=(P(nr_devices, 1),))
|
|
|
|
token = lax.outfeed(token, x, partitions=(P(1, nr_devices),))
|
|
|
|
return x
|
|
|
|
|
2023-04-13 11:48:11 -07:00
|
|
|
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
2021-07-01 11:59:13 -07:00
|
|
|
|
2023-03-15 20:06:42 -07:00
|
|
|
def _dispatch():
|
2023-02-03 14:28:07 -08:00
|
|
|
with jax.sharding.Mesh(devices, ['d']):
|
2021-07-01 11:59:13 -07:00
|
|
|
logging.info('Making pjit call')
|
2023-02-18 09:59:58 -08:00
|
|
|
pjit(f, in_shardings=(P('d'),), out_shardings=P('d'))(x)
|
2023-03-15 20:06:42 -07:00
|
|
|
execution = threading.Thread(target=_dispatch)
|
2021-07-01 11:59:13 -07:00
|
|
|
execution.start()
|
|
|
|
|
2023-05-03 06:50:48 -07:00
|
|
|
# Check the expected outfeed for all devices.
|
|
|
|
def check_outfeed(x_fn):
|
|
|
|
for didx, d in enumerate(devices):
|
|
|
|
x = x_fn(didx)
|
|
|
|
y, = d.transfer_from_outfeed(
|
|
|
|
xc.shape_from_pyval((x,)).with_major_to_minor_layout_if_absent())
|
|
|
|
self.assertAllClose(x, y, check_dtypes=True)
|
2021-07-01 11:59:13 -07:00
|
|
|
|
2022-12-16 09:02:28 -08:00
|
|
|
logging.info('Transferring from outfeed for the pjit call')
|
2023-05-03 06:50:48 -07:00
|
|
|
|
|
|
|
# Note, when checking results of multiple outfeeds, the loop structure
|
|
|
|
# should be such that we check a given outfeed for all devices before
|
|
|
|
# moving on to the next outfeed. If there are any collectives generated
|
|
|
|
# by pjit, a loop structutre like:
|
|
|
|
# for each device:
|
|
|
|
# check outfeed#0;
|
|
|
|
# check outfeed#1;
|
|
|
|
#
|
|
|
|
# Could cause a deadlock if there is a collective scheduled between the
|
|
|
|
# 2 outfeeds, as device #0, after processing outfeed#0 will execute the
|
|
|
|
# collective, waiting for other devices to join, but other devices won't
|
|
|
|
# execute their collective until their outfeed#0 is executed. This is
|
|
|
|
# because, for GPU for example, execution of an outfeed on GPU is blocked
|
|
|
|
# till the corresponding `transfer_from_outfeed` is executed on the host.
|
|
|
|
|
|
|
|
# Transfer the whole array from all devices for replicated.
|
|
|
|
check_outfeed(lambda didx: x)
|
|
|
|
# For sharded outfeed, the results are sliced.
|
|
|
|
check_outfeed(lambda didx: x[3 * didx:3 * didx + 3, :])
|
|
|
|
check_outfeed(lambda didx: x[:, 5 * didx:5 * didx + 5])
|
2021-07-01 11:59:13 -07:00
|
|
|
|
|
|
|
execution.join()
|
2021-04-27 10:29:39 -07:00
|
|
|
|
2021-10-02 20:52:00 -07:00
|
|
|
@jtu.with_mesh([('x', 2)])
|
|
|
|
def testWithCustomPRNGKey(self):
|
2023-10-11 08:45:30 -07:00
|
|
|
if not config.enable_custom_prng.value:
|
2021-10-02 20:52:00 -07:00
|
|
|
raise unittest.SkipTest("test requires jax_enable_custom_prng")
|
2023-10-17 13:18:08 -07:00
|
|
|
key = prng.random_seed(87, impl=prng.rbg_prng_impl)
|
2021-10-02 20:52:00 -07:00
|
|
|
# Make sure this doesn't crash
|
2023-02-18 09:59:58 -08:00
|
|
|
pjit(lambda x: x, in_shardings=None, out_shardings=None)(key)
|
2021-10-02 20:52:00 -07:00
|
|
|
|
2021-10-08 21:19:37 -07:00
|
|
|
@jtu.with_mesh([('x', 2), ('y', 2)])
|
|
|
|
def testLowerCompile(self):
|
|
|
|
@partial(pjit,
|
2023-02-28 14:28:32 -08:00
|
|
|
in_shardings=P(('x', 'y'),),
|
|
|
|
out_shardings=P(('x', 'y'),))
|
2021-10-08 21:19:37 -07:00
|
|
|
def f(x, y):
|
|
|
|
return x @ y
|
|
|
|
|
|
|
|
shape = (8, 8)
|
2023-04-13 11:48:11 -07:00
|
|
|
x = jnp.arange(math.prod(shape)).reshape(shape)
|
2021-10-08 21:19:37 -07:00
|
|
|
expected = x @ (x + 1)
|
|
|
|
|
2022-03-09 09:56:32 -08:00
|
|
|
lowered = f.lower(x, x + 1)
|
|
|
|
compiled = lowered.compile()
|
|
|
|
actual = compiled(x, x + 1)
|
|
|
|
|
|
|
|
self.assertEqual(lowered.in_avals, compiled.in_avals)
|
|
|
|
self.assertEqual(
|
|
|
|
lowered.in_avals,
|
2023-02-14 23:00:40 -08:00
|
|
|
((core.ShapedArray(x.shape, x.dtype, weak_type=False),) * 2, {}))
|
2021-10-08 21:19:37 -07:00
|
|
|
|
|
|
|
splits = np.split(expected, 4)
|
2023-12-04 13:35:35 -08:00
|
|
|
self.assertAllClose(np.asarray(actual.addressable_shards[0].data), splits[0],
|
2021-10-08 21:19:37 -07:00
|
|
|
check_dtypes=False)
|
2023-12-04 13:35:35 -08:00
|
|
|
self.assertAllClose(np.asarray(actual.addressable_shards[1].data), splits[1],
|
2021-10-08 21:19:37 -07:00
|
|
|
check_dtypes=False)
|
2023-12-04 13:35:35 -08:00
|
|
|
self.assertAllClose(np.asarray(actual.addressable_shards[2].data), splits[2],
|
2021-10-08 21:19:37 -07:00
|
|
|
check_dtypes=False)
|
2023-12-04 13:35:35 -08:00
|
|
|
self.assertAllClose(np.asarray(actual.addressable_shards[3].data), splits[3],
|
2021-10-08 21:19:37 -07:00
|
|
|
check_dtypes=False)
|
|
|
|
|
2022-03-09 09:56:32 -08:00
|
|
|
for obj in [lowered, compiled]:
|
2023-01-18 17:13:39 -08:00
|
|
|
self.assertFalse(obj._no_kwargs)
|
2024-02-26 14:17:18 -08:00
|
|
|
self.assertEqual(obj.in_tree, jax.tree.flatten(((0, 0), {}))[1])
|
2022-03-07 02:36:09 -08:00
|
|
|
|
2021-10-08 21:19:37 -07:00
|
|
|
@jtu.with_mesh([('x', 2), ('y', 2)])
|
|
|
|
def testLowerCompileWithKwargs(self):
|
2023-01-18 17:13:39 -08:00
|
|
|
@pjit
|
2021-10-08 21:19:37 -07:00
|
|
|
def f(x, y, **kwargs):
|
|
|
|
return x @ y
|
|
|
|
|
|
|
|
shape = (8, 8)
|
2023-04-13 11:48:11 -07:00
|
|
|
x = jnp.arange(math.prod(shape)).reshape(shape)
|
2023-01-18 17:13:39 -08:00
|
|
|
exe = f.lower(x, x + 1, a=1, b=2).compile()
|
|
|
|
out = exe(x, x + 1, a=1, b=2)
|
|
|
|
self.assertArraysEqual(out, x @ (x + 1))
|
2021-10-08 21:19:37 -07:00
|
|
|
|
|
|
|
@jtu.with_mesh([('x', 2), ('y', 2)])
|
|
|
|
def testLowerCompileInTreeMismatch(self):
|
|
|
|
@partial(pjit,
|
2023-02-28 14:28:32 -08:00
|
|
|
in_shardings=P(('x', 'y'),),
|
|
|
|
out_shardings=P(('x', 'y'),))
|
2021-10-08 21:19:37 -07:00
|
|
|
def f(x, y):
|
|
|
|
return x @ y
|
|
|
|
|
|
|
|
shape = (8, 8)
|
2023-04-13 11:48:11 -07:00
|
|
|
x = jnp.arange(math.prod(shape)).reshape(shape)
|
2021-10-08 21:19:37 -07:00
|
|
|
exe = f.lower(x, x + 1).compile()
|
|
|
|
|
|
|
|
self.assertRaisesRegex(
|
2024-03-04 13:14:47 -08:00
|
|
|
TypeError,
|
|
|
|
'Function compiled with input pytree does not match the input pytree it'
|
|
|
|
' was called with',
|
|
|
|
lambda: exe([x], [x + 1]),
|
|
|
|
)
|
2021-10-08 21:19:37 -07:00
|
|
|
|
2021-10-13 10:45:11 -07:00
|
|
|
@jtu.with_mesh([('x', 2), ('y', 2)])
|
|
|
|
def testLowerCompileArgTypeMismatch(self):
|
|
|
|
@partial(pjit,
|
2023-02-28 14:28:32 -08:00
|
|
|
in_shardings=P(('x', 'y'),),
|
|
|
|
out_shardings=P(('x', 'y'),))
|
2021-10-13 10:45:11 -07:00
|
|
|
def f(x, y):
|
|
|
|
return x @ y
|
|
|
|
|
|
|
|
shape = (8, 8)
|
2023-04-13 11:48:11 -07:00
|
|
|
x = jnp.arange(math.prod(shape)).reshape(shape)
|
2021-10-13 10:45:11 -07:00
|
|
|
x_f32 = x.astype(jnp.float32)
|
|
|
|
x_i32 = x.astype(jnp.int32)
|
|
|
|
exe = f.lower(x_f32, x_f32).compile()
|
2023-04-19 15:08:21 -07:00
|
|
|
with self.assertRaisesRegex(
|
2021-10-13 10:45:11 -07:00
|
|
|
TypeError,
|
2023-07-10 18:28:50 -07:00
|
|
|
r"Argument types differ .*"
|
|
|
|
r"The mismatches are:\n"
|
|
|
|
r"Argument 'x' compiled with.*float32.*and called with.*int32.*\n"
|
|
|
|
r"Argument 'y' compiled with.*float32.*and called with.*int32.*"):
|
2023-04-19 15:08:21 -07:00
|
|
|
exe(x_i32, x_i32)
|
2021-10-13 10:45:11 -07:00
|
|
|
|
2022-07-01 17:35:17 -07:00
|
|
|
@jtu.with_mesh([('x', 2), ('y', 2)])
|
|
|
|
def testLowerAsText(self):
|
|
|
|
@partial(pjit,
|
2023-02-28 14:28:32 -08:00
|
|
|
in_shardings=P(('x', 'y'),),
|
|
|
|
out_shardings=P(('x', 'y'),))
|
2022-07-01 17:35:17 -07:00
|
|
|
def f(x, y):
|
|
|
|
return x @ y
|
|
|
|
|
|
|
|
shape = (8, 8)
|
2023-04-13 11:48:11 -07:00
|
|
|
x = jnp.arange(math.prod(shape)).reshape(shape)
|
2022-07-01 17:35:17 -07:00
|
|
|
f = f.lower(x, x + 1)
|
|
|
|
self.assertIsInstance(f.as_text(), str)
|
|
|
|
self.assertIsInstance(f.as_text(dialect='hlo'), str)
|
2022-12-27 08:52:39 -08:00
|
|
|
self.assertIsInstance(f.as_text(dialect='stablehlo'), str)
|
2022-07-01 17:35:17 -07:00
|
|
|
|
2022-01-13 15:42:17 -08:00
|
|
|
@jtu.with_mesh([('x', 2), ('y', 2)])
|
|
|
|
def testLowerCompilerIR(self):
|
|
|
|
@partial(pjit,
|
2023-02-28 14:28:32 -08:00
|
|
|
in_shardings=P(('x', 'y'),),
|
|
|
|
out_shardings=P(('x', 'y'),))
|
2022-01-13 15:42:17 -08:00
|
|
|
def f(x, y):
|
|
|
|
return x @ y
|
|
|
|
|
|
|
|
shape = (8, 8)
|
2023-04-13 11:48:11 -07:00
|
|
|
x = jnp.arange(math.prod(shape)).reshape(shape)
|
2022-01-13 15:42:17 -08:00
|
|
|
f = f.lower(x, x + 1)
|
|
|
|
self.assertIsNotNone(f.compiler_ir())
|
|
|
|
self.assertIsNotNone(f.compiler_ir(dialect='hlo'))
|
2022-12-27 08:52:39 -08:00
|
|
|
self.assertIsNotNone(f.compiler_ir(dialect='stablehlo'))
|
2022-01-13 15:42:17 -08:00
|
|
|
|
2023-04-13 08:55:01 -07:00
|
|
|
@jtu.with_mesh([('x', 2)])
|
|
|
|
def testLowerPartitionsAttribute(self):
|
|
|
|
@partial(pjit,
|
|
|
|
in_shardings=(P('x'), P('x')),
|
|
|
|
out_shardings=None)
|
|
|
|
def f(x, y):
|
|
|
|
return x + y
|
|
|
|
|
|
|
|
shape = (8, 8)
|
|
|
|
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
|
|
|
hlo = f.lower(x, x + 1).as_text("stablehlo")
|
|
|
|
self.assertIn("mhlo.num_replicas = 1", hlo)
|
|
|
|
self.assertIn("mhlo.num_partitions = 2", hlo)
|
|
|
|
|
2022-01-13 15:42:17 -08:00
|
|
|
@jtu.with_mesh([('x', 2), ('y', 2)])
|
|
|
|
def testLowerCompileCompilerIR(self):
|
|
|
|
@partial(pjit,
|
2023-02-28 14:28:32 -08:00
|
|
|
in_shardings=P(('x', 'y'),),
|
|
|
|
out_shardings=P(('x', 'y'),))
|
2022-01-13 15:42:17 -08:00
|
|
|
def f(x, y):
|
|
|
|
return x @ y
|
|
|
|
|
|
|
|
shape = (8, 8)
|
2023-04-13 11:48:11 -07:00
|
|
|
x = jnp.arange(math.prod(shape)).reshape(shape)
|
2022-01-13 15:42:17 -08:00
|
|
|
f = f.lower(x, x + 1).compile()
|
2023-07-11 09:24:08 -07:00
|
|
|
self.assertIsNotNone(f.runtime_executable())
|
2022-01-13 15:42:17 -08:00
|
|
|
|
2022-07-01 17:35:17 -07:00
|
|
|
@jtu.with_mesh([('x', 2), ('y', 2)])
|
|
|
|
def testLowerCompileAsText(self):
|
|
|
|
@partial(pjit,
|
2023-02-28 14:28:32 -08:00
|
|
|
in_shardings=P(('x', 'y'),),
|
|
|
|
out_shardings=P(('x', 'y'),))
|
2022-07-01 17:35:17 -07:00
|
|
|
def f(x, y):
|
|
|
|
return x @ y
|
|
|
|
|
|
|
|
shape = (8, 8)
|
2023-04-13 11:48:11 -07:00
|
|
|
x = jnp.arange(math.prod(shape)).reshape(shape)
|
2022-07-01 17:35:17 -07:00
|
|
|
f = f.lower(x, x + 1).compile()
|
|
|
|
self.assertIsInstance(f.as_text(), (str, type(None)))
|
|
|
|
|
2023-02-06 12:57:30 -08:00
|
|
|
@jtu.with_mesh([('x', 2), ('y', 2)])
|
|
|
|
def testLowerCostAnalysis(self):
|
|
|
|
@partial(pjit,
|
2023-02-28 14:28:32 -08:00
|
|
|
in_shardings=P(('x', 'y'),),
|
|
|
|
out_shardings=P(('x', 'y'),))
|
2023-02-06 12:57:30 -08:00
|
|
|
def f(x, y):
|
|
|
|
return x @ y
|
|
|
|
|
|
|
|
shape = (8, 8)
|
2023-04-13 11:48:11 -07:00
|
|
|
x = jnp.arange(math.prod(shape)).reshape(shape)
|
2023-02-06 12:57:30 -08:00
|
|
|
f = f.lower(x, x + 1)
|
|
|
|
f.cost_analysis() # doesn't raise
|
|
|
|
|
2022-07-01 17:35:17 -07:00
|
|
|
@jtu.with_mesh([('x', 2), ('y', 2)])
|
|
|
|
def testLowerCompileCostAnalysis(self):
|
|
|
|
@partial(pjit,
|
2023-02-28 14:28:32 -08:00
|
|
|
in_shardings=P(('x', 'y'),),
|
|
|
|
out_shardings=P(('x', 'y'),))
|
2022-07-01 17:35:17 -07:00
|
|
|
def f(x, y):
|
|
|
|
return x @ y
|
|
|
|
|
|
|
|
shape = (8, 8)
|
2023-04-13 11:48:11 -07:00
|
|
|
x = jnp.arange(math.prod(shape)).reshape(shape)
|
2022-07-01 17:35:17 -07:00
|
|
|
f = f.lower(x, x + 1).compile()
|
|
|
|
f.cost_analysis() # doesn't raise
|
|
|
|
|
|
|
|
@jtu.with_mesh([('x', 2), ('y', 2)])
|
|
|
|
def testLowerCompileMemoryAnalysis(self):
|
|
|
|
@partial(pjit,
|
2023-02-28 14:28:32 -08:00
|
|
|
in_shardings=P(('x', 'y'),),
|
|
|
|
out_shardings=P(('x', 'y'),))
|
2022-07-01 17:35:17 -07:00
|
|
|
def f(x, y):
|
|
|
|
return x @ y
|
|
|
|
|
|
|
|
shape = (8, 8)
|
2023-04-13 11:48:11 -07:00
|
|
|
x = jnp.arange(math.prod(shape)).reshape(shape)
|
2022-07-01 17:35:17 -07:00
|
|
|
f = f.lower(x, x + 1).compile()
|
|
|
|
f.memory_analysis() # doesn't raise
|
|
|
|
|
2022-01-13 15:42:17 -08:00
|
|
|
@jtu.with_mesh([('x', 2), ('y', 2)])
|
|
|
|
def testLowerCompileExecutable(self):
|
|
|
|
@partial(pjit,
|
2023-02-28 14:28:32 -08:00
|
|
|
in_shardings=P(('x', 'y'),),
|
|
|
|
out_shardings=P(('x', 'y'),))
|
2022-01-13 15:42:17 -08:00
|
|
|
def f(x, y):
|
|
|
|
return x @ y
|
|
|
|
|
|
|
|
shape = (8, 8)
|
2023-04-13 11:48:11 -07:00
|
|
|
x = jnp.arange(math.prod(shape)).reshape(shape)
|
2022-01-13 15:42:17 -08:00
|
|
|
|
|
|
|
f = f.lower(x, x + 1).compile()
|
|
|
|
self.assertIsNotNone(f.runtime_executable())
|
|
|
|
|
2022-01-19 18:44:31 -08:00
|
|
|
@jtu.with_mesh([('x', 2)])
|
|
|
|
def test_static_argnums(self):
|
2023-02-28 14:28:32 -08:00
|
|
|
@partial(pjit, in_shardings=None, out_shardings=None,
|
2022-01-19 18:44:31 -08:00
|
|
|
static_argnums=(1,))
|
|
|
|
def f(x, y):
|
|
|
|
return x + (3 if y == 'hi' else 4)
|
|
|
|
|
|
|
|
self.assertEqual(f(1, 'hi' ), 4)
|
|
|
|
self.assertEqual(f(1, 'bye'), 5)
|
2021-11-11 19:15:41 -08:00
|
|
|
|
2022-03-24 15:00:10 -07:00
|
|
|
@jtu.with_mesh([('x', 4), ('y', 2)])
|
|
|
|
def testLowerCompileWithAvals(self):
|
|
|
|
@partial(pjit,
|
2023-02-28 14:28:32 -08:00
|
|
|
in_shardings=P(('x', 'y'),),
|
|
|
|
out_shardings=P(('x', 'y'),))
|
2022-03-24 15:00:10 -07:00
|
|
|
def f(x, y):
|
|
|
|
return x @ y
|
|
|
|
|
|
|
|
shape = (8, 8)
|
2023-02-14 23:00:40 -08:00
|
|
|
aval = core.ShapedArray(shape, dtypes.canonicalize_dtype(jnp.int64))
|
2023-04-13 11:48:11 -07:00
|
|
|
x = jnp.arange(math.prod(shape)).reshape(shape)
|
2023-01-05 14:38:58 -08:00
|
|
|
exe = f.lower(aval, x).compile()
|
2022-03-24 15:00:10 -07:00
|
|
|
self.assertIsInstance(exe, stages.Compiled)
|
|
|
|
self.assertArraysEqual(exe(x, x), x @ x)
|
|
|
|
|
2022-08-25 17:13:33 -07:00
|
|
|
def test_local_sharded_key_array_sda(self):
|
|
|
|
input_shape = (8, 4)
|
2024-09-03 16:22:23 -07:00
|
|
|
mesh = jtu.create_mesh((4, 2), ('x', 'y'))
|
2022-08-25 17:13:33 -07:00
|
|
|
seeds = jnp.arange(
|
2023-02-28 12:40:30 -08:00
|
|
|
math.prod(input_shape), dtype=np.uint32).reshape(input_shape)
|
2022-08-25 17:13:33 -07:00
|
|
|
|
|
|
|
with mesh:
|
|
|
|
def make_keys(seeds):
|
2023-10-17 13:18:08 -07:00
|
|
|
make_key = partial(prng.random_seed, impl=prng.threefry_prng_impl)
|
2022-08-25 17:13:33 -07:00
|
|
|
return make_key(seeds)
|
|
|
|
|
2023-02-18 09:59:58 -08:00
|
|
|
f = pjit(make_keys, in_shardings=P(None), out_shardings=P(None))
|
2022-08-25 17:13:33 -07:00
|
|
|
|
|
|
|
out = f(seeds)
|
2023-09-13 11:37:43 -07:00
|
|
|
self.assertTrue(jax.dtypes.issubdtype(out.dtype, jax.dtypes.prng_key))
|
2022-08-25 17:13:33 -07:00
|
|
|
self.assertEqual(out.shape, input_shape)
|
2023-09-13 16:33:21 -07:00
|
|
|
jax.random.key_data(out) # doesn't crash
|
2022-08-25 17:13:33 -07:00
|
|
|
|
2022-09-09 09:13:10 -07:00
|
|
|
def test_with_sharding_constraint_is_compatible_error(self):
|
2024-09-03 16:22:23 -07:00
|
|
|
mesh = jtu.create_mesh((1, 1, 2), ('replica', 'data', 'mdl'))
|
2022-09-09 09:13:10 -07:00
|
|
|
|
|
|
|
with mesh:
|
|
|
|
def f(x):
|
|
|
|
y = with_sharding_constraint(x, P(None, ('mdl',), None, None))
|
|
|
|
z = y + 2
|
|
|
|
return z
|
2023-02-18 09:59:58 -08:00
|
|
|
pjit_f = pjit(f, in_shardings=P(None), out_shardings=P(None))
|
2022-09-09 09:13:10 -07:00
|
|
|
|
|
|
|
with self.assertRaisesRegex(
|
|
|
|
ValueError,
|
|
|
|
r"One of with_sharding_constraint.*Sharding "
|
2023-08-01 10:16:42 -07:00
|
|
|
r"NamedSharding\(mesh=Mesh\('replica': 1, 'data': 1, 'mdl': 2\), "
|
2023-09-11 11:54:29 -07:00
|
|
|
r"spec=PartitionSpec\(None, \('mdl',\), None, None\).*\) is only "
|
2022-09-09 09:13:10 -07:00
|
|
|
"valid for values of rank at least 4, but was applied to a value of rank 1"):
|
|
|
|
pjit_f(jnp.array([1, 2, 3]))
|
|
|
|
|
2023-12-07 15:56:56 +00:00
|
|
|
def test_pretty_print(self):
|
2023-12-08 20:10:08 +00:00
|
|
|
f = pjit(lambda x: x**2)
|
|
|
|
g = pjit(lambda x: f(x) + f(x))
|
2023-12-07 15:56:56 +00:00
|
|
|
x = jnp.array([4.2], dtype=jnp.float32)
|
2023-12-08 20:10:08 +00:00
|
|
|
jaxpr = jax.make_jaxpr(g)(x)
|
|
|
|
self.assertEqual(
|
|
|
|
jaxpr.pretty_print(),
|
|
|
|
textwrap.dedent("""
|
|
|
|
let lambda = { lambda ; a:f32[1]. let b:f32[1] = integer_pow[y=2] a in (b,) } in
|
|
|
|
{ lambda ; c:f32[1]. let
|
|
|
|
d:f32[1] = pjit[
|
|
|
|
name=<lambda>
|
|
|
|
jaxpr={ lambda ; e:f32[1]. let
|
|
|
|
f:f32[1] = pjit[name=<lambda> jaxpr=lambda] e
|
|
|
|
g:f32[1] = pjit[name=<lambda> jaxpr=lambda] e
|
|
|
|
h:f32[1] = add f g
|
|
|
|
in (h,) }
|
|
|
|
] c
|
|
|
|
in (d,) }
|
|
|
|
""").strip(),
|
|
|
|
)
|
|
|
|
|
|
|
|
def test_pretty_print_with_closure(self):
|
|
|
|
@pjit
|
|
|
|
def g(x, y):
|
|
|
|
@pjit
|
|
|
|
def f(x):
|
|
|
|
return x * y
|
|
|
|
return f(x) + f(y)
|
|
|
|
|
|
|
|
x = jnp.array([4.2], dtype=jnp.float32)
|
|
|
|
jaxpr = jax.make_jaxpr(g)(x, x)
|
2023-12-07 15:56:56 +00:00
|
|
|
self.assertEqual(
|
|
|
|
jaxpr.pretty_print(),
|
|
|
|
textwrap.dedent("""
|
2023-12-08 20:10:08 +00:00
|
|
|
let f = { lambda ; a:f32[1] b:f32[1]. let c:f32[1] = mul b a in (c,) } in
|
|
|
|
{ lambda ; d:f32[1] e:f32[1]. let
|
|
|
|
g:f32[1] = pjit[
|
|
|
|
name=g
|
|
|
|
jaxpr={ lambda ; h:f32[1] i:f32[1]. let
|
|
|
|
j:f32[1] = pjit[name=f jaxpr=f] i h
|
|
|
|
k:f32[1] = pjit[name=f jaxpr=f] i i
|
|
|
|
l:f32[1] = add j k
|
|
|
|
in (l,) }
|
|
|
|
] d e
|
|
|
|
in (g,) }
|
2023-12-07 15:56:56 +00:00
|
|
|
""").strip(),
|
|
|
|
)
|
|
|
|
|
2023-12-08 20:10:08 +00:00
|
|
|
def test_pretty_print_with_name_clash(self):
|
|
|
|
@pjit
|
|
|
|
def g(x, y):
|
|
|
|
@pjit
|
|
|
|
def f(x):
|
|
|
|
return x
|
|
|
|
|
|
|
|
return f(x)*f(x) + f(y)*f(y)
|
|
|
|
|
|
|
|
x = jnp.array([4.2], dtype=jnp.float32)
|
|
|
|
y = jnp.array([4.2, 2.4], dtype=jnp.float32)
|
|
|
|
jaxpr = jax.make_jaxpr(g)(x, y)
|
|
|
|
self.assertEqual(
|
2024-05-24 01:14:16 +00:00
|
|
|
jaxpr.pretty_print(use_color=False),
|
2023-12-08 20:10:08 +00:00
|
|
|
textwrap.dedent("""
|
2024-05-24 01:14:16 +00:00
|
|
|
let f = { lambda ; a:f32[1]. let in () } in
|
|
|
|
let f1 = { lambda ; b:f32[2]. let in () } in
|
2023-12-08 20:10:08 +00:00
|
|
|
{ lambda ; c:f32[1] d:f32[2]. let
|
|
|
|
e:f32[2] = pjit[
|
|
|
|
name=g
|
|
|
|
jaxpr={ lambda ; g:f32[1] h:f32[2]. let
|
2024-05-24 01:14:16 +00:00
|
|
|
pjit[name=f jaxpr=f] g
|
|
|
|
pjit[name=f jaxpr=f] g
|
|
|
|
i:f32[1] = mul g g
|
|
|
|
pjit[name=f jaxpr=f1] h
|
|
|
|
pjit[name=f jaxpr=f1] h
|
|
|
|
j:f32[2] = mul h h
|
|
|
|
k:f32[2] = add i j
|
|
|
|
in (k,) }
|
2023-12-08 20:10:08 +00:00
|
|
|
] c d
|
|
|
|
in (e,) }
|
|
|
|
""").strip(),
|
|
|
|
)
|
|
|
|
|
2024-05-04 03:27:31 +00:00
|
|
|
def test_with_sharding_constraint_vmap_spmd_axis_name_error(self):
|
2024-09-03 16:22:23 -07:00
|
|
|
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
2024-05-04 03:27:31 +00:00
|
|
|
|
|
|
|
def f(x):
|
|
|
|
return jax.lax.with_sharding_constraint(x, NamedSharding(mesh, P('x')))
|
|
|
|
|
|
|
|
xs = jnp.arange(4 * 16.).reshape(4, 16)
|
|
|
|
with self.assertRaisesRegex(ValueError, "spmd_axis_name"):
|
|
|
|
jax.vmap(f, spmd_axis_name='x')(xs)
|
|
|
|
|
2023-04-19 18:26:21 -07:00
|
|
|
|
|
|
|
@jtu.pytest_mark_if_available('multiaccelerator')
|
|
|
|
class CustomPartitionerTest(jtu.JaxTestCase):
|
|
|
|
|
|
|
|
def skip_if_custom_partitioning_not_supported(self):
|
2022-10-25 19:29:32 +00:00
|
|
|
if jtu.is_cloud_tpu():
|
|
|
|
raise unittest.SkipTest("Custom partitioning is not supported on libtpu.")
|
2024-08-23 06:50:14 -07:00
|
|
|
if config.use_shardy_partitioner.value:
|
|
|
|
self.skipTest(
|
|
|
|
'Custom partitioning is not supported with Shardy yet.')
|
2022-10-25 19:29:32 +00:00
|
|
|
|
2023-04-19 18:26:21 -07:00
|
|
|
@jtu.skip_on_devices('cpu') # Collectives don't seem to work on CPU.
|
|
|
|
@jtu.with_mesh([('x', 4), ('y', 2)])
|
|
|
|
def test_custom_partitioner(self):
|
|
|
|
self.skip_if_custom_partitioning_not_supported()
|
|
|
|
|
2023-08-09 17:08:27 -07:00
|
|
|
def partition(precision, mesh, arg_shapes, result_shape):
|
2024-02-22 11:35:39 -08:00
|
|
|
arg_shardings = jax.tree.map(lambda s: s.sharding, arg_shapes)
|
2023-05-19 16:58:21 -07:00
|
|
|
result_sharding = result_shape[0].sharding
|
2022-10-13 18:35:24 -07:00
|
|
|
self.assertEqual(arg_shardings[0], result_sharding)
|
2023-08-10 17:01:57 -07:00
|
|
|
self.assertEqual(P('x', None), result_sharding.spec)
|
|
|
|
self.assertEqual(P('y', None), arg_shardings[1].spec)
|
2022-10-13 18:35:24 -07:00
|
|
|
|
|
|
|
def lower_fn(x, y):
|
|
|
|
axis_name = arg_shardings[1].spec[0][0]
|
|
|
|
i = jax.lax.axis_index(axis_name)
|
2023-05-19 16:58:21 -07:00
|
|
|
z = jax.lax.psum(
|
|
|
|
jax.lax.dynamic_slice(x, (0, i * 8), (8, 8)) @ y, (axis_name)
|
|
|
|
)
|
|
|
|
return z, z * z
|
2022-10-13 18:35:24 -07:00
|
|
|
|
2023-08-09 17:08:27 -07:00
|
|
|
return mesh, lower_fn, (result_sharding, result_sharding), arg_shardings
|
2022-10-13 18:35:24 -07:00
|
|
|
|
2023-08-09 17:08:27 -07:00
|
|
|
def infer_sharding_from_operands(precision, mesh, arg_shapes, result_shape):
|
2024-02-22 11:35:39 -08:00
|
|
|
arg_shardings = jax.tree.map(lambda s: s.sharding, arg_shapes)
|
2022-10-13 18:35:24 -07:00
|
|
|
x_shard, y_shard = arg_shardings
|
|
|
|
x_shape, y_shape = arg_shapes
|
|
|
|
x_names = tuple(x_shard.spec) + tuple(
|
|
|
|
None for _ in range(len(x_shape.shape) - len(x_shard.spec)))
|
|
|
|
y_names = tuple(y_shard.spec) + tuple(
|
|
|
|
None for _ in range(len(y_shape.shape) - len(y_shard.spec)))
|
2023-05-19 16:58:21 -07:00
|
|
|
z_shard = NamedSharding(y_shard.mesh, P(*(x_names[:-1] + y_names[1:])))
|
|
|
|
return z_shard, z_shard
|
2022-10-13 18:35:24 -07:00
|
|
|
|
2023-02-03 11:30:31 -08:00
|
|
|
@partial(custom_partitioning, static_argnums=(2,))
|
|
|
|
def f(x, y, precision=None):
|
2023-05-19 16:58:21 -07:00
|
|
|
z = jnp.matmul(x, y, precision=precision)
|
|
|
|
return z, z * z
|
2022-10-13 18:35:24 -07:00
|
|
|
|
|
|
|
f.def_partition(
|
|
|
|
infer_sharding_from_operands=infer_sharding_from_operands,
|
|
|
|
partition=partition)
|
|
|
|
|
2023-02-18 09:59:58 -08:00
|
|
|
pjit_f = pjit(f, in_shardings=(P('x'), P('y')), out_shardings=P('x'))
|
2022-10-13 18:35:24 -07:00
|
|
|
x = np.asarray(np.random.randint(0, 20, (32, 16)), dtype=np.float32)
|
|
|
|
y = np.asarray(np.random.randint(0, 20, (16, 32)), dtype=np.float32)
|
|
|
|
result1 = jax.jit(f)(x, y)
|
|
|
|
result2 = f(x, y)
|
|
|
|
result0 = pjit_f(x, y)
|
|
|
|
self.assertArraysEqual(result0, result1)
|
|
|
|
self.assertArraysEqual(result1, result2)
|
2022-01-31 08:44:11 -08:00
|
|
|
|
2023-05-18 14:47:34 -07:00
|
|
|
@jtu.with_mesh([('x', 4), ('y', 2)])
|
|
|
|
def test_custom_partitioner_propagate_user_sharding(self):
|
|
|
|
self.skip_if_custom_partitioning_not_supported()
|
|
|
|
|
2023-08-09 17:08:27 -07:00
|
|
|
def partition(mesh, arg_shapes, result_shape):
|
2023-05-18 14:47:34 -07:00
|
|
|
def lower_fn(x):
|
|
|
|
return x
|
|
|
|
|
|
|
|
return (
|
2023-08-09 17:08:27 -07:00
|
|
|
mesh,
|
2023-05-18 14:47:34 -07:00
|
|
|
lower_fn,
|
|
|
|
arg_shapes[0].sharding,
|
2023-05-19 16:58:21 -07:00
|
|
|
(arg_shapes[0].sharding,),
|
2023-05-18 14:47:34 -07:00
|
|
|
)
|
|
|
|
|
2023-08-09 17:08:27 -07:00
|
|
|
def infer_sharding_from_operands(mesh, arg_shapes, result_shape):
|
2023-05-18 14:47:34 -07:00
|
|
|
return arg_shapes[0].sharding
|
|
|
|
|
2023-08-09 17:08:27 -07:00
|
|
|
def propagate_user_sharding(mesh, user_shape):
|
2023-05-18 14:47:34 -07:00
|
|
|
return user_shape.sharding
|
|
|
|
|
|
|
|
@custom_partitioning
|
|
|
|
def f(x):
|
|
|
|
return x
|
|
|
|
|
|
|
|
f.def_partition(
|
|
|
|
infer_sharding_from_operands=infer_sharding_from_operands,
|
|
|
|
partition=partition,
|
|
|
|
propagate_user_sharding=propagate_user_sharding,
|
|
|
|
)
|
|
|
|
|
|
|
|
def f2(a):
|
|
|
|
return a + f(a)
|
|
|
|
|
|
|
|
pjit_f = pjit(f2, in_shardings=(P(None, 'x')), out_shardings=P('x'))
|
|
|
|
x = np.asarray(np.random.randint(0, 20, (32, 16)), dtype=np.float32)
|
|
|
|
self.assertArraysEqual(x + x, pjit_f(x))
|
|
|
|
|
2023-04-19 18:26:21 -07:00
|
|
|
@jtu.with_mesh([('x', 4), ('y', 2)])
|
|
|
|
def test_custom_partitioner_sharding_override(self):
|
|
|
|
self.skip_if_custom_partitioning_not_supported()
|
|
|
|
|
2023-08-09 17:08:27 -07:00
|
|
|
def partition(mesh, arg_shapes, result_shape):
|
2023-04-19 18:26:21 -07:00
|
|
|
def lower_fn(x):
|
|
|
|
return x
|
|
|
|
|
2023-05-18 14:47:34 -07:00
|
|
|
y_shard = arg_shapes[0].sharding
|
2023-04-19 18:26:21 -07:00
|
|
|
return (
|
2023-08-09 17:08:27 -07:00
|
|
|
mesh,
|
2023-04-19 18:26:21 -07:00
|
|
|
lower_fn,
|
|
|
|
NamedSharding(y_shard.mesh, P(None)),
|
2023-05-19 16:58:21 -07:00
|
|
|
(NamedSharding(y_shard.mesh, P(None)),),
|
2023-04-19 18:26:21 -07:00
|
|
|
)
|
|
|
|
|
2023-08-09 17:08:27 -07:00
|
|
|
def infer_sharding_from_operands(mesh, arg_shapes, result_shape):
|
2023-05-18 14:47:34 -07:00
|
|
|
y_shard = arg_shapes[0].sharding
|
2023-04-19 18:26:21 -07:00
|
|
|
return NamedSharding(y_shard.mesh, P('x'))
|
|
|
|
|
|
|
|
@custom_partitioning
|
|
|
|
def f(x):
|
|
|
|
return x
|
|
|
|
|
|
|
|
f.def_partition(
|
|
|
|
infer_sharding_from_operands=infer_sharding_from_operands,
|
|
|
|
partition=partition,
|
|
|
|
)
|
|
|
|
|
|
|
|
pjit_f = pjit(f, in_shardings=(P(None, 'x')), out_shardings=P('x'))
|
|
|
|
x = np.asarray(np.random.randint(0, 20, (32, 16)), dtype=np.float32)
|
|
|
|
self.assertArraysEqual(x, pjit_f(x))
|
|
|
|
|
|
|
|
@jtu.with_mesh([('x', 4), ('y', 2)])
|
|
|
|
def test_custom_partitioner_invalid_sharding(self):
|
|
|
|
self.skip_if_custom_partitioning_not_supported()
|
2023-08-09 17:08:27 -07:00
|
|
|
def partition(mesh, arg_shapes, result_shape):
|
2023-04-19 18:26:21 -07:00
|
|
|
def lower_fn(x):
|
|
|
|
return x
|
|
|
|
|
2023-05-18 14:47:34 -07:00
|
|
|
y_shard = arg_shapes[0].sharding
|
2023-04-19 18:26:21 -07:00
|
|
|
return (
|
2023-08-09 17:08:27 -07:00
|
|
|
mesh,
|
2023-04-19 18:26:21 -07:00
|
|
|
lower_fn,
|
|
|
|
NamedSharding(y_shard.mesh, P(None)),
|
2023-05-19 16:58:21 -07:00
|
|
|
(NamedSharding(y_shard.mesh, P(None, 'x')),),
|
2023-04-19 18:26:21 -07:00
|
|
|
)
|
|
|
|
|
2023-08-09 17:08:27 -07:00
|
|
|
def infer_sharding_from_operands(mesh, arg_shapes, result_shape):
|
2023-05-18 14:47:34 -07:00
|
|
|
y_shard = arg_shapes[0].sharding
|
2023-04-19 18:26:21 -07:00
|
|
|
return NamedSharding(y_shard.mesh, P('x'))
|
|
|
|
|
|
|
|
@custom_partitioning
|
|
|
|
def f(x):
|
|
|
|
return x
|
|
|
|
|
|
|
|
f.def_partition(
|
|
|
|
infer_sharding_from_operands=infer_sharding_from_operands,
|
|
|
|
partition=partition,
|
|
|
|
)
|
|
|
|
|
|
|
|
pjit_f = pjit(f, in_shardings=(P(None, 'x')), out_shardings=P('x'))
|
|
|
|
x = np.asarray(np.random.randint(0, 20, (32, 16)), dtype=np.float32)
|
|
|
|
|
|
|
|
with self.assertRaisesRegex(Exception, 'Mismatch in result shapes.'):
|
|
|
|
pjit_f(x).block_until_ready()
|
|
|
|
|
2023-08-18 11:12:18 -07:00
|
|
|
@jtu.with_mesh([('x', 4)])
|
|
|
|
def test_custom_partitioner_jit_annotated_function(self):
|
|
|
|
"""Test correct lowering of function with a @jax.jit annotated callee.
|
|
|
|
|
|
|
|
Annotating a callee with @jax.jit results in a module with a HLO CallOp.
|
|
|
|
This test is makes sure that the custom partitioner lowering supports
|
|
|
|
CallOps.
|
|
|
|
"""
|
|
|
|
|
|
|
|
self.skip_if_custom_partitioning_not_supported()
|
|
|
|
|
|
|
|
@custom_partitioning
|
|
|
|
def f(x):
|
|
|
|
return x
|
|
|
|
|
|
|
|
def partition(mesh, arg_shapes, result_shape):
|
|
|
|
def lower_fn(x):
|
|
|
|
@jax.jit
|
|
|
|
def g(y):
|
|
|
|
return y
|
|
|
|
|
|
|
|
return g(x)
|
|
|
|
|
|
|
|
x_shard = arg_shapes[0].sharding
|
|
|
|
return (
|
|
|
|
mesh,
|
|
|
|
lower_fn,
|
|
|
|
NamedSharding(x_shard.mesh, P('x')),
|
|
|
|
(NamedSharding(x_shard.mesh, P('x')),),
|
|
|
|
)
|
|
|
|
|
|
|
|
def infer_sharding_from_operands(mesh, arg_shapes, result_shape):
|
|
|
|
x_shard = arg_shapes[0].sharding
|
|
|
|
return NamedSharding(x_shard.mesh, P('x'))
|
|
|
|
|
|
|
|
f.def_partition(
|
|
|
|
infer_sharding_from_operands=infer_sharding_from_operands,
|
|
|
|
partition=partition,
|
|
|
|
)
|
|
|
|
|
|
|
|
jit_f = jax.jit(f)
|
|
|
|
x = np.asarray(np.random.randint(0, 20, (32,)), dtype=np.float32)
|
|
|
|
pjit_f = pjit(jit_f, in_shardings=(P('x')), out_shardings=P('x'))
|
|
|
|
self.assertArraysEqual(x, pjit_f(x))
|
|
|
|
|
2024-04-23 14:24:28 +01:00
|
|
|
@jtu.with_mesh([('x', 4)])
|
|
|
|
def test_custom_partitioner_with_scan(self):
|
|
|
|
self.skip_if_custom_partitioning_not_supported()
|
|
|
|
|
|
|
|
# This is a reproducer from https://github.com/google/jax/issues/20864.
|
|
|
|
|
|
|
|
@custom_partitioning
|
|
|
|
def f(x):
|
|
|
|
return jnp.sum(x)
|
|
|
|
|
|
|
|
def partition(mesh, arg_shapes, result_shape):
|
|
|
|
def lower_fn(xs):
|
|
|
|
def f(carry, x):
|
|
|
|
return carry + jax.lax.psum(jnp.sum(x), axis_name='x'), None
|
|
|
|
|
|
|
|
carry, _ = jax.lax.scan(f, 0, xs)
|
|
|
|
return carry
|
|
|
|
|
|
|
|
result_shardings = jax.tree.map(lambda x: x.sharding, result_shape)
|
|
|
|
arg_shardings = jax.tree.map(lambda x: x.sharding, arg_shapes)
|
|
|
|
return mesh, lower_fn, result_shardings, arg_shardings
|
|
|
|
|
|
|
|
f.def_partition(
|
|
|
|
partition,
|
|
|
|
infer_sharding_from_operands=lambda mesh, *_: NamedSharding(mesh, P()),
|
|
|
|
propagate_user_sharding=lambda _, user_shape: user_shape.sharding)
|
|
|
|
|
|
|
|
pjit_f = pjit(f, in_shardings=P(None, 'x'))
|
|
|
|
xs = jnp.ones([32, 16])
|
|
|
|
self.assertEqual(pjit_f(xs), xs.sum())
|
|
|
|
|
2024-08-12 10:39:58 -07:00
|
|
|
def test_custom_partitioning_no_mesh_context(self):
|
|
|
|
self.skip_if_custom_partitioning_not_supported()
|
|
|
|
|
|
|
|
@custom_partitioning
|
|
|
|
def f(x):
|
|
|
|
return x
|
|
|
|
|
|
|
|
def partition(mesh, arg_shapes, result_shape):
|
|
|
|
def lower_fn(x):
|
|
|
|
@jax.jit
|
|
|
|
def g(y):
|
|
|
|
return y
|
|
|
|
|
|
|
|
return g(x)
|
|
|
|
|
|
|
|
x_shard = arg_shapes[0].sharding
|
|
|
|
return (
|
|
|
|
mesh,
|
|
|
|
lower_fn,
|
|
|
|
NamedSharding(x_shard.mesh, P('x')),
|
|
|
|
(NamedSharding(x_shard.mesh, P('x')),),
|
|
|
|
)
|
|
|
|
|
|
|
|
def infer_sharding_from_operands(mesh, arg_shapes, result_shape):
|
|
|
|
x_shard = arg_shapes[0].sharding
|
|
|
|
return NamedSharding(x_shard.mesh, P('x'))
|
|
|
|
|
|
|
|
f.def_partition(
|
|
|
|
infer_sharding_from_operands=infer_sharding_from_operands,
|
|
|
|
partition=partition,
|
|
|
|
)
|
|
|
|
|
2024-09-03 16:22:23 -07:00
|
|
|
mesh = jtu.create_mesh((4,), ('x',))
|
2024-08-12 10:39:58 -07:00
|
|
|
x = np.asarray(np.random.randint(0, 20, (32,)), dtype=np.float32)
|
|
|
|
s = NamedSharding(mesh, P('x'))
|
|
|
|
|
|
|
|
jit_f = jax.jit(f, in_shardings=s, out_shardings=s)
|
|
|
|
self.assertArraysEqual(x, jit_f(x))
|
|
|
|
|
2022-08-31 22:53:32 -07:00
|
|
|
|
2023-01-12 22:42:06 +00:00
|
|
|
@jtu.pytest_mark_if_available('multiaccelerator')
|
2022-03-31 18:21:02 -07:00
|
|
|
class AutoShardingPjitTest(jtu.JaxTestCase):
|
|
|
|
|
2022-09-16 11:15:56 -07:00
|
|
|
@parameterized.named_parameters(
|
|
|
|
('2d_array', (4, 2), (4, 2), ('x', 'y')),
|
2022-07-08 09:44:48 -07:00
|
|
|
# TODO(b/226977360): Support 3D mesh shape for example (2, 2, 2).
|
2022-09-16 11:15:56 -07:00
|
|
|
('3d_array', (1, 4, 2), (2, 4, 8, 4), ('x', 'y', 'z')),
|
|
|
|
('1d_array', (8,), (8, 2), ('x')),
|
2022-03-31 18:21:02 -07:00
|
|
|
)
|
2022-09-16 11:15:56 -07:00
|
|
|
def test_pjit_arr_auto_sharding_array(self, mesh_shape, global_input_shape,
|
|
|
|
mesh_axis_names):
|
2024-08-23 06:50:14 -07:00
|
|
|
if config.use_shardy_partitioner.value:
|
|
|
|
self.skipTest('Must register auto partitioner for Shardy')
|
2024-09-03 16:22:23 -07:00
|
|
|
global_mesh = jtu.create_mesh(mesh_shape, mesh_axis_names)
|
2022-04-13 10:17:41 -07:00
|
|
|
input_data = np.arange(
|
2023-02-28 12:40:30 -08:00
|
|
|
math.prod(global_input_shape), dtype=np.float32).reshape(global_input_shape)
|
2022-03-31 18:21:02 -07:00
|
|
|
|
2023-05-20 22:59:52 -07:00
|
|
|
f = jax.jit(lambda x: x, in_shardings=AUTO(global_mesh),
|
|
|
|
out_shardings=AUTO(global_mesh))
|
2022-03-31 18:21:02 -07:00
|
|
|
|
2023-05-20 22:59:52 -07:00
|
|
|
inp = core.ShapedArray(input_data.shape, input_data.dtype)
|
|
|
|
compiled = f.lower(inp).compile()
|
|
|
|
inputs = [create_array(global_input_shape, global_mesh, ip, input_data)[0]
|
|
|
|
for ip in compiled.input_shardings[0]]
|
|
|
|
out = compiled(*inputs)
|
|
|
|
self.assertIsInstance(out, array.ArrayImpl)
|
|
|
|
self.assertArraysEqual(out._value, input_data)
|
2022-03-31 18:21:02 -07:00
|
|
|
|
2023-03-15 11:28:25 -07:00
|
|
|
def test_xla_arr_sharding_mismatch(self):
|
2024-08-23 06:50:14 -07:00
|
|
|
if config.use_shardy_partitioner.value:
|
|
|
|
self.skipTest('Must register auto partitioner for Shardy')
|
2024-09-03 16:22:23 -07:00
|
|
|
global_mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
2024-06-10 09:37:54 -07:00
|
|
|
global_input_shape = (6, 2)
|
2022-04-13 10:17:41 -07:00
|
|
|
input_data = np.arange(
|
2023-02-28 12:40:30 -08:00
|
|
|
math.prod(global_input_shape), dtype=np.float32).reshape(global_input_shape)
|
2022-03-31 18:21:02 -07:00
|
|
|
|
2023-03-15 11:28:25 -07:00
|
|
|
with global_mesh:
|
2023-05-20 22:59:52 -07:00
|
|
|
f = pjit(lambda x: x, in_shardings=AUTO(global_mesh),
|
|
|
|
out_shardings=AUTO(global_mesh))
|
2023-03-15 11:28:25 -07:00
|
|
|
inp = core.ShapedArray(input_data.shape, input_data.dtype)
|
|
|
|
compiled = f.lower(inp).compile()
|
2022-04-13 10:17:41 -07:00
|
|
|
|
2024-01-29 13:22:37 -08:00
|
|
|
different_pspec = (
|
|
|
|
P('y', 'x')
|
|
|
|
if compiled.input_shardings[0][0].is_equivalent_to(
|
|
|
|
NamedSharding(global_mesh, P('x', 'y')), len(global_input_shape)
|
|
|
|
)
|
|
|
|
else P('x', 'y')
|
|
|
|
)
|
2023-03-15 11:28:25 -07:00
|
|
|
arr, _ = create_array(global_input_shape, global_mesh, different_pspec,
|
2022-07-08 09:44:48 -07:00
|
|
|
input_data)
|
2023-03-15 11:28:25 -07:00
|
|
|
with self.assertRaisesRegex(
|
2023-03-15 17:08:21 -07:00
|
|
|
ValueError,
|
2023-10-25 15:47:17 -07:00
|
|
|
r"Compiled object called with input sharding\(s\) does not match the "
|
|
|
|
r"sharding\(s\) the computation was compiled with.*\n.*for arg x"):
|
2023-03-15 11:28:25 -07:00
|
|
|
compiled(arr)
|
2022-04-04 14:33:17 -07:00
|
|
|
|
2022-05-23 15:01:58 -07:00
|
|
|
def test_gda_auto_shardings_len(self):
|
2024-08-23 06:50:14 -07:00
|
|
|
if config.use_shardy_partitioner.value:
|
|
|
|
self.skipTest('Must register auto partitioner for Shardy')
|
2024-09-03 16:22:23 -07:00
|
|
|
global_mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
2022-05-23 15:01:58 -07:00
|
|
|
global_input_shape = (4, 2)
|
|
|
|
input_data = np.arange(
|
2023-02-28 12:40:30 -08:00
|
|
|
math.prod(global_input_shape), dtype=np.float32).reshape(global_input_shape)
|
2022-05-23 15:01:58 -07:00
|
|
|
|
2022-07-08 09:44:48 -07:00
|
|
|
with global_mesh:
|
2023-05-20 22:59:52 -07:00
|
|
|
f = pjit(lambda x, y, z: (x, y, z), in_shardings=AUTO(global_mesh),
|
|
|
|
out_shardings=AUTO(global_mesh))
|
2023-02-14 23:00:40 -08:00
|
|
|
inp = core.ShapedArray(input_data.shape, input_data.dtype)
|
2023-01-05 14:38:58 -08:00
|
|
|
compiled = f.lower(inp, inp, inp).compile()
|
2022-07-08 09:44:48 -07:00
|
|
|
self.assertLen(compiled.output_shardings, 3)
|
2022-10-07 14:28:51 -07:00
|
|
|
self.assertLen(compiled.input_shardings[0], 3)
|
2022-05-23 15:01:58 -07:00
|
|
|
|
2022-09-16 11:15:56 -07:00
|
|
|
@parameterized.named_parameters(
|
|
|
|
('3d_array', (1, 1, 2), ('x', 'y', 'z'), P(('x', 'y', 'z'))),
|
|
|
|
('2d_array', (4, 2), ('x', 'y'), P('y', 'x')),
|
|
|
|
('1d_array', (8,), ('x'), P('x')),
|
|
|
|
)
|
2023-05-20 22:59:52 -07:00
|
|
|
def test_jit_arr_partial_auto_sharding_array(
|
2022-09-16 11:15:56 -07:00
|
|
|
self, mesh_shape, mesh_axis_names, pspec):
|
2024-08-23 06:50:14 -07:00
|
|
|
if config.use_shardy_partitioner.value:
|
|
|
|
self.skipTest('Must register auto partitioner for Shardy')
|
2024-09-03 16:22:23 -07:00
|
|
|
mesh = jtu.create_mesh(mesh_shape, mesh_axis_names)
|
2022-09-16 11:15:56 -07:00
|
|
|
global_input_shape = (8, 4)
|
|
|
|
input_data = np.arange(
|
2023-02-28 12:40:30 -08:00
|
|
|
math.prod(global_input_shape), dtype=np.float32).reshape(global_input_shape)
|
2023-05-20 22:59:52 -07:00
|
|
|
inp_s = NamedSharding(mesh, pspec)
|
|
|
|
f = jax.jit(
|
|
|
|
lambda x, y: (x, y),
|
|
|
|
in_shardings=(inp_s, AUTO(mesh)),
|
|
|
|
out_shardings=AUTO(mesh))
|
|
|
|
|
|
|
|
inp = core.ShapedArray(input_data.shape, input_data.dtype)
|
|
|
|
compiled = f.lower(inp, inp).compile()
|
|
|
|
inputs = [create_array(global_input_shape, mesh, ip, input_data)[0]
|
|
|
|
for ip in compiled.input_shardings[0]]
|
|
|
|
self.assertEqual(compiled.input_shardings[0][0], inp_s)
|
|
|
|
out1, out2 = compiled(*inputs)
|
|
|
|
for o in [out1, out2]:
|
|
|
|
self.assertIsInstance(o, array.ArrayImpl)
|
|
|
|
self.assertArraysEqual(o._value, input_data)
|
|
|
|
|
|
|
|
def test_jit_different_mesh_in_auto(self):
|
2024-09-03 16:22:23 -07:00
|
|
|
mesh1 = jtu.create_mesh((4,), ('x',))
|
2023-05-20 22:59:52 -07:00
|
|
|
dev = jax.devices()
|
|
|
|
mesh2 = jax.sharding.Mesh([dev[0], dev[3], dev[2], dev[1]], 'x')
|
|
|
|
f = jax.jit(lambda x, y: (x, y),
|
|
|
|
in_shardings=(NamedSharding(mesh2, P('x')), AUTO(mesh1)))
|
|
|
|
inp = core.ShapedArray((8, 2), np.float32)
|
|
|
|
with self.assertRaisesRegex(
|
|
|
|
ValueError,
|
|
|
|
"Received incompatible devices for jitted computation"):
|
|
|
|
f.lower(inp, inp).compile()
|
2022-07-08 09:44:48 -07:00
|
|
|
|
2024-03-25 17:45:59 -07:00
|
|
|
@parameterized.named_parameters(
|
|
|
|
('2d_array', (4, 2), ('x', 'y')),
|
|
|
|
('1d_array', (8,), ('x')),
|
|
|
|
)
|
|
|
|
def test_jit_auto_sharding_partial_tuple_input_shardings(
|
|
|
|
self, mesh_shape, mesh_axis_names):
|
|
|
|
if not jtu.test_device_matches(["tpu"]):
|
|
|
|
self.skipTest('Parameters are tupled only on TPU if >2000 parameters')
|
2024-08-23 06:50:14 -07:00
|
|
|
if config.use_shardy_partitioner.value:
|
|
|
|
self.skipTest('Must register auto partitioner for Shardy')
|
2024-03-25 17:45:59 -07:00
|
|
|
|
2024-09-03 16:22:23 -07:00
|
|
|
mesh = jtu.create_mesh(mesh_shape, mesh_axis_names)
|
2024-03-25 17:45:59 -07:00
|
|
|
global_input_shape = (8, 4)
|
|
|
|
input_data = np.arange(
|
|
|
|
math.prod(global_input_shape), dtype=np.float32).reshape(global_input_shape)
|
2024-03-27 12:03:18 -07:00
|
|
|
input_sharding = NamedSharding(mesh, P(mesh_axis_names)) # sharded
|
2024-03-25 17:45:59 -07:00
|
|
|
input_sharding_annotations = [AUTO(mesh)] * 2001
|
2024-03-27 12:03:18 -07:00
|
|
|
output_sharding = NamedSharding(mesh, P()) # replicated
|
|
|
|
output_sharding_annotations = [AUTO(mesh)] * 2001
|
2024-03-25 17:45:59 -07:00
|
|
|
for i in range(1000):
|
|
|
|
input_sharding_annotations[2*i] = input_sharding
|
2024-03-27 12:03:18 -07:00
|
|
|
output_sharding_annotations[2*i] = output_sharding
|
2024-03-25 17:45:59 -07:00
|
|
|
|
|
|
|
jit_tuple_identity_fn = jax.jit(
|
|
|
|
lambda *x: x,
|
|
|
|
in_shardings=input_sharding_annotations,
|
2024-03-27 12:03:18 -07:00
|
|
|
out_shardings=tuple(output_sharding_annotations))
|
2024-03-25 17:45:59 -07:00
|
|
|
|
|
|
|
inp = core.ShapedArray(input_data.shape, input_data.dtype)
|
|
|
|
compiled = jit_tuple_identity_fn.lower(*([inp] * 2001)).compile()
|
|
|
|
|
|
|
|
|
|
|
|
# Check sharding preservation for even numbered inputs.
|
|
|
|
for i in range(1000):
|
|
|
|
self.assertEqual(compiled.input_shardings[0][2*i], input_sharding)
|
2024-03-27 12:03:18 -07:00
|
|
|
self.assertEqual(compiled.output_shardings[2*i], output_sharding)
|
2024-03-25 17:45:59 -07:00
|
|
|
|
2022-08-19 07:44:16 -07:00
|
|
|
@unittest.skip('The error is not raised yet. Enable this back once we raise '
|
|
|
|
'the error in pjit again.')
|
|
|
|
def test_pjit_array_error(self):
|
2024-09-03 16:22:23 -07:00
|
|
|
global_mesh = jtu.create_mesh((4, 2), ('x', 'y'))
|
2022-08-19 07:44:16 -07:00
|
|
|
global_input_shape = (8, 2)
|
|
|
|
input_data = np.arange(
|
2023-02-28 12:40:30 -08:00
|
|
|
math.prod(global_input_shape), dtype=np.float32).reshape(global_input_shape)
|
2022-08-19 07:44:16 -07:00
|
|
|
|
2023-03-15 17:08:21 -07:00
|
|
|
with global_mesh:
|
2023-05-20 22:59:52 -07:00
|
|
|
f = pjit(lambda x: x, in_shardings=AUTO(global_mesh),
|
|
|
|
out_shardings=AUTO(global_mesh))
|
2022-08-19 07:44:16 -07:00
|
|
|
|
2023-03-15 17:08:21 -07:00
|
|
|
inp = core.ShapedArray(input_data.shape, input_data.dtype)
|
|
|
|
compiled = f.lower(inp).compile()
|
|
|
|
inputs = [create_array(global_input_shape, global_mesh, ip, input_data)[0]
|
|
|
|
for ip in compiled.input_shardings[0]]
|
|
|
|
with self.assertRaisesRegex(
|
|
|
|
ValueError,
|
|
|
|
('Passing sharding on pjit and on args while using the '
|
|
|
|
'auto spmd partitioner is not allowed. Please call the '
|
|
|
|
'compiled object on the inputs.')):
|
|
|
|
f(*inputs)
|
2022-08-19 07:44:16 -07:00
|
|
|
|
2022-03-31 18:21:02 -07:00
|
|
|
|
2023-01-12 22:42:06 +00:00
|
|
|
@jtu.pytest_mark_if_available('multiaccelerator')
|
2022-06-10 07:31:43 -07:00
|
|
|
class ArrayPjitTest(jtu.JaxTestCase):
|
|
|
|
|
2022-06-16 19:51:56 -07:00
|
|
|
@parameterized.named_parameters(
|
|
|
|
('fully_sharded_output', P('x', 'y'), (2, 4)),
|
|
|
|
('fully_replicated_output', P(None), (8, 8)),
|
|
|
|
)
|
|
|
|
def test_pjit_array_single_output(self, out_axis_resources, shard_shape):
|
2022-06-10 07:31:43 -07:00
|
|
|
global_input_shape = (8, 2)
|
2024-09-03 16:22:23 -07:00
|
|
|
global_mesh = jtu.create_mesh((4, 2), ('x', 'y'))
|
2022-06-10 07:31:43 -07:00
|
|
|
mesh_axes = P('x', 'y')
|
|
|
|
|
|
|
|
input_array, input_data = create_array(global_input_shape, global_mesh, mesh_axes)
|
|
|
|
|
2023-02-28 14:28:32 -08:00
|
|
|
f = pjit(lambda x: x @ x.T, out_shardings=NamedSharding(
|
2022-08-05 22:24:46 -07:00
|
|
|
global_mesh, out_axis_resources))
|
|
|
|
expected_matrix_mul = input_data @ input_data.T
|
2022-06-10 07:31:43 -07:00
|
|
|
|
2022-08-05 22:24:46 -07:00
|
|
|
out = f(input_array)
|
2022-09-26 16:17:26 -07:00
|
|
|
self.assertIsInstance(out, array.ArrayImpl)
|
2022-10-10 22:08:06 -07:00
|
|
|
self.assertTrue(out._committed)
|
2022-08-05 22:24:46 -07:00
|
|
|
self.assertEqual(out.shape, (8, 8))
|
|
|
|
self.assertEqual(out.addressable_shards[0].data.shape, shard_shape)
|
|
|
|
for s in out.addressable_shards:
|
2023-01-30 20:01:58 -08:00
|
|
|
self.assertLen(s.data.devices(), 1)
|
|
|
|
self.assertArraysEqual(s.data, expected_matrix_mul[s.index])
|
2022-08-05 22:24:46 -07:00
|
|
|
self.assertArraysEqual(out._value, expected_matrix_mul)
|
|
|
|
|
|
|
|
@parameterized.named_parameters(
|
|
|
|
('fully_sharded_output', P('x', 'y'), (2, 4)),
|
|
|
|
('fully_replicated_output', P(None), (8, 8)),
|
|
|
|
)
|
|
|
|
def test_pjit_array_single_output_with_mesh_context_manager(
|
|
|
|
self, out_axis_resources, shard_shape):
|
|
|
|
global_input_shape = (8, 2)
|
2024-09-03 16:22:23 -07:00
|
|
|
global_mesh = jtu.create_mesh((4, 2), ('x', 'y'))
|
2022-08-05 22:24:46 -07:00
|
|
|
mesh_axes = P('x', 'y')
|
|
|
|
|
|
|
|
input_array, input_data = create_array(global_input_shape, global_mesh, mesh_axes)
|
|
|
|
|
|
|
|
with global_mesh:
|
2023-02-28 14:28:32 -08:00
|
|
|
f = pjit(lambda x: x @ x.T, out_shardings=NamedSharding(
|
2022-08-05 22:24:46 -07:00
|
|
|
global_mesh, out_axis_resources))
|
|
|
|
expected_matrix_mul = input_data @ input_data.T
|
|
|
|
|
|
|
|
out = f(input_array)
|
2022-09-26 16:17:26 -07:00
|
|
|
self.assertIsInstance(out, array.ArrayImpl)
|
2022-08-05 22:24:46 -07:00
|
|
|
self.assertEqual(out.shape, (8, 8))
|
|
|
|
self.assertEqual(out.addressable_shards[0].data.shape, shard_shape)
|
|
|
|
for s in out.addressable_shards:
|
2023-01-30 20:01:58 -08:00
|
|
|
self.assertLen(s.data.devices(), 1)
|
|
|
|
self.assertArraysEqual(s.data, expected_matrix_mul[s.index])
|
2022-08-05 22:24:46 -07:00
|
|
|
self.assertArraysEqual(out._value, expected_matrix_mul)
|
2022-06-10 07:31:43 -07:00
|
|
|
|
2022-08-17 11:27:33 -07:00
|
|
|
def test_numpy_array_input_assume_fully_replicated(self):
|
2022-06-10 07:31:43 -07:00
|
|
|
input_shape = (8, 2)
|
2024-09-03 16:22:23 -07:00
|
|
|
global_mesh = jtu.create_mesh((4, 2), ('x', 'y'))
|
2022-06-10 07:31:43 -07:00
|
|
|
input_data = np.arange(
|
2024-02-27 09:06:21 -08:00
|
|
|
math.prod(input_shape)).reshape(input_shape)
|
|
|
|
|
|
|
|
f = pjit(lambda x: x,
|
|
|
|
out_shardings=NamedSharding(global_mesh, P('x', 'y')))
|
|
|
|
out = f(input_data)
|
|
|
|
self.assertIsInstance(out, array.ArrayImpl)
|
|
|
|
self.assertArraysEqual(out, input_data)
|
|
|
|
for s in out.addressable_shards:
|
|
|
|
self.assertEqual(s.data.shape, (2, 1))
|
|
|
|
self.assertArraysEqual(s.data, input_data[s.index])
|
2022-06-10 07:31:43 -07:00
|
|
|
|
2022-07-11 16:26:39 -07:00
|
|
|
def test_numpy_array_input(self):
|
|
|
|
input_shape = (8, 2)
|
2024-09-03 16:22:23 -07:00
|
|
|
global_mesh = jtu.create_mesh((4, 2), ('x', 'y'))
|
2022-07-11 16:26:39 -07:00
|
|
|
input_data = np.arange(
|
2023-02-28 12:40:30 -08:00
|
|
|
math.prod(input_shape), dtype=np.float32).reshape(input_shape)
|
2023-03-15 17:08:21 -07:00
|
|
|
with global_mesh:
|
|
|
|
f = pjit(
|
|
|
|
lambda x: x,
|
|
|
|
in_shardings=NamedSharding(global_mesh, P(None)),
|
|
|
|
out_shardings=NamedSharding(global_mesh, P('x', 'y')),
|
|
|
|
)
|
|
|
|
out = f(input_data)
|
|
|
|
self.assertIsInstance(out, array.ArrayImpl)
|
|
|
|
for s in out.addressable_shards:
|
|
|
|
self.assertEqual(s.data.shape, (2, 1))
|
|
|
|
self.assertArraysEqual(s.data, input_data[s.index])
|
|
|
|
self.assertArraysEqual(out._value, input_data)
|
|
|
|
|
2022-06-10 19:11:59 -07:00
|
|
|
def test_unspecified_out_axis_resources(self):
|
2022-08-05 13:54:33 -07:00
|
|
|
|
|
|
|
def _checks(out, input_data):
|
2022-09-26 16:17:26 -07:00
|
|
|
self.assertIsInstance(out, array.ArrayImpl)
|
2023-04-09 15:41:32 -07:00
|
|
|
self.assertIsInstance(out.sharding, NamedSharding)
|
2022-08-05 13:54:33 -07:00
|
|
|
self.assertEqual(out.shape, (8, 2))
|
|
|
|
self.assertEqual(out.addressable_shards[0].data.shape, (2, 1))
|
|
|
|
for s in out.addressable_shards:
|
2023-01-30 20:01:58 -08:00
|
|
|
self.assertLen(s.data.devices(), 1)
|
|
|
|
self.assertArraysEqual(s.data, input_data[s.index])
|
2022-08-05 13:54:33 -07:00
|
|
|
self.assertArraysEqual(out._value, input_data)
|
|
|
|
|
2022-06-10 19:11:59 -07:00
|
|
|
global_input_shape = (8, 2)
|
2024-09-03 16:22:23 -07:00
|
|
|
global_mesh = jtu.create_mesh((4, 2), ('x', 'y'))
|
2022-06-10 19:11:59 -07:00
|
|
|
mesh_axes = P('x', 'y')
|
|
|
|
|
|
|
|
input_array, input_data = create_array(global_input_shape, global_mesh, mesh_axes)
|
|
|
|
|
2022-10-21 16:53:14 -07:00
|
|
|
f = pjit(lambda x: x * 2)
|
2022-06-10 19:11:59 -07:00
|
|
|
|
2022-08-05 22:24:46 -07:00
|
|
|
out = f(input_array)
|
2022-10-21 16:53:14 -07:00
|
|
|
_checks(out, input_data * 2)
|
2022-08-05 13:54:33 -07:00
|
|
|
|
2022-08-05 22:24:46 -07:00
|
|
|
out2 = f(out)
|
2022-10-21 16:53:14 -07:00
|
|
|
_checks(out2, input_data * 4)
|
2022-06-10 19:11:59 -07:00
|
|
|
|
|
|
|
@parameterized.named_parameters(
|
2022-10-21 16:53:14 -07:00
|
|
|
('mesh1', (4, 2), (2, 8), (2, 2), (1, 2), (8, 2)),
|
|
|
|
('mesh2', (2, 2), (4, 8), (4, 2), (2, 2), (8, 2)),
|
|
|
|
('mesh3', (2, 1), (4, 8), (4, 2), (4, 2), (8, 2)),
|
2022-06-10 19:11:59 -07:00
|
|
|
)
|
|
|
|
def test_pjit_array_multi_input_multi_output(self, mesh_shape, s1_shape,
|
|
|
|
s2_shape, s3_shape, s4_shape):
|
2024-08-23 06:50:14 -07:00
|
|
|
if config.use_shardy_partitioner.value:
|
|
|
|
self.skipTest(
|
|
|
|
'TODO(b/355263220) Shardy conflict resolution is not complete. Issue '
|
|
|
|
'here is that for `a1 @ a1.T` GSPMD gives dim 0 sharded on `x` while '
|
|
|
|
'Shardy gives it fully replicated.')
|
2024-09-03 16:22:23 -07:00
|
|
|
global_mesh = jtu.create_mesh(mesh_shape, ('x', 'y'))
|
2022-06-10 19:11:59 -07:00
|
|
|
global_input_shape = (8, 2)
|
|
|
|
|
|
|
|
spec1 = P('x', 'y')
|
|
|
|
a1, input_data = create_array(global_input_shape, global_mesh, spec1)
|
|
|
|
spec2 = P('x')
|
|
|
|
a2, _ = create_array(global_input_shape, global_mesh, spec2)
|
|
|
|
spec3 = P(('x', 'y'))
|
|
|
|
a3, _ = create_array(global_input_shape, global_mesh, spec3)
|
|
|
|
spec4 = P(None)
|
|
|
|
a4, _ = create_array(global_input_shape, global_mesh, spec4)
|
|
|
|
|
2022-08-05 22:24:46 -07:00
|
|
|
@pjit
|
|
|
|
def f(tree):
|
|
|
|
return tree
|
2022-10-21 16:53:14 -07:00
|
|
|
out_tree = f((a1 @ a1.T, (a2, (a3 * 2, a4))))
|
2024-02-26 14:17:18 -08:00
|
|
|
(out1, out2, out3, out4), _ = jax.tree.flatten(out_tree)
|
2022-08-05 22:24:46 -07:00
|
|
|
|
2022-09-26 16:17:26 -07:00
|
|
|
self.assertIsInstance(out1, array.ArrayImpl)
|
2022-10-21 16:53:14 -07:00
|
|
|
self.assertEqual(out1.shape, (8, 8))
|
2022-08-05 22:24:46 -07:00
|
|
|
self.assertEqual(out1.addressable_shards[0].data.shape, s1_shape)
|
|
|
|
for s in out1.addressable_shards:
|
2022-10-21 16:53:14 -07:00
|
|
|
self.assertArraysEqual(
|
2023-01-30 20:01:58 -08:00
|
|
|
s.data, (input_data @ input_data.T)[s.index])
|
2022-08-05 22:24:46 -07:00
|
|
|
|
2022-09-26 16:17:26 -07:00
|
|
|
self.assertIsInstance(out2, array.ArrayImpl)
|
2022-08-05 22:24:46 -07:00
|
|
|
self.assertEqual(out2.shape, (8, 2))
|
|
|
|
self.assertEqual(out2.addressable_shards[0].data.shape, s2_shape)
|
|
|
|
for s in out2.addressable_shards:
|
2023-01-30 20:01:58 -08:00
|
|
|
self.assertArraysEqual(s.data, input_data[s.index])
|
2022-08-05 22:24:46 -07:00
|
|
|
|
2022-09-26 16:17:26 -07:00
|
|
|
self.assertIsInstance(out3, array.ArrayImpl)
|
2022-08-05 22:24:46 -07:00
|
|
|
self.assertEqual(out3.shape, (8, 2))
|
|
|
|
self.assertEqual(out3.addressable_shards[0].data.shape, s3_shape)
|
|
|
|
for s in out3.addressable_shards:
|
2023-01-30 20:01:58 -08:00
|
|
|
self.assertArraysEqual(s.data, (input_data * 2)[s.index])
|
2022-08-05 22:24:46 -07:00
|
|
|
|
2022-09-26 16:17:26 -07:00
|
|
|
self.assertIsInstance(out4, array.ArrayImpl)
|
2022-08-05 22:24:46 -07:00
|
|
|
self.assertEqual(out4.shape, (8, 2))
|
|
|
|
self.assertEqual(out4.addressable_shards[0].data.shape, s4_shape)
|
|
|
|
for s in out4.addressable_shards:
|
2023-01-30 20:01:58 -08:00
|
|
|
self.assertArraysEqual(s.data, input_data)
|
2022-06-10 19:11:59 -07:00
|
|
|
|
2024-03-22 09:29:47 -07:00
|
|
|
def test_sds_full_like(self):
|
|
|
|
# https://github.com/google/jax/issues/20390
|
2024-09-03 16:22:23 -07:00
|
|
|
mesh = jtu.create_mesh((2, 1), ('x', 'y'))
|
2024-03-22 09:29:47 -07:00
|
|
|
s = NamedSharding(mesh, P('x', 'y'))
|
|
|
|
x = jax.ShapeDtypeStruct((4, 4), jnp.float32, sharding=s)
|
|
|
|
y = jnp.zeros_like(x)
|
|
|
|
z = jnp.zeros_like(x, device=y.sharding)
|
|
|
|
|
|
|
|
self.assertEqual(x.sharding, s)
|
|
|
|
self.assertEqual(y.sharding, s)
|
|
|
|
self.assertEqual(z.sharding, s)
|
|
|
|
|
2022-07-08 09:44:48 -07:00
|
|
|
def test_in_axis_resources_mismatch_error(self):
|
|
|
|
global_input_shape = (8, 2)
|
2024-09-03 16:22:23 -07:00
|
|
|
global_mesh = jtu.create_mesh((4, 2), ('x', 'y'))
|
2022-07-08 09:44:48 -07:00
|
|
|
mesh_axes = P('x', 'y')
|
|
|
|
|
|
|
|
input_array, _ = create_array(global_input_shape, global_mesh, mesh_axes)
|
|
|
|
|
2023-03-15 17:08:21 -07:00
|
|
|
with global_mesh:
|
|
|
|
f = pjit(lambda x: x,
|
|
|
|
in_shardings=NamedSharding(global_mesh, P('x')))
|
|
|
|
err_msg = re.compile(
|
|
|
|
"Sharding passed to pjit does not match the sharding on the "
|
2023-08-23 13:24:08 -07:00
|
|
|
r"respective arg.*arg shape.*\[8,2\]", re.M | re.S)
|
2023-03-15 17:08:21 -07:00
|
|
|
with self.assertRaisesRegex(ValueError, err_msg):
|
|
|
|
f(input_array)
|
2022-07-08 09:44:48 -07:00
|
|
|
|
|
|
|
def test_in_axis_resources_same_as_array_sharding(self):
|
|
|
|
global_input_shape = (8, 2)
|
2024-09-03 16:22:23 -07:00
|
|
|
global_mesh = jtu.create_mesh((4, 2), ('x', 'y'))
|
2022-07-08 09:44:48 -07:00
|
|
|
mesh_axes = P('x', 'y')
|
|
|
|
|
|
|
|
input_array, _ = create_array(global_input_shape, global_mesh, mesh_axes)
|
|
|
|
|
2023-03-15 17:08:21 -07:00
|
|
|
with global_mesh:
|
|
|
|
out = pjit(
|
|
|
|
lambda x: x,
|
|
|
|
in_shardings=NamedSharding(global_mesh, P('x' ,'y')))(input_array)
|
|
|
|
self.assertIsInstance(out, array.ArrayImpl)
|
2022-07-08 09:44:48 -07:00
|
|
|
|
2022-07-07 10:41:27 -07:00
|
|
|
def test_no_input_output(self):
|
2023-03-15 17:08:21 -07:00
|
|
|
def f():
|
|
|
|
pass
|
|
|
|
pjit(f)
|
2022-06-10 07:31:43 -07:00
|
|
|
|
2022-07-08 09:44:48 -07:00
|
|
|
def test_array_device_assignment_mismatch_with_mesh(self):
|
|
|
|
global_input_shape = (8, 2)
|
2024-09-03 16:22:23 -07:00
|
|
|
global_mesh = jtu.create_mesh((4, 2), ('x', 'y'))
|
2022-07-08 09:44:48 -07:00
|
|
|
mesh_axes = P('x', 'y')
|
|
|
|
|
|
|
|
input_array, _ = create_array(
|
2024-09-03 16:22:23 -07:00
|
|
|
global_input_shape, jtu.create_mesh((2, 2), ('x', 'y')),
|
2022-07-08 09:44:48 -07:00
|
|
|
mesh_axes)
|
|
|
|
|
2023-03-15 17:08:21 -07:00
|
|
|
with global_mesh:
|
|
|
|
with self.assertRaisesRegex(
|
|
|
|
ValueError, "Received incompatible devices for pjitted computation"):
|
|
|
|
pjit(lambda x: x)(input_array)
|
2022-07-08 09:44:48 -07:00
|
|
|
|
|
|
|
def test_array_lower_compile(self):
|
|
|
|
global_input_shape = (8, 2)
|
2024-09-03 16:22:23 -07:00
|
|
|
global_mesh = jtu.create_mesh((4, 2), ('x', 'y'))
|
2022-07-08 09:44:48 -07:00
|
|
|
|
|
|
|
a1, input_data = create_array(global_input_shape, global_mesh, P('x', 'y'))
|
|
|
|
a2, _ = create_array(global_input_shape, global_mesh, P('x'))
|
|
|
|
|
2023-02-14 23:00:40 -08:00
|
|
|
aval = core.ShapedArray(global_input_shape, np.float32)
|
2022-07-08 09:44:48 -07:00
|
|
|
|
2023-03-15 17:08:21 -07:00
|
|
|
with global_mesh:
|
|
|
|
f = pjit(
|
2023-04-19 12:35:15 -07:00
|
|
|
lambda x, y, z, a, b, c: (x @ y.T, y, z, a, b, c),
|
2023-03-15 17:08:21 -07:00
|
|
|
in_shardings=NamedSharding(global_mesh, P('x' ,'y')))
|
2023-04-19 12:35:15 -07:00
|
|
|
compiled = f.lower(aval, aval, aval, aval, aval, aval).compile()
|
|
|
|
out, *_ = compiled(a1, a1, a1, a1, a1, a1)
|
2023-03-15 17:08:21 -07:00
|
|
|
self.assertIsInstance(out, array.ArrayImpl)
|
|
|
|
self.assertArraysEqual(out._value, input_data @ input_data.T)
|
|
|
|
|
|
|
|
with self.assertRaisesRegex(
|
2023-04-19 12:35:15 -07:00
|
|
|
ValueError,
|
2024-03-26 09:01:25 -07:00
|
|
|
r"Compiled object called with input sharding.*does not match the "
|
|
|
|
r"sharding.*the computation was compiled with. "
|
|
|
|
"Here are.*mismatches.*"):
|
2023-04-19 12:35:15 -07:00
|
|
|
compiled(a2, a2, a2, a2, a2, a2)
|
|
|
|
|
|
|
|
with global_mesh:
|
|
|
|
f = pjit(lambda a: a, in_shardings=NamedSharding(global_mesh, P('x' ,'y')))
|
|
|
|
abstract_inp = {'x': aval, 'y': {'y1': aval}}
|
|
|
|
inp1 = {'x': a1, 'y': {'y1': a1}}
|
|
|
|
compiled = f.lower(abstract_inp).compile()
|
|
|
|
compiled(inp1)
|
|
|
|
inp2 = {'x': a2, 'y': {'y1': a2}}
|
|
|
|
with self.assertRaisesRegex(
|
|
|
|
ValueError,
|
2024-03-26 09:01:25 -07:00
|
|
|
r"Compiled object called with input sharding.*does not match the "
|
|
|
|
r"sharding.*the computation was compiled with. "
|
|
|
|
"Here are the.*mismatches"):
|
2023-04-19 12:35:15 -07:00
|
|
|
compiled(inp2)
|
2023-03-15 17:08:21 -07:00
|
|
|
|
2022-08-25 12:22:42 -07:00
|
|
|
def test_globally_sharded_key_array_result_8x4_single_device(self):
|
|
|
|
input_shape = (8, 4)
|
|
|
|
seeds = jnp.arange(
|
2023-02-28 12:40:30 -08:00
|
|
|
math.prod(input_shape), dtype=np.uint32).reshape(input_shape)
|
2022-08-25 12:22:42 -07:00
|
|
|
|
|
|
|
@pjit
|
|
|
|
def make_keys(seeds):
|
2023-10-17 13:18:08 -07:00
|
|
|
make_key = partial(prng.random_seed, impl=prng.threefry_prng_impl)
|
2022-08-25 12:22:42 -07:00
|
|
|
return make_key(seeds)
|
|
|
|
|
|
|
|
out = make_keys(seeds)
|
2023-09-13 11:37:43 -07:00
|
|
|
self.assertTrue(jax.dtypes.issubdtype(out.dtype, jax.dtypes.prng_key))
|
2022-08-25 12:22:42 -07:00
|
|
|
self.assertEqual(out.shape, input_shape)
|
2023-09-13 16:33:21 -07:00
|
|
|
jax.random.key_data(out) # doesn't crash
|
2022-08-25 12:22:42 -07:00
|
|
|
|
2022-08-31 22:53:32 -07:00
|
|
|
def test_globally_sharded_key_array_8x4_multi_device_with_out_sharding(self):
|
|
|
|
input_shape = (8, 4)
|
2024-09-03 16:22:23 -07:00
|
|
|
mesh = jtu.create_mesh((4, 2), ('x', 'y'))
|
2022-08-31 22:53:32 -07:00
|
|
|
spec = P('x', 'y')
|
|
|
|
|
|
|
|
seeds, _ = create_array(input_shape, mesh, spec, dtype=np.uint32)
|
|
|
|
|
2023-02-28 14:28:32 -08:00
|
|
|
@partial(pjit, out_shardings=NamedSharding(mesh, P('x', 'y')))
|
2022-08-31 22:53:32 -07:00
|
|
|
def make_keys(seeds):
|
2023-10-17 13:18:08 -07:00
|
|
|
make_key = partial(prng.random_seed, impl=prng.threefry_prng_impl)
|
2022-08-31 22:53:32 -07:00
|
|
|
return make_key(seeds)
|
|
|
|
|
|
|
|
out = make_keys(seeds)
|
2023-09-13 11:37:43 -07:00
|
|
|
self.assertTrue(jax.dtypes.issubdtype(out.dtype, jax.dtypes.prng_key))
|
2022-08-31 22:53:32 -07:00
|
|
|
self.assertEqual(out.shape, input_shape)
|
2023-09-13 16:33:21 -07:00
|
|
|
jax.random.key_data(out) # doesn't crash
|
2022-08-31 22:53:32 -07:00
|
|
|
|
|
|
|
def test_globally_sharded_key_array_8x4_multi_device(self):
|
2022-08-25 12:22:42 -07:00
|
|
|
input_shape = (8, 4)
|
2024-09-03 16:22:23 -07:00
|
|
|
mesh = jtu.create_mesh((4, 2), ('x', 'y'))
|
2022-08-25 12:22:42 -07:00
|
|
|
spec = P('x', 'y')
|
|
|
|
|
|
|
|
seeds, _ = create_array(input_shape, mesh, spec, dtype=np.uint32)
|
|
|
|
|
|
|
|
@pjit
|
|
|
|
def make_keys(seeds):
|
2023-10-17 13:18:08 -07:00
|
|
|
make_key = partial(prng.random_seed, impl=prng.threefry_prng_impl)
|
2022-08-25 12:22:42 -07:00
|
|
|
return make_key(seeds)
|
|
|
|
|
|
|
|
out = make_keys(seeds)
|
2023-09-13 11:37:43 -07:00
|
|
|
self.assertTrue(jax.dtypes.issubdtype(out.dtype, jax.dtypes.prng_key))
|
2022-08-31 22:53:32 -07:00
|
|
|
self.assertEqual(out.shape, input_shape)
|
2023-09-13 16:33:21 -07:00
|
|
|
jax.random.key_data(out) # doesn't crash
|
2022-08-25 12:22:42 -07:00
|
|
|
|
2022-07-08 12:13:18 -07:00
|
|
|
def test_array_device_assignment_mismatch_out_shardings(self):
|
|
|
|
input_shape = (8, 2)
|
2024-09-03 16:22:23 -07:00
|
|
|
m1 = jtu.create_mesh((4, 2), ('x', 'y'))
|
|
|
|
m2 = jtu.create_mesh((2, 2), ('x', 'y'))
|
2022-07-08 12:13:18 -07:00
|
|
|
spec = P('x', 'y')
|
|
|
|
|
2023-02-28 12:40:30 -08:00
|
|
|
a1 = jnp.arange(math.prod(input_shape)).reshape(input_shape)
|
2022-07-08 12:13:18 -07:00
|
|
|
|
2023-03-15 17:08:21 -07:00
|
|
|
with m1:
|
|
|
|
with self.assertRaisesRegex(
|
|
|
|
ValueError, "Received incompatible devices for pjitted computation"):
|
|
|
|
pjit(lambda x, y: (x, y),
|
|
|
|
out_shardings=(NamedSharding(m1, spec),
|
|
|
|
NamedSharding(m2, spec)))(a1, a1)
|
2022-07-08 12:13:18 -07:00
|
|
|
|
2022-07-11 16:26:39 -07:00
|
|
|
def test_array_device_assignment_mismatch_in_and_out_shardings(self):
|
|
|
|
input_shape = (8, 2)
|
2024-09-03 16:22:23 -07:00
|
|
|
m1 = jtu.create_mesh((4, 2), ('x', 'y'))
|
|
|
|
m2 = jtu.create_mesh((2, 2), ('x', 'y'))
|
2022-07-11 16:26:39 -07:00
|
|
|
spec = P('x', 'y')
|
|
|
|
|
2023-02-28 12:40:30 -08:00
|
|
|
a1 = jnp.arange(math.prod(input_shape)).reshape(input_shape)
|
2022-07-11 16:26:39 -07:00
|
|
|
|
2023-03-15 17:08:21 -07:00
|
|
|
with m1:
|
|
|
|
with self.assertRaisesRegex(
|
|
|
|
ValueError, "Received incompatible devices for pjitted computation"):
|
|
|
|
pjit(
|
|
|
|
lambda x, y: (x, y),
|
|
|
|
in_shardings=NamedSharding(m2, spec),
|
|
|
|
out_shardings=NamedSharding(m1, spec),
|
|
|
|
)(a1, a1)
|
2022-07-11 16:26:39 -07:00
|
|
|
|
|
|
|
def test_mixed_inputs(self):
|
|
|
|
input_shape = (8, 2)
|
2024-09-03 16:22:23 -07:00
|
|
|
global_mesh = jtu.create_mesh((4, 2), ('x', 'y'))
|
2022-07-11 16:26:39 -07:00
|
|
|
spec = P('x', 'y')
|
|
|
|
|
|
|
|
a1, input_data = create_array(input_shape, global_mesh, spec)
|
|
|
|
|
2023-03-15 17:08:21 -07:00
|
|
|
with global_mesh:
|
|
|
|
f = pjit(lambda x, y: (x, y),
|
|
|
|
in_shardings=NamedSharding(global_mesh, P(None)))
|
|
|
|
with self.assertRaisesRegex(
|
|
|
|
ValueError,
|
|
|
|
('Sharding passed to pjit does not match the sharding on the '
|
|
|
|
'respective arg')):
|
|
|
|
f(input_data, a1)
|
2022-07-11 16:26:39 -07:00
|
|
|
|
2022-07-20 15:03:07 -07:00
|
|
|
def test_pjit_array_same_sharding_aot(self):
|
2024-09-03 16:22:23 -07:00
|
|
|
global_mesh = jtu.create_mesh((4, 2), ('x', 'y'))
|
2022-07-20 15:03:07 -07:00
|
|
|
input_shape = (8, 2)
|
|
|
|
a1, _ = create_array(input_shape, global_mesh, P(None,))
|
2023-03-15 17:08:21 -07:00
|
|
|
with global_mesh:
|
|
|
|
f = pjit(lambda x: x, in_shardings=NamedSharding(global_mesh, P(None,)))
|
|
|
|
compiled = f.lower(core.ShapedArray(input_shape, jnp.float32)).compile()
|
|
|
|
compiled(a1) # no error
|
2022-07-20 15:03:07 -07:00
|
|
|
|
2022-08-05 22:24:46 -07:00
|
|
|
def test_pjit_single_device_sharding_add(self):
|
2023-01-12 17:23:55 -08:00
|
|
|
a = np.array([1, 2, 3], dtype=jnp.float32)
|
|
|
|
b = np.array([4, 5, 6], dtype=jnp.float32)
|
2022-08-05 22:24:46 -07:00
|
|
|
|
|
|
|
@pjit
|
|
|
|
def add(x, y):
|
|
|
|
return x + y
|
2022-11-02 11:58:35 -07:00
|
|
|
|
2022-08-05 22:24:46 -07:00
|
|
|
out = add(a, b)
|
2022-10-10 22:08:06 -07:00
|
|
|
cache_info1 = pjit_lib._pjit_lower_cached.cache_info()
|
2022-09-26 16:17:26 -07:00
|
|
|
self.assertIsInstance(out, array.ArrayImpl)
|
2022-08-05 22:24:46 -07:00
|
|
|
self.assertArraysEqual(out, a + b)
|
2022-10-10 22:08:06 -07:00
|
|
|
self.assertFalse(out._committed)
|
2022-08-05 22:24:46 -07:00
|
|
|
|
|
|
|
out2 = add(out, out)
|
2022-10-10 22:08:06 -07:00
|
|
|
cache_info2 = pjit_lib._pjit_lower_cached.cache_info()
|
2022-09-26 16:17:26 -07:00
|
|
|
self.assertIsInstance(out2, array.ArrayImpl)
|
2022-08-05 22:24:46 -07:00
|
|
|
self.assertArraysEqual(out2, 2 * (a + b))
|
2022-10-10 22:08:06 -07:00
|
|
|
self.assertFalse(out2._committed)
|
|
|
|
|
|
|
|
self.assertEqual(cache_info2.hits, cache_info1.hits + 1)
|
|
|
|
self.assertEqual(cache_info2.misses, cache_info1.misses)
|
|
|
|
|
|
|
|
c = jax.device_put(a, jax.devices()[0])
|
|
|
|
out3 = add(c, c)
|
|
|
|
cache_info3 = pjit_lib._pjit_lower_cached.cache_info()
|
|
|
|
self.assertArraysEqual(out3, 2 * c)
|
|
|
|
self.assertTrue(out3._committed)
|
|
|
|
|
|
|
|
self.assertEqual(cache_info3.hits, cache_info2.hits)
|
|
|
|
self.assertEqual(cache_info3.misses, cache_info2.misses + 1)
|
|
|
|
|
|
|
|
out4 = add(out3, out3)
|
|
|
|
self.assertArraysEqual(out4, 4 * c)
|
|
|
|
self.assertTrue(out4._committed)
|
|
|
|
|
2022-08-05 22:24:46 -07:00
|
|
|
def test_pjit_single_device_sharding_mul(self):
|
|
|
|
a = jnp.arange(16).reshape((8, 2))
|
|
|
|
|
|
|
|
@pjit
|
|
|
|
def mul(x):
|
|
|
|
return x @ x.T
|
|
|
|
|
|
|
|
out = mul(a)
|
2022-09-26 16:17:26 -07:00
|
|
|
self.assertIsInstance(out, array.ArrayImpl)
|
2022-08-05 22:24:46 -07:00
|
|
|
self.assertArraysEqual(out, a @ a.T)
|
|
|
|
|
2022-08-08 08:55:34 -07:00
|
|
|
def test_pjit_single_device_sharding_cache(self):
|
|
|
|
a = jnp.arange(16).reshape((8, 2))
|
|
|
|
f = pjit(lambda x: x)
|
|
|
|
|
2023-08-25 10:59:10 -07:00
|
|
|
with jtu.count_pjit_cpp_cache_miss() as count:
|
|
|
|
out = f(a)
|
|
|
|
_ = f(out)
|
|
|
|
self.assertEqual(count[0], 1)
|
2022-08-08 08:55:34 -07:00
|
|
|
|
2022-08-11 14:35:28 -07:00
|
|
|
def test_pjit_different_device_recompilation(self):
|
|
|
|
if jax.device_count() < 2:
|
|
|
|
raise unittest.SkipTest('Requires 2 or more devices.')
|
|
|
|
|
|
|
|
val1 = jnp.array([1, 2, 3], dtype=jnp.float32)
|
|
|
|
a = jax.device_put(val1, jax.devices()[0])
|
|
|
|
|
|
|
|
val2 = jnp.array([4, 5, 6], dtype=jnp.float32)
|
|
|
|
b = jax.device_put(val2, jax.devices()[1])
|
|
|
|
|
|
|
|
f = pjit(lambda x: x)
|
|
|
|
|
|
|
|
out1 = f(a)
|
|
|
|
cache_info1 = pjit_lib._pjit_lower_cached.cache_info()
|
|
|
|
|
|
|
|
out2 = f(b)
|
|
|
|
cache_info2 = pjit_lib._pjit_lower_cached.cache_info()
|
|
|
|
|
|
|
|
self.assertEqual(cache_info2.hits, cache_info1.hits)
|
|
|
|
self.assertEqual(cache_info2.misses, cache_info1.misses + 1)
|
|
|
|
self.assertArraysEqual(out1, val1)
|
|
|
|
self.assertArraysEqual(out2, val2)
|
|
|
|
|
2022-08-17 11:27:33 -07:00
|
|
|
def test_grad_of_pjit_single_device_sharding(self):
|
|
|
|
a = jnp.array(16, dtype=jnp.float32)
|
2022-10-21 16:53:14 -07:00
|
|
|
f = lambda x: x * 3
|
2022-08-17 11:27:33 -07:00
|
|
|
out = jax.grad(pjit(f))(a)
|
2022-09-26 16:17:26 -07:00
|
|
|
self.assertIsInstance(out, array.ArrayImpl)
|
2022-08-17 11:27:33 -07:00
|
|
|
self.assertArraysEqual(out, jax.grad(f)(a))
|
|
|
|
|
|
|
|
def test_autodiff_with_single_device_sharding(self):
|
|
|
|
# Add a constant captured by the nested pjit to make things more complicated
|
|
|
|
h = jnp.arange(4.)
|
|
|
|
f = pjit(lambda x: x.sum(1) * h.sum())
|
|
|
|
g = pjit(lambda x: f(jnp.sin(x * 4 + 2)))
|
|
|
|
jtu.check_grads(g, (jnp.arange(16.).reshape((4, 4)) / 100,), order=2)
|
|
|
|
|
2022-08-23 10:19:59 -07:00
|
|
|
def test_fast_path_array(self):
|
|
|
|
devices = jax.devices()
|
|
|
|
if len(devices) < 8:
|
|
|
|
raise unittest.SkipTest("Test requires 8 global devices.")
|
|
|
|
mesh_devices = np.array([[devices[0], devices[2]],
|
|
|
|
[devices[3], devices[1]],
|
|
|
|
[devices[4], devices[6]],
|
|
|
|
[devices[7], devices[5]]])
|
|
|
|
shape = (8, 2)
|
2022-12-01 19:28:02 -08:00
|
|
|
mesh = jax.sharding.Mesh(mesh_devices, ('x', 'y'))
|
2022-11-14 14:43:26 -08:00
|
|
|
s = NamedSharding(mesh, P('x', 'y'))
|
2023-02-28 12:40:30 -08:00
|
|
|
inp_data = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
2022-08-23 10:19:59 -07:00
|
|
|
|
|
|
|
# Explicitly put on the ordering of devices which does not match the mesh
|
|
|
|
# ordering to make sure we reorder them in the constructor and the output
|
|
|
|
# is correct.
|
2023-09-13 16:35:02 +01:00
|
|
|
local_devices = jax.local_devices()[:8] # Take 8 local devices
|
2022-11-05 17:33:10 -07:00
|
|
|
di_map = s.devices_indices_map(shape)
|
2023-09-13 16:35:02 +01:00
|
|
|
bufs = [jax.device_put(inp_data[di_map[d]], d) for d in local_devices]
|
2023-02-14 23:00:40 -08:00
|
|
|
arr = array.ArrayImpl(core.ShapedArray(shape, np.float32), s, bufs, committed=True)
|
2022-08-23 10:19:59 -07:00
|
|
|
|
2023-02-28 14:28:32 -08:00
|
|
|
f = pjit(lambda x: x, out_shardings=s)
|
2022-08-23 10:19:59 -07:00
|
|
|
out = f(arr)
|
2023-01-30 20:01:58 -08:00
|
|
|
self.assertTrue(out.sharding.is_equivalent_to(arr.sharding, arr.ndim))
|
2022-08-23 10:19:59 -07:00
|
|
|
self.assertArraysEqual(out, inp_data)
|
|
|
|
out2 = f(out)
|
2023-01-30 20:01:58 -08:00
|
|
|
self.assertTrue(out2.sharding.is_equivalent_to(out.sharding, out.ndim))
|
2022-08-23 10:19:59 -07:00
|
|
|
self.assertArraysEqual(out2, inp_data)
|
|
|
|
|
2022-09-15 13:26:57 -07:00
|
|
|
def test_array_enabled_non_empty_mesh_with_pspec(self):
|
|
|
|
arr = jnp.array([1, 2, 3])
|
|
|
|
with self.assertRaisesRegex(
|
|
|
|
RuntimeError,
|
2023-03-17 13:33:45 -07:00
|
|
|
r'pjit requires a non-empty mesh if you are passing `PartitionSpec`s or'
|
2023-05-03 19:28:54 -07:00
|
|
|
r' `None` to in_shardings.*'):
|
2023-02-28 14:28:32 -08:00
|
|
|
pjit(lambda x: x, in_shardings=P('x'))(arr)
|
2022-09-15 13:26:57 -07:00
|
|
|
|
|
|
|
with self.assertRaisesRegex(
|
|
|
|
TypeError,
|
2023-02-11 15:29:38 -08:00
|
|
|
"in_shardings leaf specifications are expected to be PartitionSpec "
|
2022-09-15 13:26:57 -07:00
|
|
|
"instances or None, but got x"):
|
2023-02-28 14:28:32 -08:00
|
|
|
pjit(lambda x: x, in_shardings='x')
|
2022-09-15 13:26:57 -07:00
|
|
|
|
2022-09-16 11:15:56 -07:00
|
|
|
def test_pjit_uncommitted_array_reshard(self):
|
|
|
|
arr = jnp.array([[1, 2, 3]])
|
2024-09-03 16:22:23 -07:00
|
|
|
mesh = jtu.create_mesh((4, 2), ('x', 'y'))
|
2022-09-16 11:15:56 -07:00
|
|
|
with mesh:
|
|
|
|
out = pjit(lambda x: x)(arr)
|
|
|
|
self.assertArraysEqual(out, arr)
|
|
|
|
self.assertLen(out.addressable_shards, 8)
|
|
|
|
|
2022-10-08 11:39:05 -07:00
|
|
|
def test_pjit_uncommitted_array_in_axis_resources_reshard(self):
|
|
|
|
arr = jnp.arange(16).reshape(8, 2)
|
2024-09-03 16:22:23 -07:00
|
|
|
mesh = jtu.create_mesh((4, 2), ('x', 'y'))
|
2022-10-08 11:39:05 -07:00
|
|
|
with mesh:
|
2023-02-28 14:28:32 -08:00
|
|
|
out = pjit(lambda x: x, in_shardings=P('x', 'y'))(arr)
|
2022-10-08 11:39:05 -07:00
|
|
|
self.assertArraysEqual(out, arr)
|
|
|
|
self.assertLen(out.addressable_shards, 8)
|
|
|
|
for s in out.addressable_shards:
|
|
|
|
self.assertArraysEqual(s.data, arr[s.index])
|
|
|
|
self.assertEqual(s.replica_id, 0)
|
|
|
|
|
2022-09-16 11:15:56 -07:00
|
|
|
def test_pjit_uncommitted_array_and_committed_array(self):
|
|
|
|
shape = (8, 2)
|
2023-02-28 12:40:30 -08:00
|
|
|
uarr = jnp.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
2024-09-03 16:22:23 -07:00
|
|
|
mesh = jtu.create_mesh((4, 2), ('x', 'y'))
|
2022-09-16 11:15:56 -07:00
|
|
|
carr, inp_data = create_array(shape, mesh, P('x', 'y'))
|
|
|
|
with mesh:
|
|
|
|
out1, out2 = pjit(lambda x, y: (x, y))(uarr, carr)
|
|
|
|
self.assertArraysEqual(out1, inp_data)
|
|
|
|
self.assertArraysEqual(out2, inp_data)
|
|
|
|
self.assertLen(out1.addressable_shards, 8)
|
|
|
|
self.assertLen(out2.addressable_shards, 8)
|
|
|
|
|
|
|
|
mul_out = pjit(lambda x, y: x @ y.T)(uarr, carr)
|
|
|
|
self.assertEqual(mul_out.shape, (8, 8))
|
|
|
|
self.assertLen(mul_out.addressable_shards, 8)
|
|
|
|
|
2024-09-03 16:22:23 -07:00
|
|
|
with jtu.create_mesh((2, 2), ('x', 'y')):
|
2022-09-16 11:15:56 -07:00
|
|
|
with self.assertRaisesRegex(
|
|
|
|
ValueError,
|
2023-02-10 13:53:43 -08:00
|
|
|
"Received incompatible devices for pjitted computation"):
|
2022-09-16 11:15:56 -07:00
|
|
|
pjit(lambda x, y: (x, y))(uarr, carr)
|
|
|
|
|
|
|
|
def test_pjit_uncommitted_array_multi_devices(self):
|
|
|
|
shape = (8, 2)
|
2024-09-03 16:22:23 -07:00
|
|
|
mesh = jtu.create_mesh((4, 2), ('x', 'y'))
|
2023-02-28 12:40:30 -08:00
|
|
|
inp = np.arange(math.prod(shape), dtype=np.int32).reshape(shape)
|
2022-09-26 16:17:26 -07:00
|
|
|
arr = array.ArrayImpl(
|
2023-02-14 23:00:40 -08:00
|
|
|
core.ShapedArray(shape, np.int32), NamedSharding(mesh, P(None)),
|
2022-09-16 11:15:56 -07:00
|
|
|
[jax.device_put(inp, d) for d in mesh.devices.flat], committed=False)
|
|
|
|
with self.assertRaisesRegex(
|
|
|
|
NotImplementedError,
|
|
|
|
"Having uncommitted Array sharded on multiple devices is not supported."):
|
|
|
|
pjit(lambda x: x)(arr)
|
|
|
|
|
|
|
|
def test_pjit_committed_array_different_devices(self):
|
|
|
|
if jax.device_count() < 2:
|
|
|
|
self.skipTest('Test requires >= 2 devices')
|
|
|
|
a = jax.device_put(np.array([1, 2, 3]), jax.devices()[0])
|
|
|
|
b = jax.device_put(np.array([4, 5, 6]), jax.devices()[1])
|
|
|
|
with self.assertRaisesRegex(
|
|
|
|
ValueError,
|
2023-02-10 13:53:43 -08:00
|
|
|
"Received incompatible devices for pjitted computation. Got argument "
|
2023-02-10 15:36:04 -08:00
|
|
|
r"x of.*\<lambda\> with shape int.*\[3\] and device ids \[0\].*and "
|
|
|
|
r"argument y of.*\<lambda\> with shape int.*\[3\] and device ids \[1\].*"):
|
2022-09-16 11:15:56 -07:00
|
|
|
pjit(lambda x, y: (x, y))(a, b)
|
2022-09-15 10:33:31 -07:00
|
|
|
|
2023-06-02 07:43:21 -07:00
|
|
|
def test_pjit_committed_array_different_devices_variadic_args(self):
|
|
|
|
if jax.device_count() < 2:
|
|
|
|
self.skipTest('Test requires >= 2 devices')
|
|
|
|
a = jax.device_put(np.array([1, 2, 3]), jax.devices()[0])
|
|
|
|
b = jax.device_put(np.array([4, 5, 6]), jax.devices()[1])
|
|
|
|
with self.assertRaisesRegex(
|
|
|
|
ValueError,
|
|
|
|
"Received incompatible devices for pjitted computation. Got argument "
|
2023-12-19 17:31:25 -08:00
|
|
|
r"x\[0\] of.*\<lambda\> with shape int.*\[3\] and device ids \[0\].*and "
|
|
|
|
r"argument x\[1\] of.*\<lambda\> with shape int.*\[3\] and device ids "
|
|
|
|
r"\[1\].*"):
|
2023-06-02 07:43:21 -07:00
|
|
|
pjit(lambda *x: x)(a, b)
|
|
|
|
|
2023-02-10 13:53:43 -08:00
|
|
|
def test_pjit_pytree_inp_device_assignment_mismatch(self):
|
2024-09-03 16:22:23 -07:00
|
|
|
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
2023-02-10 13:53:43 -08:00
|
|
|
a = jax.device_put(np.array([1, 2, 3]), jax.devices()[0])
|
|
|
|
b = jax.device_put(np.array([4, 5, 6]), jax.devices()[1])
|
|
|
|
c = jax.device_put(np.arange(16).reshape(8, 2),
|
|
|
|
NamedSharding(mesh, P('x', 'y')))
|
|
|
|
|
|
|
|
msg = ("Received incompatible devices for pjitted computation. Got "
|
2023-02-10 15:36:04 -08:00
|
|
|
r"argument {} of.*<lambda> with shape int.*\[3\] and device ids "
|
|
|
|
r"\[0\].*and argument {} of.*<lambda> with shape int.*\[8,2\] and "
|
2024-09-04 11:48:37 -07:00
|
|
|
r"device ids.*")
|
2023-02-10 13:53:43 -08:00
|
|
|
|
|
|
|
with self.assertRaisesRegex(
|
|
|
|
ValueError, msg.format(r'tuple_inp\[0\]', r'tuple_inp\[1\]\[0\]')):
|
|
|
|
pjit(lambda tuple_inp: tuple_inp)((a, (c, (b))))
|
|
|
|
|
|
|
|
with self.assertRaisesRegex(
|
|
|
|
ValueError, msg.format(r"dict_inp\['a'\]\['b'\]\['c'\]",
|
|
|
|
r"dict_inp\['a'\]\['b'\]\['g'\]")):
|
|
|
|
inp = {'d': a, 'z': a, 'a': {'f': a, 'y': b, 'b': {'g': c, 'c': a}}}
|
|
|
|
pjit(lambda dict_inp: dict_inp)(inp)
|
|
|
|
|
2022-09-18 15:35:18 -07:00
|
|
|
def test_same_out_sharding_id(self):
|
|
|
|
shape = (8, 2)
|
2024-09-03 16:22:23 -07:00
|
|
|
mesh = jtu.create_mesh((4, 2), ('x', 'y'))
|
2022-09-18 15:35:18 -07:00
|
|
|
arr, inp_data = create_array(shape, mesh, P('x', 'y'))
|
|
|
|
|
|
|
|
f = pjit(lambda x: x)
|
|
|
|
out1 = f(arr)
|
|
|
|
self.assertArraysEqual(out1, inp_data)
|
|
|
|
out1_sharding_id = id(out1.sharding)
|
|
|
|
|
|
|
|
out2 = f(out1)
|
|
|
|
self.assertArraysEqual(out2, inp_data)
|
|
|
|
out2_sharding_id = id(out2.sharding)
|
|
|
|
|
|
|
|
out3 = f(out2)
|
|
|
|
self.assertArraysEqual(out3, inp_data)
|
|
|
|
out3_sharding_id = id(out3.sharding)
|
|
|
|
|
|
|
|
self.assertEqual(out1_sharding_id, out2_sharding_id)
|
|
|
|
self.assertEqual(out1_sharding_id, out3_sharding_id)
|
|
|
|
self.assertEqual(out2_sharding_id, out3_sharding_id)
|
|
|
|
|
|
|
|
def test_out_sharding_indices_id_cache_hit(self):
|
|
|
|
shape = (8, 2)
|
2024-09-03 16:22:23 -07:00
|
|
|
mesh = jtu.create_mesh((4, 2), ('x', 'y'))
|
2022-09-18 15:35:18 -07:00
|
|
|
arr, _ = create_array(shape, mesh, P('x', 'y'))
|
|
|
|
|
|
|
|
f = pjit(lambda x: x)
|
|
|
|
out1 = f(arr)
|
2023-04-09 15:41:32 -07:00
|
|
|
self.assertIsInstance(out1.sharding, NamedSharding)
|
2022-09-18 15:35:18 -07:00
|
|
|
out1.sharding.devices_indices_map(shape)
|
2024-06-05 08:02:39 -07:00
|
|
|
cache_info1 = common_devices_indices_map.cache_info()
|
2022-09-18 15:35:18 -07:00
|
|
|
|
|
|
|
out2 = f(out1)
|
2023-04-09 15:41:32 -07:00
|
|
|
self.assertIsInstance(out2.sharding, NamedSharding)
|
2022-09-18 15:35:18 -07:00
|
|
|
out2.sharding.devices_indices_map(shape)
|
2024-06-05 08:02:39 -07:00
|
|
|
cache_info2 = common_devices_indices_map.cache_info()
|
2022-09-18 15:35:18 -07:00
|
|
|
self.assertEqual(cache_info2.hits, cache_info1.hits + 1)
|
|
|
|
|
|
|
|
out3 = f(out2)
|
2023-04-09 15:41:32 -07:00
|
|
|
self.assertIsInstance(out3.sharding, NamedSharding)
|
2022-09-18 15:35:18 -07:00
|
|
|
out3.sharding.devices_indices_map(shape)
|
2024-06-05 08:02:39 -07:00
|
|
|
cache_info3 = common_devices_indices_map.cache_info()
|
2022-09-18 15:35:18 -07:00
|
|
|
self.assertEqual(cache_info3.hits, cache_info2.hits + 1)
|
|
|
|
|
2024-03-04 15:34:22 -08:00
|
|
|
def test_aot_compile_in_tree_mismatch(self):
|
2024-03-04 13:14:47 -08:00
|
|
|
@jax.jit
|
|
|
|
def f(tree):
|
|
|
|
return tree
|
|
|
|
|
|
|
|
tree1 = {'a': {'c': 5, 'd': 6}}
|
|
|
|
tree2 = {'a': 1, 'c': {'b': 5, 'e': 7}}
|
|
|
|
with self.assertRaisesRegex(
|
|
|
|
TypeError,
|
|
|
|
'Function compiled with input pytree does not match the input pytree it'
|
|
|
|
' was called with'):
|
|
|
|
f.lower(tree1).compile()(tree2)
|
|
|
|
|
2023-04-27 11:40:16 -07:00
|
|
|
@jax.enable_custom_prng()
|
2022-10-07 16:48:34 -07:00
|
|
|
def test_device_put_sharding_prng(self):
|
2024-09-03 16:22:23 -07:00
|
|
|
mesh = jtu.create_mesh((8,), ('x',))
|
2022-11-14 14:43:26 -08:00
|
|
|
s = NamedSharding(mesh, P('x'))
|
2022-10-07 16:48:34 -07:00
|
|
|
|
|
|
|
x = jax.random.split(jax.random.PRNGKey(0), len(jax.devices()))
|
2022-10-08 11:39:05 -07:00
|
|
|
y = jax.device_put(x, s)
|
2023-09-13 11:37:43 -07:00
|
|
|
self.assertTrue(jax.dtypes.issubdtype(y.dtype, jax.dtypes.prng_key))
|
2023-04-27 11:40:16 -07:00
|
|
|
self.assertEqual(y.sharding, s)
|
|
|
|
|
|
|
|
s1 = SingleDeviceSharding(jax.devices()[1])
|
|
|
|
z = jax.device_put(x, s1)
|
2023-09-13 11:37:43 -07:00
|
|
|
self.assertTrue(jax.dtypes.issubdtype(z.dtype, jax.dtypes.prng_key))
|
2023-04-27 11:40:16 -07:00
|
|
|
self.assertEqual(z.sharding, s1)
|
|
|
|
|
|
|
|
out_p = jax.pmap(lambda x: x)(np.arange(jax.device_count()))
|
|
|
|
a = jax.device_put(x, out_p.sharding)
|
2023-09-13 11:37:43 -07:00
|
|
|
self.assertTrue(jax.dtypes.issubdtype(a.dtype, jax.dtypes.prng_key))
|
2023-04-27 11:40:16 -07:00
|
|
|
self.assertEqual(a.sharding, out_p.sharding)
|
|
|
|
|
2024-08-23 06:50:14 -07:00
|
|
|
if config.use_shardy_partitioner.value:
|
|
|
|
# OpSharding is not supported in shardy.
|
|
|
|
return
|
|
|
|
|
2023-04-27 11:40:16 -07:00
|
|
|
op = xc.OpSharding()
|
|
|
|
op.type = xc.OpSharding.Type.OTHER
|
|
|
|
op.tile_assignment_dimensions = [8]
|
|
|
|
op.tile_assignment_devices = [0, 1, 2, 3, 4, 5, 6, 7]
|
|
|
|
gs = GSPMDSharding(tuple(mesh.devices.flat), op)
|
|
|
|
b = jax.device_put(x, gs)
|
2023-09-13 11:37:43 -07:00
|
|
|
self.assertTrue(jax.dtypes.issubdtype(b.dtype, jax.dtypes.prng_key))
|
2023-04-27 11:40:16 -07:00
|
|
|
self.assertEqual(b.sharding, gs)
|
2022-10-08 11:39:05 -07:00
|
|
|
|
|
|
|
def test_device_put_on_different_sharding(self):
|
2024-09-03 16:22:23 -07:00
|
|
|
mesh = jtu.create_mesh((4, 2), ('x', 'y'))
|
2022-10-08 11:39:05 -07:00
|
|
|
|
|
|
|
x = jnp.arange(8).reshape(4, 2)
|
2022-11-14 14:43:26 -08:00
|
|
|
s1 = NamedSharding(mesh, P('x'))
|
2022-10-08 11:39:05 -07:00
|
|
|
a = jax.device_put(x, s1)
|
|
|
|
self.assertEqual(a.sharding, s1)
|
|
|
|
|
2022-11-14 14:43:26 -08:00
|
|
|
s2 = NamedSharding(mesh, P('x', 'y'))
|
2022-10-08 11:39:05 -07:00
|
|
|
b = jax.device_put(a, s2)
|
|
|
|
self.assertEqual(b.sharding, s2)
|
2022-10-07 16:48:34 -07:00
|
|
|
|
2022-10-10 22:08:06 -07:00
|
|
|
def test_with_sharding_constraint_jit(self):
|
2024-09-03 16:22:23 -07:00
|
|
|
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
2022-10-10 22:08:06 -07:00
|
|
|
|
|
|
|
@partial(jax.jit, static_argnums=(0, 1))
|
|
|
|
def sharded_zeros(shape, pspec):
|
|
|
|
out = jnp.zeros(shape, jnp.bfloat16)
|
2022-12-11 22:54:39 -08:00
|
|
|
return jax.lax.with_sharding_constraint(out, NamedSharding(mesh, pspec))
|
2022-10-10 22:08:06 -07:00
|
|
|
|
|
|
|
out = sharded_zeros((4096, 3072), P('x', 'y'))
|
2022-11-14 14:43:26 -08:00
|
|
|
out_s = NamedSharding(mesh, P('x', 'y'))
|
2023-04-06 08:31:47 -07:00
|
|
|
self.assertTrue(op_shardings.are_op_shardings_equal(
|
2023-06-05 13:40:59 -07:00
|
|
|
out.sharding._to_xla_hlo_sharding(out.ndim),
|
|
|
|
out_s._to_xla_hlo_sharding(out.ndim)))
|
2022-10-10 22:08:06 -07:00
|
|
|
|
|
|
|
def test_with_sharding_constraint_pjit(self):
|
2024-09-03 16:22:23 -07:00
|
|
|
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
2022-10-10 22:08:06 -07:00
|
|
|
|
|
|
|
@partial(pjit, static_argnums=(0, 1))
|
|
|
|
def sharded_zeros(shape, pspec):
|
|
|
|
out = jnp.zeros(shape, jnp.bfloat16)
|
2022-12-11 22:54:39 -08:00
|
|
|
return jax.lax.with_sharding_constraint(out, NamedSharding(mesh, pspec))
|
2022-10-10 22:08:06 -07:00
|
|
|
|
|
|
|
out = sharded_zeros((4096, 3072), P('x', 'y'))
|
2022-11-14 14:43:26 -08:00
|
|
|
out_s = NamedSharding(mesh, P('x', 'y'))
|
2023-04-06 08:31:47 -07:00
|
|
|
self.assertTrue(op_shardings.are_op_shardings_equal(
|
2023-06-05 13:40:59 -07:00
|
|
|
out.sharding._to_xla_hlo_sharding(out.ndim),
|
|
|
|
out_s._to_xla_hlo_sharding(out.ndim)))
|
2022-10-10 22:08:06 -07:00
|
|
|
|
|
|
|
def test_jit_with_sharding_constraint_committed_inp_error(self):
|
2024-09-03 16:22:23 -07:00
|
|
|
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
2022-10-10 22:08:06 -07:00
|
|
|
|
2023-02-10 13:53:43 -08:00
|
|
|
s = NamedSharding(mesh, P('x', 'y'))
|
|
|
|
|
2022-10-10 22:08:06 -07:00
|
|
|
@jax.jit
|
|
|
|
def sharded_inp(inp):
|
2022-12-11 22:54:39 -08:00
|
|
|
return jax.lax.with_sharding_constraint(
|
2022-11-14 14:43:26 -08:00
|
|
|
inp, NamedSharding(mesh, P('x', 'y')))
|
2022-10-10 22:08:06 -07:00
|
|
|
|
|
|
|
committed_inp = jax.device_put(jnp.zeros((8, 2), jnp.bfloat16), jax.devices()[0])
|
|
|
|
with self.assertRaisesRegex(
|
|
|
|
ValueError,
|
2023-02-10 13:53:43 -08:00
|
|
|
"Received incompatible devices for jitted computation. Got argument "
|
2023-02-10 15:36:04 -08:00
|
|
|
r"inp of.*sharded_inp with shape bfloat16\[8,2\] and device ids \[0\].*"
|
2024-09-04 11:48:37 -07:00
|
|
|
r"sharding_constraint inside jit with device ids.*"):
|
2022-10-10 22:08:06 -07:00
|
|
|
sharded_inp(committed_inp)
|
|
|
|
|
2023-02-10 13:53:43 -08:00
|
|
|
@pjit
|
|
|
|
def my_nested_pjit(inp1, inp2, inp3):
|
2023-02-28 14:28:32 -08:00
|
|
|
@partial(pjit, in_shardings=(s, s, s),
|
|
|
|
out_shardings=(s, s, s))
|
2023-02-10 13:53:43 -08:00
|
|
|
def f(x, y, z):
|
|
|
|
return x * 2, y * 2, z * 2
|
|
|
|
return f(inp1, inp2, inp3)
|
|
|
|
with self.assertRaisesRegex(
|
|
|
|
ValueError,
|
|
|
|
"Received incompatible devices for pjitted computation. Got argument "
|
2023-02-10 15:36:04 -08:00
|
|
|
r"inp1 of.*my_nested_pjit with shape bfloat16\[8,2\] and device ids \[0\].*"
|
2024-09-04 11:48:37 -07:00
|
|
|
r"pjit inside pjit with device ids.*"):
|
2023-02-10 13:53:43 -08:00
|
|
|
my_nested_pjit(committed_inp, committed_inp, committed_inp)
|
|
|
|
|
2024-06-12 14:43:14 -07:00
|
|
|
@jtu.ignore_warning(category=DeprecationWarning,
|
|
|
|
message="backend and device argument")
|
2022-10-10 22:08:06 -07:00
|
|
|
def test_jit_device_with_sharding_constraint_error(self):
|
2024-09-03 16:22:23 -07:00
|
|
|
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
2022-10-10 22:08:06 -07:00
|
|
|
|
|
|
|
@partial(jax.jit, static_argnums=(0, 1), device=jax.devices()[0])
|
|
|
|
def sharded_zeros(shape, pspec):
|
|
|
|
out = jnp.zeros(shape, jnp.bfloat16)
|
2022-12-11 22:54:39 -08:00
|
|
|
return jax.lax.with_sharding_constraint(out, NamedSharding(mesh, pspec))
|
2022-10-10 22:08:06 -07:00
|
|
|
|
2023-02-10 13:53:43 -08:00
|
|
|
with self.assertRaisesRegex(
|
|
|
|
ValueError,
|
|
|
|
"Received incompatible devices for jitted computation. Got explicit "
|
2023-02-10 15:36:04 -08:00
|
|
|
r"output sharding with device ids \[0\].*sharding_constraint inside "
|
2024-09-04 11:48:37 -07:00
|
|
|
r"jit with device ids.*"):
|
2022-10-10 22:08:06 -07:00
|
|
|
sharded_zeros((4096, 3072), P('x', 'y'))
|
|
|
|
|
2022-10-10 15:00:28 -07:00
|
|
|
def test_concurrent_pjit(self):
|
2024-09-03 16:22:23 -07:00
|
|
|
global_mesh = jtu.create_mesh((1,), ('x',))
|
2022-11-14 14:43:26 -08:00
|
|
|
sharding = NamedSharding(global_mesh, P('x',))
|
2022-09-21 20:17:38 -07:00
|
|
|
n = 10
|
|
|
|
with global_mesh:
|
|
|
|
fs = [pjit(lambda x, i: x + i, static_argnums=1) for _ in range(n)]
|
|
|
|
|
|
|
|
def _invoke_with_mesh_twice(arg_tuple):
|
|
|
|
f, x, i = arg_tuple
|
|
|
|
with global_mesh:
|
|
|
|
f(x, i)
|
|
|
|
return f(x, i)
|
|
|
|
|
|
|
|
xs = [
|
|
|
|
array.make_array_from_callback(
|
|
|
|
(i,), sharding, lambda idx: np.arange(i, dtype=np.float32))
|
|
|
|
for i in range(n)
|
|
|
|
]
|
|
|
|
with concurrent.futures.ThreadPoolExecutor() as executor:
|
|
|
|
ys = executor.map(_invoke_with_mesh_twice,
|
|
|
|
[(fs[i], x, i) for i, x in enumerate(xs)])
|
|
|
|
for i, x, y in zip(range(n), xs, ys):
|
|
|
|
self.assertAllClose(x + i, y)
|
|
|
|
|
2022-10-21 16:53:14 -07:00
|
|
|
def test_trivial_computation(self):
|
|
|
|
shape = (8, 2)
|
2024-09-03 16:22:23 -07:00
|
|
|
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
2022-11-14 14:43:26 -08:00
|
|
|
s = NamedSharding(mesh, P('x', 'y'))
|
2023-02-28 12:40:30 -08:00
|
|
|
inp_data = np.arange(math.prod(shape)).reshape(shape)
|
2022-10-21 16:53:14 -07:00
|
|
|
arr = jax.device_put(inp_data, s)
|
|
|
|
out = pjit(lambda x: x)(arr)
|
|
|
|
self.assertArraysEqual(out, inp_data)
|
|
|
|
|
2022-10-25 14:46:21 -07:00
|
|
|
def test_trivial_computation_with_sharded_const(self):
|
2024-09-03 16:22:23 -07:00
|
|
|
mesh = jtu.create_mesh((2, 1), ('x', 'y'))
|
2022-10-25 14:46:21 -07:00
|
|
|
const = jax.device_put(np.arange(16).reshape(8, 2),
|
2022-11-14 14:43:26 -08:00
|
|
|
NamedSharding(mesh, P('x', 'y')))
|
2022-10-25 14:46:21 -07:00
|
|
|
with mesh:
|
|
|
|
out = pjit(lambda: const)()
|
|
|
|
self.assertIsInstance(out, array.ArrayImpl)
|
|
|
|
self.assertArraysEqual(out, np.arange(16).reshape(8, 2))
|
|
|
|
|
|
|
|
def test_trivial_computation_with_sharded_const_using_transposed_mesh(self):
|
2024-09-03 16:22:23 -07:00
|
|
|
mesh = jtu.create_mesh((2, 1), ('x', 'y'))
|
2022-10-25 14:46:21 -07:00
|
|
|
const = jax.device_put(np.arange(16).reshape(8, 2),
|
2022-11-14 14:43:26 -08:00
|
|
|
NamedSharding(mesh, P('x', 'y')))
|
2024-09-03 16:22:23 -07:00
|
|
|
mesh2 = jtu.create_mesh((1, 2), ('x', 'y'))
|
2022-10-25 14:46:21 -07:00
|
|
|
with mesh2:
|
|
|
|
out = pjit(lambda: const)()
|
|
|
|
self.assertIsInstance(out, array.ArrayImpl)
|
|
|
|
self.assertArraysEqual(out, np.arange(16).reshape(8, 2))
|
|
|
|
|
|
|
|
def test_trivial_computation_with_replicated_literal(self):
|
2024-09-03 16:22:23 -07:00
|
|
|
mesh = jtu.create_mesh((2, 1), ('x', 'y'))
|
2022-10-25 14:46:21 -07:00
|
|
|
with mesh:
|
|
|
|
out = pjit(lambda: 1)()
|
2024-08-02 11:04:01 -07:00
|
|
|
self.assertEqual(out.sharding, NamedSharding(mesh, P()))
|
2022-10-25 14:46:21 -07:00
|
|
|
self.assertIsInstance(out, array.ArrayImpl)
|
|
|
|
self.assertEqual(out, 1)
|
|
|
|
|
2022-10-21 16:53:14 -07:00
|
|
|
def test_multi_device_pjit_mul(self):
|
|
|
|
shape = (8, 2)
|
2024-09-03 16:22:23 -07:00
|
|
|
mesh = jtu.create_mesh((4, 2), ('x', 'y'))
|
2023-02-28 12:40:30 -08:00
|
|
|
inp_data = np.arange(math.prod(shape)).reshape(shape)
|
2022-11-14 14:43:26 -08:00
|
|
|
arr1 = jax.device_put(inp_data, NamedSharding(mesh, P('x', 'y')))
|
|
|
|
arr2 = jax.device_put(inp_data, NamedSharding(mesh, P(None, 'y')))
|
2022-10-21 16:53:14 -07:00
|
|
|
|
|
|
|
out1, out2 = pjit(lambda x, y: (x @ x.T, y * 2))(arr1, arr2)
|
|
|
|
|
|
|
|
self.assertArraysEqual(out1, inp_data @ inp_data.T)
|
|
|
|
self.assertEqual(out1.shape, (8, 8))
|
|
|
|
self.assertArraysEqual(out2, inp_data * 2)
|
|
|
|
self.assertEqual(out2.shape, (8, 2))
|
|
|
|
|
2022-12-16 09:02:28 -08:00
|
|
|
def test_single_device_pjit_cpp_dispatch(self):
|
2022-11-02 11:58:35 -07:00
|
|
|
shape = (8, 2)
|
2024-09-03 16:22:23 -07:00
|
|
|
mesh = jtu.create_mesh((1,), ('x',))
|
2023-02-28 12:40:30 -08:00
|
|
|
inp_data = np.arange(math.prod(shape)).reshape(shape)
|
2022-11-02 11:58:35 -07:00
|
|
|
|
2023-02-18 09:59:58 -08:00
|
|
|
f = pjit(lambda x: x @ x.T, in_shardings=None, out_shardings=None)
|
2023-04-03 14:14:20 -07:00
|
|
|
with jtu.count_pjit_cpp_cache_miss() as count:
|
2022-11-02 11:58:35 -07:00
|
|
|
for _ in range(10):
|
|
|
|
arr1 = jax.device_put(
|
2022-11-14 14:43:26 -08:00
|
|
|
inp_data, jax.sharding.NamedSharding(mesh, P('x')))
|
2022-11-02 11:58:35 -07:00
|
|
|
with mesh:
|
|
|
|
f(arr1)
|
2023-02-06 20:34:51 -08:00
|
|
|
self.assertEqual(count[0], 1)
|
2022-11-02 11:58:35 -07:00
|
|
|
|
|
|
|
def test_single_device_add_single_compile(self):
|
|
|
|
f1 = pjit(lambda x, y: x + y)
|
|
|
|
a = jax.device_put(jnp.array([1, 2, 3], dtype=jnp.float32),
|
|
|
|
jax.devices()[0])
|
|
|
|
b = jax.device_put(jnp.array([4, 5, 6], dtype=jnp.float32),
|
|
|
|
jax.devices()[0])
|
|
|
|
|
2023-04-03 14:14:20 -07:00
|
|
|
with jtu.count_pjit_cpp_cache_miss() as count:
|
2022-11-02 11:58:35 -07:00
|
|
|
for _ in range(2):
|
|
|
|
f1(a, b)
|
2023-02-06 20:34:51 -08:00
|
|
|
self.assertEqual(count[0], 1)
|
2022-11-02 11:58:35 -07:00
|
|
|
|
2022-11-05 20:15:39 -07:00
|
|
|
def test_global_array_to_host_local_array_already_host_local(self):
|
|
|
|
inp_shape = (8, 2)
|
2024-09-03 16:22:23 -07:00
|
|
|
mesh = jtu.create_mesh((4, 2), ('x', 'y'))
|
2022-11-05 20:15:39 -07:00
|
|
|
pspec = P('x', 'y')
|
|
|
|
|
|
|
|
arr, _ = create_array(inp_shape, mesh, pspec)
|
|
|
|
out = multihost_utils.global_array_to_host_local_array(arr, mesh, pspec)
|
|
|
|
self.assertEqual(id(arr), id(out))
|
|
|
|
|
2022-11-07 10:21:13 -08:00
|
|
|
@jtu.with_mesh([('x', 2), ('y', 2)])
|
|
|
|
def testLowerCompileWithStaticArguments(self):
|
|
|
|
@partial(pjit,
|
2023-02-28 14:28:32 -08:00
|
|
|
in_shardings=P(('x', 'y'),),
|
|
|
|
out_shardings=P(('x', 'y'),), static_argnums=0)
|
2022-11-07 10:21:13 -08:00
|
|
|
def f(c, x):
|
|
|
|
return x if c == 0 else x + 1
|
|
|
|
|
|
|
|
shape = (8, 8)
|
2023-04-13 11:48:11 -07:00
|
|
|
x = jnp.arange(math.prod(shape)).reshape(shape)
|
2022-11-07 10:21:13 -08:00
|
|
|
exe = f.lower(1, x).compile()
|
|
|
|
|
|
|
|
self.assertAllClose(exe(x), x + 1, check_dtypes=False)
|
|
|
|
|
2022-12-01 10:41:40 -08:00
|
|
|
def test_vmap_of_jvp_pjit_no_axis_resources(self):
|
2024-09-03 16:22:23 -07:00
|
|
|
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
2022-12-01 10:41:40 -08:00
|
|
|
pjit_inp1 = jax.device_put(
|
|
|
|
jnp.arange(8.), jax.sharding.NamedSharding(mesh, P('x')))
|
|
|
|
pjit_inp2 = jax.device_put(
|
|
|
|
jnp.arange(8.), jax.sharding.NamedSharding(mesh, P(('x', 'y'))))
|
|
|
|
|
|
|
|
def f_(x, n):
|
|
|
|
if n == 0:
|
|
|
|
return x * 2.
|
|
|
|
return jax.jit(partial(f_, n=n-1))(x - 1)
|
|
|
|
f = jax.jit(partial(f_, n=5))
|
|
|
|
jit_out1, jit_out2 = jax.vmap(lambda xs, ts: jax.jvp(f, xs, ts))(
|
|
|
|
(jnp.arange(8.),), (jnp.arange(8.),))
|
|
|
|
|
|
|
|
def g_(x, n):
|
|
|
|
if n == 0:
|
|
|
|
return x * 2.
|
|
|
|
return pjit(partial(g_, n=n - 1))(x - 1)
|
|
|
|
g = pjit(partial(g_, n=5))
|
|
|
|
pjit_out1, pjit_out2 = jax.vmap(lambda xs, ts: jax.jvp(g, xs, ts))(
|
|
|
|
(pjit_inp1,), (pjit_inp2,))
|
|
|
|
|
|
|
|
self.assertArraysEqual(pjit_out1, jit_out1)
|
|
|
|
self.assertArraysEqual(pjit_out2, jit_out2)
|
|
|
|
|
|
|
|
def test_vmap_of_jvp_pjit_no_axis_resources_2d(self):
|
2024-09-03 16:22:23 -07:00
|
|
|
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
2022-12-01 10:41:40 -08:00
|
|
|
f_inp = jnp.arange(8.).reshape(2, 2, 2)
|
|
|
|
|
|
|
|
# g_inp is sharded with P(None, 'x') because f_inp is sharded with P('x')
|
|
|
|
# and then `f` will get vmapped and pjit's batching rule will insert a
|
|
|
|
# replicated axis for the batched dimension converting it into P(None, 'x')
|
|
|
|
g_inp = jax.device_put(f_inp,
|
|
|
|
jax.sharding.NamedSharding(mesh, P(None, 'x')))
|
|
|
|
|
|
|
|
# Reference pjit with axis_resources
|
|
|
|
def f_(x, n):
|
|
|
|
if n == 0:
|
|
|
|
return x * 2.
|
2023-02-18 09:59:58 -08:00
|
|
|
return pjit(
|
|
|
|
partial(f_, n=n - 1), in_shardings=P('x'), out_shardings=P('x')
|
|
|
|
)(x - 1)
|
|
|
|
f = pjit(partial(f_, n=5), in_shardings=P('x'), out_shardings=P('x'))
|
2022-12-01 10:41:40 -08:00
|
|
|
with mesh:
|
|
|
|
f_out1, f_out2 = jax.vmap(lambda xs, ts: jax.jvp(f, xs, ts))(
|
|
|
|
(f_inp,), (f_inp,))
|
|
|
|
|
|
|
|
# pjit with no axis_resources
|
|
|
|
def g_(x, n):
|
|
|
|
if n == 0:
|
|
|
|
return x * 2.
|
|
|
|
return pjit(partial(g_, n=n - 1))(x - 1)
|
|
|
|
g = pjit(partial(g_, n=5))
|
|
|
|
g_out1, g_out2 = jax.vmap(lambda xs, ts: jax.jvp(g, xs, ts))(
|
|
|
|
(g_inp,), (g_inp,))
|
|
|
|
|
|
|
|
self.assertArraysEqual(f_out1, g_out1)
|
|
|
|
self.assertArraysEqual(f_out2, g_out2)
|
2023-04-17 15:32:21 -07:00
|
|
|
self.assertEqual(f_out1.sharding, g_out1.sharding)
|
|
|
|
self.assertEqual(f_out2.sharding, g_out2.sharding)
|
2022-12-01 10:41:40 -08:00
|
|
|
|
2022-12-05 18:09:26 -08:00
|
|
|
def test_pjit_on_different_default_device_with_uncommitted_inputs(self):
|
2023-03-15 17:08:21 -07:00
|
|
|
if jax.device_count() < 2:
|
|
|
|
self.skipTest('Test requires >= 2 devices')
|
2022-12-05 18:09:26 -08:00
|
|
|
|
|
|
|
@pjit
|
|
|
|
def f(x, y):
|
|
|
|
return x + y
|
|
|
|
|
|
|
|
a = jnp.array([1, 2, 3], dtype=jnp.float32)
|
|
|
|
self.assertFalse(a._committed)
|
|
|
|
out = f(a, a)
|
|
|
|
self.assertFalse(out._committed)
|
2023-11-29 16:52:09 -08:00
|
|
|
self.assertEqual(out.devices(), {jax.devices()[0]})
|
2022-12-05 18:09:26 -08:00
|
|
|
self.assertArraysEqual(out, a * 2)
|
|
|
|
|
|
|
|
with jax.default_device(jax.devices()[1]):
|
|
|
|
b = jnp.array([4, 5, 6], dtype=jnp.float32)
|
|
|
|
self.assertFalse(b._committed)
|
|
|
|
out2 = f(b, b)
|
|
|
|
self.assertFalse(out2._committed)
|
2023-11-29 16:52:09 -08:00
|
|
|
self.assertEqual(out2.devices(), {jax.devices()[1]})
|
2022-12-05 18:09:26 -08:00
|
|
|
self.assertArraysEqual(out2, b * 2)
|
|
|
|
|
2022-12-13 13:51:41 -08:00
|
|
|
def test_pjit_with_static_argnames(self):
|
|
|
|
|
|
|
|
def f(x: str) -> int:
|
|
|
|
assert x == 'foo'
|
|
|
|
return 1
|
|
|
|
|
|
|
|
f_nums = pjit(f, static_argnums=0)
|
|
|
|
assert f_nums('foo') == 1
|
|
|
|
assert f_nums(x='foo') == 1
|
|
|
|
|
|
|
|
f_names = pjit(f, static_argnames='x')
|
|
|
|
assert f_names('foo') == 1
|
|
|
|
assert f_names(x='foo') == 1
|
|
|
|
|
2023-01-03 14:05:17 -08:00
|
|
|
def test_pjit_with_static_argnames_cpp_dispatch(self):
|
|
|
|
def f(y, **kwargs):
|
|
|
|
self.assertEqual(kwargs, {'x': 'foo'})
|
|
|
|
return y * y
|
|
|
|
|
2023-12-02 00:30:45 -08:00
|
|
|
y = jnp.arange(8.)
|
2023-04-03 14:14:20 -07:00
|
|
|
with jtu.count_pjit_cpp_cache_miss() as count:
|
2023-01-03 14:05:17 -08:00
|
|
|
f_names = pjit(f, static_argnames='x')
|
|
|
|
f_names(y, x='foo')
|
|
|
|
f_names(y, x='foo')
|
2023-02-06 20:34:51 -08:00
|
|
|
self.assertEqual(count[0], 1)
|
2023-01-03 14:05:17 -08:00
|
|
|
|
2022-12-13 13:51:41 -08:00
|
|
|
def test_new_static_argnum_on_keyword_arguments(self):
|
|
|
|
f = pjit(lambda x: x, static_argnums=0)
|
|
|
|
y = f(x=4)
|
|
|
|
assert y == 4
|
|
|
|
|
|
|
|
def test_new_static_argnum_with_default_arguments(self):
|
|
|
|
f = pjit(lambda x=4: x, static_argnums=0)
|
|
|
|
y = f()
|
|
|
|
assert y == 4
|
|
|
|
|
2023-02-07 21:32:22 -08:00
|
|
|
def test_pjit_different_default_device(self):
|
|
|
|
if jax.device_count() <= 1:
|
|
|
|
self.skipTest('Test requires more >1 device.')
|
|
|
|
|
2023-11-29 16:52:09 -08:00
|
|
|
system_default_device = list(jnp.add(1, 1).devices())[0]
|
2023-02-07 21:32:22 -08:00
|
|
|
test_device = jax.devices()[-1]
|
|
|
|
|
|
|
|
f = pjit(lambda x: x + 1)
|
|
|
|
|
|
|
|
f(1)
|
|
|
|
with jax.default_device(system_default_device):
|
|
|
|
f(1)
|
|
|
|
with jax.default_device(test_device):
|
|
|
|
f(1)
|
|
|
|
|
2023-04-03 14:14:20 -07:00
|
|
|
with jtu.count_pjit_cpp_cache_miss() as count:
|
2023-02-07 21:32:22 -08:00
|
|
|
f(1)
|
|
|
|
|
|
|
|
with jax.default_device(system_default_device):
|
|
|
|
f(1)
|
|
|
|
|
|
|
|
with jax.default_device(test_device):
|
|
|
|
f(1)
|
|
|
|
|
|
|
|
with jax.default_device(test_device):
|
|
|
|
with jax.default_device(system_default_device):
|
|
|
|
f(1)
|
|
|
|
|
2023-04-03 14:14:20 -07:00
|
|
|
# The count here is 0 because before `count_pjit_cpp_cache_miss`, `f` was
|
2023-02-07 21:32:22 -08:00
|
|
|
# called with `system_default_device` and `test_device` so it was added
|
|
|
|
# to the cache. Subsequent calls hit the C++ cache.
|
|
|
|
self.assertEqual(count[0], 0)
|
|
|
|
|
2022-12-13 13:51:41 -08:00
|
|
|
def test_pjit_with_mismatched_static_argnames(self):
|
|
|
|
x_is_tracer, y_is_tracer = False, False
|
|
|
|
def f(x, y):
|
2023-02-14 23:00:40 -08:00
|
|
|
assert isinstance(x, core.Tracer) == x_is_tracer
|
|
|
|
assert isinstance(y, core.Tracer) == y_is_tracer
|
2022-12-13 13:51:41 -08:00
|
|
|
return 1
|
|
|
|
|
|
|
|
# If both static_argnums and static_argnames are provided, they are allowed
|
|
|
|
# to disagree and `jit` will respect the user's choices.
|
|
|
|
f_nums = pjit(f, static_argnums=1, static_argnames=())
|
|
|
|
x_is_tracer, y_is_tracer = True, False
|
|
|
|
assert f_nums(2, 3) == 1
|
|
|
|
x_is_tracer, y_is_tracer = True, True
|
|
|
|
assert f_nums(1, y=2) == 1
|
|
|
|
|
|
|
|
f_names = pjit(f, static_argnums=(), static_argnames='y')
|
|
|
|
x_is_tracer, y_is_tracer = True, True
|
|
|
|
assert f_names(2, 3) == 1
|
|
|
|
x_is_tracer, y_is_tracer = True, False
|
|
|
|
assert f_names(1, y=3) == 1
|
|
|
|
|
|
|
|
f_mixed = pjit(f, static_argnums=(1,), static_argnames='x')
|
|
|
|
x_is_tracer, y_is_tracer = True, False
|
|
|
|
assert f_mixed(2, 3) == 1
|
|
|
|
x_is_tracer, y_is_tracer = True, True
|
|
|
|
assert f_mixed(1, y=3) == 1
|
|
|
|
x_is_tracer, y_is_tracer = False, True
|
|
|
|
assert f_mixed(x=2, y=3) == 1
|
|
|
|
|
|
|
|
def test_pjit_kwargs(self):
|
|
|
|
a = jnp.arange(8.)
|
|
|
|
b = jnp.arange(4.)
|
|
|
|
c = jnp.arange(2.)
|
|
|
|
|
|
|
|
@pjit
|
|
|
|
def f(x, y, z):
|
|
|
|
return x, y, z
|
|
|
|
|
|
|
|
o1, o2, o3 = f(a, y=b, z=c)
|
2022-12-15 16:25:45 -08:00
|
|
|
cache_info1 = pjit_lib._pjit_lower_cached.cache_info()
|
2022-12-13 13:51:41 -08:00
|
|
|
self.assertArraysEqual(o1, a)
|
|
|
|
self.assertArraysEqual(o2, b)
|
|
|
|
self.assertArraysEqual(o3, c)
|
|
|
|
|
|
|
|
o4, o5, o6 = f(x=a, y=b, z=c)
|
2022-12-15 16:25:45 -08:00
|
|
|
cache_info2 = pjit_lib._pjit_lower_cached.cache_info()
|
2022-12-13 13:51:41 -08:00
|
|
|
self.assertArraysEqual(o4, a)
|
|
|
|
self.assertArraysEqual(o5, b)
|
|
|
|
self.assertArraysEqual(o6, c)
|
|
|
|
|
2022-12-15 16:25:45 -08:00
|
|
|
self.assertEqual(cache_info2.hits, cache_info1.hits)
|
|
|
|
self.assertEqual(cache_info2.misses, cache_info1.misses + 1)
|
|
|
|
|
2022-12-13 13:51:41 -08:00
|
|
|
o7, o8, o9 = f(a, b, c)
|
2022-12-15 16:25:45 -08:00
|
|
|
cache_info3 = pjit_lib._pjit_lower_cached.cache_info()
|
2022-12-13 13:51:41 -08:00
|
|
|
self.assertArraysEqual(o7, a)
|
|
|
|
self.assertArraysEqual(o8, b)
|
|
|
|
self.assertArraysEqual(o9, c)
|
|
|
|
|
2022-12-15 16:25:45 -08:00
|
|
|
self.assertEqual(cache_info3.hits, cache_info2.hits)
|
|
|
|
self.assertEqual(cache_info3.misses, cache_info2.misses + 1)
|
|
|
|
|
2024-04-08 21:34:26 -07:00
|
|
|
def test_pjit_kwargs_axis_resources_error(self):
|
|
|
|
with self.assertRaisesRegex(
|
|
|
|
ValueError,
|
|
|
|
"pjit does not support kwargs when in_shardings is specified."):
|
|
|
|
pjit(lambda x: x,
|
|
|
|
in_shardings=SingleDeviceSharding(jax.devices()[0]))(x=jnp.arange(8.))
|
|
|
|
|
2022-12-16 09:02:28 -08:00
|
|
|
def test_pjit_keep_unused_true(self):
|
|
|
|
@partial(pjit, keep_unused=True)
|
|
|
|
def f(x, y, z, a, b, c): # pylint: disable=unused-argument
|
|
|
|
return c @ c.T
|
|
|
|
|
|
|
|
inp = jnp.arange(4)
|
|
|
|
unused_inp = jnp.arange(8)
|
|
|
|
|
|
|
|
out = f(unused_inp, unused_inp, unused_inp, unused_inp, unused_inp, inp)
|
|
|
|
# Run it again to take the C++ dispatch.
|
|
|
|
out_again = f(unused_inp, unused_inp, unused_inp, unused_inp, unused_inp, inp)
|
|
|
|
|
|
|
|
self.assertArraysEqual(out, inp @ inp.T)
|
|
|
|
self.assertArraysEqual(out_again, inp @ inp.T)
|
|
|
|
|
|
|
|
compiled = f.lower(
|
|
|
|
unused_inp, unused_inp, unused_inp, unused_inp, unused_inp, inp).compile()
|
|
|
|
self.assertEqual(compiled._executable._kept_var_idx, {0, 1, 2, 3, 4, 5})
|
|
|
|
self.assertLen(compiled._executable.in_avals, 6)
|
|
|
|
|
|
|
|
def test_pjit_keep_unused_default_false(self):
|
|
|
|
@pjit
|
|
|
|
def f(x, y, z, a, b, c): # pylint: disable=unused-argument
|
|
|
|
return c @ c.T
|
|
|
|
|
|
|
|
inp = jax.device_put(jnp.arange(4), jax.devices()[0])
|
|
|
|
unused_inp = jax.device_put(jnp.arange(8), jax.devices()[0])
|
|
|
|
|
|
|
|
out = f(unused_inp, unused_inp, unused_inp, unused_inp, unused_inp, inp)
|
|
|
|
# Run it again to take the C++ dispatch.
|
|
|
|
out_again = f(unused_inp, unused_inp, unused_inp, unused_inp, unused_inp, inp)
|
|
|
|
|
|
|
|
self.assertArraysEqual(out, inp @ inp.T)
|
|
|
|
self.assertArraysEqual(out_again, inp @ inp.T)
|
|
|
|
|
|
|
|
compiled = f.lower(
|
|
|
|
unused_inp, unused_inp, unused_inp, unused_inp, unused_inp, inp).compile()
|
|
|
|
self.assertEqual(compiled._executable._kept_var_idx, {5})
|
|
|
|
self.assertLen(compiled._executable.in_avals, 1)
|
|
|
|
|
2024-03-25 10:07:55 -07:00
|
|
|
def test_pjit_relayout_multi_slice(self):
|
2024-09-03 16:22:23 -07:00
|
|
|
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
2024-03-25 10:07:55 -07:00
|
|
|
|
|
|
|
@jax.jit
|
|
|
|
def mul(x):
|
|
|
|
return x @ x.T
|
|
|
|
|
|
|
|
x = jnp.arange(8).reshape(4, 2)
|
|
|
|
y = jax.device_put(x, jax.sharding.NamedSharding(mesh, P('x', 'y')))
|
|
|
|
compiled = mul.lower(jax.ShapeDtypeStruct(
|
|
|
|
y.shape, y.dtype, sharding=y.sharding)).compile()
|
|
|
|
out = compiled(y)
|
|
|
|
self.assertArraysEqual(out, x @ x.T)
|
|
|
|
|
2022-12-14 15:41:19 -08:00
|
|
|
def test_pjit_with_device_arg(self):
|
|
|
|
def mul(x):
|
|
|
|
return x @ x.T
|
|
|
|
|
|
|
|
def _check(out, expected_device, expected_out):
|
2023-11-29 16:52:09 -08:00
|
|
|
self.assertEqual(out.devices(), {expected_device})
|
2022-12-14 15:41:19 -08:00
|
|
|
self.assertLen(out.sharding.device_set, 1)
|
|
|
|
self.assertArraysEqual(out, expected_out @ expected_out.T)
|
|
|
|
|
2024-09-03 16:22:23 -07:00
|
|
|
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
2022-12-14 15:41:19 -08:00
|
|
|
|
2024-06-12 14:43:14 -07:00
|
|
|
with jtu.ignore_warning(category=DeprecationWarning,
|
|
|
|
message="backend and device argument"):
|
|
|
|
f = pjit(mul, device=jax.devices()[1])
|
2024-08-02 11:04:01 -07:00
|
|
|
|
2022-12-14 15:41:19 -08:00
|
|
|
x = jnp.arange(8).reshape(4, 2)
|
|
|
|
f_out = f(x)
|
|
|
|
f_out2 = f(f_out)
|
|
|
|
cache_info1 = pjit_lib._pjit_lower_cached.cache_info()
|
|
|
|
_check(f_out, jax.devices()[1], x)
|
|
|
|
_check(f_out2, jax.devices()[1], f_out)
|
|
|
|
|
|
|
|
y = jax.device_put(x, jax.sharding.NamedSharding(mesh, P('x', 'y')))
|
|
|
|
out2 = f(y)
|
|
|
|
cache_info2 = pjit_lib._pjit_lower_cached.cache_info()
|
|
|
|
_check(out2, jax.devices()[1], y)
|
|
|
|
|
|
|
|
self.assertEqual(cache_info2.hits, cache_info1.hits + 1)
|
|
|
|
|
2024-06-12 14:43:14 -07:00
|
|
|
with jtu.ignore_warning(category=DeprecationWarning,
|
|
|
|
message="backend and device argument"):
|
|
|
|
h = pjit(mul, device=jax.devices()[-1])
|
2022-12-14 15:41:19 -08:00
|
|
|
h_out = h(y)
|
|
|
|
cache_info3 = pjit_lib._pjit_lower_cached.cache_info()
|
2022-12-14 17:27:42 -08:00
|
|
|
_check(h_out, jax.devices()[-1], y)
|
2022-12-14 15:41:19 -08:00
|
|
|
|
|
|
|
self.assertEqual(cache_info3.hits, cache_info2.hits)
|
|
|
|
|
|
|
|
# AOT test
|
2023-02-14 23:00:40 -08:00
|
|
|
compiled = f.lower(core.ShapedArray(y.shape, y.dtype)).compile()
|
2022-12-14 15:41:19 -08:00
|
|
|
out3 = compiled(y)
|
|
|
|
_check(out3, jax.devices()[1], y)
|
|
|
|
|
|
|
|
def test_pjit_with_device_arg_input_from_another_pjit(self):
|
2024-09-03 16:22:23 -07:00
|
|
|
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
2022-12-14 15:41:19 -08:00
|
|
|
inp = np.arange(8).reshape(4, 2)
|
|
|
|
|
|
|
|
y = jax.device_put(inp, jax.sharding.NamedSharding(mesh, P('x', 'y')))
|
|
|
|
out = pjit(lambda x: x * 2)(y)
|
|
|
|
|
|
|
|
expected_device = jax.devices()[2]
|
2024-06-12 14:43:14 -07:00
|
|
|
with jtu.ignore_warning(category=DeprecationWarning,
|
|
|
|
message="backend and device argument"):
|
|
|
|
final_out = pjit(lambda x: x * 3, device=expected_device)(out)
|
2022-12-14 15:41:19 -08:00
|
|
|
|
2023-11-29 16:52:09 -08:00
|
|
|
self.assertEqual(final_out.devices(), {expected_device})
|
2022-12-14 15:41:19 -08:00
|
|
|
self.assertLen(final_out.sharding.device_set, 1)
|
|
|
|
self.assertArraysEqual(final_out, inp * 6)
|
|
|
|
|
2023-09-13 16:35:02 +01:00
|
|
|
@jtu.run_on_devices("tpu")
|
2022-12-14 15:41:19 -08:00
|
|
|
def test_pjit_with_backend_arg(self):
|
|
|
|
def _check(out, expected_device, expected_out):
|
2023-11-29 16:52:09 -08:00
|
|
|
self.assertEqual(out.devices(), {expected_device})
|
2022-12-14 15:41:19 -08:00
|
|
|
self.assertLen(out.sharding.device_set, 1)
|
|
|
|
self.assertArraysEqual(out, expected_out)
|
|
|
|
|
|
|
|
x = jnp.arange(8)
|
2024-06-12 14:43:14 -07:00
|
|
|
with jtu.ignore_warning(category=DeprecationWarning,
|
|
|
|
message="backend and device argument"):
|
|
|
|
g = pjit(lambda x: x, backend='tpu')
|
2022-12-14 15:41:19 -08:00
|
|
|
g_out = g(x)
|
|
|
|
_check(g_out, jax.devices()[0], x)
|
|
|
|
|
2023-02-14 23:00:40 -08:00
|
|
|
compiled = g.lower(core.ShapedArray(x.shape, x.dtype)).compile()
|
2022-12-14 15:41:19 -08:00
|
|
|
out4 = compiled(x)
|
|
|
|
_check(out4, jax.devices()[0], x)
|
|
|
|
|
|
|
|
def test_autodiff_with_device_arg(self):
|
|
|
|
if jax.device_count() <= 1:
|
|
|
|
self.skipTest('Test requires more >1 device.')
|
|
|
|
# Add a constant captured by the nested pjit to make things more complicated
|
|
|
|
h = jnp.arange(4.)
|
2024-06-12 14:43:14 -07:00
|
|
|
with jtu.ignore_warning(category=DeprecationWarning,
|
|
|
|
message="backend and device argument"):
|
|
|
|
f = pjit(lambda x: x.sum(1) * h.sum(), device=jax.devices()[1])
|
|
|
|
g = pjit(lambda x: f(jnp.sin(x * 4 + 2)), device=jax.devices()[1])
|
2022-12-14 15:41:19 -08:00
|
|
|
jtu.check_grads(g, (jnp.arange(16.).reshape((4, 4)) / 100,), order=2)
|
|
|
|
|
|
|
|
def test_pjit_device_backend_axis_resources_error(self):
|
2023-06-15 15:21:36 -07:00
|
|
|
s = SingleDeviceSharding(jax.devices()[0])
|
2022-12-14 15:41:19 -08:00
|
|
|
with self.assertRaisesRegex(
|
|
|
|
ValueError,
|
|
|
|
'If backend or device is specified on jit, then '
|
2023-02-11 15:29:38 -08:00
|
|
|
'in_shardings should not be specified.'):
|
2024-06-12 14:43:14 -07:00
|
|
|
with jtu.ignore_warning(category=DeprecationWarning,
|
|
|
|
message="backend and device argument"):
|
|
|
|
pjit(lambda x: x, in_shardings=s, backend='cpu')
|
2022-12-14 15:41:19 -08:00
|
|
|
|
|
|
|
with self.assertRaisesRegex(
|
|
|
|
ValueError,
|
|
|
|
'If backend or device is specified on jit, then '
|
2023-02-11 15:29:38 -08:00
|
|
|
'out_shardings should not be specified.'):
|
2024-06-12 14:43:14 -07:00
|
|
|
with jtu.ignore_warning(category=DeprecationWarning,
|
|
|
|
message="backend and device argument"):
|
|
|
|
pjit(lambda x: x, out_shardings=s, device=jax.devices()[0])
|
2022-12-14 15:41:19 -08:00
|
|
|
|
2024-03-21 17:45:44 -07:00
|
|
|
def test_check_arg_error(self):
|
|
|
|
sds = jax.ShapeDtypeStruct((4, 2), np.int32)
|
|
|
|
inp = np.arange(8).reshape(4, 2)
|
|
|
|
|
|
|
|
with self.assertRaisesRegex(
|
|
|
|
TypeError,
|
|
|
|
r"Argument 'x\['b'\]\['c'\]' of shape int32\[4,2\] of "
|
|
|
|
"type.*ShapeDtypeStruct.*is not a valid JAX type."):
|
|
|
|
jax.jit(lambda x: x)({'a': inp, 'b': {'c': sds}})
|
|
|
|
|
2022-12-14 15:41:19 -08:00
|
|
|
def test_pjit_device_backend_both_error(self):
|
|
|
|
with self.assertRaisesRegex(
|
|
|
|
ValueError, "can't specify both a device and a backend for jit"):
|
2024-06-12 14:43:14 -07:00
|
|
|
with jtu.ignore_warning(category=DeprecationWarning,
|
|
|
|
message="backend and device argument"):
|
|
|
|
pjit(lambda x: x, device=jax.devices()[0], backend='cpu')
|
2022-12-14 15:41:19 -08:00
|
|
|
|
|
|
|
def test_pjit_mesh_with_device_or_backend_error(self):
|
2024-09-03 16:22:23 -07:00
|
|
|
mesh = jtu.create_mesh((1,), ('x',))
|
2022-12-14 15:41:19 -08:00
|
|
|
with mesh:
|
|
|
|
with self.assertRaisesRegex(
|
|
|
|
ValueError,
|
|
|
|
"Mesh context manager should not be used with jit when backend or "
|
|
|
|
"device is also specified as an argument to jit."):
|
2024-06-12 14:43:14 -07:00
|
|
|
with jtu.ignore_warning(category=DeprecationWarning,
|
|
|
|
message="backend and device argument"):
|
|
|
|
pjit(lambda x: x, device=jax.devices()[0])(jnp.arange(8))
|
2022-12-14 15:41:19 -08:00
|
|
|
|
2022-12-15 16:25:45 -08:00
|
|
|
def test_pjit_inline(self):
|
|
|
|
@partial(pjit, inline=False)
|
|
|
|
def f(x):
|
|
|
|
return x * 2
|
|
|
|
|
|
|
|
jaxpr = jax.make_jaxpr(f)(3)
|
|
|
|
self.assertIn('pjit', str(jaxpr))
|
|
|
|
|
|
|
|
@partial(pjit, inline=True)
|
|
|
|
def g(x):
|
|
|
|
return x * 2
|
|
|
|
|
|
|
|
jaxpr = jax.make_jaxpr(g)(3)
|
|
|
|
self.assertNotIn('pjit', str(jaxpr))
|
|
|
|
|
2023-01-03 16:08:07 -08:00
|
|
|
def test_pmap_in_axis_resources_error(self):
|
|
|
|
pmap_out = jax.pmap(lambda x: x)(jnp.arange(jax.device_count()))
|
|
|
|
self.assertIsInstance(pmap_out.sharding, jax.sharding.PmapSharding)
|
|
|
|
|
|
|
|
with self.assertRaisesRegex(
|
|
|
|
ValueError,
|
2023-02-11 15:29:38 -08:00
|
|
|
r"One of in_shardings.*got sharding.*which is not allowed."):
|
2023-02-28 14:28:32 -08:00
|
|
|
pjit(lambda x: x, in_shardings=pmap_out.sharding)
|
2023-01-03 16:08:07 -08:00
|
|
|
|
|
|
|
with self.assertRaisesRegex(
|
|
|
|
ValueError,
|
2023-02-11 15:29:38 -08:00
|
|
|
r"One of out_shardings.*got sharding.*which is not allowed."):
|
2023-02-28 14:28:32 -08:00
|
|
|
pjit(lambda x: x, out_shardings=pmap_out.sharding)
|
2023-01-03 16:08:07 -08:00
|
|
|
|
|
|
|
def test_pmap_sharding_input_to_pjit_single_device(self):
|
|
|
|
pmap_out = jax.pmap(lambda x: x)(jnp.arange(jax.device_count()))
|
|
|
|
self.assertIsInstance(pmap_out.sharding, jax.sharding.PmapSharding)
|
|
|
|
self.assertLen(pmap_out.devices(), jax.device_count())
|
|
|
|
|
|
|
|
out = pjit(lambda x: x * 3)(pmap_out)
|
|
|
|
self.assertArraysEqual(out, pmap_out * 3)
|
|
|
|
# Even though pmap out is on jax.device_count() number of devices, the
|
|
|
|
# output will be 1 device since it will be resharded.
|
|
|
|
self.assertLen(out.devices(), 1)
|
|
|
|
|
|
|
|
def test_pmap_sharding_input_to_pjit_multi_device(self):
|
2024-09-03 16:22:23 -07:00
|
|
|
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
2023-01-03 16:08:07 -08:00
|
|
|
|
|
|
|
pmap_out = jax.pmap(lambda x: x)(jnp.arange(jax.device_count()))
|
|
|
|
self.assertIsInstance(pmap_out.sharding, jax.sharding.PmapSharding)
|
|
|
|
|
|
|
|
inp2 = jnp.arange(4)
|
|
|
|
with mesh:
|
|
|
|
out1, out2 = pjit(lambda x, y: (x * 2, y * 2))(pmap_out, inp2)
|
|
|
|
|
|
|
|
self.assertArraysEqual(out1, pmap_out * 2)
|
|
|
|
self.assertArraysEqual(out2, inp2 * 2)
|
|
|
|
self.assertLen(out1.devices(), 4)
|
|
|
|
self.assertLen(out2.devices(), 4)
|
2023-04-06 08:31:47 -07:00
|
|
|
self.assertTrue(op_shardings.is_op_sharding_replicated(
|
2023-06-05 13:40:59 -07:00
|
|
|
out1.sharding._to_xla_hlo_sharding(pmap_out.ndim)))
|
2023-04-06 08:31:47 -07:00
|
|
|
self.assertTrue(op_shardings.is_op_sharding_replicated(
|
2023-06-05 13:40:59 -07:00
|
|
|
out2.sharding._to_xla_hlo_sharding(inp2.ndim)))
|
2023-01-03 16:08:07 -08:00
|
|
|
|
|
|
|
def test_pmap_sharding_input_pjit_in_axis_resources(self):
|
2024-09-03 16:22:23 -07:00
|
|
|
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
2023-01-03 16:08:07 -08:00
|
|
|
|
|
|
|
pmap_out = jax.pmap(lambda x: x)(jnp.arange(jax.device_count()))
|
|
|
|
self.assertIsInstance(pmap_out.sharding, jax.sharding.PmapSharding)
|
|
|
|
|
2023-02-28 14:28:32 -08:00
|
|
|
out = pjit(lambda x: x * 2, in_shardings=NamedSharding(mesh, P('x')))(pmap_out)
|
2023-01-03 16:08:07 -08:00
|
|
|
self.assertArraysEqual(out, pmap_out * 2)
|
|
|
|
self.assertLen(out.devices(), 4)
|
|
|
|
|
2023-01-08 10:37:40 -08:00
|
|
|
def test_nested_pjit_closing_over_tracer(self):
|
|
|
|
@pjit
|
|
|
|
def f(x):
|
|
|
|
y = jnp.float32(2) * x
|
|
|
|
|
|
|
|
@pjit
|
|
|
|
def g(z):
|
|
|
|
return jax.pmap(lambda x: x[jnp.newaxis] * y)(z)
|
|
|
|
|
|
|
|
return g(x)
|
|
|
|
|
|
|
|
f(np.arange(1., dtype='float32').reshape((1, 1))) # doesn't crash
|
|
|
|
# Second call is to trigger C++ dispatch.
|
|
|
|
f(np.arange(1., dtype='float32').reshape((1, 1))) # doesn't crash
|
|
|
|
|
|
|
|
def test_aot_nested_pjit_closing_over_const_top_level(self):
|
|
|
|
const = jnp.arange(8.)
|
|
|
|
|
|
|
|
@pjit
|
|
|
|
def f(x):
|
|
|
|
return const * 2 + x
|
|
|
|
|
|
|
|
inp = jnp.arange(8.)
|
|
|
|
compiled = f.lower(inp).compile()
|
|
|
|
self.assertArraysEqual(compiled(inp), const * 2 + inp)
|
|
|
|
|
|
|
|
def test_nested_pjit_closing_over_const_top_level_and_tracer(self):
|
|
|
|
const = jnp.arange(8.)
|
|
|
|
|
|
|
|
@pjit
|
|
|
|
def f(x):
|
|
|
|
y = jnp.arange(8., 16.) * x + const
|
|
|
|
|
|
|
|
@pjit
|
|
|
|
def g(z):
|
|
|
|
return z + y * 2 + const
|
|
|
|
|
|
|
|
return g(x)
|
|
|
|
|
|
|
|
f(jnp.arange(8.)) # doesn't crash
|
|
|
|
# Second call is to trigger C++ dispatch.
|
|
|
|
f(jnp.arange(8.)) # doesn't crash
|
|
|
|
|
|
|
|
def test_nested_pjit_closing_over_top_level_const(self):
|
|
|
|
const = jnp.arange(8.)
|
|
|
|
|
|
|
|
@pjit
|
|
|
|
def f(x):
|
|
|
|
|
|
|
|
@pjit
|
|
|
|
def g(z):
|
|
|
|
return z + const
|
|
|
|
|
|
|
|
return g(x)
|
|
|
|
|
|
|
|
inp = jnp.arange(8., 16.)
|
|
|
|
f(inp) # doesn't crash
|
|
|
|
# Second call is to trigger C++ dispatch.
|
|
|
|
f(inp) # doesn't crash
|
|
|
|
|
2023-01-12 17:23:55 -08:00
|
|
|
def test_pjit_sin_nested(self):
|
2024-09-03 16:22:23 -07:00
|
|
|
mesh = jtu.create_mesh((4, 2), ('x', 'y'))
|
2023-01-12 17:23:55 -08:00
|
|
|
|
|
|
|
@pjit
|
|
|
|
def f(x):
|
|
|
|
return jnp.sin(x)
|
|
|
|
|
|
|
|
with mesh:
|
|
|
|
inp = jnp.arange(8.)
|
|
|
|
out = f(inp)
|
|
|
|
self.assertArraysAllClose(out, np.sin(inp))
|
|
|
|
self.assertLen(out.devices(), 8)
|
|
|
|
|
|
|
|
def test_jit_with_mesh_context_manager(self):
|
2024-09-03 16:22:23 -07:00
|
|
|
mesh = jtu.create_mesh((1,), ('x',))
|
2023-01-12 17:23:55 -08:00
|
|
|
with self.assertRaisesRegex(
|
|
|
|
RuntimeError,
|
2024-06-05 09:06:36 -07:00
|
|
|
"jax.jit only supports `Sharding`s being passed to "
|
2023-04-06 10:49:57 -07:00
|
|
|
"in_shardings"):
|
2023-01-12 17:23:55 -08:00
|
|
|
with mesh:
|
2023-02-11 15:29:38 -08:00
|
|
|
jax.jit(lambda x: x, in_shardings=P('x'),
|
|
|
|
out_shardings=P('x'))(jnp.arange(8))
|
2023-01-12 17:23:55 -08:00
|
|
|
|
2023-01-22 14:06:25 -08:00
|
|
|
def test_pjit_nested_uncommitted_output(self):
|
|
|
|
@pjit
|
|
|
|
def f(x):
|
|
|
|
@pjit
|
|
|
|
def g(y):
|
|
|
|
return y * 2
|
|
|
|
return g(x)
|
|
|
|
|
|
|
|
out = f(jnp.arange(8))
|
|
|
|
self.assertFalse(out._committed)
|
|
|
|
self.assertArraysEqual(out, np.arange(8) * 2)
|
|
|
|
|
2023-01-23 13:30:49 -08:00
|
|
|
def test_pjit_disable_jit(self):
|
|
|
|
sideeffect = []
|
|
|
|
|
|
|
|
def f(x):
|
|
|
|
sideeffect.append(None)
|
|
|
|
return x + 1
|
|
|
|
|
|
|
|
f = jax.jit(f)
|
|
|
|
for _ in range(2):
|
|
|
|
f(1)
|
|
|
|
self.assertLen(sideeffect, 1)
|
|
|
|
|
|
|
|
with jax.disable_jit():
|
|
|
|
f(1)
|
|
|
|
self.assertLen(sideeffect, 2)
|
|
|
|
|
2023-01-26 07:32:41 -08:00
|
|
|
def test_pmap_pjit_axis_index(self):
|
|
|
|
@partial(jax.pmap, axis_name='data')
|
|
|
|
def _pmapped_fun(inputs):
|
|
|
|
del inputs
|
|
|
|
return jax.lax.axis_index('data')
|
|
|
|
|
|
|
|
inputs = jnp.zeros(shape=[jax.device_count()])
|
2023-01-26 21:20:04 +00:00
|
|
|
with jtu.ignore_warning(
|
|
|
|
message=".*Using jit-of-pmap can lead to inefficient data movement"):
|
|
|
|
pjit(_pmapped_fun)(inputs) # doesn't crash
|
|
|
|
jax.jit(_pmapped_fun)(inputs) # doesn't crash
|
2023-01-26 07:32:41 -08:00
|
|
|
|
2023-02-06 20:34:51 -08:00
|
|
|
def test_pjit_function_cache_cpp(self):
|
|
|
|
def f(x):
|
|
|
|
return x * 2
|
|
|
|
|
|
|
|
inp = jnp.arange(3.)
|
|
|
|
|
2023-04-03 14:14:20 -07:00
|
|
|
with jtu.count_pjit_cpp_cache_miss() as count:
|
2023-02-06 20:34:51 -08:00
|
|
|
for _ in range(10):
|
|
|
|
pjit(f)(inp)
|
|
|
|
self.assertEqual(count[0], 1)
|
|
|
|
|
|
|
|
def test_pjit_no_global_cache_hit_axis_resources(self):
|
2024-09-03 16:22:23 -07:00
|
|
|
mesh = jtu.create_mesh((1,), ('x',))
|
2023-02-06 20:34:51 -08:00
|
|
|
s = NamedSharding(mesh, P('x'))
|
2023-11-29 18:06:36 -08:00
|
|
|
inp = jnp.arange(8.0)
|
2023-02-06 20:34:51 -08:00
|
|
|
|
2023-04-03 14:14:20 -07:00
|
|
|
with jtu.count_pjit_cpp_cache_miss() as count:
|
2023-02-06 20:34:51 -08:00
|
|
|
for _ in range(10):
|
2023-11-29 18:06:36 -08:00
|
|
|
pjit(lambda x: x * 2, in_shardings=s, out_shardings=s)(inp)
|
2023-02-06 20:34:51 -08:00
|
|
|
self.assertEqual(count[0], 10)
|
|
|
|
|
2023-04-03 14:14:20 -07:00
|
|
|
with jtu.count_pjit_cpp_cache_miss() as count:
|
2023-02-06 20:34:51 -08:00
|
|
|
for _ in range(10):
|
2024-06-12 14:43:14 -07:00
|
|
|
with jtu.ignore_warning(category=DeprecationWarning,
|
|
|
|
message="backend and device argument"):
|
|
|
|
pjit(lambda x: x * 2, device=jax.devices()[0])(inp)
|
2023-02-06 20:34:51 -08:00
|
|
|
self.assertEqual(count[0], 10)
|
|
|
|
|
2023-02-18 09:59:58 -08:00
|
|
|
pf = pjit(lambda x: x * 2, in_shardings=s, out_shardings=s)
|
2023-04-03 14:14:20 -07:00
|
|
|
with jtu.count_pjit_cpp_cache_miss() as count:
|
2023-02-06 20:34:51 -08:00
|
|
|
for _ in range(10):
|
2023-11-29 18:06:36 -08:00
|
|
|
pf(inp)
|
2023-02-06 20:34:51 -08:00
|
|
|
self.assertEqual(count[0], 1)
|
|
|
|
|
2024-06-12 14:43:14 -07:00
|
|
|
with jtu.ignore_warning(category=DeprecationWarning,
|
|
|
|
message="backend and device argument"):
|
|
|
|
pf1 = pjit(lambda x: x * 2, device=jax.devices()[0])
|
2023-04-03 14:14:20 -07:00
|
|
|
with jtu.count_pjit_cpp_cache_miss() as count:
|
2023-02-06 20:34:51 -08:00
|
|
|
for _ in range(10):
|
2023-11-29 18:06:36 -08:00
|
|
|
pf1(inp)
|
2023-02-06 20:34:51 -08:00
|
|
|
self.assertEqual(count[0], 1)
|
|
|
|
|
2023-02-13 14:57:50 -08:00
|
|
|
def test_with_sharding_constraint_spmd_axis_name(self):
|
2024-09-03 16:22:23 -07:00
|
|
|
mesh = jtu.create_mesh((2, 2, 2), ('replica', 'data', 'mdl'))
|
2023-02-13 14:57:50 -08:00
|
|
|
shape = (8, 4, 2, 2)
|
2023-02-28 12:40:30 -08:00
|
|
|
x = jnp.arange(math.prod(shape)).reshape(shape)
|
2023-02-13 14:57:50 -08:00
|
|
|
|
|
|
|
def f(inp):
|
2023-05-05 10:47:53 -07:00
|
|
|
sharding = NamedSharding(mesh, P('data', None, None))
|
|
|
|
return with_sharding_constraint(inp, sharding)
|
2023-02-13 14:57:50 -08:00
|
|
|
|
2023-05-05 10:47:53 -07:00
|
|
|
out = jax.vmap(jax.jit(f), spmd_axis_name='mdl')(x)
|
|
|
|
ns, _ = op_shardings.get_num_ways_dim_sharded(
|
2023-06-05 13:40:59 -07:00
|
|
|
out.sharding._to_xla_hlo_sharding(out.ndim))
|
2023-05-05 10:47:53 -07:00
|
|
|
self.assertListEqual(ns, [2, 2, 1, 1])
|
2023-02-13 14:57:50 -08:00
|
|
|
|
|
|
|
def apply_with_scan(x):
|
|
|
|
x, _ = jax.lax.scan(lambda x, _: (f(x), None), x, None, length=1)
|
|
|
|
return x
|
|
|
|
|
2023-05-05 10:47:53 -07:00
|
|
|
out2 = jax.vmap(apply_with_scan, spmd_axis_name='mdl')(x)
|
|
|
|
ns2, _ = op_shardings.get_num_ways_dim_sharded(
|
2023-06-05 13:40:59 -07:00
|
|
|
out2.sharding._to_xla_hlo_sharding(out2.ndim))
|
2023-05-05 10:47:53 -07:00
|
|
|
self.assertListEqual(ns2, [2, 2, 1, 1])
|
2023-02-13 14:57:50 -08:00
|
|
|
|
2023-02-23 15:37:13 -08:00
|
|
|
def test_device_put_sharding_nondivisible_sharding_error(self):
|
2024-09-03 16:22:23 -07:00
|
|
|
mesh = jtu.create_mesh((2,), ('x',))
|
2023-02-23 15:37:13 -08:00
|
|
|
s = NamedSharding(mesh, P('x'))
|
|
|
|
|
|
|
|
x = jnp.ones((1,))
|
|
|
|
with self.assertRaisesRegex(
|
2023-04-06 14:51:30 -07:00
|
|
|
ValueError, 'implies that the global size of its dimension 0 should be '
|
2023-02-23 15:37:13 -08:00
|
|
|
'divisible by 2, but it is equal to 1 '):
|
|
|
|
jax.device_put(x, s)
|
|
|
|
|
|
|
|
y = jnp.ones((2,))
|
|
|
|
with self.assertRaisesRegex(
|
2023-04-06 14:51:30 -07:00
|
|
|
ValueError, 'implies that the global size of its dimension 0 should be '
|
2023-02-23 15:37:13 -08:00
|
|
|
'divisible by 2, but it is equal to 1 '):
|
|
|
|
jax.device_put((y, x), s)
|
|
|
|
|
|
|
|
with self.assertRaisesRegex(
|
|
|
|
ValueError,
|
|
|
|
"The sharded dimension must be equal to the number of "
|
|
|
|
"devices passed to PmapSharding. Got sharded dimension 0 with value 1 "
|
|
|
|
r"in shape \(1,\) and the number of devices=2"):
|
|
|
|
s2 = jax.pmap(lambda x: x,
|
|
|
|
devices=list(mesh.devices.flat))(jnp.arange(2)).sharding
|
|
|
|
jax.device_put(x, s2)
|
|
|
|
|
2023-02-24 12:55:35 -08:00
|
|
|
jax.device_put(2., NamedSharding(mesh, P())) # doesn't crash
|
|
|
|
|
2023-03-06 10:45:02 -08:00
|
|
|
def test_with_sharding_constraint_with_two_meshes(self):
|
|
|
|
if jax.device_count() < 4:
|
|
|
|
self.skipTest("Requires more than 4 devices.")
|
|
|
|
|
|
|
|
dev0 = jax.devices()[:2]
|
|
|
|
mesh0 = jax.sharding.Mesh(dev0, ('x'))
|
|
|
|
|
|
|
|
dev1 = jax.devices()[2:4]
|
|
|
|
mesh1 = jax.sharding.Mesh(dev1, ('x'))
|
|
|
|
|
|
|
|
def f(x):
|
|
|
|
y = x * 2
|
|
|
|
y = jax.lax.with_sharding_constraint(y, P('x'))
|
|
|
|
return y + 2
|
|
|
|
|
|
|
|
with mesh0:
|
|
|
|
x = np.ones((32, 4))
|
|
|
|
out0 = pjit(f)(x)
|
|
|
|
self.assertListEqual(sorted([d.id for d in out0.devices()]),
|
|
|
|
[d.id for d in dev0])
|
|
|
|
|
|
|
|
with mesh1:
|
|
|
|
x = np.ones((32, 4))
|
|
|
|
out1 = pjit(f)(x)
|
|
|
|
self.assertListEqual(sorted([d.id for d in out1.devices()]),
|
|
|
|
[d.id for d in dev1])
|
|
|
|
|
2023-03-21 08:39:46 -07:00
|
|
|
def test_device_assignment_mismatch_apply_primitive(self):
|
|
|
|
if jax.device_count() < 2:
|
|
|
|
self.skipTest("Requires >=2 devices.")
|
|
|
|
arr = jax.device_put(np.arange(8), jax.devices()[0])
|
|
|
|
arr2 = jax.device_put(np.arange(8), jax.devices()[1])
|
|
|
|
with self.assertRaisesRegex(
|
|
|
|
ValueError,
|
2023-11-29 18:06:36 -08:00
|
|
|
"Received incompatible devices for jitted computation. Got argument.*"
|
|
|
|
r"of concatenate with shape int.*\[8\].*and argument.*"):
|
2023-03-21 08:39:46 -07:00
|
|
|
jnp.concatenate([arr, arr2])
|
|
|
|
|
2023-03-24 21:09:45 -07:00
|
|
|
def test_device_put_grad(self):
|
|
|
|
if jax.device_count() < 8:
|
|
|
|
self.skipTest("Requires >=8 devices.")
|
|
|
|
|
|
|
|
def _test(fun, inp, np_inp, in_s):
|
|
|
|
out = fun(inp)
|
|
|
|
self.assertArraysEqual(out, np.sum(np_inp ** 2 * 3))
|
|
|
|
self.assertArraysEqual(
|
|
|
|
[d.id for d in out.sharding._device_assignment], [4, 5, 6, 7])
|
|
|
|
|
|
|
|
gout = jax.grad(fun)(inp)
|
|
|
|
self.assertTrue(gout.sharding.is_equivalent_to(in_s, gout.ndim))
|
|
|
|
self.assertArraysEqual(
|
|
|
|
[d.id for d in gout.sharding._device_assignment], [0, 1, 2, 3])
|
|
|
|
self.assertArraysEqual(gout, jax.grad(fun)(np_inp))
|
|
|
|
|
|
|
|
mesh1 = jax.sharding.Mesh(jax.devices()[:4], 'x')
|
|
|
|
mesh2 = jax.sharding.Mesh(jax.devices()[4:8], 'x')
|
|
|
|
|
|
|
|
@pjit
|
|
|
|
def stage1(x):
|
|
|
|
return x ** 2
|
|
|
|
|
|
|
|
@pjit
|
|
|
|
def stage2(x):
|
|
|
|
return x * 3
|
|
|
|
|
|
|
|
def f(x):
|
|
|
|
y = stage1(x)
|
|
|
|
y = jax.device_put(y, device=NamedSharding(mesh2, P('x')))
|
|
|
|
z = stage2(y)
|
|
|
|
return jnp.sum(z)
|
|
|
|
|
|
|
|
def g(x):
|
|
|
|
y = stage1(x)
|
|
|
|
y = jax.device_put(y, src=NamedSharding(mesh1, P('x')),
|
|
|
|
device=NamedSharding(mesh2, P('x')))
|
|
|
|
z = stage2(y)
|
|
|
|
return jnp.sum(z)
|
|
|
|
|
|
|
|
np_inp = np.arange(4.)
|
|
|
|
in_s = NamedSharding(mesh1, P('x'))
|
|
|
|
arr = jax.device_put(np_inp, in_s)
|
|
|
|
|
|
|
|
_test(f, arr, np_inp, in_s)
|
|
|
|
|
|
|
|
_test(g, arr, np_inp, in_s)
|
|
|
|
# Test second order autodiff with src argument specified in device_put.
|
|
|
|
jtu.check_grads(g, (arr,), order=2)
|
|
|
|
|
2023-04-09 15:41:32 -07:00
|
|
|
def test_pjit_out_sharding_preserved(self):
|
2024-08-23 06:50:14 -07:00
|
|
|
if config.use_shardy_partitioner.value:
|
|
|
|
raise unittest.SkipTest("Shardy doesn't support PositionalSharding")
|
2024-09-03 16:22:23 -07:00
|
|
|
mesh = jtu.create_mesh((2, 1), ('x', 'y'))
|
2023-04-09 15:41:32 -07:00
|
|
|
ns = NamedSharding(mesh, P('x'))
|
|
|
|
ps = PositionalSharding(jax.devices()[:2]).reshape(2, 1)
|
|
|
|
|
|
|
|
arr = jax.device_put(np.arange(8).reshape(8, 1), ns)
|
|
|
|
arr2 = jax.device_put(np.arange(8).reshape(8, 1), ps)
|
|
|
|
|
|
|
|
def mul(x):
|
|
|
|
return x * 2
|
|
|
|
|
|
|
|
f = pjit(mul, out_shardings=ns)
|
|
|
|
f2 = pjit(mul, out_shardings=ps)
|
|
|
|
|
|
|
|
with jtu.count_pjit_cpp_cache_miss() as count:
|
|
|
|
out = f(arr)
|
|
|
|
cache_info1 = pxla._cached_compilation.cache_info()
|
|
|
|
self.assertIsInstance(out.sharding, NamedSharding)
|
|
|
|
|
|
|
|
out = f(arr)
|
|
|
|
self.assertIsInstance(out.sharding, NamedSharding)
|
|
|
|
self.assertEqual(count[0], 1)
|
|
|
|
|
|
|
|
with jtu.count_pjit_cpp_cache_miss() as count:
|
|
|
|
out2 = f2(arr)
|
|
|
|
cache_info2 = pxla._cached_compilation.cache_info()
|
|
|
|
self.assertIsInstance(out2.sharding, PositionalSharding)
|
|
|
|
|
|
|
|
out2 = f2(arr)
|
|
|
|
self.assertIsInstance(out2.sharding, PositionalSharding)
|
|
|
|
self.assertEqual(count[0], 1)
|
|
|
|
|
|
|
|
self.assertEqual(cache_info2.hits, cache_info1.hits + 1)
|
|
|
|
self.assertEqual(cache_info2.misses, cache_info1.misses)
|
|
|
|
|
|
|
|
out3 = jnp.squeeze(arr, axis=-1)
|
|
|
|
cache_info3 = pxla._cached_compilation.cache_info()
|
|
|
|
self.assertIsInstance(out3.sharding, NamedSharding)
|
|
|
|
|
|
|
|
out4 = jnp.squeeze(arr2, axis=-1)
|
|
|
|
cache_info4 = pxla._cached_compilation.cache_info()
|
2023-04-11 16:27:08 -07:00
|
|
|
self.assertIsInstance(out4.sharding, PositionalSharding)
|
2023-04-09 15:41:32 -07:00
|
|
|
|
2024-08-09 23:16:54 -07:00
|
|
|
self.assertEqual(cache_info4.hits, cache_info3.hits + 1)
|
|
|
|
self.assertEqual(cache_info4.misses, cache_info3.misses)
|
2023-04-09 15:41:32 -07:00
|
|
|
|
|
|
|
def test_cache_hit_pjit_lower_with_cpp_cache_miss(self):
|
2024-09-03 16:22:23 -07:00
|
|
|
mesh = jtu.create_mesh((2, 1), ('x', 'y'))
|
2023-04-09 15:41:32 -07:00
|
|
|
ns = NamedSharding(mesh, P('x'))
|
|
|
|
np_arr = np.arange(8, dtype=np.float32).reshape(8, 1)
|
|
|
|
arr = jax.device_put(np_arr, ns)
|
|
|
|
|
|
|
|
def mul(x):
|
|
|
|
return x * 2
|
|
|
|
|
|
|
|
f = pjit(mul, in_shardings=ns, out_shardings=ns)
|
|
|
|
|
|
|
|
with jtu.count_pjit_cpp_cache_miss() as count:
|
|
|
|
out = f(arr)
|
|
|
|
cache_info1 = pjit_lib._pjit_lower_cached.cache_info()
|
|
|
|
self.assertIsInstance(out.sharding, NamedSharding)
|
|
|
|
|
|
|
|
out2 = f(np_arr)
|
|
|
|
cache_info2 = pjit_lib._pjit_lower_cached.cache_info()
|
|
|
|
self.assertIsInstance(out2.sharding, NamedSharding)
|
|
|
|
|
|
|
|
# Drops out of C++ cache i.e. cache miss
|
|
|
|
self.assertEqual(count[0], 2)
|
|
|
|
# Still gets a hit on pjit_lower cache.
|
|
|
|
self.assertEqual(cache_info2.hits, cache_info1.hits + 1)
|
|
|
|
self.assertEqual(cache_info2.misses, cache_info1.misses)
|
|
|
|
|
2024-05-22 13:55:30 -07:00
|
|
|
def test_list_in_pspec(self):
|
2024-09-03 16:22:23 -07:00
|
|
|
mesh = jtu.create_mesh((2,), ('x',))
|
2024-05-22 13:55:30 -07:00
|
|
|
with mesh:
|
|
|
|
out = with_sharding_constraint(jnp.arange(8), P(['x']))
|
|
|
|
self.assertEqual(out.sharding, NamedSharding(mesh, P('x')))
|
|
|
|
|
2023-04-09 15:41:32 -07:00
|
|
|
def test_sharding_preserved_trivial(self):
|
2024-08-23 06:50:14 -07:00
|
|
|
if config.use_shardy_partitioner.value:
|
|
|
|
raise unittest.SkipTest("Shardy doesn't support PositionalSharding")
|
2024-09-03 16:22:23 -07:00
|
|
|
mesh = jtu.create_mesh((2, 1), ('x', 'y'))
|
2023-04-09 15:41:32 -07:00
|
|
|
ns = NamedSharding(mesh, P('x'))
|
|
|
|
ps = PositionalSharding(jax.devices()[:2]).reshape(2, 1)
|
|
|
|
|
|
|
|
arr = jax.device_put(np.arange(8).reshape(8, 1), ns)
|
|
|
|
arr2 = jax.device_put(np.arange(8).reshape(8, 1), ps)
|
|
|
|
|
|
|
|
def identity(x):
|
|
|
|
return x
|
|
|
|
|
|
|
|
out = pjit(identity)(arr)
|
|
|
|
self.assertIsInstance(out.sharding, NamedSharding)
|
|
|
|
|
|
|
|
out2 = pjit(identity)(arr2)
|
|
|
|
self.assertIsInstance(out2.sharding, PositionalSharding)
|
|
|
|
|
|
|
|
def test_sharding_preserved_aot(self):
|
2024-09-03 16:22:23 -07:00
|
|
|
mesh = jtu.create_mesh((2, 1), ('x', 'y'))
|
2023-04-09 15:41:32 -07:00
|
|
|
ns = NamedSharding(mesh, P('x'))
|
|
|
|
ps = PositionalSharding(jax.devices()[:2]).reshape(2, 1)
|
|
|
|
|
|
|
|
arr = jax.device_put(np.arange(8).reshape(8, 1), ns)
|
|
|
|
arr2 = jax.device_put(np.arange(8).reshape(8, 1), ps)
|
|
|
|
|
|
|
|
compiled = pjit(lambda x: x * 2).lower(arr).compile()
|
|
|
|
out = compiled(arr)
|
|
|
|
self.assertIsInstance(out.sharding, NamedSharding)
|
|
|
|
|
|
|
|
out2 = compiled(arr2)
|
|
|
|
# The sharding won't be PositionalSharding since the pjit was already
|
|
|
|
# Compiled which bakes in the output sharding.
|
|
|
|
self.assertIsInstance(out2.sharding, NamedSharding)
|
|
|
|
|
|
|
|
def test_sharding_on_output_with_vmap(self):
|
2024-09-03 16:22:23 -07:00
|
|
|
mesh = jtu.create_mesh((2, 1), ('x', 'y'))
|
2023-04-09 15:41:32 -07:00
|
|
|
ns = NamedSharding(mesh, P('x'))
|
|
|
|
arr = jax.device_put(
|
|
|
|
np.arange(16).reshape(8, 2), NamedSharding(mesh, P(None, 'x')))
|
|
|
|
|
2024-08-09 20:03:06 -07:00
|
|
|
with jtu.count_jit_and_pmap_lowerings() as count:
|
2023-05-26 08:56:56 -07:00
|
|
|
vf = jax.vmap(pjit(lambda x: x * 2, in_shardings=ns))
|
|
|
|
out = vf(arr)
|
|
|
|
self.assertIsInstance(out.sharding, NamedSharding)
|
2023-04-09 15:41:32 -07:00
|
|
|
|
2023-05-26 08:56:56 -07:00
|
|
|
out2 = vf(out)
|
|
|
|
self.assertIsInstance(out2.sharding, NamedSharding)
|
2023-04-09 15:41:32 -07:00
|
|
|
|
2023-05-26 08:56:56 -07:00
|
|
|
out3 = vf(out2)
|
|
|
|
self.assertIsInstance(out3.sharding, NamedSharding)
|
|
|
|
self.assertEqual(count[0], 1)
|
2023-04-09 15:41:32 -07:00
|
|
|
|
|
|
|
def test_jit_mul_sum_sharding_preserved(self):
|
2024-08-23 06:50:14 -07:00
|
|
|
if config.use_shardy_partitioner.value:
|
|
|
|
raise unittest.SkipTest("Shardy doesn't support PositionalSharding")
|
2024-09-03 16:22:23 -07:00
|
|
|
mesh = jtu.create_mesh((2, 1), ('x', 'y'))
|
2023-04-09 15:41:32 -07:00
|
|
|
ns = NamedSharding(mesh, P('x'))
|
|
|
|
ps = PositionalSharding(jax.devices()[:2]).reshape(2, 1)
|
|
|
|
|
|
|
|
arr = jax.device_put(np.arange(8).reshape(8, 1), ns)
|
|
|
|
arr2 = jax.device_put(np.arange(8).reshape(8, 1), ps)
|
|
|
|
|
|
|
|
f = jax.jit(lambda x: x * 2)
|
|
|
|
out = f(arr)
|
|
|
|
cache_info1 = pxla._cached_compilation.cache_info()
|
|
|
|
pl_cache_info1 = pjit_lib._pjit_lower_cached.cache_info()
|
|
|
|
self.assertIsInstance(out.sharding, NamedSharding)
|
|
|
|
|
2023-04-11 16:27:08 -07:00
|
|
|
with jtu.count_pjit_cpp_cache_miss() as count:
|
|
|
|
out2 = f(arr2)
|
|
|
|
cache_info2 = pxla._cached_compilation.cache_info()
|
|
|
|
pl_cache_info2 = pjit_lib._pjit_lower_cached.cache_info()
|
|
|
|
self.assertIsInstance(out2.sharding, PositionalSharding)
|
2023-04-09 15:41:32 -07:00
|
|
|
|
2023-04-11 16:27:08 -07:00
|
|
|
# This will hit the cpp cache.
|
|
|
|
out3 = f(out2)
|
|
|
|
self.assertIsInstance(out3.sharding, PositionalSharding)
|
|
|
|
self.assertEqual(count[0], 1)
|
2023-04-09 15:41:32 -07:00
|
|
|
|
2024-08-09 23:16:54 -07:00
|
|
|
self.assertEqual(cache_info2.hits, cache_info1.hits + 1)
|
|
|
|
self.assertEqual(cache_info2.misses, cache_info1.misses)
|
2023-04-09 15:41:32 -07:00
|
|
|
|
2023-04-11 16:27:08 -07:00
|
|
|
self.assertEqual(pl_cache_info2.hits, pl_cache_info1.hits)
|
2023-04-09 15:41:32 -07:00
|
|
|
self.assertEqual(pl_cache_info2.misses, pl_cache_info1.misses + 1)
|
|
|
|
|
|
|
|
out4 = jnp.sum(arr)
|
|
|
|
self.assertIsInstance(out4.sharding, NamedSharding)
|
|
|
|
|
|
|
|
def test_single_device_sharding_preserved(self):
|
|
|
|
if jax.device_count() < 2:
|
|
|
|
self.skipTest('Test requires >=2 devices')
|
|
|
|
|
|
|
|
x = jnp.arange(8)
|
|
|
|
|
|
|
|
# trivial computation
|
|
|
|
out = jax.jit(lambda x: x)(x)
|
|
|
|
self.assertIsInstance(out.sharding, SingleDeviceSharding)
|
|
|
|
|
|
|
|
# trivial computation with committed inp
|
|
|
|
y = jax.device_put(x, jax.devices()[1])
|
|
|
|
out2 = jax.jit(lambda x: x)(y)
|
|
|
|
self.assertIsInstance(out2.sharding, SingleDeviceSharding)
|
2023-11-29 16:52:09 -08:00
|
|
|
self.assertEqual(out2.devices(), {jax.devices()[1]})
|
2023-04-09 15:41:32 -07:00
|
|
|
|
|
|
|
out3 = jax.jit(lambda x: x * 2)(x)
|
|
|
|
self.assertIsInstance(out3.sharding, SingleDeviceSharding)
|
|
|
|
|
|
|
|
out4 = jax.jit(lambda x: x * 3,
|
|
|
|
out_shardings=SingleDeviceSharding(jax.devices()[1]))(x)
|
|
|
|
self.assertIsInstance(out4.sharding, SingleDeviceSharding)
|
2023-11-29 16:52:09 -08:00
|
|
|
self.assertEqual(out4.devices(), {jax.devices()[1]})
|
2023-04-09 15:41:32 -07:00
|
|
|
|
2023-04-10 08:42:18 -07:00
|
|
|
def test_none_out_sharding(self):
|
2024-09-03 16:22:23 -07:00
|
|
|
mesh = jtu.create_mesh((2, 1), ('x', 'y'))
|
2023-04-10 08:42:18 -07:00
|
|
|
x = jnp.arange(8)
|
|
|
|
with mesh:
|
|
|
|
out = pjit(lambda x: x * 2, out_shardings=None)(x)
|
|
|
|
self.assertEqual(out.sharding.mesh, mesh)
|
|
|
|
self.assertIsInstance(out.sharding, NamedSharding)
|
|
|
|
self.assertEqual(out.sharding.spec, P())
|
|
|
|
|
|
|
|
x2 = jax.device_put(x, NamedSharding(mesh, P()))
|
|
|
|
out2 = pjit(lambda x: x * 2)(x2)
|
|
|
|
self.assertIsInstance(out2.sharding, NamedSharding)
|
|
|
|
self.assertEqual(out2.sharding.mesh, mesh)
|
|
|
|
self.assertEqual(out2.sharding.spec, P())
|
|
|
|
|
2023-04-09 15:41:32 -07:00
|
|
|
def test_sharding_preserved_apply_primitive(self):
|
2024-08-23 06:50:14 -07:00
|
|
|
if config.use_shardy_partitioner.value:
|
|
|
|
raise unittest.SkipTest("Shardy doesn't support PositionalSharding")
|
2024-09-03 16:22:23 -07:00
|
|
|
mesh = jtu.create_mesh((2, 1), ('x', 'y'))
|
2023-04-09 15:41:32 -07:00
|
|
|
ns = NamedSharding(mesh, P('x'))
|
|
|
|
|
|
|
|
arr = jax.device_put(np.arange(8).reshape(8, 1), ns)
|
|
|
|
|
|
|
|
out = jnp.copy(arr)
|
|
|
|
self.assertIsInstance(out.sharding, NamedSharding)
|
|
|
|
|
2023-04-10 12:40:26 -07:00
|
|
|
ps = PositionalSharding(jax.devices()[:2]).reshape(2, 1)
|
|
|
|
arr2 = jax.device_put(np.arange(8).reshape(8, 1), ps)
|
|
|
|
out2 = jnp.copy(arr2)
|
2023-04-11 16:27:08 -07:00
|
|
|
self.assertIsInstance(out2.sharding, PositionalSharding)
|
2023-04-10 12:40:26 -07:00
|
|
|
|
|
|
|
arr3 = jnp.arange(8)
|
|
|
|
out3 = jnp.copy(arr3)
|
|
|
|
self.assertIsInstance(out3.sharding, SingleDeviceSharding)
|
|
|
|
|
|
|
|
arr4 = jax.device_put(jnp.arange(8), jax.devices()[1])
|
|
|
|
out4 = jnp.copy(arr4)
|
|
|
|
self.assertIsInstance(out4.sharding, SingleDeviceSharding)
|
2023-11-29 16:52:09 -08:00
|
|
|
self.assertEqual(out4.devices(), {jax.devices()[1]})
|
2023-04-09 15:41:32 -07:00
|
|
|
|
2023-05-01 11:46:19 -07:00
|
|
|
def test_same_named_sharding_pspec_on_eager_ops(self):
|
2024-09-03 16:22:23 -07:00
|
|
|
mesh = jtu.create_mesh((1, 8, 1), ('x', 'y', 'z'))
|
2023-05-01 11:46:19 -07:00
|
|
|
sharding = jax.sharding.NamedSharding(mesh, P('x', 'y', 'z'))
|
|
|
|
x = jax.device_put(jnp.arange(32).reshape(1, -1, 1), sharding)
|
|
|
|
y = x + 1
|
|
|
|
self.assertEqual(x.sharding, y.sharding)
|
|
|
|
|
2023-05-01 17:39:16 -07:00
|
|
|
def test_different_named_sharding_object_replicated(self):
|
2024-09-03 16:22:23 -07:00
|
|
|
mesh = jtu.create_mesh((1, 2), ('x', 'y'))
|
2023-05-01 17:39:16 -07:00
|
|
|
sharding = jax.sharding.NamedSharding(mesh, P('x'))
|
|
|
|
x = jax.device_put(np.arange(16).reshape(8, 2), sharding)
|
|
|
|
y = jnp.sum(x)
|
|
|
|
self.assertNotEqual(x.sharding, y.sharding)
|
|
|
|
|
2023-05-03 11:54:46 -07:00
|
|
|
def test_vmap_pjit_single_device(self):
|
2024-06-12 14:43:14 -07:00
|
|
|
with jtu.ignore_warning(category=DeprecationWarning,
|
|
|
|
message="backend and device argument"):
|
|
|
|
jf = pjit(lambda x: x, device=jax.devices()[0])
|
2023-05-03 11:54:46 -07:00
|
|
|
out = jax.vmap(jf)(jnp.ones((3,))) # doesn't crash
|
|
|
|
self.assertIsInstance(out.sharding, SingleDeviceSharding)
|
|
|
|
|
2023-05-09 14:23:49 -07:00
|
|
|
def test_to_gspmd_sharding_cache_with_and_without_device(self):
|
2024-09-03 16:22:23 -07:00
|
|
|
mesh = jtu.create_mesh((2,), ('x',))
|
2023-05-09 14:23:49 -07:00
|
|
|
np_inp = jnp.arange(4)
|
|
|
|
|
|
|
|
def identity(x):
|
|
|
|
return x
|
|
|
|
|
|
|
|
# Fill up the to_gspmd_sharding cache so that the next jit will miss it.
|
|
|
|
out = jax.jit(identity,
|
|
|
|
in_shardings=SingleDeviceSharding(jax.devices()[0]))(np_inp)
|
2023-11-29 16:52:09 -08:00
|
|
|
self.assertEqual(out.devices(), {jax.devices()[0]})
|
2023-05-09 14:23:49 -07:00
|
|
|
self.assertArraysEqual(out, np_inp)
|
|
|
|
|
2024-06-12 14:43:14 -07:00
|
|
|
with jtu.ignore_warning(category=DeprecationWarning,
|
|
|
|
message="backend and device argument"):
|
|
|
|
out2 = jax.jit(identity, device=jax.devices()[0])(
|
|
|
|
jax.device_put(np_inp, NamedSharding(mesh, P('x'))))
|
2023-11-29 16:52:09 -08:00
|
|
|
self.assertEqual(out2.devices(), {jax.devices()[0]})
|
2023-05-09 14:23:49 -07:00
|
|
|
self.assertArraysEqual(out2, np_inp)
|
|
|
|
|
2023-10-30 15:27:17 -07:00
|
|
|
def test_jit_submhlo_cached(self):
|
|
|
|
@jax.jit
|
|
|
|
def nest(x):
|
|
|
|
return x * 2
|
|
|
|
|
|
|
|
@jax.jit
|
|
|
|
def top(x):
|
|
|
|
y = nest(x)
|
|
|
|
z = nest(y)
|
|
|
|
a = nest(z)
|
|
|
|
b = nest(a)
|
|
|
|
return b
|
|
|
|
|
2024-01-15 02:12:52 -08:00
|
|
|
with jtu.count_subjaxpr_to_hlo_conversion(fun_name='nest') as count:
|
2023-10-30 15:27:17 -07:00
|
|
|
top(jnp.arange(8))
|
|
|
|
|
|
|
|
# The count should be 1 because `nest`'s lowering to MHLO should be cached.
|
|
|
|
self.assertEqual(count[0], 1)
|
|
|
|
|
2023-05-17 11:49:31 -07:00
|
|
|
def test_wsc_eager(self):
|
2024-09-03 16:22:23 -07:00
|
|
|
mesh = jtu.create_mesh((2,), ('x',))
|
2023-05-17 11:49:31 -07:00
|
|
|
np_inp = np.arange(8)
|
|
|
|
inp = jax.device_put(np_inp, NamedSharding(mesh, P()))
|
|
|
|
out = with_sharding_constraint(inp, NamedSharding(mesh, P('x')))
|
|
|
|
self.assertArraysEqual(out, np_inp)
|
|
|
|
self.assertEqual(out.sharding, NamedSharding(mesh, P('x')))
|
|
|
|
for s in out.addressable_shards:
|
|
|
|
self.assertArraysEqual(s.data, np_inp[s.index])
|
|
|
|
|
|
|
|
def test_wsc_eager_no_resharding(self):
|
2024-09-03 16:22:23 -07:00
|
|
|
mesh = jtu.create_mesh((2,), ('x',))
|
2023-05-17 11:49:31 -07:00
|
|
|
np_inp = np.arange(8)
|
|
|
|
inp = jax.device_put(np_inp, NamedSharding(mesh, P('x')))
|
|
|
|
out = with_sharding_constraint(inp, NamedSharding(mesh, P('x')))
|
|
|
|
self.assertEqual(id(out), id(inp))
|
|
|
|
|
|
|
|
def test_wsc_eager_different_order_devices(self):
|
2024-09-03 16:22:23 -07:00
|
|
|
mesh1 = jtu.create_mesh((2,), ('x',))
|
2023-05-17 11:49:31 -07:00
|
|
|
mesh2 = jax.sharding.Mesh([jax.devices()[1], jax.devices()[0]], 'x')
|
|
|
|
inp = jax.device_put(np.arange(8), NamedSharding(mesh1, P()))
|
|
|
|
with self.assertRaisesRegex(
|
|
|
|
ValueError, "Received incompatible devices for jitted computation"):
|
|
|
|
with_sharding_constraint(inp, NamedSharding(mesh2, P('x')))
|
|
|
|
|
2023-05-26 08:56:56 -07:00
|
|
|
def test_jaxpr_as_fun_fast_path(self):
|
|
|
|
@jax.jit
|
|
|
|
def f(x):
|
|
|
|
return x * 2
|
|
|
|
inp = jax.device_put(jnp.arange(8), jax.devices()[0])
|
|
|
|
jaxpr = jax.make_jaxpr(f)(inp)
|
|
|
|
|
|
|
|
with jtu.count_pjit_cpp_cache_miss() as count:
|
|
|
|
out1 = core.jaxpr_as_fun(jaxpr)(inp)
|
|
|
|
out2 = core.jaxpr_as_fun(jaxpr)(inp)
|
|
|
|
self.assertEqual(count[0], 1)
|
|
|
|
self.assertArraysEqual(out1[0], inp * 2)
|
|
|
|
self.assertArraysEqual(out2[0], inp * 2)
|
|
|
|
|
2023-06-05 10:06:30 -07:00
|
|
|
def test_most_recent_executable_outer_inner_cache(self):
|
|
|
|
x = np.zeros((20, 20), dtype=jnp.float64)
|
|
|
|
|
|
|
|
def trace_to_jaxpr(x):
|
|
|
|
jnp.pad(x, [(0, 1), (0, 0)], mode= 'wrap')
|
|
|
|
jnp.pad(x, [(0, 0), (1, 0)], mode= 'constant',
|
|
|
|
constant_values= ((0.0, 0.0), (0.0, 0.0)))
|
|
|
|
|
|
|
|
jaxpr = jax.make_jaxpr(trace_to_jaxpr)(x)
|
|
|
|
jax.core.jaxpr_as_fun(jaxpr)(x)
|
|
|
|
|
|
|
|
jnp.pad(x, [(0, 1), (0, 0)], mode= 'wrap')
|
|
|
|
jnp.pad(x, [(0, 1), (0, 0)], mode= 'wrap') # doesn't crash
|
|
|
|
|
2023-06-14 09:39:54 -07:00
|
|
|
def test_shape_dtype_struct_as_const_error(self):
|
|
|
|
const = jax.ShapeDtypeStruct((8,), jnp.int32)
|
|
|
|
with self.assertRaisesRegex(TypeError,
|
|
|
|
r"Argument.*is not a valid JAX type"):
|
|
|
|
jax.jit(lambda x: (x, const))(jnp.arange(8))
|
|
|
|
|
2023-06-15 15:21:36 -07:00
|
|
|
def test_jit_out_shardings_none(self):
|
2024-09-03 16:22:23 -07:00
|
|
|
mesh = jtu.create_mesh((4, 2), ('x', 'y'))
|
2023-06-15 15:21:36 -07:00
|
|
|
np_inp = np.arange(16).reshape(8, 2)
|
|
|
|
s = NamedSharding(mesh, P('x', 'y'))
|
|
|
|
inp = jax.device_put(np_inp, s)
|
|
|
|
out = jax.jit(lambda x: x * 2, out_shardings=None)(inp)
|
|
|
|
self.assertArraysEqual(out, np_inp * 2)
|
|
|
|
self.assertEqual(out.sharding, s)
|
|
|
|
|
|
|
|
def test_jit_in_shardings_none(self):
|
2024-09-03 16:22:23 -07:00
|
|
|
mesh = jtu.create_mesh((4, 2), ('x', 'y'))
|
2023-06-15 15:21:36 -07:00
|
|
|
np_inp = np.arange(16).reshape(8, 2)
|
|
|
|
s = NamedSharding(mesh, P('x', 'y'))
|
|
|
|
inp = jax.device_put(np_inp, s)
|
|
|
|
|
|
|
|
out = jax.jit(lambda x: x * 2, in_shardings=None)(inp)
|
|
|
|
self.assertArraysEqual(out, np_inp * 2)
|
|
|
|
self.assertEqual(out.sharding, s)
|
|
|
|
|
|
|
|
out2 = jax.jit(lambda x: x * 2, in_shardings=None)(np_inp)
|
|
|
|
self.assertArraysEqual(out2, np_inp * 2)
|
|
|
|
self.assertEqual(out2.sharding, SingleDeviceSharding(jax.devices()[0]))
|
|
|
|
|
2024-08-08 11:23:50 -07:00
|
|
|
def test_device_put_in_jit_default_mem_kind_no_op(self):
|
2024-09-03 16:22:23 -07:00
|
|
|
mesh = jtu.create_mesh((2,), 'x')
|
2024-08-08 11:23:50 -07:00
|
|
|
np_inp = np.arange(8)
|
|
|
|
arr = jax.device_put(np_inp, NamedSharding(mesh, P('x')))
|
|
|
|
|
|
|
|
@jax.jit
|
|
|
|
def f(x):
|
|
|
|
y = x * 2
|
|
|
|
return jax.device_put(y, NamedSharding(mesh, P()))
|
|
|
|
|
|
|
|
lowered_text = f.lower(arr).as_text()
|
|
|
|
self.assertNotIn('@Sharding', lowered_text)
|
|
|
|
self.assertNotIn('@annotate_device_placement', lowered_text)
|
|
|
|
|
2023-06-15 15:21:36 -07:00
|
|
|
def test_jit_both_shardings_none(self):
|
2024-09-03 16:22:23 -07:00
|
|
|
mesh = jtu.create_mesh((4, 2), ('x', 'y'))
|
2023-06-15 15:21:36 -07:00
|
|
|
np_inp = np.arange(16).reshape(8, 2)
|
|
|
|
s = NamedSharding(mesh, P('x', 'y'))
|
|
|
|
inp = jax.device_put(np_inp, s)
|
|
|
|
|
|
|
|
out = jax.jit(lambda x: x * 2, in_shardings=None, out_shardings=None)(inp)
|
|
|
|
self.assertArraysEqual(out, np_inp * 2)
|
|
|
|
self.assertEqual(out.sharding, s)
|
|
|
|
|
|
|
|
out2 = jax.jit(lambda x: x * 2, in_shardings=None, out_shardings=None)(np_inp)
|
|
|
|
self.assertArraysEqual(out2, np_inp * 2)
|
|
|
|
self.assertEqual(out2.sharding, SingleDeviceSharding(jax.devices()[0]))
|
|
|
|
|
2023-06-26 21:46:02 -07:00
|
|
|
def test_jit_lower_shape_dtype_struct_sharding_none(self):
|
2024-09-03 16:22:23 -07:00
|
|
|
mesh = jtu.create_mesh((4, 2), ('x', 'y'))
|
2023-06-26 21:46:02 -07:00
|
|
|
s = NamedSharding(mesh, P('x', 'y'))
|
|
|
|
|
|
|
|
lower_inp1 = jax.ShapeDtypeStruct((8, 2), np.int32, sharding=s)
|
|
|
|
# Will be considered as uncommitted and resharded over all the devices of
|
|
|
|
# the mesh.
|
|
|
|
lower_inp2 = jax.ShapeDtypeStruct((8, 2), np.int32)
|
|
|
|
|
|
|
|
compiled = jax.jit(lambda x, y: (x * 2, y * 2)).lower(
|
|
|
|
lower_inp1, lower_inp2).compile()
|
|
|
|
|
|
|
|
np_inp = np.arange(16, dtype=np.int32).reshape(8, 2)
|
|
|
|
inp = jax.device_put(np_inp, s)
|
|
|
|
out1, out2 = compiled(inp, np_inp)
|
|
|
|
|
|
|
|
self.assertArraysEqual(out1, np_inp * 2)
|
|
|
|
self.assertArraysEqual(out2, np_inp * 2)
|
|
|
|
self.assertTupleEqual(out1.sharding._device_assignment,
|
|
|
|
s.mesh._flat_devices_tuple)
|
|
|
|
self.assertTupleEqual(out2.sharding._device_assignment,
|
|
|
|
s.mesh._flat_devices_tuple)
|
|
|
|
|
2023-08-29 20:58:20 -07:00
|
|
|
def test_vmap_spmd_axis_name_error(self):
|
|
|
|
s = SingleDeviceSharding(jax.devices()[0])
|
|
|
|
|
|
|
|
def f(inp):
|
|
|
|
return with_sharding_constraint(inp, s)
|
|
|
|
|
|
|
|
arr = jax.device_put(np.arange(8), s)
|
|
|
|
with self.assertRaisesRegex(
|
|
|
|
ValueError,
|
2024-07-24 19:01:31 -07:00
|
|
|
'If you are using spmd_axis_name parameter of jax.vmap, please'
|
2023-08-29 20:58:20 -07:00
|
|
|
' make sure to run your jitted function inside the mesh context'
|
|
|
|
' manager.*SingleDeviceSharding'):
|
|
|
|
jax.jit(jax.vmap(f, spmd_axis_name='x'))(arr)
|
|
|
|
|
2023-11-17 20:48:22 -08:00
|
|
|
def test_no_output_multiple_devices(self):
|
2024-09-03 16:22:23 -07:00
|
|
|
mesh = jtu.create_mesh((2,), ('x',))
|
2023-11-17 20:48:22 -08:00
|
|
|
|
|
|
|
@pjit
|
|
|
|
def f():
|
|
|
|
return
|
|
|
|
|
|
|
|
with mesh:
|
|
|
|
f() # doesn't crash
|
|
|
|
|
2023-12-08 14:35:27 -08:00
|
|
|
def test_lowering_cache_hit_different_devices(self):
|
2024-08-23 06:50:14 -07:00
|
|
|
if config.use_shardy_partitioner.value:
|
|
|
|
self.skipTest('b/358322664: different axis names results in '
|
|
|
|
'a cache miss with Shardy.')
|
2023-12-08 14:35:27 -08:00
|
|
|
if jax.device_count() < 4:
|
|
|
|
self.skipTest('Requires >=4 devices')
|
|
|
|
|
|
|
|
mesh1 = jax.sharding.Mesh(jax.devices()[:2], 'x')
|
2024-08-09 23:16:54 -07:00
|
|
|
mesh2 = jax.sharding.Mesh(jax.devices()[2:4], 'y')
|
2023-12-08 14:35:27 -08:00
|
|
|
|
|
|
|
@jax.jit
|
|
|
|
def f(x):
|
|
|
|
return x * 2
|
|
|
|
|
|
|
|
def g(a):
|
|
|
|
a = jax.device_put(a, NamedSharding(mesh1, P('x')))
|
|
|
|
out_a = f(a) # lowering cached
|
|
|
|
|
|
|
|
# same num_devices but different devices.
|
2024-08-09 23:16:54 -07:00
|
|
|
b = jax.device_put(out_a, NamedSharding(mesh2, P('y')))
|
2023-12-08 14:35:27 -08:00
|
|
|
f(b) # lowering cache *hit*
|
|
|
|
|
2024-08-09 20:03:06 -07:00
|
|
|
with jtu.count_jit_and_pmap_lowerings() as count:
|
2023-12-14 09:13:43 -08:00
|
|
|
g(np.arange(8))
|
|
|
|
self.assertEqual(count[0], 1)
|
2023-12-08 14:35:27 -08:00
|
|
|
|
|
|
|
def test_lowering_cache_miss_different_devices_and_sharding(self):
|
|
|
|
if jax.device_count() < 4:
|
|
|
|
self.skipTest('Requires >=4 devices')
|
|
|
|
|
|
|
|
mesh1 = jax.sharding.Mesh(jax.devices()[:2], 'x')
|
|
|
|
mesh2 = jax.sharding.Mesh(jax.devices()[2:4], 'y')
|
|
|
|
|
|
|
|
@jax.jit
|
|
|
|
def f(x):
|
|
|
|
return x * 2
|
|
|
|
|
|
|
|
def g(a):
|
|
|
|
a = jax.device_put(a, NamedSharding(mesh1, P('x')))
|
|
|
|
out_a = f(a) # lowering cached
|
|
|
|
|
|
|
|
# same num_devices but different devices and sharding
|
|
|
|
b = jax.device_put(out_a, NamedSharding(mesh2, P()))
|
|
|
|
f(b) # lowering cache *miss*
|
|
|
|
|
2024-08-09 20:03:06 -07:00
|
|
|
with jtu.count_jit_and_pmap_lowerings() as count:
|
2023-12-08 14:35:27 -08:00
|
|
|
g(np.arange(8))
|
|
|
|
self.assertEqual(count[0], 2)
|
|
|
|
|
2024-01-23 21:28:33 -08:00
|
|
|
def test_single_device_named_sharding_preserved(self):
|
|
|
|
mesh = jax.sharding.Mesh([jax.devices()[0]], 'x')
|
|
|
|
s = NamedSharding(mesh, P('x'))
|
|
|
|
np_inp = np.arange(8)
|
|
|
|
inp = jax.device_put(np_inp, s)
|
|
|
|
|
|
|
|
out = jax.jit(lambda x: x)(inp)
|
|
|
|
self.assertEqual(out.sharding, s)
|
|
|
|
self.assertArraysEqual(out, np_inp)
|
|
|
|
|
2024-02-23 10:02:19 -08:00
|
|
|
def test_mpmd_device_put_fast_path(self):
|
|
|
|
if jax.device_count() < 4:
|
|
|
|
self.skipTest('Needs >= 4 devices')
|
|
|
|
|
|
|
|
dev_count = jax.device_count()
|
|
|
|
mesh1 = jax.sharding.Mesh(jax.devices()[:dev_count//2], 'x')
|
|
|
|
mesh2 = jax.sharding.Mesh(jax.devices()[dev_count//2:], 'x')
|
|
|
|
inp = np.arange(8)
|
|
|
|
arr1 = jax.device_put(inp, NamedSharding(mesh1, P('x')))
|
|
|
|
|
|
|
|
# This is to prevent changes to shard_arg_handler of Array which checks for
|
|
|
|
# indices to take the fast path for resharding. Changes made to the handler
|
|
|
|
# to check for shardings instead of indices will cause this test to fail and
|
|
|
|
# that is expected.
|
|
|
|
with jtu.count_device_put_fast_path_hit() as count:
|
|
|
|
out = jax.device_put(arr1, NamedSharding(mesh2, P('x')))
|
|
|
|
self.assertEqual(count[0], 1)
|
|
|
|
self.assertTupleEqual(out.sharding._device_assignment,
|
|
|
|
mesh2._flat_devices_tuple)
|
|
|
|
self.assertArraysEqual(out, inp)
|
|
|
|
|
2024-02-28 14:36:20 -08:00
|
|
|
def test_prng_sharding_propagation(self):
|
|
|
|
input_shape = (8, 2)
|
2024-09-03 16:22:23 -07:00
|
|
|
mesh = jtu.create_mesh((4, 2), ('x', 'y'))
|
2024-02-28 14:36:20 -08:00
|
|
|
spec = P('x', 'y')
|
|
|
|
|
|
|
|
seeds, _ = create_array(input_shape, mesh, spec, dtype=np.uint32)
|
|
|
|
|
|
|
|
@jax.jit
|
|
|
|
def make_keys(seeds):
|
|
|
|
make_key = partial(prng.random_seed, impl=prng.threefry_prng_impl)
|
|
|
|
key = make_key(seeds)
|
|
|
|
return key.T
|
|
|
|
|
|
|
|
out = make_keys(seeds)
|
|
|
|
self.assertEqual(out.sharding, NamedSharding(mesh, P('y', 'x')))
|
|
|
|
|
|
|
|
base_array = jax.random.key_data(out)
|
|
|
|
self.assertEqual(base_array.shape, (2, 8, 2))
|
|
|
|
self.assertEqual(base_array.sharding, NamedSharding(mesh, P('y', 'x', None)))
|
|
|
|
|
|
|
|
lowered_text = make_keys.lower(seeds).as_text()
|
2024-08-23 06:50:14 -07:00
|
|
|
if config.use_shardy_partitioner.value:
|
|
|
|
self.assertIn('<@mesh, [{?}, {?}, {}]>', lowered_text)
|
|
|
|
else:
|
|
|
|
self.assertIn('unspecified_dims=[0,1]', lowered_text)
|
2024-02-28 14:36:20 -08:00
|
|
|
|
|
|
|
def test_prng_sharding_propagation_with_nested_jit(self):
|
|
|
|
input_shape = (8, 2)
|
2024-09-03 16:22:23 -07:00
|
|
|
mesh = jtu.create_mesh((4, 2), ('x', 'y'))
|
2024-02-28 14:36:20 -08:00
|
|
|
spec = P('x', 'y')
|
|
|
|
|
|
|
|
seeds, _ = create_array(input_shape, mesh, spec, dtype=np.uint32)
|
|
|
|
|
|
|
|
@jax.jit
|
|
|
|
def make_keys(seeds):
|
|
|
|
@partial(jax.jit, out_shardings=NamedSharding(mesh, P('y')))
|
|
|
|
def f():
|
|
|
|
make_key = partial(prng.random_seed, impl=prng.threefry_prng_impl)
|
|
|
|
return make_key(seeds)
|
|
|
|
x = f()
|
|
|
|
return x.T
|
|
|
|
|
|
|
|
out = make_keys(seeds)
|
|
|
|
self.assertEqual(out.sharding, NamedSharding(mesh, P(None, 'y')))
|
|
|
|
|
|
|
|
base_array = jax.random.key_data(out)
|
|
|
|
self.assertEqual(base_array.shape, (2, 8, 2))
|
|
|
|
self.assertEqual(base_array.sharding, NamedSharding(mesh, P(None, 'y', None)))
|
|
|
|
|
|
|
|
lowered_text = make_keys.lower(seeds).as_text()
|
2024-08-23 06:50:14 -07:00
|
|
|
if config.use_shardy_partitioner.value:
|
|
|
|
self.assertIn('<@mesh, [{?}, {?}, {}]>', lowered_text)
|
|
|
|
else:
|
|
|
|
self.assertIn('unspecified_dims=[0,1]', lowered_text)
|
2024-02-28 14:36:20 -08:00
|
|
|
|
2024-03-06 11:41:34 -08:00
|
|
|
def test_partial_sharded_prng_key_inp(self):
|
|
|
|
input_shape = (8, 2, 2)
|
2024-09-03 16:22:23 -07:00
|
|
|
mesh = jtu.create_mesh((2, 2, 2), ('x', 'y', 'z'))
|
2024-03-06 11:41:34 -08:00
|
|
|
spec = P('x', 'y', None)
|
|
|
|
|
|
|
|
seeds, _ = create_array(input_shape, mesh, spec, dtype=np.uint32)
|
|
|
|
|
|
|
|
@jax.jit
|
|
|
|
def make_keys(seeds):
|
|
|
|
make_key = partial(prng.random_seed, impl=prng.threefry_prng_impl)
|
|
|
|
key = make_key(seeds)
|
|
|
|
return key.T
|
|
|
|
|
|
|
|
make_keys(seeds)
|
|
|
|
out = make_keys(seeds) # cpp dispatch
|
|
|
|
self.assertEqual(out.sharding, NamedSharding(mesh, P(None, 'y', 'x')))
|
|
|
|
|
|
|
|
base_array = jax.random.key_data(out)
|
|
|
|
self.assertEqual(base_array.shape, (2, 2, 8, 2))
|
|
|
|
self.assertEqual(base_array.sharding, NamedSharding(mesh, P(None, 'y', 'x')))
|
|
|
|
|
|
|
|
lowered_text = make_keys.lower(seeds).as_text()
|
2024-08-23 06:50:14 -07:00
|
|
|
if config.use_shardy_partitioner.value:
|
|
|
|
self.assertIn('<@mesh, [{?}, {?}, {?}, {}]>', lowered_text)
|
|
|
|
else:
|
|
|
|
self.assertIn('unspecified_dims=[0,1,2]', lowered_text)
|
2024-03-06 11:41:34 -08:00
|
|
|
|
2024-03-04 08:50:58 -08:00
|
|
|
def test_jit_partially_specified_shardings(self):
|
|
|
|
|
2024-09-03 16:22:23 -07:00
|
|
|
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
2024-03-04 08:50:58 -08:00
|
|
|
np_inp = np.arange(16).reshape(8, 2)
|
|
|
|
s = NamedSharding(mesh, P('x', 'y'))
|
|
|
|
s2 = NamedSharding(mesh, P('x'))
|
|
|
|
arr = jax.device_put(np_inp, s)
|
|
|
|
arr2 = jax.device_put(np_inp, s2)
|
|
|
|
|
|
|
|
@partial(jax.jit, in_shardings=(s, None, s2, UNSPECIFIED, UNSPECIFIED),
|
|
|
|
out_shardings=(s2, None, None, s, None))
|
|
|
|
def f(x, y, z, a, b):
|
|
|
|
return x * 2, y @ y.T, z ** 2, a * 3, b.T
|
|
|
|
|
|
|
|
out1, out2, out3, out4, out5 = f(arr, np_inp, arr2, np_inp, arr)
|
|
|
|
self.assertArraysEqual(out1, np_inp * 2)
|
|
|
|
self.assertArraysEqual(out2, np_inp @ np_inp.T)
|
|
|
|
self.assertArraysEqual(out3, np_inp ** 2)
|
|
|
|
self.assertArraysEqual(out4, np_inp * 3)
|
|
|
|
self.assertArraysEqual(out5, np_inp.T)
|
|
|
|
|
2024-04-12 17:52:08 -07:00
|
|
|
def test_input_shardings_aot(self):
|
2024-09-03 16:22:23 -07:00
|
|
|
mesh = jtu.create_mesh((2, 1), ('x', 'y'))
|
2024-04-12 17:52:08 -07:00
|
|
|
np_inp = np.arange(16).reshape(8, 2)
|
|
|
|
arr = jax.device_put(np_inp, NamedSharding(mesh, P('x')))
|
|
|
|
|
|
|
|
@jax.jit
|
|
|
|
def f(x, y):
|
|
|
|
return x * 2, y.T
|
|
|
|
|
|
|
|
arg_shardings, _ = f.lower(arr, np_inp).compile().input_shardings
|
|
|
|
for s in arg_shardings:
|
|
|
|
self.assertIsInstance(s, NamedSharding)
|
|
|
|
|
2024-03-14 15:09:07 -07:00
|
|
|
def test_parameter_tupled_jit(self):
|
|
|
|
if not jtu.test_device_matches(["tpu"]):
|
|
|
|
self.skipTest('Parameters are tupled only on TPU if >2000 parameters')
|
|
|
|
|
2024-09-03 16:22:23 -07:00
|
|
|
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
2024-03-14 15:09:07 -07:00
|
|
|
s = NamedSharding(mesh, P('x'))
|
|
|
|
|
|
|
|
@jax.jit
|
|
|
|
def f(*args):
|
|
|
|
return args * 2
|
|
|
|
|
|
|
|
inp = np.arange(8)
|
|
|
|
arr = jax.device_put(inp, s)
|
|
|
|
inps = [arr, *[inp] * 2001]
|
|
|
|
f(inps) # doesn't crash
|
|
|
|
|
2024-04-14 10:30:10 -07:00
|
|
|
def test_spmd_preserves_input_sharding_vmap_grad(self):
|
2024-08-23 06:50:14 -07:00
|
|
|
if config.use_shardy_partitioner.value:
|
|
|
|
self.skipTest("Shardy doesn't support PositionalSharding")
|
2024-04-14 10:30:10 -07:00
|
|
|
# https://github.com/google/jax/issues/20710
|
|
|
|
n_devices = jax.device_count()
|
|
|
|
sharding = PositionalSharding(jax.devices())
|
|
|
|
|
|
|
|
def model(params, x):
|
|
|
|
return x @ params
|
|
|
|
|
|
|
|
feature_dim = 3
|
|
|
|
batch_size_total = 8
|
|
|
|
|
|
|
|
# Get example data
|
|
|
|
x = jnp.ones((batch_size_total, feature_dim))
|
|
|
|
params = jnp.ones(feature_dim)
|
|
|
|
|
|
|
|
# Shard data, replicate params
|
|
|
|
x = jax.device_put(x, sharding.reshape(n_devices, 1))
|
|
|
|
params = jax.device_put(params, sharding.replicate(axis=0))
|
|
|
|
|
|
|
|
model(params, x) # doesn't crash
|
|
|
|
|
|
|
|
jax.vmap(model, in_axes=(None, 0))(params, x) # doesn't crash
|
|
|
|
|
|
|
|
jax.grad(lambda p: model(p, x).sum())(params) # doesn't crash
|
|
|
|
|
|
|
|
jax.vmap(jax.grad(model), in_axes=(None, 0))(params, x) # doesn't crash
|
|
|
|
|
2024-05-10 10:11:55 -07:00
|
|
|
def test_jit_token_input(self):
|
|
|
|
x = jnp.arange(8)
|
|
|
|
token = jax.lax.create_token(None)
|
|
|
|
device = jax.devices()[0]
|
|
|
|
x = jax.device_put(x, device=device)
|
|
|
|
out1, out2 = jax.jit(lambda x, t: (x, t))(x, token)
|
|
|
|
self.assertArraysEqual(out1, x)
|
|
|
|
self.assertIsInstance(out2, core.Token)
|
|
|
|
|
2024-05-16 07:47:02 -07:00
|
|
|
def test_uneven_sharding_wsc(self):
|
2024-09-03 16:22:23 -07:00
|
|
|
mesh = jtu.create_mesh(
|
2024-05-16 07:47:02 -07:00
|
|
|
(2, 1, 1, 1, 1), ('data', 'expert', 'fsdp', 'seq', 'model')
|
|
|
|
)
|
|
|
|
|
|
|
|
@jax.jit
|
|
|
|
def fn(key):
|
|
|
|
x = jnp.arange(113003)
|
|
|
|
x = with_sharding_constraint(x, P('data'))
|
|
|
|
y = jnp.arange(65536)
|
|
|
|
y = with_sharding_constraint(y.reshape(-1), P('data'))
|
|
|
|
z = jnp.concatenate([x, y], axis=0)
|
|
|
|
z = with_sharding_constraint(z, P('data'))
|
|
|
|
return x, y, z
|
|
|
|
|
|
|
|
with mesh:
|
|
|
|
x, y, z = fn(jax.random.key(42))
|
|
|
|
|
|
|
|
expected_x = np.arange(113003)
|
|
|
|
expected_y = np.arange(65536)
|
|
|
|
expected_z = np.concatenate([x, y], axis=0)
|
|
|
|
|
|
|
|
self.assertArraysEqual(expected_x.max(), x.max())
|
|
|
|
self.assertArraysEqual(expected_y.max(), y.max())
|
|
|
|
self.assertArraysEqual(expected_z.max(), z.max())
|
|
|
|
|
2024-05-24 09:14:43 -07:00
|
|
|
def test_threefry_partitionable_context_within_jit(self):
|
|
|
|
with jax.threefry_partitionable(False):
|
|
|
|
def f(x):
|
|
|
|
return x + jax.random.randint(jax.random.key(72), (), 0, 10)
|
|
|
|
|
|
|
|
def g(x):
|
|
|
|
with jax.threefry_partitionable(True): # False by default
|
|
|
|
return x + jax.random.randint(jax.random.key(72), (), 0, 10)
|
|
|
|
|
|
|
|
h = jax.jit(g)
|
|
|
|
|
|
|
|
self.assertNotEqual(f(1), g(1))
|
|
|
|
self.assertEqual(g(1), h(1))
|
|
|
|
|
2024-05-30 17:42:14 -07:00
|
|
|
def test_wsc_vmap_unconstrained_spmd_axis_name(self):
|
|
|
|
def get_wsc_eqn_sharding(jaxpr):
|
|
|
|
for eqn in jaxpr.eqns:
|
|
|
|
if str(eqn.primitive) == 'sharding_constraint':
|
|
|
|
return eqn.params['sharding'], eqn.params['unconstrained_dims']
|
|
|
|
for s in core.subjaxprs(jaxpr):
|
|
|
|
return get_wsc_eqn_sharding(s)
|
|
|
|
|
2024-09-03 16:22:23 -07:00
|
|
|
mesh = jtu.create_mesh((2, 1), ('x', 'y'))
|
2024-05-30 17:42:14 -07:00
|
|
|
inp = jnp.ones((10, 10))
|
|
|
|
|
|
|
|
def a_function(x):
|
|
|
|
return with_sharding_constraint(x, NamedSharding(mesh, P(P.UNCONSTRAINED)))
|
|
|
|
|
|
|
|
def vmap_the_function_spmd(y):
|
|
|
|
return jax.vmap(a_function, spmd_axis_name='x')(y)
|
|
|
|
|
|
|
|
f1 = jax.jit(vmap_the_function_spmd)
|
|
|
|
f1(inp) # doesn't crash
|
|
|
|
jaxpr1 = jax.make_jaxpr(f1)(inp)
|
|
|
|
s1, u1 = get_wsc_eqn_sharding(jaxpr1)
|
|
|
|
self.assertEqual(s1.spec, P('x', P.UNCONSTRAINED))
|
|
|
|
self.assertEqual(u1, {1})
|
|
|
|
|
|
|
|
def vmap_the_function_no_spmd(y):
|
|
|
|
return jax.vmap(a_function)(y)
|
|
|
|
|
|
|
|
f2 = jax.jit(vmap_the_function_no_spmd)
|
|
|
|
f2(inp) # doesn't crash
|
|
|
|
jaxpr2 = jax.make_jaxpr(f2)(inp)
|
|
|
|
s2, u2 = get_wsc_eqn_sharding(jaxpr2)
|
|
|
|
self.assertEqual(s2.spec, P(P.UNCONSTRAINED, P.UNCONSTRAINED))
|
|
|
|
self.assertEqual(u2, {0, 1})
|
|
|
|
|
2024-06-03 11:44:05 -07:00
|
|
|
def test_aot_sharding_dce(self):
|
|
|
|
inp = np.arange(8)
|
|
|
|
|
|
|
|
@jax.jit
|
|
|
|
def f(x, y):
|
|
|
|
return x
|
|
|
|
|
|
|
|
input_shardings, _ = f.lower(inp, inp).compile().input_shardings
|
|
|
|
self.assertLen(input_shardings, 2)
|
|
|
|
|
2024-06-04 18:50:02 -07:00
|
|
|
def test_aot_out_info(self):
|
|
|
|
inp = np.arange(8, dtype=np.int32)
|
|
|
|
out_info = jax.jit(lambda x: x).lower((inp, inp)).out_info
|
|
|
|
self.assertEqual(out_info[0].shape, (8,))
|
|
|
|
self.assertEqual(out_info[1].shape, (8,))
|
|
|
|
self.assertEqual(out_info[0].dtype, np.int32)
|
|
|
|
self.assertEqual(out_info[1].dtype, np.int32)
|
|
|
|
self.assertEqual(out_info[0].sharding, None)
|
|
|
|
self.assertEqual(out_info[1].sharding, None)
|
|
|
|
|
2024-06-06 17:42:25 -07:00
|
|
|
def test_jit_trace(self):
|
2024-06-05 17:45:34 -07:00
|
|
|
def f(x):
|
|
|
|
return x * 2
|
|
|
|
|
2024-06-06 17:42:25 -07:00
|
|
|
traced = jax.jit(f).trace(jnp.arange(8, dtype=jnp.int32))
|
|
|
|
self.assertLen(traced.jaxpr.eqns, 1)
|
|
|
|
self.assertEqual(jax.tree.structure(traced.out_info).num_leaves, 1)
|
|
|
|
self.assertEqual(traced.out_info.shape, (8,))
|
|
|
|
self.assertEqual(traced.out_info.dtype, jnp.int32)
|
2024-06-05 17:45:34 -07:00
|
|
|
# one for args, one for kwargs (though kwargs is empty)
|
2024-06-06 17:42:25 -07:00
|
|
|
self.assertLen(traced.in_avals, 2)
|
|
|
|
self.assertLen(traced.in_avals[0], 1)
|
|
|
|
self.assertLen(traced.in_avals[1], 0) # empty kwarg
|
2024-06-05 17:45:34 -07:00
|
|
|
|
2024-06-06 17:42:25 -07:00
|
|
|
def test_jit_trace_lower_and_compile(self):
|
2024-06-05 19:54:08 -07:00
|
|
|
def f(x):
|
|
|
|
return x * 2
|
|
|
|
|
2024-06-06 17:42:25 -07:00
|
|
|
lowered = jax.jit(f).trace(jnp.arange(8)).lower()
|
2024-06-05 19:54:08 -07:00
|
|
|
self.assertEqual(lowered.args_info[0][0].shape, (8,))
|
|
|
|
|
|
|
|
compiled = lowered.compile()
|
|
|
|
out = compiled(jnp.arange(8))
|
|
|
|
self.assertArraysEqual(out, np.arange(8) * 2)
|
|
|
|
|
|
|
|
# fast-forward
|
|
|
|
lowered2 = jax.jit(f).lower(jnp.arange(8))
|
|
|
|
self.assertEqual(lowered2.args_info[0][0].shape, (8,))
|
|
|
|
|
|
|
|
compiled2 = lowered2.compile()
|
|
|
|
out2 = compiled2(jnp.arange(8))
|
|
|
|
self.assertArraysEqual(out2, np.arange(8) * 2)
|
|
|
|
|
2024-07-01 13:13:53 -07:00
|
|
|
def test_device_put_efficient_reshard_single_host(self):
|
2024-08-23 06:50:14 -07:00
|
|
|
if config.use_shardy_partitioner.value:
|
|
|
|
self.skipTest(
|
|
|
|
'_different_device_order_reshard is creating a GSPMDSharding')
|
2024-07-01 13:13:53 -07:00
|
|
|
if jax.device_count() < 4:
|
|
|
|
self.skipTest('Requires >= 4 devices')
|
|
|
|
|
|
|
|
dev = jax.devices()
|
|
|
|
mesh1 = Mesh(np.array([dev[0], dev[1], dev[2], dev[3]]).reshape(2, 2),
|
|
|
|
('x', 'y'))
|
|
|
|
mesh2 = Mesh(np.array([dev[3], dev[2], dev[1], dev[0]]).reshape(2, 2),
|
|
|
|
('x', 'y'))
|
|
|
|
np_inp = np.arange(16).reshape(8, 2)
|
|
|
|
s1 = NamedSharding(mesh1, P('x', 'y'))
|
|
|
|
s2 = NamedSharding(mesh2, P('x'))
|
|
|
|
|
|
|
|
x_s1 = jax.device_put(np_inp, s1)
|
|
|
|
|
|
|
|
with jax.transfer_guard('disallow_explicit'):
|
|
|
|
out = jax.device_put(x_s1, s2)
|
|
|
|
self.assertArraysEqual(out, np_inp)
|
|
|
|
self.assertEqual(out.sharding, s2)
|
|
|
|
|
2024-07-20 09:08:16 -07:00
|
|
|
@parameterized.named_parameters(
|
|
|
|
("8_2", (8, 2)),
|
|
|
|
("8_384", (8, 384)),
|
|
|
|
)
|
|
|
|
def test_device_put_efficient_reshard_complex_mesh(self, shape):
|
2024-08-23 06:50:14 -07:00
|
|
|
if config.use_shardy_partitioner.value:
|
|
|
|
self.skipTest(
|
|
|
|
'_different_device_order_reshard is creating a GSPMDSharding')
|
2024-07-20 09:08:16 -07:00
|
|
|
if jax.device_count() < 8:
|
|
|
|
self.skipTest('Requires >= 8 devices')
|
|
|
|
|
|
|
|
dev = jax.devices()
|
|
|
|
mesh1 = jax.sharding.Mesh(
|
|
|
|
np.asarray(dev).reshape([1, 2, 2, 2]),
|
|
|
|
('replica', 'data', 'seq', 'model'))
|
|
|
|
mesh2 = jax.sharding.Mesh(
|
|
|
|
np.asarray(jax.devices())
|
|
|
|
.reshape([1, 1, 2, 2, 2, 1])
|
|
|
|
.swapaxes(2, 3)
|
|
|
|
.reshape([1, 1, 4, 2, 1]),
|
|
|
|
('replica', 'data', 'seq', 'model_q', 'model_kv'))
|
|
|
|
|
|
|
|
np_inp = jnp.arange(math.prod(shape)).reshape(shape)
|
|
|
|
s1 = NamedSharding(mesh1, P('model'))
|
|
|
|
s2 = NamedSharding(mesh2, P())
|
|
|
|
|
|
|
|
x_s1 = jax.device_put(np_inp, s1)
|
2024-07-22 13:55:55 -07:00
|
|
|
# Reshard!
|
2024-07-20 09:08:16 -07:00
|
|
|
out = jax.device_put(x_s1, s2)
|
|
|
|
self.assertArraysEqual(out, np_inp)
|
|
|
|
self.assertEqual(out.sharding, s2)
|
2024-07-22 13:55:55 -07:00
|
|
|
del out
|
|
|
|
|
|
|
|
s3 = NamedSharding(mesh2, P('model_q'))
|
|
|
|
x_s3 = jax.device_put(np_inp, s3)
|
|
|
|
# Reshard to iota device assignment!
|
|
|
|
out2 = jax.device_put(x_s3, s1)
|
|
|
|
self.assertArraysEqual(out2, np_inp)
|
|
|
|
self.assertEqual(out2.sharding, s1)
|
2024-07-20 09:08:16 -07:00
|
|
|
|
2024-07-09 07:32:38 -07:00
|
|
|
def test_convert_element_type_sharding(self):
|
2024-09-03 16:22:23 -07:00
|
|
|
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
2024-07-09 07:32:38 -07:00
|
|
|
s = NamedSharding(mesh, P('x', 'y'))
|
|
|
|
inp = np.arange(16).reshape(8, 2)
|
|
|
|
|
|
|
|
out = lax_internal._convert_element_type(
|
|
|
|
inp, new_dtype=np.float32, weak_type=False, sharding=s)
|
|
|
|
self.assertArraysEqual(out, inp.astype('float32'))
|
|
|
|
self.assertEqual(out.dtype, np.float32)
|
|
|
|
self.assertEqual(out.sharding, s)
|
|
|
|
|
|
|
|
def test_jnp_array_sharding(self):
|
2024-09-03 14:30:37 -07:00
|
|
|
if jax.device_count() < 4:
|
|
|
|
self.skipTest('Requires >=4 devices')
|
|
|
|
mesh = jax.make_mesh((2, 2), ('x', 'y'), devices=jax.devices()[:4])
|
2024-07-09 07:32:38 -07:00
|
|
|
s = NamedSharding(mesh, P('x', 'y'))
|
|
|
|
inp = np.arange(16).reshape(8, 2)
|
|
|
|
|
|
|
|
out = jnp.array(inp, device=s)
|
|
|
|
self.assertArraysEqual(out, inp)
|
|
|
|
self.assertEqual(out.sharding, s)
|
|
|
|
|
|
|
|
def test_jnp_array_inside_jit_sharding(self):
|
2024-09-03 14:30:37 -07:00
|
|
|
if jax.device_count() < 4:
|
|
|
|
self.skipTest('Requires >=4 devices')
|
|
|
|
mesh = jax.make_mesh((2, 2), ('x', 'y'))
|
2024-07-09 07:32:38 -07:00
|
|
|
s = NamedSharding(mesh, P('x', 'y'))
|
|
|
|
inp = np.arange(16).reshape(8, 2)
|
|
|
|
|
|
|
|
@jax.jit
|
|
|
|
def f():
|
|
|
|
return jnp.array(inp, dtype=np.float32, device=s)
|
|
|
|
|
|
|
|
out = f()
|
|
|
|
print(f.trace().jaxpr)
|
|
|
|
self.assertArraysEqual(out, inp.astype('float32'))
|
|
|
|
self.assertEqual(out.sharding, s)
|
|
|
|
self.assertEqual(out.dtype, np.float32)
|
|
|
|
|
|
|
|
@jax.jit
|
|
|
|
def g(x):
|
|
|
|
return jnp.array(x, dtype=np.float32, device=s)
|
|
|
|
|
|
|
|
out2 = g(inp)
|
|
|
|
self.assertArraysEqual(out2, inp.astype('float32'))
|
|
|
|
self.assertEqual(out2.sharding, s)
|
|
|
|
self.assertEqual(out2.dtype, np.float32)
|
|
|
|
|
|
|
|
def test_jnp_array_reshard_error(self):
|
|
|
|
if jax.device_count() < 2:
|
|
|
|
self.skipTest('Requires >=2 devices')
|
|
|
|
arr = jax.device_put(np.arange(8), jax.devices()[0])
|
|
|
|
with self.assertRaisesRegex(ValueError, "Received incompatible devices.*"):
|
|
|
|
jnp.array(arr, device=jax.devices()[1])
|
|
|
|
|
|
|
|
def test_jnp_array_sharded_array_no_op(self):
|
|
|
|
inp = np.arange(16).reshape(8, 2)
|
|
|
|
arr = jax.device_put(inp, jax.devices()[0])
|
|
|
|
|
|
|
|
out = lax_internal._convert_element_type(
|
|
|
|
arr, sharding=SingleDeviceSharding(jax.devices()[0]))
|
|
|
|
self.assertArraysEqual(out, inp)
|
|
|
|
self.assertEqual(out.unsafe_buffer_pointer(), arr.unsafe_buffer_pointer())
|
|
|
|
|
|
|
|
def test_wsc_named_sharding_nullary(self):
|
2024-09-03 16:22:23 -07:00
|
|
|
mesh = jtu.create_mesh((2,), ('x',))
|
2024-07-09 07:32:38 -07:00
|
|
|
s = NamedSharding(mesh, P())
|
|
|
|
|
|
|
|
@jax.jit
|
|
|
|
def f():
|
|
|
|
return jax.lax.with_sharding_constraint(jnp.arange(8), s)
|
|
|
|
|
|
|
|
out = f()
|
|
|
|
self.assertEqual(out.sharding, s)
|
|
|
|
|
2024-08-02 08:15:01 -07:00
|
|
|
@jtu.run_on_devices('tpu', 'gpu')
|
|
|
|
def test_aot_device_mismatch(self):
|
2024-09-03 16:22:23 -07:00
|
|
|
mesh = jtu.create_mesh((1,), 'x')
|
2024-08-02 08:15:01 -07:00
|
|
|
np_inp = np.arange(8)
|
|
|
|
arr = jax.device_put(np_inp, NamedSharding(mesh, P()))
|
|
|
|
|
|
|
|
@jax.jit
|
|
|
|
def f(x):
|
|
|
|
return x * 2
|
|
|
|
|
|
|
|
compiled = f.lower(arr).compile()
|
|
|
|
|
|
|
|
cpu_arr = jax.device_put(np_inp, jax.devices('cpu')[0])
|
|
|
|
with self.assertRaisesRegex(
|
|
|
|
ValueError,
|
|
|
|
"Compiled object called with input sharding.*does not match"):
|
|
|
|
compiled(cpu_arr)
|
|
|
|
|
Introduce `jax.sharding.AbstractMesh(shape_tuple: tuple[tuple[str, int], ...])` and allow `with_sharding_constraint` and `shard_map` to accept an abstract mesh as input (`with_sharding_constraint` is via `NamedSharding(abstract_mesh, pspec)`).
**Semantics**
Inside jit, we don't need to talk about concrete devices ever so the semantics stay the same as today i.e. we can lower a NamedSharding with abstract mesh with only mesh axis names and sizes and PartitionSpec. The only restriction is that the number of devices need to be consistent throughout the program when we are tracing.
During compilation, the order of devices throughout the program needs to be consistent (same as before this change).
Outside jit i.e. eager mode, if a `shard_map` or `with_sharding_constraint` contains AbstractMesh, then the input to those primitives should contain a concrete Mesh with the same shape and names as the abstract mesh.
**Why do this?**
There are cases, where you want the change the devices in the mesh but keep the mesh shape the same (axis names and axis sizes). But this leads to a device mismatch error if you have `with_sharding_constraint` or `shard_map` in your computation because they embed concrete devices in their signature.
So to fix the error, you need to change the mesh in `wsc` and `shmap` which will lead to a tracing cache miss (because function id is now different) and consequently a lowering to stableHLO cache miss. Explaining via an example:
```
mesh1 = Mesh(jax.devices()[:2], 'x')
mesh2 = Mesh(jax.devices()[2:4], 'x')
arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P()))
arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P()))
@jax.jit
def f(x):
y = with_sharding_constraint(x, NamedSharding(mesh1, P('x')))
return y * 2
f(arr_mesh1)
f(arr_mesh2) # DEVICE MISMATCH ERROR!
```
The same problem exists for `shard_map` since it takes a mesh with concrete devices in it's signature.
**Okay, so how do you fix this?**
As mentioned above, we need the above program to work and get tracing and lowering cache hits (**cache hits is the most important** part here)
The approach in this change, allows `with_sharding_constraint` to accept a `NamedSharding(abstract_mesh, pspec)` as input. This leads to no errors downstream and we get tracing and lowering cache hits since we don't encode the concrete devices anymore. Just the axis_names and axis_size of the mesh.
**The important part is that the concrete device information should only come from the arguments. Inside `jax.jit`, you should never reference concrete devices ever.**
```
mesh1 = Mesh(jax.devices()[:2], 'x')
mesh2 = Mesh(jax.devices()[2:4], 'x')
arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P()))
arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P()))
# Creating abstract mesh with mesh1 but since both meshes have the same shape (names
# and axis size), it should be ok.
abstract_mesh = jax.sharding.AbstractMesh(arr_mesh1.shape_tuple)
@jax.jit
def f(x):
y = with_sharding_constraint(x, NamedSharding(abstract_mesh, P('x')))
return y * 2
f(arr_mesh1)
f(arr_mesh2) # tracing and lowering cache hit
```
**One caveat is that this only works with `jax.NamedSharding` but that's fine because `NamedSharding` is the most used `Sharding` in JAX.**
**What about `shard_map`?**
shard_map's signature will be: `shmap(f, mesh: Mesh | AbstractMesh, in_specs: Specs, out_specs: Specs)`.
```
mesh1 = Mesh(jax.devices()[:2], 'x')
mesh2 = Mesh(jax.devices()[2:4], 'x')
arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P()))
arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P()))
# Creating abstract mesh with mesh1 but since both meshes have the same shape (names
# and axis size), it should be ok.
abstract_mesh = jax.sharding.AbstractMesh(arr_mesh1.shape_tuple)
@jax.jit
def f(x):
y = shard_map(lambda x: x, mesh=abstract_mesh, in_specs=P('x'), out_specs=P('x'))
return y * 2
f(arr_mesh1)
f(arr_mesh2) # tracing and lowering cache hit
```
This is a fully backwards change. So your current code will continue to work as is but you can opt-into this new behavior and get all the benefits!
PiperOrigin-RevId: 662670932
2024-08-13 15:17:30 -07:00
|
|
|
@unittest.skipIf(xla_extension_version < 281,
|
|
|
|
'Requires xla_extension_version >= 281')
|
|
|
|
def test_different_devices_wsc_abstract_mesh_cache_hit(self):
|
|
|
|
if jax.device_count() < 4:
|
|
|
|
self.skipTest('Requires >=4 devices')
|
|
|
|
|
|
|
|
mesh1 = jax.sharding.Mesh(jax.devices()[:2], 'x')
|
|
|
|
mesh2 = jax.sharding.Mesh(jax.devices()[2:4], 'x')
|
|
|
|
|
|
|
|
@jax.jit
|
|
|
|
def f(x):
|
|
|
|
x = with_sharding_constraint(
|
|
|
|
x, NamedSharding(mesh_lib.AbstractMesh(mesh1.shape_tuple), P('x')))
|
|
|
|
return jnp.sin(x)
|
|
|
|
|
|
|
|
with (
|
|
|
|
jtu.count_jit_tracing_cache_miss() as tracing_count,
|
|
|
|
jtu.count_jit_and_pmap_lowerings() as lowering_count,
|
|
|
|
jtu.count_jit_compilation_cache_miss() as compilation_count,
|
|
|
|
):
|
|
|
|
a = jax.device_put(np.arange(8.), NamedSharding(mesh1, P()))
|
|
|
|
out_a = f(a) # tracing and lowering cached
|
|
|
|
|
|
|
|
# same num_devices but different devices.
|
|
|
|
b = jax.device_put(out_a, NamedSharding(mesh2, P()))
|
|
|
|
f(b) # tracing and lowering cache *hit*
|
|
|
|
self.assertEqual(tracing_count[0], 2) # 1 miss for `f` and 1 miss for `sin`
|
|
|
|
self.assertEqual(lowering_count[0], 1)
|
|
|
|
self.assertEqual(compilation_count[0], 2) # 2 misses since devices differ.
|
|
|
|
|
|
|
|
@unittest.skipIf(xla_extension_version < 281,
|
|
|
|
'Requires xla_extension_version >= 281')
|
|
|
|
def test_wsc_abstract_mesh(self):
|
2024-09-03 16:22:23 -07:00
|
|
|
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
Introduce `jax.sharding.AbstractMesh(shape_tuple: tuple[tuple[str, int], ...])` and allow `with_sharding_constraint` and `shard_map` to accept an abstract mesh as input (`with_sharding_constraint` is via `NamedSharding(abstract_mesh, pspec)`).
**Semantics**
Inside jit, we don't need to talk about concrete devices ever so the semantics stay the same as today i.e. we can lower a NamedSharding with abstract mesh with only mesh axis names and sizes and PartitionSpec. The only restriction is that the number of devices need to be consistent throughout the program when we are tracing.
During compilation, the order of devices throughout the program needs to be consistent (same as before this change).
Outside jit i.e. eager mode, if a `shard_map` or `with_sharding_constraint` contains AbstractMesh, then the input to those primitives should contain a concrete Mesh with the same shape and names as the abstract mesh.
**Why do this?**
There are cases, where you want the change the devices in the mesh but keep the mesh shape the same (axis names and axis sizes). But this leads to a device mismatch error if you have `with_sharding_constraint` or `shard_map` in your computation because they embed concrete devices in their signature.
So to fix the error, you need to change the mesh in `wsc` and `shmap` which will lead to a tracing cache miss (because function id is now different) and consequently a lowering to stableHLO cache miss. Explaining via an example:
```
mesh1 = Mesh(jax.devices()[:2], 'x')
mesh2 = Mesh(jax.devices()[2:4], 'x')
arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P()))
arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P()))
@jax.jit
def f(x):
y = with_sharding_constraint(x, NamedSharding(mesh1, P('x')))
return y * 2
f(arr_mesh1)
f(arr_mesh2) # DEVICE MISMATCH ERROR!
```
The same problem exists for `shard_map` since it takes a mesh with concrete devices in it's signature.
**Okay, so how do you fix this?**
As mentioned above, we need the above program to work and get tracing and lowering cache hits (**cache hits is the most important** part here)
The approach in this change, allows `with_sharding_constraint` to accept a `NamedSharding(abstract_mesh, pspec)` as input. This leads to no errors downstream and we get tracing and lowering cache hits since we don't encode the concrete devices anymore. Just the axis_names and axis_size of the mesh.
**The important part is that the concrete device information should only come from the arguments. Inside `jax.jit`, you should never reference concrete devices ever.**
```
mesh1 = Mesh(jax.devices()[:2], 'x')
mesh2 = Mesh(jax.devices()[2:4], 'x')
arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P()))
arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P()))
# Creating abstract mesh with mesh1 but since both meshes have the same shape (names
# and axis size), it should be ok.
abstract_mesh = jax.sharding.AbstractMesh(arr_mesh1.shape_tuple)
@jax.jit
def f(x):
y = with_sharding_constraint(x, NamedSharding(abstract_mesh, P('x')))
return y * 2
f(arr_mesh1)
f(arr_mesh2) # tracing and lowering cache hit
```
**One caveat is that this only works with `jax.NamedSharding` but that's fine because `NamedSharding` is the most used `Sharding` in JAX.**
**What about `shard_map`?**
shard_map's signature will be: `shmap(f, mesh: Mesh | AbstractMesh, in_specs: Specs, out_specs: Specs)`.
```
mesh1 = Mesh(jax.devices()[:2], 'x')
mesh2 = Mesh(jax.devices()[2:4], 'x')
arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P()))
arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P()))
# Creating abstract mesh with mesh1 but since both meshes have the same shape (names
# and axis size), it should be ok.
abstract_mesh = jax.sharding.AbstractMesh(arr_mesh1.shape_tuple)
@jax.jit
def f(x):
y = shard_map(lambda x: x, mesh=abstract_mesh, in_specs=P('x'), out_specs=P('x'))
return y * 2
f(arr_mesh1)
f(arr_mesh2) # tracing and lowering cache hit
```
This is a fully backwards change. So your current code will continue to work as is but you can opt-into this new behavior and get all the benefits!
PiperOrigin-RevId: 662670932
2024-08-13 15:17:30 -07:00
|
|
|
np_inp = np.arange(16).reshape(8, 2)
|
|
|
|
arr = jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y')))
|
|
|
|
|
|
|
|
abstract_mesh = jax.sharding.AbstractMesh(mesh.shape_tuple)
|
|
|
|
|
|
|
|
def f(x):
|
|
|
|
x = with_sharding_constraint(x, NamedSharding(abstract_mesh, P('x')))
|
|
|
|
return x * 2
|
|
|
|
|
|
|
|
out = jax.jit(f)(arr)
|
|
|
|
self.assertArraysEqual(out, np_inp * 2)
|
|
|
|
self.assertEqual(out.sharding, NamedSharding(mesh, P('x')))
|
|
|
|
|
|
|
|
out_eager = f(arr)
|
|
|
|
self.assertArraysEqual(out_eager, np_inp * 2)
|
|
|
|
self.assertEqual(out_eager.sharding, NamedSharding(mesh, P('x')))
|
|
|
|
|
|
|
|
@unittest.skipIf(xla_extension_version < 281,
|
|
|
|
'Requires xla_extension_version >= 281')
|
|
|
|
def test_wsc_sds_abstract_mesh(self):
|
2024-09-03 16:22:23 -07:00
|
|
|
mesh = jtu.create_mesh((2,), 'x')
|
Introduce `jax.sharding.AbstractMesh(shape_tuple: tuple[tuple[str, int], ...])` and allow `with_sharding_constraint` and `shard_map` to accept an abstract mesh as input (`with_sharding_constraint` is via `NamedSharding(abstract_mesh, pspec)`).
**Semantics**
Inside jit, we don't need to talk about concrete devices ever so the semantics stay the same as today i.e. we can lower a NamedSharding with abstract mesh with only mesh axis names and sizes and PartitionSpec. The only restriction is that the number of devices need to be consistent throughout the program when we are tracing.
During compilation, the order of devices throughout the program needs to be consistent (same as before this change).
Outside jit i.e. eager mode, if a `shard_map` or `with_sharding_constraint` contains AbstractMesh, then the input to those primitives should contain a concrete Mesh with the same shape and names as the abstract mesh.
**Why do this?**
There are cases, where you want the change the devices in the mesh but keep the mesh shape the same (axis names and axis sizes). But this leads to a device mismatch error if you have `with_sharding_constraint` or `shard_map` in your computation because they embed concrete devices in their signature.
So to fix the error, you need to change the mesh in `wsc` and `shmap` which will lead to a tracing cache miss (because function id is now different) and consequently a lowering to stableHLO cache miss. Explaining via an example:
```
mesh1 = Mesh(jax.devices()[:2], 'x')
mesh2 = Mesh(jax.devices()[2:4], 'x')
arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P()))
arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P()))
@jax.jit
def f(x):
y = with_sharding_constraint(x, NamedSharding(mesh1, P('x')))
return y * 2
f(arr_mesh1)
f(arr_mesh2) # DEVICE MISMATCH ERROR!
```
The same problem exists for `shard_map` since it takes a mesh with concrete devices in it's signature.
**Okay, so how do you fix this?**
As mentioned above, we need the above program to work and get tracing and lowering cache hits (**cache hits is the most important** part here)
The approach in this change, allows `with_sharding_constraint` to accept a `NamedSharding(abstract_mesh, pspec)` as input. This leads to no errors downstream and we get tracing and lowering cache hits since we don't encode the concrete devices anymore. Just the axis_names and axis_size of the mesh.
**The important part is that the concrete device information should only come from the arguments. Inside `jax.jit`, you should never reference concrete devices ever.**
```
mesh1 = Mesh(jax.devices()[:2], 'x')
mesh2 = Mesh(jax.devices()[2:4], 'x')
arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P()))
arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P()))
# Creating abstract mesh with mesh1 but since both meshes have the same shape (names
# and axis size), it should be ok.
abstract_mesh = jax.sharding.AbstractMesh(arr_mesh1.shape_tuple)
@jax.jit
def f(x):
y = with_sharding_constraint(x, NamedSharding(abstract_mesh, P('x')))
return y * 2
f(arr_mesh1)
f(arr_mesh2) # tracing and lowering cache hit
```
**One caveat is that this only works with `jax.NamedSharding` but that's fine because `NamedSharding` is the most used `Sharding` in JAX.**
**What about `shard_map`?**
shard_map's signature will be: `shmap(f, mesh: Mesh | AbstractMesh, in_specs: Specs, out_specs: Specs)`.
```
mesh1 = Mesh(jax.devices()[:2], 'x')
mesh2 = Mesh(jax.devices()[2:4], 'x')
arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P()))
arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P()))
# Creating abstract mesh with mesh1 but since both meshes have the same shape (names
# and axis size), it should be ok.
abstract_mesh = jax.sharding.AbstractMesh(arr_mesh1.shape_tuple)
@jax.jit
def f(x):
y = shard_map(lambda x: x, mesh=abstract_mesh, in_specs=P('x'), out_specs=P('x'))
return y * 2
f(arr_mesh1)
f(arr_mesh2) # tracing and lowering cache hit
```
This is a fully backwards change. So your current code will continue to work as is but you can opt-into this new behavior and get all the benefits!
PiperOrigin-RevId: 662670932
2024-08-13 15:17:30 -07:00
|
|
|
s = NamedSharding(mesh, P())
|
|
|
|
abstract_mesh = mesh_lib.AbstractMesh(mesh.shape_tuple)
|
|
|
|
|
|
|
|
@jax.jit
|
|
|
|
def f(x):
|
|
|
|
x = with_sharding_constraint(x, NamedSharding(abstract_mesh, P('x')))
|
|
|
|
return x * 2
|
|
|
|
|
|
|
|
sds = jax.ShapeDtypeStruct((8, 2), np.float32, sharding=s)
|
|
|
|
f.eval_shape(sds) # doesn't crash
|
|
|
|
|
|
|
|
@unittest.skipIf(xla_extension_version < 281,
|
|
|
|
'Requires xla_extension_version >= 281')
|
|
|
|
def test_wsc_vmap_abstract_mesh(self):
|
2024-09-03 16:22:23 -07:00
|
|
|
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
Introduce `jax.sharding.AbstractMesh(shape_tuple: tuple[tuple[str, int], ...])` and allow `with_sharding_constraint` and `shard_map` to accept an abstract mesh as input (`with_sharding_constraint` is via `NamedSharding(abstract_mesh, pspec)`).
**Semantics**
Inside jit, we don't need to talk about concrete devices ever so the semantics stay the same as today i.e. we can lower a NamedSharding with abstract mesh with only mesh axis names and sizes and PartitionSpec. The only restriction is that the number of devices need to be consistent throughout the program when we are tracing.
During compilation, the order of devices throughout the program needs to be consistent (same as before this change).
Outside jit i.e. eager mode, if a `shard_map` or `with_sharding_constraint` contains AbstractMesh, then the input to those primitives should contain a concrete Mesh with the same shape and names as the abstract mesh.
**Why do this?**
There are cases, where you want the change the devices in the mesh but keep the mesh shape the same (axis names and axis sizes). But this leads to a device mismatch error if you have `with_sharding_constraint` or `shard_map` in your computation because they embed concrete devices in their signature.
So to fix the error, you need to change the mesh in `wsc` and `shmap` which will lead to a tracing cache miss (because function id is now different) and consequently a lowering to stableHLO cache miss. Explaining via an example:
```
mesh1 = Mesh(jax.devices()[:2], 'x')
mesh2 = Mesh(jax.devices()[2:4], 'x')
arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P()))
arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P()))
@jax.jit
def f(x):
y = with_sharding_constraint(x, NamedSharding(mesh1, P('x')))
return y * 2
f(arr_mesh1)
f(arr_mesh2) # DEVICE MISMATCH ERROR!
```
The same problem exists for `shard_map` since it takes a mesh with concrete devices in it's signature.
**Okay, so how do you fix this?**
As mentioned above, we need the above program to work and get tracing and lowering cache hits (**cache hits is the most important** part here)
The approach in this change, allows `with_sharding_constraint` to accept a `NamedSharding(abstract_mesh, pspec)` as input. This leads to no errors downstream and we get tracing and lowering cache hits since we don't encode the concrete devices anymore. Just the axis_names and axis_size of the mesh.
**The important part is that the concrete device information should only come from the arguments. Inside `jax.jit`, you should never reference concrete devices ever.**
```
mesh1 = Mesh(jax.devices()[:2], 'x')
mesh2 = Mesh(jax.devices()[2:4], 'x')
arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P()))
arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P()))
# Creating abstract mesh with mesh1 but since both meshes have the same shape (names
# and axis size), it should be ok.
abstract_mesh = jax.sharding.AbstractMesh(arr_mesh1.shape_tuple)
@jax.jit
def f(x):
y = with_sharding_constraint(x, NamedSharding(abstract_mesh, P('x')))
return y * 2
f(arr_mesh1)
f(arr_mesh2) # tracing and lowering cache hit
```
**One caveat is that this only works with `jax.NamedSharding` but that's fine because `NamedSharding` is the most used `Sharding` in JAX.**
**What about `shard_map`?**
shard_map's signature will be: `shmap(f, mesh: Mesh | AbstractMesh, in_specs: Specs, out_specs: Specs)`.
```
mesh1 = Mesh(jax.devices()[:2], 'x')
mesh2 = Mesh(jax.devices()[2:4], 'x')
arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P()))
arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P()))
# Creating abstract mesh with mesh1 but since both meshes have the same shape (names
# and axis size), it should be ok.
abstract_mesh = jax.sharding.AbstractMesh(arr_mesh1.shape_tuple)
@jax.jit
def f(x):
y = shard_map(lambda x: x, mesh=abstract_mesh, in_specs=P('x'), out_specs=P('x'))
return y * 2
f(arr_mesh1)
f(arr_mesh2) # tracing and lowering cache hit
```
This is a fully backwards change. So your current code will continue to work as is but you can opt-into this new behavior and get all the benefits!
PiperOrigin-RevId: 662670932
2024-08-13 15:17:30 -07:00
|
|
|
s = NamedSharding(mesh, P('x', 'y'))
|
|
|
|
np_inp = np.arange(16).reshape(8, 2)
|
|
|
|
arr = jax.device_put(np_inp, s)
|
|
|
|
|
|
|
|
def f(x):
|
|
|
|
x = with_sharding_constraint(x, NamedSharding(mesh.abstract_mesh, P('x')))
|
|
|
|
return x * 2
|
|
|
|
|
|
|
|
out = jax.jit(jax.vmap(f))(arr) # doesn't crash
|
|
|
|
self.assertEqual(out.sharding, NamedSharding(mesh, P(None, 'x')))
|
|
|
|
|
|
|
|
out2 = jax.jit(jax.vmap(f, spmd_axis_name='y'))(arr)
|
|
|
|
self.assertEqual(out2.sharding, NamedSharding(mesh, P('y', 'x')))
|
|
|
|
|
|
|
|
@unittest.skipIf(xla_extension_version < 281,
|
|
|
|
'Requires xla_extension_version >= 281')
|
|
|
|
def test_wsc_abstract_mesh_errors(self):
|
2024-09-03 16:22:23 -07:00
|
|
|
mesh = jtu.create_mesh((2,), ('x',))
|
Introduce `jax.sharding.AbstractMesh(shape_tuple: tuple[tuple[str, int], ...])` and allow `with_sharding_constraint` and `shard_map` to accept an abstract mesh as input (`with_sharding_constraint` is via `NamedSharding(abstract_mesh, pspec)`).
**Semantics**
Inside jit, we don't need to talk about concrete devices ever so the semantics stay the same as today i.e. we can lower a NamedSharding with abstract mesh with only mesh axis names and sizes and PartitionSpec. The only restriction is that the number of devices need to be consistent throughout the program when we are tracing.
During compilation, the order of devices throughout the program needs to be consistent (same as before this change).
Outside jit i.e. eager mode, if a `shard_map` or `with_sharding_constraint` contains AbstractMesh, then the input to those primitives should contain a concrete Mesh with the same shape and names as the abstract mesh.
**Why do this?**
There are cases, where you want the change the devices in the mesh but keep the mesh shape the same (axis names and axis sizes). But this leads to a device mismatch error if you have `with_sharding_constraint` or `shard_map` in your computation because they embed concrete devices in their signature.
So to fix the error, you need to change the mesh in `wsc` and `shmap` which will lead to a tracing cache miss (because function id is now different) and consequently a lowering to stableHLO cache miss. Explaining via an example:
```
mesh1 = Mesh(jax.devices()[:2], 'x')
mesh2 = Mesh(jax.devices()[2:4], 'x')
arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P()))
arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P()))
@jax.jit
def f(x):
y = with_sharding_constraint(x, NamedSharding(mesh1, P('x')))
return y * 2
f(arr_mesh1)
f(arr_mesh2) # DEVICE MISMATCH ERROR!
```
The same problem exists for `shard_map` since it takes a mesh with concrete devices in it's signature.
**Okay, so how do you fix this?**
As mentioned above, we need the above program to work and get tracing and lowering cache hits (**cache hits is the most important** part here)
The approach in this change, allows `with_sharding_constraint` to accept a `NamedSharding(abstract_mesh, pspec)` as input. This leads to no errors downstream and we get tracing and lowering cache hits since we don't encode the concrete devices anymore. Just the axis_names and axis_size of the mesh.
**The important part is that the concrete device information should only come from the arguments. Inside `jax.jit`, you should never reference concrete devices ever.**
```
mesh1 = Mesh(jax.devices()[:2], 'x')
mesh2 = Mesh(jax.devices()[2:4], 'x')
arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P()))
arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P()))
# Creating abstract mesh with mesh1 but since both meshes have the same shape (names
# and axis size), it should be ok.
abstract_mesh = jax.sharding.AbstractMesh(arr_mesh1.shape_tuple)
@jax.jit
def f(x):
y = with_sharding_constraint(x, NamedSharding(abstract_mesh, P('x')))
return y * 2
f(arr_mesh1)
f(arr_mesh2) # tracing and lowering cache hit
```
**One caveat is that this only works with `jax.NamedSharding` but that's fine because `NamedSharding` is the most used `Sharding` in JAX.**
**What about `shard_map`?**
shard_map's signature will be: `shmap(f, mesh: Mesh | AbstractMesh, in_specs: Specs, out_specs: Specs)`.
```
mesh1 = Mesh(jax.devices()[:2], 'x')
mesh2 = Mesh(jax.devices()[2:4], 'x')
arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P()))
arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P()))
# Creating abstract mesh with mesh1 but since both meshes have the same shape (names
# and axis size), it should be ok.
abstract_mesh = jax.sharding.AbstractMesh(arr_mesh1.shape_tuple)
@jax.jit
def f(x):
y = shard_map(lambda x: x, mesh=abstract_mesh, in_specs=P('x'), out_specs=P('x'))
return y * 2
f(arr_mesh1)
f(arr_mesh2) # tracing and lowering cache hit
```
This is a fully backwards change. So your current code will continue to work as is but you can opt-into this new behavior and get all the benefits!
PiperOrigin-RevId: 662670932
2024-08-13 15:17:30 -07:00
|
|
|
np_inp = np.arange(8)
|
|
|
|
abstract_mesh = jax.sharding.AbstractMesh(mesh.shape_tuple)
|
|
|
|
s_abs = NamedSharding(abstract_mesh, P('x'))
|
|
|
|
|
|
|
|
with self.assertRaisesRegex(
|
|
|
|
ValueError, ".*requires the input passed should be a `jax.Array`.*"):
|
|
|
|
with_sharding_constraint(np_inp, s_abs)
|
|
|
|
|
|
|
|
with self.assertRaisesRegex(
|
|
|
|
TypeError, "The sharding on the input must be a `NamedSharding`"):
|
|
|
|
with_sharding_constraint(jnp.arange(8), s_abs)
|
|
|
|
|
|
|
|
arr = jax.device_put(np_inp, NamedSharding(mesh, P('x')))
|
|
|
|
abs_mesh2 = mesh_lib.AbstractMesh(
|
2024-09-03 16:22:23 -07:00
|
|
|
jtu.create_mesh((2,), 'y').shape_tuple)
|
Introduce `jax.sharding.AbstractMesh(shape_tuple: tuple[tuple[str, int], ...])` and allow `with_sharding_constraint` and `shard_map` to accept an abstract mesh as input (`with_sharding_constraint` is via `NamedSharding(abstract_mesh, pspec)`).
**Semantics**
Inside jit, we don't need to talk about concrete devices ever so the semantics stay the same as today i.e. we can lower a NamedSharding with abstract mesh with only mesh axis names and sizes and PartitionSpec. The only restriction is that the number of devices need to be consistent throughout the program when we are tracing.
During compilation, the order of devices throughout the program needs to be consistent (same as before this change).
Outside jit i.e. eager mode, if a `shard_map` or `with_sharding_constraint` contains AbstractMesh, then the input to those primitives should contain a concrete Mesh with the same shape and names as the abstract mesh.
**Why do this?**
There are cases, where you want the change the devices in the mesh but keep the mesh shape the same (axis names and axis sizes). But this leads to a device mismatch error if you have `with_sharding_constraint` or `shard_map` in your computation because they embed concrete devices in their signature.
So to fix the error, you need to change the mesh in `wsc` and `shmap` which will lead to a tracing cache miss (because function id is now different) and consequently a lowering to stableHLO cache miss. Explaining via an example:
```
mesh1 = Mesh(jax.devices()[:2], 'x')
mesh2 = Mesh(jax.devices()[2:4], 'x')
arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P()))
arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P()))
@jax.jit
def f(x):
y = with_sharding_constraint(x, NamedSharding(mesh1, P('x')))
return y * 2
f(arr_mesh1)
f(arr_mesh2) # DEVICE MISMATCH ERROR!
```
The same problem exists for `shard_map` since it takes a mesh with concrete devices in it's signature.
**Okay, so how do you fix this?**
As mentioned above, we need the above program to work and get tracing and lowering cache hits (**cache hits is the most important** part here)
The approach in this change, allows `with_sharding_constraint` to accept a `NamedSharding(abstract_mesh, pspec)` as input. This leads to no errors downstream and we get tracing and lowering cache hits since we don't encode the concrete devices anymore. Just the axis_names and axis_size of the mesh.
**The important part is that the concrete device information should only come from the arguments. Inside `jax.jit`, you should never reference concrete devices ever.**
```
mesh1 = Mesh(jax.devices()[:2], 'x')
mesh2 = Mesh(jax.devices()[2:4], 'x')
arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P()))
arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P()))
# Creating abstract mesh with mesh1 but since both meshes have the same shape (names
# and axis size), it should be ok.
abstract_mesh = jax.sharding.AbstractMesh(arr_mesh1.shape_tuple)
@jax.jit
def f(x):
y = with_sharding_constraint(x, NamedSharding(abstract_mesh, P('x')))
return y * 2
f(arr_mesh1)
f(arr_mesh2) # tracing and lowering cache hit
```
**One caveat is that this only works with `jax.NamedSharding` but that's fine because `NamedSharding` is the most used `Sharding` in JAX.**
**What about `shard_map`?**
shard_map's signature will be: `shmap(f, mesh: Mesh | AbstractMesh, in_specs: Specs, out_specs: Specs)`.
```
mesh1 = Mesh(jax.devices()[:2], 'x')
mesh2 = Mesh(jax.devices()[2:4], 'x')
arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P()))
arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P()))
# Creating abstract mesh with mesh1 but since both meshes have the same shape (names
# and axis size), it should be ok.
abstract_mesh = jax.sharding.AbstractMesh(arr_mesh1.shape_tuple)
@jax.jit
def f(x):
y = shard_map(lambda x: x, mesh=abstract_mesh, in_specs=P('x'), out_specs=P('x'))
return y * 2
f(arr_mesh1)
f(arr_mesh2) # tracing and lowering cache hit
```
This is a fully backwards change. So your current code will continue to work as is but you can opt-into this new behavior and get all the benefits!
PiperOrigin-RevId: 662670932
2024-08-13 15:17:30 -07:00
|
|
|
with self.assertRaisesRegex(
|
|
|
|
ValueError,
|
|
|
|
'Mesh shape of the input.*does not'
|
|
|
|
' match the mesh shape of the target sharding.*'):
|
|
|
|
with_sharding_constraint(arr, NamedSharding(abs_mesh2, P('y')))
|
|
|
|
|
2022-09-21 20:17:38 -07:00
|
|
|
|
2021-04-15 06:12:18 -07:00
|
|
|
def spec_regex(s):
|
|
|
|
return str(s).replace(r"(", r"\(").replace(r")", r"\)")
|
|
|
|
|
2022-01-31 08:44:11 -08:00
|
|
|
|
2024-08-29 10:49:30 -07:00
|
|
|
class ShardingInTypesTest(jtu.JaxTestCase):
|
|
|
|
|
|
|
|
@config.sharding_in_types(True)
|
|
|
|
def test_basic_mul(self):
|
2024-09-03 16:22:23 -07:00
|
|
|
mesh = jtu.create_mesh((4, 2), ('x', 'y'))
|
2024-08-29 10:49:30 -07:00
|
|
|
np_inp = np.arange(16).reshape(8, 2)
|
|
|
|
s = NamedSharding(mesh, P('x', 'y'))
|
|
|
|
arr = jax.device_put(np_inp, s)
|
|
|
|
|
|
|
|
@jax.jit
|
|
|
|
def f(x):
|
|
|
|
self.assertEqual(x.sharding.spec, s.spec)
|
|
|
|
x = x * 2
|
|
|
|
self.assertEqual(x.sharding.spec, s.spec)
|
|
|
|
x = x * x
|
|
|
|
self.assertEqual(x.sharding.spec, s.spec)
|
|
|
|
return x
|
|
|
|
|
|
|
|
out = f(arr)
|
|
|
|
self.assertEqual(out.sharding, s)
|
|
|
|
self.assertArraysEqual(out, (np_inp * 2) * (np_inp * 2))
|
|
|
|
|
|
|
|
lowered_text = f.lower(arr).as_text()
|
2024-08-30 09:49:22 -07:00
|
|
|
if config.use_shardy_partitioner.value:
|
|
|
|
self.assertIn('sdy.sharding_constraint', lowered_text)
|
|
|
|
else:
|
|
|
|
self.assertEqual(lowered_text.count('@Sharding'), 2)
|
2024-08-29 10:49:30 -07:00
|
|
|
|
|
|
|
|
2023-01-12 22:42:06 +00:00
|
|
|
@jtu.pytest_mark_if_available('multiaccelerator')
|
2021-04-15 06:12:18 -07:00
|
|
|
class PJitErrorTest(jtu.JaxTestCase):
|
2022-07-25 13:17:47 -07:00
|
|
|
|
2021-04-15 06:12:18 -07:00
|
|
|
@check_1d_2d_mesh(set_mesh=True)
|
|
|
|
def testNonDivisibleArgs(self, mesh, resources):
|
2021-04-20 11:39:33 -07:00
|
|
|
x = jnp.ones((3, 2))
|
2021-04-15 06:12:18 -07:00
|
|
|
spec = P(resources, None)
|
2023-04-13 11:48:11 -07:00
|
|
|
mesh_size = str(math.prod([dim[1] for dim in mesh]))
|
2022-07-25 13:17:47 -07:00
|
|
|
error = re.compile(
|
2023-05-04 21:49:28 -07:00
|
|
|
r"One of pjit arguments with pytree key path x.*" + spec_regex(spec) + r".*"
|
2023-04-06 14:51:30 -07:00
|
|
|
r"implies that the global size of its dimension 0 should be "
|
2022-10-28 09:13:43 -07:00
|
|
|
r"divisible by " + mesh_size + r", but it is equal to 3 "
|
|
|
|
r"\(full shape: \(3, 2\)\)", re.M | re.S)
|
2022-07-25 13:17:47 -07:00
|
|
|
with self.assertRaisesRegex(ValueError, error):
|
2023-02-18 09:59:58 -08:00
|
|
|
pjit(lambda x: x, in_shardings=spec, out_shardings=None)(x)
|
2021-04-15 06:12:18 -07:00
|
|
|
|
|
|
|
@check_1d_2d_mesh(set_mesh=True)
|
|
|
|
def testNonDivisibleOuts(self, mesh, resources):
|
2021-04-20 11:39:33 -07:00
|
|
|
x = jnp.ones((3, 2))
|
2021-04-15 06:12:18 -07:00
|
|
|
spec = P(resources, None)
|
2023-04-13 11:48:11 -07:00
|
|
|
mesh_size = str(math.prod([dim[1] for dim in mesh]))
|
2022-07-25 13:17:47 -07:00
|
|
|
error = re.compile(
|
2023-05-04 21:49:28 -07:00
|
|
|
r"One of pjit outputs with pytree key path \['rrr'\].*" + spec_regex(spec) + r".*"
|
2023-04-06 14:51:30 -07:00
|
|
|
r"implies that the global size of its dimension 0 should be "
|
2022-07-25 13:17:47 -07:00
|
|
|
r"divisible by " + mesh_size + r", but it is equal to 3", re.M | re.S)
|
|
|
|
with self.assertRaisesRegex(ValueError, error):
|
2023-05-04 21:49:28 -07:00
|
|
|
pjit(lambda x: {'rrr': x}, in_shardings=None,
|
|
|
|
out_shardings=P(resources, None))(x)
|
2021-04-15 06:12:18 -07:00
|
|
|
|
|
|
|
@check_1d_2d_mesh(set_mesh=False)
|
2021-06-01 14:32:59 +03:00
|
|
|
@jtu.with_mesh([('z', 1)])
|
2021-04-15 06:12:18 -07:00
|
|
|
def testUndefinedResourcesArgs(self, mesh, resources):
|
2021-04-20 11:39:33 -07:00
|
|
|
x = jnp.ones((2, 2))
|
2021-04-15 06:12:18 -07:00
|
|
|
spec = P(resources,)
|
2022-07-25 13:17:47 -07:00
|
|
|
with self.assertRaisesRegex(
|
|
|
|
ValueError,
|
2024-08-09 12:17:28 -07:00
|
|
|
r"Resource axis: x of.*" + spec_regex(spec) + r" is not found in mesh: \(.*\)."):
|
2023-02-18 09:59:58 -08:00
|
|
|
pjit(lambda x: x, in_shardings=spec, out_shardings=None)(x)
|
2021-04-15 06:12:18 -07:00
|
|
|
|
|
|
|
@check_1d_2d_mesh(set_mesh=False)
|
2021-06-01 14:32:59 +03:00
|
|
|
@jtu.with_mesh([('z', 1)])
|
2021-04-15 06:12:18 -07:00
|
|
|
def testUndefinedResourcesOuts(self, mesh, resources):
|
2021-04-20 11:39:33 -07:00
|
|
|
x = jnp.ones((2, 2))
|
2021-04-15 06:12:18 -07:00
|
|
|
spec = P(resources,)
|
2022-07-25 13:17:47 -07:00
|
|
|
with self.assertRaisesRegex(
|
|
|
|
ValueError,
|
2024-08-09 12:17:28 -07:00
|
|
|
r"Resource axis: x of.*" + spec_regex(spec) + r" is not found in mesh: \(.*\)."):
|
2023-02-18 09:59:58 -08:00
|
|
|
pjit(lambda x: x, in_shardings=None, out_shardings=spec)(x)
|
2021-04-15 06:12:18 -07:00
|
|
|
|
|
|
|
@check_1d_2d_mesh(set_mesh=False)
|
2021-06-01 14:32:59 +03:00
|
|
|
@jtu.with_mesh([('z', 1)])
|
2021-04-15 06:12:18 -07:00
|
|
|
def testUndefinedResourcesConstraint(self, mesh, resources):
|
2021-04-20 11:39:33 -07:00
|
|
|
x = jnp.ones((2, 2))
|
2021-04-15 06:12:18 -07:00
|
|
|
spec = P(resources,)
|
2022-07-25 13:17:47 -07:00
|
|
|
with self.assertRaisesRegex(
|
|
|
|
ValueError,
|
2024-08-09 12:17:28 -07:00
|
|
|
r"Resource axis: x of.*" + spec_regex(spec) + r" is not found in mesh: \(.*\)."):
|
2023-02-18 09:59:58 -08:00
|
|
|
pjit(
|
|
|
|
lambda x: with_sharding_constraint(x, spec),
|
|
|
|
in_shardings=None,
|
|
|
|
out_shardings=None,
|
|
|
|
)(x)
|
2021-04-15 06:12:18 -07:00
|
|
|
|
2021-06-01 14:32:59 +03:00
|
|
|
@jtu.with_mesh([('x', 2), ('y', 1)])
|
2021-04-20 11:39:33 -07:00
|
|
|
def testRankTooLowArgs(self):
|
|
|
|
x = jnp.arange(2)
|
|
|
|
spec = P('x', 'y')
|
2022-07-25 13:17:47 -07:00
|
|
|
error = re.compile(
|
|
|
|
r"One of pjit arguments.*" + spec_regex(spec) +
|
|
|
|
r".*rank at least 2, but was applied to a value of rank 1", re.M | re.S)
|
2021-04-20 11:39:33 -07:00
|
|
|
with self.assertRaisesRegex(ValueError, error):
|
2023-02-18 09:59:58 -08:00
|
|
|
pjit(lambda x: x.sum(), in_shardings=spec, out_shardings=None)(x)
|
2021-04-20 11:39:33 -07:00
|
|
|
|
2022-01-14 14:51:57 -08:00
|
|
|
@jtu.with_mesh([('x', 2), ('y', 1)])
|
|
|
|
def testRankTooLowArgsAxisResourcesNone(self):
|
|
|
|
x = jnp.arange(2)
|
|
|
|
spec = P(None, None)
|
2022-07-25 13:17:47 -07:00
|
|
|
error = re.compile(
|
|
|
|
r"One of pjit arguments.*" + spec_regex(spec) +
|
|
|
|
r".*rank at least 2, but was applied to a value of rank 1", re.M | re.S)
|
2022-01-14 14:51:57 -08:00
|
|
|
with self.assertRaisesRegex(ValueError, error):
|
2023-02-18 09:59:58 -08:00
|
|
|
pjit(lambda x: x.sum(), in_shardings=spec, out_shardings=None)(x)
|
2022-01-14 14:51:57 -08:00
|
|
|
|
2021-06-01 14:32:59 +03:00
|
|
|
@jtu.with_mesh([('x', 2), ('y', 1)])
|
2021-04-20 11:39:33 -07:00
|
|
|
def testRankTooLowOuts(self):
|
|
|
|
x = jnp.arange(2)
|
|
|
|
spec = P('x', 'y')
|
2022-07-25 13:17:47 -07:00
|
|
|
error = re.compile(
|
|
|
|
r"One of pjit outputs.*" + spec_regex(spec) +
|
|
|
|
r".*rank at least 2, but was applied to a value of rank 0", re.M | re.S)
|
2021-04-20 11:39:33 -07:00
|
|
|
with self.assertRaisesRegex(ValueError, error):
|
2023-02-18 09:59:58 -08:00
|
|
|
pjit(lambda x: x.sum(), in_shardings=None, out_shardings=spec)(x)
|
2021-04-20 11:39:33 -07:00
|
|
|
|
2021-06-01 14:32:59 +03:00
|
|
|
@jtu.with_mesh([('x', 2), ('y', 1)])
|
2021-04-20 11:39:33 -07:00
|
|
|
def testRankTooLowConstraint(self):
|
|
|
|
x = jnp.arange(2)
|
|
|
|
spec = P('x', 'y')
|
2022-07-25 13:17:47 -07:00
|
|
|
error = re.compile(
|
2022-09-09 09:13:10 -07:00
|
|
|
r"One of with_sharding_constraint arguments" + r".*" + spec_regex(spec) +
|
2022-07-25 13:17:47 -07:00
|
|
|
r".*rank at least 2, but was applied to a value of rank 1", re.M | re.S)
|
2021-04-20 11:39:33 -07:00
|
|
|
with self.assertRaisesRegex(ValueError, error):
|
2023-02-18 09:59:58 -08:00
|
|
|
pjit(
|
2024-05-04 03:27:31 +00:00
|
|
|
lambda x: with_sharding_constraint(x, spec), in_shardings=None,
|
2023-02-18 09:59:58 -08:00
|
|
|
out_shardings=None,
|
|
|
|
)(x)
|
2021-04-20 11:39:33 -07:00
|
|
|
|
2021-06-01 14:32:59 +03:00
|
|
|
@jtu.with_mesh([('x', 2), ('y', 1)])
|
2021-04-26 03:45:31 -07:00
|
|
|
def testRepeatedInResources(self):
|
|
|
|
x = jnp.arange(2)
|
|
|
|
for spec in [P('x', 'x'), P('x', ('y', 'x'))]:
|
2023-02-11 15:29:38 -08:00
|
|
|
error = (r"A single in_shardings specification can map every mesh "
|
2021-04-26 03:45:31 -07:00
|
|
|
r"axis to at most one positional dimension, but " +
|
|
|
|
spec_regex(spec) + " has duplicate entries for `x`")
|
|
|
|
with self.assertRaisesRegex(ValueError, error):
|
2023-02-18 09:59:58 -08:00
|
|
|
pjit(lambda x: x, in_shardings=spec, out_shardings=None)(x)
|
2021-04-26 03:45:31 -07:00
|
|
|
|
2021-06-01 14:32:59 +03:00
|
|
|
@jtu.with_mesh([('x', 2), ('y', 1)])
|
2021-04-26 03:45:31 -07:00
|
|
|
def testRepeatedOutResources(self):
|
|
|
|
x = jnp.arange(2)
|
|
|
|
for spec in [P('x', 'x'), P('x', ('y', 'x'))]:
|
2023-02-11 15:29:38 -08:00
|
|
|
error = (r"A single out_shardings specification can map every mesh "
|
2021-04-26 03:45:31 -07:00
|
|
|
r"axis to at most one positional dimension, but " +
|
|
|
|
spec_regex(spec) + " has duplicate entries for `x`")
|
|
|
|
with self.assertRaisesRegex(ValueError, error):
|
2023-02-18 09:59:58 -08:00
|
|
|
pjit(lambda x: x, in_shardings=None, out_shardings=spec)(x)
|
2021-04-26 03:45:31 -07:00
|
|
|
|
2021-05-06 04:18:47 -07:00
|
|
|
def testEmptyMesh(self):
|
2023-06-15 15:21:36 -07:00
|
|
|
out = pjit(lambda x: x, in_shardings=None, out_shardings=None)(jnp.arange(4))
|
|
|
|
self.assertEqual(out.sharding, SingleDeviceSharding(jax.devices()[0]))
|
2021-05-06 04:18:47 -07:00
|
|
|
|
2023-05-03 19:28:54 -07:00
|
|
|
def test_pspec_to_wsc_without_mesh(self):
|
|
|
|
error = (
|
|
|
|
r'with_sharding_constraint requires a non-empty mesh if you are '
|
|
|
|
r'passing `PartitionSpec`s or `None` to shardings.*')
|
|
|
|
with self.assertRaisesRegex(RuntimeError, error):
|
|
|
|
pjit(lambda x: with_sharding_constraint(x, P('x')))(jnp.arange(4))
|
|
|
|
|
2021-09-07 07:53:42 -07:00
|
|
|
@jtu.with_mesh([('x', 2)])
|
|
|
|
def testAxisResourcesMismatch(self):
|
|
|
|
x = jnp.ones([])
|
|
|
|
p = [None, None, None]
|
2022-02-08 12:45:38 -08:00
|
|
|
|
2021-09-07 07:53:42 -07:00
|
|
|
pjit(lambda x: x, (p,), p)([x, x, x]) # OK
|
2022-02-08 12:45:38 -08:00
|
|
|
|
2021-09-07 07:53:42 -07:00
|
|
|
error = re.escape(
|
2023-02-11 15:29:38 -08:00
|
|
|
"pjit in_shardings specification must be a tree prefix of the "
|
2022-02-08 12:45:38 -08:00
|
|
|
"positional arguments tuple passed to the `pjit`-decorated function. "
|
2023-02-11 15:29:38 -08:00
|
|
|
"In particular, pjit in_shardings must either be a None, a "
|
2022-02-08 12:45:38 -08:00
|
|
|
"PartitionSpec, or a tuple of length equal to the number of positional "
|
2023-02-11 15:29:38 -08:00
|
|
|
"arguments. But pjit in_shardings is the wrong length: got a "
|
2022-02-08 12:45:38 -08:00
|
|
|
"tuple or list of length 3 for an args tuple of length 2.")
|
2021-09-07 07:53:42 -07:00
|
|
|
with self.assertRaisesRegex(ValueError, error):
|
2022-02-08 12:45:38 -08:00
|
|
|
pjit(lambda x, y: x, p, p)(x, x)
|
|
|
|
|
|
|
|
Foo = namedtuple('Foo', ['x'])
|
2023-02-11 15:29:38 -08:00
|
|
|
error = "in_shardings is not a tuple.*might need to be wrapped"
|
2022-02-08 12:45:38 -08:00
|
|
|
with self.assertRaisesRegex(ValueError, error):
|
|
|
|
pjit(lambda x: x, Foo(None), Foo(None))(Foo(x))
|
|
|
|
|
|
|
|
pjit(lambda x: x, (Foo(None),), Foo(None))(Foo(x)) # OK w/ singleton tuple
|
|
|
|
|
|
|
|
# TODO(apaszke,mattjj): Disable implicit list casts and enable this
|
|
|
|
# error = ("it looks like pjit in_axis_resources might need to be wrapped in "
|
|
|
|
# "a singleton tuple.")
|
|
|
|
# with self.assertRaisesRegex(ValueError, error):
|
|
|
|
# pjit(lambda x, y: x, p, p)([x, x, x])
|
|
|
|
|
2021-09-07 07:53:42 -07:00
|
|
|
# TODO(apaszke): Disable implicit list casts and enable this
|
|
|
|
# error = re.escape(
|
2021-11-12 22:41:42 -08:00
|
|
|
# r"pjit in_axis_resources specification must be a tree prefix of the "
|
|
|
|
# r"corresponding value, got specification (None, None, None) for value "
|
|
|
|
# r"tree PyTreeDef(([*, *, *],)). Note that pjit in_axis_resources that "
|
|
|
|
# r"are non-trivial pytrees should always be wrapped in a tuple representing "
|
|
|
|
# r"the argument list. In particular, you're passing in a single argument "
|
|
|
|
# r"which means that pjit in_axis_resources might need to be wrapped in a "
|
|
|
|
# r"singleton tuple.")
|
2021-09-07 07:53:42 -07:00
|
|
|
# with self.assertRaisesRegex(ValueError, error):
|
2021-11-12 22:41:42 -08:00
|
|
|
# pjit(lambda x: x, p, p)([x, x, x]) # Error, but make sure we hint at singleton tuple
|
2022-02-08 12:45:38 -08:00
|
|
|
|
2021-09-07 07:53:42 -07:00
|
|
|
error = re.escape(
|
2023-01-20 11:40:22 -08:00
|
|
|
"pytree structure error: different lengths of list at "
|
2022-02-08 12:45:38 -08:00
|
|
|
"key path\n"
|
2023-03-04 00:48:29 +00:00
|
|
|
" pjit out_shardings\n")
|
2021-09-07 07:53:42 -07:00
|
|
|
with self.assertRaisesRegex(ValueError, error):
|
|
|
|
pjit(lambda x: x, (p,), [p, None])([x, x, x]) # Error, we raise a generic tree mismatch message
|
|
|
|
|
2021-10-04 03:24:50 -07:00
|
|
|
@jtu.with_mesh([('x', 2)])
|
|
|
|
def testNestedDifferentResources(self):
|
2023-02-28 14:28:32 -08:00
|
|
|
@partial(pjit, in_shardings=P('x'), out_shardings=None)
|
2021-10-04 03:24:50 -07:00
|
|
|
def f(x):
|
2023-02-03 14:28:07 -08:00
|
|
|
with jax.sharding.Mesh(np.array([jax.local_devices()[0]]), ('x')):
|
2023-02-28 14:28:32 -08:00
|
|
|
@partial(pjit, in_shardings=P('x'), out_shardings=None)
|
2021-10-04 03:24:50 -07:00
|
|
|
def h(x):
|
|
|
|
return x
|
|
|
|
return h(x)
|
|
|
|
xshape = (2, 5, 6)
|
2023-04-13 11:48:11 -07:00
|
|
|
x = jnp.arange(math.prod(xshape)).reshape(xshape)
|
2024-07-24 12:39:42 -07:00
|
|
|
with self.assertRaisesRegex(
|
|
|
|
ValueError, "Received incompatible devices for pjitted computation.*"):
|
2021-10-04 03:24:50 -07:00
|
|
|
f(x)
|
|
|
|
|
2022-12-02 18:40:59 -08:00
|
|
|
@parameterized.named_parameters(
|
|
|
|
("committed", True),
|
|
|
|
("uncommitted", False),
|
|
|
|
)
|
|
|
|
def test_pjit_with_deleted_input_at_first_call(self, committed):
|
|
|
|
shape = (8,)
|
2024-09-03 16:22:23 -07:00
|
|
|
mesh = jtu.create_mesh((1,), ('x',))
|
2023-02-28 12:40:30 -08:00
|
|
|
inp_data = np.arange(math.prod(shape)).reshape(shape)
|
2022-12-02 18:40:59 -08:00
|
|
|
if committed:
|
|
|
|
s = NamedSharding(mesh, P('x',))
|
|
|
|
x = jax.device_put(inp_data, s)
|
|
|
|
else:
|
|
|
|
x = jax.device_put(inp_data)
|
|
|
|
f = pjit(lambda x: x + 1)
|
|
|
|
with self.assertRaisesRegex(RuntimeError, 'Array has been deleted.'):
|
|
|
|
x.delete()
|
|
|
|
_ = f(x)
|
|
|
|
|
|
|
|
@parameterized.named_parameters(
|
|
|
|
("committed", True),
|
|
|
|
("uncommitted", False),
|
|
|
|
)
|
|
|
|
def test_pjit_with_deleted_input_at_subsequent_call(self, committed):
|
|
|
|
shape = (8,)
|
2024-09-03 16:22:23 -07:00
|
|
|
mesh = jtu.create_mesh((1,), ('x',))
|
2023-02-28 12:40:30 -08:00
|
|
|
inp_data = np.arange(math.prod(shape)).reshape(shape)
|
2022-12-02 18:40:59 -08:00
|
|
|
if committed:
|
|
|
|
s = NamedSharding(mesh, P('x',))
|
|
|
|
x = jax.device_put(inp_data, s)
|
|
|
|
else:
|
|
|
|
x = jax.device_put(inp_data)
|
|
|
|
f = pjit(lambda x: x + 1)
|
|
|
|
_ = f(x)
|
2023-05-25 07:19:56 -07:00
|
|
|
with self.assertRaisesRegex((RuntimeError, ValueError),
|
|
|
|
'.*(Array|buffer|Buffer) has been deleted.*'):
|
2022-12-02 18:40:59 -08:00
|
|
|
x.delete()
|
|
|
|
_ = f(x)
|
|
|
|
|
2023-11-27 22:38:46 -08:00
|
|
|
def test_aot_error_on_dced_avals_mismatch(self):
|
|
|
|
x, y1, y2 = jnp.ones(4), jnp.ones(4), jnp.ones(1)
|
|
|
|
|
|
|
|
@jax.jit
|
|
|
|
def f(x, y):
|
|
|
|
return x + 1 if y.shape[0] > 2 else x + 2
|
|
|
|
|
|
|
|
f_out1 = f(x, y1)
|
|
|
|
f(x, y2)
|
|
|
|
|
|
|
|
g = f.lower(x, y1).compile()
|
|
|
|
g_out1 = g(x, y1)
|
|
|
|
self.assertArraysEqual(f_out1, g_out1)
|
|
|
|
|
|
|
|
with self.assertRaisesRegex(
|
|
|
|
TypeError,
|
|
|
|
'Argument types differ from the types for which this computation was'
|
|
|
|
' compiled'):
|
|
|
|
g(x, y2)
|
|
|
|
|
2023-12-19 14:25:25 -08:00
|
|
|
def test_dce_no_array(self):
|
2024-09-03 16:22:23 -07:00
|
|
|
mesh = jtu.create_mesh((2,), ('x',))
|
2023-12-19 14:25:25 -08:00
|
|
|
arr = jax.device_put(np.arange(8.), NamedSharding(mesh, P('x')))
|
|
|
|
|
|
|
|
@jax.jit
|
|
|
|
def f(a, b, c):
|
|
|
|
return a, c
|
|
|
|
|
|
|
|
f(arr, 2., 3.)
|
|
|
|
f(arr, 2., 3.) # doesn't crash
|
|
|
|
|
2021-04-15 06:12:18 -07:00
|
|
|
|
2023-01-12 22:42:06 +00:00
|
|
|
@jtu.pytest_mark_if_available('multiaccelerator')
|
Add reverse-mode AD support for pjit
This is a somewhat big patch, because the transposition process turns out to be
quite difficult. The biggest issue appears when we do partial evaluation and we have
to add a whole bunch of intermediate values as outputs of the primal computation,
but we don't have any partition specs for them!
A simple workaround would be to mark all of them as replicated, but that would
likely tank performance which is why we didn't go with that option. Instead, we use
a newly added XLA option called `allow_spmd_sharding_propagation_to_output` to compile
a throwaway executable that lets us query output sharding that XLA considers convenient
for the computation.
However, there's one more difficulty: XLA's `OpSharding` is much less constrained
than our `PartitionSpec`s. In particular, while `PartitionSpec`s can only represent
"block permutations" of devices (with blocks deliniated by mesh axes), `OpSharding`
allows arbitrary assignment (permutation) of tensor chunks to devices. This means that
not every `OpSharding` has a corresponding `PartitionSpec`, but I did implement a
(somewhat involved) procedure that should recover one whenever it exists.
Unfortunately this makes our support for reverse-mode AD partial, because we might
be unable to handle `OpSharding` returned by XLA. But this will only happen if XLA
actually comes up with sharding specifications on its own. If it merely propagates
the sharding obtained from `PartitionSpec`s into the middle of the computation, then
we should be good. In any case, if we end up seeing failures in this path, we should
consider relaxing `PartitionSpec`s, but that would be a pretty large change, so I decided
to avoid it unless there's no other way.
PiperOrigin-RevId: 399680306
2021-09-29 07:19:28 -07:00
|
|
|
class UtilTest(jtu.JaxTestCase):
|
2021-11-12 22:41:42 -08:00
|
|
|
|
Add reverse-mode AD support for pjit
This is a somewhat big patch, because the transposition process turns out to be
quite difficult. The biggest issue appears when we do partial evaluation and we have
to add a whole bunch of intermediate values as outputs of the primal computation,
but we don't have any partition specs for them!
A simple workaround would be to mark all of them as replicated, but that would
likely tank performance which is why we didn't go with that option. Instead, we use
a newly added XLA option called `allow_spmd_sharding_propagation_to_output` to compile
a throwaway executable that lets us query output sharding that XLA considers convenient
for the computation.
However, there's one more difficulty: XLA's `OpSharding` is much less constrained
than our `PartitionSpec`s. In particular, while `PartitionSpec`s can only represent
"block permutations" of devices (with blocks deliniated by mesh axes), `OpSharding`
allows arbitrary assignment (permutation) of tensor chunks to devices. This means that
not every `OpSharding` has a corresponding `PartitionSpec`, but I did implement a
(somewhat involved) procedure that should recover one whenever it exists.
Unfortunately this makes our support for reverse-mode AD partial, because we might
be unable to handle `OpSharding` returned by XLA. But this will only happen if XLA
actually comes up with sharding specifications on its own. If it merely propagates
the sharding obtained from `PartitionSpec`s into the middle of the computation, then
we should be good. In any case, if we end up seeing failures in this path, we should
consider relaxing `PartitionSpec`s, but that would be a pretty large change, so I decided
to avoid it unless there's no other way.
PiperOrigin-RevId: 399680306
2021-09-29 07:19:28 -07:00
|
|
|
def testOpShardingRoundTrip(self):
|
|
|
|
FakeDevice = namedtuple('FakeDevice', ['id'])
|
|
|
|
mesh_named_shape = OrderedDict([('a', 2), ('b', 3), ('c', 4), ('d', 7), ('e', 4)])
|
|
|
|
mesh_axes, mesh_shape = unzip2(mesh_named_shape.items())
|
2023-04-13 11:48:11 -07:00
|
|
|
devices = [FakeDevice(i) for i in range(math.prod(mesh_shape))]
|
Add reverse-mode AD support for pjit
This is a somewhat big patch, because the transposition process turns out to be
quite difficult. The biggest issue appears when we do partial evaluation and we have
to add a whole bunch of intermediate values as outputs of the primal computation,
but we don't have any partition specs for them!
A simple workaround would be to mark all of them as replicated, but that would
likely tank performance which is why we didn't go with that option. Instead, we use
a newly added XLA option called `allow_spmd_sharding_propagation_to_output` to compile
a throwaway executable that lets us query output sharding that XLA considers convenient
for the computation.
However, there's one more difficulty: XLA's `OpSharding` is much less constrained
than our `PartitionSpec`s. In particular, while `PartitionSpec`s can only represent
"block permutations" of devices (with blocks deliniated by mesh axes), `OpSharding`
allows arbitrary assignment (permutation) of tensor chunks to devices. This means that
not every `OpSharding` has a corresponding `PartitionSpec`, but I did implement a
(somewhat involved) procedure that should recover one whenever it exists.
Unfortunately this makes our support for reverse-mode AD partial, because we might
be unable to handle `OpSharding` returned by XLA. But this will only happen if XLA
actually comes up with sharding specifications on its own. If it merely propagates
the sharding obtained from `PartitionSpec`s into the middle of the computation, then
we should be good. In any case, if we end up seeing failures in this path, we should
consider relaxing `PartitionSpec`s, but that would be a pretty large change, so I decided
to avoid it unless there's no other way.
PiperOrigin-RevId: 399680306
2021-09-29 07:19:28 -07:00
|
|
|
mesh = pxla.Mesh(np.array(devices).reshape(*mesh_shape), tuple(mesh_axes))
|
|
|
|
|
|
|
|
dims = 5
|
2023-02-14 23:00:40 -08:00
|
|
|
aval = core.ShapedArray((len(devices),) * dims, jnp.float32)
|
Add reverse-mode AD support for pjit
This is a somewhat big patch, because the transposition process turns out to be
quite difficult. The biggest issue appears when we do partial evaluation and we have
to add a whole bunch of intermediate values as outputs of the primal computation,
but we don't have any partition specs for them!
A simple workaround would be to mark all of them as replicated, but that would
likely tank performance which is why we didn't go with that option. Instead, we use
a newly added XLA option called `allow_spmd_sharding_propagation_to_output` to compile
a throwaway executable that lets us query output sharding that XLA considers convenient
for the computation.
However, there's one more difficulty: XLA's `OpSharding` is much less constrained
than our `PartitionSpec`s. In particular, while `PartitionSpec`s can only represent
"block permutations" of devices (with blocks deliniated by mesh axes), `OpSharding`
allows arbitrary assignment (permutation) of tensor chunks to devices. This means that
not every `OpSharding` has a corresponding `PartitionSpec`, but I did implement a
(somewhat involved) procedure that should recover one whenever it exists.
Unfortunately this makes our support for reverse-mode AD partial, because we might
be unable to handle `OpSharding` returned by XLA. But this will only happen if XLA
actually comes up with sharding specifications on its own. If it merely propagates
the sharding obtained from `PartitionSpec`s into the middle of the computation, then
we should be good. In any case, if we end up seeing failures in this path, we should
consider relaxing `PartitionSpec`s, but that would be a pretty large change, so I decided
to avoid it unless there's no other way.
PiperOrigin-RevId: 399680306
2021-09-29 07:19:28 -07:00
|
|
|
def roundtrip(spec):
|
2023-06-05 13:40:59 -07:00
|
|
|
hlo_sharding = NamedSharding(mesh, spec)._to_xla_hlo_sharding(aval.ndim)
|
|
|
|
parsed_spec = parse_flatten_op_sharding(hlo_sharding, mesh)[0].partitions
|
Add reverse-mode AD support for pjit
This is a somewhat big patch, because the transposition process turns out to be
quite difficult. The biggest issue appears when we do partial evaluation and we have
to add a whole bunch of intermediate values as outputs of the primal computation,
but we don't have any partition specs for them!
A simple workaround would be to mark all of them as replicated, but that would
likely tank performance which is why we didn't go with that option. Instead, we use
a newly added XLA option called `allow_spmd_sharding_propagation_to_output` to compile
a throwaway executable that lets us query output sharding that XLA considers convenient
for the computation.
However, there's one more difficulty: XLA's `OpSharding` is much less constrained
than our `PartitionSpec`s. In particular, while `PartitionSpec`s can only represent
"block permutations" of devices (with blocks deliniated by mesh axes), `OpSharding`
allows arbitrary assignment (permutation) of tensor chunks to devices. This means that
not every `OpSharding` has a corresponding `PartitionSpec`, but I did implement a
(somewhat involved) procedure that should recover one whenever it exists.
Unfortunately this makes our support for reverse-mode AD partial, because we might
be unable to handle `OpSharding` returned by XLA. But this will only happen if XLA
actually comes up with sharding specifications on its own. If it merely propagates
the sharding obtained from `PartitionSpec`s into the middle of the computation, then
we should be good. In any case, if we end up seeing failures in this path, we should
consider relaxing `PartitionSpec`s, but that would be a pretty large change, so I decided
to avoid it unless there's no other way.
PiperOrigin-RevId: 399680306
2021-09-29 07:19:28 -07:00
|
|
|
self.assertEqual(parsed_spec[:len(spec)], spec)
|
|
|
|
self.assertEqual(parsed_spec[len(spec):], ((),) * (len(parsed_spec) - len(spec)))
|
|
|
|
|
|
|
|
special_specs = [P()]
|
|
|
|
for spec in special_specs:
|
|
|
|
roundtrip(spec)
|
|
|
|
|
2021-12-10 10:32:09 -08:00
|
|
|
rng = self.rng()
|
Add reverse-mode AD support for pjit
This is a somewhat big patch, because the transposition process turns out to be
quite difficult. The biggest issue appears when we do partial evaluation and we have
to add a whole bunch of intermediate values as outputs of the primal computation,
but we don't have any partition specs for them!
A simple workaround would be to mark all of them as replicated, but that would
likely tank performance which is why we didn't go with that option. Instead, we use
a newly added XLA option called `allow_spmd_sharding_propagation_to_output` to compile
a throwaway executable that lets us query output sharding that XLA considers convenient
for the computation.
However, there's one more difficulty: XLA's `OpSharding` is much less constrained
than our `PartitionSpec`s. In particular, while `PartitionSpec`s can only represent
"block permutations" of devices (with blocks deliniated by mesh axes), `OpSharding`
allows arbitrary assignment (permutation) of tensor chunks to devices. This means that
not every `OpSharding` has a corresponding `PartitionSpec`, but I did implement a
(somewhat involved) procedure that should recover one whenever it exists.
Unfortunately this makes our support for reverse-mode AD partial, because we might
be unable to handle `OpSharding` returned by XLA. But this will only happen if XLA
actually comes up with sharding specifications on its own. If it merely propagates
the sharding obtained from `PartitionSpec`s into the middle of the computation, then
we should be good. In any case, if we end up seeing failures in this path, we should
consider relaxing `PartitionSpec`s, but that would be a pretty large change, so I decided
to avoid it unless there's no other way.
PiperOrigin-RevId: 399680306
2021-09-29 07:19:28 -07:00
|
|
|
for i in range(100):
|
|
|
|
spec = [()] * dims
|
2021-12-10 10:32:09 -08:00
|
|
|
for axis in rng.permutation(mesh_axes)[:rng.randint(low=1, high=len(mesh_axes) + 1)]:
|
Add reverse-mode AD support for pjit
This is a somewhat big patch, because the transposition process turns out to be
quite difficult. The biggest issue appears when we do partial evaluation and we have
to add a whole bunch of intermediate values as outputs of the primal computation,
but we don't have any partition specs for them!
A simple workaround would be to mark all of them as replicated, but that would
likely tank performance which is why we didn't go with that option. Instead, we use
a newly added XLA option called `allow_spmd_sharding_propagation_to_output` to compile
a throwaway executable that lets us query output sharding that XLA considers convenient
for the computation.
However, there's one more difficulty: XLA's `OpSharding` is much less constrained
than our `PartitionSpec`s. In particular, while `PartitionSpec`s can only represent
"block permutations" of devices (with blocks deliniated by mesh axes), `OpSharding`
allows arbitrary assignment (permutation) of tensor chunks to devices. This means that
not every `OpSharding` has a corresponding `PartitionSpec`, but I did implement a
(somewhat involved) procedure that should recover one whenever it exists.
Unfortunately this makes our support for reverse-mode AD partial, because we might
be unable to handle `OpSharding` returned by XLA. But this will only happen if XLA
actually comes up with sharding specifications on its own. If it merely propagates
the sharding obtained from `PartitionSpec`s into the middle of the computation, then
we should be good. In any case, if we end up seeing failures in this path, we should
consider relaxing `PartitionSpec`s, but that would be a pretty large change, so I decided
to avoid it unless there's no other way.
PiperOrigin-RevId: 399680306
2021-09-29 07:19:28 -07:00
|
|
|
spec[rng.choice(dims)] += (axis,)
|
2023-05-01 11:46:19 -07:00
|
|
|
while spec and spec[-1] == ():
|
|
|
|
spec.pop()
|
Add reverse-mode AD support for pjit
This is a somewhat big patch, because the transposition process turns out to be
quite difficult. The biggest issue appears when we do partial evaluation and we have
to add a whole bunch of intermediate values as outputs of the primal computation,
but we don't have any partition specs for them!
A simple workaround would be to mark all of them as replicated, but that would
likely tank performance which is why we didn't go with that option. Instead, we use
a newly added XLA option called `allow_spmd_sharding_propagation_to_output` to compile
a throwaway executable that lets us query output sharding that XLA considers convenient
for the computation.
However, there's one more difficulty: XLA's `OpSharding` is much less constrained
than our `PartitionSpec`s. In particular, while `PartitionSpec`s can only represent
"block permutations" of devices (with blocks deliniated by mesh axes), `OpSharding`
allows arbitrary assignment (permutation) of tensor chunks to devices. This means that
not every `OpSharding` has a corresponding `PartitionSpec`, but I did implement a
(somewhat involved) procedure that should recover one whenever it exists.
Unfortunately this makes our support for reverse-mode AD partial, because we might
be unable to handle `OpSharding` returned by XLA. But this will only happen if XLA
actually comes up with sharding specifications on its own. If it merely propagates
the sharding obtained from `PartitionSpec`s into the middle of the computation, then
we should be good. In any case, if we end up seeing failures in this path, we should
consider relaxing `PartitionSpec`s, but that would be a pretty large change, so I decided
to avoid it unless there's no other way.
PiperOrigin-RevId: 399680306
2021-09-29 07:19:28 -07:00
|
|
|
roundtrip(P(*spec))
|
|
|
|
|
2021-11-12 22:41:42 -08:00
|
|
|
@parameterized.named_parameters(
|
2023-05-02 08:55:12 -07:00
|
|
|
("linear", {'x': 0, 'y': 1, 'z': 2}, P('x', 'y', 'z')),
|
|
|
|
("combine", {'x': 0, 'y': 0, 'z': 1}, P(('x', 'y'), 'z')),
|
|
|
|
("skip", {'x': 0, 'y': 0, 'z': 2}, P(('x', 'y'), None, 'z')),
|
|
|
|
("multi_skip", {'x': 0, 'y': 1, 'z': 3}, P('x', 'y', None, 'z')),
|
2021-11-12 22:41:42 -08:00
|
|
|
)
|
|
|
|
def test_array_mapping_to_axis_resources(self, inp, expected_out):
|
2023-04-10 10:15:08 -07:00
|
|
|
self.assertEqual(
|
|
|
|
sharding_impls.array_mapping_to_axis_resources(inp), expected_out
|
|
|
|
)
|
Add reverse-mode AD support for pjit
This is a somewhat big patch, because the transposition process turns out to be
quite difficult. The biggest issue appears when we do partial evaluation and we have
to add a whole bunch of intermediate values as outputs of the primal computation,
but we don't have any partition specs for them!
A simple workaround would be to mark all of them as replicated, but that would
likely tank performance which is why we didn't go with that option. Instead, we use
a newly added XLA option called `allow_spmd_sharding_propagation_to_output` to compile
a throwaway executable that lets us query output sharding that XLA considers convenient
for the computation.
However, there's one more difficulty: XLA's `OpSharding` is much less constrained
than our `PartitionSpec`s. In particular, while `PartitionSpec`s can only represent
"block permutations" of devices (with blocks deliniated by mesh axes), `OpSharding`
allows arbitrary assignment (permutation) of tensor chunks to devices. This means that
not every `OpSharding` has a corresponding `PartitionSpec`, but I did implement a
(somewhat involved) procedure that should recover one whenever it exists.
Unfortunately this makes our support for reverse-mode AD partial, because we might
be unable to handle `OpSharding` returned by XLA. But this will only happen if XLA
actually comes up with sharding specifications on its own. If it merely propagates
the sharding obtained from `PartitionSpec`s into the middle of the computation, then
we should be good. In any case, if we end up seeing failures in this path, we should
consider relaxing `PartitionSpec`s, but that would be a pretty large change, so I decided
to avoid it unless there's no other way.
PiperOrigin-RevId: 399680306
2021-09-29 07:19:28 -07:00
|
|
|
|
2022-07-28 21:00:33 -07:00
|
|
|
def test_op_sharding_equality_and_hash_equality(self):
|
2022-07-25 13:24:08 -07:00
|
|
|
op1 = xc.OpSharding()
|
|
|
|
op1.type = xc.OpSharding.Type.OTHER
|
2022-07-28 21:00:33 -07:00
|
|
|
op1.tile_assignment_dimensions = [2, 2]
|
2022-07-25 13:24:08 -07:00
|
|
|
op1.tile_assignment_devices = [0, 1, 2, 3]
|
|
|
|
|
|
|
|
op2 = xc.OpSharding()
|
|
|
|
op2.type = xc.OpSharding.Type.OTHER
|
2022-07-28 21:00:33 -07:00
|
|
|
op2.tile_assignment_dimensions = [2, 2]
|
2022-07-25 13:24:08 -07:00
|
|
|
op2.tile_assignment_devices = [0, 1, 2, 3]
|
|
|
|
|
|
|
|
op3 = xc.OpSharding()
|
|
|
|
op3.type = xc.OpSharding.Type.OTHER
|
|
|
|
op3.tile_assignment_dimensions = [4, 2]
|
|
|
|
op3.tile_assignment_devices = [0, 1, 2, 3, 4, 5, 6, 7]
|
|
|
|
|
2023-04-06 08:31:47 -07:00
|
|
|
self.assertTrue(op_shardings.are_op_shardings_equal(op1, op2))
|
|
|
|
self.assertFalse(op_shardings.are_op_shardings_equal(op1, op3))
|
|
|
|
self.assertFalse(op_shardings.are_op_shardings_equal(op2, op3))
|
2022-07-25 13:24:08 -07:00
|
|
|
|
2022-10-26 15:08:58 -04:00
|
|
|
hs1 = xc.HloSharding.from_proto(op1)
|
|
|
|
hs2 = xc.HloSharding.from_proto(op2)
|
|
|
|
hs3 = xc.HloSharding.from_proto(op3)
|
2022-07-28 21:00:33 -07:00
|
|
|
|
2023-06-01 09:36:32 -07:00
|
|
|
self.assertEqual(hs1, xc.HloSharding.iota_tile((2, 2)))
|
|
|
|
self.assertEqual(hs2, xc.HloSharding.iota_tile((2, 2)))
|
|
|
|
self.assertEqual(hs3, xc.HloSharding.iota_tile((4, 2)))
|
|
|
|
self.assertEqual(hs1.num_devices(), 4)
|
|
|
|
self.assertEqual(hs1.num_dimensions(), 2)
|
|
|
|
self.assertEqual(hs1.tile_assignment_dimensions(), [2, 2])
|
|
|
|
self.assertEqual(hs1.tile_assignment_devices(), [0, 1, 2, 3])
|
|
|
|
self.assertTrue(hs1.is_tiled())
|
|
|
|
self.assertFalse(hs1.replicate_on_last_tile_dim())
|
2022-10-26 15:08:58 -04:00
|
|
|
self.assertEqual(hash(hs1), hash(hs2))
|
|
|
|
self.assertNotEqual(hash(hs1), hash(hs3))
|
|
|
|
self.assertNotEqual(hash(hs2), hash(hs3))
|
2022-07-28 21:00:33 -07:00
|
|
|
|
2022-07-25 13:24:08 -07:00
|
|
|
def test_op_sharding_partial_sharding(self):
|
|
|
|
op1 = xc.OpSharding()
|
|
|
|
op1.type = xc.OpSharding.Type.OTHER
|
|
|
|
op1.tile_assignment_dimensions = [4, 1]
|
2023-05-23 16:37:02 -07:00
|
|
|
op1.tile_assignment_devices = [0, 2, 1, 3]
|
2022-07-25 13:24:08 -07:00
|
|
|
op1.last_tile_dims = [xc.OpSharding.Type.REPLICATED]
|
|
|
|
|
|
|
|
op2 = xc.OpSharding()
|
|
|
|
op2.type = xc.OpSharding.Type.OTHER
|
|
|
|
op2.tile_assignment_dimensions = [4, 1]
|
2023-05-23 16:37:02 -07:00
|
|
|
op2.tile_assignment_devices = [0, 2, 1, 3]
|
2022-07-25 13:24:08 -07:00
|
|
|
op2.last_tile_dims = [xc.OpSharding.Type.REPLICATED]
|
|
|
|
|
2023-04-06 08:31:47 -07:00
|
|
|
self.assertTrue(op_shardings.are_op_shardings_equal(op1, op2))
|
2022-07-25 13:24:08 -07:00
|
|
|
|
2022-10-26 15:08:58 -04:00
|
|
|
hs1 = xc.HloSharding.from_proto(op1)
|
|
|
|
hs2 = xc.HloSharding.from_proto(op2)
|
2023-06-01 09:36:32 -07:00
|
|
|
self.assertEqual(
|
|
|
|
hs1,
|
|
|
|
xc.HloSharding.iota_tile(
|
|
|
|
(4, 1),
|
|
|
|
reshape_dims=(2, 2),
|
|
|
|
transpose_perm=(1, 0),
|
|
|
|
subgroup_types=[xc.OpSharding.Type.REPLICATED],
|
|
|
|
),
|
|
|
|
)
|
|
|
|
self.assertFalse(hs1.subgroup_types())
|
|
|
|
self.assertTrue(hs1.is_tiled())
|
|
|
|
self.assertEqual(
|
|
|
|
hs2,
|
|
|
|
xc.HloSharding.iota_tile(
|
|
|
|
(4, 1),
|
|
|
|
reshape_dims=(2, 2),
|
|
|
|
transpose_perm=(1, 0),
|
|
|
|
subgroup_types=[xc.OpSharding.Type.REPLICATED],
|
|
|
|
),
|
|
|
|
)
|
|
|
|
self.assertFalse(hs2.subgroup_types())
|
|
|
|
self.assertTrue(hs2.is_tiled())
|
2022-10-26 15:08:58 -04:00
|
|
|
self.assertEqual(hash(hs1), hash(hs2))
|
2022-07-28 21:00:33 -07:00
|
|
|
|
2022-07-25 13:24:08 -07:00
|
|
|
def test_op_sharding_tuple_shardings(self):
|
|
|
|
top1 = xc.OpSharding()
|
|
|
|
top1.type = xc.OpSharding.Type.OTHER
|
|
|
|
top1.tile_assignment_dimensions = [4, 1]
|
|
|
|
top1.tile_assignment_devices = [0, 1, 2, 3]
|
|
|
|
top1.replicate_on_last_tile_dim = True
|
|
|
|
|
|
|
|
top2 = xc.OpSharding()
|
|
|
|
top2.type = xc.OpSharding.Type.OTHER
|
2022-07-28 21:00:33 -07:00
|
|
|
top2.tile_assignment_dimensions = [2, 2]
|
2022-07-25 13:24:08 -07:00
|
|
|
top2.tile_assignment_devices = [0, 1, 2, 3]
|
|
|
|
top2.replicate_on_last_tile_dim = True
|
|
|
|
|
|
|
|
op1 = xc.OpSharding()
|
|
|
|
op1.type = xc.OpSharding.Type.TUPLE
|
|
|
|
op1.tuple_shardings = [top1, top2]
|
|
|
|
|
|
|
|
op2 = xc.OpSharding()
|
|
|
|
op2.type = xc.OpSharding.Type.TUPLE
|
|
|
|
op2.tuple_shardings = [top2, top1]
|
|
|
|
|
2023-04-06 08:31:47 -07:00
|
|
|
self.assertFalse(op_shardings.are_op_shardings_equal(op1, op2))
|
2022-07-25 13:24:08 -07:00
|
|
|
|
2022-10-26 15:08:58 -04:00
|
|
|
hs1 = xc.HloSharding.from_proto(op1)
|
|
|
|
hs2 = xc.HloSharding.from_proto(op2)
|
|
|
|
self.assertNotEqual(hash(hs1), hash(hs2))
|
2022-07-28 21:00:33 -07:00
|
|
|
|
2023-05-23 16:37:02 -07:00
|
|
|
def test_hlo_sharding_iota_tile_error(self):
|
|
|
|
self.assertRaisesRegex(
|
|
|
|
xla_extension.XlaRuntimeError,
|
|
|
|
'INVALID_ARGUMENT: `dims` should not be empty.',
|
|
|
|
lambda: xc.HloSharding.iota_tile(())
|
|
|
|
)
|
|
|
|
self.assertRaisesRegex(
|
|
|
|
xla_extension.XlaRuntimeError,
|
|
|
|
'INVALID_ARGUMENT: Cannot reshape from',
|
|
|
|
lambda: xc.HloSharding.iota_tile(
|
|
|
|
(2, 2),
|
|
|
|
reshape_dims=(2, 4),
|
|
|
|
transpose_perm=(1, 0),
|
|
|
|
),
|
|
|
|
)
|
|
|
|
self.assertRaisesRegex(
|
|
|
|
xla_extension.XlaRuntimeError,
|
|
|
|
'INVALID_ARGUMENT: `reshape_dims` and `transpose_perm` should have the'
|
|
|
|
' same size',
|
|
|
|
lambda: xc.HloSharding.iota_tile(
|
|
|
|
(2, 2),
|
|
|
|
transpose_perm=(1, 0),
|
|
|
|
),
|
|
|
|
)
|
2023-05-24 20:25:48 -07:00
|
|
|
self.assertRaisesWithLiteralMatch(
|
|
|
|
xla_extension.XlaRuntimeError,
|
|
|
|
'INVALID_ARGUMENT: `subgroup_types`(3) should not have more dimensions '
|
|
|
|
'than `dims`(2).',
|
|
|
|
lambda: xc.HloSharding.iota_tile(
|
|
|
|
(2, 2),
|
|
|
|
subgroup_types=(
|
|
|
|
xc.OpSharding.Type.REPLICATED,
|
|
|
|
xc.OpSharding.Type.MANUAL,
|
|
|
|
xc.OpSharding.Type.REPLICATED,
|
|
|
|
),
|
|
|
|
),
|
|
|
|
)
|
2023-05-23 16:37:02 -07:00
|
|
|
|
2022-08-05 09:59:22 -07:00
|
|
|
def test_device_indices_cache(self):
|
|
|
|
op1 = xc.OpSharding()
|
|
|
|
op1.type = xc.OpSharding.Type.OTHER
|
|
|
|
op1.tile_assignment_dimensions = [1, 1, 2, 1]
|
|
|
|
op1.tile_assignment_devices = [0, 1]
|
|
|
|
op1.last_tile_dims = [xc.OpSharding.Type.REPLICATED, xc.OpSharding.Type.MANUAL]
|
|
|
|
|
|
|
|
op2 = xc.OpSharding()
|
|
|
|
op2.type = xc.OpSharding.Type.REPLICATED
|
|
|
|
|
|
|
|
shape = (8, 4)
|
|
|
|
devices = jax.devices()
|
|
|
|
|
2023-02-17 17:10:27 -08:00
|
|
|
ops = GSPMDSharding(devices, op1)
|
2022-08-05 09:59:22 -07:00
|
|
|
ops.devices_indices_map(shape)
|
2024-06-05 08:02:39 -07:00
|
|
|
cache_info1 = common_devices_indices_map.cache_info()
|
2022-08-05 12:17:41 -07:00
|
|
|
|
2022-08-05 09:59:22 -07:00
|
|
|
ops.devices_indices_map(shape)
|
2024-06-05 08:02:39 -07:00
|
|
|
cache_info2 = common_devices_indices_map.cache_info()
|
2022-08-05 12:17:41 -07:00
|
|
|
self.assertEqual(cache_info2.hits, cache_info1.hits + 1)
|
2022-08-05 09:59:22 -07:00
|
|
|
|
2023-02-17 17:10:27 -08:00
|
|
|
ops = GSPMDSharding(devices, op2)
|
2022-08-05 09:59:22 -07:00
|
|
|
ops.devices_indices_map(shape)
|
2024-06-05 08:02:39 -07:00
|
|
|
cache_info3 = common_devices_indices_map.cache_info()
|
2022-08-05 12:17:41 -07:00
|
|
|
self.assertEqual(cache_info3.hits, cache_info2.hits + 1)
|
|
|
|
|
2022-08-05 09:59:22 -07:00
|
|
|
ops.devices_indices_map(shape)
|
2024-06-05 08:02:39 -07:00
|
|
|
cache_info4 = common_devices_indices_map.cache_info()
|
2022-08-05 12:17:41 -07:00
|
|
|
self.assertEqual(cache_info4.hits, cache_info3.hits + 1)
|
|
|
|
|
|
|
|
def test_op_sharding_semantically_replicated(self):
|
|
|
|
op1 = xc.OpSharding()
|
|
|
|
op1.type = xc.OpSharding.Type.OTHER
|
|
|
|
op1.tile_assignment_dimensions = [1, 1, 2]
|
|
|
|
op1.tile_assignment_devices = [0, 1]
|
|
|
|
op1.last_tile_dims = [xc.OpSharding.Type.REPLICATED]
|
|
|
|
|
|
|
|
op2 = xc.OpSharding()
|
|
|
|
op2.type = xc.OpSharding.Type.REPLICATED
|
|
|
|
|
2022-08-05 18:01:15 -07:00
|
|
|
op3 = xc.OpSharding()
|
|
|
|
op3.type = xc.OpSharding.Type.OTHER
|
|
|
|
op3.tile_assignment_dimensions = [1, 1, 1, 1]
|
|
|
|
op3.tile_assignment_devices = [0]
|
|
|
|
op3.last_tile_dims = [xc.OpSharding.Type.REPLICATED]
|
|
|
|
|
|
|
|
op4 = xc.OpSharding()
|
|
|
|
op4.type = xc.OpSharding.Type.OTHER
|
|
|
|
op4.tile_assignment_dimensions = [1]
|
|
|
|
op4.tile_assignment_devices = [0]
|
|
|
|
|
2023-04-06 08:31:47 -07:00
|
|
|
self.assertTrue(op_shardings.is_op_sharding_replicated(op1))
|
|
|
|
self.assertTrue(op_shardings.is_op_sharding_replicated(op2))
|
|
|
|
self.assertTrue(op_shardings.is_op_sharding_replicated(op3))
|
|
|
|
self.assertTrue(op_shardings.is_op_sharding_replicated(op4))
|
|
|
|
self.assertTrue(op_shardings.are_op_shardings_equal(op1, op2))
|
|
|
|
self.assertTrue(op_shardings.are_op_shardings_equal(op2, op3))
|
|
|
|
self.assertTrue(op_shardings.are_op_shardings_equal(op3, op4))
|
2022-08-05 12:17:41 -07:00
|
|
|
|
|
|
|
def test_op_sharding_manual_replicated(self):
|
|
|
|
op1 = xc.OpSharding()
|
|
|
|
op1.type = xc.OpSharding.Type.OTHER
|
|
|
|
op1.tile_assignment_dimensions = [1, 1, 2, 1]
|
|
|
|
op1.tile_assignment_devices = [0, 1]
|
|
|
|
op1.last_tile_dims = [xc.OpSharding.Type.REPLICATED, xc.OpSharding.Type.MANUAL]
|
|
|
|
|
|
|
|
op2 = xc.OpSharding()
|
|
|
|
op2.type = xc.OpSharding.Type.OTHER
|
|
|
|
op2.tile_assignment_dimensions = [1, 1, 1, 2]
|
|
|
|
op2.tile_assignment_devices = [0, 1]
|
|
|
|
op2.last_tile_dims = [xc.OpSharding.Type.MANUAL, xc.OpSharding.Type.REPLICATED]
|
|
|
|
|
|
|
|
op3 = xc.OpSharding()
|
|
|
|
op3.type = xc.OpSharding.Type.REPLICATED
|
|
|
|
|
2023-04-06 08:31:47 -07:00
|
|
|
self.assertTrue(op_shardings.is_op_sharding_replicated(op1))
|
|
|
|
self.assertTrue(op_shardings.is_op_sharding_replicated(op2))
|
|
|
|
self.assertTrue(op_shardings.are_op_shardings_equal(op1, op2))
|
|
|
|
self.assertTrue(op_shardings.are_op_shardings_equal(op1, op3))
|
2022-08-05 12:17:41 -07:00
|
|
|
|
2023-06-01 09:36:32 -07:00
|
|
|
hs1 = xc.HloSharding.from_proto(op1)
|
|
|
|
self.assertEqual(
|
|
|
|
hs1,
|
|
|
|
xc.HloSharding.iota_tile(
|
|
|
|
(1, 1, 2, 1),
|
|
|
|
subgroup_types=(
|
|
|
|
xc.OpSharding.Type.REPLICATED,
|
|
|
|
xc.OpSharding.Type.MANUAL,
|
|
|
|
),
|
|
|
|
)
|
|
|
|
)
|
|
|
|
self.assertTrue(hs1.is_replicated())
|
|
|
|
self.assertFalse(hs1.replicate_on_last_tile_dim())
|
2023-05-24 20:25:48 -07:00
|
|
|
|
2023-06-01 09:36:32 -07:00
|
|
|
hs2 = xc.HloSharding.from_proto(op2)
|
|
|
|
self.assertEqual(
|
|
|
|
xc.HloSharding.from_proto(op2),
|
|
|
|
xc.HloSharding.iota_tile(
|
|
|
|
(1, 1, 1, 2),
|
|
|
|
subgroup_types=(
|
|
|
|
xc.OpSharding.Type.MANUAL,
|
|
|
|
xc.OpSharding.Type.REPLICATED,
|
|
|
|
),
|
|
|
|
)
|
|
|
|
)
|
|
|
|
self.assertTrue(hs2.is_replicated())
|
|
|
|
self.assertFalse(hs2.replicate_on_last_tile_dim())
|
|
|
|
self.assertEqual(
|
|
|
|
xc.HloSharding.from_proto(op3), xc.HloSharding.replicate()
|
|
|
|
)
|
2023-05-23 16:37:02 -07:00
|
|
|
|
2023-06-01 09:36:32 -07:00
|
|
|
def test_hlo_sharding_manual_replicated(self):
|
2023-05-23 16:37:02 -07:00
|
|
|
hs1 = xc.HloSharding.manual()
|
|
|
|
self.assertTrue(hs1.is_manual())
|
2023-05-24 20:25:48 -07:00
|
|
|
self.assertFalse(hs1.tile_assignment_devices())
|
2023-05-23 16:37:02 -07:00
|
|
|
|
|
|
|
hs2 = xc.HloSharding.replicate()
|
|
|
|
self.assertTrue(hs2.is_replicated())
|
2023-05-24 20:25:48 -07:00
|
|
|
self.assertFalse(hs2.tile_assignment_devices())
|
|
|
|
|
|
|
|
hs3 = xc.HloSharding.iota_tile(
|
|
|
|
(3, 3),
|
|
|
|
subgroup_types=(
|
|
|
|
xc.OpSharding.Type.MANUAL,
|
|
|
|
xc.OpSharding.Type.REPLICATED,
|
|
|
|
),
|
|
|
|
)
|
|
|
|
self.assertFalse(hs3.is_manual())
|
|
|
|
self.assertFalse(hs3.is_replicated())
|
|
|
|
self.assertEqual(hs3.num_dimensions(), 2)
|
|
|
|
self.assertEqual(hs3.tile_assignment_dimensions(), [3, 3])
|
|
|
|
self.assertEqual(hs3.num_devices(), 9)
|
|
|
|
self.assertEqual(hs3.tile_assignment_devices(), list(range(0, 9)))
|
|
|
|
self.assertEqual(
|
|
|
|
hs3.subgroup_types(),
|
|
|
|
[xc.OpSharding.Type.MANUAL, xc.OpSharding.Type.REPLICATED],
|
|
|
|
)
|
|
|
|
self.assertFalse(hs3.replicate_on_last_tile_dim())
|
|
|
|
self.assertTrue(hs3.is_tiled())
|
|
|
|
|
|
|
|
hs4 = xc.HloSharding.iota_tile(
|
|
|
|
(3, 4), subgroup_types=[xc.OpSharding.Type.REPLICATED]
|
|
|
|
)
|
|
|
|
self.assertTrue(hs4.replicate_on_last_tile_dim())
|
|
|
|
self.assertFalse(hs4.subgroup_types())
|
|
|
|
self.assertTrue(hs4.is_tiled())
|
2023-05-23 16:37:02 -07:00
|
|
|
|
2022-08-05 12:17:41 -07:00
|
|
|
def test_op_sharding_cache_on_mesh_pspec_sharding(self):
|
|
|
|
ndim = 2
|
2024-09-03 16:22:23 -07:00
|
|
|
mesh = jtu.create_mesh((4, 2), ('x', 'y'))
|
2022-11-14 14:43:26 -08:00
|
|
|
mps1 = NamedSharding(mesh, P('x', 'y'))
|
2023-06-05 13:40:59 -07:00
|
|
|
op1 = mps1._to_xla_hlo_sharding(ndim)
|
2023-09-05 17:27:47 -07:00
|
|
|
cache_info1 = sharding_impls.named_sharding_to_xla_hlo_sharding.cache_info()
|
2022-08-05 12:17:41 -07:00
|
|
|
|
2022-11-14 14:43:26 -08:00
|
|
|
mps2 = NamedSharding(mesh, P('x', 'y'))
|
2023-06-05 13:40:59 -07:00
|
|
|
op2 = mps2._to_xla_hlo_sharding(ndim)
|
2023-09-05 17:27:47 -07:00
|
|
|
cache_info2 = sharding_impls.named_sharding_to_xla_hlo_sharding.cache_info()
|
2022-08-05 12:17:41 -07:00
|
|
|
|
|
|
|
self.assertEqual(id(op1), id(op2))
|
|
|
|
self.assertEqual(cache_info2.hits, cache_info1.hits + 1)
|
|
|
|
self.assertEqual(cache_info2.misses, cache_info1.misses)
|
|
|
|
self.assertEqual(cache_info2.currsize, cache_info1.currsize)
|
2022-08-05 09:59:22 -07:00
|
|
|
|
2022-09-09 09:13:10 -07:00
|
|
|
def test_get_partition_spec(self):
|
2024-09-03 16:22:23 -07:00
|
|
|
mesh = jtu.create_mesh((4, 2), ('x', 'y'))
|
2022-11-14 14:43:26 -08:00
|
|
|
s = NamedSharding(mesh, P('x', 'y', None))
|
2022-09-09 09:13:10 -07:00
|
|
|
|
|
|
|
self.assertEqual(s._parsed_pspec.get_partition_spec(), P('x', 'y', None))
|
|
|
|
|
2023-04-11 19:25:56 -07:00
|
|
|
recovered_parsed_pspec = parse_flatten_op_sharding(
|
2023-06-05 13:40:59 -07:00
|
|
|
s._to_xla_hlo_sharding(3), mesh)
|
2022-09-09 09:13:10 -07:00
|
|
|
self.assertEqual(recovered_parsed_pspec[0].get_partition_spec(),
|
2023-05-02 08:55:12 -07:00
|
|
|
P('x', 'y'))
|
2022-09-09 09:13:10 -07:00
|
|
|
|
2023-04-10 10:15:08 -07:00
|
|
|
out_of_sync_parsed_pspec = sharding_impls.ParsedPartitionSpec(
|
|
|
|
P('x', 'y'), ('x', 'y'), sharding_impls.SpecSync.OUT_OF_SYNC)
|
2022-09-09 09:13:10 -07:00
|
|
|
self.assertEqual(out_of_sync_parsed_pspec.get_partition_spec(),
|
2023-05-02 08:55:12 -07:00
|
|
|
P('x', 'y'))
|
2022-09-09 09:13:10 -07:00
|
|
|
|
2022-09-23 08:47:45 -07:00
|
|
|
def test_mesh_with_list_devices(self):
|
2023-02-03 14:28:07 -08:00
|
|
|
mesh = jax.sharding.Mesh(jax.devices(), ('x',))
|
2022-09-23 08:47:45 -07:00
|
|
|
self.assertIsInstance(mesh.devices, np.ndarray)
|
|
|
|
self.assertEqual(mesh.size, jax.device_count())
|
|
|
|
|
2022-11-11 15:23:44 -08:00
|
|
|
def test_mesh_with_string_axis_names(self):
|
2023-02-03 14:28:07 -08:00
|
|
|
mesh = jax.sharding.Mesh(jax.devices(), 'dp')
|
2022-11-11 15:23:44 -08:00
|
|
|
self.assertTupleEqual(mesh.axis_names, ('dp',))
|
|
|
|
|
2024-08-22 13:31:53 -07:00
|
|
|
def test_sharded_in_place_assignment(self):
|
2024-09-03 16:22:23 -07:00
|
|
|
mesh = jtu.create_mesh((8,), ('data',))
|
2024-08-22 13:31:53 -07:00
|
|
|
|
|
|
|
idx = [0, 2, 5, 7, 8, 10, 13, 15]
|
|
|
|
n = 16
|
|
|
|
def _init():
|
|
|
|
w = jnp.zeros((n, n))
|
|
|
|
idx1 = jnp.array(idx)
|
|
|
|
w = w.at[idx1, jnp.arange(n//2)].set(1)
|
|
|
|
return w
|
|
|
|
|
|
|
|
w = jax.jit(_init, out_shardings=NamedSharding(mesh, P(None, 'data')))()
|
|
|
|
|
|
|
|
w_gt = np.zeros((n, n))
|
|
|
|
for j, i in enumerate(idx):
|
|
|
|
w_gt[i, j] = 1
|
|
|
|
|
|
|
|
self.assertArraysEqual(w, w_gt)
|
2022-07-25 13:24:08 -07:00
|
|
|
|
#sdy Initial set of changes to allow for lowering to the Shardy dialect.
The OpenXLA project is working on an open source, MLIR, named-axis based propagation (and in the future SP<D partitioning) system that will be dialect agnostic (would work for any dialect - MHLO, StableHLO, YourDialect). We plan on having frontends like JAX and PyTorch target this when using XLA and wanting SPMD propagation/partitioning. See www.github.com/openxla/shardy for more info.
Currently Shardy is implemented inside the XLA compiler, requiring us to round-trip between StableHLO and HLO with `mhlo.sharding`s. But we will eventually make Shardy the first pass in the XLA pipeline while it's still working on StableHLO. Partitioning (the system that adds the collectives like all-gathers/all-reduces) will still be the GSPMD Partitioner, but next year the Shardy partitioner will be developed, allowing for propagation and partitioning to be completely in MLIR and the first pass in the pipeline. So then we'd have:
1. Traced jaxpr
2. Jaxpr -> StableHLO
3. StableHLO with Shardy propagation
4. StableHLO with Shardy partitioning
5. StableHLO -> HLO
6. XLA optimizations
The following test:
```py
def test_sdy_lowering(self):
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
np_inp = np.arange(16).reshape(8, 2)
s = jax.sharding.NamedSharding(mesh, P('x', 'y'))
arr = jax.device_put(np_inp, s)
@partial(jax.jit, out_shardings=s)
def f(x):
return x * 2
print(f.lower(arr).as_text())
```
outputs:
```
module @jit_f attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} {
sdy.mesh @mesh = <"x"=4, "y"=2>
func.func public @main(%arg0: tensor<8x2xi64> {mhlo.layout_mode = "{1,0}", sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}) -> (tensor<8x2xi64> {jax.result_info = "", mhlo.layout_mode = "default", sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}) {
%c = stablehlo.constant dense<2> : tensor<i64>
%0 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i64>) -> tensor<8x2xi64>
%1 = stablehlo.multiply %arg0, %0 : tensor<8x2xi64>
return %1 : tensor<8x2xi64>
}
}
```
Shardy will be hidden behind the `jax_use_shardy_partitioner` flag initially before becoming enabled by default in the future.
PiperOrigin-RevId: 655127611
2024-07-23 05:31:15 -07:00
|
|
|
@jtu.with_config(jax_use_shardy_partitioner=True)
|
|
|
|
class SdyIntegrationTest(jtu.JaxTestCase):
|
|
|
|
|
|
|
|
# TODO(bartchr): Once JAX is released with SDY, remove setUp.
|
|
|
|
def setUp(self):
|
|
|
|
if not dialects.sdy:
|
|
|
|
raise unittest.SkipTest('Shardy is not available.')
|
|
|
|
|
|
|
|
def test_lowering_input_output_sharding(self):
|
2024-09-03 16:22:23 -07:00
|
|
|
mesh = jtu.create_mesh((4, 2), ('x', 'y'))
|
#sdy Initial set of changes to allow for lowering to the Shardy dialect.
The OpenXLA project is working on an open source, MLIR, named-axis based propagation (and in the future SP<D partitioning) system that will be dialect agnostic (would work for any dialect - MHLO, StableHLO, YourDialect). We plan on having frontends like JAX and PyTorch target this when using XLA and wanting SPMD propagation/partitioning. See www.github.com/openxla/shardy for more info.
Currently Shardy is implemented inside the XLA compiler, requiring us to round-trip between StableHLO and HLO with `mhlo.sharding`s. But we will eventually make Shardy the first pass in the XLA pipeline while it's still working on StableHLO. Partitioning (the system that adds the collectives like all-gathers/all-reduces) will still be the GSPMD Partitioner, but next year the Shardy partitioner will be developed, allowing for propagation and partitioning to be completely in MLIR and the first pass in the pipeline. So then we'd have:
1. Traced jaxpr
2. Jaxpr -> StableHLO
3. StableHLO with Shardy propagation
4. StableHLO with Shardy partitioning
5. StableHLO -> HLO
6. XLA optimizations
The following test:
```py
def test_sdy_lowering(self):
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
np_inp = np.arange(16).reshape(8, 2)
s = jax.sharding.NamedSharding(mesh, P('x', 'y'))
arr = jax.device_put(np_inp, s)
@partial(jax.jit, out_shardings=s)
def f(x):
return x * 2
print(f.lower(arr).as_text())
```
outputs:
```
module @jit_f attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} {
sdy.mesh @mesh = <"x"=4, "y"=2>
func.func public @main(%arg0: tensor<8x2xi64> {mhlo.layout_mode = "{1,0}", sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}) -> (tensor<8x2xi64> {jax.result_info = "", mhlo.layout_mode = "default", sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}) {
%c = stablehlo.constant dense<2> : tensor<i64>
%0 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i64>) -> tensor<8x2xi64>
%1 = stablehlo.multiply %arg0, %0 : tensor<8x2xi64>
return %1 : tensor<8x2xi64>
}
}
```
Shardy will be hidden behind the `jax_use_shardy_partitioner` flag initially before becoming enabled by default in the future.
PiperOrigin-RevId: 655127611
2024-07-23 05:31:15 -07:00
|
|
|
np_inp = np.arange(16).reshape(8, 2)
|
|
|
|
s = jax.sharding.NamedSharding(mesh, P('x', 'y'))
|
|
|
|
arr = jax.device_put(np_inp, s)
|
|
|
|
|
|
|
|
@partial(jax.jit, out_shardings=s)
|
|
|
|
def f(x):
|
|
|
|
return x * 2
|
|
|
|
|
|
|
|
self.assertIn('sdy.sharding = #sdy.sharding', f.lower(arr).as_text())
|
|
|
|
|
2024-07-25 04:20:09 -07:00
|
|
|
def test_lowering_with_sharding_constraint(self):
|
2024-09-03 16:22:23 -07:00
|
|
|
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
2024-07-25 04:20:09 -07:00
|
|
|
arr = np.arange(16).reshape(4, 2, 2)
|
|
|
|
|
|
|
|
@jax.jit
|
|
|
|
def f(x):
|
|
|
|
return jax.lax.with_sharding_constraint(
|
|
|
|
x, NamedSharding(mesh, P('x', None, 'y')))
|
|
|
|
lowered_str = jax.jit(f).lower(arr).as_text()
|
|
|
|
self.assertIn('sdy.sharding_constraint', lowered_str)
|
|
|
|
self.assertIn('<@mesh, [{"x"}, {}, {"y"}]>', lowered_str)
|
|
|
|
|
|
|
|
def test_lowering_with_sharding_constraint_unconstrained(self):
|
2024-09-03 16:22:23 -07:00
|
|
|
mesh = jtu.create_mesh((4, 2), ('x', 'y'))
|
2024-07-25 04:20:09 -07:00
|
|
|
arr = np.arange(16).reshape(4, 2, 2)
|
|
|
|
|
|
|
|
@jax.jit
|
|
|
|
def f(x):
|
|
|
|
return jax.lax.with_sharding_constraint(
|
|
|
|
x, NamedSharding(mesh, P('x', P.UNCONSTRAINED, 'y')))
|
|
|
|
lowered_str = f.lower(arr).as_text()
|
|
|
|
self.assertIn('sdy.sharding_constraint', lowered_str)
|
|
|
|
self.assertIn('<@mesh, [{"x"}, {?}, {"y"}]>', lowered_str)
|
#sdy Initial set of changes to allow for lowering to the Shardy dialect.
The OpenXLA project is working on an open source, MLIR, named-axis based propagation (and in the future SP<D partitioning) system that will be dialect agnostic (would work for any dialect - MHLO, StableHLO, YourDialect). We plan on having frontends like JAX and PyTorch target this when using XLA and wanting SPMD propagation/partitioning. See www.github.com/openxla/shardy for more info.
Currently Shardy is implemented inside the XLA compiler, requiring us to round-trip between StableHLO and HLO with `mhlo.sharding`s. But we will eventually make Shardy the first pass in the XLA pipeline while it's still working on StableHLO. Partitioning (the system that adds the collectives like all-gathers/all-reduces) will still be the GSPMD Partitioner, but next year the Shardy partitioner will be developed, allowing for propagation and partitioning to be completely in MLIR and the first pass in the pipeline. So then we'd have:
1. Traced jaxpr
2. Jaxpr -> StableHLO
3. StableHLO with Shardy propagation
4. StableHLO with Shardy partitioning
5. StableHLO -> HLO
6. XLA optimizations
The following test:
```py
def test_sdy_lowering(self):
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
np_inp = np.arange(16).reshape(8, 2)
s = jax.sharding.NamedSharding(mesh, P('x', 'y'))
arr = jax.device_put(np_inp, s)
@partial(jax.jit, out_shardings=s)
def f(x):
return x * 2
print(f.lower(arr).as_text())
```
outputs:
```
module @jit_f attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} {
sdy.mesh @mesh = <"x"=4, "y"=2>
func.func public @main(%arg0: tensor<8x2xi64> {mhlo.layout_mode = "{1,0}", sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}) -> (tensor<8x2xi64> {jax.result_info = "", mhlo.layout_mode = "default", sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}) {
%c = stablehlo.constant dense<2> : tensor<i64>
%0 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i64>) -> tensor<8x2xi64>
%1 = stablehlo.multiply %arg0, %0 : tensor<8x2xi64>
return %1 : tensor<8x2xi64>
}
}
```
Shardy will be hidden behind the `jax_use_shardy_partitioner` flag initially before becoming enabled by default in the future.
PiperOrigin-RevId: 655127611
2024-07-23 05:31:15 -07:00
|
|
|
|
2024-07-31 09:45:13 -07:00
|
|
|
# TODO(bartchr): run on CPU once Shardy is added to the XLA CPU pipeline.
|
|
|
|
@jtu.skip_on_devices('cpu')
|
|
|
|
def test_compile_with_inferred_out_sharding(self):
|
2024-09-03 16:22:23 -07:00
|
|
|
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
2024-07-31 09:45:13 -07:00
|
|
|
x = jax.device_put(np.arange(8 * 4).reshape(8, 4),
|
2024-09-10 03:06:16 -07:00
|
|
|
NamedSharding(mesh, P('x', 'y')))
|
2024-07-31 09:45:13 -07:00
|
|
|
y = jax.device_put(np.arange(4 * 16).reshape(4, 16),
|
2024-09-10 03:06:16 -07:00
|
|
|
NamedSharding(mesh, P('y')))
|
2024-07-31 09:45:13 -07:00
|
|
|
|
|
|
|
@jax.jit
|
|
|
|
def f(x, y):
|
|
|
|
return x @ y
|
|
|
|
|
|
|
|
out = f(x, y)
|
|
|
|
self.assertArraysEqual(out, x @ y)
|
2024-09-10 03:06:16 -07:00
|
|
|
self.assertEqual(out.sharding, NamedSharding(mesh, P('x')))
|
|
|
|
|
|
|
|
def test_fully_automatic_sharding(self):
|
|
|
|
mesh = jtu.create_mesh((8,), ('x',))
|
|
|
|
x = jax.ShapeDtypeStruct((128, 128), jnp.float32)
|
|
|
|
|
|
|
|
@jax.jit
|
|
|
|
def f(x, y):
|
|
|
|
return x @ y
|
|
|
|
|
|
|
|
lowered_str = jax.jit(f, in_shardings=[AUTO(mesh), AUTO(mesh)]).lower(x, x).as_text()
|
|
|
|
self.assertIn('sdy.mesh @mesh = <["x"=8]>', lowered_str)
|
2024-07-31 09:45:13 -07:00
|
|
|
|
#sdy Initial set of changes to allow for lowering to the Shardy dialect.
The OpenXLA project is working on an open source, MLIR, named-axis based propagation (and in the future SP<D partitioning) system that will be dialect agnostic (would work for any dialect - MHLO, StableHLO, YourDialect). We plan on having frontends like JAX and PyTorch target this when using XLA and wanting SPMD propagation/partitioning. See www.github.com/openxla/shardy for more info.
Currently Shardy is implemented inside the XLA compiler, requiring us to round-trip between StableHLO and HLO with `mhlo.sharding`s. But we will eventually make Shardy the first pass in the XLA pipeline while it's still working on StableHLO. Partitioning (the system that adds the collectives like all-gathers/all-reduces) will still be the GSPMD Partitioner, but next year the Shardy partitioner will be developed, allowing for propagation and partitioning to be completely in MLIR and the first pass in the pipeline. So then we'd have:
1. Traced jaxpr
2. Jaxpr -> StableHLO
3. StableHLO with Shardy propagation
4. StableHLO with Shardy partitioning
5. StableHLO -> HLO
6. XLA optimizations
The following test:
```py
def test_sdy_lowering(self):
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
np_inp = np.arange(16).reshape(8, 2)
s = jax.sharding.NamedSharding(mesh, P('x', 'y'))
arr = jax.device_put(np_inp, s)
@partial(jax.jit, out_shardings=s)
def f(x):
return x * 2
print(f.lower(arr).as_text())
```
outputs:
```
module @jit_f attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} {
sdy.mesh @mesh = <"x"=4, "y"=2>
func.func public @main(%arg0: tensor<8x2xi64> {mhlo.layout_mode = "{1,0}", sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}) -> (tensor<8x2xi64> {jax.result_info = "", mhlo.layout_mode = "default", sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}) {
%c = stablehlo.constant dense<2> : tensor<i64>
%0 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i64>) -> tensor<8x2xi64>
%1 = stablehlo.multiply %arg0, %0 : tensor<8x2xi64>
return %1 : tensor<8x2xi64>
}
}
```
Shardy will be hidden behind the `jax_use_shardy_partitioner` flag initially before becoming enabled by default in the future.
PiperOrigin-RevId: 655127611
2024-07-23 05:31:15 -07:00
|
|
|
|
2021-02-05 16:50:38 -08:00
|
|
|
if __name__ == '__main__':
|
|
|
|
absltest.main(testLoader=jtu.JaxTestLoader())
|