Make TensorList objects Refcounted.

This drastically reduces the amount of refcounting of individual tensors inside
TensorList when a TensorList variant is copied to a Variable or
MutableDenseHashTable (and back).  Same for operations like tf.stack that
operate on Variant tensors and perform Variant copies implicitly.

While this change adds a level of indirection into the TensorList object
by adding a heap-allocated RefCounted object to contain the vector,
it also reduces the size of the TensorList below the tf::Variant inlining
threshold.  This in turn removes a level of heap indirection and
should cancel out any performance regressions for existing
TensorList operations and small-size lists.

PiperOrigin-RevId: 259464769
This commit is contained in:
Eugene Brevdo 2019-07-22 21:07:01 -07:00 committed by TensorFlower Gardener
parent 4f910ac64b
commit 91425cf597
3 changed files with 305 additions and 149 deletions

View File

@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/ ==============================================================================*/
#include <limits> #include <limits>
#include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/framework/allocator.h"
#define EIGEN_USE_THREADS #define EIGEN_USE_THREADS
@ -21,8 +22,6 @@ limitations under the License.
#define EIGEN_USE_GPU #define EIGEN_USE_GPU
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#include "tensorflow/core/kernels/list_kernels.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/register_types.h"
@ -30,6 +29,7 @@ limitations under the License.
#include "tensorflow/core/framework/variant.h" #include "tensorflow/core/framework/variant.h"
#include "tensorflow/core/framework/variant_op_registry.h" #include "tensorflow/core/framework/variant_op_registry.h"
#include "tensorflow/core/kernels/concat_lib.h" #include "tensorflow/core/kernels/concat_lib.h"
#include "tensorflow/core/kernels/list_kernels.h"
#include "tensorflow/core/lib/core/coding.h" #include "tensorflow/core/lib/core/coding.h"
#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/util/util.h" #include "tensorflow/core/util/util.h"
@ -38,20 +38,16 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::ThreadPoolDevice CPUDevice;
// Variant compatible type for a list of tensors. This is mutable but instances TensorList::~TensorList() {
// should never be mutated after stored in a variant tensor. if (tensors_) tensors_->Unref();
TensorList::TensorList(const TensorList& other) }
: tensors(other.tensors),
element_shape(other.element_shape),
element_dtype(other.element_dtype),
max_num_elements(other.max_num_elements) {}
void TensorList::Encode(VariantTensorData* data) const { void TensorList::Encode(VariantTensorData* data) const {
data->set_type_name(TypeName()); data->set_type_name(TypeName());
std::vector<size_t> invalid_indices; std::vector<size_t> invalid_indices;
for (size_t i = 0; i < tensors.size(); i++) { for (size_t i = 0; i < tensors().size(); i++) {
if (tensors.at(i).dtype() != DT_INVALID) { if (tensors().at(i).dtype() != DT_INVALID) {
*data->add_tensors() = tensors.at(i); *data->add_tensors() = tensors().at(i);
} else { } else {
invalid_indices.push_back(i); invalid_indices.push_back(i);
} }
@ -78,11 +74,11 @@ static Status TensorListDeviceCopy(
to->element_shape = from.element_shape; to->element_shape = from.element_shape;
to->element_dtype = from.element_dtype; to->element_dtype = from.element_dtype;
to->max_num_elements = from.max_num_elements; to->max_num_elements = from.max_num_elements;
to->tensors.reserve(from.tensors.size()); to->tensors().reserve(from.tensors().size());
for (const Tensor& t : from.tensors) { for (const Tensor& t : from.tensors()) {
to->tensors.emplace_back(t.dtype()); to->tensors().emplace_back(t.dtype());
if (t.dtype() != DT_INVALID) { if (t.dtype() != DT_INVALID) {
TF_RETURN_IF_ERROR(copy(t, &to->tensors.back())); TF_RETURN_IF_ERROR(copy(t, &to->tensors().back()));
} }
} }
return Status::OK(); return Status::OK();
@ -116,16 +112,16 @@ bool TensorList::Decode(const VariantTensorData& data) {
} }
size_t total_num_tensors = data.tensors().size() + num_invalid_tensors; size_t total_num_tensors = data.tensors().size() + num_invalid_tensors;
tensors.reserve(total_num_tensors); tensors().reserve(total_num_tensors);
std::vector<size_t>::iterator invalid_indices_it = invalid_indices.begin(); std::vector<size_t>::iterator invalid_indices_it = invalid_indices.begin();
std::vector<Tensor>::const_iterator tensors_it = data.tensors().begin(); std::vector<Tensor>::const_iterator tensors_it = data.tensors().begin();
for (size_t i = 0; i < total_num_tensors; i++) { for (size_t i = 0; i < total_num_tensors; i++) {
if (invalid_indices_it != invalid_indices.end() && if (invalid_indices_it != invalid_indices.end() &&
*invalid_indices_it == i) { *invalid_indices_it == i) {
tensors.emplace_back(Tensor(DT_INVALID)); tensors().emplace_back(Tensor(DT_INVALID));
invalid_indices_it++; invalid_indices_it++;
} else if (tensors_it != data.tensors().end()) { } else if (tensors_it != data.tensors().end()) {
tensors.emplace_back(*tensors_it); tensors().emplace_back(*tensors_it);
tensors_it++; tensors_it++;
} else { } else {
// VariantTensorData is corrupted. // VariantTensorData is corrupted.
@ -201,19 +197,31 @@ Status ForwardInputOrCreateNewList(OpKernelContext* c, int32 input_index,
input_index, output_index, DT_VARIANT, TensorShape{}, input_index, output_index, DT_VARIANT, TensorShape{},
c->input_memory_type(input_index), AllocatorAttributes()); c->input_memory_type(input_index), AllocatorAttributes());
Tensor* output_tensor; Tensor* output_tensor;
if (maybe_output != nullptr) { if (maybe_output != nullptr && maybe_output->dtype() == DT_VARIANT &&
// Woohoo, forwarding succeeded! maybe_output->NumElements() == 1) {
output_tensor = maybe_output.get(); output_tensor = maybe_output.get();
c->set_output(output_index, *output_tensor); TensorList* tmp_out = output_tensor->scalar<Variant>()().get<TensorList>();
} else { if (tmp_out == nullptr) {
// If forwarding is not possible allocate a new output tensor and copy return errors::InvalidArgument(
// the `input_list` to it. "Expected input ", input_index, " to be a TensorList but saw ",
AllocatorAttributes attr; output_tensor->scalar<Variant>()().TypeName());
attr.set_on_host(true); }
TF_RETURN_IF_ERROR( if (tmp_out->RefCountIsOne()) {
c->allocate_output(output_index, {}, &output_tensor, attr)); // Woohoo, forwarding succeeded!
output_tensor->scalar<Variant>()() = input_list; c->set_output(output_index, *output_tensor);
*output_list = tmp_out;
return Status::OK();
}
} }
// If forwarding is not possible allocate a new output tensor and copy
// the `input_list` to it.
AllocatorAttributes attr;
attr.set_on_host(true);
TF_RETURN_IF_ERROR(
c->allocate_output(output_index, {}, &output_tensor, attr));
output_tensor->scalar<Variant>()() = input_list.Copy();
*output_list = output_tensor->scalar<Variant>()().get<TensorList>(); *output_list = output_tensor->scalar<Variant>()().get<TensorList>();
return Status::OK(); return Status::OK();
} }
@ -295,15 +303,15 @@ class TensorListPushBack : public OpKernel {
if (l->max_num_elements != -1) { if (l->max_num_elements != -1) {
OP_REQUIRES( OP_REQUIRES(
c, l->tensors.size() < l->max_num_elements, c, l->tensors().size() < l->max_num_elements,
errors::InvalidArgument("Tried to push item into a full list", errors::InvalidArgument("Tried to push item into a full list",
" list size: ", l->tensors.size(), " list size: ", l->tensors().size(),
" max_num_elements: ", l->max_num_elements)); " max_num_elements: ", l->max_num_elements));
} }
TensorList* output_list = nullptr; TensorList* output_list = nullptr;
OP_REQUIRES_OK(c, ForwardInputOrCreateNewList(c, 0, 0, *l, &output_list)); OP_REQUIRES_OK(c, ForwardInputOrCreateNewList(c, 0, 0, *l, &output_list));
output_list->tensors.push_back(input); output_list->tensors().push_back(input);
} }
private: private:
@ -330,7 +338,7 @@ class TensorListLength : public OpKernel {
OP_REQUIRES_OK(c, GetInputList(c, 0, &l)); OP_REQUIRES_OK(c, GetInputList(c, 0, &l));
Tensor* result; Tensor* result;
OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape{}, &result)); OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape{}, &result));
result->scalar<int32>()() = l->tensors.size(); result->scalar<int32>()() = l->tensors().size();
} }
}; };
@ -399,7 +407,7 @@ class TensorListReserve : public OpKernel {
TensorList output; TensorList output;
output.element_shape = element_shape; output.element_shape = element_shape;
output.element_dtype = element_dtype_; output.element_dtype = element_dtype_;
output.tensors.resize(num_elements, Tensor(DT_INVALID)); output.tensors().resize(num_elements, Tensor(DT_INVALID));
Tensor* result; Tensor* result;
AllocatorAttributes attr; AllocatorAttributes attr;
attr.set_on_host(true); attr.set_on_host(true);
@ -440,32 +448,37 @@ class TensorListResize : public OpKernel {
c->forward_input(0, 0, DT_VARIANT, TensorShape{}, c->forward_input(0, 0, DT_VARIANT, TensorShape{},
c->input_memory_type(0), AllocatorAttributes()); c->input_memory_type(0), AllocatorAttributes());
if (maybe_result != nullptr) { if (maybe_result != nullptr) {
maybe_result->scalar<Variant>()().get<TensorList>()->tensors.resize( TensorList* out = maybe_result->scalar<Variant>()().get<TensorList>();
size, Tensor(DT_INVALID)); if (out->RefCountIsOne()) {
c->set_output(0, *maybe_result); // We are able to forward the input.
} else { out->tensors().resize(size, Tensor(DT_INVALID));
Tensor* result; c->set_output(0, *maybe_result);
AllocatorAttributes attr; return;
attr.set_on_host(true);
OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape{}, &result, attr));
TensorList output_list;
output_list.element_shape = input_list->element_shape;
output_list.element_dtype = input_list->element_dtype;
output_list.max_num_elements = input_list->max_num_elements;
if (size > input_list->tensors.size()) {
output_list.tensors.insert(output_list.tensors.begin(),
input_list->tensors.begin(),
input_list->tensors.end());
// Add DT_INVALID tensors to the end of the list if the requested size
// is larger than the list length.
output_list.tensors.resize(size, Tensor(DT_INVALID));
} else {
output_list.tensors.insert(output_list.tensors.begin(),
input_list->tensors.begin(),
input_list->tensors.begin() + size);
} }
result->scalar<Variant>()() = std::move(output_list);
} }
// We were not able to forward the input. Will have to resize from scratch.
Tensor* result;
AllocatorAttributes attr;
attr.set_on_host(true);
OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape{}, &result, attr));
TensorList output_list;
output_list.element_shape = input_list->element_shape;
output_list.element_dtype = input_list->element_dtype;
output_list.max_num_elements = input_list->max_num_elements;
if (size > input_list->tensors().size()) {
output_list.tensors().insert(output_list.tensors().begin(),
input_list->tensors().begin(),
input_list->tensors().end());
// Add DT_INVALID tensors to the end of the list if the requested size
// is larger than the list length.
output_list.tensors().resize(size, Tensor(DT_INVALID));
} else {
output_list.tensors().insert(output_list.tensors().begin(),
input_list->tensors().begin(),
input_list->tensors().begin() + size);
}
result->scalar<Variant>()() = std::move(output_list);
} }
}; };
@ -495,9 +508,9 @@ class TensorListSetItem : public OpKernel {
" but list elements ", " but list elements ",
DataTypeString(l->element_dtype))); DataTypeString(l->element_dtype)));
int32 index = c->input(1).scalar<int32>()(); int32 index = c->input(1).scalar<int32>()();
OP_REQUIRES(c, index < l->tensors.size(), OP_REQUIRES(c, index < l->tensors().size(),
errors::InvalidArgument("Trying to modify element ", index, errors::InvalidArgument("Trying to modify element ", index,
" in a list with ", l->tensors.size(), " in a list with ", l->tensors().size(),
" elements.")); " elements."));
const Tensor& value = c->input(2); const Tensor& value = c->input(2);
OP_REQUIRES(c, l->element_shape.IsCompatibleWith(value.shape()), OP_REQUIRES(c, l->element_shape.IsCompatibleWith(value.shape()),
@ -508,7 +521,7 @@ class TensorListSetItem : public OpKernel {
" list shape: ", l->element_shape.DebugString())); " list shape: ", l->element_shape.DebugString()));
TensorList* output_list = nullptr; TensorList* output_list = nullptr;
OP_REQUIRES_OK(c, ForwardInputOrCreateNewList(c, 0, 0, *l, &output_list)); OP_REQUIRES_OK(c, ForwardInputOrCreateNewList(c, 0, 0, *l, &output_list));
output_list->tensors[index] = value; output_list->tensors()[index] = value;
} }
private: private:
@ -560,11 +573,26 @@ class TensorListConcatLists : public OpKernel {
const Tensor& tl_a = c->input(0); const Tensor& tl_a = c->input(0);
const Tensor& tl_b = c->input(1); const Tensor& tl_b = c->input(1);
Tensor* output; Tensor* output = nullptr;
if (tl_alias) { bool ok_to_alias = tl_alias != nullptr;
c->set_output(0, *tl_alias); if (tl_alias && tl_alias->dtype() == DT_VARIANT &&
output = tl_alias.get(); tl_alias->NumElements() > 0) {
} else { auto tl_a_t = tl_alias->flat<Variant>();
for (int64 i = 0; i < tl_alias->NumElements(); ++i) {
TensorList* aliased = tl_a_t(i).get<TensorList>();
if (aliased == nullptr || !aliased->RefCountIsOne()) {
ok_to_alias = false;
break;
}
}
if (ok_to_alias) {
c->set_output(0, *tl_alias);
output = tl_alias.get();
}
}
if (!ok_to_alias) {
// Couldn't alias the entire Tensor. We'll be conservative and not try
// to alias individual batch entries.
attr.set_on_host(true); attr.set_on_host(true);
OP_REQUIRES_OK(c, c->allocate_output(0, tl_a_shape, &output, attr)); OP_REQUIRES_OK(c, c->allocate_output(0, tl_a_shape, &output, attr));
} }
@ -573,45 +601,42 @@ class TensorListConcatLists : public OpKernel {
auto tl_a_t = tl_a.flat<Variant>(); auto tl_a_t = tl_a.flat<Variant>();
auto tl_b_t = tl_b.flat<Variant>(); auto tl_b_t = tl_b.flat<Variant>();
for (int64 b = 0; b < tl_a.NumElements(); ++b) { for (int64 i = 0; i < tl_a.NumElements(); ++i) {
const TensorList* l_a = tl_a_t(b).get<TensorList>(); const TensorList* l_a = tl_a_t(i).get<TensorList>();
const TensorList* l_b = tl_b_t(b).get<TensorList>(); const TensorList* l_b = tl_b_t(i).get<TensorList>();
OP_REQUIRES( OP_REQUIRES(
c, l_a != nullptr, c, l_a != nullptr,
errors::InvalidArgument("input_a is not a TensorList at index ", b, errors::InvalidArgument("input_a is not a TensorList at index ", i,
". Saw: '", tl_a_t(b).DebugString(), "'")); ". Saw: '", tl_a_t(i).DebugString(), "'"));
OP_REQUIRES( OP_REQUIRES(
c, l_b != nullptr, c, l_b != nullptr,
errors::InvalidArgument("input_b is not a TensorList at index ", b, errors::InvalidArgument("input_b is not a TensorList at index ", i,
". Saw: '", tl_b_t(b).DebugString(), "'")); ". Saw: '", tl_b_t(i).DebugString(), "'"));
OP_REQUIRES(c, l_a->element_dtype == element_dtype_, OP_REQUIRES(c, l_a->element_dtype == element_dtype_,
errors::InvalidArgument( errors::InvalidArgument(
"input_a[", b, "].dtype != element_dtype. Saw: ", "input_a[", i, "].dtype != element_dtype. Saw: ",
DataTypeString(l_a->element_dtype), " vs. ", DataTypeString(l_a->element_dtype), " vs. ",
DataTypeString(element_dtype_))); DataTypeString(element_dtype_)));
OP_REQUIRES(c, l_b->element_dtype == element_dtype_, OP_REQUIRES(c, l_b->element_dtype == element_dtype_,
errors::InvalidArgument( errors::InvalidArgument(
"input_b[", b, "].dtype != element_dtype. Saw: ", "input_b[", i, "].dtype != element_dtype. Saw: ",
DataTypeString(l_b->element_dtype), " vs. ", DataTypeString(l_b->element_dtype), " vs. ",
DataTypeString(element_dtype_))); DataTypeString(element_dtype_)));
OP_REQUIRES(c, l_a->element_shape.IsIdenticalTo(l_b->element_shape), OP_REQUIRES(c, l_a->element_shape.IsIdenticalTo(l_b->element_shape),
errors::InvalidArgument( errors::InvalidArgument(
"input_a and input_b TensorList element shapes are not " "input_a and input_b TensorList element shapes are not "
"identical at index ", "identical at index ",
b, ". Saw ", l_a->element_shape.DebugString(), " vs. ", i, ". Saw ", l_a->element_shape.DebugString(), " vs. ",
l_b->element_shape.DebugString())); l_b->element_shape.DebugString()));
if (tl_alias) { if (ok_to_alias) {
TensorList* out = output_t(b).get<TensorList>(); TensorList* out = output_t(i).get<TensorList>();
DCHECK(out != nullptr) << "Expected output to alias input_a, but it " std::copy(l_b->tensors().begin(), l_b->tensors().end(),
"doesn't contain a TensorList at index " std::back_inserter(out->tensors()));
<< b;
std::copy(l_b->tensors.begin(), l_b->tensors.end(),
std::back_inserter(out->tensors));
} else { } else {
TensorList out = *l_a; TensorList out = l_a->Copy();
std::copy(l_b->tensors.begin(), l_b->tensors.end(), std::copy(l_b->tensors().begin(), l_b->tensors().end(),
std::back_inserter(out.tensors)); std::back_inserter(out.tensors()));
output_t(b) = std::move(out); output_t(i) = std::move(out);
} }
} }
} }

View File

@ -31,7 +31,9 @@ limitations under the License.
#include "tensorflow/core/kernels/fill_functor.h" #include "tensorflow/core/kernels/fill_functor.h"
#include "tensorflow/core/lib/core/coding.h" #include "tensorflow/core/lib/core/coding.h"
#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/platform.h"
#include "tensorflow/core/util/tensor_ops_util.h" #include "tensorflow/core/util/tensor_ops_util.h"
#include "tensorflow/core/util/util.h" #include "tensorflow/core/util/util.h"
@ -41,12 +43,85 @@ typedef Eigen::ThreadPoolDevice CPUDevice;
// Variant compatible type for a list of tensors. This is mutable but instances // Variant compatible type for a list of tensors. This is mutable but instances
// should never be mutated after stored in a variant tensor. // should never be mutated after stored in a variant tensor.
struct TensorList { //
// **NOTE**: TensorList stores a refcounted container of tf::Tensor objects,
// which are accessible via TensorList::tensors(). Because it is refcounted,
// straight copies of the form:
//
// TensorList b = a;
// b.tensors().push_back(t); // WARNING: This modifies a.tensors().
//
// Do not create a true copy of the underlying container - but instead increment
// a reference count. Modifying b.tensors() modifies a.tensors(). In this way,
// TensorList should be considered similar to the tf::Tensor object.
//
// In order to get a copy of the underlying list, use the Copy method:
//
// TensorList b = a.Copy();
// b.tensors().push_back(t); // This does not modify a.tensors().
//
// Note that this is not a deep copy: the memory locations of the underlying
// tensors will still point to the same locations of the corresponding tensors
// in the original. To truly perform a deep copy, Device and Type-specific
// code needs to be applied to the underlying tensors as usual.
//
// The most important implication of RefCounted TLs is that OpKernels
// wishing to reuse TensorList inputs as outputs via context->forward_input()
// need to perform an additional check on the refcount of the TensorList,
// to ensure aliasing can be performed safely. For example:
//
// bool can_alias = false;
// auto fw = c->forward_input(..., DT_VARIANT, {}, ...);
// if (fw && fw->dtype() == DT_VARIANT && fw->NumElements() == 1) {
// auto* tl = fw->scalar<Variant>()().get<TensorList>();
// if (tl && tl->RefCountIsOne()) {
// can_alias = true;
// }
// }
//
class TensorList {
public: public:
TensorList() {} TensorList() : tensors_(new Tensors) {}
TensorList(const TensorList& other); ~TensorList();
TensorList(const TensorList& other)
: element_shape(other.element_shape),
element_dtype(other.element_dtype),
max_num_elements(other.max_num_elements),
tensors_(other.tensors_) {
tensors_->Ref();
}
TensorList(TensorList&& rhs)
: element_shape(std::move(rhs.element_shape)),
element_dtype(rhs.element_dtype),
max_num_elements(rhs.max_num_elements),
tensors_(rhs.tensors_) {
rhs.tensors_ = nullptr;
}
TensorList& operator=(const TensorList& rhs) {
if (this == &rhs) return *this;
element_shape = rhs.element_shape;
element_dtype = rhs.element_dtype;
max_num_elements = rhs.max_num_elements;
tensors_->Unref();
tensors_ = rhs.tensors_;
tensors_->Ref();
return *this;
}
TensorList& operator=(TensorList&& rhs) {
if (this == &rhs) return *this;
element_shape = rhs.element_shape;
element_dtype = rhs.element_dtype;
max_num_elements = rhs.max_num_elements;
std::swap(tensors_, rhs.tensors_);
return *this;
}
static const char kTypeName[]; static const char kTypeName[];
string TypeName() const { return kTypeName; } string TypeName() const { return kTypeName; }
void Encode(VariantTensorData* data) const; void Encode(VariantTensorData* data) const;
@ -56,14 +131,47 @@ struct TensorList {
// TODO(apassos) fill this out // TODO(apassos) fill this out
string DebugString() const { return "TensorList"; } string DebugString() const { return "TensorList"; }
std::vector<Tensor> tensors;
PartialTensorShape element_shape; PartialTensorShape element_shape;
DataType element_dtype; DataType element_dtype;
// The maximum allowed size of `tensors`. Defaults to -1 meaning that the size // The maximum allowed size of `tensors`. Defaults to -1 meaning that the size
// of `tensors` is unbounded. // of `tensors` is unbounded.
int max_num_elements = -1; int max_num_elements = -1;
// Access to the underlying tensor container.
std::vector<Tensor>& tensors() { return tensors_->values_; }
const std::vector<Tensor>& tensors() const { return tensors_->values_; }
// Get a new TensorList containing a copy of the underlying tensor container.
TensorList Copy() const {
TensorList out;
out.element_shape = element_shape;
out.element_dtype = element_dtype;
out.max_num_elements = max_num_elements;
// This performs a copy of the std::vector.
out.tensors_->values_ = tensors_->values_;
return out;
}
// Is this TensorList the only one with a reference to the underlying
// container?
bool RefCountIsOne() const { return tensors_->RefCountIsOne(); }
private:
class Tensors : public core::RefCounted {
public:
std::vector<Tensor> values_;
};
Tensors* tensors_;
}; };
#if defined(PLATFORM_GOOGLE)
// TODO(ebrevdo): Identify why Variant inline size is smaller on mobile devices.
static_assert(Variant::CanInlineType<TensorList>(),
"Must be able to inline TensorList into a Variant");
#endif
Status TensorShapeFromTensor(const Tensor& t, PartialTensorShape* out); Status TensorShapeFromTensor(const Tensor& t, PartialTensorShape* out);
Status GetElementShapeFromInput(OpKernelContext* c, Status GetElementShapeFromInput(OpKernelContext* c,
@ -96,18 +204,19 @@ class TensorListStack : public OpKernel {
"Invalid data types; op elements ", DataTypeString(element_dtype_), "Invalid data types; op elements ", DataTypeString(element_dtype_),
" but list elements ", DataTypeString(tensor_list->element_dtype))); " but list elements ", DataTypeString(tensor_list->element_dtype)));
if (num_elements_ != -1) { if (num_elements_ != -1) {
OP_REQUIRES(c, tensor_list->tensors.size() == num_elements_, OP_REQUIRES(c, tensor_list->tensors().size() == num_elements_,
errors::InvalidArgument( errors::InvalidArgument(
"Operation expected a list with ", num_elements_, "Operation expected a list with ", num_elements_,
" elements but got a list with ", " elements but got a list with ",
tensor_list->tensors.size(), " elements.")); tensor_list->tensors().size(), " elements."));
} }
PartialTensorShape partial_element_shape; PartialTensorShape partial_element_shape;
OP_REQUIRES_OK(c, GetElementShapeFromInput(c, *tensor_list, 1, OP_REQUIRES_OK(c, GetElementShapeFromInput(c, *tensor_list, 1,
&partial_element_shape)); &partial_element_shape));
OP_REQUIRES( OP_REQUIRES(
c, c,
partial_element_shape.IsFullyDefined() || !tensor_list->tensors.empty(), partial_element_shape.IsFullyDefined() ||
!tensor_list->tensors().empty(),
errors::InvalidArgument("Tried to stack elements of an empty ", errors::InvalidArgument("Tried to stack elements of an empty ",
"list with non-fully-defined element_shape: ", "list with non-fully-defined element_shape: ",
partial_element_shape.DebugString())); partial_element_shape.DebugString()));
@ -115,8 +224,8 @@ class TensorListStack : public OpKernel {
// Check that `element_shape` input tensor is compatible with the shapes of // Check that `element_shape` input tensor is compatible with the shapes of
// element tensors. // element tensors.
if (!tensor_list->element_shape.IsFullyDefined()) { if (!tensor_list->element_shape.IsFullyDefined()) {
for (int i = 0; i < tensor_list->tensors.size(); ++i) { for (int i = 0; i < tensor_list->tensors().size(); ++i) {
const Tensor& t = tensor_list->tensors[i]; const Tensor& t = tensor_list->tensors()[i];
if (t.dtype() != DT_INVALID) { if (t.dtype() != DT_INVALID) {
PartialTensorShape tmp = partial_element_shape; PartialTensorShape tmp = partial_element_shape;
OP_REQUIRES_OK(c, tmp.MergeWith(t.shape(), &partial_element_shape)); OP_REQUIRES_OK(c, tmp.MergeWith(t.shape(), &partial_element_shape));
@ -133,7 +242,7 @@ class TensorListStack : public OpKernel {
"tensors and has a non-fully-defined element_shape: ", "tensors and has a non-fully-defined element_shape: ",
partial_element_shape.DebugString())); partial_element_shape.DebugString()));
TensorShape output_shape = element_shape; TensorShape output_shape = element_shape;
output_shape.InsertDim(0, tensor_list->tensors.size()); output_shape.InsertDim(0, tensor_list->tensors().size());
Tensor* output; Tensor* output;
OP_REQUIRES_OK(c, c->allocate_output(0, output_shape, &output)); OP_REQUIRES_OK(c, c->allocate_output(0, output_shape, &output));
if (output->NumElements() == 0) { if (output->NumElements() == 0) {
@ -141,9 +250,9 @@ class TensorListStack : public OpKernel {
} }
ConstMatrixVector inputs_flat; ConstMatrixVector inputs_flat;
inputs_flat.reserve(tensor_list->tensors.size()); inputs_flat.reserve(tensor_list->tensors().size());
Tensor zeros; Tensor zeros;
for (const auto& t : tensor_list->tensors) { for (const auto& t : tensor_list->tensors()) {
if (t.dtype() != DT_INVALID) { if (t.dtype() != DT_INVALID) {
inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix( inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix(
t.shaped<T, 2>({1, t.NumElements()}))); t.shaped<T, 2>({1, t.NumElements()})));
@ -195,12 +304,12 @@ class TensorListGetItem : public OpKernel {
" but list elements ", " but list elements ",
DataTypeString(l->element_dtype))); DataTypeString(l->element_dtype)));
int32 index = c->input(1).scalar<int32>()(); int32 index = c->input(1).scalar<int32>()();
OP_REQUIRES(c, index < l->tensors.size(), OP_REQUIRES(c, index < l->tensors().size(),
errors::InvalidArgument("Trying to access element ", index, errors::InvalidArgument("Trying to access element ", index,
" in a list with ", l->tensors.size(), " in a list with ", l->tensors().size(),
" elements.")); " elements."));
if (l->tensors[index].dtype() != DT_INVALID) { if (l->tensors()[index].dtype() != DT_INVALID) {
c->set_output(0, l->tensors[index]); c->set_output(0, l->tensors()[index]);
} else { } else {
PartialTensorShape partial_element_shape; PartialTensorShape partial_element_shape;
OP_REQUIRES_OK( OP_REQUIRES_OK(
@ -216,7 +325,7 @@ class TensorListGetItem : public OpKernel {
// In that mode TensorArray sets the array's element_shape on the first // In that mode TensorArray sets the array's element_shape on the first
// write call. We could do something similar here if needed. // write call. We could do something similar here if needed.
if (!partial_element_shape.IsFullyDefined()) { if (!partial_element_shape.IsFullyDefined()) {
for (const Tensor& t : l->tensors) { for (const Tensor& t : l->tensors()) {
if (t.dtype() != DT_INVALID) { if (t.dtype() != DT_INVALID) {
PartialTensorShape tmp = partial_element_shape; PartialTensorShape tmp = partial_element_shape;
OP_REQUIRES_OK(c, tmp.MergeWith(t.shape(), &partial_element_shape)); OP_REQUIRES_OK(c, tmp.MergeWith(t.shape(), &partial_element_shape));
@ -260,10 +369,10 @@ class TensorListPopBack : public OpKernel {
" but list elements ", " but list elements ",
DataTypeString(l->element_dtype))); DataTypeString(l->element_dtype)));
OP_REQUIRES(c, !l->tensors.empty(), OP_REQUIRES(c, !l->tensors().empty(),
errors::InvalidArgument("Trying to pop from an empty list.")); errors::InvalidArgument("Trying to pop from an empty list."));
const Tensor& t = l->tensors.back(); const Tensor& t = l->tensors().back();
if (t.dtype() != DT_INVALID) { if (t.dtype() != DT_INVALID) {
c->set_output(1, t); c->set_output(1, t);
} else { } else {
@ -288,7 +397,7 @@ class TensorListPopBack : public OpKernel {
TensorList* output_list = nullptr; TensorList* output_list = nullptr;
OP_REQUIRES_OK(c, ForwardInputOrCreateNewList(c, 0, 0, *l, &output_list)); OP_REQUIRES_OK(c, ForwardInputOrCreateNewList(c, 0, 0, *l, &output_list));
output_list->tensors.pop_back(); output_list->tensors().pop_back();
} }
private: private:
@ -347,7 +456,7 @@ class TensorListConcat : public OpKernel {
// If the TensorList is empty, element_shape_except_first_dim_ must be fully // If the TensorList is empty, element_shape_except_first_dim_ must be fully
// defined. // defined.
OP_REQUIRES(c, OP_REQUIRES(c,
!tensor_list->tensors.empty() || !tensor_list->tensors().empty() ||
element_shape_except_first_dim_.IsFullyDefined(), element_shape_except_first_dim_.IsFullyDefined(),
errors::InvalidArgument( errors::InvalidArgument(
"All except the first dimension must be fully defined ", "All except the first dimension must be fully defined ",
@ -364,8 +473,8 @@ class TensorListConcat : public OpKernel {
if (!tensor_list->element_shape.IsFullyDefined()) { if (!tensor_list->element_shape.IsFullyDefined()) {
bool check_dim = (first_dim == -1); bool check_dim = (first_dim == -1);
int64 inferred_first_dim = first_dim; int64 inferred_first_dim = first_dim;
for (int i = 0; i < tensor_list->tensors.size(); ++i) { for (int i = 0; i < tensor_list->tensors().size(); ++i) {
const Tensor& t = tensor_list->tensors[i]; const Tensor& t = tensor_list->tensors()[i];
if (t.dtype() != DT_INVALID) { if (t.dtype() != DT_INVALID) {
PartialTensorShape tmp = element_shape_except_first_dim_; PartialTensorShape tmp = element_shape_except_first_dim_;
OP_REQUIRES( OP_REQUIRES(
@ -407,14 +516,14 @@ class TensorListConcat : public OpKernel {
OP_REQUIRES_OK( OP_REQUIRES_OK(
c, c,
c->allocate_output( c->allocate_output(
1, TensorShape({static_cast<int64>(tensor_list->tensors.size())}), 1, TensorShape({static_cast<int64>(tensor_list->tensors().size())}),
&lengths_tensor)); &lengths_tensor));
auto lengths_tensor_vec = lengths_tensor->vec<int64>(); auto lengths_tensor_vec = lengths_tensor->vec<int64>();
int64 leading_dim = 0; int64 leading_dim = 0;
for (size_t i = 0; i < tensor_list->tensors.size(); i++) { for (size_t i = 0; i < tensor_list->tensors().size(); i++) {
int64 dim; int64 dim;
if (tensor_list->tensors[i].dtype() != DT_INVALID) { if (tensor_list->tensors()[i].dtype() != DT_INVALID) {
dim = tensor_list->tensors[i].shape().dim_size(0); dim = tensor_list->tensors()[i].shape().dim_size(0);
} else { } else {
// If leading_dims is not provided or does not contain an entry for // If leading_dims is not provided or does not contain an entry for
// index i use the inferred `first_dim` if set. // index i use the inferred `first_dim` if set.
@ -449,12 +558,12 @@ class TensorListConcat : public OpKernel {
} }
ConstMatrixVector inputs_flat; ConstMatrixVector inputs_flat;
inputs_flat.reserve(tensor_list->tensors.size()); inputs_flat.reserve(tensor_list->tensors().size());
// Store the zeros tensors in a vector to prevent them from being GC'ed till // Store the zeros tensors in a vector to prevent them from being GC'ed till
// concat is complete. // concat is complete.
std::vector<Tensor> zeros_vec; std::vector<Tensor> zeros_vec;
for (int i = 0; i < tensor_list->tensors.size(); i++) { for (int i = 0; i < tensor_list->tensors().size(); i++) {
const Tensor& element_tensor = tensor_list->tensors[i]; const Tensor& element_tensor = tensor_list->tensors()[i];
if (element_tensor.dtype() != DT_INVALID) { if (element_tensor.dtype() != DT_INVALID) {
inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix( inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix(
element_tensor.shaped<T, 2>({1, element_tensor.NumElements()}))); element_tensor.shaped<T, 2>({1, element_tensor.NumElements()})));
@ -536,7 +645,7 @@ class TensorListSplit : public OpKernel {
errors::InvalidArgument( errors::InvalidArgument(
"Expected lengths to be a vector, received shape: ", "Expected lengths to be a vector, received shape: ",
lengths.shape().DebugString())); lengths.shape().DebugString()));
output_list.tensors.reserve(lengths.shape().dim_size(0)); output_list.tensors().reserve(lengths.shape().dim_size(0));
int64 start = 0; int64 start = 0;
int64 end = 0; int64 end = 0;
for (int i = 0; i < lengths.shape().dim_size(0); ++i) { for (int i = 0; i < lengths.shape().dim_size(0); ++i) {
@ -557,7 +666,7 @@ class TensorListSplit : public OpKernel {
OP_REQUIRES_OK(c, c->allocate_temp(tmp.dtype(), tmp.shape(), &aligned)); OP_REQUIRES_OK(c, c->allocate_temp(tmp.dtype(), tmp.shape(), &aligned));
aligned.flat<T>().device(c->eigen_device<Device>()) = aligned.flat<T>().device(c->eigen_device<Device>()) =
tmp.unaligned_flat<T>(); tmp.unaligned_flat<T>();
output_list.tensors.emplace_back(aligned); output_list.tensors().emplace_back(aligned);
} }
OP_REQUIRES(c, end == input_tensor.shape().dim_size(0), OP_REQUIRES(c, end == input_tensor.shape().dim_size(0),
errors::InvalidArgument( errors::InvalidArgument(
@ -599,7 +708,7 @@ class TensorListGather : public OpKernel {
if (!tensor_list->element_shape.IsFullyDefined()) { if (!tensor_list->element_shape.IsFullyDefined()) {
for (int index = 0; index < indices.NumElements(); ++index) { for (int index = 0; index < indices.NumElements(); ++index) {
const int i = indices.flat<int32>()(index); const int i = indices.flat<int32>()(index);
const Tensor& t = tensor_list->tensors[i]; const Tensor& t = tensor_list->tensors()[i];
if (t.dtype() != DT_INVALID) { if (t.dtype() != DT_INVALID) {
PartialTensorShape tmp = partial_element_shape; PartialTensorShape tmp = partial_element_shape;
OP_REQUIRES_OK(c, tmp.MergeWith(t.shape(), &partial_element_shape)); OP_REQUIRES_OK(c, tmp.MergeWith(t.shape(), &partial_element_shape));
@ -629,10 +738,10 @@ class TensorListGather : public OpKernel {
for (int index = 0; index < indices.NumElements(); ++index) { for (int index = 0; index < indices.NumElements(); ++index) {
const int i = indices.flat<int32>()(index); const int i = indices.flat<int32>()(index);
OP_REQUIRES( OP_REQUIRES(
c, i < tensor_list->tensors.size(), c, i < tensor_list->tensors().size(),
errors::InvalidArgument("Index ", i, " out o range; list only has ", errors::InvalidArgument("Index ", i, " out o range; list only has ",
tensor_list->tensors.size(), " elements.")); tensor_list->tensors().size(), " elements."));
const Tensor& t = tensor_list->tensors[i]; const Tensor& t = tensor_list->tensors()[i];
if (t.dtype() != DT_INVALID) { if (t.dtype() != DT_INVALID) {
inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix( inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix(
t.shaped<T, 2>({1, t.NumElements()}))); t.shaped<T, 2>({1, t.NumElements()})));
@ -693,7 +802,7 @@ class TensorListFromTensor : public OpKernel {
"Specified a list with shape ", element_shape.DebugString(), "Specified a list with shape ", element_shape.DebugString(),
" from a tensor with shape ", output_shape.DebugString())); " from a tensor with shape ", output_shape.DebugString()));
output_list.element_shape = element_shape; output_list.element_shape = element_shape;
output_list.tensors.reserve(t.shape().dim_size(0)); output_list.tensors().reserve(t.shape().dim_size(0));
for (int i = 0; i < t.shape().dim_size(0); ++i) { for (int i = 0; i < t.shape().dim_size(0); ++i) {
Tensor tmp = t.Slice(i, i + 1); Tensor tmp = t.Slice(i, i + 1);
TensorShape tmp_shape = tmp.shape(); TensorShape tmp_shape = tmp.shape();
@ -706,7 +815,7 @@ class TensorListFromTensor : public OpKernel {
OP_REQUIRES_OK(c, c->allocate_temp(tmp.dtype(), tmp.shape(), &aligned)); OP_REQUIRES_OK(c, c->allocate_temp(tmp.dtype(), tmp.shape(), &aligned));
aligned.flat<T>().device(c->eigen_device<Device>()) = aligned.flat<T>().device(c->eigen_device<Device>()) =
tmp.unaligned_flat<T>(); tmp.unaligned_flat<T>();
output_list.tensors.push_back(aligned); output_list.tensors().push_back(aligned);
} }
output_tensor->scalar<Variant>()() = std::move(output_list); output_tensor->scalar<Variant>()() = std::move(output_list);
} }
@ -732,7 +841,7 @@ Status Scatter(OpKernelContext* c, const Tensor& value, const Tensor& indices,
// many small ones. // many small ones.
aligned.flat<T>().device(c->eigen_device<Device>()) = aligned.flat<T>().device(c->eigen_device<Device>()) =
tmp.unaligned_flat<T>(); tmp.unaligned_flat<T>();
std::swap(list->tensors[i], aligned); std::swap(list->tensors()[i], aligned);
} }
return Status::OK(); return Status::OK();
} }
@ -777,8 +886,8 @@ class TensorListScatterIntoExistingList : public OpKernel {
? -1 ? -1
: *std::max_element(indices_vec.data(), : *std::max_element(indices_vec.data(),
indices_vec.data() + indices.NumElements()); indices_vec.data() + indices.NumElements());
if (max_index + 1 > output_list->tensors.size()) { if (max_index + 1 > output_list->tensors().size()) {
output_list->tensors.resize(max_index + 1); output_list->tensors().resize(max_index + 1);
} }
// Scatter the values. // Scatter the values.
@ -845,8 +954,8 @@ class TensorListScatter : public OpKernel {
highest_index = i; highest_index = i;
} }
} }
output_list.tensors.resize(std::max(highest_index + 1, num_elements), output_list.tensors().resize(std::max(highest_index + 1, num_elements),
Tensor(DT_INVALID)); Tensor(DT_INVALID));
} }
OP_REQUIRES_OK(c, OP_REQUIRES_OK(c,
@ -875,19 +984,19 @@ Status TensorListBinaryAdd(OpKernelContext* c, const TensorList& a,
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
a.element_shape.MergeWith(b.element_shape, &out->element_shape)); a.element_shape.MergeWith(b.element_shape, &out->element_shape));
if (a.tensors.size() != b.tensors.size()) { if (a.tensors().size() != b.tensors().size()) {
return errors::InvalidArgument( return errors::InvalidArgument(
"Trying to add two lists of tensors with different lengths. One is ", "Trying to add two lists of tensors with different lengths. One is ",
a.tensors.size(), " and the other is ", b.tensors.size()); a.tensors().size(), " and the other is ", b.tensors().size());
} }
out->tensors.reserve(a.tensors.size()); out->tensors().reserve(a.tensors().size());
for (int i = 0; i < a.tensors.size(); ++i) { for (int i = 0; i < a.tensors().size(); ++i) {
const Tensor& a_tensor = a.tensors[i]; const Tensor& a_tensor = a.tensors()[i];
const Tensor& b_tensor = b.tensors[i]; const Tensor& b_tensor = b.tensors()[i];
Tensor out_tensor; Tensor out_tensor;
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
BinaryAddTensors<Device>(c, a_tensor, b_tensor, &out_tensor)); BinaryAddTensors<Device>(c, a_tensor, b_tensor, &out_tensor));
out->tensors.push_back(out_tensor); out->tensors().push_back(out_tensor);
} }
return Status::OK(); return Status::OK();
} }
@ -897,11 +1006,11 @@ Status TensorListZerosLike(OpKernelContext* c, const TensorList& x,
TensorList* y) { TensorList* y) {
y->element_dtype = x.element_dtype; y->element_dtype = x.element_dtype;
y->element_shape = x.element_shape; y->element_shape = x.element_shape;
y->tensors.reserve(x.tensors.size()); y->tensors().reserve(x.tensors().size());
for (const Tensor& t : x.tensors) { for (const Tensor& t : x.tensors()) {
Tensor out_tensor; Tensor out_tensor;
TF_RETURN_IF_ERROR(ZerosLikeTensor<Device>(c, t, &out_tensor)); TF_RETURN_IF_ERROR(ZerosLikeTensor<Device>(c, t, &out_tensor));
y->tensors.emplace_back(out_tensor); y->tensors().emplace_back(out_tensor);
} }
return Status::OK(); return Status::OK();
} }
@ -936,7 +1045,19 @@ class TensorListPushBackBatch : public OpKernel {
0 /*input_index*/, 0 /*output_index*/, DT_VARIANT, tls_shape, 0 /*input_index*/, 0 /*output_index*/, DT_VARIANT, tls_shape,
DEVICE_MEMORY /* input is always on DEVICE_MEMORY */, attr); DEVICE_MEMORY /* input is always on DEVICE_MEMORY */, attr);
const Tensor& tls = tls_alias ? *tls_alias : c->input(0); bool ok_to_alias = tls_alias != nullptr;
if (tls_alias && tls_alias->dtype() == DT_VARIANT &&
tls_alias->NumElements() > 0) {
auto alias_t = tls_alias->flat<Variant>();
for (int i = 0; i < tls_alias->NumElements(); ++i) {
TensorList* tl_i = alias_t(i).get<TensorList>();
if (tl_i == nullptr || !tl_i->RefCountIsOne()) {
ok_to_alias = false;
break;
}
}
}
const Tensor& tls = ok_to_alias ? *tls_alias : c->input(0);
OP_REQUIRES(c, tls.dtype() == DT_VARIANT, OP_REQUIRES(c, tls.dtype() == DT_VARIANT,
errors::InvalidArgument( errors::InvalidArgument(
@ -979,7 +1100,7 @@ class TensorListPushBackBatch : public OpKernel {
Tensor* result; Tensor* result;
if (tls_alias) { if (ok_to_alias) {
result = tls_alias.get(); result = tls_alias.get();
c->set_output(0, *result); c->set_output(0, *result);
} else { } else {
@ -998,8 +1119,8 @@ class TensorListPushBackBatch : public OpKernel {
auto result_t = result->vec<Variant>(); auto result_t = result->vec<Variant>();
for (int64 b = 0; b < batch_size; ++b) { for (int64 b = 0; b < batch_size; ++b) {
if (!tls_alias) { if (!ok_to_alias) {
result_t(b) = *tl_batch[b]; result_t(b) = tl_batch[b]->Copy();
} }
TensorList* output = result_t(b).get<TensorList>(); TensorList* output = result_t(b).get<TensorList>();
DCHECK(output != nullptr); DCHECK(output != nullptr);
@ -1011,7 +1132,7 @@ class TensorListPushBackBatch : public OpKernel {
auto frame_t = frame->flat<T>(); auto frame_t = frame->flat<T>();
frame_t.device(c->eigen_device<Device>()) = input_t.template chip<0>(b); frame_t.device(c->eigen_device<Device>()) = input_t.template chip<0>(b);
} }
output->tensors.push_back(std::move(*frame)); output->tensors().push_back(std::move(*frame));
} }
} }

View File

@ -53,7 +53,10 @@ class ListOpsTest(test_util.TensorFlowTestCase, parameterized.TestCase):
max_num_elements=max_num_elements) max_num_elements=max_num_elements)
l = list_ops.tensor_list_push_back(l, constant_op.constant(1.0)) l = list_ops.tensor_list_push_back(l, constant_op.constant(1.0))
l, e = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32) l, e = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32)
self.assertAllEqual(self.evaluate(e), 1.0) l = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)
l, e = self.evaluate((l, e))
self.assertAllEqual(l, [])
self.assertAllEqual(e, 1.0)
@parameterized.named_parameters(("NoMaxNumElements", None), @parameterized.named_parameters(("NoMaxNumElements", None),
("WithMaxNumElements", 2)) ("WithMaxNumElements", 2))
@ -94,7 +97,10 @@ class ListOpsTest(test_util.TensorFlowTestCase, parameterized.TestCase):
l = list_ops.tensor_list_reserve( l = list_ops.tensor_list_reserve(
element_dtype=dtypes.float32, element_shape=[2, 3], num_elements=3) element_dtype=dtypes.float32, element_shape=[2, 3], num_elements=3)
_, e = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32) _, e = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32)
l = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)
l, e = self.evaluate((l, e))
self.assertAllEqual(e, np.zeros((2, 3))) self.assertAllEqual(e, np.zeros((2, 3)))
self.assertAllEqual(l, np.zeros((3, 2, 3)))
def testPopUninitializedTensorUseSpecifiedElementShape(self): def testPopUninitializedTensorUseSpecifiedElementShape(self):
l = list_ops.tensor_list_reserve( l = list_ops.tensor_list_reserve(
@ -954,14 +960,18 @@ class ListOpsTest(test_util.TensorFlowTestCase, parameterized.TestCase):
l_concat_11 = list_ops.tensor_list_concat_lists( l_concat_11 = list_ops.tensor_list_concat_lists(
l_batch_1, l_batch_1, element_dtype=dtypes.float32) l_batch_1, l_batch_1, element_dtype=dtypes.float32)
expected_0 = [[1.0, 2.0], [-1.0]]
expected_1 = [[-1.0], [1.0, 2.0]]
expected_00 = [[1.0, 2.0, 1.0, 2.0], [-1.0, -1.0]] expected_00 = [[1.0, 2.0, 1.0, 2.0], [-1.0, -1.0]]
expected_01 = [[1.0, 2.0, -1.0], [-1.0, 1.0, 2.0]] expected_01 = [[1.0, 2.0, -1.0], [-1.0, 1.0, 2.0]]
expected_10 = [[-1.0, 1.0, 2.0], [1.0, 2.0, -1.0]] expected_10 = [[-1.0, 1.0, 2.0], [1.0, 2.0, -1.0]]
expected_11 = [[-1.0, -1.0], [1.0, 2.0, 1.0, 2.0]] expected_11 = [[-1.0, -1.0], [1.0, 2.0, 1.0, 2.0]]
for i, (concat, expected) in enumerate(zip( for i, (concat, expected) in enumerate(zip(
[l_concat_00, l_concat_01, l_concat_10, l_concat_11], [l_batch_0, l_batch_1,
[expected_00, expected_01, expected_10, expected_11])): l_concat_00, l_concat_01, l_concat_10, l_concat_11],
[expected_0, expected_1,
expected_00, expected_01, expected_10, expected_11])):
splitted = array_ops.unstack(concat) splitted = array_ops.unstack(concat)
splitted_stacked_ret = self.evaluate( splitted_stacked_ret = self.evaluate(
(list_ops.tensor_list_stack(splitted[0], dtypes.float32), (list_ops.tensor_list_stack(splitted[0], dtypes.float32),