Tighten up dtypes across the package

This commit is contained in:
Jake VanderPlas 2021-10-29 11:02:01 -07:00
parent 853fca2245
commit 40d6f5ed90
7 changed files with 18 additions and 13 deletions

View File

@ -7141,7 +7141,7 @@ def _check_shapelike(fun_name, arg_name, obj, non_zero_shape=False):
raise TypeError(msg.format(fun_name, arg_name, bound_error, obj))
def _dynamic_slice_indices(operand, start_indices):
def _dynamic_slice_indices(operand, start_indices):
# Normalize the start_indices w.r.t. operand.shape
if len(start_indices) != operand.ndim:
msg = ("Length of slice indices must match number of operand dimensions ({} "

View File

@ -314,7 +314,7 @@ def normalize(x: Array,
return (x - mean) * lax.rsqrt(variance + epsilon)
def one_hot(x: Array, num_classes: int, *,
dtype: Any = jnp.float64, axis: Union[int, AxisName] = -1) -> Array:
dtype: Any = jnp.float_, axis: Union[int, AxisName] = -1) -> Array:
"""One-hot encodes the given indicies.
Each index in the input ``x`` is encoded as a vector of zeros of length
@ -334,8 +334,7 @@ def one_hot(x: Array, num_classes: int, *,
Args:
x: A tensor of indices.
num_classes: Number of classes in the one-hot dimension.
dtype: optional, a float dtype for the returned values (default float64 if
jax_enable_x64 is true, otherwise float32).
dtype: optional, a float dtype for the returned values (default :obj:`jnp.float_`).
axis: the axis or axes along which the function should be
computed.
"""

View File

@ -1784,7 +1784,7 @@ def ravel_multi_index(multi_index, dims, mode='raise', order='C'):
else:
raise ValueError(f"invalid order={order!r}. Expected 'C' or 'F'")
result = 0
result = array(0, dtype=dtypes.canonicalize_dtype(int_))
for i, s in zip(multi_index, strides):
result = result + i * s
return result
@ -1798,8 +1798,7 @@ and out-of-bounds indices are clipped.
@_wraps(np.unravel_index, lax_description=_UNRAVEL_INDEX_DOC)
def unravel_index(indices, shape):
_check_arraylike("unravel_index", indices)
shape = core.concrete_or_error(tuple, shape, context="shape argument of unravel_index")
sizes = array(tuple(shape) + (1,))
sizes = append(array(shape), 1)
cumulative_sizes = cumprod(sizes[::-1])[::-1]
total_size = cumulative_sizes[0]
# Clip so raveling and unraveling an oob index will not change the behavior
@ -5200,7 +5199,7 @@ def _argmax(a, axis: Optional[int] = None, out=None):
axis = 0
if a.shape[axis] == 0:
raise ValueError("attempt to get argmax of an empty sequence")
return lax.argmax(a, _canonicalize_axis(axis, a.ndim), int64)
return lax.argmax(a, _canonicalize_axis(axis, a.ndim), dtypes.canonicalize_dtype(int_))
@_wraps(np.argmin, skip_params=['out'])
def argmin(a, axis: Optional[int] = None, out=None):
@ -5216,7 +5215,7 @@ def _argmin(a, axis: Optional[int] = None, out=None):
axis = 0
if a.shape[axis] == 0:
raise ValueError("attempt to get argmin of an empty sequence")
return lax.argmin(a, _canonicalize_axis(axis, a.ndim), int64)
return lax.argmin(a, _canonicalize_axis(axis, a.ndim), dtypes.canonicalize_dtype(int_))
_NANARG_DOC = """\

View File

@ -174,15 +174,16 @@ def _minimize_lbfgs(
converged = jnp.linalg.norm(g_kp1, ord=norm) < gtol
# TODO(jakevdp): use a fixed-point procedure rather than type-casting?
state = state._replace(
converged=converged,
failed=(status > 0) & (~converged),
k=state.k + 1,
nfev=state.nfev + ls_results.nfev,
ngev=state.ngev + ls_results.ngev,
x_k=x_kp1,
f_k=f_kp1,
g_k=g_kp1,
x_k=x_kp1.astype(state.x_k.dtype),
f_k=f_kp1.astype(state.f_k.dtype),
g_k=g_kp1.astype(state.g_k.dtype),
s_history=_update_history_vectors(history=state.s_history, new=s_k),
y_history=_update_history_vectors(history=state.y_history, new=y_k),
rho_history=_update_history_scalars(history=state.rho_history, new=rho_k),

View File

@ -96,7 +96,7 @@ def minimize_bfgs(
d = x0.shape[0]
initial_H = jnp.eye(d)
initial_H = jnp.eye(d, dtype=x0.dtype)
f_0, g_0 = jax.value_and_grad(fun)(x0)
state = _BFGSResults(
converged=jnp.linalg.norm(g_0, ord=norm) < gtol,

View File

@ -131,7 +131,11 @@ def _zoom(restricted_func_and_grad, wolfe_one, wolfe_two, a_lo, phi_lo,
a_j = jnp.where(use_quad, a_j_quad, a_j)
a_j = jnp.where(use_bisection, a_j_bisection, a_j)
# TODO(jakevdp): should we use some sort of fixed-point approach here instead?
phi_j, dphi_j, g_j = restricted_func_and_grad(a_j)
phi_j = phi_j.astype(state.phi_lo.dtype)
dphi_j = dphi_j.astype(state.dphi_lo.dtype)
g_j = g_j.astype(state.g_star.dtype)
state = state._replace(nfev=state.nfev + 1,
ngev=state.ngev + 1)

View File

@ -912,6 +912,8 @@ def _gen_associated_legendre(l_max: int,
p_val = p_val + h
return p_val
# TODO(jakevdp): use some sort of fixed-point procedure here instead?
p = p.astype(jnp.result_type(p, x, d0_mask_3d))
if l_max > 1:
p = lax.fori_loop(lower=2, upper=l_max+1, body_fun=body_fun, init_val=p)