mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 13:26:06 +00:00
[shape_poly] Improve and rename export.args_specs.
We rename it to `symbolic_args_specs` in line with the other public APIs related to shape polymorphism. The function used to be in _export.py for historical reasons, we now move it to shape_poly.py but we export the `symbolci_args_specs` from the public `jax.experimental.export`. The improvement is that for the case when the `args` passed in are TF arrays, we move the logic to extract the shapes and dtypes from this function to the callers. This achieves a better separation of the JAX and TF use cases.
This commit is contained in:
parent
909f5e8114
commit
3b7917a56e
@ -33,10 +33,13 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
was deprecated and `core.max_dim` and `core.min_dim` were introduced
|
||||
({jax-issue}`#18953`) to express `max` and `min` for symbolic dimensions.
|
||||
You can use `core.max_dim(d, 0)` instead of `core.non_negative_dim(d)`.
|
||||
* the `shape_poly.is_poly_dim` is deprecated in favor if `export.is_symbolic_dim`
|
||||
* the `shape_poly.is_poly_dim` is deprecated in favor of `export.is_symbolic_dim`
|
||||
({jax-issue}`#19282`).
|
||||
* the `export.args_specs` is deprecated in favor of `export.symbolic_args_specs
|
||||
({jax-issue}`#19283`).
|
||||
* the `shape_poly.PolyShape` and `jax2tf.PolyShape` are deprecated, use
|
||||
strings for polymorphic shapes specifications ({jax-issue}`#19284`).
|
||||
|
||||
* Refactored the API for `jax.experimental.export`. Instead of
|
||||
`from jax.experimental.export import export` you should use now
|
||||
`from jax.experimental import export`. The old way of importing will
|
||||
|
@ -289,7 +289,7 @@ data_{datetime.date.today().strftime('%Y_%m_%d')} = dict(
|
||||
(d) the number of devices for which the module was serialized.
|
||||
"""
|
||||
# Use the native exporter, to make sure we get the proper serialization.
|
||||
args_specs = export.args_specs(data.inputs, polymorphic_shapes)
|
||||
args_specs = export.symbolic_args_specs(data.inputs, polymorphic_shapes)
|
||||
exported = export.export(
|
||||
jax.jit(func),
|
||||
lowering_platforms=(self.default_jax_backend(),),
|
||||
@ -306,7 +306,7 @@ data_{datetime.date.today().strftime('%Y_%m_%d')} = dict(
|
||||
|
||||
def run_serialized(self, data: CompatTestData,
|
||||
polymorphic_shapes: Sequence[str] | None = None):
|
||||
args_specs = export.args_specs(data.inputs, polymorphic_shapes)
|
||||
args_specs = export.symbolic_args_specs(data.inputs, polymorphic_shapes)
|
||||
def ndarray_to_aval(a: np.ndarray) -> core.ShapedArray:
|
||||
return core.ShapedArray(a.shape, a.dtype)
|
||||
in_avals_tree = tree_util.tree_map(ndarray_to_aval, args_specs)
|
||||
|
@ -23,11 +23,12 @@ from jax.experimental.export._export import (
|
||||
DisabledSafetyCheck,
|
||||
default_lowering_platform,
|
||||
|
||||
args_specs, # TODO: move to shape_poly
|
||||
args_specs, # TODO: deprecate
|
||||
)
|
||||
from jax.experimental.export.shape_poly import (
|
||||
is_symbolic_dim,
|
||||
symbolic_shape,
|
||||
symbolic_args_specs,
|
||||
)
|
||||
from jax.experimental.export.serialization import (
|
||||
serialize,
|
||||
|
@ -345,7 +345,9 @@ def symbolic_shape(
|
||||
return shape_poly.symbolic_shape(shape_spec, like=like)
|
||||
|
||||
def shape_and_dtype_jax_array(a) -> tuple[Sequence[int | None], DType]:
|
||||
"""Returns the shape and dtype of a jax.Array."""
|
||||
"""Returns the shape and dtype of a jax.Array or a j"""
|
||||
if isinstance(a, jax.ShapeDtypeStruct):
|
||||
return a.shape, a.dtype
|
||||
aval = core.raise_to_shaped(core.get_aval(a))
|
||||
return aval.shape, aval.dtype
|
||||
|
||||
@ -354,55 +356,15 @@ def args_specs(
|
||||
polymorphic_shapes, # prefix pytree of strings
|
||||
get_shape_and_dtype=shape_and_dtype_jax_array,
|
||||
):
|
||||
"""Constructs a pytree of jax.ShapeDtypeSpec arguments specs for `export`.
|
||||
|
||||
Args:
|
||||
args: a pytree of arguments
|
||||
polymorphic_shapes: should be `None` (all arguments are monomorphic),
|
||||
a single string (applies to all arguments), or a pytree matching a prefix
|
||||
of the `args`.
|
||||
See [how optional parameters are matched to
|
||||
arguments](https://jax.readthedocs.io/en/latest/pytrees.html#applying-optional-parameters-to-pytrees).
|
||||
|
||||
Note that this function does not ensure that the provided `args` shapes
|
||||
are compatible with `polymorphic_shapes`. The `.shape` of the `args` are
|
||||
used only to fill-in placeholders from `polymorphic_shapes`.
|
||||
|
||||
See docstring of `symbolic_shape` and
|
||||
[the README](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#shape-polymorphic-conversion)
|
||||
for more details.
|
||||
get_shape_and_dtype: a function that given an argument extracts a tuple
|
||||
of a shape and a dtype.
|
||||
|
||||
Returns: a pytree of jax.ShapeDTypeStruct matching `args`.
|
||||
"""
|
||||
args_flat, args_tree = tree_util.tree_flatten(args)
|
||||
|
||||
shapes_and_dtypes = tuple(map(get_shape_and_dtype, args_flat))
|
||||
shapes, dtypes = util.unzip2(shapes_and_dtypes)
|
||||
|
||||
if isinstance(args, tuple) and isinstance(polymorphic_shapes, list):
|
||||
# TODO: Remove backward-compatibility workaround
|
||||
polymorphic_shapes_ = tuple(polymorphic_shapes)
|
||||
else:
|
||||
polymorphic_shapes_ = polymorphic_shapes
|
||||
|
||||
try:
|
||||
polymorphic_shapes_flat = tree_util.broadcast_prefix(
|
||||
polymorphic_shapes_, args,
|
||||
is_leaf=lambda x: x is None)
|
||||
except ValueError:
|
||||
e, *_ = tree_util.prefix_errors(
|
||||
polymorphic_shapes_, args,
|
||||
is_leaf=lambda x: x is None)
|
||||
raise e("jax_export polymorphic_shapes") from None
|
||||
|
||||
# Now add in the polymorphic shapes
|
||||
args_specs_flat = (
|
||||
jax.ShapeDtypeStruct(symbolic_shape(spec, like=s), t)
|
||||
for s, t, spec in zip(shapes, dtypes, polymorphic_shapes_flat))
|
||||
|
||||
return args_tree.unflatten(args_specs_flat)
|
||||
# TODO: deprecated in January 2024, to be removed.
|
||||
warnings.warn(
|
||||
"export.args_specs is deprecated in favor of export.symbolic_args_specs",
|
||||
DeprecationWarning, stacklevel=2)
|
||||
if get_shape_and_dtype is not shape_and_dtype_jax_array:
|
||||
# This was needed in some older jax2tf implementations
|
||||
args = tree_util.tree_map(lambda a: jax.ShapeDtypeStruct(* get_shape_and_dtype(a)),
|
||||
args)
|
||||
return shape_poly.symbolic_args_specs(args, polymorphic_shapes)
|
||||
|
||||
|
||||
def _keep_main_tokens(serialization_version: int) -> bool:
|
||||
|
@ -1199,6 +1199,72 @@ def symbolic_shape(shape_spec: str | None,
|
||||
f"Found {shape_spec_repr}.")
|
||||
return _Parser(shape_spec, like, shape_spec_repr).parse()
|
||||
|
||||
def symbolic_args_specs(
|
||||
args, # pytree of arguments
|
||||
polymorphic_shapes, # prefix pytree of strings
|
||||
):
|
||||
"""Constructs a pytree of jax.ShapeDtypeSpec arguments specs for `export`.
|
||||
|
||||
Note that this function does not ensure that the provided `args` shapes
|
||||
are compatible with `polymorphic_shapes`. The `.shape` of the `args` are
|
||||
used only to fill-in placeholders from `polymorphic_shapes`.
|
||||
|
||||
See docstring of `symbolic_shape` and
|
||||
[the README](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#shape-polymorphic-conversion)
|
||||
for more details.
|
||||
|
||||
Args:
|
||||
args: a pytree of arguments. These can be jax.Array, or jax.ShapeDTypeSpec.
|
||||
This is used to learn the pytree structure of the arguments, their dtypes,
|
||||
and to fill-in the actual shapes where the `polymorphic_shapes` contains
|
||||
placeholders. Note that only the shape dimensions for which
|
||||
`polymorphic_shapes` is a placeholder are used from `args`.
|
||||
The unused dimensions can be `None`, which jax2tf uses when the TF
|
||||
shapes are not known.
|
||||
polymorphic_shapes: should be `None` (all arguments have static shapes),
|
||||
a single string (applies to all arguments), or a pytree matching a prefix
|
||||
of the `args`.
|
||||
See [how optional parameters are matched to
|
||||
arguments](https://jax.readthedocs.io/en/latest/pytrees.html#applying-optional-parameters-to-pytrees).
|
||||
|
||||
Returns: a pytree of jax.ShapeDTypeStruct matching the `args` with the shapes
|
||||
replaced with symbolic dimensions as specified by `polymorphic_shapes`.
|
||||
"""
|
||||
args_flat, args_tree = tree_util.tree_flatten(args)
|
||||
|
||||
shapes_and_dtypes = tuple(map(shape_and_dtype_jax_array, args_flat))
|
||||
shapes, dtypes = util.unzip2(shapes_and_dtypes)
|
||||
|
||||
if isinstance(args, tuple) and isinstance(polymorphic_shapes, list):
|
||||
# TODO: Remove backward-compatibility workaround
|
||||
polymorphic_shapes_ = tuple(polymorphic_shapes)
|
||||
else:
|
||||
polymorphic_shapes_ = polymorphic_shapes
|
||||
|
||||
try:
|
||||
polymorphic_shapes_flat = tree_util.broadcast_prefix(
|
||||
polymorphic_shapes_, args,
|
||||
is_leaf=lambda x: x is None)
|
||||
except ValueError:
|
||||
e, *_ = tree_util.prefix_errors(
|
||||
polymorphic_shapes_, args,
|
||||
is_leaf=lambda x: x is None)
|
||||
raise e("jax_export polymorphic_shapes") from None
|
||||
|
||||
# Now add in the polymorphic shapes
|
||||
args_specs_flat = (
|
||||
jax.ShapeDtypeStruct(symbolic_shape(spec, like=s), t)
|
||||
for s, t, spec in zip(shapes, dtypes, polymorphic_shapes_flat))
|
||||
|
||||
return args_tree.unflatten(args_specs_flat)
|
||||
|
||||
def shape_and_dtype_jax_array(a) -> tuple[Sequence[int | None], DType]:
|
||||
"""Returns the shape and dtype of a jax.Array or a j"""
|
||||
if isinstance(a, jax.ShapeDtypeStruct):
|
||||
return a.shape, a.dtype
|
||||
aval = core.raise_to_shaped(core.get_aval(a))
|
||||
return aval.shape, aval.dtype
|
||||
|
||||
class _Parser:
|
||||
def __init__(self,
|
||||
shape_spec: str,
|
||||
|
@ -341,7 +341,8 @@ def convert(fun_jax: Callable,
|
||||
"for all platforms without native_serialization.")
|
||||
|
||||
if (not isinstance(native_serialization_platforms, (list, tuple)) or
|
||||
not all(p in ["cpu", "cuda", "rocm", "tpu"] for p in native_serialization_platforms)):
|
||||
not all(p in ["cpu", "cuda", "rocm", "tpu"]
|
||||
for p in native_serialization_platforms)):
|
||||
raise ValueError(
|
||||
"native_serialization_platforms must be a sequence "
|
||||
"containing a subset of {'cpu', 'cuda', 'rocm', 'tpu'}. "
|
||||
@ -364,22 +365,26 @@ def convert(fun_jax: Callable,
|
||||
source_info_util.register_exclusion(os.path.dirname(tf.__file__))
|
||||
_has_registered_tf_source_path = True
|
||||
|
||||
def shape_and_dtype_tf(a: TfVal) -> tuple[Sequence[int | None], DType]:
|
||||
def jax_arg_spec_from_tf(a: TfVal) -> jax.ShapeDtypeStruct:
|
||||
# The shape and JAX dtype for a TF argument
|
||||
tf_arg_shape = np.shape(a)
|
||||
# Fix the shape for TF1
|
||||
tf_arg_shape = tuple(d.value if isinstance(d, tf.compat.v1.Dimension) else d for d in tf_arg_shape)
|
||||
tf_arg_shape = tuple(d.value
|
||||
if isinstance(d, tf.compat.v1.Dimension) else d
|
||||
for d in tf_arg_shape)
|
||||
_, a_jax_dtype = _tfval_to_tensor_jax_dtype(a)
|
||||
return tf_arg_shape, a_jax_dtype
|
||||
# We count on the fact that jax.ShapeDtypeStruct allows shapes that
|
||||
# contain None.
|
||||
return jax.ShapeDtypeStruct(tf_arg_shape, a_jax_dtype)
|
||||
|
||||
args_specs = export.args_specs(args_tf,
|
||||
polymorphic_shapes=polymorphic_shapes,
|
||||
get_shape_and_dtype=shape_and_dtype_tf)
|
||||
args_jax_specs = tree_util.tree_map(jax_arg_spec_from_tf, args_tf)
|
||||
args_specs = export.symbolic_args_specs(
|
||||
args_jax_specs, polymorphic_shapes=polymorphic_shapes)
|
||||
# The polymorphic_shapes argument refers to positional arguments only.
|
||||
# We assume None for the kwargs.
|
||||
kwargs_specs = export.args_specs(kwargs_tf,
|
||||
polymorphic_shapes=None,
|
||||
get_shape_and_dtype=shape_and_dtype_tf)
|
||||
kwargs_jax_specs = tree_util.tree_map(jax_arg_spec_from_tf, kwargs_tf)
|
||||
kwargs_specs = export.symbolic_args_specs(
|
||||
kwargs_jax_specs, polymorphic_shapes=None)
|
||||
combined_args_tf = (args_tf, kwargs_tf)
|
||||
args_flat_tf: Sequence[TfVal]
|
||||
args_flat_tf, args_kwargs_tree = tree_util.tree_flatten(combined_args_tf)
|
||||
@ -657,7 +662,7 @@ def eval_polymorphic_shape(fun_jax: Callable,
|
||||
(c, a)
|
||||
"""
|
||||
def do_eval_polymorphic_shape(*args_specs) -> Any:
|
||||
args_poly_specs = export.args_specs(
|
||||
args_poly_specs = export.symbolic_args_specs(
|
||||
args_specs, polymorphic_shapes=polymorphic_shapes)
|
||||
res_poly_spec = jax.eval_shape(fun_jax, *args_poly_specs)
|
||||
# TODO(necula): For now we export the polymorphic shapes using `str`.
|
||||
|
@ -166,7 +166,7 @@ class PolyHarness(Harness):
|
||||
len(polymorphic_shapes), len(args),
|
||||
f"polymorphic_shapes {polymorphic_shapes} of length "
|
||||
f"{len(polymorphic_shapes)} must match number of arguments {len(args)}")
|
||||
args_specs = export.args_specs(args, polymorphic_shapes)
|
||||
args_specs = export.symbolic_args_specs(args, polymorphic_shapes)
|
||||
input_signature = [
|
||||
tf.TensorSpec(
|
||||
[d if isinstance(d, int) else None for d in a.shape],
|
||||
|
@ -868,7 +868,7 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
perm = [(j, (j + 1) % axis_size) for j in range(axis_size)]
|
||||
return lax.ppermute(b, "x", perm=perm)
|
||||
|
||||
args_specs = export.args_specs((a,), polymorphic_shapes=poly)
|
||||
args_specs = export.symbolic_args_specs((a,), polymorphic_shapes=poly)
|
||||
exp = get_exported(f_jax)(*args_specs)
|
||||
|
||||
# Test JAX native execution
|
||||
|
@ -696,7 +696,7 @@ class PolyHarness(Harness):
|
||||
f_jax = self.dyn_fun
|
||||
args = self.dyn_args_maker(tst.rng())
|
||||
args = tree_util.tree_map(jnp.array, args)
|
||||
args_specs = export.args_specs(args, self.polymorphic_shapes)
|
||||
args_specs = export.symbolic_args_specs(args, self.polymorphic_shapes)
|
||||
|
||||
if self.expect_error is not None:
|
||||
with tst.assertRaisesRegex(self.expect_error[0], self.expect_error[1]):
|
||||
|
Loading…
x
Reference in New Issue
Block a user