[tf.data] Extending the TF 2.0 support for shuffle(..., reshuffle_each_iteration=True)
to work across different Python iterators for the same dataset.
To achieve its objective, this CL creates a `RandomSeedGenerator` resource and ops for creating and deleting the resource which is used to manage state for seeding the shuffle order of different Python iterators for the same dataset. Note that the new functionality is not yet supported for (multi-worker) distribution strategies that clone the input pipeline graph created by user programs. To support this use case, we need a mechanism to clone the `RandomSeedGenerator` resource (and in general other resources). Fixes: #27680 PiperOrigin-RevId: 260765646
This commit is contained in:
parent
da247af90f
commit
e7206a7d8e
@ -0,0 +1,4 @@
|
||||
op {
|
||||
graph_op_name: "AnonymousRandomSeedGenerator"
|
||||
visibility: HIDDEN
|
||||
}
|
@ -0,0 +1,4 @@
|
||||
op {
|
||||
graph_op_name: "DeleteRandomSeedGenerator"
|
||||
visibility: HIDDEN
|
||||
}
|
@ -0,0 +1,4 @@
|
||||
op {
|
||||
graph_op_name: "ShuffleDatasetV2"
|
||||
visibility: HIDDEN
|
||||
}
|
@ -20,6 +20,7 @@ cc_library(
|
||||
":inject_prefetch",
|
||||
":latency_all_edges",
|
||||
":make_sloppy",
|
||||
":make_stateless",
|
||||
":map_and_batch_fusion",
|
||||
":map_and_filter_fusion",
|
||||
":map_fusion",
|
||||
@ -374,6 +375,37 @@ tf_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "make_stateless",
|
||||
srcs = ["make_stateless.cc"],
|
||||
hdrs = ["make_stateless.h"],
|
||||
deps = [
|
||||
":graph_utils",
|
||||
":optimizer_base",
|
||||
"//tensorflow/core/grappler:grappler_item",
|
||||
"//tensorflow/core/grappler:mutable_graph_view",
|
||||
"//tensorflow/core/grappler:op_types",
|
||||
"//tensorflow/core/grappler/clusters:cluster",
|
||||
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
|
||||
] + tf_protos_all(),
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "make_stateless_test",
|
||||
srcs = ["make_stateless_test.cc"],
|
||||
deps = [
|
||||
":graph_test_utils",
|
||||
":graph_utils",
|
||||
":make_stateless",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
"//tensorflow/core/grappler:grappler_item",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "map_and_batch_fusion",
|
||||
srcs = ["map_and_batch_fusion.cc"],
|
||||
|
@ -129,8 +129,8 @@ Status AddShardNode(MutableGraphView* graph, const NodeDef& add_before,
|
||||
// Add shapes and other attributes
|
||||
NodeDef* add_after = graph->GetNode(add_before.input(0));
|
||||
|
||||
if (str_util::EndsWith(add_after->op(), "Dataset") ||
|
||||
str_util::EndsWith(add_after->op(), "DatasetV2")) {
|
||||
if (absl::EndsWith(add_after->op(), "Dataset") ||
|
||||
absl::EndsWith(add_after->op(), "DatasetV2")) {
|
||||
// We still may or may not have the right attributes because Datasets like
|
||||
// TFRecordDataset doesn't have a output type or shape, and by default we
|
||||
// set them to DT_STRING and an unknown shape.
|
||||
@ -174,27 +174,25 @@ Status AddShardNode(MutableGraphView* graph, const NodeDef& add_before,
|
||||
}
|
||||
|
||||
Status AddShuffleNode(MutableGraphView* graph, const NodeDef& add_before,
|
||||
const string& buffer_node) {
|
||||
const string& op_name, const string& buffer_size_node,
|
||||
const string& seed_node, const string& seed2_node,
|
||||
bool reshuffle_each_iteration) {
|
||||
NodeDef* add_after = graph->GetNode(add_before.input(0));
|
||||
|
||||
NodeDef new_node;
|
||||
new_node.set_op(kShuffleDatasetOpName);
|
||||
graph_utils::SetUniqueGraphNodeName(kShuffleDatasetOpName, graph->graph(),
|
||||
&new_node);
|
||||
|
||||
NodeDef* seed = graph_utils::AddScalarConstNode<int64>(1, graph);
|
||||
NodeDef* seed2 = graph_utils::AddScalarConstNode<int64>(2, graph);
|
||||
AttrValue reshuffle;
|
||||
reshuffle.set_b(false);
|
||||
new_node.set_op(op_name);
|
||||
graph_utils::SetUniqueGraphNodeName(op_name, graph->graph(), &new_node);
|
||||
|
||||
new_node.add_input(add_before.input(0));
|
||||
new_node.add_input(buffer_node);
|
||||
new_node.add_input(seed->name());
|
||||
new_node.add_input(seed2->name());
|
||||
new_node.add_input(buffer_size_node);
|
||||
new_node.add_input(seed_node);
|
||||
new_node.add_input(seed2_node);
|
||||
|
||||
graph_utils::CopyAttribute("output_shapes", *add_after, &new_node);
|
||||
graph_utils::CopyAttribute("output_types", *add_after, &new_node);
|
||||
(*new_node.mutable_attr())["reshuffle_each_iteration"] = reshuffle;
|
||||
|
||||
AttrValue reshuffle_attr;
|
||||
reshuffle_attr.set_b(reshuffle_each_iteration);
|
||||
(*new_node.mutable_attr())["reshuffle_each_iteration"] = reshuffle_attr;
|
||||
|
||||
NodeDef* new_node_graph = graph->AddNode(std::move(new_node));
|
||||
|
||||
@ -223,19 +221,23 @@ bool ReaderOpInFunction(const NodeDef& node,
|
||||
|
||||
Status RemoveShuffleDataset(MutableGraphView* graph, const NodeDef& node,
|
||||
absl::flat_hash_set<string>* nodes_to_delete,
|
||||
bool* shuffle_removed,
|
||||
string* buffer_size_node_name) {
|
||||
string* op_name, string* buffer_size_node,
|
||||
string* seed_node, string* seed2_node,
|
||||
bool* reshuffle_each_iteration) {
|
||||
if (node.op() == kShuffleDatasetOpName) {
|
||||
*shuffle_removed = true;
|
||||
*buffer_size_node_name = node.input(1);
|
||||
*op_name = node.op();
|
||||
*buffer_size_node = node.input(1);
|
||||
*seed_node = node.input(2);
|
||||
*seed2_node = node.input(3);
|
||||
*reshuffle_each_iteration = node.attr().at("reshuffle_each_iteration").b();
|
||||
TF_RETURN_IF_ERROR(graph->UpdateFanouts(node.name(), node.input(0)));
|
||||
nodes_to_delete->insert(node.name());
|
||||
}
|
||||
|
||||
for (const auto& fanin : graph->GetFanins(node, true)) {
|
||||
TF_RETURN_IF_ERROR(RemoveShuffleDataset(graph, *fanin.node, nodes_to_delete,
|
||||
shuffle_removed,
|
||||
buffer_size_node_name));
|
||||
TF_RETURN_IF_ERROR(RemoveShuffleDataset(
|
||||
graph, *fanin.node, nodes_to_delete, op_name, buffer_size_node,
|
||||
seed_node, seed2_node, reshuffle_each_iteration));
|
||||
}
|
||||
|
||||
// TODO(frankchn): Traverse functions too.
|
||||
@ -245,15 +247,21 @@ Status RemoveShuffleDataset(MutableGraphView* graph, const NodeDef& node,
|
||||
Status ProcessDatasetSourceNode(MutableGraphView* graph, const NodeDef& node,
|
||||
absl::flat_hash_set<string>* nodes_to_delete,
|
||||
int64 num_workers, int64 index) {
|
||||
bool shuffle_removed = false;
|
||||
string buffer_size_node_name = "";
|
||||
string shuffle_op_name = "";
|
||||
string buffer_size_node = "";
|
||||
string seed_node = "";
|
||||
string seed2_node = "";
|
||||
bool reshuffle_each_iteration;
|
||||
|
||||
TF_RETURN_IF_ERROR(AddShardNode(graph, node, num_workers, index));
|
||||
TF_RETURN_IF_ERROR(RemoveShuffleDataset(
|
||||
graph, node, nodes_to_delete, &shuffle_removed, &buffer_size_node_name));
|
||||
graph, node, nodes_to_delete, &shuffle_op_name, &buffer_size_node,
|
||||
&seed_node, &seed2_node, &reshuffle_each_iteration));
|
||||
|
||||
if (shuffle_removed) {
|
||||
TF_RETURN_IF_ERROR(AddShuffleNode(graph, node, buffer_size_node_name));
|
||||
if (!shuffle_op_name.empty()) {
|
||||
TF_RETURN_IF_ERROR(AddShuffleNode(graph, node, shuffle_op_name,
|
||||
buffer_size_node, seed_node, seed2_node,
|
||||
reshuffle_each_iteration));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
@ -383,7 +391,6 @@ Status AutoShard::OptimizeAndCollectStats(Cluster* /* cluster */,
|
||||
GraphDef* output,
|
||||
OptimizationStats* stats) {
|
||||
*output = item.graph;
|
||||
|
||||
TF_RETURN_IF_ERROR(OptimizeGraph(item, num_workers_, index_, output));
|
||||
stats->num_changes++;
|
||||
return Status::OK();
|
||||
|
@ -60,12 +60,12 @@ NodeDef MakeMapNode(StringPiece name, StringPiece input_node_name,
|
||||
{"output_types", gtl::ArraySlice<DataType>{}}});
|
||||
}
|
||||
|
||||
NodeDef MakeParallelInterleaveNode(StringPiece name,
|
||||
StringPiece input_node_name,
|
||||
StringPiece cycle_length_node_name,
|
||||
StringPiece block_length_node_name,
|
||||
StringPiece num_parallel_calls_node_name,
|
||||
StringPiece function_name, bool sloppy) {
|
||||
NodeDef MakeParallelInterleaveV2Node(StringPiece name,
|
||||
StringPiece input_node_name,
|
||||
StringPiece cycle_length_node_name,
|
||||
StringPiece block_length_node_name,
|
||||
StringPiece num_parallel_calls_node_name,
|
||||
StringPiece function_name, bool sloppy) {
|
||||
return test::function::NDef(
|
||||
name, "ParallelInterleaveDatasetV2",
|
||||
{string(input_node_name), string(cycle_length_node_name),
|
||||
@ -107,6 +107,22 @@ NodeDef MakeParseExampleNode(StringPiece name, StringPiece input_node_name,
|
||||
});
|
||||
}
|
||||
|
||||
NodeDef MakeShuffleV2Node(StringPiece name, StringPiece input_node_name,
|
||||
StringPiece buffer_size_node_name,
|
||||
StringPiece seed_generator_node_name) {
|
||||
return test::function::NDef(
|
||||
name, "ShuffleDatasetV2",
|
||||
{
|
||||
string(input_node_name),
|
||||
string(buffer_size_node_name),
|
||||
string(seed_generator_node_name),
|
||||
},
|
||||
{
|
||||
{"output_shapes", gtl::ArraySlice<TensorShape>{}},
|
||||
{"output_types", gtl::ArraySlice<DataType>{}},
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace graph_tests_utils
|
||||
} // namespace grappler
|
||||
} // namespace tensorflow
|
||||
|
@ -38,13 +38,13 @@ NodeDef MakeMapAndBatchNode(StringPiece name, StringPiece input_node_name,
|
||||
StringPiece drop_remainder_node_name,
|
||||
StringPiece function_name = "XTimesTwo");
|
||||
|
||||
// Creates a test NodeDef for ParallelInterleaveDataset.
|
||||
NodeDef MakeParallelInterleaveNode(StringPiece name,
|
||||
StringPiece input_node_name,
|
||||
StringPiece cycle_length_node_name,
|
||||
StringPiece block_length_node_name,
|
||||
StringPiece num_parallel_calls_node_name,
|
||||
StringPiece function_name, bool sloppy);
|
||||
// Creates a test NodeDef for ParallelInterleaveDatasetV2.
|
||||
NodeDef MakeParallelInterleaveV2Node(StringPiece name,
|
||||
StringPiece input_node_name,
|
||||
StringPiece cycle_length_node_name,
|
||||
StringPiece block_length_node_name,
|
||||
StringPiece num_parallel_calls_node_name,
|
||||
StringPiece function_name, bool sloppy);
|
||||
|
||||
// Creates a test NodeDef for ParallelMapDataset.
|
||||
NodeDef MakeParallelMapNode(StringPiece name, StringPiece input_node_name,
|
||||
@ -56,6 +56,11 @@ NodeDef MakeParseExampleNode(StringPiece name, StringPiece input_node_name,
|
||||
StringPiece num_parallel_calls_node_name,
|
||||
bool sloppy);
|
||||
|
||||
// Creates a test NodeDef for ShuffleDatasetV2.
|
||||
NodeDef MakeShuffleV2Node(StringPiece name, StringPiece input_node_name,
|
||||
StringPiece buffer_size_node_name,
|
||||
StringPiece seed_generator_node_name);
|
||||
|
||||
} // namespace graph_tests_utils
|
||||
} // namespace grappler
|
||||
} // namespace tensorflow
|
||||
|
@ -29,10 +29,6 @@ namespace tensorflow {
|
||||
namespace grappler {
|
||||
namespace {
|
||||
|
||||
using graph_tests_utils::MakeParallelInterleaveNode;
|
||||
using graph_tests_utils::MakeParallelMapNode;
|
||||
using graph_tests_utils::MakeParseExampleNode;
|
||||
|
||||
TEST(MakeSloppy, ParallelInterleave) {
|
||||
using test::function::NDef;
|
||||
GrapplerItem item;
|
||||
@ -45,9 +41,9 @@ TEST(MakeSloppy, ParallelInterleave) {
|
||||
NDef("block_length", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
|
||||
NDef("num_parallel_calls", "Const", {},
|
||||
{{"value", 1}, {"dtype", DT_INT32}}),
|
||||
MakeParallelInterleaveNode("interleave", "range", "cycle_length",
|
||||
"block_length", "num_parallel_calls",
|
||||
"XTimesTwo", /*sloppy=*/false)},
|
||||
graph_tests_utils::MakeParallelInterleaveV2Node(
|
||||
"interleave", "range", "cycle_length", "block_length",
|
||||
"num_parallel_calls", "XTimesTwo", /*sloppy=*/false)},
|
||||
// FunctionLib
|
||||
{
|
||||
test::function::XTimesTwo(),
|
||||
@ -71,8 +67,9 @@ TEST(MakeSloppy, ParallelMap) {
|
||||
NDef("range", "RangeDataset", {"start", "stop", "step"}, {}),
|
||||
NDef("num_parallel_calls", "Const", {},
|
||||
{{"value", 1}, {"dtype", DT_INT32}}),
|
||||
MakeParallelMapNode("map", "range", "num_parallel_calls", "XTimesTwo",
|
||||
/*sloppy=*/false)},
|
||||
graph_tests_utils::MakeParallelMapNode("map", "range",
|
||||
"num_parallel_calls", "XTimesTwo",
|
||||
/*sloppy=*/false)},
|
||||
// FunctionLib
|
||||
{
|
||||
test::function::XTimesTwo(),
|
||||
@ -96,8 +93,9 @@ TEST(MakeSloppy, ParseExampleDataset) {
|
||||
NDef("range", "RangeDataset", {"start", "stop", "step"}, {}),
|
||||
NDef("num_parallel_calls", "Const", {},
|
||||
{{"value", 1}, {"dtype", DT_INT32}}),
|
||||
MakeParseExampleNode("parse_example", "range", "num_parallel_calls",
|
||||
/*sloppy=*/false)},
|
||||
graph_tests_utils::MakeParseExampleNode("parse_example", "range",
|
||||
"num_parallel_calls",
|
||||
/*sloppy=*/false)},
|
||||
// FunctionLib
|
||||
{});
|
||||
|
||||
|
65
tensorflow/core/grappler/optimizers/data/make_stateless.cc
Normal file
65
tensorflow/core/grappler/optimizers/data/make_stateless.cc
Normal file
@ -0,0 +1,65 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/grappler/optimizers/data/make_stateless.h"
|
||||
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/grappler/clusters/cluster.h"
|
||||
#include "tensorflow/core/grappler/grappler_item.h"
|
||||
#include "tensorflow/core/grappler/mutable_graph_view.h"
|
||||
#include "tensorflow/core/grappler/op_types.h"
|
||||
#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
|
||||
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
namespace {
|
||||
|
||||
constexpr char kReshuffleEachIteration[] = "reshuffle_each_iteration";
|
||||
constexpr char kShuffleDataset[] = "ShuffleDataset";
|
||||
constexpr char kShuffleDatasetV2[] = "ShuffleDatasetV2";
|
||||
|
||||
} // namespace
|
||||
|
||||
Status MakeStateless::OptimizeAndCollectStats(Cluster* cluster,
|
||||
const GrapplerItem& item,
|
||||
GraphDef* output,
|
||||
OptimizationStats* stats) {
|
||||
*output = item.graph;
|
||||
MutableGraphView graph(output);
|
||||
|
||||
NodeDef* zero_node = graph_utils::AddScalarConstNode<int64>(0, &graph);
|
||||
|
||||
for (NodeDef& node : *output->mutable_node()) {
|
||||
if (node.op() == kShuffleDatasetV2) {
|
||||
*node.mutable_op() = kShuffleDataset;
|
||||
// remove `seed_generator` input
|
||||
node.mutable_input()->RemoveLast();
|
||||
// add `seed` input
|
||||
node.add_input(zero_node->name());
|
||||
// add `seed2` input
|
||||
node.add_input(zero_node->name());
|
||||
// set `reshuffle_each_iteration` attr
|
||||
(*node.mutable_attr())[kReshuffleEachIteration].set_b(true);
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
REGISTER_GRAPH_OPTIMIZER_AS(MakeStateless, "make_stateless");
|
||||
|
||||
} // namespace grappler
|
||||
} // namespace tensorflow
|
54
tensorflow/core/grappler/optimizers/data/make_stateless.h
Normal file
54
tensorflow/core/grappler/optimizers/data/make_stateless.h
Normal file
@ -0,0 +1,54 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_MAKE_STATELESS_H_
|
||||
#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_MAKE_STATELESS_H_
|
||||
|
||||
#include "tensorflow/core/grappler/optimizers/data/optimizer_base.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
|
||||
// This rewrite replaces transformations that depend on external state (such as
|
||||
// `ShuffleDatasetV2`) with a stateless alternative so that the input pipeline
|
||||
// graph can be cloned.
|
||||
//
|
||||
// Note that this rewrites may change observable behavior of the input pipeline
|
||||
// (e.g. `reshuffle_each_iteration` will not work) and is a stop gap solution
|
||||
// to enable cloning until a better mechanism exists.
|
||||
class MakeStateless : public TFDataOptimizerBase {
|
||||
public:
|
||||
MakeStateless() = default;
|
||||
~MakeStateless() override = default;
|
||||
|
||||
string name() const override { return "make_stateless"; }
|
||||
|
||||
Status Init(
|
||||
const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status OptimizeAndCollectStats(Cluster* cluster, const GrapplerItem& item,
|
||||
GraphDef* output,
|
||||
OptimizationStats* stats) override;
|
||||
|
||||
void Feedback(Cluster* cluster, const GrapplerItem& item,
|
||||
const GraphDef& optimize_output, double result) override {}
|
||||
};
|
||||
|
||||
} // namespace grappler
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_MAKE_STATELESS_H_
|
@ -0,0 +1,57 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/grappler/optimizers/data/make_stateless.h"
|
||||
|
||||
#include "tensorflow/core/framework/attr_value_util.h"
|
||||
#include "tensorflow/core/framework/function_testlib.h"
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
#include "tensorflow/core/grappler/grappler_item.h"
|
||||
#include "tensorflow/core/grappler/optimizers/data/graph_test_utils.h"
|
||||
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
namespace {
|
||||
|
||||
TEST(MakeStateless, Shuffle) {
|
||||
using test::function::NDef;
|
||||
GrapplerItem item;
|
||||
item.graph = test::function::GDef(
|
||||
{NDef("start", "Const", {}, {{"value", 0}, {"dtype", DT_INT32}}),
|
||||
NDef("stop", "Const", {}, {{"value", 10}, {"dtype", DT_INT32}}),
|
||||
NDef("step", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
|
||||
NDef("range", "RangeDataset", {"start", "stop", "step"}, {}),
|
||||
NDef("buffer_size", "Const", {}, {{"value", 1}, {"dtype", DT_INT64}}),
|
||||
NDef("seed_generator", "Const", {},
|
||||
{{"value", 1}, {"dtype", DT_RESOURCE}}),
|
||||
graph_tests_utils::MakeShuffleV2Node("shuffle", "range", "buffer_size",
|
||||
"seed_generator")},
|
||||
{});
|
||||
|
||||
MakeStateless optimizer;
|
||||
GraphDef output;
|
||||
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
|
||||
EXPECT_TRUE(graph_utils::ContainsGraphNodeWithName("shuffle", output));
|
||||
int index = graph_utils::FindGraphNodeWithName("shuffle", output);
|
||||
EXPECT_EQ(output.node(index).op(), "ShuffleDataset");
|
||||
EXPECT_EQ(output.node(index).input_size(), 4);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace grappler
|
||||
} // namespace tensorflow
|
@ -795,6 +795,7 @@ tf_kernel_library(
|
||||
hdrs = ["shuffle_dataset_op.h"],
|
||||
deps = [
|
||||
":name_utils",
|
||||
":random_seed_ops",
|
||||
"//tensorflow/core:dataset_ops_op_lib",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
@ -1319,3 +1320,17 @@ tf_cc_test(
|
||||
"//tensorflow/core/kernels:function_ops",
|
||||
],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "random_seed_ops",
|
||||
srcs = ["random_seed_ops.cc"],
|
||||
hdrs = ["random_seed_ops.h"],
|
||||
deps = [
|
||||
":dataset_utils",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:functional_ops_op_lib",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
],
|
||||
)
|
||||
|
@ -41,6 +41,8 @@ namespace tensorflow {
|
||||
namespace data {
|
||||
namespace {
|
||||
|
||||
constexpr char kDelimiter[] = "@@";
|
||||
|
||||
void AddFakeSinks(FunctionDef* function_def) {
|
||||
int counter = 0;
|
||||
for (const auto& output : function_def->signature().output_arg()) {
|
||||
@ -136,134 +138,6 @@ Status ApplyRewrites(OpKernelContext* ctx,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
Status AsGraphDef(OpKernelContext* ctx, const DatasetBase* dataset,
|
||||
SerializationContext&& serialization_ctx,
|
||||
GraphDef* graph_def) {
|
||||
GraphDefBuilder b;
|
||||
DatasetBase::DatasetGraphDefBuilder db(&b);
|
||||
Node* output_node = nullptr;
|
||||
TF_RETURN_IF_ERROR(
|
||||
db.AddInputDataset(&serialization_ctx, dataset, &output_node));
|
||||
// Insert a purely symbolic _Retval node to indicate to consumers which Tensor
|
||||
// represents this Dataset.
|
||||
ops::UnaryOp("_Retval", output_node,
|
||||
b.opts()
|
||||
.WithName("dataset")
|
||||
.WithAttr("T", DT_VARIANT)
|
||||
.WithAttr("index", 0));
|
||||
TF_RETURN_IF_ERROR(b.ToGraphDef(graph_def));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ConnectCancellationManagers(CancellationManager* parent,
|
||||
CancellationManager* child,
|
||||
std::function<void()>* deregister_fn) {
|
||||
if (parent) {
|
||||
CancellationToken token = parent->get_cancellation_token();
|
||||
if (!parent->RegisterCallback(token, [child]() { child->StartCancel(); })) {
|
||||
return errors::Cancelled("Operation was cancelled");
|
||||
}
|
||||
*deregister_fn = [parent, token]() { parent->DeregisterCallback(token); };
|
||||
} else {
|
||||
VLOG(1) << "Parent cancellation manager is not set. Cancellation will "
|
||||
"not be propagated to the child cancellation manager.";
|
||||
*deregister_fn = []() {};
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status RewriteDataset(OpKernelContext* ctx, const DatasetBase* input,
|
||||
std::function<RewriterConfig(void)> config_factory,
|
||||
bool optimize_function_library,
|
||||
DatasetBase** rewritten_input) {
|
||||
SerializationContext::Params params;
|
||||
std::vector<std::pair<string, Tensor>> input_list;
|
||||
params.input_list = &input_list;
|
||||
params.optimization_only = true;
|
||||
SerializationContext serialization_ctx(params);
|
||||
GraphDef graph_def;
|
||||
TF_RETURN_IF_ERROR(
|
||||
AsGraphDef(ctx, input, std::move(serialization_ctx), &graph_def));
|
||||
|
||||
string output_node;
|
||||
for (const auto& node : graph_def.node()) {
|
||||
if (node.op() == "_Retval") {
|
||||
output_node = node.input(0);
|
||||
}
|
||||
}
|
||||
|
||||
VLOG(3) << "Before graph rewrites: " << graph_def.DebugString();
|
||||
TF_RETURN_IF_ERROR(ApplyRewrites(ctx, config_factory,
|
||||
optimize_function_library, &graph_def,
|
||||
&output_node));
|
||||
VLOG(3) << "After graph rewrites: " << graph_def.DebugString();
|
||||
|
||||
// Instantiate the optimized input pipeline by running the optimized graph
|
||||
// using the optimized function library.
|
||||
FunctionLibraryRuntime* flr = nullptr;
|
||||
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr = nullptr;
|
||||
std::unique_ptr<FunctionLibraryDefinition> lib_def = nullptr;
|
||||
TF_RETURN_IF_ERROR(
|
||||
ctx->function_library()->Clone(&lib_def, &pflr, &flr, true));
|
||||
|
||||
// Some functions may have been modified without having their names
|
||||
// changed (for example, nested dataset graphs from FlatMap or
|
||||
// Interleave).
|
||||
TF_RETURN_IF_ERROR(AddToFunctionLibrary(lib_def.get(), graph_def.library()));
|
||||
|
||||
Graph graph(OpRegistry::Global());
|
||||
TF_RETURN_IF_ERROR(ImportGraphDef({}, graph_def, &graph, nullptr));
|
||||
std::vector<Tensor> outputs;
|
||||
GraphRunner graph_runner(flr->device());
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
graph_runner.Run(&graph, flr, input_list, {output_node}, &outputs));
|
||||
TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(outputs[0], rewritten_input));
|
||||
(*rewritten_input)->Ref();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status VerifyTypesMatch(const DataTypeVector& expected,
|
||||
const DataTypeVector& received) {
|
||||
if (expected.size() != received.size()) {
|
||||
return errors::InvalidArgument(
|
||||
"Number of components does not match: expected ", expected.size(),
|
||||
" types but got ", received.size(), ".");
|
||||
}
|
||||
for (size_t i = 0; i < expected.size(); ++i) {
|
||||
if (expected[i] != received[i]) {
|
||||
return errors::InvalidArgument("Data type mismatch at component ", i,
|
||||
": expected ", DataTypeString(expected[i]),
|
||||
" but got ", DataTypeString(received[i]),
|
||||
".");
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status VerifyShapesCompatible(const std::vector<PartialTensorShape>& expected,
|
||||
const std::vector<PartialTensorShape>& received) {
|
||||
if (expected.size() != received.size()) {
|
||||
return errors::InvalidArgument(
|
||||
"Number of components does not match: expected ", expected.size(),
|
||||
" shapes but got ", received.size(), ".");
|
||||
}
|
||||
for (size_t i = 0; i < expected.size(); ++i) {
|
||||
if (!expected[i].IsCompatibleWith(received[i])) {
|
||||
return errors::InvalidArgument("Incompatible shapes at component ", i,
|
||||
": expected ", expected[i].DebugString(),
|
||||
" but got ", received[i].DebugString(),
|
||||
".");
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
uint64 DefaultDependencyLoopNodeHash() {
|
||||
static const uint64 hash = Hash64("DependencyLoopNode");
|
||||
return hash;
|
||||
@ -496,7 +370,131 @@ uint64 HashSubgraphFunctionImpl(
|
||||
return final_hash;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // anonymous namespace
|
||||
|
||||
Status AsGraphDef(OpKernelContext* ctx, const DatasetBase* dataset,
|
||||
SerializationContext&& serialization_ctx,
|
||||
GraphDef* graph_def) {
|
||||
GraphDefBuilder b;
|
||||
DatasetBase::DatasetGraphDefBuilder db(&b);
|
||||
Node* output_node = nullptr;
|
||||
TF_RETURN_IF_ERROR(
|
||||
db.AddInputDataset(&serialization_ctx, dataset, &output_node));
|
||||
// Insert a purely symbolic _Retval node to indicate to consumers which Tensor
|
||||
// represents this Dataset.
|
||||
ops::UnaryOp("_Retval", output_node,
|
||||
b.opts()
|
||||
.WithName("dataset")
|
||||
.WithAttr("T", DT_VARIANT)
|
||||
.WithAttr("index", 0));
|
||||
TF_RETURN_IF_ERROR(b.ToGraphDef(graph_def));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ConnectCancellationManagers(CancellationManager* parent,
|
||||
CancellationManager* child,
|
||||
std::function<void()>* deregister_fn) {
|
||||
if (parent) {
|
||||
CancellationToken token = parent->get_cancellation_token();
|
||||
if (!parent->RegisterCallback(token, [child]() { child->StartCancel(); })) {
|
||||
return errors::Cancelled("Operation was cancelled");
|
||||
}
|
||||
*deregister_fn = [parent, token]() { parent->DeregisterCallback(token); };
|
||||
} else {
|
||||
VLOG(1) << "Parent cancellation manager is not set. Cancellation will "
|
||||
"not be propagated to the child cancellation manager.";
|
||||
*deregister_fn = []() {};
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status RewriteDataset(OpKernelContext* ctx, const DatasetBase* input,
|
||||
std::function<RewriterConfig(void)> config_factory,
|
||||
bool optimize_function_library,
|
||||
DatasetBase** rewritten_input) {
|
||||
SerializationContext::Params params;
|
||||
std::vector<std::pair<string, Tensor>> input_list;
|
||||
params.input_list = &input_list;
|
||||
params.optimization_only = true;
|
||||
SerializationContext serialization_ctx(params);
|
||||
GraphDef graph_def;
|
||||
TF_RETURN_IF_ERROR(
|
||||
AsGraphDef(ctx, input, std::move(serialization_ctx), &graph_def));
|
||||
|
||||
string output_node;
|
||||
for (const auto& node : graph_def.node()) {
|
||||
if (node.op() == "_Retval") {
|
||||
output_node = node.input(0);
|
||||
}
|
||||
}
|
||||
|
||||
VLOG(3) << "Before graph rewrites: " << graph_def.DebugString();
|
||||
TF_RETURN_IF_ERROR(ApplyRewrites(ctx, config_factory,
|
||||
optimize_function_library, &graph_def,
|
||||
&output_node));
|
||||
VLOG(3) << "After graph rewrites: " << graph_def.DebugString();
|
||||
|
||||
// Instantiate the optimized input pipeline by running the optimized graph
|
||||
// using the optimized function library.
|
||||
FunctionLibraryRuntime* flr = nullptr;
|
||||
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr = nullptr;
|
||||
std::unique_ptr<FunctionLibraryDefinition> lib_def = nullptr;
|
||||
TF_RETURN_IF_ERROR(
|
||||
ctx->function_library()->Clone(&lib_def, &pflr, &flr, true));
|
||||
|
||||
// Some functions may have been modified without having their names
|
||||
// changed (for example, nested dataset graphs from FlatMap or
|
||||
// Interleave).
|
||||
TF_RETURN_IF_ERROR(AddToFunctionLibrary(lib_def.get(), graph_def.library()));
|
||||
|
||||
Graph graph(OpRegistry::Global());
|
||||
TF_RETURN_IF_ERROR(ImportGraphDef({}, graph_def, &graph, nullptr));
|
||||
std::vector<Tensor> outputs;
|
||||
GraphRunner graph_runner(flr->device());
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
graph_runner.Run(&graph, flr, input_list, {output_node}, &outputs));
|
||||
TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(outputs[0], rewritten_input));
|
||||
(*rewritten_input)->Ref();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status VerifyTypesMatch(const DataTypeVector& expected,
|
||||
const DataTypeVector& received) {
|
||||
if (expected.size() != received.size()) {
|
||||
return errors::InvalidArgument(
|
||||
"Number of components does not match: expected ", expected.size(),
|
||||
" types but got ", received.size(), ".");
|
||||
}
|
||||
for (size_t i = 0; i < expected.size(); ++i) {
|
||||
if (expected[i] != received[i]) {
|
||||
return errors::InvalidArgument("Data type mismatch at component ", i,
|
||||
": expected ", DataTypeString(expected[i]),
|
||||
" but got ", DataTypeString(received[i]),
|
||||
".");
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status VerifyShapesCompatible(const std::vector<PartialTensorShape>& expected,
|
||||
const std::vector<PartialTensorShape>& received) {
|
||||
if (expected.size() != received.size()) {
|
||||
return errors::InvalidArgument(
|
||||
"Number of components does not match: expected ", expected.size(),
|
||||
" shapes but got ", received.size(), ".");
|
||||
}
|
||||
for (size_t i = 0; i < expected.size(); ++i) {
|
||||
if (!expected[i].IsCompatibleWith(received[i])) {
|
||||
return errors::InvalidArgument("Incompatible shapes at component ", i,
|
||||
": expected ", expected[i].DebugString(),
|
||||
" but got ", received[i].DebugString(),
|
||||
".");
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
uint64 HashSubgraphFunction(const FunctionDefLibrary& library,
|
||||
const FunctionDef* f) {
|
||||
@ -511,11 +509,6 @@ uint64 HashSubgraph(const GraphDef& g, const NodeDef* node) {
|
||||
return HashSubgraphImpl(grappler::GraphView(&g), node, &visited, &cache);
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
constexpr char kDelimiter[] = "@@";
|
||||
|
||||
} // namespace
|
||||
|
||||
VariantTensorDataReader::VariantTensorDataReader(
|
||||
const tensorflow::VariantTensorData* data)
|
||||
|
@ -15,13 +15,67 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_CORE_KERNELS_DATA_DATASET_UTILS_H_
|
||||
#define TENSORFLOW_CORE_KERNELS_DATA_DATASET_UTILS_H_
|
||||
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/framework/dataset.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/framework/resource_mgr.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace data {
|
||||
|
||||
template <typename T>
|
||||
class AnonymousResourceOp : public OpKernel {
|
||||
public:
|
||||
static std::atomic<int64> resource_id_counter_;
|
||||
|
||||
explicit AnonymousResourceOp(OpKernelConstruction* context)
|
||||
: OpKernel(context) {}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
FunctionLibraryRuntime* lib;
|
||||
std::unique_ptr<FunctionLibraryDefinition> flib_def(nullptr);
|
||||
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(nullptr);
|
||||
OP_REQUIRES_OK(
|
||||
ctx, ctx->function_library()->Clone(&flib_def, &pflr, &lib, true));
|
||||
T* resource;
|
||||
OP_REQUIRES_OK(ctx, CreateResource(ctx, std::move(flib_def),
|
||||
std::move(pflr), lib, &resource));
|
||||
|
||||
string container_name = name();
|
||||
string unique_name =
|
||||
strings::StrCat(container_name, resource_id_counter_.fetch_add(1));
|
||||
ResourceMgr* mgr = ctx->resource_manager();
|
||||
OP_REQUIRES_OK(ctx, mgr->Create<T>(container_name, unique_name, resource));
|
||||
|
||||
Tensor* handle_t;
|
||||
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &handle_t));
|
||||
ResourceHandle handle = MakeResourceHandle(ctx, container_name, unique_name,
|
||||
MakeTypeIndex<T>());
|
||||
handle_t->scalar<ResourceHandle>()() = handle;
|
||||
|
||||
if (create_deleter_) {
|
||||
Tensor* deleter_t;
|
||||
OP_REQUIRES_OK(ctx, ctx->allocate_output(1, TensorShape({}), &deleter_t));
|
||||
deleter_t->scalar<Variant>()() =
|
||||
ResourceDeleter(handle, ctx->resource_manager());
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
virtual string name() = 0;
|
||||
|
||||
virtual Status CreateResource(
|
||||
OpKernelContext* ctx, std::unique_ptr<FunctionLibraryDefinition> flib_def,
|
||||
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr,
|
||||
FunctionLibraryRuntime* lib, T** resource) = 0;
|
||||
|
||||
bool create_deleter_ = true;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
std::atomic<int64> AnonymousResourceOp<T>::resource_id_counter_;
|
||||
|
||||
// Returns a GraphDef representation of the given dataset.
|
||||
Status AsGraphDef(OpKernelContext* ctx, const DatasetBase* dataset,
|
||||
SerializationContext&& serialization_ctx,
|
||||
|
@ -28,6 +28,7 @@ namespace experimental {
|
||||
/* static */ constexpr const char* const AutoShardDatasetOp::kOutputTypes;
|
||||
/* static */ constexpr const char* const AutoShardDatasetOp::kOutputShapes;
|
||||
|
||||
constexpr char kMakeStateless[] = "make_stateless";
|
||||
constexpr char kOptimizerName[] = "tf_auto_shard";
|
||||
|
||||
AutoShardDatasetOp::AutoShardDatasetOp(OpKernelConstruction* ctx)
|
||||
@ -63,17 +64,21 @@ RewriterConfig AutoShardDatasetOp::CreateConfig(int64 num_workers,
|
||||
int64 index) {
|
||||
RewriterConfig rewriter_config;
|
||||
rewriter_config.set_fail_on_optimizer_errors(true);
|
||||
rewriter_config.add_optimizers(kOptimizerName);
|
||||
rewriter_config.set_meta_optimizer_iterations(RewriterConfig::ONE);
|
||||
|
||||
rewriter_config.add_optimizers(kMakeStateless);
|
||||
auto custom_optimizer = rewriter_config.add_custom_optimizers();
|
||||
custom_optimizer->set_name(kOptimizerName);
|
||||
custom_optimizer->set_name(kMakeStateless);
|
||||
|
||||
rewriter_config.add_optimizers(kOptimizerName);
|
||||
auto custom_optimizer2 = rewriter_config.add_custom_optimizers();
|
||||
custom_optimizer2->set_name(kOptimizerName);
|
||||
AttrValue num_workers_attr;
|
||||
num_workers_attr.set_i(num_workers);
|
||||
(*custom_optimizer->mutable_parameter_map())[kNumWorkers] = num_workers_attr;
|
||||
|
||||
(*custom_optimizer2->mutable_parameter_map())[kNumWorkers] = num_workers_attr;
|
||||
AttrValue index_attr;
|
||||
index_attr.set_i(index);
|
||||
(*custom_optimizer->mutable_parameter_map())[kIndex] = index_attr;
|
||||
(*custom_optimizer2->mutable_parameter_map())[kIndex] = index_attr;
|
||||
|
||||
return rewriter_config;
|
||||
}
|
||||
|
@ -52,7 +52,11 @@ namespace {
|
||||
// See documentation in ../../ops/dataset_ops.cc for a high-level
|
||||
// description of the following ops.
|
||||
|
||||
const char kAnonymousIterator[] = "AnonymousIterator";
|
||||
const char kAnonymousIteratorV2[] = "AnonymousIteratorV2";
|
||||
const char kIteratorVariantTypeName[] = "tensorflow::Iterator";
|
||||
const char kOutputShapes[] = "output_shapes";
|
||||
const char kOutputTypes[] = "output_types";
|
||||
|
||||
} // namespace
|
||||
|
||||
@ -259,8 +263,8 @@ REGISTER_UNARY_VARIANT_DECODE_FUNCTION(IteratorStateVariant,
|
||||
// resource containers with AnonymousIteratorHandleOp instead.
|
||||
IteratorHandleOp::IteratorHandleOp(OpKernelConstruction* ctx)
|
||||
: OpKernel(ctx), graph_def_version_(ctx->graph_def_version()) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_dtypes_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_dtypes_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("shared_name", &name_));
|
||||
}
|
||||
|
||||
@ -367,19 +371,14 @@ FunctionLibraryRuntime* IteratorHandleOp::CreatePrivateFLR(
|
||||
// running them.
|
||||
AnonymousIteratorHandleOp::AnonymousIteratorHandleOp(
|
||||
OpKernelConstruction* context)
|
||||
: AnonymousIteratorResourceOp<IteratorResource>(context),
|
||||
: AnonymousResourceOp<IteratorResource>(context),
|
||||
graph_def_version_(context->graph_def_version()) {
|
||||
create_deleter_ = context->def().op() == "AnonymousIteratorV2";
|
||||
OP_REQUIRES_OK(context, context->GetAttr(kOutputTypes, &output_dtypes_));
|
||||
OP_REQUIRES_OK(context, context->GetAttr(kOutputShapes, &output_shapes_));
|
||||
create_deleter_ = context->def().op() == kAnonymousIteratorV2;
|
||||
}
|
||||
|
||||
static std::atomic<int64> current_iterator_id_;
|
||||
|
||||
void AnonymousIteratorHandleOp::GenerateContainerNames(string* unique_name,
|
||||
string* container_name) {
|
||||
*unique_name =
|
||||
strings::StrCat("AnonymousIterator", current_iterator_id_.fetch_add(1));
|
||||
*container_name = "AnonymousIterator";
|
||||
}
|
||||
string AnonymousIteratorHandleOp::name() { return kAnonymousIterator; }
|
||||
|
||||
Status AnonymousIteratorHandleOp::CreateResource(
|
||||
OpKernelContext* ctx, std::unique_ptr<FunctionLibraryDefinition> flib_def,
|
||||
@ -539,8 +538,8 @@ class ReduceDatasetOp : public AsyncOpKernel {
|
||||
params.is_multi_device_function = true;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
FunctionMetadata::Create(ctx, "f", params, &func_metadata_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_types_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
|
||||
}
|
||||
|
||||
void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
|
||||
@ -714,8 +713,8 @@ class OneShotIteratorOp : public AsyncOpKernel {
|
||||
"support the 'shared_name' attr."));
|
||||
OP_REQUIRES_OK(ctx,
|
||||
ctx->GetAttr("dataset_factory", &dataset_factory_func_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_dtypes_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_dtypes_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
|
||||
}
|
||||
|
||||
~OneShotIteratorOp() override {
|
||||
@ -1007,8 +1006,8 @@ void IteratorToStringHandleOp::Compute(OpKernelContext* ctx) {
|
||||
IteratorFromStringHandleOp::IteratorFromStringHandleOp(
|
||||
OpKernelConstruction* ctx)
|
||||
: OpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_dtypes_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_dtypes_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
|
||||
OP_REQUIRES(
|
||||
ctx,
|
||||
output_dtypes_.empty() || output_shapes_.empty() ||
|
||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/kernels/data/dataset_utils.h"
|
||||
#include "tensorflow/core/kernels/data/unbounded_thread_pool.h"
|
||||
#include "tensorflow/core/kernels/ops_util.h"
|
||||
|
||||
@ -136,70 +137,16 @@ class IteratorHandleOp : public OpKernel {
|
||||
string name_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class AnonymousIteratorResourceOp : public OpKernel {
|
||||
public:
|
||||
explicit AnonymousIteratorResourceOp(OpKernelConstruction* context)
|
||||
: OpKernel(context) {
|
||||
OP_REQUIRES_OK(context, context->GetAttr("output_types", &output_dtypes_));
|
||||
OP_REQUIRES_OK(context, context->GetAttr("output_shapes", &output_shapes_));
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
FunctionLibraryRuntime* lib;
|
||||
std::unique_ptr<FunctionLibraryDefinition> flib_def(nullptr);
|
||||
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(nullptr);
|
||||
OP_REQUIRES_OK(
|
||||
ctx, ctx->function_library()->Clone(&flib_def, &pflr, &lib, true));
|
||||
T* resource;
|
||||
OP_REQUIRES_OK(ctx, CreateResource(ctx, std::move(flib_def),
|
||||
std::move(pflr), lib, &resource));
|
||||
|
||||
string unique_name, container_name;
|
||||
GenerateContainerNames(&unique_name, &container_name);
|
||||
ResourceMgr* mgr = ctx->resource_manager();
|
||||
OP_REQUIRES_OK(ctx, mgr->Create<T>(container_name, unique_name, resource));
|
||||
|
||||
Tensor* handle_t;
|
||||
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &handle_t));
|
||||
ResourceHandle handle = MakeResourceHandle(ctx, container_name, unique_name,
|
||||
MakeTypeIndex<T>());
|
||||
handle_t->scalar<ResourceHandle>()() = handle;
|
||||
|
||||
if (create_deleter_) {
|
||||
Tensor* deleter_t;
|
||||
OP_REQUIRES_OK(ctx, ctx->allocate_output(1, TensorShape({}), &deleter_t));
|
||||
deleter_t->scalar<Variant>()() =
|
||||
ResourceDeleter(handle, ctx->resource_manager());
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
virtual void GenerateContainerNames(string* unique_name,
|
||||
string* container_name) = 0;
|
||||
|
||||
virtual Status CreateResource(
|
||||
OpKernelContext* ctx, std::unique_ptr<FunctionLibraryDefinition> flib_def,
|
||||
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr,
|
||||
FunctionLibraryRuntime* lib, T** resource) = 0;
|
||||
|
||||
DataTypeVector output_dtypes_;
|
||||
std::vector<PartialTensorShape> output_shapes_;
|
||||
bool create_deleter_ = true;
|
||||
};
|
||||
|
||||
// Like IteratorHandleOp, but creates handles which are never shared, and does
|
||||
// not hold a reference to these handles. The latter is important for eager
|
||||
// execution, since OpKernel instances generally live as long as the program
|
||||
// running them.
|
||||
class AnonymousIteratorHandleOp
|
||||
: public AnonymousIteratorResourceOp<IteratorResource> {
|
||||
class AnonymousIteratorHandleOp : public AnonymousResourceOp<IteratorResource> {
|
||||
public:
|
||||
explicit AnonymousIteratorHandleOp(OpKernelConstruction* context);
|
||||
|
||||
private:
|
||||
void GenerateContainerNames(string* unique_name,
|
||||
string* container_name) override;
|
||||
string name() override;
|
||||
|
||||
Status CreateResource(OpKernelContext* ctx,
|
||||
std::unique_ptr<FunctionLibraryDefinition> flib_def,
|
||||
@ -207,6 +154,8 @@ class AnonymousIteratorHandleOp
|
||||
FunctionLibraryRuntime* lib,
|
||||
IteratorResource** resource) override;
|
||||
|
||||
DataTypeVector output_dtypes_;
|
||||
std::vector<PartialTensorShape> output_shapes_;
|
||||
const int graph_def_version_;
|
||||
};
|
||||
|
||||
|
@ -36,6 +36,11 @@ namespace tensorflow {
|
||||
namespace data {
|
||||
namespace {
|
||||
|
||||
const char kAnonymousMultiDeviceIterator[] = "AnonymousMultiDeviceIterator";
|
||||
const char kDevices[] = "devices";
|
||||
const char kOutputShapes[] = "output_shapes";
|
||||
const char kOutputTypes[] = "output_types";
|
||||
|
||||
struct HostBufferElement {
|
||||
Status status;
|
||||
bool end_of_sequence;
|
||||
@ -399,11 +404,11 @@ class MultiDeviceIteratorHandleOp : public OpKernel {
|
||||
public:
|
||||
explicit MultiDeviceIteratorHandleOp(OpKernelConstruction* ctx)
|
||||
: OpKernel(ctx), graph_def_version_(ctx->graph_def_version()) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_types_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("shared_name", &name_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("container", &container_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("devices", &devices_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr(kDevices, &devices_));
|
||||
}
|
||||
|
||||
// The resource is deleted from the resource manager only when it is private
|
||||
@ -443,7 +448,7 @@ class MultiDeviceIteratorHandleOp : public OpKernel {
|
||||
if (name_ == ResourceHandle::ANONYMOUS_NAME) {
|
||||
unique_name = strings::StrCat("_AnonymousMultiDeviceIterator",
|
||||
current_id_.fetch_add(1));
|
||||
container_name = "AnonymousMultiDeviceIterator";
|
||||
container_name = kAnonymousMultiDeviceIterator;
|
||||
resource = new MultiDeviceIterator(
|
||||
context->env(), output_types_, output_shapes_, devices_,
|
||||
std::move(flib_def), std::move(pflr), flr,
|
||||
@ -511,26 +516,18 @@ class MultiDeviceIteratorHandleOp : public OpKernel {
|
||||
REGISTER_KERNEL_BUILDER(Name("MultiDeviceIterator").Device(DEVICE_CPU),
|
||||
MultiDeviceIteratorHandleOp);
|
||||
|
||||
// This atomic is used to ensure that each new AnonymousMultiDeviceIterator
|
||||
// handle is unique.
|
||||
static std::atomic<int64> current_multi_device_iterator_id_;
|
||||
|
||||
class AnonymousMultiDeviceIteratorOp
|
||||
: public AnonymousIteratorResourceOp<MultiDeviceIterator> {
|
||||
: public AnonymousResourceOp<MultiDeviceIterator> {
|
||||
public:
|
||||
explicit AnonymousMultiDeviceIteratorOp(OpKernelConstruction* ctx)
|
||||
: AnonymousIteratorResourceOp<MultiDeviceIterator>(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("devices", &devices_));
|
||||
: AnonymousResourceOp<MultiDeviceIterator>(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr(kDevices, &devices_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_dtypes_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
|
||||
}
|
||||
|
||||
private:
|
||||
void GenerateContainerNames(string* unique_name,
|
||||
string* container_name) override {
|
||||
*unique_name =
|
||||
strings::StrCat("_AnonymousMultiDeviceIterator",
|
||||
current_multi_device_iterator_id_.fetch_add(1));
|
||||
*container_name = "AnonymousMultiDeviceIterator";
|
||||
}
|
||||
string name() override { return kAnonymousMultiDeviceIterator; }
|
||||
|
||||
Status CreateResource(OpKernelContext* ctx,
|
||||
std::unique_ptr<FunctionLibraryDefinition> flib_def,
|
||||
@ -546,9 +543,11 @@ class AnonymousMultiDeviceIteratorOp
|
||||
}
|
||||
|
||||
std::vector<string> devices_;
|
||||
DataTypeVector output_dtypes_;
|
||||
std::vector<PartialTensorShape> output_shapes_;
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("AnonymousMultiDeviceIterator").Device(DEVICE_CPU),
|
||||
REGISTER_KERNEL_BUILDER(Name(kAnonymousMultiDeviceIterator).Device(DEVICE_CPU),
|
||||
AnonymousMultiDeviceIteratorOp);
|
||||
|
||||
// Calls init on the MultiDeviceIterator.
|
||||
@ -657,8 +656,8 @@ class MultiDeviceIteratorFromStringHandleOp : public OpKernel {
|
||||
public:
|
||||
explicit MultiDeviceIteratorFromStringHandleOp(OpKernelConstruction* ctx)
|
||||
: OpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_types_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
|
||||
OP_REQUIRES(
|
||||
ctx,
|
||||
output_types_.empty() || output_shapes_.empty() ||
|
||||
|
128
tensorflow/core/kernels/data/random_seed_ops.cc
Normal file
128
tensorflow/core/kernels/data/random_seed_ops.cc
Normal file
@ -0,0 +1,128 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/core/kernels/data/random_seed_ops.h"
|
||||
|
||||
#include "tensorflow/core/framework/dataset.h"
|
||||
#include "tensorflow/core/framework/partial_tensor_shape.h"
|
||||
#include "tensorflow/core/framework/resource_mgr.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/lib/random/philox_random.h"
|
||||
#include "tensorflow/core/lib/random/random.h"
|
||||
#include "tensorflow/core/lib/random/random_distributions.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace data {
|
||||
namespace {
|
||||
|
||||
const char kNumRandomSamples[] = "num_random_samples";
|
||||
const char kRandomSeedGenerator[] = "RandomSeedGenerator";
|
||||
const char kSeed[] = "seed";
|
||||
const char kSeed2[] = "seed2";
|
||||
|
||||
} // namespace
|
||||
|
||||
string RandomSeedGenerator::DebugString() const {
|
||||
return "RandomSeedGenerator";
|
||||
}
|
||||
|
||||
void RandomSeedGenerator::GenerateRandomSeeds(int64* seed1, int64* seed2) {
|
||||
mutex_lock l(mu_);
|
||||
num_random_samples_++;
|
||||
*seed1 = generator_();
|
||||
num_random_samples_++;
|
||||
*seed2 = generator_();
|
||||
}
|
||||
|
||||
int64 RandomSeedGenerator::num_random_samples() {
|
||||
tf_shared_lock l(mu_);
|
||||
return num_random_samples_;
|
||||
}
|
||||
|
||||
void RandomSeedGenerator::set_num_random_samples(int64 num_random_samples) {
|
||||
mutex_lock l(mu_);
|
||||
num_random_samples_ = num_random_samples;
|
||||
}
|
||||
|
||||
void RandomSeedGenerator::Reset() {
|
||||
mutex_lock l(mu_);
|
||||
// Reset the generators based on the current seeds.
|
||||
parent_generator_ = random::PhiloxRandom(seed_, seed2_);
|
||||
generator_ =
|
||||
random::SingleSampleAdapter<random::PhiloxRandom>(&parent_generator_);
|
||||
generator_.Skip(num_random_samples_);
|
||||
}
|
||||
|
||||
void RandomSeedGenerator::Serialize(OpKernelContext* ctx) {
|
||||
mutex_lock l(mu_);
|
||||
Tensor* num_random_samples;
|
||||
OP_REQUIRES_OK(ctx, ctx->allocate_output(kNumRandomSamples, TensorShape({}),
|
||||
&num_random_samples));
|
||||
num_random_samples->scalar<int64>()() = num_random_samples_;
|
||||
Tensor* seed;
|
||||
OP_REQUIRES_OK(ctx, ctx->allocate_output(kSeed, TensorShape({}), &seed));
|
||||
seed->scalar<int64>()() = seed_;
|
||||
Tensor* seed2;
|
||||
OP_REQUIRES_OK(ctx, ctx->allocate_output(kSeed2, TensorShape({}), &seed2));
|
||||
seed2->scalar<int64>()() = seed2_;
|
||||
}
|
||||
|
||||
AnonymousRandomSeedGeneratorHandleOp::AnonymousRandomSeedGeneratorHandleOp(
|
||||
OpKernelConstruction* ctx)
|
||||
: AnonymousResourceOp<RandomSeedGenerator>(ctx) {}
|
||||
|
||||
void AnonymousRandomSeedGeneratorHandleOp::Compute(OpKernelContext* ctx) {
|
||||
OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, kSeed, &seed_));
|
||||
OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, kSeed2, &seed2_));
|
||||
AnonymousResourceOp<RandomSeedGenerator>::Compute(ctx);
|
||||
}
|
||||
|
||||
string AnonymousRandomSeedGeneratorHandleOp::name() {
|
||||
return kRandomSeedGenerator;
|
||||
}
|
||||
|
||||
Status AnonymousRandomSeedGeneratorHandleOp::CreateResource(
|
||||
OpKernelContext* ctx, std::unique_ptr<FunctionLibraryDefinition> flib_def,
|
||||
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr,
|
||||
FunctionLibraryRuntime* lib, RandomSeedGenerator** resource) {
|
||||
*resource = new RandomSeedGenerator(seed_, seed2_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void DeleteRandomSeedGeneratorOp::Compute(OpKernelContext* ctx) {
|
||||
ResourceHandle handle = ctx->input(0).flat<ResourceHandle>()(0);
|
||||
// The resource is guaranteed to exist because the variant tensor wrapping the
|
||||
// deleter is provided as an unused input to this op, which guarantees that it
|
||||
// has not run yet.
|
||||
Status s = ctx->resource_manager()->Delete(handle);
|
||||
if (errors::IsNotFound(s)) {
|
||||
// TODO(b/135948230): Investigate why is the above statement not true and
|
||||
// then get rid of the special case.
|
||||
ctx->SetStatus(Status::OK());
|
||||
return;
|
||||
}
|
||||
ctx->SetStatus(s);
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("AnonymousRandomSeedGenerator").Device(DEVICE_CPU),
|
||||
AnonymousRandomSeedGeneratorHandleOp);
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("DeleteRandomSeedGenerator").Device(DEVICE_CPU),
|
||||
DeleteRandomSeedGeneratorOp);
|
||||
|
||||
} // namespace
|
||||
} // namespace data
|
||||
} // namespace tensorflow
|
86
tensorflow/core/kernels/data/random_seed_ops.h
Normal file
86
tensorflow/core/kernels/data/random_seed_ops.h
Normal file
@ -0,0 +1,86 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_KERNELS_DATA_RANDOM_SEED_OPS_H_
|
||||
#define TENSORFLOW_CORE_KERNELS_DATA_RANDOM_SEED_OPS_H_
|
||||
|
||||
#include "tensorflow/core/framework/resource_mgr.h"
|
||||
#include "tensorflow/core/kernels/data/dataset_utils.h"
|
||||
#include "tensorflow/core/lib/random/philox_random.h"
|
||||
#include "tensorflow/core/lib/random/random.h"
|
||||
#include "tensorflow/core/lib/random/random_distributions.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace data {
|
||||
|
||||
// A random seed generator resource.
|
||||
class RandomSeedGenerator : public ResourceBase {
|
||||
public:
|
||||
RandomSeedGenerator(int64 seed, int64 seed2)
|
||||
: seed_(seed),
|
||||
seed2_(seed2),
|
||||
parent_generator_(seed, seed2),
|
||||
generator_(&parent_generator_) {}
|
||||
|
||||
int64 num_random_samples();
|
||||
void set_num_random_samples(int64 num_random_samples);
|
||||
|
||||
string DebugString() const override;
|
||||
void GenerateRandomSeeds(int64* seed1, int64* seed2);
|
||||
void Reset();
|
||||
void Serialize(OpKernelContext* ctx);
|
||||
|
||||
private:
|
||||
const int64 seed_;
|
||||
const int64 seed2_;
|
||||
mutex mu_;
|
||||
random::PhiloxRandom parent_generator_ GUARDED_BY(mu_);
|
||||
random::SingleSampleAdapter<random::PhiloxRandom> generator_ GUARDED_BY(mu_);
|
||||
int64 num_random_samples_ GUARDED_BY(mu_) = 0;
|
||||
};
|
||||
|
||||
// Creates an instance of random seed generator resource and transfers ownership
|
||||
// to the caller.
|
||||
class AnonymousRandomSeedGeneratorHandleOp
|
||||
: public AnonymousResourceOp<RandomSeedGenerator> {
|
||||
public:
|
||||
explicit AnonymousRandomSeedGeneratorHandleOp(OpKernelConstruction* ctx);
|
||||
void Compute(OpKernelContext* ctx) override;
|
||||
|
||||
private:
|
||||
string name() override;
|
||||
Status CreateResource(OpKernelContext* ctx,
|
||||
std::unique_ptr<FunctionLibraryDefinition> flib_def,
|
||||
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr,
|
||||
FunctionLibraryRuntime* lib,
|
||||
RandomSeedGenerator** resource) override;
|
||||
|
||||
int64 seed_;
|
||||
int64 seed2_;
|
||||
};
|
||||
|
||||
// Deletes an instance of random seed generator resource.
|
||||
class DeleteRandomSeedGeneratorOp : public OpKernel {
|
||||
public:
|
||||
explicit DeleteRandomSeedGeneratorOp(OpKernelConstruction* ctx)
|
||||
: OpKernel(ctx) {}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override;
|
||||
};
|
||||
|
||||
} // namespace data
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_KERNELS_DATA_RANDOM_SEED_OPS_H_
|
@ -22,6 +22,8 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/resource_mgr.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/kernels/data/name_utils.h"
|
||||
#include "tensorflow/core/kernels/data/random_seed_ops.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/random/philox_random.h"
|
||||
#include "tensorflow/core/lib/random/random.h"
|
||||
#include "tensorflow/core/lib/random/random_distributions.h"
|
||||
@ -64,6 +66,7 @@ constexpr char kTFData[] = "tf_data";
|
||||
constexpr char kDSNumRandomSamples[] = "ds_num_random_samples";
|
||||
constexpr char kFixedSeedDatasetPrefix[] = "FixedSeed";
|
||||
constexpr char kReshufflingDatasetPrefix[] = "Reshuffling";
|
||||
constexpr char kShuffleDataset[] = "ShuffleDataset";
|
||||
|
||||
ShuffleDatasetOpBase::ShuffleDatasetOpBase(OpKernelConstruction* ctx)
|
||||
: UnaryDatasetOpKernel(ctx) {}
|
||||
@ -385,12 +388,6 @@ class ShuffleDatasetOpBase::ShuffleDatasetBase : public DatasetBase {
|
||||
const int64 count_;
|
||||
};
|
||||
|
||||
ShuffleDatasetOp::ShuffleDatasetOp(OpKernelConstruction* ctx)
|
||||
: ShuffleDatasetOpBase(ctx) {
|
||||
OP_REQUIRES_OK(
|
||||
ctx, ctx->GetAttr(kReshuffleEachIteration, &reshuffle_each_iteration_));
|
||||
}
|
||||
|
||||
// A dataset that uses a pseudorandom sequence of seeds for the iterators
|
||||
// created from it. Used when `reshuffle_each_iteration` is true.
|
||||
class ShuffleDatasetOp::ReshufflingDataset : public ShuffleDatasetBase {
|
||||
@ -417,59 +414,9 @@ class ShuffleDatasetOp::ReshufflingDataset : public ShuffleDatasetBase {
|
||||
}
|
||||
|
||||
protected:
|
||||
class RandomSeedGenerator : public ResourceBase {
|
||||
public:
|
||||
RandomSeedGenerator(int64 seed, int64 seed2)
|
||||
: seed_(seed),
|
||||
seed2_(seed2),
|
||||
parent_generator_(seed, seed2),
|
||||
generator_(&parent_generator_) {}
|
||||
|
||||
string DebugString() const override {
|
||||
return strings::StrCat(kReshufflingDatasetPrefix, name_utils::kDelimiter,
|
||||
kRandomSeedGenerator);
|
||||
}
|
||||
|
||||
void GenerateRandomSeeds(int64* seed1, int64* seed2) {
|
||||
mutex_lock l(mu_);
|
||||
num_random_samples_++;
|
||||
*seed1 = generator_();
|
||||
num_random_samples_++;
|
||||
*seed2 = generator_();
|
||||
}
|
||||
|
||||
int64 num_random_samples() {
|
||||
tf_shared_lock l(mu_);
|
||||
return num_random_samples_;
|
||||
}
|
||||
|
||||
void set_num_random_samples(int64 num_random_samples) {
|
||||
mutex_lock l(mu_);
|
||||
num_random_samples_ = num_random_samples;
|
||||
}
|
||||
|
||||
void Reset() {
|
||||
mutex_lock l(mu_);
|
||||
// Reset the generators based on the current seeds.
|
||||
parent_generator_ = random::PhiloxRandom(seed_, seed2_);
|
||||
generator_ =
|
||||
random::SingleSampleAdapter<random::PhiloxRandom>(&parent_generator_);
|
||||
generator_.Skip(num_random_samples_);
|
||||
}
|
||||
|
||||
private:
|
||||
const int64 seed_;
|
||||
const int64 seed2_;
|
||||
mutex mu_;
|
||||
random::PhiloxRandom parent_generator_ GUARDED_BY(mu_);
|
||||
random::SingleSampleAdapter<random::PhiloxRandom> generator_
|
||||
GUARDED_BY(mu_);
|
||||
int64 num_random_samples_ GUARDED_BY(mu_) = 0;
|
||||
};
|
||||
|
||||
class Iterator : public ShuffleDatasetBase::Iterator<ReshufflingDataset> {
|
||||
public:
|
||||
explicit Iterator(const Params& params, int64 seed, int64 seed2)
|
||||
Iterator(const Params& params, int64 seed, int64 seed2)
|
||||
: ShuffleDatasetBase::Iterator<ReshufflingDataset>(params, seed,
|
||||
seed2) {}
|
||||
|
||||
@ -502,13 +449,10 @@ class ShuffleDatasetOp::ReshufflingDataset : public ShuffleDatasetBase {
|
||||
new RandomSeedGenerator(dataset_seed, dataset_seed2);
|
||||
return Status::OK();
|
||||
}));
|
||||
// Now use the seed generator to update the base class Iterator seeds
|
||||
// and random number generator with generated seeds for the current
|
||||
// repetition.
|
||||
mutex_lock l(mu_);
|
||||
seed_generator->GenerateRandomSeeds(&seed_, &seed2_);
|
||||
ResetRngs();
|
||||
seed_generator_ = seed_generator;
|
||||
seed_generator_->GenerateRandomSeeds(&seed_, &seed2_);
|
||||
mutex_lock l(mu_);
|
||||
ResetRngs();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -575,6 +519,108 @@ class ShuffleDatasetOp::ReshufflingDataset : public ShuffleDatasetBase {
|
||||
const int64 seed2_;
|
||||
};
|
||||
|
||||
// A dataset that uses a pseudorandom sequence of seeds for the iterators
|
||||
// created from it. Used in TF 2.0 when `reshuffle_each_iteration` is true.
|
||||
class ShuffleDatasetOp::ReshufflingDatasetV2 : public ShuffleDatasetBase {
|
||||
public:
|
||||
ReshufflingDatasetV2(OpKernelContext* ctx, const DatasetBase* input,
|
||||
int64 buffer_size, int64 count,
|
||||
const Tensor& resource_handle,
|
||||
RandomSeedGenerator* seed_generator)
|
||||
: ShuffleDatasetBase(ctx, input, buffer_size, count),
|
||||
resource_handle_(resource_handle),
|
||||
seed_generator_(seed_generator) {}
|
||||
|
||||
~ReshufflingDatasetV2() override { seed_generator_->Unref(); }
|
||||
|
||||
string DebugString() const override {
|
||||
name_utils::DatasetDebugStringParams params;
|
||||
params.dataset_prefix = kReshufflingDatasetPrefix;
|
||||
params.set_args(buffer_size_);
|
||||
return name_utils::DatasetDebugString(kDatasetType, params);
|
||||
}
|
||||
|
||||
bool IsStateful() const override { return true; }
|
||||
|
||||
std::unique_ptr<IteratorBase> MakeIteratorInternal(
|
||||
const string& prefix) const override {
|
||||
return absl::make_unique<Iterator>(
|
||||
Iterator::Params{this,
|
||||
name_utils::IteratorPrefix(kDatasetType, prefix)},
|
||||
seed_generator_);
|
||||
}
|
||||
|
||||
protected:
|
||||
class Iterator : public ShuffleDatasetBase::Iterator<ReshufflingDatasetV2> {
|
||||
public:
|
||||
Iterator(const Params& params, RandomSeedGenerator* seed_generator)
|
||||
: ShuffleDatasetBase::Iterator<ReshufflingDatasetV2>(params, 0, 0),
|
||||
seed_generator_(seed_generator) {}
|
||||
|
||||
Status Initialize(IteratorContext* ctx) override {
|
||||
mutex_lock l(mu_);
|
||||
seed_generator_->GenerateRandomSeeds(&seed_, &seed2_);
|
||||
ResetRngs();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
protected:
|
||||
std::shared_ptr<model::Node> CreateNode(
|
||||
IteratorContext* ctx, model::Node::Args args) const override {
|
||||
return model::MakeKnownRatioNode(std::move(args), /*ratio=*/1);
|
||||
}
|
||||
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
// Save state of the seed generator.
|
||||
TF_RETURN_IF_ERROR(
|
||||
writer->WriteScalar(full_name(kDSNumRandomSamples),
|
||||
seed_generator_->num_random_samples()));
|
||||
|
||||
// Save the tterator state.
|
||||
return ShuffleDatasetBase::Iterator<ReshufflingDatasetV2>::SaveInternal(
|
||||
writer);
|
||||
}
|
||||
|
||||
Status RestoreInternal(IteratorContext* ctx,
|
||||
IteratorStateReader* reader) override {
|
||||
// Restore state of the seed generator.
|
||||
int64 num_random_samples;
|
||||
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kDSNumRandomSamples),
|
||||
&num_random_samples));
|
||||
seed_generator_->set_num_random_samples(num_random_samples);
|
||||
seed_generator_->Reset();
|
||||
|
||||
// Restore the iterator state.
|
||||
return ShuffleDatasetBase::Iterator<
|
||||
ReshufflingDatasetV2>::RestoreInternal(ctx, reader);
|
||||
}
|
||||
|
||||
private:
|
||||
RandomSeedGenerator* seed_generator_;
|
||||
};
|
||||
|
||||
Status AsGraphDefInternal(SerializationContext* ctx,
|
||||
DatasetGraphDefBuilder* b,
|
||||
Node** output) const override {
|
||||
Node* input_graph_node = nullptr;
|
||||
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
|
||||
Node* buffer_size_node = nullptr;
|
||||
TF_RETURN_IF_ERROR(b->AddScalar(buffer_size_, &buffer_size_node));
|
||||
Node* resource_handle_node = nullptr;
|
||||
TF_RETURN_IF_ERROR(b->AddTensor(resource_handle_, &resource_handle_node));
|
||||
TF_RETURN_IF_ERROR(b->AddDataset(
|
||||
this,
|
||||
{input_graph_node, buffer_size_node, resource_handle_node}, // Inputs
|
||||
{}, // Attrs
|
||||
output));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
const Tensor resource_handle_;
|
||||
RandomSeedGenerator* seed_generator_ = nullptr;
|
||||
};
|
||||
|
||||
// A dataset that uses the same fixed seed for all iterators created from it.
|
||||
// Used when `reshuffle_each_iteration` is false.
|
||||
class ShuffleDatasetOp::FixedSeedDataset : public ShuffleDatasetBase {
|
||||
@ -628,6 +674,15 @@ class ShuffleDatasetOp::FixedSeedDataset : public ShuffleDatasetBase {
|
||||
const int64 seed2_;
|
||||
};
|
||||
|
||||
ShuffleDatasetOp::ShuffleDatasetOp(OpKernelConstruction* ctx)
|
||||
: ShuffleDatasetOpBase(ctx),
|
||||
op_version_(ctx->def().op() == kShuffleDataset ? 1 : 2) {
|
||||
if (ctx->HasAttr(kReshuffleEachIteration)) {
|
||||
OP_REQUIRES_OK(
|
||||
ctx, ctx->GetAttr(kReshuffleEachIteration, &reshuffle_each_iteration_));
|
||||
}
|
||||
}
|
||||
|
||||
void ShuffleDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
|
||||
DatasetBase** output) {
|
||||
int64 buffer_size = 0;
|
||||
@ -637,6 +692,16 @@ void ShuffleDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
|
||||
ctx, buffer_size > 0,
|
||||
errors::InvalidArgument("buffer_size must be greater than zero."));
|
||||
|
||||
int64 count = 1;
|
||||
if (op_version_ == 2) {
|
||||
RandomSeedGenerator* seed_generator = nullptr;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, LookupResource(ctx, HandleFromInput(ctx, 2), &seed_generator));
|
||||
*output = new ReshufflingDatasetV2(ctx, input, buffer_size, count,
|
||||
ctx->input(2), seed_generator);
|
||||
return;
|
||||
}
|
||||
|
||||
int64 seed;
|
||||
OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, kSeed, &seed));
|
||||
|
||||
@ -650,7 +715,6 @@ void ShuffleDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
|
||||
seed2 = random::New64();
|
||||
}
|
||||
|
||||
int64 count = 1;
|
||||
if (reshuffle_each_iteration_) {
|
||||
*output =
|
||||
new ReshufflingDataset(ctx, input, buffer_size, seed, seed2, count);
|
||||
@ -748,6 +812,9 @@ namespace {
|
||||
REGISTER_KERNEL_BUILDER(Name("ShuffleDataset").Device(DEVICE_CPU),
|
||||
ShuffleDatasetOp);
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("ShuffleDatasetV2").Device(DEVICE_CPU),
|
||||
ShuffleDatasetOp);
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("ShuffleAndRepeatDataset").Device(DEVICE_CPU),
|
||||
ShuffleAndRepeatDatasetOp);
|
||||
} // namespace
|
||||
|
@ -49,7 +49,9 @@ class ShuffleDatasetOp : public ShuffleDatasetOpBase {
|
||||
|
||||
private:
|
||||
class ReshufflingDataset;
|
||||
class ReshufflingDatasetV2;
|
||||
class FixedSeedDataset;
|
||||
int op_version_;
|
||||
bool reshuffle_each_iteration_;
|
||||
};
|
||||
|
||||
|
@ -354,6 +354,22 @@ REGISTER_OP("RangeDataset")
|
||||
return shape_inference::ScalarShape(c);
|
||||
});
|
||||
|
||||
REGISTER_OP("AnonymousRandomSeedGenerator")
|
||||
.Input("seed: int64")
|
||||
.Input("seed2: int64")
|
||||
.Output("handle: resource")
|
||||
.Output("deleter: variant")
|
||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||
c->set_output(0, c->Scalar());
|
||||
c->set_output(1, c->Scalar());
|
||||
return Status::OK();
|
||||
});
|
||||
|
||||
REGISTER_OP("DeleteRandomSeedGenerator")
|
||||
.Input("handle: resource")
|
||||
.Input("deleter: variant")
|
||||
.SetShapeFn(shape_inference::NoOutputs);
|
||||
|
||||
REGISTER_OP("ShuffleDataset")
|
||||
.Input("input_dataset: variant")
|
||||
.Input("buffer_size: int64")
|
||||
@ -372,6 +388,21 @@ REGISTER_OP("ShuffleDataset")
|
||||
return shape_inference::ScalarShape(c);
|
||||
});
|
||||
|
||||
REGISTER_OP("ShuffleDatasetV2")
|
||||
.Input("input_dataset: variant")
|
||||
.Input("buffer_size: int64")
|
||||
.Input("seed_generator: resource")
|
||||
.Output("handle: variant")
|
||||
.Attr("output_types: list(type) >= 1")
|
||||
.Attr("output_shapes: list(shape) >= 1")
|
||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||
shape_inference::ShapeHandle unused;
|
||||
// buffer_size, seed, and seed2 should be scalars.
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
|
||||
return shape_inference::ScalarShape(c);
|
||||
});
|
||||
|
||||
REGISTER_OP("ShuffleAndRepeatDataset")
|
||||
.Input("input_dataset: variant")
|
||||
.Input("buffer_size: int64")
|
||||
|
@ -19,12 +19,12 @@ visibility = [
|
||||
"//bazel_pip/tensorflow/lite/toco/python:__pkg__",
|
||||
]
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "if_mlir", "if_not_v2", "if_not_windows", "tf_cuda_library", "tf_gen_op_wrapper_py", "py_test", "tf_py_test", "py_tests", "tf_py_build_info_genrule", "tf_cc_shared_object")
|
||||
load("//tensorflow:tensorflow.bzl", "if_mlir", "if_not_v2", "if_not_windows", "py_test", "py_tests", "tf_cc_shared_object", "tf_cuda_library", "tf_gen_op_wrapper_py", "tf_py_build_info_genrule", "tf_py_test")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc")
|
||||
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
|
||||
load("//tensorflow:tensorflow.bzl", "cuda_py_tests")
|
||||
load("//tensorflow/core:platform/default/build_config.bzl", "pyx_library", "tf_proto_library", "tf_proto_library_py", "tf_additional_lib_deps", "tf_additional_all_protos", "tf_protos_grappler", "tf_additional_cupti_test_flags") # @unused
|
||||
load("//tensorflow/core:platform/default/build_config_root.bzl", "tf_additional_plugin_deps", "tf_additional_verbs_deps", "tf_additional_mpi_deps", "tf_additional_gdr_deps", "if_static")
|
||||
load("//tensorflow/core:platform/default/build_config.bzl", "pyx_library", "tf_additional_all_protos", "tf_additional_cupti_test_flags", "tf_additional_lib_deps", "tf_proto_library", "tf_proto_library_py", "tf_protos_grappler") # @unused
|
||||
load("//tensorflow/core:platform/default/build_config_root.bzl", "if_static", "tf_additional_gdr_deps", "tf_additional_mpi_deps", "tf_additional_plugin_deps", "tf_additional_verbs_deps")
|
||||
load("//tensorflow/python:build_defs.bzl", "tf_gen_op_wrapper_private_py")
|
||||
load(
|
||||
"//third_party/ngraph:build_defs.bzl",
|
||||
@ -6367,6 +6367,7 @@ tf_py_test(
|
||||
additional_deps = [
|
||||
":array_ops",
|
||||
":client_testlib",
|
||||
":framework_combinations",
|
||||
":framework_for_generated_wrappers",
|
||||
":tf_item",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
|
@ -102,15 +102,17 @@ class MakeTFRecordDatasetTest(
|
||||
|
||||
def _shuffle_test(self, batch_size, num_epochs, num_parallel_reads=1,
|
||||
seed=None):
|
||||
dataset = readers.make_tf_record_dataset(
|
||||
file_pattern=self.test_filenames,
|
||||
num_epochs=num_epochs,
|
||||
batch_size=batch_size,
|
||||
num_parallel_reads=num_parallel_reads,
|
||||
shuffle=True,
|
||||
shuffle_seed=seed)
|
||||
|
||||
next_element = self.getNext(dataset)
|
||||
def dataset_fn():
|
||||
return readers.make_tf_record_dataset(
|
||||
file_pattern=self.test_filenames,
|
||||
num_epochs=num_epochs,
|
||||
batch_size=batch_size,
|
||||
num_parallel_reads=num_parallel_reads,
|
||||
shuffle=True,
|
||||
shuffle_seed=seed)
|
||||
|
||||
next_element = self.getNext(dataset_fn())
|
||||
first_batches = []
|
||||
try:
|
||||
while True:
|
||||
@ -118,7 +120,7 @@ class MakeTFRecordDatasetTest(
|
||||
except errors.OutOfRangeError:
|
||||
pass
|
||||
|
||||
next_element = self.getNext(dataset)
|
||||
next_element = self.getNext(dataset_fn())
|
||||
second_batches = []
|
||||
try:
|
||||
while True:
|
||||
|
@ -17,9 +17,11 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python import tf2
|
||||
from tensorflow.python.data.experimental.ops import optimization
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.platform import test
|
||||
@ -29,8 +31,13 @@ from tensorflow.python.platform import test
|
||||
class ShuffleAndRepeatFusionTest(test_base.DatasetTestBase):
|
||||
|
||||
def testShuffleAndRepeatFusion(self):
|
||||
if tf2.enabled() and context.executing_eagerly():
|
||||
expected = "Shuffle"
|
||||
else:
|
||||
expected = "ShuffleAndRepeat"
|
||||
|
||||
dataset = dataset_ops.Dataset.range(10).apply(
|
||||
optimization.assert_next(["ShuffleAndRepeat"])).shuffle(10).repeat(2)
|
||||
optimization.assert_next([expected])).shuffle(10).repeat(2)
|
||||
options = dataset_ops.Options()
|
||||
options.experimental_optimization.apply_default_optimizations = False
|
||||
options.experimental_optimization.shuffle_and_repeat_fusion = True
|
||||
|
@ -46,7 +46,8 @@ class RejectionResampleTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
initial_dist = [0.2] * 5 if initial_known else None
|
||||
classes = math_ops.cast(classes, dtypes.int64) # needed for Windows build.
|
||||
dataset = dataset_ops.Dataset.from_tensor_slices(classes).shuffle(
|
||||
200, seed=21).map(lambda c: (c, string_ops.as_string(c))).repeat()
|
||||
200, seed=21, reshuffle_each_iteration=False).map(
|
||||
lambda c: (c, string_ops.as_string(c))).repeat()
|
||||
|
||||
get_next = self.getNext(
|
||||
dataset.apply(
|
||||
|
@ -79,8 +79,9 @@ class ListFilesTest(test_base.DatasetTestBase):
|
||||
filenames = ['a', 'b', 'c']
|
||||
self._touchTempFiles(filenames)
|
||||
|
||||
dataset = dataset_ops.Dataset.list_files(
|
||||
path.join(self.tmp_dir, '*'), shuffle=True, seed=37)
|
||||
def dataset_fn():
|
||||
return dataset_ops.Dataset.list_files(
|
||||
path.join(self.tmp_dir, '*'), shuffle=True, seed=37)
|
||||
|
||||
expected_filenames = [
|
||||
compat.as_bytes(path.join(self.tmp_dir, filename))
|
||||
@ -90,7 +91,7 @@ class ListFilesTest(test_base.DatasetTestBase):
|
||||
all_actual_filenames = []
|
||||
for _ in range(3):
|
||||
actual_filenames = []
|
||||
next_element = self.getNext(dataset, requires_initialization=True)
|
||||
next_element = self.getNext(dataset_fn(), requires_initialization=True)
|
||||
try:
|
||||
while True:
|
||||
actual_filenames.append(self.evaluate(next_element()))
|
||||
|
@ -24,19 +24,18 @@ import numpy as np
|
||||
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
|
||||
from tensorflow.python.framework import combinations
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import random_seed
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class ShuffleTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testShuffleDataset(self):
|
||||
components = (
|
||||
np.array([1, 2, 3, 4]), np.array([5, 6, 7, 8]),
|
||||
@ -115,8 +114,8 @@ class ShuffleTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(get_next())
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testSkipEagerSeedZero(self):
|
||||
@combinations.generate(combinations.combine(tf_api_version=1, mode="graph"))
|
||||
def testSeedZero(self):
|
||||
"""Test for same behavior when the seed is a Python or Tensor zero."""
|
||||
iterator = dataset_ops.make_one_shot_iterator(
|
||||
dataset_ops.Dataset.range(10).shuffle(10, seed=0))
|
||||
@ -141,6 +140,7 @@ class ShuffleTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testDefaultArguments(self):
|
||||
components = [0, 1, 2, 3, 4]
|
||||
dataset = dataset_ops.Dataset.from_tensor_slices(components).shuffle(
|
||||
@ -154,42 +154,20 @@ class ShuffleTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
for i in range(5):
|
||||
self.assertEqual(10, counts[i])
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("Reshuffle", True),
|
||||
("NoReshuffle", False),
|
||||
)
|
||||
def testReshuffle(self, reshuffle):
|
||||
dataset = dataset_ops.Dataset.range(10).shuffle(
|
||||
10, reshuffle_each_iteration=reshuffle).repeat(2)
|
||||
next_element = self.getNext(dataset)
|
||||
|
||||
first_epoch = []
|
||||
for _ in range(10):
|
||||
first_epoch.append(self.evaluate(next_element()))
|
||||
|
||||
second_epoch = []
|
||||
for _ in range(10):
|
||||
second_epoch.append(self.evaluate(next_element()))
|
||||
|
||||
self.assertEqual(first_epoch == second_epoch, not reshuffle)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("ReshuffleGraphLevelSeed", True, 38, None),
|
||||
("ReshuffleOpLevelSeed", True, None, 42),
|
||||
("ReshuffleGraphAndOpLevelSeed", True, 38, 42),
|
||||
("NoReshuffleGraphLevelSeed", False, 38, None),
|
||||
("NoReshuffleOpLevelSeed", False, None, 42),
|
||||
("NoReshuffleGraphAndOpLevelSeed", False, 38, 42),
|
||||
)
|
||||
def testSkipEagerShuffleSeed(self, reshuffle, graph_level_seed,
|
||||
op_level_seed):
|
||||
@combinations.generate(
|
||||
combinations.times(
|
||||
combinations.combine(tf_api_version=[1, 2], mode="graph"),
|
||||
combinations.combine(reshuffle=[True, False]),
|
||||
combinations.combine(graph_seed=38, op_seed=None) +
|
||||
combinations.combine(graph_seed=None, op_seed=42) +
|
||||
combinations.combine(graph_seed=38, op_seed=42)))
|
||||
def testShuffleSeed(self, reshuffle, graph_seed, op_seed):
|
||||
results = []
|
||||
for _ in range(2):
|
||||
with ops.Graph().as_default() as g:
|
||||
random_seed.set_random_seed(graph_level_seed)
|
||||
random_seed.set_random_seed(graph_seed)
|
||||
dataset = dataset_ops.Dataset.range(10).shuffle(
|
||||
10, seed=op_level_seed, reshuffle_each_iteration=reshuffle).repeat(
|
||||
3)
|
||||
10, seed=op_seed, reshuffle_each_iteration=reshuffle).repeat(3)
|
||||
iterator = dataset_ops.make_one_shot_iterator(dataset)
|
||||
next_element = iterator.get_next()
|
||||
|
||||
@ -203,15 +181,13 @@ class ShuffleTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
self.assertAllEqual(results[0], results[1])
|
||||
|
||||
# TODO(b/117581999): fails for eager mode with result[0] equal to result[1],
|
||||
# debug.
|
||||
@parameterized.named_parameters(
|
||||
("ReshuffleOneShot", True, False),
|
||||
("ReshuffleInitializable", True, True),
|
||||
("NoReshuffleOneShot", False, False),
|
||||
("NoReshuffleInitializable", False, True),
|
||||
)
|
||||
def testSkipEagerMultipleIterators(self, reshuffle, initializable):
|
||||
# TODO(b/117581999): enable this test for eager-mode.
|
||||
@combinations.generate(
|
||||
combinations.times(
|
||||
combinations.combine(tf_api_version=[1, 2], mode="graph"),
|
||||
combinations.combine(
|
||||
reshuffle=[True, False], initializable=[True, False])))
|
||||
def testMultipleIterators(self, reshuffle, initializable):
|
||||
with ops.Graph().as_default() as g:
|
||||
dataset = dataset_ops.Dataset.range(100).shuffle(
|
||||
10, reshuffle_each_iteration=reshuffle).repeat(3)
|
||||
@ -239,6 +215,43 @@ class ShuffleTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
self.assertNotEqual(results[0], results[1])
|
||||
|
||||
@combinations.generate(
|
||||
combinations.times(
|
||||
test_base.default_test_combinations(),
|
||||
combinations.combine(reshuffle=[True, False], seed=[None, 42])))
|
||||
def testReshuffleRepeatEpochs(self, reshuffle, seed):
|
||||
dataset = dataset_ops.Dataset.range(10).shuffle(
|
||||
10, seed=seed, reshuffle_each_iteration=reshuffle).repeat(2)
|
||||
next_element = self.getNext(dataset)
|
||||
|
||||
first_epoch = []
|
||||
for _ in range(10):
|
||||
first_epoch.append(self.evaluate(next_element()))
|
||||
|
||||
second_epoch = []
|
||||
for _ in range(10):
|
||||
second_epoch.append(self.evaluate(next_element()))
|
||||
|
||||
self.assertEqual(first_epoch == second_epoch, not reshuffle)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.times(
|
||||
combinations.combine(tf_api_version=2, mode="eager"),
|
||||
combinations.combine(reshuffle=[True, False], seed=[None, 42])))
|
||||
def testReshuffleIterationEpochs(self, reshuffle, seed):
|
||||
dataset = dataset_ops.Dataset.range(10).shuffle(
|
||||
10, seed=seed, reshuffle_each_iteration=reshuffle)
|
||||
|
||||
first_epoch = []
|
||||
for elem in dataset:
|
||||
first_epoch.append(elem.numpy())
|
||||
|
||||
second_epoch = []
|
||||
for elem in dataset:
|
||||
second_epoch.append(elem.numpy())
|
||||
|
||||
self.assertEqual(first_epoch == second_epoch, not reshuffle)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -30,6 +30,7 @@ from six.moves import queue as Queue # pylint: disable=redefined-builtin
|
||||
|
||||
|
||||
from tensorflow.core.framework import graph_pb2
|
||||
from tensorflow.python import tf2
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.data.experimental.ops import distribute_options
|
||||
from tensorflow.python.data.experimental.ops import optimization_options
|
||||
@ -1616,13 +1617,13 @@ class DatasetV1(DatasetV2):
|
||||
raise AttributeError("Please use _variant_tensor instead of "
|
||||
"_as_variant_tensor() to obtain the variant "
|
||||
"associated with a dataset")
|
||||
raise AttributeError("A likely cause of this error is that the super "
|
||||
raise AttributeError("{}: A likely cause of this error is that the super "
|
||||
"call for this dataset is not the last line of the "
|
||||
"__init__ method. The base class causes the "
|
||||
"_as_variant_tensor call in its constructor and "
|
||||
"if that uses attributes defined in the __init__ "
|
||||
"method, those attrs need to be defined before the "
|
||||
"super call.")
|
||||
"super call.".format(e))
|
||||
super(DatasetV1, self).__init__(variant_tensor)
|
||||
|
||||
@abc.abstractmethod
|
||||
@ -2258,8 +2259,7 @@ class Options(options_lib.OptionsBase):
|
||||
|
||||
if self.experimental_deterministic is False:
|
||||
result.append("make_sloppy")
|
||||
exp_stats_options = self.experimental_stats
|
||||
if exp_stats_options and exp_stats_options.latency_all_edges:
|
||||
if self.experimental_stats and self.experimental_stats.latency_all_edges:
|
||||
result.append("latency_all_edges")
|
||||
if self.experimental_slack:
|
||||
result.append("slack")
|
||||
@ -2942,6 +2942,48 @@ class CacheDataset(UnaryUnchangedStructureDataset):
|
||||
super(CacheDataset, self).__init__(input_dataset, variant_tensor)
|
||||
|
||||
|
||||
class _RandomSeedGeneratorDeleter(object):
|
||||
"""An object which cleans up an anonymous random seed generator resource.
|
||||
|
||||
An alternative to defining a __del__ method on an object. Even if the parent
|
||||
object is part of a reference cycle, the cycle will be collectable.
|
||||
"""
|
||||
|
||||
def __init__(self, handle, device, deleter):
|
||||
self._deleter = deleter
|
||||
self._handle = handle
|
||||
self._device = device
|
||||
self._eager_mode = context.executing_eagerly()
|
||||
|
||||
def __del__(self):
|
||||
with ops.device(self._device):
|
||||
# Make sure the resource is deleted in the same mode as it was created in.
|
||||
if self._eager_mode:
|
||||
with context.eager_mode():
|
||||
gen_dataset_ops.delete_random_seed_generator(
|
||||
handle=self._handle, deleter=self._deleter)
|
||||
else:
|
||||
with context.graph_mode():
|
||||
gen_dataset_ops.delete_random_seed_generator(
|
||||
handle=self._handle, deleter=self._deleter)
|
||||
|
||||
|
||||
class _RandomSeedGenerator(object):
|
||||
"""Represents a random seed generator resource."""
|
||||
|
||||
def __init__(self, seed, seed2):
|
||||
super(_RandomSeedGenerator, self).__init__()
|
||||
self._device = context.context().device_name
|
||||
self._handle, self._deleter = (
|
||||
gen_dataset_ops.anonymous_random_seed_generator(seed=seed, seed2=seed2))
|
||||
self._resource_deleter = _RandomSeedGeneratorDeleter(
|
||||
handle=self._handle, device=self._device, deleter=self._deleter)
|
||||
|
||||
@property
|
||||
def handle(self):
|
||||
return self._handle
|
||||
|
||||
|
||||
class ShuffleDataset(UnaryUnchangedStructureDataset):
|
||||
"""A `Dataset` that randomly shuffles the elements of its input."""
|
||||
|
||||
@ -2978,13 +3020,24 @@ class ShuffleDataset(UnaryUnchangedStructureDataset):
|
||||
self._reshuffle_each_iteration = True
|
||||
else:
|
||||
self._reshuffle_each_iteration = reshuffle_each_iteration
|
||||
variant_tensor = gen_dataset_ops.shuffle_dataset(
|
||||
input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
buffer_size=self._buffer_size,
|
||||
seed=self._seed,
|
||||
seed2=self._seed2,
|
||||
reshuffle_each_iteration=self._reshuffle_each_iteration,
|
||||
**self._flat_structure)
|
||||
|
||||
if tf2.enabled() and self._reshuffle_each_iteration and (
|
||||
context.executing_eagerly() or
|
||||
ops.get_default_graph()._building_function): # pylint: disable=protected-access
|
||||
self._seed_generator = _RandomSeedGenerator(self._seed, self._seed2)
|
||||
variant_tensor = gen_dataset_ops.shuffle_dataset_v2(
|
||||
input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
buffer_size=self._buffer_size,
|
||||
seed_generator=self._seed_generator.handle,
|
||||
**self._flat_structure)
|
||||
else:
|
||||
variant_tensor = gen_dataset_ops.shuffle_dataset(
|
||||
input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
buffer_size=self._buffer_size,
|
||||
seed=self._seed,
|
||||
seed2=self._seed2,
|
||||
reshuffle_each_iteration=self._reshuffle_each_iteration,
|
||||
**self._flat_structure)
|
||||
super(ShuffleDataset, self).__init__(input_dataset, variant_tensor)
|
||||
|
||||
|
||||
|
@ -795,8 +795,7 @@ class IteratorSpec(type_spec.TypeSpec):
|
||||
return IteratorSpec(value.element_spec) # pylint: disable=protected-access
|
||||
|
||||
|
||||
# TODO(b/71645805): Expose trackable stateful objects from dataset
|
||||
# attributes(potential).
|
||||
# TODO(b/71645805): Expose trackable stateful objects from dataset.
|
||||
class _IteratorSaveable(BaseSaverBuilder.SaveableObject):
|
||||
"""SaveableObject for saving/restoring iterator state."""
|
||||
|
||||
|
@ -60,7 +60,7 @@ class ValidationDatasetNoLimitTest(keras_parameterized.TestCase):
|
||||
# from the fit history should be equal to the final element in the output
|
||||
# of evaluating the model on the same eval dataset.
|
||||
self.assertAlmostEqual(history.history["val_mean_absolute_error"][-1],
|
||||
evaluation[-1])
|
||||
evaluation[-1], places=5)
|
||||
|
||||
|
||||
class PrintTrainingInfoTest(parameterized.TestCase):
|
||||
|
@ -100,6 +100,10 @@ tf_module {
|
||||
name: "AnonymousMultiDeviceIterator"
|
||||
argspec: "args=[\'devices\', \'output_types\', \'output_shapes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "AnonymousRandomSeedGenerator"
|
||||
argspec: "args=[\'seed\', \'seed2\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "Any"
|
||||
argspec: "args=[\'input\', \'axis\', \'keep_dims\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
|
||||
@ -952,6 +956,10 @@ tf_module {
|
||||
name: "DeleteMultiDeviceIterator"
|
||||
argspec: "args=[\'multi_device_iterator\', \'iterators\', \'deleter\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "DeleteRandomSeedGenerator"
|
||||
argspec: "args=[\'handle\', \'deleter\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "DeleteSessionTensor"
|
||||
argspec: "args=[\'handle\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
@ -3628,6 +3636,10 @@ tf_module {
|
||||
name: "ShuffleDataset"
|
||||
argspec: "args=[\'input_dataset\', \'buffer_size\', \'seed\', \'seed2\', \'output_types\', \'output_shapes\', \'reshuffle_each_iteration\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "ShuffleDatasetV2"
|
||||
argspec: "args=[\'input_dataset\', \'buffer_size\', \'seed_generator\', \'output_types\', \'output_shapes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "ShutdownDistributedTPU"
|
||||
argspec: "args=[\'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
@ -100,6 +100,10 @@ tf_module {
|
||||
name: "AnonymousMultiDeviceIterator"
|
||||
argspec: "args=[\'devices\', \'output_types\', \'output_shapes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "AnonymousRandomSeedGenerator"
|
||||
argspec: "args=[\'seed\', \'seed2\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "Any"
|
||||
argspec: "args=[\'input\', \'axis\', \'keep_dims\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
|
||||
@ -952,6 +956,10 @@ tf_module {
|
||||
name: "DeleteMultiDeviceIterator"
|
||||
argspec: "args=[\'multi_device_iterator\', \'iterators\', \'deleter\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "DeleteRandomSeedGenerator"
|
||||
argspec: "args=[\'handle\', \'deleter\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "DeleteSessionTensor"
|
||||
argspec: "args=[\'handle\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
@ -3628,6 +3636,10 @@ tf_module {
|
||||
name: "ShuffleDataset"
|
||||
argspec: "args=[\'input_dataset\', \'buffer_size\', \'seed\', \'seed2\', \'output_types\', \'output_shapes\', \'reshuffle_each_iteration\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "ShuffleDatasetV2"
|
||||
argspec: "args=[\'input_dataset\', \'buffer_size\', \'seed_generator\', \'output_types\', \'output_shapes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "ShutdownDistributedTPU"
|
||||
argspec: "args=[\'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
Loading…
Reference in New Issue
Block a user