Fix tf.test for PEP-8 and document

tf.test now has appropriate snake case function names (get_temp_dir and
is_built_with_cuda) and has normal toplevel module documentation.

Also fix a bug in make_all.
Change: 114351269
This commit is contained in:
Geoffrey Irving 2016-02-10 11:48:34 -08:00 committed by TensorFlower Gardener
parent cab186e7a6
commit 18297126c0
12 changed files with 270 additions and 29 deletions

View File

@ -28,15 +28,18 @@
* `GraphOptions.skip_common_subexpression_elimination` has been removed. All * `GraphOptions.skip_common_subexpression_elimination` has been removed. All
graph optimizer options are now specified via graph optimizer options are now specified via
`GraphOptions.OptimizerOptions`. `GraphOptions.OptimizerOptions`.
* ASSERT_OK / EXPECT_OK macros conflicted with external projects, so they were * `ASSERT_OK` / `EXPECT_OK` macros conflicted with external projects, so they
renamed TF_ASSERT_OK, TF_EXPECT_OK. The existing macros are currently were renamed `TF_ASSERT_OK`, `TF_EXPECT_OK`. The existing macros are
maintained for short-term compatibility but will be removed. currently maintained for short-term compatibility but will be removed.
* The non-public `nn.rnn` and the various `nn.seq2seq` methods now return * The non-public `nn.rnn` and the various `nn.seq2seq` methods now return
just the final state instead of the list of all states. just the final state instead of the list of all states.
* `tf.image.random_crop(image, [height, width])` is now * `tf.image.random_crop(image, [height, width])` is now
`tf.random_crop(image, [height, width, depth])`, and `tf.random_crop` works `tf.random_crop(image, [height, width, depth])`, and `tf.random_crop` works
for any rank (not just 3-D images). The C++ `RandomCrop` op has been replaced for any rank (not just 3-D images). The C++ `RandomCrop` op has been replaced
with pure Python. with pure Python.
* Renamed `tf.test.GetTempDir` and `tf.test.IsBuiltWithCuda` to
`tf.test.get_temp_dir` and `tf.test.is_built_with_cuda` for PEP-8
compatibility.
## Bug fixes ## Bug fixes
@ -71,7 +74,7 @@
## Backwards-incompatible changes ## Backwards-incompatible changes
* tf.nn.fixed_unigram_candidate_sampler changed its default 'distortion' * `tf.nn.fixed_unigram_candidate_sampler` changed its default 'distortion'
attribute from 0.0 to 1.0. This was a bug in the original release attribute from 0.0 to 1.0. This was a bug in the original release
that is now fixed. that is now fixed.

View File

@ -405,3 +405,11 @@
* **[Wraps python functions](../../api_docs/python/script_ops.md)**: * **[Wraps python functions](../../api_docs/python/script_ops.md)**:
* [`py_func`](../../api_docs/python/script_ops.md#py_func) * [`py_func`](../../api_docs/python/script_ops.md#py_func)
* **[Testing](../../api_docs/python/test.md)**:
* [`assert_equal_graph_def`](../../api_docs/python/test.md#assert_equal_graph_def)
* [`compute_gradient`](../../api_docs/python/test.md#compute_gradient)
* [`compute_gradient_error`](../../api_docs/python/test.md#compute_gradient_error)
* [`get_temp_dir`](../../api_docs/python/test.md#get_temp_dir)
* [`is_built_with_cuda`](../../api_docs/python/test.md#is_built_with_cuda)
* [`main`](../../api_docs/python/test.md#main)

View File

@ -0,0 +1,158 @@
<!-- This file is machine generated: DO NOT EDIT! -->
# Testing
[TOC]
## Unit tests
TensorFlow provides a convenience class inheriting from `unittest.TestCase`
which adds methods relevant to TensorFlow tests. Here is an example:
import tensorflow as tf
class SquareTest(tf.test.TestCase):
def testSquare(self):
with self.test_session():
x = tf.square([2, 3])
self.assertAllEqual(x.eval(), [4, 9])
if __name__ == '__main__':
tf.test.main()
`tf.test.TestCase` inherits from `unittest.TestCase` but adds a few additional
methods. We will document these methods soon.
- - -
### `tf.test.main()` {#main}
Runs all unit tests.
## Utilities
- - -
### `tf.test.assert_equal_graph_def(actual, expected)` {#assert_equal_graph_def}
Asserts that two `GraphDef`s are (mostly) the same.
Compares two `GraphDef` protos for equality, ignoring versions and ordering of
nodes, attrs, and control inputs. Node names are used to match up nodes
between the graphs, so the naming of nodes must be consistent.
##### Args:
* <b>`actual`</b>: The `GraphDef` we have.
* <b>`expected`</b>: The `GraphDef` we expected.
##### Raises:
* <b>`AssertionError`</b>: If the `GraphDef`s do not match.
* <b>`TypeError`</b>: If either argument is not a `GraphDef`.
- - -
### `tf.test.get_temp_dir()` {#get_temp_dir}
Returns a temporary directory for use during tests.
There is no need to delete the directory after the test.
##### Returns:
The temporary directory.
- - -
### `tf.test.is_built_with_cuda()` {#is_built_with_cuda}
Returns whether TensorFlow was built with CUDA (GPU) support.
## Gradient checking
[`compute_gradient`](#compute_gradient) and
[`compute_gradient_error`](#compute_gradient_error) perform numerical
differentiation of graphs for comparison against registered analytic gradients.
- - -
### `tf.test.compute_gradient(x, x_shape, y, y_shape, x_init_value=None, delta=0.001, init_targets=None)` {#compute_gradient}
Computes and returns the theoretical and numerical Jacobian.
##### Args:
* <b>`x`</b>: a tensor or list of tensors
* <b>`x_shape`</b>: the dimensions of x as a tuple or an array of ints. If x is a list,
then this is the list of shapes.
* <b>`y`</b>: a tensor
* <b>`y_shape`</b>: the dimensions of y as a tuple or an array of ints.
* <b>`x_init_value`</b>: (optional) a numpy array of the same shape as "x"
representing the initial value of x. If x is a list, this should be a list
of numpy arrays. If this is none, the function will pick a random tensor
as the initial value.
* <b>`delta`</b>: (optional) the amount of perturbation.
* <b>`init_targets`</b>: list of targets to run to initialize model params.
TODO(mrry): remove this argument.
##### Returns:
Two 2-d numpy arrays representing the theoretical and numerical
Jacobian for dy/dx. Each has "x_size" rows and "y_size" columns
where "x_size" is the number of elements in x and "y_size" is the
number of elements in y. If x is a list, returns a list of two numpy arrays.
- - -
### `tf.test.compute_gradient_error(x, x_shape, y, y_shape, x_init_value=None, delta=0.001, init_targets=None)` {#compute_gradient_error}
Computes the gradient error.
Computes the maximum error for dy/dx between the computed Jacobian and the
numerically estimated Jacobian.
This function will modify the tensors passed in as it adds more operations
and hence changing the consumers of the operations of the input tensors.
This function adds operations to the current session. To compute the error
using a particular device, such as a GPU, use the standard methods for
setting a device (e.g. using with sess.graph.device() or setting a device
function in the session constructor).
##### Args:
* <b>`x`</b>: a tensor or list of tensors
* <b>`x_shape`</b>: the dimensions of x as a tuple or an array of ints. If x is a list,
then this is the list of shapes.
* <b>`y`</b>: a tensor
* <b>`y_shape`</b>: the dimensions of y as a tuple or an array of ints.
* <b>`x_init_value`</b>: (optional) a numpy array of the same shape as "x"
representing the initial value of x. If x is a list, this should be a list
of numpy arrays. If this is none, the function will pick a random tensor
as the initial value.
* <b>`delta`</b>: (optional) the amount of perturbation.
* <b>`init_targets`</b>: list of targets to run to initialize model params.
TODO(mrry): Remove this argument.
##### Returns:
The maximum error in between the two Jacobians.

View File

@ -97,6 +97,7 @@ class Index(Document):
print(" * %s" % link, file=f) print(" * %s" % link, file=f)
print("", file=f) print("", file=f)
def collect_members(module_to_name): def collect_members(module_to_name):
"""Collect all symbols from a list of modules. """Collect all symbols from a list of modules.
@ -108,9 +109,11 @@ def collect_members(module_to_name):
""" """
members = {} members = {}
for module, module_name in module_to_name.items(): for module, module_name in module_to_name.items():
all_names = getattr(module, "__all__", None)
for name, member in inspect.getmembers(module): for name, member in inspect.getmembers(module):
if ((inspect.isfunction(member) or inspect.isclass(member)) and if ((inspect.isfunction(member) or inspect.isclass(member)) and
not _always_drop_symbol_re.match(name)): not _always_drop_symbol_re.match(name) and
(all_names is None or name in all_names)):
fullname = '%s.%s' % (module_name, name) fullname = '%s.%s' % (module_name, name)
if name in members: if name in members:
other_fullname, other_member = members[name] other_fullname, other_member = members[name]

View File

@ -50,7 +50,8 @@ def get_module_to_name():
tf.image: "tf.image", tf.image: "tf.image",
tf.nn: "tf.nn", tf.nn: "tf.nn",
tf.train: "tf.train", tf.train: "tf.train",
tf.python_io: "tf.python_io",} tf.python_io: "tf.python_io",
tf.test: "tf.test",}
def all_libraries(module_to_name, members, documented): def all_libraries(module_to_name, members, documented):
# A list of (filename, docs.Library) pairs representing the individual files # A list of (filename, docs.Library) pairs representing the individual files
@ -113,6 +114,7 @@ def all_libraries(module_to_name, members, documented):
"FeatureList", "FeatureLists", "FeatureList", "FeatureLists",
"RankingExample", "SequenceExample"]), "RankingExample", "SequenceExample"]),
library("script_ops", "Wraps python functions", prefix=PREFIX_TEXT), library("script_ops", "Wraps python functions", prefix=PREFIX_TEXT),
library("test", "Testing", tf.test),
] ]
_hidden_symbols = ["Event", "LogMessage", "Summary", "SessionLog", "xrange", _hidden_symbols = ["Event", "LogMessage", "Summary", "SessionLog", "xrange",
@ -121,7 +123,7 @@ _hidden_symbols = ["Event", "LogMessage", "Summary", "SessionLog", "xrange",
"BaseSession", "NameAttrList", "AttrValue", "BaseSession", "NameAttrList", "AttrValue",
"TensorArray", "OptimizerOptions", "TensorArray", "OptimizerOptions",
"CollectionDef", "MetaGraphDef", "QueueRunnerDef", "CollectionDef", "MetaGraphDef", "QueueRunnerDef",
"SaverDef", "VariableDef", "SaverDef", "VariableDef", "TestCase",
] ]
def main(unused_argv): def main(unused_argv):

View File

@ -107,7 +107,7 @@ def IsGoogleCudaEnabled():
class TensorFlowTestCase(googletest.TestCase): class TensorFlowTestCase(googletest.TestCase):
"""Root class for tests that need to test tensor flow. """Base class for tests that need to test TensorFlow.
""" """
def __init__(self, methodName="runTest"): def __init__(self, methodName="runTest"):

View File

@ -67,7 +67,7 @@ class AssignOpTest(tf.test.TestCase):
var_value, op_value = self._initAssignSubFetch(x, y, use_gpu=False) var_value, op_value = self._initAssignSubFetch(x, y, use_gpu=False)
self.assertAllEqual(x - y, var_value) self.assertAllEqual(x - y, var_value)
self.assertAllEqual(x - y, op_value) self.assertAllEqual(x - y, op_value)
if tf.test.IsBuiltWithCuda() and dtype in [np.float32, np.float64]: if tf.test.is_built_with_cuda() and dtype in [np.float32, np.float64]:
var_value, op_value = self._initAssignFetch(x, y, use_gpu=True) var_value, op_value = self._initAssignFetch(x, y, use_gpu=True)
self.assertAllEqual(y, var_value) self.assertAllEqual(y, var_value)
self.assertAllEqual(y, op_value) self.assertAllEqual(y, op_value)

View File

@ -41,7 +41,7 @@ class FFT2DOpsTest(tf.test.TestCase):
return np.fft.ifft2(x) return np.fft.ifft2(x)
def _Compare(self, x): def _Compare(self, x):
if tf.test.IsBuiltWithCuda(): if tf.test.is_built_with_cuda():
# GPU/Forward # GPU/Forward
self.assertAllClose( self.assertAllClose(
self._npFFT2D(x), self._npFFT2D(x),
@ -74,13 +74,13 @@ class FFT2DOpsTest(tf.test.TestCase):
self._Compare(gen(shape)) self._Compare(gen(shape))
def testEmpty(self): def testEmpty(self):
if tf.test.IsBuiltWithCuda(): if tf.test.is_built_with_cuda():
x = np.zeros([40, 0]).astype(np.complex64) x = np.zeros([40, 0]).astype(np.complex64)
self.assertEqual(x.shape, self._tfFFT2D(x).shape) self.assertEqual(x.shape, self._tfFFT2D(x).shape)
self.assertEqual(x.shape, self._tfIFFT2D(x).shape) self.assertEqual(x.shape, self._tfIFFT2D(x).shape)
def testError(self): def testError(self):
if tf.test.IsBuiltWithCuda(): if tf.test.is_built_with_cuda():
x = np.zeros([1, 2, 3]).astype(np.complex64) x = np.zeros([1, 2, 3]).astype(np.complex64)
with self.assertRaisesOpError("Input is not a matrix"): with self.assertRaisesOpError("Input is not a matrix"):
self._tfFFT2D(x) self._tfFFT2D(x)
@ -107,14 +107,14 @@ class FFT2DOpsTest(tf.test.TestCase):
self.assertAllClose(y_jacob_t, y_jacob_n, rtol=1e-2, atol=1e-2) self.assertAllClose(y_jacob_t, y_jacob_n, rtol=1e-2, atol=1e-2)
def testGrad_Simple(self): def testGrad_Simple(self):
if tf.test.IsBuiltWithCuda(): if tf.test.is_built_with_cuda():
re = np.array([[1., 0.], [0., 1.]]).astype(np.float32) re = np.array([[1., 0.], [0., 1.]]).astype(np.float32)
im = np.array([[0., 0.], [0., 0.]]).astype(np.float32) im = np.array([[0., 0.], [0., 0.]]).astype(np.float32)
self._checkGrad(tf.fft2d, re, im, use_gpu=True) self._checkGrad(tf.fft2d, re, im, use_gpu=True)
self._checkGrad(tf.ifft2d, re, im, use_gpu=True) self._checkGrad(tf.ifft2d, re, im, use_gpu=True)
def testGrad_Random(self): def testGrad_Random(self):
if tf.test.IsBuiltWithCuda(): if tf.test.is_built_with_cuda():
shape = (4, 8) shape = (4, 8)
np.random.seed(54321) np.random.seed(54321)
re = np.random.rand(*shape).astype(np.float32) * 2 - 1 re = np.random.rand(*shape).astype(np.float32) * 2 - 1

View File

@ -309,7 +309,7 @@ class PoolingTest(tf.test.TestCase):
self._testDepthwiseMaxPoolInvalidConfig( self._testDepthwiseMaxPoolInvalidConfig(
[1, 2, 2, 4], [1, 1, 1, 3], [1, 2, 2, 4], [1, 1, 1, 3],
[1, 1, 1, 3], "evenly divide") [1, 1, 1, 3], "evenly divide")
if tf.test.IsBuiltWithCuda(): if tf.test.is_built_with_cuda():
with self.test_session(use_gpu=True): with self.test_session(use_gpu=True):
t = tf.constant(1.0, shape=[1, 2, 2, 4]) t = tf.constant(1.0, shape=[1, 2, 2, 4])
with self.assertRaisesOpError("for CPU devices"): with self.assertRaisesOpError("for CPU devices"):
@ -359,7 +359,7 @@ class PoolingTest(tf.test.TestCase):
def testMaxPoolingWithArgmax(self): def testMaxPoolingWithArgmax(self):
# MaxPoolWithArgMax is implemented only on GPU. # MaxPoolWithArgMax is implemented only on GPU.
if not tf.test.IsBuiltWithCuda(): if not tf.test.is_built_with_cuda():
return return
tensor_input = [1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0] tensor_input = [1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0]
with self.test_session(use_gpu=True) as sess: with self.test_session(use_gpu=True) as sess:
@ -377,7 +377,7 @@ class PoolingTest(tf.test.TestCase):
def testMaxPoolingGradWithArgmax(self): def testMaxPoolingGradWithArgmax(self):
# MaxPoolWithArgMax is implemented only on GPU. # MaxPoolWithArgMax is implemented only on GPU.
if not tf.test.IsBuiltWithCuda(): if not tf.test.is_built_with_cuda():
return return
orig_input = [1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0] orig_input = [1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0]
tensor_input = [11.0, 12.0, 13.0, 14.0] tensor_input = [11.0, 12.0, 13.0, 14.0]
@ -641,7 +641,7 @@ class PoolingTest(tf.test.TestCase):
window_rows=2, window_cols=2, row_stride=1, col_stride=1, window_rows=2, window_cols=2, row_stride=1, col_stride=1,
padding="VALID", use_gpu=False) padding="VALID", use_gpu=False)
if not tf.test.IsBuiltWithCuda(): if not tf.test.is_built_with_cuda():
return return
# Test the GPU implementation that uses cudnn for now. # Test the GPU implementation that uses cudnn for now.
@ -675,7 +675,7 @@ class PoolingTest(tf.test.TestCase):
window_rows=2, window_cols=2, row_stride=1, col_stride=1, window_rows=2, window_cols=2, row_stride=1, col_stride=1,
padding="VALID", use_gpu=False) padding="VALID", use_gpu=False)
if not tf.test.IsBuiltWithCuda(): if not tf.test.is_built_with_cuda():
return return
# Test the GPU implementation that uses cudnn for now. # Test the GPU implementation that uses cudnn for now.
@ -813,7 +813,7 @@ class PoolingTest(tf.test.TestCase):
def GetMaxPoolFwdTest(input_size, filter_size, strides, padding): def GetMaxPoolFwdTest(input_size, filter_size, strides, padding):
def Test(self): def Test(self):
# MaxPoolWithArgMax is implemented only on GPU. # MaxPoolWithArgMax is implemented only on GPU.
if not tf.test.IsBuiltWithCuda(): if not tf.test.is_built_with_cuda():
return return
self._CompareMaxPoolingFwd(input_size, filter_size, strides, padding) self._CompareMaxPoolingFwd(input_size, filter_size, strides, padding)
return Test return Test
@ -822,7 +822,7 @@ def GetMaxPoolFwdTest(input_size, filter_size, strides, padding):
def GetMaxPoolGradTest(input_size, filter_size, output_size, strides, padding): def GetMaxPoolGradTest(input_size, filter_size, output_size, strides, padding):
def Test(self): def Test(self):
# MaxPoolWithArgMax is implemented only on GPU. # MaxPoolWithArgMax is implemented only on GPU.
if not tf.test.IsBuiltWithCuda(): if not tf.test.is_built_with_cuda():
return return
self._CompareMaxPoolingBk(input_size, output_size, self._CompareMaxPoolingBk(input_size, output_size,
filter_size, strides, padding) filter_size, strides, padding)

View File

@ -13,20 +13,86 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
# pylint: disable=g-short-docstring-punctuation
"""## Unit tests
TensorFlow provides a convenience class inheriting from `unittest.TestCase`
which adds methods relevant to TensorFlow tests. Here is an example:
import tensorflow as tf
class SquareTest(tf.test.TestCase):
def testSquare(self):
with self.test_session():
x = tf.square([2, 3])
self.assertAllEqual(x.eval(), [4, 9])
if __name__ == '__main__':
tf.test.main()
`tf.test.TestCase` inherits from `unittest.TestCase` but adds a few additional
methods. We will document these methods soon.
@@main
## Utilities
@@assert_equal_graph_def
@@get_temp_dir
@@is_built_with_cuda
## Gradient checking
[`compute_gradient`](#compute_gradient) and
[`compute_gradient_error`](#compute_gradient_error) perform numerical
differentiation of graphs for comparison against registered analytic gradients.
@@compute_gradient
@@compute_gradient_error
"""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.python.framework import test_util
from tensorflow.python.platform import googletest
from tensorflow.python.util.all_util import make_all
# pylint: disable=unused-import # pylint: disable=unused-import
from tensorflow.python.platform.googletest import GetTempDir
from tensorflow.python.platform.googletest import main
from tensorflow.python.platform.googletest import StubOutForTesting
from tensorflow.python.framework.test_util import TensorFlowTestCase as TestCase from tensorflow.python.framework.test_util import TensorFlowTestCase as TestCase
from tensorflow.python.framework.test_util import IsGoogleCudaEnabled as IsBuiltWithCuda
from tensorflow.python.framework.test_util import assert_equal_graph_def from tensorflow.python.framework.test_util import assert_equal_graph_def
from tensorflow.python.kernel_tests.gradient_checker import compute_gradient_error from tensorflow.python.kernel_tests.gradient_checker import compute_gradient_error
from tensorflow.python.kernel_tests.gradient_checker import compute_gradient from tensorflow.python.kernel_tests.gradient_checker import compute_gradient
get_temp_dir = GetTempDir
# pylint: enable=unused-import # pylint: enable=unused-import
def main():
"""Runs all unit tests."""
return googletest.main()
def get_temp_dir():
"""Returns a temporary directory for use during tests.
There is no need to delete the directory after the test.
Returns:
The temporary directory.
"""
return googletest.GetTempDir()
def is_built_with_cuda():
"""Returns whether TensorFlow was built with CUDA (GPU) support."""
return test_util.IsGoogleCudaEnabled()
__all__ = make_all(__name__)
# TODO(irving,vrv): Remove once TestCase is documented
__all__.append('TestCase')

View File

@ -218,7 +218,7 @@ class SaverTest(tf.test.TestCase):
self._SaveAndLoad("var1", 1.1, 2.2, save_path) self._SaveAndLoad("var1", 1.1, 2.2, save_path)
def testGPU(self): def testGPU(self):
if not tf.test.IsBuiltWithCuda(): if not tf.test.is_built_with_cuda():
return return
save_path = os.path.join(self.get_temp_dir(), "gpu") save_path = os.path.join(self.get_temp_dir(), "gpu")
with tf.Session("", graph=tf.Graph()) as sess: with tf.Session("", graph=tf.Graph()) as sess:

View File

@ -21,7 +21,7 @@ from __future__ import print_function
import re import re
import sys import sys
_reference_pattern = re.compile(r'^@@(\w+)$') _reference_pattern = re.compile(r'^@@(\w+)$', flags=re.MULTILINE)
def make_all(module_name): def make_all(module_name):
@ -39,4 +39,5 @@ def make_all(module_name):
doc = sys.modules[module_name].__doc__ doc = sys.modules[module_name].__doc__
return [m.group(1) for m in _reference_pattern.finditer(doc)] return [m.group(1) for m in _reference_pattern.finditer(doc)]
__all__ = ['make_all'] __all__ = ['make_all']