Fork the tf internal ops._get_graph_from_inputs method into Keras so that we do not depend on the internal tf method and use only public apis.
PiperOrigin-RevId: 338805001 Change-Id: I6fec78be84995c8da1fe56a19f5d6cfa7d11d55c
This commit is contained in:
parent
49c60f3ba9
commit
1872370488
@ -41,6 +41,7 @@ from tensorflow.python.distribute import distribution_strategy_context
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import function as eager_function
|
||||
from tensorflow.python.eager import lift_to_graph
|
||||
from tensorflow.python.framework import composite_tensor
|
||||
from tensorflow.python.framework import config
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import device_spec
|
||||
@ -585,9 +586,109 @@ def eager_learning_phase_scope(value):
|
||||
del _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH.key]
|
||||
|
||||
|
||||
def _current_graph(op_input_list):
|
||||
"""Return the graph members of `op_input_list`, or the current graph."""
|
||||
return ops._get_graph_from_inputs(op_input_list)
|
||||
def _as_graph_element(obj):
|
||||
"""Convert `obj` to a graph element if possible, otherwise return `None`.
|
||||
|
||||
Args:
|
||||
obj: Object to convert.
|
||||
|
||||
Returns:
|
||||
The result of `obj._as_graph_element()` if that method is available;
|
||||
otherwise `None`.
|
||||
"""
|
||||
conv_fn = getattr(obj, '_as_graph_element', None)
|
||||
if conv_fn and callable(conv_fn):
|
||||
return conv_fn()
|
||||
return None
|
||||
|
||||
|
||||
def _assert_same_graph(original_item, item):
|
||||
"""Fail if the 2 items are from different graphs.
|
||||
|
||||
Args:
|
||||
original_item: Original item to check against.
|
||||
item: Item to check.
|
||||
|
||||
Raises:
|
||||
ValueError: if graphs do not match.
|
||||
"""
|
||||
original_graph = getattr(original_item, 'graph', None)
|
||||
graph = getattr(item, 'graph', None)
|
||||
if original_graph and graph and original_graph is not graph:
|
||||
raise ValueError(
|
||||
'%s must be from the same graph as %s (graphs are %s and %s).' %
|
||||
(item, original_item, graph, original_graph))
|
||||
|
||||
|
||||
def _current_graph(op_input_list, graph=None):
|
||||
"""Returns the appropriate graph to use for the given inputs.
|
||||
|
||||
This library method provides a consistent algorithm for choosing the graph
|
||||
in which an Operation should be constructed:
|
||||
|
||||
1. If the default graph is being used to construct a function, we
|
||||
use the default graph.
|
||||
2. If the "graph" is specified explicitly, we validate that all of the inputs
|
||||
in "op_input_list" are compatible with that graph.
|
||||
3. Otherwise, we attempt to select a graph from the first Operation-
|
||||
or Tensor-valued input in "op_input_list", and validate that all other
|
||||
such inputs are in the same graph.
|
||||
4. If the graph was not specified and it could not be inferred from
|
||||
"op_input_list", we attempt to use the default graph.
|
||||
|
||||
Args:
|
||||
op_input_list: A list of inputs to an operation, which may include `Tensor`,
|
||||
`Operation`, and other objects that may be converted to a graph element.
|
||||
graph: (Optional) The explicit graph to use.
|
||||
|
||||
Raises:
|
||||
TypeError: If op_input_list is not a list or tuple, or if graph is not a
|
||||
Graph.
|
||||
ValueError: If a graph is explicitly passed and not all inputs are from it,
|
||||
or if the inputs are from multiple graphs, or we could not find a graph
|
||||
and there was no default graph.
|
||||
|
||||
Returns:
|
||||
The appropriate graph to use for the given inputs.
|
||||
|
||||
"""
|
||||
current_default_graph = ops.get_default_graph()
|
||||
if current_default_graph.building_function:
|
||||
return current_default_graph
|
||||
|
||||
op_input_list = tuple(op_input_list) # Handle generators correctly
|
||||
if graph and not isinstance(graph, ops.Graph):
|
||||
raise TypeError('Input graph needs to be a Graph: %s' % (graph,))
|
||||
|
||||
# 1. We validate that all of the inputs are from the same graph. This is
|
||||
# either the supplied graph parameter, or the first one selected from one
|
||||
# the graph-element-valued inputs. In the latter case, we hold onto
|
||||
# that input in original_graph_element so we can provide a more
|
||||
# informative error if a mismatch is found.
|
||||
original_graph_element = None
|
||||
for op_input in op_input_list:
|
||||
# Determine if this is a valid graph_element.
|
||||
# TODO(josh11b): Note that we exclude subclasses of Tensor. Need to clean this
|
||||
# up.
|
||||
if (isinstance(op_input, (
|
||||
ops.Operation, ops.Tensor, composite_tensor.CompositeTensor)) and
|
||||
((not isinstance(op_input, ops.Tensor))
|
||||
or type(op_input) == ops.Tensor)): # pylint: disable=unidiomatic-typecheck
|
||||
graph_element = op_input
|
||||
else:
|
||||
graph_element = _as_graph_element(op_input)
|
||||
|
||||
if graph_element is not None:
|
||||
if not graph:
|
||||
original_graph_element = graph_element
|
||||
graph = getattr(graph_element, 'graph', None)
|
||||
elif original_graph_element is not None:
|
||||
_assert_same_graph(original_graph_element, graph_element)
|
||||
elif graph_element.graph is not graph:
|
||||
raise ValueError('%s is not from the passed-in graph.' % graph_element)
|
||||
|
||||
# 2. If all else fails, we use the default graph, which is always there.
|
||||
return graph or current_default_graph
|
||||
|
||||
|
||||
def _get_session(op_input_list=()):
|
||||
|
@ -688,7 +688,7 @@ class OptimizerV2(trackable.Trackable):
|
||||
# If the current context is graph mode or any of the update ops are
|
||||
# symbolic then the step update should be carried out under a graph
|
||||
# context. (eager updates execute immediately)
|
||||
with ops._get_graph_from_inputs(update_ops).as_default(): # pylint: disable=protected-access
|
||||
with backend._current_graph(update_ops).as_default(): # pylint: disable=protected-access
|
||||
with ops.control_dependencies([control_flow_ops.group(update_ops)]):
|
||||
return self._iterations.assign_add(1, read_value=False)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user