Update keras test to use public TF test API.

PiperOrigin-RevId: 339346336
Change-Id: I96d8c98cd9122c301d2bbf8358f5d0f1e090f670
This commit is contained in:
Scott Zhu 2020-10-27 15:45:04 -07:00 committed by TensorFlower Gardener
parent b176d45d99
commit 8ae2236a6a
5 changed files with 10 additions and 10 deletions

View File

@ -24,7 +24,7 @@ import time
import tensorflow as tf
from tensorflow.python.platform import gfile
from tensorflow.python.platform import googletest
from tensorflow.python.platform import test
def save_and_load_benchmark(app):
@ -34,7 +34,7 @@ def save_and_load_benchmark(app):
model = app(weights=None)
model_name = app.__name__
tmp_dir = googletest.GetTempDir()
tmp_dir = test.get_temp_dir()
gfile.MakeDirs(tmp_dir)
save_dir = tempfile.mkdtemp(dir=tmp_dir)

View File

@ -28,7 +28,7 @@ from tensorflow.python.eager import context
from tensorflow.python.keras import keras_parameterized
from tensorflow.python.keras import testing_utils
from tensorflow.python.keras.engine import keras_tensor
from tensorflow.python.platform import googletest
from tensorflow.python.platform import test
class KerasParameterizedTest(keras_parameterized.TestCase):
@ -600,4 +600,4 @@ class KerasParameterizedTest(keras_parameterized.TestCase):
self.assertEqual(arg, True)
if __name__ == "__main__":
googletest.main()
test.main()

View File

@ -32,7 +32,7 @@ from tensorflow.python.keras import combinations
from tensorflow.python.keras.optimizer_v2 import gradient_descent
from tensorflow.python.keras.optimizer_v2 import learning_rate_schedule
from tensorflow.python.ops import variables
from tensorflow.python.platform import googletest
from tensorflow.python.platform import test
def _maybe_serialized(lr_decay, serialize_and_deserialize):
@ -510,4 +510,4 @@ class NoisyLinearCosineDecayTestV2(test_util.TensorFlowTestCase,
self.evaluate(decayed_lr(step))
if __name__ == "__main__":
googletest.main()
test.main()

View File

@ -26,7 +26,7 @@ from tensorflow.python.keras import combinations
from tensorflow.python.keras import keras_parameterized
from tensorflow.python.keras.optimizer_v2 import legacy_learning_rate_decay as learning_rate_decay
from tensorflow.python.ops import variables
from tensorflow.python.platform import googletest
from tensorflow.python.platform import test
@combinations.generate(combinations.combine(mode=["graph", "eager"]))
@ -478,4 +478,4 @@ class NoisyLinearCosineDecayTest(keras_parameterized.TestCase):
if __name__ == "__main__":
googletest.main()
test.main()

View File

@ -28,7 +28,7 @@ from tensorflow.python.keras.utils import metrics_utils
from tensorflow.python.ops import script_ops
from tensorflow.python.ops.ragged import ragged_factory_ops
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.platform import googletest
from tensorflow.python.platform import test
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
@ -301,4 +301,4 @@ class FilterTopKTest(test_util.TensorFlowTestCase, parameterized.TestCase):
if __name__ == '__main__':
googletest.main()
test.main()