Merge pull request #21828 from gnecula:exp_calling_convention

PiperOrigin-RevId: 642977662
This commit is contained in:
jax authors 2024-06-13 07:12:59 -07:00
commit a9edaeb38e
14 changed files with 233 additions and 193 deletions

View File

@ -220,7 +220,7 @@ present on the exporting machine:
```python
>>> from jax import export
>>> export.default_lowering_platform()
>>> export.default_export_platform()
'cpu'
```
@ -242,15 +242,15 @@ on multiple platforms.
>>> from jax import export
>>> from jax import lax
>>> # You can specify the lowering_platform, e.g., `tpu`, `cpu`, `cuda`, `rocm`
>>> # You can specify the export platform, e.g., `tpu`, `cpu`, `cuda`, `rocm`
>>> # even if the current machine does not have that accelerator.
>>> exp = export.export(jax.jit(lax.cos), lowering_platforms=['tpu'])(1.)
>>> exp = export.export(jax.jit(lax.cos), platforms=['tpu'])(1.)
>>> # But you will get an error if you try to compile `exp`
>>> # on a machine that does not have TPUs.
>>> exp.call(1.) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
ValueError: The exported function 'cos' was lowered for platforms '('tpu',)' but it is used on '('cpu',)'.
ValueError: Function 'cos' was lowered for platforms '('tpu',)' but it is used on '('cpu',)'.
>>> # We can avoid the error if we pass a `DisabledSafetyCheck.platform`
>>> # parameter to `export`, e.g., because you have reasons to believe
@ -449,52 +449,52 @@ on multiple devices, and the compiler will shard the function appropriately:
```
## Module serialization versions
## Calling convention versions
The JAX export support has evolved over time, e.g., to support
effects. In order to support compatibility (see [compatibility guarantees](#compatibility-guarantees))
we maintain a serialization version for each `Exported`.
As of June 2024, all modules are serialized with version 9
(the latest, see [all serialization versions](#serialization-version-numbers)):
we maintain a calling convention version for each `Exported`.
As of June 2024, all function exported with version 9
(the latest, see [all calling convention versions](#calling-convention-versions)):
```python
>>> from jax import export
>>> exp: export.Exported = export.export(jnp.cos)(1.)
>>> exp.mlir_module_serialization_version
>>> exp.calling_convention_version
9
```
At any given time, the export APIs may support a range
of serialization versions. You can control which serialization
version to use using the `--jax-serialization-version` flag
or the `JAX_SERIALIZATION_VERSION` environment variable:
of calling convention versions. You can control which calling convention
version to use using the `--jax-export-calling-convention-version` flag
or the `JAX_EXPORT_CALLING_CONVENTION_VERSION` environment variable:
```python
>>> from jax import export
>>> (export.minimum_supported_serialization_version, export.maximum_supported_serialization_version)
>>> (export.minimum_supported_calling_convention_version, export.maximum_supported_calling_convention_version)
(9, 9)
>>> from jax._src import config
>>> with config.jax_serialization_version(9):
>>> with config.jax_export_calling_convention_version(9):
... exp = export.export(jnp.cos)(1.)
... exp.mlir_module_serialization_version
... exp.calling_convention_version
9
```
We reserve the right to remove support for
generating or consuming serialization versions older than 6 months.
generating or consuming calling convention versions older than 6 months.
### Module calling convention
The `Exported.mlir_module` has a `main` function that takes an optional first
platform index argument if the module supports multiple platforms
(`len(lowering_platforms) > 1`), followed by the token arguments corresponding
(`len(platforms) > 1`), followed by the token arguments corresponding
to the ordered effects, followed by the kept array
arguments (corresponding to `module_kept_var_idx` and `in_avals`).
The platform index is a i32 or i64 scalar encoding the index of the current
compilation platform into the `lowering_platforms` sequence.
compilation platform into the `platforms` sequence.
Inner functions use a different calling convention: an optional
platform index argument, optional dimension variable arguments
@ -542,15 +542,15 @@ The signature of the `_wrapped_jax_export_main` is:
arg: f32[?, ?]) -> (stablehlo.token, ...)
```
Prior to serialization version 9 the calling convention for effects was
Prior to calling convention version 9 the calling convention for effects was
different: the `main` function does not take or return a token. Instead
the function creates dummy tokens of type `i1[0]` and passes them to the
`_wrapped_jax_export_main`. The `_wrapped_jax_export_main`
takes dummy tokens of type `i1[0]` and will create internally real
tokens to pass to the inner functions. The inner functions use real
tokens (both before and after serialization version 9)
tokens (both before and after calling convention version 9)
Also starting with serialization version 9, function arguments that contain
Also starting with calling convention version 9, function arguments that contain
the platform index or the dimension variable values have a
`jax.global_constant` string attribute whose value is the name of the
global constant, either `_platform_index` or a dimension variable name.
@ -584,11 +584,11 @@ scalar operands corresponding to the format specifiers.
error_message=""Dimension variable 'h' must have integer value >= 1. Found {0}")
```
(export-serialization-version)=
(export-calling-convention-version)=
### Serialization version numbers
### Calling convention versions
We list here a history of the serialization version numbers:
We list here a history of the calling convention version numbers:
* Version 1 used MHLO & CHLO to serialize the code, not supported anymore.
* Version 2 supports StableHLO & CHLO. Used from October 2022. Not supported
@ -622,7 +622,7 @@ We list here a history of the serialization version numbers:
and the default since October 21st, 2023 (JAX 0.4.20).
* Version 9 adds support for effects.
See the docstring for `export.Exported` for the precise calling convention.
In this serialization version we also tag the platform index and the
In this calling convention version we also tag the platform index and the
dimension variables arguments with `jax.global_constant` attributes.
Supported by XlaCallModule since October 27th, 2023,
available in JAX since October 20th, 2023 (JAX 0.4.20),

View File

@ -20,9 +20,9 @@ Functions
export
deserialize
minimum_supported_serialization_version
maximum_supported_serialization_version
default_lowering_platform
minimum_supported_calling_convention_version
maximum_supported_calling_convention_version
default_export_platform
Functions related to shape polymorphism
---------------------------------------
@ -40,8 +40,8 @@ Constants
.. data:: jax.export.minimum_supported_serialization_version
The minimum supported serialization version; see :ref:`export-serialization-version`.
The minimum supported serialization version; see :ref:`export-calling-convention-version`.
.. data:: jax.export.maximum_supported_serialization_version
The maximum supported serialization version; see :ref:`export-serialization-version`.
The maximum supported serialization version; see :ref:`export-calling-convention-version`.

View File

@ -912,13 +912,21 @@ jax2tf_default_native_serialization = define_bool_state(
jax_serialization_version = define_int_state(
name='jax_serialization_version',
# Note: bump the default serialization version at least one month after
default=int_env('JAX_SERIALIZATION_VERSION', 0), # We use 0 to detect default.
help=(
'DEPRECATED: use jax_export_calling_convention_version.'
)
)
jax_export_calling_convention_version = define_int_state(
name='jax_export_calling_convention_version',
# Note: bump the default calling convention version at least one month after
# we update XlaCallModule to support the new version, so that serialized
# modules are forward compatible with deployed versions of XlaCallModule.
# Version 9 of XlaCallModule is supported since October 27th, 2023.
default=int_env('JAX_SERIALIZATION_VERSION', 9),
default=int_env('JAX_EXPORT_CALLING_CONVENTION_VERSION', 9),
help=(
'The version number to use for native serialization. This must be '
'The calling convention version number to use for exporting. This must be '
'within the range of versions supported by the tf.XlaCallModule '
'used in your deployment environment. '
'See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#native-serialization-versions.'

View File

@ -63,17 +63,10 @@ Shape = jax._src.core.Shape
LoweringSharding = Union[sharding.Sharding, pxla.UnspecifiedValue]
HloSharding = xla_client.HloSharding
"""The minimum supported serialization version.
See https://jax.readthedocs.io/en/latest/export/export.html#serialization-version-numbers
"""
minimum_supported_serialization_version = 9
"""The maximum supported serialization version.
See https://jax.readthedocs.io/en/latest/export/export.html#serialization-version-numbers
"""
maximum_supported_serialization_version = 9
# The minimum and maximum supported calling convention version.
# See https://jax.readthedocs.io/en/latest/export.html#module-calling-convention#calling-conventions-versions
minimum_supported_calling_convention_version = 9
maximum_supported_calling_convention_version = 9
class DisabledSafetyCheck:
@ -168,23 +161,25 @@ class Exported:
the mesh. See `out_shardings_jax` for a way to turn these
into sharding specification that can be used with JAX APIs.
nr_devices: the number of devices that the module has been lowered for.
lowering_platforms: a tuple containing at least one of 'tpu', 'cpu',
platforms: a tuple containing at least one of 'tpu', 'cpu',
'cuda', 'rocm'. See https://jax.readthedocs.io/en/latest/export.html#module-calling-convention
for the calling convention for when
there are multiple lowering platforms.
there are multiple export platforms.
ordered_effects: the ordered effects present in the serialized module.
This is present from serialization version 9. See https://jax.readthedocs.io/en/latest/export.html#module-calling-convention
for the calling convention in presence of ordered effects.
unordered_effects: the unordered effects present in the serialized module.
This is present from serialization version 9.
mlir_module_serialized: the serialized lowered VHLO module.
mlir_module_serialization_version: a version number for the serialized module.
See more versioning details at https://jax.readthedocs.io/en/latest/export.html#module-calling-convention#module-serialization-versions.
calling_convention_version: a version number for the calling
convention of the exported module.
See more versioning details at https://jax.readthedocs.io/en/latest/export.html#calling-convention-versions.
module_kept_var_idx: the sorted indices of the arguments among `in_avals` that
must be passed to the module. The other arguments have been dropped
because they are not used.
uses_shape_polymorphism: whether the `mlir_module_serialized` uses shape
polymorphism. This may be because `in_avals` contains dimension
uses_global_constants: whether the `mlir_module_serialized` uses shape
polymorphism or multi-platform export.
This may be because `in_avals` contains dimension
variables, or due to inner calls of Exported modules that have
dimension variables or platform index arguments. Such modules need
shape refinement before XLA compilation.
@ -197,7 +192,7 @@ class Exported:
for each primal output. It returns a tuple with the cotangents
corresponding to the flattened primal inputs.
See a [description of the calling convention for the `mlir_module`](https://jax.readthedocs.io/en/latest/export.html#module_calling_convention).
See a [description of the calling convention for the `mlir_module`](https://jax.readthedocs.io/en/latest/export.html#module-calling-convention).
"""
fun_name: str
in_tree: tree_util.PyTreeDef
@ -208,15 +203,15 @@ class Exported:
in_shardings_hlo: tuple[HloSharding | None, ...]
out_shardings_hlo: tuple[HloSharding | None, ...]
nr_devices: int
lowering_platforms: tuple[str, ...]
platforms: tuple[str, ...]
ordered_effects: tuple[effects.Effect, ...]
unordered_effects: tuple[effects.Effect, ...]
disabled_safety_checks: Sequence[DisabledSafetyCheck]
mlir_module_serialized: bytes
mlir_module_serialization_version: int
calling_convention_version: int
module_kept_var_idx: tuple[int, ...]
uses_shape_polymorphism: bool
uses_global_constants: bool
_get_vjp: Callable[[Exported], Exported] | None
@ -286,6 +281,33 @@ class Exported:
return tuple(_hlo_sharding_to_xla_compatible_sharding(s, mesh)
for s in self.out_shardings_hlo)
# For backwards compatibility
# TODO(necula): remove after September 2024.
@property
def lowering_platforms(self):
"""DEPRECATED."""
warnings.warn("lowering_platform is deprecated. Use .platforms instead.",
DeprecationWarning, stacklevel=2)
return self.platforms
# For backwards compatibility
# TODO(necula): remove after September 2024.
@property
def mlir_module_serialization_version(self):
"""DEPRECATED."""
warnings.warn("mlir_module_serialization_version is deprecated. Use .calling_convention_version instead.",
DeprecationWarning, stacklevel=2)
return self.calling_convention_version
# For backwards compatibility
# TODO(necula): remove after September 2024.
@property
def uses_shape_polymorphism(self):
"""DEPRECATED."""
warnings.warn("uses_shape_polymorphism is deprecated. Use .uses_global_constants instead.",
DeprecationWarning, stacklevel=2)
return self.uses_global_constants
def has_vjp(self) -> bool:
"""Returns if this Exported supports VJP."""
return self._get_vjp is not None
@ -331,14 +353,16 @@ def deserialize(blob: bytearray) -> Exported:
return deserialize(blob)
def default_lowering_platform() -> str:
"""Retrieves the default lowering platform.
def default_export_platform() -> str:
"""Retrieves the default export platform.
One of: `tpu`, `cpu`, `cuda`, `rocm`.
"""
# Canonicalize to turn 'gpu' into 'cuda' or 'rocm'
return xb.canonicalize_platform(jax.default_backend())
default_lowering_platform = default_export_platform
def shape_and_dtype_jax_array(a) -> tuple[Sequence[int | None], DType]:
"""Returns the shape and dtype of a jax.Array or a j"""
if isinstance(a, jax.ShapeDtypeStruct):
@ -415,7 +439,7 @@ def export_back_compat(
if lowering_platforms is not None:
actual_lowering_platforms = tuple(lowering_platforms)
else:
actual_lowering_platforms = (default_lowering_platform(),)
actual_lowering_platforms = (default_export_platform(),)
# TODO: move to `lower`
symbolic_scope: tuple[shape_poly.SymbolicScope, tree_util.KeyPath] | None = None # type: ignore[invalid-annotation,unused-ignore]
@ -459,6 +483,7 @@ def export_back_compat(
def export(
fun_jit: stages.Wrapped,
*,
platforms: Sequence[str] | None = None,
lowering_platforms: Sequence[str] | None = None,
disabled_checks: Sequence[DisabledSafetyCheck] = (),
) -> Callable[..., Exported]:
@ -466,13 +491,14 @@ def export(
Args:
fun_jit: the function to export. Should be the result of `jax.jit`.
lowering_platforms:
platforms:
Optional sequence containing a subset of 'tpu', 'cpu',
'cuda', 'rocm'. If more than one platform is specified, then
the lowered code takes an argument specifying the platform.
the exported code takes an argument specifying the platform.
If None, then use the default JAX backend.
The calling convention for multiple platforms is explained at
https://jax.readthedocs.io/en/latest/export.html#module-calling-convention.
lowering_platforms: DEPRECATED, use `platforms`.
disabled_checks: the safety checks to disable. See documentation for
of `jax.export.DisabledSafetyCheck`.
@ -501,10 +527,14 @@ def export(
if not isinstance(fun_jit, stages.Wrapped):
raise ValueError(
f"Function to be exported must be the result of `jit` but is: {fun_jit}")
if lowering_platforms is not None:
actual_lowering_platforms = tuple(lowering_platforms)
if platforms is not None and lowering_platforms is not None:
raise ValueError("Cannot use both `platforms` and `lowering_platforms`")
if platforms is None and lowering_platforms is not None:
platforms = lowering_platforms
if platforms is not None:
actual_lowering_platforms = tuple(platforms)
else:
actual_lowering_platforms = (default_lowering_platform(),)
actual_lowering_platforms = (default_export_platform(),)
def do_export(*args_specs, **kwargs_specs) -> Exported:
# TODO: move to `lower`
@ -542,13 +572,13 @@ def _export_lowered(
disabled_checks: Sequence[DisabledSafetyCheck] = (),
_device_assignment_for_internal_jax2tf_use_only = None,
) -> Exported:
version = config.jax_serialization_version.value
if (version < minimum_supported_serialization_version or
version > maximum_supported_serialization_version):
version = config.jax_export_calling_convention_version.value
if (version < minimum_supported_calling_convention_version or
version > maximum_supported_calling_convention_version):
raise ValueError(
f"The requested jax_serialization version {version} is outside the "
f"range of supported versions [{minimum_supported_serialization_version}"
f"..{maximum_supported_serialization_version}]")
f"The requested export calling convention version {version} is outside the "
f"range of supported versions [{minimum_supported_calling_convention_version}"
f"..{maximum_supported_calling_convention_version}]")
lowering = lowered._lowering
_check_lowering(lowering)
@ -638,7 +668,7 @@ def _export_lowered(
apply_jit=True,
flat_primal_fun=True)
return export(fun_vjp_jax, # type: ignore[arg-type]
lowering_platforms=exp_primal.lowering_platforms,
platforms=exp_primal.platforms,
disabled_checks=exp_primal.disabled_safety_checks)(*vjp_in_avals)
return Exported(
@ -650,14 +680,14 @@ def _export_lowered(
in_shardings_hlo=in_shardings,
out_shardings_hlo=out_shardings,
nr_devices=nr_devices,
lowering_platforms=lowering._platforms, # type: ignore
platforms=lowering._platforms, # type: ignore
ordered_effects=ordered_effects,
unordered_effects=unordered_effects,
disabled_safety_checks=tuple(disabled_checks),
mlir_module_serialized=mlir_module_serialized,
module_kept_var_idx=module_kept_var_idx,
uses_shape_polymorphism=shape_poly_state.uses_dim_vars,
mlir_module_serialization_version=version,
uses_global_constants=shape_poly_state.uses_dim_vars,
calling_convention_version=version,
_get_vjp=_get_exported_vjp)
def _module_to_bytecode(module: ir.Module) -> bytes:
@ -1237,7 +1267,7 @@ call_exported_p.def_impl(_call_exported_impl)
def _call_exported_lowering(ctx: mlir.LoweringRuleContext, *args,
exported: Exported):
if exported.uses_shape_polymorphism:
if exported.uses_global_constants:
ctx.module_context.shape_poly_state.uses_dim_vars = True
submodule = ir.Module.parse(exported.mlir_module())
@ -1255,14 +1285,14 @@ def _call_exported_lowering(ctx: mlir.LoweringRuleContext, *args,
# than the function was exported for.
err_msg = ""
if exported.nr_devices != 1:
err_msg = "the module was lowered for more than 1 device."
err_msg = "the function was exported for more than 1 device."
elif (_check_module(submodule, disabled_checks=()) or
any(s is not None and not s.is_replicated()
for s in exported.in_shardings_hlo + exported.out_shardings_hlo)):
err_msg = "the module contains non-replicated sharding annotations."
err_msg = "the function contains non-replicated sharding annotations."
if err_msg:
raise ValueError(
f"Exported module {exported.fun_name} was lowered for "
f"Function {exported.fun_name} was exported for "
f"{exported.nr_devices} devices and is called in a context with "
f"{num_devices} devices. This is disallowed because: {err_msg}"
)
@ -1296,18 +1326,18 @@ def _call_exported_lowering(ctx: mlir.LoweringRuleContext, *args,
callee_lowering_platform_index: list[int] = []
for platform in lowering_platforms:
if platform in exported.lowering_platforms:
if platform in exported.platforms:
callee_lowering_platform_index.append(
exported.lowering_platforms.index(platform))
exported.platforms.index(platform))
elif DisabledSafetyCheck.platform() in exported.disabled_safety_checks:
callee_lowering_platform_index.append(0)
else:
raise ValueError(
f"The exported function '{exported.fun_name}' was lowered for "
f"platforms '{exported.lowering_platforms}' but it is used "
f"Function '{exported.fun_name}' was exported for "
f"platforms '{exported.platforms}' but it is used "
f"on '{lowering_platforms}'.")
if len(exported.lowering_platforms) > 1:
if len(exported.platforms) > 1:
# The exported module takes a platform index argument
if len(lowering_platforms) > 1:
current_platform_idx = ctx.dim_var_values[0]

View File

@ -119,16 +119,16 @@ table Exported {
in_shardings: [Sharding];
out_shardings: [Sharding];
lowering_platforms: [string];
platforms: [string];
ordered_effects: [Effect];
unordered_effects: [Effect];
disabled_checks: [DisabledSafetyCheck];
mlir_module_serialized: [byte];
mlir_module_serialization_version: uint16;
calling_convention_version: uint16;
module_kept_var_idx: [uint16];
uses_shape_polymorphism: bool;
uses_global_constants: bool;
vjp: Exported;
}

View File

@ -93,8 +93,8 @@ def _serialize_exported(
disabled_safety_checks = _serialize_array(
builder, _serialize_disabled_safety_check, exp.disabled_safety_checks
)
lowering_platforms = _serialize_array(
builder, lambda b, p: b.CreateString(p), exp.lowering_platforms
platforms = _serialize_array(
builder, lambda b, p: b.CreateString(p), exp.platforms
)
mlir_module_serialized = builder.CreateByteVector(exp.mlir_module_serialized)
module_kept_var_idx = builder.CreateNumpyVector(
@ -121,17 +121,17 @@ def _serialize_exported(
ser_flatbuf.ExportedAddNrDevices(builder, exp.nr_devices)
ser_flatbuf.ExportedAddInShardings(builder, in_shardings)
ser_flatbuf.ExportedAddOutShardings(builder, out_shardings)
ser_flatbuf.ExportedAddLoweringPlatforms(builder, lowering_platforms)
ser_flatbuf.ExportedAddPlatforms(builder, platforms)
ser_flatbuf.ExportedAddOrderedEffects(builder, ordered_effects)
ser_flatbuf.ExportedAddUnorderedEffects(builder, unordered_effects)
ser_flatbuf.ExportedAddDisabledChecks(builder, disabled_safety_checks)
ser_flatbuf.ExportedAddMlirModuleSerialized(builder, mlir_module_serialized)
ser_flatbuf.ExportedAddMlirModuleSerializationVersion(
builder, exp.mlir_module_serialization_version
ser_flatbuf.ExportedAddCallingConventionVersion(
builder, exp.calling_convention_version
)
ser_flatbuf.ExportedAddModuleKeptVarIdx(builder, module_kept_var_idx)
ser_flatbuf.ExportedAddUsesShapePolymorphism(
builder, exp.uses_shape_polymorphism
ser_flatbuf.ExportedAddUsesGlobalConstants(
builder, exp.uses_global_constants
)
if vjp is not None:
ser_flatbuf.ExportedAddVjp(builder, vjp)
@ -179,9 +179,9 @@ def _deserialize_exported(exp: ser_flatbuf.Exported) -> _export.Exported:
out_shardings = _deserialize_tuple(
exp.OutShardingsLength, exp.OutShardings, _deserialize_sharding
)
lowering_platforms = _deserialize_tuple(
exp.LoweringPlatformsLength,
exp.LoweringPlatforms,
platforms = _deserialize_tuple(
exp.PlatformsLength,
exp.Platforms,
lambda v: v.decode("utf-8"),
)
ordered_effects = _deserialize_tuple(
@ -197,9 +197,9 @@ def _deserialize_exported(exp: ser_flatbuf.Exported) -> _export.Exported:
)
mlir_module_serialized = exp.MlirModuleSerializedAsNumpy().tobytes()
mlir_module_serialization_version = exp.MlirModuleSerializationVersion()
calling_convention_version = exp.CallingConventionVersion()
module_kept_var_idx = tuple(exp.ModuleKeptVarIdxAsNumpy().tolist())
uses_shape_polymorphism = exp.UsesShapePolymorphism()
uses_global_constants = exp.UsesGlobalConstants()
_get_vjp = None
if vjp := exp.Vjp():
@ -214,14 +214,14 @@ def _deserialize_exported(exp: ser_flatbuf.Exported) -> _export.Exported:
nr_devices=nr_devices,
in_shardings_hlo=in_shardings,
out_shardings_hlo=out_shardings,
lowering_platforms=lowering_platforms,
platforms=platforms,
ordered_effects=ordered_effects,
unordered_effects=unordered_effects,
disabled_safety_checks=disabled_safety_checks,
mlir_module_serialized=mlir_module_serialized,
mlir_module_serialization_version=mlir_module_serialization_version,
calling_convention_version=calling_convention_version,
module_kept_var_idx=module_kept_var_idx,
uses_shape_polymorphism=uses_shape_polymorphism,
uses_global_constants=uses_global_constants,
_get_vjp=_get_vjp,
)

View File

@ -547,7 +547,7 @@ class Exported(object):
return o == 0
# Exported
def LoweringPlatforms(self, j):
def Platforms(self, j):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(22))
if o != 0:
a = self._tab.Vector(o)
@ -555,14 +555,14 @@ class Exported(object):
return ""
# Exported
def LoweringPlatformsLength(self):
def PlatformsLength(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(22))
if o != 0:
return self._tab.VectorLen(o)
return 0
# Exported
def LoweringPlatformsIsNone(self):
def PlatformsIsNone(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(22))
return o == 0
@ -666,7 +666,7 @@ class Exported(object):
return o == 0
# Exported
def MlirModuleSerializationVersion(self):
def CallingConventionVersion(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(32))
if o != 0:
return self._tab.Get(flatbuffers.number_types.Uint16Flags, o + self._tab.Pos)
@ -700,7 +700,7 @@ class Exported(object):
return o == 0
# Exported
def UsesShapePolymorphism(self):
def UsesGlobalConstants(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(36))
if o != 0:
return bool(self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos))
@ -758,10 +758,10 @@ def ExportedAddOutShardings(builder, outShardings):
def ExportedStartOutShardingsVector(builder, numElems):
return builder.StartVector(4, numElems, 4)
def ExportedAddLoweringPlatforms(builder, loweringPlatforms):
builder.PrependUOffsetTRelativeSlot(9, flatbuffers.number_types.UOffsetTFlags.py_type(loweringPlatforms), 0)
def ExportedAddPlatforms(builder, platforms):
builder.PrependUOffsetTRelativeSlot(9, flatbuffers.number_types.UOffsetTFlags.py_type(platforms), 0)
def ExportedStartLoweringPlatformsVector(builder, numElems):
def ExportedStartPlatformsVector(builder, numElems):
return builder.StartVector(4, numElems, 4)
def ExportedAddOrderedEffects(builder, orderedEffects):
@ -788,8 +788,8 @@ def ExportedAddMlirModuleSerialized(builder, mlirModuleSerialized):
def ExportedStartMlirModuleSerializedVector(builder, numElems):
return builder.StartVector(1, numElems, 1)
def ExportedAddMlirModuleSerializationVersion(builder, mlirModuleSerializationVersion):
builder.PrependUint16Slot(14, mlirModuleSerializationVersion, 0)
def ExportedAddCallingConventionVersion(builder, callingConventionVersion):
builder.PrependUint16Slot(14, callingConventionVersion, 0)
def ExportedAddModuleKeptVarIdx(builder, moduleKeptVarIdx):
builder.PrependUOffsetTRelativeSlot(15, flatbuffers.number_types.UOffsetTFlags.py_type(moduleKeptVarIdx), 0)
@ -797,8 +797,8 @@ def ExportedAddModuleKeptVarIdx(builder, moduleKeptVarIdx):
def ExportedStartModuleKeptVarIdxVector(builder, numElems):
return builder.StartVector(2, numElems, 2)
def ExportedAddUsesShapePolymorphism(builder, usesShapePolymorphism):
builder.PrependBoolSlot(16, usesShapePolymorphism, 0)
def ExportedAddUsesGlobalConstants(builder, usesGlobalConstants):
builder.PrependBoolSlot(16, usesGlobalConstants, 0)
def ExportedAddVjp(builder, vjp):
builder.PrependUOffsetTRelativeSlot(17, flatbuffers.number_types.UOffsetTFlags.py_type(vjp), 0)

View File

@ -1313,7 +1313,7 @@ dim_as_value_p.def_abstract_eval(lambda dim: core.dim_value_aval())
def dim_as_value_impl(dim: DimSize):
raise NotImplementedError(
"Evaluation rule for 'dim_as_value' is not implemented. "
"It seems that you are using shape polymorphism outside jax2tf.")
"It seems that you are using shape polymorphism outside jax.export.")
dim_as_value_p.def_impl(dim_as_value_impl)
def _dim_as_value(dim: DimSize):

View File

@ -303,7 +303,7 @@ data_{datetime.date.today().strftime('%Y_%m_%d')} = dict(
module_str = str(exported.mlir_module())
serialized = exported.mlir_module_serialized
module_version = exported.mlir_module_serialization_version
module_version = exported.calling_convention_version
nr_devices = exported.nr_devices
return serialized, module_str, module_version, nr_devices
@ -332,15 +332,15 @@ data_{datetime.date.today().strftime('%Y_%m_%d')} = dict(
out_avals=tuple(out_avals),
in_shardings_hlo=(None,) * len(in_avals),
out_shardings_hlo=(None,) * len(out_avals),
lowering_platforms=(data.platform,),
platforms=(data.platform,),
ordered_effects=(),
unordered_effects=(),
disabled_safety_checks=(),
mlir_module_serialized=data.mlir_module_serialized,
mlir_module_serialization_version=data.xla_call_module_version,
calling_convention_version=data.xla_call_module_version,
nr_devices=data.nr_devices,
module_kept_var_idx=tuple(range(len(in_avals))),
uses_shape_polymorphism=any(not core.is_constant_shape(a.shape)
uses_global_constants=any(not core.is_constant_shape(a.shape)
for a in in_avals),
_get_vjp=_get_vjp)

View File

@ -14,13 +14,13 @@
# ==============================================================================
from jax._src.export._export import (
minimum_supported_serialization_version,
maximum_supported_serialization_version,
Exported,
call_exported, # TODO: deprecate
call,
DisabledSafetyCheck,
default_lowering_platform, # TODO: deprecate
minimum_supported_calling_convention_version,
maximum_supported_calling_convention_version,
Exported,
call_exported, # TODO: deprecate
call,
DisabledSafetyCheck,
default_lowering_platform, # TODO: deprecate
)
from jax._src.export._export import export_back_compat as export

View File

@ -1652,10 +1652,10 @@ class RoundTripToTfTest(tf_test_util.JaxToTfTestCase):
kwargs=[dict(version=version) for version in [9]]
)
def test_call_tf_graph_ordered(self, *, version: int):
with config.jax_serialization_version(version):
with config.jax_export_calling_convention_version(version):
logging.info(
"Using JAX serialization version %s",
jax.config.jax_serialization_version)
jax.config.jax_export_calling_convention_version)
@tf.function
def tf_print(x):
@ -1725,10 +1725,10 @@ class RoundTripToTfTest(tf_test_util.JaxToTfTestCase):
for version in [9]]
)
def test_call_tf_ordered_dead_inputs(self, *, poly: bool, version: int):
with config.jax_serialization_version(version):
with config.jax_export_calling_convention_version(version):
logging.info(
"Using JAX serialization version %s",
jax.config.jax_serialization_version)
jax.config.jax_export_calling_convention_version)
def f_jax(x1, x_dead, x3):
return (x1, jax2tf.call_tf(lambda x: tf.math.sin(x), ordered=True,
call_tf_graph=True)(x3))
@ -1750,10 +1750,10 @@ class RoundTripToTfTest(tf_test_util.JaxToTfTestCase):
]
)
def test_call_tf_graph_polymorphic(self, ordered: bool, version: int):
with config.jax_serialization_version(version):
with config.jax_export_calling_convention_version(version):
logging.info(
"Using JAX serialization version %s",
jax.config.jax_serialization_version)
jax.config.jax_export_calling_convention_version)
@tf.function(jit_compile=True, autograph=False)
@partial(jax2tf.convert,

View File

@ -180,18 +180,18 @@ class JaxToTfTestCase(jtu.JaxTestCase):
# We run the tests using the maximum version supported, even though
# the default serialization version may be held back for a while to
# ensure compatibility
version = config.jax_serialization_version.value
version = config.jax_export_calling_convention_version.value
if self.use_max_serialization_version:
# Use the largest supported by both export and tfxla.call_module
version = min(export.maximum_supported_serialization_version,
version = min(export.maximum_supported_calling_convention_version,
tfxla.call_module_maximum_supported_version())
self.assertGreaterEqual(version,
export.minimum_supported_serialization_version)
self.enter_context(config.jax_serialization_version(version))
export.minimum_supported_calling_convention_version)
self.enter_context(config.jax_export_calling_convention_version(version))
logging.info(
"Using JAX serialization version %s (export.max_version %s, tf.XlaCallModule max version %s)",
version,
export.maximum_supported_serialization_version,
export.maximum_supported_calling_convention_version,
tfxla.call_module_maximum_supported_version())
with contextlib.ExitStack() as stack:

View File

@ -12,23 +12,25 @@
# See the License for the specific language governing permissions and
# limitations under the License.
__all__ = ["DisabledSafetyCheck", "Exported", "export", "deserialize",
"maximum_supported_serialization_version",
"minimum_supported_serialization_version",
"default_lowering_platform",
"maximum_supported_calling_convention_version",
"minimum_supported_calling_convention_version",
"default_export_platform",
"SymbolicScope", "is_symbolic_dim",
"symbolic_shape", "symbolic_args_specs"]
from jax._src.export._export import DisabledSafetyCheck as DisabledSafetyCheck
from jax._src.export._export import Exported as Exported
from jax._src.export._export import export as export
from jax._src.export._export import deserialize as deserialize
from jax._src.export._export import maximum_supported_serialization_version as maximum_supported_serialization_version
from jax._src.export._export import minimum_supported_serialization_version as minimum_supported_serialization_version
from jax._src.export._export import default_lowering_platform as default_lowering_platform
from jax._src.export._export import (
DisabledSafetyCheck,
Exported,
export,
deserialize,
maximum_supported_calling_convention_version,
minimum_supported_calling_convention_version,
default_export_platform)
from jax._src.export import shape_poly_decision # Import only to set the decision procedure
del shape_poly_decision
from jax._src.export.shape_poly import SymbolicScope as SymbolicScope
from jax._src.export.shape_poly import is_symbolic_dim as is_symbolic_dim
from jax._src.export.shape_poly import symbolic_shape as symbolic_shape
from jax._src.export.shape_poly import symbolic_args_specs as symbolic_args_specs
from jax._src.export.shape_poly import (
SymbolicScope,
is_symbolic_dim,
symbolic_shape,
symbolic_args_specs)

View File

@ -150,7 +150,7 @@ def get_exported(fun: Callable, vjp_order=0,
# Run tests with the maximum supported version by default
@jtu.with_config(jax_serialization_version=export.maximum_supported_serialization_version)
@jtu.with_config(jax_export_calling_convention_version=export.maximum_supported_calling_convention_version)
class JaxExportTest(jtu.JaxTestCase):
@classmethod
@ -173,7 +173,7 @@ class JaxExportTest(jtu.JaxTestCase):
self.assertEqual("my_fun", exp.fun_name)
expected_lowering_platform = xb.canonicalize_platform(jax.default_backend())
self.assertEqual((expected_lowering_platform,),
exp.lowering_platforms)
exp.platforms)
self.assertEqual(jax.tree.flatten(((1,), {}))[1], exp.in_tree)
self.assertEqual((core.ShapedArray((4,), dtype=np.float32),), exp.in_avals)
self.assertEqual((core.ShapedArray((4,), dtype=np.float32),), exp.out_avals)
@ -187,7 +187,7 @@ class JaxExportTest(jtu.JaxTestCase):
exp = get_exported(jax.jit(f), lowering_platforms=("cpu",))((a, b), a=a, b=b)
a_aval = core.ShapedArray(a.shape, a.dtype)
b_aval = core.ShapedArray(b.shape, b.dtype)
self.assertEqual(exp.lowering_platforms, ("cpu",))
self.assertEqual(exp.platforms, ("cpu",))
args = ((a, b),)
kwargs = dict(a=a, b=b)
self.assertEqual(exp.in_tree, jax.tree.flatten((args, kwargs))[1])
@ -331,12 +331,12 @@ class JaxExportTest(jtu.JaxTestCase):
r"Dtype mismatch for args\[0\]"):
exp_f.call(f32_4.astype(np.float16), b=f32_4)
def test_default_lowering_platform(self):
def test_default_export_platform(self):
test_platform = jtu.device_under_test()
if test_platform == "gpu": test_platform = "cuda"
self.assertEqual(export.default_lowering_platform(), test_platform)
self.assertEqual(export.default_export_platform(), test_platform)
exp = export.export(jnp.sin)(1.)
self.assertEqual(exp.lowering_platforms, (export.default_lowering_platform(),))
self.assertEqual(exp.platforms, (export.default_export_platform(),))
@jtu.parameterized_filterable(
testcase_name=lambda kw: kw["platform"],
@ -350,7 +350,7 @@ class JaxExportTest(jtu.JaxTestCase):
raise unittest.SkipTest("Uninteresting scenario")
with self.assertRaisesRegex(
ValueError, "The exported function .* was lowered for platform"):
ValueError, "Function .* was exported for platform"):
exp_f.call(a)
# Now try with the platform check disabled
@ -521,7 +521,7 @@ class JaxExportTest(jtu.JaxTestCase):
# Peek at the module
module_str = exp.mlir_module()
self.assertEqual(config.jax_serialization_version.value >= 7,
self.assertEqual(config.jax_export_calling_convention_version.value >= 7,
"shape_assertion" in module_str)
self.assertIn("jax.uses_shape_polymorphism = true", module_str)
wrapped_main_expected_re = (
@ -596,19 +596,19 @@ class JaxExportTest(jtu.JaxTestCase):
@jtu.parameterized_filterable(
kwargs=[
dict(v=v)
for v in range(export.minimum_supported_serialization_version - 1,
export.maximum_supported_serialization_version + 2)])
for v in range(export.minimum_supported_calling_convention_version - 1,
export.maximum_supported_calling_convention_version + 2)])
def test_poly_basic_versions(self, v: int):
with config.jax_serialization_version(v):
with config.jax_export_calling_convention_version(v):
logging.info(
"Using JAX serialization version %s",
config.jax_serialization_version.value)
"Using JAX calling convention version %s",
config.jax_export_calling_convention_version.value)
with contextlib.ExitStack() as e:
if not (export.minimum_supported_serialization_version <= v
<= export.maximum_supported_serialization_version):
if not (export.minimum_supported_calling_convention_version <= v
<= export.maximum_supported_calling_convention_version):
e.enter_context(self.assertRaisesRegex(
ValueError,
f"The requested jax_serialization version {v} is outside the range of supported versions"))
f"The requested export calling convention version {v} is outside the range of supported versions"))
exp = get_exported(jnp.sin)(
jax.ShapeDtypeStruct(export.symbolic_shape("w, h"), np.float32))
@ -656,7 +656,7 @@ class JaxExportTest(jtu.JaxTestCase):
disabled_checks = ()
exp_f = get_exported(jax.jit(f), disabled_checks=disabled_checks)(
jax.ShapeDtypeStruct(export.symbolic_shape(poly_spec), np.float32))
self.assertEqual(exp_f.uses_shape_polymorphism, poly_spec != "3,4,12")
self.assertEqual(exp_f.uses_global_constants, poly_spec != "3,4,12")
arg = np.arange(np.prod(arg_shape),
dtype=arg_dtype).reshape(arg_shape) # arg : f32[3,4,12]
@ -759,7 +759,7 @@ class JaxExportTest(jtu.JaxTestCase):
inner_exp = get_exported(jax.jit(inner))(
jax.ShapeDtypeStruct(export.symbolic_shape(inner_poly_spec), np.float32))
self.assertEqual(inner_exp.uses_shape_polymorphism,
self.assertEqual(inner_exp.uses_global_constants,
(inner_poly_spec != "3,4,12"))
def outer(x): # x: outer_poly_spec
# Use an addition to test that the shapes are refined properly for the
@ -777,7 +777,7 @@ class JaxExportTest(jtu.JaxTestCase):
if expect_error_outer_exp is not None:
return
self.assertEqual(outer_exp.uses_shape_polymorphism,
self.assertEqual(outer_exp.uses_global_constants,
(inner_poly_spec != "3,4,12" or outer_poly_spec != "3,4,12"))
with contextlib.ExitStack() as stack:
@ -966,12 +966,12 @@ class JaxExportTest(jtu.JaxTestCase):
# Test error reporting
with self.assertRaisesRegex(
ValueError,
"Exported module .* was lowered for 2 devices and is called in a context with 1 device"):
"Function .* was exported for 2 devices and is called in a context with 1 device"):
_ = exp.call(a)
with self.assertRaisesRegex(
ValueError,
"Exported module .* was lowered for 2 devices and is called in a context with 1 device"):
"Function .* was exported for 2 devices and is called in a context with 1 device"):
mesh1 = Mesh(jax.devices()[0:1], axis_names=("x",))
_ = jax.jit(
exp.call,
@ -1047,8 +1047,8 @@ class JaxExportTest(jtu.JaxTestCase):
with self.assertRaisesRegex(
ValueError,
"Exported module .* was lowered for 1 devices and is called in a "
f"context with {jax.local_device_count()} devices.* module contains "
"Function .* was exported for 1 devices and is called in a "
f"context with {jax.local_device_count()} devices.* function contains "
"non-replicated sharding annotations"):
exp.call(b)
@ -1093,8 +1093,8 @@ class JaxExportTest(jtu.JaxTestCase):
with self.assertRaisesRegex(
ValueError,
"Exported module .* was lowered for 1 devices and is called in a "
f"context with {jax.local_device_count()} devices.* module contains "
"Function .* was exported for 1 devices and is called in a "
f"context with {jax.local_device_count()} devices.* function contains "
"non-replicated sharding annotations"):
exp.call(b)
@ -1137,7 +1137,7 @@ class JaxExportTest(jtu.JaxTestCase):
f_r = exp.call
with self.assertRaisesRegex(
Exception,
"Exported module .* was lowered for 2 devices and is "
"Function .* was exported for 2 devices and is "
"called in a context with 1 devices"):
_ = f_r(a) # A is all on the default device
@ -1311,7 +1311,7 @@ class JaxExportTest(jtu.JaxTestCase):
x = np.arange(8, dtype=np.float32)
exp = get_exported(jax.jit(_testing_multi_platform_func),
lowering_platforms=("tpu", "cpu", "cuda","rocm"))(x)
self.assertEqual(exp.lowering_platforms, ("tpu", "cpu", "cuda", "rocm"))
self.assertEqual(exp.platforms, ("tpu", "cpu", "cuda", "rocm"))
module_str = str(exp.mlir_module())
expected_main_re = (
r"@main\("
@ -1334,7 +1334,7 @@ class JaxExportTest(jtu.JaxTestCase):
x = np.arange(5, dtype=np.float32)
exp = get_exported(jax.jit(lambda x: _testing_multi_platform_func(jnp.sin(x))),
lowering_platforms=("cpu", "tpu", "cuda","rocm"))(x)
self.assertEqual(exp.lowering_platforms, ("cpu", "tpu", "cuda","rocm"))
self.assertEqual(exp.platforms, ("cpu", "tpu", "cuda","rocm"))
# Now serialize the call to the exported using a different sequence of
# lowering platforms, but included in the lowering platforms for the
@ -1360,7 +1360,7 @@ class JaxExportTest(jtu.JaxTestCase):
x = np.arange(5, dtype=np.float32)
exp = get_exported(jax.jit(_testing_multi_platform_func),
lowering_platforms=("cpu", "tpu", "cuda","rocm"))(x)
self.assertEqual(exp.lowering_platforms, ("cpu", "tpu", "cuda", "rocm"))
self.assertEqual(exp.platforms, ("cpu", "tpu", "cuda", "rocm"))
# Now serialize the call for the current platform.
exp2 = get_exported(jax.jit(exp.call))(x)
@ -1472,13 +1472,13 @@ class JaxExportTest(jtu.JaxTestCase):
@jtu.parameterized_filterable(
kwargs=[
dict(v=v)
for v in range(export.minimum_supported_serialization_version,
export.maximum_supported_serialization_version + 1)])
for v in range(export.minimum_supported_calling_convention_version,
export.maximum_supported_calling_convention_version + 1)])
def test_ordered_effects_basic(self, *, v: int):
with config.jax_serialization_version(v):
with config.jax_export_calling_convention_version(v):
logging.info(
"Using JAX serialization version %s",
config.jax_serialization_version.value)
config.jax_export_calling_convention_version.value)
x = np.arange(3, dtype=np.float32)
def f_jax(x): # x: f32[3]
# Test also the calling convention for inner functions
@ -1550,13 +1550,13 @@ class JaxExportTest(jtu.JaxTestCase):
@jtu.parameterized_filterable(
kwargs=[
dict(v=v)
for v in range(export.minimum_supported_serialization_version,
export.maximum_supported_serialization_version + 1)])
for v in range(export.minimum_supported_calling_convention_version,
export.maximum_supported_calling_convention_version + 1)])
def test_ordered_effects_poly(self, *, v: int):
with config.jax_serialization_version(v):
with config.jax_export_calling_convention_version(v):
logging.info(
"Using JAX serialization version %s",
config.jax_serialization_version.value)
config.jax_export_calling_convention_version.value)
x = np.arange(12, dtype=np.float32).reshape((3, 4))
def f_jax(x): # x: f32[b1, b2]
return 10. + testing_primitive_with_effect_p.bind(x, effect_class_name="ForTestingOrderedEffect1")
@ -1587,13 +1587,13 @@ class JaxExportTest(jtu.JaxTestCase):
@jtu.parameterized_filterable(
kwargs=[
dict(v=v)
for v in range(export.minimum_supported_serialization_version,
export.maximum_supported_serialization_version + 1)])
for v in range(export.minimum_supported_calling_convention_version,
export.maximum_supported_calling_convention_version + 1)])
def test_ordered_effects_multi_platform_and_poly(self, *, v: int):
with config.jax_serialization_version(v):
with config.jax_export_calling_convention_version(v):
logging.info(
"Using JAX serialization version %s",
config.jax_serialization_version.value)
config.jax_export_calling_convention_version.value)
if jtu.device_under_test() == "gpu":
# The export is not applicable to GPU
raise unittest.SkipTest("Not intended for running on GPU")
@ -1632,13 +1632,13 @@ class JaxExportTest(jtu.JaxTestCase):
@jtu.parameterized_filterable(
kwargs=[
dict(v=v)
for v in range(export.minimum_supported_serialization_version,
export.maximum_supported_serialization_version + 1)])
for v in range(export.minimum_supported_calling_convention_version,
export.maximum_supported_calling_convention_version + 1)])
def test_ordered_effects_with_donation(self, *, v: int):
with config.jax_serialization_version(v):
with config.jax_export_calling_convention_version(v):
logging.info(
"Using JAX serialization version %s",
config.jax_serialization_version.value)
config.jax_export_calling_convention_version.value)
x = np.arange(3, dtype=np.float32)