Split tensorflow.python.tpu module doctests into different target.

These TPU tests do not yet run in OSS.

PiperOrigin-RevId: 310959419
Change-Id: I2a1662e52f25da9c4c58c018c83729dc6da9008d
This commit is contained in:
Revan Sopher 2020-05-11 11:50:00 -07:00 committed by TensorFlower Gardener
parent 74226e113a
commit 4727d0180f
3 changed files with 34 additions and 4 deletions

View File

@ -828,7 +828,7 @@ class TPUEmbedding(object):
... end_learning_rate=0.0) ... end_learning_rate=0.0)
>>> wordpiece_table_config = TableConfig( >>> wordpiece_table_config = TableConfig(
... vocabulary_size=119547, ... vocabulary_size=119547,
... dimension=768, ... dimension=256,
... learning_rate_fn=learning_rate_fn) ... learning_rate_fn=learning_rate_fn)
>>> wordpiece_feature_config = FeatureConfig( >>> wordpiece_feature_config = FeatureConfig(
... table_id='bert/embeddings/word_embeddings', ... table_id='bert/embeddings/word_embeddings',
@ -846,11 +846,11 @@ class TPUEmbedding(object):
... batch_size=128, ... batch_size=128,
... mode=TRAINING, ... mode=TRAINING,
... optimization_parameters=optimization_parameters, ... optimization_parameters=optimization_parameters,
... device_config=DeviceConfig( ... master='')
... num_cores=64, num_hosts=4, job_name='tpu_worker'))
>>> with tf.Graph().as_default(): >>> with tf.Graph().as_default():
... init_tpu_op = tf.compat.v1.tpu.initialize_system( ... init_tpu_op = tf.compat.v1.tpu.initialize_system(
... embedding_config=tpu_embedding.config_proto, job='tpu_worker') ... embedding_config=tpu_embedding.config_proto)
... tf.compat.v1.Session().run(init_tpu_op)
""" """
# TODO(shizhiw): Consider adding a field to FeatureConfig that indicates that # TODO(shizhiw): Consider adding a field to FeatureConfig that indicates that

View File

@ -2,6 +2,7 @@
# Doc generator # Doc generator
load("//tensorflow:tensorflow.bzl", "py_test") load("//tensorflow:tensorflow.bzl", "py_test")
load("//tensorflow/python/tpu:tpu.bzl", "tpu_py_test")
package( package(
default_visibility = ["//tensorflow:__subpackages__"], default_visibility = ["//tensorflow:__subpackages__"],
@ -22,6 +23,7 @@ py_library(
py_test( py_test(
name = "tf_doctest", name = "tf_doctest",
srcs = ["tf_doctest.py"], srcs = ["tf_doctest.py"],
args = ["--module_prefix_skip=tpu."],
python_version = "PY3", python_version = "PY3",
tags = [ tags = [
"no_oss_py2", "no_oss_py2",
@ -40,6 +42,28 @@ py_test(
], ],
) )
tpu_py_test(
name = "tf_doctest_tpu",
srcs = ["tf_doctest.py"],
args = ["--module=tpu."],
disable_experimental = True,
disable_v3 = True,
main = "tf_doctest.py",
python_version = "PY3",
tags = [
"no_oss",
"noasan",
"nomsan",
"notsan",
],
deps = [
":tf_doctest_lib",
"//tensorflow:tensorflow_py",
"//tensorflow/python/keras/preprocessing",
"//third_party/py/numpy",
],
)
py_test( py_test(
name = "tf_doctest_test", name = "tf_doctest_test",
srcs = ["tf_doctest_test.py"], srcs = ["tf_doctest_test.py"],

View File

@ -43,6 +43,8 @@ tf.keras.preprocessing = preprocessing
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
flags.DEFINE_string('module', None, 'A specific module to run doctest on.') flags.DEFINE_string('module', None, 'A specific module to run doctest on.')
flags.DEFINE_list('module_prefix_skip', [],
'A list of modules to ignore when resolving modules.')
flags.DEFINE_boolean('list', None, flags.DEFINE_boolean('list', None,
'List all the modules in the core package imported.') 'List all the modules in the core package imported.')
flags.DEFINE_string('file', None, 'A specific file to run doctest on.') flags.DEFINE_string('file', None, 'A specific file to run doctest on.')
@ -50,6 +52,7 @@ flags.DEFINE_string('file', None, 'A specific file to run doctest on.')
flags.mark_flags_as_mutual_exclusive(['module', 'file']) flags.mark_flags_as_mutual_exclusive(['module', 'file'])
flags.mark_flags_as_mutual_exclusive(['list', 'file']) flags.mark_flags_as_mutual_exclusive(['list', 'file'])
# Both --module and --module_prefix_skip are relative to PACKAGE.
PACKAGE = 'tensorflow.python.' PACKAGE = 'tensorflow.python.'
@ -140,6 +143,9 @@ def load_tests(unused_loader, tests, unused_ignore):
tf_modules = get_module_and_inject_docstring(FLAGS.file) tf_modules = get_module_and_inject_docstring(FLAGS.file)
for module in tf_modules: for module in tf_modules:
if any(module.__name__.startswith(PACKAGE + prefix)
for prefix in FLAGS.module_prefix_skip):
continue
testcase = TfTestCase() testcase = TfTestCase()
tests.addTests( tests.addTests(
doctest.DocTestSuite( doctest.DocTestSuite(