mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
Implement lax.map
Fixes GH-1113
This commit is contained in:
parent
a26963fe87
commit
8c628a267b
@ -133,6 +133,7 @@ Control flow operators
|
||||
|
||||
cond
|
||||
fori_loop
|
||||
map
|
||||
scan
|
||||
while_loop
|
||||
|
||||
|
@ -40,7 +40,7 @@ from jax.util import partial, unzip2, safe_map, safe_zip
|
||||
from jax.tree_util import build_tree, tree_unflatten, tree_map
|
||||
from jax import ad_util
|
||||
|
||||
map = safe_map
|
||||
_map = safe_map
|
||||
zip = safe_zip
|
||||
|
||||
|
||||
@ -220,7 +220,7 @@ def _while_loop_batching_rule(batched_args, batch_dims,
|
||||
init_val, cond_consts, body_consts = batched_args
|
||||
init_val_bd, cond_consts_bd, body_consts_bd = batch_dims
|
||||
|
||||
sizes = lax._reduce(set.union, map(batching.dimsize, batch_dims, batched_args))
|
||||
sizes = lax._reduce(set.union, _map(batching.dimsize, batch_dims, batched_args))
|
||||
size = sizes.pop()
|
||||
assert not sizes
|
||||
|
||||
@ -253,7 +253,7 @@ def _while_loop_batching_rule(batched_args, batch_dims,
|
||||
def _jaxtupletree_select(pred, on_true, on_false):
|
||||
aval = core.get_aval(on_true)
|
||||
if type(aval) is core.AbstractTuple:
|
||||
return core.pack(map(partial(_jaxtupletree_select, pred), on_true, on_false))
|
||||
return core.pack(_map(partial(_jaxtupletree_select, pred), on_true, on_false))
|
||||
elif isinstance(aval, UnshapedArray):
|
||||
return lax.select(pred, on_true, on_false)
|
||||
else:
|
||||
@ -402,7 +402,7 @@ xla.initial_style_translations[cond_p] = _cond_translation_rule
|
||||
|
||||
def _maybe_tracer_tuple_to_abstract_tuple(tup):
|
||||
if isinstance(tup, pe.JaxprTracerTuple):
|
||||
return core.AbstractTuple(list(map(_maybe_tracer_tuple_to_abstract_tuple, tup)))
|
||||
return core.AbstractTuple(list(_map(_maybe_tracer_tuple_to_abstract_tuple, tup)))
|
||||
elif isinstance(tup, core.AbstractValue):
|
||||
return tup
|
||||
elif tup is None:
|
||||
@ -424,10 +424,10 @@ def _convert_zeros(instantiate, example, tangent):
|
||||
raise TypeError(tangent) # not clear if ever reachable
|
||||
elif t is tuple:
|
||||
if type(tangent) is ad.TangentTuple:
|
||||
return core.pack(map(_convert_zeros, instantiate, example, tangent))
|
||||
return core.pack(_map(_convert_zeros, instantiate, example, tangent))
|
||||
elif tangent is ad_util.zero:
|
||||
zeros = [ad_util.zero] * len(instantiate)
|
||||
return core.pack(map(_convert_zeros, instantiate, example, zeros))
|
||||
return core.pack(_map(_convert_zeros, instantiate, example, zeros))
|
||||
else:
|
||||
raise TypeError(tangent)
|
||||
else:
|
||||
@ -436,20 +436,20 @@ def _convert_zeros(instantiate, example, tangent):
|
||||
def _demote_aval_rank(xs):
|
||||
assert isinstance(xs, core.AbstractValue)
|
||||
if isinstance(xs, core.AbstractTuple):
|
||||
return core.AbstractTuple(map(_demote_aval_rank, xs))
|
||||
return core.AbstractTuple(_map(_demote_aval_rank, xs))
|
||||
else:
|
||||
return ShapedArray(xs.shape[1:], xs.dtype)
|
||||
|
||||
def _promote_aval_rank(n, xs):
|
||||
assert isinstance(xs, core.AbstractValue)
|
||||
if isinstance(xs, core.AbstractTuple):
|
||||
return core.AbstractTuple(map(partial(_promote_aval_rank, n), xs))
|
||||
return core.AbstractTuple(_map(partial(_promote_aval_rank, n), xs))
|
||||
else:
|
||||
return ShapedArray((n,) + xs.shape, xs.dtype)
|
||||
|
||||
def _leading_dim_size(xs):
|
||||
if isinstance(xs, core.JaxTuple):
|
||||
sizes = set(map(_leading_dim_size, xs))
|
||||
sizes = set(_map(_leading_dim_size, xs))
|
||||
if len(sizes) == 1:
|
||||
return sizes.pop()
|
||||
elif len(sizes) > 1:
|
||||
@ -468,21 +468,21 @@ def _leading_dim_size(xs):
|
||||
def _empty_arrays(aval):
|
||||
assert isinstance(aval, core.AbstractValue)
|
||||
if isinstance(aval, core.AbstractTuple):
|
||||
return core.pack(map(_empty_arrays, aval))
|
||||
return core.pack(_map(_empty_arrays, aval))
|
||||
else:
|
||||
return lax.full(aval.shape, 0, aval.dtype)
|
||||
|
||||
def _index_arrays(i, aval, xs):
|
||||
assert isinstance(aval, core.AbstractValue)
|
||||
if isinstance(aval, core.AbstractTuple):
|
||||
return core.pack(map(partial(_index_arrays, i), aval, xs))
|
||||
return core.pack(_map(partial(_index_arrays, i), aval, xs))
|
||||
else:
|
||||
return lax.dynamic_index_in_dim(xs, i, keepdims=False)
|
||||
|
||||
def _update_arrays(i, aval, xs, x):
|
||||
assert isinstance(aval, core.AbstractValue)
|
||||
if isinstance(aval, core.AbstractTuple):
|
||||
return core.pack(map(partial(_update_arrays, i), aval, xs, x))
|
||||
return core.pack(_map(partial(_update_arrays, i), aval, xs, x))
|
||||
else:
|
||||
x = lax.reshape(x, (1,) + onp.shape(x))
|
||||
return lax.dynamic_update_index_in_dim(xs, x, i, axis=0)
|
||||
@ -543,7 +543,7 @@ def scan(f, init, xs):
|
||||
loop carry value and the second element represents the stacked outputs of
|
||||
the second output of ``f`` when scanned over the leading axis of the inputs.
|
||||
"""
|
||||
(init, xs), in_trees = unzip2(map(pytree_to_jaxtupletree, (init, xs)))
|
||||
(init, xs), in_trees = unzip2(_map(pytree_to_jaxtupletree, (init, xs)))
|
||||
f, out_tree = pytree_fun_to_jaxtupletree_fun(lu.wrap_init(f), in_trees)
|
||||
carry_pval = carry_aval, _ = _abstractify(init)
|
||||
xs_aval, _ = _abstractify(xs)
|
||||
@ -639,11 +639,11 @@ def _binary_lattice_fold(f, pack, a, b):
|
||||
recur = partial(_binary_lattice_fold, f, pack)
|
||||
t = (type(a), type(b))
|
||||
if t == (tuple, tuple):
|
||||
return pack(map(recur, a, b))
|
||||
return pack(_map(recur, a, b))
|
||||
elif t == (tuple, bool):
|
||||
return pack(map(recur, a, (b,) * len(a)))
|
||||
return pack(_map(recur, a, (b,) * len(a)))
|
||||
elif t == (bool, tuple):
|
||||
return pack(map(recur, (a,) * len(b), b))
|
||||
return pack(_map(recur, (a,) * len(b), b))
|
||||
elif t == (bool, bool):
|
||||
return f(a, b)
|
||||
else:
|
||||
@ -659,7 +659,7 @@ def _scan_partial_eval(trace, *tracers, **kwargs):
|
||||
forward = kwargs.pop('forward')
|
||||
assert not kwargs
|
||||
in_pvs, _ = unzip2([t.pval for t in tracers])
|
||||
sc_consts, sc_init, sc_xs = map(pe.unknown, in_pvs)
|
||||
sc_consts, sc_init, sc_xs = _map(pe.unknown, in_pvs)
|
||||
|
||||
sc_carry = sc_init
|
||||
for i in range(1000):
|
||||
@ -702,8 +702,8 @@ def _lift_tracer(trace, tracer, is_unknown):
|
||||
else:
|
||||
return tracer
|
||||
elif t is tuple:
|
||||
tracers = map(trace.full_raise, tracer)
|
||||
return core.pack(map(partial(_lift_tracer, trace), tracers, is_unknown))
|
||||
tracers = _map(trace.full_raise, tracer)
|
||||
return core.pack(_map(partial(_lift_tracer, trace), tracers, is_unknown))
|
||||
else:
|
||||
raise TypeError(t)
|
||||
|
||||
@ -713,7 +713,7 @@ def _put_known_pvs(is_unknown, aval):
|
||||
elif is_unknown is True:
|
||||
return aval
|
||||
else:
|
||||
return pe.JaxprTracerTuple(map(_put_known_pvs, is_unknown, aval))
|
||||
return pe.JaxprTracerTuple(_map(_put_known_pvs, is_unknown, aval))
|
||||
|
||||
|
||||
def _scan_transpose(ct, consts, init, xs, forward, length, jaxpr):
|
||||
@ -836,7 +836,7 @@ def _scan_batching_rule(batched_args, batch_dims, forward, length, jaxpr):
|
||||
consts, init, xs = batched_args
|
||||
consts_bdim, init_bdim, xs_bdim = batch_dims
|
||||
|
||||
sizes = lax._reduce(set.union, map(batching.dimsize, batch_dims, batched_args))
|
||||
sizes = lax._reduce(set.union, _map(batching.dimsize, batch_dims, batched_args))
|
||||
size = sizes.pop()
|
||||
assert not sizes
|
||||
|
||||
@ -889,3 +889,34 @@ ad.primitive_transposes[scan_p] = _scan_transpose
|
||||
pe.custom_partial_eval_rules[scan_p] = _scan_partial_eval
|
||||
xla.initial_style_translations[scan_p] = xla.lower_fun(_scan_impl, initial_style=True)
|
||||
batching.primitive_batchers[scan_p] = _scan_batching_rule
|
||||
|
||||
|
||||
def map(f, xs):
|
||||
"""Map a function over leading array axes.
|
||||
|
||||
Like Python's builtin map, except inputs and outputs are in the form of
|
||||
stacked arrays. Consider using the ``jax.vmap`` transform instead, unless you
|
||||
need to apply a function element by element for reduced memory usage or
|
||||
heterogeneous computation with other control flow primitives.
|
||||
|
||||
When ``xs`` is an array type, the semantics of ``map`` are given by this
|
||||
Python implementation::
|
||||
|
||||
def map(f, xs):
|
||||
return np.stack([f(x) for x in xs])
|
||||
|
||||
Like ``scan``, ``map`` is implemented in terms of JAX primtivies so many of
|
||||
the same advantages over a Python loop apply: ``xs`` may be an arbitrary
|
||||
nested pytree type, and the mapped computation is compiled only once.
|
||||
|
||||
Args:
|
||||
f: a Python function to apply element-wise over the first axis or axes of
|
||||
``xs``.
|
||||
xs: values over which to map along the leading axis.
|
||||
|
||||
Returns:
|
||||
Mapped values.
|
||||
"""
|
||||
g = lambda _, x: ((), f(x))
|
||||
_, ys = scan(g, (), xs)
|
||||
return ys
|
||||
|
@ -816,6 +816,13 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
f = partial(lax.scan, lambda c, x: (c + lax.psum(x, "i") , c), 0.)
|
||||
api.pmap(f, axis_name="i")(np.ones((num_devices, 4))) # doesn't crash
|
||||
|
||||
def testMap(self):
|
||||
f = lambda x: x ** 2
|
||||
xs = np.arange(10)
|
||||
expected = xs ** 2
|
||||
actual = lax.map(f, xs)
|
||||
self.assertAllClose(actual, expected, check_dtypes=True)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user