Merge pull request #12497 from mattjj:djax-dag-fix1

PiperOrigin-RevId: 477038279
This commit is contained in:
jax authors 2022-09-26 18:14:56 -07:00
commit 1bcf8d646d
3 changed files with 29 additions and 10 deletions

View File

@ -2095,7 +2095,9 @@ def infer_lambda_input_type(
idxs, implicit_types = _collect_implicit(args, specs)
implicit_sig = [(ty, False) for ty in implicit_types]
explicit_sig = [(_arg_type(idxs, x, s), True) for x, s in zip(args, specs)]
return (*implicit_sig, *explicit_sig)
input_type = (*implicit_sig, *explicit_sig)
lu._check_input_type(input_type)
return input_type
def _canonicalize_specs(
ndims: Sequence[int], specs: Optional[Sequence[AbstractedAxesSpec]]
@ -2143,6 +2145,7 @@ def _complete_specs(
for x, spec in zip(args, specs))
return specs
def _collect_implicit(
args: Sequence[Any], specs: List[Dict[int, AbstractedAxisName]]
) -> Tuple[Dict[AbstractedAxisName, DBIdx], List[AbstractValue]]:
@ -2153,24 +2156,23 @@ def _collect_implicit(
idxs: Dict[AbstractedAxisName, DBIdx] = {}
implicit_types: List[AbstractValue] = []
explicit_tracers: Dict[TracerId, int] = {}
counter = (DBIdx(i) for i in it.count())
counter = it.count()
# Add implicit arguments to idxs.
for explicit_idx, (x, spec) in enumerate(zip(args, specs)):
for i, name in spec.items():
if name not in idxs and id(x.shape[i]) not in explicit_tracers:
idxs[name] = next(counter)
idxs[name] = DBIdx(next(counter))
implicit_types.append(raise_to_shaped(get_aval(x.shape[i])))
if isinstance(x, Tracer):
explicit_tracers[id(x)] = explicit_idx
explicit_tracers.setdefault(id(x), explicit_idx) # use the first
# Now that we know the implicit args, add explicit args to idxs.
offset = len(implicit_types)
for x, spec in zip(args, specs):
for i, name in spec.items():
if id(x.shape[i]) in explicit_tracers:
idxs[name] = DBIdx(offset + explicit_tracers[id(x.shape[i])])
idxs.setdefault(name, DBIdx(offset + explicit_tracers[id(x.shape[i])]))
return idxs, implicit_types

View File

@ -236,17 +236,29 @@ def wrap_init(f, params=None) -> WrappedFun:
params = () if params is None else tuple(sorted(params.items()))
return WrappedFun(f, (), (), params, None)
def annotate(f: WrappedFun,
in_type: Optional[Tuple[Tuple[core.AbstractValue, bool], ...]]
) -> WrappedFun:
def annotate(f: WrappedFun, in_type: core.InputType) -> WrappedFun:
assert f.in_type is None
if in_type is None:
return f
_check_input_type(in_type)
return WrappedFun(f.f, f.transforms, f.stores, f.params, in_type)
def _check_input_type(in_type: core.InputType) -> None:
# Check that in_type is syntactically well-formed
assert (type(in_type) is tuple and all(type(e) is tuple for e in in_type) and
all(isinstance(a, core.AbstractValue) and type(b) is bool
and not isinstance(a, core.ConcreteArray) for a, b in in_type) and
all(isinstance(d, (int, core.BInt, core.DBIdx)) for a, _ in in_type
if type(a) is core.DShapedArray for d in a.shape))
# Check that all DBIdx point to positions to the left of the input on which
# they appear.
assert all(d.val < i for i, (aval, _) in enumerate(in_type)
if isinstance(aval, core.DShapedArray) for d in aval.shape
if isinstance(d, core.DBIdx))
# Check that all implicit arguments have at least one DBIdx pointing to them.
provided = [e for _, e in in_type]
for aval, _ in in_type:
if type(aval) is core.DShapedArray:
@ -254,7 +266,6 @@ def annotate(f: WrappedFun,
if isinstance(d, core.DBIdx):
provided[d.val] = True
assert all(provided)
return WrappedFun(f.f, f.transforms, f.stores, f.params, in_type)
class _CacheLocalContext(threading.local):

View File

@ -1354,6 +1354,12 @@ class DynamicShapeTest(jtu.JaxTestCase):
d, = jaxpr.eqns[0].outvars
self.assertEqual(d.aval.shape, (a, a))
def test_inferring_valid_subjaxpr_type_add(self):
def f(x):
return x + x.shape[0]
jax.make_jaxpr(f, abstracted_axes=('n',))(jnp.arange(3)) # doesn't crash
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())