mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Remove jaxpr
and name
from Lowered
because specialize
already has those. This keeps the abstraction boundary clear. Adapt export
to use specialize
.
PiperOrigin-RevId: 640968129
This commit is contained in:
parent
a65d3ae0da
commit
fbf2a62aa1
@ -1811,9 +1811,41 @@ def _cpp_pmap(
|
||||
|
||||
pmap_f = wraps(fun)(cpp_mapped_f)
|
||||
|
||||
@api_boundary
|
||||
def specialize(*args, **kwargs):
|
||||
lowering_parameters = kwargs.pop(
|
||||
'_experimental_lowering_parameters', mlir.LoweringParameters())
|
||||
p = _prepare_pmap(
|
||||
fun, in_axes, out_axes, static_broadcasted_tuple, donate_tuple,
|
||||
devices, backend, axis_size, args, kwargs)
|
||||
abstract_args = list(map(shaped_abstractify, p.flat_args))
|
||||
lower_callable = partial(
|
||||
pxla.lower_parallel_callable, p.flat_fun, backend, axis_name,
|
||||
axis_size=p.local_axis_size, global_axis_size=p.global_axis_size,
|
||||
devices=p.devices,
|
||||
name=p.flat_fun.__name__,
|
||||
in_axes=p.in_axes_flat,
|
||||
out_axes_thunk=p.out_axes_thunk,
|
||||
donated_invars=p.donated_invars,
|
||||
is_explicit_global_axis_size=p.is_explicit_global_axis_size,
|
||||
avals=abstract_args,
|
||||
lowering_parameters=lowering_parameters)
|
||||
jaxpr, _, _, _, _ = pxla.get_pmap_jaxpr(
|
||||
p.flat_fun, backend, axis_name,
|
||||
axis_size=p.local_axis_size, global_axis_size=p.global_axis_size,
|
||||
devices=p.devices,
|
||||
name=p.flat_fun.__name__,
|
||||
in_axes=p.in_axes_flat,
|
||||
out_axes_thunk=p.out_axes_thunk,
|
||||
avals=abstract_args)
|
||||
args_info = stages.make_args_info(p.in_tree, abstract_args, donate_tuple)
|
||||
return stages.Specialized(jaxpr, args_info, p.flat_fun.__name__,
|
||||
p.out_tree(), lower_callable)
|
||||
|
||||
pmap_f.lower = _pmap_lower(
|
||||
fun, axis_name, in_axes, out_axes, static_broadcasted_tuple, devices,
|
||||
backend, axis_size, donate_tuple)
|
||||
pmap_f.specialize = specialize
|
||||
|
||||
return pmap_f
|
||||
|
||||
@ -1845,7 +1877,7 @@ def _pmap_lower(fun, axis_name, in_axes, out_axes, static_broadcasted_tuple,
|
||||
fun, in_axes, out_axes, static_broadcasted_tuple, donate_tuple,
|
||||
devices, backend, axis_size, args, kwargs)
|
||||
abstract_args = list(map(shaped_abstractify, p.flat_args))
|
||||
computation, closed_jaxpr = pxla.lower_parallel_callable(
|
||||
computation = pxla.lower_parallel_callable(
|
||||
p.flat_fun, backend, axis_name,
|
||||
axis_size=p.local_axis_size, global_axis_size=p.global_axis_size,
|
||||
devices=p.devices,
|
||||
@ -1857,8 +1889,7 @@ def _pmap_lower(fun, axis_name, in_axes, out_axes, static_broadcasted_tuple,
|
||||
avals=abstract_args,
|
||||
lowering_parameters=lowering_parameters)
|
||||
return stages.Lowered.from_flat_info(
|
||||
computation, p.in_tree, abstract_args, donate_tuple, p.out_tree(),
|
||||
fun_name=p.flat_fun.__name__, jaxpr=closed_jaxpr)
|
||||
computation, p.in_tree, abstract_args, donate_tuple, p.out_tree())
|
||||
|
||||
return lower
|
||||
|
||||
|
@ -557,7 +557,7 @@ def parallel_callable(fun: lu.WrappedFun,
|
||||
donated_invars: Sequence[bool],
|
||||
is_explicit_global_axis_size: bool,
|
||||
*avals):
|
||||
pmap_computation, _ = lower_parallel_callable(
|
||||
pmap_computation = lower_parallel_callable(
|
||||
fun, backend_name, axis_name, axis_size, global_axis_size, devices, name,
|
||||
in_axes, out_axes_thunk, donated_invars,
|
||||
is_explicit_global_axis_size, avals,
|
||||
@ -665,6 +665,31 @@ def stage_parallel_callable(
|
||||
return jaxpr, consts, replicas, shards
|
||||
|
||||
|
||||
def get_pmap_jaxpr(
|
||||
fun: lu.WrappedFun,
|
||||
backend_name: str | None,
|
||||
axis_name: core.AxisName,
|
||||
axis_size: int,
|
||||
global_axis_size: int,
|
||||
devices: Sequence[xc.Device] | None,
|
||||
name: str,
|
||||
in_axes: Iterable[int | None],
|
||||
out_axes_thunk: Callable[[], Sequence[int | None]],
|
||||
avals: Sequence[core.AbstractValue]):
|
||||
if devices is not None and backend_name is None:
|
||||
backend = xb.get_device_backend(devices[0])
|
||||
else:
|
||||
backend = xb.get_backend(backend_name)
|
||||
|
||||
pci = ParallelCallableInfo(
|
||||
name, backend, axis_name, axis_size, global_axis_size, devices,
|
||||
in_axes, out_axes_thunk, avals)
|
||||
jaxpr, consts, replicas, shards = stage_parallel_callable(pci, fun)
|
||||
jaxpr = core.remove_named_axis_effects(jaxpr, {axis_name})
|
||||
closed_jaxpr = core.ClosedJaxpr(jaxpr, consts)
|
||||
return closed_jaxpr, backend, replicas, shards, pci
|
||||
|
||||
|
||||
@profiler.annotate_function
|
||||
def lower_parallel_callable(
|
||||
fun: lu.WrappedFun,
|
||||
@ -680,7 +705,7 @@ def lower_parallel_callable(
|
||||
is_explicit_global_axis_size: bool,
|
||||
avals: Sequence[core.AbstractValue],
|
||||
*,
|
||||
lowering_parameters: mlir.LoweringParameters) -> tuple[PmapComputation, core.ClosedJaxpr]:
|
||||
lowering_parameters: mlir.LoweringParameters) -> PmapComputation:
|
||||
# Determine global_axis_size for use in AxisEnv.
|
||||
# TODO(mattjj,skyewm): revive this check (inner_pmap always False now)
|
||||
# if xb.process_count() > 1 and global_axis_size is None and inner_pmap:
|
||||
@ -691,10 +716,10 @@ def lower_parallel_callable(
|
||||
f"Specified axis_size {global_axis_size} doesn't match received "
|
||||
f"axis_size {axis_size}.")
|
||||
|
||||
if devices is not None and backend_name is None:
|
||||
backend = xb.get_device_backend(devices[0])
|
||||
else:
|
||||
backend = xb.get_backend(backend_name)
|
||||
closed_jaxpr, backend, replicas, shards, pci = get_pmap_jaxpr(
|
||||
fun, backend_name, axis_name, axis_size, global_axis_size, devices, name,
|
||||
in_axes, out_axes_thunk, avals)
|
||||
jaxpr = closed_jaxpr.jaxpr
|
||||
|
||||
no_nested_sharding = False
|
||||
must_run_on_all_devices = False
|
||||
@ -711,10 +736,6 @@ def lower_parallel_callable(
|
||||
# devices). Nested sharding is ok in this case.
|
||||
must_run_on_all_devices = True
|
||||
|
||||
pci = ParallelCallableInfo(
|
||||
name, backend, axis_name, axis_size, global_axis_size, devices,
|
||||
in_axes, out_axes_thunk, avals)
|
||||
jaxpr, consts, replicas, shards = stage_parallel_callable(pci, fun)
|
||||
if logger.isEnabledFor(logging.DEBUG):
|
||||
logger.debug("sharded_avals: %s", shards.sharded_avals)
|
||||
logger.debug("global_sharded_avals: %s", shards.global_sharded_avals)
|
||||
@ -756,8 +777,6 @@ def lower_parallel_callable(
|
||||
axis_env = sharding_impls.AxisEnv(
|
||||
replicas.num_global_replicas, (axis_name,), (global_axis_size,))
|
||||
name_stack = source_info_util.new_name_stack(wrap_name(name, 'pmap'))
|
||||
jaxpr = core.remove_named_axis_effects(jaxpr, {axis_name})
|
||||
closed_jaxpr = core.ClosedJaxpr(jaxpr, consts)
|
||||
replicated_args = [axis is None for axis in in_axes]
|
||||
tuple_args = dispatch.should_tuple_args(len(shards.global_sharded_avals),
|
||||
backend.platform)
|
||||
@ -798,7 +817,7 @@ def lower_parallel_callable(
|
||||
keepalive=lowering_result.keepalive,
|
||||
host_callbacks=lowering_result.host_callbacks,
|
||||
jaxpr_debug_info=closed_jaxpr.jaxpr.debug_info,
|
||||
shape_poly_state=lowering_result.shape_poly_state), closed_jaxpr
|
||||
shape_poly_state=lowering_result.shape_poly_state)
|
||||
|
||||
|
||||
def _pmap_unmap_shaped_array(
|
||||
|
@ -617,7 +617,7 @@ def xmap(fun: Callable,
|
||||
'_experimental_lowering_platform', mlir.LoweringParameters())
|
||||
fun_flat, args_flat, params, in_tree, out_tree = infer_params(*args)
|
||||
avals_flat = [shaped_abstractify(arg) for arg in args_flat]
|
||||
computation, jaxpr = make_xmap_callable(
|
||||
computation = make_xmap_callable(
|
||||
fun_flat, params['name'], params['in_axes'], params['out_axes_thunk'],
|
||||
params['donated_invars'], params['global_axis_sizes'], params['axis_resources'],
|
||||
params['resource_env'], params['backend'], params['spmd_in_axes'],
|
||||
@ -627,8 +627,7 @@ def xmap(fun: Callable,
|
||||
in_tree = treedef_tuple([in_tree, tree_flatten({})[1]])
|
||||
in_avals = in_tree.unflatten(avals_flat)
|
||||
return stages.Lowered.from_flat_info(
|
||||
computation, in_tree, in_avals, donate_argnums, out_tree(),
|
||||
no_kwargs=True, fun_name=params['name'], jaxpr=jaxpr)
|
||||
computation, in_tree, in_avals, donate_argnums, out_tree())
|
||||
|
||||
fun_mapped.lower = lower
|
||||
return type_cast(stages.Wrapped, fun_mapped)
|
||||
@ -637,7 +636,7 @@ def xmap_impl(fun: lu.WrappedFun, *args, name, in_axes, out_axes_thunk, donated_
|
||||
global_axis_sizes, axis_resources, resource_env, backend,
|
||||
spmd_in_axes, spmd_out_axes_thunk):
|
||||
in_avals = [core.raise_to_shaped(core.get_aval(arg)) for arg in args]
|
||||
computation, _ = make_xmap_callable(
|
||||
computation = make_xmap_callable(
|
||||
fun, name, in_axes, out_axes_thunk, donated_invars, global_axis_sizes,
|
||||
axis_resources, resource_env, backend,
|
||||
spmd_in_axes, spmd_out_axes_thunk,
|
||||
@ -709,7 +708,7 @@ def make_xmap_callable(fun: lu.WrappedFun,
|
||||
in_shardings, out_shardings, donated_invars,
|
||||
use_spmd_lowering, in_avals,
|
||||
tiling_method=tiling_method,
|
||||
lowering_parameters=lowering_parameters), jaxpr
|
||||
lowering_parameters=lowering_parameters)
|
||||
else:
|
||||
jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(f, in_avals)
|
||||
return pxla.lower_sharding_computation(
|
||||
@ -717,7 +716,7 @@ def make_xmap_callable(fun: lu.WrappedFun,
|
||||
(UNSPECIFIED,) * len(in_avals), (UNSPECIFIED,) * len(out_avals),
|
||||
(None,) * len(in_avals), (None,) * len(out_avals),
|
||||
donated_invars, keep_unused=True, inline=False,
|
||||
devices_from_context=None, lowering_parameters=lowering_parameters), jaxpr
|
||||
devices_from_context=None, lowering_parameters=lowering_parameters)
|
||||
|
||||
|
||||
class EvaluationPlan(NamedTuple):
|
||||
|
@ -498,8 +498,7 @@ def _make_jit_wrapper(jit_info: PjitInfo):
|
||||
donate_argnums = tuple(i for i, d in enumerate(donated_invars) if d)
|
||||
jaxpr = params["jaxpr"]
|
||||
return stages.Lowered.from_flat_info(
|
||||
lowering, in_tree, jaxpr.in_avals, donate_argnums, out_tree,
|
||||
fun_name=params["name"], jaxpr=jaxpr)
|
||||
lowering, in_tree, jaxpr.in_avals, donate_argnums, out_tree)
|
||||
|
||||
@api_boundary
|
||||
def eval_shape(*args, **kwargs):
|
||||
|
@ -643,30 +643,23 @@ class Lowered(Stage):
|
||||
querying properties of lowered computations across JAX's various
|
||||
lowering paths (:func:`~jax.jit`, :func:`~jax.pmap`, etc.).
|
||||
"""
|
||||
__slots__ = ["_lowering", "args_info", "out_tree", "_no_kwargs", "_fun_name", "_jaxpr"]
|
||||
__slots__ = ["_lowering", "args_info", "out_tree", "_no_kwargs"]
|
||||
_lowering: XlaLowering
|
||||
args_info: Any # PyTree of ArgInfo
|
||||
out_tree: tree_util.PyTreeDef
|
||||
_no_kwargs: bool
|
||||
_fun_name: str
|
||||
_jaxpr: core.ClosedJaxpr | None # Can be None when this class is constructed
|
||||
# outside of JAX core.
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
lowering: XlaLowering,
|
||||
args_info, # PyTree of ArgInfo
|
||||
out_tree: tree_util.PyTreeDef,
|
||||
no_kwargs: bool = False,
|
||||
fun_name: str = "<unnamed function>",
|
||||
jaxpr: core.ClosedJaxpr | None = None):
|
||||
no_kwargs: bool = False):
|
||||
|
||||
self._lowering = lowering
|
||||
self.args_info = args_info
|
||||
self.out_tree = out_tree
|
||||
self._no_kwargs = no_kwargs
|
||||
self._fun_name = fun_name
|
||||
self._jaxpr = jaxpr
|
||||
|
||||
@classmethod
|
||||
def from_flat_info(cls,
|
||||
@ -675,9 +668,7 @@ class Lowered(Stage):
|
||||
in_avals,
|
||||
donate_argnums: tuple[int, ...],
|
||||
out_tree: tree_util.PyTreeDef,
|
||||
no_kwargs: bool = False,
|
||||
fun_name: str = "<unnamed function>",
|
||||
jaxpr: core.ClosedJaxpr | None = None):
|
||||
no_kwargs: bool = False):
|
||||
"""Initialize from flat info (``in_avals`` etc.) and an input PyTreeDef.
|
||||
|
||||
Args:
|
||||
@ -686,15 +677,12 @@ class Lowered(Stage):
|
||||
no_kwargs: If ``True`` the transformation, and the
|
||||
``Compiled`` returned from this object will not support keyword
|
||||
arguments (an error will be raised if some are provided).
|
||||
fun_name: the name of the lowered function.
|
||||
jaxpr: the Jaxpr of the lowered function. The value `None` is for
|
||||
backwards compatibility, and is used only outside JAX core.
|
||||
"""
|
||||
return cls(
|
||||
lowering,
|
||||
make_args_info(in_tree, in_avals, donate_argnums),
|
||||
out_tree,
|
||||
no_kwargs=no_kwargs, fun_name=fun_name, jaxpr=jaxpr)
|
||||
no_kwargs=no_kwargs)
|
||||
|
||||
@property
|
||||
def out_info(self): # PyTree of OutInfo
|
||||
|
@ -433,15 +433,17 @@ def export(fun_jax: Callable,
|
||||
"""
|
||||
|
||||
def do_export(*args_specs, **kwargs_specs) -> Exported:
|
||||
if not hasattr(fun_jax, "lower"):
|
||||
if hasattr(fun_jax, "lower"):
|
||||
# If we have a pjit or pmap already we do not wrap with another, and we
|
||||
# allow shardings.
|
||||
wrapped_fun_jax = fun_jax
|
||||
else:
|
||||
# We support convert(pjit(f_jax)) and convert(jit(f_jax)) but also
|
||||
# convert(f_jax), in which case a "jit" is implied. In that case we raise
|
||||
# an error if the lowered function contains non-replicated sharding annotations.
|
||||
wrapped_fun_jax = jax.jit(fun_jax)
|
||||
else:
|
||||
# If we have a pjit or pmap already we do not wrap with another, and we
|
||||
# allow shardings.
|
||||
wrapped_fun_jax = fun_jax # type: ignore
|
||||
|
||||
has_specialize = hasattr(wrapped_fun_jax, "specialize")
|
||||
|
||||
if lowering_platforms is not None:
|
||||
actual_lowering_platforms = tuple(lowering_platforms)
|
||||
@ -464,19 +466,32 @@ def export(fun_jax: Callable,
|
||||
self_descr=f"current (from {shape_poly.args_kwargs_path_to_str(symbolic_scope[1])}) ",
|
||||
other_descr=shape_poly.args_kwargs_path_to_str(k_path))
|
||||
|
||||
lowered = wrapped_fun_jax.lower(
|
||||
*args_specs, **kwargs_specs,
|
||||
_experimental_lowering_parameters=mlir.LoweringParameters(
|
||||
if has_specialize:
|
||||
specialized = wrapped_fun_jax.specialize(
|
||||
*args_specs, **kwargs_specs,
|
||||
_experimental_lowering_parameters=mlir.LoweringParameters(
|
||||
platforms=actual_lowering_platforms,
|
||||
for_export=True,
|
||||
))
|
||||
))
|
||||
jaxpr, fun_name = specialized.jaxpr, specialized.fun_name
|
||||
lowered = specialized.lower()
|
||||
else:
|
||||
lowered = wrapped_fun_jax.lower(
|
||||
*args_specs, **kwargs_specs,
|
||||
_experimental_lowering_parameters=mlir.LoweringParameters(
|
||||
platforms=actual_lowering_platforms,
|
||||
for_export=True,
|
||||
))
|
||||
jaxpr, fun_name = None, util.fun_name(wrapped_fun_jax)
|
||||
return _export_lowered(
|
||||
lowered, disabled_checks=disabled_checks,
|
||||
lowered, jaxpr, fun_name,
|
||||
disabled_checks=disabled_checks,
|
||||
_device_assignment_for_internal_jax2tf_use_only=_device_assignment_for_internal_jax2tf_use_only)
|
||||
return do_export
|
||||
|
||||
def _export_lowered(
|
||||
lowered: stages.Lowered,
|
||||
jaxpr: core.ClosedJaxpr, fun_name: str,
|
||||
disabled_checks: Sequence[DisabledSafetyCheck] = (),
|
||||
_device_assignment_for_internal_jax2tf_use_only = None,
|
||||
) -> Exported:
|
||||
@ -563,8 +578,8 @@ def _export_lowered(
|
||||
def _get_exported_vjp(exp_primal: Exported) -> Exported:
|
||||
# Turn the primal jaxpr into a function, in preparation for exporting
|
||||
# the VJP. Note that jaxpr_as_fun produces a function with flat arguments
|
||||
assert(lowered._jaxpr is not None) # None only when the lowered was created outside JAX
|
||||
fun_jax = core.jaxpr_as_fun(lowered._jaxpr)
|
||||
assert(jaxpr is not None) # None only when the lowered was created outside JAX
|
||||
fun_jax = core.jaxpr_as_fun(jaxpr)
|
||||
|
||||
fun_vjp_jax, vjp_in_avals = _get_vjp_fun(fun_jax,
|
||||
in_tree=exp_primal.in_tree,
|
||||
@ -580,7 +595,7 @@ def _export_lowered(
|
||||
disabled_checks=exp_primal.disabled_safety_checks)(*vjp_in_avals)
|
||||
|
||||
return Exported(
|
||||
fun_name=lowered._fun_name,
|
||||
fun_name=fun_name,
|
||||
in_tree=lowered.in_tree,
|
||||
out_tree=lowered.out_tree,
|
||||
in_avals=tuple(args_avals_flat),
|
||||
|
@ -34,7 +34,6 @@ import jax
|
||||
from jax._src import compiler
|
||||
from jax._src import config
|
||||
from jax._src import maps
|
||||
from jax._src.maps import xmap
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import xla_bridge
|
||||
from jax import lax
|
||||
@ -56,7 +55,6 @@ config.parse_flags_with_absl()
|
||||
from jax.experimental.jax2tf.tests import tf_test_util
|
||||
|
||||
prev_xla_flags = None
|
||||
prev_spmd_lowering_flag = None
|
||||
|
||||
topology = None
|
||||
|
||||
@ -78,9 +76,6 @@ def setUpModule():
|
||||
" --xla_force_host_platform_device_count=8")
|
||||
# Clear any cached backends so new CPU backend will pick up the env var.
|
||||
xla_bridge.get_backend.cache_clear()
|
||||
global prev_spmd_lowering_flag
|
||||
prev_spmd_lowering_flag = maps.SPMD_LOWERING.value
|
||||
config.update('experimental_xmap_spmd_lowering', True)
|
||||
|
||||
|
||||
def tearDownModule():
|
||||
@ -89,7 +84,6 @@ def tearDownModule():
|
||||
else:
|
||||
os.environ["XLA_FLAGS"] = prev_xla_flags
|
||||
xla_bridge.get_backend.cache_clear()
|
||||
config.update('experimental_xmap_spmd_lowering', prev_spmd_lowering_flag)
|
||||
|
||||
|
||||
class ShardingTest(tf_test_util.JaxToTfTestCase):
|
||||
@ -536,115 +530,6 @@ class ShardingTest(tf_test_util.JaxToTfTestCase):
|
||||
"function with sharded arguments or results must be used under a `tf.function` context"):
|
||||
jax2tf.convert(f_jax)(a)
|
||||
|
||||
def test_xmap_basic(self):
|
||||
devices = np.reshape(self.devices, (1, 2))
|
||||
ashape = (16, 8, 5)
|
||||
a = np.arange(np.prod(ashape), dtype=np.float32).reshape(ashape)
|
||||
bshape = (2, 7)
|
||||
b = np.arange(np.prod(bshape), dtype=np.float32).reshape(bshape)
|
||||
|
||||
# f_jax: f32[16,8,5], f32[2,7] -> f32[16,8,10], f32[2,28]
|
||||
# lambda ...: f32[5], f32[7] -> f32[10], f32[28]
|
||||
f_jax = xmap(lambda a, b: (jnp.concatenate([a, a], axis=0) * 2.,
|
||||
jnp.concatenate([b, b, b, b], axis=0) * 4.),
|
||||
in_axes=({0: 'a', 1: 'b'}, ['c', ...]),
|
||||
out_axes=({0: 'a', 1: 'b'}, ['c', ...]),
|
||||
axis_resources={'a': 'x', 'b': 'y', 'c': 'x'})
|
||||
|
||||
@tf.function(autograph=False, jit_compile=True)
|
||||
def f_tf(a, b):
|
||||
# xmap works only with native serialization
|
||||
f_converted = jax2tf.convert(f_jax, native_serialization=True)
|
||||
if jtu.test_device_matches(["tpu"]):
|
||||
res = tf.compat.v1.tpu.rewrite(
|
||||
f_converted, [tf.convert_to_tensor(a), tf.convert_to_tensor(b)],
|
||||
device_assignment=self.device_assignment(
|
||||
computation_shape=[1, 1, 1, 2])
|
||||
)
|
||||
return (res[0], res[1])
|
||||
else:
|
||||
return f_converted(a, b)
|
||||
|
||||
with Mesh(devices, ('x', 'y')):
|
||||
res_jax = f_jax(a, b)
|
||||
self.assertAllClose(res_jax, (jnp.concatenate([a, a], axis=2) * 2.,
|
||||
jnp.concatenate([b, b, b, b], axis=1) * 4.))
|
||||
res_tf = f_tf(a, b)
|
||||
self.assertAllClose(res_tf, res_jax)
|
||||
|
||||
self.check_sharding(
|
||||
jax2tf.convert(f_jax, native_serialization=True), [a, b],
|
||||
checks=[
|
||||
(r"f32\[16,8,5\].*custom_call_target.*Sharding.*sharding.*devices=\[1,2,1\]", 1),
|
||||
# The output sharding
|
||||
(r"f32\[2,7\].*custom_call_target.*Sharding.*sharding.*replicated", 1),
|
||||
(r"f32\[2,28\].*custom_call_target.*Sharding.*sharding.*replicated", 1),
|
||||
])
|
||||
|
||||
def test_xmap_collective_reduce(self):
|
||||
devices = np.reshape(self.devices, (1, 2))
|
||||
ashape = (16, 8, 5)
|
||||
a = np.arange(np.prod(ashape), dtype=np.float32).reshape(ashape)
|
||||
bshape = (2, 7)
|
||||
b = np.arange(np.prod(bshape), dtype=np.float32).reshape(bshape)
|
||||
f_jax = xmap(lambda a, b: (lax.psum(a * 2., 'a'), b * 4.),
|
||||
in_axes=(['a', 'b', ...], {0: 'c'}),
|
||||
out_axes=(['b', ...], {0: 'c'}),
|
||||
axis_resources={'a': 'x', 'b': 'y', 'c': 'x'})
|
||||
|
||||
@tf.function(autograph=False, jit_compile=True)
|
||||
def f_tf(a, b):
|
||||
f_converted = jax2tf.convert(f_jax, native_serialization=True)
|
||||
if jtu.test_device_matches(["tpu"]):
|
||||
res = tf.compat.v1.tpu.rewrite(
|
||||
f_converted, [tf.convert_to_tensor(a), tf.convert_to_tensor(b)],
|
||||
device_assignment=self.device_assignment(
|
||||
computation_shape=[1, 1, 1, 2])
|
||||
)
|
||||
return (res[0], res[1])
|
||||
else:
|
||||
return f_converted(a, b)
|
||||
|
||||
with Mesh(devices, ('x', 'y')):
|
||||
res_jax = f_jax(a, b)
|
||||
self.assertAllClose(res_jax, ((a * 2.).sum(0), b * 4.))
|
||||
res_tf = f_tf(a, b)
|
||||
self.assertAllClose(res_tf, res_jax)
|
||||
self.check_sharding(
|
||||
jax2tf.convert(f_jax, native_serialization=True), [a, b],
|
||||
checks=[
|
||||
(r"f32\[16,8,5\].*custom_call_target.*Sharding.*sharding.*devices=\[1,2,1\]", 1),
|
||||
(r"f32\[2,7\].*custom_call_target.*Sharding.*sharding.*replicated", 2),
|
||||
(r"f32\[8,5\].*custom_call_target.*Sharding.*sharding.*devices=\[2,1\]", 1),
|
||||
])
|
||||
|
||||
def test_grad_xmap(self):
|
||||
devices = np.reshape(self.devices, (1, 2))
|
||||
ashape = (16, 8, 5)
|
||||
a = np.arange(np.prod(ashape), dtype=np.float32).reshape(ashape)
|
||||
|
||||
# f_jax: f32[16,8,5]-> f32[16,8,10]
|
||||
# lambda ...: f32[5]-> f32[10]
|
||||
f_jax = xmap(lambda a: jnp.concatenate([a, a], axis=0) * 2.,
|
||||
in_axes=({0: 'a', 1: 'b'}),
|
||||
out_axes={0: 'a', 1: 'b'},
|
||||
axis_resources={'a': 'x', 'b': 'y'})
|
||||
|
||||
def f_grad_tf(a, res_ct):
|
||||
with tf.GradientTape(persistent=True) as tape:
|
||||
tape.watch(a)
|
||||
res_tf = jax2tf.convert(f_jax, native_serialization=True)(a)
|
||||
return tape.gradient(res_tf, a, output_gradients=res_ct)
|
||||
|
||||
with Mesh(devices, ('x', 'y')):
|
||||
self.check_sharding(f_grad_tf, [a, np.concatenate([a, a], axis=2)],
|
||||
checks=[
|
||||
# Primal input and grad output
|
||||
(r"f32\[16,8,5\].*custom_call_target.*Sharding.*sharding.*devices=\[1,2,1\]", self.GEQ(2)),
|
||||
# Input cotangent
|
||||
(r"f32\[16,8,10\].*custom_call_target.*Sharding.*sharding.*devices=\[1,2,1\]", self.GEQ(1)),
|
||||
])
|
||||
|
||||
@jtu.ignore_warning(category=UserWarning,
|
||||
message="all_to_all .* are only implemented properly for TPUs and GPUs .*")
|
||||
def test_shmap_all_to_all(self):
|
||||
|
@ -556,12 +556,16 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
return jnp.sin(x)
|
||||
|
||||
# This makes it look like a jitted-function
|
||||
def lower(self, x,
|
||||
_experimental_lowering_parameters=None):
|
||||
def lower(self, x, _experimental_lowering_parameters=None):
|
||||
return jax.jit(self.__call__).lower(
|
||||
x,
|
||||
_experimental_lowering_parameters=_experimental_lowering_parameters)
|
||||
|
||||
def specialize(self, x, _experimental_lowering_parameters=None):
|
||||
return jax.jit(self.__call__).specialize(
|
||||
x,
|
||||
_experimental_lowering_parameters=_experimental_lowering_parameters)
|
||||
|
||||
a, = export.symbolic_shape("a,")
|
||||
# No error
|
||||
_ = get_exported(MyCallable())(
|
||||
|
Loading…
x
Reference in New Issue
Block a user