mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Remove axis_name from unmapped_aval
PiperOrigin-RevId: 718558713
This commit is contained in:
parent
f6243ff8e1
commit
23d360bded
@ -2415,8 +2415,7 @@ def device_put_replicated(x: Any, devices: Sequence[xc.Device]): # noqa: F811
|
||||
raise ValueError("`devices` argument to `device_put_replicated must be "
|
||||
"a non-empty sequence.")
|
||||
def _device_put_replicated(x):
|
||||
aval = core.unmapped_aval(len(devices), core.no_axis_name, 0,
|
||||
core.get_aval(x))
|
||||
aval = core.unmapped_aval(len(devices), 0, core.get_aval(x))
|
||||
assert isinstance(aval, ShapedArray)
|
||||
sharding_spec = sharding_specs.create_pmap_sharding_spec(aval.shape)
|
||||
if config.pmap_no_rank_reduction.value:
|
||||
|
@ -159,8 +159,7 @@ def callback_batching_rule(
|
||||
new_args = [arg if dim is batching.not_mapped else
|
||||
batching.moveaxis(arg, dim, 0) for arg, dim in zip(args, dims)]
|
||||
batched_result_avals = tuple(
|
||||
core.unmapped_aval(axis_size, core.no_axis_name, 0, aval)
|
||||
for aval in result_avals)
|
||||
core.unmapped_aval(axis_size, 0, aval) for aval in result_avals)
|
||||
|
||||
# For FFI calls we must update the layouts. We handle the output layouts
|
||||
# here, but the input layout updates depend on the vmap_method parameter.
|
||||
|
@ -2346,11 +2346,11 @@ def mapped_aval(size: AxisSize, axis: int | None,
|
||||
else:
|
||||
raise TypeError(f"no mapping handler for {aval} of type {type(aval)}")
|
||||
|
||||
def unmapped_aval(size: AxisSize, axis_name, axis: int | None,
|
||||
def unmapped_aval(size: AxisSize, axis: int | None,
|
||||
aval: AbstractValue) -> AbstractValue:
|
||||
_, handler = aval_mapping_handlers.get(type(aval), (None, None))
|
||||
if handler is not None:
|
||||
return handler(size, axis_name, axis, aval)
|
||||
return handler(size, axis, aval)
|
||||
else:
|
||||
raise TypeError(f"no unmapping handler for {aval} of type {type(aval)}")
|
||||
|
||||
@ -2366,11 +2366,10 @@ def _map_shaped_array(
|
||||
weak_type=aval.weak_type, sharding=sharding)
|
||||
|
||||
def _unmap_shaped_array(
|
||||
size: int, axis_name: AxisName, axis: int | None, aval: ShapedArray
|
||||
) -> ShapedArray:
|
||||
size: int, axis: int | None, aval: ShapedArray) -> ShapedArray:
|
||||
if axis is None: return aval
|
||||
elif type(axis) is int:
|
||||
sharding = (aval.sharding.with_spec(tuple_insert(aval.sharding.spec, axis, axis_name))
|
||||
sharding = (aval.sharding.with_spec(tuple_insert(aval.sharding.spec, axis, None))
|
||||
if config.sharding_in_types.value else None)
|
||||
return ShapedArray(tuple_insert(aval.shape, axis, size), aval.dtype,
|
||||
weak_type=aval.weak_type, sharding=sharding)
|
||||
@ -2383,7 +2382,7 @@ def _map_dshaped_array(
|
||||
aval.weak_type)
|
||||
|
||||
def _unmap_dshaped_array(
|
||||
size: AxisSize, axis_name: AxisName, axis: int | None, aval: DShapedArray
|
||||
size: AxisSize, axis: int | None, aval: DShapedArray
|
||||
) -> DShapedArray:
|
||||
if axis is None: return aval
|
||||
elif type(axis) is int:
|
||||
@ -2396,7 +2395,7 @@ AvalMapHandlerPair = tuple[Callable, Callable]
|
||||
aval_mapping_handlers: dict[type, AvalMapHandlerPair] = {
|
||||
DShapedArray: (_map_dshaped_array, _unmap_dshaped_array),
|
||||
ShapedArray: (_map_shaped_array, _unmap_shaped_array),
|
||||
AbstractToken: (lambda _, __, a: a, lambda _, __, ___, a: a)
|
||||
AbstractToken: (lambda _, __, a: a, lambda _, __, a: a)
|
||||
}
|
||||
|
||||
# When a mapped function is given no axis name, we generate a name object based
|
||||
@ -2777,7 +2776,7 @@ def _check_map(ctx_factory, prim, in_avals, params):
|
||||
raise JaxprTypeError(f"Map primitive {prim} missing 'out_axes' parameter")
|
||||
out_axes = params["out_axes"]
|
||||
|
||||
binder_avals = [unmapped_aval(axis_size, axis_name, in_axis, v.aval)
|
||||
binder_avals = [unmapped_aval(axis_size, in_axis, v.aval)
|
||||
if in_axis is not None else v.aval
|
||||
for v, in_axis in zip(call_jaxpr.invars, in_axes)]
|
||||
for binder_aval, in_aval in zip(binder_avals, in_avals):
|
||||
@ -2789,7 +2788,7 @@ def _check_map(ctx_factory, prim, in_avals, params):
|
||||
_check_jaxpr(ctx_factory, call_jaxpr)
|
||||
|
||||
mapped_out_avals = [v.aval for v in call_jaxpr.outvars]
|
||||
out_avals = [unmapped_aval(axis_size, axis_name, out_axis, aval)
|
||||
out_avals = [unmapped_aval(axis_size, out_axis, aval)
|
||||
if out_axis is not None else aval
|
||||
for aval, out_axis in zip(mapped_out_avals, out_axes)]
|
||||
return out_avals, filter_named_axis_effects(call_jaxpr.effects, {axis_name})
|
||||
|
@ -1000,7 +1000,7 @@ def map_transpose(primitive, params, call_jaxpr, args, ct, _):
|
||||
assert len(in_axes) == len(arg_cts)
|
||||
def unmap_zero(zero, in_axis):
|
||||
return (zero if in_axis is None else
|
||||
Zero(core.unmapped_aval(params['axis_size'], params['axis_name'], in_axis, zero.aval)))
|
||||
Zero(core.unmapped_aval(params['axis_size'], in_axis, zero.aval)))
|
||||
arg_cts = (unmap_zero(arg_ct, in_axis) if type(arg_ct) is Zero else
|
||||
arg_ct if in_axis is not None else
|
||||
arg_ct.sum(0)
|
||||
|
@ -223,7 +223,7 @@ def _update_annotation(
|
||||
if isinstance(d, RaggedAxis):
|
||||
raise NotImplementedError
|
||||
else:
|
||||
new_avals.append(core.unmapped_aval(sz, axis_name, d, a)) # type: ignore
|
||||
new_avals.append(core.unmapped_aval(sz, d, a)) # type: ignore
|
||||
|
||||
mentioned = {d for a in new_avals if type(a) is core.DShapedArray
|
||||
for d in a.shape if type(d) is Name}
|
||||
@ -750,7 +750,7 @@ def _batch_jaxpr2(
|
||||
handle_ragged(closed_jaxpr.in_avals, dim, aval)
|
||||
if isinstance(dim, RaggedAxis) else (dim, aval)
|
||||
for dim, aval in zip(in_axes, closed_jaxpr.in_avals)])
|
||||
avals_in2 = [core.unmapped_aval(axis_data.size, axis_data.name, b, aval)
|
||||
avals_in2 = [core.unmapped_aval(axis_data.size, b, aval)
|
||||
if b is not not_mapped else aval
|
||||
for aval, b in unsafe_zip(avals_in, in_axes2)]
|
||||
jaxpr_out, _, consts, () = pe.trace_to_jaxpr_dynamic(f, avals_in2)
|
||||
@ -787,7 +787,7 @@ def _batch_jaxpr_axes(closed_jaxpr, axis_data, in_axes, out_axes_dest):
|
||||
f, out_axes = _batch_jaxpr_inner(f, axis_data)
|
||||
f, out_batched = _match_axes_jaxpr(f, axis_data, out_axes_dest, out_axes)
|
||||
f = _batch_jaxpr_outer(f, axis_data, in_axes)
|
||||
avals_in = [core.unmapped_aval(axis_data.size, axis_data.name, b, aval) if b is not not_mapped
|
||||
avals_in = [core.unmapped_aval(axis_data.size, b, aval) if b is not not_mapped
|
||||
else aval for aval, b in unsafe_zip(closed_jaxpr.in_avals, in_axes)]
|
||||
jaxpr_out, _, consts, () = pe.trace_to_jaxpr_dynamic(f, avals_in)
|
||||
return core.ClosedJaxpr(jaxpr_out, consts), out_batched()
|
||||
@ -906,9 +906,9 @@ def _matchaxis_symbolic_zeros(axis_name, sz, name, src, dst, x, sum_match=False)
|
||||
return x
|
||||
elif type(src) == type(dst) == int:
|
||||
aval = core.mapped_aval(sz, src, x.aval)
|
||||
return Zero(core.unmapped_aval(sz, name, dst, aval))
|
||||
return Zero(core.unmapped_aval(sz, dst, aval))
|
||||
elif src is not_mapped and dst is not not_mapped:
|
||||
return Zero(core.unmapped_aval(sz, name, dst, x.aval))
|
||||
return Zero(core.unmapped_aval(sz, dst, x.aval))
|
||||
elif dst is not_mapped and sum_match:
|
||||
return Zero(core.mapped_aval(sz, src, x.aval))
|
||||
else:
|
||||
|
@ -372,7 +372,7 @@ class JaxprTrace(Trace['JaxprTracer']):
|
||||
out_axes=tuple(staged_out_axes), call_jaxpr=call_jaxpr)
|
||||
del staged_params['out_axes_thunk']
|
||||
# The outputs of the staged-out call are Tracers with the new eqn as recipe.
|
||||
out_avals = [unmapped_aval(params['axis_size'], params['axis_name'], ax, a)
|
||||
out_avals = [unmapped_aval(params['axis_size'], ax, a)
|
||||
for ax, a in zip(staged_out_axes, out_avals_mapped)]
|
||||
out_tracers = [JaxprTracer(self, PartialVal.unknown(a), None)
|
||||
for a in out_avals]
|
||||
@ -1956,7 +1956,7 @@ class DynamicJaxprTrace(core.Trace):
|
||||
raise ValueError("Ordered effects not supported for "
|
||||
f"map primitives: {ordered_effects}")
|
||||
out_axes = params['out_axes_thunk']()
|
||||
out_avals = [core.unmapped_aval(axis_size, axis_name, out_axis, a)
|
||||
out_avals = [core.unmapped_aval(axis_size, out_axis, a)
|
||||
if out_axis is not None else a
|
||||
for a, out_axis in zip(reduced_out_avals, out_axes)]
|
||||
source_info = source_info_util.current()
|
||||
|
@ -914,7 +914,7 @@ _pmap_aval_mapping_handlers: dict[type, AvalMapHandlerPair] = {
|
||||
def _pmap_unmapped_aval(size: core.AxisSize, axis_name, axis: int | None,
|
||||
aval: core.AbstractValue) -> core.AbstractValue:
|
||||
if not config.pmap_no_rank_reduction.value:
|
||||
return core.unmapped_aval(size, axis_name, axis, aval)
|
||||
return core.unmapped_aval(size, axis, aval)
|
||||
|
||||
_, handler = _pmap_aval_mapping_handlers.get(type(aval), (None, None))
|
||||
if handler is not None:
|
||||
@ -1350,7 +1350,7 @@ def _pmap_partial_eval_custom_params_updater(
|
||||
return new_params_known, new_params_staged
|
||||
|
||||
def _pmap_partial_eval_custom_res_maker(params_known, aval):
|
||||
return core.unmapped_aval(params_known['axis_size'], core.no_axis_name, 0, aval)
|
||||
return core.unmapped_aval(params_known['axis_size'], 0, aval)
|
||||
|
||||
def _pmap_dce_rule(used_outputs, eqn):
|
||||
# just like pe.dce_jaxpr_call_rule, except handles in_axes / out_axes
|
||||
|
@ -520,7 +520,7 @@ def _stage_jaxpr_abstract_eval(*_, jaxpr):
|
||||
return jaxpr.out_avals, jaxpr.effects
|
||||
|
||||
def _prepend_dim_to_aval(sz, aval):
|
||||
return core.unmapped_aval(sz, None, 0, aval)
|
||||
return core.unmapped_aval(sz, 0, aval)
|
||||
|
||||
def _scan_abstract_eval(*args, reverse, length, num_consts, num_carry, jaxpr,
|
||||
linear, unroll, _split_transpose):
|
||||
@ -704,7 +704,7 @@ def _scan_partial_eval(trace, *tracers, reverse, length, num_consts, num_carry,
|
||||
extensive_res = _map(trace.new_instantiated_const, extensive_res)
|
||||
# Create output tracers for jaxpr_unknown bind, adapting extensive shapes.
|
||||
carry_avals, y_avals = split_list(jaxpr_unknown.out_avals, [sum(carry_uk)])
|
||||
ys_avals = [core.unmapped_aval(length, None, 0, y_aval)
|
||||
ys_avals = [core.unmapped_aval(length, 0, y_aval)
|
||||
for y_aval in y_avals]
|
||||
out_tracers = [pe.JaxprTracer(trace, pe.PartialVal.unknown(a), None)
|
||||
for a in itertools.chain(carry_avals, ys_avals)]
|
||||
@ -1071,7 +1071,7 @@ def _scan_partial_eval_custom(saveable, unks_in, inst_in, eqn):
|
||||
|
||||
# Create residual variables.
|
||||
intensive_avals, ext_avals_mapped = partition_list(loop_dep_res, res_avals)
|
||||
ext_avals = [core.unmapped_aval(eqn.params['length'], None, 0, a)
|
||||
ext_avals = [core.unmapped_aval(eqn.params['length'], 0, a)
|
||||
for a in ext_avals_mapped]
|
||||
newvar = core.gensym()
|
||||
intensive_res = _map(newvar, intensive_avals)
|
||||
@ -1149,7 +1149,7 @@ def _scan_typecheck(bind_time, *in_atoms, reverse, length, num_consts,
|
||||
jaxpr.in_avals, [num_consts, num_carry])
|
||||
carry_avals_jaxpr, y_avals_mapped = split_list(jaxpr.out_avals, [num_carry])
|
||||
x_avals_mapped = _map(partial(core.mapped_aval, length, 0), x_avals)
|
||||
y_avals = [core.unmapped_aval(length, None, 0, a)
|
||||
y_avals = [core.unmapped_aval(length, 0, a)
|
||||
for a in y_avals_mapped]
|
||||
|
||||
if not all(_map(core.typematch, init_avals_jaxpr, carry_avals_jaxpr)):
|
||||
|
@ -367,9 +367,8 @@ class AbstractRef(core.AbstractValue):
|
||||
def _map_ref(size, axis, ref_aval):
|
||||
return AbstractRef(core.mapped_aval(size, axis, ref_aval.inner_aval))
|
||||
|
||||
def _unmap_ref(size, axis_name, axis, ref_aval):
|
||||
return AbstractRef(core.unmapped_aval(size, axis_name, axis,
|
||||
ref_aval.inner_aval))
|
||||
def _unmap_ref(size, axis, ref_aval):
|
||||
return AbstractRef(core.unmapped_aval(size, axis, ref_aval.inner_aval))
|
||||
|
||||
core.aval_mapping_handlers[AbstractRef] = (_map_ref, _unmap_ref)
|
||||
|
||||
|
@ -1613,7 +1613,7 @@ def _promote_scalar_residuals_jaxpr(jaxpr, which):
|
||||
res, args = split_list(res_and_args, [len(jaxpr.constvars)])
|
||||
res = [_rem_singleton(x) if w else x for x, w in zip(res, which)]
|
||||
return core.eval_jaxpr(jaxpr, res, *args)
|
||||
res_avals = [core.unmapped_aval(1, None, 0, v.aval) if w else v.aval
|
||||
res_avals = [core.unmapped_aval(1, 0, v.aval) if w else v.aval
|
||||
for v, w in zip(jaxpr.constvars, which)]
|
||||
in_avals = [*res_avals, *[v.aval for v in jaxpr.invars]]
|
||||
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(fun, in_avals)
|
||||
@ -1740,7 +1740,7 @@ def _add_reshapes(which, jaxpr_known, jaxpr_staged):
|
||||
res_, ins = split_list(args, [len(which)])
|
||||
res = [_rem_singleton(x) if w else x for x, w in zip(res_, which_)]
|
||||
return core.eval_jaxpr(jaxpr_staged, (), *res, *ins)
|
||||
res_avals = [core.unmapped_aval(1, None, 0, v.aval) if w else v.aval
|
||||
res_avals = [core.unmapped_aval(1, 0, v.aval) if w else v.aval
|
||||
for w, v in zip(which_, jaxpr_staged.invars[:len(which)])]
|
||||
avals_in = [*res_avals, *[v.aval for v in jaxpr_staged.invars[len(which):]]]
|
||||
jaxpr_staged, _, (), () = pe.trace_to_jaxpr_dynamic(staged, avals_in)
|
||||
|
Loading…
x
Reference in New Issue
Block a user