rocm_jax/jax/_src/pallas/utils.py
Sharad Vikram 02f4531310 [Pallas TPU] Add helpers for writing collectives
PiperOrigin-RevId: 723250661
2025-02-04 15:39:10 -08:00

383 lines
14 KiB
Python

# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pallas utility functions."""
from __future__ import annotations
from typing import overload
import jax
from jax import lax
from jax._src import core as jax_core
from jax._src.util import split_list
import jax.numpy as jnp
import numpy as np
@overload
def cdiv(a: int, b: int) -> int:
...
@overload
def cdiv(a: int, b: jax.Array) -> jax.Array:
...
@overload
def cdiv(a: jax.Array, b: int) -> jax.Array:
...
@overload
def cdiv(a: jax.Array, b: jax.Array) -> jax.Array:
...
def cdiv(a: int | jax.Array, b: int | jax.Array) -> int | jax.Array:
if isinstance(a, int) and isinstance(b, int):
return (a + b - 1) // b
return lax.div(a + b - 1, b)
def strides_from_shape(shape: tuple[int, ...]) -> tuple[int, ...]:
size = np.prod(shape)
strides = []
for s in shape:
size = size // s
strides.append(int(size))
return tuple(strides)
def next_power_of_2(x: int) -> int:
"""Returns the next power of two greater than or equal to `x`."""
if x < 0:
raise ValueError("`next_power_of_2` requires a non-negative integer.")
return 1 if x == 0 else 2 ** (x - 1).bit_length()
def dtype_bitwidth(dtype: np.dtype | jnp.dtype) -> int:
if jnp.issubdtype(dtype, jnp.integer):
return jnp.iinfo(dtype).bits
return np.dtype(dtype).itemsize * 8
def pattern_match_scan_to_fori_loop(
jaxpr: jax_core.Jaxpr, num_consts: int, num_carry: int
) -> tuple[jax_core.Jaxpr, bool]:
if num_carry > 0:
# Pattern match onto fori_loop:
# We expect the first carry argument to the jaxpr to be the loop index and
# for the loop index + 1 to be returned as the first value out of the loop.
in_index_var = jaxpr.invars[num_consts]
out_index_var = jaxpr.outvars[0]
# Check that the loop index argument is an int32 scalar
if (in_index_var.aval.shape or
in_index_var.aval.dtype not in (jnp.int32, jnp.int64)):
raise NotImplementedError(
f"not a fori_loop index in: {in_index_var.aval} {jaxpr=}")
if (out_index_var.aval.shape or
out_index_var.aval.dtype not in (jnp.int32, jnp.int64)):
raise NotImplementedError(
f"not a fori_loop index out: {out_index_var.aval} {jaxpr=}")
# Look for the equation that increments the loop index
for i, eqn in enumerate(jaxpr.eqns):
if eqn.primitive == lax.add_p:
if eqn.invars[0] == in_index_var:
if isinstance(eqn.invars[1], jax_core.Literal):
if eqn.invars[1].val == 1:
if eqn.outvars[0] == out_index_var:
eqn_index = i
break
else:
raise NotImplementedError("Unable to match fori_loop pattern")
# Delete the equation that increments and remove the loop index from the
# output. Incrementing the loop index will be done implicitly.
jaxpr = jaxpr.replace(
eqns=jaxpr.eqns[:eqn_index] + jaxpr.eqns[eqn_index + 1:],
outvars=jaxpr.outvars[1:])
has_loop_index = True
else:
# If there's no carry, the loop index has been DCEd and the body does *not*
# expect a loop index as an argument.
has_loop_index = False
return jaxpr, has_loop_index
def pattern_match_while_to_fori_loop(
cond_jaxpr: jax_core.Jaxpr,
cond_nconsts: int,
body_jaxpr: jax_core.Jaxpr,
body_nconsts: int,
) -> tuple[jax_core.Jaxpr | None, str | None]:
# Try to pattern match to fori loop.
# Successful matches produce (jaxpr, None), while failures use the str
# component of the return tuple to capture information about the failure.
if cond_nconsts:
return (None, "Conditional jaxpr can't contain consts.")
_, cond_invars = split_list(cond_jaxpr.jaxpr.invars, [cond_nconsts])
cond_in_avals = [v.aval for v in cond_invars]
if len(cond_in_avals) < 2:
return (None, "Conditional jaxpr have only two carry args.")
# Check that the first two carry values are scalar ints
a1, a2 = cond_in_avals[:2]
if a1.shape or a1.dtype not in (jnp.int32, jnp.int64):
return (None, "First conditional jaxpr carry arg is not a scalar int.")
if a2.shape or a2.dtype not in (jnp.int32, jnp.int64):
return (None, "Second conditional jaxpr carry arg is not a scalar int.")
# Check that the only eqn in the cond checks the loop index condition
v1, v2 = cond_invars[:2]
outvar = cond_jaxpr.jaxpr.outvars[0]
assert outvar.aval.dtype == jnp.bool_
if len(cond_jaxpr.jaxpr.eqns) != 1:
return (None, "Non-trivial conditional jaxprs not supported.")
eqn = cond_jaxpr.jaxpr.eqns[0]
if eqn.primitive != lax.lt_p:
return (None, "Non-trivial conditional jaxprs not supported.")
if eqn.outvars != [outvar]:
return (None, "Non-trivial conditional jaxprs not supported.")
if eqn.invars != [v1, v2]:
return (None, "Non-trivial conditional jaxprs not supported.")
# Check that the carry is updated in the body appropriately
_, body_invars = split_list(body_jaxpr.jaxpr.invars, [body_nconsts])
v1, v2 = body_invars[:2]
vo1, vo2 = body_jaxpr.jaxpr.outvars[:2]
# Upper bound should be constant
if v2 is not vo2:
return (None, "Loop upper bound is not constant.")
# Check that we increment the loop index in the body
for i, eqn in enumerate(body_jaxpr.jaxpr.eqns):
if eqn.primitive is lax.add_p:
if eqn.invars[0] is v1:
if isinstance(eqn.invars[1], jax_core.Literal):
if eqn.invars[1].val == 1:
if eqn.outvars[0] == vo1:
eqn_index = i
break
else:
return (None, "Loop index not incremented in body.")
jaxpr = body_jaxpr.jaxpr
new_invars = (
*jaxpr.invars[:body_nconsts],
jaxpr.invars[body_nconsts],
*jaxpr.invars[body_nconsts + 2 :],
)
new_outvars = tuple(jaxpr.outvars[2:])
jaxpr = jaxpr.replace(
eqns=jaxpr.eqns[:eqn_index] + jaxpr.eqns[eqn_index + 1 :],
invars=new_invars,
outvars=new_outvars,
)
return jaxpr, None
# based on https://github.com/openxla/xla/blob/a7a09d56c3599123f8148bbf3e44c9ebc04624b9/xla/mlir_hlo/mhlo/transforms/chlo_legalize_to_hlo/chlo_legalize_to_hlo.cc#L644-L802
def _erf_inv_32_lowering_helper(x):
k_degree = 9
w_lt_5_constants = [
2.81022636e-08, 3.43273939e-07, -3.5233877e-06,
-4.39150654e-06, 0.00021858087, -0.00125372503,
-0.00417768164, 0.246640727, 1.50140941,
]
w_gt_5_constants = [
-0.000200214257, 0.000100950558, 0.00134934322,
-0.00367342844, 0.00573950773, -0.0076224613,
0.00943887047, 1.00167406, 2.83297682,
]
w = -jnp.log1p(x * -x)
w_lt_5 = w < 5.0
w = jnp.where(w_lt_5, w - 2.5, jnp.sqrt(w) - 3.0)
p = jnp.where(w_lt_5, w_lt_5_constants[0], w_gt_5_constants[0])
for i in range(1, k_degree):
c = jnp.where(w_lt_5, w_lt_5_constants[i], w_gt_5_constants[i])
p = c + p * w
return jnp.where(jnp.abs(x) == 1.0, jnp.inf * x, p * x)
# based on https://github.com/openxla/xla/blob/a7a09d56c3599123f8148bbf3e44c9ebc04624b9/xla/mlir_hlo/mhlo/transforms/chlo_legalize_to_hlo/chlo_legalize_to_hlo.cc#L696-L802
def _erf_inv_64_lowering_helper(x):
w_lt_625_constants = [
-3.6444120640178196996e-21, -1.685059138182016589e-19,
1.2858480715256400167e-18, 1.115787767802518096e-17,
-1.333171662854620906e-16, 2.0972767875968561637e-17,
6.6376381343583238325e-15, -4.0545662729752068639e-14,
-8.1519341976054721522e-14, 2.6335093153082322977e-12,
-1.2975133253453532498e-11, -5.4154120542946279317e-11,
1.051212273321532285e-09, -4.1126339803469836976e-09,
-2.9070369957882005086e-08, 4.2347877827932403518e-07,
-1.3654692000834678645e-06, -1.3882523362786468719e-05,
0.0001867342080340571352, -0.00074070253416626697512,
-0.0060336708714301490533, 0.24015818242558961693,
1.6536545626831027356
]
w_lt_16_constants = [
2.2137376921775787049e-09, 9.0756561938885390979e-08,
-2.7517406297064545428e-07, 1.8239629214389227755e-08,
1.5027403968909827627e-06, -4.013867526981545969e-06,
2.9234449089955446044e-06, 1.2475304481671778723e-05,
-4.7318229009055733981e-05, 6.8284851459573175448e-05,
2.4031110387097893999e-05, -0.0003550375203628474796,
0.00095328937973738049703, -0.0016882755560235047313,
0.0024914420961078508066, -0.0037512085075692412107,
0.005370914553590063617, 1.0052589676941592334,
3.0838856104922207635,
]
w_gt_16_constants = [
-2.7109920616438573243e-11, -2.5556418169965252055e-10,
1.5076572693500548083e-09, -3.7894654401267369937e-09,
7.6157012080783393804e-09, -1.4960026627149240478e-08,
2.9147953450901080826e-08, -6.7711997758452339498e-08,
2.2900482228026654717e-07, -9.9298272942317002539e-07,
4.5260625972231537039e-06, -1.9681778105531670567e-05,
7.5995277030017761139e-05, -0.00021503011930044477347,
-0.00013871931833623122026, 1.0103004648645343977,
4.8499064014085844221,
] # should add "as jnp.float64 array"?
w = -jnp.log1p(x * -x)
w_lt_625 = w < 6.25
w_lt_16 = w < 16.0
def get_coefficient(i):
c = w_lt_625_constants[i]
if i < 19:
c = jnp.where(w_lt_625, c, w_lt_16_constants[i])
if i < 17:
c = jnp.where(w_lt_16, c, w_gt_16_constants[i])
return c
select2 = jnp.where(w_lt_16, 3.25, 5.0)
select2_result = jnp.sqrt(w) - select2
w = jnp.where(w_lt_625, w - 3.125, select2_result)
p = get_coefficient(0)
for i in range(1, 17):
p = get_coefficient(i) + p * w
for i in range(17, 19):
p = jnp.where(w_lt_16, get_coefficient(i) + p * w, p)
for i in range(19, 23):
p = jnp.where(w_lt_625, get_coefficient(i) + p * w, p)
return jnp.where(jnp.abs(x) == 1.0, np.inf * x, p * x)
def erf_inv_lowering_helper(x):
if x.dtype == jnp.float32:
return _erf_inv_32_lowering_helper(x)
if x.dtype == jnp.float64:
return _erf_inv_64_lowering_helper(x)
raise NotImplementedError(f"erf_inv_lowering_helper not implemented for {x.dtype}")
def sign_lowering_helper(x):
if jnp.issubdtype(x.dtype, jnp.unsignedinteger):
return (x != 0).astype(x.dtype)
if jnp.issubdtype(x.dtype, jnp.integer):
return (x > 0).astype(x.dtype) - (x < 0).astype(x.dtype)
if jnp.issubdtype(x.dtype, jnp.floating):
out = (x > 0.).astype(x.dtype) - (x < 0.).astype(x.dtype)
return jnp.where(jnp.isnan(x), jnp.nan, out)
raise NotImplementedError(f"sign_lowering_helper not implemented for {x.dtype}")
# based on https://github.com/openxla/xla/blob/a7a09d56c3599123f8148bbf3e44c9ebc04624b9/xla/mlir_hlo/mhlo/transforms/chlo_legalize_to_hlo/chlo_legalize_to_hlo.cc#L1339-L1422
def nextafter_lowering_helper(x, y):
if x.dtype != y.dtype:
raise ValueError(
"The two inputs to `nextafter` must have the same dtype, but got"
f" {x.dtype} and {y.dtype}"
)
if x.dtype not in (jnp.float32, jnp.float64):
raise ValueError(
f"`nextafter` only supports float32 and float64, but got {x.dtype}"
)
jnp_float, jnp_uint, np_float, np_uint, np_int = (
jnp.float32, jnp.uint32, np.float32, np.uint32, np.int32,
) if x.dtype == jnp.float32 else (
jnp.float64, jnp.uint64, np.float64, np.uint64, np.int64,
)
bitwidth = dtype_bitwidth(x.dtype)
x_as_int = x.view(jnp_uint)
y_as_int = y.view(jnp_uint)
# The result is NaN if either "x" or "y" are NaN.
nan_input = jnp.isnan(x) | jnp.isnan(y)
result_for_nan = jnp.full_like(x_as_int, np_float(np.nan).view(np_uint))
# The sign bit is the MSB.
sign_bit = jnp_uint(1 << (bitwidth - 1))
# Discard the sign bit to make the result non-negative.
sign_mask = sign_bit
negated_sign_mask = ~sign_bit
x_abs = x_as_int & negated_sign_mask
y_abs = y_as_int & negated_sign_mask
# When both "x" and "y" are equal, the result is "y".
x_and_y_are_equal = x == y
result_for_equal = y_as_int
# When both "x" and "y" are 0, the result is "y". This is a separate case
# from above because "x" and "y" might have a different sign.
zero = jnp.zeros_like(x_as_int)
x_is_zero = x_abs == zero
y_is_zero = y_abs == zero
result_for_both_zero = y_as_int
x_sign = x_as_int & sign_mask
y_sign = y_as_int & sign_mask
# If x == 0 && y != 0, we need to return the smallest subnormal number
# signed like "y".
one = jnp.ones_like(x_as_int)
result_for_x_zero_y_non_zero = y_sign | one
# If the sign of "x" and "y" disagree:
# - we need to make the magnitude of "from" smaller so that it is closer to
# zero.
#
# Otherwise the signs agree:
# - "x" with a magnitude larger than "y" means we need to make the magnitude
# smaller.
# - "x" with a magnitude smaller than "y" means we need to make the magnitude
# larger.
signs_disagree = x_sign != y_sign
x_magnitude_larger_than_y = x_abs > y_abs
result_has_smaller_magnitude = x_magnitude_larger_than_y | signs_disagree
minus_one = jnp.full_like(x_as_int, np_int(-1).view(np_uint))
magnitude_adjustment = jnp.where(result_has_smaller_magnitude, minus_one, one)
result = x_as_int + magnitude_adjustment
# Handle x == +-0.
result = jnp.where(
x_is_zero,
jnp.where(y_is_zero, result_for_both_zero, result_for_x_zero_y_non_zero),
result,
)
# Handle x == y.
result = jnp.where(x_and_y_are_equal, result_for_equal, result)
# Handle isnan(x) || isnan(y).
result = jnp.where(nan_input, result_for_nan, result)
# Cast back to the original type.
return result.view(jnp_float)