diff --git a/tensorflow/compiler/tests/xla_ops_test.py b/tensorflow/compiler/tests/xla_ops_test.py index f3e915daa67..35d36315464 100644 --- a/tensorflow/compiler/tests/xla_ops_test.py +++ b/tensorflow/compiler/tests/xla_ops_test.py @@ -211,6 +211,24 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase): [7, 7, 7, 7, 7], [7, 2, 3, 7, 7], [7, 7, 7, 7, 7]], dtype=dtype)) + def testPadNegative(self): + for dtype in self.numeric_types: + + def pad_fn(x): + return xla.pad( + x, + padding_value=7, + padding_low=[0, -1], + padding_high=[1, -2], + padding_interior=[1, 2]) + + self._assertOpOutputMatchesExpected( + pad_fn, + args=(np.arange(6, dtype=np.int32).astype(dtype).reshape([2, 3]),), + expected=np.array( + [[7, 7, 1, 7], [7, 7, 7, 7], [7, 7, 4, 7], [7, 7, 7, 7]], + dtype=dtype)) + @test_util.disable_mlir_bridge('Not supported yet') def testReduce(self): for dtype in set(self.numeric_types).intersection( diff --git a/tensorflow/compiler/tf2xla/kernels/xla_pad_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_pad_op.cc index a3c2eef993c..d35101a771a 100644 --- a/tensorflow/compiler/tf2xla/kernels/xla_pad_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/xla_pad_op.cc @@ -64,14 +64,6 @@ class XlaPadOp : public XlaOpKernel { padding_interior.size(), " vs. ", rank, ")")); auto non_negative = [](int64 x) { return x >= 0; }; - OP_REQUIRES( - context, absl::c_all_of(padding_low, non_negative), - errors::InvalidArgument("padding_low must be non-negative, got [", - absl::StrJoin(padding_low, ","), "]")); - OP_REQUIRES( - context, absl::c_all_of(padding_high, non_negative), - errors::InvalidArgument("padding_high must be non-negative, got [", - absl::StrJoin(padding_high, ","), "]")); OP_REQUIRES( context, absl::c_all_of(padding_interior, non_negative), errors::InvalidArgument("padding_interior must be non-negative, got [",