Automated rollback of commit 7e1deb16c6

PiperOrigin-RevId: 254183487
This commit is contained in:
A. Unique TensorFlower 2019-06-20 06:18:36 -07:00 committed by TensorFlower Gardener
parent 8052ffb5b6
commit 9f06292332
9 changed files with 800 additions and 143 deletions

View File

@ -26,6 +26,7 @@ py_library(
deps = [
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:op_selector",
"//tensorflow/python:platform",
"//tensorflow/python:util",
"@six_archive//:six",

View File

@ -24,7 +24,11 @@ from six import iteritems
from six import string_types
from tensorflow.contrib.graph_editor import util
from tensorflow.python.ops import op_selector
from tensorflow.python.framework import ops as tf_ops
from tensorflow.python.util import deprecation
__all__ = [
"can_be_regex",
@ -452,6 +456,10 @@ def get_forward_walk_ops(seed_ops,
return result
@deprecation.deprecated(
"2019-06-06",
"Please use tensorflow.python.ops.op_selector.get_backward_walk_ops.",
warn_once=True)
def get_backward_walk_ops(seed_ops,
inclusive=True,
within_ops=None,
@ -479,46 +487,13 @@ def get_backward_walk_ops(seed_ops,
TypeError: if `seed_ops` or `within_ops` cannot be converted to a list of
`tf.Operation`.
"""
if not util.is_iterable(seed_ops):
seed_ops = [seed_ops]
if not seed_ops:
return []
if isinstance(seed_ops[0], tf_ops.Tensor):
ts = util.make_list_of_t(seed_ops, allow_graph=False)
seed_ops = util.get_generating_ops(ts)
else:
seed_ops = util.make_list_of_op(seed_ops, allow_graph=False)
stop_at_ts = frozenset(util.make_list_of_t(stop_at_ts))
seed_ops = frozenset(util.make_list_of_op(seed_ops))
if within_ops:
within_ops = util.make_list_of_op(within_ops, allow_graph=False)
within_ops = frozenset(within_ops)
seed_ops &= within_ops
def is_within(op):
return (within_ops is None or op in within_ops) and (
within_ops_fn is None or within_ops_fn(op))
result = list(seed_ops)
wave = set(seed_ops)
while wave:
new_wave = set()
for op in wave:
for new_t in op.inputs:
if new_t in stop_at_ts:
continue
if new_t.op not in result and is_within(new_t.op):
new_wave.add(new_t.op)
if control_inputs:
for new_op in op.control_inputs:
if new_op not in result and is_within(new_op):
new_wave.add(new_op)
util.concatenate_unique(result, new_wave)
wave = new_wave
if not inclusive:
result = [op for op in result if op not in seed_ops]
return result
return op_selector.get_backward_walk_ops(
seed_ops,
inclusive=inclusive,
within_ops=within_ops,
within_ops_fn=within_ops_fn,
stop_at_ts=stop_at_ts,
control_inputs=control_inputs)
def get_walks_intersection_ops(forward_seed_ops,

View File

@ -3079,6 +3079,13 @@ py_library(
],
)
py_library(
name = "op_selector",
srcs = ["ops/op_selector.py"],
srcs_version = "PY2AND3",
deps = [":framework_ops"],
)
py_library(
name = "math_ops",
srcs = ["ops/math_ops.py"],
@ -3878,6 +3885,20 @@ cuda_py_test(
xla_enable_strict_auto_jit = True,
)
py_test(
name = "op_selector_test",
srcs = ["ops/op_selector_test.py"],
python_version = "PY2",
srcs_version = "PY2AND3",
deps = [
":client_testlib",
":constant_op",
":framework_ops",
":math_ops",
":op_selector",
],
)
cuda_py_test(
name = "gradient_checker_v2_test",
size = "medium",

View File

@ -611,8 +611,8 @@ py_library(
srcs_version = "PY2AND3",
visibility = ["//tensorflow:internal"],
deps = [
":context",
"//tensorflow/python:framework_ops",
"//tensorflow/python:op_selector",
"@six_archive//:six",
],
)

View File

@ -25,11 +25,11 @@ import six
from tensorflow.python.framework import func_graph
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import op_selector
from tensorflow.python.ops import resource_variable_ops
def _graph_inputs(op):
return [x.op for x in op.inputs] + list(op.control_inputs)
UnliftableError = op_selector.UnliftableError
def _as_operation(op_or_tensor):
@ -38,106 +38,10 @@ def _as_operation(op_or_tensor):
return op_or_tensor
class UnliftableError(Exception):
"""Raised if a Tensor cannot be lifted from the graph."""
# Prevent autograph from rewriting this error.
ag_pass_through = True
def _constant_inputs(op_or_tensor):
return all(_as_operation(i).type == u"Const"
and not _as_operation(i).control_inputs
for i in _graph_inputs(_as_operation(op_or_tensor)))
def _path_from(from_op, tensor, sources):
"""Find one path from `from_op` to `tensor`, ignoring `sources`.
Args:
from_op: A `tf.Operation`.
tensor: A `tf.Operation` or `tf.Tensor`.
sources: A list of `tf.Tensor`.
Returns:
A python string containing the path, or "??" if none is found.
"""
visited_ops = set([x.op for x in sources])
ops_to_visit = [_as_operation(tensor)]
some_op_output = {}
while ops_to_visit:
op = ops_to_visit.pop()
if op in visited_ops:
continue
visited_ops.add(op)
if op == from_op:
path_op = op
path = [path_op]
final_op = _as_operation(tensor)
while path_op != final_op:
path_op = some_op_output[path_op]
path.append(path_op)
return " <- ".join(["%s (%s)" % (x.name, x.type) for x in reversed(path)])
else:
for inp in _graph_inputs(op):
if inp not in visited_ops and inp not in sources:
some_op_output[inp] = op
ops_to_visit.append(inp)
return "??"
def _map_subgraph(init_tensor, sources, disallowed_placeholders, visited_ops,
op_outputs, add_sources):
"""Walk a Graph and capture the subgraph between init_tensor and sources.
Note: This function mutates visited_ops and op_outputs.
Arguments:
init_tensor: A Tensor or Operation where the subgraph terminates.
sources: A set of Tensors where subgraph extraction should stop.
disallowed_placeholders: An optional set of ops which may not appear in the
lifted graph. Defaults to all placeholders.
visited_ops: A set of operations which were visited in a prior pass.
op_outputs: A defaultdict containing the outputs of an op which are to be
copied into the new subgraph.
add_sources: A boolean indicating whether placeholders which are not in
sources should be allowed.
Returns:
The set of placeholders upon which init_tensor depends and are not in
sources.
Raises:
UnliftableError: if init_tensor depends on a placeholder which is not in
sources and add_sources is False.
"""
ops_to_visit = [_as_operation(init_tensor)]
extra_sources = set()
while ops_to_visit:
op = ops_to_visit.pop()
if op in visited_ops:
continue
visited_ops.add(op)
should_raise = False
if disallowed_placeholders is not None and op in disallowed_placeholders:
should_raise = True
elif op.type == "Placeholder":
if disallowed_placeholders is None and not add_sources:
should_raise = True
extra_sources.update(op.outputs)
if should_raise:
raise UnliftableError(
"Unable to lift tensor %s because it depends transitively on "
"placeholder %s via at least one path, e.g.: %s"
% (repr(init_tensor), repr(op), _path_from(op, init_tensor, sources)))
for inp in _graph_inputs(op):
op_outputs[inp].add(op)
if inp not in visited_ops and inp not in (sources or extra_sources):
ops_to_visit.append(inp)
return extra_sources
for i in op_selector.graph_inputs(_as_operation(op_or_tensor)))
# Represents an input to `copied_op` which must be updated once
@ -323,7 +227,7 @@ def lift_to_graph(init_tensors, graph, sources=None,
# First we extract the subgraph between init_tensors and sources.
for init_tensor in init_tensors:
sources.update(_map_subgraph(
sources.update(op_selector.map_subgraph(
init_tensor=init_tensor,
sources=sources,
disallowed_placeholders=disallowed_placeholders,
@ -345,7 +249,7 @@ def lift_to_graph(init_tensors, graph, sources=None,
continue
marked_ops.add(op)
ops_to_copy.append(op)
for inp in _graph_inputs(op):
for inp in op_selector.graph_inputs(op):
# Don't lift the TPUReplicateMetadata nodes out of the function, because
# it has no registered kernels.
if inp.name == "TPUReplicateMetadata":
@ -422,3 +326,4 @@ def lift_to_graph(init_tensors, graph, sources=None,
# pylint: enable=protected-access
return op_map

View File

@ -25,6 +25,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import op_selector
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import tf_logging as logging
@ -34,6 +35,12 @@ from tensorflow.python.util import tf_inspect
from tensorflow.python.util.tf_export import tf_export
VAR_OP_TYPES = [
"VariableV2",
"VarHandleOp",
]
def copy_handle_data(source_t, target_t):
"""Copies HandleData for variant and resource type tensors if available.
@ -163,6 +170,30 @@ def custom_gradient(f):
return tf_decorator.make_decorator(f, decorated)
def get_variable_by_name(var_name):
candidate_vars = ops.get_collection(
ops.GraphKeys.GLOBAL_VARIABLES, scope=var_name)
assert len(candidate_vars) == 1
return candidate_vars[0]
def get_dependent_variables(input_ops, output_ops):
"""Finds variables involved in the subgraph b/w input_ops and output_ops."""
# avoids the edge-case when input_ops == output_ops.
output_ops = nest.map_structure(gen_array_ops.identity, output_ops)
inbetween_ops = op_selector.get_backward_walk_ops(
seed_ops=output_ops,
stop_at_ts=input_ops,
inclusive=False,
only_differentiable=True)
var_ops = (op for op in inbetween_ops if op.type in VAR_OP_TYPES)
var_names = (op.name for op in var_ops)
tf_vars = [get_variable_by_name(var_name) for var_name in var_names]
return tf_vars
def _graph_mode_decorator(f, *args, **kwargs):
"""Implement custom gradient decorator for graph mode."""
# TODO(rsepassi): Add support for kwargs
@ -191,7 +222,10 @@ def _graph_mode_decorator(f, *args, **kwargs):
"with `use_resource=False`.")
# The variables that grad_fn needs to return gradients for are the set of
# variables used that are *not* part of the inputs.
variables = list(set(tape.watched_variables()) - set(args))
tf1_variables = get_dependent_variables(input_ops=args, output_ops=result)
eager_variables = list(set(tape.watched_variables()) - set(args))
variables = list(set(tf1_variables + eager_variables))
grad_argspec = tf_inspect.getfullargspec(grad_fn)
variables_in_signature = ("variables" in grad_argspec.args or
grad_argspec.varkw)

View File

@ -872,6 +872,93 @@ class ResourceCondTest(test_util.TensorFlowTestCase):
self.assertTrue(None not in grads)
class GetDependentVariablesTest(test_util.TensorFlowTestCase):
def testNoVariables(self):
with ops.Graph().as_default():
func = lambda x: array_ops.identity(x) + 5.0
input_t = constant_op.constant(2.0)
result_t = func(input_t)
dependent_vars = custom_gradient.get_dependent_variables(
[input_t], [result_t])
# There are no variables.
self.assertEqual(dependent_vars, [])
def testVariablesOutside(self):
with ops.Graph().as_default():
init = constant_op.constant(100.0)
var = variables.Variable(init)
# The variable is closed over. It should be found.
func = lambda x: array_ops.identity(x) + 5.0 + var
input_t = constant_op.constant(2.0)
result_t = func(input_t)
dependent_vars = custom_gradient.get_dependent_variables(
[input_t], [result_t])
self.assertEqual(dependent_vars, [var])
def testVariablesOutsideButDSeparated(self):
with ops.Graph().as_default():
init = constant_op.constant(100.0)
var = variables.Variable(init)
# The variable is d-separated by the inputs. It should not be found.
input_t = array_ops.identity(var) * 5.0
func = lambda x: array_ops.identity(x) + 5.0
result_t = func(input_t)
dependent_vars = custom_gradient.get_dependent_variables(
[input_t], [result_t])
self.assertEqual(dependent_vars, [])
def testVariablesOutsideAndNonDifferentiable(self):
with ops.Graph().as_default():
init = constant_op.constant(100.0, shape=(5,))
var = variables.Variable(init, shape=(5,))
def _Func(x):
# non-differentiable dependency on var.
# the variable should not be found.
y = array_ops.ones_like(var)
return array_ops.identity(x) + 5.0 + y
input_t = constant_op.constant(2.0)
result_t = _Func(input_t)
dependent_vars = custom_gradient.get_dependent_variables(
[input_t], [result_t])
self.assertEqual(dependent_vars, [])
def testVariablesOutsideAndCustomGradient(self):
with ops.Graph().as_default():
init = constant_op.constant(100.0, shape=(5,))
var = variables.Variable(init, shape=(5,))
@custom_gradient.custom_gradient
def _MyOnesLike(x):
"""Dummy version of ones_like which defines a gradient."""
output = array_ops.ones_like(x)
def _Grad(dy):
return array_ops.identity(dy)
return output, _Grad
def _Func(x):
# non-differentiable operation with custom gradient.
# The variable should be found.
y = _MyOnesLike(var)
return array_ops.identity(x) + 5.0 + y
input_t = constant_op.constant(2.0)
result_t = _Func(input_t)
dependent_vars = custom_gradient.get_dependent_variables(
[input_t], [result_t])
self.assertEqual(dependent_vars, [var])
class CustomGradientTest(test_util.TensorFlowTestCase):
def testCustomGradientTrivial(self):
@ -954,6 +1041,42 @@ class CustomGradientTest(test_util.TensorFlowTestCase):
dw = sess.run(math_ops.reduce_sum(grads[1]))
self.assertEqual(12., dw)
def testCustomGradientWithVariablesNoFalsePositives(self):
@custom_gradient.custom_gradient
def F(x):
out = core_layers.dense(x, 3, use_bias=False)
def Grad(out_grad, variables=None): # pylint: disable=redefined-outer-name
self.assertEqual(1, len(variables))
grads = gradients.gradients(out, [x, variables[0]], grad_ys=out_grad)
return grads[0], [array_ops.ones((3, 3))]
return out, Grad
with ops.Graph().as_default():
with variable_scope.variable_scope("f", use_resource=True) as vs:
a = array_ops.ones((2, 4))
# Variabes in these layers shouldn't be picked up by the decorator.
b = core_layers.dense(a, 3, use_bias=False)
c = core_layers.dense(b, 3, use_bias=False)
x = core_layers.dense(b, 3, use_bias=False) + c
# Only the variables used in F.
y = F(x)
all_vars = vs.global_variables()
assert len(all_vars) == 4
grads = gradients.gradients(y, [x] + all_vars)
_, var_grads = grads[0], grads[1:]
for g in grads:
self.assertIsNotNone(g)
with session.Session() as sess:
self.evaluate(variables.global_variables_initializer())
dw = sess.run(math_ops.reduce_sum(var_grads[-1]))
self.assertEqual(9., dw)
def testCustomGradientWithVariablesEager(self):
with context.eager_mode():
layer = core_layers.Dense(4, use_bias=False)

View File

@ -0,0 +1,418 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tools for selecting ops in a graph."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import ops
def is_differentiable(op):
try:
return ops._gradient_registry.lookup(op.op_def.name) is not None # pylint: disable=protected-access
except LookupError:
return False
def is_iterable(obj):
"""Return true if the object is iterable."""
if isinstance(obj, ops.Tensor):
return False
try:
_ = iter(obj)
except Exception: # pylint: disable=broad-except
return False
return True
def concatenate_unique(la, lb):
"""Add all the elements of `lb` to `la` if they are not there already.
The elements added to `la` maintain ordering with respect to `lb`.
Args:
la: List of Python objects.
lb: List of Python objects.
Returns:
`la`: The list `la` with missing elements from `lb`.
"""
la_set = set(la)
for l in lb:
if l not in la_set:
la.append(l)
la_set.add(l)
return la
def get_tensors(graph):
"""get all the tensors which are input or output of an op in the graph.
Args:
graph: a `tf.Graph`.
Returns:
A list of `tf.Tensor`.
Raises:
TypeError: if graph is not a `tf.Graph`.
"""
if not isinstance(graph, ops.Graph):
raise TypeError("Expected a graph, got: {}".format(type(graph)))
ts = []
for op in graph.get_operations():
ts += op.outputs
return ts
def get_unique_graph(tops, check_types=None, none_if_empty=False):
"""Return the unique graph used by the all the elements in tops.
Args:
tops: list of elements to check (usually a list of tf.Operation and/or
tf.Tensor). Or a tf.Graph.
check_types: check that the element in tops are of given type(s). If None,
the types (tf.Operation, tf.Tensor) are used.
none_if_empty: don't raise an error if tops is an empty list, just return
None.
Returns:
The unique graph used by all the tops.
Raises:
TypeError: if tops is not a iterable of tf.Operation.
ValueError: if the graph is not unique.
"""
if isinstance(tops, ops.Graph):
return tops
if not is_iterable(tops):
raise TypeError("{} is not iterable".format(type(tops)))
if check_types is None:
check_types = (ops.Operation, ops.Tensor)
elif not is_iterable(check_types):
check_types = (check_types,)
g = None
for op in tops:
if not isinstance(op, check_types):
raise TypeError("Expected a type in ({}), got: {}".format(", ".join([str(
t) for t in check_types]), type(op)))
if g is None:
g = op.graph
elif g._graph_key != op.graph._graph_key: # pylint: disable=protected-access
raise ValueError("Operation {} does not belong to given graph".format(op))
if g is None and not none_if_empty:
raise ValueError("Can't find the unique graph of an empty list")
return g
def check_graphs(*args):
"""Check that all the element in args belong to the same graph.
Args:
*args: a list of object with a obj.graph property.
Raises:
ValueError: if all the elements do not belong to the same graph.
"""
graph = None
for i, sgv in enumerate(args):
if graph is None and sgv.graph is not None:
graph = sgv.graph
elif sgv.graph is not None and sgv.graph is not graph:
raise ValueError("Argument[{}]: Wrong graph!".format(i))
def make_list_of_t(ts, check_graph=True, allow_graph=True, ignore_ops=False):
"""Convert ts to a list of `tf.Tensor`.
Args:
ts: can be an iterable of `tf.Tensor`, a `tf.Graph` or a single tensor.
check_graph: if `True` check if all the tensors belong to the same graph.
allow_graph: if `False` a `tf.Graph` cannot be converted.
ignore_ops: if `True`, silently ignore `tf.Operation`.
Returns:
A newly created list of `tf.Tensor`.
Raises:
TypeError: if `ts` cannot be converted to a list of `tf.Tensor` or,
if `check_graph` is `True`, if all the ops do not belong to the same graph.
"""
if isinstance(ts, ops.Graph):
if allow_graph:
return get_tensors(ts)
else:
raise TypeError("allow_graph is False: cannot convert a tf.Graph.")
else:
if not is_iterable(ts):
ts = [ts]
if not ts:
return []
if check_graph:
check_types = None if ignore_ops else ops.Tensor
get_unique_graph(ts, check_types=check_types)
return [t for t in ts if isinstance(t, ops.Tensor)]
def get_generating_ops(ts):
"""Return all the generating ops of the tensors in `ts`.
Args:
ts: a list of `tf.Tensor`
Returns:
A list of all the generating `tf.Operation` of the tensors in `ts`.
Raises:
TypeError: if `ts` cannot be converted to a list of `tf.Tensor`.
"""
ts = make_list_of_t(ts, allow_graph=False)
return [t.op for t in ts]
def get_consuming_ops(ts):
"""Return all the consuming ops of the tensors in ts.
Args:
ts: a list of `tf.Tensor`
Returns:
A list of all the consuming `tf.Operation` of the tensors in `ts`.
Raises:
TypeError: if ts cannot be converted to a list of `tf.Tensor`.
"""
ts = make_list_of_t(ts, allow_graph=False)
tops = []
for t in ts:
for op in t.consumers():
if op not in tops:
tops.append(op)
return tops
def make_list_of_op(tops, check_graph=True, allow_graph=True, ignore_ts=False):
"""Convert ops to a list of `tf.Operation`.
Args:
tops: can be an iterable of `tf.Operation`, a `tf.Graph` or a single
operation.
check_graph: if `True` check if all the operations belong to the same graph.
allow_graph: if `False` a `tf.Graph` cannot be converted.
ignore_ts: if True, silently ignore `tf.Tensor`.
Returns:
A newly created list of `tf.Operation`.
Raises:
TypeError: if tops cannot be converted to a list of `tf.Operation` or,
if `check_graph` is `True`, if all the ops do not belong to the
same graph.
"""
if isinstance(tops, ops.Graph):
if allow_graph:
return tops.get_operations()
else:
raise TypeError("allow_graph is False: cannot convert a tf.Graph.")
else:
if not is_iterable(tops):
tops = [tops]
if not tops:
return []
if check_graph:
check_types = None if ignore_ts else ops.Operation
get_unique_graph(tops, check_types=check_types)
return [op for op in tops if isinstance(op, ops.Operation)]
def _get_inputs(op, only_differentiable):
op_inputs = op.inputs
if only_differentiable:
return op_inputs if is_differentiable(op) else []
else:
return op_inputs
def get_backward_walk_ops(seed_ops,
inclusive=True,
within_ops=None,
within_ops_fn=None,
stop_at_ts=(),
control_inputs=False,
only_differentiable=False):
"""Do a backward graph walk and return all the visited ops.
Args:
seed_ops: an iterable of operations from which the backward graph
walk starts. If a list of tensors is given instead, the seed_ops are set
to be the generators of those tensors.
inclusive: if True the given seed_ops are also part of the resulting set.
within_ops: an iterable of `tf.Operation` within which the search is
restricted. If `within_ops` is `None`, the search is performed within
the whole graph.
within_ops_fn: if provided, a function on ops that should return True iff
the op is within the graph traversal. This can be used along within_ops,
in which case an op is within if it is also in within_ops.
stop_at_ts: an iterable of tensors at which the graph walk stops.
control_inputs: if True, control inputs will be used while moving backward.
only_differentiable: if True, only traverse ops which are differentiable.
This includes natively differentiable ops, or ops with custom gradients.
Returns:
A Python set of all the `tf.Operation` behind `seed_ops`.
Raises:
TypeError: if `seed_ops` or `within_ops` cannot be converted to a list of
`tf.Operation`.
"""
control_inputs = control_inputs and (not only_differentiable)
if not is_iterable(seed_ops):
seed_ops = [seed_ops]
if not seed_ops:
return []
if isinstance(seed_ops[0], ops.Tensor):
ts = make_list_of_t(seed_ops, allow_graph=False)
seed_ops = get_generating_ops(ts)
else:
seed_ops = make_list_of_op(seed_ops, allow_graph=False)
stop_at_ts = frozenset(make_list_of_t(stop_at_ts))
seed_ops = frozenset(make_list_of_op(seed_ops))
if within_ops:
within_ops = make_list_of_op(within_ops, allow_graph=False)
within_ops = frozenset(within_ops)
seed_ops &= within_ops
def is_within(op):
return (within_ops is None or op in within_ops) and (
within_ops_fn is None or within_ops_fn(op))
result = list(seed_ops)
wave = set(seed_ops)
while wave:
new_wave = set()
for op in wave:
for new_t in _get_inputs(op, only_differentiable=only_differentiable):
if new_t in stop_at_ts:
continue
if new_t.op not in result and is_within(new_t.op):
new_wave.add(new_t.op)
if control_inputs:
for new_op in op.control_inputs:
if new_op not in result and is_within(new_op):
new_wave.add(new_op)
concatenate_unique(result, new_wave)
wave = new_wave
if not inclusive:
result = [op for op in result if op not in seed_ops]
return result
class UnliftableError(Exception):
"""Raised if a Tensor cannot be lifted from the graph."""
# Prevent autograph from rewriting this error.
ag_pass_through = True
def _as_operation(op_or_tensor):
if isinstance(op_or_tensor, ops.Tensor):
return op_or_tensor.op
return op_or_tensor
def graph_inputs(op):
return [x.op for x in op.inputs] + list(op.control_inputs)
def _path_from(from_op, tensor, sources):
"""Find one path from `from_op` to `tensor`, ignoring `sources`.
Args:
from_op: A `tf.Operation`.
tensor: A `tf.Operation` or `tf.Tensor`.
sources: A list of `tf.Tensor`.
Returns:
A python string containing the path, or "??" if none is found.
"""
if isinstance(from_op, ops.Tensor):
from_op = from_op.op
visited_ops = set([x.op for x in sources])
ops_to_visit = [_as_operation(tensor)]
some_op_output = {}
while ops_to_visit:
op = ops_to_visit.pop()
if op in visited_ops:
continue
visited_ops.add(op)
if op == from_op:
path_op = op
path = [path_op]
final_op = _as_operation(tensor)
while path_op != final_op:
path_op = some_op_output[path_op]
path.append(path_op)
return " <- ".join(["%s (%s)" % (x.name, x.type) for x in reversed(path)])
else:
for inp in graph_inputs(op):
if inp not in visited_ops and inp not in sources:
some_op_output[inp] = op
ops_to_visit.append(inp)
return "??"
# TODO(jmenick) - there is considerable duplication of functionality between
# this function and get_backward_walk_ops(). Need to deduplicate.
def map_subgraph(init_tensor, sources, disallowed_placeholders, visited_ops,
op_outputs, add_sources):
"""Walk a Graph and capture the subgraph between init_tensor and sources.
Note: This function mutates visited_ops and op_outputs.
Arguments:
init_tensor: A Tensor or Operation where the subgraph terminates.
sources: A set of Tensors where subgraph extraction should stop.
disallowed_placeholders: An optional set of ops which may not appear in the
lifted graph. Defaults to all placeholders.
visited_ops: A set of operations which were visited in a prior pass.
op_outputs: A defaultdict containing the outputs of an op which are to be
copied into the new subgraph.
add_sources: A boolean indicating whether placeholders which are not in
sources should be allowed.
Returns:
The set of placeholders upon which init_tensor depends and are not in
sources.
Raises:
UnliftableError: if init_tensor depends on a placeholder which is not in
sources and add_sources is False.
"""
ops_to_visit = [_as_operation(init_tensor)]
extra_sources = set()
while ops_to_visit:
op = ops_to_visit.pop()
if op in visited_ops:
continue
visited_ops.add(op)
should_raise = False
if disallowed_placeholders is not None and op in disallowed_placeholders:
should_raise = True
elif op.type == "Placeholder":
if disallowed_placeholders is None and not add_sources:
should_raise = True
extra_sources.update(op.outputs)
if should_raise:
raise UnliftableError(
"Unable to lift tensor %s because it depends transitively on "
"placeholder %s via at least one path, e.g.: %s"
% (repr(init_tensor), repr(op), _path_from(op, init_tensor, sources)))
for inp in graph_inputs(op):
op_outputs[inp].add(op)
if inp not in visited_ops and inp not in (sources or extra_sources):
ops_to_visit.append(inp)
return extra_sources

View File

@ -0,0 +1,180 @@
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for op_selector.py."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import func_graph
from tensorflow.python.framework import ops as ops_lib
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import op_selector
from tensorflow.python.platform import test
class SelectTest(test.TestCase):
def setUp(self):
self.graph = ops_lib.Graph()
with self.graph.as_default():
self.a = constant_op.constant([1., 1.], shape=[2], name="a")
with ops_lib.name_scope("foo"):
self.b = constant_op.constant([2., 2.], shape=[2], name="b")
self.c = math_ops.add(self.a, self.b, name="c")
self.d = constant_op.constant([3., 3.], shape=[2], name="d")
with ops_lib.name_scope("bar"):
self.e = math_ops.add(self.c, self.d, name="e")
self.f = math_ops.add(self.c, self.d, name="f")
self.g = math_ops.add(self.c, self.a, name="g")
with ops_lib.control_dependencies([self.c.op]):
self.h = math_ops.add(self.f, self.g, name="h")
def test_is_iterable(self):
"""Test for is_iterable."""
self.assertTrue(op_selector.is_iterable([0, 1, 2]))
self.assertFalse(op_selector.is_iterable(3))
def test_unique_graph(self):
"""Test for check_graphs and get_unique_graph."""
g0 = ops_lib.Graph()
with g0.as_default():
a0 = constant_op.constant(1)
b0 = constant_op.constant(2)
g1 = ops_lib.Graph()
with g1.as_default():
a1 = constant_op.constant(1)
b1 = constant_op.constant(2)
# Same graph, should be fine.
self.assertIsNone(op_selector.check_graphs(a0, b0))
# Two different graphs, should assert.
with self.assertRaises(ValueError):
op_selector.check_graphs(a0, b0, a1, b1)
# a0 and b0 belongs to the same graph, should be fine.
self.assertEqual(op_selector.get_unique_graph([a0, b0]), g0)
# Different graph, should raise an error.
with self.assertRaises(ValueError):
op_selector.get_unique_graph([a0, b0, a1, b1])
def test_unique_graph_func_graph(self):
"""Test for get_unique_graph with FuncGraph."""
outer = ops_lib.Graph()
with outer.as_default():
k1 = constant_op.constant(1)
inner = func_graph.FuncGraph("inner")
inner._graph_key = outer._graph_key
with inner.as_default():
k2 = constant_op.constant(2)
unique_graph = op_selector.get_unique_graph([k1, k2])
self.assertEqual(unique_graph._graph_key, inner._graph_key)
def test_make_list_of_op(self):
"""Test for make_list_of_op."""
g0 = ops_lib.Graph()
with g0.as_default():
a0 = constant_op.constant(1)
b0 = constant_op.constant(2)
# Should extract the ops from the graph.
self.assertEqual(len(op_selector.make_list_of_op(g0)), 2)
# Should extract the ops from the tuple.
self.assertEqual(len(op_selector.make_list_of_op((a0.op, b0.op))), 2)
def test_make_list_of_t(self):
"""Test for make_list_of_t."""
g0 = ops_lib.Graph()
with g0.as_default():
a0 = constant_op.constant(1)
b0 = constant_op.constant(2)
c0 = math_ops.add(a0, b0) # pylint: disable=unused-variable
# Should extract the tensors from tre graph.
self.assertEqual(len(op_selector.make_list_of_t(g0)), 3)
# Should extract the tensors from the tuple
self.assertEqual(len(op_selector.make_list_of_t((a0, b0))), 2)
# Should extract the tensors and ignore the ops.
self.assertEqual(
len(op_selector.make_list_of_t(
(a0, a0.op, b0), ignore_ops=True)), 2)
def test_get_generating_consuming(self):
"""Test for get_generating_ops and get_consuming_ops."""
g0 = ops_lib.Graph()
with g0.as_default():
a0 = constant_op.constant(1)
b0 = constant_op.constant(2)
c0 = math_ops.add(a0, b0)
self.assertEqual(len(op_selector.get_generating_ops([a0, b0])), 2)
self.assertEqual(len(op_selector.get_consuming_ops([a0, b0])), 1)
self.assertEqual(len(op_selector.get_generating_ops([c0])), 1)
self.assertEqual(op_selector.get_consuming_ops([c0]), [])
def test_backward_walk_ops(self):
seed_ops = [self.h.op]
# Include all ops except for self.g.op
within_ops = [
x.op for x in [self.a, self.b, self.c, self.d, self.e, self.f, self.h]
]
# For the fn, exclude self.c.op.
within_ops_fn = lambda op: op not in (self.c.op,)
stop_at_ts = (self.f,)
with self.graph.as_default():
# Backward walk only includes h since we stop at f and g is not within.
ops = op_selector.get_backward_walk_ops(
seed_ops,
inclusive=True,
within_ops=within_ops,
within_ops_fn=within_ops_fn,
stop_at_ts=stop_at_ts)
self.assertEqual(set(ops), set([self.h.op]))
# If we do inclusive=False, the result is empty.
ops = op_selector.get_backward_walk_ops(
seed_ops,
inclusive=False,
within_ops=within_ops,
within_ops_fn=within_ops_fn,
stop_at_ts=stop_at_ts)
self.assertEqual(set(ops), set())
# Removing stop_at_fs adds f.op, d.op.
ops = op_selector.get_backward_walk_ops(
seed_ops,
inclusive=True,
within_ops=within_ops,
within_ops_fn=within_ops_fn)
self.assertEqual(set(ops), set([self.d.op, self.f.op, self.h.op]))
# Not using within_ops_fn adds back ops for a, b, c.
ops = op_selector.get_backward_walk_ops(
seed_ops, inclusive=True, within_ops=within_ops)
self.assertEqual(
set(ops),
set([
self.a.op, self.b.op, self.c.op, self.d.op, self.f.op, self.h.op
]))
# Vanially backward search via self.h.op includes everything excpet e.op.
ops = op_selector.get_backward_walk_ops(seed_ops, inclusive=True)
self.assertEqual(
set(ops),
set([
self.a.op, self.b.op, self.c.op, self.d.op, self.f.op, self.g.op,
self.h.op
]))
if __name__ == "__main__":
test.main()