mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[jax2tf] Added special case for tf.pad. (#3462)
Fixed lax_reference.pad to handle lax.pad with negative edge padding.
This commit is contained in:
parent
ce782e610d
commit
4f21b9351c
@ -677,6 +677,10 @@ def _pad_shape(operand, padding_value, padding_config):
|
||||
|
||||
def _pad(operand, padding_value, padding_config):
|
||||
low, high, interior = util.unzip3(padding_config)
|
||||
if all(lo >= 0 and hi >= 0 and i == 0 for lo, hi, i in padding_config):
|
||||
return tf.pad(operand, util.safe_zip(low, high),
|
||||
mode="CONSTANT", constant_values=padding_value)
|
||||
# TODO(necula): implement shape inference for XlaPad
|
||||
out_shape = _pad_shape(operand, padding_value, padding_config)
|
||||
out = tfxla.pad(operand, padding_value, low, high, interior)
|
||||
out.set_shape(out_shape)
|
||||
|
@ -150,7 +150,7 @@ lax_pad = jtu.cases_from_list(
|
||||
for dtype in default_dtypes
|
||||
for pads in [
|
||||
[(0, 0, 0), (0, 0, 0)], # no padding
|
||||
[(1, 1, 0), (2, 2, 0)], # edge padding
|
||||
[(1, 1, 0), (2, 2, 0)], # only positive edge padding
|
||||
[(1, 2, 1), (0, 1, 0)], # edge padding and interior padding
|
||||
[(0, 0, 0), (-1, -1, 0)], # negative padding
|
||||
[(0, 0, 0), (-2, -2, 4)], # add big dilation then remove from edges
|
||||
|
@ -154,13 +154,14 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
|
||||
|
||||
@primitive_harness.parameterized(primitive_harness.lax_pad)
|
||||
def test_pad(self, harness: primitive_harness.Harness):
|
||||
# TODO: figure out the bfloat16 story
|
||||
if harness.params["dtype"] is dtypes.bfloat16:
|
||||
raise unittest.SkipTest("bfloat16 not implemented")
|
||||
# TODO: implement (or decide not to) pads with negative edge padding
|
||||
# TODO: fix pad with negative padding in XLA (fixed on 06/16/2020)
|
||||
if any([lo < 0 or hi < 0 for lo, hi, mid in harness.params["pads"]]):
|
||||
raise unittest.SkipTest("pad with negative pad not supported")
|
||||
self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()),
|
||||
with_function=True)
|
||||
with_function=False)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
dict(testcase_name=f"_{f_jax.__name__}",
|
||||
|
@ -232,16 +232,19 @@ def reshape(operand, new_sizes, dimensions=None):
|
||||
return np.reshape(np.transpose(operand, dimensions), new_sizes)
|
||||
|
||||
def pad(operand, padding_value, padding_config):
|
||||
# https://www.tensorflow.org/xla/operation_semantics#pad
|
||||
lo, hi, interior = zip(*padding_config)
|
||||
outshape = np.add(np.add(np.add(lo, hi), operand.shape),
|
||||
# Handle first the positive edge padding and interior
|
||||
lo_pos, hi_pos = np.clip(lo, 0, None), np.clip(hi, 0, None)
|
||||
outshape = np.add(np.add(np.add(lo_pos, hi_pos), operand.shape),
|
||||
np.multiply(interior, np.subtract(operand.shape, 1)))
|
||||
out = np.full(outshape, padding_value, operand.dtype)
|
||||
lhs_slices = tuple(_slice(l if l > 0 else 0, -h if h > 0 else None, step)
|
||||
for l, h, step in zip(lo, hi, np.add(1, interior)))
|
||||
rhs_slices = tuple(_slice(l if l < 0 else 0, -h if h < 0 else None)
|
||||
for l, h, step in zip(lo_pos, hi_pos, np.add(1, interior)))
|
||||
out[lhs_slices] = operand
|
||||
trim_slices = tuple(_slice(-l if l < 0 else 0, h if h < 0 else None)
|
||||
for l, h in zip(lo, hi))
|
||||
out[lhs_slices] = operand[rhs_slices]
|
||||
return out
|
||||
return out[trim_slices]
|
||||
|
||||
def rev(operand, dimensions):
|
||||
dimensions = frozenset(dimensions)
|
||||
|
@ -976,7 +976,14 @@ class LaxTest(jtu.JaxTestCase):
|
||||
"shape": shape, "dtype": dtype, "pads": pads, "rng_factory": jtu.rand_small}
|
||||
for shape in [(2, 3)]
|
||||
for dtype in default_dtypes
|
||||
for pads in [[(1, 2, 1), (0, 1, 0)]]))
|
||||
for pads in [
|
||||
[(0, 0, 0), (0, 0, 0)], # no padding
|
||||
[(1, 1, 0), (2, 2, 0)], # only positive edge padding
|
||||
[(1, 2, 1), (0, 1, 0)], # edge padding and interior padding
|
||||
[(0, 0, 0), (-1, -1, 0)], # negative padding
|
||||
[(0, 0, 0), (-2, -2, 4)], # add big dilation then remove from edges
|
||||
[(0, 0, 0), (-2, -3, 1)], # remove everything in one dimension
|
||||
]))
|
||||
def testPadAgainstNumpy(self, shape, dtype, pads, rng_factory):
|
||||
rng = rng_factory(self.rng())
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
|
Loading…
x
Reference in New Issue
Block a user