[tf.data] Removing an incorrect assumption from the input pipeline optimization logic.
PiperOrigin-RevId: 203546845
This commit is contained in:
parent
af3a5c2230
commit
1caaea99e0
@ -1,5 +1,5 @@
|
||||
op {
|
||||
graph_op_name: "IdentityDataset"
|
||||
graph_op_name: "SinkDataset"
|
||||
visibility: HIDDEN
|
||||
in_arg {
|
||||
name: "input_dataset"
|
@ -562,15 +562,6 @@ tf_kernel_library(
|
||||
],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "identity_dataset_op",
|
||||
srcs = ["identity_dataset_op.cc"],
|
||||
deps = [
|
||||
":dataset",
|
||||
"//tensorflow/core:framework",
|
||||
],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "optimize_dataset_op",
|
||||
srcs = ["optimize_dataset_op.cc"],
|
||||
@ -619,7 +610,6 @@ tf_kernel_library(
|
||||
":generator_dataset_op",
|
||||
":group_by_reducer_dataset_op",
|
||||
":group_by_window_dataset_op",
|
||||
":identity_dataset_op",
|
||||
":interleave_dataset_op",
|
||||
":iterator_ops",
|
||||
":map_and_batch_dataset_op",
|
||||
|
@ -1,102 +0,0 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include <map>
|
||||
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/kernels/data/dataset.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
// The purpose of identity dataset is to serve as a placeholder when performing
|
||||
// optimizations. It is not expected to be surfaced in the Python API.
|
||||
class IdentityDatasetOp : public UnaryDatasetOpKernel {
|
||||
public:
|
||||
explicit IdentityDatasetOp(OpKernelConstruction* ctx)
|
||||
: UnaryDatasetOpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
|
||||
}
|
||||
|
||||
protected:
|
||||
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
|
||||
DatasetBase** output) override {
|
||||
*output = new Dataset(ctx, input);
|
||||
}
|
||||
|
||||
private:
|
||||
class Dataset : public GraphDatasetBase {
|
||||
public:
|
||||
Dataset(OpKernelContext* ctx, const DatasetBase* input)
|
||||
: GraphDatasetBase(ctx), input_(input) {
|
||||
input_->Ref();
|
||||
}
|
||||
|
||||
~Dataset() override { input_->Unref(); }
|
||||
|
||||
std::unique_ptr<IteratorBase> MakeIteratorInternal(
|
||||
const string& prefix) const override {
|
||||
return std::unique_ptr<IteratorBase>(
|
||||
new Iterator({this, strings::StrCat(prefix, "::Identity")}));
|
||||
}
|
||||
|
||||
const DataTypeVector& output_dtypes() const override {
|
||||
return input_->output_dtypes();
|
||||
}
|
||||
|
||||
const std::vector<PartialTensorShape>& output_shapes() const override {
|
||||
return input_->output_shapes();
|
||||
}
|
||||
|
||||
string DebugString() const override { return "IdentityDatasetOp::Dataset"; }
|
||||
|
||||
protected:
|
||||
Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b,
|
||||
Node** output) const override {
|
||||
Node* input_graph_node = nullptr;
|
||||
TF_RETURN_IF_ERROR(b->AddParentDataset(ctx, input_, &input_graph_node));
|
||||
TF_RETURN_IF_ERROR(b->AddDataset(this, {input_graph_node}, output));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
class Iterator : public DatasetIterator<Dataset> {
|
||||
public:
|
||||
explicit Iterator(const Params& params)
|
||||
: DatasetIterator<Dataset>(params) {}
|
||||
|
||||
Status Initialize(IteratorContext* ctx) override {
|
||||
return errors::Unimplemented(strings::StrCat(prefix(), "::Initialize"));
|
||||
}
|
||||
|
||||
Status GetNextInternal(IteratorContext* ctx,
|
||||
std::vector<Tensor>* out_tensors,
|
||||
bool* end_of_sequence) override {
|
||||
return errors::Unimplemented(
|
||||
strings::StrCat(prefix(), "::GetNextInternal"));
|
||||
}
|
||||
};
|
||||
|
||||
const DatasetBase* const input_;
|
||||
};
|
||||
|
||||
DataTypeVector output_types_;
|
||||
std::vector<PartialTensorShape> output_shapes_;
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("IdentityDataset").Device(DEVICE_CPU),
|
||||
IdentityDatasetOp);
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
@ -136,16 +136,8 @@ class OptimizeDatasetOp : public UnaryDatasetOpKernel {
|
||||
// Add a fake sink node to allow rewriting the actual sink node.
|
||||
NodeDef* node = graph_def->mutable_node()->Add();
|
||||
node->set_name("FakeSink");
|
||||
node->set_op("IdentityDataset");
|
||||
node->set_op("SinkDataset");
|
||||
node->add_input(*output_node);
|
||||
{
|
||||
grappler::GraphView graph(graph_def);
|
||||
NodeDef* sink = graph.GetNode(*output_node);
|
||||
(*node->mutable_attr())["output_shapes"] =
|
||||
sink->attr().at("output_shapes");
|
||||
(*node->mutable_attr())["output_types"] =
|
||||
sink->attr().at("output_types");
|
||||
}
|
||||
|
||||
// Create metagraph.
|
||||
MetaGraphDef meta_graph_def;
|
||||
|
@ -26275,29 +26275,6 @@ op {
|
||||
type: "type"
|
||||
}
|
||||
}
|
||||
op {
|
||||
name: "IdentityDataset"
|
||||
input_arg {
|
||||
name: "input_dataset"
|
||||
type: DT_VARIANT
|
||||
}
|
||||
output_arg {
|
||||
name: "handle"
|
||||
type: DT_VARIANT
|
||||
}
|
||||
attr {
|
||||
name: "output_types"
|
||||
type: "list(type)"
|
||||
has_minimum: true
|
||||
minimum: 1
|
||||
}
|
||||
attr {
|
||||
name: "output_shapes"
|
||||
type: "list(shape)"
|
||||
has_minimum: true
|
||||
minimum: 1
|
||||
}
|
||||
}
|
||||
op {
|
||||
name: "IdentityN"
|
||||
input_arg {
|
||||
|
@ -798,11 +798,9 @@ REGISTER_OP("DatasetToGraph")
|
||||
.Output("graph: string")
|
||||
.SetShapeFn(shape_inference::ScalarShape);
|
||||
|
||||
REGISTER_OP("IdentityDataset")
|
||||
REGISTER_OP("SinkDataset")
|
||||
.Input("input_dataset: variant")
|
||||
.Output("handle: variant")
|
||||
.Attr("output_types: list(type) >= 1")
|
||||
.Attr("output_shapes: list(shape) >= 1")
|
||||
.SetShapeFn(shape_inference::ScalarShape);
|
||||
|
||||
REGISTER_OP("OptimizeDataset")
|
||||
|
Loading…
x
Reference in New Issue
Block a user