Add shape inference rule for XlaPad.
Before the change the output of XlaPad had a known rank but unknown dimensions. PiperOrigin-RevId: 349600804 Change-Id: Ie5e2177499ac4a9d4a5ca3c890adc9798a46fbf1
This commit is contained in:
parent
0e5003a4ee
commit
d16e734273
@ -251,6 +251,92 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
[[7, 7, 1, 7], [7, 7, 7, 7], [7, 7, 4, 7], [7, 7, 7, 7]],
|
||||
dtype=dtype))
|
||||
|
||||
def testPadShapeInference(self):
|
||||
a = array_ops.placeholder(np.float32, shape=(2, 3))
|
||||
|
||||
c = xla.pad(
|
||||
a,
|
||||
padding_value=7,
|
||||
padding_low=[2, 1],
|
||||
padding_high=[1, 2],
|
||||
padding_interior=[1, 4])
|
||||
|
||||
self.assertEqual(c.shape, tensor_shape.TensorShape([6, 14]))
|
||||
|
||||
c = xla.pad(
|
||||
a,
|
||||
padding_value=7,
|
||||
padding_low=[2, -2],
|
||||
padding_high=[1, -2],
|
||||
padding_interior=[1, 2])
|
||||
|
||||
self.assertEqual(c.shape, tensor_shape.TensorShape([6, 3]))
|
||||
|
||||
# 0-sized input dimension and interior padding
|
||||
c = xla.pad(
|
||||
array_ops.placeholder(np.float32, shape=(2, 0)),
|
||||
padding_value=7,
|
||||
padding_low=[2, 1],
|
||||
padding_high=[1, 1],
|
||||
padding_interior=[1, 2])
|
||||
|
||||
self.assertEqual(c.shape, tensor_shape.TensorShape([6, 2]))
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, 'padding_value input must be scalar, found rank 1 '):
|
||||
xla.pad(
|
||||
a,
|
||||
padding_value=[0, 1],
|
||||
padding_low=[0, 0],
|
||||
padding_high=[0, 0],
|
||||
padding_interior=[0, 0])
|
||||
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
'padding_low must be a 1D tensor of size 2 '):
|
||||
xla.pad(
|
||||
a,
|
||||
padding_value=7,
|
||||
padding_low=[0, 0, 0],
|
||||
padding_high=[0, 0],
|
||||
padding_interior=[0, 0])
|
||||
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
'padding_high must be a 1D tensor of size 2 '):
|
||||
xla.pad(
|
||||
a,
|
||||
padding_value=7,
|
||||
padding_low=[0, 0],
|
||||
padding_high=[0, 0, 0],
|
||||
padding_interior=[0, 0])
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, 'padding_interior must be a 1D tensor of size 2 '):
|
||||
xla.pad(
|
||||
a,
|
||||
padding_value=7,
|
||||
padding_low=[0, 0],
|
||||
padding_high=[0, 0],
|
||||
padding_interior=[0])
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
'padding_interior must contain only non-negative values, found -2 '):
|
||||
xla.pad(
|
||||
a,
|
||||
padding_value=7,
|
||||
padding_low=[0, 0],
|
||||
padding_high=[0, 0],
|
||||
padding_interior=[-2, 0])
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, 'resulting padded dimension has negative size -1 '):
|
||||
xla.pad(
|
||||
a,
|
||||
padding_value=7,
|
||||
padding_low=[-3, 0],
|
||||
padding_high=[0, 0],
|
||||
padding_interior=[0, 0])
|
||||
|
||||
@test_util.disable_mlir_bridge('Not supported yet')
|
||||
def testReduce(self):
|
||||
for dtype in set(self.numeric_types).intersection(
|
||||
|
@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <cstddef>
|
||||
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "absl/strings/match.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
@ -407,7 +409,79 @@ REGISTER_OP("XlaPad")
|
||||
.Output("output: T")
|
||||
.Attr("T: type")
|
||||
.Attr("Tindices: {int32, int64}")
|
||||
.SetShapeFn(UnchangedRank)
|
||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||
shape_inference::ShapeHandle input_shape_handle = c->input(0);
|
||||
if (!c->FullyDefined(input_shape_handle)) {
|
||||
return UnchangedRank(c);
|
||||
}
|
||||
const int32 op_rank = c->Rank(input_shape_handle);
|
||||
|
||||
shape_inference::ShapeHandle padding_shape_handle = c->input(1);
|
||||
if (!c->RankKnown(padding_shape_handle) ||
|
||||
c->Rank(padding_shape_handle) != 0) {
|
||||
return errors::InvalidArgument(
|
||||
"padding_value input must be scalar, found rank ",
|
||||
c->Rank(padding_shape_handle));
|
||||
}
|
||||
const Tensor* padding_low_tensor = c->input_tensor(2);
|
||||
const Tensor* padding_high_tensor = c->input_tensor(3);
|
||||
const Tensor* padding_interior_tensor = c->input_tensor(4);
|
||||
if (padding_low_tensor == nullptr || padding_high_tensor == nullptr ||
|
||||
padding_interior_tensor == nullptr) {
|
||||
return UnchangedRank(c);
|
||||
}
|
||||
|
||||
if (padding_low_tensor->shape().dims() != 1 ||
|
||||
padding_low_tensor->shape().dim_size(0) != op_rank) {
|
||||
return errors::InvalidArgument(
|
||||
"padding_low must be a 1D tensor of size ", op_rank);
|
||||
}
|
||||
if (padding_high_tensor->shape().dims() != 1 ||
|
||||
padding_high_tensor->shape().dim_size(0) != op_rank) {
|
||||
return errors::InvalidArgument(
|
||||
"padding_high must be a 1D tensor of size ", op_rank);
|
||||
}
|
||||
if (padding_interior_tensor->shape().dims() != 1 ||
|
||||
padding_interior_tensor->shape().dim_size(0) != op_rank) {
|
||||
return errors::InvalidArgument(
|
||||
"padding_interior must be a 1D tensor of size ", op_rank);
|
||||
}
|
||||
std::vector<shape_inference::DimensionHandle> output_dims;
|
||||
output_dims.reserve(op_rank);
|
||||
for (int64 i = 0; i < op_rank; ++i) {
|
||||
int64 low, high, interior;
|
||||
TF_RETURN_IF_ERROR(c->GetScalarFromTensor(padding_low_tensor, i, &low));
|
||||
TF_RETURN_IF_ERROR(
|
||||
c->GetScalarFromTensor(padding_high_tensor, i, &high));
|
||||
TF_RETURN_IF_ERROR(
|
||||
c->GetScalarFromTensor(padding_interior_tensor, i, &interior));
|
||||
if (interior < 0) {
|
||||
return errors::InvalidArgument(
|
||||
"padding_interior must contain only non-negative values, found ",
|
||||
interior);
|
||||
}
|
||||
|
||||
shape_inference::DimensionHandle orig_size_handle =
|
||||
c->Dim(input_shape_handle, i);
|
||||
if (c->ValueKnown(orig_size_handle)) {
|
||||
auto orig_dim = c->Value(orig_size_handle);
|
||||
int64 new_dim = orig_dim + low + high;
|
||||
if (orig_dim > 0) {
|
||||
new_dim += interior * (orig_dim - 1);
|
||||
}
|
||||
if (new_dim < 0) {
|
||||
return errors::InvalidArgument(
|
||||
"resulting padded dimension has negative size ", new_dim);
|
||||
}
|
||||
output_dims.emplace_back(c->MakeDim(new_dim));
|
||||
} else {
|
||||
output_dims.emplace_back(c->UnknownDim());
|
||||
}
|
||||
}
|
||||
|
||||
c->set_output(0, c->MakeShape(output_dims));
|
||||
return Status::OK();
|
||||
})
|
||||
.Doc(R"doc(
|
||||
Wraps the XLA Pad operator, documented at
|
||||
https://www.tensorflow.org/performance/xla/operation_semantics#pad
|
||||
@ -415,9 +489,13 @@ Wraps the XLA Pad operator, documented at
|
||||
|
||||
input: A `Tensor` of type T.
|
||||
padding_value: A scalar `Tensor` of type T.
|
||||
padding_low: the padding to apply at the start of each input dimensions
|
||||
padding_high: the padding to apply at the end of each input dimension.
|
||||
padding_interior: the padding to apply between each input element.
|
||||
padding_low: the padding to apply at the start of each input dimensions. Must
|
||||
be a compile-time constant 1D tensor of length equal to rank of input.
|
||||
padding_high: the padding to apply at the end of each input dimension. Must
|
||||
be a compile-time constant 1D tensor of length equal to rank of input.
|
||||
padding_interior: the padding to apply between each input element. Must
|
||||
be a compile-time constant 1D tensor of length equal to rank of input,
|
||||
containing only non-negative values.
|
||||
output: A `Tensor` of type T.
|
||||
)doc");
|
||||
|
||||
|
@ -896,10 +896,10 @@ Status InferenceContext::GetScalarFromTensor(const Tensor* t, int64* val) {
|
||||
return errors::InvalidArgument("Input must be scalar but has rank ", rank);
|
||||
}
|
||||
|
||||
if (t->dtype() == DT_INT32) {
|
||||
if (t->dtype() == DataType::DT_INT32) {
|
||||
*val = t->scalar<int32>()();
|
||||
return Status::OK();
|
||||
} else if (t->dtype() == DT_INT64) {
|
||||
} else if (t->dtype() == DataType::DT_INT64) {
|
||||
*val = t->scalar<int64>()();
|
||||
return Status::OK();
|
||||
} else {
|
||||
@ -907,6 +907,35 @@ Status InferenceContext::GetScalarFromTensor(const Tensor* t, int64* val) {
|
||||
}
|
||||
}
|
||||
|
||||
Status InferenceContext::GetScalarFromTensor(const Tensor* t, int64 idx,
|
||||
int64* val) {
|
||||
// Caller must ensure that <t> is not NULL.
|
||||
const int rank = t->dims();
|
||||
if (rank != 1) {
|
||||
return errors::InvalidArgument("Input must be 1D but has rank ", rank);
|
||||
}
|
||||
|
||||
if (t->dtype() == DataType::DT_INT32) {
|
||||
auto flat_t = t->flat<int32>();
|
||||
if (idx < 0 || idx >= flat_t.size()) {
|
||||
return errors::InvalidArgument("Invalid index ", idx,
|
||||
" for Tensor of size ", flat_t.size());
|
||||
}
|
||||
*val = flat_t(idx);
|
||||
return Status::OK();
|
||||
} else if (t->dtype() == DataType::DT_INT64) {
|
||||
auto flat_t = t->flat<int64>();
|
||||
if (idx < 0 || idx >= flat_t.size()) {
|
||||
return errors::InvalidArgument("Invalid index ", idx,
|
||||
" for Tensor of size ", flat_t.size());
|
||||
}
|
||||
*val = flat_t(idx);
|
||||
return Status::OK();
|
||||
} else {
|
||||
return errors::InvalidArgument("Tensor input must be int32 or int64.");
|
||||
}
|
||||
}
|
||||
|
||||
// Returns a new dimension whose value is given by a scalar input tensor.
|
||||
Status InferenceContext::MakeDimForScalarInput(int idx, DimensionHandle* out) {
|
||||
int64 val;
|
||||
|
@ -490,10 +490,14 @@ class InferenceContext {
|
||||
inline DimensionHandle UnknownDim() { return MakeDim(kUnknownDim); }
|
||||
|
||||
// Returns in <val> a scalar value from an input tensor <t>. The input tensor
|
||||
// must be a 1-dimensional int32 or int64 tensor. Caller must ensure that the
|
||||
// must be a 0-dimensional int32 or int64 tensor. Caller must ensure that the
|
||||
// input tensor is not NULL.
|
||||
Status GetScalarFromTensor(const Tensor* t, int64* val);
|
||||
|
||||
// Returns in <val> a scalar value from a 1D input tensor <t> with int32 or
|
||||
// int64 elements. Caller must ensure that the input tensor is not NULL.
|
||||
Status GetScalarFromTensor(const Tensor* t, int64 idx, int64* val);
|
||||
|
||||
// Returns a new dimension whose value is given by a scalar input tensor.
|
||||
// The input tensor must be in host memory, since it is dereferenced to get
|
||||
// the value.
|
||||
|
Loading…
Reference in New Issue
Block a user