mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 12:26:07 +00:00
make jvp(asarray, (1.,), (2.,)) produce Arrays
fixes #15676 Co-authored-by: Matthew Johnson <mattjj@google.com>
This commit is contained in:
parent
fa5915b34d
commit
65751bb328
@ -542,7 +542,10 @@ def _convert_element_type(operand: ArrayLike, new_dtype: Optional[DTypeLike] = N
|
||||
operand = np.asarray(operand).astype(new_dtype)
|
||||
old_weak_type = False
|
||||
|
||||
if (old_dtype, old_weak_type) == (new_dtype, weak_type) and isinstance(operand, Array):
|
||||
if ((old_dtype, old_weak_type) == (new_dtype, weak_type) and
|
||||
isinstance(operand, Array) and
|
||||
not (isinstance(operand, core.Tracer) and
|
||||
isinstance(core.get_aval(operand), core.ConcreteArray))):
|
||||
return type_cast(Array, operand)
|
||||
else:
|
||||
return convert_element_type_p.bind(operand, new_dtype=new_dtype,
|
||||
|
@ -4358,6 +4358,12 @@ class APITest(jtu.JaxTestCase):
|
||||
args_maker = lambda: [jnp.ones((), dtype=dtype)]
|
||||
self._CompileAndCheck(f, args_maker)
|
||||
|
||||
def test_jvp_asarray_returns_array(self):
|
||||
# https://github.com/google/jax/issues/15676
|
||||
p, t = jax.jvp(jax.numpy.asarray, (1.,), (2.,))
|
||||
_check_instance(self, p)
|
||||
_check_instance(self, t)
|
||||
|
||||
|
||||
class RematTest(jtu.JaxTestCase):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user