Fork the tf internal ops._get_graph_from_inputs method into Keras so that we do not depend on the internal tf method and use only public apis.

PiperOrigin-RevId: 338805001
Change-Id: I6fec78be84995c8da1fe56a19f5d6cfa7d11d55c
This commit is contained in:
Tomer Kaftan 2020-10-24 00:16:37 -07:00 committed by TensorFlower Gardener
parent 49c60f3ba9
commit 1872370488
2 changed files with 105 additions and 4 deletions

View File

@ -41,6 +41,7 @@ from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.eager import context
from tensorflow.python.eager import function as eager_function
from tensorflow.python.eager import lift_to_graph
from tensorflow.python.framework import composite_tensor
from tensorflow.python.framework import config
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import device_spec
@ -585,9 +586,109 @@ def eager_learning_phase_scope(value):
del _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH.key]
def _current_graph(op_input_list):
"""Return the graph members of `op_input_list`, or the current graph."""
return ops._get_graph_from_inputs(op_input_list)
def _as_graph_element(obj):
"""Convert `obj` to a graph element if possible, otherwise return `None`.
Args:
obj: Object to convert.
Returns:
The result of `obj._as_graph_element()` if that method is available;
otherwise `None`.
"""
conv_fn = getattr(obj, '_as_graph_element', None)
if conv_fn and callable(conv_fn):
return conv_fn()
return None
def _assert_same_graph(original_item, item):
"""Fail if the 2 items are from different graphs.
Args:
original_item: Original item to check against.
item: Item to check.
Raises:
ValueError: if graphs do not match.
"""
original_graph = getattr(original_item, 'graph', None)
graph = getattr(item, 'graph', None)
if original_graph and graph and original_graph is not graph:
raise ValueError(
'%s must be from the same graph as %s (graphs are %s and %s).' %
(item, original_item, graph, original_graph))
def _current_graph(op_input_list, graph=None):
"""Returns the appropriate graph to use for the given inputs.
This library method provides a consistent algorithm for choosing the graph
in which an Operation should be constructed:
1. If the default graph is being used to construct a function, we
use the default graph.
2. If the "graph" is specified explicitly, we validate that all of the inputs
in "op_input_list" are compatible with that graph.
3. Otherwise, we attempt to select a graph from the first Operation-
or Tensor-valued input in "op_input_list", and validate that all other
such inputs are in the same graph.
4. If the graph was not specified and it could not be inferred from
"op_input_list", we attempt to use the default graph.
Args:
op_input_list: A list of inputs to an operation, which may include `Tensor`,
`Operation`, and other objects that may be converted to a graph element.
graph: (Optional) The explicit graph to use.
Raises:
TypeError: If op_input_list is not a list or tuple, or if graph is not a
Graph.
ValueError: If a graph is explicitly passed and not all inputs are from it,
or if the inputs are from multiple graphs, or we could not find a graph
and there was no default graph.
Returns:
The appropriate graph to use for the given inputs.
"""
current_default_graph = ops.get_default_graph()
if current_default_graph.building_function:
return current_default_graph
op_input_list = tuple(op_input_list) # Handle generators correctly
if graph and not isinstance(graph, ops.Graph):
raise TypeError('Input graph needs to be a Graph: %s' % (graph,))
# 1. We validate that all of the inputs are from the same graph. This is
# either the supplied graph parameter, or the first one selected from one
# the graph-element-valued inputs. In the latter case, we hold onto
# that input in original_graph_element so we can provide a more
# informative error if a mismatch is found.
original_graph_element = None
for op_input in op_input_list:
# Determine if this is a valid graph_element.
# TODO(josh11b): Note that we exclude subclasses of Tensor. Need to clean this
# up.
if (isinstance(op_input, (
ops.Operation, ops.Tensor, composite_tensor.CompositeTensor)) and
((not isinstance(op_input, ops.Tensor))
or type(op_input) == ops.Tensor)): # pylint: disable=unidiomatic-typecheck
graph_element = op_input
else:
graph_element = _as_graph_element(op_input)
if graph_element is not None:
if not graph:
original_graph_element = graph_element
graph = getattr(graph_element, 'graph', None)
elif original_graph_element is not None:
_assert_same_graph(original_graph_element, graph_element)
elif graph_element.graph is not graph:
raise ValueError('%s is not from the passed-in graph.' % graph_element)
# 2. If all else fails, we use the default graph, which is always there.
return graph or current_default_graph
def _get_session(op_input_list=()):

View File

@ -688,7 +688,7 @@ class OptimizerV2(trackable.Trackable):
# If the current context is graph mode or any of the update ops are
# symbolic then the step update should be carried out under a graph
# context. (eager updates execute immediately)
with ops._get_graph_from_inputs(update_ops).as_default(): # pylint: disable=protected-access
with backend._current_graph(update_ops).as_default(): # pylint: disable=protected-access
with ops.control_dependencies([control_flow_ops.group(update_ops)]):
return self._iterations.assign_add(1, read_value=False)