Refactor DirectedInterleaveDatasetOp

This commit is contained in:
feihugis 2020-03-16 14:31:44 -05:00
parent 34af8d45d0
commit 9b6fd77bd6
3 changed files with 311 additions and 235 deletions

View File

@ -135,14 +135,31 @@ tf_kernel_library(
tf_kernel_library(
name = "directed_interleave_dataset_op",
srcs = ["directed_interleave_dataset_op.cc"],
hdrs = ["directed_interleave_dataset_op.h"],
deps = [
"//tensorflow/core:experimental_dataset_ops_op_lib",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core/kernels/data:name_utils",
"//third_party/eigen3",
],
)
tf_cc_test(
name = "directed_interleave_dataset_op_test",
size = "small",
srcs = ["directed_interleave_dataset_op_test.cc"],
deps = [
":directed_interleave_dataset_op",
"//tensorflow/core:experimental_dataset_ops_op_lib",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
"//tensorflow/core/kernels/data:dataset_test_base",
"//tensorflow/core/kernels/data:range_dataset_op",
"//tensorflow/core/kernels/data:tensor_slice_dataset_op",
],
)
tf_kernel_library(
name = "group_by_reducer_dataset_op",
srcs = ["group_by_reducer_dataset_op.cc"],

View File

@ -12,284 +12,296 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op.h"
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/kernels/data/name_utils.h"
#include "tensorflow/core/lib/hash/hash.h"
namespace tensorflow {
namespace data {
namespace experimental {
namespace {
class DirectedInterleaveDatasetOp : public DatasetOpKernel {
/* static */ constexpr const char* const
DirectedInterleaveDatasetOp::kDatasetType;
/* static */ constexpr const char* const
DirectedInterleaveDatasetOp::kSelectorInputDataset;
/* static */ constexpr const char* const
DirectedInterleaveDatasetOp::kDataInputDatasets;
/* static */ constexpr const char* const
DirectedInterleaveDatasetOp::kOutputTypes;
/* static */ constexpr const char* const
DirectedInterleaveDatasetOp::kOutputShapes;
/* static */ constexpr const char* const
DirectedInterleaveDatasetOp::kNumDatasets;
class DirectedInterleaveDatasetOp::Dataset : public DatasetBase {
public:
explicit DirectedInterleaveDatasetOp(OpKernelConstruction* ctx)
: DatasetOpKernel(ctx) {}
Dataset(OpKernelContext* ctx, const DatasetBase* selector_input,
std::vector<DatasetBase*> data_inputs)
: DatasetBase(DatasetContext(ctx)),
selector_input_(selector_input),
data_inputs_(std::move(data_inputs)) {
selector_input_->Ref();
void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
DatasetBase* selector_input;
OP_REQUIRES_OK(ctx,
GetDatasetFromVariantTensor(ctx->input(0), &selector_input));
OP_REQUIRES(
ctx,
selector_input->output_dtypes().size() == 1 &&
selector_input->output_dtypes()[0] == DT_INT64 &&
selector_input->output_shapes().size() == 1 &&
selector_input->output_shapes()[0].IsCompatibleWith(
PartialTensorShape({})),
errors::InvalidArgument(
"The selector input must be a dataset of scalar int64 elements."));
std::vector<DatasetBase*> data_inputs;
for (size_t i = 1; i < ctx->num_inputs(); ++i) {
DatasetBase* input;
OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(i), &input));
data_inputs.push_back(input);
OP_REQUIRES(
ctx, data_inputs[0]->output_dtypes() == input->output_dtypes(),
errors::InvalidArgument(
"All inputs must have the same output_dtypes. First input "
"has types ",
DataTypeVectorString(data_inputs[0]->output_dtypes()),
", and input ", i - 1, " has types ",
DataTypeVectorString(input->output_dtypes())));
output_shapes_ = data_inputs_[0]->output_shapes();
data_inputs_[0]->Ref();
for (size_t i = 1; i < data_inputs_.size(); ++i) {
const DatasetBase* data_input = data_inputs_[i];
data_input->Ref();
for (size_t j = 0; j < output_shapes_.size(); ++j) {
output_shapes_[j] = MostSpecificCompatibleShape(
output_shapes_[j], data_input->output_shapes()[j]);
}
}
*output = new Dataset(ctx, selector_input, std::move(data_inputs));
}
~Dataset() override {
selector_input_->Unref();
for (DatasetBase* data_input : data_inputs_) {
data_input->Unref();
}
}
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return absl::make_unique<Iterator>(Iterator::Params{
this, name_utils::IteratorPrefix(kDatasetType, prefix)});
}
const DataTypeVector& output_dtypes() const override {
return data_inputs_[0]->output_dtypes();
}
const std::vector<PartialTensorShape>& output_shapes() const override {
return output_shapes_;
}
string DebugString() const override {
return name_utils::DatasetDebugString(kDatasetType);
}
Status CheckExternalState() const override {
for (const auto& input : data_inputs_) {
TF_RETURN_IF_ERROR(input->CheckExternalState());
}
return selector_input_->CheckExternalState();
}
protected:
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,
Node** output) const override {
Node* selector_input_node;
TF_RETURN_IF_ERROR(
b->AddInputDataset(ctx, selector_input_, &selector_input_node));
std::vector<Node*> data_input_nodes(data_inputs_.size());
for (size_t i = 0; i < data_inputs_.size(); ++i) {
TF_RETURN_IF_ERROR(
b->AddInputDataset(ctx, data_inputs_[i], &data_input_nodes[i]));
}
TF_RETURN_IF_ERROR(b->AddDataset(this, {{0, selector_input_node}},
{{1, data_input_nodes}}, {}, output));
return Status::OK();
}
private:
class Dataset : public DatasetBase {
class Iterator : public DatasetIterator<Dataset> {
public:
Dataset(OpKernelContext* ctx, const DatasetBase* selector_input,
std::vector<DatasetBase*> data_inputs)
: DatasetBase(DatasetContext(ctx)),
selector_input_(selector_input),
data_inputs_(std::move(data_inputs)) {
selector_input_->Ref();
explicit Iterator(const Params& params)
: DatasetIterator<Dataset>(params),
num_active_inputs_(params.dataset->data_inputs_.size()) {}
output_shapes_ = data_inputs_[0]->output_shapes();
data_inputs_[0]->Ref();
for (size_t i = 1; i < data_inputs_.size(); ++i) {
const DatasetBase* data_input = data_inputs_[i];
data_input->Ref();
for (size_t j = 0; j < output_shapes_.size(); ++j) {
output_shapes_[j] = MostSpecificCompatibleShape(
output_shapes_[j], data_input->output_shapes()[j]);
}
Status Initialize(IteratorContext* ctx) override {
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(dataset()->selector_input_->MakeIterator(
ctx, this, prefix(), &selector_input_impl_));
data_input_impls_.resize(dataset()->data_inputs_.size());
for (size_t i = 0; i < data_input_impls_.size(); ++i) {
const DatasetBase* data_input = dataset()->data_inputs_[i];
TF_RETURN_IF_ERROR(data_input->MakeIterator(
ctx, this, strings::StrCat(prefix(), "[", i, "]"),
&data_input_impls_[i]));
}
}
~Dataset() override {
selector_input_->Unref();
for (DatasetBase* data_input : data_inputs_) {
data_input->Unref();
}
}
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return absl::make_unique<Iterator>(Iterator::Params{
this, strings::StrCat(prefix, "::DirectedInterleave")});
}
const DataTypeVector& output_dtypes() const override {
return data_inputs_[0]->output_dtypes();
}
const std::vector<PartialTensorShape>& output_shapes() const override {
return output_shapes_;
}
string DebugString() const override {
return strings::StrCat("DirectedInterleaveDatasetOp::Dataset");
}
Status CheckExternalState() const override {
for (const auto& input : data_inputs_) {
TF_RETURN_IF_ERROR(input->CheckExternalState());
}
return selector_input_->CheckExternalState();
}
protected:
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,
Node** output) const override {
Node* selector_input_node;
TF_RETURN_IF_ERROR(
b->AddInputDataset(ctx, selector_input_, &selector_input_node));
std::vector<Node*> data_input_nodes(data_inputs_.size());
for (size_t i = 0; i < data_inputs_.size(); ++i) {
TF_RETURN_IF_ERROR(
b->AddInputDataset(ctx, data_inputs_[i], &data_input_nodes[i]));
}
TF_RETURN_IF_ERROR(b->AddDataset(this, {{0, selector_input_node}},
{{1, data_input_nodes}}, {}, output));
return Status::OK();
}
private:
class Iterator : public DatasetIterator<Dataset> {
public:
explicit Iterator(const Params& params)
: DatasetIterator<Dataset>(params),
num_active_inputs_(params.dataset->data_inputs_.size()) {}
Status Initialize(IteratorContext* ctx) override {
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(dataset()->selector_input_->MakeIterator(
ctx, this, strings::StrCat(prefix()), &selector_input_impl_));
data_input_impls_.resize(dataset()->data_inputs_.size());
for (size_t i = 0; i < data_input_impls_.size(); ++i) {
const DatasetBase* data_input = dataset()->data_inputs_[i];
TF_RETURN_IF_ERROR(data_input->MakeIterator(
ctx, this, strings::StrCat(prefix(), "[", i, "]"),
&data_input_impls_[i]));
}
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
mutex_lock l(mu_);
if (!selector_input_impl_) {
*end_of_sequence = true;
return Status::OK();
}
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
mutex_lock l(mu_);
if (!selector_input_impl_) {
*end_of_sequence = true;
while (true) {
std::vector<Tensor> selector_result;
*end_of_sequence = false;
TF_RETURN_IF_ERROR(selector_input_impl_->GetNext(ctx, &selector_result,
end_of_sequence));
if (*end_of_sequence) {
selector_input_impl_.reset();
for (auto& data_input_impl : data_input_impls_) {
data_input_impl.reset();
}
return Status::OK();
}
while (true) {
std::vector<Tensor> selector_result;
*end_of_sequence = false;
TF_RETURN_IF_ERROR(selector_input_impl_->GetNext(
ctx, &selector_result, end_of_sequence));
if (*end_of_sequence) {
selector_input_impl_.reset();
for (auto& data_input_impl : data_input_impls_) {
data_input_impl.reset();
}
int64 selected_input = selector_result[0].scalar<int64>()();
if (selected_input < 0 || selected_input >= data_input_impls_.size()) {
return errors::InvalidArgument(
"Selector index out of range: ", selected_input,
" >= ", data_input_impls_.size());
}
if (data_input_impls_[selected_input]) {
bool end_of_selected_input = false;
TF_RETURN_IF_ERROR(data_input_impls_[selected_input]->GetNext(
ctx, out_tensors, &end_of_selected_input));
if (!end_of_selected_input) {
return Status::OK();
}
int64 selected_input = selector_result[0].scalar<int64>()();
if (selected_input < 0 ||
selected_input >= data_input_impls_.size()) {
return errors::InvalidArgument(
"Selector index out of range: ", selected_input,
" >= ", data_input_impls_.size());
}
data_input_impls_[selected_input].reset();
--num_active_inputs_;
if (data_input_impls_[selected_input]) {
bool end_of_selected_input = false;
TF_RETURN_IF_ERROR(data_input_impls_[selected_input]->GetNext(
ctx, out_tensors, &end_of_selected_input));
if (!end_of_selected_input) {
return Status::OK();
}
data_input_impls_[selected_input].reset();
--num_active_inputs_;
if (num_active_inputs_ == 0) {
selector_input_impl_.reset();
*end_of_sequence = true;
return Status::OK();
}
}
VLOG(2) << "DirectedInterleave selected an exhausted input: "
<< selected_input;
}
}
protected:
std::shared_ptr<model::Node> CreateNode(
IteratorContext* ctx, model::Node::Args args) const override {
return model::MakeInterleaveManyNode(std::move(args));
}
Status SaveInternal(SerializationContext* ctx,
IteratorStateWriter* writer) override {
mutex_lock l(mu_);
if (selector_input_impl_) {
TF_RETURN_IF_ERROR(SaveInput(ctx, writer, selector_input_impl_));
} else {
TF_RETURN_IF_ERROR(
writer->WriteScalar(full_name("selector_input_impl_empty"), ""));
}
for (size_t i = 0; i < data_input_impls_.size(); ++i) {
const auto& data_input_impl = data_input_impls_[i];
if (data_input_impl) {
TF_RETURN_IF_ERROR(SaveInput(ctx, writer, data_input_impl));
} else {
TF_RETURN_IF_ERROR(writer->WriteScalar(
full_name(strings::StrCat("data_input_impl_empty[", i, "]")),
""));
if (num_active_inputs_ == 0) {
selector_input_impl_.reset();
*end_of_sequence = true;
return Status::OK();
}
}
return Status::OK();
}
Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override {
mutex_lock l(mu_);
if (!reader->Contains(full_name("selector_input_impl_empty"))) {
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, selector_input_impl_));
} else {
selector_input_impl_.reset();
}
for (size_t i = 0; i < data_input_impls_.size(); ++i) {
if (!reader->Contains(full_name(
strings::StrCat("data_input_impl_empty[", i, "]")))) {
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, data_input_impls_[i]));
} else {
data_input_impls_[i].reset();
}
}
return Status::OK();
VLOG(2) << "DirectedInterleave selected an exhausted input: "
<< selected_input;
}
private:
mutex mu_;
std::unique_ptr<IteratorBase> selector_input_impl_ TF_GUARDED_BY(mu_);
std::vector<std::unique_ptr<IteratorBase>> data_input_impls_
TF_GUARDED_BY(mu_);
int64 num_active_inputs_ TF_GUARDED_BY(mu_);
};
static PartialTensorShape MostSpecificCompatibleShape(
const PartialTensorShape& ts1, const PartialTensorShape& ts2) {
PartialTensorShape output_tensorshape;
if (ts1.dims() != ts2.dims() || ts1.unknown_rank() || ts2.unknown_rank())
return output_tensorshape;
auto dims1 = ts1.dim_sizes();
auto dims2 = ts2.dim_sizes();
for (int d = 0; d < ts1.dims(); d++) {
if (dims1[d] == dims2[d])
output_tensorshape.Concatenate(dims1[d]);
else
output_tensorshape.Concatenate(-1);
}
return output_tensorshape;
}
const DatasetBase* const selector_input_;
const std::vector<DatasetBase*> data_inputs_;
std::vector<PartialTensorShape> output_shapes_;
};
protected:
std::shared_ptr<model::Node> CreateNode(
IteratorContext* ctx, model::Node::Args args) const override {
return model::MakeInterleaveManyNode(std::move(args));
}
Status SaveInternal(SerializationContext* ctx,
IteratorStateWriter* writer) override {
mutex_lock l(mu_);
if (selector_input_impl_) {
TF_RETURN_IF_ERROR(SaveInput(ctx, writer, selector_input_impl_));
} else {
TF_RETURN_IF_ERROR(writer->WriteScalar(
full_name(strings::StrCat("data_input_impl_empty[", i, "]")), ""));
}
for (size_t i = 0; i < data_input_impls_.size(); ++i) {
const auto& data_input_impl = data_input_impls_[i];
if (data_input_impl) {
TF_RETURN_IF_ERROR(SaveInput(ctx, writer, data_input_impl));
} else {
TF_RETURN_IF_ERROR(writer->WriteScalar(
full_name(strings::StrCat("data_input_impl_empty[", i, "]")),
""));
}
}
return Status::OK();
}
return Status::OK();
}
Status
RestoreInternal(IteratorContext* ctx, IteratorStateReader* reader) override {
mutex_lock l(mu_);
if (!reader->Contains(full_name("selector_input_impl_empty"))) {
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, selector_input_impl_));
} else {
selector_input_impl_.reset();
}
for (size_t i = 0; i < data_input_impls_.size(); ++i) {
if (!reader->Contains(
full_name(strings::StrCat("data_input_impl_empty[", i, "]")))) {
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, data_input_impls_[i]));
} else {
data_input_impls_[i].reset();
}
}
return Status::OK();
}
private:
mutex mu_;
std::unique_ptr<IteratorBase> selector_input_impl_ TF_GUARDED_BY(mu_);
std::vector<std::unique_ptr<IteratorBase>> data_input_impls_
TF_GUARDED_BY(mu_);
int64 num_active_inputs_ TF_GUARDED_BY(mu_);
};
static PartialTensorShape MostSpecificCompatibleShape(
const PartialTensorShape& ts1, const PartialTensorShape& ts2) {
PartialTensorShape output_tensorshape;
if (ts1.dims() != ts2.dims() || ts1.unknown_rank() || ts2.unknown_rank())
return output_tensorshape;
auto dims1 = ts1.dim_sizes();
auto dims2 = ts2.dim_sizes();
for (int d = 0; d < ts1.dims(); ++d) {
if (dims1[d] == dims2[d])
output_tensorshape.Concatenate(dims1[d]);
else
output_tensorshape.Concatenate(-1);
}
return output_tensorshape;
}
const DatasetBase* const selector_input_;
const std::vector<DatasetBase*> data_inputs_;
std::vector<PartialTensorShape> output_shapes_;
}; // namespace experimental
DirectedInterleaveDatasetOp::DirectedInterleaveDatasetOp(
OpKernelConstruction* ctx)
: DatasetOpKernel(ctx) {}
void DirectedInterleaveDatasetOp::MakeDataset(OpKernelContext* ctx,
DatasetBase** output) {
DatasetBase* selector_input;
OP_REQUIRES_OK(ctx,
GetDatasetFromVariantTensor(ctx->input(0), &selector_input));
OP_REQUIRES(
ctx,
selector_input->output_dtypes().size() == 1 &&
selector_input->output_dtypes()[0] == DT_INT64 &&
selector_input->output_shapes().size() == 1 &&
selector_input->output_shapes()[0].IsCompatibleWith(
PartialTensorShape({})),
errors::InvalidArgument(
"The selector input must be a dataset of scalar int64 elements."));
std::vector<DatasetBase*> data_inputs;
for (size_t i = 1; i < ctx->num_inputs(); ++i) {
DatasetBase* input;
OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(i), &input));
data_inputs.push_back(input);
OP_REQUIRES(ctx, data_inputs[0]->output_dtypes() == input->output_dtypes(),
errors::InvalidArgument(
"All inputs must have the same output_dtypes. First input "
"has types ",
DataTypeVectorString(data_inputs[0]->output_dtypes()),
", and input ", i - 1, " has types ",
DataTypeVectorString(input->output_dtypes())));
}
*output = new Dataset(ctx, selector_input, std::move(data_inputs));
}
namespace {
REGISTER_KERNEL_BUILDER(Name("DirectedInterleaveDataset").Device(DEVICE_CPU),
DirectedInterleaveDatasetOp);
REGISTER_KERNEL_BUILDER(
Name("ExperimentalDirectedInterleaveDataset").Device(DEVICE_CPU),
DirectedInterleaveDatasetOp);
} // namespace
} // namespace experimental
} // namespace data
} // namespace tensorflow
} // namespace tensorflow

View File

@ -0,0 +1,47 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_DIRECTED_INTERLEAVE_DATASET_OP_H_
#define TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_DIRECTED_INTERLEAVE_DATASET_OP_H_
#include "tensorflow/core/framework/dataset.h"
namespace tensorflow {
namespace data {
namespace experimental {
class DirectedInterleaveDatasetOp : public DatasetOpKernel {
public:
static constexpr const char* const kDatasetType = "DirectedInterleave";
static constexpr const char* const kSelectorInputDataset =
"selector_input_dataset";
static constexpr const char* const kDataInputDatasets = "data_input_datasets";
static constexpr const char* const kOutputTypes = "output_types";
static constexpr const char* const kOutputShapes = "output_shapes";
static constexpr const char* const kNumDatasets = "N";
explicit DirectedInterleaveDatasetOp(OpKernelConstruction* ctx);
protected:
void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override;
private:
class Dataset;
};
} // namespace experimental
} // namespace data
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_DIRECTED_INTERLEAVE_DATASET_OP_H_