mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
parent
5031016465
commit
df87d5ce43
11
jax/lax.py
11
jax/lax.py
@ -418,6 +418,13 @@ def tie_in(x, y):
|
||||
return tie_in_p.bind(x, y)
|
||||
|
||||
def full(shape, fill_value, dtype):
|
||||
try:
|
||||
shape = tuple(map(int, shape))
|
||||
except TypeError:
|
||||
msg = ("`full` requires shapes to be concrete. If using `jit`, try using "
|
||||
"`static_argnums` or applying `jit` to smaller subfunctions instead.")
|
||||
raise TypeError(msg)
|
||||
|
||||
if onp.shape(fill_value):
|
||||
msg = "full must be called with scalar fill_value, got fill_value.shape {}."
|
||||
raise TypeError(msg.format(onp.shape(fill_value)))
|
||||
@ -2532,7 +2539,9 @@ def _check_shapelike(fun_name, arg_name, obj):
|
||||
def _dynamic_slice_indices(operand, start_indices):
|
||||
if isinstance(start_indices, (tuple, list)):
|
||||
start_indices = concatenate([reshape(i, [1]) for i in start_indices], 0)
|
||||
return rem(start_indices, onp.array(operand.shape, start_indices.dtype))
|
||||
# map int over operand.shape to raise any dynamic-shape errors
|
||||
shape = onp.asarray(map(int, operand.shape), start_indices.dtype)
|
||||
return rem(start_indices, shape)
|
||||
|
||||
|
||||
def _const(example, val):
|
||||
|
Loading…
x
Reference in New Issue
Block a user