partial progress

This commit is contained in:
Matthew Johnson 2019-01-31 22:08:51 -08:00
parent 549a1f2e59
commit 08dc6994f5
10 changed files with 142 additions and 125 deletions

View File

@ -266,11 +266,10 @@ def pjit(fun, axis_name, in_axes=0, out_axes=0, mesh_axis=0):
jaxtupletree_args, in_trees = unzip2(map(pytree_to_jaxtupletree, args))
check_args(jaxtupletree_args)
f, out_tree = pytree_fun_to_jaxtupletree_fun(f, in_trees)
chunksize = pxla.chunk_size(axis_name, mesh_axis, in_axes, jaxtupletree_args)
# transform spmd primitives in f to first act locally then cross-replica
in_axes_ = pxla.canonicalize_in_axis_spec(in_trees, in_axes)
out_axes_ = OutAxesThunk(out_tree, out_axes) # for pretty-printing
chunksize = pxla.chunk_size(axis_name, mesh_axis, in_axes_, jaxtupletree_args)
f = pxla.chunk_transform(f, chunksize, axis_name, in_axes_, out_axes_)
jaxtupletree_out = pxla.xla_pcall(f, *jaxtupletree_args,

View File

@ -18,12 +18,13 @@ from __future__ import print_function
from . import partial_eval as pe
from . import xla
from . import pxla
from .. import core as core
from ..core import JaxTuple, Trace, Tracer, new_master, get_aval, pack, call_p, Primitive
from ..ad_util import (add_jaxvals, add_jaxvals_p, zeros_like_jaxval,
zeros_like_p, zero, Zero)
from ..util import unzip2, unzip3, safe_map, safe_zip, partial
from ..tree_util import process_pytree, build_tree, register_pytree_node
from ..tree_util import process_pytree, build_tree, register_pytree_node, prune
from ..linear_util import thunk, staged, transformation, transformation_with_aux, wrap_init
from six.moves import builtins, reduce
@ -184,6 +185,13 @@ class JVPTrace(Trace):
tangents = [t.tangent for t in tracers]
nonzero_tangents, in_tree_def = tree_to_jaxtuples(tangents)
f, out_tree_def = traceable(jvp_subtrace(f, self.master), in_tree_def)
if call_primitive is pxla.xla_pcall_p:
in_axes, out_axes = params['in_axes'], params['out_axes']
jvp_in_axes = (in_axes, prune(in_tree_def, in_axes))
def jvp_out_axes():
_, tangent_out_tree = out_tree_def().children
return (out_axes(), prune(tangent_out_tree, out_axes()))
params = dict(params, in_axes=jvp_in_axes, out_axes=jvp_out_axes)
result = call_primitive.bind(f, pack(primals), nonzero_tangents, **params)
primal_out, tangent_out = build_tree(out_tree_def(), result)
return JVPTracer(self, primal_out, tangent_out)
@ -386,13 +394,25 @@ def call_transpose(primitive, params, jaxpr, consts, freevar_vals, args, ct):
fun = wrap_init(backward_pass)
fun, out_tree_def = transposed_fun(fun, jaxpr, in_tree_def)
all_args = pack((pack(consts), pack(freevar_vals), ct))
# TODO(dougalm): consider signalling to bind that there are no traces in the closure
ans = primitive.bind(fun, all_args, **params)
if primitive is pxla.xla_pcall_p:
in_axes, out_axes = params['in_axes'], params['out_axes']()
ct_axes = prune(ct_tree, out_axes)
transpose_in_axes = ((None,) * len(consts), (None,) * len(freevar_vals), ct_axes),
def transpose_out_axes():
return prune(out_tree_def(), (in_axes, (None,) * len(freevar_vals)))
new_params = dict(params, in_axes=transpose_in_axes, out_axes=transpose_out_axes)
ans = primitive.bind(fun, all_args, **params)
import ipdb; ipdb.set_trace()
else:
# TODO(dougalm): consider signalling to bind that there are no traces in the closure
ans = primitive.bind(fun, all_args, **params)
return build_tree(out_tree_def(), ans)
primitive_transposes[core.call_p] = partial(call_transpose, call_p)
primitive_transposes[pe.compiled_call_p] = partial(call_transpose, pe.compiled_call_p)
primitive_transposes[xla.xla_call_p] = partial(call_transpose, xla.xla_call_p)
primitive_transposes[pxla.xla_pcall_p] = partial(call_transpose, pxla.xla_pcall_p)
tree_to_jaxtuples = partial(process_pytree, pack)

View File

@ -161,47 +161,7 @@ class PmapTrace(Trace):
return PmapTracer(self, name, vals, axis)
def unbound_name_error(primitive_name, *args, **kwargs):
axis_name = kwargs['axis_name']
msg = "axis name '{}' is unbound for primitive {}."
raise NameError(msg.format(axis_name, primitive_name))
def PmapPrimitive(name):
prim = Primitive(name)
prim.def_impl(partial(unbound_name_error, name))
prim.def_abstract_eval(lambda x, *args, **kwargs: x) # default
return prim
pmap_primitive_rules = {}
parallel_translation_rules = {}
def psum(x, axis_name):
return psum_p.bind(x, axis_name=axis_name)
def psum_pmap_rule(val, axis):
return val.sum(axis), None
def psum_parallel_translation_rule(c, val, device_groups):
if len(device_groups) > 1:
return c.CrossReplicaSum(val, device_groups)
else:
return c.CrossReplicaSum(val)
psum_p = PmapPrimitive('psum')
pmap_primitive_rules[psum_p] = psum_pmap_rule
parallel_translation_rules[psum_p] = psum_parallel_translation_rule
def gather(x, axis_name):
return gather_p.bind(x, axis_name=axis_name)
def gather_pmap_rule(val, axis):
return val, None
gather_p = PmapPrimitive('gather')
pmap_primitive_rules[gather_p] = gather_pmap_rule
### axis variable splitting and computation chunking

View File

@ -37,7 +37,6 @@ from ..lib import xla_bridge as xb
from .xla import (xla_shape, xla_destructure, translation_rule, abstractify,
xla_shape_to_result_shape, jaxpr_computation)
from .partial_eval import trace_to_subjaxpr, merge_pvals, JaxprTrace, PartialVal
from .parallel import parallel_translation_rules
from .batching import moveaxis
from . import parallel
from . import xla
@ -75,7 +74,7 @@ def chunk_aval(chunksize, aval, axis):
shape[axis] = chunksize
return ShapedArray(tuple(shape), aval.dtype)
# TODO these next two functions became pretty trivial, maybe prune
# TODO these next two functions became pretty trivial, maybe prune them
def canonicalize_in_axis_spec(in_trees, spec_tree_prefix):
"""Given argument list in_trees, canonicalize and flatten an in_axes spec."""
in_tree = tree_util.PyTreeDef(tree_util.node_types[tuple], None, in_trees)
@ -108,6 +107,16 @@ def flatten_axis_spec_tree(spec_tree):
spec_flat, _ = tree_util.tree_flatten(spec_tree)
return tuple(None if i is no_mapped_axis else i for i in spec_flat)
def tree_flatten_axes(maybe_tree, axes):
if type(maybe_tree) is core.JaxTuple:
if maybe_tree:
flat_children = map(tree_flatten_axes, maybe_tree, axes)
return it.chain.from_iterable(flat_children)
else:
return []
else:
return [axes]
# We use a special symbol for 'no mapped axis' instead of using None because
# tree_util.py treats None as a tree node.
class NoMappedAxis(object): pass
@ -127,7 +136,7 @@ def unshard_output(mesh_spec, mesh_axis, out_axis, out_shards):
"""Collect and concatenate sharded device results."""
_, ids = onp.unique(shard_assignments(mesh_spec, mesh_axis), return_index=True)
shards = [out_shards[i] for i in ids]
return onp.concatenate(shards, out_axis) # TODO device persistence
return onp.concatenate(shards, out_axis)
def shard_assignments(mesh_spec, mesh_axis):
"""Given a mesh axis long which to shard data, compute replica assignments."""
@ -204,6 +213,10 @@ def device_mesh(spec):
yield
_mesh_spec = prev_spec
# axis environments are tiny, so we don't worry about the cost of copying keys
def new_axis_env(d): return d
def extend_axis_env(d1, d2): return dict(d1, **d2)
### xla_pcall
@ -281,12 +294,14 @@ def xla_pcall_impl(fun, *args, **params):
flat_args, in_trees = unzip2(map(xla.tree_flatten, args))
flat_args = concatenate(flat_args)
fun, out_tree = xla.flatten_fun(fun, in_trees)
import ipdb; ipdb.set_trace() # TODO canonicalize in_axes
in_axes = tuple(concatenate(map(tree_flatten_axes, args, in_axes)))
assert len(flat_args) == len(in_axes)
compiled_fun = xla_parallel_callable(fun, axis_name, in_axes, mesh_axis,
mesh_spec(), *map(abstractify, flat_args))
flat_ans = compiled_fun(out_axes(), *flat_args)
flat_ans = compiled_fun(out_tree(), out_axes(), *flat_args)
if out_tree() is xla.leaf:
return flat_ans
@ -311,26 +326,18 @@ def xla_parallel_callable(fun, axis_name, in_axes, mesh_axis, mesh_spec,
return partial(execute_replicated, in_axes, mesh_axis, mesh_spec, compiled, pval)
def execute_replicated(in_axes, mesh_axis, mesh_spec, compiled, pval,
out_axes, *args):
out_tree, out_axes, *args):
input_bufs = map(partial(shard_arg, mesh_spec, mesh_axis), in_axes, args)
out_bufs = compiled.ExecutePerReplica(zip(*input_bufs))
out_shards = [merge_pvals(buf.to_py(), pval) for buf in out_bufs] # TODO
if type(out_axes) is int:
input_bufs = zip(*input_bufs) if input_bufs else [[]] * xb.get_replica_count()
out_bufs = compiled.ExecutePerReplica(input_bufs)
out_shards = [merge_pvals(buf.to_py(), pval) for buf in out_bufs]
if out_tree is xla.leaf:
return unshard_output(mesh_spec, mesh_axis, out_axes, out_shards)
else:
raise NotImplementedError
return map(partial(unshard_output, mesh_spec, mesh_axis), out_axes,
zip(*out_shards))
def tree_flatten_axes(maybe_tree, axes):
if type(maybe_tree) is core.JaxTuple:
if maybe_tree:
flat_children = map(tree_flatten_axes, maybe_tree, axes)
return it.chain.from_iterable(flat_children)
else:
return []
else:
return [axes]
xla_pcall_p = core.Primitive('xla_pcall')
xla_pcall = partial(core.call_bind, xla_pcall_p)
@ -339,6 +346,4 @@ xla_pcall_p.def_impl(xla_pcall_impl)
xla.translations[xla_pcall_p] = xla.xla_call_translation_rule
# axis environments are tiny, so we don't worry about the cost of copying keys
def new_axis_env(d): return d
def extend_axis_env(d1, d2): return dict(d1, **d2)
parallel_translation_rules = {}

View File

@ -376,14 +376,16 @@ def xla_shape(x):
def flatten_fun(in_trees, *flat_args):
jtuple_trees = tuple(map(partial(build_tree, iter(flat_args)), in_trees))
ans = yield jtuple_trees
if type(ans) is JaxTuple:
aval = core.get_aval(ans)
if type(aval) is AbstractTuple:
ans_flat, out_tree = tree_flatten(ans)
yield pack(ans_flat), out_tree
else:
yield ans, leaf
def tree_flatten(maybe_tree):
if type(maybe_tree) is JaxTuple:
aval = core.get_aval(maybe_tree)
if type(aval) is AbstractTuple:
flat_children, child_specs = unzip2(map(tree_flatten, maybe_tree))
return it.chain.from_iterable(flat_children), JTupleTreeDef(child_specs)
elif core.skip_checks or valid_jaxtype(maybe_tree):

View File

@ -39,6 +39,7 @@ from .abstract_arrays import (UnshapedArray, ShapedArray, ConcreteArray,
from .api_util import pytree_fun_to_jaxtupletree_fun, pytree_to_jaxtupletree
from .interpreters import partial_eval as pe
from .interpreters import xla
from .interpreters import pxla
from .interpreters import ad
from .interpreters import batching
from .interpreters import parallel
@ -521,8 +522,8 @@ def broadcasted_eye(dtype, shape, axes):
return EyeConstant(shape, axes, dtype)
def stop_gradient(x):
return stop_gradient_p.bind(x)
def psum(x, axis_name):
return psum_p.bind(x, axis_name=axis_name)
### convenience wrappers around traceables
@ -2296,7 +2297,6 @@ reduce_sum_p = standard_primitive(_reduce_sum_shape_rule, _input_dtype,
'reduce_sum', _reduce_sum_translation_rule)
ad.deflinear(reduce_sum_p, _reduce_sum_transpose_rule)
batching.defreducer(reduce_sum_p)
parallel.defreducer(reduce_sum_p, parallel.psum_p)
def _reduce_chooser_shape_rule(operand, axes):
@ -2902,25 +2902,38 @@ for t in [FilledConstant, IotaConstant, EyeConstant]:
ad_util.jaxval_zeros_likers[t] = zeros_like_array
### stop_gradient
### parallel
def PmapPrimitive(name):
prim = Primitive(name)
prim.def_impl(partial(unbound_name_error, name))
prim.def_abstract_eval(lambda x, *args, **kwargs: x) # default
return prim
def unbound_name_error(primitive_name, *args, **kwargs):
axis_name = kwargs['axis_name']
msg = "axis name '{}' is unbound for primitive {}."
raise NameError(msg.format(axis_name, primitive_name))
def _stop_gradient_jvp_rule(primals, tangents):
# if we don't call stop_gradient here, we'd only peel off one autodiff tracer
x, = primals
return stop_gradient(x), ad_util.zero
def psum_transpose_rule(t, axis_name):
return [t]
def _stop_gradient_batch_rule(batched_args, batch_dims):
x, = batched_args
dim, = batch_dims
return stop_gradient(x), dim
def psum_parallel_translation_rule(c, val, device_groups):
if len(device_groups) > 1:
return c.CrossReplicaSum(val, device_groups)
else:
return c.CrossReplicaSum(val)
stop_gradient_p = Primitive('stop_gradient')
stop_gradient_p.def_impl(identity)
stop_gradient_p.def_abstract_eval(identity)
xla.translations[stop_gradient_p] = lambda c, x: x
ad.primitive_jvps[stop_gradient_p] = _stop_gradient_jvp_rule
batching.primitive_batchers[stop_gradient_p] = _stop_gradient_batch_rule
def psum_pmap_rule(val, axis):
return _reduce_sum(val, [axis]), None
psum_p = PmapPrimitive('psum')
parallel.pmap_primitive_rules[psum_p] = psum_pmap_rule
pxla.parallel_translation_rules[psum_p] = psum_parallel_translation_rule
ad.deflinear(psum_p, psum_transpose_rule)
parallel.defreducer(reduce_sum_p, psum_p)
### util

View File

@ -133,6 +133,15 @@ def tree_structure(tree):
return spec
def prune(treedef, tuple_tree):
if treedef is leaf:
return tuple_tree
elif treedef.children:
return tuple(map(prune, treedef.children, tuple_tree))
else:
return ()
class PyTreeDef(object):
def __init__(self, node_type, node_data, children):
self.node_type = node_type

View File

@ -1,8 +1,8 @@
import numpy as onp
import jax.numpy as np
from jax import grad, pjit, pmap, make_jaxpr
from jax.interpreters.parallel import psum
from jax import jvp, grad, pjit, pmap, make_jaxpr
from jax.lax import psum
# def f(x, y):
@ -15,11 +15,21 @@ from jax.interpreters.parallel import psum
def f(x):
return x - psum(x, 'i')
return np.cos(x - psum(np.sin(x), 'i'))
x = np.zeros(4)
print grad(lambda x: np.sum(pmap(f, 'i')(x)))(x)
print grad(lambda x: np.sum(x - np.sum(x)))(x)
x = np.ones(4)
print jvp(pmap(f, 'i'), (x,), (x,))
g = pjit(f, axis_name='i')
print grad(lambda x: np.sum(g(x)))(x)
print jvp(g, (x,), (x,))
# def f(x):
# return x - psum(x, 'i')
# x = np.ones(4)
# print grad(lambda x: np.sum(pmap(f, 'i')(x)))(x)
# print grad(lambda x: np.sum(x - np.sum(x)))(x)
# g = pjit(f, axis_name='i')
# print grad(lambda x: np.sum(g(x)))(x)

View File

@ -25,7 +25,6 @@ from jax import test_util as jtu
from jax import lax
from jax.api import pmap, papply, jit, make_jaxpr, axisvar_split
from jax.linear_util import wrap_init
from jax.interpreters.parallel import psum, scatter_like
from jax.config import config
config.parse_flags_with_absl()
@ -40,20 +39,20 @@ class PmapTest(jtu.JaxTestCase):
self.assertAllClose(ans, expected, check_dtypes=False)
def testReduceSum(self):
f = lambda x: psum(x, 'i')
f = lambda x: lax.psum(x, 'i')
ans = pmap(f, axis_name='i')(onp.ones(4))
expected = 4 * onp.ones(4)
self.assertAllClose(ans, expected, check_dtypes=False)
def testLogSoftmax(self):
f = lambda x: x - np.log(psum(np.exp(x), 'i'))
f = lambda x: x - np.log(lax.psum(np.exp(x), 'i'))
x = onp.log(onp.arange(1., 10., dtype=onp.float32))
ans = pmap(f, axis_name='i')(x)
expected = x - onp.log(onp.sum(onp.exp(x)))
self.assertAllClose(ans, expected, check_dtypes=False)
def testNested(self):
f = lambda x: psum(psum(x, 'i'), 'j')
f = lambda x: lax.psum(lax.psum(x, 'i'), 'j')
x = onp.ones((2, 2))
ans1 = pmap(pmap(f, 'i'), 'j')(x)
ans2 = pmap(pmap(f, 'j'), 'i')(x)
@ -80,7 +79,7 @@ class PapplyTest(jtu.JaxTestCase):
pfun, axis_name = papply(np.sum)
jaxpr = make_jaxpr(pfun)(onp.zeros(5))
expected_jaxpr = make_jaxpr(lambda x: psum(x, axis_name))(onp.zeros(5))
expected_jaxpr = make_jaxpr(lambda x: lax.psum(x, axis_name))(onp.zeros(5))
assert repr(jaxpr) == repr(expected_jaxpr)
ans = pmap(pfun, axis_name)(onp.arange(3.))
@ -95,7 +94,7 @@ class PapplyTest(jtu.JaxTestCase):
pfun, axis_name = papply(fun)
jaxpr = make_jaxpr(pfun)(onp.zeros(5))
expected_jaxpr = make_jaxpr(lambda x: x - np.log(psum(np.exp(x), axis_name))
expected_jaxpr = make_jaxpr(lambda x: x - np.log(lax.psum(np.exp(x), axis_name))
)(onp.zeros(5))
assert repr(jaxpr) == repr(expected_jaxpr)
@ -127,7 +126,7 @@ class PapplyTest(jtu.JaxTestCase):
class SplitTest(jtu.JaxTestCase):
def testSplitBasic(self):
f = lambda x: psum(np.sin(x), 'i')
f = lambda x: lax.psum(np.sin(x), 'i')
x = onp.ones((2, 2))
fsplit = axisvar_split(f, 'i', ('j', 'k'))
ans = pmap(pmap(fsplit, 'j'), 'k')(x)

View File

@ -23,7 +23,7 @@ from absl.testing import parameterized
import jax.numpy as np
from jax import test_util as jtu
from jax.api import pjit
from jax.interpreters.parallel import psum
from jax.lax import psum
from jax.config import config
config.parse_flags_with_absl()
@ -31,32 +31,32 @@ config.parse_flags_with_absl()
class PmapTest(jtu.JaxTestCase):
# @jtu.skip_on_devices("gpu")
# def testBasic(self):
# f = lambda x: x - psum(x, 'i')
# x = onp.arange(8., dtype=onp.float32).reshape(4, 2)
# f = pjit(f, axis_name='i', in_axes=0, out_axes=0, mesh_axis=0)
# ans = f(x)
# expected = x - x.sum(0)
# self.assertAllClose(ans, expected, check_dtypes=False)
@jtu.skip_on_devices("gpu")
def testBasic(self):
f = lambda x: x - psum(x, 'i')
x = onp.arange(8., dtype=onp.float32).reshape(4, 2)
f = pjit(f, axis_name='i', in_axes=0, out_axes=0, mesh_axis=0)
ans = f(x)
expected = x - x.sum(0)
self.assertAllClose(ans, expected, check_dtypes=False)
# @jtu.skip_on_devices("gpu")
# def testTupleOutput(self):
# f = lambda x: (x - psum(x, 'i'),)
# x = onp.arange(8., dtype=onp.float32).reshape(4, 2)
# f = pjit(f, axis_name='i', in_axes=0, out_axes=0, mesh_axis=0)
# ans = f(x)
# expected = (x - x.sum(0),)
# self.assertAllClose(ans, expected, check_dtypes=False)
@jtu.skip_on_devices("gpu")
def testTupleOutput(self):
f = lambda x: (x - psum(x, 'i'),)
x = onp.arange(8., dtype=onp.float32).reshape(4, 2)
f = pjit(f, axis_name='i', in_axes=0, out_axes=0, mesh_axis=0)
ans = f(x)
expected = (x - x.sum(0),)
self.assertAllClose(ans, expected, check_dtypes=False)
# @jtu.skip_on_devices("gpu")
# def testTupleInput(self):
# f = lambda x: x[0] - psum(x[0], 'i')
# x = onp.arange(8., dtype=onp.float32).reshape(4, 2)
# f = pjit(f, axis_name='i', in_axes=0, out_axes=0, mesh_axis=0)
# ans = f((x,))
# expected = x - x.sum(0)
# self.assertAllClose(ans, expected, check_dtypes=False)
@jtu.skip_on_devices("gpu")
def testTupleInput(self):
f = lambda x: x[0] - psum(x[0], 'i')
x = onp.arange(8., dtype=onp.float32).reshape(4, 2)
f = pjit(f, axis_name='i', in_axes=0, out_axes=0, mesh_axis=0)
ans = f((x,))
expected = x - x.sum(0)
self.assertAllClose(ans, expected, check_dtypes=False)
@jtu.skip_on_devices("gpu")
def testNested(self):