[tf.data] Supporting shuffle + repeat fusion in TF 2 (and with correct reshuffle_each_iteration semantics).

PiperOrigin-RevId: 309135252
Change-Id: I157a09027d43d8943e764573b359755bc9380167
This commit is contained in:
Jiri Simsa 2020-04-29 17:56:56 -07:00 committed by TensorFlower Gardener
parent 3dc9712013
commit 36c49a6013
10 changed files with 517 additions and 114 deletions

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "ShuffleAndRepeatDatasetV2"
visibility: HIDDEN
}

View File

@ -26,12 +26,128 @@ limitations under the License.
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/strcat.h"
namespace tensorflow {
namespace grappler {
namespace {
constexpr char kFusedOpName[] = "ShuffleAndRepeatDataset";
constexpr char kShuffleDataset[] = "ShuffleDataset";
constexpr char kShuffleDatasetV2[] = "ShuffleDatasetV2";
constexpr char kShuffleDatasetV3[] = "ShuffleDatasetV3";
constexpr char kRepeatDataset[] = "RepeatDataset";
constexpr char kShuffleAndRepeatDataset[] = "ShuffleAndRepeatDataset";
constexpr char kShuffleAndRepeatDatasetV2[] = "ShuffleAndRepeatDatasetV2";
constexpr char kOutputShapes[] = "output_shapes";
constexpr char kOutputTypes[] = "output_types";
constexpr char kReshuffleEachIteration[] = "reshuffle_each_iteration";
Status FuseShuffleV1AndRepeat(const NodeDef& shuffle_node,
const NodeDef& repeat_node,
MutableGraphView* graph, GraphDef* output,
NodeDef* fused_node) {
fused_node->set_op(kShuffleAndRepeatDataset);
graph_utils::SetUniqueGraphNodeName(kShuffleAndRepeatDataset, output,
fused_node);
// Set the `input` input argument.
fused_node->add_input(shuffle_node.input(0));
// Set the `buffer_size` input argument.
fused_node->add_input(shuffle_node.input(1));
// Set the `seed` input argument.
fused_node->add_input(shuffle_node.input(2));
// Set the `seed2` input argument.
fused_node->add_input(shuffle_node.input(3));
// Set the `count` input argument.
fused_node->add_input(repeat_node.input(1));
// Set `output_types`, `output_shapes`, and `reshuffle_each_iteration`
// attributes.
for (auto key : {kOutputShapes, kOutputTypes, kReshuffleEachIteration}) {
graph_utils::CopyAttribute(key, shuffle_node, fused_node);
}
return Status::OK();
}
Status FuseShuffleV2AndRepeat(const NodeDef& shuffle_node,
const NodeDef& repeat_node,
MutableGraphView* graph, GraphDef* output,
NodeDef* fused_node) {
fused_node->set_op(kShuffleAndRepeatDatasetV2);
graph_utils::SetUniqueGraphNodeName(kShuffleAndRepeatDatasetV2, output,
fused_node);
NodeDef zero_node = *graph_utils::AddScalarConstNode<int64>(0, graph);
// Set the `input` input argument.
fused_node->add_input(shuffle_node.input(0));
// Set the `buffer_size` input argument.
fused_node->add_input(shuffle_node.input(1));
// Default the `seed` input argument to 0.
fused_node->add_input(zero_node.name());
// Default the `seed2` input argument to 0.
fused_node->add_input(zero_node.name());
// Set the `count` input argument.
fused_node->add_input(repeat_node.input(1));
// Set the `seed_generator` input argument.
fused_node->add_input(shuffle_node.input(2));
// Set `output_types` and `output_shapes` attributes.
for (auto key : {kOutputShapes, kOutputTypes}) {
graph_utils::CopyAttribute(key, shuffle_node, fused_node);
}
// Default the `reshuffle_each_iteration` attribute to true.
(*fused_node->mutable_attr())[kReshuffleEachIteration].set_b(true);
return Status::OK();
}
Status FuseShuffleV3AndRepeat(const NodeDef& shuffle_node,
const NodeDef& repeat_node,
MutableGraphView* graph, GraphDef* output,
NodeDef* fused_node) {
fused_node->set_op(kShuffleAndRepeatDatasetV2);
graph_utils::SetUniqueGraphNodeName(kShuffleAndRepeatDataset, output,
fused_node);
// Set the `input` input argument.
fused_node->add_input(shuffle_node.input(0));
// Set the `buffer_size` input argument.
fused_node->add_input(shuffle_node.input(1));
// Set the `seed` input argument.
fused_node->add_input(shuffle_node.input(2));
// Set the `seed2` input argument.
fused_node->add_input(shuffle_node.input(3));
// Set the `count` input argument.
fused_node->add_input(repeat_node.input(1));
// Set the `seed_generator` input argument.
fused_node->add_input(shuffle_node.input(4));
// Set `output_types`, `output_shapes`, and `reshuffle_each_iteration`
// attributes.
for (auto key : {kOutputShapes, kOutputTypes, kReshuffleEachIteration}) {
graph_utils::CopyAttribute(key, shuffle_node, fused_node);
}
return Status::OK();
}
} // namespace
@ -42,65 +158,46 @@ Status ShuffleAndRepeatFusion::OptimizeAndCollectStats(
MutableGraphView graph(output);
absl::flat_hash_set<string> nodes_to_delete;
auto make_shuffle_and_repeat_node = [&output](const NodeDef& shuffle_node,
const NodeDef& repeat_node) {
NodeDef new_node;
new_node.set_op(kFusedOpName);
graph_utils::SetUniqueGraphNodeName(kFusedOpName, output, &new_node);
// Set the `input` input argument.
new_node.add_input(shuffle_node.input(0));
// Set the `buffer_size` input argument.
new_node.add_input(shuffle_node.input(1));
// Set the `seed` input argument.
new_node.add_input(shuffle_node.input(2));
// Set the `seed2` input argument.
new_node.add_input(shuffle_node.input(3));
// Set the `count` input argument.
new_node.add_input(repeat_node.input(1));
// Set `output_types` and `output_shapes` attributes.
for (auto key : {"output_shapes", "output_types"}) {
graph_utils::CopyAttribute(key, repeat_node, &new_node);
}
return new_node;
};
for (const NodeDef& node : item.graph.node()) {
if (node.op() != "RepeatDataset") {
for (const NodeDef& repeat_node : item.graph.node()) {
if (repeat_node.op() != kRepeatDataset) {
continue;
}
// Use a more descriptive variable name now that we know the node type.
const NodeDef& repeat_node = node;
NodeDef* node2 = graph_utils::GetInputNode(repeat_node, graph);
const NodeDef& shuffle_node =
*graph_utils::GetInputNode(repeat_node, graph);
if (node2->op() != "ShuffleDataset") {
NodeDef fused_node;
if (shuffle_node.op() == kShuffleDataset) {
TF_RETURN_IF_ERROR(FuseShuffleV1AndRepeat(shuffle_node, repeat_node,
&graph, output, &fused_node));
} else if (shuffle_node.op() == kShuffleDatasetV2) {
TF_RETURN_IF_ERROR(FuseShuffleV2AndRepeat(shuffle_node, repeat_node,
&graph, output, &fused_node));
} else if (shuffle_node.op() == kShuffleDatasetV3) {
TF_RETURN_IF_ERROR(FuseShuffleV3AndRepeat(shuffle_node, repeat_node,
&graph, output, &fused_node));
} else {
continue;
}
// Use a more descriptive variable name now that we know the node type.
const NodeDef& shuffle_node = *node2;
// TODO(b/129712758): Remove when the fused kernel supports disabling
// reshuffling for each iteration.
if (HasNodeAttr(shuffle_node, "reshuffle_each_iteration") &&
!shuffle_node.attr().at("reshuffle_each_iteration").b()) {
continue;
}
NodeDef* shuffle_and_repeat_node =
graph.AddNode(make_shuffle_and_repeat_node(shuffle_node, repeat_node));
NodeDef& shuffle_and_repeat_node = *graph.AddNode(std::move(fused_node));
TF_RETURN_IF_ERROR(graph.UpdateFanouts(repeat_node.name(),
shuffle_and_repeat_node->name()));
shuffle_and_repeat_node.name()));
// Update shuffle node fanouts to shuffle_and_repeat fanouts to take care of
// control dependencies.
TF_RETURN_IF_ERROR(graph.UpdateFanouts(shuffle_node.name(),
shuffle_and_repeat_node.name()));
// Mark the `Shuffle` and `Repeat` nodes for removal.
// Mark the `Shuffle` and `Repeat` nodes for removal (as long as neither of
// them needs to be preserved).
const auto nodes_to_preserve = item.NodesToPreserve();
if (nodes_to_preserve.find(shuffle_node.name()) ==
nodes_to_preserve.end() &&
nodes_to_preserve.find(repeat_node.name()) == nodes_to_preserve.end()) {
nodes_to_delete.insert(shuffle_node.name());
nodes_to_delete.insert(repeat_node.name());
}
stats->num_changes++;
}

View File

@ -25,17 +25,21 @@ namespace tensorflow {
namespace grappler {
namespace {
TEST(ShuffleAndRepeatFusionTest, FuseShuffleAndRepeatNodesIntoOne) {
constexpr char kOutputShapes[] = "output_shapes";
constexpr char kOutputTypes[] = "output_types";
constexpr char kReshuffleEachIteration[] = "reshuffle_each_iteration";
TEST(ShuffleAndRepeatFusionTest, FuseShuffleV1AndRepeat) {
GrapplerItem item;
MutableGraphView graph(&item.graph);
std::vector<std::pair<string, AttrValue>> common_attrs(2);
AttrValue shapes_attr;
SetAttrValue("output_shapes", &shapes_attr);
common_attrs[0] = std::make_pair("output_shapes", shapes_attr);
SetAttrValue(kOutputShapes, &shapes_attr);
common_attrs[0] = std::make_pair(kOutputShapes, shapes_attr);
AttrValue types_attr;
SetAttrValue("output_types", &types_attr);
common_attrs[1] = std::make_pair("output_types", types_attr);
SetAttrValue(kOutputTypes, &types_attr);
common_attrs[1] = std::make_pair(kOutputTypes, types_attr);
NodeDef *start_node = graph_utils::AddScalarConstNode<int64>(0, &graph);
NodeDef *stop_node = graph_utils::AddScalarConstNode<int64>(10, &graph);
@ -59,6 +63,7 @@ TEST(ShuffleAndRepeatFusionTest, FuseShuffleAndRepeatNodesIntoOne) {
shuffle_inputs[3] = seed2_node->name();
NodeDef *shuffle_node = graph_utils::AddNode(
"", "ShuffleDataset", shuffle_inputs, common_attrs, &graph);
(*shuffle_node->mutable_attr())[kReshuffleEachIteration].set_b(true);
NodeDef *count_node = graph_utils::AddScalarConstNode<int64>(-1, &graph);
std::vector<string> repeat_inputs(2);
@ -85,12 +90,148 @@ TEST(ShuffleAndRepeatFusionTest, FuseShuffleAndRepeatNodesIntoOne) {
EXPECT_EQ(shuffle_and_repeat_node.input(2), shuffle_node->input(2));
EXPECT_EQ(shuffle_and_repeat_node.input(3), shuffle_node->input(3));
EXPECT_EQ(shuffle_and_repeat_node.input(4), repeat_node->input(1));
for (const auto &attr :
{kOutputShapes, kOutputTypes, kReshuffleEachIteration}) {
EXPECT_TRUE(AreAttrValuesEqual(shuffle_and_repeat_node.attr().at(attr),
shuffle_node->attr().at(attr)));
}
}
TEST(ShuffleAndRepeatFusionTest, FuseShuffleV2AndRepeat) {
GrapplerItem item;
MutableGraphView graph(&item.graph);
std::vector<std::pair<string, AttrValue>> common_attrs(2);
AttrValue shapes_attr;
SetAttrValue(kOutputShapes, &shapes_attr);
common_attrs[0] = std::make_pair(kOutputShapes, shapes_attr);
AttrValue types_attr;
SetAttrValue(kOutputTypes, &types_attr);
common_attrs[1] = std::make_pair(kOutputTypes, types_attr);
NodeDef *start_node = graph_utils::AddScalarConstNode<int64>(0, &graph);
NodeDef *stop_node = graph_utils::AddScalarConstNode<int64>(10, &graph);
NodeDef *step_node = graph_utils::AddScalarConstNode<int64>(1, &graph);
std::vector<string> range_inputs(3);
range_inputs[0] = start_node->name();
range_inputs[1] = stop_node->name();
range_inputs[2] = step_node->name();
NodeDef *range_node = graph_utils::AddNode("", "RangeDataset", range_inputs,
common_attrs, &graph);
NodeDef *buffer_size_node =
graph_utils::AddScalarConstNode<int64>(128, &graph);
NodeDef *seed_generator_node =
graph_utils::AddScalarConstNode<StringPiece>("dummy_resource", &graph);
std::vector<string> shuffle_inputs(3);
shuffle_inputs[0] = range_node->name();
shuffle_inputs[1] = buffer_size_node->name();
shuffle_inputs[2] = seed_generator_node->name();
NodeDef *shuffle_node = graph_utils::AddNode(
"", "ShuffleDatasetV2", shuffle_inputs, common_attrs, &graph);
NodeDef *count_node = graph_utils::AddScalarConstNode<int64>(-1, &graph);
std::vector<string> repeat_inputs(2);
repeat_inputs[0] = shuffle_node->name();
repeat_inputs[1] = count_node->name();
NodeDef *repeat_node = graph_utils::AddNode(
"", "RepeatDataset", repeat_inputs, common_attrs, &graph);
ShuffleAndRepeatFusion optimizer;
GraphDef output;
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
EXPECT_FALSE(
graph_utils::ContainsGraphNodeWithName(shuffle_node->name(), output));
EXPECT_FALSE(
graph_utils::ContainsGraphNodeWithName(repeat_node->name(), output));
EXPECT_TRUE(
AreAttrValuesEqual(shuffle_and_repeat_node.attr().at("output_shapes"),
repeat_node->attr().at("output_shapes")));
graph_utils::ContainsNodeWithOp("ShuffleAndRepeatDatasetV2", output));
NodeDef shuffle_and_repeat_node = output.node(
graph_utils::FindGraphNodeWithOp("ShuffleAndRepeatDatasetV2", output));
EXPECT_EQ(shuffle_and_repeat_node.input_size(), 6);
EXPECT_EQ(shuffle_and_repeat_node.input(0), shuffle_node->input(0));
EXPECT_EQ(shuffle_and_repeat_node.input(1), shuffle_node->input(1));
EXPECT_EQ(shuffle_and_repeat_node.input(4), repeat_node->input(1));
EXPECT_EQ(shuffle_and_repeat_node.input(5), shuffle_node->input(2));
for (const auto &attr : {kOutputShapes, kOutputTypes}) {
EXPECT_TRUE(AreAttrValuesEqual(shuffle_and_repeat_node.attr().at(attr),
shuffle_node->attr().at(attr)));
}
EXPECT_TRUE(shuffle_and_repeat_node.attr().at(kReshuffleEachIteration).b());
}
TEST(ShuffleAndRepeatFusionTest, FuseShuffleV3AndRepeat) {
GrapplerItem item;
MutableGraphView graph(&item.graph);
std::vector<std::pair<string, AttrValue>> common_attrs(2);
AttrValue shapes_attr;
SetAttrValue(kOutputShapes, &shapes_attr);
common_attrs[0] = std::make_pair(kOutputShapes, shapes_attr);
AttrValue types_attr;
SetAttrValue(kOutputTypes, &types_attr);
common_attrs[1] = std::make_pair(kOutputTypes, types_attr);
NodeDef *start_node = graph_utils::AddScalarConstNode<int64>(0, &graph);
NodeDef *stop_node = graph_utils::AddScalarConstNode<int64>(10, &graph);
NodeDef *step_node = graph_utils::AddScalarConstNode<int64>(1, &graph);
std::vector<string> range_inputs(3);
range_inputs[0] = start_node->name();
range_inputs[1] = stop_node->name();
range_inputs[2] = step_node->name();
NodeDef *range_node = graph_utils::AddNode("", "RangeDataset", range_inputs,
common_attrs, &graph);
NodeDef *buffer_size_node =
graph_utils::AddScalarConstNode<int64>(128, &graph);
NodeDef *seed_node = graph_utils::AddScalarConstNode<int64>(-1, &graph);
NodeDef *seed2_node = graph_utils::AddScalarConstNode<int64>(-1, &graph);
NodeDef *seed_generator_node =
graph_utils::AddScalarConstNode<StringPiece>("dummy_resource", &graph);
std::vector<string> shuffle_inputs(5);
shuffle_inputs[0] = range_node->name();
shuffle_inputs[1] = buffer_size_node->name();
shuffle_inputs[2] = seed_node->name();
shuffle_inputs[3] = seed2_node->name();
shuffle_inputs[4] = seed_generator_node->name();
NodeDef *shuffle_node = graph_utils::AddNode(
"", "ShuffleDatasetV3", shuffle_inputs, common_attrs, &graph);
(*shuffle_node->mutable_attr())[kReshuffleEachIteration].set_b(true);
NodeDef *count_node = graph_utils::AddScalarConstNode<int64>(-1, &graph);
std::vector<string> repeat_inputs(2);
repeat_inputs[0] = shuffle_node->name();
repeat_inputs[1] = count_node->name();
NodeDef *repeat_node = graph_utils::AddNode(
"", "RepeatDataset", repeat_inputs, common_attrs, &graph);
ShuffleAndRepeatFusion optimizer;
GraphDef output;
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
EXPECT_FALSE(
graph_utils::ContainsGraphNodeWithName(shuffle_node->name(), output));
EXPECT_FALSE(
graph_utils::ContainsGraphNodeWithName(repeat_node->name(), output));
EXPECT_TRUE(
AreAttrValuesEqual(shuffle_and_repeat_node.attr().at("output_types"),
repeat_node->attr().at("output_types")));
graph_utils::ContainsNodeWithOp("ShuffleAndRepeatDatasetV2", output));
NodeDef shuffle_and_repeat_node = output.node(
graph_utils::FindGraphNodeWithOp("ShuffleAndRepeatDatasetV2", output));
EXPECT_EQ(shuffle_and_repeat_node.input_size(), 6);
EXPECT_EQ(shuffle_and_repeat_node.input(0), shuffle_node->input(0));
EXPECT_EQ(shuffle_and_repeat_node.input(1), shuffle_node->input(1));
EXPECT_EQ(shuffle_and_repeat_node.input(2), shuffle_node->input(2));
EXPECT_EQ(shuffle_and_repeat_node.input(3), shuffle_node->input(3));
EXPECT_EQ(shuffle_and_repeat_node.input(4), repeat_node->input(1));
EXPECT_EQ(shuffle_and_repeat_node.input(5), shuffle_node->input(4));
for (const auto &attr :
{kOutputShapes, kOutputTypes, kReshuffleEachIteration}) {
EXPECT_TRUE(AreAttrValuesEqual(shuffle_and_repeat_node.attr().at(attr),
shuffle_node->attr().at(attr)));
}
}
TEST(ShuffleAndRepeatFusionTest, NoChange) {
@ -99,11 +240,11 @@ TEST(ShuffleAndRepeatFusionTest, NoChange) {
std::vector<std::pair<string, AttrValue>> common_attrs(2);
AttrValue shapes_attr;
SetAttrValue("output_shapes", &shapes_attr);
common_attrs[0] = std::make_pair("output_shapes", shapes_attr);
SetAttrValue(kOutputShapes, &shapes_attr);
common_attrs[0] = std::make_pair(kOutputShapes, shapes_attr);
AttrValue types_attr;
SetAttrValue("output_types", &types_attr);
common_attrs[1] = std::make_pair("output_types", types_attr);
SetAttrValue(kOutputTypes, &types_attr);
common_attrs[1] = std::make_pair(kOutputTypes, types_attr);
NodeDef *start_node = graph_utils::AddScalarConstNode<int64>(0, &graph);
NodeDef *stop_node = graph_utils::AddScalarConstNode<int64>(10, &graph);

View File

@ -44,10 +44,10 @@ namespace data {
/* static */ constexpr const char* const ShuffleDatasetOpBase::kSeed2;
/* static */ constexpr const char* const ShuffleDatasetOpBase::kOutputTypes;
/* static */ constexpr const char* const ShuffleDatasetOpBase::kOutputShapes;
/* static */ constexpr const char* const
ShuffleDatasetOpBase::kReshuffleEachIteration;
/* static */ constexpr const char* const ShuffleDatasetOp::kDatasetType;
/* static */ constexpr const char* const
ShuffleDatasetOp::kReshuffleEachIteration;
/* static */ constexpr const char* const
ShuffleAndRepeatDatasetOp::kDatasetType;
@ -72,6 +72,8 @@ constexpr char kEpochNumRandomSamples[] = "epoch_num_random_samples";
constexpr char kShuffleDatasetV1[] = "ShuffleDataset";
constexpr char kShuffleDatasetV2[] = "ShuffleDatasetV2";
constexpr char kShuffleDatasetV3[] = "ShuffleDatasetV3";
constexpr char kShuffleAndRepeatDatasetV1[] = "ShuffleAndRepeatDatasetV1";
constexpr char kShuffleAndRepeatDatasetV2[] = "ShuffleAndRepeatDatasetV2";
ShuffleDatasetOpBase::ShuffleDatasetOpBase(OpKernelConstruction* ctx)
: UnaryDatasetOpKernel(ctx) {}
@ -225,6 +227,10 @@ class ShuffleDatasetOpBase::ShuffleDatasetBase : public DatasetBase {
while (!slices_.empty() &&
slices_.front()->start == slices_.front()->end) {
slices_.pop_front();
// Reinitialize the RNG state for the next epoch.
num_random_samples_ = 0;
seed_generator_->GenerateSeeds(&seed_, &seed2_);
ResetRngs();
}
DCHECK(!slices_.empty());
// Choose an element to produce uniformly at random from the first
@ -663,6 +669,7 @@ void ShuffleDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
RandomSeeds seeds(seed, seed2);
bool owns_resource = false;
if (errors::IsNotFound(s)) {
owns_resource = true;
OP_REQUIRES_OK(
ctx,
ctx->resource_manager()->LookupOrCreate<SeedGeneratorManager>(
@ -679,7 +686,6 @@ void ShuffleDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
return Status::OK();
}));
handle = MakeResourceHandle<SeedGenerator>(ctx, container, name);
owns_resource = true;
} else {
OP_REQUIRES_OK(ctx, s);
}
@ -695,6 +701,7 @@ void ShuffleDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
handle.container(), handle.name(), &manager);
bool owns_resource = false;
if (errors::IsNotFound(s)) {
owns_resource = true;
LOG(WARNING) << "Failed to find seed generator resource. Falling back to "
"using a non-deterministically seeded generator and "
"reshuffling each iteration.";
@ -708,7 +715,6 @@ void ShuffleDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
return Status::OK();
}));
handle = MakeResourceHandle<SeedGeneratorManager>(ctx, container, name);
owns_resource = true;
} else {
OP_REQUIRES_OK(ctx, s);
}
@ -790,9 +796,13 @@ class ShuffleAndRepeatDatasetOp::Dataset : public ShuffleDatasetBase {
TF_RETURN_IF_ERROR(b->AddScalar(seeds_.input_seed(), &seed));
TF_RETURN_IF_ERROR(b->AddScalar(seeds_.input_seed2(), &seed2));
TF_RETURN_IF_ERROR(b->AddScalar(count_, &count));
AttrValue reshuffle_each_iteration;
b->BuildAttrValue(seed_generator_->reshuffle_each_iteration(),
&reshuffle_each_iteration);
TF_RETURN_IF_ERROR(b->AddDataset(
this, {input_graph_node, buffer_size, seed, seed2, count}, // Inputs
{}, // Attrs
{std::make_pair(kReshuffleEachIteration,
reshuffle_each_iteration)}, // Attrs
output));
return Status::OK();
}
@ -804,8 +814,83 @@ class ShuffleAndRepeatDatasetOp::Dataset : public ShuffleDatasetBase {
const RandomSeeds seeds_;
};
class ShuffleAndRepeatDatasetOp::DatasetV2 : public ShuffleDatasetBase {
public:
DatasetV2(OpKernelContext* ctx, const DatasetBase* input, int64 buffer_size,
int64 count, RandomSeeds&& seeds, SeedGeneratorManager* manager,
ResourceHandle&& resource_handle, bool owns_resource)
: ShuffleDatasetBase(ctx, input, buffer_size, manager->get(), count),
manager_(manager),
owns_resource_(owns_resource),
resource_handle_(std::move(resource_handle)),
resource_mgr_(ctx->resource_manager()),
seeds_(std::move(seeds)) {}
~DatasetV2() override {
manager_->Unref();
if (owns_resource_) {
Status s = resource_mgr_->Delete<SeedGeneratorManager>(
resource_handle_.container(), resource_handle_.name());
if (!s.ok()) {
LOG(WARNING) << "Failed to delete RNG resource: " << s.ToString();
}
}
}
string op_type() const override { return kDatasetType; }
protected:
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,
Node** output) const override {
Node* input_graph_node = nullptr;
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
Node* buffer_size_node = nullptr;
Node* seed_node = nullptr;
Node* seed2_node = nullptr;
Node* count_node = nullptr;
TF_RETURN_IF_ERROR(b->AddScalar(buffer_size_, &buffer_size_node));
TF_RETURN_IF_ERROR(b->AddScalar(seeds_.input_seed(), &seed_node));
TF_RETURN_IF_ERROR(b->AddScalar(seeds_.input_seed2(), &seed2_node));
TF_RETURN_IF_ERROR(b->AddScalar(count_, &count_node));
Node* resource_handle_node = nullptr;
Tensor handle(DT_RESOURCE, TensorShape({}));
handle.scalar<ResourceHandle>()() = resource_handle_;
TF_RETURN_IF_ERROR(b->AddTensor(handle, &resource_handle_node));
AttrValue reshuffle_each_iteration;
b->BuildAttrValue(seed_generator_->reshuffle_each_iteration(),
&reshuffle_each_iteration);
TF_RETURN_IF_ERROR(
b->AddDataset(this,
{input_graph_node, buffer_size_node, seed_node,
seed2_node, count_node, resource_handle_node}, // Inputs
{std::make_pair(kReshuffleEachIteration,
reshuffle_each_iteration)}, // Attrs
output));
return Status::OK();
}
private:
SeedGeneratorManager* const manager_; // Owned
const bool owns_resource_;
const ResourceHandle resource_handle_;
ResourceMgr* const resource_mgr_; // Not owned.
const RandomSeeds seeds_;
};
ShuffleAndRepeatDatasetOp::ShuffleAndRepeatDatasetOp(OpKernelConstruction* ctx)
: ShuffleDatasetOpBase(ctx) {}
: ShuffleDatasetOpBase(ctx) {
auto& op_name = ctx->def().op();
if (op_name == kShuffleAndRepeatDatasetV2) {
op_version_ = 2;
} else if (op_name == kShuffleAndRepeatDatasetV1) {
op_version_ = 1;
}
if (ctx->HasAttr(kReshuffleEachIteration)) {
OP_REQUIRES_OK(
ctx, ctx->GetAttr(kReshuffleEachIteration, &reshuffle_each_iteration_));
}
}
void ShuffleAndRepeatDatasetOp::MakeDataset(OpKernelContext* ctx,
DatasetBase* input,
@ -826,30 +911,77 @@ void ShuffleAndRepeatDatasetOp::MakeDataset(OpKernelContext* ctx,
int64 count;
OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, kCount, &count));
RandomSeeds seeds(seed, seed2);
OP_REQUIRES(ctx, count > 0 || count == -1,
errors::InvalidArgument(
"count must be greater than zero or equal to -1."));
RandomSeeds seeds(seed, seed2);
static std::atomic<int64> resource_id_counter(0);
const string& container = ctx->resource_manager()->default_container();
auto name = strings::StrCat(ctx->op_kernel().name(), "/", kSeedGenerator, "_",
resource_id_counter.fetch_add(1));
if (op_version_ == 2) {
auto handle = HandleFromInput(ctx, 5);
SeedGeneratorManager* manager = nullptr;
Status s = ctx->resource_manager()->Lookup<SeedGeneratorManager>(
handle.container(), handle.name(), &manager);
bool owns_resource = false;
if (errors::IsNotFound(s)) {
owns_resource = true;
OP_REQUIRES_OK(
ctx,
ctx->resource_manager()->LookupOrCreate<SeedGeneratorManager>(
container, name, &manager,
[reshuffle = reshuffle_each_iteration_,
&seeds](SeedGeneratorManager** manager) {
if (reshuffle) {
*manager =
new SeedGeneratorManager(new RandomSeedGenerator(seeds));
} else {
*manager =
new SeedGeneratorManager(new FixedSeedGenerator(seeds));
}
return Status::OK();
}));
handle = MakeResourceHandle<SeedGenerator>(ctx, container, name);
} else {
OP_REQUIRES_OK(ctx, s);
}
// Ownership of manager is transferred onto `DatasetV2`.
*output = new ShuffleAndRepeatDatasetOp::DatasetV2(
ctx, input, buffer_size, count, std::move(seeds), manager,
std::move(handle), owns_resource);
} else {
if (op_version_ != 1) {
LOG(WARNING) << "Unsupported version of shuffle dataset op: "
<< op_version_ << ". Defaulting to version 1.";
}
SeedGeneratorManager* manager;
OP_REQUIRES_OK(
ctx,
ctx->resource_manager()->LookupOrCreate<SeedGeneratorManager>(
container, name, &manager, [&seeds](SeedGeneratorManager** manager) {
*manager = new SeedGeneratorManager(new RandomSeedGenerator(seeds));
container, name, &manager,
[reshuffle = reshuffle_each_iteration_,
&seeds](SeedGeneratorManager** manager) {
if (reshuffle) {
*manager =
new SeedGeneratorManager(new RandomSeedGenerator(seeds));
} else {
*manager =
new SeedGeneratorManager(new FixedSeedGenerator(seeds));
}
return Status::OK();
}));
auto handle = MakeResourceHandle<SeedGeneratorManager>(ctx, container, name);
auto handle =
MakeResourceHandle<SeedGeneratorManager>(ctx, container, name);
// Ownership of manager is transferred onto `Dataset`.
*output = new Dataset(ctx, input, buffer_size, std::move(seeds), manager,
count, std::move(handle));
}
}
namespace {
REGISTER_KERNEL_BUILDER(Name("ShuffleDataset").Device(DEVICE_CPU),
@ -863,6 +995,9 @@ REGISTER_KERNEL_BUILDER(Name("ShuffleDatasetV3").Device(DEVICE_CPU),
REGISTER_KERNEL_BUILDER(Name("ShuffleAndRepeatDataset").Device(DEVICE_CPU),
ShuffleAndRepeatDatasetOp);
REGISTER_KERNEL_BUILDER(Name("ShuffleAndRepeatDatasetV2").Device(DEVICE_CPU),
ShuffleAndRepeatDatasetOp);
} // namespace
} // namespace data
} // namespace tensorflow

View File

@ -28,6 +28,8 @@ class ShuffleDatasetOpBase : public UnaryDatasetOpKernel {
static constexpr const char* const kSeed2 = "seed2";
static constexpr const char* const kOutputTypes = "output_types";
static constexpr const char* const kOutputShapes = "output_shapes";
static constexpr const char* const kReshuffleEachIteration =
"reshuffle_each_iteration";
explicit ShuffleDatasetOpBase(OpKernelConstruction* ctx);
@ -38,8 +40,6 @@ class ShuffleDatasetOpBase : public UnaryDatasetOpKernel {
class ShuffleDatasetOp : public ShuffleDatasetOpBase {
public:
static constexpr const char* const kDatasetType = "Shuffle";
static constexpr const char* const kReshuffleEachIteration =
"reshuffle_each_iteration";
explicit ShuffleDatasetOp(OpKernelConstruction* ctx);
@ -52,7 +52,7 @@ class ShuffleDatasetOp : public ShuffleDatasetOpBase {
class DatasetV2;
class DatasetV3;
int op_version_ = 0;
bool reshuffle_each_iteration_;
bool reshuffle_each_iteration_ = true;
};
class ShuffleAndRepeatDatasetOp : public ShuffleDatasetOpBase {
@ -68,6 +68,9 @@ class ShuffleAndRepeatDatasetOp : public ShuffleDatasetOpBase {
private:
class Dataset;
class DatasetV2;
int op_version_ = 0;
bool reshuffle_each_iteration_ = true;
};
} // namespace data

View File

@ -72,10 +72,8 @@ class ShuffleDatasetParams : public DatasetParams {
output_dtypes_);
attr_vector->emplace_back(ShuffleDatasetOpBase::kOutputShapes,
output_shapes_);
if (count_ == 1) {
attr_vector->emplace_back(ShuffleDatasetOp::kReshuffleEachIteration,
reshuffle_each_iteration_);
}
return Status::OK();
}
@ -297,23 +295,23 @@ std::vector<GetNextTestCase<ShuffleDatasetParams>> GetNextTestCases() {
{/*dataset_params=*/ShuffleDatasetParams7(),
/*expected_shuffle_outputs=*/
CreateTensors<int64>(TensorShape({}),
{{2}, {6}, {1}, {3}, {9}, {5}, {0}, {8}, {7}, {4},
{0}, {5}, {1}, {7}, {2}, {9}, {8}, {4}, {6}, {3}}),
{{9}, {0}, {8}, {6}, {1}, {3}, {7}, {2}, {4}, {5},
{9}, {0}, {8}, {6}, {1}, {3}, {7}, {2}, {4}, {5}}),
/*expected_reshuffle_outputs=*/
CreateTensors<int64>(TensorShape({}), {{1}, {6}, {0}, {5}, {2}, {7}, {4},
{3}, {9}, {8}, {6}, {5}, {0}, {9},
{4}, {7}, {2}, {8}, {1}, {3}})},
CreateTensors<int64>(TensorShape({}), {{9}, {0}, {8}, {6}, {1}, {3}, {7},
{2}, {4}, {5}, {9}, {0}, {8}, {6},
{1}, {3}, {7}, {2}, {4}, {5}})},
{/*dataset_params=*/ShuffleDatasetParams8(),
/*expected_shuffle_outputs=*/
CreateTensors<int64>(
TensorShape({}),
{{1}, {2}, {0}, {1}, {2}, {0}, {1}, {0}, {2}, {1}, {0},
{2}, {0}, {2}, {1}, {0}, {1}, {2}, {1}, {2}, {0}}),
{{2}, {0}, {1}, {2}, {0}, {1}, {2}, {0}, {1}, {2}, {0},
{1}, {2}, {0}, {1}, {2}, {0}, {1}, {2}, {0}, {1}}),
/*expected_reshuffle_outputs=*/
CreateTensors<int64>(
TensorShape({}),
{{1}, {0}, {2}, {0}, {1}, {2}, {2}, {1}, {0}, {0}, {1},
{2}, {0}, {2}, {1}, {0}, {1}, {2}, {1}, {0}, {2}})}};
{{2}, {0}, {1}, {2}, {0}, {1}, {2}, {0}, {1}, {2}, {0},
{1}, {2}, {0}, {1}, {2}, {0}, {1}, {2}, {0}, {1}})}};
}
class ParameterizedGetNextTest : public ShuffleDatasetOpTest,
@ -496,16 +494,16 @@ IteratorSaveAndRestoreTestCases() {
{/*dataset_params=*/ShuffleDatasetParams7(),
/*breakpoints=*/{0, 5, 22},
/*expected_shuffle_outputs=*/
CreateTensors<int64>(TensorShape({}), {{2}, {6}, {1}, {3}, {9}, {5}, {0},
{8}, {7}, {4}, {0}, {5}, {1}, {7},
{2}, {9}, {8}, {4}, {6}, {3}})},
CreateTensors<int64>(TensorShape({}), {{9}, {0}, {8}, {6}, {1}, {3}, {7},
{2}, {4}, {5}, {9}, {0}, {8}, {6},
{1}, {3}, {7}, {2}, {4}, {5}})},
{/*dataset_params=*/ShuffleDatasetParams8(),
/*breakpoints=*/{0, 5, 20},
/*expected_shuffle_outputs=*/
CreateTensors<int64>(
TensorShape({}),
{{1}, {2}, {0}, {1}, {2}, {0}, {1}, {0}, {2}, {1}, {0},
{2}, {0}, {2}, {1}, {0}, {1}, {2}, {1}, {2}, {0}})}};
{{2}, {0}, {1}, {2}, {0}, {1}, {2}, {0}, {1}, {2}, {0},
{1}, {2}, {0}, {1}, {2}, {0}, {1}, {2}, {0}, {1}})}};
}
class ParameterizedIteratorSaveAndRestoreTest

View File

@ -507,6 +507,7 @@ REGISTER_OP("ShuffleAndRepeatDataset")
.Output("handle: variant")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
.Attr("reshuffle_each_iteration: bool = true")
.SetShapeFn([](shape_inference::InferenceContext* c) {
shape_inference::ShapeHandle unused;
// buffer_size, seed, seed2, and count should be scalars.
@ -517,6 +518,28 @@ REGISTER_OP("ShuffleAndRepeatDataset")
return shape_inference::ScalarShape(c);
});
REGISTER_OP("ShuffleAndRepeatDatasetV2")
.Input("input_dataset: variant")
.Input("buffer_size: int64")
.Input("seed: int64")
.Input("seed2: int64")
.Input("count: int64")
.Input("seed_generator: resource")
.Output("handle: variant")
.Attr("reshuffle_each_iteration: bool = true")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
.SetShapeFn([](shape_inference::InferenceContext* c) {
shape_inference::ShapeHandle unused;
// buffer_size, seed, seed2, count, and seed_generator should be scalars.
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused));
return shape_inference::ScalarShape(c);
});
REGISTER_OP("AnonymousMemoryCache")
.Output("handle: resource")
.Output("deleter: variant")

View File

@ -19,11 +19,9 @@ from __future__ import print_function
from absl.testing import parameterized
from tensorflow.python import tf2
from tensorflow.python.data.experimental.ops import testing
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.eager import context
from tensorflow.python.framework import combinations
from tensorflow.python.framework import errors
from tensorflow.python.platform import test
@ -34,11 +32,7 @@ class ShuffleAndRepeatFusionTest(test_base.DatasetTestBase,
@combinations.generate(test_base.default_test_combinations())
def testShuffleAndRepeatFusion(self):
if tf2.enabled() and context.executing_eagerly():
expected = "Shuffle"
else:
expected = "ShuffleAndRepeat"
dataset = dataset_ops.Dataset.range(10).apply(
testing.assert_next([expected])).shuffle(10).repeat(2)
options = dataset_ops.Options()

View File

@ -3898,7 +3898,11 @@ tf_module {
}
member_method {
name: "ShuffleAndRepeatDataset"
argspec: "args=[\'input_dataset\', \'buffer_size\', \'seed\', \'seed2\', \'count\', \'output_types\', \'output_shapes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
argspec: "args=[\'input_dataset\', \'buffer_size\', \'seed\', \'seed2\', \'count\', \'output_types\', \'output_shapes\', \'reshuffle_each_iteration\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
}
member_method {
name: "ShuffleAndRepeatDatasetV2"
argspec: "args=[\'input_dataset\', \'buffer_size\', \'seed\', \'seed2\', \'count\', \'seed_generator\', \'output_types\', \'output_shapes\', \'reshuffle_each_iteration\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
}
member_method {
name: "ShuffleDataset"

View File

@ -3898,7 +3898,11 @@ tf_module {
}
member_method {
name: "ShuffleAndRepeatDataset"
argspec: "args=[\'input_dataset\', \'buffer_size\', \'seed\', \'seed2\', \'count\', \'output_types\', \'output_shapes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
argspec: "args=[\'input_dataset\', \'buffer_size\', \'seed\', \'seed2\', \'count\', \'output_types\', \'output_shapes\', \'reshuffle_each_iteration\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
}
member_method {
name: "ShuffleAndRepeatDatasetV2"
argspec: "args=[\'input_dataset\', \'buffer_size\', \'seed\', \'seed2\', \'count\', \'seed_generator\', \'output_types\', \'output_shapes\', \'reshuffle_each_iteration\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
}
member_method {
name: "ShuffleDataset"