From 6e9c8371fdf2db6015884ac602c859362d23df0e Mon Sep 17 00:00:00 2001 From: Gaurav Jain Date: Mon, 8 Jul 2019 10:06:26 -0700 Subject: [PATCH] Migrate function tests to use tf.config APIs Using the tf.config APIs allow the test to run in v2 mode. Additionally fixed various lint issues. PiperOrigin-RevId: 257002615 --- .../python/eager/function_gradients_test.py | 18 +++++-- tensorflow/python/eager/function_test.py | 52 ++++++++++--------- 2 files changed, 41 insertions(+), 29 deletions(-) diff --git a/tensorflow/python/eager/function_gradients_test.py b/tensorflow/python/eager/function_gradients_test.py index 09227c3a02e..3d9f22b7460 100644 --- a/tensorflow/python/eager/function_gradients_test.py +++ b/tensorflow/python/eager/function_gradients_test.py @@ -19,11 +19,11 @@ from __future__ import print_function from absl.testing import parameterized -from tensorflow.core.protobuf import config_pb2 from tensorflow.python.eager import backprop from tensorflow.python.eager import context from tensorflow.python.eager import def_function from tensorflow.python.eager import function +from tensorflow.python.framework import config from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -49,6 +49,17 @@ _COS_DERIVATIVES = [math_ops.cos, class FunctionGradientsTest(test.TestCase, parameterized.TestCase): + def setUp(self): + super(FunctionGradientsTest, self).setUp() + cpus = config.list_physical_devices('CPU') + # Set 4 virtual CPUs + config.set_virtual_device_configuration(cpus[0], [ + context.VirtualDeviceConfiguration(), + context.VirtualDeviceConfiguration(), + context.VirtualDeviceConfiguration(), + context.VirtualDeviceConfiguration() + ]) + def testGraphModeWithGradients(self): v = resource_variable_ops.ResourceVariable(1.0, name='v') @@ -215,7 +226,7 @@ class FunctionGradientsTest(test.TestCase, parameterized.TestCase): self.assertAllClose(-math_ops.sin(x), gg) def testSymGradGatherNd(self): - with ops.Graph().as_default(), self.cached_session() as sess: + with ops.Graph().as_default(), self.cached_session(): @def_function.function def f(x): @@ -897,6 +908,5 @@ class FunctionGradientsTest(test.TestCase, parameterized.TestCase): if __name__ == '__main__': - ops.enable_eager_execution( - config=config_pb2.ConfigProto(device_count={'CPU': 4})) + ops.enable_eager_execution() test.main() diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py index 9b9d9e825df..5baafe2a3cb 100644 --- a/tensorflow/python/eager/function_test.py +++ b/tensorflow/python/eager/function_test.py @@ -20,7 +20,7 @@ from __future__ import print_function import collections import functools import itertools -from multiprocessing.pool import ThreadPool +import multiprocessing.pool import sys import weakref @@ -31,10 +31,12 @@ from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import rewriter_config_pb2 from tensorflow.python import keras from tensorflow.python.autograph.core import ag_ctx +from tensorflow.python.data.ops import dataset_ops from tensorflow.python.eager import backprop from tensorflow.python.eager import context from tensorflow.python.eager import def_function from tensorflow.python.eager import function +from tensorflow.python.framework import config from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -51,7 +53,6 @@ from tensorflow.python.keras.engine import training as keras_training from tensorflow.python.keras.layers import core from tensorflow.python.keras.optimizer_v2 import adam from tensorflow.python.layers import convolutional -from tensorflow.python.data.ops import dataset_ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import clip_ops @@ -118,6 +119,17 @@ def _example_indexed_slices_without_dense_shape(): class FunctionTest(test.TestCase, parameterized.TestCase): + def setUp(self): + super(FunctionTest, self).setUp() + cpus = config.list_physical_devices('CPU') + # Set 4 virtual CPUs + config.set_virtual_device_configuration(cpus[0], [ + context.VirtualDeviceConfiguration(), + context.VirtualDeviceConfiguration(), + context.VirtualDeviceConfiguration(), + context.VirtualDeviceConfiguration() + ]) + def testBasic(self): matmul = def_function.function(math_ops.matmul) t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) @@ -410,7 +422,7 @@ class FunctionTest(test.TestCase, parameterized.TestCase): def stateless(x): return math_ops.multiply(2.0, x) - pool = ThreadPool() + pool = multiprocessing.pool.ThreadPool() inputs = [constant_op.constant(1.0 * x) for x in range(100)] outputs = [float(out) for out in pool.map(stateless, inputs)] expected = [float(2.0 * x) for x in inputs] @@ -423,12 +435,10 @@ class FunctionTest(test.TestCase, parameterized.TestCase): del x return math_ops.multiply(2.0, 2.0) - pool = ThreadPool() + pool = multiprocessing.pool.ThreadPool() # `pool.map` below instantiates 100 functions, one for each object. - outputs = [ - float(out) - for out in pool.map(stateless, [object() for _ in range(100)]) - ] + objects = [object() for _ in range(100)] + outputs = [float(out) for out in pool.map(stateless, objects)] expected = [4.0] * 100 self.assertSequenceEqual(outputs, expected) @@ -440,7 +450,7 @@ class FunctionTest(test.TestCase, parameterized.TestCase): def stateful(x): v.assign(x) - pool = ThreadPool() + pool = multiprocessing.pool.ThreadPool() inputs = [constant_op.constant(0.0)] * 100 pool.map(stateful, inputs) self.assertEqual(float(v.read_value()), 0.0) @@ -454,7 +464,7 @@ class FunctionTest(test.TestCase, parameterized.TestCase): del x return v.assign(0.0) - pool = ThreadPool() + pool = multiprocessing.pool.ThreadPool() # `pool.map` below instantiates 100 functions, one for each object. pool.map(stateful, [object() for _ in range(100)]) self.assertEqual(float(v.read_value()), 0.0) @@ -998,7 +1008,6 @@ class FunctionTest(test.TestCase, parameterized.TestCase): for (input_component, output_component) in zip(input_flat, output_flat): self.assertAllEqual(input_component, output_component) - @test_util.run_gpu_only def testFunctionOnDevice(self): x = constant_op.constant([1.]).gpu() @@ -1250,11 +1259,7 @@ class FunctionTest(test.TestCase, parameterized.TestCase): # `Function` --(instancemethod on `MiniModel`)--> `MiniModel` del model.call - # Note: The ConfigProto below unfortunately only configures graph - # construction. Eager's configuration is controlled in `__main__`. - @test_util.run_in_graph_and_eager_modes( - config=config_pb2.ConfigProto(device_count={'CPU': 4})) - @test_util.run_v1_only('b/120545219') + @test_util.run_in_graph_and_eager_modes def testDeviceAnnotationsRespected(self): def multi_device_fn(): @@ -1291,9 +1296,7 @@ class FunctionTest(test.TestCase, parameterized.TestCase): self.assertIn(compat.as_bytes('CPU:2'), outputs[2]) self.assertIn(compat.as_bytes('CPU:0'), outputs[3]) - @test_util.run_in_graph_and_eager_modes( - config=config_pb2.ConfigProto(device_count={'CPU': 2})) - @test_util.run_v1_only('b/120545219') + @test_util.run_in_graph_and_eager_modes def testCallingGraphFunctionOnDifferentDevice(self): def func(): @@ -1441,7 +1444,7 @@ class FunctionTest(test.TestCase, parameterized.TestCase): self.assertLen(defined._function_cache.arg_relaxed_shapes, 1) relaxed_shapes = ( list(defined._function_cache.arg_relaxed_shapes.values())[0]) - self.assertEqual(len(relaxed_shapes), 1) + self.assertLen(relaxed_shapes, 1) relaxed_shape = relaxed_shapes[0] # pylint: enable=protected-access self.assertEqual(relaxed_shape.rank, 1) @@ -1672,7 +1675,7 @@ class FunctionTest(test.TestCase, parameterized.TestCase): defined(array_ops.ones([2, 1])) # Wrong number of arguments. - with self.assertRaisesRegexp(TypeError, 'Received 2 argument\(s\)'): + with self.assertRaisesRegexp(TypeError, r'Received 2 argument\(s\)'): defined(array_ops.ones([2]), array_ops.ones([2])) with self.assertRaisesRegexp(ValueError, 'Structure of Python function inputs.*'): @@ -2203,10 +2206,10 @@ class FunctionTest(test.TestCase, parameterized.TestCase): rewrites.min_graph_nodes = -1 graph_options = config_pb2.GraphOptions( rewrite_options=rewrites, build_cost_model=1) - config = config_pb2.ConfigProto(graph_options=graph_options) + config_proto = config_pb2.ConfigProto(graph_options=graph_options) with context.graph_mode(), self.cached_session( - config=config, graph=ops.Graph(), use_gpu=True): + config=config_proto, graph=ops.Graph(), use_gpu=True): @function.defun_with_attributes( attributes={ @@ -3180,6 +3183,5 @@ class MultiDeviceTest(test.TestCase, parameterized.TestCase): if __name__ == '__main__': - ops.enable_eager_execution( - config=config_pb2.ConfigProto(device_count={'CPU': 4})) + ops.enable_eager_execution() test.main()