diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py index 2493d32fe6a..b4d6bf6f0d2 100644 --- a/tensorflow/python/keras/backend.py +++ b/tensorflow/python/keras/backend.py @@ -41,6 +41,7 @@ from tensorflow.python.distribute import distribution_strategy_context from tensorflow.python.eager import context from tensorflow.python.eager import function as eager_function from tensorflow.python.eager import lift_to_graph +from tensorflow.python.framework import composite_tensor from tensorflow.python.framework import config from tensorflow.python.framework import constant_op from tensorflow.python.framework import device_spec @@ -585,9 +586,109 @@ def eager_learning_phase_scope(value): del _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH.key] -def _current_graph(op_input_list): - """Return the graph members of `op_input_list`, or the current graph.""" - return ops._get_graph_from_inputs(op_input_list) +def _as_graph_element(obj): + """Convert `obj` to a graph element if possible, otherwise return `None`. + + Args: + obj: Object to convert. + + Returns: + The result of `obj._as_graph_element()` if that method is available; + otherwise `None`. + """ + conv_fn = getattr(obj, '_as_graph_element', None) + if conv_fn and callable(conv_fn): + return conv_fn() + return None + + +def _assert_same_graph(original_item, item): + """Fail if the 2 items are from different graphs. + + Args: + original_item: Original item to check against. + item: Item to check. + + Raises: + ValueError: if graphs do not match. + """ + original_graph = getattr(original_item, 'graph', None) + graph = getattr(item, 'graph', None) + if original_graph and graph and original_graph is not graph: + raise ValueError( + '%s must be from the same graph as %s (graphs are %s and %s).' % + (item, original_item, graph, original_graph)) + + +def _current_graph(op_input_list, graph=None): + """Returns the appropriate graph to use for the given inputs. + + This library method provides a consistent algorithm for choosing the graph + in which an Operation should be constructed: + + 1. If the default graph is being used to construct a function, we + use the default graph. + 2. If the "graph" is specified explicitly, we validate that all of the inputs + in "op_input_list" are compatible with that graph. + 3. Otherwise, we attempt to select a graph from the first Operation- + or Tensor-valued input in "op_input_list", and validate that all other + such inputs are in the same graph. + 4. If the graph was not specified and it could not be inferred from + "op_input_list", we attempt to use the default graph. + + Args: + op_input_list: A list of inputs to an operation, which may include `Tensor`, + `Operation`, and other objects that may be converted to a graph element. + graph: (Optional) The explicit graph to use. + + Raises: + TypeError: If op_input_list is not a list or tuple, or if graph is not a + Graph. + ValueError: If a graph is explicitly passed and not all inputs are from it, + or if the inputs are from multiple graphs, or we could not find a graph + and there was no default graph. + + Returns: + The appropriate graph to use for the given inputs. + + """ + current_default_graph = ops.get_default_graph() + if current_default_graph.building_function: + return current_default_graph + + op_input_list = tuple(op_input_list) # Handle generators correctly + if graph and not isinstance(graph, ops.Graph): + raise TypeError('Input graph needs to be a Graph: %s' % (graph,)) + + # 1. We validate that all of the inputs are from the same graph. This is + # either the supplied graph parameter, or the first one selected from one + # the graph-element-valued inputs. In the latter case, we hold onto + # that input in original_graph_element so we can provide a more + # informative error if a mismatch is found. + original_graph_element = None + for op_input in op_input_list: + # Determine if this is a valid graph_element. + # TODO(josh11b): Note that we exclude subclasses of Tensor. Need to clean this + # up. + if (isinstance(op_input, ( + ops.Operation, ops.Tensor, composite_tensor.CompositeTensor)) and + ((not isinstance(op_input, ops.Tensor)) + or type(op_input) == ops.Tensor)): # pylint: disable=unidiomatic-typecheck + graph_element = op_input + else: + graph_element = _as_graph_element(op_input) + + if graph_element is not None: + if not graph: + original_graph_element = graph_element + graph = getattr(graph_element, 'graph', None) + elif original_graph_element is not None: + _assert_same_graph(original_graph_element, graph_element) + elif graph_element.graph is not graph: + raise ValueError('%s is not from the passed-in graph.' % graph_element) + + # 2. If all else fails, we use the default graph, which is always there. + return graph or current_default_graph def _get_session(op_input_list=()): diff --git a/tensorflow/python/keras/optimizer_v2/optimizer_v2.py b/tensorflow/python/keras/optimizer_v2/optimizer_v2.py index ca3f1a3a9b1..30de3aef4a5 100644 --- a/tensorflow/python/keras/optimizer_v2/optimizer_v2.py +++ b/tensorflow/python/keras/optimizer_v2/optimizer_v2.py @@ -688,7 +688,7 @@ class OptimizerV2(trackable.Trackable): # If the current context is graph mode or any of the update ops are # symbolic then the step update should be carried out under a graph # context. (eager updates execute immediately) - with ops._get_graph_from_inputs(update_ops).as_default(): # pylint: disable=protected-access + with backend._current_graph(update_ops).as_default(): # pylint: disable=protected-access with ops.control_dependencies([control_flow_ops.group(update_ops)]): return self._iterations.assign_add(1, read_value=False)