[tf.data] Use ChooseFastestBranch
dataset in MapVectorization optimization
PiperOrigin-RevId: 238452420
This commit is contained in:
parent
1020739e17
commit
aaa0ea6191
@ -522,6 +522,7 @@ tf_cc_test(
|
|||||||
name = "map_vectorization_test",
|
name = "map_vectorization_test",
|
||||||
srcs = ["map_vectorization_test.cc"],
|
srcs = ["map_vectorization_test.cc"],
|
||||||
deps = [
|
deps = [
|
||||||
|
":function_utils",
|
||||||
":graph_utils",
|
":graph_utils",
|
||||||
":map_vectorization",
|
":map_vectorization",
|
||||||
"//tensorflow/core:array_ops_op_lib",
|
"//tensorflow/core:array_ops_op_lib",
|
||||||
|
@ -50,7 +50,7 @@ constexpr char kBatchV2Op[] = "BatchDatasetV2";
|
|||||||
constexpr char kExperimentalMapAndBatchOp[] = "ExperimentalMapAndBatchDataset";
|
constexpr char kExperimentalMapAndBatchOp[] = "ExperimentalMapAndBatchDataset";
|
||||||
constexpr char kMapOp[] = "MapDataset";
|
constexpr char kMapOp[] = "MapDataset";
|
||||||
constexpr char kParallelMapOp[] = "ParallelMapDataset";
|
constexpr char kParallelMapOp[] = "ParallelMapDataset";
|
||||||
constexpr char kChooseFastestOp[] = "ExperimentalChooseFastestDataset";
|
constexpr char kChooseFastestOp[] = "ChooseFastestBranchDataset";
|
||||||
constexpr char kPrefetchOp[] = "PrefetchDataset";
|
constexpr char kPrefetchOp[] = "PrefetchDataset";
|
||||||
constexpr int kAutotune = -1;
|
constexpr int kAutotune = -1;
|
||||||
|
|
||||||
@ -317,23 +317,123 @@ Status AddNewPrefetchNode(const NodeDef& old_prefetch_node,
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status AddNewChooseFastestNode(gtl::ArraySlice<NodeDef> input_nodes,
|
Status AddBranch(gtl::ArraySlice<const NodeDef*> branch,
|
||||||
|
NodeDef* choose_fastest_node, DataTypeVector* t_arguments,
|
||||||
|
std::vector<NameAttrList>* branches,
|
||||||
|
std::vector<int>* other_arguments_lengths,
|
||||||
|
FunctionDefLibrary* library) {
|
||||||
|
FunctionDef* branch_func = library->add_function();
|
||||||
|
auto* signature = branch_func->mutable_signature();
|
||||||
|
graph_utils::SetUniqueGraphFunctionName("branch", library, branch_func);
|
||||||
|
|
||||||
|
// Input dataset.
|
||||||
|
string prev_node_output = "args_0";
|
||||||
|
auto* input_arg_0 = signature->add_input_arg();
|
||||||
|
input_arg_0->set_name(prev_node_output);
|
||||||
|
input_arg_0->set_type(DT_VARIANT);
|
||||||
|
|
||||||
|
auto* output_arg = signature->add_output_arg();
|
||||||
|
output_arg->set_name("output");
|
||||||
|
output_arg->set_type(DT_VARIANT);
|
||||||
|
|
||||||
|
int32 captured_arg_lengths = 0;
|
||||||
|
|
||||||
|
// For each node in the branch, copy it to the function def. Add the
|
||||||
|
// corresponding non-0th inputs as captured arguments, modifying the function
|
||||||
|
// input signature, node input names, other_arguments_lengths, and t_arguments
|
||||||
|
// accordingly.
|
||||||
|
for (const NodeDef* node : branch) {
|
||||||
|
// Copy the node to the function
|
||||||
|
auto function_node = branch_func->add_node_def();
|
||||||
|
*function_node = *node;
|
||||||
|
function_utils::SetUniqueFunctionNodeName(node->name(), branch_func,
|
||||||
|
function_node);
|
||||||
|
function_node->clear_input();
|
||||||
|
function_node->add_input(prev_node_output);
|
||||||
|
|
||||||
|
// Every input besides the 0th (dataset) becomes a captured argument.
|
||||||
|
int input_size = node->input_size();
|
||||||
|
DataTypeVector input_types;
|
||||||
|
const OpDef* op_def;
|
||||||
|
TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef(node->op(), &op_def));
|
||||||
|
TF_RETURN_IF_ERROR(InputTypesForNode(*node, *op_def, &input_types));
|
||||||
|
DCHECK_EQ(input_types.size(), input_size);
|
||||||
|
|
||||||
|
for (int i = 1; i < input_size; ++i) {
|
||||||
|
// Capture input in `other_arguments`
|
||||||
|
choose_fastest_node->add_input(node->input(i));
|
||||||
|
// Add type to function signature
|
||||||
|
auto* input_arg = signature->add_input_arg();
|
||||||
|
|
||||||
|
string input_arg_name = strings::StrCat(function_node->name(), "_", i);
|
||||||
|
input_arg->set_name(input_arg_name);
|
||||||
|
input_arg->set_type(input_types[i]);
|
||||||
|
function_node->add_input(input_arg_name);
|
||||||
|
}
|
||||||
|
// Add to `Targuments`
|
||||||
|
t_arguments->reserve(t_arguments->size() + input_types.size() - 1);
|
||||||
|
t_arguments->insert(t_arguments->end(), input_types.begin() + 1,
|
||||||
|
input_types.end());
|
||||||
|
captured_arg_lengths += input_size - 1;
|
||||||
|
prev_node_output = strings::StrCat(function_node->name(), ":handle:0");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add to `other_arguments_lengths`
|
||||||
|
other_arguments_lengths->push_back(captured_arg_lengths);
|
||||||
|
(*branch_func->mutable_ret())["output"] = prev_node_output;
|
||||||
|
|
||||||
|
// Add to `branches`
|
||||||
|
NameAttrList func_attr;
|
||||||
|
func_attr.set_name(branch_func->signature().name());
|
||||||
|
branches->push_back(std::move(func_attr));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status AddNewChooseFastestNode(const NodeDef* input_dataset_node,
|
||||||
|
const string& ratio_numerator_name,
|
||||||
|
std::vector<const NodeDef*> original_branch,
|
||||||
|
std::vector<const NodeDef*> vectorized_branch,
|
||||||
MutableGraphView* graph,
|
MutableGraphView* graph,
|
||||||
|
FunctionDefLibrary* library,
|
||||||
NodeDef** new_choose_fastest_node) {
|
NodeDef** new_choose_fastest_node) {
|
||||||
NodeDef choose_fastest_node;
|
NodeDef choose_fastest_node;
|
||||||
choose_fastest_node.set_op(kChooseFastestOp);
|
choose_fastest_node.set_op(kChooseFastestOp);
|
||||||
graph_utils::SetUniqueGraphNodeName(choose_fastest_node.op(), graph->graph(),
|
graph_utils::SetUniqueGraphNodeName(choose_fastest_node.op(), graph->graph(),
|
||||||
&choose_fastest_node);
|
&choose_fastest_node);
|
||||||
|
|
||||||
// Set the `input_datasets` input argument.
|
// input_dataset
|
||||||
for (const auto& node_def : input_nodes) {
|
choose_fastest_node.add_input(input_dataset_node->name());
|
||||||
choose_fastest_node.add_input(node_def.name());
|
choose_fastest_node.add_input(ratio_numerator_name);
|
||||||
}
|
// ratio_denominator == 1
|
||||||
AddNodeAttr("N", static_cast<int>(input_nodes.size()), &choose_fastest_node);
|
auto ratio_denominator =
|
||||||
AddNodeAttr("num_experiments", 10, &choose_fastest_node);
|
graph_utils::AddScalarConstNode(static_cast<int64>(1), graph);
|
||||||
|
choose_fastest_node.add_input(ratio_denominator->name());
|
||||||
|
|
||||||
|
DataTypeVector t_arguments;
|
||||||
|
std::vector<NameAttrList> branches;
|
||||||
|
std::vector<int32> other_arguments_lengths;
|
||||||
|
// Branch 0: vectorized branch
|
||||||
|
TF_RETURN_IF_ERROR(AddBranch(vectorized_branch, &choose_fastest_node,
|
||||||
|
&t_arguments, &branches,
|
||||||
|
&other_arguments_lengths, library));
|
||||||
|
// Branch 1: original branch
|
||||||
|
TF_RETURN_IF_ERROR(AddBranch(original_branch, &choose_fastest_node,
|
||||||
|
&t_arguments, &branches,
|
||||||
|
&other_arguments_lengths, library));
|
||||||
|
|
||||||
|
DCHECK_EQ(t_arguments.size(), choose_fastest_node.input_size() - 3);
|
||||||
|
DCHECK_EQ(branches.size(), other_arguments_lengths.size());
|
||||||
|
|
||||||
|
AddNodeAttr("Targuments", t_arguments, &choose_fastest_node);
|
||||||
|
AddNodeAttr("num_elements_per_branch", 10, &choose_fastest_node);
|
||||||
|
AddNodeAttr("branches", branches, &choose_fastest_node);
|
||||||
|
AddNodeAttr("other_arguments_lengths", other_arguments_lengths,
|
||||||
|
&choose_fastest_node);
|
||||||
|
|
||||||
for (auto key : {"output_shapes", "output_types"}) {
|
for (auto key : {"output_shapes", "output_types"}) {
|
||||||
graph_utils::CopyAttribute(key, input_nodes[0], &choose_fastest_node);
|
graph_utils::CopyAttribute(key,
|
||||||
|
*vectorized_branch[vectorized_branch.size() - 1],
|
||||||
|
&choose_fastest_node);
|
||||||
}
|
}
|
||||||
|
|
||||||
*new_choose_fastest_node = graph->AddNode(std::move(choose_fastest_node));
|
*new_choose_fastest_node = graph->AddNode(std::move(choose_fastest_node));
|
||||||
@ -434,14 +534,16 @@ Status MapVectorization::OptimizeAndCollectStats(Cluster* cluster,
|
|||||||
AddVectorizedFunction(*map_node, *map_func, library);
|
AddVectorizedFunction(*map_node, *map_func, library);
|
||||||
CHECK_NOTNULL(vectorized_func);
|
CHECK_NOTNULL(vectorized_func);
|
||||||
|
|
||||||
|
std::vector<const NodeDef*> vectorized_branch;
|
||||||
NodeDef* new_batch_node;
|
NodeDef* new_batch_node;
|
||||||
TF_RETURN_IF_ERROR(AddNewBatchNode(
|
TF_RETURN_IF_ERROR(AddNewBatchNode(
|
||||||
*batch_node, *input_node, *vectorized_func, &graph, &new_batch_node));
|
*batch_node, *input_node, *vectorized_func, &graph, &new_batch_node));
|
||||||
|
vectorized_branch.push_back(new_batch_node);
|
||||||
|
|
||||||
NodeDef* new_map_node;
|
NodeDef* new_map_node;
|
||||||
TF_RETURN_IF_ERROR(AddNewMapNode(*map_node, *batch_node, *new_batch_node,
|
TF_RETURN_IF_ERROR(AddNewMapNode(*map_node, *batch_node, *new_batch_node,
|
||||||
*vectorized_func, &graph, &new_map_node));
|
*vectorized_func, &graph, &new_map_node));
|
||||||
NodeDef* final_node = new_map_node;
|
vectorized_branch.push_back(new_map_node);
|
||||||
|
|
||||||
if (optional_prefetch_node) {
|
if (optional_prefetch_node) {
|
||||||
// If the original pipeline was .map().prefetch().batch(), the new
|
// If the original pipeline was .map().prefetch().batch(), the new
|
||||||
@ -450,13 +552,22 @@ Status MapVectorization::OptimizeAndCollectStats(Cluster* cluster,
|
|||||||
TF_RETURN_IF_ERROR(AddNewPrefetchNode(*optional_prefetch_node,
|
TF_RETURN_IF_ERROR(AddNewPrefetchNode(*optional_prefetch_node,
|
||||||
*batch_node, *new_map_node, &graph,
|
*batch_node, *new_map_node, &graph,
|
||||||
&new_prefetch_node));
|
&new_prefetch_node));
|
||||||
|
vectorized_branch.push_back(new_prefetch_node);
|
||||||
|
}
|
||||||
|
|
||||||
final_node = new_prefetch_node;
|
std::vector<const NodeDef*> original_branch({map_node});
|
||||||
|
if (optional_prefetch_node) {
|
||||||
|
original_branch.push_back(optional_prefetch_node);
|
||||||
|
}
|
||||||
|
if (map_node->op() != kExperimentalMapAndBatchOp) {
|
||||||
|
original_branch.push_back(batch_node);
|
||||||
}
|
}
|
||||||
|
|
||||||
NodeDef* new_choose_fastest_node;
|
NodeDef* new_choose_fastest_node;
|
||||||
TF_RETURN_IF_ERROR(AddNewChooseFastestNode(
|
TF_RETURN_IF_ERROR(AddNewChooseFastestNode(
|
||||||
{*final_node, *batch_node}, &graph, &new_choose_fastest_node));
|
input_node, /*ratio_numerator_name=*/new_batch_node->input(1),
|
||||||
|
std::move(original_branch), std::move(vectorized_branch), &graph,
|
||||||
|
library, &new_choose_fastest_node));
|
||||||
|
|
||||||
// Make output of Batch point to ChooseFastest instead.
|
// Make output of Batch point to ChooseFastest instead.
|
||||||
TF_RETURN_IF_ERROR(graph.UpdateFanouts(batch_node->name(),
|
TF_RETURN_IF_ERROR(graph.UpdateFanouts(batch_node->name(),
|
||||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/node_def_util.h"
|
#include "tensorflow/core/framework/node_def_util.h"
|
||||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||||
#include "tensorflow/core/grappler/grappler_item.h"
|
#include "tensorflow/core/grappler/grappler_item.h"
|
||||||
|
#include "tensorflow/core/grappler/optimizers/data/function_utils.h"
|
||||||
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
|
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
|
||||||
#include "tensorflow/core/grappler/utils/topological_sort.h"
|
#include "tensorflow/core/grappler/utils/topological_sort.h"
|
||||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
@ -38,7 +39,7 @@ constexpr char kBatchV2Op[] = "BatchDatasetV2";
|
|||||||
constexpr char kExperimentalMapAndBatchOp[] = "ExperimentalMapAndBatchDataset";
|
constexpr char kExperimentalMapAndBatchOp[] = "ExperimentalMapAndBatchDataset";
|
||||||
constexpr char kMapOp[] = "MapDataset";
|
constexpr char kMapOp[] = "MapDataset";
|
||||||
constexpr char kParallelMapOp[] = "ParallelMapDataset";
|
constexpr char kParallelMapOp[] = "ParallelMapDataset";
|
||||||
constexpr char kChooseFastestOp[] = "ExperimentalChooseFastestDataset";
|
constexpr char kChooseFastestOp[] = "ChooseFastestBranchDataset";
|
||||||
constexpr char kPrefetchOp[] = "PrefetchDataset";
|
constexpr char kPrefetchOp[] = "PrefetchDataset";
|
||||||
constexpr char kAttrNameF[] = "f";
|
constexpr char kAttrNameF[] = "f";
|
||||||
constexpr char kAttrNameTarguments[] = "Targuments";
|
constexpr char kAttrNameTarguments[] = "Targuments";
|
||||||
@ -171,110 +172,65 @@ void CheckNotVectorized(const GraphDef& output, const string& map_op,
|
|||||||
EXPECT_EQ(batch_node.input(0), map_node.name());
|
EXPECT_EQ(batch_node.input(0), map_node.name());
|
||||||
}
|
}
|
||||||
|
|
||||||
void CheckBranch(const GraphDef& graph, string input_name,
|
void CheckBranch(const FunctionDef& function, gtl::ArraySlice<string> ops) {
|
||||||
gtl::ArraySlice<string> ops, const string& terminal_input) {
|
|
||||||
for (int i = 0, size = ops.size(); i < size; ++i) {
|
for (int i = 0, size = ops.size(); i < size; ++i) {
|
||||||
const NodeDef& input_node =
|
EXPECT_EQ(function.node_def(i).op(), ops[i]);
|
||||||
graph.node(graph_utils::FindGraphNodeWithName(input_name, graph));
|
|
||||||
EXPECT_EQ(input_node.op(), ops[size - i - 1]);
|
|
||||||
input_name = input_node.input(0);
|
|
||||||
}
|
}
|
||||||
EXPECT_EQ(input_name, terminal_input);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Checks that a graph has undergone the map_vectorization transformation
|
const FunctionDef* GetFunction(const GraphDef& graph,
|
||||||
// successfully, whereby the new graph has the shape:
|
const string& function_name) {
|
||||||
//
|
|
||||||
// input_node --> new batch --> new map -------+
|
|
||||||
// | |
|
|
||||||
// | v
|
|
||||||
// +-------> old map --> old batch ---> choose_fastest
|
|
||||||
//
|
|
||||||
void CheckVectorized(const GraphDef& output, const string& map_op,
|
|
||||||
const string& batch_op, const string& map_input_name,
|
|
||||||
bool fused = false, bool prefetch = false) {
|
|
||||||
ASSERT_EQ(graph_utils::FindAllGraphNodesWithOp(map_op, output).size(), 2);
|
|
||||||
ASSERT_EQ(graph_utils::FindAllGraphNodesWithOp(batch_op, output).size(), 2);
|
|
||||||
ASSERT_EQ(
|
|
||||||
graph_utils::FindAllGraphNodesWithOp(kChooseFastestOp, output).size(), 1);
|
|
||||||
const NodeDef& choose_fastest_node =
|
|
||||||
output.node(graph_utils::FindGraphNodeWithOp(kChooseFastestOp, output));
|
|
||||||
|
|
||||||
// Branch 0: vectorized
|
|
||||||
std::vector<string> vectorized_ops({batch_op, map_op});
|
|
||||||
std::vector<string> unvectorized_ops({map_op, batch_op});
|
|
||||||
if (prefetch) {
|
|
||||||
vectorized_ops.push_back(kPrefetchOp);
|
|
||||||
unvectorized_ops.insert(unvectorized_ops.begin() + 1, kPrefetchOp);
|
|
||||||
}
|
|
||||||
CheckBranch(output, choose_fastest_node.input(0), vectorized_ops,
|
|
||||||
map_input_name);
|
|
||||||
|
|
||||||
// Branch 1: original
|
|
||||||
CheckBranch(output, choose_fastest_node.input(1), unvectorized_ops,
|
|
||||||
map_input_name);
|
|
||||||
|
|
||||||
const NodeDef* vectorized_map_node = nullptr;
|
|
||||||
auto tmp_node = &output.node(
|
|
||||||
graph_utils::FindGraphNodeWithName(choose_fastest_node.input(0), output));
|
|
||||||
if (prefetch) {
|
|
||||||
vectorized_map_node = &output.node(
|
|
||||||
graph_utils::FindGraphNodeWithName(tmp_node->input(0), output));
|
|
||||||
} else {
|
|
||||||
vectorized_map_node = tmp_node;
|
|
||||||
}
|
|
||||||
// Check that the function is actually vectorized.
|
|
||||||
// The vectorization of the identity function is itself.
|
|
||||||
string function_name =
|
|
||||||
vectorized_map_node->attr().at(kAttrNameF).func().name();
|
|
||||||
int found =
|
int found =
|
||||||
graph_utils::FindGraphFunctionWithName(function_name, output.library());
|
graph_utils::FindGraphFunctionWithName(function_name, graph.library());
|
||||||
ASSERT_NE(found, -1);
|
if (found == -1) {
|
||||||
const auto& function = output.library().function(found);
|
return nullptr;
|
||||||
EXPECT_EQ(function.node_def(0).op(), "Identity");
|
}
|
||||||
|
return &graph.library().function(found);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Checks that a graph has undergone the map_vectorization transformation
|
// Checks that a graph has undergone the map_vectorization transformation
|
||||||
// successfully, whereby the new graph has the shape:
|
// successfully, whereby the new graph has the shape:
|
||||||
//
|
//
|
||||||
// input_node --> new batch -> new map --------+
|
// input_node -------------> choose_fastest --> ...
|
||||||
// | |
|
// |f0 |f1
|
||||||
// | v
|
// | |
|
||||||
// +-------> old map_and_batch ---> choose_fastest
|
// | +---> new batch --> new map
|
||||||
|
// |
|
||||||
|
// +--> old map --> old batch
|
||||||
//
|
//
|
||||||
void CheckVectorizedFused(const GraphDef& output,
|
void CheckVectorized(const GraphDef& output,
|
||||||
const string& map_input_name) {
|
gtl::ArraySlice<string> expected_vectorized_branch,
|
||||||
ASSERT_EQ(graph_utils::FindAllGraphNodesWithOp(kParallelMapOp, output).size(),
|
gtl::ArraySlice<string> expected_original_branch,
|
||||||
1);
|
const string& input_name) {
|
||||||
ASSERT_EQ(graph_utils::FindAllGraphNodesWithOp(kBatchV2Op, output).size(), 1);
|
|
||||||
ASSERT_EQ(
|
|
||||||
graph_utils::FindAllGraphNodesWithOp(kExperimentalMapAndBatchOp, output)
|
|
||||||
.size(),
|
|
||||||
1);
|
|
||||||
ASSERT_EQ(
|
ASSERT_EQ(
|
||||||
graph_utils::FindAllGraphNodesWithOp(kChooseFastestOp, output).size(), 1);
|
graph_utils::FindAllGraphNodesWithOp(kChooseFastestOp, output).size(), 1);
|
||||||
const NodeDef& choose_fastest_node =
|
const NodeDef& choose_fastest_node =
|
||||||
output.node(graph_utils::FindGraphNodeWithOp(kChooseFastestOp, output));
|
output.node(graph_utils::FindGraphNodeWithOp(kChooseFastestOp, output));
|
||||||
|
ASSERT_EQ(choose_fastest_node.input(0), input_name);
|
||||||
|
|
||||||
|
const auto& functions_list = choose_fastest_node.attr().at("branches").list();
|
||||||
|
|
||||||
// Branch 0: vectorized
|
// Branch 0: vectorized
|
||||||
CheckBranch(output, choose_fastest_node.input(0),
|
const FunctionDef* branch_0 =
|
||||||
{kBatchV2Op, kParallelMapOp}, map_input_name);
|
GetFunction(output, functions_list.func(0).name());
|
||||||
|
ASSERT_NE(branch_0, nullptr);
|
||||||
|
CheckBranch(*branch_0, expected_vectorized_branch);
|
||||||
|
|
||||||
// Branch 1: original
|
// Branch 1: original
|
||||||
CheckBranch(output, choose_fastest_node.input(1),
|
const FunctionDef* branch_1 =
|
||||||
{kExperimentalMapAndBatchOp}, map_input_name);
|
GetFunction(output, functions_list.func(1).name());
|
||||||
|
ASSERT_NE(branch_1, nullptr);
|
||||||
|
CheckBranch(*branch_1, expected_original_branch);
|
||||||
|
|
||||||
const NodeDef& vectorized_map_node = output.node(
|
const NodeDef& vectorized_map_node =
|
||||||
graph_utils::FindGraphNodeWithName(choose_fastest_node.input(0), output));
|
branch_0->node_def(function_utils::FindFunctionNodeWithOp(
|
||||||
// Check that the function is actually vectorized.
|
expected_vectorized_branch[1], *branch_0));
|
||||||
// The vectorization of the identity function is itself.
|
|
||||||
string function_name =
|
string function_name =
|
||||||
vectorized_map_node.attr().at(kAttrNameF).func().name();
|
vectorized_map_node.attr().at(kAttrNameF).func().name();
|
||||||
int found =
|
|
||||||
graph_utils::FindGraphFunctionWithName(function_name, output.library());
|
const FunctionDef* function = GetFunction(output, function_name);
|
||||||
ASSERT_NE(found, -1);
|
ASSERT_NE(function, nullptr);
|
||||||
const auto& function = output.library().function(found);
|
EXPECT_EQ(function->node_def(0).op(), "Identity");
|
||||||
EXPECT_EQ(function.node_def(0).op(), "Identity");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
class MapThenBatchTest
|
class MapThenBatchTest
|
||||||
@ -298,9 +254,26 @@ TEST_P(MapThenBatchTest, IsVectorized) {
|
|||||||
MapVectorization optimizer;
|
MapVectorization optimizer;
|
||||||
GraphDef output;
|
GraphDef output;
|
||||||
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
|
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
|
||||||
CheckVectorized(output, num_parallel_calls > 0 ? kParallelMapOp : kMapOp,
|
|
||||||
use_batch_v2 ? kBatchV2Op : kBatchOp, range_dataset->name(),
|
std::vector<string> expected_original_branch;
|
||||||
/*fused=*/false, prefetch);
|
expected_original_branch.push_back(num_parallel_calls > 0 ? kParallelMapOp
|
||||||
|
: kMapOp);
|
||||||
|
if (prefetch) {
|
||||||
|
expected_original_branch.push_back(kPrefetchOp);
|
||||||
|
}
|
||||||
|
expected_original_branch.push_back(use_batch_v2 > 0 ? kBatchV2Op : kBatchOp);
|
||||||
|
|
||||||
|
std::vector<string> expected_vectorized_branch;
|
||||||
|
expected_vectorized_branch.push_back(use_batch_v2 > 0 ? kBatchV2Op
|
||||||
|
: kBatchOp);
|
||||||
|
expected_vectorized_branch.push_back(num_parallel_calls > 0 ? kParallelMapOp
|
||||||
|
: kMapOp);
|
||||||
|
if (prefetch) {
|
||||||
|
expected_vectorized_branch.push_back(kPrefetchOp);
|
||||||
|
}
|
||||||
|
|
||||||
|
CheckVectorized(output, expected_vectorized_branch, expected_original_branch,
|
||||||
|
range_dataset->name());
|
||||||
}
|
}
|
||||||
|
|
||||||
INSTANTIATE_TEST_SUITE_P(MapThenBatchTest, MapThenBatchTest,
|
INSTANTIATE_TEST_SUITE_P(MapThenBatchTest, MapThenBatchTest,
|
||||||
@ -346,15 +319,9 @@ TEST(MapVectorizationTest, VectorizeExperimentalMapAndBatch) {
|
|||||||
MapVectorization optimizer;
|
MapVectorization optimizer;
|
||||||
GraphDef output;
|
GraphDef output;
|
||||||
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
|
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
|
||||||
CheckVectorizedFused(output, "range");
|
|
||||||
}
|
|
||||||
|
|
||||||
void EvaluateNodes(const GraphDef& graph,
|
CheckVectorized(output, {kBatchV2Op, kParallelMapOp},
|
||||||
const std::vector<string>& output_tensor_names,
|
{kExperimentalMapAndBatchOp}, range_node->name());
|
||||||
std::vector<Tensor>* output_tensors) {
|
|
||||||
std::unique_ptr<Session> session(NewSession(SessionOptions()));
|
|
||||||
TF_CHECK_OK(session->Create(graph));
|
|
||||||
TF_CHECK_OK(session->Run({}, output_tensor_names, {}, output_tensors));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
class ChainedMapAndBatchTest
|
class ChainedMapAndBatchTest
|
||||||
@ -403,17 +370,30 @@ TEST_P(ChainedMapAndBatchTest, IsVectorized) {
|
|||||||
const NodeDef& range_node =
|
const NodeDef& range_node =
|
||||||
output.node(graph_utils::FindGraphNodeWithOp(kRangeOp, output));
|
output.node(graph_utils::FindGraphNodeWithOp(kRangeOp, output));
|
||||||
const NodeDef& choose_fastest_0 = output.node(choose_fastest_nodes[0]);
|
const NodeDef& choose_fastest_0 = output.node(choose_fastest_nodes[0]);
|
||||||
CheckBranch(output, choose_fastest_0.input(0), {kBatchV2Op, kParallelMapOp},
|
ASSERT_EQ(choose_fastest_0.input(0), range_node.name());
|
||||||
range_node.name());
|
|
||||||
CheckBranch(output, choose_fastest_0.input(1),
|
|
||||||
fuse_0 ? fused_sequence : unfused_sequence, range_node.name());
|
|
||||||
|
|
||||||
const NodeDef& choose_fastest_1 = output.node(choose_fastest_nodes[1]);
|
const NodeDef& choose_fastest_1 = output.node(choose_fastest_nodes[1]);
|
||||||
CheckBranch(output, choose_fastest_1.input(0), {kBatchV2Op, kParallelMapOp},
|
ASSERT_EQ(choose_fastest_1.input(0), choose_fastest_0.name());
|
||||||
choose_fastest_0.name());
|
|
||||||
CheckBranch(output, choose_fastest_1.input(1),
|
auto check_branches = [&output](const NodeDef& choose_fastest_node,
|
||||||
fuse_1 ? fused_sequence : unfused_sequence,
|
gtl::ArraySlice<string> original_ops) {
|
||||||
choose_fastest_0.name());
|
const auto& functions_list =
|
||||||
|
choose_fastest_node.attr().at("branches").list();
|
||||||
|
|
||||||
|
// Branch 0: vectorized
|
||||||
|
const FunctionDef* branch_0 =
|
||||||
|
GetFunction(output, functions_list.func(0).name());
|
||||||
|
ASSERT_NE(branch_0, nullptr);
|
||||||
|
CheckBranch(*branch_0, {kBatchV2Op, kParallelMapOp});
|
||||||
|
|
||||||
|
// Branch 1: original
|
||||||
|
const FunctionDef* branch_1 =
|
||||||
|
GetFunction(output, functions_list.func(1).name());
|
||||||
|
ASSERT_NE(branch_1, nullptr);
|
||||||
|
CheckBranch(*branch_1, original_ops);
|
||||||
|
};
|
||||||
|
|
||||||
|
check_branches(choose_fastest_0, fuse_0 ? fused_sequence : unfused_sequence);
|
||||||
|
check_branches(choose_fastest_1, fuse_1 ? fused_sequence : unfused_sequence);
|
||||||
}
|
}
|
||||||
|
|
||||||
INSTANTIATE_TEST_SUITE_P(ChainedMapAndBatchTest, ChainedMapAndBatchTest,
|
INSTANTIATE_TEST_SUITE_P(ChainedMapAndBatchTest, ChainedMapAndBatchTest,
|
||||||
@ -516,7 +496,10 @@ TEST(MapVectorizationTest, VectorizeWithUndefinedOutputTypes) {
|
|||||||
MapVectorization optimizer;
|
MapVectorization optimizer;
|
||||||
GraphDef output;
|
GraphDef output;
|
||||||
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
|
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
|
||||||
CheckVectorized(output, map_node->op(), batch_node->op(), input_node->name());
|
CheckVectorized(
|
||||||
|
output, /*expected_vectorized_branch=*/{batch_node->op(), map_node->op()},
|
||||||
|
/*expected_original_branch=*/{map_node->op(), batch_node->op()},
|
||||||
|
input_node->name());
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(rachelim): Add test that has a polymorphic function.
|
// TODO(rachelim): Add test that has a polymorphic function.
|
||||||
|
@ -359,8 +359,8 @@ class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
unoptimized = _make_dataset([map_node_name, "Batch"])
|
unoptimized = _make_dataset([map_node_name, "Batch"])
|
||||||
# Note that because of the `ChooseDataset` fork, we can't use `assert_next`
|
# Note that because of the `ChooseDataset` fork, we can't use `assert_next`
|
||||||
# to verify the optimization result.
|
# to verify the optimization result.
|
||||||
optimized = _make_dataset(
|
optimized = _make_dataset(["ChooseFastestBranch"]
|
||||||
[] if expect_optimized else [map_node_name, "Batch"])
|
if expect_optimized else [map_node_name, "Batch"])
|
||||||
options = dataset_ops.Options()
|
options = dataset_ops.Options()
|
||||||
options.experimental_optimization.apply_default_optimizations = False
|
options.experimental_optimization.apply_default_optimizations = False
|
||||||
options.experimental_optimization.map_vectorization = True
|
options.experimental_optimization.map_vectorization = True
|
||||||
@ -422,7 +422,7 @@ class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
unoptimized = _make_dataset(["MapAndBatch"])
|
unoptimized = _make_dataset(["MapAndBatch"])
|
||||||
optimized = _make_dataset([])
|
optimized = _make_dataset(["ChooseFastestBranch"])
|
||||||
options = dataset_ops.Options()
|
options = dataset_ops.Options()
|
||||||
options.experimental_optimization.map_vectorization = True
|
options.experimental_optimization.map_vectorization = True
|
||||||
optimized = optimized.with_options(options)
|
optimized = optimized.with_options(options)
|
||||||
@ -473,7 +473,7 @@ class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
unoptimized = make_dataset(unoptimized_seq)
|
unoptimized = make_dataset(unoptimized_seq)
|
||||||
optimized = make_dataset([])
|
optimized = make_dataset(["ChooseFastestBranch", "ChooseFastestBranch"])
|
||||||
options = dataset_ops.Options()
|
options = dataset_ops.Options()
|
||||||
options.experimental_optimization.map_vectorization = True
|
options.experimental_optimization.map_vectorization = True
|
||||||
optimized = optimized.with_options(options)
|
optimized = optimized.with_options(options)
|
||||||
|
Loading…
Reference in New Issue
Block a user