mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Merge pull request #21828 from gnecula:exp_calling_convention
PiperOrigin-RevId: 642977662
This commit is contained in:
commit
a9edaeb38e
@ -220,7 +220,7 @@ present on the exporting machine:
|
||||
|
||||
```python
|
||||
>>> from jax import export
|
||||
>>> export.default_lowering_platform()
|
||||
>>> export.default_export_platform()
|
||||
'cpu'
|
||||
|
||||
```
|
||||
@ -242,15 +242,15 @@ on multiple platforms.
|
||||
>>> from jax import export
|
||||
>>> from jax import lax
|
||||
|
||||
>>> # You can specify the lowering_platform, e.g., `tpu`, `cpu`, `cuda`, `rocm`
|
||||
>>> # You can specify the export platform, e.g., `tpu`, `cpu`, `cuda`, `rocm`
|
||||
>>> # even if the current machine does not have that accelerator.
|
||||
>>> exp = export.export(jax.jit(lax.cos), lowering_platforms=['tpu'])(1.)
|
||||
>>> exp = export.export(jax.jit(lax.cos), platforms=['tpu'])(1.)
|
||||
|
||||
>>> # But you will get an error if you try to compile `exp`
|
||||
>>> # on a machine that does not have TPUs.
|
||||
>>> exp.call(1.) # doctest: +IGNORE_EXCEPTION_DETAIL
|
||||
Traceback (most recent call last):
|
||||
ValueError: The exported function 'cos' was lowered for platforms '('tpu',)' but it is used on '('cpu',)'.
|
||||
ValueError: Function 'cos' was lowered for platforms '('tpu',)' but it is used on '('cpu',)'.
|
||||
|
||||
>>> # We can avoid the error if we pass a `DisabledSafetyCheck.platform`
|
||||
>>> # parameter to `export`, e.g., because you have reasons to believe
|
||||
@ -449,52 +449,52 @@ on multiple devices, and the compiler will shard the function appropriately:
|
||||
|
||||
```
|
||||
|
||||
## Module serialization versions
|
||||
## Calling convention versions
|
||||
|
||||
The JAX export support has evolved over time, e.g., to support
|
||||
effects. In order to support compatibility (see [compatibility guarantees](#compatibility-guarantees))
|
||||
we maintain a serialization version for each `Exported`.
|
||||
As of June 2024, all modules are serialized with version 9
|
||||
(the latest, see [all serialization versions](#serialization-version-numbers)):
|
||||
we maintain a calling convention version for each `Exported`.
|
||||
As of June 2024, all function exported with version 9
|
||||
(the latest, see [all calling convention versions](#calling-convention-versions)):
|
||||
|
||||
```python
|
||||
>>> from jax import export
|
||||
>>> exp: export.Exported = export.export(jnp.cos)(1.)
|
||||
>>> exp.mlir_module_serialization_version
|
||||
>>> exp.calling_convention_version
|
||||
9
|
||||
|
||||
```
|
||||
|
||||
At any given time, the export APIs may support a range
|
||||
of serialization versions. You can control which serialization
|
||||
version to use using the `--jax-serialization-version` flag
|
||||
or the `JAX_SERIALIZATION_VERSION` environment variable:
|
||||
of calling convention versions. You can control which calling convention
|
||||
version to use using the `--jax-export-calling-convention-version` flag
|
||||
or the `JAX_EXPORT_CALLING_CONVENTION_VERSION` environment variable:
|
||||
|
||||
```python
|
||||
>>> from jax import export
|
||||
>>> (export.minimum_supported_serialization_version, export.maximum_supported_serialization_version)
|
||||
>>> (export.minimum_supported_calling_convention_version, export.maximum_supported_calling_convention_version)
|
||||
(9, 9)
|
||||
|
||||
>>> from jax._src import config
|
||||
>>> with config.jax_serialization_version(9):
|
||||
>>> with config.jax_export_calling_convention_version(9):
|
||||
... exp = export.export(jnp.cos)(1.)
|
||||
... exp.mlir_module_serialization_version
|
||||
... exp.calling_convention_version
|
||||
9
|
||||
|
||||
```
|
||||
|
||||
We reserve the right to remove support for
|
||||
generating or consuming serialization versions older than 6 months.
|
||||
generating or consuming calling convention versions older than 6 months.
|
||||
|
||||
### Module calling convention
|
||||
|
||||
The `Exported.mlir_module` has a `main` function that takes an optional first
|
||||
platform index argument if the module supports multiple platforms
|
||||
(`len(lowering_platforms) > 1`), followed by the token arguments corresponding
|
||||
(`len(platforms) > 1`), followed by the token arguments corresponding
|
||||
to the ordered effects, followed by the kept array
|
||||
arguments (corresponding to `module_kept_var_idx` and `in_avals`).
|
||||
The platform index is a i32 or i64 scalar encoding the index of the current
|
||||
compilation platform into the `lowering_platforms` sequence.
|
||||
compilation platform into the `platforms` sequence.
|
||||
|
||||
Inner functions use a different calling convention: an optional
|
||||
platform index argument, optional dimension variable arguments
|
||||
@ -542,15 +542,15 @@ The signature of the `_wrapped_jax_export_main` is:
|
||||
arg: f32[?, ?]) -> (stablehlo.token, ...)
|
||||
```
|
||||
|
||||
Prior to serialization version 9 the calling convention for effects was
|
||||
Prior to calling convention version 9 the calling convention for effects was
|
||||
different: the `main` function does not take or return a token. Instead
|
||||
the function creates dummy tokens of type `i1[0]` and passes them to the
|
||||
`_wrapped_jax_export_main`. The `_wrapped_jax_export_main`
|
||||
takes dummy tokens of type `i1[0]` and will create internally real
|
||||
tokens to pass to the inner functions. The inner functions use real
|
||||
tokens (both before and after serialization version 9)
|
||||
tokens (both before and after calling convention version 9)
|
||||
|
||||
Also starting with serialization version 9, function arguments that contain
|
||||
Also starting with calling convention version 9, function arguments that contain
|
||||
the platform index or the dimension variable values have a
|
||||
`jax.global_constant` string attribute whose value is the name of the
|
||||
global constant, either `_platform_index` or a dimension variable name.
|
||||
@ -584,11 +584,11 @@ scalar operands corresponding to the format specifiers.
|
||||
error_message=""Dimension variable 'h' must have integer value >= 1. Found {0}")
|
||||
```
|
||||
|
||||
(export-serialization-version)=
|
||||
(export-calling-convention-version)=
|
||||
|
||||
### Serialization version numbers
|
||||
### Calling convention versions
|
||||
|
||||
We list here a history of the serialization version numbers:
|
||||
We list here a history of the calling convention version numbers:
|
||||
|
||||
* Version 1 used MHLO & CHLO to serialize the code, not supported anymore.
|
||||
* Version 2 supports StableHLO & CHLO. Used from October 2022. Not supported
|
||||
@ -622,7 +622,7 @@ We list here a history of the serialization version numbers:
|
||||
and the default since October 21st, 2023 (JAX 0.4.20).
|
||||
* Version 9 adds support for effects.
|
||||
See the docstring for `export.Exported` for the precise calling convention.
|
||||
In this serialization version we also tag the platform index and the
|
||||
In this calling convention version we also tag the platform index and the
|
||||
dimension variables arguments with `jax.global_constant` attributes.
|
||||
Supported by XlaCallModule since October 27th, 2023,
|
||||
available in JAX since October 20th, 2023 (JAX 0.4.20),
|
||||
|
@ -20,9 +20,9 @@ Functions
|
||||
|
||||
export
|
||||
deserialize
|
||||
minimum_supported_serialization_version
|
||||
maximum_supported_serialization_version
|
||||
default_lowering_platform
|
||||
minimum_supported_calling_convention_version
|
||||
maximum_supported_calling_convention_version
|
||||
default_export_platform
|
||||
|
||||
Functions related to shape polymorphism
|
||||
---------------------------------------
|
||||
@ -40,8 +40,8 @@ Constants
|
||||
|
||||
.. data:: jax.export.minimum_supported_serialization_version
|
||||
|
||||
The minimum supported serialization version; see :ref:`export-serialization-version`.
|
||||
The minimum supported serialization version; see :ref:`export-calling-convention-version`.
|
||||
|
||||
.. data:: jax.export.maximum_supported_serialization_version
|
||||
|
||||
The maximum supported serialization version; see :ref:`export-serialization-version`.
|
||||
The maximum supported serialization version; see :ref:`export-calling-convention-version`.
|
||||
|
@ -912,13 +912,21 @@ jax2tf_default_native_serialization = define_bool_state(
|
||||
|
||||
jax_serialization_version = define_int_state(
|
||||
name='jax_serialization_version',
|
||||
# Note: bump the default serialization version at least one month after
|
||||
default=int_env('JAX_SERIALIZATION_VERSION', 0), # We use 0 to detect default.
|
||||
help=(
|
||||
'DEPRECATED: use jax_export_calling_convention_version.'
|
||||
)
|
||||
)
|
||||
|
||||
jax_export_calling_convention_version = define_int_state(
|
||||
name='jax_export_calling_convention_version',
|
||||
# Note: bump the default calling convention version at least one month after
|
||||
# we update XlaCallModule to support the new version, so that serialized
|
||||
# modules are forward compatible with deployed versions of XlaCallModule.
|
||||
# Version 9 of XlaCallModule is supported since October 27th, 2023.
|
||||
default=int_env('JAX_SERIALIZATION_VERSION', 9),
|
||||
default=int_env('JAX_EXPORT_CALLING_CONVENTION_VERSION', 9),
|
||||
help=(
|
||||
'The version number to use for native serialization. This must be '
|
||||
'The calling convention version number to use for exporting. This must be '
|
||||
'within the range of versions supported by the tf.XlaCallModule '
|
||||
'used in your deployment environment. '
|
||||
'See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#native-serialization-versions.'
|
||||
|
@ -63,17 +63,10 @@ Shape = jax._src.core.Shape
|
||||
LoweringSharding = Union[sharding.Sharding, pxla.UnspecifiedValue]
|
||||
HloSharding = xla_client.HloSharding
|
||||
|
||||
"""The minimum supported serialization version.
|
||||
|
||||
See https://jax.readthedocs.io/en/latest/export/export.html#serialization-version-numbers
|
||||
"""
|
||||
minimum_supported_serialization_version = 9
|
||||
|
||||
"""The maximum supported serialization version.
|
||||
|
||||
See https://jax.readthedocs.io/en/latest/export/export.html#serialization-version-numbers
|
||||
"""
|
||||
maximum_supported_serialization_version = 9
|
||||
# The minimum and maximum supported calling convention version.
|
||||
# See https://jax.readthedocs.io/en/latest/export.html#module-calling-convention#calling-conventions-versions
|
||||
minimum_supported_calling_convention_version = 9
|
||||
maximum_supported_calling_convention_version = 9
|
||||
|
||||
|
||||
class DisabledSafetyCheck:
|
||||
@ -168,23 +161,25 @@ class Exported:
|
||||
the mesh. See `out_shardings_jax` for a way to turn these
|
||||
into sharding specification that can be used with JAX APIs.
|
||||
nr_devices: the number of devices that the module has been lowered for.
|
||||
lowering_platforms: a tuple containing at least one of 'tpu', 'cpu',
|
||||
platforms: a tuple containing at least one of 'tpu', 'cpu',
|
||||
'cuda', 'rocm'. See https://jax.readthedocs.io/en/latest/export.html#module-calling-convention
|
||||
for the calling convention for when
|
||||
there are multiple lowering platforms.
|
||||
there are multiple export platforms.
|
||||
ordered_effects: the ordered effects present in the serialized module.
|
||||
This is present from serialization version 9. See https://jax.readthedocs.io/en/latest/export.html#module-calling-convention
|
||||
for the calling convention in presence of ordered effects.
|
||||
unordered_effects: the unordered effects present in the serialized module.
|
||||
This is present from serialization version 9.
|
||||
mlir_module_serialized: the serialized lowered VHLO module.
|
||||
mlir_module_serialization_version: a version number for the serialized module.
|
||||
See more versioning details at https://jax.readthedocs.io/en/latest/export.html#module-calling-convention#module-serialization-versions.
|
||||
calling_convention_version: a version number for the calling
|
||||
convention of the exported module.
|
||||
See more versioning details at https://jax.readthedocs.io/en/latest/export.html#calling-convention-versions.
|
||||
module_kept_var_idx: the sorted indices of the arguments among `in_avals` that
|
||||
must be passed to the module. The other arguments have been dropped
|
||||
because they are not used.
|
||||
uses_shape_polymorphism: whether the `mlir_module_serialized` uses shape
|
||||
polymorphism. This may be because `in_avals` contains dimension
|
||||
uses_global_constants: whether the `mlir_module_serialized` uses shape
|
||||
polymorphism or multi-platform export.
|
||||
This may be because `in_avals` contains dimension
|
||||
variables, or due to inner calls of Exported modules that have
|
||||
dimension variables or platform index arguments. Such modules need
|
||||
shape refinement before XLA compilation.
|
||||
@ -197,7 +192,7 @@ class Exported:
|
||||
for each primal output. It returns a tuple with the cotangents
|
||||
corresponding to the flattened primal inputs.
|
||||
|
||||
See a [description of the calling convention for the `mlir_module`](https://jax.readthedocs.io/en/latest/export.html#module_calling_convention).
|
||||
See a [description of the calling convention for the `mlir_module`](https://jax.readthedocs.io/en/latest/export.html#module-calling-convention).
|
||||
"""
|
||||
fun_name: str
|
||||
in_tree: tree_util.PyTreeDef
|
||||
@ -208,15 +203,15 @@ class Exported:
|
||||
in_shardings_hlo: tuple[HloSharding | None, ...]
|
||||
out_shardings_hlo: tuple[HloSharding | None, ...]
|
||||
nr_devices: int
|
||||
lowering_platforms: tuple[str, ...]
|
||||
platforms: tuple[str, ...]
|
||||
ordered_effects: tuple[effects.Effect, ...]
|
||||
unordered_effects: tuple[effects.Effect, ...]
|
||||
disabled_safety_checks: Sequence[DisabledSafetyCheck]
|
||||
|
||||
mlir_module_serialized: bytes
|
||||
mlir_module_serialization_version: int
|
||||
calling_convention_version: int
|
||||
module_kept_var_idx: tuple[int, ...]
|
||||
uses_shape_polymorphism: bool
|
||||
uses_global_constants: bool
|
||||
|
||||
_get_vjp: Callable[[Exported], Exported] | None
|
||||
|
||||
@ -286,6 +281,33 @@ class Exported:
|
||||
return tuple(_hlo_sharding_to_xla_compatible_sharding(s, mesh)
|
||||
for s in self.out_shardings_hlo)
|
||||
|
||||
# For backwards compatibility
|
||||
# TODO(necula): remove after September 2024.
|
||||
@property
|
||||
def lowering_platforms(self):
|
||||
"""DEPRECATED."""
|
||||
warnings.warn("lowering_platform is deprecated. Use .platforms instead.",
|
||||
DeprecationWarning, stacklevel=2)
|
||||
return self.platforms
|
||||
|
||||
# For backwards compatibility
|
||||
# TODO(necula): remove after September 2024.
|
||||
@property
|
||||
def mlir_module_serialization_version(self):
|
||||
"""DEPRECATED."""
|
||||
warnings.warn("mlir_module_serialization_version is deprecated. Use .calling_convention_version instead.",
|
||||
DeprecationWarning, stacklevel=2)
|
||||
return self.calling_convention_version
|
||||
|
||||
# For backwards compatibility
|
||||
# TODO(necula): remove after September 2024.
|
||||
@property
|
||||
def uses_shape_polymorphism(self):
|
||||
"""DEPRECATED."""
|
||||
warnings.warn("uses_shape_polymorphism is deprecated. Use .uses_global_constants instead.",
|
||||
DeprecationWarning, stacklevel=2)
|
||||
return self.uses_global_constants
|
||||
|
||||
def has_vjp(self) -> bool:
|
||||
"""Returns if this Exported supports VJP."""
|
||||
return self._get_vjp is not None
|
||||
@ -331,14 +353,16 @@ def deserialize(blob: bytearray) -> Exported:
|
||||
return deserialize(blob)
|
||||
|
||||
|
||||
def default_lowering_platform() -> str:
|
||||
"""Retrieves the default lowering platform.
|
||||
def default_export_platform() -> str:
|
||||
"""Retrieves the default export platform.
|
||||
|
||||
One of: `tpu`, `cpu`, `cuda`, `rocm`.
|
||||
"""
|
||||
# Canonicalize to turn 'gpu' into 'cuda' or 'rocm'
|
||||
return xb.canonicalize_platform(jax.default_backend())
|
||||
|
||||
default_lowering_platform = default_export_platform
|
||||
|
||||
def shape_and_dtype_jax_array(a) -> tuple[Sequence[int | None], DType]:
|
||||
"""Returns the shape and dtype of a jax.Array or a j"""
|
||||
if isinstance(a, jax.ShapeDtypeStruct):
|
||||
@ -415,7 +439,7 @@ def export_back_compat(
|
||||
if lowering_platforms is not None:
|
||||
actual_lowering_platforms = tuple(lowering_platforms)
|
||||
else:
|
||||
actual_lowering_platforms = (default_lowering_platform(),)
|
||||
actual_lowering_platforms = (default_export_platform(),)
|
||||
|
||||
# TODO: move to `lower`
|
||||
symbolic_scope: tuple[shape_poly.SymbolicScope, tree_util.KeyPath] | None = None # type: ignore[invalid-annotation,unused-ignore]
|
||||
@ -459,6 +483,7 @@ def export_back_compat(
|
||||
def export(
|
||||
fun_jit: stages.Wrapped,
|
||||
*,
|
||||
platforms: Sequence[str] | None = None,
|
||||
lowering_platforms: Sequence[str] | None = None,
|
||||
disabled_checks: Sequence[DisabledSafetyCheck] = (),
|
||||
) -> Callable[..., Exported]:
|
||||
@ -466,13 +491,14 @@ def export(
|
||||
|
||||
Args:
|
||||
fun_jit: the function to export. Should be the result of `jax.jit`.
|
||||
lowering_platforms:
|
||||
platforms:
|
||||
Optional sequence containing a subset of 'tpu', 'cpu',
|
||||
'cuda', 'rocm'. If more than one platform is specified, then
|
||||
the lowered code takes an argument specifying the platform.
|
||||
the exported code takes an argument specifying the platform.
|
||||
If None, then use the default JAX backend.
|
||||
The calling convention for multiple platforms is explained at
|
||||
https://jax.readthedocs.io/en/latest/export.html#module-calling-convention.
|
||||
lowering_platforms: DEPRECATED, use `platforms`.
|
||||
disabled_checks: the safety checks to disable. See documentation for
|
||||
of `jax.export.DisabledSafetyCheck`.
|
||||
|
||||
@ -501,10 +527,14 @@ def export(
|
||||
if not isinstance(fun_jit, stages.Wrapped):
|
||||
raise ValueError(
|
||||
f"Function to be exported must be the result of `jit` but is: {fun_jit}")
|
||||
if lowering_platforms is not None:
|
||||
actual_lowering_platforms = tuple(lowering_platforms)
|
||||
if platforms is not None and lowering_platforms is not None:
|
||||
raise ValueError("Cannot use both `platforms` and `lowering_platforms`")
|
||||
if platforms is None and lowering_platforms is not None:
|
||||
platforms = lowering_platforms
|
||||
if platforms is not None:
|
||||
actual_lowering_platforms = tuple(platforms)
|
||||
else:
|
||||
actual_lowering_platforms = (default_lowering_platform(),)
|
||||
actual_lowering_platforms = (default_export_platform(),)
|
||||
|
||||
def do_export(*args_specs, **kwargs_specs) -> Exported:
|
||||
# TODO: move to `lower`
|
||||
@ -542,13 +572,13 @@ def _export_lowered(
|
||||
disabled_checks: Sequence[DisabledSafetyCheck] = (),
|
||||
_device_assignment_for_internal_jax2tf_use_only = None,
|
||||
) -> Exported:
|
||||
version = config.jax_serialization_version.value
|
||||
if (version < minimum_supported_serialization_version or
|
||||
version > maximum_supported_serialization_version):
|
||||
version = config.jax_export_calling_convention_version.value
|
||||
if (version < minimum_supported_calling_convention_version or
|
||||
version > maximum_supported_calling_convention_version):
|
||||
raise ValueError(
|
||||
f"The requested jax_serialization version {version} is outside the "
|
||||
f"range of supported versions [{minimum_supported_serialization_version}"
|
||||
f"..{maximum_supported_serialization_version}]")
|
||||
f"The requested export calling convention version {version} is outside the "
|
||||
f"range of supported versions [{minimum_supported_calling_convention_version}"
|
||||
f"..{maximum_supported_calling_convention_version}]")
|
||||
|
||||
lowering = lowered._lowering
|
||||
_check_lowering(lowering)
|
||||
@ -638,7 +668,7 @@ def _export_lowered(
|
||||
apply_jit=True,
|
||||
flat_primal_fun=True)
|
||||
return export(fun_vjp_jax, # type: ignore[arg-type]
|
||||
lowering_platforms=exp_primal.lowering_platforms,
|
||||
platforms=exp_primal.platforms,
|
||||
disabled_checks=exp_primal.disabled_safety_checks)(*vjp_in_avals)
|
||||
|
||||
return Exported(
|
||||
@ -650,14 +680,14 @@ def _export_lowered(
|
||||
in_shardings_hlo=in_shardings,
|
||||
out_shardings_hlo=out_shardings,
|
||||
nr_devices=nr_devices,
|
||||
lowering_platforms=lowering._platforms, # type: ignore
|
||||
platforms=lowering._platforms, # type: ignore
|
||||
ordered_effects=ordered_effects,
|
||||
unordered_effects=unordered_effects,
|
||||
disabled_safety_checks=tuple(disabled_checks),
|
||||
mlir_module_serialized=mlir_module_serialized,
|
||||
module_kept_var_idx=module_kept_var_idx,
|
||||
uses_shape_polymorphism=shape_poly_state.uses_dim_vars,
|
||||
mlir_module_serialization_version=version,
|
||||
uses_global_constants=shape_poly_state.uses_dim_vars,
|
||||
calling_convention_version=version,
|
||||
_get_vjp=_get_exported_vjp)
|
||||
|
||||
def _module_to_bytecode(module: ir.Module) -> bytes:
|
||||
@ -1237,7 +1267,7 @@ call_exported_p.def_impl(_call_exported_impl)
|
||||
|
||||
def _call_exported_lowering(ctx: mlir.LoweringRuleContext, *args,
|
||||
exported: Exported):
|
||||
if exported.uses_shape_polymorphism:
|
||||
if exported.uses_global_constants:
|
||||
ctx.module_context.shape_poly_state.uses_dim_vars = True
|
||||
submodule = ir.Module.parse(exported.mlir_module())
|
||||
|
||||
@ -1255,14 +1285,14 @@ def _call_exported_lowering(ctx: mlir.LoweringRuleContext, *args,
|
||||
# than the function was exported for.
|
||||
err_msg = ""
|
||||
if exported.nr_devices != 1:
|
||||
err_msg = "the module was lowered for more than 1 device."
|
||||
err_msg = "the function was exported for more than 1 device."
|
||||
elif (_check_module(submodule, disabled_checks=()) or
|
||||
any(s is not None and not s.is_replicated()
|
||||
for s in exported.in_shardings_hlo + exported.out_shardings_hlo)):
|
||||
err_msg = "the module contains non-replicated sharding annotations."
|
||||
err_msg = "the function contains non-replicated sharding annotations."
|
||||
if err_msg:
|
||||
raise ValueError(
|
||||
f"Exported module {exported.fun_name} was lowered for "
|
||||
f"Function {exported.fun_name} was exported for "
|
||||
f"{exported.nr_devices} devices and is called in a context with "
|
||||
f"{num_devices} devices. This is disallowed because: {err_msg}"
|
||||
)
|
||||
@ -1296,18 +1326,18 @@ def _call_exported_lowering(ctx: mlir.LoweringRuleContext, *args,
|
||||
|
||||
callee_lowering_platform_index: list[int] = []
|
||||
for platform in lowering_platforms:
|
||||
if platform in exported.lowering_platforms:
|
||||
if platform in exported.platforms:
|
||||
callee_lowering_platform_index.append(
|
||||
exported.lowering_platforms.index(platform))
|
||||
exported.platforms.index(platform))
|
||||
elif DisabledSafetyCheck.platform() in exported.disabled_safety_checks:
|
||||
callee_lowering_platform_index.append(0)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"The exported function '{exported.fun_name}' was lowered for "
|
||||
f"platforms '{exported.lowering_platforms}' but it is used "
|
||||
f"Function '{exported.fun_name}' was exported for "
|
||||
f"platforms '{exported.platforms}' but it is used "
|
||||
f"on '{lowering_platforms}'.")
|
||||
|
||||
if len(exported.lowering_platforms) > 1:
|
||||
if len(exported.platforms) > 1:
|
||||
# The exported module takes a platform index argument
|
||||
if len(lowering_platforms) > 1:
|
||||
current_platform_idx = ctx.dim_var_values[0]
|
||||
|
@ -119,16 +119,16 @@ table Exported {
|
||||
in_shardings: [Sharding];
|
||||
out_shardings: [Sharding];
|
||||
|
||||
lowering_platforms: [string];
|
||||
platforms: [string];
|
||||
|
||||
ordered_effects: [Effect];
|
||||
unordered_effects: [Effect];
|
||||
disabled_checks: [DisabledSafetyCheck];
|
||||
|
||||
mlir_module_serialized: [byte];
|
||||
mlir_module_serialization_version: uint16;
|
||||
calling_convention_version: uint16;
|
||||
module_kept_var_idx: [uint16];
|
||||
uses_shape_polymorphism: bool;
|
||||
uses_global_constants: bool;
|
||||
|
||||
vjp: Exported;
|
||||
}
|
||||
|
@ -93,8 +93,8 @@ def _serialize_exported(
|
||||
disabled_safety_checks = _serialize_array(
|
||||
builder, _serialize_disabled_safety_check, exp.disabled_safety_checks
|
||||
)
|
||||
lowering_platforms = _serialize_array(
|
||||
builder, lambda b, p: b.CreateString(p), exp.lowering_platforms
|
||||
platforms = _serialize_array(
|
||||
builder, lambda b, p: b.CreateString(p), exp.platforms
|
||||
)
|
||||
mlir_module_serialized = builder.CreateByteVector(exp.mlir_module_serialized)
|
||||
module_kept_var_idx = builder.CreateNumpyVector(
|
||||
@ -121,17 +121,17 @@ def _serialize_exported(
|
||||
ser_flatbuf.ExportedAddNrDevices(builder, exp.nr_devices)
|
||||
ser_flatbuf.ExportedAddInShardings(builder, in_shardings)
|
||||
ser_flatbuf.ExportedAddOutShardings(builder, out_shardings)
|
||||
ser_flatbuf.ExportedAddLoweringPlatforms(builder, lowering_platforms)
|
||||
ser_flatbuf.ExportedAddPlatforms(builder, platforms)
|
||||
ser_flatbuf.ExportedAddOrderedEffects(builder, ordered_effects)
|
||||
ser_flatbuf.ExportedAddUnorderedEffects(builder, unordered_effects)
|
||||
ser_flatbuf.ExportedAddDisabledChecks(builder, disabled_safety_checks)
|
||||
ser_flatbuf.ExportedAddMlirModuleSerialized(builder, mlir_module_serialized)
|
||||
ser_flatbuf.ExportedAddMlirModuleSerializationVersion(
|
||||
builder, exp.mlir_module_serialization_version
|
||||
ser_flatbuf.ExportedAddCallingConventionVersion(
|
||||
builder, exp.calling_convention_version
|
||||
)
|
||||
ser_flatbuf.ExportedAddModuleKeptVarIdx(builder, module_kept_var_idx)
|
||||
ser_flatbuf.ExportedAddUsesShapePolymorphism(
|
||||
builder, exp.uses_shape_polymorphism
|
||||
ser_flatbuf.ExportedAddUsesGlobalConstants(
|
||||
builder, exp.uses_global_constants
|
||||
)
|
||||
if vjp is not None:
|
||||
ser_flatbuf.ExportedAddVjp(builder, vjp)
|
||||
@ -179,9 +179,9 @@ def _deserialize_exported(exp: ser_flatbuf.Exported) -> _export.Exported:
|
||||
out_shardings = _deserialize_tuple(
|
||||
exp.OutShardingsLength, exp.OutShardings, _deserialize_sharding
|
||||
)
|
||||
lowering_platforms = _deserialize_tuple(
|
||||
exp.LoweringPlatformsLength,
|
||||
exp.LoweringPlatforms,
|
||||
platforms = _deserialize_tuple(
|
||||
exp.PlatformsLength,
|
||||
exp.Platforms,
|
||||
lambda v: v.decode("utf-8"),
|
||||
)
|
||||
ordered_effects = _deserialize_tuple(
|
||||
@ -197,9 +197,9 @@ def _deserialize_exported(exp: ser_flatbuf.Exported) -> _export.Exported:
|
||||
)
|
||||
|
||||
mlir_module_serialized = exp.MlirModuleSerializedAsNumpy().tobytes()
|
||||
mlir_module_serialization_version = exp.MlirModuleSerializationVersion()
|
||||
calling_convention_version = exp.CallingConventionVersion()
|
||||
module_kept_var_idx = tuple(exp.ModuleKeptVarIdxAsNumpy().tolist())
|
||||
uses_shape_polymorphism = exp.UsesShapePolymorphism()
|
||||
uses_global_constants = exp.UsesGlobalConstants()
|
||||
|
||||
_get_vjp = None
|
||||
if vjp := exp.Vjp():
|
||||
@ -214,14 +214,14 @@ def _deserialize_exported(exp: ser_flatbuf.Exported) -> _export.Exported:
|
||||
nr_devices=nr_devices,
|
||||
in_shardings_hlo=in_shardings,
|
||||
out_shardings_hlo=out_shardings,
|
||||
lowering_platforms=lowering_platforms,
|
||||
platforms=platforms,
|
||||
ordered_effects=ordered_effects,
|
||||
unordered_effects=unordered_effects,
|
||||
disabled_safety_checks=disabled_safety_checks,
|
||||
mlir_module_serialized=mlir_module_serialized,
|
||||
mlir_module_serialization_version=mlir_module_serialization_version,
|
||||
calling_convention_version=calling_convention_version,
|
||||
module_kept_var_idx=module_kept_var_idx,
|
||||
uses_shape_polymorphism=uses_shape_polymorphism,
|
||||
uses_global_constants=uses_global_constants,
|
||||
_get_vjp=_get_vjp,
|
||||
)
|
||||
|
||||
|
@ -547,7 +547,7 @@ class Exported(object):
|
||||
return o == 0
|
||||
|
||||
# Exported
|
||||
def LoweringPlatforms(self, j):
|
||||
def Platforms(self, j):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(22))
|
||||
if o != 0:
|
||||
a = self._tab.Vector(o)
|
||||
@ -555,14 +555,14 @@ class Exported(object):
|
||||
return ""
|
||||
|
||||
# Exported
|
||||
def LoweringPlatformsLength(self):
|
||||
def PlatformsLength(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(22))
|
||||
if o != 0:
|
||||
return self._tab.VectorLen(o)
|
||||
return 0
|
||||
|
||||
# Exported
|
||||
def LoweringPlatformsIsNone(self):
|
||||
def PlatformsIsNone(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(22))
|
||||
return o == 0
|
||||
|
||||
@ -666,7 +666,7 @@ class Exported(object):
|
||||
return o == 0
|
||||
|
||||
# Exported
|
||||
def MlirModuleSerializationVersion(self):
|
||||
def CallingConventionVersion(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(32))
|
||||
if o != 0:
|
||||
return self._tab.Get(flatbuffers.number_types.Uint16Flags, o + self._tab.Pos)
|
||||
@ -700,7 +700,7 @@ class Exported(object):
|
||||
return o == 0
|
||||
|
||||
# Exported
|
||||
def UsesShapePolymorphism(self):
|
||||
def UsesGlobalConstants(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(36))
|
||||
if o != 0:
|
||||
return bool(self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos))
|
||||
@ -758,10 +758,10 @@ def ExportedAddOutShardings(builder, outShardings):
|
||||
def ExportedStartOutShardingsVector(builder, numElems):
|
||||
return builder.StartVector(4, numElems, 4)
|
||||
|
||||
def ExportedAddLoweringPlatforms(builder, loweringPlatforms):
|
||||
builder.PrependUOffsetTRelativeSlot(9, flatbuffers.number_types.UOffsetTFlags.py_type(loweringPlatforms), 0)
|
||||
def ExportedAddPlatforms(builder, platforms):
|
||||
builder.PrependUOffsetTRelativeSlot(9, flatbuffers.number_types.UOffsetTFlags.py_type(platforms), 0)
|
||||
|
||||
def ExportedStartLoweringPlatformsVector(builder, numElems):
|
||||
def ExportedStartPlatformsVector(builder, numElems):
|
||||
return builder.StartVector(4, numElems, 4)
|
||||
|
||||
def ExportedAddOrderedEffects(builder, orderedEffects):
|
||||
@ -788,8 +788,8 @@ def ExportedAddMlirModuleSerialized(builder, mlirModuleSerialized):
|
||||
def ExportedStartMlirModuleSerializedVector(builder, numElems):
|
||||
return builder.StartVector(1, numElems, 1)
|
||||
|
||||
def ExportedAddMlirModuleSerializationVersion(builder, mlirModuleSerializationVersion):
|
||||
builder.PrependUint16Slot(14, mlirModuleSerializationVersion, 0)
|
||||
def ExportedAddCallingConventionVersion(builder, callingConventionVersion):
|
||||
builder.PrependUint16Slot(14, callingConventionVersion, 0)
|
||||
|
||||
def ExportedAddModuleKeptVarIdx(builder, moduleKeptVarIdx):
|
||||
builder.PrependUOffsetTRelativeSlot(15, flatbuffers.number_types.UOffsetTFlags.py_type(moduleKeptVarIdx), 0)
|
||||
@ -797,8 +797,8 @@ def ExportedAddModuleKeptVarIdx(builder, moduleKeptVarIdx):
|
||||
def ExportedStartModuleKeptVarIdxVector(builder, numElems):
|
||||
return builder.StartVector(2, numElems, 2)
|
||||
|
||||
def ExportedAddUsesShapePolymorphism(builder, usesShapePolymorphism):
|
||||
builder.PrependBoolSlot(16, usesShapePolymorphism, 0)
|
||||
def ExportedAddUsesGlobalConstants(builder, usesGlobalConstants):
|
||||
builder.PrependBoolSlot(16, usesGlobalConstants, 0)
|
||||
|
||||
def ExportedAddVjp(builder, vjp):
|
||||
builder.PrependUOffsetTRelativeSlot(17, flatbuffers.number_types.UOffsetTFlags.py_type(vjp), 0)
|
||||
|
@ -1313,7 +1313,7 @@ dim_as_value_p.def_abstract_eval(lambda dim: core.dim_value_aval())
|
||||
def dim_as_value_impl(dim: DimSize):
|
||||
raise NotImplementedError(
|
||||
"Evaluation rule for 'dim_as_value' is not implemented. "
|
||||
"It seems that you are using shape polymorphism outside jax2tf.")
|
||||
"It seems that you are using shape polymorphism outside jax.export.")
|
||||
|
||||
dim_as_value_p.def_impl(dim_as_value_impl)
|
||||
def _dim_as_value(dim: DimSize):
|
||||
|
@ -303,7 +303,7 @@ data_{datetime.date.today().strftime('%Y_%m_%d')} = dict(
|
||||
|
||||
module_str = str(exported.mlir_module())
|
||||
serialized = exported.mlir_module_serialized
|
||||
module_version = exported.mlir_module_serialization_version
|
||||
module_version = exported.calling_convention_version
|
||||
nr_devices = exported.nr_devices
|
||||
return serialized, module_str, module_version, nr_devices
|
||||
|
||||
@ -332,15 +332,15 @@ data_{datetime.date.today().strftime('%Y_%m_%d')} = dict(
|
||||
out_avals=tuple(out_avals),
|
||||
in_shardings_hlo=(None,) * len(in_avals),
|
||||
out_shardings_hlo=(None,) * len(out_avals),
|
||||
lowering_platforms=(data.platform,),
|
||||
platforms=(data.platform,),
|
||||
ordered_effects=(),
|
||||
unordered_effects=(),
|
||||
disabled_safety_checks=(),
|
||||
mlir_module_serialized=data.mlir_module_serialized,
|
||||
mlir_module_serialization_version=data.xla_call_module_version,
|
||||
calling_convention_version=data.xla_call_module_version,
|
||||
nr_devices=data.nr_devices,
|
||||
module_kept_var_idx=tuple(range(len(in_avals))),
|
||||
uses_shape_polymorphism=any(not core.is_constant_shape(a.shape)
|
||||
uses_global_constants=any(not core.is_constant_shape(a.shape)
|
||||
for a in in_avals),
|
||||
_get_vjp=_get_vjp)
|
||||
|
||||
|
@ -14,13 +14,13 @@
|
||||
# ==============================================================================
|
||||
|
||||
from jax._src.export._export import (
|
||||
minimum_supported_serialization_version,
|
||||
maximum_supported_serialization_version,
|
||||
Exported,
|
||||
call_exported, # TODO: deprecate
|
||||
call,
|
||||
DisabledSafetyCheck,
|
||||
default_lowering_platform, # TODO: deprecate
|
||||
minimum_supported_calling_convention_version,
|
||||
maximum_supported_calling_convention_version,
|
||||
Exported,
|
||||
call_exported, # TODO: deprecate
|
||||
call,
|
||||
DisabledSafetyCheck,
|
||||
default_lowering_platform, # TODO: deprecate
|
||||
)
|
||||
from jax._src.export._export import export_back_compat as export
|
||||
|
||||
|
@ -1652,10 +1652,10 @@ class RoundTripToTfTest(tf_test_util.JaxToTfTestCase):
|
||||
kwargs=[dict(version=version) for version in [9]]
|
||||
)
|
||||
def test_call_tf_graph_ordered(self, *, version: int):
|
||||
with config.jax_serialization_version(version):
|
||||
with config.jax_export_calling_convention_version(version):
|
||||
logging.info(
|
||||
"Using JAX serialization version %s",
|
||||
jax.config.jax_serialization_version)
|
||||
jax.config.jax_export_calling_convention_version)
|
||||
|
||||
@tf.function
|
||||
def tf_print(x):
|
||||
@ -1725,10 +1725,10 @@ class RoundTripToTfTest(tf_test_util.JaxToTfTestCase):
|
||||
for version in [9]]
|
||||
)
|
||||
def test_call_tf_ordered_dead_inputs(self, *, poly: bool, version: int):
|
||||
with config.jax_serialization_version(version):
|
||||
with config.jax_export_calling_convention_version(version):
|
||||
logging.info(
|
||||
"Using JAX serialization version %s",
|
||||
jax.config.jax_serialization_version)
|
||||
jax.config.jax_export_calling_convention_version)
|
||||
def f_jax(x1, x_dead, x3):
|
||||
return (x1, jax2tf.call_tf(lambda x: tf.math.sin(x), ordered=True,
|
||||
call_tf_graph=True)(x3))
|
||||
@ -1750,10 +1750,10 @@ class RoundTripToTfTest(tf_test_util.JaxToTfTestCase):
|
||||
]
|
||||
)
|
||||
def test_call_tf_graph_polymorphic(self, ordered: bool, version: int):
|
||||
with config.jax_serialization_version(version):
|
||||
with config.jax_export_calling_convention_version(version):
|
||||
logging.info(
|
||||
"Using JAX serialization version %s",
|
||||
jax.config.jax_serialization_version)
|
||||
jax.config.jax_export_calling_convention_version)
|
||||
|
||||
@tf.function(jit_compile=True, autograph=False)
|
||||
@partial(jax2tf.convert,
|
||||
|
@ -180,18 +180,18 @@ class JaxToTfTestCase(jtu.JaxTestCase):
|
||||
# We run the tests using the maximum version supported, even though
|
||||
# the default serialization version may be held back for a while to
|
||||
# ensure compatibility
|
||||
version = config.jax_serialization_version.value
|
||||
version = config.jax_export_calling_convention_version.value
|
||||
if self.use_max_serialization_version:
|
||||
# Use the largest supported by both export and tfxla.call_module
|
||||
version = min(export.maximum_supported_serialization_version,
|
||||
version = min(export.maximum_supported_calling_convention_version,
|
||||
tfxla.call_module_maximum_supported_version())
|
||||
self.assertGreaterEqual(version,
|
||||
export.minimum_supported_serialization_version)
|
||||
self.enter_context(config.jax_serialization_version(version))
|
||||
export.minimum_supported_calling_convention_version)
|
||||
self.enter_context(config.jax_export_calling_convention_version(version))
|
||||
logging.info(
|
||||
"Using JAX serialization version %s (export.max_version %s, tf.XlaCallModule max version %s)",
|
||||
version,
|
||||
export.maximum_supported_serialization_version,
|
||||
export.maximum_supported_calling_convention_version,
|
||||
tfxla.call_module_maximum_supported_version())
|
||||
|
||||
with contextlib.ExitStack() as stack:
|
||||
|
@ -12,23 +12,25 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
__all__ = ["DisabledSafetyCheck", "Exported", "export", "deserialize",
|
||||
"maximum_supported_serialization_version",
|
||||
"minimum_supported_serialization_version",
|
||||
"default_lowering_platform",
|
||||
"maximum_supported_calling_convention_version",
|
||||
"minimum_supported_calling_convention_version",
|
||||
"default_export_platform",
|
||||
"SymbolicScope", "is_symbolic_dim",
|
||||
"symbolic_shape", "symbolic_args_specs"]
|
||||
|
||||
from jax._src.export._export import DisabledSafetyCheck as DisabledSafetyCheck
|
||||
from jax._src.export._export import Exported as Exported
|
||||
from jax._src.export._export import export as export
|
||||
from jax._src.export._export import deserialize as deserialize
|
||||
from jax._src.export._export import maximum_supported_serialization_version as maximum_supported_serialization_version
|
||||
from jax._src.export._export import minimum_supported_serialization_version as minimum_supported_serialization_version
|
||||
from jax._src.export._export import default_lowering_platform as default_lowering_platform
|
||||
from jax._src.export._export import (
|
||||
DisabledSafetyCheck,
|
||||
Exported,
|
||||
export,
|
||||
deserialize,
|
||||
maximum_supported_calling_convention_version,
|
||||
minimum_supported_calling_convention_version,
|
||||
default_export_platform)
|
||||
|
||||
from jax._src.export import shape_poly_decision # Import only to set the decision procedure
|
||||
del shape_poly_decision
|
||||
from jax._src.export.shape_poly import SymbolicScope as SymbolicScope
|
||||
from jax._src.export.shape_poly import is_symbolic_dim as is_symbolic_dim
|
||||
from jax._src.export.shape_poly import symbolic_shape as symbolic_shape
|
||||
from jax._src.export.shape_poly import symbolic_args_specs as symbolic_args_specs
|
||||
from jax._src.export.shape_poly import (
|
||||
SymbolicScope,
|
||||
is_symbolic_dim,
|
||||
symbolic_shape,
|
||||
symbolic_args_specs)
|
||||
|
@ -150,7 +150,7 @@ def get_exported(fun: Callable, vjp_order=0,
|
||||
|
||||
|
||||
# Run tests with the maximum supported version by default
|
||||
@jtu.with_config(jax_serialization_version=export.maximum_supported_serialization_version)
|
||||
@jtu.with_config(jax_export_calling_convention_version=export.maximum_supported_calling_convention_version)
|
||||
class JaxExportTest(jtu.JaxTestCase):
|
||||
|
||||
@classmethod
|
||||
@ -173,7 +173,7 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
self.assertEqual("my_fun", exp.fun_name)
|
||||
expected_lowering_platform = xb.canonicalize_platform(jax.default_backend())
|
||||
self.assertEqual((expected_lowering_platform,),
|
||||
exp.lowering_platforms)
|
||||
exp.platforms)
|
||||
self.assertEqual(jax.tree.flatten(((1,), {}))[1], exp.in_tree)
|
||||
self.assertEqual((core.ShapedArray((4,), dtype=np.float32),), exp.in_avals)
|
||||
self.assertEqual((core.ShapedArray((4,), dtype=np.float32),), exp.out_avals)
|
||||
@ -187,7 +187,7 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
exp = get_exported(jax.jit(f), lowering_platforms=("cpu",))((a, b), a=a, b=b)
|
||||
a_aval = core.ShapedArray(a.shape, a.dtype)
|
||||
b_aval = core.ShapedArray(b.shape, b.dtype)
|
||||
self.assertEqual(exp.lowering_platforms, ("cpu",))
|
||||
self.assertEqual(exp.platforms, ("cpu",))
|
||||
args = ((a, b),)
|
||||
kwargs = dict(a=a, b=b)
|
||||
self.assertEqual(exp.in_tree, jax.tree.flatten((args, kwargs))[1])
|
||||
@ -331,12 +331,12 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
r"Dtype mismatch for args\[0\]"):
|
||||
exp_f.call(f32_4.astype(np.float16), b=f32_4)
|
||||
|
||||
def test_default_lowering_platform(self):
|
||||
def test_default_export_platform(self):
|
||||
test_platform = jtu.device_under_test()
|
||||
if test_platform == "gpu": test_platform = "cuda"
|
||||
self.assertEqual(export.default_lowering_platform(), test_platform)
|
||||
self.assertEqual(export.default_export_platform(), test_platform)
|
||||
exp = export.export(jnp.sin)(1.)
|
||||
self.assertEqual(exp.lowering_platforms, (export.default_lowering_platform(),))
|
||||
self.assertEqual(exp.platforms, (export.default_export_platform(),))
|
||||
|
||||
@jtu.parameterized_filterable(
|
||||
testcase_name=lambda kw: kw["platform"],
|
||||
@ -350,7 +350,7 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
raise unittest.SkipTest("Uninteresting scenario")
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "The exported function .* was lowered for platform"):
|
||||
ValueError, "Function .* was exported for platform"):
|
||||
exp_f.call(a)
|
||||
|
||||
# Now try with the platform check disabled
|
||||
@ -521,7 +521,7 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
|
||||
# Peek at the module
|
||||
module_str = exp.mlir_module()
|
||||
self.assertEqual(config.jax_serialization_version.value >= 7,
|
||||
self.assertEqual(config.jax_export_calling_convention_version.value >= 7,
|
||||
"shape_assertion" in module_str)
|
||||
self.assertIn("jax.uses_shape_polymorphism = true", module_str)
|
||||
wrapped_main_expected_re = (
|
||||
@ -596,19 +596,19 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
@jtu.parameterized_filterable(
|
||||
kwargs=[
|
||||
dict(v=v)
|
||||
for v in range(export.minimum_supported_serialization_version - 1,
|
||||
export.maximum_supported_serialization_version + 2)])
|
||||
for v in range(export.minimum_supported_calling_convention_version - 1,
|
||||
export.maximum_supported_calling_convention_version + 2)])
|
||||
def test_poly_basic_versions(self, v: int):
|
||||
with config.jax_serialization_version(v):
|
||||
with config.jax_export_calling_convention_version(v):
|
||||
logging.info(
|
||||
"Using JAX serialization version %s",
|
||||
config.jax_serialization_version.value)
|
||||
"Using JAX calling convention version %s",
|
||||
config.jax_export_calling_convention_version.value)
|
||||
with contextlib.ExitStack() as e:
|
||||
if not (export.minimum_supported_serialization_version <= v
|
||||
<= export.maximum_supported_serialization_version):
|
||||
if not (export.minimum_supported_calling_convention_version <= v
|
||||
<= export.maximum_supported_calling_convention_version):
|
||||
e.enter_context(self.assertRaisesRegex(
|
||||
ValueError,
|
||||
f"The requested jax_serialization version {v} is outside the range of supported versions"))
|
||||
f"The requested export calling convention version {v} is outside the range of supported versions"))
|
||||
|
||||
exp = get_exported(jnp.sin)(
|
||||
jax.ShapeDtypeStruct(export.symbolic_shape("w, h"), np.float32))
|
||||
@ -656,7 +656,7 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
disabled_checks = ()
|
||||
exp_f = get_exported(jax.jit(f), disabled_checks=disabled_checks)(
|
||||
jax.ShapeDtypeStruct(export.symbolic_shape(poly_spec), np.float32))
|
||||
self.assertEqual(exp_f.uses_shape_polymorphism, poly_spec != "3,4,12")
|
||||
self.assertEqual(exp_f.uses_global_constants, poly_spec != "3,4,12")
|
||||
arg = np.arange(np.prod(arg_shape),
|
||||
dtype=arg_dtype).reshape(arg_shape) # arg : f32[3,4,12]
|
||||
|
||||
@ -759,7 +759,7 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
inner_exp = get_exported(jax.jit(inner))(
|
||||
jax.ShapeDtypeStruct(export.symbolic_shape(inner_poly_spec), np.float32))
|
||||
|
||||
self.assertEqual(inner_exp.uses_shape_polymorphism,
|
||||
self.assertEqual(inner_exp.uses_global_constants,
|
||||
(inner_poly_spec != "3,4,12"))
|
||||
def outer(x): # x: outer_poly_spec
|
||||
# Use an addition to test that the shapes are refined properly for the
|
||||
@ -777,7 +777,7 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
if expect_error_outer_exp is not None:
|
||||
return
|
||||
|
||||
self.assertEqual(outer_exp.uses_shape_polymorphism,
|
||||
self.assertEqual(outer_exp.uses_global_constants,
|
||||
(inner_poly_spec != "3,4,12" or outer_poly_spec != "3,4,12"))
|
||||
|
||||
with contextlib.ExitStack() as stack:
|
||||
@ -966,12 +966,12 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
# Test error reporting
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"Exported module .* was lowered for 2 devices and is called in a context with 1 device"):
|
||||
"Function .* was exported for 2 devices and is called in a context with 1 device"):
|
||||
_ = exp.call(a)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"Exported module .* was lowered for 2 devices and is called in a context with 1 device"):
|
||||
"Function .* was exported for 2 devices and is called in a context with 1 device"):
|
||||
mesh1 = Mesh(jax.devices()[0:1], axis_names=("x",))
|
||||
_ = jax.jit(
|
||||
exp.call,
|
||||
@ -1047,8 +1047,8 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"Exported module .* was lowered for 1 devices and is called in a "
|
||||
f"context with {jax.local_device_count()} devices.* module contains "
|
||||
"Function .* was exported for 1 devices and is called in a "
|
||||
f"context with {jax.local_device_count()} devices.* function contains "
|
||||
"non-replicated sharding annotations"):
|
||||
exp.call(b)
|
||||
|
||||
@ -1093,8 +1093,8 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"Exported module .* was lowered for 1 devices and is called in a "
|
||||
f"context with {jax.local_device_count()} devices.* module contains "
|
||||
"Function .* was exported for 1 devices and is called in a "
|
||||
f"context with {jax.local_device_count()} devices.* function contains "
|
||||
"non-replicated sharding annotations"):
|
||||
exp.call(b)
|
||||
|
||||
@ -1137,7 +1137,7 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
f_r = exp.call
|
||||
with self.assertRaisesRegex(
|
||||
Exception,
|
||||
"Exported module .* was lowered for 2 devices and is "
|
||||
"Function .* was exported for 2 devices and is "
|
||||
"called in a context with 1 devices"):
|
||||
_ = f_r(a) # A is all on the default device
|
||||
|
||||
@ -1311,7 +1311,7 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
x = np.arange(8, dtype=np.float32)
|
||||
exp = get_exported(jax.jit(_testing_multi_platform_func),
|
||||
lowering_platforms=("tpu", "cpu", "cuda","rocm"))(x)
|
||||
self.assertEqual(exp.lowering_platforms, ("tpu", "cpu", "cuda", "rocm"))
|
||||
self.assertEqual(exp.platforms, ("tpu", "cpu", "cuda", "rocm"))
|
||||
module_str = str(exp.mlir_module())
|
||||
expected_main_re = (
|
||||
r"@main\("
|
||||
@ -1334,7 +1334,7 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
x = np.arange(5, dtype=np.float32)
|
||||
exp = get_exported(jax.jit(lambda x: _testing_multi_platform_func(jnp.sin(x))),
|
||||
lowering_platforms=("cpu", "tpu", "cuda","rocm"))(x)
|
||||
self.assertEqual(exp.lowering_platforms, ("cpu", "tpu", "cuda","rocm"))
|
||||
self.assertEqual(exp.platforms, ("cpu", "tpu", "cuda","rocm"))
|
||||
|
||||
# Now serialize the call to the exported using a different sequence of
|
||||
# lowering platforms, but included in the lowering platforms for the
|
||||
@ -1360,7 +1360,7 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
x = np.arange(5, dtype=np.float32)
|
||||
exp = get_exported(jax.jit(_testing_multi_platform_func),
|
||||
lowering_platforms=("cpu", "tpu", "cuda","rocm"))(x)
|
||||
self.assertEqual(exp.lowering_platforms, ("cpu", "tpu", "cuda", "rocm"))
|
||||
self.assertEqual(exp.platforms, ("cpu", "tpu", "cuda", "rocm"))
|
||||
|
||||
# Now serialize the call for the current platform.
|
||||
exp2 = get_exported(jax.jit(exp.call))(x)
|
||||
@ -1472,13 +1472,13 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
@jtu.parameterized_filterable(
|
||||
kwargs=[
|
||||
dict(v=v)
|
||||
for v in range(export.minimum_supported_serialization_version,
|
||||
export.maximum_supported_serialization_version + 1)])
|
||||
for v in range(export.minimum_supported_calling_convention_version,
|
||||
export.maximum_supported_calling_convention_version + 1)])
|
||||
def test_ordered_effects_basic(self, *, v: int):
|
||||
with config.jax_serialization_version(v):
|
||||
with config.jax_export_calling_convention_version(v):
|
||||
logging.info(
|
||||
"Using JAX serialization version %s",
|
||||
config.jax_serialization_version.value)
|
||||
config.jax_export_calling_convention_version.value)
|
||||
x = np.arange(3, dtype=np.float32)
|
||||
def f_jax(x): # x: f32[3]
|
||||
# Test also the calling convention for inner functions
|
||||
@ -1550,13 +1550,13 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
@jtu.parameterized_filterable(
|
||||
kwargs=[
|
||||
dict(v=v)
|
||||
for v in range(export.minimum_supported_serialization_version,
|
||||
export.maximum_supported_serialization_version + 1)])
|
||||
for v in range(export.minimum_supported_calling_convention_version,
|
||||
export.maximum_supported_calling_convention_version + 1)])
|
||||
def test_ordered_effects_poly(self, *, v: int):
|
||||
with config.jax_serialization_version(v):
|
||||
with config.jax_export_calling_convention_version(v):
|
||||
logging.info(
|
||||
"Using JAX serialization version %s",
|
||||
config.jax_serialization_version.value)
|
||||
config.jax_export_calling_convention_version.value)
|
||||
x = np.arange(12, dtype=np.float32).reshape((3, 4))
|
||||
def f_jax(x): # x: f32[b1, b2]
|
||||
return 10. + testing_primitive_with_effect_p.bind(x, effect_class_name="ForTestingOrderedEffect1")
|
||||
@ -1587,13 +1587,13 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
@jtu.parameterized_filterable(
|
||||
kwargs=[
|
||||
dict(v=v)
|
||||
for v in range(export.minimum_supported_serialization_version,
|
||||
export.maximum_supported_serialization_version + 1)])
|
||||
for v in range(export.minimum_supported_calling_convention_version,
|
||||
export.maximum_supported_calling_convention_version + 1)])
|
||||
def test_ordered_effects_multi_platform_and_poly(self, *, v: int):
|
||||
with config.jax_serialization_version(v):
|
||||
with config.jax_export_calling_convention_version(v):
|
||||
logging.info(
|
||||
"Using JAX serialization version %s",
|
||||
config.jax_serialization_version.value)
|
||||
config.jax_export_calling_convention_version.value)
|
||||
if jtu.device_under_test() == "gpu":
|
||||
# The export is not applicable to GPU
|
||||
raise unittest.SkipTest("Not intended for running on GPU")
|
||||
@ -1632,13 +1632,13 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
@jtu.parameterized_filterable(
|
||||
kwargs=[
|
||||
dict(v=v)
|
||||
for v in range(export.minimum_supported_serialization_version,
|
||||
export.maximum_supported_serialization_version + 1)])
|
||||
for v in range(export.minimum_supported_calling_convention_version,
|
||||
export.maximum_supported_calling_convention_version + 1)])
|
||||
def test_ordered_effects_with_donation(self, *, v: int):
|
||||
with config.jax_serialization_version(v):
|
||||
with config.jax_export_calling_convention_version(v):
|
||||
logging.info(
|
||||
"Using JAX serialization version %s",
|
||||
config.jax_serialization_version.value)
|
||||
config.jax_export_calling_convention_version.value)
|
||||
|
||||
x = np.arange(3, dtype=np.float32)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user