diff --git a/jax/experimental/export/BUILD b/jax/experimental/export/BUILD new file mode 100644 index 000000000..fcf9d20f0 --- /dev/null +++ b/jax/experimental/export/BUILD @@ -0,0 +1,44 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# JAX-export provides APIs for exporting StableHLO for serialization purposes. + +load( + "//jaxlib:jax.bzl", + "py_deps", +) +load("@rules_python//python:defs.bzl", "py_library") + +licenses(["notice"]) + +# Please add new users to :australis_users. +package( + default_applicable_licenses = [], + default_visibility = ["//visibility:private"], +) + +py_library( + name = "export", + srcs = [ + "export.py", + "shape_poly.py", + ], + srcs_version = "PY3", + # TODO: b/255503696: enable pytype + tags = ["pytype_unchecked_annotations"], + visibility = ["//visibility:public"], + deps = [ + "//jax", + ] + py_deps("numpy"), +) diff --git a/jax/experimental/export/__init__.py b/jax/experimental/export/__init__.py new file mode 100644 index 000000000..e9a5c3c22 --- /dev/null +++ b/jax/experimental/export/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2023 The JAX Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== diff --git a/jax/experimental/export/export.py b/jax/experimental/export/export.py new file mode 100644 index 000000000..62d01d044 --- /dev/null +++ b/jax/experimental/export/export.py @@ -0,0 +1,1046 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""JAX APIs for exporting JAX functions for interoperation. + +""" + +from collections.abc import Sequence +import copy +import dataclasses +import functools +import itertools +import re +from typing import Any, Callable, Optional, Union + +from absl import logging + +import numpy as np + +import jax +from jax import config +from jax import sharding + +from jax._src import core +from jax._src import dispatch +from jax._src import pjit +from jax._src import sharding_impls +from jax._src import source_info_util +from jax._src.interpreters import mlir +from jax._src.interpreters import pxla +from jax._src.lib import xla_client +from jax._src.lib.mlir import ir +from jax._src.lib.mlir.dialects import hlo +from jax._src.lib.mlir.dialects import func as func_dialect +from jax._src import tree_util +from jax._src import util +from jax._src import xla_bridge as xb + +from jax.experimental.export import shape_poly + +map = util.safe_map +zip = util.safe_zip + +DType = Any + +class DisabledSafetyCheck: + """A safety check should be skipped on (de)serialization. + + Most of these checks are performed on serialization, but some are deferred to + deserialization. The list of disabled checks is attached to the serialization, + e.g., as a sequence of string attributes to `jax_export.Exported` or of + `tf.XlaCallModuleOp`. + + You can disable more deserialization safety checks by passing + `TF_XLA_FLAGS=--tf_xla_call_module_disabled_checks=platform`. + """ + _impl: str + + @classmethod + def platform(cls) -> "DisabledSafetyCheck": + """Allows the execution platform to differ from the serialization platform. + + Has effect only on deserialization. + """ + return DisabledSafetyCheck("platform") + + @classmethod + def custom_call(cls, target_name: str) -> "DisabledSafetyCheck": + """Allows the serialization of a call target not known to be stable. + + Has effect only on serialization. + Args: + target_name: the name of the custom call target to allow. + """ + return DisabledSafetyCheck(f"custom_call:{target_name}") + + @classmethod + def shape_assertions(cls) -> "DisabledSafetyCheck": + """Allows invocations with shapes that do not meet the constraints. + + Has effect on serialization (to suppress the generation of the assertions) + and also on deserialization (to suppress the checking of the assertions). + """ + return DisabledSafetyCheck("shape_assertions") + + def is_custom_call(self) -> Optional[str]: + """Returns the custom call target allowed by this directive.""" + m = re.match(r'custom_call:(.+)$', self._impl) + return m.group(1) if m else None + + def __init__(self, _impl:str): + # Do not use directly, use builders `platform`, `custom_call`. + self._impl = _impl + + def __str__(self): + return self._impl + __repr__ = __str__ + + def __eq__(self, other) -> bool: + return isinstance(other, DisabledSafetyCheck) and self._impl == other._impl + + def __hash__(self) -> int: + return hash(self._impl) + + +minimum_supported_serialization_version = 6 +maximum_supported_serialization_version = 8 + +@dataclasses.dataclass(frozen=True) +class Exported: + """A JAX function lowered to StableHLO. + + Attributes: + fun_name: the name of the exported function, for error messages. + in_tree: a PyTreeDef describing the tuple (args, kwargs) of the lowered JAX + function. The actual lowering does not depend on the `in_tree`, but this + can be used to invoke the exported function using the same argument + structure. + in_avals: the flat tuple of input abstract values. May contain dimension + expressions in the shapes. + out_tree: a PyTreeDef describing the result of the lowered JAX function. + out_avals: the flat tuple of output abstract values. May contain dimension + expressions in the shapes, with dimension variables among those in + `in_avals. + in_shardings: the flattened input shardings. Only for the inputs that are + specified in `module_kept_var_idx`. If `None` then it is equivalent + to unspecified shardings. + out_shardings: the flattened output shardings, as long as `in_avals`. + lowering_platforms: a tuple containing at least one of 'tpu', 'cpu', + 'cuda', 'rocm'. See below for the calling convention for when + there are multiple lowering platforms. + mlir_module_serialized: the serialized lowered VHLO module. + serialization_version: a version number for the serialized module. + See more versioning details at https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#native-serialization-versions. + module_kept_var_idx: the sorted indices of the arguments among `in_avals` that + must be passed to the module. The other arguments have been dropped + because they are not used. Same length as `in_shardings`. + uses_shape_polymorphism: whether the `mlir_module_serialized` uses shape + polymorphism. This may be because `in_avals` contains dimension + variables, but also from inner calls of shape-polymorphic + Exported modules. + disabled_checks: a list of descriptors of safety checks that have been + disabled at export time. See docstring for `DisabledSafetyCheck`. + _get_vjp: an optional function that takes the current exported function and + returns the exported VJP function. + The VJP function takes a flat list of arguments, + starting with the primal arguments and followed by a cotangent argument + for each primal output. It returns a tuple with the cotangents + corresponding to the flattened primal inputs. + + Calling convention for the exported module: + + The `mlir_module` has a `main` function that takes an optional first + platform index argument if the module supports multiple platforms + (`len(lowering_platforms) > 1`), followed by the kept array arguments + (corresponding to `module_kept_var_idx` and `in_avals`). + The platform index is a i32 scalar encoding the index of the current + compilation platform into the `lowering_platforms` sequence. + + Inner functions use a different calling convention: an optional + platform index argument, optional dimension variable arguments specified + using scalar tensors of type i32 or i64, + followed by optional token arguments (in presence of side effects), + followed by the regular array arguments. + The dimension arguments correspond to the dimension variables appearing in + the `args_avals`, in sorted order of their names. + + Consider the lowering of a function with one array argument of type "f32[w, + 2 * h]", where "w" and "h" are two dimension variables. + Assume that we use multi-platform lowering, and we have + ordered effects. The `main` function will be as follows: + + func public main(platform_index: i32, arg: f32[?, ?]) { + arg_w = hlo.get_dimension_size(arg, 0) + dim1 = hlo.get_dimension_size(arg, 1) + arg_h = hlo.floordiv(dim1, 2) + call _check_shape_assertions(arg) # See below + token = new_token() + token_out, res = call _wrapped_jax_export_main(platform_index, arg_h, arg_w, token_in, arg) + return res + } + + The actual computation is in `_wrapped_jax_export_main`, taking also + the values of `h` and `w` and the token. Proper exporting of + functions with side-effects and tokens is still work-in-progress. + + Note that `main` contains a call to `_check_shape_assertions. + JAX tracing assumes that `arg.shape[1]` is even, and that both `w` and `h` + have values >= 1. We must check these constraints when we invoke the + module. We use a special custom call `@shape_assertion` that takes + a boolean first operand, a string `error_message` attribute that may contain + format specifiers `{0}`, `{1}`, ..., and a variadic number of integer + scalar operands corresponding to the format specifiers. + + func private _check_shape_assertions(arg: f32[?, ?]) { + # Check that w is >= 1 + arg_w = hlo.get_dimension_size(arg, 0) + custom_call @shape_assertion(arg_w >= 1, arg_w, + error_message="Dimension variable 'w' must have integer value >= 1. Found {0}") + # Check that dim1 is even + dim1 = hlo.get_dimension_size(arg, 1) + custom_call @shape_assertion(dim1 % 2 == 0, dim1, + error_message="Dimension variable 'h' must have integer value >= 1. Found non-zero remainder {0}") + # Check that h >= 1 + arg_h = hlo.floordiv(dim1, 2) + custom_call @shape_assertion(arg_h >= 1, arg_h, + error_message=""Dimension variable 'h' must have integer value >= 1. Found {0}") + + If we `call_exported` with this module we perform these checks + statically (in `call_exported_abstract_eval`). + """ + fun_name: str + in_tree: tree_util.PyTreeDef + in_avals: tuple[core.AbstractValue, ...] + out_tree: tree_util.PyTreeDef + out_avals: tuple[core.AbstractValue, ...] + + in_shardings: Optional[tuple[Union[sharding.XLACompatibleSharding, pxla.UnspecifiedValue], ...]] + out_shardings: Optional[tuple[Union[sharding.XLACompatibleSharding, pxla.UnspecifiedValue], ...]] + lowering_platform: str # For backwards compatibility + lowering_platforms: tuple[str, ...] + disabled_checks: Sequence[DisabledSafetyCheck] + + mlir_module_serialized: bytes + serialization_version: int + module_kept_var_idx: tuple[int, ...] + uses_shape_polymorphism: bool + + _get_vjp: Optional[Callable[["Exported"], "Exported"]] + + def mlir_module(self) -> ir.Module: + return xla_client._xla.mlir.deserialize_portable_artifact(self.mlir_module_serialized) + + def __str__(self): + # This is called to make a MLIR source location when we call an Exported, and we + # do not want the entire serialized module to end up in locations. + return f"Exported(fun_name={self.fun_name}, ...)" + + def vjp(self) -> "Exported": + """Gets the exported VJP. + + Returns None if not available, which can happen if the Exported has been + loaded from an external format, without a VJP.""" + if self._get_vjp is None: + raise ValueError("No VJP is available") + return self._get_vjp(self) + + +def default_lowering_platform() -> str: + # Canonicalize to turn 'gpu' into 'cuda' or 'rocm' + return xb.canonicalize_platform(jax.default_backend()) + +def poly_spec( + arg_shape: Sequence[Optional[int]], + arg_dtype: DType, + polymorphic_shape: Optional[str]) -> jax.ShapeDtypeStruct: + """Constructs a jax.ShapeDtypeStruct with polymorphic shapes. + + Args: + arg_shape: the shape, with possibly some unspecified dimensions. + arg_dtype: the jax dtype. + polymorphic_shape: a string specifying the polymorphic shape. + + .. warning:: The shape-polymorphic lowering is an experimental feature. + It is meant to be sound, but it is known to reject some JAX programs + that are shape polymorphic. The details of this feature can change. + + It should be either `None` (all dimensions are constant), or a string of + specification for one axis, and can be either a constant, `_` denoting + a constant dimension given by the `arg_shape`, or the name of a + dimension variable assumed to range over dimension greater than 0. For + convenience, zero or more trailing `_` can be abbreviated with `...`, and + the surrounding parentheses may be missing. + + Note that this function does not ensure that the provided `arg_shape` + is compatible with `polymorphic_shape`. The `arg_shape` is used only + to fill-in placeholders from `polymorphic_shape`. + + See [the README](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#shape-polymorphic-conversion) + for more details. + + Returns: a jax.ShapeDTypeStruct with shapes that may contain symbolic + expressions involving dimension variables. + """ + aval_shape = shape_poly._parse_spec(polymorphic_shape, arg_shape) + return jax.ShapeDtypeStruct(aval_shape, arg_dtype) + +def shape_and_dtype_jax_array(a) -> tuple[Sequence[Optional[int]], DType]: + """Returns the shape and dtype of a jax.Array.""" + aval = core.raise_to_shaped(core.get_aval(a)) + return aval.shape, aval.dtype + +def poly_specs( + args, # pytree of arguments + polymorphic_shapes, # prefix pytree of strings + get_shape_and_dtype=shape_and_dtype_jax_array, +): + """Constructs a pytree of jax.ShapeDtypeSpec. + + Args: + args: a pytree of arguments + polymorphic_shapes: should be `None` (all arguments are monomorphic), + a single string (applies to all arguments), or a pytree matching a prefix + of the `args`. + See [how optional parameters are matched to + arguments](https://jax.readthedocs.io/en/latest/pytrees.html#applying-optional-parameters-to-pytrees). + + Note that this function does not ensure that the provided `args` shapes + are compatible with `polymorphic_shapes`. The `args.shape` are used only + to fill-in placeholders from `polymorphic_shapes`. + + See docstring of `poly_spec` and + [the README](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#shape-polymorphic-conversion) + for more details. + + Returns: a pytree of jax.ShapeDTypeStruct matching `args`. + """ + args_flat, args_tree = tree_util.tree_flatten(args) + + shapes_and_dtypes = tuple(map(get_shape_and_dtype, args_flat)) + shapes, dtypes = util.unzip2(shapes_and_dtypes) + + if isinstance(args, tuple) and isinstance(polymorphic_shapes, list): + # TODO: Remove backward-compatibility workaround + polymorphic_shapes_ = tuple(polymorphic_shapes) + else: + polymorphic_shapes_ = polymorphic_shapes + + try: + polymorphic_shapes_flat = tree_util.broadcast_prefix( + polymorphic_shapes_, args, + is_leaf=lambda x: x is None) + except ValueError: + e, *_ = tree_util.prefix_errors( + polymorphic_shapes_, args, + is_leaf=lambda x: x is None) + raise e("jax_export polymorphic_shapes") from None + + # Now add in the polymorphic shapes + args_specs_flat = tuple( + map(poly_spec, shapes, dtypes, polymorphic_shapes_flat)) + + return args_tree.unflatten(args_specs_flat) + + +def export(fun_jax: Callable, + *, + lowering_platform: Optional[str] = None, + lowering_platforms: Optional[Sequence[str]] = None, + disabled_checks: Sequence[DisabledSafetyCheck] = (), + ) -> Callable[..., Exported]: + """Exports native serialization for a JAX function. + + Args: + fun_jax: the function to lower and serialize. + lowering_platform: one of 'tpu', 'cpu', 'cuda', 'rocm'. If None, then use + the default JAX backend. + lowering_platforms: DO NOT USE (NOT YET FUNCTIONAL). + Optional sequence containing a subset of 'tpu', 'cpu', + 'cuda', 'rocm'. If more than one platform is specified, then + the lowered code takes an argument specifying the platform. + If None, then use the default JAX backend. + The calling convention for multiple platforms is explained in the + `jax_export.Exported` docstring. + disabled_checks: the safety checks to disable. See docstring + of `DisabledSafetyCheck`. + + Returns: a function that takes args and kwargs pytrees of jax.ShapeDtypeStruct, + or values with `.shape` and `.dtype` attributes, and returns an + `Exported`. + + Usage: + + def f_jax(*args, **kwargs): ... + exported = jax_export.export(f_jax)(*args, **kwargs) + """ + fun_name = getattr(fun_jax, "__name__", "unknown") + version = config.jax_serialization_version + if (version < minimum_supported_serialization_version or + version > maximum_supported_serialization_version): + raise ValueError( + f"The requested jax_serialization version {version} is outside the " + f"range of supported versions [{minimum_supported_serialization_version}" + f"..{maximum_supported_serialization_version}]") + + def do_export(*args_specs, **kwargs_specs) -> Exported: + if not hasattr(fun_jax, "lower"): + # We support convert(pjit(f_jax)) and convert(jit(f_jax)) but also + # convert(f_jax), in which case a "jit" is implied. In that case we raise + # an error if the lowered function contains non-replicated sharding annotations. + wrapped_fun_jax = jax.jit(fun_jax) + allow_non_replicated_sharding = False + else: + # If we have a pjit or pmap already we do not wrap with another, and we + # allow shardings. + wrapped_fun_jax = fun_jax # type: ignore + allow_non_replicated_sharding = True + + nonlocal lowering_platforms + if lowering_platforms is not None: + lowering_platforms = tuple(lowering_platforms) + else: + lowering_platforms = (lowering_platform or default_lowering_platform(),) + + # Do not include shape assertions if the version is < 7. + enable_shape_assertions = ( + DisabledSafetyCheck.shape_assertions() not in disabled_checks and + version >= 7) # type: ignore + try: + prev_enable_shape_assertions = shape_poly.thread_local_state.enable_shape_assertions + shape_poly.thread_local_state.enable_shape_assertions = enable_shape_assertions + lowered = wrapped_fun_jax.lower( + *args_specs, **kwargs_specs, + _experimental_lowering_platform=lowering_platforms) + + lowering = lowered._lowering # type: ignore + _check_lowering(lowering) + mlir_module = lowering.stablehlo() + + args_avals_flat, _ = tree_util.tree_flatten(lowered.in_avals) + if "kept_var_idx" in lowering.compile_args: + module_kept_var_idx = tuple(sorted(lowering.compile_args["kept_var_idx"])) + else: + # For pmap + module_kept_var_idx = tuple(range(len(args_avals_flat))) + shape_poly_state = lowering.compile_args["shape_poly_state"] + if (not all(core.is_constant_shape(a.shape) for a in args_avals_flat) + or lowering.compile_args.get("ordered_effects", [])): + # All arguments are kept if we have dimension variables. + assert len(module_kept_var_idx) == len(args_avals_flat) + mlir_module = _wrap_main_func( + mlir_module, args_avals_flat, args_kwargs_tree=lowered.in_tree, + has_platform_index_argument=shape_poly_state.has_platform_index_argument + ) + finally: + shape_poly.thread_local_state.enable_shape_assertions = prev_enable_shape_assertions + + with mlir_module.context: + mlir_module_attrs = mlir_module.operation.attributes + mlir_module_attrs["jax.uses_shape_polymorphism"] = ( + mlir.ir.BoolAttr.get(shape_poly_state.uses_dim_vars)) + + mlir_module_serialized = _serialize_module(mlir_module) + + # Figure out the result types and shapes + if "global_out_avals" in lowering.compile_args: + # This is currently the case for pjit + out_avals_flat = lowering.compile_args["global_out_avals"] + elif "shards" in lowering.compile_args: # for PmapComputation + out_avals_flat = lowering.compile_args["shards"].out_sharded_avals + else: + out_avals_flat = lowered.compile_args["out_avals"] + + # Log and then check the module. + if logging.vlog_is_on(3): + mlir_module_text = mlir.module_to_string(mlir_module) + logmsg = (f"version={version} " + f"lowering_platforms={lowering_platforms} " + f"disabled_checks={disabled_checks}") + logging.info("Lowered JAX module: %s\n", logmsg) + for l in mlir_module_text.splitlines(): + logging.info(l) + + _check_module(mlir_module, + allow_non_replicated_sharding=allow_non_replicated_sharding, + disabled_checks=disabled_checks) + + return Exported( + fun_name=fun_name, + in_tree=lowered.in_tree, + out_tree=lowered.out_tree, + in_avals=tuple(args_avals_flat), + out_avals=tuple(out_avals_flat), + in_shardings=lowering.compile_args["in_shardings"], + out_shardings=lowering.compile_args["out_shardings"], + lowering_platform=lowering_platforms[0], # TODO: remove + lowering_platforms=lowering_platforms, + disabled_checks=tuple(disabled_checks), + mlir_module_serialized=mlir_module_serialized, + module_kept_var_idx=module_kept_var_idx, + uses_shape_polymorphism=shape_poly_state.uses_dim_vars, + serialization_version=version, # type: ignore + _get_vjp=lambda exported: _export_native_vjp(fun_jax, exported)) + + return do_export + + +def _serialize_module(module: ir.Module) -> bytes: + mlir_str = mlir.module_to_bytecode(module) + if hlo.get_api_version() < 4: + target_version = hlo.get_earliest_forward_compatible_version() + else: + # `target_version` is used to manage situations when a StableHLO producer + # (in this case, jax2tf) and a StableHLO consumer were built using + # different versions of StableHLO. + # + # Each StableHLO version `producer_version` has a compatibility window, + # i.e. range of versions [`consumer_version_min`, `consumer_version_max`], + # where StableHLO portable artifacts serialized by `producer_version` + # can be deserialized by `consumer_version` within the window. + # See https://github.com/openxla/stablehlo/blob/main/docs/compatibility.md + # for the exact extent of these compatibility guarantees. + # + # `hlo.get_minimum_version()` returns `consumer_version_min` + # for the current version of StableHLO. We are using it here to maximize + # forward compatibility, i.e. to maximize how far into the past we can go + # and still have the payloads produced by `serialize_portable_artifact` + # compatible with potential consumers from the past. + target_version = hlo.get_minimum_version() + module_serialized = xla_client._xla.mlir.serialize_portable_artifact( + mlir_str, target_version) + return module_serialized + + +def _wrap_main_func( + module: ir.Module, + args_avals_flat: Sequence[core.ShapedArray], + *, + args_kwargs_tree: tree_util.PyTreeDef, + has_platform_index_argument: bool, +) -> ir.Module: + """Wraps the lowered module with a new "main" handling dimension arguments. + + See calling convention documentation for `jax_export.Exported`. + + Args: + module: the HLO module as obtained from lowering. See the calling convention + for inner functions in `jax_export.Exported`. + args_avals_flat: the avals for all the arguments of the lowered function, + which correspond to the array arguments of the `module`. + args_kwargs_tree: the PyTreeDef corresponding to `(args, kwargs)`, for error + messages. + + Returns the wrapped module, without dimension and token arguments. + """ + dim_vars = shape_poly.all_dim_vars(args_avals_flat) + context = mlir.make_ir_context() + with context, ir.Location.unknown(context): + # Make a copy, do not mutate because it may be cached + wrapped_module = ir.Module.parse(mlir.module_to_bytecode(module)) + symbol_table = ir.SymbolTable(wrapped_module.operation) + orig_main = symbol_table["main"] + orig_main.attributes["sym_visibility"] = ir.StringAttr.get("private") + symbol_table.set_symbol_name(orig_main, "_wrapped_jax_export_main") + orig_main_name = ir.StringAttr(symbol_table.insert(orig_main)).value + + def is_token(attrs): + try: + return ir.BoolAttr(ir.DictAttr(attrs)["jax.token"]).value + except KeyError: + return False + + orig_input_types = orig_main.type.inputs + arg_attrs = list(ir.ArrayAttr(orig_main.arg_attrs)) + # The order of args: platform_index_arg, dim args, token args, array args. + nr_platform_index_args = 1 if has_platform_index_argument else 0 + nr_dim_args = len(dim_vars) + nr_token_args = sum(1 for attrs in arg_attrs if is_token(attrs)) + nr_array_args = len(orig_input_types) - nr_platform_index_args - nr_dim_args - nr_token_args + assert nr_array_args >= 0 + assert not any(is_token(attrs) for attrs in arg_attrs[-nr_array_args:]) + (platform_input_types, dim_var_input_types, + token_input_types, array_input_types) = util.split_list( + orig_input_types, [nr_platform_index_args, nr_dim_args, nr_token_args]) + new_main_input_types = platform_input_types + array_input_types + orig_output_types = orig_main.type.results + result_attrs = list(ir.ArrayAttr(orig_main.result_attrs)) + nr_token_results = sum(1 for attrs in result_attrs if is_token(attrs)) + nr_array_results = len(orig_output_types) - nr_token_results + assert nr_array_results >= 0 + assert not any( + is_token(attrs) for attrs in result_attrs[-nr_array_results:]) + new_main_output_types = orig_output_types[-nr_array_results:] + new_main_ftype = ir.FunctionType.get(new_main_input_types, new_main_output_types) + new_main_op = func_dialect.FuncOp( + "main", new_main_ftype, ip=ir.InsertionPoint.at_block_begin(wrapped_module.body)) + new_main_op.attributes["sym_visibility"] = ir.StringAttr.get("public") + try: + new_main_op.arg_attrs = ir.ArrayAttr.get(arg_attrs[0:nr_platform_index_args] + arg_attrs[-nr_array_args:]) + except KeyError: + pass # TODO: better detection if orig_main.arg_attrs does not exist + try: + new_main_op.result_attrs = ir.ArrayAttr.get( + result_attrs[-nr_array_results:]) + except KeyError: + pass + symbol_table.insert(new_main_op) + entry_block = new_main_op.add_entry_block() + with ir.InsertionPoint(entry_block): + module_context = mlir.ModuleContext( + "cpu", "cpu", sharding_impls.ShardingContext([]), + source_info_util.new_name_stack(), + [], itertools.count(1), [], module=wrapped_module, context=context) + ctx = mlir.LoweringRuleContext( + module_context=module_context, primitive=None, + avals_in=args_avals_flat, avals_out=None, + tokens_in=mlir.TokenSet(), tokens_out=None) + new_main_op_array_args = new_main_op.arguments[nr_platform_index_args:] + dim_values = mlir.lower_fun( + functools.partial(shape_poly.compute_dim_vars_from_arg_shapes, + args_avals_flat, args_kwargs_tree=args_kwargs_tree), + multiple_results=True)(ctx, *new_main_op_array_args) + # The arguments to pass to the call to orig_main + orig_main_args: list[ir.Value] = [] + # The platform index and the dimension variables + for arg, arg_type in zip( + list(new_main_op.arguments[0:nr_platform_index_args]) + util.flatten(dim_values), + platform_input_types + dim_var_input_types): + if arg.type != arg_type: + orig_main_args.append(hlo.ConvertOp(arg_type, arg).result) + else: + orig_main_args.append(arg) + # Then the token arguments + orig_main_args.extend(list(mlir.dummy_token()) * nr_token_args) + # Then the array arguments. We insert a ConvertOp as the only use of + # an input argument. This helps the downstream shape refinement because + # it will set the type of input arguments to static shapes, and this + # can invalidate the module if the argument is used as the result of a + # function, or if it appears as the input to a custom_call with + # output_operand_alias attribute. See b/287386268. + for a in new_main_op_array_args: + orig_main_args.append(hlo.ConvertOp(a.type, a).result) + call = func_dialect.CallOp(orig_output_types, + ir.FlatSymbolRefAttr.get(orig_main_name), + orig_main_args) + func_dialect.ReturnOp(call.results[-nr_array_results:]) + symbol_table.set_symbol_name(new_main_op, "main") + + return wrapped_module + +def _check_lowering(lowering) -> None: + if not isinstance(lowering, pxla.MeshComputation): + raise NotImplementedError(f"serialization is supported only for pjit. {lowering}") + + if lowering.compile_args["host_callbacks"] or lowering.compile_args["keepalive"]: + raise NotImplementedError("serialization of host_callbacks is not yet implemented") + # Check that we do not see new compile_args. When we add a compile_args it is + # safe to add it to the allowed_compile_args if it does not change the semantics + # or the calling convention of the lowered module. + allowed_compile_args = [ + "backend", "mesh", "global_in_avals", + "global_out_avals", "in_shardings", "out_shardings", "kept_var_idx", + "spmd_lowering", "auto_spmd_lowering", + "tuple_args", "ordered_effects", "unordered_effects", + "keepalive", "host_callbacks", "pmap_nreps", "committed", + "device_assignment", "jaxpr_debug_info", "shape_poly_state"] + for compile_arg in lowering.compile_args.keys(): + if compile_arg not in allowed_compile_args: + raise NotImplementedError(f"Unrecognized lowered.compile_args[{compile_arg}]") + + # We have not implemented support for some of the compile_args. Check here that + # the compile_args have the values that have been implemented. + not_implemented_msgs = [] + for compile_arg, check_value, err_msg in ( + ("spmd_lowering", lambda v: v, "True"), + ("auto_spmd_lowering", lambda v: not v, "False"), + # tuple_args is a compilation flag, does not affect lowering. + ("tuple_args", lambda v: True, "N/A"), + # unordered_effects do not change the calling convention. Those from + # jax.debug will also result in keepalive being non-empty and unsupported + # custom calls. The CallTfEffect is an exception, but we want to allow + # that one. + ("unordered_effects", lambda v: True, "N/A"), + # ordered_effects are allowed and we ensure that the calling convention is + # unmodified by passing dummy tokens in the main function wrapper. + ("ordered_effects", lambda v: True, "N/A"), + # used for TPU jax.debug, send/recv. Not supported yet. + ("host_callbacks", lambda v: not v, "empty"), + # used on all platforms for callbacks. Not supported yet. + ("keepalive", lambda v: not v, "empty"), + ("pmap_nreps", lambda v: v == 1, "1"), + ("shape_poly_state", lambda v: True, "N/A"), + ): + if compile_arg in lowering.compile_args: + if not check_value(lowering.compile_args[compile_arg]): + not_implemented_msgs.append( + f"{compile_arg} must be {err_msg} and it is {lowering.compile_args[compile_arg]}") + if not_implemented_msgs: + raise NotImplementedError( + "serialization error, unimplemented lowered.compile_args:\n" + + "\n".join(not_implemented_msgs)) + +# These are the JAX custom call target names that are guaranteed to be stable. +# Their backwards compatibility is tested by back_compat_test.py. +_CUSTOM_CALL_TARGETS_GUARANTEED_STABLE = { + "Sharding", "SPMDFullToShardShape", "SPMDShardToFullShape", + "ducc_fft", "dynamic_ducc_fft", "cu_threefry2x32", + # cholesky on CPU + "lapack_spotrf", "lapack_dpotrf", "lapack_cpotrf", "lapack_zpotrf", + # eigh on CPU + "lapack_ssyevd", "lapack_dsyevd", "lapack_cheevd", "lapack_zheevd", + # eigh on GPU + "cusolver_syevj", "cusolver_syevd", + # eigh on TPU + "Eigh", + # eig on CPU + "lapack_sgeev", "lapack_dgeev", "lapack_cgeev", "lapack_zgeev", + # qr on CPU + "lapack_sgeqrf", "lapack_dgeqrf", "lapack_cgeqrf", "lapack_zgeqrf", + # householder product on CPU + "lapack_sorgqr", "lapack_dorgqr", "lapack_cungqr", "lapack_zungqr", + # svd on CPU + "lapack_sgesdd", "lapack_dgesdd", "lapack_cgesdd", "lapack_zgesdd", + # qr on GPU + "cusolver_geqrf", "cublas_geqrf_batched", + "cusolver_geqrf", "cusolver_orgqr", + # qr and svd on TPU + "Qr", "ProductOfElementaryHouseholderReflectors", + # triangular_solve on CPU + "blas_strsm", "blas_dtrsm", "blas_ctrsm", "blas_ztrsm", + # TODO(atondwal, necula): add back_compat tests for lu on CPU/GPU + # # lu on CPU + "lapack_sgetrf", "lapack_dgetrf", "lapack_cgetrf", "lapack_zgetrf", + # schur on CPU + "lapack_sgees", "lapack_dgees", "lapack_cgees", "lapack_zgees", + # # lu on GPU + # "cublas_getrf_batched", "cusolver_getrf", + # "hipblas_getrf_batched", "hipsolver_getrf", + # lu on TPU + "LuDecomposition", + # ApproxTopK on TPU + "ApproxTopK", + "tf.call_tf_function", # From jax2tf.call_tf(func, call_tf_graph=True) + "tpu_custom_call", # Pallas/TPU kernels + # TODO(burmako): maintain backwards compatibility for these, until they + # are upstreamed to StableHLO. + # See https://github.com/openxla/stablehlo/issues/8. + "stablehlo.dynamic_reduce_window", + "stablehlo.dynamic_rng_bit_generator", + "stablehlo.dynamic_top_k", + "shape_assertion", # Used by shape_poly to evaluate assertions +} + + +def _check_module(mod: ir.Module, *, + allow_non_replicated_sharding: bool, + disabled_checks: Sequence[DisabledSafetyCheck]) -> None: + """Run a number of checks on the module. + + Args: + allow_non_replicated_sharding: whether the module is allowed to contain + non_replicated sharding annotations. + disabled_checks: the safety checks that are disabled. + """ + sharding_attr = ir.StringAttr.get("Sharding", mod.context) + shape_assertion_attr = ir.StringAttr.get("shape_assertion", mod.context) + allowed_custom_call_targets: set[str] = copy.copy(_CUSTOM_CALL_TARGETS_GUARANTEED_STABLE) + for dc in disabled_checks: + target = dc.is_custom_call() + if target is not None: + allowed_custom_call_targets.add(target) + + allowed_custom_call_targets_attrs = { + ir.StringAttr.get(target, mod.context) + for target in allowed_custom_call_targets} + disallowed_custom_call_ops: list[str] = [] + def check_sharding(op: ir.Operation, loc: ir.Location): + if not allow_non_replicated_sharding: + try: + sharding = op.attributes["mhlo.sharding"] + except KeyError: + pass + else: + if ir.StringAttr(sharding).value not in ["{replicated}", ""]: + raise ValueError( + "Lowered function does not have a top-level pjit but it has" + f" non-replicated sharding annotations, e.g., {op} at {loc}.\nSee" + " https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#support-for-partitioning" + " for a discussion." + ) + + def check_op(op: ir.Operation): + op_name = op.operation.name + if op_name == "func.func": + check_sharding(op.operation, op.location) + + elif op_name == "stablehlo.custom_call" or op_name == "mhlo.custom_call": + call_target_name_attr = op.operation.attributes["call_target_name"] + if (call_target_name_attr not in allowed_custom_call_targets_attrs): + disallowed_custom_call_ops.append(f"{op} at {op.location}") + if call_target_name_attr == sharding_attr: + check_sharding(op, op.location) + elif call_target_name_attr == shape_assertion_attr: + assert (DisabledSafetyCheck.shape_assertions() not in disabled_checks) + + def walk_operations(op): + check_op(op) + for region in op.operation.regions: + for block in region: + for op in block: + walk_operations(op) + + walk_operations(mod) + if disallowed_custom_call_ops: + disallowed_custom_call_ops_str = "\n".join(disallowed_custom_call_ops) + msg = ("Cannot serialize code with custom calls whose targets have no " + "compatibility guarantees. Examples are:\n" + f"{disallowed_custom_call_ops_str}.\n" + "See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#native-lowering-supports-only-select-custom-calls") + raise ValueError(msg) + + +def _export_native_vjp(primal_fun_jax, primal: Exported) -> Exported: + # Export the VJP of `primal_fun_jax`. See documentation for Exported.vjp + + # Since jax.vjp does not handle kwargs, it is easier to do all the work + # here with flattened functions. + def fun_vjp_jax(*args_and_out_cts_flat_jax): + # Takes a flat list of primals and output cotangents + def flattened_primal_fun_jax(*args_flat): + args, kwargs = primal.in_tree.unflatten(args_flat) + res = primal_fun_jax(*args, **kwargs) + res_flat, res_tree = tree_util.tree_flatten(res) + assert res_tree == primal.out_tree + return res_flat + + args_flat_jax, out_cts_flat_jax = util.split_list(args_and_out_cts_flat_jax, + [len(primal.in_avals)]) + _, pullback_jax = jax.vjp(flattened_primal_fun_jax, *args_flat_jax) + return pullback_jax(out_cts_flat_jax) + + vjp_in_avals = list( + itertools.chain(primal.in_avals, + map(lambda a: a.at_least_vspace(), primal.out_avals))) + + # Expand in_shardings to all in_avals even not kept ones. + all_in_shardings = [sharding_impls.UNSPECIFIED] * len(primal.in_avals) + for idx, in_s in zip(sorted(primal.module_kept_var_idx), + primal.in_shardings): # type: ignore + all_in_shardings[idx] = in_s # type: ignore + all_shardings = all_in_shardings + list(primal.out_shardings) # type: ignore + # Cannot mix unspecified and specified shardings. Make the unspecified + # ones replicated. + specified_shardings = [ + s for s in all_shardings if not sharding_impls.is_unspecified(s)] + + vjp_in_shardings: Any # The primal inputs followed by output cotangents + vjp_out_shardings: Any # The primal output cotangents + if 0 == len(specified_shardings): + vjp_in_shardings = sharding_impls.UNSPECIFIED + vjp_out_shardings = sharding_impls.UNSPECIFIED + else: + if len(specified_shardings) < len(all_shardings): + # There are some specified, but not all; pjit front-end does not liwk + in_s = specified_shardings[0] # pjit will enforce that all have same devices + assert isinstance(in_s, sharding.XLACompatibleSharding) + replicated_s = sharding.GSPMDSharding.get_replicated(in_s._device_assignment) + all_shardings = [ + s if not sharding_impls.is_unspecified(s) else replicated_s + for s in all_shardings] + + vjp_in_shardings = tuple(all_shardings) + vjp_out_shardings = tuple(all_shardings[:len(primal.in_avals)]) + if all(sharding_impls.is_unspecified(s) for s in vjp_out_shardings): + vjp_out_shardings = sharding_impls.UNSPECIFIED + + fun_vjp_jax = pjit.pjit(fun_vjp_jax, + in_shardings=vjp_in_shardings, + out_shardings=vjp_out_shardings) + + return export(fun_vjp_jax, + lowering_platform=primal.lowering_platform, + disabled_checks=primal.disabled_checks)(*vjp_in_avals) + +### Importing + +def call_exported(exported: Exported) -> Callable[..., jax.Array]: + + @jax.custom_vjp + def f_flat(*args_flat): + return call_exported_p.bind(*args_flat, exported=exported) + + def f_flat_vjp_fwd(*args_flat): + # Return the primal arguments as the residual + # TODO: keep as residuals only the arguments that are needed + return f_flat(*args_flat), args_flat + + def f_flat_vjp_bwd(residual, ct_res_flat): + args_flat = residual # residual is the primal argument flat tuple + exp_vjp = exported.vjp() + in_ct_flat = call_exported(exp_vjp)(*args_flat, *ct_res_flat) + return in_ct_flat + + f_flat.defvjp(f_flat_vjp_fwd, f_flat_vjp_bwd) + + def f_imported(*args, **kwargs): + # since custom_vjp does not support kwargs, flatten the function first. + args_flat, in_tree = tree_util.tree_flatten((args, kwargs)) + if in_tree != exported.in_tree: + # Give errors with the precise tree difference; use fake leaves so we can + # use tree_util.equality_errors. + in_args = in_tree.unflatten([0] * in_tree.num_leaves) + exp_in_args = exported.in_tree.unflatten([0] * exported.in_tree.num_leaves) + + msg = ( + "The invocation args and kwargs must have the same pytree structure " + f"as when the function '{exported.fun_name}' was exported, but they " + "have the following structural differences:\n" + + ("\n".join( + f" - {shape_poly.args_kwargs_path_to_str(path)} is a {thing1} in the invocation and a " + f"{thing2} when exported, so {explanation}.\n" + for path, thing1, thing2, explanation + in tree_util.equality_errors(in_args, exp_in_args)))) + raise ValueError(msg) + + res_flat = f_flat(*args_flat) + return exported.out_tree.unflatten(res_flat) + return f_imported + + +# A JAX primitive for invoking a serialized JAX function. +call_exported_p = core.Primitive("call_exported") +call_exported_p.multiple_results = True + +@util.cache() +def _call_exported_abstract_eval(*in_avals: core.AbstractValue, + exported: Exported) -> tuple[core.AbstractValue, ...]: + exported_dim_vars = shape_poly.all_dim_vars(exported.in_avals) + assert len(in_avals) == len(exported.in_avals) # since the pytrees have the same structure + # Check that the expected shapes match the actual ones + for arg_idx, (exp_aval, actual_aval) in enumerate(zip(exported.in_avals, in_avals)): + def pp_arg_dim(dim_idx: Optional[int]) -> str: + return shape_poly.pretty_print_dimension_descriptor(exported.in_tree, + arg_idx, dim_idx) + if len(exp_aval.shape) != len(actual_aval.shape): + raise ValueError( + f"Rank mismatch for {pp_arg_dim(None)}: expected {exp_aval.shape} " + f"and called with {actual_aval.shape}") + if exp_aval.dtype != actual_aval.dtype: + raise ValueError( + f"Dtype mismatch for {pp_arg_dim(None)}: expected {exp_aval.dtype} " + f"and called with {actual_aval.dtype}") + for dim_idx, aval_d in enumerate(exp_aval.shape): + # If the exp_aval has a constant dimension then the actual argument must have + # a matching constant dimension. + if core.is_constant_dim(aval_d): + if (not core.is_constant_dim(actual_aval.shape[dim_idx]) or + aval_d != actual_aval.shape[dim_idx]): + raise ValueError( + f"Shape mismatch for {pp_arg_dim(dim_idx)} " + "(expected same constant): " + f"expected {exp_aval.shape} and called with {actual_aval.shape}") + + # Must express the exported_dim_vars in terms of the shapes in in_avals. + solution, shape_constraints, synth_dim_vars = shape_poly.solve_dim_vars( + exported.in_avals, args_kwargs_tree=exported.in_tree) + synthetic_env = {vname: in_avals[arg_idx].shape[dim_idx] + for (vname, arg_idx, dim_idx) in synth_dim_vars} + synthetic_eval = shape_poly.CachingShapeEvaluator(**synthetic_env) + # We discharge all the constraints statically. This results in much simpler + # composability (because we do not have to worry about the constraints of the + # Exported called recursively; we only need to worry about entry-point + # constraints). This also makes sense from a composibility point of view, + # because we get the same errors if we invoke the exported module, or if we + # trace the exported function. Consider for example, an exported module with + # signature `f32[a, a] -> f32[a]`. If we invoke the module with an argument + # `f32[c, d]` it is better to fail because `c == d` is inconclusive, than + # succeed and add a compile-time check that `c == d`. In the latter case, + # it would be ambiguous whether we should continue tracing with a result + # a type `f32[c]` or `f32[d]`. + shape_constraints.check_statically(synthetic_eval) + exported_dim_values = [synthetic_eval.evaluate(solution[var]) + for var in exported_dim_vars] + return tuple( + core.ShapedArray(core.evaluate_shape(out_aval.shape, exported_dim_vars, + *exported_dim_values), + dtype=out_aval.dtype, weak_type=out_aval.weak_type, + named_shape=out_aval.named_shape) + for out_aval in exported.out_avals) + + +call_exported_p.def_abstract_eval(_call_exported_abstract_eval) + +def _call_exported_impl(*args, exported: Exported): + return dispatch.apply_primitive(call_exported_p, *args, exported=exported) + +call_exported_p.def_impl(_call_exported_impl) + +def _call_exported_lowering(ctx: mlir.LoweringRuleContext, *args, + platform: str, + exported: Exported): + # TODO: implement true multi-platform lowering for call_exported + if (platform not in exported.lowering_platforms and + DisabledSafetyCheck.platform() not in exported.disabled_checks): + raise ValueError( + f"The exported function '{exported.fun_name}' was lowered for " + f"platforms '{exported.lowering_platforms}' but it is used " + f"on '{platform}'.") + + if exported.uses_shape_polymorphism: + ctx.module_context.shape_poly_state.uses_dim_vars = True + + submodule = ir.Module.parse(exported.mlir_module()) + symtab = ir.SymbolTable(submodule.operation) + # The called function may have been exported with polymorphic shapes and called + # now with more refined shapes. We insert hlo.ConvertOp to ensure the module + # is valid. + def convert_shape(x: ir.Value, x_aval: core.AbstractValue, new_aval: core.AbstractValue) -> ir.Value: + new_ir_type = mlir.aval_to_ir_type(new_aval) + if x.type != new_ir_type: + return mlir.convert_hlo(ctx, x, x_aval, new_aval) + else: + return x + + callee_type = symtab["main"].type + # TODO: maybe cache multiple calls + fn = mlir.merge_mlir_modules(ctx.module_context.module, + f"call_exported_{exported.fun_name}", + submodule) + kept_args = [ + convert_shape(a, a_aval, exported_in_aval) + for i, (a, a_aval, exported_in_aval) in enumerate(zip(args, ctx.avals_in, exported.in_avals)) + if i in exported.module_kept_var_idx] + if len(exported.lowering_platforms) > 1: + # The exported module takes a platform index argument + # TODO: implement proper handling of the platform_index when we are + # in a multi-platform lowering context. + platform_index = exported.lowering_platforms.index(platform) + arg_width = callee_type.inputs[0].element_type.width + assert arg_width in [32, 64] + platform_index = np.int32(platform_index) if arg_width == 32 else np.int64(platform_index) # type: ignore + kept_args = [mlir.ir_constant(platform_index)] + kept_args + call = func_dialect.CallOp(callee_type.results, + ir.FlatSymbolRefAttr.get(fn), + kept_args) + # The ctx.avals_out already contain the abstract values refined by + # _call_exported_abstract_eval. + return tuple( + convert_shape(out, out_aval, refined_out_aval) + for out, out_aval, refined_out_aval in zip(call.results, exported.out_avals, ctx.avals_out)) + + +for _p in ("cpu", "tpu", "cuda", "rocm"): + mlir.register_lowering(call_exported_p, + functools.partial(_call_exported_lowering, platform=_p), + platform=_p) diff --git a/jax/experimental/export/shape_poly.py b/jax/experimental/export/shape_poly.py new file mode 100644 index 000000000..a66497b74 --- /dev/null +++ b/jax/experimental/export/shape_poly.py @@ -0,0 +1,1593 @@ +# Copyright 2021 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Shape polymorphism support. + +We introduce a set of dimension variables at the top-level of a `jit` function. +They are introduced implicitly by way of specifying for each dimension of each +argument a symbolic dimension expression in terms of some dimension variables. +All dimension variables are assumed to range over integers greater or equal to 1. + +Symbolic dimensions overload some integer operations, such as +add, multiply, divide, equality, etc. The JAX NumPy layer and the LAX layers have been +touched up to be sensitive to handling shapes that contain symbolic dimensions. +This enables many JAX programs to be traced with symbolic dimensions +in some dimensions. A priority has been to enable the batch +dimension in neural network examples to be polymorphic. + +This was built initially for jax2tf, but it is now customizeable to be +independent of TF. The best documentation at the moment is in the +jax2tf.convert docstring, and the +[README](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md). +""" + +import collections +from collections.abc import Iterable, Sequence +import dataclasses +from enum import Enum +import functools +import itertools +import io +import math +import operator as op +import threading +import tokenize +from typing import Any, Optional, Union + +import numpy as np +import opt_einsum + +import jax +from jax import config +from jax.interpreters import xla + +from jax._src import core +from jax._src import dtypes +from jax._src import effects +from jax._src.lax import lax +from jax._src.lib import version as jaxlib_version +from jax._src.interpreters import mlir +from jax._src.numpy import lax_numpy +from jax._src import tree_util +from jax._src import util +from jax._src.typing import DimSize, Shape + + +TfVal = Any +DimVarEnv = dict[str, jax.Array] +DType = Any + +class InconclusiveDimensionOperation(core.InconclusiveDimensionOperation): + """Raised when we cannot conclusively compute with symbolic dimensions.""" + + _help_msg = """ +This error arises for comparison operations with shapes that +are non-constant, and the result of the operation cannot be represented as +a boolean value for all values of the symbolic dimensions involved. + +Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables +for more details. +""" + + def __init__(self, message: str): + error_msg = f"{message}\n{InconclusiveDimensionOperation._help_msg}" + # https://github.com/python/mypy/issues/5887 + super().__init__(error_msg) # type: ignore + +class _ShapePolyThreadLocalState(threading.local): + + def __init__(self): + # TODO(necula): this does not play well with some lowering caches, because + # this state is not part of the cache key. + self.enable_shape_assertions = True + +thread_local_state = _ShapePolyThreadLocalState() + +class _DimAtom: + """Represents an atom in a symbolic dimension expression. + + Atoms are either variables, or expressions of the form floordiv(E1, E2) or + mod(E1, E2). Atoms are multiplied to form monomials (see _DimMon), and + monomials are added to form symbolic expressions (see _DimExpr). + + Args: + * var: if specified then the atom is a dimension variable. `operation` + must be `None`. + * operation: if specified then the atom is an operation applied to + `operands`. One of `FLOORDIR` or `MOD` or `NON_NEGATIVE`. `var` must be `None` + * operands: the operands to which the operation is applied. + """ + # The supported operations + FLOORDIV = "floordiv" + MOD = "mod" + NON_NEGATIVE = "non_negative" # The max of the operand and 0 + + def __init__(self, *operands: '_DimExpr', + var: Optional[str] = None, + operation: Optional[str] = None): + if var is not None: + assert operation is None + assert not operands + else: + assert operation is not None + self.var = var + self.operation = operation + self.operands = operands + + @classmethod + def from_var(cls, v: str) -> '_DimAtom': + return _DimAtom(var=v) + + def to_var(self) -> Optional[str]: + return self.var + + def get_vars(self) -> set[str]: + # All the vars that appear + if self.var is not None: + return {self.var} + else: + acc = set() + for opnd in self.operands: + acc.update(opnd.get_vars()) + return acc + + @classmethod + def from_operation(cls, operation: str, *operands: '_DimExpr') -> '_DimAtom': + return _DimAtom(*operands, operation=operation) + + def __str__(self): + if self.var is not None: + return self.var + opnd_str = ", ".join([str(opnd) for opnd in self.operands]) + return f"{self.operation}({opnd_str})" + __repr__ = __str__ + + def __hash__(self): + return hash((self.var, self.operation, *self.operands)) + + def __eq__(self, other: Any): + # Used only for hashing + if not isinstance(other, _DimAtom): return False + if (self.var is None) != (other.var is None): return False + if self.var is not None: + return self.var == other.var + else: + def symbolic_equal(e1: '_DimExpr', e2: '_DimExpr') -> bool: + try: + return e1 == e2 + except InconclusiveDimensionOperation: + return False + return (self.operation == other.operation and + all(symbolic_equal(self_o, other_o) + for self_o, other_o in zip(self.operands, other.operands))) + + def __lt__(self, other: '_DimAtom'): + """ + Comparison to another atom in graded reverse lexicographic order. + Used only for determining a sorting order, does not relate to the + comparison of the values of the atom. + """ + if self.var is not None and other.var is not None: + return self.var < other.var + elif self.var is not None: + return True + elif other.var is not None: + return True + elif self.operation != other.operation: + return self.operation < other.operation # type: ignore + else: + return id(self) < id(other) + + def bounds(self) -> tuple[float, float]: + """Returns the lower and upper bounds, or -+ inf.""" + if self.var is not None: + return (1, np.inf) # variables are assumed to be >= 1 + opnd_bounds = [opnd.bounds() for opnd in self.operands] + if self.operation == _DimAtom.FLOORDIV: # a // b + (a_l, a_u), (b_l, b_u) = opnd_bounds + def math_floor_with_inf(a: float, b: float): # math.floor, but aware of inf + assert b != 0 + if not np.isinf(b): # divisor is finite + return math.floor(a / b) if not np.isinf(a) else -np.inf if (a >= 0) != (b >= 0) else np.inf + elif not np.isinf(a): # dividend is finite and divisor is infinite + return -1 if (a >= 0) != (b >= 0) else 0 + else: # both dividend and divisor are infinite + return -np.inf if (a >= 0) != (b >= 0) else np.inf + + # Same reasoning as for multiplication: the bounds are among the cross-product + # of the bounds. + bound_candidates = [math_floor_with_inf(a_l, b_l), math_floor_with_inf(a_l, b_u), + math_floor_with_inf(a_u, b_l), math_floor_with_inf(a_u, b_u)] + return (min(*bound_candidates), max(*bound_candidates)) + + elif self.operation == _DimAtom.MOD: + _, (b_l, b_u) = opnd_bounds + if b_l > 0: # positive divisor + return (0, b_u - 1) + elif b_u < 0: # negative divisor + return (b_l + 1, 0) + else: + return (-np.inf, np.inf) + + elif self.operation == _DimAtom.NON_NEGATIVE: + (b_l, b_h), = opnd_bounds + return (max(0, b_l), max(0, b_h)) + + else: + assert False + + def evaluate(self, env: DimVarEnv): + if self.var is not None: + try: + return env[self.var] + except KeyError: + err_msg = ( + f"Encountered dimension variable '{self.var}' that is not appearing in the shapes of the used function arguments.\n" + "Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#dimension-variables-must-be-solvable-from-the-input-shapes for more details.") + raise KeyError(err_msg) + else: + operand_values = [opnd.evaluate(env) for opnd in self.operands] + if self.operation == _DimAtom.FLOORDIV: + return divmod(*operand_values)[0] # type: ignore + elif self.operation == _DimAtom.MOD: + return divmod(*operand_values)[1] # type: ignore + elif self.operation == _DimAtom.NON_NEGATIVE: + return lax.max(operand_values[0], 0) + else: + assert False, self.operation + +class _DimMon(dict): + """Represents a multiplication of atoms. + + The representation is a dictionary mapping _DimAtom to exponent. + The exponents are integers >= 1. + """ + def __hash__(self): + return hash(frozenset(self.items())) + + def __str__(self): + return "*".join(f"{key}^{exponent}" if exponent != 1 else str(key) + for key, exponent in sorted(self.items())) + + @classmethod + def from_var(cls, v: str) -> '_DimMon': + return _DimMon({_DimAtom.from_var(v): 1}) + + @classmethod + def from_atom(clscls, a: _DimAtom, aexp: int): + return _DimMon({a: aexp}) + + def to_var(self) -> Optional[str]: + """Extract the variable name "x", from a monomial "x". + Return None, if the monomial is not a single variable.""" + items = self.items() + if len(items) != 1: + return None + (a, aexp), = items + if aexp != 1: + return None + return a.to_var() + + def get_vars(self) -> set[str]: + # All the vars that appear in the monomial + acc = set() + for a in self.keys(): + acc.update(a.get_vars()) + return acc + + @classmethod + def from_operation(cls, operation: str, *operands: '_DimExpr') -> '_DimMon': + return _DimMon({_DimAtom.from_operation(operation, *operands): 1}) + + @property + def degree(self): + return sum(self.values()) + + def __lt__(self, other: '_DimMon'): + """ + Comparison to another monomial in graded reverse lexicographic order. + Used only for determining a sorting order, does not relate to the + comparison of the values of the monomial. + """ + self_key = -self.degree, tuple(sorted(self)) + other_key = -other.degree, tuple(sorted(other)) + return self_key > other_key + + def mul(self, other: '_DimMon') -> '_DimMon': + """ + Returns the product with another monomial. Example: (n^2*m) * n == n^3 * m. + """ + return _DimMon(collections.Counter(self) + collections.Counter(other)) + + def divide(self, divisor: '_DimMon') -> '_DimMon': + """ + Divides by another monomial. Raises a InconclusiveDimensionOperation + if the result is not a monomial. + For example, (n^3 * m) // n == n^2*m, but n // m fails. + """ + d = collections.Counter(self) + for key, exponent in divisor.items(): + diff = self.get(key, 0) - exponent + if diff < 0: + raise InconclusiveDimensionOperation(f"Cannot divide {self} by {divisor}.") + elif diff == 0: del d[key] + elif diff > 0: d[key] = diff + return _DimMon(d) + + def bounds(self) -> tuple[float, float]: + """Returns the lower and upper bounds, or -+inf.""" + # The bounds of a product are among the product of bounds. + bounds = [] + for a, exp in self.items(): + a_l, a_u = a.bounds() + assert a_l <= a_u + bounds.append((a_l ** exp, a_u ** exp)) + + candidates = [math.prod(atom_bounds) for atom_bounds in itertools.product(*bounds)] + return (min(*candidates), max(*candidates)) # type: ignore + + + def evaluate(self, env: DimVarEnv): + prod = lambda xs: functools.reduce(_evaluate_multiply, xs) if xs else core.dim_constant(1) + def pow_opt(v, p: int): + return v if p == 1 else prod([v] * p) + return prod([pow_opt(a.evaluate(env), deg) for a, deg in self.items()]) + + +class _DimExpr(): + """Symbolic expression in terms of dimension variables. + + A dimension expression is an addition of products (_DimMon) + of atoms (_DimAtom). + + We overload integer operations, but we do that soundly, raising + :class:`InconclusiveDimensionOperation` when the result is not + representable as a _DimExpr. + + The representation of a _DimExpr is as a dictionary mapping _DimMon to + integer coefficients. The special monomial `_DimMon()` is mapped to the + free integer coefficient of the expression. + """ + + __array_priority__ = 1000 # Same as tracer, for __radd__ and others on ndarray + def __init__(self, coeffs: dict[_DimMon, int]): + # Do not construct _DimExpr directly, unless you are sure that coeffs is + # normalized; Use _DimExpr.normalize. + # Takes ownership of coeffs + self._coeffs = coeffs or {_DimMon(): 0} + + def monomials(self) -> Iterable[tuple[_DimMon, int]]: + return self._coeffs.items() + + @classmethod + def _add_coeffs(cls, coeffs: dict[_DimMon, int], mon: _DimMon, coeff: int): + """Do `coeffs[mon] += coeff` but remove 0 coefficients.""" + old_c = coeffs.get(mon) + if old_c is None: + if coeff != 0: coeffs[mon] = coeff + else: + new_c = old_c + coeff + if new_c == 0: + del coeffs[mon] + else: + coeffs[mon] = new_c + + @classmethod + def normalize(cls, coeffs: dict[_DimMon, int]) -> DimSize: + """The main constructor for _DimExpr. + + Ensures that the symbolic dimension is normalized, e.g., + it is represented as a Python int if it is known to be a constant. + """ + # TODO(necula): profile and optimize this + has_non_zero_degree = False + free_const = 0 + new_coeffs: dict[_DimMon, int] = {} + for mon, coeff in coeffs.items(): + if coeff == 0: continue + if mon.degree == 0: # A constant, there can be a single one + free_const = coeff + else: + has_non_zero_degree = True + + new_coeffs[mon] = new_coeffs.get(mon, 0) + coeff + + if has_non_zero_degree: + return _DimExpr(new_coeffs) + else: + return int(free_const) + + @classmethod + def normalize_floordiv_times_divisor(cls, coeffs: dict[_DimMon, int]) -> DimSize: + # Look for floordiv(E, M) * M and turn into E - mod(E, M). This comes + # up when handling strided convolution. + for dec in _decompose_expr(_DimExpr(coeffs), _DimAtom.FLOORDIV): + # e = factor * floordiv(operands)^exp * rest_monomial + rest_expr + if dec.exp != 1: + continue + if dec.rest_monomial == 1 and dec.factor == 1: + continue + m_trimmed, m_remainder = divmod(dec.factor * dec.rest_monomial, dec.operands[1]) + if m_remainder == 0: + return m_trimmed * (dec.operands[0] - _DimExpr.from_operation(_DimAtom.MOD, *dec.operands)) + dec.rest_expr + return _DimExpr.normalize(coeffs) + + @classmethod + def from_monomial(cls, mon: _DimMon, exp: int): + return _DimExpr.normalize({mon: exp}) + + @classmethod + def from_var(cls, v: str) -> '_DimExpr': + return _DimExpr({_DimMon.from_var(v): 1}) + + @classmethod + def from_operation(cls, operation: str, *operands: '_DimExpr') -> '_DimExpr': + return _DimExpr.from_monomial(_DimMon.from_operation(operation, *operands), 1) + + def to_var(self) -> Optional[str]: + """Extract the variable name "x", from a symbolic expression.""" + items = self.monomials() + if len(items) != 1: # type: ignore + return None + (mon, mon_count), = items + if mon_count != 1: + return None + return mon.to_var() + + def get_vars(self) -> set[str]: + """The variables that appear in a symbolic dimension.""" + acc = set() + for mon, _ in self.monomials(): + acc.update(mon.get_vars()) + return acc + + def eq(self, other: DimSize) -> bool: + lb, ub = _ensure_poly(self - other, "eq").bounds() + if lb == ub == 0: + return True + if lb > 0 or ub < 0: + return False + # See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#comparison-of-symbolic-dimensions-is-partially-supported + return False + + def inconclusive_comparison(self, operation: str, op: Any) -> Exception: + return InconclusiveDimensionOperation( + f"Symbolic dimension comparison '{self}' {operation} '{op}' is inconclusive.\n" + "See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#comparison-of-symbolic-dimensions-is-partially-supported.") + + def ge(self, other: DimSize) -> bool: + lb, ub = _ensure_poly(self - other, "ge").bounds() + if lb >= 0: + return True + if ub < 0: + return False + raise self.inconclusive_comparison(">=", other) + + def __hash__(self): + return hash(tuple(sorted(self.monomials()))) + + def __str__(self): + def _one_monomial(mon, c): + if mon.degree == 0: + return str(c) + if c == 1: + return str(mon) + return f"{c}*{mon}" + return " + ".join(_one_monomial(mon, c) + for mon, c in sorted(self.monomials(), reverse=True)) + + def __repr__(self): + return str(self) + + # We overload +, -, *, because they are fully defined for _DimExpr. + def __add__(self, other): + if isinstance(other, core.Tracer) or not _convertible_to_poly(other): + return self.__jax_array__().__add__(other) + + other = _ensure_poly(other, "add") + coeffs = self._coeffs.copy() + for mon, coeff in other.monomials(): + _DimExpr._add_coeffs(coeffs, mon, coeff) + return _DimExpr.normalize_floordiv_times_divisor(coeffs) + + def __radd__(self, other): + if isinstance(other, core.Tracer) or not _convertible_to_poly(other): + return self.__jax_array__().__radd__(other) + return _ensure_poly(other, "add").__add__(self) + + def __sub__(self, other): + if isinstance(other, core.Tracer) or not _convertible_to_poly(other): + return self.__jax_array__().__sub__(other) + return self + -_ensure_poly(other, "sub") + + def __rsub__(self, other): + if isinstance(other, core.Tracer) or not _convertible_to_poly(other): + return self.__jax_array__().__rsub__(other) + return _ensure_poly(other, "sub").__sub__(self) + + def __neg__(self) -> '_DimExpr': + return _DimExpr({mon: -coeff for mon, coeff in self.monomials()}) + + def __mul__(self, other): + if isinstance(other, core.Tracer) or not _convertible_to_poly(other): + return self.__jax_array__().__mul__(other) + other = _ensure_poly(other, "mul") + coeffs: dict[_DimMon, int] = {} + for mon1, coeff1 in self.monomials(): + for mon2, coeff2 in other.monomials(): + mon = mon1.mul(mon2) + _DimExpr._add_coeffs(coeffs, mon, coeff1 * coeff2) + return _DimExpr.normalize_floordiv_times_divisor(coeffs) + + def __rmul__(self, other): + if isinstance(other, core.Tracer) or not _convertible_to_poly(other): + return self.__jax_array__().__rmul__(other) + return _ensure_poly(other, "mul").__mul__(self) + + def __pow__(self, power, modulo=None): + assert modulo is None + try: + power = int(power) + except: + raise InconclusiveDimensionOperation(f"Symblic dimension cannot be raised to non-integer power '{self}' ^ '{power}'") + return functools.reduce(op.mul, [self] * power) + + def __floordiv__(self, divisor): + if isinstance(divisor, core.Tracer) or not _convertible_to_poly(divisor): + return self.__jax_array__().__floordiv__(divisor) + return self.divmod(_ensure_poly(divisor, "floordiv"))[0] + + def __rfloordiv__(self, other): + if isinstance(other, core.Tracer) or not _convertible_to_poly(other): + return self.__jax_array__().__rfloordiv__(other) + return _ensure_poly(other, "floordiv").__floordiv__(self) + + def __truediv__(self, divisor): + # Used for "/", which always returns a float + return self.__jax_array__().__truediv__(divisor) + + def __rtruediv__(self, dividend): + # Used for "/", when dividend is not a _DimExpr + return self.__jax_array__().__rtruediv__(dividend) + + def __mod__(self, divisor): + if isinstance(divisor, core.Tracer) or not _convertible_to_poly(divisor): + return self.__jax_array__().__mod__(divisor) + return self.divmod(_ensure_poly(divisor, "mod"))[1] + + def __rmod__(self, dividend): + if isinstance(dividend, core.Tracer) or not _convertible_to_poly(dividend): + return self.__jax_array__().__rmod__(dividend) + return _ensure_poly(dividend, "mod").__mod__(self) + + def __divmod__(self, divisor): + if isinstance(divisor, core.Tracer) or not _convertible_to_poly(divisor): + return self.__jax_array__().__divmod__(divisor) + return self.divmod(_ensure_poly(divisor, "divmod")) + + def __rdivmod__(self, dividend): + if isinstance(dividend, core.Tracer) or not _convertible_to_poly(dividend): + return self.__jax_array__().__rdivmod__(dividend) + return _ensure_poly(dividend, "divmod").__divmod__(self) + + def __int__(self): + if self.is_constant: + return op.index(next(iter(self._coeffs.values()))) + else: + raise InconclusiveDimensionOperation(f"Symbolic dimension '{self}' used in a context that requires a constant") + + # We must overload __eq__ and __ne__, or else we get unsound defaults. + __eq__ = eq + def __ne__(self, other: DimSize) -> bool: + return not self.eq(other) + + __ge__ = ge + + def __le__(self, other: DimSize): + try: + return _ensure_poly(other, "le").__ge__(self) + except InconclusiveDimensionOperation as e: + raise self.inconclusive_comparison("<=", other) from e + + def __gt__(self, other: DimSize): + try: + return not _ensure_poly(other, "gt").__ge__(self) + except InconclusiveDimensionOperation as e: + raise self.inconclusive_comparison(">", other) from e + + def __lt__(self, other: DimSize): + try: + return not self.__ge__(other) + except InconclusiveDimensionOperation as e: + raise self.inconclusive_comparison("<", other) from e + + def divmod(self, divisor: "_DimExpr") -> tuple[DimSize, int]: + """ + Floor division with remainder (divmod) generalized to polynomials. + If the `divisor` is not a constant, the remainder must be 0. + If the `divisor` is a constant, the remainder may be non 0, for consistency + with integer divmod. + + :return: Quotient resulting from polynomial division and integer remainder. + """ + assert isinstance(divisor, _DimExpr) + try: + dmon, dcount = divisor.leading_term + dividend, quotient = self, 0 + # invariant: self = dividend + divisor * quotient + # quotient and dividend are changed in the loop; the leading term of + # dividend decreases at each iteration. + while is_poly_dim(dividend) and not dividend.is_constant: + mon, count = dividend.leading_term + try: + qmon = mon.divide(dmon) + except InconclusiveDimensionOperation: + raise InconclusiveDimensionOperation("") + qcount, rcount = divmod(count, dcount) + if rcount != 0: + raise InconclusiveDimensionOperation("") + + q = _DimExpr.from_monomial(qmon, qcount) + quotient += q + dividend -= q * divisor # type: ignore[assignment] + + dividend = int(dividend) # type: ignore[assignment] + if divisor.is_constant: + q, r = divmod(dividend, int(divisor)) # type: ignore + quotient += q + remainder = r + else: + if dividend != 0: + raise InconclusiveDimensionOperation("") + remainder = 0 + + if config.jax_enable_checks: + assert self == divisor * quotient + remainder + return quotient, remainder + except InconclusiveDimensionOperation: + return (_DimExpr.from_operation(_DimAtom.FLOORDIV, self, divisor), # type: ignore + _DimExpr.from_operation(_DimAtom.MOD, self, divisor)) + + def bounds(self) -> tuple[float, float]: + """Returns the lower and upper bounds, or -+inf.""" + lb = ub = self._coeffs.get(_DimMon(), 0) # The free coefficient + for mon, coeff in self.monomials(): + if mon.degree == 0: continue # We already included the free coefficient + m_l, m_u = mon.bounds() + assert m_l <= m_u and coeff != 0 + item_l, item_u = coeff * m_l, coeff * m_u + lb = lb + min(item_l, item_u) # type: ignore + ub = ub + max(item_l, item_u) # type: ignore + + if lb != -np.inf or ub != np.inf: + return lb, ub + # Watch for special-case: ct*a - ct*mod(b, a) >= 1 when ct >= 0 and a >= 0 + # TODO(necula): add more principled support for floordiv and mod + # For example, this will miss "1 + a - mod(b, a)" + for dec in _decompose_expr(self, _DimAtom.MOD): + # E = factor*mod(op1, op2)^exp * rest_monomial + rest_expr + if dec.exp == 1 and dec.rest_monomial == 1 and dec.rest_expr == - dec.factor * dec.operands[1]: + try: + if dec.operands[1] <= 0: + continue + except InconclusiveDimensionOperation: + continue + if dec.factor > 0: + return (-np.inf, -1) + else: + return (1, np.inf) + + return lb, ub + + @property + def is_constant(self): + return len(self._coeffs) == 1 and next(iter(self._coeffs)).degree == 0 + + @property + def leading_term(self) -> tuple[_DimMon, int]: + """Returns the highest degree term that comes first lexicographically.""" + return max(self.monomials()) + + def evaluate(self, env: DimVarEnv): + # Evaluates as a value of dtype=core.dim_value_dtype() + terms = [_evaluate_multiply(mon.evaluate(env), core.dim_constant(coeff)) + for mon, coeff in self.monomials()] + return functools.reduce(_evaluate_add, terms) if len(terms) > 1 else terms[0] + + def non_negative(self) -> "_DimExpr": + return _DimExpr.from_operation(_DimAtom.NON_NEGATIVE, self) + + @staticmethod + def get_aval(dim: "_DimExpr"): + return core.dim_value_aval() + + def dimension_as_value(self): + """Turns a dimension size into a Jax value that we can compute with.""" + return _dim_as_value(self) + + def __jax_array__(self): + # Used for implicit coercions of polynomials as JAX arrays + return _dim_as_value(self) + +@dataclasses.dataclass +class _Decomposition: + """Decomposition of an expression around an operation atom. + + E = factor * mod(*operands)^exp * rest_monomial + rest_expr + """ + factor: int + operands: Sequence[_DimExpr] + exp: int + rest_monomial: _DimExpr + rest_expr: _DimExpr + + +def _decompose_expr(e: _DimExpr, operation: str) -> Iterable[_Decomposition]: + for m, m_factor in e.monomials(): + atoms = [(a, aexp) for a, aexp in m.items() if a.operation == operation] + if atoms: + e_minus_m_coeffs = e._coeffs.copy() + del e_minus_m_coeffs[m] + for a, aexp in atoms: + yield _Decomposition( + factor=m_factor, + operands=a.operands, + exp=aexp, + rest_monomial=_DimExpr({m.divide(_DimMon.from_atom(a, aexp)): 1}), + rest_expr=_DimExpr(e_minus_m_coeffs)) + +core.pytype_aval_mappings[_DimExpr] = _DimExpr.get_aval +xla.pytype_aval_mappings[_DimExpr] = _DimExpr.get_aval +dtypes._weak_types.append(_DimExpr) + +def _convertible_to_int(p: DimSize) -> bool: + try: + op.index(p) + return True + except: + return False + +def _ensure_poly(p: DimSize, + operation_name: str) -> _DimExpr: + if isinstance(p, _DimExpr): return p + if _convertible_to_int(p): + return _DimExpr({_DimMon(): op.index(p)}) + raise TypeError(f"Symnbolic dimension {operation_name} not supported for {p}.") + +def _convertible_to_poly(p: DimSize) -> bool: + return isinstance(p, _DimExpr) or _convertible_to_int(p) + +def is_poly_dim(p: DimSize) -> bool: + return isinstance(p, _DimExpr) + +dtypes.python_scalar_dtypes[_DimExpr] = dtypes.python_scalar_dtypes[int] + +def _einsum_contract_path(*operands, **kwargs): + """Like opt_einsum.contract_path, with support for DimExpr shapes. + + We use opt_einsum.contract_path to compute the schedule, using a fixed + constant for all dimension variables. This is safe because we throw an + error if there are more than 1 contractions. Essentially, we just use + opt_einsum.contract_path to parse the specification. + """ + + # Replace the polymorphic shapes with some concrete shapes for calling + # into opt_einsum.contract_path, because the latter wants to compute the + # sizes of operands and intermediate results. + fake_ops = [] + for operand in operands: + # We replace only array operands + if not hasattr(operand, "dtype"): + fake_ops.append(operand) + else: + shape = np.shape(operand) + def fake_dim(d): + if core.is_constant_dim(d): + return d + else: + if not isinstance(d, _DimExpr): + raise TypeError(f"Encountered unexpected shape dimension {d}") + # It is Ok to replace all polynomials with the same value. We may miss + # here some errors due to non-equal dimensions, but we catch them + # later. + return 8 + fake_ops.append(jax.ShapeDtypeStruct(tuple(map(fake_dim, shape)), + operand.dtype)) + + contract_fake_ops, contractions = opt_einsum.contract_path(*fake_ops, + **kwargs) + contract_operands = [] + for operand in contract_fake_ops: + idx = tuple(i for i, fake_op in enumerate(fake_ops) if operand is fake_op) + assert len(idx) == 1 + contract_operands.append(operands[idx[0]]) + return contract_operands, contractions + +lax_numpy._poly_einsum_handlers[_DimExpr] = _einsum_contract_path + +# To implement shape-constraint checking we use a shape assertion primitive. +# shape_assertion_p.bind(assert_what: bool, *error_message_inputs, +# error_message="...{0}...{1}") +# where "{0}" refers to error_message_inputs[0], etc. +shape_assertion_p = core.Primitive("shape_assertion") +shape_assertion_p.multiple_results = True +shape_assertion_p.def_effectful_abstract_eval( + lambda *_, **__: ((), {shape_assertion_effect})) # type: ignore + +def _shape_assertion_lowering_rule(ctx: mlir.LoweringRuleContext, + assert_what: mlir.ir.Value, + *error_message_inputs: mlir.ir.Value, + error_message: str): + op = mlir.custom_call( + "shape_assertion", + result_types=[], # No results + operands=[assert_what, *error_message_inputs], + has_side_effect=True, + extra_attributes=dict(error_message=mlir.ir.StringAttr.get(error_message)) + ) + return op.results + +mlir.register_lowering(shape_assertion_p, _shape_assertion_lowering_rule) + +class ShapeAssertionEffect(effects.Effect): + __str__ = lambda _: "ShapeAssertionEffect" + +shape_assertion_effect = ShapeAssertionEffect() + +effects.lowerable_effects.add_type(ShapeAssertionEffect) +effects.control_flow_allowed_effects.add_type(ShapeAssertionEffect) +effects.remat_allowed_effects.add_type(ShapeAssertionEffect) +effects.custom_derivatives_allowed_effects.add_type(ShapeAssertionEffect) + +def shape_assertion(assert_what: jax.Array, + *error_message_inputs: jax.Array, + error_message: str) -> None: + """Adds a shape assertion in the code. + + Args: + assert_what: a boolean asserted to be true. Must be computed based only + on dimension expressions, so that it can be evaluated after shape + refinement. + error_message_inputs: integers expressions whose values can be referenced + in the `error_message`. Must be computed based only + on dimension expressions, so that they can be evaluated after shape + refinement. + error_message: an error message, possibly containing format specifiers + {0}, {1}, ..., referencing the values of the `error_message_inputs`. + The format specifiers are sometimes processed with Python's + `string::format` method, and sometimes with `llvm::formatv`. + """ + if thread_local_state.enable_shape_assertions: + shape_assertion_p.bind(assert_what, *error_message_inputs, + error_message=error_message) + +# A JAX primitive with no array arguments but with a dimension parameter +# that is a DimExpr. The value of the primitive is the value of the dimension, +# using int64 in x64 mode or int32 otherwise (core.dim_value_dtype()) +dim_as_value_p = core.Primitive("dim_as_value") +dim_as_value_p.def_abstract_eval(lambda dim: core.dim_value_aval()) + +def dim_as_value_impl(dim: DimSize): + raise NotImplementedError( + "Evaluation rule for 'dim_as_value' is not implemented. " + "It seems that you are using shape polymorphism outside jax2tf.") + +dim_as_value_p.def_impl(dim_as_value_impl) +def _dim_as_value(dim: DimSize): + return dim_as_value_p.bind(dim=dim) + +def _dim_as_value_lowering(ctx: mlir.LoweringRuleContext, *, + dim): + res, = mlir.eval_dynamic_shape(ctx, (dim,)) + out_type = mlir.aval_to_ir_type(ctx.avals_out[0]) + if out_type != res.type: # type: ignore + return mlir.hlo.ConvertOp(out_type, res).results + else: + return [res] + +mlir.register_lowering(dim_as_value_p, _dim_as_value_lowering) + + +class PolyShape(tuple): + """Tuple of polymorphic dimension specifications. + + See docstring of :func:`jax2tf.convert`. + """ + + def __init__(self, *dim_specs): + tuple.__init__(dim_specs) + + def __new__(cls, *dim_specs): + for ds in dim_specs: + if not isinstance(ds, (int, str)) and ds != ...: + msg = (f"Invalid polymorphic shape element: {repr(ds)}; must be a string " + "representing a dimension variable, or an integer, or ...") + raise ValueError(msg) + return tuple.__new__(PolyShape, dim_specs) + + def __str__(self): + return "(" + ", ".join(["..." if d is ... else str(d) for d in self]) + ")" + + +def _parse_spec(shape_spec: Union[str, PolyShape, None], + arg_shape: Sequence[Optional[int]]) -> Sequence[DimSize]: + """Parses the shape polymorphic specification for one array argument. + + We have to be able to parse all strings produced by str(_DimExpr) because + sometimes the output polymorphic shapes of one function become the input + polymorphic shapes of another. + + Args: + shape_spec: a shape polymorphic specification. None stands for "...". + arg_shape: an actual shape, possibly containing unknown dimensions (None). + We use `arg_shape` to fill-in the placeholders `_` and `...` in + the `shape_spec`. The dimensions of `arg_shape` that are used for filling + must be known (not `None`). If a dimension in `arg_shape` is known and + the corresponding dimension in `shape_spec` is a constant then they + must be equal. + + See the README.md for usage. + """ + shape_spec_repr = repr(shape_spec) + if shape_spec is None: + shape_spec = "..." + elif isinstance(shape_spec, PolyShape): + shape_spec = str(shape_spec) + elif not isinstance(shape_spec, str): + raise ValueError("polymorphic shape spec should be None or a string. " + f"Found {shape_spec_repr}.") + return _Parser(shape_spec, arg_shape, shape_spec_repr).parse() + +class _Parser: + def __init__(self, + shape_spec: str, + arg_shape: Sequence[Optional[int]], + shape_spec_repr: str): + self.shape_spec = shape_spec + self.shape_spec_repr = shape_spec_repr # For error messages + self.arg_shape = arg_shape + self.dimensions: list[DimSize] = [] # dimensions we have parsed + + def parse(self) -> Sequence[DimSize]: + self.tokstream = tokenize.tokenize( + io.BytesIO(self.shape_spec.encode("utf-8")).readline) + tok = self.consume_token(self.next_tok(), tokenize.ENCODING) # Always 1st + sh, tok = self.shape(tok) + self.expect_token(tok, [tokenize.ENDMARKER]) + return sh + + def add_dim(self, expr: Optional[DimSize], tok: tokenize.TokenInfo): + if expr is None: + raise self.parse_err(tok, + ("unexpected placeholder for unknown dimension " + f"for argument shape {self.arg_shape}")) + arg_shape_dim = self.arg_shape[len(self.dimensions)] + if core.is_constant_dim(expr) and arg_shape_dim is not None: + if expr != arg_shape_dim: + raise self.parse_err(tok, + (f"different size {expr} for known dimension " + f"for argument shape {self.arg_shape}")) + self.dimensions.append(expr) + + def parse_err(self, tok: Optional[tokenize.TokenInfo], detail: str) -> Exception: + msg = ( + f"syntax error in polymorphic shape {self.shape_spec_repr} " + f"in dimension {len(self.dimensions)}: {detail}. ") + if tok is not None: + msg += f"Parsed '{tok.line[:tok.start[1]]}', remaining '{tok.line[tok.start[1]:]}'." + return ValueError(msg) + + def next_tok(self) -> tokenize.TokenInfo: + while True: + try: + t = next(self.tokstream) + except StopIteration: + raise self.parse_err(None, "unexpected end of string") + if t.exact_type not in [tokenize.NEWLINE, tokenize.INDENT, tokenize.DEDENT]: + return t + + def expect_token(self, tok: tokenize.TokenInfo, expected: Sequence[int]) -> None: + if tok.exact_type not in expected: + msg = ("expecting one of {" + + ", ".join(tokenize.tok_name[t] for t in expected) + "} but found " + + tokenize.tok_name[tok.exact_type]) + raise self.parse_err(tok, msg) + + def consume_token(self, tok: tokenize.TokenInfo, expected: int) -> tokenize.TokenInfo: + self.expect_token(tok, [expected]) + return self.next_tok() + + def integer(self, tok: tokenize.TokenInfo) -> tuple[int, tokenize.TokenInfo]: + self.expect_token(tok, [tokenize.NUMBER]) + try: + val = int(tok.string) + except Exception: + raise self.parse_err(tok, f"expecting integer, found {tok.string}") + return val, self.next_tok() + + # What can follow a shape? + FOLLOW_SHAPE = [tokenize.ENDMARKER, tokenize.RPAR] + def shape(self, tok: tokenize.TokenInfo) -> tuple[Sequence[DimSize], tokenize.TokenInfo]: + # A comma-separated list of _DimExpr, or "_", possibly ended with ... + if tok.exact_type == tokenize.LPAR: + res, tok = self.shape(self.next_tok()) + tok = self.consume_token(tok, tokenize.RPAR) + return res, tok + + while True: + if tok.exact_type in self.FOLLOW_SHAPE: + break + if tok.exact_type == tokenize.ELLIPSIS: + to_add = self.arg_shape[len(self.dimensions):] + for ad in to_add: + self.add_dim(ad, tok) + tok = self.next_tok() + break + if len(self.dimensions) >= len(self.arg_shape): + raise self.parse_err(tok, + f"too many dimensions, arg_shape has {len(self.arg_shape)}") + if tok.exact_type == tokenize.NAME and tok.string == "_": + e = self.arg_shape[len(self.dimensions)] + tok = self.next_tok() + else: + e, tok = self.expr(tok) + self.add_dim(e, tok) + if tok.exact_type in self.FOLLOW_SHAPE: + break + tok = self.consume_token(tok, tokenize.COMMA) + + return tuple(self.dimensions), tok + + # What token can follow a _DimExpr + FOLLOW_EXPR = FOLLOW_SHAPE + [tokenize.COMMA] + + def expr(self, tok: tokenize.TokenInfo) -> tuple[DimSize, tokenize.TokenInfo]: + # A sum of monomials + next_m_negated = False + acc = 0 + while True: + m, tok = self.mon(tok) + acc = acc + (- m if next_m_negated else m) + if tok.exact_type in self.FOLLOW_EXPR: + return acc, tok + next_m_negated = (tok.exact_type == tokenize.MINUS) + self.expect_token(tok, [tokenize.PLUS, tokenize.MINUS]) + tok = self.next_tok() + + FOLLOW_MON = FOLLOW_EXPR + [tokenize.PLUS, tokenize.MINUS] + def mon(self, tok: tokenize.TokenInfo) -> tuple[DimSize, tokenize.TokenInfo]: + # A monomial is product of atoms. Each atom may be raised to an integer power. + acc = 1 + while True: + a, tok = self.atom(tok) + if tok.exact_type == tokenize.CIRCUMFLEX: + tok = self.next_tok() + self.expect_token(tok, [tokenize.NUMBER]) + power, tok = self.integer(tok) + a = a ** power + + acc = acc * a + if tok.exact_type in self.FOLLOW_MON: + return acc, tok + tok = self.consume_token(tok, tokenize.STAR) + + def atom(self, tok: tokenize.TokenInfo) -> tuple[DimSize, tokenize.TokenInfo]: + if tok.exact_type == tokenize.NAME: + if tok.string == _DimAtom.MOD: + return self.binary_op(_DimAtom.MOD, self.next_tok()) + if tok.string == _DimAtom.FLOORDIV: + return self.binary_op(_DimAtom.FLOORDIV, self.next_tok()) + if tok.string == _DimAtom.NON_NEGATIVE: + return self.unary_op(_DimAtom.NON_NEGATIVE, self.next_tok()) + return _DimExpr.from_var(tok.string), self.next_tok() + number_sign = 1 + if tok.exact_type == tokenize.MINUS: # -k are negative constants + number_sign = -1 + tok = self.next_tok() + self.expect_token(tok, [tokenize.NUMBER]) + if tok.exact_type == tokenize.NUMBER: + v, tok = self.integer(tok) + return v * number_sign, tok + self.expect_token(tok, [tokenize.NAME, tokenize.MINUS, tokenize.NUMBER]) + assert False + + def unary_op(self, op: str, tok) -> tuple[DimSize, tokenize.TokenInfo]: + tok = self.consume_token(tok, tokenize.LPAR) + e1, tok = self.expr(tok) + tok = self.consume_token(tok, tokenize.RPAR) + return _DimExpr.from_operation(op, e1), tok # type: ignore + + def binary_op(self, op: str, tok) -> tuple[DimSize, tokenize.TokenInfo]: + tok = self.consume_token(tok, tokenize.LPAR) + e1, tok = self.expr(tok) + tok = self.consume_token(tok, tokenize.COMMA) + e2, tok = self.expr(tok) + tok = self.consume_token(tok, tokenize.RPAR) + return _DimExpr.from_operation(op, e1, e2), tok # type: ignore + + +def _evaluate_add(v1, v2): + try: + if op.index(v1) == 0: + return v2 + except: + pass + try: + if op.index(v2) == 0: + return v1 + except: + pass + return v1 + v2 + +def _evaluate_multiply(v1, v2): + try: + if op.index(v1) == 1: + return v2 + except: + pass + try: + if op.index(v2) == 1: + return v1 + except: + pass + return v1 * v2 + +# dimension_size(operand, dimension=i) get the operand.shape[i] as a +# value of type shape_poly.dim_as_value_dtype(). +dimension_size_p = core.Primitive("dimension_size") +def _dimension_size_abstract_eval(aval: core.AbstractValue, **_) -> core.AbstractValue: + return core.dim_value_aval() + +dimension_size_p.def_abstract_eval(_dimension_size_abstract_eval) + +def _dimension_size_impl(arg, *, dimension): + return core.dim_constant(arg.shape[dimension]) +dimension_size_p.def_impl(_dimension_size_impl) + +def _dimension_size_lowering_rule(ctx, arg, *, dimension): + dim_size = mlir.hlo.GetDimensionSizeOp(arg, dimension) + dim_type = mlir.aval_to_ir_type(core.dim_value_aval()) + if dim_size.result.type != dim_type: + dim_size = mlir.hlo.ConvertOp(dim_type, dim_size) + return dim_size.results + +mlir.register_lowering(dimension_size_p, _dimension_size_lowering_rule) + + +def arg_aval( + arg_shape: Sequence[Optional[int]], + arg_jax_dtype: DType, + polymorphic_shape: Optional[Union[str, PolyShape]]) -> core.ShapedArray: + """Computes abstract values. + + Args: + arg_shape: the shape for the argument, possibly having None dimensions. + arg_dtype: the inferred JAX dtype for the arg. + polymorphic_shape: the polymorphic specification for the argument. + Returns: the JAX abstract value for the argument. + """ + aval_shape = _parse_spec(polymorphic_shape, arg_shape) + return core.ShapedArray(aval_shape, arg_jax_dtype) + +def all_dim_vars(args_avals: Sequence[core.AbstractValue]) -> Sequence[str]: + dim_vars: set[str] = set() + for a in args_avals: + for d in a.shape: + if is_poly_dim(d): + dim_vars = dim_vars.union(d.get_vars()) + return sorted(tuple(dim_vars)) + + +class CachingShapeEvaluator: + def __init__(self, **env): + self.env = env + + @functools.lru_cache(128) + def evaluate(self, e: DimSize): + if core.is_constant_dim(e): + res = op.index(e) + else: + res = e.evaluate(self.env) # type: ignore + return res + + +@dataclasses.dataclass(frozen=True) +class ShapeConstraint: + class Comparator(Enum): + EQ = 1 + GEQ = 2 + + comp: Comparator + left: DimSize + right: DimSize + # `error_message_pieces` is a list of strings and DimSize. The error message + # is formed by evaluating the DimSize and concatenating the sequence. + error_message_pieces: Sequence[Union[str, DimSize]] + + def check_statically(self, eval: CachingShapeEvaluator) -> None: + """Evaluates a constraint statically.""" + left, right = eval.evaluate(self.left), eval.evaluate(self.right) + try: + if self.comp == ShapeConstraint.Comparator.EQ: + ok = (left == right) + elif self.comp == ShapeConstraint.Comparator.GEQ: + ok = (left >= right) + else: + assert False # We are in a context where we know we can evaluate + # all symbolic expressions to constants. + except InconclusiveDimensionOperation as e: + raise self.make_error(eval) from e + if not ok: + raise self.make_error(eval) + + def compute(self, eval: CachingShapeEvaluator) -> Optional[jax.Array]: + """Computes if the constraint is satisfied. + + If the constraint can be resolved statically returns None + or raises ValueError otherwise. If the constraint cannot be + resolved statically, returns a value representing if the + constraint is satisfied. + """ + left, right = eval.evaluate(self.left), eval.evaluate(self.right) + # Try to evaluate the constraint statically. + if core.is_constant_shape((left, right)): + left_int, right_int = op.index(left), op.index(right) + if self.comp == ShapeConstraint.Comparator.EQ: + if not (left_int == right_int): + raise self.make_error(eval) + elif self.comp == ShapeConstraint.Comparator.GEQ: + if not (left_int >= right_int): + raise self.make_error(eval) + else: assert False + return None + + if self.comp == ShapeConstraint.Comparator.EQ: + is_ok = lax.eq(left, right) + elif self.comp == ShapeConstraint.Comparator.GEQ: + is_ok = lax.ge(left, right) + else: assert False + return is_ok + + def __str__(self): + return (f"{self.left} {'==' if self.comp == ShapeConstraint.Comparator.EQ else '>='} {self.right}" + f" ({self.error_message_pieces})") + __repr__ = __str__ + + def error_message_and_inputs( + self, + eval: CachingShapeEvaluator) -> tuple[str, Sequence[Any]]: + """Forms the error_message and error message_inputs. + See shape_assertion. + """ + # There is currenly a limitation in the shape assertion checker that + # it supports at most 32 error_message_inputs. We try to stay within the + # limit, reusing a format specifier if possible. + if jaxlib_version <= (0, 4, 14): + max_error_message_inputs = 4 + else: + max_error_message_inputs = 32 + format_specifiers: dict[DimSize, str] = {} + error_message_inputs: list[Any] = [] + error_message_strings: list[str] = [] + for e in self.error_message_pieces: + if isinstance(e, str): + error_message_strings.append(e) + continue + cached_spec = format_specifiers.get(e) + if cached_spec is not None: + error_message_strings.append(cached_spec) + continue + if len(error_message_inputs) >= max_error_message_inputs: + error_message_strings.append("N/A") + continue + spec = "{" + str(len(error_message_inputs)) + "}" + format_specifiers[e] = spec + error_message_strings.append(spec) + error_message_inputs.append(eval.evaluate(e)) + return ("".join(error_message_strings), + error_message_inputs) + + def make_error(self, eval: CachingShapeEvaluator) -> Exception: + error_message, error_message_inputs = self.error_message_and_inputs(eval) + return ValueError(error_message.format(*error_message_inputs)) + + +class ShapeConstraints: + def __init__(self): + self.constraints: list[ShapeConstraint] = [] + + def add_constraint(self, + comp: ShapeConstraint.Comparator, + left: DimSize, right: DimSize, + error_message_pieces: Sequence[Union[str, DimSize]]): + c = ShapeConstraint(comp, left, right, error_message_pieces) + self.constraints.append(c) + + def check_statically(self, eval: CachingShapeEvaluator) -> None: + """Evaluates all the constraints statically. + + If the static checking of any constraint fails, raises ValueError. + """ + for constraint in self.constraints: + constraint.check_statically(eval) + + def shape_assertions(self, eval: CachingShapeEvaluator) -> None: + """Computes the shape assertions for the set of constraints. + + See jax_export._wrap_main_func docstring. + """ + # We want to report the errors in the same order as `check_statically`. + # So, we process them in order, in case some fail statically, and we + # generate the shape assertions in the same order. + for constraint in self.constraints: + is_ok = constraint.compute(eval) + if is_ok is None: continue # Was resolved statically + error_message, error_message_inputs = constraint.error_message_and_inputs(eval) + shape_assertion( + is_ok, *error_message_inputs, + error_message=error_message) + +@dataclasses.dataclass +class _DimEquation: + # Encodes that `aval_dim_expr`, which is a symbolic expressions containing + # unknown dimension variables from the abstract values, is the specification + # for dimension named `dim_name` (e.g., "args[0].field.shape[2]"). + aval_dim_expr: _DimExpr + dim_name: str + + def __str__(self): + return f"Dimension size of {self.dim_name} with specification '{self.aval_dim_expr}'" + __repr__ = __str__ + + +def args_kwargs_path_to_str(path: tree_util.KeyPath) -> str: + # String description of `args` or `kwargs`, assuming the path for a tree for + # the tuple `(args, kwargs)`. + if path[0] == tree_util.SequenceKey(0): + return f"args{tree_util.keystr(path[1:])}" + elif path[0] == tree_util.SequenceKey(1): + return f"kwargs{tree_util.keystr(path[1:])}" + else: + assert False + +@functools.lru_cache(128) +def _cached_pretty_print_dimension_descriptor( + args_kwargs_tree: tree_util.PyTreeDef, + flat_arg_idx: int) -> str: + args_kwargs_with_paths, _ = tree_util.tree_flatten_with_path( + args_kwargs_tree.unflatten((0,) * args_kwargs_tree.num_leaves)) + arg_str = args_kwargs_path_to_str(args_kwargs_with_paths[flat_arg_idx][0]) + return arg_str + +def pretty_print_dimension_descriptor( + args_kwargs_tree: tree_util.PyTreeDef, + flat_arg_idx: int, dim_idx: Optional[int]) -> str: + arg_str = _cached_pretty_print_dimension_descriptor(args_kwargs_tree, flat_arg_idx) + if dim_idx is not None: + arg_str += f".shape[{dim_idx}]" + return arg_str + +@util.cache() +def solve_dim_vars( + args_avals: Sequence[core.AbstractValue], + args_kwargs_tree: tree_util.PyTreeDef, + ) -> tuple[DimVarEnv, ShapeConstraints, Sequence[tuple[str, int, int]]]: + """Solves dimension variables in a called function's avals in terms of actual argument shapes. + + For example, given: + + args_avals = [ShapedArray((3, a, a + b), f32)] + + we introduce fresh "synthetic" dimension variables to represent the actual + dimension size of actual arguments for each non-constant dimension. + Each synthetic variable has a name, an arg_idx, and a dim_idx, e.g.: + + synthetic_vars = [("args[0].shape[1]", 0, 1), ("args[0].shape[2]", 0, 2)] + + and then we express the solution for the unknown dimension variables {a, b} + as symbolic expressions in terms of the synthetic variables: + + dict(a=args[0].shape[1], b=args[0].shape[2] - args[0].shape[1]) + + Not all equations are solvable. For now, we solve first the linear + uni-variate equations, then the solved variables are used to simplify the + remaining equations to linear uni-variate equations, and the process + continues until all dimension variables are solved. + + Args: + args_avals: the abstract values of the `args`, with shapes that may + include unknown dimension variables. + args_kwargs_tree: a PyTreeDef that describes the tuple `(args, kwargs)` + from which the flat sequence `args_avals` is extracted. Used for + describing args and kwargs in synthetic variable names and in + error messages. + + Returns: a 3-tuple with: (a) the solution for the unknown dimension variables + (b) a list of constraints that must be satisfied for the solution to be a + valid one, and (c) and the list of synthetic variables that may appear in + the solution and the constraints. + + Raises ValueError if it cannot solve some dimension variable. + """ + dim_equations: list[_DimEquation] = [] + synth_dimension_vars: list[tuple[str, int, int]] = [] + # tuples with argument name and its polymorphic shape ('args[0]', '(a, a + b')) + polymorphic_shape_specs: list[tuple[str, str]] = [] + for arg_idx, aval in enumerate(args_avals): + if all(not is_poly_dim(d) for d in aval.shape): + continue + polymorphic_shape_specs.append( + (pretty_print_dimension_descriptor(args_kwargs_tree, arg_idx, None), + str(aval.shape))) + for dim_idx, aval_d in enumerate(aval.shape): + if is_poly_dim(aval_d): + synth_dim_var = pretty_print_dimension_descriptor(args_kwargs_tree, + arg_idx, dim_idx) + synth_dimension_vars.append((synth_dim_var, arg_idx, dim_idx)) + dim_equations.append( + _DimEquation(aval_dim_expr=_ensure_poly(aval_d, "solve_dim_vars"), + dim_name=synth_dim_var)) + + solution, shape_constraints = _solve_dim_equations(dim_equations, + polymorphic_shape_specs) + return solution, shape_constraints, synth_dimension_vars + + +def compute_dim_vars_from_arg_shapes( + args_avals: Sequence[core.AbstractValue], + *actual_args: jax.Array, + args_kwargs_tree: tree_util.PyTreeDef) -> Sequence[jax.Array]: + """Computes values of dimension variables to unify args_avals with actual arguments. + + Like `solve_dim_vars` except that here we express the solution as + JAX arrays that reference the `actual_args`. This function can be used to + generate the code for computing the dimension variables. It also generates + the shape assertions. + + Returns: the values of the dimension variables, in the order determined by + `all_dim_vars(args_avals)`. + """ + dim_vars = all_dim_vars(args_avals) + solution, shape_constraints, synth_dim_vars = solve_dim_vars( + tuple(args_avals), args_kwargs_tree=args_kwargs_tree) + + # Replace the synthetic vars with the dynamic shape of the actual arg + synthetic_env = {vname: dimension_size_p.bind(actual_args[arg_idx], + dimension=dim_idx) + for (vname, arg_idx, dim_idx) in synth_dim_vars} + synthetic_eval = CachingShapeEvaluator(**synthetic_env) + shape_constraints.shape_assertions(synthetic_eval) + dim_values = [synthetic_eval.evaluate(solution[var]) for var in dim_vars] + return tuple(dim_values) + +def _solve_dim_equations( + eqns: list[_DimEquation], + polymorphic_shape_specs: Sequence[tuple[str, str]] +) -> tuple[DimVarEnv, ShapeConstraints]: + # Returns a shape environment and the shape constraints if it can solve all + # dimension variables. Raises an exception if it cannot. + shapeenv: DimVarEnv = {} + solution_error_message_pieces: list[Union[str, _DimExpr]] = [ + " Obtained dimension variables: " + ] # Error message describing the solution + # Prepare error message piece describing the polymorphic shape specs + poly_specs_err_msg = ( + " Using the following polymorphic shapes specifications: " + + ",".join(f"{arg_name}.shape = {arg_spec}" + for arg_name, arg_spec in polymorphic_shape_specs)) + "." + solution_err_msg_trailer_errors = ". Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#shape-assertion-errors for more details." + + shape_constraints = ShapeConstraints() # accumulate shape constraints + + def process_one_eqn(eqn: _DimEquation) -> bool: + # We start with a DimEquation of the form `dim_expr = dim_value` + # Try to rewrite the equation as `var * factor_var = dim_value_2` (a linear + # uni-variate equation). Returns `False` if this rewrite fails. + # Otherwise, compute the `var` value as `dim_value_2 // factor`, add it to + # `shapeenv` and return `True`. + # + # Invariant: + # var * factor_var + remaining_monomials_from_dim_expr = dim_value + var, factor_var = None, None + dim_value = _DimExpr.from_var(eqn.dim_name) + + for mon, factor in eqn.aval_dim_expr.monomials(): + # Perhaps we can already evaluate this monomial (all vars solved) + try: + mon_value = mon.evaluate(shapeenv) + except KeyError: + # `mon` still uses some variables not yet solved. We handle only the + # case when `mon` is a single variable. + v = mon.to_var() + if v is not None and var is None: + var, factor_var = v, factor + continue + else: + dim_value = dim_value + core.dim_constant(-1) * _evaluate_multiply(mon_value, core.dim_constant(factor)) + continue + return False # This equation cannot yet be used to solve a variable + + if var is not None: + if factor_var == 1: + var_value = dim_value + else: + var_value, var_remainder = divmod(dim_value, core.dim_constant(factor_var)) # type: ignore + shape_constraints.add_constraint( + ShapeConstraint.Comparator.EQ, var_remainder, 0, + error_message_pieces=([ + "Input shapes do not match the polymorphic shapes specification. " + "Division had remainder ", var_remainder, + f" when computing the value of '{var}'." + poly_specs_err_msg + ] + solution_error_message_pieces + [ + solution_err_msg_trailer_errors])) + + if not isinstance(var_value, _DimExpr): + assert var_value.dtype == core.dim_value_dtype() + shapeenv[var] = var_value # type: ignore + solution_error_message_pieces.extend([ + f"'{var}' = ", var_value, + f" from specification '{eqn.aval_dim_expr}' " + f"for dimension {eqn.dim_name} (= ", _DimExpr.from_var(eqn.dim_name), + "), "]) + + shape_constraints.add_constraint( + ShapeConstraint.Comparator.GEQ, var_value, 1, + error_message_pieces=[ + "Input shapes do not match the polymorphic shapes specification. " + f"Expected value >= 1 for dimension variable '{var}'." + + poly_specs_err_msg + ] + solution_error_message_pieces + [ + solution_err_msg_trailer_errors]) + + return True + else: + # All variables are resolved for this equation, we emit an assertion + shape_constraints.add_constraint( + ShapeConstraint.Comparator.EQ, + _DimExpr.from_var(eqn.dim_name), + eqn.aval_dim_expr.evaluate(shapeenv), + error_message_pieces=([ + "Input shapes do not match the polymorphic shapes specification. " + f"Found inconsistency between dimension size {eqn.dim_name} (= ", + _DimExpr.from_var(eqn.dim_name), + f") and the specification '{eqn.aval_dim_expr}' (= ", + eqn.aval_dim_expr.evaluate(shapeenv), + ")." + poly_specs_err_msg] + solution_error_message_pieces + + [solution_err_msg_trailer_errors]) + ) + return True + + while True: + nr_eqns = len(eqns) + eqns = [eqn for eqn in eqns if not process_one_eqn(eqn)] + if not eqns: + return shapeenv, shape_constraints # SUCCESS + elif len(eqns) >= nr_eqns: + break + + # We have some equations that we cannot solve further + unsolved_vars: set[str] = set() + unsolved_polys: list[_DimExpr] = [] + for eqn in eqns: + unsolved_vars = unsolved_vars.union(eqn.aval_dim_expr.get_vars()) + unsolved_polys.append(eqn.aval_dim_expr) + unsolved_vars = unsolved_vars.difference(shapeenv.keys()) + err_msg = ( + f"Cannot solve for values of dimension variables {unsolved_vars}. " + "We can only solve linear uni-variate constraints." + poly_specs_err_msg + + " Unprocessed specifications: " + + ", ".join(f"'{eqn.aval_dim_expr}' for dimension size {eqn.dim_name}" + for eqn in eqns) + + ". Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#dimension-variables-must-be-solvable-from-the-input-shapes for more details." + ) + raise ValueError(err_msg) diff --git a/jax/experimental/jax2tf/BUILD b/jax/experimental/jax2tf/BUILD index 5120d6b46..6fb7b8eb3 100644 --- a/jax/experimental/jax2tf/BUILD +++ b/jax/experimental/jax2tf/BUILD @@ -35,34 +35,21 @@ py_library( deps = [":jax2tf_internal"], ) -py_library( - name = "jax_export", - srcs = [ - "jax_export.py", - "shape_poly.py", - ], - srcs_version = "PY3", - # TODO: b/255503696: enable pytype - tags = ["pytype_unchecked_annotations"], - visibility = ["//visibility:public"], - deps = [ - "//jax", - ] + py_deps("numpy"), -) - py_library( name = "jax2tf_internal", srcs = [ "call_tf.py", "impl_no_xla.py", "jax2tf.py", + "jax_export.py", # TODO(necula): remove stub + "shape_poly.py", # TODO(necula): remove stub ], srcs_version = "PY3", # TODO: b/255503696: enable pytype tags = ["pytype_unchecked_annotations"], visibility = jax_visibility("jax2tf_internal"), deps = [ - ":jax_export", "//jax", + "//jax/experimental/export", ] + py_deps("numpy") + py_deps("tensorflow_core") + jax2tf_deps, ) diff --git a/jax/experimental/jax2tf/__init__.py b/jax/experimental/jax2tf/__init__.py index c087bc4e5..c82715ac3 100644 --- a/jax/experimental/jax2tf/__init__.py +++ b/jax/experimental/jax2tf/__init__.py @@ -21,3 +21,7 @@ from jax.experimental.jax2tf.jax2tf import ( PolyShape as PolyShape ) from jax.experimental.jax2tf.call_tf import call_tf as call_tf +# TODO(necula): remove stub. Needed by SAX +from jax.experimental.jax2tf import jax_export +# Needed by maths.qec. +from jax.experimental.jax2tf import shape_poly diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index 0b5427605..234b0b069 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -36,9 +36,9 @@ from jax import numpy as jnp from jax import tree_util from jax import sharding from jax.experimental import maps -from jax.experimental.jax2tf import shape_poly +from jax.experimental.export import shape_poly +from jax.experimental.export import export from jax.experimental.jax2tf import impl_no_xla -from jax.experimental.jax2tf import jax_export from jax.interpreters import xla from jax._src import ad_checkpoint @@ -86,7 +86,7 @@ NameStack = source_info_util.NameStack PolyShape = shape_poly.PolyShape DType = Any -DisabledSafetyCheck = jax_export.DisabledSafetyCheck +DisabledSafetyCheck = export.DisabledSafetyCheck # A temporary internal flag, to enable the wrapping of jax.jit functions # with tf.function(jit_compile=True). See #7389. This change has triggered a @@ -370,14 +370,14 @@ def convert(fun_jax: Callable, _, a_jax_dtype = _tfval_to_tensor_jax_dtype(a) return tf_arg_shape, a_jax_dtype - args_specs = jax_export.poly_specs(args_tf, - polymorphic_shapes=polymorphic_shapes, - get_shape_and_dtype=shape_and_dtype_tf) + args_specs = export.poly_specs(args_tf, + polymorphic_shapes=polymorphic_shapes, + get_shape_and_dtype=shape_and_dtype_tf) # The polymorphic_shapes argument refers to positional arguments only. # We assume None for the kwargs. - kwargs_specs = jax_export.poly_specs(kwargs_tf, - polymorphic_shapes=None, - get_shape_and_dtype=shape_and_dtype_tf) + kwargs_specs = export.poly_specs(kwargs_tf, + polymorphic_shapes=None, + get_shape_and_dtype=shape_and_dtype_tf) combined_args_tf = (args_tf, kwargs_tf) args_flat_tf: Sequence[TfVal] args_flat_tf, args_kwargs_tree = tree_util.tree_flatten(combined_args_tf) @@ -503,7 +503,7 @@ class NativeSerializationImpl(SerializationImpl): _thread_local_state.call_tf_concrete_function_list = _prev_func_list self._restore_context = _restore_context - self.exported = jax_export.export( + self.exported = export.export( self.fun_jax, lowering_platform=self.lowering_platform, disabled_checks=self.native_serialization_disabled_checks @@ -669,7 +669,7 @@ def eval_polymorphic_shape(fun_jax: Callable, (c, a) """ def do_eval_polymorphic_shape(*args_specs) -> Any: - args_poly_specs = jax_export.poly_specs( + args_poly_specs = export.poly_specs( args_specs, polymorphic_shapes=polymorphic_shapes) res_poly_spec = jax.eval_shape(fun_jax, *args_poly_specs) # TODO(necula): For now we export the polymorphic shapes using `str`. @@ -803,7 +803,7 @@ def _interpret_fun_jax( def _run_exported_as_tf(args_flat_tf: Sequence[TfVal], - exported: jax_export.Exported, + exported: export.Exported, ) -> Sequence[TfVal]: """Runs the `exported` as an XlaCallModule TF op. diff --git a/jax/experimental/jax2tf/jax_export.py b/jax/experimental/jax2tf/jax_export.py index 1755cf960..aaccf196d 100644 --- a/jax/experimental/jax2tf/jax_export.py +++ b/jax/experimental/jax2tf/jax_export.py @@ -11,1037 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""JAX APIs for exporting JAX functions for interoperation. +"""This is just a backwards-compatibility stub. -This module is used with jax2tf, but has no TensorFlow dependencies. +The main functionality has been moved to jax.experimental.export.export """ -from collections.abc import Sequence -import copy -import dataclasses -import functools -import itertools -import re -from typing import Any, Callable, Optional, Union - -from absl import logging - -import numpy as np - -import jax -from jax import config -from jax import sharding - -from jax._src import core -from jax._src import dispatch -from jax._src import pjit -from jax._src import sharding_impls -from jax._src import source_info_util -from jax._src.interpreters import mlir -from jax._src.interpreters import pxla -from jax._src.lib import xla_client -from jax._src.lib.mlir import ir -from jax._src.lib.mlir.dialects import hlo -from jax._src.lib.mlir.dialects import func as func_dialect -from jax._src import tree_util -from jax._src import util -from jax._src import xla_bridge as xb - -from jax.experimental.jax2tf import shape_poly - -map = util.safe_map -zip = util.safe_zip - -DType = Any - -class DisabledSafetyCheck: - """A safety check should be skipped on (de)serialization. - - Most of these checks are performed on serialization, but some are deferred to - deserialization. The list of disabled checks is attached to the serialization, - e.g., as a sequence of string attributes to `jax_export.Exported` or of - `tf.XlaCallModuleOp`. - - You can disable more deserialization safety checks by passing - `TF_XLA_FLAGS=--tf_xla_call_module_disabled_checks=platform`. - """ - _impl: str - - @classmethod - def platform(cls) -> "DisabledSafetyCheck": - """Allows the execution platform to differ from the serialization platform. - - Has effect only on deserialization. - """ - return DisabledSafetyCheck("platform") - - @classmethod - def custom_call(cls, target_name: str) -> "DisabledSafetyCheck": - """Allows the serialization of a call target not known to be stable. - - Has effect only on serialization. - Args: - target_name: the name of the custom call target to allow. - """ - return DisabledSafetyCheck(f"custom_call:{target_name}") - - @classmethod - def shape_assertions(cls) -> "DisabledSafetyCheck": - """Allows invocations with shapes that do not meet the constraints. - - Has effect on serialization (to suppress the generation of the assertions) - and also on deserialization (to suppress the checking of the assertions). - """ - return DisabledSafetyCheck("shape_assertions") - - def is_custom_call(self) -> Optional[str]: - """Returns the custom call target allowed by this directive.""" - m = re.match(r'custom_call:(.+)$', self._impl) - return m.group(1) if m else None - - def __init__(self, _impl:str): - # Do not use directly, use builders `platform`, `custom_call`. - self._impl = _impl - - def __str__(self): - return self._impl - __repr__ = __str__ - - def __eq__(self, other) -> bool: - return isinstance(other, DisabledSafetyCheck) and self._impl == other._impl - - def __hash__(self) -> int: - return hash(self._impl) - - -minimum_supported_serialization_version = 6 -maximum_supported_serialization_version = 8 - -@dataclasses.dataclass(frozen=True) -class Exported: - """A JAX function lowered to StableHLO. - - Attributes: - fun_name: the name of the exported function, for error messages. - in_tree: a PyTreeDef describing the tuple (args, kwargs) of the lowered JAX - function. The actual lowering does not depend on the `in_tree`, but this - can be used to invoke the exported function using the same argument - structure. - in_avals: the flat tuple of input abstract values. May contain dimension - expressions in the shapes. - out_tree: a PyTreeDef describing the result of the lowered JAX function. - out_avals: the flat tuple of output abstract values. May contain dimension - expressions in the shapes, with dimension variables among those in - `in_avals. - in_shardings: the flattened input shardings. Only for the inputs that are - specified in `module_kept_var_idx`. If `None` then it is equivalent - to unspecified shardings. - out_shardings: the flattened output shardings, as long as `in_avals`. - lowering_platforms: a tuple containing at least one of 'tpu', 'cpu', - 'cuda', 'rocm'. See below for the calling convention for when - there are multiple lowering platforms. - mlir_module_serialized: the serialized lowered VHLO module. - serialization_version: a version number for the serialized module. - See more versioning details at https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#native-serialization-versions. - module_kept_var_idx: the sorted indices of the arguments among `in_avals` that - must be passed to the module. The other arguments have been dropped - because they are not used. Same length as `in_shardings`. - uses_shape_polymorphism: whether the `mlir_module_serialized` uses shape - polymorphism. This may be because `in_avals` contains dimension - variables, but also from inner calls of shape-polymorphic - Exported modules. - disabled_checks: a list of descriptors of safety checks that have been - disabled at export time. See docstring for `DisabledSafetyCheck`. - _get_vjp: an optional function that takes the current exported function and - returns the exported VJP function. - The VJP function takes a flat list of arguments, - starting with the primal arguments and followed by a cotangent argument - for each primal output. It returns a tuple with the cotangents - corresponding to the flattened primal inputs. - - Calling convention for the exported module: - - The `mlir_module` has a `main` function that takes an optional first - platform index argument if the module supports multiple platforms - (`len(lowering_platforms) > 1`), followed by the kept array arguments - (corresponding to `module_kept_var_idx` and `in_avals`). - The platform index is a i32 scalar encoding the index of the current - compilation platform into the `lowering_platforms` sequence. - - Inner functions use a different calling convention: an optional - platform index argument, optional dimension variable arguments specified - using scalar tensors of type i32 or i64, - followed by optional token arguments (in presence of side effects), - followed by the regular array arguments. - The dimension arguments correspond to the dimension variables appearing in - the `args_avals`, in sorted order of their names. - - Consider the lowering of a function with one array argument of type "f32[w, - 2 * h]", where "w" and "h" are two dimension variables. - Assume that we use multi-platform lowering, and we have - ordered effects. The `main` function will be as follows: - - func public main(platform_index: i32, arg: f32[?, ?]) { - arg_w = hlo.get_dimension_size(arg, 0) - dim1 = hlo.get_dimension_size(arg, 1) - arg_h = hlo.floordiv(dim1, 2) - call _check_shape_assertions(arg) # See below - token = new_token() - token_out, res = call _wrapped_jax_export_main(platform_index, arg_h, arg_w, token_in, arg) - return res - } - - The actual computation is in `_wrapped_jax_export_main`, taking also - the values of `h` and `w` and the token. Proper exporting of - functions with side-effects and tokens is still work-in-progress. - - Note that `main` contains a call to `_check_shape_assertions. - JAX tracing assumes that `arg.shape[1]` is even, and that both `w` and `h` - have values >= 1. We must check these constraints when we invoke the - module. We use a special custom call `@shape_assertion` that takes - a boolean first operand, a string `error_message` attribute that may contain - format specifiers `{0}`, `{1}`, ..., and a variadic number of integer - scalar operands corresponding to the format specifiers. - - func private _check_shape_assertions(arg: f32[?, ?]) { - # Check that w is >= 1 - arg_w = hlo.get_dimension_size(arg, 0) - custom_call @shape_assertion(arg_w >= 1, arg_w, - error_message="Dimension variable 'w' must have integer value >= 1. Found {0}") - # Check that dim1 is even - dim1 = hlo.get_dimension_size(arg, 1) - custom_call @shape_assertion(dim1 % 2 == 0, dim1, - error_message="Dimension variable 'h' must have integer value >= 1. Found non-zero remainder {0}") - # Check that h >= 1 - arg_h = hlo.floordiv(dim1, 2) - custom_call @shape_assertion(arg_h >= 1, arg_h, - error_message=""Dimension variable 'h' must have integer value >= 1. Found {0}") - - If we `call_exported` with this module we perform these checks - statically (in `call_exported_abstract_eval`). - """ - fun_name: str - in_tree: tree_util.PyTreeDef - in_avals: tuple[core.AbstractValue, ...] - out_tree: tree_util.PyTreeDef - out_avals: tuple[core.AbstractValue, ...] - - in_shardings: Optional[tuple[Union[sharding.XLACompatibleSharding, pxla.UnspecifiedValue], ...]] - out_shardings: Optional[tuple[Union[sharding.XLACompatibleSharding, pxla.UnspecifiedValue], ...]] - lowering_platform: str # For backwards compatibility - lowering_platforms: tuple[str, ...] - disabled_checks: Sequence[DisabledSafetyCheck] - - mlir_module_serialized: bytes - serialization_version: int - module_kept_var_idx: tuple[int, ...] - uses_shape_polymorphism: bool - - _get_vjp: Optional[Callable[["Exported"], "Exported"]] - - def mlir_module(self) -> ir.Module: - return xla_client._xla.mlir.deserialize_portable_artifact(self.mlir_module_serialized) - - def __str__(self): - # This is called to make a MLIR source location when we call an Exported, and we - # do not want the entire serialized module to end up in locations. - return f"Exported(fun_name={self.fun_name}, ...)" - - def vjp(self) -> "Exported": - """Gets the exported VJP. - - Returns None if not available, which can happen if the Exported has been - loaded from an external format, without a VJP.""" - if self._get_vjp is None: - raise ValueError("No VJP is available") - return self._get_vjp(self) - - -def default_lowering_platform() -> str: - # Canonicalize to turn 'gpu' into 'cuda' or 'rocm' - return xb.canonicalize_platform(jax.default_backend()) - -def poly_spec( - arg_shape: Sequence[Optional[int]], - arg_dtype: DType, - polymorphic_shape: Optional[str]) -> jax.ShapeDtypeStruct: - """Constructs a jax.ShapeDtypeStruct with polymorphic shapes. - - Args: - arg_shape: the shape, with possibly some unspecified dimensions. - arg_dtype: the jax dtype. - polymorphic_shape: a string specifying the polymorphic shape. - - .. warning:: The shape-polymorphic lowering is an experimental feature. - It is meant to be sound, but it is known to reject some JAX programs - that are shape polymorphic. The details of this feature can change. - - It should be either `None` (all dimensions are constant), or a string of - specification for one axis, and can be either a constant, `_` denoting - a constant dimension given by the `arg_shape`, or the name of a - dimension variable assumed to range over dimension greater than 0. For - convenience, zero or more trailing `_` can be abbreviated with `...`, and - the surrounding parentheses may be missing. - - Note that this function does not ensure that the provided `arg_shape` - is compatible with `polymorphic_shape`. The `arg_shape` is used only - to fill-in placeholders from `polymorphic_shape`. - - See [the README](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#shape-polymorphic-conversion) - for more details. - - Returns: a jax.ShapeDTypeStruct with shapes that may contain symbolic - expressions involving dimension variables. - """ - aval_shape = shape_poly._parse_spec(polymorphic_shape, arg_shape) - return jax.ShapeDtypeStruct(aval_shape, arg_dtype) - -def shape_and_dtype_jax_array(a) -> tuple[Sequence[Optional[int]], DType]: - """Returns the shape and dtype of a jax.Array.""" - aval = core.raise_to_shaped(core.get_aval(a)) - return aval.shape, aval.dtype - -def poly_specs( - args, # pytree of arguments - polymorphic_shapes, # prefix pytree of strings - get_shape_and_dtype=shape_and_dtype_jax_array, -): - """Constructs a pytree of jax.ShapeDtypeSpec. - - Args: - args: a pytree of arguments - polymorphic_shapes: should be `None` (all arguments are monomorphic), - a single string (applies to all arguments), or a pytree matching a prefix - of the `args`. - See [how optional parameters are matched to - arguments](https://jax.readthedocs.io/en/latest/pytrees.html#applying-optional-parameters-to-pytrees). - - Note that this function does not ensure that the provided `args` shapes - are compatible with `polymorphic_shapes`. The `args.shape` are used only - to fill-in placeholders from `polymorphic_shapes`. - - See docstring of `poly_spec` and - [the README](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#shape-polymorphic-conversion) - for more details. - - Returns: a pytree of jax.ShapeDTypeStruct matching `args`. - """ - args_flat, args_tree = tree_util.tree_flatten(args) - - shapes_and_dtypes = tuple(map(get_shape_and_dtype, args_flat)) - shapes, dtypes = util.unzip2(shapes_and_dtypes) - - if isinstance(args, tuple) and isinstance(polymorphic_shapes, list): - # TODO: Remove backward-compatibility workaround - polymorphic_shapes_ = tuple(polymorphic_shapes) - else: - polymorphic_shapes_ = polymorphic_shapes - - try: - polymorphic_shapes_flat = tree_util.broadcast_prefix( - polymorphic_shapes_, args, - is_leaf=lambda x: x is None) - except ValueError: - e, *_ = tree_util.prefix_errors( - polymorphic_shapes_, args, - is_leaf=lambda x: x is None) - raise e("jax_export polymorphic_shapes") from None - - # Now add in the polymorphic shapes - args_specs_flat = tuple( - map(poly_spec, shapes, dtypes, polymorphic_shapes_flat)) - - return args_tree.unflatten(args_specs_flat) - - -def export(fun_jax: Callable, - *, - lowering_platform: Optional[str] = None, - lowering_platforms: Optional[Sequence[str]] = None, - disabled_checks: Sequence[DisabledSafetyCheck] = (), - ) -> Callable[..., Exported]: - """Exports native serialization for a JAX function. - - Args: - fun_jax: the function to lower and serialize. - lowering_platform: one of 'tpu', 'cpu', 'cuda', 'rocm'. If None, then use - the default JAX backend. - lowering_platforms: DO NOT USE (NOT YET FUNCTIONAL). - Optional sequence containing a subset of 'tpu', 'cpu', - 'cuda', 'rocm'. If more than one platform is specified, then - the lowered code takes an argument specifying the platform. - If None, then use the default JAX backend. - The calling convention for multiple platforms is explained in the - `jax_export.Exported` docstring. - disabled_checks: the safety checks to disable. See docstring - of `DisabledSafetyCheck`. - - Returns: a function that takes args and kwargs pytrees of jax.ShapeDtypeStruct, - or values with `.shape` and `.dtype` attributes, and returns an - `Exported`. - - Usage: - - def f_jax(*args, **kwargs): ... - exported = jax_export.export(f_jax)(*args, **kwargs) - """ - fun_name = getattr(fun_jax, "__name__", "unknown") - version = config.jax_serialization_version - if (version < minimum_supported_serialization_version or - version > maximum_supported_serialization_version): - raise ValueError( - f"The requested jax_serialization version {version} is outside the " - f"range of supported versions [{minimum_supported_serialization_version}" - f"..{maximum_supported_serialization_version}]") - - def do_export(*args_specs, **kwargs_specs) -> Exported: - if not hasattr(fun_jax, "lower"): - # We support convert(pjit(f_jax)) and convert(jit(f_jax)) but also - # convert(f_jax), in which case a "jit" is implied. In that case we raise - # an error if the lowered function contains non-replicated sharding annotations. - wrapped_fun_jax = jax.jit(fun_jax) - allow_non_replicated_sharding = False - else: - # If we have a pjit or pmap already we do not wrap with another, and we - # allow shardings. - wrapped_fun_jax = fun_jax # type: ignore - allow_non_replicated_sharding = True - - nonlocal lowering_platforms - if lowering_platforms is not None: - lowering_platforms = tuple(lowering_platforms) - else: - lowering_platforms = (lowering_platform or default_lowering_platform(),) - - # Do not include shape assertions if the version is < 7. - enable_shape_assertions = ( - DisabledSafetyCheck.shape_assertions() not in disabled_checks and - version >= 7) # type: ignore - try: - prev_enable_shape_assertions = shape_poly.thread_local_state.enable_shape_assertions - shape_poly.thread_local_state.enable_shape_assertions = enable_shape_assertions - lowered = wrapped_fun_jax.lower( - *args_specs, **kwargs_specs, - _experimental_lowering_platform=lowering_platforms) - - lowering = lowered._lowering # type: ignore - _check_lowering(lowering) - mlir_module = lowering.stablehlo() - - args_avals_flat, _ = tree_util.tree_flatten(lowered.in_avals) - if "kept_var_idx" in lowering.compile_args: - module_kept_var_idx = tuple(sorted(lowering.compile_args["kept_var_idx"])) - else: - # For pmap - module_kept_var_idx = tuple(range(len(args_avals_flat))) - shape_poly_state = lowering.compile_args["shape_poly_state"] - if (not all(core.is_constant_shape(a.shape) for a in args_avals_flat) - or lowering.compile_args.get("ordered_effects", [])): - # All arguments are kept if we have dimension variables. - assert len(module_kept_var_idx) == len(args_avals_flat) - mlir_module = _wrap_main_func( - mlir_module, args_avals_flat, args_kwargs_tree=lowered.in_tree, - has_platform_index_argument=shape_poly_state.has_platform_index_argument - ) - finally: - shape_poly.thread_local_state.enable_shape_assertions = prev_enable_shape_assertions - - with mlir_module.context: - mlir_module_attrs = mlir_module.operation.attributes - mlir_module_attrs["jax.uses_shape_polymorphism"] = ( - mlir.ir.BoolAttr.get(shape_poly_state.uses_dim_vars)) - - mlir_module_serialized = _serialize_module(mlir_module) - - # Figure out the result types and shapes - if "global_out_avals" in lowering.compile_args: - # This is currently the case for pjit - out_avals_flat = lowering.compile_args["global_out_avals"] - elif "shards" in lowering.compile_args: # for PmapComputation - out_avals_flat = lowering.compile_args["shards"].out_sharded_avals - else: - out_avals_flat = lowered.compile_args["out_avals"] - - # Log and then check the module. - if logging.vlog_is_on(3): - mlir_module_text = mlir.module_to_string(mlir_module) - logmsg = (f"version={version} " - f"lowering_platforms={lowering_platforms} " - f"disabled_checks={disabled_checks}") - logging.info("Lowered JAX module: %s\n", logmsg) - for l in mlir_module_text.splitlines(): - logging.info(l) - - _check_module(mlir_module, - allow_non_replicated_sharding=allow_non_replicated_sharding, - disabled_checks=disabled_checks) - - return Exported( - fun_name=fun_name, - in_tree=lowered.in_tree, - out_tree=lowered.out_tree, - in_avals=tuple(args_avals_flat), - out_avals=tuple(out_avals_flat), - in_shardings=lowering.compile_args["in_shardings"], - out_shardings=lowering.compile_args["out_shardings"], - lowering_platform=lowering_platforms[0], # TODO: remove - lowering_platforms=lowering_platforms, - disabled_checks=tuple(disabled_checks), - mlir_module_serialized=mlir_module_serialized, - module_kept_var_idx=module_kept_var_idx, - uses_shape_polymorphism=shape_poly_state.uses_dim_vars, - serialization_version=version, # type: ignore - _get_vjp=lambda exported: _export_native_vjp(fun_jax, exported)) - - return do_export - - -def _serialize_module(module: ir.Module) -> bytes: - mlir_str = mlir.module_to_bytecode(module) - if hlo.get_api_version() < 4: - target_version = hlo.get_earliest_forward_compatible_version() - else: - # `target_version` is used to manage situations when a StableHLO producer - # (in this case, jax2tf) and a StableHLO consumer were built using - # different versions of StableHLO. - # - # Each StableHLO version `producer_version` has a compatibility window, - # i.e. range of versions [`consumer_version_min`, `consumer_version_max`], - # where StableHLO portable artifacts serialized by `producer_version` - # can be deserialized by `consumer_version` within the window. - # See https://github.com/openxla/stablehlo/blob/main/docs/compatibility.md - # for the exact extent of these compatibility guarantees. - # - # `hlo.get_minimum_version()` returns `consumer_version_min` - # for the current version of StableHLO. We are using it here to maximize - # forward compatibility, i.e. to maximize how far into the past we can go - # and still have the payloads produced by `serialize_portable_artifact` - # compatible with potential consumers from the past. - target_version = hlo.get_minimum_version() - module_serialized = xla_client._xla.mlir.serialize_portable_artifact( - mlir_str, target_version) - return module_serialized - - -def _wrap_main_func( - module: ir.Module, - args_avals_flat: Sequence[core.ShapedArray], - *, - args_kwargs_tree: tree_util.PyTreeDef, - has_platform_index_argument: bool, -) -> ir.Module: - """Wraps the lowered module with a new "main" handling dimension arguments. - - See calling convention documentation for `jax_export.Exported`. - - Args: - module: the HLO module as obtained from lowering. See the calling convention - for inner functions in `jax_export.Exported`. - args_avals_flat: the avals for all the arguments of the lowered function, - which correspond to the array arguments of the `module`. - args_kwargs_tree: the PyTreeDef corresponding to `(args, kwargs)`, for error - messages. - - Returns the wrapped module, without dimension and token arguments. - """ - dim_vars = shape_poly.all_dim_vars(args_avals_flat) - context = mlir.make_ir_context() - with context, ir.Location.unknown(context): - # Make a copy, do not mutate because it may be cached - wrapped_module = ir.Module.parse(mlir.module_to_bytecode(module)) - symbol_table = ir.SymbolTable(wrapped_module.operation) - orig_main = symbol_table["main"] - orig_main.attributes["sym_visibility"] = ir.StringAttr.get("private") - symbol_table.set_symbol_name(orig_main, "_wrapped_jax_export_main") - orig_main_name = ir.StringAttr(symbol_table.insert(orig_main)).value - - def is_token(attrs): - try: - return ir.BoolAttr(ir.DictAttr(attrs)["jax.token"]).value - except KeyError: - return False - - orig_input_types = orig_main.type.inputs - arg_attrs = list(ir.ArrayAttr(orig_main.arg_attrs)) - # The order of args: platform_index_arg, dim args, token args, array args. - nr_platform_index_args = 1 if has_platform_index_argument else 0 - nr_dim_args = len(dim_vars) - nr_token_args = sum(1 for attrs in arg_attrs if is_token(attrs)) - nr_array_args = len(orig_input_types) - nr_platform_index_args - nr_dim_args - nr_token_args - assert nr_array_args >= 0 - assert not any(is_token(attrs) for attrs in arg_attrs[-nr_array_args:]) - (platform_input_types, dim_var_input_types, - token_input_types, array_input_types) = util.split_list( - orig_input_types, [nr_platform_index_args, nr_dim_args, nr_token_args]) - new_main_input_types = platform_input_types + array_input_types - orig_output_types = orig_main.type.results - result_attrs = list(ir.ArrayAttr(orig_main.result_attrs)) - nr_token_results = sum(1 for attrs in result_attrs if is_token(attrs)) - nr_array_results = len(orig_output_types) - nr_token_results - assert nr_array_results >= 0 - assert not any( - is_token(attrs) for attrs in result_attrs[-nr_array_results:]) - new_main_output_types = orig_output_types[-nr_array_results:] - new_main_ftype = ir.FunctionType.get(new_main_input_types, new_main_output_types) - new_main_op = func_dialect.FuncOp( - "main", new_main_ftype, ip=ir.InsertionPoint.at_block_begin(wrapped_module.body)) - new_main_op.attributes["sym_visibility"] = ir.StringAttr.get("public") - try: - new_main_op.arg_attrs = ir.ArrayAttr.get(arg_attrs[0:nr_platform_index_args] + arg_attrs[-nr_array_args:]) - except KeyError: - pass # TODO: better detection if orig_main.arg_attrs does not exist - try: - new_main_op.result_attrs = ir.ArrayAttr.get( - result_attrs[-nr_array_results:]) - except KeyError: - pass - symbol_table.insert(new_main_op) - entry_block = new_main_op.add_entry_block() - with ir.InsertionPoint(entry_block): - module_context = mlir.ModuleContext( - "cpu", "cpu", sharding_impls.ShardingContext([]), - source_info_util.new_name_stack(), - [], itertools.count(1), [], module=wrapped_module, context=context) - ctx = mlir.LoweringRuleContext( - module_context=module_context, primitive=None, - avals_in=args_avals_flat, avals_out=None, - tokens_in=mlir.TokenSet(), tokens_out=None) - new_main_op_array_args = new_main_op.arguments[nr_platform_index_args:] - dim_values = mlir.lower_fun( - functools.partial(shape_poly.compute_dim_vars_from_arg_shapes, - args_avals_flat, args_kwargs_tree=args_kwargs_tree), - multiple_results=True)(ctx, *new_main_op_array_args) - # The arguments to pass to the call to orig_main - orig_main_args: list[ir.Value] = [] - # The platform index and the dimension variables - for arg, arg_type in zip( - list(new_main_op.arguments[0:nr_platform_index_args]) + util.flatten(dim_values), - platform_input_types + dim_var_input_types): - if arg.type != arg_type: - orig_main_args.append(hlo.ConvertOp(arg_type, arg).result) - else: - orig_main_args.append(arg) - # Then the token arguments - orig_main_args.extend(list(mlir.dummy_token()) * nr_token_args) - # Then the array arguments. We insert a ConvertOp as the only use of - # an input argument. This helps the downstream shape refinement because - # it will set the type of input arguments to static shapes, and this - # can invalidate the module if the argument is used as the result of a - # function, or if it appears as the input to a custom_call with - # output_operand_alias attribute. See b/287386268. - for a in new_main_op_array_args: - orig_main_args.append(hlo.ConvertOp(a.type, a).result) - call = func_dialect.CallOp(orig_output_types, - ir.FlatSymbolRefAttr.get(orig_main_name), - orig_main_args) - func_dialect.ReturnOp(call.results[-nr_array_results:]) - symbol_table.set_symbol_name(new_main_op, "main") - - return wrapped_module - -def _check_lowering(lowering) -> None: - if not isinstance(lowering, pxla.MeshComputation): - raise NotImplementedError(f"serialization is supported only for pjit. {lowering}") - - if lowering.compile_args["host_callbacks"] or lowering.compile_args["keepalive"]: - raise NotImplementedError("serialization of host_callbacks is not yet implemented") - # Check that we do not see new compile_args. When we add a compile_args it is - # safe to add it to the allowed_compile_args if it does not change the semantics - # or the calling convention of the lowered module. - allowed_compile_args = [ - "backend", "mesh", "global_in_avals", - "global_out_avals", "in_shardings", "out_shardings", "kept_var_idx", - "spmd_lowering", "auto_spmd_lowering", - "tuple_args", "ordered_effects", "unordered_effects", - "keepalive", "host_callbacks", "pmap_nreps", "committed", - "device_assignment", "jaxpr_debug_info", "shape_poly_state"] - for compile_arg in lowering.compile_args.keys(): - if compile_arg not in allowed_compile_args: - raise NotImplementedError(f"Unrecognized lowered.compile_args[{compile_arg}]") - - # We have not implemented support for some of the compile_args. Check here that - # the compile_args have the values that have been implemented. - not_implemented_msgs = [] - for compile_arg, check_value, err_msg in ( - ("spmd_lowering", lambda v: v, "True"), - ("auto_spmd_lowering", lambda v: not v, "False"), - # tuple_args is a compilation flag, does not affect lowering. - ("tuple_args", lambda v: True, "N/A"), - # unordered_effects do not change the calling convention. Those from - # jax.debug will also result in keepalive being non-empty and unsupported - # custom calls. The CallTfEffect is an exception, but we want to allow - # that one. - ("unordered_effects", lambda v: True, "N/A"), - # ordered_effects are allowed and we ensure that the calling convention is - # unmodified by passing dummy tokens in the main function wrapper. - ("ordered_effects", lambda v: True, "N/A"), - # used for TPU jax.debug, send/recv. Not supported yet. - ("host_callbacks", lambda v: not v, "empty"), - # used on all platforms for callbacks. Not supported yet. - ("keepalive", lambda v: not v, "empty"), - ("pmap_nreps", lambda v: v == 1, "1"), - ("shape_poly_state", lambda v: True, "N/A"), - ): - if compile_arg in lowering.compile_args: - if not check_value(lowering.compile_args[compile_arg]): - not_implemented_msgs.append( - f"{compile_arg} must be {err_msg} and it is {lowering.compile_args[compile_arg]}") - if not_implemented_msgs: - raise NotImplementedError( - "serialization error, unimplemented lowered.compile_args:\n" + - "\n".join(not_implemented_msgs)) - -# These are the JAX custom call target names that are guaranteed to be stable. -# Their backwards compatibility is tested by back_compat_test.py. -_CUSTOM_CALL_TARGETS_GUARANTEED_STABLE = { - "Sharding", "SPMDFullToShardShape", "SPMDShardToFullShape", - "ducc_fft", "dynamic_ducc_fft", "cu_threefry2x32", - # cholesky on CPU - "lapack_spotrf", "lapack_dpotrf", "lapack_cpotrf", "lapack_zpotrf", - # eigh on CPU - "lapack_ssyevd", "lapack_dsyevd", "lapack_cheevd", "lapack_zheevd", - # eigh on GPU - "cusolver_syevj", "cusolver_syevd", - # eigh on TPU - "Eigh", - # eig on CPU - "lapack_sgeev", "lapack_dgeev", "lapack_cgeev", "lapack_zgeev", - # qr on CPU - "lapack_sgeqrf", "lapack_dgeqrf", "lapack_cgeqrf", "lapack_zgeqrf", - # householder product on CPU - "lapack_sorgqr", "lapack_dorgqr", "lapack_cungqr", "lapack_zungqr", - # svd on CPU - "lapack_sgesdd", "lapack_dgesdd", "lapack_cgesdd", "lapack_zgesdd", - # qr on GPU - "cusolver_geqrf", "cublas_geqrf_batched", - "cusolver_geqrf", "cusolver_orgqr", - # qr and svd on TPU - "Qr", "ProductOfElementaryHouseholderReflectors", - # triangular_solve on CPU - "blas_strsm", "blas_dtrsm", "blas_ctrsm", "blas_ztrsm", - # TODO(atondwal, necula): add back_compat tests for lu on CPU/GPU - # # lu on CPU - "lapack_sgetrf", "lapack_dgetrf", "lapack_cgetrf", "lapack_zgetrf", - # schur on CPU - "lapack_sgees", "lapack_dgees", "lapack_cgees", "lapack_zgees", - # # lu on GPU - # "cublas_getrf_batched", "cusolver_getrf", - # "hipblas_getrf_batched", "hipsolver_getrf", - # lu on TPU - "LuDecomposition", - # ApproxTopK on TPU - "ApproxTopK", - "tf.call_tf_function", # From jax2tf.call_tf(func, call_tf_graph=True) - "tpu_custom_call", # Pallas/TPU kernels - # TODO(burmako): maintain backwards compatibility for these, until they - # are upstreamed to StableHLO. - # See https://github.com/openxla/stablehlo/issues/8. - "stablehlo.dynamic_reduce_window", - "stablehlo.dynamic_rng_bit_generator", - "stablehlo.dynamic_top_k", - "shape_assertion", # Used by shape_poly to evaluate assertions -} - - -def _check_module(mod: ir.Module, *, - allow_non_replicated_sharding: bool, - disabled_checks: Sequence[DisabledSafetyCheck]) -> None: - """Run a number of checks on the module. - - Args: - allow_non_replicated_sharding: whether the module is allowed to contain - non_replicated sharding annotations. - disabled_checks: the safety checks that are disabled. - """ - sharding_attr = ir.StringAttr.get("Sharding", mod.context) - shape_assertion_attr = ir.StringAttr.get("shape_assertion", mod.context) - allowed_custom_call_targets: set[str] = copy.copy(_CUSTOM_CALL_TARGETS_GUARANTEED_STABLE) - for dc in disabled_checks: - target = dc.is_custom_call() - if target is not None: - allowed_custom_call_targets.add(target) - - allowed_custom_call_targets_attrs = { - ir.StringAttr.get(target, mod.context) - for target in allowed_custom_call_targets} - disallowed_custom_call_ops: list[str] = [] - def check_sharding(op: ir.Operation, loc: ir.Location): - if not allow_non_replicated_sharding: - try: - sharding = op.attributes["mhlo.sharding"] - except KeyError: - pass - else: - if ir.StringAttr(sharding).value not in ["{replicated}", ""]: - raise ValueError( - "Lowered function does not have a top-level pjit but it has" - f" non-replicated sharding annotations, e.g., {op} at {loc}.\nSee" - " https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#support-for-partitioning" - " for a discussion." - ) - - def check_op(op: ir.Operation): - op_name = op.operation.name - if op_name == "func.func": - check_sharding(op.operation, op.location) - - elif op_name == "stablehlo.custom_call" or op_name == "mhlo.custom_call": - call_target_name_attr = op.operation.attributes["call_target_name"] - if (call_target_name_attr not in allowed_custom_call_targets_attrs): - disallowed_custom_call_ops.append(f"{op} at {op.location}") - if call_target_name_attr == sharding_attr: - check_sharding(op, op.location) - elif call_target_name_attr == shape_assertion_attr: - assert (DisabledSafetyCheck.shape_assertions() not in disabled_checks) - - def walk_operations(op): - check_op(op) - for region in op.operation.regions: - for block in region: - for op in block: - walk_operations(op) - - walk_operations(mod) - if disallowed_custom_call_ops: - disallowed_custom_call_ops_str = "\n".join(disallowed_custom_call_ops) - msg = ("Cannot serialize code with custom calls whose targets have no " - "compatibility guarantees. Examples are:\n" - f"{disallowed_custom_call_ops_str}.\n" - "See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#native-lowering-supports-only-select-custom-calls") - raise ValueError(msg) - - -def _export_native_vjp(primal_fun_jax, primal: Exported) -> Exported: - # Export the VJP of `primal_fun_jax`. See documentation for Exported.vjp - - # Since jax.vjp does not handle kwargs, it is easier to do all the work - # here with flattened functions. - def fun_vjp_jax(*args_and_out_cts_flat_jax): - # Takes a flat list of primals and output cotangents - def flattened_primal_fun_jax(*args_flat): - args, kwargs = primal.in_tree.unflatten(args_flat) - res = primal_fun_jax(*args, **kwargs) - res_flat, res_tree = tree_util.tree_flatten(res) - assert res_tree == primal.out_tree - return res_flat - - args_flat_jax, out_cts_flat_jax = util.split_list(args_and_out_cts_flat_jax, - [len(primal.in_avals)]) - _, pullback_jax = jax.vjp(flattened_primal_fun_jax, *args_flat_jax) - return pullback_jax(out_cts_flat_jax) - - vjp_in_avals = list( - itertools.chain(primal.in_avals, - map(lambda a: a.at_least_vspace(), primal.out_avals))) - - # Expand in_shardings to all in_avals even not kept ones. - all_in_shardings = [sharding_impls.UNSPECIFIED] * len(primal.in_avals) - for idx, in_s in zip(sorted(primal.module_kept_var_idx), - primal.in_shardings): # type: ignore - all_in_shardings[idx] = in_s # type: ignore - all_shardings = all_in_shardings + list(primal.out_shardings) # type: ignore - # Cannot mix unspecified and specified shardings. Make the unspecified - # ones replicated. - specified_shardings = [ - s for s in all_shardings if not sharding_impls.is_unspecified(s)] - - vjp_in_shardings: Any # The primal inputs followed by output cotangents - vjp_out_shardings: Any # The primal output cotangents - if 0 == len(specified_shardings): - vjp_in_shardings = sharding_impls.UNSPECIFIED - vjp_out_shardings = sharding_impls.UNSPECIFIED - else: - if len(specified_shardings) < len(all_shardings): - # There are some specified, but not all; pjit front-end does not liwk - in_s = specified_shardings[0] # pjit will enforce that all have same devices - assert isinstance(in_s, sharding.XLACompatibleSharding) - replicated_s = sharding.GSPMDSharding.get_replicated(in_s._device_assignment) - all_shardings = [ - s if not sharding_impls.is_unspecified(s) else replicated_s - for s in all_shardings] - - vjp_in_shardings = tuple(all_shardings) - vjp_out_shardings = tuple(all_shardings[:len(primal.in_avals)]) - if all(sharding_impls.is_unspecified(s) for s in vjp_out_shardings): - vjp_out_shardings = sharding_impls.UNSPECIFIED - - fun_vjp_jax = pjit.pjit(fun_vjp_jax, - in_shardings=vjp_in_shardings, - out_shardings=vjp_out_shardings) - - return export(fun_vjp_jax, - lowering_platform=primal.lowering_platform, - disabled_checks=primal.disabled_checks)(*vjp_in_avals) - -### Importing - -def call_exported(exported: Exported) -> Callable[..., jax.Array]: - - @jax.custom_vjp - def f_flat(*args_flat): - return call_exported_p.bind(*args_flat, exported=exported) - - def f_flat_vjp_fwd(*args_flat): - # Return the primal arguments as the residual - # TODO: keep as residuals only the arguments that are needed - return f_flat(*args_flat), args_flat - - def f_flat_vjp_bwd(residual, ct_res_flat): - args_flat = residual # residual is the primal argument flat tuple - exp_vjp = exported.vjp() - in_ct_flat = call_exported(exp_vjp)(*args_flat, *ct_res_flat) - return in_ct_flat - - f_flat.defvjp(f_flat_vjp_fwd, f_flat_vjp_bwd) - - def f_imported(*args, **kwargs): - # since custom_vjp does not support kwargs, flatten the function first. - args_flat, in_tree = tree_util.tree_flatten((args, kwargs)) - if in_tree != exported.in_tree: - # Give errors with the precise tree difference; use fake leaves so we can - # use tree_util.equality_errors. - in_args = in_tree.unflatten([0] * in_tree.num_leaves) - exp_in_args = exported.in_tree.unflatten([0] * exported.in_tree.num_leaves) - - msg = ( - "The invocation args and kwargs must have the same pytree structure " - f"as when the function '{exported.fun_name}' was exported, but they " - "have the following structural differences:\n" + - ("\n".join( - f" - {shape_poly.args_kwargs_path_to_str(path)} is a {thing1} in the invocation and a " - f"{thing2} when exported, so {explanation}.\n" - for path, thing1, thing2, explanation - in tree_util.equality_errors(in_args, exp_in_args)))) - raise ValueError(msg) - - res_flat = f_flat(*args_flat) - return exported.out_tree.unflatten(res_flat) - return f_imported - - -# A JAX primitive for invoking a serialized JAX function. -call_exported_p = core.Primitive("call_exported") -call_exported_p.multiple_results = True - -@util.cache() -def _call_exported_abstract_eval(*in_avals: core.AbstractValue, - exported: Exported) -> tuple[core.AbstractValue, ...]: - exported_dim_vars = shape_poly.all_dim_vars(exported.in_avals) - assert len(in_avals) == len(exported.in_avals) # since the pytrees have the same structure - # Check that the expected shapes match the actual ones - for arg_idx, (exp_aval, actual_aval) in enumerate(zip(exported.in_avals, in_avals)): - def pp_arg_dim(dim_idx: Optional[int]) -> str: - return shape_poly.pretty_print_dimension_descriptor(exported.in_tree, - arg_idx, dim_idx) - if len(exp_aval.shape) != len(actual_aval.shape): - raise ValueError( - f"Rank mismatch for {pp_arg_dim(None)}: expected {exp_aval.shape} " - f"and called with {actual_aval.shape}") - if exp_aval.dtype != actual_aval.dtype: - raise ValueError( - f"Dtype mismatch for {pp_arg_dim(None)}: expected {exp_aval.dtype} " - f"and called with {actual_aval.dtype}") - for dim_idx, aval_d in enumerate(exp_aval.shape): - # If the exp_aval has a constant dimension then the actual argument must have - # a matching constant dimension. - if core.is_constant_dim(aval_d): - if (not core.is_constant_dim(actual_aval.shape[dim_idx]) or - aval_d != actual_aval.shape[dim_idx]): - raise ValueError( - f"Shape mismatch for {pp_arg_dim(dim_idx)} " - "(expected same constant): " - f"expected {exp_aval.shape} and called with {actual_aval.shape}") - - # Must express the exported_dim_vars in terms of the shapes in in_avals. - solution, shape_constraints, synth_dim_vars = shape_poly.solve_dim_vars( - exported.in_avals, args_kwargs_tree=exported.in_tree) - synthetic_env = {vname: in_avals[arg_idx].shape[dim_idx] - for (vname, arg_idx, dim_idx) in synth_dim_vars} - synthetic_eval = shape_poly.CachingShapeEvaluator(**synthetic_env) - # We discharge all the constraints statically. This results in much simpler - # composability (because we do not have to worry about the constraints of the - # Exported called recursively; we only need to worry about entry-point - # constraints). This also makes sense from a composibility point of view, - # because we get the same errors if we invoke the exported module, or if we - # trace the exported function. Consider for example, an exported module with - # signature `f32[a, a] -> f32[a]`. If we invoke the module with an argument - # `f32[c, d]` it is better to fail because `c == d` is inconclusive, than - # succeed and add a compile-time check that `c == d`. In the latter case, - # it would be ambiguous whether we should continue tracing with a result - # a type `f32[c]` or `f32[d]`. - shape_constraints.check_statically(synthetic_eval) - exported_dim_values = [synthetic_eval.evaluate(solution[var]) - for var in exported_dim_vars] - return tuple( - core.ShapedArray(core.evaluate_shape(out_aval.shape, exported_dim_vars, - *exported_dim_values), - dtype=out_aval.dtype, weak_type=out_aval.weak_type, - named_shape=out_aval.named_shape) - for out_aval in exported.out_avals) - - -call_exported_p.def_abstract_eval(_call_exported_abstract_eval) - -def _call_exported_impl(*args, exported: Exported): - return dispatch.apply_primitive(call_exported_p, *args, exported=exported) - -call_exported_p.def_impl(_call_exported_impl) - -def _call_exported_lowering(ctx: mlir.LoweringRuleContext, *args, - platform: str, - exported: Exported): - # TODO: implement true multi-platform lowering for call_exported - if (platform not in exported.lowering_platforms and - DisabledSafetyCheck.platform() not in exported.disabled_checks): - raise ValueError( - f"The exported function '{exported.fun_name}' was lowered for " - f"platforms '{exported.lowering_platforms}' but it is used " - f"on '{platform}'.") - - if exported.uses_shape_polymorphism: - ctx.module_context.shape_poly_state.uses_dim_vars = True - - submodule = ir.Module.parse(exported.mlir_module()) - symtab = ir.SymbolTable(submodule.operation) - # The called function may have been exported with polymorphic shapes and called - # now with more refined shapes. We insert hlo.ConvertOp to ensure the module - # is valid. - def convert_shape(x: ir.Value, x_aval: core.AbstractValue, new_aval: core.AbstractValue) -> ir.Value: - new_ir_type = mlir.aval_to_ir_type(new_aval) - if x.type != new_ir_type: - return mlir.convert_hlo(ctx, x, x_aval, new_aval) - else: - return x - - callee_type = symtab["main"].type - # TODO: maybe cache multiple calls - fn = mlir.merge_mlir_modules(ctx.module_context.module, - f"call_exported_{exported.fun_name}", - submodule) - kept_args = [ - convert_shape(a, a_aval, exported_in_aval) - for i, (a, a_aval, exported_in_aval) in enumerate(zip(args, ctx.avals_in, exported.in_avals)) - if i in exported.module_kept_var_idx] - if len(exported.lowering_platforms) > 1: - # The exported module takes a platform index argument - # TODO: implement proper handling of the platform_index when we are - # in a multi-platform lowering context. - platform_index = exported.lowering_platforms.index(platform) - arg_width = callee_type.inputs[0].element_type.width - assert arg_width in [32, 64] - platform_index = np.int32(platform_index) if arg_width == 32 else np.int64(platform_index) # type: ignore - kept_args = [mlir.ir_constant(platform_index)] + kept_args - call = func_dialect.CallOp(callee_type.results, - ir.FlatSymbolRefAttr.get(fn), - kept_args) - # The ctx.avals_out already contain the abstract values refined by - # _call_exported_abstract_eval. - return tuple( - convert_shape(out, out_aval, refined_out_aval) - for out, out_aval, refined_out_aval in zip(call.results, exported.out_avals, ctx.avals_out)) - - -for _p in ("cpu", "tpu", "cuda", "rocm"): - mlir.register_lowering(call_exported_p, - functools.partial(_call_exported_lowering, platform=_p), - platform=_p) +# TODO(necula): Remove these stubs +from jax.experimental.export.export import ( + default_lowering_platform, +) diff --git a/jax/experimental/jax2tf/shape_poly.py b/jax/experimental/jax2tf/shape_poly.py index a66497b74..90f35a744 100644 --- a/jax/experimental/jax2tf/shape_poly.py +++ b/jax/experimental/jax2tf/shape_poly.py @@ -11,1583 +11,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Shape polymorphism support. +"""This is just a backwards-compatibility stub. -We introduce a set of dimension variables at the top-level of a `jit` function. -They are introduced implicitly by way of specifying for each dimension of each -argument a symbolic dimension expression in terms of some dimension variables. -All dimension variables are assumed to range over integers greater or equal to 1. - -Symbolic dimensions overload some integer operations, such as -add, multiply, divide, equality, etc. The JAX NumPy layer and the LAX layers have been -touched up to be sensitive to handling shapes that contain symbolic dimensions. -This enables many JAX programs to be traced with symbolic dimensions -in some dimensions. A priority has been to enable the batch -dimension in neural network examples to be polymorphic. - -This was built initially for jax2tf, but it is now customizeable to be -independent of TF. The best documentation at the moment is in the -jax2tf.convert docstring, and the -[README](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md). +The main functionality has been moved to jax.experimental.export.shape_poly """ -import collections -from collections.abc import Iterable, Sequence -import dataclasses -from enum import Enum -import functools -import itertools -import io -import math -import operator as op -import threading -import tokenize -from typing import Any, Optional, Union - -import numpy as np -import opt_einsum - -import jax -from jax import config -from jax.interpreters import xla - -from jax._src import core -from jax._src import dtypes -from jax._src import effects -from jax._src.lax import lax -from jax._src.lib import version as jaxlib_version -from jax._src.interpreters import mlir -from jax._src.numpy import lax_numpy -from jax._src import tree_util -from jax._src import util -from jax._src.typing import DimSize, Shape - - -TfVal = Any -DimVarEnv = dict[str, jax.Array] -DType = Any - -class InconclusiveDimensionOperation(core.InconclusiveDimensionOperation): - """Raised when we cannot conclusively compute with symbolic dimensions.""" - - _help_msg = """ -This error arises for comparison operations with shapes that -are non-constant, and the result of the operation cannot be represented as -a boolean value for all values of the symbolic dimensions involved. - -Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables -for more details. -""" - - def __init__(self, message: str): - error_msg = f"{message}\n{InconclusiveDimensionOperation._help_msg}" - # https://github.com/python/mypy/issues/5887 - super().__init__(error_msg) # type: ignore - -class _ShapePolyThreadLocalState(threading.local): - - def __init__(self): - # TODO(necula): this does not play well with some lowering caches, because - # this state is not part of the cache key. - self.enable_shape_assertions = True - -thread_local_state = _ShapePolyThreadLocalState() - -class _DimAtom: - """Represents an atom in a symbolic dimension expression. - - Atoms are either variables, or expressions of the form floordiv(E1, E2) or - mod(E1, E2). Atoms are multiplied to form monomials (see _DimMon), and - monomials are added to form symbolic expressions (see _DimExpr). - - Args: - * var: if specified then the atom is a dimension variable. `operation` - must be `None`. - * operation: if specified then the atom is an operation applied to - `operands`. One of `FLOORDIR` or `MOD` or `NON_NEGATIVE`. `var` must be `None` - * operands: the operands to which the operation is applied. - """ - # The supported operations - FLOORDIV = "floordiv" - MOD = "mod" - NON_NEGATIVE = "non_negative" # The max of the operand and 0 - - def __init__(self, *operands: '_DimExpr', - var: Optional[str] = None, - operation: Optional[str] = None): - if var is not None: - assert operation is None - assert not operands - else: - assert operation is not None - self.var = var - self.operation = operation - self.operands = operands - - @classmethod - def from_var(cls, v: str) -> '_DimAtom': - return _DimAtom(var=v) - - def to_var(self) -> Optional[str]: - return self.var - - def get_vars(self) -> set[str]: - # All the vars that appear - if self.var is not None: - return {self.var} - else: - acc = set() - for opnd in self.operands: - acc.update(opnd.get_vars()) - return acc - - @classmethod - def from_operation(cls, operation: str, *operands: '_DimExpr') -> '_DimAtom': - return _DimAtom(*operands, operation=operation) - - def __str__(self): - if self.var is not None: - return self.var - opnd_str = ", ".join([str(opnd) for opnd in self.operands]) - return f"{self.operation}({opnd_str})" - __repr__ = __str__ - - def __hash__(self): - return hash((self.var, self.operation, *self.operands)) - - def __eq__(self, other: Any): - # Used only for hashing - if not isinstance(other, _DimAtom): return False - if (self.var is None) != (other.var is None): return False - if self.var is not None: - return self.var == other.var - else: - def symbolic_equal(e1: '_DimExpr', e2: '_DimExpr') -> bool: - try: - return e1 == e2 - except InconclusiveDimensionOperation: - return False - return (self.operation == other.operation and - all(symbolic_equal(self_o, other_o) - for self_o, other_o in zip(self.operands, other.operands))) - - def __lt__(self, other: '_DimAtom'): - """ - Comparison to another atom in graded reverse lexicographic order. - Used only for determining a sorting order, does not relate to the - comparison of the values of the atom. - """ - if self.var is not None and other.var is not None: - return self.var < other.var - elif self.var is not None: - return True - elif other.var is not None: - return True - elif self.operation != other.operation: - return self.operation < other.operation # type: ignore - else: - return id(self) < id(other) - - def bounds(self) -> tuple[float, float]: - """Returns the lower and upper bounds, or -+ inf.""" - if self.var is not None: - return (1, np.inf) # variables are assumed to be >= 1 - opnd_bounds = [opnd.bounds() for opnd in self.operands] - if self.operation == _DimAtom.FLOORDIV: # a // b - (a_l, a_u), (b_l, b_u) = opnd_bounds - def math_floor_with_inf(a: float, b: float): # math.floor, but aware of inf - assert b != 0 - if not np.isinf(b): # divisor is finite - return math.floor(a / b) if not np.isinf(a) else -np.inf if (a >= 0) != (b >= 0) else np.inf - elif not np.isinf(a): # dividend is finite and divisor is infinite - return -1 if (a >= 0) != (b >= 0) else 0 - else: # both dividend and divisor are infinite - return -np.inf if (a >= 0) != (b >= 0) else np.inf - - # Same reasoning as for multiplication: the bounds are among the cross-product - # of the bounds. - bound_candidates = [math_floor_with_inf(a_l, b_l), math_floor_with_inf(a_l, b_u), - math_floor_with_inf(a_u, b_l), math_floor_with_inf(a_u, b_u)] - return (min(*bound_candidates), max(*bound_candidates)) - - elif self.operation == _DimAtom.MOD: - _, (b_l, b_u) = opnd_bounds - if b_l > 0: # positive divisor - return (0, b_u - 1) - elif b_u < 0: # negative divisor - return (b_l + 1, 0) - else: - return (-np.inf, np.inf) - - elif self.operation == _DimAtom.NON_NEGATIVE: - (b_l, b_h), = opnd_bounds - return (max(0, b_l), max(0, b_h)) - - else: - assert False - - def evaluate(self, env: DimVarEnv): - if self.var is not None: - try: - return env[self.var] - except KeyError: - err_msg = ( - f"Encountered dimension variable '{self.var}' that is not appearing in the shapes of the used function arguments.\n" - "Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#dimension-variables-must-be-solvable-from-the-input-shapes for more details.") - raise KeyError(err_msg) - else: - operand_values = [opnd.evaluate(env) for opnd in self.operands] - if self.operation == _DimAtom.FLOORDIV: - return divmod(*operand_values)[0] # type: ignore - elif self.operation == _DimAtom.MOD: - return divmod(*operand_values)[1] # type: ignore - elif self.operation == _DimAtom.NON_NEGATIVE: - return lax.max(operand_values[0], 0) - else: - assert False, self.operation - -class _DimMon(dict): - """Represents a multiplication of atoms. - - The representation is a dictionary mapping _DimAtom to exponent. - The exponents are integers >= 1. - """ - def __hash__(self): - return hash(frozenset(self.items())) - - def __str__(self): - return "*".join(f"{key}^{exponent}" if exponent != 1 else str(key) - for key, exponent in sorted(self.items())) - - @classmethod - def from_var(cls, v: str) -> '_DimMon': - return _DimMon({_DimAtom.from_var(v): 1}) - - @classmethod - def from_atom(clscls, a: _DimAtom, aexp: int): - return _DimMon({a: aexp}) - - def to_var(self) -> Optional[str]: - """Extract the variable name "x", from a monomial "x". - Return None, if the monomial is not a single variable.""" - items = self.items() - if len(items) != 1: - return None - (a, aexp), = items - if aexp != 1: - return None - return a.to_var() - - def get_vars(self) -> set[str]: - # All the vars that appear in the monomial - acc = set() - for a in self.keys(): - acc.update(a.get_vars()) - return acc - - @classmethod - def from_operation(cls, operation: str, *operands: '_DimExpr') -> '_DimMon': - return _DimMon({_DimAtom.from_operation(operation, *operands): 1}) - - @property - def degree(self): - return sum(self.values()) - - def __lt__(self, other: '_DimMon'): - """ - Comparison to another monomial in graded reverse lexicographic order. - Used only for determining a sorting order, does not relate to the - comparison of the values of the monomial. - """ - self_key = -self.degree, tuple(sorted(self)) - other_key = -other.degree, tuple(sorted(other)) - return self_key > other_key - - def mul(self, other: '_DimMon') -> '_DimMon': - """ - Returns the product with another monomial. Example: (n^2*m) * n == n^3 * m. - """ - return _DimMon(collections.Counter(self) + collections.Counter(other)) - - def divide(self, divisor: '_DimMon') -> '_DimMon': - """ - Divides by another monomial. Raises a InconclusiveDimensionOperation - if the result is not a monomial. - For example, (n^3 * m) // n == n^2*m, but n // m fails. - """ - d = collections.Counter(self) - for key, exponent in divisor.items(): - diff = self.get(key, 0) - exponent - if diff < 0: - raise InconclusiveDimensionOperation(f"Cannot divide {self} by {divisor}.") - elif diff == 0: del d[key] - elif diff > 0: d[key] = diff - return _DimMon(d) - - def bounds(self) -> tuple[float, float]: - """Returns the lower and upper bounds, or -+inf.""" - # The bounds of a product are among the product of bounds. - bounds = [] - for a, exp in self.items(): - a_l, a_u = a.bounds() - assert a_l <= a_u - bounds.append((a_l ** exp, a_u ** exp)) - - candidates = [math.prod(atom_bounds) for atom_bounds in itertools.product(*bounds)] - return (min(*candidates), max(*candidates)) # type: ignore - - - def evaluate(self, env: DimVarEnv): - prod = lambda xs: functools.reduce(_evaluate_multiply, xs) if xs else core.dim_constant(1) - def pow_opt(v, p: int): - return v if p == 1 else prod([v] * p) - return prod([pow_opt(a.evaluate(env), deg) for a, deg in self.items()]) - - -class _DimExpr(): - """Symbolic expression in terms of dimension variables. - - A dimension expression is an addition of products (_DimMon) - of atoms (_DimAtom). - - We overload integer operations, but we do that soundly, raising - :class:`InconclusiveDimensionOperation` when the result is not - representable as a _DimExpr. - - The representation of a _DimExpr is as a dictionary mapping _DimMon to - integer coefficients. The special monomial `_DimMon()` is mapped to the - free integer coefficient of the expression. - """ - - __array_priority__ = 1000 # Same as tracer, for __radd__ and others on ndarray - def __init__(self, coeffs: dict[_DimMon, int]): - # Do not construct _DimExpr directly, unless you are sure that coeffs is - # normalized; Use _DimExpr.normalize. - # Takes ownership of coeffs - self._coeffs = coeffs or {_DimMon(): 0} - - def monomials(self) -> Iterable[tuple[_DimMon, int]]: - return self._coeffs.items() - - @classmethod - def _add_coeffs(cls, coeffs: dict[_DimMon, int], mon: _DimMon, coeff: int): - """Do `coeffs[mon] += coeff` but remove 0 coefficients.""" - old_c = coeffs.get(mon) - if old_c is None: - if coeff != 0: coeffs[mon] = coeff - else: - new_c = old_c + coeff - if new_c == 0: - del coeffs[mon] - else: - coeffs[mon] = new_c - - @classmethod - def normalize(cls, coeffs: dict[_DimMon, int]) -> DimSize: - """The main constructor for _DimExpr. - - Ensures that the symbolic dimension is normalized, e.g., - it is represented as a Python int if it is known to be a constant. - """ - # TODO(necula): profile and optimize this - has_non_zero_degree = False - free_const = 0 - new_coeffs: dict[_DimMon, int] = {} - for mon, coeff in coeffs.items(): - if coeff == 0: continue - if mon.degree == 0: # A constant, there can be a single one - free_const = coeff - else: - has_non_zero_degree = True - - new_coeffs[mon] = new_coeffs.get(mon, 0) + coeff - - if has_non_zero_degree: - return _DimExpr(new_coeffs) - else: - return int(free_const) - - @classmethod - def normalize_floordiv_times_divisor(cls, coeffs: dict[_DimMon, int]) -> DimSize: - # Look for floordiv(E, M) * M and turn into E - mod(E, M). This comes - # up when handling strided convolution. - for dec in _decompose_expr(_DimExpr(coeffs), _DimAtom.FLOORDIV): - # e = factor * floordiv(operands)^exp * rest_monomial + rest_expr - if dec.exp != 1: - continue - if dec.rest_monomial == 1 and dec.factor == 1: - continue - m_trimmed, m_remainder = divmod(dec.factor * dec.rest_monomial, dec.operands[1]) - if m_remainder == 0: - return m_trimmed * (dec.operands[0] - _DimExpr.from_operation(_DimAtom.MOD, *dec.operands)) + dec.rest_expr - return _DimExpr.normalize(coeffs) - - @classmethod - def from_monomial(cls, mon: _DimMon, exp: int): - return _DimExpr.normalize({mon: exp}) - - @classmethod - def from_var(cls, v: str) -> '_DimExpr': - return _DimExpr({_DimMon.from_var(v): 1}) - - @classmethod - def from_operation(cls, operation: str, *operands: '_DimExpr') -> '_DimExpr': - return _DimExpr.from_monomial(_DimMon.from_operation(operation, *operands), 1) - - def to_var(self) -> Optional[str]: - """Extract the variable name "x", from a symbolic expression.""" - items = self.monomials() - if len(items) != 1: # type: ignore - return None - (mon, mon_count), = items - if mon_count != 1: - return None - return mon.to_var() - - def get_vars(self) -> set[str]: - """The variables that appear in a symbolic dimension.""" - acc = set() - for mon, _ in self.monomials(): - acc.update(mon.get_vars()) - return acc - - def eq(self, other: DimSize) -> bool: - lb, ub = _ensure_poly(self - other, "eq").bounds() - if lb == ub == 0: - return True - if lb > 0 or ub < 0: - return False - # See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#comparison-of-symbolic-dimensions-is-partially-supported - return False - - def inconclusive_comparison(self, operation: str, op: Any) -> Exception: - return InconclusiveDimensionOperation( - f"Symbolic dimension comparison '{self}' {operation} '{op}' is inconclusive.\n" - "See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#comparison-of-symbolic-dimensions-is-partially-supported.") - - def ge(self, other: DimSize) -> bool: - lb, ub = _ensure_poly(self - other, "ge").bounds() - if lb >= 0: - return True - if ub < 0: - return False - raise self.inconclusive_comparison(">=", other) - - def __hash__(self): - return hash(tuple(sorted(self.monomials()))) - - def __str__(self): - def _one_monomial(mon, c): - if mon.degree == 0: - return str(c) - if c == 1: - return str(mon) - return f"{c}*{mon}" - return " + ".join(_one_monomial(mon, c) - for mon, c in sorted(self.monomials(), reverse=True)) - - def __repr__(self): - return str(self) - - # We overload +, -, *, because they are fully defined for _DimExpr. - def __add__(self, other): - if isinstance(other, core.Tracer) or not _convertible_to_poly(other): - return self.__jax_array__().__add__(other) - - other = _ensure_poly(other, "add") - coeffs = self._coeffs.copy() - for mon, coeff in other.monomials(): - _DimExpr._add_coeffs(coeffs, mon, coeff) - return _DimExpr.normalize_floordiv_times_divisor(coeffs) - - def __radd__(self, other): - if isinstance(other, core.Tracer) or not _convertible_to_poly(other): - return self.__jax_array__().__radd__(other) - return _ensure_poly(other, "add").__add__(self) - - def __sub__(self, other): - if isinstance(other, core.Tracer) or not _convertible_to_poly(other): - return self.__jax_array__().__sub__(other) - return self + -_ensure_poly(other, "sub") - - def __rsub__(self, other): - if isinstance(other, core.Tracer) or not _convertible_to_poly(other): - return self.__jax_array__().__rsub__(other) - return _ensure_poly(other, "sub").__sub__(self) - - def __neg__(self) -> '_DimExpr': - return _DimExpr({mon: -coeff for mon, coeff in self.monomials()}) - - def __mul__(self, other): - if isinstance(other, core.Tracer) or not _convertible_to_poly(other): - return self.__jax_array__().__mul__(other) - other = _ensure_poly(other, "mul") - coeffs: dict[_DimMon, int] = {} - for mon1, coeff1 in self.monomials(): - for mon2, coeff2 in other.monomials(): - mon = mon1.mul(mon2) - _DimExpr._add_coeffs(coeffs, mon, coeff1 * coeff2) - return _DimExpr.normalize_floordiv_times_divisor(coeffs) - - def __rmul__(self, other): - if isinstance(other, core.Tracer) or not _convertible_to_poly(other): - return self.__jax_array__().__rmul__(other) - return _ensure_poly(other, "mul").__mul__(self) - - def __pow__(self, power, modulo=None): - assert modulo is None - try: - power = int(power) - except: - raise InconclusiveDimensionOperation(f"Symblic dimension cannot be raised to non-integer power '{self}' ^ '{power}'") - return functools.reduce(op.mul, [self] * power) - - def __floordiv__(self, divisor): - if isinstance(divisor, core.Tracer) or not _convertible_to_poly(divisor): - return self.__jax_array__().__floordiv__(divisor) - return self.divmod(_ensure_poly(divisor, "floordiv"))[0] - - def __rfloordiv__(self, other): - if isinstance(other, core.Tracer) or not _convertible_to_poly(other): - return self.__jax_array__().__rfloordiv__(other) - return _ensure_poly(other, "floordiv").__floordiv__(self) - - def __truediv__(self, divisor): - # Used for "/", which always returns a float - return self.__jax_array__().__truediv__(divisor) - - def __rtruediv__(self, dividend): - # Used for "/", when dividend is not a _DimExpr - return self.__jax_array__().__rtruediv__(dividend) - - def __mod__(self, divisor): - if isinstance(divisor, core.Tracer) or not _convertible_to_poly(divisor): - return self.__jax_array__().__mod__(divisor) - return self.divmod(_ensure_poly(divisor, "mod"))[1] - - def __rmod__(self, dividend): - if isinstance(dividend, core.Tracer) or not _convertible_to_poly(dividend): - return self.__jax_array__().__rmod__(dividend) - return _ensure_poly(dividend, "mod").__mod__(self) - - def __divmod__(self, divisor): - if isinstance(divisor, core.Tracer) or not _convertible_to_poly(divisor): - return self.__jax_array__().__divmod__(divisor) - return self.divmod(_ensure_poly(divisor, "divmod")) - - def __rdivmod__(self, dividend): - if isinstance(dividend, core.Tracer) or not _convertible_to_poly(dividend): - return self.__jax_array__().__rdivmod__(dividend) - return _ensure_poly(dividend, "divmod").__divmod__(self) - - def __int__(self): - if self.is_constant: - return op.index(next(iter(self._coeffs.values()))) - else: - raise InconclusiveDimensionOperation(f"Symbolic dimension '{self}' used in a context that requires a constant") - - # We must overload __eq__ and __ne__, or else we get unsound defaults. - __eq__ = eq - def __ne__(self, other: DimSize) -> bool: - return not self.eq(other) - - __ge__ = ge - - def __le__(self, other: DimSize): - try: - return _ensure_poly(other, "le").__ge__(self) - except InconclusiveDimensionOperation as e: - raise self.inconclusive_comparison("<=", other) from e - - def __gt__(self, other: DimSize): - try: - return not _ensure_poly(other, "gt").__ge__(self) - except InconclusiveDimensionOperation as e: - raise self.inconclusive_comparison(">", other) from e - - def __lt__(self, other: DimSize): - try: - return not self.__ge__(other) - except InconclusiveDimensionOperation as e: - raise self.inconclusive_comparison("<", other) from e - - def divmod(self, divisor: "_DimExpr") -> tuple[DimSize, int]: - """ - Floor division with remainder (divmod) generalized to polynomials. - If the `divisor` is not a constant, the remainder must be 0. - If the `divisor` is a constant, the remainder may be non 0, for consistency - with integer divmod. - - :return: Quotient resulting from polynomial division and integer remainder. - """ - assert isinstance(divisor, _DimExpr) - try: - dmon, dcount = divisor.leading_term - dividend, quotient = self, 0 - # invariant: self = dividend + divisor * quotient - # quotient and dividend are changed in the loop; the leading term of - # dividend decreases at each iteration. - while is_poly_dim(dividend) and not dividend.is_constant: - mon, count = dividend.leading_term - try: - qmon = mon.divide(dmon) - except InconclusiveDimensionOperation: - raise InconclusiveDimensionOperation("") - qcount, rcount = divmod(count, dcount) - if rcount != 0: - raise InconclusiveDimensionOperation("") - - q = _DimExpr.from_monomial(qmon, qcount) - quotient += q - dividend -= q * divisor # type: ignore[assignment] - - dividend = int(dividend) # type: ignore[assignment] - if divisor.is_constant: - q, r = divmod(dividend, int(divisor)) # type: ignore - quotient += q - remainder = r - else: - if dividend != 0: - raise InconclusiveDimensionOperation("") - remainder = 0 - - if config.jax_enable_checks: - assert self == divisor * quotient + remainder - return quotient, remainder - except InconclusiveDimensionOperation: - return (_DimExpr.from_operation(_DimAtom.FLOORDIV, self, divisor), # type: ignore - _DimExpr.from_operation(_DimAtom.MOD, self, divisor)) - - def bounds(self) -> tuple[float, float]: - """Returns the lower and upper bounds, or -+inf.""" - lb = ub = self._coeffs.get(_DimMon(), 0) # The free coefficient - for mon, coeff in self.monomials(): - if mon.degree == 0: continue # We already included the free coefficient - m_l, m_u = mon.bounds() - assert m_l <= m_u and coeff != 0 - item_l, item_u = coeff * m_l, coeff * m_u - lb = lb + min(item_l, item_u) # type: ignore - ub = ub + max(item_l, item_u) # type: ignore - - if lb != -np.inf or ub != np.inf: - return lb, ub - # Watch for special-case: ct*a - ct*mod(b, a) >= 1 when ct >= 0 and a >= 0 - # TODO(necula): add more principled support for floordiv and mod - # For example, this will miss "1 + a - mod(b, a)" - for dec in _decompose_expr(self, _DimAtom.MOD): - # E = factor*mod(op1, op2)^exp * rest_monomial + rest_expr - if dec.exp == 1 and dec.rest_monomial == 1 and dec.rest_expr == - dec.factor * dec.operands[1]: - try: - if dec.operands[1] <= 0: - continue - except InconclusiveDimensionOperation: - continue - if dec.factor > 0: - return (-np.inf, -1) - else: - return (1, np.inf) - - return lb, ub - - @property - def is_constant(self): - return len(self._coeffs) == 1 and next(iter(self._coeffs)).degree == 0 - - @property - def leading_term(self) -> tuple[_DimMon, int]: - """Returns the highest degree term that comes first lexicographically.""" - return max(self.monomials()) - - def evaluate(self, env: DimVarEnv): - # Evaluates as a value of dtype=core.dim_value_dtype() - terms = [_evaluate_multiply(mon.evaluate(env), core.dim_constant(coeff)) - for mon, coeff in self.monomials()] - return functools.reduce(_evaluate_add, terms) if len(terms) > 1 else terms[0] - - def non_negative(self) -> "_DimExpr": - return _DimExpr.from_operation(_DimAtom.NON_NEGATIVE, self) - - @staticmethod - def get_aval(dim: "_DimExpr"): - return core.dim_value_aval() - - def dimension_as_value(self): - """Turns a dimension size into a Jax value that we can compute with.""" - return _dim_as_value(self) - - def __jax_array__(self): - # Used for implicit coercions of polynomials as JAX arrays - return _dim_as_value(self) - -@dataclasses.dataclass -class _Decomposition: - """Decomposition of an expression around an operation atom. - - E = factor * mod(*operands)^exp * rest_monomial + rest_expr - """ - factor: int - operands: Sequence[_DimExpr] - exp: int - rest_monomial: _DimExpr - rest_expr: _DimExpr - - -def _decompose_expr(e: _DimExpr, operation: str) -> Iterable[_Decomposition]: - for m, m_factor in e.monomials(): - atoms = [(a, aexp) for a, aexp in m.items() if a.operation == operation] - if atoms: - e_minus_m_coeffs = e._coeffs.copy() - del e_minus_m_coeffs[m] - for a, aexp in atoms: - yield _Decomposition( - factor=m_factor, - operands=a.operands, - exp=aexp, - rest_monomial=_DimExpr({m.divide(_DimMon.from_atom(a, aexp)): 1}), - rest_expr=_DimExpr(e_minus_m_coeffs)) - -core.pytype_aval_mappings[_DimExpr] = _DimExpr.get_aval -xla.pytype_aval_mappings[_DimExpr] = _DimExpr.get_aval -dtypes._weak_types.append(_DimExpr) - -def _convertible_to_int(p: DimSize) -> bool: - try: - op.index(p) - return True - except: - return False - -def _ensure_poly(p: DimSize, - operation_name: str) -> _DimExpr: - if isinstance(p, _DimExpr): return p - if _convertible_to_int(p): - return _DimExpr({_DimMon(): op.index(p)}) - raise TypeError(f"Symnbolic dimension {operation_name} not supported for {p}.") - -def _convertible_to_poly(p: DimSize) -> bool: - return isinstance(p, _DimExpr) or _convertible_to_int(p) - -def is_poly_dim(p: DimSize) -> bool: - return isinstance(p, _DimExpr) - -dtypes.python_scalar_dtypes[_DimExpr] = dtypes.python_scalar_dtypes[int] - -def _einsum_contract_path(*operands, **kwargs): - """Like opt_einsum.contract_path, with support for DimExpr shapes. - - We use opt_einsum.contract_path to compute the schedule, using a fixed - constant for all dimension variables. This is safe because we throw an - error if there are more than 1 contractions. Essentially, we just use - opt_einsum.contract_path to parse the specification. - """ - - # Replace the polymorphic shapes with some concrete shapes for calling - # into opt_einsum.contract_path, because the latter wants to compute the - # sizes of operands and intermediate results. - fake_ops = [] - for operand in operands: - # We replace only array operands - if not hasattr(operand, "dtype"): - fake_ops.append(operand) - else: - shape = np.shape(operand) - def fake_dim(d): - if core.is_constant_dim(d): - return d - else: - if not isinstance(d, _DimExpr): - raise TypeError(f"Encountered unexpected shape dimension {d}") - # It is Ok to replace all polynomials with the same value. We may miss - # here some errors due to non-equal dimensions, but we catch them - # later. - return 8 - fake_ops.append(jax.ShapeDtypeStruct(tuple(map(fake_dim, shape)), - operand.dtype)) - - contract_fake_ops, contractions = opt_einsum.contract_path(*fake_ops, - **kwargs) - contract_operands = [] - for operand in contract_fake_ops: - idx = tuple(i for i, fake_op in enumerate(fake_ops) if operand is fake_op) - assert len(idx) == 1 - contract_operands.append(operands[idx[0]]) - return contract_operands, contractions - -lax_numpy._poly_einsum_handlers[_DimExpr] = _einsum_contract_path - -# To implement shape-constraint checking we use a shape assertion primitive. -# shape_assertion_p.bind(assert_what: bool, *error_message_inputs, -# error_message="...{0}...{1}") -# where "{0}" refers to error_message_inputs[0], etc. -shape_assertion_p = core.Primitive("shape_assertion") -shape_assertion_p.multiple_results = True -shape_assertion_p.def_effectful_abstract_eval( - lambda *_, **__: ((), {shape_assertion_effect})) # type: ignore - -def _shape_assertion_lowering_rule(ctx: mlir.LoweringRuleContext, - assert_what: mlir.ir.Value, - *error_message_inputs: mlir.ir.Value, - error_message: str): - op = mlir.custom_call( - "shape_assertion", - result_types=[], # No results - operands=[assert_what, *error_message_inputs], - has_side_effect=True, - extra_attributes=dict(error_message=mlir.ir.StringAttr.get(error_message)) - ) - return op.results - -mlir.register_lowering(shape_assertion_p, _shape_assertion_lowering_rule) - -class ShapeAssertionEffect(effects.Effect): - __str__ = lambda _: "ShapeAssertionEffect" - -shape_assertion_effect = ShapeAssertionEffect() - -effects.lowerable_effects.add_type(ShapeAssertionEffect) -effects.control_flow_allowed_effects.add_type(ShapeAssertionEffect) -effects.remat_allowed_effects.add_type(ShapeAssertionEffect) -effects.custom_derivatives_allowed_effects.add_type(ShapeAssertionEffect) - -def shape_assertion(assert_what: jax.Array, - *error_message_inputs: jax.Array, - error_message: str) -> None: - """Adds a shape assertion in the code. - - Args: - assert_what: a boolean asserted to be true. Must be computed based only - on dimension expressions, so that it can be evaluated after shape - refinement. - error_message_inputs: integers expressions whose values can be referenced - in the `error_message`. Must be computed based only - on dimension expressions, so that they can be evaluated after shape - refinement. - error_message: an error message, possibly containing format specifiers - {0}, {1}, ..., referencing the values of the `error_message_inputs`. - The format specifiers are sometimes processed with Python's - `string::format` method, and sometimes with `llvm::formatv`. - """ - if thread_local_state.enable_shape_assertions: - shape_assertion_p.bind(assert_what, *error_message_inputs, - error_message=error_message) - -# A JAX primitive with no array arguments but with a dimension parameter -# that is a DimExpr. The value of the primitive is the value of the dimension, -# using int64 in x64 mode or int32 otherwise (core.dim_value_dtype()) -dim_as_value_p = core.Primitive("dim_as_value") -dim_as_value_p.def_abstract_eval(lambda dim: core.dim_value_aval()) - -def dim_as_value_impl(dim: DimSize): - raise NotImplementedError( - "Evaluation rule for 'dim_as_value' is not implemented. " - "It seems that you are using shape polymorphism outside jax2tf.") - -dim_as_value_p.def_impl(dim_as_value_impl) -def _dim_as_value(dim: DimSize): - return dim_as_value_p.bind(dim=dim) - -def _dim_as_value_lowering(ctx: mlir.LoweringRuleContext, *, - dim): - res, = mlir.eval_dynamic_shape(ctx, (dim,)) - out_type = mlir.aval_to_ir_type(ctx.avals_out[0]) - if out_type != res.type: # type: ignore - return mlir.hlo.ConvertOp(out_type, res).results - else: - return [res] - -mlir.register_lowering(dim_as_value_p, _dim_as_value_lowering) - - -class PolyShape(tuple): - """Tuple of polymorphic dimension specifications. - - See docstring of :func:`jax2tf.convert`. - """ - - def __init__(self, *dim_specs): - tuple.__init__(dim_specs) - - def __new__(cls, *dim_specs): - for ds in dim_specs: - if not isinstance(ds, (int, str)) and ds != ...: - msg = (f"Invalid polymorphic shape element: {repr(ds)}; must be a string " - "representing a dimension variable, or an integer, or ...") - raise ValueError(msg) - return tuple.__new__(PolyShape, dim_specs) - - def __str__(self): - return "(" + ", ".join(["..." if d is ... else str(d) for d in self]) + ")" - - -def _parse_spec(shape_spec: Union[str, PolyShape, None], - arg_shape: Sequence[Optional[int]]) -> Sequence[DimSize]: - """Parses the shape polymorphic specification for one array argument. - - We have to be able to parse all strings produced by str(_DimExpr) because - sometimes the output polymorphic shapes of one function become the input - polymorphic shapes of another. - - Args: - shape_spec: a shape polymorphic specification. None stands for "...". - arg_shape: an actual shape, possibly containing unknown dimensions (None). - We use `arg_shape` to fill-in the placeholders `_` and `...` in - the `shape_spec`. The dimensions of `arg_shape` that are used for filling - must be known (not `None`). If a dimension in `arg_shape` is known and - the corresponding dimension in `shape_spec` is a constant then they - must be equal. - - See the README.md for usage. - """ - shape_spec_repr = repr(shape_spec) - if shape_spec is None: - shape_spec = "..." - elif isinstance(shape_spec, PolyShape): - shape_spec = str(shape_spec) - elif not isinstance(shape_spec, str): - raise ValueError("polymorphic shape spec should be None or a string. " - f"Found {shape_spec_repr}.") - return _Parser(shape_spec, arg_shape, shape_spec_repr).parse() - -class _Parser: - def __init__(self, - shape_spec: str, - arg_shape: Sequence[Optional[int]], - shape_spec_repr: str): - self.shape_spec = shape_spec - self.shape_spec_repr = shape_spec_repr # For error messages - self.arg_shape = arg_shape - self.dimensions: list[DimSize] = [] # dimensions we have parsed - - def parse(self) -> Sequence[DimSize]: - self.tokstream = tokenize.tokenize( - io.BytesIO(self.shape_spec.encode("utf-8")).readline) - tok = self.consume_token(self.next_tok(), tokenize.ENCODING) # Always 1st - sh, tok = self.shape(tok) - self.expect_token(tok, [tokenize.ENDMARKER]) - return sh - - def add_dim(self, expr: Optional[DimSize], tok: tokenize.TokenInfo): - if expr is None: - raise self.parse_err(tok, - ("unexpected placeholder for unknown dimension " - f"for argument shape {self.arg_shape}")) - arg_shape_dim = self.arg_shape[len(self.dimensions)] - if core.is_constant_dim(expr) and arg_shape_dim is not None: - if expr != arg_shape_dim: - raise self.parse_err(tok, - (f"different size {expr} for known dimension " - f"for argument shape {self.arg_shape}")) - self.dimensions.append(expr) - - def parse_err(self, tok: Optional[tokenize.TokenInfo], detail: str) -> Exception: - msg = ( - f"syntax error in polymorphic shape {self.shape_spec_repr} " - f"in dimension {len(self.dimensions)}: {detail}. ") - if tok is not None: - msg += f"Parsed '{tok.line[:tok.start[1]]}', remaining '{tok.line[tok.start[1]:]}'." - return ValueError(msg) - - def next_tok(self) -> tokenize.TokenInfo: - while True: - try: - t = next(self.tokstream) - except StopIteration: - raise self.parse_err(None, "unexpected end of string") - if t.exact_type not in [tokenize.NEWLINE, tokenize.INDENT, tokenize.DEDENT]: - return t - - def expect_token(self, tok: tokenize.TokenInfo, expected: Sequence[int]) -> None: - if tok.exact_type not in expected: - msg = ("expecting one of {" + - ", ".join(tokenize.tok_name[t] for t in expected) + "} but found " + - tokenize.tok_name[tok.exact_type]) - raise self.parse_err(tok, msg) - - def consume_token(self, tok: tokenize.TokenInfo, expected: int) -> tokenize.TokenInfo: - self.expect_token(tok, [expected]) - return self.next_tok() - - def integer(self, tok: tokenize.TokenInfo) -> tuple[int, tokenize.TokenInfo]: - self.expect_token(tok, [tokenize.NUMBER]) - try: - val = int(tok.string) - except Exception: - raise self.parse_err(tok, f"expecting integer, found {tok.string}") - return val, self.next_tok() - - # What can follow a shape? - FOLLOW_SHAPE = [tokenize.ENDMARKER, tokenize.RPAR] - def shape(self, tok: tokenize.TokenInfo) -> tuple[Sequence[DimSize], tokenize.TokenInfo]: - # A comma-separated list of _DimExpr, or "_", possibly ended with ... - if tok.exact_type == tokenize.LPAR: - res, tok = self.shape(self.next_tok()) - tok = self.consume_token(tok, tokenize.RPAR) - return res, tok - - while True: - if tok.exact_type in self.FOLLOW_SHAPE: - break - if tok.exact_type == tokenize.ELLIPSIS: - to_add = self.arg_shape[len(self.dimensions):] - for ad in to_add: - self.add_dim(ad, tok) - tok = self.next_tok() - break - if len(self.dimensions) >= len(self.arg_shape): - raise self.parse_err(tok, - f"too many dimensions, arg_shape has {len(self.arg_shape)}") - if tok.exact_type == tokenize.NAME and tok.string == "_": - e = self.arg_shape[len(self.dimensions)] - tok = self.next_tok() - else: - e, tok = self.expr(tok) - self.add_dim(e, tok) - if tok.exact_type in self.FOLLOW_SHAPE: - break - tok = self.consume_token(tok, tokenize.COMMA) - - return tuple(self.dimensions), tok - - # What token can follow a _DimExpr - FOLLOW_EXPR = FOLLOW_SHAPE + [tokenize.COMMA] - - def expr(self, tok: tokenize.TokenInfo) -> tuple[DimSize, tokenize.TokenInfo]: - # A sum of monomials - next_m_negated = False - acc = 0 - while True: - m, tok = self.mon(tok) - acc = acc + (- m if next_m_negated else m) - if tok.exact_type in self.FOLLOW_EXPR: - return acc, tok - next_m_negated = (tok.exact_type == tokenize.MINUS) - self.expect_token(tok, [tokenize.PLUS, tokenize.MINUS]) - tok = self.next_tok() - - FOLLOW_MON = FOLLOW_EXPR + [tokenize.PLUS, tokenize.MINUS] - def mon(self, tok: tokenize.TokenInfo) -> tuple[DimSize, tokenize.TokenInfo]: - # A monomial is product of atoms. Each atom may be raised to an integer power. - acc = 1 - while True: - a, tok = self.atom(tok) - if tok.exact_type == tokenize.CIRCUMFLEX: - tok = self.next_tok() - self.expect_token(tok, [tokenize.NUMBER]) - power, tok = self.integer(tok) - a = a ** power - - acc = acc * a - if tok.exact_type in self.FOLLOW_MON: - return acc, tok - tok = self.consume_token(tok, tokenize.STAR) - - def atom(self, tok: tokenize.TokenInfo) -> tuple[DimSize, tokenize.TokenInfo]: - if tok.exact_type == tokenize.NAME: - if tok.string == _DimAtom.MOD: - return self.binary_op(_DimAtom.MOD, self.next_tok()) - if tok.string == _DimAtom.FLOORDIV: - return self.binary_op(_DimAtom.FLOORDIV, self.next_tok()) - if tok.string == _DimAtom.NON_NEGATIVE: - return self.unary_op(_DimAtom.NON_NEGATIVE, self.next_tok()) - return _DimExpr.from_var(tok.string), self.next_tok() - number_sign = 1 - if tok.exact_type == tokenize.MINUS: # -k are negative constants - number_sign = -1 - tok = self.next_tok() - self.expect_token(tok, [tokenize.NUMBER]) - if tok.exact_type == tokenize.NUMBER: - v, tok = self.integer(tok) - return v * number_sign, tok - self.expect_token(tok, [tokenize.NAME, tokenize.MINUS, tokenize.NUMBER]) - assert False - - def unary_op(self, op: str, tok) -> tuple[DimSize, tokenize.TokenInfo]: - tok = self.consume_token(tok, tokenize.LPAR) - e1, tok = self.expr(tok) - tok = self.consume_token(tok, tokenize.RPAR) - return _DimExpr.from_operation(op, e1), tok # type: ignore - - def binary_op(self, op: str, tok) -> tuple[DimSize, tokenize.TokenInfo]: - tok = self.consume_token(tok, tokenize.LPAR) - e1, tok = self.expr(tok) - tok = self.consume_token(tok, tokenize.COMMA) - e2, tok = self.expr(tok) - tok = self.consume_token(tok, tokenize.RPAR) - return _DimExpr.from_operation(op, e1, e2), tok # type: ignore - - -def _evaluate_add(v1, v2): - try: - if op.index(v1) == 0: - return v2 - except: - pass - try: - if op.index(v2) == 0: - return v1 - except: - pass - return v1 + v2 - -def _evaluate_multiply(v1, v2): - try: - if op.index(v1) == 1: - return v2 - except: - pass - try: - if op.index(v2) == 1: - return v1 - except: - pass - return v1 * v2 - -# dimension_size(operand, dimension=i) get the operand.shape[i] as a -# value of type shape_poly.dim_as_value_dtype(). -dimension_size_p = core.Primitive("dimension_size") -def _dimension_size_abstract_eval(aval: core.AbstractValue, **_) -> core.AbstractValue: - return core.dim_value_aval() - -dimension_size_p.def_abstract_eval(_dimension_size_abstract_eval) - -def _dimension_size_impl(arg, *, dimension): - return core.dim_constant(arg.shape[dimension]) -dimension_size_p.def_impl(_dimension_size_impl) - -def _dimension_size_lowering_rule(ctx, arg, *, dimension): - dim_size = mlir.hlo.GetDimensionSizeOp(arg, dimension) - dim_type = mlir.aval_to_ir_type(core.dim_value_aval()) - if dim_size.result.type != dim_type: - dim_size = mlir.hlo.ConvertOp(dim_type, dim_size) - return dim_size.results - -mlir.register_lowering(dimension_size_p, _dimension_size_lowering_rule) - - -def arg_aval( - arg_shape: Sequence[Optional[int]], - arg_jax_dtype: DType, - polymorphic_shape: Optional[Union[str, PolyShape]]) -> core.ShapedArray: - """Computes abstract values. - - Args: - arg_shape: the shape for the argument, possibly having None dimensions. - arg_dtype: the inferred JAX dtype for the arg. - polymorphic_shape: the polymorphic specification for the argument. - Returns: the JAX abstract value for the argument. - """ - aval_shape = _parse_spec(polymorphic_shape, arg_shape) - return core.ShapedArray(aval_shape, arg_jax_dtype) - -def all_dim_vars(args_avals: Sequence[core.AbstractValue]) -> Sequence[str]: - dim_vars: set[str] = set() - for a in args_avals: - for d in a.shape: - if is_poly_dim(d): - dim_vars = dim_vars.union(d.get_vars()) - return sorted(tuple(dim_vars)) - - -class CachingShapeEvaluator: - def __init__(self, **env): - self.env = env - - @functools.lru_cache(128) - def evaluate(self, e: DimSize): - if core.is_constant_dim(e): - res = op.index(e) - else: - res = e.evaluate(self.env) # type: ignore - return res - - -@dataclasses.dataclass(frozen=True) -class ShapeConstraint: - class Comparator(Enum): - EQ = 1 - GEQ = 2 - - comp: Comparator - left: DimSize - right: DimSize - # `error_message_pieces` is a list of strings and DimSize. The error message - # is formed by evaluating the DimSize and concatenating the sequence. - error_message_pieces: Sequence[Union[str, DimSize]] - - def check_statically(self, eval: CachingShapeEvaluator) -> None: - """Evaluates a constraint statically.""" - left, right = eval.evaluate(self.left), eval.evaluate(self.right) - try: - if self.comp == ShapeConstraint.Comparator.EQ: - ok = (left == right) - elif self.comp == ShapeConstraint.Comparator.GEQ: - ok = (left >= right) - else: - assert False # We are in a context where we know we can evaluate - # all symbolic expressions to constants. - except InconclusiveDimensionOperation as e: - raise self.make_error(eval) from e - if not ok: - raise self.make_error(eval) - - def compute(self, eval: CachingShapeEvaluator) -> Optional[jax.Array]: - """Computes if the constraint is satisfied. - - If the constraint can be resolved statically returns None - or raises ValueError otherwise. If the constraint cannot be - resolved statically, returns a value representing if the - constraint is satisfied. - """ - left, right = eval.evaluate(self.left), eval.evaluate(self.right) - # Try to evaluate the constraint statically. - if core.is_constant_shape((left, right)): - left_int, right_int = op.index(left), op.index(right) - if self.comp == ShapeConstraint.Comparator.EQ: - if not (left_int == right_int): - raise self.make_error(eval) - elif self.comp == ShapeConstraint.Comparator.GEQ: - if not (left_int >= right_int): - raise self.make_error(eval) - else: assert False - return None - - if self.comp == ShapeConstraint.Comparator.EQ: - is_ok = lax.eq(left, right) - elif self.comp == ShapeConstraint.Comparator.GEQ: - is_ok = lax.ge(left, right) - else: assert False - return is_ok - - def __str__(self): - return (f"{self.left} {'==' if self.comp == ShapeConstraint.Comparator.EQ else '>='} {self.right}" - f" ({self.error_message_pieces})") - __repr__ = __str__ - - def error_message_and_inputs( - self, - eval: CachingShapeEvaluator) -> tuple[str, Sequence[Any]]: - """Forms the error_message and error message_inputs. - See shape_assertion. - """ - # There is currenly a limitation in the shape assertion checker that - # it supports at most 32 error_message_inputs. We try to stay within the - # limit, reusing a format specifier if possible. - if jaxlib_version <= (0, 4, 14): - max_error_message_inputs = 4 - else: - max_error_message_inputs = 32 - format_specifiers: dict[DimSize, str] = {} - error_message_inputs: list[Any] = [] - error_message_strings: list[str] = [] - for e in self.error_message_pieces: - if isinstance(e, str): - error_message_strings.append(e) - continue - cached_spec = format_specifiers.get(e) - if cached_spec is not None: - error_message_strings.append(cached_spec) - continue - if len(error_message_inputs) >= max_error_message_inputs: - error_message_strings.append("N/A") - continue - spec = "{" + str(len(error_message_inputs)) + "}" - format_specifiers[e] = spec - error_message_strings.append(spec) - error_message_inputs.append(eval.evaluate(e)) - return ("".join(error_message_strings), - error_message_inputs) - - def make_error(self, eval: CachingShapeEvaluator) -> Exception: - error_message, error_message_inputs = self.error_message_and_inputs(eval) - return ValueError(error_message.format(*error_message_inputs)) - - -class ShapeConstraints: - def __init__(self): - self.constraints: list[ShapeConstraint] = [] - - def add_constraint(self, - comp: ShapeConstraint.Comparator, - left: DimSize, right: DimSize, - error_message_pieces: Sequence[Union[str, DimSize]]): - c = ShapeConstraint(comp, left, right, error_message_pieces) - self.constraints.append(c) - - def check_statically(self, eval: CachingShapeEvaluator) -> None: - """Evaluates all the constraints statically. - - If the static checking of any constraint fails, raises ValueError. - """ - for constraint in self.constraints: - constraint.check_statically(eval) - - def shape_assertions(self, eval: CachingShapeEvaluator) -> None: - """Computes the shape assertions for the set of constraints. - - See jax_export._wrap_main_func docstring. - """ - # We want to report the errors in the same order as `check_statically`. - # So, we process them in order, in case some fail statically, and we - # generate the shape assertions in the same order. - for constraint in self.constraints: - is_ok = constraint.compute(eval) - if is_ok is None: continue # Was resolved statically - error_message, error_message_inputs = constraint.error_message_and_inputs(eval) - shape_assertion( - is_ok, *error_message_inputs, - error_message=error_message) - -@dataclasses.dataclass -class _DimEquation: - # Encodes that `aval_dim_expr`, which is a symbolic expressions containing - # unknown dimension variables from the abstract values, is the specification - # for dimension named `dim_name` (e.g., "args[0].field.shape[2]"). - aval_dim_expr: _DimExpr - dim_name: str - - def __str__(self): - return f"Dimension size of {self.dim_name} with specification '{self.aval_dim_expr}'" - __repr__ = __str__ - - -def args_kwargs_path_to_str(path: tree_util.KeyPath) -> str: - # String description of `args` or `kwargs`, assuming the path for a tree for - # the tuple `(args, kwargs)`. - if path[0] == tree_util.SequenceKey(0): - return f"args{tree_util.keystr(path[1:])}" - elif path[0] == tree_util.SequenceKey(1): - return f"kwargs{tree_util.keystr(path[1:])}" - else: - assert False - -@functools.lru_cache(128) -def _cached_pretty_print_dimension_descriptor( - args_kwargs_tree: tree_util.PyTreeDef, - flat_arg_idx: int) -> str: - args_kwargs_with_paths, _ = tree_util.tree_flatten_with_path( - args_kwargs_tree.unflatten((0,) * args_kwargs_tree.num_leaves)) - arg_str = args_kwargs_path_to_str(args_kwargs_with_paths[flat_arg_idx][0]) - return arg_str - -def pretty_print_dimension_descriptor( - args_kwargs_tree: tree_util.PyTreeDef, - flat_arg_idx: int, dim_idx: Optional[int]) -> str: - arg_str = _cached_pretty_print_dimension_descriptor(args_kwargs_tree, flat_arg_idx) - if dim_idx is not None: - arg_str += f".shape[{dim_idx}]" - return arg_str - -@util.cache() -def solve_dim_vars( - args_avals: Sequence[core.AbstractValue], - args_kwargs_tree: tree_util.PyTreeDef, - ) -> tuple[DimVarEnv, ShapeConstraints, Sequence[tuple[str, int, int]]]: - """Solves dimension variables in a called function's avals in terms of actual argument shapes. - - For example, given: - - args_avals = [ShapedArray((3, a, a + b), f32)] - - we introduce fresh "synthetic" dimension variables to represent the actual - dimension size of actual arguments for each non-constant dimension. - Each synthetic variable has a name, an arg_idx, and a dim_idx, e.g.: - - synthetic_vars = [("args[0].shape[1]", 0, 1), ("args[0].shape[2]", 0, 2)] - - and then we express the solution for the unknown dimension variables {a, b} - as symbolic expressions in terms of the synthetic variables: - - dict(a=args[0].shape[1], b=args[0].shape[2] - args[0].shape[1]) - - Not all equations are solvable. For now, we solve first the linear - uni-variate equations, then the solved variables are used to simplify the - remaining equations to linear uni-variate equations, and the process - continues until all dimension variables are solved. - - Args: - args_avals: the abstract values of the `args`, with shapes that may - include unknown dimension variables. - args_kwargs_tree: a PyTreeDef that describes the tuple `(args, kwargs)` - from which the flat sequence `args_avals` is extracted. Used for - describing args and kwargs in synthetic variable names and in - error messages. - - Returns: a 3-tuple with: (a) the solution for the unknown dimension variables - (b) a list of constraints that must be satisfied for the solution to be a - valid one, and (c) and the list of synthetic variables that may appear in - the solution and the constraints. - - Raises ValueError if it cannot solve some dimension variable. - """ - dim_equations: list[_DimEquation] = [] - synth_dimension_vars: list[tuple[str, int, int]] = [] - # tuples with argument name and its polymorphic shape ('args[0]', '(a, a + b')) - polymorphic_shape_specs: list[tuple[str, str]] = [] - for arg_idx, aval in enumerate(args_avals): - if all(not is_poly_dim(d) for d in aval.shape): - continue - polymorphic_shape_specs.append( - (pretty_print_dimension_descriptor(args_kwargs_tree, arg_idx, None), - str(aval.shape))) - for dim_idx, aval_d in enumerate(aval.shape): - if is_poly_dim(aval_d): - synth_dim_var = pretty_print_dimension_descriptor(args_kwargs_tree, - arg_idx, dim_idx) - synth_dimension_vars.append((synth_dim_var, arg_idx, dim_idx)) - dim_equations.append( - _DimEquation(aval_dim_expr=_ensure_poly(aval_d, "solve_dim_vars"), - dim_name=synth_dim_var)) - - solution, shape_constraints = _solve_dim_equations(dim_equations, - polymorphic_shape_specs) - return solution, shape_constraints, synth_dimension_vars - - -def compute_dim_vars_from_arg_shapes( - args_avals: Sequence[core.AbstractValue], - *actual_args: jax.Array, - args_kwargs_tree: tree_util.PyTreeDef) -> Sequence[jax.Array]: - """Computes values of dimension variables to unify args_avals with actual arguments. - - Like `solve_dim_vars` except that here we express the solution as - JAX arrays that reference the `actual_args`. This function can be used to - generate the code for computing the dimension variables. It also generates - the shape assertions. - - Returns: the values of the dimension variables, in the order determined by - `all_dim_vars(args_avals)`. - """ - dim_vars = all_dim_vars(args_avals) - solution, shape_constraints, synth_dim_vars = solve_dim_vars( - tuple(args_avals), args_kwargs_tree=args_kwargs_tree) - - # Replace the synthetic vars with the dynamic shape of the actual arg - synthetic_env = {vname: dimension_size_p.bind(actual_args[arg_idx], - dimension=dim_idx) - for (vname, arg_idx, dim_idx) in synth_dim_vars} - synthetic_eval = CachingShapeEvaluator(**synthetic_env) - shape_constraints.shape_assertions(synthetic_eval) - dim_values = [synthetic_eval.evaluate(solution[var]) for var in dim_vars] - return tuple(dim_values) - -def _solve_dim_equations( - eqns: list[_DimEquation], - polymorphic_shape_specs: Sequence[tuple[str, str]] -) -> tuple[DimVarEnv, ShapeConstraints]: - # Returns a shape environment and the shape constraints if it can solve all - # dimension variables. Raises an exception if it cannot. - shapeenv: DimVarEnv = {} - solution_error_message_pieces: list[Union[str, _DimExpr]] = [ - " Obtained dimension variables: " - ] # Error message describing the solution - # Prepare error message piece describing the polymorphic shape specs - poly_specs_err_msg = ( - " Using the following polymorphic shapes specifications: " + - ",".join(f"{arg_name}.shape = {arg_spec}" - for arg_name, arg_spec in polymorphic_shape_specs)) + "." - solution_err_msg_trailer_errors = ". Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#shape-assertion-errors for more details." - - shape_constraints = ShapeConstraints() # accumulate shape constraints - - def process_one_eqn(eqn: _DimEquation) -> bool: - # We start with a DimEquation of the form `dim_expr = dim_value` - # Try to rewrite the equation as `var * factor_var = dim_value_2` (a linear - # uni-variate equation). Returns `False` if this rewrite fails. - # Otherwise, compute the `var` value as `dim_value_2 // factor`, add it to - # `shapeenv` and return `True`. - # - # Invariant: - # var * factor_var + remaining_monomials_from_dim_expr = dim_value - var, factor_var = None, None - dim_value = _DimExpr.from_var(eqn.dim_name) - - for mon, factor in eqn.aval_dim_expr.monomials(): - # Perhaps we can already evaluate this monomial (all vars solved) - try: - mon_value = mon.evaluate(shapeenv) - except KeyError: - # `mon` still uses some variables not yet solved. We handle only the - # case when `mon` is a single variable. - v = mon.to_var() - if v is not None and var is None: - var, factor_var = v, factor - continue - else: - dim_value = dim_value + core.dim_constant(-1) * _evaluate_multiply(mon_value, core.dim_constant(factor)) - continue - return False # This equation cannot yet be used to solve a variable - - if var is not None: - if factor_var == 1: - var_value = dim_value - else: - var_value, var_remainder = divmod(dim_value, core.dim_constant(factor_var)) # type: ignore - shape_constraints.add_constraint( - ShapeConstraint.Comparator.EQ, var_remainder, 0, - error_message_pieces=([ - "Input shapes do not match the polymorphic shapes specification. " - "Division had remainder ", var_remainder, - f" when computing the value of '{var}'." + poly_specs_err_msg - ] + solution_error_message_pieces + [ - solution_err_msg_trailer_errors])) - - if not isinstance(var_value, _DimExpr): - assert var_value.dtype == core.dim_value_dtype() - shapeenv[var] = var_value # type: ignore - solution_error_message_pieces.extend([ - f"'{var}' = ", var_value, - f" from specification '{eqn.aval_dim_expr}' " - f"for dimension {eqn.dim_name} (= ", _DimExpr.from_var(eqn.dim_name), - "), "]) - - shape_constraints.add_constraint( - ShapeConstraint.Comparator.GEQ, var_value, 1, - error_message_pieces=[ - "Input shapes do not match the polymorphic shapes specification. " - f"Expected value >= 1 for dimension variable '{var}'." + - poly_specs_err_msg - ] + solution_error_message_pieces + [ - solution_err_msg_trailer_errors]) - - return True - else: - # All variables are resolved for this equation, we emit an assertion - shape_constraints.add_constraint( - ShapeConstraint.Comparator.EQ, - _DimExpr.from_var(eqn.dim_name), - eqn.aval_dim_expr.evaluate(shapeenv), - error_message_pieces=([ - "Input shapes do not match the polymorphic shapes specification. " - f"Found inconsistency between dimension size {eqn.dim_name} (= ", - _DimExpr.from_var(eqn.dim_name), - f") and the specification '{eqn.aval_dim_expr}' (= ", - eqn.aval_dim_expr.evaluate(shapeenv), - ")." + poly_specs_err_msg] + solution_error_message_pieces + - [solution_err_msg_trailer_errors]) - ) - return True - - while True: - nr_eqns = len(eqns) - eqns = [eqn for eqn in eqns if not process_one_eqn(eqn)] - if not eqns: - return shapeenv, shape_constraints # SUCCESS - elif len(eqns) >= nr_eqns: - break - - # We have some equations that we cannot solve further - unsolved_vars: set[str] = set() - unsolved_polys: list[_DimExpr] = [] - for eqn in eqns: - unsolved_vars = unsolved_vars.union(eqn.aval_dim_expr.get_vars()) - unsolved_polys.append(eqn.aval_dim_expr) - unsolved_vars = unsolved_vars.difference(shapeenv.keys()) - err_msg = ( - f"Cannot solve for values of dimension variables {unsolved_vars}. " - "We can only solve linear uni-variate constraints." + poly_specs_err_msg + - " Unprocessed specifications: " + - ", ".join(f"'{eqn.aval_dim_expr}' for dimension size {eqn.dim_name}" - for eqn in eqns) + - ". Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#dimension-variables-must-be-solvable-from-the-input-shapes for more details." - ) - raise ValueError(err_msg) +# TODO(necula): Remove these stubs +from jax.experimental.export.shape_poly import ( + InconclusiveDimensionOperation, + + # PolyShape used in tensorflowjs/converters/jax_conversion.py + PolyShape, + # is_poly_dim is used by maths/qec. + is_poly_dim, +) diff --git a/jax/experimental/jax2tf/tests/back_compat_test.py b/jax/experimental/jax2tf/tests/back_compat_test.py index f0f59a8f7..50733a050 100644 --- a/jax/experimental/jax2tf/tests/back_compat_test.py +++ b/jax/experimental/jax2tf/tests/back_compat_test.py @@ -28,7 +28,7 @@ import numpy as np import jax from jax import config from jax import lax -from jax.experimental.jax2tf import jax_export +from jax.experimental.export import export from jax.experimental.jax2tf.tests import back_compat_test_util as bctu from jax.experimental.jax2tf.tests.back_compat_testdata import cpu_ducc_fft @@ -96,7 +96,7 @@ class CompatTest(bctu.CompatTestBase): def test_custom_call_coverage(self): """Tests that the back compat tests cover all the targets declared stable.""" - targets_to_cover = set(jax_export._CUSTOM_CALL_TARGETS_GUARANTEED_STABLE) + targets_to_cover = set(export._CUSTOM_CALL_TARGETS_GUARANTEED_STABLE) # Add here all the testdatas that should cover the targets guaranteed # stable covering_testdatas = [ diff --git a/jax/experimental/jax2tf/tests/back_compat_test_util.py b/jax/experimental/jax2tf/tests/back_compat_test_util.py index 6b317fbe9..23361ceab 100644 --- a/jax/experimental/jax2tf/tests/back_compat_test_util.py +++ b/jax/experimental/jax2tf/tests/back_compat_test_util.py @@ -21,7 +21,7 @@ The tests in this file refer to the test data in ./back_compat_testdata. There is one test for each version of a custom call target, e.g., `test_ducc_fft` tests the FFT custom calls on CPU. Only custom call targets tested here should be listed in -jax_export._CUSTOM_CALL_TARGETS_GUARANTEED_STABLE. All other custom +export._CUSTOM_CALL_TARGETS_GUARANTEED_STABLE. All other custom call targets will result in an error when encountered during serialization. Once we stop using a custom call target in JAX, you can remove it from the @@ -78,7 +78,7 @@ from numpy import array, float32 import jax from jax import tree_util -from jax.experimental.jax2tf import jax_export +from jax.experimental.export import export from jax.experimental import pjit @@ -281,12 +281,12 @@ data_{datetime.date.today().strftime('%Y_%m_%d')} = dict( a string (for debugging), and (c) the module serialization version. """ # Use the native exporter, to make sure we get the proper serialization. - args_specs = jax_export.poly_specs(data.inputs, polymorphic_shapes) - exported = jax_export.export( + args_specs = export.poly_specs(data.inputs, polymorphic_shapes) + exported = export.export( jax.jit(func), lowering_platform=self.default_jax_backend(), disabled_checks=tuple( - jax_export.DisabledSafetyCheck.custom_call(target) + export.DisabledSafetyCheck.custom_call(target) for target in allow_unstable_custom_call_targets) )(*args_specs) @@ -297,13 +297,13 @@ data_{datetime.date.today().strftime('%Y_%m_%d')} = dict( def run_serialized(self, data: CompatTestData, polymorphic_shapes: Optional[Sequence[str]] = None): - args_specs = jax_export.poly_specs(data.inputs, polymorphic_shapes) + args_specs = export.poly_specs(data.inputs, polymorphic_shapes) def ndarray_to_aval(a: np.ndarray) -> core.ShapedArray: return core.ShapedArray(a.shape, a.dtype) in_avals_tree = tree_util.tree_map(ndarray_to_aval, args_specs) # TODO: we ought to ensure that out_avals are polymorphic if need be. We # could either save the in/out_avals (but we need to first implement that - # support in jax_export), or we can just re-use them from the current + # support in export), or we can just re-use them from the current # exported. out_avals_tree = tree_util.tree_map(ndarray_to_aval, data.expected_outputs) # in_tree must be for (args, kwargs) @@ -312,7 +312,7 @@ data_{datetime.date.today().strftime('%Y_%m_%d')} = dict( def _get_vjp(_): assert False # We do not have and do not need VJP - exported = jax_export.Exported( + exported = export.Exported( fun_name="run_serialized", in_tree=in_tree, in_avals=tuple(in_avals), @@ -331,4 +331,4 @@ data_{datetime.date.today().strftime('%Y_%m_%d')} = dict( _get_vjp=_get_vjp) # We use pjit in case there are shardings in the exported module. - return pjit.pjit(jax_export.call_exported(exported))(*data.inputs) + return pjit.pjit(export.call_exported(exported))(*data.inputs) diff --git a/jax/experimental/jax2tf/tests/jax2tf_test.py b/jax/experimental/jax2tf/tests/jax2tf_test.py index 0d6eca9f2..40834962a 100644 --- a/jax/experimental/jax2tf/tests/jax2tf_test.py +++ b/jax/experimental/jax2tf/tests/jax2tf_test.py @@ -39,7 +39,7 @@ from jax._src.lib.mlir.dialects import hlo import jax._src.xla_bridge from jax import config from jax.experimental import jax2tf -from jax.experimental.jax2tf import jax_export +from jax.experimental.export import export from jax.experimental.jax2tf.tests import tf_test_util from jax.experimental.maps import xmap from jax.experimental.shard_map import shard_map @@ -1522,7 +1522,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase): stack.enter_context(mesh) # Run the JAX native version, to check it works, and to fill caches. _ = func_to_convert(*args) - exported = jax_export.export( + exported = export.export( func_to_convert, lowering_platform='tpu' )(*(core.ShapedArray(a.shape, a.dtype) for a in args)) diff --git a/jax/experimental/jax2tf/tests/shape_poly_test.py b/jax/experimental/jax2tf/tests/shape_poly_test.py index 7efa85a3b..ea9b0130d 100644 --- a/jax/experimental/jax2tf/tests/shape_poly_test.py +++ b/jax/experimental/jax2tf/tests/shape_poly_test.py @@ -31,8 +31,8 @@ import re import jax from jax.experimental import jax2tf -from jax.experimental.jax2tf import shape_poly -from jax.experimental.jax2tf import jax_export +from jax.experimental.export import export +from jax.experimental.export import shape_poly from jax.experimental import pjit from jax import lax import jax.numpy as jnp @@ -72,7 +72,6 @@ expect_error_associative_scan = ( "associative scan over axis of non-constant size")) - class DimExprTest(tf_test_util.JaxToTfTestCase): def sampled_assert_equal(self, @@ -585,7 +584,7 @@ class PolyHarness(Harness): len(polymorphic_shapes), len(args), f"polymorphic_shapes {polymorphic_shapes} of length " f"{len(polymorphic_shapes)} must match number of arguments {len(args)}") - args_specs = jax_export.poly_specs(args, polymorphic_shapes) + args_specs = export.poly_specs(args, polymorphic_shapes) input_signature = [ tf.TensorSpec( [d if isinstance(d, int) else None for d in a.shape], diff --git a/jax/experimental/jax2tf/tests/tf_test_util.py b/jax/experimental/jax2tf/tests/tf_test_util.py index 738620f16..82c40fca7 100644 --- a/jax/experimental/jax2tf/tests/tf_test_util.py +++ b/jax/experimental/jax2tf/tests/tf_test_util.py @@ -31,7 +31,7 @@ from jax import tree_util from jax import config from jax.experimental import jax2tf -from jax.experimental.jax2tf import jax_export +from jax.experimental.export import export from jax._src import xla_bridge import numpy as np import tensorflow as tf # type: ignore[import] @@ -158,7 +158,7 @@ def ComputeTfValueAndGrad(tf_f: Callable, tf_args: Sequence, jax_legacy_prng_key="allow") class JaxToTfTestCase(jtu.JaxTestCase): # We want most tests to use the maximum available version, from the locally - # installed tfxla module and jax_export. + # installed tfxla module and export. use_max_serialization_version = True def setUp(self): @@ -181,16 +181,16 @@ class JaxToTfTestCase(jtu.JaxTestCase): self.addCleanup(functools.partial(config.update, "jax_serialization_version", version)) if self.use_max_serialization_version: - # Use the largest supported by both jax_export and tfxla.call_module - version = min(jax_export.maximum_supported_serialization_version, + # Use the largest supported by both export and tfxla.call_module + version = min(export.maximum_supported_serialization_version, tfxla.call_module_maximum_supported_version()) self.assertGreaterEqual(version, - jax_export.minimum_supported_serialization_version) + export.minimum_supported_serialization_version) config.update("jax_serialization_version", version) logging.info( - "Using JAX serialization version %s (jax_export.max_version %s, tf.XlaCallModule max version %s)", + "Using JAX serialization version %s (export.max_version %s, tf.XlaCallModule max version %s)", version, - jax_export.maximum_supported_serialization_version, + export.maximum_supported_serialization_version, tfxla.call_module_maximum_supported_version()) with contextlib.ExitStack() as stack: diff --git a/tests/BUILD b/tests/BUILD index 0f897e257..ef4a8b5d0 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -1173,6 +1173,18 @@ py_test( ], ) +jax_test( + name = "export_test", + srcs = ["export_test.py"], + enable_configs = [ + "tpu_df_2x2", + ], + tags = [], + deps = [ + "//jax/experimental/export", + ], +) + exports_files( [ "api_test.py", diff --git a/jax/experimental/jax2tf/tests/jax_export_test.py b/tests/export_test.py similarity index 84% rename from jax/experimental/jax2tf/tests/jax_export_test.py rename to tests/export_test.py index 940017a5c..79966b759 100644 --- a/jax/experimental/jax2tf/tests/jax_export_test.py +++ b/tests/export_test.py @@ -12,9 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. import contextlib -import math import functools import logging +import math import re from typing import Optional import unittest @@ -24,7 +24,7 @@ import jax from jax import numpy as jnp from jax import tree_util from jax.config import config -from jax.experimental.jax2tf import jax_export +from jax.experimental.export import export from jax._src import core from jax._src import test_util as jtu @@ -56,14 +56,14 @@ class JaxExportTest(jtu.JaxTestCase): super().setUp() # Run tests with the maximum supported version by default self.override_serialization_version( - jax_export.maximum_supported_serialization_version) + export.maximum_supported_serialization_version) def test_basic_export_only(self): def my_fun(x): return jnp.sin(x) - exp = jax_export.export(my_fun)(jax.ShapeDtypeStruct((4,), dtype=np.float32)) + exp = export.export(my_fun)(jax.ShapeDtypeStruct((4,), dtype=np.float32)) self.assertEqual("my_fun", exp.fun_name) - self.assertEqual(jax_export.default_lowering_platform(), exp.lowering_platform) + self.assertEqual(export.default_lowering_platform(), exp.lowering_platform) self.assertEqual(tree_util.tree_flatten(((1,), {}))[1], exp.in_tree) self.assertEqual((core.ShapedArray((4,), dtype=np.float32),), exp.in_avals) self.assertEqual((core.ShapedArray((4,), dtype=np.float32),), exp.out_avals) @@ -74,7 +74,7 @@ class JaxExportTest(jtu.JaxTestCase): def f(a_b_pair, *, a, b): return (dict(res=a_b_pair, a=a, b=b), jnp.sin(a), jnp.cos(b)) - exp = jax_export.export(f, lowering_platform="cpu")((a, b), a=a, b=b) + exp = export.export(f, lowering_platform="cpu")((a, b), a=a, b=b) a_aval = core.ShapedArray(a.shape, a.dtype) b_aval = core.ShapedArray(b.shape, b.dtype) self.assertEqual(exp.lowering_platform, "cpu") @@ -90,9 +90,9 @@ class JaxExportTest(jtu.JaxTestCase): def f(a, b): # a: f32[2w,h] b: f32[w,h] return jnp.concatenate([a, b], axis=0) - exp = jax_export.export(f)( - jax_export.poly_spec(a.shape, a.dtype, "(2*w, h)"), - jax_export.poly_spec(a.shape, a.dtype, "(w, h)")) + exp = export.export(f)( + export.poly_spec(a.shape, a.dtype, "(2*w, h)"), + export.poly_spec(a.shape, a.dtype, "(w, h)")) self.assertEqual("(2*w, h)", str(exp.in_avals[0].shape)) self.assertEqual("(w, h)", str(exp.in_avals[1].shape)) self.assertEqual("(3*w, h)", str(exp.out_avals[0].shape)) @@ -102,25 +102,25 @@ class JaxExportTest(jtu.JaxTestCase): def f(a0, a1, *, ak): return jnp.concatenate([a0, a1, ak], axis=0) - a_poly_spec = jax_export.poly_spec(a.shape, a.dtype, "(w, h)") - exp = jax_export.export(f)(a_poly_spec, a_poly_spec, ak=a_poly_spec) + a_poly_spec = export.poly_spec(a.shape, a.dtype, "(w, h)") + exp = export.export(f)(a_poly_spec, a_poly_spec, ak=a_poly_spec) self.assertEqual("(w, h)", str(exp.in_avals[0].shape)) self.assertEqual("(3*w, h)", str(exp.out_avals[0].shape)) def test_basic(self): f = jnp.sin x = np.arange(4, dtype=np.float32) - exp_f = jax_export.export(f)(x) + exp_f = export.export(f)(x) - f1 = jax_export.call_exported(exp_f) + f1 = export.call_exported(exp_f) self.assertAllClose(f(x), f1(x)) def test_call_exported_lambda(self): # When we export a lambda, the exported.fun_name is not a valid MLIR function name f = lambda x: jnp.sin(x) x = np.arange(4, dtype=np.float32) - exp_f = jax_export.export(f)(x) - f1 = jax_export.call_exported(exp_f) + exp_f = export.export(f)(x) + f1 = export.call_exported(exp_f) self.assertAllClose(f(x), f1(x)) def test_call_twice_exported(self): @@ -129,8 +129,8 @@ class JaxExportTest(jtu.JaxTestCase): @jax.jit def f1(x): - exp_f = jax_export.export(f)(x) - return jax_export.call_exported(exp_f)(x) + jax_export.call_exported(exp_f)(x) + exp_f = export.export(f)(x) + return export.call_exported(exp_f)(x) + export.call_exported(exp_f)(x) self.assertAllClose(2. * f(x), f1(x)) @@ -138,9 +138,9 @@ class JaxExportTest(jtu.JaxTestCase): f = lambda x, y: jnp.sin(x) x = np.arange(4, dtype=np.float32) y = np.arange(6, dtype=np.float32) - exp_f = jax_export.export(f)(x, y) + exp_f = export.export(f)(x, y) - f1 = jax_export.call_exported(exp_f) + f1 = export.call_exported(exp_f) self.assertAllClose(f(x, y), f1(x, y)) def test_pytree(self): @@ -149,8 +149,8 @@ class JaxExportTest(jtu.JaxTestCase): def f(a_b_pair, a, b): return (dict(res=a_b_pair, a=a, b=b), jnp.sin(a), jnp.cos(b)) - exp_f = jax_export.export(f)((a, b), a=a, b=b) - f1 = jax_export.call_exported(exp_f) + exp_f = export.export(f)((a, b), a=a, b=b) + f1 = export.call_exported(exp_f) self.assertAllClose(f((a, b), a=a, b=b), f1((a, b), a=a, b=b)) @@ -158,34 +158,34 @@ class JaxExportTest(jtu.JaxTestCase): def f(a_b_pair, *, c): return jnp.sin(a_b_pair[0]) + jnp.cos(a_b_pair[1]) + c a = b = c = np.arange(4, dtype=np.float32) - exp_f = jax_export.export(f)((a, b), c=c) + exp_f = export.export(f)((a, b), c=c) with self.assertRaisesRegex( ValueError, "The invocation args and kwargs must have the same pytree structure"): - jax_export.call_exported(exp_f)(a, b, c=(a, b)) + export.call_exported(exp_f)(a, b, c=(a, b)) def test_error_wrong_avals(self): def f(a, *, b): # a: f32[4] and b: f32[4] return jnp.sin(a) + jnp.cos(b) f32_4 = np.arange(4, dtype=np.float32) - exp_f = jax_export.export(f)(f32_4, b=f32_4) + exp_f = export.export(f)(f32_4, b=f32_4) with self.assertRaisesRegex(ValueError, r"Shape mismatch for args\[0\].shape\[0\]"): - jax_export.call_exported(exp_f)(np.arange(6, dtype=np.float32), b=f32_4) + export.call_exported(exp_f)(np.arange(6, dtype=np.float32), b=f32_4) with self.assertRaisesRegex(ValueError, r"Shape mismatch for kwargs\['b'\].shape\[0\]"): - jax_export.call_exported(exp_f)(f32_4, b=np.arange(6, dtype=np.float32)) + export.call_exported(exp_f)(f32_4, b=np.arange(6, dtype=np.float32)) with self.assertRaisesRegex(ValueError, r"Rank mismatch for args\[0\]"): - jax_export.call_exported(exp_f)(f32_4.reshape((1, 4)), b=f32_4) + export.call_exported(exp_f)(f32_4.reshape((1, 4)), b=f32_4) with self.assertRaisesRegex(ValueError, r"Dtype mismatch for args\[0\]"): - jax_export.call_exported(exp_f)(f32_4.astype(np.float16), b=f32_4) + export.call_exported(exp_f)(f32_4.astype(np.float16), b=f32_4) @jtu.parameterized_filterable( testcase_name=lambda kw: kw["platform"], @@ -194,19 +194,19 @@ class JaxExportTest(jtu.JaxTestCase): def test_error_wrong_platform(self, platform): a = np.arange(4, dtype=np.float32) - exp_f = jax_export.export(jnp.sin, lowering_platform=platform)(a) + exp_f = export.export(jnp.sin, lowering_platform=platform)(a) if xb.canonicalize_platform(jtu.device_under_test()) == platform: raise unittest.SkipTest("Uninteresting scenario") with self.assertRaisesRegex( ValueError, "The exported function .* was lowered for platform"): - jax_export.call_exported(exp_f)(a) + export.call_exported(exp_f)(a) # Now try with the platform check disabled - exp_f_no_platform_check = jax_export.export( + exp_f_no_platform_check = export.export( jnp.sin, lowering_platform=platform, - disabled_checks=[jax_export.DisabledSafetyCheck.platform()])(a) - res = jax_export.call_exported(exp_f_no_platform_check)(a) + disabled_checks=[export.DisabledSafetyCheck.platform()])(a) + res = export.call_exported(exp_f_no_platform_check)(a) self.assertAllClose(res, jnp.sin(a)) @jtu.parameterized_filterable( @@ -230,23 +230,23 @@ class JaxExportTest(jtu.JaxTestCase): a = np.arange(3, dtype=np.float32) with self.assertRaisesRegex(ValueError, "Cannot serialize code with custom calls whose targets .*"): - jax_export.export( + export.export( lambda a: a + test_primitive.bind(a) )(a) # Now try again with the safety check disabled - exp = jax_export.export( + exp = export.export( lambda a: a + test_primitive.bind(a), - disabled_checks=[jax_export.DisabledSafetyCheck.custom_call("disallowed_call_target")] + disabled_checks=[export.DisabledSafetyCheck.custom_call("disallowed_call_target")] )(a) self.assertIn("disallowed_call_target", exp.mlir_module()) def test_grad(self): f = lambda x: jnp.sum(jnp.sin(x)) x = np.arange(4, dtype=np.float32) - exp_f = jax_export.export(f)(x) + exp_f = export.export(f)(x) - f1 = jax_export.call_exported(exp_f) + f1 = export.call_exported(exp_f) self.assertAllClose(jax.grad(f)(x), jax.grad(f1)(x)) def test_pytree_vjp(self): @@ -256,14 +256,14 @@ class JaxExportTest(jtu.JaxTestCase): a = np.arange(4, dtype=np.float32) b = np.arange(6, dtype=np.float32) - exp_f = jax_export.export(f)((a, b), a=a, b=b) + exp_f = export.export(f)((a, b), a=a, b=b) out_ct = f((a, b), a=a, b=b) # The output has the right structure as the cotangent def f1_jax(a, b): # For VJP, make a function without kwargs res = f((a, b), a=a, b=b) return res def f1_exp(a, b): # For VJP, make a function without kwargs - res = jax_export.call_exported(exp_f)((a, b), a=a, b=b) + res = export.call_exported(exp_f)((a, b), a=a, b=b) return res jax_vjp = jax.vjp(f1_jax, a, b)[1](out_ct) exp_vjp = jax.vjp(f1_exp, a, b)[1](out_ct) @@ -273,33 +273,33 @@ class JaxExportTest(jtu.JaxTestCase): def f1(x): return jnp.sin(x) a = np.arange(4, dtype=np.float32) - exp_f1 = jax_export.export(f1)(a) + exp_f1 = export.export(f1)(a) def f2(x): - res1 = jax_export.call_exported(exp_f1)(x) - res2 = jax_export.call_exported(exp_f1)(res1) + res1 = export.call_exported(exp_f1)(x) + res2 = export.call_exported(exp_f1)(res1) return jnp.cos(res2) - exp_f2 = jax_export.export(f2)(a) + exp_f2 = export.export(f2)(a) self.assertAllClose(jnp.cos(jnp.sin(jnp.sin(a))), - jax_export.call_exported(exp_f2)(a)) + export.call_exported(exp_f2)(a)) @jtu.parameterized_filterable( #one_containing="", kwargs=[ dict(v=v) - for v in range(jax_export.minimum_supported_serialization_version - 1, - jax_export.maximum_supported_serialization_version + 2)]) + for v in range(export.minimum_supported_serialization_version - 1, + export.maximum_supported_serialization_version + 2)]) def test_shape_poly_basic_versions(self, v: int): self.override_serialization_version(v) with contextlib.ExitStack() as e: - if not (jax_export.minimum_supported_serialization_version <= v - <= jax_export.maximum_supported_serialization_version): + if not (export.minimum_supported_serialization_version <= v + <= export.maximum_supported_serialization_version): e.enter_context(self.assertRaisesRegex( ValueError, f"The requested jax_serialization version {v} is outside the range of supported versions")) - exp = jax_export.export(jnp.sin)( - jax_export.poly_spec((3, 4), np.float32, "w, h")) + exp = export.export(jnp.sin)( + export.poly_spec((3, 4), np.float32, "w, h")) # Peek at the module module_str = exp.mlir_module() self.assertEqual(config.jax_serialization_version >= 7, @@ -307,11 +307,11 @@ class JaxExportTest(jtu.JaxTestCase): self.assertIn("jax.uses_shape_polymorphism = true", module_str) x = np.arange(30, dtype=np.float32).reshape((5, 6)) - res = jax_export.call_exported(exp)(x) + res = export.call_exported(exp)(x) self.assertAllClose(res, np.sin(x)) # A function is exported with f32[poly_spec] and is called with different arg - # shapes. We use jax_export.call_exported and we also run the shape check + # shapes. We use export.call_exported and we also run the shape check # module. @jtu.parameterized_filterable( testcase_name=lambda kw:f"poly_spec={kw['poly_spec']}_arg_shape={kw['arg_shape']}", # type: ignore @@ -348,8 +348,8 @@ class JaxExportTest(jtu.JaxTestCase): return jnp.reshape(x, (-1, x.shape[1])) disabled_checks = () - exp_f = jax_export.export(f, disabled_checks=disabled_checks)( - jax_export.poly_spec((3, 4, 12), np.float32, poly_spec)) + exp_f = export.export(f, disabled_checks=disabled_checks)( + export.poly_spec((3, 4, 12), np.float32, poly_spec)) self.assertEqual(exp_f.uses_shape_polymorphism, poly_spec != "3,4,12") arg = np.arange(np.prod(arg_shape), dtype=arg_dtype).reshape(arg_shape) # arg : f32[3,4,12] @@ -359,7 +359,7 @@ class JaxExportTest(jtu.JaxTestCase): stack.push(self.assertRaisesRegex(Exception, expect_error)) assert core.is_constant_shape(arg.shape) - res = jax_export.call_exported(exp_f)(arg) + res = export.call_exported(exp_f)(arg) if not expect_error: self.assertAllClose(res, f(arg)) @@ -450,23 +450,23 @@ class JaxExportTest(jtu.JaxTestCase): arg = np.arange(np.prod(arg_shape), dtype=arg_dtype).reshape(arg_shape) # x : f32[3,4,12] - inner_exp = jax_export.export(inner)( - jax_export.poly_spec((3, 4, 12), np.float32, inner_poly_spec)) + inner_exp = export.export(inner)( + export.poly_spec((3, 4, 12), np.float32, inner_poly_spec)) self.assertEqual(inner_exp.uses_shape_polymorphism, (inner_poly_spec != "3,4,12")) def outer(x): # x: outer_poly_spec # Use an addition to test that the shapes are refined properly for the # result of the call_exported. - return jax_export.call_exported(inner_exp)(x) + inner(x) + return export.call_exported(inner_exp)(x) + inner(x) with contextlib.ExitStack() as stack: if expect_error_outer_exp is not None: stack.push(self.assertRaisesRegex(ValueError, expect_error_outer_exp)) # Call it after exporting again, with polymorphic shapes - outer_exp = jax_export.export(outer)( - jax_export.poly_spec(arg.shape, arg.dtype, outer_poly_spec)) + outer_exp = export.export(outer)( + export.poly_spec(arg.shape, arg.dtype, outer_poly_spec)) if expect_error_outer_exp is not None: return @@ -478,7 +478,7 @@ class JaxExportTest(jtu.JaxTestCase): if expect_error_run is not None: stack.push(self.assertRaisesRegex(Exception, expect_error_run)) - res = jax_export.call_exported(outer_exp)(arg) + res = export.call_exported(outer_exp)(arg) if expect_error_run is not None: return @@ -543,9 +543,9 @@ class JaxExportTest(jtu.JaxTestCase): with contextlib.ExitStack() as stack: if expect_error is not None: stack.push(self.assertRaisesRegex(Exception, re.escape(expect_error))) - exp = jax_export.export(f_jax)( - jax_export.poly_spec(x.shape, x.dtype, poly_spec)) - jax_export.call_exported(exp)(x) + exp = export.export(f_jax)( + export.poly_spec(x.shape, x.dtype, poly_spec)) + export.call_exported(exp)(x) def test_multi_platform(self): if jtu.device_under_test() == "gpu": @@ -553,10 +553,10 @@ class JaxExportTest(jtu.JaxTestCase): raise unittest.SkipTest("Not intended for running on GPU") x = np.arange(5, dtype=np.float32) # TODO: use a function with different behavior for different platforms - exp = jax_export.export(jnp.sin, + exp = export.export(jnp.sin, lowering_platforms=('cpu', 'tpu'))(x) self.assertEqual(exp.lowering_platforms, ('cpu', 'tpu')) - res = jax_export.call_exported(exp)(x) + res = export.call_exported(exp)(x) self.assertAllClose(res, np.sin(x)) def test_multi_platform_nested(self): @@ -565,7 +565,7 @@ class JaxExportTest(jtu.JaxTestCase): raise unittest.SkipTest("Not intended for running on TPU") x = np.arange(5, dtype=np.float32) # TODO: use a function with different behavior for different platforms - exp = jax_export.export(jnp.sin, + exp = export.export(jnp.sin, lowering_platforms=('cpu', 'tpu', 'cuda'))(x) self.assertEqual(exp.lowering_platforms, ('cpu', 'tpu', 'cuda')) @@ -573,9 +573,9 @@ class JaxExportTest(jtu.JaxTestCase): # lowering platforms, but included in the lowering platforms for the # nested exported. # TODO: improve this test once we implement true multi-platform lowering - exp2 = jax_export.export(jax_export.call_exported(exp), + exp2 = export.export(export.call_exported(exp), lowering_platforms=('cpu', 'cuda'))(x) - res2 = jax_export.call_exported(exp2)(x) + res2 = export.call_exported(exp2)(x) self.assertAllClose(res2, np.sin(x)) def test_multi_platform_and_poly(self): @@ -583,16 +583,16 @@ class JaxExportTest(jtu.JaxTestCase): # The export is not applicable to GPU raise unittest.SkipTest("Not intended for running on GPU") # TODO: use a function with different behavior for different platforms - exp = jax_export.export(lambda x: jnp.reshape(jnp.sin(x), (-1,)), + exp = export.export(lambda x: jnp.reshape(jnp.sin(x), (-1,)), lowering_platforms=('cpu', 'tpu'))( - jax_export.poly_spec((5, 6), np.float32, "b1, b2") + export.poly_spec((5, 6), np.float32, "b1, b2") ) x = np.arange(12, dtype=np.float32).reshape((3, 4)) - res = jax_export.call_exported(exp)(x) + res = export.call_exported(exp)(x) self.assertAllClose(res, np.sin(x).reshape((-1,))) # Now serialize the call to the exported - exp2 = jax_export.export(jax_export.call_exported(exp))(x) - res2 = jax_export.call_exported(exp2)(x) + exp2 = export.export(export.call_exported(exp))(x) + res2 = export.call_exported(exp2)(x) self.assertAllClose(res2, np.sin(x).reshape((-1,)))