mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Split dtype argument from other arguments in special functions.
This helps pytype to determine that the arguments are of different kinds, preventing type errors. PiperOrigin-RevId: 519401250
This commit is contained in:
parent
a5d308542e
commit
ec427f2c95
@ -165,7 +165,7 @@ def lentz_thompson_barnett_algorithm(*,num_iterations, small, threshold, nth_par
|
||||
return values[kHIdx]
|
||||
|
||||
|
||||
def regularized_incomplete_beta_impl(a, b, x, dtype):
|
||||
def regularized_incomplete_beta_impl(a, b, x, *, dtype):
|
||||
shape = a.shape
|
||||
|
||||
def nth_partial_betainc_numerator(iteration, a, b, x):
|
||||
@ -287,7 +287,7 @@ def _igamma_series(ax, x, a, enabled, dtype, mode):
|
||||
else:
|
||||
raise ValueError("Invalid IgammaMode")
|
||||
|
||||
def igamma_impl(a, x, dtype):
|
||||
def igamma_impl(a, x, *, dtype):
|
||||
is_nan = bitwise_or(_isnan(a), _isnan(x))
|
||||
x_is_zero = eq(x, _const(x, 0))
|
||||
x_is_infinity = eq(x, _const(x, float('inf')))
|
||||
@ -418,7 +418,7 @@ def _igammac_continued_fraction(ax, x, a, enabled, dtype, mode):
|
||||
else:
|
||||
raise ValueError(f"Invalid mode: {mode}")
|
||||
|
||||
def igammac_impl(a, x, dtype):
|
||||
def igammac_impl(a, x, *, dtype):
|
||||
out_of_range = bitwise_or(le(x, _const(x, 0)), le(a, _const(a, 0)))
|
||||
use_igamma = bitwise_or(lt(x, _const(x, 1)), lt(x, a))
|
||||
ax = a * log(x) - x - lgamma(a)
|
||||
@ -436,7 +436,7 @@ def igammac_impl(a, x, dtype):
|
||||
result = select(x_is_infinity, full_like(result, 0), result)
|
||||
return select(out_of_range, full_like(a, 1), result)
|
||||
|
||||
def igamma_grad_a_impl(a, x, dtype):
|
||||
def igamma_grad_a_impl(a, x, *, dtype):
|
||||
is_nan = bitwise_or(_isnan(a), _isnan(x))
|
||||
x_is_zero = eq(x, full_like(x,0))
|
||||
domain_error = bitwise_or(lt(x, full_like(x, 0)), le(a, full_like(a, 0)))
|
||||
@ -456,7 +456,7 @@ def igamma_grad_a_impl(a, x, dtype):
|
||||
full_like(a, float('nan')), output)
|
||||
return output
|
||||
|
||||
def random_gamma_grad_impl(a, x, dtype):
|
||||
def random_gamma_grad_impl(a, x, *, dtype):
|
||||
is_nan = bitwise_or(_isnan(a), _isnan(x))
|
||||
x_is_zero = eq(x, full_like(x,0))
|
||||
domain_error = bitwise_or(lt(x, full_like(x,0)), le(a, full_like(a,0)))
|
||||
@ -488,7 +488,7 @@ def _up_and_broadcast(doit):
|
||||
a_x_type = np.float32
|
||||
else:
|
||||
a_x_type = a_dtype
|
||||
result = doit(*args, a_x_type)
|
||||
result = doit(*args, dtype=a_x_type)
|
||||
if needs_upcast:
|
||||
result = convert_element_type(result, a_dtype)
|
||||
return result
|
||||
|
Loading…
x
Reference in New Issue
Block a user