[TF:XLA] Add initial implementation of the Stack operators to the TF/XLA bridge. Stacks are used when computing loop gradients.
PiperOrigin-RevId: 161659980
This commit is contained in:
parent
8f66dd24d1
commit
eb1fe50da4
@ -366,6 +366,19 @@ tf_xla_py_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_xla_py_test(
|
||||
name = "stack_ops_test",
|
||||
size = "small",
|
||||
srcs = ["stack_ops_test.py"],
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:data_flow_ops",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
)
|
||||
|
||||
tf_xla_py_test(
|
||||
name = "tensor_array_ops_test",
|
||||
size = "small",
|
||||
|
104
tensorflow/compiler/tests/stack_ops_test.py
Normal file
104
tensorflow/compiler/tests/stack_ops_test.py
Normal file
@ -0,0 +1,104 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for tensorflow.ops.stack_ops."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests.xla_test import XLATestCase
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gen_data_flow_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class StackOpTest(XLATestCase):
|
||||
|
||||
def testStackPushPop(self):
|
||||
with self.test_session(), self.test_scope():
|
||||
size = array_ops.placeholder(dtypes.int32)
|
||||
v = array_ops.placeholder(dtypes.float32)
|
||||
h = gen_data_flow_ops._stack_v2(size, dtypes.float32, stack_name="foo")
|
||||
c = gen_data_flow_ops._stack_push_v2(h, v)
|
||||
with ops.control_dependencies([c]):
|
||||
c1 = gen_data_flow_ops._stack_pop_v2(h, dtypes.float32)
|
||||
self.assertAllClose([[4.0, 5.0]], c1.eval({size: 5, v: [[4.0, 5.0]]}))
|
||||
|
||||
def testStackPushPopSwap(self):
|
||||
with self.test_session(), self.test_scope():
|
||||
a = np.arange(2000)
|
||||
x = array_ops.placeholder(dtypes.float32)
|
||||
h = gen_data_flow_ops._stack_v2(5, dtypes.float32, stack_name="foo")
|
||||
c = gen_data_flow_ops._stack_push_v2(h, x, swap_memory=True)
|
||||
with ops.control_dependencies([c]):
|
||||
c1 = gen_data_flow_ops._stack_pop_v2(h, dtypes.float32)
|
||||
self.assertAllClose(a, c1.eval({x: a}))
|
||||
|
||||
def testMultiStack(self):
|
||||
with self.test_session(), self.test_scope():
|
||||
v = array_ops.placeholder(dtypes.float32)
|
||||
h1 = gen_data_flow_ops._stack_v2(5, dtypes.float32, stack_name="foo")
|
||||
c1 = gen_data_flow_ops._stack_push_v2(h1, v)
|
||||
with ops.control_dependencies([c1]):
|
||||
c1 = gen_data_flow_ops._stack_pop_v2(h1, dtypes.float32)
|
||||
h2 = gen_data_flow_ops._stack_v2(5, dtypes.float32, stack_name="bar")
|
||||
c2 = gen_data_flow_ops._stack_push_v2(h2, 5.0)
|
||||
with ops.control_dependencies([c2]):
|
||||
c2 = gen_data_flow_ops._stack_pop_v2(h2, dtypes.float32)
|
||||
r = c1 + c2
|
||||
self.assertAllClose(9.0, r.eval({v: 4.0}))
|
||||
|
||||
def testSameNameStacks(self):
|
||||
"""Different stacks with the same name do not interfere."""
|
||||
with self.test_session() as sess, self.test_scope():
|
||||
v1 = array_ops.placeholder(dtypes.float32)
|
||||
v2 = array_ops.placeholder(dtypes.float32)
|
||||
h1 = gen_data_flow_ops._stack_v2(5, dtypes.float32, stack_name="foo")
|
||||
h2 = gen_data_flow_ops._stack_v2(5, dtypes.float32, stack_name="foo")
|
||||
|
||||
c1 = gen_data_flow_ops._stack_push_v2(h1, v1)
|
||||
with ops.control_dependencies([c1]):
|
||||
c2 = gen_data_flow_ops._stack_push_v2(h2, v2)
|
||||
with ops.control_dependencies([c2]):
|
||||
pop1 = gen_data_flow_ops._stack_pop_v2(h1, dtypes.float32)
|
||||
pop2 = gen_data_flow_ops._stack_pop_v2(h2, dtypes.float32)
|
||||
|
||||
out1, out2 = sess.run([pop1, pop2], {v1: 4.0, v2: 5.0})
|
||||
self.assertAllClose(out1, 4.0)
|
||||
self.assertAllClose(out2, 5.0)
|
||||
|
||||
def testCloseStack(self):
|
||||
with self.test_session() as sess, self.test_scope():
|
||||
size = array_ops.placeholder(dtypes.int32)
|
||||
h = gen_data_flow_ops._stack_v2(size, dtypes.float32, stack_name="foo")
|
||||
c1 = gen_data_flow_ops._stack_close_v2(h)
|
||||
sess.run(c1, {size: 5})
|
||||
|
||||
def testPushCloseStack(self):
|
||||
with self.test_session() as sess, self.test_scope():
|
||||
v = array_ops.placeholder(dtypes.float32)
|
||||
h = gen_data_flow_ops._stack_v2(5, dtypes.float32, stack_name="foo")
|
||||
c = gen_data_flow_ops._stack_push_v2(h, v)
|
||||
with ops.control_dependencies([c]):
|
||||
c1 = gen_data_flow_ops._stack_close_v2(h)
|
||||
sess.run(c1, {v: [[4.0, 5.0]]})
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
@ -81,6 +81,7 @@ Status BackwardsConstAnalysis(const Graph& g,
|
||||
{"Split", "split_dim"},
|
||||
{"SplitV", "split_dim"},
|
||||
{"SplitV", "size_splits"},
|
||||
{"StackV2", "max_size"},
|
||||
{"StridedSlice", "begin"},
|
||||
{"StridedSlice", "end"},
|
||||
{"StridedSlice", "strides"},
|
||||
|
@ -54,6 +54,7 @@ tf_kernel_library(
|
||||
"softmax_op.cc",
|
||||
"spacetobatch_op.cc",
|
||||
"split_op.cc",
|
||||
"stack_ops.cc",
|
||||
"strided_slice_op.cc",
|
||||
"tensor_array_ops.cc",
|
||||
"tile_ops.cc",
|
||||
|
@ -60,6 +60,9 @@ class ArgOp : public XlaOpKernel {
|
||||
case XlaCompiler::Argument::kTensorArray:
|
||||
kind = XlaResource::kTensorArray;
|
||||
break;
|
||||
case XlaCompiler::Argument::kStack:
|
||||
kind = XlaResource::kStack;
|
||||
break;
|
||||
default:
|
||||
CHECK(false);
|
||||
}
|
||||
|
259
tensorflow/compiler/tf2xla/kernels/stack_ops.cc
Normal file
259
tensorflow/compiler/tf2xla/kernels/stack_ops.cc
Normal file
@ -0,0 +1,259 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// XLA Stack operators.
|
||||
|
||||
#include <limits>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/type_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/partial_tensor_shape.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_types.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/kernels/bounds_check.h"
|
||||
#include "tensorflow/core/kernels/concat_lib.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
Status GetStackShape(xla::ComputationBuilder* builder, XlaResource* resource,
|
||||
TensorShape* stack_shape) {
|
||||
auto shape_or_status = builder->GetShape(resource->value);
|
||||
if (!shape_or_status.ok()) {
|
||||
return shape_or_status.status();
|
||||
}
|
||||
xla::Shape shape = *shape_or_status.ValueOrDie();
|
||||
TF_RET_CHECK(xla::ShapeUtil::IsTuple(shape));
|
||||
return XLAShapeToTensorShape(xla::ShapeUtil::GetTupleElementShape(shape, 0),
|
||||
stack_shape);
|
||||
}
|
||||
|
||||
// Since the element shape is not provided to the Stack operator,
|
||||
// we lazily initialize the Stack at the time of the first write.
|
||||
//
|
||||
// If a Stack `resource` has not been initialized, constructs storage for the
|
||||
// Stack with elements of `elem_shape`. For both initialized and
|
||||
// uninitialized Stacks, checks that the tensor has a type compatible with
|
||||
// 'dtype' and shape compatible with 'elem_shape'.
|
||||
//
|
||||
// TODO(phawkins): consider changing the API of the stack operators to
|
||||
// allow an optional element shape at stack construction time.
|
||||
Status MaybeInitializeStack(xla::ComputationBuilder* builder,
|
||||
XlaResource* resource, DataType dtype,
|
||||
const TensorShape& elem_shape) {
|
||||
if (resource->type != dtype) {
|
||||
return errors::InvalidArgument(
|
||||
"Stack dtype is ", DataTypeString(resource->type), " but op has dtype ",
|
||||
DataTypeString(dtype), ".");
|
||||
}
|
||||
|
||||
TensorShape stack_shape;
|
||||
stack_shape.AddDim(resource->tensor_array_size);
|
||||
stack_shape.AppendShape(elem_shape);
|
||||
|
||||
if (resource->value.handle() == 0) {
|
||||
// Stack has not been initialized.
|
||||
xla::ComputationDataHandle zero = XlaHelpers::Zero(builder, resource->type);
|
||||
resource->value =
|
||||
builder->Tuple({builder->Broadcast(zero, stack_shape.dim_sizes()),
|
||||
builder->ConstantR0<int32>(0)});
|
||||
} else {
|
||||
// Checks the expected shape matches the actual shape.
|
||||
TensorShape actual_shape;
|
||||
TF_RETURN_IF_ERROR(GetStackShape(builder, resource, &actual_shape));
|
||||
if (stack_shape != actual_shape) {
|
||||
return errors::InvalidArgument(
|
||||
"Mismatched Stack shapes: ", stack_shape.DebugString(), " vs ",
|
||||
actual_shape.DebugString());
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Pads 'x' with 'count' zero indices. 'x' must have 1 element.
|
||||
xla::ComputationDataHandle PadIndexWithZeros(
|
||||
xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x,
|
||||
int count) {
|
||||
xla::ComputationDataHandle zero = builder->ConstantR1<int32>({0});
|
||||
std::vector<xla::ComputationDataHandle> xs(count + 1, zero);
|
||||
xs[0] = builder->Reshape(x, {1});
|
||||
return builder->ConcatInDim(xs, 0);
|
||||
}
|
||||
|
||||
class StackOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit StackOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("elem_type", &dtype_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("stack_name", &stack_name_));
|
||||
}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
int64 size;
|
||||
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(0, &size));
|
||||
OP_REQUIRES(
|
||||
ctx, size >= 0,
|
||||
errors::InvalidArgument(
|
||||
"XLA compilation requires a fixed stack size upper bound."));
|
||||
|
||||
// We defer initializing the Stack resource until we see the first push.
|
||||
// Otherwise we do not know the shape of the stack elements.
|
||||
xla::ComputationDataHandle value;
|
||||
XlaContext& xc = XlaContext::Get(ctx);
|
||||
XlaResource* resource;
|
||||
string name = strings::StrCat("Stack: ", stack_name_);
|
||||
OP_REQUIRES_OK(
|
||||
ctx, xc.CreateResource(XlaResource::kStack, -1, std::move(name), dtype_,
|
||||
value, &resource));
|
||||
resource->tensor_array_size = size;
|
||||
ctx->SetResourceOutput(0, resource);
|
||||
}
|
||||
|
||||
private:
|
||||
DataType dtype_;
|
||||
string stack_name_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(StackOp);
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("StackV2"), StackOp);
|
||||
|
||||
class StackPushOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit StackPushOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
|
||||
}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
xla::ComputationBuilder* b = ctx->builder();
|
||||
TensorShape elem_shape = ctx->InputShape(1);
|
||||
|
||||
XlaResource* resource;
|
||||
OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource));
|
||||
|
||||
// Initializes the Stack, if the element shape was not already known.
|
||||
OP_REQUIRES_OK(ctx, MaybeInitializeStack(b, resource, dtype_, elem_shape));
|
||||
|
||||
xla::ComputationDataHandle ta = b->GetTupleElement(resource->value, 0);
|
||||
xla::ComputationDataHandle index = b->GetTupleElement(resource->value, 1);
|
||||
xla::ComputationDataHandle value = ctx->Input(1);
|
||||
|
||||
// start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0].
|
||||
auto start_indices = PadIndexWithZeros(b, index, elem_shape.dims());
|
||||
|
||||
TensorShape slice_shape = elem_shape;
|
||||
slice_shape.InsertDim(0, 1LL);
|
||||
auto update = b->Reshape(value, slice_shape.dim_sizes());
|
||||
|
||||
// TODO(phawkins): We don't check the index is in bounds --- there is no
|
||||
// error mechanism in XLA.
|
||||
resource->value =
|
||||
b->Tuple({b->DynamicUpdateSlice(ta, update, start_indices),
|
||||
b->Add(index, b->ConstantR0<int32>(1))});
|
||||
|
||||
ctx->SetOutput(0, value);
|
||||
}
|
||||
|
||||
private:
|
||||
DataType dtype_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(StackPushOp);
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("StackPushV2"), StackPushOp);
|
||||
|
||||
class StackPopOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit StackPopOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("elem_type", &dtype_));
|
||||
}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
xla::ComputationBuilder* b = ctx->builder();
|
||||
|
||||
XlaResource* resource;
|
||||
OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource));
|
||||
|
||||
OP_REQUIRES(ctx, resource->type == dtype_,
|
||||
errors::InvalidArgument(
|
||||
"Stack dtype is ", DataTypeString(resource->type),
|
||||
" but Op requested dtype ", DataTypeString(dtype_), "."));
|
||||
|
||||
// There is a somewhat subtle issue here: here "uninitialized" means we have
|
||||
// not yet seen a pop in the order that we compile operators, not the order
|
||||
// that we run them. However, in practice the two orders should be the same
|
||||
// for the sole user of the stack operators (loop gradients).
|
||||
OP_REQUIRES(ctx, resource->value.handle() != 0,
|
||||
errors::InvalidArgument("Stack pop on uninitialized stack"));
|
||||
|
||||
TensorShape stack_shape;
|
||||
OP_REQUIRES_OK(ctx, GetStackShape(b, resource, &stack_shape));
|
||||
|
||||
xla::ComputationDataHandle state = resource->value;
|
||||
xla::ComputationDataHandle ta = b->GetTupleElement(state, 0);
|
||||
xla::ComputationDataHandle index = b->GetTupleElement(state, 1);
|
||||
|
||||
index = b->Sub(index, b->ConstantR0<int32>(1));
|
||||
resource->value = b->Tuple({ta, index});
|
||||
|
||||
// start_indices of the DynamicSlice are [index, 0, 0, ..., 0].
|
||||
auto start_indices = PadIndexWithZeros(b, index, stack_shape.dims() - 1);
|
||||
|
||||
auto slice_shape = stack_shape.dim_sizes();
|
||||
slice_shape[0] = 1LL;
|
||||
|
||||
// TODO(phawkins): We don't check the index is in bounds --- there is no
|
||||
// error mechanism in XLA.
|
||||
xla::ComputationDataHandle read =
|
||||
b->DynamicSlice(ta, start_indices, slice_shape);
|
||||
|
||||
// Remove the leading '1' dimension.
|
||||
std::vector<int64> value_shape(slice_shape.begin() + 1, slice_shape.end());
|
||||
ctx->SetOutput(0, b->Reshape(read, value_shape));
|
||||
}
|
||||
|
||||
private:
|
||||
DataType dtype_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(StackPopOp);
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("StackPopV2"), StackPopOp);
|
||||
|
||||
class StackCloseOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit StackCloseOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
// Do nothing.
|
||||
}
|
||||
|
||||
private:
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(StackCloseOp);
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("StackCloseV2"), StackCloseOp);
|
||||
|
||||
} // anonymous namespace
|
||||
} // namespace tensorflow
|
@ -22,6 +22,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/client/computation_builder.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
|
||||
@ -56,6 +57,9 @@ Status MakeXlaCompilerArgumentsFromInputs(
|
||||
case XlaResource::kTensorArray:
|
||||
arg.kind = XlaCompiler::Argument::kTensorArray;
|
||||
break;
|
||||
case XlaResource::kStack:
|
||||
arg.kind = XlaCompiler::Argument::kStack;
|
||||
break;
|
||||
case XlaResource::kInvalid:
|
||||
CHECK(false);
|
||||
}
|
||||
|
@ -70,6 +70,7 @@ struct XlaResource {
|
||||
kInvalid,
|
||||
kVariable,
|
||||
kTensorArray,
|
||||
kStack,
|
||||
};
|
||||
|
||||
Kind kind = kInvalid;
|
||||
|
@ -266,6 +266,7 @@ Status BuildArguments(const std::vector<XlaCompiler::Argument>& args,
|
||||
switch (args[i].kind) {
|
||||
case XlaCompiler::Argument::kVariable:
|
||||
case XlaCompiler::Argument::kTensorArray:
|
||||
case XlaCompiler::Argument::kStack:
|
||||
context_arg.is_resource = true;
|
||||
if (args[i].initialized) {
|
||||
resources.push_back(i);
|
||||
|
@ -85,7 +85,7 @@ class XlaCompiler {
|
||||
// Argument is a compile-time constant. No associated runtime parameter.
|
||||
kConstant,
|
||||
|
||||
// Argument is a variable resource. Has an associated runtime parameter
|
||||
// Argument is a Variable resource. Has an associated runtime parameter
|
||||
// iff `initialized` is true.
|
||||
kVariable,
|
||||
|
||||
@ -93,6 +93,10 @@ class XlaCompiler {
|
||||
// iff `initialized` is true.
|
||||
kTensorArray,
|
||||
|
||||
// Argument is a Stack resource. Has an associated runtime parameter
|
||||
// iff `initialized` is true.
|
||||
kStack,
|
||||
|
||||
// Argument is a run-time parameter.
|
||||
kParameter,
|
||||
};
|
||||
|
Loading…
Reference in New Issue
Block a user