Cleanup internal representation of XLA translation rules.

Over time JAX has sprouted many variants of XLA translation rules, each with slightly different but overlapping arguments. This change consolidates them into a new xla.TranslationRule signature:
rule(ctx, avals_in, avals_out, *args, **params)
where ctx contains the parts of the other signatures that were typically not specific to a particular equation.

Since there are many JAX rules to migrate, and even a number of translation rules belonging to projects downstream of JAX, we leave backwards compatibility shims around `xla.translations`, `xla.backend_specific_translations`, and `xla.call_translations` which seem to be the only ones used outside JAX itself.

In passing, this change alters the semantics of `backend` arguments to nested `jit` blocks. We now always canonicalize the backend to a specific backend at the outermost `jit`, and do not complain if an inner `jit` has an explicit `backend` that matches the current default choice.

PiperOrigin-RevId: 403607667
This commit is contained in:
Peter Hawkins 2021-10-16 07:52:57 -07:00 committed by jax authors
parent 69d7a813e7
commit 2bd010ae88
23 changed files with 457 additions and 309 deletions

View File

@ -281,20 +281,17 @@ def remat_abstract_eval(*args, jaxpr, prevent_cse, differentiated, policy):
del args, prevent_cse, differentiated, policy # Unused.
return [v.aval for v in jaxpr.outvars]
def remat_translation(c, axis_env, name_stack, avals, backend, *in_nodes,
def remat_translation(ctx, avals_in, avals_out, *in_nodes,
jaxpr, prevent_cse, differentiated, policy):
del policy # Unused.
if differentiated and prevent_cse:
if backend == "gpu":
return xla._remat_using_while(
c, axis_env, in_nodes, name_stack, backend, "checkpoint", jaxpr)
if ctx.platform == "gpu":
return xla._remat_using_while(ctx, in_nodes, "checkpoint", jaxpr)
else:
return xla._remat_using_cond(
c, axis_env, in_nodes, name_stack, backend, "checkpoint", jaxpr)
return xla._remat_using_cond(ctx, in_nodes, "checkpoint", jaxpr)
else:
outs = xla.jaxpr_subcomp(c, jaxpr, backend, axis_env, (), "", *in_nodes)
return xla.xops.Tuple(c, outs)
xla.initial_style_translations[remat_p] = remat_translation
return xla.jaxpr_subcomp(ctx, jaxpr, (), *in_nodes)
xla.register_translation(remat_p, remat_translation)
def remat_jvp(primals, tangents, jaxpr, prevent_cse, differentiated, policy):
assert not jaxpr.constvars

View File

@ -822,9 +822,10 @@ def xla_computation(fun: Callable,
should_tuple = tuple_args if tuple_args is not None else (len(avals) > 100)
xla_args, donated_invars = xla._xla_callable_args(
c, avals, should_tuple, partitions=in_parts_flat, donated_invars=donated_invars)
out_nodes = xla.jaxpr_subcomp(
c, jaxpr, backend, axis_env_, xla_consts,
extend_name_stack(wrap_name(fun_name, "xla_computation")), *xla_args)
ctx = xla.TranslationContext(
c, backend, axis_env_,
extend_name_stack(wrap_name(fun_name, "xla_computation")))
out_nodes = xla.jaxpr_subcomp(ctx, jaxpr, xla_consts, *xla_args)
build_out_tuple = partial(xc.ops.Tuple, c, out_nodes)
if out_parts is not None:
out_tuple = xb.with_sharding(c, out_parts_flat, build_out_tuple)

View File

@ -366,8 +366,11 @@ def _custom_jvp_call_jaxpr_vmap(
return batched_outs, out_dims
batching.axis_primitive_batchers[custom_jvp_call_jaxpr_p] = _custom_jvp_call_jaxpr_vmap
xla.initial_style_translations[custom_jvp_call_jaxpr_p] = \
xla.lower_fun_initial_style(_custom_jvp_call_jaxpr_impl)
xla.register_translation(
custom_jvp_call_jaxpr_p,
xla.lower_fun(_custom_jvp_call_jaxpr_impl, new_style=True,
multiple_results=True),
initial_style=True)
# If a (multi)linear function is defined with a custom jvp, then
# custom_jvp_call_jaxpr can appear in jaxprs to be transposed. Since it's
@ -646,7 +649,7 @@ def _custom_vjp_call_jaxpr_jvp(
res, primals_out = split_list(res_and_primals_out, [res_tree.num_leaves])
avals_out = [raise_to_shaped(core.get_aval(x)) for x in primals_out]
tangents_out = ad.custom_lin_p.bind(
*res, *args_dot, num_res=res_tree.num_leaves, bwd=bwd, avals_out=avals_out)
*res, *args_dot, num_res=res_tree.num_leaves, bwd=bwd, out_avals=avals_out)
tangents_out = map(ad.recast_to_float0, primals_out, tangents_out)
return primals_out, tangents_out
ad.primitive_jvps[custom_vjp_call_jaxpr_p] = _custom_vjp_call_jaxpr_jvp
@ -686,8 +689,11 @@ def _custom_vjp_call_jaxpr_vmap(
return batched_outs, out_dims
batching.axis_primitive_batchers[custom_vjp_call_jaxpr_p] = _custom_vjp_call_jaxpr_vmap
xla.initial_style_translations[custom_vjp_call_jaxpr_p] = \
xla.lower_fun_initial_style(_custom_vjp_call_jaxpr_impl)
xla.register_translation(
custom_vjp_call_jaxpr_p,
xla.lower_fun(_custom_vjp_call_jaxpr_impl, new_style=True,
multiple_results=True),
initial_style=True)
batching.primitive_batchers[ad.custom_lin_p] = ad._raise_custom_vjp_error_on_jvp
xla.translations[ad.custom_lin_p] = ad._raise_custom_vjp_error_on_jvp
@ -1067,4 +1073,7 @@ linear_call_p.multiple_results = True
linear_call_p.def_impl(_linear_call_impl)
linear_call_p.def_abstract_eval(_linear_call_abstract_eval)
ad.primitive_transposes[linear_call_p] = _linear_call_transpose_rule
xla.initial_style_translations[linear_call_p] = xla.lower_fun_initial_style(_linear_call_impl)
xla.register_translation(linear_call_p,
xla.lower_fun(_linear_call_impl, new_style=True,
multiple_results=True),
initial_style=True)

View File

@ -322,8 +322,9 @@ def while_loop(cond_fun: Callable[[T], bool],
def _while_loop_abstract_eval(*args, **kwargs):
return _map(raise_to_shaped, kwargs["body_jaxpr"].out_avals)
def _while_loop_translation_rule(c, axis_env, name_stack, avals, backend, *args,
cond_jaxpr, body_jaxpr, cond_nconsts, body_nconsts):
def _while_loop_translation_rule(ctx, avals_in, avals_out, *args, cond_jaxpr,
body_jaxpr, cond_nconsts, body_nconsts):
c = ctx.builder
cond_consts, body_consts, init_vals = split_list(args, [cond_nconsts, body_nconsts])
batched = bool(cond_jaxpr.out_avals[0].shape)
@ -339,9 +340,11 @@ def _while_loop_translation_rule(c, axis_env, name_stack, avals, backend, *args,
cond_carry = xb.parameter(cond_c, 0, c.get_shape(init_carry))
cond_carry_elts = [xops.GetTupleElement(cond_carry, i) for i in range(len(args))]
x, _, z = split_list(cond_carry_elts, [cond_nconsts, body_nconsts])
pred, = xla.jaxpr_subcomp(cond_c, cond_jaxpr.jaxpr, backend, axis_env,
_map(partial(xb.constant, cond_c), cond_jaxpr.consts),
extend_name_stack(name_stack, 'cond'), *(x + z))
cond_ctx = ctx.replace(builder=cond_c,
name_stack=extend_name_stack(ctx.name_stack, 'cond'))
pred, = xla.jaxpr_subcomp(
cond_ctx, cond_jaxpr.jaxpr,
_map(partial(xb.constant, cond_c), cond_jaxpr.consts), *(x + z))
if batched:
scalar = ShapedArray((), np.bool_)
or_ = xla.primitive_subcomputation(lax.or_p, scalar, scalar)
@ -352,13 +355,18 @@ def _while_loop_translation_rule(c, axis_env, name_stack, avals, backend, *args,
body_carry = xb.parameter(body_c, 0, c.get_shape(init_carry))
body_carry_elts = [xops.GetTupleElement(body_carry, i) for i in range(len(args))]
x, y, z = split_list(body_carry_elts, [cond_nconsts, body_nconsts])
new_z = xla.jaxpr_subcomp(body_c, body_jaxpr.jaxpr, backend, axis_env,
_map(partial(xb.constant, body_c), body_jaxpr.consts),
extend_name_stack(name_stack, 'body'), *(y + z))
body_ctx = ctx.replace(builder=body_c,
name_stack=extend_name_stack(ctx.name_stack, 'body'))
new_z = xla.jaxpr_subcomp(
body_ctx, body_jaxpr.jaxpr,
_map(partial(xb.constant, body_c), body_jaxpr.consts),
*(y + z))
if batched:
body_pred, = xla.jaxpr_subcomp(body_c, cond_jaxpr.jaxpr, backend, axis_env,
_map(partial(xb.constant, body_c), cond_jaxpr.consts),
extend_name_stack(name_stack, 'body_pred'), *(x + z))
body_pred_ctx = body_ctx.replace(
name_stack=extend_name_stack(ctx.name_stack, 'body_pred'))
body_pred, = xla.jaxpr_subcomp(
body_pred_ctx, cond_jaxpr.jaxpr,
_map(partial(xb.constant, body_c), cond_jaxpr.consts), *(x + z))
new_z = _map(partial(_pred_bcast_select, body_c, body_pred), new_z, z, body_jaxpr.out_avals)
assert _map(body_c.get_shape, new_z) == _map(body_c.get_shape, z) # no broadcast
new_carry = xops.Tuple(body_c, [*x, *y, *new_z])
@ -366,7 +374,7 @@ def _while_loop_translation_rule(c, axis_env, name_stack, avals, backend, *args,
ans = xops.While(cond_c.build(pred), body_c.build(new_carry), init_carry)
ans_elts = [xops.GetTupleElement(ans, i) for i in range(len(args))]
_, _, z = split_list(ans_elts, [cond_nconsts, body_nconsts])
return xops.Tuple(c, z)
return z
def _pred_bcast_select(c, pred, x, y, x_y_aval: core.AbstractValue):
pred_shape = c.get_shape(pred).dimensions()
@ -594,7 +602,8 @@ while_p.def_impl(partial(xla.apply_primitive, while_p))
while_p.def_abstract_eval(_while_loop_abstract_eval)
ad.primitive_jvps[while_p] = _while_loop_jvp
pe.custom_partial_eval_rules[while_p] = _while_partial_eval
xla.initial_style_translations[while_p] = _while_loop_translation_rule
xla.register_translation(while_p, _while_loop_translation_rule,
initial_style=True)
ad.primitive_transposes[while_p] = _while_transpose_error
batching.axis_primitive_batchers[while_p] = _while_loop_batching_rule
pe.partial_eval_jaxpr_custom_rules[while_p] = pe.partial_eval_jaxpr_custom_rule_not_implemented
@ -786,25 +795,29 @@ def _cond_with_per_branch_args(pred,
def _cond_abstract_eval(*args, **kwargs):
return _map(raise_to_shaped, kwargs["branches"][0].out_avals)
def _cond_translation_rule(c, axis_env, name_stack, avals, backend,
index, *args, branches, linear):
def _cond_translation_rule(ctx, avals_in, avals_out, index, *args, branches,
linear):
del linear # Unused.
name_stack = extend_name_stack(ctx.name_stack, "cond")
def make_computation(name, jaxpr, op_shape):
c = xb.make_computation_builder(name + '_comp')
op = xb.parameter(c, 0, op_shape)
ops = [xops.GetTupleElement(op, i) for i in range(len(jaxpr.in_avals))]
outs = xla.jaxpr_subcomp(c, jaxpr.jaxpr, backend, axis_env,
_map(partial(xb.constant, c), jaxpr.consts),
extend_name_stack(name_stack, name + '_fun'), *ops)
subctx = ctx.replace(
builder=c, name_stack=extend_name_stack(name_stack, name + '_fun'))
outs = xla.jaxpr_subcomp(subctx, jaxpr.jaxpr,
_map(partial(xb.constant, c), jaxpr.consts), *ops)
return c.build(xops.Tuple(c, outs))
c = ctx.builder
op = xops.Tuple(c, args)
op_shape = c.get_shape(op)
branch_computations = [
make_computation(f'branch_{i}', jaxpr, op_shape)
for i, jaxpr in enumerate(branches)]
return xops.Conditional(index, branch_computations, [op] * len(branches))
return xla.xla_destructure(
c, xops.Conditional(index, branch_computations, [op] * len(branches)))
def _select_tree(indices, branch_vals):
assert len(branch_vals) > 0
@ -1171,7 +1184,7 @@ ad.primitive_jvps[cond_p] = _cond_jvp
ad.reducing_transposes[cond_p] = _cond_transpose
pe.custom_partial_eval_rules[cond_p] = _cond_partial_eval
batching.axis_primitive_batchers[cond_p] = _cond_batching_rule
xla.initial_style_translations[cond_p] = _cond_translation_rule
xla.register_translation(cond_p, _cond_translation_rule, initial_style=True)
core.custom_typechecks[cond_p] = _cond_typecheck
pe.partial_eval_jaxpr_custom_rules[cond_p] = pe.partial_eval_jaxpr_custom_rule_not_implemented
@ -1923,7 +1936,9 @@ scan_p.def_abstract_eval(_scan_abstract_eval)
ad.primitive_jvps[scan_p] = _scan_jvp
ad.reducing_transposes[scan_p] = _scan_transpose
pe.custom_partial_eval_rules[scan_p] = _scan_partial_eval
xla.initial_style_translations[scan_p] = xla.lower_fun_initial_style(_scan_impl)
xla.register_translation(scan_p, xla.lower_fun(_scan_impl, new_style=True,
multiple_results=True),
initial_style=True)
batching.axis_primitive_batchers[scan_p] = _scan_batching_rule
masking.masking_rules[scan_p] = _scan_masking_rule
core.custom_typechecks[scan_p] = partial(_scan_typecheck, False)
@ -2482,8 +2497,10 @@ linear_solve_p.multiple_results = True
linear_solve_p.def_impl(_custom_linear_solve_impl)
linear_solve_p.def_abstract_eval(_linear_solve_abstract_eval)
ad.primitive_jvps[linear_solve_p] = _custom_linear_solve_jvp
xla.initial_style_translations[linear_solve_p] = \
xla.lower_fun_initial_style(_custom_linear_solve_impl)
xla.register_translation(
linear_solve_p, xla.lower_fun(_custom_linear_solve_impl, new_style=True,
multiple_results=True),
initial_style=True)
ad.primitive_transposes[linear_solve_p] = _linear_solve_transpose_rule
batching.axis_primitive_batchers[linear_solve_p] = _linear_solve_batching_rule
pe.partial_eval_jaxpr_custom_rules[linear_solve_p] = pe.partial_eval_jaxpr_custom_rule_not_implemented

View File

@ -5552,7 +5552,8 @@ def _reduction_computation(c, jaxpr, consts, init_values, singleton=True):
subc = xla_bridge.make_computation_builder("reduction_computation")
assert len(consts) == 0, "Reduction computations cannot have constants"
args = [xb.parameter(subc, i, shape) for i, shape in enumerate(shapes)]
out_nodes = xla.jaxpr_subcomp(subc, jaxpr, None, axis_env, consts, '', *args)
ctx = xla.TranslationContext(subc, None, axis_env, '')
out_nodes = xla.jaxpr_subcomp(ctx, jaxpr, consts, *args)
if singleton:
return subc.build(out_nodes[0])
out_nodes = xops.Tuple(subc, out_nodes)

View File

@ -643,9 +643,10 @@ def _allreduce_abstract_eval(*args, axes, axis_index_groups):
arg.dtype, named_shape=named_shape)
for arg, named_shape in zip(args, named_shapes)]
def _allreduce_translation_rule(prim, pos_prim, c, *args, axes, axis_index_groups,
axis_env, platform):
if axis_index_groups is not None and platform == "tpu":
def _allreduce_translation_rule(prim, pos_fn, ctx, avals_in, avals_out, *args,
axes, axis_index_groups):
c = ctx.builder
if axis_index_groups is not None and ctx.platform == "tpu":
len_0 = len(axis_index_groups[0])
if any(len(g) != len_0 for g in axis_index_groups):
raise ValueError("axis_index_groups must all be the same size")
@ -654,13 +655,14 @@ def _allreduce_translation_rule(prim, pos_prim, c, *args, axes, axis_index_group
axes_partition[isinstance(axis, int)].append(axis)
if positional_axes:
args = map(partial(xla.translations[pos_prim], c, axes=tuple(positional_axes)), args)
reducer = xla.lower_fun(pos_fn, multiple_results=False)
args = map(partial(reducer, c, axes=tuple(positional_axes)), args)
if not named_axes:
return xops.Tuple(c, args)
return args
def all_reduce(x):
replica_groups_protos = xc.make_replica_groups(
_replica_groups(axis_env, named_axes, axis_index_groups))
_replica_groups(ctx.axis_env, named_axes, axis_index_groups))
scalar = ShapedArray((), c.get_shape(x).numpy_dtype())
computation = xla.primitive_subcomputation(prim, scalar, scalar)
return xops.AllReduce(x, computation, replica_groups_protos, None, None)
@ -673,7 +675,7 @@ def _allreduce_translation_rule(prim, pos_prim, c, *args, axes, axis_index_group
outs = [xops.Complex(all_reduce(xops.Real(x)), all_reduce(xops.Imag(x)))
if dtypes.issubdtype(c.get_shape(x).numpy_dtype(), np.complexfloating)
else all_reduce(x) for x in args]
return xops.Tuple(c, outs)
return outs
def _psum_transpose_rule(cts, *args, axes, axis_index_groups):
named_axes, pos_axes = axes_partition = [], []
@ -698,8 +700,9 @@ psum_p = core.AxisPrimitive('psum')
psum_p.multiple_results = True
psum_p.def_impl(partial(_allreduce_impl, lax._reduce_sum))
psum_p.def_abstract_eval(_allreduce_abstract_eval)
xla.parallel_translations[psum_p] = partial(_allreduce_translation_rule,
lax.add_p, lax.reduce_sum_p) # type: ignore
xla.register_translation(
psum_p, partial(_allreduce_translation_rule, lax.add_p, lax._reduce_sum),
is_collective=True)
ad.deflinear2(psum_p, _psum_transpose_rule)
pxla.multi_host_supported_collectives.add(psum_p)
batching.primitive_batchers[psum_p] = partial(_reduction_batcher, psum_p)
@ -734,8 +737,9 @@ pmax_p = core.AxisPrimitive('pmax')
pmax_p.multiple_results = True
pmax_p.def_impl(partial(_allreduce_impl, lax._reduce_max))
pmax_p.def_abstract_eval(_allreduce_abstract_eval)
xla.parallel_translations[pmax_p] = partial(_allreduce_translation_rule,
lax.max_p, lax.reduce_max_p) # type: ignore
xla.register_translation(
pmax_p, partial(_allreduce_translation_rule, lax.max_p, lax._reduce_max),
is_collective=True)
pxla.multi_host_supported_collectives.add(pmax_p)
batching.primitive_batchers[pmax_p] = partial(_reduction_batcher, pmax_p)
batching.axis_primitive_batchers[pmax_p] = \
@ -747,8 +751,9 @@ pmin_p = core.AxisPrimitive('pmin')
pmin_p.multiple_results = True
pmin_p.def_impl(partial(_allreduce_impl, lax._reduce_min))
pmin_p.def_abstract_eval(_allreduce_abstract_eval)
xla.parallel_translations[pmin_p] = partial(_allreduce_translation_rule,
lax.min_p, lax.reduce_min_p) # type: ignore
xla.register_translation(
pmin_p, partial(_allreduce_translation_rule, lax.min_p, lax._reduce_min),
is_collective=True)
pxla.multi_host_supported_collectives.add(pmin_p)
batching.primitive_batchers[pmin_p] = partial(_reduction_batcher, pmin_p)
batching.axis_primitive_batchers[pmin_p] = \
@ -756,8 +761,8 @@ batching.axis_primitive_batchers[pmin_p] = \
core.axis_substitution_rules[pmin_p] = partial(_subst_all_names_in_param, 'axes')
def _ppermute_translation_rule(c, x, *, axis_name, axis_env, perm, platform):
replica_groups = _replica_groups(axis_env, axis_name, None)
def _ppermute_translation_rule(ctx, avals_in, avals_out, x, *, axis_name, perm):
replica_groups = _replica_groups(ctx.axis_env, axis_name, None)
group_size = len(replica_groups[0])
srcs, dsts = unzip2((src % group_size, dst % group_size) for src, dst in perm)
if not (len(srcs) == len(set(srcs)) and len(dsts) == len(set(dsts))):
@ -768,7 +773,7 @@ def _ppermute_translation_rule(c, x, *, axis_name, axis_env, perm, platform):
for grp in replica_groups:
grp = list(sorted(grp))
full_perm.extend((grp[src], grp[dst]) for src, dst in perm)
return xops.CollectivePermute(x, full_perm)
return [xops.CollectivePermute(x, full_perm)]
def _ppermute_transpose_rule(t, x, perm, axis_name):
srcs, dsts = unzip2(perm)
@ -798,7 +803,8 @@ def _collective_batcher(prim, args, dims, **params):
ppermute_p = core.AxisPrimitive('ppermute')
ppermute_p.def_abstract_eval(lambda x, **params: raise_to_shaped(x))
ad.deflinear2(ppermute_p, _ppermute_transpose_rule)
xla.parallel_translations[ppermute_p] = _ppermute_translation_rule
xla.register_translation(ppermute_p, _ppermute_translation_rule,
is_collective=True)
pxla.multi_host_supported_collectives.add(ppermute_p)
batching.primitive_batchers[ppermute_p] = partial(_collective_batcher, ppermute_p)
batching.axis_primitive_batchers[ppermute_p] = _ppermute_batcher
@ -839,19 +845,20 @@ def _all_to_all_via_all_gather(x, *, axis_name, split_axis, concat_axis, axis_in
sliced = lax.dynamic_slice_in_dim(full, tile_base_idx, tile_size, split_axis + 1)
return _foldaxis(concat_axis, _moveaxis(0, concat_axis, sliced))
def _all_to_all_translation_rule(c, x, *, split_axis, concat_axis, axis_name,
axis_index_groups, axis_env, platform):
def _all_to_all_translation_rule(ctx, avals_in, avals_out, x, *, split_axis,
concat_axis, axis_name, axis_index_groups):
# Workaround for AllToAll not being implemented on CPU.
replica_groups = _replica_groups(axis_env, axis_name, axis_index_groups)
replica_groups = _replica_groups(ctx.axis_env, axis_name, axis_index_groups)
if len(replica_groups[0]) == 1:
return x
elif (platform == "tpu") or ((platform == "gpu") and (split_axis == 0) and
return [x]
elif (ctx.platform == "tpu") or ((ctx.platform == "gpu") and (split_axis == 0) and
(concat_axis == 0)):
split_count = len(replica_groups[0])
if not all(split_count == len(g) for g in replica_groups):
raise ValueError('Replica groups must be equally sized')
replica_groups_protos = xc.make_replica_groups(replica_groups)
return xops.AllToAll(x, split_axis, concat_axis, split_count, replica_groups_protos)
return [xops.AllToAll(x, split_axis, concat_axis, split_count,
replica_groups_protos)]
else:
warnings.warn(
"all_to_all (and pswapaxes) are only implemented properly for TPUs and GPUs (if "
@ -859,16 +866,13 @@ def _all_to_all_translation_rule(c, x, *, split_axis, concat_axis, axis_name,
"very slow and memory intensive algorithm, so expect significant slowdowns."
)
lowering = xla.lower_fun(
_all_to_all_via_all_gather, multiple_results=False, parallel=True)
_all_to_all_via_all_gather, multiple_results=False, new_style=True)
return lowering(
c,
x,
ctx, avals_in, avals_out, x,
axis_name=axis_name,
split_axis=split_axis,
concat_axis=concat_axis,
axis_name=axis_name,
axis_index_groups=axis_index_groups,
axis_env=axis_env,
platform=platform)
axis_index_groups=axis_index_groups)
def _all_to_all_transpose_rule(cts, x, axis_name, split_axis, concat_axis, axis_index_groups):
return (all_to_all(
@ -950,7 +954,8 @@ def _all_to_all_abstract_eval(x, axis_name, split_axis, concat_axis, axis_index_
all_to_all_p = core.AxisPrimitive('all_to_all')
all_to_all_p.def_abstract_eval(_all_to_all_abstract_eval)
xla.parallel_translations[all_to_all_p] = _all_to_all_translation_rule
xla.register_translation(all_to_all_p, _all_to_all_translation_rule,
is_collective=True)
ad.deflinear2(all_to_all_p, _all_to_all_transpose_rule)
pxla.multi_host_supported_collectives.add(all_to_all_p)
batching.primitive_batchers[all_to_all_p] = _all_to_all_batcher
@ -1048,22 +1053,29 @@ def _all_gather_via_psum(x, *, all_gather_dimension, axis_name, axis_index_group
def _all_gather_impl(x, *, all_gather_dimension, axis_name, axis_index_groups, axis_size, tiled):
raise AssertionError("Unexpected call to _all_gather_impl")
def _all_gather_translation_rule(c, x, *, all_gather_dimension, axis_name, axis_index_groups, axis_size, tiled, axis_env, platform):
def _all_gather_translation_rule(
ctx, avals_in, avals_out, x, *, all_gather_dimension, axis_name,
axis_index_groups, axis_size, tiled):
# TODO(jekbradbury): enable for all_gather_dimension > 0
if platform == 'tpu' or platform == 'gpu' and all_gather_dimension == 0:
c = ctx.builder
if ctx.platform == 'tpu' or ctx.platform == 'gpu' and all_gather_dimension == 0:
if not tiled:
new_shape = list(c.get_shape(x).dimensions())
new_shape.insert(all_gather_dimension, 1)
broadcast_dimensions = [i for i in range(len(new_shape)) if i != all_gather_dimension]
x = xops.BroadcastInDim(x, new_shape, broadcast_dimensions)
replica_groups = _replica_groups(axis_env, axis_name, axis_index_groups)
return xops.AllGather(x, all_gather_dimension=all_gather_dimension, shard_count=axis_size,
replica_groups=xc.make_replica_groups(replica_groups))
replica_groups = _replica_groups(ctx.axis_env, axis_name, axis_index_groups)
return [
xops.AllGather(x, all_gather_dimension=all_gather_dimension,
shard_count=axis_size,
replica_groups=xc.make_replica_groups(replica_groups))]
else:
lowering = xla.lower_fun(_all_gather_via_psum, multiple_results=False, parallel=True)
return lowering(c, x, all_gather_dimension=all_gather_dimension, axis_name=axis_name,
axis_index_groups=axis_index_groups, axis_size=axis_size, tiled=tiled,
axis_env=axis_env, platform=platform)
lowering = xla.lower_fun(_all_gather_via_psum, multiple_results=False,
new_style=True)
return lowering(
ctx, avals_in, avals_out, x, all_gather_dimension=all_gather_dimension,
axis_name=axis_name, axis_index_groups=axis_index_groups,
axis_size=axis_size, tiled=tiled)
def _all_gather_abstract_eval(x, *, all_gather_dimension, axis_name, axis_index_groups, axis_size, tiled):
if not isinstance(axis_name, (list, tuple)):
@ -1103,7 +1115,8 @@ def _all_gather_batcher(vals_in, dims_in, *, all_gather_dimension, axis_name, ax
tiled=tiled)
return result, d
def _all_gather_batched_collective(frame_size, frame_name, _, vals_in, dims_in, all_gather_dimension, axis_name,
def _all_gather_batched_collective(frame_size, frame_name, _, vals_in, dims_in,
all_gather_dimension, axis_name,
axis_index_groups, axis_size, tiled):
assert axis_index_groups is None, "axis_index_groups not supported in vmap"
assert axis_size == frame_size, "axis size doesn't match"
@ -1127,7 +1140,8 @@ def _all_gather_batched_collective(frame_size, frame_name, _, vals_in, dims_in,
all_gather_p = core.AxisPrimitive('all_gather')
all_gather_p.def_abstract_eval(_all_gather_abstract_eval)
all_gather_p.def_impl(_all_gather_impl)
xla.parallel_translations[all_gather_p] = _all_gather_translation_rule
xla.register_translation(all_gather_p, _all_gather_translation_rule,
is_collective=True)
ad.deflinear2(all_gather_p, _all_gather_transpose_rule)
pxla.multi_host_supported_collectives.add(all_gather_p)
batching.primitive_batchers[all_gather_p] = _all_gather_batcher
@ -1135,7 +1149,8 @@ batching.axis_primitive_batchers[all_gather_p] = _all_gather_batched_collective
core.axis_substitution_rules[all_gather_p] = partial(_subst_all_names_in_param, 'axis_name')
def _reduce_scatter_via_reducer(x, *, reducer, scatter_dimension, axis_name, axis_index_groups, axis_size, tiled):
def _reduce_scatter_via_reducer(x, *, reducer, scatter_dimension, axis_name,
axis_index_groups, axis_size, tiled):
index = _index_in_group(axis_name, axis_index_groups)
scatter_dim_input_size = x.shape[scatter_dimension]
if tiled and scatter_dim_input_size % axis_size != 0:
@ -1159,13 +1174,14 @@ def _reduce_scatter_via_reducer(x, *, reducer, scatter_dimension, axis_name, axi
return outs
def _reduce_scatter_translation_rule(prim, reducer, c, x, *, scatter_dimension,
axis_name, axis_index_groups, axis_size,
tiled, axis_env, platform):
if platform in ("tpu", "gpu"):
def _reduce_scatter_translation_rule(prim, reducer, ctx, avals_in, avals_out, x,
*, scatter_dimension, axis_name,
axis_index_groups, axis_size, tiled):
c = ctx.builder
if ctx.platform in ("tpu", "gpu"):
scalar = ShapedArray((), c.get_shape(x).numpy_dtype())
computation = xla.primitive_subcomputation(prim, scalar, scalar)
replica_groups = _replica_groups(axis_env, axis_name, axis_index_groups)
replica_groups = _replica_groups(ctx.axis_env, axis_name, axis_index_groups)
x = xops.ReduceScatter(
x,
computation,
@ -1176,20 +1192,17 @@ def _reduce_scatter_translation_rule(prim, reducer, c, x, *, scatter_dimension,
new_shape = list(c.get_shape(x).dimensions())
del new_shape[scatter_dimension]
x = xops.Reshape(x, new_shape)
return x
return [x]
else:
return xla.lower_fun(
_reduce_scatter_via_reducer, multiple_results=False, parallel=True)(
c,
x,
_reduce_scatter_via_reducer, multiple_results=False, new_style=True)(
ctx, avals_in, avals_out, x,
reducer=reducer,
scatter_dimension=scatter_dimension,
axis_name=axis_name,
axis_index_groups=axis_index_groups,
axis_size=axis_size,
tiled=tiled,
axis_env=axis_env,
platform=platform)
tiled=tiled)
def _reduce_scatter_abstract_eval(x, *, axis_name, scatter_dimension,
@ -1222,8 +1235,10 @@ def _reduce_scatter_abstract_eval(x, *, axis_name, scatter_dimension,
reduce_scatter_p = core.AxisPrimitive("reduce_scatter")
reduce_scatter_p.def_abstract_eval(_reduce_scatter_abstract_eval)
xla.parallel_translations[reduce_scatter_p] = partial(
_reduce_scatter_translation_rule, lax.add_p, psum) # type: ignore
xla.register_translation(
reduce_scatter_p,
partial(_reduce_scatter_translation_rule, lax.add_p, psum),
is_collective=True)
pxla.multi_host_supported_collectives.add(reduce_scatter_p)
@ -1297,7 +1312,7 @@ def psum_scatter(x, axis_name, *, scatter_dimension=0, axis_index_groups=None, t
return tree_util.tree_map(bind, x)
def _axis_index_translation_rule(c, *, axis_name, axis_env, platform):
def _build_axis_index_lowering(c, axis_name, axis_env):
if isinstance(axis_name, tuple):
assert axis_name, 'empty axis name'
if len(axis_name) > 1:
@ -1312,12 +1327,17 @@ def _axis_index_translation_rule(c, *, axis_name, axis_env, platform):
unsigned_index = xops.Rem(xops.Div(xops.ReplicaId(c), div), mod)
return xops.ConvertElementType(unsigned_index, xb.dtype_to_etype(np.int32))
def _axis_index_translation_rule(ctx, avals_in, avals_out, *, axis_name):
return [_build_axis_index_lowering(ctx.builder, axis_name, ctx.axis_env)]
def _axis_index_abstract_eval(*, axis_name):
frame = core.axis_frame(axis_name)
return ShapedArray((), np.int32, named_shape={axis_name: frame.size})
axis_index_p = core.Primitive('axis_index')
xla.parallel_translations[axis_index_p] = _axis_index_translation_rule
xla.register_translation(axis_index_p, _axis_index_translation_rule,
is_collective=True)
axis_index_p.def_abstract_eval(_axis_index_abstract_eval)
pxla.multi_host_supported_collectives.add(axis_index_p)
core.axis_substitution_rules[axis_index_p] = partial(_subst_all_names_in_param, 'axis_name')
@ -1402,20 +1422,16 @@ def _pdot_vmap_batching_rule(vals_in, dims_in, *, axis_name, pos_contract,
return out, result_batch_dim
batching.primitive_batchers[pdot_p] = _pdot_vmap_batching_rule
def _pdot_translation_rule(c, x, y, *, axis_name, pos_contract, pos_batch, precision,
axis_env, platform):
local_out = lax._dot_general_translation_rule(
c, x, y, dimension_numbers=[pos_contract, pos_batch], precision=precision,
preferred_element_type=None)
if axis_name:
out_tup = xla.parallel_translations[psum_p](
c, local_out, axes=axis_name, axis_index_groups=None,
axis_env=axis_env, platform=platform)
out, = xla.xla_destructure(c, out_tup)
else:
out = local_out
return out
xla.parallel_translations[pdot_p] = _pdot_translation_rule
def _pdot_lowering(x, y, *, axis_name, pos_contract, pos_batch, precision):
local_out = lax.dot_general(x, y, dimension_numbers=[pos_contract, pos_batch],
precision=precision, preferred_element_type=None)
return psum(local_out, axis_name) if axis_name is not None else local_out
xla.register_translation(
pdot_p,
xla.lower_fun(_pdot_lowering, multiple_results=False, new_style=True),
is_collective=True)
def _pdot_transpose_lhs(g, y, *, axis_name, pos_contract, pos_batch, precision):
# TODO: avals with names, call pbroadcast with axis_name
@ -1456,11 +1472,12 @@ def _pgather_abstract_eval(src, idx, *, axes):
shape = idx.shape + tuple(shape)
return ShapedArray(shape, src.dtype)
def _pgather_parallel_translation(c, src, idx, *, axes, axis_env, platform):
def _pgather_parallel_translation(ctx, avals_in, avals_out, src, idx, *, axes):
if any(not isinstance(axis, int) for axis in axes):
raise NotImplementedError("pgather only supported in the SPMD lowering."
"Please open a feature request!")
return xla.lower_fun(_pgather_impl, multiple_results=False)(c, src, idx, axes=axes)
return xla.lower_fun(_pgather_impl, multiple_results=False, new_style=True)(
ctx, avals_in, avals_out, src, idx, axes=axes)
def _pgather_batcher(vals_in, dims_in, *, axes):
src, idx = vals_in
@ -1503,7 +1520,8 @@ def _pgather_collective_batcher(axis_size, frame_name, _, vals_in, dims_in, *, a
pgather_p = core.AxisPrimitive('pgather')
pgather_p.def_impl(_pgather_impl)
pgather_p.def_abstract_eval(_pgather_abstract_eval)
xla.parallel_translations[pgather_p] = _pgather_parallel_translation
xla.register_translation(pgather_p, _pgather_parallel_translation,
is_collective=True)
# TODO: Transpose? That requires adding pscatter...
batching.primitive_batchers[pgather_p] = _pgather_batcher
batching.axis_primitive_batchers[pgather_p] = _pgather_collective_batcher

View File

@ -373,12 +373,12 @@ threefry2x32_p.multiple_results = True
threefry2x32_p.def_impl(partial(xla.apply_primitive, threefry2x32_p))
threefry2x32_p.def_abstract_eval(_threefry2x32_abstract_eval)
batching.defbroadcasting(threefry2x32_p)
xla.translations_with_avals[threefry2x32_p] = xla.lower_fun(
xla.register_translation(threefry2x32_p, xla.lower_fun(
partial(_threefry2x32_lowering, use_rolled_loops=False),
multiple_results=True, with_avals=True)
xla.backend_specific_translations['cpu'][threefry2x32_p] = xla.lower_fun(
multiple_results=True, new_style=True))
xla.register_translation(threefry2x32_p, xla.lower_fun(
partial(_threefry2x32_lowering, use_rolled_loops=True),
multiple_results=True)
multiple_results=True, new_style=True), platform='cpu')
if cuda_prng:
xla.backend_specific_translations['gpu'][threefry2x32_p] = \
_threefry2x32_gpu_translation_rule

View File

@ -959,12 +959,12 @@ random_gamma_p = core.Primitive('random_gamma')
random_gamma_p.def_impl(_gamma_impl)
random_gamma_p.def_abstract_eval(lambda key, a: core.raise_to_shaped(a))
ad.defjvp2(random_gamma_p, None, lambda tangent, ans, key, a: tangent * _gamma_grad(ans, a))
xla.translations_with_avals[random_gamma_p] = xla.lower_fun(
xla.register_translation(random_gamma_p, xla.lower_fun(
partial(_gamma_impl, use_vmap=True),
multiple_results=False, with_avals=True)
xla.backend_specific_translations['cpu'][random_gamma_p] = xla.lower_fun(
multiple_results=False, new_style=True))
xla.register_translation(random_gamma_p, xla.lower_fun(
partial(_gamma_impl, use_vmap=False),
multiple_results=False)
multiple_results=False, new_style=True), platform='cpu')
batching.primitive_batchers[random_gamma_p] = _gamma_batching_rule
def gamma(key: KeyArray,

View File

@ -801,8 +801,8 @@ def traceable_to_padded_translation(traceable):
jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic(fun, in_avals)
operands_ = it.chain.from_iterable([*dims.values(), *operands])
outs = xla.jaxpr_subcomp(c, jaxpr, None, xla.AxisEnv(1, (), ()),
xla._xla_consts(c, consts), '', *operands_)
ctx = xla.TranslationContext(c, None, xla.AxisEnv(1, (), ()), '')
outs = xla.jaxpr_subcomp(ctx, jaxpr, xla._xla_consts(c, consts), *operands_)
return xla._partition_outputs(
[aval_to_num_buffers(aval) for aval in out_avals], outs)
return translation

View File

@ -42,6 +42,7 @@ import jax._src.random
from jax.experimental import maps
from jax.experimental import pjit
from jax.interpreters import ad
from jax.interpreters import partial_eval
from jax.interpreters import pxla
from jax.interpreters import sharded_jit
from jax.interpreters import xla
@ -940,9 +941,10 @@ def _unexpected_primitive(p: core.Primitive, *args, **kwargs):
assert False, f"Encountered unexpected primitive {p}"
for unexpected in xla.call_translations: # Call primitives are inlined
if unexpected is pjit.pjit_p:
continue
# Call primitives are inlined
for unexpected in [core.call_p, core.named_call_p, xla.xla_call_p,
partial_eval.remat_call_p, sharded_jit.sharded_call_p,
maps.xmap_p]:
tf_impl[unexpected] = partial(_unexpected_primitive, unexpected)
# Primitives that are not yet implemented must be explicitly declared here.

View File

@ -118,12 +118,10 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
"""Fail if there are JAX primitives that are not implemented."""
# Harvest primitives from XLA translation tables
all_primitives = (
set(xla.translations)
| set(xla.backend_specific_translations["cpu"])
| set(xla.backend_specific_translations["gpu"])
| set(xla.backend_specific_translations["tpu"])
| set(xla.initial_style_translations)
| set(xla.parallel_translations))
set(xla._translations)
| set(xla._backend_specific_translations["cpu"])
| set(xla._backend_specific_translations["gpu"])
| set(xla._backend_specific_translations["tpu"]))
tf_impl = set(jax.experimental.jax2tf.jax2tf.tf_impl) | set(
jax.experimental.jax2tf.jax2tf.tf_impl_with_avals)

View File

@ -46,7 +46,7 @@ from jax._src.lib import xla_client as xc
from .._src.util import (safe_map, safe_zip, HashableFunction,
as_hashable_function, unzip2, distributed_debug_log,
tuple_insert, moveaxis, split_list, wrap_name)
from .._src.lax.parallel import _axis_index_translation_rule
from .._src.lax.parallel import _build_axis_index_lowering
from .. import lax
class _PositionalSemantics(Enum):
@ -1347,9 +1347,10 @@ def _xmap_translation_rule_replica(c, axis_env,
# them!
# We in-line here rather than generating a Call HLO as in the xla_call
# translation rule just because the extra tuple stuff is a pain.
tiled_outs = xla.jaxpr_subcomp(
c, vectorized_jaxpr, backend, axis_env, (),
xla.extend_name_stack(name_stack, xla.wrap_name(name, 'xmap')), *tiled_ins)
ctx = xla.TranslationContext(
c, backend, axis_env,
xla.extend_name_stack(name_stack, xla.wrap_name(name, 'xmap')))
tiled_outs = xla.jaxpr_subcomp(ctx, vectorized_jaxpr, (), *tiled_ins)
outs = [_xla_untile(c, axis_env, tiled_out, ans_out_axes, local_mesh_shape, backend)
if v.aval is not core.abstract_unit else tiled_out
@ -1363,8 +1364,8 @@ def _xla_tile_base_indices(c, axis_env, tile_shape, axes, axis_sizes):
linear_idxs = [zero] * len(tile_shape)
strides = [1] * len(tile_shape)
for name, axis in reversed(axes.items()):
axis_index = _axis_index_translation_rule(
c, axis_name=name, axis_env=axis_env, platform=None)
axis_index = _build_axis_index_lowering(
c, axis_name=name, axis_env=axis_env)
stride_c = xb.constant(c, np.array(strides[axis], np.int32))
if linear_idxs[axis] is zero and strides[axis] == 1:
linear_idxs[axis] = axis_index
@ -1464,10 +1465,11 @@ def _xmap_translation_rule_spmd(c, axis_env,
# We in-line here rather than generating a Call HLO as in the xla_call
# translation rule just because the extra tuple stuff is a pain.
global_out_nodes = xla.jaxpr_subcomp(
c, vectorized_jaxpr, backend, axis_env, (),
xla.extend_name_stack(name_stack, xla.wrap_name(name, 'xmap')),
*sharded_global_in_nodes)
ctx = xla.TranslationContext(
c, backend, axis_env,
xla.extend_name_stack(name_stack, xla.wrap_name(name, 'xmap')))
global_out_nodes = xla.jaxpr_subcomp(ctx, vectorized_jaxpr, (),
*sharded_global_in_nodes)
sharded_global_out_nodes = [
xb.set_sharding_proto(c, node, global_sharding_spec(aval, aval_axes).sharding_proto())

View File

@ -476,9 +476,11 @@ def _pjit_translation_rule(c, axis_env, in_nodes, name_stack, backend, name,
get_sharding_proto(c, n, axis_resources, mesh)))
# TODO: Think about how to avoid duplicating constants with the outer jaxpr
ctx = xla.TranslationContext(
subc, backend, axis_env,
extend_name_stack(name_stack, wrap_name(name, "pjit")))
out_nodes = xla.jaxpr_subcomp(
subc, jaxpr.jaxpr, backend, axis_env, xla._xla_consts(subc, jaxpr.consts),
extend_name_stack(name_stack, wrap_name(name, "pjit")), *args)
ctx, jaxpr.jaxpr, xla._xla_consts(subc, jaxpr.consts), *args)
out_nodes = [
xb.set_sharding_proto(subc, out,
get_sharding_proto(subc, out, axis_resources, mesh))

View File

@ -370,7 +370,7 @@ class JVPTrace(Trace):
avals_out = [raise_to_shaped(core.get_aval(x)) for x in primals_out]
tangents_out = custom_lin_p.bind(
*res, *tangents_in, num_res=res_tree.num_leaves, bwd=bwd,
avals_out=avals_out)
out_avals=avals_out)
tangents_out = map(recast_to_float0, primals_out, tangents_out)
return map(partial(JVPTracer, self), primals_out, tangents_out)
@ -676,7 +676,7 @@ def _interleave(xs, ys):
custom_lin_p = core.Primitive('custom_lin')
custom_lin_p.def_abstract_eval(lambda *_, avals_out, **__: avals_out)
custom_lin_p.def_abstract_eval(lambda *_, out_avals, **__: out_avals)
custom_lin_p.multiple_results = True
def _raise_custom_vjp_error_on_jvp(*_, **__):
@ -684,9 +684,9 @@ def _raise_custom_vjp_error_on_jvp(*_, **__):
"function.")
custom_lin_p.def_impl(_raise_custom_vjp_error_on_jvp)
def _custom_lin_transpose(cts_out, *invals, num_res, bwd, avals_out):
def _custom_lin_transpose(cts_out, *invals, num_res, bwd, out_avals):
res, _ = split_list(invals, [num_res])
cts_out = map(instantiate_zeros_aval, avals_out, cts_out)
cts_out = map(instantiate_zeros_aval, out_avals, cts_out)
cts_in = bwd.call_wrapped(*res, *cts_out)
return [None] * num_res + list(cts_in)
primitive_transposes[custom_lin_p] = _custom_lin_transpose

View File

@ -883,8 +883,9 @@ def parallel_callable(fun: lu.WrappedFun,
partitions=arg_parts,
donated_invars=donated_invars)
with maybe_extend_axis_env(axis_name, global_axis_size, None): # type: ignore
out_nodes = xla.jaxpr_subcomp(c, jaxpr, backend_name, axis_env, xla_consts,
extend_name_stack(wrap_name(name, 'pmap')), *xla_args)
ctx = xla.TranslationContext(c, backend.platform, axis_env,
extend_name_stack(wrap_name(name, 'pmap')))
out_nodes = xla.jaxpr_subcomp(ctx, jaxpr, xla_consts, *xla_args)
build_out_tuple = partial(xops.Tuple, c, out_nodes)
if out_parts is not None:
out_tuple = xb.with_sharding(c, out_parts, build_out_tuple)
@ -1297,9 +1298,10 @@ def _pmap_translation_rule(c, axis_env,
for aval, in_node, in_axis in safe_zip(in_avals, in_nodes, in_axes))
with maybe_extend_axis_env(axis_name, global_axis_size, None): # type: ignore
sharded_outs = xla.jaxpr_subcomp(
c, call_jaxpr, backend, new_env, (),
extend_name_stack(name_stack, wrap_name(name, 'pmap')), *in_nodes_sharded)
ctx = xla.TranslationContext(
c, backend, new_env,
extend_name_stack(name_stack, wrap_name(name, 'pmap')))
sharded_outs = xla.jaxpr_subcomp(ctx, call_jaxpr, (), *in_nodes_sharded)
out_avals = [v.aval for v in call_jaxpr.outvars]
outs = [_xla_unshard(c, aval, new_env, out_axis, shard, backend=backend)
for aval, out_axis, shard in safe_zip(out_avals, out_axes, sharded_outs)]
@ -1620,9 +1622,10 @@ def lower_mesh_computation(
partitions_proto=partitions_proto,
donated_invars=donated_invars)
with core.extend_axis_env_nd(mesh.shape.items()):
out_nodes = xla.jaxpr_subcomp(
c, jaxpr, backend.platform, axis_env, xla_consts,
extend_name_stack(wrap_name(transformed_name, 'xmap')), *xla_args)
ctx = xla.TranslationContext(
c, backend.platform, axis_env,
extend_name_stack(wrap_name(transformed_name, 'xmap')))
out_nodes = xla.jaxpr_subcomp(ctx, jaxpr, xla_consts, *xla_args)
if spmd_lowering:
out_partitions_t = xb.tuple_sharding_proto(out_partitions)
out_tuple = xb.with_sharding_proto(c, out_partitions_t, xops.Tuple, c, out_nodes)

View File

@ -87,10 +87,10 @@ def _sharded_callable(
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_final(fun, global_abstract_args)
if xb.get_backend().platform not in ["tpu", "gpu"]:
platform = xb.get_backend().platform
if platform not in ["tpu", "gpu"]:
# TODO(skye): fall back to regular jit?
raise ValueError("sharded_jit not supported for " +
xb.get_backend().platform)
raise ValueError(f"sharded_jit not supported for {platform}")
nparts = pxla.reconcile_num_partitions(jaxpr, nparts)
assert nparts is not None
@ -142,9 +142,9 @@ def _sharded_callable(
xla_consts = _map(partial(xb.constant, c), consts)
xla_args = _xla_sharded_args(c, global_abstract_args, in_parts)
axis_env = xla.AxisEnv(nrep, (), ())
out_nodes = xla.jaxpr_subcomp(
c, jaxpr, None, axis_env, xla_consts,
extend_name_stack(wrap_name(name, "sharded_jit")), *xla_args)
ctx = xla.TranslationContext(
c, platform, axis_env, extend_name_stack(wrap_name(name, "sharded_jit")))
out_nodes = xla.jaxpr_subcomp(ctx, jaxpr, xla_consts, *xla_args)
out_tuple = xb.with_sharding(c, out_parts, xops.Tuple, c, out_nodes)
built = c.Build(out_tuple)
@ -194,9 +194,10 @@ def _sharded_jit_translation_rule(c, axis_env, in_nodes, name_stack,
arg = xb.parameter(subc, i, c.GetShape(n))
args.append(xb.set_sharding(subc, arg, sharding))
out_nodes = xla.jaxpr_subcomp(
subc, call_jaxpr, backend, axis_env, (),
extend_name_stack(name_stack, wrap_name(name, "sharded_jit")), *args)
ctx = xla.TranslationContext(
subc, backend, axis_env,
extend_name_stack(wrap_name(name, "sharded_jit")))
out_nodes = xla.jaxpr_subcomp(ctx, call_jaxpr, (), *args)
out_parts = out_parts_thunk()
assert len(out_parts) == len(out_nodes)
out_nodes = [xb.set_sharding(subc, out, sharding)

View File

@ -14,12 +14,15 @@
from collections import defaultdict, deque
import dataclasses
import functools
from functools import partial, partialmethod
import itertools as it
import operator as op
import re
from typing import (Any, Callable, Deque, Dict, List, Optional, Sequence, Set,
Type, Tuple, Union, NamedTuple)
from typing_extensions import Protocol
from warnings import warn
import weakref
@ -326,9 +329,9 @@ def _device_from_arg_devices(devices: Sequence[Optional[Device]]) -> Optional[De
def primitive_subcomputation(prim: core.Primitive, *avals: core.AbstractValue,
**params):
c = xb.make_computation_builder(f"primitive_computation_{prim.name}")
f = lower_fun(prim.bind, prim.multiple_results, with_avals=True)
f = lower_fun(prim.bind, multiple_results=prim.multiple_results)
xla_args, _ = _xla_callable_args(c, avals, tuple_args=False)
ans = f(c, avals, xla_args, params)
ans = f(c, *xla_args, **params)
return c.build(ans)
def backend_compile(backend, built_c, options):
@ -378,15 +381,31 @@ def _partitionmap(func: Callable, vars: Sequence, nodes: Sequence):
_partition_outputs([len(aval_to_xla_shapes(v.aval)) for v in vars],
nodes))
def jaxpr_subcomp(c, jaxpr, backend, axis_env, consts, name_stack, *args):
if backend not in ('cpu', 'gpu', 'tpu'):
platform = xb.get_backend(backend).platform # canonicalize
else:
platform = backend
class AxisEnv(NamedTuple):
"""Represents a pmap mesh (only along the replica axes)."""
nreps: int
names: Tuple[Any, ...]
sizes: Tuple[int, ...]
@dataclasses.dataclass
class TranslationContext:
builder: xc.XlaBuilder
# TODO(phawkins): make platform non-optional. We should always be translating
# with a specific platform in mind.
platform: Optional[str]
axis_env: AxisEnv
name_stack: str
def replace(self, **kw): return dataclasses.replace(self, **kw)
def jaxpr_subcomp(ctx: TranslationContext, jaxpr: core.Jaxpr,
consts: Sequence[XlaOp], *args: XlaOp) -> Sequence[XlaOp]:
# TODO(phawkins): make platform non-optional.
# assert ctx.platform is not None
def read(v):
if type(v) is Literal:
return xb.constant_general(c, canonicalize_dtype(v.val))
return xb.constant_general(ctx.builder, canonicalize_dtype(v.val))
else:
return env[v]
@ -400,52 +419,33 @@ def jaxpr_subcomp(c, jaxpr, backend, axis_env, consts, name_stack, *args):
assert node is not None
env[v] = node
env = {}
_partitionmap(write, [core.unitvar], _make_unit_constant(c))
env: Dict[core.Var, Sequence[XlaOp]] = {}
_partitionmap(write, [core.unitvar], _make_unit_constant(ctx.builder))
_partitionmap(write, jaxpr.constvars, consts)
_partitionmap(write, jaxpr.invars, args)
for eqn in jaxpr.eqns:
op_metadata = make_op_metadata(
eqn.primitive, eqn.params, name_stack=name_stack,
eqn.primitive, eqn.params, name_stack=ctx.name_stack,
source_info=eqn.source_info)
c.set_op_metadata(op_metadata)
ctx.builder.set_op_metadata(op_metadata)
in_nodes = _flatmap(read, eqn.invars)
with source_info_util.user_context(eqn.source_info):
# TODO(jakevdp): migrate `translations` table to `translations_with_avals`
if eqn.primitive in backend_specific_translations[platform]:
rule = backend_specific_translations[platform][eqn.primitive]
ans = rule(c, *in_nodes, **eqn.params)
elif eqn.primitive in translations:
ans = translations[eqn.primitive](c, *in_nodes, **eqn.params)
elif eqn.primitive in translations_with_avals:
rule = translations_with_avals[eqn.primitive]
ans = rule(c, map(aval, eqn.invars), in_nodes, eqn.params)
elif eqn.primitive in initial_style_translations:
new_params = check_backend_params(eqn.params, backend)
rule = initial_style_translations[eqn.primitive]
ans = rule(c, axis_env, extend_name_stack(name_stack, eqn.primitive.name),
map(aval, eqn.invars), backend, *in_nodes, **new_params)
elif eqn.primitive in parallel_translations:
rule = parallel_translations[eqn.primitive]
ans = rule(c, *in_nodes, axis_env=axis_env, platform=platform, **eqn.params)
elif eqn.primitive in call_translations:
new_params = check_backend_params(eqn.params, backend)
rule = call_translations[eqn.primitive]
ans = rule(c, axis_env, in_nodes,
name_stack, backend=backend, **new_params)
else:
raise NotImplementedError(
f"XLA translation rule for primitive '{eqn.primitive.name}' not found")
assert isinstance(ans, xe.XlaOp)
c.get_shape(ans) # force xla to do shape error checking
if (eqn.primitive.multiple_results or
any(len(aval_to_xla_shapes(v.aval)) > 1 for v in eqn.outvars)):
out_nodes = xla_destructure(c, ans)
if (ctx.platform is not None and
eqn.primitive in _backend_specific_translations[ctx.platform]):
rule = _backend_specific_translations[ctx.platform][eqn.primitive]
elif eqn.primitive in _translations:
rule = _translations[eqn.primitive]
else:
out_nodes = [ans]
c.clear_op_metadata()
_partitionmap(write, eqn.outvars, out_nodes)
raise NotImplementedError(
f"XLA translation rule for primitive '{eqn.primitive.name}' not found")
with source_info_util.user_context(eqn.source_info):
ans = rule(ctx, map(aval, eqn.invars), map(aval, eqn.outvars),
*in_nodes, **eqn.params)
assert all(isinstance(x, xe.XlaOp) for x in ans), ans
map(ctx.builder.get_shape, ans) # force xla to do shape error checking
ctx.builder.clear_op_metadata()
_partitionmap(write, eqn.outvars, ans)
return _flatmap(read, jaxpr.outvars)
@ -453,23 +453,15 @@ def xla_destructure(c, ans):
num_elements = len(c.get_shape(ans).tuple_shapes())
return [xops.GetTupleElement(ans, i) for i in range(num_elements)]
def check_backend_params(params, outer_backend):
def check_backend_matches(inner_backend, outer_backend):
# For nested calls, the outermost call sets the backend for all inner calls;
# it's an error if the inner call has a conflicting explicit backend spec.
inner_backend = params.get('backend', None)
if inner_backend and inner_backend != outer_backend:
raise ValueError(
f"Outer-jit backend specification {outer_backend} must match explicit "
f"inner-jit backend specification {inner_backend}.")
return {k: params[k] for k in params if k != 'backend'}
class AxisEnv(NamedTuple):
"""Represents a pmap mesh (only along the replica axes)."""
nreps: int
names: Tuple[Any, ...]
sizes: Tuple[int, ...]
def extend_axis_env(env: AxisEnv, name, size: int):
return AxisEnv(env.nreps, env.names + (name,), env.sizes + (size,))
@ -520,7 +512,7 @@ def eqn_replicas(eqn):
call_jaxpr = eqn.params.get("call_jaxpr")
if call_jaxpr:
return eqn.params.get('axis_size', 1) * jaxpr_replicas(call_jaxpr)
elif eqn.primitive in initial_style_translations:
elif eqn.primitive in _initial_style_primitives:
return initial_style_primitive_replicas(eqn.params)
else:
return 1
@ -545,7 +537,7 @@ def jaxpr_has_pmap(jaxpr):
def jaxpr_collectives(jaxpr):
"""Generates all the collective primitives anywhere inside a Jaxpr."""
for eqn in jaxpr.eqns:
if eqn.primitive in parallel_translations:
if eqn.primitive in _collective_primitives:
yield eqn.primitive
for subjaxpr in core.subjaxprs(jaxpr):
yield from jaxpr_collectives(subjaxpr)
@ -656,8 +648,7 @@ def lower_xla_callable(fun: lu.WrappedFun, device, backend, name,
nreps = jaxpr_replicas(jaxpr)
device = _xla_callable_device(nreps, backend, device, arg_devices)
backend = xb.get_device_backend(device) if device else (
xb.get_backend(backend) if backend is not None else None)
backend = xb.get_device_backend(device) if device else xb.get_backend(backend)
# Computations that only produce constants and/or only rearrange their inputs,
# which are often produced from partial evaluation, don't need compilation,
@ -696,15 +687,15 @@ def lower_xla_callable(fun: lu.WrappedFun, device, backend, name,
xla_consts = _xla_consts(c, consts)
xla_args, donated_invars = _xla_callable_args(c, abstract_args, tuple_args,
donated_invars=donated_invars)
out_nodes = jaxpr_subcomp(
c, jaxpr, backend.platform if backend is not None else None,
AxisEnv(nreps, (), ()), xla_consts,
extend_name_stack(wrap_name(name, 'jit')), *xla_args)
platform = backend.platform
ctx = TranslationContext(c, platform, AxisEnv(nreps, (), ()),
extend_name_stack(wrap_name(name, 'jit')))
out_nodes = jaxpr_subcomp(ctx, jaxpr, xla_consts, *xla_args)
backend = xb.get_backend(backend)
# There is a non-zero cost to building an output tuple, particularly on TPU.
# Avoid it if the output arity is 1.
output = out_nodes[0] if len(out_nodes) == 1 else xops.Tuple(c, out_nodes)
if backend.platform in ("gpu", "tpu"):
if platform in ("gpu", "tpu"):
donated_invars = set_up_aliases(
c, xla_args, c.GetShape(output), donated_invars, tuple_args)
if any(donated_invars):
@ -1029,15 +1020,20 @@ def _xla_call_transpose_update_params(params, undef_primals, nonzero_cts):
ad.call_transpose_param_updaters[xla_call_p] = _xla_call_transpose_update_params
def _xla_call_translation_rule(c, axis_env, in_nodes, name_stack, backend, name,
call_jaxpr, donated_invars, inline=None, device=None):
def _xla_call_translation_rule(ctx, avals_in, avals_out, *in_nodes, name,
backend=None, call_jaxpr, donated_invars,
inline=None, device=None):
del device, donated_invars, inline # Ignored.
c = ctx.builder
check_backend_matches(backend, ctx.platform)
subc = xb.make_computation_builder(f"jit_{name}")
args = [xb.parameter(subc, i, c.get_shape(n)) for i, n in enumerate(in_nodes)]
out_nodes = jaxpr_subcomp(subc, call_jaxpr, backend, axis_env, (),
extend_name_stack(name_stack, wrap_name(name, 'jit')), *args)
sub_ctx = ctx.replace(
builder=subc,
name_stack=extend_name_stack(ctx.name_stack, wrap_name(name, 'jit')))
out_nodes = jaxpr_subcomp(sub_ctx, call_jaxpr, (), *args)
subc = subc.build(xops.Tuple(subc, out_nodes))
return xops.Call(c, subc, list(in_nodes))
return xla_destructure(c, xops.Call(c, subc, list(in_nodes)))
ad.primitive_transposes[xla_call_p] = partial(ad.call_transpose, xla_call_p)
@ -1059,14 +1055,100 @@ pe.dce_rules[xla_call_p] = pe.dce_jaxpr_call_rule
### translation tables
translations: Dict[core.Primitive, Callable] = {}
translations_with_avals: Dict[core.Primitive, Callable] = {}
parallel_translations: Dict[core.Primitive, Callable] = {}
initial_style_translations: Dict[core.Primitive, Callable] = {}
call_translations: Dict[core.Primitive, Callable] = {}
backend_specific_translations: Dict[str, Dict[core.Primitive, Callable]] = defaultdict(dict)
MYPY = False
if not MYPY:
class TranslationRule(Protocol):
def __call__(self, ctx: TranslationContext,
avals_in: Sequence[core.AbstractValue],
avals_out: Sequence[core.AbstractValue],
*args: XlaOp, **kw
) -> Sequence[XlaOp]:
"""A translation rule lowers a primitive invocation into an XLA HLO."""
else:
TranslationRule = Any
call_translations[xla_call_p] = _xla_call_translation_rule
_translations: Dict[core.Primitive, TranslationRule] = {}
_backend_specific_translations: Dict[str, Dict[core.Primitive, TranslationRule]]
_backend_specific_translations = defaultdict(dict)
_collective_primitives: Set[core.Primitive] = set()
_initial_style_primitives: Set[core.Primitive] = set()
def register_translation(prim: core.Primitive, rule: TranslationRule, *,
platform: Optional[str] = None,
is_collective: bool = False,
initial_style: bool = False) -> None:
ts = (_translations if platform is None
else _backend_specific_translations[platform])
ts[prim] = rule
if is_collective:
_collective_primitives.add(prim)
if initial_style:
_initial_style_primitives.add(prim)
# As a temporary backward compatibility measure, we use an adapter class to
# convert from the old styles of translation rules to the newer ones.
# TODO(phawkins): update users of the older translation rule styles and remove
# the adapters.
class _TranslationRuleAdapter:
def __init__(self, translations,
wrap_fn: Callable[[core.Primitive, Callable], TranslationRule]):
self._translations = translations
self._wrap_fn = wrap_fn
def __setitem__(self, key: core.Primitive, value: Callable):
self._translations[key] = self._wrap_fn(key, value)
def _wrap_old_translation(prim: core.Primitive, f: Callable) -> TranslationRule:
@functools.wraps(f)
def wrapped(ctx: TranslationContext, avals_in: Sequence[core.AbstractValue],
avals_out: Sequence[core.AbstractValue],
*args: XlaOp, **kw) -> Sequence[XlaOp]:
ans = f(ctx.builder, *args, **kw)
if (prim.multiple_results or
any(len(aval_to_xla_shapes(aval)) > 1 for aval in avals_out)):
return xla_destructure(ctx.builder, ans)
else:
return [ans]
return wrapped
def _wrap_old_call_translation(prim: core.Primitive,
f: Callable) -> TranslationRule:
@functools.wraps(f)
def wrapped(ctx: TranslationContext, avals_in: Sequence[core.AbstractValue],
avals_out: Sequence[core.AbstractValue],
*args: XlaOp, **kw) -> Sequence[XlaOp]:
platform = kw.pop("backend", None)
check_backend_matches(platform, ctx.platform)
ans = f(ctx.builder, ctx.axis_env, args, ctx.name_stack,
backend=ctx.platform, **kw)
if (prim.multiple_results or
any(len(aval_to_xla_shapes(aval)) > 1 for aval in avals_out)):
return xla_destructure(ctx.builder, ans)
else:
return [ans]
return wrapped
translations : _TranslationRuleAdapter
translations = _TranslationRuleAdapter(_translations, _wrap_old_translation)
class _BackendSpecificTranslationsAdapter(defaultdict):
def __missing__(self, key):
ret = self[key] = _TranslationRuleAdapter(
_backend_specific_translations[key], _wrap_old_translation)
return ret
backend_specific_translations: Dict[str, _TranslationRuleAdapter]
backend_specific_translations = _BackendSpecificTranslationsAdapter()
call_translations : _TranslationRuleAdapter
call_translations = _TranslationRuleAdapter(
_translations, _wrap_old_call_translation)
register_translation(xla_call_p, _xla_call_translation_rule)
def zeros_like_translation_rule(c, x):
shape = c.get_shape(x)
@ -1089,8 +1171,19 @@ def _tuple_output(*args, **kwargs):
ans = yield args, kwargs
yield (ans,)
def lower_fun(fun, multiple_results, parallel=False, with_avals=False, backend=None):
# TODO(jakevdp): migrate dependent code & always use the with_avals=True.
def lower_fun(fun: Callable, *, multiple_results: bool, parallel: bool = False,
backend=None, new_style: bool = False):
if new_style:
def f_new(ctx, avals_in, avals_out, *xla_args, **params):
wrapped_fun = lu.wrap_init(fun, params)
if not multiple_results:
wrapped_fun = _tuple_output(wrapped_fun)
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, avals_in)
return jaxpr_subcomp(ctx, jaxpr, _xla_consts(ctx.builder, consts),
*xla_args)
return f_new
# TODO(phawkins): migrate dependent code & always use new_style=True.
def f(c, *xla_args, **params):
avals = [_array_aval_from_xla_shape(c.get_shape(x)) for x in xla_args]
return f_with_avals(c, avals, xla_args, params)
@ -1106,8 +1199,8 @@ def lower_fun(fun, multiple_results, parallel=False, with_avals=False, backend=N
wrapped_fun = _tuple_output(wrapped_fun)
with core.extend_axis_env_nd(zip(axis_env.names, axis_env.sizes)):
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, avals)
outs = jaxpr_subcomp(c, jaxpr, backend, axis_env, _xla_consts(c, consts),
'', *xla_args)
ctx = TranslationContext(c, backend, axis_env, '')
outs = jaxpr_subcomp(ctx, jaxpr, _xla_consts(c, consts), *xla_args)
if (multiple_results or
any(len(aval_to_xla_shapes(v.aval)) > 1 for v in jaxpr.outvars)):
return xops.Tuple(c, outs)
@ -1115,7 +1208,7 @@ def lower_fun(fun, multiple_results, parallel=False, with_avals=False, backend=N
assert len(outs) == 1, outs
return outs[0]
return f_with_avals if with_avals else f
return f
def _array_aval_from_xla_shape(xla_shape):
# This function instantiates the assumption that we can map fro XLA array
@ -1124,15 +1217,6 @@ def _array_aval_from_xla_shape(xla_shape):
assert not xla_shape.is_tuple()
return ShapedArray(xla_shape.dimensions(), xla_shape.numpy_dtype())
def lower_fun_initial_style(fun):
def f(c, axis_env, name_stack, avals, backend, *xla_args, **params):
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(fun, params), avals)
outs = jaxpr_subcomp(c, jaxpr, backend, axis_env, _xla_consts(c, consts),
name_stack, *xla_args)
return xops.Tuple(c, outs)
return f
### device-persistent data
class Token(object): pass
@ -1481,14 +1565,14 @@ def _zeros(c, xla_shape):
return xops.CreateToken(c)
def _remat_using_cond(
c, axis_env, in_nodes, name_stack, backend, name, call_jaxpr):
def _remat_using_cond(ctx, in_nodes, name, call_jaxpr):
"""Lower remat to a Conditional which always returns true. This:
1. Circumvents common subexpression elimination.
2. In common case of `jax.grad(jax.remat(f))`, ensures the remat blocks
occur after the primal blocks, because cotangent is an input to the
Conditional."""
# Fake condition which always selects True branch.
c = ctx.builder
rng = xops.RngUniform(xb.constant(c, np.array(0, dtype=np.float32)),
xb.constant(c, np.array(1, dtype=np.float32)),
xc.Shape.array_shape(xc.PrimitiveType.F32, []))
@ -1498,9 +1582,10 @@ def _remat_using_cond(
remat_subc = xb.make_computation_builder("remat_call_subcomputation")
input_op = xb.parameter(remat_subc, 0, c.get_shape(true_op), replicated=[])
args = xla_destructure(remat_subc, input_op)
out_nodes = jaxpr_subcomp(remat_subc, call_jaxpr, backend, axis_env, (),
extend_name_stack(name_stack, wrap_name(name, 'remat')),
*args)
sub_ctx = ctx.replace(
builder=remat_subc,
name_stack=extend_name_stack(ctx.name_stack, wrap_name(name, 'remat')))
out_nodes = jaxpr_subcomp(sub_ctx, call_jaxpr, (), *args)
out_node_shapes = [remat_subc.get_shape(o) for o in out_nodes]
remat_subc = remat_subc.build(xops.Tuple(remat_subc, out_nodes))
@ -1510,25 +1595,27 @@ def _remat_using_cond(
out_nodes = [_zeros(dummy_subc, s) for s in out_node_shapes]
dummy_subc = dummy_subc.build(xops.Tuple(dummy_subc, out_nodes))
return xops.Conditional(pred, true_op, remat_subc, false_op, dummy_subc)
return xla_destructure(
c, xops.Conditional(pred, true_op, remat_subc, false_op, dummy_subc))
def _remat_using_while(
c, axis_env, in_nodes, name_stack, backend, name, call_jaxpr):
def _remat_using_while(ctx, in_nodes, name, call_jaxpr):
"""Lower remat to a single iteration while loop."""
c = ctx.builder
# Dummy subc for getting subcomp shapes.
dummy_inputs = xops.Tuple(c, in_nodes)
dummy_subc = xb.make_computation_builder("remat_dummy_subcomputation")
dummy_input_op = xb.parameter(dummy_subc, 0, c.get_shape(dummy_inputs), replicated=[])
dummy_args = xla_destructure(dummy_subc, dummy_input_op)
dummy_subcomp_outs = jaxpr_subcomp(
dummy_subc, call_jaxpr, backend, axis_env, (),
extend_name_stack(name_stack, wrap_name(name, "remat")), *dummy_args)
dummy_ctx = ctx.replace(
builder=dummy_subc,
name_stack=extend_name_stack(ctx.name_stack, wrap_name(name, 'remat')))
dummy_subcomp_outs = jaxpr_subcomp(dummy_ctx, call_jaxpr, (), *dummy_args)
out_node_shapes = [dummy_subc.get_shape(o) for o in dummy_subcomp_outs]
i_init = xb.constant(c, np.array(0, dtype=np.int32))
zeros_like_outs = [_zeros(c, s) for s in out_node_shapes]
inputs = xops.Tuple(c, [i_init] + in_nodes + zeros_like_outs)
inputs = xops.Tuple(c, [i_init] + list(in_nodes) + zeros_like_outs)
cond_subc = xb.make_computation_builder("remat_cond_subcomputation")
input_op = xb.parameter(cond_subc, 0, c.get_shape(inputs), replicated=[])
@ -1542,52 +1629,54 @@ def _remat_using_while(
input_op = xb.parameter(body_subc, 0, c.get_shape(inputs), replicated=[])
i, *args = xla_destructure(body_subc, input_op)[:len(in_nodes)+1]
i_next = xops.Add(i, xb.constant(body_subc, np.array(1, dtype=np.int32)))
subcomp_outs = jaxpr_subcomp(
body_subc, call_jaxpr, backend, axis_env, (),
extend_name_stack(name_stack, wrap_name(name, "remat")), *args)
out_nodes = [i_next] + args + subcomp_outs
body_ctx = ctx.replace(
builder=body_subc,
name_stack=extend_name_stack(ctx.name_stack, wrap_name(name, 'remat')))
subcomp_outs = jaxpr_subcomp(body_ctx, call_jaxpr, (), *args)
out_nodes = [i_next] + args + list(subcomp_outs)
body_subc = body_subc.build(xops.Tuple(body_subc, out_nodes))
outs = xops.While(cond_subc, body_subc, inputs)
return xops.Tuple(c, xla_destructure(c, outs)[len(in_nodes)+1:])
return xla_destructure(c, outs)[len(in_nodes)+1:]
def _remat_translation_rule(c, axis_env, in_nodes,
name_stack, backend, name, call_jaxpr,
def _remat_translation_rule(ctx, avals_in, avals_out, *in_nodes,
name, call_jaxpr,
prevent_cse, differentiated, concrete,
policy, device=None):
del device, concrete, policy # Unused.
if differentiated and prevent_cse:
if backend == "gpu":
return _remat_using_while(
c, axis_env, in_nodes, name_stack, backend, name, call_jaxpr)
if ctx.platform == "gpu":
return _remat_using_while(ctx, in_nodes, name, call_jaxpr)
else:
return _remat_using_cond(
c, axis_env, in_nodes, name_stack, backend, name, call_jaxpr)
return _remat_using_cond(ctx, in_nodes, name, call_jaxpr)
else:
outs = jaxpr_subcomp(c, call_jaxpr, backend, axis_env, (), "", *in_nodes)
return xops.Tuple(c, outs)
return jaxpr_subcomp(ctx, call_jaxpr, (), *in_nodes)
call_translations[pe.remat_call_p] = _remat_translation_rule # type: ignore
register_translation(pe.remat_call_p, _remat_translation_rule)
ad.primitive_transposes[core.named_call_p] = partial(ad.call_transpose,
core.named_call_p)
def _named_call_translation_rule(c, axis_env, in_nodes, name_stack, *,
name="core_call", backend, call_jaxpr):
def _named_call_translation_rule(ctx, avals_in, avals_out, *in_nodes,
name="core_call", backend=None, call_jaxpr):
check_backend_matches(backend, ctx.platform)
c = ctx.builder
subc = xb.make_computation_builder(name)
args = [xb.parameter(subc, i, c.GetShape(n)) for i, n in enumerate(in_nodes)]
out_nodes = jaxpr_subcomp(subc, call_jaxpr, backend, axis_env, (),
extend_name_stack(name_stack, name), *args)
sub_ctx = ctx.replace(builder=subc,
name_stack=extend_name_stack(ctx.name_stack, name))
out_nodes = jaxpr_subcomp(sub_ctx, call_jaxpr, (), *args)
subc = subc.Build(xops.Tuple(subc, out_nodes))
return xops.Call(c, subc, list(in_nodes))
call_translations[core.named_call_p] = _named_call_translation_rule
return xla_destructure(c, xops.Call(c, subc, list(in_nodes)))
register_translation(core.named_call_p, _named_call_translation_rule)
def _call_translation_rule(c, axis_env, in_nodes, name_stack, *, backend,
def _call_translation_rule(ctx, avals_in, avals_out, *in_nodes, backend=None,
call_jaxpr):
return _named_call_translation_rule(
c, axis_env, in_nodes, name_stack, name="core_call",
backend=backend, call_jaxpr=call_jaxpr)
call_translations[core.call_p] = _call_translation_rule
ctx, avals_in, avals_out, *in_nodes, name="core_call", backend=backend,
call_jaxpr=call_jaxpr)
register_translation(core.call_p, _call_translation_rule)

View File

@ -40,6 +40,7 @@ setup(
'numpy>=1.18',
'opt_einsum',
'scipy>=1.2.1',
'typing_extensions',
],
extras_require={
# Minimum jaxlib version; used in testing.

View File

@ -182,7 +182,9 @@ def _identity_impl(mat):
def _identity_abstract_eval(mat):
return AbstractSparseArray(mat.shape, mat.dtype, mat.index_dtype, mat.nnz)
xla.translations_with_avals[identity_p] = xla.lower_fun(_identity_impl, multiple_results=False, with_avals=True)
xla.register_translation(
identity_p, xla.lower_fun(_identity_impl, multiple_results=False,
new_style=True))
def split(x):
return split_p.bind(x)
@ -199,7 +201,8 @@ def _split_abstract_eval(mat):
m = AbstractSparseArray(mat.shape, mat.dtype, mat.index_dtype, mat.nnz)
return m, m
xla.translations_with_avals[split_p] = xla.lower_fun(_split_impl, multiple_results=True, with_avals=True)
xla.register_translation(
split_p, xla.lower_fun(_split_impl, multiple_results=True, new_style=True))
def make_sparse_array(rng, shape, dtype, nnz=0.2):
mat = rng(shape, dtype)

View File

@ -164,7 +164,7 @@ def helper_set_hlo_dump():
def helper_print_optimized_hlo(fun, *args):
backend = xla_bridge.get_backend()
c = jax.xla_computation(fun)(*args)
c = jax.xla_computation(fun, backend='cpu')(*args)
print(re.sub(r", metadata.*", "",
backend.compile(c).hlo_modules()[0].to_string()))
@ -175,7 +175,7 @@ def helper_log_ir(name,
num_partitions=None,
strip_metadata=False):
print(f"Jaxpr[{name}]: {jax.make_jaxpr(f_jax)(*args)}")
jax_comp = jax.xla_computation(f_jax)(*args)
jax_comp = jax.xla_computation(f_jax, backend='cpu')(*args)
print(f"HLO[{name}]: {jax_comp.as_hlo_text()}")
backend = xla_bridge.get_backend()
@ -418,7 +418,7 @@ class HostCallbackTapTest(jtu.JaxTestCase):
logging.info("%s: %s", self._testMethodName,
jax.make_jaxpr(func)(1))
logging.info("%s: %s", self._testMethodName,
jax.xla_computation(func)(1).as_hlo_text())
jax.xla_computation(func, backend='cpu')(1).as_hlo_text())
self.assertEqual(2, jax.jit(func)(1))
hcb.barrier_wait()

View File

@ -91,6 +91,9 @@ class MultiBackendTest(jtu.JaxTestCase):
raise SkipTest("Backend is not CPU or the device under test")
if inner not in ('cpu', jtu.device_under_test(), None):
raise SkipTest("Backend is not CPU or the device under test")
if outer is None and inner == jtu.device_under_test():
raise SkipTest("(None, device) is allowed")
@partial(jax.jit, backend=outer)
def fun(x, y):
@partial(jax.jit, backend=inner)

View File

@ -355,8 +355,7 @@ class PJitTest(jtu.BufferDonationTestCase):
f = xmap(lambda x: h(x * 2), in_axes=['i', ...], out_axes=['i', ...],
axis_resources={'i': 'y'})
x = jnp.arange(16).reshape((4, 4))
self.assertIn(pjit_p, xla.call_translations)
rule = xla.call_translations[pjit_p]
rule = xla._translations[pjit_p]
test_rule_called = False
def _test_rule(*args, **kwargs):
nonlocal test_rule_called
@ -366,11 +365,11 @@ class PJitTest(jtu.BufferDonationTestCase):
self.assertIn(('y',), in_axis_resources[0].partitions)
return rule(*args, **kwargs)
try:
xla.call_translations[pjit_p] = _test_rule
xla._translations[pjit_p] = _test_rule
f(x)
self.assertTrue(test_rule_called)
finally:
xla.call_translations[pjit_p] = rule
xla._translations[pjit_p] = rule
@jtu.with_mesh([('x', 2)])
def testLowerWithAbstractArgs(self):

View File

@ -249,10 +249,12 @@ class cuSparseTest(jtu.JaxTestCase):
cuda_version = None if version == "<unknown>" else int(version.split()[-1])
if cuda_version is None or cuda_version < 11000:
self.assertFalse(cusparse and cusparse.is_supported)
self.assertNotIn(sparse.csr_todense_p, xla.backend_specific_translations["gpu"])
self.assertNotIn(sparse.csr_todense_p,
xla._backend_specific_translations["gpu"])
else:
self.assertTrue(cusparse and cusparse.is_supported)
self.assertIn(sparse.csr_todense_p, xla.backend_specific_translations["gpu"])
self.assertIn(sparse.csr_todense_p,
xla._backend_specific_translations["gpu"])
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_{}".format(