diff --git a/tensorflow/core/grappler/optimizers/data/BUILD b/tensorflow/core/grappler/optimizers/data/BUILD index 5f807746ef3..49f36fc7c55 100644 --- a/tensorflow/core/grappler/optimizers/data/BUILD +++ b/tensorflow/core/grappler/optimizers/data/BUILD @@ -522,6 +522,7 @@ tf_cc_test( name = "map_vectorization_test", srcs = ["map_vectorization_test.cc"], deps = [ + ":function_utils", ":graph_utils", ":map_vectorization", "//tensorflow/core:array_ops_op_lib", diff --git a/tensorflow/core/grappler/optimizers/data/map_vectorization.cc b/tensorflow/core/grappler/optimizers/data/map_vectorization.cc index bb71033a5dc..983b0436338 100644 --- a/tensorflow/core/grappler/optimizers/data/map_vectorization.cc +++ b/tensorflow/core/grappler/optimizers/data/map_vectorization.cc @@ -50,7 +50,7 @@ constexpr char kBatchV2Op[] = "BatchDatasetV2"; constexpr char kExperimentalMapAndBatchOp[] = "ExperimentalMapAndBatchDataset"; constexpr char kMapOp[] = "MapDataset"; constexpr char kParallelMapOp[] = "ParallelMapDataset"; -constexpr char kChooseFastestOp[] = "ExperimentalChooseFastestDataset"; +constexpr char kChooseFastestOp[] = "ChooseFastestBranchDataset"; constexpr char kPrefetchOp[] = "PrefetchDataset"; constexpr int kAutotune = -1; @@ -317,23 +317,123 @@ Status AddNewPrefetchNode(const NodeDef& old_prefetch_node, return Status::OK(); } -Status AddNewChooseFastestNode(gtl::ArraySlice input_nodes, +Status AddBranch(gtl::ArraySlice branch, + NodeDef* choose_fastest_node, DataTypeVector* t_arguments, + std::vector* branches, + std::vector* other_arguments_lengths, + FunctionDefLibrary* library) { + FunctionDef* branch_func = library->add_function(); + auto* signature = branch_func->mutable_signature(); + graph_utils::SetUniqueGraphFunctionName("branch", library, branch_func); + + // Input dataset. + string prev_node_output = "args_0"; + auto* input_arg_0 = signature->add_input_arg(); + input_arg_0->set_name(prev_node_output); + input_arg_0->set_type(DT_VARIANT); + + auto* output_arg = signature->add_output_arg(); + output_arg->set_name("output"); + output_arg->set_type(DT_VARIANT); + + int32 captured_arg_lengths = 0; + + // For each node in the branch, copy it to the function def. Add the + // corresponding non-0th inputs as captured arguments, modifying the function + // input signature, node input names, other_arguments_lengths, and t_arguments + // accordingly. + for (const NodeDef* node : branch) { + // Copy the node to the function + auto function_node = branch_func->add_node_def(); + *function_node = *node; + function_utils::SetUniqueFunctionNodeName(node->name(), branch_func, + function_node); + function_node->clear_input(); + function_node->add_input(prev_node_output); + + // Every input besides the 0th (dataset) becomes a captured argument. + int input_size = node->input_size(); + DataTypeVector input_types; + const OpDef* op_def; + TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef(node->op(), &op_def)); + TF_RETURN_IF_ERROR(InputTypesForNode(*node, *op_def, &input_types)); + DCHECK_EQ(input_types.size(), input_size); + + for (int i = 1; i < input_size; ++i) { + // Capture input in `other_arguments` + choose_fastest_node->add_input(node->input(i)); + // Add type to function signature + auto* input_arg = signature->add_input_arg(); + + string input_arg_name = strings::StrCat(function_node->name(), "_", i); + input_arg->set_name(input_arg_name); + input_arg->set_type(input_types[i]); + function_node->add_input(input_arg_name); + } + // Add to `Targuments` + t_arguments->reserve(t_arguments->size() + input_types.size() - 1); + t_arguments->insert(t_arguments->end(), input_types.begin() + 1, + input_types.end()); + captured_arg_lengths += input_size - 1; + prev_node_output = strings::StrCat(function_node->name(), ":handle:0"); + } + + // Add to `other_arguments_lengths` + other_arguments_lengths->push_back(captured_arg_lengths); + (*branch_func->mutable_ret())["output"] = prev_node_output; + + // Add to `branches` + NameAttrList func_attr; + func_attr.set_name(branch_func->signature().name()); + branches->push_back(std::move(func_attr)); + return Status::OK(); +} + +Status AddNewChooseFastestNode(const NodeDef* input_dataset_node, + const string& ratio_numerator_name, + std::vector original_branch, + std::vector vectorized_branch, MutableGraphView* graph, + FunctionDefLibrary* library, NodeDef** new_choose_fastest_node) { NodeDef choose_fastest_node; choose_fastest_node.set_op(kChooseFastestOp); graph_utils::SetUniqueGraphNodeName(choose_fastest_node.op(), graph->graph(), &choose_fastest_node); - // Set the `input_datasets` input argument. - for (const auto& node_def : input_nodes) { - choose_fastest_node.add_input(node_def.name()); - } - AddNodeAttr("N", static_cast(input_nodes.size()), &choose_fastest_node); - AddNodeAttr("num_experiments", 10, &choose_fastest_node); + // input_dataset + choose_fastest_node.add_input(input_dataset_node->name()); + choose_fastest_node.add_input(ratio_numerator_name); + // ratio_denominator == 1 + auto ratio_denominator = + graph_utils::AddScalarConstNode(static_cast(1), graph); + choose_fastest_node.add_input(ratio_denominator->name()); + + DataTypeVector t_arguments; + std::vector branches; + std::vector other_arguments_lengths; + // Branch 0: vectorized branch + TF_RETURN_IF_ERROR(AddBranch(vectorized_branch, &choose_fastest_node, + &t_arguments, &branches, + &other_arguments_lengths, library)); + // Branch 1: original branch + TF_RETURN_IF_ERROR(AddBranch(original_branch, &choose_fastest_node, + &t_arguments, &branches, + &other_arguments_lengths, library)); + + DCHECK_EQ(t_arguments.size(), choose_fastest_node.input_size() - 3); + DCHECK_EQ(branches.size(), other_arguments_lengths.size()); + + AddNodeAttr("Targuments", t_arguments, &choose_fastest_node); + AddNodeAttr("num_elements_per_branch", 10, &choose_fastest_node); + AddNodeAttr("branches", branches, &choose_fastest_node); + AddNodeAttr("other_arguments_lengths", other_arguments_lengths, + &choose_fastest_node); for (auto key : {"output_shapes", "output_types"}) { - graph_utils::CopyAttribute(key, input_nodes[0], &choose_fastest_node); + graph_utils::CopyAttribute(key, + *vectorized_branch[vectorized_branch.size() - 1], + &choose_fastest_node); } *new_choose_fastest_node = graph->AddNode(std::move(choose_fastest_node)); @@ -434,14 +534,16 @@ Status MapVectorization::OptimizeAndCollectStats(Cluster* cluster, AddVectorizedFunction(*map_node, *map_func, library); CHECK_NOTNULL(vectorized_func); + std::vector vectorized_branch; NodeDef* new_batch_node; TF_RETURN_IF_ERROR(AddNewBatchNode( *batch_node, *input_node, *vectorized_func, &graph, &new_batch_node)); + vectorized_branch.push_back(new_batch_node); NodeDef* new_map_node; TF_RETURN_IF_ERROR(AddNewMapNode(*map_node, *batch_node, *new_batch_node, *vectorized_func, &graph, &new_map_node)); - NodeDef* final_node = new_map_node; + vectorized_branch.push_back(new_map_node); if (optional_prefetch_node) { // If the original pipeline was .map().prefetch().batch(), the new @@ -450,13 +552,22 @@ Status MapVectorization::OptimizeAndCollectStats(Cluster* cluster, TF_RETURN_IF_ERROR(AddNewPrefetchNode(*optional_prefetch_node, *batch_node, *new_map_node, &graph, &new_prefetch_node)); + vectorized_branch.push_back(new_prefetch_node); + } - final_node = new_prefetch_node; + std::vector original_branch({map_node}); + if (optional_prefetch_node) { + original_branch.push_back(optional_prefetch_node); + } + if (map_node->op() != kExperimentalMapAndBatchOp) { + original_branch.push_back(batch_node); } NodeDef* new_choose_fastest_node; TF_RETURN_IF_ERROR(AddNewChooseFastestNode( - {*final_node, *batch_node}, &graph, &new_choose_fastest_node)); + input_node, /*ratio_numerator_name=*/new_batch_node->input(1), + std::move(original_branch), std::move(vectorized_branch), &graph, + library, &new_choose_fastest_node)); // Make output of Batch point to ChooseFastest instead. TF_RETURN_IF_ERROR(graph.UpdateFanouts(batch_node->name(), diff --git a/tensorflow/core/grappler/optimizers/data/map_vectorization_test.cc b/tensorflow/core/grappler/optimizers/data/map_vectorization_test.cc index 606f5b992cd..884bc17d98f 100644 --- a/tensorflow/core/grappler/optimizers/data/map_vectorization_test.cc +++ b/tensorflow/core/grappler/optimizers/data/map_vectorization_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/optimizers/data/function_utils.h" #include "tensorflow/core/grappler/optimizers/data/graph_utils.h" #include "tensorflow/core/grappler/utils/topological_sort.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -38,7 +39,7 @@ constexpr char kBatchV2Op[] = "BatchDatasetV2"; constexpr char kExperimentalMapAndBatchOp[] = "ExperimentalMapAndBatchDataset"; constexpr char kMapOp[] = "MapDataset"; constexpr char kParallelMapOp[] = "ParallelMapDataset"; -constexpr char kChooseFastestOp[] = "ExperimentalChooseFastestDataset"; +constexpr char kChooseFastestOp[] = "ChooseFastestBranchDataset"; constexpr char kPrefetchOp[] = "PrefetchDataset"; constexpr char kAttrNameF[] = "f"; constexpr char kAttrNameTarguments[] = "Targuments"; @@ -171,110 +172,65 @@ void CheckNotVectorized(const GraphDef& output, const string& map_op, EXPECT_EQ(batch_node.input(0), map_node.name()); } -void CheckBranch(const GraphDef& graph, string input_name, - gtl::ArraySlice ops, const string& terminal_input) { +void CheckBranch(const FunctionDef& function, gtl::ArraySlice ops) { for (int i = 0, size = ops.size(); i < size; ++i) { - const NodeDef& input_node = - graph.node(graph_utils::FindGraphNodeWithName(input_name, graph)); - EXPECT_EQ(input_node.op(), ops[size - i - 1]); - input_name = input_node.input(0); + EXPECT_EQ(function.node_def(i).op(), ops[i]); } - EXPECT_EQ(input_name, terminal_input); } -// Checks that a graph has undergone the map_vectorization transformation -// successfully, whereby the new graph has the shape: -// -// input_node --> new batch --> new map -------+ -// | | -// | v -// +-------> old map --> old batch ---> choose_fastest -// -void CheckVectorized(const GraphDef& output, const string& map_op, - const string& batch_op, const string& map_input_name, - bool fused = false, bool prefetch = false) { - ASSERT_EQ(graph_utils::FindAllGraphNodesWithOp(map_op, output).size(), 2); - ASSERT_EQ(graph_utils::FindAllGraphNodesWithOp(batch_op, output).size(), 2); - ASSERT_EQ( - graph_utils::FindAllGraphNodesWithOp(kChooseFastestOp, output).size(), 1); - const NodeDef& choose_fastest_node = - output.node(graph_utils::FindGraphNodeWithOp(kChooseFastestOp, output)); - - // Branch 0: vectorized - std::vector vectorized_ops({batch_op, map_op}); - std::vector unvectorized_ops({map_op, batch_op}); - if (prefetch) { - vectorized_ops.push_back(kPrefetchOp); - unvectorized_ops.insert(unvectorized_ops.begin() + 1, kPrefetchOp); - } - CheckBranch(output, choose_fastest_node.input(0), vectorized_ops, - map_input_name); - - // Branch 1: original - CheckBranch(output, choose_fastest_node.input(1), unvectorized_ops, - map_input_name); - - const NodeDef* vectorized_map_node = nullptr; - auto tmp_node = &output.node( - graph_utils::FindGraphNodeWithName(choose_fastest_node.input(0), output)); - if (prefetch) { - vectorized_map_node = &output.node( - graph_utils::FindGraphNodeWithName(tmp_node->input(0), output)); - } else { - vectorized_map_node = tmp_node; - } - // Check that the function is actually vectorized. - // The vectorization of the identity function is itself. - string function_name = - vectorized_map_node->attr().at(kAttrNameF).func().name(); +const FunctionDef* GetFunction(const GraphDef& graph, + const string& function_name) { int found = - graph_utils::FindGraphFunctionWithName(function_name, output.library()); - ASSERT_NE(found, -1); - const auto& function = output.library().function(found); - EXPECT_EQ(function.node_def(0).op(), "Identity"); + graph_utils::FindGraphFunctionWithName(function_name, graph.library()); + if (found == -1) { + return nullptr; + } + return &graph.library().function(found); } // Checks that a graph has undergone the map_vectorization transformation // successfully, whereby the new graph has the shape: // -// input_node --> new batch -> new map --------+ -// | | -// | v -// +-------> old map_and_batch ---> choose_fastest +// input_node -------------> choose_fastest --> ... +// |f0 |f1 +// | | +// | +---> new batch --> new map +// | +// +--> old map --> old batch // -void CheckVectorizedFused(const GraphDef& output, - const string& map_input_name) { - ASSERT_EQ(graph_utils::FindAllGraphNodesWithOp(kParallelMapOp, output).size(), - 1); - ASSERT_EQ(graph_utils::FindAllGraphNodesWithOp(kBatchV2Op, output).size(), 1); - ASSERT_EQ( - graph_utils::FindAllGraphNodesWithOp(kExperimentalMapAndBatchOp, output) - .size(), - 1); +void CheckVectorized(const GraphDef& output, + gtl::ArraySlice expected_vectorized_branch, + gtl::ArraySlice expected_original_branch, + const string& input_name) { ASSERT_EQ( graph_utils::FindAllGraphNodesWithOp(kChooseFastestOp, output).size(), 1); const NodeDef& choose_fastest_node = output.node(graph_utils::FindGraphNodeWithOp(kChooseFastestOp, output)); + ASSERT_EQ(choose_fastest_node.input(0), input_name); + + const auto& functions_list = choose_fastest_node.attr().at("branches").list(); // Branch 0: vectorized - CheckBranch(output, choose_fastest_node.input(0), - {kBatchV2Op, kParallelMapOp}, map_input_name); + const FunctionDef* branch_0 = + GetFunction(output, functions_list.func(0).name()); + ASSERT_NE(branch_0, nullptr); + CheckBranch(*branch_0, expected_vectorized_branch); // Branch 1: original - CheckBranch(output, choose_fastest_node.input(1), - {kExperimentalMapAndBatchOp}, map_input_name); + const FunctionDef* branch_1 = + GetFunction(output, functions_list.func(1).name()); + ASSERT_NE(branch_1, nullptr); + CheckBranch(*branch_1, expected_original_branch); - const NodeDef& vectorized_map_node = output.node( - graph_utils::FindGraphNodeWithName(choose_fastest_node.input(0), output)); - // Check that the function is actually vectorized. - // The vectorization of the identity function is itself. + const NodeDef& vectorized_map_node = + branch_0->node_def(function_utils::FindFunctionNodeWithOp( + expected_vectorized_branch[1], *branch_0)); string function_name = vectorized_map_node.attr().at(kAttrNameF).func().name(); - int found = - graph_utils::FindGraphFunctionWithName(function_name, output.library()); - ASSERT_NE(found, -1); - const auto& function = output.library().function(found); - EXPECT_EQ(function.node_def(0).op(), "Identity"); + + const FunctionDef* function = GetFunction(output, function_name); + ASSERT_NE(function, nullptr); + EXPECT_EQ(function->node_def(0).op(), "Identity"); } class MapThenBatchTest @@ -298,9 +254,26 @@ TEST_P(MapThenBatchTest, IsVectorized) { MapVectorization optimizer; GraphDef output; TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output)); - CheckVectorized(output, num_parallel_calls > 0 ? kParallelMapOp : kMapOp, - use_batch_v2 ? kBatchV2Op : kBatchOp, range_dataset->name(), - /*fused=*/false, prefetch); + + std::vector expected_original_branch; + expected_original_branch.push_back(num_parallel_calls > 0 ? kParallelMapOp + : kMapOp); + if (prefetch) { + expected_original_branch.push_back(kPrefetchOp); + } + expected_original_branch.push_back(use_batch_v2 > 0 ? kBatchV2Op : kBatchOp); + + std::vector expected_vectorized_branch; + expected_vectorized_branch.push_back(use_batch_v2 > 0 ? kBatchV2Op + : kBatchOp); + expected_vectorized_branch.push_back(num_parallel_calls > 0 ? kParallelMapOp + : kMapOp); + if (prefetch) { + expected_vectorized_branch.push_back(kPrefetchOp); + } + + CheckVectorized(output, expected_vectorized_branch, expected_original_branch, + range_dataset->name()); } INSTANTIATE_TEST_SUITE_P(MapThenBatchTest, MapThenBatchTest, @@ -346,15 +319,9 @@ TEST(MapVectorizationTest, VectorizeExperimentalMapAndBatch) { MapVectorization optimizer; GraphDef output; TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output)); - CheckVectorizedFused(output, "range"); -} -void EvaluateNodes(const GraphDef& graph, - const std::vector& output_tensor_names, - std::vector* output_tensors) { - std::unique_ptr session(NewSession(SessionOptions())); - TF_CHECK_OK(session->Create(graph)); - TF_CHECK_OK(session->Run({}, output_tensor_names, {}, output_tensors)); + CheckVectorized(output, {kBatchV2Op, kParallelMapOp}, + {kExperimentalMapAndBatchOp}, range_node->name()); } class ChainedMapAndBatchTest @@ -403,17 +370,30 @@ TEST_P(ChainedMapAndBatchTest, IsVectorized) { const NodeDef& range_node = output.node(graph_utils::FindGraphNodeWithOp(kRangeOp, output)); const NodeDef& choose_fastest_0 = output.node(choose_fastest_nodes[0]); - CheckBranch(output, choose_fastest_0.input(0), {kBatchV2Op, kParallelMapOp}, - range_node.name()); - CheckBranch(output, choose_fastest_0.input(1), - fuse_0 ? fused_sequence : unfused_sequence, range_node.name()); - + ASSERT_EQ(choose_fastest_0.input(0), range_node.name()); const NodeDef& choose_fastest_1 = output.node(choose_fastest_nodes[1]); - CheckBranch(output, choose_fastest_1.input(0), {kBatchV2Op, kParallelMapOp}, - choose_fastest_0.name()); - CheckBranch(output, choose_fastest_1.input(1), - fuse_1 ? fused_sequence : unfused_sequence, - choose_fastest_0.name()); + ASSERT_EQ(choose_fastest_1.input(0), choose_fastest_0.name()); + + auto check_branches = [&output](const NodeDef& choose_fastest_node, + gtl::ArraySlice original_ops) { + const auto& functions_list = + choose_fastest_node.attr().at("branches").list(); + + // Branch 0: vectorized + const FunctionDef* branch_0 = + GetFunction(output, functions_list.func(0).name()); + ASSERT_NE(branch_0, nullptr); + CheckBranch(*branch_0, {kBatchV2Op, kParallelMapOp}); + + // Branch 1: original + const FunctionDef* branch_1 = + GetFunction(output, functions_list.func(1).name()); + ASSERT_NE(branch_1, nullptr); + CheckBranch(*branch_1, original_ops); + }; + + check_branches(choose_fastest_0, fuse_0 ? fused_sequence : unfused_sequence); + check_branches(choose_fastest_1, fuse_1 ? fused_sequence : unfused_sequence); } INSTANTIATE_TEST_SUITE_P(ChainedMapAndBatchTest, ChainedMapAndBatchTest, @@ -516,7 +496,10 @@ TEST(MapVectorizationTest, VectorizeWithUndefinedOutputTypes) { MapVectorization optimizer; GraphDef output; TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output)); - CheckVectorized(output, map_node->op(), batch_node->op(), input_node->name()); + CheckVectorized( + output, /*expected_vectorized_branch=*/{batch_node->op(), map_node->op()}, + /*expected_original_branch=*/{map_node->op(), batch_node->op()}, + input_node->name()); } // TODO(rachelim): Add test that has a polymorphic function. diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/map_vectorization_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/map_vectorization_test.py index 82cdbcf85fe..8dfcdc7e4b5 100644 --- a/tensorflow/python/data/experimental/kernel_tests/optimization/map_vectorization_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/optimization/map_vectorization_test.py @@ -359,8 +359,8 @@ class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase): unoptimized = _make_dataset([map_node_name, "Batch"]) # Note that because of the `ChooseDataset` fork, we can't use `assert_next` # to verify the optimization result. - optimized = _make_dataset( - [] if expect_optimized else [map_node_name, "Batch"]) + optimized = _make_dataset(["ChooseFastestBranch"] + if expect_optimized else [map_node_name, "Batch"]) options = dataset_ops.Options() options.experimental_optimization.apply_default_optimizations = False options.experimental_optimization.map_vectorization = True @@ -422,7 +422,7 @@ class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase): return dataset unoptimized = _make_dataset(["MapAndBatch"]) - optimized = _make_dataset([]) + optimized = _make_dataset(["ChooseFastestBranch"]) options = dataset_ops.Options() options.experimental_optimization.map_vectorization = True optimized = optimized.with_options(options) @@ -473,7 +473,7 @@ class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase): return dataset unoptimized = make_dataset(unoptimized_seq) - optimized = make_dataset([]) + optimized = make_dataset(["ChooseFastestBranch", "ChooseFastestBranch"]) options = dataset_ops.Options() options.experimental_optimization.map_vectorization = True optimized = optimized.with_options(options)