Remove _ prefix from functions in jax._src.dtypes.

to_inexact_dtype and to_complex_dtype are used across the JAX code base,
so they shouldn't have _ prefixes.
This commit is contained in:
Peter Hawkins 2022-08-12 12:51:09 +00:00
parent 68d61e8db7
commit 29d03160e3
15 changed files with 39 additions and 38 deletions

View File

@ -70,14 +70,14 @@ _dtype_to_inexact = {
}
def _to_inexact_dtype(dtype):
def to_inexact_dtype(dtype):
"""Promotes a dtype into an inexact dtype, if it is not already one."""
dtype = np.dtype(dtype)
return _dtype_to_inexact.get(dtype, dtype)
def _to_complex_dtype(dtype):
ftype = _to_inexact_dtype(dtype)
def to_complex_dtype(dtype):
ftype = to_inexact_dtype(dtype)
if ftype in [np.dtype('float64'), np.dtype('complex128')]:
return np.dtype('complex128')
return np.dtype('complex64')

View File

@ -254,7 +254,7 @@ def gelu(x: Array, approximate: bool = True) -> Array:
"""
# Promote to nearest float-like dtype.
x = x.astype(dtypes._to_inexact_dtype(x.dtype))
x = x.astype(dtypes.to_inexact_dtype(x.dtype))
if approximate:
sqrt_2_over_pi = np.sqrt(2 / np.pi).astype(x.dtype)

View File

@ -184,7 +184,7 @@ def _complex_uniform(key: KeyArray,
"""
key_r, key_theta = random.split(key)
real_dtype = np.array(0, dtype).real.dtype
dtype = dtypes._to_complex_dtype(real_dtype)
dtype = dtypes.to_complex_dtype(real_dtype)
r = jnp.sqrt(2 * random.uniform(key_r, shape, real_dtype)).astype(dtype)
theta = 2 * jnp.pi * random.uniform(key_theta, shape, real_dtype).astype(dtype)
return r * jnp.exp(1j * theta)
@ -199,7 +199,7 @@ def _complex_truncated_normal(key: KeyArray, upper: Array,
"""
key_r, key_theta = random.split(key)
real_dtype = np.array(0, dtype).real.dtype
dtype = dtypes._to_complex_dtype(real_dtype)
dtype = dtypes.to_complex_dtype(real_dtype)
t = ((1 - jnp.exp(jnp.array(-(upper ** 2), dtype)))
* random.uniform(key_r, shape, real_dtype).astype(dtype))
r = jnp.sqrt(-jnp.log(1 - t))

View File

@ -413,7 +413,7 @@ def histogram_bin_edges(a, bins=10, range=None, weights=None):
raise NotImplementedError("string values for `bins` not implemented.")
_check_arraylike("histogram_bin_edges", a, bins)
a = ravel(a)
dtype = dtypes._to_inexact_dtype(_dtype(a))
dtype = dtypes.to_inexact_dtype(_dtype(a))
if _ndim(bins) == 1:
return asarray(bins, dtype=dtype)
bins = core.concrete_or_error(operator.index, bins,
@ -2178,9 +2178,9 @@ def _linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None,
_check_arraylike("linspace", start, stop)
if dtype is None:
dtype = dtypes._to_inexact_dtype(result_type(start, stop))
dtype = dtypes.to_inexact_dtype(result_type(start, stop))
dtype = _jnp_dtype(dtype)
computation_dtype = dtypes._to_inexact_dtype(dtype)
computation_dtype = dtypes.to_inexact_dtype(dtype)
start = asarray(start, dtype=computation_dtype)
stop = asarray(stop, dtype=computation_dtype)
@ -2237,9 +2237,9 @@ def _logspace(start, stop, num=50, endpoint=True, base=10.0, dtype=None,
"""Implementation of logspace differentiable in start and stop args."""
lax_internal._check_user_dtype_supported(dtype, "logspace")
if dtype is None:
dtype = dtypes._to_inexact_dtype(result_type(start, stop))
dtype = dtypes.to_inexact_dtype(result_type(start, stop))
dtype = _jnp_dtype(dtype)
computation_dtype = dtypes._to_inexact_dtype(dtype)
computation_dtype = dtypes.to_inexact_dtype(dtype)
_check_arraylike("logspace", start, stop)
start = asarray(start, dtype=computation_dtype)
stop = asarray(stop, dtype=computation_dtype)
@ -2259,9 +2259,9 @@ def _geomspace(start, stop, num=50, endpoint=True, dtype=None, axis: int = 0):
"""Implementation of geomspace differentiable in start and stop args."""
lax_internal._check_user_dtype_supported(dtype, "geomspace")
if dtype is None:
dtype = dtypes._to_inexact_dtype(result_type(start, stop))
dtype = dtypes.to_inexact_dtype(result_type(start, stop))
dtype = _jnp_dtype(dtype)
computation_dtype = dtypes._to_inexact_dtype(dtype)
computation_dtype = dtypes.to_inexact_dtype(dtype)
_check_arraylike("geomspace", start, stop)
start = asarray(start, dtype=computation_dtype)
stop = asarray(stop, dtype=computation_dtype)
@ -3271,7 +3271,7 @@ def sort(a, axis: Optional[int] = -1, kind='quicksort', order=None):
def sort_complex(a):
_check_arraylike("sort_complex", a)
a = lax.sort(a, dimension=0)
return lax.convert_element_type(a, dtypes._to_complex_dtype(a.dtype))
return lax.convert_element_type(a, dtypes.to_complex_dtype(a.dtype))
@_wraps(np.lexsort)
@partial(jit, static_argnames=('axis',))

View File

@ -33,7 +33,7 @@ import numpy as np
def _roots_no_zeros(p):
# build companion matrix and find its eigenvalues (the roots)
if p.size < 2:
return array([], dtype=dtypes._to_complex_dtype(p.dtype))
return array([], dtype=dtypes.to_complex_dtype(p.dtype))
A = diag(ones((p.size - 2,), p.dtype), -1)
A = A.at[0, :].set(-p[1:] / p[0])
return linalg.eigvals(A)
@ -83,7 +83,7 @@ def roots(p, *, strip_zeros=True):
if p.ndim != 1:
raise ValueError("Input must be a rank-1 array.")
if p.size < 2:
return array([], dtype=dtypes._to_complex_dtype(p.dtype))
return array([], dtype=dtypes.to_complex_dtype(p.dtype))
num_leading_zeros = _where(all(p == 0), len(p), argmin(p == 0))
if strip_zeros:

View File

@ -277,7 +277,7 @@ def _mean(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
normalizer = sum(_broadcast_to(where, np.shape(a)), axis, dtype=dtype, keepdims=keepdims)
if dtype is None:
dtype = dtypes._to_inexact_dtype(dtypes.dtype(a))
dtype = dtypes.to_inexact_dtype(dtypes.dtype(a))
dtype = dtypes.canonicalize_dtype(dtype)
return lax.div(
@ -384,7 +384,7 @@ def _var_promote_types(a_dtype, dtype):
computation_dtype = dtype
else:
if not dtypes.issubdtype(a_dtype, np.inexact):
dtype = dtypes._to_inexact_dtype(a_dtype)
dtype = dtypes.to_inexact_dtype(a_dtype)
computation_dtype = dtype
else:
dtype = _complex_elem_type(a_dtype)

View File

@ -458,7 +458,7 @@ def ldexp(x1, x2):
x1, x2 = _promote_shapes("ldexp", x1, x2)
dtype = dtypes.canonicalize_dtype(dtypes._to_inexact_dtype(x1_dtype))
dtype = dtypes.canonicalize_dtype(dtypes.to_inexact_dtype(x1_dtype))
info = dtypes.finfo(dtype)
int_type = _INT_DTYPES[info.bits]

View File

@ -282,7 +282,7 @@ def _promote_dtypes_inexact(*args):
Promotes arguments to an inexact type."""
to_dtype, weak_type = dtypes._lattice_result_type(*args)
to_dtype = dtypes.canonicalize_dtype(to_dtype)
to_dtype_inexact = dtypes._to_inexact_dtype(to_dtype)
to_dtype_inexact = dtypes.to_inexact_dtype(to_dtype)
return [lax_internal._convert_element_type(x, to_dtype_inexact, weak_type)
for x in args]
@ -293,7 +293,7 @@ def _promote_dtypes_complex(*args):
Promotes arguments to a complex type."""
to_dtype, weak_type = dtypes._lattice_result_type(*args)
to_dtype = dtypes.canonicalize_dtype(to_dtype)
to_dtype_complex = dtypes._to_complex_dtype(to_dtype)
to_dtype_complex = dtypes.to_complex_dtype(to_dtype)
return [lax_internal._convert_element_type(x, to_dtype_complex, weak_type)
for x in args]

View File

@ -120,7 +120,7 @@ def eigh(a, b=None, lower=True, eigvals_only=False, overwrite_a=False,
@partial(jit, static_argnames=('output',))
def _schur(a, output):
if output == "complex":
a = a.astype(dtypes._to_complex_dtype(a.dtype))
a = a.astype(dtypes.to_complex_dtype(a.dtype))
return lax_linalg.schur(a)
@_wraps(scipy.linalg.schur)

View File

@ -277,7 +277,7 @@ def _spectral_helper(x, y,
except ValueError as err:
raise ValueError('x and y cannot be broadcast together.') from err
result_dtype = dtypes._to_complex_dtype(x.dtype)
result_dtype = dtypes.to_complex_dtype(x.dtype)
freq_dtype = np.finfo(result_dtype).dtype
if nperseg is not None: # if specified by user

View File

@ -109,7 +109,8 @@ class FftTest(jtu.JaxTestCase):
def testLaxIrfftDoesNotMutateInputs(self, dtype):
if dtype == np.float64 and not config.x64_enabled:
raise self.skipTest("float64 requires jax_enable_x64=true")
x = (1 + 1j) * jnp.array([[1.0, 2.0], [3.0, 4.0]], dtype=dtypes._to_complex_dtype(dtype))
x = (1 + 1j) * jnp.array([[1.0, 2.0], [3.0, 4.0]],
dtype=dtypes.to_complex_dtype(dtype))
y = np.asarray(jnp.fft.irfft2(x))
z = np.asarray(jnp.fft.irfft2(x))
self.assertAllClose(y, z)

View File

@ -842,7 +842,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
def np_fun(x):
x = np.asarray(x)
if inexact:
x = x.astype(dtypes._to_inexact_dtype(x.dtype))
x = x.astype(dtypes.to_inexact_dtype(x.dtype))
x_cast = x if dtype != jnp.bfloat16 else x.astype(np.float32)
t = out_dtype if out_dtype != jnp.bfloat16 else np.float32
return np_op(x_cast, axis, dtype=t, keepdims=keepdims)
@ -884,7 +884,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
def np_fun(x):
x = np.asarray(x)
if inexact:
x = x.astype(dtypes._to_inexact_dtype(x.dtype))
x = x.astype(dtypes.to_inexact_dtype(x.dtype))
x_cast = x if not is_bf16_nan_test else x.astype(np.float32)
res = np_op(x_cast, axis, keepdims=keepdims)
res = res if not is_bf16_nan_test else res.astype(jnp.bfloat16)
@ -919,7 +919,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
def np_fun(x):
x = np.asarray(x)
if inexact:
x = x.astype(dtypes._to_inexact_dtype(x.dtype))
x = x.astype(dtypes.to_inexact_dtype(x.dtype))
x_cast = x if not is_bf16_nan_test else x.astype(np.float32)
res = np_op(x_cast, axis, keepdims=keepdims, initial=initial)
res = res if not is_bf16_nan_test else res.astype(jnp.bfloat16)
@ -955,7 +955,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
def np_fun(x):
x = np.asarray(x)
if inexact:
x = x.astype(dtypes._to_inexact_dtype(x.dtype))
x = x.astype(dtypes.to_inexact_dtype(x.dtype))
x_cast = x if not is_bf16_nan_test else x.astype(np.float32)
res = np_op(x_cast, axis, keepdims=keepdims)
res = res if not is_bf16_nan_test else res.astype(jnp.bfloat16)
@ -999,7 +999,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
def np_fun(x):
x = np.asarray(x)
if inexact:
x = x.astype(dtypes._to_inexact_dtype(x.dtype))
x = x.astype(dtypes.to_inexact_dtype(x.dtype))
x_cast = x if not is_bf16_nan_test else x.astype(np.float32)
res = np_op(x_cast, axis, keepdims=keepdims, initial=initial, where=where)
res = res if not is_bf16_nan_test else res.astype(jnp.bfloat16)
@ -1042,7 +1042,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
def np_fun(x):
x = np.asarray(x)
if inexact:
x = x.astype(dtypes._to_inexact_dtype(x.dtype))
x = x.astype(dtypes.to_inexact_dtype(x.dtype))
x_cast = x if not is_bf16_nan_test else x.astype(np.float32)
res = np_op(x_cast, axis, keepdims=keepdims, where=where)
res = res if not is_bf16_nan_test else res.astype(jnp.bfloat16)
@ -3145,7 +3145,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
@jtu.ignore_warning(category=RuntimeWarning, message="overflow.*")
def np_fun(x1, x2):
out_dtype = dtypes._to_inexact_dtype(x1.dtype)
out_dtype = dtypes.to_inexact_dtype(x1.dtype)
return np.ldexp(x1.astype(out_dtype), x2)
jnp_fun = jnp.ldexp

View File

@ -87,7 +87,7 @@ class TestPolynomial(jtu.JaxTestCase):
jnp_fun = jnp.roots
def np_fun(arg):
return np.roots(arg).astype(dtypes._to_complex_dtype(arg.dtype))
return np.roots(arg).astype(dtypes.to_complex_dtype(arg.dtype))
# Note: outputs have no defined order, so we need to use a special comparator.
args = args_maker()
@ -116,7 +116,7 @@ class TestPolynomial(jtu.JaxTestCase):
jnp_fun = partial(jnp.roots, strip_zeros=False)
def np_fun(arg):
roots = np.roots(arg).astype(dtypes._to_complex_dtype(arg.dtype))
roots = np.roots(arg).astype(dtypes.to_complex_dtype(arg.dtype))
if len(roots) < len(arg) - 1:
roots = np.pad(roots, (0, len(arg) - len(roots) - 1),
constant_values=complex(np.nan, np.nan))

View File

@ -639,7 +639,7 @@ class LaxRandomTest(jtu.JaxTestCase):
def testChoice(self, dtype, input_range_or_shape, shape, replace, weighted, axis):
# This is the function API that we test against (note that self.rng().choice differs)
np_choice = np.random.default_rng(0).choice
p_dtype = dtypes._to_inexact_dtype(dtype)
p_dtype = dtypes.to_inexact_dtype(dtype)
key = self.seed_prng(0)
is_range = type(input_range_or_shape) is int

View File

@ -61,10 +61,10 @@ default_dtypes = jtu.dtypes.floating + jtu.dtypes.integer + jtu.dtypes.complex
_TPU_FFT_TOL = 0.15
def _real_dtype(dtype):
return jnp.finfo(dtypes._to_inexact_dtype(dtype)).dtype
return jnp.finfo(dtypes.to_inexact_dtype(dtype)).dtype
def _complex_dtype(dtype):
return dtypes._to_complex_dtype(dtype)
return dtypes.to_complex_dtype(dtype)
class LaxBackedScipySignalTests(jtu.JaxTestCase):
@ -133,7 +133,7 @@ class LaxBackedScipySignalTests(jtu.JaxTestCase):
kwds = dict(axis=axis, type=type, bp=bp)
def osp_fun(x):
return osp_signal.detrend(x, **kwds).astype(dtypes._to_inexact_dtype(x.dtype))
return osp_signal.detrend(x, **kwds).astype(dtypes.to_inexact_dtype(x.dtype))
jsp_fun = partial(jsp_signal.detrend, **kwds)
if jtu.device_under_test() == 'tpu':
@ -365,7 +365,7 @@ class LaxBackedScipySignalTests(jtu.JaxTestCase):
kwargs['nperseg'] = nperseg
else:
kwargs['window'] = jnp.array(osp_signal.get_window('hann', nperseg),
dtype=dtypes._to_complex_dtype(dtype))
dtype=dtypes.to_complex_dtype(dtype))
if use_noverlap:
kwargs['noverlap'] = noverlap