vmap: preserve weak_type in batching tracer

This commit is contained in:
Jake VanderPlas 2022-03-30 11:06:56 -07:00
parent a04b777c54
commit 34f116c0e0
3 changed files with 15 additions and 4 deletions

View File

@ -1854,8 +1854,7 @@ def _scan_partial_eval(trace, *tracers, reverse, length, num_consts, num_carry,
out_extensive = [next(out_extensive_iter) if i is None
else _maybe_device_put(tracers[i].pval[1]) if tracers[i].is_known()
else tracers[i] for i in fwd_extensive]
assert all(a.strip_named_shape() == core.raise_to_shaped(
core.get_aval(out)).strip_named_shape()
assert all(core.typematch(a, core.get_aval(out))
for a, out in zip(extensive_avals, out_extensive))
out_flat = out_carry + out_extensive

View File

@ -1880,7 +1880,7 @@ def _map_shaped_array(size: int, axis: Optional[int], aval: ShapedArray
# TODO: Extend the named shape
if axis is None: return aval
return ShapedArray(tuple_delete(aval.shape, axis), aval.dtype,
named_shape=aval.named_shape)
named_shape=aval.named_shape, weak_type=aval.weak_type)
def _unmap_shaped_array(size: int, axis_name, axis: Optional[int],
aval: ShapedArray) -> ShapedArray:
@ -1889,7 +1889,7 @@ def _unmap_shaped_array(size: int, axis_name, axis: Optional[int],
named_shape.pop(axis_name, None)
if axis is None: return aval.replace(named_shape=named_shape)
return ShapedArray(tuple_insert(aval.shape, axis, size), aval.dtype,
named_shape=named_shape)
named_shape=named_shape, weak_type=aval.weak_type)
AvalMapHandlerPair = Tuple[Callable, Callable]
aval_mapping_handlers: Dict[Type, AvalMapHandlerPair] = {

View File

@ -24,6 +24,7 @@ from absl.testing import parameterized
import jax
import jax.numpy as jnp
import jax.scipy as jsp
from jax._src import dtypes
from jax._src import test_util as jtu
from jax import lax
from jax._src.lax import parallel
@ -1280,6 +1281,17 @@ class BatchingTest(jtu.JaxTestCase):
ans = vmapped_gradients_fn(vector) # doesn't crash
self.assertAllClose(ans, jnp.ones(2), check_dtypes=False)
def testBatchingPreservesWeakType(self):
# Regression test for https://github.com/google/jax/issues/10025
x = jnp.ravel(1)
self.assertTrue(dtypes.is_weakly_typed(x))
@vmap
def f(x):
self.assertTrue(dtypes.is_weakly_typed(x), f"{x} is not weakly-typed")
return x
y = f(x)
self.assertTrue(dtypes.is_weakly_typed(y))
Array = Any
ArrayElt = Any