Another attempt to land #20445

Reverts fa9f02ba2fd7e874edee0169773923e162ed0ea1

PiperOrigin-RevId: 636926775
This commit is contained in:
Sergei Lebedev 2024-05-24 08:23:31 -07:00 committed by jax authors
parent 1e01fa7b0f
commit 15b974c90b
4 changed files with 65 additions and 153 deletions

View File

@ -2409,6 +2409,8 @@ class ClosedCallPrimitive(CallPrimitive):
closed_call_p: ClosedCallPrimitive = ClosedCallPrimitive('closed_call')
closed_call_p.def_impl(call_impl)
closed_call_p.def_effectful_abstract_eval(
lambda *_, call_jaxpr: (call_jaxpr.out_avals, call_jaxpr.effects))
outfeed_primitives: set[Primitive] = set()
@ -2852,7 +2854,7 @@ custom_typechecks: dict[Primitive, Callable] = {}
def _check_closed_call(_, *in_atoms, call_jaxpr):
in_avals = [x.aval for x in in_atoms]
if list(in_avals) != list(call_jaxpr.in_avals):
if not all(map(typecompat, call_jaxpr.in_avals, in_avals)):
raise JaxprTypeError("Closed call in_avals mismatch")
return call_jaxpr.out_avals, call_jaxpr.effects
custom_typechecks[closed_call_p] = _check_closed_call

View File

@ -1403,7 +1403,10 @@ def lower_jaxpr_to_fun(
[num_dim_vars, num_tokens])
tokens_in = TokenSet(zip(effects, token_args))
args: list[list[ir.Value]] = unflattened_args
callee_name_stack = name_stack.extend(util.wrap_name(name, api_name))
if name is not None:
callee_name_stack = name_stack.extend(util.wrap_name(name, api_name))
else:
callee_name_stack = name_stack
consts = [ir_constants(xla.canonicalize_dtype(x)) for x in jaxpr.consts]
out_vals, tokens_out = jaxpr_subcomp(
ctx, jaxpr.jaxpr, callee_name_stack, tokens_in,
@ -1860,7 +1863,7 @@ def core_call_lowering(ctx: LoweringRuleContext,
register_lowering(core.call_p, partial(core_call_lowering, name="core_call"))
register_lowering(core.closed_call_p,
partial(core_call_lowering, name="core_closed_call"))
partial(core_call_lowering, name=None))
def broadcast_in_dim(ctx: LoweringRuleContext, op, aval_out: core.AbstractValue, *,
broadcast_dimensions) -> ir.Value:

View File

@ -248,7 +248,7 @@ def scan(f: Callable[[Carry, X], tuple[Carry, Y]],
ys = []
maybe_reversed = reversed if reverse else lambda x: x
for i in maybe_reversed(range(length)):
xs_slice = [_index_array(i, core.get_aval(x), x) for x in xs_flat]
xs_slice = [slicing.index_in_dim(x, i, keepdims=False) for x in xs_flat]
carry, y = f(carry, tree_unflatten(xs_tree, xs_slice))
ys.append(y)
stack = lambda *ys: jax.numpy.stack(ys)
@ -385,164 +385,69 @@ def _aval_mismatch_extra(a1: core.AbstractValue, a2: core.AbstractValue) -> str:
'the shapes do not match' * shape_mismatch)
return ''
def _scan_impl_unrolled(*args, reverse, length, num_consts, num_carry, linear,
f_impl, x_avals, y_avals):
consts, init, xs = split_list(args, [num_consts, num_carry])
carry = init
ys = []
for i in range(length):
i_ = length - i - 1 if reverse else i
x = _map(partial(_index_array, i_), x_avals, xs)
out = f_impl(*consts, *carry, *x)
carry, y = split_list(out, [num_carry])
ys.append(y)
ys = list(reversed(ys)) if reverse else ys
ys = list(zip(*ys))
ys = _map(_stack, y_avals, ys)
return (*carry, *ys)
def _scan_impl_loop(*args, reverse, length, num_consts, num_carry, linear,
f_impl, x_avals, y_avals):
consts, init, xs = split_list(args, [num_consts, num_carry])
def cond_fun(vals):
i, *_ = vals
return i < length
def body_fun(vals):
[i], carry, ys = split_list(vals, [1, num_carry])
i_ = length - i - 1 if reverse else i
# TODO(jakevdp)[key-reuse]: this key reuse logic is not quite right,
# because the scan body may consume any keys within it.
xs_unconsumed = _map(jax.random.clone, xs)
x = _map(partial(_dynamic_index_array, i_), x_avals, xs_unconsumed)
out_flat = f_impl(*consts, *carry, *x)
carry_out, y_updates = split_list(out_flat, [num_carry])
ys_out = _map(partial(_update_array, i_), y_avals, ys, y_updates)
return [i + 1] + carry_out + ys_out
# TODO(jakevdp)[key-reuse]: mark xs consumed here if f_impl consumes them.
ys_init = _map(partial(_empty_array, length), y_avals)
if length == 0:
return init + ys_init
else:
init_val = [lax._const(length, 0)] + init + ys_init
_, *outs = while_loop(cond_fun, body_fun, init_val)
return outs
def _scan_impl_block_unrolled(*args, reverse, length, num_consts, num_carry,
linear, block_length, f_impl, x_avals, y_avals):
consts, init, xs = split_list(args, [num_consts, num_carry])
num_blocks, rem = divmod(length, block_length)
assert rem == 0
partition = partial(_partition_leading, num_blocks, block_length)
xs_block = _map(partition, x_avals, xs)
prepend_aval = partial(_prepend_dim_to_aval, block_length)
x_block_avals = _map(prepend_aval, x_avals)
y_block_avals = _map(prepend_aval, y_avals)
f_impl_block = partial(
_scan_impl_unrolled, reverse=reverse, length=block_length,
num_consts=num_consts, num_carry=num_carry, linear=linear,
f_impl=f_impl, x_avals=x_avals, y_avals=y_avals)
outs = _scan_impl_loop(
*consts, *init, *xs_block, reverse=reverse, length=num_blocks,
num_consts=num_consts, num_carry=num_carry, linear=linear,
f_impl=f_impl_block, x_avals=x_block_avals, y_avals=y_block_avals)
carry, ys_blocks = split_list(outs, [num_carry])
combine = partial(_combine_leading, num_blocks, block_length)
ys = _map(combine, y_avals, ys_blocks)
return (*carry, *ys)
# TODO(mattjj): re-land #19819 version? simpler, but caused ~1 perf regression.
def _scan_impl(*args, reverse, length, num_consts, num_carry, jaxpr, linear,
unroll, _split_transpose):
del _split_transpose
_, _, x_avals = split_list(jaxpr.in_avals, [num_consts, num_carry])
consts, carry, xs_ = split_list(args, [num_consts, num_carry])
_, y_avals = split_list(jaxpr.out_avals, [num_carry])
f_impl = core.jaxpr_as_fun(jaxpr)
if unroll == 1:
return _scan_impl_loop(
*args, reverse=reverse, length=length, num_consts=num_consts,
num_carry=num_carry, linear=linear, f_impl=f_impl, x_avals=x_avals,
y_avals=y_avals)
consts, init, xs = split_list(args, [num_consts, num_carry])
num_blocks, rem = divmod(length, unroll)
length_div = num_blocks * unroll
if rem > 0:
if reverse:
split = partial(_split_leading_dim, rem)
xs_rem, xs = unzip2(_map(split, x_avals, xs))
num_trips, remainder = divmod(length, unroll)
if remainder:
if not reverse:
xs_, xs_rem = unzip2(_map(partial(_split_leading, num_trips*unroll), xs_))
else:
split = partial(_split_leading_dim, length_div)
xs, xs_rem = unzip2(_map(split, x_avals, xs))
xs_rem, xs_ = unzip2(_map(partial(_split_leading, remainder), xs_))
xss = [lax.reshape(x, (num_trips, unroll, *x.shape[1:])) for x in xs_]
yss = _map(partial(_empty_array, (num_trips, unroll)), y_avals)
outs = _scan_impl_block_unrolled(
*consts, *init, *xs, reverse=reverse, length=length_div,
num_consts=num_consts, num_carry=num_carry, linear=linear,
block_length=unroll, f_impl=f_impl, x_avals=x_avals, y_avals=y_avals)
def cond_fun(while_carry):
i, _, _ = while_carry
return i < num_trips
def body_fun(while_carry):
i_, carry, yss = while_carry
i = num_trips - i_ - 1 if reverse else i_
xs = [slicing.dynamic_index_in_dim(xs, i, keepdims=False) for xs in xss]
carry, ys = inner(unroll, carry, xs)
yss = [slicing.dynamic_update_index_in_dim(ys, upd, i, 0)
for ys, upd in zip(yss, ys)]
return i_ + 1, carry, yss
def inner(n, carry, xs):
ys = []
for i_ in range(n):
i = n - i_ - 1 if reverse else i_
x = [slicing.index_in_dim(x, i, keepdims=False) for x in xs]
carry_y = eval_jaxpr_p.bind(*consts, *carry, *x, jaxpr=jaxpr)
carry, y = split_list(carry_y, [num_carry])
ys.append(y)
ys = list(reversed(ys)) if reverse else ys
return carry, _map(jax.numpy.stack, zip(*ys))
carry, ys = split_list(outs, [num_carry])
if num_trips:
i = lax._const(num_trips, 0)
_, carry, yss = jax.lax.while_loop(cond_fun, body_fun, (i, carry, yss))
ys = [lax.reshape(ys, (num_trips * unroll, *ys.shape[2:])) for ys in yss]
if remainder:
carry, ys_rem = inner(remainder, carry, xs_rem)
ys = _map(_concat, ys, ys_rem) if not reverse else _map(_concat, ys_rem, ys)
return [*carry, *ys]
if rem > 0:
outs = _scan_impl_unrolled(
*consts, *carry, *xs_rem, reverse=reverse, length=rem,
num_consts=num_consts, num_carry=num_carry, linear=linear,
f_impl=f_impl, x_avals=x_avals, y_avals=y_avals)
carry, ys_rem = split_list(outs, [num_carry])
if reverse:
ys = _map(_concatenate, y_avals, ys_rem, ys)
else:
ys = _map(_concatenate, y_avals, ys, ys_rem)
def _split_leading(sz, x):
return (slicing.slice_in_dim(x, 0, sz),
slicing.slice_in_dim(x, sz, x.shape[0]))
return (*carry, *ys)
def _concat(a, b): return lax.concatenate([a, b], 0)
def _stack(aval, vals):
vals = [lax.expand_dims(x, (0,)) for x in vals]
return lax.concatenate(vals, 0)
def _empty_array(prefix, aval):
return lax.broadcast(lax.empty(aval.dtype), (*prefix, *aval.shape))
def _concatenate(aval, x1, x2):
return lax.concatenate([x1, x2], 0)
def _split_leading_dim(i, aval, x):
assert x.ndim >= 1
return (slicing.slice_in_dim(x, 0, i),
slicing.slice_in_dim(x, i, x.shape[0]))
def _dynamic_index_array(i, aval, x):
return slicing.dynamic_index_in_dim(x, i, keepdims=False)
def _index_array(i, aval, x):
return slicing.index_in_dim(x, i, keepdims=False)
def _empty_array(sz, aval):
return lax.broadcast(lax.empty(aval.dtype), (sz, *aval.shape))
def _update_array(i, aval, xs, x):
return slicing.dynamic_update_index_in_dim(xs, x, i, 0)
def _partition_leading(sz0, sz1, aval, x):
assert x.ndim >= 1
assert x.shape[0] == sz0 * sz1
return lax.reshape(x, (sz0, sz1, *x.shape[1:]))
def _combine_leading(sz0, sz1, aval, x):
assert x.ndim >= 2
assert x.shape[0] == sz0
assert x.shape[1] == sz1
return lax.collapse(x, 0, 2)
eval_jaxpr_p = core.Primitive('eval_jaxpr')
eval_jaxpr_p.multiple_results = True
def _stage_jaxpr(trace, *tracers, jaxpr):
params = dict(call_jaxpr=jaxpr)
return trace.default_process_primitive(core.closed_call_p, tracers, params)
pe.custom_staging_rules[eval_jaxpr_p] = _stage_jaxpr
@eval_jaxpr_p.def_effectful_abstract_eval # abstract eval only used for jax2tf
def _stage_jaxpr_abstract_eval(*_, jaxpr): return jaxpr.out_avals, jaxpr.effects
def _prepend_dim_to_aval(sz, aval):
return core.unmapped_aval(sz, core.no_axis_name, 0, aval)

View File

@ -1467,11 +1467,13 @@ class TensorFlowTrace(core.Trace):
def _unexpected_primitive(p: core.Primitive, *args, **kwargs):
assert False, f"Encountered unexpected primitive {p}"
# Call primitives are inlined
for unexpected in [core.call_p, maps.xmap_p]:
tf_impl[unexpected] = partial(_unexpected_primitive, unexpected)
tf_impl[lax_control_flow.loops.eval_jaxpr_p] = \
lambda *args, jaxpr: _interpret_jaxpr(
jaxpr, *args, fresh_constant_cache=False, extra_name_stack=None)
# Primitives that are not yet implemented must be explicitly declared here.
tf_not_yet_impl = [
"clz",