[TF:XLA] Implement support for tf.roll

This is built on DUS and requires the 'axis' input to be a constant. It
might be possible to remove that restriction, but I'm currently not sure how.

PiperOrigin-RevId: 254538077
This commit is contained in:
Benjamin Kramer 2019-06-22 02:18:38 -07:00 committed by TensorFlower Gardener
parent 25a55b0d0c
commit cc3bd2c8ed
4 changed files with 171 additions and 0 deletions

View File

@ -599,6 +599,19 @@ tf_xla_py_test(
],
)
tf_xla_py_test(
name = "manip_ops_test",
size = "small",
srcs = ["manip_ops_test.py"],
deps = [
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework",
"//tensorflow/python:manip_ops",
"//tensorflow/python:platform_test",
],
)
tf_xla_py_test(
name = "matrix_band_part_test",
size = "medium",

View File

@ -0,0 +1,68 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test cases for manip ops."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import manip_ops
from tensorflow.python.platform import googletest
class ManipOpsTest(xla_test.XLATestCase):
"""Test cases for manip ops."""
def _testRoll(self, a, shift, axis):
with self.session() as session:
with self.test_scope():
p = array_ops.placeholder(dtypes.as_dtype(a.dtype), a.shape, name="a")
output = manip_ops.roll(a, shift, axis)
result = session.run(output, {p: a})
self.assertAllEqual(result, np.roll(a, shift, axis))
def testNumericTypes(self):
for t in self.numeric_types:
self._testRoll(np.random.randint(-100, 100, (5)).astype(t), 3, 0)
self._testRoll(
np.random.randint(-100, 100, (4, 4, 3)).astype(t), [1, -6, 6],
[0, 1, 2])
self._testRoll(
np.random.randint(-100, 100, (4, 2, 1, 3)).astype(t), [0, 1, -2],
[1, 2, 3])
def testFloatTypes(self):
for t in self.float_types:
self._testRoll(np.random.rand(5).astype(t), 2, 0)
self._testRoll(np.random.rand(3, 4).astype(t), [1, 2], [1, 0])
self._testRoll(np.random.rand(1, 3, 4).astype(t), [1, 0, -3], [0, 1, 2])
def testComplexTypes(self):
for t in self.complex_types:
x = np.random.rand(4, 4).astype(t)
self._testRoll(x + 1j * x, 2, 0)
x = np.random.rand(2, 5).astype(t)
self._testRoll(x + 1j * x, [1, 2], [1, 0])
x = np.random.rand(3, 2, 1, 1).astype(t)
self._testRoll(x + 1j * x, [2, 1, 1, 0], [0, 3, 1, 2])
if __name__ == "__main__":
googletest.main()

View File

@ -79,6 +79,7 @@ tf_kernel_library(
"retval_op.cc",
"reverse_op.cc",
"reverse_sequence_op.cc",
"roll_op.cc",
"scan_ops.cc",
"scatter_nd_op.cc",
"segment_reduction_ops.cc",

View File

@ -0,0 +1,89 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/lib/slicing.h"
namespace tensorflow {
namespace {
class RollOp : public XlaOpKernel {
public:
explicit RollOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
const TensorShape input_shape = ctx->InputShape(0);
xla::XlaOp shift = ctx->Input(1);
const TensorShape shift_shape = ctx->InputShape(1);
const TensorShape axis_shape = ctx->InputShape(2);
OP_REQUIRES(ctx, input_shape.dims() >= 1,
errors::InvalidArgument("input must be 1-D or higher"));
OP_REQUIRES(ctx, shift_shape.dims() <= 1,
errors::InvalidArgument(
"shift must be a scalar or a 1-D vector. Found: ",
shift_shape.DebugString()));
OP_REQUIRES(
ctx, shift_shape.dims() == axis_shape.dims(),
errors::InvalidArgument("shift and axis must have the same size"));
xla::Literal axis;
OP_REQUIRES_OK(ctx, ctx->ConstantInput(2, &axis));
xla::XlaOp output = ctx->Input(0);
xla::PrimitiveType shift_type = ctx->input_xla_type(1);
int64 num_axes = axis_shape.dims() == 0 ? 1 : axis_shape.dim_size(0);
for (int64 i = 0; i != num_axes; ++i) {
auto cur_axis_status = axis_shape.dims() == 0
? axis.GetIntegralAsS64({})
: axis.GetIntegralAsS64({i});
OP_REQUIRES_OK(ctx, cur_axis_status.status());
int64 cur_axis = cur_axis_status.ValueOrDie();
xla::XlaOp offset =
shift_shape.dims() == 0
? shift
: xla::Reshape(xla::SliceInDim(shift, /*start_index=*/i,
/*limit_index=*/i + 1,
/*stride=*/1, /*dimno=*/0),
{});
xla::XlaOp axis_size = xla::ConstantR0WithType(
ctx->builder(), shift_type, input_shape.dim_size(cur_axis));
// Adjust large offsets into [0, axis_size). This also makes negative
// offsets positive.
offset = ((offset % axis_size) + axis_size) % axis_size;
// Stack two copies of the dimension, then slice from the calculated
// offset.
xla::XlaOp concat =
xla::ConcatInDim(ctx->builder(), {output, output}, cur_axis);
std::vector<xla::XlaOp> start_indices(
input_shape.dims(), xla::Zero(ctx->builder(), shift_type));
start_indices[cur_axis] = axis_size - offset;
output =
xla::DynamicSlice(concat, start_indices, input_shape.dim_sizes());
}
ctx->SetOutput(0, output);
}
private:
TF_DISALLOW_COPY_AND_ASSIGN(RollOp);
};
REGISTER_XLA_OP(Name("Roll").CompileTimeConstantInput("axis"), RollOp);
} // namespace
} // namespace tensorflow