Reactivate SavedModel tests that were broken by b/134660903, b/134662234.
PiperOrigin-RevId: 252996041
This commit is contained in:
parent
a2c10678f6
commit
6237fcb1a9
@ -73,14 +73,12 @@ class SavedModelTest(integration_scripts.TestCase, parameterized.TestCase):
|
||||
|
||||
NAMED_PARAMETERS_FOR_TEST_MNIST_CNN = (
|
||||
("", dict()),
|
||||
# TODO(b/134662234): Re-enable this case when fixed.
|
||||
# ("_with_retraining", dict(
|
||||
# retrain=True,
|
||||
# regularization_loss_multiplier=2, # Test impact of b/134528831.
|
||||
# )),
|
||||
("_with_retraining", dict(
|
||||
retrain=True,
|
||||
regularization_loss_multiplier=2, # Test impact of b/134528831.
|
||||
)),
|
||||
("_with_mirrored_strategy", dict(
|
||||
# TODO(b/134662234): Add back retrain=True when fixed.
|
||||
# retrain=True, # That's the relevant case for distribution.
|
||||
retrain=True, # That's the relevant case for distribution.
|
||||
use_mirrored_strategy=True,
|
||||
)),
|
||||
)
|
||||
|
@ -118,8 +118,7 @@ def main(argv):
|
||||
FLAGS.input_saved_model_dir,
|
||||
FLAGS.retrain,
|
||||
FLAGS.regularization_loss_multiplier)
|
||||
model = make_classifier(feature_extractor,
|
||||
dropout_rate=0.0) # TODO(b/134660903): Remove.
|
||||
model = make_classifier(feature_extractor)
|
||||
|
||||
model.compile(loss=tf.keras.losses.categorical_crossentropy,
|
||||
optimizer=tf.keras.optimizers.SGD(),
|
||||
|
Loading…
Reference in New Issue
Block a user