[TF:XLA] Add initial implementation of the Stack operators to the TF/XLA bridge. Stacks are used when computing loop gradients.

PiperOrigin-RevId: 161659980
This commit is contained in:
Peter Hawkins 2017-07-12 06:58:36 -07:00 committed by TensorFlower Gardener
parent 8f66dd24d1
commit eb1fe50da4
10 changed files with 392 additions and 1 deletions

View File

@ -366,6 +366,19 @@ tf_xla_py_test(
],
)
tf_xla_py_test(
name = "stack_ops_test",
size = "small",
srcs = ["stack_ops_test.py"],
deps = [
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:data_flow_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:platform_test",
],
)
tf_xla_py_test(
name = "tensor_array_ops_test",
size = "small",

View File

@ -0,0 +1,104 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tensorflow.ops.stack_ops."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.compiler.tests.xla_test import XLATestCase
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_data_flow_ops
from tensorflow.python.platform import test
class StackOpTest(XLATestCase):
def testStackPushPop(self):
with self.test_session(), self.test_scope():
size = array_ops.placeholder(dtypes.int32)
v = array_ops.placeholder(dtypes.float32)
h = gen_data_flow_ops._stack_v2(size, dtypes.float32, stack_name="foo")
c = gen_data_flow_ops._stack_push_v2(h, v)
with ops.control_dependencies([c]):
c1 = gen_data_flow_ops._stack_pop_v2(h, dtypes.float32)
self.assertAllClose([[4.0, 5.0]], c1.eval({size: 5, v: [[4.0, 5.0]]}))
def testStackPushPopSwap(self):
with self.test_session(), self.test_scope():
a = np.arange(2000)
x = array_ops.placeholder(dtypes.float32)
h = gen_data_flow_ops._stack_v2(5, dtypes.float32, stack_name="foo")
c = gen_data_flow_ops._stack_push_v2(h, x, swap_memory=True)
with ops.control_dependencies([c]):
c1 = gen_data_flow_ops._stack_pop_v2(h, dtypes.float32)
self.assertAllClose(a, c1.eval({x: a}))
def testMultiStack(self):
with self.test_session(), self.test_scope():
v = array_ops.placeholder(dtypes.float32)
h1 = gen_data_flow_ops._stack_v2(5, dtypes.float32, stack_name="foo")
c1 = gen_data_flow_ops._stack_push_v2(h1, v)
with ops.control_dependencies([c1]):
c1 = gen_data_flow_ops._stack_pop_v2(h1, dtypes.float32)
h2 = gen_data_flow_ops._stack_v2(5, dtypes.float32, stack_name="bar")
c2 = gen_data_flow_ops._stack_push_v2(h2, 5.0)
with ops.control_dependencies([c2]):
c2 = gen_data_flow_ops._stack_pop_v2(h2, dtypes.float32)
r = c1 + c2
self.assertAllClose(9.0, r.eval({v: 4.0}))
def testSameNameStacks(self):
"""Different stacks with the same name do not interfere."""
with self.test_session() as sess, self.test_scope():
v1 = array_ops.placeholder(dtypes.float32)
v2 = array_ops.placeholder(dtypes.float32)
h1 = gen_data_flow_ops._stack_v2(5, dtypes.float32, stack_name="foo")
h2 = gen_data_flow_ops._stack_v2(5, dtypes.float32, stack_name="foo")
c1 = gen_data_flow_ops._stack_push_v2(h1, v1)
with ops.control_dependencies([c1]):
c2 = gen_data_flow_ops._stack_push_v2(h2, v2)
with ops.control_dependencies([c2]):
pop1 = gen_data_flow_ops._stack_pop_v2(h1, dtypes.float32)
pop2 = gen_data_flow_ops._stack_pop_v2(h2, dtypes.float32)
out1, out2 = sess.run([pop1, pop2], {v1: 4.0, v2: 5.0})
self.assertAllClose(out1, 4.0)
self.assertAllClose(out2, 5.0)
def testCloseStack(self):
with self.test_session() as sess, self.test_scope():
size = array_ops.placeholder(dtypes.int32)
h = gen_data_flow_ops._stack_v2(size, dtypes.float32, stack_name="foo")
c1 = gen_data_flow_ops._stack_close_v2(h)
sess.run(c1, {size: 5})
def testPushCloseStack(self):
with self.test_session() as sess, self.test_scope():
v = array_ops.placeholder(dtypes.float32)
h = gen_data_flow_ops._stack_v2(5, dtypes.float32, stack_name="foo")
c = gen_data_flow_ops._stack_push_v2(h, v)
with ops.control_dependencies([c]):
c1 = gen_data_flow_ops._stack_close_v2(h)
sess.run(c1, {v: [[4.0, 5.0]]})
if __name__ == "__main__":
test.main()

View File

@ -81,6 +81,7 @@ Status BackwardsConstAnalysis(const Graph& g,
{"Split", "split_dim"},
{"SplitV", "split_dim"},
{"SplitV", "size_splits"},
{"StackV2", "max_size"},
{"StridedSlice", "begin"},
{"StridedSlice", "end"},
{"StridedSlice", "strides"},

View File

@ -54,6 +54,7 @@ tf_kernel_library(
"softmax_op.cc",
"spacetobatch_op.cc",
"split_op.cc",
"stack_ops.cc",
"strided_slice_op.cc",
"tensor_array_ops.cc",
"tile_ops.cc",

View File

@ -60,6 +60,9 @@ class ArgOp : public XlaOpKernel {
case XlaCompiler::Argument::kTensorArray:
kind = XlaResource::kTensorArray;
break;
case XlaCompiler::Argument::kStack:
kind = XlaResource::kStack;
break;
default:
CHECK(false);
}

View File

@ -0,0 +1,259 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// XLA Stack operators.
#include <limits>
#include <vector>
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/bounds_check.h"
#include "tensorflow/core/kernels/concat_lib.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
namespace {
Status GetStackShape(xla::ComputationBuilder* builder, XlaResource* resource,
TensorShape* stack_shape) {
auto shape_or_status = builder->GetShape(resource->value);
if (!shape_or_status.ok()) {
return shape_or_status.status();
}
xla::Shape shape = *shape_or_status.ValueOrDie();
TF_RET_CHECK(xla::ShapeUtil::IsTuple(shape));
return XLAShapeToTensorShape(xla::ShapeUtil::GetTupleElementShape(shape, 0),
stack_shape);
}
// Since the element shape is not provided to the Stack operator,
// we lazily initialize the Stack at the time of the first write.
//
// If a Stack `resource` has not been initialized, constructs storage for the
// Stack with elements of `elem_shape`. For both initialized and
// uninitialized Stacks, checks that the tensor has a type compatible with
// 'dtype' and shape compatible with 'elem_shape'.
//
// TODO(phawkins): consider changing the API of the stack operators to
// allow an optional element shape at stack construction time.
Status MaybeInitializeStack(xla::ComputationBuilder* builder,
XlaResource* resource, DataType dtype,
const TensorShape& elem_shape) {
if (resource->type != dtype) {
return errors::InvalidArgument(
"Stack dtype is ", DataTypeString(resource->type), " but op has dtype ",
DataTypeString(dtype), ".");
}
TensorShape stack_shape;
stack_shape.AddDim(resource->tensor_array_size);
stack_shape.AppendShape(elem_shape);
if (resource->value.handle() == 0) {
// Stack has not been initialized.
xla::ComputationDataHandle zero = XlaHelpers::Zero(builder, resource->type);
resource->value =
builder->Tuple({builder->Broadcast(zero, stack_shape.dim_sizes()),
builder->ConstantR0<int32>(0)});
} else {
// Checks the expected shape matches the actual shape.
TensorShape actual_shape;
TF_RETURN_IF_ERROR(GetStackShape(builder, resource, &actual_shape));
if (stack_shape != actual_shape) {
return errors::InvalidArgument(
"Mismatched Stack shapes: ", stack_shape.DebugString(), " vs ",
actual_shape.DebugString());
}
}
return Status::OK();
}
// Pads 'x' with 'count' zero indices. 'x' must have 1 element.
xla::ComputationDataHandle PadIndexWithZeros(
xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x,
int count) {
xla::ComputationDataHandle zero = builder->ConstantR1<int32>({0});
std::vector<xla::ComputationDataHandle> xs(count + 1, zero);
xs[0] = builder->Reshape(x, {1});
return builder->ConcatInDim(xs, 0);
}
class StackOp : public XlaOpKernel {
public:
explicit StackOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("elem_type", &dtype_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("stack_name", &stack_name_));
}
void Compile(XlaOpKernelContext* ctx) override {
int64 size;
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(0, &size));
OP_REQUIRES(
ctx, size >= 0,
errors::InvalidArgument(
"XLA compilation requires a fixed stack size upper bound."));
// We defer initializing the Stack resource until we see the first push.
// Otherwise we do not know the shape of the stack elements.
xla::ComputationDataHandle value;
XlaContext& xc = XlaContext::Get(ctx);
XlaResource* resource;
string name = strings::StrCat("Stack: ", stack_name_);
OP_REQUIRES_OK(
ctx, xc.CreateResource(XlaResource::kStack, -1, std::move(name), dtype_,
value, &resource));
resource->tensor_array_size = size;
ctx->SetResourceOutput(0, resource);
}
private:
DataType dtype_;
string stack_name_;
TF_DISALLOW_COPY_AND_ASSIGN(StackOp);
};
REGISTER_XLA_OP(Name("StackV2"), StackOp);
class StackPushOp : public XlaOpKernel {
public:
explicit StackPushOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
}
void Compile(XlaOpKernelContext* ctx) override {
xla::ComputationBuilder* b = ctx->builder();
TensorShape elem_shape = ctx->InputShape(1);
XlaResource* resource;
OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource));
// Initializes the Stack, if the element shape was not already known.
OP_REQUIRES_OK(ctx, MaybeInitializeStack(b, resource, dtype_, elem_shape));
xla::ComputationDataHandle ta = b->GetTupleElement(resource->value, 0);
xla::ComputationDataHandle index = b->GetTupleElement(resource->value, 1);
xla::ComputationDataHandle value = ctx->Input(1);
// start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0].
auto start_indices = PadIndexWithZeros(b, index, elem_shape.dims());
TensorShape slice_shape = elem_shape;
slice_shape.InsertDim(0, 1LL);
auto update = b->Reshape(value, slice_shape.dim_sizes());
// TODO(phawkins): We don't check the index is in bounds --- there is no
// error mechanism in XLA.
resource->value =
b->Tuple({b->DynamicUpdateSlice(ta, update, start_indices),
b->Add(index, b->ConstantR0<int32>(1))});
ctx->SetOutput(0, value);
}
private:
DataType dtype_;
TF_DISALLOW_COPY_AND_ASSIGN(StackPushOp);
};
REGISTER_XLA_OP(Name("StackPushV2"), StackPushOp);
class StackPopOp : public XlaOpKernel {
public:
explicit StackPopOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("elem_type", &dtype_));
}
void Compile(XlaOpKernelContext* ctx) override {
xla::ComputationBuilder* b = ctx->builder();
XlaResource* resource;
OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource));
OP_REQUIRES(ctx, resource->type == dtype_,
errors::InvalidArgument(
"Stack dtype is ", DataTypeString(resource->type),
" but Op requested dtype ", DataTypeString(dtype_), "."));
// There is a somewhat subtle issue here: here "uninitialized" means we have
// not yet seen a pop in the order that we compile operators, not the order
// that we run them. However, in practice the two orders should be the same
// for the sole user of the stack operators (loop gradients).
OP_REQUIRES(ctx, resource->value.handle() != 0,
errors::InvalidArgument("Stack pop on uninitialized stack"));
TensorShape stack_shape;
OP_REQUIRES_OK(ctx, GetStackShape(b, resource, &stack_shape));
xla::ComputationDataHandle state = resource->value;
xla::ComputationDataHandle ta = b->GetTupleElement(state, 0);
xla::ComputationDataHandle index = b->GetTupleElement(state, 1);
index = b->Sub(index, b->ConstantR0<int32>(1));
resource->value = b->Tuple({ta, index});
// start_indices of the DynamicSlice are [index, 0, 0, ..., 0].
auto start_indices = PadIndexWithZeros(b, index, stack_shape.dims() - 1);
auto slice_shape = stack_shape.dim_sizes();
slice_shape[0] = 1LL;
// TODO(phawkins): We don't check the index is in bounds --- there is no
// error mechanism in XLA.
xla::ComputationDataHandle read =
b->DynamicSlice(ta, start_indices, slice_shape);
// Remove the leading '1' dimension.
std::vector<int64> value_shape(slice_shape.begin() + 1, slice_shape.end());
ctx->SetOutput(0, b->Reshape(read, value_shape));
}
private:
DataType dtype_;
TF_DISALLOW_COPY_AND_ASSIGN(StackPopOp);
};
REGISTER_XLA_OP(Name("StackPopV2"), StackPopOp);
class StackCloseOp : public XlaOpKernel {
public:
explicit StackCloseOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
// Do nothing.
}
private:
TF_DISALLOW_COPY_AND_ASSIGN(StackCloseOp);
};
REGISTER_XLA_OP(Name("StackCloseV2"), StackCloseOp);
} // anonymous namespace
} // namespace tensorflow

View File

@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/computation_builder.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/op_kernel.h"
@ -56,6 +57,9 @@ Status MakeXlaCompilerArgumentsFromInputs(
case XlaResource::kTensorArray:
arg.kind = XlaCompiler::Argument::kTensorArray;
break;
case XlaResource::kStack:
arg.kind = XlaCompiler::Argument::kStack;
break;
case XlaResource::kInvalid:
CHECK(false);
}

View File

@ -70,6 +70,7 @@ struct XlaResource {
kInvalid,
kVariable,
kTensorArray,
kStack,
};
Kind kind = kInvalid;

View File

@ -266,6 +266,7 @@ Status BuildArguments(const std::vector<XlaCompiler::Argument>& args,
switch (args[i].kind) {
case XlaCompiler::Argument::kVariable:
case XlaCompiler::Argument::kTensorArray:
case XlaCompiler::Argument::kStack:
context_arg.is_resource = true;
if (args[i].initialized) {
resources.push_back(i);

View File

@ -85,7 +85,7 @@ class XlaCompiler {
// Argument is a compile-time constant. No associated runtime parameter.
kConstant,
// Argument is a variable resource. Has an associated runtime parameter
// Argument is a Variable resource. Has an associated runtime parameter
// iff `initialized` is true.
kVariable,
@ -93,6 +93,10 @@ class XlaCompiler {
// iff `initialized` is true.
kTensorArray,
// Argument is a Stack resource. Has an associated runtime parameter
// iff `initialized` is true.
kStack,
// Argument is a run-time parameter.
kParameter,
};