diff --git a/tensorflow/python/keras/layers/dense_attention_test.py b/tensorflow/python/keras/layers/dense_attention_test.py index 504c4ab6984..85780900593 100644 --- a/tensorflow/python/keras/layers/dense_attention_test.py +++ b/tensorflow/python/keras/layers/dense_attention_test.py @@ -23,7 +23,6 @@ import numpy as np from tensorflow.python import keras from tensorflow.python.eager import context -from tensorflow.python.framework import test_util from tensorflow.python.keras import combinations from tensorflow.python.keras.layers import core from tensorflow.python.keras.layers import dense_attention @@ -361,7 +360,6 @@ class AttentionTest(test.TestCase, parameterized.TestCase): attention_layer.build(input_shape=([1, 1, 1], [1, 1, 1])) self.assertAllClose(1., attention_layer.scale.value()) - @test_util.deprecated_graph_mode_only def test_scale_init_graph(self): """Tests that scale initializes to 1 when use_scale=True.""" with self.cached_session() as sess: diff --git a/tensorflow/python/keras/mixed_precision/experimental/loss_scale_optimizer_test.py b/tensorflow/python/keras/mixed_precision/experimental/loss_scale_optimizer_test.py index 350cfe6a09c..9a9d174a64f 100644 --- a/tensorflow/python/keras/mixed_precision/experimental/loss_scale_optimizer_test.py +++ b/tensorflow/python/keras/mixed_precision/experimental/loss_scale_optimizer_test.py @@ -106,19 +106,20 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase): # and so the variable will be init_val - grad * lr == 5 - 1 * 2 == 3 self.assertAllClose([3.], self.evaluate(var)) - @test_util.deprecated_graph_mode_only def testFixedLossScaleAppliedToLossWithGetGradients(self): - var = variables.Variable([2.0]) - opt = gradient_descent.SGD(1.0) - loss_scale = 10. - opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale) - grad_check_fn = mp_test_util.create_identity_with_grad_check_fn(loss_scale) - loss = grad_check_fn(var) - run_op = opt.get_gradients(loss, [var]) - self.evaluate(variables.global_variables_initializer()) - # This will cause an assertion to run, as - # mp_test_util.create_identity_with_grad_check_fn added an assertion op. - self.evaluate(run_op) + with ops.Graph().as_default(): + var = variables.Variable([2.0]) + opt = gradient_descent.SGD(1.0) + loss_scale = 10. + opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale) + grad_check_fn = mp_test_util.create_identity_with_grad_check_fn( + loss_scale) + loss = grad_check_fn(var) + run_op = opt.get_gradients(loss, [var]) + self.evaluate(variables.global_variables_initializer()) + # This will cause an assertion to run, as + # mp_test_util.create_identity_with_grad_check_fn added an assertion op. + self.evaluate(run_op) def testGetScaledLoss(self): opt = gradient_descent.SGD(2.0) diff --git a/tensorflow/python/keras/utils/multi_gpu_utils_test.py b/tensorflow/python/keras/utils/multi_gpu_utils_test.py index 465ace7f264..0765afb4db7 100644 --- a/tensorflow/python/keras/utils/multi_gpu_utils_test.py +++ b/tensorflow/python/keras/utils/multi_gpu_utils_test.py @@ -23,7 +23,7 @@ from tensorflow.python import data from tensorflow.python import keras from tensorflow.python.eager import context from tensorflow.python.framework import config -from tensorflow.python.framework import test_util +from tensorflow.python.framework import ops from tensorflow.python.keras.utils import multi_gpu_utils from tensorflow.python.keras.utils import np_utils from tensorflow.python.platform import test @@ -38,7 +38,7 @@ def check_if_compatible_devices(gpus=2): return False return True -@test_util.run_all_in_deprecated_graph_mode_only + class TestMultiGPUModel(test.TestCase): def __init__(self, methodName='runTest'): # pylint: disable=invalid-name @@ -161,7 +161,7 @@ class TestMultiGPUModel(test.TestCase): if not check_if_compatible_devices(gpus=gpus): self.skipTest('multi gpu only') - with self.cached_session(): + with ops.Graph().as_default(), self.cached_session(): input_shape = (num_samples,) + shape x_train = np.random.randint(0, 255, input_shape) y_train = np.random.randint(0, num_classes, (input_shape[0],))