mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Merge pull request #6428 from NeilGirdhar:weibull
PiperOrigin-RevId: 368313322
This commit is contained in:
commit
d17d9a8081
@ -1543,7 +1543,7 @@ def double_sided_maxwell(key: jnp.ndarray,
|
||||
return _double_sided_maxwell(key, loc, scale, shape, dtype)
|
||||
|
||||
|
||||
@partial(jit, static_argnums=(1, 2, 3, 4))
|
||||
@partial(jit, static_argnums=(3, 4))
|
||||
def _double_sided_maxwell(key, loc, scale, shape, dtype):
|
||||
params_shapes = lax.broadcast_shapes(np.shape(loc), np.shape(scale))
|
||||
if not shape:
|
||||
@ -1587,7 +1587,7 @@ def weibull_min(key: jnp.ndarray,
|
||||
return _weibull_min(key, scale, concentration, shape, dtype)
|
||||
|
||||
|
||||
@partial(jit, static_argnums=(1, 2, 3, 4))
|
||||
@partial(jit, static_argnums=(3, 4))
|
||||
def _weibull_min(key, scale, concentration, shape, dtype):
|
||||
random_uniform = uniform(
|
||||
key=key, shape=shape, minval=0, maxval=1, dtype=dtype)
|
||||
|
Loading…
x
Reference in New Issue
Block a user