From 660a01565246844ddf35d44280587378786ac57d Mon Sep 17 00:00:00 2001 From: George Necula Date: Tue, 5 Sep 2023 22:15:22 -0700 Subject: [PATCH] [export] Move jax_export and shape_poly out of jax2tf. Those modules have been developed initially for jax2tf but they do not depend on TF anymore. They are used for JAX native serialization. We move them under jax.experimental.export (also renaming jax_export.py to export.py) so that we can use them without depending on TF. We are leaving behind stub modules jax2tf.jax_export and jax2tf.shape_poly that just redirect some of the public APIs. To be cleaned later. PiperOrigin-RevId: 562988740 --- jax/experimental/export/BUILD | 44 + jax/experimental/export/__init__.py | 14 + jax/experimental/export/export.py | 1046 +++++++++++ jax/experimental/export/shape_poly.py | 1593 +++++++++++++++++ jax/experimental/jax2tf/BUILD | 19 +- jax/experimental/jax2tf/__init__.py | 4 + jax/experimental/jax2tf/jax2tf.py | 24 +- jax/experimental/jax2tf/jax_export.py | 1037 +---------- jax/experimental/jax2tf/shape_poly.py | 1588 +--------------- .../jax2tf/tests/back_compat_test.py | 4 +- .../jax2tf/tests/back_compat_test_util.py | 18 +- jax/experimental/jax2tf/tests/jax2tf_test.py | 4 +- .../jax2tf/tests/shape_poly_test.py | 7 +- jax/experimental/jax2tf/tests/tf_test_util.py | 14 +- tests/BUILD | 12 + .../export_test.py | 150 +- 16 files changed, 2843 insertions(+), 2735 deletions(-) create mode 100644 jax/experimental/export/BUILD create mode 100644 jax/experimental/export/__init__.py create mode 100644 jax/experimental/export/export.py create mode 100644 jax/experimental/export/shape_poly.py rename jax/experimental/jax2tf/tests/jax_export_test.py => tests/export_test.py (84%) diff --git a/jax/experimental/export/BUILD b/jax/experimental/export/BUILD new file mode 100644 index 000000000..fcf9d20f0 --- /dev/null +++ b/jax/experimental/export/BUILD @@ -0,0 +1,44 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# JAX-export provides APIs for exporting StableHLO for serialization purposes. + +load( + "//jaxlib:jax.bzl", + "py_deps", +) +load("@rules_python//python:defs.bzl", "py_library") + +licenses(["notice"]) + +# Please add new users to :australis_users. +package( + default_applicable_licenses = [], + default_visibility = ["//visibility:private"], +) + +py_library( + name = "export", + srcs = [ + "export.py", + "shape_poly.py", + ], + srcs_version = "PY3", + # TODO: b/255503696: enable pytype + tags = ["pytype_unchecked_annotations"], + visibility = ["//visibility:public"], + deps = [ + "//jax", + ] + py_deps("numpy"), +) diff --git a/jax/experimental/export/__init__.py b/jax/experimental/export/__init__.py new file mode 100644 index 000000000..e9a5c3c22 --- /dev/null +++ b/jax/experimental/export/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2023 The JAX Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== diff --git a/jax/experimental/export/export.py b/jax/experimental/export/export.py new file mode 100644 index 000000000..62d01d044 --- /dev/null +++ b/jax/experimental/export/export.py @@ -0,0 +1,1046 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""JAX APIs for exporting JAX functions for interoperation. + +""" + +from collections.abc import Sequence +import copy +import dataclasses +import functools +import itertools +import re +from typing import Any, Callable, Optional, Union + +from absl import logging + +import numpy as np + +import jax +from jax import config +from jax import sharding + +from jax._src import core +from jax._src import dispatch +from jax._src import pjit +from jax._src import sharding_impls +from jax._src import source_info_util +from jax._src.interpreters import mlir +from jax._src.interpreters import pxla +from jax._src.lib import xla_client +from jax._src.lib.mlir import ir +from jax._src.lib.mlir.dialects import hlo +from jax._src.lib.mlir.dialects import func as func_dialect +from jax._src import tree_util +from jax._src import util +from jax._src import xla_bridge as xb + +from jax.experimental.export import shape_poly + +map = util.safe_map +zip = util.safe_zip + +DType = Any + +class DisabledSafetyCheck: + """A safety check should be skipped on (de)serialization. + + Most of these checks are performed on serialization, but some are deferred to + deserialization. The list of disabled checks is attached to the serialization, + e.g., as a sequence of string attributes to `jax_export.Exported` or of + `tf.XlaCallModuleOp`. + + You can disable more deserialization safety checks by passing + `TF_XLA_FLAGS=--tf_xla_call_module_disabled_checks=platform`. + """ + _impl: str + + @classmethod + def platform(cls) -> "DisabledSafetyCheck": + """Allows the execution platform to differ from the serialization platform. + + Has effect only on deserialization. + """ + return DisabledSafetyCheck("platform") + + @classmethod + def custom_call(cls, target_name: str) -> "DisabledSafetyCheck": + """Allows the serialization of a call target not known to be stable. + + Has effect only on serialization. + Args: + target_name: the name of the custom call target to allow. + """ + return DisabledSafetyCheck(f"custom_call:{target_name}") + + @classmethod + def shape_assertions(cls) -> "DisabledSafetyCheck": + """Allows invocations with shapes that do not meet the constraints. + + Has effect on serialization (to suppress the generation of the assertions) + and also on deserialization (to suppress the checking of the assertions). + """ + return DisabledSafetyCheck("shape_assertions") + + def is_custom_call(self) -> Optional[str]: + """Returns the custom call target allowed by this directive.""" + m = re.match(r'custom_call:(.+)$', self._impl) + return m.group(1) if m else None + + def __init__(self, _impl:str): + # Do not use directly, use builders `platform`, `custom_call`. + self._impl = _impl + + def __str__(self): + return self._impl + __repr__ = __str__ + + def __eq__(self, other) -> bool: + return isinstance(other, DisabledSafetyCheck) and self._impl == other._impl + + def __hash__(self) -> int: + return hash(self._impl) + + +minimum_supported_serialization_version = 6 +maximum_supported_serialization_version = 8 + +@dataclasses.dataclass(frozen=True) +class Exported: + """A JAX function lowered to StableHLO. + + Attributes: + fun_name: the name of the exported function, for error messages. + in_tree: a PyTreeDef describing the tuple (args, kwargs) of the lowered JAX + function. The actual lowering does not depend on the `in_tree`, but this + can be used to invoke the exported function using the same argument + structure. + in_avals: the flat tuple of input abstract values. May contain dimension + expressions in the shapes. + out_tree: a PyTreeDef describing the result of the lowered JAX function. + out_avals: the flat tuple of output abstract values. May contain dimension + expressions in the shapes, with dimension variables among those in + `in_avals. + in_shardings: the flattened input shardings. Only for the inputs that are + specified in `module_kept_var_idx`. If `None` then it is equivalent + to unspecified shardings. + out_shardings: the flattened output shardings, as long as `in_avals`. + lowering_platforms: a tuple containing at least one of 'tpu', 'cpu', + 'cuda', 'rocm'. See below for the calling convention for when + there are multiple lowering platforms. + mlir_module_serialized: the serialized lowered VHLO module. + serialization_version: a version number for the serialized module. + See more versioning details at https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#native-serialization-versions. + module_kept_var_idx: the sorted indices of the arguments among `in_avals` that + must be passed to the module. The other arguments have been dropped + because they are not used. Same length as `in_shardings`. + uses_shape_polymorphism: whether the `mlir_module_serialized` uses shape + polymorphism. This may be because `in_avals` contains dimension + variables, but also from inner calls of shape-polymorphic + Exported modules. + disabled_checks: a list of descriptors of safety checks that have been + disabled at export time. See docstring for `DisabledSafetyCheck`. + _get_vjp: an optional function that takes the current exported function and + returns the exported VJP function. + The VJP function takes a flat list of arguments, + starting with the primal arguments and followed by a cotangent argument + for each primal output. It returns a tuple with the cotangents + corresponding to the flattened primal inputs. + + Calling convention for the exported module: + + The `mlir_module` has a `main` function that takes an optional first + platform index argument if the module supports multiple platforms + (`len(lowering_platforms) > 1`), followed by the kept array arguments + (corresponding to `module_kept_var_idx` and `in_avals`). + The platform index is a i32 scalar encoding the index of the current + compilation platform into the `lowering_platforms` sequence. + + Inner functions use a different calling convention: an optional + platform index argument, optional dimension variable arguments specified + using scalar tensors of type i32 or i64, + followed by optional token arguments (in presence of side effects), + followed by the regular array arguments. + The dimension arguments correspond to the dimension variables appearing in + the `args_avals`, in sorted order of their names. + + Consider the lowering of a function with one array argument of type "f32[w, + 2 * h]", where "w" and "h" are two dimension variables. + Assume that we use multi-platform lowering, and we have + ordered effects. The `main` function will be as follows: + + func public main(platform_index: i32, arg: f32[?, ?]) { + arg_w = hlo.get_dimension_size(arg, 0) + dim1 = hlo.get_dimension_size(arg, 1) + arg_h = hlo.floordiv(dim1, 2) + call _check_shape_assertions(arg) # See below + token = new_token() + token_out, res = call _wrapped_jax_export_main(platform_index, arg_h, arg_w, token_in, arg) + return res + } + + The actual computation is in `_wrapped_jax_export_main`, taking also + the values of `h` and `w` and the token. Proper exporting of + functions with side-effects and tokens is still work-in-progress. + + Note that `main` contains a call to `_check_shape_assertions. + JAX tracing assumes that `arg.shape[1]` is even, and that both `w` and `h` + have values >= 1. We must check these constraints when we invoke the + module. We use a special custom call `@shape_assertion` that takes + a boolean first operand, a string `error_message` attribute that may contain + format specifiers `{0}`, `{1}`, ..., and a variadic number of integer + scalar operands corresponding to the format specifiers. + + func private _check_shape_assertions(arg: f32[?, ?]) { + # Check that w is >= 1 + arg_w = hlo.get_dimension_size(arg, 0) + custom_call @shape_assertion(arg_w >= 1, arg_w, + error_message="Dimension variable 'w' must have integer value >= 1. Found {0}") + # Check that dim1 is even + dim1 = hlo.get_dimension_size(arg, 1) + custom_call @shape_assertion(dim1 % 2 == 0, dim1, + error_message="Dimension variable 'h' must have integer value >= 1. Found non-zero remainder {0}") + # Check that h >= 1 + arg_h = hlo.floordiv(dim1, 2) + custom_call @shape_assertion(arg_h >= 1, arg_h, + error_message=""Dimension variable 'h' must have integer value >= 1. Found {0}") + + If we `call_exported` with this module we perform these checks + statically (in `call_exported_abstract_eval`). + """ + fun_name: str + in_tree: tree_util.PyTreeDef + in_avals: tuple[core.AbstractValue, ...] + out_tree: tree_util.PyTreeDef + out_avals: tuple[core.AbstractValue, ...] + + in_shardings: Optional[tuple[Union[sharding.XLACompatibleSharding, pxla.UnspecifiedValue], ...]] + out_shardings: Optional[tuple[Union[sharding.XLACompatibleSharding, pxla.UnspecifiedValue], ...]] + lowering_platform: str # For backwards compatibility + lowering_platforms: tuple[str, ...] + disabled_checks: Sequence[DisabledSafetyCheck] + + mlir_module_serialized: bytes + serialization_version: int + module_kept_var_idx: tuple[int, ...] + uses_shape_polymorphism: bool + + _get_vjp: Optional[Callable[["Exported"], "Exported"]] + + def mlir_module(self) -> ir.Module: + return xla_client._xla.mlir.deserialize_portable_artifact(self.mlir_module_serialized) + + def __str__(self): + # This is called to make a MLIR source location when we call an Exported, and we + # do not want the entire serialized module to end up in locations. + return f"Exported(fun_name={self.fun_name}, ...)" + + def vjp(self) -> "Exported": + """Gets the exported VJP. + + Returns None if not available, which can happen if the Exported has been + loaded from an external format, without a VJP.""" + if self._get_vjp is None: + raise ValueError("No VJP is available") + return self._get_vjp(self) + + +def default_lowering_platform() -> str: + # Canonicalize to turn 'gpu' into 'cuda' or 'rocm' + return xb.canonicalize_platform(jax.default_backend()) + +def poly_spec( + arg_shape: Sequence[Optional[int]], + arg_dtype: DType, + polymorphic_shape: Optional[str]) -> jax.ShapeDtypeStruct: + """Constructs a jax.ShapeDtypeStruct with polymorphic shapes. + + Args: + arg_shape: the shape, with possibly some unspecified dimensions. + arg_dtype: the jax dtype. + polymorphic_shape: a string specifying the polymorphic shape. + + .. warning:: The shape-polymorphic lowering is an experimental feature. + It is meant to be sound, but it is known to reject some JAX programs + that are shape polymorphic. The details of this feature can change. + + It should be either `None` (all dimensions are constant), or a string of + specification for one axis, and can be either a constant, `_` denoting + a constant dimension given by the `arg_shape`, or the name of a + dimension variable assumed to range over dimension greater than 0. For + convenience, zero or more trailing `_` can be abbreviated with `...`, and + the surrounding parentheses may be missing. + + Note that this function does not ensure that the provided `arg_shape` + is compatible with `polymorphic_shape`. The `arg_shape` is used only + to fill-in placeholders from `polymorphic_shape`. + + See [the README](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#shape-polymorphic-conversion) + for more details. + + Returns: a jax.ShapeDTypeStruct with shapes that may contain symbolic + expressions involving dimension variables. + """ + aval_shape = shape_poly._parse_spec(polymorphic_shape, arg_shape) + return jax.ShapeDtypeStruct(aval_shape, arg_dtype) + +def shape_and_dtype_jax_array(a) -> tuple[Sequence[Optional[int]], DType]: + """Returns the shape and dtype of a jax.Array.""" + aval = core.raise_to_shaped(core.get_aval(a)) + return aval.shape, aval.dtype + +def poly_specs( + args, # pytree of arguments + polymorphic_shapes, # prefix pytree of strings + get_shape_and_dtype=shape_and_dtype_jax_array, +): + """Constructs a pytree of jax.ShapeDtypeSpec. + + Args: + args: a pytree of arguments + polymorphic_shapes: should be `None` (all arguments are monomorphic), + a single string (applies to all arguments), or a pytree matching a prefix + of the `args`. + See [how optional parameters are matched to + arguments](https://jax.readthedocs.io/en/latest/pytrees.html#applying-optional-parameters-to-pytrees). + + Note that this function does not ensure that the provided `args` shapes + are compatible with `polymorphic_shapes`. The `args.shape` are used only + to fill-in placeholders from `polymorphic_shapes`. + + See docstring of `poly_spec` and + [the README](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#shape-polymorphic-conversion) + for more details. + + Returns: a pytree of jax.ShapeDTypeStruct matching `args`. + """ + args_flat, args_tree = tree_util.tree_flatten(args) + + shapes_and_dtypes = tuple(map(get_shape_and_dtype, args_flat)) + shapes, dtypes = util.unzip2(shapes_and_dtypes) + + if isinstance(args, tuple) and isinstance(polymorphic_shapes, list): + # TODO: Remove backward-compatibility workaround + polymorphic_shapes_ = tuple(polymorphic_shapes) + else: + polymorphic_shapes_ = polymorphic_shapes + + try: + polymorphic_shapes_flat = tree_util.broadcast_prefix( + polymorphic_shapes_, args, + is_leaf=lambda x: x is None) + except ValueError: + e, *_ = tree_util.prefix_errors( + polymorphic_shapes_, args, + is_leaf=lambda x: x is None) + raise e("jax_export polymorphic_shapes") from None + + # Now add in the polymorphic shapes + args_specs_flat = tuple( + map(poly_spec, shapes, dtypes, polymorphic_shapes_flat)) + + return args_tree.unflatten(args_specs_flat) + + +def export(fun_jax: Callable, + *, + lowering_platform: Optional[str] = None, + lowering_platforms: Optional[Sequence[str]] = None, + disabled_checks: Sequence[DisabledSafetyCheck] = (), + ) -> Callable[..., Exported]: + """Exports native serialization for a JAX function. + + Args: + fun_jax: the function to lower and serialize. + lowering_platform: one of 'tpu', 'cpu', 'cuda', 'rocm'. If None, then use + the default JAX backend. + lowering_platforms: DO NOT USE (NOT YET FUNCTIONAL). + Optional sequence containing a subset of 'tpu', 'cpu', + 'cuda', 'rocm'. If more than one platform is specified, then + the lowered code takes an argument specifying the platform. + If None, then use the default JAX backend. + The calling convention for multiple platforms is explained in the + `jax_export.Exported` docstring. + disabled_checks: the safety checks to disable. See docstring + of `DisabledSafetyCheck`. + + Returns: a function that takes args and kwargs pytrees of jax.ShapeDtypeStruct, + or values with `.shape` and `.dtype` attributes, and returns an + `Exported`. + + Usage: + + def f_jax(*args, **kwargs): ... + exported = jax_export.export(f_jax)(*args, **kwargs) + """ + fun_name = getattr(fun_jax, "__name__", "unknown") + version = config.jax_serialization_version + if (version < minimum_supported_serialization_version or + version > maximum_supported_serialization_version): + raise ValueError( + f"The requested jax_serialization version {version} is outside the " + f"range of supported versions [{minimum_supported_serialization_version}" + f"..{maximum_supported_serialization_version}]") + + def do_export(*args_specs, **kwargs_specs) -> Exported: + if not hasattr(fun_jax, "lower"): + # We support convert(pjit(f_jax)) and convert(jit(f_jax)) but also + # convert(f_jax), in which case a "jit" is implied. In that case we raise + # an error if the lowered function contains non-replicated sharding annotations. + wrapped_fun_jax = jax.jit(fun_jax) + allow_non_replicated_sharding = False + else: + # If we have a pjit or pmap already we do not wrap with another, and we + # allow shardings. + wrapped_fun_jax = fun_jax # type: ignore + allow_non_replicated_sharding = True + + nonlocal lowering_platforms + if lowering_platforms is not None: + lowering_platforms = tuple(lowering_platforms) + else: + lowering_platforms = (lowering_platform or default_lowering_platform(),) + + # Do not include shape assertions if the version is < 7. + enable_shape_assertions = ( + DisabledSafetyCheck.shape_assertions() not in disabled_checks and + version >= 7) # type: ignore + try: + prev_enable_shape_assertions = shape_poly.thread_local_state.enable_shape_assertions + shape_poly.thread_local_state.enable_shape_assertions = enable_shape_assertions + lowered = wrapped_fun_jax.lower( + *args_specs, **kwargs_specs, + _experimental_lowering_platform=lowering_platforms) + + lowering = lowered._lowering # type: ignore + _check_lowering(lowering) + mlir_module = lowering.stablehlo() + + args_avals_flat, _ = tree_util.tree_flatten(lowered.in_avals) + if "kept_var_idx" in lowering.compile_args: + module_kept_var_idx = tuple(sorted(lowering.compile_args["kept_var_idx"])) + else: + # For pmap + module_kept_var_idx = tuple(range(len(args_avals_flat))) + shape_poly_state = lowering.compile_args["shape_poly_state"] + if (not all(core.is_constant_shape(a.shape) for a in args_avals_flat) + or lowering.compile_args.get("ordered_effects", [])): + # All arguments are kept if we have dimension variables. + assert len(module_kept_var_idx) == len(args_avals_flat) + mlir_module = _wrap_main_func( + mlir_module, args_avals_flat, args_kwargs_tree=lowered.in_tree, + has_platform_index_argument=shape_poly_state.has_platform_index_argument + ) + finally: + shape_poly.thread_local_state.enable_shape_assertions = prev_enable_shape_assertions + + with mlir_module.context: + mlir_module_attrs = mlir_module.operation.attributes + mlir_module_attrs["jax.uses_shape_polymorphism"] = ( + mlir.ir.BoolAttr.get(shape_poly_state.uses_dim_vars)) + + mlir_module_serialized = _serialize_module(mlir_module) + + # Figure out the result types and shapes + if "global_out_avals" in lowering.compile_args: + # This is currently the case for pjit + out_avals_flat = lowering.compile_args["global_out_avals"] + elif "shards" in lowering.compile_args: # for PmapComputation + out_avals_flat = lowering.compile_args["shards"].out_sharded_avals + else: + out_avals_flat = lowered.compile_args["out_avals"] + + # Log and then check the module. + if logging.vlog_is_on(3): + mlir_module_text = mlir.module_to_string(mlir_module) + logmsg = (f"version={version} " + f"lowering_platforms={lowering_platforms} " + f"disabled_checks={disabled_checks}") + logging.info("Lowered JAX module: %s\n", logmsg) + for l in mlir_module_text.splitlines(): + logging.info(l) + + _check_module(mlir_module, + allow_non_replicated_sharding=allow_non_replicated_sharding, + disabled_checks=disabled_checks) + + return Exported( + fun_name=fun_name, + in_tree=lowered.in_tree, + out_tree=lowered.out_tree, + in_avals=tuple(args_avals_flat), + out_avals=tuple(out_avals_flat), + in_shardings=lowering.compile_args["in_shardings"], + out_shardings=lowering.compile_args["out_shardings"], + lowering_platform=lowering_platforms[0], # TODO: remove + lowering_platforms=lowering_platforms, + disabled_checks=tuple(disabled_checks), + mlir_module_serialized=mlir_module_serialized, + module_kept_var_idx=module_kept_var_idx, + uses_shape_polymorphism=shape_poly_state.uses_dim_vars, + serialization_version=version, # type: ignore + _get_vjp=lambda exported: _export_native_vjp(fun_jax, exported)) + + return do_export + + +def _serialize_module(module: ir.Module) -> bytes: + mlir_str = mlir.module_to_bytecode(module) + if hlo.get_api_version() < 4: + target_version = hlo.get_earliest_forward_compatible_version() + else: + # `target_version` is used to manage situations when a StableHLO producer + # (in this case, jax2tf) and a StableHLO consumer were built using + # different versions of StableHLO. + # + # Each StableHLO version `producer_version` has a compatibility window, + # i.e. range of versions [`consumer_version_min`, `consumer_version_max`], + # where StableHLO portable artifacts serialized by `producer_version` + # can be deserialized by `consumer_version` within the window. + # See https://github.com/openxla/stablehlo/blob/main/docs/compatibility.md + # for the exact extent of these compatibility guarantees. + # + # `hlo.get_minimum_version()` returns `consumer_version_min` + # for the current version of StableHLO. We are using it here to maximize + # forward compatibility, i.e. to maximize how far into the past we can go + # and still have the payloads produced by `serialize_portable_artifact` + # compatible with potential consumers from the past. + target_version = hlo.get_minimum_version() + module_serialized = xla_client._xla.mlir.serialize_portable_artifact( + mlir_str, target_version) + return module_serialized + + +def _wrap_main_func( + module: ir.Module, + args_avals_flat: Sequence[core.ShapedArray], + *, + args_kwargs_tree: tree_util.PyTreeDef, + has_platform_index_argument: bool, +) -> ir.Module: + """Wraps the lowered module with a new "main" handling dimension arguments. + + See calling convention documentation for `jax_export.Exported`. + + Args: + module: the HLO module as obtained from lowering. See the calling convention + for inner functions in `jax_export.Exported`. + args_avals_flat: the avals for all the arguments of the lowered function, + which correspond to the array arguments of the `module`. + args_kwargs_tree: the PyTreeDef corresponding to `(args, kwargs)`, for error + messages. + + Returns the wrapped module, without dimension and token arguments. + """ + dim_vars = shape_poly.all_dim_vars(args_avals_flat) + context = mlir.make_ir_context() + with context, ir.Location.unknown(context): + # Make a copy, do not mutate because it may be cached + wrapped_module = ir.Module.parse(mlir.module_to_bytecode(module)) + symbol_table = ir.SymbolTable(wrapped_module.operation) + orig_main = symbol_table["main"] + orig_main.attributes["sym_visibility"] = ir.StringAttr.get("private") + symbol_table.set_symbol_name(orig_main, "_wrapped_jax_export_main") + orig_main_name = ir.StringAttr(symbol_table.insert(orig_main)).value + + def is_token(attrs): + try: + return ir.BoolAttr(ir.DictAttr(attrs)["jax.token"]).value + except KeyError: + return False + + orig_input_types = orig_main.type.inputs + arg_attrs = list(ir.ArrayAttr(orig_main.arg_attrs)) + # The order of args: platform_index_arg, dim args, token args, array args. + nr_platform_index_args = 1 if has_platform_index_argument else 0 + nr_dim_args = len(dim_vars) + nr_token_args = sum(1 for attrs in arg_attrs if is_token(attrs)) + nr_array_args = len(orig_input_types) - nr_platform_index_args - nr_dim_args - nr_token_args + assert nr_array_args >= 0 + assert not any(is_token(attrs) for attrs in arg_attrs[-nr_array_args:]) + (platform_input_types, dim_var_input_types, + token_input_types, array_input_types) = util.split_list( + orig_input_types, [nr_platform_index_args, nr_dim_args, nr_token_args]) + new_main_input_types = platform_input_types + array_input_types + orig_output_types = orig_main.type.results + result_attrs = list(ir.ArrayAttr(orig_main.result_attrs)) + nr_token_results = sum(1 for attrs in result_attrs if is_token(attrs)) + nr_array_results = len(orig_output_types) - nr_token_results + assert nr_array_results >= 0 + assert not any( + is_token(attrs) for attrs in result_attrs[-nr_array_results:]) + new_main_output_types = orig_output_types[-nr_array_results:] + new_main_ftype = ir.FunctionType.get(new_main_input_types, new_main_output_types) + new_main_op = func_dialect.FuncOp( + "main", new_main_ftype, ip=ir.InsertionPoint.at_block_begin(wrapped_module.body)) + new_main_op.attributes["sym_visibility"] = ir.StringAttr.get("public") + try: + new_main_op.arg_attrs = ir.ArrayAttr.get(arg_attrs[0:nr_platform_index_args] + arg_attrs[-nr_array_args:]) + except KeyError: + pass # TODO: better detection if orig_main.arg_attrs does not exist + try: + new_main_op.result_attrs = ir.ArrayAttr.get( + result_attrs[-nr_array_results:]) + except KeyError: + pass + symbol_table.insert(new_main_op) + entry_block = new_main_op.add_entry_block() + with ir.InsertionPoint(entry_block): + module_context = mlir.ModuleContext( + "cpu", "cpu", sharding_impls.ShardingContext([]), + source_info_util.new_name_stack(), + [], itertools.count(1), [], module=wrapped_module, context=context) + ctx = mlir.LoweringRuleContext( + module_context=module_context, primitive=None, + avals_in=args_avals_flat, avals_out=None, + tokens_in=mlir.TokenSet(), tokens_out=None) + new_main_op_array_args = new_main_op.arguments[nr_platform_index_args:] + dim_values = mlir.lower_fun( + functools.partial(shape_poly.compute_dim_vars_from_arg_shapes, + args_avals_flat, args_kwargs_tree=args_kwargs_tree), + multiple_results=True)(ctx, *new_main_op_array_args) + # The arguments to pass to the call to orig_main + orig_main_args: list[ir.Value] = [] + # The platform index and the dimension variables + for arg, arg_type in zip( + list(new_main_op.arguments[0:nr_platform_index_args]) + util.flatten(dim_values), + platform_input_types + dim_var_input_types): + if arg.type != arg_type: + orig_main_args.append(hlo.ConvertOp(arg_type, arg).result) + else: + orig_main_args.append(arg) + # Then the token arguments + orig_main_args.extend(list(mlir.dummy_token()) * nr_token_args) + # Then the array arguments. We insert a ConvertOp as the only use of + # an input argument. This helps the downstream shape refinement because + # it will set the type of input arguments to static shapes, and this + # can invalidate the module if the argument is used as the result of a + # function, or if it appears as the input to a custom_call with + # output_operand_alias attribute. See b/287386268. + for a in new_main_op_array_args: + orig_main_args.append(hlo.ConvertOp(a.type, a).result) + call = func_dialect.CallOp(orig_output_types, + ir.FlatSymbolRefAttr.get(orig_main_name), + orig_main_args) + func_dialect.ReturnOp(call.results[-nr_array_results:]) + symbol_table.set_symbol_name(new_main_op, "main") + + return wrapped_module + +def _check_lowering(lowering) -> None: + if not isinstance(lowering, pxla.MeshComputation): + raise NotImplementedError(f"serialization is supported only for pjit. {lowering}") + + if lowering.compile_args["host_callbacks"] or lowering.compile_args["keepalive"]: + raise NotImplementedError("serialization of host_callbacks is not yet implemented") + # Check that we do not see new compile_args. When we add a compile_args it is + # safe to add it to the allowed_compile_args if it does not change the semantics + # or the calling convention of the lowered module. + allowed_compile_args = [ + "backend", "mesh", "global_in_avals", + "global_out_avals", "in_shardings", "out_shardings", "kept_var_idx", + "spmd_lowering", "auto_spmd_lowering", + "tuple_args", "ordered_effects", "unordered_effects", + "keepalive", "host_callbacks", "pmap_nreps", "committed", + "device_assignment", "jaxpr_debug_info", "shape_poly_state"] + for compile_arg in lowering.compile_args.keys(): + if compile_arg not in allowed_compile_args: + raise NotImplementedError(f"Unrecognized lowered.compile_args[{compile_arg}]") + + # We have not implemented support for some of the compile_args. Check here that + # the compile_args have the values that have been implemented. + not_implemented_msgs = [] + for compile_arg, check_value, err_msg in ( + ("spmd_lowering", lambda v: v, "True"), + ("auto_spmd_lowering", lambda v: not v, "False"), + # tuple_args is a compilation flag, does not affect lowering. + ("tuple_args", lambda v: True, "N/A"), + # unordered_effects do not change the calling convention. Those from + # jax.debug will also result in keepalive being non-empty and unsupported + # custom calls. The CallTfEffect is an exception, but we want to allow + # that one. + ("unordered_effects", lambda v: True, "N/A"), + # ordered_effects are allowed and we ensure that the calling convention is + # unmodified by passing dummy tokens in the main function wrapper. + ("ordered_effects", lambda v: True, "N/A"), + # used for TPU jax.debug, send/recv. Not supported yet. + ("host_callbacks", lambda v: not v, "empty"), + # used on all platforms for callbacks. Not supported yet. + ("keepalive", lambda v: not v, "empty"), + ("pmap_nreps", lambda v: v == 1, "1"), + ("shape_poly_state", lambda v: True, "N/A"), + ): + if compile_arg in lowering.compile_args: + if not check_value(lowering.compile_args[compile_arg]): + not_implemented_msgs.append( + f"{compile_arg} must be {err_msg} and it is {lowering.compile_args[compile_arg]}") + if not_implemented_msgs: + raise NotImplementedError( + "serialization error, unimplemented lowered.compile_args:\n" + + "\n".join(not_implemented_msgs)) + +# These are the JAX custom call target names that are guaranteed to be stable. +# Their backwards compatibility is tested by back_compat_test.py. +_CUSTOM_CALL_TARGETS_GUARANTEED_STABLE = { + "Sharding", "SPMDFullToShardShape", "SPMDShardToFullShape", + "ducc_fft", "dynamic_ducc_fft", "cu_threefry2x32", + # cholesky on CPU + "lapack_spotrf", "lapack_dpotrf", "lapack_cpotrf", "lapack_zpotrf", + # eigh on CPU + "lapack_ssyevd", "lapack_dsyevd", "lapack_cheevd", "lapack_zheevd", + # eigh on GPU + "cusolver_syevj", "cusolver_syevd", + # eigh on TPU + "Eigh", + # eig on CPU + "lapack_sgeev", "lapack_dgeev", "lapack_cgeev", "lapack_zgeev", + # qr on CPU + "lapack_sgeqrf", "lapack_dgeqrf", "lapack_cgeqrf", "lapack_zgeqrf", + # householder product on CPU + "lapack_sorgqr", "lapack_dorgqr", "lapack_cungqr", "lapack_zungqr", + # svd on CPU + "lapack_sgesdd", "lapack_dgesdd", "lapack_cgesdd", "lapack_zgesdd", + # qr on GPU + "cusolver_geqrf", "cublas_geqrf_batched", + "cusolver_geqrf", "cusolver_orgqr", + # qr and svd on TPU + "Qr", "ProductOfElementaryHouseholderReflectors", + # triangular_solve on CPU + "blas_strsm", "blas_dtrsm", "blas_ctrsm", "blas_ztrsm", + # TODO(atondwal, necula): add back_compat tests for lu on CPU/GPU + # # lu on CPU + "lapack_sgetrf", "lapack_dgetrf", "lapack_cgetrf", "lapack_zgetrf", + # schur on CPU + "lapack_sgees", "lapack_dgees", "lapack_cgees", "lapack_zgees", + # # lu on GPU + # "cublas_getrf_batched", "cusolver_getrf", + # "hipblas_getrf_batched", "hipsolver_getrf", + # lu on TPU + "LuDecomposition", + # ApproxTopK on TPU + "ApproxTopK", + "tf.call_tf_function", # From jax2tf.call_tf(func, call_tf_graph=True) + "tpu_custom_call", # Pallas/TPU kernels + # TODO(burmako): maintain backwards compatibility for these, until they + # are upstreamed to StableHLO. + # See https://github.com/openxla/stablehlo/issues/8. + "stablehlo.dynamic_reduce_window", + "stablehlo.dynamic_rng_bit_generator", + "stablehlo.dynamic_top_k", + "shape_assertion", # Used by shape_poly to evaluate assertions +} + + +def _check_module(mod: ir.Module, *, + allow_non_replicated_sharding: bool, + disabled_checks: Sequence[DisabledSafetyCheck]) -> None: + """Run a number of checks on the module. + + Args: + allow_non_replicated_sharding: whether the module is allowed to contain + non_replicated sharding annotations. + disabled_checks: the safety checks that are disabled. + """ + sharding_attr = ir.StringAttr.get("Sharding", mod.context) + shape_assertion_attr = ir.StringAttr.get("shape_assertion", mod.context) + allowed_custom_call_targets: set[str] = copy.copy(_CUSTOM_CALL_TARGETS_GUARANTEED_STABLE) + for dc in disabled_checks: + target = dc.is_custom_call() + if target is not None: + allowed_custom_call_targets.add(target) + + allowed_custom_call_targets_attrs = { + ir.StringAttr.get(target, mod.context) + for target in allowed_custom_call_targets} + disallowed_custom_call_ops: list[str] = [] + def check_sharding(op: ir.Operation, loc: ir.Location): + if not allow_non_replicated_sharding: + try: + sharding = op.attributes["mhlo.sharding"] + except KeyError: + pass + else: + if ir.StringAttr(sharding).value not in ["{replicated}", ""]: + raise ValueError( + "Lowered function does not have a top-level pjit but it has" + f" non-replicated sharding annotations, e.g., {op} at {loc}.\nSee" + " https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#support-for-partitioning" + " for a discussion." + ) + + def check_op(op: ir.Operation): + op_name = op.operation.name + if op_name == "func.func": + check_sharding(op.operation, op.location) + + elif op_name == "stablehlo.custom_call" or op_name == "mhlo.custom_call": + call_target_name_attr = op.operation.attributes["call_target_name"] + if (call_target_name_attr not in allowed_custom_call_targets_attrs): + disallowed_custom_call_ops.append(f"{op} at {op.location}") + if call_target_name_attr == sharding_attr: + check_sharding(op, op.location) + elif call_target_name_attr == shape_assertion_attr: + assert (DisabledSafetyCheck.shape_assertions() not in disabled_checks) + + def walk_operations(op): + check_op(op) + for region in op.operation.regions: + for block in region: + for op in block: + walk_operations(op) + + walk_operations(mod) + if disallowed_custom_call_ops: + disallowed_custom_call_ops_str = "\n".join(disallowed_custom_call_ops) + msg = ("Cannot serialize code with custom calls whose targets have no " + "compatibility guarantees. Examples are:\n" + f"{disallowed_custom_call_ops_str}.\n" + "See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#native-lowering-supports-only-select-custom-calls") + raise ValueError(msg) + + +def _export_native_vjp(primal_fun_jax, primal: Exported) -> Exported: + # Export the VJP of `primal_fun_jax`. See documentation for Exported.vjp + + # Since jax.vjp does not handle kwargs, it is easier to do all the work + # here with flattened functions. + def fun_vjp_jax(*args_and_out_cts_flat_jax): + # Takes a flat list of primals and output cotangents + def flattened_primal_fun_jax(*args_flat): + args, kwargs = primal.in_tree.unflatten(args_flat) + res = primal_fun_jax(*args, **kwargs) + res_flat, res_tree = tree_util.tree_flatten(res) + assert res_tree == primal.out_tree + return res_flat + + args_flat_jax, out_cts_flat_jax = util.split_list(args_and_out_cts_flat_jax, + [len(primal.in_avals)]) + _, pullback_jax = jax.vjp(flattened_primal_fun_jax, *args_flat_jax) + return pullback_jax(out_cts_flat_jax) + + vjp_in_avals = list( + itertools.chain(primal.in_avals, + map(lambda a: a.at_least_vspace(), primal.out_avals))) + + # Expand in_shardings to all in_avals even not kept ones. + all_in_shardings = [sharding_impls.UNSPECIFIED] * len(primal.in_avals) + for idx, in_s in zip(sorted(primal.module_kept_var_idx), + primal.in_shardings): # type: ignore + all_in_shardings[idx] = in_s # type: ignore + all_shardings = all_in_shardings + list(primal.out_shardings) # type: ignore + # Cannot mix unspecified and specified shardings. Make the unspecified + # ones replicated. + specified_shardings = [ + s for s in all_shardings if not sharding_impls.is_unspecified(s)] + + vjp_in_shardings: Any # The primal inputs followed by output cotangents + vjp_out_shardings: Any # The primal output cotangents + if 0 == len(specified_shardings): + vjp_in_shardings = sharding_impls.UNSPECIFIED + vjp_out_shardings = sharding_impls.UNSPECIFIED + else: + if len(specified_shardings) < len(all_shardings): + # There are some specified, but not all; pjit front-end does not liwk + in_s = specified_shardings[0] # pjit will enforce that all have same devices + assert isinstance(in_s, sharding.XLACompatibleSharding) + replicated_s = sharding.GSPMDSharding.get_replicated(in_s._device_assignment) + all_shardings = [ + s if not sharding_impls.is_unspecified(s) else replicated_s + for s in all_shardings] + + vjp_in_shardings = tuple(all_shardings) + vjp_out_shardings = tuple(all_shardings[:len(primal.in_avals)]) + if all(sharding_impls.is_unspecified(s) for s in vjp_out_shardings): + vjp_out_shardings = sharding_impls.UNSPECIFIED + + fun_vjp_jax = pjit.pjit(fun_vjp_jax, + in_shardings=vjp_in_shardings, + out_shardings=vjp_out_shardings) + + return export(fun_vjp_jax, + lowering_platform=primal.lowering_platform, + disabled_checks=primal.disabled_checks)(*vjp_in_avals) + +### Importing + +def call_exported(exported: Exported) -> Callable[..., jax.Array]: + + @jax.custom_vjp + def f_flat(*args_flat): + return call_exported_p.bind(*args_flat, exported=exported) + + def f_flat_vjp_fwd(*args_flat): + # Return the primal arguments as the residual + # TODO: keep as residuals only the arguments that are needed + return f_flat(*args_flat), args_flat + + def f_flat_vjp_bwd(residual, ct_res_flat): + args_flat = residual # residual is the primal argument flat tuple + exp_vjp = exported.vjp() + in_ct_flat = call_exported(exp_vjp)(*args_flat, *ct_res_flat) + return in_ct_flat + + f_flat.defvjp(f_flat_vjp_fwd, f_flat_vjp_bwd) + + def f_imported(*args, **kwargs): + # since custom_vjp does not support kwargs, flatten the function first. + args_flat, in_tree = tree_util.tree_flatten((args, kwargs)) + if in_tree != exported.in_tree: + # Give errors with the precise tree difference; use fake leaves so we can + # use tree_util.equality_errors. + in_args = in_tree.unflatten([0] * in_tree.num_leaves) + exp_in_args = exported.in_tree.unflatten([0] * exported.in_tree.num_leaves) + + msg = ( + "The invocation args and kwargs must have the same pytree structure " + f"as when the function '{exported.fun_name}' was exported, but they " + "have the following structural differences:\n" + + ("\n".join( + f" - {shape_poly.args_kwargs_path_to_str(path)} is a {thing1} in the invocation and a " + f"{thing2} when exported, so {explanation}.\n" + for path, thing1, thing2, explanation + in tree_util.equality_errors(in_args, exp_in_args)))) + raise ValueError(msg) + + res_flat = f_flat(*args_flat) + return exported.out_tree.unflatten(res_flat) + return f_imported + + +# A JAX primitive for invoking a serialized JAX function. +call_exported_p = core.Primitive("call_exported") +call_exported_p.multiple_results = True + +@util.cache() +def _call_exported_abstract_eval(*in_avals: core.AbstractValue, + exported: Exported) -> tuple[core.AbstractValue, ...]: + exported_dim_vars = shape_poly.all_dim_vars(exported.in_avals) + assert len(in_avals) == len(exported.in_avals) # since the pytrees have the same structure + # Check that the expected shapes match the actual ones + for arg_idx, (exp_aval, actual_aval) in enumerate(zip(exported.in_avals, in_avals)): + def pp_arg_dim(dim_idx: Optional[int]) -> str: + return shape_poly.pretty_print_dimension_descriptor(exported.in_tree, + arg_idx, dim_idx) + if len(exp_aval.shape) != len(actual_aval.shape): + raise ValueError( + f"Rank mismatch for {pp_arg_dim(None)}: expected {exp_aval.shape} " + f"and called with {actual_aval.shape}") + if exp_aval.dtype != actual_aval.dtype: + raise ValueError( + f"Dtype mismatch for {pp_arg_dim(None)}: expected {exp_aval.dtype} " + f"and called with {actual_aval.dtype}") + for dim_idx, aval_d in enumerate(exp_aval.shape): + # If the exp_aval has a constant dimension then the actual argument must have + # a matching constant dimension. + if core.is_constant_dim(aval_d): + if (not core.is_constant_dim(actual_aval.shape[dim_idx]) or + aval_d != actual_aval.shape[dim_idx]): + raise ValueError( + f"Shape mismatch for {pp_arg_dim(dim_idx)} " + "(expected same constant): " + f"expected {exp_aval.shape} and called with {actual_aval.shape}") + + # Must express the exported_dim_vars in terms of the shapes in in_avals. + solution, shape_constraints, synth_dim_vars = shape_poly.solve_dim_vars( + exported.in_avals, args_kwargs_tree=exported.in_tree) + synthetic_env = {vname: in_avals[arg_idx].shape[dim_idx] + for (vname, arg_idx, dim_idx) in synth_dim_vars} + synthetic_eval = shape_poly.CachingShapeEvaluator(**synthetic_env) + # We discharge all the constraints statically. This results in much simpler + # composability (because we do not have to worry about the constraints of the + # Exported called recursively; we only need to worry about entry-point + # constraints). This also makes sense from a composibility point of view, + # because we get the same errors if we invoke the exported module, or if we + # trace the exported function. Consider for example, an exported module with + # signature `f32[a, a] -> f32[a]`. If we invoke the module with an argument + # `f32[c, d]` it is better to fail because `c == d` is inconclusive, than + # succeed and add a compile-time check that `c == d`. In the latter case, + # it would be ambiguous whether we should continue tracing with a result + # a type `f32[c]` or `f32[d]`. + shape_constraints.check_statically(synthetic_eval) + exported_dim_values = [synthetic_eval.evaluate(solution[var]) + for var in exported_dim_vars] + return tuple( + core.ShapedArray(core.evaluate_shape(out_aval.shape, exported_dim_vars, + *exported_dim_values), + dtype=out_aval.dtype, weak_type=out_aval.weak_type, + named_shape=out_aval.named_shape) + for out_aval in exported.out_avals) + + +call_exported_p.def_abstract_eval(_call_exported_abstract_eval) + +def _call_exported_impl(*args, exported: Exported): + return dispatch.apply_primitive(call_exported_p, *args, exported=exported) + +call_exported_p.def_impl(_call_exported_impl) + +def _call_exported_lowering(ctx: mlir.LoweringRuleContext, *args, + platform: str, + exported: Exported): + # TODO: implement true multi-platform lowering for call_exported + if (platform not in exported.lowering_platforms and + DisabledSafetyCheck.platform() not in exported.disabled_checks): + raise ValueError( + f"The exported function '{exported.fun_name}' was lowered for " + f"platforms '{exported.lowering_platforms}' but it is used " + f"on '{platform}'.") + + if exported.uses_shape_polymorphism: + ctx.module_context.shape_poly_state.uses_dim_vars = True + + submodule = ir.Module.parse(exported.mlir_module()) + symtab = ir.SymbolTable(submodule.operation) + # The called function may have been exported with polymorphic shapes and called + # now with more refined shapes. We insert hlo.ConvertOp to ensure the module + # is valid. + def convert_shape(x: ir.Value, x_aval: core.AbstractValue, new_aval: core.AbstractValue) -> ir.Value: + new_ir_type = mlir.aval_to_ir_type(new_aval) + if x.type != new_ir_type: + return mlir.convert_hlo(ctx, x, x_aval, new_aval) + else: + return x + + callee_type = symtab["main"].type + # TODO: maybe cache multiple calls + fn = mlir.merge_mlir_modules(ctx.module_context.module, + f"call_exported_{exported.fun_name}", + submodule) + kept_args = [ + convert_shape(a, a_aval, exported_in_aval) + for i, (a, a_aval, exported_in_aval) in enumerate(zip(args, ctx.avals_in, exported.in_avals)) + if i in exported.module_kept_var_idx] + if len(exported.lowering_platforms) > 1: + # The exported module takes a platform index argument + # TODO: implement proper handling of the platform_index when we are + # in a multi-platform lowering context. + platform_index = exported.lowering_platforms.index(platform) + arg_width = callee_type.inputs[0].element_type.width + assert arg_width in [32, 64] + platform_index = np.int32(platform_index) if arg_width == 32 else np.int64(platform_index) # type: ignore + kept_args = [mlir.ir_constant(platform_index)] + kept_args + call = func_dialect.CallOp(callee_type.results, + ir.FlatSymbolRefAttr.get(fn), + kept_args) + # The ctx.avals_out already contain the abstract values refined by + # _call_exported_abstract_eval. + return tuple( + convert_shape(out, out_aval, refined_out_aval) + for out, out_aval, refined_out_aval in zip(call.results, exported.out_avals, ctx.avals_out)) + + +for _p in ("cpu", "tpu", "cuda", "rocm"): + mlir.register_lowering(call_exported_p, + functools.partial(_call_exported_lowering, platform=_p), + platform=_p) diff --git a/jax/experimental/export/shape_poly.py b/jax/experimental/export/shape_poly.py new file mode 100644 index 000000000..a66497b74 --- /dev/null +++ b/jax/experimental/export/shape_poly.py @@ -0,0 +1,1593 @@ +# Copyright 2021 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Shape polymorphism support. + +We introduce a set of dimension variables at the top-level of a `jit` function. +They are introduced implicitly by way of specifying for each dimension of each +argument a symbolic dimension expression in terms of some dimension variables. +All dimension variables are assumed to range over integers greater or equal to 1. + +Symbolic dimensions overload some integer operations, such as +add, multiply, divide, equality, etc. The JAX NumPy layer and the LAX layers have been +touched up to be sensitive to handling shapes that contain symbolic dimensions. +This enables many JAX programs to be traced with symbolic dimensions +in some dimensions. A priority has been to enable the batch +dimension in neural network examples to be polymorphic. + +This was built initially for jax2tf, but it is now customizeable to be +independent of TF. The best documentation at the moment is in the +jax2tf.convert docstring, and the +[README](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md). +""" + +import collections +from collections.abc import Iterable, Sequence +import dataclasses +from enum import Enum +import functools +import itertools +import io +import math +import operator as op +import threading +import tokenize +from typing import Any, Optional, Union + +import numpy as np +import opt_einsum + +import jax +from jax import config +from jax.interpreters import xla + +from jax._src import core +from jax._src import dtypes +from jax._src import effects +from jax._src.lax import lax +from jax._src.lib import version as jaxlib_version +from jax._src.interpreters import mlir +from jax._src.numpy import lax_numpy +from jax._src import tree_util +from jax._src import util +from jax._src.typing import DimSize, Shape + + +TfVal = Any +DimVarEnv = dict[str, jax.Array] +DType = Any + +class InconclusiveDimensionOperation(core.InconclusiveDimensionOperation): + """Raised when we cannot conclusively compute with symbolic dimensions.""" + + _help_msg = """ +This error arises for comparison operations with shapes that +are non-constant, and the result of the operation cannot be represented as +a boolean value for all values of the symbolic dimensions involved. + +Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables +for more details. +""" + + def __init__(self, message: str): + error_msg = f"{message}\n{InconclusiveDimensionOperation._help_msg}" + # https://github.com/python/mypy/issues/5887 + super().__init__(error_msg) # type: ignore + +class _ShapePolyThreadLocalState(threading.local): + + def __init__(self): + # TODO(necula): this does not play well with some lowering caches, because + # this state is not part of the cache key. + self.enable_shape_assertions = True + +thread_local_state = _ShapePolyThreadLocalState() + +class _DimAtom: + """Represents an atom in a symbolic dimension expression. + + Atoms are either variables, or expressions of the form floordiv(E1, E2) or + mod(E1, E2). Atoms are multiplied to form monomials (see _DimMon), and + monomials are added to form symbolic expressions (see _DimExpr). + + Args: + * var: if specified then the atom is a dimension variable. `operation` + must be `None`. + * operation: if specified then the atom is an operation applied to + `operands`. One of `FLOORDIR` or `MOD` or `NON_NEGATIVE`. `var` must be `None` + * operands: the operands to which the operation is applied. + """ + # The supported operations + FLOORDIV = "floordiv" + MOD = "mod" + NON_NEGATIVE = "non_negative" # The max of the operand and 0 + + def __init__(self, *operands: '_DimExpr', + var: Optional[str] = None, + operation: Optional[str] = None): + if var is not None: + assert operation is None + assert not operands + else: + assert operation is not None + self.var = var + self.operation = operation + self.operands = operands + + @classmethod + def from_var(cls, v: str) -> '_DimAtom': + return _DimAtom(var=v) + + def to_var(self) -> Optional[str]: + return self.var + + def get_vars(self) -> set[str]: + # All the vars that appear + if self.var is not None: + return {self.var} + else: + acc = set() + for opnd in self.operands: + acc.update(opnd.get_vars()) + return acc + + @classmethod + def from_operation(cls, operation: str, *operands: '_DimExpr') -> '_DimAtom': + return _DimAtom(*operands, operation=operation) + + def __str__(self): + if self.var is not None: + return self.var + opnd_str = ", ".join([str(opnd) for opnd in self.operands]) + return f"{self.operation}({opnd_str})" + __repr__ = __str__ + + def __hash__(self): + return hash((self.var, self.operation, *self.operands)) + + def __eq__(self, other: Any): + # Used only for hashing + if not isinstance(other, _DimAtom): return False + if (self.var is None) != (other.var is None): return False + if self.var is not None: + return self.var == other.var + else: + def symbolic_equal(e1: '_DimExpr', e2: '_DimExpr') -> bool: + try: + return e1 == e2 + except InconclusiveDimensionOperation: + return False + return (self.operation == other.operation and + all(symbolic_equal(self_o, other_o) + for self_o, other_o in zip(self.operands, other.operands))) + + def __lt__(self, other: '_DimAtom'): + """ + Comparison to another atom in graded reverse lexicographic order. + Used only for determining a sorting order, does not relate to the + comparison of the values of the atom. + """ + if self.var is not None and other.var is not None: + return self.var < other.var + elif self.var is not None: + return True + elif other.var is not None: + return True + elif self.operation != other.operation: + return self.operation < other.operation # type: ignore + else: + return id(self) < id(other) + + def bounds(self) -> tuple[float, float]: + """Returns the lower and upper bounds, or -+ inf.""" + if self.var is not None: + return (1, np.inf) # variables are assumed to be >= 1 + opnd_bounds = [opnd.bounds() for opnd in self.operands] + if self.operation == _DimAtom.FLOORDIV: # a // b + (a_l, a_u), (b_l, b_u) = opnd_bounds + def math_floor_with_inf(a: float, b: float): # math.floor, but aware of inf + assert b != 0 + if not np.isinf(b): # divisor is finite + return math.floor(a / b) if not np.isinf(a) else -np.inf if (a >= 0) != (b >= 0) else np.inf + elif not np.isinf(a): # dividend is finite and divisor is infinite + return -1 if (a >= 0) != (b >= 0) else 0 + else: # both dividend and divisor are infinite + return -np.inf if (a >= 0) != (b >= 0) else np.inf + + # Same reasoning as for multiplication: the bounds are among the cross-product + # of the bounds. + bound_candidates = [math_floor_with_inf(a_l, b_l), math_floor_with_inf(a_l, b_u), + math_floor_with_inf(a_u, b_l), math_floor_with_inf(a_u, b_u)] + return (min(*bound_candidates), max(*bound_candidates)) + + elif self.operation == _DimAtom.MOD: + _, (b_l, b_u) = opnd_bounds + if b_l > 0: # positive divisor + return (0, b_u - 1) + elif b_u < 0: # negative divisor + return (b_l + 1, 0) + else: + return (-np.inf, np.inf) + + elif self.operation == _DimAtom.NON_NEGATIVE: + (b_l, b_h), = opnd_bounds + return (max(0, b_l), max(0, b_h)) + + else: + assert False + + def evaluate(self, env: DimVarEnv): + if self.var is not None: + try: + return env[self.var] + except KeyError: + err_msg = ( + f"Encountered dimension variable '{self.var}' that is not appearing in the shapes of the used function arguments.\n" + "Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#dimension-variables-must-be-solvable-from-the-input-shapes for more details.") + raise KeyError(err_msg) + else: + operand_values = [opnd.evaluate(env) for opnd in self.operands] + if self.operation == _DimAtom.FLOORDIV: + return divmod(*operand_values)[0] # type: ignore + elif self.operation == _DimAtom.MOD: + return divmod(*operand_values)[1] # type: ignore + elif self.operation == _DimAtom.NON_NEGATIVE: + return lax.max(operand_values[0], 0) + else: + assert False, self.operation + +class _DimMon(dict): + """Represents a multiplication of atoms. + + The representation is a dictionary mapping _DimAtom to exponent. + The exponents are integers >= 1. + """ + def __hash__(self): + return hash(frozenset(self.items())) + + def __str__(self): + return "*".join(f"{key}^{exponent}" if exponent != 1 else str(key) + for key, exponent in sorted(self.items())) + + @classmethod + def from_var(cls, v: str) -> '_DimMon': + return _DimMon({_DimAtom.from_var(v): 1}) + + @classmethod + def from_atom(clscls, a: _DimAtom, aexp: int): + return _DimMon({a: aexp}) + + def to_var(self) -> Optional[str]: + """Extract the variable name "x", from a monomial "x". + Return None, if the monomial is not a single variable.""" + items = self.items() + if len(items) != 1: + return None + (a, aexp), = items + if aexp != 1: + return None + return a.to_var() + + def get_vars(self) -> set[str]: + # All the vars that appear in the monomial + acc = set() + for a in self.keys(): + acc.update(a.get_vars()) + return acc + + @classmethod + def from_operation(cls, operation: str, *operands: '_DimExpr') -> '_DimMon': + return _DimMon({_DimAtom.from_operation(operation, *operands): 1}) + + @property + def degree(self): + return sum(self.values()) + + def __lt__(self, other: '_DimMon'): + """ + Comparison to another monomial in graded reverse lexicographic order. + Used only for determining a sorting order, does not relate to the + comparison of the values of the monomial. + """ + self_key = -self.degree, tuple(sorted(self)) + other_key = -other.degree, tuple(sorted(other)) + return self_key > other_key + + def mul(self, other: '_DimMon') -> '_DimMon': + """ + Returns the product with another monomial. Example: (n^2*m) * n == n^3 * m. + """ + return _DimMon(collections.Counter(self) + collections.Counter(other)) + + def divide(self, divisor: '_DimMon') -> '_DimMon': + """ + Divides by another monomial. Raises a InconclusiveDimensionOperation + if the result is not a monomial. + For example, (n^3 * m) // n == n^2*m, but n // m fails. + """ + d = collections.Counter(self) + for key, exponent in divisor.items(): + diff = self.get(key, 0) - exponent + if diff < 0: + raise InconclusiveDimensionOperation(f"Cannot divide {self} by {divisor}.") + elif diff == 0: del d[key] + elif diff > 0: d[key] = diff + return _DimMon(d) + + def bounds(self) -> tuple[float, float]: + """Returns the lower and upper bounds, or -+inf.""" + # The bounds of a product are among the product of bounds. + bounds = [] + for a, exp in self.items(): + a_l, a_u = a.bounds() + assert a_l <= a_u + bounds.append((a_l ** exp, a_u ** exp)) + + candidates = [math.prod(atom_bounds) for atom_bounds in itertools.product(*bounds)] + return (min(*candidates), max(*candidates)) # type: ignore + + + def evaluate(self, env: DimVarEnv): + prod = lambda xs: functools.reduce(_evaluate_multiply, xs) if xs else core.dim_constant(1) + def pow_opt(v, p: int): + return v if p == 1 else prod([v] * p) + return prod([pow_opt(a.evaluate(env), deg) for a, deg in self.items()]) + + +class _DimExpr(): + """Symbolic expression in terms of dimension variables. + + A dimension expression is an addition of products (_DimMon) + of atoms (_DimAtom). + + We overload integer operations, but we do that soundly, raising + :class:`InconclusiveDimensionOperation` when the result is not + representable as a _DimExpr. + + The representation of a _DimExpr is as a dictionary mapping _DimMon to + integer coefficients. The special monomial `_DimMon()` is mapped to the + free integer coefficient of the expression. + """ + + __array_priority__ = 1000 # Same as tracer, for __radd__ and others on ndarray + def __init__(self, coeffs: dict[_DimMon, int]): + # Do not construct _DimExpr directly, unless you are sure that coeffs is + # normalized; Use _DimExpr.normalize. + # Takes ownership of coeffs + self._coeffs = coeffs or {_DimMon(): 0} + + def monomials(self) -> Iterable[tuple[_DimMon, int]]: + return self._coeffs.items() + + @classmethod + def _add_coeffs(cls, coeffs: dict[_DimMon, int], mon: _DimMon, coeff: int): + """Do `coeffs[mon] += coeff` but remove 0 coefficients.""" + old_c = coeffs.get(mon) + if old_c is None: + if coeff != 0: coeffs[mon] = coeff + else: + new_c = old_c + coeff + if new_c == 0: + del coeffs[mon] + else: + coeffs[mon] = new_c + + @classmethod + def normalize(cls, coeffs: dict[_DimMon, int]) -> DimSize: + """The main constructor for _DimExpr. + + Ensures that the symbolic dimension is normalized, e.g., + it is represented as a Python int if it is known to be a constant. + """ + # TODO(necula): profile and optimize this + has_non_zero_degree = False + free_const = 0 + new_coeffs: dict[_DimMon, int] = {} + for mon, coeff in coeffs.items(): + if coeff == 0: continue + if mon.degree == 0: # A constant, there can be a single one + free_const = coeff + else: + has_non_zero_degree = True + + new_coeffs[mon] = new_coeffs.get(mon, 0) + coeff + + if has_non_zero_degree: + return _DimExpr(new_coeffs) + else: + return int(free_const) + + @classmethod + def normalize_floordiv_times_divisor(cls, coeffs: dict[_DimMon, int]) -> DimSize: + # Look for floordiv(E, M) * M and turn into E - mod(E, M). This comes + # up when handling strided convolution. + for dec in _decompose_expr(_DimExpr(coeffs), _DimAtom.FLOORDIV): + # e = factor * floordiv(operands)^exp * rest_monomial + rest_expr + if dec.exp != 1: + continue + if dec.rest_monomial == 1 and dec.factor == 1: + continue + m_trimmed, m_remainder = divmod(dec.factor * dec.rest_monomial, dec.operands[1]) + if m_remainder == 0: + return m_trimmed * (dec.operands[0] - _DimExpr.from_operation(_DimAtom.MOD, *dec.operands)) + dec.rest_expr + return _DimExpr.normalize(coeffs) + + @classmethod + def from_monomial(cls, mon: _DimMon, exp: int): + return _DimExpr.normalize({mon: exp}) + + @classmethod + def from_var(cls, v: str) -> '_DimExpr': + return _DimExpr({_DimMon.from_var(v): 1}) + + @classmethod + def from_operation(cls, operation: str, *operands: '_DimExpr') -> '_DimExpr': + return _DimExpr.from_monomial(_DimMon.from_operation(operation, *operands), 1) + + def to_var(self) -> Optional[str]: + """Extract the variable name "x", from a symbolic expression.""" + items = self.monomials() + if len(items) != 1: # type: ignore + return None + (mon, mon_count), = items + if mon_count != 1: + return None + return mon.to_var() + + def get_vars(self) -> set[str]: + """The variables that appear in a symbolic dimension.""" + acc = set() + for mon, _ in self.monomials(): + acc.update(mon.get_vars()) + return acc + + def eq(self, other: DimSize) -> bool: + lb, ub = _ensure_poly(self - other, "eq").bounds() + if lb == ub == 0: + return True + if lb > 0 or ub < 0: + return False + # See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#comparison-of-symbolic-dimensions-is-partially-supported + return False + + def inconclusive_comparison(self, operation: str, op: Any) -> Exception: + return InconclusiveDimensionOperation( + f"Symbolic dimension comparison '{self}' {operation} '{op}' is inconclusive.\n" + "See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#comparison-of-symbolic-dimensions-is-partially-supported.") + + def ge(self, other: DimSize) -> bool: + lb, ub = _ensure_poly(self - other, "ge").bounds() + if lb >= 0: + return True + if ub < 0: + return False + raise self.inconclusive_comparison(">=", other) + + def __hash__(self): + return hash(tuple(sorted(self.monomials()))) + + def __str__(self): + def _one_monomial(mon, c): + if mon.degree == 0: + return str(c) + if c == 1: + return str(mon) + return f"{c}*{mon}" + return " + ".join(_one_monomial(mon, c) + for mon, c in sorted(self.monomials(), reverse=True)) + + def __repr__(self): + return str(self) + + # We overload +, -, *, because they are fully defined for _DimExpr. + def __add__(self, other): + if isinstance(other, core.Tracer) or not _convertible_to_poly(other): + return self.__jax_array__().__add__(other) + + other = _ensure_poly(other, "add") + coeffs = self._coeffs.copy() + for mon, coeff in other.monomials(): + _DimExpr._add_coeffs(coeffs, mon, coeff) + return _DimExpr.normalize_floordiv_times_divisor(coeffs) + + def __radd__(self, other): + if isinstance(other, core.Tracer) or not _convertible_to_poly(other): + return self.__jax_array__().__radd__(other) + return _ensure_poly(other, "add").__add__(self) + + def __sub__(self, other): + if isinstance(other, core.Tracer) or not _convertible_to_poly(other): + return self.__jax_array__().__sub__(other) + return self + -_ensure_poly(other, "sub") + + def __rsub__(self, other): + if isinstance(other, core.Tracer) or not _convertible_to_poly(other): + return self.__jax_array__().__rsub__(other) + return _ensure_poly(other, "sub").__sub__(self) + + def __neg__(self) -> '_DimExpr': + return _DimExpr({mon: -coeff for mon, coeff in self.monomials()}) + + def __mul__(self, other): + if isinstance(other, core.Tracer) or not _convertible_to_poly(other): + return self.__jax_array__().__mul__(other) + other = _ensure_poly(other, "mul") + coeffs: dict[_DimMon, int] = {} + for mon1, coeff1 in self.monomials(): + for mon2, coeff2 in other.monomials(): + mon = mon1.mul(mon2) + _DimExpr._add_coeffs(coeffs, mon, coeff1 * coeff2) + return _DimExpr.normalize_floordiv_times_divisor(coeffs) + + def __rmul__(self, other): + if isinstance(other, core.Tracer) or not _convertible_to_poly(other): + return self.__jax_array__().__rmul__(other) + return _ensure_poly(other, "mul").__mul__(self) + + def __pow__(self, power, modulo=None): + assert modulo is None + try: + power = int(power) + except: + raise InconclusiveDimensionOperation(f"Symblic dimension cannot be raised to non-integer power '{self}' ^ '{power}'") + return functools.reduce(op.mul, [self] * power) + + def __floordiv__(self, divisor): + if isinstance(divisor, core.Tracer) or not _convertible_to_poly(divisor): + return self.__jax_array__().__floordiv__(divisor) + return self.divmod(_ensure_poly(divisor, "floordiv"))[0] + + def __rfloordiv__(self, other): + if isinstance(other, core.Tracer) or not _convertible_to_poly(other): + return self.__jax_array__().__rfloordiv__(other) + return _ensure_poly(other, "floordiv").__floordiv__(self) + + def __truediv__(self, divisor): + # Used for "/", which always returns a float + return self.__jax_array__().__truediv__(divisor) + + def __rtruediv__(self, dividend): + # Used for "/", when dividend is not a _DimExpr + return self.__jax_array__().__rtruediv__(dividend) + + def __mod__(self, divisor): + if isinstance(divisor, core.Tracer) or not _convertible_to_poly(divisor): + return self.__jax_array__().__mod__(divisor) + return self.divmod(_ensure_poly(divisor, "mod"))[1] + + def __rmod__(self, dividend): + if isinstance(dividend, core.Tracer) or not _convertible_to_poly(dividend): + return self.__jax_array__().__rmod__(dividend) + return _ensure_poly(dividend, "mod").__mod__(self) + + def __divmod__(self, divisor): + if isinstance(divisor, core.Tracer) or not _convertible_to_poly(divisor): + return self.__jax_array__().__divmod__(divisor) + return self.divmod(_ensure_poly(divisor, "divmod")) + + def __rdivmod__(self, dividend): + if isinstance(dividend, core.Tracer) or not _convertible_to_poly(dividend): + return self.__jax_array__().__rdivmod__(dividend) + return _ensure_poly(dividend, "divmod").__divmod__(self) + + def __int__(self): + if self.is_constant: + return op.index(next(iter(self._coeffs.values()))) + else: + raise InconclusiveDimensionOperation(f"Symbolic dimension '{self}' used in a context that requires a constant") + + # We must overload __eq__ and __ne__, or else we get unsound defaults. + __eq__ = eq + def __ne__(self, other: DimSize) -> bool: + return not self.eq(other) + + __ge__ = ge + + def __le__(self, other: DimSize): + try: + return _ensure_poly(other, "le").__ge__(self) + except InconclusiveDimensionOperation as e: + raise self.inconclusive_comparison("<=", other) from e + + def __gt__(self, other: DimSize): + try: + return not _ensure_poly(other, "gt").__ge__(self) + except InconclusiveDimensionOperation as e: + raise self.inconclusive_comparison(">", other) from e + + def __lt__(self, other: DimSize): + try: + return not self.__ge__(other) + except InconclusiveDimensionOperation as e: + raise self.inconclusive_comparison("<", other) from e + + def divmod(self, divisor: "_DimExpr") -> tuple[DimSize, int]: + """ + Floor division with remainder (divmod) generalized to polynomials. + If the `divisor` is not a constant, the remainder must be 0. + If the `divisor` is a constant, the remainder may be non 0, for consistency + with integer divmod. + + :return: Quotient resulting from polynomial division and integer remainder. + """ + assert isinstance(divisor, _DimExpr) + try: + dmon, dcount = divisor.leading_term + dividend, quotient = self, 0 + # invariant: self = dividend + divisor * quotient + # quotient and dividend are changed in the loop; the leading term of + # dividend decreases at each iteration. + while is_poly_dim(dividend) and not dividend.is_constant: + mon, count = dividend.leading_term + try: + qmon = mon.divide(dmon) + except InconclusiveDimensionOperation: + raise InconclusiveDimensionOperation("") + qcount, rcount = divmod(count, dcount) + if rcount != 0: + raise InconclusiveDimensionOperation("") + + q = _DimExpr.from_monomial(qmon, qcount) + quotient += q + dividend -= q * divisor # type: ignore[assignment] + + dividend = int(dividend) # type: ignore[assignment] + if divisor.is_constant: + q, r = divmod(dividend, int(divisor)) # type: ignore + quotient += q + remainder = r + else: + if dividend != 0: + raise InconclusiveDimensionOperation("") + remainder = 0 + + if config.jax_enable_checks: + assert self == divisor * quotient + remainder + return quotient, remainder + except InconclusiveDimensionOperation: + return (_DimExpr.from_operation(_DimAtom.FLOORDIV, self, divisor), # type: ignore + _DimExpr.from_operation(_DimAtom.MOD, self, divisor)) + + def bounds(self) -> tuple[float, float]: + """Returns the lower and upper bounds, or -+inf.""" + lb = ub = self._coeffs.get(_DimMon(), 0) # The free coefficient + for mon, coeff in self.monomials(): + if mon.degree == 0: continue # We already included the free coefficient + m_l, m_u = mon.bounds() + assert m_l <= m_u and coeff != 0 + item_l, item_u = coeff * m_l, coeff * m_u + lb = lb + min(item_l, item_u) # type: ignore + ub = ub + max(item_l, item_u) # type: ignore + + if lb != -np.inf or ub != np.inf: + return lb, ub + # Watch for special-case: ct*a - ct*mod(b, a) >= 1 when ct >= 0 and a >= 0 + # TODO(necula): add more principled support for floordiv and mod + # For example, this will miss "1 + a - mod(b, a)" + for dec in _decompose_expr(self, _DimAtom.MOD): + # E = factor*mod(op1, op2)^exp * rest_monomial + rest_expr + if dec.exp == 1 and dec.rest_monomial == 1 and dec.rest_expr == - dec.factor * dec.operands[1]: + try: + if dec.operands[1] <= 0: + continue + except InconclusiveDimensionOperation: + continue + if dec.factor > 0: + return (-np.inf, -1) + else: + return (1, np.inf) + + return lb, ub + + @property + def is_constant(self): + return len(self._coeffs) == 1 and next(iter(self._coeffs)).degree == 0 + + @property + def leading_term(self) -> tuple[_DimMon, int]: + """Returns the highest degree term that comes first lexicographically.""" + return max(self.monomials()) + + def evaluate(self, env: DimVarEnv): + # Evaluates as a value of dtype=core.dim_value_dtype() + terms = [_evaluate_multiply(mon.evaluate(env), core.dim_constant(coeff)) + for mon, coeff in self.monomials()] + return functools.reduce(_evaluate_add, terms) if len(terms) > 1 else terms[0] + + def non_negative(self) -> "_DimExpr": + return _DimExpr.from_operation(_DimAtom.NON_NEGATIVE, self) + + @staticmethod + def get_aval(dim: "_DimExpr"): + return core.dim_value_aval() + + def dimension_as_value(self): + """Turns a dimension size into a Jax value that we can compute with.""" + return _dim_as_value(self) + + def __jax_array__(self): + # Used for implicit coercions of polynomials as JAX arrays + return _dim_as_value(self) + +@dataclasses.dataclass +class _Decomposition: + """Decomposition of an expression around an operation atom. + + E = factor * mod(*operands)^exp * rest_monomial + rest_expr + """ + factor: int + operands: Sequence[_DimExpr] + exp: int + rest_monomial: _DimExpr + rest_expr: _DimExpr + + +def _decompose_expr(e: _DimExpr, operation: str) -> Iterable[_Decomposition]: + for m, m_factor in e.monomials(): + atoms = [(a, aexp) for a, aexp in m.items() if a.operation == operation] + if atoms: + e_minus_m_coeffs = e._coeffs.copy() + del e_minus_m_coeffs[m] + for a, aexp in atoms: + yield _Decomposition( + factor=m_factor, + operands=a.operands, + exp=aexp, + rest_monomial=_DimExpr({m.divide(_DimMon.from_atom(a, aexp)): 1}), + rest_expr=_DimExpr(e_minus_m_coeffs)) + +core.pytype_aval_mappings[_DimExpr] = _DimExpr.get_aval +xla.pytype_aval_mappings[_DimExpr] = _DimExpr.get_aval +dtypes._weak_types.append(_DimExpr) + +def _convertible_to_int(p: DimSize) -> bool: + try: + op.index(p) + return True + except: + return False + +def _ensure_poly(p: DimSize, + operation_name: str) -> _DimExpr: + if isinstance(p, _DimExpr): return p + if _convertible_to_int(p): + return _DimExpr({_DimMon(): op.index(p)}) + raise TypeError(f"Symnbolic dimension {operation_name} not supported for {p}.") + +def _convertible_to_poly(p: DimSize) -> bool: + return isinstance(p, _DimExpr) or _convertible_to_int(p) + +def is_poly_dim(p: DimSize) -> bool: + return isinstance(p, _DimExpr) + +dtypes.python_scalar_dtypes[_DimExpr] = dtypes.python_scalar_dtypes[int] + +def _einsum_contract_path(*operands, **kwargs): + """Like opt_einsum.contract_path, with support for DimExpr shapes. + + We use opt_einsum.contract_path to compute the schedule, using a fixed + constant for all dimension variables. This is safe because we throw an + error if there are more than 1 contractions. Essentially, we just use + opt_einsum.contract_path to parse the specification. + """ + + # Replace the polymorphic shapes with some concrete shapes for calling + # into opt_einsum.contract_path, because the latter wants to compute the + # sizes of operands and intermediate results. + fake_ops = [] + for operand in operands: + # We replace only array operands + if not hasattr(operand, "dtype"): + fake_ops.append(operand) + else: + shape = np.shape(operand) + def fake_dim(d): + if core.is_constant_dim(d): + return d + else: + if not isinstance(d, _DimExpr): + raise TypeError(f"Encountered unexpected shape dimension {d}") + # It is Ok to replace all polynomials with the same value. We may miss + # here some errors due to non-equal dimensions, but we catch them + # later. + return 8 + fake_ops.append(jax.ShapeDtypeStruct(tuple(map(fake_dim, shape)), + operand.dtype)) + + contract_fake_ops, contractions = opt_einsum.contract_path(*fake_ops, + **kwargs) + contract_operands = [] + for operand in contract_fake_ops: + idx = tuple(i for i, fake_op in enumerate(fake_ops) if operand is fake_op) + assert len(idx) == 1 + contract_operands.append(operands[idx[0]]) + return contract_operands, contractions + +lax_numpy._poly_einsum_handlers[_DimExpr] = _einsum_contract_path + +# To implement shape-constraint checking we use a shape assertion primitive. +# shape_assertion_p.bind(assert_what: bool, *error_message_inputs, +# error_message="...{0}...{1}") +# where "{0}" refers to error_message_inputs[0], etc. +shape_assertion_p = core.Primitive("shape_assertion") +shape_assertion_p.multiple_results = True +shape_assertion_p.def_effectful_abstract_eval( + lambda *_, **__: ((), {shape_assertion_effect})) # type: ignore + +def _shape_assertion_lowering_rule(ctx: mlir.LoweringRuleContext, + assert_what: mlir.ir.Value, + *error_message_inputs: mlir.ir.Value, + error_message: str): + op = mlir.custom_call( + "shape_assertion", + result_types=[], # No results + operands=[assert_what, *error_message_inputs], + has_side_effect=True, + extra_attributes=dict(error_message=mlir.ir.StringAttr.get(error_message)) + ) + return op.results + +mlir.register_lowering(shape_assertion_p, _shape_assertion_lowering_rule) + +class ShapeAssertionEffect(effects.Effect): + __str__ = lambda _: "ShapeAssertionEffect" + +shape_assertion_effect = ShapeAssertionEffect() + +effects.lowerable_effects.add_type(ShapeAssertionEffect) +effects.control_flow_allowed_effects.add_type(ShapeAssertionEffect) +effects.remat_allowed_effects.add_type(ShapeAssertionEffect) +effects.custom_derivatives_allowed_effects.add_type(ShapeAssertionEffect) + +def shape_assertion(assert_what: jax.Array, + *error_message_inputs: jax.Array, + error_message: str) -> None: + """Adds a shape assertion in the code. + + Args: + assert_what: a boolean asserted to be true. Must be computed based only + on dimension expressions, so that it can be evaluated after shape + refinement. + error_message_inputs: integers expressions whose values can be referenced + in the `error_message`. Must be computed based only + on dimension expressions, so that they can be evaluated after shape + refinement. + error_message: an error message, possibly containing format specifiers + {0}, {1}, ..., referencing the values of the `error_message_inputs`. + The format specifiers are sometimes processed with Python's + `string::format` method, and sometimes with `llvm::formatv`. + """ + if thread_local_state.enable_shape_assertions: + shape_assertion_p.bind(assert_what, *error_message_inputs, + error_message=error_message) + +# A JAX primitive with no array arguments but with a dimension parameter +# that is a DimExpr. The value of the primitive is the value of the dimension, +# using int64 in x64 mode or int32 otherwise (core.dim_value_dtype()) +dim_as_value_p = core.Primitive("dim_as_value") +dim_as_value_p.def_abstract_eval(lambda dim: core.dim_value_aval()) + +def dim_as_value_impl(dim: DimSize): + raise NotImplementedError( + "Evaluation rule for 'dim_as_value' is not implemented. " + "It seems that you are using shape polymorphism outside jax2tf.") + +dim_as_value_p.def_impl(dim_as_value_impl) +def _dim_as_value(dim: DimSize): + return dim_as_value_p.bind(dim=dim) + +def _dim_as_value_lowering(ctx: mlir.LoweringRuleContext, *, + dim): + res, = mlir.eval_dynamic_shape(ctx, (dim,)) + out_type = mlir.aval_to_ir_type(ctx.avals_out[0]) + if out_type != res.type: # type: ignore + return mlir.hlo.ConvertOp(out_type, res).results + else: + return [res] + +mlir.register_lowering(dim_as_value_p, _dim_as_value_lowering) + + +class PolyShape(tuple): + """Tuple of polymorphic dimension specifications. + + See docstring of :func:`jax2tf.convert`. + """ + + def __init__(self, *dim_specs): + tuple.__init__(dim_specs) + + def __new__(cls, *dim_specs): + for ds in dim_specs: + if not isinstance(ds, (int, str)) and ds != ...: + msg = (f"Invalid polymorphic shape element: {repr(ds)}; must be a string " + "representing a dimension variable, or an integer, or ...") + raise ValueError(msg) + return tuple.__new__(PolyShape, dim_specs) + + def __str__(self): + return "(" + ", ".join(["..." if d is ... else str(d) for d in self]) + ")" + + +def _parse_spec(shape_spec: Union[str, PolyShape, None], + arg_shape: Sequence[Optional[int]]) -> Sequence[DimSize]: + """Parses the shape polymorphic specification for one array argument. + + We have to be able to parse all strings produced by str(_DimExpr) because + sometimes the output polymorphic shapes of one function become the input + polymorphic shapes of another. + + Args: + shape_spec: a shape polymorphic specification. None stands for "...". + arg_shape: an actual shape, possibly containing unknown dimensions (None). + We use `arg_shape` to fill-in the placeholders `_` and `...` in + the `shape_spec`. The dimensions of `arg_shape` that are used for filling + must be known (not `None`). If a dimension in `arg_shape` is known and + the corresponding dimension in `shape_spec` is a constant then they + must be equal. + + See the README.md for usage. + """ + shape_spec_repr = repr(shape_spec) + if shape_spec is None: + shape_spec = "..." + elif isinstance(shape_spec, PolyShape): + shape_spec = str(shape_spec) + elif not isinstance(shape_spec, str): + raise ValueError("polymorphic shape spec should be None or a string. " + f"Found {shape_spec_repr}.") + return _Parser(shape_spec, arg_shape, shape_spec_repr).parse() + +class _Parser: + def __init__(self, + shape_spec: str, + arg_shape: Sequence[Optional[int]], + shape_spec_repr: str): + self.shape_spec = shape_spec + self.shape_spec_repr = shape_spec_repr # For error messages + self.arg_shape = arg_shape + self.dimensions: list[DimSize] = [] # dimensions we have parsed + + def parse(self) -> Sequence[DimSize]: + self.tokstream = tokenize.tokenize( + io.BytesIO(self.shape_spec.encode("utf-8")).readline) + tok = self.consume_token(self.next_tok(), tokenize.ENCODING) # Always 1st + sh, tok = self.shape(tok) + self.expect_token(tok, [tokenize.ENDMARKER]) + return sh + + def add_dim(self, expr: Optional[DimSize], tok: tokenize.TokenInfo): + if expr is None: + raise self.parse_err(tok, + ("unexpected placeholder for unknown dimension " + f"for argument shape {self.arg_shape}")) + arg_shape_dim = self.arg_shape[len(self.dimensions)] + if core.is_constant_dim(expr) and arg_shape_dim is not None: + if expr != arg_shape_dim: + raise self.parse_err(tok, + (f"different size {expr} for known dimension " + f"for argument shape {self.arg_shape}")) + self.dimensions.append(expr) + + def parse_err(self, tok: Optional[tokenize.TokenInfo], detail: str) -> Exception: + msg = ( + f"syntax error in polymorphic shape {self.shape_spec_repr} " + f"in dimension {len(self.dimensions)}: {detail}. ") + if tok is not None: + msg += f"Parsed '{tok.line[:tok.start[1]]}', remaining '{tok.line[tok.start[1]:]}'." + return ValueError(msg) + + def next_tok(self) -> tokenize.TokenInfo: + while True: + try: + t = next(self.tokstream) + except StopIteration: + raise self.parse_err(None, "unexpected end of string") + if t.exact_type not in [tokenize.NEWLINE, tokenize.INDENT, tokenize.DEDENT]: + return t + + def expect_token(self, tok: tokenize.TokenInfo, expected: Sequence[int]) -> None: + if tok.exact_type not in expected: + msg = ("expecting one of {" + + ", ".join(tokenize.tok_name[t] for t in expected) + "} but found " + + tokenize.tok_name[tok.exact_type]) + raise self.parse_err(tok, msg) + + def consume_token(self, tok: tokenize.TokenInfo, expected: int) -> tokenize.TokenInfo: + self.expect_token(tok, [expected]) + return self.next_tok() + + def integer(self, tok: tokenize.TokenInfo) -> tuple[int, tokenize.TokenInfo]: + self.expect_token(tok, [tokenize.NUMBER]) + try: + val = int(tok.string) + except Exception: + raise self.parse_err(tok, f"expecting integer, found {tok.string}") + return val, self.next_tok() + + # What can follow a shape? + FOLLOW_SHAPE = [tokenize.ENDMARKER, tokenize.RPAR] + def shape(self, tok: tokenize.TokenInfo) -> tuple[Sequence[DimSize], tokenize.TokenInfo]: + # A comma-separated list of _DimExpr, or "_", possibly ended with ... + if tok.exact_type == tokenize.LPAR: + res, tok = self.shape(self.next_tok()) + tok = self.consume_token(tok, tokenize.RPAR) + return res, tok + + while True: + if tok.exact_type in self.FOLLOW_SHAPE: + break + if tok.exact_type == tokenize.ELLIPSIS: + to_add = self.arg_shape[len(self.dimensions):] + for ad in to_add: + self.add_dim(ad, tok) + tok = self.next_tok() + break + if len(self.dimensions) >= len(self.arg_shape): + raise self.parse_err(tok, + f"too many dimensions, arg_shape has {len(self.arg_shape)}") + if tok.exact_type == tokenize.NAME and tok.string == "_": + e = self.arg_shape[len(self.dimensions)] + tok = self.next_tok() + else: + e, tok = self.expr(tok) + self.add_dim(e, tok) + if tok.exact_type in self.FOLLOW_SHAPE: + break + tok = self.consume_token(tok, tokenize.COMMA) + + return tuple(self.dimensions), tok + + # What token can follow a _DimExpr + FOLLOW_EXPR = FOLLOW_SHAPE + [tokenize.COMMA] + + def expr(self, tok: tokenize.TokenInfo) -> tuple[DimSize, tokenize.TokenInfo]: + # A sum of monomials + next_m_negated = False + acc = 0 + while True: + m, tok = self.mon(tok) + acc = acc + (- m if next_m_negated else m) + if tok.exact_type in self.FOLLOW_EXPR: + return acc, tok + next_m_negated = (tok.exact_type == tokenize.MINUS) + self.expect_token(tok, [tokenize.PLUS, tokenize.MINUS]) + tok = self.next_tok() + + FOLLOW_MON = FOLLOW_EXPR + [tokenize.PLUS, tokenize.MINUS] + def mon(self, tok: tokenize.TokenInfo) -> tuple[DimSize, tokenize.TokenInfo]: + # A monomial is product of atoms. Each atom may be raised to an integer power. + acc = 1 + while True: + a, tok = self.atom(tok) + if tok.exact_type == tokenize.CIRCUMFLEX: + tok = self.next_tok() + self.expect_token(tok, [tokenize.NUMBER]) + power, tok = self.integer(tok) + a = a ** power + + acc = acc * a + if tok.exact_type in self.FOLLOW_MON: + return acc, tok + tok = self.consume_token(tok, tokenize.STAR) + + def atom(self, tok: tokenize.TokenInfo) -> tuple[DimSize, tokenize.TokenInfo]: + if tok.exact_type == tokenize.NAME: + if tok.string == _DimAtom.MOD: + return self.binary_op(_DimAtom.MOD, self.next_tok()) + if tok.string == _DimAtom.FLOORDIV: + return self.binary_op(_DimAtom.FLOORDIV, self.next_tok()) + if tok.string == _DimAtom.NON_NEGATIVE: + return self.unary_op(_DimAtom.NON_NEGATIVE, self.next_tok()) + return _DimExpr.from_var(tok.string), self.next_tok() + number_sign = 1 + if tok.exact_type == tokenize.MINUS: # -k are negative constants + number_sign = -1 + tok = self.next_tok() + self.expect_token(tok, [tokenize.NUMBER]) + if tok.exact_type == tokenize.NUMBER: + v, tok = self.integer(tok) + return v * number_sign, tok + self.expect_token(tok, [tokenize.NAME, tokenize.MINUS, tokenize.NUMBER]) + assert False + + def unary_op(self, op: str, tok) -> tuple[DimSize, tokenize.TokenInfo]: + tok = self.consume_token(tok, tokenize.LPAR) + e1, tok = self.expr(tok) + tok = self.consume_token(tok, tokenize.RPAR) + return _DimExpr.from_operation(op, e1), tok # type: ignore + + def binary_op(self, op: str, tok) -> tuple[DimSize, tokenize.TokenInfo]: + tok = self.consume_token(tok, tokenize.LPAR) + e1, tok = self.expr(tok) + tok = self.consume_token(tok, tokenize.COMMA) + e2, tok = self.expr(tok) + tok = self.consume_token(tok, tokenize.RPAR) + return _DimExpr.from_operation(op, e1, e2), tok # type: ignore + + +def _evaluate_add(v1, v2): + try: + if op.index(v1) == 0: + return v2 + except: + pass + try: + if op.index(v2) == 0: + return v1 + except: + pass + return v1 + v2 + +def _evaluate_multiply(v1, v2): + try: + if op.index(v1) == 1: + return v2 + except: + pass + try: + if op.index(v2) == 1: + return v1 + except: + pass + return v1 * v2 + +# dimension_size(operand, dimension=i) get the operand.shape[i] as a +# value of type shape_poly.dim_as_value_dtype(). +dimension_size_p = core.Primitive("dimension_size") +def _dimension_size_abstract_eval(aval: core.AbstractValue, **_) -> core.AbstractValue: + return core.dim_value_aval() + +dimension_size_p.def_abstract_eval(_dimension_size_abstract_eval) + +def _dimension_size_impl(arg, *, dimension): + return core.dim_constant(arg.shape[dimension]) +dimension_size_p.def_impl(_dimension_size_impl) + +def _dimension_size_lowering_rule(ctx, arg, *, dimension): + dim_size = mlir.hlo.GetDimensionSizeOp(arg, dimension) + dim_type = mlir.aval_to_ir_type(core.dim_value_aval()) + if dim_size.result.type != dim_type: + dim_size = mlir.hlo.ConvertOp(dim_type, dim_size) + return dim_size.results + +mlir.register_lowering(dimension_size_p, _dimension_size_lowering_rule) + + +def arg_aval( + arg_shape: Sequence[Optional[int]], + arg_jax_dtype: DType, + polymorphic_shape: Optional[Union[str, PolyShape]]) -> core.ShapedArray: + """Computes abstract values. + + Args: + arg_shape: the shape for the argument, possibly having None dimensions. + arg_dtype: the inferred JAX dtype for the arg. + polymorphic_shape: the polymorphic specification for the argument. + Returns: the JAX abstract value for the argument. + """ + aval_shape = _parse_spec(polymorphic_shape, arg_shape) + return core.ShapedArray(aval_shape, arg_jax_dtype) + +def all_dim_vars(args_avals: Sequence[core.AbstractValue]) -> Sequence[str]: + dim_vars: set[str] = set() + for a in args_avals: + for d in a.shape: + if is_poly_dim(d): + dim_vars = dim_vars.union(d.get_vars()) + return sorted(tuple(dim_vars)) + + +class CachingShapeEvaluator: + def __init__(self, **env): + self.env = env + + @functools.lru_cache(128) + def evaluate(self, e: DimSize): + if core.is_constant_dim(e): + res = op.index(e) + else: + res = e.evaluate(self.env) # type: ignore + return res + + +@dataclasses.dataclass(frozen=True) +class ShapeConstraint: + class Comparator(Enum): + EQ = 1 + GEQ = 2 + + comp: Comparator + left: DimSize + right: DimSize + # `error_message_pieces` is a list of strings and DimSize. The error message + # is formed by evaluating the DimSize and concatenating the sequence. + error_message_pieces: Sequence[Union[str, DimSize]] + + def check_statically(self, eval: CachingShapeEvaluator) -> None: + """Evaluates a constraint statically.""" + left, right = eval.evaluate(self.left), eval.evaluate(self.right) + try: + if self.comp == ShapeConstraint.Comparator.EQ: + ok = (left == right) + elif self.comp == ShapeConstraint.Comparator.GEQ: + ok = (left >= right) + else: + assert False # We are in a context where we know we can evaluate + # all symbolic expressions to constants. + except InconclusiveDimensionOperation as e: + raise self.make_error(eval) from e + if not ok: + raise self.make_error(eval) + + def compute(self, eval: CachingShapeEvaluator) -> Optional[jax.Array]: + """Computes if the constraint is satisfied. + + If the constraint can be resolved statically returns None + or raises ValueError otherwise. If the constraint cannot be + resolved statically, returns a value representing if the + constraint is satisfied. + """ + left, right = eval.evaluate(self.left), eval.evaluate(self.right) + # Try to evaluate the constraint statically. + if core.is_constant_shape((left, right)): + left_int, right_int = op.index(left), op.index(right) + if self.comp == ShapeConstraint.Comparator.EQ: + if not (left_int == right_int): + raise self.make_error(eval) + elif self.comp == ShapeConstraint.Comparator.GEQ: + if not (left_int >= right_int): + raise self.make_error(eval) + else: assert False + return None + + if self.comp == ShapeConstraint.Comparator.EQ: + is_ok = lax.eq(left, right) + elif self.comp == ShapeConstraint.Comparator.GEQ: + is_ok = lax.ge(left, right) + else: assert False + return is_ok + + def __str__(self): + return (f"{self.left} {'==' if self.comp == ShapeConstraint.Comparator.EQ else '>='} {self.right}" + f" ({self.error_message_pieces})") + __repr__ = __str__ + + def error_message_and_inputs( + self, + eval: CachingShapeEvaluator) -> tuple[str, Sequence[Any]]: + """Forms the error_message and error message_inputs. + See shape_assertion. + """ + # There is currenly a limitation in the shape assertion checker that + # it supports at most 32 error_message_inputs. We try to stay within the + # limit, reusing a format specifier if possible. + if jaxlib_version <= (0, 4, 14): + max_error_message_inputs = 4 + else: + max_error_message_inputs = 32 + format_specifiers: dict[DimSize, str] = {} + error_message_inputs: list[Any] = [] + error_message_strings: list[str] = [] + for e in self.error_message_pieces: + if isinstance(e, str): + error_message_strings.append(e) + continue + cached_spec = format_specifiers.get(e) + if cached_spec is not None: + error_message_strings.append(cached_spec) + continue + if len(error_message_inputs) >= max_error_message_inputs: + error_message_strings.append("N/A") + continue + spec = "{" + str(len(error_message_inputs)) + "}" + format_specifiers[e] = spec + error_message_strings.append(spec) + error_message_inputs.append(eval.evaluate(e)) + return ("".join(error_message_strings), + error_message_inputs) + + def make_error(self, eval: CachingShapeEvaluator) -> Exception: + error_message, error_message_inputs = self.error_message_and_inputs(eval) + return ValueError(error_message.format(*error_message_inputs)) + + +class ShapeConstraints: + def __init__(self): + self.constraints: list[ShapeConstraint] = [] + + def add_constraint(self, + comp: ShapeConstraint.Comparator, + left: DimSize, right: DimSize, + error_message_pieces: Sequence[Union[str, DimSize]]): + c = ShapeConstraint(comp, left, right, error_message_pieces) + self.constraints.append(c) + + def check_statically(self, eval: CachingShapeEvaluator) -> None: + """Evaluates all the constraints statically. + + If the static checking of any constraint fails, raises ValueError. + """ + for constraint in self.constraints: + constraint.check_statically(eval) + + def shape_assertions(self, eval: CachingShapeEvaluator) -> None: + """Computes the shape assertions for the set of constraints. + + See jax_export._wrap_main_func docstring. + """ + # We want to report the errors in the same order as `check_statically`. + # So, we process them in order, in case some fail statically, and we + # generate the shape assertions in the same order. + for constraint in self.constraints: + is_ok = constraint.compute(eval) + if is_ok is None: continue # Was resolved statically + error_message, error_message_inputs = constraint.error_message_and_inputs(eval) + shape_assertion( + is_ok, *error_message_inputs, + error_message=error_message) + +@dataclasses.dataclass +class _DimEquation: + # Encodes that `aval_dim_expr`, which is a symbolic expressions containing + # unknown dimension variables from the abstract values, is the specification + # for dimension named `dim_name` (e.g., "args[0].field.shape[2]"). + aval_dim_expr: _DimExpr + dim_name: str + + def __str__(self): + return f"Dimension size of {self.dim_name} with specification '{self.aval_dim_expr}'" + __repr__ = __str__ + + +def args_kwargs_path_to_str(path: tree_util.KeyPath) -> str: + # String description of `args` or `kwargs`, assuming the path for a tree for + # the tuple `(args, kwargs)`. + if path[0] == tree_util.SequenceKey(0): + return f"args{tree_util.keystr(path[1:])}" + elif path[0] == tree_util.SequenceKey(1): + return f"kwargs{tree_util.keystr(path[1:])}" + else: + assert False + +@functools.lru_cache(128) +def _cached_pretty_print_dimension_descriptor( + args_kwargs_tree: tree_util.PyTreeDef, + flat_arg_idx: int) -> str: + args_kwargs_with_paths, _ = tree_util.tree_flatten_with_path( + args_kwargs_tree.unflatten((0,) * args_kwargs_tree.num_leaves)) + arg_str = args_kwargs_path_to_str(args_kwargs_with_paths[flat_arg_idx][0]) + return arg_str + +def pretty_print_dimension_descriptor( + args_kwargs_tree: tree_util.PyTreeDef, + flat_arg_idx: int, dim_idx: Optional[int]) -> str: + arg_str = _cached_pretty_print_dimension_descriptor(args_kwargs_tree, flat_arg_idx) + if dim_idx is not None: + arg_str += f".shape[{dim_idx}]" + return arg_str + +@util.cache() +def solve_dim_vars( + args_avals: Sequence[core.AbstractValue], + args_kwargs_tree: tree_util.PyTreeDef, + ) -> tuple[DimVarEnv, ShapeConstraints, Sequence[tuple[str, int, int]]]: + """Solves dimension variables in a called function's avals in terms of actual argument shapes. + + For example, given: + + args_avals = [ShapedArray((3, a, a + b), f32)] + + we introduce fresh "synthetic" dimension variables to represent the actual + dimension size of actual arguments for each non-constant dimension. + Each synthetic variable has a name, an arg_idx, and a dim_idx, e.g.: + + synthetic_vars = [("args[0].shape[1]", 0, 1), ("args[0].shape[2]", 0, 2)] + + and then we express the solution for the unknown dimension variables {a, b} + as symbolic expressions in terms of the synthetic variables: + + dict(a=args[0].shape[1], b=args[0].shape[2] - args[0].shape[1]) + + Not all equations are solvable. For now, we solve first the linear + uni-variate equations, then the solved variables are used to simplify the + remaining equations to linear uni-variate equations, and the process + continues until all dimension variables are solved. + + Args: + args_avals: the abstract values of the `args`, with shapes that may + include unknown dimension variables. + args_kwargs_tree: a PyTreeDef that describes the tuple `(args, kwargs)` + from which the flat sequence `args_avals` is extracted. Used for + describing args and kwargs in synthetic variable names and in + error messages. + + Returns: a 3-tuple with: (a) the solution for the unknown dimension variables + (b) a list of constraints that must be satisfied for the solution to be a + valid one, and (c) and the list of synthetic variables that may appear in + the solution and the constraints. + + Raises ValueError if it cannot solve some dimension variable. + """ + dim_equations: list[_DimEquation] = [] + synth_dimension_vars: list[tuple[str, int, int]] = [] + # tuples with argument name and its polymorphic shape ('args[0]', '(a, a + b')) + polymorphic_shape_specs: list[tuple[str, str]] = [] + for arg_idx, aval in enumerate(args_avals): + if all(not is_poly_dim(d) for d in aval.shape): + continue + polymorphic_shape_specs.append( + (pretty_print_dimension_descriptor(args_kwargs_tree, arg_idx, None), + str(aval.shape))) + for dim_idx, aval_d in enumerate(aval.shape): + if is_poly_dim(aval_d): + synth_dim_var = pretty_print_dimension_descriptor(args_kwargs_tree, + arg_idx, dim_idx) + synth_dimension_vars.append((synth_dim_var, arg_idx, dim_idx)) + dim_equations.append( + _DimEquation(aval_dim_expr=_ensure_poly(aval_d, "solve_dim_vars"), + dim_name=synth_dim_var)) + + solution, shape_constraints = _solve_dim_equations(dim_equations, + polymorphic_shape_specs) + return solution, shape_constraints, synth_dimension_vars + + +def compute_dim_vars_from_arg_shapes( + args_avals: Sequence[core.AbstractValue], + *actual_args: jax.Array, + args_kwargs_tree: tree_util.PyTreeDef) -> Sequence[jax.Array]: + """Computes values of dimension variables to unify args_avals with actual arguments. + + Like `solve_dim_vars` except that here we express the solution as + JAX arrays that reference the `actual_args`. This function can be used to + generate the code for computing the dimension variables. It also generates + the shape assertions. + + Returns: the values of the dimension variables, in the order determined by + `all_dim_vars(args_avals)`. + """ + dim_vars = all_dim_vars(args_avals) + solution, shape_constraints, synth_dim_vars = solve_dim_vars( + tuple(args_avals), args_kwargs_tree=args_kwargs_tree) + + # Replace the synthetic vars with the dynamic shape of the actual arg + synthetic_env = {vname: dimension_size_p.bind(actual_args[arg_idx], + dimension=dim_idx) + for (vname, arg_idx, dim_idx) in synth_dim_vars} + synthetic_eval = CachingShapeEvaluator(**synthetic_env) + shape_constraints.shape_assertions(synthetic_eval) + dim_values = [synthetic_eval.evaluate(solution[var]) for var in dim_vars] + return tuple(dim_values) + +def _solve_dim_equations( + eqns: list[_DimEquation], + polymorphic_shape_specs: Sequence[tuple[str, str]] +) -> tuple[DimVarEnv, ShapeConstraints]: + # Returns a shape environment and the shape constraints if it can solve all + # dimension variables. Raises an exception if it cannot. + shapeenv: DimVarEnv = {} + solution_error_message_pieces: list[Union[str, _DimExpr]] = [ + " Obtained dimension variables: " + ] # Error message describing the solution + # Prepare error message piece describing the polymorphic shape specs + poly_specs_err_msg = ( + " Using the following polymorphic shapes specifications: " + + ",".join(f"{arg_name}.shape = {arg_spec}" + for arg_name, arg_spec in polymorphic_shape_specs)) + "." + solution_err_msg_trailer_errors = ". Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#shape-assertion-errors for more details." + + shape_constraints = ShapeConstraints() # accumulate shape constraints + + def process_one_eqn(eqn: _DimEquation) -> bool: + # We start with a DimEquation of the form `dim_expr = dim_value` + # Try to rewrite the equation as `var * factor_var = dim_value_2` (a linear + # uni-variate equation). Returns `False` if this rewrite fails. + # Otherwise, compute the `var` value as `dim_value_2 // factor`, add it to + # `shapeenv` and return `True`. + # + # Invariant: + # var * factor_var + remaining_monomials_from_dim_expr = dim_value + var, factor_var = None, None + dim_value = _DimExpr.from_var(eqn.dim_name) + + for mon, factor in eqn.aval_dim_expr.monomials(): + # Perhaps we can already evaluate this monomial (all vars solved) + try: + mon_value = mon.evaluate(shapeenv) + except KeyError: + # `mon` still uses some variables not yet solved. We handle only the + # case when `mon` is a single variable. + v = mon.to_var() + if v is not None and var is None: + var, factor_var = v, factor + continue + else: + dim_value = dim_value + core.dim_constant(-1) * _evaluate_multiply(mon_value, core.dim_constant(factor)) + continue + return False # This equation cannot yet be used to solve a variable + + if var is not None: + if factor_var == 1: + var_value = dim_value + else: + var_value, var_remainder = divmod(dim_value, core.dim_constant(factor_var)) # type: ignore + shape_constraints.add_constraint( + ShapeConstraint.Comparator.EQ, var_remainder, 0, + error_message_pieces=([ + "Input shapes do not match the polymorphic shapes specification. " + "Division had remainder ", var_remainder, + f" when computing the value of '{var}'." + poly_specs_err_msg + ] + solution_error_message_pieces + [ + solution_err_msg_trailer_errors])) + + if not isinstance(var_value, _DimExpr): + assert var_value.dtype == core.dim_value_dtype() + shapeenv[var] = var_value # type: ignore + solution_error_message_pieces.extend([ + f"'{var}' = ", var_value, + f" from specification '{eqn.aval_dim_expr}' " + f"for dimension {eqn.dim_name} (= ", _DimExpr.from_var(eqn.dim_name), + "), "]) + + shape_constraints.add_constraint( + ShapeConstraint.Comparator.GEQ, var_value, 1, + error_message_pieces=[ + "Input shapes do not match the polymorphic shapes specification. " + f"Expected value >= 1 for dimension variable '{var}'." + + poly_specs_err_msg + ] + solution_error_message_pieces + [ + solution_err_msg_trailer_errors]) + + return True + else: + # All variables are resolved for this equation, we emit an assertion + shape_constraints.add_constraint( + ShapeConstraint.Comparator.EQ, + _DimExpr.from_var(eqn.dim_name), + eqn.aval_dim_expr.evaluate(shapeenv), + error_message_pieces=([ + "Input shapes do not match the polymorphic shapes specification. " + f"Found inconsistency between dimension size {eqn.dim_name} (= ", + _DimExpr.from_var(eqn.dim_name), + f") and the specification '{eqn.aval_dim_expr}' (= ", + eqn.aval_dim_expr.evaluate(shapeenv), + ")." + poly_specs_err_msg] + solution_error_message_pieces + + [solution_err_msg_trailer_errors]) + ) + return True + + while True: + nr_eqns = len(eqns) + eqns = [eqn for eqn in eqns if not process_one_eqn(eqn)] + if not eqns: + return shapeenv, shape_constraints # SUCCESS + elif len(eqns) >= nr_eqns: + break + + # We have some equations that we cannot solve further + unsolved_vars: set[str] = set() + unsolved_polys: list[_DimExpr] = [] + for eqn in eqns: + unsolved_vars = unsolved_vars.union(eqn.aval_dim_expr.get_vars()) + unsolved_polys.append(eqn.aval_dim_expr) + unsolved_vars = unsolved_vars.difference(shapeenv.keys()) + err_msg = ( + f"Cannot solve for values of dimension variables {unsolved_vars}. " + "We can only solve linear uni-variate constraints." + poly_specs_err_msg + + " Unprocessed specifications: " + + ", ".join(f"'{eqn.aval_dim_expr}' for dimension size {eqn.dim_name}" + for eqn in eqns) + + ". Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#dimension-variables-must-be-solvable-from-the-input-shapes for more details." + ) + raise ValueError(err_msg) diff --git a/jax/experimental/jax2tf/BUILD b/jax/experimental/jax2tf/BUILD index 5120d6b46..6fb7b8eb3 100644 --- a/jax/experimental/jax2tf/BUILD +++ b/jax/experimental/jax2tf/BUILD @@ -35,34 +35,21 @@ py_library( deps = [":jax2tf_internal"], ) -py_library( - name = "jax_export", - srcs = [ - "jax_export.py", - "shape_poly.py", - ], - srcs_version = "PY3", - # TODO: b/255503696: enable pytype - tags = ["pytype_unchecked_annotations"], - visibility = ["//visibility:public"], - deps = [ - "//jax", - ] + py_deps("numpy"), -) - py_library( name = "jax2tf_internal", srcs = [ "call_tf.py", "impl_no_xla.py", "jax2tf.py", + "jax_export.py", # TODO(necula): remove stub + "shape_poly.py", # TODO(necula): remove stub ], srcs_version = "PY3", # TODO: b/255503696: enable pytype tags = ["pytype_unchecked_annotations"], visibility = jax_visibility("jax2tf_internal"), deps = [ - ":jax_export", "//jax", + "//jax/experimental/export", ] + py_deps("numpy") + py_deps("tensorflow_core") + jax2tf_deps, ) diff --git a/jax/experimental/jax2tf/__init__.py b/jax/experimental/jax2tf/__init__.py index c087bc4e5..c82715ac3 100644 --- a/jax/experimental/jax2tf/__init__.py +++ b/jax/experimental/jax2tf/__init__.py @@ -21,3 +21,7 @@ from jax.experimental.jax2tf.jax2tf import ( PolyShape as PolyShape ) from jax.experimental.jax2tf.call_tf import call_tf as call_tf +# TODO(necula): remove stub. Needed by SAX +from jax.experimental.jax2tf import jax_export +# Needed by maths.qec. +from jax.experimental.jax2tf import shape_poly diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index 0b5427605..234b0b069 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -36,9 +36,9 @@ from jax import numpy as jnp from jax import tree_util from jax import sharding from jax.experimental import maps -from jax.experimental.jax2tf import shape_poly +from jax.experimental.export import shape_poly +from jax.experimental.export import export from jax.experimental.jax2tf import impl_no_xla -from jax.experimental.jax2tf import jax_export from jax.interpreters import xla from jax._src import ad_checkpoint @@ -86,7 +86,7 @@ NameStack = source_info_util.NameStack PolyShape = shape_poly.PolyShape DType = Any -DisabledSafetyCheck = jax_export.DisabledSafetyCheck +DisabledSafetyCheck = export.DisabledSafetyCheck # A temporary internal flag, to enable the wrapping of jax.jit functions # with tf.function(jit_compile=True). See #7389. This change has triggered a @@ -370,14 +370,14 @@ def convert(fun_jax: Callable, _, a_jax_dtype = _tfval_to_tensor_jax_dtype(a) return tf_arg_shape, a_jax_dtype - args_specs = jax_export.poly_specs(args_tf, - polymorphic_shapes=polymorphic_shapes, - get_shape_and_dtype=shape_and_dtype_tf) + args_specs = export.poly_specs(args_tf, + polymorphic_shapes=polymorphic_shapes, + get_shape_and_dtype=shape_and_dtype_tf) # The polymorphic_shapes argument refers to positional arguments only. # We assume None for the kwargs. - kwargs_specs = jax_export.poly_specs(kwargs_tf, - polymorphic_shapes=None, - get_shape_and_dtype=shape_and_dtype_tf) + kwargs_specs = export.poly_specs(kwargs_tf, + polymorphic_shapes=None, + get_shape_and_dtype=shape_and_dtype_tf) combined_args_tf = (args_tf, kwargs_tf) args_flat_tf: Sequence[TfVal] args_flat_tf, args_kwargs_tree = tree_util.tree_flatten(combined_args_tf) @@ -503,7 +503,7 @@ class NativeSerializationImpl(SerializationImpl): _thread_local_state.call_tf_concrete_function_list = _prev_func_list self._restore_context = _restore_context - self.exported = jax_export.export( + self.exported = export.export( self.fun_jax, lowering_platform=self.lowering_platform, disabled_checks=self.native_serialization_disabled_checks @@ -669,7 +669,7 @@ def eval_polymorphic_shape(fun_jax: Callable, (c, a) """ def do_eval_polymorphic_shape(*args_specs) -> Any: - args_poly_specs = jax_export.poly_specs( + args_poly_specs = export.poly_specs( args_specs, polymorphic_shapes=polymorphic_shapes) res_poly_spec = jax.eval_shape(fun_jax, *args_poly_specs) # TODO(necula): For now we export the polymorphic shapes using `str`. @@ -803,7 +803,7 @@ def _interpret_fun_jax( def _run_exported_as_tf(args_flat_tf: Sequence[TfVal], - exported: jax_export.Exported, + exported: export.Exported, ) -> Sequence[TfVal]: """Runs the `exported` as an XlaCallModule TF op. diff --git a/jax/experimental/jax2tf/jax_export.py b/jax/experimental/jax2tf/jax_export.py index 1755cf960..aaccf196d 100644 --- a/jax/experimental/jax2tf/jax_export.py +++ b/jax/experimental/jax2tf/jax_export.py @@ -11,1037 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""JAX APIs for exporting JAX functions for interoperation. +"""This is just a backwards-compatibility stub. -This module is used with jax2tf, but has no TensorFlow dependencies. +The main functionality has been moved to jax.experimental.export.export """ -from collections.abc import Sequence -import copy -import dataclasses -import functools -import itertools -import re -from typing import Any, Callable, Optional, Union - -from absl import logging - -import numpy as np - -import jax -from jax import config -from jax import sharding - -from jax._src import core -from jax._src import dispatch -from jax._src import pjit -from jax._src import sharding_impls -from jax._src import source_info_util -from jax._src.interpreters import mlir -from jax._src.interpreters import pxla -from jax._src.lib import xla_client -from jax._src.lib.mlir import ir -from jax._src.lib.mlir.dialects import hlo -from jax._src.lib.mlir.dialects import func as func_dialect -from jax._src import tree_util -from jax._src import util -from jax._src import xla_bridge as xb - -from jax.experimental.jax2tf import shape_poly - -map = util.safe_map -zip = util.safe_zip - -DType = Any - -class DisabledSafetyCheck: - """A safety check should be skipped on (de)serialization. - - Most of these checks are performed on serialization, but some are deferred to - deserialization. The list of disabled checks is attached to the serialization, - e.g., as a sequence of string attributes to `jax_export.Exported` or of - `tf.XlaCallModuleOp`. - - You can disable more deserialization safety checks by passing - `TF_XLA_FLAGS=--tf_xla_call_module_disabled_checks=platform`. - """ - _impl: str - - @classmethod - def platform(cls) -> "DisabledSafetyCheck": - """Allows the execution platform to differ from the serialization platform. - - Has effect only on deserialization. - """ - return DisabledSafetyCheck("platform") - - @classmethod - def custom_call(cls, target_name: str) -> "DisabledSafetyCheck": - """Allows the serialization of a call target not known to be stable. - - Has effect only on serialization. - Args: - target_name: the name of the custom call target to allow. - """ - return DisabledSafetyCheck(f"custom_call:{target_name}") - - @classmethod - def shape_assertions(cls) -> "DisabledSafetyCheck": - """Allows invocations with shapes that do not meet the constraints. - - Has effect on serialization (to suppress the generation of the assertions) - and also on deserialization (to suppress the checking of the assertions). - """ - return DisabledSafetyCheck("shape_assertions") - - def is_custom_call(self) -> Optional[str]: - """Returns the custom call target allowed by this directive.""" - m = re.match(r'custom_call:(.+)$', self._impl) - return m.group(1) if m else None - - def __init__(self, _impl:str): - # Do not use directly, use builders `platform`, `custom_call`. - self._impl = _impl - - def __str__(self): - return self._impl - __repr__ = __str__ - - def __eq__(self, other) -> bool: - return isinstance(other, DisabledSafetyCheck) and self._impl == other._impl - - def __hash__(self) -> int: - return hash(self._impl) - - -minimum_supported_serialization_version = 6 -maximum_supported_serialization_version = 8 - -@dataclasses.dataclass(frozen=True) -class Exported: - """A JAX function lowered to StableHLO. - - Attributes: - fun_name: the name of the exported function, for error messages. - in_tree: a PyTreeDef describing the tuple (args, kwargs) of the lowered JAX - function. The actual lowering does not depend on the `in_tree`, but this - can be used to invoke the exported function using the same argument - structure. - in_avals: the flat tuple of input abstract values. May contain dimension - expressions in the shapes. - out_tree: a PyTreeDef describing the result of the lowered JAX function. - out_avals: the flat tuple of output abstract values. May contain dimension - expressions in the shapes, with dimension variables among those in - `in_avals. - in_shardings: the flattened input shardings. Only for the inputs that are - specified in `module_kept_var_idx`. If `None` then it is equivalent - to unspecified shardings. - out_shardings: the flattened output shardings, as long as `in_avals`. - lowering_platforms: a tuple containing at least one of 'tpu', 'cpu', - 'cuda', 'rocm'. See below for the calling convention for when - there are multiple lowering platforms. - mlir_module_serialized: the serialized lowered VHLO module. - serialization_version: a version number for the serialized module. - See more versioning details at https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#native-serialization-versions. - module_kept_var_idx: the sorted indices of the arguments among `in_avals` that - must be passed to the module. The other arguments have been dropped - because they are not used. Same length as `in_shardings`. - uses_shape_polymorphism: whether the `mlir_module_serialized` uses shape - polymorphism. This may be because `in_avals` contains dimension - variables, but also from inner calls of shape-polymorphic - Exported modules. - disabled_checks: a list of descriptors of safety checks that have been - disabled at export time. See docstring for `DisabledSafetyCheck`. - _get_vjp: an optional function that takes the current exported function and - returns the exported VJP function. - The VJP function takes a flat list of arguments, - starting with the primal arguments and followed by a cotangent argument - for each primal output. It returns a tuple with the cotangents - corresponding to the flattened primal inputs. - - Calling convention for the exported module: - - The `mlir_module` has a `main` function that takes an optional first - platform index argument if the module supports multiple platforms - (`len(lowering_platforms) > 1`), followed by the kept array arguments - (corresponding to `module_kept_var_idx` and `in_avals`). - The platform index is a i32 scalar encoding the index of the current - compilation platform into the `lowering_platforms` sequence. - - Inner functions use a different calling convention: an optional - platform index argument, optional dimension variable arguments specified - using scalar tensors of type i32 or i64, - followed by optional token arguments (in presence of side effects), - followed by the regular array arguments. - The dimension arguments correspond to the dimension variables appearing in - the `args_avals`, in sorted order of their names. - - Consider the lowering of a function with one array argument of type "f32[w, - 2 * h]", where "w" and "h" are two dimension variables. - Assume that we use multi-platform lowering, and we have - ordered effects. The `main` function will be as follows: - - func public main(platform_index: i32, arg: f32[?, ?]) { - arg_w = hlo.get_dimension_size(arg, 0) - dim1 = hlo.get_dimension_size(arg, 1) - arg_h = hlo.floordiv(dim1, 2) - call _check_shape_assertions(arg) # See below - token = new_token() - token_out, res = call _wrapped_jax_export_main(platform_index, arg_h, arg_w, token_in, arg) - return res - } - - The actual computation is in `_wrapped_jax_export_main`, taking also - the values of `h` and `w` and the token. Proper exporting of - functions with side-effects and tokens is still work-in-progress. - - Note that `main` contains a call to `_check_shape_assertions. - JAX tracing assumes that `arg.shape[1]` is even, and that both `w` and `h` - have values >= 1. We must check these constraints when we invoke the - module. We use a special custom call `@shape_assertion` that takes - a boolean first operand, a string `error_message` attribute that may contain - format specifiers `{0}`, `{1}`, ..., and a variadic number of integer - scalar operands corresponding to the format specifiers. - - func private _check_shape_assertions(arg: f32[?, ?]) { - # Check that w is >= 1 - arg_w = hlo.get_dimension_size(arg, 0) - custom_call @shape_assertion(arg_w >= 1, arg_w, - error_message="Dimension variable 'w' must have integer value >= 1. Found {0}") - # Check that dim1 is even - dim1 = hlo.get_dimension_size(arg, 1) - custom_call @shape_assertion(dim1 % 2 == 0, dim1, - error_message="Dimension variable 'h' must have integer value >= 1. Found non-zero remainder {0}") - # Check that h >= 1 - arg_h = hlo.floordiv(dim1, 2) - custom_call @shape_assertion(arg_h >= 1, arg_h, - error_message=""Dimension variable 'h' must have integer value >= 1. Found {0}") - - If we `call_exported` with this module we perform these checks - statically (in `call_exported_abstract_eval`). - """ - fun_name: str - in_tree: tree_util.PyTreeDef - in_avals: tuple[core.AbstractValue, ...] - out_tree: tree_util.PyTreeDef - out_avals: tuple[core.AbstractValue, ...] - - in_shardings: Optional[tuple[Union[sharding.XLACompatibleSharding, pxla.UnspecifiedValue], ...]] - out_shardings: Optional[tuple[Union[sharding.XLACompatibleSharding, pxla.UnspecifiedValue], ...]] - lowering_platform: str # For backwards compatibility - lowering_platforms: tuple[str, ...] - disabled_checks: Sequence[DisabledSafetyCheck] - - mlir_module_serialized: bytes - serialization_version: int - module_kept_var_idx: tuple[int, ...] - uses_shape_polymorphism: bool - - _get_vjp: Optional[Callable[["Exported"], "Exported"]] - - def mlir_module(self) -> ir.Module: - return xla_client._xla.mlir.deserialize_portable_artifact(self.mlir_module_serialized) - - def __str__(self): - # This is called to make a MLIR source location when we call an Exported, and we - # do not want the entire serialized module to end up in locations. - return f"Exported(fun_name={self.fun_name}, ...)" - - def vjp(self) -> "Exported": - """Gets the exported VJP. - - Returns None if not available, which can happen if the Exported has been - loaded from an external format, without a VJP.""" - if self._get_vjp is None: - raise ValueError("No VJP is available") - return self._get_vjp(self) - - -def default_lowering_platform() -> str: - # Canonicalize to turn 'gpu' into 'cuda' or 'rocm' - return xb.canonicalize_platform(jax.default_backend()) - -def poly_spec( - arg_shape: Sequence[Optional[int]], - arg_dtype: DType, - polymorphic_shape: Optional[str]) -> jax.ShapeDtypeStruct: - """Constructs a jax.ShapeDtypeStruct with polymorphic shapes. - - Args: - arg_shape: the shape, with possibly some unspecified dimensions. - arg_dtype: the jax dtype. - polymorphic_shape: a string specifying the polymorphic shape. - - .. warning:: The shape-polymorphic lowering is an experimental feature. - It is meant to be sound, but it is known to reject some JAX programs - that are shape polymorphic. The details of this feature can change. - - It should be either `None` (all dimensions are constant), or a string of - specification for one axis, and can be either a constant, `_` denoting - a constant dimension given by the `arg_shape`, or the name of a - dimension variable assumed to range over dimension greater than 0. For - convenience, zero or more trailing `_` can be abbreviated with `...`, and - the surrounding parentheses may be missing. - - Note that this function does not ensure that the provided `arg_shape` - is compatible with `polymorphic_shape`. The `arg_shape` is used only - to fill-in placeholders from `polymorphic_shape`. - - See [the README](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#shape-polymorphic-conversion) - for more details. - - Returns: a jax.ShapeDTypeStruct with shapes that may contain symbolic - expressions involving dimension variables. - """ - aval_shape = shape_poly._parse_spec(polymorphic_shape, arg_shape) - return jax.ShapeDtypeStruct(aval_shape, arg_dtype) - -def shape_and_dtype_jax_array(a) -> tuple[Sequence[Optional[int]], DType]: - """Returns the shape and dtype of a jax.Array.""" - aval = core.raise_to_shaped(core.get_aval(a)) - return aval.shape, aval.dtype - -def poly_specs( - args, # pytree of arguments - polymorphic_shapes, # prefix pytree of strings - get_shape_and_dtype=shape_and_dtype_jax_array, -): - """Constructs a pytree of jax.ShapeDtypeSpec. - - Args: - args: a pytree of arguments - polymorphic_shapes: should be `None` (all arguments are monomorphic), - a single string (applies to all arguments), or a pytree matching a prefix - of the `args`. - See [how optional parameters are matched to - arguments](https://jax.readthedocs.io/en/latest/pytrees.html#applying-optional-parameters-to-pytrees). - - Note that this function does not ensure that the provided `args` shapes - are compatible with `polymorphic_shapes`. The `args.shape` are used only - to fill-in placeholders from `polymorphic_shapes`. - - See docstring of `poly_spec` and - [the README](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#shape-polymorphic-conversion) - for more details. - - Returns: a pytree of jax.ShapeDTypeStruct matching `args`. - """ - args_flat, args_tree = tree_util.tree_flatten(args) - - shapes_and_dtypes = tuple(map(get_shape_and_dtype, args_flat)) - shapes, dtypes = util.unzip2(shapes_and_dtypes) - - if isinstance(args, tuple) and isinstance(polymorphic_shapes, list): - # TODO: Remove backward-compatibility workaround - polymorphic_shapes_ = tuple(polymorphic_shapes) - else: - polymorphic_shapes_ = polymorphic_shapes - - try: - polymorphic_shapes_flat = tree_util.broadcast_prefix( - polymorphic_shapes_, args, - is_leaf=lambda x: x is None) - except ValueError: - e, *_ = tree_util.prefix_errors( - polymorphic_shapes_, args, - is_leaf=lambda x: x is None) - raise e("jax_export polymorphic_shapes") from None - - # Now add in the polymorphic shapes - args_specs_flat = tuple( - map(poly_spec, shapes, dtypes, polymorphic_shapes_flat)) - - return args_tree.unflatten(args_specs_flat) - - -def export(fun_jax: Callable, - *, - lowering_platform: Optional[str] = None, - lowering_platforms: Optional[Sequence[str]] = None, - disabled_checks: Sequence[DisabledSafetyCheck] = (), - ) -> Callable[..., Exported]: - """Exports native serialization for a JAX function. - - Args: - fun_jax: the function to lower and serialize. - lowering_platform: one of 'tpu', 'cpu', 'cuda', 'rocm'. If None, then use - the default JAX backend. - lowering_platforms: DO NOT USE (NOT YET FUNCTIONAL). - Optional sequence containing a subset of 'tpu', 'cpu', - 'cuda', 'rocm'. If more than one platform is specified, then - the lowered code takes an argument specifying the platform. - If None, then use the default JAX backend. - The calling convention for multiple platforms is explained in the - `jax_export.Exported` docstring. - disabled_checks: the safety checks to disable. See docstring - of `DisabledSafetyCheck`. - - Returns: a function that takes args and kwargs pytrees of jax.ShapeDtypeStruct, - or values with `.shape` and `.dtype` attributes, and returns an - `Exported`. - - Usage: - - def f_jax(*args, **kwargs): ... - exported = jax_export.export(f_jax)(*args, **kwargs) - """ - fun_name = getattr(fun_jax, "__name__", "unknown") - version = config.jax_serialization_version - if (version < minimum_supported_serialization_version or - version > maximum_supported_serialization_version): - raise ValueError( - f"The requested jax_serialization version {version} is outside the " - f"range of supported versions [{minimum_supported_serialization_version}" - f"..{maximum_supported_serialization_version}]") - - def do_export(*args_specs, **kwargs_specs) -> Exported: - if not hasattr(fun_jax, "lower"): - # We support convert(pjit(f_jax)) and convert(jit(f_jax)) but also - # convert(f_jax), in which case a "jit" is implied. In that case we raise - # an error if the lowered function contains non-replicated sharding annotations. - wrapped_fun_jax = jax.jit(fun_jax) - allow_non_replicated_sharding = False - else: - # If we have a pjit or pmap already we do not wrap with another, and we - # allow shardings. - wrapped_fun_jax = fun_jax # type: ignore - allow_non_replicated_sharding = True - - nonlocal lowering_platforms - if lowering_platforms is not None: - lowering_platforms = tuple(lowering_platforms) - else: - lowering_platforms = (lowering_platform or default_lowering_platform(),) - - # Do not include shape assertions if the version is < 7. - enable_shape_assertions = ( - DisabledSafetyCheck.shape_assertions() not in disabled_checks and - version >= 7) # type: ignore - try: - prev_enable_shape_assertions = shape_poly.thread_local_state.enable_shape_assertions - shape_poly.thread_local_state.enable_shape_assertions = enable_shape_assertions - lowered = wrapped_fun_jax.lower( - *args_specs, **kwargs_specs, - _experimental_lowering_platform=lowering_platforms) - - lowering = lowered._lowering # type: ignore - _check_lowering(lowering) - mlir_module = lowering.stablehlo() - - args_avals_flat, _ = tree_util.tree_flatten(lowered.in_avals) - if "kept_var_idx" in lowering.compile_args: - module_kept_var_idx = tuple(sorted(lowering.compile_args["kept_var_idx"])) - else: - # For pmap - module_kept_var_idx = tuple(range(len(args_avals_flat))) - shape_poly_state = lowering.compile_args["shape_poly_state"] - if (not all(core.is_constant_shape(a.shape) for a in args_avals_flat) - or lowering.compile_args.get("ordered_effects", [])): - # All arguments are kept if we have dimension variables. - assert len(module_kept_var_idx) == len(args_avals_flat) - mlir_module = _wrap_main_func( - mlir_module, args_avals_flat, args_kwargs_tree=lowered.in_tree, - has_platform_index_argument=shape_poly_state.has_platform_index_argument - ) - finally: - shape_poly.thread_local_state.enable_shape_assertions = prev_enable_shape_assertions - - with mlir_module.context: - mlir_module_attrs = mlir_module.operation.attributes - mlir_module_attrs["jax.uses_shape_polymorphism"] = ( - mlir.ir.BoolAttr.get(shape_poly_state.uses_dim_vars)) - - mlir_module_serialized = _serialize_module(mlir_module) - - # Figure out the result types and shapes - if "global_out_avals" in lowering.compile_args: - # This is currently the case for pjit - out_avals_flat = lowering.compile_args["global_out_avals"] - elif "shards" in lowering.compile_args: # for PmapComputation - out_avals_flat = lowering.compile_args["shards"].out_sharded_avals - else: - out_avals_flat = lowered.compile_args["out_avals"] - - # Log and then check the module. - if logging.vlog_is_on(3): - mlir_module_text = mlir.module_to_string(mlir_module) - logmsg = (f"version={version} " - f"lowering_platforms={lowering_platforms} " - f"disabled_checks={disabled_checks}") - logging.info("Lowered JAX module: %s\n", logmsg) - for l in mlir_module_text.splitlines(): - logging.info(l) - - _check_module(mlir_module, - allow_non_replicated_sharding=allow_non_replicated_sharding, - disabled_checks=disabled_checks) - - return Exported( - fun_name=fun_name, - in_tree=lowered.in_tree, - out_tree=lowered.out_tree, - in_avals=tuple(args_avals_flat), - out_avals=tuple(out_avals_flat), - in_shardings=lowering.compile_args["in_shardings"], - out_shardings=lowering.compile_args["out_shardings"], - lowering_platform=lowering_platforms[0], # TODO: remove - lowering_platforms=lowering_platforms, - disabled_checks=tuple(disabled_checks), - mlir_module_serialized=mlir_module_serialized, - module_kept_var_idx=module_kept_var_idx, - uses_shape_polymorphism=shape_poly_state.uses_dim_vars, - serialization_version=version, # type: ignore - _get_vjp=lambda exported: _export_native_vjp(fun_jax, exported)) - - return do_export - - -def _serialize_module(module: ir.Module) -> bytes: - mlir_str = mlir.module_to_bytecode(module) - if hlo.get_api_version() < 4: - target_version = hlo.get_earliest_forward_compatible_version() - else: - # `target_version` is used to manage situations when a StableHLO producer - # (in this case, jax2tf) and a StableHLO consumer were built using - # different versions of StableHLO. - # - # Each StableHLO version `producer_version` has a compatibility window, - # i.e. range of versions [`consumer_version_min`, `consumer_version_max`], - # where StableHLO portable artifacts serialized by `producer_version` - # can be deserialized by `consumer_version` within the window. - # See https://github.com/openxla/stablehlo/blob/main/docs/compatibility.md - # for the exact extent of these compatibility guarantees. - # - # `hlo.get_minimum_version()` returns `consumer_version_min` - # for the current version of StableHLO. We are using it here to maximize - # forward compatibility, i.e. to maximize how far into the past we can go - # and still have the payloads produced by `serialize_portable_artifact` - # compatible with potential consumers from the past. - target_version = hlo.get_minimum_version() - module_serialized = xla_client._xla.mlir.serialize_portable_artifact( - mlir_str, target_version) - return module_serialized - - -def _wrap_main_func( - module: ir.Module, - args_avals_flat: Sequence[core.ShapedArray], - *, - args_kwargs_tree: tree_util.PyTreeDef, - has_platform_index_argument: bool, -) -> ir.Module: - """Wraps the lowered module with a new "main" handling dimension arguments. - - See calling convention documentation for `jax_export.Exported`. - - Args: - module: the HLO module as obtained from lowering. See the calling convention - for inner functions in `jax_export.Exported`. - args_avals_flat: the avals for all the arguments of the lowered function, - which correspond to the array arguments of the `module`. - args_kwargs_tree: the PyTreeDef corresponding to `(args, kwargs)`, for error - messages. - - Returns the wrapped module, without dimension and token arguments. - """ - dim_vars = shape_poly.all_dim_vars(args_avals_flat) - context = mlir.make_ir_context() - with context, ir.Location.unknown(context): - # Make a copy, do not mutate because it may be cached - wrapped_module = ir.Module.parse(mlir.module_to_bytecode(module)) - symbol_table = ir.SymbolTable(wrapped_module.operation) - orig_main = symbol_table["main"] - orig_main.attributes["sym_visibility"] = ir.StringAttr.get("private") - symbol_table.set_symbol_name(orig_main, "_wrapped_jax_export_main") - orig_main_name = ir.StringAttr(symbol_table.insert(orig_main)).value - - def is_token(attrs): - try: - return ir.BoolAttr(ir.DictAttr(attrs)["jax.token"]).value - except KeyError: - return False - - orig_input_types = orig_main.type.inputs - arg_attrs = list(ir.ArrayAttr(orig_main.arg_attrs)) - # The order of args: platform_index_arg, dim args, token args, array args. - nr_platform_index_args = 1 if has_platform_index_argument else 0 - nr_dim_args = len(dim_vars) - nr_token_args = sum(1 for attrs in arg_attrs if is_token(attrs)) - nr_array_args = len(orig_input_types) - nr_platform_index_args - nr_dim_args - nr_token_args - assert nr_array_args >= 0 - assert not any(is_token(attrs) for attrs in arg_attrs[-nr_array_args:]) - (platform_input_types, dim_var_input_types, - token_input_types, array_input_types) = util.split_list( - orig_input_types, [nr_platform_index_args, nr_dim_args, nr_token_args]) - new_main_input_types = platform_input_types + array_input_types - orig_output_types = orig_main.type.results - result_attrs = list(ir.ArrayAttr(orig_main.result_attrs)) - nr_token_results = sum(1 for attrs in result_attrs if is_token(attrs)) - nr_array_results = len(orig_output_types) - nr_token_results - assert nr_array_results >= 0 - assert not any( - is_token(attrs) for attrs in result_attrs[-nr_array_results:]) - new_main_output_types = orig_output_types[-nr_array_results:] - new_main_ftype = ir.FunctionType.get(new_main_input_types, new_main_output_types) - new_main_op = func_dialect.FuncOp( - "main", new_main_ftype, ip=ir.InsertionPoint.at_block_begin(wrapped_module.body)) - new_main_op.attributes["sym_visibility"] = ir.StringAttr.get("public") - try: - new_main_op.arg_attrs = ir.ArrayAttr.get(arg_attrs[0:nr_platform_index_args] + arg_attrs[-nr_array_args:]) - except KeyError: - pass # TODO: better detection if orig_main.arg_attrs does not exist - try: - new_main_op.result_attrs = ir.ArrayAttr.get( - result_attrs[-nr_array_results:]) - except KeyError: - pass - symbol_table.insert(new_main_op) - entry_block = new_main_op.add_entry_block() - with ir.InsertionPoint(entry_block): - module_context = mlir.ModuleContext( - "cpu", "cpu", sharding_impls.ShardingContext([]), - source_info_util.new_name_stack(), - [], itertools.count(1), [], module=wrapped_module, context=context) - ctx = mlir.LoweringRuleContext( - module_context=module_context, primitive=None, - avals_in=args_avals_flat, avals_out=None, - tokens_in=mlir.TokenSet(), tokens_out=None) - new_main_op_array_args = new_main_op.arguments[nr_platform_index_args:] - dim_values = mlir.lower_fun( - functools.partial(shape_poly.compute_dim_vars_from_arg_shapes, - args_avals_flat, args_kwargs_tree=args_kwargs_tree), - multiple_results=True)(ctx, *new_main_op_array_args) - # The arguments to pass to the call to orig_main - orig_main_args: list[ir.Value] = [] - # The platform index and the dimension variables - for arg, arg_type in zip( - list(new_main_op.arguments[0:nr_platform_index_args]) + util.flatten(dim_values), - platform_input_types + dim_var_input_types): - if arg.type != arg_type: - orig_main_args.append(hlo.ConvertOp(arg_type, arg).result) - else: - orig_main_args.append(arg) - # Then the token arguments - orig_main_args.extend(list(mlir.dummy_token()) * nr_token_args) - # Then the array arguments. We insert a ConvertOp as the only use of - # an input argument. This helps the downstream shape refinement because - # it will set the type of input arguments to static shapes, and this - # can invalidate the module if the argument is used as the result of a - # function, or if it appears as the input to a custom_call with - # output_operand_alias attribute. See b/287386268. - for a in new_main_op_array_args: - orig_main_args.append(hlo.ConvertOp(a.type, a).result) - call = func_dialect.CallOp(orig_output_types, - ir.FlatSymbolRefAttr.get(orig_main_name), - orig_main_args) - func_dialect.ReturnOp(call.results[-nr_array_results:]) - symbol_table.set_symbol_name(new_main_op, "main") - - return wrapped_module - -def _check_lowering(lowering) -> None: - if not isinstance(lowering, pxla.MeshComputation): - raise NotImplementedError(f"serialization is supported only for pjit. {lowering}") - - if lowering.compile_args["host_callbacks"] or lowering.compile_args["keepalive"]: - raise NotImplementedError("serialization of host_callbacks is not yet implemented") - # Check that we do not see new compile_args. When we add a compile_args it is - # safe to add it to the allowed_compile_args if it does not change the semantics - # or the calling convention of the lowered module. - allowed_compile_args = [ - "backend", "mesh", "global_in_avals", - "global_out_avals", "in_shardings", "out_shardings", "kept_var_idx", - "spmd_lowering", "auto_spmd_lowering", - "tuple_args", "ordered_effects", "unordered_effects", - "keepalive", "host_callbacks", "pmap_nreps", "committed", - "device_assignment", "jaxpr_debug_info", "shape_poly_state"] - for compile_arg in lowering.compile_args.keys(): - if compile_arg not in allowed_compile_args: - raise NotImplementedError(f"Unrecognized lowered.compile_args[{compile_arg}]") - - # We have not implemented support for some of the compile_args. Check here that - # the compile_args have the values that have been implemented. - not_implemented_msgs = [] - for compile_arg, check_value, err_msg in ( - ("spmd_lowering", lambda v: v, "True"), - ("auto_spmd_lowering", lambda v: not v, "False"), - # tuple_args is a compilation flag, does not affect lowering. - ("tuple_args", lambda v: True, "N/A"), - # unordered_effects do not change the calling convention. Those from - # jax.debug will also result in keepalive being non-empty and unsupported - # custom calls. The CallTfEffect is an exception, but we want to allow - # that one. - ("unordered_effects", lambda v: True, "N/A"), - # ordered_effects are allowed and we ensure that the calling convention is - # unmodified by passing dummy tokens in the main function wrapper. - ("ordered_effects", lambda v: True, "N/A"), - # used for TPU jax.debug, send/recv. Not supported yet. - ("host_callbacks", lambda v: not v, "empty"), - # used on all platforms for callbacks. Not supported yet. - ("keepalive", lambda v: not v, "empty"), - ("pmap_nreps", lambda v: v == 1, "1"), - ("shape_poly_state", lambda v: True, "N/A"), - ): - if compile_arg in lowering.compile_args: - if not check_value(lowering.compile_args[compile_arg]): - not_implemented_msgs.append( - f"{compile_arg} must be {err_msg} and it is {lowering.compile_args[compile_arg]}") - if not_implemented_msgs: - raise NotImplementedError( - "serialization error, unimplemented lowered.compile_args:\n" + - "\n".join(not_implemented_msgs)) - -# These are the JAX custom call target names that are guaranteed to be stable. -# Their backwards compatibility is tested by back_compat_test.py. -_CUSTOM_CALL_TARGETS_GUARANTEED_STABLE = { - "Sharding", "SPMDFullToShardShape", "SPMDShardToFullShape", - "ducc_fft", "dynamic_ducc_fft", "cu_threefry2x32", - # cholesky on CPU - "lapack_spotrf", "lapack_dpotrf", "lapack_cpotrf", "lapack_zpotrf", - # eigh on CPU - "lapack_ssyevd", "lapack_dsyevd", "lapack_cheevd", "lapack_zheevd", - # eigh on GPU - "cusolver_syevj", "cusolver_syevd", - # eigh on TPU - "Eigh", - # eig on CPU - "lapack_sgeev", "lapack_dgeev", "lapack_cgeev", "lapack_zgeev", - # qr on CPU - "lapack_sgeqrf", "lapack_dgeqrf", "lapack_cgeqrf", "lapack_zgeqrf", - # householder product on CPU - "lapack_sorgqr", "lapack_dorgqr", "lapack_cungqr", "lapack_zungqr", - # svd on CPU - "lapack_sgesdd", "lapack_dgesdd", "lapack_cgesdd", "lapack_zgesdd", - # qr on GPU - "cusolver_geqrf", "cublas_geqrf_batched", - "cusolver_geqrf", "cusolver_orgqr", - # qr and svd on TPU - "Qr", "ProductOfElementaryHouseholderReflectors", - # triangular_solve on CPU - "blas_strsm", "blas_dtrsm", "blas_ctrsm", "blas_ztrsm", - # TODO(atondwal, necula): add back_compat tests for lu on CPU/GPU - # # lu on CPU - "lapack_sgetrf", "lapack_dgetrf", "lapack_cgetrf", "lapack_zgetrf", - # schur on CPU - "lapack_sgees", "lapack_dgees", "lapack_cgees", "lapack_zgees", - # # lu on GPU - # "cublas_getrf_batched", "cusolver_getrf", - # "hipblas_getrf_batched", "hipsolver_getrf", - # lu on TPU - "LuDecomposition", - # ApproxTopK on TPU - "ApproxTopK", - "tf.call_tf_function", # From jax2tf.call_tf(func, call_tf_graph=True) - "tpu_custom_call", # Pallas/TPU kernels - # TODO(burmako): maintain backwards compatibility for these, until they - # are upstreamed to StableHLO. - # See https://github.com/openxla/stablehlo/issues/8. - "stablehlo.dynamic_reduce_window", - "stablehlo.dynamic_rng_bit_generator", - "stablehlo.dynamic_top_k", - "shape_assertion", # Used by shape_poly to evaluate assertions -} - - -def _check_module(mod: ir.Module, *, - allow_non_replicated_sharding: bool, - disabled_checks: Sequence[DisabledSafetyCheck]) -> None: - """Run a number of checks on the module. - - Args: - allow_non_replicated_sharding: whether the module is allowed to contain - non_replicated sharding annotations. - disabled_checks: the safety checks that are disabled. - """ - sharding_attr = ir.StringAttr.get("Sharding", mod.context) - shape_assertion_attr = ir.StringAttr.get("shape_assertion", mod.context) - allowed_custom_call_targets: set[str] = copy.copy(_CUSTOM_CALL_TARGETS_GUARANTEED_STABLE) - for dc in disabled_checks: - target = dc.is_custom_call() - if target is not None: - allowed_custom_call_targets.add(target) - - allowed_custom_call_targets_attrs = { - ir.StringAttr.get(target, mod.context) - for target in allowed_custom_call_targets} - disallowed_custom_call_ops: list[str] = [] - def check_sharding(op: ir.Operation, loc: ir.Location): - if not allow_non_replicated_sharding: - try: - sharding = op.attributes["mhlo.sharding"] - except KeyError: - pass - else: - if ir.StringAttr(sharding).value not in ["{replicated}", ""]: - raise ValueError( - "Lowered function does not have a top-level pjit but it has" - f" non-replicated sharding annotations, e.g., {op} at {loc}.\nSee" - " https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#support-for-partitioning" - " for a discussion." - ) - - def check_op(op: ir.Operation): - op_name = op.operation.name - if op_name == "func.func": - check_sharding(op.operation, op.location) - - elif op_name == "stablehlo.custom_call" or op_name == "mhlo.custom_call": - call_target_name_attr = op.operation.attributes["call_target_name"] - if (call_target_name_attr not in allowed_custom_call_targets_attrs): - disallowed_custom_call_ops.append(f"{op} at {op.location}") - if call_target_name_attr == sharding_attr: - check_sharding(op, op.location) - elif call_target_name_attr == shape_assertion_attr: - assert (DisabledSafetyCheck.shape_assertions() not in disabled_checks) - - def walk_operations(op): - check_op(op) - for region in op.operation.regions: - for block in region: - for op in block: - walk_operations(op) - - walk_operations(mod) - if disallowed_custom_call_ops: - disallowed_custom_call_ops_str = "\n".join(disallowed_custom_call_ops) - msg = ("Cannot serialize code with custom calls whose targets have no " - "compatibility guarantees. Examples are:\n" - f"{disallowed_custom_call_ops_str}.\n" - "See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#native-lowering-supports-only-select-custom-calls") - raise ValueError(msg) - - -def _export_native_vjp(primal_fun_jax, primal: Exported) -> Exported: - # Export the VJP of `primal_fun_jax`. See documentation for Exported.vjp - - # Since jax.vjp does not handle kwargs, it is easier to do all the work - # here with flattened functions. - def fun_vjp_jax(*args_and_out_cts_flat_jax): - # Takes a flat list of primals and output cotangents - def flattened_primal_fun_jax(*args_flat): - args, kwargs = primal.in_tree.unflatten(args_flat) - res = primal_fun_jax(*args, **kwargs) - res_flat, res_tree = tree_util.tree_flatten(res) - assert res_tree == primal.out_tree - return res_flat - - args_flat_jax, out_cts_flat_jax = util.split_list(args_and_out_cts_flat_jax, - [len(primal.in_avals)]) - _, pullback_jax = jax.vjp(flattened_primal_fun_jax, *args_flat_jax) - return pullback_jax(out_cts_flat_jax) - - vjp_in_avals = list( - itertools.chain(primal.in_avals, - map(lambda a: a.at_least_vspace(), primal.out_avals))) - - # Expand in_shardings to all in_avals even not kept ones. - all_in_shardings = [sharding_impls.UNSPECIFIED] * len(primal.in_avals) - for idx, in_s in zip(sorted(primal.module_kept_var_idx), - primal.in_shardings): # type: ignore - all_in_shardings[idx] = in_s # type: ignore - all_shardings = all_in_shardings + list(primal.out_shardings) # type: ignore - # Cannot mix unspecified and specified shardings. Make the unspecified - # ones replicated. - specified_shardings = [ - s for s in all_shardings if not sharding_impls.is_unspecified(s)] - - vjp_in_shardings: Any # The primal inputs followed by output cotangents - vjp_out_shardings: Any # The primal output cotangents - if 0 == len(specified_shardings): - vjp_in_shardings = sharding_impls.UNSPECIFIED - vjp_out_shardings = sharding_impls.UNSPECIFIED - else: - if len(specified_shardings) < len(all_shardings): - # There are some specified, but not all; pjit front-end does not liwk - in_s = specified_shardings[0] # pjit will enforce that all have same devices - assert isinstance(in_s, sharding.XLACompatibleSharding) - replicated_s = sharding.GSPMDSharding.get_replicated(in_s._device_assignment) - all_shardings = [ - s if not sharding_impls.is_unspecified(s) else replicated_s - for s in all_shardings] - - vjp_in_shardings = tuple(all_shardings) - vjp_out_shardings = tuple(all_shardings[:len(primal.in_avals)]) - if all(sharding_impls.is_unspecified(s) for s in vjp_out_shardings): - vjp_out_shardings = sharding_impls.UNSPECIFIED - - fun_vjp_jax = pjit.pjit(fun_vjp_jax, - in_shardings=vjp_in_shardings, - out_shardings=vjp_out_shardings) - - return export(fun_vjp_jax, - lowering_platform=primal.lowering_platform, - disabled_checks=primal.disabled_checks)(*vjp_in_avals) - -### Importing - -def call_exported(exported: Exported) -> Callable[..., jax.Array]: - - @jax.custom_vjp - def f_flat(*args_flat): - return call_exported_p.bind(*args_flat, exported=exported) - - def f_flat_vjp_fwd(*args_flat): - # Return the primal arguments as the residual - # TODO: keep as residuals only the arguments that are needed - return f_flat(*args_flat), args_flat - - def f_flat_vjp_bwd(residual, ct_res_flat): - args_flat = residual # residual is the primal argument flat tuple - exp_vjp = exported.vjp() - in_ct_flat = call_exported(exp_vjp)(*args_flat, *ct_res_flat) - return in_ct_flat - - f_flat.defvjp(f_flat_vjp_fwd, f_flat_vjp_bwd) - - def f_imported(*args, **kwargs): - # since custom_vjp does not support kwargs, flatten the function first. - args_flat, in_tree = tree_util.tree_flatten((args, kwargs)) - if in_tree != exported.in_tree: - # Give errors with the precise tree difference; use fake leaves so we can - # use tree_util.equality_errors. - in_args = in_tree.unflatten([0] * in_tree.num_leaves) - exp_in_args = exported.in_tree.unflatten([0] * exported.in_tree.num_leaves) - - msg = ( - "The invocation args and kwargs must have the same pytree structure " - f"as when the function '{exported.fun_name}' was exported, but they " - "have the following structural differences:\n" + - ("\n".join( - f" - {shape_poly.args_kwargs_path_to_str(path)} is a {thing1} in the invocation and a " - f"{thing2} when exported, so {explanation}.\n" - for path, thing1, thing2, explanation - in tree_util.equality_errors(in_args, exp_in_args)))) - raise ValueError(msg) - - res_flat = f_flat(*args_flat) - return exported.out_tree.unflatten(res_flat) - return f_imported - - -# A JAX primitive for invoking a serialized JAX function. -call_exported_p = core.Primitive("call_exported") -call_exported_p.multiple_results = True - -@util.cache() -def _call_exported_abstract_eval(*in_avals: core.AbstractValue, - exported: Exported) -> tuple[core.AbstractValue, ...]: - exported_dim_vars = shape_poly.all_dim_vars(exported.in_avals) - assert len(in_avals) == len(exported.in_avals) # since the pytrees have the same structure - # Check that the expected shapes match the actual ones - for arg_idx, (exp_aval, actual_aval) in enumerate(zip(exported.in_avals, in_avals)): - def pp_arg_dim(dim_idx: Optional[int]) -> str: - return shape_poly.pretty_print_dimension_descriptor(exported.in_tree, - arg_idx, dim_idx) - if len(exp_aval.shape) != len(actual_aval.shape): - raise ValueError( - f"Rank mismatch for {pp_arg_dim(None)}: expected {exp_aval.shape} " - f"and called with {actual_aval.shape}") - if exp_aval.dtype != actual_aval.dtype: - raise ValueError( - f"Dtype mismatch for {pp_arg_dim(None)}: expected {exp_aval.dtype} " - f"and called with {actual_aval.dtype}") - for dim_idx, aval_d in enumerate(exp_aval.shape): - # If the exp_aval has a constant dimension then the actual argument must have - # a matching constant dimension. - if core.is_constant_dim(aval_d): - if (not core.is_constant_dim(actual_aval.shape[dim_idx]) or - aval_d != actual_aval.shape[dim_idx]): - raise ValueError( - f"Shape mismatch for {pp_arg_dim(dim_idx)} " - "(expected same constant): " - f"expected {exp_aval.shape} and called with {actual_aval.shape}") - - # Must express the exported_dim_vars in terms of the shapes in in_avals. - solution, shape_constraints, synth_dim_vars = shape_poly.solve_dim_vars( - exported.in_avals, args_kwargs_tree=exported.in_tree) - synthetic_env = {vname: in_avals[arg_idx].shape[dim_idx] - for (vname, arg_idx, dim_idx) in synth_dim_vars} - synthetic_eval = shape_poly.CachingShapeEvaluator(**synthetic_env) - # We discharge all the constraints statically. This results in much simpler - # composability (because we do not have to worry about the constraints of the - # Exported called recursively; we only need to worry about entry-point - # constraints). This also makes sense from a composibility point of view, - # because we get the same errors if we invoke the exported module, or if we - # trace the exported function. Consider for example, an exported module with - # signature `f32[a, a] -> f32[a]`. If we invoke the module with an argument - # `f32[c, d]` it is better to fail because `c == d` is inconclusive, than - # succeed and add a compile-time check that `c == d`. In the latter case, - # it would be ambiguous whether we should continue tracing with a result - # a type `f32[c]` or `f32[d]`. - shape_constraints.check_statically(synthetic_eval) - exported_dim_values = [synthetic_eval.evaluate(solution[var]) - for var in exported_dim_vars] - return tuple( - core.ShapedArray(core.evaluate_shape(out_aval.shape, exported_dim_vars, - *exported_dim_values), - dtype=out_aval.dtype, weak_type=out_aval.weak_type, - named_shape=out_aval.named_shape) - for out_aval in exported.out_avals) - - -call_exported_p.def_abstract_eval(_call_exported_abstract_eval) - -def _call_exported_impl(*args, exported: Exported): - return dispatch.apply_primitive(call_exported_p, *args, exported=exported) - -call_exported_p.def_impl(_call_exported_impl) - -def _call_exported_lowering(ctx: mlir.LoweringRuleContext, *args, - platform: str, - exported: Exported): - # TODO: implement true multi-platform lowering for call_exported - if (platform not in exported.lowering_platforms and - DisabledSafetyCheck.platform() not in exported.disabled_checks): - raise ValueError( - f"The exported function '{exported.fun_name}' was lowered for " - f"platforms '{exported.lowering_platforms}' but it is used " - f"on '{platform}'.") - - if exported.uses_shape_polymorphism: - ctx.module_context.shape_poly_state.uses_dim_vars = True - - submodule = ir.Module.parse(exported.mlir_module()) - symtab = ir.SymbolTable(submodule.operation) - # The called function may have been exported with polymorphic shapes and called - # now with more refined shapes. We insert hlo.ConvertOp to ensure the module - # is valid. - def convert_shape(x: ir.Value, x_aval: core.AbstractValue, new_aval: core.AbstractValue) -> ir.Value: - new_ir_type = mlir.aval_to_ir_type(new_aval) - if x.type != new_ir_type: - return mlir.convert_hlo(ctx, x, x_aval, new_aval) - else: - return x - - callee_type = symtab["main"].type - # TODO: maybe cache multiple calls - fn = mlir.merge_mlir_modules(ctx.module_context.module, - f"call_exported_{exported.fun_name}", - submodule) - kept_args = [ - convert_shape(a, a_aval, exported_in_aval) - for i, (a, a_aval, exported_in_aval) in enumerate(zip(args, ctx.avals_in, exported.in_avals)) - if i in exported.module_kept_var_idx] - if len(exported.lowering_platforms) > 1: - # The exported module takes a platform index argument - # TODO: implement proper handling of the platform_index when we are - # in a multi-platform lowering context. - platform_index = exported.lowering_platforms.index(platform) - arg_width = callee_type.inputs[0].element_type.width - assert arg_width in [32, 64] - platform_index = np.int32(platform_index) if arg_width == 32 else np.int64(platform_index) # type: ignore - kept_args = [mlir.ir_constant(platform_index)] + kept_args - call = func_dialect.CallOp(callee_type.results, - ir.FlatSymbolRefAttr.get(fn), - kept_args) - # The ctx.avals_out already contain the abstract values refined by - # _call_exported_abstract_eval. - return tuple( - convert_shape(out, out_aval, refined_out_aval) - for out, out_aval, refined_out_aval in zip(call.results, exported.out_avals, ctx.avals_out)) - - -for _p in ("cpu", "tpu", "cuda", "rocm"): - mlir.register_lowering(call_exported_p, - functools.partial(_call_exported_lowering, platform=_p), - platform=_p) +# TODO(necula): Remove these stubs +from jax.experimental.export.export import ( + default_lowering_platform, +) diff --git a/jax/experimental/jax2tf/shape_poly.py b/jax/experimental/jax2tf/shape_poly.py index a66497b74..90f35a744 100644 --- a/jax/experimental/jax2tf/shape_poly.py +++ b/jax/experimental/jax2tf/shape_poly.py @@ -11,1583 +11,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Shape polymorphism support. +"""This is just a backwards-compatibility stub. -We introduce a set of dimension variables at the top-level of a `jit` function. -They are introduced implicitly by way of specifying for each dimension of each -argument a symbolic dimension expression in terms of some dimension variables. -All dimension variables are assumed to range over integers greater or equal to 1. - -Symbolic dimensions overload some integer operations, such as -add, multiply, divide, equality, etc. The JAX NumPy layer and the LAX layers have been -touched up to be sensitive to handling shapes that contain symbolic dimensions. -This enables many JAX programs to be traced with symbolic dimensions -in some dimensions. A priority has been to enable the batch -dimension in neural network examples to be polymorphic. - -This was built initially for jax2tf, but it is now customizeable to be -independent of TF. The best documentation at the moment is in the -jax2tf.convert docstring, and the -[README](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md). +The main functionality has been moved to jax.experimental.export.shape_poly """ -import collections -from collections.abc import Iterable, Sequence -import dataclasses -from enum import Enum -import functools -import itertools -import io -import math -import operator as op -import threading -import tokenize -from typing import Any, Optional, Union - -import numpy as np -import opt_einsum - -import jax -from jax import config -from jax.interpreters import xla - -from jax._src import core -from jax._src import dtypes -from jax._src import effects -from jax._src.lax import lax -from jax._src.lib import version as jaxlib_version -from jax._src.interpreters import mlir -from jax._src.numpy import lax_numpy -from jax._src import tree_util -from jax._src import util -from jax._src.typing import DimSize, Shape - - -TfVal = Any -DimVarEnv = dict[str, jax.Array] -DType = Any - -class InconclusiveDimensionOperation(core.InconclusiveDimensionOperation): - """Raised when we cannot conclusively compute with symbolic dimensions.""" - - _help_msg = """ -This error arises for comparison operations with shapes that -are non-constant, and the result of the operation cannot be represented as -a boolean value for all values of the symbolic dimensions involved. - -Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables -for more details. -""" - - def __init__(self, message: str): - error_msg = f"{message}\n{InconclusiveDimensionOperation._help_msg}" - # https://github.com/python/mypy/issues/5887 - super().__init__(error_msg) # type: ignore - -class _ShapePolyThreadLocalState(threading.local): - - def __init__(self): - # TODO(necula): this does not play well with some lowering caches, because - # this state is not part of the cache key. - self.enable_shape_assertions = True - -thread_local_state = _ShapePolyThreadLocalState() - -class _DimAtom: - """Represents an atom in a symbolic dimension expression. - - Atoms are either variables, or expressions of the form floordiv(E1, E2) or - mod(E1, E2). Atoms are multiplied to form monomials (see _DimMon), and - monomials are added to form symbolic expressions (see _DimExpr). - - Args: - * var: if specified then the atom is a dimension variable. `operation` - must be `None`. - * operation: if specified then the atom is an operation applied to - `operands`. One of `FLOORDIR` or `MOD` or `NON_NEGATIVE`. `var` must be `None` - * operands: the operands to which the operation is applied. - """ - # The supported operations - FLOORDIV = "floordiv" - MOD = "mod" - NON_NEGATIVE = "non_negative" # The max of the operand and 0 - - def __init__(self, *operands: '_DimExpr', - var: Optional[str] = None, - operation: Optional[str] = None): - if var is not None: - assert operation is None - assert not operands - else: - assert operation is not None - self.var = var - self.operation = operation - self.operands = operands - - @classmethod - def from_var(cls, v: str) -> '_DimAtom': - return _DimAtom(var=v) - - def to_var(self) -> Optional[str]: - return self.var - - def get_vars(self) -> set[str]: - # All the vars that appear - if self.var is not None: - return {self.var} - else: - acc = set() - for opnd in self.operands: - acc.update(opnd.get_vars()) - return acc - - @classmethod - def from_operation(cls, operation: str, *operands: '_DimExpr') -> '_DimAtom': - return _DimAtom(*operands, operation=operation) - - def __str__(self): - if self.var is not None: - return self.var - opnd_str = ", ".join([str(opnd) for opnd in self.operands]) - return f"{self.operation}({opnd_str})" - __repr__ = __str__ - - def __hash__(self): - return hash((self.var, self.operation, *self.operands)) - - def __eq__(self, other: Any): - # Used only for hashing - if not isinstance(other, _DimAtom): return False - if (self.var is None) != (other.var is None): return False - if self.var is not None: - return self.var == other.var - else: - def symbolic_equal(e1: '_DimExpr', e2: '_DimExpr') -> bool: - try: - return e1 == e2 - except InconclusiveDimensionOperation: - return False - return (self.operation == other.operation and - all(symbolic_equal(self_o, other_o) - for self_o, other_o in zip(self.operands, other.operands))) - - def __lt__(self, other: '_DimAtom'): - """ - Comparison to another atom in graded reverse lexicographic order. - Used only for determining a sorting order, does not relate to the - comparison of the values of the atom. - """ - if self.var is not None and other.var is not None: - return self.var < other.var - elif self.var is not None: - return True - elif other.var is not None: - return True - elif self.operation != other.operation: - return self.operation < other.operation # type: ignore - else: - return id(self) < id(other) - - def bounds(self) -> tuple[float, float]: - """Returns the lower and upper bounds, or -+ inf.""" - if self.var is not None: - return (1, np.inf) # variables are assumed to be >= 1 - opnd_bounds = [opnd.bounds() for opnd in self.operands] - if self.operation == _DimAtom.FLOORDIV: # a // b - (a_l, a_u), (b_l, b_u) = opnd_bounds - def math_floor_with_inf(a: float, b: float): # math.floor, but aware of inf - assert b != 0 - if not np.isinf(b): # divisor is finite - return math.floor(a / b) if not np.isinf(a) else -np.inf if (a >= 0) != (b >= 0) else np.inf - elif not np.isinf(a): # dividend is finite and divisor is infinite - return -1 if (a >= 0) != (b >= 0) else 0 - else: # both dividend and divisor are infinite - return -np.inf if (a >= 0) != (b >= 0) else np.inf - - # Same reasoning as for multiplication: the bounds are among the cross-product - # of the bounds. - bound_candidates = [math_floor_with_inf(a_l, b_l), math_floor_with_inf(a_l, b_u), - math_floor_with_inf(a_u, b_l), math_floor_with_inf(a_u, b_u)] - return (min(*bound_candidates), max(*bound_candidates)) - - elif self.operation == _DimAtom.MOD: - _, (b_l, b_u) = opnd_bounds - if b_l > 0: # positive divisor - return (0, b_u - 1) - elif b_u < 0: # negative divisor - return (b_l + 1, 0) - else: - return (-np.inf, np.inf) - - elif self.operation == _DimAtom.NON_NEGATIVE: - (b_l, b_h), = opnd_bounds - return (max(0, b_l), max(0, b_h)) - - else: - assert False - - def evaluate(self, env: DimVarEnv): - if self.var is not None: - try: - return env[self.var] - except KeyError: - err_msg = ( - f"Encountered dimension variable '{self.var}' that is not appearing in the shapes of the used function arguments.\n" - "Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#dimension-variables-must-be-solvable-from-the-input-shapes for more details.") - raise KeyError(err_msg) - else: - operand_values = [opnd.evaluate(env) for opnd in self.operands] - if self.operation == _DimAtom.FLOORDIV: - return divmod(*operand_values)[0] # type: ignore - elif self.operation == _DimAtom.MOD: - return divmod(*operand_values)[1] # type: ignore - elif self.operation == _DimAtom.NON_NEGATIVE: - return lax.max(operand_values[0], 0) - else: - assert False, self.operation - -class _DimMon(dict): - """Represents a multiplication of atoms. - - The representation is a dictionary mapping _DimAtom to exponent. - The exponents are integers >= 1. - """ - def __hash__(self): - return hash(frozenset(self.items())) - - def __str__(self): - return "*".join(f"{key}^{exponent}" if exponent != 1 else str(key) - for key, exponent in sorted(self.items())) - - @classmethod - def from_var(cls, v: str) -> '_DimMon': - return _DimMon({_DimAtom.from_var(v): 1}) - - @classmethod - def from_atom(clscls, a: _DimAtom, aexp: int): - return _DimMon({a: aexp}) - - def to_var(self) -> Optional[str]: - """Extract the variable name "x", from a monomial "x". - Return None, if the monomial is not a single variable.""" - items = self.items() - if len(items) != 1: - return None - (a, aexp), = items - if aexp != 1: - return None - return a.to_var() - - def get_vars(self) -> set[str]: - # All the vars that appear in the monomial - acc = set() - for a in self.keys(): - acc.update(a.get_vars()) - return acc - - @classmethod - def from_operation(cls, operation: str, *operands: '_DimExpr') -> '_DimMon': - return _DimMon({_DimAtom.from_operation(operation, *operands): 1}) - - @property - def degree(self): - return sum(self.values()) - - def __lt__(self, other: '_DimMon'): - """ - Comparison to another monomial in graded reverse lexicographic order. - Used only for determining a sorting order, does not relate to the - comparison of the values of the monomial. - """ - self_key = -self.degree, tuple(sorted(self)) - other_key = -other.degree, tuple(sorted(other)) - return self_key > other_key - - def mul(self, other: '_DimMon') -> '_DimMon': - """ - Returns the product with another monomial. Example: (n^2*m) * n == n^3 * m. - """ - return _DimMon(collections.Counter(self) + collections.Counter(other)) - - def divide(self, divisor: '_DimMon') -> '_DimMon': - """ - Divides by another monomial. Raises a InconclusiveDimensionOperation - if the result is not a monomial. - For example, (n^3 * m) // n == n^2*m, but n // m fails. - """ - d = collections.Counter(self) - for key, exponent in divisor.items(): - diff = self.get(key, 0) - exponent - if diff < 0: - raise InconclusiveDimensionOperation(f"Cannot divide {self} by {divisor}.") - elif diff == 0: del d[key] - elif diff > 0: d[key] = diff - return _DimMon(d) - - def bounds(self) -> tuple[float, float]: - """Returns the lower and upper bounds, or -+inf.""" - # The bounds of a product are among the product of bounds. - bounds = [] - for a, exp in self.items(): - a_l, a_u = a.bounds() - assert a_l <= a_u - bounds.append((a_l ** exp, a_u ** exp)) - - candidates = [math.prod(atom_bounds) for atom_bounds in itertools.product(*bounds)] - return (min(*candidates), max(*candidates)) # type: ignore - - - def evaluate(self, env: DimVarEnv): - prod = lambda xs: functools.reduce(_evaluate_multiply, xs) if xs else core.dim_constant(1) - def pow_opt(v, p: int): - return v if p == 1 else prod([v] * p) - return prod([pow_opt(a.evaluate(env), deg) for a, deg in self.items()]) - - -class _DimExpr(): - """Symbolic expression in terms of dimension variables. - - A dimension expression is an addition of products (_DimMon) - of atoms (_DimAtom). - - We overload integer operations, but we do that soundly, raising - :class:`InconclusiveDimensionOperation` when the result is not - representable as a _DimExpr. - - The representation of a _DimExpr is as a dictionary mapping _DimMon to - integer coefficients. The special monomial `_DimMon()` is mapped to the - free integer coefficient of the expression. - """ - - __array_priority__ = 1000 # Same as tracer, for __radd__ and others on ndarray - def __init__(self, coeffs: dict[_DimMon, int]): - # Do not construct _DimExpr directly, unless you are sure that coeffs is - # normalized; Use _DimExpr.normalize. - # Takes ownership of coeffs - self._coeffs = coeffs or {_DimMon(): 0} - - def monomials(self) -> Iterable[tuple[_DimMon, int]]: - return self._coeffs.items() - - @classmethod - def _add_coeffs(cls, coeffs: dict[_DimMon, int], mon: _DimMon, coeff: int): - """Do `coeffs[mon] += coeff` but remove 0 coefficients.""" - old_c = coeffs.get(mon) - if old_c is None: - if coeff != 0: coeffs[mon] = coeff - else: - new_c = old_c + coeff - if new_c == 0: - del coeffs[mon] - else: - coeffs[mon] = new_c - - @classmethod - def normalize(cls, coeffs: dict[_DimMon, int]) -> DimSize: - """The main constructor for _DimExpr. - - Ensures that the symbolic dimension is normalized, e.g., - it is represented as a Python int if it is known to be a constant. - """ - # TODO(necula): profile and optimize this - has_non_zero_degree = False - free_const = 0 - new_coeffs: dict[_DimMon, int] = {} - for mon, coeff in coeffs.items(): - if coeff == 0: continue - if mon.degree == 0: # A constant, there can be a single one - free_const = coeff - else: - has_non_zero_degree = True - - new_coeffs[mon] = new_coeffs.get(mon, 0) + coeff - - if has_non_zero_degree: - return _DimExpr(new_coeffs) - else: - return int(free_const) - - @classmethod - def normalize_floordiv_times_divisor(cls, coeffs: dict[_DimMon, int]) -> DimSize: - # Look for floordiv(E, M) * M and turn into E - mod(E, M). This comes - # up when handling strided convolution. - for dec in _decompose_expr(_DimExpr(coeffs), _DimAtom.FLOORDIV): - # e = factor * floordiv(operands)^exp * rest_monomial + rest_expr - if dec.exp != 1: - continue - if dec.rest_monomial == 1 and dec.factor == 1: - continue - m_trimmed, m_remainder = divmod(dec.factor * dec.rest_monomial, dec.operands[1]) - if m_remainder == 0: - return m_trimmed * (dec.operands[0] - _DimExpr.from_operation(_DimAtom.MOD, *dec.operands)) + dec.rest_expr - return _DimExpr.normalize(coeffs) - - @classmethod - def from_monomial(cls, mon: _DimMon, exp: int): - return _DimExpr.normalize({mon: exp}) - - @classmethod - def from_var(cls, v: str) -> '_DimExpr': - return _DimExpr({_DimMon.from_var(v): 1}) - - @classmethod - def from_operation(cls, operation: str, *operands: '_DimExpr') -> '_DimExpr': - return _DimExpr.from_monomial(_DimMon.from_operation(operation, *operands), 1) - - def to_var(self) -> Optional[str]: - """Extract the variable name "x", from a symbolic expression.""" - items = self.monomials() - if len(items) != 1: # type: ignore - return None - (mon, mon_count), = items - if mon_count != 1: - return None - return mon.to_var() - - def get_vars(self) -> set[str]: - """The variables that appear in a symbolic dimension.""" - acc = set() - for mon, _ in self.monomials(): - acc.update(mon.get_vars()) - return acc - - def eq(self, other: DimSize) -> bool: - lb, ub = _ensure_poly(self - other, "eq").bounds() - if lb == ub == 0: - return True - if lb > 0 or ub < 0: - return False - # See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#comparison-of-symbolic-dimensions-is-partially-supported - return False - - def inconclusive_comparison(self, operation: str, op: Any) -> Exception: - return InconclusiveDimensionOperation( - f"Symbolic dimension comparison '{self}' {operation} '{op}' is inconclusive.\n" - "See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#comparison-of-symbolic-dimensions-is-partially-supported.") - - def ge(self, other: DimSize) -> bool: - lb, ub = _ensure_poly(self - other, "ge").bounds() - if lb >= 0: - return True - if ub < 0: - return False - raise self.inconclusive_comparison(">=", other) - - def __hash__(self): - return hash(tuple(sorted(self.monomials()))) - - def __str__(self): - def _one_monomial(mon, c): - if mon.degree == 0: - return str(c) - if c == 1: - return str(mon) - return f"{c}*{mon}" - return " + ".join(_one_monomial(mon, c) - for mon, c in sorted(self.monomials(), reverse=True)) - - def __repr__(self): - return str(self) - - # We overload +, -, *, because they are fully defined for _DimExpr. - def __add__(self, other): - if isinstance(other, core.Tracer) or not _convertible_to_poly(other): - return self.__jax_array__().__add__(other) - - other = _ensure_poly(other, "add") - coeffs = self._coeffs.copy() - for mon, coeff in other.monomials(): - _DimExpr._add_coeffs(coeffs, mon, coeff) - return _DimExpr.normalize_floordiv_times_divisor(coeffs) - - def __radd__(self, other): - if isinstance(other, core.Tracer) or not _convertible_to_poly(other): - return self.__jax_array__().__radd__(other) - return _ensure_poly(other, "add").__add__(self) - - def __sub__(self, other): - if isinstance(other, core.Tracer) or not _convertible_to_poly(other): - return self.__jax_array__().__sub__(other) - return self + -_ensure_poly(other, "sub") - - def __rsub__(self, other): - if isinstance(other, core.Tracer) or not _convertible_to_poly(other): - return self.__jax_array__().__rsub__(other) - return _ensure_poly(other, "sub").__sub__(self) - - def __neg__(self) -> '_DimExpr': - return _DimExpr({mon: -coeff for mon, coeff in self.monomials()}) - - def __mul__(self, other): - if isinstance(other, core.Tracer) or not _convertible_to_poly(other): - return self.__jax_array__().__mul__(other) - other = _ensure_poly(other, "mul") - coeffs: dict[_DimMon, int] = {} - for mon1, coeff1 in self.monomials(): - for mon2, coeff2 in other.monomials(): - mon = mon1.mul(mon2) - _DimExpr._add_coeffs(coeffs, mon, coeff1 * coeff2) - return _DimExpr.normalize_floordiv_times_divisor(coeffs) - - def __rmul__(self, other): - if isinstance(other, core.Tracer) or not _convertible_to_poly(other): - return self.__jax_array__().__rmul__(other) - return _ensure_poly(other, "mul").__mul__(self) - - def __pow__(self, power, modulo=None): - assert modulo is None - try: - power = int(power) - except: - raise InconclusiveDimensionOperation(f"Symblic dimension cannot be raised to non-integer power '{self}' ^ '{power}'") - return functools.reduce(op.mul, [self] * power) - - def __floordiv__(self, divisor): - if isinstance(divisor, core.Tracer) or not _convertible_to_poly(divisor): - return self.__jax_array__().__floordiv__(divisor) - return self.divmod(_ensure_poly(divisor, "floordiv"))[0] - - def __rfloordiv__(self, other): - if isinstance(other, core.Tracer) or not _convertible_to_poly(other): - return self.__jax_array__().__rfloordiv__(other) - return _ensure_poly(other, "floordiv").__floordiv__(self) - - def __truediv__(self, divisor): - # Used for "/", which always returns a float - return self.__jax_array__().__truediv__(divisor) - - def __rtruediv__(self, dividend): - # Used for "/", when dividend is not a _DimExpr - return self.__jax_array__().__rtruediv__(dividend) - - def __mod__(self, divisor): - if isinstance(divisor, core.Tracer) or not _convertible_to_poly(divisor): - return self.__jax_array__().__mod__(divisor) - return self.divmod(_ensure_poly(divisor, "mod"))[1] - - def __rmod__(self, dividend): - if isinstance(dividend, core.Tracer) or not _convertible_to_poly(dividend): - return self.__jax_array__().__rmod__(dividend) - return _ensure_poly(dividend, "mod").__mod__(self) - - def __divmod__(self, divisor): - if isinstance(divisor, core.Tracer) or not _convertible_to_poly(divisor): - return self.__jax_array__().__divmod__(divisor) - return self.divmod(_ensure_poly(divisor, "divmod")) - - def __rdivmod__(self, dividend): - if isinstance(dividend, core.Tracer) or not _convertible_to_poly(dividend): - return self.__jax_array__().__rdivmod__(dividend) - return _ensure_poly(dividend, "divmod").__divmod__(self) - - def __int__(self): - if self.is_constant: - return op.index(next(iter(self._coeffs.values()))) - else: - raise InconclusiveDimensionOperation(f"Symbolic dimension '{self}' used in a context that requires a constant") - - # We must overload __eq__ and __ne__, or else we get unsound defaults. - __eq__ = eq - def __ne__(self, other: DimSize) -> bool: - return not self.eq(other) - - __ge__ = ge - - def __le__(self, other: DimSize): - try: - return _ensure_poly(other, "le").__ge__(self) - except InconclusiveDimensionOperation as e: - raise self.inconclusive_comparison("<=", other) from e - - def __gt__(self, other: DimSize): - try: - return not _ensure_poly(other, "gt").__ge__(self) - except InconclusiveDimensionOperation as e: - raise self.inconclusive_comparison(">", other) from e - - def __lt__(self, other: DimSize): - try: - return not self.__ge__(other) - except InconclusiveDimensionOperation as e: - raise self.inconclusive_comparison("<", other) from e - - def divmod(self, divisor: "_DimExpr") -> tuple[DimSize, int]: - """ - Floor division with remainder (divmod) generalized to polynomials. - If the `divisor` is not a constant, the remainder must be 0. - If the `divisor` is a constant, the remainder may be non 0, for consistency - with integer divmod. - - :return: Quotient resulting from polynomial division and integer remainder. - """ - assert isinstance(divisor, _DimExpr) - try: - dmon, dcount = divisor.leading_term - dividend, quotient = self, 0 - # invariant: self = dividend + divisor * quotient - # quotient and dividend are changed in the loop; the leading term of - # dividend decreases at each iteration. - while is_poly_dim(dividend) and not dividend.is_constant: - mon, count = dividend.leading_term - try: - qmon = mon.divide(dmon) - except InconclusiveDimensionOperation: - raise InconclusiveDimensionOperation("") - qcount, rcount = divmod(count, dcount) - if rcount != 0: - raise InconclusiveDimensionOperation("") - - q = _DimExpr.from_monomial(qmon, qcount) - quotient += q - dividend -= q * divisor # type: ignore[assignment] - - dividend = int(dividend) # type: ignore[assignment] - if divisor.is_constant: - q, r = divmod(dividend, int(divisor)) # type: ignore - quotient += q - remainder = r - else: - if dividend != 0: - raise InconclusiveDimensionOperation("") - remainder = 0 - - if config.jax_enable_checks: - assert self == divisor * quotient + remainder - return quotient, remainder - except InconclusiveDimensionOperation: - return (_DimExpr.from_operation(_DimAtom.FLOORDIV, self, divisor), # type: ignore - _DimExpr.from_operation(_DimAtom.MOD, self, divisor)) - - def bounds(self) -> tuple[float, float]: - """Returns the lower and upper bounds, or -+inf.""" - lb = ub = self._coeffs.get(_DimMon(), 0) # The free coefficient - for mon, coeff in self.monomials(): - if mon.degree == 0: continue # We already included the free coefficient - m_l, m_u = mon.bounds() - assert m_l <= m_u and coeff != 0 - item_l, item_u = coeff * m_l, coeff * m_u - lb = lb + min(item_l, item_u) # type: ignore - ub = ub + max(item_l, item_u) # type: ignore - - if lb != -np.inf or ub != np.inf: - return lb, ub - # Watch for special-case: ct*a - ct*mod(b, a) >= 1 when ct >= 0 and a >= 0 - # TODO(necula): add more principled support for floordiv and mod - # For example, this will miss "1 + a - mod(b, a)" - for dec in _decompose_expr(self, _DimAtom.MOD): - # E = factor*mod(op1, op2)^exp * rest_monomial + rest_expr - if dec.exp == 1 and dec.rest_monomial == 1 and dec.rest_expr == - dec.factor * dec.operands[1]: - try: - if dec.operands[1] <= 0: - continue - except InconclusiveDimensionOperation: - continue - if dec.factor > 0: - return (-np.inf, -1) - else: - return (1, np.inf) - - return lb, ub - - @property - def is_constant(self): - return len(self._coeffs) == 1 and next(iter(self._coeffs)).degree == 0 - - @property - def leading_term(self) -> tuple[_DimMon, int]: - """Returns the highest degree term that comes first lexicographically.""" - return max(self.monomials()) - - def evaluate(self, env: DimVarEnv): - # Evaluates as a value of dtype=core.dim_value_dtype() - terms = [_evaluate_multiply(mon.evaluate(env), core.dim_constant(coeff)) - for mon, coeff in self.monomials()] - return functools.reduce(_evaluate_add, terms) if len(terms) > 1 else terms[0] - - def non_negative(self) -> "_DimExpr": - return _DimExpr.from_operation(_DimAtom.NON_NEGATIVE, self) - - @staticmethod - def get_aval(dim: "_DimExpr"): - return core.dim_value_aval() - - def dimension_as_value(self): - """Turns a dimension size into a Jax value that we can compute with.""" - return _dim_as_value(self) - - def __jax_array__(self): - # Used for implicit coercions of polynomials as JAX arrays - return _dim_as_value(self) - -@dataclasses.dataclass -class _Decomposition: - """Decomposition of an expression around an operation atom. - - E = factor * mod(*operands)^exp * rest_monomial + rest_expr - """ - factor: int - operands: Sequence[_DimExpr] - exp: int - rest_monomial: _DimExpr - rest_expr: _DimExpr - - -def _decompose_expr(e: _DimExpr, operation: str) -> Iterable[_Decomposition]: - for m, m_factor in e.monomials(): - atoms = [(a, aexp) for a, aexp in m.items() if a.operation == operation] - if atoms: - e_minus_m_coeffs = e._coeffs.copy() - del e_minus_m_coeffs[m] - for a, aexp in atoms: - yield _Decomposition( - factor=m_factor, - operands=a.operands, - exp=aexp, - rest_monomial=_DimExpr({m.divide(_DimMon.from_atom(a, aexp)): 1}), - rest_expr=_DimExpr(e_minus_m_coeffs)) - -core.pytype_aval_mappings[_DimExpr] = _DimExpr.get_aval -xla.pytype_aval_mappings[_DimExpr] = _DimExpr.get_aval -dtypes._weak_types.append(_DimExpr) - -def _convertible_to_int(p: DimSize) -> bool: - try: - op.index(p) - return True - except: - return False - -def _ensure_poly(p: DimSize, - operation_name: str) -> _DimExpr: - if isinstance(p, _DimExpr): return p - if _convertible_to_int(p): - return _DimExpr({_DimMon(): op.index(p)}) - raise TypeError(f"Symnbolic dimension {operation_name} not supported for {p}.") - -def _convertible_to_poly(p: DimSize) -> bool: - return isinstance(p, _DimExpr) or _convertible_to_int(p) - -def is_poly_dim(p: DimSize) -> bool: - return isinstance(p, _DimExpr) - -dtypes.python_scalar_dtypes[_DimExpr] = dtypes.python_scalar_dtypes[int] - -def _einsum_contract_path(*operands, **kwargs): - """Like opt_einsum.contract_path, with support for DimExpr shapes. - - We use opt_einsum.contract_path to compute the schedule, using a fixed - constant for all dimension variables. This is safe because we throw an - error if there are more than 1 contractions. Essentially, we just use - opt_einsum.contract_path to parse the specification. - """ - - # Replace the polymorphic shapes with some concrete shapes for calling - # into opt_einsum.contract_path, because the latter wants to compute the - # sizes of operands and intermediate results. - fake_ops = [] - for operand in operands: - # We replace only array operands - if not hasattr(operand, "dtype"): - fake_ops.append(operand) - else: - shape = np.shape(operand) - def fake_dim(d): - if core.is_constant_dim(d): - return d - else: - if not isinstance(d, _DimExpr): - raise TypeError(f"Encountered unexpected shape dimension {d}") - # It is Ok to replace all polynomials with the same value. We may miss - # here some errors due to non-equal dimensions, but we catch them - # later. - return 8 - fake_ops.append(jax.ShapeDtypeStruct(tuple(map(fake_dim, shape)), - operand.dtype)) - - contract_fake_ops, contractions = opt_einsum.contract_path(*fake_ops, - **kwargs) - contract_operands = [] - for operand in contract_fake_ops: - idx = tuple(i for i, fake_op in enumerate(fake_ops) if operand is fake_op) - assert len(idx) == 1 - contract_operands.append(operands[idx[0]]) - return contract_operands, contractions - -lax_numpy._poly_einsum_handlers[_DimExpr] = _einsum_contract_path - -# To implement shape-constraint checking we use a shape assertion primitive. -# shape_assertion_p.bind(assert_what: bool, *error_message_inputs, -# error_message="...{0}...{1}") -# where "{0}" refers to error_message_inputs[0], etc. -shape_assertion_p = core.Primitive("shape_assertion") -shape_assertion_p.multiple_results = True -shape_assertion_p.def_effectful_abstract_eval( - lambda *_, **__: ((), {shape_assertion_effect})) # type: ignore - -def _shape_assertion_lowering_rule(ctx: mlir.LoweringRuleContext, - assert_what: mlir.ir.Value, - *error_message_inputs: mlir.ir.Value, - error_message: str): - op = mlir.custom_call( - "shape_assertion", - result_types=[], # No results - operands=[assert_what, *error_message_inputs], - has_side_effect=True, - extra_attributes=dict(error_message=mlir.ir.StringAttr.get(error_message)) - ) - return op.results - -mlir.register_lowering(shape_assertion_p, _shape_assertion_lowering_rule) - -class ShapeAssertionEffect(effects.Effect): - __str__ = lambda _: "ShapeAssertionEffect" - -shape_assertion_effect = ShapeAssertionEffect() - -effects.lowerable_effects.add_type(ShapeAssertionEffect) -effects.control_flow_allowed_effects.add_type(ShapeAssertionEffect) -effects.remat_allowed_effects.add_type(ShapeAssertionEffect) -effects.custom_derivatives_allowed_effects.add_type(ShapeAssertionEffect) - -def shape_assertion(assert_what: jax.Array, - *error_message_inputs: jax.Array, - error_message: str) -> None: - """Adds a shape assertion in the code. - - Args: - assert_what: a boolean asserted to be true. Must be computed based only - on dimension expressions, so that it can be evaluated after shape - refinement. - error_message_inputs: integers expressions whose values can be referenced - in the `error_message`. Must be computed based only - on dimension expressions, so that they can be evaluated after shape - refinement. - error_message: an error message, possibly containing format specifiers - {0}, {1}, ..., referencing the values of the `error_message_inputs`. - The format specifiers are sometimes processed with Python's - `string::format` method, and sometimes with `llvm::formatv`. - """ - if thread_local_state.enable_shape_assertions: - shape_assertion_p.bind(assert_what, *error_message_inputs, - error_message=error_message) - -# A JAX primitive with no array arguments but with a dimension parameter -# that is a DimExpr. The value of the primitive is the value of the dimension, -# using int64 in x64 mode or int32 otherwise (core.dim_value_dtype()) -dim_as_value_p = core.Primitive("dim_as_value") -dim_as_value_p.def_abstract_eval(lambda dim: core.dim_value_aval()) - -def dim_as_value_impl(dim: DimSize): - raise NotImplementedError( - "Evaluation rule for 'dim_as_value' is not implemented. " - "It seems that you are using shape polymorphism outside jax2tf.") - -dim_as_value_p.def_impl(dim_as_value_impl) -def _dim_as_value(dim: DimSize): - return dim_as_value_p.bind(dim=dim) - -def _dim_as_value_lowering(ctx: mlir.LoweringRuleContext, *, - dim): - res, = mlir.eval_dynamic_shape(ctx, (dim,)) - out_type = mlir.aval_to_ir_type(ctx.avals_out[0]) - if out_type != res.type: # type: ignore - return mlir.hlo.ConvertOp(out_type, res).results - else: - return [res] - -mlir.register_lowering(dim_as_value_p, _dim_as_value_lowering) - - -class PolyShape(tuple): - """Tuple of polymorphic dimension specifications. - - See docstring of :func:`jax2tf.convert`. - """ - - def __init__(self, *dim_specs): - tuple.__init__(dim_specs) - - def __new__(cls, *dim_specs): - for ds in dim_specs: - if not isinstance(ds, (int, str)) and ds != ...: - msg = (f"Invalid polymorphic shape element: {repr(ds)}; must be a string " - "representing a dimension variable, or an integer, or ...") - raise ValueError(msg) - return tuple.__new__(PolyShape, dim_specs) - - def __str__(self): - return "(" + ", ".join(["..." if d is ... else str(d) for d in self]) + ")" - - -def _parse_spec(shape_spec: Union[str, PolyShape, None], - arg_shape: Sequence[Optional[int]]) -> Sequence[DimSize]: - """Parses the shape polymorphic specification for one array argument. - - We have to be able to parse all strings produced by str(_DimExpr) because - sometimes the output polymorphic shapes of one function become the input - polymorphic shapes of another. - - Args: - shape_spec: a shape polymorphic specification. None stands for "...". - arg_shape: an actual shape, possibly containing unknown dimensions (None). - We use `arg_shape` to fill-in the placeholders `_` and `...` in - the `shape_spec`. The dimensions of `arg_shape` that are used for filling - must be known (not `None`). If a dimension in `arg_shape` is known and - the corresponding dimension in `shape_spec` is a constant then they - must be equal. - - See the README.md for usage. - """ - shape_spec_repr = repr(shape_spec) - if shape_spec is None: - shape_spec = "..." - elif isinstance(shape_spec, PolyShape): - shape_spec = str(shape_spec) - elif not isinstance(shape_spec, str): - raise ValueError("polymorphic shape spec should be None or a string. " - f"Found {shape_spec_repr}.") - return _Parser(shape_spec, arg_shape, shape_spec_repr).parse() - -class _Parser: - def __init__(self, - shape_spec: str, - arg_shape: Sequence[Optional[int]], - shape_spec_repr: str): - self.shape_spec = shape_spec - self.shape_spec_repr = shape_spec_repr # For error messages - self.arg_shape = arg_shape - self.dimensions: list[DimSize] = [] # dimensions we have parsed - - def parse(self) -> Sequence[DimSize]: - self.tokstream = tokenize.tokenize( - io.BytesIO(self.shape_spec.encode("utf-8")).readline) - tok = self.consume_token(self.next_tok(), tokenize.ENCODING) # Always 1st - sh, tok = self.shape(tok) - self.expect_token(tok, [tokenize.ENDMARKER]) - return sh - - def add_dim(self, expr: Optional[DimSize], tok: tokenize.TokenInfo): - if expr is None: - raise self.parse_err(tok, - ("unexpected placeholder for unknown dimension " - f"for argument shape {self.arg_shape}")) - arg_shape_dim = self.arg_shape[len(self.dimensions)] - if core.is_constant_dim(expr) and arg_shape_dim is not None: - if expr != arg_shape_dim: - raise self.parse_err(tok, - (f"different size {expr} for known dimension " - f"for argument shape {self.arg_shape}")) - self.dimensions.append(expr) - - def parse_err(self, tok: Optional[tokenize.TokenInfo], detail: str) -> Exception: - msg = ( - f"syntax error in polymorphic shape {self.shape_spec_repr} " - f"in dimension {len(self.dimensions)}: {detail}. ") - if tok is not None: - msg += f"Parsed '{tok.line[:tok.start[1]]}', remaining '{tok.line[tok.start[1]:]}'." - return ValueError(msg) - - def next_tok(self) -> tokenize.TokenInfo: - while True: - try: - t = next(self.tokstream) - except StopIteration: - raise self.parse_err(None, "unexpected end of string") - if t.exact_type not in [tokenize.NEWLINE, tokenize.INDENT, tokenize.DEDENT]: - return t - - def expect_token(self, tok: tokenize.TokenInfo, expected: Sequence[int]) -> None: - if tok.exact_type not in expected: - msg = ("expecting one of {" + - ", ".join(tokenize.tok_name[t] for t in expected) + "} but found " + - tokenize.tok_name[tok.exact_type]) - raise self.parse_err(tok, msg) - - def consume_token(self, tok: tokenize.TokenInfo, expected: int) -> tokenize.TokenInfo: - self.expect_token(tok, [expected]) - return self.next_tok() - - def integer(self, tok: tokenize.TokenInfo) -> tuple[int, tokenize.TokenInfo]: - self.expect_token(tok, [tokenize.NUMBER]) - try: - val = int(tok.string) - except Exception: - raise self.parse_err(tok, f"expecting integer, found {tok.string}") - return val, self.next_tok() - - # What can follow a shape? - FOLLOW_SHAPE = [tokenize.ENDMARKER, tokenize.RPAR] - def shape(self, tok: tokenize.TokenInfo) -> tuple[Sequence[DimSize], tokenize.TokenInfo]: - # A comma-separated list of _DimExpr, or "_", possibly ended with ... - if tok.exact_type == tokenize.LPAR: - res, tok = self.shape(self.next_tok()) - tok = self.consume_token(tok, tokenize.RPAR) - return res, tok - - while True: - if tok.exact_type in self.FOLLOW_SHAPE: - break - if tok.exact_type == tokenize.ELLIPSIS: - to_add = self.arg_shape[len(self.dimensions):] - for ad in to_add: - self.add_dim(ad, tok) - tok = self.next_tok() - break - if len(self.dimensions) >= len(self.arg_shape): - raise self.parse_err(tok, - f"too many dimensions, arg_shape has {len(self.arg_shape)}") - if tok.exact_type == tokenize.NAME and tok.string == "_": - e = self.arg_shape[len(self.dimensions)] - tok = self.next_tok() - else: - e, tok = self.expr(tok) - self.add_dim(e, tok) - if tok.exact_type in self.FOLLOW_SHAPE: - break - tok = self.consume_token(tok, tokenize.COMMA) - - return tuple(self.dimensions), tok - - # What token can follow a _DimExpr - FOLLOW_EXPR = FOLLOW_SHAPE + [tokenize.COMMA] - - def expr(self, tok: tokenize.TokenInfo) -> tuple[DimSize, tokenize.TokenInfo]: - # A sum of monomials - next_m_negated = False - acc = 0 - while True: - m, tok = self.mon(tok) - acc = acc + (- m if next_m_negated else m) - if tok.exact_type in self.FOLLOW_EXPR: - return acc, tok - next_m_negated = (tok.exact_type == tokenize.MINUS) - self.expect_token(tok, [tokenize.PLUS, tokenize.MINUS]) - tok = self.next_tok() - - FOLLOW_MON = FOLLOW_EXPR + [tokenize.PLUS, tokenize.MINUS] - def mon(self, tok: tokenize.TokenInfo) -> tuple[DimSize, tokenize.TokenInfo]: - # A monomial is product of atoms. Each atom may be raised to an integer power. - acc = 1 - while True: - a, tok = self.atom(tok) - if tok.exact_type == tokenize.CIRCUMFLEX: - tok = self.next_tok() - self.expect_token(tok, [tokenize.NUMBER]) - power, tok = self.integer(tok) - a = a ** power - - acc = acc * a - if tok.exact_type in self.FOLLOW_MON: - return acc, tok - tok = self.consume_token(tok, tokenize.STAR) - - def atom(self, tok: tokenize.TokenInfo) -> tuple[DimSize, tokenize.TokenInfo]: - if tok.exact_type == tokenize.NAME: - if tok.string == _DimAtom.MOD: - return self.binary_op(_DimAtom.MOD, self.next_tok()) - if tok.string == _DimAtom.FLOORDIV: - return self.binary_op(_DimAtom.FLOORDIV, self.next_tok()) - if tok.string == _DimAtom.NON_NEGATIVE: - return self.unary_op(_DimAtom.NON_NEGATIVE, self.next_tok()) - return _DimExpr.from_var(tok.string), self.next_tok() - number_sign = 1 - if tok.exact_type == tokenize.MINUS: # -k are negative constants - number_sign = -1 - tok = self.next_tok() - self.expect_token(tok, [tokenize.NUMBER]) - if tok.exact_type == tokenize.NUMBER: - v, tok = self.integer(tok) - return v * number_sign, tok - self.expect_token(tok, [tokenize.NAME, tokenize.MINUS, tokenize.NUMBER]) - assert False - - def unary_op(self, op: str, tok) -> tuple[DimSize, tokenize.TokenInfo]: - tok = self.consume_token(tok, tokenize.LPAR) - e1, tok = self.expr(tok) - tok = self.consume_token(tok, tokenize.RPAR) - return _DimExpr.from_operation(op, e1), tok # type: ignore - - def binary_op(self, op: str, tok) -> tuple[DimSize, tokenize.TokenInfo]: - tok = self.consume_token(tok, tokenize.LPAR) - e1, tok = self.expr(tok) - tok = self.consume_token(tok, tokenize.COMMA) - e2, tok = self.expr(tok) - tok = self.consume_token(tok, tokenize.RPAR) - return _DimExpr.from_operation(op, e1, e2), tok # type: ignore - - -def _evaluate_add(v1, v2): - try: - if op.index(v1) == 0: - return v2 - except: - pass - try: - if op.index(v2) == 0: - return v1 - except: - pass - return v1 + v2 - -def _evaluate_multiply(v1, v2): - try: - if op.index(v1) == 1: - return v2 - except: - pass - try: - if op.index(v2) == 1: - return v1 - except: - pass - return v1 * v2 - -# dimension_size(operand, dimension=i) get the operand.shape[i] as a -# value of type shape_poly.dim_as_value_dtype(). -dimension_size_p = core.Primitive("dimension_size") -def _dimension_size_abstract_eval(aval: core.AbstractValue, **_) -> core.AbstractValue: - return core.dim_value_aval() - -dimension_size_p.def_abstract_eval(_dimension_size_abstract_eval) - -def _dimension_size_impl(arg, *, dimension): - return core.dim_constant(arg.shape[dimension]) -dimension_size_p.def_impl(_dimension_size_impl) - -def _dimension_size_lowering_rule(ctx, arg, *, dimension): - dim_size = mlir.hlo.GetDimensionSizeOp(arg, dimension) - dim_type = mlir.aval_to_ir_type(core.dim_value_aval()) - if dim_size.result.type != dim_type: - dim_size = mlir.hlo.ConvertOp(dim_type, dim_size) - return dim_size.results - -mlir.register_lowering(dimension_size_p, _dimension_size_lowering_rule) - - -def arg_aval( - arg_shape: Sequence[Optional[int]], - arg_jax_dtype: DType, - polymorphic_shape: Optional[Union[str, PolyShape]]) -> core.ShapedArray: - """Computes abstract values. - - Args: - arg_shape: the shape for the argument, possibly having None dimensions. - arg_dtype: the inferred JAX dtype for the arg. - polymorphic_shape: the polymorphic specification for the argument. - Returns: the JAX abstract value for the argument. - """ - aval_shape = _parse_spec(polymorphic_shape, arg_shape) - return core.ShapedArray(aval_shape, arg_jax_dtype) - -def all_dim_vars(args_avals: Sequence[core.AbstractValue]) -> Sequence[str]: - dim_vars: set[str] = set() - for a in args_avals: - for d in a.shape: - if is_poly_dim(d): - dim_vars = dim_vars.union(d.get_vars()) - return sorted(tuple(dim_vars)) - - -class CachingShapeEvaluator: - def __init__(self, **env): - self.env = env - - @functools.lru_cache(128) - def evaluate(self, e: DimSize): - if core.is_constant_dim(e): - res = op.index(e) - else: - res = e.evaluate(self.env) # type: ignore - return res - - -@dataclasses.dataclass(frozen=True) -class ShapeConstraint: - class Comparator(Enum): - EQ = 1 - GEQ = 2 - - comp: Comparator - left: DimSize - right: DimSize - # `error_message_pieces` is a list of strings and DimSize. The error message - # is formed by evaluating the DimSize and concatenating the sequence. - error_message_pieces: Sequence[Union[str, DimSize]] - - def check_statically(self, eval: CachingShapeEvaluator) -> None: - """Evaluates a constraint statically.""" - left, right = eval.evaluate(self.left), eval.evaluate(self.right) - try: - if self.comp == ShapeConstraint.Comparator.EQ: - ok = (left == right) - elif self.comp == ShapeConstraint.Comparator.GEQ: - ok = (left >= right) - else: - assert False # We are in a context where we know we can evaluate - # all symbolic expressions to constants. - except InconclusiveDimensionOperation as e: - raise self.make_error(eval) from e - if not ok: - raise self.make_error(eval) - - def compute(self, eval: CachingShapeEvaluator) -> Optional[jax.Array]: - """Computes if the constraint is satisfied. - - If the constraint can be resolved statically returns None - or raises ValueError otherwise. If the constraint cannot be - resolved statically, returns a value representing if the - constraint is satisfied. - """ - left, right = eval.evaluate(self.left), eval.evaluate(self.right) - # Try to evaluate the constraint statically. - if core.is_constant_shape((left, right)): - left_int, right_int = op.index(left), op.index(right) - if self.comp == ShapeConstraint.Comparator.EQ: - if not (left_int == right_int): - raise self.make_error(eval) - elif self.comp == ShapeConstraint.Comparator.GEQ: - if not (left_int >= right_int): - raise self.make_error(eval) - else: assert False - return None - - if self.comp == ShapeConstraint.Comparator.EQ: - is_ok = lax.eq(left, right) - elif self.comp == ShapeConstraint.Comparator.GEQ: - is_ok = lax.ge(left, right) - else: assert False - return is_ok - - def __str__(self): - return (f"{self.left} {'==' if self.comp == ShapeConstraint.Comparator.EQ else '>='} {self.right}" - f" ({self.error_message_pieces})") - __repr__ = __str__ - - def error_message_and_inputs( - self, - eval: CachingShapeEvaluator) -> tuple[str, Sequence[Any]]: - """Forms the error_message and error message_inputs. - See shape_assertion. - """ - # There is currenly a limitation in the shape assertion checker that - # it supports at most 32 error_message_inputs. We try to stay within the - # limit, reusing a format specifier if possible. - if jaxlib_version <= (0, 4, 14): - max_error_message_inputs = 4 - else: - max_error_message_inputs = 32 - format_specifiers: dict[DimSize, str] = {} - error_message_inputs: list[Any] = [] - error_message_strings: list[str] = [] - for e in self.error_message_pieces: - if isinstance(e, str): - error_message_strings.append(e) - continue - cached_spec = format_specifiers.get(e) - if cached_spec is not None: - error_message_strings.append(cached_spec) - continue - if len(error_message_inputs) >= max_error_message_inputs: - error_message_strings.append("N/A") - continue - spec = "{" + str(len(error_message_inputs)) + "}" - format_specifiers[e] = spec - error_message_strings.append(spec) - error_message_inputs.append(eval.evaluate(e)) - return ("".join(error_message_strings), - error_message_inputs) - - def make_error(self, eval: CachingShapeEvaluator) -> Exception: - error_message, error_message_inputs = self.error_message_and_inputs(eval) - return ValueError(error_message.format(*error_message_inputs)) - - -class ShapeConstraints: - def __init__(self): - self.constraints: list[ShapeConstraint] = [] - - def add_constraint(self, - comp: ShapeConstraint.Comparator, - left: DimSize, right: DimSize, - error_message_pieces: Sequence[Union[str, DimSize]]): - c = ShapeConstraint(comp, left, right, error_message_pieces) - self.constraints.append(c) - - def check_statically(self, eval: CachingShapeEvaluator) -> None: - """Evaluates all the constraints statically. - - If the static checking of any constraint fails, raises ValueError. - """ - for constraint in self.constraints: - constraint.check_statically(eval) - - def shape_assertions(self, eval: CachingShapeEvaluator) -> None: - """Computes the shape assertions for the set of constraints. - - See jax_export._wrap_main_func docstring. - """ - # We want to report the errors in the same order as `check_statically`. - # So, we process them in order, in case some fail statically, and we - # generate the shape assertions in the same order. - for constraint in self.constraints: - is_ok = constraint.compute(eval) - if is_ok is None: continue # Was resolved statically - error_message, error_message_inputs = constraint.error_message_and_inputs(eval) - shape_assertion( - is_ok, *error_message_inputs, - error_message=error_message) - -@dataclasses.dataclass -class _DimEquation: - # Encodes that `aval_dim_expr`, which is a symbolic expressions containing - # unknown dimension variables from the abstract values, is the specification - # for dimension named `dim_name` (e.g., "args[0].field.shape[2]"). - aval_dim_expr: _DimExpr - dim_name: str - - def __str__(self): - return f"Dimension size of {self.dim_name} with specification '{self.aval_dim_expr}'" - __repr__ = __str__ - - -def args_kwargs_path_to_str(path: tree_util.KeyPath) -> str: - # String description of `args` or `kwargs`, assuming the path for a tree for - # the tuple `(args, kwargs)`. - if path[0] == tree_util.SequenceKey(0): - return f"args{tree_util.keystr(path[1:])}" - elif path[0] == tree_util.SequenceKey(1): - return f"kwargs{tree_util.keystr(path[1:])}" - else: - assert False - -@functools.lru_cache(128) -def _cached_pretty_print_dimension_descriptor( - args_kwargs_tree: tree_util.PyTreeDef, - flat_arg_idx: int) -> str: - args_kwargs_with_paths, _ = tree_util.tree_flatten_with_path( - args_kwargs_tree.unflatten((0,) * args_kwargs_tree.num_leaves)) - arg_str = args_kwargs_path_to_str(args_kwargs_with_paths[flat_arg_idx][0]) - return arg_str - -def pretty_print_dimension_descriptor( - args_kwargs_tree: tree_util.PyTreeDef, - flat_arg_idx: int, dim_idx: Optional[int]) -> str: - arg_str = _cached_pretty_print_dimension_descriptor(args_kwargs_tree, flat_arg_idx) - if dim_idx is not None: - arg_str += f".shape[{dim_idx}]" - return arg_str - -@util.cache() -def solve_dim_vars( - args_avals: Sequence[core.AbstractValue], - args_kwargs_tree: tree_util.PyTreeDef, - ) -> tuple[DimVarEnv, ShapeConstraints, Sequence[tuple[str, int, int]]]: - """Solves dimension variables in a called function's avals in terms of actual argument shapes. - - For example, given: - - args_avals = [ShapedArray((3, a, a + b), f32)] - - we introduce fresh "synthetic" dimension variables to represent the actual - dimension size of actual arguments for each non-constant dimension. - Each synthetic variable has a name, an arg_idx, and a dim_idx, e.g.: - - synthetic_vars = [("args[0].shape[1]", 0, 1), ("args[0].shape[2]", 0, 2)] - - and then we express the solution for the unknown dimension variables {a, b} - as symbolic expressions in terms of the synthetic variables: - - dict(a=args[0].shape[1], b=args[0].shape[2] - args[0].shape[1]) - - Not all equations are solvable. For now, we solve first the linear - uni-variate equations, then the solved variables are used to simplify the - remaining equations to linear uni-variate equations, and the process - continues until all dimension variables are solved. - - Args: - args_avals: the abstract values of the `args`, with shapes that may - include unknown dimension variables. - args_kwargs_tree: a PyTreeDef that describes the tuple `(args, kwargs)` - from which the flat sequence `args_avals` is extracted. Used for - describing args and kwargs in synthetic variable names and in - error messages. - - Returns: a 3-tuple with: (a) the solution for the unknown dimension variables - (b) a list of constraints that must be satisfied for the solution to be a - valid one, and (c) and the list of synthetic variables that may appear in - the solution and the constraints. - - Raises ValueError if it cannot solve some dimension variable. - """ - dim_equations: list[_DimEquation] = [] - synth_dimension_vars: list[tuple[str, int, int]] = [] - # tuples with argument name and its polymorphic shape ('args[0]', '(a, a + b')) - polymorphic_shape_specs: list[tuple[str, str]] = [] - for arg_idx, aval in enumerate(args_avals): - if all(not is_poly_dim(d) for d in aval.shape): - continue - polymorphic_shape_specs.append( - (pretty_print_dimension_descriptor(args_kwargs_tree, arg_idx, None), - str(aval.shape))) - for dim_idx, aval_d in enumerate(aval.shape): - if is_poly_dim(aval_d): - synth_dim_var = pretty_print_dimension_descriptor(args_kwargs_tree, - arg_idx, dim_idx) - synth_dimension_vars.append((synth_dim_var, arg_idx, dim_idx)) - dim_equations.append( - _DimEquation(aval_dim_expr=_ensure_poly(aval_d, "solve_dim_vars"), - dim_name=synth_dim_var)) - - solution, shape_constraints = _solve_dim_equations(dim_equations, - polymorphic_shape_specs) - return solution, shape_constraints, synth_dimension_vars - - -def compute_dim_vars_from_arg_shapes( - args_avals: Sequence[core.AbstractValue], - *actual_args: jax.Array, - args_kwargs_tree: tree_util.PyTreeDef) -> Sequence[jax.Array]: - """Computes values of dimension variables to unify args_avals with actual arguments. - - Like `solve_dim_vars` except that here we express the solution as - JAX arrays that reference the `actual_args`. This function can be used to - generate the code for computing the dimension variables. It also generates - the shape assertions. - - Returns: the values of the dimension variables, in the order determined by - `all_dim_vars(args_avals)`. - """ - dim_vars = all_dim_vars(args_avals) - solution, shape_constraints, synth_dim_vars = solve_dim_vars( - tuple(args_avals), args_kwargs_tree=args_kwargs_tree) - - # Replace the synthetic vars with the dynamic shape of the actual arg - synthetic_env = {vname: dimension_size_p.bind(actual_args[arg_idx], - dimension=dim_idx) - for (vname, arg_idx, dim_idx) in synth_dim_vars} - synthetic_eval = CachingShapeEvaluator(**synthetic_env) - shape_constraints.shape_assertions(synthetic_eval) - dim_values = [synthetic_eval.evaluate(solution[var]) for var in dim_vars] - return tuple(dim_values) - -def _solve_dim_equations( - eqns: list[_DimEquation], - polymorphic_shape_specs: Sequence[tuple[str, str]] -) -> tuple[DimVarEnv, ShapeConstraints]: - # Returns a shape environment and the shape constraints if it can solve all - # dimension variables. Raises an exception if it cannot. - shapeenv: DimVarEnv = {} - solution_error_message_pieces: list[Union[str, _DimExpr]] = [ - " Obtained dimension variables: " - ] # Error message describing the solution - # Prepare error message piece describing the polymorphic shape specs - poly_specs_err_msg = ( - " Using the following polymorphic shapes specifications: " + - ",".join(f"{arg_name}.shape = {arg_spec}" - for arg_name, arg_spec in polymorphic_shape_specs)) + "." - solution_err_msg_trailer_errors = ". Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#shape-assertion-errors for more details." - - shape_constraints = ShapeConstraints() # accumulate shape constraints - - def process_one_eqn(eqn: _DimEquation) -> bool: - # We start with a DimEquation of the form `dim_expr = dim_value` - # Try to rewrite the equation as `var * factor_var = dim_value_2` (a linear - # uni-variate equation). Returns `False` if this rewrite fails. - # Otherwise, compute the `var` value as `dim_value_2 // factor`, add it to - # `shapeenv` and return `True`. - # - # Invariant: - # var * factor_var + remaining_monomials_from_dim_expr = dim_value - var, factor_var = None, None - dim_value = _DimExpr.from_var(eqn.dim_name) - - for mon, factor in eqn.aval_dim_expr.monomials(): - # Perhaps we can already evaluate this monomial (all vars solved) - try: - mon_value = mon.evaluate(shapeenv) - except KeyError: - # `mon` still uses some variables not yet solved. We handle only the - # case when `mon` is a single variable. - v = mon.to_var() - if v is not None and var is None: - var, factor_var = v, factor - continue - else: - dim_value = dim_value + core.dim_constant(-1) * _evaluate_multiply(mon_value, core.dim_constant(factor)) - continue - return False # This equation cannot yet be used to solve a variable - - if var is not None: - if factor_var == 1: - var_value = dim_value - else: - var_value, var_remainder = divmod(dim_value, core.dim_constant(factor_var)) # type: ignore - shape_constraints.add_constraint( - ShapeConstraint.Comparator.EQ, var_remainder, 0, - error_message_pieces=([ - "Input shapes do not match the polymorphic shapes specification. " - "Division had remainder ", var_remainder, - f" when computing the value of '{var}'." + poly_specs_err_msg - ] + solution_error_message_pieces + [ - solution_err_msg_trailer_errors])) - - if not isinstance(var_value, _DimExpr): - assert var_value.dtype == core.dim_value_dtype() - shapeenv[var] = var_value # type: ignore - solution_error_message_pieces.extend([ - f"'{var}' = ", var_value, - f" from specification '{eqn.aval_dim_expr}' " - f"for dimension {eqn.dim_name} (= ", _DimExpr.from_var(eqn.dim_name), - "), "]) - - shape_constraints.add_constraint( - ShapeConstraint.Comparator.GEQ, var_value, 1, - error_message_pieces=[ - "Input shapes do not match the polymorphic shapes specification. " - f"Expected value >= 1 for dimension variable '{var}'." + - poly_specs_err_msg - ] + solution_error_message_pieces + [ - solution_err_msg_trailer_errors]) - - return True - else: - # All variables are resolved for this equation, we emit an assertion - shape_constraints.add_constraint( - ShapeConstraint.Comparator.EQ, - _DimExpr.from_var(eqn.dim_name), - eqn.aval_dim_expr.evaluate(shapeenv), - error_message_pieces=([ - "Input shapes do not match the polymorphic shapes specification. " - f"Found inconsistency between dimension size {eqn.dim_name} (= ", - _DimExpr.from_var(eqn.dim_name), - f") and the specification '{eqn.aval_dim_expr}' (= ", - eqn.aval_dim_expr.evaluate(shapeenv), - ")." + poly_specs_err_msg] + solution_error_message_pieces + - [solution_err_msg_trailer_errors]) - ) - return True - - while True: - nr_eqns = len(eqns) - eqns = [eqn for eqn in eqns if not process_one_eqn(eqn)] - if not eqns: - return shapeenv, shape_constraints # SUCCESS - elif len(eqns) >= nr_eqns: - break - - # We have some equations that we cannot solve further - unsolved_vars: set[str] = set() - unsolved_polys: list[_DimExpr] = [] - for eqn in eqns: - unsolved_vars = unsolved_vars.union(eqn.aval_dim_expr.get_vars()) - unsolved_polys.append(eqn.aval_dim_expr) - unsolved_vars = unsolved_vars.difference(shapeenv.keys()) - err_msg = ( - f"Cannot solve for values of dimension variables {unsolved_vars}. " - "We can only solve linear uni-variate constraints." + poly_specs_err_msg + - " Unprocessed specifications: " + - ", ".join(f"'{eqn.aval_dim_expr}' for dimension size {eqn.dim_name}" - for eqn in eqns) + - ". Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#dimension-variables-must-be-solvable-from-the-input-shapes for more details." - ) - raise ValueError(err_msg) +# TODO(necula): Remove these stubs +from jax.experimental.export.shape_poly import ( + InconclusiveDimensionOperation, + + # PolyShape used in tensorflowjs/converters/jax_conversion.py + PolyShape, + # is_poly_dim is used by maths/qec. + is_poly_dim, +) diff --git a/jax/experimental/jax2tf/tests/back_compat_test.py b/jax/experimental/jax2tf/tests/back_compat_test.py index f0f59a8f7..50733a050 100644 --- a/jax/experimental/jax2tf/tests/back_compat_test.py +++ b/jax/experimental/jax2tf/tests/back_compat_test.py @@ -28,7 +28,7 @@ import numpy as np import jax from jax import config from jax import lax -from jax.experimental.jax2tf import jax_export +from jax.experimental.export import export from jax.experimental.jax2tf.tests import back_compat_test_util as bctu from jax.experimental.jax2tf.tests.back_compat_testdata import cpu_ducc_fft @@ -96,7 +96,7 @@ class CompatTest(bctu.CompatTestBase): def test_custom_call_coverage(self): """Tests that the back compat tests cover all the targets declared stable.""" - targets_to_cover = set(jax_export._CUSTOM_CALL_TARGETS_GUARANTEED_STABLE) + targets_to_cover = set(export._CUSTOM_CALL_TARGETS_GUARANTEED_STABLE) # Add here all the testdatas that should cover the targets guaranteed # stable covering_testdatas = [ diff --git a/jax/experimental/jax2tf/tests/back_compat_test_util.py b/jax/experimental/jax2tf/tests/back_compat_test_util.py index 6b317fbe9..23361ceab 100644 --- a/jax/experimental/jax2tf/tests/back_compat_test_util.py +++ b/jax/experimental/jax2tf/tests/back_compat_test_util.py @@ -21,7 +21,7 @@ The tests in this file refer to the test data in ./back_compat_testdata. There is one test for each version of a custom call target, e.g., `test_ducc_fft` tests the FFT custom calls on CPU. Only custom call targets tested here should be listed in -jax_export._CUSTOM_CALL_TARGETS_GUARANTEED_STABLE. All other custom +export._CUSTOM_CALL_TARGETS_GUARANTEED_STABLE. All other custom call targets will result in an error when encountered during serialization. Once we stop using a custom call target in JAX, you can remove it from the @@ -78,7 +78,7 @@ from numpy import array, float32 import jax from jax import tree_util -from jax.experimental.jax2tf import jax_export +from jax.experimental.export import export from jax.experimental import pjit @@ -281,12 +281,12 @@ data_{datetime.date.today().strftime('%Y_%m_%d')} = dict( a string (for debugging), and (c) the module serialization version. """ # Use the native exporter, to make sure we get the proper serialization. - args_specs = jax_export.poly_specs(data.inputs, polymorphic_shapes) - exported = jax_export.export( + args_specs = export.poly_specs(data.inputs, polymorphic_shapes) + exported = export.export( jax.jit(func), lowering_platform=self.default_jax_backend(), disabled_checks=tuple( - jax_export.DisabledSafetyCheck.custom_call(target) + export.DisabledSafetyCheck.custom_call(target) for target in allow_unstable_custom_call_targets) )(*args_specs) @@ -297,13 +297,13 @@ data_{datetime.date.today().strftime('%Y_%m_%d')} = dict( def run_serialized(self, data: CompatTestData, polymorphic_shapes: Optional[Sequence[str]] = None): - args_specs = jax_export.poly_specs(data.inputs, polymorphic_shapes) + args_specs = export.poly_specs(data.inputs, polymorphic_shapes) def ndarray_to_aval(a: np.ndarray) -> core.ShapedArray: return core.ShapedArray(a.shape, a.dtype) in_avals_tree = tree_util.tree_map(ndarray_to_aval, args_specs) # TODO: we ought to ensure that out_avals are polymorphic if need be. We # could either save the in/out_avals (but we need to first implement that - # support in jax_export), or we can just re-use them from the current + # support in export), or we can just re-use them from the current # exported. out_avals_tree = tree_util.tree_map(ndarray_to_aval, data.expected_outputs) # in_tree must be for (args, kwargs) @@ -312,7 +312,7 @@ data_{datetime.date.today().strftime('%Y_%m_%d')} = dict( def _get_vjp(_): assert False # We do not have and do not need VJP - exported = jax_export.Exported( + exported = export.Exported( fun_name="run_serialized", in_tree=in_tree, in_avals=tuple(in_avals), @@ -331,4 +331,4 @@ data_{datetime.date.today().strftime('%Y_%m_%d')} = dict( _get_vjp=_get_vjp) # We use pjit in case there are shardings in the exported module. - return pjit.pjit(jax_export.call_exported(exported))(*data.inputs) + return pjit.pjit(export.call_exported(exported))(*data.inputs) diff --git a/jax/experimental/jax2tf/tests/jax2tf_test.py b/jax/experimental/jax2tf/tests/jax2tf_test.py index 0d6eca9f2..40834962a 100644 --- a/jax/experimental/jax2tf/tests/jax2tf_test.py +++ b/jax/experimental/jax2tf/tests/jax2tf_test.py @@ -39,7 +39,7 @@ from jax._src.lib.mlir.dialects import hlo import jax._src.xla_bridge from jax import config from jax.experimental import jax2tf -from jax.experimental.jax2tf import jax_export +from jax.experimental.export import export from jax.experimental.jax2tf.tests import tf_test_util from jax.experimental.maps import xmap from jax.experimental.shard_map import shard_map @@ -1522,7 +1522,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase): stack.enter_context(mesh) # Run the JAX native version, to check it works, and to fill caches. _ = func_to_convert(*args) - exported = jax_export.export( + exported = export.export( func_to_convert, lowering_platform='tpu' )(*(core.ShapedArray(a.shape, a.dtype) for a in args)) diff --git a/jax/experimental/jax2tf/tests/shape_poly_test.py b/jax/experimental/jax2tf/tests/shape_poly_test.py index 7efa85a3b..ea9b0130d 100644 --- a/jax/experimental/jax2tf/tests/shape_poly_test.py +++ b/jax/experimental/jax2tf/tests/shape_poly_test.py @@ -31,8 +31,8 @@ import re import jax from jax.experimental import jax2tf -from jax.experimental.jax2tf import shape_poly -from jax.experimental.jax2tf import jax_export +from jax.experimental.export import export +from jax.experimental.export import shape_poly from jax.experimental import pjit from jax import lax import jax.numpy as jnp @@ -72,7 +72,6 @@ expect_error_associative_scan = ( "associative scan over axis of non-constant size")) - class DimExprTest(tf_test_util.JaxToTfTestCase): def sampled_assert_equal(self, @@ -585,7 +584,7 @@ class PolyHarness(Harness): len(polymorphic_shapes), len(args), f"polymorphic_shapes {polymorphic_shapes} of length " f"{len(polymorphic_shapes)} must match number of arguments {len(args)}") - args_specs = jax_export.poly_specs(args, polymorphic_shapes) + args_specs = export.poly_specs(args, polymorphic_shapes) input_signature = [ tf.TensorSpec( [d if isinstance(d, int) else None for d in a.shape], diff --git a/jax/experimental/jax2tf/tests/tf_test_util.py b/jax/experimental/jax2tf/tests/tf_test_util.py index 738620f16..82c40fca7 100644 --- a/jax/experimental/jax2tf/tests/tf_test_util.py +++ b/jax/experimental/jax2tf/tests/tf_test_util.py @@ -31,7 +31,7 @@ from jax import tree_util from jax import config from jax.experimental import jax2tf -from jax.experimental.jax2tf import jax_export +from jax.experimental.export import export from jax._src import xla_bridge import numpy as np import tensorflow as tf # type: ignore[import] @@ -158,7 +158,7 @@ def ComputeTfValueAndGrad(tf_f: Callable, tf_args: Sequence, jax_legacy_prng_key="allow") class JaxToTfTestCase(jtu.JaxTestCase): # We want most tests to use the maximum available version, from the locally - # installed tfxla module and jax_export. + # installed tfxla module and export. use_max_serialization_version = True def setUp(self): @@ -181,16 +181,16 @@ class JaxToTfTestCase(jtu.JaxTestCase): self.addCleanup(functools.partial(config.update, "jax_serialization_version", version)) if self.use_max_serialization_version: - # Use the largest supported by both jax_export and tfxla.call_module - version = min(jax_export.maximum_supported_serialization_version, + # Use the largest supported by both export and tfxla.call_module + version = min(export.maximum_supported_serialization_version, tfxla.call_module_maximum_supported_version()) self.assertGreaterEqual(version, - jax_export.minimum_supported_serialization_version) + export.minimum_supported_serialization_version) config.update("jax_serialization_version", version) logging.info( - "Using JAX serialization version %s (jax_export.max_version %s, tf.XlaCallModule max version %s)", + "Using JAX serialization version %s (export.max_version %s, tf.XlaCallModule max version %s)", version, - jax_export.maximum_supported_serialization_version, + export.maximum_supported_serialization_version, tfxla.call_module_maximum_supported_version()) with contextlib.ExitStack() as stack: diff --git a/tests/BUILD b/tests/BUILD index 0f897e257..ef4a8b5d0 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -1173,6 +1173,18 @@ py_test( ], ) +jax_test( + name = "export_test", + srcs = ["export_test.py"], + enable_configs = [ + "tpu_df_2x2", + ], + tags = [], + deps = [ + "//jax/experimental/export", + ], +) + exports_files( [ "api_test.py", diff --git a/jax/experimental/jax2tf/tests/jax_export_test.py b/tests/export_test.py similarity index 84% rename from jax/experimental/jax2tf/tests/jax_export_test.py rename to tests/export_test.py index 940017a5c..79966b759 100644 --- a/jax/experimental/jax2tf/tests/jax_export_test.py +++ b/tests/export_test.py @@ -12,9 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. import contextlib -import math import functools import logging +import math import re from typing import Optional import unittest @@ -24,7 +24,7 @@ import jax from jax import numpy as jnp from jax import tree_util from jax.config import config -from jax.experimental.jax2tf import jax_export +from jax.experimental.export import export from jax._src import core from jax._src import test_util as jtu @@ -56,14 +56,14 @@ class JaxExportTest(jtu.JaxTestCase): super().setUp() # Run tests with the maximum supported version by default self.override_serialization_version( - jax_export.maximum_supported_serialization_version) + export.maximum_supported_serialization_version) def test_basic_export_only(self): def my_fun(x): return jnp.sin(x) - exp = jax_export.export(my_fun)(jax.ShapeDtypeStruct((4,), dtype=np.float32)) + exp = export.export(my_fun)(jax.ShapeDtypeStruct((4,), dtype=np.float32)) self.assertEqual("my_fun", exp.fun_name) - self.assertEqual(jax_export.default_lowering_platform(), exp.lowering_platform) + self.assertEqual(export.default_lowering_platform(), exp.lowering_platform) self.assertEqual(tree_util.tree_flatten(((1,), {}))[1], exp.in_tree) self.assertEqual((core.ShapedArray((4,), dtype=np.float32),), exp.in_avals) self.assertEqual((core.ShapedArray((4,), dtype=np.float32),), exp.out_avals) @@ -74,7 +74,7 @@ class JaxExportTest(jtu.JaxTestCase): def f(a_b_pair, *, a, b): return (dict(res=a_b_pair, a=a, b=b), jnp.sin(a), jnp.cos(b)) - exp = jax_export.export(f, lowering_platform="cpu")((a, b), a=a, b=b) + exp = export.export(f, lowering_platform="cpu")((a, b), a=a, b=b) a_aval = core.ShapedArray(a.shape, a.dtype) b_aval = core.ShapedArray(b.shape, b.dtype) self.assertEqual(exp.lowering_platform, "cpu") @@ -90,9 +90,9 @@ class JaxExportTest(jtu.JaxTestCase): def f(a, b): # a: f32[2w,h] b: f32[w,h] return jnp.concatenate([a, b], axis=0) - exp = jax_export.export(f)( - jax_export.poly_spec(a.shape, a.dtype, "(2*w, h)"), - jax_export.poly_spec(a.shape, a.dtype, "(w, h)")) + exp = export.export(f)( + export.poly_spec(a.shape, a.dtype, "(2*w, h)"), + export.poly_spec(a.shape, a.dtype, "(w, h)")) self.assertEqual("(2*w, h)", str(exp.in_avals[0].shape)) self.assertEqual("(w, h)", str(exp.in_avals[1].shape)) self.assertEqual("(3*w, h)", str(exp.out_avals[0].shape)) @@ -102,25 +102,25 @@ class JaxExportTest(jtu.JaxTestCase): def f(a0, a1, *, ak): return jnp.concatenate([a0, a1, ak], axis=0) - a_poly_spec = jax_export.poly_spec(a.shape, a.dtype, "(w, h)") - exp = jax_export.export(f)(a_poly_spec, a_poly_spec, ak=a_poly_spec) + a_poly_spec = export.poly_spec(a.shape, a.dtype, "(w, h)") + exp = export.export(f)(a_poly_spec, a_poly_spec, ak=a_poly_spec) self.assertEqual("(w, h)", str(exp.in_avals[0].shape)) self.assertEqual("(3*w, h)", str(exp.out_avals[0].shape)) def test_basic(self): f = jnp.sin x = np.arange(4, dtype=np.float32) - exp_f = jax_export.export(f)(x) + exp_f = export.export(f)(x) - f1 = jax_export.call_exported(exp_f) + f1 = export.call_exported(exp_f) self.assertAllClose(f(x), f1(x)) def test_call_exported_lambda(self): # When we export a lambda, the exported.fun_name is not a valid MLIR function name f = lambda x: jnp.sin(x) x = np.arange(4, dtype=np.float32) - exp_f = jax_export.export(f)(x) - f1 = jax_export.call_exported(exp_f) + exp_f = export.export(f)(x) + f1 = export.call_exported(exp_f) self.assertAllClose(f(x), f1(x)) def test_call_twice_exported(self): @@ -129,8 +129,8 @@ class JaxExportTest(jtu.JaxTestCase): @jax.jit def f1(x): - exp_f = jax_export.export(f)(x) - return jax_export.call_exported(exp_f)(x) + jax_export.call_exported(exp_f)(x) + exp_f = export.export(f)(x) + return export.call_exported(exp_f)(x) + export.call_exported(exp_f)(x) self.assertAllClose(2. * f(x), f1(x)) @@ -138,9 +138,9 @@ class JaxExportTest(jtu.JaxTestCase): f = lambda x, y: jnp.sin(x) x = np.arange(4, dtype=np.float32) y = np.arange(6, dtype=np.float32) - exp_f = jax_export.export(f)(x, y) + exp_f = export.export(f)(x, y) - f1 = jax_export.call_exported(exp_f) + f1 = export.call_exported(exp_f) self.assertAllClose(f(x, y), f1(x, y)) def test_pytree(self): @@ -149,8 +149,8 @@ class JaxExportTest(jtu.JaxTestCase): def f(a_b_pair, a, b): return (dict(res=a_b_pair, a=a, b=b), jnp.sin(a), jnp.cos(b)) - exp_f = jax_export.export(f)((a, b), a=a, b=b) - f1 = jax_export.call_exported(exp_f) + exp_f = export.export(f)((a, b), a=a, b=b) + f1 = export.call_exported(exp_f) self.assertAllClose(f((a, b), a=a, b=b), f1((a, b), a=a, b=b)) @@ -158,34 +158,34 @@ class JaxExportTest(jtu.JaxTestCase): def f(a_b_pair, *, c): return jnp.sin(a_b_pair[0]) + jnp.cos(a_b_pair[1]) + c a = b = c = np.arange(4, dtype=np.float32) - exp_f = jax_export.export(f)((a, b), c=c) + exp_f = export.export(f)((a, b), c=c) with self.assertRaisesRegex( ValueError, "The invocation args and kwargs must have the same pytree structure"): - jax_export.call_exported(exp_f)(a, b, c=(a, b)) + export.call_exported(exp_f)(a, b, c=(a, b)) def test_error_wrong_avals(self): def f(a, *, b): # a: f32[4] and b: f32[4] return jnp.sin(a) + jnp.cos(b) f32_4 = np.arange(4, dtype=np.float32) - exp_f = jax_export.export(f)(f32_4, b=f32_4) + exp_f = export.export(f)(f32_4, b=f32_4) with self.assertRaisesRegex(ValueError, r"Shape mismatch for args\[0\].shape\[0\]"): - jax_export.call_exported(exp_f)(np.arange(6, dtype=np.float32), b=f32_4) + export.call_exported(exp_f)(np.arange(6, dtype=np.float32), b=f32_4) with self.assertRaisesRegex(ValueError, r"Shape mismatch for kwargs\['b'\].shape\[0\]"): - jax_export.call_exported(exp_f)(f32_4, b=np.arange(6, dtype=np.float32)) + export.call_exported(exp_f)(f32_4, b=np.arange(6, dtype=np.float32)) with self.assertRaisesRegex(ValueError, r"Rank mismatch for args\[0\]"): - jax_export.call_exported(exp_f)(f32_4.reshape((1, 4)), b=f32_4) + export.call_exported(exp_f)(f32_4.reshape((1, 4)), b=f32_4) with self.assertRaisesRegex(ValueError, r"Dtype mismatch for args\[0\]"): - jax_export.call_exported(exp_f)(f32_4.astype(np.float16), b=f32_4) + export.call_exported(exp_f)(f32_4.astype(np.float16), b=f32_4) @jtu.parameterized_filterable( testcase_name=lambda kw: kw["platform"], @@ -194,19 +194,19 @@ class JaxExportTest(jtu.JaxTestCase): def test_error_wrong_platform(self, platform): a = np.arange(4, dtype=np.float32) - exp_f = jax_export.export(jnp.sin, lowering_platform=platform)(a) + exp_f = export.export(jnp.sin, lowering_platform=platform)(a) if xb.canonicalize_platform(jtu.device_under_test()) == platform: raise unittest.SkipTest("Uninteresting scenario") with self.assertRaisesRegex( ValueError, "The exported function .* was lowered for platform"): - jax_export.call_exported(exp_f)(a) + export.call_exported(exp_f)(a) # Now try with the platform check disabled - exp_f_no_platform_check = jax_export.export( + exp_f_no_platform_check = export.export( jnp.sin, lowering_platform=platform, - disabled_checks=[jax_export.DisabledSafetyCheck.platform()])(a) - res = jax_export.call_exported(exp_f_no_platform_check)(a) + disabled_checks=[export.DisabledSafetyCheck.platform()])(a) + res = export.call_exported(exp_f_no_platform_check)(a) self.assertAllClose(res, jnp.sin(a)) @jtu.parameterized_filterable( @@ -230,23 +230,23 @@ class JaxExportTest(jtu.JaxTestCase): a = np.arange(3, dtype=np.float32) with self.assertRaisesRegex(ValueError, "Cannot serialize code with custom calls whose targets .*"): - jax_export.export( + export.export( lambda a: a + test_primitive.bind(a) )(a) # Now try again with the safety check disabled - exp = jax_export.export( + exp = export.export( lambda a: a + test_primitive.bind(a), - disabled_checks=[jax_export.DisabledSafetyCheck.custom_call("disallowed_call_target")] + disabled_checks=[export.DisabledSafetyCheck.custom_call("disallowed_call_target")] )(a) self.assertIn("disallowed_call_target", exp.mlir_module()) def test_grad(self): f = lambda x: jnp.sum(jnp.sin(x)) x = np.arange(4, dtype=np.float32) - exp_f = jax_export.export(f)(x) + exp_f = export.export(f)(x) - f1 = jax_export.call_exported(exp_f) + f1 = export.call_exported(exp_f) self.assertAllClose(jax.grad(f)(x), jax.grad(f1)(x)) def test_pytree_vjp(self): @@ -256,14 +256,14 @@ class JaxExportTest(jtu.JaxTestCase): a = np.arange(4, dtype=np.float32) b = np.arange(6, dtype=np.float32) - exp_f = jax_export.export(f)((a, b), a=a, b=b) + exp_f = export.export(f)((a, b), a=a, b=b) out_ct = f((a, b), a=a, b=b) # The output has the right structure as the cotangent def f1_jax(a, b): # For VJP, make a function without kwargs res = f((a, b), a=a, b=b) return res def f1_exp(a, b): # For VJP, make a function without kwargs - res = jax_export.call_exported(exp_f)((a, b), a=a, b=b) + res = export.call_exported(exp_f)((a, b), a=a, b=b) return res jax_vjp = jax.vjp(f1_jax, a, b)[1](out_ct) exp_vjp = jax.vjp(f1_exp, a, b)[1](out_ct) @@ -273,33 +273,33 @@ class JaxExportTest(jtu.JaxTestCase): def f1(x): return jnp.sin(x) a = np.arange(4, dtype=np.float32) - exp_f1 = jax_export.export(f1)(a) + exp_f1 = export.export(f1)(a) def f2(x): - res1 = jax_export.call_exported(exp_f1)(x) - res2 = jax_export.call_exported(exp_f1)(res1) + res1 = export.call_exported(exp_f1)(x) + res2 = export.call_exported(exp_f1)(res1) return jnp.cos(res2) - exp_f2 = jax_export.export(f2)(a) + exp_f2 = export.export(f2)(a) self.assertAllClose(jnp.cos(jnp.sin(jnp.sin(a))), - jax_export.call_exported(exp_f2)(a)) + export.call_exported(exp_f2)(a)) @jtu.parameterized_filterable( #one_containing="", kwargs=[ dict(v=v) - for v in range(jax_export.minimum_supported_serialization_version - 1, - jax_export.maximum_supported_serialization_version + 2)]) + for v in range(export.minimum_supported_serialization_version - 1, + export.maximum_supported_serialization_version + 2)]) def test_shape_poly_basic_versions(self, v: int): self.override_serialization_version(v) with contextlib.ExitStack() as e: - if not (jax_export.minimum_supported_serialization_version <= v - <= jax_export.maximum_supported_serialization_version): + if not (export.minimum_supported_serialization_version <= v + <= export.maximum_supported_serialization_version): e.enter_context(self.assertRaisesRegex( ValueError, f"The requested jax_serialization version {v} is outside the range of supported versions")) - exp = jax_export.export(jnp.sin)( - jax_export.poly_spec((3, 4), np.float32, "w, h")) + exp = export.export(jnp.sin)( + export.poly_spec((3, 4), np.float32, "w, h")) # Peek at the module module_str = exp.mlir_module() self.assertEqual(config.jax_serialization_version >= 7, @@ -307,11 +307,11 @@ class JaxExportTest(jtu.JaxTestCase): self.assertIn("jax.uses_shape_polymorphism = true", module_str) x = np.arange(30, dtype=np.float32).reshape((5, 6)) - res = jax_export.call_exported(exp)(x) + res = export.call_exported(exp)(x) self.assertAllClose(res, np.sin(x)) # A function is exported with f32[poly_spec] and is called with different arg - # shapes. We use jax_export.call_exported and we also run the shape check + # shapes. We use export.call_exported and we also run the shape check # module. @jtu.parameterized_filterable( testcase_name=lambda kw:f"poly_spec={kw['poly_spec']}_arg_shape={kw['arg_shape']}", # type: ignore @@ -348,8 +348,8 @@ class JaxExportTest(jtu.JaxTestCase): return jnp.reshape(x, (-1, x.shape[1])) disabled_checks = () - exp_f = jax_export.export(f, disabled_checks=disabled_checks)( - jax_export.poly_spec((3, 4, 12), np.float32, poly_spec)) + exp_f = export.export(f, disabled_checks=disabled_checks)( + export.poly_spec((3, 4, 12), np.float32, poly_spec)) self.assertEqual(exp_f.uses_shape_polymorphism, poly_spec != "3,4,12") arg = np.arange(np.prod(arg_shape), dtype=arg_dtype).reshape(arg_shape) # arg : f32[3,4,12] @@ -359,7 +359,7 @@ class JaxExportTest(jtu.JaxTestCase): stack.push(self.assertRaisesRegex(Exception, expect_error)) assert core.is_constant_shape(arg.shape) - res = jax_export.call_exported(exp_f)(arg) + res = export.call_exported(exp_f)(arg) if not expect_error: self.assertAllClose(res, f(arg)) @@ -450,23 +450,23 @@ class JaxExportTest(jtu.JaxTestCase): arg = np.arange(np.prod(arg_shape), dtype=arg_dtype).reshape(arg_shape) # x : f32[3,4,12] - inner_exp = jax_export.export(inner)( - jax_export.poly_spec((3, 4, 12), np.float32, inner_poly_spec)) + inner_exp = export.export(inner)( + export.poly_spec((3, 4, 12), np.float32, inner_poly_spec)) self.assertEqual(inner_exp.uses_shape_polymorphism, (inner_poly_spec != "3,4,12")) def outer(x): # x: outer_poly_spec # Use an addition to test that the shapes are refined properly for the # result of the call_exported. - return jax_export.call_exported(inner_exp)(x) + inner(x) + return export.call_exported(inner_exp)(x) + inner(x) with contextlib.ExitStack() as stack: if expect_error_outer_exp is not None: stack.push(self.assertRaisesRegex(ValueError, expect_error_outer_exp)) # Call it after exporting again, with polymorphic shapes - outer_exp = jax_export.export(outer)( - jax_export.poly_spec(arg.shape, arg.dtype, outer_poly_spec)) + outer_exp = export.export(outer)( + export.poly_spec(arg.shape, arg.dtype, outer_poly_spec)) if expect_error_outer_exp is not None: return @@ -478,7 +478,7 @@ class JaxExportTest(jtu.JaxTestCase): if expect_error_run is not None: stack.push(self.assertRaisesRegex(Exception, expect_error_run)) - res = jax_export.call_exported(outer_exp)(arg) + res = export.call_exported(outer_exp)(arg) if expect_error_run is not None: return @@ -543,9 +543,9 @@ class JaxExportTest(jtu.JaxTestCase): with contextlib.ExitStack() as stack: if expect_error is not None: stack.push(self.assertRaisesRegex(Exception, re.escape(expect_error))) - exp = jax_export.export(f_jax)( - jax_export.poly_spec(x.shape, x.dtype, poly_spec)) - jax_export.call_exported(exp)(x) + exp = export.export(f_jax)( + export.poly_spec(x.shape, x.dtype, poly_spec)) + export.call_exported(exp)(x) def test_multi_platform(self): if jtu.device_under_test() == "gpu": @@ -553,10 +553,10 @@ class JaxExportTest(jtu.JaxTestCase): raise unittest.SkipTest("Not intended for running on GPU") x = np.arange(5, dtype=np.float32) # TODO: use a function with different behavior for different platforms - exp = jax_export.export(jnp.sin, + exp = export.export(jnp.sin, lowering_platforms=('cpu', 'tpu'))(x) self.assertEqual(exp.lowering_platforms, ('cpu', 'tpu')) - res = jax_export.call_exported(exp)(x) + res = export.call_exported(exp)(x) self.assertAllClose(res, np.sin(x)) def test_multi_platform_nested(self): @@ -565,7 +565,7 @@ class JaxExportTest(jtu.JaxTestCase): raise unittest.SkipTest("Not intended for running on TPU") x = np.arange(5, dtype=np.float32) # TODO: use a function with different behavior for different platforms - exp = jax_export.export(jnp.sin, + exp = export.export(jnp.sin, lowering_platforms=('cpu', 'tpu', 'cuda'))(x) self.assertEqual(exp.lowering_platforms, ('cpu', 'tpu', 'cuda')) @@ -573,9 +573,9 @@ class JaxExportTest(jtu.JaxTestCase): # lowering platforms, but included in the lowering platforms for the # nested exported. # TODO: improve this test once we implement true multi-platform lowering - exp2 = jax_export.export(jax_export.call_exported(exp), + exp2 = export.export(export.call_exported(exp), lowering_platforms=('cpu', 'cuda'))(x) - res2 = jax_export.call_exported(exp2)(x) + res2 = export.call_exported(exp2)(x) self.assertAllClose(res2, np.sin(x)) def test_multi_platform_and_poly(self): @@ -583,16 +583,16 @@ class JaxExportTest(jtu.JaxTestCase): # The export is not applicable to GPU raise unittest.SkipTest("Not intended for running on GPU") # TODO: use a function with different behavior for different platforms - exp = jax_export.export(lambda x: jnp.reshape(jnp.sin(x), (-1,)), + exp = export.export(lambda x: jnp.reshape(jnp.sin(x), (-1,)), lowering_platforms=('cpu', 'tpu'))( - jax_export.poly_spec((5, 6), np.float32, "b1, b2") + export.poly_spec((5, 6), np.float32, "b1, b2") ) x = np.arange(12, dtype=np.float32).reshape((3, 4)) - res = jax_export.call_exported(exp)(x) + res = export.call_exported(exp)(x) self.assertAllClose(res, np.sin(x).reshape((-1,))) # Now serialize the call to the exported - exp2 = jax_export.export(jax_export.call_exported(exp))(x) - res2 = jax_export.call_exported(exp2)(x) + exp2 = export.export(export.call_exported(exp))(x) + res2 = export.call_exported(exp2)(x) self.assertAllClose(res2, np.sin(x).reshape((-1,)))