[tf.data] Supporting shuffle + repeat fusion in TF 2 (and with correct reshuffle_each_iteration
semantics).
PiperOrigin-RevId: 309135252 Change-Id: I157a09027d43d8943e764573b359755bc9380167
This commit is contained in:
parent
3dc9712013
commit
36c49a6013
@ -0,0 +1,4 @@
|
|||||||
|
op {
|
||||||
|
graph_op_name: "ShuffleAndRepeatDatasetV2"
|
||||||
|
visibility: HIDDEN
|
||||||
|
}
|
@ -26,12 +26,128 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
|
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
|
||||||
#include "tensorflow/core/grappler/utils.h"
|
#include "tensorflow/core/grappler/utils.h"
|
||||||
#include "tensorflow/core/platform/protobuf.h"
|
#include "tensorflow/core/platform/protobuf.h"
|
||||||
|
#include "tensorflow/core/platform/strcat.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace grappler {
|
namespace grappler {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
constexpr char kFusedOpName[] = "ShuffleAndRepeatDataset";
|
constexpr char kShuffleDataset[] = "ShuffleDataset";
|
||||||
|
constexpr char kShuffleDatasetV2[] = "ShuffleDatasetV2";
|
||||||
|
constexpr char kShuffleDatasetV3[] = "ShuffleDatasetV3";
|
||||||
|
constexpr char kRepeatDataset[] = "RepeatDataset";
|
||||||
|
constexpr char kShuffleAndRepeatDataset[] = "ShuffleAndRepeatDataset";
|
||||||
|
constexpr char kShuffleAndRepeatDatasetV2[] = "ShuffleAndRepeatDatasetV2";
|
||||||
|
|
||||||
|
constexpr char kOutputShapes[] = "output_shapes";
|
||||||
|
constexpr char kOutputTypes[] = "output_types";
|
||||||
|
constexpr char kReshuffleEachIteration[] = "reshuffle_each_iteration";
|
||||||
|
|
||||||
|
Status FuseShuffleV1AndRepeat(const NodeDef& shuffle_node,
|
||||||
|
const NodeDef& repeat_node,
|
||||||
|
MutableGraphView* graph, GraphDef* output,
|
||||||
|
NodeDef* fused_node) {
|
||||||
|
fused_node->set_op(kShuffleAndRepeatDataset);
|
||||||
|
graph_utils::SetUniqueGraphNodeName(kShuffleAndRepeatDataset, output,
|
||||||
|
fused_node);
|
||||||
|
|
||||||
|
// Set the `input` input argument.
|
||||||
|
fused_node->add_input(shuffle_node.input(0));
|
||||||
|
|
||||||
|
// Set the `buffer_size` input argument.
|
||||||
|
fused_node->add_input(shuffle_node.input(1));
|
||||||
|
|
||||||
|
// Set the `seed` input argument.
|
||||||
|
fused_node->add_input(shuffle_node.input(2));
|
||||||
|
|
||||||
|
// Set the `seed2` input argument.
|
||||||
|
fused_node->add_input(shuffle_node.input(3));
|
||||||
|
|
||||||
|
// Set the `count` input argument.
|
||||||
|
fused_node->add_input(repeat_node.input(1));
|
||||||
|
|
||||||
|
// Set `output_types`, `output_shapes`, and `reshuffle_each_iteration`
|
||||||
|
// attributes.
|
||||||
|
for (auto key : {kOutputShapes, kOutputTypes, kReshuffleEachIteration}) {
|
||||||
|
graph_utils::CopyAttribute(key, shuffle_node, fused_node);
|
||||||
|
}
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status FuseShuffleV2AndRepeat(const NodeDef& shuffle_node,
|
||||||
|
const NodeDef& repeat_node,
|
||||||
|
MutableGraphView* graph, GraphDef* output,
|
||||||
|
NodeDef* fused_node) {
|
||||||
|
fused_node->set_op(kShuffleAndRepeatDatasetV2);
|
||||||
|
graph_utils::SetUniqueGraphNodeName(kShuffleAndRepeatDatasetV2, output,
|
||||||
|
fused_node);
|
||||||
|
|
||||||
|
NodeDef zero_node = *graph_utils::AddScalarConstNode<int64>(0, graph);
|
||||||
|
|
||||||
|
// Set the `input` input argument.
|
||||||
|
fused_node->add_input(shuffle_node.input(0));
|
||||||
|
|
||||||
|
// Set the `buffer_size` input argument.
|
||||||
|
fused_node->add_input(shuffle_node.input(1));
|
||||||
|
|
||||||
|
// Default the `seed` input argument to 0.
|
||||||
|
fused_node->add_input(zero_node.name());
|
||||||
|
|
||||||
|
// Default the `seed2` input argument to 0.
|
||||||
|
fused_node->add_input(zero_node.name());
|
||||||
|
|
||||||
|
// Set the `count` input argument.
|
||||||
|
fused_node->add_input(repeat_node.input(1));
|
||||||
|
|
||||||
|
// Set the `seed_generator` input argument.
|
||||||
|
fused_node->add_input(shuffle_node.input(2));
|
||||||
|
|
||||||
|
// Set `output_types` and `output_shapes` attributes.
|
||||||
|
for (auto key : {kOutputShapes, kOutputTypes}) {
|
||||||
|
graph_utils::CopyAttribute(key, shuffle_node, fused_node);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Default the `reshuffle_each_iteration` attribute to true.
|
||||||
|
(*fused_node->mutable_attr())[kReshuffleEachIteration].set_b(true);
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status FuseShuffleV3AndRepeat(const NodeDef& shuffle_node,
|
||||||
|
const NodeDef& repeat_node,
|
||||||
|
MutableGraphView* graph, GraphDef* output,
|
||||||
|
NodeDef* fused_node) {
|
||||||
|
fused_node->set_op(kShuffleAndRepeatDatasetV2);
|
||||||
|
graph_utils::SetUniqueGraphNodeName(kShuffleAndRepeatDataset, output,
|
||||||
|
fused_node);
|
||||||
|
|
||||||
|
// Set the `input` input argument.
|
||||||
|
fused_node->add_input(shuffle_node.input(0));
|
||||||
|
|
||||||
|
// Set the `buffer_size` input argument.
|
||||||
|
fused_node->add_input(shuffle_node.input(1));
|
||||||
|
|
||||||
|
// Set the `seed` input argument.
|
||||||
|
fused_node->add_input(shuffle_node.input(2));
|
||||||
|
|
||||||
|
// Set the `seed2` input argument.
|
||||||
|
fused_node->add_input(shuffle_node.input(3));
|
||||||
|
|
||||||
|
// Set the `count` input argument.
|
||||||
|
fused_node->add_input(repeat_node.input(1));
|
||||||
|
|
||||||
|
// Set the `seed_generator` input argument.
|
||||||
|
fused_node->add_input(shuffle_node.input(4));
|
||||||
|
|
||||||
|
// Set `output_types`, `output_shapes`, and `reshuffle_each_iteration`
|
||||||
|
// attributes.
|
||||||
|
for (auto key : {kOutputShapes, kOutputTypes, kReshuffleEachIteration}) {
|
||||||
|
graph_utils::CopyAttribute(key, shuffle_node, fused_node);
|
||||||
|
}
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
@ -42,65 +158,46 @@ Status ShuffleAndRepeatFusion::OptimizeAndCollectStats(
|
|||||||
MutableGraphView graph(output);
|
MutableGraphView graph(output);
|
||||||
absl::flat_hash_set<string> nodes_to_delete;
|
absl::flat_hash_set<string> nodes_to_delete;
|
||||||
|
|
||||||
auto make_shuffle_and_repeat_node = [&output](const NodeDef& shuffle_node,
|
for (const NodeDef& repeat_node : item.graph.node()) {
|
||||||
const NodeDef& repeat_node) {
|
if (repeat_node.op() != kRepeatDataset) {
|
||||||
NodeDef new_node;
|
|
||||||
new_node.set_op(kFusedOpName);
|
|
||||||
graph_utils::SetUniqueGraphNodeName(kFusedOpName, output, &new_node);
|
|
||||||
|
|
||||||
// Set the `input` input argument.
|
|
||||||
new_node.add_input(shuffle_node.input(0));
|
|
||||||
|
|
||||||
// Set the `buffer_size` input argument.
|
|
||||||
new_node.add_input(shuffle_node.input(1));
|
|
||||||
|
|
||||||
// Set the `seed` input argument.
|
|
||||||
new_node.add_input(shuffle_node.input(2));
|
|
||||||
|
|
||||||
// Set the `seed2` input argument.
|
|
||||||
new_node.add_input(shuffle_node.input(3));
|
|
||||||
|
|
||||||
// Set the `count` input argument.
|
|
||||||
new_node.add_input(repeat_node.input(1));
|
|
||||||
|
|
||||||
// Set `output_types` and `output_shapes` attributes.
|
|
||||||
for (auto key : {"output_shapes", "output_types"}) {
|
|
||||||
graph_utils::CopyAttribute(key, repeat_node, &new_node);
|
|
||||||
}
|
|
||||||
return new_node;
|
|
||||||
};
|
|
||||||
|
|
||||||
for (const NodeDef& node : item.graph.node()) {
|
|
||||||
if (node.op() != "RepeatDataset") {
|
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Use a more descriptive variable name now that we know the node type.
|
const NodeDef& shuffle_node =
|
||||||
const NodeDef& repeat_node = node;
|
*graph_utils::GetInputNode(repeat_node, graph);
|
||||||
NodeDef* node2 = graph_utils::GetInputNode(repeat_node, graph);
|
|
||||||
|
|
||||||
if (node2->op() != "ShuffleDataset") {
|
NodeDef fused_node;
|
||||||
|
if (shuffle_node.op() == kShuffleDataset) {
|
||||||
|
TF_RETURN_IF_ERROR(FuseShuffleV1AndRepeat(shuffle_node, repeat_node,
|
||||||
|
&graph, output, &fused_node));
|
||||||
|
} else if (shuffle_node.op() == kShuffleDatasetV2) {
|
||||||
|
TF_RETURN_IF_ERROR(FuseShuffleV2AndRepeat(shuffle_node, repeat_node,
|
||||||
|
&graph, output, &fused_node));
|
||||||
|
|
||||||
|
} else if (shuffle_node.op() == kShuffleDatasetV3) {
|
||||||
|
TF_RETURN_IF_ERROR(FuseShuffleV3AndRepeat(shuffle_node, repeat_node,
|
||||||
|
&graph, output, &fused_node));
|
||||||
|
} else {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Use a more descriptive variable name now that we know the node type.
|
NodeDef& shuffle_and_repeat_node = *graph.AddNode(std::move(fused_node));
|
||||||
const NodeDef& shuffle_node = *node2;
|
|
||||||
|
|
||||||
// TODO(b/129712758): Remove when the fused kernel supports disabling
|
|
||||||
// reshuffling for each iteration.
|
|
||||||
if (HasNodeAttr(shuffle_node, "reshuffle_each_iteration") &&
|
|
||||||
!shuffle_node.attr().at("reshuffle_each_iteration").b()) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
NodeDef* shuffle_and_repeat_node =
|
|
||||||
graph.AddNode(make_shuffle_and_repeat_node(shuffle_node, repeat_node));
|
|
||||||
TF_RETURN_IF_ERROR(graph.UpdateFanouts(repeat_node.name(),
|
TF_RETURN_IF_ERROR(graph.UpdateFanouts(repeat_node.name(),
|
||||||
shuffle_and_repeat_node->name()));
|
shuffle_and_repeat_node.name()));
|
||||||
|
// Update shuffle node fanouts to shuffle_and_repeat fanouts to take care of
|
||||||
|
// control dependencies.
|
||||||
|
TF_RETURN_IF_ERROR(graph.UpdateFanouts(shuffle_node.name(),
|
||||||
|
shuffle_and_repeat_node.name()));
|
||||||
|
|
||||||
// Mark the `Shuffle` and `Repeat` nodes for removal.
|
// Mark the `Shuffle` and `Repeat` nodes for removal (as long as neither of
|
||||||
|
// them needs to be preserved).
|
||||||
|
const auto nodes_to_preserve = item.NodesToPreserve();
|
||||||
|
if (nodes_to_preserve.find(shuffle_node.name()) ==
|
||||||
|
nodes_to_preserve.end() &&
|
||||||
|
nodes_to_preserve.find(repeat_node.name()) == nodes_to_preserve.end()) {
|
||||||
nodes_to_delete.insert(shuffle_node.name());
|
nodes_to_delete.insert(shuffle_node.name());
|
||||||
nodes_to_delete.insert(repeat_node.name());
|
nodes_to_delete.insert(repeat_node.name());
|
||||||
|
}
|
||||||
stats->num_changes++;
|
stats->num_changes++;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -25,17 +25,21 @@ namespace tensorflow {
|
|||||||
namespace grappler {
|
namespace grappler {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
TEST(ShuffleAndRepeatFusionTest, FuseShuffleAndRepeatNodesIntoOne) {
|
constexpr char kOutputShapes[] = "output_shapes";
|
||||||
|
constexpr char kOutputTypes[] = "output_types";
|
||||||
|
constexpr char kReshuffleEachIteration[] = "reshuffle_each_iteration";
|
||||||
|
|
||||||
|
TEST(ShuffleAndRepeatFusionTest, FuseShuffleV1AndRepeat) {
|
||||||
GrapplerItem item;
|
GrapplerItem item;
|
||||||
MutableGraphView graph(&item.graph);
|
MutableGraphView graph(&item.graph);
|
||||||
|
|
||||||
std::vector<std::pair<string, AttrValue>> common_attrs(2);
|
std::vector<std::pair<string, AttrValue>> common_attrs(2);
|
||||||
AttrValue shapes_attr;
|
AttrValue shapes_attr;
|
||||||
SetAttrValue("output_shapes", &shapes_attr);
|
SetAttrValue(kOutputShapes, &shapes_attr);
|
||||||
common_attrs[0] = std::make_pair("output_shapes", shapes_attr);
|
common_attrs[0] = std::make_pair(kOutputShapes, shapes_attr);
|
||||||
AttrValue types_attr;
|
AttrValue types_attr;
|
||||||
SetAttrValue("output_types", &types_attr);
|
SetAttrValue(kOutputTypes, &types_attr);
|
||||||
common_attrs[1] = std::make_pair("output_types", types_attr);
|
common_attrs[1] = std::make_pair(kOutputTypes, types_attr);
|
||||||
|
|
||||||
NodeDef *start_node = graph_utils::AddScalarConstNode<int64>(0, &graph);
|
NodeDef *start_node = graph_utils::AddScalarConstNode<int64>(0, &graph);
|
||||||
NodeDef *stop_node = graph_utils::AddScalarConstNode<int64>(10, &graph);
|
NodeDef *stop_node = graph_utils::AddScalarConstNode<int64>(10, &graph);
|
||||||
@ -59,6 +63,7 @@ TEST(ShuffleAndRepeatFusionTest, FuseShuffleAndRepeatNodesIntoOne) {
|
|||||||
shuffle_inputs[3] = seed2_node->name();
|
shuffle_inputs[3] = seed2_node->name();
|
||||||
NodeDef *shuffle_node = graph_utils::AddNode(
|
NodeDef *shuffle_node = graph_utils::AddNode(
|
||||||
"", "ShuffleDataset", shuffle_inputs, common_attrs, &graph);
|
"", "ShuffleDataset", shuffle_inputs, common_attrs, &graph);
|
||||||
|
(*shuffle_node->mutable_attr())[kReshuffleEachIteration].set_b(true);
|
||||||
|
|
||||||
NodeDef *count_node = graph_utils::AddScalarConstNode<int64>(-1, &graph);
|
NodeDef *count_node = graph_utils::AddScalarConstNode<int64>(-1, &graph);
|
||||||
std::vector<string> repeat_inputs(2);
|
std::vector<string> repeat_inputs(2);
|
||||||
@ -85,12 +90,148 @@ TEST(ShuffleAndRepeatFusionTest, FuseShuffleAndRepeatNodesIntoOne) {
|
|||||||
EXPECT_EQ(shuffle_and_repeat_node.input(2), shuffle_node->input(2));
|
EXPECT_EQ(shuffle_and_repeat_node.input(2), shuffle_node->input(2));
|
||||||
EXPECT_EQ(shuffle_and_repeat_node.input(3), shuffle_node->input(3));
|
EXPECT_EQ(shuffle_and_repeat_node.input(3), shuffle_node->input(3));
|
||||||
EXPECT_EQ(shuffle_and_repeat_node.input(4), repeat_node->input(1));
|
EXPECT_EQ(shuffle_and_repeat_node.input(4), repeat_node->input(1));
|
||||||
|
for (const auto &attr :
|
||||||
|
{kOutputShapes, kOutputTypes, kReshuffleEachIteration}) {
|
||||||
|
EXPECT_TRUE(AreAttrValuesEqual(shuffle_and_repeat_node.attr().at(attr),
|
||||||
|
shuffle_node->attr().at(attr)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(ShuffleAndRepeatFusionTest, FuseShuffleV2AndRepeat) {
|
||||||
|
GrapplerItem item;
|
||||||
|
MutableGraphView graph(&item.graph);
|
||||||
|
|
||||||
|
std::vector<std::pair<string, AttrValue>> common_attrs(2);
|
||||||
|
AttrValue shapes_attr;
|
||||||
|
SetAttrValue(kOutputShapes, &shapes_attr);
|
||||||
|
common_attrs[0] = std::make_pair(kOutputShapes, shapes_attr);
|
||||||
|
AttrValue types_attr;
|
||||||
|
SetAttrValue(kOutputTypes, &types_attr);
|
||||||
|
common_attrs[1] = std::make_pair(kOutputTypes, types_attr);
|
||||||
|
|
||||||
|
NodeDef *start_node = graph_utils::AddScalarConstNode<int64>(0, &graph);
|
||||||
|
NodeDef *stop_node = graph_utils::AddScalarConstNode<int64>(10, &graph);
|
||||||
|
NodeDef *step_node = graph_utils::AddScalarConstNode<int64>(1, &graph);
|
||||||
|
|
||||||
|
std::vector<string> range_inputs(3);
|
||||||
|
range_inputs[0] = start_node->name();
|
||||||
|
range_inputs[1] = stop_node->name();
|
||||||
|
range_inputs[2] = step_node->name();
|
||||||
|
NodeDef *range_node = graph_utils::AddNode("", "RangeDataset", range_inputs,
|
||||||
|
common_attrs, &graph);
|
||||||
|
|
||||||
|
NodeDef *buffer_size_node =
|
||||||
|
graph_utils::AddScalarConstNode<int64>(128, &graph);
|
||||||
|
NodeDef *seed_generator_node =
|
||||||
|
graph_utils::AddScalarConstNode<StringPiece>("dummy_resource", &graph);
|
||||||
|
std::vector<string> shuffle_inputs(3);
|
||||||
|
shuffle_inputs[0] = range_node->name();
|
||||||
|
shuffle_inputs[1] = buffer_size_node->name();
|
||||||
|
shuffle_inputs[2] = seed_generator_node->name();
|
||||||
|
NodeDef *shuffle_node = graph_utils::AddNode(
|
||||||
|
"", "ShuffleDatasetV2", shuffle_inputs, common_attrs, &graph);
|
||||||
|
|
||||||
|
NodeDef *count_node = graph_utils::AddScalarConstNode<int64>(-1, &graph);
|
||||||
|
std::vector<string> repeat_inputs(2);
|
||||||
|
repeat_inputs[0] = shuffle_node->name();
|
||||||
|
repeat_inputs[1] = count_node->name();
|
||||||
|
NodeDef *repeat_node = graph_utils::AddNode(
|
||||||
|
"", "RepeatDataset", repeat_inputs, common_attrs, &graph);
|
||||||
|
|
||||||
|
ShuffleAndRepeatFusion optimizer;
|
||||||
|
GraphDef output;
|
||||||
|
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
|
||||||
|
|
||||||
|
EXPECT_FALSE(
|
||||||
|
graph_utils::ContainsGraphNodeWithName(shuffle_node->name(), output));
|
||||||
|
EXPECT_FALSE(
|
||||||
|
graph_utils::ContainsGraphNodeWithName(repeat_node->name(), output));
|
||||||
EXPECT_TRUE(
|
EXPECT_TRUE(
|
||||||
AreAttrValuesEqual(shuffle_and_repeat_node.attr().at("output_shapes"),
|
graph_utils::ContainsNodeWithOp("ShuffleAndRepeatDatasetV2", output));
|
||||||
repeat_node->attr().at("output_shapes")));
|
NodeDef shuffle_and_repeat_node = output.node(
|
||||||
|
graph_utils::FindGraphNodeWithOp("ShuffleAndRepeatDatasetV2", output));
|
||||||
|
EXPECT_EQ(shuffle_and_repeat_node.input_size(), 6);
|
||||||
|
EXPECT_EQ(shuffle_and_repeat_node.input(0), shuffle_node->input(0));
|
||||||
|
EXPECT_EQ(shuffle_and_repeat_node.input(1), shuffle_node->input(1));
|
||||||
|
EXPECT_EQ(shuffle_and_repeat_node.input(4), repeat_node->input(1));
|
||||||
|
EXPECT_EQ(shuffle_and_repeat_node.input(5), shuffle_node->input(2));
|
||||||
|
for (const auto &attr : {kOutputShapes, kOutputTypes}) {
|
||||||
|
EXPECT_TRUE(AreAttrValuesEqual(shuffle_and_repeat_node.attr().at(attr),
|
||||||
|
shuffle_node->attr().at(attr)));
|
||||||
|
}
|
||||||
|
EXPECT_TRUE(shuffle_and_repeat_node.attr().at(kReshuffleEachIteration).b());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(ShuffleAndRepeatFusionTest, FuseShuffleV3AndRepeat) {
|
||||||
|
GrapplerItem item;
|
||||||
|
MutableGraphView graph(&item.graph);
|
||||||
|
|
||||||
|
std::vector<std::pair<string, AttrValue>> common_attrs(2);
|
||||||
|
AttrValue shapes_attr;
|
||||||
|
SetAttrValue(kOutputShapes, &shapes_attr);
|
||||||
|
common_attrs[0] = std::make_pair(kOutputShapes, shapes_attr);
|
||||||
|
AttrValue types_attr;
|
||||||
|
SetAttrValue(kOutputTypes, &types_attr);
|
||||||
|
common_attrs[1] = std::make_pair(kOutputTypes, types_attr);
|
||||||
|
|
||||||
|
NodeDef *start_node = graph_utils::AddScalarConstNode<int64>(0, &graph);
|
||||||
|
NodeDef *stop_node = graph_utils::AddScalarConstNode<int64>(10, &graph);
|
||||||
|
NodeDef *step_node = graph_utils::AddScalarConstNode<int64>(1, &graph);
|
||||||
|
|
||||||
|
std::vector<string> range_inputs(3);
|
||||||
|
range_inputs[0] = start_node->name();
|
||||||
|
range_inputs[1] = stop_node->name();
|
||||||
|
range_inputs[2] = step_node->name();
|
||||||
|
NodeDef *range_node = graph_utils::AddNode("", "RangeDataset", range_inputs,
|
||||||
|
common_attrs, &graph);
|
||||||
|
|
||||||
|
NodeDef *buffer_size_node =
|
||||||
|
graph_utils::AddScalarConstNode<int64>(128, &graph);
|
||||||
|
NodeDef *seed_node = graph_utils::AddScalarConstNode<int64>(-1, &graph);
|
||||||
|
NodeDef *seed2_node = graph_utils::AddScalarConstNode<int64>(-1, &graph);
|
||||||
|
NodeDef *seed_generator_node =
|
||||||
|
graph_utils::AddScalarConstNode<StringPiece>("dummy_resource", &graph);
|
||||||
|
std::vector<string> shuffle_inputs(5);
|
||||||
|
shuffle_inputs[0] = range_node->name();
|
||||||
|
shuffle_inputs[1] = buffer_size_node->name();
|
||||||
|
shuffle_inputs[2] = seed_node->name();
|
||||||
|
shuffle_inputs[3] = seed2_node->name();
|
||||||
|
shuffle_inputs[4] = seed_generator_node->name();
|
||||||
|
NodeDef *shuffle_node = graph_utils::AddNode(
|
||||||
|
"", "ShuffleDatasetV3", shuffle_inputs, common_attrs, &graph);
|
||||||
|
(*shuffle_node->mutable_attr())[kReshuffleEachIteration].set_b(true);
|
||||||
|
|
||||||
|
NodeDef *count_node = graph_utils::AddScalarConstNode<int64>(-1, &graph);
|
||||||
|
std::vector<string> repeat_inputs(2);
|
||||||
|
repeat_inputs[0] = shuffle_node->name();
|
||||||
|
repeat_inputs[1] = count_node->name();
|
||||||
|
NodeDef *repeat_node = graph_utils::AddNode(
|
||||||
|
"", "RepeatDataset", repeat_inputs, common_attrs, &graph);
|
||||||
|
|
||||||
|
ShuffleAndRepeatFusion optimizer;
|
||||||
|
GraphDef output;
|
||||||
|
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
|
||||||
|
|
||||||
|
EXPECT_FALSE(
|
||||||
|
graph_utils::ContainsGraphNodeWithName(shuffle_node->name(), output));
|
||||||
|
EXPECT_FALSE(
|
||||||
|
graph_utils::ContainsGraphNodeWithName(repeat_node->name(), output));
|
||||||
EXPECT_TRUE(
|
EXPECT_TRUE(
|
||||||
AreAttrValuesEqual(shuffle_and_repeat_node.attr().at("output_types"),
|
graph_utils::ContainsNodeWithOp("ShuffleAndRepeatDatasetV2", output));
|
||||||
repeat_node->attr().at("output_types")));
|
NodeDef shuffle_and_repeat_node = output.node(
|
||||||
|
graph_utils::FindGraphNodeWithOp("ShuffleAndRepeatDatasetV2", output));
|
||||||
|
EXPECT_EQ(shuffle_and_repeat_node.input_size(), 6);
|
||||||
|
EXPECT_EQ(shuffle_and_repeat_node.input(0), shuffle_node->input(0));
|
||||||
|
EXPECT_EQ(shuffle_and_repeat_node.input(1), shuffle_node->input(1));
|
||||||
|
EXPECT_EQ(shuffle_and_repeat_node.input(2), shuffle_node->input(2));
|
||||||
|
EXPECT_EQ(shuffle_and_repeat_node.input(3), shuffle_node->input(3));
|
||||||
|
EXPECT_EQ(shuffle_and_repeat_node.input(4), repeat_node->input(1));
|
||||||
|
EXPECT_EQ(shuffle_and_repeat_node.input(5), shuffle_node->input(4));
|
||||||
|
for (const auto &attr :
|
||||||
|
{kOutputShapes, kOutputTypes, kReshuffleEachIteration}) {
|
||||||
|
EXPECT_TRUE(AreAttrValuesEqual(shuffle_and_repeat_node.attr().at(attr),
|
||||||
|
shuffle_node->attr().at(attr)));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(ShuffleAndRepeatFusionTest, NoChange) {
|
TEST(ShuffleAndRepeatFusionTest, NoChange) {
|
||||||
@ -99,11 +240,11 @@ TEST(ShuffleAndRepeatFusionTest, NoChange) {
|
|||||||
|
|
||||||
std::vector<std::pair<string, AttrValue>> common_attrs(2);
|
std::vector<std::pair<string, AttrValue>> common_attrs(2);
|
||||||
AttrValue shapes_attr;
|
AttrValue shapes_attr;
|
||||||
SetAttrValue("output_shapes", &shapes_attr);
|
SetAttrValue(kOutputShapes, &shapes_attr);
|
||||||
common_attrs[0] = std::make_pair("output_shapes", shapes_attr);
|
common_attrs[0] = std::make_pair(kOutputShapes, shapes_attr);
|
||||||
AttrValue types_attr;
|
AttrValue types_attr;
|
||||||
SetAttrValue("output_types", &types_attr);
|
SetAttrValue(kOutputTypes, &types_attr);
|
||||||
common_attrs[1] = std::make_pair("output_types", types_attr);
|
common_attrs[1] = std::make_pair(kOutputTypes, types_attr);
|
||||||
|
|
||||||
NodeDef *start_node = graph_utils::AddScalarConstNode<int64>(0, &graph);
|
NodeDef *start_node = graph_utils::AddScalarConstNode<int64>(0, &graph);
|
||||||
NodeDef *stop_node = graph_utils::AddScalarConstNode<int64>(10, &graph);
|
NodeDef *stop_node = graph_utils::AddScalarConstNode<int64>(10, &graph);
|
||||||
|
@ -44,10 +44,10 @@ namespace data {
|
|||||||
/* static */ constexpr const char* const ShuffleDatasetOpBase::kSeed2;
|
/* static */ constexpr const char* const ShuffleDatasetOpBase::kSeed2;
|
||||||
/* static */ constexpr const char* const ShuffleDatasetOpBase::kOutputTypes;
|
/* static */ constexpr const char* const ShuffleDatasetOpBase::kOutputTypes;
|
||||||
/* static */ constexpr const char* const ShuffleDatasetOpBase::kOutputShapes;
|
/* static */ constexpr const char* const ShuffleDatasetOpBase::kOutputShapes;
|
||||||
|
/* static */ constexpr const char* const
|
||||||
|
ShuffleDatasetOpBase::kReshuffleEachIteration;
|
||||||
|
|
||||||
/* static */ constexpr const char* const ShuffleDatasetOp::kDatasetType;
|
/* static */ constexpr const char* const ShuffleDatasetOp::kDatasetType;
|
||||||
/* static */ constexpr const char* const
|
|
||||||
ShuffleDatasetOp::kReshuffleEachIteration;
|
|
||||||
|
|
||||||
/* static */ constexpr const char* const
|
/* static */ constexpr const char* const
|
||||||
ShuffleAndRepeatDatasetOp::kDatasetType;
|
ShuffleAndRepeatDatasetOp::kDatasetType;
|
||||||
@ -72,6 +72,8 @@ constexpr char kEpochNumRandomSamples[] = "epoch_num_random_samples";
|
|||||||
constexpr char kShuffleDatasetV1[] = "ShuffleDataset";
|
constexpr char kShuffleDatasetV1[] = "ShuffleDataset";
|
||||||
constexpr char kShuffleDatasetV2[] = "ShuffleDatasetV2";
|
constexpr char kShuffleDatasetV2[] = "ShuffleDatasetV2";
|
||||||
constexpr char kShuffleDatasetV3[] = "ShuffleDatasetV3";
|
constexpr char kShuffleDatasetV3[] = "ShuffleDatasetV3";
|
||||||
|
constexpr char kShuffleAndRepeatDatasetV1[] = "ShuffleAndRepeatDatasetV1";
|
||||||
|
constexpr char kShuffleAndRepeatDatasetV2[] = "ShuffleAndRepeatDatasetV2";
|
||||||
|
|
||||||
ShuffleDatasetOpBase::ShuffleDatasetOpBase(OpKernelConstruction* ctx)
|
ShuffleDatasetOpBase::ShuffleDatasetOpBase(OpKernelConstruction* ctx)
|
||||||
: UnaryDatasetOpKernel(ctx) {}
|
: UnaryDatasetOpKernel(ctx) {}
|
||||||
@ -225,6 +227,10 @@ class ShuffleDatasetOpBase::ShuffleDatasetBase : public DatasetBase {
|
|||||||
while (!slices_.empty() &&
|
while (!slices_.empty() &&
|
||||||
slices_.front()->start == slices_.front()->end) {
|
slices_.front()->start == slices_.front()->end) {
|
||||||
slices_.pop_front();
|
slices_.pop_front();
|
||||||
|
// Reinitialize the RNG state for the next epoch.
|
||||||
|
num_random_samples_ = 0;
|
||||||
|
seed_generator_->GenerateSeeds(&seed_, &seed2_);
|
||||||
|
ResetRngs();
|
||||||
}
|
}
|
||||||
DCHECK(!slices_.empty());
|
DCHECK(!slices_.empty());
|
||||||
// Choose an element to produce uniformly at random from the first
|
// Choose an element to produce uniformly at random from the first
|
||||||
@ -663,6 +669,7 @@ void ShuffleDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
|
|||||||
RandomSeeds seeds(seed, seed2);
|
RandomSeeds seeds(seed, seed2);
|
||||||
bool owns_resource = false;
|
bool owns_resource = false;
|
||||||
if (errors::IsNotFound(s)) {
|
if (errors::IsNotFound(s)) {
|
||||||
|
owns_resource = true;
|
||||||
OP_REQUIRES_OK(
|
OP_REQUIRES_OK(
|
||||||
ctx,
|
ctx,
|
||||||
ctx->resource_manager()->LookupOrCreate<SeedGeneratorManager>(
|
ctx->resource_manager()->LookupOrCreate<SeedGeneratorManager>(
|
||||||
@ -679,7 +686,6 @@ void ShuffleDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}));
|
}));
|
||||||
handle = MakeResourceHandle<SeedGenerator>(ctx, container, name);
|
handle = MakeResourceHandle<SeedGenerator>(ctx, container, name);
|
||||||
owns_resource = true;
|
|
||||||
} else {
|
} else {
|
||||||
OP_REQUIRES_OK(ctx, s);
|
OP_REQUIRES_OK(ctx, s);
|
||||||
}
|
}
|
||||||
@ -695,6 +701,7 @@ void ShuffleDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
|
|||||||
handle.container(), handle.name(), &manager);
|
handle.container(), handle.name(), &manager);
|
||||||
bool owns_resource = false;
|
bool owns_resource = false;
|
||||||
if (errors::IsNotFound(s)) {
|
if (errors::IsNotFound(s)) {
|
||||||
|
owns_resource = true;
|
||||||
LOG(WARNING) << "Failed to find seed generator resource. Falling back to "
|
LOG(WARNING) << "Failed to find seed generator resource. Falling back to "
|
||||||
"using a non-deterministically seeded generator and "
|
"using a non-deterministically seeded generator and "
|
||||||
"reshuffling each iteration.";
|
"reshuffling each iteration.";
|
||||||
@ -708,7 +715,6 @@ void ShuffleDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}));
|
}));
|
||||||
handle = MakeResourceHandle<SeedGeneratorManager>(ctx, container, name);
|
handle = MakeResourceHandle<SeedGeneratorManager>(ctx, container, name);
|
||||||
owns_resource = true;
|
|
||||||
} else {
|
} else {
|
||||||
OP_REQUIRES_OK(ctx, s);
|
OP_REQUIRES_OK(ctx, s);
|
||||||
}
|
}
|
||||||
@ -790,9 +796,13 @@ class ShuffleAndRepeatDatasetOp::Dataset : public ShuffleDatasetBase {
|
|||||||
TF_RETURN_IF_ERROR(b->AddScalar(seeds_.input_seed(), &seed));
|
TF_RETURN_IF_ERROR(b->AddScalar(seeds_.input_seed(), &seed));
|
||||||
TF_RETURN_IF_ERROR(b->AddScalar(seeds_.input_seed2(), &seed2));
|
TF_RETURN_IF_ERROR(b->AddScalar(seeds_.input_seed2(), &seed2));
|
||||||
TF_RETURN_IF_ERROR(b->AddScalar(count_, &count));
|
TF_RETURN_IF_ERROR(b->AddScalar(count_, &count));
|
||||||
|
AttrValue reshuffle_each_iteration;
|
||||||
|
b->BuildAttrValue(seed_generator_->reshuffle_each_iteration(),
|
||||||
|
&reshuffle_each_iteration);
|
||||||
TF_RETURN_IF_ERROR(b->AddDataset(
|
TF_RETURN_IF_ERROR(b->AddDataset(
|
||||||
this, {input_graph_node, buffer_size, seed, seed2, count}, // Inputs
|
this, {input_graph_node, buffer_size, seed, seed2, count}, // Inputs
|
||||||
{}, // Attrs
|
{std::make_pair(kReshuffleEachIteration,
|
||||||
|
reshuffle_each_iteration)}, // Attrs
|
||||||
output));
|
output));
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
@ -804,8 +814,83 @@ class ShuffleAndRepeatDatasetOp::Dataset : public ShuffleDatasetBase {
|
|||||||
const RandomSeeds seeds_;
|
const RandomSeeds seeds_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class ShuffleAndRepeatDatasetOp::DatasetV2 : public ShuffleDatasetBase {
|
||||||
|
public:
|
||||||
|
DatasetV2(OpKernelContext* ctx, const DatasetBase* input, int64 buffer_size,
|
||||||
|
int64 count, RandomSeeds&& seeds, SeedGeneratorManager* manager,
|
||||||
|
ResourceHandle&& resource_handle, bool owns_resource)
|
||||||
|
: ShuffleDatasetBase(ctx, input, buffer_size, manager->get(), count),
|
||||||
|
manager_(manager),
|
||||||
|
owns_resource_(owns_resource),
|
||||||
|
resource_handle_(std::move(resource_handle)),
|
||||||
|
resource_mgr_(ctx->resource_manager()),
|
||||||
|
seeds_(std::move(seeds)) {}
|
||||||
|
|
||||||
|
~DatasetV2() override {
|
||||||
|
manager_->Unref();
|
||||||
|
if (owns_resource_) {
|
||||||
|
Status s = resource_mgr_->Delete<SeedGeneratorManager>(
|
||||||
|
resource_handle_.container(), resource_handle_.name());
|
||||||
|
if (!s.ok()) {
|
||||||
|
LOG(WARNING) << "Failed to delete RNG resource: " << s.ToString();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
string op_type() const override { return kDatasetType; }
|
||||||
|
|
||||||
|
protected:
|
||||||
|
Status AsGraphDefInternal(SerializationContext* ctx,
|
||||||
|
DatasetGraphDefBuilder* b,
|
||||||
|
Node** output) const override {
|
||||||
|
Node* input_graph_node = nullptr;
|
||||||
|
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
|
||||||
|
Node* buffer_size_node = nullptr;
|
||||||
|
Node* seed_node = nullptr;
|
||||||
|
Node* seed2_node = nullptr;
|
||||||
|
Node* count_node = nullptr;
|
||||||
|
TF_RETURN_IF_ERROR(b->AddScalar(buffer_size_, &buffer_size_node));
|
||||||
|
TF_RETURN_IF_ERROR(b->AddScalar(seeds_.input_seed(), &seed_node));
|
||||||
|
TF_RETURN_IF_ERROR(b->AddScalar(seeds_.input_seed2(), &seed2_node));
|
||||||
|
TF_RETURN_IF_ERROR(b->AddScalar(count_, &count_node));
|
||||||
|
Node* resource_handle_node = nullptr;
|
||||||
|
Tensor handle(DT_RESOURCE, TensorShape({}));
|
||||||
|
handle.scalar<ResourceHandle>()() = resource_handle_;
|
||||||
|
TF_RETURN_IF_ERROR(b->AddTensor(handle, &resource_handle_node));
|
||||||
|
AttrValue reshuffle_each_iteration;
|
||||||
|
b->BuildAttrValue(seed_generator_->reshuffle_each_iteration(),
|
||||||
|
&reshuffle_each_iteration);
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
b->AddDataset(this,
|
||||||
|
{input_graph_node, buffer_size_node, seed_node,
|
||||||
|
seed2_node, count_node, resource_handle_node}, // Inputs
|
||||||
|
{std::make_pair(kReshuffleEachIteration,
|
||||||
|
reshuffle_each_iteration)}, // Attrs
|
||||||
|
output));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
SeedGeneratorManager* const manager_; // Owned
|
||||||
|
const bool owns_resource_;
|
||||||
|
const ResourceHandle resource_handle_;
|
||||||
|
ResourceMgr* const resource_mgr_; // Not owned.
|
||||||
|
const RandomSeeds seeds_;
|
||||||
|
};
|
||||||
|
|
||||||
ShuffleAndRepeatDatasetOp::ShuffleAndRepeatDatasetOp(OpKernelConstruction* ctx)
|
ShuffleAndRepeatDatasetOp::ShuffleAndRepeatDatasetOp(OpKernelConstruction* ctx)
|
||||||
: ShuffleDatasetOpBase(ctx) {}
|
: ShuffleDatasetOpBase(ctx) {
|
||||||
|
auto& op_name = ctx->def().op();
|
||||||
|
if (op_name == kShuffleAndRepeatDatasetV2) {
|
||||||
|
op_version_ = 2;
|
||||||
|
} else if (op_name == kShuffleAndRepeatDatasetV1) {
|
||||||
|
op_version_ = 1;
|
||||||
|
}
|
||||||
|
if (ctx->HasAttr(kReshuffleEachIteration)) {
|
||||||
|
OP_REQUIRES_OK(
|
||||||
|
ctx, ctx->GetAttr(kReshuffleEachIteration, &reshuffle_each_iteration_));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void ShuffleAndRepeatDatasetOp::MakeDataset(OpKernelContext* ctx,
|
void ShuffleAndRepeatDatasetOp::MakeDataset(OpKernelContext* ctx,
|
||||||
DatasetBase* input,
|
DatasetBase* input,
|
||||||
@ -826,29 +911,76 @@ void ShuffleAndRepeatDatasetOp::MakeDataset(OpKernelContext* ctx,
|
|||||||
int64 count;
|
int64 count;
|
||||||
OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, kCount, &count));
|
OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, kCount, &count));
|
||||||
|
|
||||||
RandomSeeds seeds(seed, seed2);
|
|
||||||
|
|
||||||
OP_REQUIRES(ctx, count > 0 || count == -1,
|
OP_REQUIRES(ctx, count > 0 || count == -1,
|
||||||
errors::InvalidArgument(
|
errors::InvalidArgument(
|
||||||
"count must be greater than zero or equal to -1."));
|
"count must be greater than zero or equal to -1."));
|
||||||
|
|
||||||
|
RandomSeeds seeds(seed, seed2);
|
||||||
|
|
||||||
static std::atomic<int64> resource_id_counter(0);
|
static std::atomic<int64> resource_id_counter(0);
|
||||||
const string& container = ctx->resource_manager()->default_container();
|
const string& container = ctx->resource_manager()->default_container();
|
||||||
auto name = strings::StrCat(ctx->op_kernel().name(), "/", kSeedGenerator, "_",
|
auto name = strings::StrCat(ctx->op_kernel().name(), "/", kSeedGenerator, "_",
|
||||||
resource_id_counter.fetch_add(1));
|
resource_id_counter.fetch_add(1));
|
||||||
|
if (op_version_ == 2) {
|
||||||
|
auto handle = HandleFromInput(ctx, 5);
|
||||||
|
SeedGeneratorManager* manager = nullptr;
|
||||||
|
Status s = ctx->resource_manager()->Lookup<SeedGeneratorManager>(
|
||||||
|
handle.container(), handle.name(), &manager);
|
||||||
|
bool owns_resource = false;
|
||||||
|
if (errors::IsNotFound(s)) {
|
||||||
|
owns_resource = true;
|
||||||
|
OP_REQUIRES_OK(
|
||||||
|
ctx,
|
||||||
|
ctx->resource_manager()->LookupOrCreate<SeedGeneratorManager>(
|
||||||
|
container, name, &manager,
|
||||||
|
[reshuffle = reshuffle_each_iteration_,
|
||||||
|
&seeds](SeedGeneratorManager** manager) {
|
||||||
|
if (reshuffle) {
|
||||||
|
*manager =
|
||||||
|
new SeedGeneratorManager(new RandomSeedGenerator(seeds));
|
||||||
|
} else {
|
||||||
|
*manager =
|
||||||
|
new SeedGeneratorManager(new FixedSeedGenerator(seeds));
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}));
|
||||||
|
handle = MakeResourceHandle<SeedGenerator>(ctx, container, name);
|
||||||
|
} else {
|
||||||
|
OP_REQUIRES_OK(ctx, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ownership of manager is transferred onto `DatasetV2`.
|
||||||
|
*output = new ShuffleAndRepeatDatasetOp::DatasetV2(
|
||||||
|
ctx, input, buffer_size, count, std::move(seeds), manager,
|
||||||
|
std::move(handle), owns_resource);
|
||||||
|
} else {
|
||||||
|
if (op_version_ != 1) {
|
||||||
|
LOG(WARNING) << "Unsupported version of shuffle dataset op: "
|
||||||
|
<< op_version_ << ". Defaulting to version 1.";
|
||||||
|
}
|
||||||
SeedGeneratorManager* manager;
|
SeedGeneratorManager* manager;
|
||||||
OP_REQUIRES_OK(
|
OP_REQUIRES_OK(
|
||||||
ctx,
|
ctx,
|
||||||
ctx->resource_manager()->LookupOrCreate<SeedGeneratorManager>(
|
ctx->resource_manager()->LookupOrCreate<SeedGeneratorManager>(
|
||||||
container, name, &manager, [&seeds](SeedGeneratorManager** manager) {
|
container, name, &manager,
|
||||||
*manager = new SeedGeneratorManager(new RandomSeedGenerator(seeds));
|
[reshuffle = reshuffle_each_iteration_,
|
||||||
|
&seeds](SeedGeneratorManager** manager) {
|
||||||
|
if (reshuffle) {
|
||||||
|
*manager =
|
||||||
|
new SeedGeneratorManager(new RandomSeedGenerator(seeds));
|
||||||
|
} else {
|
||||||
|
*manager =
|
||||||
|
new SeedGeneratorManager(new FixedSeedGenerator(seeds));
|
||||||
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}));
|
}));
|
||||||
auto handle = MakeResourceHandle<SeedGeneratorManager>(ctx, container, name);
|
auto handle =
|
||||||
|
MakeResourceHandle<SeedGeneratorManager>(ctx, container, name);
|
||||||
|
|
||||||
// Ownership of manager is transferred onto `Dataset`.
|
// Ownership of manager is transferred onto `Dataset`.
|
||||||
*output = new Dataset(ctx, input, buffer_size, std::move(seeds), manager,
|
*output = new Dataset(ctx, input, buffer_size, std::move(seeds), manager,
|
||||||
count, std::move(handle));
|
count, std::move(handle));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
@ -863,6 +995,9 @@ REGISTER_KERNEL_BUILDER(Name("ShuffleDatasetV3").Device(DEVICE_CPU),
|
|||||||
|
|
||||||
REGISTER_KERNEL_BUILDER(Name("ShuffleAndRepeatDataset").Device(DEVICE_CPU),
|
REGISTER_KERNEL_BUILDER(Name("ShuffleAndRepeatDataset").Device(DEVICE_CPU),
|
||||||
ShuffleAndRepeatDatasetOp);
|
ShuffleAndRepeatDatasetOp);
|
||||||
|
|
||||||
|
REGISTER_KERNEL_BUILDER(Name("ShuffleAndRepeatDatasetV2").Device(DEVICE_CPU),
|
||||||
|
ShuffleAndRepeatDatasetOp);
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace data
|
} // namespace data
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -28,6 +28,8 @@ class ShuffleDatasetOpBase : public UnaryDatasetOpKernel {
|
|||||||
static constexpr const char* const kSeed2 = "seed2";
|
static constexpr const char* const kSeed2 = "seed2";
|
||||||
static constexpr const char* const kOutputTypes = "output_types";
|
static constexpr const char* const kOutputTypes = "output_types";
|
||||||
static constexpr const char* const kOutputShapes = "output_shapes";
|
static constexpr const char* const kOutputShapes = "output_shapes";
|
||||||
|
static constexpr const char* const kReshuffleEachIteration =
|
||||||
|
"reshuffle_each_iteration";
|
||||||
|
|
||||||
explicit ShuffleDatasetOpBase(OpKernelConstruction* ctx);
|
explicit ShuffleDatasetOpBase(OpKernelConstruction* ctx);
|
||||||
|
|
||||||
@ -38,8 +40,6 @@ class ShuffleDatasetOpBase : public UnaryDatasetOpKernel {
|
|||||||
class ShuffleDatasetOp : public ShuffleDatasetOpBase {
|
class ShuffleDatasetOp : public ShuffleDatasetOpBase {
|
||||||
public:
|
public:
|
||||||
static constexpr const char* const kDatasetType = "Shuffle";
|
static constexpr const char* const kDatasetType = "Shuffle";
|
||||||
static constexpr const char* const kReshuffleEachIteration =
|
|
||||||
"reshuffle_each_iteration";
|
|
||||||
|
|
||||||
explicit ShuffleDatasetOp(OpKernelConstruction* ctx);
|
explicit ShuffleDatasetOp(OpKernelConstruction* ctx);
|
||||||
|
|
||||||
@ -52,7 +52,7 @@ class ShuffleDatasetOp : public ShuffleDatasetOpBase {
|
|||||||
class DatasetV2;
|
class DatasetV2;
|
||||||
class DatasetV3;
|
class DatasetV3;
|
||||||
int op_version_ = 0;
|
int op_version_ = 0;
|
||||||
bool reshuffle_each_iteration_;
|
bool reshuffle_each_iteration_ = true;
|
||||||
};
|
};
|
||||||
|
|
||||||
class ShuffleAndRepeatDatasetOp : public ShuffleDatasetOpBase {
|
class ShuffleAndRepeatDatasetOp : public ShuffleDatasetOpBase {
|
||||||
@ -68,6 +68,9 @@ class ShuffleAndRepeatDatasetOp : public ShuffleDatasetOpBase {
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
class Dataset;
|
class Dataset;
|
||||||
|
class DatasetV2;
|
||||||
|
int op_version_ = 0;
|
||||||
|
bool reshuffle_each_iteration_ = true;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace data
|
} // namespace data
|
||||||
|
@ -72,10 +72,8 @@ class ShuffleDatasetParams : public DatasetParams {
|
|||||||
output_dtypes_);
|
output_dtypes_);
|
||||||
attr_vector->emplace_back(ShuffleDatasetOpBase::kOutputShapes,
|
attr_vector->emplace_back(ShuffleDatasetOpBase::kOutputShapes,
|
||||||
output_shapes_);
|
output_shapes_);
|
||||||
if (count_ == 1) {
|
|
||||||
attr_vector->emplace_back(ShuffleDatasetOp::kReshuffleEachIteration,
|
attr_vector->emplace_back(ShuffleDatasetOp::kReshuffleEachIteration,
|
||||||
reshuffle_each_iteration_);
|
reshuffle_each_iteration_);
|
||||||
}
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -297,23 +295,23 @@ std::vector<GetNextTestCase<ShuffleDatasetParams>> GetNextTestCases() {
|
|||||||
{/*dataset_params=*/ShuffleDatasetParams7(),
|
{/*dataset_params=*/ShuffleDatasetParams7(),
|
||||||
/*expected_shuffle_outputs=*/
|
/*expected_shuffle_outputs=*/
|
||||||
CreateTensors<int64>(TensorShape({}),
|
CreateTensors<int64>(TensorShape({}),
|
||||||
{{2}, {6}, {1}, {3}, {9}, {5}, {0}, {8}, {7}, {4},
|
{{9}, {0}, {8}, {6}, {1}, {3}, {7}, {2}, {4}, {5},
|
||||||
{0}, {5}, {1}, {7}, {2}, {9}, {8}, {4}, {6}, {3}}),
|
{9}, {0}, {8}, {6}, {1}, {3}, {7}, {2}, {4}, {5}}),
|
||||||
/*expected_reshuffle_outputs=*/
|
/*expected_reshuffle_outputs=*/
|
||||||
CreateTensors<int64>(TensorShape({}), {{1}, {6}, {0}, {5}, {2}, {7}, {4},
|
CreateTensors<int64>(TensorShape({}), {{9}, {0}, {8}, {6}, {1}, {3}, {7},
|
||||||
{3}, {9}, {8}, {6}, {5}, {0}, {9},
|
{2}, {4}, {5}, {9}, {0}, {8}, {6},
|
||||||
{4}, {7}, {2}, {8}, {1}, {3}})},
|
{1}, {3}, {7}, {2}, {4}, {5}})},
|
||||||
{/*dataset_params=*/ShuffleDatasetParams8(),
|
{/*dataset_params=*/ShuffleDatasetParams8(),
|
||||||
/*expected_shuffle_outputs=*/
|
/*expected_shuffle_outputs=*/
|
||||||
CreateTensors<int64>(
|
CreateTensors<int64>(
|
||||||
TensorShape({}),
|
TensorShape({}),
|
||||||
{{1}, {2}, {0}, {1}, {2}, {0}, {1}, {0}, {2}, {1}, {0},
|
{{2}, {0}, {1}, {2}, {0}, {1}, {2}, {0}, {1}, {2}, {0},
|
||||||
{2}, {0}, {2}, {1}, {0}, {1}, {2}, {1}, {2}, {0}}),
|
{1}, {2}, {0}, {1}, {2}, {0}, {1}, {2}, {0}, {1}}),
|
||||||
/*expected_reshuffle_outputs=*/
|
/*expected_reshuffle_outputs=*/
|
||||||
CreateTensors<int64>(
|
CreateTensors<int64>(
|
||||||
TensorShape({}),
|
TensorShape({}),
|
||||||
{{1}, {0}, {2}, {0}, {1}, {2}, {2}, {1}, {0}, {0}, {1},
|
{{2}, {0}, {1}, {2}, {0}, {1}, {2}, {0}, {1}, {2}, {0},
|
||||||
{2}, {0}, {2}, {1}, {0}, {1}, {2}, {1}, {0}, {2}})}};
|
{1}, {2}, {0}, {1}, {2}, {0}, {1}, {2}, {0}, {1}})}};
|
||||||
}
|
}
|
||||||
|
|
||||||
class ParameterizedGetNextTest : public ShuffleDatasetOpTest,
|
class ParameterizedGetNextTest : public ShuffleDatasetOpTest,
|
||||||
@ -496,16 +494,16 @@ IteratorSaveAndRestoreTestCases() {
|
|||||||
{/*dataset_params=*/ShuffleDatasetParams7(),
|
{/*dataset_params=*/ShuffleDatasetParams7(),
|
||||||
/*breakpoints=*/{0, 5, 22},
|
/*breakpoints=*/{0, 5, 22},
|
||||||
/*expected_shuffle_outputs=*/
|
/*expected_shuffle_outputs=*/
|
||||||
CreateTensors<int64>(TensorShape({}), {{2}, {6}, {1}, {3}, {9}, {5}, {0},
|
CreateTensors<int64>(TensorShape({}), {{9}, {0}, {8}, {6}, {1}, {3}, {7},
|
||||||
{8}, {7}, {4}, {0}, {5}, {1}, {7},
|
{2}, {4}, {5}, {9}, {0}, {8}, {6},
|
||||||
{2}, {9}, {8}, {4}, {6}, {3}})},
|
{1}, {3}, {7}, {2}, {4}, {5}})},
|
||||||
{/*dataset_params=*/ShuffleDatasetParams8(),
|
{/*dataset_params=*/ShuffleDatasetParams8(),
|
||||||
/*breakpoints=*/{0, 5, 20},
|
/*breakpoints=*/{0, 5, 20},
|
||||||
/*expected_shuffle_outputs=*/
|
/*expected_shuffle_outputs=*/
|
||||||
CreateTensors<int64>(
|
CreateTensors<int64>(
|
||||||
TensorShape({}),
|
TensorShape({}),
|
||||||
{{1}, {2}, {0}, {1}, {2}, {0}, {1}, {0}, {2}, {1}, {0},
|
{{2}, {0}, {1}, {2}, {0}, {1}, {2}, {0}, {1}, {2}, {0},
|
||||||
{2}, {0}, {2}, {1}, {0}, {1}, {2}, {1}, {2}, {0}})}};
|
{1}, {2}, {0}, {1}, {2}, {0}, {1}, {2}, {0}, {1}})}};
|
||||||
}
|
}
|
||||||
|
|
||||||
class ParameterizedIteratorSaveAndRestoreTest
|
class ParameterizedIteratorSaveAndRestoreTest
|
||||||
|
@ -507,6 +507,7 @@ REGISTER_OP("ShuffleAndRepeatDataset")
|
|||||||
.Output("handle: variant")
|
.Output("handle: variant")
|
||||||
.Attr("output_types: list(type) >= 1")
|
.Attr("output_types: list(type) >= 1")
|
||||||
.Attr("output_shapes: list(shape) >= 1")
|
.Attr("output_shapes: list(shape) >= 1")
|
||||||
|
.Attr("reshuffle_each_iteration: bool = true")
|
||||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||||
shape_inference::ShapeHandle unused;
|
shape_inference::ShapeHandle unused;
|
||||||
// buffer_size, seed, seed2, and count should be scalars.
|
// buffer_size, seed, seed2, and count should be scalars.
|
||||||
@ -517,6 +518,28 @@ REGISTER_OP("ShuffleAndRepeatDataset")
|
|||||||
return shape_inference::ScalarShape(c);
|
return shape_inference::ScalarShape(c);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
REGISTER_OP("ShuffleAndRepeatDatasetV2")
|
||||||
|
.Input("input_dataset: variant")
|
||||||
|
.Input("buffer_size: int64")
|
||||||
|
.Input("seed: int64")
|
||||||
|
.Input("seed2: int64")
|
||||||
|
.Input("count: int64")
|
||||||
|
.Input("seed_generator: resource")
|
||||||
|
.Output("handle: variant")
|
||||||
|
.Attr("reshuffle_each_iteration: bool = true")
|
||||||
|
.Attr("output_types: list(type) >= 1")
|
||||||
|
.Attr("output_shapes: list(shape) >= 1")
|
||||||
|
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||||
|
shape_inference::ShapeHandle unused;
|
||||||
|
// buffer_size, seed, seed2, count, and seed_generator should be scalars.
|
||||||
|
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
|
||||||
|
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
|
||||||
|
TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
|
||||||
|
TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
|
||||||
|
TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused));
|
||||||
|
return shape_inference::ScalarShape(c);
|
||||||
|
});
|
||||||
|
|
||||||
REGISTER_OP("AnonymousMemoryCache")
|
REGISTER_OP("AnonymousMemoryCache")
|
||||||
.Output("handle: resource")
|
.Output("handle: resource")
|
||||||
.Output("deleter: variant")
|
.Output("deleter: variant")
|
||||||
|
@ -19,11 +19,9 @@ from __future__ import print_function
|
|||||||
|
|
||||||
from absl.testing import parameterized
|
from absl.testing import parameterized
|
||||||
|
|
||||||
from tensorflow.python import tf2
|
|
||||||
from tensorflow.python.data.experimental.ops import testing
|
from tensorflow.python.data.experimental.ops import testing
|
||||||
from tensorflow.python.data.kernel_tests import test_base
|
from tensorflow.python.data.kernel_tests import test_base
|
||||||
from tensorflow.python.data.ops import dataset_ops
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
from tensorflow.python.eager import context
|
|
||||||
from tensorflow.python.framework import combinations
|
from tensorflow.python.framework import combinations
|
||||||
from tensorflow.python.framework import errors
|
from tensorflow.python.framework import errors
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
@ -34,11 +32,7 @@ class ShuffleAndRepeatFusionTest(test_base.DatasetTestBase,
|
|||||||
|
|
||||||
@combinations.generate(test_base.default_test_combinations())
|
@combinations.generate(test_base.default_test_combinations())
|
||||||
def testShuffleAndRepeatFusion(self):
|
def testShuffleAndRepeatFusion(self):
|
||||||
if tf2.enabled() and context.executing_eagerly():
|
|
||||||
expected = "Shuffle"
|
|
||||||
else:
|
|
||||||
expected = "ShuffleAndRepeat"
|
expected = "ShuffleAndRepeat"
|
||||||
|
|
||||||
dataset = dataset_ops.Dataset.range(10).apply(
|
dataset = dataset_ops.Dataset.range(10).apply(
|
||||||
testing.assert_next([expected])).shuffle(10).repeat(2)
|
testing.assert_next([expected])).shuffle(10).repeat(2)
|
||||||
options = dataset_ops.Options()
|
options = dataset_ops.Options()
|
||||||
|
@ -3898,7 +3898,11 @@ tf_module {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "ShuffleAndRepeatDataset"
|
name: "ShuffleAndRepeatDataset"
|
||||||
argspec: "args=[\'input_dataset\', \'buffer_size\', \'seed\', \'seed2\', \'count\', \'output_types\', \'output_shapes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'input_dataset\', \'buffer_size\', \'seed\', \'seed2\', \'count\', \'output_types\', \'output_shapes\', \'reshuffle_each_iteration\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "ShuffleAndRepeatDatasetV2"
|
||||||
|
argspec: "args=[\'input_dataset\', \'buffer_size\', \'seed\', \'seed2\', \'count\', \'seed_generator\', \'output_types\', \'output_shapes\', \'reshuffle_each_iteration\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "ShuffleDataset"
|
name: "ShuffleDataset"
|
||||||
|
@ -3898,7 +3898,11 @@ tf_module {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "ShuffleAndRepeatDataset"
|
name: "ShuffleAndRepeatDataset"
|
||||||
argspec: "args=[\'input_dataset\', \'buffer_size\', \'seed\', \'seed2\', \'count\', \'output_types\', \'output_shapes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'input_dataset\', \'buffer_size\', \'seed\', \'seed2\', \'count\', \'output_types\', \'output_shapes\', \'reshuffle_each_iteration\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "ShuffleAndRepeatDatasetV2"
|
||||||
|
argspec: "args=[\'input_dataset\', \'buffer_size\', \'seed\', \'seed2\', \'count\', \'seed_generator\', \'output_types\', \'output_shapes\', \'reshuffle_each_iteration\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "ShuffleDataset"
|
name: "ShuffleDataset"
|
||||||
|
Loading…
Reference in New Issue
Block a user