Keras Fixit: Copy the util is_oss into multi_worker_callback_tf2_test.
PiperOrigin-RevId: 339382960 Change-Id: I47b7b4e9a5f7c1242c77f18c0d935815c13794fd
This commit is contained in:
parent
f34ada9573
commit
091f679cdf
@ -20,6 +20,7 @@ from __future__ import print_function
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
from absl.testing import parameterized
|
from absl.testing import parameterized
|
||||||
|
|
||||||
@ -35,6 +36,11 @@ from tensorflow.python.lib.io import file_io
|
|||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
|
def _is_oss():
|
||||||
|
"""Returns whether the test is run under OSS."""
|
||||||
|
return len(sys.argv) >= 1 and 'bazel' in sys.argv[0]
|
||||||
|
|
||||||
|
|
||||||
def checkpoint_exists(filepath):
|
def checkpoint_exists(filepath):
|
||||||
"""Returns whether the checkpoint `filepath` refers to exists."""
|
"""Returns whether the checkpoint `filepath` refers to exists."""
|
||||||
if filepath.endswith('.h5'):
|
if filepath.endswith('.h5'):
|
||||||
@ -183,7 +189,7 @@ class KerasCallbackMultiProcessTest(parameterized.TestCase, test.TestCase):
|
|||||||
|
|
||||||
def proc_model_checkpoint_works_with_same_file_path(test_obj,
|
def proc_model_checkpoint_works_with_same_file_path(test_obj,
|
||||||
saving_filepath):
|
saving_filepath):
|
||||||
if multi_process_runner.is_oss():
|
if _is_oss():
|
||||||
test_obj.skipTest('TODO(b/170838633): Failing in OSS')
|
test_obj.skipTest('TODO(b/170838633): Failing in OSS')
|
||||||
model, _, train_ds, steps = _model_setup(test_obj, file_format='')
|
model, _, train_ds, steps = _model_setup(test_obj, file_format='')
|
||||||
num_epoch = 4
|
num_epoch = 4
|
||||||
|
Loading…
Reference in New Issue
Block a user