mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[JAX] Remove uses of the deprecated jax.experimental.vectorize.
jax.numpy.vectorize should be used instead. PiperOrigin-RevId: 341836454
This commit is contained in:
parent
f1b14aa22d
commit
2b8d840cc3
@ -78,13 +78,6 @@ pytype_library(
|
||||
deps = [":jax"],
|
||||
)
|
||||
|
||||
pytype_library(
|
||||
name = "vectorize",
|
||||
srcs = ["experimental/vectorize.py"],
|
||||
srcs_version = "PY3",
|
||||
deps = [":jax"],
|
||||
)
|
||||
|
||||
pytype_library(
|
||||
name = "loops",
|
||||
srcs = ["experimental/loops.py"],
|
||||
|
@ -1,289 +0,0 @@
|
||||
# Copyright 2018 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Extending JAX's vmap to work like NumPy's gufuncs.
|
||||
|
||||
By `Stephan Hoyer <https://github.com/shoyer>`_
|
||||
|
||||
What is a gufunc?
|
||||
=================
|
||||
|
||||
`Generalized universal functions
|
||||
<https://docs.scipy.org/doc/numpy-1.15.0/reference/c-api.generalized-ufuncs.html>`_
|
||||
("gufuncs") are one of my favorite abstractions from NumPy. They generalize
|
||||
NumPy's `broadcasting rules
|
||||
<https://docs.scipy.org/doc/numpy-1.15.0/user/basics.broadcasting.html>`_ to
|
||||
handle non-scalar operations. When a gufuncs is applied to arrays, there are:
|
||||
|
||||
* "core dimensions" over which an operation is defined.
|
||||
* "broadcast dimensions" over which operations can be automatically vectorized.
|
||||
|
||||
A string `signature <https://docs.scipy.org/doc/numpy-1.15.0/reference/c-api.generalized-ufuncs.html#details-of-signature>`_
|
||||
associated with each gufunc controls how this happens by indicating how core
|
||||
dimensions are mapped between inputs and outputs. The syntax is easiest to
|
||||
understand by looking at a few examples:
|
||||
|
||||
* Addition: `(),()->()`
|
||||
* 1D inner product: `(i),(i)->()`
|
||||
* 1D sum: `(i)->()`
|
||||
* Matrix multiplication: `(m,n),(n,k)->(m,k)`
|
||||
|
||||
Why write gufuncs?
|
||||
=====================
|
||||
|
||||
From a user perspective, gufuncs are nice because they're guaranteed to
|
||||
vectorize in a consistent and general fashion. For example, by default gufuncs
|
||||
use the last dimensions of arrays as core dimensions, but you can control that
|
||||
explicitly with the ``axis`` or ``axes`` arguments.
|
||||
|
||||
From a developer perspective, gufuncs are nice because they simplify your work:
|
||||
you only need to think about the core logic of your function, not how it
|
||||
handles arbitrary dimensional input. You can just write that down in a simple,
|
||||
declarative way.
|
||||
|
||||
JAX makes it easy to write high-level performant code
|
||||
=====================================================
|
||||
|
||||
Unfortunately, writing NumPy gufuncs today is somewhat non-trivial. Your
|
||||
options today are:
|
||||
|
||||
1. Write the inner loops yourself in C.
|
||||
2. `np.vectorize <https://docs.scipy.org/doc/numpy/reference/generated/numpy.vectorize.html>`_
|
||||
creates something kind of like a gufunc, but it's painfully slow: the outer loop is performed in Python.
|
||||
3. `numba.guvectorize <https://numba.pydata.org/numba-doc/dev/user/vectorize.html>`_ can work well,
|
||||
if you don't need further code transformations like automatic differentiation.
|
||||
|
||||
JAX's ``vmap`` contains all the core functionality we need to write functions that work like gufuncs.
|
||||
JAX gufuncs play nicely with other transformations like ``grad`` and ``jit``.
|
||||
|
||||
A simple example
|
||||
================
|
||||
|
||||
Consider a simple example from data preprocessing, centering an array.
|
||||
|
||||
Here's how we might write a vectorized version using NumPy::
|
||||
|
||||
def center(array, axis=-1):
|
||||
# array can have any number of dimensions
|
||||
bias = np.mean(array, axis=axis)
|
||||
debiased = array - np.expand_dims(bias, axis)
|
||||
return bias, debiased
|
||||
|
||||
And here's how we could write a vectorized version using JAX gufuncs::
|
||||
|
||||
@vectorize('(n)->(),(n)')
|
||||
def center(array):
|
||||
# array is always a 1D vector
|
||||
bias = np.mean(array)
|
||||
debiased = array - bias
|
||||
return bias, debiased
|
||||
|
||||
See the difference?
|
||||
|
||||
* Instead of needing to think about broadcasting while writing the entire function,
|
||||
we can write the function assuming the input is always a vector.
|
||||
* We get the ``axis`` argument automatically, without needing to write it ourselves.
|
||||
* As a bonus, the decorator makes the function self-documenting: a reader immediately
|
||||
knows that it handles higher dimensional input and output correctly.
|
||||
|
||||
"""
|
||||
import functools
|
||||
import re
|
||||
import warnings
|
||||
|
||||
from jax import vmap
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
|
||||
|
||||
warnings.warn(
|
||||
"jax.experimental.vectorize is deprecated and will be removed soon. Use "
|
||||
"jax.numpy.vectorize instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
|
||||
|
||||
# See http://docs.scipy.org/doc/numpy/reference/c-api.generalized-ufuncs.html
|
||||
_DIMENSION_NAME = r'\w+'
|
||||
_CORE_DIMENSION_LIST = '(?:{0:}(?:,{0:})*)?'.format(_DIMENSION_NAME)
|
||||
_ARGUMENT = r'\({}\)'.format(_CORE_DIMENSION_LIST)
|
||||
_ARGUMENT_LIST = '{0:}(?:,{0:})*'.format(_ARGUMENT)
|
||||
_SIGNATURE = '^{0:}->{0:}$'.format(_ARGUMENT_LIST)
|
||||
|
||||
|
||||
def _parse_gufunc_signature(signature):
|
||||
"""Parse string signatures for a generalized universal function.
|
||||
|
||||
Args:
|
||||
signature : string
|
||||
Generalized universal function signature, e.g., ``(m,n),(n,p)->(m,p)``
|
||||
for ``np.matmul``.
|
||||
|
||||
Returns:
|
||||
Tuple of input and output core dimensions parsed from the signature, each
|
||||
of the form List[Tuple[str, ...]].
|
||||
"""
|
||||
if not re.match(_SIGNATURE, signature):
|
||||
raise ValueError(
|
||||
'not a valid gufunc signature: {}'.format(signature))
|
||||
return tuple([tuple(re.findall(_DIMENSION_NAME, arg))
|
||||
for arg in re.findall(_ARGUMENT, arg_list)]
|
||||
for arg_list in signature.split('->'))
|
||||
|
||||
|
||||
|
||||
def _update_dim_sizes(dim_sizes, arg, core_dims):
|
||||
"""Incrementally check and update core dimension sizes for a single argument.
|
||||
|
||||
Args:
|
||||
dim_sizes : Dict[str, int]
|
||||
Sizes of existing core dimensions. Will be updated in-place.
|
||||
arg : ndarray
|
||||
Argument to examine.
|
||||
core_dims : Tuple[str, ...]
|
||||
Core dimensions for this argument.
|
||||
"""
|
||||
if not core_dims:
|
||||
return
|
||||
|
||||
num_core_dims = len(core_dims)
|
||||
if arg.ndim < num_core_dims:
|
||||
raise ValueError(
|
||||
'%d-dimensional argument does not have enough '
|
||||
'dimensions for all core dimensions %r'
|
||||
% (arg.ndim, core_dims))
|
||||
|
||||
core_shape = arg.shape[-num_core_dims:]
|
||||
for dim, size in zip(core_dims, core_shape):
|
||||
if dim in dim_sizes:
|
||||
if size != dim_sizes[dim]:
|
||||
raise ValueError(
|
||||
'inconsistent size for core dimension %r: %r vs %r'
|
||||
% (dim, size, dim_sizes[dim]))
|
||||
else:
|
||||
dim_sizes[dim] = size
|
||||
|
||||
|
||||
def _parse_input_dimensions(args, input_core_dims):
|
||||
"""Parse broadcast and core dimensions for vectorize with a signature.
|
||||
|
||||
Args:
|
||||
args : Tuple[ndarray, ...]
|
||||
Tuple of input arguments to examine.
|
||||
input_core_dims : List[Tuple[str, ...]]
|
||||
List of core dimensions corresponding to each input.
|
||||
|
||||
Returns:
|
||||
broadcast_shape : Tuple[int, ...]
|
||||
Common shape to broadcast all non-core dimensions to.
|
||||
dim_sizes : Dict[str, int]
|
||||
Common sizes for named core dimensions.
|
||||
"""
|
||||
broadcast_args = []
|
||||
dim_sizes = {}
|
||||
for arg, core_dims in zip(args, input_core_dims):
|
||||
_update_dim_sizes(dim_sizes, arg, core_dims)
|
||||
ndim = arg.ndim - len(core_dims)
|
||||
dummy_array = np.lib.stride_tricks.as_strided(0, arg.shape[:ndim])
|
||||
broadcast_args.append(dummy_array)
|
||||
broadcast_shape = np.lib.stride_tricks._broadcast_shape(*broadcast_args)
|
||||
return broadcast_shape, dim_sizes
|
||||
|
||||
|
||||
def _calculate_shapes(broadcast_shape, dim_sizes, list_of_core_dims):
|
||||
"""Helper for calculating broadcast shapes with core dimensions."""
|
||||
return [broadcast_shape + tuple(dim_sizes[dim] for dim in core_dims)
|
||||
for core_dims in list_of_core_dims]
|
||||
|
||||
|
||||
# adapted from np.vectorize (again authored by shoyer@)
|
||||
def broadcast_with_core_dims(args, input_core_dims, output_core_dims):
|
||||
if len(args) != len(input_core_dims):
|
||||
raise TypeError('wrong number of positional arguments: '
|
||||
'expected %r, got %r'
|
||||
% (len(input_core_dims), len(args)))
|
||||
|
||||
broadcast_shape, dim_sizes = _parse_input_dimensions(
|
||||
args, input_core_dims)
|
||||
input_shapes = _calculate_shapes(broadcast_shape, dim_sizes,
|
||||
input_core_dims)
|
||||
args = [jnp.broadcast_to(arg, shape)
|
||||
for arg, shape in zip(args, input_shapes)]
|
||||
return args
|
||||
|
||||
def verify_axis_is_supported(input_core_dims, output_core_dims):
|
||||
all_core_dims = set()
|
||||
for input_or_output_core_dims in [input_core_dims, output_core_dims]:
|
||||
for core_dims in input_or_output_core_dims:
|
||||
all_core_dims.update(core_dims)
|
||||
if len(core_dims) > 1:
|
||||
raise ValueError('only one gufuncs with one core dim support axis')
|
||||
|
||||
|
||||
def reorder_inputs(args, axis, input_core_dims):
|
||||
return tuple(jnp.moveaxis(arg, axis, -1) if core_dims else arg
|
||||
for arg, core_dims in zip(args, input_core_dims))
|
||||
|
||||
|
||||
def reorder_outputs(result, axis, output_core_dims):
|
||||
if not isinstance(result, tuple):
|
||||
result = (result,)
|
||||
result = tuple(jnp.moveaxis(res, -1, axis) if core_dims else res
|
||||
for res, core_dims in zip(result, output_core_dims))
|
||||
if len(result) == 1:
|
||||
(result,) = result
|
||||
return result
|
||||
|
||||
|
||||
def vectorize(signature):
|
||||
"""Vectorize a function using JAX.
|
||||
|
||||
Turns an arbitrary function into a numpy style "gufunc". Once
|
||||
you specify the behavior of the core axis, the rest will be
|
||||
broadcast naturally.
|
||||
|
||||
Args:
|
||||
signature: an einsum style signature that defines how the core dimensions are mapped between inputs and outputs.
|
||||
|
||||
Returns:
|
||||
The vectorized 'gufunc' that will automatically broadcast
|
||||
while maintaining the specified core logic, the returned
|
||||
function also has a new ``axis`` parameter for specifying
|
||||
which axis should be treated as the core one.
|
||||
"""
|
||||
input_core_dims, output_core_dims = _parse_gufunc_signature(signature)
|
||||
|
||||
def decorator(func):
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
axis = kwargs.get('axis') # for python2 compat.
|
||||
|
||||
if axis is not None:
|
||||
verify_axis_is_supported(input_core_dims, output_core_dims)
|
||||
args = reorder_inputs(args, axis, input_core_dims)
|
||||
|
||||
broadcast_args = broadcast_with_core_dims(
|
||||
args, input_core_dims, output_core_dims)
|
||||
num_batch_dims = len(broadcast_args[0].shape) - len(input_core_dims[0])
|
||||
|
||||
vectorized_func = func
|
||||
for _ in range(num_batch_dims):
|
||||
vectorized_func = vmap(vectorized_func)
|
||||
result = vectorized_func(*broadcast_args)
|
||||
|
||||
if axis is not None:
|
||||
result = reorder_outputs(result, axis, output_core_dims)
|
||||
|
||||
return result
|
||||
return wrapper
|
||||
return decorator
|
@ -3,7 +3,6 @@ filterwarnings =
|
||||
error
|
||||
ignore:No GPU/TPU found, falling back to CPU.:UserWarning
|
||||
ignore:Explicitly requested dtype.*is not available.*:UserWarning
|
||||
ignore:jax.experimental.vectorize is deprecated.*:FutureWarning
|
||||
ignore:outfeed_receiver is unnecessary and deprecated:DeprecationWarning
|
||||
ignore:jax.experimental.optix is deprecated.*:DeprecationWarning
|
||||
# The rest are for experimental/jax_to_tf
|
||||
|
@ -1,145 +0,0 @@
|
||||
# Copyright 2019 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for Vectorize library."""
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
|
||||
from jax import numpy as jnp
|
||||
from jax import test_util as jtu
|
||||
from jax.experimental.vectorize import vectorize
|
||||
|
||||
from jax.config import config
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
matmat = vectorize('(n,m),(m,k)->(n,k)')(jnp.dot)
|
||||
matvec = vectorize('(n,m),(m)->(n)')(jnp.dot)
|
||||
vecmat = vectorize('(m),(m,k)->(k)')(jnp.dot)
|
||||
vecvec = vectorize('(m),(m)->()')(jnp.dot)
|
||||
|
||||
@vectorize('(n)->()')
|
||||
def magnitude(x):
|
||||
return jnp.dot(x, x)
|
||||
|
||||
mean = vectorize('(n)->()')(jnp.mean)
|
||||
|
||||
@vectorize('()->(n)')
|
||||
def stack_plus_minus(x):
|
||||
return jnp.stack([x, -x])
|
||||
|
||||
@vectorize('(n)->(),(n)')
|
||||
def center(array):
|
||||
bias = jnp.mean(array)
|
||||
debiased = array - bias
|
||||
return bias, debiased
|
||||
|
||||
|
||||
class VectorizeTest(jtu.JaxTestCase):
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_leftshape={}_rightshape={}".format(left_shape, right_shape),
|
||||
"left_shape": left_shape, "right_shape": right_shape, "result_shape": result_shape}
|
||||
for left_shape, right_shape, result_shape in [
|
||||
((2, 3), (3, 4), (2, 4)),
|
||||
((2, 3), (1, 3, 4), (1, 2, 4)),
|
||||
((5, 2, 3), (1, 3, 4), (5, 2, 4)),
|
||||
((6, 5, 2, 3), (3, 4), (6, 5, 2, 4)),
|
||||
]))
|
||||
def test_matmat(self, left_shape, right_shape, result_shape):
|
||||
self.assertEqual(matmat(jnp.zeros(left_shape),
|
||||
jnp.zeros(right_shape)).shape, result_shape)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_leftshape={}_rightshape={}".format(left_shape, right_shape),
|
||||
"left_shape": left_shape, "right_shape": right_shape, "result_shape": result_shape}
|
||||
for left_shape, right_shape, result_shape in [
|
||||
((2, 3), (3,), (2,)),
|
||||
((2, 3), (1, 3), (1, 2)),
|
||||
((4, 2, 3), (1, 3), (4, 2)),
|
||||
((5, 4, 2, 3), (1, 3), (5, 4, 2)),
|
||||
]))
|
||||
def test_matvec(self, left_shape, right_shape, result_shape):
|
||||
self.assertEqual(matvec(jnp.zeros(left_shape),
|
||||
jnp.zeros(right_shape)).shape, result_shape)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_leftshape={}_rightshape={}".format(left_shape, right_shape),
|
||||
"left_shape": left_shape, "right_shape": right_shape, "result_shape": result_shape}
|
||||
for left_shape, right_shape, result_shape in [
|
||||
((3,), (3,), ()),
|
||||
((2, 3), (3,), (2,)),
|
||||
((4, 2, 3), (3,), (4, 2)),
|
||||
]))
|
||||
def test_vecvec(self, left_shape, right_shape, result_shape):
|
||||
self.assertEqual(vecvec(jnp.zeros(left_shape),
|
||||
jnp.zeros(right_shape)).shape, result_shape)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_shape={}".format(shape),
|
||||
"shape": shape, "result_shape": result_shape}
|
||||
for shape, result_shape in [
|
||||
((3,), ()),
|
||||
((2, 3,), (2,)),
|
||||
((1, 2, 3,), (1, 2)),
|
||||
]))
|
||||
def test_magnitude(self, shape, result_shape):
|
||||
size = 1
|
||||
for x in shape:
|
||||
size *= x
|
||||
self.assertEqual(magnitude(jnp.arange(size).reshape(shape)).shape, result_shape)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_shape={}".format(shape),
|
||||
"shape": shape, "result_shape": result_shape}
|
||||
for shape, result_shape in [
|
||||
((3,), ()),
|
||||
((2, 3), (2,)),
|
||||
((1, 2, 3, 4), (1, 2, 3)),
|
||||
]))
|
||||
def test_mean(self, shape, result_shape):
|
||||
self.assertEqual(mean(jnp.zeros(shape)).shape, result_shape)
|
||||
|
||||
def test_mean_axis(self):
|
||||
self.assertEqual(mean(jnp.zeros((2, 3)), axis=0).shape, (3,))
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_shape={}".format(shape),
|
||||
"shape": shape, "result_shape": result_shape}
|
||||
for shape, result_shape in [
|
||||
((), (2,)),
|
||||
((3,), (3,2,)),
|
||||
]))
|
||||
def test_stack_plus_minus(self, shape, result_shape):
|
||||
self.assertEqual(stack_plus_minus(jnp.zeros(shape)).shape, result_shape)
|
||||
|
||||
def test_center(self):
|
||||
b, a = center(jnp.arange(3))
|
||||
self.assertEqual(a.shape, (3,))
|
||||
self.assertEqual(b.shape, ())
|
||||
self.assertAllClose(1.0, b, check_dtypes=False)
|
||||
|
||||
X = jnp.arange(12).reshape((3, 4))
|
||||
b, a = center(X, axis=1)
|
||||
self.assertEqual(a.shape, (3, 4))
|
||||
self.assertEqual(b.shape, (3,))
|
||||
self.assertAllClose(jnp.mean(X, axis=1), b)
|
||||
|
||||
b, a = center(X, axis=0)
|
||||
self.assertEqual(a.shape, (3, 4))
|
||||
self.assertEqual(b.shape, (4,))
|
||||
self.assertAllClose(jnp.mean(X, axis=0), b)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
Loading…
x
Reference in New Issue
Block a user