From 444a8bc05f73d820c6146c914759fda36240a34c Mon Sep 17 00:00:00 2001 From: Chris Jones Date: Wed, 13 Mar 2019 03:40:24 -0700 Subject: [PATCH] Make `tf.make_template` DistributionStrategy aware. PiperOrigin-RevId: 238197434 --- tensorflow/python/kernel_tests/BUILD | 15 ++++++ .../template_mirrored_strategy_test.py | 52 +++++++++++++++++++ tensorflow/python/ops/template.py | 44 ++++++++++------ 3 files changed, 95 insertions(+), 16 deletions(-) create mode 100644 tensorflow/python/kernel_tests/template_mirrored_strategy_test.py diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD index cec51dd1e87..6b0c6a62642 100644 --- a/tensorflow/python/kernel_tests/BUILD +++ b/tensorflow/python/kernel_tests/BUILD @@ -1181,6 +1181,21 @@ tf_py_test( ], ) +cuda_py_test( + name = "template_mirrored_strategy_test", + size = "small", + srcs = ["template_mirrored_strategy_test.py"], + additional_deps = [ + "//tensorflow/python/distribute:distribute_lib", + "//tensorflow/python/distribute:mirrored_strategy", + "//tensorflow/python:client_testlib", + "//tensorflow/python:init_ops", + "//tensorflow/python:template", + "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", + ], +) + tf_py_test( name = "tridiagonal_solve_op_test", size = "medium", diff --git a/tensorflow/python/kernel_tests/template_mirrored_strategy_test.py b/tensorflow/python/kernel_tests/template_mirrored_strategy_test.py new file mode 100644 index 00000000000..de94212a9eb --- /dev/null +++ b/tensorflow/python/kernel_tests/template_mirrored_strategy_test.py @@ -0,0 +1,52 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for make_template used with MirroredStrategy.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.distribute import distribution_strategy_context as ds_context +from tensorflow.python.distribute import mirrored_strategy +from tensorflow.python.framework import test_util +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import template +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables +from tensorflow.python.platform import test + + +class TemplateMirroredStrategyTest(test.TestCase): + + @test_util.run_deprecated_v1 + def test_merge_call(self): + def fn(): + var1 = variable_scope.get_variable( + "var1", shape=[], initializer=init_ops.constant_initializer(21.)) + ds_context.get_replica_context().merge_call(lambda _: ()) + var2 = variable_scope.get_variable( + "var2", shape=[], initializer=init_ops.constant_initializer(2.)) + return var1 * var2 + + temp = template.make_template("my_template", fn) + + strategy = mirrored_strategy.MirroredStrategy(["/cpu:0", "/gpu:0"]) + out = strategy.unwrap(strategy.experimental_run_v2(temp)) + + self.evaluate(variables.global_variables_initializer()) + self.assertAllEqual([42., 42.], self.evaluate(out)) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/ops/template.py b/tensorflow/python/ops/template.py index ff4f23a0e75..3ca9799e431 100644 --- a/tensorflow/python/ops/template.py +++ b/tensorflow/python/ops/template.py @@ -292,30 +292,31 @@ class Template(trackable.Trackable): self._variable_scope = vs else: self._variable_scope = None - # This variable keeps track of whether the template has been called yet, - # which is not the same as whether the scope has been created. + # This variable keeps track of whether the template has been called to + # completion, which is not the same as whether the scope has been created. self._variables_created = False + # `MirroredStrategy` builds the graph with multiple threads. If a + # `merge_call` happens within a template, multiple calls may be in progress + # simultaneously. This variable keeps track of whether any call of the + # template has started. + self._first_call = True def _call_func(self, args, kwargs): try: - vars_at_start = len( - ops.get_collection_ref(ops.GraphKeys.GLOBAL_VARIABLES)) - trainable_at_start = len( - ops.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES)) if self._variables_created: - result = self._func(*args, **kwargs) - else: - # The first time we run, restore variables if necessary (via - # Trackable). - with trackable_util.capture_dependencies(template=self): - result = self._func(*args, **kwargs) + vars_at_start = len( + ops.get_collection_ref(ops.GraphKeys.GLOBAL_VARIABLES)) + trainable_at_start = len( + ops.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES)) + + result = self._func(*args, **kwargs) - if self._variables_created: # Variables were previously created, implying this is not the first # time the template has been called. Check to make sure that no new # trainable variables were created this time around. trainable_variables = ops.get_collection_ref( ops.GraphKeys.TRAINABLE_VARIABLES) + # If a variable that we intend to train is created as a side effect # of creating a template, then that is almost certainly an error. if trainable_at_start != len(trainable_variables): @@ -333,8 +334,19 @@ class Template(trackable.Trackable): "the first time, perhaps you used tf.Variable when you " "meant tf.get_variable: %s", variables[vars_at_start:]) - else: + elif self._first_call: + self._first_call = False + try: + # The first time we run, restore variables if necessary (via + # Trackable). + with trackable_util.capture_dependencies(template=self): + result = self._func(*args, **kwargs) + except: + self._first_call = True + raise self._variables_created = True + else: # We are calling the template in parallel from another thread. + result = self._func(*args, **kwargs) return result except Exception as exc: # Reraise the exception, but append the original definition to the @@ -354,9 +366,9 @@ class Template(trackable.Trackable): def __call__(self, *args, **kwargs): if self._variable_scope: - # Only reuse variables if they were already created. + # Only reuse variables if not on first call. with variable_scope.variable_scope( - self._variable_scope, reuse=self._variables_created): + self._variable_scope, reuse=not self._first_call): return self._call_func(args, kwargs) else: # The scope was not created at construction time, so create it here.