Split tensorflow.python.tpu module doctests into different target.
These TPU tests do not yet run in OSS. PiperOrigin-RevId: 310959419 Change-Id: I2a1662e52f25da9c4c58c018c83729dc6da9008d
This commit is contained in:
parent
74226e113a
commit
4727d0180f
@ -828,7 +828,7 @@ class TPUEmbedding(object):
|
|||||||
... end_learning_rate=0.0)
|
... end_learning_rate=0.0)
|
||||||
>>> wordpiece_table_config = TableConfig(
|
>>> wordpiece_table_config = TableConfig(
|
||||||
... vocabulary_size=119547,
|
... vocabulary_size=119547,
|
||||||
... dimension=768,
|
... dimension=256,
|
||||||
... learning_rate_fn=learning_rate_fn)
|
... learning_rate_fn=learning_rate_fn)
|
||||||
>>> wordpiece_feature_config = FeatureConfig(
|
>>> wordpiece_feature_config = FeatureConfig(
|
||||||
... table_id='bert/embeddings/word_embeddings',
|
... table_id='bert/embeddings/word_embeddings',
|
||||||
@ -846,11 +846,11 @@ class TPUEmbedding(object):
|
|||||||
... batch_size=128,
|
... batch_size=128,
|
||||||
... mode=TRAINING,
|
... mode=TRAINING,
|
||||||
... optimization_parameters=optimization_parameters,
|
... optimization_parameters=optimization_parameters,
|
||||||
... device_config=DeviceConfig(
|
... master='')
|
||||||
... num_cores=64, num_hosts=4, job_name='tpu_worker'))
|
|
||||||
>>> with tf.Graph().as_default():
|
>>> with tf.Graph().as_default():
|
||||||
... init_tpu_op = tf.compat.v1.tpu.initialize_system(
|
... init_tpu_op = tf.compat.v1.tpu.initialize_system(
|
||||||
... embedding_config=tpu_embedding.config_proto, job='tpu_worker')
|
... embedding_config=tpu_embedding.config_proto)
|
||||||
|
... tf.compat.v1.Session().run(init_tpu_op)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# TODO(shizhiw): Consider adding a field to FeatureConfig that indicates that
|
# TODO(shizhiw): Consider adding a field to FeatureConfig that indicates that
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
# Doc generator
|
# Doc generator
|
||||||
|
|
||||||
load("//tensorflow:tensorflow.bzl", "py_test")
|
load("//tensorflow:tensorflow.bzl", "py_test")
|
||||||
|
load("//tensorflow/python/tpu:tpu.bzl", "tpu_py_test")
|
||||||
|
|
||||||
package(
|
package(
|
||||||
default_visibility = ["//tensorflow:__subpackages__"],
|
default_visibility = ["//tensorflow:__subpackages__"],
|
||||||
@ -22,6 +23,7 @@ py_library(
|
|||||||
py_test(
|
py_test(
|
||||||
name = "tf_doctest",
|
name = "tf_doctest",
|
||||||
srcs = ["tf_doctest.py"],
|
srcs = ["tf_doctest.py"],
|
||||||
|
args = ["--module_prefix_skip=tpu."],
|
||||||
python_version = "PY3",
|
python_version = "PY3",
|
||||||
tags = [
|
tags = [
|
||||||
"no_oss_py2",
|
"no_oss_py2",
|
||||||
@ -40,6 +42,28 @@ py_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tpu_py_test(
|
||||||
|
name = "tf_doctest_tpu",
|
||||||
|
srcs = ["tf_doctest.py"],
|
||||||
|
args = ["--module=tpu."],
|
||||||
|
disable_experimental = True,
|
||||||
|
disable_v3 = True,
|
||||||
|
main = "tf_doctest.py",
|
||||||
|
python_version = "PY3",
|
||||||
|
tags = [
|
||||||
|
"no_oss",
|
||||||
|
"noasan",
|
||||||
|
"nomsan",
|
||||||
|
"notsan",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":tf_doctest_lib",
|
||||||
|
"//tensorflow:tensorflow_py",
|
||||||
|
"//tensorflow/python/keras/preprocessing",
|
||||||
|
"//third_party/py/numpy",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
py_test(
|
py_test(
|
||||||
name = "tf_doctest_test",
|
name = "tf_doctest_test",
|
||||||
srcs = ["tf_doctest_test.py"],
|
srcs = ["tf_doctest_test.py"],
|
||||||
|
@ -43,6 +43,8 @@ tf.keras.preprocessing = preprocessing
|
|||||||
FLAGS = flags.FLAGS
|
FLAGS = flags.FLAGS
|
||||||
|
|
||||||
flags.DEFINE_string('module', None, 'A specific module to run doctest on.')
|
flags.DEFINE_string('module', None, 'A specific module to run doctest on.')
|
||||||
|
flags.DEFINE_list('module_prefix_skip', [],
|
||||||
|
'A list of modules to ignore when resolving modules.')
|
||||||
flags.DEFINE_boolean('list', None,
|
flags.DEFINE_boolean('list', None,
|
||||||
'List all the modules in the core package imported.')
|
'List all the modules in the core package imported.')
|
||||||
flags.DEFINE_string('file', None, 'A specific file to run doctest on.')
|
flags.DEFINE_string('file', None, 'A specific file to run doctest on.')
|
||||||
@ -50,6 +52,7 @@ flags.DEFINE_string('file', None, 'A specific file to run doctest on.')
|
|||||||
flags.mark_flags_as_mutual_exclusive(['module', 'file'])
|
flags.mark_flags_as_mutual_exclusive(['module', 'file'])
|
||||||
flags.mark_flags_as_mutual_exclusive(['list', 'file'])
|
flags.mark_flags_as_mutual_exclusive(['list', 'file'])
|
||||||
|
|
||||||
|
# Both --module and --module_prefix_skip are relative to PACKAGE.
|
||||||
PACKAGE = 'tensorflow.python.'
|
PACKAGE = 'tensorflow.python.'
|
||||||
|
|
||||||
|
|
||||||
@ -140,6 +143,9 @@ def load_tests(unused_loader, tests, unused_ignore):
|
|||||||
tf_modules = get_module_and_inject_docstring(FLAGS.file)
|
tf_modules = get_module_and_inject_docstring(FLAGS.file)
|
||||||
|
|
||||||
for module in tf_modules:
|
for module in tf_modules:
|
||||||
|
if any(module.__name__.startswith(PACKAGE + prefix)
|
||||||
|
for prefix in FLAGS.module_prefix_skip):
|
||||||
|
continue
|
||||||
testcase = TfTestCase()
|
testcase = TfTestCase()
|
||||||
tests.addTests(
|
tests.addTests(
|
||||||
doctest.DocTestSuite(
|
doctest.DocTestSuite(
|
||||||
|
Loading…
Reference in New Issue
Block a user