diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD index 6f798fcdb06..af7930224b5 100644 --- a/tensorflow/python/eager/BUILD +++ b/tensorflow/python/eager/BUILD @@ -6,6 +6,7 @@ load( "//tensorflow/tools/test:performance.bzl", "tf_py_logged_benchmark", ) +load("//tensorflow/compiler/tests:build_defs.bzl", "tf_xla_py_test") cc_library( name = "pywrap_tfe_lib", @@ -523,6 +524,19 @@ tf_py_test( ], ) +tf_xla_py_test( + name = "def_function_xla_test", + srcs = ["def_function_xla_test.py"], + tags = ["no_pip"], + deps = [ + ":def_function", + "//tensorflow/compiler/tests:xla_test", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:framework_ops", + ], +) + py_library( name = "wrap_function", srcs = ["wrap_function.py"], diff --git a/tensorflow/python/eager/def_function.py b/tensorflow/python/eager/def_function.py index 897a38e7d34..43cb07e79d7 100644 --- a/tensorflow/python/eager/def_function.py +++ b/tensorflow/python/eager/def_function.py @@ -25,7 +25,6 @@ import weakref from tensorflow.python.eager import context from tensorflow.python.eager import function as function_lib from tensorflow.python.eager import lift_to_graph -from tensorflow.python.eager import tape from tensorflow.python.framework import func_graph as func_graph_module from tensorflow.python.framework import ops from tensorflow.python.ops import control_flow_ops @@ -57,8 +56,6 @@ class UnliftedInitializerVariable(resource_variable_ops.ResourceVariable): constraint=None, add_initializers_to=None, lifted_initializer_graph=None, - lifted_all_initializers=None, - lifted_placeholders=None, **unused_kwargs): """Creates a variable. @@ -90,13 +87,9 @@ class UnliftedInitializerVariable(resource_variable_ops.ResourceVariable): (which must have the same shape). Constraints are not safe to use when doing asynchronous distributed training. add_initializers_to: if not None and not in legacy graph mode, the - initializer tensor will be added to this map instead of adding the + initializer tensor will be added to this map in addition to adding the assignment to the function. lifted_initializer_graph: FuncGraph to try to lift initializers to. - lifted_all_initializers: list with one boolean element, which will be - set to False if we cannot lift this initializer to the above graph. - lifted_placeholders: placeholders for resource handles lifted out of - this graph. Raises: ValueError: If the initial value is not specified, or does not have a @@ -174,7 +167,6 @@ class UnliftedInitializerVariable(resource_variable_ops.ResourceVariable): with ops.name_scope("Assign") as n, ops.colocate_with(self._handle): self._initializer_op = resource_variable_ops.assign_variable_op( self._handle, lifted_initializer, name=n) - assign = self._initializer_op with ops.name_scope("Read"), ops.colocate_with(self._handle): # Manually assign reads to the handle's device to avoid log # messages. @@ -185,32 +177,21 @@ class UnliftedInitializerVariable(resource_variable_ops.ResourceVariable): else: if add_initializers_to is not None: add_initializers_to[self] = initial_value - assign = None - else: - def assign_fn(): - with ops.name_scope("Assign") as n, ops.colocate_with(self._handle): - resource_variable_ops.assign_variable_op( - self._handle, - initial_value, - name=n) - # Returning values to keep tf.cond happy. - return ops.convert_to_tensor(1) - def not_assign_fn(): - return ops.convert_to_tensor(0) - # Note: this cond is always guaranteed to run because we're inside a - # defun which will insert automatic control dependencies. - assign = control_flow_ops.cond( - resource_variable_ops.var_is_initialized_op(self._handle), - not_assign_fn, assign_fn) - if lifted_initializer_graph is not None and assign is not None: - try: - handle_placeholder = ops.convert_to_tensor(self._handle) - op_map = lift_to_graph.lift_to_graph( - assign, lifted_initializer_graph, - sources=[handle_placeholder]) - lifted_placeholders.append((self._handle, op_map[handle_placeholder])) - except ValueError: - lifted_all_initializers[0] = False + def assign_fn(): + with ops.name_scope("Assign") as n, ops.colocate_with(self._handle): + resource_variable_ops.assign_variable_op( + self._handle, + initial_value, + name=n) + # Returning values to keep tf.cond happy. + return ops.convert_to_tensor(1) + def not_assign_fn(): + return ops.convert_to_tensor(0) + # Note: this cond is always guaranteed to run because we're inside a + # defun which will insert automatic control dependencies. + control_flow_ops.cond( + resource_variable_ops.var_is_initialized_op(self._handle), + not_assign_fn, assign_fn) # After the handle has been created, set up a way to clean it up when # executing eagerly. We'll hold the only reference to the deleter, so that @@ -340,16 +321,12 @@ class Function(object): created_variables = [] lifted_initializer_graph = func_graph_module.FuncGraph("initializer") - lifted_all_initializers = [True] - lifted_placeholders = [] def variable_capturing_scope(unused_next_creator, **kwds): """Creates UnliftedInitializerVariables and saves references to them.""" v = UnliftedInitializerVariable( add_initializers_to=add_initializers_to, - lifted_initializer_graph=lifted_initializer_graph, - lifted_all_initializers=lifted_all_initializers, - lifted_placeholders=lifted_placeholders, **kwds) + lifted_initializer_graph=lifted_initializer_graph, **kwds) created_variables.append(weakref.ref(v)) return v @@ -359,11 +336,9 @@ class Function(object): # Force the definition of the function for these arguments self._lifted_initializer_graph = lifted_initializer_graph self._graph_deleter = FunctionDeleter(self._lifted_initializer_graph) - self._lifted_placeholders = lifted_placeholders self._concrete_stateful_fn = ( self._stateful_fn._get_concrete_function_internal_garbage_collected( # pylint: disable=protected-access *args, **kwds)) - self._lifted_all_initializers = lifted_all_initializers[0] def invalid_creator_scope(*unused_args, **unused_kwds): """Disables variable creation.""" @@ -390,21 +365,22 @@ class Function(object): return results # This is the first call of __call__, so we have to initialize. - self._initialize(args, kwds) - if self._lifted_all_initializers and self._lifted_placeholders: - with ops.init_scope(): - handles, placeholders = zip(*self._lifted_placeholders) - if context.executing_eagerly(): - lifted_fn = function_lib._EagerDefinedFunction( # pylint: disable=protected-access - "initializer" + str(ops.uid()), - self._lifted_initializer_graph, - placeholders, [], {}) - with tape.stop_recording(): - lifted_fn.call(context.context(), list(handles)) - return self._stateless_fn(*args, **kwds) - canon_args, canon_kwds = self._canonicalize_function_inputs(args, kwds) - - if not self._created_variables: + initializer_map = {} + self._initialize(args, kwds, add_initializers_to=initializer_map) + if self._created_variables: + try: + # Attempt to initialize variables eagerly and without conds by lifting + # out initialization graphs. This is the only initialization strategy + # compatible with XLA at the moment. + self._initialize_uninitialized_variables(initializer_map) + except lift_to_graph.UnliftableError: + pass # Fall through to cond-based initialization. + else: + # Lifting succeeded, so variables are initialized and we can run the + # stateless function. + return self._stateless_fn(*args, **kwds) + else: + canon_args, canon_kwds = self._canonicalize_function_inputs(args, kwds) # If we did not create any variables the trace we have is good enough. return self._concrete_stateful_fn._filtered_call(canon_args, canon_kwds) # pylint: disable=protected-access @@ -459,6 +435,9 @@ class Function(object): functools.partial(self._concrete_stateful_fn._filtered_call, # pylint: disable=protected-access inner_args, inner_kwds)) + # We've created variables and are unable to lift the initialization graphs, + # so we fall back to initializing with conds while running the function. + canon_args, canon_kwds = self._canonicalize_function_inputs(args, kwds) return function_lib.defun(fn_with_cond)(*canon_args, **canon_kwds) @property @@ -474,6 +453,23 @@ class Function(object): def function_spec(self): return self._function_spec + def _initialize_uninitialized_variables(self, initializer_map): + """Make and call a `ConcreteFunction` which initializes variables.""" + + # Note: using defun here avoids an infinite recursion. + @function_lib.defun + def initialize_variables(): + for v, init in initializer_map.items(): + with ops.init_scope(): + if resource_variable_ops.var_is_initialized_op(v.handle): + # Ignore variables which are already initialized at trace time. + continue + v.assign(lift_to_graph.lift_to_graph( + init, ops.get_default_graph())[init]) + + with ops.init_scope(): + return initialize_variables.get_concrete_function()() + def get_initialization_function(self, *args, **kwargs): """Returns a `ConcreteFunction` which initializes this function's variables. @@ -482,6 +478,9 @@ class Function(object): function which does not depend on the concrete values of the inputs to this function. + Note that running this function will overwrite any values currently assigned + to variables, for example restores from a checkpoint. + Args: *args: arguments to the underlying python callable. **kwargs: keyword arguments to the python callable. @@ -626,7 +625,9 @@ class Function(object): """ assert context.executing_eagerly() if self._stateful_fn is None: - self.get_initialization_function(*args, **kwargs)() + initializer_map = {} + self._initialize(args, kwargs, add_initializers_to=initializer_map) + self._initialize_uninitialized_variables(initializer_map) if self._created_variables: # In this case we have created variables on the first call, so we run the diff --git a/tensorflow/python/eager/def_function_test.py b/tensorflow/python/eager/def_function_test.py index 912198dfcfb..4a47f67fd16 100644 --- a/tensorflow/python/eager/def_function_test.py +++ b/tensorflow/python/eager/def_function_test.py @@ -22,6 +22,7 @@ import weakref from tensorflow.python.eager import backprop from tensorflow.python.eager import def_function +from tensorflow.python.eager import lift_to_graph from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -208,7 +209,7 @@ class DefFunctionTest(test.TestCase): state.append(variables.Variable(2.0 * x)) return state[0] * x - with self.assertRaises(ValueError): + with self.assertRaises(lift_to_graph.UnliftableError): fn(constant_op.constant(3.0)) def testMethod(self): @@ -343,6 +344,45 @@ class DefFunctionTest(test.TestCase): f() self.assertEqual(created_variables, captured_variables) + def testVarAlreadyInitializedNoClobbering(self): + v_holder = [] + + @def_function.function + def add_var(x): + if not v_holder: + v = variables.Variable([1., 2.]) + v_holder.append(v) + already_initialized = variables.Variable(3.) + with ops.init_scope(): + already_initialized.assign(10.) + v_holder.append(already_initialized) + return v_holder[0] + v_holder[1] + x + + add_var.get_concrete_function(constant_op.constant(2.)) + self.assertAllClose([13., 14.], add_var(constant_op.constant(2.))) + + def testInitializationInNestedCall(self): + v_holder = [] + + @def_function.function + def add_var(x): + if not v_holder: + v = variables.Variable([1., 2.]) + v_holder.append(v) + already_initialized = variables.Variable(3.) + with ops.init_scope(): + already_initialized.assign(10.) + v_holder.append(already_initialized) + return v_holder[0] + v_holder[1] + x + + @def_function.function + def wrapper(x): + return add_var(x) + + self.assertAllClose([13., 14.], wrapper(constant_op.constant(2.))) + v_holder[1].assign(11.) + self.assertAllClose([14., 15.], wrapper(constant_op.constant(2.))) + if __name__ == '__main__': ops.enable_eager_execution() diff --git a/tensorflow/python/eager/def_function_xla_test.py b/tensorflow/python/eager/def_function_xla_test.py new file mode 100644 index 00000000000..9115d8a6943 --- /dev/null +++ b/tensorflow/python/eager/def_function_xla_test.py @@ -0,0 +1,49 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.compiler.tests import xla_test +from tensorflow.python.eager import def_function +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test + + +class DefFunctionTests(xla_test.XLATestCase): + + def testVarInitializedInFunction(self): + with self.test_scope(): + v_holder = [] + + @def_function.function + def add_var(x): + if not v_holder: + v = variables.Variable([1., 2.]) + v_holder.append(v) + already_initialized = variables.Variable(3.) + with ops.init_scope(): + already_initialized.assign(10.) + v_holder.append(already_initialized) + return v_holder[0] + v_holder[1] + x + + self.assertAllClose([13., 14.], add_var(constant_op.constant(2.))) + + +if __name__ == "__main__": + ops.enable_eager_execution() + test.main() diff --git a/tensorflow/python/eager/lift_to_graph.py b/tensorflow/python/eager/lift_to_graph.py index 2e9d24f61ea..ad62e6d10ac 100644 --- a/tensorflow/python/eager/lift_to_graph.py +++ b/tensorflow/python/eager/lift_to_graph.py @@ -35,6 +35,11 @@ def _as_operation(op_or_tensor): return op_or_tensor +class UnliftableError(Exception): + """Raised if a Tensor cannot be lifted from the graph.""" + pass + + def lift_to_graph(init_tensor, graph, sources=None): """Copies the tensor and all its inputs recursively to the outer graph.""" # Check that the initializer does not depend on any placeholders. @@ -52,7 +57,7 @@ def lift_to_graph(init_tensor, graph, sources=None): # and placeholders the user might directly use to initialize # variables. if op.type == "Placeholder": - raise ValueError( + raise UnliftableError( "Unable to lift tensor", init_tensor, "because it depends transitively on placeholder ", op) for inp in _graph_inputs(op):