fixit for server_lib container test.

PiperOrigin-RevId: 325079700
Change-Id: Iad7abfb5493c568cd31bc9a88886aa04ef9ddcd6
This commit is contained in:
Zhenyu Tan 2020-08-05 12:42:11 -07:00 committed by TensorFlower Gardener
parent c2e594440e
commit 1ae71606d0

View File

@ -20,7 +20,7 @@ from __future__ import print_function
from tensorflow.python.client import session
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import test_util
from tensorflow.python.framework import ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.training import server_lib
@ -33,7 +33,6 @@ class SameVariablesClearContainerTest(test.TestCase):
# TODO(b/34465411): Starting multiple servers with different configurations
# in the same test is flaky. Move this test case back into
# "server_lib_test.py" when this is no longer the case.
@test_util.run_deprecated_v1
def testSameVariablesClearContainer(self):
# Starts two servers with different names so they map to different
# resource "containers".
@ -47,36 +46,37 @@ class SameVariablesClearContainerTest(test.TestCase):
}, protocol="grpc", start=True)
# Creates a graph with 2 variables.
v0 = variables.Variable(1.0, name="v0")
v1 = variables.Variable(2.0, name="v0")
with ops.Graph().as_default():
v0 = variables.Variable(1.0, name="v0")
v1 = variables.Variable(2.0, name="v0")
# Initializes the variables. Verifies that the values are correct.
sess_0 = session.Session(server0.target)
sess_1 = session.Session(server1.target)
sess_0.run(v0.initializer)
sess_1.run(v1.initializer)
self.assertAllEqual(1.0, sess_0.run(v0))
self.assertAllEqual(2.0, sess_1.run(v1))
# Initializes the variables. Verifies that the values are correct.
sess_0 = session.Session(server0.target)
sess_1 = session.Session(server1.target)
sess_0.run(v0.initializer)
sess_1.run(v1.initializer)
self.assertAllEqual(1.0, sess_0.run(v0))
self.assertAllEqual(2.0, sess_1.run(v1))
# Resets container "local0". Verifies that v0 is no longer initialized.
session.Session.reset(server0.target, ["local0"])
sess = session.Session(server0.target)
with self.assertRaises(errors_impl.FailedPreconditionError):
self.evaluate(v0)
# Reinitializes v0 for the following test.
self.evaluate(v0.initializer)
# Resets container "local0". Verifies that v0 is no longer initialized.
session.Session.reset(server0.target, ["local0"])
_ = session.Session(server0.target)
with self.assertRaises(errors_impl.FailedPreconditionError):
self.evaluate(v0)
# Reinitializes v0 for the following test.
self.evaluate(v0.initializer)
# Verifies that v1 is still valid.
self.assertAllEqual(2.0, sess_1.run(v1))
# Verifies that v1 is still valid.
self.assertAllEqual(2.0, sess_1.run(v1))
# Resets container "local1". Verifies that v1 is no longer initialized.
session.Session.reset(server1.target, ["local1"])
sess = session.Session(server1.target)
with self.assertRaises(errors_impl.FailedPreconditionError):
self.evaluate(v1)
# Verifies that v0 is still valid.
sess = session.Session(server0.target)
self.assertAllEqual(1.0, self.evaluate(v0))
# Resets container "local1". Verifies that v1 is no longer initialized.
session.Session.reset(server1.target, ["local1"])
_ = session.Session(server1.target)
with self.assertRaises(errors_impl.FailedPreconditionError):
self.evaluate(v1)
# Verifies that v0 is still valid.
_ = session.Session(server0.target)
self.assertAllEqual(1.0, self.evaluate(v0))
if __name__ == "__main__":