Deprecated tf.Session removed in bidirectional_sequence_rnn_test.py

This commit is contained in:
Siju Samuel 2019-07-10 14:44:06 +05:30
parent 6cdf3802a9
commit c98c853676

View File

@ -206,7 +206,7 @@ class BidirectionalSequenceRnnTest(test_util.TensorFlowTestCase):
x, prediction, output_class = self.buildModel( x, prediction, output_class = self.buildModel(
fw_rnn_layer, bw_rnn_layer, is_dynamic_rnn, True, use_sequence_length) fw_rnn_layer, bw_rnn_layer, is_dynamic_rnn, True, use_sequence_length)
new_sess = tf.Session(config=CONFIG) new_sess = tf.compat.v1.Session(config=CONFIG)
saver = tf.train.Saver() saver = tf.train.Saver()
saver.restore(new_sess, model_dir) saver.restore(new_sess, model_dir)
return x, prediction, output_class, new_sess return x, prediction, output_class, new_sess
@ -265,7 +265,7 @@ class BidirectionalSequenceRnnTest(test_util.TensorFlowTestCase):
return result return result
def testStaticRnnMultiRnnCell(self): def testStaticRnnMultiRnnCell(self):
sess = tf.Session(config=CONFIG) sess = tf.compat.v1.Session(config=CONFIG)
x, prediction, output_class = self.buildModel( x, prediction, output_class = self.buildModel(
self.buildRnnLayer(), self.buildRnnLayer(), False, is_inference=False) self.buildRnnLayer(), self.buildRnnLayer(), False, is_inference=False)
@ -282,7 +282,7 @@ class BidirectionalSequenceRnnTest(test_util.TensorFlowTestCase):
self.assertTrue(np.allclose(expected_output, result, rtol=1e-6, atol=1e-2)) self.assertTrue(np.allclose(expected_output, result, rtol=1e-6, atol=1e-2))
def testStaticRnnMultiRnnCellWithSequenceLength(self): def testStaticRnnMultiRnnCellWithSequenceLength(self):
sess = tf.Session(config=CONFIG) sess = tf.compat.v1.Session(config=CONFIG)
x, prediction, output_class = self.buildModel( x, prediction, output_class = self.buildModel(
self.buildRnnLayer(), self.buildRnnLayer(),
@ -309,7 +309,7 @@ class BidirectionalSequenceRnnTest(test_util.TensorFlowTestCase):
@test_util.enable_control_flow_v2 @test_util.enable_control_flow_v2
def testDynamicRnnMultiRnnCell(self): def testDynamicRnnMultiRnnCell(self):
sess = tf.Session(config=CONFIG) sess = tf.compat.v1.Session(config=CONFIG)
x, prediction, output_class = self.buildModel( x, prediction, output_class = self.buildModel(
self.buildRnnLayer(), self.buildRnnLayer(), True, is_inference=False) self.buildRnnLayer(), self.buildRnnLayer(), True, is_inference=False)
@ -331,7 +331,7 @@ class BidirectionalSequenceRnnTest(test_util.TensorFlowTestCase):
@test_util.enable_control_flow_v2 @test_util.enable_control_flow_v2
def testDynamicRnnMultiRnnCellWithSequenceLength(self): def testDynamicRnnMultiRnnCellWithSequenceLength(self):
sess = tf.Session(config=CONFIG) sess = tf.compat.v1.Session(config=CONFIG)
x, prediction, output_class = self.buildModel( x, prediction, output_class = self.buildModel(
self.buildRnnLayer(), self.buildRnnLayer(),