mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
Tighten up dtypes across the package
This commit is contained in:
parent
853fca2245
commit
40d6f5ed90
@ -7141,7 +7141,7 @@ def _check_shapelike(fun_name, arg_name, obj, non_zero_shape=False):
|
||||
raise TypeError(msg.format(fun_name, arg_name, bound_error, obj))
|
||||
|
||||
|
||||
def _dynamic_slice_indices(operand, start_indices):
|
||||
def _dynamic_slice_indices(operand, start_indices):
|
||||
# Normalize the start_indices w.r.t. operand.shape
|
||||
if len(start_indices) != operand.ndim:
|
||||
msg = ("Length of slice indices must match number of operand dimensions ({} "
|
||||
|
@ -314,7 +314,7 @@ def normalize(x: Array,
|
||||
return (x - mean) * lax.rsqrt(variance + epsilon)
|
||||
|
||||
def one_hot(x: Array, num_classes: int, *,
|
||||
dtype: Any = jnp.float64, axis: Union[int, AxisName] = -1) -> Array:
|
||||
dtype: Any = jnp.float_, axis: Union[int, AxisName] = -1) -> Array:
|
||||
"""One-hot encodes the given indicies.
|
||||
|
||||
Each index in the input ``x`` is encoded as a vector of zeros of length
|
||||
@ -334,8 +334,7 @@ def one_hot(x: Array, num_classes: int, *,
|
||||
Args:
|
||||
x: A tensor of indices.
|
||||
num_classes: Number of classes in the one-hot dimension.
|
||||
dtype: optional, a float dtype for the returned values (default float64 if
|
||||
jax_enable_x64 is true, otherwise float32).
|
||||
dtype: optional, a float dtype for the returned values (default :obj:`jnp.float_`).
|
||||
axis: the axis or axes along which the function should be
|
||||
computed.
|
||||
"""
|
||||
|
@ -1784,7 +1784,7 @@ def ravel_multi_index(multi_index, dims, mode='raise', order='C'):
|
||||
else:
|
||||
raise ValueError(f"invalid order={order!r}. Expected 'C' or 'F'")
|
||||
|
||||
result = 0
|
||||
result = array(0, dtype=dtypes.canonicalize_dtype(int_))
|
||||
for i, s in zip(multi_index, strides):
|
||||
result = result + i * s
|
||||
return result
|
||||
@ -1798,8 +1798,7 @@ and out-of-bounds indices are clipped.
|
||||
@_wraps(np.unravel_index, lax_description=_UNRAVEL_INDEX_DOC)
|
||||
def unravel_index(indices, shape):
|
||||
_check_arraylike("unravel_index", indices)
|
||||
shape = core.concrete_or_error(tuple, shape, context="shape argument of unravel_index")
|
||||
sizes = array(tuple(shape) + (1,))
|
||||
sizes = append(array(shape), 1)
|
||||
cumulative_sizes = cumprod(sizes[::-1])[::-1]
|
||||
total_size = cumulative_sizes[0]
|
||||
# Clip so raveling and unraveling an oob index will not change the behavior
|
||||
@ -5200,7 +5199,7 @@ def _argmax(a, axis: Optional[int] = None, out=None):
|
||||
axis = 0
|
||||
if a.shape[axis] == 0:
|
||||
raise ValueError("attempt to get argmax of an empty sequence")
|
||||
return lax.argmax(a, _canonicalize_axis(axis, a.ndim), int64)
|
||||
return lax.argmax(a, _canonicalize_axis(axis, a.ndim), dtypes.canonicalize_dtype(int_))
|
||||
|
||||
@_wraps(np.argmin, skip_params=['out'])
|
||||
def argmin(a, axis: Optional[int] = None, out=None):
|
||||
@ -5216,7 +5215,7 @@ def _argmin(a, axis: Optional[int] = None, out=None):
|
||||
axis = 0
|
||||
if a.shape[axis] == 0:
|
||||
raise ValueError("attempt to get argmin of an empty sequence")
|
||||
return lax.argmin(a, _canonicalize_axis(axis, a.ndim), int64)
|
||||
return lax.argmin(a, _canonicalize_axis(axis, a.ndim), dtypes.canonicalize_dtype(int_))
|
||||
|
||||
|
||||
_NANARG_DOC = """\
|
||||
|
@ -174,15 +174,16 @@ def _minimize_lbfgs(
|
||||
|
||||
converged = jnp.linalg.norm(g_kp1, ord=norm) < gtol
|
||||
|
||||
# TODO(jakevdp): use a fixed-point procedure rather than type-casting?
|
||||
state = state._replace(
|
||||
converged=converged,
|
||||
failed=(status > 0) & (~converged),
|
||||
k=state.k + 1,
|
||||
nfev=state.nfev + ls_results.nfev,
|
||||
ngev=state.ngev + ls_results.ngev,
|
||||
x_k=x_kp1,
|
||||
f_k=f_kp1,
|
||||
g_k=g_kp1,
|
||||
x_k=x_kp1.astype(state.x_k.dtype),
|
||||
f_k=f_kp1.astype(state.f_k.dtype),
|
||||
g_k=g_kp1.astype(state.g_k.dtype),
|
||||
s_history=_update_history_vectors(history=state.s_history, new=s_k),
|
||||
y_history=_update_history_vectors(history=state.y_history, new=y_k),
|
||||
rho_history=_update_history_scalars(history=state.rho_history, new=rho_k),
|
||||
|
@ -96,7 +96,7 @@ def minimize_bfgs(
|
||||
|
||||
d = x0.shape[0]
|
||||
|
||||
initial_H = jnp.eye(d)
|
||||
initial_H = jnp.eye(d, dtype=x0.dtype)
|
||||
f_0, g_0 = jax.value_and_grad(fun)(x0)
|
||||
state = _BFGSResults(
|
||||
converged=jnp.linalg.norm(g_0, ord=norm) < gtol,
|
||||
|
@ -131,7 +131,11 @@ def _zoom(restricted_func_and_grad, wolfe_one, wolfe_two, a_lo, phi_lo,
|
||||
a_j = jnp.where(use_quad, a_j_quad, a_j)
|
||||
a_j = jnp.where(use_bisection, a_j_bisection, a_j)
|
||||
|
||||
# TODO(jakevdp): should we use some sort of fixed-point approach here instead?
|
||||
phi_j, dphi_j, g_j = restricted_func_and_grad(a_j)
|
||||
phi_j = phi_j.astype(state.phi_lo.dtype)
|
||||
dphi_j = dphi_j.astype(state.dphi_lo.dtype)
|
||||
g_j = g_j.astype(state.g_star.dtype)
|
||||
state = state._replace(nfev=state.nfev + 1,
|
||||
ngev=state.ngev + 1)
|
||||
|
||||
|
@ -912,6 +912,8 @@ def _gen_associated_legendre(l_max: int,
|
||||
p_val = p_val + h
|
||||
return p_val
|
||||
|
||||
# TODO(jakevdp): use some sort of fixed-point procedure here instead?
|
||||
p = p.astype(jnp.result_type(p, x, d0_mask_3d))
|
||||
if l_max > 1:
|
||||
p = lax.fori_loop(lower=2, upper=l_max+1, body_fun=body_fun, init_val=p)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user