make jvp(asarray, (1.,), (2.,)) produce Arrays

fixes #15676

Co-authored-by: Matthew Johnson <mattjj@google.com>
This commit is contained in:
Jake VanderPlas 2023-07-20 09:21:55 -07:00
parent fa5915b34d
commit 65751bb328
2 changed files with 10 additions and 1 deletions

View File

@ -542,7 +542,10 @@ def _convert_element_type(operand: ArrayLike, new_dtype: Optional[DTypeLike] = N
operand = np.asarray(operand).astype(new_dtype)
old_weak_type = False
if (old_dtype, old_weak_type) == (new_dtype, weak_type) and isinstance(operand, Array):
if ((old_dtype, old_weak_type) == (new_dtype, weak_type) and
isinstance(operand, Array) and
not (isinstance(operand, core.Tracer) and
isinstance(core.get_aval(operand), core.ConcreteArray))):
return type_cast(Array, operand)
else:
return convert_element_type_p.bind(operand, new_dtype=new_dtype,

View File

@ -4358,6 +4358,12 @@ class APITest(jtu.JaxTestCase):
args_maker = lambda: [jnp.ones((), dtype=dtype)]
self._CompileAndCheck(f, args_maker)
def test_jvp_asarray_returns_array(self):
# https://github.com/google/jax/issues/15676
p, t = jax.jvp(jax.numpy.asarray, (1.,), (2.,))
_check_instance(self, p)
_check_instance(self, t)
class RematTest(jtu.JaxTestCase):