address reviewer comments

This commit is contained in:
Matthew Johnson 2019-08-24 12:34:44 -07:00
parent e90457d737
commit afe21bafa4
2 changed files with 3 additions and 2 deletions

View File

@ -1475,8 +1475,8 @@ def full_like(a, fill_value, dtype=None):
def zeros(shape, dtype=None):
if isinstance(shape, types.GeneratorType):
raise TypeError("expected sequence object with len >= 0 or a single integer")
dtype = onp.dtype("float64") if dtype is None else dtype
lax._check_user_dtype_supported(dtype, "zeros")
dtype = onp.dtype("float64") if dtype is None else dtype
shape = (shape,) if onp.isscalar(shape) else shape
return lax.full(shape, 0, dtype)

View File

@ -977,7 +977,8 @@ class APITest(jtu.JaxTestCase):
for x, y in zip(xs, ys):
self.assertAllClose(x * 2 - 3., y, check_dtypes=True)
def test_issue_1230(self):
def test_dtype_warning(self):
# cf. issue #1230
if FLAGS.jax_enable_x64:
return # test only applies when x64 is disabled