Remove jaxpr and name from Lowered because specialize already has those. This keeps the abstraction boundary clear. Adapt export to use specialize.

PiperOrigin-RevId: 640968129
This commit is contained in:
Yash Katariya 2024-06-06 11:36:59 -07:00 committed by jax authors
parent a65d3ae0da
commit fbf2a62aa1
8 changed files with 110 additions and 170 deletions

View File

@ -1811,9 +1811,41 @@ def _cpp_pmap(
pmap_f = wraps(fun)(cpp_mapped_f)
@api_boundary
def specialize(*args, **kwargs):
lowering_parameters = kwargs.pop(
'_experimental_lowering_parameters', mlir.LoweringParameters())
p = _prepare_pmap(
fun, in_axes, out_axes, static_broadcasted_tuple, donate_tuple,
devices, backend, axis_size, args, kwargs)
abstract_args = list(map(shaped_abstractify, p.flat_args))
lower_callable = partial(
pxla.lower_parallel_callable, p.flat_fun, backend, axis_name,
axis_size=p.local_axis_size, global_axis_size=p.global_axis_size,
devices=p.devices,
name=p.flat_fun.__name__,
in_axes=p.in_axes_flat,
out_axes_thunk=p.out_axes_thunk,
donated_invars=p.donated_invars,
is_explicit_global_axis_size=p.is_explicit_global_axis_size,
avals=abstract_args,
lowering_parameters=lowering_parameters)
jaxpr, _, _, _, _ = pxla.get_pmap_jaxpr(
p.flat_fun, backend, axis_name,
axis_size=p.local_axis_size, global_axis_size=p.global_axis_size,
devices=p.devices,
name=p.flat_fun.__name__,
in_axes=p.in_axes_flat,
out_axes_thunk=p.out_axes_thunk,
avals=abstract_args)
args_info = stages.make_args_info(p.in_tree, abstract_args, donate_tuple)
return stages.Specialized(jaxpr, args_info, p.flat_fun.__name__,
p.out_tree(), lower_callable)
pmap_f.lower = _pmap_lower(
fun, axis_name, in_axes, out_axes, static_broadcasted_tuple, devices,
backend, axis_size, donate_tuple)
pmap_f.specialize = specialize
return pmap_f
@ -1845,7 +1877,7 @@ def _pmap_lower(fun, axis_name, in_axes, out_axes, static_broadcasted_tuple,
fun, in_axes, out_axes, static_broadcasted_tuple, donate_tuple,
devices, backend, axis_size, args, kwargs)
abstract_args = list(map(shaped_abstractify, p.flat_args))
computation, closed_jaxpr = pxla.lower_parallel_callable(
computation = pxla.lower_parallel_callable(
p.flat_fun, backend, axis_name,
axis_size=p.local_axis_size, global_axis_size=p.global_axis_size,
devices=p.devices,
@ -1857,8 +1889,7 @@ def _pmap_lower(fun, axis_name, in_axes, out_axes, static_broadcasted_tuple,
avals=abstract_args,
lowering_parameters=lowering_parameters)
return stages.Lowered.from_flat_info(
computation, p.in_tree, abstract_args, donate_tuple, p.out_tree(),
fun_name=p.flat_fun.__name__, jaxpr=closed_jaxpr)
computation, p.in_tree, abstract_args, donate_tuple, p.out_tree())
return lower

View File

@ -557,7 +557,7 @@ def parallel_callable(fun: lu.WrappedFun,
donated_invars: Sequence[bool],
is_explicit_global_axis_size: bool,
*avals):
pmap_computation, _ = lower_parallel_callable(
pmap_computation = lower_parallel_callable(
fun, backend_name, axis_name, axis_size, global_axis_size, devices, name,
in_axes, out_axes_thunk, donated_invars,
is_explicit_global_axis_size, avals,
@ -665,6 +665,31 @@ def stage_parallel_callable(
return jaxpr, consts, replicas, shards
def get_pmap_jaxpr(
fun: lu.WrappedFun,
backend_name: str | None,
axis_name: core.AxisName,
axis_size: int,
global_axis_size: int,
devices: Sequence[xc.Device] | None,
name: str,
in_axes: Iterable[int | None],
out_axes_thunk: Callable[[], Sequence[int | None]],
avals: Sequence[core.AbstractValue]):
if devices is not None and backend_name is None:
backend = xb.get_device_backend(devices[0])
else:
backend = xb.get_backend(backend_name)
pci = ParallelCallableInfo(
name, backend, axis_name, axis_size, global_axis_size, devices,
in_axes, out_axes_thunk, avals)
jaxpr, consts, replicas, shards = stage_parallel_callable(pci, fun)
jaxpr = core.remove_named_axis_effects(jaxpr, {axis_name})
closed_jaxpr = core.ClosedJaxpr(jaxpr, consts)
return closed_jaxpr, backend, replicas, shards, pci
@profiler.annotate_function
def lower_parallel_callable(
fun: lu.WrappedFun,
@ -680,7 +705,7 @@ def lower_parallel_callable(
is_explicit_global_axis_size: bool,
avals: Sequence[core.AbstractValue],
*,
lowering_parameters: mlir.LoweringParameters) -> tuple[PmapComputation, core.ClosedJaxpr]:
lowering_parameters: mlir.LoweringParameters) -> PmapComputation:
# Determine global_axis_size for use in AxisEnv.
# TODO(mattjj,skyewm): revive this check (inner_pmap always False now)
# if xb.process_count() > 1 and global_axis_size is None and inner_pmap:
@ -691,10 +716,10 @@ def lower_parallel_callable(
f"Specified axis_size {global_axis_size} doesn't match received "
f"axis_size {axis_size}.")
if devices is not None and backend_name is None:
backend = xb.get_device_backend(devices[0])
else:
backend = xb.get_backend(backend_name)
closed_jaxpr, backend, replicas, shards, pci = get_pmap_jaxpr(
fun, backend_name, axis_name, axis_size, global_axis_size, devices, name,
in_axes, out_axes_thunk, avals)
jaxpr = closed_jaxpr.jaxpr
no_nested_sharding = False
must_run_on_all_devices = False
@ -711,10 +736,6 @@ def lower_parallel_callable(
# devices). Nested sharding is ok in this case.
must_run_on_all_devices = True
pci = ParallelCallableInfo(
name, backend, axis_name, axis_size, global_axis_size, devices,
in_axes, out_axes_thunk, avals)
jaxpr, consts, replicas, shards = stage_parallel_callable(pci, fun)
if logger.isEnabledFor(logging.DEBUG):
logger.debug("sharded_avals: %s", shards.sharded_avals)
logger.debug("global_sharded_avals: %s", shards.global_sharded_avals)
@ -756,8 +777,6 @@ def lower_parallel_callable(
axis_env = sharding_impls.AxisEnv(
replicas.num_global_replicas, (axis_name,), (global_axis_size,))
name_stack = source_info_util.new_name_stack(wrap_name(name, 'pmap'))
jaxpr = core.remove_named_axis_effects(jaxpr, {axis_name})
closed_jaxpr = core.ClosedJaxpr(jaxpr, consts)
replicated_args = [axis is None for axis in in_axes]
tuple_args = dispatch.should_tuple_args(len(shards.global_sharded_avals),
backend.platform)
@ -798,7 +817,7 @@ def lower_parallel_callable(
keepalive=lowering_result.keepalive,
host_callbacks=lowering_result.host_callbacks,
jaxpr_debug_info=closed_jaxpr.jaxpr.debug_info,
shape_poly_state=lowering_result.shape_poly_state), closed_jaxpr
shape_poly_state=lowering_result.shape_poly_state)
def _pmap_unmap_shaped_array(

View File

@ -617,7 +617,7 @@ def xmap(fun: Callable,
'_experimental_lowering_platform', mlir.LoweringParameters())
fun_flat, args_flat, params, in_tree, out_tree = infer_params(*args)
avals_flat = [shaped_abstractify(arg) for arg in args_flat]
computation, jaxpr = make_xmap_callable(
computation = make_xmap_callable(
fun_flat, params['name'], params['in_axes'], params['out_axes_thunk'],
params['donated_invars'], params['global_axis_sizes'], params['axis_resources'],
params['resource_env'], params['backend'], params['spmd_in_axes'],
@ -627,8 +627,7 @@ def xmap(fun: Callable,
in_tree = treedef_tuple([in_tree, tree_flatten({})[1]])
in_avals = in_tree.unflatten(avals_flat)
return stages.Lowered.from_flat_info(
computation, in_tree, in_avals, donate_argnums, out_tree(),
no_kwargs=True, fun_name=params['name'], jaxpr=jaxpr)
computation, in_tree, in_avals, donate_argnums, out_tree())
fun_mapped.lower = lower
return type_cast(stages.Wrapped, fun_mapped)
@ -637,7 +636,7 @@ def xmap_impl(fun: lu.WrappedFun, *args, name, in_axes, out_axes_thunk, donated_
global_axis_sizes, axis_resources, resource_env, backend,
spmd_in_axes, spmd_out_axes_thunk):
in_avals = [core.raise_to_shaped(core.get_aval(arg)) for arg in args]
computation, _ = make_xmap_callable(
computation = make_xmap_callable(
fun, name, in_axes, out_axes_thunk, donated_invars, global_axis_sizes,
axis_resources, resource_env, backend,
spmd_in_axes, spmd_out_axes_thunk,
@ -709,7 +708,7 @@ def make_xmap_callable(fun: lu.WrappedFun,
in_shardings, out_shardings, donated_invars,
use_spmd_lowering, in_avals,
tiling_method=tiling_method,
lowering_parameters=lowering_parameters), jaxpr
lowering_parameters=lowering_parameters)
else:
jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(f, in_avals)
return pxla.lower_sharding_computation(
@ -717,7 +716,7 @@ def make_xmap_callable(fun: lu.WrappedFun,
(UNSPECIFIED,) * len(in_avals), (UNSPECIFIED,) * len(out_avals),
(None,) * len(in_avals), (None,) * len(out_avals),
donated_invars, keep_unused=True, inline=False,
devices_from_context=None, lowering_parameters=lowering_parameters), jaxpr
devices_from_context=None, lowering_parameters=lowering_parameters)
class EvaluationPlan(NamedTuple):

View File

@ -498,8 +498,7 @@ def _make_jit_wrapper(jit_info: PjitInfo):
donate_argnums = tuple(i for i, d in enumerate(donated_invars) if d)
jaxpr = params["jaxpr"]
return stages.Lowered.from_flat_info(
lowering, in_tree, jaxpr.in_avals, donate_argnums, out_tree,
fun_name=params["name"], jaxpr=jaxpr)
lowering, in_tree, jaxpr.in_avals, donate_argnums, out_tree)
@api_boundary
def eval_shape(*args, **kwargs):

View File

@ -643,30 +643,23 @@ class Lowered(Stage):
querying properties of lowered computations across JAX's various
lowering paths (:func:`~jax.jit`, :func:`~jax.pmap`, etc.).
"""
__slots__ = ["_lowering", "args_info", "out_tree", "_no_kwargs", "_fun_name", "_jaxpr"]
__slots__ = ["_lowering", "args_info", "out_tree", "_no_kwargs"]
_lowering: XlaLowering
args_info: Any # PyTree of ArgInfo
out_tree: tree_util.PyTreeDef
_no_kwargs: bool
_fun_name: str
_jaxpr: core.ClosedJaxpr | None # Can be None when this class is constructed
# outside of JAX core.
def __init__(
self,
lowering: XlaLowering,
args_info, # PyTree of ArgInfo
out_tree: tree_util.PyTreeDef,
no_kwargs: bool = False,
fun_name: str = "<unnamed function>",
jaxpr: core.ClosedJaxpr | None = None):
no_kwargs: bool = False):
self._lowering = lowering
self.args_info = args_info
self.out_tree = out_tree
self._no_kwargs = no_kwargs
self._fun_name = fun_name
self._jaxpr = jaxpr
@classmethod
def from_flat_info(cls,
@ -675,9 +668,7 @@ class Lowered(Stage):
in_avals,
donate_argnums: tuple[int, ...],
out_tree: tree_util.PyTreeDef,
no_kwargs: bool = False,
fun_name: str = "<unnamed function>",
jaxpr: core.ClosedJaxpr | None = None):
no_kwargs: bool = False):
"""Initialize from flat info (``in_avals`` etc.) and an input PyTreeDef.
Args:
@ -686,15 +677,12 @@ class Lowered(Stage):
no_kwargs: If ``True`` the transformation, and the
``Compiled`` returned from this object will not support keyword
arguments (an error will be raised if some are provided).
fun_name: the name of the lowered function.
jaxpr: the Jaxpr of the lowered function. The value `None` is for
backwards compatibility, and is used only outside JAX core.
"""
return cls(
lowering,
make_args_info(in_tree, in_avals, donate_argnums),
out_tree,
no_kwargs=no_kwargs, fun_name=fun_name, jaxpr=jaxpr)
no_kwargs=no_kwargs)
@property
def out_info(self): # PyTree of OutInfo

View File

@ -433,15 +433,17 @@ def export(fun_jax: Callable,
"""
def do_export(*args_specs, **kwargs_specs) -> Exported:
if not hasattr(fun_jax, "lower"):
if hasattr(fun_jax, "lower"):
# If we have a pjit or pmap already we do not wrap with another, and we
# allow shardings.
wrapped_fun_jax = fun_jax
else:
# We support convert(pjit(f_jax)) and convert(jit(f_jax)) but also
# convert(f_jax), in which case a "jit" is implied. In that case we raise
# an error if the lowered function contains non-replicated sharding annotations.
wrapped_fun_jax = jax.jit(fun_jax)
else:
# If we have a pjit or pmap already we do not wrap with another, and we
# allow shardings.
wrapped_fun_jax = fun_jax # type: ignore
has_specialize = hasattr(wrapped_fun_jax, "specialize")
if lowering_platforms is not None:
actual_lowering_platforms = tuple(lowering_platforms)
@ -464,19 +466,32 @@ def export(fun_jax: Callable,
self_descr=f"current (from {shape_poly.args_kwargs_path_to_str(symbolic_scope[1])}) ",
other_descr=shape_poly.args_kwargs_path_to_str(k_path))
lowered = wrapped_fun_jax.lower(
*args_specs, **kwargs_specs,
_experimental_lowering_parameters=mlir.LoweringParameters(
if has_specialize:
specialized = wrapped_fun_jax.specialize(
*args_specs, **kwargs_specs,
_experimental_lowering_parameters=mlir.LoweringParameters(
platforms=actual_lowering_platforms,
for_export=True,
))
))
jaxpr, fun_name = specialized.jaxpr, specialized.fun_name
lowered = specialized.lower()
else:
lowered = wrapped_fun_jax.lower(
*args_specs, **kwargs_specs,
_experimental_lowering_parameters=mlir.LoweringParameters(
platforms=actual_lowering_platforms,
for_export=True,
))
jaxpr, fun_name = None, util.fun_name(wrapped_fun_jax)
return _export_lowered(
lowered, disabled_checks=disabled_checks,
lowered, jaxpr, fun_name,
disabled_checks=disabled_checks,
_device_assignment_for_internal_jax2tf_use_only=_device_assignment_for_internal_jax2tf_use_only)
return do_export
def _export_lowered(
lowered: stages.Lowered,
jaxpr: core.ClosedJaxpr, fun_name: str,
disabled_checks: Sequence[DisabledSafetyCheck] = (),
_device_assignment_for_internal_jax2tf_use_only = None,
) -> Exported:
@ -563,8 +578,8 @@ def _export_lowered(
def _get_exported_vjp(exp_primal: Exported) -> Exported:
# Turn the primal jaxpr into a function, in preparation for exporting
# the VJP. Note that jaxpr_as_fun produces a function with flat arguments
assert(lowered._jaxpr is not None) # None only when the lowered was created outside JAX
fun_jax = core.jaxpr_as_fun(lowered._jaxpr)
assert(jaxpr is not None) # None only when the lowered was created outside JAX
fun_jax = core.jaxpr_as_fun(jaxpr)
fun_vjp_jax, vjp_in_avals = _get_vjp_fun(fun_jax,
in_tree=exp_primal.in_tree,
@ -580,7 +595,7 @@ def _export_lowered(
disabled_checks=exp_primal.disabled_safety_checks)(*vjp_in_avals)
return Exported(
fun_name=lowered._fun_name,
fun_name=fun_name,
in_tree=lowered.in_tree,
out_tree=lowered.out_tree,
in_avals=tuple(args_avals_flat),

View File

@ -34,7 +34,6 @@ import jax
from jax._src import compiler
from jax._src import config
from jax._src import maps
from jax._src.maps import xmap
from jax._src import test_util as jtu
from jax._src import xla_bridge
from jax import lax
@ -56,7 +55,6 @@ config.parse_flags_with_absl()
from jax.experimental.jax2tf.tests import tf_test_util
prev_xla_flags = None
prev_spmd_lowering_flag = None
topology = None
@ -78,9 +76,6 @@ def setUpModule():
" --xla_force_host_platform_device_count=8")
# Clear any cached backends so new CPU backend will pick up the env var.
xla_bridge.get_backend.cache_clear()
global prev_spmd_lowering_flag
prev_spmd_lowering_flag = maps.SPMD_LOWERING.value
config.update('experimental_xmap_spmd_lowering', True)
def tearDownModule():
@ -89,7 +84,6 @@ def tearDownModule():
else:
os.environ["XLA_FLAGS"] = prev_xla_flags
xla_bridge.get_backend.cache_clear()
config.update('experimental_xmap_spmd_lowering', prev_spmd_lowering_flag)
class ShardingTest(tf_test_util.JaxToTfTestCase):
@ -536,115 +530,6 @@ class ShardingTest(tf_test_util.JaxToTfTestCase):
"function with sharded arguments or results must be used under a `tf.function` context"):
jax2tf.convert(f_jax)(a)
def test_xmap_basic(self):
devices = np.reshape(self.devices, (1, 2))
ashape = (16, 8, 5)
a = np.arange(np.prod(ashape), dtype=np.float32).reshape(ashape)
bshape = (2, 7)
b = np.arange(np.prod(bshape), dtype=np.float32).reshape(bshape)
# f_jax: f32[16,8,5], f32[2,7] -> f32[16,8,10], f32[2,28]
# lambda ...: f32[5], f32[7] -> f32[10], f32[28]
f_jax = xmap(lambda a, b: (jnp.concatenate([a, a], axis=0) * 2.,
jnp.concatenate([b, b, b, b], axis=0) * 4.),
in_axes=({0: 'a', 1: 'b'}, ['c', ...]),
out_axes=({0: 'a', 1: 'b'}, ['c', ...]),
axis_resources={'a': 'x', 'b': 'y', 'c': 'x'})
@tf.function(autograph=False, jit_compile=True)
def f_tf(a, b):
# xmap works only with native serialization
f_converted = jax2tf.convert(f_jax, native_serialization=True)
if jtu.test_device_matches(["tpu"]):
res = tf.compat.v1.tpu.rewrite(
f_converted, [tf.convert_to_tensor(a), tf.convert_to_tensor(b)],
device_assignment=self.device_assignment(
computation_shape=[1, 1, 1, 2])
)
return (res[0], res[1])
else:
return f_converted(a, b)
with Mesh(devices, ('x', 'y')):
res_jax = f_jax(a, b)
self.assertAllClose(res_jax, (jnp.concatenate([a, a], axis=2) * 2.,
jnp.concatenate([b, b, b, b], axis=1) * 4.))
res_tf = f_tf(a, b)
self.assertAllClose(res_tf, res_jax)
self.check_sharding(
jax2tf.convert(f_jax, native_serialization=True), [a, b],
checks=[
(r"f32\[16,8,5\].*custom_call_target.*Sharding.*sharding.*devices=\[1,2,1\]", 1),
# The output sharding
(r"f32\[2,7\].*custom_call_target.*Sharding.*sharding.*replicated", 1),
(r"f32\[2,28\].*custom_call_target.*Sharding.*sharding.*replicated", 1),
])
def test_xmap_collective_reduce(self):
devices = np.reshape(self.devices, (1, 2))
ashape = (16, 8, 5)
a = np.arange(np.prod(ashape), dtype=np.float32).reshape(ashape)
bshape = (2, 7)
b = np.arange(np.prod(bshape), dtype=np.float32).reshape(bshape)
f_jax = xmap(lambda a, b: (lax.psum(a * 2., 'a'), b * 4.),
in_axes=(['a', 'b', ...], {0: 'c'}),
out_axes=(['b', ...], {0: 'c'}),
axis_resources={'a': 'x', 'b': 'y', 'c': 'x'})
@tf.function(autograph=False, jit_compile=True)
def f_tf(a, b):
f_converted = jax2tf.convert(f_jax, native_serialization=True)
if jtu.test_device_matches(["tpu"]):
res = tf.compat.v1.tpu.rewrite(
f_converted, [tf.convert_to_tensor(a), tf.convert_to_tensor(b)],
device_assignment=self.device_assignment(
computation_shape=[1, 1, 1, 2])
)
return (res[0], res[1])
else:
return f_converted(a, b)
with Mesh(devices, ('x', 'y')):
res_jax = f_jax(a, b)
self.assertAllClose(res_jax, ((a * 2.).sum(0), b * 4.))
res_tf = f_tf(a, b)
self.assertAllClose(res_tf, res_jax)
self.check_sharding(
jax2tf.convert(f_jax, native_serialization=True), [a, b],
checks=[
(r"f32\[16,8,5\].*custom_call_target.*Sharding.*sharding.*devices=\[1,2,1\]", 1),
(r"f32\[2,7\].*custom_call_target.*Sharding.*sharding.*replicated", 2),
(r"f32\[8,5\].*custom_call_target.*Sharding.*sharding.*devices=\[2,1\]", 1),
])
def test_grad_xmap(self):
devices = np.reshape(self.devices, (1, 2))
ashape = (16, 8, 5)
a = np.arange(np.prod(ashape), dtype=np.float32).reshape(ashape)
# f_jax: f32[16,8,5]-> f32[16,8,10]
# lambda ...: f32[5]-> f32[10]
f_jax = xmap(lambda a: jnp.concatenate([a, a], axis=0) * 2.,
in_axes=({0: 'a', 1: 'b'}),
out_axes={0: 'a', 1: 'b'},
axis_resources={'a': 'x', 'b': 'y'})
def f_grad_tf(a, res_ct):
with tf.GradientTape(persistent=True) as tape:
tape.watch(a)
res_tf = jax2tf.convert(f_jax, native_serialization=True)(a)
return tape.gradient(res_tf, a, output_gradients=res_ct)
with Mesh(devices, ('x', 'y')):
self.check_sharding(f_grad_tf, [a, np.concatenate([a, a], axis=2)],
checks=[
# Primal input and grad output
(r"f32\[16,8,5\].*custom_call_target.*Sharding.*sharding.*devices=\[1,2,1\]", self.GEQ(2)),
# Input cotangent
(r"f32\[16,8,10\].*custom_call_target.*Sharding.*sharding.*devices=\[1,2,1\]", self.GEQ(1)),
])
@jtu.ignore_warning(category=UserWarning,
message="all_to_all .* are only implemented properly for TPUs and GPUs .*")
def test_shmap_all_to_all(self):

View File

@ -556,12 +556,16 @@ class JaxExportTest(jtu.JaxTestCase):
return jnp.sin(x)
# This makes it look like a jitted-function
def lower(self, x,
_experimental_lowering_parameters=None):
def lower(self, x, _experimental_lowering_parameters=None):
return jax.jit(self.__call__).lower(
x,
_experimental_lowering_parameters=_experimental_lowering_parameters)
def specialize(self, x, _experimental_lowering_parameters=None):
return jax.jit(self.__call__).specialize(
x,
_experimental_lowering_parameters=_experimental_lowering_parameters)
a, = export.symbolic_shape("a,")
# No error
_ = get_exported(MyCallable())(