2022-09-22 12:26:48 -07:00
|
|
|
# Copyright 2020 The JAX Authors.
|
2020-10-11 19:48:36 +03:00
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
2023-07-21 14:20:39 -04:00
|
|
|
|
2023-12-11 13:59:29 +00:00
|
|
|
from __future__ import annotations
|
|
|
|
|
2024-06-26 14:44:52 -04:00
|
|
|
from collections.abc import Callable, Sequence
|
2022-12-20 15:29:51 +02:00
|
|
|
import contextlib
|
2023-04-13 11:48:11 -07:00
|
|
|
import math
|
2024-06-26 14:44:52 -04:00
|
|
|
from typing import Any
|
[jax2tf] An alternative support for shape polymorphism for native serialization.
jax2tf already supports many cases of shape polymorphism, e.g., those
where the shapes of all intermediates can be expressed as polynomials
in the dimension variables in the input. We want to achieve the same
same coverage, or more, while using StableHLO as the lowering format,
rather than tf.Graph.
For native serialization we will support two lowering implementations:
* one is using the growing support in JAX for dynamic shapes,
of which shape polymorphism is a special case.
This implementation is enabled with the --jax_dynamic_shapes flag.
At the moment, the JAX dynamic shapes support is still
incomplete and over 300 jax2tf shape polymorphism tests fail.
* a new one (added) here in which we form a Jaxpr using abstract
values that express dimension sizes as dimension polynomials
(as for the standard jax2tf). Then we lower the Jaxpr to StableHLO.
This implementation is enabled when --jax_dynamic_shapes is off.
With this implementation only 50 jax2tf tests fail (to be fixed
separately).
The key contribution here is to enable lowering a Jaxpr that contains
dimension polynomials in some of the intermediate shapes. Many lowering rules
already have some partial support for Jaxprs where the shapes contain
`Var`s. To the extent possible, we try to write lowering rules that should
cover both cases of dynamic shapes: Var or polynomials in shapes.
The lowering convention is that at top level we collect the sorted list
of dimension variable names in the inputs, and we store it in ModuleContext.dim_vars.
All IR functions will take N additional prefix arguments of int32 type
containing the values of the dimension variables. This is stored as
a list of `ir.Value` in `LoweringContext.dim_var_values`.
Note that the Jaxprs are not changed to have extra Vars for the dimension
variable values. An alternative implementation could work by transforming
the Jaxpr to replace dimension polynomials into Vars.
The key code pattern used in the lowering rule is::
if not core.is_constant_shape(shape): # Handles both Var, and polynomials
shape = mlir.eval_dynamic_shape(ctx, shape)
return mhlo.DynamicXXX(..., shape)
else:
return mhlo.XXX(..., shape)
with `mlir.eval_dynamic_shape` handling both cases::
def eval_dynamic_shape(ctx, shape):
if config.jax_dynamic_shapes:
# Using Var
return ... subst using ctx.axis_size_env ...
else:
# Using polynomials
return ... subst using ctx.module_context.dim_vars and ctx.dim_var_values
In order to support the above some lowering functions need to take a
LoweringContext parameter, e.g., mlir.broadcast_mhlo.
I expect that the changes here will improve the --jax_dynamic_shapes coverage
as well.
2022-11-28 13:16:07 +01:00
|
|
|
import unittest
|
2020-10-11 19:48:36 +03:00
|
|
|
|
2023-05-09 02:28:49 -07:00
|
|
|
from absl import logging
|
2023-11-09 13:57:30 -08:00
|
|
|
from absl.testing import absltest
|
2020-10-11 19:48:36 +03:00
|
|
|
|
2021-04-01 15:37:01 +03:00
|
|
|
import collections
|
2020-10-11 19:48:36 +03:00
|
|
|
import functools
|
2021-04-09 14:02:44 +03:00
|
|
|
from functools import partial
|
2023-01-18 12:27:02 +02:00
|
|
|
import operator as op
|
2020-10-11 19:48:36 +03:00
|
|
|
import re
|
|
|
|
|
|
|
|
import jax
|
|
|
|
from jax.experimental import jax2tf
|
2023-02-04 08:30:44 +02:00
|
|
|
from jax.experimental import pjit
|
2024-06-10 09:45:09 +02:00
|
|
|
from jax import export
|
2020-10-15 08:24:35 +03:00
|
|
|
from jax import lax
|
2020-10-11 19:48:36 +03:00
|
|
|
import jax.numpy as jnp
|
2022-12-04 08:37:24 +02:00
|
|
|
from jax import random
|
2023-04-26 08:46:52 +02:00
|
|
|
from jax import tree_util
|
2025-02-07 10:15:47 +02:00
|
|
|
from jax._src import api_util
|
2023-10-12 13:15:22 +01:00
|
|
|
from jax._src import config
|
2023-06-28 08:22:16 -07:00
|
|
|
from jax._src import core
|
2021-09-24 07:02:08 -07:00
|
|
|
from jax._src import test_util as jtu
|
2022-12-17 05:56:48 +02:00
|
|
|
from jax._src import util
|
2024-06-04 22:02:36 -07:00
|
|
|
from jax._src.export import shape_poly
|
2022-03-08 13:45:06 -08:00
|
|
|
from jax._src.lax import lax as lax_internal
|
2021-04-04 16:23:24 +03:00
|
|
|
from jax._src.lax import control_flow as lax_control_flow
|
2020-10-11 19:48:36 +03:00
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
from jax.experimental.jax2tf.tests import tf_test_util
|
|
|
|
|
2024-05-17 09:46:36 +01:00
|
|
|
import tensorflow as tf
|
2020-10-11 19:48:36 +03:00
|
|
|
|
|
|
|
config.parse_flags_with_absl()
|
|
|
|
|
2021-04-01 15:37:01 +03:00
|
|
|
# Import after parsing flags
|
2023-11-09 13:57:30 -08:00
|
|
|
from jax._src.internal_test_util import test_harnesses
|
|
|
|
from jax._src.internal_test_util.test_harnesses import Harness, CustomArg, RandArg, StaticArg
|
2021-04-08 06:21:12 -07:00
|
|
|
from jax.experimental.jax2tf.tests.jax2tf_limitations import Jax2TfLimitation
|
2020-10-11 19:48:36 +03:00
|
|
|
|
2022-12-17 05:56:48 +02:00
|
|
|
_f32 = np.float32
|
2023-03-02 14:31:32 +01:00
|
|
|
_i32 = np.int32
|
|
|
|
|
|
|
|
expect_error_associative_scan = (
|
2023-09-27 12:10:06 -07:00
|
|
|
NotImplementedError,
|
|
|
|
"associative scan over axis of non-constant size",
|
|
|
|
)
|
2023-03-02 14:31:32 +01:00
|
|
|
|
2020-10-11 19:48:36 +03:00
|
|
|
|
2022-12-17 05:56:48 +02:00
|
|
|
class PolyHarness(Harness):
|
|
|
|
"""Tests a function with shape polymorphism.
|
|
|
|
|
|
|
|
Converts `fun` with shape polymorphism, creates a `tf.ConcreteFunction`
|
|
|
|
given `input_signature` and checks the inferred output shapes to match
|
|
|
|
`expected_output_shapes`, then checks that the JAX and the TF functions
|
|
|
|
produce the same results.
|
|
|
|
"""
|
|
|
|
def __init__(self,
|
|
|
|
group_name: str, name: str,
|
|
|
|
fun: Callable,
|
|
|
|
*,
|
2023-11-09 13:57:30 -08:00
|
|
|
arg_descriptors: Sequence[test_harnesses.ArgDescriptor] = (),
|
2023-12-11 13:59:29 +00:00
|
|
|
polymorphic_shapes: Sequence[str | None] = (),
|
2024-01-01 23:09:42 +07:00
|
|
|
polymorphic_constraints: Sequence[str] = (),
|
2023-12-11 13:59:29 +00:00
|
|
|
input_signature: Sequence[tf.TensorSpec] | None = None,
|
|
|
|
expected_output_signature: tf.TensorSpec | None = None,
|
|
|
|
expect_error: tuple[Any | None, str | None] = (None, None),
|
2022-12-17 05:56:48 +02:00
|
|
|
skip_jax_run: bool = False,
|
|
|
|
check_result: bool = True,
|
2023-12-11 13:59:29 +00:00
|
|
|
tol: float | None = None,
|
2023-05-04 09:52:21 +02:00
|
|
|
limitations: Sequence[Jax2TfLimitation] = (),
|
2023-06-23 15:11:37 -07:00
|
|
|
override_jax_config_flags: dict[str, Any] = {}):
|
2022-12-17 05:56:48 +02:00
|
|
|
"""Args:
|
|
|
|
|
|
|
|
group_name, name: The name for the harness. See `Harness.__init__`.
|
2023-09-22 14:54:31 -07:00
|
|
|
fun: the function to be converted, possibly after partial application to
|
2022-12-17 05:56:48 +02:00
|
|
|
static arguments from `arg_descriptors`. See `Harness.__init__`.
|
|
|
|
arg_descriptors: The argument descriptors. See `Harness.__init__`. May
|
|
|
|
be missing, in which case `skip_jax_run` should be `True` and
|
2023-06-16 12:50:50 +03:00
|
|
|
`input_signature` must be present.
|
|
|
|
polymorphic_shapes: For `jax2tf.convert`.
|
2024-01-01 23:09:42 +07:00
|
|
|
polymorphic_constraints: For `jax2tf.convert`.
|
2022-12-17 05:56:48 +02:00
|
|
|
input_signature: For `tf.function.get_concrete_function`. If missing,
|
2023-06-16 12:50:50 +03:00
|
|
|
generated from `polymorphic_shapes`.
|
2022-12-17 05:56:48 +02:00
|
|
|
expected_output_signature: the expected inferred output shape.
|
|
|
|
expect_error: a pair of an Exception type and a regular expression to
|
|
|
|
match the expected exception string.
|
|
|
|
skip_jax_run: If True, then neither the JAX nor the TF functions are
|
|
|
|
executed.
|
|
|
|
check_result: specifies if we want to check that the result of the shape
|
|
|
|
polymorphic conversion produces the same result and the JAX function.
|
|
|
|
tol: the tolerance to use for checking results.
|
2023-05-09 02:28:49 -07:00
|
|
|
limitations: if given, then apply the custom_assert and tolerance from the
|
|
|
|
Jax2TfLimitations.
|
2023-05-04 09:52:21 +02:00
|
|
|
override_jax_config_flags: jax.config flags to override for the duration
|
|
|
|
of the test.
|
2022-12-17 05:56:48 +02:00
|
|
|
"""
|
|
|
|
super().__init__(group_name, name, fun, arg_descriptors,
|
|
|
|
dtype=np.float32)
|
|
|
|
self.polymorphic_shapes = polymorphic_shapes
|
2024-01-01 23:09:42 +07:00
|
|
|
self.polymorphic_constraints = polymorphic_constraints
|
2022-12-17 05:56:48 +02:00
|
|
|
self.input_signature = input_signature
|
|
|
|
self.expected_output_signature = expected_output_signature
|
|
|
|
self.skip_jax_run = skip_jax_run
|
|
|
|
self.expect_error = expect_error
|
|
|
|
self.tol = tol
|
|
|
|
self.check_result = check_result
|
2023-05-09 02:28:49 -07:00
|
|
|
self.limitations = limitations
|
2023-05-04 09:52:21 +02:00
|
|
|
self.override_jax_config_flags = override_jax_config_flags
|
2022-12-17 05:56:48 +02:00
|
|
|
|
2023-12-11 13:59:29 +00:00
|
|
|
def run_test(self, tst: tf_test_util.JaxToTfTestCase) -> jax.Array | None:
|
2023-05-09 02:28:49 -07:00
|
|
|
def log_message(extra: str):
|
|
|
|
return f"[{tst._testMethodName}]: {extra}"
|
|
|
|
|
2023-08-25 00:10:13 +09:00
|
|
|
# Check that we have overridden the jax.config flags
|
2023-05-04 09:52:21 +02:00
|
|
|
for fname, fvalue in self.override_jax_config_flags.items():
|
|
|
|
tst.assertEqual(getattr(jax.config, fname), fvalue, (
|
|
|
|
f"Flag {fname} current value {getattr(jax.config, fname)} != {fvalue}"))
|
|
|
|
|
2023-06-16 12:50:50 +03:00
|
|
|
tst.assertIsNotNone(self.polymorphic_shapes)
|
|
|
|
polymorphic_shapes = self.polymorphic_shapes
|
|
|
|
if not self.skip_jax_run:
|
2022-12-17 05:56:48 +02:00
|
|
|
args = self.dyn_args_maker(tst.rng())
|
2023-06-16 12:50:50 +03:00
|
|
|
else:
|
|
|
|
tst.assertIsNotNone(self.input_signature)
|
|
|
|
|
|
|
|
if self.input_signature is None:
|
|
|
|
tst.assertEqual(
|
|
|
|
len(polymorphic_shapes), len(args),
|
|
|
|
f"polymorphic_shapes {polymorphic_shapes} of length "
|
|
|
|
f"{len(polymorphic_shapes)} must match number of arguments {len(args)}")
|
2024-01-10 09:44:31 +02:00
|
|
|
args_specs = export.symbolic_args_specs(args, polymorphic_shapes)
|
2023-06-16 12:50:50 +03:00
|
|
|
input_signature = [
|
|
|
|
tf.TensorSpec(
|
|
|
|
[d if isinstance(d, int) else None for d in a.shape],
|
|
|
|
dtype=a.dtype) for a in args_specs]
|
|
|
|
else:
|
|
|
|
input_signature = self.input_signature # type: ignore
|
2022-12-17 05:56:48 +02:00
|
|
|
|
|
|
|
expect_error_type, expect_error_regex = self.expect_error
|
2023-05-09 02:28:49 -07:00
|
|
|
if self.skip_jax_run and not self.arg_descriptors:
|
2022-12-17 05:56:48 +02:00
|
|
|
f_jax = self.fun
|
|
|
|
else:
|
|
|
|
f_jax = self.dyn_fun
|
|
|
|
|
2022-12-20 15:29:51 +02:00
|
|
|
with contextlib.ExitStack() as stack:
|
|
|
|
if expect_error_type is not None:
|
|
|
|
stack.enter_context(tst.assertRaisesRegex(expect_error_type, expect_error_regex))
|
|
|
|
|
2022-12-17 05:56:48 +02:00
|
|
|
f_tf = jax2tf.convert(f_jax, polymorphic_shapes=polymorphic_shapes,
|
2024-07-16 02:04:59 -07:00
|
|
|
polymorphic_constraints=self.polymorphic_constraints)
|
2022-12-20 15:29:51 +02:00
|
|
|
# Run in tf.Eager mode first, because it is friendlier to debuggers
|
|
|
|
res_tf = f_tf(*args) if not self.skip_jax_run else None
|
2022-12-17 05:56:48 +02:00
|
|
|
f_tf_func = tf.function(
|
|
|
|
f_tf, autograph=False, input_signature=input_signature)
|
|
|
|
# Create tf.ConcreteFunction and check inferred output signature
|
|
|
|
concrete_f_tf = f_tf_func.get_concrete_function(*input_signature)
|
|
|
|
|
2022-12-20 15:29:51 +02:00
|
|
|
if expect_error_type is not None:
|
2023-11-12 18:17:07 +01:00
|
|
|
return None
|
2022-12-20 15:29:51 +02:00
|
|
|
|
2022-12-17 05:56:48 +02:00
|
|
|
if self.expected_output_signature:
|
|
|
|
# Strangely, output_shapes can be a single shape for a function with a
|
|
|
|
# single result, or a list/tuple of shapes.
|
|
|
|
expected_output_signature = self.expected_output_signature
|
|
|
|
concrete_output_tf_shape = concrete_f_tf.output_shapes
|
|
|
|
if not isinstance(concrete_output_tf_shape, (tuple, list)): # Single result
|
|
|
|
assert not isinstance(self.expected_output_signature, (tuple, list))
|
|
|
|
expected_output_signature = [self.expected_output_signature]
|
|
|
|
concrete_output_tf_shape = [concrete_output_tf_shape]
|
|
|
|
for expected, found in util.safe_zip(expected_output_signature,
|
|
|
|
concrete_output_tf_shape):
|
|
|
|
tst.assertEqual(tuple(expected.shape), tuple(found))
|
|
|
|
|
|
|
|
# Run the JAX and the TF functions and compare the results
|
|
|
|
if not self.skip_jax_run:
|
|
|
|
res_jax = f_jax(*args)
|
|
|
|
if self.check_result:
|
2024-05-17 09:46:36 +01:00
|
|
|
res_tf = tf.nest.map_structure(lambda t: t.numpy(), res_tf)
|
2023-05-09 02:28:49 -07:00
|
|
|
custom_assert_lims = [
|
|
|
|
l for l in self.limitations if l.custom_assert is not None]
|
|
|
|
assert len(custom_assert_lims) <= 1, custom_assert_lims
|
|
|
|
tol = None
|
|
|
|
if self.tol is not None:
|
|
|
|
tol = self.tol
|
|
|
|
elif self.limitations:
|
|
|
|
max_lim = self.limitations[0].get_max_tolerance_limitation(
|
|
|
|
self.limitations)
|
|
|
|
if max_lim is not None:
|
|
|
|
tol = max_lim.tol
|
|
|
|
|
|
|
|
if not custom_assert_lims:
|
|
|
|
tst.assertAllClose(res_jax, res_tf, atol=tol, rtol=tol)
|
|
|
|
else:
|
|
|
|
logging.info(log_message(
|
|
|
|
f"Running custom_assert with tol={tol} due "
|
|
|
|
f"to {custom_assert_lims[0]}"))
|
|
|
|
custom_assert_lims[0].custom_assert(tst, res_jax, res_tf, args=args, # type: ignore
|
|
|
|
tol=tol, err_msg=None)
|
2023-11-12 18:17:07 +01:00
|
|
|
return res_tf
|
|
|
|
else:
|
|
|
|
return None
|
|
|
|
else:
|
|
|
|
return None
|
2022-12-17 05:56:48 +02:00
|
|
|
|
|
|
|
|
|
|
|
def check_shape_poly(tst, f_jax: Callable, *,
|
2023-11-09 13:57:30 -08:00
|
|
|
arg_descriptors: Sequence[test_harnesses.ArgDescriptor] = (),
|
2022-12-17 05:56:48 +02:00
|
|
|
skip_jax_run: bool = False,
|
2023-12-11 13:59:29 +00:00
|
|
|
polymorphic_shapes: Sequence[str | None] = (),
|
2024-01-01 23:09:42 +07:00
|
|
|
polymorphic_constraints: Sequence[str] = (),
|
2023-12-11 13:59:29 +00:00
|
|
|
input_signature: Sequence[tf.TensorSpec] | None = None,
|
|
|
|
expected_output_signature: tf.TensorSpec | None = None,
|
|
|
|
expect_error=(None, None)) -> jax.Array | None:
|
2022-12-17 05:56:48 +02:00
|
|
|
# Makes and tests a harness. See PolyHarness documentation.
|
|
|
|
h = PolyHarness("", "", f_jax,
|
|
|
|
arg_descriptors=arg_descriptors,
|
2023-06-16 12:50:50 +03:00
|
|
|
skip_jax_run=skip_jax_run,
|
2022-12-17 05:56:48 +02:00
|
|
|
polymorphic_shapes=polymorphic_shapes,
|
2024-01-01 23:09:42 +07:00
|
|
|
polymorphic_constraints=polymorphic_constraints,
|
2022-12-17 05:56:48 +02:00
|
|
|
input_signature=input_signature,
|
2023-02-13 02:47:35 -08:00
|
|
|
expected_output_signature=expected_output_signature,
|
|
|
|
expect_error=expect_error)
|
2023-11-12 18:17:07 +01:00
|
|
|
return h.run_test(tst)
|
2021-03-16 11:38:57 +01:00
|
|
|
|
2022-11-16 11:53:43 +01:00
|
|
|
|
2022-12-17 05:56:48 +02:00
|
|
|
class ShapePolyTest(tf_test_util.JaxToTfTestCase):
|
|
|
|
|
2021-07-23 10:59:44 +03:00
|
|
|
def test_simple_unary(self):
|
|
|
|
"""Test shape polymorphism for a simple case, unary function."""
|
2021-04-08 06:21:12 -07:00
|
|
|
|
2020-10-11 19:48:36 +03:00
|
|
|
def f_jax(x):
|
|
|
|
return x + jnp.sin(x)
|
|
|
|
|
2022-12-17 05:56:48 +02:00
|
|
|
check_shape_poly(self,
|
|
|
|
f_jax,
|
|
|
|
arg_descriptors=[RandArg((2, 3), _f32)],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=[None],
|
2022-12-17 05:56:48 +02:00
|
|
|
expected_output_signature=tf.TensorSpec([2, 3]))
|
|
|
|
|
|
|
|
check_shape_poly(self,
|
|
|
|
f_jax,
|
|
|
|
arg_descriptors=[RandArg((2, 3), _f32)],
|
|
|
|
polymorphic_shapes=["_, h"],
|
|
|
|
expected_output_signature=tf.TensorSpec([2, None]))
|
|
|
|
|
|
|
|
check_shape_poly(self,
|
|
|
|
f_jax,
|
|
|
|
arg_descriptors=[RandArg((3, 3), _f32)],
|
|
|
|
polymorphic_shapes=["h, h"],
|
|
|
|
expected_output_signature=tf.TensorSpec([None, None]))
|
|
|
|
|
|
|
|
check_shape_poly(self,
|
|
|
|
f_jax,
|
|
|
|
arg_descriptors=[RandArg((3, 3), _f32)],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=["h, h"],
|
2022-12-17 05:56:48 +02:00
|
|
|
expected_output_signature=tf.TensorSpec([None, None]))
|
2021-07-23 10:59:44 +03:00
|
|
|
|
|
|
|
def test_simple_binary(self):
|
|
|
|
"""Test shape polymorphism for a simple case, binary function."""
|
|
|
|
|
|
|
|
def f_jax(x, y):
|
|
|
|
return x + jnp.sin(y)
|
|
|
|
|
2022-12-17 05:56:48 +02:00
|
|
|
check_shape_poly(self,
|
|
|
|
f_jax,
|
|
|
|
arg_descriptors=[RandArg((2, 3), _f32), RandArg((2, 3), _f32)],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=[None, None],
|
2022-12-17 05:56:48 +02:00
|
|
|
expected_output_signature=tf.TensorSpec([2, 3]))
|
|
|
|
|
|
|
|
check_shape_poly(self,
|
|
|
|
f_jax,
|
|
|
|
arg_descriptors=[RandArg((2, 3), _f32), RandArg((2, 3), _f32)],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=["_, h", "_, h"],
|
2022-12-17 05:56:48 +02:00
|
|
|
input_signature=[tf.TensorSpec([2, None]), tf.TensorSpec([2, 3])],
|
|
|
|
expected_output_signature=(
|
2023-03-15 10:30:52 -07:00
|
|
|
# for native serialization we cannot refine the inferred shape of the
|
2022-12-17 05:56:48 +02:00
|
|
|
# output if the input is more specific than polymorphic_shapes.
|
2023-10-12 13:15:22 +01:00
|
|
|
tf.TensorSpec([2, 3]) if not config.jax2tf_default_native_serialization.value
|
2022-12-17 05:56:48 +02:00
|
|
|
else tf.TensorSpec([2, None])))
|
|
|
|
|
|
|
|
check_shape_poly(self,
|
|
|
|
f_jax,
|
|
|
|
arg_descriptors=[RandArg((3, 3), _f32), RandArg((3, 3), _f32)],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=["h, h", "h, h"],
|
2022-12-17 05:56:48 +02:00
|
|
|
expected_output_signature=tf.TensorSpec([None, None]))
|
2021-07-23 10:59:44 +03:00
|
|
|
|
2023-06-16 12:50:50 +03:00
|
|
|
@jtu.parameterized_filterable(
|
|
|
|
# make_args invoked with op.shape[0]: start, stop, step, dtype
|
|
|
|
kwargs=[
|
|
|
|
dict(testcase_name=name, make_args=make_args, expect_error=expect_error, expect_msg=expect_msg)
|
2023-03-27 09:27:10 +02:00
|
|
|
for name, make_args, expect_error, expect_msg in [
|
|
|
|
# make_args invoked with op.shape[0]: start, stop, step, dtype
|
|
|
|
("float_start", lambda b: (0., b, None),
|
|
|
|
ValueError, "must be either dimension expressions or integers"),
|
|
|
|
("float_step", lambda b: (0, b, 0.5),
|
|
|
|
ValueError, "must be either dimension expressions or integers"),
|
|
|
|
("step_0", lambda b: (0, b, 0),
|
|
|
|
ValueError, "has step == 0"),
|
|
|
|
("inconclusive_step_sign", lambda b: (0, b, b - 2),
|
|
|
|
core.InconclusiveDimensionOperation,
|
|
|
|
"must be resolved statically if it is > 0 or < 0"),
|
|
|
|
]
|
2023-06-16 12:50:50 +03:00
|
|
|
]
|
|
|
|
)
|
2023-03-27 09:27:10 +02:00
|
|
|
def test_arange_error(self, make_args=lambda b: (0., b, 2),
|
|
|
|
expect_error=ValueError,
|
|
|
|
expect_msg="must be either dimension expressions or integers"):
|
|
|
|
def f_jax(x): # x: i32[b]
|
|
|
|
return x[0] + jnp.arange(*(make_args(x.shape[0])))
|
|
|
|
x = np.ones((3,), dtype=np.int32)
|
|
|
|
with self.assertRaisesRegex(expect_error, expect_msg):
|
2023-11-12 18:17:07 +01:00
|
|
|
check_shape_poly(self, f_jax, arg_descriptors=[x],
|
|
|
|
polymorphic_shapes=["b"])
|
2023-03-27 09:27:10 +02:00
|
|
|
|
2023-06-16 12:50:50 +03:00
|
|
|
@jtu.parameterized_filterable(
|
|
|
|
kwargs=[
|
|
|
|
dict(testcase_name=f"expr={name}", expr=expr)
|
2023-02-15 23:50:44 -08:00
|
|
|
for name, expr in [
|
|
|
|
("d + 2", lambda d: d + 2),
|
|
|
|
("2 - d", lambda d: 2 - d),
|
|
|
|
("d * 2", lambda d: d * 2),
|
|
|
|
("d * d", lambda d: d * d),
|
|
|
|
("(- d) * d", lambda d: (- d) * d),
|
|
|
|
("d * d - d", lambda d: d * d - d),
|
|
|
|
# Division
|
|
|
|
("d // 2", lambda d: d // 2),
|
|
|
|
("(d + 1) // 2", lambda d: (d + 1) // 2),
|
|
|
|
("d // -2", lambda d: d // -2),
|
|
|
|
("(d + 1) // -2", lambda d: (d + 1) // -2),
|
|
|
|
("(-d) // 2", lambda d: (-d) // 2),
|
|
|
|
("(-d - 1) // 2", lambda d: (-d - 1) // 2),
|
|
|
|
("(-d) // -2", lambda d: (-d) // -2),
|
|
|
|
("(-d - 1) // -2", lambda d: (-d - 1) // -2),
|
|
|
|
# Remainder
|
|
|
|
("d % 2", lambda d: d % 2),
|
|
|
|
("(d + 1) % 2", lambda d: (d + 1) % 2),
|
|
|
|
("d % -2", lambda d: d % -2),
|
|
|
|
("(d + 1) % -2", lambda d: (d + 1) % -2),
|
|
|
|
("(-d) % 2", lambda d: (-d) % 2),
|
|
|
|
("(-d - 1) % 2", lambda d: (-d - 1) % 2),
|
|
|
|
("(-d) % -2", lambda d: (-d) % -2),
|
|
|
|
("(-d - 1) % -2", lambda d: (-d - 1) % -2),
|
|
|
|
]
|
|
|
|
])
|
|
|
|
def test_non_trivial_dim_expr(self, expr=lambda d: d % -2):
|
|
|
|
# Check the lowering for shape expressions
|
|
|
|
check_shape_poly(
|
2023-06-16 12:50:50 +03:00
|
|
|
self,
|
|
|
|
lambda x: x[0] * 0 + expr(x.shape[0]),
|
|
|
|
arg_descriptors=[RandArg((3,), np.int64)],
|
|
|
|
polymorphic_shapes=["b"])
|
2023-02-03 11:25:27 +02:00
|
|
|
|
2022-04-12 07:02:29 -07:00
|
|
|
def test_static_shape_result(self):
|
|
|
|
"""The result has static shape."""
|
|
|
|
|
|
|
|
def f_jax(x):
|
|
|
|
return jnp.sum(x + jnp.sin(x), axis=0)
|
|
|
|
|
2022-12-17 05:56:48 +02:00
|
|
|
check_shape_poly(self,
|
|
|
|
f_jax,
|
|
|
|
arg_descriptors=[RandArg((2, 3), _f32)],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=[None],
|
2022-12-17 05:56:48 +02:00
|
|
|
expected_output_signature=tf.TensorSpec([3]))
|
2022-04-12 07:02:29 -07:00
|
|
|
|
2022-12-17 05:56:48 +02:00
|
|
|
check_shape_poly(self,
|
|
|
|
f_jax,
|
|
|
|
arg_descriptors=[RandArg((2, 3), _f32)],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=["b, _"],
|
2022-12-17 05:56:48 +02:00
|
|
|
expected_output_signature=tf.TensorSpec([3]))
|
2022-04-12 07:02:29 -07:00
|
|
|
|
2021-07-26 17:48:21 +03:00
|
|
|
def test_forgot_polymorphic_shapes_error(self):
|
2023-11-23 09:05:37 +02:00
|
|
|
msg_re = "syntax error in symbolic shape"
|
2021-07-26 17:48:21 +03:00
|
|
|
with self.assertRaisesRegex(ValueError, msg_re):
|
2022-12-17 05:56:48 +02:00
|
|
|
check_shape_poly(self,
|
|
|
|
jnp.sin,
|
|
|
|
arg_descriptors=[RandArg((1, 3,), _f32)],
|
|
|
|
input_signature=[tf.TensorSpec([1, None])],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=[None])
|
2021-07-26 17:48:21 +03:00
|
|
|
|
2024-01-01 23:09:42 +07:00
|
|
|
def test_with_constraints(self):
|
|
|
|
if not config.jax2tf_default_native_serialization.value:
|
|
|
|
self.skipTest("not supported")
|
|
|
|
def f_jax(x): # x: i32[a], with a >= 8
|
|
|
|
return lax.dynamic_slice_in_dim(x, 0, 8, 0)
|
|
|
|
check_shape_poly(self, f_jax,
|
|
|
|
arg_descriptors=[RandArg((16,), _i32)],
|
|
|
|
polymorphic_shapes=["a"],
|
|
|
|
polymorphic_constraints=["a >= 8"])
|
|
|
|
|
2021-08-04 09:05:05 +03:00
|
|
|
def test_kwargs(self):
|
|
|
|
"""Test shape polymorphism for a function with kwargs."""
|
|
|
|
|
|
|
|
x = np.ones(3, dtype=np.float32)
|
|
|
|
y = np.ones(1, dtype=np.float32)
|
|
|
|
def f_jax(x, *, y):
|
|
|
|
return x + jnp.sin(y)
|
|
|
|
|
2022-12-17 05:56:48 +02:00
|
|
|
f_tf: Callable[..., Any] = jax2tf.convert(f_jax, polymorphic_shapes=["b, ..."])
|
|
|
|
self.assertAllClose(f_jax(x, y=y), f_tf(x, y=y))
|
2021-08-04 09:05:05 +03:00
|
|
|
|
2023-05-27 06:15:50 +02:00
|
|
|
def test_arg_avals_non_native(self):
|
2021-04-08 06:21:12 -07:00
|
|
|
"""Test conversion of actual arguments to abstract values."""
|
|
|
|
|
2023-12-11 13:59:29 +00:00
|
|
|
def check_avals(*, arg_shapes: Sequence[Sequence[int | None]],
|
2024-01-10 09:05:16 +02:00
|
|
|
polymorphic_shapes: Sequence[str | None],
|
2024-01-01 23:09:42 +07:00
|
|
|
expected_shapes: Sequence[str] | None = None,
|
2023-12-11 13:59:29 +00:00
|
|
|
expected_shapeenv: dict[str, int] | None = None,
|
2021-08-04 07:03:51 +03:00
|
|
|
eager_mode: bool = False):
|
|
|
|
# Use eager mode only for when all arg_shapes are known, in order to
|
|
|
|
# check expected_shapeenv.
|
|
|
|
arg_dtypes = (_f32,) * len(arg_shapes)
|
2024-01-01 23:09:42 +07:00
|
|
|
symbolic_scope = shape_poly.SymbolicScope()
|
2022-09-02 04:33:39 +03:00
|
|
|
def f_tf(*args_tf):
|
2023-11-23 09:05:37 +02:00
|
|
|
avals = tuple(map(
|
|
|
|
lambda s, dt, spec: core.ShapedArray(
|
2024-01-01 23:09:42 +07:00
|
|
|
export.symbolic_shape(spec, like=s, scope=symbolic_scope),
|
2023-11-23 09:05:37 +02:00
|
|
|
dt),
|
|
|
|
arg_shapes, arg_dtypes, polymorphic_shapes))
|
2023-03-29 12:09:47 +02:00
|
|
|
dim_vars = shape_poly.all_dim_vars(avals)
|
|
|
|
dim_values, _ = jax2tf.jax2tf._interpret_fun_jax(
|
2023-05-13 16:57:27 +02:00
|
|
|
partial(shape_poly.compute_dim_vars_from_arg_shapes,
|
2023-05-27 06:15:50 +02:00
|
|
|
avals,
|
|
|
|
args_kwargs_tree=tree_util.tree_flatten((avals, {}))[1]),
|
2025-02-07 10:15:47 +02:00
|
|
|
args_tf, avals, "",
|
|
|
|
debug_info=api_util.debug_info("jax2tf dim_vars",
|
|
|
|
shape_poly.compute_dim_vars_from_arg_shapes,
|
|
|
|
avals, {}))
|
2024-01-01 23:09:42 +07:00
|
|
|
if expected_shapes is not None:
|
|
|
|
expected_avals = tree_util.tree_map(
|
|
|
|
lambda shape_str: core.ShapedArray(
|
|
|
|
shape_poly.symbolic_shape(shape_str, scope=symbolic_scope),
|
|
|
|
np.float32),
|
|
|
|
expected_shapes)
|
2021-08-04 07:03:51 +03:00
|
|
|
self.assertEqual(expected_avals, avals)
|
2021-12-16 11:30:53 +02:00
|
|
|
return dict(zip(dim_vars, dim_values))
|
2021-08-04 07:03:51 +03:00
|
|
|
if eager_mode:
|
|
|
|
# If we want to check the shape_env then all arg_shapes must be known
|
|
|
|
assert all(all(d is not None for d in a_s)
|
|
|
|
for a_s in arg_shapes)
|
|
|
|
shape_env = f_tf(*[tf.ones(a_s, dtype=_f32) for a_s in arg_shapes])
|
|
|
|
if expected_shapeenv is not None:
|
|
|
|
for v, val in expected_shapeenv.items():
|
|
|
|
self.assertEqual(val, shape_env.get(v))
|
|
|
|
else:
|
|
|
|
f_tf = tf.function(autograph=False)(f_tf)
|
|
|
|
f_tf.get_concrete_function(*[tf.TensorSpec(a_s, _f32)
|
|
|
|
for a_s in arg_shapes])
|
2021-12-16 11:30:53 +02:00
|
|
|
assert not expected_shapeenv, "Should use eager_mode=True"
|
2021-04-01 15:37:01 +03:00
|
|
|
|
2020-10-15 08:24:35 +03:00
|
|
|
# Known shapes for the arguments
|
2021-04-08 06:21:12 -07:00
|
|
|
check_avals(
|
2021-08-04 07:03:51 +03:00
|
|
|
arg_shapes=[(2, 3)],
|
2021-04-08 06:21:12 -07:00
|
|
|
polymorphic_shapes=[None],
|
2024-01-01 23:09:42 +07:00
|
|
|
expected_shapes=("2, 3",))
|
2021-04-08 06:21:12 -07:00
|
|
|
|
|
|
|
check_avals(
|
2021-08-04 07:03:51 +03:00
|
|
|
arg_shapes=[(2, 3)],
|
2021-04-08 06:21:12 -07:00
|
|
|
polymorphic_shapes=["(2, 3)"],
|
2024-01-01 23:09:42 +07:00
|
|
|
expected_shapes=("2, 3",))
|
2021-04-08 06:21:12 -07:00
|
|
|
|
|
|
|
check_avals(
|
2021-08-04 07:03:51 +03:00
|
|
|
arg_shapes=[(2, 3)],
|
2021-04-08 06:21:12 -07:00
|
|
|
polymorphic_shapes=["(_, 3)"],
|
2024-01-01 23:09:42 +07:00
|
|
|
expected_shapes=("2, 3",))
|
2021-04-08 06:21:12 -07:00
|
|
|
|
2021-07-23 10:59:44 +03:00
|
|
|
check_avals(
|
2021-08-04 07:03:51 +03:00
|
|
|
arg_shapes=[(2, 3)],
|
2021-07-23 10:59:44 +03:00
|
|
|
polymorphic_shapes=["..."],
|
2024-01-01 23:09:42 +07:00
|
|
|
expected_shapes=("2, 3",))
|
2021-07-23 10:59:44 +03:00
|
|
|
|
2020-10-15 08:24:35 +03:00
|
|
|
# Partially known shapes for the arguments
|
2024-01-01 23:09:42 +07:00
|
|
|
check_avals(
|
|
|
|
arg_shapes=[(None, 3)],
|
|
|
|
polymorphic_shapes=["b, ..."],
|
|
|
|
expected_shapes=("(b, 3)",))
|
|
|
|
|
2021-04-08 06:21:12 -07:00
|
|
|
check_avals(
|
2021-08-04 07:03:51 +03:00
|
|
|
arg_shapes=[(None, None)],
|
2021-04-08 06:21:12 -07:00
|
|
|
polymorphic_shapes=["h, h"],
|
2024-01-01 23:09:42 +07:00
|
|
|
expected_shapes=("(h, h)",))
|
2021-04-08 06:21:12 -07:00
|
|
|
|
|
|
|
check_avals(
|
2021-08-04 07:03:51 +03:00
|
|
|
arg_shapes=[(2, None)],
|
|
|
|
polymorphic_shapes=["h, h"],
|
2024-01-01 23:09:42 +07:00
|
|
|
expected_shapes=("(h, h)",))
|
2021-04-08 06:21:12 -07:00
|
|
|
|
|
|
|
check_avals(
|
2021-08-04 07:03:51 +03:00
|
|
|
arg_shapes=[(None, 3, 4)],
|
2021-04-08 06:21:12 -07:00
|
|
|
polymorphic_shapes=["(c, b, a)"],
|
2024-01-01 23:09:42 +07:00
|
|
|
expected_shapes=("(c, b, a)",),
|
2021-04-08 06:21:12 -07:00
|
|
|
)
|
2020-10-11 19:48:36 +03:00
|
|
|
|
2021-08-04 07:03:51 +03:00
|
|
|
# Check cases when the specifications are polynomials
|
2021-06-24 09:33:29 +02:00
|
|
|
check_avals(
|
2021-08-04 07:03:51 +03:00
|
|
|
arg_shapes=[(2, 3)],
|
2024-01-10 09:05:16 +02:00
|
|
|
polymorphic_shapes=["a + 1, b + 2"],
|
2021-08-04 07:03:51 +03:00
|
|
|
eager_mode=True,
|
|
|
|
expected_shapeenv=dict(a=1, b=1))
|
|
|
|
|
|
|
|
check_avals(
|
|
|
|
arg_shapes=[(7, 5)],
|
2024-01-10 09:05:16 +02:00
|
|
|
polymorphic_shapes=["2 * a + b, b + 2"],
|
2021-08-04 07:03:51 +03:00
|
|
|
eager_mode=True,
|
|
|
|
expected_shapeenv=dict(a=2, b=3))
|
|
|
|
|
|
|
|
check_avals(
|
|
|
|
arg_shapes=[(7, 11, 4)],
|
2024-01-10 09:05:16 +02:00
|
|
|
polymorphic_shapes=["2 * a + b, b * b + 2, b + 1"],
|
2021-08-04 07:03:51 +03:00
|
|
|
eager_mode=True,
|
|
|
|
expected_shapeenv=dict(a=2, b=3))
|
|
|
|
|
|
|
|
check_avals(
|
|
|
|
arg_shapes=[(7, 11, 19, 7)],
|
2024-01-10 09:05:16 +02:00
|
|
|
polymorphic_shapes=["2 * a + b, b * b + 2, b + c * c, 2 * c + -1"],
|
2021-08-04 07:03:51 +03:00
|
|
|
eager_mode=True,
|
|
|
|
expected_shapeenv=dict(a=2, b=3, c=4))
|
|
|
|
|
2023-05-27 06:15:50 +02:00
|
|
|
def test_arg_avals_errors(self):
|
2023-11-12 18:17:07 +01:00
|
|
|
"""Test error reporting for shape polymorphism."""
|
2023-05-27 06:15:50 +02:00
|
|
|
def conv_and_run(*, arg_shape: core.Shape,
|
|
|
|
polymorphic_shape: str):
|
|
|
|
arg = np.arange(math.prod(arg_shape), dtype=np.float32).reshape(arg_shape)
|
2023-11-12 18:17:07 +01:00
|
|
|
check_shape_poly(self, lambda x: x,
|
|
|
|
arg_descriptors=[arg],
|
|
|
|
polymorphic_shapes=[polymorphic_shape])
|
2023-05-27 06:15:50 +02:00
|
|
|
|
|
|
|
with self.assertRaisesRegex(ValueError,
|
|
|
|
re.escape("polymorphic shape spec should be")):
|
|
|
|
conv_and_run(arg_shape=(2,), polymorphic_shape=5.)
|
|
|
|
|
|
|
|
with self.assertRaisesRegex(ValueError,
|
|
|
|
re.escape("pytree structure error: different types")):
|
|
|
|
conv_and_run(arg_shape=(2,), polymorphic_shape=["a list"])
|
|
|
|
|
|
|
|
with self.assertRaisesRegex(ValueError,
|
|
|
|
re.escape("pytree structure error: different types")):
|
|
|
|
conv_and_run(arg_shape=(2,), polymorphic_shape=("a tuple",))
|
|
|
|
|
2021-08-04 07:03:51 +03:00
|
|
|
with self.assertRaisesRegex(ValueError,
|
|
|
|
"Cannot solve for values of dimension variables {'b'}"):
|
2023-05-27 06:15:50 +02:00
|
|
|
conv_and_run(arg_shape=(4, 36, 3), polymorphic_shape="b * b, b * d * d, d")
|
2021-08-04 07:03:51 +03:00
|
|
|
|
2023-07-03 17:31:31 +03:00
|
|
|
with self.assertRaisesRegex(tf.errors.InvalidArgumentError,
|
2023-07-21 14:46:30 +03:00
|
|
|
"Division had remainder 2 when computing the value of 'b'"):
|
2023-05-27 06:15:50 +02:00
|
|
|
conv_and_run(arg_shape=(5, 36), polymorphic_shape="3 * b, ...")
|
2021-08-04 07:03:51 +03:00
|
|
|
|
2023-07-03 17:31:31 +03:00
|
|
|
with self.assertRaisesRegex(tf.errors.InvalidArgumentError,
|
2023-07-21 14:46:30 +03:00
|
|
|
"Expected value >= 1 for dimension variable 'b'"):
|
2023-05-27 06:15:50 +02:00
|
|
|
conv_and_run(arg_shape=(10, 3), polymorphic_shape="3 * b + 10, ...")
|
2021-08-04 07:03:51 +03:00
|
|
|
|
2023-07-03 17:31:31 +03:00
|
|
|
with self.assertRaisesRegex(tf.errors.InvalidArgumentError,
|
2023-07-21 14:46:30 +03:00
|
|
|
"Expected value >= 1 for dimension variable 'b'"):
|
2023-05-27 06:15:50 +02:00
|
|
|
conv_and_run(arg_shape=(7, 3), polymorphic_shape="3 * b + 10, ...")
|
2021-07-23 10:59:44 +03:00
|
|
|
|
2021-04-08 06:21:12 -07:00
|
|
|
with self.assertRaisesRegex(
|
2023-07-03 17:31:31 +03:00
|
|
|
tf.errors.InvalidArgumentError,
|
2023-07-21 14:46:30 +03:00
|
|
|
re.escape(
|
|
|
|
"Found inconsistency between dimension size "
|
|
|
|
"args[0].shape[1] (= 3) and the specification 'a' (= 2)")):
|
2023-05-27 06:15:50 +02:00
|
|
|
conv_and_run(arg_shape=(2, 3), polymorphic_shape="(a, a)")
|
2023-07-21 14:46:30 +03:00
|
|
|
|
|
|
|
# Tests details of the shape constraints errors.
|
|
|
|
# This test exists also in jax_export_test.py.
|
|
|
|
@jtu.parameterized_filterable(
|
|
|
|
testcase_name=lambda kw: kw["shape"],
|
|
|
|
kwargs=[
|
|
|
|
dict(shape=(8, 2, 9), # a = 2, b = 3, c = 4
|
|
|
|
poly_spec="(a + 2*b, a, a + b + c)"),
|
|
|
|
dict(shape=(2, 2, 6), # a = 2, b = 0, c = 4
|
|
|
|
poly_spec="(a + 2*b, a, a + b + c)",
|
|
|
|
expect_error=(
|
|
|
|
"Input shapes do not match the polymorphic shapes specification. "
|
|
|
|
"Expected value >= 1 for dimension variable 'b'. "
|
2024-01-05 14:48:53 +07:00
|
|
|
"Using the following polymorphic shapes specifications: args[0].shape = (2*b + a, a, c + b + a). "
|
2023-07-21 14:46:30 +03:00
|
|
|
"Obtained dimension variables: 'a' = 2 from specification 'a' for dimension args[0].shape[1] (= 2), "
|
2024-01-05 14:48:53 +07:00
|
|
|
"'b' = 0 from specification '2*b + a' for dimension args[0].shape[0] (= 2), . "
|
2024-06-12 08:47:17 +02:00
|
|
|
"Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#shape-assertion-errors for more details."
|
2023-07-21 14:46:30 +03:00
|
|
|
)),
|
|
|
|
dict(shape=(3, 2, 6), # a = 2, b = 0.5, c = 4 - b is not integer
|
|
|
|
poly_spec="(a + 2*b, a, a + b + c)",
|
|
|
|
expect_error=(
|
|
|
|
"Input shapes do not match the polymorphic shapes specification. "
|
|
|
|
"Division had remainder 1 when computing the value of 'b'. "
|
2024-01-05 14:48:53 +07:00
|
|
|
"Using the following polymorphic shapes specifications: args[0].shape = (2*b + a, a, c + b + a). "
|
2023-07-21 14:46:30 +03:00
|
|
|
"Obtained dimension variables: 'a' = 2 from specification 'a' for dimension args[0].shape[1] (= 2), . "
|
2024-06-12 08:47:17 +02:00
|
|
|
"Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#shape-assertion-errors for more details."
|
2023-07-21 14:46:30 +03:00
|
|
|
)),
|
|
|
|
dict(shape=(8, 2, 6), # a = 2, b = 3 - inconsistency
|
|
|
|
poly_spec="(a + 2*b, a, a + b)",
|
|
|
|
expect_error=(
|
|
|
|
"Input shapes do not match the polymorphic shapes specification. "
|
2024-01-05 14:48:53 +07:00
|
|
|
"Found inconsistency between dimension size args[0].shape[0] (= 8) and the specification '2*b + a' (= 10). "
|
|
|
|
"Using the following polymorphic shapes specifications: args[0].shape = (2*b + a, a, b + a). "
|
2023-07-21 14:46:30 +03:00
|
|
|
"Obtained dimension variables: 'a' = 2 from specification 'a' for dimension args[0].shape[1] (= 2), "
|
2024-01-05 14:48:53 +07:00
|
|
|
"'b' = 4 from specification 'b + a' for dimension args[0].shape[2] (= 6), . "
|
2024-06-12 08:47:17 +02:00
|
|
|
"Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#shape-assertion-errors for more details."
|
2023-07-21 14:46:30 +03:00
|
|
|
)),
|
|
|
|
dict(shape=(7, 2, 36), # a = 2, b = 3, c = 6 - cannot solve c
|
|
|
|
poly_spec="(2 * a + b, a, c * c)",
|
|
|
|
expect_error=(
|
|
|
|
"Cannot solve for values of dimension variables {'c'}. "
|
|
|
|
"We can only solve linear uni-variate constraints. "
|
2024-01-05 14:48:53 +07:00
|
|
|
"Using the following polymorphic shapes specifications: args[0].shape = (b + 2*a, a, c^2). "
|
2023-07-21 14:46:30 +03:00
|
|
|
"Unprocessed specifications: 'c^2' for dimension size args[0].shape[2]. "
|
2024-06-12 08:47:17 +02:00
|
|
|
"Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#dimension-variables-must-be-solvable-from-the-input-shapes for more details."
|
2023-07-21 14:46:30 +03:00
|
|
|
)),
|
|
|
|
])
|
|
|
|
def test_shape_constraints_errors(self, *,
|
2023-12-11 13:59:29 +00:00
|
|
|
shape, poly_spec: str, expect_error: str | None = None):
|
2023-07-21 14:46:30 +03:00
|
|
|
def f_jax(x): # x: f32[a + 2*b, a, a + b + c]
|
|
|
|
return 0.
|
|
|
|
|
|
|
|
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
|
|
|
with contextlib.ExitStack() as stack:
|
|
|
|
if expect_error is not None:
|
|
|
|
stack.push(self.assertRaisesRegex(Exception, re.escape(expect_error)))
|
2023-11-12 18:17:07 +01:00
|
|
|
_ = check_shape_poly(self, f_jax,
|
|
|
|
arg_descriptors=[x],
|
|
|
|
polymorphic_shapes=[poly_spec])
|
2023-07-21 14:46:30 +03:00
|
|
|
|
2020-10-11 19:48:36 +03:00
|
|
|
def test_pytree(self):
|
2021-04-01 15:37:01 +03:00
|
|
|
"""Arguments and polymorphic_shapes are pytrees."""
|
2020-10-11 19:48:36 +03:00
|
|
|
|
|
|
|
# Arguments are of the form [([x00, x01], [x10]), dict(a=ya, b=yb)]
|
|
|
|
def add_all_jax(x_pair_of_list, y_dict):
|
|
|
|
x_list_0, x_list_1 = x_pair_of_list
|
2023-01-18 12:27:02 +02:00
|
|
|
return functools.reduce(op.add,
|
2020-10-11 19:48:36 +03:00
|
|
|
x_list_0 + x_list_1 + [y_dict["a"], y_dict["b"]])
|
|
|
|
|
2023-04-04 13:23:43 +02:00
|
|
|
input_signature = [([tf.TensorSpec([None]), tf.TensorSpec([None])],
|
|
|
|
[tf.TensorSpec([None])]),
|
|
|
|
dict(a=tf.TensorSpec([None]),
|
|
|
|
b=tf.TensorSpec([None]))]
|
2022-12-17 05:56:48 +02:00
|
|
|
check_shape_poly(self,
|
|
|
|
add_all_jax,
|
|
|
|
skip_jax_run=True,
|
2023-04-04 13:23:43 +02:00
|
|
|
input_signature=input_signature,
|
2022-12-17 05:56:48 +02:00
|
|
|
polymorphic_shapes=[(["v", "v"], ["v"]),
|
|
|
|
dict(a="v", b="v")],
|
|
|
|
expected_output_signature=tf.TensorSpec([None]))
|
2021-04-08 06:21:12 -07:00
|
|
|
|
2023-04-04 13:23:43 +02:00
|
|
|
# Prefix polymorphic shapes
|
|
|
|
check_shape_poly(self,
|
|
|
|
add_all_jax,
|
|
|
|
skip_jax_run=True,
|
|
|
|
input_signature=input_signature,
|
|
|
|
polymorphic_shapes="v",
|
|
|
|
expected_output_signature=tf.TensorSpec([None]))
|
|
|
|
|
|
|
|
check_shape_poly(self,
|
|
|
|
add_all_jax,
|
|
|
|
skip_jax_run=True,
|
|
|
|
input_signature=input_signature,
|
|
|
|
polymorphic_shapes=["v", "v"],
|
|
|
|
expected_output_signature=tf.TensorSpec([None]))
|
|
|
|
|
|
|
|
check_shape_poly(self,
|
|
|
|
add_all_jax,
|
|
|
|
skip_jax_run=True,
|
|
|
|
input_signature=input_signature,
|
|
|
|
polymorphic_shapes=[("v", "v"), "v"],
|
|
|
|
expected_output_signature=tf.TensorSpec([None]))
|
|
|
|
|
2021-04-08 06:21:12 -07:00
|
|
|
# Now partial polymorphic_shapes; the parts of the polymorphic_shapes that
|
|
|
|
# are not specified must have full input_signatures.
|
2022-12-17 05:56:48 +02:00
|
|
|
check_shape_poly(self,
|
|
|
|
add_all_jax,
|
|
|
|
skip_jax_run=True,
|
|
|
|
input_signature=[([tf.TensorSpec([4]), tf.TensorSpec([4])], [tf.TensorSpec([4])]),
|
|
|
|
dict(a=tf.TensorSpec([4]), b=tf.TensorSpec([4]))],
|
2023-04-04 13:23:43 +02:00
|
|
|
polymorphic_shapes=((["(4,)", "(_,)"], [("4,")]),
|
|
|
|
dict(a="(_,)", b="(4,)")),
|
2022-12-17 05:56:48 +02:00
|
|
|
expected_output_signature=tf.TensorSpec([4]))
|
2020-10-11 19:48:36 +03:00
|
|
|
|
2023-06-16 12:50:50 +03:00
|
|
|
@jtu.parameterized_filterable(
|
|
|
|
kwargs=[
|
|
|
|
dict(testcase_name=name, polymorphic_shapes=polymorphic_shapes)
|
2023-04-04 13:23:43 +02:00
|
|
|
for name, polymorphic_shapes in [
|
|
|
|
("1", ("b", "b", "b")),
|
|
|
|
("2", dict(a="b")),
|
|
|
|
("3", (dict(a="b"), "b")),
|
2023-06-16 12:50:50 +03:00
|
|
|
]]
|
2023-04-04 13:23:43 +02:00
|
|
|
)
|
|
|
|
def test_pytree_errors(self, polymorphic_shapes=("b", "b", "b")):
|
|
|
|
"""Arguments and polymorphic_shapes are not-matching pytrees."""
|
|
|
|
|
|
|
|
# Arguments are of the form [([x00, x01], [x10]), dict(a=ya, b=yb)]
|
|
|
|
x = np.arange(4, dtype=_f32)
|
|
|
|
args = (([x, x], [x]), dict(a=x, b=x))
|
|
|
|
def add_all_jax(x_pair_of_list, y_dict):
|
|
|
|
x_list_0, x_list_1 = x_pair_of_list
|
|
|
|
return functools.reduce(op.add,
|
|
|
|
x_list_0 + x_list_1 + [y_dict["a"], y_dict["b"]])
|
|
|
|
|
|
|
|
with self.assertRaisesRegex(ValueError, "pytree structure error"):
|
|
|
|
jax2tf.convert(add_all_jax,
|
|
|
|
polymorphic_shapes=polymorphic_shapes)(*args)
|
|
|
|
|
2022-09-06 09:32:45 +03:00
|
|
|
def test_with_nested_jit(self):
|
[jax2tf] An alternative support for shape polymorphism for native serialization.
jax2tf already supports many cases of shape polymorphism, e.g., those
where the shapes of all intermediates can be expressed as polynomials
in the dimension variables in the input. We want to achieve the same
same coverage, or more, while using StableHLO as the lowering format,
rather than tf.Graph.
For native serialization we will support two lowering implementations:
* one is using the growing support in JAX for dynamic shapes,
of which shape polymorphism is a special case.
This implementation is enabled with the --jax_dynamic_shapes flag.
At the moment, the JAX dynamic shapes support is still
incomplete and over 300 jax2tf shape polymorphism tests fail.
* a new one (added) here in which we form a Jaxpr using abstract
values that express dimension sizes as dimension polynomials
(as for the standard jax2tf). Then we lower the Jaxpr to StableHLO.
This implementation is enabled when --jax_dynamic_shapes is off.
With this implementation only 50 jax2tf tests fail (to be fixed
separately).
The key contribution here is to enable lowering a Jaxpr that contains
dimension polynomials in some of the intermediate shapes. Many lowering rules
already have some partial support for Jaxprs where the shapes contain
`Var`s. To the extent possible, we try to write lowering rules that should
cover both cases of dynamic shapes: Var or polynomials in shapes.
The lowering convention is that at top level we collect the sorted list
of dimension variable names in the inputs, and we store it in ModuleContext.dim_vars.
All IR functions will take N additional prefix arguments of int32 type
containing the values of the dimension variables. This is stored as
a list of `ir.Value` in `LoweringContext.dim_var_values`.
Note that the Jaxprs are not changed to have extra Vars for the dimension
variable values. An alternative implementation could work by transforming
the Jaxpr to replace dimension polynomials into Vars.
The key code pattern used in the lowering rule is::
if not core.is_constant_shape(shape): # Handles both Var, and polynomials
shape = mlir.eval_dynamic_shape(ctx, shape)
return mhlo.DynamicXXX(..., shape)
else:
return mhlo.XXX(..., shape)
with `mlir.eval_dynamic_shape` handling both cases::
def eval_dynamic_shape(ctx, shape):
if config.jax_dynamic_shapes:
# Using Var
return ... subst using ctx.axis_size_env ...
else:
# Using polynomials
return ... subst using ctx.module_context.dim_vars and ctx.dim_var_values
In order to support the above some lowering functions need to take a
LoweringContext parameter, e.g., mlir.broadcast_mhlo.
I expect that the changes here will improve the --jax_dynamic_shapes coverage
as well.
2022-11-28 13:16:07 +01:00
|
|
|
def f_jax(x): # x: f32[w, h]
|
2022-12-17 05:56:48 +02:00
|
|
|
# x + (np.sin(x) + np.broadcast_to(np.arange(x.shape[1]), x.shape))
|
[jax2tf] An alternative support for shape polymorphism for native serialization.
jax2tf already supports many cases of shape polymorphism, e.g., those
where the shapes of all intermediates can be expressed as polynomials
in the dimension variables in the input. We want to achieve the same
same coverage, or more, while using StableHLO as the lowering format,
rather than tf.Graph.
For native serialization we will support two lowering implementations:
* one is using the growing support in JAX for dynamic shapes,
of which shape polymorphism is a special case.
This implementation is enabled with the --jax_dynamic_shapes flag.
At the moment, the JAX dynamic shapes support is still
incomplete and over 300 jax2tf shape polymorphism tests fail.
* a new one (added) here in which we form a Jaxpr using abstract
values that express dimension sizes as dimension polynomials
(as for the standard jax2tf). Then we lower the Jaxpr to StableHLO.
This implementation is enabled when --jax_dynamic_shapes is off.
With this implementation only 50 jax2tf tests fail (to be fixed
separately).
The key contribution here is to enable lowering a Jaxpr that contains
dimension polynomials in some of the intermediate shapes. Many lowering rules
already have some partial support for Jaxprs where the shapes contain
`Var`s. To the extent possible, we try to write lowering rules that should
cover both cases of dynamic shapes: Var or polynomials in shapes.
The lowering convention is that at top level we collect the sorted list
of dimension variable names in the inputs, and we store it in ModuleContext.dim_vars.
All IR functions will take N additional prefix arguments of int32 type
containing the values of the dimension variables. This is stored as
a list of `ir.Value` in `LoweringContext.dim_var_values`.
Note that the Jaxprs are not changed to have extra Vars for the dimension
variable values. An alternative implementation could work by transforming
the Jaxpr to replace dimension polynomials into Vars.
The key code pattern used in the lowering rule is::
if not core.is_constant_shape(shape): # Handles both Var, and polynomials
shape = mlir.eval_dynamic_shape(ctx, shape)
return mhlo.DynamicXXX(..., shape)
else:
return mhlo.XXX(..., shape)
with `mlir.eval_dynamic_shape` handling both cases::
def eval_dynamic_shape(ctx, shape):
if config.jax_dynamic_shapes:
# Using Var
return ... subst using ctx.axis_size_env ...
else:
# Using polynomials
return ... subst using ctx.module_context.dim_vars and ctx.dim_var_values
In order to support the above some lowering functions need to take a
LoweringContext parameter, e.g., mlir.broadcast_mhlo.
I expect that the changes here will improve the --jax_dynamic_shapes coverage
as well.
2022-11-28 13:16:07 +01:00
|
|
|
return jnp.sin(x) + jnp.arange(x.shape[1], dtype=x.dtype)
|
2022-12-17 05:56:48 +02:00
|
|
|
check_shape_poly(self,
|
|
|
|
lambda x: x + jax.jit(f_jax)(x),
|
|
|
|
arg_descriptors=[RandArg((3, 4), _f32)],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=["a, b"])
|
2022-09-06 09:32:45 +03:00
|
|
|
|
2023-06-16 12:50:50 +03:00
|
|
|
@jtu.parameterized_filterable(
|
|
|
|
kwargs=[
|
|
|
|
dict(testcase_name=str(polymorphic_shapes), polymorphic_shapes=polymorphic_shapes)
|
2023-03-29 12:09:47 +02:00
|
|
|
# The polymorphic_shapes should have three comma-separated DimExpr matching
|
|
|
|
# 16, 24, 32
|
|
|
|
for polymorphic_shapes in [
|
|
|
|
"b1+6,b1+14,b2", # b1=10, b2=32
|
|
|
|
"2*b1,4*b2,b1+b2+18", # b1=8,b2=6
|
|
|
|
"b1+2*b2,4*b2,b1*b1+16", # b1=4,b2=6
|
2023-06-16 12:50:50 +03:00
|
|
|
]
|
2023-03-29 12:09:47 +02:00
|
|
|
])
|
|
|
|
def test_non_trivial_polynomials_spec(self,
|
|
|
|
polymorphic_shapes="2*b1,4*b2,b1+b2+18"):
|
2022-12-09 17:47:56 +02:00
|
|
|
# We can handle non-trivial polynomials in the input shape,
|
2023-03-29 12:09:47 +02:00
|
|
|
# as long as all variables also occur in trivial expressions
|
2022-12-17 05:56:48 +02:00
|
|
|
check_shape_poly(self,
|
2023-03-29 12:09:47 +02:00
|
|
|
lambda x: 2 * x.shape[0] + 3 * x.shape[1] + 4 * x.shape[2],
|
|
|
|
arg_descriptors=[RandArg((16, 24, 32), _f32)],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=[polymorphic_shapes])
|
2022-12-09 17:47:56 +02:00
|
|
|
|
|
|
|
def test_unused_args(self):
|
|
|
|
# Tests with functions that do not use their inputs.
|
|
|
|
|
|
|
|
# First arg unused, not polymorphic
|
2022-12-17 05:56:48 +02:00
|
|
|
check_shape_poly(self,
|
|
|
|
lambda x_unused, y: y * 2.0,
|
|
|
|
arg_descriptors=[RandArg((2, 3), _f32), RandArg((3,), _f32)],
|
|
|
|
polymorphic_shapes=[None, "b"])
|
2022-12-09 17:47:56 +02:00
|
|
|
|
|
|
|
# Some args unused, not polymorphic
|
2022-12-17 05:56:48 +02:00
|
|
|
check_shape_poly(self,
|
|
|
|
lambda x_unused, y, z_unused, w: jnp.concatenate([y, w]),
|
|
|
|
arg_descriptors=[RandArg((3,), _f32), RandArg((4,), _f32),
|
|
|
|
RandArg((5,), _f32), RandArg((6,), _f32)],
|
|
|
|
polymorphic_shapes=[None, "b1", None, "b2"])
|
2022-12-09 17:47:56 +02:00
|
|
|
|
|
|
|
# A polymorphic arg is not used, but the dimension var appears
|
|
|
|
# in a used arg also
|
2022-12-17 05:56:48 +02:00
|
|
|
check_shape_poly(self,
|
|
|
|
lambda x_unused, y: y * 2.0,
|
|
|
|
arg_descriptors=[RandArg((3,), _f32), RandArg((3,), _f32)],
|
|
|
|
polymorphic_shapes=["b", "b"])
|
2022-12-09 17:47:56 +02:00
|
|
|
|
|
|
|
# A polymorphic arg is not used, and the dimension var does not appear
|
|
|
|
# elsewhere.
|
2022-12-17 05:56:48 +02:00
|
|
|
check_shape_poly(self,
|
2022-12-16 08:50:30 +02:00
|
|
|
lambda x_unused, y: y * 2.0,
|
2022-12-17 05:56:48 +02:00
|
|
|
arg_descriptors=[RandArg((4,), _f32), RandArg((3,), _f32)],
|
2023-03-27 13:12:10 +02:00
|
|
|
polymorphic_shapes=["b1", "b2"])
|
2022-12-09 17:47:56 +02:00
|
|
|
|
|
|
|
# A polymorphic arg is not used, and the dimension var does appear
|
|
|
|
# elsewhere but not as a trivial monomial.
|
2022-12-17 05:56:48 +02:00
|
|
|
check_shape_poly(self,
|
2022-12-16 08:50:30 +02:00
|
|
|
lambda x_unused, y: y * 2.0,
|
2022-12-17 05:56:48 +02:00
|
|
|
arg_descriptors=[RandArg((3,), _f32), RandArg((9,), _f32)],
|
2023-03-27 13:12:10 +02:00
|
|
|
polymorphic_shapes=["b1", "b1 * b1"])
|
2022-12-09 17:47:56 +02:00
|
|
|
|
2023-02-10 10:59:46 +01:00
|
|
|
# It is not sufficient to just use the shape of an input; it is still unused
|
2023-02-13 02:47:35 -08:00
|
|
|
check_shape_poly(self,
|
|
|
|
lambda x_unused, y: y + x_unused.shape[0],
|
|
|
|
arg_descriptors=[RandArg((3,), _f32), RandArg((9,), _f32)],
|
2023-03-27 13:12:10 +02:00
|
|
|
polymorphic_shapes=["b1", "b2"])
|
2023-02-10 10:59:46 +01:00
|
|
|
|
2020-10-11 19:48:36 +03:00
|
|
|
def test_with_custom_vjp(self):
|
|
|
|
"""Shape-polymorphic custom VJP."""
|
2021-04-08 06:21:12 -07:00
|
|
|
|
2020-10-11 19:48:36 +03:00
|
|
|
@jax.custom_vjp
|
|
|
|
def f(x):
|
|
|
|
# x: [b1, b2, d1, d2] (a batch of matrices)
|
|
|
|
# res: [b1, b2, d1, d1]
|
|
|
|
return jnp.matmul(x, jnp.transpose(x, axes=(0, 1, 3, 2)))
|
|
|
|
|
|
|
|
# f_fwd: a -> (b, residual)
|
|
|
|
def f_fwd(x):
|
|
|
|
# x: [b1, b2, d1, d2]
|
|
|
|
# b: [b1, b2, d1, d1]
|
2021-06-29 01:19:05 -07:00
|
|
|
# res: [b1, b2, d1, d1]
|
2020-10-11 19:48:36 +03:00
|
|
|
# residual: [b1, b2, d1, d2]
|
|
|
|
return f(x), 3. * x
|
2021-04-08 06:21:12 -07:00
|
|
|
|
2020-10-11 19:48:36 +03:00
|
|
|
# f_bwd: (residual, CT b) -> [CT a]
|
|
|
|
def f_bwd(residual, ct_b):
|
|
|
|
# residual: [b1, b2, d1, d2]
|
|
|
|
# ct_b: [b1, b2, d1, d1]
|
|
|
|
# ct_a: [b1, b2, d1, d2]
|
|
|
|
return jnp.matmul(ct_b, residual),
|
|
|
|
|
|
|
|
f.defvjp(f_fwd, f_bwd)
|
|
|
|
x = np.ones((2, 3, 4, 5), dtype=np.float32)
|
|
|
|
res_jax = f(x)
|
|
|
|
res_jax_grad = jax.grad(lambda x: jnp.sum(f(x)))(x)
|
|
|
|
|
2022-12-17 05:56:48 +02:00
|
|
|
f_tf = jax2tf.convert(f, polymorphic_shapes=["(batch1, batch2, d1, d2)"])
|
2020-10-11 19:48:36 +03:00
|
|
|
self.assertAllClose(res_jax, f_tf(x))
|
|
|
|
|
|
|
|
xv = tf.Variable(x, dtype=np.float32)
|
2021-04-08 06:21:12 -07:00
|
|
|
|
2020-10-11 19:48:36 +03:00
|
|
|
def tf_value_and_grad(xv):
|
|
|
|
with tf.GradientTape() as tape:
|
|
|
|
tape.watch(xv)
|
|
|
|
res_tf = f_tf(xv)
|
|
|
|
res_tf_grad = tape.gradient(res_tf, xv)
|
|
|
|
return res_tf, res_tf_grad
|
|
|
|
|
|
|
|
res_tf, res_tf_grad = tf_value_and_grad(xv)
|
|
|
|
self.assertAllClose(res_jax, res_tf)
|
|
|
|
self.assertAllClose(res_jax_grad, res_tf_grad)
|
|
|
|
|
|
|
|
# Now use TF tracing for the gradient
|
|
|
|
tf_grad = tf.function(
|
2021-06-29 01:19:05 -07:00
|
|
|
tf_value_and_grad, autograph=False).get_concrete_function(
|
|
|
|
tf.TensorSpec([3, 4, 8, 9]))
|
|
|
|
|
2023-03-15 10:30:52 -07:00
|
|
|
# for native serialization we cannot refine the inferred shape of the
|
2022-12-16 23:28:39 -08:00
|
|
|
# output if the input is more specific than polymorphic_shapes.
|
2023-10-12 13:15:22 +01:00
|
|
|
if config.jax2tf_default_native_serialization.value:
|
2022-12-16 23:28:39 -08:00
|
|
|
self.assertEqual((None, None, None, None), tuple(tf_grad.output_shapes[0]))
|
|
|
|
self.assertEqual((None, None, None, None), tuple(tf_grad.output_shapes[1]))
|
|
|
|
else:
|
|
|
|
self.assertEqual((3, 4, 8, 8), tuple(tf_grad.output_shapes[0]))
|
|
|
|
self.assertEqual((3, 4, 8, 9), tuple(tf_grad.output_shapes[1]))
|
2020-10-11 19:48:36 +03:00
|
|
|
|
|
|
|
def test_gradients_pytree(self):
|
|
|
|
"""Shape polymorphism with gradients and pytrees for inputs and outputs."""
|
2021-04-08 06:21:12 -07:00
|
|
|
|
2020-10-11 19:48:36 +03:00
|
|
|
def f(x):
|
|
|
|
# x: dict(x=[b, 3, 4])
|
|
|
|
# res: dict(res=[b, 3, 4])
|
|
|
|
return dict(res=x["x"] * 2.)
|
|
|
|
|
2022-12-17 05:56:48 +02:00
|
|
|
check_shape_poly(self,
|
|
|
|
f,
|
|
|
|
skip_jax_run=True,
|
|
|
|
input_signature=[dict(x=tf.TensorSpec([None, 3, 4]))],
|
|
|
|
polymorphic_shapes=[dict(x=("b, 3, 4"))])
|
2020-10-11 19:48:36 +03:00
|
|
|
|
2022-12-17 05:56:48 +02:00
|
|
|
f_tf = jax2tf.convert(f, polymorphic_shapes=[dict(x=("b, 3, 4"))])
|
2020-10-11 19:48:36 +03:00
|
|
|
x = dict(x=np.ones((2, 3, 4), dtype=np.float32))
|
|
|
|
xv = tf.Variable(x["x"], dtype=np.float32)
|
2021-04-08 06:21:12 -07:00
|
|
|
|
2020-10-11 19:48:36 +03:00
|
|
|
def tf_value_and_grad(xv):
|
|
|
|
# xv: [b, 3, 4]
|
|
|
|
# res_value: dict(res=[b, 3, 4])
|
|
|
|
# res_grad: dict(grad=[b, 3, 4])
|
|
|
|
with tf.GradientTape() as tape:
|
|
|
|
tape.watch(xv)
|
|
|
|
res_tf = f_tf(dict(x=xv))
|
|
|
|
res_tf_grad = tape.gradient(res_tf, xv)
|
|
|
|
return res_tf, dict(grad=res_tf_grad)
|
|
|
|
|
|
|
|
res_tf, res_tf_grad = tf_value_and_grad(xv)
|
|
|
|
# Now use TF tracing for the gradient
|
2021-04-08 06:21:12 -07:00
|
|
|
tf_grad = tf.function(
|
|
|
|
tf_value_and_grad,
|
|
|
|
autograph=False).get_concrete_function(tf.TensorSpec([None, 3, 4]))
|
2020-10-11 19:48:36 +03:00
|
|
|
# The shape of the value
|
|
|
|
self.assertEqual((None, 3, 4), tuple(tf_grad.output_shapes[0]["res"]))
|
|
|
|
# The shape of the gradient should match the input
|
|
|
|
self.assertEqual((None, 3, 4), tuple(tf_grad.output_shapes[1]["grad"]))
|
|
|
|
|
2021-06-23 10:52:03 +02:00
|
|
|
def test_grad_not_var_output(self):
|
|
|
|
def f_jax(x): # :[b, 3]
|
|
|
|
return jnp.reshape(x, (-1,)) # : [3b]
|
|
|
|
x = np.arange(12, dtype=np.float32).reshape((4, 3))
|
|
|
|
xv = tf.Variable(x)
|
|
|
|
|
|
|
|
f_tf = jax2tf.convert(f_jax, with_gradient=True,
|
|
|
|
polymorphic_shapes=["b, ..."])
|
|
|
|
|
|
|
|
with tf.GradientTape() as tape:
|
|
|
|
res_tf = f_tf(xv)
|
|
|
|
grad_tf = tape.gradient(res_tf, xv)
|
|
|
|
self.assertAllClose(np.ones(x.shape, dtype=np.float32), grad_tf.numpy())
|
|
|
|
|
2020-10-15 08:24:35 +03:00
|
|
|
def test_cond(self):
|
|
|
|
# Test the primitive under conditional
|
|
|
|
def f(x, y):
|
|
|
|
# x: f32[B, H], y : f32[H]
|
2021-04-08 06:21:12 -07:00
|
|
|
return lax.cond(
|
|
|
|
jnp.sum(x) > 0.,
|
|
|
|
lambda _: x + y,
|
|
|
|
lambda _: jnp.zeros_like(x),
|
|
|
|
operand=None)
|
|
|
|
|
2020-10-15 08:24:35 +03:00
|
|
|
x = np.ones((2, 3))
|
|
|
|
y = np.ones((3,))
|
|
|
|
res_jax = f(x, y)
|
2021-04-08 06:21:12 -07:00
|
|
|
self.assertAllClose(
|
|
|
|
res_jax,
|
2023-11-12 18:17:07 +01:00
|
|
|
check_shape_poly(self, f, arg_descriptors=[x, y],
|
|
|
|
polymorphic_shapes=["(b, h)", "h"]))
|
2020-10-15 08:24:35 +03:00
|
|
|
|
[jax2tf] An alternative support for shape polymorphism for native serialization.
jax2tf already supports many cases of shape polymorphism, e.g., those
where the shapes of all intermediates can be expressed as polynomials
in the dimension variables in the input. We want to achieve the same
same coverage, or more, while using StableHLO as the lowering format,
rather than tf.Graph.
For native serialization we will support two lowering implementations:
* one is using the growing support in JAX for dynamic shapes,
of which shape polymorphism is a special case.
This implementation is enabled with the --jax_dynamic_shapes flag.
At the moment, the JAX dynamic shapes support is still
incomplete and over 300 jax2tf shape polymorphism tests fail.
* a new one (added) here in which we form a Jaxpr using abstract
values that express dimension sizes as dimension polynomials
(as for the standard jax2tf). Then we lower the Jaxpr to StableHLO.
This implementation is enabled when --jax_dynamic_shapes is off.
With this implementation only 50 jax2tf tests fail (to be fixed
separately).
The key contribution here is to enable lowering a Jaxpr that contains
dimension polynomials in some of the intermediate shapes. Many lowering rules
already have some partial support for Jaxprs where the shapes contain
`Var`s. To the extent possible, we try to write lowering rules that should
cover both cases of dynamic shapes: Var or polynomials in shapes.
The lowering convention is that at top level we collect the sorted list
of dimension variable names in the inputs, and we store it in ModuleContext.dim_vars.
All IR functions will take N additional prefix arguments of int32 type
containing the values of the dimension variables. This is stored as
a list of `ir.Value` in `LoweringContext.dim_var_values`.
Note that the Jaxprs are not changed to have extra Vars for the dimension
variable values. An alternative implementation could work by transforming
the Jaxpr to replace dimension polynomials into Vars.
The key code pattern used in the lowering rule is::
if not core.is_constant_shape(shape): # Handles both Var, and polynomials
shape = mlir.eval_dynamic_shape(ctx, shape)
return mhlo.DynamicXXX(..., shape)
else:
return mhlo.XXX(..., shape)
with `mlir.eval_dynamic_shape` handling both cases::
def eval_dynamic_shape(ctx, shape):
if config.jax_dynamic_shapes:
# Using Var
return ... subst using ctx.axis_size_env ...
else:
# Using polynomials
return ... subst using ctx.module_context.dim_vars and ctx.dim_var_values
In order to support the above some lowering functions need to take a
LoweringContext parameter, e.g., mlir.broadcast_mhlo.
I expect that the changes here will improve the --jax_dynamic_shapes coverage
as well.
2022-11-28 13:16:07 +01:00
|
|
|
def test_while(self):
|
|
|
|
def f(x):
|
|
|
|
# x: f32[B], iter: i32
|
|
|
|
return lax.while_loop(lambda x_iter: x_iter[1] < 5,
|
|
|
|
lambda x_iter: (x_iter[0] + jnp.arange(x_iter[0].shape[0], dtype=np.float32), x_iter[1] + 1),
|
|
|
|
(x, 0))
|
|
|
|
|
|
|
|
x = np.ones((3,), dtype=np.float32)
|
2023-11-12 18:17:07 +01:00
|
|
|
res_tf = check_shape_poly(self, f, arg_descriptors=[x],
|
|
|
|
polymorphic_shapes=["(b,)"])
|
[jax2tf] An alternative support for shape polymorphism for native serialization.
jax2tf already supports many cases of shape polymorphism, e.g., those
where the shapes of all intermediates can be expressed as polynomials
in the dimension variables in the input. We want to achieve the same
same coverage, or more, while using StableHLO as the lowering format,
rather than tf.Graph.
For native serialization we will support two lowering implementations:
* one is using the growing support in JAX for dynamic shapes,
of which shape polymorphism is a special case.
This implementation is enabled with the --jax_dynamic_shapes flag.
At the moment, the JAX dynamic shapes support is still
incomplete and over 300 jax2tf shape polymorphism tests fail.
* a new one (added) here in which we form a Jaxpr using abstract
values that express dimension sizes as dimension polynomials
(as for the standard jax2tf). Then we lower the Jaxpr to StableHLO.
This implementation is enabled when --jax_dynamic_shapes is off.
With this implementation only 50 jax2tf tests fail (to be fixed
separately).
The key contribution here is to enable lowering a Jaxpr that contains
dimension polynomials in some of the intermediate shapes. Many lowering rules
already have some partial support for Jaxprs where the shapes contain
`Var`s. To the extent possible, we try to write lowering rules that should
cover both cases of dynamic shapes: Var or polynomials in shapes.
The lowering convention is that at top level we collect the sorted list
of dimension variable names in the inputs, and we store it in ModuleContext.dim_vars.
All IR functions will take N additional prefix arguments of int32 type
containing the values of the dimension variables. This is stored as
a list of `ir.Value` in `LoweringContext.dim_var_values`.
Note that the Jaxprs are not changed to have extra Vars for the dimension
variable values. An alternative implementation could work by transforming
the Jaxpr to replace dimension polynomials into Vars.
The key code pattern used in the lowering rule is::
if not core.is_constant_shape(shape): # Handles both Var, and polynomials
shape = mlir.eval_dynamic_shape(ctx, shape)
return mhlo.DynamicXXX(..., shape)
else:
return mhlo.XXX(..., shape)
with `mlir.eval_dynamic_shape` handling both cases::
def eval_dynamic_shape(ctx, shape):
if config.jax_dynamic_shapes:
# Using Var
return ... subst using ctx.axis_size_env ...
else:
# Using polynomials
return ... subst using ctx.module_context.dim_vars and ctx.dim_var_values
In order to support the above some lowering functions need to take a
LoweringContext parameter, e.g., mlir.broadcast_mhlo.
I expect that the changes here will improve the --jax_dynamic_shapes coverage
as well.
2022-11-28 13:16:07 +01:00
|
|
|
self.assertAllClose(f(x), res_tf)
|
|
|
|
|
2023-07-03 17:31:31 +03:00
|
|
|
@jtu.parameterized_filterable(
|
|
|
|
kwargs=[dict(with_function=v) for v in [True, False]]
|
|
|
|
)
|
2023-04-04 13:23:43 +02:00
|
|
|
def test_grad_int(self, with_function=False):
|
2024-09-20 07:51:48 -07:00
|
|
|
# https://github.com/jax-ml/jax/issues/7093
|
2021-07-27 15:50:47 +03:00
|
|
|
# Also issue #6975.
|
2021-06-25 08:43:04 +02:00
|
|
|
x_shape = (2, 3, 4)
|
2023-04-13 11:48:11 -07:00
|
|
|
xi = np.arange(math.prod(x_shape), dtype=np.int16).reshape(x_shape)
|
2021-06-25 08:43:04 +02:00
|
|
|
yf = xi.astype(np.float32)
|
2021-07-27 15:50:47 +03:00
|
|
|
xi_yf = (xi, yf)
|
|
|
|
zb = np.array([True, False], dtype=np.bool_)
|
|
|
|
def f_jax(xi_yf, zb): # xi: s16[2, 3, 4], yf: f32[2, 3, 4], zb: bool[2]
|
2024-09-24 12:28:32 -07:00
|
|
|
# results: f32[2, 3, 4], s16[2, 3, 4], bool[2], f32[2, 3, 4]
|
2021-07-27 15:50:47 +03:00
|
|
|
xi, yf = xi_yf
|
|
|
|
# Return a tuple:
|
|
|
|
# (1) float constant, with 0 tangent;
|
|
|
|
# (2) a tuple with:
|
|
|
|
# (2.1) the integer input;
|
|
|
|
# (2.2) the boolean input;
|
|
|
|
# (2.3) a float depending on both inputs.
|
|
|
|
# TODO: there is a problem if we add a None output
|
|
|
|
return (jnp.zeros(xi.shape, dtype=jnp.float32),
|
|
|
|
(xi, zb, xi.astype(np.float32) * 2. * yf))
|
|
|
|
|
|
|
|
args = (xi_yf, zb)
|
|
|
|
|
|
|
|
f_tf = jax2tf.convert(f_jax, polymorphic_shapes=[("b1, b2, 4", "b1, b2, 4"), "b1"])
|
|
|
|
if with_function:
|
|
|
|
f_tf = tf.function(f_tf, autograph=False)
|
|
|
|
|
|
|
|
res_tf, g_tf = tf_test_util.ComputeTfValueAndGrad(f_tf, args)
|
|
|
|
self.assertAllClose(g_tf[0][0], np.zeros_like(xi))
|
|
|
|
self.assertAllClose(g_tf[0][1], (xi * 2).astype(yf.dtype))
|
|
|
|
self.assertAllClose(g_tf[1], np.zeros_like(zb))
|
2021-06-25 08:43:04 +02:00
|
|
|
|
2022-12-04 08:37:24 +02:00
|
|
|
def test_prng(self):
|
|
|
|
# The PRNG implementation uses opaque types, test shape polymorphism
|
2023-10-12 13:15:22 +01:00
|
|
|
with config.enable_custom_prng(True):
|
2022-12-04 08:37:24 +02:00
|
|
|
|
|
|
|
def f_jax(x): # x: f32[b1, b2]
|
2023-05-01 11:36:49 +02:00
|
|
|
key = random.PRNGKey(123) # key: key<fry>[]
|
2022-12-04 08:37:24 +02:00
|
|
|
# Exercise key operations that have custom lowering rules
|
2023-05-01 11:36:49 +02:00
|
|
|
broadcast_keys = lax.broadcast_in_dim(key, x.shape, ()) # key<fry>[b1, b2]
|
2022-12-04 08:37:24 +02:00
|
|
|
gather_keys = lax.broadcast_in_dim(broadcast_keys[0], (1, x.shape[1]), (1,)) # : key[1, b2]
|
|
|
|
slice_keys1 = lax.slice(broadcast_keys, (0, 0), (1, x.shape[1]), (1, 1)) # key[1, b2]
|
|
|
|
slice_keys2 = lax.dynamic_slice(broadcast_keys, (0, 0), slice_sizes=(1, x.shape[1])) # key[1, b2]
|
|
|
|
upd1 = lax.dynamic_update_slice(slice_keys2, slice_keys1, start_indices=(0, 0)) # key[1, b2]
|
|
|
|
_ = lax.dynamic_update_slice(upd1, gather_keys, start_indices=(0, 0))
|
2023-05-01 11:36:49 +02:00
|
|
|
|
|
|
|
# We need to test the special case for vmap(while)
|
|
|
|
xs = broadcast_keys
|
|
|
|
counts = jnp.arange(broadcast_keys.shape[0], dtype=np.int32)
|
|
|
|
def f_vmap_jax(counts, xs): # counts: i32[b1], xs: key<fry>[b1, b2]
|
|
|
|
def inner(count, x): # count i32, x: key<fry>[b2]
|
|
|
|
return lax.fori_loop(0, count, lambda _, acc: acc, x)
|
|
|
|
return jax.vmap(inner)(counts, xs)
|
|
|
|
|
|
|
|
_ = f_vmap_jax(counts, xs)
|
2022-12-04 08:37:24 +02:00
|
|
|
return x
|
|
|
|
|
2022-12-17 05:56:48 +02:00
|
|
|
check_shape_poly(self, f_jax,
|
|
|
|
arg_descriptors=[RandArg((3, 4), _f32)],
|
|
|
|
polymorphic_shapes=["b1, b2"])
|
2022-12-04 08:37:24 +02:00
|
|
|
|
2021-06-25 08:43:04 +02:00
|
|
|
def test_saved_model(self):
|
2021-06-10 17:01:22 +02:00
|
|
|
f_jax = jnp.sin
|
|
|
|
f_tf = jax2tf.convert(f_jax, polymorphic_shapes=["(b, ...)"])
|
|
|
|
x = np.array([0.7, 0.8], dtype=np.float32)
|
2021-07-28 19:30:44 +03:00
|
|
|
restored_f, _ = tf_test_util.SaveAndLoadFunction(
|
|
|
|
f_tf, input_signature=[tf.TensorSpec([None], x.dtype)])
|
2021-06-10 17:01:22 +02:00
|
|
|
self.assertAllClose(f_jax(x), restored_f(x))
|
|
|
|
# Ensure that restored_f works at other batch size as well
|
|
|
|
y = np.concatenate([x, x])
|
|
|
|
self.assertAllClose(f_jax(y), restored_f(y))
|
|
|
|
|
2021-06-25 08:43:04 +02:00
|
|
|
def test_saved_model_int_function(self):
|
2023-02-13 02:47:35 -08:00
|
|
|
|
2021-07-27 15:50:47 +03:00
|
|
|
def f_jax(x): # x:s32[b, 3, 4]
|
|
|
|
return jnp.reshape(x, (-1,)) # : s32[b * 12]
|
2021-06-25 08:43:04 +02:00
|
|
|
f_tf = jax2tf.convert(f_jax, polymorphic_shapes=["(b, ...)"])
|
2021-07-27 15:50:47 +03:00
|
|
|
f_tf = tf.function(f_tf, autograph=False)
|
2021-06-25 08:43:04 +02:00
|
|
|
x_shape = (2, 3, 4)
|
2023-04-13 11:48:11 -07:00
|
|
|
x = np.arange(math.prod(x_shape), dtype=np.int32).reshape(x_shape)
|
2021-06-25 08:43:04 +02:00
|
|
|
|
|
|
|
# When saving the model with gradients, we trace the gradient function
|
|
|
|
# and we used to get an error when creating zeros_like_aval for a
|
|
|
|
# polymorphic shape
|
2021-07-28 19:30:44 +03:00
|
|
|
restored_f, _ = tf_test_util.SaveAndLoadFunction(
|
|
|
|
f_tf, input_signature=[tf.TensorSpec((None,) + x.shape[1:], x.dtype)])
|
2021-06-25 08:43:04 +02:00
|
|
|
f_jax_rt = jax2tf.call_tf(restored_f)
|
|
|
|
res_jax_rt = f_jax_rt(x)
|
|
|
|
self.assertAllClose(f_jax(x), res_jax_rt)
|
|
|
|
|
|
|
|
def test_saved_model_constant_gradient(self):
|
|
|
|
def f_jax(x): # A function whose gradient is a constant
|
2023-02-13 02:47:35 -08:00
|
|
|
return x
|
2021-06-25 08:43:04 +02:00
|
|
|
|
|
|
|
f_tf = jax2tf.convert(f_jax, polymorphic_shapes=["(b, ...)"])
|
|
|
|
x = np.array([0.7, 0.8], dtype=np.float32)
|
2021-07-28 19:30:44 +03:00
|
|
|
restored_f, _ = tf_test_util.SaveAndLoadFunction(
|
|
|
|
f_tf, input_signature=[tf.TensorSpec([None], x.dtype)])
|
2023-02-13 02:47:35 -08:00
|
|
|
self.assertAllClose(f_jax(x), restored_f(x))
|
|
|
|
|
2024-09-24 12:28:32 -07:00
|
|
|
@jtu.ignore_warning(
|
2024-10-25 02:30:17 -07:00
|
|
|
message="jax2tf.convert with native_serialization=False has been deprecated"
|
2024-09-24 12:28:32 -07:00
|
|
|
)
|
2022-10-07 09:45:12 +03:00
|
|
|
def test_readme_examples(self):
|
2021-04-15 09:50:00 +03:00
|
|
|
"""Some of the examples from the README."""
|
|
|
|
|
2022-06-23 16:49:47 +03:00
|
|
|
jax2tf.convert(lambda x: jnp.reshape(x, (x.shape[0] * x.shape[1],)),
|
|
|
|
polymorphic_shapes=["(b, 4)"])(np.ones((3, 4)))
|
|
|
|
|
2023-04-13 11:48:11 -07:00
|
|
|
jax2tf.convert(lambda x: jnp.reshape(x, (math.prod(x.shape),)),
|
2022-06-23 16:49:47 +03:00
|
|
|
polymorphic_shapes=["(b, 4)"])(np.ones((3, 4)))
|
|
|
|
|
2023-01-18 12:27:02 +02:00
|
|
|
jax2tf.convert(lambda x: x + x.shape[0] + jnp.sin(x.shape[0]),
|
|
|
|
polymorphic_shapes=["b"])(np.ones(3))
|
2022-10-07 09:45:12 +03:00
|
|
|
|
|
|
|
jax2tf.convert(lambda x: jnp.sum(x, axis=0) / x.shape[0],
|
|
|
|
polymorphic_shapes=["(v, _)"])(np.ones((3, 4)))
|
|
|
|
|
2022-06-23 16:49:47 +03:00
|
|
|
with self.assertRaisesRegex(TypeError,
|
2022-10-07 09:45:12 +03:00
|
|
|
"prod requires ndarray or scalar arguments"):
|
2023-02-13 02:47:35 -08:00
|
|
|
jax2tf.convert(lambda x: jnp.prod(x.shape) + x,
|
2022-06-23 16:49:47 +03:00
|
|
|
polymorphic_shapes=["(b, 4)"])(np.ones((3, 4)))
|
|
|
|
|
2023-02-13 02:47:35 -08:00
|
|
|
jax2tf.convert(lambda x: jnp.prod(jnp.array(x.shape)) + x,
|
2022-10-07 09:45:12 +03:00
|
|
|
polymorphic_shapes=["(b, 4)"])(np.ones((3, 4)))
|
|
|
|
|
2022-12-17 05:56:48 +02:00
|
|
|
four_ones = np.ones((4,))
|
2021-04-08 06:21:12 -07:00
|
|
|
with self.assertRaisesRegex(
|
|
|
|
TypeError,
|
|
|
|
re.escape("add got incompatible shapes for broadcasting: (v,), (4,)")):
|
2022-12-17 05:56:48 +02:00
|
|
|
jax2tf.convert(lambda x, y: x + y,
|
|
|
|
polymorphic_shapes=["(v,)", "(4,)"])(four_ones, four_ones)
|
2020-10-15 08:24:35 +03:00
|
|
|
|
|
|
|
# We get the error even if we use correct actual arguments
|
2021-04-08 06:21:12 -07:00
|
|
|
with self.assertRaisesRegex(
|
|
|
|
TypeError,
|
|
|
|
re.escape("add got incompatible shapes for broadcasting: (v,), (4,)")):
|
|
|
|
jax2tf.convert(
|
|
|
|
lambda x, y: x + y, polymorphic_shapes=["(v,)", "(4,)"])(four_ones,
|
|
|
|
four_ones)
|
2020-10-15 08:24:35 +03:00
|
|
|
|
[jax2tf] Improved coverage of shape polymorphism by allowing dimension polynomials.
Previously we allowed a dimension variable in lieu of a dimension. Now we
allow multi-variate dimension polynomials. These polynomials overload addition, subtraction,
multiplication. They also partially support equality and inequality checking.
Equality and inequality are supported only when the operation result is the
same for all valuations of variables greater than 0. For example, `a == a`,
`a * b + 1 == 1 + b * a`, `a >= 1`, `2 * a + b >= 3`, `a >= a`. However, for
the following a `core.InconclusiveDimensionOperation` is raised: `a = b`, `a
>= 2`.
Division is supported only in the cases when either there is no remainder,
or the divisor is a constant.
This change allows us to support more general cases of `jnp.reshape(-1)`,
such as those used in the internal implementation of `random_gamma`:
```
y = x.reshape((2, -1))
z = ... y ...
return z.reshape(x.shape)
```
2021-05-20 14:07:52 +03:00
|
|
|
with self.assertRaisesRegex(TypeError,
|
2022-01-20 22:58:09 -08:00
|
|
|
re.escape("dot_general requires contracting dimensions to have the same shape, got (4,) and (v,)")):
|
2021-04-15 09:50:00 +03:00
|
|
|
jax2tf.convert(lambda x: jnp.matmul(x, x),
|
|
|
|
polymorphic_shapes=["(v, 4)"])(np.ones((4, 4)))
|
|
|
|
|
[jax2tf] Improved coverage of shape polymorphism by allowing dimension polynomials.
Previously we allowed a dimension variable in lieu of a dimension. Now we
allow multi-variate dimension polynomials. These polynomials overload addition, subtraction,
multiplication. They also partially support equality and inequality checking.
Equality and inequality are supported only when the operation result is the
same for all valuations of variables greater than 0. For example, `a == a`,
`a * b + 1 == 1 + b * a`, `a >= 1`, `2 * a + b >= 3`, `a >= a`. However, for
the following a `core.InconclusiveDimensionOperation` is raised: `a = b`, `a
>= 2`.
Division is supported only in the cases when either there is no remainder,
or the divisor is a constant.
This change allows us to support more general cases of `jnp.reshape(-1)`,
such as those used in the internal implementation of `random_gamma`:
```
y = x.reshape((2, -1))
z = ... y ...
return z.reshape(x.shape)
```
2021-05-20 14:07:52 +03:00
|
|
|
with self.assertRaisesRegex(core.InconclusiveDimensionOperation,
|
2023-01-18 12:27:02 +02:00
|
|
|
re.compile("Cannot divide evenly the sizes of shapes \\(b, 5, 7\\) and \\(2, -1\\)",
|
2022-07-13 17:08:51 +03:00
|
|
|
re.DOTALL)):
|
[jax2tf] Improved coverage of shape polymorphism by allowing dimension polynomials.
Previously we allowed a dimension variable in lieu of a dimension. Now we
allow multi-variate dimension polynomials. These polynomials overload addition, subtraction,
multiplication. They also partially support equality and inequality checking.
Equality and inequality are supported only when the operation result is the
same for all valuations of variables greater than 0. For example, `a == a`,
`a * b + 1 == 1 + b * a`, `a >= 1`, `2 * a + b >= 3`, `a >= a`. However, for
the following a `core.InconclusiveDimensionOperation` is raised: `a = b`, `a
>= 2`.
Division is supported only in the cases when either there is no remainder,
or the divisor is a constant.
This change allows us to support more general cases of `jnp.reshape(-1)`,
such as those used in the internal implementation of `random_gamma`:
```
y = x.reshape((2, -1))
z = ... y ...
return z.reshape(x.shape)
```
2021-05-20 14:07:52 +03:00
|
|
|
jax2tf.convert(lambda x: jnp.reshape(x, (2, -1)),
|
|
|
|
polymorphic_shapes=["(b, _, _)"])(np.ones((4, 5, 7)))
|
2021-04-15 09:50:00 +03:00
|
|
|
|
[jax2tf] Improved coverage of shape polymorphism by allowing dimension polynomials.
Previously we allowed a dimension variable in lieu of a dimension. Now we
allow multi-variate dimension polynomials. These polynomials overload addition, subtraction,
multiplication. They also partially support equality and inequality checking.
Equality and inequality are supported only when the operation result is the
same for all valuations of variables greater than 0. For example, `a == a`,
`a * b + 1 == 1 + b * a`, `a >= 1`, `2 * a + b >= 3`, `a >= a`. However, for
the following a `core.InconclusiveDimensionOperation` is raised: `a = b`, `a
>= 2`.
Division is supported only in the cases when either there is no remainder,
or the divisor is a constant.
This change allows us to support more general cases of `jnp.reshape(-1)`,
such as those used in the internal implementation of `random_gamma`:
```
y = x.reshape((2, -1))
z = ... y ...
return z.reshape(x.shape)
```
2021-05-20 14:07:52 +03:00
|
|
|
jax2tf.convert(lambda x: jnp.reshape(x, (2, -1)),
|
|
|
|
polymorphic_shapes=["(b, _, _)"])(np.ones((4, 5, 6)))
|
|
|
|
jax2tf.convert(lambda x: jnp.reshape(x, (-1, x.shape[0])),
|
|
|
|
polymorphic_shapes=["(b1, b2, ...)"])(np.ones((4, 5, 6)))
|
2021-04-15 09:50:00 +03:00
|
|
|
|
2023-03-29 12:09:47 +02:00
|
|
|
jax2tf.convert(lambda x: jnp.reshape(x, (2, -1)),
|
|
|
|
polymorphic_shapes=["(2*b, ...)"])(np.ones((4, 5, 7)))
|
2022-07-13 17:08:51 +03:00
|
|
|
|
[jax2tf] Improved coverage of shape polymorphism by allowing dimension polynomials.
Previously we allowed a dimension variable in lieu of a dimension. Now we
allow multi-variate dimension polynomials. These polynomials overload addition, subtraction,
multiplication. They also partially support equality and inequality checking.
Equality and inequality are supported only when the operation result is the
same for all valuations of variables greater than 0. For example, `a == a`,
`a * b + 1 == 1 + b * a`, `a >= 1`, `2 * a + b >= 3`, `a >= a`. However, for
the following a `core.InconclusiveDimensionOperation` is raised: `a = b`, `a
>= 2`.
Division is supported only in the cases when either there is no remainder,
or the divisor is a constant.
This change allows us to support more general cases of `jnp.reshape(-1)`,
such as those used in the internal implementation of `random_gamma`:
```
y = x.reshape((2, -1))
z = ... y ...
return z.reshape(x.shape)
```
2021-05-20 14:07:52 +03:00
|
|
|
with self.assertRaisesRegex(
|
|
|
|
core.InconclusiveDimensionOperation,
|
2023-02-04 08:30:44 +02:00
|
|
|
re.escape("Symbolic dimension comparison 'a + 1' >= 'b' is inconclusive")):
|
|
|
|
jax2tf.convert(lambda x: 0 if x.shape[0] + 1 >= x.shape[1] else 1,
|
[jax2tf] Improved coverage of shape polymorphism by allowing dimension polynomials.
Previously we allowed a dimension variable in lieu of a dimension. Now we
allow multi-variate dimension polynomials. These polynomials overload addition, subtraction,
multiplication. They also partially support equality and inequality checking.
Equality and inequality are supported only when the operation result is the
same for all valuations of variables greater than 0. For example, `a == a`,
`a * b + 1 == 1 + b * a`, `a >= 1`, `2 * a + b >= 3`, `a >= a`. However, for
the following a `core.InconclusiveDimensionOperation` is raised: `a = b`, `a
>= 2`.
Division is supported only in the cases when either there is no remainder,
or the divisor is a constant.
This change allows us to support more general cases of `jnp.reshape(-1)`,
such as those used in the internal implementation of `random_gamma`:
```
y = x.reshape((2, -1))
z = ... y ...
return z.reshape(x.shape)
```
2021-05-20 14:07:52 +03:00
|
|
|
polymorphic_shapes=["(a, b)"])(np.ones((4, 4)))
|
2021-04-05 16:37:35 +03:00
|
|
|
|
2023-07-03 17:31:31 +03:00
|
|
|
# Checking that the dimension variable is >= 1
|
2023-03-12 12:52:46 +02:00
|
|
|
def f1_jax(x): # f32[b]
|
|
|
|
# We have to use "x"
|
|
|
|
return jnp.concatenate([x, jnp.array([0. if x.shape[0] == 0 else 1.],
|
|
|
|
dtype=np.float32)])
|
2020-10-15 08:24:35 +03:00
|
|
|
|
[jax2tf] Improved coverage of shape polymorphism by allowing dimension polynomials.
Previously we allowed a dimension variable in lieu of a dimension. Now we
allow multi-variate dimension polynomials. These polynomials overload addition, subtraction,
multiplication. They also partially support equality and inequality checking.
Equality and inequality are supported only when the operation result is the
same for all valuations of variables greater than 0. For example, `a == a`,
`a * b + 1 == 1 + b * a`, `a >= 1`, `2 * a + b >= 3`, `a >= a`. However, for
the following a `core.InconclusiveDimensionOperation` is raised: `a = b`, `a
>= 2`.
Division is supported only in the cases when either there is no remainder,
or the divisor is a constant.
This change allows us to support more general cases of `jnp.reshape(-1)`,
such as those used in the internal implementation of `random_gamma`:
```
y = x.reshape((2, -1))
z = ... y ...
return z.reshape(x.shape)
```
2021-05-20 14:07:52 +03:00
|
|
|
x0 = np.array([], np.float32)
|
2023-03-12 12:52:46 +02:00
|
|
|
self.assertEqual(jnp.array([0.], dtype=np.float32), f1_jax(x0))
|
2021-04-09 13:46:28 +03:00
|
|
|
|
2023-07-03 17:31:31 +03:00
|
|
|
# We also catch the error with native serialization
|
|
|
|
with self.assertRaisesRegex(
|
|
|
|
tf.errors.InvalidArgumentError,
|
2023-07-21 14:46:30 +03:00
|
|
|
re.escape(
|
|
|
|
"Expected value >= 1 for dimension variable 'b'. "
|
|
|
|
"Using the following polymorphic shapes specifications: args[0].shape = (b,). "
|
|
|
|
"Obtained dimension variables: 'b' = 0")):
|
2024-07-16 02:04:59 -07:00
|
|
|
_ = jax2tf.convert(f1_jax, polymorphic_shapes=["b"])(x0)
|
2023-07-03 17:31:31 +03:00
|
|
|
|
|
|
|
# Checking that the actual dimensions denoted by the same
|
2021-07-26 17:48:21 +03:00
|
|
|
# dimension variables have equal sizes.
|
2023-03-12 12:52:46 +02:00
|
|
|
def f2_jax(x): # f32[b, b]
|
|
|
|
# We have to use "x"
|
|
|
|
return jnp.sum(x) + (0. if x.shape[0] != x.shape[1] else 1.)
|
2021-04-09 13:46:28 +03:00
|
|
|
|
[jax2tf] Improved coverage of shape polymorphism by allowing dimension polynomials.
Previously we allowed a dimension variable in lieu of a dimension. Now we
allow multi-variate dimension polynomials. These polynomials overload addition, subtraction,
multiplication. They also partially support equality and inequality checking.
Equality and inequality are supported only when the operation result is the
same for all valuations of variables greater than 0. For example, `a == a`,
`a * b + 1 == 1 + b * a`, `a >= 1`, `2 * a + b >= 3`, `a >= a`. However, for
the following a `core.InconclusiveDimensionOperation` is raised: `a = b`, `a
>= 2`.
Division is supported only in the cases when either there is no remainder,
or the divisor is a constant.
This change allows us to support more general cases of `jnp.reshape(-1)`,
such as those used in the internal implementation of `random_gamma`:
```
y = x.reshape((2, -1))
z = ... y ...
return z.reshape(x.shape)
```
2021-05-20 14:07:52 +03:00
|
|
|
x45 = np.ones((4, 5), dtype=np.float32)
|
2023-03-12 12:52:46 +02:00
|
|
|
# JAX with static shapes sees that x.shape[0] != x.shape[1]
|
|
|
|
self.assertEqual(jnp.sum(x45), f2_jax(x45))
|
2021-04-09 13:46:28 +03:00
|
|
|
|
2023-07-03 17:31:31 +03:00
|
|
|
# We also catch the error with native serialization
|
|
|
|
with self.assertRaisesRegex(
|
|
|
|
tf.errors.InvalidArgumentError,
|
2023-07-21 14:46:30 +03:00
|
|
|
re.escape(
|
|
|
|
"Found inconsistency between dimension size args[0].shape[1] (= 5) "
|
|
|
|
"and the specification 'b' (= 4)")):
|
2024-07-16 02:04:59 -07:00
|
|
|
_ = jax2tf.convert(f2_jax, polymorphic_shapes=["b, b"])(x45)
|
2023-07-03 17:31:31 +03:00
|
|
|
|
2022-12-09 17:47:56 +02:00
|
|
|
x = np.ones((5,), dtype=np.float32)
|
2023-03-12 12:52:46 +02:00
|
|
|
with self.assertRaisesRegex(
|
|
|
|
ValueError,
|
|
|
|
"Cannot solve for values of dimension variables"):
|
2023-03-29 12:09:47 +02:00
|
|
|
jax2tf.convert(lambda x: jnp.sum(x), polymorphic_shapes=["a + b"])(x)
|
|
|
|
|
2020-10-15 08:24:35 +03:00
|
|
|
def test_dynamic_shapes(self):
|
[jax2tf] Expand shape polymorphism support to use dimension polynomials as values.
The goal of this change is to support shape polymorphism for operations
such as average (which needs to divide by the size of a dimension) or
indexing (which needs to normalize indices by comparing them with 0 and
adding dimension size for negative indices). In both of these cases
the size of a dimenion needs to be used as a value in the array
computation. In general, the size of a dimension is used only to
customize primitives.
This change introduces `core.dim_as_value` which must be used on
a dimension size before using it as a value in the array computation.
E.g.,
```
def average(x):
return jnp.sum(x, axis=0) / core.dim_as_value(x.shape[0])
```
This function is the identity function if the dimension size is
constant, otherwise it uses a new primitive `shape_poly.dim_as_value_p`.
Note that this does not change fundamentally the flavor of shape
polymorphism supported in jax2tf: intermediate shapes and their values
may depend on the input shapes, but never does a shape depend on the
input values. In fact, one could have expressed the `dim_as_value`
already:
```
def dim_as_value(d):
jnp.sum(jnp.broadcast_to(jnp.array(1), shape=(d,)))
```
We were able to suppot `jnp.mean`, `jnp.average`, `jnp.take`,
`lax.dynamic_slice`, `lax.dynamic_update_slice` by using
`core.dim_as_value` internally, but to fully roll-up the solution
we need to make `core.dim_as_value` a public API and teach the
users how to use it when they want to use shape polymorphism.
Alternatively, perhaps there is a way to automatically convert
dimension polynomials to values when passed to the lax primitives.
2021-07-16 20:01:22 +03:00
|
|
|
# Test dim_as_value with dynamic shapes.
|
2020-10-15 08:24:35 +03:00
|
|
|
def f(x):
|
2022-10-07 09:45:12 +03:00
|
|
|
return jnp.sum(x, axis=0) * x.shape[0]
|
2020-10-15 08:24:35 +03:00
|
|
|
|
|
|
|
x = np.arange(3.)
|
2023-11-12 18:17:07 +01:00
|
|
|
self.assertAllClose(9.,
|
|
|
|
check_shape_poly(self, f,
|
|
|
|
arg_descriptors=[x],
|
|
|
|
polymorphic_shapes=["(b,)"]))
|
2021-04-08 06:21:12 -07:00
|
|
|
self.assertAllClose(
|
|
|
|
9.,
|
2023-11-12 18:17:07 +01:00
|
|
|
check_shape_poly(self, jax.jit(f),
|
|
|
|
arg_descriptors=[x], polymorphic_shapes=["(b,)"]))
|
2020-10-15 08:24:35 +03:00
|
|
|
|
2023-11-12 18:17:07 +01:00
|
|
|
res_primal, res_tangent = check_shape_poly(self,
|
2021-04-08 06:21:12 -07:00
|
|
|
lambda x, xt: jax.jvp(f, (x,), (xt,)),
|
2023-11-12 18:17:07 +01:00
|
|
|
arg_descriptors=[x, np.array([0.1, 0.2, 0.3])],
|
|
|
|
polymorphic_shapes=["b", "b"])
|
2020-10-15 08:24:35 +03:00
|
|
|
self.assertAllClose((9., 1.8), (res_primal, res_tangent))
|
|
|
|
|
2023-03-29 12:09:47 +02:00
|
|
|
self.assertAllClose(
|
|
|
|
np.array([3., 3., 3.]),
|
2023-11-12 18:17:07 +01:00
|
|
|
check_shape_poly(self, jax.grad(f),
|
|
|
|
arg_descriptors=[x],
|
|
|
|
polymorphic_shapes=["b"]))
|
2020-10-15 08:24:35 +03:00
|
|
|
|
|
|
|
xv = np.arange(24.).reshape((2, 3, 4))
|
|
|
|
res_vmap = jax.vmap(f, in_axes=1)(xv)
|
|
|
|
# Implement by iteration
|
|
|
|
res_iter = jnp.stack([f(xv[:, i, :]) for i in range(xv.shape[1])])
|
|
|
|
self.assertAllClose(res_iter, res_vmap)
|
|
|
|
|
2023-11-12 18:17:07 +01:00
|
|
|
res_vmap_tf = check_shape_poly(self, jax.vmap(f, in_axes=1),
|
|
|
|
arg_descriptors=[xv],
|
|
|
|
polymorphic_shapes=["b1, b2, ..."])
|
|
|
|
self.assertAllClose(res_iter, res_vmap_tf)
|
2020-10-15 08:24:35 +03:00
|
|
|
|
2023-02-04 08:30:44 +02:00
|
|
|
def test_with_hash_collision_vmap(self):
|
|
|
|
# Batching caches based on Jaxpr, and Jaxpr include _DimExpr. If we have
|
|
|
|
# a collision for the hashing of a _DimExpr, then Python will call the
|
|
|
|
# equality, which will raise InconclusiveDimensionOperation.
|
|
|
|
|
|
|
|
def f_jax(x):
|
|
|
|
return jnp.reshape(x, (2, -1,))
|
|
|
|
try:
|
|
|
|
# Override the hashing to create collisions
|
|
|
|
orig_hash = getattr(shape_poly._DimExpr, "__hash__")
|
|
|
|
def collision_hash(obj):
|
|
|
|
return hash(5)
|
|
|
|
|
|
|
|
setattr(shape_poly._DimExpr, "__hash__", collision_hash)
|
|
|
|
xs = np.ones((3, 5, 6), dtype=np.float32)
|
|
|
|
f_toconvert = jax.vmap(pjit.pjit(f_jax))
|
|
|
|
res_1 = jax2tf.convert(f_toconvert)(xs)
|
|
|
|
res_2 = jax2tf.convert(f_toconvert,
|
|
|
|
polymorphic_shapes = "b1, b2, ...")(xs)
|
|
|
|
self.assertAllClose(res_1, res_2)
|
|
|
|
finally:
|
|
|
|
setattr(shape_poly._DimExpr, "__hash__", orig_hash)
|
|
|
|
|
2023-06-16 12:50:50 +03:00
|
|
|
@jtu.parameterized_filterable(
|
|
|
|
kwargs=[
|
|
|
|
dict(testcase_name=op_name, op=op)
|
2022-12-28 12:07:46 +02:00
|
|
|
for op, op_name in [
|
|
|
|
(jnp.array, "array"),
|
|
|
|
(jnp.sin, "sin"),
|
|
|
|
(lambda x: x, "id"),
|
2022-12-25 14:22:29 +02:00
|
|
|
(core.dimension_as_value, "dimension_as_value"),
|
2023-06-16 12:50:50 +03:00
|
|
|
]])
|
2022-12-28 12:07:46 +02:00
|
|
|
def test_poly_unary_op(self, *, op=jnp.array):
|
|
|
|
def f_jax(x): # x: f32[b]
|
|
|
|
poly = 2 * x.shape[0]
|
2023-03-18 16:14:40 +02:00
|
|
|
return (op(poly), x) # Make sure we are using x
|
2020-10-15 08:24:35 +03:00
|
|
|
|
2022-12-17 05:56:48 +02:00
|
|
|
check_shape_poly(self,
|
|
|
|
f_jax,
|
2022-12-28 12:07:46 +02:00
|
|
|
arg_descriptors=[RandArg((3,), _f32)],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=["b"],
|
2023-03-18 16:14:40 +02:00
|
|
|
expected_output_signature=(tf.TensorSpec([]), tf.TensorSpec((None,), _f32)))
|
2020-10-15 08:24:35 +03:00
|
|
|
|
2023-06-16 12:50:50 +03:00
|
|
|
@jtu.parameterized_filterable(
|
|
|
|
kwargs=[
|
2023-01-18 12:27:02 +02:00
|
|
|
dict(testcase_name=f"_{op.__name__}_other={other}:{type(other)}{'_other_jnp_array' if other_jnp_array else ''}{'_swap' if swap else ''}",
|
|
|
|
op=op, other=other,
|
|
|
|
other_jnp_array=other_jnp_array, swap=swap)
|
2023-01-18 12:27:02 +02:00
|
|
|
for op in [op.add, op.mul, op.sub,
|
|
|
|
op.mod, op.floordiv, op.truediv]
|
2023-01-18 12:27:02 +02:00
|
|
|
for other in [
|
|
|
|
2, np.int32(2), 2., np.float32(2),
|
|
|
|
np.array(2, dtype=np.int32), np.arange(1, 5, dtype=np.int32),
|
|
|
|
np.array(2., dtype=np.float32), np.arange(1., 7., dtype=np.float32)
|
|
|
|
]
|
|
|
|
for other_jnp_array in (
|
|
|
|
[True, False] if np.shape(other) == (7,) else [False]) # type: ignore
|
|
|
|
for swap in [False, True] # The poly is the left op by default
|
2022-12-28 12:07:46 +02:00
|
|
|
])
|
2023-01-18 12:27:02 +02:00
|
|
|
def test_poly_binary_op(self, *, op=op.add,
|
|
|
|
other=np.arange(2, dtype=np.int32),
|
2023-01-18 12:27:02 +02:00
|
|
|
other_jnp_array=False,
|
2023-01-18 12:27:02 +02:00
|
|
|
swap=True):
|
2023-01-18 12:27:02 +02:00
|
|
|
# Test arithmetic operations with poly and a variety of other operand types
|
2022-12-28 12:07:46 +02:00
|
|
|
def f_jax(x): # x: f32[b]
|
2023-01-18 12:27:02 +02:00
|
|
|
poly = 2 * x.shape[0] # This will allow divisions with 2
|
|
|
|
other_wrapped = jnp.array(other) if other_jnp_array else other
|
|
|
|
ops = (poly, other_wrapped) if not swap else (other_wrapped, poly)
|
|
|
|
res = op(*ops)
|
|
|
|
|
|
|
|
# If the other op is an integer then the result is a symbolic dim
|
|
|
|
try:
|
2023-01-18 12:27:02 +02:00
|
|
|
op.index(other)
|
2023-01-18 12:27:02 +02:00
|
|
|
other_isint = True
|
|
|
|
except Exception:
|
|
|
|
other_isint = False
|
|
|
|
|
|
|
|
if (hasattr(poly, "dimension_as_value") and
|
|
|
|
other_isint and
|
2023-01-18 12:27:02 +02:00
|
|
|
op.__name__ != "truediv"):
|
2023-01-18 12:27:02 +02:00
|
|
|
# If we running under jax2tf and "other" is an integer the result
|
|
|
|
# should be a symbolic dimension
|
|
|
|
self.assertTrue(isinstance(res, int) or hasattr(res, "dimension_as_value"))
|
|
|
|
|
2023-10-12 13:15:22 +01:00
|
|
|
if config.enable_x64.value:
|
2023-01-18 12:27:02 +02:00
|
|
|
# Outside jax2tf, x.shape[0] is a Python (64-bit) integer and for most
|
|
|
|
# operations here JAX is not involved at all because the other operand
|
|
|
|
# is a Python or NumPy constant. So the result will be 64-bits. But under
|
|
|
|
# jax2tf, x.shape[0] is rewritten to jnp.array(x.shape[0]) which when
|
|
|
|
# used with int32 or float32 values will produce 32-bit values.
|
2023-03-18 16:14:40 +02:00
|
|
|
return (lax.convert_element_type(res, np.float32), x)
|
|
|
|
return (res, x) # Make sure we are using x
|
2022-12-28 12:07:46 +02:00
|
|
|
|
2023-01-18 12:27:02 +02:00
|
|
|
check_shape_poly(self,
|
|
|
|
f_jax,
|
|
|
|
arg_descriptors=[RandArg((3,), np.int32)],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=["b"])
|
2022-12-28 12:07:46 +02:00
|
|
|
|
|
|
|
def test_mean0(self):
|
|
|
|
def f_jax(x): # x: f32[b, 4]
|
|
|
|
return jnp.sum(x, axis=0) / x.shape[0]
|
2022-12-17 05:56:48 +02:00
|
|
|
check_shape_poly(self,
|
|
|
|
f_jax,
|
2022-12-28 12:07:46 +02:00
|
|
|
arg_descriptors=[RandArg((3, 4), _f32)],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=["b, _"],
|
2022-12-28 12:07:46 +02:00
|
|
|
expected_output_signature=tf.TensorSpec([4]))
|
|
|
|
|
2022-10-07 09:45:12 +03:00
|
|
|
def test_shape_as_array(self):
|
|
|
|
def f_jax(x):
|
|
|
|
# The entire x.shape is passed to jnp.array
|
2023-02-10 10:59:46 +01:00
|
|
|
return x + jnp.sum(jnp.array(x.shape)).astype(np.int32)
|
2022-03-15 13:14:34 +01:00
|
|
|
|
2022-12-17 05:56:48 +02:00
|
|
|
check_shape_poly(self,
|
|
|
|
f_jax,
|
|
|
|
arg_descriptors=[RandArg((3, 4), _f32)],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=["b, _"])
|
2022-12-28 12:07:46 +02:00
|
|
|
|
2023-01-18 12:56:48 +02:00
|
|
|
def test_dim_as_value_weak_type(self):
|
|
|
|
def f_jax(x): # x: f32[b]
|
|
|
|
d0 = jnp.array(x.shape[0]) # in JAX should have weak_type=True
|
|
|
|
if isinstance(d0, core.Tracer):
|
|
|
|
self.assertTrue(d0.aval.weak_type), d0
|
|
|
|
|
|
|
|
# And an implicit conversion to array
|
|
|
|
d1 = x.shape[0] + jnp.array(4)
|
|
|
|
if isinstance(d1, core.Tracer):
|
|
|
|
self.assertTrue(d1.aval.weak_type), d1
|
2023-02-10 10:59:46 +01:00
|
|
|
return d0 + np.array(5., dtype=np.float32) + d1 + x[0]
|
2023-01-18 12:56:48 +02:00
|
|
|
|
2023-10-12 13:15:22 +01:00
|
|
|
with config.numpy_dtype_promotion("strict"):
|
2023-01-18 12:56:48 +02:00
|
|
|
# strict type promotion is sensitive to weak_types
|
|
|
|
check_shape_poly(self,
|
|
|
|
f_jax,
|
|
|
|
arg_descriptors=[RandArg((3,), _f32)],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=["b"])
|
2023-01-18 12:56:48 +02:00
|
|
|
|
2022-12-17 05:56:48 +02:00
|
|
|
def test_vmap_while(self):
|
|
|
|
def cond_func(x): # x: f32[3]
|
|
|
|
return jnp.sum(x) >= 0.
|
|
|
|
def body_func(x): # x: f32[3]
|
|
|
|
return x - 1.
|
|
|
|
def f_jax(x):
|
|
|
|
return lax.while_loop(cond_func, body_func, x)
|
2021-04-09 14:02:44 +03:00
|
|
|
|
2022-12-17 05:56:48 +02:00
|
|
|
check_shape_poly(self,
|
|
|
|
jax.vmap(f_jax),
|
2023-06-16 12:50:50 +03:00
|
|
|
arg_descriptors=[RandArg((5, 3), _f32)],
|
2022-12-17 05:56:48 +02:00
|
|
|
polymorphic_shapes=["b, ..."],
|
|
|
|
expected_output_signature=tf.TensorSpec((None, 3), dtype=tf.float32)
|
|
|
|
)
|
|
|
|
|
2023-01-17 10:42:20 +02:00
|
|
|
def test_vmap_error(self):
|
|
|
|
# vmap is careful to give nice error messages when mapped axes have
|
|
|
|
# different sizes, but this can be foiled by InconsistentDimensionOperation
|
|
|
|
x = y = np.ones((3, 5), dtype=np.float32)
|
|
|
|
with self.assertRaisesRegex(ValueError,
|
|
|
|
"vmap got inconsistent sizes for array axes to be mapped"):
|
|
|
|
jax2tf.convert(jax.vmap(lambda x, y: x + y),
|
|
|
|
polymorphic_shapes=["b, ...", None])(x, y)
|
|
|
|
|
|
|
|
z = x
|
|
|
|
with self.assertRaisesRegex(ValueError,
|
|
|
|
"vmap got inconsistent sizes for array axes to be mapped"):
|
|
|
|
jax2tf.convert(jax.vmap(lambda x, y, z: x + y + z),
|
|
|
|
polymorphic_shapes=["b, ...", "c, ...", None])(x, y, z)
|
|
|
|
|
2022-12-17 05:56:48 +02:00
|
|
|
def test_reshape_compiled(self):
|
|
|
|
# We compile the result of conversion for two shapes, hence we need to
|
|
|
|
# involve the TF compiler twice, but we trace only once with shape polymorphism
|
|
|
|
traced = False
|
|
|
|
|
|
|
|
def f_jax(x):
|
|
|
|
nonlocal traced
|
|
|
|
traced = True
|
|
|
|
y = jnp.sin(x)
|
|
|
|
return y.reshape([x.shape[0], -1])
|
|
|
|
|
|
|
|
x = self.rng().rand(4, 2, 3)
|
|
|
|
res_jax = f_jax(x)
|
|
|
|
|
|
|
|
traced = False
|
|
|
|
# If we get_concrete_function we trace once
|
|
|
|
f_tf = tf.function(
|
2024-01-10 08:45:03 +02:00
|
|
|
jax2tf.convert(f_jax, polymorphic_shapes=["b, ..."]),
|
2022-12-17 05:56:48 +02:00
|
|
|
autograph=False,
|
|
|
|
jit_compile=True).get_concrete_function(
|
|
|
|
tf.TensorSpec([None, 2, 3], x.dtype))
|
|
|
|
self.assertTrue(traced)
|
|
|
|
traced = False
|
|
|
|
self.assertAllClose(res_jax, f_tf(x))
|
|
|
|
self.assertFalse(traced) # We are not tracing again
|
|
|
|
|
|
|
|
x = self.rng().rand(6, 2, 3)
|
|
|
|
res_jax = f_jax(x)
|
|
|
|
traced = False
|
|
|
|
|
|
|
|
self.assertAllClose(res_jax, f_tf(x))
|
|
|
|
self.assertFalse(traced) # We are not tracing again
|
2021-04-09 14:02:44 +03:00
|
|
|
|
2023-05-03 10:42:50 +02:00
|
|
|
def test_eval_poly_shapes(self):
|
|
|
|
def f1(x, y): # x: f32[a, 5] y: f[a, 5] -> f32[a, 10]
|
|
|
|
return jnp.concatenate([x, y], axis=1)
|
|
|
|
def f2(x, z): # x: f32[a, 5] z: f32[a, 10]
|
|
|
|
return jnp.concatenate([x, jax.lax.slice_in_dim(z, 0, 5, axis=1)],
|
|
|
|
axis=1),
|
|
|
|
|
|
|
|
x = np.arange(np.prod((3, 5)), dtype=np.float32).reshape((3, 5))
|
|
|
|
y = x
|
|
|
|
|
|
|
|
x_polymorphic_shape = "a, _"
|
|
|
|
y_polymorphic_shape = x_polymorphic_shape
|
|
|
|
z_spec, z_polymorphic_shape = jax2tf.eval_polymorphic_shape(
|
|
|
|
f1,
|
|
|
|
polymorphic_shapes=[x_polymorphic_shape, y_polymorphic_shape])(x, y)
|
|
|
|
self.assertEqual(np.float32, z_spec.dtype)
|
|
|
|
self.assertEqual("(a, 10)", z_polymorphic_shape)
|
|
|
|
|
|
|
|
# We can use the z_polymorphic_shape for jax2tf.convert
|
|
|
|
z = jax2tf.convert(
|
|
|
|
f1,
|
|
|
|
polymorphic_shapes=[x_polymorphic_shape, y_polymorphic_shape])(x, y)
|
|
|
|
res = jax2tf.convert(
|
|
|
|
f2,
|
|
|
|
polymorphic_shapes=[x_polymorphic_shape, z_polymorphic_shape])(x, z)
|
|
|
|
self.assertAllClose(f2(x, f1(x, y)), res)
|
|
|
|
|
|
|
|
def test_eval_poly_shapes_tuple_output(self):
|
|
|
|
def f1(x, y): # x: f32[a, 5] y: f[b, 5] -> (f32[a, 5], f32[a + b, 5])
|
|
|
|
return (x, jnp.concatenate([x, y], axis=0))
|
|
|
|
def f2(z, w): # z: f32[a, 5] w: f32[a + b, 5] -> f32[2*a + b, 10]
|
|
|
|
return jnp.concatenate([z, w], axis=0)
|
|
|
|
x = np.arange(np.prod((3, 5)), dtype=np.float32).reshape((3, 5))
|
|
|
|
y = np.arange(np.prod((4, 5)), dtype=np.float32).reshape((4, 5))
|
|
|
|
|
|
|
|
x_polymorphic_shape = "a, _"
|
|
|
|
y_polymorphic_shape = "b, _"
|
|
|
|
zw_specs, zw_polymorphic_shapes = jax2tf.eval_polymorphic_shape(
|
|
|
|
f1,
|
|
|
|
polymorphic_shapes=[x_polymorphic_shape, y_polymorphic_shape])(x, y)
|
|
|
|
self.assertEqual(np.float32, zw_specs[0].dtype)
|
|
|
|
self.assertEqual(np.float32, zw_specs[1].dtype)
|
2024-01-05 14:48:53 +07:00
|
|
|
self.assertEqual(("(a, 5)", "(b + a, 5)"), zw_polymorphic_shapes)
|
2023-05-03 10:42:50 +02:00
|
|
|
|
|
|
|
# We can use the zw_polymorphic_shapes for jax2tf.convert
|
|
|
|
z, w = jax2tf.convert(
|
|
|
|
f1,
|
|
|
|
polymorphic_shapes=[x_polymorphic_shape, y_polymorphic_shape])(x, y)
|
|
|
|
res = jax2tf.convert(f2, polymorphic_shapes=zw_polymorphic_shapes)(z, w)
|
|
|
|
self.assertAllClose(f2(* f1(x, y)), res)
|
|
|
|
|
2021-04-09 14:02:44 +03:00
|
|
|
|
[jax2tf] Expand shape polymorphism support to use dimension polynomials as values.
The goal of this change is to support shape polymorphism for operations
such as average (which needs to divide by the size of a dimension) or
indexing (which needs to normalize indices by comparing them with 0 and
adding dimension size for negative indices). In both of these cases
the size of a dimenion needs to be used as a value in the array
computation. In general, the size of a dimension is used only to
customize primitives.
This change introduces `core.dim_as_value` which must be used on
a dimension size before using it as a value in the array computation.
E.g.,
```
def average(x):
return jnp.sum(x, axis=0) / core.dim_as_value(x.shape[0])
```
This function is the identity function if the dimension size is
constant, otherwise it uses a new primitive `shape_poly.dim_as_value_p`.
Note that this does not change fundamentally the flavor of shape
polymorphism supported in jax2tf: intermediate shapes and their values
may depend on the input shapes, but never does a shape depend on the
input values. In fact, one could have expressed the `dim_as_value`
already:
```
def dim_as_value(d):
jnp.sum(jnp.broadcast_to(jnp.array(1), shape=(d,)))
```
We were able to suppot `jnp.mean`, `jnp.average`, `jnp.take`,
`lax.dynamic_slice`, `lax.dynamic_update_slice` by using
`core.dim_as_value` internally, but to fully roll-up the solution
we need to make `core.dim_as_value` a public API and teach the
users how to use it when they want to use shape polymorphism.
Alternatively, perhaps there is a way to automatically convert
dimension polynomials to values when passed to the lax primitives.
2021-07-16 20:01:22 +03:00
|
|
|
# List containing either harnesses, or lists of harnesses
|
2021-04-09 14:02:44 +03:00
|
|
|
_POLY_SHAPE_TEST_HARNESSES = [
|
2022-12-17 05:56:48 +02:00
|
|
|
PolyHarness("add", "",
|
|
|
|
jnp.add,
|
|
|
|
arg_descriptors=[RandArg((3, 4), _f32), RandArg((2, 3, 4), _f32)],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=["b, ...", "_, b, _"]),
|
2022-12-17 05:56:48 +02:00
|
|
|
PolyHarness("add_transpose", "",
|
2023-02-10 10:59:46 +01:00
|
|
|
jax.grad(lambda x: jnp.sum(jnp.sum(x, axis=0, keepdims=False) + jnp.sin(x))),
|
2022-12-17 05:56:48 +02:00
|
|
|
arg_descriptors=[RandArg((3, 4), _f32)],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=["b, ..."]),
|
2023-11-12 18:17:07 +01:00
|
|
|
[
|
|
|
|
# make_args invoked with op.shape[0] and produces the arange args:
|
|
|
|
# start, stop, step, dtype
|
|
|
|
PolyHarness("arange", kwargs["testcase_name"], # type: ignore
|
|
|
|
lambda x: jnp.arange(*(kwargs["make_args"](x.shape[0]))), # type: ignore
|
|
|
|
arg_descriptors=[RandArg((6,), np.float32)],
|
|
|
|
polymorphic_shapes=["b"])
|
|
|
|
for kwargs in [
|
|
|
|
# Positive step
|
|
|
|
dict(testcase_name="b", make_args=lambda b: (b, None, None, None)),
|
|
|
|
dict(testcase_name="0_b+1", make_args=lambda b: (0, b + 1, None, None)),
|
|
|
|
dict(testcase_name="0_5b_2", make_args=lambda b: (0, 5 * b, 2, None)),
|
|
|
|
dict(testcase_name="0_5b+1_2", make_args=lambda b: (0, 5 * b + 1, 2, None)),
|
|
|
|
dict(testcase_name="b_5b+2_2", make_args=lambda b: (b, 5 * b + 2, 2, None)),
|
|
|
|
dict(testcase_name="0_b-1_2", make_args=lambda b: (0, b - 1, 2, None)),
|
|
|
|
dict(testcase_name="0_b-2_2", make_args=lambda b: (0, b - 2, 2, None)),
|
|
|
|
dict(testcase_name="0_-b_2", make_args=lambda b: (0, -b, 2, None)),
|
|
|
|
dict(testcase_name="0_1-b_2", make_args=lambda b: (0, 1 - b, 2, None)),
|
|
|
|
dict(testcase_name="0_b-3_2", make_args=lambda b: (0, b - 3, 2, None)),
|
|
|
|
# Cannot tell if size >= 0
|
|
|
|
# Negative step
|
|
|
|
dict(testcase_name="b_0_-1", make_args=lambda b: (b, 0, -1, None)),
|
|
|
|
dict(testcase_name="b_1_-2", make_args=lambda b: (b, 1, -2, None)),
|
|
|
|
dict(testcase_name="b_-1_-1", make_args=lambda b: (b, -1, -1, None)),
|
|
|
|
dict(testcase_name="5b+1_0_-2",
|
|
|
|
make_args=lambda b: (5 * b + 1, 0, -2, None)),
|
|
|
|
dict(testcase_name="5b+2_0_-2",
|
|
|
|
make_args=lambda b: (5 * b + 2, 0, -2, None)),
|
|
|
|
dict(testcase_name="b-3_0_-2", make_args=lambda b: (b - 3, 0, -2, None)),
|
|
|
|
# Cannot tell if size >= 0
|
|
|
|
# Symbolic step
|
|
|
|
dict(testcase_name="0_10_b", make_args=lambda b: (0, 10, b)),
|
|
|
|
dict(testcase_name="0_0_b", make_args=lambda b: (0, 0, b)),
|
|
|
|
dict(testcase_name="10_0_-b", make_args=lambda b: (10, 0, -b)),
|
|
|
|
dict(testcase_name="b_1_-b", make_args=lambda b: (b, 1, -b)),
|
|
|
|
# Float return type
|
|
|
|
dict(testcase_name="0_b_1_f32", make_args=lambda b: (0, b, 1, np.float32))
|
|
|
|
]
|
|
|
|
],
|
2021-11-05 17:03:46 +02:00
|
|
|
# Reduce the poly dimension
|
2022-12-17 05:56:48 +02:00
|
|
|
PolyHarness("argmax", "0",
|
|
|
|
lambda op: lax.argmax(op, axis=0, index_dtype=np.int32),
|
|
|
|
arg_descriptors=[RandArg((3, 4, 5), _f32)],
|
2024-07-16 02:04:59 -07:00
|
|
|
polymorphic_shapes=["b, ..."]),
|
2021-11-05 17:03:46 +02:00
|
|
|
# Reduce the non-poly dimension
|
2022-12-17 05:56:48 +02:00
|
|
|
PolyHarness("argmax", "1",
|
|
|
|
lambda op: lax.argmax(op, axis=1, index_dtype=np.int32),
|
|
|
|
arg_descriptors=[RandArg((3, 4, 5), _f32)],
|
2024-07-16 02:04:59 -07:00
|
|
|
polymorphic_shapes=["b, ..."]),
|
2023-03-02 14:31:32 +01:00
|
|
|
PolyHarness("jnp.argsort", "",
|
|
|
|
lambda op: jnp.argsort(op),
|
|
|
|
arg_descriptors=[RandArg((3, 4, 5), _f32)],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=["b, ..."]),
|
2021-07-30 10:52:34 +03:00
|
|
|
[
|
2022-12-17 05:56:48 +02:00
|
|
|
PolyHarness("average",
|
|
|
|
f"{axis=}_weights=None",
|
|
|
|
lambda x, axis: jnp.average(x, axis=axis, returned=False, weights=None),
|
|
|
|
arg_descriptors=[RandArg((7, 8, 4), _f32), StaticArg(axis)],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=["b, ..."])
|
2021-07-30 10:52:34 +03:00
|
|
|
for axis in [None, 0, 1]
|
|
|
|
],
|
|
|
|
[
|
2022-12-17 05:56:48 +02:00
|
|
|
PolyHarness("average",
|
|
|
|
f"{axis=}_weights=Some",
|
|
|
|
lambda x, weights, axis: jnp.average(x, axis=axis, returned=False, weights=weights),
|
|
|
|
arg_descriptors=[RandArg((7, 8, 4), _f32), RandArg((7, 8, 4), _f32), StaticArg(axis)],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=["b, ...", "b, ..."])
|
2021-07-30 10:52:34 +03:00
|
|
|
for axis in [None, 0, 1]
|
|
|
|
],
|
2023-03-02 14:31:32 +01:00
|
|
|
PolyHarness("jnp.bincount", "length=constant",
|
|
|
|
lambda x: jnp.bincount(x % 2, length=4),
|
|
|
|
arg_descriptors=[RandArg((12,), np.int32)],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=["b, ..."]),
|
2023-03-02 14:31:32 +01:00
|
|
|
PolyHarness("jnp.bincount", "length=poly",
|
|
|
|
lambda x: jnp.bincount(x % 4, length=x.shape[0]),
|
|
|
|
arg_descriptors=[RandArg((12,), np.int32)],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=["b, ..."]),
|
2022-12-17 05:56:48 +02:00
|
|
|
PolyHarness("broadcast_to", "",
|
|
|
|
lambda x: jnp.broadcast_to(x, [x.shape[0], x.shape[0], 4]),
|
|
|
|
arg_descriptors=[RandArg((3, 4), _f32)],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=["b, ..."]),
|
2022-12-17 05:56:48 +02:00
|
|
|
PolyHarness("broadcast_in_dim", "0",
|
|
|
|
lambda x: lax.broadcast_in_dim(x, [x.shape[0], 4, 5, 6],
|
|
|
|
broadcast_dimensions=(0, 2, 3)),
|
|
|
|
arg_descriptors=[RandArg((3, 1, 6), _f32)],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=["b, ..."]),
|
2022-12-17 05:56:48 +02:00
|
|
|
PolyHarness("broadcast_in_dim", "poly",
|
|
|
|
lambda x: lax.broadcast_in_dim(x, [x.shape[0], x.shape[0] + x.shape[0], 4],
|
|
|
|
broadcast_dimensions=(0, 1, 2)),
|
|
|
|
arg_descriptors=[RandArg((3, 1, 4), _f32)],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=["b, ..."]),
|
2022-12-17 05:56:48 +02:00
|
|
|
PolyHarness("broadcast_in_dim", "poly2",
|
|
|
|
lambda x: lax.broadcast_in_dim(x, [x.shape[0], 5, 6, x.shape[2], 4],
|
|
|
|
broadcast_dimensions=(0, 2, 3)),
|
|
|
|
arg_descriptors=[RandArg((3, 1, 4), _f32)],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=["b1, _, b2"]),
|
2022-12-17 05:56:48 +02:00
|
|
|
PolyHarness("broadcast_in_dim", "transpose",
|
2023-02-10 10:59:46 +01:00
|
|
|
jax.grad(lambda x: jnp.sum(
|
|
|
|
lax.broadcast_in_dim(jnp.sin(x), [2, x.shape[0], 5, x.shape[2], 4],
|
|
|
|
broadcast_dimensions=(1, 2, 3)))),
|
2022-12-17 05:56:48 +02:00
|
|
|
arg_descriptors=[RandArg((3, 1, 4), _f32)],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=["b1, _, b2"]),
|
2022-12-17 05:56:48 +02:00
|
|
|
PolyHarness("clamp", "",
|
|
|
|
lax.clamp,
|
|
|
|
arg_descriptors=[RandArg((3, 4, 5), _f32), RandArg((3, 4, 5), _f32),
|
|
|
|
RandArg((3, 4, 5), _f32)],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=["b, ...", "b, ...", "b, ..."]),
|
2022-12-17 05:56:48 +02:00
|
|
|
PolyHarness("collapse", "",
|
|
|
|
lambda x: lax.collapse(x, 1, 4),
|
|
|
|
arg_descriptors=[RandArg((3, 4, 5, 6, 7), _f32)],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=["b0, b1, _, b3, ..."]),
|
2022-12-17 05:56:48 +02:00
|
|
|
PolyHarness("concatenate", "",
|
|
|
|
lambda x: jnp.concatenate([x, x], axis=0),
|
|
|
|
arg_descriptors=[RandArg((3, 4, 5), _f32)],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=["b0, b1, _"]),
|
2022-12-17 05:56:48 +02:00
|
|
|
PolyHarness("concatenate", "grad",
|
2023-02-10 10:59:46 +01:00
|
|
|
jax.grad(lambda x: jnp.sum(jnp.concatenate([x, jnp.sin(x)], axis=0))),
|
2022-12-17 05:56:48 +02:00
|
|
|
arg_descriptors=[RandArg((3, 4, 5), _f32)],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=["b0, b1, _"]),
|
2022-07-12 16:27:57 +03:00
|
|
|
|
2023-01-18 12:27:02 +02:00
|
|
|
PolyHarness("conv_general_dilated", "1d_stride=1",
|
2022-12-17 05:56:48 +02:00
|
|
|
lambda lhs, rhs: lax.conv_general_dilated(
|
2023-01-18 12:27:02 +02:00
|
|
|
lhs, rhs,
|
|
|
|
window_strides=(1,),
|
2022-12-17 05:56:48 +02:00
|
|
|
padding="SAME",
|
|
|
|
rhs_dilation=None,
|
|
|
|
dimension_numbers=lax.ConvDimensionNumbers(lhs_spec=(0, 2, 1),
|
|
|
|
rhs_spec=(2, 1, 0),
|
|
|
|
out_spec=(0, 2, 1))),
|
2023-01-18 12:27:02 +02:00
|
|
|
arg_descriptors=[RandArg((1, 12, 16), _f32), RandArg((4, 16, 16), _f32)],
|
2024-07-16 02:04:59 -07:00
|
|
|
polymorphic_shapes=["_, b, _", None]),
|
2023-01-18 12:27:02 +02:00
|
|
|
# The same example from above, but with stride=2.
|
|
|
|
PolyHarness("conv_general_dilated", "1d_stride=2_even",
|
2022-12-17 05:56:48 +02:00
|
|
|
lambda lhs, rhs: lax.conv_general_dilated(
|
|
|
|
lhs, rhs,
|
|
|
|
window_strides=(2,),
|
|
|
|
padding="SAME",
|
|
|
|
rhs_dilation=None,
|
|
|
|
dimension_numbers=lax.ConvDimensionNumbers(lhs_spec=(0, 2, 1),
|
|
|
|
rhs_spec=(2, 1, 0),
|
|
|
|
out_spec=(0, 2, 1))),
|
|
|
|
arg_descriptors=[RandArg((1, 12, 16), _f32), RandArg((4, 16, 16), _f32)],
|
2024-07-16 02:04:59 -07:00
|
|
|
polymorphic_shapes=["_, b, _", None]),
|
2023-01-18 12:27:02 +02:00
|
|
|
# The same example from above, but with stride=2 and odd input size.
|
|
|
|
PolyHarness("conv_general_dilated", "1d_stride=2_odd",
|
|
|
|
lambda lhs, rhs: lax.conv_general_dilated(
|
|
|
|
lhs, rhs,
|
|
|
|
window_strides=(2,),
|
|
|
|
padding="SAME",
|
|
|
|
rhs_dilation=None,
|
|
|
|
dimension_numbers=lax.ConvDimensionNumbers(lhs_spec=(0, 2, 1),
|
|
|
|
rhs_spec=(2, 1, 0),
|
|
|
|
out_spec=(0, 2, 1))),
|
|
|
|
arg_descriptors=[RandArg((1, 13, 16), _f32), RandArg((4, 16, 16), _f32)],
|
2024-07-16 02:04:59 -07:00
|
|
|
polymorphic_shapes=["_, b, _", None]),
|
[shape_poly] Add support for max0 for symbolic dimensions.
There are a few cases when JAX computes `max(v, 0)`, most
notably when computing the sizes of strided access,
dilated convolutions and padding, and for the size
of jnp.arange.
Until now these cases were supported
for shape polymorphism only when we can tell statically
that the size is >= 0. Here we add support to the
symbolic expressions for a `non_negative` operator,
which essentially implements `max(v, 0)` and with this
we can now support the general case for `jnp.arange`, with
simpler code.
We could add a general `max` operator, and we may do so in the
future, but for now `non_negative` suffices.
Note that this fixes a couple of bugs
* for core.dilated_dim we had the code "if d == 0 then 0 else ..."
but this works only if we can tell statically that `d == 0`, and
it produced wrong results when `d` was symbolic and could take
the value 0.
* for core.stride_dim we did not handle correctly the case when
`d < window_size`.
Handling the above fundamentally requires a `max(d, 0)` operation.
2023-07-12 12:46:47 +03:00
|
|
|
PolyHarness("conv_general_dilated", "1d_stride=2_zero_output",
|
|
|
|
lambda lhs, rhs: lax.conv_general_dilated(
|
|
|
|
lhs, rhs,
|
|
|
|
window_strides=(2,),
|
|
|
|
padding="VALID",
|
|
|
|
rhs_dilation=None,
|
|
|
|
dimension_numbers=lax.ConvDimensionNumbers(lhs_spec=(0, 2, 1),
|
|
|
|
rhs_spec=(2, 1, 0),
|
|
|
|
out_spec=(0, 2, 1))
|
|
|
|
).shape[1], # should be 0 in JAX native
|
|
|
|
arg_descriptors=[RandArg((1, 4, 16), _f32),
|
|
|
|
RandArg((8, 16, 16), _f32)],
|
|
|
|
polymorphic_shapes=["_, b, _",
|
2024-07-16 02:04:59 -07:00
|
|
|
None]),
|
2022-07-12 16:27:57 +03:00
|
|
|
# Issue #11402
|
2022-12-17 05:56:48 +02:00
|
|
|
PolyHarness("conv_general_dilated", "1d_2",
|
|
|
|
lambda lhs, rhs: lax.conv_transpose(lhs, rhs,
|
|
|
|
strides=(2,),
|
|
|
|
padding="SAME",
|
|
|
|
rhs_dilation=None,
|
|
|
|
transpose_kernel=False),
|
|
|
|
arg_descriptors=[RandArg((5, 12, 16), _f32), RandArg((4, 16, 16), _f32)],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=["b, _, _", None],
|
2024-07-16 02:04:59 -07:00
|
|
|
tol=5e-5),
|
2022-07-12 16:27:57 +03:00
|
|
|
# Issue #11402
|
2022-12-17 05:56:48 +02:00
|
|
|
PolyHarness("conv_general_dilated", "1d_3",
|
|
|
|
lambda lhs, rhs: lax.conv_transpose(lhs, rhs,
|
|
|
|
strides=(2,),
|
|
|
|
padding="SAME",
|
|
|
|
rhs_dilation=None,
|
|
|
|
transpose_kernel=False),
|
|
|
|
arg_descriptors=[RandArg((5, 12, 16), _f32), RandArg((4, 16, 16), _f32)],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=["_, b, _", None],
|
2024-07-16 02:04:59 -07:00
|
|
|
tol=5e-5),
|
2022-12-17 05:56:48 +02:00
|
|
|
PolyHarness("conv_general_dilated", "",
|
|
|
|
lambda lhs, rhs: lax.conv_general_dilated(
|
|
|
|
lhs, rhs,
|
|
|
|
window_strides=(2, 3),
|
|
|
|
padding=((0, 0), (0, 0)),
|
|
|
|
lhs_dilation=(1, 1),
|
|
|
|
rhs_dilation=(1, 2),
|
|
|
|
dimension_numbers=("NCHW", "OIHW", "NCHW"),
|
|
|
|
feature_group_count=1,
|
|
|
|
batch_group_count=1,
|
|
|
|
precision=None),
|
|
|
|
arg_descriptors=[RandArg((7, 3, 9, 10), _f32), RandArg((3, 3, 4, 5), _f32)],
|
2024-07-16 02:04:59 -07:00
|
|
|
polymorphic_shapes=["b, ...", None]),
|
2023-06-17 10:33:29 -07:00
|
|
|
[
|
|
|
|
[
|
2023-08-12 08:25:45 +03:00
|
|
|
PolyHarness(cum_name, "reduce_axis_poly",
|
2023-06-17 10:33:29 -07:00
|
|
|
lambda x: cum_func(x, axis=0),
|
|
|
|
arg_descriptors=[RandArg((3, 5), _f32)],
|
|
|
|
polymorphic_shapes=["b, ..."]),
|
2023-08-12 08:25:45 +03:00
|
|
|
PolyHarness(cum_name, "reduce_axis_static",
|
2023-06-17 10:33:29 -07:00
|
|
|
lambda x: cum_func(x, axis=1),
|
|
|
|
arg_descriptors=[RandArg((3, 5), _f32)],
|
|
|
|
polymorphic_shapes=["b, ..."])
|
|
|
|
]
|
|
|
|
for cum_name, cum_func in [
|
|
|
|
("cumlogsumexp", lax_control_flow.cumlogsumexp),
|
|
|
|
("cummax", lax_control_flow.cummax),
|
|
|
|
("cummin", lax_control_flow.cummin),
|
|
|
|
("cumsum", lax_control_flow.cumsum),
|
|
|
|
("cumprod", lax_control_flow.cumprod)
|
|
|
|
]
|
|
|
|
],
|
2022-12-17 05:56:48 +02:00
|
|
|
PolyHarness("delta", "0",
|
2023-02-13 02:47:35 -08:00
|
|
|
lambda x: lax_internal._delta(_f32, x.shape, axes=(0, 1)) + x,
|
|
|
|
arg_descriptors=[RandArg((3, 1), _f32)],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=["b, ..."]),
|
2022-12-17 05:56:48 +02:00
|
|
|
PolyHarness("dot_general", "",
|
|
|
|
lambda lhs, rhs: lax.dot_general(lhs, rhs,
|
|
|
|
dimension_numbers=(((2,), (1,)), ((0,), (0,)))),
|
|
|
|
arg_descriptors=[RandArg((3, 4, 4), _f32), RandArg((3, 4), _f32)],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=["b, ...", "b, ..."]),
|
2022-12-17 05:56:48 +02:00
|
|
|
PolyHarness("dynamic_slice", "idx=tuple_int",
|
|
|
|
# x:shape: (b, 4)
|
|
|
|
lambda x: lax.dynamic_slice(x, (0, 1), (x.shape[0], 2)),
|
|
|
|
arg_descriptors=[RandArg((3, 4), _f32)],
|
2024-07-16 02:04:59 -07:00
|
|
|
polymorphic_shapes=["b, ..."]),
|
2022-12-17 05:56:48 +02:00
|
|
|
PolyHarness("dynamic_slice", "idx=tuple_arg",
|
|
|
|
# x:shape: (b, 4)
|
|
|
|
lambda x, i0: lax.dynamic_slice(x, (i0, np.int32(1)), (x.shape[0], 2)),
|
|
|
|
arg_descriptors=[RandArg((3, 4), _f32), np.array(-2, dtype=np.int32)],
|
2024-07-16 02:04:59 -07:00
|
|
|
polymorphic_shapes=["b, ...", None]),
|
2022-12-17 05:56:48 +02:00
|
|
|
PolyHarness("dynamic_slice", "idx=array",
|
|
|
|
# x:shape: (b, 4)
|
|
|
|
lambda x, idx: lax.dynamic_slice(x, idx, (x.shape[0], 2)),
|
|
|
|
arg_descriptors=[RandArg((3, 4), _f32), np.array([-2, -1], dtype=np.int32)],
|
2024-07-16 02:04:59 -07:00
|
|
|
polymorphic_shapes=["b, ...", None]),
|
2023-06-14 08:40:47 +03:00
|
|
|
PolyHarness("dynamic_slice", "idx=tuple_int_start_oob_large",
|
|
|
|
# x:shape: (b, 4)
|
|
|
|
lambda x: lax.dynamic_slice(x, (1, 1), (x.shape[0], 2)),
|
|
|
|
arg_descriptors=[RandArg((3, 4), _f32)],
|
2024-07-16 02:04:59 -07:00
|
|
|
polymorphic_shapes=["b, ..."]),
|
2023-06-14 08:40:47 +03:00
|
|
|
PolyHarness("dynamic_slice", "idx=tuple_int_start_oob_small",
|
|
|
|
# x:shape: (b, 4)
|
|
|
|
lambda x: lax.dynamic_slice(x, (-1, 1), (x.shape[0] - 1, 2)),
|
|
|
|
arg_descriptors=[RandArg((3, 4), _f32)],
|
2024-07-16 02:04:59 -07:00
|
|
|
polymorphic_shapes=["b, ..."]),
|
2022-12-17 05:56:48 +02:00
|
|
|
PolyHarness("dynamic_slice_in_dim", "idx=0",
|
|
|
|
# x:shape: (b, 4)
|
|
|
|
lambda x: lax.dynamic_slice_in_dim(x, 0, x.shape[0], axis=0),
|
|
|
|
arg_descriptors=[RandArg((3, 4), _f32)],
|
2024-07-16 02:04:59 -07:00
|
|
|
polymorphic_shapes=["b, ..."]),
|
2022-12-17 05:56:48 +02:00
|
|
|
PolyHarness("dynamic_update_slice", "idx=tuple_int",
|
|
|
|
# x:shape: (b, 4)
|
|
|
|
lambda x: lax.dynamic_update_slice(x, x, (0, 0)),
|
|
|
|
arg_descriptors=[RandArg((3, 4), _f32)],
|
2024-07-16 02:04:59 -07:00
|
|
|
polymorphic_shapes=["b, ..."]),
|
2022-12-17 05:56:48 +02:00
|
|
|
PolyHarness("dynamic_update_slice", "idx=tuple_arg",
|
|
|
|
# x:shape: (b, 4)
|
|
|
|
lambda x, i0: lax.dynamic_update_slice(x, x, (i0, np.int32(0))),
|
|
|
|
arg_descriptors=[RandArg((3, 4), _f32), np.array(-2, dtype=np.int32)],
|
2024-07-16 02:04:59 -07:00
|
|
|
polymorphic_shapes=["b, ...", None]),
|
2022-12-17 05:56:48 +02:00
|
|
|
PolyHarness("dynamic_update_slice", "idx=array",
|
|
|
|
# x:shape: (b, 4)
|
|
|
|
lambda x, idx: lax.dynamic_update_slice(x, x, idx),
|
|
|
|
arg_descriptors=[RandArg((3, 4), _f32), np.array([-2, -1], dtype=np.int32)],
|
2024-07-16 02:04:59 -07:00
|
|
|
polymorphic_shapes=["b, _", None]),
|
2023-06-16 23:58:37 -07:00
|
|
|
[
|
|
|
|
PolyHarness("eig", f"shape={jtu.format_shape_dtype_string((3, 5, 5), dtype)}_poly={poly}_{left=}_{right=}",
|
|
|
|
lambda x, left, right: lax.linalg.eig(x, compute_left_eigenvectors=left, compute_right_eigenvectors=right),
|
|
|
|
arg_descriptors=[RandArg((3, 5, 5), dtype),
|
|
|
|
StaticArg(left), StaticArg(right)],
|
|
|
|
polymorphic_shapes=[poly],
|
|
|
|
# In non-native serialization, we cannot check exact match,
|
|
|
|
# we ought to check the invariants of the result.
|
2023-10-12 13:15:22 +01:00
|
|
|
check_result=config.jax2tf_default_native_serialization.value)
|
2023-11-19 08:59:23 -08:00
|
|
|
for dtype in {np.float32, np.float64, np.complex64, np.complex128} & jtu.supported_dtypes()
|
2023-06-16 23:58:37 -07:00
|
|
|
for poly in ["b, ...", "b, w, w"]
|
|
|
|
for left in ([True, False] if dtype == np.float32 else [True])
|
|
|
|
for right in ([True, False] if dtype == np.float32 else [False])
|
|
|
|
],
|
2022-12-17 05:56:48 +02:00
|
|
|
PolyHarness("einsum", "0",
|
|
|
|
lambda x: jnp.einsum("...i->...", x),
|
|
|
|
arg_descriptors=[RandArg((3, 4), _f32)],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=["b, ..."]),
|
2022-12-17 05:56:48 +02:00
|
|
|
PolyHarness("einsum", "0_alt",
|
|
|
|
lambda x: jnp.einsum(x, (..., 1), [...]),
|
|
|
|
arg_descriptors=[RandArg((3, 4), _f32)],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=["b, ..."]),
|
2022-12-17 05:56:48 +02:00
|
|
|
PolyHarness("einsum", "1",
|
|
|
|
lambda x, y: jnp.einsum("...ij,...jk->...ik", x, y),
|
|
|
|
arg_descriptors=[RandArg((3, 4, 5), _f32), RandArg((3, 5, 6), _f32)],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=["b, ...", "b, ..."]),
|
2022-12-17 05:56:48 +02:00
|
|
|
PolyHarness("einsum", "1_alt",
|
|
|
|
lambda x, y: jnp.einsum(x, [..., 0, 1], y, (..., 1, 2), [..., 0, 2]),
|
|
|
|
arg_descriptors=[RandArg((3, 4, 5), _f32), RandArg((3, 5, 6), _f32)],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=["b, ...", "b, ..."]),
|
2022-12-17 05:56:48 +02:00
|
|
|
PolyHarness("einsum", "2",
|
|
|
|
lambda x, y: jnp.einsum("...ij,jk->...ik", x, y),
|
|
|
|
arg_descriptors=[RandArg((3, 4, 5), _f32), RandArg((5, 6), _f32)],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=["b, ...", None]),
|
2022-12-17 05:56:48 +02:00
|
|
|
PolyHarness("einsum", "2_alt",
|
|
|
|
lambda x, y: jnp.einsum(x, [..., 0, 1], y, [1, 2], [..., 0, 2]),
|
|
|
|
arg_descriptors=[RandArg((3, 4, 5), _f32), RandArg((5, 6), _f32)],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=["b, ...", None]),
|
2022-12-17 05:56:48 +02:00
|
|
|
PolyHarness("einsum", "3",
|
|
|
|
# Reduced dimension is polymorphic
|
|
|
|
lambda x, y: jnp.einsum("ij,jk->ik", x, y),
|
|
|
|
arg_descriptors=[RandArg((3, 4), _f32), RandArg((4, 5), _f32)],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=["_, b", "b, ..."]),
|
2022-12-17 05:56:48 +02:00
|
|
|
PolyHarness("einsum", "3_alt",
|
|
|
|
# Reduced dimension is polymorphic
|
|
|
|
lambda x, y: jnp.einsum(x, [0, 1], y, [1, 2], [0, 2]),
|
|
|
|
arg_descriptors=[RandArg((3, 4), _f32), RandArg((4, 5), _f32)],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=["_, b", "b, ..."]),
|
2022-12-17 05:56:48 +02:00
|
|
|
PolyHarness("einsum", "4",
|
|
|
|
# Reduced dimension is polymorphic, and is 2*b
|
|
|
|
lambda x, y: jnp.einsum("ij,jk->ik",
|
|
|
|
jnp.concatenate([x, x], axis=1),
|
|
|
|
jnp.concatenate([y, y], axis=0)),
|
|
|
|
arg_descriptors=[RandArg((3, 4), _f32), RandArg((4, 5), _f32)],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=["_, b", "b, ..."]),
|
2022-12-17 05:56:48 +02:00
|
|
|
PolyHarness("einsum", "4_alt",
|
|
|
|
# Reduced dimension is polymorphic, and is 2*b
|
|
|
|
lambda x, y: jnp.einsum(jnp.concatenate([x, x], axis=1), [0, 1],
|
|
|
|
jnp.concatenate([y, y], axis=0), [1, 2],
|
|
|
|
[0, 2]),
|
|
|
|
arg_descriptors=[RandArg((3, 4), _f32), RandArg((4, 5), _f32)],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=["_, b", "b, ..."]),
|
2022-12-17 05:56:48 +02:00
|
|
|
PolyHarness("einsum", "multiple_contractions",
|
|
|
|
lambda x, y, z: jnp.einsum("ab,bc,cd->ad", x, y, z),
|
2022-12-20 15:29:51 +02:00
|
|
|
arg_descriptors=[RandArg((3, 2), _f32), RandArg((2, 3), _f32), RandArg((3, 4), _f32)],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=["b, ...", None, None]),
|
2022-12-17 05:56:48 +02:00
|
|
|
PolyHarness("einsum", "incompatible_contractions_error",
|
|
|
|
lambda x, y: jnp.einsum("ab,cb->ac", x, y),
|
|
|
|
arg_descriptors=[RandArg((2, 3), _f32), RandArg((2, 3), _f32)],
|
2022-12-20 15:29:51 +02:00
|
|
|
polymorphic_shapes=["(2, b0)", "(2, b1)"],
|
|
|
|
input_signature=[tf.TensorSpec((2, None)), tf.TensorSpec((2, None))],
|
2023-02-04 08:30:44 +02:00
|
|
|
expect_error=(AssertionError,
|
|
|
|
"Incompatible reduction dimensions")),
|
2022-12-17 05:56:48 +02:00
|
|
|
PolyHarness("eye", "N=poly_M=None",
|
2023-02-13 02:47:35 -08:00
|
|
|
lambda x: jnp.eye(x.shape[0]) + x,
|
|
|
|
arg_descriptors=[RandArg((3, 1), _f32)],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=["b, ..."]),
|
2022-12-17 05:56:48 +02:00
|
|
|
PolyHarness("eye", "N=poly_M=poly",
|
2023-02-13 02:47:35 -08:00
|
|
|
lambda x: jnp.eye(x.shape[0], M=x.shape[0] + 2) + x,
|
|
|
|
arg_descriptors=[RandArg((3, 1), _f32)],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=["b, ..."]),
|
2023-04-04 06:45:17 -07:00
|
|
|
[
|
|
|
|
PolyHarness("fft", f"{fft_type=}_{nr_fft_lengths=}",
|
|
|
|
lambda x, fft_type, nr_fft_lengths: lax.fft_p.bind(
|
|
|
|
x, fft_type=fft_type,
|
|
|
|
fft_lengths=tuple(
|
2024-10-10 08:06:50 -07:00
|
|
|
x.shape[-nr_fft_lengths:] if fft_type != lax.FftType.IRFFT else
|
2023-04-04 06:45:17 -07:00
|
|
|
[(x.shape[-1] - 1) * 2])),
|
|
|
|
arg_descriptors=[
|
|
|
|
RandArg((3, 4, 5, 6),
|
2024-10-10 08:06:50 -07:00
|
|
|
np.float32 if fft_type == lax.FftType.RFFT else np.complex64),
|
2023-04-04 06:45:17 -07:00
|
|
|
StaticArg(fft_type),
|
|
|
|
StaticArg(nr_fft_lengths)],
|
|
|
|
# All axes but the last one are dynamic. This means that the test
|
|
|
|
# with nr_fft_lengths==1 will not have dynamic fft_lengths.
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=["b0, b1, b2, ..."],
|
2023-04-04 06:45:17 -07:00
|
|
|
tol=1e-4)
|
|
|
|
|
2024-10-10 08:06:50 -07:00
|
|
|
for fft_type in (lax.FftType.FFT, lax.FftType.IFFT,
|
|
|
|
lax.FftType.RFFT, lax.FftType.IRFFT)
|
2023-04-04 06:45:17 -07:00
|
|
|
for nr_fft_lengths in (1, 2)
|
|
|
|
],
|
2022-12-17 05:56:48 +02:00
|
|
|
PolyHarness("full", "",
|
2023-02-13 02:47:35 -08:00
|
|
|
lambda x: lax.full((x.shape[0], 2), 3.) + x,
|
|
|
|
arg_descriptors=[RandArg((3, 1), _f32)],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=["b, ..."]),
|
2023-11-12 18:17:07 +01:00
|
|
|
PolyHarness("gather", "1d",
|
|
|
|
lambda operand, start_indices, x: lax.gather(
|
|
|
|
operand,
|
|
|
|
start_indices,
|
|
|
|
dimension_numbers=lax.GatherDimensionNumbers(
|
|
|
|
offset_dims=(1,),
|
|
|
|
collapsed_slice_dims=(),
|
|
|
|
start_index_map=(0,)),
|
|
|
|
slice_sizes=x.shape,
|
|
|
|
mode="promise_in_bounds"),
|
|
|
|
arg_descriptors=[
|
|
|
|
RandArg((10,), np.float32),
|
|
|
|
np.random.randint(0, high=10, size=(3, 1),
|
|
|
|
dtype=np.int32),
|
|
|
|
np.zeros((10,), dtype=jnp.int32),
|
|
|
|
],
|
|
|
|
polymorphic_shapes=["(t, )", "(3, 1)", "(t)"]),
|
2021-07-28 19:30:44 +03:00
|
|
|
# operand is non-poly, index is poly
|
2022-12-17 05:56:48 +02:00
|
|
|
PolyHarness("getitem", "op=static_idx=poly",
|
|
|
|
lambda a, i: a[i],
|
|
|
|
arg_descriptors=[RandArg((3, 4), _f32), np.array([2, 2], np.int32)],
|
2024-07-16 02:04:59 -07:00
|
|
|
polymorphic_shapes=[None, "b0, ..."]),
|
2021-07-28 19:30:44 +03:00
|
|
|
# operand is poly, index is integer
|
2022-12-17 05:56:48 +02:00
|
|
|
PolyHarness("getitem", "op=poly_idx=const",
|
|
|
|
lambda a: a[1],
|
|
|
|
arg_descriptors=[RandArg((3, 4), _f32)],
|
2024-07-16 02:04:59 -07:00
|
|
|
polymorphic_shapes=["b, ..."]),
|
2021-07-28 19:30:44 +03:00
|
|
|
# operand is poly, index is dim poly
|
2022-12-17 05:56:48 +02:00
|
|
|
PolyHarness("getitem", "op=poly_idx=dim",
|
|
|
|
lambda a: a[jnp.array(a.shape[0] - 2)],
|
|
|
|
arg_descriptors=[RandArg((3, 4), _f32)],
|
2024-07-16 02:04:59 -07:00
|
|
|
polymorphic_shapes=["b, ..."]),
|
2021-07-28 19:30:44 +03:00
|
|
|
# Both the operand and the index are poly
|
2022-12-17 05:56:48 +02:00
|
|
|
PolyHarness("getitem", "op=poly_idx=poly",
|
|
|
|
lambda a, i: a[i],
|
|
|
|
arg_descriptors=[RandArg((3, 4), _f32), np.array([1, 2, 0], np.int32)],
|
2024-07-16 02:04:59 -07:00
|
|
|
polymorphic_shapes=["b, ...", "b, ..."]),
|
2021-10-07 11:30:07 +02:00
|
|
|
# op is poly and index is an entire slice
|
2022-12-17 05:56:48 +02:00
|
|
|
PolyHarness("getitem", "op=poly_idx=slice-all",
|
|
|
|
lambda a: a[:],
|
|
|
|
arg_descriptors=[RandArg((3, 4), _f32)],
|
2024-07-16 02:04:59 -07:00
|
|
|
polymorphic_shapes=["b, ..."]),
|
2021-10-07 11:30:07 +02:00
|
|
|
# op is poly and index is a partial slice
|
2022-12-17 05:56:48 +02:00
|
|
|
PolyHarness("getitem", "op=poly_idx=slice-ct-1",
|
|
|
|
lambda a: a[:2],
|
|
|
|
arg_descriptors=[RandArg((3, 4), _f32)],
|
2024-07-16 02:04:59 -07:00
|
|
|
polymorphic_shapes=["b + 2, ..."]),
|
2022-12-17 05:56:48 +02:00
|
|
|
PolyHarness("getitem", "op=poly_idx=slice-ct-2",
|
|
|
|
lambda a: a[:, :2],
|
|
|
|
arg_descriptors=[RandArg((3, 4), _f32)],
|
2024-07-16 02:04:59 -07:00
|
|
|
polymorphic_shapes=["b, ..."]),
|
2022-12-17 05:56:48 +02:00
|
|
|
PolyHarness("getitem", "op=poly_idx=slice-None-1",
|
|
|
|
lambda a: a[:a.shape[0]],
|
|
|
|
arg_descriptors=[RandArg((3, 4), _f32)],
|
2024-07-16 02:04:59 -07:00
|
|
|
polymorphic_shapes=["b, ..."]),
|
2022-12-17 05:56:48 +02:00
|
|
|
PolyHarness("getitem", "op=poly_idx=slice-poly",
|
|
|
|
lambda a: a[:a.shape[0] - 1],
|
|
|
|
arg_descriptors=[RandArg((3, 4), _f32)],
|
2024-07-16 02:04:59 -07:00
|
|
|
polymorphic_shapes=["b, ..."]),
|
2022-12-17 05:56:48 +02:00
|
|
|
PolyHarness("image_resize", "linear_0",
|
|
|
|
lambda x: jax.image.resize(x, (x.shape[0], 2 * x.shape[1], 2 * x.shape[2], x.shape[3]),
|
|
|
|
method="linear"),
|
|
|
|
arg_descriptors=[RandArg((3, 16, 32, 3), _f32)],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=["_, b1, b2, ..."]),
|
2022-12-17 05:56:48 +02:00
|
|
|
PolyHarness("image_resize", "linear_to_fixed_dim",
|
|
|
|
lambda x: jax.image.resize(x, (x.shape[0], 64, 64, x.shape[3]),
|
|
|
|
method="linear"),
|
|
|
|
arg_descriptors=[RandArg((3, 16, 32, 3), _f32)],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=["_, b1, b2, ..."]),
|
2022-12-17 05:56:48 +02:00
|
|
|
PolyHarness("image_resize", "nearest_0",
|
|
|
|
lambda x: jax.image.resize(x, (x.shape[0], 2 * x.shape[1], 2 * x.shape[2], x.shape[3]),
|
|
|
|
method="nearest"),
|
|
|
|
arg_descriptors=[RandArg((3, 5, 7, 3), _f32)],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=["_, b1, b2, ..."]),
|
2022-12-17 05:56:48 +02:00
|
|
|
PolyHarness("index_in_dim", "0",
|
|
|
|
lambda x: lax.index_in_dim(x, -1, axis=0, keepdims=False),
|
|
|
|
arg_descriptors=[RandArg((3, 4), _f32)],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=["b, ..."]),
|
2022-12-17 05:56:48 +02:00
|
|
|
PolyHarness("index_in_dim", "idx=neg",
|
|
|
|
lambda x: lax.index_in_dim(x, -1, axis=0, keepdims=False),
|
|
|
|
arg_descriptors=[RandArg((3, 4), _f32)],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=["b, ..."]),
|
2022-12-17 05:56:48 +02:00
|
|
|
PolyHarness("index_in_dim", "idx=last",
|
|
|
|
lambda x: lax.index_in_dim(x, x.shape[0] - 1, axis=0, keepdims=False),
|
|
|
|
arg_descriptors=[RandArg((3, 4), _f32)],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=["b, ..."]),
|
2023-03-02 14:31:32 +01:00
|
|
|
PolyHarness("jnp.insert", "insert=constant",
|
|
|
|
lambda x: jnp.insert(x, jnp.arange(3, dtype=_i32), np.array([3, 4, 5], dtype=_i32)),
|
|
|
|
arg_descriptors=[RandArg((12,), _i32)],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=["b, ..."],
|
2023-03-02 14:31:32 +01:00
|
|
|
expect_error=expect_error_associative_scan),
|
|
|
|
PolyHarness("jnp.insert", "insert=poly",
|
|
|
|
lambda x: jnp.insert(x, jnp.arange(x.shape[0], dtype=_i32), x, axis=0),
|
|
|
|
arg_descriptors=[RandArg((12, 3), _i32)],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=["b0, b1, ..."],
|
2023-03-02 14:31:32 +01:00
|
|
|
expect_error=expect_error_associative_scan),
|
2022-12-17 05:56:48 +02:00
|
|
|
PolyHarness("iota", "",
|
|
|
|
lambda x: x + lax.iota(_f32, x.shape[0]),
|
|
|
|
arg_descriptors=[RandArg((3,), _f32)],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=["b, ..."]),
|
2022-12-17 05:56:48 +02:00
|
|
|
PolyHarness("matmul", "0",
|
|
|
|
jnp.matmul,
|
|
|
|
arg_descriptors=[RandArg((7, 8, 4), _f32), RandArg((7, 4, 5), _f32)],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=["b, ...", "b, ..."],
|
2022-12-17 05:56:48 +02:00
|
|
|
tol=1e-5),
|
|
|
|
PolyHarness("matmul", "1",
|
|
|
|
jnp.matmul,
|
|
|
|
arg_descriptors=[RandArg((7, 8, 4), _f32), RandArg((4, 5), _f32)],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=["b, ...", None],
|
2022-12-17 05:56:48 +02:00
|
|
|
tol=1e-5),
|
[jax2tf] Expand shape polymorphism support to use dimension polynomials as values.
The goal of this change is to support shape polymorphism for operations
such as average (which needs to divide by the size of a dimension) or
indexing (which needs to normalize indices by comparing them with 0 and
adding dimension size for negative indices). In both of these cases
the size of a dimenion needs to be used as a value in the array
computation. In general, the size of a dimension is used only to
customize primitives.
This change introduces `core.dim_as_value` which must be used on
a dimension size before using it as a value in the array computation.
E.g.,
```
def average(x):
return jnp.sum(x, axis=0) / core.dim_as_value(x.shape[0])
```
This function is the identity function if the dimension size is
constant, otherwise it uses a new primitive `shape_poly.dim_as_value_p`.
Note that this does not change fundamentally the flavor of shape
polymorphism supported in jax2tf: intermediate shapes and their values
may depend on the input shapes, but never does a shape depend on the
input values. In fact, one could have expressed the `dim_as_value`
already:
```
def dim_as_value(d):
jnp.sum(jnp.broadcast_to(jnp.array(1), shape=(d,)))
```
We were able to suppot `jnp.mean`, `jnp.average`, `jnp.take`,
`lax.dynamic_slice`, `lax.dynamic_update_slice` by using
`core.dim_as_value` internally, but to fully roll-up the solution
we need to make `core.dim_as_value` a public API and teach the
users how to use it when they want to use shape polymorphism.
Alternatively, perhaps there is a way to automatically convert
dimension polynomials to values when passed to the lax primitives.
2021-07-16 20:01:22 +03:00
|
|
|
[
|
2022-12-17 05:56:48 +02:00
|
|
|
PolyHarness("mean",
|
|
|
|
f"{axis=}_{keepdims=}_where=None",
|
|
|
|
lambda x, axis, keepdims: jnp.mean(x, axis=axis, keepdims=keepdims, where=None),
|
|
|
|
arg_descriptors=[RandArg((7, 8, 4), _f32), StaticArg(axis), StaticArg(keepdims)],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=["b, ..."])
|
[jax2tf] Expand shape polymorphism support to use dimension polynomials as values.
The goal of this change is to support shape polymorphism for operations
such as average (which needs to divide by the size of a dimension) or
indexing (which needs to normalize indices by comparing them with 0 and
adding dimension size for negative indices). In both of these cases
the size of a dimenion needs to be used as a value in the array
computation. In general, the size of a dimension is used only to
customize primitives.
This change introduces `core.dim_as_value` which must be used on
a dimension size before using it as a value in the array computation.
E.g.,
```
def average(x):
return jnp.sum(x, axis=0) / core.dim_as_value(x.shape[0])
```
This function is the identity function if the dimension size is
constant, otherwise it uses a new primitive `shape_poly.dim_as_value_p`.
Note that this does not change fundamentally the flavor of shape
polymorphism supported in jax2tf: intermediate shapes and their values
may depend on the input shapes, but never does a shape depend on the
input values. In fact, one could have expressed the `dim_as_value`
already:
```
def dim_as_value(d):
jnp.sum(jnp.broadcast_to(jnp.array(1), shape=(d,)))
```
We were able to suppot `jnp.mean`, `jnp.average`, `jnp.take`,
`lax.dynamic_slice`, `lax.dynamic_update_slice` by using
`core.dim_as_value` internally, but to fully roll-up the solution
we need to make `core.dim_as_value` a public API and teach the
users how to use it when they want to use shape polymorphism.
Alternatively, perhaps there is a way to automatically convert
dimension polynomials to values when passed to the lax primitives.
2021-07-16 20:01:22 +03:00
|
|
|
for keepdims in [False, True]
|
|
|
|
for axis in [None, (0,), (0, 1), (1,)]
|
|
|
|
],
|
|
|
|
[
|
2022-12-17 05:56:48 +02:00
|
|
|
PolyHarness("mean",
|
|
|
|
f"{axis=}_{keepdims=}_where=Some",
|
|
|
|
lambda x, where, axis, keepdims: jnp.mean(x, axis=axis, keepdims=keepdims, where=where),
|
|
|
|
arg_descriptors=[RandArg((7, 8, 4), _f32), RandArg((7, 8, 4), np.bool_),
|
|
|
|
StaticArg(axis), StaticArg(keepdims)],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=["b, ...", "b, ..."])
|
[jax2tf] Expand shape polymorphism support to use dimension polynomials as values.
The goal of this change is to support shape polymorphism for operations
such as average (which needs to divide by the size of a dimension) or
indexing (which needs to normalize indices by comparing them with 0 and
adding dimension size for negative indices). In both of these cases
the size of a dimenion needs to be used as a value in the array
computation. In general, the size of a dimension is used only to
customize primitives.
This change introduces `core.dim_as_value` which must be used on
a dimension size before using it as a value in the array computation.
E.g.,
```
def average(x):
return jnp.sum(x, axis=0) / core.dim_as_value(x.shape[0])
```
This function is the identity function if the dimension size is
constant, otherwise it uses a new primitive `shape_poly.dim_as_value_p`.
Note that this does not change fundamentally the flavor of shape
polymorphism supported in jax2tf: intermediate shapes and their values
may depend on the input shapes, but never does a shape depend on the
input values. In fact, one could have expressed the `dim_as_value`
already:
```
def dim_as_value(d):
jnp.sum(jnp.broadcast_to(jnp.array(1), shape=(d,)))
```
We were able to suppot `jnp.mean`, `jnp.average`, `jnp.take`,
`lax.dynamic_slice`, `lax.dynamic_update_slice` by using
`core.dim_as_value` internally, but to fully roll-up the solution
we need to make `core.dim_as_value` a public API and teach the
users how to use it when they want to use shape polymorphism.
Alternatively, perhaps there is a way to automatically convert
dimension polynomials to values when passed to the lax primitives.
2021-07-16 20:01:22 +03:00
|
|
|
for keepdims in [False, True]
|
|
|
|
for axis in [None, (0,), (0, 1), (1,)]
|
|
|
|
],
|
2023-03-02 14:31:32 +01:00
|
|
|
PolyHarness("jnp.nonzero", "size=constant",
|
|
|
|
lambda x: jnp.nonzero(x % 3, size=10, fill_value=100),
|
|
|
|
arg_descriptors=[RandArg((3, 2, 4), _i32)],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=["b, ..."],
|
2023-03-02 14:31:32 +01:00
|
|
|
expect_error=expect_error_associative_scan),
|
|
|
|
PolyHarness("jnp.nonzero", "size=poly",
|
|
|
|
lambda x: jnp.nonzero(x % 3, size=x.shape[0] * 2, fill_value=100),
|
|
|
|
arg_descriptors=[RandArg((3, 2, 4), _i32)],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=["b, ..."],
|
2023-03-02 14:31:32 +01:00
|
|
|
expect_error=expect_error_associative_scan),
|
2023-04-12 13:27:17 +03:00
|
|
|
PolyHarness("one_hot", "poly_num_classes",
|
|
|
|
lambda x, y: jax.nn.one_hot(x, y.shape[0]),
|
2024-12-19 07:11:31 -08:00
|
|
|
arg_descriptors=[np.arange(16, dtype=_i32), RandArg((16,), _f32)],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=[None, "b0, ..."]),
|
2023-04-12 13:27:17 +03:00
|
|
|
PolyHarness("one_hot", "all_poly",
|
|
|
|
lambda x, y: jax.nn.one_hot(x, y.shape[0]),
|
2024-12-19 07:11:31 -08:00
|
|
|
arg_descriptors=[np.arange(16, dtype=_i32), RandArg((16,), _f32)],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=["b, ...", "b, ..."]),
|
2022-12-17 05:56:48 +02:00
|
|
|
PolyHarness("ones", "",
|
2023-02-13 02:47:35 -08:00
|
|
|
lambda x: jnp.ones(x.shape, dtype=_f32) + x,
|
2022-12-17 05:56:48 +02:00
|
|
|
arg_descriptors=[RandArg((3, 2, 4), _f32)],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=["b, ..."]),
|
2022-12-17 05:56:48 +02:00
|
|
|
PolyHarness("pad", "",
|
|
|
|
lax.pad,
|
|
|
|
arg_descriptors=[RandArg((3, 2, 5), _f32), np.float32(5.),
|
|
|
|
StaticArg(((0, 0, 0), (0, 0, 0), (1, 1, 1)))],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=["b, ...", None]),
|
2022-12-29 03:00:03 -08:00
|
|
|
PolyHarness("pad", "poly_padding_config",
|
|
|
|
lambda x: lax.pad(x, _f32(0.),
|
|
|
|
((x.shape[0], x.shape[1], x.shape[0]),
|
|
|
|
(0, 0, 0))),
|
|
|
|
arg_descriptors=[RandArg((3, 2), _f32)],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=["b, ..."]),
|
2022-12-29 03:00:03 -08:00
|
|
|
PolyHarness("jnp.pad", "mode=constant",
|
|
|
|
lambda x: jnp.pad(x, [[x.shape[0], 0], [x.shape[1], 1]],
|
|
|
|
mode="constant"),
|
|
|
|
arg_descriptors=[RandArg((3, 5), _f32)],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=["b, ..."]),
|
2022-12-29 03:00:03 -08:00
|
|
|
PolyHarness("jnp.pad", "mode=constant_bminus1",
|
|
|
|
# We slice first the unknown dimension to make it of size b - 1
|
|
|
|
# which may be 0.
|
|
|
|
lambda x: jnp.pad(lax.dynamic_slice_in_dim(x, 1, x.shape[0] - 1,
|
|
|
|
axis=0),
|
|
|
|
[[x.shape[0], 0], [x.shape[1], 1]],
|
|
|
|
mode="constant"),
|
|
|
|
arg_descriptors=[RandArg((3, 5), _f32)],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=["b, ..."]),
|
2022-12-29 03:00:03 -08:00
|
|
|
PolyHarness("jnp.pad", "mode=edge",
|
|
|
|
lambda x: jnp.pad(x, [[x.shape[0], 0], [x.shape[1], 1]],
|
|
|
|
mode="edge"),
|
|
|
|
arg_descriptors=[RandArg((3, 5), _f32)],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=["b, ..."]),
|
2023-02-14 17:40:50 +01:00
|
|
|
PolyHarness("percentile", "axis=None",
|
|
|
|
lambda x: jnp.percentile(x, 50, axis=None),
|
|
|
|
arg_descriptors=[RandArg((3, 5), _f32)],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=["b, ..."]),
|
2023-02-14 17:40:50 +01:00
|
|
|
PolyHarness("nanquantile", "axis=None",
|
|
|
|
lambda x: jnp.nanquantile(x, .5, axis=None),
|
|
|
|
arg_descriptors=[RandArg((3, 5), _f32)],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=["b, ..."]),
|
2023-02-14 17:40:50 +01:00
|
|
|
PolyHarness("percentile", "axis=0",
|
|
|
|
lambda x: jnp.percentile(x, 50, axis=0),
|
|
|
|
arg_descriptors=[RandArg((3, 5), _f32)],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=["b, ..."]),
|
2023-02-14 17:40:50 +01:00
|
|
|
PolyHarness("nanquantile", "axis=0",
|
|
|
|
lambda x: jnp.nanquantile(x, .5, axis=0),
|
|
|
|
arg_descriptors=[RandArg((3, 5), _f32)],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=["b, ..."]),
|
2023-06-22 01:39:20 -07:00
|
|
|
[
|
|
|
|
PolyHarness(
|
|
|
|
"qr", f"shape={jtu.format_shape_dtype_string(shape, dtype)}_poly={poly}_{full_matrices=}",
|
|
|
|
lambda x, full_matrices: lax.linalg.qr(x, full_matrices=full_matrices),
|
|
|
|
arg_descriptors=[RandArg(shape, dtype), StaticArg(full_matrices)],
|
|
|
|
polymorphic_shapes=[poly],
|
2023-10-12 13:15:22 +01:00
|
|
|
tol=(None if config.jax2tf_default_native_serialization.value else 1e-5))
|
2023-11-19 08:59:23 -08:00
|
|
|
for dtype in {np.float32, np.float64, np.complex64, np.complex128} & jtu.supported_dtypes()
|
2023-06-22 01:39:20 -07:00
|
|
|
# m and n must be static for now
|
|
|
|
for shape, poly, full_matrices in [
|
|
|
|
((2, 0, 4), "b, ...", False), # m = 0
|
|
|
|
((2, 4, 0), "b, ...", False), # n = 0
|
|
|
|
((2, 3, 4, 4), "b1, b2, ...", False), # m == n
|
|
|
|
((2, 3, 4, 4), "b1, b2, ...", True),
|
|
|
|
((2, 3, 4, 5), "b1, b2, ...", False), # m < n
|
|
|
|
((2, 3, 4, 5), "b1, b2, ...", True),
|
|
|
|
((2, 3, 8, 4), "b1, b2, ...", False), # m > n
|
|
|
|
((2, 3, 8, 4), "b1, b2, ...", True),
|
|
|
|
]
|
|
|
|
],
|
2023-05-04 09:52:21 +02:00
|
|
|
[
|
|
|
|
# The random primitive tests, with threefry (both partitionable and
|
|
|
|
# non-partitionable), and unsafe_rbg.
|
|
|
|
[
|
|
|
|
PolyHarness("random_gamma", f"{flags_name}",
|
2024-01-17 12:53:24 -08:00
|
|
|
lambda key, a: jax.vmap(jax.random.gamma)(key, a),
|
2023-05-04 09:52:21 +02:00
|
|
|
arg_descriptors=[RandArg((3, key_size), np.uint32), RandArg((3, 4, 5), _f32)],
|
2023-07-19 16:15:03 -07:00
|
|
|
polymorphic_shapes=["b, ...", "b, w, ..."], tol=1E-5,
|
2023-05-04 09:52:21 +02:00
|
|
|
override_jax_config_flags=override_jax_config_flags), # type: ignore
|
|
|
|
# The known dimensions product must be even.
|
|
|
|
PolyHarness("random_categorical", f"axis=0_{flags_name}",
|
|
|
|
lambda key, a: jax.random.categorical(key, a, axis=0),
|
|
|
|
arg_descriptors=[RandArg((key_size,), np.uint32), RandArg((3, 8), _f32)],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=[None, "b0, ..."],
|
2023-05-04 09:52:21 +02:00
|
|
|
override_jax_config_flags=override_jax_config_flags), # type: ignore
|
|
|
|
PolyHarness("random_categorical", f"axis=1_{flags_name}",
|
|
|
|
lambda key, a: jax.random.categorical(key, a, axis=1),
|
|
|
|
arg_descriptors=[RandArg((key_size,), np.uint32), RandArg((3, 5, 8), _f32)],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=[None, "b0, b1, ..."],
|
2023-05-04 09:52:21 +02:00
|
|
|
override_jax_config_flags=override_jax_config_flags), # type: ignore
|
|
|
|
PolyHarness("random_categorical", f"axis=1_then_reshape_{flags_name}",
|
2023-07-21 14:20:39 -04:00
|
|
|
lambda key, a: jax.random.categorical(key, a, axis=1).reshape(-1),
|
2023-05-04 09:52:21 +02:00
|
|
|
arg_descriptors=[RandArg((key_size,), np.uint32), RandArg((3, 5, 8), _f32)],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=[None, "b0, b1, ..."],
|
2023-05-04 09:52:21 +02:00
|
|
|
override_jax_config_flags=override_jax_config_flags), # type: ignore
|
|
|
|
PolyHarness("random_categorical", f"0_dim_{flags_name}", # One axis has 0 size
|
|
|
|
lambda key, a: jax.random.categorical(key, a, axis=1),
|
|
|
|
arg_descriptors=[RandArg((key_size,), np.uint32), RandArg((3, 5, 0), _f32)],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=[None, "b0, b1, ..."],
|
2023-05-04 09:52:21 +02:00
|
|
|
override_jax_config_flags=override_jax_config_flags), # type: ignore
|
|
|
|
PolyHarness("random_split", f"{flags_name}",
|
2023-05-25 15:16:46 -07:00
|
|
|
lambda key, a: jax.random.key_data(jax.random.split(key, 2 * a.shape[0])),
|
2023-05-04 09:52:21 +02:00
|
|
|
arg_descriptors=[RandArg((key_size,), np.uint32),
|
|
|
|
RandArg((3, 4), _f32)],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=[None, "b0, ..."],
|
2023-05-04 09:52:21 +02:00
|
|
|
override_jax_config_flags=override_jax_config_flags), # type: ignore
|
|
|
|
# Works when the known dimensions are known to be even or odd.
|
|
|
|
PolyHarness("random_uniform", f"even_1_{flags_name}",
|
|
|
|
lambda key, a: jax.random.uniform(key, a.shape, dtype=_f32),
|
|
|
|
arg_descriptors=[RandArg((key_size,), np.uint32), RandArg((3, 4, 5), _f32)],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=[None, "b0, ..."],
|
2023-05-04 09:52:21 +02:00
|
|
|
override_jax_config_flags=override_jax_config_flags), # type: ignore
|
|
|
|
PolyHarness("random_uniform", f"even_2_{flags_name}",
|
|
|
|
lambda key, a: jax.random.uniform(key, (2 * a.shape[0], a.shape[1]),
|
|
|
|
dtype=_f32),
|
|
|
|
arg_descriptors=[RandArg((key_size,), np.uint32), RandArg((3, 4), _f32)],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=[None, "b0, b1, ..."],
|
2023-05-04 09:52:21 +02:00
|
|
|
override_jax_config_flags=override_jax_config_flags), # type: ignore
|
|
|
|
PolyHarness("random_uniform", f"error_not_even_{flags_name}",
|
|
|
|
lambda key, a: jax.random.uniform(key, a.shape, dtype=_f32),
|
|
|
|
arg_descriptors=[RandArg((key_size,), np.uint32), RandArg((3, 5), _f32)],
|
2024-12-21 13:24:28 +08:00
|
|
|
polymorphic_shapes=[None, "b0, b1"],
|
2023-05-04 09:52:21 +02:00
|
|
|
override_jax_config_flags=override_jax_config_flags) # type: ignore
|
|
|
|
]
|
|
|
|
for key_size, flags_name, override_jax_config_flags in [
|
|
|
|
(2, "threefry_non_partitionable",
|
|
|
|
dict(jax_default_prng_impl="threefry2x32", jax_threefry_partitionable=False)),
|
|
|
|
(2, "threefry_partitionable",
|
|
|
|
dict(jax_default_prng_impl="threefry2x32", jax_threefry_partitionable=True)),
|
|
|
|
(4, "unsafe_rbg",
|
|
|
|
dict(jax_default_prng_impl="unsafe_rbg"))
|
|
|
|
]
|
|
|
|
],
|
2023-06-17 10:33:29 -07:00
|
|
|
# For reduce_window we have a variant with one reduction axis of
|
|
|
|
# non-static shape, and one with additionally the dimension window
|
|
|
|
# non-static.
|
|
|
|
PolyHarness("reduce_window", "min_window_size=static",
|
|
|
|
# x: f32[b, 8]
|
2022-12-17 05:56:48 +02:00
|
|
|
lambda x: lax.reduce_window(x, np.array(1., _f32), lax.min,
|
|
|
|
(2, 2), (1, 1), "VALID"),
|
|
|
|
arg_descriptors=[RandArg((3, 8), _f32)],
|
2024-07-16 02:04:59 -07:00
|
|
|
polymorphic_shapes=["b, ..."]),
|
2023-06-17 10:33:29 -07:00
|
|
|
PolyHarness("reduce_window", "min_window_size=dynamic",
|
|
|
|
# x: f32[b, 8]
|
|
|
|
lambda x: lax.reduce_window(x, np.array(1., _f32), lax.min,
|
|
|
|
(2, x.shape[0]), (1, 1), "VALID"),
|
|
|
|
arg_descriptors=[RandArg((3, 8), _f32)],
|
2024-07-16 02:04:59 -07:00
|
|
|
polymorphic_shapes=["b, ..."]),
|
2023-06-17 10:33:29 -07:00
|
|
|
PolyHarness("reduce_window", "min_plus_max_window_size=static",
|
|
|
|
# x: f32[b, 8]
|
|
|
|
lambda x: (
|
|
|
|
# Test that we don't get confusion for the reducer name.
|
|
|
|
lax.reduce_window(x, np.array(1., _f32), lax.min,
|
|
|
|
(2, 2), (1, 1), "VALID") +
|
|
|
|
lax.reduce_window(x, np.array(1., _f32), lax.max,
|
|
|
|
(2, 2), (1, 1), "VALID")),
|
|
|
|
arg_descriptors=[RandArg((3, 8), _f32)],
|
2024-07-16 02:04:59 -07:00
|
|
|
polymorphic_shapes=["b, ..."]),
|
2023-06-17 10:33:29 -07:00
|
|
|
PolyHarness("reduce_window", "min_plus_max_window_size=dynamic",
|
|
|
|
# x: f32[b, 8]
|
|
|
|
lambda x: (
|
|
|
|
# Test that we don't get confusion for the reducer name.
|
|
|
|
lax.reduce_window(x, np.array(1., _f32), lax.min,
|
|
|
|
(2, x.shape[0]), (1, 1), "VALID") +
|
|
|
|
lax.reduce_window(x, np.array(1., _f32), lax.max,
|
|
|
|
(2, x.shape[0]), (1, 1), "VALID")),
|
|
|
|
arg_descriptors=[RandArg((3, 8), _f32)],
|
2024-07-16 02:04:59 -07:00
|
|
|
polymorphic_shapes=["b, ..."]),
|
2023-06-17 10:33:29 -07:00
|
|
|
PolyHarness("reduce_window", "add_monoid_base_window_size=static",
|
|
|
|
# x: f32[b, 8]
|
|
|
|
lambda x: lax.reduce_window(x, np.array(0., _f32), lax.add,
|
|
|
|
(2, 2), (1, 1), "VALID"),
|
|
|
|
arg_descriptors=[RandArg((3, 8), _f32)],
|
2024-07-16 02:04:59 -07:00
|
|
|
polymorphic_shapes=["b, ..."]),
|
2023-06-17 10:33:29 -07:00
|
|
|
PolyHarness("reduce_window", "add_monoid_base_window_size=dynamic",
|
|
|
|
# x: f32[b, 8]
|
|
|
|
lambda x: lax.reduce_window(x, np.array(0., _f32), lax.add,
|
|
|
|
(2, x.shape[0]), (1, 1), "VALID"),
|
2022-12-17 05:56:48 +02:00
|
|
|
arg_descriptors=[RandArg((3, 8), _f32)],
|
2024-07-16 02:04:59 -07:00
|
|
|
polymorphic_shapes=["b, ..."]),
|
2024-09-20 07:51:48 -07:00
|
|
|
# https://github.com/jax-ml/jax/issues/11804
|
2022-08-16 03:01:17 -07:00
|
|
|
# Use the reshape trick to simulate a polymorphic dimension of 16*b.
|
|
|
|
# (See test "conv_general_dilated.1d_1" above for more details.)
|
2023-06-17 10:33:29 -07:00
|
|
|
PolyHarness("reduce_window", "add_monoid_strides_window_size=static",
|
|
|
|
# x: f32[1, 16*b, 1]
|
|
|
|
lambda x: lax.reduce_window(
|
|
|
|
jnp.reshape(x, (1, -1, 1)),
|
|
|
|
np.array(0., _f32), lax.add, (1, 4, 1), (1, 2, 1), "SAME"),
|
|
|
|
arg_descriptors=[RandArg((1, 128, 16), _f32)],
|
2024-07-16 02:04:59 -07:00
|
|
|
polymorphic_shapes=["_, b1, ..."]),
|
2023-06-17 10:33:29 -07:00
|
|
|
PolyHarness("reduce_window", "add_generic_window_size=static",
|
|
|
|
# x: f32[1, 16*b, 1]
|
|
|
|
# Use an initial value of 1. to trigger the generic reduction path
|
2022-12-17 05:56:48 +02:00
|
|
|
lambda x: lax.reduce_window(
|
|
|
|
jnp.reshape(x, (1, -1, 1)),
|
2023-06-17 10:33:29 -07:00
|
|
|
np.array(1., _f32), lax.add, (1, 4, 1), (1, 2, 1), "SAME"),
|
2022-12-17 05:56:48 +02:00
|
|
|
arg_descriptors=[RandArg((1, 128, 16), _f32)],
|
2024-07-16 02:04:59 -07:00
|
|
|
polymorphic_shapes=["_, b1, ..."]),
|
2023-06-17 10:33:29 -07:00
|
|
|
PolyHarness("reduce_window", "variadic_generic_window_size=static",
|
|
|
|
# x: f32[b, 8] y: f32[b, 8]
|
|
|
|
lambda x, y: lax.reduce_window(
|
|
|
|
(x, y), (np.array(1., _f32), np.array(2, _i32)),
|
|
|
|
lambda xy0, xy1: (lax.add(xy0[0], xy1[0]),
|
|
|
|
lax.sub(xy0[1], xy1[1])),
|
|
|
|
(2, 2), (1, 1), "VALID"),
|
|
|
|
arg_descriptors=[RandArg((3, 8), _f32), RandArg((3, 8), _i32)],
|
2024-07-16 02:04:59 -07:00
|
|
|
polymorphic_shapes=["b, ...", "b, ..."]),
|
2023-06-17 10:33:29 -07:00
|
|
|
PolyHarness("reduce_window", "variadic_generic_window_size=dynamic",
|
|
|
|
# x: f32[b, 8] y: f32[b, 8]
|
|
|
|
lambda x, y: lax.reduce_window(
|
|
|
|
(x, y), (np.array(1., _f32), np.array(2, _i32)),
|
|
|
|
lambda xy0, xy1: (lax.add(xy0[0], xy1[0]),
|
|
|
|
lax.sub(xy0[1], xy1[1])),
|
|
|
|
(2, x.shape[0]), (1, 1), "VALID"),
|
|
|
|
arg_descriptors=[RandArg((3, 8), _f32), RandArg((3, 8), _i32)],
|
2024-07-16 02:04:59 -07:00
|
|
|
polymorphic_shapes=["b, ...", "b, ..."]),
|
2021-07-29 16:07:13 +03:00
|
|
|
# TODO(necula): not yet supported, but also unlikely to come up.
|
2022-12-17 05:56:48 +02:00
|
|
|
# PolyHarness("random_uniform", "odd",
|
2021-07-29 16:07:13 +03:00
|
|
|
# lambda key, a: jax.random.uniform(key, (2 * a.shape[0] + 1, a.shape[1]),
|
|
|
|
# dtype=_f32),
|
|
|
|
# [RandArg((2,), np.uint32), RandArg((3, 5), _f32)],
|
2023-06-16 12:50:50 +03:00
|
|
|
# polymorphic_shapes=[None, "b0, ..."]),
|
2021-07-28 19:30:44 +03:00
|
|
|
[
|
2022-12-17 05:56:48 +02:00
|
|
|
PolyHarness("reduce", reduce_op.__name__,
|
|
|
|
lambda x: reduce_op(x, axis=-1, keepdims=True), # type: ignore
|
|
|
|
arg_descriptors=[RandArg((3, 5), _f32)],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=["b, ..."])
|
2021-07-28 19:30:44 +03:00
|
|
|
for reduce_op in [jnp.all, jnp.any, jnp.max, jnp.min, jnp.prod, jnp.sum]
|
|
|
|
],
|
2022-08-30 23:04:31 -07:00
|
|
|
# Repeat f32[b, 2] * 3
|
2022-12-17 05:56:48 +02:00
|
|
|
PolyHarness("repeat", "repeats=int_axis=0",
|
|
|
|
lambda x: jnp.repeat(x, repeats=3, axis=0),
|
|
|
|
arg_descriptors=[RandArg((3, 2), _f32)],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=["b, ..."]),
|
2022-08-30 23:04:31 -07:00
|
|
|
# Repeat f32[b, 2] * b
|
2022-12-17 05:56:48 +02:00
|
|
|
PolyHarness("repeat", "repeats=poly_axis=0",
|
|
|
|
lambda x: jnp.repeat(x, repeats=x.shape[0], axis=0),
|
|
|
|
arg_descriptors=[RandArg((3, 2), _f32)],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=["b, ..."]),
|
2022-08-30 23:04:31 -07:00
|
|
|
# Repeat f32[b, 2] * b
|
2022-12-17 05:56:48 +02:00
|
|
|
PolyHarness("repeat", "repeats=poly_axis=None",
|
|
|
|
lambda x: jnp.repeat(x, repeats=x.shape[0], axis=None),
|
|
|
|
arg_descriptors=[RandArg((3, 2), _f32)],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=["b, ..."]),
|
2022-08-30 23:04:31 -07:00
|
|
|
# Repeat f32 * b
|
2022-12-17 05:56:48 +02:00
|
|
|
PolyHarness("repeat", "repeats=poly_axis=None_scalar",
|
2023-02-13 02:47:35 -08:00
|
|
|
lambda x, y: jnp.repeat(x, repeats=y.shape[0], axis=None) + y,
|
|
|
|
arg_descriptors=[RandArg((), _f32), RandArg((3, 1), _f32)],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=[None, "b0, ..."]),
|
2022-12-17 05:56:48 +02:00
|
|
|
PolyHarness("repeat", "repeats=poly_axis=None_total_repeat_length1",
|
|
|
|
lambda x: jnp.repeat(x, repeats=x.shape[0], axis=None, total_repeat_length=8),
|
|
|
|
arg_descriptors=[RandArg((3, 2), _f32)],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=["b, ..."],
|
2023-01-18 12:27:02 +02:00
|
|
|
expect_error=(ValueError, "jnp.repeat with a non-constant `repeats` is supported only .*")),
|
2022-12-17 05:56:48 +02:00
|
|
|
PolyHarness("reshape", "0",
|
|
|
|
lambda x: x.reshape([x.shape[0], -1]),
|
|
|
|
arg_descriptors=[RandArg((3, 2, 3), _f32)],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=["b, ..."]),
|
2022-12-17 05:56:48 +02:00
|
|
|
PolyHarness("reshape", "1",
|
|
|
|
lambda x: x.reshape([x.shape[0], -1]),
|
|
|
|
arg_descriptors=[RandArg((3, 2, 3), _f32)],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=["b0, b1, ..."]),
|
2022-12-17 05:56:48 +02:00
|
|
|
PolyHarness("reshape", "2",
|
|
|
|
lambda x: x.reshape([x.shape[0], -1, x.shape[3], x.shape[2]]),
|
|
|
|
arg_descriptors=[RandArg((3, 4, 5, 6, 7), _f32)],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=["b0, _, b2, b3, ..."]),
|
2022-12-17 05:56:48 +02:00
|
|
|
PolyHarness("reshape", "3",
|
|
|
|
lambda x: jnp.reshape(x, [2, -1]),
|
|
|
|
arg_descriptors=[RandArg((3, 4, 5, 6, 7), _f32)],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=["b0, _, b2, ..."]),
|
2022-12-17 05:56:48 +02:00
|
|
|
PolyHarness("reshape", "_issue_9975",
|
|
|
|
# The newshape is a scalar
|
|
|
|
lambda x: jnp.reshape(x, x.shape[0] * x.shape[1]),
|
|
|
|
arg_descriptors=[RandArg((3, 4), _f32)],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=["b, ..."]),
|
2022-12-17 05:56:48 +02:00
|
|
|
PolyHarness("reshape", "error",
|
|
|
|
lambda x: x.reshape([x.shape[0], -1, 3]),
|
|
|
|
arg_descriptors=[RandArg((3, 2, 4), _f32)],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=["b, ..."],
|
|
|
|
input_signature=[tf.TensorSpec([None, 2, 4], _f32)],
|
2022-12-17 05:56:48 +02:00
|
|
|
skip_jax_run=True,
|
|
|
|
expect_error=(core.InconclusiveDimensionOperation,
|
|
|
|
re.escape(
|
2023-06-16 12:50:50 +03:00
|
|
|
"Cannot divide evenly the sizes of shapes (b, 2, 4) and (b, -1, 3)"))),
|
2022-12-17 05:56:48 +02:00
|
|
|
PolyHarness("roll", "axis=0",
|
|
|
|
lambda x: jnp.roll(x, 2, axis=0),
|
|
|
|
arg_descriptors=[RandArg((3, 4), _f32)],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=["b, ..."]),
|
2022-12-17 05:56:48 +02:00
|
|
|
PolyHarness("roll", "axis=None",
|
|
|
|
lambda x: jnp.roll(x, 2),
|
|
|
|
arg_descriptors=[RandArg((3, 4), _f32)],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=["b, ..."]),
|
2022-12-17 05:56:48 +02:00
|
|
|
PolyHarness("scatter_add", "",
|
|
|
|
partial(lax.scatter_add, indices_are_sorted=False, unique_indices=True),
|
2023-11-06 18:38:12 +01:00
|
|
|
arg_descriptors=[RandArg((7, 4), _f32), # op: [b, 4]
|
2022-12-17 05:56:48 +02:00
|
|
|
np.array([[1], [2]], np.int32), # indices: [2, 1]
|
2023-11-06 18:38:12 +01:00
|
|
|
RandArg((7, 2), _f32), # updates: [b, 2]
|
2022-12-17 05:56:48 +02:00
|
|
|
StaticArg(lax.ScatterDimensionNumbers((0,), (1,), (1,)))],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=["b, ...", None, "b, ..."]),
|
2022-12-17 05:56:48 +02:00
|
|
|
PolyHarness("scatter_add", "clip0",
|
|
|
|
partial(lax.scatter_add, indices_are_sorted=False, unique_indices=True, mode=lax.GatherScatterMode.CLIP),
|
2023-11-06 18:38:12 +01:00
|
|
|
arg_descriptors=[RandArg((7, 4), _f32), # op: [b, 4]
|
2022-12-17 05:56:48 +02:00
|
|
|
np.array([[1], [2]], np.int32), # indices: [2, 1]
|
|
|
|
RandArg((7, 2), _f32), # updates: [b, 2]
|
|
|
|
StaticArg(lax.ScatterDimensionNumbers((0,), (1,), (1,)))],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=["b, ...", None, "b, ..."]),
|
2022-12-17 05:56:48 +02:00
|
|
|
PolyHarness("scatter_add", "clip1",
|
|
|
|
partial(lax.scatter_add, indices_are_sorted=False, unique_indices=True, mode=lax.GatherScatterMode.CLIP),
|
2023-11-06 18:38:12 +01:00
|
|
|
arg_descriptors=[RandArg((7, 4), _f32), # op: [b, 4]
|
|
|
|
# indices: [b, 2]
|
|
|
|
np.array([[1, 2], [-2, 0], [6, 4], [7, -1], [1, 0], [3, 0], [0, 5]], np.int32),
|
2022-12-17 05:56:48 +02:00
|
|
|
RandArg((7, 1), _f32), # updates: [b, 1]
|
|
|
|
StaticArg(lax.ScatterDimensionNumbers((1,), (0,), (0, 1,)))],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=["b, ...", "b, ...", "b, ..."]),
|
2023-11-06 18:38:12 +01:00
|
|
|
PolyHarness("scatter_grad", "",
|
|
|
|
lambda *args: jax.grad(
|
|
|
|
lambda *args:
|
2024-05-17 09:46:36 +01:00
|
|
|
jnp.sum(lax.scatter(
|
2023-11-06 18:38:12 +01:00
|
|
|
*args,
|
|
|
|
indices_are_sorted=False,
|
|
|
|
unique_indices=False,
|
|
|
|
))
|
|
|
|
)(*args),
|
|
|
|
arg_descriptors=[RandArg((7, 4), _f32), # : [b, 4]
|
|
|
|
np.array([[1], [2]], np.int32), # indices: [2, 1]
|
|
|
|
RandArg((7, 2), _f32), # updates: [b, 2]
|
|
|
|
StaticArg(lax.ScatterDimensionNumbers((0,), (1,), (1,))),
|
|
|
|
],
|
|
|
|
polymorphic_shapes=["b, ...", None, "b, ..."]),
|
|
|
|
PolyHarness("scatter_grad", "poly_indices",
|
|
|
|
lambda *args: jax.grad(
|
|
|
|
lambda *args:
|
2024-05-17 09:46:36 +01:00
|
|
|
jnp.sum(lax.scatter(
|
2023-11-06 18:38:12 +01:00
|
|
|
*args,
|
|
|
|
indices_are_sorted=False,
|
|
|
|
unique_indices=False))
|
|
|
|
)(*args),
|
|
|
|
arg_descriptors=[RandArg((7, 4), _f32), # op: [b, 4]
|
|
|
|
# indices: [b, 2]
|
|
|
|
np.array(
|
|
|
|
[[1, 2], [-2, 0], [6, 4], [7, -1], [1, 0],
|
|
|
|
[3, 0], [0, 5]], np.int32),
|
|
|
|
RandArg((7, 1), _f32), # updates: [b, 1]
|
|
|
|
StaticArg(lax.ScatterDimensionNumbers((1,), (0,), (0, 1))),
|
|
|
|
],
|
|
|
|
polymorphic_shapes=["b, ...", "b, ...", "b, ..."]),
|
2023-06-26 13:58:23 -07:00
|
|
|
[
|
|
|
|
PolyHarness("schur",
|
|
|
|
f"shape={jtu.format_shape_dtype_string(shape, dtype)}_{poly=}_{compute_schur_vectors=}",
|
|
|
|
lambda a, compute_schur_vectors: lax.linalg.schur(
|
|
|
|
a, compute_schur_vectors=compute_schur_vectors),
|
|
|
|
arg_descriptors=[RandArg(shape, dtype),
|
|
|
|
StaticArg(compute_schur_vectors)],
|
|
|
|
polymorphic_shapes=[poly],
|
|
|
|
# In non-native serialization, we cannot check exact match,
|
|
|
|
# we ought to check the invariants of the result.
|
2023-10-12 13:15:22 +01:00
|
|
|
check_result=config.jax2tf_default_native_serialization.value)
|
2023-11-19 08:59:23 -08:00
|
|
|
for dtype in {np.float32, np.float64, np.complex64, np.complex128} & jtu.supported_dtypes()
|
2023-06-26 13:58:23 -07:00
|
|
|
for compute_schur_vectors in [True, False]
|
|
|
|
for (shape, poly) in [
|
|
|
|
((3, 3), "w, w"),
|
|
|
|
((3, 4, 4), "b, w, w"),
|
|
|
|
]
|
|
|
|
],
|
2022-12-17 05:56:48 +02:00
|
|
|
PolyHarness("select", "0",
|
|
|
|
# x.shape = (b, 3)
|
|
|
|
lambda x: lax.select(x > 5., x, x),
|
|
|
|
arg_descriptors=[RandArg((7, 3), _f32)],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=["b, ..."]),
|
2022-12-17 05:56:48 +02:00
|
|
|
PolyHarness("select", "1",
|
|
|
|
# x.shape = (b, 3); y.shape = (3,)
|
|
|
|
jax.vmap(lambda x, y: lax.select(x > 5., x, y), in_axes=[0, None]),
|
|
|
|
arg_descriptors=[RandArg((7, 3), _f32), RandArg((3,), _f32)],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=["b, ...", None]),
|
2022-12-17 05:56:48 +02:00
|
|
|
PolyHarness("slice", "entire_axis",
|
|
|
|
lambda x: lax.slice(x, start_indices=(0, 1), limit_indices=(x.shape[0], 3)),
|
|
|
|
arg_descriptors=[RandArg((7, 3), _f32)],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=["b, ..."]),
|
2022-12-17 05:56:48 +02:00
|
|
|
PolyHarness("slice_in_dim", "entire_axis",
|
|
|
|
lambda x: lax.slice_in_dim(x, 0, x.shape[0], stride=1, axis=0),
|
|
|
|
arg_descriptors=[RandArg((3, 4), _f32)],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=["b, ..."]),
|
2022-12-17 05:56:48 +02:00
|
|
|
PolyHarness("slice_in_dim", "start=neg",
|
|
|
|
lambda x: lax.slice_in_dim(x, -1, x.shape[0], stride=1, axis=0),
|
|
|
|
arg_descriptors=[RandArg((3, 4), _f32)],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=["b, ..."]),
|
2022-12-17 05:56:48 +02:00
|
|
|
PolyHarness("slice_in_dim", "limit=neg",
|
|
|
|
lambda x: lax.slice_in_dim(x, 0, -1, stride=1, axis=0),
|
|
|
|
arg_descriptors=[RandArg((3, 4), _f32)],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=["b, ..."]),
|
2023-01-18 12:27:02 +02:00
|
|
|
PolyHarness("slice_in_dim", "stride=2_even",
|
|
|
|
lambda x: lax.slice_in_dim(x, 0, x.shape[0], stride=2, axis=0),
|
|
|
|
arg_descriptors=[RandArg((12, 4), _f32)],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=["b, ..."]),
|
2023-01-18 12:27:02 +02:00
|
|
|
PolyHarness("slice_in_dim", "stride=2_odd",
|
|
|
|
lambda x: lax.slice_in_dim(x, 0, x.shape[0], stride=2, axis=0),
|
|
|
|
arg_descriptors=[RandArg((13, 4), _f32)],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=["b, ..."]),
|
2023-01-18 12:27:02 +02:00
|
|
|
# Not yet, the slice_in_dim does int(stride)
|
|
|
|
# PolyHarness("slice_in_dim", "stride=sym",
|
|
|
|
# lambda x: lax.slice_in_dim(x, 0, x.shape[0], stride=x.shape[0] // 4, axis=0),
|
|
|
|
# arg_descriptors=[RandArg((13, 4), _f32)],
|
2023-06-16 12:50:50 +03:00
|
|
|
# polymorphic_shapes=["b, ..."]),
|
2022-12-20 15:42:09 +02:00
|
|
|
PolyHarness("squeeze", "axis=empty",
|
2022-12-17 05:56:48 +02:00
|
|
|
jnp.squeeze,
|
|
|
|
arg_descriptors=[RandArg((5,), _f32), StaticArg(())],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=["b, ..."]),
|
2022-12-20 15:42:09 +02:00
|
|
|
PolyHarness("squeeze", "axis=None",
|
|
|
|
jnp.squeeze,
|
|
|
|
arg_descriptors=[RandArg((5,), _f32), StaticArg(None)],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=["b, ..."],
|
2022-12-20 15:42:09 +02:00
|
|
|
expect_error=(ValueError, "jnp.squeeze with axis=None is not supported with shape polymorphism")),
|
2022-12-17 05:56:48 +02:00
|
|
|
PolyHarness("squeeze", "axis=1",
|
|
|
|
jnp.squeeze,
|
|
|
|
arg_descriptors=[RandArg((4, 1), _f32), StaticArg((1,))],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=["b, ..."]),
|
2022-12-17 05:56:48 +02:00
|
|
|
PolyHarness("squeeze", "axis=1_2",
|
|
|
|
jnp.squeeze,
|
|
|
|
arg_descriptors=[RandArg((4, 1, 1), _f32), StaticArg((1, 2))],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=["b, ..."]),
|
2022-12-17 05:56:48 +02:00
|
|
|
PolyHarness("squeeze", "error",
|
|
|
|
jnp.squeeze,
|
|
|
|
arg_descriptors=[RandArg((3, 33), _f32), StaticArg(-1)],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=["b0, b1"],
|
|
|
|
input_signature=[tf.TensorSpec([None, None], _f32)],
|
2022-12-17 05:56:48 +02:00
|
|
|
skip_jax_run=True,
|
|
|
|
expect_error=(ValueError,
|
|
|
|
re.escape(
|
|
|
|
"cannot select an axis to squeeze out which has size not equal to one, got shape=(b0, b1) and dimensions=(1,)"))
|
|
|
|
),
|
|
|
|
PolyHarness("take", "",
|
|
|
|
lambda a, i: jnp.take(a, i, axis=1),
|
|
|
|
arg_descriptors=[RandArg((3, 4, 5), _f32), np.array([1, 2], np.int32)],
|
2024-07-16 02:04:59 -07:00
|
|
|
polymorphic_shapes=["b, ...", None]),
|
2022-12-17 05:56:48 +02:00
|
|
|
PolyHarness("take_along_axis", "0",
|
|
|
|
lambda x, y: jnp.take_along_axis(x, y, axis=0),
|
|
|
|
arg_descriptors=[RandArg((5, 2), _f32), RandArg((5, 1), np.int32)],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=["b, ...", "b, ..."]),
|
2022-12-17 05:56:48 +02:00
|
|
|
PolyHarness("take_along_axis", "1",
|
|
|
|
lambda x, y: jnp.take_along_axis(x, y, axis=1),
|
|
|
|
arg_descriptors=[RandArg((5, 2), _f32), RandArg((5, 1), np.int32)],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=["b, ...", "b, ..."]),
|
2022-12-17 05:56:48 +02:00
|
|
|
PolyHarness("tile", "0",
|
|
|
|
lambda x: jnp.tile(x, (1, 2)),
|
|
|
|
arg_descriptors=[RandArg((4, 3), _f32)],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=["b, ..."]),
|
2022-12-17 05:56:48 +02:00
|
|
|
PolyHarness("tile", "1",
|
|
|
|
# The repetitions are polys
|
|
|
|
lambda x: jnp.tile(x, (1, x.shape[0])),
|
|
|
|
arg_descriptors=[RandArg((4, 2), _f32)],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=["b, ..."]),
|
2023-04-12 14:08:12 +03:00
|
|
|
PolyHarness("lax_top_k", "",
|
|
|
|
lambda x: jax.lax.top_k(x, x.shape[-1] - 1),
|
|
|
|
arg_descriptors=[RandArg((16,), _f32)],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=["b, ..."]),
|
2022-12-17 05:56:48 +02:00
|
|
|
PolyHarness("tri", "N=poly_M=None",
|
2023-02-13 02:47:35 -08:00
|
|
|
lambda x: jnp.tri(x.shape[0]) + x,
|
|
|
|
arg_descriptors=[RandArg((3, 1), _f32)],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=["b, ..."]),
|
2022-12-17 05:56:48 +02:00
|
|
|
PolyHarness("tri", "N=poly_M=poly",
|
2023-02-13 02:47:35 -08:00
|
|
|
lambda x: jnp.tri(x.shape[0], M=x.shape[0] + 2) + x,
|
|
|
|
arg_descriptors=[RandArg((3, 1), _f32)],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=["b, ..."]),
|
2024-05-29 14:23:03 +03:00
|
|
|
PolyHarness("tril", "",
|
|
|
|
lambda x: jnp.tril(jnp.ones((x.shape[0], x.shape[0] + x.shape[1]),
|
|
|
|
dtype=_f32),
|
|
|
|
k=x.shape[1]),
|
|
|
|
arg_descriptors=[RandArg((3, 4), _f32)],
|
|
|
|
polymorphic_shapes=["m, n"]),
|
2023-06-26 12:12:15 -07:00
|
|
|
[
|
|
|
|
PolyHarness("triangular_solve",
|
|
|
|
f"shape={jtu.format_shape_dtype_string(a_shape, dtype)}_{left_side=}_{a_poly=}_{b_poly=}",
|
|
|
|
lambda a, b, left_side: lax.linalg.triangular_solve(
|
|
|
|
jnp.tril(a) + 5 * jnp.eye(a.shape[-1], dtype=a.dtype),
|
|
|
|
b, left_side=left_side,
|
|
|
|
lower=True, transpose_a=False, conjugate_a=False,
|
|
|
|
unit_diagonal=False),
|
|
|
|
arg_descriptors=[RandArg(a_shape, dtype),
|
|
|
|
RandArg(b_shape, dtype),
|
|
|
|
StaticArg(left_side)],
|
|
|
|
polymorphic_shapes=[a_poly, b_poly],
|
|
|
|
# In non-native serialization, we cannot check exact match,
|
|
|
|
# we ought to check the invariants of the result.
|
2023-10-12 13:15:22 +01:00
|
|
|
check_result=config.jax2tf_default_native_serialization.value)
|
2023-11-19 08:59:23 -08:00
|
|
|
for dtype in {np.float32, np.float64, np.complex64, np.complex128} & jtu.supported_dtypes()
|
2023-06-26 12:12:15 -07:00
|
|
|
for (left_side, a_shape, b_shape, a_poly, b_poly) in [
|
|
|
|
(True, (3, 4, 4), (3, 4, 5), "b, ...", "b, ..."),
|
|
|
|
(True, (3, 4, 4), (3, 4, 5), "b, k, k", "b, k, m"),
|
|
|
|
(False, (3, 4, 4), (3, 5, 4), "b, ...", "b, ..."),
|
|
|
|
(False, (3, 4, 4), (3, 5, 4), "b, k, k", "b, m, k"),
|
|
|
|
# We use custom calls on CPU if not batched
|
|
|
|
(True, (4, 4), (4, 5), "k, k", "k, m"),
|
|
|
|
(False, (4, 4), (5, 4), "k, k", "m, k"),
|
|
|
|
]
|
|
|
|
],
|
2021-07-30 10:52:34 +03:00
|
|
|
[
|
2022-12-17 05:56:48 +02:00
|
|
|
PolyHarness("var",
|
|
|
|
f"{axis=}_{keepdims=}_where=None",
|
|
|
|
lambda x, axis, keepdims: jnp.var(x, axis=axis, keepdims=keepdims, where=None),
|
|
|
|
arg_descriptors=[RandArg((7, 8, 4), _f32), StaticArg(axis), StaticArg(keepdims)],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=["b, ..."])
|
2021-07-30 10:52:34 +03:00
|
|
|
for keepdims in [False, True]
|
|
|
|
for axis in [None, (0,), (0, 1), (1,)]
|
|
|
|
],
|
|
|
|
[
|
2022-12-17 05:56:48 +02:00
|
|
|
PolyHarness("var",
|
|
|
|
f"{axis=}_{keepdims=}_where=Some",
|
|
|
|
lambda x, where, axis, keepdims: jnp.var(x, axis=axis, keepdims=keepdims, where=where),
|
|
|
|
arg_descriptors=[RandArg((7, 8, 4), _f32), RandArg((7, 8, 4), np.bool_), StaticArg(axis), StaticArg(keepdims)],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=["b, ...", "b, ..."])
|
2021-07-30 10:52:34 +03:00
|
|
|
for keepdims in [False, True]
|
|
|
|
for axis in [None, (0,), (0, 1), (1,)]
|
|
|
|
],
|
2022-12-17 05:56:48 +02:00
|
|
|
PolyHarness("where", "",
|
|
|
|
jnp.where,
|
|
|
|
arg_descriptors=[RandArg((2,), np.bool_), RandArg((), _f32), RandArg((2,), _f32)],
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=["b, ...", None, "b, ..."]),
|
2021-04-09 14:02:44 +03:00
|
|
|
]
|
|
|
|
|
2022-07-05 14:01:19 +02:00
|
|
|
def _get_jax2tf_limitations(
|
2023-11-09 13:57:30 -08:00
|
|
|
device, h: test_harnesses.Harness) -> Sequence[Jax2TfLimitation]:
|
2022-07-05 14:01:19 +02:00
|
|
|
# And the jax2tf limitations
|
|
|
|
def applicable_jax2tf_limitation(l: Jax2TfLimitation) -> bool:
|
|
|
|
# The CheckShapePolymorphism uses tf.function, so we care about "graph"
|
|
|
|
return l.filter(device=device, dtype=h.dtype, mode="graph")
|
|
|
|
|
|
|
|
limitations = Jax2TfLimitation.limitations_for_harness(h)
|
|
|
|
return tuple(filter(applicable_jax2tf_limitation, limitations))
|
|
|
|
|
2021-04-09 14:02:44 +03:00
|
|
|
### We add to the test harnesses some that are obtained from the
|
|
|
|
### primitive harnesses by applying vmap to the function and then asserting
|
|
|
|
### that we can convert shape polymorphically the result.
|
|
|
|
|
2022-12-17 05:56:48 +02:00
|
|
|
def _make_vmap_primitive_harnesses() -> Sequence[PolyHarness]:
|
2021-04-05 12:27:53 +03:00
|
|
|
"""For each harness group, pick a single dtype.
|
2020-10-11 19:48:36 +03:00
|
|
|
|
2022-12-17 05:56:48 +02:00
|
|
|
See PolyHarness for documentation.
|
|
|
|
|
2021-04-05 12:27:53 +03:00
|
|
|
Ignore harnesses that fail in graph mode in jax2tf.
|
|
|
|
"""
|
2023-11-09 13:57:30 -08:00
|
|
|
all_h = test_harnesses.all_harnesses
|
2022-07-05 14:01:19 +02:00
|
|
|
res = []
|
2021-03-16 11:38:57 +01:00
|
|
|
|
2021-04-05 10:13:02 +03:00
|
|
|
# Index by group
|
2023-06-23 15:11:37 -07:00
|
|
|
harness_groups: dict[
|
2023-11-09 13:57:30 -08:00
|
|
|
str, Sequence[test_harnesses.Harness]] = collections.defaultdict(list)
|
2021-04-05 10:13:02 +03:00
|
|
|
device = jtu.device_under_test()
|
2020-10-11 19:48:36 +03:00
|
|
|
|
2021-04-01 15:37:01 +03:00
|
|
|
for h in all_h:
|
2022-02-13 23:47:52 +09:00
|
|
|
# Drop the JAX limitations
|
2021-04-05 10:13:02 +03:00
|
|
|
if not h.filter(device_under_test=device, include_jax_unimpl=False):
|
|
|
|
continue
|
2021-04-09 14:02:44 +03:00
|
|
|
# And the jax2tf limitations that are known to result in TF error.
|
|
|
|
if any(l.expect_tf_error for l in _get_jax2tf_limitations(device, h)):
|
2021-04-01 15:37:01 +03:00
|
|
|
continue
|
|
|
|
harness_groups[h.group_name].append(h)
|
2020-10-11 19:48:36 +03:00
|
|
|
|
2021-04-09 14:02:44 +03:00
|
|
|
selected_harnesses = []
|
2023-11-19 08:59:23 -08:00
|
|
|
for _, hlist in harness_groups.items():
|
2021-04-01 15:37:01 +03:00
|
|
|
# Pick the dtype with the most harnesses in this group. Some harness
|
|
|
|
# groups only test different use cases at a few dtypes.
|
|
|
|
c = collections.Counter([h.dtype for h in hlist])
|
2023-11-19 08:59:23 -08:00
|
|
|
(_, max_count), = c.most_common(1)
|
|
|
|
# Pick the first alphabetically among those with max_count, to ensure
|
|
|
|
# that we generate deterministic tests.
|
|
|
|
dtypes_with_max_count = (dtype for dtype, count in c.items()
|
|
|
|
if count == max_count)
|
|
|
|
dtype, *_ = sorted(dtypes_with_max_count, key=str)
|
2021-04-09 14:02:44 +03:00
|
|
|
selected_harnesses.extend([h for h in hlist if h.dtype == dtype])
|
|
|
|
|
|
|
|
batch_size = 3
|
|
|
|
for h in selected_harnesses:
|
2023-05-09 21:29:16 +02:00
|
|
|
if h.group_name in [
|
|
|
|
"tridiagonal_solve", # batching not implemented in JAX
|
|
|
|
]:
|
2021-04-09 14:02:44 +03:00
|
|
|
continue
|
|
|
|
|
|
|
|
def make_batched_arg_descriptor(
|
2023-12-11 13:59:29 +00:00
|
|
|
ad: test_harnesses.ArgDescriptor) -> test_harnesses.ArgDescriptor | None:
|
2021-04-09 14:02:44 +03:00
|
|
|
if isinstance(ad, RandArg):
|
|
|
|
return RandArg((batch_size,) + ad.shape, ad.dtype)
|
|
|
|
elif isinstance(ad, CustomArg):
|
|
|
|
def wrap_custom(rng):
|
|
|
|
arg = ad.make(rng)
|
|
|
|
return np.stack([arg] * batch_size)
|
|
|
|
|
|
|
|
return CustomArg(wrap_custom)
|
|
|
|
else:
|
|
|
|
assert isinstance(ad, np.ndarray), ad
|
|
|
|
return np.stack([ad] * batch_size)
|
|
|
|
|
|
|
|
new_args = [make_batched_arg_descriptor(ad)
|
|
|
|
for ad in h.arg_descriptors
|
|
|
|
if not isinstance(ad, StaticArg)]
|
|
|
|
|
2021-05-22 11:02:26 -07:00
|
|
|
# This test does not make sense for nullary functions
|
|
|
|
if not new_args:
|
|
|
|
continue
|
|
|
|
|
2023-05-09 02:28:49 -07:00
|
|
|
limitations = [
|
|
|
|
l for l in _get_jax2tf_limitations(device, h)
|
|
|
|
if not l.skip_comparison and (l.custom_assert or l.tol is not None)]
|
|
|
|
|
2022-12-17 05:56:48 +02:00
|
|
|
vmap_harness = PolyHarness("vmap_" + h.group_name, h.name,
|
|
|
|
jax.vmap(h.dyn_fun, in_axes=0, out_axes=0),
|
|
|
|
arg_descriptors=new_args,
|
2023-06-16 12:50:50 +03:00
|
|
|
polymorphic_shapes=["b, ..."] * len(new_args),
|
2023-05-09 02:28:49 -07:00
|
|
|
limitations=limitations)
|
2022-12-17 05:56:48 +02:00
|
|
|
vmap_harness.original_harness = h
|
2022-07-05 14:01:19 +02:00
|
|
|
res.append(vmap_harness)
|
|
|
|
return res
|
2020-10-11 19:48:36 +03:00
|
|
|
|
2022-12-17 05:56:48 +02:00
|
|
|
_POLY_SHAPE_TEST_HARNESSES.append(_make_vmap_primitive_harnesses())
|
|
|
|
|
|
|
|
def _flatten_harnesses(harnesses):
|
|
|
|
res = []
|
|
|
|
for h in harnesses:
|
|
|
|
if isinstance(h, Sequence):
|
2023-05-04 09:52:21 +02:00
|
|
|
res.extend(_flatten_harnesses(h))
|
2022-12-17 05:56:48 +02:00
|
|
|
else:
|
|
|
|
res.append(h)
|
|
|
|
return res
|
|
|
|
|
2020-10-11 19:48:36 +03:00
|
|
|
|
2022-12-17 05:56:48 +02:00
|
|
|
class ShapePolyPrimitivesTest(tf_test_util.JaxToTfTestCase):
|
|
|
|
"""Tests for primitives that take shape values as parameters."""
|
2020-10-15 08:24:35 +03:00
|
|
|
|
2022-12-17 05:56:48 +02:00
|
|
|
# This test runs for all _POLY_SHAPE_PRIMITIVE_HARNESSES.
|
2021-04-08 06:21:12 -07:00
|
|
|
|
2022-12-17 05:56:48 +02:00
|
|
|
# For each primitive "xxx" the test will be called "test_harness_xxx_...".
|
2021-05-12 02:29:51 -07:00
|
|
|
# If you want to run this test for only one harness that includes "foo"
|
2022-12-17 05:56:48 +02:00
|
|
|
# in the name (after test_harness), add parameter `one_containing="foo"`
|
[jax2tf] Expand shape polymorphism support to use dimension polynomials as values.
The goal of this change is to support shape polymorphism for operations
such as average (which needs to divide by the size of a dimension) or
indexing (which needs to normalize indices by comparing them with 0 and
adding dimension size for negative indices). In both of these cases
the size of a dimenion needs to be used as a value in the array
computation. In general, the size of a dimension is used only to
customize primitives.
This change introduces `core.dim_as_value` which must be used on
a dimension size before using it as a value in the array computation.
E.g.,
```
def average(x):
return jnp.sum(x, axis=0) / core.dim_as_value(x.shape[0])
```
This function is the identity function if the dimension size is
constant, otherwise it uses a new primitive `shape_poly.dim_as_value_p`.
Note that this does not change fundamentally the flavor of shape
polymorphism supported in jax2tf: intermediate shapes and their values
may depend on the input shapes, but never does a shape depend on the
input values. In fact, one could have expressed the `dim_as_value`
already:
```
def dim_as_value(d):
jnp.sum(jnp.broadcast_to(jnp.array(1), shape=(d,)))
```
We were able to suppot `jnp.mean`, `jnp.average`, `jnp.take`,
`lax.dynamic_slice`, `lax.dynamic_update_slice` by using
`core.dim_as_value` internally, but to fully roll-up the solution
we need to make `core.dim_as_value` a public API and teach the
users how to use it when they want to use shape polymorphism.
Alternatively, perhaps there is a way to automatically convert
dimension polynomials to values when passed to the lax primitives.
2021-07-16 20:01:22 +03:00
|
|
|
# to parameterized below.
|
2023-11-09 13:57:30 -08:00
|
|
|
@test_harnesses.parameterized(
|
2022-12-17 05:56:48 +02:00
|
|
|
_flatten_harnesses(_POLY_SHAPE_TEST_HARNESSES),
|
|
|
|
#one_containing="",
|
2021-07-27 15:50:47 +03:00
|
|
|
)
|
2022-12-17 05:56:48 +02:00
|
|
|
def test_harness(self, harness: PolyHarness):
|
2023-09-27 12:10:06 -07:00
|
|
|
if harness.expect_error == expect_error_associative_scan and (
|
2023-10-12 13:15:22 +01:00
|
|
|
not config.jax2tf_default_native_serialization.value
|
2023-09-27 12:10:06 -07:00
|
|
|
or jtu.test_device_matches(["tpu"])
|
|
|
|
):
|
|
|
|
harness.expect_error = (None, None)
|
|
|
|
|
2022-12-20 15:29:51 +02:00
|
|
|
# Exclude some harnesses that are known to fail for native serialization
|
2023-06-10 23:20:19 -07:00
|
|
|
# FOR NATIVE SERIALIZATION
|
2024-07-16 02:04:59 -07:00
|
|
|
# Set of harness.group_name:platform that are implemented with custom call
|
|
|
|
custom_call_harnesses = {
|
|
|
|
"householder_product:gpu",
|
|
|
|
"vmap_geqrf:gpu", # used for linalg.qr
|
|
|
|
"vmap_lu:gpu",
|
|
|
|
# custom_linear_solve works as long as lu works.
|
|
|
|
"vmap_custom_linear_solve:gpu",
|
|
|
|
"vmap_qr:gpu", "qr:gpu",
|
|
|
|
"vmap_svd:gpu",
|
|
|
|
}
|
|
|
|
if f"{harness.group_name}:{jtu.device_under_test()}" in custom_call_harnesses:
|
|
|
|
raise unittest.SkipTest("native serialization with shape polymorphism not implemented for custom calls; b/261671778")
|
|
|
|
|
|
|
|
if harness.group_name == "schur" and not jtu.test_device_matches(["cpu"]):
|
|
|
|
raise unittest.SkipTest("schur decomposition is only implemented on CPU.")
|
|
|
|
|
|
|
|
if "fft_fft_type" in harness.fullname:
|
|
|
|
if "nr_fft_lengths_2" in harness.fullname:
|
|
|
|
raise unittest.SkipTest("native serialization with shape polymorphism not implemented for fft with non-constant fft_lengths on GPU and TPU")
|
|
|
|
|
|
|
|
if harness.group_name == "vmap_eigh" and jtu.test_device_matches(["gpu"]):
|
|
|
|
# For eigh on GPU with shape polymorphism under native serialization,
|
|
|
|
# we use a different lowering for small matrices. See README.md.
|
|
|
|
shape = harness.original_harness.params["shape"]
|
|
|
|
if 0 < shape[-1] <= 32:
|
2023-05-09 02:28:49 -07:00
|
|
|
harness.check_result = False
|
|
|
|
|
2024-08-30 09:08:56 -07:00
|
|
|
if harness.group_name == "vmap_eigh":
|
|
|
|
raise unittest.SkipTest(
|
|
|
|
"Should not compare eigendecompositions for equality directly"
|
|
|
|
"because eigenvalues are sorted.")
|
|
|
|
|
2024-07-16 02:04:59 -07:00
|
|
|
if harness.group_name == "vmap_tan":
|
|
|
|
# Tan (b/274462307) require support for custom call stablehlo.tan.
|
|
|
|
raise unittest.SkipTest(
|
|
|
|
"native lowering with shape polymorphism requires additional StableHLO feature support")
|
|
|
|
|
|
|
|
if (jtu.test_device_matches(["cpu", "gpu"]) and
|
|
|
|
harness.fullname in [
|
|
|
|
"cumsum_reduce_axis_poly", "cumprod_reduce_axis_poly",
|
|
|
|
"cummin_reduce_axis_poly", "cummax_reduce_axis_poly",
|
|
|
|
"cumlogsumexp_reduce_axis_poly",
|
|
|
|
"jnp_insert_insert_constant", "jnp_insert_insert_poly",
|
|
|
|
"jnp_nonzero_size_constant", "jnp_nonzero_size_poly"]):
|
|
|
|
# Need associative scan reductions on CPU and GPU. On TPU we use the
|
|
|
|
# reduce_window HLO, but on CPU and GPU (with axis size >= 32) we use
|
|
|
|
# a recursive associative scan that we cannot express with shape
|
|
|
|
# polymorphism.
|
|
|
|
raise unittest.SkipTest(
|
|
|
|
"native serialization with shape polymorphism not implemented for window_reductions on CPU and GPU")
|
2023-06-17 10:33:29 -07:00
|
|
|
|
2023-05-09 21:29:16 +02:00
|
|
|
# FOR BOTH NATIVE AND GRAPH SERIALIZATION
|
|
|
|
if harness.group_name == "vmap_conv_general_dilated":
|
|
|
|
# https://github.com/openxla/stablehlo/issues/1268
|
|
|
|
raise unittest.SkipTest("Need more dynamism for DynamicConvOp")
|
|
|
|
|
2023-09-27 12:10:06 -07:00
|
|
|
if harness.group_name == "eig" and not jtu.test_device_matches(["cpu"]):
|
2023-06-16 23:58:37 -07:00
|
|
|
raise unittest.SkipTest("JAX implements eig only on CPU.")
|
|
|
|
|
2025-01-08 14:08:33 -08:00
|
|
|
with jtu.thread_local_config_context(**harness.override_jax_config_flags):
|
2023-05-04 09:52:21 +02:00
|
|
|
harness.run_test(self)
|
2021-04-05 11:08:46 +03:00
|
|
|
|
2020-10-11 19:48:36 +03:00
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
absltest.main(testLoader=jtu.JaxTestLoader())
|