From 18297126c0bc13c7045841f5a7a99f3da68176f4 Mon Sep 17 00:00:00 2001
From: Geoffrey Irving <geoffreyi@google.com>
Date: Wed, 10 Feb 2016 11:48:34 -0800
Subject: [PATCH] Fix tf.test for PEP-8 and document

tf.test now has appropriate snake case function names (get_temp_dir and
is_built_with_cuda) and has normal toplevel module documentation.

Also fix a bug in make_all.
Change: 114351269
---
 RELEASE.md                                    |  11 +-
 tensorflow/g3doc/api_docs/python/index.md     |   8 +
 tensorflow/g3doc/api_docs/python/test.md      | 158 ++++++++++++++++++
 tensorflow/python/framework/docs.py           |   5 +-
 .../python/framework/gen_docs_combined.py     |   6 +-
 tensorflow/python/framework/test_util.py      |   2 +-
 .../kernel_tests/dense_update_ops_test.py     |   2 +-
 .../python/kernel_tests/fft_ops_test.py       |  10 +-
 .../python/kernel_tests/pooling_ops_test.py   |  14 +-
 tensorflow/python/platform/test.py            |  78 ++++++++-
 tensorflow/python/training/saver_test.py      |   2 +-
 tensorflow/python/util/all_util.py            |   3 +-
 12 files changed, 270 insertions(+), 29 deletions(-)
 create mode 100644 tensorflow/g3doc/api_docs/python/test.md

diff --git a/RELEASE.md b/RELEASE.md
index b8146610c80..3f3d66dc6e0 100644
--- a/RELEASE.md
+++ b/RELEASE.md
@@ -28,15 +28,18 @@
 * `GraphOptions.skip_common_subexpression_elimination` has been removed. All
   graph optimizer options are now specified via
   `GraphOptions.OptimizerOptions`.
-* ASSERT_OK / EXPECT_OK macros conflicted with external projects, so they were
-  renamed TF_ASSERT_OK, TF_EXPECT_OK.  The existing macros are currently
-  maintained for short-term compatibility but will be removed.
+* `ASSERT_OK` / `EXPECT_OK` macros conflicted with external projects, so they
+  were renamed `TF_ASSERT_OK`, `TF_EXPECT_OK`.  The existing macros are
+  currently maintained for short-term compatibility but will be removed.
 * The non-public `nn.rnn` and the various `nn.seq2seq` methods now return
   just the final state instead of the list of all states.
 * `tf.image.random_crop(image, [height, width])` is now
   `tf.random_crop(image, [height, width, depth])`, and `tf.random_crop` works
   for any rank (not just 3-D images).  The C++ `RandomCrop` op has been replaced
   with pure Python.
+* Renamed `tf.test.GetTempDir` and `tf.test.IsBuiltWithCuda` to
+  `tf.test.get_temp_dir` and `tf.test.is_built_with_cuda` for PEP-8
+  compatibility.
 
 
 ## Bug fixes
@@ -71,7 +74,7 @@
 
 ## Backwards-incompatible changes
 
-* tf.nn.fixed_unigram_candidate_sampler changed its default 'distortion'
+* `tf.nn.fixed_unigram_candidate_sampler` changed its default 'distortion'
   attribute from 0.0 to 1.0. This was a bug in the original release
   that is now fixed.
 
diff --git a/tensorflow/g3doc/api_docs/python/index.md b/tensorflow/g3doc/api_docs/python/index.md
index d79e110727c..6a4f1662dfa 100644
--- a/tensorflow/g3doc/api_docs/python/index.md
+++ b/tensorflow/g3doc/api_docs/python/index.md
@@ -405,3 +405,11 @@
 * **[Wraps python functions](../../api_docs/python/script_ops.md)**:
   * [`py_func`](../../api_docs/python/script_ops.md#py_func)
 
+* **[Testing](../../api_docs/python/test.md)**:
+  * [`assert_equal_graph_def`](../../api_docs/python/test.md#assert_equal_graph_def)
+  * [`compute_gradient`](../../api_docs/python/test.md#compute_gradient)
+  * [`compute_gradient_error`](../../api_docs/python/test.md#compute_gradient_error)
+  * [`get_temp_dir`](../../api_docs/python/test.md#get_temp_dir)
+  * [`is_built_with_cuda`](../../api_docs/python/test.md#is_built_with_cuda)
+  * [`main`](../../api_docs/python/test.md#main)
+
diff --git a/tensorflow/g3doc/api_docs/python/test.md b/tensorflow/g3doc/api_docs/python/test.md
new file mode 100644
index 00000000000..1b5f409b1cb
--- /dev/null
+++ b/tensorflow/g3doc/api_docs/python/test.md
@@ -0,0 +1,158 @@
+<!-- This file is machine generated: DO NOT EDIT! -->
+
+# Testing
+[TOC]
+
+## Unit tests
+
+TensorFlow provides a convenience class inheriting from `unittest.TestCase`
+which adds methods relevant to TensorFlow tests.  Here is an example:
+
+    import tensorflow as tf
+
+
+    class SquareTest(tf.test.TestCase):
+
+      def testSquare(self):
+        with self.test_session():
+          x = tf.square([2, 3])
+          self.assertAllEqual(x.eval(), [4, 9])
+
+
+    if __name__ == '__main__':
+      tf.test.main()
+
+
+`tf.test.TestCase` inherits from `unittest.TestCase` but adds a few additional
+methods.  We will document these methods soon.
+
+- - -
+
+### `tf.test.main()` {#main}
+
+Runs all unit tests.
+
+
+
+## Utilities
+
+- - -
+
+### `tf.test.assert_equal_graph_def(actual, expected)` {#assert_equal_graph_def}
+
+Asserts that two `GraphDef`s are (mostly) the same.
+
+Compares two `GraphDef` protos for equality, ignoring versions and ordering of
+nodes, attrs, and control inputs.  Node names are used to match up nodes
+between the graphs, so the naming of nodes must be consistent.
+
+##### Args:
+
+
+*  <b>`actual`</b>: The `GraphDef` we have.
+*  <b>`expected`</b>: The `GraphDef` we expected.
+
+##### Raises:
+
+
+*  <b>`AssertionError`</b>: If the `GraphDef`s do not match.
+*  <b>`TypeError`</b>: If either argument is not a `GraphDef`.
+
+
+- - -
+
+### `tf.test.get_temp_dir()` {#get_temp_dir}
+
+Returns a temporary directory for use during tests.
+
+There is no need to delete the directory after the test.
+
+##### Returns:
+
+  The temporary directory.
+
+
+- - -
+
+### `tf.test.is_built_with_cuda()` {#is_built_with_cuda}
+
+Returns whether TensorFlow was built with CUDA (GPU) support.
+
+
+
+## Gradient checking
+
+[`compute_gradient`](#compute_gradient) and
+[`compute_gradient_error`](#compute_gradient_error) perform numerical
+differentiation of graphs for comparison against registered analytic gradients.
+
+- - -
+
+### `tf.test.compute_gradient(x, x_shape, y, y_shape, x_init_value=None, delta=0.001, init_targets=None)` {#compute_gradient}
+
+Computes and returns the theoretical and numerical Jacobian.
+
+##### Args:
+
+
+*  <b>`x`</b>: a tensor or list of tensors
+*  <b>`x_shape`</b>: the dimensions of x as a tuple or an array of ints. If x is a list,
+  then this is the list of shapes.
+
+*  <b>`y`</b>: a tensor
+*  <b>`y_shape`</b>: the dimensions of y as a tuple or an array of ints.
+*  <b>`x_init_value`</b>: (optional) a numpy array of the same shape as "x"
+    representing the initial value of x. If x is a list, this should be a list
+    of numpy arrays.  If this is none, the function will pick a random tensor
+    as the initial value.
+*  <b>`delta`</b>: (optional) the amount of perturbation.
+*  <b>`init_targets`</b>: list of targets to run to initialize model params.
+    TODO(mrry): remove this argument.
+
+##### Returns:
+
+  Two 2-d numpy arrays representing the theoretical and numerical
+  Jacobian for dy/dx. Each has "x_size" rows and "y_size" columns
+  where "x_size" is the number of elements in x and "y_size" is the
+  number of elements in y. If x is a list, returns a list of two numpy arrays.
+
+
+- - -
+
+### `tf.test.compute_gradient_error(x, x_shape, y, y_shape, x_init_value=None, delta=0.001, init_targets=None)` {#compute_gradient_error}
+
+Computes the gradient error.
+
+Computes the maximum error for dy/dx between the computed Jacobian and the
+numerically estimated Jacobian.
+
+This function will modify the tensors passed in as it adds more operations
+and hence changing the consumers of the operations of the input tensors.
+
+This function adds operations to the current session. To compute the error
+using a particular device, such as a GPU, use the standard methods for
+setting a device (e.g. using with sess.graph.device() or setting a device
+function in the session constructor).
+
+##### Args:
+
+
+*  <b>`x`</b>: a tensor or list of tensors
+*  <b>`x_shape`</b>: the dimensions of x as a tuple or an array of ints. If x is a list,
+  then this is the list of shapes.
+
+*  <b>`y`</b>: a tensor
+*  <b>`y_shape`</b>: the dimensions of y as a tuple or an array of ints.
+*  <b>`x_init_value`</b>: (optional) a numpy array of the same shape as "x"
+    representing the initial value of x. If x is a list, this should be a list
+    of numpy arrays.  If this is none, the function will pick a random tensor
+    as the initial value.
+*  <b>`delta`</b>: (optional) the amount of perturbation.
+*  <b>`init_targets`</b>: list of targets to run to initialize model params.
+    TODO(mrry): Remove this argument.
+
+##### Returns:
+
+  The maximum error in between the two Jacobians.
+
+
diff --git a/tensorflow/python/framework/docs.py b/tensorflow/python/framework/docs.py
index 8b4ef789cd5..aefbffdc470 100644
--- a/tensorflow/python/framework/docs.py
+++ b/tensorflow/python/framework/docs.py
@@ -97,6 +97,7 @@ class Index(Document):
           print("  * %s" % link, file=f)
         print("", file=f)
 
+
 def collect_members(module_to_name):
   """Collect all symbols from a list of modules.
 
@@ -108,9 +109,11 @@ def collect_members(module_to_name):
   """
   members = {}
   for module, module_name in module_to_name.items():
+    all_names = getattr(module, "__all__", None)
     for name, member in inspect.getmembers(module):
       if ((inspect.isfunction(member) or inspect.isclass(member)) and
-          not _always_drop_symbol_re.match(name)):
+          not _always_drop_symbol_re.match(name) and
+          (all_names is None or name in all_names)):
         fullname = '%s.%s' % (module_name, name)
         if name in members:
           other_fullname, other_member = members[name]
diff --git a/tensorflow/python/framework/gen_docs_combined.py b/tensorflow/python/framework/gen_docs_combined.py
index 3ade2abd050..3f6ab24901c 100644
--- a/tensorflow/python/framework/gen_docs_combined.py
+++ b/tensorflow/python/framework/gen_docs_combined.py
@@ -50,7 +50,8 @@ def get_module_to_name():
           tf.image: "tf.image",
           tf.nn: "tf.nn",
           tf.train: "tf.train",
-          tf.python_io: "tf.python_io",}
+          tf.python_io: "tf.python_io",
+          tf.test: "tf.test",}
 
 def all_libraries(module_to_name, members, documented):
   # A list of (filename, docs.Library) pairs representing the individual files
@@ -113,6 +114,7 @@ def all_libraries(module_to_name, members, documented):
                                "FeatureList", "FeatureLists",
                                "RankingExample", "SequenceExample"]),
       library("script_ops", "Wraps python functions", prefix=PREFIX_TEXT),
+      library("test", "Testing", tf.test),
   ]
 
 _hidden_symbols = ["Event", "LogMessage", "Summary", "SessionLog", "xrange",
@@ -121,7 +123,7 @@ _hidden_symbols = ["Event", "LogMessage", "Summary", "SessionLog", "xrange",
                    "BaseSession", "NameAttrList", "AttrValue",
                    "TensorArray", "OptimizerOptions",
                    "CollectionDef", "MetaGraphDef", "QueueRunnerDef",
-                   "SaverDef", "VariableDef",
+                   "SaverDef", "VariableDef", "TestCase",
                   ]
 
 def main(unused_argv):
diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py
index c004e8eb2ef..8eb8c53a29d 100644
--- a/tensorflow/python/framework/test_util.py
+++ b/tensorflow/python/framework/test_util.py
@@ -107,7 +107,7 @@ def IsGoogleCudaEnabled():
 
 
 class TensorFlowTestCase(googletest.TestCase):
-  """Root class for tests that need to test tensor flow.
+  """Base class for tests that need to test TensorFlow.
   """
 
   def __init__(self, methodName="runTest"):
diff --git a/tensorflow/python/kernel_tests/dense_update_ops_test.py b/tensorflow/python/kernel_tests/dense_update_ops_test.py
index a3c8e92a86a..e51d14a6ee4 100644
--- a/tensorflow/python/kernel_tests/dense_update_ops_test.py
+++ b/tensorflow/python/kernel_tests/dense_update_ops_test.py
@@ -67,7 +67,7 @@ class AssignOpTest(tf.test.TestCase):
       var_value, op_value = self._initAssignSubFetch(x, y, use_gpu=False)
       self.assertAllEqual(x - y, var_value)
       self.assertAllEqual(x - y, op_value)
-      if tf.test.IsBuiltWithCuda() and dtype in [np.float32, np.float64]:
+      if tf.test.is_built_with_cuda() and dtype in [np.float32, np.float64]:
         var_value, op_value = self._initAssignFetch(x, y, use_gpu=True)
         self.assertAllEqual(y, var_value)
         self.assertAllEqual(y, op_value)
diff --git a/tensorflow/python/kernel_tests/fft_ops_test.py b/tensorflow/python/kernel_tests/fft_ops_test.py
index c2ec9c8743d..fa7d6c7680e 100644
--- a/tensorflow/python/kernel_tests/fft_ops_test.py
+++ b/tensorflow/python/kernel_tests/fft_ops_test.py
@@ -41,7 +41,7 @@ class FFT2DOpsTest(tf.test.TestCase):
     return np.fft.ifft2(x)
 
   def _Compare(self, x):
-    if tf.test.IsBuiltWithCuda():
+    if tf.test.is_built_with_cuda():
       # GPU/Forward
       self.assertAllClose(
           self._npFFT2D(x),
@@ -74,13 +74,13 @@ class FFT2DOpsTest(tf.test.TestCase):
       self._Compare(gen(shape))
 
   def testEmpty(self):
-    if tf.test.IsBuiltWithCuda():
+    if tf.test.is_built_with_cuda():
       x = np.zeros([40, 0]).astype(np.complex64)
       self.assertEqual(x.shape, self._tfFFT2D(x).shape)
       self.assertEqual(x.shape, self._tfIFFT2D(x).shape)
 
   def testError(self):
-    if tf.test.IsBuiltWithCuda():
+    if tf.test.is_built_with_cuda():
       x = np.zeros([1, 2, 3]).astype(np.complex64)
       with self.assertRaisesOpError("Input is not a matrix"):
         self._tfFFT2D(x)
@@ -107,14 +107,14 @@ class FFT2DOpsTest(tf.test.TestCase):
     self.assertAllClose(y_jacob_t, y_jacob_n, rtol=1e-2, atol=1e-2)
 
   def testGrad_Simple(self):
-    if tf.test.IsBuiltWithCuda():
+    if tf.test.is_built_with_cuda():
       re = np.array([[1., 0.], [0., 1.]]).astype(np.float32)
       im = np.array([[0., 0.], [0., 0.]]).astype(np.float32)
       self._checkGrad(tf.fft2d, re, im, use_gpu=True)
       self._checkGrad(tf.ifft2d, re, im, use_gpu=True)
 
   def testGrad_Random(self):
-    if tf.test.IsBuiltWithCuda():
+    if tf.test.is_built_with_cuda():
       shape = (4, 8)
       np.random.seed(54321)
       re = np.random.rand(*shape).astype(np.float32) * 2 - 1
diff --git a/tensorflow/python/kernel_tests/pooling_ops_test.py b/tensorflow/python/kernel_tests/pooling_ops_test.py
index 427d83b2106..66e8cbab8bc 100644
--- a/tensorflow/python/kernel_tests/pooling_ops_test.py
+++ b/tensorflow/python/kernel_tests/pooling_ops_test.py
@@ -309,7 +309,7 @@ class PoolingTest(tf.test.TestCase):
     self._testDepthwiseMaxPoolInvalidConfig(
         [1, 2, 2, 4], [1, 1, 1, 3],
         [1, 1, 1, 3], "evenly divide")
-    if tf.test.IsBuiltWithCuda():
+    if tf.test.is_built_with_cuda():
       with self.test_session(use_gpu=True):
         t = tf.constant(1.0, shape=[1, 2, 2, 4])
         with self.assertRaisesOpError("for CPU devices"):
@@ -359,7 +359,7 @@ class PoolingTest(tf.test.TestCase):
 
   def testMaxPoolingWithArgmax(self):
     # MaxPoolWithArgMax is implemented only on GPU.
-    if not tf.test.IsBuiltWithCuda():
+    if not tf.test.is_built_with_cuda():
       return
     tensor_input = [1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0]
     with self.test_session(use_gpu=True) as sess:
@@ -377,7 +377,7 @@ class PoolingTest(tf.test.TestCase):
 
   def testMaxPoolingGradWithArgmax(self):
     # MaxPoolWithArgMax is implemented only on GPU.
-    if not tf.test.IsBuiltWithCuda():
+    if not tf.test.is_built_with_cuda():
       return
     orig_input = [1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0]
     tensor_input = [11.0, 12.0, 13.0, 14.0]
@@ -641,7 +641,7 @@ class PoolingTest(tf.test.TestCase):
         window_rows=2, window_cols=2, row_stride=1, col_stride=1,
         padding="VALID", use_gpu=False)
 
-    if not tf.test.IsBuiltWithCuda():
+    if not tf.test.is_built_with_cuda():
       return
 
     # Test the GPU implementation that uses cudnn for now.
@@ -675,7 +675,7 @@ class PoolingTest(tf.test.TestCase):
         window_rows=2, window_cols=2, row_stride=1, col_stride=1,
         padding="VALID", use_gpu=False)
 
-    if not tf.test.IsBuiltWithCuda():
+    if not tf.test.is_built_with_cuda():
       return
 
     # Test the GPU implementation that uses cudnn for now.
@@ -813,7 +813,7 @@ class PoolingTest(tf.test.TestCase):
 def GetMaxPoolFwdTest(input_size, filter_size, strides, padding):
   def Test(self):
     # MaxPoolWithArgMax is implemented only on GPU.
-    if not tf.test.IsBuiltWithCuda():
+    if not tf.test.is_built_with_cuda():
       return
     self._CompareMaxPoolingFwd(input_size, filter_size, strides, padding)
   return Test
@@ -822,7 +822,7 @@ def GetMaxPoolFwdTest(input_size, filter_size, strides, padding):
 def GetMaxPoolGradTest(input_size, filter_size, output_size, strides, padding):
   def Test(self):
     # MaxPoolWithArgMax is implemented only on GPU.
-    if not tf.test.IsBuiltWithCuda():
+    if not tf.test.is_built_with_cuda():
       return
     self._CompareMaxPoolingBk(input_size, output_size,
                               filter_size, strides, padding)
diff --git a/tensorflow/python/platform/test.py b/tensorflow/python/platform/test.py
index 919455b495f..d2b9d1f9747 100644
--- a/tensorflow/python/platform/test.py
+++ b/tensorflow/python/platform/test.py
@@ -13,20 +13,86 @@
 # limitations under the License.
 # ==============================================================================
 
+# pylint: disable=g-short-docstring-punctuation
+"""## Unit tests
+
+TensorFlow provides a convenience class inheriting from `unittest.TestCase`
+which adds methods relevant to TensorFlow tests.  Here is an example:
+
+    import tensorflow as tf
+
+
+    class SquareTest(tf.test.TestCase):
+
+      def testSquare(self):
+        with self.test_session():
+          x = tf.square([2, 3])
+          self.assertAllEqual(x.eval(), [4, 9])
+
+
+    if __name__ == '__main__':
+      tf.test.main()
+
+
+`tf.test.TestCase` inherits from `unittest.TestCase` but adds a few additional
+methods.  We will document these methods soon.
+
+@@main
+
+## Utilities
+
+@@assert_equal_graph_def
+@@get_temp_dir
+@@is_built_with_cuda
+
+## Gradient checking
+
+[`compute_gradient`](#compute_gradient) and
+[`compute_gradient_error`](#compute_gradient_error) perform numerical
+differentiation of graphs for comparison against registered analytic gradients.
+
+@@compute_gradient
+@@compute_gradient_error
+
+"""
 from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+from tensorflow.python.framework import test_util
+from tensorflow.python.platform import googletest
+from tensorflow.python.util.all_util import make_all
+
 # pylint: disable=unused-import
-from tensorflow.python.platform.googletest import GetTempDir
-from tensorflow.python.platform.googletest import main
-from tensorflow.python.platform.googletest import StubOutForTesting
 from tensorflow.python.framework.test_util import TensorFlowTestCase as TestCase
-from tensorflow.python.framework.test_util import IsGoogleCudaEnabled as IsBuiltWithCuda
 from tensorflow.python.framework.test_util import assert_equal_graph_def
 
 from tensorflow.python.kernel_tests.gradient_checker import compute_gradient_error
 from tensorflow.python.kernel_tests.gradient_checker import compute_gradient
-
-get_temp_dir = GetTempDir
 # pylint: enable=unused-import
+
+
+def main():
+  """Runs all unit tests."""
+  return googletest.main()
+
+
+def get_temp_dir():
+  """Returns a temporary directory for use during tests.
+
+  There is no need to delete the directory after the test.
+
+  Returns:
+    The temporary directory.
+  """
+  return googletest.GetTempDir()
+
+
+def is_built_with_cuda():
+  """Returns whether TensorFlow was built with CUDA (GPU) support."""
+  return test_util.IsGoogleCudaEnabled()
+
+
+__all__ = make_all(__name__)
+# TODO(irving,vrv): Remove once TestCase is documented
+__all__.append('TestCase')
diff --git a/tensorflow/python/training/saver_test.py b/tensorflow/python/training/saver_test.py
index 4ef97af88cc..dd034cdb5d1 100644
--- a/tensorflow/python/training/saver_test.py
+++ b/tensorflow/python/training/saver_test.py
@@ -218,7 +218,7 @@ class SaverTest(tf.test.TestCase):
     self._SaveAndLoad("var1", 1.1, 2.2, save_path)
 
   def testGPU(self):
-    if not tf.test.IsBuiltWithCuda():
+    if not tf.test.is_built_with_cuda():
       return
     save_path = os.path.join(self.get_temp_dir(), "gpu")
     with tf.Session("", graph=tf.Graph()) as sess:
diff --git a/tensorflow/python/util/all_util.py b/tensorflow/python/util/all_util.py
index e88e4ea8475..d6d6c731f26 100644
--- a/tensorflow/python/util/all_util.py
+++ b/tensorflow/python/util/all_util.py
@@ -21,7 +21,7 @@ from __future__ import print_function
 import re
 import sys
 
-_reference_pattern = re.compile(r'^@@(\w+)$')
+_reference_pattern = re.compile(r'^@@(\w+)$', flags=re.MULTILINE)
 
 
 def make_all(module_name):
@@ -39,4 +39,5 @@ def make_all(module_name):
   doc = sys.modules[module_name].__doc__
   return [m.group(1) for m in _reference_pattern.finditer(doc)]
 
+
 __all__ = ['make_all']