[tf.data] Use ChooseFastestBranch dataset in MapVectorization optimization

PiperOrigin-RevId: 238452420
This commit is contained in:
Rachel Lim 2019-03-14 09:00:40 -07:00 committed by TensorFlower Gardener
parent 1020739e17
commit aaa0ea6191
4 changed files with 216 additions and 121 deletions

View File

@ -522,6 +522,7 @@ tf_cc_test(
name = "map_vectorization_test", name = "map_vectorization_test",
srcs = ["map_vectorization_test.cc"], srcs = ["map_vectorization_test.cc"],
deps = [ deps = [
":function_utils",
":graph_utils", ":graph_utils",
":map_vectorization", ":map_vectorization",
"//tensorflow/core:array_ops_op_lib", "//tensorflow/core:array_ops_op_lib",

View File

@ -50,7 +50,7 @@ constexpr char kBatchV2Op[] = "BatchDatasetV2";
constexpr char kExperimentalMapAndBatchOp[] = "ExperimentalMapAndBatchDataset"; constexpr char kExperimentalMapAndBatchOp[] = "ExperimentalMapAndBatchDataset";
constexpr char kMapOp[] = "MapDataset"; constexpr char kMapOp[] = "MapDataset";
constexpr char kParallelMapOp[] = "ParallelMapDataset"; constexpr char kParallelMapOp[] = "ParallelMapDataset";
constexpr char kChooseFastestOp[] = "ExperimentalChooseFastestDataset"; constexpr char kChooseFastestOp[] = "ChooseFastestBranchDataset";
constexpr char kPrefetchOp[] = "PrefetchDataset"; constexpr char kPrefetchOp[] = "PrefetchDataset";
constexpr int kAutotune = -1; constexpr int kAutotune = -1;
@ -317,23 +317,123 @@ Status AddNewPrefetchNode(const NodeDef& old_prefetch_node,
return Status::OK(); return Status::OK();
} }
Status AddNewChooseFastestNode(gtl::ArraySlice<NodeDef> input_nodes, Status AddBranch(gtl::ArraySlice<const NodeDef*> branch,
NodeDef* choose_fastest_node, DataTypeVector* t_arguments,
std::vector<NameAttrList>* branches,
std::vector<int>* other_arguments_lengths,
FunctionDefLibrary* library) {
FunctionDef* branch_func = library->add_function();
auto* signature = branch_func->mutable_signature();
graph_utils::SetUniqueGraphFunctionName("branch", library, branch_func);
// Input dataset.
string prev_node_output = "args_0";
auto* input_arg_0 = signature->add_input_arg();
input_arg_0->set_name(prev_node_output);
input_arg_0->set_type(DT_VARIANT);
auto* output_arg = signature->add_output_arg();
output_arg->set_name("output");
output_arg->set_type(DT_VARIANT);
int32 captured_arg_lengths = 0;
// For each node in the branch, copy it to the function def. Add the
// corresponding non-0th inputs as captured arguments, modifying the function
// input signature, node input names, other_arguments_lengths, and t_arguments
// accordingly.
for (const NodeDef* node : branch) {
// Copy the node to the function
auto function_node = branch_func->add_node_def();
*function_node = *node;
function_utils::SetUniqueFunctionNodeName(node->name(), branch_func,
function_node);
function_node->clear_input();
function_node->add_input(prev_node_output);
// Every input besides the 0th (dataset) becomes a captured argument.
int input_size = node->input_size();
DataTypeVector input_types;
const OpDef* op_def;
TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef(node->op(), &op_def));
TF_RETURN_IF_ERROR(InputTypesForNode(*node, *op_def, &input_types));
DCHECK_EQ(input_types.size(), input_size);
for (int i = 1; i < input_size; ++i) {
// Capture input in `other_arguments`
choose_fastest_node->add_input(node->input(i));
// Add type to function signature
auto* input_arg = signature->add_input_arg();
string input_arg_name = strings::StrCat(function_node->name(), "_", i);
input_arg->set_name(input_arg_name);
input_arg->set_type(input_types[i]);
function_node->add_input(input_arg_name);
}
// Add to `Targuments`
t_arguments->reserve(t_arguments->size() + input_types.size() - 1);
t_arguments->insert(t_arguments->end(), input_types.begin() + 1,
input_types.end());
captured_arg_lengths += input_size - 1;
prev_node_output = strings::StrCat(function_node->name(), ":handle:0");
}
// Add to `other_arguments_lengths`
other_arguments_lengths->push_back(captured_arg_lengths);
(*branch_func->mutable_ret())["output"] = prev_node_output;
// Add to `branches`
NameAttrList func_attr;
func_attr.set_name(branch_func->signature().name());
branches->push_back(std::move(func_attr));
return Status::OK();
}
Status AddNewChooseFastestNode(const NodeDef* input_dataset_node,
const string& ratio_numerator_name,
std::vector<const NodeDef*> original_branch,
std::vector<const NodeDef*> vectorized_branch,
MutableGraphView* graph, MutableGraphView* graph,
FunctionDefLibrary* library,
NodeDef** new_choose_fastest_node) { NodeDef** new_choose_fastest_node) {
NodeDef choose_fastest_node; NodeDef choose_fastest_node;
choose_fastest_node.set_op(kChooseFastestOp); choose_fastest_node.set_op(kChooseFastestOp);
graph_utils::SetUniqueGraphNodeName(choose_fastest_node.op(), graph->graph(), graph_utils::SetUniqueGraphNodeName(choose_fastest_node.op(), graph->graph(),
&choose_fastest_node); &choose_fastest_node);
// Set the `input_datasets` input argument. // input_dataset
for (const auto& node_def : input_nodes) { choose_fastest_node.add_input(input_dataset_node->name());
choose_fastest_node.add_input(node_def.name()); choose_fastest_node.add_input(ratio_numerator_name);
} // ratio_denominator == 1
AddNodeAttr("N", static_cast<int>(input_nodes.size()), &choose_fastest_node); auto ratio_denominator =
AddNodeAttr("num_experiments", 10, &choose_fastest_node); graph_utils::AddScalarConstNode(static_cast<int64>(1), graph);
choose_fastest_node.add_input(ratio_denominator->name());
DataTypeVector t_arguments;
std::vector<NameAttrList> branches;
std::vector<int32> other_arguments_lengths;
// Branch 0: vectorized branch
TF_RETURN_IF_ERROR(AddBranch(vectorized_branch, &choose_fastest_node,
&t_arguments, &branches,
&other_arguments_lengths, library));
// Branch 1: original branch
TF_RETURN_IF_ERROR(AddBranch(original_branch, &choose_fastest_node,
&t_arguments, &branches,
&other_arguments_lengths, library));
DCHECK_EQ(t_arguments.size(), choose_fastest_node.input_size() - 3);
DCHECK_EQ(branches.size(), other_arguments_lengths.size());
AddNodeAttr("Targuments", t_arguments, &choose_fastest_node);
AddNodeAttr("num_elements_per_branch", 10, &choose_fastest_node);
AddNodeAttr("branches", branches, &choose_fastest_node);
AddNodeAttr("other_arguments_lengths", other_arguments_lengths,
&choose_fastest_node);
for (auto key : {"output_shapes", "output_types"}) { for (auto key : {"output_shapes", "output_types"}) {
graph_utils::CopyAttribute(key, input_nodes[0], &choose_fastest_node); graph_utils::CopyAttribute(key,
*vectorized_branch[vectorized_branch.size() - 1],
&choose_fastest_node);
} }
*new_choose_fastest_node = graph->AddNode(std::move(choose_fastest_node)); *new_choose_fastest_node = graph->AddNode(std::move(choose_fastest_node));
@ -434,14 +534,16 @@ Status MapVectorization::OptimizeAndCollectStats(Cluster* cluster,
AddVectorizedFunction(*map_node, *map_func, library); AddVectorizedFunction(*map_node, *map_func, library);
CHECK_NOTNULL(vectorized_func); CHECK_NOTNULL(vectorized_func);
std::vector<const NodeDef*> vectorized_branch;
NodeDef* new_batch_node; NodeDef* new_batch_node;
TF_RETURN_IF_ERROR(AddNewBatchNode( TF_RETURN_IF_ERROR(AddNewBatchNode(
*batch_node, *input_node, *vectorized_func, &graph, &new_batch_node)); *batch_node, *input_node, *vectorized_func, &graph, &new_batch_node));
vectorized_branch.push_back(new_batch_node);
NodeDef* new_map_node; NodeDef* new_map_node;
TF_RETURN_IF_ERROR(AddNewMapNode(*map_node, *batch_node, *new_batch_node, TF_RETURN_IF_ERROR(AddNewMapNode(*map_node, *batch_node, *new_batch_node,
*vectorized_func, &graph, &new_map_node)); *vectorized_func, &graph, &new_map_node));
NodeDef* final_node = new_map_node; vectorized_branch.push_back(new_map_node);
if (optional_prefetch_node) { if (optional_prefetch_node) {
// If the original pipeline was .map().prefetch().batch(), the new // If the original pipeline was .map().prefetch().batch(), the new
@ -450,13 +552,22 @@ Status MapVectorization::OptimizeAndCollectStats(Cluster* cluster,
TF_RETURN_IF_ERROR(AddNewPrefetchNode(*optional_prefetch_node, TF_RETURN_IF_ERROR(AddNewPrefetchNode(*optional_prefetch_node,
*batch_node, *new_map_node, &graph, *batch_node, *new_map_node, &graph,
&new_prefetch_node)); &new_prefetch_node));
vectorized_branch.push_back(new_prefetch_node);
}
final_node = new_prefetch_node; std::vector<const NodeDef*> original_branch({map_node});
if (optional_prefetch_node) {
original_branch.push_back(optional_prefetch_node);
}
if (map_node->op() != kExperimentalMapAndBatchOp) {
original_branch.push_back(batch_node);
} }
NodeDef* new_choose_fastest_node; NodeDef* new_choose_fastest_node;
TF_RETURN_IF_ERROR(AddNewChooseFastestNode( TF_RETURN_IF_ERROR(AddNewChooseFastestNode(
{*final_node, *batch_node}, &graph, &new_choose_fastest_node)); input_node, /*ratio_numerator_name=*/new_batch_node->input(1),
std::move(original_branch), std::move(vectorized_branch), &graph,
library, &new_choose_fastest_node));
// Make output of Batch point to ChooseFastest instead. // Make output of Batch point to ChooseFastest instead.
TF_RETURN_IF_ERROR(graph.UpdateFanouts(batch_node->name(), TF_RETURN_IF_ERROR(graph.UpdateFanouts(batch_node->name(),

View File

@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/grappler/grappler_item.h" #include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/optimizers/data/function_utils.h"
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h" #include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
#include "tensorflow/core/grappler/utils/topological_sort.h" #include "tensorflow/core/grappler/utils/topological_sort.h"
#include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/core/status_test_util.h"
@ -38,7 +39,7 @@ constexpr char kBatchV2Op[] = "BatchDatasetV2";
constexpr char kExperimentalMapAndBatchOp[] = "ExperimentalMapAndBatchDataset"; constexpr char kExperimentalMapAndBatchOp[] = "ExperimentalMapAndBatchDataset";
constexpr char kMapOp[] = "MapDataset"; constexpr char kMapOp[] = "MapDataset";
constexpr char kParallelMapOp[] = "ParallelMapDataset"; constexpr char kParallelMapOp[] = "ParallelMapDataset";
constexpr char kChooseFastestOp[] = "ExperimentalChooseFastestDataset"; constexpr char kChooseFastestOp[] = "ChooseFastestBranchDataset";
constexpr char kPrefetchOp[] = "PrefetchDataset"; constexpr char kPrefetchOp[] = "PrefetchDataset";
constexpr char kAttrNameF[] = "f"; constexpr char kAttrNameF[] = "f";
constexpr char kAttrNameTarguments[] = "Targuments"; constexpr char kAttrNameTarguments[] = "Targuments";
@ -171,110 +172,65 @@ void CheckNotVectorized(const GraphDef& output, const string& map_op,
EXPECT_EQ(batch_node.input(0), map_node.name()); EXPECT_EQ(batch_node.input(0), map_node.name());
} }
void CheckBranch(const GraphDef& graph, string input_name, void CheckBranch(const FunctionDef& function, gtl::ArraySlice<string> ops) {
gtl::ArraySlice<string> ops, const string& terminal_input) {
for (int i = 0, size = ops.size(); i < size; ++i) { for (int i = 0, size = ops.size(); i < size; ++i) {
const NodeDef& input_node = EXPECT_EQ(function.node_def(i).op(), ops[i]);
graph.node(graph_utils::FindGraphNodeWithName(input_name, graph));
EXPECT_EQ(input_node.op(), ops[size - i - 1]);
input_name = input_node.input(0);
} }
EXPECT_EQ(input_name, terminal_input);
} }
// Checks that a graph has undergone the map_vectorization transformation const FunctionDef* GetFunction(const GraphDef& graph,
// successfully, whereby the new graph has the shape: const string& function_name) {
//
// input_node --> new batch --> new map -------+
// | |
// | v
// +-------> old map --> old batch ---> choose_fastest
//
void CheckVectorized(const GraphDef& output, const string& map_op,
const string& batch_op, const string& map_input_name,
bool fused = false, bool prefetch = false) {
ASSERT_EQ(graph_utils::FindAllGraphNodesWithOp(map_op, output).size(), 2);
ASSERT_EQ(graph_utils::FindAllGraphNodesWithOp(batch_op, output).size(), 2);
ASSERT_EQ(
graph_utils::FindAllGraphNodesWithOp(kChooseFastestOp, output).size(), 1);
const NodeDef& choose_fastest_node =
output.node(graph_utils::FindGraphNodeWithOp(kChooseFastestOp, output));
// Branch 0: vectorized
std::vector<string> vectorized_ops({batch_op, map_op});
std::vector<string> unvectorized_ops({map_op, batch_op});
if (prefetch) {
vectorized_ops.push_back(kPrefetchOp);
unvectorized_ops.insert(unvectorized_ops.begin() + 1, kPrefetchOp);
}
CheckBranch(output, choose_fastest_node.input(0), vectorized_ops,
map_input_name);
// Branch 1: original
CheckBranch(output, choose_fastest_node.input(1), unvectorized_ops,
map_input_name);
const NodeDef* vectorized_map_node = nullptr;
auto tmp_node = &output.node(
graph_utils::FindGraphNodeWithName(choose_fastest_node.input(0), output));
if (prefetch) {
vectorized_map_node = &output.node(
graph_utils::FindGraphNodeWithName(tmp_node->input(0), output));
} else {
vectorized_map_node = tmp_node;
}
// Check that the function is actually vectorized.
// The vectorization of the identity function is itself.
string function_name =
vectorized_map_node->attr().at(kAttrNameF).func().name();
int found = int found =
graph_utils::FindGraphFunctionWithName(function_name, output.library()); graph_utils::FindGraphFunctionWithName(function_name, graph.library());
ASSERT_NE(found, -1); if (found == -1) {
const auto& function = output.library().function(found); return nullptr;
EXPECT_EQ(function.node_def(0).op(), "Identity"); }
return &graph.library().function(found);
} }
// Checks that a graph has undergone the map_vectorization transformation // Checks that a graph has undergone the map_vectorization transformation
// successfully, whereby the new graph has the shape: // successfully, whereby the new graph has the shape:
// //
// input_node --> new batch -> new map --------+ // input_node -------------> choose_fastest --> ...
// | | // |f0 |f1
// | v // | |
// +-------> old map_and_batch ---> choose_fastest // | +---> new batch --> new map
// |
// +--> old map --> old batch
// //
void CheckVectorizedFused(const GraphDef& output, void CheckVectorized(const GraphDef& output,
const string& map_input_name) { gtl::ArraySlice<string> expected_vectorized_branch,
ASSERT_EQ(graph_utils::FindAllGraphNodesWithOp(kParallelMapOp, output).size(), gtl::ArraySlice<string> expected_original_branch,
1); const string& input_name) {
ASSERT_EQ(graph_utils::FindAllGraphNodesWithOp(kBatchV2Op, output).size(), 1);
ASSERT_EQ(
graph_utils::FindAllGraphNodesWithOp(kExperimentalMapAndBatchOp, output)
.size(),
1);
ASSERT_EQ( ASSERT_EQ(
graph_utils::FindAllGraphNodesWithOp(kChooseFastestOp, output).size(), 1); graph_utils::FindAllGraphNodesWithOp(kChooseFastestOp, output).size(), 1);
const NodeDef& choose_fastest_node = const NodeDef& choose_fastest_node =
output.node(graph_utils::FindGraphNodeWithOp(kChooseFastestOp, output)); output.node(graph_utils::FindGraphNodeWithOp(kChooseFastestOp, output));
ASSERT_EQ(choose_fastest_node.input(0), input_name);
const auto& functions_list = choose_fastest_node.attr().at("branches").list();
// Branch 0: vectorized // Branch 0: vectorized
CheckBranch(output, choose_fastest_node.input(0), const FunctionDef* branch_0 =
{kBatchV2Op, kParallelMapOp}, map_input_name); GetFunction(output, functions_list.func(0).name());
ASSERT_NE(branch_0, nullptr);
CheckBranch(*branch_0, expected_vectorized_branch);
// Branch 1: original // Branch 1: original
CheckBranch(output, choose_fastest_node.input(1), const FunctionDef* branch_1 =
{kExperimentalMapAndBatchOp}, map_input_name); GetFunction(output, functions_list.func(1).name());
ASSERT_NE(branch_1, nullptr);
CheckBranch(*branch_1, expected_original_branch);
const NodeDef& vectorized_map_node = output.node( const NodeDef& vectorized_map_node =
graph_utils::FindGraphNodeWithName(choose_fastest_node.input(0), output)); branch_0->node_def(function_utils::FindFunctionNodeWithOp(
// Check that the function is actually vectorized. expected_vectorized_branch[1], *branch_0));
// The vectorization of the identity function is itself.
string function_name = string function_name =
vectorized_map_node.attr().at(kAttrNameF).func().name(); vectorized_map_node.attr().at(kAttrNameF).func().name();
int found =
graph_utils::FindGraphFunctionWithName(function_name, output.library()); const FunctionDef* function = GetFunction(output, function_name);
ASSERT_NE(found, -1); ASSERT_NE(function, nullptr);
const auto& function = output.library().function(found); EXPECT_EQ(function->node_def(0).op(), "Identity");
EXPECT_EQ(function.node_def(0).op(), "Identity");
} }
class MapThenBatchTest class MapThenBatchTest
@ -298,9 +254,26 @@ TEST_P(MapThenBatchTest, IsVectorized) {
MapVectorization optimizer; MapVectorization optimizer;
GraphDef output; GraphDef output;
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output)); TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
CheckVectorized(output, num_parallel_calls > 0 ? kParallelMapOp : kMapOp,
use_batch_v2 ? kBatchV2Op : kBatchOp, range_dataset->name(), std::vector<string> expected_original_branch;
/*fused=*/false, prefetch); expected_original_branch.push_back(num_parallel_calls > 0 ? kParallelMapOp
: kMapOp);
if (prefetch) {
expected_original_branch.push_back(kPrefetchOp);
}
expected_original_branch.push_back(use_batch_v2 > 0 ? kBatchV2Op : kBatchOp);
std::vector<string> expected_vectorized_branch;
expected_vectorized_branch.push_back(use_batch_v2 > 0 ? kBatchV2Op
: kBatchOp);
expected_vectorized_branch.push_back(num_parallel_calls > 0 ? kParallelMapOp
: kMapOp);
if (prefetch) {
expected_vectorized_branch.push_back(kPrefetchOp);
}
CheckVectorized(output, expected_vectorized_branch, expected_original_branch,
range_dataset->name());
} }
INSTANTIATE_TEST_SUITE_P(MapThenBatchTest, MapThenBatchTest, INSTANTIATE_TEST_SUITE_P(MapThenBatchTest, MapThenBatchTest,
@ -346,15 +319,9 @@ TEST(MapVectorizationTest, VectorizeExperimentalMapAndBatch) {
MapVectorization optimizer; MapVectorization optimizer;
GraphDef output; GraphDef output;
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output)); TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
CheckVectorizedFused(output, "range");
}
void EvaluateNodes(const GraphDef& graph, CheckVectorized(output, {kBatchV2Op, kParallelMapOp},
const std::vector<string>& output_tensor_names, {kExperimentalMapAndBatchOp}, range_node->name());
std::vector<Tensor>* output_tensors) {
std::unique_ptr<Session> session(NewSession(SessionOptions()));
TF_CHECK_OK(session->Create(graph));
TF_CHECK_OK(session->Run({}, output_tensor_names, {}, output_tensors));
} }
class ChainedMapAndBatchTest class ChainedMapAndBatchTest
@ -403,17 +370,30 @@ TEST_P(ChainedMapAndBatchTest, IsVectorized) {
const NodeDef& range_node = const NodeDef& range_node =
output.node(graph_utils::FindGraphNodeWithOp(kRangeOp, output)); output.node(graph_utils::FindGraphNodeWithOp(kRangeOp, output));
const NodeDef& choose_fastest_0 = output.node(choose_fastest_nodes[0]); const NodeDef& choose_fastest_0 = output.node(choose_fastest_nodes[0]);
CheckBranch(output, choose_fastest_0.input(0), {kBatchV2Op, kParallelMapOp}, ASSERT_EQ(choose_fastest_0.input(0), range_node.name());
range_node.name());
CheckBranch(output, choose_fastest_0.input(1),
fuse_0 ? fused_sequence : unfused_sequence, range_node.name());
const NodeDef& choose_fastest_1 = output.node(choose_fastest_nodes[1]); const NodeDef& choose_fastest_1 = output.node(choose_fastest_nodes[1]);
CheckBranch(output, choose_fastest_1.input(0), {kBatchV2Op, kParallelMapOp}, ASSERT_EQ(choose_fastest_1.input(0), choose_fastest_0.name());
choose_fastest_0.name());
CheckBranch(output, choose_fastest_1.input(1), auto check_branches = [&output](const NodeDef& choose_fastest_node,
fuse_1 ? fused_sequence : unfused_sequence, gtl::ArraySlice<string> original_ops) {
choose_fastest_0.name()); const auto& functions_list =
choose_fastest_node.attr().at("branches").list();
// Branch 0: vectorized
const FunctionDef* branch_0 =
GetFunction(output, functions_list.func(0).name());
ASSERT_NE(branch_0, nullptr);
CheckBranch(*branch_0, {kBatchV2Op, kParallelMapOp});
// Branch 1: original
const FunctionDef* branch_1 =
GetFunction(output, functions_list.func(1).name());
ASSERT_NE(branch_1, nullptr);
CheckBranch(*branch_1, original_ops);
};
check_branches(choose_fastest_0, fuse_0 ? fused_sequence : unfused_sequence);
check_branches(choose_fastest_1, fuse_1 ? fused_sequence : unfused_sequence);
} }
INSTANTIATE_TEST_SUITE_P(ChainedMapAndBatchTest, ChainedMapAndBatchTest, INSTANTIATE_TEST_SUITE_P(ChainedMapAndBatchTest, ChainedMapAndBatchTest,
@ -516,7 +496,10 @@ TEST(MapVectorizationTest, VectorizeWithUndefinedOutputTypes) {
MapVectorization optimizer; MapVectorization optimizer;
GraphDef output; GraphDef output;
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output)); TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
CheckVectorized(output, map_node->op(), batch_node->op(), input_node->name()); CheckVectorized(
output, /*expected_vectorized_branch=*/{batch_node->op(), map_node->op()},
/*expected_original_branch=*/{map_node->op(), batch_node->op()},
input_node->name());
} }
// TODO(rachelim): Add test that has a polymorphic function. // TODO(rachelim): Add test that has a polymorphic function.

View File

@ -359,8 +359,8 @@ class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase):
unoptimized = _make_dataset([map_node_name, "Batch"]) unoptimized = _make_dataset([map_node_name, "Batch"])
# Note that because of the `ChooseDataset` fork, we can't use `assert_next` # Note that because of the `ChooseDataset` fork, we can't use `assert_next`
# to verify the optimization result. # to verify the optimization result.
optimized = _make_dataset( optimized = _make_dataset(["ChooseFastestBranch"]
[] if expect_optimized else [map_node_name, "Batch"]) if expect_optimized else [map_node_name, "Batch"])
options = dataset_ops.Options() options = dataset_ops.Options()
options.experimental_optimization.apply_default_optimizations = False options.experimental_optimization.apply_default_optimizations = False
options.experimental_optimization.map_vectorization = True options.experimental_optimization.map_vectorization = True
@ -422,7 +422,7 @@ class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase):
return dataset return dataset
unoptimized = _make_dataset(["MapAndBatch"]) unoptimized = _make_dataset(["MapAndBatch"])
optimized = _make_dataset([]) optimized = _make_dataset(["ChooseFastestBranch"])
options = dataset_ops.Options() options = dataset_ops.Options()
options.experimental_optimization.map_vectorization = True options.experimental_optimization.map_vectorization = True
optimized = optimized.with_options(options) optimized = optimized.with_options(options)
@ -473,7 +473,7 @@ class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase):
return dataset return dataset
unoptimized = make_dataset(unoptimized_seq) unoptimized = make_dataset(unoptimized_seq)
optimized = make_dataset([]) optimized = make_dataset(["ChooseFastestBranch", "ChooseFastestBranch"])
options = dataset_ops.Options() options = dataset_ops.Options()
options.experimental_optimization.map_vectorization = True options.experimental_optimization.map_vectorization = True
optimized = optimized.with_options(options) optimized = optimized.with_options(options)