mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Add deprecation warning for sharded_jit.
PiperOrigin-RevId: 439926957
This commit is contained in:
parent
3bfa6af2c8
commit
4ed06602d3
@ -14,7 +14,7 @@
|
||||
|
||||
from functools import partial
|
||||
from typing import Callable, Iterable, Optional, Tuple, Union
|
||||
|
||||
from warnings import warn
|
||||
from absl import logging
|
||||
import numpy as np
|
||||
|
||||
@ -403,6 +403,9 @@ def sharded_jit(
|
||||
Returns:
|
||||
A version of ``fun`` that will be distributed across multiple devices.
|
||||
"""
|
||||
warn("`sharded_jit` is deprecated. Please use `pjit` instead. "
|
||||
"See https://jax.readthedocs.io/en/latest/jax-101/08-pjit.html for more information.",
|
||||
DeprecationWarning)
|
||||
if num_partitions is not None:
|
||||
nparts = num_partitions
|
||||
else:
|
||||
|
@ -21,5 +21,6 @@ filterwarnings =
|
||||
# (seen on scipy 1.2.3).
|
||||
ignore:`np.*` is a deprecated alias for.*:DeprecationWarning
|
||||
ignore:The module numpy.dual is deprecated.*:DeprecationWarning
|
||||
ignore:`sharded_jit` is deprecated. Please use `pjit` instead.*:DeprecationWarning
|
||||
doctest_optionflags = NUMBER NORMALIZE_WHITESPACE
|
||||
addopts = --doctest-glob="*.rst"
|
||||
|
Loading…
x
Reference in New Issue
Block a user