Update keras test to use public TF test API.
PiperOrigin-RevId: 339346336 Change-Id: I96d8c98cd9122c301d2bbf8358f5d0f1e090f670
This commit is contained in:
parent
b176d45d99
commit
8ae2236a6a
@ -24,7 +24,7 @@ import time
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.python.platform import gfile
|
||||
from tensorflow.python.platform import googletest
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
def save_and_load_benchmark(app):
|
||||
@ -34,7 +34,7 @@ def save_and_load_benchmark(app):
|
||||
model = app(weights=None)
|
||||
model_name = app.__name__
|
||||
|
||||
tmp_dir = googletest.GetTempDir()
|
||||
tmp_dir = test.get_temp_dir()
|
||||
gfile.MakeDirs(tmp_dir)
|
||||
save_dir = tempfile.mkdtemp(dir=tmp_dir)
|
||||
|
||||
|
@ -28,7 +28,7 @@ from tensorflow.python.eager import context
|
||||
from tensorflow.python.keras import keras_parameterized
|
||||
from tensorflow.python.keras import testing_utils
|
||||
from tensorflow.python.keras.engine import keras_tensor
|
||||
from tensorflow.python.platform import googletest
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class KerasParameterizedTest(keras_parameterized.TestCase):
|
||||
@ -600,4 +600,4 @@ class KerasParameterizedTest(keras_parameterized.TestCase):
|
||||
self.assertEqual(arg, True)
|
||||
|
||||
if __name__ == "__main__":
|
||||
googletest.main()
|
||||
test.main()
|
||||
|
@ -32,7 +32,7 @@ from tensorflow.python.keras import combinations
|
||||
from tensorflow.python.keras.optimizer_v2 import gradient_descent
|
||||
from tensorflow.python.keras.optimizer_v2 import learning_rate_schedule
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import googletest
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
def _maybe_serialized(lr_decay, serialize_and_deserialize):
|
||||
@ -510,4 +510,4 @@ class NoisyLinearCosineDecayTestV2(test_util.TensorFlowTestCase,
|
||||
self.evaluate(decayed_lr(step))
|
||||
|
||||
if __name__ == "__main__":
|
||||
googletest.main()
|
||||
test.main()
|
||||
|
@ -26,7 +26,7 @@ from tensorflow.python.keras import combinations
|
||||
from tensorflow.python.keras import keras_parameterized
|
||||
from tensorflow.python.keras.optimizer_v2 import legacy_learning_rate_decay as learning_rate_decay
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import googletest
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@combinations.generate(combinations.combine(mode=["graph", "eager"]))
|
||||
@ -478,4 +478,4 @@ class NoisyLinearCosineDecayTest(keras_parameterized.TestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
googletest.main()
|
||||
test.main()
|
||||
|
@ -28,7 +28,7 @@ from tensorflow.python.keras.utils import metrics_utils
|
||||
from tensorflow.python.ops import script_ops
|
||||
from tensorflow.python.ops.ragged import ragged_factory_ops
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
from tensorflow.python.platform import googletest
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||
@ -301,4 +301,4 @@ class FilterTopKTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
googletest.main()
|
||||
test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user