From 7e1deb16c6f10a477003b5f2caadbe3e6fb48af4 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 13 Jun 2019 03:33:27 -0700 Subject: [PATCH] Improve variable detection in tf.custom_gradient. PiperOrigin-RevId: 252998106 --- tensorflow/contrib/graph_editor/BUILD | 1 + tensorflow/contrib/graph_editor/select.py | 55 +-- tensorflow/python/BUILD | 24 +- tensorflow/python/eager/BUILD | 2 +- tensorflow/python/eager/lift_to_graph.py | 107 +----- tensorflow/python/ops/custom_gradient.py | 35 +- tensorflow/python/ops/gradients_test.py | 36 ++ tensorflow/python/ops/op_selector.py | 395 ++++++++++++++++++++++ tensorflow/python/ops/op_selector_test.py | 166 +++++++++ 9 files changed, 675 insertions(+), 146 deletions(-) create mode 100644 tensorflow/python/ops/op_selector.py create mode 100644 tensorflow/python/ops/op_selector_test.py diff --git a/tensorflow/contrib/graph_editor/BUILD b/tensorflow/contrib/graph_editor/BUILD index 40f749adadc..f4bed99e2dc 100644 --- a/tensorflow/contrib/graph_editor/BUILD +++ b/tensorflow/contrib/graph_editor/BUILD @@ -26,6 +26,7 @@ py_library( deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:op_selector", "//tensorflow/python:platform", "//tensorflow/python:util", "@six_archive//:six", diff --git a/tensorflow/contrib/graph_editor/select.py b/tensorflow/contrib/graph_editor/select.py index d700e6e1a75..3f7bc91ef05 100644 --- a/tensorflow/contrib/graph_editor/select.py +++ b/tensorflow/contrib/graph_editor/select.py @@ -24,7 +24,11 @@ from six import iteritems from six import string_types from tensorflow.contrib.graph_editor import util +from tensorflow.python.ops import op_selector from tensorflow.python.framework import ops as tf_ops +from tensorflow.python.util import deprecation + + __all__ = [ "can_be_regex", @@ -452,6 +456,10 @@ def get_forward_walk_ops(seed_ops, return result +@deprecation.deprecated( + "2019-06-06", + "Please use tensorflow.python.ops.op_selector.get_backward_walk_ops.", + warn_once=True) def get_backward_walk_ops(seed_ops, inclusive=True, within_ops=None, @@ -479,46 +487,13 @@ def get_backward_walk_ops(seed_ops, TypeError: if `seed_ops` or `within_ops` cannot be converted to a list of `tf.Operation`. """ - if not util.is_iterable(seed_ops): - seed_ops = [seed_ops] - if not seed_ops: - return [] - if isinstance(seed_ops[0], tf_ops.Tensor): - ts = util.make_list_of_t(seed_ops, allow_graph=False) - seed_ops = util.get_generating_ops(ts) - else: - seed_ops = util.make_list_of_op(seed_ops, allow_graph=False) - - stop_at_ts = frozenset(util.make_list_of_t(stop_at_ts)) - seed_ops = frozenset(util.make_list_of_op(seed_ops)) - if within_ops: - within_ops = util.make_list_of_op(within_ops, allow_graph=False) - within_ops = frozenset(within_ops) - seed_ops &= within_ops - - def is_within(op): - return (within_ops is None or op in within_ops) and ( - within_ops_fn is None or within_ops_fn(op)) - - result = list(seed_ops) - wave = set(seed_ops) - while wave: - new_wave = set() - for op in wave: - for new_t in op.inputs: - if new_t in stop_at_ts: - continue - if new_t.op not in result and is_within(new_t.op): - new_wave.add(new_t.op) - if control_inputs: - for new_op in op.control_inputs: - if new_op not in result and is_within(new_op): - new_wave.add(new_op) - util.concatenate_unique(result, new_wave) - wave = new_wave - if not inclusive: - result = [op for op in result if op not in seed_ops] - return result + return op_selector.get_backward_walk_ops( + seed_ops, + inclusive=inclusive, + within_ops=within_ops, + within_ops_fn=within_ops_fn, + stop_at_ts=stop_at_ts, + control_inputs=control_inputs) def get_walks_intersection_ops(forward_seed_ops, diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 6145b06a9c1..a56a3fd488e 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -36,13 +36,10 @@ load("//tensorflow:tensorflow.bzl", "py_tests") load("//tensorflow:tensorflow.bzl", "tf_py_build_info_genrule") load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc") load("//tensorflow:tensorflow.bzl", "tf_cc_shared_object") -load("//tensorflow:tensorflow.bzl", "tf_native_cc_binary") -load("//tensorflow:tensorflow.bzl", "tf_custom_op_library_additional_deps_impl") load("//tensorflow:tensorflow.bzl", "cuda_py_test") load("//tensorflow:tensorflow.bzl", "cuda_py_tests") load("//tensorflow/core:platform/default/build_config.bzl", "pyx_library") load("//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library") -load("//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library_py") load("//tensorflow/core:platform/default/build_config.bzl", "tf_additional_lib_deps") load("//tensorflow/core:platform/default/build_config.bzl", "tf_additional_all_protos") load("//tensorflow/core:platform/default/build_config.bzl", "tf_protos_grappler") @@ -3074,6 +3071,13 @@ py_library( ], ) +py_library( + name = "op_selector", + srcs = ["ops/op_selector.py"], + srcs_version = "PY2AND3", + deps = [":framework_ops"], +) + py_library( name = "math_ops", srcs = ["ops/math_ops.py"], @@ -3862,6 +3866,20 @@ cuda_py_test( xla_enable_strict_auto_jit = True, ) +py_test( + name = "op_selector_test", + srcs = ["ops/op_selector_test.py"], + python_version = "PY2", + srcs_version = "PY2AND3", + deps = [ + ":client_testlib", + ":constant_op", + ":framework_ops", + ":math_ops", + ":op_selector", + ], +) + cuda_py_test( name = "gradient_checker_v2_test", size = "medium", diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD index 5d4dc46a5d0..678128c23f7 100644 --- a/tensorflow/python/eager/BUILD +++ b/tensorflow/python/eager/BUILD @@ -585,8 +585,8 @@ py_library( srcs_version = "PY2AND3", visibility = ["//tensorflow:internal"], deps = [ - ":context", "//tensorflow/python:framework_ops", + "//tensorflow/python:op_selector", "@six_archive//:six", ], ) diff --git a/tensorflow/python/eager/lift_to_graph.py b/tensorflow/python/eager/lift_to_graph.py index 86f178b14cc..34294c954d3 100644 --- a/tensorflow/python/eager/lift_to_graph.py +++ b/tensorflow/python/eager/lift_to_graph.py @@ -25,11 +25,11 @@ import six from tensorflow.python.framework import func_graph from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import op_selector from tensorflow.python.ops import resource_variable_ops -def _graph_inputs(op): - return [x.op for x in op.inputs] + list(op.control_inputs) +UnliftableError = op_selector.UnliftableError def _as_operation(op_or_tensor): @@ -38,106 +38,10 @@ def _as_operation(op_or_tensor): return op_or_tensor -class UnliftableError(Exception): - """Raised if a Tensor cannot be lifted from the graph.""" - - # Prevent autograph from rewriting this error. - ag_pass_through = True - - def _constant_inputs(op_or_tensor): return all(_as_operation(i).type == u"Const" and not _as_operation(i).control_inputs - for i in _graph_inputs(_as_operation(op_or_tensor))) - - -def _path_from(from_op, tensor, sources): - """Find one path from `from_op` to `tensor`, ignoring `sources`. - - Args: - from_op: A `tf.Operation`. - tensor: A `tf.Operation` or `tf.Tensor`. - sources: A list of `tf.Tensor`. - - Returns: - A python string containing the path, or "??" if none is found. - """ - visited_ops = set([x.op for x in sources]) - ops_to_visit = [_as_operation(tensor)] - some_op_output = {} - while ops_to_visit: - op = ops_to_visit.pop() - if op in visited_ops: - continue - visited_ops.add(op) - if op == from_op: - path_op = op - path = [path_op] - final_op = _as_operation(tensor) - while path_op != final_op: - path_op = some_op_output[path_op] - path.append(path_op) - return " <- ".join(["%s (%s)" % (x.name, x.type) for x in reversed(path)]) - else: - for inp in _graph_inputs(op): - if inp not in visited_ops and inp not in sources: - some_op_output[inp] = op - ops_to_visit.append(inp) - return "??" - - -def _map_subgraph(init_tensor, sources, disallowed_placeholders, visited_ops, - op_outputs, add_sources): - """Walk a Graph and capture the subgraph between init_tensor and sources. - - Note: This function mutates visited_ops and op_outputs. - - Arguments: - init_tensor: A Tensor or Operation where the subgraph terminates. - sources: A set of Tensors where subgraph extraction should stop. - disallowed_placeholders: An optional set of ops which may not appear in the - lifted graph. Defaults to all placeholders. - visited_ops: A set of operations which were visited in a prior pass. - op_outputs: A defaultdict containing the outputs of an op which are to be - copied into the new subgraph. - add_sources: A boolean indicating whether placeholders which are not in - sources should be allowed. - - Returns: - The set of placeholders upon which init_tensor depends and are not in - sources. - - Raises: - UnliftableError: if init_tensor depends on a placeholder which is not in - sources and add_sources is False. - """ - ops_to_visit = [_as_operation(init_tensor)] - extra_sources = set() - while ops_to_visit: - op = ops_to_visit.pop() - if op in visited_ops: - continue - visited_ops.add(op) - - should_raise = False - if disallowed_placeholders is not None and op in disallowed_placeholders: - should_raise = True - elif op.type == "Placeholder": - if disallowed_placeholders is None and not add_sources: - should_raise = True - extra_sources.update(op.outputs) - - if should_raise: - raise UnliftableError( - "Unable to lift tensor %s because it depends transitively on " - "placeholder %s via at least one path, e.g.: %s" - % (repr(init_tensor), repr(op), _path_from(op, init_tensor, sources))) - for inp in _graph_inputs(op): - op_outputs[inp].add(op) - if inp not in visited_ops and inp not in (sources or extra_sources): - ops_to_visit.append(inp) - - return extra_sources + for i in op_selector.graph_inputs(_as_operation(op_or_tensor))) # Represents an input to `copied_op` which must be updated once @@ -323,7 +227,7 @@ def lift_to_graph(init_tensors, graph, sources=None, # First we extract the subgraph between init_tensors and sources. for init_tensor in init_tensors: - sources.update(_map_subgraph( + sources.update(op_selector.map_subgraph( init_tensor=init_tensor, sources=sources, disallowed_placeholders=disallowed_placeholders, @@ -345,7 +249,7 @@ def lift_to_graph(init_tensors, graph, sources=None, continue marked_ops.add(op) ops_to_copy.append(op) - for inp in _graph_inputs(op): + for inp in op_selector.graph_inputs(op): # Don't lift the TPUReplicateMetadata nodes out of the function, because # it has no registered kernels. if inp.name == "TPUReplicateMetadata": @@ -422,3 +326,4 @@ def lift_to_graph(init_tensors, graph, sources=None, # pylint: enable=protected-access return op_map + diff --git a/tensorflow/python/ops/custom_gradient.py b/tensorflow/python/ops/custom_gradient.py index d1eda831c66..6ef69e8f1c4 100644 --- a/tensorflow/python/ops/custom_gradient.py +++ b/tensorflow/python/ops/custom_gradient.py @@ -25,6 +25,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_array_ops +from tensorflow.python.ops import op_selector from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variable_scope from tensorflow.python.platform import tf_logging as logging @@ -34,6 +35,12 @@ from tensorflow.python.util import tf_inspect from tensorflow.python.util.tf_export import tf_export +VAR_OP_TYPES = [ + "VariableV2", + "VarHandleOp", +] + + def copy_handle_data(source_t, target_t): """Copies HandleData for variant and resource type tensors if available. @@ -163,6 +170,29 @@ def custom_gradient(f): return tf_decorator.make_decorator(f, decorated) +def get_variable_by_name(var_name): + candidate_vars = ops.get_collection( + ops.GraphKeys.GLOBAL_VARIABLES, scope=var_name) + assert len(candidate_vars) == 1 + return candidate_vars[0] + + +def get_dependent_variables(input_ops, output_ops): + """Finds variables involved in the subgraph b/w input_ops and output_ops.""" + + # avoids the edge-case when input_ops == output_ops. + output_ops = nest.map_structure(gen_array_ops.identity, output_ops) + + inbetween_ops = op_selector.get_backward_walk_ops( + seed_ops=output_ops, + stop_at_ts=input_ops, + inclusive=False) + var_ops = (op for op in inbetween_ops if op.type in VAR_OP_TYPES) + var_names = (op.name for op in var_ops) + tf_vars = [get_variable_by_name(var_name) for var_name in var_names] + return tf_vars + + def _graph_mode_decorator(f, *args, **kwargs): """Implement custom gradient decorator for graph mode.""" # TODO(rsepassi): Add support for kwargs @@ -191,7 +221,10 @@ def _graph_mode_decorator(f, *args, **kwargs): "with `use_resource=False`.") # The variables that grad_fn needs to return gradients for are the set of # variables used that are *not* part of the inputs. - variables = list(set(tape.watched_variables()) - set(args)) + tf1_variables = get_dependent_variables(input_ops=args, output_ops=result) + eager_variables = list(set(tape.watched_variables()) - set(args)) + variables = list(set(tf1_variables + eager_variables)) + grad_argspec = tf_inspect.getfullargspec(grad_fn) variables_in_signature = ("variables" in grad_argspec.args or grad_argspec.varkw) diff --git a/tensorflow/python/ops/gradients_test.py b/tensorflow/python/ops/gradients_test.py index 0a01dda94dd..d9de0cd6384 100644 --- a/tensorflow/python/ops/gradients_test.py +++ b/tensorflow/python/ops/gradients_test.py @@ -954,6 +954,42 @@ class CustomGradientTest(test_util.TensorFlowTestCase): dw = sess.run(math_ops.reduce_sum(grads[1])) self.assertEqual(12., dw) + def testCustomGradientWithVariablesNoFalsePositives(self): + + @custom_gradient.custom_gradient + def F(x): + out = core_layers.dense(x, 3, use_bias=False) + + def Grad(out_grad, variables=None): # pylint: disable=redefined-outer-name + self.assertEqual(1, len(variables)) + grads = gradients.gradients(out, [x, variables[0]], grad_ys=out_grad) + return grads[0], [array_ops.ones((3, 3))] + + return out, Grad + + with ops.Graph().as_default(): + with variable_scope.variable_scope("f", use_resource=True) as vs: + a = array_ops.ones((2, 4)) + + # Variabes in these layers shouldn't be picked up by the decorator. + b = core_layers.dense(a, 3, use_bias=False) + c = core_layers.dense(b, 3, use_bias=False) + x = core_layers.dense(b, 3, use_bias=False) + c + + # Only the variables used in F. + y = F(x) + + all_vars = vs.global_variables() + assert len(all_vars) == 4 + grads = gradients.gradients(y, [x] + all_vars) + _, var_grads = grads[0], grads[1:] + for g in grads: + self.assertIsNotNone(g) + with session.Session() as sess: + self.evaluate(variables.global_variables_initializer()) + dw = sess.run(math_ops.reduce_sum(var_grads[-1])) + self.assertEqual(9., dw) + def testCustomGradientWithVariablesEager(self): with context.eager_mode(): layer = core_layers.Dense(4, use_bias=False) diff --git a/tensorflow/python/ops/op_selector.py b/tensorflow/python/ops/op_selector.py new file mode 100644 index 00000000000..2e3a07ded45 --- /dev/null +++ b/tensorflow/python/ops/op_selector.py @@ -0,0 +1,395 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tools for selecting ops in a graph.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import ops + + +def is_iterable(obj): + """Return true if the object is iterable.""" + if isinstance(obj, ops.Tensor): + return False + try: + _ = iter(obj) + except Exception: # pylint: disable=broad-except + return False + return True + + +def concatenate_unique(la, lb): + """Add all the elements of `lb` to `la` if they are not there already. + + The elements added to `la` maintain ordering with respect to `lb`. + + Args: + la: List of Python objects. + lb: List of Python objects. + Returns: + `la`: The list `la` with missing elements from `lb`. + """ + la_set = set(la) + for l in lb: + if l not in la_set: + la.append(l) + la_set.add(l) + return la + + +def get_tensors(graph): + """get all the tensors which are input or output of an op in the graph. + + Args: + graph: a `tf.Graph`. + Returns: + A list of `tf.Tensor`. + Raises: + TypeError: if graph is not a `tf.Graph`. + """ + if not isinstance(graph, ops.Graph): + raise TypeError("Expected a graph, got: {}".format(type(graph))) + ts = [] + for op in graph.get_operations(): + ts += op.outputs + return ts + + +def get_unique_graph(tops, check_types=None, none_if_empty=False): + """Return the unique graph used by the all the elements in tops. + + Args: + tops: list of elements to check (usually a list of tf.Operation and/or + tf.Tensor). Or a tf.Graph. + check_types: check that the element in tops are of given type(s). If None, + the types (tf.Operation, tf.Tensor) are used. + none_if_empty: don't raise an error if tops is an empty list, just return + None. + Returns: + The unique graph used by all the tops. + Raises: + TypeError: if tops is not a iterable of tf.Operation. + ValueError: if the graph is not unique. + """ + if isinstance(tops, ops.Graph): + return tops + if not is_iterable(tops): + raise TypeError("{} is not iterable".format(type(tops))) + if check_types is None: + check_types = (ops.Operation, ops.Tensor) + elif not is_iterable(check_types): + check_types = (check_types,) + g = None + for op in tops: + if not isinstance(op, check_types): + raise TypeError("Expected a type in ({}), got: {}".format(", ".join([str( + t) for t in check_types]), type(op))) + if g is None: + g = op.graph + elif g is not op.graph: + raise ValueError("Operation {} does not belong to given graph".format(op)) + if g is None and not none_if_empty: + raise ValueError("Can't find the unique graph of an empty list") + return g + + +def check_graphs(*args): + """Check that all the element in args belong to the same graph. + + Args: + *args: a list of object with a obj.graph property. + Raises: + ValueError: if all the elements do not belong to the same graph. + """ + graph = None + for i, sgv in enumerate(args): + if graph is None and sgv.graph is not None: + graph = sgv.graph + elif sgv.graph is not None and sgv.graph is not graph: + raise ValueError("Argument[{}]: Wrong graph!".format(i)) + + +def make_list_of_t(ts, check_graph=True, allow_graph=True, ignore_ops=False): + """Convert ts to a list of `tf.Tensor`. + + Args: + ts: can be an iterable of `tf.Tensor`, a `tf.Graph` or a single tensor. + check_graph: if `True` check if all the tensors belong to the same graph. + allow_graph: if `False` a `tf.Graph` cannot be converted. + ignore_ops: if `True`, silently ignore `tf.Operation`. + Returns: + A newly created list of `tf.Tensor`. + Raises: + TypeError: if `ts` cannot be converted to a list of `tf.Tensor` or, + if `check_graph` is `True`, if all the ops do not belong to the same graph. + """ + if isinstance(ts, ops.Graph): + if allow_graph: + return get_tensors(ts) + else: + raise TypeError("allow_graph is False: cannot convert a tf.Graph.") + else: + if not is_iterable(ts): + ts = [ts] + if not ts: + return [] + if check_graph: + check_types = None if ignore_ops else ops.Tensor + get_unique_graph(ts, check_types=check_types) + return [t for t in ts if isinstance(t, ops.Tensor)] + + +def get_generating_ops(ts): + """Return all the generating ops of the tensors in `ts`. + + Args: + ts: a list of `tf.Tensor` + Returns: + A list of all the generating `tf.Operation` of the tensors in `ts`. + Raises: + TypeError: if `ts` cannot be converted to a list of `tf.Tensor`. + """ + ts = make_list_of_t(ts, allow_graph=False) + return [t.op for t in ts] + + +def get_consuming_ops(ts): + """Return all the consuming ops of the tensors in ts. + + Args: + ts: a list of `tf.Tensor` + Returns: + A list of all the consuming `tf.Operation` of the tensors in `ts`. + Raises: + TypeError: if ts cannot be converted to a list of `tf.Tensor`. + """ + ts = make_list_of_t(ts, allow_graph=False) + tops = [] + for t in ts: + for op in t.consumers(): + if op not in tops: + tops.append(op) + return tops + + +def make_list_of_op(tops, check_graph=True, allow_graph=True, ignore_ts=False): + """Convert ops to a list of `tf.Operation`. + + Args: + tops: can be an iterable of `tf.Operation`, a `tf.Graph` or a single + operation. + check_graph: if `True` check if all the operations belong to the same graph. + allow_graph: if `False` a `tf.Graph` cannot be converted. + ignore_ts: if True, silently ignore `tf.Tensor`. + Returns: + A newly created list of `tf.Operation`. + Raises: + TypeError: if tops cannot be converted to a list of `tf.Operation` or, + if `check_graph` is `True`, if all the ops do not belong to the + same graph. + """ + if isinstance(tops, ops.Graph): + if allow_graph: + return tops.get_operations() + else: + raise TypeError("allow_graph is False: cannot convert a tf.Graph.") + else: + if not is_iterable(tops): + tops = [tops] + if not tops: + return [] + if check_graph: + check_types = None if ignore_ts else ops.Operation + get_unique_graph(tops, check_types=check_types) + return [op for op in tops if isinstance(op, ops.Operation)] + + +def get_backward_walk_ops(seed_ops, + inclusive=True, + within_ops=None, + within_ops_fn=None, + stop_at_ts=(), + control_inputs=False): + """Do a backward graph walk and return all the visited ops. + + Args: + seed_ops: an iterable of operations from which the backward graph + walk starts. If a list of tensors is given instead, the seed_ops are set + to be the generators of those tensors. + inclusive: if True the given seed_ops are also part of the resulting set. + within_ops: an iterable of `tf.Operation` within which the search is + restricted. If `within_ops` is `None`, the search is performed within + the whole graph. + within_ops_fn: if provided, a function on ops that should return True iff + the op is within the graph traversal. This can be used along within_ops, + in which case an op is within if it is also in within_ops. + stop_at_ts: an iterable of tensors at which the graph walk stops. + control_inputs: if True, control inputs will be used while moving backward. + Returns: + A Python set of all the `tf.Operation` behind `seed_ops`. + Raises: + TypeError: if `seed_ops` or `within_ops` cannot be converted to a list of + `tf.Operation`. + """ + if not is_iterable(seed_ops): + seed_ops = [seed_ops] + if not seed_ops: + return [] + if isinstance(seed_ops[0], ops.Tensor): + ts = make_list_of_t(seed_ops, allow_graph=False) + seed_ops = get_generating_ops(ts) + else: + seed_ops = make_list_of_op(seed_ops, allow_graph=False) + + stop_at_ts = frozenset(make_list_of_t(stop_at_ts)) + seed_ops = frozenset(make_list_of_op(seed_ops)) + if within_ops: + within_ops = make_list_of_op(within_ops, allow_graph=False) + within_ops = frozenset(within_ops) + seed_ops &= within_ops + + def is_within(op): + return (within_ops is None or op in within_ops) and ( + within_ops_fn is None or within_ops_fn(op)) + + result = list(seed_ops) + wave = set(seed_ops) + while wave: + new_wave = set() + for op in wave: + for new_t in op.inputs: + if new_t in stop_at_ts: + continue + if new_t.op not in result and is_within(new_t.op): + new_wave.add(new_t.op) + if control_inputs: + for new_op in op.control_inputs: + if new_op not in result and is_within(new_op): + new_wave.add(new_op) + concatenate_unique(result, new_wave) + wave = new_wave + if not inclusive: + result = [op for op in result if op not in seed_ops] + return result + + +class UnliftableError(Exception): + """Raised if a Tensor cannot be lifted from the graph.""" + + # Prevent autograph from rewriting this error. + ag_pass_through = True + + +def _as_operation(op_or_tensor): + if isinstance(op_or_tensor, ops.Tensor): + return op_or_tensor.op + return op_or_tensor + + +def graph_inputs(op): + return [x.op for x in op.inputs] + list(op.control_inputs) + + +def _path_from(from_op, tensor, sources): + """Find one path from `from_op` to `tensor`, ignoring `sources`. + + Args: + from_op: A `tf.Operation`. + tensor: A `tf.Operation` or `tf.Tensor`. + sources: A list of `tf.Tensor`. + + Returns: + A python string containing the path, or "??" if none is found. + """ + visited_ops = set([x.op for x in sources]) + ops_to_visit = [_as_operation(tensor)] + some_op_output = {} + while ops_to_visit: + op = ops_to_visit.pop() + if op in visited_ops: + continue + visited_ops.add(op) + if op == from_op: + path_op = op + path = [path_op] + final_op = _as_operation(tensor) + while path_op != final_op: + path_op = some_op_output[path_op] + path.append(path_op) + return " <- ".join(["%s (%s)" % (x.name, x.type) for x in reversed(path)]) + else: + for inp in graph_inputs(op): + if inp not in visited_ops and inp not in sources: + some_op_output[inp] = op + ops_to_visit.append(inp) + return "??" + + +# TODO(jmenick) - there is considerable duplication of functionality between +# this function and get_backward_walk_ops(). Need to deduplicate. +def map_subgraph(init_tensor, sources, disallowed_placeholders, visited_ops, + op_outputs, add_sources): + """Walk a Graph and capture the subgraph between init_tensor and sources. + + Note: This function mutates visited_ops and op_outputs. + + Arguments: + init_tensor: A Tensor or Operation where the subgraph terminates. + sources: A set of Tensors where subgraph extraction should stop. + disallowed_placeholders: An optional set of ops which may not appear in the + lifted graph. Defaults to all placeholders. + visited_ops: A set of operations which were visited in a prior pass. + op_outputs: A defaultdict containing the outputs of an op which are to be + copied into the new subgraph. + add_sources: A boolean indicating whether placeholders which are not in + sources should be allowed. + + Returns: + The set of placeholders upon which init_tensor depends and are not in + sources. + + Raises: + UnliftableError: if init_tensor depends on a placeholder which is not in + sources and add_sources is False. + """ + ops_to_visit = [_as_operation(init_tensor)] + extra_sources = set() + while ops_to_visit: + op = ops_to_visit.pop() + if op in visited_ops: + continue + visited_ops.add(op) + + should_raise = False + if disallowed_placeholders is not None and op in disallowed_placeholders: + should_raise = True + elif op.type == "Placeholder": + if disallowed_placeholders is None and not add_sources: + should_raise = True + extra_sources.update(op.outputs) + + if should_raise: + raise UnliftableError( + "Unable to lift tensor %s because it depends transitively on " + "placeholder %s via at least one path, e.g.: %s" + % (repr(init_tensor), repr(op), _path_from(op, init_tensor, sources))) + for inp in graph_inputs(op): + op_outputs[inp].add(op) + if inp not in visited_ops and inp not in (sources or extra_sources): + ops_to_visit.append(inp) + + return extra_sources diff --git a/tensorflow/python/ops/op_selector_test.py b/tensorflow/python/ops/op_selector_test.py new file mode 100644 index 00000000000..daeaeb4efab --- /dev/null +++ b/tensorflow/python/ops/op_selector_test.py @@ -0,0 +1,166 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for op_selector.py.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops as ops_lib +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import op_selector +from tensorflow.python.platform import test + + +class SelectTest(test.TestCase): + + def setUp(self): + self.graph = ops_lib.Graph() + with self.graph.as_default(): + self.a = constant_op.constant([1., 1.], shape=[2], name="a") + with ops_lib.name_scope("foo"): + self.b = constant_op.constant([2., 2.], shape=[2], name="b") + self.c = math_ops.add(self.a, self.b, name="c") + self.d = constant_op.constant([3., 3.], shape=[2], name="d") + with ops_lib.name_scope("bar"): + self.e = math_ops.add(self.c, self.d, name="e") + self.f = math_ops.add(self.c, self.d, name="f") + self.g = math_ops.add(self.c, self.a, name="g") + with ops_lib.control_dependencies([self.c.op]): + self.h = math_ops.add(self.f, self.g, name="h") + + def test_is_iterable(self): + """Test for is_iterable.""" + self.assertTrue(op_selector.is_iterable([0, 1, 2])) + self.assertFalse(op_selector.is_iterable(3)) + + def test_unique_graph(self): + """Test for check_graphs and get_unique_graph.""" + g0 = ops_lib.Graph() + with g0.as_default(): + a0 = constant_op.constant(1) + b0 = constant_op.constant(2) + g1 = ops_lib.Graph() + with g1.as_default(): + a1 = constant_op.constant(1) + b1 = constant_op.constant(2) + # Same graph, should be fine. + self.assertIsNone(op_selector.check_graphs(a0, b0)) + # Two different graphs, should assert. + with self.assertRaises(ValueError): + op_selector.check_graphs(a0, b0, a1, b1) + # a0 and b0 belongs to the same graph, should be fine. + self.assertEqual(op_selector.get_unique_graph([a0, b0]), g0) + # Different graph, should raise an error. + with self.assertRaises(ValueError): + op_selector.get_unique_graph([a0, b0, a1, b1]) + + def test_make_list_of_op(self): + """Test for make_list_of_op.""" + g0 = ops_lib.Graph() + with g0.as_default(): + a0 = constant_op.constant(1) + b0 = constant_op.constant(2) + # Should extract the ops from the graph. + self.assertEqual(len(op_selector.make_list_of_op(g0)), 2) + # Should extract the ops from the tuple. + self.assertEqual(len(op_selector.make_list_of_op((a0.op, b0.op))), 2) + + def test_make_list_of_t(self): + """Test for make_list_of_t.""" + g0 = ops_lib.Graph() + with g0.as_default(): + a0 = constant_op.constant(1) + b0 = constant_op.constant(2) + c0 = math_ops.add(a0, b0) # pylint: disable=unused-variable + # Should extract the tensors from tre graph. + self.assertEqual(len(op_selector.make_list_of_t(g0)), 3) + # Should extract the tensors from the tuple + self.assertEqual(len(op_selector.make_list_of_t((a0, b0))), 2) + # Should extract the tensors and ignore the ops. + self.assertEqual( + len(op_selector.make_list_of_t( + (a0, a0.op, b0), ignore_ops=True)), 2) + + def test_get_generating_consuming(self): + """Test for get_generating_ops and get_consuming_ops.""" + g0 = ops_lib.Graph() + with g0.as_default(): + a0 = constant_op.constant(1) + b0 = constant_op.constant(2) + c0 = math_ops.add(a0, b0) + self.assertEqual(len(op_selector.get_generating_ops([a0, b0])), 2) + self.assertEqual(len(op_selector.get_consuming_ops([a0, b0])), 1) + self.assertEqual(len(op_selector.get_generating_ops([c0])), 1) + self.assertEqual(op_selector.get_consuming_ops([c0]), []) + + def test_backward_walk_ops(self): + seed_ops = [self.h.op] + # Include all ops except for self.g.op + within_ops = [ + x.op for x in [self.a, self.b, self.c, self.d, self.e, self.f, self.h] + ] + # For the fn, exclude self.c.op. + within_ops_fn = lambda op: op not in (self.c.op,) + stop_at_ts = (self.f,) + + with self.graph.as_default(): + # Backward walk only includes h since we stop at f and g is not within. + ops = op_selector.get_backward_walk_ops( + seed_ops, + inclusive=True, + within_ops=within_ops, + within_ops_fn=within_ops_fn, + stop_at_ts=stop_at_ts) + self.assertEqual(set(ops), set([self.h.op])) + + # If we do inclusive=False, the result is empty. + ops = op_selector.get_backward_walk_ops( + seed_ops, + inclusive=False, + within_ops=within_ops, + within_ops_fn=within_ops_fn, + stop_at_ts=stop_at_ts) + self.assertEqual(set(ops), set()) + + # Removing stop_at_fs adds f.op, d.op. + ops = op_selector.get_backward_walk_ops( + seed_ops, + inclusive=True, + within_ops=within_ops, + within_ops_fn=within_ops_fn) + self.assertEqual(set(ops), set([self.d.op, self.f.op, self.h.op])) + + # Not using within_ops_fn adds back ops for a, b, c. + ops = op_selector.get_backward_walk_ops( + seed_ops, inclusive=True, within_ops=within_ops) + self.assertEqual( + set(ops), + set([ + self.a.op, self.b.op, self.c.op, self.d.op, self.f.op, self.h.op + ])) + + # Vanially backward search via self.h.op includes everything excpet e.op. + ops = op_selector.get_backward_walk_ops(seed_ops, inclusive=True) + self.assertEqual( + set(ops), + set([ + self.a.op, self.b.op, self.c.op, self.d.op, self.f.op, self.g.op, + self.h.op + ])) + + +if __name__ == "__main__": + test.main()