Split dtype argument from other arguments in special functions.

This helps pytype to determine that the arguments are of different kinds, preventing type errors.

PiperOrigin-RevId: 519401250
This commit is contained in:
Peter Hawkins 2023-03-25 11:40:32 -07:00 committed by jax authors
parent a5d308542e
commit ec427f2c95

View File

@ -165,7 +165,7 @@ def lentz_thompson_barnett_algorithm(*,num_iterations, small, threshold, nth_par
return values[kHIdx]
def regularized_incomplete_beta_impl(a, b, x, dtype):
def regularized_incomplete_beta_impl(a, b, x, *, dtype):
shape = a.shape
def nth_partial_betainc_numerator(iteration, a, b, x):
@ -287,7 +287,7 @@ def _igamma_series(ax, x, a, enabled, dtype, mode):
else:
raise ValueError("Invalid IgammaMode")
def igamma_impl(a, x, dtype):
def igamma_impl(a, x, *, dtype):
is_nan = bitwise_or(_isnan(a), _isnan(x))
x_is_zero = eq(x, _const(x, 0))
x_is_infinity = eq(x, _const(x, float('inf')))
@ -418,7 +418,7 @@ def _igammac_continued_fraction(ax, x, a, enabled, dtype, mode):
else:
raise ValueError(f"Invalid mode: {mode}")
def igammac_impl(a, x, dtype):
def igammac_impl(a, x, *, dtype):
out_of_range = bitwise_or(le(x, _const(x, 0)), le(a, _const(a, 0)))
use_igamma = bitwise_or(lt(x, _const(x, 1)), lt(x, a))
ax = a * log(x) - x - lgamma(a)
@ -436,7 +436,7 @@ def igammac_impl(a, x, dtype):
result = select(x_is_infinity, full_like(result, 0), result)
return select(out_of_range, full_like(a, 1), result)
def igamma_grad_a_impl(a, x, dtype):
def igamma_grad_a_impl(a, x, *, dtype):
is_nan = bitwise_or(_isnan(a), _isnan(x))
x_is_zero = eq(x, full_like(x,0))
domain_error = bitwise_or(lt(x, full_like(x, 0)), le(a, full_like(a, 0)))
@ -456,7 +456,7 @@ def igamma_grad_a_impl(a, x, dtype):
full_like(a, float('nan')), output)
return output
def random_gamma_grad_impl(a, x, dtype):
def random_gamma_grad_impl(a, x, *, dtype):
is_nan = bitwise_or(_isnan(a), _isnan(x))
x_is_zero = eq(x, full_like(x,0))
domain_error = bitwise_or(lt(x, full_like(x,0)), le(a, full_like(a,0)))
@ -488,7 +488,7 @@ def _up_and_broadcast(doit):
a_x_type = np.float32
else:
a_x_type = a_dtype
result = doit(*args, a_x_type)
result = doit(*args, dtype=a_x_type)
if needs_upcast:
result = convert_element_type(result, a_dtype)
return result