mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #12497 from mattjj:djax-dag-fix1
PiperOrigin-RevId: 477038279
This commit is contained in:
commit
1bcf8d646d
@ -2095,7 +2095,9 @@ def infer_lambda_input_type(
|
||||
idxs, implicit_types = _collect_implicit(args, specs)
|
||||
implicit_sig = [(ty, False) for ty in implicit_types]
|
||||
explicit_sig = [(_arg_type(idxs, x, s), True) for x, s in zip(args, specs)]
|
||||
return (*implicit_sig, *explicit_sig)
|
||||
input_type = (*implicit_sig, *explicit_sig)
|
||||
lu._check_input_type(input_type)
|
||||
return input_type
|
||||
|
||||
def _canonicalize_specs(
|
||||
ndims: Sequence[int], specs: Optional[Sequence[AbstractedAxesSpec]]
|
||||
@ -2143,6 +2145,7 @@ def _complete_specs(
|
||||
for x, spec in zip(args, specs))
|
||||
return specs
|
||||
|
||||
|
||||
def _collect_implicit(
|
||||
args: Sequence[Any], specs: List[Dict[int, AbstractedAxisName]]
|
||||
) -> Tuple[Dict[AbstractedAxisName, DBIdx], List[AbstractValue]]:
|
||||
@ -2153,24 +2156,23 @@ def _collect_implicit(
|
||||
idxs: Dict[AbstractedAxisName, DBIdx] = {}
|
||||
implicit_types: List[AbstractValue] = []
|
||||
explicit_tracers: Dict[TracerId, int] = {}
|
||||
counter = (DBIdx(i) for i in it.count())
|
||||
counter = it.count()
|
||||
|
||||
# Add implicit arguments to idxs.
|
||||
|
||||
for explicit_idx, (x, spec) in enumerate(zip(args, specs)):
|
||||
for i, name in spec.items():
|
||||
if name not in idxs and id(x.shape[i]) not in explicit_tracers:
|
||||
idxs[name] = next(counter)
|
||||
idxs[name] = DBIdx(next(counter))
|
||||
implicit_types.append(raise_to_shaped(get_aval(x.shape[i])))
|
||||
if isinstance(x, Tracer):
|
||||
explicit_tracers[id(x)] = explicit_idx
|
||||
explicit_tracers.setdefault(id(x), explicit_idx) # use the first
|
||||
|
||||
# Now that we know the implicit args, add explicit args to idxs.
|
||||
offset = len(implicit_types)
|
||||
for x, spec in zip(args, specs):
|
||||
for i, name in spec.items():
|
||||
if id(x.shape[i]) in explicit_tracers:
|
||||
idxs[name] = DBIdx(offset + explicit_tracers[id(x.shape[i])])
|
||||
idxs.setdefault(name, DBIdx(offset + explicit_tracers[id(x.shape[i])]))
|
||||
|
||||
return idxs, implicit_types
|
||||
|
||||
|
@ -236,17 +236,29 @@ def wrap_init(f, params=None) -> WrappedFun:
|
||||
params = () if params is None else tuple(sorted(params.items()))
|
||||
return WrappedFun(f, (), (), params, None)
|
||||
|
||||
def annotate(f: WrappedFun,
|
||||
in_type: Optional[Tuple[Tuple[core.AbstractValue, bool], ...]]
|
||||
) -> WrappedFun:
|
||||
|
||||
def annotate(f: WrappedFun, in_type: core.InputType) -> WrappedFun:
|
||||
assert f.in_type is None
|
||||
if in_type is None:
|
||||
return f
|
||||
_check_input_type(in_type)
|
||||
return WrappedFun(f.f, f.transforms, f.stores, f.params, in_type)
|
||||
|
||||
def _check_input_type(in_type: core.InputType) -> None:
|
||||
# Check that in_type is syntactically well-formed
|
||||
assert (type(in_type) is tuple and all(type(e) is tuple for e in in_type) and
|
||||
all(isinstance(a, core.AbstractValue) and type(b) is bool
|
||||
and not isinstance(a, core.ConcreteArray) for a, b in in_type) and
|
||||
all(isinstance(d, (int, core.BInt, core.DBIdx)) for a, _ in in_type
|
||||
if type(a) is core.DShapedArray for d in a.shape))
|
||||
|
||||
# Check that all DBIdx point to positions to the left of the input on which
|
||||
# they appear.
|
||||
assert all(d.val < i for i, (aval, _) in enumerate(in_type)
|
||||
if isinstance(aval, core.DShapedArray) for d in aval.shape
|
||||
if isinstance(d, core.DBIdx))
|
||||
|
||||
# Check that all implicit arguments have at least one DBIdx pointing to them.
|
||||
provided = [e for _, e in in_type]
|
||||
for aval, _ in in_type:
|
||||
if type(aval) is core.DShapedArray:
|
||||
@ -254,7 +266,6 @@ def annotate(f: WrappedFun,
|
||||
if isinstance(d, core.DBIdx):
|
||||
provided[d.val] = True
|
||||
assert all(provided)
|
||||
return WrappedFun(f.f, f.transforms, f.stores, f.params, in_type)
|
||||
|
||||
|
||||
class _CacheLocalContext(threading.local):
|
||||
|
@ -1354,6 +1354,12 @@ class DynamicShapeTest(jtu.JaxTestCase):
|
||||
d, = jaxpr.eqns[0].outvars
|
||||
self.assertEqual(d.aval.shape, (a, a))
|
||||
|
||||
def test_inferring_valid_subjaxpr_type_add(self):
|
||||
def f(x):
|
||||
return x + x.shape[0]
|
||||
|
||||
jax.make_jaxpr(f, abstracted_axes=('n',))(jnp.arange(3)) # doesn't crash
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user