[tf.data] This CL changes how the shuffle seed generator is managed, making it possible for the shuffle dataset to support both a) sharing of the seed generator across iterators and b) serialization. As a consequence, this CL enables reshuffling across iterations for tf.distribute and tf.data service use cases (which require both sharing of the seed generator across iterators and serialization support).
This CL in itself is a fairly large refactoring of the shuffle dataset implementation. Unifying the implementation of different op kernels for shuffle with fixed seeds, shuffle with pseudorandom seeds, and fused shuffle and repeat. This CL also removes the `make_stateless` graph rewrite as it is no longer needed. PiperOrigin-RevId: 308064029 Change-Id: I2f1d7916fe9958cf99d4e1b197da95c46b5d8b5f
This commit is contained in:
parent
50abd7d652
commit
ca9d421ddd
@ -0,0 +1,4 @@
|
||||
op {
|
||||
graph_op_name: "DummySeedGenerator"
|
||||
visibility: HIDDEN
|
||||
}
|
@ -0,0 +1,4 @@
|
||||
op {
|
||||
graph_op_name: "ShuffleDatasetV3"
|
||||
visibility: HIDDEN
|
||||
}
|
@ -20,7 +20,6 @@ cc_library(
|
||||
":inject_prefetch",
|
||||
":latency_all_edges",
|
||||
":make_sloppy",
|
||||
":make_stateless",
|
||||
":map_and_batch_fusion",
|
||||
":map_and_filter_fusion",
|
||||
":map_fusion",
|
||||
@ -43,6 +42,7 @@ cc_library(
|
||||
":optimizer_base",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/strings",
|
||||
"//tensorflow/core/grappler:grappler_item",
|
||||
"//tensorflow/core/grappler:mutable_graph_view",
|
||||
"//tensorflow/core/grappler:op_types",
|
||||
@ -390,37 +390,6 @@ tf_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "make_stateless",
|
||||
srcs = ["make_stateless.cc"],
|
||||
hdrs = ["make_stateless.h"],
|
||||
deps = [
|
||||
":graph_utils",
|
||||
":optimizer_base",
|
||||
"//tensorflow/core/grappler:grappler_item",
|
||||
"//tensorflow/core/grappler:mutable_graph_view",
|
||||
"//tensorflow/core/grappler:op_types",
|
||||
"//tensorflow/core/grappler/clusters:cluster",
|
||||
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
|
||||
] + tf_protos_all(),
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "make_stateless_test",
|
||||
srcs = ["make_stateless_test.cc"],
|
||||
deps = [
|
||||
":graph_test_utils",
|
||||
":graph_utils",
|
||||
":make_stateless",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
"//tensorflow/core/grappler:grappler_item",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "map_and_batch_fusion",
|
||||
srcs = ["map_and_batch_fusion.cc"],
|
||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/strings/match.h"
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/framework/function.pb.h"
|
||||
@ -38,10 +39,14 @@ namespace {
|
||||
constexpr char kShardDatasetOpName[] = "ShardDataset";
|
||||
constexpr char kShuffleDatasetOpName[] = "ShuffleDataset";
|
||||
constexpr char kShuffleDatasetV2OpName[] = "ShuffleDatasetV2";
|
||||
constexpr char kShuffleDatasetV3OpName[] = "ShuffleDatasetV3";
|
||||
|
||||
constexpr char kNumWorkersAttrName[] = "num_workers";
|
||||
constexpr char kIndexAttrName[] = "index";
|
||||
constexpr char kAutoShardPolicyAttrName[] = "auto_shard_policy";
|
||||
constexpr char kReshuffleEachIteration[] = "reshuffle_each_iteration";
|
||||
constexpr char kOutputShapes[] = "output_shapes";
|
||||
constexpr char kOutputTypes[] = "output_types";
|
||||
|
||||
constexpr std::array<const char*, 6> kReaderDatasetOps = {
|
||||
"FixedLengthRecordDataset",
|
||||
@ -57,7 +62,7 @@ constexpr std::array<const char*, 2> kMultipleInputsDatasetOps = {
|
||||
"ZipDataset"
|
||||
};
|
||||
|
||||
constexpr std::array<const char*, 30> kPassThroughOps = {
|
||||
constexpr std::array<const char*, 31> kPassThroughOps = {
|
||||
"_Retval",
|
||||
"AssertCardinalityDataset",
|
||||
"AssertNextDataset",
|
||||
@ -85,6 +90,7 @@ constexpr std::array<const char*, 30> kPassThroughOps = {
|
||||
"ShuffleAndRepeatDataset",
|
||||
"ShuffleDataset",
|
||||
"ShuffleDatasetV2",
|
||||
"ShuffleDatasetV3",
|
||||
"SkipDataset",
|
||||
"TakeDataset",
|
||||
"WindowDataset",
|
||||
@ -146,28 +152,27 @@ Status AddShardNode(MutableGraphView* graph, const NodeDef& add_before,
|
||||
// Add shapes and other attributes
|
||||
NodeDef* add_after = graph->GetNode(add_before.input(0));
|
||||
|
||||
if (absl::EndsWith(add_after->op(), "Dataset") ||
|
||||
absl::EndsWith(add_after->op(), "DatasetV2")) {
|
||||
if (absl::StrContains(add_after->op(), "Dataset")) {
|
||||
// We still may or may not have the right attributes because Datasets like
|
||||
// TFRecordDataset doesn't have a output type or shape, and by default we
|
||||
// set them to DT_STRING and an unknown shape.
|
||||
if (add_after->attr().count("output_shapes") > 0) {
|
||||
graph_utils::CopyAttribute("output_shapes", *add_after, &new_node);
|
||||
if (add_after->attr().count(kOutputShapes) > 0) {
|
||||
graph_utils::CopyAttribute(kOutputShapes, *add_after, &new_node);
|
||||
} else {
|
||||
tensorflow::TensorShapeProto* shape =
|
||||
(*(new_node.mutable_attr()))["output_shapes"]
|
||||
(*(new_node.mutable_attr()))[kOutputShapes]
|
||||
.mutable_list()
|
||||
->add_shape();
|
||||
shape->set_unknown_rank(true);
|
||||
}
|
||||
|
||||
if (add_after->attr().count("output_types") > 0) {
|
||||
graph_utils::CopyAttribute("output_types", *add_after, &new_node);
|
||||
if (add_after->attr().count(kOutputTypes) > 0) {
|
||||
graph_utils::CopyAttribute(kOutputTypes, *add_after, &new_node);
|
||||
} else if (add_after->attr().count("Toutput_types") > 0) {
|
||||
(*(new_node.mutable_attr()))["output_types"] =
|
||||
(*(new_node.mutable_attr()))[kOutputTypes] =
|
||||
add_after->attr().at("Toutput_types");
|
||||
} else {
|
||||
(*(new_node.mutable_attr()))["output_types"].mutable_list()->add_type(
|
||||
(*(new_node.mutable_attr()))[kOutputTypes].mutable_list()->add_type(
|
||||
tensorflow::DataType::DT_STRING);
|
||||
}
|
||||
} else {
|
||||
@ -189,9 +194,10 @@ Status AddShardNode(MutableGraphView* graph, const NodeDef& add_before,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status AddShuffleNode(MutableGraphView* graph, const NodeDef& add_before,
|
||||
const string& buffer_size_node, const string& seed_node,
|
||||
const string& seed2_node, bool reshuffle_each_iteration) {
|
||||
Status AddShuffleDataset(MutableGraphView* graph, const NodeDef& add_before,
|
||||
const string& buffer_size_node,
|
||||
const string& seed_node, const string& seed2_node,
|
||||
bool reshuffle_each_iteration) {
|
||||
NodeDef* add_after = graph->GetNode(add_before.input(0));
|
||||
NodeDef new_node;
|
||||
new_node.set_op(kShuffleDatasetOpName);
|
||||
@ -203,12 +209,12 @@ Status AddShuffleNode(MutableGraphView* graph, const NodeDef& add_before,
|
||||
new_node.add_input(seed_node);
|
||||
new_node.add_input(seed2_node);
|
||||
|
||||
graph_utils::CopyAttribute("output_shapes", *add_after, &new_node);
|
||||
graph_utils::CopyAttribute("output_types", *add_after, &new_node);
|
||||
graph_utils::CopyAttribute(kOutputShapes, *add_after, &new_node);
|
||||
graph_utils::CopyAttribute(kOutputTypes, *add_after, &new_node);
|
||||
|
||||
AttrValue reshuffle_attr;
|
||||
reshuffle_attr.set_b(reshuffle_each_iteration);
|
||||
(*new_node.mutable_attr())["reshuffle_each_iteration"] = reshuffle_attr;
|
||||
(*new_node.mutable_attr())[kReshuffleEachIteration] = reshuffle_attr;
|
||||
|
||||
NodeDef* new_node_graph = graph->AddNode(std::move(new_node));
|
||||
|
||||
@ -217,9 +223,9 @@ Status AddShuffleNode(MutableGraphView* graph, const NodeDef& add_before,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status AddShuffleV2Node(MutableGraphView* graph, const NodeDef& add_before,
|
||||
const string& buffer_size_node,
|
||||
const string& seed_generator_node) {
|
||||
Status AddShuffleDatasetV2(MutableGraphView* graph, const NodeDef& add_before,
|
||||
const string& buffer_size_node,
|
||||
const string& seed_generator_node) {
|
||||
NodeDef* add_after = graph->GetNode(add_before.input(0));
|
||||
NodeDef new_node;
|
||||
new_node.set_op(kShuffleDatasetV2OpName);
|
||||
@ -230,8 +236,39 @@ Status AddShuffleV2Node(MutableGraphView* graph, const NodeDef& add_before,
|
||||
new_node.add_input(buffer_size_node);
|
||||
new_node.add_input(seed_generator_node);
|
||||
|
||||
graph_utils::CopyAttribute("output_shapes", *add_after, &new_node);
|
||||
graph_utils::CopyAttribute("output_types", *add_after, &new_node);
|
||||
graph_utils::CopyAttribute(kOutputShapes, *add_after, &new_node);
|
||||
graph_utils::CopyAttribute(kOutputTypes, *add_after, &new_node);
|
||||
|
||||
NodeDef* new_node_graph = graph->AddNode(std::move(new_node));
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
graph->UpdateFanouts(add_after->name(), new_node_graph->name()));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status AddShuffleDatasetV3(MutableGraphView* graph, const NodeDef& add_before,
|
||||
const string& buffer_size_node,
|
||||
const string& seed_node, const string& seed2_node,
|
||||
const string& seed_generator_node,
|
||||
bool reshuffle_each_iteration) {
|
||||
NodeDef* add_after = graph->GetNode(add_before.input(0));
|
||||
NodeDef new_node;
|
||||
new_node.set_op(kShuffleDatasetV3OpName);
|
||||
graph_utils::SetUniqueGraphNodeName(kShuffleDatasetV3OpName, graph->graph(),
|
||||
&new_node);
|
||||
|
||||
new_node.add_input(add_before.input(0));
|
||||
new_node.add_input(buffer_size_node);
|
||||
new_node.add_input(seed_node);
|
||||
new_node.add_input(seed2_node);
|
||||
new_node.add_input(seed_generator_node);
|
||||
|
||||
graph_utils::CopyAttribute(kOutputShapes, *add_after, &new_node);
|
||||
graph_utils::CopyAttribute(kOutputTypes, *add_after, &new_node);
|
||||
|
||||
AttrValue reshuffle_attr;
|
||||
reshuffle_attr.set_b(reshuffle_each_iteration);
|
||||
(*new_node.mutable_attr())[kReshuffleEachIteration] = reshuffle_attr;
|
||||
|
||||
NodeDef* new_node_graph = graph->AddNode(std::move(new_node));
|
||||
|
||||
@ -268,7 +305,7 @@ Status RemoveShuffleDataset(MutableGraphView* graph, const NodeDef& node,
|
||||
*buffer_size_node = node.input(1);
|
||||
*seed_node = node.input(2);
|
||||
*seed2_node = node.input(3);
|
||||
*reshuffle_each_iteration = node.attr().at("reshuffle_each_iteration").b();
|
||||
*reshuffle_each_iteration = node.attr().at(kReshuffleEachIteration).b();
|
||||
TF_RETURN_IF_ERROR(graph->UpdateFanouts(node.name(), node.input(0)));
|
||||
nodes_to_delete->insert(node.name());
|
||||
}
|
||||
@ -305,6 +342,33 @@ Status RemoveShuffleDatasetV2(MutableGraphView* graph, const NodeDef& node,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status RemoveShuffleDatasetV3(MutableGraphView* graph, const NodeDef& node,
|
||||
absl::flat_hash_set<string>* nodes_to_delete,
|
||||
string* op_name, string* buffer_size_node,
|
||||
string* seed_node, string* seed2_node,
|
||||
string* seed_generator_node,
|
||||
bool* reshuffle_each_iteration) {
|
||||
if (node.op() == kShuffleDatasetV3OpName) {
|
||||
*op_name = node.op();
|
||||
*buffer_size_node = node.input(1);
|
||||
*seed_node = node.input(2);
|
||||
*seed2_node = node.input(3);
|
||||
*seed_generator_node = node.input(4);
|
||||
*reshuffle_each_iteration = node.attr().at(kReshuffleEachIteration).b();
|
||||
TF_RETURN_IF_ERROR(graph->UpdateFanouts(node.name(), node.input(0)));
|
||||
nodes_to_delete->insert(node.name());
|
||||
}
|
||||
|
||||
for (const auto& fanin : graph->GetFanins(node, true)) {
|
||||
TF_RETURN_IF_ERROR(RemoveShuffleDatasetV3(
|
||||
graph, *fanin.node, nodes_to_delete, op_name, buffer_size_node,
|
||||
seed_node, seed2_node, seed_generator_node, reshuffle_each_iteration));
|
||||
}
|
||||
|
||||
// TODO(frankchn): Traverse functions too.
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ProcessDatasetSourceNode(MutableGraphView* graph, const NodeDef& node,
|
||||
absl::flat_hash_set<string>* nodes_to_delete,
|
||||
int64 num_workers, int64 index) {
|
||||
@ -324,13 +388,24 @@ Status ProcessDatasetSourceNode(MutableGraphView* graph, const NodeDef& node,
|
||||
RemoveShuffleDatasetV2(graph, node, nodes_to_delete, &shuffle_op_name,
|
||||
&buffer_size_node, &seed_generator_node));
|
||||
}
|
||||
if (shuffle_op_name.empty()) {
|
||||
TF_RETURN_IF_ERROR(RemoveShuffleDatasetV3(
|
||||
graph, node, nodes_to_delete, &shuffle_op_name, &buffer_size_node,
|
||||
&seed_node, &seed2_node, &seed_generator_node,
|
||||
&reshuffle_each_iteration));
|
||||
}
|
||||
|
||||
if (shuffle_op_name == kShuffleDatasetOpName) {
|
||||
TF_RETURN_IF_ERROR(AddShuffleNode(graph, node, buffer_size_node, seed_node,
|
||||
seed2_node, reshuffle_each_iteration));
|
||||
TF_RETURN_IF_ERROR(AddShuffleDataset(graph, node, buffer_size_node,
|
||||
seed_node, seed2_node,
|
||||
reshuffle_each_iteration));
|
||||
} else if (shuffle_op_name == kShuffleDatasetV2OpName) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
AddShuffleV2Node(graph, node, buffer_size_node, seed_generator_node));
|
||||
TF_RETURN_IF_ERROR(AddShuffleDatasetV2(graph, node, buffer_size_node,
|
||||
seed_generator_node));
|
||||
} else if (shuffle_op_name == kShuffleDatasetV3OpName) {
|
||||
TF_RETURN_IF_ERROR(AddShuffleDatasetV3(
|
||||
graph, node, buffer_size_node, seed_node, seed2_node,
|
||||
seed_generator_node, reshuffle_each_iteration));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
|
@ -1,65 +0,0 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/grappler/optimizers/data/make_stateless.h"
|
||||
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/grappler/clusters/cluster.h"
|
||||
#include "tensorflow/core/grappler/grappler_item.h"
|
||||
#include "tensorflow/core/grappler/mutable_graph_view.h"
|
||||
#include "tensorflow/core/grappler/op_types.h"
|
||||
#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
|
||||
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
namespace {
|
||||
|
||||
constexpr char kReshuffleEachIteration[] = "reshuffle_each_iteration";
|
||||
constexpr char kShuffleDataset[] = "ShuffleDataset";
|
||||
constexpr char kShuffleDatasetV2[] = "ShuffleDatasetV2";
|
||||
|
||||
} // namespace
|
||||
|
||||
Status MakeStateless::OptimizeAndCollectStats(Cluster* cluster,
|
||||
const GrapplerItem& item,
|
||||
GraphDef* output,
|
||||
OptimizationStats* stats) {
|
||||
*output = item.graph;
|
||||
MutableGraphView graph(output);
|
||||
|
||||
NodeDef* zero_node = graph_utils::AddScalarConstNode<int64>(0, &graph);
|
||||
|
||||
for (NodeDef& node : *output->mutable_node()) {
|
||||
if (node.op() == kShuffleDatasetV2) {
|
||||
*node.mutable_op() = kShuffleDataset;
|
||||
// remove `seed_generator` input
|
||||
node.mutable_input()->RemoveLast();
|
||||
// add `seed` input
|
||||
node.add_input(zero_node->name());
|
||||
// add `seed2` input
|
||||
node.add_input(zero_node->name());
|
||||
// set `reshuffle_each_iteration` attr
|
||||
(*node.mutable_attr())[kReshuffleEachIteration].set_b(true);
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
REGISTER_GRAPH_OPTIMIZER_AS(MakeStateless, "make_stateless");
|
||||
|
||||
} // namespace grappler
|
||||
} // namespace tensorflow
|
@ -1,56 +0,0 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_MAKE_STATELESS_H_
|
||||
#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_MAKE_STATELESS_H_
|
||||
|
||||
#include "tensorflow/core/grappler/optimizers/data/optimizer_base.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
|
||||
// This rewrite replaces transformations that depend on external state (such as
|
||||
// `ShuffleDatasetV2`) with a stateless alternative so that the input pipeline
|
||||
// graph can be cloned.
|
||||
//
|
||||
// Note that this rewrites may change observable behavior of the input pipeline
|
||||
// (e.g. `reshuffle_each_iteration` will not work) and is a stop gap solution
|
||||
// to enable cloning until a better mechanism exists.
|
||||
class MakeStateless : public TFDataOptimizerBase {
|
||||
public:
|
||||
MakeStateless() = default;
|
||||
~MakeStateless() override = default;
|
||||
|
||||
string name() const override { return "make_stateless"; }
|
||||
|
||||
bool UsesFunctionLibrary() const override { return false; }
|
||||
|
||||
Status Init(
|
||||
const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status OptimizeAndCollectStats(Cluster* cluster, const GrapplerItem& item,
|
||||
GraphDef* output,
|
||||
OptimizationStats* stats) override;
|
||||
|
||||
void Feedback(Cluster* cluster, const GrapplerItem& item,
|
||||
const GraphDef& optimize_output, double result) override {}
|
||||
};
|
||||
|
||||
} // namespace grappler
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_MAKE_STATELESS_H_
|
@ -1,56 +0,0 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/grappler/optimizers/data/make_stateless.h"
|
||||
|
||||
#include "tensorflow/core/framework/attr_value_util.h"
|
||||
#include "tensorflow/core/framework/function_testlib.h"
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
#include "tensorflow/core/grappler/grappler_item.h"
|
||||
#include "tensorflow/core/grappler/optimizers/data/graph_test_utils.h"
|
||||
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
namespace {
|
||||
|
||||
TEST(MakeStateless, Shuffle) {
|
||||
using test::function::NDef;
|
||||
GrapplerItem item;
|
||||
item.graph = test::function::GDef(
|
||||
{NDef("start", "Const", {}, {{"value", 0}, {"dtype", DT_INT32}}),
|
||||
NDef("stop", "Const", {}, {{"value", 10}, {"dtype", DT_INT32}}),
|
||||
NDef("step", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
|
||||
NDef("range", "RangeDataset", {"start", "stop", "step"}, {}),
|
||||
NDef("buffer_size", "Const", {}, {{"value", 1}, {"dtype", DT_INT64}}),
|
||||
NDef("handle", "Const", {}, {{"value", 1}, {"dtype", DT_RESOURCE}}),
|
||||
graph_tests_utils::MakeShuffleV2Node("shuffle", "range", "buffer_size",
|
||||
"handle")},
|
||||
{});
|
||||
|
||||
MakeStateless optimizer;
|
||||
GraphDef output;
|
||||
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
|
||||
EXPECT_TRUE(graph_utils::ContainsGraphNodeWithName("shuffle", output));
|
||||
int index = graph_utils::FindGraphNodeWithName("shuffle", output);
|
||||
EXPECT_EQ(output.node(index).op(), "ShuffleDataset");
|
||||
EXPECT_EQ(output.node(index).input_size(), 4);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace grappler
|
||||
} // namespace tensorflow
|
@ -43,8 +43,7 @@ using ConfigMap =
|
||||
std::map<string, tensorflow::RewriterConfig_CustomGraphOptimizer>;
|
||||
|
||||
// tf.data optimizations, in the order we want to perform them.
|
||||
constexpr std::array<const char*, 16> kTFDataOptimizations = {
|
||||
"make_stateless",
|
||||
constexpr std::array<const char*, 15> kTFDataOptimizations = {
|
||||
"noop_elimination",
|
||||
"shuffle_and_repeat_fusion",
|
||||
"map_fusion",
|
||||
|
@ -51,7 +51,7 @@ bool IsDatasetNodeOfType(const NodeDef& node,
|
||||
constexpr std::array<const char*, 2> kMultipleInputsDatasetOps = {
|
||||
"ZipDataset", "ConcatenateDataset"};
|
||||
|
||||
constexpr std::array<const char*, 21> kPassThroughOps = {
|
||||
constexpr std::array<const char*, 22> kPassThroughOps = {
|
||||
"CacheDataset",
|
||||
"CacheDatasetV2",
|
||||
"ExperimentalMaxIntraOpParallelismDataset",
|
||||
@ -70,6 +70,7 @@ constexpr std::array<const char*, 21> kPassThroughOps = {
|
||||
"ShuffleAndRepeatDataset",
|
||||
"ShuffleDataset",
|
||||
"ShuffleDatasetV2",
|
||||
"ShuffleDatasetV3",
|
||||
"SkipDataset",
|
||||
"TakeDataset",
|
||||
"WindowDataset",
|
||||
|
@ -39,40 +39,6 @@ Status CreateHandle(OpKernelContext* ctx, T* resource,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// A wrapper class that manages the lifetime of a resource handle from its
|
||||
// creation to its deletion from the resource manager.
|
||||
class OwnedResourceHandle {
|
||||
public:
|
||||
template <typename T>
|
||||
static Status Create(OpKernelContext* ctx, T* resource, const string& name,
|
||||
std::unique_ptr<OwnedResourceHandle>* result) {
|
||||
ResourceHandle handle;
|
||||
TF_RETURN_IF_ERROR(CreateHandle<T>(ctx, resource, name, &handle));
|
||||
// We need to increase the refcount to match the decrease that occurs when
|
||||
// the resource associate.
|
||||
resource->Ref();
|
||||
*result = absl::make_unique<OwnedResourceHandle>(ctx, std::move(handle));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
OwnedResourceHandle(OpKernelContext* ctx, ResourceHandle&& handle)
|
||||
: mgr_(ctx->resource_manager()), handle_(handle) {}
|
||||
|
||||
~OwnedResourceHandle() {
|
||||
Status s = mgr_->Delete(handle_);
|
||||
if (!s.ok()) {
|
||||
VLOG(2) << s.ToString();
|
||||
}
|
||||
}
|
||||
|
||||
// Returns the wrapped `ResourceHandle` object.
|
||||
const ResourceHandle& handle() const { return handle_; }
|
||||
|
||||
private:
|
||||
ResourceMgr* mgr_; // not owned
|
||||
const ResourceHandle handle_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class AnonymousResourceOp : public OpKernel {
|
||||
public:
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/random/philox_random.h"
|
||||
#include "tensorflow/core/lib/random/random.h"
|
||||
#include "tensorflow/core/lib/random/random_distributions.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace data {
|
||||
@ -28,8 +29,6 @@ namespace {
|
||||
|
||||
const char kAnonymousRandomSeedGenerator[] = "AnonymousRandomSeedGenerator";
|
||||
const char kNumRandomSamples[] = "num_random_samples";
|
||||
const char kFixedSeedGenerator[] = "FixedSeedGenerator";
|
||||
const char kRandomSeedGenerator[] = "RandomSeedGenerator";
|
||||
const char kSeedGenerator[] = "SeedGenerator";
|
||||
const char kSeed[] = "seed";
|
||||
const char kSeed2[] = "seed2";
|
||||
@ -37,27 +36,15 @@ const char kReshuffle[] = "reshuffle";
|
||||
|
||||
} // namespace
|
||||
|
||||
int64 SeedGenerator::num_random_samples() {
|
||||
tf_shared_lock l(mu_);
|
||||
return num_random_samples_;
|
||||
}
|
||||
|
||||
void SeedGenerator::set_num_random_samples(int64 num_random_samples) {
|
||||
mutex_lock l(mu_);
|
||||
num_random_samples_ = num_random_samples;
|
||||
}
|
||||
|
||||
string FixedSeedGenerator::DebugString() const { return kFixedSeedGenerator; }
|
||||
string SeedGeneratorManager::DebugString() const { return kSeedGenerator; }
|
||||
|
||||
void FixedSeedGenerator::GenerateSeeds(int64* seed1, int64* seed2) {
|
||||
mutex_lock l(mu_);
|
||||
num_random_samples_++;
|
||||
*seed1 = seed_;
|
||||
*seed2 = seed2_;
|
||||
*seed1 = seeds_.seed();
|
||||
*seed2 = seeds_.seed2();
|
||||
}
|
||||
|
||||
string RandomSeedGenerator::DebugString() const { return kRandomSeedGenerator; }
|
||||
|
||||
void RandomSeedGenerator::GenerateSeeds(int64* seed1, int64* seed2) {
|
||||
mutex_lock l(mu_);
|
||||
num_random_samples_++;
|
||||
@ -69,7 +56,7 @@ void RandomSeedGenerator::GenerateSeeds(int64* seed1, int64* seed2) {
|
||||
void RandomSeedGenerator::Reset() {
|
||||
mutex_lock l(mu_);
|
||||
// Reset the generators based on the current seeds.
|
||||
parent_generator_ = random::PhiloxRandom(seed_, seed2_);
|
||||
parent_generator_ = random::PhiloxRandom(seeds_.seed(), seeds_.seed2());
|
||||
generator_ =
|
||||
random::SingleSampleAdapter<random::PhiloxRandom>(&parent_generator_);
|
||||
generator_.Skip(num_random_samples_);
|
||||
@ -77,29 +64,17 @@ void RandomSeedGenerator::Reset() {
|
||||
|
||||
AnonymousSeedGeneratorHandleOp::AnonymousSeedGeneratorHandleOp(
|
||||
OpKernelConstruction* ctx)
|
||||
: AnonymousResourceOp<SeedGenerator>(ctx) {}
|
||||
: AnonymousResourceOp<SeedGeneratorManager>(ctx) {}
|
||||
|
||||
void AnonymousSeedGeneratorHandleOp::Compute(OpKernelContext* ctx) {
|
||||
int64 seed;
|
||||
OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, kSeed, &seed));
|
||||
int64 seed2;
|
||||
OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, kSeed2, &seed2));
|
||||
if (seed == 0 && seed2 == 0) {
|
||||
seed = random::New64();
|
||||
seed2 = random::New64();
|
||||
}
|
||||
seed_ = seed;
|
||||
seed2_ = seed2;
|
||||
|
||||
// TODO(b/151115950): Remove this case when the forward compatibility window
|
||||
// expires.
|
||||
if (ctx->op_kernel().def().op() == kAnonymousRandomSeedGenerator) {
|
||||
reshuffle_ = true;
|
||||
} else {
|
||||
OP_REQUIRES_OK(ctx,
|
||||
ParseScalarArgument<bool>(ctx, kReshuffle, &reshuffle_));
|
||||
}
|
||||
AnonymousResourceOp<SeedGenerator>::Compute(ctx);
|
||||
// Seeds will be consumed by `CreateResource`, which is called via `Compute`.
|
||||
seeds_ = absl::make_unique<RandomSeeds>(seed, seed2);
|
||||
OP_REQUIRES_OK(ctx, ParseScalarArgument<bool>(ctx, kReshuffle, &reshuffle_));
|
||||
AnonymousResourceOp<SeedGeneratorManager>::Compute(ctx);
|
||||
}
|
||||
|
||||
std::string AnonymousSeedGeneratorHandleOp::name() { return kSeedGenerator; }
|
||||
@ -107,12 +82,13 @@ std::string AnonymousSeedGeneratorHandleOp::name() { return kSeedGenerator; }
|
||||
Status AnonymousSeedGeneratorHandleOp::CreateResource(
|
||||
OpKernelContext* ctx, std::unique_ptr<FunctionLibraryDefinition> flib_def,
|
||||
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr,
|
||||
FunctionLibraryRuntime* lib, SeedGenerator** resource) {
|
||||
FunctionLibraryRuntime* lib, SeedGeneratorManager** manager) {
|
||||
if (reshuffle_) {
|
||||
*resource = new RandomSeedGenerator(seed_, seed2_);
|
||||
*manager = new SeedGeneratorManager(new RandomSeedGenerator(*seeds_));
|
||||
} else {
|
||||
*resource = new FixedSeedGenerator(seed_, seed2_);
|
||||
*manager = new SeedGeneratorManager(new FixedSeedGenerator(*seeds_));
|
||||
}
|
||||
seeds_ = nullptr;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -137,6 +113,9 @@ REGISTER_KERNEL_BUILDER(Name("AnonymousRandomSeedGenerator").Device(DEVICE_CPU),
|
||||
REGISTER_KERNEL_BUILDER(Name("DeleteRandomSeedGenerator").Device(DEVICE_CPU),
|
||||
DeleteSeedGeneratorOp);
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("DummySeedGenerator").Device(DEVICE_CPU),
|
||||
DummyResourceOp<SeedGenerator>);
|
||||
|
||||
} // namespace
|
||||
} // namespace data
|
||||
} // namespace tensorflow
|
||||
|
@ -25,51 +25,102 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace data {
|
||||
|
||||
// Represents a pair of random seeds. By TensorFlow convention, if both seeds
|
||||
// are 0, then pseudo-random values are used instead.
|
||||
class RandomSeeds {
|
||||
public:
|
||||
RandomSeeds(int64 seed, int64 seed2)
|
||||
: input_seed_(seed),
|
||||
input_seed2_(seed2),
|
||||
seed_((seed | seed2) == 0 ? random::New64() : seed),
|
||||
seed2_((seed | seed2) == 0 ? random::New64() : seed2) {}
|
||||
|
||||
int64 input_seed() const { return input_seed_; }
|
||||
int64 input_seed2() const { return input_seed2_; }
|
||||
int64 seed() const { return seed_; }
|
||||
int64 seed2() const { return seed2_; }
|
||||
|
||||
private:
|
||||
const int64 input_seed_;
|
||||
const int64 input_seed2_;
|
||||
const int64 seed_;
|
||||
const int64 seed2_;
|
||||
};
|
||||
|
||||
// Base class for seed generator resources. Subclasses customize how seeds are
|
||||
// generated.
|
||||
class SeedGenerator : public ResourceBase {
|
||||
class SeedGenerator {
|
||||
public:
|
||||
virtual ~SeedGenerator() {}
|
||||
|
||||
virtual int64 seed() const = 0;
|
||||
virtual int64 seed2() const = 0;
|
||||
virtual bool reshuffle_each_iteration() const = 0;
|
||||
|
||||
virtual void GenerateSeeds(int64* seed1, int64* seed2) = 0;
|
||||
virtual void Reset() = 0;
|
||||
|
||||
virtual int64 num_random_samples();
|
||||
virtual void set_num_random_samples(int64 num_random_samples);
|
||||
virtual int64 num_random_samples() const {
|
||||
tf_shared_lock l(mu_);
|
||||
return num_random_samples_;
|
||||
}
|
||||
virtual void set_num_random_samples(int64 num_random_samples) {
|
||||
mutex_lock l(mu_);
|
||||
num_random_samples_ = num_random_samples;
|
||||
}
|
||||
|
||||
protected:
|
||||
mutex mu_;
|
||||
mutable mutex mu_;
|
||||
int64 num_random_samples_ TF_GUARDED_BY(mu_) = 0;
|
||||
};
|
||||
|
||||
// A resource wrapping a shared instance of a seed generator.
|
||||
class SeedGeneratorManager : public ResourceBase {
|
||||
public:
|
||||
explicit SeedGeneratorManager(SeedGenerator* seed_generator)
|
||||
: seed_generator_(seed_generator) {}
|
||||
|
||||
std::string DebugString() const override;
|
||||
|
||||
std::shared_ptr<SeedGenerator> get() { return seed_generator_; }
|
||||
|
||||
private:
|
||||
std::shared_ptr<SeedGenerator> seed_generator_;
|
||||
};
|
||||
|
||||
// Always generates the specified seed values.
|
||||
class FixedSeedGenerator : public SeedGenerator {
|
||||
public:
|
||||
FixedSeedGenerator(int64 seed, int64 seed2) : seed_(seed), seed2_(seed2) {}
|
||||
explicit FixedSeedGenerator(RandomSeeds seeds) : seeds_(std::move(seeds)) {}
|
||||
|
||||
int64 seed() const override { return seeds_.seed(); }
|
||||
int64 seed2() const override { return seeds_.seed(); }
|
||||
bool reshuffle_each_iteration() const override { return false; }
|
||||
|
||||
std::string DebugString() const override;
|
||||
void GenerateSeeds(int64* seed1, int64* seed2) override;
|
||||
void Reset() override {}
|
||||
|
||||
private:
|
||||
const int64 seed_;
|
||||
const int64 seed2_;
|
||||
const RandomSeeds seeds_;
|
||||
};
|
||||
|
||||
// Generates different (but deterministically chosen) seed values.
|
||||
class RandomSeedGenerator : public SeedGenerator {
|
||||
public:
|
||||
RandomSeedGenerator(int64 seed, int64 seed2)
|
||||
: seed_(seed),
|
||||
seed2_(seed2),
|
||||
parent_generator_(seed, seed2),
|
||||
explicit RandomSeedGenerator(RandomSeeds seeds)
|
||||
: seeds_(std::move(seeds)),
|
||||
parent_generator_(seeds_.seed(), seeds_.seed2()),
|
||||
generator_(&parent_generator_) {}
|
||||
|
||||
std::string DebugString() const override;
|
||||
int64 seed() const override { return seeds_.seed(); }
|
||||
int64 seed2() const override { return seeds_.seed2(); }
|
||||
bool reshuffle_each_iteration() const override { return true; }
|
||||
|
||||
void GenerateSeeds(int64* seed1, int64* seed2) override;
|
||||
void Reset() override;
|
||||
|
||||
private:
|
||||
const int64 seed_;
|
||||
const int64 seed2_;
|
||||
const RandomSeeds seeds_;
|
||||
random::PhiloxRandom parent_generator_ TF_GUARDED_BY(mu_);
|
||||
random::SingleSampleAdapter<random::PhiloxRandom> generator_
|
||||
TF_GUARDED_BY(mu_);
|
||||
@ -78,7 +129,7 @@ class RandomSeedGenerator : public SeedGenerator {
|
||||
// Creates an instance of seed generator resource and transfers ownership
|
||||
// to the caller.
|
||||
class AnonymousSeedGeneratorHandleOp
|
||||
: public AnonymousResourceOp<SeedGenerator> {
|
||||
: public AnonymousResourceOp<SeedGeneratorManager> {
|
||||
public:
|
||||
explicit AnonymousSeedGeneratorHandleOp(OpKernelConstruction* ctx);
|
||||
void Compute(OpKernelContext* ctx) override;
|
||||
@ -89,10 +140,9 @@ class AnonymousSeedGeneratorHandleOp
|
||||
std::unique_ptr<FunctionLibraryDefinition> flib_def,
|
||||
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr,
|
||||
FunctionLibraryRuntime* lib,
|
||||
SeedGenerator** resource) override;
|
||||
SeedGeneratorManager** manager) override;
|
||||
|
||||
int64 seed_;
|
||||
int64 seed2_;
|
||||
std::unique_ptr<RandomSeeds> seeds_ = nullptr;
|
||||
bool reshuffle_;
|
||||
};
|
||||
|
||||
|
@ -68,34 +68,10 @@ constexpr char kBuffer[] = "buffer";
|
||||
constexpr char kSize[] = "size";
|
||||
constexpr char kSeedGenerator[] = "SeedGenerator";
|
||||
constexpr char kTFData[] = "tf_data";
|
||||
constexpr char kDSNumRandomSamples[] = "ds_num_random_samples";
|
||||
constexpr char kFixedSeedDatasetPrefix[] = "FixedSeed";
|
||||
constexpr char kDatasetPrefix[] = "Dataset";
|
||||
constexpr char kDatasetV2Prefix[] = "DatasetV2";
|
||||
constexpr char kShuffleDataset[] = "ShuffleDataset";
|
||||
|
||||
namespace {
|
||||
class Seeds {
|
||||
public:
|
||||
Seeds(int64 seed, int64 seed2) {
|
||||
input_seed_ = seed;
|
||||
input_seed2_ = seed2;
|
||||
seed_ = seed;
|
||||
seed2_ = seed2;
|
||||
// By TensorFlow convention, if both seeds are 0, then shuffling should be
|
||||
// seeded non-deterministically.
|
||||
if (seed == 0 && seed2 == 0) {
|
||||
seed_ = random::New64();
|
||||
seed2_ = random::New64();
|
||||
}
|
||||
}
|
||||
|
||||
int64 input_seed_;
|
||||
int64 input_seed2_;
|
||||
int64 seed_;
|
||||
int64 seed2_;
|
||||
};
|
||||
} // namespace
|
||||
constexpr char kEpochNumRandomSamples[] = "epoch_num_random_samples";
|
||||
constexpr char kShuffleDatasetV1[] = "ShuffleDataset";
|
||||
constexpr char kShuffleDatasetV2[] = "ShuffleDatasetV2";
|
||||
constexpr char kShuffleDatasetV3[] = "ShuffleDatasetV3";
|
||||
|
||||
ShuffleDatasetOpBase::ShuffleDatasetOpBase(OpKernelConstruction* ctx)
|
||||
: UnaryDatasetOpKernel(ctx) {}
|
||||
@ -104,10 +80,12 @@ ShuffleDatasetOpBase::ShuffleDatasetOpBase(OpKernelConstruction* ctx)
|
||||
class ShuffleDatasetOpBase::ShuffleDatasetBase : public DatasetBase {
|
||||
public:
|
||||
ShuffleDatasetBase(OpKernelContext* ctx, const DatasetBase* input,
|
||||
int64 buffer_size, int64 count)
|
||||
int64 buffer_size,
|
||||
std::shared_ptr<SeedGenerator> seed_generator, int64 count)
|
||||
: DatasetBase(DatasetContext(ctx)),
|
||||
input_(input),
|
||||
buffer_size_(buffer_size),
|
||||
seed_generator_(std::move(seed_generator)),
|
||||
count_(count),
|
||||
traceme_metadata_(
|
||||
{{"buffer_size",
|
||||
@ -117,6 +95,8 @@ class ShuffleDatasetOpBase::ShuffleDatasetBase : public DatasetBase {
|
||||
|
||||
~ShuffleDatasetBase() override { input_->Unref(); }
|
||||
|
||||
virtual string op_type() const = 0;
|
||||
|
||||
const DataTypeVector& output_dtypes() const override {
|
||||
return input_->output_dtypes();
|
||||
}
|
||||
@ -139,37 +119,40 @@ class ShuffleDatasetOpBase::ShuffleDatasetBase : public DatasetBase {
|
||||
return input_->CheckExternalState();
|
||||
}
|
||||
|
||||
protected:
|
||||
// Adds the seeds to the given graphdef builder. `preserve_random_seeds`
|
||||
// controls whether to add the input seeds or the resolved seeds.
|
||||
Status AddSeeds(Seeds seeds, bool preserve_random_seeds,
|
||||
DatasetGraphDefBuilder* b, Node** seed, Node** seed2) const {
|
||||
int64 seed_to_add = preserve_random_seeds ? seeds.input_seed_ : seeds.seed_;
|
||||
int64 seed2_to_add =
|
||||
preserve_random_seeds ? seeds.input_seed2_ : seeds.seed2_;
|
||||
TF_RETURN_IF_ERROR(b->AddScalar(seed_to_add, seed));
|
||||
TF_RETURN_IF_ERROR(b->AddScalar(seed2_to_add, seed2));
|
||||
return Status::OK();
|
||||
string DebugString() const override {
|
||||
name_utils::DatasetDebugStringParams params;
|
||||
params.set_args(buffer_size_, seed_generator_->seed(),
|
||||
seed_generator_->seed2(), count_);
|
||||
return name_utils::DatasetDebugString(op_type(), params);
|
||||
}
|
||||
|
||||
template <class T>
|
||||
class Iterator : public DatasetIterator<T> {
|
||||
std::unique_ptr<IteratorBase> MakeIteratorInternal(
|
||||
const string& prefix) const override {
|
||||
return absl::make_unique<Iterator>(
|
||||
Iterator::Params{this, name_utils::IteratorPrefix(op_type(), prefix)},
|
||||
seed_generator_.get());
|
||||
}
|
||||
|
||||
protected:
|
||||
class Iterator : public DatasetIterator<ShuffleDatasetBase> {
|
||||
public:
|
||||
explicit Iterator(const typename DatasetIterator<T>::Params& params,
|
||||
int64 seed, int64 seed2)
|
||||
: DatasetIterator<T>(params),
|
||||
seed_(seed),
|
||||
seed2_(seed2),
|
||||
input_impl_(nullptr),
|
||||
epoch_(0),
|
||||
num_elements_(0),
|
||||
parent_generator_(seed, seed2),
|
||||
explicit Iterator(const Params& params, SeedGenerator* seed_generator)
|
||||
: DatasetIterator<ShuffleDatasetBase>(params),
|
||||
seed_generator_(seed_generator),
|
||||
parent_generator_(seed_generator->seed(), seed_generator->seed2()),
|
||||
generator_(&parent_generator_) {
|
||||
buffer_ = absl::make_unique<std::vector<Tensor>[]>(
|
||||
params.dataset->buffer_size_);
|
||||
slices_.push_back(absl::make_unique<Slice>(0, 0));
|
||||
}
|
||||
|
||||
Status Initialize(IteratorContext* ctx) override {
|
||||
mutex_lock l(mu_);
|
||||
seed_generator_->GenerateSeeds(&seed_, &seed2_);
|
||||
ResetRngs();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GetNextInternal(IteratorContext* ctx,
|
||||
std::vector<Tensor>* out_tensors,
|
||||
bool* end_of_sequence) override {
|
||||
@ -283,6 +266,9 @@ class ShuffleDatasetOpBase::ShuffleDatasetBase : public DatasetBase {
|
||||
IteratorStateWriter* writer) override {
|
||||
mutex_lock l(mu_);
|
||||
// Save state needed to restore the random number generators.
|
||||
TF_RETURN_IF_ERROR(
|
||||
writer->WriteScalar(full_name(kEpochNumRandomSamples),
|
||||
seed_generator_->num_random_samples()));
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(this->full_name(kNumRandomSamples),
|
||||
num_random_samples_));
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(this->full_name(kSeed), seed_));
|
||||
@ -337,6 +323,11 @@ class ShuffleDatasetOpBase::ShuffleDatasetBase : public DatasetBase {
|
||||
IteratorStateReader* reader) override {
|
||||
mutex_lock l(mu_);
|
||||
// Restore the random number generators.
|
||||
int64 num_random_samples;
|
||||
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kEpochNumRandomSamples),
|
||||
&num_random_samples));
|
||||
seed_generator_->set_num_random_samples(num_random_samples);
|
||||
seed_generator_->Reset();
|
||||
TF_RETURN_IF_ERROR(reader->ReadScalar(this->full_name(kNumRandomSamples),
|
||||
&num_random_samples_));
|
||||
TF_RETURN_IF_ERROR(reader->ReadScalar(this->full_name(kSeed), &seed_));
|
||||
@ -402,10 +393,6 @@ class ShuffleDatasetOpBase::ShuffleDatasetBase : public DatasetBase {
|
||||
return this->dataset()->traceme_metadata_;
|
||||
}
|
||||
|
||||
mutex mu_;
|
||||
int64 seed_ TF_GUARDED_BY(mu_);
|
||||
int64 seed2_ TF_GUARDED_BY(mu_);
|
||||
|
||||
private:
|
||||
// Used to represent slices of `buffer_` that belong to different epochs.
|
||||
// The invariant maintained by the implementation is: `start` <= `end`.
|
||||
@ -426,10 +413,14 @@ class ShuffleDatasetOpBase::ShuffleDatasetBase : public DatasetBase {
|
||||
return out;
|
||||
}
|
||||
|
||||
mutex mu_;
|
||||
SeedGenerator* const seed_generator_ TF_GUARDED_BY(mu_); // Not owned.
|
||||
std::unique_ptr<std::vector<Tensor>[]> buffer_ TF_GUARDED_BY(mu_);
|
||||
std::unique_ptr<IteratorBase> input_impl_ TF_GUARDED_BY(mu_);
|
||||
int64 epoch_ TF_GUARDED_BY(mu_);
|
||||
int64 num_elements_ TF_GUARDED_BY(mu_);
|
||||
std::unique_ptr<IteratorBase> input_impl_ TF_GUARDED_BY(mu_) = nullptr;
|
||||
int64 epoch_ TF_GUARDED_BY(mu_) = 0;
|
||||
int64 num_elements_ TF_GUARDED_BY(mu_) = 0;
|
||||
int64 seed_ TF_GUARDED_BY(mu_) = 0;
|
||||
int64 seed2_ TF_GUARDED_BY(mu_) = 0;
|
||||
// Indices into `buffer_` indicating which data belongs to which epoch.
|
||||
// The slice at the front of the deque references data from the earliest
|
||||
// buffered epoch. It is an invariant that all slices reference
|
||||
@ -444,135 +435,59 @@ class ShuffleDatasetOpBase::ShuffleDatasetBase : public DatasetBase {
|
||||
|
||||
const DatasetBase* const input_;
|
||||
const int64 buffer_size_;
|
||||
const std::shared_ptr<SeedGenerator> seed_generator_;
|
||||
// The number of epochs to run for. Normally this is just 1, but sometimes we
|
||||
// fuse shuffle and repeat together, and make the shuffle dataset op
|
||||
// responsible for repeating as well.
|
||||
const int64 count_;
|
||||
const TraceMeMetadata traceme_metadata_;
|
||||
};
|
||||
}; // ShuffleDatasetBase
|
||||
|
||||
// This version of memory dataset has an exclusive ownership of the seed
|
||||
// generator resource. It supports sharing of the seed generator across
|
||||
// different iterations of the `repeat` transformation but not across different
|
||||
// iterators.
|
||||
class ShuffleDatasetOp::Dataset : public ShuffleDatasetBase {
|
||||
public:
|
||||
Dataset(OpKernelContext* ctx, const DatasetBase* input, int64 buffer_size,
|
||||
Seeds seeds, int64 count, bool reshuffle_each_iteration)
|
||||
: ShuffleDatasetBase(ctx, input, buffer_size, count),
|
||||
seeds_(seeds),
|
||||
reshuffle_each_iteration_(reshuffle_each_iteration) {}
|
||||
int64 count, RandomSeeds&& seeds, SeedGeneratorManager* manager,
|
||||
ResourceHandle&& resource_handle)
|
||||
: ShuffleDatasetBase(ctx, input, buffer_size, manager->get(), count),
|
||||
manager_(manager),
|
||||
resource_handle_(std::move(resource_handle)),
|
||||
resource_mgr_(ctx->resource_manager()),
|
||||
seeds_(std::move(seeds)) {}
|
||||
|
||||
string DebugString() const override {
|
||||
name_utils::DatasetDebugStringParams params;
|
||||
params.dataset_prefix = kDatasetPrefix;
|
||||
params.set_args(buffer_size_, seeds_.seed_, seeds_.seed2_);
|
||||
return name_utils::DatasetDebugString(kDatasetType, params);
|
||||
~Dataset() override {
|
||||
manager_->Unref();
|
||||
Status s = resource_mgr_->Delete<SeedGeneratorManager>(
|
||||
resource_handle_.container(), resource_handle_.name());
|
||||
if (!s.ok()) {
|
||||
LOG(WARNING) << "Failed to delete RNG resource: " << s.ToString();
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<IteratorBase> MakeIteratorInternal(
|
||||
const string& prefix) const override {
|
||||
return absl::make_unique<Iterator>(
|
||||
Iterator::Params{this,
|
||||
name_utils::IteratorPrefix(kDatasetType, prefix)},
|
||||
seeds_.seed_, seeds_.seed2_);
|
||||
}
|
||||
string op_type() const override { return kDatasetType; }
|
||||
|
||||
protected:
|
||||
class Iterator : public ShuffleDatasetBase::Iterator<Dataset> {
|
||||
public:
|
||||
Iterator(const Params& params, int64 seed, int64 seed2)
|
||||
: ShuffleDatasetBase::Iterator<Dataset>(params, seed, seed2) {}
|
||||
|
||||
~Iterator() override { seed_generator_->Unref(); }
|
||||
|
||||
Status Initialize(IteratorContext* ctx) override {
|
||||
// Firstly, lookup or create a seed generator from the IteratorResource
|
||||
// resource_mgr.
|
||||
ResourceMgr* mgr = ctx->resource_mgr();
|
||||
SeedGenerator* seed_generator;
|
||||
const string name = strings::StrCat(
|
||||
prefix(), name_utils::kDelimiter, dataset()->type_string(),
|
||||
name_utils::kDelimiter, kSeedGenerator);
|
||||
|
||||
int64 dataset_seed, dataset_seed2;
|
||||
{
|
||||
tf_shared_lock l(mu_);
|
||||
// Ideally we'd like to hold this lock in the LookupOrCreate method,
|
||||
// but that trips up our Deadlock detection code.
|
||||
dataset_seed = seed_;
|
||||
dataset_seed2 = seed2_;
|
||||
}
|
||||
TF_RETURN_IF_ERROR(mgr->LookupOrCreate<SeedGenerator>(
|
||||
kTFData, name, &seed_generator,
|
||||
[this, dataset_seed, dataset_seed2](SeedGenerator** seed_generator) {
|
||||
// On the first iterator creation, use the original seeds from the
|
||||
// dataset to seed a `SeedGenerator` that will provide seeds
|
||||
// for subsequent repetitions of the same dataset.
|
||||
if (dataset()->reshuffle_each_iteration_) {
|
||||
*seed_generator =
|
||||
new RandomSeedGenerator(dataset_seed, dataset_seed2);
|
||||
} else {
|
||||
*seed_generator =
|
||||
new FixedSeedGenerator(dataset_seed, dataset_seed2);
|
||||
}
|
||||
return Status::OK();
|
||||
}));
|
||||
seed_generator_ = seed_generator;
|
||||
seed_generator_->GenerateSeeds(&seed_, &seed2_);
|
||||
mutex_lock l(mu_);
|
||||
ResetRngs();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
protected:
|
||||
std::shared_ptr<model::Node> CreateNode(
|
||||
IteratorContext* ctx, model::Node::Args args) const override {
|
||||
return model::MakeKnownRatioNode(std::move(args),
|
||||
/*ratio=*/1);
|
||||
}
|
||||
|
||||
Status SaveInternal(SerializationContext* ctx,
|
||||
IteratorStateWriter* writer) override {
|
||||
// Save RNG state of Dataset.
|
||||
TF_RETURN_IF_ERROR(
|
||||
writer->WriteScalar(full_name(kDSNumRandomSamples),
|
||||
seed_generator_->num_random_samples()));
|
||||
|
||||
// Save the Iterator.
|
||||
return ShuffleDatasetBase::Iterator<Dataset>::SaveInternal(ctx, writer);
|
||||
}
|
||||
|
||||
Status RestoreInternal(IteratorContext* ctx,
|
||||
IteratorStateReader* reader) override {
|
||||
// Restore RNG state of Dataset.
|
||||
int64 num_random_samples;
|
||||
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kDSNumRandomSamples),
|
||||
&num_random_samples));
|
||||
seed_generator_->set_num_random_samples(num_random_samples);
|
||||
seed_generator_->Reset();
|
||||
|
||||
// Restore the Iterator.
|
||||
return ShuffleDatasetBase::Iterator<Dataset>::RestoreInternal(ctx,
|
||||
reader);
|
||||
}
|
||||
|
||||
private:
|
||||
SeedGenerator* seed_generator_;
|
||||
};
|
||||
|
||||
Status AsGraphDefInternal(SerializationContext* ctx,
|
||||
DatasetGraphDefBuilder* b,
|
||||
Node** output) const override {
|
||||
Node* input_graph_node = nullptr;
|
||||
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
|
||||
Node* buffer_size = nullptr;
|
||||
Node* seed = nullptr;
|
||||
Node* seed2 = nullptr;
|
||||
Node* buffer_size_node = nullptr;
|
||||
Node* seed_node = nullptr;
|
||||
Node* seed2_node = nullptr;
|
||||
AttrValue reshuffle_each_iteration;
|
||||
|
||||
TF_RETURN_IF_ERROR(b->AddScalar(buffer_size_, &buffer_size));
|
||||
TF_RETURN_IF_ERROR(
|
||||
AddSeeds(seeds_, /*preserve_random_seeds=*/true, b, &seed, &seed2));
|
||||
b->BuildAttrValue(reshuffle_each_iteration_, &reshuffle_each_iteration);
|
||||
TF_RETURN_IF_ERROR(b->AddScalar(buffer_size_, &buffer_size_node));
|
||||
TF_RETURN_IF_ERROR(b->AddScalar(seeds_.input_seed(), &seed_node));
|
||||
TF_RETURN_IF_ERROR(b->AddScalar(seeds_.input_seed2(), &seed2_node));
|
||||
b->BuildAttrValue(seed_generator_->reshuffle_each_iteration(),
|
||||
&reshuffle_each_iteration);
|
||||
TF_RETURN_IF_ERROR(b->AddDataset(
|
||||
this, {input_graph_node, buffer_size, seed, seed2}, // Inputs
|
||||
this,
|
||||
{input_graph_node, buffer_size_node, seed_node, seed2_node}, // Inputs
|
||||
{std::make_pair(kReshuffleEachIteration,
|
||||
reshuffle_each_iteration)}, // Attrs
|
||||
output));
|
||||
@ -580,92 +495,41 @@ class ShuffleDatasetOp::Dataset : public ShuffleDatasetBase {
|
||||
}
|
||||
|
||||
private:
|
||||
const Seeds seeds_;
|
||||
const bool reshuffle_each_iteration_;
|
||||
SeedGeneratorManager* const manager_; // Owned.
|
||||
const ResourceHandle resource_handle_;
|
||||
ResourceMgr* const resource_mgr_; // Not owned.
|
||||
const RandomSeeds seeds_;
|
||||
};
|
||||
|
||||
// A shuffle dataset that uses an external seed generator resource to choose the
|
||||
// shuffle seeds for each iteration.
|
||||
// This version of shuffle dataset has a shared ownership of the seed generator
|
||||
// resource. It supports sharing of the generator state across different
|
||||
// iterations of the `repeat` transformation and also across different
|
||||
// iterators.
|
||||
class ShuffleDatasetOp::DatasetV2 : public ShuffleDatasetBase {
|
||||
public:
|
||||
DatasetV2(OpKernelContext* ctx, const DatasetBase* input, int64 buffer_size,
|
||||
int64 count, SeedGenerator* seed_generator,
|
||||
std::unique_ptr<OwnedResourceHandle> handle)
|
||||
: ShuffleDatasetBase(ctx, input, buffer_size, count),
|
||||
seed_generator_(seed_generator),
|
||||
handle_(std::move(handle)) {}
|
||||
int64 count, SeedGeneratorManager* manager,
|
||||
ResourceHandle&& resource_handle, bool owns_resource)
|
||||
: ShuffleDatasetBase(ctx, input, buffer_size, manager->get(), count),
|
||||
manager_(manager),
|
||||
owns_resource_(owns_resource),
|
||||
resource_handle_(std::move(resource_handle)),
|
||||
resource_mgr_(ctx->resource_manager()) {}
|
||||
|
||||
~DatasetV2() override { seed_generator_->Unref(); }
|
||||
|
||||
string DebugString() const override {
|
||||
name_utils::DatasetDebugStringParams params;
|
||||
params.dataset_prefix = kDatasetV2Prefix;
|
||||
params.set_args(buffer_size_);
|
||||
return name_utils::DatasetDebugString(kDatasetType, params);
|
||||
~DatasetV2() override {
|
||||
manager_->Unref();
|
||||
if (owns_resource_) {
|
||||
Status s = resource_mgr_->Delete<SeedGeneratorManager>(
|
||||
resource_handle_.container(), resource_handle_.name());
|
||||
if (!s.ok()) {
|
||||
LOG(WARNING) << "Failed to delete RNG resource: " << s.ToString();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Status CheckExternalState() const override {
|
||||
return errors::FailedPrecondition(
|
||||
DebugString(), " depends on random seed generator resource.");
|
||||
}
|
||||
|
||||
std::unique_ptr<IteratorBase> MakeIteratorInternal(
|
||||
const string& prefix) const override {
|
||||
return absl::make_unique<Iterator>(
|
||||
Iterator::Params{this,
|
||||
name_utils::IteratorPrefix(kDatasetType, prefix)},
|
||||
seed_generator_);
|
||||
}
|
||||
string op_type() const override { return kDatasetType; }
|
||||
|
||||
protected:
|
||||
class Iterator : public ShuffleDatasetBase::Iterator<DatasetV2> {
|
||||
public:
|
||||
Iterator(const Params& params, SeedGenerator* seed_generator)
|
||||
: ShuffleDatasetBase::Iterator<DatasetV2>(params, 0, 0),
|
||||
seed_generator_(seed_generator) {}
|
||||
|
||||
Status Initialize(IteratorContext* ctx) override {
|
||||
mutex_lock l(mu_);
|
||||
seed_generator_->GenerateSeeds(&seed_, &seed2_);
|
||||
ResetRngs();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
protected:
|
||||
std::shared_ptr<model::Node> CreateNode(
|
||||
IteratorContext* ctx, model::Node::Args args) const override {
|
||||
return model::MakeKnownRatioNode(std::move(args), /*ratio=*/1);
|
||||
}
|
||||
|
||||
Status SaveInternal(SerializationContext* ctx,
|
||||
IteratorStateWriter* writer) override {
|
||||
// Save state of the seed generator.
|
||||
TF_RETURN_IF_ERROR(
|
||||
writer->WriteScalar(full_name(kDSNumRandomSamples),
|
||||
seed_generator_->num_random_samples()));
|
||||
|
||||
// Save the tterator state.
|
||||
return ShuffleDatasetBase::Iterator<DatasetV2>::SaveInternal(ctx, writer);
|
||||
}
|
||||
|
||||
Status RestoreInternal(IteratorContext* ctx,
|
||||
IteratorStateReader* reader) override {
|
||||
// Restore state of the seed generator.
|
||||
int64 num_random_samples;
|
||||
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kDSNumRandomSamples),
|
||||
&num_random_samples));
|
||||
seed_generator_->set_num_random_samples(num_random_samples);
|
||||
seed_generator_->Reset();
|
||||
|
||||
// Restore the iterator state.
|
||||
return ShuffleDatasetBase::Iterator<DatasetV2>::RestoreInternal(ctx,
|
||||
reader);
|
||||
}
|
||||
|
||||
private:
|
||||
SeedGenerator* seed_generator_;
|
||||
};
|
||||
|
||||
Status AsGraphDefInternal(SerializationContext* ctx,
|
||||
DatasetGraphDefBuilder* b,
|
||||
Node** output) const override {
|
||||
@ -675,7 +539,7 @@ class ShuffleDatasetOp::DatasetV2 : public ShuffleDatasetBase {
|
||||
TF_RETURN_IF_ERROR(b->AddScalar(buffer_size_, &buffer_size_node));
|
||||
Node* resource_handle_node = nullptr;
|
||||
Tensor handle(DT_RESOURCE, TensorShape({}));
|
||||
handle.scalar<ResourceHandle>()() = handle_->handle();
|
||||
handle.scalar<ResourceHandle>()() = resource_handle_;
|
||||
TF_RETURN_IF_ERROR(b->AddTensor(handle, &resource_handle_node));
|
||||
TF_RETURN_IF_ERROR(b->AddDataset(
|
||||
this,
|
||||
@ -686,33 +550,39 @@ class ShuffleDatasetOp::DatasetV2 : public ShuffleDatasetBase {
|
||||
}
|
||||
|
||||
private:
|
||||
SeedGenerator* seed_generator_ = nullptr;
|
||||
std::unique_ptr<OwnedResourceHandle> handle_;
|
||||
SeedGeneratorManager* const manager_; // Owned.
|
||||
const bool owns_resource_;
|
||||
const ResourceHandle resource_handle_;
|
||||
ResourceMgr* const resource_mgr_; // Not owned.
|
||||
};
|
||||
|
||||
// A dataset that uses the same fixed seed for all iterators created from it.
|
||||
// Used when `reshuffle_each_iteration` is false.
|
||||
// TODO(b/151115950): delete this class.
|
||||
class ShuffleDatasetOp::FixedSeedDataset : public ShuffleDatasetBase {
|
||||
// This version of shuffle dataset extends the functionality of DatasetV2 with
|
||||
// the ability to preserve seed generator configuration (i.e. initial seeds and
|
||||
// whether to reshuffle each iteration) across serialization of the dataset.
|
||||
class ShuffleDatasetOp::DatasetV3 : public ShuffleDatasetBase {
|
||||
public:
|
||||
FixedSeedDataset(OpKernelContext* ctx, const DatasetBase* input,
|
||||
int64 buffer_size, Seeds seeds, int64 count)
|
||||
: ShuffleDatasetBase(ctx, input, buffer_size, count), seeds_(seeds) {}
|
||||
DatasetV3(OpKernelContext* ctx, const DatasetBase* input, int64 buffer_size,
|
||||
int64 count, RandomSeeds&& seeds, SeedGeneratorManager* manager,
|
||||
ResourceHandle&& resource_handle, bool owns_resource)
|
||||
: ShuffleDatasetBase(ctx, input, buffer_size, manager->get(), count),
|
||||
manager_(manager),
|
||||
owns_resource_(owns_resource),
|
||||
resource_handle_(std::move(resource_handle)),
|
||||
resource_mgr_(ctx->resource_manager()),
|
||||
seeds_(std::move(seeds)) {}
|
||||
|
||||
string DebugString() const override {
|
||||
name_utils::DatasetDebugStringParams params;
|
||||
params.dataset_prefix = kFixedSeedDatasetPrefix;
|
||||
params.set_args(buffer_size_, seeds_.seed_, seeds_.seed2_);
|
||||
return name_utils::DatasetDebugString(kDatasetType, params);
|
||||
~DatasetV3() override {
|
||||
manager_->Unref();
|
||||
if (owns_resource_) {
|
||||
Status s = resource_mgr_->Delete<SeedGeneratorManager>(
|
||||
resource_handle_.container(), resource_handle_.name());
|
||||
if (!s.ok()) {
|
||||
LOG(WARNING) << "Failed to delete RNG resource: " << s.ToString();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<IteratorBase> MakeIteratorInternal(
|
||||
const string& prefix) const override {
|
||||
return absl::make_unique<ShuffleDatasetBase::Iterator<ShuffleDatasetBase>>(
|
||||
ShuffleDatasetBase::Iterator<ShuffleDatasetBase>::Params{
|
||||
this, name_utils::IteratorPrefix(kDatasetType, prefix)},
|
||||
seeds_.seed_, seeds_.seed2_);
|
||||
}
|
||||
string op_type() const override { return kDatasetType; }
|
||||
|
||||
protected:
|
||||
Status AsGraphDefInternal(SerializationContext* ctx,
|
||||
@ -720,30 +590,47 @@ class ShuffleDatasetOp::FixedSeedDataset : public ShuffleDatasetBase {
|
||||
Node** output) const override {
|
||||
Node* input_graph_node = nullptr;
|
||||
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
|
||||
Node* buffer_size = nullptr;
|
||||
Node* seed = nullptr;
|
||||
Node* seed2 = nullptr;
|
||||
Node* buffer_size_node = nullptr;
|
||||
Node* seed_node = nullptr;
|
||||
Node* seed2_node = nullptr;
|
||||
TF_RETURN_IF_ERROR(b->AddScalar(buffer_size_, &buffer_size_node));
|
||||
TF_RETURN_IF_ERROR(b->AddScalar(seeds_.input_seed(), &seed_node));
|
||||
TF_RETURN_IF_ERROR(b->AddScalar(seeds_.input_seed2(), &seed2_node));
|
||||
Node* resource_handle_node = nullptr;
|
||||
Tensor handle(DT_RESOURCE, TensorShape({}));
|
||||
handle.scalar<ResourceHandle>()() = resource_handle_;
|
||||
TF_RETURN_IF_ERROR(b->AddTensor(handle, &resource_handle_node));
|
||||
AttrValue reshuffle_each_iteration;
|
||||
|
||||
TF_RETURN_IF_ERROR(b->AddScalar(buffer_size_, &buffer_size));
|
||||
b->BuildAttrValue(seed_generator_->reshuffle_each_iteration(),
|
||||
&reshuffle_each_iteration);
|
||||
TF_RETURN_IF_ERROR(
|
||||
AddSeeds(seeds_, ctx->preserve_random_seeds(), b, &seed, &seed2));
|
||||
b->BuildAttrValue(false, &reshuffle_each_iteration);
|
||||
TF_RETURN_IF_ERROR(b->AddDataset(
|
||||
this, {input_graph_node, buffer_size, seed, seed2}, // Inputs
|
||||
{std::make_pair(kReshuffleEachIteration,
|
||||
reshuffle_each_iteration)}, // Attrs
|
||||
output));
|
||||
b->AddDataset(this,
|
||||
{input_graph_node, buffer_size_node, seed_node,
|
||||
seed2_node, resource_handle_node}, // Inputs
|
||||
{std::make_pair(kReshuffleEachIteration,
|
||||
reshuffle_each_iteration)}, // Attrs
|
||||
output));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
const Seeds seeds_;
|
||||
SeedGeneratorManager* const manager_; // Owned
|
||||
const bool owns_resource_;
|
||||
const ResourceHandle resource_handle_;
|
||||
ResourceMgr* const resource_mgr_; // Not owned.
|
||||
const RandomSeeds seeds_;
|
||||
};
|
||||
|
||||
ShuffleDatasetOp::ShuffleDatasetOp(OpKernelConstruction* ctx)
|
||||
: ShuffleDatasetOpBase(ctx),
|
||||
op_version_(ctx->def().op() == kShuffleDataset ? 1 : 2) {
|
||||
: ShuffleDatasetOpBase(ctx) {
|
||||
auto& op_name = ctx->def().op();
|
||||
if (op_name == kShuffleDatasetV3) {
|
||||
op_version_ = 3;
|
||||
} else if (op_name == kShuffleDatasetV2) {
|
||||
op_version_ = 2;
|
||||
} else if (op_name == kShuffleDatasetV1) {
|
||||
op_version_ = 1;
|
||||
}
|
||||
if (ctx->HasAttr(kReshuffleEachIteration)) {
|
||||
OP_REQUIRES_OK(
|
||||
ctx, ctx->GetAttr(kReshuffleEachIteration, &reshuffle_each_iteration_));
|
||||
@ -760,71 +647,133 @@ void ShuffleDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
|
||||
errors::InvalidArgument("buffer_size must be greater than zero."));
|
||||
|
||||
int64 count = 1;
|
||||
if (op_version_ == 2) {
|
||||
SeedGenerator* seed_generator = nullptr;
|
||||
Status s = LookupResource(ctx, HandleFromInput(ctx, 2), &seed_generator);
|
||||
static std::atomic<int64> resource_id_counter(0);
|
||||
const string& container = ctx->resource_manager()->default_container();
|
||||
auto name = strings::StrCat(ctx->op_kernel().name(), "/", kSeedGenerator, "_",
|
||||
resource_id_counter.fetch_add(1));
|
||||
if (op_version_ == 3) {
|
||||
auto handle = HandleFromInput(ctx, 4);
|
||||
SeedGeneratorManager* manager = nullptr;
|
||||
Status s = ctx->resource_manager()->Lookup<SeedGeneratorManager>(
|
||||
handle.container(), handle.name(), &manager);
|
||||
int64 seed;
|
||||
OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, kSeed, &seed));
|
||||
int64 seed2;
|
||||
OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, kSeed2, &seed2));
|
||||
RandomSeeds seeds(seed, seed2);
|
||||
bool owns_resource = false;
|
||||
if (errors::IsNotFound(s)) {
|
||||
OP_REQUIRES_OK(
|
||||
ctx,
|
||||
ctx->resource_manager()->LookupOrCreate<SeedGeneratorManager>(
|
||||
container, name, &manager,
|
||||
[reshuffle = reshuffle_each_iteration_,
|
||||
&seeds](SeedGeneratorManager** manager) {
|
||||
if (reshuffle) {
|
||||
*manager =
|
||||
new SeedGeneratorManager(new RandomSeedGenerator(seeds));
|
||||
} else {
|
||||
*manager =
|
||||
new SeedGeneratorManager(new FixedSeedGenerator(seeds));
|
||||
}
|
||||
return Status::OK();
|
||||
}));
|
||||
handle = MakeResourceHandle<SeedGenerator>(ctx, container, name);
|
||||
owns_resource = true;
|
||||
} else {
|
||||
OP_REQUIRES_OK(ctx, s);
|
||||
}
|
||||
|
||||
// Ownership of manager is transferred onto `DatasetV3`.
|
||||
*output = new ShuffleDatasetOp::DatasetV3(ctx, input, buffer_size, count,
|
||||
std::move(seeds), manager,
|
||||
std::move(handle), owns_resource);
|
||||
} else if (op_version_ == 2) {
|
||||
auto handle = HandleFromInput(ctx, 2);
|
||||
SeedGeneratorManager* manager = nullptr;
|
||||
Status s = ctx->resource_manager()->Lookup<SeedGeneratorManager>(
|
||||
handle.container(), handle.name(), &manager);
|
||||
bool owns_resource = false;
|
||||
if (errors::IsNotFound(s)) {
|
||||
LOG(WARNING) << "Failed to find seed generator resource. Falling back to "
|
||||
"using a non-deterministically-seeded seed generator.";
|
||||
*output =
|
||||
new ShuffleDatasetOp::Dataset(ctx, input, buffer_size, Seeds(0, 0),
|
||||
count, reshuffle_each_iteration_);
|
||||
return;
|
||||
"using a non-deterministically seeded generator and "
|
||||
"reshuffling each iteration.";
|
||||
RandomSeeds seeds(0, 0);
|
||||
OP_REQUIRES_OK(
|
||||
ctx, ctx->resource_manager()->LookupOrCreate<SeedGeneratorManager>(
|
||||
container, name, &manager,
|
||||
[&seeds](SeedGeneratorManager** manager) {
|
||||
*manager = new SeedGeneratorManager(
|
||||
new RandomSeedGenerator(seeds));
|
||||
return Status::OK();
|
||||
}));
|
||||
handle = MakeResourceHandle<SeedGeneratorManager>(ctx, container, name);
|
||||
owns_resource = true;
|
||||
} else {
|
||||
OP_REQUIRES_OK(ctx, s);
|
||||
}
|
||||
OP_REQUIRES_OK(ctx, s);
|
||||
|
||||
// Create a fresh handle for the resource because the input handle can
|
||||
// become invalid after this op executes.
|
||||
std::unique_ptr<OwnedResourceHandle> handle;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, OwnedResourceHandle::Create(
|
||||
ctx, seed_generator, seed_generator->DebugString(), &handle));
|
||||
|
||||
// Ownership of seed generator is transferred onto `DatasetV2`.
|
||||
*output = new ShuffleDatasetOp::DatasetV2(
|
||||
ctx, input, buffer_size, count, seed_generator, std::move(handle));
|
||||
return;
|
||||
}
|
||||
|
||||
int64 seed;
|
||||
OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, kSeed, &seed));
|
||||
|
||||
int64 seed2;
|
||||
OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, kSeed2, &seed2));
|
||||
|
||||
if (!reshuffle_each_iteration_) {
|
||||
// This dataset is only needed to support old clients running v2 eager with
|
||||
// reshuffle_each_iteration_=false. We can't tell here whether we are in v2
|
||||
// eager, so we conservatively always use FixedSeedDataset when
|
||||
// reshuffle_each_iteration=false.
|
||||
*output = new FixedSeedDataset(ctx, input, buffer_size, Seeds(seed, seed2),
|
||||
count);
|
||||
// Ownership of manager is transferred onto `DatasetV2`.
|
||||
*output =
|
||||
new ShuffleDatasetOp::DatasetV2(ctx, input, buffer_size, count, manager,
|
||||
std::move(handle), owns_resource);
|
||||
} else {
|
||||
*output = new ShuffleDatasetOp::Dataset(ctx, input, buffer_size,
|
||||
Seeds(seed, seed2), count,
|
||||
reshuffle_each_iteration_);
|
||||
if (op_version_ != 1) {
|
||||
LOG(WARNING) << "Unsupported version of shuffle dataset op: "
|
||||
<< op_version_ << ". Defaulting to version 1.";
|
||||
}
|
||||
int64 seed;
|
||||
OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, kSeed, &seed));
|
||||
int64 seed2;
|
||||
OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, kSeed2, &seed2));
|
||||
RandomSeeds seeds(seed, seed2);
|
||||
SeedGeneratorManager* manager;
|
||||
OP_REQUIRES_OK(
|
||||
ctx,
|
||||
ctx->resource_manager()->LookupOrCreate<SeedGeneratorManager>(
|
||||
container, name, &manager,
|
||||
[reshuffle = reshuffle_each_iteration_,
|
||||
&seeds](SeedGeneratorManager** manager) {
|
||||
if (reshuffle) {
|
||||
*manager =
|
||||
new SeedGeneratorManager(new RandomSeedGenerator(seeds));
|
||||
} else {
|
||||
*manager =
|
||||
new SeedGeneratorManager(new FixedSeedGenerator(seeds));
|
||||
}
|
||||
return Status::OK();
|
||||
}));
|
||||
auto handle =
|
||||
MakeResourceHandle<SeedGeneratorManager>(ctx, container, name);
|
||||
|
||||
// Ownership of manager is transferred onto `Dataset`.
|
||||
*output = new ShuffleDatasetOp::Dataset(ctx, input, buffer_size, count,
|
||||
std::move(seeds), manager,
|
||||
std::move(handle));
|
||||
}
|
||||
}
|
||||
|
||||
class ShuffleAndRepeatDatasetOp::Dataset : public ShuffleDatasetBase {
|
||||
public:
|
||||
Dataset(OpKernelContext* ctx, const DatasetBase* input, int64 buffer_size,
|
||||
Seeds seeds, int64 count)
|
||||
: ShuffleDatasetBase(ctx, input, buffer_size, count), seeds_(seeds) {}
|
||||
RandomSeeds&& seeds, SeedGeneratorManager* manager, int64 count,
|
||||
ResourceHandle&& resource_handle)
|
||||
: ShuffleDatasetBase(ctx, input, buffer_size, manager->get(), count),
|
||||
manager_(manager),
|
||||
resource_handle_(std::move(resource_handle)),
|
||||
resource_mgr_(ctx->resource_manager()),
|
||||
seeds_(std::move(seeds)) {}
|
||||
|
||||
string DebugString() const override {
|
||||
name_utils::DatasetDebugStringParams params;
|
||||
params.set_args(buffer_size_, seeds_.seed_, seeds_.seed2_);
|
||||
return name_utils::DatasetDebugString(kDatasetType, params);
|
||||
~Dataset() override {
|
||||
manager_->Unref();
|
||||
Status s = resource_mgr_->Delete<SeedGeneratorManager>(
|
||||
resource_handle_.container(), resource_handle_.name());
|
||||
if (!s.ok()) {
|
||||
LOG(WARNING) << "Failed to delete RNG resource: " << s.ToString();
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<IteratorBase> MakeIteratorInternal(
|
||||
const string& prefix) const override {
|
||||
return absl::make_unique<ShuffleDatasetBase::Iterator<ShuffleDatasetBase>>(
|
||||
ShuffleDatasetBase::Iterator<ShuffleDatasetBase>::Params{
|
||||
this, name_utils::IteratorPrefix(kDatasetType, prefix)},
|
||||
seeds_.seed_, seeds_.seed2_);
|
||||
}
|
||||
string op_type() const override { return kDatasetType; }
|
||||
|
||||
protected:
|
||||
Status AsGraphDefInternal(SerializationContext* ctx,
|
||||
@ -838,8 +787,8 @@ class ShuffleAndRepeatDatasetOp::Dataset : public ShuffleDatasetBase {
|
||||
Node* count = nullptr;
|
||||
|
||||
TF_RETURN_IF_ERROR(b->AddScalar(buffer_size_, &buffer_size));
|
||||
TF_RETURN_IF_ERROR(
|
||||
AddSeeds(seeds_, /*preserve_random_seeds=*/true, b, &seed, &seed2));
|
||||
TF_RETURN_IF_ERROR(b->AddScalar(seeds_.input_seed(), &seed));
|
||||
TF_RETURN_IF_ERROR(b->AddScalar(seeds_.input_seed2(), &seed2));
|
||||
TF_RETURN_IF_ERROR(b->AddScalar(count_, &count));
|
||||
TF_RETURN_IF_ERROR(b->AddDataset(
|
||||
this, {input_graph_node, buffer_size, seed, seed2, count}, // Inputs
|
||||
@ -849,7 +798,10 @@ class ShuffleAndRepeatDatasetOp::Dataset : public ShuffleDatasetBase {
|
||||
}
|
||||
|
||||
private:
|
||||
const Seeds seeds_;
|
||||
SeedGeneratorManager* const manager_; // Owned.
|
||||
const ResourceHandle resource_handle_;
|
||||
ResourceMgr* const resource_mgr_; // Not owned.
|
||||
const RandomSeeds seeds_;
|
||||
};
|
||||
|
||||
ShuffleAndRepeatDatasetOp::ShuffleAndRepeatDatasetOp(OpKernelConstruction* ctx)
|
||||
@ -874,11 +826,29 @@ void ShuffleAndRepeatDatasetOp::MakeDataset(OpKernelContext* ctx,
|
||||
int64 count;
|
||||
OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, kCount, &count));
|
||||
|
||||
RandomSeeds seeds(seed, seed2);
|
||||
|
||||
OP_REQUIRES(ctx, count > 0 || count == -1,
|
||||
errors::InvalidArgument(
|
||||
"count must be greater than zero or equal to -1."));
|
||||
|
||||
*output = new Dataset(ctx, input, buffer_size, Seeds(seed, seed2), count);
|
||||
static std::atomic<int64> resource_id_counter(0);
|
||||
const string& container = ctx->resource_manager()->default_container();
|
||||
auto name = strings::StrCat(ctx->op_kernel().name(), "/", kSeedGenerator, "_",
|
||||
resource_id_counter.fetch_add(1));
|
||||
SeedGeneratorManager* manager;
|
||||
OP_REQUIRES_OK(
|
||||
ctx,
|
||||
ctx->resource_manager()->LookupOrCreate<SeedGeneratorManager>(
|
||||
container, name, &manager, [&seeds](SeedGeneratorManager** manager) {
|
||||
*manager = new SeedGeneratorManager(new RandomSeedGenerator(seeds));
|
||||
return Status::OK();
|
||||
}));
|
||||
auto handle = MakeResourceHandle<SeedGeneratorManager>(ctx, container, name);
|
||||
|
||||
// Ownership of manager is transferred onto `Dataset`.
|
||||
*output = new Dataset(ctx, input, buffer_size, std::move(seeds), manager,
|
||||
count, std::move(handle));
|
||||
}
|
||||
|
||||
namespace {
|
||||
@ -888,6 +858,9 @@ REGISTER_KERNEL_BUILDER(Name("ShuffleDataset").Device(DEVICE_CPU),
|
||||
REGISTER_KERNEL_BUILDER(Name("ShuffleDatasetV2").Device(DEVICE_CPU),
|
||||
ShuffleDatasetOp);
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("ShuffleDatasetV3").Device(DEVICE_CPU),
|
||||
ShuffleDatasetOp);
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("ShuffleAndRepeatDataset").Device(DEVICE_CPU),
|
||||
ShuffleAndRepeatDatasetOp);
|
||||
} // namespace
|
||||
|
@ -50,8 +50,8 @@ class ShuffleDatasetOp : public ShuffleDatasetOpBase {
|
||||
private:
|
||||
class Dataset;
|
||||
class DatasetV2;
|
||||
class FixedSeedDataset;
|
||||
int op_version_;
|
||||
class DatasetV3;
|
||||
int op_version_ = 0;
|
||||
bool reshuffle_each_iteration_;
|
||||
};
|
||||
|
||||
|
@ -297,23 +297,23 @@ std::vector<GetNextTestCase<ShuffleDatasetParams>> GetNextTestCases() {
|
||||
{/*dataset_params=*/ShuffleDatasetParams7(),
|
||||
/*expected_shuffle_outputs=*/
|
||||
CreateTensors<int64>(TensorShape({}),
|
||||
{{9}, {0}, {8}, {6}, {1}, {3}, {7}, {2}, {4}, {5},
|
||||
{4}, {3}, {0}, {5}, {8}, {2}, {6}, {9}, {7}, {1}}),
|
||||
{{2}, {6}, {1}, {3}, {9}, {5}, {0}, {8}, {7}, {4},
|
||||
{0}, {5}, {1}, {7}, {2}, {9}, {8}, {4}, {6}, {3}}),
|
||||
/*expected_reshuffle_outputs=*/
|
||||
CreateTensors<int64>(TensorShape({}), {{9}, {0}, {8}, {6}, {1}, {3}, {7},
|
||||
{2}, {4}, {5}, {4}, {3}, {0}, {5},
|
||||
{8}, {2}, {6}, {9}, {7}, {1}})},
|
||||
CreateTensors<int64>(TensorShape({}), {{1}, {6}, {0}, {5}, {2}, {7}, {4},
|
||||
{3}, {9}, {8}, {6}, {5}, {0}, {9},
|
||||
{4}, {7}, {2}, {8}, {1}, {3}})},
|
||||
{/*dataset_params=*/ShuffleDatasetParams8(),
|
||||
/*expected_shuffle_outputs=*/
|
||||
CreateTensors<int64>(
|
||||
TensorShape({}),
|
||||
{{2}, {0}, {1}, {2}, {0}, {1}, {1}, {2}, {0}, {1}, {0},
|
||||
{2}, {2}, {0}, {1}, {1}, {0}, {2}, {2}, {1}, {0}}),
|
||||
{{1}, {2}, {0}, {1}, {2}, {0}, {1}, {0}, {2}, {1}, {0},
|
||||
{2}, {0}, {2}, {1}, {0}, {1}, {2}, {1}, {2}, {0}}),
|
||||
/*expected_reshuffle_outputs=*/
|
||||
CreateTensors<int64>(
|
||||
TensorShape({}),
|
||||
{{2}, {0}, {1}, {2}, {0}, {1}, {1}, {2}, {0}, {1}, {0},
|
||||
{2}, {2}, {0}, {1}, {1}, {0}, {2}, {2}, {1}, {0}})}};
|
||||
{{1}, {0}, {2}, {0}, {1}, {2}, {2}, {1}, {0}, {0}, {1},
|
||||
{2}, {0}, {2}, {1}, {0}, {1}, {2}, {1}, {0}, {2}})}};
|
||||
}
|
||||
|
||||
class ParameterizedGetNextTest : public ShuffleDatasetOpTest,
|
||||
@ -496,16 +496,16 @@ IteratorSaveAndRestoreTestCases() {
|
||||
{/*dataset_params=*/ShuffleDatasetParams7(),
|
||||
/*breakpoints=*/{0, 5, 22},
|
||||
/*expected_shuffle_outputs=*/
|
||||
CreateTensors<int64>(TensorShape({}), {{9}, {0}, {8}, {6}, {1}, {3}, {7},
|
||||
{2}, {4}, {5}, {4}, {3}, {0}, {5},
|
||||
{8}, {2}, {6}, {9}, {7}, {1}})},
|
||||
CreateTensors<int64>(TensorShape({}), {{2}, {6}, {1}, {3}, {9}, {5}, {0},
|
||||
{8}, {7}, {4}, {0}, {5}, {1}, {7},
|
||||
{2}, {9}, {8}, {4}, {6}, {3}})},
|
||||
{/*dataset_params=*/ShuffleDatasetParams8(),
|
||||
/*breakpoints=*/{0, 5, 20},
|
||||
/*expected_shuffle_outputs=*/
|
||||
CreateTensors<int64>(
|
||||
TensorShape({}),
|
||||
{{2}, {0}, {1}, {2}, {0}, {1}, {1}, {2}, {0}, {1}, {0},
|
||||
{2}, {2}, {0}, {1}, {1}, {0}, {2}, {2}, {1}, {0}})}};
|
||||
{{1}, {2}, {0}, {1}, {2}, {0}, {1}, {0}, {2}, {1}, {0},
|
||||
{2}, {0}, {2}, {1}, {0}, {1}, {2}, {1}, {2}, {0}})}};
|
||||
}
|
||||
|
||||
class ParameterizedIteratorSaveAndRestoreTest
|
||||
|
@ -438,6 +438,13 @@ REGISTER_OP("DeleteRandomSeedGenerator")
|
||||
.Input("deleter: variant")
|
||||
.SetShapeFn(shape_inference::NoOutputs);
|
||||
|
||||
REGISTER_OP("DummySeedGenerator")
|
||||
.Output("handle: resource")
|
||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||
c->set_output(0, c->Scalar());
|
||||
return Status::OK();
|
||||
});
|
||||
|
||||
REGISTER_OP("ShuffleDataset")
|
||||
.Input("input_dataset: variant")
|
||||
.Input("buffer_size: int64")
|
||||
@ -465,12 +472,32 @@ REGISTER_OP("ShuffleDatasetV2")
|
||||
.Attr("output_shapes: list(shape) >= 1")
|
||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||
shape_inference::ShapeHandle unused;
|
||||
// buffer_size, seed, and seed2 should be scalars.
|
||||
// buffer_size and seed_generator should be scalars.
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
|
||||
return shape_inference::ScalarShape(c);
|
||||
});
|
||||
|
||||
REGISTER_OP("ShuffleDatasetV3")
|
||||
.Input("input_dataset: variant")
|
||||
.Input("buffer_size: int64")
|
||||
.Input("seed: int64")
|
||||
.Input("seed2: int64")
|
||||
.Input("seed_generator: resource")
|
||||
.Output("handle: variant")
|
||||
.Attr("reshuffle_each_iteration: bool = true")
|
||||
.Attr("output_types: list(type) >= 1")
|
||||
.Attr("output_shapes: list(shape) >= 1")
|
||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||
shape_inference::ShapeHandle unused;
|
||||
// buffer_size, seed, seed2, and seed_generator should be scalars.
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
|
||||
return shape_inference::ScalarShape(c);
|
||||
});
|
||||
|
||||
REGISTER_OP("ShuffleAndRepeatDataset")
|
||||
.Input("input_dataset: variant")
|
||||
.Input("buffer_size: int64")
|
||||
|
@ -77,14 +77,6 @@ class DistributeOptions(options.OptionsBase):
|
||||
"files to shard.",
|
||||
default_factory=lambda: AutoShardPolicy.AUTO)
|
||||
|
||||
_make_stateless = options.create_option(
|
||||
name="_make_stateless",
|
||||
ty=bool,
|
||||
docstring=
|
||||
"Determines whether the input pipeline should be rewritten to not "
|
||||
"contain stateful transformations (so that its graph can be moved "
|
||||
"between devices).")
|
||||
|
||||
num_devices = options.create_option(
|
||||
name="num_devices",
|
||||
ty=int,
|
||||
|
@ -23,8 +23,11 @@ import functools
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python import tf2
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import function
|
||||
from tensorflow.python.framework import combinations
|
||||
from tensorflow.python.framework import dtypes
|
||||
@ -334,9 +337,12 @@ class ShuffleTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
@combinations.generate(
|
||||
combinations.times(
|
||||
test_base.graph_only_combinations() +
|
||||
combinations.combine(mode=["eager"], tf_api_version=1),
|
||||
combinations.combine(mode=["eager"]),
|
||||
combinations.combine(reshuffle=[True, False])))
|
||||
def testRerandomizeOnReplicate(self, reshuffle):
|
||||
if tf2.enabled() and not compat.forward_compatible(2020, 5, 22):
|
||||
self.skipTest("Functionality currently not supported.")
|
||||
|
||||
random_seed.set_random_seed(None)
|
||||
# When no seeds are fixed, each instantiation of the shuffle dataset should
|
||||
# produce elements in a different order.
|
||||
@ -351,6 +357,22 @@ class ShuffleTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
self.assertCountEqual(shuffle_1, shuffle_2)
|
||||
self.assertNotEqual(shuffle_1, shuffle_2)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testCoordinateShuffling(self):
|
||||
if not compat.forward_compatible(
|
||||
2020, 5, 22) and tf2.enabled() and context.executing_eagerly():
|
||||
self.skipTest("Functionality currently not supported.")
|
||||
|
||||
num_elements = 100
|
||||
ds = dataset_ops.Dataset.range(num_elements)
|
||||
ds = ds.shuffle(num_elements, seed=42)
|
||||
ds = dataset_ops.Dataset.zip((ds, ds))
|
||||
get_next = self.getNext(ds)
|
||||
|
||||
for _ in range(100):
|
||||
x, y = self.evaluate(get_next())
|
||||
self.assertEqual(x, y)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -2824,9 +2824,6 @@ class Options(options_lib.OptionsBase):
|
||||
result.append("latency_all_edges")
|
||||
if self.experimental_slack:
|
||||
result.append("slack")
|
||||
if (self.experimental_distribute and
|
||||
self.experimental_distribute._make_stateless): # pylint: disable=protected-access
|
||||
result.append("make_stateless")
|
||||
return result
|
||||
|
||||
def _graph_rewrite_configs(self):
|
||||
@ -3597,6 +3594,8 @@ class CacheDataset(UnaryUnchangedStructureDataset):
|
||||
super(CacheDataset, self).__init__(input_dataset, variant_tensor)
|
||||
|
||||
|
||||
# This can be deleted after the forward compatibility window for switching
|
||||
# to using dummy resource expires on 5/22.
|
||||
class _SeedGeneratorDeleter(object):
|
||||
"""An object which cleans up an anonymous seed generator resource.
|
||||
|
||||
@ -3623,63 +3622,22 @@ class _SeedGeneratorDeleter(object):
|
||||
handle=self._handle, deleter=self._deleter)
|
||||
|
||||
|
||||
# This can be deleted after the forward compatibility window for switching
|
||||
# to using dummy resource expires on 5/22.
|
||||
class _SeedGenerator(object):
|
||||
"""Represents a fixed seed generator resource."""
|
||||
|
||||
def __init__(self, seed, seed2, reshuffle):
|
||||
super(_SeedGenerator, self).__init__()
|
||||
self._device = context.context().device_name
|
||||
self._handle, self._deleter = (
|
||||
gen_dataset_ops.anonymous_seed_generator(
|
||||
seed=seed, seed2=seed2, reshuffle=reshuffle))
|
||||
self._resource_deleter = _SeedGeneratorDeleter(
|
||||
handle=self._handle, device=self._device, deleter=self._deleter)
|
||||
|
||||
@property
|
||||
def handle(self):
|
||||
return self._handle
|
||||
|
||||
|
||||
# TODO(b/151115950): Remove this class after forward compatibility window
|
||||
# expires
|
||||
class _RandomSeedGeneratorDeleter(object):
|
||||
"""An object which cleans up an anonymous random seed generator resource.
|
||||
|
||||
An alternative to defining a __del__ method on an object. Even if the parent
|
||||
object is part of a reference cycle, the cycle will be collectable.
|
||||
"""
|
||||
|
||||
def __init__(self, handle, device, deleter):
|
||||
self._deleter = deleter
|
||||
self._handle = handle
|
||||
self._device = device
|
||||
self._eager_mode = context.executing_eagerly()
|
||||
|
||||
def __del__(self):
|
||||
with ops.device(self._device):
|
||||
# Make sure the resource is deleted in the same mode as it was created in.
|
||||
if self._eager_mode:
|
||||
with context.eager_mode():
|
||||
gen_dataset_ops.delete_random_seed_generator(
|
||||
handle=self._handle, deleter=self._deleter)
|
||||
else:
|
||||
with context.graph_mode():
|
||||
gen_dataset_ops.delete_random_seed_generator(
|
||||
handle=self._handle, deleter=self._deleter)
|
||||
|
||||
|
||||
# TODO(b/151115950): Remove this class after forward compatibility window
|
||||
# expires
|
||||
class _RandomSeedGenerator(object):
|
||||
"""Represents a random seed generator resource."""
|
||||
|
||||
def __init__(self, seed, seed2):
|
||||
super(_RandomSeedGenerator, self).__init__()
|
||||
self._device = context.context().device_name
|
||||
self._handle, self._deleter = (
|
||||
gen_dataset_ops.anonymous_random_seed_generator(seed=seed, seed2=seed2))
|
||||
self._resource_deleter = _RandomSeedGeneratorDeleter(
|
||||
handle=self._handle, device=self._device, deleter=self._deleter)
|
||||
if compat.forward_compatible(2020, 5, 22):
|
||||
self._handle = gen_dataset_ops.dummy_seed_generator()
|
||||
else:
|
||||
self._device = context.context().device_name
|
||||
self._handle, self._deleter = (
|
||||
gen_dataset_ops.anonymous_seed_generator(
|
||||
seed=seed, seed2=seed2, reshuffle=reshuffle))
|
||||
self._resource_deleter = _SeedGeneratorDeleter(
|
||||
handle=self._handle, device=self._device, deleter=self._deleter)
|
||||
|
||||
@property
|
||||
def handle(self):
|
||||
@ -3717,25 +3675,29 @@ class ShuffleDataset(UnaryUnchangedStructureDataset):
|
||||
self._buffer_size = ops.convert_to_tensor(
|
||||
buffer_size, dtype=dtypes.int64, name="buffer_size")
|
||||
self._seed, self._seed2 = random_seed.get_seed(seed)
|
||||
|
||||
if reshuffle_each_iteration is None:
|
||||
self._reshuffle_each_iteration = True
|
||||
else:
|
||||
self._reshuffle_each_iteration = reshuffle_each_iteration
|
||||
reshuffle_each_iteration = True
|
||||
self._reshuffle_each_iteration = reshuffle_each_iteration
|
||||
|
||||
if (tf2.enabled() and (self._reshuffle_each_iteration or
|
||||
compat.forward_compatible(2020, 4, 10)) and
|
||||
if (tf2.enabled() and
|
||||
(context.executing_eagerly() or ops.inside_function())):
|
||||
if compat.forward_compatible(2020, 4, 10):
|
||||
self._seed_generator = _SeedGenerator(self._seed, self._seed2,
|
||||
self._reshuffle_each_iteration)
|
||||
self._seed_generator = _SeedGenerator(self._seed, self._seed2,
|
||||
self._reshuffle_each_iteration)
|
||||
if compat.forward_compatible(2020, 5, 22):
|
||||
variant_tensor = gen_dataset_ops.shuffle_dataset_v3(
|
||||
input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
buffer_size=self._buffer_size,
|
||||
seed=self._seed,
|
||||
seed2=self._seed2,
|
||||
seed_generator=self._seed_generator.handle,
|
||||
reshuffle_each_iteration=self._reshuffle_each_iteration,
|
||||
**self._flat_structure)
|
||||
else:
|
||||
self._seed_generator = _RandomSeedGenerator(self._seed, self._seed2)
|
||||
variant_tensor = gen_dataset_ops.shuffle_dataset_v2(
|
||||
input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
buffer_size=self._buffer_size,
|
||||
seed_generator=self._seed_generator.handle,
|
||||
**self._flat_structure)
|
||||
variant_tensor = gen_dataset_ops.shuffle_dataset_v2(
|
||||
input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
buffer_size=self._buffer_size,
|
||||
seed_generator=self._seed_generator.handle,
|
||||
**self._flat_structure)
|
||||
else:
|
||||
variant_tensor = gen_dataset_ops.shuffle_dataset(
|
||||
input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
@ -4145,29 +4107,17 @@ class ParallelMapDataset(UnaryDataset):
|
||||
else:
|
||||
self._deterministic = "false"
|
||||
self._preserve_cardinality = preserve_cardinality
|
||||
if deterministic is not None or compat.forward_compatible(2020, 3, 6):
|
||||
self._num_parallel_calls = ops.convert_to_tensor(
|
||||
num_parallel_calls, dtype=dtypes.int64, name="num_parallel_calls")
|
||||
variant_tensor = gen_dataset_ops.parallel_map_dataset_v2(
|
||||
input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
self._map_func.function.captured_inputs,
|
||||
f=self._map_func.function,
|
||||
num_parallel_calls=self._num_parallel_calls,
|
||||
deterministic=self._deterministic,
|
||||
use_inter_op_parallelism=self._use_inter_op_parallelism,
|
||||
preserve_cardinality=self._preserve_cardinality,
|
||||
**self._flat_structure)
|
||||
else:
|
||||
self._num_parallel_calls = ops.convert_to_tensor(
|
||||
num_parallel_calls, dtype=dtypes.int32, name="num_parallel_calls")
|
||||
variant_tensor = gen_dataset_ops.parallel_map_dataset(
|
||||
input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
self._map_func.function.captured_inputs,
|
||||
f=self._map_func.function,
|
||||
num_parallel_calls=self._num_parallel_calls,
|
||||
use_inter_op_parallelism=self._use_inter_op_parallelism,
|
||||
preserve_cardinality=self._preserve_cardinality,
|
||||
**self._flat_structure)
|
||||
self._num_parallel_calls = ops.convert_to_tensor(
|
||||
num_parallel_calls, dtype=dtypes.int64, name="num_parallel_calls")
|
||||
variant_tensor = gen_dataset_ops.parallel_map_dataset_v2(
|
||||
input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
self._map_func.function.captured_inputs,
|
||||
f=self._map_func.function,
|
||||
num_parallel_calls=self._num_parallel_calls,
|
||||
deterministic=self._deterministic,
|
||||
use_inter_op_parallelism=self._use_inter_op_parallelism,
|
||||
preserve_cardinality=self._preserve_cardinality,
|
||||
**self._flat_structure)
|
||||
super(ParallelMapDataset, self).__init__(input_dataset, variant_tensor)
|
||||
|
||||
def _functions(self):
|
||||
@ -4294,30 +4244,17 @@ class ParallelInterleaveDataset(UnaryDataset):
|
||||
else:
|
||||
deterministic_string = "false"
|
||||
|
||||
if (buffer_output_elements != AUTOTUNE or
|
||||
prefetch_input_elements != AUTOTUNE or
|
||||
compat.forward_compatible(2020, 3, 6)):
|
||||
variant_tensor = gen_dataset_ops.parallel_interleave_dataset_v4(
|
||||
input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
self._map_func.function.captured_inputs, # pylint: disable=protected-access
|
||||
self._cycle_length,
|
||||
self._block_length,
|
||||
self._buffer_output_elements,
|
||||
self._prefetch_input_elements,
|
||||
self._num_parallel_calls,
|
||||
f=self._map_func.function,
|
||||
deterministic=deterministic_string,
|
||||
**self._flat_structure)
|
||||
else:
|
||||
variant_tensor = gen_dataset_ops.parallel_interleave_dataset_v3(
|
||||
input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
self._map_func.function.captured_inputs, # pylint: disable=protected-access
|
||||
self._cycle_length,
|
||||
self._block_length,
|
||||
self._num_parallel_calls,
|
||||
f=self._map_func.function,
|
||||
deterministic=deterministic_string,
|
||||
**self._flat_structure)
|
||||
variant_tensor = gen_dataset_ops.parallel_interleave_dataset_v4(
|
||||
input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
self._map_func.function.captured_inputs, # pylint: disable=protected-access
|
||||
self._cycle_length,
|
||||
self._block_length,
|
||||
self._buffer_output_elements,
|
||||
self._prefetch_input_elements,
|
||||
self._num_parallel_calls,
|
||||
f=self._map_func.function,
|
||||
deterministic=deterministic_string,
|
||||
**self._flat_structure)
|
||||
super(ParallelInterleaveDataset, self).__init__(input_dataset,
|
||||
variant_tensor)
|
||||
|
||||
|
@ -511,11 +511,6 @@ class DistributedDataset(_IterableInput):
|
||||
else:
|
||||
raise
|
||||
|
||||
# TODO(b/138745411): Remove once stateful transformations are supported.
|
||||
options = dataset_ops.Options()
|
||||
options.experimental_distribute._make_stateless = True # pylint: disable=protected-access
|
||||
dataset = dataset.with_options(options)
|
||||
|
||||
self._cloned_datasets = []
|
||||
if input_context:
|
||||
# Between-graph where we rely on the input_context for sharding
|
||||
@ -1034,10 +1029,6 @@ def _create_iterators_per_worker_with_input_context(input_contexts,
|
||||
worker = input_workers.worker_devices[i]
|
||||
with ops.device(worker):
|
||||
dataset = dataset_fn(ctx)
|
||||
# TODO(b/138745411): Remove once stateful transformations are supported.
|
||||
options = dataset_ops.Options()
|
||||
options.experimental_distribute._make_stateless = True # pylint: disable=protected-access
|
||||
dataset = dataset.with_options(options)
|
||||
devices = input_workers.compute_devices_for_worker(i)
|
||||
iterator = _SingleWorkerDatasetIterator(dataset, worker, devices)
|
||||
iterators.append(iterator)
|
||||
|
@ -26,6 +26,7 @@ from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python import tf2
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.data.experimental.ops.distribute_options import AutoShardPolicy
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.distribute import collective_all_reduce_strategy
|
||||
@ -524,15 +525,51 @@ class DistributedIteratorSingleWorkerTest(DistributedIteratorTestBase,
|
||||
strategy_combinations.central_storage_strategy_with_two_gpus,
|
||||
],
|
||||
))
|
||||
def testCache(self, distribution):
|
||||
self.skipTest("Disable due to breakage.")
|
||||
dataset = dataset_ops.Dataset.range(10).shuffle(10).cache().batch(1)
|
||||
def testCacheAcrossIteration(self, distribution):
|
||||
if not tf2.enabled():
|
||||
self.skipTest("Only V2 is supported.")
|
||||
|
||||
dataset = dataset_ops.Dataset.range(10).shuffle(10).cache().batch(2)
|
||||
dist_dataset = distribution.experimental_distribute_dataset(dataset)
|
||||
|
||||
first_epoch = list(x.numpy() for x in dist_dataset)
|
||||
second_epoch = list(x.numpy() for x in dist_dataset)
|
||||
first_epoch = list(
|
||||
distribution.experimental_local_results(x) for x in dist_dataset)
|
||||
second_epoch = list(
|
||||
distribution.experimental_local_results(x) for x in dist_dataset)
|
||||
|
||||
self.assertEqual(first_epoch, second_epoch)
|
||||
self.assertAllEqual(first_epoch, second_epoch)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
mode=["eager"],
|
||||
distribution=[
|
||||
strategy_combinations.one_device_strategy,
|
||||
strategy_combinations.mirrored_strategy_with_one_cpu,
|
||||
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
||||
strategy_combinations.tpu_strategy,
|
||||
strategy_combinations.central_storage_strategy_with_two_gpus,
|
||||
],
|
||||
reshuffle=[True, False]))
|
||||
def testShuffleAcrossIterations(self, distribution, reshuffle):
|
||||
if not tf2.enabled():
|
||||
self.skipTest("Only V2 is supported.")
|
||||
|
||||
if not reshuffle and not compat.forward_compatible(2020, 5, 22):
|
||||
self.skipTest("Functionality currently not supported.")
|
||||
|
||||
dataset = dataset_ops.Dataset.range(10).shuffle(
|
||||
10, reshuffle_each_iteration=reshuffle).batch(2)
|
||||
dist_dataset = distribution.experimental_distribute_dataset(dataset)
|
||||
|
||||
first_epoch = list(
|
||||
distribution.experimental_local_results(x) for x in dist_dataset)
|
||||
second_epoch = list(
|
||||
distribution.experimental_local_results(x) for x in dist_dataset)
|
||||
|
||||
if reshuffle:
|
||||
self.assertNotAllEqual(first_epoch, second_epoch)
|
||||
else:
|
||||
self.assertAllEqual(first_epoch, second_epoch)
|
||||
|
||||
|
||||
class DistributedIteratorTensorTypeTest(DistributedIteratorTestBase,
|
||||
|
@ -1180,6 +1180,10 @@ tf_module {
|
||||
name: "DummyMemoryCache"
|
||||
argspec: "args=[\'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "DummySeedGenerator"
|
||||
argspec: "args=[\'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "DynamicPartition"
|
||||
argspec: "args=[\'data\', \'partitions\', \'num_partitions\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
@ -3904,6 +3908,10 @@ tf_module {
|
||||
name: "ShuffleDatasetV2"
|
||||
argspec: "args=[\'input_dataset\', \'buffer_size\', \'seed_generator\', \'output_types\', \'output_shapes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "ShuffleDatasetV3"
|
||||
argspec: "args=[\'input_dataset\', \'buffer_size\', \'seed\', \'seed2\', \'seed_generator\', \'output_types\', \'output_shapes\', \'reshuffle_each_iteration\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "ShutdownDistributedTPU"
|
||||
argspec: "args=[\'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
@ -1180,6 +1180,10 @@ tf_module {
|
||||
name: "DummyMemoryCache"
|
||||
argspec: "args=[\'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "DummySeedGenerator"
|
||||
argspec: "args=[\'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "DynamicPartition"
|
||||
argspec: "args=[\'data\', \'partitions\', \'num_partitions\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
@ -3904,6 +3908,10 @@ tf_module {
|
||||
name: "ShuffleDatasetV2"
|
||||
argspec: "args=[\'input_dataset\', \'buffer_size\', \'seed_generator\', \'output_types\', \'output_shapes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "ShuffleDatasetV3"
|
||||
argspec: "args=[\'input_dataset\', \'buffer_size\', \'seed\', \'seed2\', \'seed_generator\', \'output_types\', \'output_shapes\', \'reshuffle_each_iteration\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "ShutdownDistributedTPU"
|
||||
argspec: "args=[\'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
Loading…
Reference in New Issue
Block a user