[XLA] Implement MirrorPad op.
Addresses #11890 * Improves the shape inference error message for concatenate. * Adds a helper to Literal that gets an integral value converted to int64. PiperOrigin-RevId: 163829437
This commit is contained in:
parent
c7b674fa28
commit
3cc5fc0886
tensorflow/compiler
@ -650,6 +650,80 @@ class BinaryOpsTest(XLATestCase):
|
||||
[0, 0, 0, 0, 0, 0]],
|
||||
dtype=dtype))
|
||||
|
||||
def testMirrorPad(self):
|
||||
mirror_pad = lambda t, paddings: array_ops.pad(t, paddings, "REFLECT")
|
||||
for dtype in self.numeric_types:
|
||||
self._testBinary(
|
||||
mirror_pad,
|
||||
np.array(
|
||||
[
|
||||
[1, 2, 3], #
|
||||
[4, 5, 6], #
|
||||
],
|
||||
dtype=dtype),
|
||||
np.array([[
|
||||
1,
|
||||
1,
|
||||
], [2, 2]], dtype=np.int32),
|
||||
expected=np.array(
|
||||
[
|
||||
[6, 5, 4, 5, 6, 5, 4], #
|
||||
[3, 2, 1, 2, 3, 2, 1], #
|
||||
[6, 5, 4, 5, 6, 5, 4], #
|
||||
[3, 2, 1, 2, 3, 2, 1]
|
||||
],
|
||||
dtype=dtype))
|
||||
self._testBinary(
|
||||
mirror_pad,
|
||||
np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype),
|
||||
np.array([[0, 0], [0, 0]], dtype=np.int32),
|
||||
expected=np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype))
|
||||
self._testBinary(
|
||||
mirror_pad,
|
||||
np.array(
|
||||
[
|
||||
[1, 2, 3], #
|
||||
[4, 5, 6], #
|
||||
[7, 8, 9]
|
||||
],
|
||||
dtype=dtype),
|
||||
np.array([[2, 2], [0, 0]], dtype=np.int32),
|
||||
expected=np.array(
|
||||
[
|
||||
[7, 8, 9], #
|
||||
[4, 5, 6], #
|
||||
[1, 2, 3], #
|
||||
[4, 5, 6], #
|
||||
[7, 8, 9], #
|
||||
[4, 5, 6], #
|
||||
[1, 2, 3]
|
||||
],
|
||||
dtype=dtype))
|
||||
self._testBinary(
|
||||
mirror_pad,
|
||||
np.array(
|
||||
[
|
||||
[[1, 2, 3], [4, 5, 6]],
|
||||
[[7, 8, 9], [10, 11, 12]],
|
||||
], dtype=dtype),
|
||||
np.array([[0, 0], [1, 1], [1, 1]], dtype=np.int32),
|
||||
expected=np.array(
|
||||
[
|
||||
[
|
||||
[5, 4, 5, 6, 5], #
|
||||
[2, 1, 2, 3, 2], #
|
||||
[5, 4, 5, 6, 5], #
|
||||
[2, 1, 2, 3, 2], #
|
||||
],
|
||||
[
|
||||
[11, 10, 11, 12, 11], #
|
||||
[8, 7, 8, 9, 8], #
|
||||
[11, 10, 11, 12, 11], #
|
||||
[8, 7, 8, 9, 8], #
|
||||
]
|
||||
],
|
||||
dtype=dtype))
|
||||
|
||||
def testReshape(self):
|
||||
for dtype in self.numeric_types:
|
||||
self._testBinary(
|
||||
|
@ -62,6 +62,7 @@ Status BackwardsConstAnalysis(const Graph& g,
|
||||
{"Min", "reduction_indices"},
|
||||
{"OneHot", "depth"},
|
||||
{"Pad", "paddings"},
|
||||
{"MirrorPad", "paddings"},
|
||||
{"Prod", "reduction_indices"},
|
||||
{"RandomStandardNormal", "shape"},
|
||||
{"RandomUniform", "shape"},
|
||||
|
@ -35,6 +35,7 @@ tf_kernel_library(
|
||||
"l2loss_op.cc",
|
||||
"lrn_ops.cc",
|
||||
"matmul_op.cc",
|
||||
"mirror_pad_op.cc",
|
||||
"no_op.cc",
|
||||
"one_hot_op.cc",
|
||||
"pack_op.cc",
|
||||
|
98
tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc
Normal file
98
tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc
Normal file
@ -0,0 +1,98 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/core/util/mirror_pad_mode.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
class MirrorPadOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit MirrorPadOp(OpKernelConstruction* context) : XlaOpKernel(context) {}
|
||||
|
||||
xla::StatusOr<xla::ComputationDataHandle> DoMirrorPad(
|
||||
const xla::ComputationDataHandle& t, const xla::Shape& original_shape,
|
||||
const xla::Literal& pad_literal, xla::ComputationBuilder* b) {
|
||||
xla::ComputationDataHandle accum = t;
|
||||
for (int64 dimno = xla::ShapeUtil::Rank(original_shape) - 1; dimno >= 0;
|
||||
--dimno) {
|
||||
auto t_rev = b->Rev(accum, {dimno});
|
||||
TF_ASSIGN_OR_RETURN(int64 lhs_padding,
|
||||
pad_literal.GetIntegralAsS64({dimno, 0}));
|
||||
TF_ASSIGN_OR_RETURN(int64 rhs_padding,
|
||||
pad_literal.GetIntegralAsS64({dimno, 1}));
|
||||
int64 dim_size = original_shape.dimensions(dimno);
|
||||
auto lhs_pad = b->SliceInDim(t_rev, dim_size - 1 - lhs_padding,
|
||||
dim_size - 1, 1, dimno);
|
||||
auto rhs_pad = b->SliceInDim(t_rev, 1, 1 + rhs_padding, 1, dimno);
|
||||
accum = b->ConcatInDim({lhs_pad, accum, rhs_pad}, dimno);
|
||||
}
|
||||
return accum;
|
||||
}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
const TensorShape input_shape = ctx->InputShape(0);
|
||||
const TensorShape pad_shape = ctx->InputShape(1);
|
||||
|
||||
MirrorPadMode mode;
|
||||
OP_REQUIRES_OK(ctx, GetNodeAttr(def(), "mode", &mode));
|
||||
OP_REQUIRES(ctx, mode == MirrorPadMode::REFLECT,
|
||||
xla::Unimplemented(
|
||||
"Only REFLECT MirrorPad mode is currently supported"));
|
||||
|
||||
const int dims = input_shape.dims();
|
||||
OP_REQUIRES(
|
||||
ctx,
|
||||
TensorShapeUtils::IsMatrix(pad_shape) && pad_shape.dim_size(1) == 2,
|
||||
errors::InvalidArgument("paddings must be a matrix with 2 columns: ",
|
||||
pad_shape.DebugString()));
|
||||
const int fixed_dims =
|
||||
(allow_legacy_scalars() && dims == 0 && pad_shape.dim_size(0) == 1)
|
||||
? 1
|
||||
: dims;
|
||||
OP_REQUIRES(
|
||||
ctx, fixed_dims == pad_shape.dim_size(0),
|
||||
errors::InvalidArgument(
|
||||
"The first dimension of paddings must be the rank of inputs",
|
||||
pad_shape.DebugString(), " ", input_shape.DebugString()));
|
||||
|
||||
// Evaluate the 'padding' constant input, reshaping to a matrix.
|
||||
xla::Literal pad_literal;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, ctx->ConstantInputReshaped(1, {fixed_dims, 2}, &pad_literal));
|
||||
|
||||
xla::ComputationBuilder* b = ctx->builder();
|
||||
auto in0 = ctx->Input(0);
|
||||
xla::StatusOr<std::unique_ptr<xla::Shape>> in0_shape = b->GetShape(in0);
|
||||
OP_REQUIRES(ctx, in0_shape.ok(), in0_shape.status());
|
||||
xla::StatusOr<xla::ComputationDataHandle> accum_status =
|
||||
DoMirrorPad(in0, *in0_shape.ValueOrDie(), pad_literal, b);
|
||||
|
||||
OP_REQUIRES_OK(ctx, accum_status.status());
|
||||
|
||||
ctx->SetOutput(0, accum_status.ValueOrDie());
|
||||
}
|
||||
|
||||
private:
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(MirrorPadOp);
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("MirrorPad"), MirrorPadOp);
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
@ -284,6 +284,25 @@ ComputationDataHandle ComputationBuilder::Slice(
|
||||
return ParseOpResponse(s, &response);
|
||||
}
|
||||
|
||||
ComputationDataHandle ComputationBuilder::SliceInDim(
|
||||
const ComputationDataHandle& operand, int64 start_index, int64 limit_index,
|
||||
int64 stride, int64 dimno) {
|
||||
StatusOr<std::unique_ptr<Shape>> shape_status = GetShape(operand);
|
||||
if (!shape_status.ok()) {
|
||||
NoteError(shape_status.status());
|
||||
return ComputationDataHandle{};
|
||||
}
|
||||
const Shape& shape = *shape_status.ValueOrDie();
|
||||
std::vector<int64> starts(ShapeUtil::Rank(shape), 0);
|
||||
std::vector<int64> limits(shape.dimensions().begin(),
|
||||
shape.dimensions().end());
|
||||
std::vector<int64> strides(ShapeUtil::Rank(shape), 1);
|
||||
starts[dimno] = start_index;
|
||||
limits[dimno] = limit_index;
|
||||
strides[dimno] = stride;
|
||||
return Slice(operand, starts, limits, strides);
|
||||
}
|
||||
|
||||
ComputationDataHandle ComputationBuilder::DynamicSlice(
|
||||
const ComputationDataHandle& operand,
|
||||
const ComputationDataHandle& start_indices,
|
||||
|
@ -217,6 +217,16 @@ class ComputationBuilder {
|
||||
tensorflow::gtl::ArraySlice<int64> limit_indices,
|
||||
tensorflow::gtl::ArraySlice<int64> stride);
|
||||
|
||||
// Enqueues a slice operation in a given dimension, taking all other
|
||||
// dimensions as they are; e.g. if dimno is 1 from start_index 2 to
|
||||
// limit_index 4 by 1, and the shape is f32[7,8,9], this call is short-hand
|
||||
// for:
|
||||
//
|
||||
// array[:, 2:4:1, :]
|
||||
ComputationDataHandle SliceInDim(const ComputationDataHandle& operand,
|
||||
int64 start_index, int64 limit_index,
|
||||
int64 stride, int64 dimno);
|
||||
|
||||
// Enqueues a slice operation onto the computation that slices the 'operand'
|
||||
// from dynamic start indices which are passed in 'start_indices'.
|
||||
// The size of the slice in each dimension is passed in 'slice_sizes',
|
||||
|
@ -503,6 +503,28 @@ string Literal::GetAsString(
|
||||
}
|
||||
}
|
||||
|
||||
StatusOr<int64> Literal::GetIntegralAsS64(
|
||||
tensorflow::gtl::ArraySlice<int64> multi_index) const {
|
||||
switch (shape().element_type()) {
|
||||
case PRED:
|
||||
return Get<bool>(multi_index);
|
||||
case U8:
|
||||
return Get<uint8>(multi_index);
|
||||
case S32:
|
||||
return Get<int32>(multi_index);
|
||||
case S64:
|
||||
return Get<int64>(multi_index);
|
||||
case U32:
|
||||
return Get<uint32>(multi_index);
|
||||
case U64:
|
||||
return Get<uint64>(multi_index);
|
||||
default:
|
||||
return FailedPrecondition(
|
||||
"Array element type is not integral: %s",
|
||||
PrimitiveType_Name(shape().element_type()).c_str());
|
||||
}
|
||||
}
|
||||
|
||||
int64 Literal::LinearIndex(
|
||||
tensorflow::gtl::ArraySlice<int64> multi_index) const {
|
||||
return IndexUtil::MultidimensionalIndexToLinearIndex(shape(), multi_index);
|
||||
|
@ -390,6 +390,11 @@ class Literal {
|
||||
// into text.
|
||||
string GetAsString(tensorflow::gtl::ArraySlice<int64> multi_index) const;
|
||||
|
||||
// As Get(), but determines the correct type and converts the value into
|
||||
// int64.
|
||||
StatusOr<int64> GetIntegralAsS64(
|
||||
tensorflow::gtl::ArraySlice<int64> multi_index) const;
|
||||
|
||||
// Returns an identity matrix (rank 2) with the given row and column count.
|
||||
template <typename NativeT>
|
||||
static std::unique_ptr<Literal> MakeIdentityR2(int64 size);
|
||||
|
@ -269,9 +269,9 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
|
||||
return InvalidArgument(
|
||||
"cannot concatenate arrays that differ in dimensions other than "
|
||||
"the one being concatenated (the other array dimensions must be "
|
||||
"the same): %s vs %s",
|
||||
"the same): %s vs %s in dimension %lld",
|
||||
ShapeUtil::HumanString(*arg_shape).c_str(),
|
||||
ShapeUtil::HumanString(*shape).c_str());
|
||||
ShapeUtil::HumanString(*shape).c_str(), dimension);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user