Make TensorList objects Refcounted.
This drastically reduces the amount of refcounting of individual tensors inside TensorList when a TensorList variant is copied to a Variable or MutableDenseHashTable (and back). Same for operations like tf.stack that operate on Variant tensors and perform Variant copies implicitly. While this change adds a level of indirection into the TensorList object by adding a heap-allocated RefCounted object to contain the vector, it also reduces the size of the TensorList below the tf::Variant inlining threshold. This in turn removes a level of heap indirection and should cancel out any performance regressions for existing TensorList operations and small-size lists. PiperOrigin-RevId: 259464769
This commit is contained in:
parent
4f910ac64b
commit
91425cf597
@ -14,6 +14,7 @@ limitations under the License.
|
|||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include <limits>
|
#include <limits>
|
||||||
|
|
||||||
#include "tensorflow/core/framework/allocator.h"
|
#include "tensorflow/core/framework/allocator.h"
|
||||||
|
|
||||||
#define EIGEN_USE_THREADS
|
#define EIGEN_USE_THREADS
|
||||||
@ -21,8 +22,6 @@ limitations under the License.
|
|||||||
#define EIGEN_USE_GPU
|
#define EIGEN_USE_GPU
|
||||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
|
||||||
#include "tensorflow/core/kernels/list_kernels.h"
|
|
||||||
|
|
||||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
#include "tensorflow/core/framework/register_types.h"
|
#include "tensorflow/core/framework/register_types.h"
|
||||||
@ -30,6 +29,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/variant.h"
|
#include "tensorflow/core/framework/variant.h"
|
||||||
#include "tensorflow/core/framework/variant_op_registry.h"
|
#include "tensorflow/core/framework/variant_op_registry.h"
|
||||||
#include "tensorflow/core/kernels/concat_lib.h"
|
#include "tensorflow/core/kernels/concat_lib.h"
|
||||||
|
#include "tensorflow/core/kernels/list_kernels.h"
|
||||||
#include "tensorflow/core/lib/core/coding.h"
|
#include "tensorflow/core/lib/core/coding.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
#include "tensorflow/core/util/util.h"
|
#include "tensorflow/core/util/util.h"
|
||||||
@ -38,20 +38,16 @@ namespace tensorflow {
|
|||||||
|
|
||||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||||
|
|
||||||
// Variant compatible type for a list of tensors. This is mutable but instances
|
TensorList::~TensorList() {
|
||||||
// should never be mutated after stored in a variant tensor.
|
if (tensors_) tensors_->Unref();
|
||||||
TensorList::TensorList(const TensorList& other)
|
}
|
||||||
: tensors(other.tensors),
|
|
||||||
element_shape(other.element_shape),
|
|
||||||
element_dtype(other.element_dtype),
|
|
||||||
max_num_elements(other.max_num_elements) {}
|
|
||||||
|
|
||||||
void TensorList::Encode(VariantTensorData* data) const {
|
void TensorList::Encode(VariantTensorData* data) const {
|
||||||
data->set_type_name(TypeName());
|
data->set_type_name(TypeName());
|
||||||
std::vector<size_t> invalid_indices;
|
std::vector<size_t> invalid_indices;
|
||||||
for (size_t i = 0; i < tensors.size(); i++) {
|
for (size_t i = 0; i < tensors().size(); i++) {
|
||||||
if (tensors.at(i).dtype() != DT_INVALID) {
|
if (tensors().at(i).dtype() != DT_INVALID) {
|
||||||
*data->add_tensors() = tensors.at(i);
|
*data->add_tensors() = tensors().at(i);
|
||||||
} else {
|
} else {
|
||||||
invalid_indices.push_back(i);
|
invalid_indices.push_back(i);
|
||||||
}
|
}
|
||||||
@ -78,11 +74,11 @@ static Status TensorListDeviceCopy(
|
|||||||
to->element_shape = from.element_shape;
|
to->element_shape = from.element_shape;
|
||||||
to->element_dtype = from.element_dtype;
|
to->element_dtype = from.element_dtype;
|
||||||
to->max_num_elements = from.max_num_elements;
|
to->max_num_elements = from.max_num_elements;
|
||||||
to->tensors.reserve(from.tensors.size());
|
to->tensors().reserve(from.tensors().size());
|
||||||
for (const Tensor& t : from.tensors) {
|
for (const Tensor& t : from.tensors()) {
|
||||||
to->tensors.emplace_back(t.dtype());
|
to->tensors().emplace_back(t.dtype());
|
||||||
if (t.dtype() != DT_INVALID) {
|
if (t.dtype() != DT_INVALID) {
|
||||||
TF_RETURN_IF_ERROR(copy(t, &to->tensors.back()));
|
TF_RETURN_IF_ERROR(copy(t, &to->tensors().back()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
@ -116,16 +112,16 @@ bool TensorList::Decode(const VariantTensorData& data) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
size_t total_num_tensors = data.tensors().size() + num_invalid_tensors;
|
size_t total_num_tensors = data.tensors().size() + num_invalid_tensors;
|
||||||
tensors.reserve(total_num_tensors);
|
tensors().reserve(total_num_tensors);
|
||||||
std::vector<size_t>::iterator invalid_indices_it = invalid_indices.begin();
|
std::vector<size_t>::iterator invalid_indices_it = invalid_indices.begin();
|
||||||
std::vector<Tensor>::const_iterator tensors_it = data.tensors().begin();
|
std::vector<Tensor>::const_iterator tensors_it = data.tensors().begin();
|
||||||
for (size_t i = 0; i < total_num_tensors; i++) {
|
for (size_t i = 0; i < total_num_tensors; i++) {
|
||||||
if (invalid_indices_it != invalid_indices.end() &&
|
if (invalid_indices_it != invalid_indices.end() &&
|
||||||
*invalid_indices_it == i) {
|
*invalid_indices_it == i) {
|
||||||
tensors.emplace_back(Tensor(DT_INVALID));
|
tensors().emplace_back(Tensor(DT_INVALID));
|
||||||
invalid_indices_it++;
|
invalid_indices_it++;
|
||||||
} else if (tensors_it != data.tensors().end()) {
|
} else if (tensors_it != data.tensors().end()) {
|
||||||
tensors.emplace_back(*tensors_it);
|
tensors().emplace_back(*tensors_it);
|
||||||
tensors_it++;
|
tensors_it++;
|
||||||
} else {
|
} else {
|
||||||
// VariantTensorData is corrupted.
|
// VariantTensorData is corrupted.
|
||||||
@ -201,19 +197,31 @@ Status ForwardInputOrCreateNewList(OpKernelContext* c, int32 input_index,
|
|||||||
input_index, output_index, DT_VARIANT, TensorShape{},
|
input_index, output_index, DT_VARIANT, TensorShape{},
|
||||||
c->input_memory_type(input_index), AllocatorAttributes());
|
c->input_memory_type(input_index), AllocatorAttributes());
|
||||||
Tensor* output_tensor;
|
Tensor* output_tensor;
|
||||||
if (maybe_output != nullptr) {
|
if (maybe_output != nullptr && maybe_output->dtype() == DT_VARIANT &&
|
||||||
// Woohoo, forwarding succeeded!
|
maybe_output->NumElements() == 1) {
|
||||||
output_tensor = maybe_output.get();
|
output_tensor = maybe_output.get();
|
||||||
c->set_output(output_index, *output_tensor);
|
TensorList* tmp_out = output_tensor->scalar<Variant>()().get<TensorList>();
|
||||||
} else {
|
if (tmp_out == nullptr) {
|
||||||
// If forwarding is not possible allocate a new output tensor and copy
|
return errors::InvalidArgument(
|
||||||
// the `input_list` to it.
|
"Expected input ", input_index, " to be a TensorList but saw ",
|
||||||
AllocatorAttributes attr;
|
output_tensor->scalar<Variant>()().TypeName());
|
||||||
attr.set_on_host(true);
|
}
|
||||||
TF_RETURN_IF_ERROR(
|
if (tmp_out->RefCountIsOne()) {
|
||||||
c->allocate_output(output_index, {}, &output_tensor, attr));
|
// Woohoo, forwarding succeeded!
|
||||||
output_tensor->scalar<Variant>()() = input_list;
|
c->set_output(output_index, *output_tensor);
|
||||||
|
*output_list = tmp_out;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If forwarding is not possible allocate a new output tensor and copy
|
||||||
|
// the `input_list` to it.
|
||||||
|
AllocatorAttributes attr;
|
||||||
|
attr.set_on_host(true);
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
c->allocate_output(output_index, {}, &output_tensor, attr));
|
||||||
|
output_tensor->scalar<Variant>()() = input_list.Copy();
|
||||||
|
|
||||||
*output_list = output_tensor->scalar<Variant>()().get<TensorList>();
|
*output_list = output_tensor->scalar<Variant>()().get<TensorList>();
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
@ -295,15 +303,15 @@ class TensorListPushBack : public OpKernel {
|
|||||||
|
|
||||||
if (l->max_num_elements != -1) {
|
if (l->max_num_elements != -1) {
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(
|
||||||
c, l->tensors.size() < l->max_num_elements,
|
c, l->tensors().size() < l->max_num_elements,
|
||||||
errors::InvalidArgument("Tried to push item into a full list",
|
errors::InvalidArgument("Tried to push item into a full list",
|
||||||
" list size: ", l->tensors.size(),
|
" list size: ", l->tensors().size(),
|
||||||
" max_num_elements: ", l->max_num_elements));
|
" max_num_elements: ", l->max_num_elements));
|
||||||
}
|
}
|
||||||
|
|
||||||
TensorList* output_list = nullptr;
|
TensorList* output_list = nullptr;
|
||||||
OP_REQUIRES_OK(c, ForwardInputOrCreateNewList(c, 0, 0, *l, &output_list));
|
OP_REQUIRES_OK(c, ForwardInputOrCreateNewList(c, 0, 0, *l, &output_list));
|
||||||
output_list->tensors.push_back(input);
|
output_list->tensors().push_back(input);
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@ -330,7 +338,7 @@ class TensorListLength : public OpKernel {
|
|||||||
OP_REQUIRES_OK(c, GetInputList(c, 0, &l));
|
OP_REQUIRES_OK(c, GetInputList(c, 0, &l));
|
||||||
Tensor* result;
|
Tensor* result;
|
||||||
OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape{}, &result));
|
OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape{}, &result));
|
||||||
result->scalar<int32>()() = l->tensors.size();
|
result->scalar<int32>()() = l->tensors().size();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -399,7 +407,7 @@ class TensorListReserve : public OpKernel {
|
|||||||
TensorList output;
|
TensorList output;
|
||||||
output.element_shape = element_shape;
|
output.element_shape = element_shape;
|
||||||
output.element_dtype = element_dtype_;
|
output.element_dtype = element_dtype_;
|
||||||
output.tensors.resize(num_elements, Tensor(DT_INVALID));
|
output.tensors().resize(num_elements, Tensor(DT_INVALID));
|
||||||
Tensor* result;
|
Tensor* result;
|
||||||
AllocatorAttributes attr;
|
AllocatorAttributes attr;
|
||||||
attr.set_on_host(true);
|
attr.set_on_host(true);
|
||||||
@ -440,32 +448,37 @@ class TensorListResize : public OpKernel {
|
|||||||
c->forward_input(0, 0, DT_VARIANT, TensorShape{},
|
c->forward_input(0, 0, DT_VARIANT, TensorShape{},
|
||||||
c->input_memory_type(0), AllocatorAttributes());
|
c->input_memory_type(0), AllocatorAttributes());
|
||||||
if (maybe_result != nullptr) {
|
if (maybe_result != nullptr) {
|
||||||
maybe_result->scalar<Variant>()().get<TensorList>()->tensors.resize(
|
TensorList* out = maybe_result->scalar<Variant>()().get<TensorList>();
|
||||||
size, Tensor(DT_INVALID));
|
if (out->RefCountIsOne()) {
|
||||||
c->set_output(0, *maybe_result);
|
// We are able to forward the input.
|
||||||
} else {
|
out->tensors().resize(size, Tensor(DT_INVALID));
|
||||||
Tensor* result;
|
c->set_output(0, *maybe_result);
|
||||||
AllocatorAttributes attr;
|
return;
|
||||||
attr.set_on_host(true);
|
|
||||||
OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape{}, &result, attr));
|
|
||||||
TensorList output_list;
|
|
||||||
output_list.element_shape = input_list->element_shape;
|
|
||||||
output_list.element_dtype = input_list->element_dtype;
|
|
||||||
output_list.max_num_elements = input_list->max_num_elements;
|
|
||||||
if (size > input_list->tensors.size()) {
|
|
||||||
output_list.tensors.insert(output_list.tensors.begin(),
|
|
||||||
input_list->tensors.begin(),
|
|
||||||
input_list->tensors.end());
|
|
||||||
// Add DT_INVALID tensors to the end of the list if the requested size
|
|
||||||
// is larger than the list length.
|
|
||||||
output_list.tensors.resize(size, Tensor(DT_INVALID));
|
|
||||||
} else {
|
|
||||||
output_list.tensors.insert(output_list.tensors.begin(),
|
|
||||||
input_list->tensors.begin(),
|
|
||||||
input_list->tensors.begin() + size);
|
|
||||||
}
|
}
|
||||||
result->scalar<Variant>()() = std::move(output_list);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// We were not able to forward the input. Will have to resize from scratch.
|
||||||
|
Tensor* result;
|
||||||
|
AllocatorAttributes attr;
|
||||||
|
attr.set_on_host(true);
|
||||||
|
OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape{}, &result, attr));
|
||||||
|
TensorList output_list;
|
||||||
|
output_list.element_shape = input_list->element_shape;
|
||||||
|
output_list.element_dtype = input_list->element_dtype;
|
||||||
|
output_list.max_num_elements = input_list->max_num_elements;
|
||||||
|
if (size > input_list->tensors().size()) {
|
||||||
|
output_list.tensors().insert(output_list.tensors().begin(),
|
||||||
|
input_list->tensors().begin(),
|
||||||
|
input_list->tensors().end());
|
||||||
|
// Add DT_INVALID tensors to the end of the list if the requested size
|
||||||
|
// is larger than the list length.
|
||||||
|
output_list.tensors().resize(size, Tensor(DT_INVALID));
|
||||||
|
} else {
|
||||||
|
output_list.tensors().insert(output_list.tensors().begin(),
|
||||||
|
input_list->tensors().begin(),
|
||||||
|
input_list->tensors().begin() + size);
|
||||||
|
}
|
||||||
|
result->scalar<Variant>()() = std::move(output_list);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -495,9 +508,9 @@ class TensorListSetItem : public OpKernel {
|
|||||||
" but list elements ",
|
" but list elements ",
|
||||||
DataTypeString(l->element_dtype)));
|
DataTypeString(l->element_dtype)));
|
||||||
int32 index = c->input(1).scalar<int32>()();
|
int32 index = c->input(1).scalar<int32>()();
|
||||||
OP_REQUIRES(c, index < l->tensors.size(),
|
OP_REQUIRES(c, index < l->tensors().size(),
|
||||||
errors::InvalidArgument("Trying to modify element ", index,
|
errors::InvalidArgument("Trying to modify element ", index,
|
||||||
" in a list with ", l->tensors.size(),
|
" in a list with ", l->tensors().size(),
|
||||||
" elements."));
|
" elements."));
|
||||||
const Tensor& value = c->input(2);
|
const Tensor& value = c->input(2);
|
||||||
OP_REQUIRES(c, l->element_shape.IsCompatibleWith(value.shape()),
|
OP_REQUIRES(c, l->element_shape.IsCompatibleWith(value.shape()),
|
||||||
@ -508,7 +521,7 @@ class TensorListSetItem : public OpKernel {
|
|||||||
" list shape: ", l->element_shape.DebugString()));
|
" list shape: ", l->element_shape.DebugString()));
|
||||||
TensorList* output_list = nullptr;
|
TensorList* output_list = nullptr;
|
||||||
OP_REQUIRES_OK(c, ForwardInputOrCreateNewList(c, 0, 0, *l, &output_list));
|
OP_REQUIRES_OK(c, ForwardInputOrCreateNewList(c, 0, 0, *l, &output_list));
|
||||||
output_list->tensors[index] = value;
|
output_list->tensors()[index] = value;
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@ -560,11 +573,26 @@ class TensorListConcatLists : public OpKernel {
|
|||||||
const Tensor& tl_a = c->input(0);
|
const Tensor& tl_a = c->input(0);
|
||||||
const Tensor& tl_b = c->input(1);
|
const Tensor& tl_b = c->input(1);
|
||||||
|
|
||||||
Tensor* output;
|
Tensor* output = nullptr;
|
||||||
if (tl_alias) {
|
bool ok_to_alias = tl_alias != nullptr;
|
||||||
c->set_output(0, *tl_alias);
|
if (tl_alias && tl_alias->dtype() == DT_VARIANT &&
|
||||||
output = tl_alias.get();
|
tl_alias->NumElements() > 0) {
|
||||||
} else {
|
auto tl_a_t = tl_alias->flat<Variant>();
|
||||||
|
for (int64 i = 0; i < tl_alias->NumElements(); ++i) {
|
||||||
|
TensorList* aliased = tl_a_t(i).get<TensorList>();
|
||||||
|
if (aliased == nullptr || !aliased->RefCountIsOne()) {
|
||||||
|
ok_to_alias = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (ok_to_alias) {
|
||||||
|
c->set_output(0, *tl_alias);
|
||||||
|
output = tl_alias.get();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!ok_to_alias) {
|
||||||
|
// Couldn't alias the entire Tensor. We'll be conservative and not try
|
||||||
|
// to alias individual batch entries.
|
||||||
attr.set_on_host(true);
|
attr.set_on_host(true);
|
||||||
OP_REQUIRES_OK(c, c->allocate_output(0, tl_a_shape, &output, attr));
|
OP_REQUIRES_OK(c, c->allocate_output(0, tl_a_shape, &output, attr));
|
||||||
}
|
}
|
||||||
@ -573,45 +601,42 @@ class TensorListConcatLists : public OpKernel {
|
|||||||
auto tl_a_t = tl_a.flat<Variant>();
|
auto tl_a_t = tl_a.flat<Variant>();
|
||||||
auto tl_b_t = tl_b.flat<Variant>();
|
auto tl_b_t = tl_b.flat<Variant>();
|
||||||
|
|
||||||
for (int64 b = 0; b < tl_a.NumElements(); ++b) {
|
for (int64 i = 0; i < tl_a.NumElements(); ++i) {
|
||||||
const TensorList* l_a = tl_a_t(b).get<TensorList>();
|
const TensorList* l_a = tl_a_t(i).get<TensorList>();
|
||||||
const TensorList* l_b = tl_b_t(b).get<TensorList>();
|
const TensorList* l_b = tl_b_t(i).get<TensorList>();
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(
|
||||||
c, l_a != nullptr,
|
c, l_a != nullptr,
|
||||||
errors::InvalidArgument("input_a is not a TensorList at index ", b,
|
errors::InvalidArgument("input_a is not a TensorList at index ", i,
|
||||||
". Saw: '", tl_a_t(b).DebugString(), "'"));
|
". Saw: '", tl_a_t(i).DebugString(), "'"));
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(
|
||||||
c, l_b != nullptr,
|
c, l_b != nullptr,
|
||||||
errors::InvalidArgument("input_b is not a TensorList at index ", b,
|
errors::InvalidArgument("input_b is not a TensorList at index ", i,
|
||||||
". Saw: '", tl_b_t(b).DebugString(), "'"));
|
". Saw: '", tl_b_t(i).DebugString(), "'"));
|
||||||
OP_REQUIRES(c, l_a->element_dtype == element_dtype_,
|
OP_REQUIRES(c, l_a->element_dtype == element_dtype_,
|
||||||
errors::InvalidArgument(
|
errors::InvalidArgument(
|
||||||
"input_a[", b, "].dtype != element_dtype. Saw: ",
|
"input_a[", i, "].dtype != element_dtype. Saw: ",
|
||||||
DataTypeString(l_a->element_dtype), " vs. ",
|
DataTypeString(l_a->element_dtype), " vs. ",
|
||||||
DataTypeString(element_dtype_)));
|
DataTypeString(element_dtype_)));
|
||||||
OP_REQUIRES(c, l_b->element_dtype == element_dtype_,
|
OP_REQUIRES(c, l_b->element_dtype == element_dtype_,
|
||||||
errors::InvalidArgument(
|
errors::InvalidArgument(
|
||||||
"input_b[", b, "].dtype != element_dtype. Saw: ",
|
"input_b[", i, "].dtype != element_dtype. Saw: ",
|
||||||
DataTypeString(l_b->element_dtype), " vs. ",
|
DataTypeString(l_b->element_dtype), " vs. ",
|
||||||
DataTypeString(element_dtype_)));
|
DataTypeString(element_dtype_)));
|
||||||
OP_REQUIRES(c, l_a->element_shape.IsIdenticalTo(l_b->element_shape),
|
OP_REQUIRES(c, l_a->element_shape.IsIdenticalTo(l_b->element_shape),
|
||||||
errors::InvalidArgument(
|
errors::InvalidArgument(
|
||||||
"input_a and input_b TensorList element shapes are not "
|
"input_a and input_b TensorList element shapes are not "
|
||||||
"identical at index ",
|
"identical at index ",
|
||||||
b, ". Saw ", l_a->element_shape.DebugString(), " vs. ",
|
i, ". Saw ", l_a->element_shape.DebugString(), " vs. ",
|
||||||
l_b->element_shape.DebugString()));
|
l_b->element_shape.DebugString()));
|
||||||
if (tl_alias) {
|
if (ok_to_alias) {
|
||||||
TensorList* out = output_t(b).get<TensorList>();
|
TensorList* out = output_t(i).get<TensorList>();
|
||||||
DCHECK(out != nullptr) << "Expected output to alias input_a, but it "
|
std::copy(l_b->tensors().begin(), l_b->tensors().end(),
|
||||||
"doesn't contain a TensorList at index "
|
std::back_inserter(out->tensors()));
|
||||||
<< b;
|
|
||||||
std::copy(l_b->tensors.begin(), l_b->tensors.end(),
|
|
||||||
std::back_inserter(out->tensors));
|
|
||||||
} else {
|
} else {
|
||||||
TensorList out = *l_a;
|
TensorList out = l_a->Copy();
|
||||||
std::copy(l_b->tensors.begin(), l_b->tensors.end(),
|
std::copy(l_b->tensors().begin(), l_b->tensors().end(),
|
||||||
std::back_inserter(out.tensors));
|
std::back_inserter(out.tensors()));
|
||||||
output_t(b) = std::move(out);
|
output_t(i) = std::move(out);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -31,7 +31,9 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/kernels/fill_functor.h"
|
#include "tensorflow/core/kernels/fill_functor.h"
|
||||||
#include "tensorflow/core/lib/core/coding.h"
|
#include "tensorflow/core/lib/core/coding.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
|
#include "tensorflow/core/lib/core/refcount.h"
|
||||||
#include "tensorflow/core/lib/gtl/array_slice.h"
|
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||||
|
#include "tensorflow/core/platform/platform.h"
|
||||||
#include "tensorflow/core/util/tensor_ops_util.h"
|
#include "tensorflow/core/util/tensor_ops_util.h"
|
||||||
#include "tensorflow/core/util/util.h"
|
#include "tensorflow/core/util/util.h"
|
||||||
|
|
||||||
@ -41,12 +43,85 @@ typedef Eigen::ThreadPoolDevice CPUDevice;
|
|||||||
|
|
||||||
// Variant compatible type for a list of tensors. This is mutable but instances
|
// Variant compatible type for a list of tensors. This is mutable but instances
|
||||||
// should never be mutated after stored in a variant tensor.
|
// should never be mutated after stored in a variant tensor.
|
||||||
struct TensorList {
|
//
|
||||||
|
// **NOTE**: TensorList stores a refcounted container of tf::Tensor objects,
|
||||||
|
// which are accessible via TensorList::tensors(). Because it is refcounted,
|
||||||
|
// straight copies of the form:
|
||||||
|
//
|
||||||
|
// TensorList b = a;
|
||||||
|
// b.tensors().push_back(t); // WARNING: This modifies a.tensors().
|
||||||
|
//
|
||||||
|
// Do not create a true copy of the underlying container - but instead increment
|
||||||
|
// a reference count. Modifying b.tensors() modifies a.tensors(). In this way,
|
||||||
|
// TensorList should be considered similar to the tf::Tensor object.
|
||||||
|
//
|
||||||
|
// In order to get a copy of the underlying list, use the Copy method:
|
||||||
|
//
|
||||||
|
// TensorList b = a.Copy();
|
||||||
|
// b.tensors().push_back(t); // This does not modify a.tensors().
|
||||||
|
//
|
||||||
|
// Note that this is not a deep copy: the memory locations of the underlying
|
||||||
|
// tensors will still point to the same locations of the corresponding tensors
|
||||||
|
// in the original. To truly perform a deep copy, Device and Type-specific
|
||||||
|
// code needs to be applied to the underlying tensors as usual.
|
||||||
|
//
|
||||||
|
// The most important implication of RefCounted TLs is that OpKernels
|
||||||
|
// wishing to reuse TensorList inputs as outputs via context->forward_input()
|
||||||
|
// need to perform an additional check on the refcount of the TensorList,
|
||||||
|
// to ensure aliasing can be performed safely. For example:
|
||||||
|
//
|
||||||
|
// bool can_alias = false;
|
||||||
|
// auto fw = c->forward_input(..., DT_VARIANT, {}, ...);
|
||||||
|
// if (fw && fw->dtype() == DT_VARIANT && fw->NumElements() == 1) {
|
||||||
|
// auto* tl = fw->scalar<Variant>()().get<TensorList>();
|
||||||
|
// if (tl && tl->RefCountIsOne()) {
|
||||||
|
// can_alias = true;
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
class TensorList {
|
||||||
public:
|
public:
|
||||||
TensorList() {}
|
TensorList() : tensors_(new Tensors) {}
|
||||||
TensorList(const TensorList& other);
|
~TensorList();
|
||||||
|
|
||||||
|
TensorList(const TensorList& other)
|
||||||
|
: element_shape(other.element_shape),
|
||||||
|
element_dtype(other.element_dtype),
|
||||||
|
max_num_elements(other.max_num_elements),
|
||||||
|
tensors_(other.tensors_) {
|
||||||
|
tensors_->Ref();
|
||||||
|
}
|
||||||
|
|
||||||
|
TensorList(TensorList&& rhs)
|
||||||
|
: element_shape(std::move(rhs.element_shape)),
|
||||||
|
element_dtype(rhs.element_dtype),
|
||||||
|
max_num_elements(rhs.max_num_elements),
|
||||||
|
tensors_(rhs.tensors_) {
|
||||||
|
rhs.tensors_ = nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
TensorList& operator=(const TensorList& rhs) {
|
||||||
|
if (this == &rhs) return *this;
|
||||||
|
element_shape = rhs.element_shape;
|
||||||
|
element_dtype = rhs.element_dtype;
|
||||||
|
max_num_elements = rhs.max_num_elements;
|
||||||
|
tensors_->Unref();
|
||||||
|
tensors_ = rhs.tensors_;
|
||||||
|
tensors_->Ref();
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
TensorList& operator=(TensorList&& rhs) {
|
||||||
|
if (this == &rhs) return *this;
|
||||||
|
element_shape = rhs.element_shape;
|
||||||
|
element_dtype = rhs.element_dtype;
|
||||||
|
max_num_elements = rhs.max_num_elements;
|
||||||
|
std::swap(tensors_, rhs.tensors_);
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
static const char kTypeName[];
|
static const char kTypeName[];
|
||||||
|
|
||||||
string TypeName() const { return kTypeName; }
|
string TypeName() const { return kTypeName; }
|
||||||
|
|
||||||
void Encode(VariantTensorData* data) const;
|
void Encode(VariantTensorData* data) const;
|
||||||
@ -56,14 +131,47 @@ struct TensorList {
|
|||||||
// TODO(apassos) fill this out
|
// TODO(apassos) fill this out
|
||||||
string DebugString() const { return "TensorList"; }
|
string DebugString() const { return "TensorList"; }
|
||||||
|
|
||||||
std::vector<Tensor> tensors;
|
|
||||||
PartialTensorShape element_shape;
|
PartialTensorShape element_shape;
|
||||||
|
|
||||||
DataType element_dtype;
|
DataType element_dtype;
|
||||||
|
|
||||||
// The maximum allowed size of `tensors`. Defaults to -1 meaning that the size
|
// The maximum allowed size of `tensors`. Defaults to -1 meaning that the size
|
||||||
// of `tensors` is unbounded.
|
// of `tensors` is unbounded.
|
||||||
int max_num_elements = -1;
|
int max_num_elements = -1;
|
||||||
|
|
||||||
|
// Access to the underlying tensor container.
|
||||||
|
std::vector<Tensor>& tensors() { return tensors_->values_; }
|
||||||
|
const std::vector<Tensor>& tensors() const { return tensors_->values_; }
|
||||||
|
|
||||||
|
// Get a new TensorList containing a copy of the underlying tensor container.
|
||||||
|
TensorList Copy() const {
|
||||||
|
TensorList out;
|
||||||
|
out.element_shape = element_shape;
|
||||||
|
out.element_dtype = element_dtype;
|
||||||
|
out.max_num_elements = max_num_elements;
|
||||||
|
// This performs a copy of the std::vector.
|
||||||
|
out.tensors_->values_ = tensors_->values_;
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Is this TensorList the only one with a reference to the underlying
|
||||||
|
// container?
|
||||||
|
bool RefCountIsOne() const { return tensors_->RefCountIsOne(); }
|
||||||
|
|
||||||
|
private:
|
||||||
|
class Tensors : public core::RefCounted {
|
||||||
|
public:
|
||||||
|
std::vector<Tensor> values_;
|
||||||
|
};
|
||||||
|
Tensors* tensors_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
#if defined(PLATFORM_GOOGLE)
|
||||||
|
// TODO(ebrevdo): Identify why Variant inline size is smaller on mobile devices.
|
||||||
|
static_assert(Variant::CanInlineType<TensorList>(),
|
||||||
|
"Must be able to inline TensorList into a Variant");
|
||||||
|
#endif
|
||||||
|
|
||||||
Status TensorShapeFromTensor(const Tensor& t, PartialTensorShape* out);
|
Status TensorShapeFromTensor(const Tensor& t, PartialTensorShape* out);
|
||||||
|
|
||||||
Status GetElementShapeFromInput(OpKernelContext* c,
|
Status GetElementShapeFromInput(OpKernelContext* c,
|
||||||
@ -96,18 +204,19 @@ class TensorListStack : public OpKernel {
|
|||||||
"Invalid data types; op elements ", DataTypeString(element_dtype_),
|
"Invalid data types; op elements ", DataTypeString(element_dtype_),
|
||||||
" but list elements ", DataTypeString(tensor_list->element_dtype)));
|
" but list elements ", DataTypeString(tensor_list->element_dtype)));
|
||||||
if (num_elements_ != -1) {
|
if (num_elements_ != -1) {
|
||||||
OP_REQUIRES(c, tensor_list->tensors.size() == num_elements_,
|
OP_REQUIRES(c, tensor_list->tensors().size() == num_elements_,
|
||||||
errors::InvalidArgument(
|
errors::InvalidArgument(
|
||||||
"Operation expected a list with ", num_elements_,
|
"Operation expected a list with ", num_elements_,
|
||||||
" elements but got a list with ",
|
" elements but got a list with ",
|
||||||
tensor_list->tensors.size(), " elements."));
|
tensor_list->tensors().size(), " elements."));
|
||||||
}
|
}
|
||||||
PartialTensorShape partial_element_shape;
|
PartialTensorShape partial_element_shape;
|
||||||
OP_REQUIRES_OK(c, GetElementShapeFromInput(c, *tensor_list, 1,
|
OP_REQUIRES_OK(c, GetElementShapeFromInput(c, *tensor_list, 1,
|
||||||
&partial_element_shape));
|
&partial_element_shape));
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(
|
||||||
c,
|
c,
|
||||||
partial_element_shape.IsFullyDefined() || !tensor_list->tensors.empty(),
|
partial_element_shape.IsFullyDefined() ||
|
||||||
|
!tensor_list->tensors().empty(),
|
||||||
errors::InvalidArgument("Tried to stack elements of an empty ",
|
errors::InvalidArgument("Tried to stack elements of an empty ",
|
||||||
"list with non-fully-defined element_shape: ",
|
"list with non-fully-defined element_shape: ",
|
||||||
partial_element_shape.DebugString()));
|
partial_element_shape.DebugString()));
|
||||||
@ -115,8 +224,8 @@ class TensorListStack : public OpKernel {
|
|||||||
// Check that `element_shape` input tensor is compatible with the shapes of
|
// Check that `element_shape` input tensor is compatible with the shapes of
|
||||||
// element tensors.
|
// element tensors.
|
||||||
if (!tensor_list->element_shape.IsFullyDefined()) {
|
if (!tensor_list->element_shape.IsFullyDefined()) {
|
||||||
for (int i = 0; i < tensor_list->tensors.size(); ++i) {
|
for (int i = 0; i < tensor_list->tensors().size(); ++i) {
|
||||||
const Tensor& t = tensor_list->tensors[i];
|
const Tensor& t = tensor_list->tensors()[i];
|
||||||
if (t.dtype() != DT_INVALID) {
|
if (t.dtype() != DT_INVALID) {
|
||||||
PartialTensorShape tmp = partial_element_shape;
|
PartialTensorShape tmp = partial_element_shape;
|
||||||
OP_REQUIRES_OK(c, tmp.MergeWith(t.shape(), &partial_element_shape));
|
OP_REQUIRES_OK(c, tmp.MergeWith(t.shape(), &partial_element_shape));
|
||||||
@ -133,7 +242,7 @@ class TensorListStack : public OpKernel {
|
|||||||
"tensors and has a non-fully-defined element_shape: ",
|
"tensors and has a non-fully-defined element_shape: ",
|
||||||
partial_element_shape.DebugString()));
|
partial_element_shape.DebugString()));
|
||||||
TensorShape output_shape = element_shape;
|
TensorShape output_shape = element_shape;
|
||||||
output_shape.InsertDim(0, tensor_list->tensors.size());
|
output_shape.InsertDim(0, tensor_list->tensors().size());
|
||||||
Tensor* output;
|
Tensor* output;
|
||||||
OP_REQUIRES_OK(c, c->allocate_output(0, output_shape, &output));
|
OP_REQUIRES_OK(c, c->allocate_output(0, output_shape, &output));
|
||||||
if (output->NumElements() == 0) {
|
if (output->NumElements() == 0) {
|
||||||
@ -141,9 +250,9 @@ class TensorListStack : public OpKernel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
ConstMatrixVector inputs_flat;
|
ConstMatrixVector inputs_flat;
|
||||||
inputs_flat.reserve(tensor_list->tensors.size());
|
inputs_flat.reserve(tensor_list->tensors().size());
|
||||||
Tensor zeros;
|
Tensor zeros;
|
||||||
for (const auto& t : tensor_list->tensors) {
|
for (const auto& t : tensor_list->tensors()) {
|
||||||
if (t.dtype() != DT_INVALID) {
|
if (t.dtype() != DT_INVALID) {
|
||||||
inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix(
|
inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix(
|
||||||
t.shaped<T, 2>({1, t.NumElements()})));
|
t.shaped<T, 2>({1, t.NumElements()})));
|
||||||
@ -195,12 +304,12 @@ class TensorListGetItem : public OpKernel {
|
|||||||
" but list elements ",
|
" but list elements ",
|
||||||
DataTypeString(l->element_dtype)));
|
DataTypeString(l->element_dtype)));
|
||||||
int32 index = c->input(1).scalar<int32>()();
|
int32 index = c->input(1).scalar<int32>()();
|
||||||
OP_REQUIRES(c, index < l->tensors.size(),
|
OP_REQUIRES(c, index < l->tensors().size(),
|
||||||
errors::InvalidArgument("Trying to access element ", index,
|
errors::InvalidArgument("Trying to access element ", index,
|
||||||
" in a list with ", l->tensors.size(),
|
" in a list with ", l->tensors().size(),
|
||||||
" elements."));
|
" elements."));
|
||||||
if (l->tensors[index].dtype() != DT_INVALID) {
|
if (l->tensors()[index].dtype() != DT_INVALID) {
|
||||||
c->set_output(0, l->tensors[index]);
|
c->set_output(0, l->tensors()[index]);
|
||||||
} else {
|
} else {
|
||||||
PartialTensorShape partial_element_shape;
|
PartialTensorShape partial_element_shape;
|
||||||
OP_REQUIRES_OK(
|
OP_REQUIRES_OK(
|
||||||
@ -216,7 +325,7 @@ class TensorListGetItem : public OpKernel {
|
|||||||
// In that mode TensorArray sets the array's element_shape on the first
|
// In that mode TensorArray sets the array's element_shape on the first
|
||||||
// write call. We could do something similar here if needed.
|
// write call. We could do something similar here if needed.
|
||||||
if (!partial_element_shape.IsFullyDefined()) {
|
if (!partial_element_shape.IsFullyDefined()) {
|
||||||
for (const Tensor& t : l->tensors) {
|
for (const Tensor& t : l->tensors()) {
|
||||||
if (t.dtype() != DT_INVALID) {
|
if (t.dtype() != DT_INVALID) {
|
||||||
PartialTensorShape tmp = partial_element_shape;
|
PartialTensorShape tmp = partial_element_shape;
|
||||||
OP_REQUIRES_OK(c, tmp.MergeWith(t.shape(), &partial_element_shape));
|
OP_REQUIRES_OK(c, tmp.MergeWith(t.shape(), &partial_element_shape));
|
||||||
@ -260,10 +369,10 @@ class TensorListPopBack : public OpKernel {
|
|||||||
" but list elements ",
|
" but list elements ",
|
||||||
DataTypeString(l->element_dtype)));
|
DataTypeString(l->element_dtype)));
|
||||||
|
|
||||||
OP_REQUIRES(c, !l->tensors.empty(),
|
OP_REQUIRES(c, !l->tensors().empty(),
|
||||||
errors::InvalidArgument("Trying to pop from an empty list."));
|
errors::InvalidArgument("Trying to pop from an empty list."));
|
||||||
|
|
||||||
const Tensor& t = l->tensors.back();
|
const Tensor& t = l->tensors().back();
|
||||||
if (t.dtype() != DT_INVALID) {
|
if (t.dtype() != DT_INVALID) {
|
||||||
c->set_output(1, t);
|
c->set_output(1, t);
|
||||||
} else {
|
} else {
|
||||||
@ -288,7 +397,7 @@ class TensorListPopBack : public OpKernel {
|
|||||||
|
|
||||||
TensorList* output_list = nullptr;
|
TensorList* output_list = nullptr;
|
||||||
OP_REQUIRES_OK(c, ForwardInputOrCreateNewList(c, 0, 0, *l, &output_list));
|
OP_REQUIRES_OK(c, ForwardInputOrCreateNewList(c, 0, 0, *l, &output_list));
|
||||||
output_list->tensors.pop_back();
|
output_list->tensors().pop_back();
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@ -347,7 +456,7 @@ class TensorListConcat : public OpKernel {
|
|||||||
// If the TensorList is empty, element_shape_except_first_dim_ must be fully
|
// If the TensorList is empty, element_shape_except_first_dim_ must be fully
|
||||||
// defined.
|
// defined.
|
||||||
OP_REQUIRES(c,
|
OP_REQUIRES(c,
|
||||||
!tensor_list->tensors.empty() ||
|
!tensor_list->tensors().empty() ||
|
||||||
element_shape_except_first_dim_.IsFullyDefined(),
|
element_shape_except_first_dim_.IsFullyDefined(),
|
||||||
errors::InvalidArgument(
|
errors::InvalidArgument(
|
||||||
"All except the first dimension must be fully defined ",
|
"All except the first dimension must be fully defined ",
|
||||||
@ -364,8 +473,8 @@ class TensorListConcat : public OpKernel {
|
|||||||
if (!tensor_list->element_shape.IsFullyDefined()) {
|
if (!tensor_list->element_shape.IsFullyDefined()) {
|
||||||
bool check_dim = (first_dim == -1);
|
bool check_dim = (first_dim == -1);
|
||||||
int64 inferred_first_dim = first_dim;
|
int64 inferred_first_dim = first_dim;
|
||||||
for (int i = 0; i < tensor_list->tensors.size(); ++i) {
|
for (int i = 0; i < tensor_list->tensors().size(); ++i) {
|
||||||
const Tensor& t = tensor_list->tensors[i];
|
const Tensor& t = tensor_list->tensors()[i];
|
||||||
if (t.dtype() != DT_INVALID) {
|
if (t.dtype() != DT_INVALID) {
|
||||||
PartialTensorShape tmp = element_shape_except_first_dim_;
|
PartialTensorShape tmp = element_shape_except_first_dim_;
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(
|
||||||
@ -407,14 +516,14 @@ class TensorListConcat : public OpKernel {
|
|||||||
OP_REQUIRES_OK(
|
OP_REQUIRES_OK(
|
||||||
c,
|
c,
|
||||||
c->allocate_output(
|
c->allocate_output(
|
||||||
1, TensorShape({static_cast<int64>(tensor_list->tensors.size())}),
|
1, TensorShape({static_cast<int64>(tensor_list->tensors().size())}),
|
||||||
&lengths_tensor));
|
&lengths_tensor));
|
||||||
auto lengths_tensor_vec = lengths_tensor->vec<int64>();
|
auto lengths_tensor_vec = lengths_tensor->vec<int64>();
|
||||||
int64 leading_dim = 0;
|
int64 leading_dim = 0;
|
||||||
for (size_t i = 0; i < tensor_list->tensors.size(); i++) {
|
for (size_t i = 0; i < tensor_list->tensors().size(); i++) {
|
||||||
int64 dim;
|
int64 dim;
|
||||||
if (tensor_list->tensors[i].dtype() != DT_INVALID) {
|
if (tensor_list->tensors()[i].dtype() != DT_INVALID) {
|
||||||
dim = tensor_list->tensors[i].shape().dim_size(0);
|
dim = tensor_list->tensors()[i].shape().dim_size(0);
|
||||||
} else {
|
} else {
|
||||||
// If leading_dims is not provided or does not contain an entry for
|
// If leading_dims is not provided or does not contain an entry for
|
||||||
// index i use the inferred `first_dim` if set.
|
// index i use the inferred `first_dim` if set.
|
||||||
@ -449,12 +558,12 @@ class TensorListConcat : public OpKernel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
ConstMatrixVector inputs_flat;
|
ConstMatrixVector inputs_flat;
|
||||||
inputs_flat.reserve(tensor_list->tensors.size());
|
inputs_flat.reserve(tensor_list->tensors().size());
|
||||||
// Store the zeros tensors in a vector to prevent them from being GC'ed till
|
// Store the zeros tensors in a vector to prevent them from being GC'ed till
|
||||||
// concat is complete.
|
// concat is complete.
|
||||||
std::vector<Tensor> zeros_vec;
|
std::vector<Tensor> zeros_vec;
|
||||||
for (int i = 0; i < tensor_list->tensors.size(); i++) {
|
for (int i = 0; i < tensor_list->tensors().size(); i++) {
|
||||||
const Tensor& element_tensor = tensor_list->tensors[i];
|
const Tensor& element_tensor = tensor_list->tensors()[i];
|
||||||
if (element_tensor.dtype() != DT_INVALID) {
|
if (element_tensor.dtype() != DT_INVALID) {
|
||||||
inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix(
|
inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix(
|
||||||
element_tensor.shaped<T, 2>({1, element_tensor.NumElements()})));
|
element_tensor.shaped<T, 2>({1, element_tensor.NumElements()})));
|
||||||
@ -536,7 +645,7 @@ class TensorListSplit : public OpKernel {
|
|||||||
errors::InvalidArgument(
|
errors::InvalidArgument(
|
||||||
"Expected lengths to be a vector, received shape: ",
|
"Expected lengths to be a vector, received shape: ",
|
||||||
lengths.shape().DebugString()));
|
lengths.shape().DebugString()));
|
||||||
output_list.tensors.reserve(lengths.shape().dim_size(0));
|
output_list.tensors().reserve(lengths.shape().dim_size(0));
|
||||||
int64 start = 0;
|
int64 start = 0;
|
||||||
int64 end = 0;
|
int64 end = 0;
|
||||||
for (int i = 0; i < lengths.shape().dim_size(0); ++i) {
|
for (int i = 0; i < lengths.shape().dim_size(0); ++i) {
|
||||||
@ -557,7 +666,7 @@ class TensorListSplit : public OpKernel {
|
|||||||
OP_REQUIRES_OK(c, c->allocate_temp(tmp.dtype(), tmp.shape(), &aligned));
|
OP_REQUIRES_OK(c, c->allocate_temp(tmp.dtype(), tmp.shape(), &aligned));
|
||||||
aligned.flat<T>().device(c->eigen_device<Device>()) =
|
aligned.flat<T>().device(c->eigen_device<Device>()) =
|
||||||
tmp.unaligned_flat<T>();
|
tmp.unaligned_flat<T>();
|
||||||
output_list.tensors.emplace_back(aligned);
|
output_list.tensors().emplace_back(aligned);
|
||||||
}
|
}
|
||||||
OP_REQUIRES(c, end == input_tensor.shape().dim_size(0),
|
OP_REQUIRES(c, end == input_tensor.shape().dim_size(0),
|
||||||
errors::InvalidArgument(
|
errors::InvalidArgument(
|
||||||
@ -599,7 +708,7 @@ class TensorListGather : public OpKernel {
|
|||||||
if (!tensor_list->element_shape.IsFullyDefined()) {
|
if (!tensor_list->element_shape.IsFullyDefined()) {
|
||||||
for (int index = 0; index < indices.NumElements(); ++index) {
|
for (int index = 0; index < indices.NumElements(); ++index) {
|
||||||
const int i = indices.flat<int32>()(index);
|
const int i = indices.flat<int32>()(index);
|
||||||
const Tensor& t = tensor_list->tensors[i];
|
const Tensor& t = tensor_list->tensors()[i];
|
||||||
if (t.dtype() != DT_INVALID) {
|
if (t.dtype() != DT_INVALID) {
|
||||||
PartialTensorShape tmp = partial_element_shape;
|
PartialTensorShape tmp = partial_element_shape;
|
||||||
OP_REQUIRES_OK(c, tmp.MergeWith(t.shape(), &partial_element_shape));
|
OP_REQUIRES_OK(c, tmp.MergeWith(t.shape(), &partial_element_shape));
|
||||||
@ -629,10 +738,10 @@ class TensorListGather : public OpKernel {
|
|||||||
for (int index = 0; index < indices.NumElements(); ++index) {
|
for (int index = 0; index < indices.NumElements(); ++index) {
|
||||||
const int i = indices.flat<int32>()(index);
|
const int i = indices.flat<int32>()(index);
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(
|
||||||
c, i < tensor_list->tensors.size(),
|
c, i < tensor_list->tensors().size(),
|
||||||
errors::InvalidArgument("Index ", i, " out o range; list only has ",
|
errors::InvalidArgument("Index ", i, " out o range; list only has ",
|
||||||
tensor_list->tensors.size(), " elements."));
|
tensor_list->tensors().size(), " elements."));
|
||||||
const Tensor& t = tensor_list->tensors[i];
|
const Tensor& t = tensor_list->tensors()[i];
|
||||||
if (t.dtype() != DT_INVALID) {
|
if (t.dtype() != DT_INVALID) {
|
||||||
inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix(
|
inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix(
|
||||||
t.shaped<T, 2>({1, t.NumElements()})));
|
t.shaped<T, 2>({1, t.NumElements()})));
|
||||||
@ -693,7 +802,7 @@ class TensorListFromTensor : public OpKernel {
|
|||||||
"Specified a list with shape ", element_shape.DebugString(),
|
"Specified a list with shape ", element_shape.DebugString(),
|
||||||
" from a tensor with shape ", output_shape.DebugString()));
|
" from a tensor with shape ", output_shape.DebugString()));
|
||||||
output_list.element_shape = element_shape;
|
output_list.element_shape = element_shape;
|
||||||
output_list.tensors.reserve(t.shape().dim_size(0));
|
output_list.tensors().reserve(t.shape().dim_size(0));
|
||||||
for (int i = 0; i < t.shape().dim_size(0); ++i) {
|
for (int i = 0; i < t.shape().dim_size(0); ++i) {
|
||||||
Tensor tmp = t.Slice(i, i + 1);
|
Tensor tmp = t.Slice(i, i + 1);
|
||||||
TensorShape tmp_shape = tmp.shape();
|
TensorShape tmp_shape = tmp.shape();
|
||||||
@ -706,7 +815,7 @@ class TensorListFromTensor : public OpKernel {
|
|||||||
OP_REQUIRES_OK(c, c->allocate_temp(tmp.dtype(), tmp.shape(), &aligned));
|
OP_REQUIRES_OK(c, c->allocate_temp(tmp.dtype(), tmp.shape(), &aligned));
|
||||||
aligned.flat<T>().device(c->eigen_device<Device>()) =
|
aligned.flat<T>().device(c->eigen_device<Device>()) =
|
||||||
tmp.unaligned_flat<T>();
|
tmp.unaligned_flat<T>();
|
||||||
output_list.tensors.push_back(aligned);
|
output_list.tensors().push_back(aligned);
|
||||||
}
|
}
|
||||||
output_tensor->scalar<Variant>()() = std::move(output_list);
|
output_tensor->scalar<Variant>()() = std::move(output_list);
|
||||||
}
|
}
|
||||||
@ -732,7 +841,7 @@ Status Scatter(OpKernelContext* c, const Tensor& value, const Tensor& indices,
|
|||||||
// many small ones.
|
// many small ones.
|
||||||
aligned.flat<T>().device(c->eigen_device<Device>()) =
|
aligned.flat<T>().device(c->eigen_device<Device>()) =
|
||||||
tmp.unaligned_flat<T>();
|
tmp.unaligned_flat<T>();
|
||||||
std::swap(list->tensors[i], aligned);
|
std::swap(list->tensors()[i], aligned);
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
@ -777,8 +886,8 @@ class TensorListScatterIntoExistingList : public OpKernel {
|
|||||||
? -1
|
? -1
|
||||||
: *std::max_element(indices_vec.data(),
|
: *std::max_element(indices_vec.data(),
|
||||||
indices_vec.data() + indices.NumElements());
|
indices_vec.data() + indices.NumElements());
|
||||||
if (max_index + 1 > output_list->tensors.size()) {
|
if (max_index + 1 > output_list->tensors().size()) {
|
||||||
output_list->tensors.resize(max_index + 1);
|
output_list->tensors().resize(max_index + 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Scatter the values.
|
// Scatter the values.
|
||||||
@ -845,8 +954,8 @@ class TensorListScatter : public OpKernel {
|
|||||||
highest_index = i;
|
highest_index = i;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
output_list.tensors.resize(std::max(highest_index + 1, num_elements),
|
output_list.tensors().resize(std::max(highest_index + 1, num_elements),
|
||||||
Tensor(DT_INVALID));
|
Tensor(DT_INVALID));
|
||||||
}
|
}
|
||||||
|
|
||||||
OP_REQUIRES_OK(c,
|
OP_REQUIRES_OK(c,
|
||||||
@ -875,19 +984,19 @@ Status TensorListBinaryAdd(OpKernelContext* c, const TensorList& a,
|
|||||||
|
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
a.element_shape.MergeWith(b.element_shape, &out->element_shape));
|
a.element_shape.MergeWith(b.element_shape, &out->element_shape));
|
||||||
if (a.tensors.size() != b.tensors.size()) {
|
if (a.tensors().size() != b.tensors().size()) {
|
||||||
return errors::InvalidArgument(
|
return errors::InvalidArgument(
|
||||||
"Trying to add two lists of tensors with different lengths. One is ",
|
"Trying to add two lists of tensors with different lengths. One is ",
|
||||||
a.tensors.size(), " and the other is ", b.tensors.size());
|
a.tensors().size(), " and the other is ", b.tensors().size());
|
||||||
}
|
}
|
||||||
out->tensors.reserve(a.tensors.size());
|
out->tensors().reserve(a.tensors().size());
|
||||||
for (int i = 0; i < a.tensors.size(); ++i) {
|
for (int i = 0; i < a.tensors().size(); ++i) {
|
||||||
const Tensor& a_tensor = a.tensors[i];
|
const Tensor& a_tensor = a.tensors()[i];
|
||||||
const Tensor& b_tensor = b.tensors[i];
|
const Tensor& b_tensor = b.tensors()[i];
|
||||||
Tensor out_tensor;
|
Tensor out_tensor;
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
BinaryAddTensors<Device>(c, a_tensor, b_tensor, &out_tensor));
|
BinaryAddTensors<Device>(c, a_tensor, b_tensor, &out_tensor));
|
||||||
out->tensors.push_back(out_tensor);
|
out->tensors().push_back(out_tensor);
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
@ -897,11 +1006,11 @@ Status TensorListZerosLike(OpKernelContext* c, const TensorList& x,
|
|||||||
TensorList* y) {
|
TensorList* y) {
|
||||||
y->element_dtype = x.element_dtype;
|
y->element_dtype = x.element_dtype;
|
||||||
y->element_shape = x.element_shape;
|
y->element_shape = x.element_shape;
|
||||||
y->tensors.reserve(x.tensors.size());
|
y->tensors().reserve(x.tensors().size());
|
||||||
for (const Tensor& t : x.tensors) {
|
for (const Tensor& t : x.tensors()) {
|
||||||
Tensor out_tensor;
|
Tensor out_tensor;
|
||||||
TF_RETURN_IF_ERROR(ZerosLikeTensor<Device>(c, t, &out_tensor));
|
TF_RETURN_IF_ERROR(ZerosLikeTensor<Device>(c, t, &out_tensor));
|
||||||
y->tensors.emplace_back(out_tensor);
|
y->tensors().emplace_back(out_tensor);
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
@ -936,7 +1045,19 @@ class TensorListPushBackBatch : public OpKernel {
|
|||||||
0 /*input_index*/, 0 /*output_index*/, DT_VARIANT, tls_shape,
|
0 /*input_index*/, 0 /*output_index*/, DT_VARIANT, tls_shape,
|
||||||
DEVICE_MEMORY /* input is always on DEVICE_MEMORY */, attr);
|
DEVICE_MEMORY /* input is always on DEVICE_MEMORY */, attr);
|
||||||
|
|
||||||
const Tensor& tls = tls_alias ? *tls_alias : c->input(0);
|
bool ok_to_alias = tls_alias != nullptr;
|
||||||
|
if (tls_alias && tls_alias->dtype() == DT_VARIANT &&
|
||||||
|
tls_alias->NumElements() > 0) {
|
||||||
|
auto alias_t = tls_alias->flat<Variant>();
|
||||||
|
for (int i = 0; i < tls_alias->NumElements(); ++i) {
|
||||||
|
TensorList* tl_i = alias_t(i).get<TensorList>();
|
||||||
|
if (tl_i == nullptr || !tl_i->RefCountIsOne()) {
|
||||||
|
ok_to_alias = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
const Tensor& tls = ok_to_alias ? *tls_alias : c->input(0);
|
||||||
|
|
||||||
OP_REQUIRES(c, tls.dtype() == DT_VARIANT,
|
OP_REQUIRES(c, tls.dtype() == DT_VARIANT,
|
||||||
errors::InvalidArgument(
|
errors::InvalidArgument(
|
||||||
@ -979,7 +1100,7 @@ class TensorListPushBackBatch : public OpKernel {
|
|||||||
|
|
||||||
Tensor* result;
|
Tensor* result;
|
||||||
|
|
||||||
if (tls_alias) {
|
if (ok_to_alias) {
|
||||||
result = tls_alias.get();
|
result = tls_alias.get();
|
||||||
c->set_output(0, *result);
|
c->set_output(0, *result);
|
||||||
} else {
|
} else {
|
||||||
@ -998,8 +1119,8 @@ class TensorListPushBackBatch : public OpKernel {
|
|||||||
auto result_t = result->vec<Variant>();
|
auto result_t = result->vec<Variant>();
|
||||||
|
|
||||||
for (int64 b = 0; b < batch_size; ++b) {
|
for (int64 b = 0; b < batch_size; ++b) {
|
||||||
if (!tls_alias) {
|
if (!ok_to_alias) {
|
||||||
result_t(b) = *tl_batch[b];
|
result_t(b) = tl_batch[b]->Copy();
|
||||||
}
|
}
|
||||||
TensorList* output = result_t(b).get<TensorList>();
|
TensorList* output = result_t(b).get<TensorList>();
|
||||||
DCHECK(output != nullptr);
|
DCHECK(output != nullptr);
|
||||||
@ -1011,7 +1132,7 @@ class TensorListPushBackBatch : public OpKernel {
|
|||||||
auto frame_t = frame->flat<T>();
|
auto frame_t = frame->flat<T>();
|
||||||
frame_t.device(c->eigen_device<Device>()) = input_t.template chip<0>(b);
|
frame_t.device(c->eigen_device<Device>()) = input_t.template chip<0>(b);
|
||||||
}
|
}
|
||||||
output->tensors.push_back(std::move(*frame));
|
output->tensors().push_back(std::move(*frame));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -53,7 +53,10 @@ class ListOpsTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
|||||||
max_num_elements=max_num_elements)
|
max_num_elements=max_num_elements)
|
||||||
l = list_ops.tensor_list_push_back(l, constant_op.constant(1.0))
|
l = list_ops.tensor_list_push_back(l, constant_op.constant(1.0))
|
||||||
l, e = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32)
|
l, e = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32)
|
||||||
self.assertAllEqual(self.evaluate(e), 1.0)
|
l = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)
|
||||||
|
l, e = self.evaluate((l, e))
|
||||||
|
self.assertAllEqual(l, [])
|
||||||
|
self.assertAllEqual(e, 1.0)
|
||||||
|
|
||||||
@parameterized.named_parameters(("NoMaxNumElements", None),
|
@parameterized.named_parameters(("NoMaxNumElements", None),
|
||||||
("WithMaxNumElements", 2))
|
("WithMaxNumElements", 2))
|
||||||
@ -94,7 +97,10 @@ class ListOpsTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
|||||||
l = list_ops.tensor_list_reserve(
|
l = list_ops.tensor_list_reserve(
|
||||||
element_dtype=dtypes.float32, element_shape=[2, 3], num_elements=3)
|
element_dtype=dtypes.float32, element_shape=[2, 3], num_elements=3)
|
||||||
_, e = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32)
|
_, e = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32)
|
||||||
|
l = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)
|
||||||
|
l, e = self.evaluate((l, e))
|
||||||
self.assertAllEqual(e, np.zeros((2, 3)))
|
self.assertAllEqual(e, np.zeros((2, 3)))
|
||||||
|
self.assertAllEqual(l, np.zeros((3, 2, 3)))
|
||||||
|
|
||||||
def testPopUninitializedTensorUseSpecifiedElementShape(self):
|
def testPopUninitializedTensorUseSpecifiedElementShape(self):
|
||||||
l = list_ops.tensor_list_reserve(
|
l = list_ops.tensor_list_reserve(
|
||||||
@ -954,14 +960,18 @@ class ListOpsTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
|||||||
l_concat_11 = list_ops.tensor_list_concat_lists(
|
l_concat_11 = list_ops.tensor_list_concat_lists(
|
||||||
l_batch_1, l_batch_1, element_dtype=dtypes.float32)
|
l_batch_1, l_batch_1, element_dtype=dtypes.float32)
|
||||||
|
|
||||||
|
expected_0 = [[1.0, 2.0], [-1.0]]
|
||||||
|
expected_1 = [[-1.0], [1.0, 2.0]]
|
||||||
expected_00 = [[1.0, 2.0, 1.0, 2.0], [-1.0, -1.0]]
|
expected_00 = [[1.0, 2.0, 1.0, 2.0], [-1.0, -1.0]]
|
||||||
expected_01 = [[1.0, 2.0, -1.0], [-1.0, 1.0, 2.0]]
|
expected_01 = [[1.0, 2.0, -1.0], [-1.0, 1.0, 2.0]]
|
||||||
expected_10 = [[-1.0, 1.0, 2.0], [1.0, 2.0, -1.0]]
|
expected_10 = [[-1.0, 1.0, 2.0], [1.0, 2.0, -1.0]]
|
||||||
expected_11 = [[-1.0, -1.0], [1.0, 2.0, 1.0, 2.0]]
|
expected_11 = [[-1.0, -1.0], [1.0, 2.0, 1.0, 2.0]]
|
||||||
|
|
||||||
for i, (concat, expected) in enumerate(zip(
|
for i, (concat, expected) in enumerate(zip(
|
||||||
[l_concat_00, l_concat_01, l_concat_10, l_concat_11],
|
[l_batch_0, l_batch_1,
|
||||||
[expected_00, expected_01, expected_10, expected_11])):
|
l_concat_00, l_concat_01, l_concat_10, l_concat_11],
|
||||||
|
[expected_0, expected_1,
|
||||||
|
expected_00, expected_01, expected_10, expected_11])):
|
||||||
splitted = array_ops.unstack(concat)
|
splitted = array_ops.unstack(concat)
|
||||||
splitted_stacked_ret = self.evaluate(
|
splitted_stacked_ret = self.evaluate(
|
||||||
(list_ops.tensor_list_stack(splitted[0], dtypes.float32),
|
(list_ops.tensor_list_stack(splitted[0], dtypes.float32),
|
||||||
|
Loading…
Reference in New Issue
Block a user