Add deprecation warning for sharded_jit.

PiperOrigin-RevId: 439926957
This commit is contained in:
Yash Katariya 2022-04-06 13:53:34 -07:00 committed by jax authors
parent 3bfa6af2c8
commit 4ed06602d3
2 changed files with 5 additions and 1 deletions

View File

@ -14,7 +14,7 @@
from functools import partial
from typing import Callable, Iterable, Optional, Tuple, Union
from warnings import warn
from absl import logging
import numpy as np
@ -403,6 +403,9 @@ def sharded_jit(
Returns:
A version of ``fun`` that will be distributed across multiple devices.
"""
warn("`sharded_jit` is deprecated. Please use `pjit` instead. "
"See https://jax.readthedocs.io/en/latest/jax-101/08-pjit.html for more information.",
DeprecationWarning)
if num_partitions is not None:
nparts = num_partitions
else:

View File

@ -21,5 +21,6 @@ filterwarnings =
# (seen on scipy 1.2.3).
ignore:`np.*` is a deprecated alias for.*:DeprecationWarning
ignore:The module numpy.dual is deprecated.*:DeprecationWarning
ignore:`sharded_jit` is deprecated. Please use `pjit` instead.*:DeprecationWarning
doctest_optionflags = NUMBER NORMALIZE_WHITESPACE
addopts = --doctest-glob="*.rst"