[tf.data] Remove noop map functions in noop_elimination.
PiperOrigin-RevId: 298935664 Change-Id: I83c644cba22048890a77d767e0fbf26f5a46dbb7
This commit is contained in:
parent
cf06dfad84
commit
50a19c3549
@ -655,6 +655,7 @@ cc_library(
|
||||
"noop_elimination.h",
|
||||
],
|
||||
deps = [
|
||||
":function_utils",
|
||||
":graph_utils",
|
||||
":optimizer_base",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/grappler/mutable_graph_view.h"
|
||||
#include "tensorflow/core/grappler/op_types.h"
|
||||
#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
|
||||
#include "tensorflow/core/grappler/optimizers/data/function_utils.h"
|
||||
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
|
||||
#include "tensorflow/core/grappler/utils.h"
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
@ -31,6 +32,8 @@ namespace tensorflow {
|
||||
namespace grappler {
|
||||
namespace {
|
||||
|
||||
constexpr char kIdentity[] = "Identity";
|
||||
|
||||
bool IsTakeAll(const NodeDef& take_node, const MutableGraphView& graph) {
|
||||
if (take_node.op() != "TakeDataset") return false;
|
||||
|
||||
@ -64,9 +67,71 @@ bool IsPrefetchZero(const NodeDef& prefetch_node,
|
||||
return IsConstNodeWithValue(*graph.GetNode(prefetch_node.input(1)), 0);
|
||||
}
|
||||
|
||||
bool IsOutputIdentityOfInput(const FunctionDef& fdef, const string& output_arg,
|
||||
const string& input_arg) {
|
||||
if (!fdef.ret().contains(output_arg)) {
|
||||
LOG(WARNING)
|
||||
<< "Malformed FunctionDef: ret dict does not contain output arg key.";
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto& ret_val = fdef.ret().at(output_arg);
|
||||
auto input = function_utils::FunctionDefTensorDesc(ret_val);
|
||||
|
||||
// Walk from output to input. If any node along the path is not an
|
||||
// Identity node, return false.
|
||||
while (function_utils::ContainsFunctionNodeWithName(input.node_name, fdef)) {
|
||||
int idx = function_utils::FindFunctionNodeWithName(input.node_name, fdef);
|
||||
|
||||
const NodeDef& node = fdef.node_def(idx);
|
||||
if (node.op() != kIdentity) {
|
||||
return false;
|
||||
}
|
||||
|
||||
input = function_utils::FunctionDefTensorDesc(node.input(0));
|
||||
}
|
||||
|
||||
// If we get here, input is not a node. Check that it matches the correct
|
||||
// input arg name.
|
||||
return input.node_name == input_arg;
|
||||
}
|
||||
|
||||
bool IsMapIdentity(const NodeDef& map_node, const MutableGraphView& graph) {
|
||||
if (map_node.op() != "MapDataset" && map_node.op() != "ParallelMapDataset" &&
|
||||
map_node.op() != "ParallelMapDatasetV2") {
|
||||
return false;
|
||||
}
|
||||
|
||||
// We are looking only for map(lambda *x: x) nodes.
|
||||
|
||||
// Don't eliminate map nodes with captured arguments.
|
||||
if (map_node.attr().at("Targuments").list().type_size() != 0) return false;
|
||||
|
||||
FunctionLibraryDefinition function_library(OpRegistry::Global(),
|
||||
graph.graph()->library());
|
||||
const FunctionDef* fdef =
|
||||
function_library.Find(map_node.attr().at("f").func().name());
|
||||
|
||||
// Don't eliminate map nodes with stateful functions.
|
||||
if (function_utils::IsFunctionStateful(function_library, *fdef)) return false;
|
||||
|
||||
const auto& sig = fdef->signature();
|
||||
if (sig.input_arg_size() != sig.output_arg_size()) return false;
|
||||
|
||||
// For each output, check that it maps to input i
|
||||
for (int i = 0; i < sig.input_arg_size(); ++i) {
|
||||
if (!IsOutputIdentityOfInput(*fdef, sig.output_arg(i).name(),
|
||||
sig.input_arg(i).name())) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool IsNoOp(const NodeDef& node, const MutableGraphView& graph) {
|
||||
return IsTakeAll(node, graph) || IsSkipNone(node, graph) ||
|
||||
IsRepeatOne(node, graph) || IsPrefetchZero(node, graph);
|
||||
IsRepeatOne(node, graph) || IsPrefetchZero(node, graph) ||
|
||||
IsMapIdentity(node, graph);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -391,7 +391,7 @@ class AutoShardDatasetTest(reader_dataset_ops_test_base.TFRecordDatasetTestBase,
|
||||
# Tests that Rebatch is a passthrough op.
|
||||
dataset = dataset_ops.Dataset.list_files(self.test_filenames, shuffle=False)
|
||||
dataset = dataset.apply(
|
||||
testing.assert_next(["Shard", "FlatMap", "BatchV2", "Map", "Rebatch"]))
|
||||
testing.assert_next(["Shard", "FlatMap", "BatchV2", "Rebatch"]))
|
||||
dataset = dataset.flat_map(core_readers.TFRecordDataset)
|
||||
dataset = dataset.batch(5)
|
||||
dataset = distribute._RebatchDataset(dataset, num_replicas=1)
|
||||
|
@ -17,37 +17,109 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import functools
|
||||
|
||||
from absl.testing import parameterized
|
||||
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.data.experimental.ops import testing
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import combinations
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import logging_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
def _test_combinations():
|
||||
|
||||
def make_range():
|
||||
return dataset_ops.Dataset.range(10)
|
||||
|
||||
def fn_with_side_effect(arg):
|
||||
logging_ops.print_v2(arg)
|
||||
return arg
|
||||
|
||||
# Test case for map function with capture args
|
||||
def apply_map_with_capture(ds):
|
||||
const = constant_op.constant(-1, dtype=dtypes.int64)
|
||||
return ds.map(lambda x: (x, const))
|
||||
|
||||
# Test case for map functions with multiple components
|
||||
def apply_map_with_multiple_components(ds):
|
||||
ds = ds.map(lambda x: (x, x), num_parallel_calls=2) # Not eliminated
|
||||
return ds.map(lambda x, y: (x, y)) # Eliminated
|
||||
|
||||
parallel_map_name = "ParallelMapV2" if compat.forward_compatible(
|
||||
2020, 3, 6) else "ParallelMap"
|
||||
|
||||
cases = [
|
||||
("Skip0", lambda ds: ds.skip(0), None),
|
||||
("SkipN", lambda ds: ds.skip(5), "FiniteSkip"),
|
||||
("Repeat1", lambda ds: ds.repeat(1), None),
|
||||
("RepeatN", lambda ds: ds.repeat(10), "FiniteRepeat"),
|
||||
("Prefetch0", lambda ds: ds.prefetch(0), None),
|
||||
("PrefetchN", lambda ds: ds.prefetch(1), "Prefetch"),
|
||||
("Take-1", lambda ds: ds.take(-1), None),
|
||||
("TakeN", lambda ds: ds.take(2), "FiniteTake"),
|
||||
("MapIdentity", lambda ds: ds.map(lambda x: x), None),
|
||||
("MapNonIdentity", lambda ds: ds.map(lambda x: x * 2), "Map"),
|
||||
("MapWithSideEffect", lambda ds: ds.map(fn_with_side_effect), "Map"),
|
||||
("MapWithCapture", apply_map_with_capture, "Map"),
|
||||
("MapWithMultipleComponents", apply_map_with_multiple_components,
|
||||
parallel_map_name),
|
||||
("MapRestructure", lambda ds: ds.map(lambda x: {"value": x}), ""),
|
||||
("PMapIdentity", lambda ds: ds.map(lambda x: x, num_parallel_calls=2),
|
||||
None),
|
||||
("PMapNonIdentity",
|
||||
lambda ds: ds.map(lambda x: x * 2, num_parallel_calls=2),
|
||||
parallel_map_name),
|
||||
]
|
||||
|
||||
def reduce_fn(result, case):
|
||||
name, transformation, expected = case
|
||||
return result + combinations.combine(
|
||||
init_dataset_fn=make_range,
|
||||
transformation=combinations.NamedObject(name, transformation),
|
||||
expected_name=expected)
|
||||
|
||||
test_combinations = functools.reduce(reduce_fn, cases, [])
|
||||
|
||||
return test_combinations
|
||||
|
||||
|
||||
class NoopEliminationTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testNoopElimination(self):
|
||||
a = constant_op.constant(1, dtype=dtypes.int64)
|
||||
b = constant_op.constant(2, dtype=dtypes.int64)
|
||||
some_tensor = math_ops.mul(a, b)
|
||||
@combinations.generate(
|
||||
combinations.times(test_base.default_test_combinations(),
|
||||
_test_combinations()))
|
||||
def testNoopElimination(self, init_dataset_fn, transformation, expected_name):
|
||||
"""Runs a noop elimination test case.
|
||||
|
||||
dataset = dataset_ops.Dataset.range(5)
|
||||
dataset = dataset.apply(
|
||||
testing.assert_next(
|
||||
["FiniteRepeat", "FiniteSkip", "Prefetch", "MemoryCacheImpl"]))
|
||||
dataset = dataset.repeat(some_tensor).skip(5).take(-1).skip(0).repeat(
|
||||
1).prefetch(0).prefetch(1).cache()
|
||||
Args:
|
||||
init_dataset_fn: Function to create the initial dataset
|
||||
transformation: Transformation to apply
|
||||
expected_name: Name of the transformation if it is not eliminated
|
||||
"""
|
||||
dataset = init_dataset_fn()
|
||||
|
||||
if expected_name:
|
||||
dataset = dataset.apply(
|
||||
testing.assert_next([expected_name, "FiniteTake"]))
|
||||
else:
|
||||
dataset = dataset.apply(testing.assert_next(["FiniteTake"]))
|
||||
|
||||
dataset = dataset.apply(transformation)
|
||||
dataset = dataset.take(1)
|
||||
options = dataset_ops.Options()
|
||||
options.experimental_optimization.apply_default_optimizations = False
|
||||
options.experimental_optimization.noop_elimination = True
|
||||
dataset = dataset.with_options(options)
|
||||
self.assertDatasetProduces(dataset, expected_output=range(5))
|
||||
|
||||
# Run the first iteration for the side effect of checking the assertion.
|
||||
get_next = self.getNext(dataset)
|
||||
self.evaluate(get_next())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -260,7 +260,7 @@ class ReduceTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
def testOptions(self):
|
||||
dataset = dataset_ops.Dataset.range(5)
|
||||
dataset = dataset.apply(testing.assert_next(["MapAndBatch"]))
|
||||
dataset = dataset.map(lambda x: x).batch(5)
|
||||
dataset = dataset.map(lambda x: x * 2).batch(5)
|
||||
self.evaluate(dataset.reduce(0, lambda state, value: state))
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user