Define TensorList class in a separate library
Defining TensorList outside list_kernels library will allow clients to use TensorList class without having to also include the kernels operating on it. PiperOrigin-RevId: 291474941 Change-Id: Iaab9d6c077b6a6c6236896c80d53ac8196472a82
This commit is contained in:
parent
e977b1cfa6
commit
85cc56784d
@ -2818,6 +2818,19 @@ tf_kernel_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tensor_list",
|
||||
srcs = ["tensor_list.cc"],
|
||||
hdrs = ["tensor_list.h"],
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/framework:tensor_shape_proto_cc",
|
||||
"//tensorflow/core/lib/core:refcount",
|
||||
],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "list_kernels",
|
||||
srcs = ["list_kernels.cc"],
|
||||
@ -2829,6 +2842,7 @@ tf_kernel_library(
|
||||
deps = [
|
||||
":concat_lib",
|
||||
":fill_functor",
|
||||
":tensor_list",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//third_party/eigen3",
|
||||
|
@ -39,107 +39,6 @@ namespace tensorflow {
|
||||
|
||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||
|
||||
TensorList::~TensorList() {
|
||||
if (tensors_) tensors_->Unref();
|
||||
}
|
||||
|
||||
void TensorList::Encode(VariantTensorData* data) const {
|
||||
data->set_type_name(TypeName());
|
||||
std::vector<size_t> invalid_indices;
|
||||
for (size_t i = 0; i < tensors().size(); i++) {
|
||||
if (tensors().at(i).dtype() != DT_INVALID) {
|
||||
*data->add_tensors() = tensors().at(i);
|
||||
} else {
|
||||
invalid_indices.push_back(i);
|
||||
}
|
||||
}
|
||||
string metadata;
|
||||
// TODO(b/118838800): Add a proto for storing the metadata.
|
||||
// Metadata format:
|
||||
// <num_invalid_tensors><invalid_indices><element_dtype><element_shape_proto>
|
||||
core::PutVarint64(&metadata, static_cast<uint64>(invalid_indices.size()));
|
||||
for (size_t i : invalid_indices) {
|
||||
core::PutVarint64(&metadata, static_cast<uint64>(i));
|
||||
}
|
||||
core::PutVarint64(&metadata, static_cast<uint64>(element_dtype));
|
||||
core::PutVarint64(&metadata, static_cast<uint64>(max_num_elements));
|
||||
TensorShapeProto element_shape_proto;
|
||||
element_shape.AsProto(&element_shape_proto);
|
||||
element_shape_proto.AppendToString(&metadata);
|
||||
data->set_metadata(metadata);
|
||||
}
|
||||
|
||||
static Status TensorListDeviceCopy(
|
||||
const TensorList& from, TensorList* to,
|
||||
const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copy) {
|
||||
to->element_shape = from.element_shape;
|
||||
to->element_dtype = from.element_dtype;
|
||||
to->max_num_elements = from.max_num_elements;
|
||||
to->tensors().reserve(from.tensors().size());
|
||||
for (const Tensor& t : from.tensors()) {
|
||||
to->tensors().emplace_back(t.dtype());
|
||||
if (t.dtype() != DT_INVALID) {
|
||||
TF_RETURN_IF_ERROR(copy(t, &to->tensors().back()));
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
#define REGISTER_LIST_COPY(DIRECTION) \
|
||||
INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION(TensorList, DIRECTION, \
|
||||
TensorListDeviceCopy)
|
||||
|
||||
REGISTER_LIST_COPY(VariantDeviceCopyDirection::HOST_TO_DEVICE);
|
||||
REGISTER_LIST_COPY(VariantDeviceCopyDirection::DEVICE_TO_HOST);
|
||||
REGISTER_LIST_COPY(VariantDeviceCopyDirection::DEVICE_TO_DEVICE);
|
||||
|
||||
REGISTER_UNARY_VARIANT_DECODE_FUNCTION(TensorList, TensorList::kTypeName);
|
||||
|
||||
bool TensorList::Decode(const VariantTensorData& data) {
|
||||
// TODO(srbs): Change the signature to Decode(VariantTensorData data) so
|
||||
// that we do not have to copy each tensor individually below. This would
|
||||
// require changing VariantTensorData::tensors() as well.
|
||||
string metadata;
|
||||
data.get_metadata(&metadata);
|
||||
uint64 scratch;
|
||||
StringPiece iter(metadata);
|
||||
std::vector<size_t> invalid_indices;
|
||||
core::GetVarint64(&iter, &scratch);
|
||||
size_t num_invalid_tensors = static_cast<size_t>(scratch);
|
||||
invalid_indices.resize(num_invalid_tensors);
|
||||
for (size_t i = 0; i < num_invalid_tensors; i++) {
|
||||
core::GetVarint64(&iter, &scratch);
|
||||
invalid_indices[i] = static_cast<size_t>(scratch);
|
||||
}
|
||||
|
||||
size_t total_num_tensors = data.tensors().size() + num_invalid_tensors;
|
||||
tensors().reserve(total_num_tensors);
|
||||
std::vector<size_t>::iterator invalid_indices_it = invalid_indices.begin();
|
||||
std::vector<Tensor>::const_iterator tensors_it = data.tensors().begin();
|
||||
for (size_t i = 0; i < total_num_tensors; i++) {
|
||||
if (invalid_indices_it != invalid_indices.end() &&
|
||||
*invalid_indices_it == i) {
|
||||
tensors().emplace_back(Tensor(DT_INVALID));
|
||||
invalid_indices_it++;
|
||||
} else if (tensors_it != data.tensors().end()) {
|
||||
tensors().emplace_back(*tensors_it);
|
||||
tensors_it++;
|
||||
} else {
|
||||
// VariantTensorData is corrupted.
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
core::GetVarint64(&iter, &scratch);
|
||||
element_dtype = static_cast<DataType>(scratch);
|
||||
core::GetVarint64(&iter, &scratch);
|
||||
max_num_elements = static_cast<int>(scratch);
|
||||
TensorShapeProto element_shape_proto;
|
||||
element_shape_proto.ParseFromString(string(iter.data(), iter.size()));
|
||||
element_shape = PartialTensorShape(element_shape_proto);
|
||||
return true;
|
||||
}
|
||||
|
||||
Status TensorShapeFromTensor(const Tensor& t, PartialTensorShape* out) {
|
||||
if (t.shape() == TensorShape({})) {
|
||||
if ((t.dtype() == DT_INT32 && t.scalar<int32>()() == -1) ||
|
||||
@ -257,8 +156,6 @@ class EmptyTensorList : public OpKernel {
|
||||
DataType element_dtype_;
|
||||
};
|
||||
|
||||
const char TensorList::kTypeName[] = "tensorflow::TensorList";
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("EmptyTensorList").Device(DEVICE_CPU),
|
||||
EmptyTensorList);
|
||||
|
||||
|
@ -29,6 +29,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/variant_op_registry.h"
|
||||
#include "tensorflow/core/kernels/concat_lib.h"
|
||||
#include "tensorflow/core/kernels/fill_functor.h"
|
||||
#include "tensorflow/core/kernels/tensor_list.h"
|
||||
#include "tensorflow/core/lib/core/coding.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/refcount.h"
|
||||
@ -41,137 +42,6 @@ namespace tensorflow {
|
||||
|
||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||
|
||||
// Variant compatible type for a list of tensors. This is mutable but instances
|
||||
// should never be mutated after stored in a variant tensor.
|
||||
//
|
||||
// **NOTE**: TensorList stores a refcounted container of tf::Tensor objects,
|
||||
// which are accessible via TensorList::tensors(). Because it is refcounted,
|
||||
// straight copies of the form:
|
||||
//
|
||||
// TensorList b = a;
|
||||
// b.tensors().push_back(t); // WARNING: This modifies a.tensors().
|
||||
//
|
||||
// Do not create a true copy of the underlying container - but instead increment
|
||||
// a reference count. Modifying b.tensors() modifies a.tensors(). In this way,
|
||||
// TensorList should be considered similar to the tf::Tensor object.
|
||||
//
|
||||
// In order to get a copy of the underlying list, use the Copy method:
|
||||
//
|
||||
// TensorList b = a.Copy();
|
||||
// b.tensors().push_back(t); // This does not modify a.tensors().
|
||||
//
|
||||
// Note that this is not a deep copy: the memory locations of the underlying
|
||||
// tensors will still point to the same locations of the corresponding tensors
|
||||
// in the original. To truly perform a deep copy, Device and Type-specific
|
||||
// code needs to be applied to the underlying tensors as usual.
|
||||
//
|
||||
// The most important implication of RefCounted TLs is that OpKernels
|
||||
// wishing to reuse TensorList inputs as outputs via context->forward_input()
|
||||
// need to perform an additional check on the refcount of the TensorList,
|
||||
// to ensure aliasing can be performed safely. For example:
|
||||
//
|
||||
// bool can_alias = false;
|
||||
// auto fw = c->forward_input(..., DT_VARIANT, {}, ...);
|
||||
// if (fw && fw->dtype() == DT_VARIANT && fw->NumElements() == 1) {
|
||||
// auto* tl = fw->scalar<Variant>()().get<TensorList>();
|
||||
// if (tl && tl->RefCountIsOne()) {
|
||||
// can_alias = true;
|
||||
// }
|
||||
// }
|
||||
//
|
||||
class TensorList {
|
||||
public:
|
||||
TensorList() : tensors_(new Tensors) {}
|
||||
~TensorList();
|
||||
|
||||
TensorList(const TensorList& other)
|
||||
: element_shape(other.element_shape),
|
||||
element_dtype(other.element_dtype),
|
||||
max_num_elements(other.max_num_elements),
|
||||
tensors_(other.tensors_) {
|
||||
tensors_->Ref();
|
||||
}
|
||||
|
||||
TensorList(TensorList&& rhs)
|
||||
: element_shape(std::move(rhs.element_shape)),
|
||||
element_dtype(rhs.element_dtype),
|
||||
max_num_elements(rhs.max_num_elements),
|
||||
tensors_(rhs.tensors_) {
|
||||
rhs.tensors_ = nullptr;
|
||||
}
|
||||
|
||||
TensorList& operator=(const TensorList& rhs) {
|
||||
if (this == &rhs) return *this;
|
||||
element_shape = rhs.element_shape;
|
||||
element_dtype = rhs.element_dtype;
|
||||
max_num_elements = rhs.max_num_elements;
|
||||
tensors_->Unref();
|
||||
tensors_ = rhs.tensors_;
|
||||
tensors_->Ref();
|
||||
return *this;
|
||||
}
|
||||
|
||||
TensorList& operator=(TensorList&& rhs) {
|
||||
if (this == &rhs) return *this;
|
||||
element_shape = rhs.element_shape;
|
||||
element_dtype = rhs.element_dtype;
|
||||
max_num_elements = rhs.max_num_elements;
|
||||
std::swap(tensors_, rhs.tensors_);
|
||||
return *this;
|
||||
}
|
||||
|
||||
static const char kTypeName[];
|
||||
|
||||
string TypeName() const { return kTypeName; }
|
||||
|
||||
void Encode(VariantTensorData* data) const;
|
||||
|
||||
bool Decode(const VariantTensorData& data);
|
||||
|
||||
// TODO(apassos) fill this out
|
||||
string DebugString() const { return "TensorList"; }
|
||||
|
||||
PartialTensorShape element_shape;
|
||||
|
||||
DataType element_dtype;
|
||||
|
||||
// The maximum allowed size of `tensors`. Defaults to -1 meaning that the size
|
||||
// of `tensors` is unbounded.
|
||||
int max_num_elements = -1;
|
||||
|
||||
// Access to the underlying tensor container.
|
||||
std::vector<Tensor>& tensors() { return tensors_->values_; }
|
||||
const std::vector<Tensor>& tensors() const { return tensors_->values_; }
|
||||
|
||||
// Get a new TensorList containing a copy of the underlying tensor container.
|
||||
TensorList Copy() const {
|
||||
TensorList out;
|
||||
out.element_shape = element_shape;
|
||||
out.element_dtype = element_dtype;
|
||||
out.max_num_elements = max_num_elements;
|
||||
// This performs a copy of the std::vector.
|
||||
out.tensors_->values_ = tensors_->values_;
|
||||
return out;
|
||||
}
|
||||
|
||||
// Is this TensorList the only one with a reference to the underlying
|
||||
// container?
|
||||
bool RefCountIsOne() const { return tensors_->RefCountIsOne(); }
|
||||
|
||||
private:
|
||||
class Tensors : public core::RefCounted {
|
||||
public:
|
||||
std::vector<Tensor> values_;
|
||||
};
|
||||
Tensors* tensors_;
|
||||
};
|
||||
|
||||
#if defined(PLATFORM_GOOGLE)
|
||||
// TODO(ebrevdo): Identify why Variant inline size is smaller on mobile devices.
|
||||
static_assert(Variant::CanInlineType<TensorList>(),
|
||||
"Must be able to inline TensorList into a Variant");
|
||||
#endif
|
||||
|
||||
Status TensorShapeFromTensor(const Tensor& t, PartialTensorShape* out);
|
||||
|
||||
Status GetElementShapeFromInput(OpKernelContext* c,
|
||||
|
127
tensorflow/core/kernels/tensor_list.cc
Normal file
127
tensorflow/core/kernels/tensor_list.cc
Normal file
@ -0,0 +1,127 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/core/kernels/tensor_list.h"
|
||||
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
||||
#include "tensorflow/core/framework/variant_op_registry.h"
|
||||
#include "tensorflow/core/lib/core/coding.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
TensorList::~TensorList() {
|
||||
if (tensors_) tensors_->Unref();
|
||||
}
|
||||
|
||||
void TensorList::Encode(VariantTensorData* data) const {
|
||||
data->set_type_name(TypeName());
|
||||
std::vector<size_t> invalid_indices;
|
||||
for (size_t i = 0; i < tensors().size(); i++) {
|
||||
if (tensors().at(i).dtype() != DT_INVALID) {
|
||||
*data->add_tensors() = tensors().at(i);
|
||||
} else {
|
||||
invalid_indices.push_back(i);
|
||||
}
|
||||
}
|
||||
string metadata;
|
||||
// TODO(b/118838800): Add a proto for storing the metadata.
|
||||
// Metadata format:
|
||||
// <num_invalid_tensors><invalid_indices><element_dtype><element_shape_proto>
|
||||
core::PutVarint64(&metadata, static_cast<uint64>(invalid_indices.size()));
|
||||
for (size_t i : invalid_indices) {
|
||||
core::PutVarint64(&metadata, static_cast<uint64>(i));
|
||||
}
|
||||
core::PutVarint64(&metadata, static_cast<uint64>(element_dtype));
|
||||
core::PutVarint64(&metadata, static_cast<uint64>(max_num_elements));
|
||||
TensorShapeProto element_shape_proto;
|
||||
element_shape.AsProto(&element_shape_proto);
|
||||
element_shape_proto.AppendToString(&metadata);
|
||||
data->set_metadata(metadata);
|
||||
}
|
||||
|
||||
static Status TensorListDeviceCopy(
|
||||
const TensorList& from, TensorList* to,
|
||||
const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copy) {
|
||||
to->element_shape = from.element_shape;
|
||||
to->element_dtype = from.element_dtype;
|
||||
to->max_num_elements = from.max_num_elements;
|
||||
to->tensors().reserve(from.tensors().size());
|
||||
for (const Tensor& t : from.tensors()) {
|
||||
to->tensors().emplace_back(t.dtype());
|
||||
if (t.dtype() != DT_INVALID) {
|
||||
TF_RETURN_IF_ERROR(copy(t, &to->tensors().back()));
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
#define REGISTER_LIST_COPY(DIRECTION) \
|
||||
INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION(TensorList, DIRECTION, \
|
||||
TensorListDeviceCopy)
|
||||
|
||||
REGISTER_LIST_COPY(VariantDeviceCopyDirection::HOST_TO_DEVICE);
|
||||
REGISTER_LIST_COPY(VariantDeviceCopyDirection::DEVICE_TO_HOST);
|
||||
REGISTER_LIST_COPY(VariantDeviceCopyDirection::DEVICE_TO_DEVICE);
|
||||
|
||||
REGISTER_UNARY_VARIANT_DECODE_FUNCTION(TensorList, TensorList::kTypeName);
|
||||
|
||||
bool TensorList::Decode(const VariantTensorData& data) {
|
||||
// TODO(srbs): Change the signature to Decode(VariantTensorData data) so
|
||||
// that we do not have to copy each tensor individually below. This would
|
||||
// require changing VariantTensorData::tensors() as well.
|
||||
string metadata;
|
||||
data.get_metadata(&metadata);
|
||||
uint64 scratch;
|
||||
StringPiece iter(metadata);
|
||||
std::vector<size_t> invalid_indices;
|
||||
core::GetVarint64(&iter, &scratch);
|
||||
size_t num_invalid_tensors = static_cast<size_t>(scratch);
|
||||
invalid_indices.resize(num_invalid_tensors);
|
||||
for (size_t i = 0; i < num_invalid_tensors; i++) {
|
||||
core::GetVarint64(&iter, &scratch);
|
||||
invalid_indices[i] = static_cast<size_t>(scratch);
|
||||
}
|
||||
|
||||
size_t total_num_tensors = data.tensors().size() + num_invalid_tensors;
|
||||
tensors().reserve(total_num_tensors);
|
||||
std::vector<size_t>::iterator invalid_indices_it = invalid_indices.begin();
|
||||
std::vector<Tensor>::const_iterator tensors_it = data.tensors().begin();
|
||||
for (size_t i = 0; i < total_num_tensors; i++) {
|
||||
if (invalid_indices_it != invalid_indices.end() &&
|
||||
*invalid_indices_it == i) {
|
||||
tensors().emplace_back(Tensor(DT_INVALID));
|
||||
invalid_indices_it++;
|
||||
} else if (tensors_it != data.tensors().end()) {
|
||||
tensors().emplace_back(*tensors_it);
|
||||
tensors_it++;
|
||||
} else {
|
||||
// VariantTensorData is corrupted.
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
core::GetVarint64(&iter, &scratch);
|
||||
element_dtype = static_cast<DataType>(scratch);
|
||||
core::GetVarint64(&iter, &scratch);
|
||||
max_num_elements = static_cast<int>(scratch);
|
||||
TensorShapeProto element_shape_proto;
|
||||
element_shape_proto.ParseFromString(string(iter.data(), iter.size()));
|
||||
element_shape = PartialTensorShape(element_shape_proto);
|
||||
return true;
|
||||
}
|
||||
|
||||
const char TensorList::kTypeName[] = "tensorflow::TensorList";
|
||||
|
||||
} // namespace tensorflow
|
159
tensorflow/core/kernels/tensor_list.h
Normal file
159
tensorflow/core/kernels/tensor_list.h
Normal file
@ -0,0 +1,159 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_CORE_KERNELS_TENSOR_LIST_H_
|
||||
#define TENSORFLOW_CORE_KERNELS_TENSOR_LIST_H_
|
||||
|
||||
#include <utility>
|
||||
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/variant.h"
|
||||
#include "tensorflow/core/framework/variant_tensor_data.h"
|
||||
#include "tensorflow/core/lib/core/refcount.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Variant compatible type for a list of tensors. This is mutable but instances
|
||||
// should never be mutated after stored in a variant tensor.
|
||||
//
|
||||
// **NOTE**: TensorList stores a refcounted container of tf::Tensor objects,
|
||||
// which are accessible via TensorList::tensors(). Because it is refcounted,
|
||||
// straight copies of the form:
|
||||
//
|
||||
// TensorList b = a;
|
||||
// b.tensors().push_back(t); // WARNING: This modifies a.tensors().
|
||||
//
|
||||
// Do not create a true copy of the underlying container - but instead increment
|
||||
// a reference count. Modifying b.tensors() modifies a.tensors(). In this way,
|
||||
// TensorList should be considered similar to the tf::Tensor object.
|
||||
//
|
||||
// In order to get a copy of the underlying list, use the Copy method:
|
||||
//
|
||||
// TensorList b = a.Copy();
|
||||
// b.tensors().push_back(t); // This does not modify a.tensors().
|
||||
//
|
||||
// Note that this is not a deep copy: the memory locations of the underlying
|
||||
// tensors will still point to the same locations of the corresponding tensors
|
||||
// in the original. To truly perform a deep copy, Device and Type-specific
|
||||
// code needs to be applied to the underlying tensors as usual.
|
||||
//
|
||||
// The most important implication of RefCounted TLs is that OpKernels
|
||||
// wishing to reuse TensorList inputs as outputs via context->forward_input()
|
||||
// need to perform an additional check on the refcount of the TensorList,
|
||||
// to ensure aliasing can be performed safely. For example:
|
||||
//
|
||||
// bool can_alias = false;
|
||||
// auto fw = c->forward_input(..., DT_VARIANT, {}, ...);
|
||||
// if (fw && fw->dtype() == DT_VARIANT && fw->NumElements() == 1) {
|
||||
// auto* tl = fw->scalar<Variant>()().get<TensorList>();
|
||||
// if (tl && tl->RefCountIsOne()) {
|
||||
// can_alias = true;
|
||||
// }
|
||||
// }
|
||||
//
|
||||
class TensorList {
|
||||
public:
|
||||
TensorList() : tensors_(new Tensors) {}
|
||||
~TensorList();
|
||||
|
||||
TensorList(const TensorList& other)
|
||||
: element_shape(other.element_shape),
|
||||
element_dtype(other.element_dtype),
|
||||
max_num_elements(other.max_num_elements),
|
||||
tensors_(other.tensors_) {
|
||||
tensors_->Ref();
|
||||
}
|
||||
|
||||
TensorList(TensorList&& rhs)
|
||||
: element_shape(std::move(rhs.element_shape)),
|
||||
element_dtype(rhs.element_dtype),
|
||||
max_num_elements(rhs.max_num_elements),
|
||||
tensors_(rhs.tensors_) {
|
||||
rhs.tensors_ = nullptr;
|
||||
}
|
||||
|
||||
TensorList& operator=(const TensorList& rhs) {
|
||||
if (this == &rhs) return *this;
|
||||
element_shape = rhs.element_shape;
|
||||
element_dtype = rhs.element_dtype;
|
||||
max_num_elements = rhs.max_num_elements;
|
||||
tensors_->Unref();
|
||||
tensors_ = rhs.tensors_;
|
||||
tensors_->Ref();
|
||||
return *this;
|
||||
}
|
||||
|
||||
TensorList& operator=(TensorList&& rhs) {
|
||||
if (this == &rhs) return *this;
|
||||
element_shape = rhs.element_shape;
|
||||
element_dtype = rhs.element_dtype;
|
||||
max_num_elements = rhs.max_num_elements;
|
||||
std::swap(tensors_, rhs.tensors_);
|
||||
return *this;
|
||||
}
|
||||
|
||||
static const char kTypeName[];
|
||||
|
||||
string TypeName() const { return kTypeName; }
|
||||
|
||||
void Encode(VariantTensorData* data) const;
|
||||
|
||||
bool Decode(const VariantTensorData& data);
|
||||
|
||||
// TODO(apassos) fill this out
|
||||
string DebugString() const { return "TensorList"; }
|
||||
|
||||
PartialTensorShape element_shape;
|
||||
|
||||
DataType element_dtype;
|
||||
|
||||
// The maximum allowed size of `tensors`. Defaults to -1 meaning that the size
|
||||
// of `tensors` is unbounded.
|
||||
int max_num_elements = -1;
|
||||
|
||||
// Access to the underlying tensor container.
|
||||
std::vector<Tensor>& tensors() { return tensors_->values_; }
|
||||
const std::vector<Tensor>& tensors() const { return tensors_->values_; }
|
||||
|
||||
// Get a new TensorList containing a copy of the underlying tensor container.
|
||||
TensorList Copy() const {
|
||||
TensorList out;
|
||||
out.element_shape = element_shape;
|
||||
out.element_dtype = element_dtype;
|
||||
out.max_num_elements = max_num_elements;
|
||||
// This performs a copy of the std::vector.
|
||||
out.tensors_->values_ = tensors_->values_;
|
||||
return out;
|
||||
}
|
||||
|
||||
// Is this TensorList the only one with a reference to the underlying
|
||||
// container?
|
||||
bool RefCountIsOne() const { return tensors_->RefCountIsOne(); }
|
||||
|
||||
private:
|
||||
class Tensors : public core::RefCounted {
|
||||
public:
|
||||
std::vector<Tensor> values_;
|
||||
};
|
||||
Tensors* tensors_;
|
||||
};
|
||||
|
||||
#if defined(PLATFORM_GOOGLE)
|
||||
// TODO(ebrevdo): Identify why Variant inline size is smaller on mobile devices.
|
||||
static_assert(Variant::CanInlineType<TensorList>(),
|
||||
"Must be able to inline TensorList into a Variant");
|
||||
#endif
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_KERNELS_TENSOR_LIST_H_
|
Loading…
Reference in New Issue
Block a user