mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
383 lines
14 KiB
Python
383 lines
14 KiB
Python
# Copyright 2023 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""Pallas utility functions."""
|
|
|
|
from __future__ import annotations
|
|
from typing import overload
|
|
|
|
import jax
|
|
from jax import lax
|
|
from jax._src import core as jax_core
|
|
from jax._src.util import split_list
|
|
import jax.numpy as jnp
|
|
import numpy as np
|
|
|
|
|
|
@overload
|
|
def cdiv(a: int, b: int) -> int:
|
|
...
|
|
|
|
@overload
|
|
def cdiv(a: int, b: jax.Array) -> jax.Array:
|
|
...
|
|
|
|
@overload
|
|
def cdiv(a: jax.Array, b: int) -> jax.Array:
|
|
...
|
|
|
|
@overload
|
|
def cdiv(a: jax.Array, b: jax.Array) -> jax.Array:
|
|
...
|
|
|
|
def cdiv(a: int | jax.Array, b: int | jax.Array) -> int | jax.Array:
|
|
if isinstance(a, int) and isinstance(b, int):
|
|
return (a + b - 1) // b
|
|
return lax.div(a + b - 1, b)
|
|
|
|
|
|
def strides_from_shape(shape: tuple[int, ...]) -> tuple[int, ...]:
|
|
size = np.prod(shape)
|
|
strides = []
|
|
for s in shape:
|
|
size = size // s
|
|
strides.append(int(size))
|
|
return tuple(strides)
|
|
|
|
|
|
def next_power_of_2(x: int) -> int:
|
|
"""Returns the next power of two greater than or equal to `x`."""
|
|
if x < 0:
|
|
raise ValueError("`next_power_of_2` requires a non-negative integer.")
|
|
return 1 if x == 0 else 2 ** (x - 1).bit_length()
|
|
|
|
def dtype_bitwidth(dtype: np.dtype | jnp.dtype) -> int:
|
|
if jnp.issubdtype(dtype, jnp.integer):
|
|
return jnp.iinfo(dtype).bits
|
|
return np.dtype(dtype).itemsize * 8
|
|
|
|
def pattern_match_scan_to_fori_loop(
|
|
jaxpr: jax_core.Jaxpr, num_consts: int, num_carry: int
|
|
) -> tuple[jax_core.Jaxpr, bool]:
|
|
if num_carry > 0:
|
|
# Pattern match onto fori_loop:
|
|
# We expect the first carry argument to the jaxpr to be the loop index and
|
|
# for the loop index + 1 to be returned as the first value out of the loop.
|
|
in_index_var = jaxpr.invars[num_consts]
|
|
out_index_var = jaxpr.outvars[0]
|
|
# Check that the loop index argument is an int32 scalar
|
|
if (in_index_var.aval.shape or
|
|
in_index_var.aval.dtype not in (jnp.int32, jnp.int64)):
|
|
raise NotImplementedError(
|
|
f"not a fori_loop index in: {in_index_var.aval} {jaxpr=}")
|
|
if (out_index_var.aval.shape or
|
|
out_index_var.aval.dtype not in (jnp.int32, jnp.int64)):
|
|
raise NotImplementedError(
|
|
f"not a fori_loop index out: {out_index_var.aval} {jaxpr=}")
|
|
# Look for the equation that increments the loop index
|
|
for i, eqn in enumerate(jaxpr.eqns):
|
|
if eqn.primitive == lax.add_p:
|
|
if eqn.invars[0] == in_index_var:
|
|
if isinstance(eqn.invars[1], jax_core.Literal):
|
|
if eqn.invars[1].val == 1:
|
|
if eqn.outvars[0] == out_index_var:
|
|
eqn_index = i
|
|
break
|
|
else:
|
|
raise NotImplementedError("Unable to match fori_loop pattern")
|
|
# Delete the equation that increments and remove the loop index from the
|
|
# output. Incrementing the loop index will be done implicitly.
|
|
jaxpr = jaxpr.replace(
|
|
eqns=jaxpr.eqns[:eqn_index] + jaxpr.eqns[eqn_index + 1:],
|
|
outvars=jaxpr.outvars[1:])
|
|
has_loop_index = True
|
|
else:
|
|
# If there's no carry, the loop index has been DCEd and the body does *not*
|
|
# expect a loop index as an argument.
|
|
has_loop_index = False
|
|
return jaxpr, has_loop_index
|
|
|
|
|
|
def pattern_match_while_to_fori_loop(
|
|
cond_jaxpr: jax_core.Jaxpr,
|
|
cond_nconsts: int,
|
|
body_jaxpr: jax_core.Jaxpr,
|
|
body_nconsts: int,
|
|
) -> tuple[jax_core.Jaxpr | None, str | None]:
|
|
# Try to pattern match to fori loop.
|
|
# Successful matches produce (jaxpr, None), while failures use the str
|
|
# component of the return tuple to capture information about the failure.
|
|
if cond_nconsts:
|
|
return (None, "Conditional jaxpr can't contain consts.")
|
|
_, cond_invars = split_list(cond_jaxpr.jaxpr.invars, [cond_nconsts])
|
|
cond_in_avals = [v.aval for v in cond_invars]
|
|
if len(cond_in_avals) < 2:
|
|
return (None, "Conditional jaxpr have only two carry args.")
|
|
# Check that the first two carry values are scalar ints
|
|
a1, a2 = cond_in_avals[:2]
|
|
if a1.shape or a1.dtype not in (jnp.int32, jnp.int64):
|
|
return (None, "First conditional jaxpr carry arg is not a scalar int.")
|
|
if a2.shape or a2.dtype not in (jnp.int32, jnp.int64):
|
|
return (None, "Second conditional jaxpr carry arg is not a scalar int.")
|
|
# Check that the only eqn in the cond checks the loop index condition
|
|
v1, v2 = cond_invars[:2]
|
|
outvar = cond_jaxpr.jaxpr.outvars[0]
|
|
assert outvar.aval.dtype == jnp.bool_
|
|
if len(cond_jaxpr.jaxpr.eqns) != 1:
|
|
return (None, "Non-trivial conditional jaxprs not supported.")
|
|
eqn = cond_jaxpr.jaxpr.eqns[0]
|
|
if eqn.primitive != lax.lt_p:
|
|
return (None, "Non-trivial conditional jaxprs not supported.")
|
|
if eqn.outvars != [outvar]:
|
|
return (None, "Non-trivial conditional jaxprs not supported.")
|
|
if eqn.invars != [v1, v2]:
|
|
return (None, "Non-trivial conditional jaxprs not supported.")
|
|
# Check that the carry is updated in the body appropriately
|
|
_, body_invars = split_list(body_jaxpr.jaxpr.invars, [body_nconsts])
|
|
v1, v2 = body_invars[:2]
|
|
vo1, vo2 = body_jaxpr.jaxpr.outvars[:2]
|
|
# Upper bound should be constant
|
|
if v2 is not vo2:
|
|
return (None, "Loop upper bound is not constant.")
|
|
# Check that we increment the loop index in the body
|
|
for i, eqn in enumerate(body_jaxpr.jaxpr.eqns):
|
|
if eqn.primitive is lax.add_p:
|
|
if eqn.invars[0] is v1:
|
|
if isinstance(eqn.invars[1], jax_core.Literal):
|
|
if eqn.invars[1].val == 1:
|
|
if eqn.outvars[0] == vo1:
|
|
eqn_index = i
|
|
break
|
|
else:
|
|
return (None, "Loop index not incremented in body.")
|
|
jaxpr = body_jaxpr.jaxpr
|
|
new_invars = (
|
|
*jaxpr.invars[:body_nconsts],
|
|
jaxpr.invars[body_nconsts],
|
|
*jaxpr.invars[body_nconsts + 2 :],
|
|
)
|
|
new_outvars = tuple(jaxpr.outvars[2:])
|
|
jaxpr = jaxpr.replace(
|
|
eqns=jaxpr.eqns[:eqn_index] + jaxpr.eqns[eqn_index + 1 :],
|
|
invars=new_invars,
|
|
outvars=new_outvars,
|
|
)
|
|
return jaxpr, None
|
|
|
|
|
|
# based on https://github.com/openxla/xla/blob/a7a09d56c3599123f8148bbf3e44c9ebc04624b9/xla/mlir_hlo/mhlo/transforms/chlo_legalize_to_hlo/chlo_legalize_to_hlo.cc#L644-L802
|
|
def _erf_inv_32_lowering_helper(x):
|
|
k_degree = 9
|
|
w_lt_5_constants = [
|
|
2.81022636e-08, 3.43273939e-07, -3.5233877e-06,
|
|
-4.39150654e-06, 0.00021858087, -0.00125372503,
|
|
-0.00417768164, 0.246640727, 1.50140941,
|
|
]
|
|
w_gt_5_constants = [
|
|
-0.000200214257, 0.000100950558, 0.00134934322,
|
|
-0.00367342844, 0.00573950773, -0.0076224613,
|
|
0.00943887047, 1.00167406, 2.83297682,
|
|
]
|
|
|
|
w = -jnp.log1p(x * -x)
|
|
w_lt_5 = w < 5.0
|
|
|
|
w = jnp.where(w_lt_5, w - 2.5, jnp.sqrt(w) - 3.0)
|
|
|
|
p = jnp.where(w_lt_5, w_lt_5_constants[0], w_gt_5_constants[0])
|
|
for i in range(1, k_degree):
|
|
c = jnp.where(w_lt_5, w_lt_5_constants[i], w_gt_5_constants[i])
|
|
p = c + p * w
|
|
|
|
return jnp.where(jnp.abs(x) == 1.0, jnp.inf * x, p * x)
|
|
|
|
|
|
# based on https://github.com/openxla/xla/blob/a7a09d56c3599123f8148bbf3e44c9ebc04624b9/xla/mlir_hlo/mhlo/transforms/chlo_legalize_to_hlo/chlo_legalize_to_hlo.cc#L696-L802
|
|
def _erf_inv_64_lowering_helper(x):
|
|
w_lt_625_constants = [
|
|
-3.6444120640178196996e-21, -1.685059138182016589e-19,
|
|
1.2858480715256400167e-18, 1.115787767802518096e-17,
|
|
-1.333171662854620906e-16, 2.0972767875968561637e-17,
|
|
6.6376381343583238325e-15, -4.0545662729752068639e-14,
|
|
-8.1519341976054721522e-14, 2.6335093153082322977e-12,
|
|
-1.2975133253453532498e-11, -5.4154120542946279317e-11,
|
|
1.051212273321532285e-09, -4.1126339803469836976e-09,
|
|
-2.9070369957882005086e-08, 4.2347877827932403518e-07,
|
|
-1.3654692000834678645e-06, -1.3882523362786468719e-05,
|
|
0.0001867342080340571352, -0.00074070253416626697512,
|
|
-0.0060336708714301490533, 0.24015818242558961693,
|
|
1.6536545626831027356
|
|
]
|
|
|
|
w_lt_16_constants = [
|
|
2.2137376921775787049e-09, 9.0756561938885390979e-08,
|
|
-2.7517406297064545428e-07, 1.8239629214389227755e-08,
|
|
1.5027403968909827627e-06, -4.013867526981545969e-06,
|
|
2.9234449089955446044e-06, 1.2475304481671778723e-05,
|
|
-4.7318229009055733981e-05, 6.8284851459573175448e-05,
|
|
2.4031110387097893999e-05, -0.0003550375203628474796,
|
|
0.00095328937973738049703, -0.0016882755560235047313,
|
|
0.0024914420961078508066, -0.0037512085075692412107,
|
|
0.005370914553590063617, 1.0052589676941592334,
|
|
3.0838856104922207635,
|
|
]
|
|
|
|
w_gt_16_constants = [
|
|
-2.7109920616438573243e-11, -2.5556418169965252055e-10,
|
|
1.5076572693500548083e-09, -3.7894654401267369937e-09,
|
|
7.6157012080783393804e-09, -1.4960026627149240478e-08,
|
|
2.9147953450901080826e-08, -6.7711997758452339498e-08,
|
|
2.2900482228026654717e-07, -9.9298272942317002539e-07,
|
|
4.5260625972231537039e-06, -1.9681778105531670567e-05,
|
|
7.5995277030017761139e-05, -0.00021503011930044477347,
|
|
-0.00013871931833623122026, 1.0103004648645343977,
|
|
4.8499064014085844221,
|
|
] # should add "as jnp.float64 array"?
|
|
|
|
w = -jnp.log1p(x * -x)
|
|
w_lt_625 = w < 6.25
|
|
w_lt_16 = w < 16.0
|
|
|
|
def get_coefficient(i):
|
|
c = w_lt_625_constants[i]
|
|
if i < 19:
|
|
c = jnp.where(w_lt_625, c, w_lt_16_constants[i])
|
|
if i < 17:
|
|
c = jnp.where(w_lt_16, c, w_gt_16_constants[i])
|
|
return c
|
|
|
|
select2 = jnp.where(w_lt_16, 3.25, 5.0)
|
|
select2_result = jnp.sqrt(w) - select2
|
|
w = jnp.where(w_lt_625, w - 3.125, select2_result)
|
|
|
|
p = get_coefficient(0)
|
|
for i in range(1, 17):
|
|
p = get_coefficient(i) + p * w
|
|
for i in range(17, 19):
|
|
p = jnp.where(w_lt_16, get_coefficient(i) + p * w, p)
|
|
for i in range(19, 23):
|
|
p = jnp.where(w_lt_625, get_coefficient(i) + p * w, p)
|
|
|
|
return jnp.where(jnp.abs(x) == 1.0, np.inf * x, p * x)
|
|
|
|
|
|
def erf_inv_lowering_helper(x):
|
|
if x.dtype == jnp.float32:
|
|
return _erf_inv_32_lowering_helper(x)
|
|
if x.dtype == jnp.float64:
|
|
return _erf_inv_64_lowering_helper(x)
|
|
raise NotImplementedError(f"erf_inv_lowering_helper not implemented for {x.dtype}")
|
|
|
|
|
|
def sign_lowering_helper(x):
|
|
if jnp.issubdtype(x.dtype, jnp.unsignedinteger):
|
|
return (x != 0).astype(x.dtype)
|
|
|
|
if jnp.issubdtype(x.dtype, jnp.integer):
|
|
return (x > 0).astype(x.dtype) - (x < 0).astype(x.dtype)
|
|
|
|
if jnp.issubdtype(x.dtype, jnp.floating):
|
|
out = (x > 0.).astype(x.dtype) - (x < 0.).astype(x.dtype)
|
|
return jnp.where(jnp.isnan(x), jnp.nan, out)
|
|
|
|
raise NotImplementedError(f"sign_lowering_helper not implemented for {x.dtype}")
|
|
|
|
|
|
# based on https://github.com/openxla/xla/blob/a7a09d56c3599123f8148bbf3e44c9ebc04624b9/xla/mlir_hlo/mhlo/transforms/chlo_legalize_to_hlo/chlo_legalize_to_hlo.cc#L1339-L1422
|
|
def nextafter_lowering_helper(x, y):
|
|
if x.dtype != y.dtype:
|
|
raise ValueError(
|
|
"The two inputs to `nextafter` must have the same dtype, but got"
|
|
f" {x.dtype} and {y.dtype}"
|
|
)
|
|
|
|
if x.dtype not in (jnp.float32, jnp.float64):
|
|
raise ValueError(
|
|
f"`nextafter` only supports float32 and float64, but got {x.dtype}"
|
|
)
|
|
|
|
jnp_float, jnp_uint, np_float, np_uint, np_int = (
|
|
jnp.float32, jnp.uint32, np.float32, np.uint32, np.int32,
|
|
) if x.dtype == jnp.float32 else (
|
|
jnp.float64, jnp.uint64, np.float64, np.uint64, np.int64,
|
|
)
|
|
|
|
bitwidth = dtype_bitwidth(x.dtype)
|
|
|
|
x_as_int = x.view(jnp_uint)
|
|
y_as_int = y.view(jnp_uint)
|
|
|
|
# The result is NaN if either "x" or "y" are NaN.
|
|
nan_input = jnp.isnan(x) | jnp.isnan(y)
|
|
result_for_nan = jnp.full_like(x_as_int, np_float(np.nan).view(np_uint))
|
|
|
|
# The sign bit is the MSB.
|
|
sign_bit = jnp_uint(1 << (bitwidth - 1))
|
|
# Discard the sign bit to make the result non-negative.
|
|
sign_mask = sign_bit
|
|
negated_sign_mask = ~sign_bit
|
|
x_abs = x_as_int & negated_sign_mask
|
|
y_abs = y_as_int & negated_sign_mask
|
|
|
|
# When both "x" and "y" are equal, the result is "y".
|
|
x_and_y_are_equal = x == y
|
|
result_for_equal = y_as_int
|
|
|
|
# When both "x" and "y" are 0, the result is "y". This is a separate case
|
|
# from above because "x" and "y" might have a different sign.
|
|
zero = jnp.zeros_like(x_as_int)
|
|
x_is_zero = x_abs == zero
|
|
y_is_zero = y_abs == zero
|
|
result_for_both_zero = y_as_int
|
|
|
|
x_sign = x_as_int & sign_mask
|
|
y_sign = y_as_int & sign_mask
|
|
|
|
# If x == 0 && y != 0, we need to return the smallest subnormal number
|
|
# signed like "y".
|
|
one = jnp.ones_like(x_as_int)
|
|
result_for_x_zero_y_non_zero = y_sign | one
|
|
|
|
# If the sign of "x" and "y" disagree:
|
|
# - we need to make the magnitude of "from" smaller so that it is closer to
|
|
# zero.
|
|
#
|
|
# Otherwise the signs agree:
|
|
# - "x" with a magnitude larger than "y" means we need to make the magnitude
|
|
# smaller.
|
|
# - "x" with a magnitude smaller than "y" means we need to make the magnitude
|
|
# larger.
|
|
signs_disagree = x_sign != y_sign
|
|
x_magnitude_larger_than_y = x_abs > y_abs
|
|
result_has_smaller_magnitude = x_magnitude_larger_than_y | signs_disagree
|
|
minus_one = jnp.full_like(x_as_int, np_int(-1).view(np_uint))
|
|
magnitude_adjustment = jnp.where(result_has_smaller_magnitude, minus_one, one)
|
|
result = x_as_int + magnitude_adjustment
|
|
|
|
# Handle x == +-0.
|
|
result = jnp.where(
|
|
x_is_zero,
|
|
jnp.where(y_is_zero, result_for_both_zero, result_for_x_zero_y_non_zero),
|
|
result,
|
|
)
|
|
|
|
# Handle x == y.
|
|
result = jnp.where(x_and_y_are_equal, result_for_equal, result)
|
|
|
|
# Handle isnan(x) || isnan(y).
|
|
result = jnp.where(nan_input, result_for_nan, result)
|
|
|
|
# Cast back to the original type.
|
|
return result.view(jnp_float)
|