[tf.data] Modify the optimization autotune_buffer_sizes so that it will inject autotuned PrefetchDatasets after non-prefetched asynchronous datasets. The optimization will also rewrite those existing non-autotuned PrefetchDatasets into autotuned with fixed start value and minimal value for the tunable parameter buffer_size.

PiperOrigin-RevId: 342347611
Change-Id: I1d4fe8f00944595a6fb9b5b99a1679a493b32edf
This commit is contained in:
Jay Shi 2020-11-13 15:14:42 -08:00 committed by TensorFlower Gardener
parent 5f1da7ab8c
commit 6c838fcbf8
10 changed files with 222 additions and 21 deletions

View File

@ -31,6 +31,7 @@ namespace grappler {
namespace {
constexpr char kLegacyAutotune[] = "legacy_autotune";
constexpr char kBufferSizeMin[] = "buffer_size_min";
constexpr char kPrefetchDataset[] = "PrefetchDataset";
constexpr std::array<const char*, 7> kAsyncDatasetOps = {
@ -54,8 +55,47 @@ Status AutotuneBufferSizes::OptimizeAndCollectStats(Cluster* cluster,
}
MutableGraphView graph(output);
// Add a const node with value kAutotune
NodeDef* autotune_value =
graph_utils::AddScalarConstNode(data::model::kAutotune, &graph);
absl::flat_hash_set<string> already_prefetched;
// 1) Collect about all existing `PrefetchDataset` nodes, replacing
// `prefetch(N)` with `prefetch(AUTOTUNE, buffer_size_min=N)` for all N !=-1.
for (NodeDef& node : *output->mutable_node()) {
if (node.op() == kPrefetchDataset) {
NodeDef* buffer_size_node = graph.GetNode(node.input(1));
// We only consider to rewrite if `buffer_size` is constant.
if (buffer_size_node->op() == "Const") {
int64 initial_buffer_size =
buffer_size_node->attr().at("value").tensor().int64_val(0);
if (initial_buffer_size != data::model::kAutotune) {
TF_RETURN_IF_ERROR(graph.UpdateFanin(node.name(),
{buffer_size_node->name(), 0},
{autotune_value->name(), 0}));
node.mutable_attr()->at(kBufferSizeMin).set_i(initial_buffer_size);
stats->num_changes++;
}
} else {
return errors::FailedPrecondition(
"The autotune_buffer_sizes rewrite does not currently support "
"non-constant buffer_size input.");
}
NodeDef* prefetched_node = graph_utils::GetInputNode(node, graph);
if (prefetched_node) {
already_prefetched.insert(prefetched_node->name());
}
}
}
std::vector<const NodeDef*> async_datasets;
// 2) Insert `prefetch(AUTOTUNE)` after all asynchronous transformations that
// are not followed by a `prefetch` yet.
for (const NodeDef& node : item.graph.node()) {
if (already_prefetched.find(node.name()) != already_prefetched.end()) {
continue;
}
for (const auto& async_dataset_op : kAsyncDatasetOps) {
if (node.op() == async_dataset_op) {
async_datasets.push_back(&node);
@ -67,10 +107,6 @@ Status AutotuneBufferSizes::OptimizeAndCollectStats(Cluster* cluster,
if (async_datasets.empty()) return Status::OK();
// Add a const node with value kAutotune
NodeDef* autotune_value =
graph_utils::AddScalarConstNode(data::model::kAutotune, &graph);
for (const NodeDef* async_dataset_node : async_datasets) {
NodeDef prefetch_node;
graph_utils::SetUniqueGraphNodeName(

View File

@ -24,8 +24,15 @@ namespace grappler {
constexpr char kAutotune[] = "autotune";
// This optimization adds `prefetch(AUTOTUNE)` after all asynchronous tf.data
// transformations (e.g. parallel map, parallel interleave, and map + batch).
// This optimization does the following:
//
// 1. Adds `prefetch(AUTOTUNE)` after all asynchronous tf.data transformations
// (e.g. parallel map, parallel interleave, and map + batch) if they are not
// followed by a `prefetch` yet.
//
// 2. If there exists any `prefetch(buffer_size=N)` for `N>=0`, it will replace
// the transformation with autotunable version of `prefetch` which uses N as
// the minimum size of the buffer.
class AutotuneBufferSizes : public TFDataOptimizerBase {
public:
AutotuneBufferSizes() = default;

View File

@ -148,6 +148,116 @@ TEST_P(AutotuneSetting, AutotuneBufferSizesTest) {
autotune);
}
class MultipleNodes : public ::testing::TestWithParam<std::tuple<bool, int64>> {
};
TEST_P(MultipleNodes, AutotuneBufferSizesTest) {
const bool legacy_autotune = std::get<0>(GetParam());
const int64 initial_buffer_size = std::get<1>(GetParam());
GrapplerItem item;
MutableGraphView graph(&item.graph);
NodeDef *start_val = graph_utils::AddScalarConstNode<int64>(0, &graph);
NodeDef *stop_val = graph_utils::AddScalarConstNode<int64>(10, &graph);
NodeDef *step_val = graph_utils::AddScalarConstNode<int64>(1, &graph);
std::vector<string> range_inputs(3);
range_inputs[0] = start_val->name();
range_inputs[1] = stop_val->name();
range_inputs[2] = step_val->name();
std::vector<std::pair<string, AttrValue>> range_attrs;
NodeDef *range_node = graph_utils::AddNode("range", "RangeDataset",
range_inputs, range_attrs, &graph);
NodeDef *parallelism_val = graph_utils::AddScalarConstNode<int64>(1, &graph);
std::vector<string> map_inputs1(2);
map_inputs1[0] = range_node->name();
map_inputs1[1] = parallelism_val->name();
std::vector<std::pair<string, AttrValue>> map_attrs(4);
AttrValue attr_val;
SetAttrValue("value", &attr_val);
map_attrs[0] = std::make_pair("f", attr_val);
map_attrs[1] = std::make_pair("Targuments", attr_val);
map_attrs[2] = std::make_pair("output_types", attr_val);
map_attrs[3] = std::make_pair("output_shapes", attr_val);
NodeDef *map_node1 = graph_utils::AddNode("map1", "ParallelMapDatasetV2",
map_inputs1, map_attrs, &graph);
NodeDef *buffer_size_val =
graph_utils::AddScalarConstNode<int64>(initial_buffer_size, &graph);
std::vector<string> prefetch_inputs(2);
prefetch_inputs[0] = map_node1->name();
prefetch_inputs[1] = buffer_size_val->name();
std::vector<std::pair<string, AttrValue>> prefetch_attrs(4);
AttrValue legacy_autotune_attr;
SetAttrValue(legacy_autotune, &legacy_autotune_attr);
AttrValue buffer_size_min_attr;
SetAttrValue(0, &buffer_size_min_attr);
prefetch_attrs[0] = std::make_pair("legacy_autotune", legacy_autotune_attr);
prefetch_attrs[1] = std::make_pair("buffer_size_min", buffer_size_min_attr);
prefetch_attrs[2] = std::make_pair("output_types", attr_val);
prefetch_attrs[3] = std::make_pair("output_shapes", attr_val);
NodeDef *prefetch_node = graph_utils::AddNode(
"prefetch", "PrefetchDataset", prefetch_inputs, prefetch_attrs, &graph);
std::vector<string> map_inputs2(2);
map_inputs2[0] = prefetch_node->name();
map_inputs2[1] = parallelism_val->name();
NodeDef *map_node2 = graph_utils::AddNode("map2", "ParallelMapDatasetV2",
map_inputs2, map_attrs, &graph);
std::vector<string> map_inputs3(1);
map_inputs3[0] = map_node2->name();
graph_utils::AddNode("map3", "MapDataset", map_inputs3, map_attrs, &graph);
GraphDef output;
TF_ASSERT_OK(OptimizeWithAutotuneBufferSizes(item, &output, true));
std::vector<int> prefetch_indices =
graph_utils::FindAllGraphNodesWithOp("PrefetchDataset", output);
EXPECT_EQ(prefetch_indices.size(), 2);
NodeDef new_map_node3 =
output.node(graph_utils::FindGraphNodeWithName("map3", output));
NodeDef new_prefetch_node2 = output.node(
graph_utils::FindGraphNodeWithName(new_map_node3.input(0), output));
EXPECT_EQ(new_prefetch_node2.op(), "PrefetchDataset");
EXPECT_EQ(new_prefetch_node2.input_size(), 2);
EXPECT_TRUE(new_prefetch_node2.attr().find("legacy_autotune") ==
new_prefetch_node2.attr().end());
EXPECT_TRUE(new_prefetch_node2.attr().find("buffer_size_min") ==
new_prefetch_node2.attr().end());
NodeDef new_buffer_size_val2 = output.node(
graph_utils::FindGraphNodeWithName(new_prefetch_node2.input(1), output));
EXPECT_EQ(new_buffer_size_val2.attr().at("value").tensor().int64_val(0), -1);
NodeDef new_map_node2 = output.node(
graph_utils::FindGraphNodeWithName(new_prefetch_node2.input(0), output));
EXPECT_EQ(new_map_node2.name(), "map2");
NodeDef new_prefetch_node1 = output.node(
graph_utils::FindGraphNodeWithName(new_map_node2.input(0), output));
EXPECT_EQ(new_prefetch_node1.op(), "PrefetchDataset");
EXPECT_EQ(new_prefetch_node1.input_size(), 2);
EXPECT_EQ(new_prefetch_node1.attr().at("legacy_autotune").b(),
legacy_autotune);
EXPECT_EQ(new_prefetch_node1.attr().at("buffer_size_min").i(),
(initial_buffer_size == -1 ? 0 : initial_buffer_size));
NodeDef new_buffer_size_val1 = output.node(
graph_utils::FindGraphNodeWithName(new_prefetch_node1.input(1), output));
EXPECT_EQ(new_buffer_size_val1.attr().at("value").tensor().int64_val(0), -1);
NodeDef new_map_node1 = output.node(
graph_utils::FindGraphNodeWithName(new_prefetch_node1.input(0), output));
EXPECT_EQ(new_map_node1.name(), "map1");
}
INSTANTIATE_TEST_SUITE_P(Test, MultipleNodes,
::testing::Combine(::testing::Values(true, false),
::testing::Values(-1, 3)));
INSTANTIATE_TEST_SUITE_P(Test, AutotuneSetting, ::testing::Values(false, true));
} // namespace

View File

@ -20,11 +20,12 @@ limitations under the License.
namespace tensorflow {
namespace data {
PrefetchAutotuner::PrefetchAutotuner(int64 initial_buffer_size)
PrefetchAutotuner::PrefetchAutotuner(int64 initial_buffer_size,
int64 buffer_size_min)
: buffer_limit_(initial_buffer_size) {
if (initial_buffer_size == model::kAutotune) {
mode_ = Mode::kUpswing;
buffer_limit_ = 1;
buffer_limit_ = std::max(int64{1}, buffer_size_min);
}
}

View File

@ -39,7 +39,7 @@ namespace data {
// PrefetchAutotuner is NOT thread safe.
class PrefetchAutotuner {
public:
explicit PrefetchAutotuner(int64 initial_buffer_size);
explicit PrefetchAutotuner(int64 initial_buffer_size, int64 buffer_size_min);
int64 buffer_limit() const { return buffer_limit_; }

View File

@ -23,7 +23,7 @@ namespace data {
namespace {
TEST(PrefetchAutotuner, Disabled) {
PrefetchAutotuner t(2);
PrefetchAutotuner t(2, 0);
EXPECT_EQ(2, t.buffer_limit());
t.RecordConsumption(0);
t.RecordConsumption(2);
@ -33,7 +33,7 @@ TEST(PrefetchAutotuner, Disabled) {
}
TEST(PrefetchAutotuner, Enabled) {
PrefetchAutotuner t(model::kAutotune);
PrefetchAutotuner t(model::kAutotune, 0);
EXPECT_EQ(1, t.buffer_limit());
t.RecordConsumption(0); // Expect buffer limit to stay the same.
EXPECT_EQ(1, t.buffer_limit());
@ -58,7 +58,7 @@ TEST(PrefetchAutotuner, Enabled) {
}
TEST(PrefetchAutotuner, EnabledSteady) {
PrefetchAutotuner t(model::kAutotune);
PrefetchAutotuner t(model::kAutotune, 0);
EXPECT_EQ(1, t.buffer_limit());
t.RecordConsumption(0); // Expect buffer limit to stay the same!
EXPECT_EQ(1, t.buffer_limit());
@ -80,6 +80,29 @@ TEST(PrefetchAutotuner, EnabledSteady) {
}
}
TEST(PrefetchAutotuner, StartWithMin) {
PrefetchAutotuner t(model::kAutotune, 2);
EXPECT_EQ(2, t.buffer_limit());
t.RecordConsumption(0); // Expect buffer limit to stay the same!
EXPECT_EQ(2, t.buffer_limit());
t.RecordConsumption(2); // Expect buffer limit to stay the same!
EXPECT_EQ(2, t.buffer_limit());
t.RecordConsumption(0); // Expect buffer limit to increase.
EXPECT_EQ(4, t.buffer_limit());
t.RecordConsumption(4); // Expect buffer limit to stay the same!
EXPECT_EQ(4, t.buffer_limit());
t.RecordConsumption(0); // Expect buffer limit to increase.
EXPECT_EQ(8, t.buffer_limit());
// Never reach zero again.
std::vector<size_t> consumption_values = {3, 5, 7, 1, 4, 6, 8, 3, 5, 1, 2, 4};
for (int i = 0; i < consumption_values.size(); ++i) {
t.RecordConsumption(consumption_values[i]);
EXPECT_EQ(8, t.buffer_limit())
<< "Failed at index " << i << " with value: " << consumption_values[i];
}
}
} // namespace
} // namespace data
} // namespace tensorflow

View File

@ -46,6 +46,7 @@ namespace data {
/* static */ constexpr const char* const PrefetchDatasetOp::kOutputShapes;
/* static */ constexpr const char* const PrefetchDatasetOp::kSlackPeriod;
/* static */ constexpr const char* const PrefetchDatasetOp::kLegacyAutotune;
/* static */ constexpr const char* const PrefetchDatasetOp::kBufferSizeMin;
namespace {
@ -62,12 +63,13 @@ constexpr char kErrorMessageSuffix[] = ".error_message";
class PrefetchDatasetOp::Dataset : public DatasetBase {
public:
Dataset(OpKernelContext* ctx, const DatasetBase* input, int64 buffer_size,
int64 slack_period, bool legacy_autotune)
int64 slack_period, bool legacy_autotune, int64 buffer_size_min)
: DatasetBase(DatasetContext(ctx)),
input_(input),
buffer_size_(buffer_size),
slack_period_(slack_period),
legacy_autotune_(legacy_autotune) {
legacy_autotune_(legacy_autotune),
buffer_size_min_(buffer_size_min) {
input_->Ref();
}
@ -114,10 +116,14 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
b->BuildAttrValue(slack_period_, &slack_period_attr);
AttrValue legacy_autotune_attr;
b->BuildAttrValue(legacy_autotune_, &legacy_autotune_attr);
AttrValue buffer_size_min_attr;
b->BuildAttrValue(buffer_size_min_, &buffer_size_min_attr);
TF_RETURN_IF_ERROR(
b->AddDataset(this, {input_graph_node, buffer_size},
{std::make_pair(kSlackPeriod, slack_period_attr),
std::make_pair(kLegacyAutotune, legacy_autotune_attr)},
std::make_pair(kLegacyAutotune, legacy_autotune_attr),
std::make_pair(kBufferSizeMin, buffer_size_min_attr)},
output));
return Status::OK();
}
@ -129,8 +135,12 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
: DatasetIterator<Dataset>(params),
mu_(std::make_shared<mutex>()),
cond_var_(std::make_shared<condition_variable>()),
auto_tuner_(params.dataset->buffer_size_),
buffer_size_min_(params.dataset->buffer_size_min_),
auto_tuner_(params.dataset->buffer_size_, buffer_size_min_),
legacy_autotune_(params.dataset->legacy_autotune_),
// If `legacy_autotune_`, initialize the `buffer_size_` value to be 0
// to avoid the created node to be collected as tunable nodes in the
// autotuning optimization.
buffer_size_(std::make_shared<model::SharedState>(
legacy_autotune_ ? 0 : params.dataset->buffer_size_, mu_,
cond_var_)) {
@ -145,7 +155,7 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
Status Initialize(IteratorContext* ctx) override {
mutex_lock l(*mu_);
if (buffer_size_->value == model::kAutotune) {
buffer_size_->value = 0;
buffer_size_->value = buffer_size_min_;
}
TF_RETURN_IF_ERROR(RegisterCancellationCallback(
ctx->cancellation_manager(), [this]() { CancelThreads(); },
@ -218,7 +228,8 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
return model::MakeAsyncKnownRatioNode(
std::move(args),
/*ratio=*/1,
{model::MakeParameter(kBufferSize, buffer_size_, /*min=*/0,
{model::MakeParameter(kBufferSize, buffer_size_,
/*min=*/buffer_size_min_,
/*max=*/std::numeric_limits<int64>::max())});
}
@ -536,6 +547,7 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
mutex input_mu_ TF_ACQUIRED_BEFORE(*mu_);
std::unique_ptr<IteratorBase> input_impl_ TF_GUARDED_BY(input_mu_);
const std::shared_ptr<condition_variable> cond_var_;
const int64 buffer_size_min_;
PrefetchAutotuner auto_tuner_ TF_GUARDED_BY(*mu_);
std::deque<BufferElement> buffer_ TF_GUARDED_BY(*mu_);
std::unique_ptr<Thread> prefetch_thread_ TF_GUARDED_BY(*mu_);
@ -561,6 +573,10 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
// Determines whether legacy autotuning should be used.
const bool legacy_autotune_ = true;
// If autotune is enabled, determines the minimal value of `buffer_size`
// parameter.
const int64 buffer_size_min_ = 0;
TraceMeMetadata traceme_metadata_;
};
@ -572,6 +588,9 @@ PrefetchDatasetOp::PrefetchDatasetOp(OpKernelConstruction* ctx)
if (ctx->HasAttr(kLegacyAutotune)) {
OP_REQUIRES_OK(ctx, ctx->GetAttr(kLegacyAutotune, &legacy_autotune_));
}
if (ctx->HasAttr(kBufferSizeMin)) {
OP_REQUIRES_OK(ctx, ctx->GetAttr(kBufferSizeMin, &buffer_size_min_));
}
}
void PrefetchDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
@ -588,8 +607,8 @@ void PrefetchDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
metrics::RecordTFDataAutotune(kDatasetType);
}
*output =
new Dataset(ctx, input, buffer_size, slack_period_, legacy_autotune_);
*output = new Dataset(ctx, input, buffer_size, slack_period_,
legacy_autotune_, buffer_size_min_);
}
namespace {

View File

@ -31,6 +31,7 @@ class PrefetchDatasetOp : public UnaryDatasetOpKernel {
static constexpr const char* const kOutputShapes = "output_shapes";
static constexpr const char* const kSlackPeriod = "slack_period";
static constexpr const char* const kLegacyAutotune = "legacy_autotune";
static constexpr const char* const kBufferSizeMin = "buffer_size_min";
explicit PrefetchDatasetOp(OpKernelConstruction* ctx);
@ -42,6 +43,7 @@ class PrefetchDatasetOp : public UnaryDatasetOpKernel {
class Dataset;
int64 slack_period_ = 0;
bool legacy_autotune_ = true;
int64 buffer_size_min_ = 0;
};
} // namespace data

View File

@ -60,7 +60,8 @@ class PrefetchDatasetParams : public DatasetParams {
attr_vector->emplace_back(PrefetchDatasetOp::kSlackPeriod, slack_period_);
attr_vector->emplace_back(PrefetchDatasetOp::kLegacyAutotune,
legacy_autotune_);
attr_vector->emplace_back("buffer_size_min", buffer_size_min_);
attr_vector->emplace_back(PrefetchDatasetOp::kBufferSizeMin,
buffer_size_min_);
return Status::OK();
}

View File

@ -81,10 +81,12 @@ class AutotuneBufferSizesTest(test_base.DatasetTestBase,
]))
dataset = dataset.map(
lambda x: x + 1, num_parallel_calls=dataset_ops.AUTOTUNE)
dataset = dataset.prefetch(buffer_size=3)
dataset = dataset.map(
lambda x: x + 1, num_parallel_calls=dataset_ops.AUTOTUNE)
dataset = dataset.map(
lambda x: x + 1, num_parallel_calls=dataset_ops.AUTOTUNE)
dataset = dataset.prefetch(buffer_size=dataset_ops.AUTOTUNE)
dataset = dataset.interleave(
lambda x: dataset_ops.Dataset.from_tensors(x + 1),
num_parallel_calls=dataset_ops.AUTOTUNE)