1
0
mirror of https://github.com/ROCm/jax.git synced 2025-04-19 13:26:06 +00:00
PiperOrigin-RevId: 652528759
This commit is contained in:
Kevin Gleason 2024-07-15 10:37:14 -07:00 committed by jax authors
parent 26ec43f9e5
commit 5e897c61f5
2 changed files with 27 additions and 6 deletions
jax/experimental/jax2tf/tests
tests/filecheck

@ -198,6 +198,12 @@ class Jax2TfLimitation(test_harnesses.Limitation):
devices=("cpu", "gpu"),
tol=1e-13,
modes=("eager", "graph", "compiled")),
custom_numeric(
dtypes=[np.complex64],
devices=("tpu",),
tol=1e-3,
modes=("eager", "graph", "compiled"),
native_serialization=Jax2TfLimitation.FOR_NON_NATIVE),
]
@classmethod
@ -206,7 +212,17 @@ class Jax2TfLimitation(test_harnesses.Limitation):
custom_numeric(dtypes=[np.complex64], devices=("cpu", "gpu", "tpu"),
tol=1e-3),
custom_numeric(dtypes=[np.complex128], devices=("cpu", "gpu"), tol=1e-12),
cls.helper_get_trig_custom_limitation(np.cosh)
Jax2TfLimitation(
"TF2XLA impl for Acosh doesn't properly handle large complex types,"
" native serialization more closely matches numpy numerics.",
dtypes=[np.complex64, np.complex128],
devices=("cpu", "gpu", "tpu"),
modes="compiled",
expect_tf_error=False,
skip_comparison=True,
native_serialization=Jax2TfLimitation.FOR_NON_NATIVE,
),
cls.helper_get_trig_custom_limitation(np.cosh),
]
@classmethod
@ -281,6 +297,11 @@ class Jax2TfLimitation(test_harnesses.Limitation):
custom_numeric(dtypes=[np.complex64], devices=("cpu", "gpu", "tpu"),
tol=1e-3),
custom_numeric(dtypes=[np.complex128], devices=("cpu", "gpu"), tol=1e-12),
custom_numeric(dtypes=[np.complex64, np.complex128],
devices=("cpu", "gpu", "tpu"),
modes=("compiled",),
tol=1e-3,
native_serialization=Jax2TfLimitation.FOR_NON_NATIVE),
custom_numeric(dtypes=[np.complex128], devices=("cpu",),
modes=("eager", "compiled", "graph"),
tol=1e-13,

@ -50,11 +50,11 @@ def main(_):
aval1 = ShapedArray((2, 3), np.dtype(np.float32))
aval2 = ShapedArray((3, 4), np.dtype(np.int64))
# CHECK-LABEL: TEST: simple
# CHECK: stablehlo.custom_call @simple(%arg0) {api_version = 2 : i32} : (tensor<2x3xf32>) -> tensor<3x4xi64>
# CHECK: stablehlo.custom_call @simple(%arg0) {api_version = 2 : i32, backend_config = ""} : (tensor<2x3xf32>) -> tensor<3x4xi64>
print_custom_call("simple", [aval1], [aval2])
# CHECK-LABEL: TEST: sideeffect
# CHECK: stablehlo.custom_call @sideeffect(%arg0) {has_side_effect = true} : (tensor<2x3xf32>) -> tensor<3x4xi64>
# CHECK: stablehlo.custom_call @sideeffect(%arg0) {backend_config = "", has_side_effect = true} : (tensor<2x3xf32>) -> tensor<3x4xi64>
print_custom_call("sideeffect", [aval1], [aval2], api_version=1,
has_side_effect=True)
@ -64,17 +64,17 @@ def main(_):
backend_config=b"hello")
# CHECK-LABEL: TEST: calledcomputations
# CHECK: stablehlo.custom_call @calledcomputations(%arg0) {called_computations = [@a, @b]} : (tensor<2x3xf32>) -> tensor<3x4xi64>
# CHECK: stablehlo.custom_call @calledcomputations(%arg0) {backend_config = "", called_computations = [@a, @b]} : (tensor<2x3xf32>) -> tensor<3x4xi64>
print_custom_call("calledcomputations", [aval1], [aval2], api_version=1,
called_computations=["a", "b"])
# CHECK-LABEL: TEST: aliases
# CHECK: stablehlo.custom_call @aliases(%arg0, %arg1) {output_operand_aliases = [#stablehlo.output_operand_alias<output_tuple_indices = [0], operand_index = 1, operand_tuple_indices = []>]} : (tensor<2x3xf32>, tensor<3x4xi64>) -> (tensor<3x4xi64>, tensor<2x3xf32>)
# CHECK: stablehlo.custom_call @aliases(%arg0, %arg1) {backend_config = "", output_operand_aliases = [#stablehlo.output_operand_alias<output_tuple_indices = [0], operand_index = 1, operand_tuple_indices = []>]} : (tensor<2x3xf32>, tensor<3x4xi64>) -> (tensor<3x4xi64>, tensor<2x3xf32>)
print_custom_call("aliases", [aval1, aval2], [aval2, aval1], api_version=1,
operand_output_aliases={1: 0})
# CHECK-LABEL: TEST: layouts
# CHECK: stablehlo.custom_call @layouts(%arg0, %arg1) {operand_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<[1, 0]> : tensor<2xindex>], result_layouts = [dense<[1, 0]> : tensor<2xindex>, dense<[0, 1]> : tensor<2xindex>]} : (tensor<2x3xf32>, tensor<3x4xi64>) -> (tensor<3x4xi64>, tensor<2x3xf32>)
# CHECK: stablehlo.custom_call @layouts(%arg0, %arg1) {backend_config = "", operand_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<[1, 0]> : tensor<2xindex>], result_layouts = [dense<[1, 0]> : tensor<2xindex>, dense<[0, 1]> : tensor<2xindex>]} : (tensor<2x3xf32>, tensor<3x4xi64>) -> (tensor<3x4xi64>, tensor<2x3xf32>)
print_custom_call("layouts", [aval1, aval2], [aval2, aval1], api_version=1,
operand_layouts=[[0, 1], [1, 0]],
result_layouts=[[1, 0], [0, 1]])