mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Remove deprecated array.tile() method
This commit is contained in:
parent
3aa42775e7
commit
8378d08fcd
@ -16,6 +16,9 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
|
||||
[JEP 11830](https://jax.readthedocs.io/en/latest/jep/11830-new-remat-checkpoint.html).
|
||||
* Changes
|
||||
* Added {func}`jax.pure_callback` that enables calling back to pure Python functions from compiled functions (e.g. functions decorated with `jax.jit` or `jax.pmap`).
|
||||
* Deprecations:
|
||||
* The deprecated `DeviceArray.tile()` method has been removed. Use {func}`jax.numpy.tile`
|
||||
({jax-issue}`#11944`).
|
||||
|
||||
## jax 0.3.16
|
||||
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.15...main).
|
||||
|
@ -4737,14 +4737,6 @@ _diff_methods = ["choose", "conj", "conjugate", "copy", "cumprod", "cumsum",
|
||||
"ravel", "repeat", "sort", "squeeze", "std", "sum",
|
||||
"swapaxes", "take", "trace", "var"]
|
||||
|
||||
|
||||
def _deprecate_function(fun, msg):
|
||||
@functools_wraps(fun)
|
||||
def wrapped(*args, **kwargs):
|
||||
warnings.warn(msg, FutureWarning)
|
||||
return fun(*args, **kwargs)
|
||||
return wrapped
|
||||
|
||||
# These methods are mentioned explicitly by nondiff_methods, so we create
|
||||
# _not_implemented implementations of them here rather than in __init__.py.
|
||||
# TODO(phawkins): implement these.
|
||||
@ -5080,8 +5072,6 @@ def _set_shaped_array_attributes(shaped_array):
|
||||
# Forward methods and properties using core.{aval_method, aval_property}:
|
||||
for method_name in _nondiff_methods + _diff_methods:
|
||||
setattr(shaped_array, method_name, core.aval_method(globals()[method_name]))
|
||||
# TODO(jakevdp): remove tile method after August 2022
|
||||
setattr(shaped_array, "tile", core.aval_method(_deprecate_function(tile, "arr.tile(...) is deprecated and will be removed. Use jnp.tile(arr, ...) instead.")))
|
||||
setattr(shaped_array, "reshape", core.aval_method(_reshape))
|
||||
setattr(shaped_array, "transpose", core.aval_method(_transpose))
|
||||
setattr(shaped_array, "flatten", core.aval_method(ravel))
|
||||
@ -5118,8 +5108,6 @@ def _set_device_array_base_attributes(device_array, include=None):
|
||||
maybe_setattr(f"__{operator_name}__", function)
|
||||
for method_name in _nondiff_methods + _diff_methods:
|
||||
maybe_setattr(method_name, globals()[method_name])
|
||||
# TODO(jakevdp): remove tile method after August 2022
|
||||
maybe_setattr("tile", _deprecate_function(tile, "arr.tile(...) is deprecated and will be removed. Use jnp.tile(arr, ...) instead."))
|
||||
maybe_setattr("reshape", _reshape)
|
||||
maybe_setattr("transpose", _transpose)
|
||||
maybe_setattr("flatten", ravel)
|
||||
|
Loading…
x
Reference in New Issue
Block a user