compat.tensor __*__ methods no longer do implicit broadcasting

This change makes it simpler to lower binary operations to Triton IR
bypassing Triton Python bindings.

PiperOrigin-RevId: 601796719
This commit is contained in:
Sergei Lebedev 2024-01-26 10:06:22 -08:00 committed by jax authors
parent 2a8ce9ae9c
commit 273cb27047
2 changed files with 39 additions and 18 deletions

View File

@ -577,8 +577,8 @@ triton_lowering_rules[lax.max_p] = _max_lowering_rule
def _div_lowering_rule(ctx: TritonLoweringRuleContext, a, b):
if a.dtype.is_floating() or b.dtype.is_floating():
return a / b
return a // b
return tc.semantic.truediv(a, b)
return tc.semantic.floordiv(a, b)
triton_lowering_rules[lax.div_p] = _div_lowering_rule
@ -721,7 +721,8 @@ def _compute_pointers_from_indices(
if isinstance(index.start, int):
ptr_dim_offset = tc.arange(index.start, index.start + index.size)
else:
ptr_dim_offset = index.start + tc.arange(0, index.size)
ptr_dim_offset = tc.broadcast_to(index.start, [index.size])
ptr_dim_offset += tc.arange(0, index.size)
# We need to add broadcastable dimensions for the advanced int indexing
# and for previous slices
num_left_expand_dims = len(int_indexer_shape) + other_shape_idx
@ -746,7 +747,7 @@ def _compute_pointers_from_indices(
if not ptr_dim_offset.type.is_block() and indexer_shape:
ptr_dim_offset = tc.broadcast_to(
ptr_dim_offset,
[tc.constexpr(1)] * len(indexer_shape),
[1] * len(indexer_shape),
)
else:
for _ in range(num_left_expand_dims):
@ -755,20 +756,20 @@ def _compute_pointers_from_indices(
ndim = len(ptr_dim_offset.shape)
ptr_dim_offset = tc.expand_dims(ptr_dim_offset, ndim)
if start_offset is not None:
ptr_dim_offset += start_offset
stride_size = tc._to_tensor(int(dim_stride))
ptr_dim_offset += tc.broadcast_to(start_offset, ptr_dim_offset.shape)
stride_size = tc.broadcast_to(dim_stride, ptr_dim_offset.shape)
bcast_indices.append(ptr_dim_offset * stride_size)
block_shapes = [
() if not index.type.is_block() else tuple(index.type.get_block_shapes())
for index in bcast_indices
]
bcast_indices = [
tc.broadcast_to(index, map(tc.constexpr, indexer_shape))
tc.broadcast_to(index, indexer_shape)
if indexer_shape != block_shape
else index
for index, block_shape in zip(bcast_indices, block_shapes)
]
return sum(bcast_indices, root_ptr)
return sum(bcast_indices, tc.broadcast_to(root_ptr, indexer_shape))
def _pack_indices(non_slice_idx, indexed_dims):

View File

@ -811,40 +811,58 @@ def _to_tensor(v) -> "tensor":
class tensor(tl.core.tensor):
def __add__(self, other):
return semantic.add(self, _to_tensor(other))
other = _to_tensor(other)
assert self.shape == other.shape
return semantic.add(self, other)
def __radd__(self, other):
return self + other
def __sub__(self, other):
return semantic.sub(self, _to_tensor(other))
other = _to_tensor(other)
assert self.shape == other.shape
return semantic.sub(self, other)
def __rsub__(self, other):
return semantic.sub(_to_tensor(other), self)
def __mul__(self, other):
return semantic.mul(self, _to_tensor(other))
other = _to_tensor(other)
assert self.shape == other.shape
return semantic.mul(self, other)
def __rmul__(self, other):
return self * other
def __truediv__(self, other):
return semantic.truediv(self, _to_tensor(other))
other = _to_tensor(other)
assert self.shape == other.shape
return semantic.truediv(self, other)
def __rtruediv__(self, other):
return semantic.truediv(_to_tensor(other), self)
other = _to_tensor(other)
assert self.shape == other.shape
return semantic.truediv(other, self)
def __floordiv__(self, other):
return semantic.floordiv(self, _to_tensor(other))
other = _to_tensor(other)
assert self.shape == other.shape
return semantic.floordiv(self, other)
def __rfloordiv__(self, other):
return semantic.floordiv(_to_tensor(other), self)
other = _to_tensor(other)
assert self.shape == other.shape
return semantic.floordiv(other, self)
def __mod__(self, other):
return semantic.mod(self, _to_tensor(other))
other = _to_tensor(other)
assert self.shape == other.shape
return semantic.mod(self, other)
def __rmod__(self, other):
return semantic.mod(_to_tensor(other), self)
other = _to_tensor(other)
assert self.shape == other.shape
return semantic.mod(other, self)
def __neg__(self):
return semantic.minus(self)
@ -854,7 +872,9 @@ class tensor(tl.core.tensor):
# TODO(slebedev): Override other comparison methods.
def __eq__(self, other):
return semantic.equal(self, _to_tensor(other))
other = _to_tensor(other)
assert self.shape == other.shape
return semantic.equal(self, other)
def __getitem__(self, slices) -> tensor:
if isinstance(slices, (slice, constexpr)):