[tf.data] Remove noop map functions in noop_elimination.

PiperOrigin-RevId: 298935664
Change-Id: I83c644cba22048890a77d767e0fbf26f5a46dbb7
This commit is contained in:
Rachel Lim 2020-03-04 14:04:48 -08:00 committed by TensorFlower Gardener
parent cf06dfad84
commit 50a19c3549
5 changed files with 154 additions and 16 deletions

View File

@ -655,6 +655,7 @@ cc_library(
"noop_elimination.h",
],
deps = [
":function_utils",
":graph_utils",
":optimizer_base",
"@com_google_absl//absl/container:flat_hash_set",

View File

@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/core/grappler/mutable_graph_view.h"
#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
#include "tensorflow/core/grappler/optimizers/data/function_utils.h"
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/platform/protobuf.h"
@ -31,6 +32,8 @@ namespace tensorflow {
namespace grappler {
namespace {
constexpr char kIdentity[] = "Identity";
bool IsTakeAll(const NodeDef& take_node, const MutableGraphView& graph) {
if (take_node.op() != "TakeDataset") return false;
@ -64,9 +67,71 @@ bool IsPrefetchZero(const NodeDef& prefetch_node,
return IsConstNodeWithValue(*graph.GetNode(prefetch_node.input(1)), 0);
}
bool IsOutputIdentityOfInput(const FunctionDef& fdef, const string& output_arg,
const string& input_arg) {
if (!fdef.ret().contains(output_arg)) {
LOG(WARNING)
<< "Malformed FunctionDef: ret dict does not contain output arg key.";
return false;
}
const auto& ret_val = fdef.ret().at(output_arg);
auto input = function_utils::FunctionDefTensorDesc(ret_val);
// Walk from output to input. If any node along the path is not an
// Identity node, return false.
while (function_utils::ContainsFunctionNodeWithName(input.node_name, fdef)) {
int idx = function_utils::FindFunctionNodeWithName(input.node_name, fdef);
const NodeDef& node = fdef.node_def(idx);
if (node.op() != kIdentity) {
return false;
}
input = function_utils::FunctionDefTensorDesc(node.input(0));
}
// If we get here, input is not a node. Check that it matches the correct
// input arg name.
return input.node_name == input_arg;
}
bool IsMapIdentity(const NodeDef& map_node, const MutableGraphView& graph) {
if (map_node.op() != "MapDataset" && map_node.op() != "ParallelMapDataset" &&
map_node.op() != "ParallelMapDatasetV2") {
return false;
}
// We are looking only for map(lambda *x: x) nodes.
// Don't eliminate map nodes with captured arguments.
if (map_node.attr().at("Targuments").list().type_size() != 0) return false;
FunctionLibraryDefinition function_library(OpRegistry::Global(),
graph.graph()->library());
const FunctionDef* fdef =
function_library.Find(map_node.attr().at("f").func().name());
// Don't eliminate map nodes with stateful functions.
if (function_utils::IsFunctionStateful(function_library, *fdef)) return false;
const auto& sig = fdef->signature();
if (sig.input_arg_size() != sig.output_arg_size()) return false;
// For each output, check that it maps to input i
for (int i = 0; i < sig.input_arg_size(); ++i) {
if (!IsOutputIdentityOfInput(*fdef, sig.output_arg(i).name(),
sig.input_arg(i).name())) {
return false;
}
}
return true;
}
bool IsNoOp(const NodeDef& node, const MutableGraphView& graph) {
return IsTakeAll(node, graph) || IsSkipNone(node, graph) ||
IsRepeatOne(node, graph) || IsPrefetchZero(node, graph);
IsRepeatOne(node, graph) || IsPrefetchZero(node, graph) ||
IsMapIdentity(node, graph);
}
} // namespace

View File

@ -391,7 +391,7 @@ class AutoShardDatasetTest(reader_dataset_ops_test_base.TFRecordDatasetTestBase,
# Tests that Rebatch is a passthrough op.
dataset = dataset_ops.Dataset.list_files(self.test_filenames, shuffle=False)
dataset = dataset.apply(
testing.assert_next(["Shard", "FlatMap", "BatchV2", "Map", "Rebatch"]))
testing.assert_next(["Shard", "FlatMap", "BatchV2", "Rebatch"]))
dataset = dataset.flat_map(core_readers.TFRecordDataset)
dataset = dataset.batch(5)
dataset = distribute._RebatchDataset(dataset, num_replicas=1)

View File

@ -17,37 +17,109 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
from absl.testing import parameterized
from tensorflow.python.compat import compat
from tensorflow.python.data.experimental.ops import testing
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import combinations
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import logging_ops
from tensorflow.python.platform import test
def _test_combinations():
def make_range():
return dataset_ops.Dataset.range(10)
def fn_with_side_effect(arg):
logging_ops.print_v2(arg)
return arg
# Test case for map function with capture args
def apply_map_with_capture(ds):
const = constant_op.constant(-1, dtype=dtypes.int64)
return ds.map(lambda x: (x, const))
# Test case for map functions with multiple components
def apply_map_with_multiple_components(ds):
ds = ds.map(lambda x: (x, x), num_parallel_calls=2) # Not eliminated
return ds.map(lambda x, y: (x, y)) # Eliminated
parallel_map_name = "ParallelMapV2" if compat.forward_compatible(
2020, 3, 6) else "ParallelMap"
cases = [
("Skip0", lambda ds: ds.skip(0), None),
("SkipN", lambda ds: ds.skip(5), "FiniteSkip"),
("Repeat1", lambda ds: ds.repeat(1), None),
("RepeatN", lambda ds: ds.repeat(10), "FiniteRepeat"),
("Prefetch0", lambda ds: ds.prefetch(0), None),
("PrefetchN", lambda ds: ds.prefetch(1), "Prefetch"),
("Take-1", lambda ds: ds.take(-1), None),
("TakeN", lambda ds: ds.take(2), "FiniteTake"),
("MapIdentity", lambda ds: ds.map(lambda x: x), None),
("MapNonIdentity", lambda ds: ds.map(lambda x: x * 2), "Map"),
("MapWithSideEffect", lambda ds: ds.map(fn_with_side_effect), "Map"),
("MapWithCapture", apply_map_with_capture, "Map"),
("MapWithMultipleComponents", apply_map_with_multiple_components,
parallel_map_name),
("MapRestructure", lambda ds: ds.map(lambda x: {"value": x}), ""),
("PMapIdentity", lambda ds: ds.map(lambda x: x, num_parallel_calls=2),
None),
("PMapNonIdentity",
lambda ds: ds.map(lambda x: x * 2, num_parallel_calls=2),
parallel_map_name),
]
def reduce_fn(result, case):
name, transformation, expected = case
return result + combinations.combine(
init_dataset_fn=make_range,
transformation=combinations.NamedObject(name, transformation),
expected_name=expected)
test_combinations = functools.reduce(reduce_fn, cases, [])
return test_combinations
class NoopEliminationTest(test_base.DatasetTestBase, parameterized.TestCase):
@combinations.generate(test_base.default_test_combinations())
def testNoopElimination(self):
a = constant_op.constant(1, dtype=dtypes.int64)
b = constant_op.constant(2, dtype=dtypes.int64)
some_tensor = math_ops.mul(a, b)
@combinations.generate(
combinations.times(test_base.default_test_combinations(),
_test_combinations()))
def testNoopElimination(self, init_dataset_fn, transformation, expected_name):
"""Runs a noop elimination test case.
dataset = dataset_ops.Dataset.range(5)
dataset = dataset.apply(
testing.assert_next(
["FiniteRepeat", "FiniteSkip", "Prefetch", "MemoryCacheImpl"]))
dataset = dataset.repeat(some_tensor).skip(5).take(-1).skip(0).repeat(
1).prefetch(0).prefetch(1).cache()
Args:
init_dataset_fn: Function to create the initial dataset
transformation: Transformation to apply
expected_name: Name of the transformation if it is not eliminated
"""
dataset = init_dataset_fn()
if expected_name:
dataset = dataset.apply(
testing.assert_next([expected_name, "FiniteTake"]))
else:
dataset = dataset.apply(testing.assert_next(["FiniteTake"]))
dataset = dataset.apply(transformation)
dataset = dataset.take(1)
options = dataset_ops.Options()
options.experimental_optimization.apply_default_optimizations = False
options.experimental_optimization.noop_elimination = True
dataset = dataset.with_options(options)
self.assertDatasetProduces(dataset, expected_output=range(5))
# Run the first iteration for the side effect of checking the assertion.
get_next = self.getNext(dataset)
self.evaluate(get_next())
if __name__ == "__main__":

View File

@ -260,7 +260,7 @@ class ReduceTest(test_base.DatasetTestBase, parameterized.TestCase):
def testOptions(self):
dataset = dataset_ops.Dataset.range(5)
dataset = dataset.apply(testing.assert_next(["MapAndBatch"]))
dataset = dataset.map(lambda x: x).batch(5)
dataset = dataset.map(lambda x: x * 2).batch(5)
self.evaluate(dataset.reduce(0, lambda state, value: state))