mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Move functions out of xla.py closer to their users.
Refactoring only, no changes intended. The goal is to shrink xla.py down to only its HLO-compatibility role, and remove things that aren't related to HLO compatibility. Remove an unused top_k translation rule as well. PiperOrigin-RevId: 554946059
This commit is contained in:
parent
d01695c746
commit
ca17b6c08f
@ -333,6 +333,13 @@ register_constant_handler(core.Token, _token_constant_handler)
|
||||
|
||||
# Source locations
|
||||
|
||||
def get_canonical_source_file(frame: source_info_util.Frame) -> str:
|
||||
source_file = frame.file_name
|
||||
if config.jax_hlo_source_file_canonicalization_regex:
|
||||
source_file = re.sub(config.jax_hlo_source_file_canonicalization_regex,
|
||||
'', source_file)
|
||||
return source_file
|
||||
|
||||
def _traceback_to_location(tb: xc.Traceback) -> ir.Location:
|
||||
"""Converts a full traceback to a callsite() MLIR location."""
|
||||
frame_locs = []
|
||||
@ -340,7 +347,7 @@ def _traceback_to_location(tb: xc.Traceback) -> ir.Location:
|
||||
frame = source_info_util.raw_frame_to_frame(code, lasti)
|
||||
if source_info_util.is_user_filename(frame.file_name):
|
||||
file_loc = ir.Location.file(
|
||||
xla.get_canonical_source_file(frame),
|
||||
get_canonical_source_file(frame),
|
||||
frame.start_line,
|
||||
frame.start_column,
|
||||
)
|
||||
@ -371,7 +378,7 @@ def _source_info_to_location(
|
||||
if frame is None:
|
||||
loc = ir.Location.unknown()
|
||||
else:
|
||||
loc = ir.Location.file(xla.get_canonical_source_file(frame),
|
||||
loc = ir.Location.file(get_canonical_source_file(frame),
|
||||
frame.start_line, frame.start_column)
|
||||
loc = ir.Location.name(eqn_str, childLoc=loc)
|
||||
# TODO(phawkins): also include primitive.name as the operator type.
|
||||
@ -1383,13 +1390,25 @@ def _lower_jaxpr_to_fun_cached(ctx, fn_name, call_jaxpr, effects,
|
||||
return func_op
|
||||
|
||||
|
||||
def check_backend_matches(inner_backend, outer_backend):
|
||||
# For nested calls, the outermost call sets the backend for all inner calls;
|
||||
# it's an error if the inner call has a conflicting explicit backend spec.
|
||||
if inner_backend is None:
|
||||
return
|
||||
if (inner_backend != outer_backend and
|
||||
outer_backend not in xb.expand_platform_alias(inner_backend)):
|
||||
raise ValueError(
|
||||
f"Outer-jit backend specification {outer_backend} must match explicit "
|
||||
f"inner-jit backend specification {inner_backend}.")
|
||||
|
||||
|
||||
def _call_lowering(fn_name, stack_name, call_jaxpr, backend, ctx, avals_in,
|
||||
avals_out, tokens_in, *args,
|
||||
dim_var_values: Sequence[ir.Value],
|
||||
arg_names=None, result_names=None):
|
||||
if isinstance(call_jaxpr, core.Jaxpr):
|
||||
call_jaxpr = core.ClosedJaxpr(call_jaxpr, ())
|
||||
xla.check_backend_matches(backend, ctx.platform)
|
||||
check_backend_matches(backend, ctx.platform)
|
||||
effects = list(tokens_in.effects())
|
||||
output_types = map(aval_to_ir_types, avals_out)
|
||||
output_types = [token_type()] * len(effects) + output_types
|
||||
|
@ -1235,15 +1235,43 @@ def _pmap_dce_rule(used_outputs, eqn):
|
||||
return used_inputs, new_eqn
|
||||
|
||||
|
||||
def _xla_call_partial_eval_update_params(
|
||||
params: core.ParamDict, kept_inputs: Sequence[bool], num_new_inputs: int
|
||||
) -> core.ParamDict:
|
||||
donated_invars = params['donated_invars']
|
||||
if not kept_inputs and donated_invars:
|
||||
# JaxprTrace.post_process_call creates a call with no input tracers
|
||||
donated_invars = (False,) * num_new_inputs
|
||||
else:
|
||||
assert len(kept_inputs) == len(donated_invars)
|
||||
# JaxprTrace.process_call drops known input tracers
|
||||
donated_invars = [d for d, kept in zip(donated_invars, kept_inputs) if kept]
|
||||
# Any new inputs are prepended to the left, so mark those as not donated.
|
||||
donated_invars = [False] * num_new_inputs + donated_invars
|
||||
return dict(params, donated_invars=tuple(donated_invars))
|
||||
|
||||
def xla_call_jvp_update_params(params, nz_tangents):
|
||||
donated_invars = params['donated_invars']
|
||||
donated_tangents = [d for d, nz in zip(donated_invars, nz_tangents) if nz]
|
||||
new_donated_invars = (*donated_invars, *donated_tangents)
|
||||
return dict(params, donated_invars=new_donated_invars)
|
||||
|
||||
def _xla_call_transpose_update_params(params, undef_primals, nonzero_cts):
|
||||
donated_invars = params['donated_invars']
|
||||
donated_primals = [d for d, u in zip(donated_invars, undef_primals) if not u]
|
||||
donated_cotangents = [False for nz in nonzero_cts if nz]
|
||||
return dict(params, donated_invars=(*donated_primals, *donated_cotangents))
|
||||
|
||||
|
||||
# Set param update handlers to update `donated_invars` just like xla_call_p
|
||||
pe.call_param_updaters[xla_pmap_p] = xla.xla_call_partial_eval_update_params
|
||||
pe.call_param_updaters[xla_pmap_p] = _xla_call_partial_eval_update_params
|
||||
pe.partial_eval_jaxpr_custom_rules[xla_pmap_p] = \
|
||||
partial(pe.call_partial_eval_custom_rule,
|
||||
'call_jaxpr', _pmap_partial_eval_custom_params_updater,
|
||||
res_aval=_pmap_partial_eval_custom_res_maker)
|
||||
pe.dce_rules[xla_pmap_p] = _pmap_dce_rule
|
||||
ad.call_param_updaters[xla_pmap_p] = xla.xla_call_jvp_update_params
|
||||
ad.call_transpose_param_updaters[xla_pmap_p] = xla.xla_call_transpose_update_params
|
||||
ad.call_param_updaters[xla_pmap_p] = xla_call_jvp_update_params
|
||||
ad.call_transpose_param_updaters[xla_pmap_p] = _xla_call_transpose_update_params
|
||||
|
||||
ad.primitive_transposes[xla_pmap_p] = partial(ad.map_transpose, xla_pmap_p)
|
||||
|
||||
@ -1289,6 +1317,38 @@ def _hlo_shard(aval, axis_env, xs, in_axis):
|
||||
raise TypeError(aval)
|
||||
|
||||
|
||||
def _axis_read(axis_env, axis_name):
|
||||
try:
|
||||
return max(i for i, name in enumerate(axis_env.names) if name == axis_name)
|
||||
except ValueError:
|
||||
raise NameError(f"unbound axis name: {axis_name}") from None
|
||||
|
||||
def axis_groups(axis_env: sharding_impls.AxisEnv, name) -> tuple[tuple[int, ...]]:
|
||||
if not isinstance(name, (list, tuple)):
|
||||
name = (name,)
|
||||
mesh_axes = tuple(unsafe_map(partial(_axis_read, axis_env), name))
|
||||
trailing_size, ragged = divmod(axis_env.nreps, math.prod(axis_env.sizes))
|
||||
assert not ragged
|
||||
mesh_spec = axis_env.sizes + (trailing_size,)
|
||||
return _axis_groups(mesh_spec, mesh_axes)
|
||||
|
||||
def _axis_groups(mesh_spec, mesh_axes):
|
||||
"""Computes replica group ids for a collective performed over a subset of the mesh.
|
||||
|
||||
Args:
|
||||
mesh_spec: A sequence of integers representing the mesh shape.
|
||||
mesh_axes: A sequence of integers between 0 and `len(mesh_spec)` (exclusive)
|
||||
indicating over which axes the collective is performed.
|
||||
Returns:
|
||||
A tuple of replica groups (i.e. tuples containing replica ids).
|
||||
"""
|
||||
iota = np.arange(math.prod(mesh_spec)).reshape(mesh_spec)
|
||||
groups = np.reshape(
|
||||
np.moveaxis(iota, mesh_axes, np.arange(len(mesh_axes))),
|
||||
(math.prod(np.take(mesh_spec, mesh_axes)), -1))
|
||||
return tuple(unsafe_map(tuple, groups.T))
|
||||
|
||||
|
||||
# TODO(b/110096942): more efficient gather
|
||||
def _hlo_unshard(ctx: mlir.LoweringRuleContext, aval, axis_env, out_axis, xs, platform):
|
||||
if aval is core.abstract_token:
|
||||
@ -1311,7 +1371,7 @@ def _hlo_unshard(ctx: mlir.LoweringRuleContext, aval, axis_env, out_axis, xs, pl
|
||||
x, mlir.dense_int_elements([1])).result
|
||||
padded = hlo.DynamicUpdateSliceOp(padded, broadcast_result, idxs).result
|
||||
replica_groups = mlir.dense_int_elements(
|
||||
xla.axis_groups(axis_env, axis_env.names[-1]))
|
||||
axis_groups(axis_env, axis_env.names[-1]))
|
||||
out = hlo.CrossReplicaSumOp(padded, replica_groups).result
|
||||
if out_axis != 0:
|
||||
# TODO(apaszke,mattjj): Change the indices to DynamicUpdateSlice instead
|
||||
@ -1335,17 +1395,22 @@ def _hlo_unshard(ctx: mlir.LoweringRuleContext, aval, axis_env, out_axis, xs, pl
|
||||
raise TypeError(aval)
|
||||
|
||||
|
||||
def _extend_axis_env(env: sharding_impls.AxisEnv, name, size: int):
|
||||
return sharding_impls.AxisEnv(env.nreps, env.names + (name,),
|
||||
env.sizes + (size,))
|
||||
|
||||
|
||||
def _pmap_lowering(ctx, *in_nodes, axis_name,
|
||||
axis_size, global_axis_size, devices, name,
|
||||
call_jaxpr, backend=None, in_axes, out_axes,
|
||||
donated_invars, is_explicit_global_axis_size):
|
||||
del donated_invars # Unused.
|
||||
xla.check_backend_matches(backend, ctx.module_context.platform)
|
||||
mlir.check_backend_matches(backend, ctx.module_context.platform)
|
||||
# We in-line here rather than generating a Call HLO as in the xla_call
|
||||
# translation rule just because the extra tuple stuff is a pain.
|
||||
if ctx.module_context.axis_env.names and devices is not None:
|
||||
raise ValueError("Nested pmap with explicit devices argument.")
|
||||
new_env = xla.extend_axis_env(ctx.module_context.axis_env, axis_name,
|
||||
new_env = _extend_axis_env(ctx.module_context.axis_env, axis_name,
|
||||
global_axis_size)
|
||||
# Shard the in_nodes that are mapped
|
||||
in_avals = [v.aval for v in call_jaxpr.invars]
|
||||
|
@ -20,15 +20,11 @@ import dataclasses
|
||||
import functools
|
||||
from functools import partial
|
||||
import itertools as it
|
||||
import math
|
||||
import operator
|
||||
import re
|
||||
from typing import Any, Callable, Optional, Protocol, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax._src.config import config
|
||||
|
||||
from jax._src import core
|
||||
from jax._src import dtypes
|
||||
from jax._src import source_info_util
|
||||
@ -59,13 +55,6 @@ def _make_array_shape(aval: ShapedArray) -> Sequence[xc.Shape]:
|
||||
dtype = np.dtype('bool') if aval.dtype == dtypes.float0 else aval.dtype
|
||||
return (xc.Shape.array_shape(dtype, aval.shape),)
|
||||
|
||||
def get_canonical_source_file(frame: source_info_util.Frame):
|
||||
source_file = frame.file_name
|
||||
if config.jax_hlo_source_file_canonicalization_regex:
|
||||
source_file = re.sub(config.jax_hlo_source_file_canonicalization_regex,
|
||||
'', source_file)
|
||||
return source_file
|
||||
|
||||
# Utilities
|
||||
|
||||
def parameter(builder, num, shape, name=None, replicated=None):
|
||||
@ -121,18 +110,6 @@ def tuple_sharding_proto(elems):
|
||||
return proto
|
||||
|
||||
|
||||
def with_sharding_proto(builder, sharding_proto, op_fn, *args, **kwargs):
|
||||
"""Builds op_fn(*args, **kwargs) with sharding annotation."""
|
||||
builder.set_sharding(sharding_proto)
|
||||
try:
|
||||
return op_fn(*args, **kwargs)
|
||||
finally:
|
||||
builder.clear_sharding()
|
||||
|
||||
def with_sharding(builder, sharding: SpatialSharding, op_fn, *args, **kwargs):
|
||||
"""Builds op_fn(*args, **kwargs) with sharding annotation."""
|
||||
return with_sharding_proto(builder, sharding_to_proto(sharding), op_fn, *args,
|
||||
**kwargs)
|
||||
|
||||
|
||||
### handlers
|
||||
@ -141,16 +118,16 @@ def with_sharding(builder, sharding: SpatialSharding, op_fn, *args, **kwargs):
|
||||
|
||||
def aval_to_xla_shapes(aval: core.AbstractValue) -> Sequence[xc.Shape]:
|
||||
try:
|
||||
return xla_shape_handlers[type(aval)](aval)
|
||||
return _xla_shape_handlers[type(aval)](aval)
|
||||
except KeyError as err:
|
||||
raise TypeError(f"No xla_shape_handler for type: {type(aval)}") from err
|
||||
|
||||
xla_shape_handlers: dict[type[core.AbstractValue],
|
||||
_xla_shape_handlers: dict[type[core.AbstractValue],
|
||||
Callable[[Any], Sequence[xc.Shape]]] = {
|
||||
ShapedArray: _make_array_shape,
|
||||
ConcreteArray: _make_array_shape,
|
||||
}
|
||||
xla_shape_handlers[core.AbstractToken] = lambda _: (xc.Shape.token_shape(),)
|
||||
_xla_shape_handlers[core.AbstractToken] = lambda _: (xc.Shape.token_shape(),)
|
||||
|
||||
|
||||
# IR constants
|
||||
@ -270,52 +247,6 @@ def xla_destructure(c, ans):
|
||||
num_elements = len(c.get_shape(ans).tuple_shapes())
|
||||
return [xops.GetTupleElement(ans, i) for i in range(num_elements)]
|
||||
|
||||
def check_backend_matches(inner_backend, outer_backend):
|
||||
# For nested calls, the outermost call sets the backend for all inner calls;
|
||||
# it's an error if the inner call has a conflicting explicit backend spec.
|
||||
if inner_backend is None:
|
||||
return
|
||||
if (inner_backend != outer_backend and
|
||||
outer_backend not in xb.expand_platform_alias(inner_backend)):
|
||||
raise ValueError(
|
||||
f"Outer-jit backend specification {outer_backend} must match explicit "
|
||||
f"inner-jit backend specification {inner_backend}.")
|
||||
|
||||
|
||||
def extend_axis_env(env: AxisEnv, name, size: int):
|
||||
return AxisEnv(env.nreps, env.names + (name,), env.sizes + (size,))
|
||||
|
||||
def axis_read(axis_env, axis_name):
|
||||
try:
|
||||
return max(i for i, name in enumerate(axis_env.names) if name == axis_name)
|
||||
except ValueError:
|
||||
raise NameError(f"unbound axis name: {axis_name}") from None
|
||||
|
||||
def axis_groups(axis_env: AxisEnv, name) -> tuple[tuple[int, ...]]:
|
||||
if not isinstance(name, (list, tuple)):
|
||||
name = (name,)
|
||||
mesh_axes = tuple(unsafe_map(partial(axis_read, axis_env), name))
|
||||
trailing_size, ragged = divmod(axis_env.nreps, math.prod(axis_env.sizes))
|
||||
assert not ragged
|
||||
mesh_spec = axis_env.sizes + (trailing_size,)
|
||||
return _axis_groups(mesh_spec, mesh_axes)
|
||||
|
||||
def _axis_groups(mesh_spec, mesh_axes):
|
||||
"""Computes replica group ids for a collective performed over a subset of the mesh.
|
||||
|
||||
Args:
|
||||
mesh_spec: A sequence of integers representing the mesh shape.
|
||||
mesh_axes: A sequence of integers between 0 and `len(mesh_spec)` (exclusive)
|
||||
indicating over which axes the collective is performed.
|
||||
Returns:
|
||||
A tuple of replica groups (i.e. tuples containing replica ids).
|
||||
"""
|
||||
iota = np.arange(math.prod(mesh_spec)).reshape(mesh_spec)
|
||||
groups = np.reshape(
|
||||
np.moveaxis(iota, mesh_axes, np.arange(len(mesh_axes))),
|
||||
(math.prod(np.take(mesh_spec, mesh_axes)), -1))
|
||||
return tuple(unsafe_map(tuple, groups.T))
|
||||
|
||||
|
||||
# TODO(mattjj,skyewm): the functions here are utilities for checking if
|
||||
# not-yet-supported features are used with multi-host programming
|
||||
@ -329,37 +260,6 @@ def jaxpr_collectives(jaxpr):
|
||||
for subjaxpr in core.subjaxprs(jaxpr): yield from jaxpr_collectives(subjaxpr)
|
||||
|
||||
|
||||
### xla_call underlying jit
|
||||
|
||||
|
||||
def xla_call_partial_eval_update_params(
|
||||
params: core.ParamDict, kept_inputs: Sequence[bool], num_new_inputs: int
|
||||
) -> core.ParamDict:
|
||||
donated_invars = params['donated_invars']
|
||||
if not kept_inputs and donated_invars:
|
||||
# JaxprTrace.post_process_call creates a call with no input tracers
|
||||
donated_invars = (False,) * num_new_inputs
|
||||
else:
|
||||
assert len(kept_inputs) == len(donated_invars)
|
||||
# JaxprTrace.process_call drops known input tracers
|
||||
donated_invars = [d for d, kept in zip(donated_invars, kept_inputs) if kept]
|
||||
# Any new inputs are prepended to the left, so mark those as not donated.
|
||||
donated_invars = [False] * num_new_inputs + donated_invars
|
||||
return dict(params, donated_invars=tuple(donated_invars))
|
||||
|
||||
def xla_call_jvp_update_params(params, nz_tangents):
|
||||
donated_invars = params['donated_invars']
|
||||
donated_tangents = [d for d, nz in zip(donated_invars, nz_tangents) if nz]
|
||||
new_donated_invars = (*donated_invars, *donated_tangents)
|
||||
return dict(params, donated_invars=new_donated_invars)
|
||||
|
||||
def xla_call_transpose_update_params(params, undef_primals, nonzero_cts):
|
||||
donated_invars = params['donated_invars']
|
||||
donated_primals = [d for d, u in zip(donated_invars, undef_primals) if not u]
|
||||
donated_cotangents = [False for nz in nonzero_cts if nz]
|
||||
return dict(params, donated_invars=(*donated_primals, *donated_cotangents))
|
||||
|
||||
|
||||
### translation tables
|
||||
|
||||
MYPY = False
|
||||
|
@ -4184,9 +4184,6 @@ def _top_k_batch_rule(batched_args, batch_dims, *, k):
|
||||
else:
|
||||
return top_k(operand, k=k), (bdim, bdim)
|
||||
|
||||
def _top_k_translation_rule(ctx, avals_in, avals_out, x, *, k):
|
||||
return xla.xla_destructure(ctx.builder, xops.TopK(x, k))
|
||||
|
||||
top_k_p = Primitive('top_k')
|
||||
top_k_p.multiple_results = True
|
||||
top_k_p.def_impl(partial(dispatch.apply_primitive, top_k_p))
|
||||
|
@ -690,7 +690,7 @@ def _batched_reduction_collective(
|
||||
return vals_out, [batching.not_mapped] * len(vals_out)
|
||||
|
||||
def _replica_groups(axis_env, axis_name, axis_index_groups):
|
||||
replica_groups = xla.axis_groups(axis_env, axis_name)
|
||||
replica_groups = pxla.axis_groups(axis_env, axis_name)
|
||||
if axis_index_groups is not None:
|
||||
replica_groups = [[axis_group[i] for i in axis_index_group]
|
||||
for axis_group in replica_groups
|
||||
|
@ -50,7 +50,6 @@ from jax._src.interpreters.partial_eval import (
|
||||
trace_to_subjaxpr_dynamic, DynamicJaxprTracer,
|
||||
convert_constvars_jaxpr, new_jaxpr_eqn)
|
||||
from jax._src.interpreters import pxla
|
||||
from jax._src.interpreters import xla
|
||||
from jax._src.pjit import (sharding_constraint_p, get_unconstrained_dims,
|
||||
GSPMDSharding)
|
||||
from jax._src.sharding_impls import (
|
||||
@ -868,7 +867,7 @@ core.axis_substitution_rules[xmap_p] = _xmap_axis_subst
|
||||
# NOTE: We don't have to handle spmd_{in|out}_axes here, because
|
||||
# SPMD batching always gets involved as the last transform before XLA translation
|
||||
ad.JVPTrace.process_xmap = ad.JVPTrace.process_call # type: ignore
|
||||
ad.call_param_updaters[xmap_p] = xla.xla_call_jvp_update_params
|
||||
ad.call_param_updaters[xmap_p] = pxla.xla_call_jvp_update_params
|
||||
|
||||
def _xmap_transpose(params, call_jaxpr, args, cts_in, cts_in_avals, reduce_axes):
|
||||
all_args, in_tree_def = tree_flatten(((), args, cts_in)) # empty consts
|
||||
@ -1305,7 +1304,7 @@ def _xmap_lowering_rule_replica(ctx, *in_nodes,
|
||||
global_axis_sizes,
|
||||
spmd_in_axes, spmd_out_axes,
|
||||
axis_resources, resource_env, backend):
|
||||
xla.check_backend_matches(backend, ctx.module_context.platform)
|
||||
mlir.check_backend_matches(backend, ctx.module_context.platform)
|
||||
# The only way for any of those two assertions to be violated is when xmap
|
||||
# is using the SPMD lowering, but then this rule shouldn't even trigger.
|
||||
assert spmd_in_axes is None and spmd_out_axes is None
|
||||
@ -1382,7 +1381,7 @@ def _xmap_lowering_rule_spmd(ctx, *global_in_nodes,
|
||||
donated_invars, global_axis_sizes, spmd_in_axes,
|
||||
spmd_out_axes, axis_resources,
|
||||
resource_env, backend):
|
||||
xla.check_backend_matches(backend, ctx.module_context.platform)
|
||||
mlir.check_backend_matches(backend, ctx.module_context.platform)
|
||||
plan = EvaluationPlan.from_axis_resources(
|
||||
axis_resources, resource_env, global_axis_sizes)
|
||||
|
||||
@ -1450,7 +1449,7 @@ def _xmap_lowering_rule_spmd_manual(ctx, *global_in_nodes,
|
||||
resource_env, backend):
|
||||
assert spmd_in_axes is None and spmd_out_axes is None
|
||||
# This first part (up to vtile_manual) is shared with non-MANUAL SPMD rule.
|
||||
xla.check_backend_matches(backend, ctx.module_context.platform)
|
||||
mlir.check_backend_matches(backend, ctx.module_context.platform)
|
||||
plan = EvaluationPlan.from_axis_resources(
|
||||
axis_resources, resource_env, global_axis_sizes)
|
||||
manual_mesh_axes = frozenset(it.chain.from_iterable(plan.physical_axis_resources.values()))
|
||||
|
@ -1061,6 +1061,15 @@ def _outside_call_impl(*args, **params):
|
||||
outside_call_p.def_impl(_outside_call_impl)
|
||||
|
||||
|
||||
def _with_sharding_proto(builder, sharding_proto, op_fn, *args, **kwargs):
|
||||
"""Builds op_fn(*args, **kwargs) with sharding annotation."""
|
||||
builder.set_sharding(sharding_proto)
|
||||
try:
|
||||
return op_fn(*args, **kwargs)
|
||||
finally:
|
||||
builder.clear_sharding()
|
||||
|
||||
|
||||
def _outside_call_translation_rule(ctx,
|
||||
avals_in,
|
||||
avals_out,
|
||||
@ -1137,7 +1146,7 @@ def _outside_call_translation_rule(ctx,
|
||||
build_infeed = functools.partial(xops.InfeedWithToken,
|
||||
after_outfeed_itoken,
|
||||
xla_client.Shape.tuple_shape(shape))
|
||||
outs_and_token = xla.with_sharding_proto(comp, infeed_sharding_proto,
|
||||
outs_and_token = _with_sharding_proto(comp, infeed_sharding_proto,
|
||||
build_infeed)
|
||||
outs = xops.GetTupleElement(outs_and_token, 0)
|
||||
next_itoken = xops.GetTupleElement(outs_and_token, 1)
|
||||
|
@ -57,6 +57,7 @@ from jax._src import random as random_internal
|
||||
from jax._src import source_info_util
|
||||
from jax._src import util
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.lax import control_flow as lax_control_flow
|
||||
from jax._src.lax import lax as lax_internal
|
||||
from jax._src.lax import linalg as lax_linalg
|
||||
@ -1217,7 +1218,7 @@ def _make_op_metadata(primitive: core.Primitive,
|
||||
return xla_client.OpMetadata(
|
||||
op_type=primitive.name,
|
||||
op_name=eqn_str,
|
||||
source_file=xla.get_canonical_source_file(frame) if frame else None,
|
||||
source_file=mlir.get_canonical_source_file(frame) if frame else None,
|
||||
source_line=frame.start_line if frame else None)
|
||||
|
||||
|
||||
|
@ -16,20 +16,17 @@ from jax._src.interpreters.xla import (
|
||||
TranslationContext as TranslationContext,
|
||||
TranslationRule as TranslationRule,
|
||||
abstractify as abstractify,
|
||||
axis_groups as axis_groups,
|
||||
backend_specific_translations as backend_specific_translations,
|
||||
canonicalize_dtype as canonicalize_dtype,
|
||||
canonicalize_dtype_handlers as canonicalize_dtype_handlers,
|
||||
check_backend_matches as check_backend_matches,
|
||||
parameter as parameter,
|
||||
pytype_aval_mappings as pytype_aval_mappings,
|
||||
register_collective_primitive as register_collective_primitive,
|
||||
register_initial_style_primitive as register_initial_style_primitive,
|
||||
register_translation as register_translation,
|
||||
sharding_to_proto as sharding_to_proto,
|
||||
translations as translations,
|
||||
xla_destructure as xla_destructure,
|
||||
xla_shape_handlers as xla_shape_handlers,
|
||||
)
|
||||
from jax._src.interpreters.pxla import (
|
||||
axis_groups as axis_groups,
|
||||
)
|
||||
|
||||
from jax._src.core import (
|
||||
|
@ -123,7 +123,6 @@ core.pytype_aval_mappings[SparseArray] = lambda x: x.aval
|
||||
core.raise_to_shaped_mappings[AbstractSparseArray] = lambda aval, _: aval
|
||||
xla.pytype_aval_mappings[SparseArray] = lambda x: x.aval
|
||||
xla.canonicalize_dtype_handlers[SparseArray] = lambda x: x
|
||||
xla.xla_shape_handlers[AbstractSparseArray] = sparse_array_shape_handler
|
||||
|
||||
def sparse_array_mlir_type_handler(a):
|
||||
return (
|
||||
@ -258,7 +257,6 @@ core.pytype_aval_mappings[Empty] = lambda x: ConcreteEmpty()
|
||||
core.raise_to_shaped_mappings[AbstractEmpty] = lambda aval, _: aval
|
||||
xla.pytype_aval_mappings[Empty] = lambda x: AbstractEmpty()
|
||||
xla.canonicalize_dtype_handlers[Empty] = lambda x: x
|
||||
xla.xla_shape_handlers[AbstractEmpty] = lambda _: ()
|
||||
|
||||
|
||||
@unittest.skip("Test does not work with jax.Array")
|
||||
|
@ -50,7 +50,6 @@ from jax._src.lib import xla_extension
|
||||
from jax._src.util import safe_map, safe_zip
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import pxla
|
||||
from jax.interpreters import xla
|
||||
from jax._src import array
|
||||
from jax._src.sharding_impls import PmapSharding
|
||||
from jax.ad_checkpoint import checkpoint as new_checkpoint
|
||||
@ -1082,16 +1081,16 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
|
||||
def testAxisGroups(self):
|
||||
axis_env = sharding_impls.AxisEnv(8, ('i', 'j'), (4, 2))
|
||||
groups = xla.axis_groups(axis_env, 'i')
|
||||
groups = pxla.axis_groups(axis_env, 'i')
|
||||
self.assertEqual(groups, ((0, 2, 4, 6), (1, 3, 5, 7)))
|
||||
|
||||
groups = xla.axis_groups(axis_env, 'j')
|
||||
groups = pxla.axis_groups(axis_env, 'j')
|
||||
self.assertEqual(groups, ((0, 1), (2, 3), (4, 5), (6, 7)))
|
||||
|
||||
groups = xla.axis_groups(axis_env, ('i', 'j'))
|
||||
groups = pxla.axis_groups(axis_env, ('i', 'j'))
|
||||
self.assertEqual(groups, ((0, 1, 2, 3, 4, 5, 6, 7,),))
|
||||
|
||||
groups = xla.axis_groups(axis_env, ('j', 'i'))
|
||||
groups = pxla.axis_groups(axis_env, ('j', 'i'))
|
||||
self.assertEqual(len(groups), 1)
|
||||
self.assertEqual((tuple(sorted(groups[0])),),
|
||||
((0, 1, 2, 3, 4, 5, 6, 7,),)) # order doesn't matter
|
||||
|
@ -22,7 +22,7 @@ from absl.testing import absltest
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import xla_bridge as xb
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax.interpreters import xla
|
||||
from jax._src.interpreters import xla
|
||||
|
||||
from jax._src.config import config
|
||||
config.parse_flags_with_absl()
|
||||
|
Loading…
x
Reference in New Issue
Block a user