[tf.data] Add an optimization that vectorizes map functions and swaps the order of Map->Batch dataset transformations to Batch->Map
PiperOrigin-RevId: 209674669
This commit is contained in:
parent
7989b2bc99
commit
62fcb03449
@ -230,12 +230,15 @@ py_test(
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":stats_dataset_test_base",
|
||||
":test_utils",
|
||||
"//tensorflow/contrib/data/python/ops:optimization",
|
||||
"//tensorflow/contrib/data/python/ops:stats_ops",
|
||||
"//tensorflow/python:check_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
@ -548,3 +551,13 @@ py_test(
|
||||
"//tensorflow/python/data/ops:readers",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "test_utils",
|
||||
srcs = ["test_utils.py"],
|
||||
deps = [
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python/data/util:nest",
|
||||
],
|
||||
)
|
||||
|
@ -31,47 +31,57 @@ from tensorflow.python.platform import test
|
||||
|
||||
class MapDefunTest(test.TestCase):
|
||||
|
||||
def testMapDefun_Simple(self):
|
||||
def testMapDefunSimple(self):
|
||||
|
||||
@function.Defun(dtypes.int32)
|
||||
def simple_fn(x):
|
||||
return x * 2 + 3
|
||||
|
||||
with self.test_session():
|
||||
nums = [[1, 2], [3, 4], [5, 6]]
|
||||
elems = constant_op.constant(nums, dtype=dtypes.int32, name="data")
|
||||
r = map_defun.map_defun(simple_fn, [elems], [dtypes.int32], [(2,)])[0]
|
||||
expected = elems * 2 + 3
|
||||
self.assertAllEqual(self.evaluate(r), self.evaluate(expected))
|
||||
nums = [[1, 2], [3, 4], [5, 6]]
|
||||
elems = constant_op.constant(nums, dtype=dtypes.int32, name="data")
|
||||
r = map_defun.map_defun(simple_fn, [elems], [dtypes.int32], [(2,)])[0]
|
||||
expected = elems * 2 + 3
|
||||
self.assertAllEqual(self.evaluate(r), self.evaluate(expected))
|
||||
|
||||
def testMapDefun_MismatchedTypes(self):
|
||||
def testMapDefunMismatchedTypes(self):
|
||||
|
||||
@function.Defun(dtypes.int32)
|
||||
def fn(x):
|
||||
return math_ops.cast(x, dtypes.float64)
|
||||
|
||||
with self.test_session():
|
||||
nums = [1, 2, 3, 4, 5, 6]
|
||||
elems = constant_op.constant(nums, dtype=dtypes.int32, name="data")
|
||||
r = map_defun.map_defun(fn, [elems], [dtypes.int32], [()])[0]
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
self.evaluate(r)
|
||||
nums = [1, 2, 3, 4, 5, 6]
|
||||
elems = constant_op.constant(nums, dtype=dtypes.int32, name="data")
|
||||
r = map_defun.map_defun(fn, [elems], [dtypes.int32], [()])[0]
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
self.evaluate(r)
|
||||
|
||||
def testMapDefun_MultipleOutputs(self):
|
||||
def testMapDefunReduceDim(self):
|
||||
# Tests where the output has a different rank from the input
|
||||
|
||||
@function.Defun(dtypes.int32)
|
||||
def fn(x):
|
||||
return array_ops.gather(x, 0)
|
||||
|
||||
nums = [[1, 2], [3, 4], [5, 6]]
|
||||
elems = constant_op.constant(nums, dtype=dtypes.int32, name="data")
|
||||
r = map_defun.map_defun(fn, [elems], [dtypes.int32], [()])[0]
|
||||
expected = constant_op.constant([1, 3, 5])
|
||||
self.assertAllEqual(self.evaluate(r), self.evaluate(expected))
|
||||
|
||||
def testMapDefunMultipleOutputs(self):
|
||||
|
||||
@function.Defun(dtypes.int32)
|
||||
def fn(x):
|
||||
return (x, math_ops.cast(x * 2 + 3, dtypes.float64))
|
||||
|
||||
with self.test_session():
|
||||
nums = [[1, 2], [3, 4], [5, 6]]
|
||||
elems = constant_op.constant(nums, dtype=dtypes.int32, name="data")
|
||||
r = map_defun.map_defun(fn, [elems], [dtypes.int32, dtypes.float64],
|
||||
[(2,), (2,)])
|
||||
expected = [elems, elems * 2 + 3]
|
||||
self.assertAllEqual(self.evaluate(r), self.evaluate(expected))
|
||||
nums = [[1, 2], [3, 4], [5, 6]]
|
||||
elems = constant_op.constant(nums, dtype=dtypes.int32, name="data")
|
||||
r = map_defun.map_defun(fn, [elems], [dtypes.int32, dtypes.float64], [(2,),
|
||||
(2,)])
|
||||
expected = [elems, elems * 2 + 3]
|
||||
self.assertAllEqual(self.evaluate(r), self.evaluate(expected))
|
||||
|
||||
def testMapDefun_ShapeInference(self):
|
||||
def testMapDefunShapeInference(self):
|
||||
|
||||
@function.Defun(dtypes.int32)
|
||||
def fn(x):
|
||||
@ -82,7 +92,7 @@ class MapDefunTest(test.TestCase):
|
||||
result = map_defun.map_defun(fn, [elems], [dtypes.int32], [(2,)])[0]
|
||||
self.assertEqual(result.get_shape(), (3, 2))
|
||||
|
||||
def testMapDefun_PartialShapeInference(self):
|
||||
def testMapDefunPartialShapeInference(self):
|
||||
|
||||
@function.Defun(dtypes.int32)
|
||||
def fn(x):
|
||||
@ -92,7 +102,7 @@ class MapDefunTest(test.TestCase):
|
||||
result = map_defun.map_defun(fn, [elems], [dtypes.int32], [(2,)])
|
||||
self.assertEqual(result[0].get_shape().as_list(), [None, 2])
|
||||
|
||||
def testMapDefun_RaisesErrorOnRuntimeShapeMismatch(self):
|
||||
def testMapDefunRaisesErrorOnRuntimeShapeMismatch(self):
|
||||
|
||||
@function.Defun(dtypes.int32, dtypes.int32)
|
||||
def fn(x, y):
|
||||
@ -108,7 +118,7 @@ class MapDefunTest(test.TestCase):
|
||||
"All inputs must have the same dimension 0."):
|
||||
sess.run(result, feed_dict={elems1: [1, 2, 3, 4, 5], elems2: [1, 2, 3]})
|
||||
|
||||
def testMapDefun_RaisesDefunError(self):
|
||||
def testMapDefunRaisesDefunError(self):
|
||||
|
||||
@function.Defun(dtypes.int32)
|
||||
def fn(x):
|
||||
@ -117,9 +127,8 @@ class MapDefunTest(test.TestCase):
|
||||
|
||||
elems = constant_op.constant([0, 0, 0, 37, 0])
|
||||
result = map_defun.map_defun(fn, [elems], [dtypes.int32], [()])
|
||||
with self.test_session():
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
self.evaluate(result)
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
self.evaluate(result)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -20,12 +20,16 @@ from __future__ import print_function
|
||||
from absl.testing import parameterized
|
||||
|
||||
from tensorflow.contrib.data.python.kernel_tests import stats_dataset_test_base
|
||||
from tensorflow.contrib.data.python.kernel_tests import test_utils
|
||||
from tensorflow.contrib.data.python.ops import optimization
|
||||
from tensorflow.contrib.data.python.ops import stats_ops
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import check_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
@ -277,5 +281,124 @@ class OptimizeStatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
|
||||
"record_latency_PrefetchDataset/_6", 1)
|
||||
|
||||
|
||||
class MapVectorizationTest(test_utils.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
def _get_test_datasets(self,
|
||||
base_dataset,
|
||||
map_fn,
|
||||
num_parallel_calls=None,
|
||||
expect_optimized=True):
|
||||
"""Given base dataset and map fn, creates test datasets.
|
||||
|
||||
Returns a tuple of (unoptimized, dataset, optimized dataset). The
|
||||
unoptimized dataset has the assertion that Batch follows Map. The optimized
|
||||
dataset has the assertion that Map follows Batch, and has the
|
||||
"map_vectorization" optimization applied.
|
||||
|
||||
Args:
|
||||
base_dataset: Input dataset to map->batch
|
||||
map_fn: Map function to use
|
||||
num_parallel_calls: (Optional.) num_parallel_calls argument for map
|
||||
expect_optimized: (Optional.) Whether we expect the optimization to take
|
||||
place, in which case we will assert that Batch is followed by Map,
|
||||
otherwise Map followed by Batch. Defaults to True.
|
||||
|
||||
Returns:
|
||||
Tuple of (unoptimized dataset, optimized dataset).
|
||||
"""
|
||||
map_node_name = "Map" if num_parallel_calls is None else "ParallelMap"
|
||||
batch_size = 100
|
||||
|
||||
def _make_dataset(node_names):
|
||||
return base_dataset.apply(optimization.assert_next(node_names)).map(
|
||||
map_fn, num_parallel_calls=num_parallel_calls).batch(batch_size)
|
||||
|
||||
unoptimized = _make_dataset([map_node_name, "Batch"])
|
||||
optimized = _make_dataset(["Batch", map_node_name] if expect_optimized else
|
||||
[map_node_name, "Batch"]).apply(
|
||||
optimization.optimize(["map_vectorization"]))
|
||||
|
||||
return unoptimized, optimized
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("Basic", lambda x: (x, x + 1), None),
|
||||
("Parallel", lambda x: (x, x + 1), 12),
|
||||
("Gather", lambda x: array_ops.gather(x, 0), 12),
|
||||
)
|
||||
def testOptimization(self, map_fn, num_parallel_calls):
|
||||
base_dataset = dataset_ops.Dataset.from_tensor_slices([[1, 2],
|
||||
[3, 4]]).repeat(5)
|
||||
unoptimized, optimized = self._get_test_datasets(base_dataset, map_fn,
|
||||
num_parallel_calls)
|
||||
self._assert_datasets_equal(unoptimized, optimized)
|
||||
|
||||
def testOptimizationBadMapFn(self):
|
||||
# Test map functions that give an error
|
||||
def map_fn(x):
|
||||
# x has leading dimension 5, this will raise an error
|
||||
return array_ops.gather(x, 10)
|
||||
|
||||
base_dataset = dataset_ops.Dataset.range(5).repeat(5).batch(
|
||||
5, drop_remainder=True)
|
||||
_, optimized = self._get_test_datasets(base_dataset, map_fn)
|
||||
nxt = optimized.make_one_shot_iterator().get_next()
|
||||
with self.assertRaisesRegexp(errors.InvalidArgumentError,
|
||||
r"indices = 10 is not in \[0, 5\)"):
|
||||
self.evaluate(nxt)
|
||||
|
||||
def testOptimizationWithCapturedInputs(self):
|
||||
# Tests that vectorization works with captured inputs
|
||||
def map_fn(x):
|
||||
return x + y
|
||||
|
||||
y = constant_op.constant(1, shape=(2,))
|
||||
base_dataset = dataset_ops.Dataset.from_tensor_slices([[1, 2],
|
||||
[3, 4]]).repeat(5)
|
||||
# TODO(rachelim): when this optimization works, turn on expect_optimized
|
||||
unoptimized, optimized = self._get_test_datasets(
|
||||
base_dataset, map_fn, expect_optimized=False)
|
||||
self._assert_datasets_equal(optimized, unoptimized)
|
||||
|
||||
def testOptimizationIgnoreStateful(self):
|
||||
|
||||
def map_fn(x):
|
||||
with ops.control_dependencies([check_ops.assert_equal(x, 0)]):
|
||||
return array_ops.identity(x)
|
||||
|
||||
base_dataset = dataset_ops.Dataset.from_tensor_slices([[1, 2],
|
||||
[3, 4]]).repeat(5)
|
||||
_, optimized = self._get_test_datasets(
|
||||
base_dataset, map_fn, expect_optimized=False)
|
||||
nxt = optimized.make_one_shot_iterator().get_next()
|
||||
|
||||
# NOTE: Right now, it raises an error because we can't save datasets that
|
||||
# are stateful, and we rely on this saving mechanism to optimize datasets,
|
||||
# so stateful functions can't be optimized.
|
||||
with self.assertRaisesRegexp(errors.InvalidArgumentError, "[Ss]tateful"):
|
||||
self.evaluate(nxt)
|
||||
|
||||
def testOptimizationIgnoreRagged(self):
|
||||
# Make sure we ignore inputs that might not be uniformly sized
|
||||
def map_fn(x):
|
||||
return array_ops.gather(x, 0)
|
||||
|
||||
# output_shape = (?,)
|
||||
base_dataset = dataset_ops.Dataset.range(20).batch(3, drop_remainder=False)
|
||||
unoptimized, optimized = self._get_test_datasets(
|
||||
base_dataset, map_fn, expect_optimized=False)
|
||||
self._assert_datasets_equal(unoptimized, optimized)
|
||||
|
||||
def testOptimizationIgnoreRaggedMap(self):
|
||||
# Don't optimize when the output of the map fn shapes are unknown.
|
||||
def map_fn(x):
|
||||
return array_ops.tile(x, x)
|
||||
|
||||
base_dataset = dataset_ops.Dataset.range(20).batch(1, drop_remainder=True)
|
||||
unoptimized, optimized = self._get_test_datasets(
|
||||
base_dataset, map_fn, expect_optimized=False)
|
||||
self._assert_datasets_raise_same_error(unoptimized, optimized,
|
||||
errors.InvalidArgumentError)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
60
tensorflow/contrib/data/python/kernel_tests/test_utils.py
Normal file
60
tensorflow/contrib/data/python/kernel_tests/test_utils.py
Normal file
@ -0,0 +1,60 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Test utilities for tf.data functionality."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.data.util import nest
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class DatasetTestBase(test.TestCase):
|
||||
"""Base class for dataset tests."""
|
||||
|
||||
def _assert_datasets_equal(self, dataset1, dataset2):
|
||||
# TODO(rachelim): support sparse tensor outputs
|
||||
next1 = dataset1.make_one_shot_iterator().get_next()
|
||||
next2 = dataset2.make_one_shot_iterator().get_next()
|
||||
with self.test_session() as sess:
|
||||
while True:
|
||||
try:
|
||||
op1 = sess.run(next1)
|
||||
except errors.OutOfRangeError:
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next2)
|
||||
break
|
||||
op2 = sess.run(next2)
|
||||
|
||||
op1 = nest.flatten(op1)
|
||||
op2 = nest.flatten(op2)
|
||||
assert len(op1) == len(op2)
|
||||
for i in range(len(op1)):
|
||||
self.assertAllEqual(op1[i], op2[i])
|
||||
|
||||
def _assert_datasets_raise_same_error(self, dataset1, dataset2, exc_class):
|
||||
next1 = dataset1.make_one_shot_iterator().get_next()
|
||||
next2 = dataset2.make_one_shot_iterator().get_next()
|
||||
with self.test_session() as sess:
|
||||
try:
|
||||
sess.run(next1)
|
||||
raise ValueError(
|
||||
"Expected dataset to raise an error of type %s, but it did not." %
|
||||
repr(exc_class))
|
||||
except exc_class as e:
|
||||
# Check that the first segment of the error messages are the same.
|
||||
with self.assertRaisesRegexp(exc_class, e.message.split(". ")[0]):
|
||||
sess.run(next2)
|
@ -124,6 +124,43 @@ cc_library(
|
||||
] + tf_protos_all(),
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "map_vectorization",
|
||||
srcs = ["map_vectorization.cc"],
|
||||
hdrs = [
|
||||
"map_vectorization.h",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":graph_utils",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/grappler:mutable_graph_view",
|
||||
"//tensorflow/core/grappler:grappler_item",
|
||||
"//tensorflow/core/grappler:op_types",
|
||||
"//tensorflow/core/grappler:utils",
|
||||
"//tensorflow/core/grappler/clusters:cluster",
|
||||
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
|
||||
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
|
||||
"//tensorflow/core:lib_internal",
|
||||
] + tf_protos_all(),
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "map_vectorization_test",
|
||||
srcs = ["map_vectorization_test.cc"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":graph_utils",
|
||||
":map_vectorization",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
"//tensorflow/core/grappler:grappler_item",
|
||||
"//tensorflow/core/kernels:cast_op", # Must be linked for the testlib functions to work.
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "map_and_batch_fusion",
|
||||
srcs = ["map_and_batch_fusion.cc"],
|
||||
@ -311,6 +348,7 @@ cc_library(
|
||||
":map_and_batch_fusion",
|
||||
":map_and_filter_fusion",
|
||||
":map_fusion",
|
||||
":map_vectorization",
|
||||
":noop_elimination",
|
||||
":shuffle_and_repeat_fusion",
|
||||
],
|
||||
|
@ -108,6 +108,26 @@ NodeDef* AddNode(StringPiece name, StringPiece op,
|
||||
return graph->AddNode(std::move(node));
|
||||
}
|
||||
|
||||
NodeDef* AddNode(StringPiece name, StringPiece op,
|
||||
const std::vector<string>& inputs,
|
||||
const std::vector<std::pair<string, AttrValue>>& attributes,
|
||||
FunctionDef* fd) {
|
||||
NodeDef* node = fd->add_node_def();
|
||||
if (!name.empty()) {
|
||||
node->set_name(name.ToString());
|
||||
} else {
|
||||
SetUniqueFunctionNodeName(op, fd, node);
|
||||
}
|
||||
node->set_op(op.ToString());
|
||||
for (const string& input : inputs) {
|
||||
node->add_input(input);
|
||||
}
|
||||
for (auto attr : attributes) {
|
||||
(*node->mutable_attr())[attr.first] = attr.second;
|
||||
}
|
||||
return node;
|
||||
}
|
||||
|
||||
template <>
|
||||
NodeDef* AddScalarConstNode(bool v, MutableGraphView* graph) {
|
||||
return AddScalarConstNodeHelper(
|
||||
@ -181,7 +201,7 @@ bool ContainsGraphNodeWithName(StringPiece name, const GraphDef& graph) {
|
||||
}
|
||||
|
||||
bool ContainsNodeWithOp(StringPiece op, const GraphDef& graph) {
|
||||
return FindNodeWithOp(op, graph) != -1;
|
||||
return FindGraphNodeWithOp(op, graph) != -1;
|
||||
}
|
||||
|
||||
bool ContainsGraphFunctionWithName(StringPiece name,
|
||||
@ -205,7 +225,7 @@ int FindGraphNodeWithName(StringPiece name, const GraphDef& graph) {
|
||||
return indices.empty() ? -1 : indices.front();
|
||||
}
|
||||
|
||||
int FindNodeWithOp(StringPiece op, const GraphDef& graph) {
|
||||
int FindGraphNodeWithOp(StringPiece op, const GraphDef& graph) {
|
||||
std::vector<int> indices = GetElementIndicesWithPredicate(
|
||||
[&op](const NodeDef& node) { return node.op() == op; }, graph.node());
|
||||
return indices.empty() ? -1 : indices.front();
|
||||
@ -242,6 +262,12 @@ int FindFunctionNodeWithOp(StringPiece op, const FunctionDef& function) {
|
||||
return indices.empty() ? -1 : indices.front();
|
||||
}
|
||||
|
||||
NodeDef* GetInputNode(const NodeDef& node, const MutableGraphView& graph) {
|
||||
if (node.input_size() == 0) return nullptr;
|
||||
GraphView::InputPort input_port = graph.GetInputPort(node.name(), 0);
|
||||
return graph.GetRegularFanin(input_port).node;
|
||||
}
|
||||
|
||||
void SetUniqueGraphNodeName(StringPiece prefix, GraphDef* graph,
|
||||
NodeDef* node) {
|
||||
string name = prefix.ToString();
|
||||
|
@ -37,6 +37,12 @@ NodeDef* AddNode(StringPiece name, StringPiece op,
|
||||
const std::vector<std::pair<string, AttrValue>>& attributes,
|
||||
MutableGraphView* graph);
|
||||
|
||||
// Adds a node to a FunctionDef.
|
||||
NodeDef* AddNode(StringPiece name, StringPiece op,
|
||||
const std::vector<string>& inputs,
|
||||
const std::vector<std::pair<string, AttrValue>>& attributes,
|
||||
FunctionDef* fd);
|
||||
|
||||
// Adds a Const node with the given value to the graph.
|
||||
template <typename T>
|
||||
NodeDef* AddScalarConstNode(T v, MutableGraphView* graph) {
|
||||
@ -99,7 +105,10 @@ int FindFunctionNodeWithOp(StringPiece op, const FunctionDef& function);
|
||||
|
||||
// Returns the index of the first node with the given op or -1 if no such node
|
||||
// exists.
|
||||
int FindNodeWithOp(StringPiece op, const GraphDef& graph);
|
||||
int FindGraphNodeWithOp(StringPiece op, const GraphDef& graph);
|
||||
|
||||
// Gets the 0th input to a node in the graph.
|
||||
NodeDef* GetInputNode(const NodeDef& node, const MutableGraphView& graph);
|
||||
|
||||
// Returns the list of indices of all nodes with the given op or empty list if
|
||||
// no such node exists.
|
||||
|
@ -176,25 +176,25 @@ TEST(GraphUtilsTest, FindGraphFunctionWithName) {
|
||||
FindGraphFunctionWithName(new_function->signature().name(), library), -1);
|
||||
}
|
||||
|
||||
TEST(GraphUtilsTest, FindNodeWithOp) {
|
||||
TEST(GraphUtilsTest, FindGraphNodeWithOp) {
|
||||
GraphDef graph_def;
|
||||
MutableGraphView graph(&graph_def);
|
||||
EXPECT_EQ(FindNodeWithOp("OpA", *graph.GetGraph()), -1);
|
||||
EXPECT_EQ(FindGraphNodeWithOp("OpA", *graph.GetGraph()), -1);
|
||||
|
||||
AddNode("A", "OpA", {}, {}, &graph);
|
||||
AddNode("B", "OpB", {"A"}, {}, &graph);
|
||||
AddNode("A2", "OpA", {"B"}, {}, &graph);
|
||||
EXPECT_EQ(FindNodeWithOp("OpA", *graph.GetGraph()), 0);
|
||||
EXPECT_EQ(FindGraphNodeWithOp("OpA", *graph.GetGraph()), 0);
|
||||
|
||||
graph.DeleteNodes({"B"});
|
||||
EXPECT_EQ(FindNodeWithOp("OpB", *graph.GetGraph()), -1);
|
||||
EXPECT_EQ(FindGraphNodeWithOp("OpB", *graph.GetGraph()), -1);
|
||||
EXPECT_EQ(FindGraphNodeWithName("A2", *graph.GetGraph()), 1);
|
||||
}
|
||||
|
||||
TEST(GraphUtilsTest, FindAllGraphNodesWithOp) {
|
||||
GraphDef graph_def;
|
||||
MutableGraphView graph(&graph_def);
|
||||
EXPECT_EQ(FindNodeWithOp("OpA", *graph.GetGraph()), -1);
|
||||
EXPECT_EQ(FindGraphNodeWithOp("OpA", *graph.GetGraph()), -1);
|
||||
|
||||
AddNode("A", "OpA", {}, {}, &graph);
|
||||
AddNode("B", "OpB", {"A"}, {}, &graph);
|
||||
@ -251,6 +251,54 @@ TEST(GraphUtilsTest, SetUniqueGraphFunctionName) {
|
||||
other_function->signature().name());
|
||||
}
|
||||
|
||||
TEST(GraphUtilsTest, AddNodeToFunctionDef) {
|
||||
FunctionDef func;
|
||||
const char* op_name = "xxx";
|
||||
AddNode(op_name, op_name, {}, {}, &func);
|
||||
|
||||
const NodeDef& node1 = func.node_def(FindFunctionNodeWithName("xxx", func));
|
||||
EXPECT_EQ(node1.op(), op_name);
|
||||
EXPECT_EQ(node1.input_size(), 0);
|
||||
EXPECT_EQ(node1.attr_size(), 0);
|
||||
|
||||
const std::vector<string> inputs({"input1", "input2"});
|
||||
AddNode("", op_name, inputs, {}, &func);
|
||||
const NodeDef& node2 =
|
||||
func.node_def(FindFunctionNodeWithName("xxx/_2", func));
|
||||
EXPECT_EQ(node2.op(), op_name);
|
||||
EXPECT_EQ(node2.attr_size(), 0);
|
||||
EXPECT_EQ(node2.input_size(), inputs.size());
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
EXPECT_EQ(node2.input(i), inputs[i]);
|
||||
}
|
||||
|
||||
AttrValue a1, a2;
|
||||
a1.set_type(DT_INT32);
|
||||
a2.set_type(DT_INT64);
|
||||
const std::vector<std::pair<string, AttrValue>> attrs(
|
||||
{{"attr1", a1}, {"attr2", a2}});
|
||||
AddNode("", op_name, {}, attrs, &func);
|
||||
const NodeDef& node3 =
|
||||
func.node_def(FindFunctionNodeWithName("xxx/_3", func));
|
||||
EXPECT_EQ(node3.op(), op_name);
|
||||
EXPECT_EQ(node3.input_size(), 0);
|
||||
EXPECT_EQ(node3.attr_size(), attrs.size());
|
||||
for (size_t i = 0; i < attrs.size(); ++i) {
|
||||
EXPECT_EQ(attrs[i].second.type(), node3.attr().at(attrs[i].first).type());
|
||||
}
|
||||
}
|
||||
|
||||
TEST(GraphUtilsTest, GetInputNode) {
|
||||
GraphDef graph_def;
|
||||
MutableGraphView graph(&graph_def);
|
||||
|
||||
NodeDef* node1 = AddNode("", "A", {}, {}, &graph);
|
||||
NodeDef* node2 = AddNode("", "A", {node1->name()}, {}, &graph);
|
||||
|
||||
EXPECT_EQ(GetInputNode(*node2, graph), node1);
|
||||
EXPECT_EQ(GetInputNode(*node1, graph), nullptr);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace graph_utils
|
||||
} // namespace grappler
|
||||
|
@ -85,8 +85,8 @@ TEST(MapAndBatchFusionTest, FuseMapAndBatchNodesIntoOne) {
|
||||
EXPECT_FALSE(
|
||||
graph_utils::ContainsGraphNodeWithName(batch_node->name(), output));
|
||||
EXPECT_TRUE(graph_utils::ContainsNodeWithOp("MapAndBatchDatasetV2", output));
|
||||
NodeDef map_and_batch_node =
|
||||
output.node(graph_utils::FindNodeWithOp("MapAndBatchDatasetV2", output));
|
||||
NodeDef map_and_batch_node = output.node(
|
||||
graph_utils::FindGraphNodeWithOp("MapAndBatchDatasetV2", output));
|
||||
EXPECT_EQ(map_and_batch_node.input_size(), 5);
|
||||
EXPECT_EQ(map_and_batch_node.input(0), map_node->input(0));
|
||||
EXPECT_EQ(map_and_batch_node.input(1), map_node->input(1));
|
||||
@ -170,8 +170,8 @@ TEST(MapAndBatchFusionTest, FuseMapAndBatchV2NodesIntoOne) {
|
||||
EXPECT_FALSE(
|
||||
graph_utils::ContainsGraphNodeWithName(batch_node->name(), output));
|
||||
EXPECT_TRUE(graph_utils::ContainsNodeWithOp("MapAndBatchDatasetV2", output));
|
||||
NodeDef map_and_batch_node =
|
||||
output.node(graph_utils::FindNodeWithOp("MapAndBatchDatasetV2", output));
|
||||
NodeDef map_and_batch_node = output.node(
|
||||
graph_utils::FindGraphNodeWithOp("MapAndBatchDatasetV2", output));
|
||||
EXPECT_EQ(map_and_batch_node.input_size(), 5);
|
||||
EXPECT_EQ(map_and_batch_node.input(0), map_node->input(0));
|
||||
EXPECT_EQ(map_and_batch_node.input(1), map_node->input(1));
|
||||
@ -253,8 +253,8 @@ TEST(MapAndBatchFusionTest, FuseParallelMapAndBatchNodesIntoOne) {
|
||||
EXPECT_FALSE(
|
||||
graph_utils::ContainsGraphNodeWithName(batch_node->name(), output));
|
||||
EXPECT_TRUE(graph_utils::ContainsNodeWithOp("MapAndBatchDatasetV2", output));
|
||||
NodeDef map_and_batch_node =
|
||||
output.node(graph_utils::FindNodeWithOp("MapAndBatchDatasetV2", output));
|
||||
NodeDef map_and_batch_node = output.node(
|
||||
graph_utils::FindGraphNodeWithOp("MapAndBatchDatasetV2", output));
|
||||
EXPECT_EQ(map_and_batch_node.input_size(), 5);
|
||||
EXPECT_EQ(map_and_batch_node.input(0), map_node->input(0));
|
||||
EXPECT_EQ(map_and_batch_node.input(1), map_node->input(1));
|
||||
|
@ -101,18 +101,18 @@ TEST(MapAndFilterFusionTest, FuseMapAndFilterWithExtraChild) {
|
||||
graph_utils::ContainsNodeWithOp("FilterByLastComponentDataset", output));
|
||||
ASSERT_TRUE(graph_utils::ContainsNodeWithOp("CacheDataset", output));
|
||||
|
||||
int map_id = graph_utils::FindNodeWithOp("MapDataset", output);
|
||||
int map_id = graph_utils::FindGraphNodeWithOp("MapDataset", output);
|
||||
auto& map_node = output.node(map_id);
|
||||
ASSERT_EQ(map_node.input_size(), 1);
|
||||
EXPECT_EQ(map_node.input(0), "range");
|
||||
|
||||
int filter_by_component_id =
|
||||
graph_utils::FindNodeWithOp("FilterByLastComponentDataset", output);
|
||||
graph_utils::FindGraphNodeWithOp("FilterByLastComponentDataset", output);
|
||||
auto& filter_by_component = output.node(filter_by_component_id);
|
||||
ASSERT_EQ(filter_by_component.input_size(), 1);
|
||||
EXPECT_EQ(filter_by_component.input(0), map_node.name());
|
||||
|
||||
int cache_id = graph_utils::FindNodeWithOp("CacheDataset", output);
|
||||
int cache_id = graph_utils::FindGraphNodeWithOp("CacheDataset", output);
|
||||
auto& cache_node = output.node(cache_id);
|
||||
ASSERT_EQ(cache_node.input_size(), 2);
|
||||
EXPECT_EQ(cache_node.input(0), filter_by_component.name());
|
||||
|
257
tensorflow/core/grappler/optimizers/data/map_vectorization.cc
Normal file
257
tensorflow/core/grappler/optimizers/data/map_vectorization.cc
Normal file
@ -0,0 +1,257 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/grappler/optimizers/data/map_vectorization.h"
|
||||
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
||||
#include "tensorflow/core/grappler/clusters/cluster.h"
|
||||
#include "tensorflow/core/grappler/grappler_item.h"
|
||||
#include "tensorflow/core/grappler/mutable_graph_view.h"
|
||||
#include "tensorflow/core/grappler/op_types.h"
|
||||
#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
|
||||
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
|
||||
#include "tensorflow/core/grappler/utils.h"
|
||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
namespace {
|
||||
|
||||
void CopyAttribute(const string& attr_name, const NodeDef& from, NodeDef* to) {
|
||||
(*to->mutable_attr())[attr_name] = from.attr().at(attr_name);
|
||||
}
|
||||
|
||||
FunctionDef* AddVectorizedFunction(const NodeDef& map_node,
|
||||
const FunctionDef& orig_func,
|
||||
FunctionDefLibrary* library) {
|
||||
// If we decide to use a different method of vectorization, we can just
|
||||
// swap out this part.
|
||||
FunctionDef* vectorized_func = library->add_function();
|
||||
// Function inputs and outputs are the same as original, just
|
||||
// with different shapes.
|
||||
*vectorized_func->mutable_signature() = orig_func.signature();
|
||||
graph_utils::SetUniqueGraphFunctionName("vectorized_function", library,
|
||||
vectorized_func);
|
||||
|
||||
// Add MapDefun node
|
||||
NodeDef* map_defun_node = vectorized_func->mutable_node_def()->Add();
|
||||
map_defun_node->set_op("MapDefun");
|
||||
graph_utils::SetUniqueFunctionNodeName(map_defun_node->op(), vectorized_func,
|
||||
map_defun_node);
|
||||
|
||||
// Set attrs and inputs
|
||||
for (const string& k : {"f", "output_types", "output_shapes"}) {
|
||||
// Function, output types and (unbatched) shapes are the same as the
|
||||
// original map node.
|
||||
CopyAttribute(k, map_node, map_defun_node);
|
||||
}
|
||||
|
||||
// Get types of input arguments from original map function
|
||||
AttrValue t_args;
|
||||
for (const auto& input : vectorized_func->signature().input_arg()) {
|
||||
t_args.mutable_list()->add_type(input.type());
|
||||
map_defun_node->add_input(input.name());
|
||||
}
|
||||
(*map_defun_node->mutable_attr())["Targuments"] = t_args;
|
||||
|
||||
// Set return values to match output names
|
||||
string output_prefix = strings::StrCat(map_defun_node->name(), ":output:");
|
||||
for (size_t i = 0; i < vectorized_func->signature().output_arg_size(); ++i) {
|
||||
const auto& output_arg = vectorized_func->signature().output_arg(i);
|
||||
(*vectorized_func->mutable_ret())[output_arg.name()] =
|
||||
strings::StrCat(output_prefix, i);
|
||||
}
|
||||
|
||||
return vectorized_func;
|
||||
}
|
||||
|
||||
bool IsOutputShapesFullyDefined(const NodeDef& node) {
|
||||
auto* shapes_attr = gtl::FindOrNull(node.attr(), "output_shapes");
|
||||
if (shapes_attr == nullptr) return false;
|
||||
const auto& shapes = shapes_attr->list().shape();
|
||||
|
||||
for (const TensorShapeProto& shape : shapes) {
|
||||
for (const auto& dim : shape.dim()) {
|
||||
if (dim.size() == -1) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool IsStatefulFn(const FunctionLibraryDefinition& library,
|
||||
const FunctionDef& function_def) {
|
||||
for (const NodeDef& node_def : function_def.node_def()) {
|
||||
const OpDef* op_def;
|
||||
Status s = library.LookUpOpDef(node_def.op(), &op_def);
|
||||
if (!s.ok() || op_def->is_stateful()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool HasCapturedInputs(const NodeDef& map_node) {
|
||||
return map_node.attr().at("Targuments").list().type_size() > 0;
|
||||
}
|
||||
|
||||
NodeDef make_new_batch_node(const NodeDef& old_batch_node,
|
||||
const NodeDef& input_node,
|
||||
const FunctionDef& vectorized_func,
|
||||
MutableGraphView* graph) {
|
||||
NodeDef batch_node;
|
||||
batch_node.set_op(old_batch_node.op());
|
||||
graph_utils::SetUniqueGraphNodeName(batch_node.op(), graph->GetGraph(),
|
||||
&batch_node);
|
||||
|
||||
// Set the `input_dataset` input argument
|
||||
batch_node.add_input(input_node.name());
|
||||
// Set the `batch_size` input_argument
|
||||
batch_node.add_input(old_batch_node.input(1));
|
||||
if (batch_node.op() == "BatchDatasetV2") {
|
||||
// Set the `drop_remainder` input argument
|
||||
batch_node.add_input(old_batch_node.input(2));
|
||||
}
|
||||
|
||||
// Set attrs
|
||||
AttrValue output_types;
|
||||
for (const auto& input : vectorized_func.signature().input_arg()) {
|
||||
output_types.mutable_list()->add_type(input.type());
|
||||
}
|
||||
(*batch_node.mutable_attr())["output_types"] = output_types;
|
||||
|
||||
auto& output_shapes_attr = (*batch_node.mutable_attr())["output_shapes"];
|
||||
const auto& input_shapes =
|
||||
input_node.attr().at("output_shapes").list().shape();
|
||||
int64 batch_size =
|
||||
old_batch_node.attr().at("output_shapes").list().shape()[0].dim(0).size();
|
||||
for (size_t i = 0; i < input_shapes.size(); ++i) {
|
||||
TensorShapeProto* shape = output_shapes_attr.mutable_list()->add_shape();
|
||||
TensorShapeProto_Dim* dim = shape->add_dim();
|
||||
dim->set_size(batch_size);
|
||||
shape->MergeFrom(input_shapes.Get(i));
|
||||
}
|
||||
return batch_node;
|
||||
}
|
||||
|
||||
NodeDef make_new_map_node(const NodeDef& old_map_node,
|
||||
const NodeDef& old_batch_node,
|
||||
const NodeDef& new_batch_node,
|
||||
const FunctionDef& vectorized_func,
|
||||
MutableGraphView* graph) {
|
||||
NodeDef map_node;
|
||||
map_node.set_op(old_map_node.op());
|
||||
graph_utils::SetUniqueGraphNodeName(map_node.op(), graph->GetGraph(),
|
||||
&map_node);
|
||||
|
||||
// Set the `input_dataset` input argument
|
||||
map_node.add_input(new_batch_node.name());
|
||||
for (int i = 1; i < old_map_node.input_size(); i++) {
|
||||
// Set the `other_arguments` and `num_parallel_calls` input arguments
|
||||
map_node.add_input(old_map_node.input(i));
|
||||
}
|
||||
|
||||
// Set attrs
|
||||
CopyAttribute("Targuments", old_map_node, &map_node);
|
||||
auto& func_attr = (*map_node.mutable_attr())["f"];
|
||||
func_attr.mutable_func()->set_name(vectorized_func.signature().name());
|
||||
|
||||
for (auto key : {"output_shapes", "output_types"}) {
|
||||
CopyAttribute(key, old_batch_node, &map_node);
|
||||
}
|
||||
return map_node;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Status MapVectorization::Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||
GraphDef* output) {
|
||||
*output = item.graph;
|
||||
MutableGraphView graph(output);
|
||||
std::set<string> nodes_to_delete;
|
||||
|
||||
for (const NodeDef& node : item.graph.node()) {
|
||||
// Find Map->Batch nodes.
|
||||
// TODO(rachelim): Optimize MapAndBatchDataset[V2] as well.
|
||||
if (node.op() != "BatchDataset" && node.op() != "BatchDatasetV2") {
|
||||
continue;
|
||||
}
|
||||
|
||||
const NodeDef& batch_node(node);
|
||||
NodeDef* node2 = graph_utils::GetInputNode(batch_node, graph);
|
||||
if (node2->op() != "MapDataset" && node2->op() != "ParallelMapDataset") {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Use a more descriptive variable name now that we know the node type.
|
||||
NodeDef* map_node = node2;
|
||||
// Input to the map node
|
||||
NodeDef* input_node = graph_utils::GetInputNode(*map_node, graph);
|
||||
CHECK_NOTNULL(input_node);
|
||||
|
||||
FunctionDefLibrary* library = output->mutable_library();
|
||||
|
||||
FunctionLibraryDefinition function_library(OpRegistry::Global(), *library);
|
||||
const FunctionDef* orig_func =
|
||||
function_library.Find(map_node->attr().at("f").func().name());
|
||||
|
||||
// Check that this is a valid optimization.
|
||||
if (!IsOutputShapesFullyDefined(*input_node) ||
|
||||
!IsOutputShapesFullyDefined(*map_node) ||
|
||||
IsStatefulFn(function_library, *orig_func) ||
|
||||
HasCapturedInputs(*map_node)) {
|
||||
// 1. If any of the inputs have an unknown shape, don't optimize, since
|
||||
// inputs might not be batchable.
|
||||
// 2. If any of the map func outputs have an unknown shape, don't
|
||||
// optimize, so that batching errors surface as before.
|
||||
// 3. If the function is stateful, don't vectorize it.
|
||||
// 4. TODO(rachelim): Make this work for MapDataset with captured inputs
|
||||
// by tiling inputs or modifying the signature of MapDefun.
|
||||
continue;
|
||||
}
|
||||
|
||||
FunctionDef* vectorized_func =
|
||||
AddVectorizedFunction(*map_node, *orig_func, library);
|
||||
CHECK_NOTNULL(vectorized_func);
|
||||
|
||||
auto* new_batch_node = graph.AddNode(
|
||||
make_new_batch_node(batch_node, *input_node, *vectorized_func, &graph));
|
||||
|
||||
auto* new_map_node = graph.AddNode(make_new_map_node(
|
||||
*map_node, batch_node, *new_batch_node, *vectorized_func, &graph));
|
||||
graph.ReplaceInput(batch_node, *new_map_node);
|
||||
|
||||
// Mark the `Map` and `Batch` nodes for removal.
|
||||
nodes_to_delete.insert(map_node->name());
|
||||
nodes_to_delete.insert(batch_node.name());
|
||||
}
|
||||
graph.DeleteNodes(nodes_to_delete);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void MapVectorization::Feedback(Cluster* cluster, const GrapplerItem& item,
|
||||
const GraphDef& optimize_output,
|
||||
double result) {
|
||||
// no-op
|
||||
}
|
||||
|
||||
REGISTER_GRAPH_OPTIMIZER_AS(MapVectorization, "map_vectorization");
|
||||
|
||||
} // end namespace grappler
|
||||
} // end namespace tensorflow
|
46
tensorflow/core/grappler/optimizers/data/map_vectorization.h
Normal file
46
tensorflow/core/grappler/optimizers/data/map_vectorization.h
Normal file
@ -0,0 +1,46 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_MAP_VECTORIZATION_H_
|
||||
#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_MAP_VECTORIZATION_H_
|
||||
|
||||
#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
|
||||
class MapVectorization : public CustomGraphOptimizer {
|
||||
public:
|
||||
MapVectorization() = default;
|
||||
~MapVectorization() override = default;
|
||||
|
||||
string name() const override { return "map_vectorization"; };
|
||||
|
||||
Status Init(
|
||||
const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||
GraphDef* output) override;
|
||||
|
||||
void Feedback(Cluster* cluster, const GrapplerItem& item,
|
||||
const GraphDef& optimize_output, double result) override;
|
||||
};
|
||||
|
||||
} // end namespace grappler
|
||||
} // end namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_MAP_VECTORIZATION_H_
|
@ -0,0 +1,201 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/grappler/optimizers/data/map_vectorization.h"
|
||||
|
||||
#include "tensorflow/core/framework/attr_value_util.h"
|
||||
#include "tensorflow/core/framework/function_testlib.h"
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
#include "tensorflow/core/grappler/grappler_item.h"
|
||||
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
namespace {
|
||||
|
||||
using test::function::GDef;
|
||||
using test::function::NDef;
|
||||
|
||||
void MakeTensorShapeProtoHelper(const gtl::ArraySlice<int> dims,
|
||||
TensorShapeProto* t) {
|
||||
for (size_t i = 0; i < dims.size(); ++i) {
|
||||
auto* d = t->add_dim();
|
||||
d->set_size(dims[i]);
|
||||
}
|
||||
}
|
||||
|
||||
AttrValue MakeShapeListAttr(
|
||||
const gtl::ArraySlice<const gtl::ArraySlice<int>>& shapes) {
|
||||
AttrValue shapes_attr;
|
||||
for (size_t i = 0; i < shapes.size(); ++i) {
|
||||
MakeTensorShapeProtoHelper(shapes[i],
|
||||
shapes_attr.mutable_list()->add_shape());
|
||||
}
|
||||
|
||||
return shapes_attr;
|
||||
}
|
||||
|
||||
NodeDef MakeMapNodeHelper(
|
||||
StringPiece name, StringPiece input_node_name, StringPiece function_name,
|
||||
StringPiece map_op_name,
|
||||
const gtl::ArraySlice<const gtl::ArraySlice<int>>& output_shapes,
|
||||
const gtl::ArraySlice<DataType>& output_types) {
|
||||
return test::function::NDef(
|
||||
name, map_op_name, {input_node_name.ToString()},
|
||||
{{"f", FunctionDefHelper::FunctionRef(function_name.ToString())},
|
||||
{"Targuments", {}},
|
||||
{"output_shapes", MakeShapeListAttr(output_shapes)},
|
||||
{"output_types", output_types}});
|
||||
}
|
||||
|
||||
NodeDef MakeMapNode(
|
||||
StringPiece name, StringPiece input_node_name, StringPiece function_name,
|
||||
const gtl::ArraySlice<const gtl::ArraySlice<int>>& output_shapes,
|
||||
const gtl::ArraySlice<DataType>& output_types) {
|
||||
return MakeMapNodeHelper(name, input_node_name, function_name, "MapDataset",
|
||||
output_shapes, output_types);
|
||||
}
|
||||
|
||||
NodeDef MakeBatchNode(
|
||||
StringPiece name, StringPiece input_node_name,
|
||||
StringPiece input_batch_size_name,
|
||||
const gtl::ArraySlice<const gtl::ArraySlice<int>>& output_shapes,
|
||||
const gtl::ArraySlice<DataType>& output_types) {
|
||||
return NDef(name, "BatchDataset",
|
||||
{input_node_name.ToString(), input_batch_size_name.ToString()},
|
||||
{{"output_types", output_types},
|
||||
{"output_shapes", MakeShapeListAttr(output_shapes)}});
|
||||
}
|
||||
|
||||
NodeDef MakeBatchV2Node(
|
||||
StringPiece name, StringPiece input_node_name,
|
||||
StringPiece input_batch_size_name, StringPiece input_drop_remainder_name,
|
||||
const gtl::ArraySlice<const gtl::ArraySlice<int>>& output_shapes,
|
||||
const gtl::ArraySlice<DataType>& output_types) {
|
||||
return NDef(name, "BatchDatasetV2",
|
||||
{input_node_name.ToString(), input_batch_size_name.ToString(),
|
||||
input_drop_remainder_name.ToString()},
|
||||
{{"output_types", output_types},
|
||||
{"output_shapes", MakeShapeListAttr(output_shapes)}});
|
||||
}
|
||||
|
||||
NodeDef MakeRangeNode(StringPiece name, const gtl::ArraySlice<string>& inputs) {
|
||||
return NDef(name, "RangeDataset", inputs,
|
||||
{{"output_shapes", MakeShapeListAttr({{}})},
|
||||
{"output_types", gtl::ArraySlice<DataType>({DT_INT64})}});
|
||||
}
|
||||
|
||||
TEST(MapVectorizationTest, VectorizeMapWithBatch) {
|
||||
GrapplerItem item;
|
||||
item.graph = GDef(
|
||||
{NDef("start", "Const", {}, {{"value", 0}, {"dtype", DT_INT32}}),
|
||||
NDef("stop", "Const", {}, {{"value", 10}, {"dtype", DT_INT32}}),
|
||||
NDef("step", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
|
||||
NDef("batch_size", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
|
||||
MakeRangeNode("range", {"start", "stop", "step"}),
|
||||
MakeMapNode("map", "range", "XTimesTwo", {{}}, {DT_INT32}),
|
||||
MakeBatchNode("batch", "map", "batch_size", {{-1}}, {DT_INT32})},
|
||||
// FunctionLib
|
||||
{
|
||||
test::function::XTimesTwo(),
|
||||
});
|
||||
MapVectorization optimizer;
|
||||
GraphDef output;
|
||||
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
|
||||
|
||||
EXPECT_EQ(graph_utils::FindAllGraphNodesWithOp("MapDataset", output).size(),
|
||||
1);
|
||||
EXPECT_EQ(graph_utils::FindAllGraphNodesWithOp("BatchDataset", output).size(),
|
||||
1);
|
||||
const NodeDef& map_node =
|
||||
output.node(graph_utils::FindGraphNodeWithOp("MapDataset", output));
|
||||
const NodeDef& batch_node =
|
||||
output.node(graph_utils::FindGraphNodeWithOp("BatchDataset", output));
|
||||
EXPECT_EQ(map_node.input(0), batch_node.name());
|
||||
EXPECT_EQ(batch_node.input(0), "range");
|
||||
}
|
||||
|
||||
TEST(MapVectorizationTest, VectorizeMapWithBatchV2) {
|
||||
GrapplerItem item;
|
||||
item.graph = GDef(
|
||||
{NDef("start", "Const", {}, {{"value", 0}, {"dtype", DT_INT32}}),
|
||||
NDef("stop", "Const", {}, {{"value", 10}, {"dtype", DT_INT32}}),
|
||||
NDef("step", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
|
||||
NDef("batch_size", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
|
||||
NDef("drop_remainder", "Const", {},
|
||||
{{"value", false}, {"dtype", DT_BOOL}}),
|
||||
MakeRangeNode("range", {"start", "stop", "step"}),
|
||||
MakeMapNode("map", "range", "XTimesTwo", {{}}, {DT_INT32}),
|
||||
MakeBatchV2Node("batch", "map", "batch_size", "drop_remainder", {{-1}},
|
||||
{DT_INT32})},
|
||||
// FunctionLib
|
||||
{
|
||||
test::function::XTimesTwo(),
|
||||
});
|
||||
MapVectorization optimizer;
|
||||
GraphDef output;
|
||||
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
|
||||
|
||||
EXPECT_EQ(graph_utils::FindAllGraphNodesWithOp("MapDataset", output).size(),
|
||||
1);
|
||||
EXPECT_EQ(
|
||||
graph_utils::FindAllGraphNodesWithOp("BatchDatasetV2", output).size(), 1);
|
||||
const NodeDef& map_node =
|
||||
output.node(graph_utils::FindGraphNodeWithOp("MapDataset", output));
|
||||
const NodeDef& batch_node =
|
||||
output.node(graph_utils::FindGraphNodeWithOp("BatchDatasetV2", output));
|
||||
EXPECT_EQ(map_node.input(0), batch_node.name());
|
||||
EXPECT_EQ(batch_node.input(0), "range");
|
||||
}
|
||||
|
||||
TEST(MapVectorizationTest, VectorizeWithUndefinedOutputShape) {
|
||||
GrapplerItem item;
|
||||
item.graph = GDef(
|
||||
{NDef("batch_size", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
|
||||
NDef("input", "InputDataset", {},
|
||||
{{"output_types", gtl::ArraySlice<DataType>({DT_INT32})}}),
|
||||
MakeMapNode("map", "input", "XTimesTwo", {{}}, {DT_INT32}),
|
||||
MakeBatchNode("batch", "map", "batch_size", {{-1}}, {DT_INT32})},
|
||||
// FunctionLib
|
||||
{
|
||||
test::function::XTimesTwo(),
|
||||
});
|
||||
MapVectorization optimizer;
|
||||
GraphDef output;
|
||||
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
|
||||
}
|
||||
|
||||
TEST(MapVectorizationTest, VectorizeWithUndefinedOutputTypes) {
|
||||
GrapplerItem item;
|
||||
item.graph = GDef(
|
||||
{NDef("batch_size", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
|
||||
NDef("input", "InputDataset", {},
|
||||
{{"output_shapes", MakeShapeListAttr({{}})}}),
|
||||
MakeMapNode("map", "input", "XTimesTwo", {{}}, {DT_INT32}),
|
||||
MakeBatchNode("batch", "map", "batch_size", {{-1}}, {DT_INT32})},
|
||||
// FunctionLib
|
||||
{
|
||||
test::function::XTimesTwo(),
|
||||
});
|
||||
MapVectorization optimizer;
|
||||
GraphDef output;
|
||||
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace grappler
|
||||
} // namespace tensorflow
|
@ -78,7 +78,7 @@ TEST(ShuffleAndRepeatFusionTest, FuseShuffleAndRepeatNodesIntoOne) {
|
||||
EXPECT_TRUE(
|
||||
graph_utils::ContainsNodeWithOp("ShuffleAndRepeatDataset", output));
|
||||
NodeDef shuffle_and_repeat_node = output.node(
|
||||
graph_utils::FindNodeWithOp("ShuffleAndRepeatDataset", output));
|
||||
graph_utils::FindGraphNodeWithOp("ShuffleAndRepeatDataset", output));
|
||||
EXPECT_EQ(shuffle_and_repeat_node.input_size(), 5);
|
||||
EXPECT_EQ(shuffle_and_repeat_node.input(0), shuffle_node->input(0));
|
||||
EXPECT_EQ(shuffle_and_repeat_node.input(1), shuffle_node->input(1));
|
||||
|
@ -74,7 +74,11 @@ class MapDefunOp : public AsyncOpKernel {
|
||||
arg_shapes->at(i).RemoveDim(0); // Remove the first batch dimension
|
||||
OP_REQUIRES_ASYNC(
|
||||
ctx, batch_size == ctx->input(i).dim_size(0),
|
||||
errors::InvalidArgument("All inputs must have the same dimension 0."),
|
||||
errors::InvalidArgument(
|
||||
"All inputs must have the same dimension 0. Input ", i,
|
||||
" has leading dimension ", ctx->input(i).dim_size(0),
|
||||
", while all previous inputs have leading dimension ", batch_size,
|
||||
"."),
|
||||
done);
|
||||
}
|
||||
|
||||
|
@ -71,6 +71,7 @@ COMMON_PIP_DEPS = [
|
||||
"//tensorflow/contrib/constrained_optimization:constrained_optimization_pip",
|
||||
"//tensorflow/contrib/data/python/kernel_tests/serialization:dataset_serialization_test_base",
|
||||
"//tensorflow/contrib/data/python/kernel_tests:stats_dataset_test_base",
|
||||
"//tensorflow/contrib/data/python/kernel_tests:test_utils",
|
||||
"//tensorflow/contrib/data/python/ops:contrib_op_loader",
|
||||
"//tensorflow/contrib/eager/python/examples:examples_pip",
|
||||
"//tensorflow/contrib/eager/python:evaluator",
|
||||
|
Loading…
Reference in New Issue
Block a user