mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Finalize deprecation of ffi_call
with inline arguments.
PiperOrigin-RevId: 745261995
This commit is contained in:
parent
09fed2f643
commit
2d44f985c3
@ -42,6 +42,9 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
|
||||
available from `jax.extend.mlir`.
|
||||
* `jax.interpreters.mlir.custom_call` is deprecated. The APIs provided by
|
||||
{mod}`jax.ffi` should be used instead.
|
||||
* The deprecated use of {func}`jax.ffi.ffi_call` with inline arguments is no
|
||||
longer supported. {func}`~jax.ffi.ffi_call` now unconditionally returns a
|
||||
callable.
|
||||
* Several previously-deprecated APIs have been removed, including:
|
||||
* From `jax.lib.xla_client`: `FftType`, `PaddingType`, `dtype_to_etype`,
|
||||
and `shape_from_pyval`.
|
||||
|
@ -39,11 +39,6 @@ from jax._src.lib.mlir import ir
|
||||
from jax._src.typing import (Array, ArrayLike, DeprecatedArg, DuckTypedArray,
|
||||
Shape)
|
||||
|
||||
# TODO(dfm): Remove after 6 months or less because there aren't any offical
|
||||
# compatibility guarantees for jax.extend (see JEP 15856)
|
||||
# Added Oct 13, 2024
|
||||
deprecations.register("jax-ffi-call-args")
|
||||
|
||||
map, unsafe_map = util.safe_map, map
|
||||
FfiLayoutOptions = Sequence[int] | DeviceLocalLayout | None
|
||||
|
||||
@ -325,7 +320,7 @@ def _convert_layouts_for_ffi_call(
|
||||
def ffi_call(
|
||||
target_name: str,
|
||||
result_shape_dtypes: ResultMetadata,
|
||||
*deprecated_args: ArrayLike,
|
||||
*,
|
||||
has_side_effect: bool = ...,
|
||||
vmap_method: str | None = ...,
|
||||
input_layouts: Sequence[FfiLayoutOptions] | None = ...,
|
||||
@ -334,8 +329,7 @@ def ffi_call(
|
||||
custom_call_api_version: int = ...,
|
||||
legacy_backend_config: str | None = ...,
|
||||
vectorized: bool | DeprecatedArg = ...,
|
||||
**deprecated_kwargs: Any,
|
||||
) -> Callable[..., Array] | Array:
|
||||
) -> Callable[..., Array]:
|
||||
...
|
||||
|
||||
|
||||
@ -343,7 +337,7 @@ def ffi_call(
|
||||
def ffi_call(
|
||||
target_name: str,
|
||||
result_shape_dtypes: Sequence[ResultMetadata],
|
||||
*deprecated_args: ArrayLike,
|
||||
*,
|
||||
has_side_effect: bool = ...,
|
||||
vmap_method: str | None = ...,
|
||||
input_layouts: Sequence[FfiLayoutOptions] | None = ...,
|
||||
@ -352,15 +346,14 @@ def ffi_call(
|
||||
custom_call_api_version: int = ...,
|
||||
legacy_backend_config: str | None = ...,
|
||||
vectorized: bool | DeprecatedArg = ...,
|
||||
**deprecated_kwargs: Any,
|
||||
) -> Callable[..., Sequence[Array]] | Sequence[Array]:
|
||||
) -> Callable[..., Sequence[Array]]:
|
||||
...
|
||||
|
||||
|
||||
def ffi_call(
|
||||
target_name: str,
|
||||
result_shape_dtypes: ResultMetadata | Sequence[ResultMetadata],
|
||||
*deprecated_args: ArrayLike,
|
||||
*,
|
||||
has_side_effect: bool = False,
|
||||
vmap_method: str | None = None,
|
||||
input_layouts: Sequence[FfiLayoutOptions] | None = None,
|
||||
@ -369,8 +362,7 @@ def ffi_call(
|
||||
custom_call_api_version: int = 4,
|
||||
legacy_backend_config: str | None = None,
|
||||
vectorized: bool | DeprecatedArg = DeprecatedArg(),
|
||||
**deprecated_kwargs: Any,
|
||||
) -> Callable[..., Array | Sequence[Array]] | Array | Sequence[Array]:
|
||||
) -> Callable[..., Array | Sequence[Array]]:
|
||||
"""Call a foreign function interface (FFI) target.
|
||||
|
||||
See the :ref:`ffi-tutorial` tutorial for more information.
|
||||
@ -537,19 +529,7 @@ def ffi_call(
|
||||
else:
|
||||
return results[0]
|
||||
|
||||
if deprecated_args or deprecated_kwargs:
|
||||
deprecations.warn(
|
||||
"jax-ffi-call-args",
|
||||
"Calling ffi_call directly with input arguments is deprecated. "
|
||||
"Instead, ffi_call should be used to construct a callable, which can "
|
||||
"then be called with the appropriate inputs. For example,\n"
|
||||
" ffi_call('target_name', output_type, x, argument=5)\n"
|
||||
"should be replaced with\n"
|
||||
" ffi_call('target_name', output_type)(x, argument=5)",
|
||||
stacklevel=2)
|
||||
return wrapped(*deprecated_args, **deprecated_kwargs)
|
||||
else:
|
||||
return wrapped
|
||||
return wrapped
|
||||
|
||||
|
||||
# ffi_call must support some small non-hashable input arguments, like np.arrays
|
||||
|
@ -208,13 +208,6 @@ class FfiTest(jtu.JaxTestCase):
|
||||
with self.assertWarns(DeprecationWarning):
|
||||
jax.vmap(ffi_call_geqrf)(x)
|
||||
|
||||
def test_backward_compat_syntax(self):
|
||||
def fun(x):
|
||||
return jax.ffi.ffi_call("test_ffi", x, x, param=0.5)
|
||||
msg = "Calling ffi_call directly with input arguments is deprecated"
|
||||
with self.assertDeprecationWarnsOrRaises("jax-ffi-call-args", msg):
|
||||
jax.jit(fun).lower(jnp.ones(5))
|
||||
|
||||
def test_input_output_aliases(self):
|
||||
def fun(x):
|
||||
return jax.ffi.ffi_call("test", x, input_output_aliases={0: 0})(x)
|
||||
|
Loading…
x
Reference in New Issue
Block a user