Merge pull request #13035 from jakevdp:jnp-put

PiperOrigin-RevId: 485075125
This commit is contained in:
jax authors 2022-10-31 09:45:17 -07:00
commit f3ddd565c3
4 changed files with 42 additions and 6 deletions

View File

@ -302,6 +302,7 @@ namespace; they are listed below.
pad
percentile
piecewise
place
poly
polyadd
polyder
@ -318,6 +319,7 @@ namespace; they are listed below.
product
promote_types
ptp
put
quantile
r_
rad2deg

View File

@ -4885,6 +4885,32 @@ def _not_implemented(fun, module=None):
return wrapped
@_wraps(np.place, lax_description="""
Numpy function :func:`numpy.place` is not available in JAX and will raise a
:class:`NotImplementedError`, because ``np.place`` modifies its arguments in-place,
and in JAX arrays are immutable. A JAX-compatible approach to array updates
can be found in :attr:`jax.numpy.ndarray.at`.
""")
def place(*args, **kwargs):
raise NotImplementedError(
"jax.numpy.place is not implemented because JAX arrays cannot be modified in-place. "
"For functional approaches to updating array values, see jax.numpy.ndarray.at: "
"https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html.")
@_wraps(np.put, lax_description="""
Numpy function :func:`numpy.put` is not available in JAX and will raise a
:class:`NotImplementedError`, because ``np.put`` modifies its arguments in-place,
and in JAX arrays are immutable. A JAX-compatible approach to array updates
can be found in :attr:`jax.numpy.ndarray.at`.
""")
def put(*args, **kwargs):
raise NotImplementedError(
"jax.numpy.put is not implemented because JAX arrays cannot be modified in-place. "
"For functional approaches to updating array values, see jax.numpy.ndarray.at: "
"https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html.")
### add method and operator overloads to arraylike classes
# We add operator overloads to DeviceArray and ShapedArray. These method and

View File

@ -177,12 +177,18 @@ def _wraps(
parameters = _parse_parameters(parsed.sections['Parameters'])
if extra_params:
parameters.update(_parse_extra_params(extra_params))
parsed.sections['Parameters'] = (
"Parameters\n"
"----------\n" +
"\n".join(_versionadded.split(desc)[0].rstrip() for p, desc in parameters.items()
if (code is None or p in code.co_varnames) and p not in skip_params)
)
parameters = {p: desc for p, desc in parameters.items()
if (code is None or p in code.co_varnames)
and p not in skip_params}
if parameters:
parsed.sections['Parameters'] = (
"Parameters\n"
"----------\n" +
"\n".join(_versionadded.split(desc)[0].rstrip()
for p, desc in parameters.items())
)
else:
del parsed.sections['Parameters']
docstr = parsed.summary.strip() + "\n" if parsed.summary else ""
docstr += f"\nLAX-backend implementation of :func:`{name}`.\n"

View File

@ -191,8 +191,10 @@ from jax._src.numpy.lax_numpy import (
percentile as percentile,
pi as pi,
piecewise as piecewise,
place as place,
printoptions as printoptions,
promote_types as promote_types,
put as put,
quantile as quantile,
ravel as ravel,
ravel_multi_index as ravel_multi_index,