Migrate function tests to use tf.config APIs

Using the tf.config APIs allow the test to run in v2 mode.
Additionally fixed various lint issues.

PiperOrigin-RevId: 257002615
This commit is contained in:
Gaurav Jain 2019-07-08 10:06:26 -07:00 committed by TensorFlower Gardener
parent 48753b6b1e
commit 6e9c8371fd
2 changed files with 41 additions and 29 deletions

View File

@ -19,11 +19,11 @@ from __future__ import print_function
from absl.testing import parameterized from absl.testing import parameterized
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.eager import backprop from tensorflow.python.eager import backprop
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.eager import def_function from tensorflow.python.eager import def_function
from tensorflow.python.eager import function from tensorflow.python.eager import function
from tensorflow.python.framework import config
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
@ -49,6 +49,17 @@ _COS_DERIVATIVES = [math_ops.cos,
class FunctionGradientsTest(test.TestCase, parameterized.TestCase): class FunctionGradientsTest(test.TestCase, parameterized.TestCase):
def setUp(self):
super(FunctionGradientsTest, self).setUp()
cpus = config.list_physical_devices('CPU')
# Set 4 virtual CPUs
config.set_virtual_device_configuration(cpus[0], [
context.VirtualDeviceConfiguration(),
context.VirtualDeviceConfiguration(),
context.VirtualDeviceConfiguration(),
context.VirtualDeviceConfiguration()
])
def testGraphModeWithGradients(self): def testGraphModeWithGradients(self):
v = resource_variable_ops.ResourceVariable(1.0, name='v') v = resource_variable_ops.ResourceVariable(1.0, name='v')
@ -215,7 +226,7 @@ class FunctionGradientsTest(test.TestCase, parameterized.TestCase):
self.assertAllClose(-math_ops.sin(x), gg) self.assertAllClose(-math_ops.sin(x), gg)
def testSymGradGatherNd(self): def testSymGradGatherNd(self):
with ops.Graph().as_default(), self.cached_session() as sess: with ops.Graph().as_default(), self.cached_session():
@def_function.function @def_function.function
def f(x): def f(x):
@ -897,6 +908,5 @@ class FunctionGradientsTest(test.TestCase, parameterized.TestCase):
if __name__ == '__main__': if __name__ == '__main__':
ops.enable_eager_execution( ops.enable_eager_execution()
config=config_pb2.ConfigProto(device_count={'CPU': 4}))
test.main() test.main()

View File

@ -20,7 +20,7 @@ from __future__ import print_function
import collections import collections
import functools import functools
import itertools import itertools
from multiprocessing.pool import ThreadPool import multiprocessing.pool
import sys import sys
import weakref import weakref
@ -31,10 +31,12 @@ from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import rewriter_config_pb2 from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python import keras from tensorflow.python import keras
from tensorflow.python.autograph.core import ag_ctx from tensorflow.python.autograph.core import ag_ctx
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.eager import backprop from tensorflow.python.eager import backprop
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.eager import def_function from tensorflow.python.eager import def_function
from tensorflow.python.eager import function from tensorflow.python.eager import function
from tensorflow.python.framework import config
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors from tensorflow.python.framework import errors
@ -51,7 +53,6 @@ from tensorflow.python.keras.engine import training as keras_training
from tensorflow.python.keras.layers import core from tensorflow.python.keras.layers import core
from tensorflow.python.keras.optimizer_v2 import adam from tensorflow.python.keras.optimizer_v2 import adam
from tensorflow.python.layers import convolutional from tensorflow.python.layers import convolutional
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops from tensorflow.python.ops import check_ops
from tensorflow.python.ops import clip_ops from tensorflow.python.ops import clip_ops
@ -118,6 +119,17 @@ def _example_indexed_slices_without_dense_shape():
class FunctionTest(test.TestCase, parameterized.TestCase): class FunctionTest(test.TestCase, parameterized.TestCase):
def setUp(self):
super(FunctionTest, self).setUp()
cpus = config.list_physical_devices('CPU')
# Set 4 virtual CPUs
config.set_virtual_device_configuration(cpus[0], [
context.VirtualDeviceConfiguration(),
context.VirtualDeviceConfiguration(),
context.VirtualDeviceConfiguration(),
context.VirtualDeviceConfiguration()
])
def testBasic(self): def testBasic(self):
matmul = def_function.function(math_ops.matmul) matmul = def_function.function(math_ops.matmul)
t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
@ -410,7 +422,7 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
def stateless(x): def stateless(x):
return math_ops.multiply(2.0, x) return math_ops.multiply(2.0, x)
pool = ThreadPool() pool = multiprocessing.pool.ThreadPool()
inputs = [constant_op.constant(1.0 * x) for x in range(100)] inputs = [constant_op.constant(1.0 * x) for x in range(100)]
outputs = [float(out) for out in pool.map(stateless, inputs)] outputs = [float(out) for out in pool.map(stateless, inputs)]
expected = [float(2.0 * x) for x in inputs] expected = [float(2.0 * x) for x in inputs]
@ -423,12 +435,10 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
del x del x
return math_ops.multiply(2.0, 2.0) return math_ops.multiply(2.0, 2.0)
pool = ThreadPool() pool = multiprocessing.pool.ThreadPool()
# `pool.map` below instantiates 100 functions, one for each object. # `pool.map` below instantiates 100 functions, one for each object.
outputs = [ objects = [object() for _ in range(100)]
float(out) outputs = [float(out) for out in pool.map(stateless, objects)]
for out in pool.map(stateless, [object() for _ in range(100)])
]
expected = [4.0] * 100 expected = [4.0] * 100
self.assertSequenceEqual(outputs, expected) self.assertSequenceEqual(outputs, expected)
@ -440,7 +450,7 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
def stateful(x): def stateful(x):
v.assign(x) v.assign(x)
pool = ThreadPool() pool = multiprocessing.pool.ThreadPool()
inputs = [constant_op.constant(0.0)] * 100 inputs = [constant_op.constant(0.0)] * 100
pool.map(stateful, inputs) pool.map(stateful, inputs)
self.assertEqual(float(v.read_value()), 0.0) self.assertEqual(float(v.read_value()), 0.0)
@ -454,7 +464,7 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
del x del x
return v.assign(0.0) return v.assign(0.0)
pool = ThreadPool() pool = multiprocessing.pool.ThreadPool()
# `pool.map` below instantiates 100 functions, one for each object. # `pool.map` below instantiates 100 functions, one for each object.
pool.map(stateful, [object() for _ in range(100)]) pool.map(stateful, [object() for _ in range(100)])
self.assertEqual(float(v.read_value()), 0.0) self.assertEqual(float(v.read_value()), 0.0)
@ -998,7 +1008,6 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
for (input_component, output_component) in zip(input_flat, output_flat): for (input_component, output_component) in zip(input_flat, output_flat):
self.assertAllEqual(input_component, output_component) self.assertAllEqual(input_component, output_component)
@test_util.run_gpu_only @test_util.run_gpu_only
def testFunctionOnDevice(self): def testFunctionOnDevice(self):
x = constant_op.constant([1.]).gpu() x = constant_op.constant([1.]).gpu()
@ -1250,11 +1259,7 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
# `Function` --(instancemethod on `MiniModel`)--> `MiniModel` # `Function` --(instancemethod on `MiniModel`)--> `MiniModel`
del model.call del model.call
# Note: The ConfigProto below unfortunately only configures graph @test_util.run_in_graph_and_eager_modes
# construction. Eager's configuration is controlled in `__main__`.
@test_util.run_in_graph_and_eager_modes(
config=config_pb2.ConfigProto(device_count={'CPU': 4}))
@test_util.run_v1_only('b/120545219')
def testDeviceAnnotationsRespected(self): def testDeviceAnnotationsRespected(self):
def multi_device_fn(): def multi_device_fn():
@ -1291,9 +1296,7 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
self.assertIn(compat.as_bytes('CPU:2'), outputs[2]) self.assertIn(compat.as_bytes('CPU:2'), outputs[2])
self.assertIn(compat.as_bytes('CPU:0'), outputs[3]) self.assertIn(compat.as_bytes('CPU:0'), outputs[3])
@test_util.run_in_graph_and_eager_modes( @test_util.run_in_graph_and_eager_modes
config=config_pb2.ConfigProto(device_count={'CPU': 2}))
@test_util.run_v1_only('b/120545219')
def testCallingGraphFunctionOnDifferentDevice(self): def testCallingGraphFunctionOnDifferentDevice(self):
def func(): def func():
@ -1441,7 +1444,7 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
self.assertLen(defined._function_cache.arg_relaxed_shapes, 1) self.assertLen(defined._function_cache.arg_relaxed_shapes, 1)
relaxed_shapes = ( relaxed_shapes = (
list(defined._function_cache.arg_relaxed_shapes.values())[0]) list(defined._function_cache.arg_relaxed_shapes.values())[0])
self.assertEqual(len(relaxed_shapes), 1) self.assertLen(relaxed_shapes, 1)
relaxed_shape = relaxed_shapes[0] relaxed_shape = relaxed_shapes[0]
# pylint: enable=protected-access # pylint: enable=protected-access
self.assertEqual(relaxed_shape.rank, 1) self.assertEqual(relaxed_shape.rank, 1)
@ -1672,7 +1675,7 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
defined(array_ops.ones([2, 1])) defined(array_ops.ones([2, 1]))
# Wrong number of arguments. # Wrong number of arguments.
with self.assertRaisesRegexp(TypeError, 'Received 2 argument\(s\)'): with self.assertRaisesRegexp(TypeError, r'Received 2 argument\(s\)'):
defined(array_ops.ones([2]), array_ops.ones([2])) defined(array_ops.ones([2]), array_ops.ones([2]))
with self.assertRaisesRegexp(ValueError, with self.assertRaisesRegexp(ValueError,
'Structure of Python function inputs.*'): 'Structure of Python function inputs.*'):
@ -2203,10 +2206,10 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
rewrites.min_graph_nodes = -1 rewrites.min_graph_nodes = -1
graph_options = config_pb2.GraphOptions( graph_options = config_pb2.GraphOptions(
rewrite_options=rewrites, build_cost_model=1) rewrite_options=rewrites, build_cost_model=1)
config = config_pb2.ConfigProto(graph_options=graph_options) config_proto = config_pb2.ConfigProto(graph_options=graph_options)
with context.graph_mode(), self.cached_session( with context.graph_mode(), self.cached_session(
config=config, graph=ops.Graph(), use_gpu=True): config=config_proto, graph=ops.Graph(), use_gpu=True):
@function.defun_with_attributes( @function.defun_with_attributes(
attributes={ attributes={
@ -3180,6 +3183,5 @@ class MultiDeviceTest(test.TestCase, parameterized.TestCase):
if __name__ == '__main__': if __name__ == '__main__':
ops.enable_eager_execution( ops.enable_eager_execution()
config=config_pb2.ConfigProto(device_count={'CPU': 4}))
test.main() test.main()