[tf.data] Add an optimization that vectorizes map functions and swaps the order of Map->Batch dataset transformations to Batch->Map

PiperOrigin-RevId: 209674669
This commit is contained in:
Rachel Lim 2018-08-21 15:52:10 -07:00 committed by TensorFlower Gardener
parent 7989b2bc99
commit 62fcb03449
16 changed files with 883 additions and 48 deletions

View File

@ -230,12 +230,15 @@ py_test(
srcs_version = "PY2AND3",
deps = [
":stats_dataset_test_base",
":test_utils",
"//tensorflow/contrib/data/python/ops:optimization",
"//tensorflow/contrib/data/python/ops:stats_ops",
"//tensorflow/python:check_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:framework_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python/data/ops:dataset_ops",
"@absl_py//absl/testing:parameterized",
@ -548,3 +551,13 @@ py_test(
"//tensorflow/python/data/ops:readers",
],
)
py_library(
name = "test_utils",
srcs = ["test_utils.py"],
deps = [
"//tensorflow/python:client_testlib",
"//tensorflow/python:errors",
"//tensorflow/python/data/util:nest",
],
)

View File

@ -31,47 +31,57 @@ from tensorflow.python.platform import test
class MapDefunTest(test.TestCase):
def testMapDefun_Simple(self):
def testMapDefunSimple(self):
@function.Defun(dtypes.int32)
def simple_fn(x):
return x * 2 + 3
with self.test_session():
nums = [[1, 2], [3, 4], [5, 6]]
elems = constant_op.constant(nums, dtype=dtypes.int32, name="data")
r = map_defun.map_defun(simple_fn, [elems], [dtypes.int32], [(2,)])[0]
expected = elems * 2 + 3
self.assertAllEqual(self.evaluate(r), self.evaluate(expected))
nums = [[1, 2], [3, 4], [5, 6]]
elems = constant_op.constant(nums, dtype=dtypes.int32, name="data")
r = map_defun.map_defun(simple_fn, [elems], [dtypes.int32], [(2,)])[0]
expected = elems * 2 + 3
self.assertAllEqual(self.evaluate(r), self.evaluate(expected))
def testMapDefun_MismatchedTypes(self):
def testMapDefunMismatchedTypes(self):
@function.Defun(dtypes.int32)
def fn(x):
return math_ops.cast(x, dtypes.float64)
with self.test_session():
nums = [1, 2, 3, 4, 5, 6]
elems = constant_op.constant(nums, dtype=dtypes.int32, name="data")
r = map_defun.map_defun(fn, [elems], [dtypes.int32], [()])[0]
with self.assertRaises(errors.InvalidArgumentError):
self.evaluate(r)
nums = [1, 2, 3, 4, 5, 6]
elems = constant_op.constant(nums, dtype=dtypes.int32, name="data")
r = map_defun.map_defun(fn, [elems], [dtypes.int32], [()])[0]
with self.assertRaises(errors.InvalidArgumentError):
self.evaluate(r)
def testMapDefun_MultipleOutputs(self):
def testMapDefunReduceDim(self):
# Tests where the output has a different rank from the input
@function.Defun(dtypes.int32)
def fn(x):
return array_ops.gather(x, 0)
nums = [[1, 2], [3, 4], [5, 6]]
elems = constant_op.constant(nums, dtype=dtypes.int32, name="data")
r = map_defun.map_defun(fn, [elems], [dtypes.int32], [()])[0]
expected = constant_op.constant([1, 3, 5])
self.assertAllEqual(self.evaluate(r), self.evaluate(expected))
def testMapDefunMultipleOutputs(self):
@function.Defun(dtypes.int32)
def fn(x):
return (x, math_ops.cast(x * 2 + 3, dtypes.float64))
with self.test_session():
nums = [[1, 2], [3, 4], [5, 6]]
elems = constant_op.constant(nums, dtype=dtypes.int32, name="data")
r = map_defun.map_defun(fn, [elems], [dtypes.int32, dtypes.float64],
[(2,), (2,)])
expected = [elems, elems * 2 + 3]
self.assertAllEqual(self.evaluate(r), self.evaluate(expected))
nums = [[1, 2], [3, 4], [5, 6]]
elems = constant_op.constant(nums, dtype=dtypes.int32, name="data")
r = map_defun.map_defun(fn, [elems], [dtypes.int32, dtypes.float64], [(2,),
(2,)])
expected = [elems, elems * 2 + 3]
self.assertAllEqual(self.evaluate(r), self.evaluate(expected))
def testMapDefun_ShapeInference(self):
def testMapDefunShapeInference(self):
@function.Defun(dtypes.int32)
def fn(x):
@ -82,7 +92,7 @@ class MapDefunTest(test.TestCase):
result = map_defun.map_defun(fn, [elems], [dtypes.int32], [(2,)])[0]
self.assertEqual(result.get_shape(), (3, 2))
def testMapDefun_PartialShapeInference(self):
def testMapDefunPartialShapeInference(self):
@function.Defun(dtypes.int32)
def fn(x):
@ -92,7 +102,7 @@ class MapDefunTest(test.TestCase):
result = map_defun.map_defun(fn, [elems], [dtypes.int32], [(2,)])
self.assertEqual(result[0].get_shape().as_list(), [None, 2])
def testMapDefun_RaisesErrorOnRuntimeShapeMismatch(self):
def testMapDefunRaisesErrorOnRuntimeShapeMismatch(self):
@function.Defun(dtypes.int32, dtypes.int32)
def fn(x, y):
@ -108,7 +118,7 @@ class MapDefunTest(test.TestCase):
"All inputs must have the same dimension 0."):
sess.run(result, feed_dict={elems1: [1, 2, 3, 4, 5], elems2: [1, 2, 3]})
def testMapDefun_RaisesDefunError(self):
def testMapDefunRaisesDefunError(self):
@function.Defun(dtypes.int32)
def fn(x):
@ -117,9 +127,8 @@ class MapDefunTest(test.TestCase):
elems = constant_op.constant([0, 0, 0, 37, 0])
result = map_defun.map_defun(fn, [elems], [dtypes.int32], [()])
with self.test_session():
with self.assertRaises(errors.InvalidArgumentError):
self.evaluate(result)
with self.assertRaises(errors.InvalidArgumentError):
self.evaluate(result)
if __name__ == "__main__":

View File

@ -20,12 +20,16 @@ from __future__ import print_function
from absl.testing import parameterized
from tensorflow.contrib.data.python.kernel_tests import stats_dataset_test_base
from tensorflow.contrib.data.python.kernel_tests import test_utils
from tensorflow.contrib.data.python.ops import optimization
from tensorflow.contrib.data.python.ops import stats_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
@ -277,5 +281,124 @@ class OptimizeStatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
"record_latency_PrefetchDataset/_6", 1)
class MapVectorizationTest(test_utils.DatasetTestBase, parameterized.TestCase):
def _get_test_datasets(self,
base_dataset,
map_fn,
num_parallel_calls=None,
expect_optimized=True):
"""Given base dataset and map fn, creates test datasets.
Returns a tuple of (unoptimized, dataset, optimized dataset). The
unoptimized dataset has the assertion that Batch follows Map. The optimized
dataset has the assertion that Map follows Batch, and has the
"map_vectorization" optimization applied.
Args:
base_dataset: Input dataset to map->batch
map_fn: Map function to use
num_parallel_calls: (Optional.) num_parallel_calls argument for map
expect_optimized: (Optional.) Whether we expect the optimization to take
place, in which case we will assert that Batch is followed by Map,
otherwise Map followed by Batch. Defaults to True.
Returns:
Tuple of (unoptimized dataset, optimized dataset).
"""
map_node_name = "Map" if num_parallel_calls is None else "ParallelMap"
batch_size = 100
def _make_dataset(node_names):
return base_dataset.apply(optimization.assert_next(node_names)).map(
map_fn, num_parallel_calls=num_parallel_calls).batch(batch_size)
unoptimized = _make_dataset([map_node_name, "Batch"])
optimized = _make_dataset(["Batch", map_node_name] if expect_optimized else
[map_node_name, "Batch"]).apply(
optimization.optimize(["map_vectorization"]))
return unoptimized, optimized
@parameterized.named_parameters(
("Basic", lambda x: (x, x + 1), None),
("Parallel", lambda x: (x, x + 1), 12),
("Gather", lambda x: array_ops.gather(x, 0), 12),
)
def testOptimization(self, map_fn, num_parallel_calls):
base_dataset = dataset_ops.Dataset.from_tensor_slices([[1, 2],
[3, 4]]).repeat(5)
unoptimized, optimized = self._get_test_datasets(base_dataset, map_fn,
num_parallel_calls)
self._assert_datasets_equal(unoptimized, optimized)
def testOptimizationBadMapFn(self):
# Test map functions that give an error
def map_fn(x):
# x has leading dimension 5, this will raise an error
return array_ops.gather(x, 10)
base_dataset = dataset_ops.Dataset.range(5).repeat(5).batch(
5, drop_remainder=True)
_, optimized = self._get_test_datasets(base_dataset, map_fn)
nxt = optimized.make_one_shot_iterator().get_next()
with self.assertRaisesRegexp(errors.InvalidArgumentError,
r"indices = 10 is not in \[0, 5\)"):
self.evaluate(nxt)
def testOptimizationWithCapturedInputs(self):
# Tests that vectorization works with captured inputs
def map_fn(x):
return x + y
y = constant_op.constant(1, shape=(2,))
base_dataset = dataset_ops.Dataset.from_tensor_slices([[1, 2],
[3, 4]]).repeat(5)
# TODO(rachelim): when this optimization works, turn on expect_optimized
unoptimized, optimized = self._get_test_datasets(
base_dataset, map_fn, expect_optimized=False)
self._assert_datasets_equal(optimized, unoptimized)
def testOptimizationIgnoreStateful(self):
def map_fn(x):
with ops.control_dependencies([check_ops.assert_equal(x, 0)]):
return array_ops.identity(x)
base_dataset = dataset_ops.Dataset.from_tensor_slices([[1, 2],
[3, 4]]).repeat(5)
_, optimized = self._get_test_datasets(
base_dataset, map_fn, expect_optimized=False)
nxt = optimized.make_one_shot_iterator().get_next()
# NOTE: Right now, it raises an error because we can't save datasets that
# are stateful, and we rely on this saving mechanism to optimize datasets,
# so stateful functions can't be optimized.
with self.assertRaisesRegexp(errors.InvalidArgumentError, "[Ss]tateful"):
self.evaluate(nxt)
def testOptimizationIgnoreRagged(self):
# Make sure we ignore inputs that might not be uniformly sized
def map_fn(x):
return array_ops.gather(x, 0)
# output_shape = (?,)
base_dataset = dataset_ops.Dataset.range(20).batch(3, drop_remainder=False)
unoptimized, optimized = self._get_test_datasets(
base_dataset, map_fn, expect_optimized=False)
self._assert_datasets_equal(unoptimized, optimized)
def testOptimizationIgnoreRaggedMap(self):
# Don't optimize when the output of the map fn shapes are unknown.
def map_fn(x):
return array_ops.tile(x, x)
base_dataset = dataset_ops.Dataset.range(20).batch(1, drop_remainder=True)
unoptimized, optimized = self._get_test_datasets(
base_dataset, map_fn, expect_optimized=False)
self._assert_datasets_raise_same_error(unoptimized, optimized,
errors.InvalidArgumentError)
if __name__ == "__main__":
test.main()

View File

@ -0,0 +1,60 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test utilities for tf.data functionality."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.data.util import nest
from tensorflow.python.framework import errors
from tensorflow.python.platform import test
class DatasetTestBase(test.TestCase):
"""Base class for dataset tests."""
def _assert_datasets_equal(self, dataset1, dataset2):
# TODO(rachelim): support sparse tensor outputs
next1 = dataset1.make_one_shot_iterator().get_next()
next2 = dataset2.make_one_shot_iterator().get_next()
with self.test_session() as sess:
while True:
try:
op1 = sess.run(next1)
except errors.OutOfRangeError:
with self.assertRaises(errors.OutOfRangeError):
sess.run(next2)
break
op2 = sess.run(next2)
op1 = nest.flatten(op1)
op2 = nest.flatten(op2)
assert len(op1) == len(op2)
for i in range(len(op1)):
self.assertAllEqual(op1[i], op2[i])
def _assert_datasets_raise_same_error(self, dataset1, dataset2, exc_class):
next1 = dataset1.make_one_shot_iterator().get_next()
next2 = dataset2.make_one_shot_iterator().get_next()
with self.test_session() as sess:
try:
sess.run(next1)
raise ValueError(
"Expected dataset to raise an error of type %s, but it did not." %
repr(exc_class))
except exc_class as e:
# Check that the first segment of the error messages are the same.
with self.assertRaisesRegexp(exc_class, e.message.split(". ")[0]):
sess.run(next2)

View File

@ -124,6 +124,43 @@ cc_library(
] + tf_protos_all(),
)
cc_library(
name = "map_vectorization",
srcs = ["map_vectorization.cc"],
hdrs = [
"map_vectorization.h",
],
visibility = ["//visibility:public"],
deps = [
":graph_utils",
"//tensorflow/core:lib",
"//tensorflow/core/grappler:mutable_graph_view",
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler:op_types",
"//tensorflow/core/grappler:utils",
"//tensorflow/core/grappler/clusters:cluster",
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
"//tensorflow/core:lib_internal",
] + tf_protos_all(),
)
tf_cc_test(
name = "map_vectorization_test",
srcs = ["map_vectorization_test.cc"],
visibility = ["//visibility:public"],
deps = [
":graph_utils",
":map_vectorization",
"//tensorflow/core:framework",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/kernels:cast_op", # Must be linked for the testlib functions to work.
],
)
cc_library(
name = "map_and_batch_fusion",
srcs = ["map_and_batch_fusion.cc"],
@ -311,6 +348,7 @@ cc_library(
":map_and_batch_fusion",
":map_and_filter_fusion",
":map_fusion",
":map_vectorization",
":noop_elimination",
":shuffle_and_repeat_fusion",
],

View File

@ -108,6 +108,26 @@ NodeDef* AddNode(StringPiece name, StringPiece op,
return graph->AddNode(std::move(node));
}
NodeDef* AddNode(StringPiece name, StringPiece op,
const std::vector<string>& inputs,
const std::vector<std::pair<string, AttrValue>>& attributes,
FunctionDef* fd) {
NodeDef* node = fd->add_node_def();
if (!name.empty()) {
node->set_name(name.ToString());
} else {
SetUniqueFunctionNodeName(op, fd, node);
}
node->set_op(op.ToString());
for (const string& input : inputs) {
node->add_input(input);
}
for (auto attr : attributes) {
(*node->mutable_attr())[attr.first] = attr.second;
}
return node;
}
template <>
NodeDef* AddScalarConstNode(bool v, MutableGraphView* graph) {
return AddScalarConstNodeHelper(
@ -181,7 +201,7 @@ bool ContainsGraphNodeWithName(StringPiece name, const GraphDef& graph) {
}
bool ContainsNodeWithOp(StringPiece op, const GraphDef& graph) {
return FindNodeWithOp(op, graph) != -1;
return FindGraphNodeWithOp(op, graph) != -1;
}
bool ContainsGraphFunctionWithName(StringPiece name,
@ -205,7 +225,7 @@ int FindGraphNodeWithName(StringPiece name, const GraphDef& graph) {
return indices.empty() ? -1 : indices.front();
}
int FindNodeWithOp(StringPiece op, const GraphDef& graph) {
int FindGraphNodeWithOp(StringPiece op, const GraphDef& graph) {
std::vector<int> indices = GetElementIndicesWithPredicate(
[&op](const NodeDef& node) { return node.op() == op; }, graph.node());
return indices.empty() ? -1 : indices.front();
@ -242,6 +262,12 @@ int FindFunctionNodeWithOp(StringPiece op, const FunctionDef& function) {
return indices.empty() ? -1 : indices.front();
}
NodeDef* GetInputNode(const NodeDef& node, const MutableGraphView& graph) {
if (node.input_size() == 0) return nullptr;
GraphView::InputPort input_port = graph.GetInputPort(node.name(), 0);
return graph.GetRegularFanin(input_port).node;
}
void SetUniqueGraphNodeName(StringPiece prefix, GraphDef* graph,
NodeDef* node) {
string name = prefix.ToString();

View File

@ -37,6 +37,12 @@ NodeDef* AddNode(StringPiece name, StringPiece op,
const std::vector<std::pair<string, AttrValue>>& attributes,
MutableGraphView* graph);
// Adds a node to a FunctionDef.
NodeDef* AddNode(StringPiece name, StringPiece op,
const std::vector<string>& inputs,
const std::vector<std::pair<string, AttrValue>>& attributes,
FunctionDef* fd);
// Adds a Const node with the given value to the graph.
template <typename T>
NodeDef* AddScalarConstNode(T v, MutableGraphView* graph) {
@ -99,7 +105,10 @@ int FindFunctionNodeWithOp(StringPiece op, const FunctionDef& function);
// Returns the index of the first node with the given op or -1 if no such node
// exists.
int FindNodeWithOp(StringPiece op, const GraphDef& graph);
int FindGraphNodeWithOp(StringPiece op, const GraphDef& graph);
// Gets the 0th input to a node in the graph.
NodeDef* GetInputNode(const NodeDef& node, const MutableGraphView& graph);
// Returns the list of indices of all nodes with the given op or empty list if
// no such node exists.

View File

@ -176,25 +176,25 @@ TEST(GraphUtilsTest, FindGraphFunctionWithName) {
FindGraphFunctionWithName(new_function->signature().name(), library), -1);
}
TEST(GraphUtilsTest, FindNodeWithOp) {
TEST(GraphUtilsTest, FindGraphNodeWithOp) {
GraphDef graph_def;
MutableGraphView graph(&graph_def);
EXPECT_EQ(FindNodeWithOp("OpA", *graph.GetGraph()), -1);
EXPECT_EQ(FindGraphNodeWithOp("OpA", *graph.GetGraph()), -1);
AddNode("A", "OpA", {}, {}, &graph);
AddNode("B", "OpB", {"A"}, {}, &graph);
AddNode("A2", "OpA", {"B"}, {}, &graph);
EXPECT_EQ(FindNodeWithOp("OpA", *graph.GetGraph()), 0);
EXPECT_EQ(FindGraphNodeWithOp("OpA", *graph.GetGraph()), 0);
graph.DeleteNodes({"B"});
EXPECT_EQ(FindNodeWithOp("OpB", *graph.GetGraph()), -1);
EXPECT_EQ(FindGraphNodeWithOp("OpB", *graph.GetGraph()), -1);
EXPECT_EQ(FindGraphNodeWithName("A2", *graph.GetGraph()), 1);
}
TEST(GraphUtilsTest, FindAllGraphNodesWithOp) {
GraphDef graph_def;
MutableGraphView graph(&graph_def);
EXPECT_EQ(FindNodeWithOp("OpA", *graph.GetGraph()), -1);
EXPECT_EQ(FindGraphNodeWithOp("OpA", *graph.GetGraph()), -1);
AddNode("A", "OpA", {}, {}, &graph);
AddNode("B", "OpB", {"A"}, {}, &graph);
@ -251,6 +251,54 @@ TEST(GraphUtilsTest, SetUniqueGraphFunctionName) {
other_function->signature().name());
}
TEST(GraphUtilsTest, AddNodeToFunctionDef) {
FunctionDef func;
const char* op_name = "xxx";
AddNode(op_name, op_name, {}, {}, &func);
const NodeDef& node1 = func.node_def(FindFunctionNodeWithName("xxx", func));
EXPECT_EQ(node1.op(), op_name);
EXPECT_EQ(node1.input_size(), 0);
EXPECT_EQ(node1.attr_size(), 0);
const std::vector<string> inputs({"input1", "input2"});
AddNode("", op_name, inputs, {}, &func);
const NodeDef& node2 =
func.node_def(FindFunctionNodeWithName("xxx/_2", func));
EXPECT_EQ(node2.op(), op_name);
EXPECT_EQ(node2.attr_size(), 0);
EXPECT_EQ(node2.input_size(), inputs.size());
for (size_t i = 0; i < inputs.size(); ++i) {
EXPECT_EQ(node2.input(i), inputs[i]);
}
AttrValue a1, a2;
a1.set_type(DT_INT32);
a2.set_type(DT_INT64);
const std::vector<std::pair<string, AttrValue>> attrs(
{{"attr1", a1}, {"attr2", a2}});
AddNode("", op_name, {}, attrs, &func);
const NodeDef& node3 =
func.node_def(FindFunctionNodeWithName("xxx/_3", func));
EXPECT_EQ(node3.op(), op_name);
EXPECT_EQ(node3.input_size(), 0);
EXPECT_EQ(node3.attr_size(), attrs.size());
for (size_t i = 0; i < attrs.size(); ++i) {
EXPECT_EQ(attrs[i].second.type(), node3.attr().at(attrs[i].first).type());
}
}
TEST(GraphUtilsTest, GetInputNode) {
GraphDef graph_def;
MutableGraphView graph(&graph_def);
NodeDef* node1 = AddNode("", "A", {}, {}, &graph);
NodeDef* node2 = AddNode("", "A", {node1->name()}, {}, &graph);
EXPECT_EQ(GetInputNode(*node2, graph), node1);
EXPECT_EQ(GetInputNode(*node1, graph), nullptr);
}
} // namespace
} // namespace graph_utils
} // namespace grappler

View File

@ -85,8 +85,8 @@ TEST(MapAndBatchFusionTest, FuseMapAndBatchNodesIntoOne) {
EXPECT_FALSE(
graph_utils::ContainsGraphNodeWithName(batch_node->name(), output));
EXPECT_TRUE(graph_utils::ContainsNodeWithOp("MapAndBatchDatasetV2", output));
NodeDef map_and_batch_node =
output.node(graph_utils::FindNodeWithOp("MapAndBatchDatasetV2", output));
NodeDef map_and_batch_node = output.node(
graph_utils::FindGraphNodeWithOp("MapAndBatchDatasetV2", output));
EXPECT_EQ(map_and_batch_node.input_size(), 5);
EXPECT_EQ(map_and_batch_node.input(0), map_node->input(0));
EXPECT_EQ(map_and_batch_node.input(1), map_node->input(1));
@ -170,8 +170,8 @@ TEST(MapAndBatchFusionTest, FuseMapAndBatchV2NodesIntoOne) {
EXPECT_FALSE(
graph_utils::ContainsGraphNodeWithName(batch_node->name(), output));
EXPECT_TRUE(graph_utils::ContainsNodeWithOp("MapAndBatchDatasetV2", output));
NodeDef map_and_batch_node =
output.node(graph_utils::FindNodeWithOp("MapAndBatchDatasetV2", output));
NodeDef map_and_batch_node = output.node(
graph_utils::FindGraphNodeWithOp("MapAndBatchDatasetV2", output));
EXPECT_EQ(map_and_batch_node.input_size(), 5);
EXPECT_EQ(map_and_batch_node.input(0), map_node->input(0));
EXPECT_EQ(map_and_batch_node.input(1), map_node->input(1));
@ -253,8 +253,8 @@ TEST(MapAndBatchFusionTest, FuseParallelMapAndBatchNodesIntoOne) {
EXPECT_FALSE(
graph_utils::ContainsGraphNodeWithName(batch_node->name(), output));
EXPECT_TRUE(graph_utils::ContainsNodeWithOp("MapAndBatchDatasetV2", output));
NodeDef map_and_batch_node =
output.node(graph_utils::FindNodeWithOp("MapAndBatchDatasetV2", output));
NodeDef map_and_batch_node = output.node(
graph_utils::FindGraphNodeWithOp("MapAndBatchDatasetV2", output));
EXPECT_EQ(map_and_batch_node.input_size(), 5);
EXPECT_EQ(map_and_batch_node.input(0), map_node->input(0));
EXPECT_EQ(map_and_batch_node.input(1), map_node->input(1));

View File

@ -101,18 +101,18 @@ TEST(MapAndFilterFusionTest, FuseMapAndFilterWithExtraChild) {
graph_utils::ContainsNodeWithOp("FilterByLastComponentDataset", output));
ASSERT_TRUE(graph_utils::ContainsNodeWithOp("CacheDataset", output));
int map_id = graph_utils::FindNodeWithOp("MapDataset", output);
int map_id = graph_utils::FindGraphNodeWithOp("MapDataset", output);
auto& map_node = output.node(map_id);
ASSERT_EQ(map_node.input_size(), 1);
EXPECT_EQ(map_node.input(0), "range");
int filter_by_component_id =
graph_utils::FindNodeWithOp("FilterByLastComponentDataset", output);
graph_utils::FindGraphNodeWithOp("FilterByLastComponentDataset", output);
auto& filter_by_component = output.node(filter_by_component_id);
ASSERT_EQ(filter_by_component.input_size(), 1);
EXPECT_EQ(filter_by_component.input(0), map_node.name());
int cache_id = graph_utils::FindNodeWithOp("CacheDataset", output);
int cache_id = graph_utils::FindGraphNodeWithOp("CacheDataset", output);
auto& cache_node = output.node(cache_id);
ASSERT_EQ(cache_node.input_size(), 2);
EXPECT_EQ(cache_node.input(0), filter_by_component.name());

View File

@ -0,0 +1,257 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/optimizers/data/map_vectorization.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/grappler/clusters/cluster.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/mutable_graph_view.h"
#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/platform/protobuf.h"
namespace tensorflow {
namespace grappler {
namespace {
void CopyAttribute(const string& attr_name, const NodeDef& from, NodeDef* to) {
(*to->mutable_attr())[attr_name] = from.attr().at(attr_name);
}
FunctionDef* AddVectorizedFunction(const NodeDef& map_node,
const FunctionDef& orig_func,
FunctionDefLibrary* library) {
// If we decide to use a different method of vectorization, we can just
// swap out this part.
FunctionDef* vectorized_func = library->add_function();
// Function inputs and outputs are the same as original, just
// with different shapes.
*vectorized_func->mutable_signature() = orig_func.signature();
graph_utils::SetUniqueGraphFunctionName("vectorized_function", library,
vectorized_func);
// Add MapDefun node
NodeDef* map_defun_node = vectorized_func->mutable_node_def()->Add();
map_defun_node->set_op("MapDefun");
graph_utils::SetUniqueFunctionNodeName(map_defun_node->op(), vectorized_func,
map_defun_node);
// Set attrs and inputs
for (const string& k : {"f", "output_types", "output_shapes"}) {
// Function, output types and (unbatched) shapes are the same as the
// original map node.
CopyAttribute(k, map_node, map_defun_node);
}
// Get types of input arguments from original map function
AttrValue t_args;
for (const auto& input : vectorized_func->signature().input_arg()) {
t_args.mutable_list()->add_type(input.type());
map_defun_node->add_input(input.name());
}
(*map_defun_node->mutable_attr())["Targuments"] = t_args;
// Set return values to match output names
string output_prefix = strings::StrCat(map_defun_node->name(), ":output:");
for (size_t i = 0; i < vectorized_func->signature().output_arg_size(); ++i) {
const auto& output_arg = vectorized_func->signature().output_arg(i);
(*vectorized_func->mutable_ret())[output_arg.name()] =
strings::StrCat(output_prefix, i);
}
return vectorized_func;
}
bool IsOutputShapesFullyDefined(const NodeDef& node) {
auto* shapes_attr = gtl::FindOrNull(node.attr(), "output_shapes");
if (shapes_attr == nullptr) return false;
const auto& shapes = shapes_attr->list().shape();
for (const TensorShapeProto& shape : shapes) {
for (const auto& dim : shape.dim()) {
if (dim.size() == -1) {
return false;
}
}
}
return true;
}
bool IsStatefulFn(const FunctionLibraryDefinition& library,
const FunctionDef& function_def) {
for (const NodeDef& node_def : function_def.node_def()) {
const OpDef* op_def;
Status s = library.LookUpOpDef(node_def.op(), &op_def);
if (!s.ok() || op_def->is_stateful()) {
return true;
}
}
return false;
}
bool HasCapturedInputs(const NodeDef& map_node) {
return map_node.attr().at("Targuments").list().type_size() > 0;
}
NodeDef make_new_batch_node(const NodeDef& old_batch_node,
const NodeDef& input_node,
const FunctionDef& vectorized_func,
MutableGraphView* graph) {
NodeDef batch_node;
batch_node.set_op(old_batch_node.op());
graph_utils::SetUniqueGraphNodeName(batch_node.op(), graph->GetGraph(),
&batch_node);
// Set the `input_dataset` input argument
batch_node.add_input(input_node.name());
// Set the `batch_size` input_argument
batch_node.add_input(old_batch_node.input(1));
if (batch_node.op() == "BatchDatasetV2") {
// Set the `drop_remainder` input argument
batch_node.add_input(old_batch_node.input(2));
}
// Set attrs
AttrValue output_types;
for (const auto& input : vectorized_func.signature().input_arg()) {
output_types.mutable_list()->add_type(input.type());
}
(*batch_node.mutable_attr())["output_types"] = output_types;
auto& output_shapes_attr = (*batch_node.mutable_attr())["output_shapes"];
const auto& input_shapes =
input_node.attr().at("output_shapes").list().shape();
int64 batch_size =
old_batch_node.attr().at("output_shapes").list().shape()[0].dim(0).size();
for (size_t i = 0; i < input_shapes.size(); ++i) {
TensorShapeProto* shape = output_shapes_attr.mutable_list()->add_shape();
TensorShapeProto_Dim* dim = shape->add_dim();
dim->set_size(batch_size);
shape->MergeFrom(input_shapes.Get(i));
}
return batch_node;
}
NodeDef make_new_map_node(const NodeDef& old_map_node,
const NodeDef& old_batch_node,
const NodeDef& new_batch_node,
const FunctionDef& vectorized_func,
MutableGraphView* graph) {
NodeDef map_node;
map_node.set_op(old_map_node.op());
graph_utils::SetUniqueGraphNodeName(map_node.op(), graph->GetGraph(),
&map_node);
// Set the `input_dataset` input argument
map_node.add_input(new_batch_node.name());
for (int i = 1; i < old_map_node.input_size(); i++) {
// Set the `other_arguments` and `num_parallel_calls` input arguments
map_node.add_input(old_map_node.input(i));
}
// Set attrs
CopyAttribute("Targuments", old_map_node, &map_node);
auto& func_attr = (*map_node.mutable_attr())["f"];
func_attr.mutable_func()->set_name(vectorized_func.signature().name());
for (auto key : {"output_shapes", "output_types"}) {
CopyAttribute(key, old_batch_node, &map_node);
}
return map_node;
}
} // namespace
Status MapVectorization::Optimize(Cluster* cluster, const GrapplerItem& item,
GraphDef* output) {
*output = item.graph;
MutableGraphView graph(output);
std::set<string> nodes_to_delete;
for (const NodeDef& node : item.graph.node()) {
// Find Map->Batch nodes.
// TODO(rachelim): Optimize MapAndBatchDataset[V2] as well.
if (node.op() != "BatchDataset" && node.op() != "BatchDatasetV2") {
continue;
}
const NodeDef& batch_node(node);
NodeDef* node2 = graph_utils::GetInputNode(batch_node, graph);
if (node2->op() != "MapDataset" && node2->op() != "ParallelMapDataset") {
continue;
}
// Use a more descriptive variable name now that we know the node type.
NodeDef* map_node = node2;
// Input to the map node
NodeDef* input_node = graph_utils::GetInputNode(*map_node, graph);
CHECK_NOTNULL(input_node);
FunctionDefLibrary* library = output->mutable_library();
FunctionLibraryDefinition function_library(OpRegistry::Global(), *library);
const FunctionDef* orig_func =
function_library.Find(map_node->attr().at("f").func().name());
// Check that this is a valid optimization.
if (!IsOutputShapesFullyDefined(*input_node) ||
!IsOutputShapesFullyDefined(*map_node) ||
IsStatefulFn(function_library, *orig_func) ||
HasCapturedInputs(*map_node)) {
// 1. If any of the inputs have an unknown shape, don't optimize, since
// inputs might not be batchable.
// 2. If any of the map func outputs have an unknown shape, don't
// optimize, so that batching errors surface as before.
// 3. If the function is stateful, don't vectorize it.
// 4. TODO(rachelim): Make this work for MapDataset with captured inputs
// by tiling inputs or modifying the signature of MapDefun.
continue;
}
FunctionDef* vectorized_func =
AddVectorizedFunction(*map_node, *orig_func, library);
CHECK_NOTNULL(vectorized_func);
auto* new_batch_node = graph.AddNode(
make_new_batch_node(batch_node, *input_node, *vectorized_func, &graph));
auto* new_map_node = graph.AddNode(make_new_map_node(
*map_node, batch_node, *new_batch_node, *vectorized_func, &graph));
graph.ReplaceInput(batch_node, *new_map_node);
// Mark the `Map` and `Batch` nodes for removal.
nodes_to_delete.insert(map_node->name());
nodes_to_delete.insert(batch_node.name());
}
graph.DeleteNodes(nodes_to_delete);
return Status::OK();
}
void MapVectorization::Feedback(Cluster* cluster, const GrapplerItem& item,
const GraphDef& optimize_output,
double result) {
// no-op
}
REGISTER_GRAPH_OPTIMIZER_AS(MapVectorization, "map_vectorization");
} // end namespace grappler
} // end namespace tensorflow

View File

@ -0,0 +1,46 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_MAP_VECTORIZATION_H_
#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_MAP_VECTORIZATION_H_
#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h"
namespace tensorflow {
namespace grappler {
class MapVectorization : public CustomGraphOptimizer {
public:
MapVectorization() = default;
~MapVectorization() override = default;
string name() const override { return "map_vectorization"; };
Status Init(
const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override {
return Status::OK();
}
Status Optimize(Cluster* cluster, const GrapplerItem& item,
GraphDef* output) override;
void Feedback(Cluster* cluster, const GrapplerItem& item,
const GraphDef& optimize_output, double result) override;
};
} // end namespace grappler
} // end namespace tensorflow
#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_MAP_VECTORIZATION_H_

View File

@ -0,0 +1,201 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/optimizers/data/map_vectorization.h"
#include "tensorflow/core/framework/attr_value_util.h"
#include "tensorflow/core/framework/function_testlib.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace grappler {
namespace {
using test::function::GDef;
using test::function::NDef;
void MakeTensorShapeProtoHelper(const gtl::ArraySlice<int> dims,
TensorShapeProto* t) {
for (size_t i = 0; i < dims.size(); ++i) {
auto* d = t->add_dim();
d->set_size(dims[i]);
}
}
AttrValue MakeShapeListAttr(
const gtl::ArraySlice<const gtl::ArraySlice<int>>& shapes) {
AttrValue shapes_attr;
for (size_t i = 0; i < shapes.size(); ++i) {
MakeTensorShapeProtoHelper(shapes[i],
shapes_attr.mutable_list()->add_shape());
}
return shapes_attr;
}
NodeDef MakeMapNodeHelper(
StringPiece name, StringPiece input_node_name, StringPiece function_name,
StringPiece map_op_name,
const gtl::ArraySlice<const gtl::ArraySlice<int>>& output_shapes,
const gtl::ArraySlice<DataType>& output_types) {
return test::function::NDef(
name, map_op_name, {input_node_name.ToString()},
{{"f", FunctionDefHelper::FunctionRef(function_name.ToString())},
{"Targuments", {}},
{"output_shapes", MakeShapeListAttr(output_shapes)},
{"output_types", output_types}});
}
NodeDef MakeMapNode(
StringPiece name, StringPiece input_node_name, StringPiece function_name,
const gtl::ArraySlice<const gtl::ArraySlice<int>>& output_shapes,
const gtl::ArraySlice<DataType>& output_types) {
return MakeMapNodeHelper(name, input_node_name, function_name, "MapDataset",
output_shapes, output_types);
}
NodeDef MakeBatchNode(
StringPiece name, StringPiece input_node_name,
StringPiece input_batch_size_name,
const gtl::ArraySlice<const gtl::ArraySlice<int>>& output_shapes,
const gtl::ArraySlice<DataType>& output_types) {
return NDef(name, "BatchDataset",
{input_node_name.ToString(), input_batch_size_name.ToString()},
{{"output_types", output_types},
{"output_shapes", MakeShapeListAttr(output_shapes)}});
}
NodeDef MakeBatchV2Node(
StringPiece name, StringPiece input_node_name,
StringPiece input_batch_size_name, StringPiece input_drop_remainder_name,
const gtl::ArraySlice<const gtl::ArraySlice<int>>& output_shapes,
const gtl::ArraySlice<DataType>& output_types) {
return NDef(name, "BatchDatasetV2",
{input_node_name.ToString(), input_batch_size_name.ToString(),
input_drop_remainder_name.ToString()},
{{"output_types", output_types},
{"output_shapes", MakeShapeListAttr(output_shapes)}});
}
NodeDef MakeRangeNode(StringPiece name, const gtl::ArraySlice<string>& inputs) {
return NDef(name, "RangeDataset", inputs,
{{"output_shapes", MakeShapeListAttr({{}})},
{"output_types", gtl::ArraySlice<DataType>({DT_INT64})}});
}
TEST(MapVectorizationTest, VectorizeMapWithBatch) {
GrapplerItem item;
item.graph = GDef(
{NDef("start", "Const", {}, {{"value", 0}, {"dtype", DT_INT32}}),
NDef("stop", "Const", {}, {{"value", 10}, {"dtype", DT_INT32}}),
NDef("step", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
NDef("batch_size", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
MakeRangeNode("range", {"start", "stop", "step"}),
MakeMapNode("map", "range", "XTimesTwo", {{}}, {DT_INT32}),
MakeBatchNode("batch", "map", "batch_size", {{-1}}, {DT_INT32})},
// FunctionLib
{
test::function::XTimesTwo(),
});
MapVectorization optimizer;
GraphDef output;
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
EXPECT_EQ(graph_utils::FindAllGraphNodesWithOp("MapDataset", output).size(),
1);
EXPECT_EQ(graph_utils::FindAllGraphNodesWithOp("BatchDataset", output).size(),
1);
const NodeDef& map_node =
output.node(graph_utils::FindGraphNodeWithOp("MapDataset", output));
const NodeDef& batch_node =
output.node(graph_utils::FindGraphNodeWithOp("BatchDataset", output));
EXPECT_EQ(map_node.input(0), batch_node.name());
EXPECT_EQ(batch_node.input(0), "range");
}
TEST(MapVectorizationTest, VectorizeMapWithBatchV2) {
GrapplerItem item;
item.graph = GDef(
{NDef("start", "Const", {}, {{"value", 0}, {"dtype", DT_INT32}}),
NDef("stop", "Const", {}, {{"value", 10}, {"dtype", DT_INT32}}),
NDef("step", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
NDef("batch_size", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
NDef("drop_remainder", "Const", {},
{{"value", false}, {"dtype", DT_BOOL}}),
MakeRangeNode("range", {"start", "stop", "step"}),
MakeMapNode("map", "range", "XTimesTwo", {{}}, {DT_INT32}),
MakeBatchV2Node("batch", "map", "batch_size", "drop_remainder", {{-1}},
{DT_INT32})},
// FunctionLib
{
test::function::XTimesTwo(),
});
MapVectorization optimizer;
GraphDef output;
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
EXPECT_EQ(graph_utils::FindAllGraphNodesWithOp("MapDataset", output).size(),
1);
EXPECT_EQ(
graph_utils::FindAllGraphNodesWithOp("BatchDatasetV2", output).size(), 1);
const NodeDef& map_node =
output.node(graph_utils::FindGraphNodeWithOp("MapDataset", output));
const NodeDef& batch_node =
output.node(graph_utils::FindGraphNodeWithOp("BatchDatasetV2", output));
EXPECT_EQ(map_node.input(0), batch_node.name());
EXPECT_EQ(batch_node.input(0), "range");
}
TEST(MapVectorizationTest, VectorizeWithUndefinedOutputShape) {
GrapplerItem item;
item.graph = GDef(
{NDef("batch_size", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
NDef("input", "InputDataset", {},
{{"output_types", gtl::ArraySlice<DataType>({DT_INT32})}}),
MakeMapNode("map", "input", "XTimesTwo", {{}}, {DT_INT32}),
MakeBatchNode("batch", "map", "batch_size", {{-1}}, {DT_INT32})},
// FunctionLib
{
test::function::XTimesTwo(),
});
MapVectorization optimizer;
GraphDef output;
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
}
TEST(MapVectorizationTest, VectorizeWithUndefinedOutputTypes) {
GrapplerItem item;
item.graph = GDef(
{NDef("batch_size", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
NDef("input", "InputDataset", {},
{{"output_shapes", MakeShapeListAttr({{}})}}),
MakeMapNode("map", "input", "XTimesTwo", {{}}, {DT_INT32}),
MakeBatchNode("batch", "map", "batch_size", {{-1}}, {DT_INT32})},
// FunctionLib
{
test::function::XTimesTwo(),
});
MapVectorization optimizer;
GraphDef output;
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
}
} // namespace
} // namespace grappler
} // namespace tensorflow

View File

@ -78,7 +78,7 @@ TEST(ShuffleAndRepeatFusionTest, FuseShuffleAndRepeatNodesIntoOne) {
EXPECT_TRUE(
graph_utils::ContainsNodeWithOp("ShuffleAndRepeatDataset", output));
NodeDef shuffle_and_repeat_node = output.node(
graph_utils::FindNodeWithOp("ShuffleAndRepeatDataset", output));
graph_utils::FindGraphNodeWithOp("ShuffleAndRepeatDataset", output));
EXPECT_EQ(shuffle_and_repeat_node.input_size(), 5);
EXPECT_EQ(shuffle_and_repeat_node.input(0), shuffle_node->input(0));
EXPECT_EQ(shuffle_and_repeat_node.input(1), shuffle_node->input(1));

View File

@ -74,7 +74,11 @@ class MapDefunOp : public AsyncOpKernel {
arg_shapes->at(i).RemoveDim(0); // Remove the first batch dimension
OP_REQUIRES_ASYNC(
ctx, batch_size == ctx->input(i).dim_size(0),
errors::InvalidArgument("All inputs must have the same dimension 0."),
errors::InvalidArgument(
"All inputs must have the same dimension 0. Input ", i,
" has leading dimension ", ctx->input(i).dim_size(0),
", while all previous inputs have leading dimension ", batch_size,
"."),
done);
}

View File

@ -71,6 +71,7 @@ COMMON_PIP_DEPS = [
"//tensorflow/contrib/constrained_optimization:constrained_optimization_pip",
"//tensorflow/contrib/data/python/kernel_tests/serialization:dataset_serialization_test_base",
"//tensorflow/contrib/data/python/kernel_tests:stats_dataset_test_base",
"//tensorflow/contrib/data/python/kernel_tests:test_utils",
"//tensorflow/contrib/data/python/ops:contrib_op_loader",
"//tensorflow/contrib/eager/python/examples:examples_pip",
"//tensorflow/contrib/eager/python:evaluator",