Cleanup: replace lax._abstractify with core.get_aval

This commit is contained in:
Jake VanderPlas 2024-12-12 14:08:17 -08:00
parent 97459ba9aa
commit 67b3413b96
3 changed files with 16 additions and 20 deletions

View File

@ -1490,7 +1490,7 @@ def reduce(operands: Any,
return _convert_element_type(monoid_reducer(*flat_operands, dimensions),
weak_type=weak_type)
else:
flat_init_avals = safe_map(_abstractify, flat_init_values)
flat_init_avals = safe_map(core.get_aval, flat_init_values)
closed_jaxpr, out_tree = _variadic_reduction_jaxpr(
computation, tuple(flat_init_avals), init_value_tree)
out = reduce_p.bind(*flat_operands, *flat_init_values, computation=computation,
@ -2761,8 +2761,8 @@ def _add_transpose(t, x, y):
# some places (e.g. in custom_jvp) it may not always hold. For example, see
# api_test.py's CustomJVPTest.test_jaxpr_zeros.
# assert ad.is_undefined_primal(x) and ad.is_undefined_primal(y)
x_aval = x.aval if ad.is_undefined_primal(x) else _abstractify(x)
y_aval = y.aval if ad.is_undefined_primal(y) else _abstractify(y)
x_aval = x.aval if ad.is_undefined_primal(x) else core.get_aval(x)
y_aval = y.aval if ad.is_undefined_primal(y) else core.get_aval(y)
if type(t) is ad_util.Zero:
return [ad_util.Zero(x_aval), ad_util.Zero(y_aval)]
else:
@ -2792,8 +2792,8 @@ def _sub_transpose(t, x, y):
# Morally the following assertion is true, but see the comment in add_p's
# transpose rule.
# assert ad.is_undefined_primal(x) and ad.is_undefined_primal(y)
x_aval = x.aval if ad.is_undefined_primal(x) else _abstractify(x)
y_aval = y.aval if ad.is_undefined_primal(y) else _abstractify(y)
x_aval = x.aval if ad.is_undefined_primal(x) else core.get_aval(x)
y_aval = y.aval if ad.is_undefined_primal(y) else core.get_aval(y)
if type(t) is ad_util.Zero:
return [ad_util.Zero(x_aval), ad_util.Zero(y_aval)]
else:
@ -6385,10 +6385,6 @@ def _eq_meet(a, b):
return eq(a, b)
def _abstractify(x):
return core.get_aval(x)
def empty(dtype):
return empty_p.bind(dtype=dtype)
empty_p = core.Primitive('empty')

View File

@ -483,7 +483,7 @@ def scatter_add(
An array containing the sum of `operand` and the scattered updates.
"""
jaxpr, consts = lax._reduction_jaxpr(lax.add,
lax._abstractify(lax._const(operand, 0)))
core.get_aval(lax._const(operand, 0)))
return scatter_add_p.bind(
operand, scatter_indices, updates, update_jaxpr=jaxpr,
update_consts=consts, dimension_numbers=dimension_numbers,
@ -536,7 +536,7 @@ def scatter_sub(
An array containing the sum of `operand` and the scattered updates.
"""
jaxpr, consts = lax._reduction_jaxpr(
lax.sub, lax._abstractify(lax._const(operand, 0))
lax.sub, core.get_aval(lax._const(operand, 0))
)
return scatter_sub_p.bind(
operand,
@ -591,7 +591,7 @@ def scatter_mul(
An array containing the sum of `operand` and the scattered updates.
"""
jaxpr, consts = lax._reduction_jaxpr(lax.mul,
lax._abstractify(lax._const(operand, 1)))
core.get_aval(lax._const(operand, 1)))
return scatter_mul_p.bind(
operand, scatter_indices, updates, update_jaxpr=jaxpr,
update_consts=consts, dimension_numbers=dimension_numbers,
@ -638,7 +638,7 @@ def scatter_min(
An array containing the sum of `operand` and the scattered updates.
"""
jaxpr, consts = lax._reduction_jaxpr(lax.min,
lax._abstractify(lax._const(operand, 0)))
core.get_aval(lax._const(operand, 0)))
return scatter_min_p.bind(
operand, scatter_indices, updates, update_jaxpr=jaxpr,
update_consts=consts, dimension_numbers=dimension_numbers,
@ -685,7 +685,7 @@ def scatter_max(
An array containing the sum of `operand` and the scattered updates.
"""
jaxpr, consts = lax._reduction_jaxpr(lax.max,
lax._abstractify(lax._const(operand, 0)))
core.get_aval(lax._const(operand, 0)))
return scatter_max_p.bind(
operand, scatter_indices, updates, update_jaxpr=jaxpr,
update_consts=consts, dimension_numbers=dimension_numbers,
@ -748,7 +748,7 @@ def scatter_apply(
_apply = _scatter_apply_cache.setdefault(func, _apply)
except TypeError: # func is not weak referenceable
pass
jaxpr, consts = lax._reduction_jaxpr(_apply, lax._abstractify(lax._zero(operand)))
jaxpr, consts = lax._reduction_jaxpr(_apply, core.get_aval(lax._zero(operand)))
# TODO: implement this via its own primitive so we can define appropriate autodiff rules.
return scatter_p.bind(
operand, scatter_indices, unused, update_jaxpr=jaxpr,

View File

@ -90,7 +90,7 @@ def _reduce_window(
return monoid_reducer(operand, window_dimensions, window_strides, padding,
base_dilation, window_dilation)
else:
flat_init_avals = map(lax._abstractify, flat_init_values)
flat_init_avals = map(core.get_aval, flat_init_values)
jaxpr, out_tree = lax._variadic_reduction_jaxpr(
computation, tuple(flat_init_avals), init_value_tree
)
@ -176,7 +176,7 @@ def _reduce_window_prod(operand: Array, window_dimensions: core.Shape,
base_dilation: Sequence[int] | None = None,
window_dilation: Sequence[int] | None = None) -> Array:
init_value = lax._const(operand, 1)
jaxpr, consts = lax._reduction_jaxpr(lax.mul, lax._abstractify(init_value))
jaxpr, consts = lax._reduction_jaxpr(lax.mul, core.get_aval(init_value))
if base_dilation is None:
base_dilation = (1,) * len(window_dimensions)
if window_dilation is None:
@ -226,7 +226,7 @@ def _reduce_window_logaddexp(
base_dilation: Sequence[int] | None = None,
window_dilation: Sequence[int] | None = None) -> Array:
init_value = lax._const(operand, -np.inf)
jaxpr, consts = lax._reduction_jaxpr(logaddexp, lax._abstractify(init_value))
jaxpr, consts = lax._reduction_jaxpr(logaddexp, core.get_aval(init_value))
if base_dilation is None:
base_dilation = (1,) * len(window_dimensions)
if window_dilation is None:
@ -245,9 +245,9 @@ def _select_and_scatter(operand: Array, select: Callable,
padding: Sequence[tuple[int, int]], source: Array,
init_value: Array, scatter: Callable) -> Array:
select_jaxpr, select_consts = lax._reduction_jaxpr(
select, lax._abstractify(init_value))
select, core.get_aval(init_value))
scatter_jaxpr, scatter_consts = lax._reduction_jaxpr(
scatter, lax._abstractify(init_value))
scatter, core.get_aval(init_value))
return select_and_scatter_p.bind(
operand, source, init_value, select_jaxpr=select_jaxpr,
select_consts=select_consts, scatter_jaxpr=scatter_jaxpr,