Merge pull request #14206 from jakevdp:jax-shapedarray

PiperOrigin-RevId: 505788784
This commit is contained in:
jax authors 2023-01-30 13:52:13 -08:00
commit c7b1b6cb1e
11 changed files with 48 additions and 48 deletions

View File

@ -626,7 +626,7 @@ def bench_pjit_check_aval_sharding(state):
if mesh is None:
return
s = sharding.NamedSharding(mesh, pxla.PartitionSpec('x', 'y'))
aval = jax.ShapedArray((8, 2), np.int32)
aval = jax.core.ShapedArray((8, 2), np.int32)
while state:
pjit_lib.pjit_check_aval_sharding([s] * 100, [aval] * 100, 'benchmark', False)

View File

@ -111,7 +111,7 @@ from jax._src.api import (
pure_callback as pure_callback,
pxla, # TODO(phawkins): update users to avoid this.
remat as remat,
ShapedArray as ShapedArray,
ShapedArray, # TODO(jakevdp): update users to avoid this.
ShapeDtypeStruct as ShapeDtypeStruct,
value_and_grad as value_and_grad,
vjp as vjp,

View File

@ -483,7 +483,7 @@ class PmapSharding(XLACompatibleSharding):
"""
# The dtype doesn't matter here. Its only used for creating the
# sharding_spec.
aval = jax.ShapedArray(shape, np.int32)
aval = jax.core.ShapedArray(shape, np.int32)
sharding_spec = pxla._create_pmap_sharding_spec(aval, sharded_dim)
num_ways_sharded = None

View File

@ -110,7 +110,7 @@ def _handle_array_process_allgather(inp, tiled):
if host_np_arr.ndim == 0 or not tiled:
host_np_arr = np.expand_dims(host_np_arr, axis=0)
aval = jax.ShapedArray(host_np_arr.shape, host_np_arr.dtype)
aval = jax.core.ShapedArray(host_np_arr.shape, host_np_arr.dtype)
global_aval = global_mesh._local_to_global(
pxla._get_array_mapping(pspec), aval)
@ -322,7 +322,7 @@ def host_local_array_to_global_array(local_inputs: Any,
))
global_aval = _local_to_global_aval(
jax.ShapedArray(arr.shape, arrays[0].dtype), global_mesh, pspec)
jax.core.ShapedArray(arr.shape, arrays[0].dtype), global_mesh, pspec)
return array.ArrayImpl(
global_aval, jax.sharding.NamedSharding(global_mesh, pspec),

View File

@ -95,7 +95,7 @@ def jax_to_ir(fn, input_shapes, *, constants=None, format):
Args:
fn: Function to convert.
input_shapes: List of tuples (arg name, jax.ShapedArray),
input_shapes: List of tuples (arg name, jax.core.ShapedArray),
indicating the shapes of the arguments to fn. The order of parameters in
the resulting XLA program will match the order in this list.
constants: Dict mapping function argument name to a Python value. Specified
@ -213,7 +213,7 @@ def parse_shape_str(s):
shape = tuple(int(d.strip()) for d in match.group(2).split(","))
else:
shape = ()
return jax.ShapedArray(shape, dtype)
return jax.core.ShapedArray(shape, dtype)
_DT = {'pred': jnp.bool_,
'u8': jnp.uint8, 'u16': jnp.uint16, 'u32': jnp.uint32, 'u64': jnp.uint64,

View File

@ -962,7 +962,7 @@ class CPPJitTest(jtu.BufferDonationTestCase):
for obj in [lowered, compiled]:
self.assertEqual(
obj.in_avals,
((jax.ShapedArray([], expected_dtype, weak_type=True),), {}))
((jax.core.ShapedArray([], expected_dtype, weak_type=True),), {}))
self.assertEqual(obj.in_tree, jax.tree_util.tree_flatten(((0,), {}))[1])
def test_jit_lower_duck_typing(self):

View File

@ -144,7 +144,7 @@ class JaxArrayTest(jtu.JaxTestCase):
self.assertEqual(arr.is_fully_replicated, expected_is_fully_replicated)
for i, s in enumerate(arr.addressable_shards):
self.assertEqual(s.data.aval,
jax.ShapedArray(expected_shard_shape, s.data.dtype))
jax.core.ShapedArray(expected_shard_shape, s.data.dtype))
self.assertArraysEqual(s.data, global_input_data[s.index])
self.assertArraysEqual(s.data, arr.addressable_data(i))
@ -317,13 +317,13 @@ class JaxArrayTest(jtu.JaxTestCase):
ValueError,
r'Expected 8 per-device arrays \(this is how many devices are addressable '
r'by the sharding\), but got 4'):
array.ArrayImpl(jax.ShapedArray(shape, np.float32), s, bufs[:4], committed=True)
array.ArrayImpl(jax.core.ShapedArray(shape, np.float32), s, bufs[:4], committed=True)
with self.assertRaisesRegex(
ValueError,
r'Expected 8 per-device arrays \(this is how many devices are addressable '
r'by the sharding\), but got 16'):
array.ArrayImpl(jax.ShapedArray(shape, np.float32), s, bufs + bufs, committed=True)
array.ArrayImpl(jax.core.ShapedArray(shape, np.float32), s, bufs + bufs, committed=True)
def test_arrays_not_in_device_assignment(self):
if jax.device_count() < 4:
@ -341,7 +341,7 @@ class JaxArrayTest(jtu.JaxTestCase):
"Sharding contains devices {0, 1} that are not present in per-device "
"arrays. Per-device arrays contain devices {2, 3} that are not present "
"in the sharding."):
array.ArrayImpl(jax.ShapedArray(shape, np.float32), s, bufs, committed=True)
array.ArrayImpl(jax.core.ShapedArray(shape, np.float32), s, bufs, committed=True)
def test_more_devices_in_sharding_than_arrays(self):
shape = (8, 2)
@ -356,7 +356,7 @@ class JaxArrayTest(jtu.JaxTestCase):
"Addressable devices and per-device arrays devices do not match. "
r"Sharding contains devices \{1\} that are not present in per-device "
"arrays."):
array.ArrayImpl(jax.ShapedArray(shape, np.float32), s, bufs, committed=True)
array.ArrayImpl(jax.core.ShapedArray(shape, np.float32), s, bufs, committed=True)
def test_different_devices_in_arrays_than_sharding(self):
if jax.device_count() < 3:
@ -374,7 +374,7 @@ class JaxArrayTest(jtu.JaxTestCase):
r"Sharding contains devices \{2\} that are not present in per-device "
r"arrays. Per-device arrays contain devices \{0\} that are not present "
"in the sharding."):
array.ArrayImpl(jax.ShapedArray(shape, np.float32), s, bufs, committed=True)
array.ArrayImpl(jax.core.ShapedArray(shape, np.float32), s, bufs, committed=True)
@parameterized.named_parameters(
("mesh_x_y", P("x", "y"), (2, 2)),
@ -409,7 +409,7 @@ class JaxArrayTest(jtu.JaxTestCase):
ValueError,
"Input buffers to `Array` must have matching dtypes. "
"Got int32, expected float32"):
array.ArrayImpl(jax.ShapedArray(shape, np.float32), s, bufs, committed=True)
array.ArrayImpl(jax.core.ShapedArray(shape, np.float32), s, bufs, committed=True)
def test_array_iter_pmap_sharding(self):
if jax.device_count() < 2:
@ -980,7 +980,7 @@ class RngShardingTest(jtu.JaxTestCase):
fun,
in_axis_resources=P('data'),
out_axis_resources=P(None, 'data'),
).lower(jax.ShapedArray(shape=(8, 8), dtype=np.float32))
).lower(jax.core.ShapedArray(shape=(8, 8), dtype=np.float32))
def verify_serialization(lowered):
serialized, in_tree, out_tree = compile_and_serialize(lowered)

View File

@ -37,9 +37,9 @@ class InfeedTest(jtu.JaxTestCase):
def f(x):
token = lax.create_token(x)
(y,), token = lax.infeed(
token, shape=(jax.ShapedArray((3, 4), jnp.float32),))
token, shape=(jax.core.ShapedArray((3, 4), jnp.float32),))
(z,), _ = lax.infeed(
token, shape=(jax.ShapedArray((3, 1, 1), jnp.float32),))
token, shape=(jax.core.ShapedArray((3, 1, 1), jnp.float32),))
return x + y + z
x = np.float32(1.5)
@ -55,8 +55,8 @@ class InfeedTest(jtu.JaxTestCase):
x = np.float32(1.5)
y = np.reshape(np.arange(12, dtype=np.int16), (3, 4))
to_infeed = dict(a=x, b=y)
to_infeed_shape = dict(a=jax.ShapedArray((), dtype=np.float32),
b=jax.ShapedArray((3, 4), dtype=np.int16))
to_infeed_shape = dict(a=jax.core.ShapedArray((), dtype=np.float32),
b=jax.core.ShapedArray((3, 4), dtype=np.int16))
@jax.jit
def f(x):
token = lax.create_token(x)
@ -77,7 +77,7 @@ class InfeedTest(jtu.JaxTestCase):
def f(x):
token = lax.create_token(x)
y, token = lax.infeed(
token, shape=jax.ShapedArray((3, 4), jnp.float32))
token, shape=jax.core.ShapedArray((3, 4), jnp.float32))
token = lax.outfeed(token, y + np.float32(1))
return x - 1
@ -97,7 +97,7 @@ class InfeedTest(jtu.JaxTestCase):
def doubler(_, token):
y, token = lax.infeed(
token, shape=jax.ShapedArray((3, 4), jnp.float32))
token, shape=jax.core.ShapedArray((3, 4), jnp.float32))
return lax.outfeed(token, y * np.float32(2))
@jax.jit

View File

@ -538,7 +538,7 @@ class SlurmMultiNodeGpuTest(jtu.JaxTestCase):
f = pjit.pjit(lambda x, y: (x, y),
in_axis_resources=experimental.PartitionSpec("x", "y"),
out_axis_resources=experimental.PartitionSpec("x", "y"))
inp_aval = jax.ShapedArray((8, 2), jnp.int32)
inp_aval = jax.core.ShapedArray((8, 2), jnp.int32)
# `ShapedArray` is considered global when lowered and compiled.
# Hence it can bypass the contiguous mesh restriction.
compiled = f.lower(inp_aval, gda1).compile()

View File

@ -710,11 +710,11 @@ class PJitTest(jtu.BufferDonationTestCase):
def f_for_jit(x):
token = lax.create_token(x)
(y,), token = lax.infeed(
token, shape=(jax.ShapedArray(x.shape, np.float32),))
token, shape=(jax.core.ShapedArray(x.shape, np.float32),))
(z,), token = lax.infeed(
token, shape=(jax.ShapedArray(x.shape, np.float32),))
token, shape=(jax.core.ShapedArray(x.shape, np.float32),))
(w,), token = lax.infeed(
token, shape=(jax.ShapedArray(x.shape, np.float32),))
token, shape=(jax.core.ShapedArray(x.shape, np.float32),))
return x + y + z + w
@ -744,17 +744,17 @@ class PJitTest(jtu.BufferDonationTestCase):
# A replicated infeed
(y,), token = lax.infeed(
token,
shape=(jax.ShapedArray(x.shape, np.float32),),
shape=(jax.core.ShapedArray(x.shape, np.float32),),
partitions=(None,))
# An infeed sharded on first axis
(z,), token = lax.infeed(
token,
shape=(jax.ShapedArray(x.shape, np.float32),),
shape=(jax.core.ShapedArray(x.shape, np.float32),),
partitions=(P(nr_devices, 1),))
# An infeed sharded on second axis
(w,), token = lax.infeed(
token,
shape=(jax.ShapedArray(x.shape, np.float32),),
shape=(jax.core.ShapedArray(x.shape, np.float32),),
partitions=(P(1, nr_devices),))
return x + y + z + w
@ -838,7 +838,7 @@ class PJitTest(jtu.BufferDonationTestCase):
self.assertEqual(lowered.in_avals, compiled.in_avals)
self.assertEqual(
lowered.in_avals,
((jax.ShapedArray(x.shape, x.dtype, weak_type=False),) * 2, {}))
((jax.core.ShapedArray(x.shape, x.dtype, weak_type=False),) * 2, {}))
splits = np.split(expected, 4)
self.assertAllClose(np.asarray(actual.device_buffers[0]), splits[0],
@ -1027,7 +1027,7 @@ class PJitTest(jtu.BufferDonationTestCase):
return x @ y
shape = (8, 8)
aval = jax.ShapedArray(shape, dtypes.canonicalize_dtype(jnp.int64))
aval = jax.core.ShapedArray(shape, dtypes.canonicalize_dtype(jnp.int64))
x = jnp.arange(np.prod(shape)).reshape(shape)
exe = f.lower(aval, x).compile()
self.assertIsInstance(exe, stages.Compiled)
@ -1473,7 +1473,7 @@ class GDAPjitTest(jtu.JaxTestCase):
with global_mesh:
f = pjit(lambda x: x, in_axis_resources=P('x'), out_axis_resources=P('x'))
compiled = f.lower(jax.ShapedArray(global_input_shape, jnp.float32)).compile()
compiled = f.lower(jax.core.ShapedArray(global_input_shape, jnp.float32)).compile()
with self.assertRaisesRegex(
ValueError, "GDA sharding does not match the input sharding."):
compiled(input_gda)
@ -1485,7 +1485,7 @@ class GDAPjitTest(jtu.JaxTestCase):
g1, _ = create_gda(global_input_shape, global_mesh, P(None,))
with global_mesh:
f = pjit(lambda x: x, in_axis_resources=P(None), out_axis_resources=P('x'))
compiled = f.lower(jax.ShapedArray(global_input_shape, jnp.float32)).compile()
compiled = f.lower(jax.core.ShapedArray(global_input_shape, jnp.float32)).compile()
compiled(g1) # no error
@parallel_functions_output_gda(True)
@ -1541,7 +1541,7 @@ class AutoShardingPjitTest(jtu.JaxTestCase):
f = pjit(lambda x: x, in_axis_resources=AUTO,
out_axis_resources=AUTO)
inp = jax.ShapedArray(input_data.shape, input_data.dtype)
inp = jax.core.ShapedArray(input_data.shape, input_data.dtype)
compiled = f.lower(inp).compile()
inputs = [create_gda(global_input_shape, global_mesh, ip, input_data)[0]
for ip in compiled.input_shardings[0]]
@ -1568,7 +1568,7 @@ class AutoShardingPjitTest(jtu.JaxTestCase):
f = pjit(lambda x: x, in_axis_resources=AUTO,
out_axis_resources=AUTO)
inp = jax.ShapedArray(input_data.shape, input_data.dtype)
inp = jax.core.ShapedArray(input_data.shape, input_data.dtype)
compiled = f.lower(inp).compile()
inputs = [create_array(global_input_shape, global_mesh, ip, input_data)[0]
for ip in compiled.input_shardings[0]]
@ -1592,7 +1592,7 @@ class AutoShardingPjitTest(jtu.JaxTestCase):
with ctx(True):
with global_mesh:
f = pjit(lambda x: x, in_axis_resources=AUTO, out_axis_resources=AUTO)
inp = jax.ShapedArray(input_data.shape, input_data.dtype)
inp = jax.core.ShapedArray(input_data.shape, input_data.dtype)
compiled = f.lower(inp).compile()
different_pspec = (P('y', 'x')
@ -1616,7 +1616,7 @@ class AutoShardingPjitTest(jtu.JaxTestCase):
with global_mesh:
f = pjit(lambda x, y, z: (x, y, z), in_axis_resources=AUTO,
out_axis_resources=AUTO)
inp = jax.ShapedArray(input_data.shape, input_data.dtype)
inp = jax.core.ShapedArray(input_data.shape, input_data.dtype)
compiled = f.lower(inp, inp, inp).compile()
self.assertLen(compiled.output_shardings, 3)
self.assertLen(compiled.input_shardings[0], 3)
@ -1644,7 +1644,7 @@ class AutoShardingPjitTest(jtu.JaxTestCase):
f = pjit(lambda x, y: (x, y), in_axis_resources=(in_resource, AUTO),
out_axis_resources=AUTO)
inp = jax.ShapedArray(input_data.shape, input_data.dtype)
inp = jax.core.ShapedArray(input_data.shape, input_data.dtype)
compiled = f.lower(inp, inp).compile()
inputs = [create_gda(global_input_shape, global_mesh, ip, input_data)[0]
for ip in compiled.input_shardings[0]]
@ -1674,7 +1674,7 @@ class AutoShardingPjitTest(jtu.JaxTestCase):
f = pjit(lambda x, y: (x, y), in_axis_resources=(in_resource, AUTO),
out_axis_resources=AUTO)
inp = jax.ShapedArray(input_data.shape, input_data.dtype)
inp = jax.core.ShapedArray(input_data.shape, input_data.dtype)
compiled = f.lower(inp, inp).compile()
inputs = [create_array(global_input_shape, global_mesh, ip, input_data)[0]
for ip in compiled.input_shardings[0]]
@ -1699,7 +1699,7 @@ class AutoShardingPjitTest(jtu.JaxTestCase):
f = pjit(lambda x: x, in_axis_resources=AUTO,
out_axis_resources=AUTO)
inp = jax.ShapedArray(input_data.shape, input_data.dtype)
inp = jax.core.ShapedArray(input_data.shape, input_data.dtype)
compiled = f.lower(inp).compile()
inputs = [create_array(global_input_shape, global_mesh, ip, input_data)[0]
for ip in compiled.input_shardings[0]]
@ -1951,7 +1951,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
a1, input_data = create_array(global_input_shape, global_mesh, P('x', 'y'))
a2, _ = create_array(global_input_shape, global_mesh, P('x'))
aval = jax.ShapedArray(global_input_shape, np.float32)
aval = jax.core.ShapedArray(global_input_shape, np.float32)
with jax_array(True):
with global_mesh:
@ -2075,7 +2075,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
with jax_array(True):
with global_mesh:
f = pjit(lambda x: x, in_axis_resources=NamedSharding(global_mesh, P(None,)))
compiled = f.lower(jax.ShapedArray(input_shape, jnp.float32)).compile()
compiled = f.lower(jax.core.ShapedArray(input_shape, jnp.float32)).compile()
compiled(a1) # no error
@jax_array(True)
@ -2201,7 +2201,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
di_map = s.devices_indices_map(shape)
bufs = [jax.device_put(inp_data[di_map[d]], d)
for d in jax.local_devices()]
arr = array.ArrayImpl(jax.ShapedArray(shape, np.float32), s, bufs, committed=True)
arr = array.ArrayImpl(jax.core.ShapedArray(shape, np.float32), s, bufs, committed=True)
f = pjit(lambda x: x, out_axis_resources=s)
out = f(arr)
@ -2302,7 +2302,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
inp = np.arange(prod(shape), dtype=np.int32).reshape(shape)
arr = array.ArrayImpl(
jax.ShapedArray(shape, np.int32), NamedSharding(mesh, P(None)),
jax.core.ShapedArray(shape, np.int32), NamedSharding(mesh, P(None)),
[jax.device_put(inp, d) for d in mesh.devices.flat], committed=False)
with self.assertRaisesRegex(
NotImplementedError,
@ -2935,7 +2935,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
self.assertEqual(cache_info3.hits, cache_info2.hits)
# AOT test
compiled = f.lower(jax.ShapedArray(y.shape, y.dtype)).compile()
compiled = f.lower(jax.core.ShapedArray(y.shape, y.dtype)).compile()
out3 = compiled(y)
_check(out3, jax.devices()[1], y)
@ -2965,7 +2965,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
g_out = g(x)
_check(g_out, jax.devices()[0], x)
compiled = g.lower(jax.ShapedArray(x.shape, x.dtype)).compile()
compiled = g.lower(jax.core.ShapedArray(x.shape, x.dtype)).compile()
out4 = compiled(x)
_check(out4, jax.devices()[0], x)
@ -3570,7 +3570,7 @@ class UtilTest(jtu.JaxTestCase):
def test_mesh_sharding_spec(self):
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
array_mapping = pxla._get_array_mapping(P('x', 'y'))
aval = jax.ShapedArray((1, 1), jnp.int32)
aval = jax.core.ShapedArray((1, 1), jnp.int32)
with self.assertRaisesRegex(
ValueError,
'The aval shape on dimension 0 is 1 and the size of axis x is 4. The '

View File

@ -204,7 +204,7 @@ class PythonPmapTest(jtu.JaxTestCase):
for obj in [lowered, compiled]:
self.assertFalse(obj._no_kwargs)
self.assertEqual(obj.in_tree, jax.tree_util.tree_flatten(((0,), {}))[1])
self.assertEqual(obj.in_avals, ((jax.ShapedArray(x.shape, x.dtype),), {}))
self.assertEqual(obj.in_avals, ((jax.core.ShapedArray(x.shape, x.dtype),), {}))
def testLowerCompileInTreeMismatch(self):
f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i')