mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Cleanup: replace lax._abstractify with core.get_aval
This commit is contained in:
parent
97459ba9aa
commit
67b3413b96
@ -1490,7 +1490,7 @@ def reduce(operands: Any,
|
||||
return _convert_element_type(monoid_reducer(*flat_operands, dimensions),
|
||||
weak_type=weak_type)
|
||||
else:
|
||||
flat_init_avals = safe_map(_abstractify, flat_init_values)
|
||||
flat_init_avals = safe_map(core.get_aval, flat_init_values)
|
||||
closed_jaxpr, out_tree = _variadic_reduction_jaxpr(
|
||||
computation, tuple(flat_init_avals), init_value_tree)
|
||||
out = reduce_p.bind(*flat_operands, *flat_init_values, computation=computation,
|
||||
@ -2761,8 +2761,8 @@ def _add_transpose(t, x, y):
|
||||
# some places (e.g. in custom_jvp) it may not always hold. For example, see
|
||||
# api_test.py's CustomJVPTest.test_jaxpr_zeros.
|
||||
# assert ad.is_undefined_primal(x) and ad.is_undefined_primal(y)
|
||||
x_aval = x.aval if ad.is_undefined_primal(x) else _abstractify(x)
|
||||
y_aval = y.aval if ad.is_undefined_primal(y) else _abstractify(y)
|
||||
x_aval = x.aval if ad.is_undefined_primal(x) else core.get_aval(x)
|
||||
y_aval = y.aval if ad.is_undefined_primal(y) else core.get_aval(y)
|
||||
if type(t) is ad_util.Zero:
|
||||
return [ad_util.Zero(x_aval), ad_util.Zero(y_aval)]
|
||||
else:
|
||||
@ -2792,8 +2792,8 @@ def _sub_transpose(t, x, y):
|
||||
# Morally the following assertion is true, but see the comment in add_p's
|
||||
# transpose rule.
|
||||
# assert ad.is_undefined_primal(x) and ad.is_undefined_primal(y)
|
||||
x_aval = x.aval if ad.is_undefined_primal(x) else _abstractify(x)
|
||||
y_aval = y.aval if ad.is_undefined_primal(y) else _abstractify(y)
|
||||
x_aval = x.aval if ad.is_undefined_primal(x) else core.get_aval(x)
|
||||
y_aval = y.aval if ad.is_undefined_primal(y) else core.get_aval(y)
|
||||
if type(t) is ad_util.Zero:
|
||||
return [ad_util.Zero(x_aval), ad_util.Zero(y_aval)]
|
||||
else:
|
||||
@ -6385,10 +6385,6 @@ def _eq_meet(a, b):
|
||||
return eq(a, b)
|
||||
|
||||
|
||||
def _abstractify(x):
|
||||
return core.get_aval(x)
|
||||
|
||||
|
||||
def empty(dtype):
|
||||
return empty_p.bind(dtype=dtype)
|
||||
empty_p = core.Primitive('empty')
|
||||
|
@ -483,7 +483,7 @@ def scatter_add(
|
||||
An array containing the sum of `operand` and the scattered updates.
|
||||
"""
|
||||
jaxpr, consts = lax._reduction_jaxpr(lax.add,
|
||||
lax._abstractify(lax._const(operand, 0)))
|
||||
core.get_aval(lax._const(operand, 0)))
|
||||
return scatter_add_p.bind(
|
||||
operand, scatter_indices, updates, update_jaxpr=jaxpr,
|
||||
update_consts=consts, dimension_numbers=dimension_numbers,
|
||||
@ -536,7 +536,7 @@ def scatter_sub(
|
||||
An array containing the sum of `operand` and the scattered updates.
|
||||
"""
|
||||
jaxpr, consts = lax._reduction_jaxpr(
|
||||
lax.sub, lax._abstractify(lax._const(operand, 0))
|
||||
lax.sub, core.get_aval(lax._const(operand, 0))
|
||||
)
|
||||
return scatter_sub_p.bind(
|
||||
operand,
|
||||
@ -591,7 +591,7 @@ def scatter_mul(
|
||||
An array containing the sum of `operand` and the scattered updates.
|
||||
"""
|
||||
jaxpr, consts = lax._reduction_jaxpr(lax.mul,
|
||||
lax._abstractify(lax._const(operand, 1)))
|
||||
core.get_aval(lax._const(operand, 1)))
|
||||
return scatter_mul_p.bind(
|
||||
operand, scatter_indices, updates, update_jaxpr=jaxpr,
|
||||
update_consts=consts, dimension_numbers=dimension_numbers,
|
||||
@ -638,7 +638,7 @@ def scatter_min(
|
||||
An array containing the sum of `operand` and the scattered updates.
|
||||
"""
|
||||
jaxpr, consts = lax._reduction_jaxpr(lax.min,
|
||||
lax._abstractify(lax._const(operand, 0)))
|
||||
core.get_aval(lax._const(operand, 0)))
|
||||
return scatter_min_p.bind(
|
||||
operand, scatter_indices, updates, update_jaxpr=jaxpr,
|
||||
update_consts=consts, dimension_numbers=dimension_numbers,
|
||||
@ -685,7 +685,7 @@ def scatter_max(
|
||||
An array containing the sum of `operand` and the scattered updates.
|
||||
"""
|
||||
jaxpr, consts = lax._reduction_jaxpr(lax.max,
|
||||
lax._abstractify(lax._const(operand, 0)))
|
||||
core.get_aval(lax._const(operand, 0)))
|
||||
return scatter_max_p.bind(
|
||||
operand, scatter_indices, updates, update_jaxpr=jaxpr,
|
||||
update_consts=consts, dimension_numbers=dimension_numbers,
|
||||
@ -748,7 +748,7 @@ def scatter_apply(
|
||||
_apply = _scatter_apply_cache.setdefault(func, _apply)
|
||||
except TypeError: # func is not weak referenceable
|
||||
pass
|
||||
jaxpr, consts = lax._reduction_jaxpr(_apply, lax._abstractify(lax._zero(operand)))
|
||||
jaxpr, consts = lax._reduction_jaxpr(_apply, core.get_aval(lax._zero(operand)))
|
||||
# TODO: implement this via its own primitive so we can define appropriate autodiff rules.
|
||||
return scatter_p.bind(
|
||||
operand, scatter_indices, unused, update_jaxpr=jaxpr,
|
||||
|
@ -90,7 +90,7 @@ def _reduce_window(
|
||||
return monoid_reducer(operand, window_dimensions, window_strides, padding,
|
||||
base_dilation, window_dilation)
|
||||
else:
|
||||
flat_init_avals = map(lax._abstractify, flat_init_values)
|
||||
flat_init_avals = map(core.get_aval, flat_init_values)
|
||||
jaxpr, out_tree = lax._variadic_reduction_jaxpr(
|
||||
computation, tuple(flat_init_avals), init_value_tree
|
||||
)
|
||||
@ -176,7 +176,7 @@ def _reduce_window_prod(operand: Array, window_dimensions: core.Shape,
|
||||
base_dilation: Sequence[int] | None = None,
|
||||
window_dilation: Sequence[int] | None = None) -> Array:
|
||||
init_value = lax._const(operand, 1)
|
||||
jaxpr, consts = lax._reduction_jaxpr(lax.mul, lax._abstractify(init_value))
|
||||
jaxpr, consts = lax._reduction_jaxpr(lax.mul, core.get_aval(init_value))
|
||||
if base_dilation is None:
|
||||
base_dilation = (1,) * len(window_dimensions)
|
||||
if window_dilation is None:
|
||||
@ -226,7 +226,7 @@ def _reduce_window_logaddexp(
|
||||
base_dilation: Sequence[int] | None = None,
|
||||
window_dilation: Sequence[int] | None = None) -> Array:
|
||||
init_value = lax._const(operand, -np.inf)
|
||||
jaxpr, consts = lax._reduction_jaxpr(logaddexp, lax._abstractify(init_value))
|
||||
jaxpr, consts = lax._reduction_jaxpr(logaddexp, core.get_aval(init_value))
|
||||
if base_dilation is None:
|
||||
base_dilation = (1,) * len(window_dimensions)
|
||||
if window_dilation is None:
|
||||
@ -245,9 +245,9 @@ def _select_and_scatter(operand: Array, select: Callable,
|
||||
padding: Sequence[tuple[int, int]], source: Array,
|
||||
init_value: Array, scatter: Callable) -> Array:
|
||||
select_jaxpr, select_consts = lax._reduction_jaxpr(
|
||||
select, lax._abstractify(init_value))
|
||||
select, core.get_aval(init_value))
|
||||
scatter_jaxpr, scatter_consts = lax._reduction_jaxpr(
|
||||
scatter, lax._abstractify(init_value))
|
||||
scatter, core.get_aval(init_value))
|
||||
return select_and_scatter_p.bind(
|
||||
operand, source, init_value, select_jaxpr=select_jaxpr,
|
||||
select_consts=select_consts, scatter_jaxpr=scatter_jaxpr,
|
||||
|
Loading…
x
Reference in New Issue
Block a user