Split tensorflow.python.tpu module doctests into different target.
These TPU tests do not yet run in OSS. PiperOrigin-RevId: 310959419 Change-Id: I2a1662e52f25da9c4c58c018c83729dc6da9008d
This commit is contained in:
parent
74226e113a
commit
4727d0180f
@ -828,7 +828,7 @@ class TPUEmbedding(object):
|
||||
... end_learning_rate=0.0)
|
||||
>>> wordpiece_table_config = TableConfig(
|
||||
... vocabulary_size=119547,
|
||||
... dimension=768,
|
||||
... dimension=256,
|
||||
... learning_rate_fn=learning_rate_fn)
|
||||
>>> wordpiece_feature_config = FeatureConfig(
|
||||
... table_id='bert/embeddings/word_embeddings',
|
||||
@ -846,11 +846,11 @@ class TPUEmbedding(object):
|
||||
... batch_size=128,
|
||||
... mode=TRAINING,
|
||||
... optimization_parameters=optimization_parameters,
|
||||
... device_config=DeviceConfig(
|
||||
... num_cores=64, num_hosts=4, job_name='tpu_worker'))
|
||||
... master='')
|
||||
>>> with tf.Graph().as_default():
|
||||
... init_tpu_op = tf.compat.v1.tpu.initialize_system(
|
||||
... embedding_config=tpu_embedding.config_proto, job='tpu_worker')
|
||||
... embedding_config=tpu_embedding.config_proto)
|
||||
... tf.compat.v1.Session().run(init_tpu_op)
|
||||
"""
|
||||
|
||||
# TODO(shizhiw): Consider adding a field to FeatureConfig that indicates that
|
||||
|
@ -2,6 +2,7 @@
|
||||
# Doc generator
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "py_test")
|
||||
load("//tensorflow/python/tpu:tpu.bzl", "tpu_py_test")
|
||||
|
||||
package(
|
||||
default_visibility = ["//tensorflow:__subpackages__"],
|
||||
@ -22,6 +23,7 @@ py_library(
|
||||
py_test(
|
||||
name = "tf_doctest",
|
||||
srcs = ["tf_doctest.py"],
|
||||
args = ["--module_prefix_skip=tpu."],
|
||||
python_version = "PY3",
|
||||
tags = [
|
||||
"no_oss_py2",
|
||||
@ -40,6 +42,28 @@ py_test(
|
||||
],
|
||||
)
|
||||
|
||||
tpu_py_test(
|
||||
name = "tf_doctest_tpu",
|
||||
srcs = ["tf_doctest.py"],
|
||||
args = ["--module=tpu."],
|
||||
disable_experimental = True,
|
||||
disable_v3 = True,
|
||||
main = "tf_doctest.py",
|
||||
python_version = "PY3",
|
||||
tags = [
|
||||
"no_oss",
|
||||
"noasan",
|
||||
"nomsan",
|
||||
"notsan",
|
||||
],
|
||||
deps = [
|
||||
":tf_doctest_lib",
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/python/keras/preprocessing",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "tf_doctest_test",
|
||||
srcs = ["tf_doctest_test.py"],
|
||||
|
@ -43,6 +43,8 @@ tf.keras.preprocessing = preprocessing
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
flags.DEFINE_string('module', None, 'A specific module to run doctest on.')
|
||||
flags.DEFINE_list('module_prefix_skip', [],
|
||||
'A list of modules to ignore when resolving modules.')
|
||||
flags.DEFINE_boolean('list', None,
|
||||
'List all the modules in the core package imported.')
|
||||
flags.DEFINE_string('file', None, 'A specific file to run doctest on.')
|
||||
@ -50,6 +52,7 @@ flags.DEFINE_string('file', None, 'A specific file to run doctest on.')
|
||||
flags.mark_flags_as_mutual_exclusive(['module', 'file'])
|
||||
flags.mark_flags_as_mutual_exclusive(['list', 'file'])
|
||||
|
||||
# Both --module and --module_prefix_skip are relative to PACKAGE.
|
||||
PACKAGE = 'tensorflow.python.'
|
||||
|
||||
|
||||
@ -140,6 +143,9 @@ def load_tests(unused_loader, tests, unused_ignore):
|
||||
tf_modules = get_module_and_inject_docstring(FLAGS.file)
|
||||
|
||||
for module in tf_modules:
|
||||
if any(module.__name__.startswith(PACKAGE + prefix)
|
||||
for prefix in FLAGS.module_prefix_skip):
|
||||
continue
|
||||
testcase = TfTestCase()
|
||||
tests.addTests(
|
||||
doctest.DocTestSuite(
|
||||
|
Loading…
Reference in New Issue
Block a user