mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Remove _ prefix from functions in jax._src.dtypes.
to_inexact_dtype and to_complex_dtype are used across the JAX code base, so they shouldn't have _ prefixes.
This commit is contained in:
parent
68d61e8db7
commit
29d03160e3
@ -70,14 +70,14 @@ _dtype_to_inexact = {
|
||||
}
|
||||
|
||||
|
||||
def _to_inexact_dtype(dtype):
|
||||
def to_inexact_dtype(dtype):
|
||||
"""Promotes a dtype into an inexact dtype, if it is not already one."""
|
||||
dtype = np.dtype(dtype)
|
||||
return _dtype_to_inexact.get(dtype, dtype)
|
||||
|
||||
|
||||
def _to_complex_dtype(dtype):
|
||||
ftype = _to_inexact_dtype(dtype)
|
||||
def to_complex_dtype(dtype):
|
||||
ftype = to_inexact_dtype(dtype)
|
||||
if ftype in [np.dtype('float64'), np.dtype('complex128')]:
|
||||
return np.dtype('complex128')
|
||||
return np.dtype('complex64')
|
||||
|
@ -254,7 +254,7 @@ def gelu(x: Array, approximate: bool = True) -> Array:
|
||||
"""
|
||||
|
||||
# Promote to nearest float-like dtype.
|
||||
x = x.astype(dtypes._to_inexact_dtype(x.dtype))
|
||||
x = x.astype(dtypes.to_inexact_dtype(x.dtype))
|
||||
|
||||
if approximate:
|
||||
sqrt_2_over_pi = np.sqrt(2 / np.pi).astype(x.dtype)
|
||||
|
@ -184,7 +184,7 @@ def _complex_uniform(key: KeyArray,
|
||||
"""
|
||||
key_r, key_theta = random.split(key)
|
||||
real_dtype = np.array(0, dtype).real.dtype
|
||||
dtype = dtypes._to_complex_dtype(real_dtype)
|
||||
dtype = dtypes.to_complex_dtype(real_dtype)
|
||||
r = jnp.sqrt(2 * random.uniform(key_r, shape, real_dtype)).astype(dtype)
|
||||
theta = 2 * jnp.pi * random.uniform(key_theta, shape, real_dtype).astype(dtype)
|
||||
return r * jnp.exp(1j * theta)
|
||||
@ -199,7 +199,7 @@ def _complex_truncated_normal(key: KeyArray, upper: Array,
|
||||
"""
|
||||
key_r, key_theta = random.split(key)
|
||||
real_dtype = np.array(0, dtype).real.dtype
|
||||
dtype = dtypes._to_complex_dtype(real_dtype)
|
||||
dtype = dtypes.to_complex_dtype(real_dtype)
|
||||
t = ((1 - jnp.exp(jnp.array(-(upper ** 2), dtype)))
|
||||
* random.uniform(key_r, shape, real_dtype).astype(dtype))
|
||||
r = jnp.sqrt(-jnp.log(1 - t))
|
||||
|
@ -413,7 +413,7 @@ def histogram_bin_edges(a, bins=10, range=None, weights=None):
|
||||
raise NotImplementedError("string values for `bins` not implemented.")
|
||||
_check_arraylike("histogram_bin_edges", a, bins)
|
||||
a = ravel(a)
|
||||
dtype = dtypes._to_inexact_dtype(_dtype(a))
|
||||
dtype = dtypes.to_inexact_dtype(_dtype(a))
|
||||
if _ndim(bins) == 1:
|
||||
return asarray(bins, dtype=dtype)
|
||||
bins = core.concrete_or_error(operator.index, bins,
|
||||
@ -2178,9 +2178,9 @@ def _linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None,
|
||||
_check_arraylike("linspace", start, stop)
|
||||
|
||||
if dtype is None:
|
||||
dtype = dtypes._to_inexact_dtype(result_type(start, stop))
|
||||
dtype = dtypes.to_inexact_dtype(result_type(start, stop))
|
||||
dtype = _jnp_dtype(dtype)
|
||||
computation_dtype = dtypes._to_inexact_dtype(dtype)
|
||||
computation_dtype = dtypes.to_inexact_dtype(dtype)
|
||||
start = asarray(start, dtype=computation_dtype)
|
||||
stop = asarray(stop, dtype=computation_dtype)
|
||||
|
||||
@ -2237,9 +2237,9 @@ def _logspace(start, stop, num=50, endpoint=True, base=10.0, dtype=None,
|
||||
"""Implementation of logspace differentiable in start and stop args."""
|
||||
lax_internal._check_user_dtype_supported(dtype, "logspace")
|
||||
if dtype is None:
|
||||
dtype = dtypes._to_inexact_dtype(result_type(start, stop))
|
||||
dtype = dtypes.to_inexact_dtype(result_type(start, stop))
|
||||
dtype = _jnp_dtype(dtype)
|
||||
computation_dtype = dtypes._to_inexact_dtype(dtype)
|
||||
computation_dtype = dtypes.to_inexact_dtype(dtype)
|
||||
_check_arraylike("logspace", start, stop)
|
||||
start = asarray(start, dtype=computation_dtype)
|
||||
stop = asarray(stop, dtype=computation_dtype)
|
||||
@ -2259,9 +2259,9 @@ def _geomspace(start, stop, num=50, endpoint=True, dtype=None, axis: int = 0):
|
||||
"""Implementation of geomspace differentiable in start and stop args."""
|
||||
lax_internal._check_user_dtype_supported(dtype, "geomspace")
|
||||
if dtype is None:
|
||||
dtype = dtypes._to_inexact_dtype(result_type(start, stop))
|
||||
dtype = dtypes.to_inexact_dtype(result_type(start, stop))
|
||||
dtype = _jnp_dtype(dtype)
|
||||
computation_dtype = dtypes._to_inexact_dtype(dtype)
|
||||
computation_dtype = dtypes.to_inexact_dtype(dtype)
|
||||
_check_arraylike("geomspace", start, stop)
|
||||
start = asarray(start, dtype=computation_dtype)
|
||||
stop = asarray(stop, dtype=computation_dtype)
|
||||
@ -3271,7 +3271,7 @@ def sort(a, axis: Optional[int] = -1, kind='quicksort', order=None):
|
||||
def sort_complex(a):
|
||||
_check_arraylike("sort_complex", a)
|
||||
a = lax.sort(a, dimension=0)
|
||||
return lax.convert_element_type(a, dtypes._to_complex_dtype(a.dtype))
|
||||
return lax.convert_element_type(a, dtypes.to_complex_dtype(a.dtype))
|
||||
|
||||
@_wraps(np.lexsort)
|
||||
@partial(jit, static_argnames=('axis',))
|
||||
|
@ -33,7 +33,7 @@ import numpy as np
|
||||
def _roots_no_zeros(p):
|
||||
# build companion matrix and find its eigenvalues (the roots)
|
||||
if p.size < 2:
|
||||
return array([], dtype=dtypes._to_complex_dtype(p.dtype))
|
||||
return array([], dtype=dtypes.to_complex_dtype(p.dtype))
|
||||
A = diag(ones((p.size - 2,), p.dtype), -1)
|
||||
A = A.at[0, :].set(-p[1:] / p[0])
|
||||
return linalg.eigvals(A)
|
||||
@ -83,7 +83,7 @@ def roots(p, *, strip_zeros=True):
|
||||
if p.ndim != 1:
|
||||
raise ValueError("Input must be a rank-1 array.")
|
||||
if p.size < 2:
|
||||
return array([], dtype=dtypes._to_complex_dtype(p.dtype))
|
||||
return array([], dtype=dtypes.to_complex_dtype(p.dtype))
|
||||
num_leading_zeros = _where(all(p == 0), len(p), argmin(p == 0))
|
||||
|
||||
if strip_zeros:
|
||||
|
@ -277,7 +277,7 @@ def _mean(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
|
||||
normalizer = sum(_broadcast_to(where, np.shape(a)), axis, dtype=dtype, keepdims=keepdims)
|
||||
|
||||
if dtype is None:
|
||||
dtype = dtypes._to_inexact_dtype(dtypes.dtype(a))
|
||||
dtype = dtypes.to_inexact_dtype(dtypes.dtype(a))
|
||||
dtype = dtypes.canonicalize_dtype(dtype)
|
||||
|
||||
return lax.div(
|
||||
@ -384,7 +384,7 @@ def _var_promote_types(a_dtype, dtype):
|
||||
computation_dtype = dtype
|
||||
else:
|
||||
if not dtypes.issubdtype(a_dtype, np.inexact):
|
||||
dtype = dtypes._to_inexact_dtype(a_dtype)
|
||||
dtype = dtypes.to_inexact_dtype(a_dtype)
|
||||
computation_dtype = dtype
|
||||
else:
|
||||
dtype = _complex_elem_type(a_dtype)
|
||||
|
@ -458,7 +458,7 @@ def ldexp(x1, x2):
|
||||
|
||||
x1, x2 = _promote_shapes("ldexp", x1, x2)
|
||||
|
||||
dtype = dtypes.canonicalize_dtype(dtypes._to_inexact_dtype(x1_dtype))
|
||||
dtype = dtypes.canonicalize_dtype(dtypes.to_inexact_dtype(x1_dtype))
|
||||
info = dtypes.finfo(dtype)
|
||||
int_type = _INT_DTYPES[info.bits]
|
||||
|
||||
|
@ -282,7 +282,7 @@ def _promote_dtypes_inexact(*args):
|
||||
Promotes arguments to an inexact type."""
|
||||
to_dtype, weak_type = dtypes._lattice_result_type(*args)
|
||||
to_dtype = dtypes.canonicalize_dtype(to_dtype)
|
||||
to_dtype_inexact = dtypes._to_inexact_dtype(to_dtype)
|
||||
to_dtype_inexact = dtypes.to_inexact_dtype(to_dtype)
|
||||
return [lax_internal._convert_element_type(x, to_dtype_inexact, weak_type)
|
||||
for x in args]
|
||||
|
||||
@ -293,7 +293,7 @@ def _promote_dtypes_complex(*args):
|
||||
Promotes arguments to a complex type."""
|
||||
to_dtype, weak_type = dtypes._lattice_result_type(*args)
|
||||
to_dtype = dtypes.canonicalize_dtype(to_dtype)
|
||||
to_dtype_complex = dtypes._to_complex_dtype(to_dtype)
|
||||
to_dtype_complex = dtypes.to_complex_dtype(to_dtype)
|
||||
return [lax_internal._convert_element_type(x, to_dtype_complex, weak_type)
|
||||
for x in args]
|
||||
|
||||
|
@ -120,7 +120,7 @@ def eigh(a, b=None, lower=True, eigvals_only=False, overwrite_a=False,
|
||||
@partial(jit, static_argnames=('output',))
|
||||
def _schur(a, output):
|
||||
if output == "complex":
|
||||
a = a.astype(dtypes._to_complex_dtype(a.dtype))
|
||||
a = a.astype(dtypes.to_complex_dtype(a.dtype))
|
||||
return lax_linalg.schur(a)
|
||||
|
||||
@_wraps(scipy.linalg.schur)
|
||||
|
@ -277,7 +277,7 @@ def _spectral_helper(x, y,
|
||||
except ValueError as err:
|
||||
raise ValueError('x and y cannot be broadcast together.') from err
|
||||
|
||||
result_dtype = dtypes._to_complex_dtype(x.dtype)
|
||||
result_dtype = dtypes.to_complex_dtype(x.dtype)
|
||||
freq_dtype = np.finfo(result_dtype).dtype
|
||||
|
||||
if nperseg is not None: # if specified by user
|
||||
|
@ -109,7 +109,8 @@ class FftTest(jtu.JaxTestCase):
|
||||
def testLaxIrfftDoesNotMutateInputs(self, dtype):
|
||||
if dtype == np.float64 and not config.x64_enabled:
|
||||
raise self.skipTest("float64 requires jax_enable_x64=true")
|
||||
x = (1 + 1j) * jnp.array([[1.0, 2.0], [3.0, 4.0]], dtype=dtypes._to_complex_dtype(dtype))
|
||||
x = (1 + 1j) * jnp.array([[1.0, 2.0], [3.0, 4.0]],
|
||||
dtype=dtypes.to_complex_dtype(dtype))
|
||||
y = np.asarray(jnp.fft.irfft2(x))
|
||||
z = np.asarray(jnp.fft.irfft2(x))
|
||||
self.assertAllClose(y, z)
|
||||
|
@ -842,7 +842,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
def np_fun(x):
|
||||
x = np.asarray(x)
|
||||
if inexact:
|
||||
x = x.astype(dtypes._to_inexact_dtype(x.dtype))
|
||||
x = x.astype(dtypes.to_inexact_dtype(x.dtype))
|
||||
x_cast = x if dtype != jnp.bfloat16 else x.astype(np.float32)
|
||||
t = out_dtype if out_dtype != jnp.bfloat16 else np.float32
|
||||
return np_op(x_cast, axis, dtype=t, keepdims=keepdims)
|
||||
@ -884,7 +884,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
def np_fun(x):
|
||||
x = np.asarray(x)
|
||||
if inexact:
|
||||
x = x.astype(dtypes._to_inexact_dtype(x.dtype))
|
||||
x = x.astype(dtypes.to_inexact_dtype(x.dtype))
|
||||
x_cast = x if not is_bf16_nan_test else x.astype(np.float32)
|
||||
res = np_op(x_cast, axis, keepdims=keepdims)
|
||||
res = res if not is_bf16_nan_test else res.astype(jnp.bfloat16)
|
||||
@ -919,7 +919,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
def np_fun(x):
|
||||
x = np.asarray(x)
|
||||
if inexact:
|
||||
x = x.astype(dtypes._to_inexact_dtype(x.dtype))
|
||||
x = x.astype(dtypes.to_inexact_dtype(x.dtype))
|
||||
x_cast = x if not is_bf16_nan_test else x.astype(np.float32)
|
||||
res = np_op(x_cast, axis, keepdims=keepdims, initial=initial)
|
||||
res = res if not is_bf16_nan_test else res.astype(jnp.bfloat16)
|
||||
@ -955,7 +955,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
def np_fun(x):
|
||||
x = np.asarray(x)
|
||||
if inexact:
|
||||
x = x.astype(dtypes._to_inexact_dtype(x.dtype))
|
||||
x = x.astype(dtypes.to_inexact_dtype(x.dtype))
|
||||
x_cast = x if not is_bf16_nan_test else x.astype(np.float32)
|
||||
res = np_op(x_cast, axis, keepdims=keepdims)
|
||||
res = res if not is_bf16_nan_test else res.astype(jnp.bfloat16)
|
||||
@ -999,7 +999,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
def np_fun(x):
|
||||
x = np.asarray(x)
|
||||
if inexact:
|
||||
x = x.astype(dtypes._to_inexact_dtype(x.dtype))
|
||||
x = x.astype(dtypes.to_inexact_dtype(x.dtype))
|
||||
x_cast = x if not is_bf16_nan_test else x.astype(np.float32)
|
||||
res = np_op(x_cast, axis, keepdims=keepdims, initial=initial, where=where)
|
||||
res = res if not is_bf16_nan_test else res.astype(jnp.bfloat16)
|
||||
@ -1042,7 +1042,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
def np_fun(x):
|
||||
x = np.asarray(x)
|
||||
if inexact:
|
||||
x = x.astype(dtypes._to_inexact_dtype(x.dtype))
|
||||
x = x.astype(dtypes.to_inexact_dtype(x.dtype))
|
||||
x_cast = x if not is_bf16_nan_test else x.astype(np.float32)
|
||||
res = np_op(x_cast, axis, keepdims=keepdims, where=where)
|
||||
res = res if not is_bf16_nan_test else res.astype(jnp.bfloat16)
|
||||
@ -3145,7 +3145,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
|
||||
@jtu.ignore_warning(category=RuntimeWarning, message="overflow.*")
|
||||
def np_fun(x1, x2):
|
||||
out_dtype = dtypes._to_inexact_dtype(x1.dtype)
|
||||
out_dtype = dtypes.to_inexact_dtype(x1.dtype)
|
||||
return np.ldexp(x1.astype(out_dtype), x2)
|
||||
|
||||
jnp_fun = jnp.ldexp
|
||||
|
@ -87,7 +87,7 @@ class TestPolynomial(jtu.JaxTestCase):
|
||||
|
||||
jnp_fun = jnp.roots
|
||||
def np_fun(arg):
|
||||
return np.roots(arg).astype(dtypes._to_complex_dtype(arg.dtype))
|
||||
return np.roots(arg).astype(dtypes.to_complex_dtype(arg.dtype))
|
||||
|
||||
# Note: outputs have no defined order, so we need to use a special comparator.
|
||||
args = args_maker()
|
||||
@ -116,7 +116,7 @@ class TestPolynomial(jtu.JaxTestCase):
|
||||
|
||||
jnp_fun = partial(jnp.roots, strip_zeros=False)
|
||||
def np_fun(arg):
|
||||
roots = np.roots(arg).astype(dtypes._to_complex_dtype(arg.dtype))
|
||||
roots = np.roots(arg).astype(dtypes.to_complex_dtype(arg.dtype))
|
||||
if len(roots) < len(arg) - 1:
|
||||
roots = np.pad(roots, (0, len(arg) - len(roots) - 1),
|
||||
constant_values=complex(np.nan, np.nan))
|
||||
|
@ -639,7 +639,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
def testChoice(self, dtype, input_range_or_shape, shape, replace, weighted, axis):
|
||||
# This is the function API that we test against (note that self.rng().choice differs)
|
||||
np_choice = np.random.default_rng(0).choice
|
||||
p_dtype = dtypes._to_inexact_dtype(dtype)
|
||||
p_dtype = dtypes.to_inexact_dtype(dtype)
|
||||
|
||||
key = self.seed_prng(0)
|
||||
is_range = type(input_range_or_shape) is int
|
||||
|
@ -61,10 +61,10 @@ default_dtypes = jtu.dtypes.floating + jtu.dtypes.integer + jtu.dtypes.complex
|
||||
_TPU_FFT_TOL = 0.15
|
||||
|
||||
def _real_dtype(dtype):
|
||||
return jnp.finfo(dtypes._to_inexact_dtype(dtype)).dtype
|
||||
return jnp.finfo(dtypes.to_inexact_dtype(dtype)).dtype
|
||||
|
||||
def _complex_dtype(dtype):
|
||||
return dtypes._to_complex_dtype(dtype)
|
||||
return dtypes.to_complex_dtype(dtype)
|
||||
|
||||
|
||||
class LaxBackedScipySignalTests(jtu.JaxTestCase):
|
||||
@ -133,7 +133,7 @@ class LaxBackedScipySignalTests(jtu.JaxTestCase):
|
||||
kwds = dict(axis=axis, type=type, bp=bp)
|
||||
|
||||
def osp_fun(x):
|
||||
return osp_signal.detrend(x, **kwds).astype(dtypes._to_inexact_dtype(x.dtype))
|
||||
return osp_signal.detrend(x, **kwds).astype(dtypes.to_inexact_dtype(x.dtype))
|
||||
jsp_fun = partial(jsp_signal.detrend, **kwds)
|
||||
|
||||
if jtu.device_under_test() == 'tpu':
|
||||
@ -365,7 +365,7 @@ class LaxBackedScipySignalTests(jtu.JaxTestCase):
|
||||
kwargs['nperseg'] = nperseg
|
||||
else:
|
||||
kwargs['window'] = jnp.array(osp_signal.get_window('hann', nperseg),
|
||||
dtype=dtypes._to_complex_dtype(dtype))
|
||||
dtype=dtypes.to_complex_dtype(dtype))
|
||||
if use_noverlap:
|
||||
kwargs['noverlap'] = noverlap
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user