Define TensorList class in a separate library

Defining TensorList outside list_kernels library will allow clients to use TensorList class without having to also include the kernels operating on it.

PiperOrigin-RevId: 291474941
Change-Id: Iaab9d6c077b6a6c6236896c80d53ac8196472a82
This commit is contained in:
Smit Hinsu 2020-01-24 17:34:08 -08:00 committed by TensorFlower Gardener
parent e977b1cfa6
commit 85cc56784d
5 changed files with 301 additions and 234 deletions

View File

@ -2818,6 +2818,19 @@ tf_kernel_library(
],
)
cc_library(
name = "tensor_list",
srcs = ["tensor_list.cc"],
hdrs = ["tensor_list.h"],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/framework:tensor_shape_proto_cc",
"//tensorflow/core/lib/core:refcount",
],
)
tf_kernel_library(
name = "list_kernels",
srcs = ["list_kernels.cc"],
@ -2829,6 +2842,7 @@ tf_kernel_library(
deps = [
":concat_lib",
":fill_functor",
":tensor_list",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//third_party/eigen3",

View File

@ -39,107 +39,6 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
TensorList::~TensorList() {
if (tensors_) tensors_->Unref();
}
void TensorList::Encode(VariantTensorData* data) const {
data->set_type_name(TypeName());
std::vector<size_t> invalid_indices;
for (size_t i = 0; i < tensors().size(); i++) {
if (tensors().at(i).dtype() != DT_INVALID) {
*data->add_tensors() = tensors().at(i);
} else {
invalid_indices.push_back(i);
}
}
string metadata;
// TODO(b/118838800): Add a proto for storing the metadata.
// Metadata format:
// <num_invalid_tensors><invalid_indices><element_dtype><element_shape_proto>
core::PutVarint64(&metadata, static_cast<uint64>(invalid_indices.size()));
for (size_t i : invalid_indices) {
core::PutVarint64(&metadata, static_cast<uint64>(i));
}
core::PutVarint64(&metadata, static_cast<uint64>(element_dtype));
core::PutVarint64(&metadata, static_cast<uint64>(max_num_elements));
TensorShapeProto element_shape_proto;
element_shape.AsProto(&element_shape_proto);
element_shape_proto.AppendToString(&metadata);
data->set_metadata(metadata);
}
static Status TensorListDeviceCopy(
const TensorList& from, TensorList* to,
const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copy) {
to->element_shape = from.element_shape;
to->element_dtype = from.element_dtype;
to->max_num_elements = from.max_num_elements;
to->tensors().reserve(from.tensors().size());
for (const Tensor& t : from.tensors()) {
to->tensors().emplace_back(t.dtype());
if (t.dtype() != DT_INVALID) {
TF_RETURN_IF_ERROR(copy(t, &to->tensors().back()));
}
}
return Status::OK();
}
#define REGISTER_LIST_COPY(DIRECTION) \
INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION(TensorList, DIRECTION, \
TensorListDeviceCopy)
REGISTER_LIST_COPY(VariantDeviceCopyDirection::HOST_TO_DEVICE);
REGISTER_LIST_COPY(VariantDeviceCopyDirection::DEVICE_TO_HOST);
REGISTER_LIST_COPY(VariantDeviceCopyDirection::DEVICE_TO_DEVICE);
REGISTER_UNARY_VARIANT_DECODE_FUNCTION(TensorList, TensorList::kTypeName);
bool TensorList::Decode(const VariantTensorData& data) {
// TODO(srbs): Change the signature to Decode(VariantTensorData data) so
// that we do not have to copy each tensor individually below. This would
// require changing VariantTensorData::tensors() as well.
string metadata;
data.get_metadata(&metadata);
uint64 scratch;
StringPiece iter(metadata);
std::vector<size_t> invalid_indices;
core::GetVarint64(&iter, &scratch);
size_t num_invalid_tensors = static_cast<size_t>(scratch);
invalid_indices.resize(num_invalid_tensors);
for (size_t i = 0; i < num_invalid_tensors; i++) {
core::GetVarint64(&iter, &scratch);
invalid_indices[i] = static_cast<size_t>(scratch);
}
size_t total_num_tensors = data.tensors().size() + num_invalid_tensors;
tensors().reserve(total_num_tensors);
std::vector<size_t>::iterator invalid_indices_it = invalid_indices.begin();
std::vector<Tensor>::const_iterator tensors_it = data.tensors().begin();
for (size_t i = 0; i < total_num_tensors; i++) {
if (invalid_indices_it != invalid_indices.end() &&
*invalid_indices_it == i) {
tensors().emplace_back(Tensor(DT_INVALID));
invalid_indices_it++;
} else if (tensors_it != data.tensors().end()) {
tensors().emplace_back(*tensors_it);
tensors_it++;
} else {
// VariantTensorData is corrupted.
return false;
}
}
core::GetVarint64(&iter, &scratch);
element_dtype = static_cast<DataType>(scratch);
core::GetVarint64(&iter, &scratch);
max_num_elements = static_cast<int>(scratch);
TensorShapeProto element_shape_proto;
element_shape_proto.ParseFromString(string(iter.data(), iter.size()));
element_shape = PartialTensorShape(element_shape_proto);
return true;
}
Status TensorShapeFromTensor(const Tensor& t, PartialTensorShape* out) {
if (t.shape() == TensorShape({})) {
if ((t.dtype() == DT_INT32 && t.scalar<int32>()() == -1) ||
@ -257,8 +156,6 @@ class EmptyTensorList : public OpKernel {
DataType element_dtype_;
};
const char TensorList::kTypeName[] = "tensorflow::TensorList";
REGISTER_KERNEL_BUILDER(Name("EmptyTensorList").Device(DEVICE_CPU),
EmptyTensorList);

View File

@ -29,6 +29,7 @@ limitations under the License.
#include "tensorflow/core/framework/variant_op_registry.h"
#include "tensorflow/core/kernels/concat_lib.h"
#include "tensorflow/core/kernels/fill_functor.h"
#include "tensorflow/core/kernels/tensor_list.h"
#include "tensorflow/core/lib/core/coding.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/refcount.h"
@ -41,137 +42,6 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
// Variant compatible type for a list of tensors. This is mutable but instances
// should never be mutated after stored in a variant tensor.
//
// **NOTE**: TensorList stores a refcounted container of tf::Tensor objects,
// which are accessible via TensorList::tensors(). Because it is refcounted,
// straight copies of the form:
//
// TensorList b = a;
// b.tensors().push_back(t); // WARNING: This modifies a.tensors().
//
// Do not create a true copy of the underlying container - but instead increment
// a reference count. Modifying b.tensors() modifies a.tensors(). In this way,
// TensorList should be considered similar to the tf::Tensor object.
//
// In order to get a copy of the underlying list, use the Copy method:
//
// TensorList b = a.Copy();
// b.tensors().push_back(t); // This does not modify a.tensors().
//
// Note that this is not a deep copy: the memory locations of the underlying
// tensors will still point to the same locations of the corresponding tensors
// in the original. To truly perform a deep copy, Device and Type-specific
// code needs to be applied to the underlying tensors as usual.
//
// The most important implication of RefCounted TLs is that OpKernels
// wishing to reuse TensorList inputs as outputs via context->forward_input()
// need to perform an additional check on the refcount of the TensorList,
// to ensure aliasing can be performed safely. For example:
//
// bool can_alias = false;
// auto fw = c->forward_input(..., DT_VARIANT, {}, ...);
// if (fw && fw->dtype() == DT_VARIANT && fw->NumElements() == 1) {
// auto* tl = fw->scalar<Variant>()().get<TensorList>();
// if (tl && tl->RefCountIsOne()) {
// can_alias = true;
// }
// }
//
class TensorList {
public:
TensorList() : tensors_(new Tensors) {}
~TensorList();
TensorList(const TensorList& other)
: element_shape(other.element_shape),
element_dtype(other.element_dtype),
max_num_elements(other.max_num_elements),
tensors_(other.tensors_) {
tensors_->Ref();
}
TensorList(TensorList&& rhs)
: element_shape(std::move(rhs.element_shape)),
element_dtype(rhs.element_dtype),
max_num_elements(rhs.max_num_elements),
tensors_(rhs.tensors_) {
rhs.tensors_ = nullptr;
}
TensorList& operator=(const TensorList& rhs) {
if (this == &rhs) return *this;
element_shape = rhs.element_shape;
element_dtype = rhs.element_dtype;
max_num_elements = rhs.max_num_elements;
tensors_->Unref();
tensors_ = rhs.tensors_;
tensors_->Ref();
return *this;
}
TensorList& operator=(TensorList&& rhs) {
if (this == &rhs) return *this;
element_shape = rhs.element_shape;
element_dtype = rhs.element_dtype;
max_num_elements = rhs.max_num_elements;
std::swap(tensors_, rhs.tensors_);
return *this;
}
static const char kTypeName[];
string TypeName() const { return kTypeName; }
void Encode(VariantTensorData* data) const;
bool Decode(const VariantTensorData& data);
// TODO(apassos) fill this out
string DebugString() const { return "TensorList"; }
PartialTensorShape element_shape;
DataType element_dtype;
// The maximum allowed size of `tensors`. Defaults to -1 meaning that the size
// of `tensors` is unbounded.
int max_num_elements = -1;
// Access to the underlying tensor container.
std::vector<Tensor>& tensors() { return tensors_->values_; }
const std::vector<Tensor>& tensors() const { return tensors_->values_; }
// Get a new TensorList containing a copy of the underlying tensor container.
TensorList Copy() const {
TensorList out;
out.element_shape = element_shape;
out.element_dtype = element_dtype;
out.max_num_elements = max_num_elements;
// This performs a copy of the std::vector.
out.tensors_->values_ = tensors_->values_;
return out;
}
// Is this TensorList the only one with a reference to the underlying
// container?
bool RefCountIsOne() const { return tensors_->RefCountIsOne(); }
private:
class Tensors : public core::RefCounted {
public:
std::vector<Tensor> values_;
};
Tensors* tensors_;
};
#if defined(PLATFORM_GOOGLE)
// TODO(ebrevdo): Identify why Variant inline size is smaller on mobile devices.
static_assert(Variant::CanInlineType<TensorList>(),
"Must be able to inline TensorList into a Variant");
#endif
Status TensorShapeFromTensor(const Tensor& t, PartialTensorShape* out);
Status GetElementShapeFromInput(OpKernelContext* c,

View File

@ -0,0 +1,127 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/tensor_list.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/variant_op_registry.h"
#include "tensorflow/core/lib/core/coding.h"
namespace tensorflow {
TensorList::~TensorList() {
if (tensors_) tensors_->Unref();
}
void TensorList::Encode(VariantTensorData* data) const {
data->set_type_name(TypeName());
std::vector<size_t> invalid_indices;
for (size_t i = 0; i < tensors().size(); i++) {
if (tensors().at(i).dtype() != DT_INVALID) {
*data->add_tensors() = tensors().at(i);
} else {
invalid_indices.push_back(i);
}
}
string metadata;
// TODO(b/118838800): Add a proto for storing the metadata.
// Metadata format:
// <num_invalid_tensors><invalid_indices><element_dtype><element_shape_proto>
core::PutVarint64(&metadata, static_cast<uint64>(invalid_indices.size()));
for (size_t i : invalid_indices) {
core::PutVarint64(&metadata, static_cast<uint64>(i));
}
core::PutVarint64(&metadata, static_cast<uint64>(element_dtype));
core::PutVarint64(&metadata, static_cast<uint64>(max_num_elements));
TensorShapeProto element_shape_proto;
element_shape.AsProto(&element_shape_proto);
element_shape_proto.AppendToString(&metadata);
data->set_metadata(metadata);
}
static Status TensorListDeviceCopy(
const TensorList& from, TensorList* to,
const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copy) {
to->element_shape = from.element_shape;
to->element_dtype = from.element_dtype;
to->max_num_elements = from.max_num_elements;
to->tensors().reserve(from.tensors().size());
for (const Tensor& t : from.tensors()) {
to->tensors().emplace_back(t.dtype());
if (t.dtype() != DT_INVALID) {
TF_RETURN_IF_ERROR(copy(t, &to->tensors().back()));
}
}
return Status::OK();
}
#define REGISTER_LIST_COPY(DIRECTION) \
INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION(TensorList, DIRECTION, \
TensorListDeviceCopy)
REGISTER_LIST_COPY(VariantDeviceCopyDirection::HOST_TO_DEVICE);
REGISTER_LIST_COPY(VariantDeviceCopyDirection::DEVICE_TO_HOST);
REGISTER_LIST_COPY(VariantDeviceCopyDirection::DEVICE_TO_DEVICE);
REGISTER_UNARY_VARIANT_DECODE_FUNCTION(TensorList, TensorList::kTypeName);
bool TensorList::Decode(const VariantTensorData& data) {
// TODO(srbs): Change the signature to Decode(VariantTensorData data) so
// that we do not have to copy each tensor individually below. This would
// require changing VariantTensorData::tensors() as well.
string metadata;
data.get_metadata(&metadata);
uint64 scratch;
StringPiece iter(metadata);
std::vector<size_t> invalid_indices;
core::GetVarint64(&iter, &scratch);
size_t num_invalid_tensors = static_cast<size_t>(scratch);
invalid_indices.resize(num_invalid_tensors);
for (size_t i = 0; i < num_invalid_tensors; i++) {
core::GetVarint64(&iter, &scratch);
invalid_indices[i] = static_cast<size_t>(scratch);
}
size_t total_num_tensors = data.tensors().size() + num_invalid_tensors;
tensors().reserve(total_num_tensors);
std::vector<size_t>::iterator invalid_indices_it = invalid_indices.begin();
std::vector<Tensor>::const_iterator tensors_it = data.tensors().begin();
for (size_t i = 0; i < total_num_tensors; i++) {
if (invalid_indices_it != invalid_indices.end() &&
*invalid_indices_it == i) {
tensors().emplace_back(Tensor(DT_INVALID));
invalid_indices_it++;
} else if (tensors_it != data.tensors().end()) {
tensors().emplace_back(*tensors_it);
tensors_it++;
} else {
// VariantTensorData is corrupted.
return false;
}
}
core::GetVarint64(&iter, &scratch);
element_dtype = static_cast<DataType>(scratch);
core::GetVarint64(&iter, &scratch);
max_num_elements = static_cast<int>(scratch);
TensorShapeProto element_shape_proto;
element_shape_proto.ParseFromString(string(iter.data(), iter.size()));
element_shape = PartialTensorShape(element_shape_proto);
return true;
}
const char TensorList::kTypeName[] = "tensorflow::TensorList";
} // namespace tensorflow

View File

@ -0,0 +1,159 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_KERNELS_TENSOR_LIST_H_
#define TENSORFLOW_CORE_KERNELS_TENSOR_LIST_H_
#include <utility>
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/variant.h"
#include "tensorflow/core/framework/variant_tensor_data.h"
#include "tensorflow/core/lib/core/refcount.h"
namespace tensorflow {
// Variant compatible type for a list of tensors. This is mutable but instances
// should never be mutated after stored in a variant tensor.
//
// **NOTE**: TensorList stores a refcounted container of tf::Tensor objects,
// which are accessible via TensorList::tensors(). Because it is refcounted,
// straight copies of the form:
//
// TensorList b = a;
// b.tensors().push_back(t); // WARNING: This modifies a.tensors().
//
// Do not create a true copy of the underlying container - but instead increment
// a reference count. Modifying b.tensors() modifies a.tensors(). In this way,
// TensorList should be considered similar to the tf::Tensor object.
//
// In order to get a copy of the underlying list, use the Copy method:
//
// TensorList b = a.Copy();
// b.tensors().push_back(t); // This does not modify a.tensors().
//
// Note that this is not a deep copy: the memory locations of the underlying
// tensors will still point to the same locations of the corresponding tensors
// in the original. To truly perform a deep copy, Device and Type-specific
// code needs to be applied to the underlying tensors as usual.
//
// The most important implication of RefCounted TLs is that OpKernels
// wishing to reuse TensorList inputs as outputs via context->forward_input()
// need to perform an additional check on the refcount of the TensorList,
// to ensure aliasing can be performed safely. For example:
//
// bool can_alias = false;
// auto fw = c->forward_input(..., DT_VARIANT, {}, ...);
// if (fw && fw->dtype() == DT_VARIANT && fw->NumElements() == 1) {
// auto* tl = fw->scalar<Variant>()().get<TensorList>();
// if (tl && tl->RefCountIsOne()) {
// can_alias = true;
// }
// }
//
class TensorList {
public:
TensorList() : tensors_(new Tensors) {}
~TensorList();
TensorList(const TensorList& other)
: element_shape(other.element_shape),
element_dtype(other.element_dtype),
max_num_elements(other.max_num_elements),
tensors_(other.tensors_) {
tensors_->Ref();
}
TensorList(TensorList&& rhs)
: element_shape(std::move(rhs.element_shape)),
element_dtype(rhs.element_dtype),
max_num_elements(rhs.max_num_elements),
tensors_(rhs.tensors_) {
rhs.tensors_ = nullptr;
}
TensorList& operator=(const TensorList& rhs) {
if (this == &rhs) return *this;
element_shape = rhs.element_shape;
element_dtype = rhs.element_dtype;
max_num_elements = rhs.max_num_elements;
tensors_->Unref();
tensors_ = rhs.tensors_;
tensors_->Ref();
return *this;
}
TensorList& operator=(TensorList&& rhs) {
if (this == &rhs) return *this;
element_shape = rhs.element_shape;
element_dtype = rhs.element_dtype;
max_num_elements = rhs.max_num_elements;
std::swap(tensors_, rhs.tensors_);
return *this;
}
static const char kTypeName[];
string TypeName() const { return kTypeName; }
void Encode(VariantTensorData* data) const;
bool Decode(const VariantTensorData& data);
// TODO(apassos) fill this out
string DebugString() const { return "TensorList"; }
PartialTensorShape element_shape;
DataType element_dtype;
// The maximum allowed size of `tensors`. Defaults to -1 meaning that the size
// of `tensors` is unbounded.
int max_num_elements = -1;
// Access to the underlying tensor container.
std::vector<Tensor>& tensors() { return tensors_->values_; }
const std::vector<Tensor>& tensors() const { return tensors_->values_; }
// Get a new TensorList containing a copy of the underlying tensor container.
TensorList Copy() const {
TensorList out;
out.element_shape = element_shape;
out.element_dtype = element_dtype;
out.max_num_elements = max_num_elements;
// This performs a copy of the std::vector.
out.tensors_->values_ = tensors_->values_;
return out;
}
// Is this TensorList the only one with a reference to the underlying
// container?
bool RefCountIsOne() const { return tensors_->RefCountIsOne(); }
private:
class Tensors : public core::RefCounted {
public:
std::vector<Tensor> values_;
};
Tensors* tensors_;
};
#if defined(PLATFORM_GOOGLE)
// TODO(ebrevdo): Identify why Variant inline size is smaller on mobile devices.
static_assert(Variant::CanInlineType<TensorList>(),
"Must be able to inline TensorList into a Variant");
#endif
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_TENSOR_LIST_H_