[JAX] Remove uses of the deprecated jax.experimental.vectorize.

jax.numpy.vectorize should be used instead.

PiperOrigin-RevId: 341836454
This commit is contained in:
Peter Hawkins 2020-11-11 08:34:22 -08:00 committed by jax authors
parent f1b14aa22d
commit 2b8d840cc3
4 changed files with 0 additions and 442 deletions

View File

@ -78,13 +78,6 @@ pytype_library(
deps = [":jax"],
)
pytype_library(
name = "vectorize",
srcs = ["experimental/vectorize.py"],
srcs_version = "PY3",
deps = [":jax"],
)
pytype_library(
name = "loops",
srcs = ["experimental/loops.py"],

View File

@ -1,289 +0,0 @@
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Extending JAX's vmap to work like NumPy's gufuncs.
By `Stephan Hoyer <https://github.com/shoyer>`_
What is a gufunc?
=================
`Generalized universal functions
<https://docs.scipy.org/doc/numpy-1.15.0/reference/c-api.generalized-ufuncs.html>`_
("gufuncs") are one of my favorite abstractions from NumPy. They generalize
NumPy's `broadcasting rules
<https://docs.scipy.org/doc/numpy-1.15.0/user/basics.broadcasting.html>`_ to
handle non-scalar operations. When a gufuncs is applied to arrays, there are:
* "core dimensions" over which an operation is defined.
* "broadcast dimensions" over which operations can be automatically vectorized.
A string `signature <https://docs.scipy.org/doc/numpy-1.15.0/reference/c-api.generalized-ufuncs.html#details-of-signature>`_
associated with each gufunc controls how this happens by indicating how core
dimensions are mapped between inputs and outputs. The syntax is easiest to
understand by looking at a few examples:
* Addition: `(),()->()`
* 1D inner product: `(i),(i)->()`
* 1D sum: `(i)->()`
* Matrix multiplication: `(m,n),(n,k)->(m,k)`
Why write gufuncs?
=====================
From a user perspective, gufuncs are nice because they're guaranteed to
vectorize in a consistent and general fashion. For example, by default gufuncs
use the last dimensions of arrays as core dimensions, but you can control that
explicitly with the ``axis`` or ``axes`` arguments.
From a developer perspective, gufuncs are nice because they simplify your work:
you only need to think about the core logic of your function, not how it
handles arbitrary dimensional input. You can just write that down in a simple,
declarative way.
JAX makes it easy to write high-level performant code
=====================================================
Unfortunately, writing NumPy gufuncs today is somewhat non-trivial. Your
options today are:
1. Write the inner loops yourself in C.
2. `np.vectorize <https://docs.scipy.org/doc/numpy/reference/generated/numpy.vectorize.html>`_
creates something kind of like a gufunc, but it's painfully slow: the outer loop is performed in Python.
3. `numba.guvectorize <https://numba.pydata.org/numba-doc/dev/user/vectorize.html>`_ can work well,
if you don't need further code transformations like automatic differentiation.
JAX's ``vmap`` contains all the core functionality we need to write functions that work like gufuncs.
JAX gufuncs play nicely with other transformations like ``grad`` and ``jit``.
A simple example
================
Consider a simple example from data preprocessing, centering an array.
Here's how we might write a vectorized version using NumPy::
def center(array, axis=-1):
# array can have any number of dimensions
bias = np.mean(array, axis=axis)
debiased = array - np.expand_dims(bias, axis)
return bias, debiased
And here's how we could write a vectorized version using JAX gufuncs::
@vectorize('(n)->(),(n)')
def center(array):
# array is always a 1D vector
bias = np.mean(array)
debiased = array - bias
return bias, debiased
See the difference?
* Instead of needing to think about broadcasting while writing the entire function,
we can write the function assuming the input is always a vector.
* We get the ``axis`` argument automatically, without needing to write it ourselves.
* As a bonus, the decorator makes the function self-documenting: a reader immediately
knows that it handles higher dimensional input and output correctly.
"""
import functools
import re
import warnings
from jax import vmap
import jax.numpy as jnp
import numpy as np
warnings.warn(
"jax.experimental.vectorize is deprecated and will be removed soon. Use "
"jax.numpy.vectorize instead.",
FutureWarning,
)
# See http://docs.scipy.org/doc/numpy/reference/c-api.generalized-ufuncs.html
_DIMENSION_NAME = r'\w+'
_CORE_DIMENSION_LIST = '(?:{0:}(?:,{0:})*)?'.format(_DIMENSION_NAME)
_ARGUMENT = r'\({}\)'.format(_CORE_DIMENSION_LIST)
_ARGUMENT_LIST = '{0:}(?:,{0:})*'.format(_ARGUMENT)
_SIGNATURE = '^{0:}->{0:}$'.format(_ARGUMENT_LIST)
def _parse_gufunc_signature(signature):
"""Parse string signatures for a generalized universal function.
Args:
signature : string
Generalized universal function signature, e.g., ``(m,n),(n,p)->(m,p)``
for ``np.matmul``.
Returns:
Tuple of input and output core dimensions parsed from the signature, each
of the form List[Tuple[str, ...]].
"""
if not re.match(_SIGNATURE, signature):
raise ValueError(
'not a valid gufunc signature: {}'.format(signature))
return tuple([tuple(re.findall(_DIMENSION_NAME, arg))
for arg in re.findall(_ARGUMENT, arg_list)]
for arg_list in signature.split('->'))
def _update_dim_sizes(dim_sizes, arg, core_dims):
"""Incrementally check and update core dimension sizes for a single argument.
Args:
dim_sizes : Dict[str, int]
Sizes of existing core dimensions. Will be updated in-place.
arg : ndarray
Argument to examine.
core_dims : Tuple[str, ...]
Core dimensions for this argument.
"""
if not core_dims:
return
num_core_dims = len(core_dims)
if arg.ndim < num_core_dims:
raise ValueError(
'%d-dimensional argument does not have enough '
'dimensions for all core dimensions %r'
% (arg.ndim, core_dims))
core_shape = arg.shape[-num_core_dims:]
for dim, size in zip(core_dims, core_shape):
if dim in dim_sizes:
if size != dim_sizes[dim]:
raise ValueError(
'inconsistent size for core dimension %r: %r vs %r'
% (dim, size, dim_sizes[dim]))
else:
dim_sizes[dim] = size
def _parse_input_dimensions(args, input_core_dims):
"""Parse broadcast and core dimensions for vectorize with a signature.
Args:
args : Tuple[ndarray, ...]
Tuple of input arguments to examine.
input_core_dims : List[Tuple[str, ...]]
List of core dimensions corresponding to each input.
Returns:
broadcast_shape : Tuple[int, ...]
Common shape to broadcast all non-core dimensions to.
dim_sizes : Dict[str, int]
Common sizes for named core dimensions.
"""
broadcast_args = []
dim_sizes = {}
for arg, core_dims in zip(args, input_core_dims):
_update_dim_sizes(dim_sizes, arg, core_dims)
ndim = arg.ndim - len(core_dims)
dummy_array = np.lib.stride_tricks.as_strided(0, arg.shape[:ndim])
broadcast_args.append(dummy_array)
broadcast_shape = np.lib.stride_tricks._broadcast_shape(*broadcast_args)
return broadcast_shape, dim_sizes
def _calculate_shapes(broadcast_shape, dim_sizes, list_of_core_dims):
"""Helper for calculating broadcast shapes with core dimensions."""
return [broadcast_shape + tuple(dim_sizes[dim] for dim in core_dims)
for core_dims in list_of_core_dims]
# adapted from np.vectorize (again authored by shoyer@)
def broadcast_with_core_dims(args, input_core_dims, output_core_dims):
if len(args) != len(input_core_dims):
raise TypeError('wrong number of positional arguments: '
'expected %r, got %r'
% (len(input_core_dims), len(args)))
broadcast_shape, dim_sizes = _parse_input_dimensions(
args, input_core_dims)
input_shapes = _calculate_shapes(broadcast_shape, dim_sizes,
input_core_dims)
args = [jnp.broadcast_to(arg, shape)
for arg, shape in zip(args, input_shapes)]
return args
def verify_axis_is_supported(input_core_dims, output_core_dims):
all_core_dims = set()
for input_or_output_core_dims in [input_core_dims, output_core_dims]:
for core_dims in input_or_output_core_dims:
all_core_dims.update(core_dims)
if len(core_dims) > 1:
raise ValueError('only one gufuncs with one core dim support axis')
def reorder_inputs(args, axis, input_core_dims):
return tuple(jnp.moveaxis(arg, axis, -1) if core_dims else arg
for arg, core_dims in zip(args, input_core_dims))
def reorder_outputs(result, axis, output_core_dims):
if not isinstance(result, tuple):
result = (result,)
result = tuple(jnp.moveaxis(res, -1, axis) if core_dims else res
for res, core_dims in zip(result, output_core_dims))
if len(result) == 1:
(result,) = result
return result
def vectorize(signature):
"""Vectorize a function using JAX.
Turns an arbitrary function into a numpy style "gufunc". Once
you specify the behavior of the core axis, the rest will be
broadcast naturally.
Args:
signature: an einsum style signature that defines how the core dimensions are mapped between inputs and outputs.
Returns:
The vectorized 'gufunc' that will automatically broadcast
while maintaining the specified core logic, the returned
function also has a new ``axis`` parameter for specifying
which axis should be treated as the core one.
"""
input_core_dims, output_core_dims = _parse_gufunc_signature(signature)
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
axis = kwargs.get('axis') # for python2 compat.
if axis is not None:
verify_axis_is_supported(input_core_dims, output_core_dims)
args = reorder_inputs(args, axis, input_core_dims)
broadcast_args = broadcast_with_core_dims(
args, input_core_dims, output_core_dims)
num_batch_dims = len(broadcast_args[0].shape) - len(input_core_dims[0])
vectorized_func = func
for _ in range(num_batch_dims):
vectorized_func = vmap(vectorized_func)
result = vectorized_func(*broadcast_args)
if axis is not None:
result = reorder_outputs(result, axis, output_core_dims)
return result
return wrapper
return decorator

View File

@ -3,7 +3,6 @@ filterwarnings =
error
ignore:No GPU/TPU found, falling back to CPU.:UserWarning
ignore:Explicitly requested dtype.*is not available.*:UserWarning
ignore:jax.experimental.vectorize is deprecated.*:FutureWarning
ignore:outfeed_receiver is unnecessary and deprecated:DeprecationWarning
ignore:jax.experimental.optix is deprecated.*:DeprecationWarning
# The rest are for experimental/jax_to_tf

View File

@ -1,145 +0,0 @@
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for Vectorize library."""
from absl.testing import absltest
from absl.testing import parameterized
from jax import numpy as jnp
from jax import test_util as jtu
from jax.experimental.vectorize import vectorize
from jax.config import config
config.parse_flags_with_absl()
matmat = vectorize('(n,m),(m,k)->(n,k)')(jnp.dot)
matvec = vectorize('(n,m),(m)->(n)')(jnp.dot)
vecmat = vectorize('(m),(m,k)->(k)')(jnp.dot)
vecvec = vectorize('(m),(m)->()')(jnp.dot)
@vectorize('(n)->()')
def magnitude(x):
return jnp.dot(x, x)
mean = vectorize('(n)->()')(jnp.mean)
@vectorize('()->(n)')
def stack_plus_minus(x):
return jnp.stack([x, -x])
@vectorize('(n)->(),(n)')
def center(array):
bias = jnp.mean(array)
debiased = array - bias
return bias, debiased
class VectorizeTest(jtu.JaxTestCase):
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_leftshape={}_rightshape={}".format(left_shape, right_shape),
"left_shape": left_shape, "right_shape": right_shape, "result_shape": result_shape}
for left_shape, right_shape, result_shape in [
((2, 3), (3, 4), (2, 4)),
((2, 3), (1, 3, 4), (1, 2, 4)),
((5, 2, 3), (1, 3, 4), (5, 2, 4)),
((6, 5, 2, 3), (3, 4), (6, 5, 2, 4)),
]))
def test_matmat(self, left_shape, right_shape, result_shape):
self.assertEqual(matmat(jnp.zeros(left_shape),
jnp.zeros(right_shape)).shape, result_shape)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_leftshape={}_rightshape={}".format(left_shape, right_shape),
"left_shape": left_shape, "right_shape": right_shape, "result_shape": result_shape}
for left_shape, right_shape, result_shape in [
((2, 3), (3,), (2,)),
((2, 3), (1, 3), (1, 2)),
((4, 2, 3), (1, 3), (4, 2)),
((5, 4, 2, 3), (1, 3), (5, 4, 2)),
]))
def test_matvec(self, left_shape, right_shape, result_shape):
self.assertEqual(matvec(jnp.zeros(left_shape),
jnp.zeros(right_shape)).shape, result_shape)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_leftshape={}_rightshape={}".format(left_shape, right_shape),
"left_shape": left_shape, "right_shape": right_shape, "result_shape": result_shape}
for left_shape, right_shape, result_shape in [
((3,), (3,), ()),
((2, 3), (3,), (2,)),
((4, 2, 3), (3,), (4, 2)),
]))
def test_vecvec(self, left_shape, right_shape, result_shape):
self.assertEqual(vecvec(jnp.zeros(left_shape),
jnp.zeros(right_shape)).shape, result_shape)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}".format(shape),
"shape": shape, "result_shape": result_shape}
for shape, result_shape in [
((3,), ()),
((2, 3,), (2,)),
((1, 2, 3,), (1, 2)),
]))
def test_magnitude(self, shape, result_shape):
size = 1
for x in shape:
size *= x
self.assertEqual(magnitude(jnp.arange(size).reshape(shape)).shape, result_shape)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}".format(shape),
"shape": shape, "result_shape": result_shape}
for shape, result_shape in [
((3,), ()),
((2, 3), (2,)),
((1, 2, 3, 4), (1, 2, 3)),
]))
def test_mean(self, shape, result_shape):
self.assertEqual(mean(jnp.zeros(shape)).shape, result_shape)
def test_mean_axis(self):
self.assertEqual(mean(jnp.zeros((2, 3)), axis=0).shape, (3,))
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}".format(shape),
"shape": shape, "result_shape": result_shape}
for shape, result_shape in [
((), (2,)),
((3,), (3,2,)),
]))
def test_stack_plus_minus(self, shape, result_shape):
self.assertEqual(stack_plus_minus(jnp.zeros(shape)).shape, result_shape)
def test_center(self):
b, a = center(jnp.arange(3))
self.assertEqual(a.shape, (3,))
self.assertEqual(b.shape, ())
self.assertAllClose(1.0, b, check_dtypes=False)
X = jnp.arange(12).reshape((3, 4))
b, a = center(X, axis=1)
self.assertEqual(a.shape, (3, 4))
self.assertEqual(b.shape, (3,))
self.assertAllClose(jnp.mean(X, axis=1), b)
b, a = center(X, axis=0)
self.assertEqual(a.shape, (3, 4))
self.assertEqual(b.shape, (4,))
self.assertAllClose(jnp.mean(X, axis=0), b)
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())