mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
partial progress
This commit is contained in:
parent
549a1f2e59
commit
08dc6994f5
@ -266,11 +266,10 @@ def pjit(fun, axis_name, in_axes=0, out_axes=0, mesh_axis=0):
|
||||
jaxtupletree_args, in_trees = unzip2(map(pytree_to_jaxtupletree, args))
|
||||
check_args(jaxtupletree_args)
|
||||
f, out_tree = pytree_fun_to_jaxtupletree_fun(f, in_trees)
|
||||
chunksize = pxla.chunk_size(axis_name, mesh_axis, in_axes, jaxtupletree_args)
|
||||
|
||||
# transform spmd primitives in f to first act locally then cross-replica
|
||||
in_axes_ = pxla.canonicalize_in_axis_spec(in_trees, in_axes)
|
||||
out_axes_ = OutAxesThunk(out_tree, out_axes) # for pretty-printing
|
||||
chunksize = pxla.chunk_size(axis_name, mesh_axis, in_axes_, jaxtupletree_args)
|
||||
f = pxla.chunk_transform(f, chunksize, axis_name, in_axes_, out_axes_)
|
||||
|
||||
jaxtupletree_out = pxla.xla_pcall(f, *jaxtupletree_args,
|
||||
|
@ -18,12 +18,13 @@ from __future__ import print_function
|
||||
|
||||
from . import partial_eval as pe
|
||||
from . import xla
|
||||
from . import pxla
|
||||
from .. import core as core
|
||||
from ..core import JaxTuple, Trace, Tracer, new_master, get_aval, pack, call_p, Primitive
|
||||
from ..ad_util import (add_jaxvals, add_jaxvals_p, zeros_like_jaxval,
|
||||
zeros_like_p, zero, Zero)
|
||||
from ..util import unzip2, unzip3, safe_map, safe_zip, partial
|
||||
from ..tree_util import process_pytree, build_tree, register_pytree_node
|
||||
from ..tree_util import process_pytree, build_tree, register_pytree_node, prune
|
||||
from ..linear_util import thunk, staged, transformation, transformation_with_aux, wrap_init
|
||||
|
||||
from six.moves import builtins, reduce
|
||||
@ -184,6 +185,13 @@ class JVPTrace(Trace):
|
||||
tangents = [t.tangent for t in tracers]
|
||||
nonzero_tangents, in_tree_def = tree_to_jaxtuples(tangents)
|
||||
f, out_tree_def = traceable(jvp_subtrace(f, self.master), in_tree_def)
|
||||
if call_primitive is pxla.xla_pcall_p:
|
||||
in_axes, out_axes = params['in_axes'], params['out_axes']
|
||||
jvp_in_axes = (in_axes, prune(in_tree_def, in_axes))
|
||||
def jvp_out_axes():
|
||||
_, tangent_out_tree = out_tree_def().children
|
||||
return (out_axes(), prune(tangent_out_tree, out_axes()))
|
||||
params = dict(params, in_axes=jvp_in_axes, out_axes=jvp_out_axes)
|
||||
result = call_primitive.bind(f, pack(primals), nonzero_tangents, **params)
|
||||
primal_out, tangent_out = build_tree(out_tree_def(), result)
|
||||
return JVPTracer(self, primal_out, tangent_out)
|
||||
@ -386,13 +394,25 @@ def call_transpose(primitive, params, jaxpr, consts, freevar_vals, args, ct):
|
||||
fun = wrap_init(backward_pass)
|
||||
fun, out_tree_def = transposed_fun(fun, jaxpr, in_tree_def)
|
||||
all_args = pack((pack(consts), pack(freevar_vals), ct))
|
||||
# TODO(dougalm): consider signalling to bind that there are no traces in the closure
|
||||
ans = primitive.bind(fun, all_args, **params)
|
||||
if primitive is pxla.xla_pcall_p:
|
||||
in_axes, out_axes = params['in_axes'], params['out_axes']()
|
||||
ct_axes = prune(ct_tree, out_axes)
|
||||
transpose_in_axes = ((None,) * len(consts), (None,) * len(freevar_vals), ct_axes),
|
||||
def transpose_out_axes():
|
||||
return prune(out_tree_def(), (in_axes, (None,) * len(freevar_vals)))
|
||||
new_params = dict(params, in_axes=transpose_in_axes, out_axes=transpose_out_axes)
|
||||
ans = primitive.bind(fun, all_args, **params)
|
||||
|
||||
import ipdb; ipdb.set_trace()
|
||||
else:
|
||||
# TODO(dougalm): consider signalling to bind that there are no traces in the closure
|
||||
ans = primitive.bind(fun, all_args, **params)
|
||||
return build_tree(out_tree_def(), ans)
|
||||
|
||||
primitive_transposes[core.call_p] = partial(call_transpose, call_p)
|
||||
primitive_transposes[pe.compiled_call_p] = partial(call_transpose, pe.compiled_call_p)
|
||||
primitive_transposes[xla.xla_call_p] = partial(call_transpose, xla.xla_call_p)
|
||||
primitive_transposes[pxla.xla_pcall_p] = partial(call_transpose, pxla.xla_pcall_p)
|
||||
|
||||
|
||||
tree_to_jaxtuples = partial(process_pytree, pack)
|
||||
|
@ -161,47 +161,7 @@ class PmapTrace(Trace):
|
||||
return PmapTracer(self, name, vals, axis)
|
||||
|
||||
|
||||
def unbound_name_error(primitive_name, *args, **kwargs):
|
||||
axis_name = kwargs['axis_name']
|
||||
msg = "axis name '{}' is unbound for primitive {}."
|
||||
raise NameError(msg.format(axis_name, primitive_name))
|
||||
|
||||
def PmapPrimitive(name):
|
||||
prim = Primitive(name)
|
||||
prim.def_impl(partial(unbound_name_error, name))
|
||||
prim.def_abstract_eval(lambda x, *args, **kwargs: x) # default
|
||||
return prim
|
||||
|
||||
|
||||
pmap_primitive_rules = {}
|
||||
parallel_translation_rules = {}
|
||||
|
||||
|
||||
def psum(x, axis_name):
|
||||
return psum_p.bind(x, axis_name=axis_name)
|
||||
|
||||
def psum_pmap_rule(val, axis):
|
||||
return val.sum(axis), None
|
||||
|
||||
def psum_parallel_translation_rule(c, val, device_groups):
|
||||
if len(device_groups) > 1:
|
||||
return c.CrossReplicaSum(val, device_groups)
|
||||
else:
|
||||
return c.CrossReplicaSum(val)
|
||||
|
||||
psum_p = PmapPrimitive('psum')
|
||||
pmap_primitive_rules[psum_p] = psum_pmap_rule
|
||||
parallel_translation_rules[psum_p] = psum_parallel_translation_rule
|
||||
|
||||
|
||||
def gather(x, axis_name):
|
||||
return gather_p.bind(x, axis_name=axis_name)
|
||||
|
||||
def gather_pmap_rule(val, axis):
|
||||
return val, None
|
||||
|
||||
gather_p = PmapPrimitive('gather')
|
||||
pmap_primitive_rules[gather_p] = gather_pmap_rule
|
||||
|
||||
|
||||
### axis variable splitting and computation chunking
|
||||
|
@ -37,7 +37,6 @@ from ..lib import xla_bridge as xb
|
||||
from .xla import (xla_shape, xla_destructure, translation_rule, abstractify,
|
||||
xla_shape_to_result_shape, jaxpr_computation)
|
||||
from .partial_eval import trace_to_subjaxpr, merge_pvals, JaxprTrace, PartialVal
|
||||
from .parallel import parallel_translation_rules
|
||||
from .batching import moveaxis
|
||||
from . import parallel
|
||||
from . import xla
|
||||
@ -75,7 +74,7 @@ def chunk_aval(chunksize, aval, axis):
|
||||
shape[axis] = chunksize
|
||||
return ShapedArray(tuple(shape), aval.dtype)
|
||||
|
||||
# TODO these next two functions became pretty trivial, maybe prune
|
||||
# TODO these next two functions became pretty trivial, maybe prune them
|
||||
def canonicalize_in_axis_spec(in_trees, spec_tree_prefix):
|
||||
"""Given argument list in_trees, canonicalize and flatten an in_axes spec."""
|
||||
in_tree = tree_util.PyTreeDef(tree_util.node_types[tuple], None, in_trees)
|
||||
@ -108,6 +107,16 @@ def flatten_axis_spec_tree(spec_tree):
|
||||
spec_flat, _ = tree_util.tree_flatten(spec_tree)
|
||||
return tuple(None if i is no_mapped_axis else i for i in spec_flat)
|
||||
|
||||
def tree_flatten_axes(maybe_tree, axes):
|
||||
if type(maybe_tree) is core.JaxTuple:
|
||||
if maybe_tree:
|
||||
flat_children = map(tree_flatten_axes, maybe_tree, axes)
|
||||
return it.chain.from_iterable(flat_children)
|
||||
else:
|
||||
return []
|
||||
else:
|
||||
return [axes]
|
||||
|
||||
# We use a special symbol for 'no mapped axis' instead of using None because
|
||||
# tree_util.py treats None as a tree node.
|
||||
class NoMappedAxis(object): pass
|
||||
@ -127,7 +136,7 @@ def unshard_output(mesh_spec, mesh_axis, out_axis, out_shards):
|
||||
"""Collect and concatenate sharded device results."""
|
||||
_, ids = onp.unique(shard_assignments(mesh_spec, mesh_axis), return_index=True)
|
||||
shards = [out_shards[i] for i in ids]
|
||||
return onp.concatenate(shards, out_axis) # TODO device persistence
|
||||
return onp.concatenate(shards, out_axis)
|
||||
|
||||
def shard_assignments(mesh_spec, mesh_axis):
|
||||
"""Given a mesh axis long which to shard data, compute replica assignments."""
|
||||
@ -204,6 +213,10 @@ def device_mesh(spec):
|
||||
yield
|
||||
_mesh_spec = prev_spec
|
||||
|
||||
# axis environments are tiny, so we don't worry about the cost of copying keys
|
||||
def new_axis_env(d): return d
|
||||
def extend_axis_env(d1, d2): return dict(d1, **d2)
|
||||
|
||||
|
||||
### xla_pcall
|
||||
|
||||
@ -281,12 +294,14 @@ def xla_pcall_impl(fun, *args, **params):
|
||||
flat_args, in_trees = unzip2(map(xla.tree_flatten, args))
|
||||
flat_args = concatenate(flat_args)
|
||||
fun, out_tree = xla.flatten_fun(fun, in_trees)
|
||||
|
||||
import ipdb; ipdb.set_trace() # TODO canonicalize in_axes
|
||||
in_axes = tuple(concatenate(map(tree_flatten_axes, args, in_axes)))
|
||||
assert len(flat_args) == len(in_axes)
|
||||
|
||||
compiled_fun = xla_parallel_callable(fun, axis_name, in_axes, mesh_axis,
|
||||
mesh_spec(), *map(abstractify, flat_args))
|
||||
flat_ans = compiled_fun(out_axes(), *flat_args)
|
||||
flat_ans = compiled_fun(out_tree(), out_axes(), *flat_args)
|
||||
|
||||
if out_tree() is xla.leaf:
|
||||
return flat_ans
|
||||
@ -311,26 +326,18 @@ def xla_parallel_callable(fun, axis_name, in_axes, mesh_axis, mesh_spec,
|
||||
return partial(execute_replicated, in_axes, mesh_axis, mesh_spec, compiled, pval)
|
||||
|
||||
def execute_replicated(in_axes, mesh_axis, mesh_spec, compiled, pval,
|
||||
out_axes, *args):
|
||||
out_tree, out_axes, *args):
|
||||
input_bufs = map(partial(shard_arg, mesh_spec, mesh_axis), in_axes, args)
|
||||
out_bufs = compiled.ExecutePerReplica(zip(*input_bufs))
|
||||
out_shards = [merge_pvals(buf.to_py(), pval) for buf in out_bufs] # TODO
|
||||
if type(out_axes) is int:
|
||||
input_bufs = zip(*input_bufs) if input_bufs else [[]] * xb.get_replica_count()
|
||||
out_bufs = compiled.ExecutePerReplica(input_bufs)
|
||||
out_shards = [merge_pvals(buf.to_py(), pval) for buf in out_bufs]
|
||||
if out_tree is xla.leaf:
|
||||
return unshard_output(mesh_spec, mesh_axis, out_axes, out_shards)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return map(partial(unshard_output, mesh_spec, mesh_axis), out_axes,
|
||||
zip(*out_shards))
|
||||
|
||||
def tree_flatten_axes(maybe_tree, axes):
|
||||
if type(maybe_tree) is core.JaxTuple:
|
||||
if maybe_tree:
|
||||
flat_children = map(tree_flatten_axes, maybe_tree, axes)
|
||||
return it.chain.from_iterable(flat_children)
|
||||
else:
|
||||
return []
|
||||
else:
|
||||
return [axes]
|
||||
|
||||
|
||||
xla_pcall_p = core.Primitive('xla_pcall')
|
||||
xla_pcall = partial(core.call_bind, xla_pcall_p)
|
||||
@ -339,6 +346,4 @@ xla_pcall_p.def_impl(xla_pcall_impl)
|
||||
xla.translations[xla_pcall_p] = xla.xla_call_translation_rule
|
||||
|
||||
|
||||
# axis environments are tiny, so we don't worry about the cost of copying keys
|
||||
def new_axis_env(d): return d
|
||||
def extend_axis_env(d1, d2): return dict(d1, **d2)
|
||||
parallel_translation_rules = {}
|
||||
|
@ -376,14 +376,16 @@ def xla_shape(x):
|
||||
def flatten_fun(in_trees, *flat_args):
|
||||
jtuple_trees = tuple(map(partial(build_tree, iter(flat_args)), in_trees))
|
||||
ans = yield jtuple_trees
|
||||
if type(ans) is JaxTuple:
|
||||
aval = core.get_aval(ans)
|
||||
if type(aval) is AbstractTuple:
|
||||
ans_flat, out_tree = tree_flatten(ans)
|
||||
yield pack(ans_flat), out_tree
|
||||
else:
|
||||
yield ans, leaf
|
||||
|
||||
def tree_flatten(maybe_tree):
|
||||
if type(maybe_tree) is JaxTuple:
|
||||
aval = core.get_aval(maybe_tree)
|
||||
if type(aval) is AbstractTuple:
|
||||
flat_children, child_specs = unzip2(map(tree_flatten, maybe_tree))
|
||||
return it.chain.from_iterable(flat_children), JTupleTreeDef(child_specs)
|
||||
elif core.skip_checks or valid_jaxtype(maybe_tree):
|
||||
|
49
jax/lax.py
49
jax/lax.py
@ -39,6 +39,7 @@ from .abstract_arrays import (UnshapedArray, ShapedArray, ConcreteArray,
|
||||
from .api_util import pytree_fun_to_jaxtupletree_fun, pytree_to_jaxtupletree
|
||||
from .interpreters import partial_eval as pe
|
||||
from .interpreters import xla
|
||||
from .interpreters import pxla
|
||||
from .interpreters import ad
|
||||
from .interpreters import batching
|
||||
from .interpreters import parallel
|
||||
@ -521,8 +522,8 @@ def broadcasted_eye(dtype, shape, axes):
|
||||
return EyeConstant(shape, axes, dtype)
|
||||
|
||||
|
||||
def stop_gradient(x):
|
||||
return stop_gradient_p.bind(x)
|
||||
def psum(x, axis_name):
|
||||
return psum_p.bind(x, axis_name=axis_name)
|
||||
|
||||
|
||||
### convenience wrappers around traceables
|
||||
@ -2296,7 +2297,6 @@ reduce_sum_p = standard_primitive(_reduce_sum_shape_rule, _input_dtype,
|
||||
'reduce_sum', _reduce_sum_translation_rule)
|
||||
ad.deflinear(reduce_sum_p, _reduce_sum_transpose_rule)
|
||||
batching.defreducer(reduce_sum_p)
|
||||
parallel.defreducer(reduce_sum_p, parallel.psum_p)
|
||||
|
||||
|
||||
def _reduce_chooser_shape_rule(operand, axes):
|
||||
@ -2902,25 +2902,38 @@ for t in [FilledConstant, IotaConstant, EyeConstant]:
|
||||
ad_util.jaxval_zeros_likers[t] = zeros_like_array
|
||||
|
||||
|
||||
### stop_gradient
|
||||
### parallel
|
||||
|
||||
def PmapPrimitive(name):
|
||||
prim = Primitive(name)
|
||||
prim.def_impl(partial(unbound_name_error, name))
|
||||
prim.def_abstract_eval(lambda x, *args, **kwargs: x) # default
|
||||
return prim
|
||||
|
||||
def unbound_name_error(primitive_name, *args, **kwargs):
|
||||
axis_name = kwargs['axis_name']
|
||||
msg = "axis name '{}' is unbound for primitive {}."
|
||||
raise NameError(msg.format(axis_name, primitive_name))
|
||||
|
||||
|
||||
def _stop_gradient_jvp_rule(primals, tangents):
|
||||
# if we don't call stop_gradient here, we'd only peel off one autodiff tracer
|
||||
x, = primals
|
||||
return stop_gradient(x), ad_util.zero
|
||||
def psum_transpose_rule(t, axis_name):
|
||||
return [t]
|
||||
|
||||
def _stop_gradient_batch_rule(batched_args, batch_dims):
|
||||
x, = batched_args
|
||||
dim, = batch_dims
|
||||
return stop_gradient(x), dim
|
||||
def psum_parallel_translation_rule(c, val, device_groups):
|
||||
if len(device_groups) > 1:
|
||||
return c.CrossReplicaSum(val, device_groups)
|
||||
else:
|
||||
return c.CrossReplicaSum(val)
|
||||
|
||||
stop_gradient_p = Primitive('stop_gradient')
|
||||
stop_gradient_p.def_impl(identity)
|
||||
stop_gradient_p.def_abstract_eval(identity)
|
||||
xla.translations[stop_gradient_p] = lambda c, x: x
|
||||
ad.primitive_jvps[stop_gradient_p] = _stop_gradient_jvp_rule
|
||||
batching.primitive_batchers[stop_gradient_p] = _stop_gradient_batch_rule
|
||||
def psum_pmap_rule(val, axis):
|
||||
return _reduce_sum(val, [axis]), None
|
||||
|
||||
psum_p = PmapPrimitive('psum')
|
||||
parallel.pmap_primitive_rules[psum_p] = psum_pmap_rule
|
||||
pxla.parallel_translation_rules[psum_p] = psum_parallel_translation_rule
|
||||
ad.deflinear(psum_p, psum_transpose_rule)
|
||||
|
||||
parallel.defreducer(reduce_sum_p, psum_p)
|
||||
|
||||
|
||||
### util
|
||||
|
@ -133,6 +133,15 @@ def tree_structure(tree):
|
||||
return spec
|
||||
|
||||
|
||||
def prune(treedef, tuple_tree):
|
||||
if treedef is leaf:
|
||||
return tuple_tree
|
||||
elif treedef.children:
|
||||
return tuple(map(prune, treedef.children, tuple_tree))
|
||||
else:
|
||||
return ()
|
||||
|
||||
|
||||
class PyTreeDef(object):
|
||||
def __init__(self, node_type, node_data, children):
|
||||
self.node_type = node_type
|
||||
|
@ -1,8 +1,8 @@
|
||||
import numpy as onp
|
||||
|
||||
import jax.numpy as np
|
||||
from jax import grad, pjit, pmap, make_jaxpr
|
||||
from jax.interpreters.parallel import psum
|
||||
from jax import jvp, grad, pjit, pmap, make_jaxpr
|
||||
from jax.lax import psum
|
||||
|
||||
|
||||
# def f(x, y):
|
||||
@ -15,11 +15,21 @@ from jax.interpreters.parallel import psum
|
||||
|
||||
|
||||
def f(x):
|
||||
return x - psum(x, 'i')
|
||||
return np.cos(x - psum(np.sin(x), 'i'))
|
||||
|
||||
x = np.zeros(4)
|
||||
print grad(lambda x: np.sum(pmap(f, 'i')(x)))(x)
|
||||
print grad(lambda x: np.sum(x - np.sum(x)))(x)
|
||||
x = np.ones(4)
|
||||
print jvp(pmap(f, 'i'), (x,), (x,))
|
||||
|
||||
g = pjit(f, axis_name='i')
|
||||
print grad(lambda x: np.sum(g(x)))(x)
|
||||
print jvp(g, (x,), (x,))
|
||||
|
||||
|
||||
# def f(x):
|
||||
# return x - psum(x, 'i')
|
||||
|
||||
# x = np.ones(4)
|
||||
# print grad(lambda x: np.sum(pmap(f, 'i')(x)))(x)
|
||||
# print grad(lambda x: np.sum(x - np.sum(x)))(x)
|
||||
|
||||
# g = pjit(f, axis_name='i')
|
||||
# print grad(lambda x: np.sum(g(x)))(x)
|
||||
|
@ -25,7 +25,6 @@ from jax import test_util as jtu
|
||||
from jax import lax
|
||||
from jax.api import pmap, papply, jit, make_jaxpr, axisvar_split
|
||||
from jax.linear_util import wrap_init
|
||||
from jax.interpreters.parallel import psum, scatter_like
|
||||
|
||||
from jax.config import config
|
||||
config.parse_flags_with_absl()
|
||||
@ -40,20 +39,20 @@ class PmapTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def testReduceSum(self):
|
||||
f = lambda x: psum(x, 'i')
|
||||
f = lambda x: lax.psum(x, 'i')
|
||||
ans = pmap(f, axis_name='i')(onp.ones(4))
|
||||
expected = 4 * onp.ones(4)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def testLogSoftmax(self):
|
||||
f = lambda x: x - np.log(psum(np.exp(x), 'i'))
|
||||
f = lambda x: x - np.log(lax.psum(np.exp(x), 'i'))
|
||||
x = onp.log(onp.arange(1., 10., dtype=onp.float32))
|
||||
ans = pmap(f, axis_name='i')(x)
|
||||
expected = x - onp.log(onp.sum(onp.exp(x)))
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def testNested(self):
|
||||
f = lambda x: psum(psum(x, 'i'), 'j')
|
||||
f = lambda x: lax.psum(lax.psum(x, 'i'), 'j')
|
||||
x = onp.ones((2, 2))
|
||||
ans1 = pmap(pmap(f, 'i'), 'j')(x)
|
||||
ans2 = pmap(pmap(f, 'j'), 'i')(x)
|
||||
@ -80,7 +79,7 @@ class PapplyTest(jtu.JaxTestCase):
|
||||
pfun, axis_name = papply(np.sum)
|
||||
|
||||
jaxpr = make_jaxpr(pfun)(onp.zeros(5))
|
||||
expected_jaxpr = make_jaxpr(lambda x: psum(x, axis_name))(onp.zeros(5))
|
||||
expected_jaxpr = make_jaxpr(lambda x: lax.psum(x, axis_name))(onp.zeros(5))
|
||||
assert repr(jaxpr) == repr(expected_jaxpr)
|
||||
|
||||
ans = pmap(pfun, axis_name)(onp.arange(3.))
|
||||
@ -95,7 +94,7 @@ class PapplyTest(jtu.JaxTestCase):
|
||||
pfun, axis_name = papply(fun)
|
||||
|
||||
jaxpr = make_jaxpr(pfun)(onp.zeros(5))
|
||||
expected_jaxpr = make_jaxpr(lambda x: x - np.log(psum(np.exp(x), axis_name))
|
||||
expected_jaxpr = make_jaxpr(lambda x: x - np.log(lax.psum(np.exp(x), axis_name))
|
||||
)(onp.zeros(5))
|
||||
assert repr(jaxpr) == repr(expected_jaxpr)
|
||||
|
||||
@ -127,7 +126,7 @@ class PapplyTest(jtu.JaxTestCase):
|
||||
class SplitTest(jtu.JaxTestCase):
|
||||
|
||||
def testSplitBasic(self):
|
||||
f = lambda x: psum(np.sin(x), 'i')
|
||||
f = lambda x: lax.psum(np.sin(x), 'i')
|
||||
x = onp.ones((2, 2))
|
||||
fsplit = axisvar_split(f, 'i', ('j', 'k'))
|
||||
ans = pmap(pmap(fsplit, 'j'), 'k')(x)
|
||||
|
@ -23,7 +23,7 @@ from absl.testing import parameterized
|
||||
import jax.numpy as np
|
||||
from jax import test_util as jtu
|
||||
from jax.api import pjit
|
||||
from jax.interpreters.parallel import psum
|
||||
from jax.lax import psum
|
||||
|
||||
from jax.config import config
|
||||
config.parse_flags_with_absl()
|
||||
@ -31,32 +31,32 @@ config.parse_flags_with_absl()
|
||||
|
||||
class PmapTest(jtu.JaxTestCase):
|
||||
|
||||
# @jtu.skip_on_devices("gpu")
|
||||
# def testBasic(self):
|
||||
# f = lambda x: x - psum(x, 'i')
|
||||
# x = onp.arange(8., dtype=onp.float32).reshape(4, 2)
|
||||
# f = pjit(f, axis_name='i', in_axes=0, out_axes=0, mesh_axis=0)
|
||||
# ans = f(x)
|
||||
# expected = x - x.sum(0)
|
||||
# self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
@jtu.skip_on_devices("gpu")
|
||||
def testBasic(self):
|
||||
f = lambda x: x - psum(x, 'i')
|
||||
x = onp.arange(8., dtype=onp.float32).reshape(4, 2)
|
||||
f = pjit(f, axis_name='i', in_axes=0, out_axes=0, mesh_axis=0)
|
||||
ans = f(x)
|
||||
expected = x - x.sum(0)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
# @jtu.skip_on_devices("gpu")
|
||||
# def testTupleOutput(self):
|
||||
# f = lambda x: (x - psum(x, 'i'),)
|
||||
# x = onp.arange(8., dtype=onp.float32).reshape(4, 2)
|
||||
# f = pjit(f, axis_name='i', in_axes=0, out_axes=0, mesh_axis=0)
|
||||
# ans = f(x)
|
||||
# expected = (x - x.sum(0),)
|
||||
# self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
@jtu.skip_on_devices("gpu")
|
||||
def testTupleOutput(self):
|
||||
f = lambda x: (x - psum(x, 'i'),)
|
||||
x = onp.arange(8., dtype=onp.float32).reshape(4, 2)
|
||||
f = pjit(f, axis_name='i', in_axes=0, out_axes=0, mesh_axis=0)
|
||||
ans = f(x)
|
||||
expected = (x - x.sum(0),)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
# @jtu.skip_on_devices("gpu")
|
||||
# def testTupleInput(self):
|
||||
# f = lambda x: x[0] - psum(x[0], 'i')
|
||||
# x = onp.arange(8., dtype=onp.float32).reshape(4, 2)
|
||||
# f = pjit(f, axis_name='i', in_axes=0, out_axes=0, mesh_axis=0)
|
||||
# ans = f((x,))
|
||||
# expected = x - x.sum(0)
|
||||
# self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
@jtu.skip_on_devices("gpu")
|
||||
def testTupleInput(self):
|
||||
f = lambda x: x[0] - psum(x[0], 'i')
|
||||
x = onp.arange(8., dtype=onp.float32).reshape(4, 2)
|
||||
f = pjit(f, axis_name='i', in_axes=0, out_axes=0, mesh_axis=0)
|
||||
ans = f((x,))
|
||||
expected = x - x.sum(0)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
@jtu.skip_on_devices("gpu")
|
||||
def testNested(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user