mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[NFC] Rename standard_insert_pbroadcast
to standard_insert_pvary
PiperOrigin-RevId: 747943230
This commit is contained in:
parent
c527ddb7bf
commit
6e00b5e02d
@ -31,7 +31,7 @@ T = TypeVar('T')
|
||||
map = safe_map
|
||||
|
||||
def add_jaxvals(x: ArrayLike, y: ArrayLike) -> Array:
|
||||
x, y = core.standard_insert_pbroadcast(x, y)
|
||||
x, y = core.standard_insert_pvary(x, y)
|
||||
return add_jaxvals_p.bind(x, y)
|
||||
|
||||
add_jaxvals_p = Primitive('add_any')
|
||||
|
@ -2026,7 +2026,7 @@ def _pvary_abstract_eval(*args, axes, axis_index_groups):
|
||||
pvary_p.def_abstract_eval(_pvary_abstract_eval)
|
||||
|
||||
|
||||
def standard_insert_pbroadcast(*args):
|
||||
def standard_insert_pvary(*args):
|
||||
if not config.varying_axes_in_types.value:
|
||||
return args
|
||||
if not config._check_rep.value:
|
||||
|
@ -499,7 +499,7 @@ def ffi_call(
|
||||
"and an output with a different layout "
|
||||
f"{static_output_layouts[o_idx]}.")
|
||||
static_input_output_aliases += ((i_idx, o_idx),)
|
||||
args = core.standard_insert_pbroadcast(*args)
|
||||
args = core.standard_insert_pvary(*args)
|
||||
results = ffi_call_p.bind(
|
||||
*args,
|
||||
result_avals=result_avals,
|
||||
|
@ -310,7 +310,7 @@ def custom_linear_solve(
|
||||
matvec_jaxpr, vecmat_jaxpr, solve_jaxpr, tr_solve_jaxpr)
|
||||
|
||||
args = _flatten(all_consts) + b_flat
|
||||
args = core.standard_insert_pbroadcast(*args)
|
||||
args = core.standard_insert_pvary(*args)
|
||||
out_flat = linear_solve_p.bind(*args, const_lengths=const_lengths, jaxprs=jaxprs)
|
||||
|
||||
return tree_unflatten(out_tree, out_flat)
|
||||
|
@ -158,7 +158,7 @@ def conv_general_dilated(
|
||||
preferred_element_type = (
|
||||
None if preferred_element_type is None else
|
||||
dtypes.canonicalize_dtype(np.dtype(preferred_element_type)))
|
||||
lhs, rhs = core.standard_insert_pbroadcast(lhs, rhs)
|
||||
lhs, rhs = core.standard_insert_pvary(lhs, rhs)
|
||||
return conv_general_dilated_p.bind(
|
||||
lhs, rhs, window_strides=tuple(window_strides), padding=tuple(padding),
|
||||
lhs_dilation=tuple(lhs_dilation), rhs_dilation=tuple(rhs_dilation),
|
||||
|
@ -369,7 +369,7 @@ def nextafter(x1: ArrayLike, x2: ArrayLike) -> Array:
|
||||
|
||||
For the smallest usable (i.e. normal) float, use ``tiny`` of ``jnp.finfo``.
|
||||
"""
|
||||
x1, x2 = core.standard_insert_pbroadcast(x1, x2)
|
||||
x1, x2 = core.standard_insert_pvary(x1, x2)
|
||||
return nextafter_p.bind(x1, x2)
|
||||
|
||||
@export
|
||||
@ -775,7 +775,7 @@ def atan2(x: ArrayLike, y: ArrayLike) -> Array:
|
||||
|
||||
.. _stablehlo.atan2: https://openxla.org/stablehlo/spec#atan2
|
||||
"""
|
||||
x, y = core.standard_insert_pbroadcast(x, y)
|
||||
x, y = core.standard_insert_pvary(x, y)
|
||||
return atan2_p.bind(x, y)
|
||||
|
||||
@export
|
||||
@ -845,7 +845,7 @@ def complex(x: ArrayLike, y: ArrayLike) -> Array:
|
||||
|
||||
.. _stablehlo.complex: https://openxla.org/stablehlo/spec#complex
|
||||
"""
|
||||
x, y = core.standard_insert_pbroadcast(x, y)
|
||||
x, y = core.standard_insert_pvary(x, y)
|
||||
return complex_p.bind(x, y)
|
||||
|
||||
@export
|
||||
@ -917,7 +917,7 @@ def pow(x: ArrayLike, y: ArrayLike) -> Array:
|
||||
.. _stablehlo.convert: https://openxla.org/stablehlo/spec#convert
|
||||
.. _stablehlo.pow: https://openxla.org/stablehlo/spec#pow
|
||||
"""
|
||||
x, y = core.standard_insert_pbroadcast(x, y)
|
||||
x, y = core.standard_insert_pvary(x, y)
|
||||
return pow_p.bind(x, y)
|
||||
|
||||
@export
|
||||
@ -1072,7 +1072,7 @@ def bitwise_and(x: ArrayLike, y: ArrayLike) -> Array:
|
||||
|
||||
.. _stablehlo.and: https://openxla.org/stablehlo/spec#and
|
||||
"""
|
||||
x, y = core.standard_insert_pbroadcast(x, y)
|
||||
x, y = core.standard_insert_pvary(x, y)
|
||||
return and_p.bind(x, y)
|
||||
|
||||
@export
|
||||
@ -1099,7 +1099,7 @@ def bitwise_or(x: ArrayLike, y: ArrayLike) -> Array:
|
||||
|
||||
.. _stablehlo.or: https://openxla.org/stablehlo/spec#or
|
||||
"""
|
||||
x, y = core.standard_insert_pbroadcast(x, y)
|
||||
x, y = core.standard_insert_pvary(x, y)
|
||||
return or_p.bind(x, y)
|
||||
|
||||
@export
|
||||
@ -1126,7 +1126,7 @@ def bitwise_xor(x: ArrayLike, y: ArrayLike) -> Array:
|
||||
|
||||
.. _stablehlo.xor: https://openxla.org/stablehlo/spec#xor
|
||||
"""
|
||||
x, y = core.standard_insert_pbroadcast(x, y)
|
||||
x, y = core.standard_insert_pvary(x, y)
|
||||
return xor_p.bind(x, y)
|
||||
|
||||
@export
|
||||
@ -1191,7 +1191,7 @@ def add(x: ArrayLike, y: ArrayLike) -> Array:
|
||||
|
||||
.. _stablehlo.add: https://openxla.org/stablehlo/spec#add
|
||||
"""
|
||||
x, y = core.standard_insert_pbroadcast(x, y)
|
||||
x, y = core.standard_insert_pvary(x, y)
|
||||
return add_p.bind(x, y)
|
||||
|
||||
@export
|
||||
@ -1215,7 +1215,7 @@ def sub(x: ArrayLike, y: ArrayLike) -> Array:
|
||||
|
||||
.. _stablehlo.subtract: https://openxla.org/stablehlo/spec#subtract
|
||||
"""
|
||||
x, y = core.standard_insert_pbroadcast(x, y)
|
||||
x, y = core.standard_insert_pvary(x, y)
|
||||
return sub_p.bind(x, y)
|
||||
|
||||
@export
|
||||
@ -1239,7 +1239,7 @@ def mul(x: ArrayLike, y: ArrayLike) -> Array:
|
||||
|
||||
.. _stablehlo.multiply: https://openxla.org/stablehlo/spec#multiply
|
||||
"""
|
||||
x, y = core.standard_insert_pbroadcast(x, y)
|
||||
x, y = core.standard_insert_pvary(x, y)
|
||||
return mul_p.bind(x, y)
|
||||
|
||||
@export
|
||||
@ -1269,7 +1269,7 @@ def div(x: ArrayLike, y: ArrayLike) -> Array:
|
||||
|
||||
.. _stablehlo.divide: https://openxla.org/stablehlo/spec#divide
|
||||
"""
|
||||
x, y = core.standard_insert_pbroadcast(x, y)
|
||||
x, y = core.standard_insert_pvary(x, y)
|
||||
return div_p.bind(x, y)
|
||||
|
||||
@export
|
||||
@ -1297,7 +1297,7 @@ def rem(x: ArrayLike, y: ArrayLike) -> Array:
|
||||
|
||||
.. _stablehlo.remainder: https://openxla.org/stablehlo/spec#remainder
|
||||
"""
|
||||
x, y = core.standard_insert_pbroadcast(x, y)
|
||||
x, y = core.standard_insert_pvary(x, y)
|
||||
return rem_p.bind(x, y)
|
||||
|
||||
@export
|
||||
@ -1323,7 +1323,7 @@ def max(x: ArrayLike, y: ArrayLike) -> Array:
|
||||
|
||||
.. _stablehlo.maximum: https://openxla.org/stablehlo/spec#maximum
|
||||
"""
|
||||
x, y = core.standard_insert_pbroadcast(x, y)
|
||||
x, y = core.standard_insert_pvary(x, y)
|
||||
return max_p.bind(x, y)
|
||||
|
||||
@export
|
||||
@ -1349,7 +1349,7 @@ def min(x: ArrayLike, y: ArrayLike) -> Array:
|
||||
|
||||
.. _stablehlo.minimum: https://openxla.org/stablehlo/spec#minimum
|
||||
"""
|
||||
x, y = core.standard_insert_pbroadcast(x, y)
|
||||
x, y = core.standard_insert_pvary(x, y)
|
||||
return min_p.bind(x, y)
|
||||
|
||||
@export
|
||||
@ -1375,7 +1375,7 @@ def shift_left(x: ArrayLike, y: ArrayLike) -> Array:
|
||||
|
||||
.. _stablehlo.shift_left: https://openxla.org/stablehlo/spec#shift_left
|
||||
"""
|
||||
x, y = core.standard_insert_pbroadcast(x, y)
|
||||
x, y = core.standard_insert_pvary(x, y)
|
||||
return shift_left_p.bind(x, y)
|
||||
|
||||
@export
|
||||
@ -1402,7 +1402,7 @@ def shift_right_arithmetic(x: ArrayLike, y: ArrayLike) -> Array:
|
||||
|
||||
.. _stablehlo.shift_right_arithmetic: https://openxla.org/stablehlo/spec#shift_right_arithmetic
|
||||
"""
|
||||
x, y = core.standard_insert_pbroadcast(x, y)
|
||||
x, y = core.standard_insert_pvary(x, y)
|
||||
return shift_right_arithmetic_p.bind(x, y)
|
||||
|
||||
@export
|
||||
@ -1429,7 +1429,7 @@ def shift_right_logical(x: ArrayLike, y: ArrayLike) -> Array:
|
||||
|
||||
.. _stablehlo.shift_right_logical: https://openxla.org/stablehlo/spec#shift_right_logical
|
||||
"""
|
||||
x, y = core.standard_insert_pbroadcast(x, y)
|
||||
x, y = core.standard_insert_pvary(x, y)
|
||||
return shift_right_logical_p.bind(x, y)
|
||||
|
||||
@export
|
||||
@ -1460,7 +1460,7 @@ def eq(x: ArrayLike, y: ArrayLike) -> Array:
|
||||
|
||||
.. _stablehlo.compare: https://openxla.org/stablehlo/spec#compare
|
||||
"""
|
||||
x, y = core.standard_insert_pbroadcast(x, y)
|
||||
x, y = core.standard_insert_pvary(x, y)
|
||||
return eq_p.bind(x, y)
|
||||
|
||||
@export
|
||||
@ -1491,7 +1491,7 @@ def ne(x: ArrayLike, y: ArrayLike) -> Array:
|
||||
|
||||
.. _stablehlo.compare: https://openxla.org/stablehlo/spec#compare
|
||||
"""
|
||||
x, y = core.standard_insert_pbroadcast(x, y)
|
||||
x, y = core.standard_insert_pvary(x, y)
|
||||
return ne_p.bind(x, y)
|
||||
|
||||
@export
|
||||
@ -1522,7 +1522,7 @@ def ge(x: ArrayLike, y: ArrayLike) -> Array:
|
||||
|
||||
.. _stablehlo.compare: https://openxla.org/stablehlo/spec#compare
|
||||
"""
|
||||
x, y = core.standard_insert_pbroadcast(x, y)
|
||||
x, y = core.standard_insert_pvary(x, y)
|
||||
return ge_p.bind(x, y)
|
||||
|
||||
@export
|
||||
@ -1553,7 +1553,7 @@ def gt(x: ArrayLike, y: ArrayLike) -> Array:
|
||||
|
||||
.. _stablehlo.compare: https://openxla.org/stablehlo/spec#compare
|
||||
"""
|
||||
x, y = core.standard_insert_pbroadcast(x, y)
|
||||
x, y = core.standard_insert_pvary(x, y)
|
||||
return gt_p.bind(x, y)
|
||||
|
||||
@export
|
||||
@ -1584,7 +1584,7 @@ def le(x: ArrayLike, y: ArrayLike) -> Array:
|
||||
|
||||
.. _stablehlo.compare: https://openxla.org/stablehlo/spec#compare
|
||||
"""
|
||||
x, y = core.standard_insert_pbroadcast(x, y)
|
||||
x, y = core.standard_insert_pvary(x, y)
|
||||
return le_p.bind(x, y)
|
||||
|
||||
@export
|
||||
@ -1615,7 +1615,7 @@ def lt(x: ArrayLike, y: ArrayLike) -> Array:
|
||||
|
||||
.. _stablehlo.compare: https://openxla.org/stablehlo/spec#compare
|
||||
"""
|
||||
x, y = core.standard_insert_pbroadcast(x, y)
|
||||
x, y = core.standard_insert_pvary(x, y)
|
||||
return lt_p.bind(x, y)
|
||||
|
||||
@export
|
||||
@ -1771,7 +1771,7 @@ def clamp(min: ArrayLike, x: ArrayLike, max: ArrayLike) -> Array:
|
||||
x & \text{otherwise}
|
||||
\end{cases}`.
|
||||
"""
|
||||
min, x, max = core.standard_insert_pbroadcast(min, x, max)
|
||||
min, x, max = core.standard_insert_pvary(min, x, max)
|
||||
return clamp_p.bind(min, x, max)
|
||||
|
||||
|
||||
@ -1878,7 +1878,7 @@ def composite(
|
||||
closed_jaxpr, out_tree = _trace_composite_to_jaxpr(
|
||||
partial(decomposition, **kwargs), in_tree, in_avals, name, debug_info
|
||||
)
|
||||
flat_args = core.standard_insert_pbroadcast(*flat_args)
|
||||
flat_args = core.standard_insert_pvary(*flat_args)
|
||||
out_flat = composite_p.bind(
|
||||
*flat_args,
|
||||
name=name,
|
||||
@ -1996,7 +1996,7 @@ def concatenate(operands: Array | Sequence[ArrayLike], dimension: int) -> Array:
|
||||
op, = operands
|
||||
if isinstance(op, Array):
|
||||
return op
|
||||
operands = core.standard_insert_pbroadcast(*operands)
|
||||
operands = core.standard_insert_pvary(*operands)
|
||||
return concatenate_p.bind(*operands, dimension=dimension)
|
||||
|
||||
|
||||
@ -2520,7 +2520,7 @@ def dot_general(lhs: ArrayLike, rhs: ArrayLike, dimension_numbers: DotDimensionN
|
||||
preferred_element_type = (
|
||||
None if preferred_element_type is None else
|
||||
dtypes.canonicalize_dtype(np.dtype(preferred_element_type)))
|
||||
lhs, rhs = core.standard_insert_pbroadcast(lhs, rhs)
|
||||
lhs, rhs = core.standard_insert_pvary(lhs, rhs)
|
||||
return dot_general_p.bind(lhs, rhs,
|
||||
dimension_numbers=(cdims, bdims),
|
||||
precision=canonicalize_precision(precision),
|
||||
@ -2656,7 +2656,7 @@ def ragged_dot_general(
|
||||
extra leading dimension of size `g` in the case where the lhs ragged
|
||||
dimension is a contracting dimension.
|
||||
"""
|
||||
lhs, rhs, group_sizes = core.standard_insert_pbroadcast(lhs, rhs, group_sizes)
|
||||
lhs, rhs, group_sizes = core.standard_insert_pvary(lhs, rhs, group_sizes)
|
||||
return ragged_dot_general_p.bind(
|
||||
lhs,
|
||||
rhs,
|
||||
@ -2840,7 +2840,7 @@ def pad(operand: ArrayLike, padding_value: ArrayLike,
|
||||
[-1, -1, -1, -1, -1, -1, -1],
|
||||
[-1, -1, -1, -1, -1, -1, -1]], dtype=int32)
|
||||
"""
|
||||
operand, padding_value = core.standard_insert_pbroadcast(operand, padding_value)
|
||||
operand, padding_value = core.standard_insert_pvary(operand, padding_value)
|
||||
return pad_p.bind(operand, padding_value, padding_config=tuple(padding_config))
|
||||
|
||||
def rev(operand: ArrayLike, dimensions: Sequence[int]) -> Array:
|
||||
@ -2873,7 +2873,7 @@ def select(pred: ArrayLike, on_true: ArrayLike, on_false: ArrayLike) -> Array:
|
||||
"""
|
||||
# Caution! The select_n_p primitive has the *opposite* order of arguments to
|
||||
# select(). This is because it implements `select_n`.
|
||||
pred, on_false, on_true = core.standard_insert_pbroadcast(
|
||||
pred, on_false, on_true = core.standard_insert_pvary(
|
||||
pred, on_false, on_true)
|
||||
return select_n_p.bind(pred, on_false, on_true)
|
||||
|
||||
@ -2900,7 +2900,7 @@ def select_n(which: ArrayLike, *cases: ArrayLike) -> Array:
|
||||
"""
|
||||
if len(cases) == 0:
|
||||
raise ValueError("select_n() must have at least one case")
|
||||
which, *cases = core.standard_insert_pbroadcast(which, *cases)
|
||||
which, *cases = core.standard_insert_pvary(which, *cases)
|
||||
return select_n_p.bind(which, *cases)
|
||||
|
||||
|
||||
@ -3262,7 +3262,7 @@ def sort(operand: Array | Sequence[Array], dimension: int = -1,
|
||||
if not (1 <= num_keys <= len(operand)):
|
||||
raise ValueError(f"{num_keys=} must be between 1 and {len(operand)=}")
|
||||
dimension = canonicalize_axis(dimension, len(operand[0].shape))
|
||||
operand = core.standard_insert_pbroadcast(*operand)
|
||||
operand = core.standard_insert_pvary(*operand)
|
||||
return tuple(sort_p.bind(*operand, dimension=dimension,
|
||||
is_stable=is_stable,
|
||||
num_keys=num_keys))
|
||||
@ -8111,7 +8111,7 @@ def after_all(*operands):
|
||||
"""Merges one or more XLA token values. Experimental.
|
||||
|
||||
Wraps the XLA AfterAll operator."""
|
||||
operands = core.standard_insert_pbroadcast(*operands)
|
||||
operands = core.standard_insert_pvary(*operands)
|
||||
return after_all_p.bind(*operands)
|
||||
|
||||
def _after_all_abstract_eval(*operands):
|
||||
@ -8246,7 +8246,7 @@ def rng_uniform(a, b, shape):
|
||||
|
||||
This API may be removed at any time.
|
||||
"""
|
||||
a, b = core.standard_insert_pbroadcast(a, b)
|
||||
a, b = core.standard_insert_pvary(a, b)
|
||||
return rng_uniform_p.bind(a, b, shape=tuple(shape))
|
||||
|
||||
def _rng_uniform_abstract_eval(a, b, *, shape):
|
||||
@ -8930,7 +8930,7 @@ def optimization_barrier(operand, /):
|
||||
"""
|
||||
flat_args, treedef = tree_util.tree_flatten(operand)
|
||||
# TODO(yashkatariya): Enable this
|
||||
# flat_args = core.standard_insert_pbroadcast(flat_args)
|
||||
# flat_args = core.standard_insert_pvary(flat_args)
|
||||
out = optimization_barrier_p.bind(*flat_args)
|
||||
return tree_util.tree_unflatten(treedef, out)
|
||||
|
||||
|
@ -121,7 +121,7 @@ def cholesky_update(r_matrix: ArrayLike, w_vector: ArrayLike) -> Array:
|
||||
A new upper-triangular matrix :math:`R` defining the Cholesky decomposition
|
||||
of :math:`A + w \, w^T`.
|
||||
"""
|
||||
r_matrix, w_vector = core.standard_insert_pbroadcast(r_matrix, w_vector)
|
||||
r_matrix, w_vector = core.standard_insert_pvary(r_matrix, w_vector)
|
||||
return cholesky_update_p.bind(r_matrix, w_vector)
|
||||
|
||||
|
||||
@ -269,7 +269,7 @@ def householder_product(a: ArrayLike, taus: ArrayLike) -> Array:
|
||||
A batch of orthogonal (unitary) matrices with the same shape as ``a``,
|
||||
containing the products of the elementary Householder reflectors.
|
||||
"""
|
||||
a, taus = core.standard_insert_pbroadcast(a, taus)
|
||||
a, taus = core.standard_insert_pvary(a, taus)
|
||||
return householder_product_p.bind(a, taus)
|
||||
|
||||
|
||||
@ -547,7 +547,7 @@ def symmetric_product(
|
||||
``symmetrize_output`` is ``True``, the upper triangle is filled with the
|
||||
transpose of the lower triangle, and the whole matrix is valid.
|
||||
"""
|
||||
a_matrix, c_matrix = core.standard_insert_pbroadcast(a_matrix, c_matrix)
|
||||
a_matrix, c_matrix = core.standard_insert_pvary(a_matrix, c_matrix)
|
||||
result = symmetric_product_p.bind(a_matrix, c_matrix, alpha=alpha, beta=beta)
|
||||
if symmetrize_output:
|
||||
upper_half = lax.transpose(
|
||||
@ -605,7 +605,7 @@ def triangular_solve(
|
||||
singleton = np.ndim(b) == np.ndim(a) - 1
|
||||
if singleton:
|
||||
b = lax.expand_dims(b, (-1 if left_side else -2,))
|
||||
a, b = core.standard_insert_pbroadcast(a, b)
|
||||
a, b = core.standard_insert_pvary(a, b)
|
||||
out = triangular_solve_p.bind(
|
||||
a, b, left_side=left_side, lower=lower, transpose_a=transpose_a,
|
||||
conjugate_a=conjugate_a, unit_diagonal=unit_diagonal)
|
||||
@ -665,7 +665,7 @@ def tridiagonal_solve(dl: Array, d: Array, du: Array, b: Array) -> Array:
|
||||
Returns:
|
||||
Solution ``X`` of tridiagonal system.
|
||||
"""
|
||||
dl, d, du, b = core.standard_insert_pbroadcast(dl, d, du, b)
|
||||
dl, d, du, b = core.standard_insert_pvary(dl, d, du, b)
|
||||
return tridiagonal_solve_p.bind(dl, d, du, b)
|
||||
|
||||
|
||||
@ -1658,7 +1658,7 @@ def _generic_lu_pivots_to_permutation(swaps, permutation_size):
|
||||
if m == 0 or k == 0:
|
||||
return permutation
|
||||
upper = np.array(k, np.int32) if is_constant_dim(k) else k
|
||||
permutation, swaps = core.standard_insert_pbroadcast(permutation, swaps)
|
||||
permutation, swaps = core.standard_insert_pvary(permutation, swaps)
|
||||
result, _ = lax.fori_loop(np.array(0, np.int32), upper, _lu_pivots_body_fn,
|
||||
(permutation, swaps))
|
||||
return result
|
||||
@ -1774,7 +1774,7 @@ def geqp3(a: ArrayLike, jpvt: ArrayLike, *,
|
||||
elementary Householder reflectors, and ``jpvt`` is the column-pivot indices
|
||||
such that ``a[:, jpvt] = q @ r``.
|
||||
"""
|
||||
a, jpvt = core.standard_insert_pbroadcast(a, jpvt)
|
||||
a, jpvt = core.standard_insert_pvary(a, jpvt)
|
||||
a_out, jpvt_out, taus = geqp3_p.bind(a, jpvt, use_magma=use_magma)
|
||||
return a_out, jpvt_out, taus
|
||||
|
||||
|
@ -173,7 +173,7 @@ def dynamic_slice(
|
||||
else:
|
||||
dynamic_sizes = []
|
||||
static_sizes = core.canonicalize_shape(slice_sizes) # type: ignore
|
||||
operand, *start_indices = core.standard_insert_pbroadcast(
|
||||
operand, *start_indices = core.standard_insert_pvary(
|
||||
operand, *start_indices)
|
||||
return dynamic_slice_p.bind(operand, *start_indices, *dynamic_sizes,
|
||||
slice_sizes=tuple(static_sizes))
|
||||
@ -236,7 +236,7 @@ def dynamic_update_slice(
|
||||
"""
|
||||
start_indices = _dynamic_slice_indices(
|
||||
operand, start_indices, allow_negative_indices)
|
||||
operand, update, *start_indices = core.standard_insert_pbroadcast(
|
||||
operand, update, *start_indices = core.standard_insert_pvary(
|
||||
operand, update, *start_indices)
|
||||
return dynamic_update_slice_p.bind(operand, update, *start_indices)
|
||||
|
||||
@ -420,7 +420,7 @@ def gather(operand: ArrayLike, start_indices: ArrayLike,
|
||||
raise ValueError(f"Unsupported dtype for gather fill_value {dtype}")
|
||||
else:
|
||||
fill_value = None
|
||||
operand, start_indices = core.standard_insert_pbroadcast(operand, start_indices)
|
||||
operand, start_indices = core.standard_insert_pvary(operand, start_indices)
|
||||
return gather_p.bind(
|
||||
operand, start_indices, dimension_numbers=dimension_numbers,
|
||||
slice_sizes=core.canonicalize_shape(slice_sizes),
|
||||
@ -510,7 +510,7 @@ def scatter_add(
|
||||
"""
|
||||
jaxpr, consts = lax._reduction_jaxpr(lax.add,
|
||||
core.get_aval(lax._const(operand, 0)))
|
||||
operand, scatter_indices, updates = core.standard_insert_pbroadcast(
|
||||
operand, scatter_indices, updates = core.standard_insert_pvary(
|
||||
operand, scatter_indices, updates)
|
||||
return scatter_add_p.bind(
|
||||
operand, scatter_indices, updates, update_jaxpr=jaxpr,
|
||||
@ -566,7 +566,7 @@ def scatter_sub(
|
||||
jaxpr, consts = lax._reduction_jaxpr(
|
||||
lax.sub, core.get_aval(lax._const(operand, 0))
|
||||
)
|
||||
operand, scatter_indices, updates = core.standard_insert_pbroadcast(
|
||||
operand, scatter_indices, updates = core.standard_insert_pvary(
|
||||
operand, scatter_indices, updates)
|
||||
return scatter_sub_p.bind(
|
||||
operand,
|
||||
@ -622,7 +622,7 @@ def scatter_mul(
|
||||
"""
|
||||
jaxpr, consts = lax._reduction_jaxpr(lax.mul,
|
||||
core.get_aval(lax._const(operand, 1)))
|
||||
operand, scatter_indices, updates = core.standard_insert_pbroadcast(
|
||||
operand, scatter_indices, updates = core.standard_insert_pvary(
|
||||
operand, scatter_indices, updates)
|
||||
return scatter_mul_p.bind(
|
||||
operand, scatter_indices, updates, update_jaxpr=jaxpr,
|
||||
@ -671,7 +671,7 @@ def scatter_min(
|
||||
"""
|
||||
jaxpr, consts = lax._reduction_jaxpr(lax.min,
|
||||
core.get_aval(lax._const(operand, 0)))
|
||||
operand, scatter_indices, updates = core.standard_insert_pbroadcast(
|
||||
operand, scatter_indices, updates = core.standard_insert_pvary(
|
||||
operand, scatter_indices, updates)
|
||||
return scatter_min_p.bind(
|
||||
operand, scatter_indices, updates, update_jaxpr=jaxpr,
|
||||
@ -720,7 +720,7 @@ def scatter_max(
|
||||
"""
|
||||
jaxpr, consts = lax._reduction_jaxpr(lax.max,
|
||||
core.get_aval(lax._const(operand, 0)))
|
||||
operand, scatter_indices, updates = core.standard_insert_pbroadcast(
|
||||
operand, scatter_indices, updates = core.standard_insert_pvary(
|
||||
operand, scatter_indices, updates)
|
||||
return scatter_max_p.bind(
|
||||
operand, scatter_indices, updates, update_jaxpr=jaxpr,
|
||||
@ -786,7 +786,7 @@ def scatter_apply(
|
||||
pass
|
||||
jaxpr, consts = lax._reduction_jaxpr(_apply, core.get_aval(lax._zero(operand)))
|
||||
# TODO: implement this via its own primitive so we can define appropriate autodiff rules.
|
||||
operand, scatter_indices, unused = core.standard_insert_pbroadcast(
|
||||
operand, scatter_indices, unused = core.standard_insert_pvary(
|
||||
operand, scatter_indices, unused)
|
||||
return scatter_p.bind(
|
||||
operand, scatter_indices, unused, update_jaxpr=jaxpr,
|
||||
@ -871,7 +871,7 @@ def scatter(
|
||||
... mode=lax.GatherScatterMode.PROMISE_IN_BOUNDS)
|
||||
Array([0., 2., 3., 0., 4.], dtype=float32)
|
||||
"""
|
||||
operand, scatter_indices, updates = core.standard_insert_pbroadcast(
|
||||
operand, scatter_indices, updates = core.standard_insert_pvary(
|
||||
operand, scatter_indices, updates)
|
||||
return scatter_p.bind(
|
||||
operand, scatter_indices, updates, update_jaxpr=None,
|
||||
|
@ -59,7 +59,7 @@ def _up_and_broadcast(doit):
|
||||
|
||||
def betainc(a: ArrayLike, b: ArrayLike, x: ArrayLike) -> Array:
|
||||
r"""Elementwise regularized incomplete beta integral."""
|
||||
a, b, x = core.standard_insert_pbroadcast(a, b, x)
|
||||
a, b, x = core.standard_insert_pvary(a, b, x)
|
||||
return regularized_incomplete_beta_p.bind(a, b, x)
|
||||
|
||||
def lgamma(x: ArrayLike) -> Array:
|
||||
@ -72,33 +72,33 @@ def digamma(x: ArrayLike) -> Array:
|
||||
|
||||
def polygamma(m: ArrayLike, x: ArrayLike) -> Array:
|
||||
r"""Elementwise polygamma: :math:`\psi^{(m)}(x)`."""
|
||||
m, x = core.standard_insert_pbroadcast(m, x)
|
||||
m, x = core.standard_insert_pvary(m, x)
|
||||
return polygamma_p.bind(m, x)
|
||||
|
||||
def igamma(a: ArrayLike, x: ArrayLike) -> Array:
|
||||
r"""Elementwise regularized incomplete gamma function."""
|
||||
a, x = core.standard_insert_pbroadcast(a, x)
|
||||
a, x = core.standard_insert_pvary(a, x)
|
||||
return igamma_p.bind(a, x)
|
||||
|
||||
def igammac(a: ArrayLike, x: ArrayLike) -> Array:
|
||||
r"""Elementwise complementary regularized incomplete gamma function."""
|
||||
a, x = core.standard_insert_pbroadcast(a, x)
|
||||
a, x = core.standard_insert_pvary(a, x)
|
||||
return igammac_p.bind(a, x)
|
||||
|
||||
def igamma_grad_a(a: ArrayLike, x: ArrayLike) -> Array:
|
||||
r"""Elementwise derivative of the regularized incomplete gamma function."""
|
||||
a, x = core.standard_insert_pbroadcast(a, x)
|
||||
a, x = core.standard_insert_pvary(a, x)
|
||||
return igamma_grad_a_p.bind(a, x)
|
||||
|
||||
@_up_and_broadcast
|
||||
def random_gamma_grad(a: ArrayLike, x: ArrayLike, *, dtype) -> Array:
|
||||
r"""Elementwise derivative of samples from `Gamma(a, 1)`."""
|
||||
a, x = core.standard_insert_pbroadcast(a, x)
|
||||
a, x = core.standard_insert_pvary(a, x)
|
||||
return random_gamma_grad_impl(a, x, dtype=dtype)
|
||||
|
||||
def zeta(x: ArrayLike, q: ArrayLike) -> Array:
|
||||
r"""Elementwise Hurwitz zeta function: :math:`\zeta(x, q)`"""
|
||||
x, q = core.standard_insert_pbroadcast(x, q)
|
||||
x, q = core.standard_insert_pvary(x, q)
|
||||
return zeta_p.bind(x, q)
|
||||
|
||||
def bessel_i0e(x: ArrayLike) -> Array:
|
||||
|
@ -97,7 +97,7 @@ def _reduce_window(
|
||||
raise ValueError(
|
||||
'reduce_window output must have the same tree structure as the operands'
|
||||
f' {operand_tree} vs. {out_tree}')
|
||||
flat_operands = core.standard_insert_pbroadcast(*flat_operands)
|
||||
flat_operands = core.standard_insert_pvary(*flat_operands)
|
||||
out_flat = reduce_window_p.bind(
|
||||
*flat_operands,
|
||||
*flat_init_values,
|
||||
@ -251,7 +251,7 @@ def _select_and_scatter(operand: Array, select: Callable,
|
||||
select, core.get_aval(init_value))
|
||||
scatter_jaxpr, scatter_consts = lax._reduction_jaxpr(
|
||||
scatter, core.get_aval(init_value))
|
||||
operand, source, init_value = core.standard_insert_pbroadcast(
|
||||
operand, source, init_value = core.standard_insert_pvary(
|
||||
operand, source, init_value)
|
||||
return select_and_scatter_p.bind(
|
||||
operand, source, init_value, select_jaxpr=select_jaxpr,
|
||||
@ -264,7 +264,7 @@ def _select_and_scatter_add(source: Array, operand: Array,
|
||||
window_dimensions: core.Shape,
|
||||
window_strides: Sequence[int],
|
||||
padding: Sequence[tuple[int, int]]) -> Array:
|
||||
source, operand = core.standard_insert_pbroadcast(source, operand)
|
||||
source, operand = core.standard_insert_pvary(source, operand)
|
||||
return select_and_scatter_add_p.bind(
|
||||
source, operand, select_prim=select_prim,
|
||||
window_dimensions=tuple(window_dimensions),
|
||||
@ -300,7 +300,7 @@ def _select_and_gather_add(tangents: Array, operand: Array,
|
||||
An array containing the elements in `tangents` corresponding to the output
|
||||
of the reduction of `operand` fin each window.
|
||||
"""
|
||||
tangents, operand = core.standard_insert_pbroadcast(tangents, operand)
|
||||
tangents, operand = core.standard_insert_pvary(tangents, operand)
|
||||
return select_and_gather_add_p.bind(
|
||||
tangents, operand, select_prim=select_prim,
|
||||
window_dimensions=tuple(window_dimensions),
|
||||
|
@ -618,7 +618,7 @@ mlir.register_lowering(random_split_p, random_split_lowering)
|
||||
|
||||
def random_fold_in(keys, msgs):
|
||||
msgs = jnp.asarray(msgs)
|
||||
keys, msgs = core.standard_insert_pbroadcast(keys, msgs)
|
||||
keys, msgs = core.standard_insert_pvary(keys, msgs)
|
||||
return random_fold_in_p.bind(keys, msgs)
|
||||
|
||||
random_fold_in_p = core.Primitive('random_fold_in')
|
||||
|
Loading…
x
Reference in New Issue
Block a user