mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 04:46:06 +00:00
Cleanup internal representation of XLA translation rules.
Over time JAX has sprouted many variants of XLA translation rules, each with slightly different but overlapping arguments. This change consolidates them into a new xla.TranslationRule signature: rule(ctx, avals_in, avals_out, *args, **params) where ctx contains the parts of the other signatures that were typically not specific to a particular equation. Since there are many JAX rules to migrate, and even a number of translation rules belonging to projects downstream of JAX, we leave backwards compatibility shims around `xla.translations`, `xla.backend_specific_translations`, and `xla.call_translations` which seem to be the only ones used outside JAX itself. In passing, this change alters the semantics of `backend` arguments to nested `jit` blocks. We now always canonicalize the backend to a specific backend at the outermost `jit`, and do not complain if an inner `jit` has an explicit `backend` that matches the current default choice. PiperOrigin-RevId: 403607667
This commit is contained in:
parent
69d7a813e7
commit
2bd010ae88
@ -281,20 +281,17 @@ def remat_abstract_eval(*args, jaxpr, prevent_cse, differentiated, policy):
|
||||
del args, prevent_cse, differentiated, policy # Unused.
|
||||
return [v.aval for v in jaxpr.outvars]
|
||||
|
||||
def remat_translation(c, axis_env, name_stack, avals, backend, *in_nodes,
|
||||
def remat_translation(ctx, avals_in, avals_out, *in_nodes,
|
||||
jaxpr, prevent_cse, differentiated, policy):
|
||||
del policy # Unused.
|
||||
if differentiated and prevent_cse:
|
||||
if backend == "gpu":
|
||||
return xla._remat_using_while(
|
||||
c, axis_env, in_nodes, name_stack, backend, "checkpoint", jaxpr)
|
||||
if ctx.platform == "gpu":
|
||||
return xla._remat_using_while(ctx, in_nodes, "checkpoint", jaxpr)
|
||||
else:
|
||||
return xla._remat_using_cond(
|
||||
c, axis_env, in_nodes, name_stack, backend, "checkpoint", jaxpr)
|
||||
return xla._remat_using_cond(ctx, in_nodes, "checkpoint", jaxpr)
|
||||
else:
|
||||
outs = xla.jaxpr_subcomp(c, jaxpr, backend, axis_env, (), "", *in_nodes)
|
||||
return xla.xops.Tuple(c, outs)
|
||||
xla.initial_style_translations[remat_p] = remat_translation
|
||||
return xla.jaxpr_subcomp(ctx, jaxpr, (), *in_nodes)
|
||||
xla.register_translation(remat_p, remat_translation)
|
||||
|
||||
def remat_jvp(primals, tangents, jaxpr, prevent_cse, differentiated, policy):
|
||||
assert not jaxpr.constvars
|
||||
|
@ -822,9 +822,10 @@ def xla_computation(fun: Callable,
|
||||
should_tuple = tuple_args if tuple_args is not None else (len(avals) > 100)
|
||||
xla_args, donated_invars = xla._xla_callable_args(
|
||||
c, avals, should_tuple, partitions=in_parts_flat, donated_invars=donated_invars)
|
||||
out_nodes = xla.jaxpr_subcomp(
|
||||
c, jaxpr, backend, axis_env_, xla_consts,
|
||||
extend_name_stack(wrap_name(fun_name, "xla_computation")), *xla_args)
|
||||
ctx = xla.TranslationContext(
|
||||
c, backend, axis_env_,
|
||||
extend_name_stack(wrap_name(fun_name, "xla_computation")))
|
||||
out_nodes = xla.jaxpr_subcomp(ctx, jaxpr, xla_consts, *xla_args)
|
||||
build_out_tuple = partial(xc.ops.Tuple, c, out_nodes)
|
||||
if out_parts is not None:
|
||||
out_tuple = xb.with_sharding(c, out_parts_flat, build_out_tuple)
|
||||
|
@ -366,8 +366,11 @@ def _custom_jvp_call_jaxpr_vmap(
|
||||
return batched_outs, out_dims
|
||||
batching.axis_primitive_batchers[custom_jvp_call_jaxpr_p] = _custom_jvp_call_jaxpr_vmap
|
||||
|
||||
xla.initial_style_translations[custom_jvp_call_jaxpr_p] = \
|
||||
xla.lower_fun_initial_style(_custom_jvp_call_jaxpr_impl)
|
||||
xla.register_translation(
|
||||
custom_jvp_call_jaxpr_p,
|
||||
xla.lower_fun(_custom_jvp_call_jaxpr_impl, new_style=True,
|
||||
multiple_results=True),
|
||||
initial_style=True)
|
||||
|
||||
# If a (multi)linear function is defined with a custom jvp, then
|
||||
# custom_jvp_call_jaxpr can appear in jaxprs to be transposed. Since it's
|
||||
@ -646,7 +649,7 @@ def _custom_vjp_call_jaxpr_jvp(
|
||||
res, primals_out = split_list(res_and_primals_out, [res_tree.num_leaves])
|
||||
avals_out = [raise_to_shaped(core.get_aval(x)) for x in primals_out]
|
||||
tangents_out = ad.custom_lin_p.bind(
|
||||
*res, *args_dot, num_res=res_tree.num_leaves, bwd=bwd, avals_out=avals_out)
|
||||
*res, *args_dot, num_res=res_tree.num_leaves, bwd=bwd, out_avals=avals_out)
|
||||
tangents_out = map(ad.recast_to_float0, primals_out, tangents_out)
|
||||
return primals_out, tangents_out
|
||||
ad.primitive_jvps[custom_vjp_call_jaxpr_p] = _custom_vjp_call_jaxpr_jvp
|
||||
@ -686,8 +689,11 @@ def _custom_vjp_call_jaxpr_vmap(
|
||||
return batched_outs, out_dims
|
||||
batching.axis_primitive_batchers[custom_vjp_call_jaxpr_p] = _custom_vjp_call_jaxpr_vmap
|
||||
|
||||
xla.initial_style_translations[custom_vjp_call_jaxpr_p] = \
|
||||
xla.lower_fun_initial_style(_custom_vjp_call_jaxpr_impl)
|
||||
xla.register_translation(
|
||||
custom_vjp_call_jaxpr_p,
|
||||
xla.lower_fun(_custom_vjp_call_jaxpr_impl, new_style=True,
|
||||
multiple_results=True),
|
||||
initial_style=True)
|
||||
|
||||
batching.primitive_batchers[ad.custom_lin_p] = ad._raise_custom_vjp_error_on_jvp
|
||||
xla.translations[ad.custom_lin_p] = ad._raise_custom_vjp_error_on_jvp
|
||||
@ -1067,4 +1073,7 @@ linear_call_p.multiple_results = True
|
||||
linear_call_p.def_impl(_linear_call_impl)
|
||||
linear_call_p.def_abstract_eval(_linear_call_abstract_eval)
|
||||
ad.primitive_transposes[linear_call_p] = _linear_call_transpose_rule
|
||||
xla.initial_style_translations[linear_call_p] = xla.lower_fun_initial_style(_linear_call_impl)
|
||||
xla.register_translation(linear_call_p,
|
||||
xla.lower_fun(_linear_call_impl, new_style=True,
|
||||
multiple_results=True),
|
||||
initial_style=True)
|
||||
|
@ -322,8 +322,9 @@ def while_loop(cond_fun: Callable[[T], bool],
|
||||
def _while_loop_abstract_eval(*args, **kwargs):
|
||||
return _map(raise_to_shaped, kwargs["body_jaxpr"].out_avals)
|
||||
|
||||
def _while_loop_translation_rule(c, axis_env, name_stack, avals, backend, *args,
|
||||
cond_jaxpr, body_jaxpr, cond_nconsts, body_nconsts):
|
||||
def _while_loop_translation_rule(ctx, avals_in, avals_out, *args, cond_jaxpr,
|
||||
body_jaxpr, cond_nconsts, body_nconsts):
|
||||
c = ctx.builder
|
||||
cond_consts, body_consts, init_vals = split_list(args, [cond_nconsts, body_nconsts])
|
||||
batched = bool(cond_jaxpr.out_avals[0].shape)
|
||||
|
||||
@ -339,9 +340,11 @@ def _while_loop_translation_rule(c, axis_env, name_stack, avals, backend, *args,
|
||||
cond_carry = xb.parameter(cond_c, 0, c.get_shape(init_carry))
|
||||
cond_carry_elts = [xops.GetTupleElement(cond_carry, i) for i in range(len(args))]
|
||||
x, _, z = split_list(cond_carry_elts, [cond_nconsts, body_nconsts])
|
||||
pred, = xla.jaxpr_subcomp(cond_c, cond_jaxpr.jaxpr, backend, axis_env,
|
||||
_map(partial(xb.constant, cond_c), cond_jaxpr.consts),
|
||||
extend_name_stack(name_stack, 'cond'), *(x + z))
|
||||
cond_ctx = ctx.replace(builder=cond_c,
|
||||
name_stack=extend_name_stack(ctx.name_stack, 'cond'))
|
||||
pred, = xla.jaxpr_subcomp(
|
||||
cond_ctx, cond_jaxpr.jaxpr,
|
||||
_map(partial(xb.constant, cond_c), cond_jaxpr.consts), *(x + z))
|
||||
if batched:
|
||||
scalar = ShapedArray((), np.bool_)
|
||||
or_ = xla.primitive_subcomputation(lax.or_p, scalar, scalar)
|
||||
@ -352,13 +355,18 @@ def _while_loop_translation_rule(c, axis_env, name_stack, avals, backend, *args,
|
||||
body_carry = xb.parameter(body_c, 0, c.get_shape(init_carry))
|
||||
body_carry_elts = [xops.GetTupleElement(body_carry, i) for i in range(len(args))]
|
||||
x, y, z = split_list(body_carry_elts, [cond_nconsts, body_nconsts])
|
||||
new_z = xla.jaxpr_subcomp(body_c, body_jaxpr.jaxpr, backend, axis_env,
|
||||
_map(partial(xb.constant, body_c), body_jaxpr.consts),
|
||||
extend_name_stack(name_stack, 'body'), *(y + z))
|
||||
body_ctx = ctx.replace(builder=body_c,
|
||||
name_stack=extend_name_stack(ctx.name_stack, 'body'))
|
||||
new_z = xla.jaxpr_subcomp(
|
||||
body_ctx, body_jaxpr.jaxpr,
|
||||
_map(partial(xb.constant, body_c), body_jaxpr.consts),
|
||||
*(y + z))
|
||||
if batched:
|
||||
body_pred, = xla.jaxpr_subcomp(body_c, cond_jaxpr.jaxpr, backend, axis_env,
|
||||
_map(partial(xb.constant, body_c), cond_jaxpr.consts),
|
||||
extend_name_stack(name_stack, 'body_pred'), *(x + z))
|
||||
body_pred_ctx = body_ctx.replace(
|
||||
name_stack=extend_name_stack(ctx.name_stack, 'body_pred'))
|
||||
body_pred, = xla.jaxpr_subcomp(
|
||||
body_pred_ctx, cond_jaxpr.jaxpr,
|
||||
_map(partial(xb.constant, body_c), cond_jaxpr.consts), *(x + z))
|
||||
new_z = _map(partial(_pred_bcast_select, body_c, body_pred), new_z, z, body_jaxpr.out_avals)
|
||||
assert _map(body_c.get_shape, new_z) == _map(body_c.get_shape, z) # no broadcast
|
||||
new_carry = xops.Tuple(body_c, [*x, *y, *new_z])
|
||||
@ -366,7 +374,7 @@ def _while_loop_translation_rule(c, axis_env, name_stack, avals, backend, *args,
|
||||
ans = xops.While(cond_c.build(pred), body_c.build(new_carry), init_carry)
|
||||
ans_elts = [xops.GetTupleElement(ans, i) for i in range(len(args))]
|
||||
_, _, z = split_list(ans_elts, [cond_nconsts, body_nconsts])
|
||||
return xops.Tuple(c, z)
|
||||
return z
|
||||
|
||||
def _pred_bcast_select(c, pred, x, y, x_y_aval: core.AbstractValue):
|
||||
pred_shape = c.get_shape(pred).dimensions()
|
||||
@ -594,7 +602,8 @@ while_p.def_impl(partial(xla.apply_primitive, while_p))
|
||||
while_p.def_abstract_eval(_while_loop_abstract_eval)
|
||||
ad.primitive_jvps[while_p] = _while_loop_jvp
|
||||
pe.custom_partial_eval_rules[while_p] = _while_partial_eval
|
||||
xla.initial_style_translations[while_p] = _while_loop_translation_rule
|
||||
xla.register_translation(while_p, _while_loop_translation_rule,
|
||||
initial_style=True)
|
||||
ad.primitive_transposes[while_p] = _while_transpose_error
|
||||
batching.axis_primitive_batchers[while_p] = _while_loop_batching_rule
|
||||
pe.partial_eval_jaxpr_custom_rules[while_p] = pe.partial_eval_jaxpr_custom_rule_not_implemented
|
||||
@ -786,25 +795,29 @@ def _cond_with_per_branch_args(pred,
|
||||
def _cond_abstract_eval(*args, **kwargs):
|
||||
return _map(raise_to_shaped, kwargs["branches"][0].out_avals)
|
||||
|
||||
def _cond_translation_rule(c, axis_env, name_stack, avals, backend,
|
||||
index, *args, branches, linear):
|
||||
def _cond_translation_rule(ctx, avals_in, avals_out, index, *args, branches,
|
||||
linear):
|
||||
del linear # Unused.
|
||||
|
||||
name_stack = extend_name_stack(ctx.name_stack, "cond")
|
||||
def make_computation(name, jaxpr, op_shape):
|
||||
c = xb.make_computation_builder(name + '_comp')
|
||||
op = xb.parameter(c, 0, op_shape)
|
||||
ops = [xops.GetTupleElement(op, i) for i in range(len(jaxpr.in_avals))]
|
||||
outs = xla.jaxpr_subcomp(c, jaxpr.jaxpr, backend, axis_env,
|
||||
_map(partial(xb.constant, c), jaxpr.consts),
|
||||
extend_name_stack(name_stack, name + '_fun'), *ops)
|
||||
subctx = ctx.replace(
|
||||
builder=c, name_stack=extend_name_stack(name_stack, name + '_fun'))
|
||||
outs = xla.jaxpr_subcomp(subctx, jaxpr.jaxpr,
|
||||
_map(partial(xb.constant, c), jaxpr.consts), *ops)
|
||||
return c.build(xops.Tuple(c, outs))
|
||||
|
||||
c = ctx.builder
|
||||
op = xops.Tuple(c, args)
|
||||
op_shape = c.get_shape(op)
|
||||
branch_computations = [
|
||||
make_computation(f'branch_{i}', jaxpr, op_shape)
|
||||
for i, jaxpr in enumerate(branches)]
|
||||
return xops.Conditional(index, branch_computations, [op] * len(branches))
|
||||
return xla.xla_destructure(
|
||||
c, xops.Conditional(index, branch_computations, [op] * len(branches)))
|
||||
|
||||
def _select_tree(indices, branch_vals):
|
||||
assert len(branch_vals) > 0
|
||||
@ -1171,7 +1184,7 @@ ad.primitive_jvps[cond_p] = _cond_jvp
|
||||
ad.reducing_transposes[cond_p] = _cond_transpose
|
||||
pe.custom_partial_eval_rules[cond_p] = _cond_partial_eval
|
||||
batching.axis_primitive_batchers[cond_p] = _cond_batching_rule
|
||||
xla.initial_style_translations[cond_p] = _cond_translation_rule
|
||||
xla.register_translation(cond_p, _cond_translation_rule, initial_style=True)
|
||||
core.custom_typechecks[cond_p] = _cond_typecheck
|
||||
pe.partial_eval_jaxpr_custom_rules[cond_p] = pe.partial_eval_jaxpr_custom_rule_not_implemented
|
||||
|
||||
@ -1923,7 +1936,9 @@ scan_p.def_abstract_eval(_scan_abstract_eval)
|
||||
ad.primitive_jvps[scan_p] = _scan_jvp
|
||||
ad.reducing_transposes[scan_p] = _scan_transpose
|
||||
pe.custom_partial_eval_rules[scan_p] = _scan_partial_eval
|
||||
xla.initial_style_translations[scan_p] = xla.lower_fun_initial_style(_scan_impl)
|
||||
xla.register_translation(scan_p, xla.lower_fun(_scan_impl, new_style=True,
|
||||
multiple_results=True),
|
||||
initial_style=True)
|
||||
batching.axis_primitive_batchers[scan_p] = _scan_batching_rule
|
||||
masking.masking_rules[scan_p] = _scan_masking_rule
|
||||
core.custom_typechecks[scan_p] = partial(_scan_typecheck, False)
|
||||
@ -2482,8 +2497,10 @@ linear_solve_p.multiple_results = True
|
||||
linear_solve_p.def_impl(_custom_linear_solve_impl)
|
||||
linear_solve_p.def_abstract_eval(_linear_solve_abstract_eval)
|
||||
ad.primitive_jvps[linear_solve_p] = _custom_linear_solve_jvp
|
||||
xla.initial_style_translations[linear_solve_p] = \
|
||||
xla.lower_fun_initial_style(_custom_linear_solve_impl)
|
||||
xla.register_translation(
|
||||
linear_solve_p, xla.lower_fun(_custom_linear_solve_impl, new_style=True,
|
||||
multiple_results=True),
|
||||
initial_style=True)
|
||||
ad.primitive_transposes[linear_solve_p] = _linear_solve_transpose_rule
|
||||
batching.axis_primitive_batchers[linear_solve_p] = _linear_solve_batching_rule
|
||||
pe.partial_eval_jaxpr_custom_rules[linear_solve_p] = pe.partial_eval_jaxpr_custom_rule_not_implemented
|
||||
|
@ -5552,7 +5552,8 @@ def _reduction_computation(c, jaxpr, consts, init_values, singleton=True):
|
||||
subc = xla_bridge.make_computation_builder("reduction_computation")
|
||||
assert len(consts) == 0, "Reduction computations cannot have constants"
|
||||
args = [xb.parameter(subc, i, shape) for i, shape in enumerate(shapes)]
|
||||
out_nodes = xla.jaxpr_subcomp(subc, jaxpr, None, axis_env, consts, '', *args)
|
||||
ctx = xla.TranslationContext(subc, None, axis_env, '')
|
||||
out_nodes = xla.jaxpr_subcomp(ctx, jaxpr, consts, *args)
|
||||
if singleton:
|
||||
return subc.build(out_nodes[0])
|
||||
out_nodes = xops.Tuple(subc, out_nodes)
|
||||
|
@ -643,9 +643,10 @@ def _allreduce_abstract_eval(*args, axes, axis_index_groups):
|
||||
arg.dtype, named_shape=named_shape)
|
||||
for arg, named_shape in zip(args, named_shapes)]
|
||||
|
||||
def _allreduce_translation_rule(prim, pos_prim, c, *args, axes, axis_index_groups,
|
||||
axis_env, platform):
|
||||
if axis_index_groups is not None and platform == "tpu":
|
||||
def _allreduce_translation_rule(prim, pos_fn, ctx, avals_in, avals_out, *args,
|
||||
axes, axis_index_groups):
|
||||
c = ctx.builder
|
||||
if axis_index_groups is not None and ctx.platform == "tpu":
|
||||
len_0 = len(axis_index_groups[0])
|
||||
if any(len(g) != len_0 for g in axis_index_groups):
|
||||
raise ValueError("axis_index_groups must all be the same size")
|
||||
@ -654,13 +655,14 @@ def _allreduce_translation_rule(prim, pos_prim, c, *args, axes, axis_index_group
|
||||
axes_partition[isinstance(axis, int)].append(axis)
|
||||
|
||||
if positional_axes:
|
||||
args = map(partial(xla.translations[pos_prim], c, axes=tuple(positional_axes)), args)
|
||||
reducer = xla.lower_fun(pos_fn, multiple_results=False)
|
||||
args = map(partial(reducer, c, axes=tuple(positional_axes)), args)
|
||||
if not named_axes:
|
||||
return xops.Tuple(c, args)
|
||||
return args
|
||||
|
||||
def all_reduce(x):
|
||||
replica_groups_protos = xc.make_replica_groups(
|
||||
_replica_groups(axis_env, named_axes, axis_index_groups))
|
||||
_replica_groups(ctx.axis_env, named_axes, axis_index_groups))
|
||||
scalar = ShapedArray((), c.get_shape(x).numpy_dtype())
|
||||
computation = xla.primitive_subcomputation(prim, scalar, scalar)
|
||||
return xops.AllReduce(x, computation, replica_groups_protos, None, None)
|
||||
@ -673,7 +675,7 @@ def _allreduce_translation_rule(prim, pos_prim, c, *args, axes, axis_index_group
|
||||
outs = [xops.Complex(all_reduce(xops.Real(x)), all_reduce(xops.Imag(x)))
|
||||
if dtypes.issubdtype(c.get_shape(x).numpy_dtype(), np.complexfloating)
|
||||
else all_reduce(x) for x in args]
|
||||
return xops.Tuple(c, outs)
|
||||
return outs
|
||||
|
||||
def _psum_transpose_rule(cts, *args, axes, axis_index_groups):
|
||||
named_axes, pos_axes = axes_partition = [], []
|
||||
@ -698,8 +700,9 @@ psum_p = core.AxisPrimitive('psum')
|
||||
psum_p.multiple_results = True
|
||||
psum_p.def_impl(partial(_allreduce_impl, lax._reduce_sum))
|
||||
psum_p.def_abstract_eval(_allreduce_abstract_eval)
|
||||
xla.parallel_translations[psum_p] = partial(_allreduce_translation_rule,
|
||||
lax.add_p, lax.reduce_sum_p) # type: ignore
|
||||
xla.register_translation(
|
||||
psum_p, partial(_allreduce_translation_rule, lax.add_p, lax._reduce_sum),
|
||||
is_collective=True)
|
||||
ad.deflinear2(psum_p, _psum_transpose_rule)
|
||||
pxla.multi_host_supported_collectives.add(psum_p)
|
||||
batching.primitive_batchers[psum_p] = partial(_reduction_batcher, psum_p)
|
||||
@ -734,8 +737,9 @@ pmax_p = core.AxisPrimitive('pmax')
|
||||
pmax_p.multiple_results = True
|
||||
pmax_p.def_impl(partial(_allreduce_impl, lax._reduce_max))
|
||||
pmax_p.def_abstract_eval(_allreduce_abstract_eval)
|
||||
xla.parallel_translations[pmax_p] = partial(_allreduce_translation_rule,
|
||||
lax.max_p, lax.reduce_max_p) # type: ignore
|
||||
xla.register_translation(
|
||||
pmax_p, partial(_allreduce_translation_rule, lax.max_p, lax._reduce_max),
|
||||
is_collective=True)
|
||||
pxla.multi_host_supported_collectives.add(pmax_p)
|
||||
batching.primitive_batchers[pmax_p] = partial(_reduction_batcher, pmax_p)
|
||||
batching.axis_primitive_batchers[pmax_p] = \
|
||||
@ -747,8 +751,9 @@ pmin_p = core.AxisPrimitive('pmin')
|
||||
pmin_p.multiple_results = True
|
||||
pmin_p.def_impl(partial(_allreduce_impl, lax._reduce_min))
|
||||
pmin_p.def_abstract_eval(_allreduce_abstract_eval)
|
||||
xla.parallel_translations[pmin_p] = partial(_allreduce_translation_rule,
|
||||
lax.min_p, lax.reduce_min_p) # type: ignore
|
||||
xla.register_translation(
|
||||
pmin_p, partial(_allreduce_translation_rule, lax.min_p, lax._reduce_min),
|
||||
is_collective=True)
|
||||
pxla.multi_host_supported_collectives.add(pmin_p)
|
||||
batching.primitive_batchers[pmin_p] = partial(_reduction_batcher, pmin_p)
|
||||
batching.axis_primitive_batchers[pmin_p] = \
|
||||
@ -756,8 +761,8 @@ batching.axis_primitive_batchers[pmin_p] = \
|
||||
core.axis_substitution_rules[pmin_p] = partial(_subst_all_names_in_param, 'axes')
|
||||
|
||||
|
||||
def _ppermute_translation_rule(c, x, *, axis_name, axis_env, perm, platform):
|
||||
replica_groups = _replica_groups(axis_env, axis_name, None)
|
||||
def _ppermute_translation_rule(ctx, avals_in, avals_out, x, *, axis_name, perm):
|
||||
replica_groups = _replica_groups(ctx.axis_env, axis_name, None)
|
||||
group_size = len(replica_groups[0])
|
||||
srcs, dsts = unzip2((src % group_size, dst % group_size) for src, dst in perm)
|
||||
if not (len(srcs) == len(set(srcs)) and len(dsts) == len(set(dsts))):
|
||||
@ -768,7 +773,7 @@ def _ppermute_translation_rule(c, x, *, axis_name, axis_env, perm, platform):
|
||||
for grp in replica_groups:
|
||||
grp = list(sorted(grp))
|
||||
full_perm.extend((grp[src], grp[dst]) for src, dst in perm)
|
||||
return xops.CollectivePermute(x, full_perm)
|
||||
return [xops.CollectivePermute(x, full_perm)]
|
||||
|
||||
def _ppermute_transpose_rule(t, x, perm, axis_name):
|
||||
srcs, dsts = unzip2(perm)
|
||||
@ -798,7 +803,8 @@ def _collective_batcher(prim, args, dims, **params):
|
||||
ppermute_p = core.AxisPrimitive('ppermute')
|
||||
ppermute_p.def_abstract_eval(lambda x, **params: raise_to_shaped(x))
|
||||
ad.deflinear2(ppermute_p, _ppermute_transpose_rule)
|
||||
xla.parallel_translations[ppermute_p] = _ppermute_translation_rule
|
||||
xla.register_translation(ppermute_p, _ppermute_translation_rule,
|
||||
is_collective=True)
|
||||
pxla.multi_host_supported_collectives.add(ppermute_p)
|
||||
batching.primitive_batchers[ppermute_p] = partial(_collective_batcher, ppermute_p)
|
||||
batching.axis_primitive_batchers[ppermute_p] = _ppermute_batcher
|
||||
@ -839,19 +845,20 @@ def _all_to_all_via_all_gather(x, *, axis_name, split_axis, concat_axis, axis_in
|
||||
sliced = lax.dynamic_slice_in_dim(full, tile_base_idx, tile_size, split_axis + 1)
|
||||
return _foldaxis(concat_axis, _moveaxis(0, concat_axis, sliced))
|
||||
|
||||
def _all_to_all_translation_rule(c, x, *, split_axis, concat_axis, axis_name,
|
||||
axis_index_groups, axis_env, platform):
|
||||
def _all_to_all_translation_rule(ctx, avals_in, avals_out, x, *, split_axis,
|
||||
concat_axis, axis_name, axis_index_groups):
|
||||
# Workaround for AllToAll not being implemented on CPU.
|
||||
replica_groups = _replica_groups(axis_env, axis_name, axis_index_groups)
|
||||
replica_groups = _replica_groups(ctx.axis_env, axis_name, axis_index_groups)
|
||||
if len(replica_groups[0]) == 1:
|
||||
return x
|
||||
elif (platform == "tpu") or ((platform == "gpu") and (split_axis == 0) and
|
||||
return [x]
|
||||
elif (ctx.platform == "tpu") or ((ctx.platform == "gpu") and (split_axis == 0) and
|
||||
(concat_axis == 0)):
|
||||
split_count = len(replica_groups[0])
|
||||
if not all(split_count == len(g) for g in replica_groups):
|
||||
raise ValueError('Replica groups must be equally sized')
|
||||
replica_groups_protos = xc.make_replica_groups(replica_groups)
|
||||
return xops.AllToAll(x, split_axis, concat_axis, split_count, replica_groups_protos)
|
||||
return [xops.AllToAll(x, split_axis, concat_axis, split_count,
|
||||
replica_groups_protos)]
|
||||
else:
|
||||
warnings.warn(
|
||||
"all_to_all (and pswapaxes) are only implemented properly for TPUs and GPUs (if "
|
||||
@ -859,16 +866,13 @@ def _all_to_all_translation_rule(c, x, *, split_axis, concat_axis, axis_name,
|
||||
"very slow and memory intensive algorithm, so expect significant slowdowns."
|
||||
)
|
||||
lowering = xla.lower_fun(
|
||||
_all_to_all_via_all_gather, multiple_results=False, parallel=True)
|
||||
_all_to_all_via_all_gather, multiple_results=False, new_style=True)
|
||||
return lowering(
|
||||
c,
|
||||
x,
|
||||
ctx, avals_in, avals_out, x,
|
||||
axis_name=axis_name,
|
||||
split_axis=split_axis,
|
||||
concat_axis=concat_axis,
|
||||
axis_name=axis_name,
|
||||
axis_index_groups=axis_index_groups,
|
||||
axis_env=axis_env,
|
||||
platform=platform)
|
||||
axis_index_groups=axis_index_groups)
|
||||
|
||||
def _all_to_all_transpose_rule(cts, x, axis_name, split_axis, concat_axis, axis_index_groups):
|
||||
return (all_to_all(
|
||||
@ -950,7 +954,8 @@ def _all_to_all_abstract_eval(x, axis_name, split_axis, concat_axis, axis_index_
|
||||
|
||||
all_to_all_p = core.AxisPrimitive('all_to_all')
|
||||
all_to_all_p.def_abstract_eval(_all_to_all_abstract_eval)
|
||||
xla.parallel_translations[all_to_all_p] = _all_to_all_translation_rule
|
||||
xla.register_translation(all_to_all_p, _all_to_all_translation_rule,
|
||||
is_collective=True)
|
||||
ad.deflinear2(all_to_all_p, _all_to_all_transpose_rule)
|
||||
pxla.multi_host_supported_collectives.add(all_to_all_p)
|
||||
batching.primitive_batchers[all_to_all_p] = _all_to_all_batcher
|
||||
@ -1048,22 +1053,29 @@ def _all_gather_via_psum(x, *, all_gather_dimension, axis_name, axis_index_group
|
||||
def _all_gather_impl(x, *, all_gather_dimension, axis_name, axis_index_groups, axis_size, tiled):
|
||||
raise AssertionError("Unexpected call to _all_gather_impl")
|
||||
|
||||
def _all_gather_translation_rule(c, x, *, all_gather_dimension, axis_name, axis_index_groups, axis_size, tiled, axis_env, platform):
|
||||
def _all_gather_translation_rule(
|
||||
ctx, avals_in, avals_out, x, *, all_gather_dimension, axis_name,
|
||||
axis_index_groups, axis_size, tiled):
|
||||
# TODO(jekbradbury): enable for all_gather_dimension > 0
|
||||
if platform == 'tpu' or platform == 'gpu' and all_gather_dimension == 0:
|
||||
c = ctx.builder
|
||||
if ctx.platform == 'tpu' or ctx.platform == 'gpu' and all_gather_dimension == 0:
|
||||
if not tiled:
|
||||
new_shape = list(c.get_shape(x).dimensions())
|
||||
new_shape.insert(all_gather_dimension, 1)
|
||||
broadcast_dimensions = [i for i in range(len(new_shape)) if i != all_gather_dimension]
|
||||
x = xops.BroadcastInDim(x, new_shape, broadcast_dimensions)
|
||||
replica_groups = _replica_groups(axis_env, axis_name, axis_index_groups)
|
||||
return xops.AllGather(x, all_gather_dimension=all_gather_dimension, shard_count=axis_size,
|
||||
replica_groups=xc.make_replica_groups(replica_groups))
|
||||
replica_groups = _replica_groups(ctx.axis_env, axis_name, axis_index_groups)
|
||||
return [
|
||||
xops.AllGather(x, all_gather_dimension=all_gather_dimension,
|
||||
shard_count=axis_size,
|
||||
replica_groups=xc.make_replica_groups(replica_groups))]
|
||||
else:
|
||||
lowering = xla.lower_fun(_all_gather_via_psum, multiple_results=False, parallel=True)
|
||||
return lowering(c, x, all_gather_dimension=all_gather_dimension, axis_name=axis_name,
|
||||
axis_index_groups=axis_index_groups, axis_size=axis_size, tiled=tiled,
|
||||
axis_env=axis_env, platform=platform)
|
||||
lowering = xla.lower_fun(_all_gather_via_psum, multiple_results=False,
|
||||
new_style=True)
|
||||
return lowering(
|
||||
ctx, avals_in, avals_out, x, all_gather_dimension=all_gather_dimension,
|
||||
axis_name=axis_name, axis_index_groups=axis_index_groups,
|
||||
axis_size=axis_size, tiled=tiled)
|
||||
|
||||
def _all_gather_abstract_eval(x, *, all_gather_dimension, axis_name, axis_index_groups, axis_size, tiled):
|
||||
if not isinstance(axis_name, (list, tuple)):
|
||||
@ -1103,7 +1115,8 @@ def _all_gather_batcher(vals_in, dims_in, *, all_gather_dimension, axis_name, ax
|
||||
tiled=tiled)
|
||||
return result, d
|
||||
|
||||
def _all_gather_batched_collective(frame_size, frame_name, _, vals_in, dims_in, all_gather_dimension, axis_name,
|
||||
def _all_gather_batched_collective(frame_size, frame_name, _, vals_in, dims_in,
|
||||
all_gather_dimension, axis_name,
|
||||
axis_index_groups, axis_size, tiled):
|
||||
assert axis_index_groups is None, "axis_index_groups not supported in vmap"
|
||||
assert axis_size == frame_size, "axis size doesn't match"
|
||||
@ -1127,7 +1140,8 @@ def _all_gather_batched_collective(frame_size, frame_name, _, vals_in, dims_in,
|
||||
all_gather_p = core.AxisPrimitive('all_gather')
|
||||
all_gather_p.def_abstract_eval(_all_gather_abstract_eval)
|
||||
all_gather_p.def_impl(_all_gather_impl)
|
||||
xla.parallel_translations[all_gather_p] = _all_gather_translation_rule
|
||||
xla.register_translation(all_gather_p, _all_gather_translation_rule,
|
||||
is_collective=True)
|
||||
ad.deflinear2(all_gather_p, _all_gather_transpose_rule)
|
||||
pxla.multi_host_supported_collectives.add(all_gather_p)
|
||||
batching.primitive_batchers[all_gather_p] = _all_gather_batcher
|
||||
@ -1135,7 +1149,8 @@ batching.axis_primitive_batchers[all_gather_p] = _all_gather_batched_collective
|
||||
core.axis_substitution_rules[all_gather_p] = partial(_subst_all_names_in_param, 'axis_name')
|
||||
|
||||
|
||||
def _reduce_scatter_via_reducer(x, *, reducer, scatter_dimension, axis_name, axis_index_groups, axis_size, tiled):
|
||||
def _reduce_scatter_via_reducer(x, *, reducer, scatter_dimension, axis_name,
|
||||
axis_index_groups, axis_size, tiled):
|
||||
index = _index_in_group(axis_name, axis_index_groups)
|
||||
scatter_dim_input_size = x.shape[scatter_dimension]
|
||||
if tiled and scatter_dim_input_size % axis_size != 0:
|
||||
@ -1159,13 +1174,14 @@ def _reduce_scatter_via_reducer(x, *, reducer, scatter_dimension, axis_name, axi
|
||||
return outs
|
||||
|
||||
|
||||
def _reduce_scatter_translation_rule(prim, reducer, c, x, *, scatter_dimension,
|
||||
axis_name, axis_index_groups, axis_size,
|
||||
tiled, axis_env, platform):
|
||||
if platform in ("tpu", "gpu"):
|
||||
def _reduce_scatter_translation_rule(prim, reducer, ctx, avals_in, avals_out, x,
|
||||
*, scatter_dimension, axis_name,
|
||||
axis_index_groups, axis_size, tiled):
|
||||
c = ctx.builder
|
||||
if ctx.platform in ("tpu", "gpu"):
|
||||
scalar = ShapedArray((), c.get_shape(x).numpy_dtype())
|
||||
computation = xla.primitive_subcomputation(prim, scalar, scalar)
|
||||
replica_groups = _replica_groups(axis_env, axis_name, axis_index_groups)
|
||||
replica_groups = _replica_groups(ctx.axis_env, axis_name, axis_index_groups)
|
||||
x = xops.ReduceScatter(
|
||||
x,
|
||||
computation,
|
||||
@ -1176,20 +1192,17 @@ def _reduce_scatter_translation_rule(prim, reducer, c, x, *, scatter_dimension,
|
||||
new_shape = list(c.get_shape(x).dimensions())
|
||||
del new_shape[scatter_dimension]
|
||||
x = xops.Reshape(x, new_shape)
|
||||
return x
|
||||
return [x]
|
||||
else:
|
||||
return xla.lower_fun(
|
||||
_reduce_scatter_via_reducer, multiple_results=False, parallel=True)(
|
||||
c,
|
||||
x,
|
||||
_reduce_scatter_via_reducer, multiple_results=False, new_style=True)(
|
||||
ctx, avals_in, avals_out, x,
|
||||
reducer=reducer,
|
||||
scatter_dimension=scatter_dimension,
|
||||
axis_name=axis_name,
|
||||
axis_index_groups=axis_index_groups,
|
||||
axis_size=axis_size,
|
||||
tiled=tiled,
|
||||
axis_env=axis_env,
|
||||
platform=platform)
|
||||
tiled=tiled)
|
||||
|
||||
|
||||
def _reduce_scatter_abstract_eval(x, *, axis_name, scatter_dimension,
|
||||
@ -1222,8 +1235,10 @@ def _reduce_scatter_abstract_eval(x, *, axis_name, scatter_dimension,
|
||||
|
||||
reduce_scatter_p = core.AxisPrimitive("reduce_scatter")
|
||||
reduce_scatter_p.def_abstract_eval(_reduce_scatter_abstract_eval)
|
||||
xla.parallel_translations[reduce_scatter_p] = partial(
|
||||
_reduce_scatter_translation_rule, lax.add_p, psum) # type: ignore
|
||||
xla.register_translation(
|
||||
reduce_scatter_p,
|
||||
partial(_reduce_scatter_translation_rule, lax.add_p, psum),
|
||||
is_collective=True)
|
||||
pxla.multi_host_supported_collectives.add(reduce_scatter_p)
|
||||
|
||||
|
||||
@ -1297,7 +1312,7 @@ def psum_scatter(x, axis_name, *, scatter_dimension=0, axis_index_groups=None, t
|
||||
return tree_util.tree_map(bind, x)
|
||||
|
||||
|
||||
def _axis_index_translation_rule(c, *, axis_name, axis_env, platform):
|
||||
def _build_axis_index_lowering(c, axis_name, axis_env):
|
||||
if isinstance(axis_name, tuple):
|
||||
assert axis_name, 'empty axis name'
|
||||
if len(axis_name) > 1:
|
||||
@ -1312,12 +1327,17 @@ def _axis_index_translation_rule(c, *, axis_name, axis_env, platform):
|
||||
unsigned_index = xops.Rem(xops.Div(xops.ReplicaId(c), div), mod)
|
||||
return xops.ConvertElementType(unsigned_index, xb.dtype_to_etype(np.int32))
|
||||
|
||||
def _axis_index_translation_rule(ctx, avals_in, avals_out, *, axis_name):
|
||||
return [_build_axis_index_lowering(ctx.builder, axis_name, ctx.axis_env)]
|
||||
|
||||
|
||||
def _axis_index_abstract_eval(*, axis_name):
|
||||
frame = core.axis_frame(axis_name)
|
||||
return ShapedArray((), np.int32, named_shape={axis_name: frame.size})
|
||||
|
||||
axis_index_p = core.Primitive('axis_index')
|
||||
xla.parallel_translations[axis_index_p] = _axis_index_translation_rule
|
||||
xla.register_translation(axis_index_p, _axis_index_translation_rule,
|
||||
is_collective=True)
|
||||
axis_index_p.def_abstract_eval(_axis_index_abstract_eval)
|
||||
pxla.multi_host_supported_collectives.add(axis_index_p)
|
||||
core.axis_substitution_rules[axis_index_p] = partial(_subst_all_names_in_param, 'axis_name')
|
||||
@ -1402,20 +1422,16 @@ def _pdot_vmap_batching_rule(vals_in, dims_in, *, axis_name, pos_contract,
|
||||
return out, result_batch_dim
|
||||
batching.primitive_batchers[pdot_p] = _pdot_vmap_batching_rule
|
||||
|
||||
def _pdot_translation_rule(c, x, y, *, axis_name, pos_contract, pos_batch, precision,
|
||||
axis_env, platform):
|
||||
local_out = lax._dot_general_translation_rule(
|
||||
c, x, y, dimension_numbers=[pos_contract, pos_batch], precision=precision,
|
||||
preferred_element_type=None)
|
||||
if axis_name:
|
||||
out_tup = xla.parallel_translations[psum_p](
|
||||
c, local_out, axes=axis_name, axis_index_groups=None,
|
||||
axis_env=axis_env, platform=platform)
|
||||
out, = xla.xla_destructure(c, out_tup)
|
||||
else:
|
||||
out = local_out
|
||||
return out
|
||||
xla.parallel_translations[pdot_p] = _pdot_translation_rule
|
||||
|
||||
def _pdot_lowering(x, y, *, axis_name, pos_contract, pos_batch, precision):
|
||||
local_out = lax.dot_general(x, y, dimension_numbers=[pos_contract, pos_batch],
|
||||
precision=precision, preferred_element_type=None)
|
||||
return psum(local_out, axis_name) if axis_name is not None else local_out
|
||||
|
||||
xla.register_translation(
|
||||
pdot_p,
|
||||
xla.lower_fun(_pdot_lowering, multiple_results=False, new_style=True),
|
||||
is_collective=True)
|
||||
|
||||
def _pdot_transpose_lhs(g, y, *, axis_name, pos_contract, pos_batch, precision):
|
||||
# TODO: avals with names, call pbroadcast with axis_name
|
||||
@ -1456,11 +1472,12 @@ def _pgather_abstract_eval(src, idx, *, axes):
|
||||
shape = idx.shape + tuple(shape)
|
||||
return ShapedArray(shape, src.dtype)
|
||||
|
||||
def _pgather_parallel_translation(c, src, idx, *, axes, axis_env, platform):
|
||||
def _pgather_parallel_translation(ctx, avals_in, avals_out, src, idx, *, axes):
|
||||
if any(not isinstance(axis, int) for axis in axes):
|
||||
raise NotImplementedError("pgather only supported in the SPMD lowering."
|
||||
"Please open a feature request!")
|
||||
return xla.lower_fun(_pgather_impl, multiple_results=False)(c, src, idx, axes=axes)
|
||||
return xla.lower_fun(_pgather_impl, multiple_results=False, new_style=True)(
|
||||
ctx, avals_in, avals_out, src, idx, axes=axes)
|
||||
|
||||
def _pgather_batcher(vals_in, dims_in, *, axes):
|
||||
src, idx = vals_in
|
||||
@ -1503,7 +1520,8 @@ def _pgather_collective_batcher(axis_size, frame_name, _, vals_in, dims_in, *, a
|
||||
pgather_p = core.AxisPrimitive('pgather')
|
||||
pgather_p.def_impl(_pgather_impl)
|
||||
pgather_p.def_abstract_eval(_pgather_abstract_eval)
|
||||
xla.parallel_translations[pgather_p] = _pgather_parallel_translation
|
||||
xla.register_translation(pgather_p, _pgather_parallel_translation,
|
||||
is_collective=True)
|
||||
# TODO: Transpose? That requires adding pscatter...
|
||||
batching.primitive_batchers[pgather_p] = _pgather_batcher
|
||||
batching.axis_primitive_batchers[pgather_p] = _pgather_collective_batcher
|
||||
|
@ -373,12 +373,12 @@ threefry2x32_p.multiple_results = True
|
||||
threefry2x32_p.def_impl(partial(xla.apply_primitive, threefry2x32_p))
|
||||
threefry2x32_p.def_abstract_eval(_threefry2x32_abstract_eval)
|
||||
batching.defbroadcasting(threefry2x32_p)
|
||||
xla.translations_with_avals[threefry2x32_p] = xla.lower_fun(
|
||||
xla.register_translation(threefry2x32_p, xla.lower_fun(
|
||||
partial(_threefry2x32_lowering, use_rolled_loops=False),
|
||||
multiple_results=True, with_avals=True)
|
||||
xla.backend_specific_translations['cpu'][threefry2x32_p] = xla.lower_fun(
|
||||
multiple_results=True, new_style=True))
|
||||
xla.register_translation(threefry2x32_p, xla.lower_fun(
|
||||
partial(_threefry2x32_lowering, use_rolled_loops=True),
|
||||
multiple_results=True)
|
||||
multiple_results=True, new_style=True), platform='cpu')
|
||||
if cuda_prng:
|
||||
xla.backend_specific_translations['gpu'][threefry2x32_p] = \
|
||||
_threefry2x32_gpu_translation_rule
|
||||
|
@ -959,12 +959,12 @@ random_gamma_p = core.Primitive('random_gamma')
|
||||
random_gamma_p.def_impl(_gamma_impl)
|
||||
random_gamma_p.def_abstract_eval(lambda key, a: core.raise_to_shaped(a))
|
||||
ad.defjvp2(random_gamma_p, None, lambda tangent, ans, key, a: tangent * _gamma_grad(ans, a))
|
||||
xla.translations_with_avals[random_gamma_p] = xla.lower_fun(
|
||||
xla.register_translation(random_gamma_p, xla.lower_fun(
|
||||
partial(_gamma_impl, use_vmap=True),
|
||||
multiple_results=False, with_avals=True)
|
||||
xla.backend_specific_translations['cpu'][random_gamma_p] = xla.lower_fun(
|
||||
multiple_results=False, new_style=True))
|
||||
xla.register_translation(random_gamma_p, xla.lower_fun(
|
||||
partial(_gamma_impl, use_vmap=False),
|
||||
multiple_results=False)
|
||||
multiple_results=False, new_style=True), platform='cpu')
|
||||
batching.primitive_batchers[random_gamma_p] = _gamma_batching_rule
|
||||
|
||||
def gamma(key: KeyArray,
|
||||
|
@ -801,8 +801,8 @@ def traceable_to_padded_translation(traceable):
|
||||
jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic(fun, in_avals)
|
||||
|
||||
operands_ = it.chain.from_iterable([*dims.values(), *operands])
|
||||
outs = xla.jaxpr_subcomp(c, jaxpr, None, xla.AxisEnv(1, (), ()),
|
||||
xla._xla_consts(c, consts), '', *operands_)
|
||||
ctx = xla.TranslationContext(c, None, xla.AxisEnv(1, (), ()), '')
|
||||
outs = xla.jaxpr_subcomp(ctx, jaxpr, xla._xla_consts(c, consts), *operands_)
|
||||
return xla._partition_outputs(
|
||||
[aval_to_num_buffers(aval) for aval in out_avals], outs)
|
||||
return translation
|
||||
|
@ -42,6 +42,7 @@ import jax._src.random
|
||||
from jax.experimental import maps
|
||||
from jax.experimental import pjit
|
||||
from jax.interpreters import ad
|
||||
from jax.interpreters import partial_eval
|
||||
from jax.interpreters import pxla
|
||||
from jax.interpreters import sharded_jit
|
||||
from jax.interpreters import xla
|
||||
@ -940,9 +941,10 @@ def _unexpected_primitive(p: core.Primitive, *args, **kwargs):
|
||||
assert False, f"Encountered unexpected primitive {p}"
|
||||
|
||||
|
||||
for unexpected in xla.call_translations: # Call primitives are inlined
|
||||
if unexpected is pjit.pjit_p:
|
||||
continue
|
||||
# Call primitives are inlined
|
||||
for unexpected in [core.call_p, core.named_call_p, xla.xla_call_p,
|
||||
partial_eval.remat_call_p, sharded_jit.sharded_call_p,
|
||||
maps.xmap_p]:
|
||||
tf_impl[unexpected] = partial(_unexpected_primitive, unexpected)
|
||||
|
||||
# Primitives that are not yet implemented must be explicitly declared here.
|
||||
|
@ -118,12 +118,10 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
|
||||
"""Fail if there are JAX primitives that are not implemented."""
|
||||
# Harvest primitives from XLA translation tables
|
||||
all_primitives = (
|
||||
set(xla.translations)
|
||||
| set(xla.backend_specific_translations["cpu"])
|
||||
| set(xla.backend_specific_translations["gpu"])
|
||||
| set(xla.backend_specific_translations["tpu"])
|
||||
| set(xla.initial_style_translations)
|
||||
| set(xla.parallel_translations))
|
||||
set(xla._translations)
|
||||
| set(xla._backend_specific_translations["cpu"])
|
||||
| set(xla._backend_specific_translations["gpu"])
|
||||
| set(xla._backend_specific_translations["tpu"]))
|
||||
|
||||
tf_impl = set(jax.experimental.jax2tf.jax2tf.tf_impl) | set(
|
||||
jax.experimental.jax2tf.jax2tf.tf_impl_with_avals)
|
||||
|
@ -46,7 +46,7 @@ from jax._src.lib import xla_client as xc
|
||||
from .._src.util import (safe_map, safe_zip, HashableFunction,
|
||||
as_hashable_function, unzip2, distributed_debug_log,
|
||||
tuple_insert, moveaxis, split_list, wrap_name)
|
||||
from .._src.lax.parallel import _axis_index_translation_rule
|
||||
from .._src.lax.parallel import _build_axis_index_lowering
|
||||
from .. import lax
|
||||
|
||||
class _PositionalSemantics(Enum):
|
||||
@ -1347,9 +1347,10 @@ def _xmap_translation_rule_replica(c, axis_env,
|
||||
# them!
|
||||
# We in-line here rather than generating a Call HLO as in the xla_call
|
||||
# translation rule just because the extra tuple stuff is a pain.
|
||||
tiled_outs = xla.jaxpr_subcomp(
|
||||
c, vectorized_jaxpr, backend, axis_env, (),
|
||||
xla.extend_name_stack(name_stack, xla.wrap_name(name, 'xmap')), *tiled_ins)
|
||||
ctx = xla.TranslationContext(
|
||||
c, backend, axis_env,
|
||||
xla.extend_name_stack(name_stack, xla.wrap_name(name, 'xmap')))
|
||||
tiled_outs = xla.jaxpr_subcomp(ctx, vectorized_jaxpr, (), *tiled_ins)
|
||||
|
||||
outs = [_xla_untile(c, axis_env, tiled_out, ans_out_axes, local_mesh_shape, backend)
|
||||
if v.aval is not core.abstract_unit else tiled_out
|
||||
@ -1363,8 +1364,8 @@ def _xla_tile_base_indices(c, axis_env, tile_shape, axes, axis_sizes):
|
||||
linear_idxs = [zero] * len(tile_shape)
|
||||
strides = [1] * len(tile_shape)
|
||||
for name, axis in reversed(axes.items()):
|
||||
axis_index = _axis_index_translation_rule(
|
||||
c, axis_name=name, axis_env=axis_env, platform=None)
|
||||
axis_index = _build_axis_index_lowering(
|
||||
c, axis_name=name, axis_env=axis_env)
|
||||
stride_c = xb.constant(c, np.array(strides[axis], np.int32))
|
||||
if linear_idxs[axis] is zero and strides[axis] == 1:
|
||||
linear_idxs[axis] = axis_index
|
||||
@ -1464,10 +1465,11 @@ def _xmap_translation_rule_spmd(c, axis_env,
|
||||
|
||||
# We in-line here rather than generating a Call HLO as in the xla_call
|
||||
# translation rule just because the extra tuple stuff is a pain.
|
||||
global_out_nodes = xla.jaxpr_subcomp(
|
||||
c, vectorized_jaxpr, backend, axis_env, (),
|
||||
xla.extend_name_stack(name_stack, xla.wrap_name(name, 'xmap')),
|
||||
*sharded_global_in_nodes)
|
||||
ctx = xla.TranslationContext(
|
||||
c, backend, axis_env,
|
||||
xla.extend_name_stack(name_stack, xla.wrap_name(name, 'xmap')))
|
||||
global_out_nodes = xla.jaxpr_subcomp(ctx, vectorized_jaxpr, (),
|
||||
*sharded_global_in_nodes)
|
||||
|
||||
sharded_global_out_nodes = [
|
||||
xb.set_sharding_proto(c, node, global_sharding_spec(aval, aval_axes).sharding_proto())
|
||||
|
@ -476,9 +476,11 @@ def _pjit_translation_rule(c, axis_env, in_nodes, name_stack, backend, name,
|
||||
get_sharding_proto(c, n, axis_resources, mesh)))
|
||||
|
||||
# TODO: Think about how to avoid duplicating constants with the outer jaxpr
|
||||
ctx = xla.TranslationContext(
|
||||
subc, backend, axis_env,
|
||||
extend_name_stack(name_stack, wrap_name(name, "pjit")))
|
||||
out_nodes = xla.jaxpr_subcomp(
|
||||
subc, jaxpr.jaxpr, backend, axis_env, xla._xla_consts(subc, jaxpr.consts),
|
||||
extend_name_stack(name_stack, wrap_name(name, "pjit")), *args)
|
||||
ctx, jaxpr.jaxpr, xla._xla_consts(subc, jaxpr.consts), *args)
|
||||
out_nodes = [
|
||||
xb.set_sharding_proto(subc, out,
|
||||
get_sharding_proto(subc, out, axis_resources, mesh))
|
||||
|
@ -370,7 +370,7 @@ class JVPTrace(Trace):
|
||||
avals_out = [raise_to_shaped(core.get_aval(x)) for x in primals_out]
|
||||
tangents_out = custom_lin_p.bind(
|
||||
*res, *tangents_in, num_res=res_tree.num_leaves, bwd=bwd,
|
||||
avals_out=avals_out)
|
||||
out_avals=avals_out)
|
||||
tangents_out = map(recast_to_float0, primals_out, tangents_out)
|
||||
return map(partial(JVPTracer, self), primals_out, tangents_out)
|
||||
|
||||
@ -676,7 +676,7 @@ def _interleave(xs, ys):
|
||||
|
||||
|
||||
custom_lin_p = core.Primitive('custom_lin')
|
||||
custom_lin_p.def_abstract_eval(lambda *_, avals_out, **__: avals_out)
|
||||
custom_lin_p.def_abstract_eval(lambda *_, out_avals, **__: out_avals)
|
||||
custom_lin_p.multiple_results = True
|
||||
|
||||
def _raise_custom_vjp_error_on_jvp(*_, **__):
|
||||
@ -684,9 +684,9 @@ def _raise_custom_vjp_error_on_jvp(*_, **__):
|
||||
"function.")
|
||||
custom_lin_p.def_impl(_raise_custom_vjp_error_on_jvp)
|
||||
|
||||
def _custom_lin_transpose(cts_out, *invals, num_res, bwd, avals_out):
|
||||
def _custom_lin_transpose(cts_out, *invals, num_res, bwd, out_avals):
|
||||
res, _ = split_list(invals, [num_res])
|
||||
cts_out = map(instantiate_zeros_aval, avals_out, cts_out)
|
||||
cts_out = map(instantiate_zeros_aval, out_avals, cts_out)
|
||||
cts_in = bwd.call_wrapped(*res, *cts_out)
|
||||
return [None] * num_res + list(cts_in)
|
||||
primitive_transposes[custom_lin_p] = _custom_lin_transpose
|
||||
|
@ -883,8 +883,9 @@ def parallel_callable(fun: lu.WrappedFun,
|
||||
partitions=arg_parts,
|
||||
donated_invars=donated_invars)
|
||||
with maybe_extend_axis_env(axis_name, global_axis_size, None): # type: ignore
|
||||
out_nodes = xla.jaxpr_subcomp(c, jaxpr, backend_name, axis_env, xla_consts,
|
||||
extend_name_stack(wrap_name(name, 'pmap')), *xla_args)
|
||||
ctx = xla.TranslationContext(c, backend.platform, axis_env,
|
||||
extend_name_stack(wrap_name(name, 'pmap')))
|
||||
out_nodes = xla.jaxpr_subcomp(ctx, jaxpr, xla_consts, *xla_args)
|
||||
build_out_tuple = partial(xops.Tuple, c, out_nodes)
|
||||
if out_parts is not None:
|
||||
out_tuple = xb.with_sharding(c, out_parts, build_out_tuple)
|
||||
@ -1297,9 +1298,10 @@ def _pmap_translation_rule(c, axis_env,
|
||||
for aval, in_node, in_axis in safe_zip(in_avals, in_nodes, in_axes))
|
||||
|
||||
with maybe_extend_axis_env(axis_name, global_axis_size, None): # type: ignore
|
||||
sharded_outs = xla.jaxpr_subcomp(
|
||||
c, call_jaxpr, backend, new_env, (),
|
||||
extend_name_stack(name_stack, wrap_name(name, 'pmap')), *in_nodes_sharded)
|
||||
ctx = xla.TranslationContext(
|
||||
c, backend, new_env,
|
||||
extend_name_stack(name_stack, wrap_name(name, 'pmap')))
|
||||
sharded_outs = xla.jaxpr_subcomp(ctx, call_jaxpr, (), *in_nodes_sharded)
|
||||
out_avals = [v.aval for v in call_jaxpr.outvars]
|
||||
outs = [_xla_unshard(c, aval, new_env, out_axis, shard, backend=backend)
|
||||
for aval, out_axis, shard in safe_zip(out_avals, out_axes, sharded_outs)]
|
||||
@ -1620,9 +1622,10 @@ def lower_mesh_computation(
|
||||
partitions_proto=partitions_proto,
|
||||
donated_invars=donated_invars)
|
||||
with core.extend_axis_env_nd(mesh.shape.items()):
|
||||
out_nodes = xla.jaxpr_subcomp(
|
||||
c, jaxpr, backend.platform, axis_env, xla_consts,
|
||||
extend_name_stack(wrap_name(transformed_name, 'xmap')), *xla_args)
|
||||
ctx = xla.TranslationContext(
|
||||
c, backend.platform, axis_env,
|
||||
extend_name_stack(wrap_name(transformed_name, 'xmap')))
|
||||
out_nodes = xla.jaxpr_subcomp(ctx, jaxpr, xla_consts, *xla_args)
|
||||
if spmd_lowering:
|
||||
out_partitions_t = xb.tuple_sharding_proto(out_partitions)
|
||||
out_tuple = xb.with_sharding_proto(c, out_partitions_t, xops.Tuple, c, out_nodes)
|
||||
|
@ -87,10 +87,10 @@ def _sharded_callable(
|
||||
|
||||
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_final(fun, global_abstract_args)
|
||||
|
||||
if xb.get_backend().platform not in ["tpu", "gpu"]:
|
||||
platform = xb.get_backend().platform
|
||||
if platform not in ["tpu", "gpu"]:
|
||||
# TODO(skye): fall back to regular jit?
|
||||
raise ValueError("sharded_jit not supported for " +
|
||||
xb.get_backend().platform)
|
||||
raise ValueError(f"sharded_jit not supported for {platform}")
|
||||
|
||||
nparts = pxla.reconcile_num_partitions(jaxpr, nparts)
|
||||
assert nparts is not None
|
||||
@ -142,9 +142,9 @@ def _sharded_callable(
|
||||
xla_consts = _map(partial(xb.constant, c), consts)
|
||||
xla_args = _xla_sharded_args(c, global_abstract_args, in_parts)
|
||||
axis_env = xla.AxisEnv(nrep, (), ())
|
||||
out_nodes = xla.jaxpr_subcomp(
|
||||
c, jaxpr, None, axis_env, xla_consts,
|
||||
extend_name_stack(wrap_name(name, "sharded_jit")), *xla_args)
|
||||
ctx = xla.TranslationContext(
|
||||
c, platform, axis_env, extend_name_stack(wrap_name(name, "sharded_jit")))
|
||||
out_nodes = xla.jaxpr_subcomp(ctx, jaxpr, xla_consts, *xla_args)
|
||||
out_tuple = xb.with_sharding(c, out_parts, xops.Tuple, c, out_nodes)
|
||||
built = c.Build(out_tuple)
|
||||
|
||||
@ -194,9 +194,10 @@ def _sharded_jit_translation_rule(c, axis_env, in_nodes, name_stack,
|
||||
arg = xb.parameter(subc, i, c.GetShape(n))
|
||||
args.append(xb.set_sharding(subc, arg, sharding))
|
||||
|
||||
out_nodes = xla.jaxpr_subcomp(
|
||||
subc, call_jaxpr, backend, axis_env, (),
|
||||
extend_name_stack(name_stack, wrap_name(name, "sharded_jit")), *args)
|
||||
ctx = xla.TranslationContext(
|
||||
subc, backend, axis_env,
|
||||
extend_name_stack(wrap_name(name, "sharded_jit")))
|
||||
out_nodes = xla.jaxpr_subcomp(ctx, call_jaxpr, (), *args)
|
||||
out_parts = out_parts_thunk()
|
||||
assert len(out_parts) == len(out_nodes)
|
||||
out_nodes = [xb.set_sharding(subc, out, sharding)
|
||||
|
@ -14,12 +14,15 @@
|
||||
|
||||
|
||||
from collections import defaultdict, deque
|
||||
import dataclasses
|
||||
import functools
|
||||
from functools import partial, partialmethod
|
||||
import itertools as it
|
||||
import operator as op
|
||||
import re
|
||||
from typing import (Any, Callable, Deque, Dict, List, Optional, Sequence, Set,
|
||||
Type, Tuple, Union, NamedTuple)
|
||||
from typing_extensions import Protocol
|
||||
from warnings import warn
|
||||
import weakref
|
||||
|
||||
@ -326,9 +329,9 @@ def _device_from_arg_devices(devices: Sequence[Optional[Device]]) -> Optional[De
|
||||
def primitive_subcomputation(prim: core.Primitive, *avals: core.AbstractValue,
|
||||
**params):
|
||||
c = xb.make_computation_builder(f"primitive_computation_{prim.name}")
|
||||
f = lower_fun(prim.bind, prim.multiple_results, with_avals=True)
|
||||
f = lower_fun(prim.bind, multiple_results=prim.multiple_results)
|
||||
xla_args, _ = _xla_callable_args(c, avals, tuple_args=False)
|
||||
ans = f(c, avals, xla_args, params)
|
||||
ans = f(c, *xla_args, **params)
|
||||
return c.build(ans)
|
||||
|
||||
def backend_compile(backend, built_c, options):
|
||||
@ -378,15 +381,31 @@ def _partitionmap(func: Callable, vars: Sequence, nodes: Sequence):
|
||||
_partition_outputs([len(aval_to_xla_shapes(v.aval)) for v in vars],
|
||||
nodes))
|
||||
|
||||
def jaxpr_subcomp(c, jaxpr, backend, axis_env, consts, name_stack, *args):
|
||||
if backend not in ('cpu', 'gpu', 'tpu'):
|
||||
platform = xb.get_backend(backend).platform # canonicalize
|
||||
else:
|
||||
platform = backend
|
||||
class AxisEnv(NamedTuple):
|
||||
"""Represents a pmap mesh (only along the replica axes)."""
|
||||
nreps: int
|
||||
names: Tuple[Any, ...]
|
||||
sizes: Tuple[int, ...]
|
||||
|
||||
@dataclasses.dataclass
|
||||
class TranslationContext:
|
||||
builder: xc.XlaBuilder
|
||||
# TODO(phawkins): make platform non-optional. We should always be translating
|
||||
# with a specific platform in mind.
|
||||
platform: Optional[str]
|
||||
axis_env: AxisEnv
|
||||
name_stack: str
|
||||
|
||||
def replace(self, **kw): return dataclasses.replace(self, **kw)
|
||||
|
||||
|
||||
def jaxpr_subcomp(ctx: TranslationContext, jaxpr: core.Jaxpr,
|
||||
consts: Sequence[XlaOp], *args: XlaOp) -> Sequence[XlaOp]:
|
||||
# TODO(phawkins): make platform non-optional.
|
||||
# assert ctx.platform is not None
|
||||
def read(v):
|
||||
if type(v) is Literal:
|
||||
return xb.constant_general(c, canonicalize_dtype(v.val))
|
||||
return xb.constant_general(ctx.builder, canonicalize_dtype(v.val))
|
||||
else:
|
||||
return env[v]
|
||||
|
||||
@ -400,52 +419,33 @@ def jaxpr_subcomp(c, jaxpr, backend, axis_env, consts, name_stack, *args):
|
||||
assert node is not None
|
||||
env[v] = node
|
||||
|
||||
env = {}
|
||||
_partitionmap(write, [core.unitvar], _make_unit_constant(c))
|
||||
env: Dict[core.Var, Sequence[XlaOp]] = {}
|
||||
_partitionmap(write, [core.unitvar], _make_unit_constant(ctx.builder))
|
||||
_partitionmap(write, jaxpr.constvars, consts)
|
||||
_partitionmap(write, jaxpr.invars, args)
|
||||
for eqn in jaxpr.eqns:
|
||||
op_metadata = make_op_metadata(
|
||||
eqn.primitive, eqn.params, name_stack=name_stack,
|
||||
eqn.primitive, eqn.params, name_stack=ctx.name_stack,
|
||||
source_info=eqn.source_info)
|
||||
c.set_op_metadata(op_metadata)
|
||||
ctx.builder.set_op_metadata(op_metadata)
|
||||
in_nodes = _flatmap(read, eqn.invars)
|
||||
with source_info_util.user_context(eqn.source_info):
|
||||
# TODO(jakevdp): migrate `translations` table to `translations_with_avals`
|
||||
if eqn.primitive in backend_specific_translations[platform]:
|
||||
rule = backend_specific_translations[platform][eqn.primitive]
|
||||
ans = rule(c, *in_nodes, **eqn.params)
|
||||
elif eqn.primitive in translations:
|
||||
ans = translations[eqn.primitive](c, *in_nodes, **eqn.params)
|
||||
elif eqn.primitive in translations_with_avals:
|
||||
rule = translations_with_avals[eqn.primitive]
|
||||
ans = rule(c, map(aval, eqn.invars), in_nodes, eqn.params)
|
||||
elif eqn.primitive in initial_style_translations:
|
||||
new_params = check_backend_params(eqn.params, backend)
|
||||
rule = initial_style_translations[eqn.primitive]
|
||||
ans = rule(c, axis_env, extend_name_stack(name_stack, eqn.primitive.name),
|
||||
map(aval, eqn.invars), backend, *in_nodes, **new_params)
|
||||
elif eqn.primitive in parallel_translations:
|
||||
rule = parallel_translations[eqn.primitive]
|
||||
ans = rule(c, *in_nodes, axis_env=axis_env, platform=platform, **eqn.params)
|
||||
elif eqn.primitive in call_translations:
|
||||
new_params = check_backend_params(eqn.params, backend)
|
||||
rule = call_translations[eqn.primitive]
|
||||
ans = rule(c, axis_env, in_nodes,
|
||||
name_stack, backend=backend, **new_params)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"XLA translation rule for primitive '{eqn.primitive.name}' not found")
|
||||
|
||||
assert isinstance(ans, xe.XlaOp)
|
||||
c.get_shape(ans) # force xla to do shape error checking
|
||||
if (eqn.primitive.multiple_results or
|
||||
any(len(aval_to_xla_shapes(v.aval)) > 1 for v in eqn.outvars)):
|
||||
out_nodes = xla_destructure(c, ans)
|
||||
if (ctx.platform is not None and
|
||||
eqn.primitive in _backend_specific_translations[ctx.platform]):
|
||||
rule = _backend_specific_translations[ctx.platform][eqn.primitive]
|
||||
elif eqn.primitive in _translations:
|
||||
rule = _translations[eqn.primitive]
|
||||
else:
|
||||
out_nodes = [ans]
|
||||
c.clear_op_metadata()
|
||||
_partitionmap(write, eqn.outvars, out_nodes)
|
||||
raise NotImplementedError(
|
||||
f"XLA translation rule for primitive '{eqn.primitive.name}' not found")
|
||||
|
||||
with source_info_util.user_context(eqn.source_info):
|
||||
ans = rule(ctx, map(aval, eqn.invars), map(aval, eqn.outvars),
|
||||
*in_nodes, **eqn.params)
|
||||
|
||||
assert all(isinstance(x, xe.XlaOp) for x in ans), ans
|
||||
map(ctx.builder.get_shape, ans) # force xla to do shape error checking
|
||||
ctx.builder.clear_op_metadata()
|
||||
_partitionmap(write, eqn.outvars, ans)
|
||||
return _flatmap(read, jaxpr.outvars)
|
||||
|
||||
|
||||
@ -453,23 +453,15 @@ def xla_destructure(c, ans):
|
||||
num_elements = len(c.get_shape(ans).tuple_shapes())
|
||||
return [xops.GetTupleElement(ans, i) for i in range(num_elements)]
|
||||
|
||||
def check_backend_params(params, outer_backend):
|
||||
def check_backend_matches(inner_backend, outer_backend):
|
||||
# For nested calls, the outermost call sets the backend for all inner calls;
|
||||
# it's an error if the inner call has a conflicting explicit backend spec.
|
||||
inner_backend = params.get('backend', None)
|
||||
if inner_backend and inner_backend != outer_backend:
|
||||
raise ValueError(
|
||||
f"Outer-jit backend specification {outer_backend} must match explicit "
|
||||
f"inner-jit backend specification {inner_backend}.")
|
||||
return {k: params[k] for k in params if k != 'backend'}
|
||||
|
||||
|
||||
class AxisEnv(NamedTuple):
|
||||
"""Represents a pmap mesh (only along the replica axes)."""
|
||||
nreps: int
|
||||
names: Tuple[Any, ...]
|
||||
sizes: Tuple[int, ...]
|
||||
|
||||
def extend_axis_env(env: AxisEnv, name, size: int):
|
||||
return AxisEnv(env.nreps, env.names + (name,), env.sizes + (size,))
|
||||
|
||||
@ -520,7 +512,7 @@ def eqn_replicas(eqn):
|
||||
call_jaxpr = eqn.params.get("call_jaxpr")
|
||||
if call_jaxpr:
|
||||
return eqn.params.get('axis_size', 1) * jaxpr_replicas(call_jaxpr)
|
||||
elif eqn.primitive in initial_style_translations:
|
||||
elif eqn.primitive in _initial_style_primitives:
|
||||
return initial_style_primitive_replicas(eqn.params)
|
||||
else:
|
||||
return 1
|
||||
@ -545,7 +537,7 @@ def jaxpr_has_pmap(jaxpr):
|
||||
def jaxpr_collectives(jaxpr):
|
||||
"""Generates all the collective primitives anywhere inside a Jaxpr."""
|
||||
for eqn in jaxpr.eqns:
|
||||
if eqn.primitive in parallel_translations:
|
||||
if eqn.primitive in _collective_primitives:
|
||||
yield eqn.primitive
|
||||
for subjaxpr in core.subjaxprs(jaxpr):
|
||||
yield from jaxpr_collectives(subjaxpr)
|
||||
@ -656,8 +648,7 @@ def lower_xla_callable(fun: lu.WrappedFun, device, backend, name,
|
||||
|
||||
nreps = jaxpr_replicas(jaxpr)
|
||||
device = _xla_callable_device(nreps, backend, device, arg_devices)
|
||||
backend = xb.get_device_backend(device) if device else (
|
||||
xb.get_backend(backend) if backend is not None else None)
|
||||
backend = xb.get_device_backend(device) if device else xb.get_backend(backend)
|
||||
|
||||
# Computations that only produce constants and/or only rearrange their inputs,
|
||||
# which are often produced from partial evaluation, don't need compilation,
|
||||
@ -696,15 +687,15 @@ def lower_xla_callable(fun: lu.WrappedFun, device, backend, name,
|
||||
xla_consts = _xla_consts(c, consts)
|
||||
xla_args, donated_invars = _xla_callable_args(c, abstract_args, tuple_args,
|
||||
donated_invars=donated_invars)
|
||||
out_nodes = jaxpr_subcomp(
|
||||
c, jaxpr, backend.platform if backend is not None else None,
|
||||
AxisEnv(nreps, (), ()), xla_consts,
|
||||
extend_name_stack(wrap_name(name, 'jit')), *xla_args)
|
||||
platform = backend.platform
|
||||
ctx = TranslationContext(c, platform, AxisEnv(nreps, (), ()),
|
||||
extend_name_stack(wrap_name(name, 'jit')))
|
||||
out_nodes = jaxpr_subcomp(ctx, jaxpr, xla_consts, *xla_args)
|
||||
backend = xb.get_backend(backend)
|
||||
# There is a non-zero cost to building an output tuple, particularly on TPU.
|
||||
# Avoid it if the output arity is 1.
|
||||
output = out_nodes[0] if len(out_nodes) == 1 else xops.Tuple(c, out_nodes)
|
||||
if backend.platform in ("gpu", "tpu"):
|
||||
if platform in ("gpu", "tpu"):
|
||||
donated_invars = set_up_aliases(
|
||||
c, xla_args, c.GetShape(output), donated_invars, tuple_args)
|
||||
if any(donated_invars):
|
||||
@ -1029,15 +1020,20 @@ def _xla_call_transpose_update_params(params, undef_primals, nonzero_cts):
|
||||
ad.call_transpose_param_updaters[xla_call_p] = _xla_call_transpose_update_params
|
||||
|
||||
|
||||
def _xla_call_translation_rule(c, axis_env, in_nodes, name_stack, backend, name,
|
||||
call_jaxpr, donated_invars, inline=None, device=None):
|
||||
def _xla_call_translation_rule(ctx, avals_in, avals_out, *in_nodes, name,
|
||||
backend=None, call_jaxpr, donated_invars,
|
||||
inline=None, device=None):
|
||||
del device, donated_invars, inline # Ignored.
|
||||
c = ctx.builder
|
||||
check_backend_matches(backend, ctx.platform)
|
||||
subc = xb.make_computation_builder(f"jit_{name}")
|
||||
args = [xb.parameter(subc, i, c.get_shape(n)) for i, n in enumerate(in_nodes)]
|
||||
out_nodes = jaxpr_subcomp(subc, call_jaxpr, backend, axis_env, (),
|
||||
extend_name_stack(name_stack, wrap_name(name, 'jit')), *args)
|
||||
sub_ctx = ctx.replace(
|
||||
builder=subc,
|
||||
name_stack=extend_name_stack(ctx.name_stack, wrap_name(name, 'jit')))
|
||||
out_nodes = jaxpr_subcomp(sub_ctx, call_jaxpr, (), *args)
|
||||
subc = subc.build(xops.Tuple(subc, out_nodes))
|
||||
return xops.Call(c, subc, list(in_nodes))
|
||||
return xla_destructure(c, xops.Call(c, subc, list(in_nodes)))
|
||||
ad.primitive_transposes[xla_call_p] = partial(ad.call_transpose, xla_call_p)
|
||||
|
||||
|
||||
@ -1059,14 +1055,100 @@ pe.dce_rules[xla_call_p] = pe.dce_jaxpr_call_rule
|
||||
|
||||
### translation tables
|
||||
|
||||
translations: Dict[core.Primitive, Callable] = {}
|
||||
translations_with_avals: Dict[core.Primitive, Callable] = {}
|
||||
parallel_translations: Dict[core.Primitive, Callable] = {}
|
||||
initial_style_translations: Dict[core.Primitive, Callable] = {}
|
||||
call_translations: Dict[core.Primitive, Callable] = {}
|
||||
backend_specific_translations: Dict[str, Dict[core.Primitive, Callable]] = defaultdict(dict)
|
||||
MYPY = False
|
||||
if not MYPY:
|
||||
class TranslationRule(Protocol):
|
||||
def __call__(self, ctx: TranslationContext,
|
||||
avals_in: Sequence[core.AbstractValue],
|
||||
avals_out: Sequence[core.AbstractValue],
|
||||
*args: XlaOp, **kw
|
||||
) -> Sequence[XlaOp]:
|
||||
"""A translation rule lowers a primitive invocation into an XLA HLO."""
|
||||
else:
|
||||
TranslationRule = Any
|
||||
|
||||
call_translations[xla_call_p] = _xla_call_translation_rule
|
||||
_translations: Dict[core.Primitive, TranslationRule] = {}
|
||||
_backend_specific_translations: Dict[str, Dict[core.Primitive, TranslationRule]]
|
||||
_backend_specific_translations = defaultdict(dict)
|
||||
|
||||
_collective_primitives: Set[core.Primitive] = set()
|
||||
_initial_style_primitives: Set[core.Primitive] = set()
|
||||
|
||||
def register_translation(prim: core.Primitive, rule: TranslationRule, *,
|
||||
platform: Optional[str] = None,
|
||||
is_collective: bool = False,
|
||||
initial_style: bool = False) -> None:
|
||||
ts = (_translations if platform is None
|
||||
else _backend_specific_translations[platform])
|
||||
ts[prim] = rule
|
||||
if is_collective:
|
||||
_collective_primitives.add(prim)
|
||||
if initial_style:
|
||||
_initial_style_primitives.add(prim)
|
||||
|
||||
# As a temporary backward compatibility measure, we use an adapter class to
|
||||
# convert from the old styles of translation rules to the newer ones.
|
||||
# TODO(phawkins): update users of the older translation rule styles and remove
|
||||
# the adapters.
|
||||
class _TranslationRuleAdapter:
|
||||
def __init__(self, translations,
|
||||
wrap_fn: Callable[[core.Primitive, Callable], TranslationRule]):
|
||||
self._translations = translations
|
||||
self._wrap_fn = wrap_fn
|
||||
|
||||
def __setitem__(self, key: core.Primitive, value: Callable):
|
||||
self._translations[key] = self._wrap_fn(key, value)
|
||||
|
||||
|
||||
def _wrap_old_translation(prim: core.Primitive, f: Callable) -> TranslationRule:
|
||||
@functools.wraps(f)
|
||||
def wrapped(ctx: TranslationContext, avals_in: Sequence[core.AbstractValue],
|
||||
avals_out: Sequence[core.AbstractValue],
|
||||
*args: XlaOp, **kw) -> Sequence[XlaOp]:
|
||||
ans = f(ctx.builder, *args, **kw)
|
||||
if (prim.multiple_results or
|
||||
any(len(aval_to_xla_shapes(aval)) > 1 for aval in avals_out)):
|
||||
return xla_destructure(ctx.builder, ans)
|
||||
else:
|
||||
return [ans]
|
||||
return wrapped
|
||||
|
||||
|
||||
def _wrap_old_call_translation(prim: core.Primitive,
|
||||
f: Callable) -> TranslationRule:
|
||||
@functools.wraps(f)
|
||||
def wrapped(ctx: TranslationContext, avals_in: Sequence[core.AbstractValue],
|
||||
avals_out: Sequence[core.AbstractValue],
|
||||
*args: XlaOp, **kw) -> Sequence[XlaOp]:
|
||||
platform = kw.pop("backend", None)
|
||||
check_backend_matches(platform, ctx.platform)
|
||||
ans = f(ctx.builder, ctx.axis_env, args, ctx.name_stack,
|
||||
backend=ctx.platform, **kw)
|
||||
if (prim.multiple_results or
|
||||
any(len(aval_to_xla_shapes(aval)) > 1 for aval in avals_out)):
|
||||
return xla_destructure(ctx.builder, ans)
|
||||
else:
|
||||
return [ans]
|
||||
return wrapped
|
||||
|
||||
translations : _TranslationRuleAdapter
|
||||
translations = _TranslationRuleAdapter(_translations, _wrap_old_translation)
|
||||
|
||||
class _BackendSpecificTranslationsAdapter(defaultdict):
|
||||
def __missing__(self, key):
|
||||
ret = self[key] = _TranslationRuleAdapter(
|
||||
_backend_specific_translations[key], _wrap_old_translation)
|
||||
return ret
|
||||
|
||||
backend_specific_translations: Dict[str, _TranslationRuleAdapter]
|
||||
backend_specific_translations = _BackendSpecificTranslationsAdapter()
|
||||
call_translations : _TranslationRuleAdapter
|
||||
call_translations = _TranslationRuleAdapter(
|
||||
_translations, _wrap_old_call_translation)
|
||||
|
||||
|
||||
|
||||
register_translation(xla_call_p, _xla_call_translation_rule)
|
||||
|
||||
def zeros_like_translation_rule(c, x):
|
||||
shape = c.get_shape(x)
|
||||
@ -1089,8 +1171,19 @@ def _tuple_output(*args, **kwargs):
|
||||
ans = yield args, kwargs
|
||||
yield (ans,)
|
||||
|
||||
def lower_fun(fun, multiple_results, parallel=False, with_avals=False, backend=None):
|
||||
# TODO(jakevdp): migrate dependent code & always use the with_avals=True.
|
||||
def lower_fun(fun: Callable, *, multiple_results: bool, parallel: bool = False,
|
||||
backend=None, new_style: bool = False):
|
||||
if new_style:
|
||||
def f_new(ctx, avals_in, avals_out, *xla_args, **params):
|
||||
wrapped_fun = lu.wrap_init(fun, params)
|
||||
if not multiple_results:
|
||||
wrapped_fun = _tuple_output(wrapped_fun)
|
||||
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, avals_in)
|
||||
return jaxpr_subcomp(ctx, jaxpr, _xla_consts(ctx.builder, consts),
|
||||
*xla_args)
|
||||
return f_new
|
||||
|
||||
# TODO(phawkins): migrate dependent code & always use new_style=True.
|
||||
def f(c, *xla_args, **params):
|
||||
avals = [_array_aval_from_xla_shape(c.get_shape(x)) for x in xla_args]
|
||||
return f_with_avals(c, avals, xla_args, params)
|
||||
@ -1106,8 +1199,8 @@ def lower_fun(fun, multiple_results, parallel=False, with_avals=False, backend=N
|
||||
wrapped_fun = _tuple_output(wrapped_fun)
|
||||
with core.extend_axis_env_nd(zip(axis_env.names, axis_env.sizes)):
|
||||
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, avals)
|
||||
outs = jaxpr_subcomp(c, jaxpr, backend, axis_env, _xla_consts(c, consts),
|
||||
'', *xla_args)
|
||||
ctx = TranslationContext(c, backend, axis_env, '')
|
||||
outs = jaxpr_subcomp(ctx, jaxpr, _xla_consts(c, consts), *xla_args)
|
||||
if (multiple_results or
|
||||
any(len(aval_to_xla_shapes(v.aval)) > 1 for v in jaxpr.outvars)):
|
||||
return xops.Tuple(c, outs)
|
||||
@ -1115,7 +1208,7 @@ def lower_fun(fun, multiple_results, parallel=False, with_avals=False, backend=N
|
||||
assert len(outs) == 1, outs
|
||||
return outs[0]
|
||||
|
||||
return f_with_avals if with_avals else f
|
||||
return f
|
||||
|
||||
def _array_aval_from_xla_shape(xla_shape):
|
||||
# This function instantiates the assumption that we can map fro XLA array
|
||||
@ -1124,15 +1217,6 @@ def _array_aval_from_xla_shape(xla_shape):
|
||||
assert not xla_shape.is_tuple()
|
||||
return ShapedArray(xla_shape.dimensions(), xla_shape.numpy_dtype())
|
||||
|
||||
def lower_fun_initial_style(fun):
|
||||
def f(c, axis_env, name_stack, avals, backend, *xla_args, **params):
|
||||
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(fun, params), avals)
|
||||
outs = jaxpr_subcomp(c, jaxpr, backend, axis_env, _xla_consts(c, consts),
|
||||
name_stack, *xla_args)
|
||||
return xops.Tuple(c, outs)
|
||||
return f
|
||||
|
||||
|
||||
### device-persistent data
|
||||
|
||||
class Token(object): pass
|
||||
@ -1481,14 +1565,14 @@ def _zeros(c, xla_shape):
|
||||
return xops.CreateToken(c)
|
||||
|
||||
|
||||
def _remat_using_cond(
|
||||
c, axis_env, in_nodes, name_stack, backend, name, call_jaxpr):
|
||||
def _remat_using_cond(ctx, in_nodes, name, call_jaxpr):
|
||||
"""Lower remat to a Conditional which always returns true. This:
|
||||
1. Circumvents common subexpression elimination.
|
||||
2. In common case of `jax.grad(jax.remat(f))`, ensures the remat blocks
|
||||
occur after the primal blocks, because cotangent is an input to the
|
||||
Conditional."""
|
||||
# Fake condition which always selects True branch.
|
||||
c = ctx.builder
|
||||
rng = xops.RngUniform(xb.constant(c, np.array(0, dtype=np.float32)),
|
||||
xb.constant(c, np.array(1, dtype=np.float32)),
|
||||
xc.Shape.array_shape(xc.PrimitiveType.F32, []))
|
||||
@ -1498,9 +1582,10 @@ def _remat_using_cond(
|
||||
remat_subc = xb.make_computation_builder("remat_call_subcomputation")
|
||||
input_op = xb.parameter(remat_subc, 0, c.get_shape(true_op), replicated=[])
|
||||
args = xla_destructure(remat_subc, input_op)
|
||||
out_nodes = jaxpr_subcomp(remat_subc, call_jaxpr, backend, axis_env, (),
|
||||
extend_name_stack(name_stack, wrap_name(name, 'remat')),
|
||||
*args)
|
||||
sub_ctx = ctx.replace(
|
||||
builder=remat_subc,
|
||||
name_stack=extend_name_stack(ctx.name_stack, wrap_name(name, 'remat')))
|
||||
out_nodes = jaxpr_subcomp(sub_ctx, call_jaxpr, (), *args)
|
||||
out_node_shapes = [remat_subc.get_shape(o) for o in out_nodes]
|
||||
remat_subc = remat_subc.build(xops.Tuple(remat_subc, out_nodes))
|
||||
|
||||
@ -1510,25 +1595,27 @@ def _remat_using_cond(
|
||||
out_nodes = [_zeros(dummy_subc, s) for s in out_node_shapes]
|
||||
dummy_subc = dummy_subc.build(xops.Tuple(dummy_subc, out_nodes))
|
||||
|
||||
return xops.Conditional(pred, true_op, remat_subc, false_op, dummy_subc)
|
||||
return xla_destructure(
|
||||
c, xops.Conditional(pred, true_op, remat_subc, false_op, dummy_subc))
|
||||
|
||||
|
||||
def _remat_using_while(
|
||||
c, axis_env, in_nodes, name_stack, backend, name, call_jaxpr):
|
||||
def _remat_using_while(ctx, in_nodes, name, call_jaxpr):
|
||||
"""Lower remat to a single iteration while loop."""
|
||||
c = ctx.builder
|
||||
# Dummy subc for getting subcomp shapes.
|
||||
dummy_inputs = xops.Tuple(c, in_nodes)
|
||||
dummy_subc = xb.make_computation_builder("remat_dummy_subcomputation")
|
||||
dummy_input_op = xb.parameter(dummy_subc, 0, c.get_shape(dummy_inputs), replicated=[])
|
||||
dummy_args = xla_destructure(dummy_subc, dummy_input_op)
|
||||
dummy_subcomp_outs = jaxpr_subcomp(
|
||||
dummy_subc, call_jaxpr, backend, axis_env, (),
|
||||
extend_name_stack(name_stack, wrap_name(name, "remat")), *dummy_args)
|
||||
dummy_ctx = ctx.replace(
|
||||
builder=dummy_subc,
|
||||
name_stack=extend_name_stack(ctx.name_stack, wrap_name(name, 'remat')))
|
||||
dummy_subcomp_outs = jaxpr_subcomp(dummy_ctx, call_jaxpr, (), *dummy_args)
|
||||
out_node_shapes = [dummy_subc.get_shape(o) for o in dummy_subcomp_outs]
|
||||
|
||||
i_init = xb.constant(c, np.array(0, dtype=np.int32))
|
||||
zeros_like_outs = [_zeros(c, s) for s in out_node_shapes]
|
||||
inputs = xops.Tuple(c, [i_init] + in_nodes + zeros_like_outs)
|
||||
inputs = xops.Tuple(c, [i_init] + list(in_nodes) + zeros_like_outs)
|
||||
|
||||
cond_subc = xb.make_computation_builder("remat_cond_subcomputation")
|
||||
input_op = xb.parameter(cond_subc, 0, c.get_shape(inputs), replicated=[])
|
||||
@ -1542,52 +1629,54 @@ def _remat_using_while(
|
||||
input_op = xb.parameter(body_subc, 0, c.get_shape(inputs), replicated=[])
|
||||
i, *args = xla_destructure(body_subc, input_op)[:len(in_nodes)+1]
|
||||
i_next = xops.Add(i, xb.constant(body_subc, np.array(1, dtype=np.int32)))
|
||||
subcomp_outs = jaxpr_subcomp(
|
||||
body_subc, call_jaxpr, backend, axis_env, (),
|
||||
extend_name_stack(name_stack, wrap_name(name, "remat")), *args)
|
||||
out_nodes = [i_next] + args + subcomp_outs
|
||||
body_ctx = ctx.replace(
|
||||
builder=body_subc,
|
||||
name_stack=extend_name_stack(ctx.name_stack, wrap_name(name, 'remat')))
|
||||
subcomp_outs = jaxpr_subcomp(body_ctx, call_jaxpr, (), *args)
|
||||
out_nodes = [i_next] + args + list(subcomp_outs)
|
||||
body_subc = body_subc.build(xops.Tuple(body_subc, out_nodes))
|
||||
outs = xops.While(cond_subc, body_subc, inputs)
|
||||
return xops.Tuple(c, xla_destructure(c, outs)[len(in_nodes)+1:])
|
||||
return xla_destructure(c, outs)[len(in_nodes)+1:]
|
||||
|
||||
|
||||
def _remat_translation_rule(c, axis_env, in_nodes,
|
||||
name_stack, backend, name, call_jaxpr,
|
||||
|
||||
def _remat_translation_rule(ctx, avals_in, avals_out, *in_nodes,
|
||||
name, call_jaxpr,
|
||||
prevent_cse, differentiated, concrete,
|
||||
policy, device=None):
|
||||
del device, concrete, policy # Unused.
|
||||
if differentiated and prevent_cse:
|
||||
if backend == "gpu":
|
||||
return _remat_using_while(
|
||||
c, axis_env, in_nodes, name_stack, backend, name, call_jaxpr)
|
||||
if ctx.platform == "gpu":
|
||||
return _remat_using_while(ctx, in_nodes, name, call_jaxpr)
|
||||
else:
|
||||
return _remat_using_cond(
|
||||
c, axis_env, in_nodes, name_stack, backend, name, call_jaxpr)
|
||||
return _remat_using_cond(ctx, in_nodes, name, call_jaxpr)
|
||||
else:
|
||||
outs = jaxpr_subcomp(c, call_jaxpr, backend, axis_env, (), "", *in_nodes)
|
||||
return xops.Tuple(c, outs)
|
||||
return jaxpr_subcomp(ctx, call_jaxpr, (), *in_nodes)
|
||||
|
||||
call_translations[pe.remat_call_p] = _remat_translation_rule # type: ignore
|
||||
register_translation(pe.remat_call_p, _remat_translation_rule)
|
||||
|
||||
|
||||
ad.primitive_transposes[core.named_call_p] = partial(ad.call_transpose,
|
||||
core.named_call_p)
|
||||
|
||||
|
||||
def _named_call_translation_rule(c, axis_env, in_nodes, name_stack, *,
|
||||
name="core_call", backend, call_jaxpr):
|
||||
def _named_call_translation_rule(ctx, avals_in, avals_out, *in_nodes,
|
||||
name="core_call", backend=None, call_jaxpr):
|
||||
check_backend_matches(backend, ctx.platform)
|
||||
c = ctx.builder
|
||||
subc = xb.make_computation_builder(name)
|
||||
args = [xb.parameter(subc, i, c.GetShape(n)) for i, n in enumerate(in_nodes)]
|
||||
out_nodes = jaxpr_subcomp(subc, call_jaxpr, backend, axis_env, (),
|
||||
extend_name_stack(name_stack, name), *args)
|
||||
sub_ctx = ctx.replace(builder=subc,
|
||||
name_stack=extend_name_stack(ctx.name_stack, name))
|
||||
out_nodes = jaxpr_subcomp(sub_ctx, call_jaxpr, (), *args)
|
||||
subc = subc.Build(xops.Tuple(subc, out_nodes))
|
||||
return xops.Call(c, subc, list(in_nodes))
|
||||
call_translations[core.named_call_p] = _named_call_translation_rule
|
||||
return xla_destructure(c, xops.Call(c, subc, list(in_nodes)))
|
||||
register_translation(core.named_call_p, _named_call_translation_rule)
|
||||
|
||||
|
||||
def _call_translation_rule(c, axis_env, in_nodes, name_stack, *, backend,
|
||||
def _call_translation_rule(ctx, avals_in, avals_out, *in_nodes, backend=None,
|
||||
call_jaxpr):
|
||||
return _named_call_translation_rule(
|
||||
c, axis_env, in_nodes, name_stack, name="core_call",
|
||||
backend=backend, call_jaxpr=call_jaxpr)
|
||||
call_translations[core.call_p] = _call_translation_rule
|
||||
ctx, avals_in, avals_out, *in_nodes, name="core_call", backend=backend,
|
||||
call_jaxpr=call_jaxpr)
|
||||
register_translation(core.call_p, _call_translation_rule)
|
||||
|
1
setup.py
1
setup.py
@ -40,6 +40,7 @@ setup(
|
||||
'numpy>=1.18',
|
||||
'opt_einsum',
|
||||
'scipy>=1.2.1',
|
||||
'typing_extensions',
|
||||
],
|
||||
extras_require={
|
||||
# Minimum jaxlib version; used in testing.
|
||||
|
@ -182,7 +182,9 @@ def _identity_impl(mat):
|
||||
def _identity_abstract_eval(mat):
|
||||
return AbstractSparseArray(mat.shape, mat.dtype, mat.index_dtype, mat.nnz)
|
||||
|
||||
xla.translations_with_avals[identity_p] = xla.lower_fun(_identity_impl, multiple_results=False, with_avals=True)
|
||||
xla.register_translation(
|
||||
identity_p, xla.lower_fun(_identity_impl, multiple_results=False,
|
||||
new_style=True))
|
||||
|
||||
def split(x):
|
||||
return split_p.bind(x)
|
||||
@ -199,7 +201,8 @@ def _split_abstract_eval(mat):
|
||||
m = AbstractSparseArray(mat.shape, mat.dtype, mat.index_dtype, mat.nnz)
|
||||
return m, m
|
||||
|
||||
xla.translations_with_avals[split_p] = xla.lower_fun(_split_impl, multiple_results=True, with_avals=True)
|
||||
xla.register_translation(
|
||||
split_p, xla.lower_fun(_split_impl, multiple_results=True, new_style=True))
|
||||
|
||||
def make_sparse_array(rng, shape, dtype, nnz=0.2):
|
||||
mat = rng(shape, dtype)
|
||||
|
@ -164,7 +164,7 @@ def helper_set_hlo_dump():
|
||||
|
||||
def helper_print_optimized_hlo(fun, *args):
|
||||
backend = xla_bridge.get_backend()
|
||||
c = jax.xla_computation(fun)(*args)
|
||||
c = jax.xla_computation(fun, backend='cpu')(*args)
|
||||
print(re.sub(r", metadata.*", "",
|
||||
backend.compile(c).hlo_modules()[0].to_string()))
|
||||
|
||||
@ -175,7 +175,7 @@ def helper_log_ir(name,
|
||||
num_partitions=None,
|
||||
strip_metadata=False):
|
||||
print(f"Jaxpr[{name}]: {jax.make_jaxpr(f_jax)(*args)}")
|
||||
jax_comp = jax.xla_computation(f_jax)(*args)
|
||||
jax_comp = jax.xla_computation(f_jax, backend='cpu')(*args)
|
||||
print(f"HLO[{name}]: {jax_comp.as_hlo_text()}")
|
||||
|
||||
backend = xla_bridge.get_backend()
|
||||
@ -418,7 +418,7 @@ class HostCallbackTapTest(jtu.JaxTestCase):
|
||||
logging.info("%s: %s", self._testMethodName,
|
||||
jax.make_jaxpr(func)(1))
|
||||
logging.info("%s: %s", self._testMethodName,
|
||||
jax.xla_computation(func)(1).as_hlo_text())
|
||||
jax.xla_computation(func, backend='cpu')(1).as_hlo_text())
|
||||
self.assertEqual(2, jax.jit(func)(1))
|
||||
hcb.barrier_wait()
|
||||
|
||||
|
@ -91,6 +91,9 @@ class MultiBackendTest(jtu.JaxTestCase):
|
||||
raise SkipTest("Backend is not CPU or the device under test")
|
||||
if inner not in ('cpu', jtu.device_under_test(), None):
|
||||
raise SkipTest("Backend is not CPU or the device under test")
|
||||
if outer is None and inner == jtu.device_under_test():
|
||||
raise SkipTest("(None, device) is allowed")
|
||||
|
||||
@partial(jax.jit, backend=outer)
|
||||
def fun(x, y):
|
||||
@partial(jax.jit, backend=inner)
|
||||
|
@ -355,8 +355,7 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
f = xmap(lambda x: h(x * 2), in_axes=['i', ...], out_axes=['i', ...],
|
||||
axis_resources={'i': 'y'})
|
||||
x = jnp.arange(16).reshape((4, 4))
|
||||
self.assertIn(pjit_p, xla.call_translations)
|
||||
rule = xla.call_translations[pjit_p]
|
||||
rule = xla._translations[pjit_p]
|
||||
test_rule_called = False
|
||||
def _test_rule(*args, **kwargs):
|
||||
nonlocal test_rule_called
|
||||
@ -366,11 +365,11 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
self.assertIn(('y',), in_axis_resources[0].partitions)
|
||||
return rule(*args, **kwargs)
|
||||
try:
|
||||
xla.call_translations[pjit_p] = _test_rule
|
||||
xla._translations[pjit_p] = _test_rule
|
||||
f(x)
|
||||
self.assertTrue(test_rule_called)
|
||||
finally:
|
||||
xla.call_translations[pjit_p] = rule
|
||||
xla._translations[pjit_p] = rule
|
||||
|
||||
@jtu.with_mesh([('x', 2)])
|
||||
def testLowerWithAbstractArgs(self):
|
||||
|
@ -249,10 +249,12 @@ class cuSparseTest(jtu.JaxTestCase):
|
||||
cuda_version = None if version == "<unknown>" else int(version.split()[-1])
|
||||
if cuda_version is None or cuda_version < 11000:
|
||||
self.assertFalse(cusparse and cusparse.is_supported)
|
||||
self.assertNotIn(sparse.csr_todense_p, xla.backend_specific_translations["gpu"])
|
||||
self.assertNotIn(sparse.csr_todense_p,
|
||||
xla._backend_specific_translations["gpu"])
|
||||
else:
|
||||
self.assertTrue(cusparse and cusparse.is_supported)
|
||||
self.assertIn(sparse.csr_todense_p, xla.backend_specific_translations["gpu"])
|
||||
self.assertIn(sparse.csr_todense_p,
|
||||
xla._backend_specific_translations["gpu"])
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_{}_{}".format(
|
||||
|
Loading…
x
Reference in New Issue
Block a user