[tf.data] Supporting shuffle + repeat fusion in TF 2 (and with correct reshuffle_each_iteration
semantics).
PiperOrigin-RevId: 309135252 Change-Id: I157a09027d43d8943e764573b359755bc9380167
This commit is contained in:
parent
3dc9712013
commit
36c49a6013
@ -0,0 +1,4 @@
|
||||
op {
|
||||
graph_op_name: "ShuffleAndRepeatDatasetV2"
|
||||
visibility: HIDDEN
|
||||
}
|
@ -26,12 +26,128 @@ limitations under the License.
|
||||
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
|
||||
#include "tensorflow/core/grappler/utils.h"
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
#include "tensorflow/core/platform/strcat.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
namespace {
|
||||
|
||||
constexpr char kFusedOpName[] = "ShuffleAndRepeatDataset";
|
||||
constexpr char kShuffleDataset[] = "ShuffleDataset";
|
||||
constexpr char kShuffleDatasetV2[] = "ShuffleDatasetV2";
|
||||
constexpr char kShuffleDatasetV3[] = "ShuffleDatasetV3";
|
||||
constexpr char kRepeatDataset[] = "RepeatDataset";
|
||||
constexpr char kShuffleAndRepeatDataset[] = "ShuffleAndRepeatDataset";
|
||||
constexpr char kShuffleAndRepeatDatasetV2[] = "ShuffleAndRepeatDatasetV2";
|
||||
|
||||
constexpr char kOutputShapes[] = "output_shapes";
|
||||
constexpr char kOutputTypes[] = "output_types";
|
||||
constexpr char kReshuffleEachIteration[] = "reshuffle_each_iteration";
|
||||
|
||||
Status FuseShuffleV1AndRepeat(const NodeDef& shuffle_node,
|
||||
const NodeDef& repeat_node,
|
||||
MutableGraphView* graph, GraphDef* output,
|
||||
NodeDef* fused_node) {
|
||||
fused_node->set_op(kShuffleAndRepeatDataset);
|
||||
graph_utils::SetUniqueGraphNodeName(kShuffleAndRepeatDataset, output,
|
||||
fused_node);
|
||||
|
||||
// Set the `input` input argument.
|
||||
fused_node->add_input(shuffle_node.input(0));
|
||||
|
||||
// Set the `buffer_size` input argument.
|
||||
fused_node->add_input(shuffle_node.input(1));
|
||||
|
||||
// Set the `seed` input argument.
|
||||
fused_node->add_input(shuffle_node.input(2));
|
||||
|
||||
// Set the `seed2` input argument.
|
||||
fused_node->add_input(shuffle_node.input(3));
|
||||
|
||||
// Set the `count` input argument.
|
||||
fused_node->add_input(repeat_node.input(1));
|
||||
|
||||
// Set `output_types`, `output_shapes`, and `reshuffle_each_iteration`
|
||||
// attributes.
|
||||
for (auto key : {kOutputShapes, kOutputTypes, kReshuffleEachIteration}) {
|
||||
graph_utils::CopyAttribute(key, shuffle_node, fused_node);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status FuseShuffleV2AndRepeat(const NodeDef& shuffle_node,
|
||||
const NodeDef& repeat_node,
|
||||
MutableGraphView* graph, GraphDef* output,
|
||||
NodeDef* fused_node) {
|
||||
fused_node->set_op(kShuffleAndRepeatDatasetV2);
|
||||
graph_utils::SetUniqueGraphNodeName(kShuffleAndRepeatDatasetV2, output,
|
||||
fused_node);
|
||||
|
||||
NodeDef zero_node = *graph_utils::AddScalarConstNode<int64>(0, graph);
|
||||
|
||||
// Set the `input` input argument.
|
||||
fused_node->add_input(shuffle_node.input(0));
|
||||
|
||||
// Set the `buffer_size` input argument.
|
||||
fused_node->add_input(shuffle_node.input(1));
|
||||
|
||||
// Default the `seed` input argument to 0.
|
||||
fused_node->add_input(zero_node.name());
|
||||
|
||||
// Default the `seed2` input argument to 0.
|
||||
fused_node->add_input(zero_node.name());
|
||||
|
||||
// Set the `count` input argument.
|
||||
fused_node->add_input(repeat_node.input(1));
|
||||
|
||||
// Set the `seed_generator` input argument.
|
||||
fused_node->add_input(shuffle_node.input(2));
|
||||
|
||||
// Set `output_types` and `output_shapes` attributes.
|
||||
for (auto key : {kOutputShapes, kOutputTypes}) {
|
||||
graph_utils::CopyAttribute(key, shuffle_node, fused_node);
|
||||
}
|
||||
|
||||
// Default the `reshuffle_each_iteration` attribute to true.
|
||||
(*fused_node->mutable_attr())[kReshuffleEachIteration].set_b(true);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status FuseShuffleV3AndRepeat(const NodeDef& shuffle_node,
|
||||
const NodeDef& repeat_node,
|
||||
MutableGraphView* graph, GraphDef* output,
|
||||
NodeDef* fused_node) {
|
||||
fused_node->set_op(kShuffleAndRepeatDatasetV2);
|
||||
graph_utils::SetUniqueGraphNodeName(kShuffleAndRepeatDataset, output,
|
||||
fused_node);
|
||||
|
||||
// Set the `input` input argument.
|
||||
fused_node->add_input(shuffle_node.input(0));
|
||||
|
||||
// Set the `buffer_size` input argument.
|
||||
fused_node->add_input(shuffle_node.input(1));
|
||||
|
||||
// Set the `seed` input argument.
|
||||
fused_node->add_input(shuffle_node.input(2));
|
||||
|
||||
// Set the `seed2` input argument.
|
||||
fused_node->add_input(shuffle_node.input(3));
|
||||
|
||||
// Set the `count` input argument.
|
||||
fused_node->add_input(repeat_node.input(1));
|
||||
|
||||
// Set the `seed_generator` input argument.
|
||||
fused_node->add_input(shuffle_node.input(4));
|
||||
|
||||
// Set `output_types`, `output_shapes`, and `reshuffle_each_iteration`
|
||||
// attributes.
|
||||
for (auto key : {kOutputShapes, kOutputTypes, kReshuffleEachIteration}) {
|
||||
graph_utils::CopyAttribute(key, shuffle_node, fused_node);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
@ -42,65 +158,46 @@ Status ShuffleAndRepeatFusion::OptimizeAndCollectStats(
|
||||
MutableGraphView graph(output);
|
||||
absl::flat_hash_set<string> nodes_to_delete;
|
||||
|
||||
auto make_shuffle_and_repeat_node = [&output](const NodeDef& shuffle_node,
|
||||
const NodeDef& repeat_node) {
|
||||
NodeDef new_node;
|
||||
new_node.set_op(kFusedOpName);
|
||||
graph_utils::SetUniqueGraphNodeName(kFusedOpName, output, &new_node);
|
||||
|
||||
// Set the `input` input argument.
|
||||
new_node.add_input(shuffle_node.input(0));
|
||||
|
||||
// Set the `buffer_size` input argument.
|
||||
new_node.add_input(shuffle_node.input(1));
|
||||
|
||||
// Set the `seed` input argument.
|
||||
new_node.add_input(shuffle_node.input(2));
|
||||
|
||||
// Set the `seed2` input argument.
|
||||
new_node.add_input(shuffle_node.input(3));
|
||||
|
||||
// Set the `count` input argument.
|
||||
new_node.add_input(repeat_node.input(1));
|
||||
|
||||
// Set `output_types` and `output_shapes` attributes.
|
||||
for (auto key : {"output_shapes", "output_types"}) {
|
||||
graph_utils::CopyAttribute(key, repeat_node, &new_node);
|
||||
}
|
||||
return new_node;
|
||||
};
|
||||
|
||||
for (const NodeDef& node : item.graph.node()) {
|
||||
if (node.op() != "RepeatDataset") {
|
||||
for (const NodeDef& repeat_node : item.graph.node()) {
|
||||
if (repeat_node.op() != kRepeatDataset) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Use a more descriptive variable name now that we know the node type.
|
||||
const NodeDef& repeat_node = node;
|
||||
NodeDef* node2 = graph_utils::GetInputNode(repeat_node, graph);
|
||||
const NodeDef& shuffle_node =
|
||||
*graph_utils::GetInputNode(repeat_node, graph);
|
||||
|
||||
if (node2->op() != "ShuffleDataset") {
|
||||
NodeDef fused_node;
|
||||
if (shuffle_node.op() == kShuffleDataset) {
|
||||
TF_RETURN_IF_ERROR(FuseShuffleV1AndRepeat(shuffle_node, repeat_node,
|
||||
&graph, output, &fused_node));
|
||||
} else if (shuffle_node.op() == kShuffleDatasetV2) {
|
||||
TF_RETURN_IF_ERROR(FuseShuffleV2AndRepeat(shuffle_node, repeat_node,
|
||||
&graph, output, &fused_node));
|
||||
|
||||
} else if (shuffle_node.op() == kShuffleDatasetV3) {
|
||||
TF_RETURN_IF_ERROR(FuseShuffleV3AndRepeat(shuffle_node, repeat_node,
|
||||
&graph, output, &fused_node));
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Use a more descriptive variable name now that we know the node type.
|
||||
const NodeDef& shuffle_node = *node2;
|
||||
|
||||
// TODO(b/129712758): Remove when the fused kernel supports disabling
|
||||
// reshuffling for each iteration.
|
||||
if (HasNodeAttr(shuffle_node, "reshuffle_each_iteration") &&
|
||||
!shuffle_node.attr().at("reshuffle_each_iteration").b()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
NodeDef* shuffle_and_repeat_node =
|
||||
graph.AddNode(make_shuffle_and_repeat_node(shuffle_node, repeat_node));
|
||||
NodeDef& shuffle_and_repeat_node = *graph.AddNode(std::move(fused_node));
|
||||
TF_RETURN_IF_ERROR(graph.UpdateFanouts(repeat_node.name(),
|
||||
shuffle_and_repeat_node->name()));
|
||||
shuffle_and_repeat_node.name()));
|
||||
// Update shuffle node fanouts to shuffle_and_repeat fanouts to take care of
|
||||
// control dependencies.
|
||||
TF_RETURN_IF_ERROR(graph.UpdateFanouts(shuffle_node.name(),
|
||||
shuffle_and_repeat_node.name()));
|
||||
|
||||
// Mark the `Shuffle` and `Repeat` nodes for removal.
|
||||
// Mark the `Shuffle` and `Repeat` nodes for removal (as long as neither of
|
||||
// them needs to be preserved).
|
||||
const auto nodes_to_preserve = item.NodesToPreserve();
|
||||
if (nodes_to_preserve.find(shuffle_node.name()) ==
|
||||
nodes_to_preserve.end() &&
|
||||
nodes_to_preserve.find(repeat_node.name()) == nodes_to_preserve.end()) {
|
||||
nodes_to_delete.insert(shuffle_node.name());
|
||||
nodes_to_delete.insert(repeat_node.name());
|
||||
}
|
||||
stats->num_changes++;
|
||||
}
|
||||
|
||||
|
@ -25,17 +25,21 @@ namespace tensorflow {
|
||||
namespace grappler {
|
||||
namespace {
|
||||
|
||||
TEST(ShuffleAndRepeatFusionTest, FuseShuffleAndRepeatNodesIntoOne) {
|
||||
constexpr char kOutputShapes[] = "output_shapes";
|
||||
constexpr char kOutputTypes[] = "output_types";
|
||||
constexpr char kReshuffleEachIteration[] = "reshuffle_each_iteration";
|
||||
|
||||
TEST(ShuffleAndRepeatFusionTest, FuseShuffleV1AndRepeat) {
|
||||
GrapplerItem item;
|
||||
MutableGraphView graph(&item.graph);
|
||||
|
||||
std::vector<std::pair<string, AttrValue>> common_attrs(2);
|
||||
AttrValue shapes_attr;
|
||||
SetAttrValue("output_shapes", &shapes_attr);
|
||||
common_attrs[0] = std::make_pair("output_shapes", shapes_attr);
|
||||
SetAttrValue(kOutputShapes, &shapes_attr);
|
||||
common_attrs[0] = std::make_pair(kOutputShapes, shapes_attr);
|
||||
AttrValue types_attr;
|
||||
SetAttrValue("output_types", &types_attr);
|
||||
common_attrs[1] = std::make_pair("output_types", types_attr);
|
||||
SetAttrValue(kOutputTypes, &types_attr);
|
||||
common_attrs[1] = std::make_pair(kOutputTypes, types_attr);
|
||||
|
||||
NodeDef *start_node = graph_utils::AddScalarConstNode<int64>(0, &graph);
|
||||
NodeDef *stop_node = graph_utils::AddScalarConstNode<int64>(10, &graph);
|
||||
@ -59,6 +63,7 @@ TEST(ShuffleAndRepeatFusionTest, FuseShuffleAndRepeatNodesIntoOne) {
|
||||
shuffle_inputs[3] = seed2_node->name();
|
||||
NodeDef *shuffle_node = graph_utils::AddNode(
|
||||
"", "ShuffleDataset", shuffle_inputs, common_attrs, &graph);
|
||||
(*shuffle_node->mutable_attr())[kReshuffleEachIteration].set_b(true);
|
||||
|
||||
NodeDef *count_node = graph_utils::AddScalarConstNode<int64>(-1, &graph);
|
||||
std::vector<string> repeat_inputs(2);
|
||||
@ -85,12 +90,148 @@ TEST(ShuffleAndRepeatFusionTest, FuseShuffleAndRepeatNodesIntoOne) {
|
||||
EXPECT_EQ(shuffle_and_repeat_node.input(2), shuffle_node->input(2));
|
||||
EXPECT_EQ(shuffle_and_repeat_node.input(3), shuffle_node->input(3));
|
||||
EXPECT_EQ(shuffle_and_repeat_node.input(4), repeat_node->input(1));
|
||||
for (const auto &attr :
|
||||
{kOutputShapes, kOutputTypes, kReshuffleEachIteration}) {
|
||||
EXPECT_TRUE(AreAttrValuesEqual(shuffle_and_repeat_node.attr().at(attr),
|
||||
shuffle_node->attr().at(attr)));
|
||||
}
|
||||
}
|
||||
|
||||
TEST(ShuffleAndRepeatFusionTest, FuseShuffleV2AndRepeat) {
|
||||
GrapplerItem item;
|
||||
MutableGraphView graph(&item.graph);
|
||||
|
||||
std::vector<std::pair<string, AttrValue>> common_attrs(2);
|
||||
AttrValue shapes_attr;
|
||||
SetAttrValue(kOutputShapes, &shapes_attr);
|
||||
common_attrs[0] = std::make_pair(kOutputShapes, shapes_attr);
|
||||
AttrValue types_attr;
|
||||
SetAttrValue(kOutputTypes, &types_attr);
|
||||
common_attrs[1] = std::make_pair(kOutputTypes, types_attr);
|
||||
|
||||
NodeDef *start_node = graph_utils::AddScalarConstNode<int64>(0, &graph);
|
||||
NodeDef *stop_node = graph_utils::AddScalarConstNode<int64>(10, &graph);
|
||||
NodeDef *step_node = graph_utils::AddScalarConstNode<int64>(1, &graph);
|
||||
|
||||
std::vector<string> range_inputs(3);
|
||||
range_inputs[0] = start_node->name();
|
||||
range_inputs[1] = stop_node->name();
|
||||
range_inputs[2] = step_node->name();
|
||||
NodeDef *range_node = graph_utils::AddNode("", "RangeDataset", range_inputs,
|
||||
common_attrs, &graph);
|
||||
|
||||
NodeDef *buffer_size_node =
|
||||
graph_utils::AddScalarConstNode<int64>(128, &graph);
|
||||
NodeDef *seed_generator_node =
|
||||
graph_utils::AddScalarConstNode<StringPiece>("dummy_resource", &graph);
|
||||
std::vector<string> shuffle_inputs(3);
|
||||
shuffle_inputs[0] = range_node->name();
|
||||
shuffle_inputs[1] = buffer_size_node->name();
|
||||
shuffle_inputs[2] = seed_generator_node->name();
|
||||
NodeDef *shuffle_node = graph_utils::AddNode(
|
||||
"", "ShuffleDatasetV2", shuffle_inputs, common_attrs, &graph);
|
||||
|
||||
NodeDef *count_node = graph_utils::AddScalarConstNode<int64>(-1, &graph);
|
||||
std::vector<string> repeat_inputs(2);
|
||||
repeat_inputs[0] = shuffle_node->name();
|
||||
repeat_inputs[1] = count_node->name();
|
||||
NodeDef *repeat_node = graph_utils::AddNode(
|
||||
"", "RepeatDataset", repeat_inputs, common_attrs, &graph);
|
||||
|
||||
ShuffleAndRepeatFusion optimizer;
|
||||
GraphDef output;
|
||||
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
|
||||
|
||||
EXPECT_FALSE(
|
||||
graph_utils::ContainsGraphNodeWithName(shuffle_node->name(), output));
|
||||
EXPECT_FALSE(
|
||||
graph_utils::ContainsGraphNodeWithName(repeat_node->name(), output));
|
||||
EXPECT_TRUE(
|
||||
AreAttrValuesEqual(shuffle_and_repeat_node.attr().at("output_shapes"),
|
||||
repeat_node->attr().at("output_shapes")));
|
||||
graph_utils::ContainsNodeWithOp("ShuffleAndRepeatDatasetV2", output));
|
||||
NodeDef shuffle_and_repeat_node = output.node(
|
||||
graph_utils::FindGraphNodeWithOp("ShuffleAndRepeatDatasetV2", output));
|
||||
EXPECT_EQ(shuffle_and_repeat_node.input_size(), 6);
|
||||
EXPECT_EQ(shuffle_and_repeat_node.input(0), shuffle_node->input(0));
|
||||
EXPECT_EQ(shuffle_and_repeat_node.input(1), shuffle_node->input(1));
|
||||
EXPECT_EQ(shuffle_and_repeat_node.input(4), repeat_node->input(1));
|
||||
EXPECT_EQ(shuffle_and_repeat_node.input(5), shuffle_node->input(2));
|
||||
for (const auto &attr : {kOutputShapes, kOutputTypes}) {
|
||||
EXPECT_TRUE(AreAttrValuesEqual(shuffle_and_repeat_node.attr().at(attr),
|
||||
shuffle_node->attr().at(attr)));
|
||||
}
|
||||
EXPECT_TRUE(shuffle_and_repeat_node.attr().at(kReshuffleEachIteration).b());
|
||||
}
|
||||
|
||||
TEST(ShuffleAndRepeatFusionTest, FuseShuffleV3AndRepeat) {
|
||||
GrapplerItem item;
|
||||
MutableGraphView graph(&item.graph);
|
||||
|
||||
std::vector<std::pair<string, AttrValue>> common_attrs(2);
|
||||
AttrValue shapes_attr;
|
||||
SetAttrValue(kOutputShapes, &shapes_attr);
|
||||
common_attrs[0] = std::make_pair(kOutputShapes, shapes_attr);
|
||||
AttrValue types_attr;
|
||||
SetAttrValue(kOutputTypes, &types_attr);
|
||||
common_attrs[1] = std::make_pair(kOutputTypes, types_attr);
|
||||
|
||||
NodeDef *start_node = graph_utils::AddScalarConstNode<int64>(0, &graph);
|
||||
NodeDef *stop_node = graph_utils::AddScalarConstNode<int64>(10, &graph);
|
||||
NodeDef *step_node = graph_utils::AddScalarConstNode<int64>(1, &graph);
|
||||
|
||||
std::vector<string> range_inputs(3);
|
||||
range_inputs[0] = start_node->name();
|
||||
range_inputs[1] = stop_node->name();
|
||||
range_inputs[2] = step_node->name();
|
||||
NodeDef *range_node = graph_utils::AddNode("", "RangeDataset", range_inputs,
|
||||
common_attrs, &graph);
|
||||
|
||||
NodeDef *buffer_size_node =
|
||||
graph_utils::AddScalarConstNode<int64>(128, &graph);
|
||||
NodeDef *seed_node = graph_utils::AddScalarConstNode<int64>(-1, &graph);
|
||||
NodeDef *seed2_node = graph_utils::AddScalarConstNode<int64>(-1, &graph);
|
||||
NodeDef *seed_generator_node =
|
||||
graph_utils::AddScalarConstNode<StringPiece>("dummy_resource", &graph);
|
||||
std::vector<string> shuffle_inputs(5);
|
||||
shuffle_inputs[0] = range_node->name();
|
||||
shuffle_inputs[1] = buffer_size_node->name();
|
||||
shuffle_inputs[2] = seed_node->name();
|
||||
shuffle_inputs[3] = seed2_node->name();
|
||||
shuffle_inputs[4] = seed_generator_node->name();
|
||||
NodeDef *shuffle_node = graph_utils::AddNode(
|
||||
"", "ShuffleDatasetV3", shuffle_inputs, common_attrs, &graph);
|
||||
(*shuffle_node->mutable_attr())[kReshuffleEachIteration].set_b(true);
|
||||
|
||||
NodeDef *count_node = graph_utils::AddScalarConstNode<int64>(-1, &graph);
|
||||
std::vector<string> repeat_inputs(2);
|
||||
repeat_inputs[0] = shuffle_node->name();
|
||||
repeat_inputs[1] = count_node->name();
|
||||
NodeDef *repeat_node = graph_utils::AddNode(
|
||||
"", "RepeatDataset", repeat_inputs, common_attrs, &graph);
|
||||
|
||||
ShuffleAndRepeatFusion optimizer;
|
||||
GraphDef output;
|
||||
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
|
||||
|
||||
EXPECT_FALSE(
|
||||
graph_utils::ContainsGraphNodeWithName(shuffle_node->name(), output));
|
||||
EXPECT_FALSE(
|
||||
graph_utils::ContainsGraphNodeWithName(repeat_node->name(), output));
|
||||
EXPECT_TRUE(
|
||||
AreAttrValuesEqual(shuffle_and_repeat_node.attr().at("output_types"),
|
||||
repeat_node->attr().at("output_types")));
|
||||
graph_utils::ContainsNodeWithOp("ShuffleAndRepeatDatasetV2", output));
|
||||
NodeDef shuffle_and_repeat_node = output.node(
|
||||
graph_utils::FindGraphNodeWithOp("ShuffleAndRepeatDatasetV2", output));
|
||||
EXPECT_EQ(shuffle_and_repeat_node.input_size(), 6);
|
||||
EXPECT_EQ(shuffle_and_repeat_node.input(0), shuffle_node->input(0));
|
||||
EXPECT_EQ(shuffle_and_repeat_node.input(1), shuffle_node->input(1));
|
||||
EXPECT_EQ(shuffle_and_repeat_node.input(2), shuffle_node->input(2));
|
||||
EXPECT_EQ(shuffle_and_repeat_node.input(3), shuffle_node->input(3));
|
||||
EXPECT_EQ(shuffle_and_repeat_node.input(4), repeat_node->input(1));
|
||||
EXPECT_EQ(shuffle_and_repeat_node.input(5), shuffle_node->input(4));
|
||||
for (const auto &attr :
|
||||
{kOutputShapes, kOutputTypes, kReshuffleEachIteration}) {
|
||||
EXPECT_TRUE(AreAttrValuesEqual(shuffle_and_repeat_node.attr().at(attr),
|
||||
shuffle_node->attr().at(attr)));
|
||||
}
|
||||
}
|
||||
|
||||
TEST(ShuffleAndRepeatFusionTest, NoChange) {
|
||||
@ -99,11 +240,11 @@ TEST(ShuffleAndRepeatFusionTest, NoChange) {
|
||||
|
||||
std::vector<std::pair<string, AttrValue>> common_attrs(2);
|
||||
AttrValue shapes_attr;
|
||||
SetAttrValue("output_shapes", &shapes_attr);
|
||||
common_attrs[0] = std::make_pair("output_shapes", shapes_attr);
|
||||
SetAttrValue(kOutputShapes, &shapes_attr);
|
||||
common_attrs[0] = std::make_pair(kOutputShapes, shapes_attr);
|
||||
AttrValue types_attr;
|
||||
SetAttrValue("output_types", &types_attr);
|
||||
common_attrs[1] = std::make_pair("output_types", types_attr);
|
||||
SetAttrValue(kOutputTypes, &types_attr);
|
||||
common_attrs[1] = std::make_pair(kOutputTypes, types_attr);
|
||||
|
||||
NodeDef *start_node = graph_utils::AddScalarConstNode<int64>(0, &graph);
|
||||
NodeDef *stop_node = graph_utils::AddScalarConstNode<int64>(10, &graph);
|
||||
|
@ -44,10 +44,10 @@ namespace data {
|
||||
/* static */ constexpr const char* const ShuffleDatasetOpBase::kSeed2;
|
||||
/* static */ constexpr const char* const ShuffleDatasetOpBase::kOutputTypes;
|
||||
/* static */ constexpr const char* const ShuffleDatasetOpBase::kOutputShapes;
|
||||
/* static */ constexpr const char* const
|
||||
ShuffleDatasetOpBase::kReshuffleEachIteration;
|
||||
|
||||
/* static */ constexpr const char* const ShuffleDatasetOp::kDatasetType;
|
||||
/* static */ constexpr const char* const
|
||||
ShuffleDatasetOp::kReshuffleEachIteration;
|
||||
|
||||
/* static */ constexpr const char* const
|
||||
ShuffleAndRepeatDatasetOp::kDatasetType;
|
||||
@ -72,6 +72,8 @@ constexpr char kEpochNumRandomSamples[] = "epoch_num_random_samples";
|
||||
constexpr char kShuffleDatasetV1[] = "ShuffleDataset";
|
||||
constexpr char kShuffleDatasetV2[] = "ShuffleDatasetV2";
|
||||
constexpr char kShuffleDatasetV3[] = "ShuffleDatasetV3";
|
||||
constexpr char kShuffleAndRepeatDatasetV1[] = "ShuffleAndRepeatDatasetV1";
|
||||
constexpr char kShuffleAndRepeatDatasetV2[] = "ShuffleAndRepeatDatasetV2";
|
||||
|
||||
ShuffleDatasetOpBase::ShuffleDatasetOpBase(OpKernelConstruction* ctx)
|
||||
: UnaryDatasetOpKernel(ctx) {}
|
||||
@ -225,6 +227,10 @@ class ShuffleDatasetOpBase::ShuffleDatasetBase : public DatasetBase {
|
||||
while (!slices_.empty() &&
|
||||
slices_.front()->start == slices_.front()->end) {
|
||||
slices_.pop_front();
|
||||
// Reinitialize the RNG state for the next epoch.
|
||||
num_random_samples_ = 0;
|
||||
seed_generator_->GenerateSeeds(&seed_, &seed2_);
|
||||
ResetRngs();
|
||||
}
|
||||
DCHECK(!slices_.empty());
|
||||
// Choose an element to produce uniformly at random from the first
|
||||
@ -663,6 +669,7 @@ void ShuffleDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
|
||||
RandomSeeds seeds(seed, seed2);
|
||||
bool owns_resource = false;
|
||||
if (errors::IsNotFound(s)) {
|
||||
owns_resource = true;
|
||||
OP_REQUIRES_OK(
|
||||
ctx,
|
||||
ctx->resource_manager()->LookupOrCreate<SeedGeneratorManager>(
|
||||
@ -679,7 +686,6 @@ void ShuffleDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
|
||||
return Status::OK();
|
||||
}));
|
||||
handle = MakeResourceHandle<SeedGenerator>(ctx, container, name);
|
||||
owns_resource = true;
|
||||
} else {
|
||||
OP_REQUIRES_OK(ctx, s);
|
||||
}
|
||||
@ -695,6 +701,7 @@ void ShuffleDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
|
||||
handle.container(), handle.name(), &manager);
|
||||
bool owns_resource = false;
|
||||
if (errors::IsNotFound(s)) {
|
||||
owns_resource = true;
|
||||
LOG(WARNING) << "Failed to find seed generator resource. Falling back to "
|
||||
"using a non-deterministically seeded generator and "
|
||||
"reshuffling each iteration.";
|
||||
@ -708,7 +715,6 @@ void ShuffleDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
|
||||
return Status::OK();
|
||||
}));
|
||||
handle = MakeResourceHandle<SeedGeneratorManager>(ctx, container, name);
|
||||
owns_resource = true;
|
||||
} else {
|
||||
OP_REQUIRES_OK(ctx, s);
|
||||
}
|
||||
@ -790,9 +796,13 @@ class ShuffleAndRepeatDatasetOp::Dataset : public ShuffleDatasetBase {
|
||||
TF_RETURN_IF_ERROR(b->AddScalar(seeds_.input_seed(), &seed));
|
||||
TF_RETURN_IF_ERROR(b->AddScalar(seeds_.input_seed2(), &seed2));
|
||||
TF_RETURN_IF_ERROR(b->AddScalar(count_, &count));
|
||||
AttrValue reshuffle_each_iteration;
|
||||
b->BuildAttrValue(seed_generator_->reshuffle_each_iteration(),
|
||||
&reshuffle_each_iteration);
|
||||
TF_RETURN_IF_ERROR(b->AddDataset(
|
||||
this, {input_graph_node, buffer_size, seed, seed2, count}, // Inputs
|
||||
{}, // Attrs
|
||||
{std::make_pair(kReshuffleEachIteration,
|
||||
reshuffle_each_iteration)}, // Attrs
|
||||
output));
|
||||
return Status::OK();
|
||||
}
|
||||
@ -804,8 +814,83 @@ class ShuffleAndRepeatDatasetOp::Dataset : public ShuffleDatasetBase {
|
||||
const RandomSeeds seeds_;
|
||||
};
|
||||
|
||||
class ShuffleAndRepeatDatasetOp::DatasetV2 : public ShuffleDatasetBase {
|
||||
public:
|
||||
DatasetV2(OpKernelContext* ctx, const DatasetBase* input, int64 buffer_size,
|
||||
int64 count, RandomSeeds&& seeds, SeedGeneratorManager* manager,
|
||||
ResourceHandle&& resource_handle, bool owns_resource)
|
||||
: ShuffleDatasetBase(ctx, input, buffer_size, manager->get(), count),
|
||||
manager_(manager),
|
||||
owns_resource_(owns_resource),
|
||||
resource_handle_(std::move(resource_handle)),
|
||||
resource_mgr_(ctx->resource_manager()),
|
||||
seeds_(std::move(seeds)) {}
|
||||
|
||||
~DatasetV2() override {
|
||||
manager_->Unref();
|
||||
if (owns_resource_) {
|
||||
Status s = resource_mgr_->Delete<SeedGeneratorManager>(
|
||||
resource_handle_.container(), resource_handle_.name());
|
||||
if (!s.ok()) {
|
||||
LOG(WARNING) << "Failed to delete RNG resource: " << s.ToString();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
string op_type() const override { return kDatasetType; }
|
||||
|
||||
protected:
|
||||
Status AsGraphDefInternal(SerializationContext* ctx,
|
||||
DatasetGraphDefBuilder* b,
|
||||
Node** output) const override {
|
||||
Node* input_graph_node = nullptr;
|
||||
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
|
||||
Node* buffer_size_node = nullptr;
|
||||
Node* seed_node = nullptr;
|
||||
Node* seed2_node = nullptr;
|
||||
Node* count_node = nullptr;
|
||||
TF_RETURN_IF_ERROR(b->AddScalar(buffer_size_, &buffer_size_node));
|
||||
TF_RETURN_IF_ERROR(b->AddScalar(seeds_.input_seed(), &seed_node));
|
||||
TF_RETURN_IF_ERROR(b->AddScalar(seeds_.input_seed2(), &seed2_node));
|
||||
TF_RETURN_IF_ERROR(b->AddScalar(count_, &count_node));
|
||||
Node* resource_handle_node = nullptr;
|
||||
Tensor handle(DT_RESOURCE, TensorShape({}));
|
||||
handle.scalar<ResourceHandle>()() = resource_handle_;
|
||||
TF_RETURN_IF_ERROR(b->AddTensor(handle, &resource_handle_node));
|
||||
AttrValue reshuffle_each_iteration;
|
||||
b->BuildAttrValue(seed_generator_->reshuffle_each_iteration(),
|
||||
&reshuffle_each_iteration);
|
||||
TF_RETURN_IF_ERROR(
|
||||
b->AddDataset(this,
|
||||
{input_graph_node, buffer_size_node, seed_node,
|
||||
seed2_node, count_node, resource_handle_node}, // Inputs
|
||||
{std::make_pair(kReshuffleEachIteration,
|
||||
reshuffle_each_iteration)}, // Attrs
|
||||
output));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
SeedGeneratorManager* const manager_; // Owned
|
||||
const bool owns_resource_;
|
||||
const ResourceHandle resource_handle_;
|
||||
ResourceMgr* const resource_mgr_; // Not owned.
|
||||
const RandomSeeds seeds_;
|
||||
};
|
||||
|
||||
ShuffleAndRepeatDatasetOp::ShuffleAndRepeatDatasetOp(OpKernelConstruction* ctx)
|
||||
: ShuffleDatasetOpBase(ctx) {}
|
||||
: ShuffleDatasetOpBase(ctx) {
|
||||
auto& op_name = ctx->def().op();
|
||||
if (op_name == kShuffleAndRepeatDatasetV2) {
|
||||
op_version_ = 2;
|
||||
} else if (op_name == kShuffleAndRepeatDatasetV1) {
|
||||
op_version_ = 1;
|
||||
}
|
||||
if (ctx->HasAttr(kReshuffleEachIteration)) {
|
||||
OP_REQUIRES_OK(
|
||||
ctx, ctx->GetAttr(kReshuffleEachIteration, &reshuffle_each_iteration_));
|
||||
}
|
||||
}
|
||||
|
||||
void ShuffleAndRepeatDatasetOp::MakeDataset(OpKernelContext* ctx,
|
||||
DatasetBase* input,
|
||||
@ -826,30 +911,77 @@ void ShuffleAndRepeatDatasetOp::MakeDataset(OpKernelContext* ctx,
|
||||
int64 count;
|
||||
OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, kCount, &count));
|
||||
|
||||
RandomSeeds seeds(seed, seed2);
|
||||
|
||||
OP_REQUIRES(ctx, count > 0 || count == -1,
|
||||
errors::InvalidArgument(
|
||||
"count must be greater than zero or equal to -1."));
|
||||
|
||||
RandomSeeds seeds(seed, seed2);
|
||||
|
||||
static std::atomic<int64> resource_id_counter(0);
|
||||
const string& container = ctx->resource_manager()->default_container();
|
||||
auto name = strings::StrCat(ctx->op_kernel().name(), "/", kSeedGenerator, "_",
|
||||
resource_id_counter.fetch_add(1));
|
||||
if (op_version_ == 2) {
|
||||
auto handle = HandleFromInput(ctx, 5);
|
||||
SeedGeneratorManager* manager = nullptr;
|
||||
Status s = ctx->resource_manager()->Lookup<SeedGeneratorManager>(
|
||||
handle.container(), handle.name(), &manager);
|
||||
bool owns_resource = false;
|
||||
if (errors::IsNotFound(s)) {
|
||||
owns_resource = true;
|
||||
OP_REQUIRES_OK(
|
||||
ctx,
|
||||
ctx->resource_manager()->LookupOrCreate<SeedGeneratorManager>(
|
||||
container, name, &manager,
|
||||
[reshuffle = reshuffle_each_iteration_,
|
||||
&seeds](SeedGeneratorManager** manager) {
|
||||
if (reshuffle) {
|
||||
*manager =
|
||||
new SeedGeneratorManager(new RandomSeedGenerator(seeds));
|
||||
} else {
|
||||
*manager =
|
||||
new SeedGeneratorManager(new FixedSeedGenerator(seeds));
|
||||
}
|
||||
return Status::OK();
|
||||
}));
|
||||
handle = MakeResourceHandle<SeedGenerator>(ctx, container, name);
|
||||
} else {
|
||||
OP_REQUIRES_OK(ctx, s);
|
||||
}
|
||||
|
||||
// Ownership of manager is transferred onto `DatasetV2`.
|
||||
*output = new ShuffleAndRepeatDatasetOp::DatasetV2(
|
||||
ctx, input, buffer_size, count, std::move(seeds), manager,
|
||||
std::move(handle), owns_resource);
|
||||
} else {
|
||||
if (op_version_ != 1) {
|
||||
LOG(WARNING) << "Unsupported version of shuffle dataset op: "
|
||||
<< op_version_ << ". Defaulting to version 1.";
|
||||
}
|
||||
SeedGeneratorManager* manager;
|
||||
OP_REQUIRES_OK(
|
||||
ctx,
|
||||
ctx->resource_manager()->LookupOrCreate<SeedGeneratorManager>(
|
||||
container, name, &manager, [&seeds](SeedGeneratorManager** manager) {
|
||||
*manager = new SeedGeneratorManager(new RandomSeedGenerator(seeds));
|
||||
container, name, &manager,
|
||||
[reshuffle = reshuffle_each_iteration_,
|
||||
&seeds](SeedGeneratorManager** manager) {
|
||||
if (reshuffle) {
|
||||
*manager =
|
||||
new SeedGeneratorManager(new RandomSeedGenerator(seeds));
|
||||
} else {
|
||||
*manager =
|
||||
new SeedGeneratorManager(new FixedSeedGenerator(seeds));
|
||||
}
|
||||
return Status::OK();
|
||||
}));
|
||||
auto handle = MakeResourceHandle<SeedGeneratorManager>(ctx, container, name);
|
||||
auto handle =
|
||||
MakeResourceHandle<SeedGeneratorManager>(ctx, container, name);
|
||||
|
||||
// Ownership of manager is transferred onto `Dataset`.
|
||||
*output = new Dataset(ctx, input, buffer_size, std::move(seeds), manager,
|
||||
count, std::move(handle));
|
||||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
REGISTER_KERNEL_BUILDER(Name("ShuffleDataset").Device(DEVICE_CPU),
|
||||
@ -863,6 +995,9 @@ REGISTER_KERNEL_BUILDER(Name("ShuffleDatasetV3").Device(DEVICE_CPU),
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("ShuffleAndRepeatDataset").Device(DEVICE_CPU),
|
||||
ShuffleAndRepeatDatasetOp);
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("ShuffleAndRepeatDatasetV2").Device(DEVICE_CPU),
|
||||
ShuffleAndRepeatDatasetOp);
|
||||
} // namespace
|
||||
} // namespace data
|
||||
} // namespace tensorflow
|
||||
|
@ -28,6 +28,8 @@ class ShuffleDatasetOpBase : public UnaryDatasetOpKernel {
|
||||
static constexpr const char* const kSeed2 = "seed2";
|
||||
static constexpr const char* const kOutputTypes = "output_types";
|
||||
static constexpr const char* const kOutputShapes = "output_shapes";
|
||||
static constexpr const char* const kReshuffleEachIteration =
|
||||
"reshuffle_each_iteration";
|
||||
|
||||
explicit ShuffleDatasetOpBase(OpKernelConstruction* ctx);
|
||||
|
||||
@ -38,8 +40,6 @@ class ShuffleDatasetOpBase : public UnaryDatasetOpKernel {
|
||||
class ShuffleDatasetOp : public ShuffleDatasetOpBase {
|
||||
public:
|
||||
static constexpr const char* const kDatasetType = "Shuffle";
|
||||
static constexpr const char* const kReshuffleEachIteration =
|
||||
"reshuffle_each_iteration";
|
||||
|
||||
explicit ShuffleDatasetOp(OpKernelConstruction* ctx);
|
||||
|
||||
@ -52,7 +52,7 @@ class ShuffleDatasetOp : public ShuffleDatasetOpBase {
|
||||
class DatasetV2;
|
||||
class DatasetV3;
|
||||
int op_version_ = 0;
|
||||
bool reshuffle_each_iteration_;
|
||||
bool reshuffle_each_iteration_ = true;
|
||||
};
|
||||
|
||||
class ShuffleAndRepeatDatasetOp : public ShuffleDatasetOpBase {
|
||||
@ -68,6 +68,9 @@ class ShuffleAndRepeatDatasetOp : public ShuffleDatasetOpBase {
|
||||
|
||||
private:
|
||||
class Dataset;
|
||||
class DatasetV2;
|
||||
int op_version_ = 0;
|
||||
bool reshuffle_each_iteration_ = true;
|
||||
};
|
||||
|
||||
} // namespace data
|
||||
|
@ -72,10 +72,8 @@ class ShuffleDatasetParams : public DatasetParams {
|
||||
output_dtypes_);
|
||||
attr_vector->emplace_back(ShuffleDatasetOpBase::kOutputShapes,
|
||||
output_shapes_);
|
||||
if (count_ == 1) {
|
||||
attr_vector->emplace_back(ShuffleDatasetOp::kReshuffleEachIteration,
|
||||
reshuffle_each_iteration_);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -297,23 +295,23 @@ std::vector<GetNextTestCase<ShuffleDatasetParams>> GetNextTestCases() {
|
||||
{/*dataset_params=*/ShuffleDatasetParams7(),
|
||||
/*expected_shuffle_outputs=*/
|
||||
CreateTensors<int64>(TensorShape({}),
|
||||
{{2}, {6}, {1}, {3}, {9}, {5}, {0}, {8}, {7}, {4},
|
||||
{0}, {5}, {1}, {7}, {2}, {9}, {8}, {4}, {6}, {3}}),
|
||||
{{9}, {0}, {8}, {6}, {1}, {3}, {7}, {2}, {4}, {5},
|
||||
{9}, {0}, {8}, {6}, {1}, {3}, {7}, {2}, {4}, {5}}),
|
||||
/*expected_reshuffle_outputs=*/
|
||||
CreateTensors<int64>(TensorShape({}), {{1}, {6}, {0}, {5}, {2}, {7}, {4},
|
||||
{3}, {9}, {8}, {6}, {5}, {0}, {9},
|
||||
{4}, {7}, {2}, {8}, {1}, {3}})},
|
||||
CreateTensors<int64>(TensorShape({}), {{9}, {0}, {8}, {6}, {1}, {3}, {7},
|
||||
{2}, {4}, {5}, {9}, {0}, {8}, {6},
|
||||
{1}, {3}, {7}, {2}, {4}, {5}})},
|
||||
{/*dataset_params=*/ShuffleDatasetParams8(),
|
||||
/*expected_shuffle_outputs=*/
|
||||
CreateTensors<int64>(
|
||||
TensorShape({}),
|
||||
{{1}, {2}, {0}, {1}, {2}, {0}, {1}, {0}, {2}, {1}, {0},
|
||||
{2}, {0}, {2}, {1}, {0}, {1}, {2}, {1}, {2}, {0}}),
|
||||
{{2}, {0}, {1}, {2}, {0}, {1}, {2}, {0}, {1}, {2}, {0},
|
||||
{1}, {2}, {0}, {1}, {2}, {0}, {1}, {2}, {0}, {1}}),
|
||||
/*expected_reshuffle_outputs=*/
|
||||
CreateTensors<int64>(
|
||||
TensorShape({}),
|
||||
{{1}, {0}, {2}, {0}, {1}, {2}, {2}, {1}, {0}, {0}, {1},
|
||||
{2}, {0}, {2}, {1}, {0}, {1}, {2}, {1}, {0}, {2}})}};
|
||||
{{2}, {0}, {1}, {2}, {0}, {1}, {2}, {0}, {1}, {2}, {0},
|
||||
{1}, {2}, {0}, {1}, {2}, {0}, {1}, {2}, {0}, {1}})}};
|
||||
}
|
||||
|
||||
class ParameterizedGetNextTest : public ShuffleDatasetOpTest,
|
||||
@ -496,16 +494,16 @@ IteratorSaveAndRestoreTestCases() {
|
||||
{/*dataset_params=*/ShuffleDatasetParams7(),
|
||||
/*breakpoints=*/{0, 5, 22},
|
||||
/*expected_shuffle_outputs=*/
|
||||
CreateTensors<int64>(TensorShape({}), {{2}, {6}, {1}, {3}, {9}, {5}, {0},
|
||||
{8}, {7}, {4}, {0}, {5}, {1}, {7},
|
||||
{2}, {9}, {8}, {4}, {6}, {3}})},
|
||||
CreateTensors<int64>(TensorShape({}), {{9}, {0}, {8}, {6}, {1}, {3}, {7},
|
||||
{2}, {4}, {5}, {9}, {0}, {8}, {6},
|
||||
{1}, {3}, {7}, {2}, {4}, {5}})},
|
||||
{/*dataset_params=*/ShuffleDatasetParams8(),
|
||||
/*breakpoints=*/{0, 5, 20},
|
||||
/*expected_shuffle_outputs=*/
|
||||
CreateTensors<int64>(
|
||||
TensorShape({}),
|
||||
{{1}, {2}, {0}, {1}, {2}, {0}, {1}, {0}, {2}, {1}, {0},
|
||||
{2}, {0}, {2}, {1}, {0}, {1}, {2}, {1}, {2}, {0}})}};
|
||||
{{2}, {0}, {1}, {2}, {0}, {1}, {2}, {0}, {1}, {2}, {0},
|
||||
{1}, {2}, {0}, {1}, {2}, {0}, {1}, {2}, {0}, {1}})}};
|
||||
}
|
||||
|
||||
class ParameterizedIteratorSaveAndRestoreTest
|
||||
|
@ -507,6 +507,7 @@ REGISTER_OP("ShuffleAndRepeatDataset")
|
||||
.Output("handle: variant")
|
||||
.Attr("output_types: list(type) >= 1")
|
||||
.Attr("output_shapes: list(shape) >= 1")
|
||||
.Attr("reshuffle_each_iteration: bool = true")
|
||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||
shape_inference::ShapeHandle unused;
|
||||
// buffer_size, seed, seed2, and count should be scalars.
|
||||
@ -517,6 +518,28 @@ REGISTER_OP("ShuffleAndRepeatDataset")
|
||||
return shape_inference::ScalarShape(c);
|
||||
});
|
||||
|
||||
REGISTER_OP("ShuffleAndRepeatDatasetV2")
|
||||
.Input("input_dataset: variant")
|
||||
.Input("buffer_size: int64")
|
||||
.Input("seed: int64")
|
||||
.Input("seed2: int64")
|
||||
.Input("count: int64")
|
||||
.Input("seed_generator: resource")
|
||||
.Output("handle: variant")
|
||||
.Attr("reshuffle_each_iteration: bool = true")
|
||||
.Attr("output_types: list(type) >= 1")
|
||||
.Attr("output_shapes: list(shape) >= 1")
|
||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||
shape_inference::ShapeHandle unused;
|
||||
// buffer_size, seed, seed2, count, and seed_generator should be scalars.
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused));
|
||||
return shape_inference::ScalarShape(c);
|
||||
});
|
||||
|
||||
REGISTER_OP("AnonymousMemoryCache")
|
||||
.Output("handle: resource")
|
||||
.Output("deleter: variant")
|
||||
|
@ -19,11 +19,9 @@ from __future__ import print_function
|
||||
|
||||
from absl.testing import parameterized
|
||||
|
||||
from tensorflow.python import tf2
|
||||
from tensorflow.python.data.experimental.ops import testing
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import combinations
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.platform import test
|
||||
@ -34,11 +32,7 @@ class ShuffleAndRepeatFusionTest(test_base.DatasetTestBase,
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testShuffleAndRepeatFusion(self):
|
||||
if tf2.enabled() and context.executing_eagerly():
|
||||
expected = "Shuffle"
|
||||
else:
|
||||
expected = "ShuffleAndRepeat"
|
||||
|
||||
dataset = dataset_ops.Dataset.range(10).apply(
|
||||
testing.assert_next([expected])).shuffle(10).repeat(2)
|
||||
options = dataset_ops.Options()
|
||||
|
@ -3898,7 +3898,11 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "ShuffleAndRepeatDataset"
|
||||
argspec: "args=[\'input_dataset\', \'buffer_size\', \'seed\', \'seed2\', \'count\', \'output_types\', \'output_shapes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'input_dataset\', \'buffer_size\', \'seed\', \'seed2\', \'count\', \'output_types\', \'output_shapes\', \'reshuffle_each_iteration\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "ShuffleAndRepeatDatasetV2"
|
||||
argspec: "args=[\'input_dataset\', \'buffer_size\', \'seed\', \'seed2\', \'count\', \'seed_generator\', \'output_types\', \'output_shapes\', \'reshuffle_each_iteration\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "ShuffleDataset"
|
||||
|
@ -3898,7 +3898,11 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "ShuffleAndRepeatDataset"
|
||||
argspec: "args=[\'input_dataset\', \'buffer_size\', \'seed\', \'seed2\', \'count\', \'output_types\', \'output_shapes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'input_dataset\', \'buffer_size\', \'seed\', \'seed2\', \'count\', \'output_types\', \'output_shapes\', \'reshuffle_each_iteration\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "ShuffleAndRepeatDatasetV2"
|
||||
argspec: "args=[\'input_dataset\', \'buffer_size\', \'seed\', \'seed2\', \'count\', \'seed_generator\', \'output_types\', \'output_shapes\', \'reshuffle_each_iteration\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "ShuffleDataset"
|
||||
|
Loading…
Reference in New Issue
Block a user