mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Rename Piles to Jumbles, to avoid unfortunate Imperial entanglements.
This commit is contained in:
parent
f348366041
commit
fbb587232c
@ -42,11 +42,11 @@ map, unsafe_map = safe_map, map
|
||||
zip, unsafe_zip = safe_zip, zip
|
||||
|
||||
|
||||
# Piles
|
||||
# Jumbles
|
||||
|
||||
# i:(Fin 3) => f32[[3, 1, 4].i]
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class PileTy:
|
||||
class JumbleTy:
|
||||
binder: core.Var
|
||||
length: Union[int, Tracer, core.Var]
|
||||
elt_ty: core.DShapedArray
|
||||
@ -63,41 +63,41 @@ class IndexedAxisSize:
|
||||
return f'{str(self.lengths)}.Var{id(self.idx)}'
|
||||
replace = dataclasses.replace
|
||||
|
||||
# Pile(aval=a:3 => f32[[3 1 4].a],
|
||||
# data=DeviceArray([0., 1., 2., 0., 0., 1., 2., 3.], dtype=float32))
|
||||
# Jumble(aval=a:3 => f32[[3 1 4].a],
|
||||
# data=DeviceArray([0., 1., 2., 0., 0., 1., 2., 3.], dtype=float32))
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class Pile:
|
||||
aval: PileTy
|
||||
class Jumble:
|
||||
aval: JumbleTy
|
||||
data: Array
|
||||
|
||||
# To vmap over a pile, one must specify the axis as PileAxis.
|
||||
class PileAxis: pass
|
||||
pile_axis = PileAxis()
|
||||
# To vmap over a jumble, one must specify the axis as JumbleAxis.
|
||||
class JumbleAxis: pass
|
||||
jumble_axis = JumbleAxis()
|
||||
|
||||
# As a temporary measure before we have more general JITable / ADable interfaces
|
||||
# (analogues to vmappable), to enable Piles to be used with other
|
||||
# (analogues to vmappable), to enable Jumbles to be used with other
|
||||
# transformations and higher-order primitives (primarily jit, though also grad
|
||||
# with allow_int=True) we register them as pytrees.
|
||||
# TODO(mattjj): add JITable / ADable interfaces, remove this pytree registration
|
||||
def _pile_flatten(pile):
|
||||
def _jumble_flatten(jumble):
|
||||
lengths = []
|
||||
new_shape = [lengths.append(d.lengths) or d.replace(lengths=len(lengths))
|
||||
if type(d) is IndexedAxisSize else d
|
||||
for d in pile.aval.elt_ty.shape]
|
||||
elt_ty = pile.aval.elt_ty.update(shape=tuple(new_shape))
|
||||
aval = pile.aval.replace(elt_ty=elt_ty)
|
||||
return (lengths, pile.data), aval
|
||||
def _pile_unflatten(aval, x):
|
||||
for d in jumble.aval.elt_ty.shape]
|
||||
elt_ty = jumble.aval.elt_ty.update(shape=tuple(new_shape))
|
||||
aval = jumble.aval.replace(elt_ty=elt_ty)
|
||||
return (lengths, jumble.data), aval
|
||||
def _jumble_unflatten(aval, x):
|
||||
lengths, data = x
|
||||
new_shape = [d.replace(lengths=lengths[d.lengths - 1])
|
||||
if type(d) is IndexedAxisSize else d
|
||||
for d in aval.elt_ty.shape]
|
||||
elt_ty = aval.elt_ty.update(shape=tuple(new_shape))
|
||||
aval = aval.replace(elt_ty=elt_ty)
|
||||
return Pile(aval, data)
|
||||
register_pytree_node(Pile, _pile_flatten, _pile_unflatten)
|
||||
return Jumble(aval, data)
|
||||
register_pytree_node(Jumble, _jumble_flatten, _jumble_unflatten)
|
||||
|
||||
def _pile_result(axis_size, stacked_axis, ragged_axes, x):
|
||||
def _jumble_result(axis_size, stacked_axis, ragged_axes, x):
|
||||
binder = core.Var(0, '', core.ShapedArray((), np.dtype('int32')))
|
||||
if stacked_axis != 0:
|
||||
raise NotImplementedError # TODO Transpose x so the stacked axis is axis 0
|
||||
@ -106,14 +106,14 @@ def _pile_result(axis_size, stacked_axis, ragged_axes, x):
|
||||
for ragged_axis, segment_lens in ragged_axes:
|
||||
shape[ragged_axis-1] = IndexedAxisSize(binder, segment_lens)
|
||||
elt_ty = core.DShapedArray(tuple(shape), x.dtype, x.weak_type)
|
||||
return Pile(PileTy(binder, axis_size, elt_ty), x)
|
||||
return Jumble(JumbleTy(binder, axis_size, elt_ty), x)
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class RaggedAxis:
|
||||
stacked_axis: int
|
||||
# For each axis, we store its index and the corresponding segment lengths.
|
||||
# For example, the pile i:(Fin 3) => f32[lens1.i, 7, lens2.i]
|
||||
# For example, the jumble i:(Fin 3) => f32[lens1.i, 7, lens2.i]
|
||||
# would be represented with ragged_axes = [(1, lens1), (3, lens2)]
|
||||
ragged_axes: tuple[tuple[int, Array], ...]
|
||||
|
||||
@ -234,9 +234,9 @@ def to_elt(trace: Trace, get_idx: GetIdx, x: Vmappable, spec: MapSpec) -> Elt:
|
||||
handler = to_elt_handlers.get(type(x))
|
||||
if handler:
|
||||
return handler(partial(to_elt, trace, get_idx), get_idx, x, spec)
|
||||
elif type(x) is Pile:
|
||||
if spec is not pile_axis:
|
||||
raise TypeError("pile input without using pile_axis in_axes spec")
|
||||
elif type(x) is Jumble:
|
||||
if spec is not jumble_axis:
|
||||
raise TypeError("jumble input without using jumble_axis in_axes spec")
|
||||
ias: IndexedAxisSize # Not present in the AxisSize union in core.py
|
||||
(d, ias), = ((i, sz) # type: ignore
|
||||
for i, sz in enumerate(x.aval.elt_ty.shape)
|
||||
@ -259,10 +259,10 @@ def from_elt(trace: 'BatchTrace', axis_size: AxisSize, x: Elt, spec: MapSpec
|
||||
x_ = trace.full_raise(x)
|
||||
val, bdim = x_.val, x_.batch_dim
|
||||
if type(bdim) is RaggedAxis:
|
||||
if spec is not pile_axis:
|
||||
if spec is not jumble_axis:
|
||||
# TODO(mattjj): improve this error message
|
||||
raise TypeError("ragged output without using pile_axis out_axes spec")
|
||||
return _pile_result(axis_size, bdim.stacked_axis, bdim.ragged_axes, val)
|
||||
raise TypeError("ragged output without using jumble_axis out_axes spec")
|
||||
return _jumble_result(axis_size, bdim.stacked_axis, bdim.ragged_axes, val)
|
||||
else:
|
||||
return matchaxis(trace.axis_name, axis_size, x_.batch_dim, spec, x_.val)
|
||||
from_elt_handlers: dict[type, FromEltHandler] = {}
|
||||
@ -284,7 +284,7 @@ def register_vmappable(data_type: type, spec_type: type, axis_size_type: type,
|
||||
from_elt_handlers[data_type] = from_elt
|
||||
if make_iota: make_iota_handlers[axis_size_type] = make_iota
|
||||
vmappables: dict[type, tuple[type, type]] = {}
|
||||
spec_types: set[type] = {PileAxis}
|
||||
spec_types: set[type] = {JumbleAxis}
|
||||
|
||||
def unregister_vmappable(data_type: type) -> None:
|
||||
spec_type, axis_size_type = vmappables.pop(data_type)
|
||||
@ -295,7 +295,7 @@ def unregister_vmappable(data_type: type) -> None:
|
||||
del make_iota_handlers[axis_size_type]
|
||||
|
||||
def is_vmappable(x: Any) -> bool:
|
||||
return type(x) is Pile or type(x) in vmappables
|
||||
return type(x) is Jumble or type(x) in vmappables
|
||||
|
||||
@lu.transformation_with_aux
|
||||
def flatten_fun_for_vmap(in_tree, *args_flat):
|
||||
@ -1089,12 +1089,12 @@ def broadcast(x, sz, axis):
|
||||
return jax.lax.broadcast_in_dim(x, shape, broadcast_dims)
|
||||
|
||||
def matchaxis(axis_name, sz, src, dst, x, sum_match=False):
|
||||
if dst == pile_axis:
|
||||
if dst == jumble_axis:
|
||||
x = bdim_at_front(x, src, sz)
|
||||
elt_ty = x.aval.update(shape=x.shape[1:])
|
||||
aval = PileTy(core.Var(0, '', core.ShapedArray((), np.dtype('int32'))),
|
||||
x.shape[0], elt_ty)
|
||||
return Pile(aval, x)
|
||||
aval = JumbleTy(core.Var(0, '', core.ShapedArray((), np.dtype('int32'))),
|
||||
x.shape[0], elt_ty)
|
||||
return Jumble(aval, x)
|
||||
try:
|
||||
_ = core.get_aval(x)
|
||||
except TypeError as e:
|
||||
|
@ -29,9 +29,9 @@ from jax._src.interpreters.batching import (
|
||||
MakeIotaHandler as MakeIotaHandler,
|
||||
MapSpec as MapSpec,
|
||||
NotMapped as NotMapped,
|
||||
Pile as Pile,
|
||||
PileAxis as PileAxis,
|
||||
PileTy as PileTy,
|
||||
Jumble as Jumble,
|
||||
JumbleAxis as JumbleAxis,
|
||||
JumbleTy as JumbleTy,
|
||||
ToEltHandler as ToEltHandler,
|
||||
Vmappable as Vmappable,
|
||||
Zero as Zero,
|
||||
@ -60,7 +60,7 @@ from jax._src.interpreters.batching import (
|
||||
matchaxis as matchaxis,
|
||||
moveaxis as moveaxis,
|
||||
not_mapped as not_mapped,
|
||||
pile_axis as pile_axis,
|
||||
jumble_axis as jumble_axis,
|
||||
primitive_batchers as primitive_batchers,
|
||||
reducer_batcher as reducer_batcher,
|
||||
register_vmappable as register_vmappable,
|
||||
|
@ -1490,68 +1490,69 @@ class DynamicShapeExecutionTest(jtu.JaxTestCase):
|
||||
|
||||
@jtu.with_config(jax_dynamic_shapes=True, jax_numpy_rank_promotion="allow",
|
||||
jax_traceback_filtering='off')
|
||||
class PileTest(jtu.JaxTestCase):
|
||||
class JumbleTest(jtu.JaxTestCase):
|
||||
|
||||
@parameterized.parameters((True,), (False,))
|
||||
def test_internal_pile(self, disable_jit):
|
||||
def test_internal_jumble(self, disable_jit):
|
||||
config.update('jax_disable_jit', disable_jit)
|
||||
ins = lax.convert_element_type(jnp.array([3, 1, 4]), core.bint(5))
|
||||
xs = jax.vmap(lambda n: jax.lax.iota('int32', n).sum())(ins)
|
||||
self.assertAllClose(xs, jnp.array([3, 0, 6]), check_dtypes=False)
|
||||
|
||||
def test_pile_escapes(self):
|
||||
def test_jumble_escapes(self):
|
||||
ins = lax.convert_element_type(jnp.array([3, 1, 4]), core.bint(5))
|
||||
xs = jax.vmap(jax.jit(lambda n: jax.lax.iota('int32', n)),
|
||||
out_axes=batching.pile_axis)(ins)
|
||||
self.assertIsInstance(xs, batching.Pile)
|
||||
out_axes=batching.jumble_axis)(ins)
|
||||
self.assertIsInstance(xs, batching.Jumble)
|
||||
data = jax.lax.broadcasted_iota('int32', (3, 5), 1)
|
||||
self.assertAllClose(xs.data, data, check_dtypes=False)
|
||||
|
||||
def test_make_pile_from_dynamic_shape(self):
|
||||
# We may not want to support returning piles from vmapped functions (instead
|
||||
# preferring to have a separate API which allows piles). But for now it
|
||||
# makes for a convenient way to construct piles for the other tests!
|
||||
def test_make_jumble_from_dynamic_shape(self):
|
||||
# We may not want to support returning jumbles from vmapped functions
|
||||
# (instead preferring to have a separate API which allows jumbles). But for
|
||||
# now it makes for a convenient way to construct jumbles for the other
|
||||
# tests!
|
||||
ins = lax.convert_element_type(jnp.array([3, 1, 4]), core.bint(5))
|
||||
p = jax.vmap(partial(jnp.arange, dtype='int32'),
|
||||
out_axes=batching.pile_axis)(ins)
|
||||
self.assertIsInstance(p, batching.Pile)
|
||||
out_axes=batching.jumble_axis)(ins)
|
||||
self.assertIsInstance(p, batching.Jumble)
|
||||
self.assertRegex(str(p.aval), r'Var[0-9]+:3 => i32\[bint\{≤5\}\[3\] with value: \[3 1 4\]\.Var[0-9]+\]')
|
||||
data = jax.lax.broadcasted_iota('int32', (3, 5), 1)
|
||||
self.assertAllClose(p.data, data, check_dtypes=False)
|
||||
|
||||
def test_pile_map_eltwise(self):
|
||||
def test_jumble_map_eltwise(self):
|
||||
ins = lax.convert_element_type(jnp.array([3, 1, 4]), core.bint(5))
|
||||
p = jax.vmap(partial(jnp.arange, dtype='int32'),
|
||||
out_axes=batching.pile_axis)(ins)
|
||||
p = pile_map(jax.jit(lambda x: x ** 2))(p)
|
||||
self.assertIsInstance(p, batching.Pile)
|
||||
out_axes=batching.jumble_axis)(ins)
|
||||
p = jumble_map(jax.jit(lambda x: x ** 2))(p)
|
||||
self.assertIsInstance(p, batching.Jumble)
|
||||
self.assertRegex(str(p.aval), r'Var[0-9]+:3 => i32\[bint\{≤5\}\[3\] with value: \[3 1 4\]\.Var[0-9]+\]')
|
||||
data = jax.lax.broadcasted_iota('int32', (3, 5), 1) ** 2
|
||||
self.assertAllClose(p.data, data, check_dtypes=False)
|
||||
|
||||
def test_pile_map_vector_dot(self):
|
||||
def test_jumble_map_vector_dot(self):
|
||||
ins = lax.convert_element_type(jnp.array([3, 1, 4]), core.bint(5))
|
||||
p = jax.vmap(partial(jnp.arange, dtype='int32'),
|
||||
out_axes=batching.pile_axis)(ins)
|
||||
y = pile_map(jnp.dot)(p, p)
|
||||
self.assertIsInstance(y, batching.Pile)
|
||||
out_axes=batching.jumble_axis)(ins)
|
||||
y = jumble_map(jnp.dot)(p, p)
|
||||
self.assertIsInstance(y, batching.Jumble)
|
||||
self.assertAllClose(y.data, jnp.array([5, 0, 14], dtype='int32'))
|
||||
|
||||
@parameterized.parameters((True,), (False,))
|
||||
def test_pile_map_matrix_dot_ragged_contract(self, disable_jit):
|
||||
def test_jumble_map_matrix_dot_ragged_contract(self, disable_jit):
|
||||
config.update('jax_disable_jit', disable_jit)
|
||||
sizes = lax.convert_element_type(jnp.array([3, 1, 4]), core.bint(5))
|
||||
p1 = jax.vmap(lambda n: jnp.ones((7, n)), out_axes=batching.pile_axis
|
||||
p1 = jax.vmap(lambda n: jnp.ones((7, n)), out_axes=batching.jumble_axis
|
||||
)(sizes)
|
||||
p2 = jax.vmap(lambda n: jnp.ones((n, 7)), out_axes=batching.pile_axis
|
||||
p2 = jax.vmap(lambda n: jnp.ones((n, 7)), out_axes=batching.jumble_axis
|
||||
)(sizes)
|
||||
y = jax.vmap(jnp.dot, in_axes=batching.pile_axis, out_axes=0,
|
||||
y = jax.vmap(jnp.dot, in_axes=batching.jumble_axis, out_axes=0,
|
||||
axis_size=3)(p1, p2)
|
||||
self.assertAllClose(y, np.tile(np.array([3, 1, 4])[:, None, None], (7, 7)),
|
||||
check_dtypes=False)
|
||||
|
||||
@parameterized.parameters((True,), (False,))
|
||||
def test_pile_map_matrix_dot_ragged_tensor(self, disable_jit):
|
||||
def test_jumble_map_matrix_dot_ragged_tensor(self, disable_jit):
|
||||
config.update('jax_disable_jit', disable_jit)
|
||||
sizes = lax.convert_element_type(jnp.array([3, 1, 4]), core.bint(5))
|
||||
def func(size):
|
||||
@ -1559,8 +1560,8 @@ class PileTest(jtu.JaxTestCase):
|
||||
lhs_two_d = jax.lax.broadcast_in_dim(lhs_one_d, (size, 2), (0,))
|
||||
rhs = jax.lax.broadcasted_iota('int32', (2, 4), 0) + 1
|
||||
return jnp.dot(lhs_two_d, rhs)
|
||||
p = jax.vmap(func, out_axes=batching.pile_axis)(sizes)
|
||||
self.assertIsInstance(p, batching.Pile)
|
||||
p = jax.vmap(func, out_axes=batching.jumble_axis)(sizes)
|
||||
self.assertIsInstance(p, batching.Jumble)
|
||||
self.assertEqual(p.data.shape, (3, 5, 4))
|
||||
|
||||
def test_broadcast_in_dim_while_ragged(self):
|
||||
@ -1569,8 +1570,8 @@ class PileTest(jtu.JaxTestCase):
|
||||
one_d = jnp.arange(size, dtype='int32')
|
||||
two_d = jax.lax.broadcast_in_dim(one_d, (size, 7), (0,))
|
||||
return two_d
|
||||
p = jax.vmap(func, out_axes=batching.pile_axis)(ins)
|
||||
self.assertIsInstance(p, batching.Pile)
|
||||
p = jax.vmap(func, out_axes=batching.jumble_axis)(ins)
|
||||
self.assertIsInstance(p, batching.Jumble)
|
||||
data = jax.lax.broadcasted_iota('int32', (3, 5, 7), 1)
|
||||
self.assertAllClose(p.data, data)
|
||||
|
||||
@ -1580,8 +1581,8 @@ class PileTest(jtu.JaxTestCase):
|
||||
one_d = jnp.arange(12, dtype='int32')
|
||||
two_d = jax.lax.broadcast_in_dim(one_d, (size, 12), (1,))
|
||||
return two_d
|
||||
p = jax.vmap(func, out_axes=batching.pile_axis)(ins)
|
||||
self.assertIsInstance(p, batching.Pile)
|
||||
p = jax.vmap(func, out_axes=batching.jumble_axis)(ins)
|
||||
self.assertIsInstance(p, batching.Jumble)
|
||||
data = jax.lax.broadcasted_iota('int32', (3, 5, 12), 2)
|
||||
self.assertAllClose(p.data, data)
|
||||
|
||||
@ -1595,7 +1596,7 @@ class PileTest(jtu.JaxTestCase):
|
||||
return two_d
|
||||
msg = r"got operand of shape \(\[dynamic\],\), target broadcast shape \(4, 5\)"
|
||||
with self.assertRaisesRegex(TypeError, msg):
|
||||
jax.vmap(func, out_axes=batching.pile_axis)(ins)
|
||||
jax.vmap(func, out_axes=batching.jumble_axis)(ins)
|
||||
|
||||
def test_broadcast_in_dim_to_doubly_ragged(self):
|
||||
ins1 = lax.convert_element_type(jnp.array([3, 1, 4]), core.bint(5))
|
||||
@ -1604,8 +1605,8 @@ class PileTest(jtu.JaxTestCase):
|
||||
one_d = jnp.arange(size1, dtype='int32')
|
||||
two_d = jax.lax.broadcast_in_dim(one_d, (size1, size2), (0,))
|
||||
return two_d
|
||||
p = jax.vmap(func, out_axes=batching.pile_axis)(ins1, ins2)
|
||||
self.assertIsInstance(p, batching.Pile)
|
||||
p = jax.vmap(func, out_axes=batching.jumble_axis)(ins1, ins2)
|
||||
self.assertIsInstance(p, batching.Jumble)
|
||||
data = jax.lax.broadcasted_iota('int32', (3, 5, 6), 1)
|
||||
self.assertAllClose(p.data, data)
|
||||
|
||||
@ -1616,8 +1617,8 @@ class PileTest(jtu.JaxTestCase):
|
||||
two_d = jax.lax.broadcast_in_dim(one_d, (size, 1), (0,))
|
||||
one_again = jax.lax.squeeze(two_d, dimensions=[1])
|
||||
return one_again
|
||||
p = jax.vmap(func, out_axes=batching.pile_axis)(ins)
|
||||
self.assertIsInstance(p, batching.Pile)
|
||||
p = jax.vmap(func, out_axes=batching.jumble_axis)(ins)
|
||||
self.assertIsInstance(p, batching.Jumble)
|
||||
data = jax.lax.broadcasted_iota('int32', (3, 5), 1)
|
||||
self.assertAllClose(p.data, data)
|
||||
|
||||
@ -1627,8 +1628,8 @@ class PileTest(jtu.JaxTestCase):
|
||||
one_d = jnp.arange(size, dtype='int32')
|
||||
two_d = jnp.broadcast_to(one_d, (4, size))
|
||||
return two_d
|
||||
p = jax.vmap(func, out_axes=batching.pile_axis)(ins)
|
||||
self.assertIsInstance(p, batching.Pile)
|
||||
p = jax.vmap(func, out_axes=batching.jumble_axis)(ins)
|
||||
self.assertIsInstance(p, batching.Jumble)
|
||||
data = jax.lax.broadcasted_iota('int32', (3, 4, 5), 2)
|
||||
self.assertAllClose(p.data, data)
|
||||
|
||||
@ -1638,8 +1639,8 @@ class PileTest(jtu.JaxTestCase):
|
||||
one_d = jnp.arange(size, dtype='int32')
|
||||
two_d = jnp.broadcast_to(one_d, (size, size))
|
||||
return two_d
|
||||
p = jax.vmap(func, out_axes=batching.pile_axis)(ins)
|
||||
self.assertIsInstance(p, batching.Pile)
|
||||
p = jax.vmap(func, out_axes=batching.jumble_axis)(ins)
|
||||
self.assertIsInstance(p, batching.Jumble)
|
||||
data = jax.lax.broadcasted_iota('int32', (3, 5, 5), 2)
|
||||
self.assertAllClose(p.data, data)
|
||||
|
||||
@ -1649,8 +1650,8 @@ class PileTest(jtu.JaxTestCase):
|
||||
one_d = jnp.arange(size, dtype='int32')
|
||||
two_d = jnp.broadcast_to(one_d, (7, size))
|
||||
return jnp.transpose(two_d, [1, 0])
|
||||
p = jax.vmap(func, out_axes=batching.pile_axis)(ins)
|
||||
self.assertIsInstance(p, batching.Pile)
|
||||
p = jax.vmap(func, out_axes=batching.jumble_axis)(ins)
|
||||
self.assertIsInstance(p, batching.Jumble)
|
||||
data = jax.lax.broadcasted_iota('int32', (3, 5, 7), 1)
|
||||
self.assertAllClose(p.data, data)
|
||||
|
||||
@ -1662,8 +1663,8 @@ class PileTest(jtu.JaxTestCase):
|
||||
wqkv = jax.lax.broadcasted_iota('int32', (3, 2, 7, 11), 1)
|
||||
qkv = jnp.einsum('te,ihqe->ithq', x, wqkv)
|
||||
return qkv
|
||||
p = jax.vmap(fprop_layer, out_axes=batching.pile_axis)(x_sizes)
|
||||
self.assertIsInstance(p, batching.Pile)
|
||||
p = jax.vmap(fprop_layer, out_axes=batching.jumble_axis)(x_sizes)
|
||||
self.assertIsInstance(p, batching.Jumble)
|
||||
self.assertRegex(str(p.aval), r'Var[0-9]+:3 => i32\[3,bint\{≤5\}\[3\] with value: \[3 1 4\]\.Var[0-9]+,2,7\]')
|
||||
self.assertEqual(p.data.shape, (3, 3, 5, 2, 7))
|
||||
|
||||
@ -1677,8 +1678,8 @@ class PileTest(jtu.JaxTestCase):
|
||||
v = jax.lax.broadcast_in_dim(one_d, (ragged_size, 2, 7), [0])
|
||||
inner = jnp.einsum('tsh,shq->thq', alpha, v)
|
||||
return inner
|
||||
p = jax.vmap(fprop_layer, out_axes=batching.pile_axis)(ragged_sizes)
|
||||
self.assertIsInstance(p, batching.Pile)
|
||||
p = jax.vmap(fprop_layer, out_axes=batching.jumble_axis)(ragged_sizes)
|
||||
self.assertIsInstance(p, batching.Jumble)
|
||||
self.assertRegex(str(p.aval), r'Var[0-9]+:3 => i32\[bint\{≤5\}\[3\] with value: \[3 1 4\]\.Var[0-9]+,2,7\]')
|
||||
self.assertEqual(p.data.shape, (3, 5, 2, 7))
|
||||
|
||||
@ -1689,14 +1690,14 @@ class PileTest(jtu.JaxTestCase):
|
||||
two_d = jnp.broadcast_to(one_d, (2, size))
|
||||
part_1, part_2 = two_d
|
||||
return part_1
|
||||
p = jax.vmap(func, out_axes=batching.pile_axis)(ins)
|
||||
self.assertIsInstance(p, batching.Pile)
|
||||
p = jax.vmap(func, out_axes=batching.jumble_axis)(ins)
|
||||
self.assertIsInstance(p, batching.Jumble)
|
||||
self.assertRegex(str(p.aval), r'Var[0-9]+:3 => i32\[bint\{≤5\}\[3\] with value: \[3 1 4\]\.Var[0-9]+\]')
|
||||
data = jax.lax.broadcasted_iota('int32', (3, 5), 1)
|
||||
self.assertAllClose(p.data, data)
|
||||
|
||||
@parameterized.parameters((True,), (False,))
|
||||
def test_pile_map_end_to_end_fprop_layer(self, disable_jit):
|
||||
def test_jumble_map_end_to_end_fprop_layer(self, disable_jit):
|
||||
config.update('jax_disable_jit', disable_jit)
|
||||
|
||||
def fprop_layer(params, x):
|
||||
@ -1731,13 +1732,12 @@ class PileTest(jtu.JaxTestCase):
|
||||
jnp.zeros((420, 128)),
|
||||
]
|
||||
|
||||
def pile_stack(xs: list[jax.Array]) -> batching.Pile:
|
||||
def jumble_stack(xs: list[jax.Array]) -> batching.Jumble:
|
||||
max_length = max(len(x) for x in xs)
|
||||
lengths = jnp.array([len(x) for x in xs])
|
||||
lengths = jax.lax.convert_element_type(lengths, core.bint(max_length))
|
||||
xs_padded = jnp.stack([jnp.zeros((max_length, 128), dtype=x.dtype
|
||||
).at[:x.shape[0]].set(x) for x in xs])
|
||||
# jax.vmap(lambda l, xp: xp[:l, :], out_axes=pile_axis)(lengths, xs_padded)
|
||||
|
||||
# binder = i
|
||||
binder = core.Var(0, '', core.ShapedArray((), np.dtype('int32')))
|
||||
@ -1745,26 +1745,26 @@ class PileTest(jtu.JaxTestCase):
|
||||
elt_ty = core.DShapedArray((batching.IndexedAxisSize(binder, lengths), 128),
|
||||
xs_padded.dtype)
|
||||
# aval = i:(Fin 3) => f32[[3, 1, 4].i, 128]
|
||||
aval = batching.PileTy(binder, len(xs), elt_ty)
|
||||
xs_pile = batching.Pile(aval, xs_padded)
|
||||
return xs_pile
|
||||
aval = batching.JumbleTy(binder, len(xs), elt_ty)
|
||||
xs_jumble = batching.Jumble(aval, xs_padded)
|
||||
return xs_jumble
|
||||
|
||||
xs_pile = pile_stack(xs)
|
||||
xs_jumble = jumble_stack(xs)
|
||||
|
||||
fprop_batched = jax.vmap(fprop_layer,
|
||||
in_axes=(None, batching.pile_axis),
|
||||
out_axes=batching.pile_axis,
|
||||
in_axes=(None, batching.jumble_axis),
|
||||
out_axes=batching.jumble_axis,
|
||||
axis_size=3)
|
||||
result_jumble = fprop_batched(params, xs_jumble)
|
||||
self.assertIsInstance(result_jumble, batching.Jumble)
|
||||
regex = r'Var[0-9]+:3 => (f32|f64)\[bint\{≤512\}\[3\] with value: \[512 386 420\]\.Var[0-9]+,128\]'
|
||||
self.assertRegex(str(result_jumble.aval), regex)
|
||||
self.assertAllClose(result_jumble.data.shape, (3, 512, 128))
|
||||
|
||||
result_pile = fprop_batched(params, xs_pile)
|
||||
self.assertIsInstance(result_pile, batching.Pile)
|
||||
self.assertRegex(str(result_pile.aval), r'Var[0-9]+:3 => (f32|f64)\[bint\{≤512\}\[3\] with value: \[512 386 420\]\.Var[0-9]+,128\]')
|
||||
self.assertAllClose(result_pile.data.shape, (3, 512, 128))
|
||||
|
||||
def pile_map(f):
|
||||
def mapped(*piles):
|
||||
return jax.vmap(f, in_axes=batching.pile_axis, out_axes=batching.pile_axis,
|
||||
axis_size=piles[0].aval.length)(*piles)
|
||||
def jumble_map(f):
|
||||
def mapped(*jumbles):
|
||||
return jax.vmap(f, in_axes=batching.jumble_axis, out_axes=batching.jumble_axis,
|
||||
axis_size=jumbles[0].aval.length)(*jumbles)
|
||||
return mapped
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
Loading…
x
Reference in New Issue
Block a user