mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 12:26:07 +00:00
compat.tensor __*__ methods no longer do implicit broadcasting
This change makes it simpler to lower binary operations to Triton IR bypassing Triton Python bindings. PiperOrigin-RevId: 601796719
This commit is contained in:
parent
2a8ce9ae9c
commit
273cb27047
@ -577,8 +577,8 @@ triton_lowering_rules[lax.max_p] = _max_lowering_rule
|
||||
|
||||
def _div_lowering_rule(ctx: TritonLoweringRuleContext, a, b):
|
||||
if a.dtype.is_floating() or b.dtype.is_floating():
|
||||
return a / b
|
||||
return a // b
|
||||
return tc.semantic.truediv(a, b)
|
||||
return tc.semantic.floordiv(a, b)
|
||||
|
||||
|
||||
triton_lowering_rules[lax.div_p] = _div_lowering_rule
|
||||
@ -721,7 +721,8 @@ def _compute_pointers_from_indices(
|
||||
if isinstance(index.start, int):
|
||||
ptr_dim_offset = tc.arange(index.start, index.start + index.size)
|
||||
else:
|
||||
ptr_dim_offset = index.start + tc.arange(0, index.size)
|
||||
ptr_dim_offset = tc.broadcast_to(index.start, [index.size])
|
||||
ptr_dim_offset += tc.arange(0, index.size)
|
||||
# We need to add broadcastable dimensions for the advanced int indexing
|
||||
# and for previous slices
|
||||
num_left_expand_dims = len(int_indexer_shape) + other_shape_idx
|
||||
@ -746,7 +747,7 @@ def _compute_pointers_from_indices(
|
||||
if not ptr_dim_offset.type.is_block() and indexer_shape:
|
||||
ptr_dim_offset = tc.broadcast_to(
|
||||
ptr_dim_offset,
|
||||
[tc.constexpr(1)] * len(indexer_shape),
|
||||
[1] * len(indexer_shape),
|
||||
)
|
||||
else:
|
||||
for _ in range(num_left_expand_dims):
|
||||
@ -755,20 +756,20 @@ def _compute_pointers_from_indices(
|
||||
ndim = len(ptr_dim_offset.shape)
|
||||
ptr_dim_offset = tc.expand_dims(ptr_dim_offset, ndim)
|
||||
if start_offset is not None:
|
||||
ptr_dim_offset += start_offset
|
||||
stride_size = tc._to_tensor(int(dim_stride))
|
||||
ptr_dim_offset += tc.broadcast_to(start_offset, ptr_dim_offset.shape)
|
||||
stride_size = tc.broadcast_to(dim_stride, ptr_dim_offset.shape)
|
||||
bcast_indices.append(ptr_dim_offset * stride_size)
|
||||
block_shapes = [
|
||||
() if not index.type.is_block() else tuple(index.type.get_block_shapes())
|
||||
for index in bcast_indices
|
||||
]
|
||||
bcast_indices = [
|
||||
tc.broadcast_to(index, map(tc.constexpr, indexer_shape))
|
||||
tc.broadcast_to(index, indexer_shape)
|
||||
if indexer_shape != block_shape
|
||||
else index
|
||||
for index, block_shape in zip(bcast_indices, block_shapes)
|
||||
]
|
||||
return sum(bcast_indices, root_ptr)
|
||||
return sum(bcast_indices, tc.broadcast_to(root_ptr, indexer_shape))
|
||||
|
||||
|
||||
def _pack_indices(non_slice_idx, indexed_dims):
|
||||
|
@ -811,40 +811,58 @@ def _to_tensor(v) -> "tensor":
|
||||
class tensor(tl.core.tensor):
|
||||
|
||||
def __add__(self, other):
|
||||
return semantic.add(self, _to_tensor(other))
|
||||
other = _to_tensor(other)
|
||||
assert self.shape == other.shape
|
||||
return semantic.add(self, other)
|
||||
|
||||
def __radd__(self, other):
|
||||
return self + other
|
||||
|
||||
def __sub__(self, other):
|
||||
return semantic.sub(self, _to_tensor(other))
|
||||
other = _to_tensor(other)
|
||||
assert self.shape == other.shape
|
||||
return semantic.sub(self, other)
|
||||
|
||||
def __rsub__(self, other):
|
||||
return semantic.sub(_to_tensor(other), self)
|
||||
|
||||
def __mul__(self, other):
|
||||
return semantic.mul(self, _to_tensor(other))
|
||||
other = _to_tensor(other)
|
||||
assert self.shape == other.shape
|
||||
return semantic.mul(self, other)
|
||||
|
||||
def __rmul__(self, other):
|
||||
return self * other
|
||||
|
||||
def __truediv__(self, other):
|
||||
return semantic.truediv(self, _to_tensor(other))
|
||||
other = _to_tensor(other)
|
||||
assert self.shape == other.shape
|
||||
return semantic.truediv(self, other)
|
||||
|
||||
def __rtruediv__(self, other):
|
||||
return semantic.truediv(_to_tensor(other), self)
|
||||
other = _to_tensor(other)
|
||||
assert self.shape == other.shape
|
||||
return semantic.truediv(other, self)
|
||||
|
||||
def __floordiv__(self, other):
|
||||
return semantic.floordiv(self, _to_tensor(other))
|
||||
other = _to_tensor(other)
|
||||
assert self.shape == other.shape
|
||||
return semantic.floordiv(self, other)
|
||||
|
||||
def __rfloordiv__(self, other):
|
||||
return semantic.floordiv(_to_tensor(other), self)
|
||||
other = _to_tensor(other)
|
||||
assert self.shape == other.shape
|
||||
return semantic.floordiv(other, self)
|
||||
|
||||
def __mod__(self, other):
|
||||
return semantic.mod(self, _to_tensor(other))
|
||||
other = _to_tensor(other)
|
||||
assert self.shape == other.shape
|
||||
return semantic.mod(self, other)
|
||||
|
||||
def __rmod__(self, other):
|
||||
return semantic.mod(_to_tensor(other), self)
|
||||
other = _to_tensor(other)
|
||||
assert self.shape == other.shape
|
||||
return semantic.mod(other, self)
|
||||
|
||||
def __neg__(self):
|
||||
return semantic.minus(self)
|
||||
@ -854,7 +872,9 @@ class tensor(tl.core.tensor):
|
||||
|
||||
# TODO(slebedev): Override other comparison methods.
|
||||
def __eq__(self, other):
|
||||
return semantic.equal(self, _to_tensor(other))
|
||||
other = _to_tensor(other)
|
||||
assert self.shape == other.shape
|
||||
return semantic.equal(self, other)
|
||||
|
||||
def __getitem__(self, slices) -> tensor:
|
||||
if isinstance(slices, (slice, constexpr)):
|
||||
|
Loading…
x
Reference in New Issue
Block a user