mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Cleanup: convert uses of import numpy as onp
in library code (#3754)
This commit is contained in:
parent
512ed18d5a
commit
a7c2cdea64
@ -20,7 +20,7 @@ import time
|
||||
from typing import Any, Optional, Union, Callable, List, Dict
|
||||
|
||||
from absl import flags
|
||||
import numpy as onp
|
||||
import numpy as np
|
||||
from tabulate import tabulate
|
||||
|
||||
from jax.util import safe_zip
|
||||
@ -59,7 +59,7 @@ def benchmark(f: Callable[[], Any], iters: Optional[int] = None,
|
||||
if iters is None:
|
||||
warmup = 1
|
||||
else:
|
||||
warmup = onp.clip(1, iters // 10, 10)
|
||||
warmup = np.clip(1, iters // 10, 10)
|
||||
for _ in range(warmup):
|
||||
f()
|
||||
|
||||
@ -73,7 +73,7 @@ def benchmark(f: Callable[[], Any], iters: Optional[int] = None,
|
||||
times.append(end - start)
|
||||
count += 1
|
||||
|
||||
times_arr = onp.array(times)
|
||||
times_arr = np.array(times)
|
||||
print("---------Benchmark results for %s---------" % (name or f.__name__))
|
||||
print("mean=%f std=%f %%std=%f total=%f" %
|
||||
(times_arr.mean(), times_arr.std(), _pstd(times_arr), times_arr.sum()))
|
||||
|
136
jax/api.py
136
jax/api.py
@ -32,7 +32,7 @@ import threading
|
||||
from typing import Any, Callable, Iterable, Optional, Sequence, Tuple, Union
|
||||
from warnings import warn
|
||||
|
||||
import numpy as onp
|
||||
import numpy as np
|
||||
from contextlib import contextmanager
|
||||
|
||||
from . import core
|
||||
@ -204,10 +204,10 @@ def disable_jit():
|
||||
debugging, and avoid the tracer too, we can use the :py:func:`disable_jit`
|
||||
context manager:
|
||||
|
||||
>>> import jax.numpy as np
|
||||
>>> import jax
|
||||
>>>
|
||||
>>> with jax.disable_jit():
|
||||
... print(f(np.array([1, 2, 3])))
|
||||
... print(f(jax.numpy.array([1, 2, 3])))
|
||||
...
|
||||
Value of y is [2 4 6]
|
||||
[5 7 9]
|
||||
@ -339,7 +339,7 @@ def xla_computation(fun: Callable,
|
||||
return xla.AxisEnv(nreps, names, sizes)
|
||||
|
||||
def abstractify(x):
|
||||
return ShapedArray(onp.shape(x), dtypes.result_type(x))
|
||||
return ShapedArray(np.shape(x), dtypes.result_type(x))
|
||||
|
||||
@wraps(fun)
|
||||
def computation_maker(*args, **kwargs):
|
||||
@ -474,7 +474,7 @@ def value_and_grad(fun: Callable, argnums: Union[int, Sequence[int]] = 0,
|
||||
_check_scalar(ans)
|
||||
dtype = dtypes.result_type(ans)
|
||||
tree_map(partial(_check_output_dtype_grad, holomorphic), ans)
|
||||
g = vjp_py(onp.ones((), dtype=dtype))
|
||||
g = vjp_py(np.ones((), dtype=dtype))
|
||||
g = g[0] if isinstance(argnums, int) else g
|
||||
if not has_aux:
|
||||
return ans, g
|
||||
@ -500,12 +500,12 @@ def _check_input_dtype_revderiv(name, holomorphic, x):
|
||||
_check_arg(x)
|
||||
aval = core.get_aval(x)
|
||||
if holomorphic:
|
||||
if not dtypes.issubdtype(aval.dtype, onp.complexfloating):
|
||||
if not dtypes.issubdtype(aval.dtype, np.complexfloating):
|
||||
msg = (f"{name} with holomorphic=True requires inputs with complex dtype, "
|
||||
f"but got {aval.dtype.name}.")
|
||||
raise TypeError(msg)
|
||||
elif not (dtypes.issubdtype(aval.dtype, onp.floating) or
|
||||
dtypes.issubdtype(aval.dtype, onp.complexfloating)):
|
||||
elif not (dtypes.issubdtype(aval.dtype, np.floating) or
|
||||
dtypes.issubdtype(aval.dtype, np.complexfloating)):
|
||||
msg = (f"{name} requires real- or complex-valued inputs (input dtype that "
|
||||
"is a sub-dtype of np.floating or np.complexfloating), "
|
||||
f"but got {aval.dtype.name}. ")
|
||||
@ -515,11 +515,11 @@ _check_input_dtype_grad = partial(_check_input_dtype_revderiv, "grad")
|
||||
def _check_output_dtype_revderiv(name, holomorphic, x):
|
||||
aval = core.get_aval(x)
|
||||
if holomorphic:
|
||||
if not dtypes.issubdtype(aval.dtype, onp.complexfloating):
|
||||
if not dtypes.issubdtype(aval.dtype, np.complexfloating):
|
||||
msg = (f"{name} with holomorphic=True requires outputs with complex dtype, "
|
||||
f"but got {aval.dtype.name}.")
|
||||
raise TypeError(msg)
|
||||
elif not dtypes.issubdtype(aval.dtype, onp.floating):
|
||||
elif not dtypes.issubdtype(aval.dtype, np.floating):
|
||||
msg = (f"{name} requires real-valued outputs (output dtype that is "
|
||||
f"a sub-dtype of np.floating), but got {aval.dtype.name}. "
|
||||
"For holomorphic differentiation, pass holomorphic=True. "
|
||||
@ -545,13 +545,13 @@ def jacfwd(fun: Callable, argnums: Union[int, Sequence[int]] = 0,
|
||||
``fun`` using forward-mode automatic differentiation.
|
||||
|
||||
>>> import jax
|
||||
>>> import jax.numpy as np
|
||||
>>> import jax.numpy as jnp
|
||||
>>>
|
||||
>>> def f(x):
|
||||
... return jax.numpy.asarray(
|
||||
... [x[0], 5*x[2], 4*x[1]**2 - 2*x[2], x[2] * jax.numpy.sin(x[0])])
|
||||
... return jnp.asarray(
|
||||
... [x[0], 5*x[2], 4*x[1]**2 - 2*x[2], x[2] * jnp.sin(x[0])])
|
||||
...
|
||||
>>> print(jax.jacfwd(f)(np.array([1., 2., 3.])))
|
||||
>>> print(jax.jacfwd(f)(jnp.array([1., 2., 3.])))
|
||||
[[ 1. 0. 0. ]
|
||||
[ 0. 0. 5. ]
|
||||
[ 0. 16. -2. ]
|
||||
@ -575,12 +575,12 @@ def _check_input_dtype_jacfwd(holomorphic, x):
|
||||
_check_arg(x)
|
||||
aval = core.get_aval(x)
|
||||
if holomorphic:
|
||||
if not (dtypes.issubdtype(aval.dtype, onp.complexfloating) and
|
||||
not dtypes.issubdtype(aval.dtype, onp.floating)):
|
||||
if not (dtypes.issubdtype(aval.dtype, np.complexfloating) and
|
||||
not dtypes.issubdtype(aval.dtype, np.floating)):
|
||||
msg = ("jacfwd with holomorphic=True requires inputs with complex dtype, "
|
||||
f"but got {aval.dtype.name}.")
|
||||
raise TypeError(msg)
|
||||
elif not dtypes.issubdtype(aval.dtype, onp.floating):
|
||||
elif not dtypes.issubdtype(aval.dtype, np.floating):
|
||||
msg = ("jacfwd requires real-valued inputs (input dtype that is "
|
||||
f"a sub-dtype of np.floating), but got {aval.dtype.name}. "
|
||||
"For holomorphic differentiation, pass holomorphic=True. "
|
||||
@ -591,7 +591,7 @@ def _check_input_dtype_jacfwd(holomorphic, x):
|
||||
def _check_output_dtype_jacfwd(holomorphic, x):
|
||||
aval = core.get_aval(x)
|
||||
if holomorphic:
|
||||
if not dtypes.issubdtype(aval.dtype, onp.complexfloating):
|
||||
if not dtypes.issubdtype(aval.dtype, np.complexfloating):
|
||||
msg = ("jacfwd with holomorphic=True requires outputs with complex dtype, "
|
||||
f"but got {aval.dtype.name}.")
|
||||
raise TypeError(msg)
|
||||
@ -613,13 +613,13 @@ def jacrev(fun: Callable, argnums: Union[int, Sequence[int]] = 0,
|
||||
``fun`` using reverse-mode automatic differentiation.
|
||||
|
||||
>>> import jax
|
||||
>>> import jax.numpy as np
|
||||
>>> import jax.numpy as jnp
|
||||
>>>
|
||||
>>> def f(x):
|
||||
... return jax.numpy.asarray(
|
||||
... [x[0], 5*x[2], 4*x[1]**2 - 2*x[2], x[2] * jax.numpy.sin(x[0])])
|
||||
... return jnp.asarray(
|
||||
... [x[0], 5*x[2], 4*x[1]**2 - 2*x[2], x[2] * jnp.sin(x[0])])
|
||||
...
|
||||
>>> print(jax.jacrev(f)(np.array([1., 2., 3.])))
|
||||
>>> print(jax.jacrev(f)(jnp.array([1., 2., 3.])))
|
||||
[[ 1. 0. 0. ]
|
||||
[ 0. 0. 5. ]
|
||||
[ 0. 16. -2. ]
|
||||
@ -711,23 +711,23 @@ def hessian(fun: Callable, argnums: Union[int, Sequence[int]] = 0,
|
||||
|
||||
def _std_basis(pytree):
|
||||
leaves, _ = tree_flatten(pytree)
|
||||
ndim = sum(map(onp.size, leaves))
|
||||
ndim = sum(map(np.size, leaves))
|
||||
# TODO(mattjj): use a symbolic identity matrix here
|
||||
dtype = dtypes.result_type(*leaves)
|
||||
flat_basis = onp.eye(ndim, dtype=dtype)
|
||||
flat_basis = np.eye(ndim, dtype=dtype)
|
||||
return _unravel_array_into_pytree(pytree, 1, flat_basis)
|
||||
|
||||
def _unravel_array_into_pytree(pytree, axis, arr):
|
||||
leaves, treedef = tree_flatten(pytree)
|
||||
axis = axis % arr.ndim
|
||||
shapes = [arr.shape[:axis] + onp.shape(l) + arr.shape[axis+1:] for l in leaves]
|
||||
parts = _split(arr, onp.cumsum(map(onp.size, leaves[:-1])), axis)
|
||||
reshaped_parts = [onp.reshape(x, shape) for x, shape in zip(parts, shapes)]
|
||||
shapes = [arr.shape[:axis] + np.shape(l) + arr.shape[axis+1:] for l in leaves]
|
||||
parts = _split(arr, np.cumsum(map(np.size, leaves[:-1])), axis)
|
||||
reshaped_parts = [np.reshape(x, shape) for x, shape in zip(parts, shapes)]
|
||||
return tree_unflatten(treedef, reshaped_parts)
|
||||
|
||||
def _split(x, indices, axis):
|
||||
if isinstance(x, onp.ndarray):
|
||||
return onp.split(x, indices, axis)
|
||||
if isinstance(x, np.ndarray):
|
||||
return np.split(x, indices, axis)
|
||||
else:
|
||||
return x.split(indices, axis)
|
||||
|
||||
@ -771,9 +771,9 @@ def vmap(fun: Callable, in_axes=0, out_axes=0) -> Callable:
|
||||
For example, we can implement a matrix-matrix product using a vector dot
|
||||
product:
|
||||
|
||||
>>> import jax.numpy as np
|
||||
>>> import jax.numpy as jnp
|
||||
>>>
|
||||
>>> vv = lambda x, y: np.vdot(x, y) # ([a], [a]) -> []
|
||||
>>> vv = lambda x, y: jnp.vdot(x, y) # ([a], [a]) -> []
|
||||
>>> mv = vmap(vv, (0, None), 0) # ([b,a], [a]) -> [b] (b is the mapped axis)
|
||||
>>> mm = vmap(mv, (None, 1), 1) # ([b,a], [a,c]) -> [b,c] (c is the mapped axis)
|
||||
|
||||
@ -788,21 +788,21 @@ def vmap(fun: Callable, in_axes=0, out_axes=0) -> Callable:
|
||||
axes of the container elements to map over:
|
||||
|
||||
>>> A, B, C, D = 2, 3, 4, 5
|
||||
>>> x = np.ones((A, B))
|
||||
>>> y = np.ones((B, C))
|
||||
>>> z = np.ones((C, D))
|
||||
>>> x = jnp.ones((A, B))
|
||||
>>> y = jnp.ones((B, C))
|
||||
>>> z = jnp.ones((C, D))
|
||||
>>> def foo(tree_arg):
|
||||
... x, (y, z) = tree_arg
|
||||
... return np.dot(x, np.dot(y, z))
|
||||
... return jnp.dot(x, jnp.dot(y, z))
|
||||
>>> tree = (x, (y, z))
|
||||
>>> print(foo(tree))
|
||||
[[12. 12. 12. 12. 12.]
|
||||
[12. 12. 12. 12. 12.]]
|
||||
>>> from jax import vmap
|
||||
>>> K = 6 # batch size
|
||||
>>> x = np.ones((K, A, B)) # batch axis in different locations
|
||||
>>> y = np.ones((B, K, C))
|
||||
>>> z = np.ones((C, D, K))
|
||||
>>> x = jnp.ones((K, A, B)) # batch axis in different locations
|
||||
>>> y = jnp.ones((B, K, C))
|
||||
>>> z = jnp.ones((C, D, K))
|
||||
>>> tree = (x, (y, z))
|
||||
>>> vfoo = vmap(foo, in_axes=((0, (1, 2)),))
|
||||
>>> print(vfoo(tree).shape)
|
||||
@ -811,7 +811,7 @@ def vmap(fun: Callable, in_axes=0, out_axes=0) -> Callable:
|
||||
Here's another example using container types in ``in_axes``, this time a
|
||||
dictionary, to specify the elements of the container to map over:
|
||||
|
||||
>>> dct = {'a': 0., 'b': np.arange(5.)}
|
||||
>>> dct = {'a': 0., 'b': jnp.arange(5.)}
|
||||
>>> x = 1.
|
||||
>>> def foo(dct, x):
|
||||
... return dct['a'] + dct['b'] + x
|
||||
@ -824,13 +824,13 @@ def vmap(fun: Callable, in_axes=0, out_axes=0) -> Callable:
|
||||
element mapped and the second unmapped. Only for unmapped results
|
||||
we can specify ``out_axes`` to be ``None`` (to keep it unmapped).
|
||||
|
||||
>>> print(vmap(lambda x, y: (x + y, y * 2.), in_axes=(0, None), out_axes=(0, None))(np.arange(2.), 4.))
|
||||
>>> print(vmap(lambda x, y: (x + y, y * 2.), in_axes=(0, None), out_axes=(0, None))(jnp.arange(2.), 4.))
|
||||
(DeviceArray([4., 5.], dtype=float32), 8.0)
|
||||
|
||||
If the ``out_axes`` is specified for an unmapped result, the result is broadcast
|
||||
across the mapped axis:
|
||||
|
||||
>>> print(vmap(lambda x, y: (x + y, y * 2.), in_axes=(0, None), out_axes=0)(np.arange(2.), 4.))
|
||||
>>> print(vmap(lambda x, y: (x + y, y * 2.), in_axes=(0, None), out_axes=0)(jnp.arange(2.), 4.))
|
||||
(DeviceArray([4., 5.], dtype=float32), DeviceArray([8., 8.], dtype=float32))
|
||||
|
||||
If the ``out_axes`` is specified for a mapped result, the result is
|
||||
@ -884,7 +884,7 @@ def _get_axis_size(name: str, i:int, shape: Tuple[int, ...], axis: int):
|
||||
f"but axis to be mapped {axis}") from e
|
||||
|
||||
def _mapped_axis_size(tree, vals, dims, name):
|
||||
mapped_axis_sizes = {_get_axis_size(name, i, onp.shape(x), d)
|
||||
mapped_axis_sizes = {_get_axis_size(name, i, np.shape(x), d)
|
||||
for i, (x, d) in enumerate(zip(vals, dims))
|
||||
if d is not None}
|
||||
try:
|
||||
@ -1005,18 +1005,18 @@ def pmap(fun: Callable, axis_name: Optional[AxisName] = None, *, in_axes=0,
|
||||
For example, assuming 8 XLA devices are available, :py:func:`pmap` can be used as a
|
||||
map along a leading array axis:
|
||||
|
||||
>>> import jax.numpy as np
|
||||
>>> import jax.numpy as jnp
|
||||
>>>
|
||||
>>> out = pmap(lambda x: x ** 2)(np.arange(8)) # doctest: +SKIP
|
||||
>>> out = pmap(lambda x: x ** 2)(jnp.arange(8)) # doctest: +SKIP
|
||||
>>> print(out) # doctest: +SKIP
|
||||
[0, 1, 4, 9, 16, 25, 36, 49]
|
||||
|
||||
When the leading dimension is smaller than the number of available devices JAX
|
||||
will simply run on a subset of devices:
|
||||
|
||||
>>> x = np.arange(3 * 2 * 2.).reshape((3, 2, 2))
|
||||
>>> y = np.arange(3 * 2 * 2.).reshape((3, 2, 2)) ** 2
|
||||
>>> out = pmap(np.dot)(x, y) # doctest: +SKIP
|
||||
>>> x = jnp.arange(3 * 2 * 2.).reshape((3, 2, 2))
|
||||
>>> y = jnp.arange(3 * 2 * 2.).reshape((3, 2, 2)) ** 2
|
||||
>>> out = pmap(jnp.dot)(x, y) # doctest: +SKIP
|
||||
>>> print(out) # doctest: +SKIP
|
||||
[[[ 4. 9.]
|
||||
[ 12. 29.]]
|
||||
@ -1028,14 +1028,14 @@ def pmap(fun: Callable, axis_name: Optional[AxisName] = None, *, in_axes=0,
|
||||
If your leading dimension is larger than the number of available devices you
|
||||
will get an error:
|
||||
|
||||
>>> pmap(lambda x: x ** 2)(np.arange(9)) # doctest: +SKIP
|
||||
>>> pmap(lambda x: x ** 2)(jnp.arange(9)) # doctest: +SKIP
|
||||
ValueError: ... requires 9 replicas, but only 8 XLA devices are available
|
||||
|
||||
As with :py:func:`vmap`, using ``None`` in ``in_axes`` indicates that an argument
|
||||
doesn't have an extra axis and should be broadcasted, rather than mapped,
|
||||
across the replicas:
|
||||
|
||||
>>> x, y = np.arange(2.), 4.
|
||||
>>> x, y = jnp.arange(2.), 4.
|
||||
>>> out = pmap(lambda x, y: (x + y, y * 2.), in_axes=(0, None))(x, y) # doctest: +SKIP
|
||||
>>> print(out) # doctest: +SKIP
|
||||
([4., 5.], [8., 8.])
|
||||
@ -1048,7 +1048,7 @@ def pmap(fun: Callable, axis_name: Optional[AxisName] = None, *, in_axes=0,
|
||||
collective operations. For example:
|
||||
|
||||
>>> f = lambda x: x / jax.lax.psum(x, axis_name='i')
|
||||
>>> out = pmap(f, axis_name='i')(np.arange(4.)) # doctest: +SKIP
|
||||
>>> out = pmap(f, axis_name='i')(jnp.arange(4.)) # doctest: +SKIP
|
||||
>>> print(out) # doctest: +SKIP
|
||||
[ 0. 0.16666667 0.33333334 0.5 ]
|
||||
>>> print(out.sum()) # doctest: +SKIP
|
||||
@ -1073,7 +1073,7 @@ def pmap(fun: Callable, axis_name: Optional[AxisName] = None, *, in_axes=0,
|
||||
... doubly_normed = x / jax.lax.psum(x, ('rows', 'cols'))
|
||||
... return row_normed, col_normed, doubly_normed
|
||||
>>>
|
||||
>>> x = np.arange(8.).reshape((4, 2))
|
||||
>>> x = jnp.arange(8.).reshape((4, 2))
|
||||
>>> row_normed, col_normed, doubly_normed = normalize(x) # doctest: +SKIP
|
||||
>>> print(row_normed.sum(0)) # doctest: +SKIP
|
||||
[ 1. 1.]
|
||||
@ -1087,7 +1087,7 @@ def pmap(fun: Callable, axis_name: Optional[AxisName] = None, *, in_axes=0,
|
||||
runs on two hosts with 4 XLA devices each:
|
||||
|
||||
>>> f = lambda x: x + jax.lax.psum(x, axis_name='i')
|
||||
>>> data = np.arange(4) if jax.host_id() == 0 else np.arange(4,8)
|
||||
>>> data = jnp.arange(4) if jax.host_id() == 0 else jnp.arange(4,8)
|
||||
>>> out = pmap(f, axis_name='i')(data) # doctest: +SKIP
|
||||
>>> print(out) # doctest: +SKIP
|
||||
[28 29 30 31] # on host 0
|
||||
@ -1096,7 +1096,7 @@ def pmap(fun: Callable, axis_name: Optional[AxisName] = None, *, in_axes=0,
|
||||
Each host passes in a different length-4 array, corresponding to its 4 local
|
||||
devices, and the psum operates over all 8 values. Conceptually, the two
|
||||
length-4 arrays can be thought of as sharded length-8 array (in this example
|
||||
equivalent to np.arange(8)) that is mapped over, with the length-8 mapped axis
|
||||
equivalent to jnp.arange(8)) that is mapped over, with the length-8 mapped axis
|
||||
given name 'i'. The pmap call on each host then returns the corresponding
|
||||
length-4 output shard.
|
||||
|
||||
@ -1114,9 +1114,9 @@ def pmap(fun: Callable, axis_name: Optional[AxisName] = None, *, in_axes=0,
|
||||
... def f2(x):
|
||||
... return jax.lax.psum(x ** 2, axis_name='i')
|
||||
>>>
|
||||
>>> print(f1(np.arange(6.))) # doctest: +SKIP
|
||||
>>> print(f1(jnp.arange(6.))) # doctest: +SKIP
|
||||
[0. 0.06666667 0.13333333 0.2 0.26666667 0.33333333]
|
||||
>>> print(f2(np.array([2., 3.]))) # doctest: +SKIP
|
||||
>>> print(f2(jnp.array([2., 3.]))) # doctest: +SKIP
|
||||
[ 13. 13.]
|
||||
"""
|
||||
# axis_size is an optional integer representing the global axis size.
|
||||
@ -1301,7 +1301,7 @@ def mask(fun: Callable, in_shapes, out_shape) -> Callable:
|
||||
if in_tree != in_shapes_tree:
|
||||
raise TypeError(f"Tree mismatch: Input {in_tree} and shape spec {in_shapes_tree}.")
|
||||
logical_env = {unique_ids[name] : val for name, val in logical_env.items()}
|
||||
in_shapes = map(masking.finalize_spec, in_specs, map(onp.shape, args_flat))
|
||||
in_shapes = map(masking.finalize_spec, in_specs, map(np.shape, args_flat))
|
||||
padded_env = masking.bind_shapes(in_shapes, [x.shape for x in args_flat])
|
||||
f = lu.wrap_init(fun)
|
||||
flat_fun, out_tree_thunk = flatten_fun_nokwargs(f, in_tree)
|
||||
@ -1314,7 +1314,7 @@ def mask(fun: Callable, in_shapes, out_shape) -> Callable:
|
||||
return tuple(dim if dim is masking._monomorphic_dim else
|
||||
masking.eval_poly(dim, padded_env) for dim in shape_spec)
|
||||
masking.check_shapes(map(padded_spec, out_specs), out_spec_tree,
|
||||
map(onp.shape, outs), out_tree, "Padded output")
|
||||
map(np.shape, outs), out_tree, "Padded output")
|
||||
return tree_unflatten(out_tree, outs)
|
||||
return wrapped_fun
|
||||
|
||||
@ -1326,7 +1326,7 @@ def shapecheck(in_shapes, out_shape, fun: Callable):
|
||||
out_specs, out_spec_tree = tree_flatten(out_shape)
|
||||
out_specs = map(masking.parse_spec, out_specs)
|
||||
flat_fun, out_tree_thunk = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
|
||||
avals = map(partial(ShapedArray, dtype=onp.float32), in_shapes)
|
||||
avals = map(partial(ShapedArray, dtype=np.float32), in_shapes)
|
||||
out_shapes = [o.shape for o in pe.abstract_eval_fun(flat_fun.call_wrapped, *avals)]
|
||||
masking.check_shapes(map(tuple, out_specs), out_spec_tree,
|
||||
map(tuple, out_shapes), out_tree_thunk())
|
||||
@ -1439,9 +1439,9 @@ def linearize(fun: Callable, *primals) -> Tuple[Any, Callable]:
|
||||
Here's a more complete example of using :py:func:`linearize`:
|
||||
|
||||
>>> import jax
|
||||
>>> import jax.numpy as np
|
||||
>>> import jax.numpy as jnp
|
||||
>>>
|
||||
>>> def f(x): return 3. * np.sin(x) + np.cos(x / 2.)
|
||||
>>> def f(x): return 3. * jnp.sin(x) + jnp.cos(x / 2.)
|
||||
...
|
||||
>>> jax.jvp(f, (2.,), (3.,))
|
||||
(DeviceArray(3.26819, dtype=float32), DeviceArray(-5.00753, dtype=float32))
|
||||
@ -1483,7 +1483,7 @@ def _lift_linearized(jaxpr, primal_avals, consts, io_tree, out_pvals, *py_args):
|
||||
|
||||
def _check_inexact_input_vjp(x):
|
||||
aval = core.get_aval(x)
|
||||
if not dtypes.issubdtype(aval.dtype, onp.inexact):
|
||||
if not dtypes.issubdtype(aval.dtype, np.inexact):
|
||||
msg = ("Primal inputs to reverse-mode differentiation must be of float "
|
||||
"or complex type, got type {}")
|
||||
raise TypeError(msg.format(aval.dtype.name))
|
||||
@ -1703,9 +1703,9 @@ class ShapeDtypeStruct(object):
|
||||
__slots__ = ["shape", "dtype"]
|
||||
def __init__(self, shape, dtype):
|
||||
self.shape = shape
|
||||
self.dtype = onp.dtype(dtype)
|
||||
self.dtype = np.dtype(dtype)
|
||||
|
||||
size = property(lambda self: onp.prod(self.shape))
|
||||
size = property(lambda self: np.prod(self.shape))
|
||||
ndim = property(lambda self: len(self.shape))
|
||||
|
||||
def __len__(self):
|
||||
@ -1773,16 +1773,16 @@ def eval_shape(fun: Callable, *args, **kwargs):
|
||||
For example:
|
||||
|
||||
>>> import jax
|
||||
>>> import jax.numpy as np
|
||||
>>> import jax.numpy as jnp
|
||||
>>>
|
||||
>>> f = lambda A, x: np.tanh(np.dot(A, x))
|
||||
>>> f = lambda A, x: jnp.tanh(jnp.dot(A, x))
|
||||
>>> class MyArgArray(object):
|
||||
... def __init__(self, shape, dtype):
|
||||
... self.shape = shape
|
||||
... self.dtype = dtype
|
||||
...
|
||||
>>> A = MyArgArray((2000, 3000), np.float32)
|
||||
>>> x = MyArgArray((3000, 1000), np.float32)
|
||||
>>> A = MyArgArray((2000, 3000), jnp.float32)
|
||||
>>> x = MyArgArray((3000, 1000), jnp.float32)
|
||||
>>> out = jax.eval_shape(f, A, x) # no FLOPs performed
|
||||
>>> print(out.shape)
|
||||
(2000, 1000)
|
||||
@ -1790,7 +1790,7 @@ def eval_shape(fun: Callable, *args, **kwargs):
|
||||
float32
|
||||
"""
|
||||
def abstractify(x):
|
||||
return ShapedArray(onp.shape(x), dtypes.result_type(x))
|
||||
return ShapedArray(np.shape(x), dtypes.result_type(x))
|
||||
args_flat, in_tree = tree_flatten((args, kwargs))
|
||||
wrapped_fun, out_tree = flatten_fun(lu.wrap_init(fun), in_tree)
|
||||
out = pe.abstract_eval_fun(wrapped_fun.call_wrapped,
|
||||
|
14
jax/core.py
14
jax/core.py
@ -26,7 +26,7 @@ from typing import (Any, Callable, ClassVar, Dict, Generator,
|
||||
Iterator, List, NamedTuple, Optional, Sequence, Set, Tuple,
|
||||
Type, Union, cast)
|
||||
|
||||
import numpy as onp
|
||||
import numpy as np
|
||||
|
||||
from . import dtypes
|
||||
from .config import FLAGS
|
||||
@ -846,7 +846,7 @@ class UnshapedArray(AbstractValue):
|
||||
array_abstraction_level = 2
|
||||
|
||||
def __init__(self, dtype, weak_type=False):
|
||||
self.dtype = onp.dtype(dtypes.canonicalize_dtype(dtype))
|
||||
self.dtype = np.dtype(dtypes.canonicalize_dtype(dtype))
|
||||
self.weak_type = weak_type
|
||||
|
||||
def __eq__(self, other):
|
||||
@ -858,7 +858,7 @@ class UnshapedArray(AbstractValue):
|
||||
|
||||
def __hash__(self):
|
||||
# can use hash(self.dtype) and rely on the fact that numpy reuses base dtype
|
||||
# objects, e.g. `onp.zeros(3).dtype is onp.zeros(4).dtype`, or we can use
|
||||
# objects, e.g. `np.zeros(3).dtype is np.zeros(4).dtype`, or we can use
|
||||
# the unique character code via hash(self.dtype.char)
|
||||
return hash((self.dtype, self.weak_type))
|
||||
|
||||
@ -925,7 +925,7 @@ class ShapedArray(UnshapedArray):
|
||||
|
||||
def __hash__(self):
|
||||
# can use hash(self.dtype) and rely on the fact that numpy reuses base dtype
|
||||
# objects, e.g. `onp.zeros(3).dtype is onp.zeros(4).dtype`, or we can use
|
||||
# objects, e.g. `np.zeros(3).dtype is np.zeros(4).dtype`, or we can use
|
||||
# the unique character code via hash(self.dtype.char)
|
||||
return hash((self.shape, self.dtype, self.weak_type))
|
||||
|
||||
@ -968,16 +968,16 @@ class ConcreteArray(ShapedArray):
|
||||
array_abstraction_level = 0
|
||||
|
||||
def __init__(self, val, weak_type=False):
|
||||
super(ConcreteArray, self).__init__(onp.shape(val), onp.result_type(val),
|
||||
super(ConcreteArray, self).__init__(np.shape(val), np.result_type(val),
|
||||
weak_type=weak_type)
|
||||
# Note: canonicalized self.dtype doesn't necessarily match self.val
|
||||
self.val = val
|
||||
assert self.dtype != onp.dtype('O')
|
||||
assert self.dtype != np.dtype('O')
|
||||
|
||||
def __eq__(self, other):
|
||||
return (type(self) is type(other) and self.dtype == other.dtype
|
||||
and self.shape == other.shape and self.weak_type == other.weak_type
|
||||
and onp.all(self.val == other.val))
|
||||
and np.all(self.val == other.val))
|
||||
|
||||
def __hash__(self):
|
||||
return id(self.val)
|
||||
|
@ -21,7 +21,7 @@ not yet fine-tuned the performance of the resulting XLA compilation!
|
||||
By default, loops and control-flow in JAX are executed and inlined during tracing.
|
||||
For example, in the following code the `for` loop is unrolled during JAX tracing::
|
||||
|
||||
arr = onp.zeros(5)
|
||||
arr = np.zeros(5)
|
||||
for i in range(arr.shape[0]):
|
||||
arr[i] += 2.
|
||||
if i % 2 == 0:
|
||||
@ -32,7 +32,7 @@ JAX operations, which require you to express the body of the loops and
|
||||
conditionals as functions, and the array updates using a functional style that
|
||||
returns an updated array, e.g.::
|
||||
|
||||
arr = onp.zeros(5)
|
||||
arr = np.zeros(5)
|
||||
def loop_body(i, acc_arr):
|
||||
arr1 = ops.index_update(acc_arr, i, acc_arr[i] + 2.)
|
||||
return lax.cond(i % 2 == 0,
|
||||
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as onp
|
||||
import numpy as np
|
||||
from typing import Any, Callable, Dict, Optional, Tuple, Union
|
||||
|
||||
import jax
|
||||
@ -103,7 +103,7 @@ class BatchTracer(Tracer):
|
||||
return aval
|
||||
elif type(aval) is ShapedArray:
|
||||
assert 0 <= self.batch_dim < aval.ndim
|
||||
new_shape = tuple(onp.delete(aval.shape, self.batch_dim))
|
||||
new_shape = tuple(np.delete(aval.shape, self.batch_dim))
|
||||
return ShapedArray(new_shape, aval.dtype)
|
||||
else:
|
||||
raise TypeError(aval)
|
||||
@ -236,7 +236,7 @@ def broadcast_batcher(prim, args, dims, **params):
|
||||
either an int indicating the batch dimension, or else `not_mapped`
|
||||
indicating no batching.
|
||||
"""
|
||||
shapes = {(x.shape, d) for x, d in zip(args, dims) if onp.ndim(x)}
|
||||
shapes = {(x.shape, d) for x, d in zip(args, dims) if np.ndim(x)}
|
||||
if len(shapes) == 1:
|
||||
# if there's only agreeing batch dims and scalars, just call the primitive
|
||||
d = next(d for d in dims if d is not not_mapped)
|
||||
@ -245,16 +245,16 @@ def broadcast_batcher(prim, args, dims, **params):
|
||||
else:
|
||||
size, = {shape[d] for shape, d in shapes if d is not not_mapped}
|
||||
args = [bdim_at_front(x, d, size) for x, d in zip(args, dims)]
|
||||
ndim = max(onp.ndim(x) for x in args) # special-case scalar broadcasting
|
||||
ndim = max(np.ndim(x) for x in args) # special-case scalar broadcasting
|
||||
args = [_handle_scalar_broadcasting(ndim, x, d) for x, d in zip(args, dims)]
|
||||
out = prim.bind(*args, **params)
|
||||
return (out, (0,) * len(out)) if prim.multiple_results else (out, 0)
|
||||
|
||||
def _handle_scalar_broadcasting(nd, x, d):
|
||||
if d is not_mapped or nd == onp.ndim(x):
|
||||
if d is not_mapped or nd == np.ndim(x):
|
||||
return x
|
||||
else:
|
||||
return x.reshape(x.shape + (1,) * (nd - onp.ndim(x)))
|
||||
return x.reshape(x.shape + (1,) * (nd - np.ndim(x)))
|
||||
|
||||
def defreducer(prim):
|
||||
primitive_batchers[prim] = partial(reducer_batcher, prim)
|
||||
@ -262,8 +262,8 @@ def defreducer(prim):
|
||||
def reducer_batcher(prim, batched_args, batch_dims, axes, **params):
|
||||
operand, = batched_args
|
||||
bdim, = batch_dims
|
||||
axes = tuple(onp.where(onp.less(axes, bdim), axes, onp.add(axes, 1)))
|
||||
bdim_out = int(list(onp.delete(onp.arange(operand.ndim), axes)).index(bdim))
|
||||
axes = tuple(np.where(np.less(axes, bdim), axes, np.add(axes, 1)))
|
||||
bdim_out = int(list(np.delete(np.arange(operand.ndim), axes)).index(bdim))
|
||||
if 'input_shape' in params:
|
||||
params = dict(params, input_shape=operand.shape)
|
||||
return prim.bind(operand, axes=axes, **params), bdim_out
|
||||
@ -303,10 +303,10 @@ def broadcast(x, sz, axis):
|
||||
if core.get_aval(x) is core.abstract_unit:
|
||||
return core.unit
|
||||
if axis is last:
|
||||
axis = onp.ndim(x)
|
||||
shape = list(onp.shape(x))
|
||||
axis = np.ndim(x)
|
||||
shape = list(np.shape(x))
|
||||
shape.insert(axis, sz)
|
||||
broadcast_dims = tuple(onp.delete(onp.arange(len(shape)), axis))
|
||||
broadcast_dims = tuple(np.delete(np.arange(len(shape)), axis))
|
||||
return jax.lax.broadcast_in_dim(x, shape, broadcast_dims)
|
||||
|
||||
def moveaxis(x, src, dst):
|
||||
@ -315,7 +315,7 @@ def moveaxis(x, src, dst):
|
||||
if src == dst:
|
||||
return x
|
||||
src, dst = src % x.ndim, dst % x.ndim
|
||||
perm = [i for i in range(onp.ndim(x)) if i != src]
|
||||
perm = [i for i in range(np.ndim(x)) if i != src]
|
||||
perm.insert(dst, src)
|
||||
return x.transpose(perm)
|
||||
|
||||
|
@ -20,7 +20,7 @@ import operator as op
|
||||
import string
|
||||
from typing import Callable, Dict, Sequence, Union
|
||||
|
||||
import numpy as onp
|
||||
import numpy as np
|
||||
|
||||
from .. import abstract_arrays
|
||||
from .. import core, dtypes
|
||||
@ -317,7 +317,7 @@ def parse_spec(spec=''):
|
||||
|
||||
def _parse_dim(spec):
|
||||
if '+' in spec:
|
||||
return onp.sum(map(_parse_dim, spec.split('+')))
|
||||
return np.sum(map(_parse_dim, spec.split('+')))
|
||||
elif '*' in spec:
|
||||
return prod(map(_parse_dim, spec.split('*')))
|
||||
elif spec.isdigit() or spec.startswith('-') and spec[1:].isdigit():
|
||||
@ -383,10 +383,10 @@ class MaskTracer(Tracer):
|
||||
|
||||
class MaskTrace(Trace):
|
||||
def pure(self, val):
|
||||
return MaskTracer(self, val, onp.shape(val))
|
||||
return MaskTracer(self, val, np.shape(val))
|
||||
|
||||
def lift(self, val):
|
||||
return MaskTracer(self, val, onp.shape(val))
|
||||
return MaskTracer(self, val, np.shape(val))
|
||||
|
||||
def sublift(self, val):
|
||||
return MaskTracer(self, val.val, val.polymorphic_shape)
|
||||
|
@ -20,7 +20,7 @@ from typing import (Callable, Dict, NamedTuple, Optional, Sequence,
|
||||
Set, Tuple, Type, Union, cast)
|
||||
from weakref import ref
|
||||
|
||||
import numpy as onp
|
||||
import numpy as np
|
||||
|
||||
from .. import core
|
||||
from .. import linear_util as lu
|
||||
@ -128,7 +128,7 @@ class JaxprTrace(Trace):
|
||||
if const is None:
|
||||
return tracer
|
||||
else:
|
||||
if type(const) in core.literalable_types and onp.shape(const) == ():
|
||||
if type(const) in core.literalable_types and np.shape(const) == ():
|
||||
return self.new_instantiated_literal(const)
|
||||
else:
|
||||
return self.new_instantiated_const(const)
|
||||
@ -138,7 +138,7 @@ class JaxprTrace(Trace):
|
||||
if const is None:
|
||||
return tracer
|
||||
else:
|
||||
aval = raise_to_shaped(get_aval(const), onp.isscalar(const))
|
||||
aval = raise_to_shaped(get_aval(const), np.isscalar(const))
|
||||
return JaxprTracer(self, PartialVal.unknown(aval), ConstVar(const))
|
||||
|
||||
def process_primitive(self, primitive, tracers, params):
|
||||
|
@ -37,7 +37,7 @@ from typing import (Any, Callable, Dict, List, Optional, Sequence, Set, Tuple,
|
||||
Type, Union)
|
||||
|
||||
from absl import logging
|
||||
import numpy as onp
|
||||
import numpy as np
|
||||
|
||||
from ..config import flags
|
||||
from .. import core
|
||||
@ -465,7 +465,7 @@ def _axis_index_bind(*, axis_name):
|
||||
nreps = dynamic_axis_env.nreps
|
||||
trace = frame.pmap_trace
|
||||
|
||||
out_aval = ShapedArray((), onp.int32)
|
||||
out_aval = ShapedArray((), np.int32)
|
||||
out_tracer = pe.JaxprTracer(trace, pe.PartialVal.unknown(out_aval), None)
|
||||
eqn = pe.new_eqn_recipe([], [out_tracer], axis_index_p,
|
||||
dict(nreps=nreps, sizes=sizes,
|
||||
@ -476,19 +476,19 @@ def _axis_index_bind(*, axis_name):
|
||||
if not frame.soft_trace:
|
||||
return out_tracer
|
||||
else:
|
||||
val_out = out_tracer * frame.soft_size + onp.arange(frame.soft_size)
|
||||
val_out = out_tracer * frame.soft_size + np.arange(frame.soft_size)
|
||||
return SplitAxisTracer(frame.soft_trace, axis_name, val_out)
|
||||
|
||||
def _axis_index_translation_rule(c, nreps, sizes, soft_size, axis_name):
|
||||
div = xb.constant(c, onp.array(nreps // prod(sizes), dtype=onp.uint32))
|
||||
mod = xb.constant(c, onp.array(sizes[-1], dtype=onp.uint32))
|
||||
div = xb.constant(c, np.array(nreps // prod(sizes), dtype=np.uint32))
|
||||
mod = xb.constant(c, np.array(sizes[-1], dtype=np.uint32))
|
||||
unsigned_index = xops.Rem(xops.Div(xops.ReplicaId(c), div), mod)
|
||||
return xops.ConvertElementType(unsigned_index, xb.dtype_to_etype(onp.int32))
|
||||
return xops.ConvertElementType(unsigned_index, xb.dtype_to_etype(np.int32))
|
||||
|
||||
axis_index_p = core.Primitive('axis_index')
|
||||
axis_index_p.def_custom_bind(_axis_index_bind)
|
||||
axis_index_p.def_abstract_eval(
|
||||
lambda *args, **params: ShapedArray((), onp.int32))
|
||||
lambda *args, **params: ShapedArray((), np.int32))
|
||||
xla.translations[axis_index_p] = _axis_index_translation_rule
|
||||
|
||||
|
||||
@ -587,7 +587,7 @@ class ShardedDeviceArray(xla.DeviceArray):
|
||||
def _value(self):
|
||||
if self._npy_value is None:
|
||||
self.copy_to_host_async()
|
||||
npy_value = onp.empty(self.aval.shape, self.aval.dtype)
|
||||
npy_value = np.empty(self.aval.shape, self.aval.dtype)
|
||||
for i in self.one_replica_buffer_indices:
|
||||
npy_value[self.indices[i]] = self.device_buffers[i].to_py()
|
||||
self._npy_value = npy_value
|
||||
@ -633,7 +633,7 @@ def _shard_sharded_device_array_slow_path(x, devices, indices):
|
||||
shard_arg_handlers[ShardedDeviceArray] = _shard_sharded_device_array_slow_path
|
||||
|
||||
def _sharded_device_array_constant_handler(c, val, canonicalize_types=True):
|
||||
return xb.constant(c, onp.asarray(val), canonicalize_types=canonicalize_types)
|
||||
return xb.constant(c, np.asarray(val), canonicalize_types=canonicalize_types)
|
||||
xb.register_constant_handler(ShardedDeviceArray, _sharded_device_array_constant_handler)
|
||||
|
||||
core.pytype_aval_mappings[ShardedDeviceArray] = ConcreteArray
|
||||
@ -838,7 +838,7 @@ def parallel_callable(fun, backend, axis_name, axis_size, global_axis_size,
|
||||
# provided 1D list of devices).
|
||||
device_assignment = tree_map(lambda d: d.id, devices)
|
||||
# Convert to 2D in case it's 1D and we have > 1 partitions.
|
||||
device_assignment = onp.array(device_assignment).reshape(
|
||||
device_assignment = np.array(device_assignment).reshape(
|
||||
(num_global_replicas, num_partitions))
|
||||
compile_options = xb.get_compile_options(
|
||||
num_replicas=num_global_replicas,
|
||||
@ -933,7 +933,7 @@ def get_num_partitions(*partitions):
|
||||
if len(partition_specs) == 0:
|
||||
# Everything is specified as replicated (all Nones).
|
||||
return None
|
||||
num_partitions_set = set(onp.prod(spec) for spec in partition_specs)
|
||||
num_partitions_set = set(np.prod(spec) for spec in partition_specs)
|
||||
if len(num_partitions_set) > 1:
|
||||
raise ValueError(
|
||||
f"All partition specs must use the same number of total partitions, "
|
||||
@ -1157,7 +1157,7 @@ def _xla_shard(c, aval, axis_env, x):
|
||||
return x
|
||||
elif isinstance(aval, ShapedArray):
|
||||
dims = list(c.get_shape(x).dimensions())
|
||||
zero = xb.constant(c, onp.zeros((), dtype=onp.uint32))
|
||||
zero = xb.constant(c, np.zeros((), dtype=np.uint32))
|
||||
idxs = [_unravel_index(c, axis_env)] + [zero] * (len(dims) - 1)
|
||||
return xops.Reshape(xops.DynamicSlice(x, idxs, [1] + dims[1:]), dims[1:])
|
||||
else:
|
||||
@ -1169,16 +1169,16 @@ def _xla_unshard(c, aval, axis_env, x, backend):
|
||||
return x
|
||||
elif isinstance(aval, ShapedArray):
|
||||
# TODO(mattjj): remove this logic when AllReduce PRED supported on CPU / GPU
|
||||
convert_bool = (onp.issubdtype(aval.dtype, onp.bool_)
|
||||
convert_bool = (np.issubdtype(aval.dtype, np.bool_)
|
||||
and xb.get_backend(backend).platform in ('cpu', 'gpu'))
|
||||
if convert_bool:
|
||||
x = xops.ConvertElementType(x, xb.dtype_to_etype(onp.float32))
|
||||
x = xops.ConvertElementType(x, xb.dtype_to_etype(np.float32))
|
||||
|
||||
xla_shape = c.get_shape(x)
|
||||
dims = list(xla_shape.dimensions())
|
||||
padded = xops.Broadcast(xb.constant(c, onp.array(0, xla_shape.numpy_dtype())),
|
||||
padded = xops.Broadcast(xb.constant(c, np.array(0, xla_shape.numpy_dtype())),
|
||||
[axis_env.sizes[-1]] + dims)
|
||||
zero = xb.constant(c, onp.zeros((), dtype=onp.uint32))
|
||||
zero = xb.constant(c, np.zeros((), dtype=np.uint32))
|
||||
idxs = [_unravel_index(c, axis_env)] + [zero] * len(dims)
|
||||
padded = xops.DynamicUpdateSlice(padded, xops.Reshape(x, [1] + dims), idxs)
|
||||
replica_groups_protos = xc.make_replica_groups(
|
||||
@ -1187,15 +1187,15 @@ def _xla_unshard(c, aval, axis_env, x, backend):
|
||||
|
||||
# TODO(mattjj): remove this logic when AllReduce PRED supported on CPU / GPU
|
||||
if convert_bool:
|
||||
nonzero = xops.Ne(out, xb.constant(c, onp.array(0, dtype=onp.float32)))
|
||||
out = xops.ConvertElementType(nonzero, xb.dtype_to_etype(onp.bool_))
|
||||
nonzero = xops.Ne(out, xb.constant(c, np.array(0, dtype=np.float32)))
|
||||
out = xops.ConvertElementType(nonzero, xb.dtype_to_etype(np.bool_))
|
||||
return out
|
||||
else:
|
||||
raise TypeError((aval, c.get_shape(x)))
|
||||
|
||||
def _unravel_index(c, axis_env):
|
||||
div = xb.constant(c, onp.array(axis_env.nreps // prod(axis_env.sizes), onp.uint32))
|
||||
mod = xb.constant(c, onp.array(axis_env.sizes[-1], onp.uint32))
|
||||
div = xb.constant(c, np.array(axis_env.nreps // prod(axis_env.sizes), np.uint32))
|
||||
mod = xb.constant(c, np.array(axis_env.sizes[-1], np.uint32))
|
||||
return xops.Rem(xops.Div(xops.ReplicaId(c), div), mod)
|
||||
|
||||
|
||||
@ -1278,7 +1278,7 @@ class SplitAxisTrace(core.Trace):
|
||||
if primitive is axis_index_p:
|
||||
dummy, = vals_in
|
||||
hard_idx = primitive.bind(dummy, **params)
|
||||
val_out = hard_idx * params['soft_size'] + onp.arange(params['soft_size'])
|
||||
val_out = hard_idx * params['soft_size'] + np.arange(params['soft_size'])
|
||||
return SplitAxisTracer(self, params['axis_name'], val_out)
|
||||
elif all(axis_name is not_mapped for axis_name in names_in):
|
||||
return primitive.bind(*vals_in, **params)
|
||||
|
@ -20,7 +20,7 @@ from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Type, Tup
|
||||
from warnings import warn
|
||||
|
||||
from absl import logging
|
||||
import numpy as onp
|
||||
import numpy as np
|
||||
|
||||
from ..config import flags, bool_env
|
||||
from .. import core
|
||||
@ -74,11 +74,11 @@ def identity(x): return x
|
||||
_scalar_types = dtypes.python_scalar_dtypes.keys()
|
||||
|
||||
# unit representation
|
||||
def _make_unit(c): return xb.constant(c, onp.zeros((), dtype=onp.dtype('bool')))
|
||||
def _make_abstract_unit(_): return xc.Shape.array_shape(onp.dtype('bool'), ())
|
||||
def _make_unit(c): return xb.constant(c, np.zeros((), dtype=np.dtype('bool')))
|
||||
def _make_abstract_unit(_): return xc.Shape.array_shape(np.dtype('bool'), ())
|
||||
def _device_put_unit(_, device):
|
||||
backend = xb.get_device_backend(device)
|
||||
return backend.buffer_from_pyval(onp.zeros((), dtype=onp.dtype('bool')),
|
||||
return backend.buffer_from_pyval(np.zeros((), dtype=np.dtype('bool')),
|
||||
device)
|
||||
def _make_array_shape(a):
|
||||
return xc.Shape.array_shape(a.dtype, a.shape)
|
||||
@ -143,10 +143,10 @@ def canonicalize_dtype(x):
|
||||
raise TypeError(f"No canonicalize_dtype handler for type: {type(x)}")
|
||||
|
||||
def _canonicalize_ndarray_dtype(x):
|
||||
return onp.asarray(x, dtypes.canonicalize_dtype(dtypes.result_type(x)))
|
||||
return np.asarray(x, dtypes.canonicalize_dtype(dtypes.result_type(x)))
|
||||
|
||||
def _canonicalize_python_scalar_dtype(typ, x):
|
||||
return onp.asarray(
|
||||
return np.asarray(
|
||||
x, dtypes.canonicalize_dtype(dtypes.python_scalar_dtypes[typ]))
|
||||
|
||||
canonicalize_dtype_handlers: Dict[Any, Callable] = {core.Unit: identity}
|
||||
@ -342,8 +342,8 @@ def check_nans(prim, bufs):
|
||||
|
||||
def _check_nans(name, xla_shape, buf):
|
||||
assert not xla_shape.is_tuple()
|
||||
if dtypes.issubdtype(xla_shape.element_type(), onp.inexact):
|
||||
if onp.any(onp.isnan(buf.to_py())):
|
||||
if dtypes.issubdtype(xla_shape.element_type(), np.inexact):
|
||||
if np.any(np.isnan(buf.to_py())):
|
||||
raise FloatingPointError(f"invalid value (nan) encountered in {name}")
|
||||
|
||||
### compiling jaxprs
|
||||
@ -477,10 +477,10 @@ def _axis_groups(nrep, mesh_spec, mesh_axes):
|
||||
trailing_size, ragged = divmod(nrep, prod(mesh_spec))
|
||||
assert not ragged
|
||||
full_spec = list(mesh_spec) + [trailing_size]
|
||||
iota = onp.arange(prod(full_spec)).reshape(full_spec)
|
||||
groups = onp.reshape(
|
||||
onp.moveaxis(iota, mesh_axes, onp.arange(len(mesh_axes))),
|
||||
(prod(onp.take(full_spec, mesh_axes)), -1))
|
||||
iota = np.arange(prod(full_spec)).reshape(full_spec)
|
||||
groups = np.reshape(
|
||||
np.moveaxis(iota, mesh_axes, np.arange(len(mesh_axes))),
|
||||
(prod(np.take(full_spec, mesh_axes)), -1))
|
||||
return tuple(unsafe_map(tuple, groups.T))
|
||||
|
||||
def jaxpr_replicas(jaxpr):
|
||||
@ -862,7 +862,7 @@ call_translations[xla_call_p] = _xla_call_translation_rule
|
||||
def zeros_like_translation_rule(c, x):
|
||||
shape = c.get_shape(x)
|
||||
assert not shape.is_tuple()
|
||||
zero = xb.constant(c, onp.array(0, shape.element_type()))
|
||||
zero = xb.constant(c, np.array(0, shape.element_type()))
|
||||
return xops.Broadcast(zero, shape.dimensions())
|
||||
translations[ad_util.zeros_like_p] = zeros_like_translation_rule
|
||||
|
||||
@ -1018,7 +1018,7 @@ class DeviceArray(DeviceValue):
|
||||
|
||||
def copy(self):
|
||||
"""Returns an ndarray (backed by host memory, not device memory)."""
|
||||
return onp.asarray(self)
|
||||
return np.asarray(self)
|
||||
|
||||
def copy_to_host_async(self):
|
||||
"""Requests a copy of the buffer to the host."""
|
||||
@ -1042,10 +1042,10 @@ class DeviceArray(DeviceValue):
|
||||
self._npy_value = None
|
||||
|
||||
def __repr__(self):
|
||||
line_width = onp.get_printoptions()['linewidth']
|
||||
line_width = np.get_printoptions()['linewidth']
|
||||
prefix = '{}('.format(self.__class__.__name__)
|
||||
s = onp.array2string(self._value, prefix=prefix, suffix=',',
|
||||
separator=', ', max_line_width=line_width)
|
||||
s = np.array2string(self._value, prefix=prefix, suffix=',',
|
||||
separator=', ', max_line_width=line_width)
|
||||
dtype_str = 'dtype={})'.format(self.dtype.name)
|
||||
last_line_len = len(s) - s.rfind('\n') + 1
|
||||
sep = ' '
|
||||
@ -1054,13 +1054,13 @@ class DeviceArray(DeviceValue):
|
||||
return "{}{},{}{}".format(prefix, s, sep, dtype_str)
|
||||
|
||||
def item(self):
|
||||
if dtypes.issubdtype(self.dtype, onp.complexfloating):
|
||||
if dtypes.issubdtype(self.dtype, np.complexfloating):
|
||||
return complex(self)
|
||||
elif dtypes.issubdtype(self.dtype, onp.floating):
|
||||
elif dtypes.issubdtype(self.dtype, np.floating):
|
||||
return float(self)
|
||||
elif dtypes.issubdtype(self.dtype, onp.integer):
|
||||
elif dtypes.issubdtype(self.dtype, np.integer):
|
||||
return int(self)
|
||||
elif dtypes.issubdtype(self.dtype, onp.bool_):
|
||||
elif dtypes.issubdtype(self.dtype, np.bool_):
|
||||
return bool(self)
|
||||
else:
|
||||
raise TypeError(self.dtype)
|
||||
@ -1091,7 +1091,7 @@ class DeviceArray(DeviceValue):
|
||||
return format(self._value, format_spec)
|
||||
|
||||
def __array__(self, dtype=None, context=None):
|
||||
return onp.asarray(self._value, dtype=dtype)
|
||||
return np.asarray(self._value, dtype=dtype)
|
||||
|
||||
@property
|
||||
def __cuda_array_interface__(self):
|
||||
@ -1251,10 +1251,10 @@ def _remat_translation_rule(c, axis_env, in_nodes,
|
||||
Conditional."""
|
||||
del device, concrete # Unused.
|
||||
# Fake condition which always selects True branch.
|
||||
rng = xops.RngUniform(xb.constant(c, onp.array(0, dtype=onp.float32)),
|
||||
xb.constant(c, onp.array(1, dtype=onp.float32)),
|
||||
rng = xops.RngUniform(xb.constant(c, np.array(0, dtype=np.float32)),
|
||||
xb.constant(c, np.array(1, dtype=np.float32)),
|
||||
xc.Shape.array_shape(xc.PrimitiveType.F32, []))
|
||||
pred = xops.Lt(rng, xb.constant(c, onp.array(2, dtype=onp.float32)))
|
||||
pred = xops.Lt(rng, xb.constant(c, np.array(2, dtype=np.float32)))
|
||||
|
||||
true_op = xops.Tuple(c, in_nodes)
|
||||
remat_subc = xb.make_computation_builder("remat_call_subcomputation")
|
||||
@ -1272,7 +1272,7 @@ def _remat_translation_rule(c, axis_env, in_nodes,
|
||||
|
||||
def zeros(xla_shape):
|
||||
shape, dtype = xla_shape.dimensions(), xla_shape.numpy_dtype()
|
||||
zero = xb.constant(dummy_subc, onp.array(0, dtype=dtype))
|
||||
zero = xb.constant(dummy_subc, np.array(0, dtype=dtype))
|
||||
return xops.Broadcast(zero, shape)
|
||||
out_nodes = [zeros(s) for s in out_node_shapes]
|
||||
dummy_subc = dummy_subc.build(xops.Tuple(dummy_subc, out_nodes))
|
||||
|
626
jax/lax/lax.py
626
jax/lax/lax.py
File diff suppressed because it is too large
Load Diff
@ -24,7 +24,7 @@ import itertools
|
||||
import operator
|
||||
from typing import Callable, Sequence
|
||||
|
||||
import numpy as onp
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
from jax import core
|
||||
@ -273,7 +273,7 @@ def while_loop(cond_fun, body_fun, init_val):
|
||||
if not treedef_is_leaf(cond_tree) or len(cond_jaxpr.out_avals) != 1:
|
||||
msg = "cond_fun must return a boolean scalar, but got pytree {}."
|
||||
raise TypeError(msg.format(cond_tree))
|
||||
if cond_jaxpr.out_avals[0].strip_weak_type() != ShapedArray((), onp.bool_):
|
||||
if cond_jaxpr.out_avals[0].strip_weak_type() != ShapedArray((), np.bool_):
|
||||
msg = "cond_fun must return a boolean scalar, but got output type(s) {}."
|
||||
raise TypeError(msg.format(cond_jaxpr.out_avals))
|
||||
|
||||
@ -313,9 +313,9 @@ def _while_loop_translation_rule(c, axis_env, name_stack, avals, backend, *args,
|
||||
cond_jaxpr.literals),
|
||||
extend_name_stack(name_stack, 'cond'), *(x + z))
|
||||
if batched:
|
||||
scalar = ShapedArray((), onp.bool_)
|
||||
scalar = ShapedArray((), np.bool_)
|
||||
or_ = xla.primitive_subcomputation(lax.or_p, scalar, scalar)
|
||||
pred = xops.Reduce(cond_c, [pred], [xb.constant(cond_c, onp.array(False))], or_,
|
||||
pred = xops.Reduce(cond_c, [pred], [xb.constant(cond_c, np.array(False))], or_,
|
||||
list(range(cond_jaxpr.out_avals[0].ndim)))
|
||||
|
||||
body_c = xb.make_computation_builder("body_computation")
|
||||
@ -560,10 +560,10 @@ def switch(index, branches: Sequence[Callable], operand):
|
||||
branches: Sequence of functions (A -> B) to be applied based on `index`.
|
||||
operand: Operand (A) input to whichever branch is applied.
|
||||
"""
|
||||
if len(onp.shape(index)) != 0:
|
||||
if len(np.shape(index)) != 0:
|
||||
raise TypeError(
|
||||
f"Branch index must be scalar, "
|
||||
f"got {index} of shape {onp.shape(index)}.")
|
||||
f"got {index} of shape {np.shape(index)}.")
|
||||
|
||||
try:
|
||||
index_dtype = dtypes.result_type(index)
|
||||
@ -582,9 +582,9 @@ def switch(index, branches: Sequence[Callable], operand):
|
||||
elif len(branches) == 1:
|
||||
return branches[0](operand)
|
||||
|
||||
index = lax.convert_element_type(index, onp.int32)
|
||||
lo = onp.array(0, onp.int32)
|
||||
hi = onp.array(len(branches) - 1, onp.int32)
|
||||
index = lax.convert_element_type(index, np.int32)
|
||||
lo = np.array(0, np.int32)
|
||||
hi = np.array(len(branches) - 1, np.int32)
|
||||
index = lax.clamp(lo, index, hi)
|
||||
|
||||
if (jax.api._jit_is_disabled() and
|
||||
@ -651,9 +651,9 @@ def cond(*args, **kwargs):
|
||||
return _cond(*args, **kwargs)
|
||||
|
||||
def _cond(pred, true_fun: Callable, false_fun: Callable, operand):
|
||||
if len(onp.shape(pred)) != 0:
|
||||
if len(np.shape(pred)) != 0:
|
||||
raise TypeError(
|
||||
f"Pred must be a scalar, got {pred} of shape {onp.shape(pred)}.")
|
||||
f"Pred must be a scalar, got {pred} of shape {np.shape(pred)}.")
|
||||
|
||||
try:
|
||||
pred_dtype = dtypes.result_type(pred)
|
||||
@ -686,7 +686,7 @@ def _cond(pred, true_fun: Callable, false_fun: Callable, operand):
|
||||
out_tree, true_jaxpr.out_avals,
|
||||
false_out_tree, false_jaxpr.out_avals)
|
||||
|
||||
index = lax.convert_element_type(pred, onp.int32)
|
||||
index = lax.convert_element_type(pred, np.int32)
|
||||
|
||||
linear = (False,) * (len(consts) + len(ops))
|
||||
out = cond_p.bind(
|
||||
@ -742,7 +742,7 @@ def _select_tree(indices, branch_vals):
|
||||
if len(branch_vals) == 1:
|
||||
return branch_vals[0]
|
||||
mid = len(branch_vals) // 2
|
||||
mid = onp.array(mid, dtypes.canonicalize_dtype(lax.dtype(indices)))
|
||||
mid = np.array(mid, dtypes.canonicalize_dtype(lax.dtype(indices)))
|
||||
return lax.select(lax.lt(indices, mid),
|
||||
_select_tree(indices, branch_vals[:mid]),
|
||||
_select_tree(indices - mid, branch_vals[mid:]))
|
||||
@ -752,7 +752,7 @@ def _cond_index_bcast_and_select_tree(indices, branch_vals):
|
||||
return branch_vals[0]
|
||||
else:
|
||||
bcast_indices = lax.broadcast_in_dim(
|
||||
indices, onp.shape(branch_vals[0]), list(range(onp.ndim(indices))))
|
||||
indices, np.shape(branch_vals[0]), list(range(np.ndim(indices))))
|
||||
return _select_tree(bcast_indices, branch_vals)
|
||||
|
||||
def _cond_batching_rule(args, dims, branches, linear):
|
||||
@ -1066,7 +1066,7 @@ def _cond_typecheck(*avals, branches, linear):
|
||||
|
||||
index_aval, *op_avals = avals
|
||||
core.typecheck_assert(
|
||||
index_aval.dtype == onp.int32,
|
||||
index_aval.dtype == np.int32,
|
||||
f'cond called with index of type {index_aval.dtype} instead of int32')
|
||||
core.typecheck_assert(
|
||||
all(_map(core.typecompat, jaxpr0.in_avals, op_avals)),
|
||||
@ -1859,8 +1859,8 @@ def _flatten(args):
|
||||
|
||||
|
||||
def _check_shapes(func_name, expected_name, actual, expected, tree):
|
||||
actual_shapes = _map(onp.shape, actual)
|
||||
expected_shapes = _map(onp.shape, expected)
|
||||
actual_shapes = _map(np.shape, actual)
|
||||
expected_shapes = _map(np.shape, expected)
|
||||
if actual_shapes != expected_shapes:
|
||||
raise ValueError('{}() output shapes must match {}, got {} and {}'
|
||||
.format(func_name, expected_name,
|
||||
@ -2102,18 +2102,18 @@ batching.primitive_batchers[linear_solve_p] = _linear_solve_batching_rule
|
||||
def _interleave(a, b):
|
||||
"""Given two Tensors of static shape, interleave them along the first axis."""
|
||||
# TODO(mattjj)
|
||||
import jax.numpy as np
|
||||
import jax.numpy as jnp
|
||||
# [a b c ...] [d e f ...] -> [a d b e c f ...]
|
||||
half_num_elems = b.shape[0]
|
||||
|
||||
if a.shape[0] > b.shape[0]:
|
||||
return np.concatenate(
|
||||
[np.reshape(np.stack([a[: -1], b], axis=1),
|
||||
(2 * half_num_elems,) + a.shape[1:]),
|
||||
return jnp.concatenate(
|
||||
[jnp.reshape(jnp.stack([a[: -1], b], axis=1),
|
||||
(2 * half_num_elems,) + a.shape[1:]),
|
||||
a[-1:]], axis=0)
|
||||
else:
|
||||
return np.reshape(np.stack([a, b], axis=1),
|
||||
(2 * half_num_elems,) + a.shape[1:])
|
||||
return jnp.reshape(jnp.stack([a, b], axis=1),
|
||||
(2 * half_num_elems,) + a.shape[1:])
|
||||
|
||||
def associative_scan(fn, elems):
|
||||
"""Perform a scan with an associative binary operation, in parallel.
|
||||
|
@ -15,7 +15,7 @@
|
||||
|
||||
from functools import partial
|
||||
|
||||
import numpy as onp
|
||||
import numpy as np
|
||||
|
||||
from jax.abstract_arrays import ShapedArray
|
||||
from jax.api import jit, vjp
|
||||
@ -36,24 +36,24 @@ __all__ = [
|
||||
]
|
||||
|
||||
def _promote_to_complex(arg):
|
||||
dtype = dtypes.result_type(arg, onp.complex64)
|
||||
dtype = dtypes.result_type(arg, np.complex64)
|
||||
# XLA's FFT op only supports C64 in jaxlib versions 0.1.47 and earlier.
|
||||
# TODO(phawkins): remove when minimum jaxlib version is 0.1.48 or newer.
|
||||
if lib.version <= (0, 1, 47) and dtype == onp.complex128:
|
||||
dtype = onp.complex64
|
||||
if lib.version <= (0, 1, 47) and dtype == np.complex128:
|
||||
dtype = np.complex64
|
||||
return lax.convert_element_type(arg, dtype)
|
||||
|
||||
def _promote_to_real(arg):
|
||||
dtype = dtypes.result_type(arg, onp.float32)
|
||||
dtype = dtypes.result_type(arg, np.float32)
|
||||
# XLA's FFT op only supports F32.
|
||||
# TODO(phawkins): remove when minimum jaxlib version is 0.1.48 or newer.
|
||||
if lib.version <= (0, 1, 47) and dtype == onp.float64:
|
||||
dtype = onp.float32
|
||||
if lib.version <= (0, 1, 47) and dtype == np.float64:
|
||||
dtype = np.float32
|
||||
return lax.convert_element_type(arg, dtype)
|
||||
|
||||
def fft(x, fft_type, fft_lengths):
|
||||
if fft_type == xla_client.FftType.RFFT:
|
||||
if onp.iscomplexobj(x):
|
||||
if np.iscomplexobj(x):
|
||||
raise ValueError("only real valued inputs supported for rfft")
|
||||
x = _promote_to_real(x)
|
||||
else:
|
||||
@ -67,8 +67,8 @@ def fft(x, fft_type, fft_lengths):
|
||||
def fft_impl(x, fft_type, fft_lengths):
|
||||
return xla.apply_primitive(fft_p, x, fft_type=fft_type, fft_lengths=fft_lengths)
|
||||
|
||||
_complex_dtype = lambda dtype: (onp.zeros((), dtype) + onp.zeros((), onp.complex64)).dtype
|
||||
_real_dtype = lambda dtype: onp.zeros((), dtype).real.dtype
|
||||
_complex_dtype = lambda dtype: (np.zeros((), dtype) + np.zeros((), np.complex64)).dtype
|
||||
_real_dtype = lambda dtype: np.zeros((), dtype).real.dtype
|
||||
_is_even = lambda x: x % 2 == 0
|
||||
|
||||
def fft_abstract_eval(x, fft_type, fft_lengths):
|
||||
|
@ -17,7 +17,7 @@ Parallelization primitives.
|
||||
|
||||
import collections
|
||||
|
||||
import numpy as onp
|
||||
import numpy as np
|
||||
|
||||
from jax import core
|
||||
from jax import ad_util
|
||||
@ -72,8 +72,8 @@ def psum(x, axis_name, *, axis_index_groups=None):
|
||||
"""
|
||||
_validate_axis_index_groups(axis_index_groups)
|
||||
leaves, treedef = tree_util.tree_flatten(x)
|
||||
leaves = [lax.convert_element_type(l, onp.int32)
|
||||
if dtypes.dtype(l) == onp.bool_ else l for l in leaves]
|
||||
leaves = [lax.convert_element_type(l, np.int32)
|
||||
if dtypes.dtype(l) == np.bool_ else l for l in leaves]
|
||||
out_flat = psum_p.bind(*leaves, axis_name=axis_name,
|
||||
axis_index_groups=axis_index_groups)
|
||||
return tree_util.tree_unflatten(treedef, out_flat)
|
||||
@ -327,7 +327,7 @@ def _psum_translation_rule(c, *args, replica_groups=None, platform=None):
|
||||
out = [None] * len(args)
|
||||
replica_groups_protos = xc.make_replica_groups(replica_groups)
|
||||
for dtype, (indices, dtype_args) in sorted(args_by_type.items()):
|
||||
is_complex = dtypes.issubdtype(dtype, onp.complexfloating)
|
||||
is_complex = dtypes.issubdtype(dtype, np.complexfloating)
|
||||
n = len(dtype_args)
|
||||
if is_complex:
|
||||
dtype_args = ([xops.Real(x) for x in dtype_args] +
|
||||
@ -355,7 +355,7 @@ def _notuple_psum_translation_rule(c, *args, replica_groups):
|
||||
psum = partial(_allreduce_translation_rule, lax.add_p, c,
|
||||
replica_groups=replica_groups)
|
||||
dtype = c.get_shape(val).numpy_dtype()
|
||||
if dtypes.issubdtype(dtype, onp.complexfloating):
|
||||
if dtypes.issubdtype(dtype, np.complexfloating):
|
||||
return xops.Complex(psum(xops.Real(val)), psum(xops.Imag(val)))
|
||||
else:
|
||||
return psum(val)
|
||||
@ -507,14 +507,14 @@ def _broadcasting_papply(prim, name, size, vals, axes, **params):
|
||||
if xdim is None:
|
||||
if x.shape:
|
||||
if x.shape[ydim] == 1:
|
||||
x = x.reshape(onp.delete(x.shape, ydim))
|
||||
x = x.reshape(np.delete(x.shape, ydim))
|
||||
else:
|
||||
x = _drop(x, ydim, name)
|
||||
return prim.bind(x, y, **params), ydim
|
||||
elif ydim is None:
|
||||
if y.shape:
|
||||
if y.shape[xdim] == 1:
|
||||
y = y.reshape(onp.delete(y.shape, xdim))
|
||||
y = y.reshape(np.delete(y.shape, xdim))
|
||||
else:
|
||||
y = _drop(y, xdim, name)
|
||||
return prim.bind(x, y, **params), xdim
|
||||
@ -525,11 +525,11 @@ def _broadcasting_papply(prim, name, size, vals, axes, **params):
|
||||
y_tosplit = xdim - int(ydim <= xdim)
|
||||
if y.shape[y_tosplit] == 1:
|
||||
y = _allgather(y, ydim, size, name)
|
||||
y = y.reshape(onp.delete(y.shape, xdim))
|
||||
y = y.reshape(np.delete(y.shape, xdim))
|
||||
return prim.bind(x, y, **params), ydim
|
||||
elif x.shape[x_tosplit] == 1:
|
||||
x = _allgather(x, xdim, size, name)
|
||||
x = x.reshape(onp.delete(x.shape, ydim))
|
||||
x = x.reshape(np.delete(x.shape, ydim))
|
||||
return prim.bind(x, y, **params), ydim
|
||||
else:
|
||||
x = all_to_all(x, name, x_tosplit, xdim)
|
||||
@ -565,7 +565,7 @@ def _reducer_papply(prim, collective, name, size, vals, papply_axes, axes, **kwa
|
||||
if not axes or papply_axis in axes:
|
||||
return collective(result, axis_name=name), None
|
||||
else:
|
||||
new_papply_axis = papply_axis - onp.sum(onp.less(other_axes, papply_axis))
|
||||
new_papply_axis = papply_axis - np.sum(np.less(other_axes, papply_axis))
|
||||
return result, new_papply_axis
|
||||
|
||||
def _defreducer(prim, collective_prim):
|
||||
@ -754,13 +754,13 @@ def _dot_general_papply_rule(name, size, vals, dims, dimension_numbers,
|
||||
def _reshape_papply_rule(name, size, vals, axes, new_sizes, dimensions):
|
||||
operand, = vals
|
||||
axis, = axes
|
||||
old_sizes = tuple(onp.insert(operand.shape, axis, size))
|
||||
old_sizes = tuple(np.insert(operand.shape, axis, size))
|
||||
|
||||
def filter_ones(xs):
|
||||
return filter(lambda x: x != 1, xs)
|
||||
|
||||
def find_new_axis(old_axis, old_sizes, new_sizes):
|
||||
left = onp.prod(old_sizes[:old_axis])
|
||||
left = np.prod(old_sizes[:old_axis])
|
||||
size = old_sizes[old_axis]
|
||||
prod = 1
|
||||
for i, cur_sz in enumerate(new_sizes):
|
||||
@ -829,7 +829,7 @@ def _conv_general_dilated_papply_rule(
|
||||
lhs_dim, rhs_dim = dims
|
||||
lhs_spec_batch_dim = dimension_numbers.lhs_spec[0]
|
||||
if rhs_dim is None and lhs_dim == lhs_spec_batch_dim:
|
||||
lhs = lax.reshape(lhs, tuple(onp.insert(lhs.shape, lhs_dim, 1)))
|
||||
lhs = lax.reshape(lhs, tuple(np.insert(lhs.shape, lhs_dim, 1)))
|
||||
out = lax.conv_general_dilated(
|
||||
lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
|
||||
dimension_numbers, feature_group_count, precision)
|
||||
@ -848,8 +848,8 @@ def _broadcast_in_dim_papply_rule(name, size, vals, dims, shape,
|
||||
raise ValueError(
|
||||
"broadcast_in_dim changes hidden dimension size: {} to {}".format(
|
||||
shape[dim], shape[out_dim]))
|
||||
sub_bdims = tuple(onp.delete(broadcast_dimensions, dim))
|
||||
sub_shape = tuple(onp.delete(shape, out_dim))
|
||||
sub_bdims = tuple(np.delete(broadcast_dimensions, dim))
|
||||
sub_shape = tuple(np.delete(shape, out_dim))
|
||||
return lax.broadcast_in_dim(operand, sub_shape, sub_bdims), out_dim
|
||||
|
||||
|
||||
@ -906,8 +906,8 @@ def _gather_papply_rule(
|
||||
start_index_map=dimension_numbers.start_index_map)
|
||||
out = lax.gather(operand, start_indices, dimension_numbers=dnums,
|
||||
slice_sizes=slice_sizes)
|
||||
out_dim = start_indices_dim + onp.sum(
|
||||
onp.less_equal(offset_dims, start_indices_dim))
|
||||
out_dim = start_indices_dim + np.sum(
|
||||
np.less_equal(offset_dims, start_indices_dim))
|
||||
return out, out_dim
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
@ -30,7 +30,7 @@ from absl import logging
|
||||
from ..config import flags
|
||||
from .. import util
|
||||
from .. import dtypes
|
||||
import numpy as onp # 'onp' rather than 'np' to distinguish from autograd.numpy
|
||||
import numpy as np
|
||||
import threading
|
||||
|
||||
try:
|
||||
@ -86,7 +86,7 @@ def get_compile_options(num_replicas, num_partitions, device_assignment=None,
|
||||
2,
|
||||
'get_compile_options: num_replicas=%s num_partitions=%s device_assignment=%s',
|
||||
num_replicas, num_partitions, device_assignment)
|
||||
device_assignment = onp.array(device_assignment)
|
||||
device_assignment = np.array(device_assignment)
|
||||
|
||||
# Allow 1D device assignment if num_partitions is 1.
|
||||
if (device_assignment.ndim == 1) and (num_partitions == 1):
|
||||
@ -288,9 +288,8 @@ def supported_numpy_dtypes():
|
||||
# TODO(mattjj,frostig): try to remove this function
|
||||
def normalize_to_xla_dtypes(val):
|
||||
"""Normalize dtypes in a value."""
|
||||
if hasattr(val, '__array__') or onp.isscalar(val):
|
||||
return onp.asarray(val,
|
||||
dtype=dtypes.canonicalize_dtype(dtypes.result_type(val)))
|
||||
if hasattr(val, '__array__') or np.isscalar(val):
|
||||
return np.asarray(val, dtype=dtypes.canonicalize_dtype(dtypes.result_type(val)))
|
||||
elif isinstance(val, (tuple, list)):
|
||||
return tuple(normalize_to_xla_dtypes(x) for x in val)
|
||||
raise TypeError('Can\'t convert to XLA: {}'.format(val))
|
||||
@ -361,7 +360,7 @@ def _sharding_to_proto(sharding: SpatialSharding):
|
||||
else:
|
||||
proto.type = xla_client.OpSharding.Type.OTHER
|
||||
proto.tile_assignment_dimensions = list(sharding)
|
||||
proto.tile_assignment_devices = list(range(onp.product(sharding)))
|
||||
proto.tile_assignment_devices = list(range(np.product(sharding)))
|
||||
return proto
|
||||
|
||||
def set_sharding(builder, op, sharding: SpatialSharding):
|
||||
@ -395,7 +394,7 @@ def _ndarray_constant_handler(c, val, canonicalize_types=True):
|
||||
special handling of arrays with any strides of size zero: for those, it
|
||||
generates appropriate calls to NumpyArrayConstant, Broadcast, and Transpose
|
||||
to avoid staging in large literals that might arise from np.zeros or np.ones
|
||||
or the output of lax.broadcast (which uses onp.broadcast_to which in turn
|
||||
or the output of lax.broadcast (which uses np.broadcast_to which in turn
|
||||
uses size-zero strides).
|
||||
|
||||
Args:
|
||||
@ -407,28 +406,28 @@ def _ndarray_constant_handler(c, val, canonicalize_types=True):
|
||||
staged into the XLA Computation.
|
||||
"""
|
||||
# TODO(mattjj): revise this to use xops.BroadcastInDim rather than Transpose
|
||||
if onp.any(onp.equal(0, val.strides)) and val.size > 0:
|
||||
zero_stride_axes, = onp.where(onp.equal(0, val.strides))
|
||||
other_axes, = onp.where(onp.not_equal(0, val.strides))
|
||||
if np.any(np.equal(0, val.strides)) and val.size > 0:
|
||||
zero_stride_axes, = np.where(np.equal(0, val.strides))
|
||||
other_axes, = np.where(np.not_equal(0, val.strides))
|
||||
collapsed_val = val[tuple(0 if ax in zero_stride_axes else slice(None)
|
||||
for ax in range(val.ndim))]
|
||||
xla_val = xops.Broadcast(
|
||||
_numpy_array_constant(c, collapsed_val, canonicalize_types),
|
||||
onp.take(val.shape, zero_stride_axes))
|
||||
permutation = onp.argsort(tuple(zero_stride_axes) + tuple(other_axes))
|
||||
np.take(val.shape, zero_stride_axes))
|
||||
permutation = np.argsort(tuple(zero_stride_axes) + tuple(other_axes))
|
||||
return xops.Transpose(xla_val, permutation)
|
||||
else:
|
||||
return _numpy_array_constant(c, val, canonicalize_types)
|
||||
register_constant_handler(onp.ndarray, _ndarray_constant_handler)
|
||||
register_constant_handler(np.ndarray, _ndarray_constant_handler)
|
||||
|
||||
|
||||
def _scalar_constant_handler(c, val, canonicalize_types=True):
|
||||
return _numpy_array_constant(c, val, canonicalize_types)
|
||||
|
||||
for scalar_type in [onp.int8, onp.int16, onp.int32, onp.int64,
|
||||
onp.uint8, onp.uint16, onp.uint32, onp.uint64,
|
||||
onp.float16, onp.float32, onp.float64, onp.float128,
|
||||
onp.bool_, onp.longlong]:
|
||||
for scalar_type in [np.int8, np.int16, np.int32, np.int64,
|
||||
np.uint8, np.uint16, np.uint32, np.uint64,
|
||||
np.float16, np.float32, np.float64, np.float128,
|
||||
np.bool_, np.longlong]:
|
||||
register_constant_handler(scalar_type, _scalar_constant_handler)
|
||||
|
||||
def _python_scalar_handler(dtype, c, val, canonicalize_dtypes=True):
|
||||
|
@ -17,7 +17,7 @@ import functools
|
||||
import itertools as it
|
||||
import types
|
||||
|
||||
import numpy as onp
|
||||
import numpy as np
|
||||
|
||||
|
||||
def safe_zip(*args):
|
||||
@ -233,7 +233,7 @@ def get_module_functions(module):
|
||||
continue
|
||||
attr = getattr(module, key)
|
||||
if isinstance(
|
||||
attr, (types.BuiltinFunctionType, types.FunctionType, onp.ufunc)):
|
||||
attr, (types.BuiltinFunctionType, types.FunctionType, np.ufunc)):
|
||||
module_fns[key] = attr
|
||||
return module_fns
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user