Fix bugs in TensorArray gradients + unit tests.

Two major bugs:

* Multiple calls to tf.gradients will create multiple TensorArrayWrites to the
  same gradient slot from the exact same gradient source.  This gets treated as
  a write + an add, and as a result gradients of TensorArrayRead can be
  double-counted.  The solution is to create a separate TensorArray for each
  call to tf.gradients.  This can be done by:
  1. looking at the name of the input gradient to e.g. TensorArrayRead
  2. slicing off the prefix (e.g. "gradients", "gradients_1", etc)
  3. passing this to the tensor_array_grad op which
  4. uses this as a suffix to the original name when creating or looking up a new
     TensorArray object for storing the gradients.

* The initial gradient TensorArrayWrite performed a shallow copy of the
  write tensor.  Since we support aggregation to the same slot from different
  sources, modifying the PersistentTensor in place can affect the original
  tensor elsewhere.

  Instead, we now specifically disallow multiple reads / packs from a TensorArray.
  This simplifies the code immensely and removes the need to support gradient
  aggregation.  It also makes the interface much more functional.  Once a
  Tensor has been read out of the TensorArray, it can be used in several places.
  However, gradient aggregation can be performed outside the TensorArray.

Additional improvements:

* TensorArray constructor is now "read_once", which means after a Read or
  Pack operation, the reference inside TensorArray to the Tensor is removed.  This
  frees up memory early.
Change: 113193579
This commit is contained in:
Eugene Brevdo 2016-01-27 13:17:52 -08:00 committed by Vijay Vasudevan
parent e36590b4d5
commit 697084c97b
8 changed files with 430 additions and 331 deletions

View File

@ -1,58 +0,0 @@
/* Copyright 2015 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#define EIGEN_USE_THREADS
#include "tensorflow/core/kernels/tensor_array.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/kernels/aggregate_ops_cpu.h"
namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
namespace tensor_array {
#define TENSOR_ARRAY_WRITE_OR_ADD(Device, T) \
template <> \
Status TensorArrayWriteOrAdd<T, Device>(OpKernelContext * ctx, Tensor * sum, \
const Tensor* current, \
const Tensor* add) { \
functor::Add2Functor<Device, T> add_functor; \
add_functor(ctx->template eigen_device<Device>(), sum->flat<T>(), \
current->flat<T>(), add->flat<T>()); \
return Status::OK(); \
}
#define TENSOR_ARRAY_WRITE_OR_ADD_CPU(T) TENSOR_ARRAY_WRITE_OR_ADD(CPUDevice, T)
TF_CALL_NUMBER_TYPES(TENSOR_ARRAY_WRITE_OR_ADD_CPU)
#undef TENSOR_ARRAY_WRITE_OR_ADD_CPU
#if GOOGLE_CUDA
#define TENSOR_ARRAY_WRITE_OR_ADD_GPU(T) TENSOR_ARRAY_WRITE_OR_ADD(GPUDevice, T)
TF_CALL_GPU_NUMBER_TYPES(TENSOR_ARRAY_WRITE_OR_ADD_GPU);
#undef TENSOR_ARRAY_WRITE_OR_ADD_GPU
#endif // GOOGLE_CUDA
#undef TENSOR_ARRAY_WRITE_OR_ADD
} // namespace tensor_array
} // namespace tensorflow

View File

@ -34,144 +34,181 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
namespace tensor_array {
// Full implementations are in tensor_array.cc
template <typename T, typename Device>
Status TensorArrayWriteOrAdd(OpKernelContext* ctx, Tensor* sum,
const Tensor* current, const Tensor* add) {
return errors::InvalidArgument("TensorArrayWriteOrAdd type not supported: ",
DataTypeString(DataTypeToEnum<T>::value));
};
#define TENSOR_ARRAY_WRITE_OR_ADD(Device, T) \
template <> \
Status TensorArrayWriteOrAdd<T, Device>(OpKernelContext * ctx, Tensor * sum, \
const Tensor* current, \
const Tensor* add);
#define TENSOR_ARRAY_WRITE_OR_ADD_CPU(T) TENSOR_ARRAY_WRITE_OR_ADD(CPUDevice, T)
TF_CALL_NUMBER_TYPES(TENSOR_ARRAY_WRITE_OR_ADD_CPU)
#undef TENSOR_ARRAY_WRITE_OR_ADD_CPU
#if GOOGLE_CUDA
#define TENSOR_ARRAY_WRITE_OR_ADD_GPU(T) TENSOR_ARRAY_WRITE_OR_ADD(GPUDevice, T)
TF_CALL_GPU_NUMBER_TYPES(TENSOR_ARRAY_WRITE_OR_ADD_GPU);
#undef TENSOR_ARRAY_WRITE_OR_ADD_GPU
#endif // GOOGLE_CUDA
#undef TENSOR_ARRAY_WRITE_OR_ADD
} // namespace tensor_array
// The TensorArray object keeps an array of PersistentTensors. It
// allows reading from the array and writing to the array.
//
// Important properties:
// * Reading and writing to a particular index in the TensorArray
// is allowed at most once per index.
// * Upon reading an entry, that entry is cleared from the array and
// marked as read. This allows removal of Tensor from memory
// as soon as it is not needed. Its shape is saved.
// * No deep copies of any PersistentTensor are ever made.
// * Reading and Writing to the array is protected by a mutex.
// All operations on a TensorArray are thread-safe.
// * A TensorArray may be preemptively closed, which releases all
// memory associated with it.
//
// These properties together allow the TensorArray to work as a
// functional object and makes gradient computation easy. For
// example:
// * Write-Once semantics mean the gradient of a TensorArray Read never has to
// worry which of multiple writes to that index the gradient value
// is meant for.
// * Read-Once semantics mean the TensorArray never sees
// multiple writes to the same index as part of gradient aggregation.
//
class TensorArray : public ResourceBase {
public:
// Construct a TensorArray for holding Tensors of type 'dtype' with
// 'N' elements. While the underlying storage is a std::vector and
// can hold more than MAX_INT entries, in practice we do not expect
// users to construct this many Tensors for storage in a TensorArray.
TensorArray(const DataType& dtype, const Tensor& handle, int32 N)
: dead_(false), dtype_(dtype), handle_(handle), tensor_array_(N) {}
: dtype_(dtype), handle_(handle), closed_(false), tensors_(N) {}
Status Write(const int32 index, const PersistentTensor& value) {
// Write PersistentTensor 'value' to index 'index'.
//
// Preconditions:
// * The TensorArray is not closed
// * The index is in [0, N)
// * The dtype of the Tensor in 'value' matches the TensorArray's dtype.
// * The Tensor at 'index' has not yet been written to.
//
// Side effects:
// * The underlying Tensor in 'value' has a new reference to it.
// * Index 'index' is marked as written.
//
// Note, value is passed as a pointer because we its underlying
// Tesnor's shape is accessed. Otherwise it is not modified.
Status Write(OpKernelContext* ctx, const int32 index,
PersistentTensor* value) {
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(LockedReturnIfDead());
if (index < 0 || index >= tensor_array_.size()) {
return errors::InvalidArgument("Tried to write to index ", index,
" but array size is: ",
tensor_array_.size());
TF_RETURN_IF_ERROR(LockedReturnIfClosed());
if (index < 0 || static_cast<size_t>(index) >= tensors_.size()) {
return errors::InvalidArgument("TensorArray ", handle_.vec<string>()(1),
": Tried to write to index ", index,
" but array size is: ", tensors_.size());
}
if (tensor_array_[index].IsInitialized()) {
TensorAndState& t = tensors_[index];
if (t.written) {
return errors::InvalidArgument(
"Could not write to TensorArray index ", index,
"TensorArray ", handle_.vec<string>()(1),
": Could not write to TensorArray index ", index,
" because it has already been written to.");
}
tensor_array_[index] = value;
return Status::OK();
}
template <typename Device, typename T>
Status WriteOrAdd(OpKernelContext* ctx, const int32 index,
PersistentTensor value) {
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(LockedReturnIfDead());
if (index < 0 || index >= tensor_array_.size()) {
return errors::InvalidArgument("Tried to write to index ", index,
" but array size is: ",
tensor_array_.size());
}
if (tensor_array_[index].IsInitialized()) {
Tensor* sum = tensor_array_[index].AccessTensor(ctx);
const Tensor* current = tensor_array_[index].AccessTensor(ctx);
const Tensor* add = value.AccessTensor(ctx);
if (!current->shape().IsSameSize(add->shape())) {
return errors::InvalidArgument("Cannot add to index ", index,
" because shapes are inconsistent: ",
current->shape().DebugString(), " vs. ",
add->shape().DebugString());
}
return tensor_array::TensorArrayWriteOrAdd<T, Device>(ctx, sum, current,
add);
} else {
tensor_array_[index] = value;
}
Tensor* value_t = value->AccessTensor(ctx);
if (value_t->dtype() != dtype_) {
return errors::InvalidArgument(
"TensorArray ", handle_.vec<string>()(1),
": Could not write to TensorArray index ", index,
" because the value dtype is ", DataTypeString(value_t->dtype()),
" but TensorArray dtype is ", DataTypeString(dtype_), ".");
}
t.tensor = *value;
t.shape = value_t->shape();
t.written = true;
return Status::OK();
}
// Read from index 'index' into PersistentTensor 'value'.
//
// Preconditions:
// * The TensorArray is not closed
// * The index is in [0, N)
// * The Tensor at 'index' has been written to.
// * The Tensor at 'index' has not already been read.
//
// Side effects:
// * The PersistentTensor at 'index' is cleared from the given index.
// * The reference to the underlying Tensor at 'index' is shifted to
// the returned '*value'.
// * Index 'index' is marked as read.
Status Read(const int32 index, PersistentTensor* value) {
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(LockedReturnIfDead());
if (index < 0 || index >= tensor_array_.size()) {
TF_RETURN_IF_ERROR(LockedReturnIfClosed());
if (index < 0 || static_cast<size_t>(index) >= tensors_.size()) {
return errors::InvalidArgument("Tried to read from index ", index,
" but array size is: ",
tensor_array_.size());
" but array size is: ", tensors_.size());
}
if (!tensor_array_[index].IsInitialized()) {
TensorAndState& t = tensors_[index];
if (t.read) {
return errors::InvalidArgument(
"Could not read from TensorArray index ", index,
"TensorArray ", handle_.vec<string>()(1), ": Could not read index ",
index, " twice because TensorArray a read-once object.");
}
if (!t.written) {
return errors::InvalidArgument(
"TensorArray ", handle_.vec<string>()(1),
": Could not read from TensorArray index ", index,
" because it has not yet been written to.");
}
*value = tensor_array_[index];
*value = t.tensor;
t.read = true;
t.tensor = PersistentTensor();
return Status::OK();
}
inline int32 Size() {
// Return the Size of the TensorArray.
Status Size(int32* size) {
mutex_lock l(mu_);
DCHECK(!dead_);
return tensor_array_.size();
}
inline Status LockedReturnIfDead() const {
if (dead_) {
return errors::InvalidArgument("Tensor ", handle_.vec<string>()(1),
" has already been closed.");
}
TF_RETURN_IF_ERROR(LockedReturnIfClosed());
*size = tensors_.size();
return Status::OK();
}
DataType ElemType() { return dtype_; }
DataType ElemType() const { return dtype_; }
string DebugString() override {
mutex_lock l(mu_);
DCHECK(!dead_);
return strings::StrCat("TensorArray[", tensor_array_.size(), "]");
CHECK(!closed_);
return strings::StrCat("TensorArray[", tensors_.size(), "]");
}
inline bool IsDead() const { return dead_; }
void ClearAndMarkDead() {
inline bool IsClosed() {
mutex_lock l(mu_);
tensor_array_.clear();
dead_ = true;
return closed_;
}
// Clear the TensorArray, including any Tensor references, and mark as closed.
void ClearAndMarkClosed() {
mutex_lock l(mu_);
tensors_.clear();
closed_ = true;
}
mutex* mu() { return &mu_; }
Tensor* handle() { return &handle_; }
private:
bool dead_; // Marks that the tensor_array_ has been cleared.
mutex mu_;
Status LockedReturnIfClosed() const {
if (closed_) {
return errors::InvalidArgument("TensorArray ", handle_.vec<string>()(1),
" has already been closed.");
}
return Status::OK();
}
DataType dtype_;
Tensor handle_;
std::vector<PersistentTensor> tensor_array_ GUARDED_BY(mu_);
mutex mu_;
bool closed_
GUARDED_BY(mu_); // Marks that the tensor_array_ has been cleared.
// TensorAndState is used to keep track of the PersistentTensors
// stored in the TensorArray, along with their shapes, and a boolean
// that determines whether they have already been read or not.
struct TensorAndState {
TensorAndState() : written(false), read(false) {}
PersistentTensor tensor;
TensorShape shape;
bool written; // True if a Tensor has been written to the index.
bool read; // True if a Tensor has been written to and read from the index.
};
// The list of underlying PersistentTensors and states.
std::vector<TensorAndState> tensors_ GUARDED_BY(mu_);
};
} // namespace tensorflow

View File

@ -33,7 +33,7 @@ limitations under the License.
#include "tensorflow/core/kernels/split_op.h"
#include "tensorflow/core/kernels/tensor_array.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/platform/types.h"
@ -158,7 +158,7 @@ class TensorArrayOp : public TensorArrayCreationOp {
private:
DataType dtype_;
string tensor_array_name_;
string tensor_array_name_; // The name used to create the TensorArray.
TF_DISALLOW_COPY_AND_ASSIGN(TensorArrayOp);
};
@ -184,7 +184,9 @@ REGISTER_GPU(bfloat16);
class TensorArrayGradOp : public TensorArrayCreationOp {
public:
explicit TensorArrayGradOp(OpKernelConstruction* context)
: TensorArrayCreationOp(context) {}
: TensorArrayCreationOp(context) {
OP_REQUIRES_OK(context, context->GetAttr("source", &source_));
}
Status CreateTensorArray(OpKernelContext* ctx, ResourceMgr* rm,
Tensor* tensor_array_output_handle,
@ -202,15 +204,17 @@ class TensorArrayGradOp : public TensorArrayCreationOp {
auto output_handle = tensor_array_output_handle->flat<string>();
output_handle(0) = "_tensor_array_grads";
output_handle(1) = tensor_array_name;
output_handle(1) = strings::StrCat(tensor_array_name, "@", source_);
TensorArray* tensor_array;
int32 array_size;
TF_RETURN_IF_ERROR(rm->Lookup(container, tensor_array_name, &tensor_array));
TF_RETURN_IF_ERROR(tensor_array->Size(&array_size));
auto creator = [this, tensor_array,
auto creator = [this, tensor_array, array_size,
tensor_array_output_handle](TensorArray** ret) {
*ret = new TensorArray(tensor_array->ElemType(),
*tensor_array_output_handle, tensor_array->Size());
*tensor_array_output_handle, array_size);
return Status::OK();
};
@ -221,6 +225,11 @@ class TensorArrayGradOp : public TensorArrayCreationOp {
}
private:
// The gradient source for creating the given
// gradient TensorArray. This should be unique to each gradients
// call. Typical values look like "gradients", "gradients_1", ...
string source_;
TF_DISALLOW_COPY_AND_ASSIGN(TensorArrayGradOp);
};
@ -237,9 +246,7 @@ template <typename Device, typename T>
class TensorArrayWriteOp : public OpKernel {
public:
explicit TensorArrayWriteOp(OpKernelConstruction* context)
: OpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("gradient_add", &gradient_add_));
}
: OpKernel(context) {}
void Compute(OpKernelContext* ctx) override {
OP_REQUIRES_OK(ctx, SetupFlowControlInputs(ctx, true));
@ -264,17 +271,8 @@ class TensorArrayWriteOp : public OpKernel {
" but Op is trying to write dtype ",
DataTypeString(tensor_value->dtype()), "."));
PersistentTensor persistent_tensor(*tensor_value);
if (gradient_add_) {
Status s =
tensor_array->WriteOrAdd<Device, T>(ctx, index, persistent_tensor);
OP_REQUIRES_OK(ctx, s);
} else {
OP_REQUIRES_OK(ctx, tensor_array->Write(index, persistent_tensor));
}
OP_REQUIRES_OK(ctx, tensor_array->Write(ctx, index, &persistent_tensor));
}
private:
bool gradient_add_;
};
#define REGISTER_WRITE(type) \
@ -375,7 +373,8 @@ class TensorArrayPackOp : public OpKernel {
TensorArray* tensor_array = nullptr;
OP_REQUIRES_OK(ctx, GetTensorArray("handle", ctx, &tensor_array));
const int32 array_size = tensor_array->Size();
int32 array_size;
OP_REQUIRES_OK(ctx, tensor_array->Size(&array_size));
OP_REQUIRES(
ctx, dtype_ == tensor_array->ElemType(),
errors::InvalidArgument(
@ -389,9 +388,14 @@ class TensorArrayPackOp : public OpKernel {
return;
}
PersistentTensor value_0;
OP_REQUIRES_OK(ctx, tensor_array->Read(0, &value_0));
Tensor* value_0_t = value_0.AccessTensor(ctx);
// Read all the PersistentTensors into a vector to keep track of
// their memory.
std::vector<PersistentTensor> values(array_size);
for (int i = 0; i < array_size; ++i) {
OP_REQUIRES_OK(ctx, tensor_array->Read(i, &values[i]));
}
const Tensor* value_0_t = values[0].AccessTensor(ctx);
TensorShape output_shape(value_0_t->shape());
output_shape.InsertDim(0, array_size);
@ -402,10 +406,12 @@ class TensorArrayPackOp : public OpKernel {
auto output_flat =
output_tensor->shaped<T, 2>({1, output_shape.num_elements()});
for (int i = 0; i < array_size; ++i) {
PersistentTensor value;
OP_REQUIRES_OK(ctx, tensor_array->Read(i, &value));
const Tensor* value_t = value.AccessTensor(ctx);
// Insert the first value
input_tensors_flat.emplace_back(new ConstMatrix(
value_0_t->shaped<T, 2>({1, value_0_t->NumElements()})));
for (int i = 1; i < array_size; ++i) {
const Tensor* value_t = values[i].AccessTensor(ctx);
OP_REQUIRES(
ctx, value_0_t->shape() == value_t->shape(),
errors::InvalidArgument(
@ -470,20 +476,18 @@ template <typename Device, typename T>
class TensorArrayUnpackOp : public OpKernel {
public:
explicit TensorArrayUnpackOp(OpKernelConstruction* context)
: OpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("gradient_add", &gradient_add_));
}
: OpKernel(context) {}
void Compute(OpKernelContext* ctx) override {
OP_REQUIRES_OK(ctx, SetupFlowControlInputs(ctx, true));
TensorArray* tensor_array = nullptr;
OP_REQUIRES_OK(ctx, GetTensorArray("handle", ctx, &tensor_array));
const Tensor* tensor_value;
OP_REQUIRES_OK(ctx, ctx->input("value", &tensor_value));
const int32 array_size = tensor_array->Size();
int32 array_size;
OP_REQUIRES_OK(ctx, tensor_array->Size(&array_size));
OP_REQUIRES(
ctx, tensor_value->dtype() == tensor_array->ElemType(),
errors::InvalidArgument("TensorArray dtype is ",
@ -522,18 +526,9 @@ class TensorArrayUnpackOp : public OpKernel {
functor::Split<Device, T>()(ctx->eigen_device<Device>(), tensor_value_i_t,
tensor_value_t, indices, sizes);
if (gradient_add_) {
Status s =
tensor_array->WriteOrAdd<Device, T>(ctx, i, persistent_tensor);
OP_REQUIRES_OK(ctx, s);
} else {
OP_REQUIRES_OK(ctx, tensor_array->Write(i, persistent_tensor));
}
OP_REQUIRES_OK(ctx, tensor_array->Write(ctx, i, &persistent_tensor));
}
}
private:
bool gradient_add_;
};
#define REGISTER_UNPACK(type) \
@ -571,13 +566,13 @@ class TensorArrayCloseOp : public OpKernel {
TensorArray* tensor_array;
OP_REQUIRES_OK(ctx, GetTensorArray("handle", ctx, &tensor_array));
// Instead of deleting this TA from the ResourceManager, we just
// clear it away and mark it as dead. The remaining memory
// clear it away and mark it as closed. The remaining memory
// consumed store its mutex and handle Tensor. This will be
// cleared out at the end of the step anyway, so it's fine to keep
// it around temporarily. The next call to GetTensorArray will
// fail because GetTensorArray checks to see if the TensorArray is
// dead or not.
tensor_array->ClearAndMarkDead();
// it around until the end of the step. Further calls to the
// TensorArray will fail because TensorArray checks internally to
// see if it is closed or not.
tensor_array->ClearAndMarkClosed();
}
};

View File

@ -396,20 +396,39 @@ via Read or Pack.
handle: The handle to the TensorArray.
size: The size of the array.
dtype: The type of the elements on the tensor_array.
tensor_array_name: Overrides the name used for the temporary tensor_array resource. Default
value is the name of the 'TensorArray' op (which is guaranteed unique).
tensor_array_name: Overrides the name used for the temporary tensor_array
resource. Default value is the name of the 'TensorArray' op (which
is guaranteed unique).
)doc");
REGISTER_OP("TensorArrayGrad")
.Input("handle: Ref(string)")
.Output("grad_handle: Ref(string)")
.Attr("source: string")
.SetIsStateful()
.Doc(R"doc(
Creates a TensorArray for storing the gradients of values in the given handle.
If the given TensorArray gradient already exists, returns a reference to it.
TensorArray gradient calls use an accumulator TensorArray object. If
multiple gradients are calculated and run in the same session, the multiple
gradient nodes may accidentally flow throuth the same accumulator TensorArray.
This double counts and generally breaks the TensorArray gradient flow.
The solution is to identify which gradient call this particular
TensorArray gradient is being called in. This is performed by identifying
a unique string (e.g. "gradients", "gradients_1", ...) from the input
gradient Tensor's name. This string is used as a suffix when creating
the TensorArray gradient object here (the attribute `source`).
The attribute `source` is added as a suffix to the forward TensorArray's
name when performing the creation / lookup, so that each separate gradient
calculation gets its own TensorArray accumulator.
handle: The handle to the forward TensorArray.
source: The gradient source string, used to decide which gradient TensorArray
to return.
)doc");
REGISTER_OP("TensorArrayWrite")
@ -419,8 +438,6 @@ REGISTER_OP("TensorArrayWrite")
.Input("flow_in: float")
.Output("flow_out: float")
.Attr("T: type")
.Attr("gradient_add: bool = false")
.SetIsStateful()
.Doc(R"doc(
Push an element onto the tensor_array.
@ -429,8 +446,6 @@ index: The position to write to inside the TensorArray.
value: The tensor to write to the TensorArray.
flow_in: A float scalar that enforces proper chaining of operations.
flow_out: A float scalar that enforces proper chaining of operations.
gradient_add: Used for gradient back-propagation. If the value has already
been written to the handle, validate input shape and add to it.
)doc");
REGISTER_OP("TensorArrayRead")
@ -439,7 +454,6 @@ REGISTER_OP("TensorArrayRead")
.Input("flow_in: float")
.Output("value: dtype")
.Attr("dtype: type")
.SetIsStateful()
.Doc(R"doc(
Read an element from the TensorArray.
@ -454,7 +468,6 @@ REGISTER_OP("TensorArrayPack")
.Input("flow_in: float")
.Output("value: dtype")
.Attr("dtype: type")
.SetIsStateful()
.Doc(R"doc(
Pack the elements from the TensorArray.
@ -473,16 +486,12 @@ REGISTER_OP("TensorArrayUnpack")
.Input("flow_in: float")
.Output("flow_out: float")
.Attr("T: type")
.Attr("gradient_add: bool = false")
.SetIsStateful()
.Doc(R"doc(
Unpack the data from the input value into TensorArray elements.
handle: The handle to a TensorArray.
value: The concatenated tensor to write to the TensorArray.
flow_in: A float scalar that enforces proper chaining of operations.
gradient_add: Used for gradient back-propagation. If values are already
written to the handle, validate shapes and add to them.
flow_out: A float scalar that enforces proper chaining of operations.
)doc");

View File

@ -9082,7 +9082,7 @@ op {
default_value {
s: ""
}
description: "Overrides the name used for the temporary tensor_array resource. Default\nvalue is the name of the \'TensorArray\' op (which is guaranteed unique)."
description: "Overrides the name used for the temporary tensor_array\nresource. Default value is the name of the \'TensorArray\' op (which\nis guaranteed unique)."
}
summary: "An array of Tensors of given size, with data written via Write and read"
description: "via Read or Pack."
@ -9112,8 +9112,13 @@ op {
type: DT_STRING
is_ref: true
}
attr {
name: "source"
type: "string"
description: "The gradient source string, used to decide which gradient TensorArray\nto return."
}
summary: "Creates a TensorArray for storing the gradients of values in the given handle."
description: "If the given TensorArray gradient already exists, returns a reference to it."
description: "If the given TensorArray gradient already exists, returns a reference to it.\n\nTensorArray gradient calls use an accumulator TensorArray object. If\nmultiple gradients are calculated and run in the same session, the multiple\ngradient nodes may accidentally flow throuth the same accumulator TensorArray.\nThis double counts and generally breaks the TensorArray gradient flow.\n\nThe solution is to identify which gradient call this particular\nTensorArray gradient is being called in. This is performed by identifying\na unique string (e.g. \"gradients\", \"gradients_1\", ...) from the input\ngradient Tensor\'s name. This string is used as a suffix when creating\nthe TensorArray gradient object here (the attribute `source`).\n\nThe attribute `source` is added as a suffix to the forward TensorArray\'s\nname when performing the creation / lookup, so that each separate gradient\ncalculation gets its own TensorArray accumulator."
is_stateful: true
}
op {
@ -9141,7 +9146,6 @@ op {
}
summary: "Pack the elements from the TensorArray."
description: "All elements must have the same shape."
is_stateful: true
}
op {
name: "TensorArrayRead"
@ -9171,7 +9175,6 @@ op {
description: "The type of the elem that is returned."
}
summary: "Read an element from the TensorArray."
is_stateful: true
}
op {
name: "TensorArrayUnpack"
@ -9200,16 +9203,7 @@ op {
name: "T"
type: "type"
}
attr {
name: "gradient_add"
type: "bool"
default_value {
b: false
}
description: "Used for gradient back-propagation. If values are already\nwritten to the handle, validate shapes and add to them."
}
summary: "Unpack the data from the input value into TensorArray elements."
is_stateful: true
}
op {
name: "TensorArrayWrite"
@ -9243,16 +9237,7 @@ op {
name: "T"
type: "type"
}
attr {
name: "gradient_add"
type: "bool"
default_value {
b: false
}
description: "Used for gradient back-propagation. If the value has already\nbeen written to the handle, validate input shape and add to it."
}
summary: "Push an element onto the tensor_array."
is_stateful: true
}
op {
name: "TextLineReader"

View File

@ -26,6 +26,7 @@ import numpy as np
import tensorflow as tf
from tensorflow.python.framework import errors
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import gen_data_flow_ops
@ -155,7 +156,7 @@ class TensorArrayTest(tf.test.TestCase):
with self.test_session(use_gpu=use_gpu) as sess:
h = data_flow_ops.TensorArray(
dtype=tf.float32, tensor_array_name="foo", size=3)
g_h = h.grad()
g_h = h.grad("grad")
w0 = h.write(0, [[4.0, 5.0]])
w1 = w0.write(1, [[1.0]])
@ -189,8 +190,8 @@ class TensorArrayTest(tf.test.TestCase):
with self.test_session(use_gpu=use_gpu) as sess:
h = data_flow_ops.TensorArray(
dtype=tf.float32, tensor_array_name="foo", size=3)
g_h_0 = h.grad()
g_h_1 = h.grad()
g_h_0 = h.grad("grad")
g_h_1 = h.grad("grad")
with tf.control_dependencies([g_h_0.write(0, [[4.0, 5.0]]).flow]):
# Write with one gradient handle, read with another copy of it
@ -277,27 +278,6 @@ class TensorArrayTest(tf.test.TestCase):
self._testTensorArrayWriteMultipleFails(use_gpu=False)
self._testTensorArrayWriteMultipleFails(use_gpu=True)
def _testTensorArrayWriteGradientAddMultipleAddsType(self, use_gpu, dtype):
with self.test_session(use_gpu=use_gpu):
h = data_flow_ops.TensorArray(
dtype=dtype, tensor_array_name="foo", size=3)
h._gradient_add = True
c = lambda x: np.asarray(x, dtype=dtype.as_numpy_dtype)
w0 = h.write(2, c(3.0))
w1 = w0.write(2, c(4.0))
self.assertAllEqual(c(7.00), w1.read(2).eval())
def _testTensorArrayWriteGradientAddMultipleAdds(self, use_gpu):
for dtype in [tf.int32, tf.int64, tf.float32, tf.float64, tf.complex64]:
self._testTensorArrayWriteGradientAddMultipleAddsType(use_gpu, dtype)
def testTensorArrayWriteGradientAddMultipleAdds(self):
self._testTensorArrayWriteGradientAddMultipleAdds(use_gpu=False)
self._testTensorArrayWriteGradientAddMultipleAdds(use_gpu=True)
def _testMultiTensorArray(self, use_gpu):
with self.test_session(use_gpu=use_gpu):
h1 = data_flow_ops.TensorArray(
@ -346,7 +326,6 @@ class TensorArrayTest(tf.test.TestCase):
w1 = w0.write(1, value_1)
r0 = w1.read(0)
r1 = w1.read(1)
r0_2 = w1.read(0)
# Test individual components' gradients
grad_just_r0 = tf.gradients(
@ -354,12 +333,6 @@ class TensorArrayTest(tf.test.TestCase):
grad_just_r0_vals = sess.run(grad_just_r0)
self.assertAllEqual(c([[2.0, 3.0]]), grad_just_r0_vals[0])
grad_r0_r0_2 = tf.gradients(
ys=[r0, r0_2], xs=[value_0],
grad_ys=[c([[2.0, 3.0]]), c([[1.0, -1.0]])])
grad_r0_r0_2_vals = sess.run(grad_r0_r0_2)
self.assertAllEqual(c([[3.0, 2.0]]), grad_r0_r0_2_vals[0])
grad_just_r1 = tf.gradients(
ys=[r1], xs=[value_1], grad_ys=[c(-2.0)])
grad_just_r1_vals = sess.run(grad_just_r1)
@ -367,11 +340,11 @@ class TensorArrayTest(tf.test.TestCase):
# Test combined gradients
grad = tf.gradients(
ys=[r0, r0_2, r1], xs=[value_0, value_1],
grad_ys=[c(-1.0), c(-2.0), c([[2.0, 3.0]])])
ys=[r0, r1], xs=[value_0, value_1],
grad_ys=[c(-1.0), c([[2.0, 3.0]])])
grad_vals = sess.run(grad)
self.assertEqual(len(grad_vals), 2)
self.assertAllClose(c(-3.0), grad_vals[0])
self.assertAllClose(c(-1.0), grad_vals[0])
self.assertAllEqual(c([[2.0, 3.0]]), grad_vals[1])
def _testTensorArrayGradientWriteRead(self, use_gpu):
@ -382,34 +355,6 @@ class TensorArrayTest(tf.test.TestCase):
self._testTensorArrayGradientWriteRead(False)
self._testTensorArrayGradientWriteRead(True)
def _testTensorArrayGradientWritePackAndRead(self, use_gpu):
with self.test_session(use_gpu=use_gpu) as sess:
h = data_flow_ops.TensorArray(
dtype=tf.float32, tensor_array_name="foo", size=2)
value_0 = tf.constant([-1.0, 1.0])
value_1 = tf.constant([-10.0, 10.0])
w0 = h.write(0, value_0)
w1 = w0.write(1, value_1)
p0 = w1.pack()
r0 = w1.read(0)
# Test gradient accumulation between read(0) and pack()
grad_r = tf.gradients(
ys=[p0, r0], xs=[value_0, value_1],
grad_ys=[
[[2.0, 3.0], [4.0, 5.0]],
[-0.5, 1.5]])
grad_vals = sess.run(grad_r) # 2 + 2 entries
self.assertAllClose([2.0 - 0.5, 3.0 + 1.5], grad_vals[0])
self.assertAllEqual([4.0, 5.0], grad_vals[1])
def testTensorArrayGradientWritePackAndRead(self):
self._testTensorArrayGradientWritePackAndRead(False)
self._testTensorArrayGradientWritePackAndRead(True)
def _testTensorArrayGradientUnpackRead(self, use_gpu):
with self.test_session(use_gpu=use_gpu) as sess:
h = data_flow_ops.TensorArray(
@ -419,17 +364,16 @@ class TensorArrayTest(tf.test.TestCase):
w = h.unpack(value)
r0 = w.read(0)
r0_1 = w.read(0)
r1 = w.read(1)
# Test combined gradients + aggregation of read(0)
grad = tf.gradients(
ys=[r0, r0_1, r1], xs=[value], grad_ys=
[[2.0, 3.0], [-1.5, 1.5], [4.0, 5.0]])
ys=[r0, r1], xs=[value], grad_ys=
[[2.0, 3.0], [4.0, 5.0]])
grad_vals = sess.run(grad)
self.assertEqual(len(grad_vals), 1)
self.assertAllClose([[2.0 - 1.5, 3.0 + 1.5], [4.0, 5.0]], grad_vals[0])
self.assertAllClose([[2.0, 3.0], [4.0, 5.0]], grad_vals[0])
def testTensorArrayGradientUnpackRead(self):
self._testTensorArrayGradientUnpackRead(False)
@ -454,7 +398,8 @@ class TensorArrayTest(tf.test.TestCase):
w1 = w0.write(1, [3.0])
w1.close().run() # Expected to run without problems
with self.assertRaisesOpError(r"Tensor foo has already been closed."):
with self.assertRaisesOpError(
r"TensorArray foo has already been closed."):
with tf.control_dependencies([w1.close()]):
w1.write(2, 3.0).flow.eval()
@ -462,6 +407,108 @@ class TensorArrayTest(tf.test.TestCase):
self._testWriteCloseTensorArray(use_gpu=False)
self._testWriteCloseTensorArray(use_gpu=True)
def _testWhileLoopWritePackGradients(self, use_gpu):
with self.test_session(use_gpu=use_gpu) as sess:
v0 = tf.identity(np.arange(3*5, dtype=np.float32).reshape(3, 5))
var = tf.Variable(np.arange(100, 105, dtype=np.float32))
state0 = tf.identity([1.0] * 5)
h = data_flow_ops.TensorArray(
dtype=tf.float32, tensor_array_name="foo", size=3)
time_0 = tf.identity(0)
def body(time, flow, state):
sliced = tf.slice(v0, begin=tf.pack([time, 0]), size=[1, -1])
sliced = tf.squeeze(sliced)
out = sliced + var + state
state += sliced
h_n = h
h_n._flow = flow
h_n = h_n.write(time, out)
return (time+1, h_n.flow, state)
(unused_0, final_flow, unused_2) = control_flow_ops.While(
cond=lambda time, unused_1, unused_2: time < 3,
body=body,
loop_vars=(time_0, h.flow, state0),
parallel_iterations=3)
h._flow = final_flow
vout = h.pack()
grad_val = -np.arange(3*5, dtype=np.float32).reshape(3, 5)
v0_grad = tf.gradients([vout], [v0], [grad_val])[0]
state0_grad = tf.gradients([vout], [state0], [grad_val])[0]
var_grad = tf.gradients([vout], [var], [grad_val])[0]
tf.initialize_all_variables().run()
state0_t, var_t, v0_t, vout_t, v0_grad_t, var_grad_t, state0_grad_t = (
sess.run([state0, var, v0, vout, v0_grad, var_grad, state0_grad]))
just_v0_grad_t, = sess.run([v0_grad])
# state = [ state0 | state0 + v0[0] | state0 + v0[0] + v0[1] ]
# vout = [ v0[0] + var + state[0] |
# v0[1] + var + state[1] |
# v0[2] + var + state[2] ]
# = [ v0[0] + var + state0 |
# v0[1] + var + state0 + v0[0] |
# v0[2] + var + state0 + v0[0] + v0[1] ]
#
# d(vout[0])/d(v0) = [1 | 0 | 0 ]
# d(vout[1])/d(v0) = [1 | 1 | 0 ]
# d(vout[2])/d(v0) = [1 | 1 | 1 ]
# d(vout)/d(var) = [1 | 1 | 1]
# d(vout)/d(state0) = [ 1 | 1 | 1 ]
state_per_time = np.array([
state0_t,
state0_t + v0_t[0, :],
state0_t + v0_t[0, :] + v0_t[1, :]])
# Compare forward prop
self.assertAllClose(v0_t + var_t + state_per_time, vout_t)
# Compare backward prop
expected_v0_grad_t = np.array([
grad_val[0, :] + grad_val[1, :] + grad_val[2, :],
grad_val[1, :] + grad_val[2, :],
grad_val[2, :]])
self.assertAllEqual(expected_v0_grad_t, v0_grad_t)
self.assertAllEqual(expected_v0_grad_t, just_v0_grad_t)
self.assertAllClose(grad_val.sum(axis=0), var_grad_t)
self.assertAllClose(grad_val.sum(axis=0), state0_grad_t)
def testWhileLoopWritePackGradients(self):
self._testWhileLoopWritePackGradients(use_gpu=False)
self._testWhileLoopWritePackGradients(use_gpu=True)
def _testSumOfTwoReadVariablesWithoutRepeatGrad(self, use_gpu):
with self.test_session(use_gpu=use_gpu) as sess:
a = tf.identity(np.arange(3*5, dtype=np.float32).reshape(3, 5) + 1)
b = tf.identity(np.arange(3*5, dtype=np.float32).reshape(3, 5) + 1 + 3*5)
ta = data_flow_ops.TensorArray(dtype=tf.float32, size=2)
ta = ta.write(0, a, name="write_a")
ta = ta.write(1, b, name="write_b")
c = (ta.read(0, name="read_a_0") + # a + b
ta.read(1, name="read_b_0"))
g0 = -(np.arange(3*5, dtype=np.float32).reshape(3, 5) + 1)
grad_a = tf.gradients([c], [a], [g0])[0] # d(a+b)/da = 1
grad_b = tf.gradients([c], [b], [g0])[0] # d(a+b)/db = 1
# Test gradients calculated individually
grad_a_t, = sess.run([grad_a])
self.assertAllEqual(grad_a_t, g0)
grad_b_t, = sess.run([grad_b])
self.assertAllEqual(grad_b_t, g0)
# Test gradients calculated jointly
joint_grad_a_t, joint_grad_b_t = sess.run([grad_a, grad_b])
self.assertAllEqual(joint_grad_a_t, g0)
self.assertAllEqual(joint_grad_b_t, g0)
def testSumOfTwoReadVariablesWithoutRepeatGrad(self):
self._testSumOfTwoReadVariablesWithoutRepeatGrad(use_gpu=False)
self._testSumOfTwoReadVariablesWithoutRepeatGrad(use_gpu=True)
if __name__ == "__main__":
tf.test.main()

View File

@ -81,24 +81,82 @@ ops.NoGradient("TensorArrayGrad")
ops.NoGradient("TensorArrayClose")
def _GetGradSource(op_or_tensor):
"""Identify which call to tf.gradients created this gradient op or tensor.
TensorArray gradient calls use an accumulator TensorArray object. If
multiple gradients are calculated and run in the same session, the multiple
gradient nodes may accidentally flow throuth the same accumulator TensorArray.
This double counting breaks the TensorArray gradient flow.
The solution is to identify which gradient call this particular
TensorArray*Grad is being called in, by looking at the input gradient
tensor's name, and create or lookup an accumulator gradient TensorArray
associated with this specific call. This solves any confusion and ensures
different gradients from the same forward graph get their own accumulators.
This function creates the unique label associated with the tf.gradients call
that is used to create the gradient TensorArray.
Args:
op_or_tensor: `Tensor` or `Operation` which is an input to a
TensorArray*Grad call.
Returns:
A python string, the unique label associated with this particular
gradients calculation.
Raises:
ValueError: If not called within a gradients calculation.
"""
if not op_or_tensor.name.startswith("gradients"):
raise ValueError(
"Expected op/tensor name to start with gradients, got: %s"
% op_or_tensor.name)
return op_or_tensor.name.split("/")[0]
@ops.RegisterGradient("TensorArrayRead")
def _TensorArrayReadGrad(op, grad):
"""Gradient for TensorArrayRead.
Args:
op: Forward TensorArrayRead op.
grad: Gradient `Tensor` to TensorArrayRead.
Returns:
A flow `Tensor`, which can be used in control dependencies to
force the write of `grad` to the gradient `TensorArray`.
"""
handle = op.inputs[0]
index = op.inputs[1]
dtype = op.get_attr("dtype")
g = data_flow_ops.TensorArray(size=None, dtype=dtype, handle=handle).grad()
grad_source = _GetGradSource(grad)
g = data_flow_ops.TensorArray(size=None, dtype=dtype, handle=handle).grad(
source=grad_source)
w_g = g.write(index, grad)
return [None, None, w_g.flow]
@ops.RegisterGradient("TensorArrayWrite")
def _TensorArrayWriteGrad(op, flow):
"""Gradient for TensorArrayWrite.
Args:
op: Forward TensorArrayWrite op.
flow: Gradient `Tensor` flow to TensorArrayWrite.
Returns:
A grad `Tensor`, the gradient created in an upstream ReadGrad or PackGrad.
"""
# handle is the output store_handle of TensorArrayReadGrad or
# the handle output of TensorArrayWriteGrad. we must use this one.
handle = op.inputs[0]
index = op.inputs[1]
dtype = op.get_attr("T")
g = data_flow_ops.TensorArray(size=None, dtype=dtype, handle=handle).grad()
grad_source = _GetGradSource(flow)
g = data_flow_ops.TensorArray(size=None, dtype=dtype, handle=handle).grad(
source=grad_source)
with ops.control_dependencies([flow]):
grad = g.read(index)
return [None, None, grad, flow]
@ -106,20 +164,41 @@ def _TensorArrayWriteGrad(op, flow):
@ops.RegisterGradient("TensorArrayPack")
def _TensorArrayPackGrad(op, grad):
"""Gradient for TensorArrayPack.
Args:
op: Forward TensorArrayPack op.
grad: Gradient `Tensor` to TensorArrayPack.
Returns:
A flow `Tensor`, which can be used in control dependencies to
force the write of `grad` to the gradient `TensorArray`.
"""
handle = op.inputs[0]
dtype = op.get_attr("dtype")
g = data_flow_ops.TensorArray(size=None, dtype=dtype, handle=handle).grad()
grad_source = _GetGradSource(grad)
g = data_flow_ops.TensorArray(size=None, dtype=dtype, handle=handle).grad(
source=grad_source)
u_g = g.unpack(grad)
return [None, u_g.flow]
@ops.RegisterGradient("TensorArrayUnpack")
def _TensorArrayUnpackGrad(op, flow):
# handle is the output store_handle of TensorArrayReadGrad or
# the handle output of TensorArrayUnpackGrad. we must use this one.
"""Gradient for TensorArrayUnpack.
Args:
op: Forward TensorArrayUnpack op.
flow: Gradient `Tensor` flow to TensorArrayUnpack.
Returns:
A grad `Tensor`, the gradient created in upstream ReadGrads or PackGrad.
"""
handle = op.inputs[0]
dtype = op.get_attr("T")
g = data_flow_ops.TensorArray(size=None, dtype=dtype, handle=handle).grad()
grad_source = _GetGradSource(flow)
g = data_flow_ops.TensorArray(size=None, dtype=dtype, handle=handle).grad(
source=grad_source)
with ops.control_dependencies([flow]):
grad = g.pack()
return [None, grad, flow]

View File

@ -557,12 +557,13 @@ class TensorArray(object):
"""
def __init__(
self, dtype, size, tensor_array_name=None, handle=None, name=None):
self, dtype, size=None, tensor_array_name=None, handle=None, name=None):
"""Construct a new TensorArray or wrap an existing TensorArray handle.
Args:
dtype: (required) data type of the TensorArray.
size: (required) int32 scalar `Tensor`: the size of the TensorArray.
size: (optional) int32 scalar `Tensor`: the size of the TensorArray.
Required if handle is not provided.
tensor_array_name: (optional) Python string: the name of the TensorArray.
This is used when creating the TensorArray handle. If this value is
set, handle should be None.
@ -572,10 +573,15 @@ class TensorArray(object):
Raises:
ValueError: if both handle and tensor_array_name are provided.
TypeError: if handle is provided but is not a Tensor.
"""
if handle and tensor_array_name:
raise ValueError(
"Cannot construct with both handle and tensor_array_name")
if handle and not isinstance(handle, ops.Tensor):
raise TypeError("Handle must be a Tensor")
if handle is None and size is None:
raise ValueError("Size must be provided if handle is not provided")
with ops.op_scope([handle, size], name, "TensorArray") as scope:
if handle:
@ -587,7 +593,6 @@ class TensorArray(object):
self._flow = constant_op.constant(0, dtype=_dtypes.float32)
self._dtype = dtype
self._gradient_add = False
@property
def flow(self):
@ -599,50 +604,50 @@ class TensorArray(object):
"""The reference to the TensorArray."""
return self._handle
def grad(self):
g = TensorArray(
dtype=self._dtype,
size=-1,
handle=gen_data_flow_ops._tensor_array_grad(self._handle))
g._gradient_add = True
def grad(self, source):
g_handle = gen_data_flow_ops._tensor_array_grad(
handle=self._handle, source=source)
g = TensorArray(dtype=self._dtype, size=None, handle=g_handle)
return g
def read(self, index):
def read(self, index, name=None):
"""Read the value at location `index` in the TensorArray."""
value = gen_data_flow_ops._tensor_array_read(
handle=self._handle, index=index, flow_in=self._flow, dtype=self._dtype)
handle=self._handle, index=index, flow_in=self._flow, dtype=self._dtype,
name=name)
return value
def write(self, index, value):
def write(self, index, value, name=None):
"""Write `value` into index `index` of the TensorArray."""
flow_out = gen_data_flow_ops._tensor_array_write(
handle=self._handle, index=index, value=value, flow_in=self._flow,
gradient_add=self._gradient_add)
name=name)
# Size below is ignored
ta = TensorArray(dtype=self._dtype, size=-1, handle=self._handle)
ta._gradient_add = self._gradient_add
ta._flow = flow_out
return ta
def pack(self):
def pack(self, name=None):
"""Return the values in the TensorArray as a packed `Tensor`."""
value = gen_data_flow_ops._tensor_array_pack(
handle=self._handle, flow_in=self._flow, dtype=self._dtype)
handle=self._handle, flow_in=self._flow, dtype=self._dtype,
name=name)
return value
def unpack(self, value):
def unpack(self, value, name=None):
"""Packs the values of a `Tensor` in the TensorArray."""
flow_out = gen_data_flow_ops._tensor_array_unpack(
handle=self._handle, value=value, flow_in=self._flow,
gradient_add=self._gradient_add)
name=name)
ta = TensorArray(dtype=self._dtype, size=-1, handle=self._handle)
ta._gradient_add = self._gradient_add
ta._flow = flow_out
return ta
def close(self):
def close(self, name=None):
"""Close the current TensorArray."""
return gen_data_flow_ops._tensor_array_close(handle=self._handle)
return gen_data_flow_ops._tensor_array_close(
handle=self._handle, name=name)
ops.NoGradient("LookupTableFind")