Reactivate SavedModel tests that were broken by b/134660903, b/134662234.
PiperOrigin-RevId: 252996041
This commit is contained in:
parent
a2c10678f6
commit
6237fcb1a9
@ -73,14 +73,12 @@ class SavedModelTest(integration_scripts.TestCase, parameterized.TestCase):
|
|||||||
|
|
||||||
NAMED_PARAMETERS_FOR_TEST_MNIST_CNN = (
|
NAMED_PARAMETERS_FOR_TEST_MNIST_CNN = (
|
||||||
("", dict()),
|
("", dict()),
|
||||||
# TODO(b/134662234): Re-enable this case when fixed.
|
("_with_retraining", dict(
|
||||||
# ("_with_retraining", dict(
|
retrain=True,
|
||||||
# retrain=True,
|
regularization_loss_multiplier=2, # Test impact of b/134528831.
|
||||||
# regularization_loss_multiplier=2, # Test impact of b/134528831.
|
)),
|
||||||
# )),
|
|
||||||
("_with_mirrored_strategy", dict(
|
("_with_mirrored_strategy", dict(
|
||||||
# TODO(b/134662234): Add back retrain=True when fixed.
|
retrain=True, # That's the relevant case for distribution.
|
||||||
# retrain=True, # That's the relevant case for distribution.
|
|
||||||
use_mirrored_strategy=True,
|
use_mirrored_strategy=True,
|
||||||
)),
|
)),
|
||||||
)
|
)
|
||||||
|
@ -118,8 +118,7 @@ def main(argv):
|
|||||||
FLAGS.input_saved_model_dir,
|
FLAGS.input_saved_model_dir,
|
||||||
FLAGS.retrain,
|
FLAGS.retrain,
|
||||||
FLAGS.regularization_loss_multiplier)
|
FLAGS.regularization_loss_multiplier)
|
||||||
model = make_classifier(feature_extractor,
|
model = make_classifier(feature_extractor)
|
||||||
dropout_rate=0.0) # TODO(b/134660903): Remove.
|
|
||||||
|
|
||||||
model.compile(loss=tf.keras.losses.categorical_crossentropy,
|
model.compile(loss=tf.keras.losses.categorical_crossentropy,
|
||||||
optimizer=tf.keras.optimizers.SGD(),
|
optimizer=tf.keras.optimizers.SGD(),
|
||||||
|
Loading…
Reference in New Issue
Block a user