Merge pull request #13605 from gnecula:temp_2

PiperOrigin-RevId: 495270864
This commit is contained in:
jax authors 2022-12-14 03:46:15 -08:00
commit aaa70bcb3c
5 changed files with 60 additions and 7 deletions

View File

@ -19,7 +19,7 @@ from typing import Union, Sequence
import numpy as np
from jax._src.api import jit, linear_transpose, ShapeDtypeStruct
from jax.core import Primitive
from jax.core import Primitive, is_constant_shape
from jax.interpreters import mlir
from jax.interpreters import xla
from jax._src.util import prod
@ -103,14 +103,15 @@ def fft_abstract_eval(x, fft_type, fft_lengths):
return x.update(shape=shape, dtype=dtype)
def _fft_lowering(ctx, x, *, fft_type, fft_lengths):
out_aval, = ctx.avals_out
return [
mhlo.FftOp(x, mhlo.FftTypeAttr.get(fft_type.name),
mlir.dense_int_elements(fft_lengths)).result
mlir.dense_int_elements(fft_lengths)).result
]
def _fft_lowering_cpu(ctx, x, *, fft_type, fft_lengths):
if any(not is_constant_shape(a.shape) for a in (ctx.avals_in + ctx.avals_out)):
raise NotImplementedError("Shape polymorphism for custom call is not implemented (fft); b/261671778")
x_aval, = ctx.avals_in
return [ducc_fft.ducc_fft_mhlo(x, x_aval.dtype, fft_type=fft_type,
fft_lengths=fft_lengths)]

View File

@ -32,7 +32,7 @@ from jax.interpreters import xla
from jax.interpreters import ad
from jax.interpreters import batching
from jax._src.util import prod
from jax.core import Primitive, ShapedArray, raise_to_shaped
from jax.core import Primitive, ShapedArray, raise_to_shaped, is_constant_shape
from jax._src.lax.lax import (
standard_primitive, standard_unop, naryop_dtype_rule, _float, _complex,
_input_dtype)
@ -423,6 +423,8 @@ def _cholesky_lowering(ctx, x):
mlir.register_lowering(cholesky_p, _cholesky_lowering)
def _cholesky_cpu_gpu_lowering(potrf_impl, ctx, operand):
if any(not is_constant_shape(a.shape) for a in (ctx.avals_in + ctx.avals_out)):
raise NotImplementedError("Shape polymorphism for custom call is not implemented (cholesky); b/261671778")
operand_aval, = ctx.avals_in
out_aval, = ctx.avals_out
batch_dims = operand_aval.shape[:-2]
@ -483,6 +485,8 @@ def eig_abstract_eval(operand, *, compute_left_eigenvectors,
def _eig_cpu_lowering(ctx, operand, *, compute_left_eigenvectors,
compute_right_eigenvectors):
if any(not is_constant_shape(a.shape) for a in (ctx.avals_in + ctx.avals_out)):
raise NotImplementedError("Shape polymorphism for custom call is not implemented (eig); b/261671778")
operand_aval, = ctx.avals_in
out_aval = ctx.avals_out[0]
batch_dims = operand_aval.shape[:-2]
@ -632,6 +636,9 @@ def _eigh_abstract_eval(operand, *, lower, sort_eigenvalues):
def _eigh_cpu_gpu_lowering(syevd_impl, ctx, operand, *, lower,
sort_eigenvalues):
if any(not is_constant_shape(a.shape) for a in (ctx.avals_in + ctx.avals_out)):
raise NotImplementedError("Shape polymorphism for custom call is not implemented (eigh); b/261671778")
del sort_eigenvalues # The CPU/GPU implementations always sort.
operand_aval, = ctx.avals_in
v_aval, w_aval = ctx.avals_out
@ -1171,6 +1178,8 @@ def _lu_batching_rule(batched_args, batch_dims):
return lu_p.bind(x), (0, 0, 0)
def _lu_cpu_gpu_lowering(getrf_impl, ctx, operand):
if any(not is_constant_shape(a.shape) for a in (ctx.avals_in + ctx.avals_out)):
raise NotImplementedError("Shape polymorphism for custom call is not implemented (lu); b/261671778")
operand_aval, = ctx.avals_in
out_aval, pivot_aval, perm_aval = ctx.avals_out
batch_dims = operand_aval.shape[:-2]
@ -1315,6 +1324,8 @@ def _geqrf_translation_rule(ctx, avals_in, avals_out, operand):
return xops.QrDecomposition(operand)
def _geqrf_cpu_gpu_lowering(geqrf_impl, batched_geqrf_impl, ctx, a):
if any(not is_constant_shape(a.shape) for a in (ctx.avals_in + ctx.avals_out)):
raise NotImplementedError("Shape polymorphism for custom call is not implemented (geqrf); b/261671778")
a_aval, taus_aval = ctx.avals_out
*batch_dims, m, n = a_aval.shape
batch = prod(batch_dims)
@ -1404,6 +1415,8 @@ def _householder_product_translation_rule(ctx, avals_in, avals_out, a, taus):
return [xops.ProductOfElementaryHouseholderReflectors(a, taus)]
def _householder_product_cpu_gpu_lowering(orgqr_impl, ctx, a, taus):
if any(not is_constant_shape(a.shape) for a in (ctx.avals_in + ctx.avals_out)):
raise NotImplementedError("Shape polymorphism for custom call is not implemented (householder product); b/261671778")
a_aval, _ = ctx.avals_in
*batch_dims, m, n = a_aval.shape
@ -1614,6 +1627,8 @@ def _empty_svd(a, *, full_matrices, compute_uv):
def _svd_cpu_gpu_lowering(gesvd_impl, ctx, operand, *, full_matrices,
compute_uv):
if any(not is_constant_shape(a.shape) for a in (ctx.avals_in + ctx.avals_out)):
raise NotImplementedError("Shape polymorphism for custom call is not implemented (svd); b/261671778")
operand_aval, = ctx.avals_in
s_aval = ctx.avals_out[0]
m, n = operand_aval.shape[-2:]

View File

@ -2032,6 +2032,35 @@ def _flatten_harnesses(harnesses):
res.append(h)
return res
# Set of harness.group_name:platform that are implemented with custom call
custom_call_harnesses = {
"cholesky:cpu", "cholesky:gpu", "eig:cpu",
"eigh:cpu", "eigh:gpu", "fft:cpu",
"householder_product:cpu", "householder_product:gpu",
"geqrf:cpu", "geqrf:gpu", "lu:cpu", "lu:gpu", "qr:cpu", "qr:gpu",
"random_gamma:gpu", "random_categorical:gpu",
"random_randint:gpu", "random_uniform:gpu", "random_split:gpu",
"svd:cpu", "svd:gpu"}
# Set of harness.group_name or harness.group_name:platform that are implemented with HLO fallback lowering rules
fallback_lowering_harnesses = {
"approx_top_k:cpu", "bessel_i0e", "eigh:tpu",
"erf_inv", "igamma", "igammac", "lu",
"regularized_incomplete_beta", "qr:tpu",
"random_gamma:cpu", "random_gamma:tpu", "svd:tpu"}
def _exclude_native_lowering_harnesses(harness: Harness):
if config.jax2tf_default_experimental_native_lowering and not harness.params.get("enable_xla", True):
raise unittest.SkipTest("disabled for experimental_native_lowering and enable_xla=False")
if config.jax2tf_default_experimental_native_lowering:
if f"{harness.group_name}:{jtu.device_under_test()}" in custom_call_harnesses:
raise unittest.SkipTest("native lowering with shape polymorphism not implemented for custom calls; b/261671778")
if (config.jax2tf_default_experimental_native_lowering and
(harness.group_name in fallback_lowering_harnesses or
f"{harness.group_name}:{jtu.device_under_test()}" in fallback_lowering_harnesses)):
raise unittest.SkipTest(
"native lowering with shape polymorphism not implemented for JAX primitives still using HLO fallback lowering; b/261682623")
class ShapePolyPrimitivesTest(tf_test_util.JaxToTfTestCase):
"""Tests for primitives that take shape values as parameters."""
@ -2043,11 +2072,10 @@ class ShapePolyPrimitivesTest(tf_test_util.JaxToTfTestCase):
# to parameterized below.
@primitive_harness.parameterized(
_flatten_harnesses(_POLY_SHAPE_TEST_HARNESSES),
#one_containing="roll_axis=None",
#one_containing="",
)
def test_prim(self, harness: Harness):
if config.jax2tf_default_experimental_native_lowering and not harness.params.get("enable_xla", True):
raise unittest.SkipTest("disabled for experimental_native_lowering and enable_xla=False")
_exclude_native_lowering_harnesses(harness)
_test_one_harness(self, harness)
def test_vmap_while(self):
@ -2207,6 +2235,7 @@ class ShapePolyVmapPrimitivesTest(tf_test_util.JaxToTfTestCase):
one_containing=""
)
def test_vmap_prim(self, harness: Harness):
_exclude_native_lowering_harnesses(harness)
return _test_one_harness(self, harness)

View File

@ -1541,6 +1541,12 @@ def xla_fallback_lowering(prim: core.Primitive):
axis_env = axis_ctx.unsafe_axis_env
else:
axis_env = module_ctx.axis_env
if any(hasattr(a, "shape") and
not core.is_constant_shape(a.shape) for a in (ctx.avals_in + ctx.avals_out)):
raise NotImplementedError(
f"Shape polymorphism for xla_fallback_lowering is not implemented ({ctx.primitive}); b/261682623")
xla_computation = xla.primitive_subcomputation(
module_ctx.platform, axis_env, prim, ctx.avals_in,
ctx.avals_out, **params)

View File

@ -49,6 +49,8 @@ def _threefry2x32_lowering(prng, platform, keys, data):
ir.IntegerType.get_unsigned(32)), keys[0].type
typ = keys[0].type
dims = ir.RankedTensorType(typ).shape
if any(d < 0 for d in dims):
raise NotImplementedError("Shape polymorphism for custom call is not implemented (threefry); b/261671778")
for x in itertools.chain(keys, data):
assert x.type == typ, (x.type, typ)