From eb1fe50da445d3880b588215f6fadcc7f48dd3ff Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 12 Jul 2017 06:58:36 -0700 Subject: [PATCH] [TF:XLA] Add initial implementation of the Stack operators to the TF/XLA bridge. Stacks are used when computing loop gradients. PiperOrigin-RevId: 161659980 --- tensorflow/compiler/tests/BUILD | 13 + tensorflow/compiler/tests/stack_ops_test.py | 104 +++++++ tensorflow/compiler/tf2xla/const_analysis.cc | 1 + tensorflow/compiler/tf2xla/kernels/BUILD | 1 + tensorflow/compiler/tf2xla/kernels/arg_op.cc | 3 + .../compiler/tf2xla/kernels/stack_ops.cc | 259 ++++++++++++++++++ .../compiler/tf2xla/kernels/while_op.cc | 4 + .../compiler/tf2xla/xla_compilation_device.h | 1 + tensorflow/compiler/tf2xla/xla_compiler.cc | 1 + tensorflow/compiler/tf2xla/xla_compiler.h | 6 +- 10 files changed, 392 insertions(+), 1 deletion(-) create mode 100644 tensorflow/compiler/tests/stack_ops_test.py create mode 100644 tensorflow/compiler/tf2xla/kernels/stack_ops.cc diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index 432b24756d2..4cd2137eda8 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -366,6 +366,19 @@ tf_xla_py_test( ], ) +tf_xla_py_test( + name = "stack_ops_test", + size = "small", + srcs = ["stack_ops_test.py"], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:data_flow_ops", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:platform_test", + ], +) + tf_xla_py_test( name = "tensor_array_ops_test", size = "small", diff --git a/tensorflow/compiler/tests/stack_ops_test.py b/tensorflow/compiler/tests/stack_ops_test.py new file mode 100644 index 00000000000..2b9c2279737 --- /dev/null +++ b/tensorflow/compiler/tests/stack_ops_test.py @@ -0,0 +1,104 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tensorflow.ops.stack_ops.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_data_flow_ops +from tensorflow.python.platform import test + + +class StackOpTest(XLATestCase): + + def testStackPushPop(self): + with self.test_session(), self.test_scope(): + size = array_ops.placeholder(dtypes.int32) + v = array_ops.placeholder(dtypes.float32) + h = gen_data_flow_ops._stack_v2(size, dtypes.float32, stack_name="foo") + c = gen_data_flow_ops._stack_push_v2(h, v) + with ops.control_dependencies([c]): + c1 = gen_data_flow_ops._stack_pop_v2(h, dtypes.float32) + self.assertAllClose([[4.0, 5.0]], c1.eval({size: 5, v: [[4.0, 5.0]]})) + + def testStackPushPopSwap(self): + with self.test_session(), self.test_scope(): + a = np.arange(2000) + x = array_ops.placeholder(dtypes.float32) + h = gen_data_flow_ops._stack_v2(5, dtypes.float32, stack_name="foo") + c = gen_data_flow_ops._stack_push_v2(h, x, swap_memory=True) + with ops.control_dependencies([c]): + c1 = gen_data_flow_ops._stack_pop_v2(h, dtypes.float32) + self.assertAllClose(a, c1.eval({x: a})) + + def testMultiStack(self): + with self.test_session(), self.test_scope(): + v = array_ops.placeholder(dtypes.float32) + h1 = gen_data_flow_ops._stack_v2(5, dtypes.float32, stack_name="foo") + c1 = gen_data_flow_ops._stack_push_v2(h1, v) + with ops.control_dependencies([c1]): + c1 = gen_data_flow_ops._stack_pop_v2(h1, dtypes.float32) + h2 = gen_data_flow_ops._stack_v2(5, dtypes.float32, stack_name="bar") + c2 = gen_data_flow_ops._stack_push_v2(h2, 5.0) + with ops.control_dependencies([c2]): + c2 = gen_data_flow_ops._stack_pop_v2(h2, dtypes.float32) + r = c1 + c2 + self.assertAllClose(9.0, r.eval({v: 4.0})) + + def testSameNameStacks(self): + """Different stacks with the same name do not interfere.""" + with self.test_session() as sess, self.test_scope(): + v1 = array_ops.placeholder(dtypes.float32) + v2 = array_ops.placeholder(dtypes.float32) + h1 = gen_data_flow_ops._stack_v2(5, dtypes.float32, stack_name="foo") + h2 = gen_data_flow_ops._stack_v2(5, dtypes.float32, stack_name="foo") + + c1 = gen_data_flow_ops._stack_push_v2(h1, v1) + with ops.control_dependencies([c1]): + c2 = gen_data_flow_ops._stack_push_v2(h2, v2) + with ops.control_dependencies([c2]): + pop1 = gen_data_flow_ops._stack_pop_v2(h1, dtypes.float32) + pop2 = gen_data_flow_ops._stack_pop_v2(h2, dtypes.float32) + + out1, out2 = sess.run([pop1, pop2], {v1: 4.0, v2: 5.0}) + self.assertAllClose(out1, 4.0) + self.assertAllClose(out2, 5.0) + + def testCloseStack(self): + with self.test_session() as sess, self.test_scope(): + size = array_ops.placeholder(dtypes.int32) + h = gen_data_flow_ops._stack_v2(size, dtypes.float32, stack_name="foo") + c1 = gen_data_flow_ops._stack_close_v2(h) + sess.run(c1, {size: 5}) + + def testPushCloseStack(self): + with self.test_session() as sess, self.test_scope(): + v = array_ops.placeholder(dtypes.float32) + h = gen_data_flow_ops._stack_v2(5, dtypes.float32, stack_name="foo") + c = gen_data_flow_ops._stack_push_v2(h, v) + with ops.control_dependencies([c]): + c1 = gen_data_flow_ops._stack_close_v2(h) + sess.run(c1, {v: [[4.0, 5.0]]}) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tf2xla/const_analysis.cc b/tensorflow/compiler/tf2xla/const_analysis.cc index 36a6c90af4f..d98cf829bb6 100644 --- a/tensorflow/compiler/tf2xla/const_analysis.cc +++ b/tensorflow/compiler/tf2xla/const_analysis.cc @@ -81,6 +81,7 @@ Status BackwardsConstAnalysis(const Graph& g, {"Split", "split_dim"}, {"SplitV", "split_dim"}, {"SplitV", "size_splits"}, + {"StackV2", "max_size"}, {"StridedSlice", "begin"}, {"StridedSlice", "end"}, {"StridedSlice", "strides"}, diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index 96b4fdfec6d..35bc6b5a24e 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -54,6 +54,7 @@ tf_kernel_library( "softmax_op.cc", "spacetobatch_op.cc", "split_op.cc", + "stack_ops.cc", "strided_slice_op.cc", "tensor_array_ops.cc", "tile_ops.cc", diff --git a/tensorflow/compiler/tf2xla/kernels/arg_op.cc b/tensorflow/compiler/tf2xla/kernels/arg_op.cc index 6ad72c6219e..11565465129 100644 --- a/tensorflow/compiler/tf2xla/kernels/arg_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/arg_op.cc @@ -60,6 +60,9 @@ class ArgOp : public XlaOpKernel { case XlaCompiler::Argument::kTensorArray: kind = XlaResource::kTensorArray; break; + case XlaCompiler::Argument::kStack: + kind = XlaResource::kStack; + break; default: CHECK(false); } diff --git a/tensorflow/compiler/tf2xla/kernels/stack_ops.cc b/tensorflow/compiler/tf2xla/kernels/stack_ops.cc new file mode 100644 index 00000000000..6a6f65c02cc --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/stack_ops.cc @@ -0,0 +1,259 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// XLA Stack operators. + +#include +#include + +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/partial_tensor_shape.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/bounds_check.h" +#include "tensorflow/core/kernels/concat_lib.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace { + +Status GetStackShape(xla::ComputationBuilder* builder, XlaResource* resource, + TensorShape* stack_shape) { + auto shape_or_status = builder->GetShape(resource->value); + if (!shape_or_status.ok()) { + return shape_or_status.status(); + } + xla::Shape shape = *shape_or_status.ValueOrDie(); + TF_RET_CHECK(xla::ShapeUtil::IsTuple(shape)); + return XLAShapeToTensorShape(xla::ShapeUtil::GetTupleElementShape(shape, 0), + stack_shape); +} + +// Since the element shape is not provided to the Stack operator, +// we lazily initialize the Stack at the time of the first write. +// +// If a Stack `resource` has not been initialized, constructs storage for the +// Stack with elements of `elem_shape`. For both initialized and +// uninitialized Stacks, checks that the tensor has a type compatible with +// 'dtype' and shape compatible with 'elem_shape'. +// +// TODO(phawkins): consider changing the API of the stack operators to +// allow an optional element shape at stack construction time. +Status MaybeInitializeStack(xla::ComputationBuilder* builder, + XlaResource* resource, DataType dtype, + const TensorShape& elem_shape) { + if (resource->type != dtype) { + return errors::InvalidArgument( + "Stack dtype is ", DataTypeString(resource->type), " but op has dtype ", + DataTypeString(dtype), "."); + } + + TensorShape stack_shape; + stack_shape.AddDim(resource->tensor_array_size); + stack_shape.AppendShape(elem_shape); + + if (resource->value.handle() == 0) { + // Stack has not been initialized. + xla::ComputationDataHandle zero = XlaHelpers::Zero(builder, resource->type); + resource->value = + builder->Tuple({builder->Broadcast(zero, stack_shape.dim_sizes()), + builder->ConstantR0(0)}); + } else { + // Checks the expected shape matches the actual shape. + TensorShape actual_shape; + TF_RETURN_IF_ERROR(GetStackShape(builder, resource, &actual_shape)); + if (stack_shape != actual_shape) { + return errors::InvalidArgument( + "Mismatched Stack shapes: ", stack_shape.DebugString(), " vs ", + actual_shape.DebugString()); + } + } + return Status::OK(); +} + +// Pads 'x' with 'count' zero indices. 'x' must have 1 element. +xla::ComputationDataHandle PadIndexWithZeros( + xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x, + int count) { + xla::ComputationDataHandle zero = builder->ConstantR1({0}); + std::vector xs(count + 1, zero); + xs[0] = builder->Reshape(x, {1}); + return builder->ConcatInDim(xs, 0); +} + +class StackOp : public XlaOpKernel { + public: + explicit StackOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("elem_type", &dtype_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("stack_name", &stack_name_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + int64 size; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(0, &size)); + OP_REQUIRES( + ctx, size >= 0, + errors::InvalidArgument( + "XLA compilation requires a fixed stack size upper bound.")); + + // We defer initializing the Stack resource until we see the first push. + // Otherwise we do not know the shape of the stack elements. + xla::ComputationDataHandle value; + XlaContext& xc = XlaContext::Get(ctx); + XlaResource* resource; + string name = strings::StrCat("Stack: ", stack_name_); + OP_REQUIRES_OK( + ctx, xc.CreateResource(XlaResource::kStack, -1, std::move(name), dtype_, + value, &resource)); + resource->tensor_array_size = size; + ctx->SetResourceOutput(0, resource); + } + + private: + DataType dtype_; + string stack_name_; + + TF_DISALLOW_COPY_AND_ASSIGN(StackOp); +}; + +REGISTER_XLA_OP(Name("StackV2"), StackOp); + +class StackPushOp : public XlaOpKernel { + public: + explicit StackPushOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + xla::ComputationBuilder* b = ctx->builder(); + TensorShape elem_shape = ctx->InputShape(1); + + XlaResource* resource; + OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource)); + + // Initializes the Stack, if the element shape was not already known. + OP_REQUIRES_OK(ctx, MaybeInitializeStack(b, resource, dtype_, elem_shape)); + + xla::ComputationDataHandle ta = b->GetTupleElement(resource->value, 0); + xla::ComputationDataHandle index = b->GetTupleElement(resource->value, 1); + xla::ComputationDataHandle value = ctx->Input(1); + + // start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0]. + auto start_indices = PadIndexWithZeros(b, index, elem_shape.dims()); + + TensorShape slice_shape = elem_shape; + slice_shape.InsertDim(0, 1LL); + auto update = b->Reshape(value, slice_shape.dim_sizes()); + + // TODO(phawkins): We don't check the index is in bounds --- there is no + // error mechanism in XLA. + resource->value = + b->Tuple({b->DynamicUpdateSlice(ta, update, start_indices), + b->Add(index, b->ConstantR0(1))}); + + ctx->SetOutput(0, value); + } + + private: + DataType dtype_; + + TF_DISALLOW_COPY_AND_ASSIGN(StackPushOp); +}; + +REGISTER_XLA_OP(Name("StackPushV2"), StackPushOp); + +class StackPopOp : public XlaOpKernel { + public: + explicit StackPopOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("elem_type", &dtype_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + xla::ComputationBuilder* b = ctx->builder(); + + XlaResource* resource; + OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource)); + + OP_REQUIRES(ctx, resource->type == dtype_, + errors::InvalidArgument( + "Stack dtype is ", DataTypeString(resource->type), + " but Op requested dtype ", DataTypeString(dtype_), ".")); + + // There is a somewhat subtle issue here: here "uninitialized" means we have + // not yet seen a pop in the order that we compile operators, not the order + // that we run them. However, in practice the two orders should be the same + // for the sole user of the stack operators (loop gradients). + OP_REQUIRES(ctx, resource->value.handle() != 0, + errors::InvalidArgument("Stack pop on uninitialized stack")); + + TensorShape stack_shape; + OP_REQUIRES_OK(ctx, GetStackShape(b, resource, &stack_shape)); + + xla::ComputationDataHandle state = resource->value; + xla::ComputationDataHandle ta = b->GetTupleElement(state, 0); + xla::ComputationDataHandle index = b->GetTupleElement(state, 1); + + index = b->Sub(index, b->ConstantR0(1)); + resource->value = b->Tuple({ta, index}); + + // start_indices of the DynamicSlice are [index, 0, 0, ..., 0]. + auto start_indices = PadIndexWithZeros(b, index, stack_shape.dims() - 1); + + auto slice_shape = stack_shape.dim_sizes(); + slice_shape[0] = 1LL; + + // TODO(phawkins): We don't check the index is in bounds --- there is no + // error mechanism in XLA. + xla::ComputationDataHandle read = + b->DynamicSlice(ta, start_indices, slice_shape); + + // Remove the leading '1' dimension. + std::vector value_shape(slice_shape.begin() + 1, slice_shape.end()); + ctx->SetOutput(0, b->Reshape(read, value_shape)); + } + + private: + DataType dtype_; + + TF_DISALLOW_COPY_AND_ASSIGN(StackPopOp); +}; + +REGISTER_XLA_OP(Name("StackPopV2"), StackPopOp); + +class StackCloseOp : public XlaOpKernel { + public: + explicit StackCloseOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + + void Compile(XlaOpKernelContext* ctx) override { + // Do nothing. + } + + private: + TF_DISALLOW_COPY_AND_ASSIGN(StackCloseOp); +}; + +REGISTER_XLA_OP(Name("StackCloseV2"), StackCloseOp); + +} // anonymous namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/while_op.cc b/tensorflow/compiler/tf2xla/kernels/while_op.cc index 6d1af500535..2c2031fc761 100644 --- a/tensorflow/compiler/tf2xla/kernels/while_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/while_op.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/op_kernel.h" @@ -56,6 +57,9 @@ Status MakeXlaCompilerArgumentsFromInputs( case XlaResource::kTensorArray: arg.kind = XlaCompiler::Argument::kTensorArray; break; + case XlaResource::kStack: + arg.kind = XlaCompiler::Argument::kStack; + break; case XlaResource::kInvalid: CHECK(false); } diff --git a/tensorflow/compiler/tf2xla/xla_compilation_device.h b/tensorflow/compiler/tf2xla/xla_compilation_device.h index e4f43f1950d..ec28bdccda4 100644 --- a/tensorflow/compiler/tf2xla/xla_compilation_device.h +++ b/tensorflow/compiler/tf2xla/xla_compilation_device.h @@ -70,6 +70,7 @@ struct XlaResource { kInvalid, kVariable, kTensorArray, + kStack, }; Kind kind = kInvalid; diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index d1dc85a0ebf..adabb3574c9 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -266,6 +266,7 @@ Status BuildArguments(const std::vector& args, switch (args[i].kind) { case XlaCompiler::Argument::kVariable: case XlaCompiler::Argument::kTensorArray: + case XlaCompiler::Argument::kStack: context_arg.is_resource = true; if (args[i].initialized) { resources.push_back(i); diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h index 968f667bbac..429515e2b76 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.h +++ b/tensorflow/compiler/tf2xla/xla_compiler.h @@ -85,7 +85,7 @@ class XlaCompiler { // Argument is a compile-time constant. No associated runtime parameter. kConstant, - // Argument is a variable resource. Has an associated runtime parameter + // Argument is a Variable resource. Has an associated runtime parameter // iff `initialized` is true. kVariable, @@ -93,6 +93,10 @@ class XlaCompiler { // iff `initialized` is true. kTensorArray, + // Argument is a Stack resource. Has an associated runtime parameter + // iff `initialized` is true. + kStack, + // Argument is a run-time parameter. kParameter, };