Implement lax.map

Fixes GH-1113
This commit is contained in:
Stephan Hoyer 2019-08-05 12:13:07 -07:00
parent a26963fe87
commit 8c628a267b
3 changed files with 60 additions and 21 deletions

View File

@ -133,6 +133,7 @@ Control flow operators
cond
fori_loop
map
scan
while_loop

View File

@ -40,7 +40,7 @@ from jax.util import partial, unzip2, safe_map, safe_zip
from jax.tree_util import build_tree, tree_unflatten, tree_map
from jax import ad_util
map = safe_map
_map = safe_map
zip = safe_zip
@ -220,7 +220,7 @@ def _while_loop_batching_rule(batched_args, batch_dims,
init_val, cond_consts, body_consts = batched_args
init_val_bd, cond_consts_bd, body_consts_bd = batch_dims
sizes = lax._reduce(set.union, map(batching.dimsize, batch_dims, batched_args))
sizes = lax._reduce(set.union, _map(batching.dimsize, batch_dims, batched_args))
size = sizes.pop()
assert not sizes
@ -253,7 +253,7 @@ def _while_loop_batching_rule(batched_args, batch_dims,
def _jaxtupletree_select(pred, on_true, on_false):
aval = core.get_aval(on_true)
if type(aval) is core.AbstractTuple:
return core.pack(map(partial(_jaxtupletree_select, pred), on_true, on_false))
return core.pack(_map(partial(_jaxtupletree_select, pred), on_true, on_false))
elif isinstance(aval, UnshapedArray):
return lax.select(pred, on_true, on_false)
else:
@ -402,7 +402,7 @@ xla.initial_style_translations[cond_p] = _cond_translation_rule
def _maybe_tracer_tuple_to_abstract_tuple(tup):
if isinstance(tup, pe.JaxprTracerTuple):
return core.AbstractTuple(list(map(_maybe_tracer_tuple_to_abstract_tuple, tup)))
return core.AbstractTuple(list(_map(_maybe_tracer_tuple_to_abstract_tuple, tup)))
elif isinstance(tup, core.AbstractValue):
return tup
elif tup is None:
@ -424,10 +424,10 @@ def _convert_zeros(instantiate, example, tangent):
raise TypeError(tangent) # not clear if ever reachable
elif t is tuple:
if type(tangent) is ad.TangentTuple:
return core.pack(map(_convert_zeros, instantiate, example, tangent))
return core.pack(_map(_convert_zeros, instantiate, example, tangent))
elif tangent is ad_util.zero:
zeros = [ad_util.zero] * len(instantiate)
return core.pack(map(_convert_zeros, instantiate, example, zeros))
return core.pack(_map(_convert_zeros, instantiate, example, zeros))
else:
raise TypeError(tangent)
else:
@ -436,20 +436,20 @@ def _convert_zeros(instantiate, example, tangent):
def _demote_aval_rank(xs):
assert isinstance(xs, core.AbstractValue)
if isinstance(xs, core.AbstractTuple):
return core.AbstractTuple(map(_demote_aval_rank, xs))
return core.AbstractTuple(_map(_demote_aval_rank, xs))
else:
return ShapedArray(xs.shape[1:], xs.dtype)
def _promote_aval_rank(n, xs):
assert isinstance(xs, core.AbstractValue)
if isinstance(xs, core.AbstractTuple):
return core.AbstractTuple(map(partial(_promote_aval_rank, n), xs))
return core.AbstractTuple(_map(partial(_promote_aval_rank, n), xs))
else:
return ShapedArray((n,) + xs.shape, xs.dtype)
def _leading_dim_size(xs):
if isinstance(xs, core.JaxTuple):
sizes = set(map(_leading_dim_size, xs))
sizes = set(_map(_leading_dim_size, xs))
if len(sizes) == 1:
return sizes.pop()
elif len(sizes) > 1:
@ -468,21 +468,21 @@ def _leading_dim_size(xs):
def _empty_arrays(aval):
assert isinstance(aval, core.AbstractValue)
if isinstance(aval, core.AbstractTuple):
return core.pack(map(_empty_arrays, aval))
return core.pack(_map(_empty_arrays, aval))
else:
return lax.full(aval.shape, 0, aval.dtype)
def _index_arrays(i, aval, xs):
assert isinstance(aval, core.AbstractValue)
if isinstance(aval, core.AbstractTuple):
return core.pack(map(partial(_index_arrays, i), aval, xs))
return core.pack(_map(partial(_index_arrays, i), aval, xs))
else:
return lax.dynamic_index_in_dim(xs, i, keepdims=False)
def _update_arrays(i, aval, xs, x):
assert isinstance(aval, core.AbstractValue)
if isinstance(aval, core.AbstractTuple):
return core.pack(map(partial(_update_arrays, i), aval, xs, x))
return core.pack(_map(partial(_update_arrays, i), aval, xs, x))
else:
x = lax.reshape(x, (1,) + onp.shape(x))
return lax.dynamic_update_index_in_dim(xs, x, i, axis=0)
@ -543,7 +543,7 @@ def scan(f, init, xs):
loop carry value and the second element represents the stacked outputs of
the second output of ``f`` when scanned over the leading axis of the inputs.
"""
(init, xs), in_trees = unzip2(map(pytree_to_jaxtupletree, (init, xs)))
(init, xs), in_trees = unzip2(_map(pytree_to_jaxtupletree, (init, xs)))
f, out_tree = pytree_fun_to_jaxtupletree_fun(lu.wrap_init(f), in_trees)
carry_pval = carry_aval, _ = _abstractify(init)
xs_aval, _ = _abstractify(xs)
@ -639,11 +639,11 @@ def _binary_lattice_fold(f, pack, a, b):
recur = partial(_binary_lattice_fold, f, pack)
t = (type(a), type(b))
if t == (tuple, tuple):
return pack(map(recur, a, b))
return pack(_map(recur, a, b))
elif t == (tuple, bool):
return pack(map(recur, a, (b,) * len(a)))
return pack(_map(recur, a, (b,) * len(a)))
elif t == (bool, tuple):
return pack(map(recur, (a,) * len(b), b))
return pack(_map(recur, (a,) * len(b), b))
elif t == (bool, bool):
return f(a, b)
else:
@ -659,7 +659,7 @@ def _scan_partial_eval(trace, *tracers, **kwargs):
forward = kwargs.pop('forward')
assert not kwargs
in_pvs, _ = unzip2([t.pval for t in tracers])
sc_consts, sc_init, sc_xs = map(pe.unknown, in_pvs)
sc_consts, sc_init, sc_xs = _map(pe.unknown, in_pvs)
sc_carry = sc_init
for i in range(1000):
@ -702,8 +702,8 @@ def _lift_tracer(trace, tracer, is_unknown):
else:
return tracer
elif t is tuple:
tracers = map(trace.full_raise, tracer)
return core.pack(map(partial(_lift_tracer, trace), tracers, is_unknown))
tracers = _map(trace.full_raise, tracer)
return core.pack(_map(partial(_lift_tracer, trace), tracers, is_unknown))
else:
raise TypeError(t)
@ -713,7 +713,7 @@ def _put_known_pvs(is_unknown, aval):
elif is_unknown is True:
return aval
else:
return pe.JaxprTracerTuple(map(_put_known_pvs, is_unknown, aval))
return pe.JaxprTracerTuple(_map(_put_known_pvs, is_unknown, aval))
def _scan_transpose(ct, consts, init, xs, forward, length, jaxpr):
@ -836,7 +836,7 @@ def _scan_batching_rule(batched_args, batch_dims, forward, length, jaxpr):
consts, init, xs = batched_args
consts_bdim, init_bdim, xs_bdim = batch_dims
sizes = lax._reduce(set.union, map(batching.dimsize, batch_dims, batched_args))
sizes = lax._reduce(set.union, _map(batching.dimsize, batch_dims, batched_args))
size = sizes.pop()
assert not sizes
@ -889,3 +889,34 @@ ad.primitive_transposes[scan_p] = _scan_transpose
pe.custom_partial_eval_rules[scan_p] = _scan_partial_eval
xla.initial_style_translations[scan_p] = xla.lower_fun(_scan_impl, initial_style=True)
batching.primitive_batchers[scan_p] = _scan_batching_rule
def map(f, xs):
"""Map a function over leading array axes.
Like Python's builtin map, except inputs and outputs are in the form of
stacked arrays. Consider using the ``jax.vmap`` transform instead, unless you
need to apply a function element by element for reduced memory usage or
heterogeneous computation with other control flow primitives.
When ``xs`` is an array type, the semantics of ``map`` are given by this
Python implementation::
def map(f, xs):
return np.stack([f(x) for x in xs])
Like ``scan``, ``map`` is implemented in terms of JAX primtivies so many of
the same advantages over a Python loop apply: ``xs`` may be an arbitrary
nested pytree type, and the mapped computation is compiled only once.
Args:
f: a Python function to apply element-wise over the first axis or axes of
``xs``.
xs: values over which to map along the leading axis.
Returns:
Mapped values.
"""
g = lambda _, x: ((), f(x))
_, ys = scan(g, (), xs)
return ys

View File

@ -816,6 +816,13 @@ class LaxControlFlowTest(jtu.JaxTestCase):
f = partial(lax.scan, lambda c, x: (c + lax.psum(x, "i") , c), 0.)
api.pmap(f, axis_name="i")(np.ones((num_devices, 4))) # doesn't crash
def testMap(self):
f = lambda x: x ** 2
xs = np.arange(10)
expected = xs ** 2
actual = lax.map(f, xs)
self.assertAllClose(actual, expected, check_dtypes=True)
if __name__ == '__main__':
absltest.main()