Remove tag run_deprecated_v1.
PiperOrigin-RevId: 343834715 Change-Id: I88a2e30f4800c5fec645dc6744be23bdcac66b80
This commit is contained in:
parent
0e8565529f
commit
e14bc4bbd9
@ -19,6 +19,7 @@ from __future__ import print_function
|
||||
|
||||
from tensorflow.python.distribute import distribution_strategy_context as ds_context
|
||||
from tensorflow.python.distribute import mirrored_strategy
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import template
|
||||
@ -29,28 +30,29 @@ from tensorflow.python.platform import test
|
||||
|
||||
class TemplateMirroredStrategyTest(test.TestCase):
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
@test_util.disable_tfrt("Strategy not supported yet.")
|
||||
def test_merge_call(self):
|
||||
if not test.is_gpu_available():
|
||||
self.skipTest("No GPU available")
|
||||
with ops.Graph().as_default():
|
||||
# The test is testing a v1 only function.
|
||||
if not test.is_gpu_available():
|
||||
self.skipTest("No GPU available")
|
||||
|
||||
def fn():
|
||||
var1 = variable_scope.get_variable(
|
||||
"var1", shape=[], initializer=init_ops.constant_initializer(21.))
|
||||
ds_context.get_replica_context().merge_call(lambda _: ())
|
||||
var2 = variable_scope.get_variable(
|
||||
"var2", shape=[], initializer=init_ops.constant_initializer(2.))
|
||||
return var1 * var2
|
||||
def fn():
|
||||
var1 = variable_scope.get_variable(
|
||||
"var1", shape=[], initializer=init_ops.constant_initializer(21.))
|
||||
ds_context.get_replica_context().merge_call(lambda _: ())
|
||||
var2 = variable_scope.get_variable(
|
||||
"var2", shape=[], initializer=init_ops.constant_initializer(2.))
|
||||
return var1 * var2
|
||||
|
||||
temp = template.make_template("my_template", fn)
|
||||
temp = template.make_template("my_template", fn)
|
||||
|
||||
strategy = mirrored_strategy.MirroredStrategy(["/cpu:0", "/gpu:0"])
|
||||
out = strategy.experimental_local_results(
|
||||
strategy.run(temp))
|
||||
strategy = mirrored_strategy.MirroredStrategy(["/cpu:0", "/gpu:0"])
|
||||
out = strategy.experimental_local_results(
|
||||
strategy.run(temp))
|
||||
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
self.assertAllEqual([42., 42.], self.evaluate(out))
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
self.assertAllEqual([42., 42.], self.evaluate(out))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Loading…
x
Reference in New Issue
Block a user