Cleanup: convert uses of import numpy as onp in library code (#3754)

This commit is contained in:
Jake Vanderplas 2020-07-14 13:05:31 -07:00 committed by GitHub
parent 512ed18d5a
commit a7c2cdea64
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 527 additions and 528 deletions

View File

@ -20,7 +20,7 @@ import time
from typing import Any, Optional, Union, Callable, List, Dict
from absl import flags
import numpy as onp
import numpy as np
from tabulate import tabulate
from jax.util import safe_zip
@ -59,7 +59,7 @@ def benchmark(f: Callable[[], Any], iters: Optional[int] = None,
if iters is None:
warmup = 1
else:
warmup = onp.clip(1, iters // 10, 10)
warmup = np.clip(1, iters // 10, 10)
for _ in range(warmup):
f()
@ -73,7 +73,7 @@ def benchmark(f: Callable[[], Any], iters: Optional[int] = None,
times.append(end - start)
count += 1
times_arr = onp.array(times)
times_arr = np.array(times)
print("---------Benchmark results for %s---------" % (name or f.__name__))
print("mean=%f std=%f %%std=%f total=%f" %
(times_arr.mean(), times_arr.std(), _pstd(times_arr), times_arr.sum()))

View File

@ -32,7 +32,7 @@ import threading
from typing import Any, Callable, Iterable, Optional, Sequence, Tuple, Union
from warnings import warn
import numpy as onp
import numpy as np
from contextlib import contextmanager
from . import core
@ -204,10 +204,10 @@ def disable_jit():
debugging, and avoid the tracer too, we can use the :py:func:`disable_jit`
context manager:
>>> import jax.numpy as np
>>> import jax
>>>
>>> with jax.disable_jit():
... print(f(np.array([1, 2, 3])))
... print(f(jax.numpy.array([1, 2, 3])))
...
Value of y is [2 4 6]
[5 7 9]
@ -339,7 +339,7 @@ def xla_computation(fun: Callable,
return xla.AxisEnv(nreps, names, sizes)
def abstractify(x):
return ShapedArray(onp.shape(x), dtypes.result_type(x))
return ShapedArray(np.shape(x), dtypes.result_type(x))
@wraps(fun)
def computation_maker(*args, **kwargs):
@ -474,7 +474,7 @@ def value_and_grad(fun: Callable, argnums: Union[int, Sequence[int]] = 0,
_check_scalar(ans)
dtype = dtypes.result_type(ans)
tree_map(partial(_check_output_dtype_grad, holomorphic), ans)
g = vjp_py(onp.ones((), dtype=dtype))
g = vjp_py(np.ones((), dtype=dtype))
g = g[0] if isinstance(argnums, int) else g
if not has_aux:
return ans, g
@ -500,12 +500,12 @@ def _check_input_dtype_revderiv(name, holomorphic, x):
_check_arg(x)
aval = core.get_aval(x)
if holomorphic:
if not dtypes.issubdtype(aval.dtype, onp.complexfloating):
if not dtypes.issubdtype(aval.dtype, np.complexfloating):
msg = (f"{name} with holomorphic=True requires inputs with complex dtype, "
f"but got {aval.dtype.name}.")
raise TypeError(msg)
elif not (dtypes.issubdtype(aval.dtype, onp.floating) or
dtypes.issubdtype(aval.dtype, onp.complexfloating)):
elif not (dtypes.issubdtype(aval.dtype, np.floating) or
dtypes.issubdtype(aval.dtype, np.complexfloating)):
msg = (f"{name} requires real- or complex-valued inputs (input dtype that "
"is a sub-dtype of np.floating or np.complexfloating), "
f"but got {aval.dtype.name}. ")
@ -515,11 +515,11 @@ _check_input_dtype_grad = partial(_check_input_dtype_revderiv, "grad")
def _check_output_dtype_revderiv(name, holomorphic, x):
aval = core.get_aval(x)
if holomorphic:
if not dtypes.issubdtype(aval.dtype, onp.complexfloating):
if not dtypes.issubdtype(aval.dtype, np.complexfloating):
msg = (f"{name} with holomorphic=True requires outputs with complex dtype, "
f"but got {aval.dtype.name}.")
raise TypeError(msg)
elif not dtypes.issubdtype(aval.dtype, onp.floating):
elif not dtypes.issubdtype(aval.dtype, np.floating):
msg = (f"{name} requires real-valued outputs (output dtype that is "
f"a sub-dtype of np.floating), but got {aval.dtype.name}. "
"For holomorphic differentiation, pass holomorphic=True. "
@ -545,13 +545,13 @@ def jacfwd(fun: Callable, argnums: Union[int, Sequence[int]] = 0,
``fun`` using forward-mode automatic differentiation.
>>> import jax
>>> import jax.numpy as np
>>> import jax.numpy as jnp
>>>
>>> def f(x):
... return jax.numpy.asarray(
... [x[0], 5*x[2], 4*x[1]**2 - 2*x[2], x[2] * jax.numpy.sin(x[0])])
... return jnp.asarray(
... [x[0], 5*x[2], 4*x[1]**2 - 2*x[2], x[2] * jnp.sin(x[0])])
...
>>> print(jax.jacfwd(f)(np.array([1., 2., 3.])))
>>> print(jax.jacfwd(f)(jnp.array([1., 2., 3.])))
[[ 1. 0. 0. ]
[ 0. 0. 5. ]
[ 0. 16. -2. ]
@ -575,12 +575,12 @@ def _check_input_dtype_jacfwd(holomorphic, x):
_check_arg(x)
aval = core.get_aval(x)
if holomorphic:
if not (dtypes.issubdtype(aval.dtype, onp.complexfloating) and
not dtypes.issubdtype(aval.dtype, onp.floating)):
if not (dtypes.issubdtype(aval.dtype, np.complexfloating) and
not dtypes.issubdtype(aval.dtype, np.floating)):
msg = ("jacfwd with holomorphic=True requires inputs with complex dtype, "
f"but got {aval.dtype.name}.")
raise TypeError(msg)
elif not dtypes.issubdtype(aval.dtype, onp.floating):
elif not dtypes.issubdtype(aval.dtype, np.floating):
msg = ("jacfwd requires real-valued inputs (input dtype that is "
f"a sub-dtype of np.floating), but got {aval.dtype.name}. "
"For holomorphic differentiation, pass holomorphic=True. "
@ -591,7 +591,7 @@ def _check_input_dtype_jacfwd(holomorphic, x):
def _check_output_dtype_jacfwd(holomorphic, x):
aval = core.get_aval(x)
if holomorphic:
if not dtypes.issubdtype(aval.dtype, onp.complexfloating):
if not dtypes.issubdtype(aval.dtype, np.complexfloating):
msg = ("jacfwd with holomorphic=True requires outputs with complex dtype, "
f"but got {aval.dtype.name}.")
raise TypeError(msg)
@ -613,13 +613,13 @@ def jacrev(fun: Callable, argnums: Union[int, Sequence[int]] = 0,
``fun`` using reverse-mode automatic differentiation.
>>> import jax
>>> import jax.numpy as np
>>> import jax.numpy as jnp
>>>
>>> def f(x):
... return jax.numpy.asarray(
... [x[0], 5*x[2], 4*x[1]**2 - 2*x[2], x[2] * jax.numpy.sin(x[0])])
... return jnp.asarray(
... [x[0], 5*x[2], 4*x[1]**2 - 2*x[2], x[2] * jnp.sin(x[0])])
...
>>> print(jax.jacrev(f)(np.array([1., 2., 3.])))
>>> print(jax.jacrev(f)(jnp.array([1., 2., 3.])))
[[ 1. 0. 0. ]
[ 0. 0. 5. ]
[ 0. 16. -2. ]
@ -711,23 +711,23 @@ def hessian(fun: Callable, argnums: Union[int, Sequence[int]] = 0,
def _std_basis(pytree):
leaves, _ = tree_flatten(pytree)
ndim = sum(map(onp.size, leaves))
ndim = sum(map(np.size, leaves))
# TODO(mattjj): use a symbolic identity matrix here
dtype = dtypes.result_type(*leaves)
flat_basis = onp.eye(ndim, dtype=dtype)
flat_basis = np.eye(ndim, dtype=dtype)
return _unravel_array_into_pytree(pytree, 1, flat_basis)
def _unravel_array_into_pytree(pytree, axis, arr):
leaves, treedef = tree_flatten(pytree)
axis = axis % arr.ndim
shapes = [arr.shape[:axis] + onp.shape(l) + arr.shape[axis+1:] for l in leaves]
parts = _split(arr, onp.cumsum(map(onp.size, leaves[:-1])), axis)
reshaped_parts = [onp.reshape(x, shape) for x, shape in zip(parts, shapes)]
shapes = [arr.shape[:axis] + np.shape(l) + arr.shape[axis+1:] for l in leaves]
parts = _split(arr, np.cumsum(map(np.size, leaves[:-1])), axis)
reshaped_parts = [np.reshape(x, shape) for x, shape in zip(parts, shapes)]
return tree_unflatten(treedef, reshaped_parts)
def _split(x, indices, axis):
if isinstance(x, onp.ndarray):
return onp.split(x, indices, axis)
if isinstance(x, np.ndarray):
return np.split(x, indices, axis)
else:
return x.split(indices, axis)
@ -771,9 +771,9 @@ def vmap(fun: Callable, in_axes=0, out_axes=0) -> Callable:
For example, we can implement a matrix-matrix product using a vector dot
product:
>>> import jax.numpy as np
>>> import jax.numpy as jnp
>>>
>>> vv = lambda x, y: np.vdot(x, y) # ([a], [a]) -> []
>>> vv = lambda x, y: jnp.vdot(x, y) # ([a], [a]) -> []
>>> mv = vmap(vv, (0, None), 0) # ([b,a], [a]) -> [b] (b is the mapped axis)
>>> mm = vmap(mv, (None, 1), 1) # ([b,a], [a,c]) -> [b,c] (c is the mapped axis)
@ -788,21 +788,21 @@ def vmap(fun: Callable, in_axes=0, out_axes=0) -> Callable:
axes of the container elements to map over:
>>> A, B, C, D = 2, 3, 4, 5
>>> x = np.ones((A, B))
>>> y = np.ones((B, C))
>>> z = np.ones((C, D))
>>> x = jnp.ones((A, B))
>>> y = jnp.ones((B, C))
>>> z = jnp.ones((C, D))
>>> def foo(tree_arg):
... x, (y, z) = tree_arg
... return np.dot(x, np.dot(y, z))
... return jnp.dot(x, jnp.dot(y, z))
>>> tree = (x, (y, z))
>>> print(foo(tree))
[[12. 12. 12. 12. 12.]
[12. 12. 12. 12. 12.]]
>>> from jax import vmap
>>> K = 6 # batch size
>>> x = np.ones((K, A, B)) # batch axis in different locations
>>> y = np.ones((B, K, C))
>>> z = np.ones((C, D, K))
>>> x = jnp.ones((K, A, B)) # batch axis in different locations
>>> y = jnp.ones((B, K, C))
>>> z = jnp.ones((C, D, K))
>>> tree = (x, (y, z))
>>> vfoo = vmap(foo, in_axes=((0, (1, 2)),))
>>> print(vfoo(tree).shape)
@ -811,7 +811,7 @@ def vmap(fun: Callable, in_axes=0, out_axes=0) -> Callable:
Here's another example using container types in ``in_axes``, this time a
dictionary, to specify the elements of the container to map over:
>>> dct = {'a': 0., 'b': np.arange(5.)}
>>> dct = {'a': 0., 'b': jnp.arange(5.)}
>>> x = 1.
>>> def foo(dct, x):
... return dct['a'] + dct['b'] + x
@ -824,13 +824,13 @@ def vmap(fun: Callable, in_axes=0, out_axes=0) -> Callable:
element mapped and the second unmapped. Only for unmapped results
we can specify ``out_axes`` to be ``None`` (to keep it unmapped).
>>> print(vmap(lambda x, y: (x + y, y * 2.), in_axes=(0, None), out_axes=(0, None))(np.arange(2.), 4.))
>>> print(vmap(lambda x, y: (x + y, y * 2.), in_axes=(0, None), out_axes=(0, None))(jnp.arange(2.), 4.))
(DeviceArray([4., 5.], dtype=float32), 8.0)
If the ``out_axes`` is specified for an unmapped result, the result is broadcast
across the mapped axis:
>>> print(vmap(lambda x, y: (x + y, y * 2.), in_axes=(0, None), out_axes=0)(np.arange(2.), 4.))
>>> print(vmap(lambda x, y: (x + y, y * 2.), in_axes=(0, None), out_axes=0)(jnp.arange(2.), 4.))
(DeviceArray([4., 5.], dtype=float32), DeviceArray([8., 8.], dtype=float32))
If the ``out_axes`` is specified for a mapped result, the result is
@ -884,7 +884,7 @@ def _get_axis_size(name: str, i:int, shape: Tuple[int, ...], axis: int):
f"but axis to be mapped {axis}") from e
def _mapped_axis_size(tree, vals, dims, name):
mapped_axis_sizes = {_get_axis_size(name, i, onp.shape(x), d)
mapped_axis_sizes = {_get_axis_size(name, i, np.shape(x), d)
for i, (x, d) in enumerate(zip(vals, dims))
if d is not None}
try:
@ -1005,18 +1005,18 @@ def pmap(fun: Callable, axis_name: Optional[AxisName] = None, *, in_axes=0,
For example, assuming 8 XLA devices are available, :py:func:`pmap` can be used as a
map along a leading array axis:
>>> import jax.numpy as np
>>> import jax.numpy as jnp
>>>
>>> out = pmap(lambda x: x ** 2)(np.arange(8)) # doctest: +SKIP
>>> out = pmap(lambda x: x ** 2)(jnp.arange(8)) # doctest: +SKIP
>>> print(out) # doctest: +SKIP
[0, 1, 4, 9, 16, 25, 36, 49]
When the leading dimension is smaller than the number of available devices JAX
will simply run on a subset of devices:
>>> x = np.arange(3 * 2 * 2.).reshape((3, 2, 2))
>>> y = np.arange(3 * 2 * 2.).reshape((3, 2, 2)) ** 2
>>> out = pmap(np.dot)(x, y) # doctest: +SKIP
>>> x = jnp.arange(3 * 2 * 2.).reshape((3, 2, 2))
>>> y = jnp.arange(3 * 2 * 2.).reshape((3, 2, 2)) ** 2
>>> out = pmap(jnp.dot)(x, y) # doctest: +SKIP
>>> print(out) # doctest: +SKIP
[[[ 4. 9.]
[ 12. 29.]]
@ -1028,14 +1028,14 @@ def pmap(fun: Callable, axis_name: Optional[AxisName] = None, *, in_axes=0,
If your leading dimension is larger than the number of available devices you
will get an error:
>>> pmap(lambda x: x ** 2)(np.arange(9)) # doctest: +SKIP
>>> pmap(lambda x: x ** 2)(jnp.arange(9)) # doctest: +SKIP
ValueError: ... requires 9 replicas, but only 8 XLA devices are available
As with :py:func:`vmap`, using ``None`` in ``in_axes`` indicates that an argument
doesn't have an extra axis and should be broadcasted, rather than mapped,
across the replicas:
>>> x, y = np.arange(2.), 4.
>>> x, y = jnp.arange(2.), 4.
>>> out = pmap(lambda x, y: (x + y, y * 2.), in_axes=(0, None))(x, y) # doctest: +SKIP
>>> print(out) # doctest: +SKIP
([4., 5.], [8., 8.])
@ -1048,7 +1048,7 @@ def pmap(fun: Callable, axis_name: Optional[AxisName] = None, *, in_axes=0,
collective operations. For example:
>>> f = lambda x: x / jax.lax.psum(x, axis_name='i')
>>> out = pmap(f, axis_name='i')(np.arange(4.)) # doctest: +SKIP
>>> out = pmap(f, axis_name='i')(jnp.arange(4.)) # doctest: +SKIP
>>> print(out) # doctest: +SKIP
[ 0. 0.16666667 0.33333334 0.5 ]
>>> print(out.sum()) # doctest: +SKIP
@ -1073,7 +1073,7 @@ def pmap(fun: Callable, axis_name: Optional[AxisName] = None, *, in_axes=0,
... doubly_normed = x / jax.lax.psum(x, ('rows', 'cols'))
... return row_normed, col_normed, doubly_normed
>>>
>>> x = np.arange(8.).reshape((4, 2))
>>> x = jnp.arange(8.).reshape((4, 2))
>>> row_normed, col_normed, doubly_normed = normalize(x) # doctest: +SKIP
>>> print(row_normed.sum(0)) # doctest: +SKIP
[ 1. 1.]
@ -1087,7 +1087,7 @@ def pmap(fun: Callable, axis_name: Optional[AxisName] = None, *, in_axes=0,
runs on two hosts with 4 XLA devices each:
>>> f = lambda x: x + jax.lax.psum(x, axis_name='i')
>>> data = np.arange(4) if jax.host_id() == 0 else np.arange(4,8)
>>> data = jnp.arange(4) if jax.host_id() == 0 else jnp.arange(4,8)
>>> out = pmap(f, axis_name='i')(data) # doctest: +SKIP
>>> print(out) # doctest: +SKIP
[28 29 30 31] # on host 0
@ -1096,7 +1096,7 @@ def pmap(fun: Callable, axis_name: Optional[AxisName] = None, *, in_axes=0,
Each host passes in a different length-4 array, corresponding to its 4 local
devices, and the psum operates over all 8 values. Conceptually, the two
length-4 arrays can be thought of as sharded length-8 array (in this example
equivalent to np.arange(8)) that is mapped over, with the length-8 mapped axis
equivalent to jnp.arange(8)) that is mapped over, with the length-8 mapped axis
given name 'i'. The pmap call on each host then returns the corresponding
length-4 output shard.
@ -1114,9 +1114,9 @@ def pmap(fun: Callable, axis_name: Optional[AxisName] = None, *, in_axes=0,
... def f2(x):
... return jax.lax.psum(x ** 2, axis_name='i')
>>>
>>> print(f1(np.arange(6.))) # doctest: +SKIP
>>> print(f1(jnp.arange(6.))) # doctest: +SKIP
[0. 0.06666667 0.13333333 0.2 0.26666667 0.33333333]
>>> print(f2(np.array([2., 3.]))) # doctest: +SKIP
>>> print(f2(jnp.array([2., 3.]))) # doctest: +SKIP
[ 13. 13.]
"""
# axis_size is an optional integer representing the global axis size.
@ -1301,7 +1301,7 @@ def mask(fun: Callable, in_shapes, out_shape) -> Callable:
if in_tree != in_shapes_tree:
raise TypeError(f"Tree mismatch: Input {in_tree} and shape spec {in_shapes_tree}.")
logical_env = {unique_ids[name] : val for name, val in logical_env.items()}
in_shapes = map(masking.finalize_spec, in_specs, map(onp.shape, args_flat))
in_shapes = map(masking.finalize_spec, in_specs, map(np.shape, args_flat))
padded_env = masking.bind_shapes(in_shapes, [x.shape for x in args_flat])
f = lu.wrap_init(fun)
flat_fun, out_tree_thunk = flatten_fun_nokwargs(f, in_tree)
@ -1314,7 +1314,7 @@ def mask(fun: Callable, in_shapes, out_shape) -> Callable:
return tuple(dim if dim is masking._monomorphic_dim else
masking.eval_poly(dim, padded_env) for dim in shape_spec)
masking.check_shapes(map(padded_spec, out_specs), out_spec_tree,
map(onp.shape, outs), out_tree, "Padded output")
map(np.shape, outs), out_tree, "Padded output")
return tree_unflatten(out_tree, outs)
return wrapped_fun
@ -1326,7 +1326,7 @@ def shapecheck(in_shapes, out_shape, fun: Callable):
out_specs, out_spec_tree = tree_flatten(out_shape)
out_specs = map(masking.parse_spec, out_specs)
flat_fun, out_tree_thunk = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
avals = map(partial(ShapedArray, dtype=onp.float32), in_shapes)
avals = map(partial(ShapedArray, dtype=np.float32), in_shapes)
out_shapes = [o.shape for o in pe.abstract_eval_fun(flat_fun.call_wrapped, *avals)]
masking.check_shapes(map(tuple, out_specs), out_spec_tree,
map(tuple, out_shapes), out_tree_thunk())
@ -1439,9 +1439,9 @@ def linearize(fun: Callable, *primals) -> Tuple[Any, Callable]:
Here's a more complete example of using :py:func:`linearize`:
>>> import jax
>>> import jax.numpy as np
>>> import jax.numpy as jnp
>>>
>>> def f(x): return 3. * np.sin(x) + np.cos(x / 2.)
>>> def f(x): return 3. * jnp.sin(x) + jnp.cos(x / 2.)
...
>>> jax.jvp(f, (2.,), (3.,))
(DeviceArray(3.26819, dtype=float32), DeviceArray(-5.00753, dtype=float32))
@ -1483,7 +1483,7 @@ def _lift_linearized(jaxpr, primal_avals, consts, io_tree, out_pvals, *py_args):
def _check_inexact_input_vjp(x):
aval = core.get_aval(x)
if not dtypes.issubdtype(aval.dtype, onp.inexact):
if not dtypes.issubdtype(aval.dtype, np.inexact):
msg = ("Primal inputs to reverse-mode differentiation must be of float "
"or complex type, got type {}")
raise TypeError(msg.format(aval.dtype.name))
@ -1703,9 +1703,9 @@ class ShapeDtypeStruct(object):
__slots__ = ["shape", "dtype"]
def __init__(self, shape, dtype):
self.shape = shape
self.dtype = onp.dtype(dtype)
self.dtype = np.dtype(dtype)
size = property(lambda self: onp.prod(self.shape))
size = property(lambda self: np.prod(self.shape))
ndim = property(lambda self: len(self.shape))
def __len__(self):
@ -1773,16 +1773,16 @@ def eval_shape(fun: Callable, *args, **kwargs):
For example:
>>> import jax
>>> import jax.numpy as np
>>> import jax.numpy as jnp
>>>
>>> f = lambda A, x: np.tanh(np.dot(A, x))
>>> f = lambda A, x: jnp.tanh(jnp.dot(A, x))
>>> class MyArgArray(object):
... def __init__(self, shape, dtype):
... self.shape = shape
... self.dtype = dtype
...
>>> A = MyArgArray((2000, 3000), np.float32)
>>> x = MyArgArray((3000, 1000), np.float32)
>>> A = MyArgArray((2000, 3000), jnp.float32)
>>> x = MyArgArray((3000, 1000), jnp.float32)
>>> out = jax.eval_shape(f, A, x) # no FLOPs performed
>>> print(out.shape)
(2000, 1000)
@ -1790,7 +1790,7 @@ def eval_shape(fun: Callable, *args, **kwargs):
float32
"""
def abstractify(x):
return ShapedArray(onp.shape(x), dtypes.result_type(x))
return ShapedArray(np.shape(x), dtypes.result_type(x))
args_flat, in_tree = tree_flatten((args, kwargs))
wrapped_fun, out_tree = flatten_fun(lu.wrap_init(fun), in_tree)
out = pe.abstract_eval_fun(wrapped_fun.call_wrapped,

View File

@ -26,7 +26,7 @@ from typing import (Any, Callable, ClassVar, Dict, Generator,
Iterator, List, NamedTuple, Optional, Sequence, Set, Tuple,
Type, Union, cast)
import numpy as onp
import numpy as np
from . import dtypes
from .config import FLAGS
@ -846,7 +846,7 @@ class UnshapedArray(AbstractValue):
array_abstraction_level = 2
def __init__(self, dtype, weak_type=False):
self.dtype = onp.dtype(dtypes.canonicalize_dtype(dtype))
self.dtype = np.dtype(dtypes.canonicalize_dtype(dtype))
self.weak_type = weak_type
def __eq__(self, other):
@ -858,7 +858,7 @@ class UnshapedArray(AbstractValue):
def __hash__(self):
# can use hash(self.dtype) and rely on the fact that numpy reuses base dtype
# objects, e.g. `onp.zeros(3).dtype is onp.zeros(4).dtype`, or we can use
# objects, e.g. `np.zeros(3).dtype is np.zeros(4).dtype`, or we can use
# the unique character code via hash(self.dtype.char)
return hash((self.dtype, self.weak_type))
@ -925,7 +925,7 @@ class ShapedArray(UnshapedArray):
def __hash__(self):
# can use hash(self.dtype) and rely on the fact that numpy reuses base dtype
# objects, e.g. `onp.zeros(3).dtype is onp.zeros(4).dtype`, or we can use
# objects, e.g. `np.zeros(3).dtype is np.zeros(4).dtype`, or we can use
# the unique character code via hash(self.dtype.char)
return hash((self.shape, self.dtype, self.weak_type))
@ -968,16 +968,16 @@ class ConcreteArray(ShapedArray):
array_abstraction_level = 0
def __init__(self, val, weak_type=False):
super(ConcreteArray, self).__init__(onp.shape(val), onp.result_type(val),
super(ConcreteArray, self).__init__(np.shape(val), np.result_type(val),
weak_type=weak_type)
# Note: canonicalized self.dtype doesn't necessarily match self.val
self.val = val
assert self.dtype != onp.dtype('O')
assert self.dtype != np.dtype('O')
def __eq__(self, other):
return (type(self) is type(other) and self.dtype == other.dtype
and self.shape == other.shape and self.weak_type == other.weak_type
and onp.all(self.val == other.val))
and np.all(self.val == other.val))
def __hash__(self):
return id(self.val)

View File

@ -21,7 +21,7 @@ not yet fine-tuned the performance of the resulting XLA compilation!
By default, loops and control-flow in JAX are executed and inlined during tracing.
For example, in the following code the `for` loop is unrolled during JAX tracing::
arr = onp.zeros(5)
arr = np.zeros(5)
for i in range(arr.shape[0]):
arr[i] += 2.
if i % 2 == 0:
@ -32,7 +32,7 @@ JAX operations, which require you to express the body of the loops and
conditionals as functions, and the array updates using a functional style that
returns an updated array, e.g.::
arr = onp.zeros(5)
arr = np.zeros(5)
def loop_body(i, acc_arr):
arr1 = ops.index_update(acc_arr, i, acc_arr[i] + 2.)
return lax.cond(i % 2 == 0,

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as onp
import numpy as np
from typing import Any, Callable, Dict, Optional, Tuple, Union
import jax
@ -103,7 +103,7 @@ class BatchTracer(Tracer):
return aval
elif type(aval) is ShapedArray:
assert 0 <= self.batch_dim < aval.ndim
new_shape = tuple(onp.delete(aval.shape, self.batch_dim))
new_shape = tuple(np.delete(aval.shape, self.batch_dim))
return ShapedArray(new_shape, aval.dtype)
else:
raise TypeError(aval)
@ -236,7 +236,7 @@ def broadcast_batcher(prim, args, dims, **params):
either an int indicating the batch dimension, or else `not_mapped`
indicating no batching.
"""
shapes = {(x.shape, d) for x, d in zip(args, dims) if onp.ndim(x)}
shapes = {(x.shape, d) for x, d in zip(args, dims) if np.ndim(x)}
if len(shapes) == 1:
# if there's only agreeing batch dims and scalars, just call the primitive
d = next(d for d in dims if d is not not_mapped)
@ -245,16 +245,16 @@ def broadcast_batcher(prim, args, dims, **params):
else:
size, = {shape[d] for shape, d in shapes if d is not not_mapped}
args = [bdim_at_front(x, d, size) for x, d in zip(args, dims)]
ndim = max(onp.ndim(x) for x in args) # special-case scalar broadcasting
ndim = max(np.ndim(x) for x in args) # special-case scalar broadcasting
args = [_handle_scalar_broadcasting(ndim, x, d) for x, d in zip(args, dims)]
out = prim.bind(*args, **params)
return (out, (0,) * len(out)) if prim.multiple_results else (out, 0)
def _handle_scalar_broadcasting(nd, x, d):
if d is not_mapped or nd == onp.ndim(x):
if d is not_mapped or nd == np.ndim(x):
return x
else:
return x.reshape(x.shape + (1,) * (nd - onp.ndim(x)))
return x.reshape(x.shape + (1,) * (nd - np.ndim(x)))
def defreducer(prim):
primitive_batchers[prim] = partial(reducer_batcher, prim)
@ -262,8 +262,8 @@ def defreducer(prim):
def reducer_batcher(prim, batched_args, batch_dims, axes, **params):
operand, = batched_args
bdim, = batch_dims
axes = tuple(onp.where(onp.less(axes, bdim), axes, onp.add(axes, 1)))
bdim_out = int(list(onp.delete(onp.arange(operand.ndim), axes)).index(bdim))
axes = tuple(np.where(np.less(axes, bdim), axes, np.add(axes, 1)))
bdim_out = int(list(np.delete(np.arange(operand.ndim), axes)).index(bdim))
if 'input_shape' in params:
params = dict(params, input_shape=operand.shape)
return prim.bind(operand, axes=axes, **params), bdim_out
@ -303,10 +303,10 @@ def broadcast(x, sz, axis):
if core.get_aval(x) is core.abstract_unit:
return core.unit
if axis is last:
axis = onp.ndim(x)
shape = list(onp.shape(x))
axis = np.ndim(x)
shape = list(np.shape(x))
shape.insert(axis, sz)
broadcast_dims = tuple(onp.delete(onp.arange(len(shape)), axis))
broadcast_dims = tuple(np.delete(np.arange(len(shape)), axis))
return jax.lax.broadcast_in_dim(x, shape, broadcast_dims)
def moveaxis(x, src, dst):
@ -315,7 +315,7 @@ def moveaxis(x, src, dst):
if src == dst:
return x
src, dst = src % x.ndim, dst % x.ndim
perm = [i for i in range(onp.ndim(x)) if i != src]
perm = [i for i in range(np.ndim(x)) if i != src]
perm.insert(dst, src)
return x.transpose(perm)

View File

@ -20,7 +20,7 @@ import operator as op
import string
from typing import Callable, Dict, Sequence, Union
import numpy as onp
import numpy as np
from .. import abstract_arrays
from .. import core, dtypes
@ -317,7 +317,7 @@ def parse_spec(spec=''):
def _parse_dim(spec):
if '+' in spec:
return onp.sum(map(_parse_dim, spec.split('+')))
return np.sum(map(_parse_dim, spec.split('+')))
elif '*' in spec:
return prod(map(_parse_dim, spec.split('*')))
elif spec.isdigit() or spec.startswith('-') and spec[1:].isdigit():
@ -383,10 +383,10 @@ class MaskTracer(Tracer):
class MaskTrace(Trace):
def pure(self, val):
return MaskTracer(self, val, onp.shape(val))
return MaskTracer(self, val, np.shape(val))
def lift(self, val):
return MaskTracer(self, val, onp.shape(val))
return MaskTracer(self, val, np.shape(val))
def sublift(self, val):
return MaskTracer(self, val.val, val.polymorphic_shape)

View File

@ -20,7 +20,7 @@ from typing import (Callable, Dict, NamedTuple, Optional, Sequence,
Set, Tuple, Type, Union, cast)
from weakref import ref
import numpy as onp
import numpy as np
from .. import core
from .. import linear_util as lu
@ -128,7 +128,7 @@ class JaxprTrace(Trace):
if const is None:
return tracer
else:
if type(const) in core.literalable_types and onp.shape(const) == ():
if type(const) in core.literalable_types and np.shape(const) == ():
return self.new_instantiated_literal(const)
else:
return self.new_instantiated_const(const)
@ -138,7 +138,7 @@ class JaxprTrace(Trace):
if const is None:
return tracer
else:
aval = raise_to_shaped(get_aval(const), onp.isscalar(const))
aval = raise_to_shaped(get_aval(const), np.isscalar(const))
return JaxprTracer(self, PartialVal.unknown(aval), ConstVar(const))
def process_primitive(self, primitive, tracers, params):

View File

@ -37,7 +37,7 @@ from typing import (Any, Callable, Dict, List, Optional, Sequence, Set, Tuple,
Type, Union)
from absl import logging
import numpy as onp
import numpy as np
from ..config import flags
from .. import core
@ -465,7 +465,7 @@ def _axis_index_bind(*, axis_name):
nreps = dynamic_axis_env.nreps
trace = frame.pmap_trace
out_aval = ShapedArray((), onp.int32)
out_aval = ShapedArray((), np.int32)
out_tracer = pe.JaxprTracer(trace, pe.PartialVal.unknown(out_aval), None)
eqn = pe.new_eqn_recipe([], [out_tracer], axis_index_p,
dict(nreps=nreps, sizes=sizes,
@ -476,19 +476,19 @@ def _axis_index_bind(*, axis_name):
if not frame.soft_trace:
return out_tracer
else:
val_out = out_tracer * frame.soft_size + onp.arange(frame.soft_size)
val_out = out_tracer * frame.soft_size + np.arange(frame.soft_size)
return SplitAxisTracer(frame.soft_trace, axis_name, val_out)
def _axis_index_translation_rule(c, nreps, sizes, soft_size, axis_name):
div = xb.constant(c, onp.array(nreps // prod(sizes), dtype=onp.uint32))
mod = xb.constant(c, onp.array(sizes[-1], dtype=onp.uint32))
div = xb.constant(c, np.array(nreps // prod(sizes), dtype=np.uint32))
mod = xb.constant(c, np.array(sizes[-1], dtype=np.uint32))
unsigned_index = xops.Rem(xops.Div(xops.ReplicaId(c), div), mod)
return xops.ConvertElementType(unsigned_index, xb.dtype_to_etype(onp.int32))
return xops.ConvertElementType(unsigned_index, xb.dtype_to_etype(np.int32))
axis_index_p = core.Primitive('axis_index')
axis_index_p.def_custom_bind(_axis_index_bind)
axis_index_p.def_abstract_eval(
lambda *args, **params: ShapedArray((), onp.int32))
lambda *args, **params: ShapedArray((), np.int32))
xla.translations[axis_index_p] = _axis_index_translation_rule
@ -587,7 +587,7 @@ class ShardedDeviceArray(xla.DeviceArray):
def _value(self):
if self._npy_value is None:
self.copy_to_host_async()
npy_value = onp.empty(self.aval.shape, self.aval.dtype)
npy_value = np.empty(self.aval.shape, self.aval.dtype)
for i in self.one_replica_buffer_indices:
npy_value[self.indices[i]] = self.device_buffers[i].to_py()
self._npy_value = npy_value
@ -633,7 +633,7 @@ def _shard_sharded_device_array_slow_path(x, devices, indices):
shard_arg_handlers[ShardedDeviceArray] = _shard_sharded_device_array_slow_path
def _sharded_device_array_constant_handler(c, val, canonicalize_types=True):
return xb.constant(c, onp.asarray(val), canonicalize_types=canonicalize_types)
return xb.constant(c, np.asarray(val), canonicalize_types=canonicalize_types)
xb.register_constant_handler(ShardedDeviceArray, _sharded_device_array_constant_handler)
core.pytype_aval_mappings[ShardedDeviceArray] = ConcreteArray
@ -838,7 +838,7 @@ def parallel_callable(fun, backend, axis_name, axis_size, global_axis_size,
# provided 1D list of devices).
device_assignment = tree_map(lambda d: d.id, devices)
# Convert to 2D in case it's 1D and we have > 1 partitions.
device_assignment = onp.array(device_assignment).reshape(
device_assignment = np.array(device_assignment).reshape(
(num_global_replicas, num_partitions))
compile_options = xb.get_compile_options(
num_replicas=num_global_replicas,
@ -933,7 +933,7 @@ def get_num_partitions(*partitions):
if len(partition_specs) == 0:
# Everything is specified as replicated (all Nones).
return None
num_partitions_set = set(onp.prod(spec) for spec in partition_specs)
num_partitions_set = set(np.prod(spec) for spec in partition_specs)
if len(num_partitions_set) > 1:
raise ValueError(
f"All partition specs must use the same number of total partitions, "
@ -1157,7 +1157,7 @@ def _xla_shard(c, aval, axis_env, x):
return x
elif isinstance(aval, ShapedArray):
dims = list(c.get_shape(x).dimensions())
zero = xb.constant(c, onp.zeros((), dtype=onp.uint32))
zero = xb.constant(c, np.zeros((), dtype=np.uint32))
idxs = [_unravel_index(c, axis_env)] + [zero] * (len(dims) - 1)
return xops.Reshape(xops.DynamicSlice(x, idxs, [1] + dims[1:]), dims[1:])
else:
@ -1169,16 +1169,16 @@ def _xla_unshard(c, aval, axis_env, x, backend):
return x
elif isinstance(aval, ShapedArray):
# TODO(mattjj): remove this logic when AllReduce PRED supported on CPU / GPU
convert_bool = (onp.issubdtype(aval.dtype, onp.bool_)
convert_bool = (np.issubdtype(aval.dtype, np.bool_)
and xb.get_backend(backend).platform in ('cpu', 'gpu'))
if convert_bool:
x = xops.ConvertElementType(x, xb.dtype_to_etype(onp.float32))
x = xops.ConvertElementType(x, xb.dtype_to_etype(np.float32))
xla_shape = c.get_shape(x)
dims = list(xla_shape.dimensions())
padded = xops.Broadcast(xb.constant(c, onp.array(0, xla_shape.numpy_dtype())),
padded = xops.Broadcast(xb.constant(c, np.array(0, xla_shape.numpy_dtype())),
[axis_env.sizes[-1]] + dims)
zero = xb.constant(c, onp.zeros((), dtype=onp.uint32))
zero = xb.constant(c, np.zeros((), dtype=np.uint32))
idxs = [_unravel_index(c, axis_env)] + [zero] * len(dims)
padded = xops.DynamicUpdateSlice(padded, xops.Reshape(x, [1] + dims), idxs)
replica_groups_protos = xc.make_replica_groups(
@ -1187,15 +1187,15 @@ def _xla_unshard(c, aval, axis_env, x, backend):
# TODO(mattjj): remove this logic when AllReduce PRED supported on CPU / GPU
if convert_bool:
nonzero = xops.Ne(out, xb.constant(c, onp.array(0, dtype=onp.float32)))
out = xops.ConvertElementType(nonzero, xb.dtype_to_etype(onp.bool_))
nonzero = xops.Ne(out, xb.constant(c, np.array(0, dtype=np.float32)))
out = xops.ConvertElementType(nonzero, xb.dtype_to_etype(np.bool_))
return out
else:
raise TypeError((aval, c.get_shape(x)))
def _unravel_index(c, axis_env):
div = xb.constant(c, onp.array(axis_env.nreps // prod(axis_env.sizes), onp.uint32))
mod = xb.constant(c, onp.array(axis_env.sizes[-1], onp.uint32))
div = xb.constant(c, np.array(axis_env.nreps // prod(axis_env.sizes), np.uint32))
mod = xb.constant(c, np.array(axis_env.sizes[-1], np.uint32))
return xops.Rem(xops.Div(xops.ReplicaId(c), div), mod)
@ -1278,7 +1278,7 @@ class SplitAxisTrace(core.Trace):
if primitive is axis_index_p:
dummy, = vals_in
hard_idx = primitive.bind(dummy, **params)
val_out = hard_idx * params['soft_size'] + onp.arange(params['soft_size'])
val_out = hard_idx * params['soft_size'] + np.arange(params['soft_size'])
return SplitAxisTracer(self, params['axis_name'], val_out)
elif all(axis_name is not_mapped for axis_name in names_in):
return primitive.bind(*vals_in, **params)

View File

@ -20,7 +20,7 @@ from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Type, Tup
from warnings import warn
from absl import logging
import numpy as onp
import numpy as np
from ..config import flags, bool_env
from .. import core
@ -74,11 +74,11 @@ def identity(x): return x
_scalar_types = dtypes.python_scalar_dtypes.keys()
# unit representation
def _make_unit(c): return xb.constant(c, onp.zeros((), dtype=onp.dtype('bool')))
def _make_abstract_unit(_): return xc.Shape.array_shape(onp.dtype('bool'), ())
def _make_unit(c): return xb.constant(c, np.zeros((), dtype=np.dtype('bool')))
def _make_abstract_unit(_): return xc.Shape.array_shape(np.dtype('bool'), ())
def _device_put_unit(_, device):
backend = xb.get_device_backend(device)
return backend.buffer_from_pyval(onp.zeros((), dtype=onp.dtype('bool')),
return backend.buffer_from_pyval(np.zeros((), dtype=np.dtype('bool')),
device)
def _make_array_shape(a):
return xc.Shape.array_shape(a.dtype, a.shape)
@ -143,10 +143,10 @@ def canonicalize_dtype(x):
raise TypeError(f"No canonicalize_dtype handler for type: {type(x)}")
def _canonicalize_ndarray_dtype(x):
return onp.asarray(x, dtypes.canonicalize_dtype(dtypes.result_type(x)))
return np.asarray(x, dtypes.canonicalize_dtype(dtypes.result_type(x)))
def _canonicalize_python_scalar_dtype(typ, x):
return onp.asarray(
return np.asarray(
x, dtypes.canonicalize_dtype(dtypes.python_scalar_dtypes[typ]))
canonicalize_dtype_handlers: Dict[Any, Callable] = {core.Unit: identity}
@ -342,8 +342,8 @@ def check_nans(prim, bufs):
def _check_nans(name, xla_shape, buf):
assert not xla_shape.is_tuple()
if dtypes.issubdtype(xla_shape.element_type(), onp.inexact):
if onp.any(onp.isnan(buf.to_py())):
if dtypes.issubdtype(xla_shape.element_type(), np.inexact):
if np.any(np.isnan(buf.to_py())):
raise FloatingPointError(f"invalid value (nan) encountered in {name}")
### compiling jaxprs
@ -477,10 +477,10 @@ def _axis_groups(nrep, mesh_spec, mesh_axes):
trailing_size, ragged = divmod(nrep, prod(mesh_spec))
assert not ragged
full_spec = list(mesh_spec) + [trailing_size]
iota = onp.arange(prod(full_spec)).reshape(full_spec)
groups = onp.reshape(
onp.moveaxis(iota, mesh_axes, onp.arange(len(mesh_axes))),
(prod(onp.take(full_spec, mesh_axes)), -1))
iota = np.arange(prod(full_spec)).reshape(full_spec)
groups = np.reshape(
np.moveaxis(iota, mesh_axes, np.arange(len(mesh_axes))),
(prod(np.take(full_spec, mesh_axes)), -1))
return tuple(unsafe_map(tuple, groups.T))
def jaxpr_replicas(jaxpr):
@ -862,7 +862,7 @@ call_translations[xla_call_p] = _xla_call_translation_rule
def zeros_like_translation_rule(c, x):
shape = c.get_shape(x)
assert not shape.is_tuple()
zero = xb.constant(c, onp.array(0, shape.element_type()))
zero = xb.constant(c, np.array(0, shape.element_type()))
return xops.Broadcast(zero, shape.dimensions())
translations[ad_util.zeros_like_p] = zeros_like_translation_rule
@ -1018,7 +1018,7 @@ class DeviceArray(DeviceValue):
def copy(self):
"""Returns an ndarray (backed by host memory, not device memory)."""
return onp.asarray(self)
return np.asarray(self)
def copy_to_host_async(self):
"""Requests a copy of the buffer to the host."""
@ -1042,10 +1042,10 @@ class DeviceArray(DeviceValue):
self._npy_value = None
def __repr__(self):
line_width = onp.get_printoptions()['linewidth']
line_width = np.get_printoptions()['linewidth']
prefix = '{}('.format(self.__class__.__name__)
s = onp.array2string(self._value, prefix=prefix, suffix=',',
separator=', ', max_line_width=line_width)
s = np.array2string(self._value, prefix=prefix, suffix=',',
separator=', ', max_line_width=line_width)
dtype_str = 'dtype={})'.format(self.dtype.name)
last_line_len = len(s) - s.rfind('\n') + 1
sep = ' '
@ -1054,13 +1054,13 @@ class DeviceArray(DeviceValue):
return "{}{},{}{}".format(prefix, s, sep, dtype_str)
def item(self):
if dtypes.issubdtype(self.dtype, onp.complexfloating):
if dtypes.issubdtype(self.dtype, np.complexfloating):
return complex(self)
elif dtypes.issubdtype(self.dtype, onp.floating):
elif dtypes.issubdtype(self.dtype, np.floating):
return float(self)
elif dtypes.issubdtype(self.dtype, onp.integer):
elif dtypes.issubdtype(self.dtype, np.integer):
return int(self)
elif dtypes.issubdtype(self.dtype, onp.bool_):
elif dtypes.issubdtype(self.dtype, np.bool_):
return bool(self)
else:
raise TypeError(self.dtype)
@ -1091,7 +1091,7 @@ class DeviceArray(DeviceValue):
return format(self._value, format_spec)
def __array__(self, dtype=None, context=None):
return onp.asarray(self._value, dtype=dtype)
return np.asarray(self._value, dtype=dtype)
@property
def __cuda_array_interface__(self):
@ -1251,10 +1251,10 @@ def _remat_translation_rule(c, axis_env, in_nodes,
Conditional."""
del device, concrete # Unused.
# Fake condition which always selects True branch.
rng = xops.RngUniform(xb.constant(c, onp.array(0, dtype=onp.float32)),
xb.constant(c, onp.array(1, dtype=onp.float32)),
rng = xops.RngUniform(xb.constant(c, np.array(0, dtype=np.float32)),
xb.constant(c, np.array(1, dtype=np.float32)),
xc.Shape.array_shape(xc.PrimitiveType.F32, []))
pred = xops.Lt(rng, xb.constant(c, onp.array(2, dtype=onp.float32)))
pred = xops.Lt(rng, xb.constant(c, np.array(2, dtype=np.float32)))
true_op = xops.Tuple(c, in_nodes)
remat_subc = xb.make_computation_builder("remat_call_subcomputation")
@ -1272,7 +1272,7 @@ def _remat_translation_rule(c, axis_env, in_nodes,
def zeros(xla_shape):
shape, dtype = xla_shape.dimensions(), xla_shape.numpy_dtype()
zero = xb.constant(dummy_subc, onp.array(0, dtype=dtype))
zero = xb.constant(dummy_subc, np.array(0, dtype=dtype))
return xops.Broadcast(zero, shape)
out_nodes = [zeros(s) for s in out_node_shapes]
dummy_subc = dummy_subc.build(xops.Tuple(dummy_subc, out_nodes))

File diff suppressed because it is too large Load Diff

View File

@ -24,7 +24,7 @@ import itertools
import operator
from typing import Callable, Sequence
import numpy as onp
import numpy as np
import jax
from jax import core
@ -273,7 +273,7 @@ def while_loop(cond_fun, body_fun, init_val):
if not treedef_is_leaf(cond_tree) or len(cond_jaxpr.out_avals) != 1:
msg = "cond_fun must return a boolean scalar, but got pytree {}."
raise TypeError(msg.format(cond_tree))
if cond_jaxpr.out_avals[0].strip_weak_type() != ShapedArray((), onp.bool_):
if cond_jaxpr.out_avals[0].strip_weak_type() != ShapedArray((), np.bool_):
msg = "cond_fun must return a boolean scalar, but got output type(s) {}."
raise TypeError(msg.format(cond_jaxpr.out_avals))
@ -313,9 +313,9 @@ def _while_loop_translation_rule(c, axis_env, name_stack, avals, backend, *args,
cond_jaxpr.literals),
extend_name_stack(name_stack, 'cond'), *(x + z))
if batched:
scalar = ShapedArray((), onp.bool_)
scalar = ShapedArray((), np.bool_)
or_ = xla.primitive_subcomputation(lax.or_p, scalar, scalar)
pred = xops.Reduce(cond_c, [pred], [xb.constant(cond_c, onp.array(False))], or_,
pred = xops.Reduce(cond_c, [pred], [xb.constant(cond_c, np.array(False))], or_,
list(range(cond_jaxpr.out_avals[0].ndim)))
body_c = xb.make_computation_builder("body_computation")
@ -560,10 +560,10 @@ def switch(index, branches: Sequence[Callable], operand):
branches: Sequence of functions (A -> B) to be applied based on `index`.
operand: Operand (A) input to whichever branch is applied.
"""
if len(onp.shape(index)) != 0:
if len(np.shape(index)) != 0:
raise TypeError(
f"Branch index must be scalar, "
f"got {index} of shape {onp.shape(index)}.")
f"got {index} of shape {np.shape(index)}.")
try:
index_dtype = dtypes.result_type(index)
@ -582,9 +582,9 @@ def switch(index, branches: Sequence[Callable], operand):
elif len(branches) == 1:
return branches[0](operand)
index = lax.convert_element_type(index, onp.int32)
lo = onp.array(0, onp.int32)
hi = onp.array(len(branches) - 1, onp.int32)
index = lax.convert_element_type(index, np.int32)
lo = np.array(0, np.int32)
hi = np.array(len(branches) - 1, np.int32)
index = lax.clamp(lo, index, hi)
if (jax.api._jit_is_disabled() and
@ -651,9 +651,9 @@ def cond(*args, **kwargs):
return _cond(*args, **kwargs)
def _cond(pred, true_fun: Callable, false_fun: Callable, operand):
if len(onp.shape(pred)) != 0:
if len(np.shape(pred)) != 0:
raise TypeError(
f"Pred must be a scalar, got {pred} of shape {onp.shape(pred)}.")
f"Pred must be a scalar, got {pred} of shape {np.shape(pred)}.")
try:
pred_dtype = dtypes.result_type(pred)
@ -686,7 +686,7 @@ def _cond(pred, true_fun: Callable, false_fun: Callable, operand):
out_tree, true_jaxpr.out_avals,
false_out_tree, false_jaxpr.out_avals)
index = lax.convert_element_type(pred, onp.int32)
index = lax.convert_element_type(pred, np.int32)
linear = (False,) * (len(consts) + len(ops))
out = cond_p.bind(
@ -742,7 +742,7 @@ def _select_tree(indices, branch_vals):
if len(branch_vals) == 1:
return branch_vals[0]
mid = len(branch_vals) // 2
mid = onp.array(mid, dtypes.canonicalize_dtype(lax.dtype(indices)))
mid = np.array(mid, dtypes.canonicalize_dtype(lax.dtype(indices)))
return lax.select(lax.lt(indices, mid),
_select_tree(indices, branch_vals[:mid]),
_select_tree(indices - mid, branch_vals[mid:]))
@ -752,7 +752,7 @@ def _cond_index_bcast_and_select_tree(indices, branch_vals):
return branch_vals[0]
else:
bcast_indices = lax.broadcast_in_dim(
indices, onp.shape(branch_vals[0]), list(range(onp.ndim(indices))))
indices, np.shape(branch_vals[0]), list(range(np.ndim(indices))))
return _select_tree(bcast_indices, branch_vals)
def _cond_batching_rule(args, dims, branches, linear):
@ -1066,7 +1066,7 @@ def _cond_typecheck(*avals, branches, linear):
index_aval, *op_avals = avals
core.typecheck_assert(
index_aval.dtype == onp.int32,
index_aval.dtype == np.int32,
f'cond called with index of type {index_aval.dtype} instead of int32')
core.typecheck_assert(
all(_map(core.typecompat, jaxpr0.in_avals, op_avals)),
@ -1859,8 +1859,8 @@ def _flatten(args):
def _check_shapes(func_name, expected_name, actual, expected, tree):
actual_shapes = _map(onp.shape, actual)
expected_shapes = _map(onp.shape, expected)
actual_shapes = _map(np.shape, actual)
expected_shapes = _map(np.shape, expected)
if actual_shapes != expected_shapes:
raise ValueError('{}() output shapes must match {}, got {} and {}'
.format(func_name, expected_name,
@ -2102,18 +2102,18 @@ batching.primitive_batchers[linear_solve_p] = _linear_solve_batching_rule
def _interleave(a, b):
"""Given two Tensors of static shape, interleave them along the first axis."""
# TODO(mattjj)
import jax.numpy as np
import jax.numpy as jnp
# [a b c ...] [d e f ...] -> [a d b e c f ...]
half_num_elems = b.shape[0]
if a.shape[0] > b.shape[0]:
return np.concatenate(
[np.reshape(np.stack([a[: -1], b], axis=1),
(2 * half_num_elems,) + a.shape[1:]),
return jnp.concatenate(
[jnp.reshape(jnp.stack([a[: -1], b], axis=1),
(2 * half_num_elems,) + a.shape[1:]),
a[-1:]], axis=0)
else:
return np.reshape(np.stack([a, b], axis=1),
(2 * half_num_elems,) + a.shape[1:])
return jnp.reshape(jnp.stack([a, b], axis=1),
(2 * half_num_elems,) + a.shape[1:])
def associative_scan(fn, elems):
"""Perform a scan with an associative binary operation, in parallel.

View File

@ -15,7 +15,7 @@
from functools import partial
import numpy as onp
import numpy as np
from jax.abstract_arrays import ShapedArray
from jax.api import jit, vjp
@ -36,24 +36,24 @@ __all__ = [
]
def _promote_to_complex(arg):
dtype = dtypes.result_type(arg, onp.complex64)
dtype = dtypes.result_type(arg, np.complex64)
# XLA's FFT op only supports C64 in jaxlib versions 0.1.47 and earlier.
# TODO(phawkins): remove when minimum jaxlib version is 0.1.48 or newer.
if lib.version <= (0, 1, 47) and dtype == onp.complex128:
dtype = onp.complex64
if lib.version <= (0, 1, 47) and dtype == np.complex128:
dtype = np.complex64
return lax.convert_element_type(arg, dtype)
def _promote_to_real(arg):
dtype = dtypes.result_type(arg, onp.float32)
dtype = dtypes.result_type(arg, np.float32)
# XLA's FFT op only supports F32.
# TODO(phawkins): remove when minimum jaxlib version is 0.1.48 or newer.
if lib.version <= (0, 1, 47) and dtype == onp.float64:
dtype = onp.float32
if lib.version <= (0, 1, 47) and dtype == np.float64:
dtype = np.float32
return lax.convert_element_type(arg, dtype)
def fft(x, fft_type, fft_lengths):
if fft_type == xla_client.FftType.RFFT:
if onp.iscomplexobj(x):
if np.iscomplexobj(x):
raise ValueError("only real valued inputs supported for rfft")
x = _promote_to_real(x)
else:
@ -67,8 +67,8 @@ def fft(x, fft_type, fft_lengths):
def fft_impl(x, fft_type, fft_lengths):
return xla.apply_primitive(fft_p, x, fft_type=fft_type, fft_lengths=fft_lengths)
_complex_dtype = lambda dtype: (onp.zeros((), dtype) + onp.zeros((), onp.complex64)).dtype
_real_dtype = lambda dtype: onp.zeros((), dtype).real.dtype
_complex_dtype = lambda dtype: (np.zeros((), dtype) + np.zeros((), np.complex64)).dtype
_real_dtype = lambda dtype: np.zeros((), dtype).real.dtype
_is_even = lambda x: x % 2 == 0
def fft_abstract_eval(x, fft_type, fft_lengths):

View File

@ -17,7 +17,7 @@ Parallelization primitives.
import collections
import numpy as onp
import numpy as np
from jax import core
from jax import ad_util
@ -72,8 +72,8 @@ def psum(x, axis_name, *, axis_index_groups=None):
"""
_validate_axis_index_groups(axis_index_groups)
leaves, treedef = tree_util.tree_flatten(x)
leaves = [lax.convert_element_type(l, onp.int32)
if dtypes.dtype(l) == onp.bool_ else l for l in leaves]
leaves = [lax.convert_element_type(l, np.int32)
if dtypes.dtype(l) == np.bool_ else l for l in leaves]
out_flat = psum_p.bind(*leaves, axis_name=axis_name,
axis_index_groups=axis_index_groups)
return tree_util.tree_unflatten(treedef, out_flat)
@ -327,7 +327,7 @@ def _psum_translation_rule(c, *args, replica_groups=None, platform=None):
out = [None] * len(args)
replica_groups_protos = xc.make_replica_groups(replica_groups)
for dtype, (indices, dtype_args) in sorted(args_by_type.items()):
is_complex = dtypes.issubdtype(dtype, onp.complexfloating)
is_complex = dtypes.issubdtype(dtype, np.complexfloating)
n = len(dtype_args)
if is_complex:
dtype_args = ([xops.Real(x) for x in dtype_args] +
@ -355,7 +355,7 @@ def _notuple_psum_translation_rule(c, *args, replica_groups):
psum = partial(_allreduce_translation_rule, lax.add_p, c,
replica_groups=replica_groups)
dtype = c.get_shape(val).numpy_dtype()
if dtypes.issubdtype(dtype, onp.complexfloating):
if dtypes.issubdtype(dtype, np.complexfloating):
return xops.Complex(psum(xops.Real(val)), psum(xops.Imag(val)))
else:
return psum(val)
@ -507,14 +507,14 @@ def _broadcasting_papply(prim, name, size, vals, axes, **params):
if xdim is None:
if x.shape:
if x.shape[ydim] == 1:
x = x.reshape(onp.delete(x.shape, ydim))
x = x.reshape(np.delete(x.shape, ydim))
else:
x = _drop(x, ydim, name)
return prim.bind(x, y, **params), ydim
elif ydim is None:
if y.shape:
if y.shape[xdim] == 1:
y = y.reshape(onp.delete(y.shape, xdim))
y = y.reshape(np.delete(y.shape, xdim))
else:
y = _drop(y, xdim, name)
return prim.bind(x, y, **params), xdim
@ -525,11 +525,11 @@ def _broadcasting_papply(prim, name, size, vals, axes, **params):
y_tosplit = xdim - int(ydim <= xdim)
if y.shape[y_tosplit] == 1:
y = _allgather(y, ydim, size, name)
y = y.reshape(onp.delete(y.shape, xdim))
y = y.reshape(np.delete(y.shape, xdim))
return prim.bind(x, y, **params), ydim
elif x.shape[x_tosplit] == 1:
x = _allgather(x, xdim, size, name)
x = x.reshape(onp.delete(x.shape, ydim))
x = x.reshape(np.delete(x.shape, ydim))
return prim.bind(x, y, **params), ydim
else:
x = all_to_all(x, name, x_tosplit, xdim)
@ -565,7 +565,7 @@ def _reducer_papply(prim, collective, name, size, vals, papply_axes, axes, **kwa
if not axes or papply_axis in axes:
return collective(result, axis_name=name), None
else:
new_papply_axis = papply_axis - onp.sum(onp.less(other_axes, papply_axis))
new_papply_axis = papply_axis - np.sum(np.less(other_axes, papply_axis))
return result, new_papply_axis
def _defreducer(prim, collective_prim):
@ -754,13 +754,13 @@ def _dot_general_papply_rule(name, size, vals, dims, dimension_numbers,
def _reshape_papply_rule(name, size, vals, axes, new_sizes, dimensions):
operand, = vals
axis, = axes
old_sizes = tuple(onp.insert(operand.shape, axis, size))
old_sizes = tuple(np.insert(operand.shape, axis, size))
def filter_ones(xs):
return filter(lambda x: x != 1, xs)
def find_new_axis(old_axis, old_sizes, new_sizes):
left = onp.prod(old_sizes[:old_axis])
left = np.prod(old_sizes[:old_axis])
size = old_sizes[old_axis]
prod = 1
for i, cur_sz in enumerate(new_sizes):
@ -829,7 +829,7 @@ def _conv_general_dilated_papply_rule(
lhs_dim, rhs_dim = dims
lhs_spec_batch_dim = dimension_numbers.lhs_spec[0]
if rhs_dim is None and lhs_dim == lhs_spec_batch_dim:
lhs = lax.reshape(lhs, tuple(onp.insert(lhs.shape, lhs_dim, 1)))
lhs = lax.reshape(lhs, tuple(np.insert(lhs.shape, lhs_dim, 1)))
out = lax.conv_general_dilated(
lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
dimension_numbers, feature_group_count, precision)
@ -848,8 +848,8 @@ def _broadcast_in_dim_papply_rule(name, size, vals, dims, shape,
raise ValueError(
"broadcast_in_dim changes hidden dimension size: {} to {}".format(
shape[dim], shape[out_dim]))
sub_bdims = tuple(onp.delete(broadcast_dimensions, dim))
sub_shape = tuple(onp.delete(shape, out_dim))
sub_bdims = tuple(np.delete(broadcast_dimensions, dim))
sub_shape = tuple(np.delete(shape, out_dim))
return lax.broadcast_in_dim(operand, sub_shape, sub_bdims), out_dim
@ -906,8 +906,8 @@ def _gather_papply_rule(
start_index_map=dimension_numbers.start_index_map)
out = lax.gather(operand, start_indices, dimension_numbers=dnums,
slice_sizes=slice_sizes)
out_dim = start_indices_dim + onp.sum(
onp.less_equal(offset_dims, start_indices_dim))
out_dim = start_indices_dim + np.sum(
np.less_equal(offset_dims, start_indices_dim))
return out, out_dim
else:
raise NotImplementedError

View File

@ -30,7 +30,7 @@ from absl import logging
from ..config import flags
from .. import util
from .. import dtypes
import numpy as onp # 'onp' rather than 'np' to distinguish from autograd.numpy
import numpy as np
import threading
try:
@ -86,7 +86,7 @@ def get_compile_options(num_replicas, num_partitions, device_assignment=None,
2,
'get_compile_options: num_replicas=%s num_partitions=%s device_assignment=%s',
num_replicas, num_partitions, device_assignment)
device_assignment = onp.array(device_assignment)
device_assignment = np.array(device_assignment)
# Allow 1D device assignment if num_partitions is 1.
if (device_assignment.ndim == 1) and (num_partitions == 1):
@ -288,9 +288,8 @@ def supported_numpy_dtypes():
# TODO(mattjj,frostig): try to remove this function
def normalize_to_xla_dtypes(val):
"""Normalize dtypes in a value."""
if hasattr(val, '__array__') or onp.isscalar(val):
return onp.asarray(val,
dtype=dtypes.canonicalize_dtype(dtypes.result_type(val)))
if hasattr(val, '__array__') or np.isscalar(val):
return np.asarray(val, dtype=dtypes.canonicalize_dtype(dtypes.result_type(val)))
elif isinstance(val, (tuple, list)):
return tuple(normalize_to_xla_dtypes(x) for x in val)
raise TypeError('Can\'t convert to XLA: {}'.format(val))
@ -361,7 +360,7 @@ def _sharding_to_proto(sharding: SpatialSharding):
else:
proto.type = xla_client.OpSharding.Type.OTHER
proto.tile_assignment_dimensions = list(sharding)
proto.tile_assignment_devices = list(range(onp.product(sharding)))
proto.tile_assignment_devices = list(range(np.product(sharding)))
return proto
def set_sharding(builder, op, sharding: SpatialSharding):
@ -395,7 +394,7 @@ def _ndarray_constant_handler(c, val, canonicalize_types=True):
special handling of arrays with any strides of size zero: for those, it
generates appropriate calls to NumpyArrayConstant, Broadcast, and Transpose
to avoid staging in large literals that might arise from np.zeros or np.ones
or the output of lax.broadcast (which uses onp.broadcast_to which in turn
or the output of lax.broadcast (which uses np.broadcast_to which in turn
uses size-zero strides).
Args:
@ -407,28 +406,28 @@ def _ndarray_constant_handler(c, val, canonicalize_types=True):
staged into the XLA Computation.
"""
# TODO(mattjj): revise this to use xops.BroadcastInDim rather than Transpose
if onp.any(onp.equal(0, val.strides)) and val.size > 0:
zero_stride_axes, = onp.where(onp.equal(0, val.strides))
other_axes, = onp.where(onp.not_equal(0, val.strides))
if np.any(np.equal(0, val.strides)) and val.size > 0:
zero_stride_axes, = np.where(np.equal(0, val.strides))
other_axes, = np.where(np.not_equal(0, val.strides))
collapsed_val = val[tuple(0 if ax in zero_stride_axes else slice(None)
for ax in range(val.ndim))]
xla_val = xops.Broadcast(
_numpy_array_constant(c, collapsed_val, canonicalize_types),
onp.take(val.shape, zero_stride_axes))
permutation = onp.argsort(tuple(zero_stride_axes) + tuple(other_axes))
np.take(val.shape, zero_stride_axes))
permutation = np.argsort(tuple(zero_stride_axes) + tuple(other_axes))
return xops.Transpose(xla_val, permutation)
else:
return _numpy_array_constant(c, val, canonicalize_types)
register_constant_handler(onp.ndarray, _ndarray_constant_handler)
register_constant_handler(np.ndarray, _ndarray_constant_handler)
def _scalar_constant_handler(c, val, canonicalize_types=True):
return _numpy_array_constant(c, val, canonicalize_types)
for scalar_type in [onp.int8, onp.int16, onp.int32, onp.int64,
onp.uint8, onp.uint16, onp.uint32, onp.uint64,
onp.float16, onp.float32, onp.float64, onp.float128,
onp.bool_, onp.longlong]:
for scalar_type in [np.int8, np.int16, np.int32, np.int64,
np.uint8, np.uint16, np.uint32, np.uint64,
np.float16, np.float32, np.float64, np.float128,
np.bool_, np.longlong]:
register_constant_handler(scalar_type, _scalar_constant_handler)
def _python_scalar_handler(dtype, c, val, canonicalize_dtypes=True):

View File

@ -17,7 +17,7 @@ import functools
import itertools as it
import types
import numpy as onp
import numpy as np
def safe_zip(*args):
@ -233,7 +233,7 @@ def get_module_functions(module):
continue
attr = getattr(module, key)
if isinstance(
attr, (types.BuiltinFunctionType, types.FunctionType, onp.ufunc)):
attr, (types.BuiltinFunctionType, types.FunctionType, np.ufunc)):
module_fns[key] = attr
return module_fns