mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #13035 from jakevdp:jnp-put
PiperOrigin-RevId: 485075125
This commit is contained in:
commit
f3ddd565c3
@ -302,6 +302,7 @@ namespace; they are listed below.
|
||||
pad
|
||||
percentile
|
||||
piecewise
|
||||
place
|
||||
poly
|
||||
polyadd
|
||||
polyder
|
||||
@ -318,6 +319,7 @@ namespace; they are listed below.
|
||||
product
|
||||
promote_types
|
||||
ptp
|
||||
put
|
||||
quantile
|
||||
r_
|
||||
rad2deg
|
||||
|
@ -4885,6 +4885,32 @@ def _not_implemented(fun, module=None):
|
||||
return wrapped
|
||||
|
||||
|
||||
@_wraps(np.place, lax_description="""
|
||||
Numpy function :func:`numpy.place` is not available in JAX and will raise a
|
||||
:class:`NotImplementedError`, because ``np.place`` modifies its arguments in-place,
|
||||
and in JAX arrays are immutable. A JAX-compatible approach to array updates
|
||||
can be found in :attr:`jax.numpy.ndarray.at`.
|
||||
""")
|
||||
def place(*args, **kwargs):
|
||||
raise NotImplementedError(
|
||||
"jax.numpy.place is not implemented because JAX arrays cannot be modified in-place. "
|
||||
"For functional approaches to updating array values, see jax.numpy.ndarray.at: "
|
||||
"https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html.")
|
||||
|
||||
|
||||
@_wraps(np.put, lax_description="""
|
||||
Numpy function :func:`numpy.put` is not available in JAX and will raise a
|
||||
:class:`NotImplementedError`, because ``np.put`` modifies its arguments in-place,
|
||||
and in JAX arrays are immutable. A JAX-compatible approach to array updates
|
||||
can be found in :attr:`jax.numpy.ndarray.at`.
|
||||
""")
|
||||
def put(*args, **kwargs):
|
||||
raise NotImplementedError(
|
||||
"jax.numpy.put is not implemented because JAX arrays cannot be modified in-place. "
|
||||
"For functional approaches to updating array values, see jax.numpy.ndarray.at: "
|
||||
"https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html.")
|
||||
|
||||
|
||||
### add method and operator overloads to arraylike classes
|
||||
|
||||
# We add operator overloads to DeviceArray and ShapedArray. These method and
|
||||
|
@ -177,12 +177,18 @@ def _wraps(
|
||||
parameters = _parse_parameters(parsed.sections['Parameters'])
|
||||
if extra_params:
|
||||
parameters.update(_parse_extra_params(extra_params))
|
||||
parsed.sections['Parameters'] = (
|
||||
"Parameters\n"
|
||||
"----------\n" +
|
||||
"\n".join(_versionadded.split(desc)[0].rstrip() for p, desc in parameters.items()
|
||||
if (code is None or p in code.co_varnames) and p not in skip_params)
|
||||
)
|
||||
parameters = {p: desc for p, desc in parameters.items()
|
||||
if (code is None or p in code.co_varnames)
|
||||
and p not in skip_params}
|
||||
if parameters:
|
||||
parsed.sections['Parameters'] = (
|
||||
"Parameters\n"
|
||||
"----------\n" +
|
||||
"\n".join(_versionadded.split(desc)[0].rstrip()
|
||||
for p, desc in parameters.items())
|
||||
)
|
||||
else:
|
||||
del parsed.sections['Parameters']
|
||||
|
||||
docstr = parsed.summary.strip() + "\n" if parsed.summary else ""
|
||||
docstr += f"\nLAX-backend implementation of :func:`{name}`.\n"
|
||||
|
@ -191,8 +191,10 @@ from jax._src.numpy.lax_numpy import (
|
||||
percentile as percentile,
|
||||
pi as pi,
|
||||
piecewise as piecewise,
|
||||
place as place,
|
||||
printoptions as printoptions,
|
||||
promote_types as promote_types,
|
||||
put as put,
|
||||
quantile as quantile,
|
||||
ravel as ravel,
|
||||
ravel_multi_index as ravel_multi_index,
|
||||
|
Loading…
x
Reference in New Issue
Block a user