Adding a RebatchDatasetOp that uses Grappler to modify the dataset graph to mutate the batch size of the Dataset.

Right now this basically does all the wiring through Grappler etc. and the logic is very minimal and basic. Subsequent CL's will make the logic expand to different scenarios.

PiperOrigin-RevId: 231983575
This commit is contained in:
Rohan Jain 2019-02-01 09:25:37 -08:00 committed by TensorFlower Gardener
parent 90a8b12b28
commit de87e628e6
21 changed files with 874 additions and 264 deletions

View File

@ -0,0 +1,23 @@
op {
graph_op_name: "ExperimentalRebatchDataset"
visibility: HIDDEN
in_arg {
name: "input_dataset"
description: <<END
A variant tensor representing the input dataset.
END
}
in_arg {
name: "num_workers"
description: <<END
A scalar representing the number of workers to distribute this batch across. As
a result of this transformation the current batch size would end up being
divided by this parameter.
END
}
summary: "Creates a dataset that changes the batch size."
description: <<END
Creates a dataset that changes the batch size of the dataset to current batch
size // num_workers.
END
}

View File

@ -6,6 +6,7 @@ load("//tensorflow/core:platform/default/build_config.bzl", "tf_protos_all")
package(default_visibility = [
"//tensorflow/core/grappler/optimizers/data:__subpackages__",
"//tensorflow/core/kernels/data:__pkg__",
"//tensorflow/core/kernels/data/experimental:__pkg__",
])
cc_library(
@ -540,6 +541,20 @@ cc_library(
alwayslink = 1,
)
cc_library(
name = "rebatch",
srcs = ["rebatch.cc"],
hdrs = ["rebatch.h"],
deps = [
":graph_utils",
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler:mutable_graph_view",
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
],
alwayslink = 1,
)
cc_library(
name = "noop_elimination",
srcs = ["noop_elimination.cc"],

View File

@ -232,6 +232,13 @@ NodeDef* GetInputNode(const NodeDef& node, const MutableGraphView& graph) {
return graph.GetRegularFanin(input_port).node;
}
NodeDef* GetInputNode(const NodeDef& node, const MutableGraphView& graph,
int64 i) {
if (node.input_size() <= i) return nullptr;
MutableGraphView::InputPort input_port = graph.GetInputPort(node.name(), i);
return graph.GetRegularFanin(input_port).node;
}
void SetUniqueGraphNodeName(StringPiece prefix, GraphDef* graph,
NodeDef* node) {
string name = string(prefix);

View File

@ -108,6 +108,10 @@ int FindGraphNodeWithOp(StringPiece op, const GraphDef& graph);
// Gets the 0th input to a node in the graph.
NodeDef* GetInputNode(const NodeDef& node, const MutableGraphView& graph);
// Gets the ith input to a node in the graph.
NodeDef* GetInputNode(const NodeDef& node, const MutableGraphView& graph,
int64 i);
// Returns the list of indices of all nodes with the given op or empty list if
// no such node exists.
std::vector<int> FindAllGraphNodesWithOp(const string& op,

View File

@ -228,6 +228,21 @@ TEST(GraphUtilsTest, GetInputNode) {
EXPECT_EQ(GetInputNode(*node1, graph), nullptr);
}
TEST(GraphUtilsTest, GetIthInputNode) {
GraphDef graph_def;
MutableGraphView graph(&graph_def);
NodeDef* node1 = AddNode("", "A", {}, {}, &graph);
NodeDef* node2 = AddNode("", "A", {}, {}, &graph);
NodeDef* node3 = AddNode("", "A", {node1->name(), node2->name()}, {}, &graph);
EXPECT_EQ(GetInputNode(*node3, graph), node1);
EXPECT_EQ(GetInputNode(*node3, graph, 1), node2);
EXPECT_EQ(GetInputNode(*node3, graph, 0), node1);
EXPECT_EQ(GetInputNode(*node3, graph, 2), nullptr);
EXPECT_EQ(GetInputNode(*node1, graph), nullptr);
}
TEST(GraphUtilsTest, EnsureNodeNamesUnique) {
Graph g(OpRegistry::Global());

View File

@ -0,0 +1,115 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/optimizers/data/rebatch.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/mutable_graph_view.h"
#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
namespace tensorflow {
namespace grappler {
Status RebatchOptimizer::Init(
const tensorflow::RewriterConfig_CustomGraphOptimizer* config) {
if (!config) return Status::OK();
num_workers_ = config->parameter_map().at("num_workers").i();
return Status::OK();
}
namespace {
constexpr char kCastOp[] = "Cast";
constexpr char kRealDivOp[] = "RealDiv";
constexpr char kBatchDatasetOp[] = "BatchDatasetV2";
NodeDef* AddCastNode(const string& input, DataType src_t, DataType dst_t,
MutableGraphView* graph) {
NodeDef cast_node;
cast_node.set_op(kCastOp);
cast_node.add_input(input);
graph_utils::SetUniqueGraphNodeName(cast_node.op(), graph->graph(),
&cast_node);
AddNodeAttr("SrcT", src_t, &cast_node);
AddNodeAttr("DstT", dst_t, &cast_node);
return graph->AddNode(std::move(cast_node));
}
NodeDef* AddBinaryNode(const string& input_x, const string& input_y,
const string& op, DataType type,
MutableGraphView* graph) {
NodeDef node;
node.set_op(op);
node.add_input(input_x);
node.add_input(input_y);
graph_utils::SetUniqueGraphNodeName(op, graph->graph(), &node);
AddNodeAttr("T", type, &node);
return graph->AddNode(std::move(node));
}
NodeDef* AddFloatDivNode(const string& input_x, const string& input_y,
MutableGraphView* graph) {
return AddBinaryNode(input_x, input_y, kRealDivOp, DT_FLOAT, graph);
}
} // anonymous namespace
Status RebatchOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
GraphDef* output) {
*output = item.graph;
MutableGraphView graph(output);
absl::flat_hash_set<string> nodes_to_delete;
for (const NodeDef& node : item.graph.node()) {
if (node.op() == kBatchDatasetOp) {
NodeDef* batch_size_node = graph_utils::GetInputNode(node, graph, 1);
NodeDef tmp_node;
tmp_node = *batch_size_node;
graph_utils::SetUniqueGraphNodeName(tmp_node.op(), graph.graph(),
&tmp_node);
NodeDef* copy_batch_size_node = graph.AddNode(std::move(tmp_node));
NodeDef* float_copy_batch_size_node =
AddCastNode(copy_batch_size_node->name(), DT_INT64, DT_FLOAT, &graph);
NodeDef* num_worker_node =
graph_utils::AddScalarConstNode<int64>(num_workers_, &graph);
NodeDef* float_num_worker_node =
AddCastNode(num_worker_node->name(), DT_INT64, DT_FLOAT, &graph);
NodeDef* divided_batch_size_node =
AddFloatDivNode(float_copy_batch_size_node->name(),
float_num_worker_node->name(), &graph);
NodeDef* cast_new_batch_size_node = AddCastNode(
divided_batch_size_node->name(), DT_FLOAT, DT_INT64, &graph);
TF_RETURN_IF_ERROR(graph.UpdateFanouts(batch_size_node->name(),
cast_new_batch_size_node->name()));
nodes_to_delete.insert(batch_size_node->name());
break;
}
}
TF_RETURN_IF_ERROR(graph.DeleteNodes(nodes_to_delete));
return Status::OK();
}
void RebatchOptimizer::Feedback(Cluster* cluster, const GrapplerItem& item,
const GraphDef& optimize_output,
double result) {}
REGISTER_GRAPH_OPTIMIZER_AS(RebatchOptimizer, "tf_data_rebatcher");
} // namespace grappler
} // namespace tensorflow

View File

@ -0,0 +1,52 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_REBATCH_H_
#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_REBATCH_H_
#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h"
namespace tensorflow {
namespace grappler {
// This optimizer changes the batch size of the output dataset by dividing the
// current batch size by parameter `num_workers`. Currently, this works only
// for very simple pipelines with a single BatchDatasetV2 transformation.
//
// TODO(rohanj): Extend this logic to correctly handle any input pipeline that
// uses core tf.data APIs + MapAndBatch.
class RebatchOptimizer : public CustomGraphOptimizer {
public:
RebatchOptimizer() = default;
~RebatchOptimizer() override = default;
string name() const override { return "tf_data_rebatcher"; }
Status Init(
const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override;
Status Optimize(Cluster* cluster, const GrapplerItem& item,
GraphDef* output) override;
void Feedback(Cluster* cluster, const GrapplerItem& item,
const GraphDef& optimize_output, double result) override;
private:
int64 num_workers_;
};
} // namespace grappler
} // namespace tensorflow
#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_REBATCH_H_

View File

@ -548,17 +548,14 @@ tf_kernel_library(
],
)
tf_kernel_library(
name = "optimize_dataset_op",
srcs = ["optimize_dataset_op.cc"],
cc_library(
name = "graph_rewrite_dataset",
srcs = ["graph_rewrite_dataset.cc"],
hdrs = ["graph_rewrite_dataset.h"],
deps = [
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:dataset_ops_op_lib",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler:graph_view",
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler:grappler_item_builder",
"//tensorflow/core/grappler/clusters:virtual_cluster",
@ -569,6 +566,19 @@ tf_kernel_library(
],
)
tf_kernel_library(
name = "optimize_dataset_op",
srcs = ["optimize_dataset_op.cc"],
deps = [
":graph_rewrite_dataset",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:dataset_ops_op_lib",
"//tensorflow/core:framework",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
],
)
tf_kernel_library(
name = "model_dataset_op",
srcs = ["model_dataset_op.cc"],

View File

@ -232,6 +232,21 @@ tf_kernel_library(
],
)
tf_kernel_library(
name = "rebatch_dataset_op",
srcs = ["rebatch_dataset_op.cc"],
deps = [
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:dataset_ops_op_lib",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler/optimizers/data:rebatch",
"//tensorflow/core/kernels/data:graph_rewrite_dataset",
],
)
tf_kernel_library(
name = "scan_dataset_op",
srcs = ["scan_dataset_op.cc"],
@ -391,6 +406,7 @@ tf_kernel_library(
":parse_example_dataset_op",
":prefetching_kernels",
":random_dataset_op",
":rebatch_dataset_op",
":scan_dataset_op",
":set_stats_aggregator_dataset_op",
":sleep_dataset_op",

View File

@ -0,0 +1,92 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/kernels/data/graph_rewrite_dataset.h"
namespace tensorflow {
namespace data {
namespace {
constexpr char kOptimizerName[] = "tf_data_rebatcher";
class RebatchDatasetOp : public UnaryDatasetOpKernel {
public:
explicit RebatchDatasetOp(OpKernelConstruction* ctx)
: UnaryDatasetOpKernel(ctx),
graph_def_version_(ctx->graph_def_version()) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
}
protected:
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
DatasetBase** output) override {
int64 num_workers;
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "num_workers", &num_workers));
OP_REQUIRES(ctx, num_workers > 0,
errors::InvalidArgument(
"num_parallel_calls must be greater than zero."));
Dataset* dataset =
new Dataset(ctx, input, num_workers, output_types_, output_shapes_);
Status s = dataset->Optimize(ctx);
if (s.ok()) {
*output = dataset;
} else {
dataset->Unref();
OP_REQUIRES_OK(ctx, s);
}
}
private:
class Dataset : public GraphRewriteDataset {
public:
Dataset(OpKernelContext* ctx, const DatasetBase* input,
const int64 num_workers, const DataTypeVector& output_types,
const std::vector<PartialTensorShape>& output_shapes)
: GraphRewriteDataset(ctx, input, output_types, output_shapes),
num_workers_(num_workers) {}
string DebugString() const override { return "RebatchDatasetOp::Dataset"; }
private:
RewriterConfig CreateGrapplerRewriteConfig() override {
RewriterConfig rewriter_config;
rewriter_config.add_optimizers(kOptimizerName);
rewriter_config.set_meta_optimizer_iterations(
RewriterConfig_NumIterationsType_ONE);
auto custom_optimizer = rewriter_config.add_custom_optimizers();
custom_optimizer->set_name(kOptimizerName);
AttrValue num_workers_attr;
num_workers_attr.set_i(num_workers_);
(*custom_optimizer->mutable_parameter_map())["num_workers"] =
num_workers_attr;
return rewriter_config;
}
const int64 num_workers_;
};
const int graph_def_version_;
DataTypeVector output_types_;
std::vector<PartialTensorShape> output_shapes_;
};
REGISTER_KERNEL_BUILDER(Name("ExperimentalRebatchDataset").Device(DEVICE_CPU),
RebatchDatasetOp);
} // anonymous namespace
} // namespace data
} // namespace tensorflow

View File

@ -0,0 +1,239 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/data/graph_rewrite_dataset.h"
#include "tensorflow/core/protobuf/meta_graph.pb.h"
#include "tensorflow/core/protobuf/rewriter_config.pb.h"
namespace tensorflow {
namespace data {
GraphRewriteDataset::~GraphRewriteDataset() {
input_->Unref();
if (optimized_input_) {
optimized_input_->Unref();
}
}
Status GraphRewriteDataset::Optimize(OpKernelContext* ctx) {
GraphDefBuilder b;
DatasetGraphDefBuilder db(&b);
Node* input_node = nullptr;
SerializationContext::Params params;
std::vector<std::pair<string, Tensor>> input_list;
params.flib_def = ctx->function_library()->GetFunctionLibraryDefinition();
params.input_list = &input_list;
params.optimization_only = true;
SerializationContext serialization_ctx(params);
TF_RETURN_IF_ERROR(
db.AddInputDataset(&serialization_ctx, input_, &input_node));
string output_node = input_node->name();
GraphDef graph_def;
TF_RETURN_IF_ERROR(b.ToGraphDef(&graph_def));
VLOG(3) << "Before optimization: " << graph_def.DebugString();
TF_RETURN_IF_ERROR(ApplyOptimizations(ctx, &graph_def, &output_node));
VLOG(3) << "After optimization: " << graph_def.DebugString();
// Instantiate the optimized input pipeline by running the optimized graph
// using the optimized function library.
TF_RETURN_IF_ERROR(ctx->function_library()->Clone(&flib_def_, &pflr_, &lib_));
// Create a FunctionHandleCache.
function_handle_cache_ = absl::make_unique<FunctionHandleCache>(lib_);
// Some functions may have been modified without having their names
// changed (for example, nested dataset graphs from FlatMap or
// Interleave). To avoid name conflicts, we remove these functions from
// flib_def_ before adding the optimized function library.
for (const FunctionDef& fd : graph_def.library().function()) {
if (flib_def_->Find(fd.signature().name()) != nullptr) {
TF_RETURN_IF_ERROR(flib_def_->RemoveFunction(fd.signature().name()));
}
}
TF_RETURN_IF_ERROR(flib_def_->AddLibrary(graph_def.library()));
Graph graph(OpRegistry::Global());
TF_RETURN_IF_ERROR(ImportGraphDef({}, graph_def, &graph, nullptr));
std::vector<Tensor> outputs;
GraphRunner graph_runner(ctx->function_library()->device());
TF_RETURN_IF_ERROR(
graph_runner.Run(&graph, lib_, input_list, {output_node}, &outputs));
TF_RETURN_IF_ERROR(
GetDatasetFromVariantTensor(outputs[0], &optimized_input_));
optimized_input_->Ref();
return Status::OK();
}
Status GraphRewriteDataset::AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,
Node** output) const {
// We only serialize the optimized dataset to avoid re-running
// optimizations when the input pipeline is restored from a checkpoint.
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, optimized_input_, output));
return Status::OK();
}
namespace {
void AddFakeSinks(FunctionDef* function_def) {
int counter = 0;
for (const auto& output : function_def->signature().output_arg()) {
NodeDef* node = function_def->add_node_def();
tensorflow::grappler::function_utils::SetUniqueFunctionNodeName(
strings::StrCat("FakeSink", counter++), function_def, node);
node->set_op("Identity");
node->add_input(function_def->ret().at(output.name()));
(*node->mutable_attr())["T"].set_type(output.type());
(*function_def->mutable_ret())[output.name()] =
strings::StrCat(node->name(), ":output:0");
}
}
void RemoveFakeSinks(FunctionDef* function_def) {
// Map from identity node names to their input tensor strings
std::map<string, string> identity_map;
for (const auto& node : function_def->node_def()) {
if (node.op() == "Identity" && node.input_size() == 1) {
identity_map[node.name()] = node.input(0);
}
}
for (const auto& output_arg : function_def->signature().output_arg()) {
const string& tensor = function_def->ret().at(output_arg.name());
const string& output_node = tensor.substr(0, tensor.find(':'));
if (identity_map.find(output_node) != identity_map.end()) {
(*function_def->mutable_ret())[output_arg.name()] =
identity_map.at(output_node);
}
}
}
} // anonymous namespace
Status GraphRewriteDataset::ApplyOptimizations(OpKernelContext* ctx,
GraphDef* graph_def,
string* output_node) {
// Add an identity node as the fetch node, otherwise we might get
// 'placeholder is both fed and fetched' errors in some cases when using
// input list with placeholder dataset nodes.
NodeDef* node = graph_def->mutable_node()->Add();
tensorflow::grappler::graph_utils::SetUniqueGraphNodeName("Sink", graph_def,
node);
node->set_op("Identity");
node->add_input(*output_node);
(*node->mutable_attr())["T"].set_type(DT_VARIANT);
*output_node = node->name();
// Add fake sink node to graph and functions to allow rewriting the actual
// sink nodes.
// TODO(b/118820916): When MetaOptimizer adds provisions for function
// retvals to be optimizable, we will no longer need this.
for (auto& function_def : *graph_def->mutable_library()->mutable_function()) {
AddFakeSinks(&function_def);
}
// Create metagraph.
MetaGraphDef meta_graph_def;
(*meta_graph_def.mutable_graph_def()) = *graph_def;
// Grappler determines fetch ops from collection 'train_op'.
CollectionDef collection_def;
auto node_list = collection_def.mutable_node_list();
node_list->add_value(*output_node);
(*meta_graph_def.mutable_collection_def())["train_op"] = collection_def;
// Create Grappler item.
tensorflow::grappler::ItemConfig item_config;
item_config.apply_optimizations = true;
std::unique_ptr<tensorflow::grappler::GrapplerItem> grappler_item =
tensorflow::grappler::GrapplerItemFromMetaGraphDef(
"graph", meta_graph_def, item_config);
std::unordered_map<string, tensorflow::DeviceProperties> device_map;
tensorflow::grappler::VirtualCluster cluster(device_map);
// Run data optimizer using grappler's meta optimizer.
tensorflow::ConfigProto config;
*config.mutable_graph_options()->mutable_rewrite_options() =
CreateGrapplerRewriteConfig();
TF_RETURN_IF_ERROR(tensorflow::grappler::RunMetaOptimizer(
*grappler_item, config, ctx->device(), &cluster, graph_def));
// Remove fake sinks after optimizations are done.
// TODO(b/118820916): When MetaOptimizer adds provisions for function
// retvals to be optimizable, we will no longer need this.
for (auto& function_def : *graph_def->mutable_library()->mutable_function()) {
RemoveFakeSinks(&function_def);
}
return Status::OK();
}
class GraphRewriteDataset::Iterator
: public DatasetIterator<GraphRewriteDataset> {
public:
explicit Iterator(const Params& params)
: DatasetIterator<GraphRewriteDataset>(params) {}
Status Initialize(IteratorContext* ctx) override {
IteratorContext::Params params(ctx);
params.lib = dataset()->lib_;
params.function_handle_cache = dataset()->function_handle_cache_.get();
return dataset()->optimized_input_->MakeIterator(
IteratorContext(std::move(params)), prefix(), &input_impl_);
}
Status GetNextInternal(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
IteratorContext::Params params(ctx);
params.lib = dataset()->lib_;
params.function_handle_cache = dataset()->function_handle_cache_.get();
return input_impl_->GetNext(IteratorContext(std::move(params)), out_tensors,
end_of_sequence);
}
protected:
std::shared_ptr<model::Node> CreateNode(
IteratorContext* ctx, model::Node::Args args) const override {
return model::MakeKnownRatioNode(std::move(args),
/*ratio=*/1);
}
Status SaveInternal(IteratorStateWriter* writer) override {
TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
return Status::OK();
}
Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override {
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
return Status::OK();
}
private:
std::unique_ptr<IteratorBase> input_impl_;
};
std::unique_ptr<IteratorBase> GraphRewriteDataset::MakeIteratorInternal(
const string& prefix) const {
// We do not add a token for this dataset to the prefix. The
// prefix is used to identify checkpoint elements and since this
// dataset is excluded from the checkpoint, adding a token
// here would result in invalid checkpoint identifiers.
return absl::make_unique<Iterator>(Iterator::Params{this, prefix});
}
} // namespace data
} // namespace tensorflow

View File

@ -0,0 +1,92 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_KERNELS_DATA_GRAPH_REWRITE_DATASET_H_
#define TENSORFLOW_CORE_KERNELS_DATA_GRAPH_REWRITE_DATASET_H_
#include "tensorflow/core/common_runtime/graph_runner.h"
#include "tensorflow/core/common_runtime/process_function_library_runtime.h"
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/function_handle_cache.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/grappler/clusters/virtual_cluster.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/grappler_item_builder.h"
#include "tensorflow/core/grappler/optimizers/data/function_utils.h"
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
#include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
namespace tensorflow {
namespace data {
class GraphRewriteDataset : public DatasetBase {
public:
GraphRewriteDataset(OpKernelContext* ctx, const DatasetBase* input,
const DataTypeVector& output_types,
const std::vector<PartialTensorShape>& output_shapes)
: DatasetBase(DatasetContext(ctx)),
optimized_input_(nullptr),
input_(input),
output_types_(output_types),
output_shapes_(output_shapes) {
input_->Ref();
}
~GraphRewriteDataset() override;
// Runs Grappler to transform the input dataset into optimized_input_
// dataset.
Status Optimize(OpKernelContext* ctx);
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override;
const DataTypeVector& output_dtypes() const override { return output_types_; }
const std::vector<PartialTensorShape>& output_shapes() const override {
return output_shapes_;
}
int64 Cardinality() const override { return input_->Cardinality(); }
protected:
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,
Node** output) const override;
private:
class Iterator;
// Create a Grappler RewriteConfig proto that defines the list of
// optimizations to be run by the Grappler Meta Optimizer.
virtual RewriterConfig CreateGrapplerRewriteConfig() = 0;
Status ApplyOptimizations(OpKernelContext* ctx, GraphDef* graph_def,
string* output_node);
DatasetBase* optimized_input_;
FunctionLibraryRuntime* lib_ = nullptr;
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_ = nullptr;
std::unique_ptr<FunctionLibraryDefinition> flib_def_ = nullptr;
std::unique_ptr<FunctionHandleCache> function_handle_cache_ = nullptr;
const DatasetBase* input_;
const DataTypeVector output_types_;
const std::vector<PartialTensorShape> output_shapes_;
};
} // namespace data
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_DATA_GRAPH_REWRITE_DATASET_H_

View File

@ -14,26 +14,11 @@ limitations under the License.
==============================================================================*/
#include <map>
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/graph_runner.h"
#include "tensorflow/core/common_runtime/process_function_library_runtime.h"
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/device_base.h"
#include "tensorflow/core/framework/function_handle_cache.h"
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/grappler/clusters/virtual_cluster.h"
#include "tensorflow/core/grappler/graph_view.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/grappler_item_builder.h"
#include "tensorflow/core/grappler/optimizers/data/function_utils.h"
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
#include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/kernels/data/graph_rewrite_dataset.h"
#include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/protobuf/meta_graph.pb.h"
#include "tensorflow/core/protobuf/rewriter_config.pb.h"
namespace tensorflow {
@ -71,235 +56,20 @@ class OptimizeDatasetOp : public UnaryDatasetOpKernel {
}
private:
class Dataset : public DatasetBase {
class Dataset : public GraphRewriteDataset {
public:
Dataset(OpKernelContext* ctx, const DatasetBase* input,
const std::vector<string>& optimizations,
const DataTypeVector& output_types,
const std::vector<PartialTensorShape>& output_shapes)
: DatasetBase(DatasetContext(ctx)),
optimized_input_(nullptr),
input_(input),
optimizations_(optimizations),
output_types_(output_types),
output_shapes_(output_shapes) {
input_->Ref();
}
~Dataset() override {
input_->Unref();
if (optimized_input_) {
optimized_input_->Unref();
}
}
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
// We do not add a token for the optimization dataset to the prefix. The
// prefix is used to identify checkpoint elements and since the
// optimization dataset is excluded from the checkpoint, adding a token
// here would result in invalid checkpoint identifiers.
return absl::make_unique<Iterator>(Iterator::Params{this, prefix});
}
Status Optimize(OpKernelContext* ctx) {
GraphDefBuilder b;
DatasetGraphDefBuilder db(&b);
Node* input_node = nullptr;
SerializationContext::Params params;
std::vector<std::pair<string, Tensor>> input_list;
params.flib_def = ctx->function_library()->GetFunctionLibraryDefinition();
params.input_list = &input_list;
params.optimization_only = true;
SerializationContext serialization_ctx(params);
TF_RETURN_IF_ERROR(
db.AddInputDataset(&serialization_ctx, input_, &input_node));
string output_node = input_node->name();
GraphDef graph_def;
TF_RETURN_IF_ERROR(b.ToGraphDef(&graph_def));
VLOG(3) << "Before optimization: " << graph_def.DebugString();
TF_RETURN_IF_ERROR(ApplyOptimizations(ctx, &graph_def, &output_node));
VLOG(3) << "After optimization: " << graph_def.DebugString();
// Instantiate the optimized input pipeline by running the optimized graph
// using the optimized function library.
TF_RETURN_IF_ERROR(
ctx->function_library()->Clone(&flib_def_, &pflr_, &lib_));
// Create a FunctionHandleCache.
function_handle_cache_ = absl::make_unique<FunctionHandleCache>(lib_);
// Some functions may have been modified without having their names
// changed (for example, nested dataset graphs from FlatMap or
// Interleave). To avoid name conflicts, we remove these functions from
// flib_def_ before adding the optimized function library.
for (const FunctionDef& fd : graph_def.library().function()) {
if (flib_def_->Find(fd.signature().name()) != nullptr) {
TF_RETURN_IF_ERROR(flib_def_->RemoveFunction(fd.signature().name()));
}
}
TF_RETURN_IF_ERROR(flib_def_->AddLibrary(graph_def.library()));
Graph graph(OpRegistry::Global());
TF_RETURN_IF_ERROR(ImportGraphDef({}, graph_def, &graph, nullptr));
std::vector<Tensor> outputs;
GraphRunner graph_runner(ctx->function_library()->device());
TF_RETURN_IF_ERROR(
graph_runner.Run(&graph, lib_, input_list, {output_node}, &outputs));
TF_RETURN_IF_ERROR(
GetDatasetFromVariantTensor(outputs[0], &optimized_input_));
optimized_input_->Ref();
return Status::OK();
}
const DataTypeVector& output_dtypes() const override {
return output_types_;
}
const std::vector<PartialTensorShape>& output_shapes() const override {
return output_shapes_;
}
: GraphRewriteDataset(ctx, input, output_types, output_shapes),
optimizations_(optimizations) {}
string DebugString() const override { return "OptimizeDatasetOp::Dataset"; }
int64 Cardinality() const override { return input_->Cardinality(); }
protected:
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,
Node** output) const override {
// We only serialize the optimized dataset to avoid re-running
// optimizations when the input pipeline is restored from a checkpoint.
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, optimized_input_, output));
return Status::OK();
}
private:
class Iterator : public DatasetIterator<Dataset> {
public:
explicit Iterator(const Params& params)
: DatasetIterator<Dataset>(params) {}
Status Initialize(IteratorContext* ctx) override {
IteratorContext::Params params(ctx);
params.lib = dataset()->lib_;
params.function_handle_cache = dataset()->function_handle_cache_.get();
return dataset()->optimized_input_->MakeIterator(
IteratorContext(std::move(params)), prefix(), &input_impl_);
}
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
IteratorContext::Params params(ctx);
params.lib = dataset()->lib_;
params.function_handle_cache = dataset()->function_handle_cache_.get();
return input_impl_->GetNext(IteratorContext(std::move(params)),
out_tensors, end_of_sequence);
}
protected:
std::shared_ptr<model::Node> CreateNode(
IteratorContext* ctx, model::Node::Args args) const override {
return model::MakeKnownRatioNode(std::move(args),
/*ratio=*/1);
}
Status SaveInternal(IteratorStateWriter* writer) override {
TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
return Status::OK();
}
Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override {
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
return Status::OK();
}
private:
std::unique_ptr<IteratorBase> input_impl_;
};
void AddFakeSinks(FunctionDef* function_def) {
int counter = 0;
for (const auto& output : function_def->signature().output_arg()) {
NodeDef* node = function_def->add_node_def();
tensorflow::grappler::function_utils::SetUniqueFunctionNodeName(
strings::StrCat("FakeSink", counter++), function_def, node);
node->set_op("Identity");
node->add_input(function_def->ret().at(output.name()));
(*node->mutable_attr())["T"].set_type(output.type());
(*function_def->mutable_ret())[output.name()] =
strings::StrCat(node->name(), ":output:0");
}
}
void RemoveFakeSinks(FunctionDef* function_def) {
// Map from identity node names to their input tensor strings
std::map<string, string> identity_map;
for (const auto& node : function_def->node_def()) {
if (node.op() == "Identity" && node.input_size() == 1) {
identity_map[node.name()] = node.input(0);
}
}
for (const auto& output_arg : function_def->signature().output_arg()) {
const string& tensor = function_def->ret().at(output_arg.name());
const string& output_node = tensor.substr(0, tensor.find(':'));
if (identity_map.find(output_node) != identity_map.end()) {
(*function_def->mutable_ret())[output_arg.name()] =
identity_map.at(output_node);
}
}
}
Status ApplyOptimizations(OpKernelContext* ctx, GraphDef* graph_def,
string* output_node) {
// Add an identity node as the fetch node, otherwise we might get
// 'placeholder is both fed and fetched' errors in some cases when using
// input list with placeholder dataset nodes.
NodeDef* node = graph_def->mutable_node()->Add();
tensorflow::grappler::graph_utils::SetUniqueGraphNodeName(
"Sink", graph_def, node);
node->set_op("Identity");
node->add_input(*output_node);
(*node->mutable_attr())["T"].set_type(DT_VARIANT);
*output_node = node->name();
// Add fake sink node to graph and functions to allow rewriting the actual
// sink nodes.
// TODO(b/118820916): When MetaOptimizer adds provisions for function
// retvals to be optimizable, we will no longer need this.
for (auto& function_def :
*graph_def->mutable_library()->mutable_function()) {
AddFakeSinks(&function_def);
}
// Create metagraph.
MetaGraphDef meta_graph_def;
(*meta_graph_def.mutable_graph_def()) = *graph_def;
// Grappler determines fetch ops from collection 'train_op'.
CollectionDef collection_def;
auto node_list = collection_def.mutable_node_list();
node_list->add_value(*output_node);
(*meta_graph_def.mutable_collection_def())["train_op"] = collection_def;
// Create Grappler item.
tensorflow::grappler::ItemConfig item_config;
item_config.apply_optimizations = true;
std::unique_ptr<tensorflow::grappler::GrapplerItem> grappler_item =
tensorflow::grappler::GrapplerItemFromMetaGraphDef(
"graph", meta_graph_def, item_config);
std::unordered_map<string, tensorflow::DeviceProperties> device_map;
tensorflow::grappler::VirtualCluster cluster(device_map);
// Run data optimizer using grappler's meta optimizer.
tensorflow::ConfigProto config;
RewriterConfig& rewriter_config =
*config.mutable_graph_options()->mutable_rewrite_options();
RewriterConfig CreateGrapplerRewriteConfig() override {
RewriterConfig rewriter_config;
rewriter_config.add_optimizers(kOptimizerName);
rewriter_config.set_meta_optimizer_iterations(
RewriterConfig_NumIterationsType_ONE);
@ -311,30 +81,10 @@ class OptimizeDatasetOp : public UnaryDatasetOpKernel {
for (const auto& opt : optimizations_) {
custom_optimizations_list->add_s(opt);
}
TF_RETURN_IF_ERROR(tensorflow::grappler::RunMetaOptimizer(
*grappler_item, config, ctx->device(), &cluster, graph_def));
// Remove fake sinks after optimizations are done.
// TODO(b/118820916): When MetaOptimizer adds provisions for function
// retvals to be optimizable, we will no longer need this.
for (auto& function_def :
*graph_def->mutable_library()->mutable_function()) {
RemoveFakeSinks(&function_def);
}
return Status::OK();
return rewriter_config;
}
DatasetBase* optimized_input_;
FunctionLibraryRuntime* lib_ = nullptr;
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_ = nullptr;
std::unique_ptr<FunctionLibraryDefinition> flib_def_ = nullptr;
std::unique_ptr<FunctionHandleCache> function_handle_cache_ = nullptr;
const DatasetBase* input_;
const std::vector<string> optimizations_;
const DataTypeVector output_types_;
const std::vector<PartialTensorShape> output_shapes_;
};
const int graph_def_version_;

View File

@ -190,6 +190,14 @@ REGISTER_OP("ExperimentalMapAndBatchDataset")
return shape_inference::ScalarShape(c);
});
REGISTER_OP("ExperimentalRebatchDataset")
.Input("input_dataset: variant")
.Input("num_workers: int64")
.Output("handle: variant")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
.SetShapeFn(shape_inference::ScalarShape);
REGISTER_OP("ExperimentalMapDataset")
.Input("input_dataset: variant")
.Input("other_arguments: Targuments")

View File

@ -470,6 +470,20 @@ py_library(
],
)
py_test(
name = "rebatch_dataset_test",
size = "small",
srcs = ["rebatch_dataset_test.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:client_testlib",
"//tensorflow/python/data/experimental/ops:batching",
"//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/util:nest",
],
)
py_test(
name = "rejection_resample_test",
size = "medium",

View File

@ -0,0 +1,60 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for the private `_RebatchDataset` transformation."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.data.experimental.ops import batching
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import nest
from tensorflow.python.framework import test_util
from tensorflow.python.platform import test
@test_util.run_all_in_graph_and_eager_modes
class RebatchDatasetTest(test_base.DatasetTestBase):
def testBasic(self):
dataset = dataset_ops.Dataset.range(1024).batch(32, drop_remainder=True)
rebatched_dataset = batching._RebatchDataset(dataset, num_workers=4)
self.assertEqual(
[[32]], [ts.as_list() for ts in nest.flatten(dataset.output_shapes)])
self.assertEqual(
[[8]],
[ts.as_list() for ts in nest.flatten(rebatched_dataset.output_shapes)])
expected_output = [[k for k in range(i, i + 8)] for i in range(0, 1024, 8)] # pylint: disable=g-complex-comprehension
self.assertDatasetProduces(rebatched_dataset, expected_output)
def testScalarInputError(self):
dataset = dataset_ops.Dataset.range(1024)
with self.assertRaisesRegexp(ValueError, "at least one dimension"):
batching._RebatchDataset(dataset, num_workers=4)
def testUnknownBatchSizeError(self):
dataset = dataset_ops.Dataset.range(1024).batch(32)
with self.assertRaisesRegexp(ValueError, "unknown batch size datasets"):
batching._RebatchDataset(dataset, num_workers=4)
def testNotDivisibleError(self):
dataset = dataset_ops.Dataset.range(1024).batch(32, drop_remainder=True)
with self.assertRaisesRegexp(ValueError, "not divisible by"):
batching._RebatchDataset(dataset, num_workers=5)
if __name__ == "__main__":
test.main()

View File

@ -408,6 +408,24 @@ py_test(
],
)
py_test(
name = "rebatch_dataset_serialization_test",
size = "small",
srcs = ["rebatch_dataset_serialization_test.py"],
srcs_version = "PY2AND3",
tags = [
"no_oss",
"no_pip",
"no_windows",
],
deps = [
":dataset_serialization_test_base",
"//tensorflow/python:client_testlib",
"//tensorflow/python/data/experimental/ops:batching",
"//tensorflow/python/data/ops:dataset_ops",
],
)
py_test(
name = "padded_batch_dataset_serialization_test",
size = "medium",

View File

@ -0,0 +1,41 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for the _RebatchDataset serialization."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
from tensorflow.python.data.experimental.ops import batching
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.platform import test
class RebatchDatasetSerializationTest(
dataset_serialization_test_base.DatasetSerializationTestBase):
def testCore(self):
def build_dataset(num_elements, batch_size):
return batching._RebatchDataset(
dataset_ops.Dataset.range(num_elements).batch(
4 * batch_size, drop_remainder=True),
num_workers=4)
self.run_core_tests(lambda: build_dataset(200, 10), None, 20)
if __name__ == "__main__":
test.main()

View File

@ -645,3 +645,34 @@ def map_and_batch(map_func,
num_parallel_calls, drop_remainder)
return _apply_fn
class _RebatchDataset(dataset_ops.UnaryDataset):
"""A `Dataset` that divides the batch size by `num_workers`."""
def __init__(self, input_dataset, num_workers):
self._input_dataset = input_dataset
output_shapes = input_dataset.output_shapes
if len(output_shapes) < 1:
raise ValueError("Input shape should have at least one dimension.")
if not output_shapes.dims[0].value:
raise ValueError("Cannot rebatch unknown batch size datasets.")
if output_shapes.dims[0].value % num_workers != 0:
raise ValueError(
"First dim of input shape: %d is not divisible by num_workers: %d" %
(output_shapes[0], num_workers))
output_dims = [d for d in output_shapes.dims]
output_dims[0] = output_dims[0] // num_workers
output_shapes = tensor_shape.TensorShapeV1(output_dims)
self._structure = structure.convert_legacy_structure(
self._input_dataset.output_types, output_shapes,
self._input_dataset.output_classes)
variant_tensor = ged_ops.experimental_rebatch_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access
num_workers=num_workers,
**dataset_ops.flat_structure(self))
super(_RebatchDataset, self).__init__(input_dataset, variant_tensor)
@property
def _element_structure(self):
return self._structure

View File

@ -1096,6 +1096,10 @@ tf_module {
name: "ExperimentalRandomDataset"
argspec: "args=[\'seed\', \'seed2\', \'output_types\', \'output_shapes\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "ExperimentalRebatchDataset"
argspec: "args=[\'input_dataset\', \'num_workers\', \'output_types\', \'output_shapes\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "ExperimentalScanDataset"
argspec: "args=[\'input_dataset\', \'initial_state\', \'other_arguments\', \'f\', \'output_types\', \'output_shapes\', \'preserve_cardinality\'], varargs=None, keywords=None, defaults=None"

View File

@ -1096,6 +1096,10 @@ tf_module {
name: "ExperimentalRandomDataset"
argspec: "args=[\'seed\', \'seed2\', \'output_types\', \'output_shapes\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "ExperimentalRebatchDataset"
argspec: "args=[\'input_dataset\', \'num_workers\', \'output_types\', \'output_shapes\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "ExperimentalScanDataset"
argspec: "args=[\'input_dataset\', \'initial_state\', \'other_arguments\', \'f\', \'output_types\', \'output_shapes\', \'preserve_cardinality\'], varargs=None, keywords=None, defaults=None"