Make tf.make_template
DistributionStrategy aware.
PiperOrigin-RevId: 238197434
This commit is contained in:
parent
5392565fb8
commit
444a8bc05f
@ -1181,6 +1181,21 @@ tf_py_test(
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
name = "template_mirrored_strategy_test",
|
||||
size = "small",
|
||||
srcs = ["template_mirrored_strategy_test.py"],
|
||||
additional_deps = [
|
||||
"//tensorflow/python/distribute:distribute_lib",
|
||||
"//tensorflow/python/distribute:mirrored_strategy",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:init_ops",
|
||||
"//tensorflow/python:template",
|
||||
"//tensorflow/python:variable_scope",
|
||||
"//tensorflow/python:variables",
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "tridiagonal_solve_op_test",
|
||||
size = "medium",
|
||||
|
@ -0,0 +1,52 @@
|
||||
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for make_template used with MirroredStrategy."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.distribute import distribution_strategy_context as ds_context
|
||||
from tensorflow.python.distribute import mirrored_strategy
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import template
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class TemplateMirroredStrategyTest(test.TestCase):
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def test_merge_call(self):
|
||||
def fn():
|
||||
var1 = variable_scope.get_variable(
|
||||
"var1", shape=[], initializer=init_ops.constant_initializer(21.))
|
||||
ds_context.get_replica_context().merge_call(lambda _: ())
|
||||
var2 = variable_scope.get_variable(
|
||||
"var2", shape=[], initializer=init_ops.constant_initializer(2.))
|
||||
return var1 * var2
|
||||
|
||||
temp = template.make_template("my_template", fn)
|
||||
|
||||
strategy = mirrored_strategy.MirroredStrategy(["/cpu:0", "/gpu:0"])
|
||||
out = strategy.unwrap(strategy.experimental_run_v2(temp))
|
||||
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
self.assertAllEqual([42., 42.], self.evaluate(out))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
@ -292,30 +292,31 @@ class Template(trackable.Trackable):
|
||||
self._variable_scope = vs
|
||||
else:
|
||||
self._variable_scope = None
|
||||
# This variable keeps track of whether the template has been called yet,
|
||||
# which is not the same as whether the scope has been created.
|
||||
# This variable keeps track of whether the template has been called to
|
||||
# completion, which is not the same as whether the scope has been created.
|
||||
self._variables_created = False
|
||||
# `MirroredStrategy` builds the graph with multiple threads. If a
|
||||
# `merge_call` happens within a template, multiple calls may be in progress
|
||||
# simultaneously. This variable keeps track of whether any call of the
|
||||
# template has started.
|
||||
self._first_call = True
|
||||
|
||||
def _call_func(self, args, kwargs):
|
||||
try:
|
||||
vars_at_start = len(
|
||||
ops.get_collection_ref(ops.GraphKeys.GLOBAL_VARIABLES))
|
||||
trainable_at_start = len(
|
||||
ops.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES))
|
||||
if self._variables_created:
|
||||
result = self._func(*args, **kwargs)
|
||||
else:
|
||||
# The first time we run, restore variables if necessary (via
|
||||
# Trackable).
|
||||
with trackable_util.capture_dependencies(template=self):
|
||||
result = self._func(*args, **kwargs)
|
||||
vars_at_start = len(
|
||||
ops.get_collection_ref(ops.GraphKeys.GLOBAL_VARIABLES))
|
||||
trainable_at_start = len(
|
||||
ops.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES))
|
||||
|
||||
result = self._func(*args, **kwargs)
|
||||
|
||||
if self._variables_created:
|
||||
# Variables were previously created, implying this is not the first
|
||||
# time the template has been called. Check to make sure that no new
|
||||
# trainable variables were created this time around.
|
||||
trainable_variables = ops.get_collection_ref(
|
||||
ops.GraphKeys.TRAINABLE_VARIABLES)
|
||||
|
||||
# If a variable that we intend to train is created as a side effect
|
||||
# of creating a template, then that is almost certainly an error.
|
||||
if trainable_at_start != len(trainable_variables):
|
||||
@ -333,8 +334,19 @@ class Template(trackable.Trackable):
|
||||
"the first time, perhaps you used tf.Variable when you "
|
||||
"meant tf.get_variable: %s",
|
||||
variables[vars_at_start:])
|
||||
else:
|
||||
elif self._first_call:
|
||||
self._first_call = False
|
||||
try:
|
||||
# The first time we run, restore variables if necessary (via
|
||||
# Trackable).
|
||||
with trackable_util.capture_dependencies(template=self):
|
||||
result = self._func(*args, **kwargs)
|
||||
except:
|
||||
self._first_call = True
|
||||
raise
|
||||
self._variables_created = True
|
||||
else: # We are calling the template in parallel from another thread.
|
||||
result = self._func(*args, **kwargs)
|
||||
return result
|
||||
except Exception as exc:
|
||||
# Reraise the exception, but append the original definition to the
|
||||
@ -354,9 +366,9 @@ class Template(trackable.Trackable):
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
if self._variable_scope:
|
||||
# Only reuse variables if they were already created.
|
||||
# Only reuse variables if not on first call.
|
||||
with variable_scope.variable_scope(
|
||||
self._variable_scope, reuse=self._variables_created):
|
||||
self._variable_scope, reuse=not self._first_call):
|
||||
return self._call_func(args, kwargs)
|
||||
else:
|
||||
# The scope was not created at construction time, so create it here.
|
||||
|
Loading…
Reference in New Issue
Block a user