Add shape inference rule for XlaPad.

Before the change the output of XlaPad had a known rank but unknown dimensions.

PiperOrigin-RevId: 349600804
Change-Id: Ie5e2177499ac4a9d4a5ca3c890adc9798a46fbf1
This commit is contained in:
A. Unique TensorFlower 2020-12-30 14:11:44 -08:00 committed by TensorFlower Gardener
parent 0e5003a4ee
commit d16e734273
4 changed files with 204 additions and 7 deletions

View File

@ -251,6 +251,92 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase):
[[7, 7, 1, 7], [7, 7, 7, 7], [7, 7, 4, 7], [7, 7, 7, 7]],
dtype=dtype))
def testPadShapeInference(self):
a = array_ops.placeholder(np.float32, shape=(2, 3))
c = xla.pad(
a,
padding_value=7,
padding_low=[2, 1],
padding_high=[1, 2],
padding_interior=[1, 4])
self.assertEqual(c.shape, tensor_shape.TensorShape([6, 14]))
c = xla.pad(
a,
padding_value=7,
padding_low=[2, -2],
padding_high=[1, -2],
padding_interior=[1, 2])
self.assertEqual(c.shape, tensor_shape.TensorShape([6, 3]))
# 0-sized input dimension and interior padding
c = xla.pad(
array_ops.placeholder(np.float32, shape=(2, 0)),
padding_value=7,
padding_low=[2, 1],
padding_high=[1, 1],
padding_interior=[1, 2])
self.assertEqual(c.shape, tensor_shape.TensorShape([6, 2]))
with self.assertRaisesRegex(
ValueError, 'padding_value input must be scalar, found rank 1 '):
xla.pad(
a,
padding_value=[0, 1],
padding_low=[0, 0],
padding_high=[0, 0],
padding_interior=[0, 0])
with self.assertRaisesRegex(ValueError,
'padding_low must be a 1D tensor of size 2 '):
xla.pad(
a,
padding_value=7,
padding_low=[0, 0, 0],
padding_high=[0, 0],
padding_interior=[0, 0])
with self.assertRaisesRegex(ValueError,
'padding_high must be a 1D tensor of size 2 '):
xla.pad(
a,
padding_value=7,
padding_low=[0, 0],
padding_high=[0, 0, 0],
padding_interior=[0, 0])
with self.assertRaisesRegex(
ValueError, 'padding_interior must be a 1D tensor of size 2 '):
xla.pad(
a,
padding_value=7,
padding_low=[0, 0],
padding_high=[0, 0],
padding_interior=[0])
with self.assertRaisesRegex(
ValueError,
'padding_interior must contain only non-negative values, found -2 '):
xla.pad(
a,
padding_value=7,
padding_low=[0, 0],
padding_high=[0, 0],
padding_interior=[-2, 0])
with self.assertRaisesRegex(
ValueError, 'resulting padded dimension has negative size -1 '):
xla.pad(
a,
padding_value=7,
padding_low=[-3, 0],
padding_high=[0, 0],
padding_interior=[0, 0])
@test_util.disable_mlir_bridge('Not supported yet')
def testReduce(self):
for dtype in set(self.numeric_types).intersection(

View File

@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <cstddef>
#include "absl/algorithm/container.h"
#include "absl/strings/match.h"
#include "absl/strings/str_cat.h"
@ -407,7 +409,79 @@ REGISTER_OP("XlaPad")
.Output("output: T")
.Attr("T: type")
.Attr("Tindices: {int32, int64}")
.SetShapeFn(UnchangedRank)
.SetShapeFn([](shape_inference::InferenceContext* c) {
shape_inference::ShapeHandle input_shape_handle = c->input(0);
if (!c->FullyDefined(input_shape_handle)) {
return UnchangedRank(c);
}
const int32 op_rank = c->Rank(input_shape_handle);
shape_inference::ShapeHandle padding_shape_handle = c->input(1);
if (!c->RankKnown(padding_shape_handle) ||
c->Rank(padding_shape_handle) != 0) {
return errors::InvalidArgument(
"padding_value input must be scalar, found rank ",
c->Rank(padding_shape_handle));
}
const Tensor* padding_low_tensor = c->input_tensor(2);
const Tensor* padding_high_tensor = c->input_tensor(3);
const Tensor* padding_interior_tensor = c->input_tensor(4);
if (padding_low_tensor == nullptr || padding_high_tensor == nullptr ||
padding_interior_tensor == nullptr) {
return UnchangedRank(c);
}
if (padding_low_tensor->shape().dims() != 1 ||
padding_low_tensor->shape().dim_size(0) != op_rank) {
return errors::InvalidArgument(
"padding_low must be a 1D tensor of size ", op_rank);
}
if (padding_high_tensor->shape().dims() != 1 ||
padding_high_tensor->shape().dim_size(0) != op_rank) {
return errors::InvalidArgument(
"padding_high must be a 1D tensor of size ", op_rank);
}
if (padding_interior_tensor->shape().dims() != 1 ||
padding_interior_tensor->shape().dim_size(0) != op_rank) {
return errors::InvalidArgument(
"padding_interior must be a 1D tensor of size ", op_rank);
}
std::vector<shape_inference::DimensionHandle> output_dims;
output_dims.reserve(op_rank);
for (int64 i = 0; i < op_rank; ++i) {
int64 low, high, interior;
TF_RETURN_IF_ERROR(c->GetScalarFromTensor(padding_low_tensor, i, &low));
TF_RETURN_IF_ERROR(
c->GetScalarFromTensor(padding_high_tensor, i, &high));
TF_RETURN_IF_ERROR(
c->GetScalarFromTensor(padding_interior_tensor, i, &interior));
if (interior < 0) {
return errors::InvalidArgument(
"padding_interior must contain only non-negative values, found ",
interior);
}
shape_inference::DimensionHandle orig_size_handle =
c->Dim(input_shape_handle, i);
if (c->ValueKnown(orig_size_handle)) {
auto orig_dim = c->Value(orig_size_handle);
int64 new_dim = orig_dim + low + high;
if (orig_dim > 0) {
new_dim += interior * (orig_dim - 1);
}
if (new_dim < 0) {
return errors::InvalidArgument(
"resulting padded dimension has negative size ", new_dim);
}
output_dims.emplace_back(c->MakeDim(new_dim));
} else {
output_dims.emplace_back(c->UnknownDim());
}
}
c->set_output(0, c->MakeShape(output_dims));
return Status::OK();
})
.Doc(R"doc(
Wraps the XLA Pad operator, documented at
https://www.tensorflow.org/performance/xla/operation_semantics#pad
@ -415,9 +489,13 @@ Wraps the XLA Pad operator, documented at
input: A `Tensor` of type T.
padding_value: A scalar `Tensor` of type T.
padding_low: the padding to apply at the start of each input dimensions
padding_high: the padding to apply at the end of each input dimension.
padding_interior: the padding to apply between each input element.
padding_low: the padding to apply at the start of each input dimensions. Must
be a compile-time constant 1D tensor of length equal to rank of input.
padding_high: the padding to apply at the end of each input dimension. Must
be a compile-time constant 1D tensor of length equal to rank of input.
padding_interior: the padding to apply between each input element. Must
be a compile-time constant 1D tensor of length equal to rank of input,
containing only non-negative values.
output: A `Tensor` of type T.
)doc");

View File

@ -896,10 +896,10 @@ Status InferenceContext::GetScalarFromTensor(const Tensor* t, int64* val) {
return errors::InvalidArgument("Input must be scalar but has rank ", rank);
}
if (t->dtype() == DT_INT32) {
if (t->dtype() == DataType::DT_INT32) {
*val = t->scalar<int32>()();
return Status::OK();
} else if (t->dtype() == DT_INT64) {
} else if (t->dtype() == DataType::DT_INT64) {
*val = t->scalar<int64>()();
return Status::OK();
} else {
@ -907,6 +907,35 @@ Status InferenceContext::GetScalarFromTensor(const Tensor* t, int64* val) {
}
}
Status InferenceContext::GetScalarFromTensor(const Tensor* t, int64 idx,
int64* val) {
// Caller must ensure that <t> is not NULL.
const int rank = t->dims();
if (rank != 1) {
return errors::InvalidArgument("Input must be 1D but has rank ", rank);
}
if (t->dtype() == DataType::DT_INT32) {
auto flat_t = t->flat<int32>();
if (idx < 0 || idx >= flat_t.size()) {
return errors::InvalidArgument("Invalid index ", idx,
" for Tensor of size ", flat_t.size());
}
*val = flat_t(idx);
return Status::OK();
} else if (t->dtype() == DataType::DT_INT64) {
auto flat_t = t->flat<int64>();
if (idx < 0 || idx >= flat_t.size()) {
return errors::InvalidArgument("Invalid index ", idx,
" for Tensor of size ", flat_t.size());
}
*val = flat_t(idx);
return Status::OK();
} else {
return errors::InvalidArgument("Tensor input must be int32 or int64.");
}
}
// Returns a new dimension whose value is given by a scalar input tensor.
Status InferenceContext::MakeDimForScalarInput(int idx, DimensionHandle* out) {
int64 val;

View File

@ -490,10 +490,14 @@ class InferenceContext {
inline DimensionHandle UnknownDim() { return MakeDim(kUnknownDim); }
// Returns in <val> a scalar value from an input tensor <t>. The input tensor
// must be a 1-dimensional int32 or int64 tensor. Caller must ensure that the
// must be a 0-dimensional int32 or int64 tensor. Caller must ensure that the
// input tensor is not NULL.
Status GetScalarFromTensor(const Tensor* t, int64* val);
// Returns in <val> a scalar value from a 1D input tensor <t> with int32 or
// int64 elements. Caller must ensure that the input tensor is not NULL.
Status GetScalarFromTensor(const Tensor* t, int64 idx, int64* val);
// Returns a new dimension whose value is given by a scalar input tensor.
// The input tensor must be in host memory, since it is dereferenced to get
// the value.