Keras Fixit: Copy the util is_oss into multi_worker_callback_tf2_test.
PiperOrigin-RevId: 339382960 Change-Id: I47b7b4e9a5f7c1242c77f18c0d935815c13794fd
This commit is contained in:
parent
f34ada9573
commit
091f679cdf
@ -20,6 +20,7 @@ from __future__ import print_function
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
|
||||
from absl.testing import parameterized
|
||||
|
||||
@ -35,6 +36,11 @@ from tensorflow.python.lib.io import file_io
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
def _is_oss():
|
||||
"""Returns whether the test is run under OSS."""
|
||||
return len(sys.argv) >= 1 and 'bazel' in sys.argv[0]
|
||||
|
||||
|
||||
def checkpoint_exists(filepath):
|
||||
"""Returns whether the checkpoint `filepath` refers to exists."""
|
||||
if filepath.endswith('.h5'):
|
||||
@ -183,7 +189,7 @@ class KerasCallbackMultiProcessTest(parameterized.TestCase, test.TestCase):
|
||||
|
||||
def proc_model_checkpoint_works_with_same_file_path(test_obj,
|
||||
saving_filepath):
|
||||
if multi_process_runner.is_oss():
|
||||
if _is_oss():
|
||||
test_obj.skipTest('TODO(b/170838633): Failing in OSS')
|
||||
model, _, train_ds, steps = _model_setup(test_obj, file_format='')
|
||||
num_epoch = 4
|
||||
|
Loading…
Reference in New Issue
Block a user