make lax.full require concrete shapes

improves error message for #204
This commit is contained in:
Matthew Johnson 2019-01-07 12:28:52 -08:00
parent 5031016465
commit df87d5ce43

View File

@ -418,6 +418,13 @@ def tie_in(x, y):
return tie_in_p.bind(x, y)
def full(shape, fill_value, dtype):
try:
shape = tuple(map(int, shape))
except TypeError:
msg = ("`full` requires shapes to be concrete. If using `jit`, try using "
"`static_argnums` or applying `jit` to smaller subfunctions instead.")
raise TypeError(msg)
if onp.shape(fill_value):
msg = "full must be called with scalar fill_value, got fill_value.shape {}."
raise TypeError(msg.format(onp.shape(fill_value)))
@ -2532,7 +2539,9 @@ def _check_shapelike(fun_name, arg_name, obj):
def _dynamic_slice_indices(operand, start_indices):
if isinstance(start_indices, (tuple, list)):
start_indices = concatenate([reshape(i, [1]) for i in start_indices], 0)
return rem(start_indices, onp.array(operand.shape, start_indices.dtype))
# map int over operand.shape to raise any dynamic-shape errors
shape = onp.asarray(map(int, operand.shape), start_indices.dtype)
return rem(start_indices, shape)
def _const(example, val):