mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 04:46:06 +00:00
vmap: preserve weak_type in batching tracer
This commit is contained in:
parent
a04b777c54
commit
34f116c0e0
@ -1854,8 +1854,7 @@ def _scan_partial_eval(trace, *tracers, reverse, length, num_consts, num_carry,
|
||||
out_extensive = [next(out_extensive_iter) if i is None
|
||||
else _maybe_device_put(tracers[i].pval[1]) if tracers[i].is_known()
|
||||
else tracers[i] for i in fwd_extensive]
|
||||
assert all(a.strip_named_shape() == core.raise_to_shaped(
|
||||
core.get_aval(out)).strip_named_shape()
|
||||
assert all(core.typematch(a, core.get_aval(out))
|
||||
for a, out in zip(extensive_avals, out_extensive))
|
||||
out_flat = out_carry + out_extensive
|
||||
|
||||
|
@ -1880,7 +1880,7 @@ def _map_shaped_array(size: int, axis: Optional[int], aval: ShapedArray
|
||||
# TODO: Extend the named shape
|
||||
if axis is None: return aval
|
||||
return ShapedArray(tuple_delete(aval.shape, axis), aval.dtype,
|
||||
named_shape=aval.named_shape)
|
||||
named_shape=aval.named_shape, weak_type=aval.weak_type)
|
||||
|
||||
def _unmap_shaped_array(size: int, axis_name, axis: Optional[int],
|
||||
aval: ShapedArray) -> ShapedArray:
|
||||
@ -1889,7 +1889,7 @@ def _unmap_shaped_array(size: int, axis_name, axis: Optional[int],
|
||||
named_shape.pop(axis_name, None)
|
||||
if axis is None: return aval.replace(named_shape=named_shape)
|
||||
return ShapedArray(tuple_insert(aval.shape, axis, size), aval.dtype,
|
||||
named_shape=named_shape)
|
||||
named_shape=named_shape, weak_type=aval.weak_type)
|
||||
|
||||
AvalMapHandlerPair = Tuple[Callable, Callable]
|
||||
aval_mapping_handlers: Dict[Type, AvalMapHandlerPair] = {
|
||||
|
@ -24,6 +24,7 @@ from absl.testing import parameterized
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import jax.scipy as jsp
|
||||
from jax._src import dtypes
|
||||
from jax._src import test_util as jtu
|
||||
from jax import lax
|
||||
from jax._src.lax import parallel
|
||||
@ -1280,6 +1281,17 @@ class BatchingTest(jtu.JaxTestCase):
|
||||
ans = vmapped_gradients_fn(vector) # doesn't crash
|
||||
self.assertAllClose(ans, jnp.ones(2), check_dtypes=False)
|
||||
|
||||
def testBatchingPreservesWeakType(self):
|
||||
# Regression test for https://github.com/google/jax/issues/10025
|
||||
x = jnp.ravel(1)
|
||||
self.assertTrue(dtypes.is_weakly_typed(x))
|
||||
@vmap
|
||||
def f(x):
|
||||
self.assertTrue(dtypes.is_weakly_typed(x), f"{x} is not weakly-typed")
|
||||
return x
|
||||
y = f(x)
|
||||
self.assertTrue(dtypes.is_weakly_typed(y))
|
||||
|
||||
|
||||
Array = Any
|
||||
ArrayElt = Any
|
||||
|
Loading…
x
Reference in New Issue
Block a user