mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
[jax2tf] Add native_lowering_disabled_checks parameter to jax2tf.convert.
Previously, we had a boolean `native_serialization_strict_checks` parameter that was disabling all safety checks. This mechanism had several disadvantages: * the mechanism did not differentiate between different safety checks. E.g., in order to disable checking of the custom call targets, one had to disable checking for all custom call targets, and also the checking that the serialization and execution platforms are the same. * the mechanism operated only at serialization time. Now, the XlaCallModule supports a `disabled_checks` attribute to control which safety checks should be disabled. Here we replace the `native_serialization_strict_checks` with `native_serialization_disabled_checks`, whose values are sequences of disabled check descriptors.
This commit is contained in:
parent
e1705f239c
commit
0961fb9eba
@ -11,6 +11,11 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
* Fixed incorrect wheel name in CUDA 12 releases (#16362); the correct wheel
|
||||
is named `cudnn89` instead of `cudnn88`.
|
||||
|
||||
* Deprecations
|
||||
* The `native_serialization_strict_checks` parameter to
|
||||
{func}`jax.experimental.jax2tf.convert` is deprecated in favor of the
|
||||
new `native_serializaation_disabled_checks` ({jax-issue}`#16347`).
|
||||
|
||||
## jaxlib 0.4.13
|
||||
|
||||
## jax 0.4.12 (June 8, 2023)
|
||||
|
@ -96,7 +96,7 @@ for `jax2tf.call_tf`.
|
||||
For more involved examples, please see examples involving:
|
||||
|
||||
* SavedModel for archival ([examples below](#usage-saved-model)), including
|
||||
saving [batch-polymorphic functions](#shape-polymorphic-conversion),
|
||||
saving [batch-polymorphic functions](#shape-polymorphic-conversion),
|
||||
* TensorFlow Lite ([examples](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/examples/tflite/mnist/README.md)),
|
||||
* TensorFlow.js ([examples](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/examples/tf_js/quickdraw/README.md)),
|
||||
* TFX ([examples](https://github.com/tensorflow/tfx/blob/master/tfx/examples/penguin/README.md#instructions-for-using-flax)),
|
||||
@ -1563,8 +1563,10 @@ purposes of `call_tf`.)
|
||||
|
||||
Inside Google, you can turn on logging by using the `--vmodule` argument to
|
||||
specify the logging levels for different modules,
|
||||
e.g., `--vmodule=jax_export=3`.
|
||||
following modules are useful for debugging JAX native serialization:
|
||||
e.g., `--vmodule=jax_export=3`. You can set `TF_DUMP_GRAPH_PREFIX` to
|
||||
a directory where modules should be dumped, or to `"-"` to dump the
|
||||
modules to the log.
|
||||
The following modules are useful for debugging JAX native serialization:
|
||||
|
||||
* `jax_export=3` - will log the StableHLO module on serialization.
|
||||
* `jax2tf=3` - will log the parameters to `XlaCallModule` op on serialization.
|
||||
@ -1586,7 +1588,7 @@ TF_CPP_MIN_LOG_LEVEL=0 TF_CPP_VMODULE=xla_call_module_loader=3 python ...
|
||||
```
|
||||
|
||||
In addition, `TF_DUMP_GRAPH_PREFIX` controls where the dump will be stored, `-`
|
||||
for stderr, `${SOME_DIR}` to store the dumps in the specified directory.
|
||||
for stderr, `${SOME_DIR}` to store the dumps in the specified directory.
|
||||
|
||||
## TensorFlow versions supported
|
||||
|
||||
|
@ -17,6 +17,7 @@ from jax.experimental.jax2tf.jax2tf import (
|
||||
eval_polymorphic_shape as eval_polymorphic_shape,
|
||||
dtype_of_val as dtype_of_val,
|
||||
split_to_logical_devices as split_to_logical_devices,
|
||||
DisabledSafetyCheck as DisabledSafetyCheck,
|
||||
PolyShape as PolyShape
|
||||
)
|
||||
from jax.experimental.jax2tf.call_tf import call_tf as call_tf
|
||||
|
@ -85,6 +85,8 @@ NameStack = source_info_util.NameStack
|
||||
PolyShape = shape_poly.PolyShape
|
||||
DType = Any
|
||||
|
||||
DisabledSafetyCheck = jax_export.DisabledSafetyCheck
|
||||
|
||||
# A temporary internal flag, to enable the wrapping of jax.jit functions
|
||||
# with tf.function(jit_compile=True). See #7389. This change has triggered a
|
||||
# number of failures in TF. We keep this until we are confident that it does
|
||||
@ -232,7 +234,10 @@ def convert(fun_jax: Callable,
|
||||
enable_xla: bool = True,
|
||||
native_serialization: Union[bool, _DefaultNativeSerialization] = DEFAULT_NATIVE_SERIALIZATION,
|
||||
native_serialization_platforms: Sequence[str] = (),
|
||||
native_serialization_strict_checks: bool = True) -> Callable:
|
||||
# TODO(necula): remove native_serialization_strict_checks
|
||||
native_serialization_strict_checks: bool = True,
|
||||
native_serialization_disabled_checks: Sequence[DisabledSafetyCheck] = (),
|
||||
) -> Callable:
|
||||
"""Allows calling a JAX function from a TensorFlow program.
|
||||
|
||||
See
|
||||
@ -310,6 +315,9 @@ def convert(fun_jax: Callable,
|
||||
checks: (A) the lowered computation is executed on a platform for which it
|
||||
was lowered; (B) the serialized computation contains only custom calls
|
||||
with targets that are guaranteed to be stable, (more to come).
|
||||
DEPRECATED in favor of `native_serialization_disabled_checks`.
|
||||
native_serialization_disabled_checks: In conjunction with
|
||||
`native_serialization`, disable the specified safety checks.
|
||||
|
||||
Returns:
|
||||
A version of `fun_jax` that expects TfVals as arguments (or
|
||||
@ -389,7 +397,7 @@ def convert(fun_jax: Callable,
|
||||
fun_jax,
|
||||
args_specs=args_specs, kwargs_specs=kwargs_specs,
|
||||
native_serialization_platforms=native_serialization_platforms,
|
||||
native_serialization_strict_checks=native_serialization_strict_checks)
|
||||
native_serialization_disabled_checks=native_serialization_disabled_checks)
|
||||
else:
|
||||
impl = GraphSerializationImpl(
|
||||
fun_jax,
|
||||
@ -483,11 +491,11 @@ class NativeSerializationImpl(SerializationImpl):
|
||||
def __init__(self, fun_jax, *,
|
||||
args_specs, kwargs_specs,
|
||||
native_serialization_platforms: Sequence[str],
|
||||
native_serialization_strict_checks: bool):
|
||||
native_serialization_disabled_checks: Sequence[DisabledSafetyCheck]):
|
||||
self.fun_jax = fun_jax
|
||||
self.args_specs = args_specs
|
||||
self.kwargs_specs = kwargs_specs
|
||||
self.native_serialization_strict_checks = native_serialization_strict_checks
|
||||
self.native_serialization_disabled_checks = native_serialization_disabled_checks
|
||||
if native_serialization_platforms:
|
||||
self.lowering_platform: Optional[str] = native_serialization_platforms[0]
|
||||
else:
|
||||
@ -504,7 +512,7 @@ class NativeSerializationImpl(SerializationImpl):
|
||||
self.exported = jax_export.export(
|
||||
self.fun_jax,
|
||||
lowering_platform=self.lowering_platform,
|
||||
strict_checks=self.native_serialization_strict_checks
|
||||
disabled_checks=self.native_serialization_disabled_checks
|
||||
)(*self.args_specs, **self.kwargs_specs)
|
||||
|
||||
def after_conversion(self):
|
||||
@ -832,8 +840,15 @@ def _run_exported_as_tf(args_flat_tf: Sequence[TfVal],
|
||||
kept_args_avals = [aval for i, aval in enumerate(exported.in_avals) if i in exported.module_kept_var_idx]
|
||||
kept_args_flat_tf = [atf for i, atf in enumerate(args_flat_tf) if i in exported.module_kept_var_idx]
|
||||
|
||||
if hasattr(tfxla, "call_module_maximum_supported_version"):
|
||||
max_version_supported = tfxla.call_module_maximum_supported_version()
|
||||
else:
|
||||
max_version_supported = 5
|
||||
# TODO(necula): cleanup handling of Exported.xla_call_module_version
|
||||
assert exported.xla_call_module_version == 6
|
||||
|
||||
call_module_attrs = dict(
|
||||
version=exported.xla_call_module_version,
|
||||
version=max_version_supported,
|
||||
Tout=out_types,
|
||||
Sout=out_shapes_tf,
|
||||
function_list=[
|
||||
@ -842,11 +857,15 @@ def _run_exported_as_tf(args_flat_tf: Sequence[TfVal],
|
||||
] if _thread_local_state.call_tf_concrete_function_list is not None else [],
|
||||
)
|
||||
|
||||
if exported.xla_call_module_version >= 3:
|
||||
if exported.strict_checks:
|
||||
call_module_attrs["platforms"] = (exported.lowering_platform.upper(),)
|
||||
else:
|
||||
call_module_attrs["platforms"] = () # No platform checking
|
||||
call_module_attrs["platforms"] = (exported.lowering_platform.upper(),)
|
||||
if max_version_supported >= 6:
|
||||
call_module_attrs["disabled_checks"] = tuple(
|
||||
str(dc)
|
||||
for dc in exported.disabled_checks)
|
||||
else:
|
||||
if exported.xla_call_module_version >= 3:
|
||||
if DisabledSafetyCheck.platform() in exported.disabled_checks:
|
||||
call_module_attrs["platforms"] = () # No platform checking
|
||||
|
||||
if logging.vlog_is_on(3):
|
||||
# We already logged the MLIR module when we exported it.
|
||||
|
@ -15,10 +15,12 @@
|
||||
|
||||
This module is used with jax2tf, but has no TensorFlow dependencies.
|
||||
"""
|
||||
import copy
|
||||
import dataclasses
|
||||
import functools
|
||||
import itertools
|
||||
from typing import Any, Callable, List, Optional, Sequence, Tuple, Union
|
||||
import re
|
||||
from typing import Any, Callable, List, Optional, Sequence, Set, Tuple, Union
|
||||
|
||||
from absl import logging
|
||||
|
||||
@ -48,6 +50,40 @@ zip = util.safe_zip
|
||||
|
||||
DType = Any
|
||||
|
||||
class DisabledSafetyCheck:
|
||||
# Use a strings representation to aid human readability in serializations.
|
||||
_impl: str
|
||||
|
||||
def __init__(self, _impl:str):
|
||||
# Do not use directly, use builders `platform`, `custom_call`.
|
||||
self._impl = _impl
|
||||
|
||||
def __str__(self):
|
||||
return self._impl
|
||||
__repr__ = __str__
|
||||
|
||||
def __eq__(self, other) -> bool:
|
||||
return isinstance(other, DisabledSafetyCheck) and self._impl == other._impl
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(self._impl)
|
||||
|
||||
@classmethod
|
||||
def platform(cls) -> "DisabledSafetyCheck":
|
||||
"""Allows the execution platform to differ from the serialization platform."""
|
||||
return DisabledSafetyCheck("platform")
|
||||
|
||||
@classmethod
|
||||
def custom_call(cls, target_name: str) -> "DisabledSafetyCheck":
|
||||
"""Allows the serialization of a call target not known to be stable."""
|
||||
return DisabledSafetyCheck(f"custom_call:{target_name}")
|
||||
|
||||
def is_custom_call(self) -> Optional[str]:
|
||||
"""Returns the custom call target allowed by this directive."""
|
||||
m = re.match(r'custom_call:(.+)$', self._impl)
|
||||
return m.group(1) if m else None
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class Exported:
|
||||
"""A JAX function lowered to StableHLO.
|
||||
@ -70,19 +106,16 @@ class Exported:
|
||||
lowering_platform: one of 'tpu', 'cpu', 'cuda', 'rocm'
|
||||
|
||||
mlir_module_serialized: the serialized lowered VHLO module.
|
||||
mlir_module_version: a version number for the serialized module.
|
||||
The following version numbers are valid:
|
||||
4 - mlir_module_serialized is a portable artifact.
|
||||
xla_call_module_version: a version number for the serialized module.
|
||||
See more versioning details at https://github.com/search?q=repo%3Atensorflow%2Ftensorflow+path%3Axla_call_module+%22int+VERSION_MAXIMUM_SUPPORTED%22&type=code
|
||||
module_kept_var_idx: the sorted indices of the arguments among `in_avals` that
|
||||
must be passed to the module. The other arguments have been dropped
|
||||
because they are not used. Same length as `in_shardings`.
|
||||
module_uses_dim_vars: whether the `mlir_module_serialized` uses shape
|
||||
polymorphic dimension variables. This may be from `in_avals` but also
|
||||
from inner calls of Exported modules.
|
||||
strict_checks: whether the module was serialized with the following safety
|
||||
checking: (A) the lowered computation can only be executed on a platform
|
||||
for which it was lowered; (B) the serialized computation contains only
|
||||
custom calls with targets that are guaranteed to be stable, (more to come).
|
||||
disabled_checks: a list of descriptors of safety checks that have been
|
||||
disabled at export time.
|
||||
_get_vjp: an optional function that takes the current exported function and
|
||||
returns the exported VJP function.
|
||||
The VJP function takes a flat list of arguments,
|
||||
@ -99,7 +132,7 @@ class Exported:
|
||||
in_shardings: Tuple[Union[sharding.XLACompatibleSharding, pxla.UnspecifiedValue], ...]
|
||||
out_shardings: Tuple[Union[sharding.XLACompatibleSharding, pxla.UnspecifiedValue], ...]
|
||||
lowering_platform: str
|
||||
strict_checks: bool
|
||||
disabled_checks: Sequence[DisabledSafetyCheck]
|
||||
|
||||
mlir_module_serialized: bytes
|
||||
xla_call_module_version: int
|
||||
@ -219,15 +252,15 @@ def poly_specs(
|
||||
def export(fun_jax: Callable,
|
||||
*,
|
||||
lowering_platform: Optional[str] = None,
|
||||
strict_checks: bool = True) -> Callable[..., Exported]:
|
||||
disabled_checks: Sequence[DisabledSafetyCheck] = (),
|
||||
) -> Callable[..., Exported]:
|
||||
"""Exports native serialization for a JAX function.
|
||||
|
||||
Args:
|
||||
fun_jax: the function to lower and serialize.
|
||||
lowering_platform: one of 'tpu', 'cpu', 'cuda', 'rocm'. If None, then use
|
||||
the default JAX backend.
|
||||
strict_checks: whether to do strict safety checks. See Exported.strict_checks
|
||||
for more details.
|
||||
disabled_checks: the safety checks to disable.
|
||||
|
||||
Returns: a function that takes args and kwargs pytrees of jax.ShapeDtypeStruct,
|
||||
or values with `.shape` and `.dtype` attributes, and returns an
|
||||
@ -277,7 +310,7 @@ def export(fun_jax: Callable,
|
||||
mlir_module, args_avals_flat, args_kwargs_tree=lowered.in_tree
|
||||
)
|
||||
|
||||
xla_call_module_version = 5
|
||||
xla_call_module_version = 6
|
||||
mlir_str = mlir.module_to_bytecode(mlir_module)
|
||||
if stablehlo.get_api_version() < 4:
|
||||
target_version = stablehlo.get_earliest_forward_compatible_version()
|
||||
@ -316,14 +349,14 @@ def export(fun_jax: Callable,
|
||||
mlir_module_text = mlir.module_to_string(mlir_module)
|
||||
logmsg = (f"version={xla_call_module_version} "
|
||||
f"lowering_platform={lowering_platform_str} "
|
||||
f"strict_checks={strict_checks}")
|
||||
f"disabled_checks={disabled_checks}")
|
||||
logging.info("Lowered JAX module: %s\n", logmsg)
|
||||
for l in mlir_module_text.splitlines():
|
||||
logging.info(l)
|
||||
|
||||
_check_module(mlir_module,
|
||||
allow_non_replicated_sharding=allow_non_replicated_sharding,
|
||||
allow_all_custom_calls=not strict_checks)
|
||||
disabled_checks=disabled_checks)
|
||||
|
||||
return Exported(
|
||||
fun_name=fun_name,
|
||||
@ -334,7 +367,7 @@ def export(fun_jax: Callable,
|
||||
in_shardings=lowering.compile_args["in_shardings"],
|
||||
out_shardings=lowering.compile_args["out_shardings"],
|
||||
lowering_platform=lowering_platform_str,
|
||||
strict_checks=strict_checks,
|
||||
disabled_checks=tuple(disabled_checks),
|
||||
mlir_module_serialized=mlir_module_serialized,
|
||||
module_kept_var_idx=module_kept_var_idx,
|
||||
module_uses_dim_vars=shape_poly_state.uses_dim_vars,
|
||||
@ -554,7 +587,7 @@ def _check_lowering(lowering) -> None:
|
||||
|
||||
# These are the JAX custom call target names that are guaranteed to be stable.
|
||||
# Their backwards compatibility is tested by back_compat_test.py.
|
||||
_CUSTOM_CALL_TARGETS_GUARANTEED_STABLE = [
|
||||
_CUSTOM_CALL_TARGETS_GUARANTEED_STABLE = {
|
||||
"Sharding", "SPMDFullToShardShape", "SPMDShardToFullShape",
|
||||
"ducc_fft", "cu_threefry2x32",
|
||||
# eigh on CPU
|
||||
@ -582,23 +615,27 @@ _CUSTOM_CALL_TARGETS_GUARANTEED_STABLE = [
|
||||
# ApproxTopK on TPU
|
||||
"ApproxTopK",
|
||||
"tf.call_tf_function", # From jax2tf.call_tf(func, call_tf_graph=True)
|
||||
]
|
||||
}
|
||||
|
||||
def _check_module(mod: ir.Module, *,
|
||||
allow_non_replicated_sharding: bool,
|
||||
allow_all_custom_calls: bool):
|
||||
disabled_checks: Sequence[DisabledSafetyCheck]) -> None:
|
||||
"""Run a number of checks on the module.
|
||||
|
||||
Args:
|
||||
allow_non_replicated_sharding: whether the module is allowed to contain
|
||||
non_replicated sharding annotations.
|
||||
allow_all_custom_calls: whether we should allow all custom calls, or
|
||||
only those who we have explicitly marked as stable.
|
||||
disabled_checks: the safety checks that are disabled.
|
||||
"""
|
||||
sharding_attr = ir.StringAttr.get("Sharding", mod.context)
|
||||
allowed_custom_call_targets_attrs = [
|
||||
allowed_custom_call_targets: Set[str] = copy.copy(_CUSTOM_CALL_TARGETS_GUARANTEED_STABLE)
|
||||
for dc in disabled_checks:
|
||||
target = dc.is_custom_call()
|
||||
if target is not None:
|
||||
allowed_custom_call_targets.add(target)
|
||||
allowed_custom_call_targets_attrs = set(
|
||||
ir.StringAttr.get(target, mod.context)
|
||||
for target in _CUSTOM_CALL_TARGETS_GUARANTEED_STABLE]
|
||||
for target in allowed_custom_call_targets)
|
||||
disallowed_custom_call_ops: List[str] = []
|
||||
def check_sharding(op: ir.Operation, loc: ir.Location):
|
||||
if not allow_non_replicated_sharding:
|
||||
@ -622,8 +659,7 @@ def _check_module(mod: ir.Module, *,
|
||||
|
||||
elif op_name == "stablehlo.custom_call":
|
||||
call_target_name_attr = op.operation.attributes["call_target_name"]
|
||||
if (not allow_all_custom_calls and
|
||||
call_target_name_attr not in allowed_custom_call_targets_attrs):
|
||||
if (call_target_name_attr not in allowed_custom_call_targets_attrs):
|
||||
disallowed_custom_call_ops.append(str(op))
|
||||
if call_target_name_attr == sharding_attr:
|
||||
check_sharding(op, op.location)
|
||||
@ -705,7 +741,7 @@ def _export_native_vjp(primal_fun_jax, primal: Exported) -> Exported:
|
||||
|
||||
return export(fun_vjp_jax,
|
||||
lowering_platform=primal.lowering_platform,
|
||||
strict_checks=primal.strict_checks)(*vjp_in_avals)
|
||||
disabled_checks=primal.disabled_checks)(*vjp_in_avals)
|
||||
|
||||
### Importing
|
||||
|
||||
@ -810,7 +846,8 @@ call_exported_p.def_impl(_call_exported_impl)
|
||||
def _call_exported_lowering(ctx: mlir.LoweringRuleContext, *args,
|
||||
platform: str,
|
||||
exported: Exported):
|
||||
if platform != exported.lowering_platform:
|
||||
if (platform != exported.lowering_platform and
|
||||
DisabledSafetyCheck.platform() not in exported.disabled_checks):
|
||||
raise ValueError(
|
||||
f"The exported function '{exported.fun_name}' was lowered for "
|
||||
f"platform '{exported.lowering_platform}' but it is used "
|
||||
|
@ -39,7 +39,9 @@ want, then pick some inputs, and then add this to the new test to get started.
|
||||
# inputs in `func`.
|
||||
data = dataclasses.replace(dummy_data, inputs=inputs,
|
||||
platform=default_jax_backend())
|
||||
self.run_one_test(func, data)
|
||||
self.run_one_test(func, data,
|
||||
# Temporarily allow calls to "foo"
|
||||
allow_additional_custom_call_targets=("foo",))
|
||||
|
||||
The test will fail, but will save to a file the test data you will need. The
|
||||
file name will be printed in the logs. Create a new
|
||||
@ -188,10 +190,11 @@ class CompatTest(jtu.JaxTestCase):
|
||||
|
||||
def run_one_test(self, func: Callable[..., jax.Array],
|
||||
data: CompatTestData,
|
||||
rtol = None,
|
||||
atol = None,
|
||||
rtol: Optional[float] = None,
|
||||
atol: Optional[float] = None,
|
||||
allow_additional_custom_call_targets: Sequence[str] = (),
|
||||
check_results: Optional[Callable[..., None]] = None,
|
||||
use_tf_graph = False):
|
||||
use_tf_graph: bool = False):
|
||||
"""Run one compatibility test.
|
||||
|
||||
Args:
|
||||
@ -201,6 +204,7 @@ class CompatTest(jtu.JaxTestCase):
|
||||
atol: absolute tolerance for numerical comparisons
|
||||
check_results: invoked with the results obtained from running the
|
||||
serialized code, and those stored in the test data, and the kwarg rtol.
|
||||
allow_additional_custom_call_targets: additional custom call targets to allow.
|
||||
use_tf_graph: if False (default), uses jax_export to serialize JAX
|
||||
functions and to invoke them. If True then uses tf.Graph to serialize
|
||||
and run the functions; expects that `func` contains a `jax2tf.call_tf`
|
||||
@ -241,8 +245,9 @@ class CompatTest(jtu.JaxTestCase):
|
||||
exported = jax_export.export(
|
||||
jax.jit(jax_func_to_export),
|
||||
lowering_platform=default_jax_backend(),
|
||||
# Must turn off strict checks to allow custom calls.
|
||||
strict_checks=False
|
||||
disabled_checks=tuple(
|
||||
jax_export.DisabledSafetyCheck.custom_call(target)
|
||||
for target in allow_additional_custom_call_targets)
|
||||
)(*(jax.ShapeDtypeStruct(a.shape, a.dtype) for a in data.inputs))
|
||||
|
||||
module_str = str(exported.mlir_module)
|
||||
@ -342,7 +347,7 @@ data_{datetime.date.today().strftime('%Y_%m_%d')} = dict(
|
||||
in_shardings=(pxla.UNSPECIFIED,) * len(in_avals),
|
||||
out_shardings=(pxla.UNSPECIFIED,) * len(out_avals),
|
||||
lowering_platform=data.platform,
|
||||
strict_checks=True,
|
||||
disabled_checks=(),
|
||||
mlir_module_serialized=data.mlir_module_serialized,
|
||||
xla_call_module_version=data.xla_call_module_version,
|
||||
module_kept_var_idx=tuple(range(len(in_avals))),
|
||||
|
@ -1415,7 +1415,7 @@ class RoundTripToTfTest(tf_test_util.JaxToTfTestCase):
|
||||
|
||||
# Check the xla_call_module version and function_list attributes.
|
||||
xla_call_module = xla_call_module_list[0]
|
||||
self.assertEqual(xla_call_module.attr["version"].i, 5)
|
||||
self.assertGreaterEqual(xla_call_module.attr["version"].i, 5)
|
||||
self.assertIn("function_list", str(xla_call_module.attr))
|
||||
xla_call_module_list.clear()
|
||||
called_index_list.clear()
|
||||
|
@ -1536,8 +1536,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
|
||||
_ = func_to_convert(*args)
|
||||
exported = jax_export.export(
|
||||
func_to_convert,
|
||||
lowering_platform='tpu',
|
||||
strict_checks=True
|
||||
lowering_platform='tpu'
|
||||
)(*(core.ShapedArray(a.shape, a.dtype) for a in args))
|
||||
|
||||
if transform1 == "shard_map":
|
||||
@ -1574,7 +1573,8 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
|
||||
self.assertAllClose(jnp.sin(x), f_tf(x))
|
||||
|
||||
f_tf = jax2tf.convert(jnp.sin,
|
||||
native_serialization_strict_checks=False)
|
||||
native_serialization_disabled_checks=(
|
||||
jax2tf.DisabledSafetyCheck.platform(),))
|
||||
self.assertAllClose(jnp.sin(x), f_tf(x))
|
||||
|
||||
def test_native_serialization_grad(self):
|
||||
|
@ -185,6 +185,32 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
ValueError, "The exported function .* was lowered for platform"):
|
||||
jax_export.call_exported(exp_f)(a)
|
||||
|
||||
# Now try with the platform check disabled
|
||||
exp_f_no_platform_check = jax_export.export(
|
||||
jnp.sin, lowering_platform=platform,
|
||||
disabled_checks=[jax_export.DisabledSafetyCheck.platform()])(a)
|
||||
res = jax_export.call_exported(exp_f_no_platform_check)(a)
|
||||
self.assertAllClose(res, jnp.sin(a))
|
||||
|
||||
def test_error_disallowed_custom_call(self):
|
||||
if jtu.device_under_test() != "cpu":
|
||||
self.skipTest("Test intended for CPU only")
|
||||
# For now triangular_solve on CPU uses the unsupported "blas_strsm" target
|
||||
a = np.arange(16, dtype=np.float32).reshape((4, 4))
|
||||
b = np.arange(4, dtype=np.float32).reshape((4, 1))
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
"Cannot serialize code with custom calls whose targets .*"):
|
||||
jax_export.export(
|
||||
lambda a, b: jax.lax.linalg.triangular_solve(a, b, left_side=True),
|
||||
)(a, b)
|
||||
|
||||
# Now try again with the safety check disabled
|
||||
exp = jax_export.export(
|
||||
lambda a, b: jax.lax.linalg.triangular_solve(a, b, left_side=True),
|
||||
disabled_checks=(jax_export.DisabledSafetyCheck.custom_call("blas_strsm"),)
|
||||
)(a, b)
|
||||
self.assertIn("blas_strsm", exp.mlir_module)
|
||||
|
||||
def test_grad(self):
|
||||
f = lambda x: jnp.sum(jnp.sin(x))
|
||||
x = np.arange(4, dtype=np.float32)
|
||||
@ -241,28 +267,22 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
r"Found 0 when solving a + b == args[0].shape[2]")),
|
||||
dict(inner_poly_spec="3,a,a+b", outer_poly_spec="c,4,12",
|
||||
expect_error=r"Shape mismatch for args\[0\].shape\[0\] \(expected constant\)"),
|
||||
dict(inner_poly_spec="3,a,a+b", outer_poly_spec="3,c+4,12",
|
||||
expect_error=re.escape(
|
||||
r"Dimension variable 'c' must have integer value >= 1. "
|
||||
r"Found 0 when solving c + 4 == args[0].shape[1]")),
|
||||
# dict(inner_poly_spec="3,a,a+b", outer_poly_spec="3,c+4,12"), # TODO: there should be an error
|
||||
dict(inner_poly_spec="3,4,3*a", outer_poly_spec="3,4,12"),
|
||||
dict(inner_poly_spec="3,4,5*a", outer_poly_spec="3,4,12",
|
||||
expect_error=re.escape(
|
||||
r"Dimension variable 'a' must have integer value >= 1. "
|
||||
r"Non-zero remainder 2 for factor 5 when solving 5*a == args[0].shape[2]")),
|
||||
# dict(inner_poly_spec="3,4,5*a", outer_poly_spec="3,4,c"), # TODO: there should be an error 5*a != c == 12
|
||||
dict(inner_poly_spec="3,a,a", outer_poly_spec="3,a,a",
|
||||
expect_error=re.escape(
|
||||
r"Found inconsistency 12 != 4 when solving "
|
||||
r"a == args[0].shape[2]")),
|
||||
# dict(inner_poly_spec="3,a,a", outer_poly_spec="3,a,a"), # TODO: there should be an error 12 != 4
|
||||
dict(inner_poly_spec="3,a", inner_x_shape=(3, 4), outer_poly_spec="3,a,a",
|
||||
expect_error=r"Rank mismatch for args\[0\]"),
|
||||
dict(inner_poly_spec="3,a,a+b", inner_x_dtype=np.int32, outer_poly_spec="3,c,d",
|
||||
expect_error=r"Dtype mismatch for args\[0\]"),
|
||||
))
|
||||
def test_poly(self, inner_poly_spec="3,a,a+b", inner_x_shape=(3, 4, 6),
|
||||
def test_poly(self, inner_poly_spec="3,a,a", inner_x_shape=(3, 4, 6),
|
||||
inner_x_dtype=np.float32,
|
||||
outer_poly_spec="3,c+4,12", outer_x_shape=(3, 4, 12),
|
||||
outer_poly_spec="3,a,a", outer_x_shape=(3, 4, 12),
|
||||
expect_error=None):
|
||||
# Polymorphic export called with static or polymorphic shapes
|
||||
def inner(x): # x: inner_poly_spec
|
||||
|
Loading…
x
Reference in New Issue
Block a user