diff --git a/tensorflow/core/api_def/base_api/api_def_ShuffleAndRepeatDatasetV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_ShuffleAndRepeatDatasetV2.pbtxt new file mode 100644 index 00000000000..93135231b87 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_ShuffleAndRepeatDatasetV2.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "ShuffleAndRepeatDatasetV2" + visibility: HIDDEN +} diff --git a/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.cc b/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.cc index 6f5c32edf26..64bb4528f62 100644 --- a/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.cc +++ b/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.cc @@ -26,12 +26,128 @@ limitations under the License. #include "tensorflow/core/grappler/optimizers/data/graph_utils.h" #include "tensorflow/core/grappler/utils.h" #include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/strcat.h" namespace tensorflow { namespace grappler { namespace { -constexpr char kFusedOpName[] = "ShuffleAndRepeatDataset"; +constexpr char kShuffleDataset[] = "ShuffleDataset"; +constexpr char kShuffleDatasetV2[] = "ShuffleDatasetV2"; +constexpr char kShuffleDatasetV3[] = "ShuffleDatasetV3"; +constexpr char kRepeatDataset[] = "RepeatDataset"; +constexpr char kShuffleAndRepeatDataset[] = "ShuffleAndRepeatDataset"; +constexpr char kShuffleAndRepeatDatasetV2[] = "ShuffleAndRepeatDatasetV2"; + +constexpr char kOutputShapes[] = "output_shapes"; +constexpr char kOutputTypes[] = "output_types"; +constexpr char kReshuffleEachIteration[] = "reshuffle_each_iteration"; + +Status FuseShuffleV1AndRepeat(const NodeDef& shuffle_node, + const NodeDef& repeat_node, + MutableGraphView* graph, GraphDef* output, + NodeDef* fused_node) { + fused_node->set_op(kShuffleAndRepeatDataset); + graph_utils::SetUniqueGraphNodeName(kShuffleAndRepeatDataset, output, + fused_node); + + // Set the `input` input argument. + fused_node->add_input(shuffle_node.input(0)); + + // Set the `buffer_size` input argument. + fused_node->add_input(shuffle_node.input(1)); + + // Set the `seed` input argument. + fused_node->add_input(shuffle_node.input(2)); + + // Set the `seed2` input argument. + fused_node->add_input(shuffle_node.input(3)); + + // Set the `count` input argument. + fused_node->add_input(repeat_node.input(1)); + + // Set `output_types`, `output_shapes`, and `reshuffle_each_iteration` + // attributes. + for (auto key : {kOutputShapes, kOutputTypes, kReshuffleEachIteration}) { + graph_utils::CopyAttribute(key, shuffle_node, fused_node); + } + + return Status::OK(); +} + +Status FuseShuffleV2AndRepeat(const NodeDef& shuffle_node, + const NodeDef& repeat_node, + MutableGraphView* graph, GraphDef* output, + NodeDef* fused_node) { + fused_node->set_op(kShuffleAndRepeatDatasetV2); + graph_utils::SetUniqueGraphNodeName(kShuffleAndRepeatDatasetV2, output, + fused_node); + + NodeDef zero_node = *graph_utils::AddScalarConstNode(0, graph); + + // Set the `input` input argument. + fused_node->add_input(shuffle_node.input(0)); + + // Set the `buffer_size` input argument. + fused_node->add_input(shuffle_node.input(1)); + + // Default the `seed` input argument to 0. + fused_node->add_input(zero_node.name()); + + // Default the `seed2` input argument to 0. + fused_node->add_input(zero_node.name()); + + // Set the `count` input argument. + fused_node->add_input(repeat_node.input(1)); + + // Set the `seed_generator` input argument. + fused_node->add_input(shuffle_node.input(2)); + + // Set `output_types` and `output_shapes` attributes. + for (auto key : {kOutputShapes, kOutputTypes}) { + graph_utils::CopyAttribute(key, shuffle_node, fused_node); + } + + // Default the `reshuffle_each_iteration` attribute to true. + (*fused_node->mutable_attr())[kReshuffleEachIteration].set_b(true); + + return Status::OK(); +} + +Status FuseShuffleV3AndRepeat(const NodeDef& shuffle_node, + const NodeDef& repeat_node, + MutableGraphView* graph, GraphDef* output, + NodeDef* fused_node) { + fused_node->set_op(kShuffleAndRepeatDatasetV2); + graph_utils::SetUniqueGraphNodeName(kShuffleAndRepeatDataset, output, + fused_node); + + // Set the `input` input argument. + fused_node->add_input(shuffle_node.input(0)); + + // Set the `buffer_size` input argument. + fused_node->add_input(shuffle_node.input(1)); + + // Set the `seed` input argument. + fused_node->add_input(shuffle_node.input(2)); + + // Set the `seed2` input argument. + fused_node->add_input(shuffle_node.input(3)); + + // Set the `count` input argument. + fused_node->add_input(repeat_node.input(1)); + + // Set the `seed_generator` input argument. + fused_node->add_input(shuffle_node.input(4)); + + // Set `output_types`, `output_shapes`, and `reshuffle_each_iteration` + // attributes. + for (auto key : {kOutputShapes, kOutputTypes, kReshuffleEachIteration}) { + graph_utils::CopyAttribute(key, shuffle_node, fused_node); + } + + return Status::OK(); +} } // namespace @@ -42,65 +158,46 @@ Status ShuffleAndRepeatFusion::OptimizeAndCollectStats( MutableGraphView graph(output); absl::flat_hash_set nodes_to_delete; - auto make_shuffle_and_repeat_node = [&output](const NodeDef& shuffle_node, - const NodeDef& repeat_node) { - NodeDef new_node; - new_node.set_op(kFusedOpName); - graph_utils::SetUniqueGraphNodeName(kFusedOpName, output, &new_node); - - // Set the `input` input argument. - new_node.add_input(shuffle_node.input(0)); - - // Set the `buffer_size` input argument. - new_node.add_input(shuffle_node.input(1)); - - // Set the `seed` input argument. - new_node.add_input(shuffle_node.input(2)); - - // Set the `seed2` input argument. - new_node.add_input(shuffle_node.input(3)); - - // Set the `count` input argument. - new_node.add_input(repeat_node.input(1)); - - // Set `output_types` and `output_shapes` attributes. - for (auto key : {"output_shapes", "output_types"}) { - graph_utils::CopyAttribute(key, repeat_node, &new_node); - } - return new_node; - }; - - for (const NodeDef& node : item.graph.node()) { - if (node.op() != "RepeatDataset") { + for (const NodeDef& repeat_node : item.graph.node()) { + if (repeat_node.op() != kRepeatDataset) { continue; } - // Use a more descriptive variable name now that we know the node type. - const NodeDef& repeat_node = node; - NodeDef* node2 = graph_utils::GetInputNode(repeat_node, graph); + const NodeDef& shuffle_node = + *graph_utils::GetInputNode(repeat_node, graph); - if (node2->op() != "ShuffleDataset") { + NodeDef fused_node; + if (shuffle_node.op() == kShuffleDataset) { + TF_RETURN_IF_ERROR(FuseShuffleV1AndRepeat(shuffle_node, repeat_node, + &graph, output, &fused_node)); + } else if (shuffle_node.op() == kShuffleDatasetV2) { + TF_RETURN_IF_ERROR(FuseShuffleV2AndRepeat(shuffle_node, repeat_node, + &graph, output, &fused_node)); + + } else if (shuffle_node.op() == kShuffleDatasetV3) { + TF_RETURN_IF_ERROR(FuseShuffleV3AndRepeat(shuffle_node, repeat_node, + &graph, output, &fused_node)); + } else { continue; } - // Use a more descriptive variable name now that we know the node type. - const NodeDef& shuffle_node = *node2; - - // TODO(b/129712758): Remove when the fused kernel supports disabling - // reshuffling for each iteration. - if (HasNodeAttr(shuffle_node, "reshuffle_each_iteration") && - !shuffle_node.attr().at("reshuffle_each_iteration").b()) { - continue; - } - - NodeDef* shuffle_and_repeat_node = - graph.AddNode(make_shuffle_and_repeat_node(shuffle_node, repeat_node)); + NodeDef& shuffle_and_repeat_node = *graph.AddNode(std::move(fused_node)); TF_RETURN_IF_ERROR(graph.UpdateFanouts(repeat_node.name(), - shuffle_and_repeat_node->name())); + shuffle_and_repeat_node.name())); + // Update shuffle node fanouts to shuffle_and_repeat fanouts to take care of + // control dependencies. + TF_RETURN_IF_ERROR(graph.UpdateFanouts(shuffle_node.name(), + shuffle_and_repeat_node.name())); - // Mark the `Shuffle` and `Repeat` nodes for removal. - nodes_to_delete.insert(shuffle_node.name()); - nodes_to_delete.insert(repeat_node.name()); + // Mark the `Shuffle` and `Repeat` nodes for removal (as long as neither of + // them needs to be preserved). + const auto nodes_to_preserve = item.NodesToPreserve(); + if (nodes_to_preserve.find(shuffle_node.name()) == + nodes_to_preserve.end() && + nodes_to_preserve.find(repeat_node.name()) == nodes_to_preserve.end()) { + nodes_to_delete.insert(shuffle_node.name()); + nodes_to_delete.insert(repeat_node.name()); + } stats->num_changes++; } diff --git a/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion_test.cc b/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion_test.cc index 556e1d3ab57..9a5c454ad0c 100644 --- a/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion_test.cc +++ b/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion_test.cc @@ -25,17 +25,21 @@ namespace tensorflow { namespace grappler { namespace { -TEST(ShuffleAndRepeatFusionTest, FuseShuffleAndRepeatNodesIntoOne) { +constexpr char kOutputShapes[] = "output_shapes"; +constexpr char kOutputTypes[] = "output_types"; +constexpr char kReshuffleEachIteration[] = "reshuffle_each_iteration"; + +TEST(ShuffleAndRepeatFusionTest, FuseShuffleV1AndRepeat) { GrapplerItem item; MutableGraphView graph(&item.graph); std::vector> common_attrs(2); AttrValue shapes_attr; - SetAttrValue("output_shapes", &shapes_attr); - common_attrs[0] = std::make_pair("output_shapes", shapes_attr); + SetAttrValue(kOutputShapes, &shapes_attr); + common_attrs[0] = std::make_pair(kOutputShapes, shapes_attr); AttrValue types_attr; - SetAttrValue("output_types", &types_attr); - common_attrs[1] = std::make_pair("output_types", types_attr); + SetAttrValue(kOutputTypes, &types_attr); + common_attrs[1] = std::make_pair(kOutputTypes, types_attr); NodeDef *start_node = graph_utils::AddScalarConstNode(0, &graph); NodeDef *stop_node = graph_utils::AddScalarConstNode(10, &graph); @@ -59,6 +63,7 @@ TEST(ShuffleAndRepeatFusionTest, FuseShuffleAndRepeatNodesIntoOne) { shuffle_inputs[3] = seed2_node->name(); NodeDef *shuffle_node = graph_utils::AddNode( "", "ShuffleDataset", shuffle_inputs, common_attrs, &graph); + (*shuffle_node->mutable_attr())[kReshuffleEachIteration].set_b(true); NodeDef *count_node = graph_utils::AddScalarConstNode(-1, &graph); std::vector repeat_inputs(2); @@ -85,12 +90,148 @@ TEST(ShuffleAndRepeatFusionTest, FuseShuffleAndRepeatNodesIntoOne) { EXPECT_EQ(shuffle_and_repeat_node.input(2), shuffle_node->input(2)); EXPECT_EQ(shuffle_and_repeat_node.input(3), shuffle_node->input(3)); EXPECT_EQ(shuffle_and_repeat_node.input(4), repeat_node->input(1)); + for (const auto &attr : + {kOutputShapes, kOutputTypes, kReshuffleEachIteration}) { + EXPECT_TRUE(AreAttrValuesEqual(shuffle_and_repeat_node.attr().at(attr), + shuffle_node->attr().at(attr))); + } +} + +TEST(ShuffleAndRepeatFusionTest, FuseShuffleV2AndRepeat) { + GrapplerItem item; + MutableGraphView graph(&item.graph); + + std::vector> common_attrs(2); + AttrValue shapes_attr; + SetAttrValue(kOutputShapes, &shapes_attr); + common_attrs[0] = std::make_pair(kOutputShapes, shapes_attr); + AttrValue types_attr; + SetAttrValue(kOutputTypes, &types_attr); + common_attrs[1] = std::make_pair(kOutputTypes, types_attr); + + NodeDef *start_node = graph_utils::AddScalarConstNode(0, &graph); + NodeDef *stop_node = graph_utils::AddScalarConstNode(10, &graph); + NodeDef *step_node = graph_utils::AddScalarConstNode(1, &graph); + + std::vector range_inputs(3); + range_inputs[0] = start_node->name(); + range_inputs[1] = stop_node->name(); + range_inputs[2] = step_node->name(); + NodeDef *range_node = graph_utils::AddNode("", "RangeDataset", range_inputs, + common_attrs, &graph); + + NodeDef *buffer_size_node = + graph_utils::AddScalarConstNode(128, &graph); + NodeDef *seed_generator_node = + graph_utils::AddScalarConstNode("dummy_resource", &graph); + std::vector shuffle_inputs(3); + shuffle_inputs[0] = range_node->name(); + shuffle_inputs[1] = buffer_size_node->name(); + shuffle_inputs[2] = seed_generator_node->name(); + NodeDef *shuffle_node = graph_utils::AddNode( + "", "ShuffleDatasetV2", shuffle_inputs, common_attrs, &graph); + + NodeDef *count_node = graph_utils::AddScalarConstNode(-1, &graph); + std::vector repeat_inputs(2); + repeat_inputs[0] = shuffle_node->name(); + repeat_inputs[1] = count_node->name(); + NodeDef *repeat_node = graph_utils::AddNode( + "", "RepeatDataset", repeat_inputs, common_attrs, &graph); + + ShuffleAndRepeatFusion optimizer; + GraphDef output; + TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output)); + + EXPECT_FALSE( + graph_utils::ContainsGraphNodeWithName(shuffle_node->name(), output)); + EXPECT_FALSE( + graph_utils::ContainsGraphNodeWithName(repeat_node->name(), output)); EXPECT_TRUE( - AreAttrValuesEqual(shuffle_and_repeat_node.attr().at("output_shapes"), - repeat_node->attr().at("output_shapes"))); + graph_utils::ContainsNodeWithOp("ShuffleAndRepeatDatasetV2", output)); + NodeDef shuffle_and_repeat_node = output.node( + graph_utils::FindGraphNodeWithOp("ShuffleAndRepeatDatasetV2", output)); + EXPECT_EQ(shuffle_and_repeat_node.input_size(), 6); + EXPECT_EQ(shuffle_and_repeat_node.input(0), shuffle_node->input(0)); + EXPECT_EQ(shuffle_and_repeat_node.input(1), shuffle_node->input(1)); + EXPECT_EQ(shuffle_and_repeat_node.input(4), repeat_node->input(1)); + EXPECT_EQ(shuffle_and_repeat_node.input(5), shuffle_node->input(2)); + for (const auto &attr : {kOutputShapes, kOutputTypes}) { + EXPECT_TRUE(AreAttrValuesEqual(shuffle_and_repeat_node.attr().at(attr), + shuffle_node->attr().at(attr))); + } + EXPECT_TRUE(shuffle_and_repeat_node.attr().at(kReshuffleEachIteration).b()); +} + +TEST(ShuffleAndRepeatFusionTest, FuseShuffleV3AndRepeat) { + GrapplerItem item; + MutableGraphView graph(&item.graph); + + std::vector> common_attrs(2); + AttrValue shapes_attr; + SetAttrValue(kOutputShapes, &shapes_attr); + common_attrs[0] = std::make_pair(kOutputShapes, shapes_attr); + AttrValue types_attr; + SetAttrValue(kOutputTypes, &types_attr); + common_attrs[1] = std::make_pair(kOutputTypes, types_attr); + + NodeDef *start_node = graph_utils::AddScalarConstNode(0, &graph); + NodeDef *stop_node = graph_utils::AddScalarConstNode(10, &graph); + NodeDef *step_node = graph_utils::AddScalarConstNode(1, &graph); + + std::vector range_inputs(3); + range_inputs[0] = start_node->name(); + range_inputs[1] = stop_node->name(); + range_inputs[2] = step_node->name(); + NodeDef *range_node = graph_utils::AddNode("", "RangeDataset", range_inputs, + common_attrs, &graph); + + NodeDef *buffer_size_node = + graph_utils::AddScalarConstNode(128, &graph); + NodeDef *seed_node = graph_utils::AddScalarConstNode(-1, &graph); + NodeDef *seed2_node = graph_utils::AddScalarConstNode(-1, &graph); + NodeDef *seed_generator_node = + graph_utils::AddScalarConstNode("dummy_resource", &graph); + std::vector shuffle_inputs(5); + shuffle_inputs[0] = range_node->name(); + shuffle_inputs[1] = buffer_size_node->name(); + shuffle_inputs[2] = seed_node->name(); + shuffle_inputs[3] = seed2_node->name(); + shuffle_inputs[4] = seed_generator_node->name(); + NodeDef *shuffle_node = graph_utils::AddNode( + "", "ShuffleDatasetV3", shuffle_inputs, common_attrs, &graph); + (*shuffle_node->mutable_attr())[kReshuffleEachIteration].set_b(true); + + NodeDef *count_node = graph_utils::AddScalarConstNode(-1, &graph); + std::vector repeat_inputs(2); + repeat_inputs[0] = shuffle_node->name(); + repeat_inputs[1] = count_node->name(); + NodeDef *repeat_node = graph_utils::AddNode( + "", "RepeatDataset", repeat_inputs, common_attrs, &graph); + + ShuffleAndRepeatFusion optimizer; + GraphDef output; + TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output)); + + EXPECT_FALSE( + graph_utils::ContainsGraphNodeWithName(shuffle_node->name(), output)); + EXPECT_FALSE( + graph_utils::ContainsGraphNodeWithName(repeat_node->name(), output)); EXPECT_TRUE( - AreAttrValuesEqual(shuffle_and_repeat_node.attr().at("output_types"), - repeat_node->attr().at("output_types"))); + graph_utils::ContainsNodeWithOp("ShuffleAndRepeatDatasetV2", output)); + NodeDef shuffle_and_repeat_node = output.node( + graph_utils::FindGraphNodeWithOp("ShuffleAndRepeatDatasetV2", output)); + EXPECT_EQ(shuffle_and_repeat_node.input_size(), 6); + EXPECT_EQ(shuffle_and_repeat_node.input(0), shuffle_node->input(0)); + EXPECT_EQ(shuffle_and_repeat_node.input(1), shuffle_node->input(1)); + EXPECT_EQ(shuffle_and_repeat_node.input(2), shuffle_node->input(2)); + EXPECT_EQ(shuffle_and_repeat_node.input(3), shuffle_node->input(3)); + EXPECT_EQ(shuffle_and_repeat_node.input(4), repeat_node->input(1)); + EXPECT_EQ(shuffle_and_repeat_node.input(5), shuffle_node->input(4)); + for (const auto &attr : + {kOutputShapes, kOutputTypes, kReshuffleEachIteration}) { + EXPECT_TRUE(AreAttrValuesEqual(shuffle_and_repeat_node.attr().at(attr), + shuffle_node->attr().at(attr))); + } } TEST(ShuffleAndRepeatFusionTest, NoChange) { @@ -99,11 +240,11 @@ TEST(ShuffleAndRepeatFusionTest, NoChange) { std::vector> common_attrs(2); AttrValue shapes_attr; - SetAttrValue("output_shapes", &shapes_attr); - common_attrs[0] = std::make_pair("output_shapes", shapes_attr); + SetAttrValue(kOutputShapes, &shapes_attr); + common_attrs[0] = std::make_pair(kOutputShapes, shapes_attr); AttrValue types_attr; - SetAttrValue("output_types", &types_attr); - common_attrs[1] = std::make_pair("output_types", types_attr); + SetAttrValue(kOutputTypes, &types_attr); + common_attrs[1] = std::make_pair(kOutputTypes, types_attr); NodeDef *start_node = graph_utils::AddScalarConstNode(0, &graph); NodeDef *stop_node = graph_utils::AddScalarConstNode(10, &graph); diff --git a/tensorflow/core/kernels/data/shuffle_dataset_op.cc b/tensorflow/core/kernels/data/shuffle_dataset_op.cc index 852ba23e774..3e549246a95 100644 --- a/tensorflow/core/kernels/data/shuffle_dataset_op.cc +++ b/tensorflow/core/kernels/data/shuffle_dataset_op.cc @@ -44,10 +44,10 @@ namespace data { /* static */ constexpr const char* const ShuffleDatasetOpBase::kSeed2; /* static */ constexpr const char* const ShuffleDatasetOpBase::kOutputTypes; /* static */ constexpr const char* const ShuffleDatasetOpBase::kOutputShapes; +/* static */ constexpr const char* const + ShuffleDatasetOpBase::kReshuffleEachIteration; /* static */ constexpr const char* const ShuffleDatasetOp::kDatasetType; -/* static */ constexpr const char* const - ShuffleDatasetOp::kReshuffleEachIteration; /* static */ constexpr const char* const ShuffleAndRepeatDatasetOp::kDatasetType; @@ -72,6 +72,8 @@ constexpr char kEpochNumRandomSamples[] = "epoch_num_random_samples"; constexpr char kShuffleDatasetV1[] = "ShuffleDataset"; constexpr char kShuffleDatasetV2[] = "ShuffleDatasetV2"; constexpr char kShuffleDatasetV3[] = "ShuffleDatasetV3"; +constexpr char kShuffleAndRepeatDatasetV1[] = "ShuffleAndRepeatDatasetV1"; +constexpr char kShuffleAndRepeatDatasetV2[] = "ShuffleAndRepeatDatasetV2"; ShuffleDatasetOpBase::ShuffleDatasetOpBase(OpKernelConstruction* ctx) : UnaryDatasetOpKernel(ctx) {} @@ -225,6 +227,10 @@ class ShuffleDatasetOpBase::ShuffleDatasetBase : public DatasetBase { while (!slices_.empty() && slices_.front()->start == slices_.front()->end) { slices_.pop_front(); + // Reinitialize the RNG state for the next epoch. + num_random_samples_ = 0; + seed_generator_->GenerateSeeds(&seed_, &seed2_); + ResetRngs(); } DCHECK(!slices_.empty()); // Choose an element to produce uniformly at random from the first @@ -663,6 +669,7 @@ void ShuffleDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input, RandomSeeds seeds(seed, seed2); bool owns_resource = false; if (errors::IsNotFound(s)) { + owns_resource = true; OP_REQUIRES_OK( ctx, ctx->resource_manager()->LookupOrCreate( @@ -679,7 +686,6 @@ void ShuffleDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input, return Status::OK(); })); handle = MakeResourceHandle(ctx, container, name); - owns_resource = true; } else { OP_REQUIRES_OK(ctx, s); } @@ -695,6 +701,7 @@ void ShuffleDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input, handle.container(), handle.name(), &manager); bool owns_resource = false; if (errors::IsNotFound(s)) { + owns_resource = true; LOG(WARNING) << "Failed to find seed generator resource. Falling back to " "using a non-deterministically seeded generator and " "reshuffling each iteration."; @@ -708,7 +715,6 @@ void ShuffleDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input, return Status::OK(); })); handle = MakeResourceHandle(ctx, container, name); - owns_resource = true; } else { OP_REQUIRES_OK(ctx, s); } @@ -790,9 +796,13 @@ class ShuffleAndRepeatDatasetOp::Dataset : public ShuffleDatasetBase { TF_RETURN_IF_ERROR(b->AddScalar(seeds_.input_seed(), &seed)); TF_RETURN_IF_ERROR(b->AddScalar(seeds_.input_seed2(), &seed2)); TF_RETURN_IF_ERROR(b->AddScalar(count_, &count)); + AttrValue reshuffle_each_iteration; + b->BuildAttrValue(seed_generator_->reshuffle_each_iteration(), + &reshuffle_each_iteration); TF_RETURN_IF_ERROR(b->AddDataset( this, {input_graph_node, buffer_size, seed, seed2, count}, // Inputs - {}, // Attrs + {std::make_pair(kReshuffleEachIteration, + reshuffle_each_iteration)}, // Attrs output)); return Status::OK(); } @@ -804,8 +814,83 @@ class ShuffleAndRepeatDatasetOp::Dataset : public ShuffleDatasetBase { const RandomSeeds seeds_; }; +class ShuffleAndRepeatDatasetOp::DatasetV2 : public ShuffleDatasetBase { + public: + DatasetV2(OpKernelContext* ctx, const DatasetBase* input, int64 buffer_size, + int64 count, RandomSeeds&& seeds, SeedGeneratorManager* manager, + ResourceHandle&& resource_handle, bool owns_resource) + : ShuffleDatasetBase(ctx, input, buffer_size, manager->get(), count), + manager_(manager), + owns_resource_(owns_resource), + resource_handle_(std::move(resource_handle)), + resource_mgr_(ctx->resource_manager()), + seeds_(std::move(seeds)) {} + + ~DatasetV2() override { + manager_->Unref(); + if (owns_resource_) { + Status s = resource_mgr_->Delete( + resource_handle_.container(), resource_handle_.name()); + if (!s.ok()) { + LOG(WARNING) << "Failed to delete RNG resource: " << s.ToString(); + } + } + } + + string op_type() const override { return kDatasetType; } + + protected: + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { + Node* input_graph_node = nullptr; + TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); + Node* buffer_size_node = nullptr; + Node* seed_node = nullptr; + Node* seed2_node = nullptr; + Node* count_node = nullptr; + TF_RETURN_IF_ERROR(b->AddScalar(buffer_size_, &buffer_size_node)); + TF_RETURN_IF_ERROR(b->AddScalar(seeds_.input_seed(), &seed_node)); + TF_RETURN_IF_ERROR(b->AddScalar(seeds_.input_seed2(), &seed2_node)); + TF_RETURN_IF_ERROR(b->AddScalar(count_, &count_node)); + Node* resource_handle_node = nullptr; + Tensor handle(DT_RESOURCE, TensorShape({})); + handle.scalar()() = resource_handle_; + TF_RETURN_IF_ERROR(b->AddTensor(handle, &resource_handle_node)); + AttrValue reshuffle_each_iteration; + b->BuildAttrValue(seed_generator_->reshuffle_each_iteration(), + &reshuffle_each_iteration); + TF_RETURN_IF_ERROR( + b->AddDataset(this, + {input_graph_node, buffer_size_node, seed_node, + seed2_node, count_node, resource_handle_node}, // Inputs + {std::make_pair(kReshuffleEachIteration, + reshuffle_each_iteration)}, // Attrs + output)); + return Status::OK(); + } + + private: + SeedGeneratorManager* const manager_; // Owned + const bool owns_resource_; + const ResourceHandle resource_handle_; + ResourceMgr* const resource_mgr_; // Not owned. + const RandomSeeds seeds_; +}; + ShuffleAndRepeatDatasetOp::ShuffleAndRepeatDatasetOp(OpKernelConstruction* ctx) - : ShuffleDatasetOpBase(ctx) {} + : ShuffleDatasetOpBase(ctx) { + auto& op_name = ctx->def().op(); + if (op_name == kShuffleAndRepeatDatasetV2) { + op_version_ = 2; + } else if (op_name == kShuffleAndRepeatDatasetV1) { + op_version_ = 1; + } + if (ctx->HasAttr(kReshuffleEachIteration)) { + OP_REQUIRES_OK( + ctx, ctx->GetAttr(kReshuffleEachIteration, &reshuffle_each_iteration_)); + } +} void ShuffleAndRepeatDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input, @@ -826,29 +911,76 @@ void ShuffleAndRepeatDatasetOp::MakeDataset(OpKernelContext* ctx, int64 count; OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kCount, &count)); - RandomSeeds seeds(seed, seed2); - OP_REQUIRES(ctx, count > 0 || count == -1, errors::InvalidArgument( "count must be greater than zero or equal to -1.")); + RandomSeeds seeds(seed, seed2); + static std::atomic resource_id_counter(0); const string& container = ctx->resource_manager()->default_container(); auto name = strings::StrCat(ctx->op_kernel().name(), "/", kSeedGenerator, "_", resource_id_counter.fetch_add(1)); - SeedGeneratorManager* manager; - OP_REQUIRES_OK( - ctx, - ctx->resource_manager()->LookupOrCreate( - container, name, &manager, [&seeds](SeedGeneratorManager** manager) { - *manager = new SeedGeneratorManager(new RandomSeedGenerator(seeds)); - return Status::OK(); - })); - auto handle = MakeResourceHandle(ctx, container, name); + if (op_version_ == 2) { + auto handle = HandleFromInput(ctx, 5); + SeedGeneratorManager* manager = nullptr; + Status s = ctx->resource_manager()->Lookup( + handle.container(), handle.name(), &manager); + bool owns_resource = false; + if (errors::IsNotFound(s)) { + owns_resource = true; + OP_REQUIRES_OK( + ctx, + ctx->resource_manager()->LookupOrCreate( + container, name, &manager, + [reshuffle = reshuffle_each_iteration_, + &seeds](SeedGeneratorManager** manager) { + if (reshuffle) { + *manager = + new SeedGeneratorManager(new RandomSeedGenerator(seeds)); + } else { + *manager = + new SeedGeneratorManager(new FixedSeedGenerator(seeds)); + } + return Status::OK(); + })); + handle = MakeResourceHandle(ctx, container, name); + } else { + OP_REQUIRES_OK(ctx, s); + } - // Ownership of manager is transferred onto `Dataset`. - *output = new Dataset(ctx, input, buffer_size, std::move(seeds), manager, - count, std::move(handle)); + // Ownership of manager is transferred onto `DatasetV2`. + *output = new ShuffleAndRepeatDatasetOp::DatasetV2( + ctx, input, buffer_size, count, std::move(seeds), manager, + std::move(handle), owns_resource); + } else { + if (op_version_ != 1) { + LOG(WARNING) << "Unsupported version of shuffle dataset op: " + << op_version_ << ". Defaulting to version 1."; + } + SeedGeneratorManager* manager; + OP_REQUIRES_OK( + ctx, + ctx->resource_manager()->LookupOrCreate( + container, name, &manager, + [reshuffle = reshuffle_each_iteration_, + &seeds](SeedGeneratorManager** manager) { + if (reshuffle) { + *manager = + new SeedGeneratorManager(new RandomSeedGenerator(seeds)); + } else { + *manager = + new SeedGeneratorManager(new FixedSeedGenerator(seeds)); + } + return Status::OK(); + })); + auto handle = + MakeResourceHandle(ctx, container, name); + + // Ownership of manager is transferred onto `Dataset`. + *output = new Dataset(ctx, input, buffer_size, std::move(seeds), manager, + count, std::move(handle)); + } } namespace { @@ -863,6 +995,9 @@ REGISTER_KERNEL_BUILDER(Name("ShuffleDatasetV3").Device(DEVICE_CPU), REGISTER_KERNEL_BUILDER(Name("ShuffleAndRepeatDataset").Device(DEVICE_CPU), ShuffleAndRepeatDatasetOp); + +REGISTER_KERNEL_BUILDER(Name("ShuffleAndRepeatDatasetV2").Device(DEVICE_CPU), + ShuffleAndRepeatDatasetOp); } // namespace } // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/shuffle_dataset_op.h b/tensorflow/core/kernels/data/shuffle_dataset_op.h index 7aa3c0e3ef0..f33f75c84eb 100644 --- a/tensorflow/core/kernels/data/shuffle_dataset_op.h +++ b/tensorflow/core/kernels/data/shuffle_dataset_op.h @@ -28,6 +28,8 @@ class ShuffleDatasetOpBase : public UnaryDatasetOpKernel { static constexpr const char* const kSeed2 = "seed2"; static constexpr const char* const kOutputTypes = "output_types"; static constexpr const char* const kOutputShapes = "output_shapes"; + static constexpr const char* const kReshuffleEachIteration = + "reshuffle_each_iteration"; explicit ShuffleDatasetOpBase(OpKernelConstruction* ctx); @@ -38,8 +40,6 @@ class ShuffleDatasetOpBase : public UnaryDatasetOpKernel { class ShuffleDatasetOp : public ShuffleDatasetOpBase { public: static constexpr const char* const kDatasetType = "Shuffle"; - static constexpr const char* const kReshuffleEachIteration = - "reshuffle_each_iteration"; explicit ShuffleDatasetOp(OpKernelConstruction* ctx); @@ -52,7 +52,7 @@ class ShuffleDatasetOp : public ShuffleDatasetOpBase { class DatasetV2; class DatasetV3; int op_version_ = 0; - bool reshuffle_each_iteration_; + bool reshuffle_each_iteration_ = true; }; class ShuffleAndRepeatDatasetOp : public ShuffleDatasetOpBase { @@ -68,6 +68,9 @@ class ShuffleAndRepeatDatasetOp : public ShuffleDatasetOpBase { private: class Dataset; + class DatasetV2; + int op_version_ = 0; + bool reshuffle_each_iteration_ = true; }; } // namespace data diff --git a/tensorflow/core/kernels/data/shuffle_dataset_op_test.cc b/tensorflow/core/kernels/data/shuffle_dataset_op_test.cc index 6d16d76ea61..65f6855b7fa 100644 --- a/tensorflow/core/kernels/data/shuffle_dataset_op_test.cc +++ b/tensorflow/core/kernels/data/shuffle_dataset_op_test.cc @@ -72,10 +72,8 @@ class ShuffleDatasetParams : public DatasetParams { output_dtypes_); attr_vector->emplace_back(ShuffleDatasetOpBase::kOutputShapes, output_shapes_); - if (count_ == 1) { - attr_vector->emplace_back(ShuffleDatasetOp::kReshuffleEachIteration, - reshuffle_each_iteration_); - } + attr_vector->emplace_back(ShuffleDatasetOp::kReshuffleEachIteration, + reshuffle_each_iteration_); return Status::OK(); } @@ -297,23 +295,23 @@ std::vector> GetNextTestCases() { {/*dataset_params=*/ShuffleDatasetParams7(), /*expected_shuffle_outputs=*/ CreateTensors(TensorShape({}), - {{2}, {6}, {1}, {3}, {9}, {5}, {0}, {8}, {7}, {4}, - {0}, {5}, {1}, {7}, {2}, {9}, {8}, {4}, {6}, {3}}), + {{9}, {0}, {8}, {6}, {1}, {3}, {7}, {2}, {4}, {5}, + {9}, {0}, {8}, {6}, {1}, {3}, {7}, {2}, {4}, {5}}), /*expected_reshuffle_outputs=*/ - CreateTensors(TensorShape({}), {{1}, {6}, {0}, {5}, {2}, {7}, {4}, - {3}, {9}, {8}, {6}, {5}, {0}, {9}, - {4}, {7}, {2}, {8}, {1}, {3}})}, + CreateTensors(TensorShape({}), {{9}, {0}, {8}, {6}, {1}, {3}, {7}, + {2}, {4}, {5}, {9}, {0}, {8}, {6}, + {1}, {3}, {7}, {2}, {4}, {5}})}, {/*dataset_params=*/ShuffleDatasetParams8(), /*expected_shuffle_outputs=*/ CreateTensors( TensorShape({}), - {{1}, {2}, {0}, {1}, {2}, {0}, {1}, {0}, {2}, {1}, {0}, - {2}, {0}, {2}, {1}, {0}, {1}, {2}, {1}, {2}, {0}}), + {{2}, {0}, {1}, {2}, {0}, {1}, {2}, {0}, {1}, {2}, {0}, + {1}, {2}, {0}, {1}, {2}, {0}, {1}, {2}, {0}, {1}}), /*expected_reshuffle_outputs=*/ CreateTensors( TensorShape({}), - {{1}, {0}, {2}, {0}, {1}, {2}, {2}, {1}, {0}, {0}, {1}, - {2}, {0}, {2}, {1}, {0}, {1}, {2}, {1}, {0}, {2}})}}; + {{2}, {0}, {1}, {2}, {0}, {1}, {2}, {0}, {1}, {2}, {0}, + {1}, {2}, {0}, {1}, {2}, {0}, {1}, {2}, {0}, {1}})}}; } class ParameterizedGetNextTest : public ShuffleDatasetOpTest, @@ -496,16 +494,16 @@ IteratorSaveAndRestoreTestCases() { {/*dataset_params=*/ShuffleDatasetParams7(), /*breakpoints=*/{0, 5, 22}, /*expected_shuffle_outputs=*/ - CreateTensors(TensorShape({}), {{2}, {6}, {1}, {3}, {9}, {5}, {0}, - {8}, {7}, {4}, {0}, {5}, {1}, {7}, - {2}, {9}, {8}, {4}, {6}, {3}})}, + CreateTensors(TensorShape({}), {{9}, {0}, {8}, {6}, {1}, {3}, {7}, + {2}, {4}, {5}, {9}, {0}, {8}, {6}, + {1}, {3}, {7}, {2}, {4}, {5}})}, {/*dataset_params=*/ShuffleDatasetParams8(), /*breakpoints=*/{0, 5, 20}, /*expected_shuffle_outputs=*/ CreateTensors( TensorShape({}), - {{1}, {2}, {0}, {1}, {2}, {0}, {1}, {0}, {2}, {1}, {0}, - {2}, {0}, {2}, {1}, {0}, {1}, {2}, {1}, {2}, {0}})}}; + {{2}, {0}, {1}, {2}, {0}, {1}, {2}, {0}, {1}, {2}, {0}, + {1}, {2}, {0}, {1}, {2}, {0}, {1}, {2}, {0}, {1}})}}; } class ParameterizedIteratorSaveAndRestoreTest diff --git a/tensorflow/core/ops/dataset_ops.cc b/tensorflow/core/ops/dataset_ops.cc index 6dc2280feae..ab2cf35fa08 100644 --- a/tensorflow/core/ops/dataset_ops.cc +++ b/tensorflow/core/ops/dataset_ops.cc @@ -507,6 +507,7 @@ REGISTER_OP("ShuffleAndRepeatDataset") .Output("handle: variant") .Attr("output_types: list(type) >= 1") .Attr("output_shapes: list(shape) >= 1") + .Attr("reshuffle_each_iteration: bool = true") .SetShapeFn([](shape_inference::InferenceContext* c) { shape_inference::ShapeHandle unused; // buffer_size, seed, seed2, and count should be scalars. @@ -517,6 +518,28 @@ REGISTER_OP("ShuffleAndRepeatDataset") return shape_inference::ScalarShape(c); }); +REGISTER_OP("ShuffleAndRepeatDatasetV2") + .Input("input_dataset: variant") + .Input("buffer_size: int64") + .Input("seed: int64") + .Input("seed2: int64") + .Input("count: int64") + .Input("seed_generator: resource") + .Output("handle: variant") + .Attr("reshuffle_each_iteration: bool = true") + .Attr("output_types: list(type) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .SetShapeFn([](shape_inference::InferenceContext* c) { + shape_inference::ShapeHandle unused; + // buffer_size, seed, seed2, count, and seed_generator should be scalars. + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); + return shape_inference::ScalarShape(c); + }); + REGISTER_OP("AnonymousMemoryCache") .Output("handle: resource") .Output("deleter: variant") diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/shuffle_and_repeat_fusion_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/shuffle_and_repeat_fusion_test.py index ad1a98134b8..9dfeec75c95 100644 --- a/tensorflow/python/data/experimental/kernel_tests/optimization/shuffle_and_repeat_fusion_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/optimization/shuffle_and_repeat_fusion_test.py @@ -19,11 +19,9 @@ from __future__ import print_function from absl.testing import parameterized -from tensorflow.python import tf2 from tensorflow.python.data.experimental.ops import testing from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.eager import context from tensorflow.python.framework import combinations from tensorflow.python.framework import errors from tensorflow.python.platform import test @@ -34,11 +32,7 @@ class ShuffleAndRepeatFusionTest(test_base.DatasetTestBase, @combinations.generate(test_base.default_test_combinations()) def testShuffleAndRepeatFusion(self): - if tf2.enabled() and context.executing_eagerly(): - expected = "Shuffle" - else: - expected = "ShuffleAndRepeat" - + expected = "ShuffleAndRepeat" dataset = dataset_ops.Dataset.range(10).apply( testing.assert_next([expected])).shuffle(10).repeat(2) options = dataset_ops.Options() diff --git a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt index 1f050e933ed..cf6b807502c 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt @@ -3898,7 +3898,11 @@ tf_module { } member_method { name: "ShuffleAndRepeatDataset" - argspec: "args=[\'input_dataset\', \'buffer_size\', \'seed\', \'seed2\', \'count\', \'output_types\', \'output_shapes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'input_dataset\', \'buffer_size\', \'seed\', \'seed2\', \'count\', \'output_types\', \'output_shapes\', \'reshuffle_each_iteration\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], " + } + member_method { + name: "ShuffleAndRepeatDatasetV2" + argspec: "args=[\'input_dataset\', \'buffer_size\', \'seed\', \'seed2\', \'count\', \'seed_generator\', \'output_types\', \'output_shapes\', \'reshuffle_each_iteration\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], " } member_method { name: "ShuffleDataset" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt index 1f050e933ed..cf6b807502c 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt @@ -3898,7 +3898,11 @@ tf_module { } member_method { name: "ShuffleAndRepeatDataset" - argspec: "args=[\'input_dataset\', \'buffer_size\', \'seed\', \'seed2\', \'count\', \'output_types\', \'output_shapes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'input_dataset\', \'buffer_size\', \'seed\', \'seed2\', \'count\', \'output_types\', \'output_shapes\', \'reshuffle_each_iteration\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], " + } + member_method { + name: "ShuffleAndRepeatDatasetV2" + argspec: "args=[\'input_dataset\', \'buffer_size\', \'seed\', \'seed2\', \'count\', \'seed_generator\', \'output_types\', \'output_shapes\', \'reshuffle_each_iteration\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], " } member_method { name: "ShuffleDataset"