mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Another attempt to land #20445
Reverts fa9f02ba2fd7e874edee0169773923e162ed0ea1 PiperOrigin-RevId: 636926775
This commit is contained in:
parent
1e01fa7b0f
commit
15b974c90b
@ -2409,6 +2409,8 @@ class ClosedCallPrimitive(CallPrimitive):
|
||||
|
||||
closed_call_p: ClosedCallPrimitive = ClosedCallPrimitive('closed_call')
|
||||
closed_call_p.def_impl(call_impl)
|
||||
closed_call_p.def_effectful_abstract_eval(
|
||||
lambda *_, call_jaxpr: (call_jaxpr.out_avals, call_jaxpr.effects))
|
||||
|
||||
|
||||
outfeed_primitives: set[Primitive] = set()
|
||||
@ -2852,7 +2854,7 @@ custom_typechecks: dict[Primitive, Callable] = {}
|
||||
|
||||
def _check_closed_call(_, *in_atoms, call_jaxpr):
|
||||
in_avals = [x.aval for x in in_atoms]
|
||||
if list(in_avals) != list(call_jaxpr.in_avals):
|
||||
if not all(map(typecompat, call_jaxpr.in_avals, in_avals)):
|
||||
raise JaxprTypeError("Closed call in_avals mismatch")
|
||||
return call_jaxpr.out_avals, call_jaxpr.effects
|
||||
custom_typechecks[closed_call_p] = _check_closed_call
|
||||
|
@ -1403,7 +1403,10 @@ def lower_jaxpr_to_fun(
|
||||
[num_dim_vars, num_tokens])
|
||||
tokens_in = TokenSet(zip(effects, token_args))
|
||||
args: list[list[ir.Value]] = unflattened_args
|
||||
callee_name_stack = name_stack.extend(util.wrap_name(name, api_name))
|
||||
if name is not None:
|
||||
callee_name_stack = name_stack.extend(util.wrap_name(name, api_name))
|
||||
else:
|
||||
callee_name_stack = name_stack
|
||||
consts = [ir_constants(xla.canonicalize_dtype(x)) for x in jaxpr.consts]
|
||||
out_vals, tokens_out = jaxpr_subcomp(
|
||||
ctx, jaxpr.jaxpr, callee_name_stack, tokens_in,
|
||||
@ -1860,7 +1863,7 @@ def core_call_lowering(ctx: LoweringRuleContext,
|
||||
|
||||
register_lowering(core.call_p, partial(core_call_lowering, name="core_call"))
|
||||
register_lowering(core.closed_call_p,
|
||||
partial(core_call_lowering, name="core_closed_call"))
|
||||
partial(core_call_lowering, name=None))
|
||||
|
||||
def broadcast_in_dim(ctx: LoweringRuleContext, op, aval_out: core.AbstractValue, *,
|
||||
broadcast_dimensions) -> ir.Value:
|
||||
|
@ -248,7 +248,7 @@ def scan(f: Callable[[Carry, X], tuple[Carry, Y]],
|
||||
ys = []
|
||||
maybe_reversed = reversed if reverse else lambda x: x
|
||||
for i in maybe_reversed(range(length)):
|
||||
xs_slice = [_index_array(i, core.get_aval(x), x) for x in xs_flat]
|
||||
xs_slice = [slicing.index_in_dim(x, i, keepdims=False) for x in xs_flat]
|
||||
carry, y = f(carry, tree_unflatten(xs_tree, xs_slice))
|
||||
ys.append(y)
|
||||
stack = lambda *ys: jax.numpy.stack(ys)
|
||||
@ -385,164 +385,69 @@ def _aval_mismatch_extra(a1: core.AbstractValue, a2: core.AbstractValue) -> str:
|
||||
'the shapes do not match' * shape_mismatch)
|
||||
return ''
|
||||
|
||||
|
||||
def _scan_impl_unrolled(*args, reverse, length, num_consts, num_carry, linear,
|
||||
f_impl, x_avals, y_avals):
|
||||
consts, init, xs = split_list(args, [num_consts, num_carry])
|
||||
|
||||
carry = init
|
||||
ys = []
|
||||
|
||||
for i in range(length):
|
||||
i_ = length - i - 1 if reverse else i
|
||||
x = _map(partial(_index_array, i_), x_avals, xs)
|
||||
out = f_impl(*consts, *carry, *x)
|
||||
carry, y = split_list(out, [num_carry])
|
||||
ys.append(y)
|
||||
|
||||
ys = list(reversed(ys)) if reverse else ys
|
||||
ys = list(zip(*ys))
|
||||
ys = _map(_stack, y_avals, ys)
|
||||
return (*carry, *ys)
|
||||
|
||||
def _scan_impl_loop(*args, reverse, length, num_consts, num_carry, linear,
|
||||
f_impl, x_avals, y_avals):
|
||||
consts, init, xs = split_list(args, [num_consts, num_carry])
|
||||
|
||||
def cond_fun(vals):
|
||||
i, *_ = vals
|
||||
return i < length
|
||||
|
||||
def body_fun(vals):
|
||||
[i], carry, ys = split_list(vals, [1, num_carry])
|
||||
i_ = length - i - 1 if reverse else i
|
||||
# TODO(jakevdp)[key-reuse]: this key reuse logic is not quite right,
|
||||
# because the scan body may consume any keys within it.
|
||||
xs_unconsumed = _map(jax.random.clone, xs)
|
||||
x = _map(partial(_dynamic_index_array, i_), x_avals, xs_unconsumed)
|
||||
out_flat = f_impl(*consts, *carry, *x)
|
||||
carry_out, y_updates = split_list(out_flat, [num_carry])
|
||||
ys_out = _map(partial(_update_array, i_), y_avals, ys, y_updates)
|
||||
return [i + 1] + carry_out + ys_out
|
||||
|
||||
# TODO(jakevdp)[key-reuse]: mark xs consumed here if f_impl consumes them.
|
||||
|
||||
ys_init = _map(partial(_empty_array, length), y_avals)
|
||||
if length == 0:
|
||||
return init + ys_init
|
||||
else:
|
||||
init_val = [lax._const(length, 0)] + init + ys_init
|
||||
_, *outs = while_loop(cond_fun, body_fun, init_val)
|
||||
return outs
|
||||
|
||||
def _scan_impl_block_unrolled(*args, reverse, length, num_consts, num_carry,
|
||||
linear, block_length, f_impl, x_avals, y_avals):
|
||||
consts, init, xs = split_list(args, [num_consts, num_carry])
|
||||
|
||||
num_blocks, rem = divmod(length, block_length)
|
||||
assert rem == 0
|
||||
|
||||
partition = partial(_partition_leading, num_blocks, block_length)
|
||||
xs_block = _map(partition, x_avals, xs)
|
||||
|
||||
prepend_aval = partial(_prepend_dim_to_aval, block_length)
|
||||
x_block_avals = _map(prepend_aval, x_avals)
|
||||
y_block_avals = _map(prepend_aval, y_avals)
|
||||
|
||||
f_impl_block = partial(
|
||||
_scan_impl_unrolled, reverse=reverse, length=block_length,
|
||||
num_consts=num_consts, num_carry=num_carry, linear=linear,
|
||||
f_impl=f_impl, x_avals=x_avals, y_avals=y_avals)
|
||||
|
||||
outs = _scan_impl_loop(
|
||||
*consts, *init, *xs_block, reverse=reverse, length=num_blocks,
|
||||
num_consts=num_consts, num_carry=num_carry, linear=linear,
|
||||
f_impl=f_impl_block, x_avals=x_block_avals, y_avals=y_block_avals)
|
||||
|
||||
carry, ys_blocks = split_list(outs, [num_carry])
|
||||
combine = partial(_combine_leading, num_blocks, block_length)
|
||||
ys = _map(combine, y_avals, ys_blocks)
|
||||
return (*carry, *ys)
|
||||
|
||||
# TODO(mattjj): re-land #19819 version? simpler, but caused ~1 perf regression.
|
||||
def _scan_impl(*args, reverse, length, num_consts, num_carry, jaxpr, linear,
|
||||
unroll, _split_transpose):
|
||||
del _split_transpose
|
||||
_, _, x_avals = split_list(jaxpr.in_avals, [num_consts, num_carry])
|
||||
consts, carry, xs_ = split_list(args, [num_consts, num_carry])
|
||||
_, y_avals = split_list(jaxpr.out_avals, [num_carry])
|
||||
f_impl = core.jaxpr_as_fun(jaxpr)
|
||||
|
||||
if unroll == 1:
|
||||
return _scan_impl_loop(
|
||||
*args, reverse=reverse, length=length, num_consts=num_consts,
|
||||
num_carry=num_carry, linear=linear, f_impl=f_impl, x_avals=x_avals,
|
||||
y_avals=y_avals)
|
||||
|
||||
consts, init, xs = split_list(args, [num_consts, num_carry])
|
||||
num_blocks, rem = divmod(length, unroll)
|
||||
length_div = num_blocks * unroll
|
||||
|
||||
if rem > 0:
|
||||
if reverse:
|
||||
split = partial(_split_leading_dim, rem)
|
||||
xs_rem, xs = unzip2(_map(split, x_avals, xs))
|
||||
num_trips, remainder = divmod(length, unroll)
|
||||
if remainder:
|
||||
if not reverse:
|
||||
xs_, xs_rem = unzip2(_map(partial(_split_leading, num_trips*unroll), xs_))
|
||||
else:
|
||||
split = partial(_split_leading_dim, length_div)
|
||||
xs, xs_rem = unzip2(_map(split, x_avals, xs))
|
||||
xs_rem, xs_ = unzip2(_map(partial(_split_leading, remainder), xs_))
|
||||
xss = [lax.reshape(x, (num_trips, unroll, *x.shape[1:])) for x in xs_]
|
||||
yss = _map(partial(_empty_array, (num_trips, unroll)), y_avals)
|
||||
|
||||
outs = _scan_impl_block_unrolled(
|
||||
*consts, *init, *xs, reverse=reverse, length=length_div,
|
||||
num_consts=num_consts, num_carry=num_carry, linear=linear,
|
||||
block_length=unroll, f_impl=f_impl, x_avals=x_avals, y_avals=y_avals)
|
||||
def cond_fun(while_carry):
|
||||
i, _, _ = while_carry
|
||||
return i < num_trips
|
||||
def body_fun(while_carry):
|
||||
i_, carry, yss = while_carry
|
||||
i = num_trips - i_ - 1 if reverse else i_
|
||||
xs = [slicing.dynamic_index_in_dim(xs, i, keepdims=False) for xs in xss]
|
||||
carry, ys = inner(unroll, carry, xs)
|
||||
yss = [slicing.dynamic_update_index_in_dim(ys, upd, i, 0)
|
||||
for ys, upd in zip(yss, ys)]
|
||||
return i_ + 1, carry, yss
|
||||
def inner(n, carry, xs):
|
||||
ys = []
|
||||
for i_ in range(n):
|
||||
i = n - i_ - 1 if reverse else i_
|
||||
x = [slicing.index_in_dim(x, i, keepdims=False) for x in xs]
|
||||
carry_y = eval_jaxpr_p.bind(*consts, *carry, *x, jaxpr=jaxpr)
|
||||
carry, y = split_list(carry_y, [num_carry])
|
||||
ys.append(y)
|
||||
ys = list(reversed(ys)) if reverse else ys
|
||||
return carry, _map(jax.numpy.stack, zip(*ys))
|
||||
|
||||
carry, ys = split_list(outs, [num_carry])
|
||||
if num_trips:
|
||||
i = lax._const(num_trips, 0)
|
||||
_, carry, yss = jax.lax.while_loop(cond_fun, body_fun, (i, carry, yss))
|
||||
ys = [lax.reshape(ys, (num_trips * unroll, *ys.shape[2:])) for ys in yss]
|
||||
if remainder:
|
||||
carry, ys_rem = inner(remainder, carry, xs_rem)
|
||||
ys = _map(_concat, ys, ys_rem) if not reverse else _map(_concat, ys_rem, ys)
|
||||
return [*carry, *ys]
|
||||
|
||||
if rem > 0:
|
||||
outs = _scan_impl_unrolled(
|
||||
*consts, *carry, *xs_rem, reverse=reverse, length=rem,
|
||||
num_consts=num_consts, num_carry=num_carry, linear=linear,
|
||||
f_impl=f_impl, x_avals=x_avals, y_avals=y_avals)
|
||||
carry, ys_rem = split_list(outs, [num_carry])
|
||||
if reverse:
|
||||
ys = _map(_concatenate, y_avals, ys_rem, ys)
|
||||
else:
|
||||
ys = _map(_concatenate, y_avals, ys, ys_rem)
|
||||
def _split_leading(sz, x):
|
||||
return (slicing.slice_in_dim(x, 0, sz),
|
||||
slicing.slice_in_dim(x, sz, x.shape[0]))
|
||||
|
||||
return (*carry, *ys)
|
||||
def _concat(a, b): return lax.concatenate([a, b], 0)
|
||||
|
||||
def _stack(aval, vals):
|
||||
vals = [lax.expand_dims(x, (0,)) for x in vals]
|
||||
return lax.concatenate(vals, 0)
|
||||
def _empty_array(prefix, aval):
|
||||
return lax.broadcast(lax.empty(aval.dtype), (*prefix, *aval.shape))
|
||||
|
||||
def _concatenate(aval, x1, x2):
|
||||
return lax.concatenate([x1, x2], 0)
|
||||
|
||||
def _split_leading_dim(i, aval, x):
|
||||
assert x.ndim >= 1
|
||||
return (slicing.slice_in_dim(x, 0, i),
|
||||
slicing.slice_in_dim(x, i, x.shape[0]))
|
||||
|
||||
def _dynamic_index_array(i, aval, x):
|
||||
return slicing.dynamic_index_in_dim(x, i, keepdims=False)
|
||||
|
||||
def _index_array(i, aval, x):
|
||||
return slicing.index_in_dim(x, i, keepdims=False)
|
||||
|
||||
def _empty_array(sz, aval):
|
||||
return lax.broadcast(lax.empty(aval.dtype), (sz, *aval.shape))
|
||||
|
||||
def _update_array(i, aval, xs, x):
|
||||
return slicing.dynamic_update_index_in_dim(xs, x, i, 0)
|
||||
|
||||
def _partition_leading(sz0, sz1, aval, x):
|
||||
assert x.ndim >= 1
|
||||
assert x.shape[0] == sz0 * sz1
|
||||
return lax.reshape(x, (sz0, sz1, *x.shape[1:]))
|
||||
|
||||
def _combine_leading(sz0, sz1, aval, x):
|
||||
assert x.ndim >= 2
|
||||
assert x.shape[0] == sz0
|
||||
assert x.shape[1] == sz1
|
||||
return lax.collapse(x, 0, 2)
|
||||
eval_jaxpr_p = core.Primitive('eval_jaxpr')
|
||||
eval_jaxpr_p.multiple_results = True
|
||||
def _stage_jaxpr(trace, *tracers, jaxpr):
|
||||
params = dict(call_jaxpr=jaxpr)
|
||||
return trace.default_process_primitive(core.closed_call_p, tracers, params)
|
||||
pe.custom_staging_rules[eval_jaxpr_p] = _stage_jaxpr
|
||||
@eval_jaxpr_p.def_effectful_abstract_eval # abstract eval only used for jax2tf
|
||||
def _stage_jaxpr_abstract_eval(*_, jaxpr): return jaxpr.out_avals, jaxpr.effects
|
||||
|
||||
def _prepend_dim_to_aval(sz, aval):
|
||||
return core.unmapped_aval(sz, core.no_axis_name, 0, aval)
|
||||
|
@ -1467,11 +1467,13 @@ class TensorFlowTrace(core.Trace):
|
||||
def _unexpected_primitive(p: core.Primitive, *args, **kwargs):
|
||||
assert False, f"Encountered unexpected primitive {p}"
|
||||
|
||||
|
||||
# Call primitives are inlined
|
||||
for unexpected in [core.call_p, maps.xmap_p]:
|
||||
tf_impl[unexpected] = partial(_unexpected_primitive, unexpected)
|
||||
|
||||
tf_impl[lax_control_flow.loops.eval_jaxpr_p] = \
|
||||
lambda *args, jaxpr: _interpret_jaxpr(
|
||||
jaxpr, *args, fresh_constant_cache=False, extra_name_stack=None)
|
||||
|
||||
# Primitives that are not yet implemented must be explicitly declared here.
|
||||
tf_not_yet_impl = [
|
||||
"clz",
|
||||
|
Loading…
x
Reference in New Issue
Block a user