Refactor DirectedInterleaveDatasetOp

This commit is contained in:
feihugis 2020-03-16 14:31:44 -05:00
parent 34af8d45d0
commit 9b6fd77bd6
3 changed files with 311 additions and 235 deletions

View File

@ -135,14 +135,31 @@ tf_kernel_library(
tf_kernel_library(
name = "directed_interleave_dataset_op",
srcs = ["directed_interleave_dataset_op.cc"],
hdrs = ["directed_interleave_dataset_op.h"],
deps = [
"//tensorflow/core:experimental_dataset_ops_op_lib",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core/kernels/data:name_utils",
"//third_party/eigen3",
],
)
tf_cc_test(
name = "directed_interleave_dataset_op_test",
size = "small",
srcs = ["directed_interleave_dataset_op_test.cc"],
deps = [
":directed_interleave_dataset_op",
"//tensorflow/core:experimental_dataset_ops_op_lib",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
"//tensorflow/core/kernels/data:dataset_test_base",
"//tensorflow/core/kernels/data:range_dataset_op",
"//tensorflow/core/kernels/data:tensor_slice_dataset_op",
],
)
tf_kernel_library(
name = "group_by_reducer_dataset_op",
srcs = ["group_by_reducer_dataset_op.cc"],

View File

@ -12,56 +12,31 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op.h"
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/kernels/data/name_utils.h"
#include "tensorflow/core/lib/hash/hash.h"
namespace tensorflow {
namespace data {
namespace experimental {
namespace {
class DirectedInterleaveDatasetOp : public DatasetOpKernel {
public:
explicit DirectedInterleaveDatasetOp(OpKernelConstruction* ctx)
: DatasetOpKernel(ctx) {}
/* static */ constexpr const char* const
DirectedInterleaveDatasetOp::kDatasetType;
/* static */ constexpr const char* const
DirectedInterleaveDatasetOp::kSelectorInputDataset;
/* static */ constexpr const char* const
DirectedInterleaveDatasetOp::kDataInputDatasets;
/* static */ constexpr const char* const
DirectedInterleaveDatasetOp::kOutputTypes;
/* static */ constexpr const char* const
DirectedInterleaveDatasetOp::kOutputShapes;
/* static */ constexpr const char* const
DirectedInterleaveDatasetOp::kNumDatasets;
void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
DatasetBase* selector_input;
OP_REQUIRES_OK(ctx,
GetDatasetFromVariantTensor(ctx->input(0), &selector_input));
OP_REQUIRES(
ctx,
selector_input->output_dtypes().size() == 1 &&
selector_input->output_dtypes()[0] == DT_INT64 &&
selector_input->output_shapes().size() == 1 &&
selector_input->output_shapes()[0].IsCompatibleWith(
PartialTensorShape({})),
errors::InvalidArgument(
"The selector input must be a dataset of scalar int64 elements."));
std::vector<DatasetBase*> data_inputs;
for (size_t i = 1; i < ctx->num_inputs(); ++i) {
DatasetBase* input;
OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(i), &input));
data_inputs.push_back(input);
OP_REQUIRES(
ctx, data_inputs[0]->output_dtypes() == input->output_dtypes(),
errors::InvalidArgument(
"All inputs must have the same output_dtypes. First input "
"has types ",
DataTypeVectorString(data_inputs[0]->output_dtypes()),
", and input ", i - 1, " has types ",
DataTypeVectorString(input->output_dtypes())));
}
*output = new Dataset(ctx, selector_input, std::move(data_inputs));
}
private:
class Dataset : public DatasetBase {
class DirectedInterleaveDatasetOp::Dataset : public DatasetBase {
public:
Dataset(OpKernelContext* ctx, const DatasetBase* selector_input,
std::vector<DatasetBase*> data_inputs)
@ -92,7 +67,7 @@ class DirectedInterleaveDatasetOp : public DatasetOpKernel {
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return absl::make_unique<Iterator>(Iterator::Params{
this, strings::StrCat(prefix, "::DirectedInterleave")});
this, name_utils::IteratorPrefix(kDatasetType, prefix)});
}
const DataTypeVector& output_dtypes() const override {
@ -104,7 +79,7 @@ class DirectedInterleaveDatasetOp : public DatasetOpKernel {
}
string DebugString() const override {
return strings::StrCat("DirectedInterleaveDatasetOp::Dataset");
return name_utils::DatasetDebugString(kDatasetType);
}
Status CheckExternalState() const override {
@ -141,7 +116,7 @@ class DirectedInterleaveDatasetOp : public DatasetOpKernel {
Status Initialize(IteratorContext* ctx) override {
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(dataset()->selector_input_->MakeIterator(
ctx, this, strings::StrCat(prefix()), &selector_input_impl_));
ctx, this, prefix(), &selector_input_impl_));
data_input_impls_.resize(dataset()->data_inputs_.size());
for (size_t i = 0; i < data_input_impls_.size(); ++i) {
const DatasetBase* data_input = dataset()->data_inputs_[i];
@ -164,8 +139,8 @@ class DirectedInterleaveDatasetOp : public DatasetOpKernel {
while (true) {
std::vector<Tensor> selector_result;
*end_of_sequence = false;
TF_RETURN_IF_ERROR(selector_input_impl_->GetNext(
ctx, &selector_result, end_of_sequence));
TF_RETURN_IF_ERROR(selector_input_impl_->GetNext(ctx, &selector_result,
end_of_sequence));
if (*end_of_sequence) {
selector_input_impl_.reset();
for (auto& data_input_impl : data_input_impls_) {
@ -175,8 +150,7 @@ class DirectedInterleaveDatasetOp : public DatasetOpKernel {
}
int64 selected_input = selector_result[0].scalar<int64>()();
if (selected_input < 0 ||
selected_input >= data_input_impls_.size()) {
if (selected_input < 0 || selected_input >= data_input_impls_.size()) {
return errors::InvalidArgument(
"Selector index out of range: ", selected_input,
" >= ", data_input_impls_.size());
@ -218,8 +192,8 @@ class DirectedInterleaveDatasetOp : public DatasetOpKernel {
if (selector_input_impl_) {
TF_RETURN_IF_ERROR(SaveInput(ctx, writer, selector_input_impl_));
} else {
TF_RETURN_IF_ERROR(
writer->WriteScalar(full_name("selector_input_impl_empty"), ""));
TF_RETURN_IF_ERROR(writer->WriteScalar(
full_name(strings::StrCat("data_input_impl_empty[", i, "]")), ""));
}
for (size_t i = 0; i < data_input_impls_.size(); ++i) {
const auto& data_input_impl = data_input_impls_[i];
@ -233,9 +207,11 @@ class DirectedInterleaveDatasetOp : public DatasetOpKernel {
}
return Status::OK();
}
return Status::OK();
}
Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override {
Status
RestoreInternal(IteratorContext* ctx, IteratorStateReader* reader) override {
mutex_lock l(mu_);
if (!reader->Contains(full_name("selector_input_impl_empty"))) {
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, selector_input_impl_));
@ -243,8 +219,8 @@ class DirectedInterleaveDatasetOp : public DatasetOpKernel {
selector_input_impl_.reset();
}
for (size_t i = 0; i < data_input_impls_.size(); ++i) {
if (!reader->Contains(full_name(
strings::StrCat("data_input_impl_empty[", i, "]")))) {
if (!reader->Contains(
full_name(strings::StrCat("data_input_impl_empty[", i, "]")))) {
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, data_input_impls_[i]));
} else {
data_input_impls_[i].reset();
@ -259,37 +235,73 @@ class DirectedInterleaveDatasetOp : public DatasetOpKernel {
std::vector<std::unique_ptr<IteratorBase>> data_input_impls_
TF_GUARDED_BY(mu_);
int64 num_active_inputs_ TF_GUARDED_BY(mu_);
};
};
static PartialTensorShape MostSpecificCompatibleShape(
static PartialTensorShape MostSpecificCompatibleShape(
const PartialTensorShape& ts1, const PartialTensorShape& ts2) {
PartialTensorShape output_tensorshape;
if (ts1.dims() != ts2.dims() || ts1.unknown_rank() || ts2.unknown_rank())
return output_tensorshape;
auto dims1 = ts1.dim_sizes();
auto dims2 = ts2.dim_sizes();
for (int d = 0; d < ts1.dims(); d++) {
for (int d = 0; d < ts1.dims(); ++d) {
if (dims1[d] == dims2[d])
output_tensorshape.Concatenate(dims1[d]);
else
output_tensorshape.Concatenate(-1);
}
return output_tensorshape;
}
const DatasetBase* const selector_input_;
const std::vector<DatasetBase*> data_inputs_;
std::vector<PartialTensorShape> output_shapes_;
}; // namespace experimental
DirectedInterleaveDatasetOp::DirectedInterleaveDatasetOp(
OpKernelConstruction* ctx)
: DatasetOpKernel(ctx) {}
void DirectedInterleaveDatasetOp::MakeDataset(OpKernelContext* ctx,
DatasetBase** output) {
DatasetBase* selector_input;
OP_REQUIRES_OK(ctx,
GetDatasetFromVariantTensor(ctx->input(0), &selector_input));
OP_REQUIRES(
ctx,
selector_input->output_dtypes().size() == 1 &&
selector_input->output_dtypes()[0] == DT_INT64 &&
selector_input->output_shapes().size() == 1 &&
selector_input->output_shapes()[0].IsCompatibleWith(
PartialTensorShape({})),
errors::InvalidArgument(
"The selector input must be a dataset of scalar int64 elements."));
std::vector<DatasetBase*> data_inputs;
for (size_t i = 1; i < ctx->num_inputs(); ++i) {
DatasetBase* input;
OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(i), &input));
data_inputs.push_back(input);
OP_REQUIRES(ctx, data_inputs[0]->output_dtypes() == input->output_dtypes(),
errors::InvalidArgument(
"All inputs must have the same output_dtypes. First input "
"has types ",
DataTypeVectorString(data_inputs[0]->output_dtypes()),
", and input ", i - 1, " has types ",
DataTypeVectorString(input->output_dtypes())));
}
*output = new Dataset(ctx, selector_input, std::move(data_inputs));
}
const DatasetBase* const selector_input_;
const std::vector<DatasetBase*> data_inputs_;
std::vector<PartialTensorShape> output_shapes_;
};
};
namespace {
REGISTER_KERNEL_BUILDER(Name("DirectedInterleaveDataset").Device(DEVICE_CPU),
DirectedInterleaveDatasetOp);
REGISTER_KERNEL_BUILDER(
Name("ExperimentalDirectedInterleaveDataset").Device(DEVICE_CPU),
DirectedInterleaveDatasetOp);
} // namespace
} // namespace experimental
} // namespace data
} // namespace tensorflow
} // namespace tensorflow

View File

@ -0,0 +1,47 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_DIRECTED_INTERLEAVE_DATASET_OP_H_
#define TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_DIRECTED_INTERLEAVE_DATASET_OP_H_
#include "tensorflow/core/framework/dataset.h"
namespace tensorflow {
namespace data {
namespace experimental {
class DirectedInterleaveDatasetOp : public DatasetOpKernel {
public:
static constexpr const char* const kDatasetType = "DirectedInterleave";
static constexpr const char* const kSelectorInputDataset =
"selector_input_dataset";
static constexpr const char* const kDataInputDatasets = "data_input_datasets";
static constexpr const char* const kOutputTypes = "output_types";
static constexpr const char* const kOutputShapes = "output_shapes";
static constexpr const char* const kNumDatasets = "N";
explicit DirectedInterleaveDatasetOp(OpKernelConstruction* ctx);
protected:
void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override;
private:
class Dataset;
};
} // namespace experimental
} // namespace data
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_DIRECTED_INTERLEAVE_DATASET_OP_H_