Finalize deprecation of ffi_call with inline arguments.

PiperOrigin-RevId: 745261995
This commit is contained in:
Dan Foreman-Mackey 2025-04-08 13:08:54 -07:00 committed by Charles Hofer
parent 879b72a603
commit 686144b099
3 changed files with 10 additions and 34 deletions

View File

@ -42,6 +42,9 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
available from `jax.extend.mlir`.
* `jax.interpreters.mlir.custom_call` is deprecated. The APIs provided by
{mod}`jax.ffi` should be used instead.
* The deprecated use of {func}`jax.ffi.ffi_call` with inline arguments is no
longer supported. {func}`~jax.ffi.ffi_call` now unconditionally returns a
callable.
* Several previously-deprecated APIs have been removed, including:
* From `jax.lib.xla_client`: `FftType`, `PaddingType`, `dtype_to_etype`,
and `shape_from_pyval`.

View File

@ -39,11 +39,6 @@ from jax._src.lib.mlir import ir
from jax._src.typing import (Array, ArrayLike, DeprecatedArg, DuckTypedArray,
Shape)
# TODO(dfm): Remove after 6 months or less because there aren't any offical
# compatibility guarantees for jax.extend (see JEP 15856)
# Added Oct 13, 2024
deprecations.register("jax-ffi-call-args")
map, unsafe_map = util.safe_map, map
FfiLayoutOptions = Sequence[int] | DeviceLocalLayout | None
@ -325,7 +320,7 @@ def _convert_layouts_for_ffi_call(
def ffi_call(
target_name: str,
result_shape_dtypes: ResultMetadata,
*deprecated_args: ArrayLike,
*,
has_side_effect: bool = ...,
vmap_method: str | None = ...,
input_layouts: Sequence[FfiLayoutOptions] | None = ...,
@ -334,8 +329,7 @@ def ffi_call(
custom_call_api_version: int = ...,
legacy_backend_config: str | None = ...,
vectorized: bool | DeprecatedArg = ...,
**deprecated_kwargs: Any,
) -> Callable[..., Array] | Array:
) -> Callable[..., Array]:
...
@ -343,7 +337,7 @@ def ffi_call(
def ffi_call(
target_name: str,
result_shape_dtypes: Sequence[ResultMetadata],
*deprecated_args: ArrayLike,
*,
has_side_effect: bool = ...,
vmap_method: str | None = ...,
input_layouts: Sequence[FfiLayoutOptions] | None = ...,
@ -352,15 +346,14 @@ def ffi_call(
custom_call_api_version: int = ...,
legacy_backend_config: str | None = ...,
vectorized: bool | DeprecatedArg = ...,
**deprecated_kwargs: Any,
) -> Callable[..., Sequence[Array]] | Sequence[Array]:
) -> Callable[..., Sequence[Array]]:
...
def ffi_call(
target_name: str,
result_shape_dtypes: ResultMetadata | Sequence[ResultMetadata],
*deprecated_args: ArrayLike,
*,
has_side_effect: bool = False,
vmap_method: str | None = None,
input_layouts: Sequence[FfiLayoutOptions] | None = None,
@ -369,8 +362,7 @@ def ffi_call(
custom_call_api_version: int = 4,
legacy_backend_config: str | None = None,
vectorized: bool | DeprecatedArg = DeprecatedArg(),
**deprecated_kwargs: Any,
) -> Callable[..., Array | Sequence[Array]] | Array | Sequence[Array]:
) -> Callable[..., Array | Sequence[Array]]:
"""Call a foreign function interface (FFI) target.
See the :ref:`ffi-tutorial` tutorial for more information.
@ -537,19 +529,7 @@ def ffi_call(
else:
return results[0]
if deprecated_args or deprecated_kwargs:
deprecations.warn(
"jax-ffi-call-args",
"Calling ffi_call directly with input arguments is deprecated. "
"Instead, ffi_call should be used to construct a callable, which can "
"then be called with the appropriate inputs. For example,\n"
" ffi_call('target_name', output_type, x, argument=5)\n"
"should be replaced with\n"
" ffi_call('target_name', output_type)(x, argument=5)",
stacklevel=2)
return wrapped(*deprecated_args, **deprecated_kwargs)
else:
return wrapped
return wrapped
# ffi_call must support some small non-hashable input arguments, like np.arrays

View File

@ -208,13 +208,6 @@ class FfiTest(jtu.JaxTestCase):
with self.assertWarns(DeprecationWarning):
jax.vmap(ffi_call_geqrf)(x)
def test_backward_compat_syntax(self):
def fun(x):
return jax.ffi.ffi_call("test_ffi", x, x, param=0.5)
msg = "Calling ffi_call directly with input arguments is deprecated"
with self.assertDeprecationWarnsOrRaises("jax-ffi-call-args", msg):
jax.jit(fun).lower(jnp.ones(5))
def test_input_output_aliases(self):
def fun(x):
return jax.ffi.ffi_call("test", x, input_output_aliases={0: 0})(x)