mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 13:26:06 +00:00
Integrate StableHLO at openxla/stablehlo@8817ff1d
PiperOrigin-RevId: 652528759
This commit is contained in:
parent
26ec43f9e5
commit
5e897c61f5
@ -198,6 +198,12 @@ class Jax2TfLimitation(test_harnesses.Limitation):
|
||||
devices=("cpu", "gpu"),
|
||||
tol=1e-13,
|
||||
modes=("eager", "graph", "compiled")),
|
||||
custom_numeric(
|
||||
dtypes=[np.complex64],
|
||||
devices=("tpu",),
|
||||
tol=1e-3,
|
||||
modes=("eager", "graph", "compiled"),
|
||||
native_serialization=Jax2TfLimitation.FOR_NON_NATIVE),
|
||||
]
|
||||
|
||||
@classmethod
|
||||
@ -206,7 +212,17 @@ class Jax2TfLimitation(test_harnesses.Limitation):
|
||||
custom_numeric(dtypes=[np.complex64], devices=("cpu", "gpu", "tpu"),
|
||||
tol=1e-3),
|
||||
custom_numeric(dtypes=[np.complex128], devices=("cpu", "gpu"), tol=1e-12),
|
||||
cls.helper_get_trig_custom_limitation(np.cosh)
|
||||
Jax2TfLimitation(
|
||||
"TF2XLA impl for Acosh doesn't properly handle large complex types,"
|
||||
" native serialization more closely matches numpy numerics.",
|
||||
dtypes=[np.complex64, np.complex128],
|
||||
devices=("cpu", "gpu", "tpu"),
|
||||
modes="compiled",
|
||||
expect_tf_error=False,
|
||||
skip_comparison=True,
|
||||
native_serialization=Jax2TfLimitation.FOR_NON_NATIVE,
|
||||
),
|
||||
cls.helper_get_trig_custom_limitation(np.cosh),
|
||||
]
|
||||
|
||||
@classmethod
|
||||
@ -281,6 +297,11 @@ class Jax2TfLimitation(test_harnesses.Limitation):
|
||||
custom_numeric(dtypes=[np.complex64], devices=("cpu", "gpu", "tpu"),
|
||||
tol=1e-3),
|
||||
custom_numeric(dtypes=[np.complex128], devices=("cpu", "gpu"), tol=1e-12),
|
||||
custom_numeric(dtypes=[np.complex64, np.complex128],
|
||||
devices=("cpu", "gpu", "tpu"),
|
||||
modes=("compiled",),
|
||||
tol=1e-3,
|
||||
native_serialization=Jax2TfLimitation.FOR_NON_NATIVE),
|
||||
custom_numeric(dtypes=[np.complex128], devices=("cpu",),
|
||||
modes=("eager", "compiled", "graph"),
|
||||
tol=1e-13,
|
||||
|
@ -50,11 +50,11 @@ def main(_):
|
||||
aval1 = ShapedArray((2, 3), np.dtype(np.float32))
|
||||
aval2 = ShapedArray((3, 4), np.dtype(np.int64))
|
||||
# CHECK-LABEL: TEST: simple
|
||||
# CHECK: stablehlo.custom_call @simple(%arg0) {api_version = 2 : i32} : (tensor<2x3xf32>) -> tensor<3x4xi64>
|
||||
# CHECK: stablehlo.custom_call @simple(%arg0) {api_version = 2 : i32, backend_config = ""} : (tensor<2x3xf32>) -> tensor<3x4xi64>
|
||||
print_custom_call("simple", [aval1], [aval2])
|
||||
|
||||
# CHECK-LABEL: TEST: sideeffect
|
||||
# CHECK: stablehlo.custom_call @sideeffect(%arg0) {has_side_effect = true} : (tensor<2x3xf32>) -> tensor<3x4xi64>
|
||||
# CHECK: stablehlo.custom_call @sideeffect(%arg0) {backend_config = "", has_side_effect = true} : (tensor<2x3xf32>) -> tensor<3x4xi64>
|
||||
print_custom_call("sideeffect", [aval1], [aval2], api_version=1,
|
||||
has_side_effect=True)
|
||||
|
||||
@ -64,17 +64,17 @@ def main(_):
|
||||
backend_config=b"hello")
|
||||
|
||||
# CHECK-LABEL: TEST: calledcomputations
|
||||
# CHECK: stablehlo.custom_call @calledcomputations(%arg0) {called_computations = [@a, @b]} : (tensor<2x3xf32>) -> tensor<3x4xi64>
|
||||
# CHECK: stablehlo.custom_call @calledcomputations(%arg0) {backend_config = "", called_computations = [@a, @b]} : (tensor<2x3xf32>) -> tensor<3x4xi64>
|
||||
print_custom_call("calledcomputations", [aval1], [aval2], api_version=1,
|
||||
called_computations=["a", "b"])
|
||||
|
||||
# CHECK-LABEL: TEST: aliases
|
||||
# CHECK: stablehlo.custom_call @aliases(%arg0, %arg1) {output_operand_aliases = [#stablehlo.output_operand_alias<output_tuple_indices = [0], operand_index = 1, operand_tuple_indices = []>]} : (tensor<2x3xf32>, tensor<3x4xi64>) -> (tensor<3x4xi64>, tensor<2x3xf32>)
|
||||
# CHECK: stablehlo.custom_call @aliases(%arg0, %arg1) {backend_config = "", output_operand_aliases = [#stablehlo.output_operand_alias<output_tuple_indices = [0], operand_index = 1, operand_tuple_indices = []>]} : (tensor<2x3xf32>, tensor<3x4xi64>) -> (tensor<3x4xi64>, tensor<2x3xf32>)
|
||||
print_custom_call("aliases", [aval1, aval2], [aval2, aval1], api_version=1,
|
||||
operand_output_aliases={1: 0})
|
||||
|
||||
# CHECK-LABEL: TEST: layouts
|
||||
# CHECK: stablehlo.custom_call @layouts(%arg0, %arg1) {operand_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<[1, 0]> : tensor<2xindex>], result_layouts = [dense<[1, 0]> : tensor<2xindex>, dense<[0, 1]> : tensor<2xindex>]} : (tensor<2x3xf32>, tensor<3x4xi64>) -> (tensor<3x4xi64>, tensor<2x3xf32>)
|
||||
# CHECK: stablehlo.custom_call @layouts(%arg0, %arg1) {backend_config = "", operand_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<[1, 0]> : tensor<2xindex>], result_layouts = [dense<[1, 0]> : tensor<2xindex>, dense<[0, 1]> : tensor<2xindex>]} : (tensor<2x3xf32>, tensor<3x4xi64>) -> (tensor<3x4xi64>, tensor<2x3xf32>)
|
||||
print_custom_call("layouts", [aval1, aval2], [aval2, aval1], api_version=1,
|
||||
operand_layouts=[[0, 1], [1, 0]],
|
||||
result_layouts=[[1, 0], [0, 1]])
|
||||
|
Loading…
x
Reference in New Issue
Block a user