mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
address reviewer comments
This commit is contained in:
parent
e90457d737
commit
afe21bafa4
@ -1475,8 +1475,8 @@ def full_like(a, fill_value, dtype=None):
|
||||
def zeros(shape, dtype=None):
|
||||
if isinstance(shape, types.GeneratorType):
|
||||
raise TypeError("expected sequence object with len >= 0 or a single integer")
|
||||
dtype = onp.dtype("float64") if dtype is None else dtype
|
||||
lax._check_user_dtype_supported(dtype, "zeros")
|
||||
dtype = onp.dtype("float64") if dtype is None else dtype
|
||||
shape = (shape,) if onp.isscalar(shape) else shape
|
||||
return lax.full(shape, 0, dtype)
|
||||
|
||||
|
@ -977,7 +977,8 @@ class APITest(jtu.JaxTestCase):
|
||||
for x, y in zip(xs, ys):
|
||||
self.assertAllClose(x * 2 - 3., y, check_dtypes=True)
|
||||
|
||||
def test_issue_1230(self):
|
||||
def test_dtype_warning(self):
|
||||
# cf. issue #1230
|
||||
if FLAGS.jax_enable_x64:
|
||||
return # test only applies when x64 is disabled
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user