mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00
Merge pull request #16123 from mattjj:refine-vmap-frame-getting
PiperOrigin-RevId: 537047149
This commit is contained in:
commit
adca0fa9b8
@ -333,18 +333,13 @@ class BatchTrace(Trace):
|
||||
frame.size, frame.name, frame.main_trace.trace_type)
|
||||
|
||||
def get_frame(self, vals, dims) -> core.AxisEnvFrame:
|
||||
frame = core.axis_frame(self.axis_name, self.main)
|
||||
assert frame.main_trace is self.main
|
||||
if any(d is not not_mapped for d in dims):
|
||||
sizes = (x.shape[d] if type(d) is int else len(d.segment_lengths)
|
||||
for x, d in zip(vals, dims) if d is not not_mapped)
|
||||
axis_size, = core.dedup_referents(sizes)
|
||||
else:
|
||||
axis_size = None # can't be inferred from data
|
||||
if self.axis_name is core.no_axis_name:
|
||||
assert axis_size is not None # must be inferrable from data
|
||||
return core.AxisEnvFrame(self.axis_name, axis_size, self.main)
|
||||
frame = core.axis_frame(self.axis_name, self.main)
|
||||
assert axis_size is None or axis_size == frame.size, (axis_size, frame.size)
|
||||
assert frame.main_trace is self.main
|
||||
data_axis_size, = core.dedup_referents(sizes)
|
||||
assert data_axis_size == frame.size
|
||||
return frame
|
||||
|
||||
def process_primitive(self, primitive, tracers, params):
|
||||
|
Loading…
x
Reference in New Issue
Block a user