1
0
mirror of https://github.com/ROCm/jax.git synced 2025-04-19 13:26:06 +00:00

[shape_poly] Improve and rename export.args_specs.

We rename it to `symbolic_args_specs` in line with the other
public APIs related to shape polymorphism. The function used to
be in _export.py for historical reasons, we now move it to
shape_poly.py but we export the `symbolci_args_specs` from
the public `jax.experimental.export`.

The improvement is that for the case when the `args` passed in
are TF arrays, we move the logic to extract the shapes and dtypes
from this function to the callers. This achieves a better
separation of the JAX and TF use cases.
This commit is contained in:
George Necula 2024-01-10 09:44:31 +02:00
parent 909f5e8114
commit 3b7917a56e
9 changed files with 105 additions and 68 deletions

@ -33,10 +33,13 @@ Remember to align the itemized text with the first line of an item within a list
was deprecated and `core.max_dim` and `core.min_dim` were introduced
({jax-issue}`#18953`) to express `max` and `min` for symbolic dimensions.
You can use `core.max_dim(d, 0)` instead of `core.non_negative_dim(d)`.
* the `shape_poly.is_poly_dim` is deprecated in favor if `export.is_symbolic_dim`
* the `shape_poly.is_poly_dim` is deprecated in favor of `export.is_symbolic_dim`
({jax-issue}`#19282`).
* the `export.args_specs` is deprecated in favor of `export.symbolic_args_specs
({jax-issue}`#19283`).
* the `shape_poly.PolyShape` and `jax2tf.PolyShape` are deprecated, use
strings for polymorphic shapes specifications ({jax-issue}`#19284`).
* Refactored the API for `jax.experimental.export`. Instead of
`from jax.experimental.export import export` you should use now
`from jax.experimental import export`. The old way of importing will

@ -289,7 +289,7 @@ data_{datetime.date.today().strftime('%Y_%m_%d')} = dict(
(d) the number of devices for which the module was serialized.
"""
# Use the native exporter, to make sure we get the proper serialization.
args_specs = export.args_specs(data.inputs, polymorphic_shapes)
args_specs = export.symbolic_args_specs(data.inputs, polymorphic_shapes)
exported = export.export(
jax.jit(func),
lowering_platforms=(self.default_jax_backend(),),
@ -306,7 +306,7 @@ data_{datetime.date.today().strftime('%Y_%m_%d')} = dict(
def run_serialized(self, data: CompatTestData,
polymorphic_shapes: Sequence[str] | None = None):
args_specs = export.args_specs(data.inputs, polymorphic_shapes)
args_specs = export.symbolic_args_specs(data.inputs, polymorphic_shapes)
def ndarray_to_aval(a: np.ndarray) -> core.ShapedArray:
return core.ShapedArray(a.shape, a.dtype)
in_avals_tree = tree_util.tree_map(ndarray_to_aval, args_specs)

@ -23,11 +23,12 @@ from jax.experimental.export._export import (
DisabledSafetyCheck,
default_lowering_platform,
args_specs, # TODO: move to shape_poly
args_specs, # TODO: deprecate
)
from jax.experimental.export.shape_poly import (
is_symbolic_dim,
symbolic_shape,
symbolic_args_specs,
)
from jax.experimental.export.serialization import (
serialize,

@ -345,7 +345,9 @@ def symbolic_shape(
return shape_poly.symbolic_shape(shape_spec, like=like)
def shape_and_dtype_jax_array(a) -> tuple[Sequence[int | None], DType]:
"""Returns the shape and dtype of a jax.Array."""
"""Returns the shape and dtype of a jax.Array or a j"""
if isinstance(a, jax.ShapeDtypeStruct):
return a.shape, a.dtype
aval = core.raise_to_shaped(core.get_aval(a))
return aval.shape, aval.dtype
@ -354,55 +356,15 @@ def args_specs(
polymorphic_shapes, # prefix pytree of strings
get_shape_and_dtype=shape_and_dtype_jax_array,
):
"""Constructs a pytree of jax.ShapeDtypeSpec arguments specs for `export`.
Args:
args: a pytree of arguments
polymorphic_shapes: should be `None` (all arguments are monomorphic),
a single string (applies to all arguments), or a pytree matching a prefix
of the `args`.
See [how optional parameters are matched to
arguments](https://jax.readthedocs.io/en/latest/pytrees.html#applying-optional-parameters-to-pytrees).
Note that this function does not ensure that the provided `args` shapes
are compatible with `polymorphic_shapes`. The `.shape` of the `args` are
used only to fill-in placeholders from `polymorphic_shapes`.
See docstring of `symbolic_shape` and
[the README](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#shape-polymorphic-conversion)
for more details.
get_shape_and_dtype: a function that given an argument extracts a tuple
of a shape and a dtype.
Returns: a pytree of jax.ShapeDTypeStruct matching `args`.
"""
args_flat, args_tree = tree_util.tree_flatten(args)
shapes_and_dtypes = tuple(map(get_shape_and_dtype, args_flat))
shapes, dtypes = util.unzip2(shapes_and_dtypes)
if isinstance(args, tuple) and isinstance(polymorphic_shapes, list):
# TODO: Remove backward-compatibility workaround
polymorphic_shapes_ = tuple(polymorphic_shapes)
else:
polymorphic_shapes_ = polymorphic_shapes
try:
polymorphic_shapes_flat = tree_util.broadcast_prefix(
polymorphic_shapes_, args,
is_leaf=lambda x: x is None)
except ValueError:
e, *_ = tree_util.prefix_errors(
polymorphic_shapes_, args,
is_leaf=lambda x: x is None)
raise e("jax_export polymorphic_shapes") from None
# Now add in the polymorphic shapes
args_specs_flat = (
jax.ShapeDtypeStruct(symbolic_shape(spec, like=s), t)
for s, t, spec in zip(shapes, dtypes, polymorphic_shapes_flat))
return args_tree.unflatten(args_specs_flat)
# TODO: deprecated in January 2024, to be removed.
warnings.warn(
"export.args_specs is deprecated in favor of export.symbolic_args_specs",
DeprecationWarning, stacklevel=2)
if get_shape_and_dtype is not shape_and_dtype_jax_array:
# This was needed in some older jax2tf implementations
args = tree_util.tree_map(lambda a: jax.ShapeDtypeStruct(* get_shape_and_dtype(a)),
args)
return shape_poly.symbolic_args_specs(args, polymorphic_shapes)
def _keep_main_tokens(serialization_version: int) -> bool:

@ -1199,6 +1199,72 @@ def symbolic_shape(shape_spec: str | None,
f"Found {shape_spec_repr}.")
return _Parser(shape_spec, like, shape_spec_repr).parse()
def symbolic_args_specs(
args, # pytree of arguments
polymorphic_shapes, # prefix pytree of strings
):
"""Constructs a pytree of jax.ShapeDtypeSpec arguments specs for `export`.
Note that this function does not ensure that the provided `args` shapes
are compatible with `polymorphic_shapes`. The `.shape` of the `args` are
used only to fill-in placeholders from `polymorphic_shapes`.
See docstring of `symbolic_shape` and
[the README](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#shape-polymorphic-conversion)
for more details.
Args:
args: a pytree of arguments. These can be jax.Array, or jax.ShapeDTypeSpec.
This is used to learn the pytree structure of the arguments, their dtypes,
and to fill-in the actual shapes where the `polymorphic_shapes` contains
placeholders. Note that only the shape dimensions for which
`polymorphic_shapes` is a placeholder are used from `args`.
The unused dimensions can be `None`, which jax2tf uses when the TF
shapes are not known.
polymorphic_shapes: should be `None` (all arguments have static shapes),
a single string (applies to all arguments), or a pytree matching a prefix
of the `args`.
See [how optional parameters are matched to
arguments](https://jax.readthedocs.io/en/latest/pytrees.html#applying-optional-parameters-to-pytrees).
Returns: a pytree of jax.ShapeDTypeStruct matching the `args` with the shapes
replaced with symbolic dimensions as specified by `polymorphic_shapes`.
"""
args_flat, args_tree = tree_util.tree_flatten(args)
shapes_and_dtypes = tuple(map(shape_and_dtype_jax_array, args_flat))
shapes, dtypes = util.unzip2(shapes_and_dtypes)
if isinstance(args, tuple) and isinstance(polymorphic_shapes, list):
# TODO: Remove backward-compatibility workaround
polymorphic_shapes_ = tuple(polymorphic_shapes)
else:
polymorphic_shapes_ = polymorphic_shapes
try:
polymorphic_shapes_flat = tree_util.broadcast_prefix(
polymorphic_shapes_, args,
is_leaf=lambda x: x is None)
except ValueError:
e, *_ = tree_util.prefix_errors(
polymorphic_shapes_, args,
is_leaf=lambda x: x is None)
raise e("jax_export polymorphic_shapes") from None
# Now add in the polymorphic shapes
args_specs_flat = (
jax.ShapeDtypeStruct(symbolic_shape(spec, like=s), t)
for s, t, spec in zip(shapes, dtypes, polymorphic_shapes_flat))
return args_tree.unflatten(args_specs_flat)
def shape_and_dtype_jax_array(a) -> tuple[Sequence[int | None], DType]:
"""Returns the shape and dtype of a jax.Array or a j"""
if isinstance(a, jax.ShapeDtypeStruct):
return a.shape, a.dtype
aval = core.raise_to_shaped(core.get_aval(a))
return aval.shape, aval.dtype
class _Parser:
def __init__(self,
shape_spec: str,

@ -341,7 +341,8 @@ def convert(fun_jax: Callable,
"for all platforms without native_serialization.")
if (not isinstance(native_serialization_platforms, (list, tuple)) or
not all(p in ["cpu", "cuda", "rocm", "tpu"] for p in native_serialization_platforms)):
not all(p in ["cpu", "cuda", "rocm", "tpu"]
for p in native_serialization_platforms)):
raise ValueError(
"native_serialization_platforms must be a sequence "
"containing a subset of {'cpu', 'cuda', 'rocm', 'tpu'}. "
@ -364,22 +365,26 @@ def convert(fun_jax: Callable,
source_info_util.register_exclusion(os.path.dirname(tf.__file__))
_has_registered_tf_source_path = True
def shape_and_dtype_tf(a: TfVal) -> tuple[Sequence[int | None], DType]:
def jax_arg_spec_from_tf(a: TfVal) -> jax.ShapeDtypeStruct:
# The shape and JAX dtype for a TF argument
tf_arg_shape = np.shape(a)
# Fix the shape for TF1
tf_arg_shape = tuple(d.value if isinstance(d, tf.compat.v1.Dimension) else d for d in tf_arg_shape)
tf_arg_shape = tuple(d.value
if isinstance(d, tf.compat.v1.Dimension) else d
for d in tf_arg_shape)
_, a_jax_dtype = _tfval_to_tensor_jax_dtype(a)
return tf_arg_shape, a_jax_dtype
# We count on the fact that jax.ShapeDtypeStruct allows shapes that
# contain None.
return jax.ShapeDtypeStruct(tf_arg_shape, a_jax_dtype)
args_specs = export.args_specs(args_tf,
polymorphic_shapes=polymorphic_shapes,
get_shape_and_dtype=shape_and_dtype_tf)
args_jax_specs = tree_util.tree_map(jax_arg_spec_from_tf, args_tf)
args_specs = export.symbolic_args_specs(
args_jax_specs, polymorphic_shapes=polymorphic_shapes)
# The polymorphic_shapes argument refers to positional arguments only.
# We assume None for the kwargs.
kwargs_specs = export.args_specs(kwargs_tf,
polymorphic_shapes=None,
get_shape_and_dtype=shape_and_dtype_tf)
kwargs_jax_specs = tree_util.tree_map(jax_arg_spec_from_tf, kwargs_tf)
kwargs_specs = export.symbolic_args_specs(
kwargs_jax_specs, polymorphic_shapes=None)
combined_args_tf = (args_tf, kwargs_tf)
args_flat_tf: Sequence[TfVal]
args_flat_tf, args_kwargs_tree = tree_util.tree_flatten(combined_args_tf)
@ -657,7 +662,7 @@ def eval_polymorphic_shape(fun_jax: Callable,
(c, a)
"""
def do_eval_polymorphic_shape(*args_specs) -> Any:
args_poly_specs = export.args_specs(
args_poly_specs = export.symbolic_args_specs(
args_specs, polymorphic_shapes=polymorphic_shapes)
res_poly_spec = jax.eval_shape(fun_jax, *args_poly_specs)
# TODO(necula): For now we export the polymorphic shapes using `str`.

@ -166,7 +166,7 @@ class PolyHarness(Harness):
len(polymorphic_shapes), len(args),
f"polymorphic_shapes {polymorphic_shapes} of length "
f"{len(polymorphic_shapes)} must match number of arguments {len(args)}")
args_specs = export.args_specs(args, polymorphic_shapes)
args_specs = export.symbolic_args_specs(args, polymorphic_shapes)
input_signature = [
tf.TensorSpec(
[d if isinstance(d, int) else None for d in a.shape],

@ -868,7 +868,7 @@ class JaxExportTest(jtu.JaxTestCase):
perm = [(j, (j + 1) % axis_size) for j in range(axis_size)]
return lax.ppermute(b, "x", perm=perm)
args_specs = export.args_specs((a,), polymorphic_shapes=poly)
args_specs = export.symbolic_args_specs((a,), polymorphic_shapes=poly)
exp = get_exported(f_jax)(*args_specs)
# Test JAX native execution

@ -696,7 +696,7 @@ class PolyHarness(Harness):
f_jax = self.dyn_fun
args = self.dyn_args_maker(tst.rng())
args = tree_util.tree_map(jnp.array, args)
args_specs = export.args_specs(args, self.polymorphic_shapes)
args_specs = export.symbolic_args_specs(args, self.polymorphic_shapes)
if self.expect_error is not None:
with tst.assertRaisesRegex(self.expect_error[0], self.expect_error[1]):