Keras Fixit: Copy the util is_oss into multi_worker_callback_tf2_test.

PiperOrigin-RevId: 339382960
Change-Id: I47b7b4e9a5f7c1242c77f18c0d935815c13794fd
This commit is contained in:
Rick Chao 2020-10-27 20:12:04 -07:00 committed by TensorFlower Gardener
parent f34ada9573
commit 091f679cdf

View File

@ -20,6 +20,7 @@ from __future__ import print_function
import json
import os
import sys
from absl.testing import parameterized
@ -35,6 +36,11 @@ from tensorflow.python.lib.io import file_io
from tensorflow.python.platform import test
def _is_oss():
"""Returns whether the test is run under OSS."""
return len(sys.argv) >= 1 and 'bazel' in sys.argv[0]
def checkpoint_exists(filepath):
"""Returns whether the checkpoint `filepath` refers to exists."""
if filepath.endswith('.h5'):
@ -183,7 +189,7 @@ class KerasCallbackMultiProcessTest(parameterized.TestCase, test.TestCase):
def proc_model_checkpoint_works_with_same_file_path(test_obj,
saving_filepath):
if multi_process_runner.is_oss():
if _is_oss():
test_obj.skipTest('TODO(b/170838633): Failing in OSS')
model, _, train_ds, steps = _model_setup(test_obj, file_format='')
num_epoch = 4