[tf.data] Add an optimization that vectorizes map functions and swaps the order of Map->Batch dataset transformations to Batch->Map
PiperOrigin-RevId: 209674669
This commit is contained in:
parent
7989b2bc99
commit
62fcb03449
@ -230,12 +230,15 @@ py_test(
|
|||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
deps = [
|
deps = [
|
||||||
":stats_dataset_test_base",
|
":stats_dataset_test_base",
|
||||||
|
":test_utils",
|
||||||
"//tensorflow/contrib/data/python/ops:optimization",
|
"//tensorflow/contrib/data/python/ops:optimization",
|
||||||
"//tensorflow/contrib/data/python/ops:stats_ops",
|
"//tensorflow/contrib/data/python/ops:stats_ops",
|
||||||
|
"//tensorflow/python:check_ops",
|
||||||
"//tensorflow/python:client_testlib",
|
"//tensorflow/python:client_testlib",
|
||||||
"//tensorflow/python:constant_op",
|
"//tensorflow/python:constant_op",
|
||||||
"//tensorflow/python:dtypes",
|
"//tensorflow/python:dtypes",
|
||||||
"//tensorflow/python:errors",
|
"//tensorflow/python:errors",
|
||||||
|
"//tensorflow/python:framework_ops",
|
||||||
"//tensorflow/python:math_ops",
|
"//tensorflow/python:math_ops",
|
||||||
"//tensorflow/python/data/ops:dataset_ops",
|
"//tensorflow/python/data/ops:dataset_ops",
|
||||||
"@absl_py//absl/testing:parameterized",
|
"@absl_py//absl/testing:parameterized",
|
||||||
@ -548,3 +551,13 @@ py_test(
|
|||||||
"//tensorflow/python/data/ops:readers",
|
"//tensorflow/python/data/ops:readers",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "test_utils",
|
||||||
|
srcs = ["test_utils.py"],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/python:client_testlib",
|
||||||
|
"//tensorflow/python:errors",
|
||||||
|
"//tensorflow/python/data/util:nest",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@ -31,47 +31,57 @@ from tensorflow.python.platform import test
|
|||||||
|
|
||||||
class MapDefunTest(test.TestCase):
|
class MapDefunTest(test.TestCase):
|
||||||
|
|
||||||
def testMapDefun_Simple(self):
|
def testMapDefunSimple(self):
|
||||||
|
|
||||||
@function.Defun(dtypes.int32)
|
@function.Defun(dtypes.int32)
|
||||||
def simple_fn(x):
|
def simple_fn(x):
|
||||||
return x * 2 + 3
|
return x * 2 + 3
|
||||||
|
|
||||||
with self.test_session():
|
nums = [[1, 2], [3, 4], [5, 6]]
|
||||||
nums = [[1, 2], [3, 4], [5, 6]]
|
elems = constant_op.constant(nums, dtype=dtypes.int32, name="data")
|
||||||
elems = constant_op.constant(nums, dtype=dtypes.int32, name="data")
|
r = map_defun.map_defun(simple_fn, [elems], [dtypes.int32], [(2,)])[0]
|
||||||
r = map_defun.map_defun(simple_fn, [elems], [dtypes.int32], [(2,)])[0]
|
expected = elems * 2 + 3
|
||||||
expected = elems * 2 + 3
|
self.assertAllEqual(self.evaluate(r), self.evaluate(expected))
|
||||||
self.assertAllEqual(self.evaluate(r), self.evaluate(expected))
|
|
||||||
|
|
||||||
def testMapDefun_MismatchedTypes(self):
|
def testMapDefunMismatchedTypes(self):
|
||||||
|
|
||||||
@function.Defun(dtypes.int32)
|
@function.Defun(dtypes.int32)
|
||||||
def fn(x):
|
def fn(x):
|
||||||
return math_ops.cast(x, dtypes.float64)
|
return math_ops.cast(x, dtypes.float64)
|
||||||
|
|
||||||
with self.test_session():
|
nums = [1, 2, 3, 4, 5, 6]
|
||||||
nums = [1, 2, 3, 4, 5, 6]
|
elems = constant_op.constant(nums, dtype=dtypes.int32, name="data")
|
||||||
elems = constant_op.constant(nums, dtype=dtypes.int32, name="data")
|
r = map_defun.map_defun(fn, [elems], [dtypes.int32], [()])[0]
|
||||||
r = map_defun.map_defun(fn, [elems], [dtypes.int32], [()])[0]
|
with self.assertRaises(errors.InvalidArgumentError):
|
||||||
with self.assertRaises(errors.InvalidArgumentError):
|
self.evaluate(r)
|
||||||
self.evaluate(r)
|
|
||||||
|
|
||||||
def testMapDefun_MultipleOutputs(self):
|
def testMapDefunReduceDim(self):
|
||||||
|
# Tests where the output has a different rank from the input
|
||||||
|
|
||||||
|
@function.Defun(dtypes.int32)
|
||||||
|
def fn(x):
|
||||||
|
return array_ops.gather(x, 0)
|
||||||
|
|
||||||
|
nums = [[1, 2], [3, 4], [5, 6]]
|
||||||
|
elems = constant_op.constant(nums, dtype=dtypes.int32, name="data")
|
||||||
|
r = map_defun.map_defun(fn, [elems], [dtypes.int32], [()])[0]
|
||||||
|
expected = constant_op.constant([1, 3, 5])
|
||||||
|
self.assertAllEqual(self.evaluate(r), self.evaluate(expected))
|
||||||
|
|
||||||
|
def testMapDefunMultipleOutputs(self):
|
||||||
|
|
||||||
@function.Defun(dtypes.int32)
|
@function.Defun(dtypes.int32)
|
||||||
def fn(x):
|
def fn(x):
|
||||||
return (x, math_ops.cast(x * 2 + 3, dtypes.float64))
|
return (x, math_ops.cast(x * 2 + 3, dtypes.float64))
|
||||||
|
|
||||||
with self.test_session():
|
nums = [[1, 2], [3, 4], [5, 6]]
|
||||||
nums = [[1, 2], [3, 4], [5, 6]]
|
elems = constant_op.constant(nums, dtype=dtypes.int32, name="data")
|
||||||
elems = constant_op.constant(nums, dtype=dtypes.int32, name="data")
|
r = map_defun.map_defun(fn, [elems], [dtypes.int32, dtypes.float64], [(2,),
|
||||||
r = map_defun.map_defun(fn, [elems], [dtypes.int32, dtypes.float64],
|
(2,)])
|
||||||
[(2,), (2,)])
|
expected = [elems, elems * 2 + 3]
|
||||||
expected = [elems, elems * 2 + 3]
|
self.assertAllEqual(self.evaluate(r), self.evaluate(expected))
|
||||||
self.assertAllEqual(self.evaluate(r), self.evaluate(expected))
|
|
||||||
|
|
||||||
def testMapDefun_ShapeInference(self):
|
def testMapDefunShapeInference(self):
|
||||||
|
|
||||||
@function.Defun(dtypes.int32)
|
@function.Defun(dtypes.int32)
|
||||||
def fn(x):
|
def fn(x):
|
||||||
@ -82,7 +92,7 @@ class MapDefunTest(test.TestCase):
|
|||||||
result = map_defun.map_defun(fn, [elems], [dtypes.int32], [(2,)])[0]
|
result = map_defun.map_defun(fn, [elems], [dtypes.int32], [(2,)])[0]
|
||||||
self.assertEqual(result.get_shape(), (3, 2))
|
self.assertEqual(result.get_shape(), (3, 2))
|
||||||
|
|
||||||
def testMapDefun_PartialShapeInference(self):
|
def testMapDefunPartialShapeInference(self):
|
||||||
|
|
||||||
@function.Defun(dtypes.int32)
|
@function.Defun(dtypes.int32)
|
||||||
def fn(x):
|
def fn(x):
|
||||||
@ -92,7 +102,7 @@ class MapDefunTest(test.TestCase):
|
|||||||
result = map_defun.map_defun(fn, [elems], [dtypes.int32], [(2,)])
|
result = map_defun.map_defun(fn, [elems], [dtypes.int32], [(2,)])
|
||||||
self.assertEqual(result[0].get_shape().as_list(), [None, 2])
|
self.assertEqual(result[0].get_shape().as_list(), [None, 2])
|
||||||
|
|
||||||
def testMapDefun_RaisesErrorOnRuntimeShapeMismatch(self):
|
def testMapDefunRaisesErrorOnRuntimeShapeMismatch(self):
|
||||||
|
|
||||||
@function.Defun(dtypes.int32, dtypes.int32)
|
@function.Defun(dtypes.int32, dtypes.int32)
|
||||||
def fn(x, y):
|
def fn(x, y):
|
||||||
@ -108,7 +118,7 @@ class MapDefunTest(test.TestCase):
|
|||||||
"All inputs must have the same dimension 0."):
|
"All inputs must have the same dimension 0."):
|
||||||
sess.run(result, feed_dict={elems1: [1, 2, 3, 4, 5], elems2: [1, 2, 3]})
|
sess.run(result, feed_dict={elems1: [1, 2, 3, 4, 5], elems2: [1, 2, 3]})
|
||||||
|
|
||||||
def testMapDefun_RaisesDefunError(self):
|
def testMapDefunRaisesDefunError(self):
|
||||||
|
|
||||||
@function.Defun(dtypes.int32)
|
@function.Defun(dtypes.int32)
|
||||||
def fn(x):
|
def fn(x):
|
||||||
@ -117,9 +127,8 @@ class MapDefunTest(test.TestCase):
|
|||||||
|
|
||||||
elems = constant_op.constant([0, 0, 0, 37, 0])
|
elems = constant_op.constant([0, 0, 0, 37, 0])
|
||||||
result = map_defun.map_defun(fn, [elems], [dtypes.int32], [()])
|
result = map_defun.map_defun(fn, [elems], [dtypes.int32], [()])
|
||||||
with self.test_session():
|
with self.assertRaises(errors.InvalidArgumentError):
|
||||||
with self.assertRaises(errors.InvalidArgumentError):
|
self.evaluate(result)
|
||||||
self.evaluate(result)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -20,12 +20,16 @@ from __future__ import print_function
|
|||||||
from absl.testing import parameterized
|
from absl.testing import parameterized
|
||||||
|
|
||||||
from tensorflow.contrib.data.python.kernel_tests import stats_dataset_test_base
|
from tensorflow.contrib.data.python.kernel_tests import stats_dataset_test_base
|
||||||
|
from tensorflow.contrib.data.python.kernel_tests import test_utils
|
||||||
from tensorflow.contrib.data.python.ops import optimization
|
from tensorflow.contrib.data.python.ops import optimization
|
||||||
from tensorflow.contrib.data.python.ops import stats_ops
|
from tensorflow.contrib.data.python.ops import stats_ops
|
||||||
from tensorflow.python.data.ops import dataset_ops
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import errors
|
from tensorflow.python.framework import errors
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.ops import check_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
@ -277,5 +281,124 @@ class OptimizeStatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
|
|||||||
"record_latency_PrefetchDataset/_6", 1)
|
"record_latency_PrefetchDataset/_6", 1)
|
||||||
|
|
||||||
|
|
||||||
|
class MapVectorizationTest(test_utils.DatasetTestBase, parameterized.TestCase):
|
||||||
|
|
||||||
|
def _get_test_datasets(self,
|
||||||
|
base_dataset,
|
||||||
|
map_fn,
|
||||||
|
num_parallel_calls=None,
|
||||||
|
expect_optimized=True):
|
||||||
|
"""Given base dataset and map fn, creates test datasets.
|
||||||
|
|
||||||
|
Returns a tuple of (unoptimized, dataset, optimized dataset). The
|
||||||
|
unoptimized dataset has the assertion that Batch follows Map. The optimized
|
||||||
|
dataset has the assertion that Map follows Batch, and has the
|
||||||
|
"map_vectorization" optimization applied.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
base_dataset: Input dataset to map->batch
|
||||||
|
map_fn: Map function to use
|
||||||
|
num_parallel_calls: (Optional.) num_parallel_calls argument for map
|
||||||
|
expect_optimized: (Optional.) Whether we expect the optimization to take
|
||||||
|
place, in which case we will assert that Batch is followed by Map,
|
||||||
|
otherwise Map followed by Batch. Defaults to True.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (unoptimized dataset, optimized dataset).
|
||||||
|
"""
|
||||||
|
map_node_name = "Map" if num_parallel_calls is None else "ParallelMap"
|
||||||
|
batch_size = 100
|
||||||
|
|
||||||
|
def _make_dataset(node_names):
|
||||||
|
return base_dataset.apply(optimization.assert_next(node_names)).map(
|
||||||
|
map_fn, num_parallel_calls=num_parallel_calls).batch(batch_size)
|
||||||
|
|
||||||
|
unoptimized = _make_dataset([map_node_name, "Batch"])
|
||||||
|
optimized = _make_dataset(["Batch", map_node_name] if expect_optimized else
|
||||||
|
[map_node_name, "Batch"]).apply(
|
||||||
|
optimization.optimize(["map_vectorization"]))
|
||||||
|
|
||||||
|
return unoptimized, optimized
|
||||||
|
|
||||||
|
@parameterized.named_parameters(
|
||||||
|
("Basic", lambda x: (x, x + 1), None),
|
||||||
|
("Parallel", lambda x: (x, x + 1), 12),
|
||||||
|
("Gather", lambda x: array_ops.gather(x, 0), 12),
|
||||||
|
)
|
||||||
|
def testOptimization(self, map_fn, num_parallel_calls):
|
||||||
|
base_dataset = dataset_ops.Dataset.from_tensor_slices([[1, 2],
|
||||||
|
[3, 4]]).repeat(5)
|
||||||
|
unoptimized, optimized = self._get_test_datasets(base_dataset, map_fn,
|
||||||
|
num_parallel_calls)
|
||||||
|
self._assert_datasets_equal(unoptimized, optimized)
|
||||||
|
|
||||||
|
def testOptimizationBadMapFn(self):
|
||||||
|
# Test map functions that give an error
|
||||||
|
def map_fn(x):
|
||||||
|
# x has leading dimension 5, this will raise an error
|
||||||
|
return array_ops.gather(x, 10)
|
||||||
|
|
||||||
|
base_dataset = dataset_ops.Dataset.range(5).repeat(5).batch(
|
||||||
|
5, drop_remainder=True)
|
||||||
|
_, optimized = self._get_test_datasets(base_dataset, map_fn)
|
||||||
|
nxt = optimized.make_one_shot_iterator().get_next()
|
||||||
|
with self.assertRaisesRegexp(errors.InvalidArgumentError,
|
||||||
|
r"indices = 10 is not in \[0, 5\)"):
|
||||||
|
self.evaluate(nxt)
|
||||||
|
|
||||||
|
def testOptimizationWithCapturedInputs(self):
|
||||||
|
# Tests that vectorization works with captured inputs
|
||||||
|
def map_fn(x):
|
||||||
|
return x + y
|
||||||
|
|
||||||
|
y = constant_op.constant(1, shape=(2,))
|
||||||
|
base_dataset = dataset_ops.Dataset.from_tensor_slices([[1, 2],
|
||||||
|
[3, 4]]).repeat(5)
|
||||||
|
# TODO(rachelim): when this optimization works, turn on expect_optimized
|
||||||
|
unoptimized, optimized = self._get_test_datasets(
|
||||||
|
base_dataset, map_fn, expect_optimized=False)
|
||||||
|
self._assert_datasets_equal(optimized, unoptimized)
|
||||||
|
|
||||||
|
def testOptimizationIgnoreStateful(self):
|
||||||
|
|
||||||
|
def map_fn(x):
|
||||||
|
with ops.control_dependencies([check_ops.assert_equal(x, 0)]):
|
||||||
|
return array_ops.identity(x)
|
||||||
|
|
||||||
|
base_dataset = dataset_ops.Dataset.from_tensor_slices([[1, 2],
|
||||||
|
[3, 4]]).repeat(5)
|
||||||
|
_, optimized = self._get_test_datasets(
|
||||||
|
base_dataset, map_fn, expect_optimized=False)
|
||||||
|
nxt = optimized.make_one_shot_iterator().get_next()
|
||||||
|
|
||||||
|
# NOTE: Right now, it raises an error because we can't save datasets that
|
||||||
|
# are stateful, and we rely on this saving mechanism to optimize datasets,
|
||||||
|
# so stateful functions can't be optimized.
|
||||||
|
with self.assertRaisesRegexp(errors.InvalidArgumentError, "[Ss]tateful"):
|
||||||
|
self.evaluate(nxt)
|
||||||
|
|
||||||
|
def testOptimizationIgnoreRagged(self):
|
||||||
|
# Make sure we ignore inputs that might not be uniformly sized
|
||||||
|
def map_fn(x):
|
||||||
|
return array_ops.gather(x, 0)
|
||||||
|
|
||||||
|
# output_shape = (?,)
|
||||||
|
base_dataset = dataset_ops.Dataset.range(20).batch(3, drop_remainder=False)
|
||||||
|
unoptimized, optimized = self._get_test_datasets(
|
||||||
|
base_dataset, map_fn, expect_optimized=False)
|
||||||
|
self._assert_datasets_equal(unoptimized, optimized)
|
||||||
|
|
||||||
|
def testOptimizationIgnoreRaggedMap(self):
|
||||||
|
# Don't optimize when the output of the map fn shapes are unknown.
|
||||||
|
def map_fn(x):
|
||||||
|
return array_ops.tile(x, x)
|
||||||
|
|
||||||
|
base_dataset = dataset_ops.Dataset.range(20).batch(1, drop_remainder=True)
|
||||||
|
unoptimized, optimized = self._get_test_datasets(
|
||||||
|
base_dataset, map_fn, expect_optimized=False)
|
||||||
|
self._assert_datasets_raise_same_error(unoptimized, optimized,
|
||||||
|
errors.InvalidArgumentError)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test.main()
|
test.main()
|
||||||
|
60
tensorflow/contrib/data/python/kernel_tests/test_utils.py
Normal file
60
tensorflow/contrib/data/python/kernel_tests/test_utils.py
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Test utilities for tf.data functionality."""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from tensorflow.python.data.util import nest
|
||||||
|
from tensorflow.python.framework import errors
|
||||||
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetTestBase(test.TestCase):
|
||||||
|
"""Base class for dataset tests."""
|
||||||
|
|
||||||
|
def _assert_datasets_equal(self, dataset1, dataset2):
|
||||||
|
# TODO(rachelim): support sparse tensor outputs
|
||||||
|
next1 = dataset1.make_one_shot_iterator().get_next()
|
||||||
|
next2 = dataset2.make_one_shot_iterator().get_next()
|
||||||
|
with self.test_session() as sess:
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
op1 = sess.run(next1)
|
||||||
|
except errors.OutOfRangeError:
|
||||||
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
|
sess.run(next2)
|
||||||
|
break
|
||||||
|
op2 = sess.run(next2)
|
||||||
|
|
||||||
|
op1 = nest.flatten(op1)
|
||||||
|
op2 = nest.flatten(op2)
|
||||||
|
assert len(op1) == len(op2)
|
||||||
|
for i in range(len(op1)):
|
||||||
|
self.assertAllEqual(op1[i], op2[i])
|
||||||
|
|
||||||
|
def _assert_datasets_raise_same_error(self, dataset1, dataset2, exc_class):
|
||||||
|
next1 = dataset1.make_one_shot_iterator().get_next()
|
||||||
|
next2 = dataset2.make_one_shot_iterator().get_next()
|
||||||
|
with self.test_session() as sess:
|
||||||
|
try:
|
||||||
|
sess.run(next1)
|
||||||
|
raise ValueError(
|
||||||
|
"Expected dataset to raise an error of type %s, but it did not." %
|
||||||
|
repr(exc_class))
|
||||||
|
except exc_class as e:
|
||||||
|
# Check that the first segment of the error messages are the same.
|
||||||
|
with self.assertRaisesRegexp(exc_class, e.message.split(". ")[0]):
|
||||||
|
sess.run(next2)
|
@ -124,6 +124,43 @@ cc_library(
|
|||||||
] + tf_protos_all(),
|
] + tf_protos_all(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "map_vectorization",
|
||||||
|
srcs = ["map_vectorization.cc"],
|
||||||
|
hdrs = [
|
||||||
|
"map_vectorization.h",
|
||||||
|
],
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
deps = [
|
||||||
|
":graph_utils",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core/grappler:mutable_graph_view",
|
||||||
|
"//tensorflow/core/grappler:grappler_item",
|
||||||
|
"//tensorflow/core/grappler:op_types",
|
||||||
|
"//tensorflow/core/grappler:utils",
|
||||||
|
"//tensorflow/core/grappler/clusters:cluster",
|
||||||
|
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
|
||||||
|
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
|
||||||
|
"//tensorflow/core:lib_internal",
|
||||||
|
] + tf_protos_all(),
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_cc_test(
|
||||||
|
name = "map_vectorization_test",
|
||||||
|
srcs = ["map_vectorization_test.cc"],
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
deps = [
|
||||||
|
":graph_utils",
|
||||||
|
":map_vectorization",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:test",
|
||||||
|
"//tensorflow/core:test_main",
|
||||||
|
"//tensorflow/core:testlib",
|
||||||
|
"//tensorflow/core/grappler:grappler_item",
|
||||||
|
"//tensorflow/core/kernels:cast_op", # Must be linked for the testlib functions to work.
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "map_and_batch_fusion",
|
name = "map_and_batch_fusion",
|
||||||
srcs = ["map_and_batch_fusion.cc"],
|
srcs = ["map_and_batch_fusion.cc"],
|
||||||
@ -311,6 +348,7 @@ cc_library(
|
|||||||
":map_and_batch_fusion",
|
":map_and_batch_fusion",
|
||||||
":map_and_filter_fusion",
|
":map_and_filter_fusion",
|
||||||
":map_fusion",
|
":map_fusion",
|
||||||
|
":map_vectorization",
|
||||||
":noop_elimination",
|
":noop_elimination",
|
||||||
":shuffle_and_repeat_fusion",
|
":shuffle_and_repeat_fusion",
|
||||||
],
|
],
|
||||||
|
@ -108,6 +108,26 @@ NodeDef* AddNode(StringPiece name, StringPiece op,
|
|||||||
return graph->AddNode(std::move(node));
|
return graph->AddNode(std::move(node));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
NodeDef* AddNode(StringPiece name, StringPiece op,
|
||||||
|
const std::vector<string>& inputs,
|
||||||
|
const std::vector<std::pair<string, AttrValue>>& attributes,
|
||||||
|
FunctionDef* fd) {
|
||||||
|
NodeDef* node = fd->add_node_def();
|
||||||
|
if (!name.empty()) {
|
||||||
|
node->set_name(name.ToString());
|
||||||
|
} else {
|
||||||
|
SetUniqueFunctionNodeName(op, fd, node);
|
||||||
|
}
|
||||||
|
node->set_op(op.ToString());
|
||||||
|
for (const string& input : inputs) {
|
||||||
|
node->add_input(input);
|
||||||
|
}
|
||||||
|
for (auto attr : attributes) {
|
||||||
|
(*node->mutable_attr())[attr.first] = attr.second;
|
||||||
|
}
|
||||||
|
return node;
|
||||||
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
NodeDef* AddScalarConstNode(bool v, MutableGraphView* graph) {
|
NodeDef* AddScalarConstNode(bool v, MutableGraphView* graph) {
|
||||||
return AddScalarConstNodeHelper(
|
return AddScalarConstNodeHelper(
|
||||||
@ -181,7 +201,7 @@ bool ContainsGraphNodeWithName(StringPiece name, const GraphDef& graph) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
bool ContainsNodeWithOp(StringPiece op, const GraphDef& graph) {
|
bool ContainsNodeWithOp(StringPiece op, const GraphDef& graph) {
|
||||||
return FindNodeWithOp(op, graph) != -1;
|
return FindGraphNodeWithOp(op, graph) != -1;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool ContainsGraphFunctionWithName(StringPiece name,
|
bool ContainsGraphFunctionWithName(StringPiece name,
|
||||||
@ -205,7 +225,7 @@ int FindGraphNodeWithName(StringPiece name, const GraphDef& graph) {
|
|||||||
return indices.empty() ? -1 : indices.front();
|
return indices.empty() ? -1 : indices.front();
|
||||||
}
|
}
|
||||||
|
|
||||||
int FindNodeWithOp(StringPiece op, const GraphDef& graph) {
|
int FindGraphNodeWithOp(StringPiece op, const GraphDef& graph) {
|
||||||
std::vector<int> indices = GetElementIndicesWithPredicate(
|
std::vector<int> indices = GetElementIndicesWithPredicate(
|
||||||
[&op](const NodeDef& node) { return node.op() == op; }, graph.node());
|
[&op](const NodeDef& node) { return node.op() == op; }, graph.node());
|
||||||
return indices.empty() ? -1 : indices.front();
|
return indices.empty() ? -1 : indices.front();
|
||||||
@ -242,6 +262,12 @@ int FindFunctionNodeWithOp(StringPiece op, const FunctionDef& function) {
|
|||||||
return indices.empty() ? -1 : indices.front();
|
return indices.empty() ? -1 : indices.front();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
NodeDef* GetInputNode(const NodeDef& node, const MutableGraphView& graph) {
|
||||||
|
if (node.input_size() == 0) return nullptr;
|
||||||
|
GraphView::InputPort input_port = graph.GetInputPort(node.name(), 0);
|
||||||
|
return graph.GetRegularFanin(input_port).node;
|
||||||
|
}
|
||||||
|
|
||||||
void SetUniqueGraphNodeName(StringPiece prefix, GraphDef* graph,
|
void SetUniqueGraphNodeName(StringPiece prefix, GraphDef* graph,
|
||||||
NodeDef* node) {
|
NodeDef* node) {
|
||||||
string name = prefix.ToString();
|
string name = prefix.ToString();
|
||||||
|
@ -37,6 +37,12 @@ NodeDef* AddNode(StringPiece name, StringPiece op,
|
|||||||
const std::vector<std::pair<string, AttrValue>>& attributes,
|
const std::vector<std::pair<string, AttrValue>>& attributes,
|
||||||
MutableGraphView* graph);
|
MutableGraphView* graph);
|
||||||
|
|
||||||
|
// Adds a node to a FunctionDef.
|
||||||
|
NodeDef* AddNode(StringPiece name, StringPiece op,
|
||||||
|
const std::vector<string>& inputs,
|
||||||
|
const std::vector<std::pair<string, AttrValue>>& attributes,
|
||||||
|
FunctionDef* fd);
|
||||||
|
|
||||||
// Adds a Const node with the given value to the graph.
|
// Adds a Const node with the given value to the graph.
|
||||||
template <typename T>
|
template <typename T>
|
||||||
NodeDef* AddScalarConstNode(T v, MutableGraphView* graph) {
|
NodeDef* AddScalarConstNode(T v, MutableGraphView* graph) {
|
||||||
@ -99,7 +105,10 @@ int FindFunctionNodeWithOp(StringPiece op, const FunctionDef& function);
|
|||||||
|
|
||||||
// Returns the index of the first node with the given op or -1 if no such node
|
// Returns the index of the first node with the given op or -1 if no such node
|
||||||
// exists.
|
// exists.
|
||||||
int FindNodeWithOp(StringPiece op, const GraphDef& graph);
|
int FindGraphNodeWithOp(StringPiece op, const GraphDef& graph);
|
||||||
|
|
||||||
|
// Gets the 0th input to a node in the graph.
|
||||||
|
NodeDef* GetInputNode(const NodeDef& node, const MutableGraphView& graph);
|
||||||
|
|
||||||
// Returns the list of indices of all nodes with the given op or empty list if
|
// Returns the list of indices of all nodes with the given op or empty list if
|
||||||
// no such node exists.
|
// no such node exists.
|
||||||
|
@ -176,25 +176,25 @@ TEST(GraphUtilsTest, FindGraphFunctionWithName) {
|
|||||||
FindGraphFunctionWithName(new_function->signature().name(), library), -1);
|
FindGraphFunctionWithName(new_function->signature().name(), library), -1);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(GraphUtilsTest, FindNodeWithOp) {
|
TEST(GraphUtilsTest, FindGraphNodeWithOp) {
|
||||||
GraphDef graph_def;
|
GraphDef graph_def;
|
||||||
MutableGraphView graph(&graph_def);
|
MutableGraphView graph(&graph_def);
|
||||||
EXPECT_EQ(FindNodeWithOp("OpA", *graph.GetGraph()), -1);
|
EXPECT_EQ(FindGraphNodeWithOp("OpA", *graph.GetGraph()), -1);
|
||||||
|
|
||||||
AddNode("A", "OpA", {}, {}, &graph);
|
AddNode("A", "OpA", {}, {}, &graph);
|
||||||
AddNode("B", "OpB", {"A"}, {}, &graph);
|
AddNode("B", "OpB", {"A"}, {}, &graph);
|
||||||
AddNode("A2", "OpA", {"B"}, {}, &graph);
|
AddNode("A2", "OpA", {"B"}, {}, &graph);
|
||||||
EXPECT_EQ(FindNodeWithOp("OpA", *graph.GetGraph()), 0);
|
EXPECT_EQ(FindGraphNodeWithOp("OpA", *graph.GetGraph()), 0);
|
||||||
|
|
||||||
graph.DeleteNodes({"B"});
|
graph.DeleteNodes({"B"});
|
||||||
EXPECT_EQ(FindNodeWithOp("OpB", *graph.GetGraph()), -1);
|
EXPECT_EQ(FindGraphNodeWithOp("OpB", *graph.GetGraph()), -1);
|
||||||
EXPECT_EQ(FindGraphNodeWithName("A2", *graph.GetGraph()), 1);
|
EXPECT_EQ(FindGraphNodeWithName("A2", *graph.GetGraph()), 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(GraphUtilsTest, FindAllGraphNodesWithOp) {
|
TEST(GraphUtilsTest, FindAllGraphNodesWithOp) {
|
||||||
GraphDef graph_def;
|
GraphDef graph_def;
|
||||||
MutableGraphView graph(&graph_def);
|
MutableGraphView graph(&graph_def);
|
||||||
EXPECT_EQ(FindNodeWithOp("OpA", *graph.GetGraph()), -1);
|
EXPECT_EQ(FindGraphNodeWithOp("OpA", *graph.GetGraph()), -1);
|
||||||
|
|
||||||
AddNode("A", "OpA", {}, {}, &graph);
|
AddNode("A", "OpA", {}, {}, &graph);
|
||||||
AddNode("B", "OpB", {"A"}, {}, &graph);
|
AddNode("B", "OpB", {"A"}, {}, &graph);
|
||||||
@ -251,6 +251,54 @@ TEST(GraphUtilsTest, SetUniqueGraphFunctionName) {
|
|||||||
other_function->signature().name());
|
other_function->signature().name());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(GraphUtilsTest, AddNodeToFunctionDef) {
|
||||||
|
FunctionDef func;
|
||||||
|
const char* op_name = "xxx";
|
||||||
|
AddNode(op_name, op_name, {}, {}, &func);
|
||||||
|
|
||||||
|
const NodeDef& node1 = func.node_def(FindFunctionNodeWithName("xxx", func));
|
||||||
|
EXPECT_EQ(node1.op(), op_name);
|
||||||
|
EXPECT_EQ(node1.input_size(), 0);
|
||||||
|
EXPECT_EQ(node1.attr_size(), 0);
|
||||||
|
|
||||||
|
const std::vector<string> inputs({"input1", "input2"});
|
||||||
|
AddNode("", op_name, inputs, {}, &func);
|
||||||
|
const NodeDef& node2 =
|
||||||
|
func.node_def(FindFunctionNodeWithName("xxx/_2", func));
|
||||||
|
EXPECT_EQ(node2.op(), op_name);
|
||||||
|
EXPECT_EQ(node2.attr_size(), 0);
|
||||||
|
EXPECT_EQ(node2.input_size(), inputs.size());
|
||||||
|
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||||
|
EXPECT_EQ(node2.input(i), inputs[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
AttrValue a1, a2;
|
||||||
|
a1.set_type(DT_INT32);
|
||||||
|
a2.set_type(DT_INT64);
|
||||||
|
const std::vector<std::pair<string, AttrValue>> attrs(
|
||||||
|
{{"attr1", a1}, {"attr2", a2}});
|
||||||
|
AddNode("", op_name, {}, attrs, &func);
|
||||||
|
const NodeDef& node3 =
|
||||||
|
func.node_def(FindFunctionNodeWithName("xxx/_3", func));
|
||||||
|
EXPECT_EQ(node3.op(), op_name);
|
||||||
|
EXPECT_EQ(node3.input_size(), 0);
|
||||||
|
EXPECT_EQ(node3.attr_size(), attrs.size());
|
||||||
|
for (size_t i = 0; i < attrs.size(); ++i) {
|
||||||
|
EXPECT_EQ(attrs[i].second.type(), node3.attr().at(attrs[i].first).type());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(GraphUtilsTest, GetInputNode) {
|
||||||
|
GraphDef graph_def;
|
||||||
|
MutableGraphView graph(&graph_def);
|
||||||
|
|
||||||
|
NodeDef* node1 = AddNode("", "A", {}, {}, &graph);
|
||||||
|
NodeDef* node2 = AddNode("", "A", {node1->name()}, {}, &graph);
|
||||||
|
|
||||||
|
EXPECT_EQ(GetInputNode(*node2, graph), node1);
|
||||||
|
EXPECT_EQ(GetInputNode(*node1, graph), nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace graph_utils
|
} // namespace graph_utils
|
||||||
} // namespace grappler
|
} // namespace grappler
|
||||||
|
@ -85,8 +85,8 @@ TEST(MapAndBatchFusionTest, FuseMapAndBatchNodesIntoOne) {
|
|||||||
EXPECT_FALSE(
|
EXPECT_FALSE(
|
||||||
graph_utils::ContainsGraphNodeWithName(batch_node->name(), output));
|
graph_utils::ContainsGraphNodeWithName(batch_node->name(), output));
|
||||||
EXPECT_TRUE(graph_utils::ContainsNodeWithOp("MapAndBatchDatasetV2", output));
|
EXPECT_TRUE(graph_utils::ContainsNodeWithOp("MapAndBatchDatasetV2", output));
|
||||||
NodeDef map_and_batch_node =
|
NodeDef map_and_batch_node = output.node(
|
||||||
output.node(graph_utils::FindNodeWithOp("MapAndBatchDatasetV2", output));
|
graph_utils::FindGraphNodeWithOp("MapAndBatchDatasetV2", output));
|
||||||
EXPECT_EQ(map_and_batch_node.input_size(), 5);
|
EXPECT_EQ(map_and_batch_node.input_size(), 5);
|
||||||
EXPECT_EQ(map_and_batch_node.input(0), map_node->input(0));
|
EXPECT_EQ(map_and_batch_node.input(0), map_node->input(0));
|
||||||
EXPECT_EQ(map_and_batch_node.input(1), map_node->input(1));
|
EXPECT_EQ(map_and_batch_node.input(1), map_node->input(1));
|
||||||
@ -170,8 +170,8 @@ TEST(MapAndBatchFusionTest, FuseMapAndBatchV2NodesIntoOne) {
|
|||||||
EXPECT_FALSE(
|
EXPECT_FALSE(
|
||||||
graph_utils::ContainsGraphNodeWithName(batch_node->name(), output));
|
graph_utils::ContainsGraphNodeWithName(batch_node->name(), output));
|
||||||
EXPECT_TRUE(graph_utils::ContainsNodeWithOp("MapAndBatchDatasetV2", output));
|
EXPECT_TRUE(graph_utils::ContainsNodeWithOp("MapAndBatchDatasetV2", output));
|
||||||
NodeDef map_and_batch_node =
|
NodeDef map_and_batch_node = output.node(
|
||||||
output.node(graph_utils::FindNodeWithOp("MapAndBatchDatasetV2", output));
|
graph_utils::FindGraphNodeWithOp("MapAndBatchDatasetV2", output));
|
||||||
EXPECT_EQ(map_and_batch_node.input_size(), 5);
|
EXPECT_EQ(map_and_batch_node.input_size(), 5);
|
||||||
EXPECT_EQ(map_and_batch_node.input(0), map_node->input(0));
|
EXPECT_EQ(map_and_batch_node.input(0), map_node->input(0));
|
||||||
EXPECT_EQ(map_and_batch_node.input(1), map_node->input(1));
|
EXPECT_EQ(map_and_batch_node.input(1), map_node->input(1));
|
||||||
@ -253,8 +253,8 @@ TEST(MapAndBatchFusionTest, FuseParallelMapAndBatchNodesIntoOne) {
|
|||||||
EXPECT_FALSE(
|
EXPECT_FALSE(
|
||||||
graph_utils::ContainsGraphNodeWithName(batch_node->name(), output));
|
graph_utils::ContainsGraphNodeWithName(batch_node->name(), output));
|
||||||
EXPECT_TRUE(graph_utils::ContainsNodeWithOp("MapAndBatchDatasetV2", output));
|
EXPECT_TRUE(graph_utils::ContainsNodeWithOp("MapAndBatchDatasetV2", output));
|
||||||
NodeDef map_and_batch_node =
|
NodeDef map_and_batch_node = output.node(
|
||||||
output.node(graph_utils::FindNodeWithOp("MapAndBatchDatasetV2", output));
|
graph_utils::FindGraphNodeWithOp("MapAndBatchDatasetV2", output));
|
||||||
EXPECT_EQ(map_and_batch_node.input_size(), 5);
|
EXPECT_EQ(map_and_batch_node.input_size(), 5);
|
||||||
EXPECT_EQ(map_and_batch_node.input(0), map_node->input(0));
|
EXPECT_EQ(map_and_batch_node.input(0), map_node->input(0));
|
||||||
EXPECT_EQ(map_and_batch_node.input(1), map_node->input(1));
|
EXPECT_EQ(map_and_batch_node.input(1), map_node->input(1));
|
||||||
|
@ -101,18 +101,18 @@ TEST(MapAndFilterFusionTest, FuseMapAndFilterWithExtraChild) {
|
|||||||
graph_utils::ContainsNodeWithOp("FilterByLastComponentDataset", output));
|
graph_utils::ContainsNodeWithOp("FilterByLastComponentDataset", output));
|
||||||
ASSERT_TRUE(graph_utils::ContainsNodeWithOp("CacheDataset", output));
|
ASSERT_TRUE(graph_utils::ContainsNodeWithOp("CacheDataset", output));
|
||||||
|
|
||||||
int map_id = graph_utils::FindNodeWithOp("MapDataset", output);
|
int map_id = graph_utils::FindGraphNodeWithOp("MapDataset", output);
|
||||||
auto& map_node = output.node(map_id);
|
auto& map_node = output.node(map_id);
|
||||||
ASSERT_EQ(map_node.input_size(), 1);
|
ASSERT_EQ(map_node.input_size(), 1);
|
||||||
EXPECT_EQ(map_node.input(0), "range");
|
EXPECT_EQ(map_node.input(0), "range");
|
||||||
|
|
||||||
int filter_by_component_id =
|
int filter_by_component_id =
|
||||||
graph_utils::FindNodeWithOp("FilterByLastComponentDataset", output);
|
graph_utils::FindGraphNodeWithOp("FilterByLastComponentDataset", output);
|
||||||
auto& filter_by_component = output.node(filter_by_component_id);
|
auto& filter_by_component = output.node(filter_by_component_id);
|
||||||
ASSERT_EQ(filter_by_component.input_size(), 1);
|
ASSERT_EQ(filter_by_component.input_size(), 1);
|
||||||
EXPECT_EQ(filter_by_component.input(0), map_node.name());
|
EXPECT_EQ(filter_by_component.input(0), map_node.name());
|
||||||
|
|
||||||
int cache_id = graph_utils::FindNodeWithOp("CacheDataset", output);
|
int cache_id = graph_utils::FindGraphNodeWithOp("CacheDataset", output);
|
||||||
auto& cache_node = output.node(cache_id);
|
auto& cache_node = output.node(cache_id);
|
||||||
ASSERT_EQ(cache_node.input_size(), 2);
|
ASSERT_EQ(cache_node.input_size(), 2);
|
||||||
EXPECT_EQ(cache_node.input(0), filter_by_component.name());
|
EXPECT_EQ(cache_node.input(0), filter_by_component.name());
|
||||||
|
257
tensorflow/core/grappler/optimizers/data/map_vectorization.cc
Normal file
257
tensorflow/core/grappler/optimizers/data/map_vectorization.cc
Normal file
@ -0,0 +1,257 @@
|
|||||||
|
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/core/grappler/optimizers/data/map_vectorization.h"
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||||
|
#include "tensorflow/core/framework/node_def.pb.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
||||||
|
#include "tensorflow/core/grappler/clusters/cluster.h"
|
||||||
|
#include "tensorflow/core/grappler/grappler_item.h"
|
||||||
|
#include "tensorflow/core/grappler/mutable_graph_view.h"
|
||||||
|
#include "tensorflow/core/grappler/op_types.h"
|
||||||
|
#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
|
||||||
|
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
|
||||||
|
#include "tensorflow/core/grappler/utils.h"
|
||||||
|
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||||
|
#include "tensorflow/core/platform/protobuf.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace grappler {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
void CopyAttribute(const string& attr_name, const NodeDef& from, NodeDef* to) {
|
||||||
|
(*to->mutable_attr())[attr_name] = from.attr().at(attr_name);
|
||||||
|
}
|
||||||
|
|
||||||
|
FunctionDef* AddVectorizedFunction(const NodeDef& map_node,
|
||||||
|
const FunctionDef& orig_func,
|
||||||
|
FunctionDefLibrary* library) {
|
||||||
|
// If we decide to use a different method of vectorization, we can just
|
||||||
|
// swap out this part.
|
||||||
|
FunctionDef* vectorized_func = library->add_function();
|
||||||
|
// Function inputs and outputs are the same as original, just
|
||||||
|
// with different shapes.
|
||||||
|
*vectorized_func->mutable_signature() = orig_func.signature();
|
||||||
|
graph_utils::SetUniqueGraphFunctionName("vectorized_function", library,
|
||||||
|
vectorized_func);
|
||||||
|
|
||||||
|
// Add MapDefun node
|
||||||
|
NodeDef* map_defun_node = vectorized_func->mutable_node_def()->Add();
|
||||||
|
map_defun_node->set_op("MapDefun");
|
||||||
|
graph_utils::SetUniqueFunctionNodeName(map_defun_node->op(), vectorized_func,
|
||||||
|
map_defun_node);
|
||||||
|
|
||||||
|
// Set attrs and inputs
|
||||||
|
for (const string& k : {"f", "output_types", "output_shapes"}) {
|
||||||
|
// Function, output types and (unbatched) shapes are the same as the
|
||||||
|
// original map node.
|
||||||
|
CopyAttribute(k, map_node, map_defun_node);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get types of input arguments from original map function
|
||||||
|
AttrValue t_args;
|
||||||
|
for (const auto& input : vectorized_func->signature().input_arg()) {
|
||||||
|
t_args.mutable_list()->add_type(input.type());
|
||||||
|
map_defun_node->add_input(input.name());
|
||||||
|
}
|
||||||
|
(*map_defun_node->mutable_attr())["Targuments"] = t_args;
|
||||||
|
|
||||||
|
// Set return values to match output names
|
||||||
|
string output_prefix = strings::StrCat(map_defun_node->name(), ":output:");
|
||||||
|
for (size_t i = 0; i < vectorized_func->signature().output_arg_size(); ++i) {
|
||||||
|
const auto& output_arg = vectorized_func->signature().output_arg(i);
|
||||||
|
(*vectorized_func->mutable_ret())[output_arg.name()] =
|
||||||
|
strings::StrCat(output_prefix, i);
|
||||||
|
}
|
||||||
|
|
||||||
|
return vectorized_func;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool IsOutputShapesFullyDefined(const NodeDef& node) {
|
||||||
|
auto* shapes_attr = gtl::FindOrNull(node.attr(), "output_shapes");
|
||||||
|
if (shapes_attr == nullptr) return false;
|
||||||
|
const auto& shapes = shapes_attr->list().shape();
|
||||||
|
|
||||||
|
for (const TensorShapeProto& shape : shapes) {
|
||||||
|
for (const auto& dim : shape.dim()) {
|
||||||
|
if (dim.size() == -1) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool IsStatefulFn(const FunctionLibraryDefinition& library,
|
||||||
|
const FunctionDef& function_def) {
|
||||||
|
for (const NodeDef& node_def : function_def.node_def()) {
|
||||||
|
const OpDef* op_def;
|
||||||
|
Status s = library.LookUpOpDef(node_def.op(), &op_def);
|
||||||
|
if (!s.ok() || op_def->is_stateful()) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool HasCapturedInputs(const NodeDef& map_node) {
|
||||||
|
return map_node.attr().at("Targuments").list().type_size() > 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
NodeDef make_new_batch_node(const NodeDef& old_batch_node,
|
||||||
|
const NodeDef& input_node,
|
||||||
|
const FunctionDef& vectorized_func,
|
||||||
|
MutableGraphView* graph) {
|
||||||
|
NodeDef batch_node;
|
||||||
|
batch_node.set_op(old_batch_node.op());
|
||||||
|
graph_utils::SetUniqueGraphNodeName(batch_node.op(), graph->GetGraph(),
|
||||||
|
&batch_node);
|
||||||
|
|
||||||
|
// Set the `input_dataset` input argument
|
||||||
|
batch_node.add_input(input_node.name());
|
||||||
|
// Set the `batch_size` input_argument
|
||||||
|
batch_node.add_input(old_batch_node.input(1));
|
||||||
|
if (batch_node.op() == "BatchDatasetV2") {
|
||||||
|
// Set the `drop_remainder` input argument
|
||||||
|
batch_node.add_input(old_batch_node.input(2));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set attrs
|
||||||
|
AttrValue output_types;
|
||||||
|
for (const auto& input : vectorized_func.signature().input_arg()) {
|
||||||
|
output_types.mutable_list()->add_type(input.type());
|
||||||
|
}
|
||||||
|
(*batch_node.mutable_attr())["output_types"] = output_types;
|
||||||
|
|
||||||
|
auto& output_shapes_attr = (*batch_node.mutable_attr())["output_shapes"];
|
||||||
|
const auto& input_shapes =
|
||||||
|
input_node.attr().at("output_shapes").list().shape();
|
||||||
|
int64 batch_size =
|
||||||
|
old_batch_node.attr().at("output_shapes").list().shape()[0].dim(0).size();
|
||||||
|
for (size_t i = 0; i < input_shapes.size(); ++i) {
|
||||||
|
TensorShapeProto* shape = output_shapes_attr.mutable_list()->add_shape();
|
||||||
|
TensorShapeProto_Dim* dim = shape->add_dim();
|
||||||
|
dim->set_size(batch_size);
|
||||||
|
shape->MergeFrom(input_shapes.Get(i));
|
||||||
|
}
|
||||||
|
return batch_node;
|
||||||
|
}
|
||||||
|
|
||||||
|
NodeDef make_new_map_node(const NodeDef& old_map_node,
|
||||||
|
const NodeDef& old_batch_node,
|
||||||
|
const NodeDef& new_batch_node,
|
||||||
|
const FunctionDef& vectorized_func,
|
||||||
|
MutableGraphView* graph) {
|
||||||
|
NodeDef map_node;
|
||||||
|
map_node.set_op(old_map_node.op());
|
||||||
|
graph_utils::SetUniqueGraphNodeName(map_node.op(), graph->GetGraph(),
|
||||||
|
&map_node);
|
||||||
|
|
||||||
|
// Set the `input_dataset` input argument
|
||||||
|
map_node.add_input(new_batch_node.name());
|
||||||
|
for (int i = 1; i < old_map_node.input_size(); i++) {
|
||||||
|
// Set the `other_arguments` and `num_parallel_calls` input arguments
|
||||||
|
map_node.add_input(old_map_node.input(i));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set attrs
|
||||||
|
CopyAttribute("Targuments", old_map_node, &map_node);
|
||||||
|
auto& func_attr = (*map_node.mutable_attr())["f"];
|
||||||
|
func_attr.mutable_func()->set_name(vectorized_func.signature().name());
|
||||||
|
|
||||||
|
for (auto key : {"output_shapes", "output_types"}) {
|
||||||
|
CopyAttribute(key, old_batch_node, &map_node);
|
||||||
|
}
|
||||||
|
return map_node;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
Status MapVectorization::Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||||
|
GraphDef* output) {
|
||||||
|
*output = item.graph;
|
||||||
|
MutableGraphView graph(output);
|
||||||
|
std::set<string> nodes_to_delete;
|
||||||
|
|
||||||
|
for (const NodeDef& node : item.graph.node()) {
|
||||||
|
// Find Map->Batch nodes.
|
||||||
|
// TODO(rachelim): Optimize MapAndBatchDataset[V2] as well.
|
||||||
|
if (node.op() != "BatchDataset" && node.op() != "BatchDatasetV2") {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
const NodeDef& batch_node(node);
|
||||||
|
NodeDef* node2 = graph_utils::GetInputNode(batch_node, graph);
|
||||||
|
if (node2->op() != "MapDataset" && node2->op() != "ParallelMapDataset") {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use a more descriptive variable name now that we know the node type.
|
||||||
|
NodeDef* map_node = node2;
|
||||||
|
// Input to the map node
|
||||||
|
NodeDef* input_node = graph_utils::GetInputNode(*map_node, graph);
|
||||||
|
CHECK_NOTNULL(input_node);
|
||||||
|
|
||||||
|
FunctionDefLibrary* library = output->mutable_library();
|
||||||
|
|
||||||
|
FunctionLibraryDefinition function_library(OpRegistry::Global(), *library);
|
||||||
|
const FunctionDef* orig_func =
|
||||||
|
function_library.Find(map_node->attr().at("f").func().name());
|
||||||
|
|
||||||
|
// Check that this is a valid optimization.
|
||||||
|
if (!IsOutputShapesFullyDefined(*input_node) ||
|
||||||
|
!IsOutputShapesFullyDefined(*map_node) ||
|
||||||
|
IsStatefulFn(function_library, *orig_func) ||
|
||||||
|
HasCapturedInputs(*map_node)) {
|
||||||
|
// 1. If any of the inputs have an unknown shape, don't optimize, since
|
||||||
|
// inputs might not be batchable.
|
||||||
|
// 2. If any of the map func outputs have an unknown shape, don't
|
||||||
|
// optimize, so that batching errors surface as before.
|
||||||
|
// 3. If the function is stateful, don't vectorize it.
|
||||||
|
// 4. TODO(rachelim): Make this work for MapDataset with captured inputs
|
||||||
|
// by tiling inputs or modifying the signature of MapDefun.
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
FunctionDef* vectorized_func =
|
||||||
|
AddVectorizedFunction(*map_node, *orig_func, library);
|
||||||
|
CHECK_NOTNULL(vectorized_func);
|
||||||
|
|
||||||
|
auto* new_batch_node = graph.AddNode(
|
||||||
|
make_new_batch_node(batch_node, *input_node, *vectorized_func, &graph));
|
||||||
|
|
||||||
|
auto* new_map_node = graph.AddNode(make_new_map_node(
|
||||||
|
*map_node, batch_node, *new_batch_node, *vectorized_func, &graph));
|
||||||
|
graph.ReplaceInput(batch_node, *new_map_node);
|
||||||
|
|
||||||
|
// Mark the `Map` and `Batch` nodes for removal.
|
||||||
|
nodes_to_delete.insert(map_node->name());
|
||||||
|
nodes_to_delete.insert(batch_node.name());
|
||||||
|
}
|
||||||
|
graph.DeleteNodes(nodes_to_delete);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
void MapVectorization::Feedback(Cluster* cluster, const GrapplerItem& item,
|
||||||
|
const GraphDef& optimize_output,
|
||||||
|
double result) {
|
||||||
|
// no-op
|
||||||
|
}
|
||||||
|
|
||||||
|
REGISTER_GRAPH_OPTIMIZER_AS(MapVectorization, "map_vectorization");
|
||||||
|
|
||||||
|
} // end namespace grappler
|
||||||
|
} // end namespace tensorflow
|
46
tensorflow/core/grappler/optimizers/data/map_vectorization.h
Normal file
46
tensorflow/core/grappler/optimizers/data/map_vectorization.h
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_MAP_VECTORIZATION_H_
|
||||||
|
#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_MAP_VECTORIZATION_H_
|
||||||
|
|
||||||
|
#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace grappler {
|
||||||
|
|
||||||
|
class MapVectorization : public CustomGraphOptimizer {
|
||||||
|
public:
|
||||||
|
MapVectorization() = default;
|
||||||
|
~MapVectorization() override = default;
|
||||||
|
|
||||||
|
string name() const override { return "map_vectorization"; };
|
||||||
|
|
||||||
|
Status Init(
|
||||||
|
const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override {
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||||
|
GraphDef* output) override;
|
||||||
|
|
||||||
|
void Feedback(Cluster* cluster, const GrapplerItem& item,
|
||||||
|
const GraphDef& optimize_output, double result) override;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // end namespace grappler
|
||||||
|
} // end namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_MAP_VECTORIZATION_H_
|
@ -0,0 +1,201 @@
|
|||||||
|
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/core/grappler/optimizers/data/map_vectorization.h"
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/attr_value_util.h"
|
||||||
|
#include "tensorflow/core/framework/function_testlib.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||||
|
#include "tensorflow/core/grappler/grappler_item.h"
|
||||||
|
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
|
||||||
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace grappler {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
using test::function::GDef;
|
||||||
|
using test::function::NDef;
|
||||||
|
|
||||||
|
void MakeTensorShapeProtoHelper(const gtl::ArraySlice<int> dims,
|
||||||
|
TensorShapeProto* t) {
|
||||||
|
for (size_t i = 0; i < dims.size(); ++i) {
|
||||||
|
auto* d = t->add_dim();
|
||||||
|
d->set_size(dims[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
AttrValue MakeShapeListAttr(
|
||||||
|
const gtl::ArraySlice<const gtl::ArraySlice<int>>& shapes) {
|
||||||
|
AttrValue shapes_attr;
|
||||||
|
for (size_t i = 0; i < shapes.size(); ++i) {
|
||||||
|
MakeTensorShapeProtoHelper(shapes[i],
|
||||||
|
shapes_attr.mutable_list()->add_shape());
|
||||||
|
}
|
||||||
|
|
||||||
|
return shapes_attr;
|
||||||
|
}
|
||||||
|
|
||||||
|
NodeDef MakeMapNodeHelper(
|
||||||
|
StringPiece name, StringPiece input_node_name, StringPiece function_name,
|
||||||
|
StringPiece map_op_name,
|
||||||
|
const gtl::ArraySlice<const gtl::ArraySlice<int>>& output_shapes,
|
||||||
|
const gtl::ArraySlice<DataType>& output_types) {
|
||||||
|
return test::function::NDef(
|
||||||
|
name, map_op_name, {input_node_name.ToString()},
|
||||||
|
{{"f", FunctionDefHelper::FunctionRef(function_name.ToString())},
|
||||||
|
{"Targuments", {}},
|
||||||
|
{"output_shapes", MakeShapeListAttr(output_shapes)},
|
||||||
|
{"output_types", output_types}});
|
||||||
|
}
|
||||||
|
|
||||||
|
NodeDef MakeMapNode(
|
||||||
|
StringPiece name, StringPiece input_node_name, StringPiece function_name,
|
||||||
|
const gtl::ArraySlice<const gtl::ArraySlice<int>>& output_shapes,
|
||||||
|
const gtl::ArraySlice<DataType>& output_types) {
|
||||||
|
return MakeMapNodeHelper(name, input_node_name, function_name, "MapDataset",
|
||||||
|
output_shapes, output_types);
|
||||||
|
}
|
||||||
|
|
||||||
|
NodeDef MakeBatchNode(
|
||||||
|
StringPiece name, StringPiece input_node_name,
|
||||||
|
StringPiece input_batch_size_name,
|
||||||
|
const gtl::ArraySlice<const gtl::ArraySlice<int>>& output_shapes,
|
||||||
|
const gtl::ArraySlice<DataType>& output_types) {
|
||||||
|
return NDef(name, "BatchDataset",
|
||||||
|
{input_node_name.ToString(), input_batch_size_name.ToString()},
|
||||||
|
{{"output_types", output_types},
|
||||||
|
{"output_shapes", MakeShapeListAttr(output_shapes)}});
|
||||||
|
}
|
||||||
|
|
||||||
|
NodeDef MakeBatchV2Node(
|
||||||
|
StringPiece name, StringPiece input_node_name,
|
||||||
|
StringPiece input_batch_size_name, StringPiece input_drop_remainder_name,
|
||||||
|
const gtl::ArraySlice<const gtl::ArraySlice<int>>& output_shapes,
|
||||||
|
const gtl::ArraySlice<DataType>& output_types) {
|
||||||
|
return NDef(name, "BatchDatasetV2",
|
||||||
|
{input_node_name.ToString(), input_batch_size_name.ToString(),
|
||||||
|
input_drop_remainder_name.ToString()},
|
||||||
|
{{"output_types", output_types},
|
||||||
|
{"output_shapes", MakeShapeListAttr(output_shapes)}});
|
||||||
|
}
|
||||||
|
|
||||||
|
NodeDef MakeRangeNode(StringPiece name, const gtl::ArraySlice<string>& inputs) {
|
||||||
|
return NDef(name, "RangeDataset", inputs,
|
||||||
|
{{"output_shapes", MakeShapeListAttr({{}})},
|
||||||
|
{"output_types", gtl::ArraySlice<DataType>({DT_INT64})}});
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(MapVectorizationTest, VectorizeMapWithBatch) {
|
||||||
|
GrapplerItem item;
|
||||||
|
item.graph = GDef(
|
||||||
|
{NDef("start", "Const", {}, {{"value", 0}, {"dtype", DT_INT32}}),
|
||||||
|
NDef("stop", "Const", {}, {{"value", 10}, {"dtype", DT_INT32}}),
|
||||||
|
NDef("step", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
|
||||||
|
NDef("batch_size", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
|
||||||
|
MakeRangeNode("range", {"start", "stop", "step"}),
|
||||||
|
MakeMapNode("map", "range", "XTimesTwo", {{}}, {DT_INT32}),
|
||||||
|
MakeBatchNode("batch", "map", "batch_size", {{-1}}, {DT_INT32})},
|
||||||
|
// FunctionLib
|
||||||
|
{
|
||||||
|
test::function::XTimesTwo(),
|
||||||
|
});
|
||||||
|
MapVectorization optimizer;
|
||||||
|
GraphDef output;
|
||||||
|
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
|
||||||
|
|
||||||
|
EXPECT_EQ(graph_utils::FindAllGraphNodesWithOp("MapDataset", output).size(),
|
||||||
|
1);
|
||||||
|
EXPECT_EQ(graph_utils::FindAllGraphNodesWithOp("BatchDataset", output).size(),
|
||||||
|
1);
|
||||||
|
const NodeDef& map_node =
|
||||||
|
output.node(graph_utils::FindGraphNodeWithOp("MapDataset", output));
|
||||||
|
const NodeDef& batch_node =
|
||||||
|
output.node(graph_utils::FindGraphNodeWithOp("BatchDataset", output));
|
||||||
|
EXPECT_EQ(map_node.input(0), batch_node.name());
|
||||||
|
EXPECT_EQ(batch_node.input(0), "range");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(MapVectorizationTest, VectorizeMapWithBatchV2) {
|
||||||
|
GrapplerItem item;
|
||||||
|
item.graph = GDef(
|
||||||
|
{NDef("start", "Const", {}, {{"value", 0}, {"dtype", DT_INT32}}),
|
||||||
|
NDef("stop", "Const", {}, {{"value", 10}, {"dtype", DT_INT32}}),
|
||||||
|
NDef("step", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
|
||||||
|
NDef("batch_size", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
|
||||||
|
NDef("drop_remainder", "Const", {},
|
||||||
|
{{"value", false}, {"dtype", DT_BOOL}}),
|
||||||
|
MakeRangeNode("range", {"start", "stop", "step"}),
|
||||||
|
MakeMapNode("map", "range", "XTimesTwo", {{}}, {DT_INT32}),
|
||||||
|
MakeBatchV2Node("batch", "map", "batch_size", "drop_remainder", {{-1}},
|
||||||
|
{DT_INT32})},
|
||||||
|
// FunctionLib
|
||||||
|
{
|
||||||
|
test::function::XTimesTwo(),
|
||||||
|
});
|
||||||
|
MapVectorization optimizer;
|
||||||
|
GraphDef output;
|
||||||
|
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
|
||||||
|
|
||||||
|
EXPECT_EQ(graph_utils::FindAllGraphNodesWithOp("MapDataset", output).size(),
|
||||||
|
1);
|
||||||
|
EXPECT_EQ(
|
||||||
|
graph_utils::FindAllGraphNodesWithOp("BatchDatasetV2", output).size(), 1);
|
||||||
|
const NodeDef& map_node =
|
||||||
|
output.node(graph_utils::FindGraphNodeWithOp("MapDataset", output));
|
||||||
|
const NodeDef& batch_node =
|
||||||
|
output.node(graph_utils::FindGraphNodeWithOp("BatchDatasetV2", output));
|
||||||
|
EXPECT_EQ(map_node.input(0), batch_node.name());
|
||||||
|
EXPECT_EQ(batch_node.input(0), "range");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(MapVectorizationTest, VectorizeWithUndefinedOutputShape) {
|
||||||
|
GrapplerItem item;
|
||||||
|
item.graph = GDef(
|
||||||
|
{NDef("batch_size", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
|
||||||
|
NDef("input", "InputDataset", {},
|
||||||
|
{{"output_types", gtl::ArraySlice<DataType>({DT_INT32})}}),
|
||||||
|
MakeMapNode("map", "input", "XTimesTwo", {{}}, {DT_INT32}),
|
||||||
|
MakeBatchNode("batch", "map", "batch_size", {{-1}}, {DT_INT32})},
|
||||||
|
// FunctionLib
|
||||||
|
{
|
||||||
|
test::function::XTimesTwo(),
|
||||||
|
});
|
||||||
|
MapVectorization optimizer;
|
||||||
|
GraphDef output;
|
||||||
|
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(MapVectorizationTest, VectorizeWithUndefinedOutputTypes) {
|
||||||
|
GrapplerItem item;
|
||||||
|
item.graph = GDef(
|
||||||
|
{NDef("batch_size", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
|
||||||
|
NDef("input", "InputDataset", {},
|
||||||
|
{{"output_shapes", MakeShapeListAttr({{}})}}),
|
||||||
|
MakeMapNode("map", "input", "XTimesTwo", {{}}, {DT_INT32}),
|
||||||
|
MakeBatchNode("batch", "map", "batch_size", {{-1}}, {DT_INT32})},
|
||||||
|
// FunctionLib
|
||||||
|
{
|
||||||
|
test::function::XTimesTwo(),
|
||||||
|
});
|
||||||
|
MapVectorization optimizer;
|
||||||
|
GraphDef output;
|
||||||
|
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace grappler
|
||||||
|
} // namespace tensorflow
|
@ -78,7 +78,7 @@ TEST(ShuffleAndRepeatFusionTest, FuseShuffleAndRepeatNodesIntoOne) {
|
|||||||
EXPECT_TRUE(
|
EXPECT_TRUE(
|
||||||
graph_utils::ContainsNodeWithOp("ShuffleAndRepeatDataset", output));
|
graph_utils::ContainsNodeWithOp("ShuffleAndRepeatDataset", output));
|
||||||
NodeDef shuffle_and_repeat_node = output.node(
|
NodeDef shuffle_and_repeat_node = output.node(
|
||||||
graph_utils::FindNodeWithOp("ShuffleAndRepeatDataset", output));
|
graph_utils::FindGraphNodeWithOp("ShuffleAndRepeatDataset", output));
|
||||||
EXPECT_EQ(shuffle_and_repeat_node.input_size(), 5);
|
EXPECT_EQ(shuffle_and_repeat_node.input_size(), 5);
|
||||||
EXPECT_EQ(shuffle_and_repeat_node.input(0), shuffle_node->input(0));
|
EXPECT_EQ(shuffle_and_repeat_node.input(0), shuffle_node->input(0));
|
||||||
EXPECT_EQ(shuffle_and_repeat_node.input(1), shuffle_node->input(1));
|
EXPECT_EQ(shuffle_and_repeat_node.input(1), shuffle_node->input(1));
|
||||||
|
@ -74,7 +74,11 @@ class MapDefunOp : public AsyncOpKernel {
|
|||||||
arg_shapes->at(i).RemoveDim(0); // Remove the first batch dimension
|
arg_shapes->at(i).RemoveDim(0); // Remove the first batch dimension
|
||||||
OP_REQUIRES_ASYNC(
|
OP_REQUIRES_ASYNC(
|
||||||
ctx, batch_size == ctx->input(i).dim_size(0),
|
ctx, batch_size == ctx->input(i).dim_size(0),
|
||||||
errors::InvalidArgument("All inputs must have the same dimension 0."),
|
errors::InvalidArgument(
|
||||||
|
"All inputs must have the same dimension 0. Input ", i,
|
||||||
|
" has leading dimension ", ctx->input(i).dim_size(0),
|
||||||
|
", while all previous inputs have leading dimension ", batch_size,
|
||||||
|
"."),
|
||||||
done);
|
done);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -71,6 +71,7 @@ COMMON_PIP_DEPS = [
|
|||||||
"//tensorflow/contrib/constrained_optimization:constrained_optimization_pip",
|
"//tensorflow/contrib/constrained_optimization:constrained_optimization_pip",
|
||||||
"//tensorflow/contrib/data/python/kernel_tests/serialization:dataset_serialization_test_base",
|
"//tensorflow/contrib/data/python/kernel_tests/serialization:dataset_serialization_test_base",
|
||||||
"//tensorflow/contrib/data/python/kernel_tests:stats_dataset_test_base",
|
"//tensorflow/contrib/data/python/kernel_tests:stats_dataset_test_base",
|
||||||
|
"//tensorflow/contrib/data/python/kernel_tests:test_utils",
|
||||||
"//tensorflow/contrib/data/python/ops:contrib_op_loader",
|
"//tensorflow/contrib/data/python/ops:contrib_op_loader",
|
||||||
"//tensorflow/contrib/eager/python/examples:examples_pip",
|
"//tensorflow/contrib/eager/python/examples:examples_pip",
|
||||||
"//tensorflow/contrib/eager/python:evaluator",
|
"//tensorflow/contrib/eager/python:evaluator",
|
||||||
|
Loading…
Reference in New Issue
Block a user