Merge pull request #6428 from NeilGirdhar:weibull

PiperOrigin-RevId: 368313322
This commit is contained in:
jax authors 2021-04-13 15:59:56 -07:00
commit d17d9a8081

View File

@ -1543,7 +1543,7 @@ def double_sided_maxwell(key: jnp.ndarray,
return _double_sided_maxwell(key, loc, scale, shape, dtype)
@partial(jit, static_argnums=(1, 2, 3, 4))
@partial(jit, static_argnums=(3, 4))
def _double_sided_maxwell(key, loc, scale, shape, dtype):
params_shapes = lax.broadcast_shapes(np.shape(loc), np.shape(scale))
if not shape:
@ -1587,7 +1587,7 @@ def weibull_min(key: jnp.ndarray,
return _weibull_min(key, scale, concentration, shape, dtype)
@partial(jit, static_argnums=(1, 2, 3, 4))
@partial(jit, static_argnums=(3, 4))
def _weibull_min(key, scale, concentration, shape, dtype):
random_uniform = uniform(
key=key, shape=shape, minval=0, maxval=1, dtype=dtype)