mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 12:26:07 +00:00
Merge pull request #13605 from gnecula:temp_2
PiperOrigin-RevId: 495270864
This commit is contained in:
commit
aaa70bcb3c
@ -19,7 +19,7 @@ from typing import Union, Sequence
|
||||
import numpy as np
|
||||
|
||||
from jax._src.api import jit, linear_transpose, ShapeDtypeStruct
|
||||
from jax.core import Primitive
|
||||
from jax.core import Primitive, is_constant_shape
|
||||
from jax.interpreters import mlir
|
||||
from jax.interpreters import xla
|
||||
from jax._src.util import prod
|
||||
@ -103,14 +103,15 @@ def fft_abstract_eval(x, fft_type, fft_lengths):
|
||||
return x.update(shape=shape, dtype=dtype)
|
||||
|
||||
def _fft_lowering(ctx, x, *, fft_type, fft_lengths):
|
||||
out_aval, = ctx.avals_out
|
||||
return [
|
||||
mhlo.FftOp(x, mhlo.FftTypeAttr.get(fft_type.name),
|
||||
mlir.dense_int_elements(fft_lengths)).result
|
||||
mlir.dense_int_elements(fft_lengths)).result
|
||||
]
|
||||
|
||||
|
||||
def _fft_lowering_cpu(ctx, x, *, fft_type, fft_lengths):
|
||||
if any(not is_constant_shape(a.shape) for a in (ctx.avals_in + ctx.avals_out)):
|
||||
raise NotImplementedError("Shape polymorphism for custom call is not implemented (fft); b/261671778")
|
||||
x_aval, = ctx.avals_in
|
||||
return [ducc_fft.ducc_fft_mhlo(x, x_aval.dtype, fft_type=fft_type,
|
||||
fft_lengths=fft_lengths)]
|
||||
|
@ -32,7 +32,7 @@ from jax.interpreters import xla
|
||||
from jax.interpreters import ad
|
||||
from jax.interpreters import batching
|
||||
from jax._src.util import prod
|
||||
from jax.core import Primitive, ShapedArray, raise_to_shaped
|
||||
from jax.core import Primitive, ShapedArray, raise_to_shaped, is_constant_shape
|
||||
from jax._src.lax.lax import (
|
||||
standard_primitive, standard_unop, naryop_dtype_rule, _float, _complex,
|
||||
_input_dtype)
|
||||
@ -423,6 +423,8 @@ def _cholesky_lowering(ctx, x):
|
||||
mlir.register_lowering(cholesky_p, _cholesky_lowering)
|
||||
|
||||
def _cholesky_cpu_gpu_lowering(potrf_impl, ctx, operand):
|
||||
if any(not is_constant_shape(a.shape) for a in (ctx.avals_in + ctx.avals_out)):
|
||||
raise NotImplementedError("Shape polymorphism for custom call is not implemented (cholesky); b/261671778")
|
||||
operand_aval, = ctx.avals_in
|
||||
out_aval, = ctx.avals_out
|
||||
batch_dims = operand_aval.shape[:-2]
|
||||
@ -483,6 +485,8 @@ def eig_abstract_eval(operand, *, compute_left_eigenvectors,
|
||||
|
||||
def _eig_cpu_lowering(ctx, operand, *, compute_left_eigenvectors,
|
||||
compute_right_eigenvectors):
|
||||
if any(not is_constant_shape(a.shape) for a in (ctx.avals_in + ctx.avals_out)):
|
||||
raise NotImplementedError("Shape polymorphism for custom call is not implemented (eig); b/261671778")
|
||||
operand_aval, = ctx.avals_in
|
||||
out_aval = ctx.avals_out[0]
|
||||
batch_dims = operand_aval.shape[:-2]
|
||||
@ -632,6 +636,9 @@ def _eigh_abstract_eval(operand, *, lower, sort_eigenvalues):
|
||||
|
||||
def _eigh_cpu_gpu_lowering(syevd_impl, ctx, operand, *, lower,
|
||||
sort_eigenvalues):
|
||||
if any(not is_constant_shape(a.shape) for a in (ctx.avals_in + ctx.avals_out)):
|
||||
raise NotImplementedError("Shape polymorphism for custom call is not implemented (eigh); b/261671778")
|
||||
|
||||
del sort_eigenvalues # The CPU/GPU implementations always sort.
|
||||
operand_aval, = ctx.avals_in
|
||||
v_aval, w_aval = ctx.avals_out
|
||||
@ -1171,6 +1178,8 @@ def _lu_batching_rule(batched_args, batch_dims):
|
||||
return lu_p.bind(x), (0, 0, 0)
|
||||
|
||||
def _lu_cpu_gpu_lowering(getrf_impl, ctx, operand):
|
||||
if any(not is_constant_shape(a.shape) for a in (ctx.avals_in + ctx.avals_out)):
|
||||
raise NotImplementedError("Shape polymorphism for custom call is not implemented (lu); b/261671778")
|
||||
operand_aval, = ctx.avals_in
|
||||
out_aval, pivot_aval, perm_aval = ctx.avals_out
|
||||
batch_dims = operand_aval.shape[:-2]
|
||||
@ -1315,6 +1324,8 @@ def _geqrf_translation_rule(ctx, avals_in, avals_out, operand):
|
||||
return xops.QrDecomposition(operand)
|
||||
|
||||
def _geqrf_cpu_gpu_lowering(geqrf_impl, batched_geqrf_impl, ctx, a):
|
||||
if any(not is_constant_shape(a.shape) for a in (ctx.avals_in + ctx.avals_out)):
|
||||
raise NotImplementedError("Shape polymorphism for custom call is not implemented (geqrf); b/261671778")
|
||||
a_aval, taus_aval = ctx.avals_out
|
||||
*batch_dims, m, n = a_aval.shape
|
||||
batch = prod(batch_dims)
|
||||
@ -1404,6 +1415,8 @@ def _householder_product_translation_rule(ctx, avals_in, avals_out, a, taus):
|
||||
return [xops.ProductOfElementaryHouseholderReflectors(a, taus)]
|
||||
|
||||
def _householder_product_cpu_gpu_lowering(orgqr_impl, ctx, a, taus):
|
||||
if any(not is_constant_shape(a.shape) for a in (ctx.avals_in + ctx.avals_out)):
|
||||
raise NotImplementedError("Shape polymorphism for custom call is not implemented (householder product); b/261671778")
|
||||
a_aval, _ = ctx.avals_in
|
||||
*batch_dims, m, n = a_aval.shape
|
||||
|
||||
@ -1614,6 +1627,8 @@ def _empty_svd(a, *, full_matrices, compute_uv):
|
||||
|
||||
def _svd_cpu_gpu_lowering(gesvd_impl, ctx, operand, *, full_matrices,
|
||||
compute_uv):
|
||||
if any(not is_constant_shape(a.shape) for a in (ctx.avals_in + ctx.avals_out)):
|
||||
raise NotImplementedError("Shape polymorphism for custom call is not implemented (svd); b/261671778")
|
||||
operand_aval, = ctx.avals_in
|
||||
s_aval = ctx.avals_out[0]
|
||||
m, n = operand_aval.shape[-2:]
|
||||
|
@ -2032,6 +2032,35 @@ def _flatten_harnesses(harnesses):
|
||||
res.append(h)
|
||||
return res
|
||||
|
||||
# Set of harness.group_name:platform that are implemented with custom call
|
||||
custom_call_harnesses = {
|
||||
"cholesky:cpu", "cholesky:gpu", "eig:cpu",
|
||||
"eigh:cpu", "eigh:gpu", "fft:cpu",
|
||||
"householder_product:cpu", "householder_product:gpu",
|
||||
"geqrf:cpu", "geqrf:gpu", "lu:cpu", "lu:gpu", "qr:cpu", "qr:gpu",
|
||||
"random_gamma:gpu", "random_categorical:gpu",
|
||||
"random_randint:gpu", "random_uniform:gpu", "random_split:gpu",
|
||||
"svd:cpu", "svd:gpu"}
|
||||
|
||||
# Set of harness.group_name or harness.group_name:platform that are implemented with HLO fallback lowering rules
|
||||
fallback_lowering_harnesses = {
|
||||
"approx_top_k:cpu", "bessel_i0e", "eigh:tpu",
|
||||
"erf_inv", "igamma", "igammac", "lu",
|
||||
"regularized_incomplete_beta", "qr:tpu",
|
||||
"random_gamma:cpu", "random_gamma:tpu", "svd:tpu"}
|
||||
|
||||
def _exclude_native_lowering_harnesses(harness: Harness):
|
||||
if config.jax2tf_default_experimental_native_lowering and not harness.params.get("enable_xla", True):
|
||||
raise unittest.SkipTest("disabled for experimental_native_lowering and enable_xla=False")
|
||||
if config.jax2tf_default_experimental_native_lowering:
|
||||
if f"{harness.group_name}:{jtu.device_under_test()}" in custom_call_harnesses:
|
||||
raise unittest.SkipTest("native lowering with shape polymorphism not implemented for custom calls; b/261671778")
|
||||
if (config.jax2tf_default_experimental_native_lowering and
|
||||
(harness.group_name in fallback_lowering_harnesses or
|
||||
f"{harness.group_name}:{jtu.device_under_test()}" in fallback_lowering_harnesses)):
|
||||
raise unittest.SkipTest(
|
||||
"native lowering with shape polymorphism not implemented for JAX primitives still using HLO fallback lowering; b/261682623")
|
||||
|
||||
class ShapePolyPrimitivesTest(tf_test_util.JaxToTfTestCase):
|
||||
"""Tests for primitives that take shape values as parameters."""
|
||||
|
||||
@ -2043,11 +2072,10 @@ class ShapePolyPrimitivesTest(tf_test_util.JaxToTfTestCase):
|
||||
# to parameterized below.
|
||||
@primitive_harness.parameterized(
|
||||
_flatten_harnesses(_POLY_SHAPE_TEST_HARNESSES),
|
||||
#one_containing="roll_axis=None",
|
||||
#one_containing="",
|
||||
)
|
||||
def test_prim(self, harness: Harness):
|
||||
if config.jax2tf_default_experimental_native_lowering and not harness.params.get("enable_xla", True):
|
||||
raise unittest.SkipTest("disabled for experimental_native_lowering and enable_xla=False")
|
||||
_exclude_native_lowering_harnesses(harness)
|
||||
_test_one_harness(self, harness)
|
||||
|
||||
def test_vmap_while(self):
|
||||
@ -2207,6 +2235,7 @@ class ShapePolyVmapPrimitivesTest(tf_test_util.JaxToTfTestCase):
|
||||
one_containing=""
|
||||
)
|
||||
def test_vmap_prim(self, harness: Harness):
|
||||
_exclude_native_lowering_harnesses(harness)
|
||||
return _test_one_harness(self, harness)
|
||||
|
||||
|
||||
|
@ -1541,6 +1541,12 @@ def xla_fallback_lowering(prim: core.Primitive):
|
||||
axis_env = axis_ctx.unsafe_axis_env
|
||||
else:
|
||||
axis_env = module_ctx.axis_env
|
||||
|
||||
if any(hasattr(a, "shape") and
|
||||
not core.is_constant_shape(a.shape) for a in (ctx.avals_in + ctx.avals_out)):
|
||||
raise NotImplementedError(
|
||||
f"Shape polymorphism for xla_fallback_lowering is not implemented ({ctx.primitive}); b/261682623")
|
||||
|
||||
xla_computation = xla.primitive_subcomputation(
|
||||
module_ctx.platform, axis_env, prim, ctx.avals_in,
|
||||
ctx.avals_out, **params)
|
||||
|
@ -49,6 +49,8 @@ def _threefry2x32_lowering(prng, platform, keys, data):
|
||||
ir.IntegerType.get_unsigned(32)), keys[0].type
|
||||
typ = keys[0].type
|
||||
dims = ir.RankedTensorType(typ).shape
|
||||
if any(d < 0 for d in dims):
|
||||
raise NotImplementedError("Shape polymorphism for custom call is not implemented (threefry); b/261671778")
|
||||
|
||||
for x in itertools.chain(keys, data):
|
||||
assert x.type == typ, (x.type, typ)
|
||||
|
Loading…
x
Reference in New Issue
Block a user