Migrate function tests to use tf.config APIs

Using the tf.config APIs allow the test to run in v2 mode.
Additionally fixed various lint issues.

PiperOrigin-RevId: 257002615
This commit is contained in:
Gaurav Jain 2019-07-08 10:06:26 -07:00 committed by TensorFlower Gardener
parent 48753b6b1e
commit 6e9c8371fd
2 changed files with 41 additions and 29 deletions

View File

@ -19,11 +19,11 @@ from __future__ import print_function
from absl.testing import parameterized
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.eager import function
from tensorflow.python.framework import config
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
@ -49,6 +49,17 @@ _COS_DERIVATIVES = [math_ops.cos,
class FunctionGradientsTest(test.TestCase, parameterized.TestCase):
def setUp(self):
super(FunctionGradientsTest, self).setUp()
cpus = config.list_physical_devices('CPU')
# Set 4 virtual CPUs
config.set_virtual_device_configuration(cpus[0], [
context.VirtualDeviceConfiguration(),
context.VirtualDeviceConfiguration(),
context.VirtualDeviceConfiguration(),
context.VirtualDeviceConfiguration()
])
def testGraphModeWithGradients(self):
v = resource_variable_ops.ResourceVariable(1.0, name='v')
@ -215,7 +226,7 @@ class FunctionGradientsTest(test.TestCase, parameterized.TestCase):
self.assertAllClose(-math_ops.sin(x), gg)
def testSymGradGatherNd(self):
with ops.Graph().as_default(), self.cached_session() as sess:
with ops.Graph().as_default(), self.cached_session():
@def_function.function
def f(x):
@ -897,6 +908,5 @@ class FunctionGradientsTest(test.TestCase, parameterized.TestCase):
if __name__ == '__main__':
ops.enable_eager_execution(
config=config_pb2.ConfigProto(device_count={'CPU': 4}))
ops.enable_eager_execution()
test.main()

View File

@ -20,7 +20,7 @@ from __future__ import print_function
import collections
import functools
import itertools
from multiprocessing.pool import ThreadPool
import multiprocessing.pool
import sys
import weakref
@ -31,10 +31,12 @@ from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python import keras
from tensorflow.python.autograph.core import ag_ctx
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.eager import function
from tensorflow.python.framework import config
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@ -51,7 +53,6 @@ from tensorflow.python.keras.engine import training as keras_training
from tensorflow.python.keras.layers import core
from tensorflow.python.keras.optimizer_v2 import adam
from tensorflow.python.layers import convolutional
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import clip_ops
@ -118,6 +119,17 @@ def _example_indexed_slices_without_dense_shape():
class FunctionTest(test.TestCase, parameterized.TestCase):
def setUp(self):
super(FunctionTest, self).setUp()
cpus = config.list_physical_devices('CPU')
# Set 4 virtual CPUs
config.set_virtual_device_configuration(cpus[0], [
context.VirtualDeviceConfiguration(),
context.VirtualDeviceConfiguration(),
context.VirtualDeviceConfiguration(),
context.VirtualDeviceConfiguration()
])
def testBasic(self):
matmul = def_function.function(math_ops.matmul)
t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
@ -410,7 +422,7 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
def stateless(x):
return math_ops.multiply(2.0, x)
pool = ThreadPool()
pool = multiprocessing.pool.ThreadPool()
inputs = [constant_op.constant(1.0 * x) for x in range(100)]
outputs = [float(out) for out in pool.map(stateless, inputs)]
expected = [float(2.0 * x) for x in inputs]
@ -423,12 +435,10 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
del x
return math_ops.multiply(2.0, 2.0)
pool = ThreadPool()
pool = multiprocessing.pool.ThreadPool()
# `pool.map` below instantiates 100 functions, one for each object.
outputs = [
float(out)
for out in pool.map(stateless, [object() for _ in range(100)])
]
objects = [object() for _ in range(100)]
outputs = [float(out) for out in pool.map(stateless, objects)]
expected = [4.0] * 100
self.assertSequenceEqual(outputs, expected)
@ -440,7 +450,7 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
def stateful(x):
v.assign(x)
pool = ThreadPool()
pool = multiprocessing.pool.ThreadPool()
inputs = [constant_op.constant(0.0)] * 100
pool.map(stateful, inputs)
self.assertEqual(float(v.read_value()), 0.0)
@ -454,7 +464,7 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
del x
return v.assign(0.0)
pool = ThreadPool()
pool = multiprocessing.pool.ThreadPool()
# `pool.map` below instantiates 100 functions, one for each object.
pool.map(stateful, [object() for _ in range(100)])
self.assertEqual(float(v.read_value()), 0.0)
@ -998,7 +1008,6 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
for (input_component, output_component) in zip(input_flat, output_flat):
self.assertAllEqual(input_component, output_component)
@test_util.run_gpu_only
def testFunctionOnDevice(self):
x = constant_op.constant([1.]).gpu()
@ -1250,11 +1259,7 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
# `Function` --(instancemethod on `MiniModel`)--> `MiniModel`
del model.call
# Note: The ConfigProto below unfortunately only configures graph
# construction. Eager's configuration is controlled in `__main__`.
@test_util.run_in_graph_and_eager_modes(
config=config_pb2.ConfigProto(device_count={'CPU': 4}))
@test_util.run_v1_only('b/120545219')
@test_util.run_in_graph_and_eager_modes
def testDeviceAnnotationsRespected(self):
def multi_device_fn():
@ -1291,9 +1296,7 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
self.assertIn(compat.as_bytes('CPU:2'), outputs[2])
self.assertIn(compat.as_bytes('CPU:0'), outputs[3])
@test_util.run_in_graph_and_eager_modes(
config=config_pb2.ConfigProto(device_count={'CPU': 2}))
@test_util.run_v1_only('b/120545219')
@test_util.run_in_graph_and_eager_modes
def testCallingGraphFunctionOnDifferentDevice(self):
def func():
@ -1441,7 +1444,7 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
self.assertLen(defined._function_cache.arg_relaxed_shapes, 1)
relaxed_shapes = (
list(defined._function_cache.arg_relaxed_shapes.values())[0])
self.assertEqual(len(relaxed_shapes), 1)
self.assertLen(relaxed_shapes, 1)
relaxed_shape = relaxed_shapes[0]
# pylint: enable=protected-access
self.assertEqual(relaxed_shape.rank, 1)
@ -1672,7 +1675,7 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
defined(array_ops.ones([2, 1]))
# Wrong number of arguments.
with self.assertRaisesRegexp(TypeError, 'Received 2 argument\(s\)'):
with self.assertRaisesRegexp(TypeError, r'Received 2 argument\(s\)'):
defined(array_ops.ones([2]), array_ops.ones([2]))
with self.assertRaisesRegexp(ValueError,
'Structure of Python function inputs.*'):
@ -2203,10 +2206,10 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
rewrites.min_graph_nodes = -1
graph_options = config_pb2.GraphOptions(
rewrite_options=rewrites, build_cost_model=1)
config = config_pb2.ConfigProto(graph_options=graph_options)
config_proto = config_pb2.ConfigProto(graph_options=graph_options)
with context.graph_mode(), self.cached_session(
config=config, graph=ops.Graph(), use_gpu=True):
config=config_proto, graph=ops.Graph(), use_gpu=True):
@function.defun_with_attributes(
attributes={
@ -3180,6 +3183,5 @@ class MultiDeviceTest(test.TestCase, parameterized.TestCase):
if __name__ == '__main__':
ops.enable_eager_execution(
config=config_pb2.ConfigProto(device_count={'CPU': 4}))
ops.enable_eager_execution()
test.main()