mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Add precision to jax.numpy functions that use lax.dot_general (#1728)
* Add precision to jax.numpy functions that use lax.dot_general * Test precision argument * check default precision * test with jaxprs * Document precision
This commit is contained in:
parent
eff7b45dba
commit
27aa76e6a6
@ -2096,26 +2096,33 @@ def append(arr, values, axis=None):
|
||||
### Tensor contraction operations
|
||||
|
||||
|
||||
@_wraps(onp.dot)
|
||||
def dot(a, b): # pylint: disable=missing-docstring
|
||||
_PRECISION_DOC = """\
|
||||
In addition to the original NumPy arguments listed below, also supports
|
||||
``precision`` for extra control over matrix-multiplication precision
|
||||
on supported devices. See :py:func:`jax.lax.dot` for details.
|
||||
"""
|
||||
|
||||
|
||||
@_wraps(onp.dot, lax_description=_PRECISION_DOC)
|
||||
def dot(a, b, precision=None): # pylint: disable=missing-docstring
|
||||
_check_arraylike("dot", a, b)
|
||||
a, b = _promote_dtypes(a, b)
|
||||
a_ndim, b_ndim = ndim(a), ndim(b)
|
||||
if a_ndim == 0 or b_ndim == 0:
|
||||
return lax.mul(a, b)
|
||||
if _max(a_ndim, b_ndim) <= 2:
|
||||
return lax.dot(a, b)
|
||||
return lax.dot(a, b, precision=precision)
|
||||
|
||||
if b_ndim == 1:
|
||||
contract_dims = ((a_ndim - 1,), (0,))
|
||||
else:
|
||||
contract_dims = ((a_ndim - 1,), (b_ndim - 2,))
|
||||
batch_dims = ((), ())
|
||||
return lax.dot_general(a, b, (contract_dims, batch_dims))
|
||||
return lax.dot_general(a, b, (contract_dims, batch_dims), precision)
|
||||
|
||||
|
||||
@_wraps(onp.matmul)
|
||||
def matmul(a, b): # pylint: disable=missing-docstring
|
||||
@_wraps(onp.matmul, lax_description=_PRECISION_DOC)
|
||||
def matmul(a, b, precision=None): # pylint: disable=missing-docstring
|
||||
_check_arraylike("matmul", a, b)
|
||||
a_is_vec, b_is_vec = (ndim(a) == 1), (ndim(b) == 1)
|
||||
a = lax.reshape(a, (1,) + shape(a)) if a_is_vec else a
|
||||
@ -2126,8 +2133,8 @@ def matmul(a, b): # pylint: disable=missing-docstring
|
||||
a = broadcast_to(a, batch_shape + shape(a)[-2:])
|
||||
b = broadcast_to(b, batch_shape + shape(b)[-2:])
|
||||
batch_dims = tuple(range(len(batch_shape)))
|
||||
result = lax.dot_general(a, b, (((ndim(a) - 1,), (ndim(b) - 2,)),
|
||||
(batch_dims, batch_dims)))
|
||||
dim_numbers = (((ndim(a) - 1,), (ndim(b) - 2,)), (batch_dims, batch_dims))
|
||||
result = lax.dot_general(a, b, dim_numbers, precision)
|
||||
|
||||
if a_is_vec or b_is_vec:
|
||||
m, n = shape(result)[-2:]
|
||||
@ -2138,15 +2145,15 @@ def matmul(a, b): # pylint: disable=missing-docstring
|
||||
return result
|
||||
|
||||
|
||||
@_wraps(onp.vdot)
|
||||
def vdot(a, b):
|
||||
@_wraps(onp.vdot, lax_description=_PRECISION_DOC)
|
||||
def vdot(a, b, precision=None):
|
||||
if issubdtype(_dtype(a), onp.complexfloating):
|
||||
a = conj(a)
|
||||
return dot(a.ravel(), b.ravel())
|
||||
return dot(a.ravel(), b.ravel(), precision=precision)
|
||||
|
||||
|
||||
@_wraps(onp.tensordot)
|
||||
def tensordot(a, b, axes=2):
|
||||
@_wraps(onp.tensordot, lax_description=_PRECISION_DOC)
|
||||
def tensordot(a, b, axes=2, precision=None):
|
||||
_check_arraylike("tensordot", a, b)
|
||||
if not (ndim(a) >= 1 and ndim(b) >= 1):
|
||||
msg = "tensordot requires a.ndim and b.dim to be at least 1, got {} and {}."
|
||||
@ -2161,14 +2168,14 @@ def tensordot(a, b, axes=2):
|
||||
a, b = _promote_dtypes(a, b)
|
||||
a_reshape = lax.reshape(a, (_prod(a.shape[:-axes]), _prod(a.shape[-axes:])))
|
||||
b_reshape = lax.reshape(b, (_prod(b.shape[:axes]), _prod(b.shape[axes:])))
|
||||
out_reshape = lax.dot(a_reshape, b_reshape)
|
||||
out_reshape = lax.dot(a_reshape, b_reshape, precision=precision)
|
||||
return lax.reshape(out_reshape, a.shape[:-axes] + b.shape[axes:])
|
||||
elif type(axes) in (list, tuple) and len(axes) == 2:
|
||||
ax1, ax2 = axes
|
||||
if type(ax1) == type(ax2) == int:
|
||||
a_transposed = moveaxis(a, ax1, -1) if ax1 != a.ndim - 1 else a
|
||||
b_transposed = moveaxis(b, ax2, 0) if ax2 != 0 else b
|
||||
return tensordot(a_transposed, b_transposed, 1)
|
||||
return tensordot(a_transposed, b_transposed, 1, precision)
|
||||
elif type(ax1) in (list, tuple) and type(ax2) in (list, tuple):
|
||||
if len(ax1) != len(ax2):
|
||||
msg = "tensordot requires axes lists to have equal length, got {} and {}."
|
||||
@ -2176,16 +2183,17 @@ def tensordot(a, b, axes=2):
|
||||
num_axes = len(ax1)
|
||||
a_transposed = moveaxis(a, ax1, tuple(range(a.ndim - num_axes, a.ndim)))
|
||||
b_transposed = moveaxis(b, ax2, tuple(range(num_axes)))
|
||||
return tensordot(a_transposed, b_transposed, num_axes)
|
||||
return tensordot(a_transposed, b_transposed, num_axes, precision)
|
||||
msg = ("tensordot axes argument must be an int, a pair of ints, or a pair of "
|
||||
"lists/tuples of ints.")
|
||||
raise TypeError(msg)
|
||||
|
||||
|
||||
@_wraps(onp.einsum)
|
||||
@_wraps(onp.einsum, lax_description=_PRECISION_DOC)
|
||||
def einsum(*operands, **kwargs):
|
||||
optimize = kwargs.pop('optimize', 'auto')
|
||||
optimize = 'greedy' if optimize is True else optimize
|
||||
precision = kwargs.pop('precision', None)
|
||||
if kwargs:
|
||||
msg = 'invalid keyword arguments for einsum: {}'
|
||||
raise TypeError(msg.format(', '.join(kwargs)))
|
||||
@ -2193,7 +2201,7 @@ def einsum(*operands, **kwargs):
|
||||
operands, contractions = opt_einsum.contract_path(
|
||||
*operands, einsum_call=True, use_blas=True, optimize=optimize)
|
||||
contractions = tuple(data[:3] for data in contractions)
|
||||
return _einsum(operands, contractions)
|
||||
return _einsum(operands, contractions, precision)
|
||||
|
||||
@_wraps(onp.einsum_path)
|
||||
def einsum_path(subscripts, *operands, **kwargs):
|
||||
@ -2201,8 +2209,8 @@ def einsum_path(subscripts, *operands, **kwargs):
|
||||
# using einsum_call=True here is an internal api for opt_einsum
|
||||
return opt_einsum.contract_path(subscripts, *operands, optimize=optimize)
|
||||
|
||||
@partial(jit, static_argnums=(1,))
|
||||
def _einsum(operands, contractions):
|
||||
@partial(jit, static_argnums=(1, 2))
|
||||
def _einsum(operands, contractions, precision):
|
||||
operands = list(_promote_dtypes(*operands))
|
||||
sum = lambda x, axes: lax.reduce(x, onp.array(0, x.dtype), lax.add, axes)
|
||||
|
||||
@ -2292,7 +2300,8 @@ def _einsum(operands, contractions):
|
||||
# contract using lax.dot_general
|
||||
lhs_cont, rhs_cont = unzip2((lhs_names.index(n), rhs_names.index(n))
|
||||
for n in contracted_names)
|
||||
operand = _dot_general(lhs, rhs, lhs_cont, rhs_cont, len(batch_dims))
|
||||
operand = _dot_general(lhs, rhs, lhs_cont, rhs_cont, len(batch_dims),
|
||||
precision)
|
||||
deleted_names = batch_names + ''.join(contracted_names)
|
||||
names = (batch_names + removechars(lhs_names, deleted_names)
|
||||
+ removechars(rhs_names, deleted_names))
|
||||
@ -2320,7 +2329,7 @@ def _einsum(operands, contractions):
|
||||
return operands[0]
|
||||
|
||||
|
||||
def _dot_general(lhs, rhs, lhs_cont, rhs_cont, nbatch):
|
||||
def _dot_general(lhs, rhs, lhs_cont, rhs_cont, nbatch, precision):
|
||||
"""Helper for einsum contractions."""
|
||||
# lax.dot_general has some tight constraints on dimension_numbers that this
|
||||
# wrapper loosens via transposes and reshapes
|
||||
@ -2332,7 +2341,7 @@ def _dot_general(lhs, rhs, lhs_cont, rhs_cont, nbatch):
|
||||
|
||||
if ncont == 1 and 0 <= lhs_ntensor <= 1 and 0 <= rhs_ntensor <= 1:
|
||||
dimension_numbers = [(lhs_cont, rhs_cont), (batch_dims, batch_dims)]
|
||||
return lax.dot_general(lhs, rhs, dimension_numbers)
|
||||
return lax.dot_general(lhs, rhs, dimension_numbers, precision)
|
||||
else:
|
||||
# move contracting dimensions to the end. lax.dot_general only allows one
|
||||
# contracting dimension, so if there's more than one we collapse them.
|
||||
@ -2360,7 +2369,7 @@ def _dot_general(lhs, rhs, lhs_cont, rhs_cont, nbatch):
|
||||
|
||||
lhs_cont, rhs_cont = [lhs.ndim - 1], [rhs.ndim - 1]
|
||||
dimension_numbers = [(lhs_cont, rhs_cont), (batch_dims, batch_dims)]
|
||||
result = lax.dot_general(lhs, rhs, dimension_numbers)
|
||||
result = lax.dot_general(lhs, rhs, dimension_numbers, precision)
|
||||
return lax.reshape(result, result_shape)
|
||||
|
||||
|
||||
@ -2372,11 +2381,11 @@ def _movechars(s, src, dst):
|
||||
return ''.join(chars)
|
||||
|
||||
|
||||
@_wraps(onp.inner)
|
||||
def inner(a, b):
|
||||
@_wraps(onp.inner, lax_description=_PRECISION_DOC)
|
||||
def inner(a, b, precision=None):
|
||||
if ndim(a) == 0 or ndim(b) == 0:
|
||||
return a * b
|
||||
return tensordot(a, b, (-1, -1))
|
||||
return tensordot(a, b, (-1, -1), precision=precision)
|
||||
|
||||
|
||||
@_wraps(onp.outer)
|
||||
|
@ -2390,6 +2390,68 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
self.assertAllClose(lnp.broadcast_to(1, (3, 2)), onp.ones((3, 2)),
|
||||
check_dtypes=False)
|
||||
|
||||
def testPrecision(self):
|
||||
|
||||
def iter_eqns(jaxpr):
|
||||
for eqn in jaxpr.eqns:
|
||||
yield eqn
|
||||
for subjaxpr, _, _ in eqn.bound_subjaxprs:
|
||||
for sub_eqn in iter_eqns(subjaxpr):
|
||||
yield sub_eqn
|
||||
|
||||
def assert_precision(expected, fun, *args):
|
||||
jaxpr = jax.make_jaxpr(fun)(*args)
|
||||
precision, = [eqn.params['precision'] for eqn in iter_eqns(jaxpr)
|
||||
if eqn.primitive == lax.dot_general_p]
|
||||
self.assertEqual(precision, expected)
|
||||
|
||||
ones_1d = onp.ones((2,))
|
||||
ones_2d = onp.ones((2, 2))
|
||||
ones_3d = onp.ones((2, 2, 2))
|
||||
HIGHEST = lax.Precision.HIGHEST
|
||||
|
||||
assert_precision(None, lnp.dot, ones_1d, ones_1d)
|
||||
assert_precision(
|
||||
HIGHEST,
|
||||
partial(lnp.dot, precision=HIGHEST),
|
||||
ones_1d, ones_1d)
|
||||
assert_precision(
|
||||
HIGHEST,
|
||||
partial(lnp.dot, precision=HIGHEST),
|
||||
ones_3d, ones_3d)
|
||||
assert_precision(
|
||||
HIGHEST,
|
||||
partial(lnp.matmul, precision=HIGHEST),
|
||||
ones_2d, ones_2d)
|
||||
assert_precision(
|
||||
HIGHEST,
|
||||
partial(lnp.vdot, precision=HIGHEST),
|
||||
ones_1d, ones_1d)
|
||||
assert_precision(
|
||||
HIGHEST,
|
||||
partial(lnp.tensordot, axes=2, precision=HIGHEST),
|
||||
ones_2d, ones_2d)
|
||||
assert_precision(
|
||||
HIGHEST,
|
||||
partial(lnp.tensordot, axes=(0, 0), precision=HIGHEST),
|
||||
ones_1d, ones_1d)
|
||||
assert_precision(
|
||||
HIGHEST,
|
||||
partial(lnp.tensordot, axes=((0,), (0,)), precision=HIGHEST),
|
||||
ones_1d, ones_1d)
|
||||
assert_precision(
|
||||
HIGHEST,
|
||||
partial(lnp.einsum, 'i,i', precision=HIGHEST),
|
||||
ones_1d, ones_1d)
|
||||
assert_precision(
|
||||
HIGHEST,
|
||||
partial(lnp.einsum, 'ij,ij', precision=HIGHEST),
|
||||
ones_2d, ones_2d)
|
||||
assert_precision(
|
||||
HIGHEST,
|
||||
partial(lnp.inner, precision=HIGHEST),
|
||||
ones_1d, ones_1d)
|
||||
|
||||
# Most grad tests are at the lax level (see lax_test.py), but we add some here
|
||||
# as needed for e.g. particular compound ops of interest.
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user