Make tf.make_template DistributionStrategy aware.

PiperOrigin-RevId: 238197434
This commit is contained in:
Chris Jones 2019-03-13 03:40:24 -07:00 committed by TensorFlower Gardener
parent 5392565fb8
commit 444a8bc05f
3 changed files with 95 additions and 16 deletions

View File

@ -1181,6 +1181,21 @@ tf_py_test(
],
)
cuda_py_test(
name = "template_mirrored_strategy_test",
size = "small",
srcs = ["template_mirrored_strategy_test.py"],
additional_deps = [
"//tensorflow/python/distribute:distribute_lib",
"//tensorflow/python/distribute:mirrored_strategy",
"//tensorflow/python:client_testlib",
"//tensorflow/python:init_ops",
"//tensorflow/python:template",
"//tensorflow/python:variable_scope",
"//tensorflow/python:variables",
],
)
tf_py_test(
name = "tridiagonal_solve_op_test",
size = "medium",

View File

@ -0,0 +1,52 @@
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for make_template used with MirroredStrategy."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.distribute import distribution_strategy_context as ds_context
from tensorflow.python.distribute import mirrored_strategy
from tensorflow.python.framework import test_util
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import template
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
class TemplateMirroredStrategyTest(test.TestCase):
@test_util.run_deprecated_v1
def test_merge_call(self):
def fn():
var1 = variable_scope.get_variable(
"var1", shape=[], initializer=init_ops.constant_initializer(21.))
ds_context.get_replica_context().merge_call(lambda _: ())
var2 = variable_scope.get_variable(
"var2", shape=[], initializer=init_ops.constant_initializer(2.))
return var1 * var2
temp = template.make_template("my_template", fn)
strategy = mirrored_strategy.MirroredStrategy(["/cpu:0", "/gpu:0"])
out = strategy.unwrap(strategy.experimental_run_v2(temp))
self.evaluate(variables.global_variables_initializer())
self.assertAllEqual([42., 42.], self.evaluate(out))
if __name__ == "__main__":
test.main()

View File

@ -292,30 +292,31 @@ class Template(trackable.Trackable):
self._variable_scope = vs
else:
self._variable_scope = None
# This variable keeps track of whether the template has been called yet,
# which is not the same as whether the scope has been created.
# This variable keeps track of whether the template has been called to
# completion, which is not the same as whether the scope has been created.
self._variables_created = False
# `MirroredStrategy` builds the graph with multiple threads. If a
# `merge_call` happens within a template, multiple calls may be in progress
# simultaneously. This variable keeps track of whether any call of the
# template has started.
self._first_call = True
def _call_func(self, args, kwargs):
try:
vars_at_start = len(
ops.get_collection_ref(ops.GraphKeys.GLOBAL_VARIABLES))
trainable_at_start = len(
ops.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES))
if self._variables_created:
result = self._func(*args, **kwargs)
else:
# The first time we run, restore variables if necessary (via
# Trackable).
with trackable_util.capture_dependencies(template=self):
result = self._func(*args, **kwargs)
vars_at_start = len(
ops.get_collection_ref(ops.GraphKeys.GLOBAL_VARIABLES))
trainable_at_start = len(
ops.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES))
result = self._func(*args, **kwargs)
if self._variables_created:
# Variables were previously created, implying this is not the first
# time the template has been called. Check to make sure that no new
# trainable variables were created this time around.
trainable_variables = ops.get_collection_ref(
ops.GraphKeys.TRAINABLE_VARIABLES)
# If a variable that we intend to train is created as a side effect
# of creating a template, then that is almost certainly an error.
if trainable_at_start != len(trainable_variables):
@ -333,8 +334,19 @@ class Template(trackable.Trackable):
"the first time, perhaps you used tf.Variable when you "
"meant tf.get_variable: %s",
variables[vars_at_start:])
else:
elif self._first_call:
self._first_call = False
try:
# The first time we run, restore variables if necessary (via
# Trackable).
with trackable_util.capture_dependencies(template=self):
result = self._func(*args, **kwargs)
except:
self._first_call = True
raise
self._variables_created = True
else: # We are calling the template in parallel from another thread.
result = self._func(*args, **kwargs)
return result
except Exception as exc:
# Reraise the exception, but append the original definition to the
@ -354,9 +366,9 @@ class Template(trackable.Trackable):
def __call__(self, *args, **kwargs):
if self._variable_scope:
# Only reuse variables if they were already created.
# Only reuse variables if not on first call.
with variable_scope.variable_scope(
self._variable_scope, reuse=self._variables_created):
self._variable_scope, reuse=not self._first_call):
return self._call_func(args, kwargs)
else:
# The scope was not created at construction time, so create it here.