[tf.data] Add option to control whether vectorization is aggressive (i.e. always vectorizes) or safe (i.e. uses ChooseFastestBranchDataset)
PiperOrigin-RevId: 239203789
This commit is contained in:
parent
f6dfeeeccd
commit
f9fbff63fb
@ -548,6 +548,7 @@ cc_library(
|
|||||||
hdrs = ["meta_optimizer.h"],
|
hdrs = ["meta_optimizer.h"],
|
||||||
deps = [
|
deps = [
|
||||||
"@com_google_absl//absl/container:flat_hash_map",
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
"//tensorflow/core/grappler/clusters:cluster",
|
"//tensorflow/core/grappler/clusters:cluster",
|
||||||
"//tensorflow/core/grappler/optimizers:arithmetic_optimizer",
|
"//tensorflow/core/grappler/optimizers:arithmetic_optimizer",
|
||||||
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
|
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
|
||||||
|
@ -534,44 +534,66 @@ Status MapVectorization::OptimizeAndCollectStats(Cluster* cluster,
|
|||||||
AddVectorizedFunction(*map_node, *map_func, library);
|
AddVectorizedFunction(*map_node, *map_func, library);
|
||||||
CHECK_NOTNULL(vectorized_func);
|
CHECK_NOTNULL(vectorized_func);
|
||||||
|
|
||||||
std::vector<const NodeDef*> vectorized_branch;
|
|
||||||
NodeDef* new_batch_node;
|
NodeDef* new_batch_node;
|
||||||
TF_RETURN_IF_ERROR(AddNewBatchNode(
|
TF_RETURN_IF_ERROR(AddNewBatchNode(
|
||||||
*batch_node, *input_node, *vectorized_func, &graph, &new_batch_node));
|
*batch_node, *input_node, *vectorized_func, &graph, &new_batch_node));
|
||||||
vectorized_branch.push_back(new_batch_node);
|
|
||||||
|
|
||||||
NodeDef* new_map_node;
|
NodeDef* new_map_node;
|
||||||
TF_RETURN_IF_ERROR(AddNewMapNode(*map_node, *batch_node, *new_batch_node,
|
TF_RETURN_IF_ERROR(AddNewMapNode(*map_node, *batch_node, *new_batch_node,
|
||||||
*vectorized_func, &graph, &new_map_node));
|
*vectorized_func, &graph, &new_map_node));
|
||||||
vectorized_branch.push_back(new_map_node);
|
|
||||||
|
|
||||||
|
NodeDef* optional_new_prefetch_node = nullptr;
|
||||||
if (optional_prefetch_node) {
|
if (optional_prefetch_node) {
|
||||||
// If the original pipeline was .map().prefetch().batch(), the new
|
// If the original pipeline was .map().prefetch().batch(), the new
|
||||||
// pipeline is .batch().map().prefetch()
|
// pipeline is .batch().map().prefetch()
|
||||||
NodeDef* new_prefetch_node;
|
|
||||||
TF_RETURN_IF_ERROR(AddNewPrefetchNode(*optional_prefetch_node,
|
TF_RETURN_IF_ERROR(AddNewPrefetchNode(*optional_prefetch_node,
|
||||||
*batch_node, *new_map_node, &graph,
|
*batch_node, *new_map_node, &graph,
|
||||||
&new_prefetch_node));
|
&optional_new_prefetch_node));
|
||||||
vectorized_branch.push_back(new_prefetch_node);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<const NodeDef*> vectorized_branch(
|
||||||
|
{new_batch_node, new_map_node});
|
||||||
|
|
||||||
std::vector<const NodeDef*> original_branch({map_node});
|
std::vector<const NodeDef*> original_branch({map_node});
|
||||||
if (optional_prefetch_node) {
|
if (optional_prefetch_node) {
|
||||||
original_branch.push_back(optional_prefetch_node);
|
original_branch.push_back(optional_prefetch_node);
|
||||||
|
vectorized_branch.push_back(optional_new_prefetch_node);
|
||||||
}
|
}
|
||||||
if (map_node->op() != kExperimentalMapAndBatchOp) {
|
if (batch_node->op() != kExperimentalMapAndBatchOp) {
|
||||||
original_branch.push_back(batch_node);
|
original_branch.push_back(batch_node);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Mark the original nodes for deletion.
|
||||||
|
for (const auto& n : original_branch) {
|
||||||
|
nodes_to_delete.insert(n->name());
|
||||||
|
}
|
||||||
|
|
||||||
|
if (use_choose_fastest_) {
|
||||||
|
// Optionally, use ChooseFastestBranch node to mitigate potential
|
||||||
|
// regressions caused by vectorization.
|
||||||
|
for (const auto& n : vectorized_branch) {
|
||||||
|
// Mark the vectorized nodes for deletion, since they will be added in
|
||||||
|
// the choose fastest dataset branch function separately.
|
||||||
|
nodes_to_delete.insert(n->name());
|
||||||
|
}
|
||||||
NodeDef* new_choose_fastest_node;
|
NodeDef* new_choose_fastest_node;
|
||||||
TF_RETURN_IF_ERROR(AddNewChooseFastestNode(
|
TF_RETURN_IF_ERROR(AddNewChooseFastestNode(
|
||||||
input_node, /*ratio_numerator_name=*/new_batch_node->input(1),
|
input_node, /*ratio_numerator_name=*/new_batch_node->input(1),
|
||||||
std::move(original_branch), std::move(vectorized_branch), &graph,
|
std::move(original_branch), std::move(vectorized_branch), &graph,
|
||||||
library, &new_choose_fastest_node));
|
library, &new_choose_fastest_node));
|
||||||
|
|
||||||
// Make output of Batch point to ChooseFastest instead.
|
// Make output of Batch point to ChooseFastest instead.
|
||||||
TF_RETURN_IF_ERROR(graph.UpdateFanouts(batch_node->name(),
|
TF_RETURN_IF_ERROR(graph.UpdateFanouts(batch_node->name(),
|
||||||
new_choose_fastest_node->name()));
|
new_choose_fastest_node->name()));
|
||||||
|
|
||||||
|
} else {
|
||||||
|
// Make output of Batch point to the new Map (or Prefetch) node instead.
|
||||||
|
TF_RETURN_IF_ERROR(graph.UpdateFanouts(
|
||||||
|
batch_node->name(), optional_new_prefetch_node
|
||||||
|
? optional_new_prefetch_node->name()
|
||||||
|
: new_map_node->name()));
|
||||||
|
}
|
||||||
|
|
||||||
|
TF_RETURN_IF_ERROR(graph.DeleteNodes(nodes_to_delete));
|
||||||
stats->num_changes++;
|
stats->num_changes++;
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_MAP_VECTORIZATION_H_
|
#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_MAP_VECTORIZATION_H_
|
||||||
#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_MAP_VECTORIZATION_H_
|
#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_MAP_VECTORIZATION_H_
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||||
#include "tensorflow/core/grappler/optimizers/data/optimizer_base.h"
|
#include "tensorflow/core/grappler/optimizers/data/optimizer_base.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
@ -33,10 +34,11 @@ namespace grappler {
|
|||||||
// (or map_and_batch)
|
// (or map_and_batch)
|
||||||
//
|
//
|
||||||
// To:
|
// To:
|
||||||
// input --> map --> batch --------+
|
// input --> batch --> map --> output
|
||||||
// | (or map_and_batch) |
|
//
|
||||||
// | v
|
// If the "ChooseFastest" configuration is enabled, it adds a
|
||||||
// +-----> batch --> map --> choose_fastest --> output
|
// ChooseFastestBranch dataset node to pick between the original map->batch
|
||||||
|
// branch and the vectorized batch->map branch.
|
||||||
//
|
//
|
||||||
class MapVectorization : public TFDataOptimizerBase {
|
class MapVectorization : public TFDataOptimizerBase {
|
||||||
public:
|
public:
|
||||||
@ -47,6 +49,19 @@ class MapVectorization : public TFDataOptimizerBase {
|
|||||||
|
|
||||||
Status Init(
|
Status Init(
|
||||||
const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override {
|
const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override {
|
||||||
|
if (!config) return Status::OK();
|
||||||
|
|
||||||
|
const string& choose_fastest_param =
|
||||||
|
config->parameter_map().at("use_choose_fastest").s();
|
||||||
|
if (choose_fastest_param == "true") {
|
||||||
|
use_choose_fastest_ = true;
|
||||||
|
} else if (choose_fastest_param == "false") {
|
||||||
|
use_choose_fastest_ = false;
|
||||||
|
} else {
|
||||||
|
return errors::Internal(
|
||||||
|
"Received an invalid value for parameter \"use_choose_fastest\"",
|
||||||
|
choose_fastest_param);
|
||||||
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -56,6 +71,9 @@ class MapVectorization : public TFDataOptimizerBase {
|
|||||||
|
|
||||||
void Feedback(Cluster* cluster, const GrapplerItem& item,
|
void Feedback(Cluster* cluster, const GrapplerItem& item,
|
||||||
const GraphDef& optimize_output, double result) override;
|
const GraphDef& optimize_output, double result) override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
bool use_choose_fastest_ = false;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace grappler
|
} // namespace grappler
|
||||||
|
@ -53,6 +53,19 @@ constexpr char kAttrNameDtype[] = "dtype";
|
|||||||
|
|
||||||
using test::function::NDef;
|
using test::function::NDef;
|
||||||
|
|
||||||
|
Status OptimizeWithMapVectorization(const GrapplerItem& item, GraphDef* output,
|
||||||
|
bool use_choose_fastest) {
|
||||||
|
MapVectorization optimizer;
|
||||||
|
RewriterConfig_CustomGraphOptimizer config;
|
||||||
|
if (use_choose_fastest) {
|
||||||
|
(*config.mutable_parameter_map())["use_choose_fastest"].set_s("true");
|
||||||
|
} else {
|
||||||
|
(*config.mutable_parameter_map())["use_choose_fastest"].set_s("false");
|
||||||
|
}
|
||||||
|
TF_RETURN_IF_ERROR(optimizer.Init(&config));
|
||||||
|
return optimizer.Optimize(nullptr, item, output);
|
||||||
|
}
|
||||||
|
|
||||||
// Adds a simple vectorizable map function that is akin to
|
// Adds a simple vectorizable map function that is akin to
|
||||||
// dataset.map(lambda x: tf.identity(x))
|
// dataset.map(lambda x: tf.identity(x))
|
||||||
FunctionDef* AddMapFn(MutableGraphView* graph) {
|
FunctionDef* AddMapFn(MutableGraphView* graph) {
|
||||||
@ -188,6 +201,35 @@ const FunctionDef* GetFunction(const GraphDef& graph,
|
|||||||
return &graph.library().function(found);
|
return &graph.library().function(found);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void CheckVectorizedWithoutChooseFastest(
|
||||||
|
const GraphDef& output, gtl::ArraySlice<string> expected_vectorized_branch,
|
||||||
|
const string& input_name) {
|
||||||
|
std::vector<const NodeDef*> vectorized_branch;
|
||||||
|
for (const auto& op : expected_vectorized_branch) {
|
||||||
|
// This assumes that vectorized op is the only one that exists in the graph.
|
||||||
|
// For our test cases, this is true (we don't have superfluous map/batch
|
||||||
|
// nodes in other parts of the pipeline).
|
||||||
|
ASSERT_EQ(graph_utils::FindAllGraphNodesWithOp(op, output).size(), 1);
|
||||||
|
vectorized_branch.push_back(
|
||||||
|
&output.node(graph_utils::FindGraphNodeWithOp(op, output)));
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 1; i < vectorized_branch.size() - 1; ++i) {
|
||||||
|
const NodeDef* node = vectorized_branch[i];
|
||||||
|
const NodeDef* next_node = vectorized_branch[i + 1];
|
||||||
|
ASSERT_EQ(next_node->input(0), node->name());
|
||||||
|
}
|
||||||
|
ASSERT_EQ(vectorized_branch[0]->input(0), input_name);
|
||||||
|
|
||||||
|
const NodeDef* vectorized_map_node = vectorized_branch[1];
|
||||||
|
string function_name =
|
||||||
|
vectorized_map_node->attr().at(kAttrNameF).func().name();
|
||||||
|
|
||||||
|
const FunctionDef* function = GetFunction(output, function_name);
|
||||||
|
ASSERT_NE(function, nullptr);
|
||||||
|
EXPECT_EQ(function->node_def(0).op(), "Identity");
|
||||||
|
}
|
||||||
|
|
||||||
// Checks that a graph has undergone the map_vectorization transformation
|
// Checks that a graph has undergone the map_vectorization transformation
|
||||||
// successfully, whereby the new graph has the shape:
|
// successfully, whereby the new graph has the shape:
|
||||||
//
|
//
|
||||||
@ -198,10 +240,15 @@ const FunctionDef* GetFunction(const GraphDef& graph,
|
|||||||
// |
|
// |
|
||||||
// +--> old map --> old batch
|
// +--> old map --> old batch
|
||||||
//
|
//
|
||||||
void CheckVectorized(const GraphDef& output,
|
void CheckVectorizedWithChooseFastest(
|
||||||
gtl::ArraySlice<string> expected_vectorized_branch,
|
const GraphDef& output, gtl::ArraySlice<string> expected_vectorized_branch,
|
||||||
gtl::ArraySlice<string> expected_original_branch,
|
gtl::ArraySlice<string> expected_original_branch,
|
||||||
const string& input_name) {
|
const string& input_name) {
|
||||||
|
for (const auto& op : {kBatchOp, kBatchV2Op, kMapOp, kParallelMapOp,
|
||||||
|
kExperimentalMapAndBatchOp}) {
|
||||||
|
// Check that the dataset nodes have been removed from the main graph.
|
||||||
|
ASSERT_EQ(graph_utils::FindAllGraphNodesWithOp(op, output).size(), 0);
|
||||||
|
}
|
||||||
ASSERT_EQ(
|
ASSERT_EQ(
|
||||||
graph_utils::FindAllGraphNodesWithOp(kChooseFastestOp, output).size(), 1);
|
graph_utils::FindAllGraphNodesWithOp(kChooseFastestOp, output).size(), 1);
|
||||||
const NodeDef& choose_fastest_node =
|
const NodeDef& choose_fastest_node =
|
||||||
@ -234,12 +281,13 @@ void CheckVectorized(const GraphDef& output,
|
|||||||
}
|
}
|
||||||
|
|
||||||
class MapThenBatchTest
|
class MapThenBatchTest
|
||||||
: public ::testing::TestWithParam<std::tuple<int, bool, int>> {};
|
: public ::testing::TestWithParam<std::tuple<int, bool, int, bool>> {};
|
||||||
|
|
||||||
TEST_P(MapThenBatchTest, IsVectorized) {
|
TEST_P(MapThenBatchTest, IsVectorized) {
|
||||||
int num_parallel_calls = std::get<0>(GetParam());
|
int num_parallel_calls = std::get<0>(GetParam());
|
||||||
bool use_batch_v2 = std::get<1>(GetParam());
|
bool use_batch_v2 = std::get<1>(GetParam());
|
||||||
int prefetch = std::get<2>(GetParam());
|
int prefetch = std::get<2>(GetParam());
|
||||||
|
bool use_choose_fastest = std::get<3>(GetParam());
|
||||||
GrapplerItem item;
|
GrapplerItem item;
|
||||||
MutableGraphView graph(&item.graph);
|
MutableGraphView graph(&item.graph);
|
||||||
auto range_dataset = AddRangeNode(&graph);
|
auto range_dataset = AddRangeNode(&graph);
|
||||||
@ -251,9 +299,8 @@ TEST_P(MapThenBatchTest, IsVectorized) {
|
|||||||
dataset = AddPrefetchNode(&graph, dataset->name(), prefetch);
|
dataset = AddPrefetchNode(&graph, dataset->name(), prefetch);
|
||||||
}
|
}
|
||||||
dataset = AddBatchNode(&graph, dataset->name(), use_batch_v2);
|
dataset = AddBatchNode(&graph, dataset->name(), use_batch_v2);
|
||||||
MapVectorization optimizer;
|
|
||||||
GraphDef output;
|
GraphDef output;
|
||||||
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
|
TF_ASSERT_OK(OptimizeWithMapVectorization(item, &output, use_choose_fastest));
|
||||||
|
|
||||||
std::vector<string> expected_original_branch;
|
std::vector<string> expected_original_branch;
|
||||||
expected_original_branch.push_back(num_parallel_calls > 0 ? kParallelMapOp
|
expected_original_branch.push_back(num_parallel_calls > 0 ? kParallelMapOp
|
||||||
@ -272,14 +319,24 @@ TEST_P(MapThenBatchTest, IsVectorized) {
|
|||||||
expected_vectorized_branch.push_back(kPrefetchOp);
|
expected_vectorized_branch.push_back(kPrefetchOp);
|
||||||
}
|
}
|
||||||
|
|
||||||
CheckVectorized(output, expected_vectorized_branch, expected_original_branch,
|
if (use_choose_fastest) {
|
||||||
|
CheckVectorizedWithChooseFastest(output, expected_vectorized_branch,
|
||||||
|
expected_original_branch,
|
||||||
range_dataset->name());
|
range_dataset->name());
|
||||||
|
|
||||||
|
} else {
|
||||||
|
CheckVectorizedWithoutChooseFastest(output, expected_vectorized_branch,
|
||||||
|
range_dataset->name());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
INSTANTIATE_TEST_SUITE_P(MapThenBatchTest, MapThenBatchTest,
|
INSTANTIATE_TEST_SUITE_P(MapThenBatchTest, MapThenBatchTest,
|
||||||
::testing::Combine(::testing::Values(0, 12),
|
::testing::Combine(::testing::Values(0, 12),
|
||||||
::testing::Bool(),
|
::testing::Bool(),
|
||||||
::testing::Values(0, 20)));
|
::testing::Values(0, 20),
|
||||||
|
::testing::Bool()));
|
||||||
|
|
||||||
|
class MapAndBatchTest : public ::testing::TestWithParam<bool> {};
|
||||||
|
|
||||||
NodeDef* AddMapAndBatchNode(MutableGraphView* graph,
|
NodeDef* AddMapAndBatchNode(MutableGraphView* graph,
|
||||||
const string& input_dataset, const string& map_fn,
|
const string& input_dataset, const string& map_fn,
|
||||||
@ -307,7 +364,7 @@ NodeDef* AddMapAndBatchNode(MutableGraphView* graph,
|
|||||||
return graph->AddNode(std::move(result));
|
return graph->AddNode(std::move(result));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(MapVectorizationTest, VectorizeExperimentalMapAndBatch) {
|
TEST_P(MapAndBatchTest, VectorizeExperimentalMapAndBatch) {
|
||||||
GrapplerItem item;
|
GrapplerItem item;
|
||||||
MutableGraphView graph(&item.graph);
|
MutableGraphView graph(&item.graph);
|
||||||
auto range_node = AddRangeNode(&graph);
|
auto range_node = AddRangeNode(&graph);
|
||||||
@ -316,16 +373,24 @@ TEST(MapVectorizationTest, VectorizeExperimentalMapAndBatch) {
|
|||||||
map_fn->signature().name());
|
map_fn->signature().name());
|
||||||
ASSERT_NE(map_and_batch_node, nullptr);
|
ASSERT_NE(map_and_batch_node, nullptr);
|
||||||
|
|
||||||
MapVectorization optimizer;
|
|
||||||
GraphDef output;
|
GraphDef output;
|
||||||
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
|
bool use_choose_fastest = GetParam();
|
||||||
|
|
||||||
CheckVectorized(output, {kBatchV2Op, kParallelMapOp},
|
TF_ASSERT_OK(OptimizeWithMapVectorization(item, &output, use_choose_fastest));
|
||||||
{kExperimentalMapAndBatchOp}, range_node->name());
|
if (use_choose_fastest) {
|
||||||
|
CheckVectorizedWithChooseFastest(output, {kBatchV2Op, kParallelMapOp},
|
||||||
|
{kExperimentalMapAndBatchOp},
|
||||||
|
range_node->name());
|
||||||
|
} else {
|
||||||
|
CheckVectorizedWithoutChooseFastest(output, {kBatchV2Op, kParallelMapOp},
|
||||||
|
range_node->name());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_SUITE_P(MapAndBatchTest, MapAndBatchTest, ::testing::Bool());
|
||||||
|
|
||||||
class ChainedMapAndBatchTest
|
class ChainedMapAndBatchTest
|
||||||
: public ::testing::TestWithParam<std::tuple<bool, bool>> {};
|
: public ::testing::TestWithParam<std::tuple<bool, bool, bool>> {};
|
||||||
|
|
||||||
// Tests:
|
// Tests:
|
||||||
// 1) map.batch.map.batch
|
// 1) map.batch.map.batch
|
||||||
@ -352,15 +417,16 @@ TEST_P(ChainedMapAndBatchTest, IsVectorized) {
|
|||||||
|
|
||||||
bool fuse_0 = std::get<0>(GetParam());
|
bool fuse_0 = std::get<0>(GetParam());
|
||||||
bool fuse_1 = std::get<1>(GetParam());
|
bool fuse_1 = std::get<1>(GetParam());
|
||||||
|
bool use_choose_fastest = std::get<2>(GetParam());
|
||||||
auto map_and_batch_0 = make_map_and_batch(input_node, fuse_0);
|
auto map_and_batch_0 = make_map_and_batch(input_node, fuse_0);
|
||||||
auto map_and_batch_1 = make_map_and_batch(map_and_batch_0, fuse_1);
|
auto map_and_batch_1 = make_map_and_batch(map_and_batch_0, fuse_1);
|
||||||
ASSERT_NE(map_and_batch_1, nullptr);
|
ASSERT_NE(map_and_batch_1, nullptr);
|
||||||
|
|
||||||
MapVectorization optimizer;
|
|
||||||
GraphDef output;
|
GraphDef output;
|
||||||
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
|
TF_ASSERT_OK(OptimizeWithMapVectorization(item, &output, use_choose_fastest));
|
||||||
TF_ASSERT_OK(TopologicalSort(&output));
|
TF_ASSERT_OK(TopologicalSort(&output));
|
||||||
|
|
||||||
|
if (use_choose_fastest) {
|
||||||
std::vector<int> choose_fastest_nodes =
|
std::vector<int> choose_fastest_nodes =
|
||||||
graph_utils::FindAllGraphNodesWithOp(kChooseFastestOp, output);
|
graph_utils::FindAllGraphNodesWithOp(kChooseFastestOp, output);
|
||||||
ASSERT_EQ(choose_fastest_nodes.size(), 2);
|
ASSERT_EQ(choose_fastest_nodes.size(), 2);
|
||||||
@ -392,12 +458,35 @@ TEST_P(ChainedMapAndBatchTest, IsVectorized) {
|
|||||||
CheckBranch(*branch_1, original_ops);
|
CheckBranch(*branch_1, original_ops);
|
||||||
};
|
};
|
||||||
|
|
||||||
check_branches(choose_fastest_0, fuse_0 ? fused_sequence : unfused_sequence);
|
check_branches(choose_fastest_0,
|
||||||
check_branches(choose_fastest_1, fuse_1 ? fused_sequence : unfused_sequence);
|
fuse_0 ? fused_sequence : unfused_sequence);
|
||||||
|
check_branches(choose_fastest_1,
|
||||||
|
fuse_1 ? fused_sequence : unfused_sequence);
|
||||||
|
} else {
|
||||||
|
std::vector<int> map_nodes =
|
||||||
|
graph_utils::FindAllGraphNodesWithOp(kParallelMapOp, output);
|
||||||
|
std::vector<int> batch_nodes =
|
||||||
|
graph_utils::FindAllGraphNodesWithOp(kBatchV2Op, output);
|
||||||
|
ASSERT_EQ(map_nodes.size(), 2);
|
||||||
|
ASSERT_EQ(batch_nodes.size(), 2);
|
||||||
|
|
||||||
|
const NodeDef& range_node =
|
||||||
|
output.node(graph_utils::FindGraphNodeWithOp(kRangeOp, output));
|
||||||
|
|
||||||
|
const NodeDef& batch_node_0 = output.node(batch_nodes[0]);
|
||||||
|
EXPECT_EQ(batch_node_0.input(0), range_node.name());
|
||||||
|
const NodeDef& map_node_0 = output.node(map_nodes[0]);
|
||||||
|
EXPECT_EQ(map_node_0.input(0), batch_node_0.name());
|
||||||
|
const NodeDef& batch_node_1 = output.node(batch_nodes[1]);
|
||||||
|
EXPECT_EQ(batch_node_1.input(0), map_node_0.name());
|
||||||
|
const NodeDef& map_node_1 = output.node(map_nodes[1]);
|
||||||
|
EXPECT_EQ(map_node_1.input(0), batch_node_1.name());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
INSTANTIATE_TEST_SUITE_P(ChainedMapAndBatchTest, ChainedMapAndBatchTest,
|
INSTANTIATE_TEST_SUITE_P(ChainedMapAndBatchTest, ChainedMapAndBatchTest,
|
||||||
::testing::Combine(::testing::Bool(),
|
::testing::Combine(::testing::Bool(),
|
||||||
|
::testing::Bool(),
|
||||||
::testing::Bool()));
|
::testing::Bool()));
|
||||||
|
|
||||||
// Not all dataset types have "output_shapes" and "output_types"
|
// Not all dataset types have "output_shapes" and "output_types"
|
||||||
@ -434,9 +523,8 @@ TEST(MapVectorizationTest, VectorizeWithUndefinedOutputShapes) {
|
|||||||
auto map_node =
|
auto map_node =
|
||||||
AddMapNode(&graph, input_node->name(), map_fn->signature().name());
|
AddMapNode(&graph, input_node->name(), map_fn->signature().name());
|
||||||
auto batch_node = AddBatchNode(&graph, map_node->name());
|
auto batch_node = AddBatchNode(&graph, map_node->name());
|
||||||
MapVectorization optimizer;
|
|
||||||
GraphDef output;
|
GraphDef output;
|
||||||
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
|
TF_ASSERT_OK(OptimizeWithMapVectorization(item, &output, true));
|
||||||
CheckNotVectorized(output, map_node->op(), batch_node->op(),
|
CheckNotVectorized(output, map_node->op(), batch_node->op(),
|
||||||
input_node->name());
|
input_node->name());
|
||||||
}
|
}
|
||||||
@ -454,9 +542,8 @@ TEST(MapVectorizationTest, VectorizeWithUnknownRank) {
|
|||||||
auto map_node =
|
auto map_node =
|
||||||
AddMapNode(&graph, input_node->name(), map_fn->signature().name());
|
AddMapNode(&graph, input_node->name(), map_fn->signature().name());
|
||||||
auto batch_node = AddBatchNode(&graph, map_node->name());
|
auto batch_node = AddBatchNode(&graph, map_node->name());
|
||||||
MapVectorization optimizer;
|
|
||||||
GraphDef output;
|
GraphDef output;
|
||||||
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
|
TF_ASSERT_OK(OptimizeWithMapVectorization(item, &output, true));
|
||||||
CheckNotVectorized(output, map_node->op(), batch_node->op(),
|
CheckNotVectorized(output, map_node->op(), batch_node->op(),
|
||||||
input_node->name());
|
input_node->name());
|
||||||
}
|
}
|
||||||
@ -474,9 +561,8 @@ TEST(MapVectorizationTest, VectorizeWithUnknownDim) {
|
|||||||
auto map_node =
|
auto map_node =
|
||||||
AddMapNode(&graph, input_node->name(), map_fn->signature().name());
|
AddMapNode(&graph, input_node->name(), map_fn->signature().name());
|
||||||
auto batch_node = AddBatchNode(&graph, map_node->name());
|
auto batch_node = AddBatchNode(&graph, map_node->name());
|
||||||
MapVectorization optimizer;
|
|
||||||
GraphDef output;
|
GraphDef output;
|
||||||
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
|
TF_ASSERT_OK(OptimizeWithMapVectorization(item, &output, true));
|
||||||
CheckNotVectorized(output, map_node->op(), batch_node->op(),
|
CheckNotVectorized(output, map_node->op(), batch_node->op(),
|
||||||
input_node->name());
|
input_node->name());
|
||||||
}
|
}
|
||||||
@ -493,10 +579,9 @@ TEST(MapVectorizationTest, VectorizeWithUndefinedOutputTypes) {
|
|||||||
auto map_node =
|
auto map_node =
|
||||||
AddMapNode(&graph, input_node->name(), map_fn->signature().name());
|
AddMapNode(&graph, input_node->name(), map_fn->signature().name());
|
||||||
auto batch_node = AddBatchNode(&graph, map_node->name());
|
auto batch_node = AddBatchNode(&graph, map_node->name());
|
||||||
MapVectorization optimizer;
|
|
||||||
GraphDef output;
|
GraphDef output;
|
||||||
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
|
TF_ASSERT_OK(OptimizeWithMapVectorization(item, &output, true));
|
||||||
CheckVectorized(
|
CheckVectorizedWithChooseFastest(
|
||||||
output, /*expected_vectorized_branch=*/{batch_node->op(), map_node->op()},
|
output, /*expected_vectorized_branch=*/{batch_node->op(), map_node->op()},
|
||||||
/*expected_original_branch=*/{map_node->op(), batch_node->op()},
|
/*expected_original_branch=*/{map_node->op(), batch_node->op()},
|
||||||
input_node->name());
|
input_node->name());
|
||||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/core/grappler/optimizers/data/meta_optimizer.h"
|
#include "tensorflow/core/grappler/optimizers/data/meta_optimizer.h"
|
||||||
|
|
||||||
|
#include "absl/strings/str_split.h"
|
||||||
#include "tensorflow/core/grappler/clusters/cluster.h"
|
#include "tensorflow/core/grappler/clusters/cluster.h"
|
||||||
#include "tensorflow/core/grappler/grappler_item.h"
|
#include "tensorflow/core/grappler/grappler_item.h"
|
||||||
#include "tensorflow/core/grappler/optimizers/arithmetic_optimizer.h"
|
#include "tensorflow/core/grappler/optimizers/arithmetic_optimizer.h"
|
||||||
@ -29,6 +30,50 @@ limitations under the License.
|
|||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace grappler {
|
namespace grappler {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
using ConfigMap =
|
||||||
|
std::map<string, tensorflow::RewriterConfig_CustomGraphOptimizer>;
|
||||||
|
|
||||||
|
// Parses a list of string optimizer configurations into a map from
|
||||||
|
// optimizer name -> rewriter config for that optimizer.
|
||||||
|
Status ToConfigMap(
|
||||||
|
const tensorflow::RewriterConfig_CustomGraphOptimizer* config,
|
||||||
|
ConfigMap* result) {
|
||||||
|
auto found = gtl::FindOrNull(config->parameter_map(), "optimizer_configs");
|
||||||
|
if (!found) return Status::OK();
|
||||||
|
|
||||||
|
auto& options = found->list().s();
|
||||||
|
for (const auto& option_string : options) {
|
||||||
|
// The option string has the format
|
||||||
|
// <optimizer_name>:<config_key>:<config_value>
|
||||||
|
std::vector<string> split = absl::StrSplit(option_string, ':');
|
||||||
|
if (split.size() != 3) {
|
||||||
|
return errors::Internal(
|
||||||
|
"Wrong format for optimizer options. Expect <optimizer name>:<config "
|
||||||
|
"key>:<config value>, received: ",
|
||||||
|
option_string);
|
||||||
|
}
|
||||||
|
|
||||||
|
const string& optimizer_name = split[0];
|
||||||
|
const string& config_key = split[1];
|
||||||
|
const string& config_value = split[2];
|
||||||
|
|
||||||
|
auto optimizer_config = gtl::FindOrNull(*result, optimizer_name);
|
||||||
|
if (!optimizer_config) {
|
||||||
|
(*result)[optimizer_name] =
|
||||||
|
tensorflow::RewriterConfig_CustomGraphOptimizer();
|
||||||
|
optimizer_config = gtl::FindOrNull(*result, optimizer_name);
|
||||||
|
}
|
||||||
|
(*optimizer_config->mutable_parameter_map())[config_key].set_s(
|
||||||
|
config_value);
|
||||||
|
}
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
Status TFDataMetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
|
Status TFDataMetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||||
GraphDef* output) {
|
GraphDef* output) {
|
||||||
// Stores the optimized item so far.
|
// Stores the optimized item so far.
|
||||||
@ -86,13 +131,16 @@ Status TFDataMetaOptimizer::Init(
|
|||||||
|
|
||||||
// Initialize custom tf.data optimizers based on config.
|
// Initialize custom tf.data optimizers based on config.
|
||||||
auto& optimizers = config->parameter_map().at("optimizers").list().s();
|
auto& optimizers = config->parameter_map().at("optimizers").list().s();
|
||||||
|
ConfigMap optimizer_configs;
|
||||||
|
TF_RETURN_IF_ERROR(ToConfigMap(config, &optimizer_configs));
|
||||||
|
|
||||||
for (const auto& optimizer_name : optimizers) {
|
for (const auto& optimizer_name : optimizers) {
|
||||||
auto optimizer =
|
auto optimizer =
|
||||||
CustomGraphOptimizerRegistry::CreateByNameOrNull(optimizer_name);
|
CustomGraphOptimizerRegistry::CreateByNameOrNull(optimizer_name);
|
||||||
if (optimizer) {
|
if (optimizer) {
|
||||||
// None of our data optimizers implement a meaningful Init function.
|
TF_RETURN_IF_ERROR(
|
||||||
// This returns an error in case any of them does.
|
optimizer->Init(gtl::FindOrNull(optimizer_configs, optimizer_name)));
|
||||||
TF_RETURN_IF_ERROR(optimizer->Init());
|
|
||||||
enabled_optimizers_[optimizer_name] = std::move(optimizer);
|
enabled_optimizers_[optimizer_name] = std::move(optimizer);
|
||||||
} else {
|
} else {
|
||||||
// This should never happen.
|
// This should never happen.
|
||||||
|
@ -36,6 +36,8 @@ class OptimizeDatasetOp : public UnaryDatasetOpKernel {
|
|||||||
graph_def_version_(ctx->graph_def_version()) {
|
graph_def_version_(ctx->graph_def_version()) {
|
||||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
|
||||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
|
||||||
|
OP_REQUIRES_OK(ctx,
|
||||||
|
ctx->GetAttr("optimization_configs", &optimizer_configs_));
|
||||||
}
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
@ -44,8 +46,8 @@ class OptimizeDatasetOp : public UnaryDatasetOpKernel {
|
|||||||
std::vector<string> optimizations;
|
std::vector<string> optimizations;
|
||||||
OP_REQUIRES_OK(
|
OP_REQUIRES_OK(
|
||||||
ctx, ParseVectorArgument<string>(ctx, "optimizations", &optimizations));
|
ctx, ParseVectorArgument<string>(ctx, "optimizations", &optimizations));
|
||||||
Dataset* dataset =
|
Dataset* dataset = new Dataset(ctx, input, optimizations, output_types_,
|
||||||
new Dataset(ctx, input, optimizations, output_types_, output_shapes_);
|
output_shapes_, optimizer_configs_);
|
||||||
Status s = dataset->Optimize(ctx);
|
Status s = dataset->Optimize(ctx);
|
||||||
if (s.ok()) {
|
if (s.ok()) {
|
||||||
*output = dataset;
|
*output = dataset;
|
||||||
@ -61,9 +63,11 @@ class OptimizeDatasetOp : public UnaryDatasetOpKernel {
|
|||||||
Dataset(OpKernelContext* ctx, const DatasetBase* input,
|
Dataset(OpKernelContext* ctx, const DatasetBase* input,
|
||||||
const std::vector<string>& optimizations,
|
const std::vector<string>& optimizations,
|
||||||
const DataTypeVector& output_types,
|
const DataTypeVector& output_types,
|
||||||
const std::vector<PartialTensorShape>& output_shapes)
|
const std::vector<PartialTensorShape>& output_shapes,
|
||||||
|
const std::vector<string>& optimizer_configs)
|
||||||
: GraphRewriteDataset(ctx, input, output_types, output_shapes),
|
: GraphRewriteDataset(ctx, input, output_types, output_shapes),
|
||||||
optimizations_(optimizations) {}
|
optimizations_(optimizations),
|
||||||
|
optimizer_configs_(optimizer_configs) {}
|
||||||
|
|
||||||
string DebugString() const override { return "OptimizeDatasetOp::Dataset"; }
|
string DebugString() const override { return "OptimizeDatasetOp::Dataset"; }
|
||||||
|
|
||||||
@ -81,15 +85,23 @@ class OptimizeDatasetOp : public UnaryDatasetOpKernel {
|
|||||||
for (const auto& opt : optimizations_) {
|
for (const auto& opt : optimizations_) {
|
||||||
custom_optimizations_list->add_s(opt);
|
custom_optimizations_list->add_s(opt);
|
||||||
}
|
}
|
||||||
|
auto* config_list =
|
||||||
|
(*custom_optimizer->mutable_parameter_map())["optimizer_configs"]
|
||||||
|
.mutable_list();
|
||||||
|
for (const auto& config : optimizer_configs_) {
|
||||||
|
config_list->add_s(config);
|
||||||
|
}
|
||||||
return rewriter_config;
|
return rewriter_config;
|
||||||
}
|
}
|
||||||
|
|
||||||
const std::vector<string> optimizations_;
|
const std::vector<string> optimizations_;
|
||||||
|
const std::vector<string> optimizer_configs_;
|
||||||
};
|
};
|
||||||
|
|
||||||
const int graph_def_version_;
|
const int graph_def_version_;
|
||||||
DataTypeVector output_types_;
|
DataTypeVector output_types_;
|
||||||
std::vector<PartialTensorShape> output_shapes_;
|
std::vector<PartialTensorShape> output_shapes_;
|
||||||
|
std::vector<string> optimizer_configs_;
|
||||||
};
|
};
|
||||||
|
|
||||||
REGISTER_KERNEL_BUILDER(Name("OptimizeDataset").Device(DEVICE_CPU),
|
REGISTER_KERNEL_BUILDER(Name("OptimizeDataset").Device(DEVICE_CPU),
|
||||||
|
@ -624,6 +624,7 @@ REGISTER_OP("OptimizeDataset")
|
|||||||
.Output("handle: variant")
|
.Output("handle: variant")
|
||||||
.Attr("output_types: list(type) >= 1")
|
.Attr("output_types: list(type) >= 1")
|
||||||
.Attr("output_shapes: list(shape) >= 1")
|
.Attr("output_shapes: list(shape) >= 1")
|
||||||
|
.Attr("optimization_configs: list(string) = []")
|
||||||
.SetShapeFn(shape_inference::ScalarShape);
|
.SetShapeFn(shape_inference::ScalarShape);
|
||||||
|
|
||||||
REGISTER_OP("OptionalFromValue")
|
REGISTER_OP("OptionalFromValue")
|
||||||
|
@ -26,6 +26,7 @@ See [Importing Data](https://tensorflow.org/guide/datasets) for an overview.
|
|||||||
@@CheckpointInputPipelineHook
|
@@CheckpointInputPipelineHook
|
||||||
@@CsvDataset
|
@@CsvDataset
|
||||||
@@DatasetStructure
|
@@DatasetStructure
|
||||||
|
@@MapVectorizationOptions
|
||||||
@@NestedStructure
|
@@NestedStructure
|
||||||
@@OptimizationOptions
|
@@OptimizationOptions
|
||||||
@@Optional
|
@@Optional
|
||||||
@ -102,6 +103,7 @@ from tensorflow.python.data.experimental.ops.interleave_ops import sample_from_d
|
|||||||
from tensorflow.python.data.experimental.ops.iterator_ops import CheckpointInputPipelineHook
|
from tensorflow.python.data.experimental.ops.iterator_ops import CheckpointInputPipelineHook
|
||||||
from tensorflow.python.data.experimental.ops.iterator_ops import make_saveable_from_iterator
|
from tensorflow.python.data.experimental.ops.iterator_ops import make_saveable_from_iterator
|
||||||
from tensorflow.python.data.experimental.ops.optimization import AUTOTUNE
|
from tensorflow.python.data.experimental.ops.optimization import AUTOTUNE
|
||||||
|
from tensorflow.python.data.experimental.ops.optimization_options import MapVectorizationOptions
|
||||||
from tensorflow.python.data.experimental.ops.optimization_options import OptimizationOptions
|
from tensorflow.python.data.experimental.ops.optimization_options import OptimizationOptions
|
||||||
from tensorflow.python.data.experimental.ops.parsing_ops import parse_example_dataset
|
from tensorflow.python.data.experimental.ops.parsing_ops import parse_example_dataset
|
||||||
from tensorflow.python.data.experimental.ops.prefetching_ops import copy_to_device
|
from tensorflow.python.data.experimental.ops.prefetching_ops import copy_to_device
|
||||||
|
@ -321,6 +321,13 @@ def _generate_optimization_test_cases():
|
|||||||
@test_util.run_all_in_graph_and_eager_modes
|
@test_util.run_all_in_graph_and_eager_modes
|
||||||
class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase):
|
class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||||
|
|
||||||
|
def _enable_map_vectorization(self, dataset, use_choose=True):
|
||||||
|
options = dataset_ops.Options()
|
||||||
|
opt_options = options.experimental_optimization
|
||||||
|
opt_options.map_vectorization.enabled = True
|
||||||
|
opt_options.map_vectorization.use_choose_fastest = use_choose
|
||||||
|
return dataset.with_options(options)
|
||||||
|
|
||||||
def _get_test_datasets(self,
|
def _get_test_datasets(self,
|
||||||
base_dataset,
|
base_dataset,
|
||||||
map_fn,
|
map_fn,
|
||||||
@ -361,10 +368,7 @@ class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
# to verify the optimization result.
|
# to verify the optimization result.
|
||||||
optimized = _make_dataset(["ChooseFastestBranch"]
|
optimized = _make_dataset(["ChooseFastestBranch"]
|
||||||
if expect_optimized else [map_node_name, "Batch"])
|
if expect_optimized else [map_node_name, "Batch"])
|
||||||
options = dataset_ops.Options()
|
optimized = self._enable_map_vectorization(optimized)
|
||||||
options.experimental_optimization.apply_default_optimizations = False
|
|
||||||
options.experimental_optimization.map_vectorization = True
|
|
||||||
optimized = optimized.with_options(options)
|
|
||||||
return unoptimized, optimized
|
return unoptimized, optimized
|
||||||
|
|
||||||
@parameterized.named_parameters(_generate_optimization_test_cases())
|
@parameterized.named_parameters(_generate_optimization_test_cases())
|
||||||
@ -404,16 +408,12 @@ class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
|
|
||||||
def testOptimizationWithMapAndBatchFusion(self):
|
def testOptimizationWithMapAndBatchFusion(self):
|
||||||
# Tests that vectorization works on fused map and batch.
|
# Tests that vectorization works on fused map and batch.
|
||||||
y = constant_op.constant(1, shape=(2,))
|
|
||||||
z = constant_op.constant(2, shape=(2,))
|
|
||||||
|
|
||||||
def map_fn(x):
|
def map_fn(x):
|
||||||
return x, y, z
|
return x**2
|
||||||
|
|
||||||
|
base_dataset = dataset_ops.Dataset.range(1000)
|
||||||
options = dataset_ops.Options()
|
options = dataset_ops.Options()
|
||||||
options.experimental_optimization.apply_default_optimizations = False
|
options.experimental_optimization.apply_default_optimizations = False
|
||||||
base_dataset = dataset_ops.Dataset.from_tensor_slices([[1, 2],
|
|
||||||
[3, 4]]).repeat(5)
|
|
||||||
base_dataset = base_dataset.with_options(options)
|
base_dataset = base_dataset.with_options(options)
|
||||||
|
|
||||||
def _make_dataset(node_names):
|
def _make_dataset(node_names):
|
||||||
@ -423,9 +423,7 @@ class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
|
|
||||||
unoptimized = _make_dataset(["MapAndBatch"])
|
unoptimized = _make_dataset(["MapAndBatch"])
|
||||||
optimized = _make_dataset(["ChooseFastestBranch"])
|
optimized = _make_dataset(["ChooseFastestBranch"])
|
||||||
options = dataset_ops.Options()
|
optimized = self._enable_map_vectorization(optimized)
|
||||||
options.experimental_optimization.map_vectorization = True
|
|
||||||
optimized = optimized.with_options(options)
|
|
||||||
self.assertDatasetsEqual(optimized, unoptimized)
|
self.assertDatasetsEqual(optimized, unoptimized)
|
||||||
|
|
||||||
@parameterized.named_parameters(
|
@parameterized.named_parameters(
|
||||||
@ -474,10 +472,7 @@ class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
|
|
||||||
unoptimized = make_dataset(unoptimized_seq)
|
unoptimized = make_dataset(unoptimized_seq)
|
||||||
optimized = make_dataset(["ChooseFastestBranch", "ChooseFastestBranch"])
|
optimized = make_dataset(["ChooseFastestBranch", "ChooseFastestBranch"])
|
||||||
options = dataset_ops.Options()
|
optimized = self._enable_map_vectorization(optimized)
|
||||||
options.experimental_optimization.map_vectorization = True
|
|
||||||
optimized = optimized.with_options(options)
|
|
||||||
|
|
||||||
self.assertDatasetsEqual(optimized, unoptimized)
|
self.assertDatasetsEqual(optimized, unoptimized)
|
||||||
|
|
||||||
def testOptimizationIgnoreStateful(self):
|
def testOptimizationIgnoreStateful(self):
|
||||||
@ -536,9 +531,7 @@ class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
options.experimental_optimization.apply_default_optimizations = False
|
options.experimental_optimization.apply_default_optimizations = False
|
||||||
unoptimized = unoptimized.with_options(options)
|
unoptimized = unoptimized.with_options(options)
|
||||||
|
|
||||||
options = dataset_ops.Options()
|
optimized = self._enable_map_vectorization(unoptimized)
|
||||||
options.experimental_optimization.map_vectorization = True
|
|
||||||
optimized = unoptimized.with_options(options)
|
|
||||||
self.assertDatasetsEqual(unoptimized, optimized)
|
self.assertDatasetsEqual(unoptimized, optimized)
|
||||||
|
|
||||||
def testOptimizationWithSparseTensor(self):
|
def testOptimizationWithSparseTensor(self):
|
||||||
@ -554,10 +547,7 @@ class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
options = dataset_ops.Options()
|
options = dataset_ops.Options()
|
||||||
options.experimental_optimization.apply_default_optimizations = False
|
options.experimental_optimization.apply_default_optimizations = False
|
||||||
unoptimized = unoptimized.with_options(options)
|
unoptimized = unoptimized.with_options(options)
|
||||||
|
optimized = self._enable_map_vectorization(unoptimized)
|
||||||
options = dataset_ops.Options()
|
|
||||||
options.experimental_optimization.map_vectorization = True
|
|
||||||
optimized = unoptimized.with_options(options)
|
|
||||||
self.assertDatasetsEqual(unoptimized, optimized)
|
self.assertDatasetsEqual(unoptimized, optimized)
|
||||||
|
|
||||||
def testOptimizationWithPrefetch(self):
|
def testOptimizationWithPrefetch(self):
|
||||||
@ -565,11 +555,16 @@ class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
dataset = dataset.map(lambda x: x)
|
dataset = dataset.map(lambda x: x)
|
||||||
dataset = dataset.prefetch(1)
|
dataset = dataset.prefetch(1)
|
||||||
dataset = dataset.batch(10)
|
dataset = dataset.batch(10)
|
||||||
options = dataset_ops.Options()
|
dataset = self._enable_map_vectorization(dataset)
|
||||||
options.experimental_optimization.map_vectorization = True
|
|
||||||
dataset = dataset.with_options(options)
|
|
||||||
self.assertDatasetProduces(dataset, [list(range(10))])
|
self.assertDatasetProduces(dataset, [list(range(10))])
|
||||||
|
|
||||||
|
def testOptimizationWithoutChooseFastest(self):
|
||||||
|
dataset = dataset_ops.Dataset.range(10)
|
||||||
|
dataset = dataset.map(lambda x: x**2)
|
||||||
|
dataset = dataset.batch(10)
|
||||||
|
dataset = self._enable_map_vectorization(dataset, use_choose=False)
|
||||||
|
self.assertDatasetProduces(dataset, [[x**2 for x in range(10)]])
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test.main()
|
test.main()
|
||||||
|
@ -17,11 +17,42 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
|
||||||
from tensorflow.python.data.util import options
|
from tensorflow.python.data.util import options
|
||||||
from tensorflow.python.util.tf_export import tf_export
|
from tensorflow.python.util.tf_export import tf_export
|
||||||
|
|
||||||
|
|
||||||
|
@tf_export("data.experimental.MapVectorizationOptions")
|
||||||
|
class MapVectorizationOptions(options.OptionsBase):
|
||||||
|
"""Represents options for the MapVectorization optimization."""
|
||||||
|
# TODO(rachelim): Other configuration parameters can go here, for example,
|
||||||
|
# how many "experiments" to run with ChooseFastestBranchDataset.
|
||||||
|
enabled = options.create_option(
|
||||||
|
name="enabled",
|
||||||
|
ty=bool,
|
||||||
|
docstring=
|
||||||
|
"Whether to vectorize map transformations. If None, defaults to False."
|
||||||
|
)
|
||||||
|
|
||||||
|
use_choose_fastest = options.create_option(
|
||||||
|
name="use_choose_fastest",
|
||||||
|
ty=bool,
|
||||||
|
docstring="Whether to use ChooseFastestBranchDataset with this "
|
||||||
|
"transformation. If True, the pipeline picks between the vectorized and "
|
||||||
|
"original segment at runtime based on their iterations speed. If None, "
|
||||||
|
"defaults to False.")
|
||||||
|
|
||||||
|
def _static_optimizations(self):
|
||||||
|
if self.enabled:
|
||||||
|
return ["map_vectorization"]
|
||||||
|
return []
|
||||||
|
|
||||||
|
def _static_optimization_configs(self):
|
||||||
|
if self.use_choose_fastest:
|
||||||
|
return ["map_vectorization:use_choose_fastest:true"]
|
||||||
|
else:
|
||||||
|
return ["map_vectorization:use_choose_fastest:false"]
|
||||||
|
|
||||||
|
|
||||||
@tf_export("data.experimental.OptimizationOptions")
|
@tf_export("data.experimental.OptimizationOptions")
|
||||||
class OptimizationOptions(options.OptionsBase):
|
class OptimizationOptions(options.OptionsBase):
|
||||||
"""Represents options for dataset optimizations.
|
"""Represents options for dataset optimizations.
|
||||||
@ -102,9 +133,11 @@ class OptimizationOptions(options.OptionsBase):
|
|||||||
|
|
||||||
map_vectorization = options.create_option(
|
map_vectorization = options.create_option(
|
||||||
name="map_vectorization",
|
name="map_vectorization",
|
||||||
ty=bool,
|
ty=MapVectorizationOptions,
|
||||||
docstring=
|
docstring=
|
||||||
"Whether to vectorize map transformations. If None, defaults to False.")
|
"The map vectorization options associated with the dataset. See "
|
||||||
|
"`tf.data.experimental.MapVectorizationOptions` for more details.",
|
||||||
|
default_factory=MapVectorizationOptions)
|
||||||
|
|
||||||
noop_elimination = options.create_option(
|
noop_elimination = options.create_option(
|
||||||
name="noop_elimination",
|
name="noop_elimination",
|
||||||
@ -128,7 +161,6 @@ class OptimizationOptions(options.OptionsBase):
|
|||||||
"map_and_filter_fusion",
|
"map_and_filter_fusion",
|
||||||
"map_parallelization",
|
"map_parallelization",
|
||||||
"map_fusion",
|
"map_fusion",
|
||||||
"map_vectorization",
|
|
||||||
"noop_elimination",
|
"noop_elimination",
|
||||||
"shuffle_and_repeat_fusion",
|
"shuffle_and_repeat_fusion",
|
||||||
]
|
]
|
||||||
@ -147,4 +179,12 @@ class OptimizationOptions(options.OptionsBase):
|
|||||||
for optimization in optimizations_to_disable:
|
for optimization in optimizations_to_disable:
|
||||||
if getattr(self, optimization) is not False:
|
if getattr(self, optimization) is not False:
|
||||||
result.add(optimization)
|
result.add(optimization)
|
||||||
|
|
||||||
|
if self.map_vectorization is not None:
|
||||||
|
result.update(self.map_vectorization._static_optimizations()) # pylint: disable=protected-access
|
||||||
return sorted(list(result))
|
return sorted(list(result))
|
||||||
|
|
||||||
|
def _static_optimization_configs(self):
|
||||||
|
if self.map_vectorization is not None:
|
||||||
|
return self.map_vectorization._static_optimization_configs() # pylint: disable=protected-access
|
||||||
|
return []
|
||||||
|
@ -191,7 +191,8 @@ class DatasetV2(object):
|
|||||||
"`tf.enable_resource_variables()` at the start of the program." %
|
"`tf.enable_resource_variables()` at the start of the program." %
|
||||||
", ".join(static_optimizations))
|
", ".join(static_optimizations))
|
||||||
else:
|
else:
|
||||||
dataset = _OptimizeDataset(dataset, static_optimizations)
|
dataset = _OptimizeDataset(dataset, static_optimizations,
|
||||||
|
options._static_optimization_configs()) # pylint: disable=protected-access
|
||||||
|
|
||||||
autotune = True
|
autotune = True
|
||||||
cpu_budget = 0 # Indicates that all CPU cores should be used.
|
cpu_budget = 0 # Indicates that all CPU cores should be used.
|
||||||
@ -2009,6 +2010,10 @@ class Options(options_lib.OptionsBase):
|
|||||||
result.append("latency_all_edges")
|
result.append("latency_all_edges")
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
def _static_optimization_configs(self):
|
||||||
|
"""Produces the list of configurations for enabled static optimizations."""
|
||||||
|
return self.experimental_optimization._static_optimization_configs() # pylint: disable=protected-access
|
||||||
|
|
||||||
def merge(self, options):
|
def merge(self, options):
|
||||||
"""Merges itself with the given `tf.data.Options`.
|
"""Merges itself with the given `tf.data.Options`.
|
||||||
|
|
||||||
@ -3295,15 +3300,18 @@ class _ModelDataset(UnaryUnchangedStructureDataset):
|
|||||||
class _OptimizeDataset(UnaryUnchangedStructureDataset):
|
class _OptimizeDataset(UnaryUnchangedStructureDataset):
|
||||||
"""A `Dataset` that acts as an identity, and applies optimizations."""
|
"""A `Dataset` that acts as an identity, and applies optimizations."""
|
||||||
|
|
||||||
def __init__(self, input_dataset, optimizations):
|
def __init__(self, input_dataset, optimizations, optimization_configs=None):
|
||||||
self._input_dataset = input_dataset
|
self._input_dataset = input_dataset
|
||||||
if optimizations is None:
|
if optimizations is None:
|
||||||
optimizations = []
|
optimizations = []
|
||||||
|
if optimization_configs is None:
|
||||||
|
optimization_configs = []
|
||||||
self._optimizations = ops.convert_to_tensor(
|
self._optimizations = ops.convert_to_tensor(
|
||||||
optimizations, dtype=dtypes.string, name="optimizations")
|
optimizations, dtype=dtypes.string, name="optimizations")
|
||||||
variant_tensor = gen_dataset_ops.optimize_dataset(
|
variant_tensor = gen_dataset_ops.optimize_dataset(
|
||||||
input_dataset._variant_tensor, # pylint: disable=protected-access
|
input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||||
self._optimizations,
|
self._optimizations,
|
||||||
|
optimization_configs=optimization_configs,
|
||||||
**flat_structure(self))
|
**flat_structure(self))
|
||||||
super(_OptimizeDataset, self).__init__(input_dataset, variant_tensor)
|
super(_OptimizeDataset, self).__init__(input_dataset, variant_tensor)
|
||||||
|
|
||||||
|
@ -0,0 +1,18 @@
|
|||||||
|
path: "tensorflow.data.experimental.MapVectorizationOptions"
|
||||||
|
tf_class {
|
||||||
|
is_instance: "<class \'tensorflow.python.data.experimental.ops.optimization_options.MapVectorizationOptions\'>"
|
||||||
|
is_instance: "<class \'tensorflow.python.data.util.options.OptionsBase\'>"
|
||||||
|
is_instance: "<type \'object\'>"
|
||||||
|
member {
|
||||||
|
name: "enabled"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "use_choose_fastest"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "__init__"
|
||||||
|
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
}
|
@ -20,6 +20,10 @@ tf_module {
|
|||||||
name: "INFINITE_CARDINALITY"
|
name: "INFINITE_CARDINALITY"
|
||||||
mtype: "<type \'int\'>"
|
mtype: "<type \'int\'>"
|
||||||
}
|
}
|
||||||
|
member {
|
||||||
|
name: "MapVectorizationOptions"
|
||||||
|
mtype: "<type \'type\'>"
|
||||||
|
}
|
||||||
member {
|
member {
|
||||||
name: "NestedStructure"
|
name: "NestedStructure"
|
||||||
mtype: "<type \'type\'>"
|
mtype: "<type \'type\'>"
|
||||||
|
@ -2182,7 +2182,7 @@ tf_module {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "OptimizeDataset"
|
name: "OptimizeDataset"
|
||||||
argspec: "args=[\'input_dataset\', \'optimizations\', \'output_types\', \'output_shapes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'input_dataset\', \'optimizations\', \'output_types\', \'output_shapes\', \'optimization_configs\', \'name\'], varargs=None, keywords=None, defaults=[\'[]\', \'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "OptionalFromValue"
|
name: "OptionalFromValue"
|
||||||
|
@ -0,0 +1,18 @@
|
|||||||
|
path: "tensorflow.data.experimental.MapVectorizationOptions"
|
||||||
|
tf_class {
|
||||||
|
is_instance: "<class \'tensorflow.python.data.experimental.ops.optimization_options.MapVectorizationOptions\'>"
|
||||||
|
is_instance: "<class \'tensorflow.python.data.util.options.OptionsBase\'>"
|
||||||
|
is_instance: "<type \'object\'>"
|
||||||
|
member {
|
||||||
|
name: "enabled"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "use_choose_fastest"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "__init__"
|
||||||
|
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
}
|
@ -20,6 +20,10 @@ tf_module {
|
|||||||
name: "INFINITE_CARDINALITY"
|
name: "INFINITE_CARDINALITY"
|
||||||
mtype: "<type \'int\'>"
|
mtype: "<type \'int\'>"
|
||||||
}
|
}
|
||||||
|
member {
|
||||||
|
name: "MapVectorizationOptions"
|
||||||
|
mtype: "<type \'type\'>"
|
||||||
|
}
|
||||||
member {
|
member {
|
||||||
name: "NestedStructure"
|
name: "NestedStructure"
|
||||||
mtype: "<type \'type\'>"
|
mtype: "<type \'type\'>"
|
||||||
|
@ -2182,7 +2182,7 @@ tf_module {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "OptimizeDataset"
|
name: "OptimizeDataset"
|
||||||
argspec: "args=[\'input_dataset\', \'optimizations\', \'output_types\', \'output_shapes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'input_dataset\', \'optimizations\', \'output_types\', \'output_shapes\', \'optimization_configs\', \'name\'], varargs=None, keywords=None, defaults=[\'[]\', \'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "OptionalFromValue"
|
name: "OptionalFromValue"
|
||||||
|
Loading…
Reference in New Issue
Block a user