[jax2tf] Added special case for tf.pad. (#3462)

Fixed lax_reference.pad to handle lax.pad with negative edge padding.
This commit is contained in:
George Necula 2020-06-17 11:57:21 +03:00 committed by GitHub
parent ce782e610d
commit 4f21b9351c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 24 additions and 9 deletions

View File

@ -677,6 +677,10 @@ def _pad_shape(operand, padding_value, padding_config):
def _pad(operand, padding_value, padding_config):
low, high, interior = util.unzip3(padding_config)
if all(lo >= 0 and hi >= 0 and i == 0 for lo, hi, i in padding_config):
return tf.pad(operand, util.safe_zip(low, high),
mode="CONSTANT", constant_values=padding_value)
# TODO(necula): implement shape inference for XlaPad
out_shape = _pad_shape(operand, padding_value, padding_config)
out = tfxla.pad(operand, padding_value, low, high, interior)
out.set_shape(out_shape)

View File

@ -150,7 +150,7 @@ lax_pad = jtu.cases_from_list(
for dtype in default_dtypes
for pads in [
[(0, 0, 0), (0, 0, 0)], # no padding
[(1, 1, 0), (2, 2, 0)], # edge padding
[(1, 1, 0), (2, 2, 0)], # only positive edge padding
[(1, 2, 1), (0, 1, 0)], # edge padding and interior padding
[(0, 0, 0), (-1, -1, 0)], # negative padding
[(0, 0, 0), (-2, -2, 4)], # add big dilation then remove from edges

View File

@ -154,13 +154,14 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
@primitive_harness.parameterized(primitive_harness.lax_pad)
def test_pad(self, harness: primitive_harness.Harness):
# TODO: figure out the bfloat16 story
if harness.params["dtype"] is dtypes.bfloat16:
raise unittest.SkipTest("bfloat16 not implemented")
# TODO: implement (or decide not to) pads with negative edge padding
# TODO: fix pad with negative padding in XLA (fixed on 06/16/2020)
if any([lo < 0 or hi < 0 for lo, hi, mid in harness.params["pads"]]):
raise unittest.SkipTest("pad with negative pad not supported")
self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()),
with_function=True)
with_function=False)
@parameterized.named_parameters(jtu.cases_from_list(
dict(testcase_name=f"_{f_jax.__name__}",

View File

@ -232,16 +232,19 @@ def reshape(operand, new_sizes, dimensions=None):
return np.reshape(np.transpose(operand, dimensions), new_sizes)
def pad(operand, padding_value, padding_config):
# https://www.tensorflow.org/xla/operation_semantics#pad
lo, hi, interior = zip(*padding_config)
outshape = np.add(np.add(np.add(lo, hi), operand.shape),
# Handle first the positive edge padding and interior
lo_pos, hi_pos = np.clip(lo, 0, None), np.clip(hi, 0, None)
outshape = np.add(np.add(np.add(lo_pos, hi_pos), operand.shape),
np.multiply(interior, np.subtract(operand.shape, 1)))
out = np.full(outshape, padding_value, operand.dtype)
lhs_slices = tuple(_slice(l if l > 0 else 0, -h if h > 0 else None, step)
for l, h, step in zip(lo, hi, np.add(1, interior)))
rhs_slices = tuple(_slice(l if l < 0 else 0, -h if h < 0 else None)
for l, h, step in zip(lo_pos, hi_pos, np.add(1, interior)))
out[lhs_slices] = operand
trim_slices = tuple(_slice(-l if l < 0 else 0, h if h < 0 else None)
for l, h in zip(lo, hi))
out[lhs_slices] = operand[rhs_slices]
return out
return out[trim_slices]
def rev(operand, dimensions):
dimensions = frozenset(dimensions)

View File

@ -976,7 +976,14 @@ class LaxTest(jtu.JaxTestCase):
"shape": shape, "dtype": dtype, "pads": pads, "rng_factory": jtu.rand_small}
for shape in [(2, 3)]
for dtype in default_dtypes
for pads in [[(1, 2, 1), (0, 1, 0)]]))
for pads in [
[(0, 0, 0), (0, 0, 0)], # no padding
[(1, 1, 0), (2, 2, 0)], # only positive edge padding
[(1, 2, 1), (0, 1, 0)], # edge padding and interior padding
[(0, 0, 0), (-1, -1, 0)], # negative padding
[(0, 0, 0), (-2, -2, 4)], # add big dilation then remove from edges
[(0, 0, 0), (-2, -3, 1)], # remove everything in one dimension
]))
def testPadAgainstNumpy(self, shape, dtype, pads, rng_factory):
rng = rng_factory(self.rng())
args_maker = lambda: [rng(shape, dtype)]