Remove tag run_deprecated_v1.

PiperOrigin-RevId: 343834715
Change-Id: I88a2e30f4800c5fec645dc6744be23bdcac66b80
This commit is contained in:
Xinyi Wang 2020-11-23 05:53:34 -08:00 committed by TensorFlower Gardener
parent 0e8565529f
commit e14bc4bbd9

View File

@ -19,6 +19,7 @@ from __future__ import print_function
from tensorflow.python.distribute import distribution_strategy_context as ds_context
from tensorflow.python.distribute import mirrored_strategy
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import template
@ -29,28 +30,29 @@ from tensorflow.python.platform import test
class TemplateMirroredStrategyTest(test.TestCase):
@test_util.run_deprecated_v1
@test_util.disable_tfrt("Strategy not supported yet.")
def test_merge_call(self):
if not test.is_gpu_available():
self.skipTest("No GPU available")
with ops.Graph().as_default():
# The test is testing a v1 only function.
if not test.is_gpu_available():
self.skipTest("No GPU available")
def fn():
var1 = variable_scope.get_variable(
"var1", shape=[], initializer=init_ops.constant_initializer(21.))
ds_context.get_replica_context().merge_call(lambda _: ())
var2 = variable_scope.get_variable(
"var2", shape=[], initializer=init_ops.constant_initializer(2.))
return var1 * var2
def fn():
var1 = variable_scope.get_variable(
"var1", shape=[], initializer=init_ops.constant_initializer(21.))
ds_context.get_replica_context().merge_call(lambda _: ())
var2 = variable_scope.get_variable(
"var2", shape=[], initializer=init_ops.constant_initializer(2.))
return var1 * var2
temp = template.make_template("my_template", fn)
temp = template.make_template("my_template", fn)
strategy = mirrored_strategy.MirroredStrategy(["/cpu:0", "/gpu:0"])
out = strategy.experimental_local_results(
strategy.run(temp))
strategy = mirrored_strategy.MirroredStrategy(["/cpu:0", "/gpu:0"])
out = strategy.experimental_local_results(
strategy.run(temp))
self.evaluate(variables.global_variables_initializer())
self.assertAllEqual([42., 42.], self.evaluate(out))
self.evaluate(variables.global_variables_initializer())
self.assertAllEqual([42., 42.], self.evaluate(out))
if __name__ == "__main__":