Update optimizers to v2 in saved model test.
PiperOrigin-RevId: 258907912
This commit is contained in:
parent
238dcdfdee
commit
69359f86c9
@ -32,8 +32,11 @@ from tensorflow.python.framework import dtypes
|
|||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import tensor_spec
|
from tensorflow.python.framework import tensor_spec
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
|
from tensorflow.python.keras import keras_parameterized
|
||||||
|
from tensorflow.python.keras import testing_utils
|
||||||
from tensorflow.python.keras.engine import training as model_lib
|
from tensorflow.python.keras.engine import training as model_lib
|
||||||
from tensorflow.python.keras.optimizer_v2 import adadelta
|
from tensorflow.python.keras.optimizer_v2 import adadelta
|
||||||
|
from tensorflow.python.keras.optimizer_v2 import rmsprop
|
||||||
from tensorflow.python.keras.saving import saved_model_experimental as keras_saved_model
|
from tensorflow.python.keras.saving import saved_model_experimental as keras_saved_model
|
||||||
from tensorflow.python.keras.utils import mode_keys
|
from tensorflow.python.keras.utils import mode_keys
|
||||||
from tensorflow.python.keras.utils import tf_utils
|
from tensorflow.python.keras.utils import tf_utils
|
||||||
@ -44,7 +47,8 @@ from tensorflow.python.saved_model import model_utils
|
|||||||
from tensorflow.python.training import training as training_module
|
from tensorflow.python.training import training as training_module
|
||||||
|
|
||||||
|
|
||||||
class TestModelSavingandLoading(test.TestCase):
|
@keras_parameterized.run_all_keras_modes()
|
||||||
|
class TestModelSavingandLoading(parameterized.TestCase, test.TestCase):
|
||||||
|
|
||||||
def _save_model_dir(self, dirname='saved_model'):
|
def _save_model_dir(self, dirname='saved_model'):
|
||||||
temp_dir = self.get_temp_dir()
|
temp_dir = self.get_temp_dir()
|
||||||
@ -59,9 +63,11 @@ class TestModelSavingandLoading(test.TestCase):
|
|||||||
model.add(keras.layers.TimeDistributed(keras.layers.Dense(3)))
|
model.add(keras.layers.TimeDistributed(keras.layers.Dense(3)))
|
||||||
model.compile(
|
model.compile(
|
||||||
loss=keras.losses.MSE,
|
loss=keras.losses.MSE,
|
||||||
optimizer=keras.optimizers.RMSprop(lr=0.0001),
|
optimizer=rmsprop.RMSprop(lr=0.0001),
|
||||||
metrics=[keras.metrics.categorical_accuracy],
|
metrics=[keras.metrics.categorical_accuracy],
|
||||||
sample_weight_mode='temporal')
|
sample_weight_mode='temporal',
|
||||||
|
run_eagerly=testing_utils.should_run_eagerly(),
|
||||||
|
run_distributed=testing_utils.should_run_distributed())
|
||||||
x = np.random.random((1, 3))
|
x = np.random.random((1, 3))
|
||||||
y = np.random.random((1, 3, 3))
|
y = np.random.random((1, 3, 3))
|
||||||
model.train_on_batch(x, y)
|
model.train_on_batch(x, y)
|
||||||
@ -102,8 +108,10 @@ class TestModelSavingandLoading(test.TestCase):
|
|||||||
model = keras.models.Model(inputs, output)
|
model = keras.models.Model(inputs, output)
|
||||||
model.compile(
|
model.compile(
|
||||||
loss=keras.losses.MSE,
|
loss=keras.losses.MSE,
|
||||||
optimizer=keras.optimizers.RMSprop(lr=0.0001),
|
optimizer=rmsprop.RMSprop(lr=0.0001),
|
||||||
metrics=[keras.metrics.categorical_accuracy])
|
metrics=[keras.metrics.categorical_accuracy],
|
||||||
|
run_eagerly=testing_utils.should_run_eagerly(),
|
||||||
|
run_distributed=testing_utils.should_run_distributed())
|
||||||
x = np.random.random((1, 3))
|
x = np.random.random((1, 3))
|
||||||
y = np.random.random((1, 3))
|
y = np.random.random((1, 3))
|
||||||
model.train_on_batch(x, y)
|
model.train_on_batch(x, y)
|
||||||
@ -159,7 +167,9 @@ class TestModelSavingandLoading(test.TestCase):
|
|||||||
loaded_model.compile(
|
loaded_model.compile(
|
||||||
loss='mse',
|
loss='mse',
|
||||||
optimizer=training_module.RMSPropOptimizer(0.1),
|
optimizer=training_module.RMSPropOptimizer(0.1),
|
||||||
metrics=['acc'])
|
metrics=['acc'],
|
||||||
|
run_eagerly=testing_utils.should_run_eagerly(),
|
||||||
|
run_distributed=testing_utils.should_run_distributed())
|
||||||
y = loaded_model.predict(x)
|
y = loaded_model.predict(x)
|
||||||
self.assertAllClose(ref_y, y, atol=1e-05)
|
self.assertAllClose(ref_y, y, atol=1e-05)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user