Improve variable detection in tf.custom_gradient.
PiperOrigin-RevId: 252998106
This commit is contained in:
parent
6237fcb1a9
commit
7e1deb16c6
|
@ -26,6 +26,7 @@ py_library(
|
|||
deps = [
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:op_selector",
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/python:util",
|
||||
"@six_archive//:six",
|
||||
|
|
|
@ -24,7 +24,11 @@ from six import iteritems
|
|||
from six import string_types
|
||||
|
||||
from tensorflow.contrib.graph_editor import util
|
||||
from tensorflow.python.ops import op_selector
|
||||
from tensorflow.python.framework import ops as tf_ops
|
||||
from tensorflow.python.util import deprecation
|
||||
|
||||
|
||||
|
||||
__all__ = [
|
||||
"can_be_regex",
|
||||
|
@ -452,6 +456,10 @@ def get_forward_walk_ops(seed_ops,
|
|||
return result
|
||||
|
||||
|
||||
@deprecation.deprecated(
|
||||
"2019-06-06",
|
||||
"Please use tensorflow.python.ops.op_selector.get_backward_walk_ops.",
|
||||
warn_once=True)
|
||||
def get_backward_walk_ops(seed_ops,
|
||||
inclusive=True,
|
||||
within_ops=None,
|
||||
|
@ -479,46 +487,13 @@ def get_backward_walk_ops(seed_ops,
|
|||
TypeError: if `seed_ops` or `within_ops` cannot be converted to a list of
|
||||
`tf.Operation`.
|
||||
"""
|
||||
if not util.is_iterable(seed_ops):
|
||||
seed_ops = [seed_ops]
|
||||
if not seed_ops:
|
||||
return []
|
||||
if isinstance(seed_ops[0], tf_ops.Tensor):
|
||||
ts = util.make_list_of_t(seed_ops, allow_graph=False)
|
||||
seed_ops = util.get_generating_ops(ts)
|
||||
else:
|
||||
seed_ops = util.make_list_of_op(seed_ops, allow_graph=False)
|
||||
|
||||
stop_at_ts = frozenset(util.make_list_of_t(stop_at_ts))
|
||||
seed_ops = frozenset(util.make_list_of_op(seed_ops))
|
||||
if within_ops:
|
||||
within_ops = util.make_list_of_op(within_ops, allow_graph=False)
|
||||
within_ops = frozenset(within_ops)
|
||||
seed_ops &= within_ops
|
||||
|
||||
def is_within(op):
|
||||
return (within_ops is None or op in within_ops) and (
|
||||
within_ops_fn is None or within_ops_fn(op))
|
||||
|
||||
result = list(seed_ops)
|
||||
wave = set(seed_ops)
|
||||
while wave:
|
||||
new_wave = set()
|
||||
for op in wave:
|
||||
for new_t in op.inputs:
|
||||
if new_t in stop_at_ts:
|
||||
continue
|
||||
if new_t.op not in result and is_within(new_t.op):
|
||||
new_wave.add(new_t.op)
|
||||
if control_inputs:
|
||||
for new_op in op.control_inputs:
|
||||
if new_op not in result and is_within(new_op):
|
||||
new_wave.add(new_op)
|
||||
util.concatenate_unique(result, new_wave)
|
||||
wave = new_wave
|
||||
if not inclusive:
|
||||
result = [op for op in result if op not in seed_ops]
|
||||
return result
|
||||
return op_selector.get_backward_walk_ops(
|
||||
seed_ops,
|
||||
inclusive=inclusive,
|
||||
within_ops=within_ops,
|
||||
within_ops_fn=within_ops_fn,
|
||||
stop_at_ts=stop_at_ts,
|
||||
control_inputs=control_inputs)
|
||||
|
||||
|
||||
def get_walks_intersection_ops(forward_seed_ops,
|
||||
|
|
|
@ -36,13 +36,10 @@ load("//tensorflow:tensorflow.bzl", "py_tests")
|
|||
load("//tensorflow:tensorflow.bzl", "tf_py_build_info_genrule")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cc_shared_object")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_native_cc_binary")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_custom_op_library_additional_deps_impl")
|
||||
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
|
||||
load("//tensorflow:tensorflow.bzl", "cuda_py_tests")
|
||||
load("//tensorflow/core:platform/default/build_config.bzl", "pyx_library")
|
||||
load("//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library")
|
||||
load("//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library_py")
|
||||
load("//tensorflow/core:platform/default/build_config.bzl", "tf_additional_lib_deps")
|
||||
load("//tensorflow/core:platform/default/build_config.bzl", "tf_additional_all_protos")
|
||||
load("//tensorflow/core:platform/default/build_config.bzl", "tf_protos_grappler")
|
||||
|
@ -3074,6 +3071,13 @@ py_library(
|
|||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "op_selector",
|
||||
srcs = ["ops/op_selector.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [":framework_ops"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "math_ops",
|
||||
srcs = ["ops/math_ops.py"],
|
||||
|
@ -3862,6 +3866,20 @@ cuda_py_test(
|
|||
xla_enable_strict_auto_jit = True,
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "op_selector_test",
|
||||
srcs = ["ops/op_selector_test.py"],
|
||||
python_version = "PY2",
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":client_testlib",
|
||||
":constant_op",
|
||||
":framework_ops",
|
||||
":math_ops",
|
||||
":op_selector",
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
name = "gradient_checker_v2_test",
|
||||
size = "medium",
|
||||
|
|
|
@ -585,8 +585,8 @@ py_library(
|
|||
srcs_version = "PY2AND3",
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = [
|
||||
":context",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:op_selector",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -25,11 +25,11 @@ import six
|
|||
from tensorflow.python.framework import func_graph
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import op_selector
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
|
||||
|
||||
def _graph_inputs(op):
|
||||
return [x.op for x in op.inputs] + list(op.control_inputs)
|
||||
UnliftableError = op_selector.UnliftableError
|
||||
|
||||
|
||||
def _as_operation(op_or_tensor):
|
||||
|
@ -38,106 +38,10 @@ def _as_operation(op_or_tensor):
|
|||
return op_or_tensor
|
||||
|
||||
|
||||
class UnliftableError(Exception):
|
||||
"""Raised if a Tensor cannot be lifted from the graph."""
|
||||
|
||||
# Prevent autograph from rewriting this error.
|
||||
ag_pass_through = True
|
||||
|
||||
|
||||
def _constant_inputs(op_or_tensor):
|
||||
return all(_as_operation(i).type == u"Const"
|
||||
and not _as_operation(i).control_inputs
|
||||
for i in _graph_inputs(_as_operation(op_or_tensor)))
|
||||
|
||||
|
||||
def _path_from(from_op, tensor, sources):
|
||||
"""Find one path from `from_op` to `tensor`, ignoring `sources`.
|
||||
|
||||
Args:
|
||||
from_op: A `tf.Operation`.
|
||||
tensor: A `tf.Operation` or `tf.Tensor`.
|
||||
sources: A list of `tf.Tensor`.
|
||||
|
||||
Returns:
|
||||
A python string containing the path, or "??" if none is found.
|
||||
"""
|
||||
visited_ops = set([x.op for x in sources])
|
||||
ops_to_visit = [_as_operation(tensor)]
|
||||
some_op_output = {}
|
||||
while ops_to_visit:
|
||||
op = ops_to_visit.pop()
|
||||
if op in visited_ops:
|
||||
continue
|
||||
visited_ops.add(op)
|
||||
if op == from_op:
|
||||
path_op = op
|
||||
path = [path_op]
|
||||
final_op = _as_operation(tensor)
|
||||
while path_op != final_op:
|
||||
path_op = some_op_output[path_op]
|
||||
path.append(path_op)
|
||||
return " <- ".join(["%s (%s)" % (x.name, x.type) for x in reversed(path)])
|
||||
else:
|
||||
for inp in _graph_inputs(op):
|
||||
if inp not in visited_ops and inp not in sources:
|
||||
some_op_output[inp] = op
|
||||
ops_to_visit.append(inp)
|
||||
return "??"
|
||||
|
||||
|
||||
def _map_subgraph(init_tensor, sources, disallowed_placeholders, visited_ops,
|
||||
op_outputs, add_sources):
|
||||
"""Walk a Graph and capture the subgraph between init_tensor and sources.
|
||||
|
||||
Note: This function mutates visited_ops and op_outputs.
|
||||
|
||||
Arguments:
|
||||
init_tensor: A Tensor or Operation where the subgraph terminates.
|
||||
sources: A set of Tensors where subgraph extraction should stop.
|
||||
disallowed_placeholders: An optional set of ops which may not appear in the
|
||||
lifted graph. Defaults to all placeholders.
|
||||
visited_ops: A set of operations which were visited in a prior pass.
|
||||
op_outputs: A defaultdict containing the outputs of an op which are to be
|
||||
copied into the new subgraph.
|
||||
add_sources: A boolean indicating whether placeholders which are not in
|
||||
sources should be allowed.
|
||||
|
||||
Returns:
|
||||
The set of placeholders upon which init_tensor depends and are not in
|
||||
sources.
|
||||
|
||||
Raises:
|
||||
UnliftableError: if init_tensor depends on a placeholder which is not in
|
||||
sources and add_sources is False.
|
||||
"""
|
||||
ops_to_visit = [_as_operation(init_tensor)]
|
||||
extra_sources = set()
|
||||
while ops_to_visit:
|
||||
op = ops_to_visit.pop()
|
||||
if op in visited_ops:
|
||||
continue
|
||||
visited_ops.add(op)
|
||||
|
||||
should_raise = False
|
||||
if disallowed_placeholders is not None and op in disallowed_placeholders:
|
||||
should_raise = True
|
||||
elif op.type == "Placeholder":
|
||||
if disallowed_placeholders is None and not add_sources:
|
||||
should_raise = True
|
||||
extra_sources.update(op.outputs)
|
||||
|
||||
if should_raise:
|
||||
raise UnliftableError(
|
||||
"Unable to lift tensor %s because it depends transitively on "
|
||||
"placeholder %s via at least one path, e.g.: %s"
|
||||
% (repr(init_tensor), repr(op), _path_from(op, init_tensor, sources)))
|
||||
for inp in _graph_inputs(op):
|
||||
op_outputs[inp].add(op)
|
||||
if inp not in visited_ops and inp not in (sources or extra_sources):
|
||||
ops_to_visit.append(inp)
|
||||
|
||||
return extra_sources
|
||||
for i in op_selector.graph_inputs(_as_operation(op_or_tensor)))
|
||||
|
||||
|
||||
# Represents an input to `copied_op` which must be updated once
|
||||
|
@ -323,7 +227,7 @@ def lift_to_graph(init_tensors, graph, sources=None,
|
|||
|
||||
# First we extract the subgraph between init_tensors and sources.
|
||||
for init_tensor in init_tensors:
|
||||
sources.update(_map_subgraph(
|
||||
sources.update(op_selector.map_subgraph(
|
||||
init_tensor=init_tensor,
|
||||
sources=sources,
|
||||
disallowed_placeholders=disallowed_placeholders,
|
||||
|
@ -345,7 +249,7 @@ def lift_to_graph(init_tensors, graph, sources=None,
|
|||
continue
|
||||
marked_ops.add(op)
|
||||
ops_to_copy.append(op)
|
||||
for inp in _graph_inputs(op):
|
||||
for inp in op_selector.graph_inputs(op):
|
||||
# Don't lift the TPUReplicateMetadata nodes out of the function, because
|
||||
# it has no registered kernels.
|
||||
if inp.name == "TPUReplicateMetadata":
|
||||
|
@ -422,3 +326,4 @@ def lift_to_graph(init_tensors, graph, sources=None,
|
|||
# pylint: enable=protected-access
|
||||
|
||||
return op_map
|
||||
|
||||
|
|
|
@ -25,6 +25,7 @@ from tensorflow.python.framework import dtypes
|
|||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gen_array_ops
|
||||
from tensorflow.python.ops import op_selector
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
|
@ -34,6 +35,12 @@ from tensorflow.python.util import tf_inspect
|
|||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
||||
VAR_OP_TYPES = [
|
||||
"VariableV2",
|
||||
"VarHandleOp",
|
||||
]
|
||||
|
||||
|
||||
def copy_handle_data(source_t, target_t):
|
||||
"""Copies HandleData for variant and resource type tensors if available.
|
||||
|
||||
|
@ -163,6 +170,29 @@ def custom_gradient(f):
|
|||
return tf_decorator.make_decorator(f, decorated)
|
||||
|
||||
|
||||
def get_variable_by_name(var_name):
|
||||
candidate_vars = ops.get_collection(
|
||||
ops.GraphKeys.GLOBAL_VARIABLES, scope=var_name)
|
||||
assert len(candidate_vars) == 1
|
||||
return candidate_vars[0]
|
||||
|
||||
|
||||
def get_dependent_variables(input_ops, output_ops):
|
||||
"""Finds variables involved in the subgraph b/w input_ops and output_ops."""
|
||||
|
||||
# avoids the edge-case when input_ops == output_ops.
|
||||
output_ops = nest.map_structure(gen_array_ops.identity, output_ops)
|
||||
|
||||
inbetween_ops = op_selector.get_backward_walk_ops(
|
||||
seed_ops=output_ops,
|
||||
stop_at_ts=input_ops,
|
||||
inclusive=False)
|
||||
var_ops = (op for op in inbetween_ops if op.type in VAR_OP_TYPES)
|
||||
var_names = (op.name for op in var_ops)
|
||||
tf_vars = [get_variable_by_name(var_name) for var_name in var_names]
|
||||
return tf_vars
|
||||
|
||||
|
||||
def _graph_mode_decorator(f, *args, **kwargs):
|
||||
"""Implement custom gradient decorator for graph mode."""
|
||||
# TODO(rsepassi): Add support for kwargs
|
||||
|
@ -191,7 +221,10 @@ def _graph_mode_decorator(f, *args, **kwargs):
|
|||
"with `use_resource=False`.")
|
||||
# The variables that grad_fn needs to return gradients for are the set of
|
||||
# variables used that are *not* part of the inputs.
|
||||
variables = list(set(tape.watched_variables()) - set(args))
|
||||
tf1_variables = get_dependent_variables(input_ops=args, output_ops=result)
|
||||
eager_variables = list(set(tape.watched_variables()) - set(args))
|
||||
variables = list(set(tf1_variables + eager_variables))
|
||||
|
||||
grad_argspec = tf_inspect.getfullargspec(grad_fn)
|
||||
variables_in_signature = ("variables" in grad_argspec.args or
|
||||
grad_argspec.varkw)
|
||||
|
|
|
@ -954,6 +954,42 @@ class CustomGradientTest(test_util.TensorFlowTestCase):
|
|||
dw = sess.run(math_ops.reduce_sum(grads[1]))
|
||||
self.assertEqual(12., dw)
|
||||
|
||||
def testCustomGradientWithVariablesNoFalsePositives(self):
|
||||
|
||||
@custom_gradient.custom_gradient
|
||||
def F(x):
|
||||
out = core_layers.dense(x, 3, use_bias=False)
|
||||
|
||||
def Grad(out_grad, variables=None): # pylint: disable=redefined-outer-name
|
||||
self.assertEqual(1, len(variables))
|
||||
grads = gradients.gradients(out, [x, variables[0]], grad_ys=out_grad)
|
||||
return grads[0], [array_ops.ones((3, 3))]
|
||||
|
||||
return out, Grad
|
||||
|
||||
with ops.Graph().as_default():
|
||||
with variable_scope.variable_scope("f", use_resource=True) as vs:
|
||||
a = array_ops.ones((2, 4))
|
||||
|
||||
# Variabes in these layers shouldn't be picked up by the decorator.
|
||||
b = core_layers.dense(a, 3, use_bias=False)
|
||||
c = core_layers.dense(b, 3, use_bias=False)
|
||||
x = core_layers.dense(b, 3, use_bias=False) + c
|
||||
|
||||
# Only the variables used in F.
|
||||
y = F(x)
|
||||
|
||||
all_vars = vs.global_variables()
|
||||
assert len(all_vars) == 4
|
||||
grads = gradients.gradients(y, [x] + all_vars)
|
||||
_, var_grads = grads[0], grads[1:]
|
||||
for g in grads:
|
||||
self.assertIsNotNone(g)
|
||||
with session.Session() as sess:
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
dw = sess.run(math_ops.reduce_sum(var_grads[-1]))
|
||||
self.assertEqual(9., dw)
|
||||
|
||||
def testCustomGradientWithVariablesEager(self):
|
||||
with context.eager_mode():
|
||||
layer = core_layers.Dense(4, use_bias=False)
|
||||
|
|
|
@ -0,0 +1,395 @@
|
|||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tools for selecting ops in a graph."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.framework import ops
|
||||
|
||||
|
||||
def is_iterable(obj):
|
||||
"""Return true if the object is iterable."""
|
||||
if isinstance(obj, ops.Tensor):
|
||||
return False
|
||||
try:
|
||||
_ = iter(obj)
|
||||
except Exception: # pylint: disable=broad-except
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def concatenate_unique(la, lb):
|
||||
"""Add all the elements of `lb` to `la` if they are not there already.
|
||||
|
||||
The elements added to `la` maintain ordering with respect to `lb`.
|
||||
|
||||
Args:
|
||||
la: List of Python objects.
|
||||
lb: List of Python objects.
|
||||
Returns:
|
||||
`la`: The list `la` with missing elements from `lb`.
|
||||
"""
|
||||
la_set = set(la)
|
||||
for l in lb:
|
||||
if l not in la_set:
|
||||
la.append(l)
|
||||
la_set.add(l)
|
||||
return la
|
||||
|
||||
|
||||
def get_tensors(graph):
|
||||
"""get all the tensors which are input or output of an op in the graph.
|
||||
|
||||
Args:
|
||||
graph: a `tf.Graph`.
|
||||
Returns:
|
||||
A list of `tf.Tensor`.
|
||||
Raises:
|
||||
TypeError: if graph is not a `tf.Graph`.
|
||||
"""
|
||||
if not isinstance(graph, ops.Graph):
|
||||
raise TypeError("Expected a graph, got: {}".format(type(graph)))
|
||||
ts = []
|
||||
for op in graph.get_operations():
|
||||
ts += op.outputs
|
||||
return ts
|
||||
|
||||
|
||||
def get_unique_graph(tops, check_types=None, none_if_empty=False):
|
||||
"""Return the unique graph used by the all the elements in tops.
|
||||
|
||||
Args:
|
||||
tops: list of elements to check (usually a list of tf.Operation and/or
|
||||
tf.Tensor). Or a tf.Graph.
|
||||
check_types: check that the element in tops are of given type(s). If None,
|
||||
the types (tf.Operation, tf.Tensor) are used.
|
||||
none_if_empty: don't raise an error if tops is an empty list, just return
|
||||
None.
|
||||
Returns:
|
||||
The unique graph used by all the tops.
|
||||
Raises:
|
||||
TypeError: if tops is not a iterable of tf.Operation.
|
||||
ValueError: if the graph is not unique.
|
||||
"""
|
||||
if isinstance(tops, ops.Graph):
|
||||
return tops
|
||||
if not is_iterable(tops):
|
||||
raise TypeError("{} is not iterable".format(type(tops)))
|
||||
if check_types is None:
|
||||
check_types = (ops.Operation, ops.Tensor)
|
||||
elif not is_iterable(check_types):
|
||||
check_types = (check_types,)
|
||||
g = None
|
||||
for op in tops:
|
||||
if not isinstance(op, check_types):
|
||||
raise TypeError("Expected a type in ({}), got: {}".format(", ".join([str(
|
||||
t) for t in check_types]), type(op)))
|
||||
if g is None:
|
||||
g = op.graph
|
||||
elif g is not op.graph:
|
||||
raise ValueError("Operation {} does not belong to given graph".format(op))
|
||||
if g is None and not none_if_empty:
|
||||
raise ValueError("Can't find the unique graph of an empty list")
|
||||
return g
|
||||
|
||||
|
||||
def check_graphs(*args):
|
||||
"""Check that all the element in args belong to the same graph.
|
||||
|
||||
Args:
|
||||
*args: a list of object with a obj.graph property.
|
||||
Raises:
|
||||
ValueError: if all the elements do not belong to the same graph.
|
||||
"""
|
||||
graph = None
|
||||
for i, sgv in enumerate(args):
|
||||
if graph is None and sgv.graph is not None:
|
||||
graph = sgv.graph
|
||||
elif sgv.graph is not None and sgv.graph is not graph:
|
||||
raise ValueError("Argument[{}]: Wrong graph!".format(i))
|
||||
|
||||
|
||||
def make_list_of_t(ts, check_graph=True, allow_graph=True, ignore_ops=False):
|
||||
"""Convert ts to a list of `tf.Tensor`.
|
||||
|
||||
Args:
|
||||
ts: can be an iterable of `tf.Tensor`, a `tf.Graph` or a single tensor.
|
||||
check_graph: if `True` check if all the tensors belong to the same graph.
|
||||
allow_graph: if `False` a `tf.Graph` cannot be converted.
|
||||
ignore_ops: if `True`, silently ignore `tf.Operation`.
|
||||
Returns:
|
||||
A newly created list of `tf.Tensor`.
|
||||
Raises:
|
||||
TypeError: if `ts` cannot be converted to a list of `tf.Tensor` or,
|
||||
if `check_graph` is `True`, if all the ops do not belong to the same graph.
|
||||
"""
|
||||
if isinstance(ts, ops.Graph):
|
||||
if allow_graph:
|
||||
return get_tensors(ts)
|
||||
else:
|
||||
raise TypeError("allow_graph is False: cannot convert a tf.Graph.")
|
||||
else:
|
||||
if not is_iterable(ts):
|
||||
ts = [ts]
|
||||
if not ts:
|
||||
return []
|
||||
if check_graph:
|
||||
check_types = None if ignore_ops else ops.Tensor
|
||||
get_unique_graph(ts, check_types=check_types)
|
||||
return [t for t in ts if isinstance(t, ops.Tensor)]
|
||||
|
||||
|
||||
def get_generating_ops(ts):
|
||||
"""Return all the generating ops of the tensors in `ts`.
|
||||
|
||||
Args:
|
||||
ts: a list of `tf.Tensor`
|
||||
Returns:
|
||||
A list of all the generating `tf.Operation` of the tensors in `ts`.
|
||||
Raises:
|
||||
TypeError: if `ts` cannot be converted to a list of `tf.Tensor`.
|
||||
"""
|
||||
ts = make_list_of_t(ts, allow_graph=False)
|
||||
return [t.op for t in ts]
|
||||
|
||||
|
||||
def get_consuming_ops(ts):
|
||||
"""Return all the consuming ops of the tensors in ts.
|
||||
|
||||
Args:
|
||||
ts: a list of `tf.Tensor`
|
||||
Returns:
|
||||
A list of all the consuming `tf.Operation` of the tensors in `ts`.
|
||||
Raises:
|
||||
TypeError: if ts cannot be converted to a list of `tf.Tensor`.
|
||||
"""
|
||||
ts = make_list_of_t(ts, allow_graph=False)
|
||||
tops = []
|
||||
for t in ts:
|
||||
for op in t.consumers():
|
||||
if op not in tops:
|
||||
tops.append(op)
|
||||
return tops
|
||||
|
||||
|
||||
def make_list_of_op(tops, check_graph=True, allow_graph=True, ignore_ts=False):
|
||||
"""Convert ops to a list of `tf.Operation`.
|
||||
|
||||
Args:
|
||||
tops: can be an iterable of `tf.Operation`, a `tf.Graph` or a single
|
||||
operation.
|
||||
check_graph: if `True` check if all the operations belong to the same graph.
|
||||
allow_graph: if `False` a `tf.Graph` cannot be converted.
|
||||
ignore_ts: if True, silently ignore `tf.Tensor`.
|
||||
Returns:
|
||||
A newly created list of `tf.Operation`.
|
||||
Raises:
|
||||
TypeError: if tops cannot be converted to a list of `tf.Operation` or,
|
||||
if `check_graph` is `True`, if all the ops do not belong to the
|
||||
same graph.
|
||||
"""
|
||||
if isinstance(tops, ops.Graph):
|
||||
if allow_graph:
|
||||
return tops.get_operations()
|
||||
else:
|
||||
raise TypeError("allow_graph is False: cannot convert a tf.Graph.")
|
||||
else:
|
||||
if not is_iterable(tops):
|
||||
tops = [tops]
|
||||
if not tops:
|
||||
return []
|
||||
if check_graph:
|
||||
check_types = None if ignore_ts else ops.Operation
|
||||
get_unique_graph(tops, check_types=check_types)
|
||||
return [op for op in tops if isinstance(op, ops.Operation)]
|
||||
|
||||
|
||||
def get_backward_walk_ops(seed_ops,
|
||||
inclusive=True,
|
||||
within_ops=None,
|
||||
within_ops_fn=None,
|
||||
stop_at_ts=(),
|
||||
control_inputs=False):
|
||||
"""Do a backward graph walk and return all the visited ops.
|
||||
|
||||
Args:
|
||||
seed_ops: an iterable of operations from which the backward graph
|
||||
walk starts. If a list of tensors is given instead, the seed_ops are set
|
||||
to be the generators of those tensors.
|
||||
inclusive: if True the given seed_ops are also part of the resulting set.
|
||||
within_ops: an iterable of `tf.Operation` within which the search is
|
||||
restricted. If `within_ops` is `None`, the search is performed within
|
||||
the whole graph.
|
||||
within_ops_fn: if provided, a function on ops that should return True iff
|
||||
the op is within the graph traversal. This can be used along within_ops,
|
||||
in which case an op is within if it is also in within_ops.
|
||||
stop_at_ts: an iterable of tensors at which the graph walk stops.
|
||||
control_inputs: if True, control inputs will be used while moving backward.
|
||||
Returns:
|
||||
A Python set of all the `tf.Operation` behind `seed_ops`.
|
||||
Raises:
|
||||
TypeError: if `seed_ops` or `within_ops` cannot be converted to a list of
|
||||
`tf.Operation`.
|
||||
"""
|
||||
if not is_iterable(seed_ops):
|
||||
seed_ops = [seed_ops]
|
||||
if not seed_ops:
|
||||
return []
|
||||
if isinstance(seed_ops[0], ops.Tensor):
|
||||
ts = make_list_of_t(seed_ops, allow_graph=False)
|
||||
seed_ops = get_generating_ops(ts)
|
||||
else:
|
||||
seed_ops = make_list_of_op(seed_ops, allow_graph=False)
|
||||
|
||||
stop_at_ts = frozenset(make_list_of_t(stop_at_ts))
|
||||
seed_ops = frozenset(make_list_of_op(seed_ops))
|
||||
if within_ops:
|
||||
within_ops = make_list_of_op(within_ops, allow_graph=False)
|
||||
within_ops = frozenset(within_ops)
|
||||
seed_ops &= within_ops
|
||||
|
||||
def is_within(op):
|
||||
return (within_ops is None or op in within_ops) and (
|
||||
within_ops_fn is None or within_ops_fn(op))
|
||||
|
||||
result = list(seed_ops)
|
||||
wave = set(seed_ops)
|
||||
while wave:
|
||||
new_wave = set()
|
||||
for op in wave:
|
||||
for new_t in op.inputs:
|
||||
if new_t in stop_at_ts:
|
||||
continue
|
||||
if new_t.op not in result and is_within(new_t.op):
|
||||
new_wave.add(new_t.op)
|
||||
if control_inputs:
|
||||
for new_op in op.control_inputs:
|
||||
if new_op not in result and is_within(new_op):
|
||||
new_wave.add(new_op)
|
||||
concatenate_unique(result, new_wave)
|
||||
wave = new_wave
|
||||
if not inclusive:
|
||||
result = [op for op in result if op not in seed_ops]
|
||||
return result
|
||||
|
||||
|
||||
class UnliftableError(Exception):
|
||||
"""Raised if a Tensor cannot be lifted from the graph."""
|
||||
|
||||
# Prevent autograph from rewriting this error.
|
||||
ag_pass_through = True
|
||||
|
||||
|
||||
def _as_operation(op_or_tensor):
|
||||
if isinstance(op_or_tensor, ops.Tensor):
|
||||
return op_or_tensor.op
|
||||
return op_or_tensor
|
||||
|
||||
|
||||
def graph_inputs(op):
|
||||
return [x.op for x in op.inputs] + list(op.control_inputs)
|
||||
|
||||
|
||||
def _path_from(from_op, tensor, sources):
|
||||
"""Find one path from `from_op` to `tensor`, ignoring `sources`.
|
||||
|
||||
Args:
|
||||
from_op: A `tf.Operation`.
|
||||
tensor: A `tf.Operation` or `tf.Tensor`.
|
||||
sources: A list of `tf.Tensor`.
|
||||
|
||||
Returns:
|
||||
A python string containing the path, or "??" if none is found.
|
||||
"""
|
||||
visited_ops = set([x.op for x in sources])
|
||||
ops_to_visit = [_as_operation(tensor)]
|
||||
some_op_output = {}
|
||||
while ops_to_visit:
|
||||
op = ops_to_visit.pop()
|
||||
if op in visited_ops:
|
||||
continue
|
||||
visited_ops.add(op)
|
||||
if op == from_op:
|
||||
path_op = op
|
||||
path = [path_op]
|
||||
final_op = _as_operation(tensor)
|
||||
while path_op != final_op:
|
||||
path_op = some_op_output[path_op]
|
||||
path.append(path_op)
|
||||
return " <- ".join(["%s (%s)" % (x.name, x.type) for x in reversed(path)])
|
||||
else:
|
||||
for inp in graph_inputs(op):
|
||||
if inp not in visited_ops and inp not in sources:
|
||||
some_op_output[inp] = op
|
||||
ops_to_visit.append(inp)
|
||||
return "??"
|
||||
|
||||
|
||||
# TODO(jmenick) - there is considerable duplication of functionality between
|
||||
# this function and get_backward_walk_ops(). Need to deduplicate.
|
||||
def map_subgraph(init_tensor, sources, disallowed_placeholders, visited_ops,
|
||||
op_outputs, add_sources):
|
||||
"""Walk a Graph and capture the subgraph between init_tensor and sources.
|
||||
|
||||
Note: This function mutates visited_ops and op_outputs.
|
||||
|
||||
Arguments:
|
||||
init_tensor: A Tensor or Operation where the subgraph terminates.
|
||||
sources: A set of Tensors where subgraph extraction should stop.
|
||||
disallowed_placeholders: An optional set of ops which may not appear in the
|
||||
lifted graph. Defaults to all placeholders.
|
||||
visited_ops: A set of operations which were visited in a prior pass.
|
||||
op_outputs: A defaultdict containing the outputs of an op which are to be
|
||||
copied into the new subgraph.
|
||||
add_sources: A boolean indicating whether placeholders which are not in
|
||||
sources should be allowed.
|
||||
|
||||
Returns:
|
||||
The set of placeholders upon which init_tensor depends and are not in
|
||||
sources.
|
||||
|
||||
Raises:
|
||||
UnliftableError: if init_tensor depends on a placeholder which is not in
|
||||
sources and add_sources is False.
|
||||
"""
|
||||
ops_to_visit = [_as_operation(init_tensor)]
|
||||
extra_sources = set()
|
||||
while ops_to_visit:
|
||||
op = ops_to_visit.pop()
|
||||
if op in visited_ops:
|
||||
continue
|
||||
visited_ops.add(op)
|
||||
|
||||
should_raise = False
|
||||
if disallowed_placeholders is not None and op in disallowed_placeholders:
|
||||
should_raise = True
|
||||
elif op.type == "Placeholder":
|
||||
if disallowed_placeholders is None and not add_sources:
|
||||
should_raise = True
|
||||
extra_sources.update(op.outputs)
|
||||
|
||||
if should_raise:
|
||||
raise UnliftableError(
|
||||
"Unable to lift tensor %s because it depends transitively on "
|
||||
"placeholder %s via at least one path, e.g.: %s"
|
||||
% (repr(init_tensor), repr(op), _path_from(op, init_tensor, sources)))
|
||||
for inp in graph_inputs(op):
|
||||
op_outputs[inp].add(op)
|
||||
if inp not in visited_ops and inp not in (sources or extra_sources):
|
||||
ops_to_visit.append(inp)
|
||||
|
||||
return extra_sources
|
|
@ -0,0 +1,166 @@
|
|||
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for op_selector.py."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import ops as ops_lib
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import op_selector
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class SelectTest(test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.graph = ops_lib.Graph()
|
||||
with self.graph.as_default():
|
||||
self.a = constant_op.constant([1., 1.], shape=[2], name="a")
|
||||
with ops_lib.name_scope("foo"):
|
||||
self.b = constant_op.constant([2., 2.], shape=[2], name="b")
|
||||
self.c = math_ops.add(self.a, self.b, name="c")
|
||||
self.d = constant_op.constant([3., 3.], shape=[2], name="d")
|
||||
with ops_lib.name_scope("bar"):
|
||||
self.e = math_ops.add(self.c, self.d, name="e")
|
||||
self.f = math_ops.add(self.c, self.d, name="f")
|
||||
self.g = math_ops.add(self.c, self.a, name="g")
|
||||
with ops_lib.control_dependencies([self.c.op]):
|
||||
self.h = math_ops.add(self.f, self.g, name="h")
|
||||
|
||||
def test_is_iterable(self):
|
||||
"""Test for is_iterable."""
|
||||
self.assertTrue(op_selector.is_iterable([0, 1, 2]))
|
||||
self.assertFalse(op_selector.is_iterable(3))
|
||||
|
||||
def test_unique_graph(self):
|
||||
"""Test for check_graphs and get_unique_graph."""
|
||||
g0 = ops_lib.Graph()
|
||||
with g0.as_default():
|
||||
a0 = constant_op.constant(1)
|
||||
b0 = constant_op.constant(2)
|
||||
g1 = ops_lib.Graph()
|
||||
with g1.as_default():
|
||||
a1 = constant_op.constant(1)
|
||||
b1 = constant_op.constant(2)
|
||||
# Same graph, should be fine.
|
||||
self.assertIsNone(op_selector.check_graphs(a0, b0))
|
||||
# Two different graphs, should assert.
|
||||
with self.assertRaises(ValueError):
|
||||
op_selector.check_graphs(a0, b0, a1, b1)
|
||||
# a0 and b0 belongs to the same graph, should be fine.
|
||||
self.assertEqual(op_selector.get_unique_graph([a0, b0]), g0)
|
||||
# Different graph, should raise an error.
|
||||
with self.assertRaises(ValueError):
|
||||
op_selector.get_unique_graph([a0, b0, a1, b1])
|
||||
|
||||
def test_make_list_of_op(self):
|
||||
"""Test for make_list_of_op."""
|
||||
g0 = ops_lib.Graph()
|
||||
with g0.as_default():
|
||||
a0 = constant_op.constant(1)
|
||||
b0 = constant_op.constant(2)
|
||||
# Should extract the ops from the graph.
|
||||
self.assertEqual(len(op_selector.make_list_of_op(g0)), 2)
|
||||
# Should extract the ops from the tuple.
|
||||
self.assertEqual(len(op_selector.make_list_of_op((a0.op, b0.op))), 2)
|
||||
|
||||
def test_make_list_of_t(self):
|
||||
"""Test for make_list_of_t."""
|
||||
g0 = ops_lib.Graph()
|
||||
with g0.as_default():
|
||||
a0 = constant_op.constant(1)
|
||||
b0 = constant_op.constant(2)
|
||||
c0 = math_ops.add(a0, b0) # pylint: disable=unused-variable
|
||||
# Should extract the tensors from tre graph.
|
||||
self.assertEqual(len(op_selector.make_list_of_t(g0)), 3)
|
||||
# Should extract the tensors from the tuple
|
||||
self.assertEqual(len(op_selector.make_list_of_t((a0, b0))), 2)
|
||||
# Should extract the tensors and ignore the ops.
|
||||
self.assertEqual(
|
||||
len(op_selector.make_list_of_t(
|
||||
(a0, a0.op, b0), ignore_ops=True)), 2)
|
||||
|
||||
def test_get_generating_consuming(self):
|
||||
"""Test for get_generating_ops and get_consuming_ops."""
|
||||
g0 = ops_lib.Graph()
|
||||
with g0.as_default():
|
||||
a0 = constant_op.constant(1)
|
||||
b0 = constant_op.constant(2)
|
||||
c0 = math_ops.add(a0, b0)
|
||||
self.assertEqual(len(op_selector.get_generating_ops([a0, b0])), 2)
|
||||
self.assertEqual(len(op_selector.get_consuming_ops([a0, b0])), 1)
|
||||
self.assertEqual(len(op_selector.get_generating_ops([c0])), 1)
|
||||
self.assertEqual(op_selector.get_consuming_ops([c0]), [])
|
||||
|
||||
def test_backward_walk_ops(self):
|
||||
seed_ops = [self.h.op]
|
||||
# Include all ops except for self.g.op
|
||||
within_ops = [
|
||||
x.op for x in [self.a, self.b, self.c, self.d, self.e, self.f, self.h]
|
||||
]
|
||||
# For the fn, exclude self.c.op.
|
||||
within_ops_fn = lambda op: op not in (self.c.op,)
|
||||
stop_at_ts = (self.f,)
|
||||
|
||||
with self.graph.as_default():
|
||||
# Backward walk only includes h since we stop at f and g is not within.
|
||||
ops = op_selector.get_backward_walk_ops(
|
||||
seed_ops,
|
||||
inclusive=True,
|
||||
within_ops=within_ops,
|
||||
within_ops_fn=within_ops_fn,
|
||||
stop_at_ts=stop_at_ts)
|
||||
self.assertEqual(set(ops), set([self.h.op]))
|
||||
|
||||
# If we do inclusive=False, the result is empty.
|
||||
ops = op_selector.get_backward_walk_ops(
|
||||
seed_ops,
|
||||
inclusive=False,
|
||||
within_ops=within_ops,
|
||||
within_ops_fn=within_ops_fn,
|
||||
stop_at_ts=stop_at_ts)
|
||||
self.assertEqual(set(ops), set())
|
||||
|
||||
# Removing stop_at_fs adds f.op, d.op.
|
||||
ops = op_selector.get_backward_walk_ops(
|
||||
seed_ops,
|
||||
inclusive=True,
|
||||
within_ops=within_ops,
|
||||
within_ops_fn=within_ops_fn)
|
||||
self.assertEqual(set(ops), set([self.d.op, self.f.op, self.h.op]))
|
||||
|
||||
# Not using within_ops_fn adds back ops for a, b, c.
|
||||
ops = op_selector.get_backward_walk_ops(
|
||||
seed_ops, inclusive=True, within_ops=within_ops)
|
||||
self.assertEqual(
|
||||
set(ops),
|
||||
set([
|
||||
self.a.op, self.b.op, self.c.op, self.d.op, self.f.op, self.h.op
|
||||
]))
|
||||
|
||||
# Vanially backward search via self.h.op includes everything excpet e.op.
|
||||
ops = op_selector.get_backward_walk_ops(seed_ops, inclusive=True)
|
||||
self.assertEqual(
|
||||
set(ops),
|
||||
set([
|
||||
self.a.op, self.b.op, self.c.op, self.d.op, self.f.op, self.g.op,
|
||||
self.h.op
|
||||
]))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
Loading…
Reference in New Issue