[XLA] Implement MirrorPad op.

Addresses 

* Improves the shape inference error message for concatenate.
* Adds a helper to Literal that gets an integral value converted to int64.

PiperOrigin-RevId: 163829437
This commit is contained in:
Chris Leary 2017-08-01 07:59:31 -07:00 committed by TensorFlower Gardener
parent c7b674fa28
commit 3cc5fc0886
9 changed files with 232 additions and 2 deletions

View File

@ -650,6 +650,80 @@ class BinaryOpsTest(XLATestCase):
[0, 0, 0, 0, 0, 0]],
dtype=dtype))
def testMirrorPad(self):
mirror_pad = lambda t, paddings: array_ops.pad(t, paddings, "REFLECT")
for dtype in self.numeric_types:
self._testBinary(
mirror_pad,
np.array(
[
[1, 2, 3], #
[4, 5, 6], #
],
dtype=dtype),
np.array([[
1,
1,
], [2, 2]], dtype=np.int32),
expected=np.array(
[
[6, 5, 4, 5, 6, 5, 4], #
[3, 2, 1, 2, 3, 2, 1], #
[6, 5, 4, 5, 6, 5, 4], #
[3, 2, 1, 2, 3, 2, 1]
],
dtype=dtype))
self._testBinary(
mirror_pad,
np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype),
np.array([[0, 0], [0, 0]], dtype=np.int32),
expected=np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype))
self._testBinary(
mirror_pad,
np.array(
[
[1, 2, 3], #
[4, 5, 6], #
[7, 8, 9]
],
dtype=dtype),
np.array([[2, 2], [0, 0]], dtype=np.int32),
expected=np.array(
[
[7, 8, 9], #
[4, 5, 6], #
[1, 2, 3], #
[4, 5, 6], #
[7, 8, 9], #
[4, 5, 6], #
[1, 2, 3]
],
dtype=dtype))
self._testBinary(
mirror_pad,
np.array(
[
[[1, 2, 3], [4, 5, 6]],
[[7, 8, 9], [10, 11, 12]],
], dtype=dtype),
np.array([[0, 0], [1, 1], [1, 1]], dtype=np.int32),
expected=np.array(
[
[
[5, 4, 5, 6, 5], #
[2, 1, 2, 3, 2], #
[5, 4, 5, 6, 5], #
[2, 1, 2, 3, 2], #
],
[
[11, 10, 11, 12, 11], #
[8, 7, 8, 9, 8], #
[11, 10, 11, 12, 11], #
[8, 7, 8, 9, 8], #
]
],
dtype=dtype))
def testReshape(self):
for dtype in self.numeric_types:
self._testBinary(

View File

@ -62,6 +62,7 @@ Status BackwardsConstAnalysis(const Graph& g,
{"Min", "reduction_indices"},
{"OneHot", "depth"},
{"Pad", "paddings"},
{"MirrorPad", "paddings"},
{"Prod", "reduction_indices"},
{"RandomStandardNormal", "shape"},
{"RandomUniform", "shape"},

View File

@ -35,6 +35,7 @@ tf_kernel_library(
"l2loss_op.cc",
"lrn_ops.cc",
"matmul_op.cc",
"mirror_pad_op.cc",
"no_op.cc",
"one_hot_op.cc",
"pack_op.cc",

View File

@ -0,0 +1,98 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/core/util/mirror_pad_mode.h"
namespace tensorflow {
namespace {
class MirrorPadOp : public XlaOpKernel {
public:
explicit MirrorPadOp(OpKernelConstruction* context) : XlaOpKernel(context) {}
xla::StatusOr<xla::ComputationDataHandle> DoMirrorPad(
const xla::ComputationDataHandle& t, const xla::Shape& original_shape,
const xla::Literal& pad_literal, xla::ComputationBuilder* b) {
xla::ComputationDataHandle accum = t;
for (int64 dimno = xla::ShapeUtil::Rank(original_shape) - 1; dimno >= 0;
--dimno) {
auto t_rev = b->Rev(accum, {dimno});
TF_ASSIGN_OR_RETURN(int64 lhs_padding,
pad_literal.GetIntegralAsS64({dimno, 0}));
TF_ASSIGN_OR_RETURN(int64 rhs_padding,
pad_literal.GetIntegralAsS64({dimno, 1}));
int64 dim_size = original_shape.dimensions(dimno);
auto lhs_pad = b->SliceInDim(t_rev, dim_size - 1 - lhs_padding,
dim_size - 1, 1, dimno);
auto rhs_pad = b->SliceInDim(t_rev, 1, 1 + rhs_padding, 1, dimno);
accum = b->ConcatInDim({lhs_pad, accum, rhs_pad}, dimno);
}
return accum;
}
void Compile(XlaOpKernelContext* ctx) override {
const TensorShape input_shape = ctx->InputShape(0);
const TensorShape pad_shape = ctx->InputShape(1);
MirrorPadMode mode;
OP_REQUIRES_OK(ctx, GetNodeAttr(def(), "mode", &mode));
OP_REQUIRES(ctx, mode == MirrorPadMode::REFLECT,
xla::Unimplemented(
"Only REFLECT MirrorPad mode is currently supported"));
const int dims = input_shape.dims();
OP_REQUIRES(
ctx,
TensorShapeUtils::IsMatrix(pad_shape) && pad_shape.dim_size(1) == 2,
errors::InvalidArgument("paddings must be a matrix with 2 columns: ",
pad_shape.DebugString()));
const int fixed_dims =
(allow_legacy_scalars() && dims == 0 && pad_shape.dim_size(0) == 1)
? 1
: dims;
OP_REQUIRES(
ctx, fixed_dims == pad_shape.dim_size(0),
errors::InvalidArgument(
"The first dimension of paddings must be the rank of inputs",
pad_shape.DebugString(), " ", input_shape.DebugString()));
// Evaluate the 'padding' constant input, reshaping to a matrix.
xla::Literal pad_literal;
OP_REQUIRES_OK(
ctx, ctx->ConstantInputReshaped(1, {fixed_dims, 2}, &pad_literal));
xla::ComputationBuilder* b = ctx->builder();
auto in0 = ctx->Input(0);
xla::StatusOr<std::unique_ptr<xla::Shape>> in0_shape = b->GetShape(in0);
OP_REQUIRES(ctx, in0_shape.ok(), in0_shape.status());
xla::StatusOr<xla::ComputationDataHandle> accum_status =
DoMirrorPad(in0, *in0_shape.ValueOrDie(), pad_literal, b);
OP_REQUIRES_OK(ctx, accum_status.status());
ctx->SetOutput(0, accum_status.ValueOrDie());
}
private:
TF_DISALLOW_COPY_AND_ASSIGN(MirrorPadOp);
};
REGISTER_XLA_OP(Name("MirrorPad"), MirrorPadOp);
} // namespace
} // namespace tensorflow

View File

@ -284,6 +284,25 @@ ComputationDataHandle ComputationBuilder::Slice(
return ParseOpResponse(s, &response);
}
ComputationDataHandle ComputationBuilder::SliceInDim(
const ComputationDataHandle& operand, int64 start_index, int64 limit_index,
int64 stride, int64 dimno) {
StatusOr<std::unique_ptr<Shape>> shape_status = GetShape(operand);
if (!shape_status.ok()) {
NoteError(shape_status.status());
return ComputationDataHandle{};
}
const Shape& shape = *shape_status.ValueOrDie();
std::vector<int64> starts(ShapeUtil::Rank(shape), 0);
std::vector<int64> limits(shape.dimensions().begin(),
shape.dimensions().end());
std::vector<int64> strides(ShapeUtil::Rank(shape), 1);
starts[dimno] = start_index;
limits[dimno] = limit_index;
strides[dimno] = stride;
return Slice(operand, starts, limits, strides);
}
ComputationDataHandle ComputationBuilder::DynamicSlice(
const ComputationDataHandle& operand,
const ComputationDataHandle& start_indices,

View File

@ -217,6 +217,16 @@ class ComputationBuilder {
tensorflow::gtl::ArraySlice<int64> limit_indices,
tensorflow::gtl::ArraySlice<int64> stride);
// Enqueues a slice operation in a given dimension, taking all other
// dimensions as they are; e.g. if dimno is 1 from start_index 2 to
// limit_index 4 by 1, and the shape is f32[7,8,9], this call is short-hand
// for:
//
// array[:, 2:4:1, :]
ComputationDataHandle SliceInDim(const ComputationDataHandle& operand,
int64 start_index, int64 limit_index,
int64 stride, int64 dimno);
// Enqueues a slice operation onto the computation that slices the 'operand'
// from dynamic start indices which are passed in 'start_indices'.
// The size of the slice in each dimension is passed in 'slice_sizes',

View File

@ -503,6 +503,28 @@ string Literal::GetAsString(
}
}
StatusOr<int64> Literal::GetIntegralAsS64(
tensorflow::gtl::ArraySlice<int64> multi_index) const {
switch (shape().element_type()) {
case PRED:
return Get<bool>(multi_index);
case U8:
return Get<uint8>(multi_index);
case S32:
return Get<int32>(multi_index);
case S64:
return Get<int64>(multi_index);
case U32:
return Get<uint32>(multi_index);
case U64:
return Get<uint64>(multi_index);
default:
return FailedPrecondition(
"Array element type is not integral: %s",
PrimitiveType_Name(shape().element_type()).c_str());
}
}
int64 Literal::LinearIndex(
tensorflow::gtl::ArraySlice<int64> multi_index) const {
return IndexUtil::MultidimensionalIndexToLinearIndex(shape(), multi_index);

View File

@ -390,6 +390,11 @@ class Literal {
// into text.
string GetAsString(tensorflow::gtl::ArraySlice<int64> multi_index) const;
// As Get(), but determines the correct type and converts the value into
// int64.
StatusOr<int64> GetIntegralAsS64(
tensorflow::gtl::ArraySlice<int64> multi_index) const;
// Returns an identity matrix (rank 2) with the given row and column count.
template <typename NativeT>
static std::unique_ptr<Literal> MakeIdentityR2(int64 size);

View File

@ -269,9 +269,9 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
return InvalidArgument(
"cannot concatenate arrays that differ in dimensions other than "
"the one being concatenated (the other array dimensions must be "
"the same): %s vs %s",
"the same): %s vs %s in dimension %lld",
ShapeUtil::HumanString(*arg_shape).c_str(),
ShapeUtil::HumanString(*shape).c_str());
ShapeUtil::HumanString(*shape).c_str(), dimension);
}
}
}