Merge pull request #16123 from mattjj:refine-vmap-frame-getting

PiperOrigin-RevId: 537047149
This commit is contained in:
jax authors 2023-06-01 09:42:43 -07:00
commit adca0fa9b8

View File

@ -333,18 +333,13 @@ class BatchTrace(Trace):
frame.size, frame.name, frame.main_trace.trace_type)
def get_frame(self, vals, dims) -> core.AxisEnvFrame:
frame = core.axis_frame(self.axis_name, self.main)
assert frame.main_trace is self.main
if any(d is not not_mapped for d in dims):
sizes = (x.shape[d] if type(d) is int else len(d.segment_lengths)
for x, d in zip(vals, dims) if d is not not_mapped)
axis_size, = core.dedup_referents(sizes)
else:
axis_size = None # can't be inferred from data
if self.axis_name is core.no_axis_name:
assert axis_size is not None # must be inferrable from data
return core.AxisEnvFrame(self.axis_name, axis_size, self.main)
frame = core.axis_frame(self.axis_name, self.main)
assert axis_size is None or axis_size == frame.size, (axis_size, frame.size)
assert frame.main_trace is self.main
data_axis_size, = core.dedup_referents(sizes)
assert data_axis_size == frame.size
return frame
def process_primitive(self, primitive, tracers, params):