[NFC] Rename standard_insert_pbroadcast to standard_insert_pvary

PiperOrigin-RevId: 747943230
This commit is contained in:
Yash Katariya 2025-04-15 11:01:49 -07:00 committed by jax authors
parent c527ddb7bf
commit 6e00b5e02d
11 changed files with 69 additions and 69 deletions

View File

@ -31,7 +31,7 @@ T = TypeVar('T')
map = safe_map
def add_jaxvals(x: ArrayLike, y: ArrayLike) -> Array:
x, y = core.standard_insert_pbroadcast(x, y)
x, y = core.standard_insert_pvary(x, y)
return add_jaxvals_p.bind(x, y)
add_jaxvals_p = Primitive('add_any')

View File

@ -2026,7 +2026,7 @@ def _pvary_abstract_eval(*args, axes, axis_index_groups):
pvary_p.def_abstract_eval(_pvary_abstract_eval)
def standard_insert_pbroadcast(*args):
def standard_insert_pvary(*args):
if not config.varying_axes_in_types.value:
return args
if not config._check_rep.value:

View File

@ -499,7 +499,7 @@ def ffi_call(
"and an output with a different layout "
f"{static_output_layouts[o_idx]}.")
static_input_output_aliases += ((i_idx, o_idx),)
args = core.standard_insert_pbroadcast(*args)
args = core.standard_insert_pvary(*args)
results = ffi_call_p.bind(
*args,
result_avals=result_avals,

View File

@ -310,7 +310,7 @@ def custom_linear_solve(
matvec_jaxpr, vecmat_jaxpr, solve_jaxpr, tr_solve_jaxpr)
args = _flatten(all_consts) + b_flat
args = core.standard_insert_pbroadcast(*args)
args = core.standard_insert_pvary(*args)
out_flat = linear_solve_p.bind(*args, const_lengths=const_lengths, jaxprs=jaxprs)
return tree_unflatten(out_tree, out_flat)

View File

@ -158,7 +158,7 @@ def conv_general_dilated(
preferred_element_type = (
None if preferred_element_type is None else
dtypes.canonicalize_dtype(np.dtype(preferred_element_type)))
lhs, rhs = core.standard_insert_pbroadcast(lhs, rhs)
lhs, rhs = core.standard_insert_pvary(lhs, rhs)
return conv_general_dilated_p.bind(
lhs, rhs, window_strides=tuple(window_strides), padding=tuple(padding),
lhs_dilation=tuple(lhs_dilation), rhs_dilation=tuple(rhs_dilation),

View File

@ -369,7 +369,7 @@ def nextafter(x1: ArrayLike, x2: ArrayLike) -> Array:
For the smallest usable (i.e. normal) float, use ``tiny`` of ``jnp.finfo``.
"""
x1, x2 = core.standard_insert_pbroadcast(x1, x2)
x1, x2 = core.standard_insert_pvary(x1, x2)
return nextafter_p.bind(x1, x2)
@export
@ -775,7 +775,7 @@ def atan2(x: ArrayLike, y: ArrayLike) -> Array:
.. _stablehlo.atan2: https://openxla.org/stablehlo/spec#atan2
"""
x, y = core.standard_insert_pbroadcast(x, y)
x, y = core.standard_insert_pvary(x, y)
return atan2_p.bind(x, y)
@export
@ -845,7 +845,7 @@ def complex(x: ArrayLike, y: ArrayLike) -> Array:
.. _stablehlo.complex: https://openxla.org/stablehlo/spec#complex
"""
x, y = core.standard_insert_pbroadcast(x, y)
x, y = core.standard_insert_pvary(x, y)
return complex_p.bind(x, y)
@export
@ -917,7 +917,7 @@ def pow(x: ArrayLike, y: ArrayLike) -> Array:
.. _stablehlo.convert: https://openxla.org/stablehlo/spec#convert
.. _stablehlo.pow: https://openxla.org/stablehlo/spec#pow
"""
x, y = core.standard_insert_pbroadcast(x, y)
x, y = core.standard_insert_pvary(x, y)
return pow_p.bind(x, y)
@export
@ -1072,7 +1072,7 @@ def bitwise_and(x: ArrayLike, y: ArrayLike) -> Array:
.. _stablehlo.and: https://openxla.org/stablehlo/spec#and
"""
x, y = core.standard_insert_pbroadcast(x, y)
x, y = core.standard_insert_pvary(x, y)
return and_p.bind(x, y)
@export
@ -1099,7 +1099,7 @@ def bitwise_or(x: ArrayLike, y: ArrayLike) -> Array:
.. _stablehlo.or: https://openxla.org/stablehlo/spec#or
"""
x, y = core.standard_insert_pbroadcast(x, y)
x, y = core.standard_insert_pvary(x, y)
return or_p.bind(x, y)
@export
@ -1126,7 +1126,7 @@ def bitwise_xor(x: ArrayLike, y: ArrayLike) -> Array:
.. _stablehlo.xor: https://openxla.org/stablehlo/spec#xor
"""
x, y = core.standard_insert_pbroadcast(x, y)
x, y = core.standard_insert_pvary(x, y)
return xor_p.bind(x, y)
@export
@ -1191,7 +1191,7 @@ def add(x: ArrayLike, y: ArrayLike) -> Array:
.. _stablehlo.add: https://openxla.org/stablehlo/spec#add
"""
x, y = core.standard_insert_pbroadcast(x, y)
x, y = core.standard_insert_pvary(x, y)
return add_p.bind(x, y)
@export
@ -1215,7 +1215,7 @@ def sub(x: ArrayLike, y: ArrayLike) -> Array:
.. _stablehlo.subtract: https://openxla.org/stablehlo/spec#subtract
"""
x, y = core.standard_insert_pbroadcast(x, y)
x, y = core.standard_insert_pvary(x, y)
return sub_p.bind(x, y)
@export
@ -1239,7 +1239,7 @@ def mul(x: ArrayLike, y: ArrayLike) -> Array:
.. _stablehlo.multiply: https://openxla.org/stablehlo/spec#multiply
"""
x, y = core.standard_insert_pbroadcast(x, y)
x, y = core.standard_insert_pvary(x, y)
return mul_p.bind(x, y)
@export
@ -1269,7 +1269,7 @@ def div(x: ArrayLike, y: ArrayLike) -> Array:
.. _stablehlo.divide: https://openxla.org/stablehlo/spec#divide
"""
x, y = core.standard_insert_pbroadcast(x, y)
x, y = core.standard_insert_pvary(x, y)
return div_p.bind(x, y)
@export
@ -1297,7 +1297,7 @@ def rem(x: ArrayLike, y: ArrayLike) -> Array:
.. _stablehlo.remainder: https://openxla.org/stablehlo/spec#remainder
"""
x, y = core.standard_insert_pbroadcast(x, y)
x, y = core.standard_insert_pvary(x, y)
return rem_p.bind(x, y)
@export
@ -1323,7 +1323,7 @@ def max(x: ArrayLike, y: ArrayLike) -> Array:
.. _stablehlo.maximum: https://openxla.org/stablehlo/spec#maximum
"""
x, y = core.standard_insert_pbroadcast(x, y)
x, y = core.standard_insert_pvary(x, y)
return max_p.bind(x, y)
@export
@ -1349,7 +1349,7 @@ def min(x: ArrayLike, y: ArrayLike) -> Array:
.. _stablehlo.minimum: https://openxla.org/stablehlo/spec#minimum
"""
x, y = core.standard_insert_pbroadcast(x, y)
x, y = core.standard_insert_pvary(x, y)
return min_p.bind(x, y)
@export
@ -1375,7 +1375,7 @@ def shift_left(x: ArrayLike, y: ArrayLike) -> Array:
.. _stablehlo.shift_left: https://openxla.org/stablehlo/spec#shift_left
"""
x, y = core.standard_insert_pbroadcast(x, y)
x, y = core.standard_insert_pvary(x, y)
return shift_left_p.bind(x, y)
@export
@ -1402,7 +1402,7 @@ def shift_right_arithmetic(x: ArrayLike, y: ArrayLike) -> Array:
.. _stablehlo.shift_right_arithmetic: https://openxla.org/stablehlo/spec#shift_right_arithmetic
"""
x, y = core.standard_insert_pbroadcast(x, y)
x, y = core.standard_insert_pvary(x, y)
return shift_right_arithmetic_p.bind(x, y)
@export
@ -1429,7 +1429,7 @@ def shift_right_logical(x: ArrayLike, y: ArrayLike) -> Array:
.. _stablehlo.shift_right_logical: https://openxla.org/stablehlo/spec#shift_right_logical
"""
x, y = core.standard_insert_pbroadcast(x, y)
x, y = core.standard_insert_pvary(x, y)
return shift_right_logical_p.bind(x, y)
@export
@ -1460,7 +1460,7 @@ def eq(x: ArrayLike, y: ArrayLike) -> Array:
.. _stablehlo.compare: https://openxla.org/stablehlo/spec#compare
"""
x, y = core.standard_insert_pbroadcast(x, y)
x, y = core.standard_insert_pvary(x, y)
return eq_p.bind(x, y)
@export
@ -1491,7 +1491,7 @@ def ne(x: ArrayLike, y: ArrayLike) -> Array:
.. _stablehlo.compare: https://openxla.org/stablehlo/spec#compare
"""
x, y = core.standard_insert_pbroadcast(x, y)
x, y = core.standard_insert_pvary(x, y)
return ne_p.bind(x, y)
@export
@ -1522,7 +1522,7 @@ def ge(x: ArrayLike, y: ArrayLike) -> Array:
.. _stablehlo.compare: https://openxla.org/stablehlo/spec#compare
"""
x, y = core.standard_insert_pbroadcast(x, y)
x, y = core.standard_insert_pvary(x, y)
return ge_p.bind(x, y)
@export
@ -1553,7 +1553,7 @@ def gt(x: ArrayLike, y: ArrayLike) -> Array:
.. _stablehlo.compare: https://openxla.org/stablehlo/spec#compare
"""
x, y = core.standard_insert_pbroadcast(x, y)
x, y = core.standard_insert_pvary(x, y)
return gt_p.bind(x, y)
@export
@ -1584,7 +1584,7 @@ def le(x: ArrayLike, y: ArrayLike) -> Array:
.. _stablehlo.compare: https://openxla.org/stablehlo/spec#compare
"""
x, y = core.standard_insert_pbroadcast(x, y)
x, y = core.standard_insert_pvary(x, y)
return le_p.bind(x, y)
@export
@ -1615,7 +1615,7 @@ def lt(x: ArrayLike, y: ArrayLike) -> Array:
.. _stablehlo.compare: https://openxla.org/stablehlo/spec#compare
"""
x, y = core.standard_insert_pbroadcast(x, y)
x, y = core.standard_insert_pvary(x, y)
return lt_p.bind(x, y)
@export
@ -1771,7 +1771,7 @@ def clamp(min: ArrayLike, x: ArrayLike, max: ArrayLike) -> Array:
x & \text{otherwise}
\end{cases}`.
"""
min, x, max = core.standard_insert_pbroadcast(min, x, max)
min, x, max = core.standard_insert_pvary(min, x, max)
return clamp_p.bind(min, x, max)
@ -1878,7 +1878,7 @@ def composite(
closed_jaxpr, out_tree = _trace_composite_to_jaxpr(
partial(decomposition, **kwargs), in_tree, in_avals, name, debug_info
)
flat_args = core.standard_insert_pbroadcast(*flat_args)
flat_args = core.standard_insert_pvary(*flat_args)
out_flat = composite_p.bind(
*flat_args,
name=name,
@ -1996,7 +1996,7 @@ def concatenate(operands: Array | Sequence[ArrayLike], dimension: int) -> Array:
op, = operands
if isinstance(op, Array):
return op
operands = core.standard_insert_pbroadcast(*operands)
operands = core.standard_insert_pvary(*operands)
return concatenate_p.bind(*operands, dimension=dimension)
@ -2520,7 +2520,7 @@ def dot_general(lhs: ArrayLike, rhs: ArrayLike, dimension_numbers: DotDimensionN
preferred_element_type = (
None if preferred_element_type is None else
dtypes.canonicalize_dtype(np.dtype(preferred_element_type)))
lhs, rhs = core.standard_insert_pbroadcast(lhs, rhs)
lhs, rhs = core.standard_insert_pvary(lhs, rhs)
return dot_general_p.bind(lhs, rhs,
dimension_numbers=(cdims, bdims),
precision=canonicalize_precision(precision),
@ -2656,7 +2656,7 @@ def ragged_dot_general(
extra leading dimension of size `g` in the case where the lhs ragged
dimension is a contracting dimension.
"""
lhs, rhs, group_sizes = core.standard_insert_pbroadcast(lhs, rhs, group_sizes)
lhs, rhs, group_sizes = core.standard_insert_pvary(lhs, rhs, group_sizes)
return ragged_dot_general_p.bind(
lhs,
rhs,
@ -2840,7 +2840,7 @@ def pad(operand: ArrayLike, padding_value: ArrayLike,
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1]], dtype=int32)
"""
operand, padding_value = core.standard_insert_pbroadcast(operand, padding_value)
operand, padding_value = core.standard_insert_pvary(operand, padding_value)
return pad_p.bind(operand, padding_value, padding_config=tuple(padding_config))
def rev(operand: ArrayLike, dimensions: Sequence[int]) -> Array:
@ -2873,7 +2873,7 @@ def select(pred: ArrayLike, on_true: ArrayLike, on_false: ArrayLike) -> Array:
"""
# Caution! The select_n_p primitive has the *opposite* order of arguments to
# select(). This is because it implements `select_n`.
pred, on_false, on_true = core.standard_insert_pbroadcast(
pred, on_false, on_true = core.standard_insert_pvary(
pred, on_false, on_true)
return select_n_p.bind(pred, on_false, on_true)
@ -2900,7 +2900,7 @@ def select_n(which: ArrayLike, *cases: ArrayLike) -> Array:
"""
if len(cases) == 0:
raise ValueError("select_n() must have at least one case")
which, *cases = core.standard_insert_pbroadcast(which, *cases)
which, *cases = core.standard_insert_pvary(which, *cases)
return select_n_p.bind(which, *cases)
@ -3262,7 +3262,7 @@ def sort(operand: Array | Sequence[Array], dimension: int = -1,
if not (1 <= num_keys <= len(operand)):
raise ValueError(f"{num_keys=} must be between 1 and {len(operand)=}")
dimension = canonicalize_axis(dimension, len(operand[0].shape))
operand = core.standard_insert_pbroadcast(*operand)
operand = core.standard_insert_pvary(*operand)
return tuple(sort_p.bind(*operand, dimension=dimension,
is_stable=is_stable,
num_keys=num_keys))
@ -8111,7 +8111,7 @@ def after_all(*operands):
"""Merges one or more XLA token values. Experimental.
Wraps the XLA AfterAll operator."""
operands = core.standard_insert_pbroadcast(*operands)
operands = core.standard_insert_pvary(*operands)
return after_all_p.bind(*operands)
def _after_all_abstract_eval(*operands):
@ -8246,7 +8246,7 @@ def rng_uniform(a, b, shape):
This API may be removed at any time.
"""
a, b = core.standard_insert_pbroadcast(a, b)
a, b = core.standard_insert_pvary(a, b)
return rng_uniform_p.bind(a, b, shape=tuple(shape))
def _rng_uniform_abstract_eval(a, b, *, shape):
@ -8930,7 +8930,7 @@ def optimization_barrier(operand, /):
"""
flat_args, treedef = tree_util.tree_flatten(operand)
# TODO(yashkatariya): Enable this
# flat_args = core.standard_insert_pbroadcast(flat_args)
# flat_args = core.standard_insert_pvary(flat_args)
out = optimization_barrier_p.bind(*flat_args)
return tree_util.tree_unflatten(treedef, out)

View File

@ -121,7 +121,7 @@ def cholesky_update(r_matrix: ArrayLike, w_vector: ArrayLike) -> Array:
A new upper-triangular matrix :math:`R` defining the Cholesky decomposition
of :math:`A + w \, w^T`.
"""
r_matrix, w_vector = core.standard_insert_pbroadcast(r_matrix, w_vector)
r_matrix, w_vector = core.standard_insert_pvary(r_matrix, w_vector)
return cholesky_update_p.bind(r_matrix, w_vector)
@ -269,7 +269,7 @@ def householder_product(a: ArrayLike, taus: ArrayLike) -> Array:
A batch of orthogonal (unitary) matrices with the same shape as ``a``,
containing the products of the elementary Householder reflectors.
"""
a, taus = core.standard_insert_pbroadcast(a, taus)
a, taus = core.standard_insert_pvary(a, taus)
return householder_product_p.bind(a, taus)
@ -547,7 +547,7 @@ def symmetric_product(
``symmetrize_output`` is ``True``, the upper triangle is filled with the
transpose of the lower triangle, and the whole matrix is valid.
"""
a_matrix, c_matrix = core.standard_insert_pbroadcast(a_matrix, c_matrix)
a_matrix, c_matrix = core.standard_insert_pvary(a_matrix, c_matrix)
result = symmetric_product_p.bind(a_matrix, c_matrix, alpha=alpha, beta=beta)
if symmetrize_output:
upper_half = lax.transpose(
@ -605,7 +605,7 @@ def triangular_solve(
singleton = np.ndim(b) == np.ndim(a) - 1
if singleton:
b = lax.expand_dims(b, (-1 if left_side else -2,))
a, b = core.standard_insert_pbroadcast(a, b)
a, b = core.standard_insert_pvary(a, b)
out = triangular_solve_p.bind(
a, b, left_side=left_side, lower=lower, transpose_a=transpose_a,
conjugate_a=conjugate_a, unit_diagonal=unit_diagonal)
@ -665,7 +665,7 @@ def tridiagonal_solve(dl: Array, d: Array, du: Array, b: Array) -> Array:
Returns:
Solution ``X`` of tridiagonal system.
"""
dl, d, du, b = core.standard_insert_pbroadcast(dl, d, du, b)
dl, d, du, b = core.standard_insert_pvary(dl, d, du, b)
return tridiagonal_solve_p.bind(dl, d, du, b)
@ -1658,7 +1658,7 @@ def _generic_lu_pivots_to_permutation(swaps, permutation_size):
if m == 0 or k == 0:
return permutation
upper = np.array(k, np.int32) if is_constant_dim(k) else k
permutation, swaps = core.standard_insert_pbroadcast(permutation, swaps)
permutation, swaps = core.standard_insert_pvary(permutation, swaps)
result, _ = lax.fori_loop(np.array(0, np.int32), upper, _lu_pivots_body_fn,
(permutation, swaps))
return result
@ -1774,7 +1774,7 @@ def geqp3(a: ArrayLike, jpvt: ArrayLike, *,
elementary Householder reflectors, and ``jpvt`` is the column-pivot indices
such that ``a[:, jpvt] = q @ r``.
"""
a, jpvt = core.standard_insert_pbroadcast(a, jpvt)
a, jpvt = core.standard_insert_pvary(a, jpvt)
a_out, jpvt_out, taus = geqp3_p.bind(a, jpvt, use_magma=use_magma)
return a_out, jpvt_out, taus

View File

@ -173,7 +173,7 @@ def dynamic_slice(
else:
dynamic_sizes = []
static_sizes = core.canonicalize_shape(slice_sizes) # type: ignore
operand, *start_indices = core.standard_insert_pbroadcast(
operand, *start_indices = core.standard_insert_pvary(
operand, *start_indices)
return dynamic_slice_p.bind(operand, *start_indices, *dynamic_sizes,
slice_sizes=tuple(static_sizes))
@ -236,7 +236,7 @@ def dynamic_update_slice(
"""
start_indices = _dynamic_slice_indices(
operand, start_indices, allow_negative_indices)
operand, update, *start_indices = core.standard_insert_pbroadcast(
operand, update, *start_indices = core.standard_insert_pvary(
operand, update, *start_indices)
return dynamic_update_slice_p.bind(operand, update, *start_indices)
@ -420,7 +420,7 @@ def gather(operand: ArrayLike, start_indices: ArrayLike,
raise ValueError(f"Unsupported dtype for gather fill_value {dtype}")
else:
fill_value = None
operand, start_indices = core.standard_insert_pbroadcast(operand, start_indices)
operand, start_indices = core.standard_insert_pvary(operand, start_indices)
return gather_p.bind(
operand, start_indices, dimension_numbers=dimension_numbers,
slice_sizes=core.canonicalize_shape(slice_sizes),
@ -510,7 +510,7 @@ def scatter_add(
"""
jaxpr, consts = lax._reduction_jaxpr(lax.add,
core.get_aval(lax._const(operand, 0)))
operand, scatter_indices, updates = core.standard_insert_pbroadcast(
operand, scatter_indices, updates = core.standard_insert_pvary(
operand, scatter_indices, updates)
return scatter_add_p.bind(
operand, scatter_indices, updates, update_jaxpr=jaxpr,
@ -566,7 +566,7 @@ def scatter_sub(
jaxpr, consts = lax._reduction_jaxpr(
lax.sub, core.get_aval(lax._const(operand, 0))
)
operand, scatter_indices, updates = core.standard_insert_pbroadcast(
operand, scatter_indices, updates = core.standard_insert_pvary(
operand, scatter_indices, updates)
return scatter_sub_p.bind(
operand,
@ -622,7 +622,7 @@ def scatter_mul(
"""
jaxpr, consts = lax._reduction_jaxpr(lax.mul,
core.get_aval(lax._const(operand, 1)))
operand, scatter_indices, updates = core.standard_insert_pbroadcast(
operand, scatter_indices, updates = core.standard_insert_pvary(
operand, scatter_indices, updates)
return scatter_mul_p.bind(
operand, scatter_indices, updates, update_jaxpr=jaxpr,
@ -671,7 +671,7 @@ def scatter_min(
"""
jaxpr, consts = lax._reduction_jaxpr(lax.min,
core.get_aval(lax._const(operand, 0)))
operand, scatter_indices, updates = core.standard_insert_pbroadcast(
operand, scatter_indices, updates = core.standard_insert_pvary(
operand, scatter_indices, updates)
return scatter_min_p.bind(
operand, scatter_indices, updates, update_jaxpr=jaxpr,
@ -720,7 +720,7 @@ def scatter_max(
"""
jaxpr, consts = lax._reduction_jaxpr(lax.max,
core.get_aval(lax._const(operand, 0)))
operand, scatter_indices, updates = core.standard_insert_pbroadcast(
operand, scatter_indices, updates = core.standard_insert_pvary(
operand, scatter_indices, updates)
return scatter_max_p.bind(
operand, scatter_indices, updates, update_jaxpr=jaxpr,
@ -786,7 +786,7 @@ def scatter_apply(
pass
jaxpr, consts = lax._reduction_jaxpr(_apply, core.get_aval(lax._zero(operand)))
# TODO: implement this via its own primitive so we can define appropriate autodiff rules.
operand, scatter_indices, unused = core.standard_insert_pbroadcast(
operand, scatter_indices, unused = core.standard_insert_pvary(
operand, scatter_indices, unused)
return scatter_p.bind(
operand, scatter_indices, unused, update_jaxpr=jaxpr,
@ -871,7 +871,7 @@ def scatter(
... mode=lax.GatherScatterMode.PROMISE_IN_BOUNDS)
Array([0., 2., 3., 0., 4.], dtype=float32)
"""
operand, scatter_indices, updates = core.standard_insert_pbroadcast(
operand, scatter_indices, updates = core.standard_insert_pvary(
operand, scatter_indices, updates)
return scatter_p.bind(
operand, scatter_indices, updates, update_jaxpr=None,

View File

@ -59,7 +59,7 @@ def _up_and_broadcast(doit):
def betainc(a: ArrayLike, b: ArrayLike, x: ArrayLike) -> Array:
r"""Elementwise regularized incomplete beta integral."""
a, b, x = core.standard_insert_pbroadcast(a, b, x)
a, b, x = core.standard_insert_pvary(a, b, x)
return regularized_incomplete_beta_p.bind(a, b, x)
def lgamma(x: ArrayLike) -> Array:
@ -72,33 +72,33 @@ def digamma(x: ArrayLike) -> Array:
def polygamma(m: ArrayLike, x: ArrayLike) -> Array:
r"""Elementwise polygamma: :math:`\psi^{(m)}(x)`."""
m, x = core.standard_insert_pbroadcast(m, x)
m, x = core.standard_insert_pvary(m, x)
return polygamma_p.bind(m, x)
def igamma(a: ArrayLike, x: ArrayLike) -> Array:
r"""Elementwise regularized incomplete gamma function."""
a, x = core.standard_insert_pbroadcast(a, x)
a, x = core.standard_insert_pvary(a, x)
return igamma_p.bind(a, x)
def igammac(a: ArrayLike, x: ArrayLike) -> Array:
r"""Elementwise complementary regularized incomplete gamma function."""
a, x = core.standard_insert_pbroadcast(a, x)
a, x = core.standard_insert_pvary(a, x)
return igammac_p.bind(a, x)
def igamma_grad_a(a: ArrayLike, x: ArrayLike) -> Array:
r"""Elementwise derivative of the regularized incomplete gamma function."""
a, x = core.standard_insert_pbroadcast(a, x)
a, x = core.standard_insert_pvary(a, x)
return igamma_grad_a_p.bind(a, x)
@_up_and_broadcast
def random_gamma_grad(a: ArrayLike, x: ArrayLike, *, dtype) -> Array:
r"""Elementwise derivative of samples from `Gamma(a, 1)`."""
a, x = core.standard_insert_pbroadcast(a, x)
a, x = core.standard_insert_pvary(a, x)
return random_gamma_grad_impl(a, x, dtype=dtype)
def zeta(x: ArrayLike, q: ArrayLike) -> Array:
r"""Elementwise Hurwitz zeta function: :math:`\zeta(x, q)`"""
x, q = core.standard_insert_pbroadcast(x, q)
x, q = core.standard_insert_pvary(x, q)
return zeta_p.bind(x, q)
def bessel_i0e(x: ArrayLike) -> Array:

View File

@ -97,7 +97,7 @@ def _reduce_window(
raise ValueError(
'reduce_window output must have the same tree structure as the operands'
f' {operand_tree} vs. {out_tree}')
flat_operands = core.standard_insert_pbroadcast(*flat_operands)
flat_operands = core.standard_insert_pvary(*flat_operands)
out_flat = reduce_window_p.bind(
*flat_operands,
*flat_init_values,
@ -251,7 +251,7 @@ def _select_and_scatter(operand: Array, select: Callable,
select, core.get_aval(init_value))
scatter_jaxpr, scatter_consts = lax._reduction_jaxpr(
scatter, core.get_aval(init_value))
operand, source, init_value = core.standard_insert_pbroadcast(
operand, source, init_value = core.standard_insert_pvary(
operand, source, init_value)
return select_and_scatter_p.bind(
operand, source, init_value, select_jaxpr=select_jaxpr,
@ -264,7 +264,7 @@ def _select_and_scatter_add(source: Array, operand: Array,
window_dimensions: core.Shape,
window_strides: Sequence[int],
padding: Sequence[tuple[int, int]]) -> Array:
source, operand = core.standard_insert_pbroadcast(source, operand)
source, operand = core.standard_insert_pvary(source, operand)
return select_and_scatter_add_p.bind(
source, operand, select_prim=select_prim,
window_dimensions=tuple(window_dimensions),
@ -300,7 +300,7 @@ def _select_and_gather_add(tangents: Array, operand: Array,
An array containing the elements in `tangents` corresponding to the output
of the reduction of `operand` fin each window.
"""
tangents, operand = core.standard_insert_pbroadcast(tangents, operand)
tangents, operand = core.standard_insert_pvary(tangents, operand)
return select_and_gather_add_p.bind(
tangents, operand, select_prim=select_prim,
window_dimensions=tuple(window_dimensions),

View File

@ -618,7 +618,7 @@ mlir.register_lowering(random_split_p, random_split_lowering)
def random_fold_in(keys, msgs):
msgs = jnp.asarray(msgs)
keys, msgs = core.standard_insert_pbroadcast(keys, msgs)
keys, msgs = core.standard_insert_pvary(keys, msgs)
return random_fold_in_p.bind(keys, msgs)
random_fold_in_p = core.Primitive('random_fold_in')