mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #14206 from jakevdp:jax-shapedarray
PiperOrigin-RevId: 505788784
This commit is contained in:
commit
c7b1b6cb1e
@ -626,7 +626,7 @@ def bench_pjit_check_aval_sharding(state):
|
||||
if mesh is None:
|
||||
return
|
||||
s = sharding.NamedSharding(mesh, pxla.PartitionSpec('x', 'y'))
|
||||
aval = jax.ShapedArray((8, 2), np.int32)
|
||||
aval = jax.core.ShapedArray((8, 2), np.int32)
|
||||
|
||||
while state:
|
||||
pjit_lib.pjit_check_aval_sharding([s] * 100, [aval] * 100, 'benchmark', False)
|
||||
|
@ -111,7 +111,7 @@ from jax._src.api import (
|
||||
pure_callback as pure_callback,
|
||||
pxla, # TODO(phawkins): update users to avoid this.
|
||||
remat as remat,
|
||||
ShapedArray as ShapedArray,
|
||||
ShapedArray, # TODO(jakevdp): update users to avoid this.
|
||||
ShapeDtypeStruct as ShapeDtypeStruct,
|
||||
value_and_grad as value_and_grad,
|
||||
vjp as vjp,
|
||||
|
@ -483,7 +483,7 @@ class PmapSharding(XLACompatibleSharding):
|
||||
"""
|
||||
# The dtype doesn't matter here. Its only used for creating the
|
||||
# sharding_spec.
|
||||
aval = jax.ShapedArray(shape, np.int32)
|
||||
aval = jax.core.ShapedArray(shape, np.int32)
|
||||
sharding_spec = pxla._create_pmap_sharding_spec(aval, sharded_dim)
|
||||
|
||||
num_ways_sharded = None
|
||||
|
@ -110,7 +110,7 @@ def _handle_array_process_allgather(inp, tiled):
|
||||
if host_np_arr.ndim == 0 or not tiled:
|
||||
host_np_arr = np.expand_dims(host_np_arr, axis=0)
|
||||
|
||||
aval = jax.ShapedArray(host_np_arr.shape, host_np_arr.dtype)
|
||||
aval = jax.core.ShapedArray(host_np_arr.shape, host_np_arr.dtype)
|
||||
global_aval = global_mesh._local_to_global(
|
||||
pxla._get_array_mapping(pspec), aval)
|
||||
|
||||
@ -322,7 +322,7 @@ def host_local_array_to_global_array(local_inputs: Any,
|
||||
))
|
||||
|
||||
global_aval = _local_to_global_aval(
|
||||
jax.ShapedArray(arr.shape, arrays[0].dtype), global_mesh, pspec)
|
||||
jax.core.ShapedArray(arr.shape, arrays[0].dtype), global_mesh, pspec)
|
||||
|
||||
return array.ArrayImpl(
|
||||
global_aval, jax.sharding.NamedSharding(global_mesh, pspec),
|
||||
|
@ -95,7 +95,7 @@ def jax_to_ir(fn, input_shapes, *, constants=None, format):
|
||||
|
||||
Args:
|
||||
fn: Function to convert.
|
||||
input_shapes: List of tuples (arg name, jax.ShapedArray),
|
||||
input_shapes: List of tuples (arg name, jax.core.ShapedArray),
|
||||
indicating the shapes of the arguments to fn. The order of parameters in
|
||||
the resulting XLA program will match the order in this list.
|
||||
constants: Dict mapping function argument name to a Python value. Specified
|
||||
@ -213,7 +213,7 @@ def parse_shape_str(s):
|
||||
shape = tuple(int(d.strip()) for d in match.group(2).split(","))
|
||||
else:
|
||||
shape = ()
|
||||
return jax.ShapedArray(shape, dtype)
|
||||
return jax.core.ShapedArray(shape, dtype)
|
||||
|
||||
_DT = {'pred': jnp.bool_,
|
||||
'u8': jnp.uint8, 'u16': jnp.uint16, 'u32': jnp.uint32, 'u64': jnp.uint64,
|
||||
|
@ -962,7 +962,7 @@ class CPPJitTest(jtu.BufferDonationTestCase):
|
||||
for obj in [lowered, compiled]:
|
||||
self.assertEqual(
|
||||
obj.in_avals,
|
||||
((jax.ShapedArray([], expected_dtype, weak_type=True),), {}))
|
||||
((jax.core.ShapedArray([], expected_dtype, weak_type=True),), {}))
|
||||
self.assertEqual(obj.in_tree, jax.tree_util.tree_flatten(((0,), {}))[1])
|
||||
|
||||
def test_jit_lower_duck_typing(self):
|
||||
|
@ -144,7 +144,7 @@ class JaxArrayTest(jtu.JaxTestCase):
|
||||
self.assertEqual(arr.is_fully_replicated, expected_is_fully_replicated)
|
||||
for i, s in enumerate(arr.addressable_shards):
|
||||
self.assertEqual(s.data.aval,
|
||||
jax.ShapedArray(expected_shard_shape, s.data.dtype))
|
||||
jax.core.ShapedArray(expected_shard_shape, s.data.dtype))
|
||||
self.assertArraysEqual(s.data, global_input_data[s.index])
|
||||
self.assertArraysEqual(s.data, arr.addressable_data(i))
|
||||
|
||||
@ -317,13 +317,13 @@ class JaxArrayTest(jtu.JaxTestCase):
|
||||
ValueError,
|
||||
r'Expected 8 per-device arrays \(this is how many devices are addressable '
|
||||
r'by the sharding\), but got 4'):
|
||||
array.ArrayImpl(jax.ShapedArray(shape, np.float32), s, bufs[:4], committed=True)
|
||||
array.ArrayImpl(jax.core.ShapedArray(shape, np.float32), s, bufs[:4], committed=True)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
r'Expected 8 per-device arrays \(this is how many devices are addressable '
|
||||
r'by the sharding\), but got 16'):
|
||||
array.ArrayImpl(jax.ShapedArray(shape, np.float32), s, bufs + bufs, committed=True)
|
||||
array.ArrayImpl(jax.core.ShapedArray(shape, np.float32), s, bufs + bufs, committed=True)
|
||||
|
||||
def test_arrays_not_in_device_assignment(self):
|
||||
if jax.device_count() < 4:
|
||||
@ -341,7 +341,7 @@ class JaxArrayTest(jtu.JaxTestCase):
|
||||
"Sharding contains devices {0, 1} that are not present in per-device "
|
||||
"arrays. Per-device arrays contain devices {2, 3} that are not present "
|
||||
"in the sharding."):
|
||||
array.ArrayImpl(jax.ShapedArray(shape, np.float32), s, bufs, committed=True)
|
||||
array.ArrayImpl(jax.core.ShapedArray(shape, np.float32), s, bufs, committed=True)
|
||||
|
||||
def test_more_devices_in_sharding_than_arrays(self):
|
||||
shape = (8, 2)
|
||||
@ -356,7 +356,7 @@ class JaxArrayTest(jtu.JaxTestCase):
|
||||
"Addressable devices and per-device arrays devices do not match. "
|
||||
r"Sharding contains devices \{1\} that are not present in per-device "
|
||||
"arrays."):
|
||||
array.ArrayImpl(jax.ShapedArray(shape, np.float32), s, bufs, committed=True)
|
||||
array.ArrayImpl(jax.core.ShapedArray(shape, np.float32), s, bufs, committed=True)
|
||||
|
||||
def test_different_devices_in_arrays_than_sharding(self):
|
||||
if jax.device_count() < 3:
|
||||
@ -374,7 +374,7 @@ class JaxArrayTest(jtu.JaxTestCase):
|
||||
r"Sharding contains devices \{2\} that are not present in per-device "
|
||||
r"arrays. Per-device arrays contain devices \{0\} that are not present "
|
||||
"in the sharding."):
|
||||
array.ArrayImpl(jax.ShapedArray(shape, np.float32), s, bufs, committed=True)
|
||||
array.ArrayImpl(jax.core.ShapedArray(shape, np.float32), s, bufs, committed=True)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("mesh_x_y", P("x", "y"), (2, 2)),
|
||||
@ -409,7 +409,7 @@ class JaxArrayTest(jtu.JaxTestCase):
|
||||
ValueError,
|
||||
"Input buffers to `Array` must have matching dtypes. "
|
||||
"Got int32, expected float32"):
|
||||
array.ArrayImpl(jax.ShapedArray(shape, np.float32), s, bufs, committed=True)
|
||||
array.ArrayImpl(jax.core.ShapedArray(shape, np.float32), s, bufs, committed=True)
|
||||
|
||||
def test_array_iter_pmap_sharding(self):
|
||||
if jax.device_count() < 2:
|
||||
@ -980,7 +980,7 @@ class RngShardingTest(jtu.JaxTestCase):
|
||||
fun,
|
||||
in_axis_resources=P('data'),
|
||||
out_axis_resources=P(None, 'data'),
|
||||
).lower(jax.ShapedArray(shape=(8, 8), dtype=np.float32))
|
||||
).lower(jax.core.ShapedArray(shape=(8, 8), dtype=np.float32))
|
||||
|
||||
def verify_serialization(lowered):
|
||||
serialized, in_tree, out_tree = compile_and_serialize(lowered)
|
||||
|
@ -37,9 +37,9 @@ class InfeedTest(jtu.JaxTestCase):
|
||||
def f(x):
|
||||
token = lax.create_token(x)
|
||||
(y,), token = lax.infeed(
|
||||
token, shape=(jax.ShapedArray((3, 4), jnp.float32),))
|
||||
token, shape=(jax.core.ShapedArray((3, 4), jnp.float32),))
|
||||
(z,), _ = lax.infeed(
|
||||
token, shape=(jax.ShapedArray((3, 1, 1), jnp.float32),))
|
||||
token, shape=(jax.core.ShapedArray((3, 1, 1), jnp.float32),))
|
||||
return x + y + z
|
||||
|
||||
x = np.float32(1.5)
|
||||
@ -55,8 +55,8 @@ class InfeedTest(jtu.JaxTestCase):
|
||||
x = np.float32(1.5)
|
||||
y = np.reshape(np.arange(12, dtype=np.int16), (3, 4))
|
||||
to_infeed = dict(a=x, b=y)
|
||||
to_infeed_shape = dict(a=jax.ShapedArray((), dtype=np.float32),
|
||||
b=jax.ShapedArray((3, 4), dtype=np.int16))
|
||||
to_infeed_shape = dict(a=jax.core.ShapedArray((), dtype=np.float32),
|
||||
b=jax.core.ShapedArray((3, 4), dtype=np.int16))
|
||||
@jax.jit
|
||||
def f(x):
|
||||
token = lax.create_token(x)
|
||||
@ -77,7 +77,7 @@ class InfeedTest(jtu.JaxTestCase):
|
||||
def f(x):
|
||||
token = lax.create_token(x)
|
||||
y, token = lax.infeed(
|
||||
token, shape=jax.ShapedArray((3, 4), jnp.float32))
|
||||
token, shape=jax.core.ShapedArray((3, 4), jnp.float32))
|
||||
token = lax.outfeed(token, y + np.float32(1))
|
||||
return x - 1
|
||||
|
||||
@ -97,7 +97,7 @@ class InfeedTest(jtu.JaxTestCase):
|
||||
|
||||
def doubler(_, token):
|
||||
y, token = lax.infeed(
|
||||
token, shape=jax.ShapedArray((3, 4), jnp.float32))
|
||||
token, shape=jax.core.ShapedArray((3, 4), jnp.float32))
|
||||
return lax.outfeed(token, y * np.float32(2))
|
||||
|
||||
@jax.jit
|
||||
|
@ -538,7 +538,7 @@ class SlurmMultiNodeGpuTest(jtu.JaxTestCase):
|
||||
f = pjit.pjit(lambda x, y: (x, y),
|
||||
in_axis_resources=experimental.PartitionSpec("x", "y"),
|
||||
out_axis_resources=experimental.PartitionSpec("x", "y"))
|
||||
inp_aval = jax.ShapedArray((8, 2), jnp.int32)
|
||||
inp_aval = jax.core.ShapedArray((8, 2), jnp.int32)
|
||||
# `ShapedArray` is considered global when lowered and compiled.
|
||||
# Hence it can bypass the contiguous mesh restriction.
|
||||
compiled = f.lower(inp_aval, gda1).compile()
|
||||
|
@ -710,11 +710,11 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
def f_for_jit(x):
|
||||
token = lax.create_token(x)
|
||||
(y,), token = lax.infeed(
|
||||
token, shape=(jax.ShapedArray(x.shape, np.float32),))
|
||||
token, shape=(jax.core.ShapedArray(x.shape, np.float32),))
|
||||
(z,), token = lax.infeed(
|
||||
token, shape=(jax.ShapedArray(x.shape, np.float32),))
|
||||
token, shape=(jax.core.ShapedArray(x.shape, np.float32),))
|
||||
(w,), token = lax.infeed(
|
||||
token, shape=(jax.ShapedArray(x.shape, np.float32),))
|
||||
token, shape=(jax.core.ShapedArray(x.shape, np.float32),))
|
||||
|
||||
return x + y + z + w
|
||||
|
||||
@ -744,17 +744,17 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
# A replicated infeed
|
||||
(y,), token = lax.infeed(
|
||||
token,
|
||||
shape=(jax.ShapedArray(x.shape, np.float32),),
|
||||
shape=(jax.core.ShapedArray(x.shape, np.float32),),
|
||||
partitions=(None,))
|
||||
# An infeed sharded on first axis
|
||||
(z,), token = lax.infeed(
|
||||
token,
|
||||
shape=(jax.ShapedArray(x.shape, np.float32),),
|
||||
shape=(jax.core.ShapedArray(x.shape, np.float32),),
|
||||
partitions=(P(nr_devices, 1),))
|
||||
# An infeed sharded on second axis
|
||||
(w,), token = lax.infeed(
|
||||
token,
|
||||
shape=(jax.ShapedArray(x.shape, np.float32),),
|
||||
shape=(jax.core.ShapedArray(x.shape, np.float32),),
|
||||
partitions=(P(1, nr_devices),))
|
||||
return x + y + z + w
|
||||
|
||||
@ -838,7 +838,7 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
self.assertEqual(lowered.in_avals, compiled.in_avals)
|
||||
self.assertEqual(
|
||||
lowered.in_avals,
|
||||
((jax.ShapedArray(x.shape, x.dtype, weak_type=False),) * 2, {}))
|
||||
((jax.core.ShapedArray(x.shape, x.dtype, weak_type=False),) * 2, {}))
|
||||
|
||||
splits = np.split(expected, 4)
|
||||
self.assertAllClose(np.asarray(actual.device_buffers[0]), splits[0],
|
||||
@ -1027,7 +1027,7 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
return x @ y
|
||||
|
||||
shape = (8, 8)
|
||||
aval = jax.ShapedArray(shape, dtypes.canonicalize_dtype(jnp.int64))
|
||||
aval = jax.core.ShapedArray(shape, dtypes.canonicalize_dtype(jnp.int64))
|
||||
x = jnp.arange(np.prod(shape)).reshape(shape)
|
||||
exe = f.lower(aval, x).compile()
|
||||
self.assertIsInstance(exe, stages.Compiled)
|
||||
@ -1473,7 +1473,7 @@ class GDAPjitTest(jtu.JaxTestCase):
|
||||
|
||||
with global_mesh:
|
||||
f = pjit(lambda x: x, in_axis_resources=P('x'), out_axis_resources=P('x'))
|
||||
compiled = f.lower(jax.ShapedArray(global_input_shape, jnp.float32)).compile()
|
||||
compiled = f.lower(jax.core.ShapedArray(global_input_shape, jnp.float32)).compile()
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "GDA sharding does not match the input sharding."):
|
||||
compiled(input_gda)
|
||||
@ -1485,7 +1485,7 @@ class GDAPjitTest(jtu.JaxTestCase):
|
||||
g1, _ = create_gda(global_input_shape, global_mesh, P(None,))
|
||||
with global_mesh:
|
||||
f = pjit(lambda x: x, in_axis_resources=P(None), out_axis_resources=P('x'))
|
||||
compiled = f.lower(jax.ShapedArray(global_input_shape, jnp.float32)).compile()
|
||||
compiled = f.lower(jax.core.ShapedArray(global_input_shape, jnp.float32)).compile()
|
||||
compiled(g1) # no error
|
||||
|
||||
@parallel_functions_output_gda(True)
|
||||
@ -1541,7 +1541,7 @@ class AutoShardingPjitTest(jtu.JaxTestCase):
|
||||
f = pjit(lambda x: x, in_axis_resources=AUTO,
|
||||
out_axis_resources=AUTO)
|
||||
|
||||
inp = jax.ShapedArray(input_data.shape, input_data.dtype)
|
||||
inp = jax.core.ShapedArray(input_data.shape, input_data.dtype)
|
||||
compiled = f.lower(inp).compile()
|
||||
inputs = [create_gda(global_input_shape, global_mesh, ip, input_data)[0]
|
||||
for ip in compiled.input_shardings[0]]
|
||||
@ -1568,7 +1568,7 @@ class AutoShardingPjitTest(jtu.JaxTestCase):
|
||||
f = pjit(lambda x: x, in_axis_resources=AUTO,
|
||||
out_axis_resources=AUTO)
|
||||
|
||||
inp = jax.ShapedArray(input_data.shape, input_data.dtype)
|
||||
inp = jax.core.ShapedArray(input_data.shape, input_data.dtype)
|
||||
compiled = f.lower(inp).compile()
|
||||
inputs = [create_array(global_input_shape, global_mesh, ip, input_data)[0]
|
||||
for ip in compiled.input_shardings[0]]
|
||||
@ -1592,7 +1592,7 @@ class AutoShardingPjitTest(jtu.JaxTestCase):
|
||||
with ctx(True):
|
||||
with global_mesh:
|
||||
f = pjit(lambda x: x, in_axis_resources=AUTO, out_axis_resources=AUTO)
|
||||
inp = jax.ShapedArray(input_data.shape, input_data.dtype)
|
||||
inp = jax.core.ShapedArray(input_data.shape, input_data.dtype)
|
||||
compiled = f.lower(inp).compile()
|
||||
|
||||
different_pspec = (P('y', 'x')
|
||||
@ -1616,7 +1616,7 @@ class AutoShardingPjitTest(jtu.JaxTestCase):
|
||||
with global_mesh:
|
||||
f = pjit(lambda x, y, z: (x, y, z), in_axis_resources=AUTO,
|
||||
out_axis_resources=AUTO)
|
||||
inp = jax.ShapedArray(input_data.shape, input_data.dtype)
|
||||
inp = jax.core.ShapedArray(input_data.shape, input_data.dtype)
|
||||
compiled = f.lower(inp, inp, inp).compile()
|
||||
self.assertLen(compiled.output_shardings, 3)
|
||||
self.assertLen(compiled.input_shardings[0], 3)
|
||||
@ -1644,7 +1644,7 @@ class AutoShardingPjitTest(jtu.JaxTestCase):
|
||||
f = pjit(lambda x, y: (x, y), in_axis_resources=(in_resource, AUTO),
|
||||
out_axis_resources=AUTO)
|
||||
|
||||
inp = jax.ShapedArray(input_data.shape, input_data.dtype)
|
||||
inp = jax.core.ShapedArray(input_data.shape, input_data.dtype)
|
||||
compiled = f.lower(inp, inp).compile()
|
||||
inputs = [create_gda(global_input_shape, global_mesh, ip, input_data)[0]
|
||||
for ip in compiled.input_shardings[0]]
|
||||
@ -1674,7 +1674,7 @@ class AutoShardingPjitTest(jtu.JaxTestCase):
|
||||
f = pjit(lambda x, y: (x, y), in_axis_resources=(in_resource, AUTO),
|
||||
out_axis_resources=AUTO)
|
||||
|
||||
inp = jax.ShapedArray(input_data.shape, input_data.dtype)
|
||||
inp = jax.core.ShapedArray(input_data.shape, input_data.dtype)
|
||||
compiled = f.lower(inp, inp).compile()
|
||||
inputs = [create_array(global_input_shape, global_mesh, ip, input_data)[0]
|
||||
for ip in compiled.input_shardings[0]]
|
||||
@ -1699,7 +1699,7 @@ class AutoShardingPjitTest(jtu.JaxTestCase):
|
||||
f = pjit(lambda x: x, in_axis_resources=AUTO,
|
||||
out_axis_resources=AUTO)
|
||||
|
||||
inp = jax.ShapedArray(input_data.shape, input_data.dtype)
|
||||
inp = jax.core.ShapedArray(input_data.shape, input_data.dtype)
|
||||
compiled = f.lower(inp).compile()
|
||||
inputs = [create_array(global_input_shape, global_mesh, ip, input_data)[0]
|
||||
for ip in compiled.input_shardings[0]]
|
||||
@ -1951,7 +1951,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
a1, input_data = create_array(global_input_shape, global_mesh, P('x', 'y'))
|
||||
a2, _ = create_array(global_input_shape, global_mesh, P('x'))
|
||||
|
||||
aval = jax.ShapedArray(global_input_shape, np.float32)
|
||||
aval = jax.core.ShapedArray(global_input_shape, np.float32)
|
||||
|
||||
with jax_array(True):
|
||||
with global_mesh:
|
||||
@ -2075,7 +2075,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
with jax_array(True):
|
||||
with global_mesh:
|
||||
f = pjit(lambda x: x, in_axis_resources=NamedSharding(global_mesh, P(None,)))
|
||||
compiled = f.lower(jax.ShapedArray(input_shape, jnp.float32)).compile()
|
||||
compiled = f.lower(jax.core.ShapedArray(input_shape, jnp.float32)).compile()
|
||||
compiled(a1) # no error
|
||||
|
||||
@jax_array(True)
|
||||
@ -2201,7 +2201,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
di_map = s.devices_indices_map(shape)
|
||||
bufs = [jax.device_put(inp_data[di_map[d]], d)
|
||||
for d in jax.local_devices()]
|
||||
arr = array.ArrayImpl(jax.ShapedArray(shape, np.float32), s, bufs, committed=True)
|
||||
arr = array.ArrayImpl(jax.core.ShapedArray(shape, np.float32), s, bufs, committed=True)
|
||||
|
||||
f = pjit(lambda x: x, out_axis_resources=s)
|
||||
out = f(arr)
|
||||
@ -2302,7 +2302,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
inp = np.arange(prod(shape), dtype=np.int32).reshape(shape)
|
||||
arr = array.ArrayImpl(
|
||||
jax.ShapedArray(shape, np.int32), NamedSharding(mesh, P(None)),
|
||||
jax.core.ShapedArray(shape, np.int32), NamedSharding(mesh, P(None)),
|
||||
[jax.device_put(inp, d) for d in mesh.devices.flat], committed=False)
|
||||
with self.assertRaisesRegex(
|
||||
NotImplementedError,
|
||||
@ -2935,7 +2935,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
self.assertEqual(cache_info3.hits, cache_info2.hits)
|
||||
|
||||
# AOT test
|
||||
compiled = f.lower(jax.ShapedArray(y.shape, y.dtype)).compile()
|
||||
compiled = f.lower(jax.core.ShapedArray(y.shape, y.dtype)).compile()
|
||||
out3 = compiled(y)
|
||||
_check(out3, jax.devices()[1], y)
|
||||
|
||||
@ -2965,7 +2965,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
g_out = g(x)
|
||||
_check(g_out, jax.devices()[0], x)
|
||||
|
||||
compiled = g.lower(jax.ShapedArray(x.shape, x.dtype)).compile()
|
||||
compiled = g.lower(jax.core.ShapedArray(x.shape, x.dtype)).compile()
|
||||
out4 = compiled(x)
|
||||
_check(out4, jax.devices()[0], x)
|
||||
|
||||
@ -3570,7 +3570,7 @@ class UtilTest(jtu.JaxTestCase):
|
||||
def test_mesh_sharding_spec(self):
|
||||
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
array_mapping = pxla._get_array_mapping(P('x', 'y'))
|
||||
aval = jax.ShapedArray((1, 1), jnp.int32)
|
||||
aval = jax.core.ShapedArray((1, 1), jnp.int32)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
'The aval shape on dimension 0 is 1 and the size of axis x is 4. The '
|
||||
|
@ -204,7 +204,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
for obj in [lowered, compiled]:
|
||||
self.assertFalse(obj._no_kwargs)
|
||||
self.assertEqual(obj.in_tree, jax.tree_util.tree_flatten(((0,), {}))[1])
|
||||
self.assertEqual(obj.in_avals, ((jax.ShapedArray(x.shape, x.dtype),), {}))
|
||||
self.assertEqual(obj.in_avals, ((jax.core.ShapedArray(x.shape, x.dtype),), {}))
|
||||
|
||||
def testLowerCompileInTreeMismatch(self):
|
||||
f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i')
|
||||
|
Loading…
x
Reference in New Issue
Block a user