diff --git a/tensorflow/compiler/tests/xla_ops_test.py b/tensorflow/compiler/tests/xla_ops_test.py index b80b6263992..8c1a67f0e87 100644 --- a/tensorflow/compiler/tests/xla_ops_test.py +++ b/tensorflow/compiler/tests/xla_ops_test.py @@ -251,6 +251,92 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase): [[7, 7, 1, 7], [7, 7, 7, 7], [7, 7, 4, 7], [7, 7, 7, 7]], dtype=dtype)) + def testPadShapeInference(self): + a = array_ops.placeholder(np.float32, shape=(2, 3)) + + c = xla.pad( + a, + padding_value=7, + padding_low=[2, 1], + padding_high=[1, 2], + padding_interior=[1, 4]) + + self.assertEqual(c.shape, tensor_shape.TensorShape([6, 14])) + + c = xla.pad( + a, + padding_value=7, + padding_low=[2, -2], + padding_high=[1, -2], + padding_interior=[1, 2]) + + self.assertEqual(c.shape, tensor_shape.TensorShape([6, 3])) + + # 0-sized input dimension and interior padding + c = xla.pad( + array_ops.placeholder(np.float32, shape=(2, 0)), + padding_value=7, + padding_low=[2, 1], + padding_high=[1, 1], + padding_interior=[1, 2]) + + self.assertEqual(c.shape, tensor_shape.TensorShape([6, 2])) + + with self.assertRaisesRegex( + ValueError, 'padding_value input must be scalar, found rank 1 '): + xla.pad( + a, + padding_value=[0, 1], + padding_low=[0, 0], + padding_high=[0, 0], + padding_interior=[0, 0]) + + with self.assertRaisesRegex(ValueError, + 'padding_low must be a 1D tensor of size 2 '): + xla.pad( + a, + padding_value=7, + padding_low=[0, 0, 0], + padding_high=[0, 0], + padding_interior=[0, 0]) + + with self.assertRaisesRegex(ValueError, + 'padding_high must be a 1D tensor of size 2 '): + xla.pad( + a, + padding_value=7, + padding_low=[0, 0], + padding_high=[0, 0, 0], + padding_interior=[0, 0]) + + with self.assertRaisesRegex( + ValueError, 'padding_interior must be a 1D tensor of size 2 '): + xla.pad( + a, + padding_value=7, + padding_low=[0, 0], + padding_high=[0, 0], + padding_interior=[0]) + + with self.assertRaisesRegex( + ValueError, + 'padding_interior must contain only non-negative values, found -2 '): + xla.pad( + a, + padding_value=7, + padding_low=[0, 0], + padding_high=[0, 0], + padding_interior=[-2, 0]) + + with self.assertRaisesRegex( + ValueError, 'resulting padded dimension has negative size -1 '): + xla.pad( + a, + padding_value=7, + padding_low=[-3, 0], + padding_high=[0, 0], + padding_interior=[0, 0]) + @test_util.disable_mlir_bridge('Not supported yet') def testReduce(self): for dtype in set(self.numeric_types).intersection( diff --git a/tensorflow/compiler/tf2xla/ops/xla_ops.cc b/tensorflow/compiler/tf2xla/ops/xla_ops.cc index 0e780a763d3..00d1fefa941 100644 --- a/tensorflow/compiler/tf2xla/ops/xla_ops.cc +++ b/tensorflow/compiler/tf2xla/ops/xla_ops.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "absl/algorithm/container.h" #include "absl/strings/match.h" #include "absl/strings/str_cat.h" @@ -407,7 +409,79 @@ REGISTER_OP("XlaPad") .Output("output: T") .Attr("T: type") .Attr("Tindices: {int32, int64}") - .SetShapeFn(UnchangedRank) + .SetShapeFn([](shape_inference::InferenceContext* c) { + shape_inference::ShapeHandle input_shape_handle = c->input(0); + if (!c->FullyDefined(input_shape_handle)) { + return UnchangedRank(c); + } + const int32 op_rank = c->Rank(input_shape_handle); + + shape_inference::ShapeHandle padding_shape_handle = c->input(1); + if (!c->RankKnown(padding_shape_handle) || + c->Rank(padding_shape_handle) != 0) { + return errors::InvalidArgument( + "padding_value input must be scalar, found rank ", + c->Rank(padding_shape_handle)); + } + const Tensor* padding_low_tensor = c->input_tensor(2); + const Tensor* padding_high_tensor = c->input_tensor(3); + const Tensor* padding_interior_tensor = c->input_tensor(4); + if (padding_low_tensor == nullptr || padding_high_tensor == nullptr || + padding_interior_tensor == nullptr) { + return UnchangedRank(c); + } + + if (padding_low_tensor->shape().dims() != 1 || + padding_low_tensor->shape().dim_size(0) != op_rank) { + return errors::InvalidArgument( + "padding_low must be a 1D tensor of size ", op_rank); + } + if (padding_high_tensor->shape().dims() != 1 || + padding_high_tensor->shape().dim_size(0) != op_rank) { + return errors::InvalidArgument( + "padding_high must be a 1D tensor of size ", op_rank); + } + if (padding_interior_tensor->shape().dims() != 1 || + padding_interior_tensor->shape().dim_size(0) != op_rank) { + return errors::InvalidArgument( + "padding_interior must be a 1D tensor of size ", op_rank); + } + std::vector output_dims; + output_dims.reserve(op_rank); + for (int64 i = 0; i < op_rank; ++i) { + int64 low, high, interior; + TF_RETURN_IF_ERROR(c->GetScalarFromTensor(padding_low_tensor, i, &low)); + TF_RETURN_IF_ERROR( + c->GetScalarFromTensor(padding_high_tensor, i, &high)); + TF_RETURN_IF_ERROR( + c->GetScalarFromTensor(padding_interior_tensor, i, &interior)); + if (interior < 0) { + return errors::InvalidArgument( + "padding_interior must contain only non-negative values, found ", + interior); + } + + shape_inference::DimensionHandle orig_size_handle = + c->Dim(input_shape_handle, i); + if (c->ValueKnown(orig_size_handle)) { + auto orig_dim = c->Value(orig_size_handle); + int64 new_dim = orig_dim + low + high; + if (orig_dim > 0) { + new_dim += interior * (orig_dim - 1); + } + if (new_dim < 0) { + return errors::InvalidArgument( + "resulting padded dimension has negative size ", new_dim); + } + output_dims.emplace_back(c->MakeDim(new_dim)); + } else { + output_dims.emplace_back(c->UnknownDim()); + } + } + + c->set_output(0, c->MakeShape(output_dims)); + return Status::OK(); + }) .Doc(R"doc( Wraps the XLA Pad operator, documented at https://www.tensorflow.org/performance/xla/operation_semantics#pad @@ -415,9 +489,13 @@ Wraps the XLA Pad operator, documented at input: A `Tensor` of type T. padding_value: A scalar `Tensor` of type T. -padding_low: the padding to apply at the start of each input dimensions -padding_high: the padding to apply at the end of each input dimension. -padding_interior: the padding to apply between each input element. +padding_low: the padding to apply at the start of each input dimensions. Must + be a compile-time constant 1D tensor of length equal to rank of input. +padding_high: the padding to apply at the end of each input dimension. Must + be a compile-time constant 1D tensor of length equal to rank of input. +padding_interior: the padding to apply between each input element. Must + be a compile-time constant 1D tensor of length equal to rank of input, + containing only non-negative values. output: A `Tensor` of type T. )doc"); diff --git a/tensorflow/core/framework/shape_inference.cc b/tensorflow/core/framework/shape_inference.cc index 2d81b294372..721c20b7491 100644 --- a/tensorflow/core/framework/shape_inference.cc +++ b/tensorflow/core/framework/shape_inference.cc @@ -896,10 +896,10 @@ Status InferenceContext::GetScalarFromTensor(const Tensor* t, int64* val) { return errors::InvalidArgument("Input must be scalar but has rank ", rank); } - if (t->dtype() == DT_INT32) { + if (t->dtype() == DataType::DT_INT32) { *val = t->scalar()(); return Status::OK(); - } else if (t->dtype() == DT_INT64) { + } else if (t->dtype() == DataType::DT_INT64) { *val = t->scalar()(); return Status::OK(); } else { @@ -907,6 +907,35 @@ Status InferenceContext::GetScalarFromTensor(const Tensor* t, int64* val) { } } +Status InferenceContext::GetScalarFromTensor(const Tensor* t, int64 idx, + int64* val) { + // Caller must ensure that is not NULL. + const int rank = t->dims(); + if (rank != 1) { + return errors::InvalidArgument("Input must be 1D but has rank ", rank); + } + + if (t->dtype() == DataType::DT_INT32) { + auto flat_t = t->flat(); + if (idx < 0 || idx >= flat_t.size()) { + return errors::InvalidArgument("Invalid index ", idx, + " for Tensor of size ", flat_t.size()); + } + *val = flat_t(idx); + return Status::OK(); + } else if (t->dtype() == DataType::DT_INT64) { + auto flat_t = t->flat(); + if (idx < 0 || idx >= flat_t.size()) { + return errors::InvalidArgument("Invalid index ", idx, + " for Tensor of size ", flat_t.size()); + } + *val = flat_t(idx); + return Status::OK(); + } else { + return errors::InvalidArgument("Tensor input must be int32 or int64."); + } +} + // Returns a new dimension whose value is given by a scalar input tensor. Status InferenceContext::MakeDimForScalarInput(int idx, DimensionHandle* out) { int64 val; diff --git a/tensorflow/core/framework/shape_inference.h b/tensorflow/core/framework/shape_inference.h index a7c72ebe294..be73a3df5ab 100644 --- a/tensorflow/core/framework/shape_inference.h +++ b/tensorflow/core/framework/shape_inference.h @@ -490,10 +490,14 @@ class InferenceContext { inline DimensionHandle UnknownDim() { return MakeDim(kUnknownDim); } // Returns in a scalar value from an input tensor . The input tensor - // must be a 1-dimensional int32 or int64 tensor. Caller must ensure that the + // must be a 0-dimensional int32 or int64 tensor. Caller must ensure that the // input tensor is not NULL. Status GetScalarFromTensor(const Tensor* t, int64* val); + // Returns in a scalar value from a 1D input tensor with int32 or + // int64 elements. Caller must ensure that the input tensor is not NULL. + Status GetScalarFromTensor(const Tensor* t, int64 idx, int64* val); + // Returns a new dimension whose value is given by a scalar input tensor. // The input tensor must be in host memory, since it is dereferenced to get // the value.