Migrate function tests to use tf.config APIs
Using the tf.config APIs allow the test to run in v2 mode. Additionally fixed various lint issues. PiperOrigin-RevId: 257002615
This commit is contained in:
parent
48753b6b1e
commit
6e9c8371fd
@ -19,11 +19,11 @@ from __future__ import print_function
|
||||
|
||||
from absl.testing import parameterized
|
||||
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.eager import function
|
||||
from tensorflow.python.framework import config
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
@ -49,6 +49,17 @@ _COS_DERIVATIVES = [math_ops.cos,
|
||||
|
||||
class FunctionGradientsTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super(FunctionGradientsTest, self).setUp()
|
||||
cpus = config.list_physical_devices('CPU')
|
||||
# Set 4 virtual CPUs
|
||||
config.set_virtual_device_configuration(cpus[0], [
|
||||
context.VirtualDeviceConfiguration(),
|
||||
context.VirtualDeviceConfiguration(),
|
||||
context.VirtualDeviceConfiguration(),
|
||||
context.VirtualDeviceConfiguration()
|
||||
])
|
||||
|
||||
def testGraphModeWithGradients(self):
|
||||
v = resource_variable_ops.ResourceVariable(1.0, name='v')
|
||||
|
||||
@ -215,7 +226,7 @@ class FunctionGradientsTest(test.TestCase, parameterized.TestCase):
|
||||
self.assertAllClose(-math_ops.sin(x), gg)
|
||||
|
||||
def testSymGradGatherNd(self):
|
||||
with ops.Graph().as_default(), self.cached_session() as sess:
|
||||
with ops.Graph().as_default(), self.cached_session():
|
||||
|
||||
@def_function.function
|
||||
def f(x):
|
||||
@ -897,6 +908,5 @@ class FunctionGradientsTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
ops.enable_eager_execution(
|
||||
config=config_pb2.ConfigProto(device_count={'CPU': 4}))
|
||||
ops.enable_eager_execution()
|
||||
test.main()
|
||||
|
@ -20,7 +20,7 @@ from __future__ import print_function
|
||||
import collections
|
||||
import functools
|
||||
import itertools
|
||||
from multiprocessing.pool import ThreadPool
|
||||
import multiprocessing.pool
|
||||
import sys
|
||||
import weakref
|
||||
|
||||
@ -31,10 +31,12 @@ from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.core.protobuf import rewriter_config_pb2
|
||||
from tensorflow.python import keras
|
||||
from tensorflow.python.autograph.core import ag_ctx
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.eager import function
|
||||
from tensorflow.python.framework import config
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
@ -51,7 +53,6 @@ from tensorflow.python.keras.engine import training as keras_training
|
||||
from tensorflow.python.keras.layers import core
|
||||
from tensorflow.python.keras.optimizer_v2 import adam
|
||||
from tensorflow.python.layers import convolutional
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import check_ops
|
||||
from tensorflow.python.ops import clip_ops
|
||||
@ -118,6 +119,17 @@ def _example_indexed_slices_without_dense_shape():
|
||||
|
||||
class FunctionTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super(FunctionTest, self).setUp()
|
||||
cpus = config.list_physical_devices('CPU')
|
||||
# Set 4 virtual CPUs
|
||||
config.set_virtual_device_configuration(cpus[0], [
|
||||
context.VirtualDeviceConfiguration(),
|
||||
context.VirtualDeviceConfiguration(),
|
||||
context.VirtualDeviceConfiguration(),
|
||||
context.VirtualDeviceConfiguration()
|
||||
])
|
||||
|
||||
def testBasic(self):
|
||||
matmul = def_function.function(math_ops.matmul)
|
||||
t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
|
||||
@ -410,7 +422,7 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
||||
def stateless(x):
|
||||
return math_ops.multiply(2.0, x)
|
||||
|
||||
pool = ThreadPool()
|
||||
pool = multiprocessing.pool.ThreadPool()
|
||||
inputs = [constant_op.constant(1.0 * x) for x in range(100)]
|
||||
outputs = [float(out) for out in pool.map(stateless, inputs)]
|
||||
expected = [float(2.0 * x) for x in inputs]
|
||||
@ -423,12 +435,10 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
||||
del x
|
||||
return math_ops.multiply(2.0, 2.0)
|
||||
|
||||
pool = ThreadPool()
|
||||
pool = multiprocessing.pool.ThreadPool()
|
||||
# `pool.map` below instantiates 100 functions, one for each object.
|
||||
outputs = [
|
||||
float(out)
|
||||
for out in pool.map(stateless, [object() for _ in range(100)])
|
||||
]
|
||||
objects = [object() for _ in range(100)]
|
||||
outputs = [float(out) for out in pool.map(stateless, objects)]
|
||||
expected = [4.0] * 100
|
||||
self.assertSequenceEqual(outputs, expected)
|
||||
|
||||
@ -440,7 +450,7 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
||||
def stateful(x):
|
||||
v.assign(x)
|
||||
|
||||
pool = ThreadPool()
|
||||
pool = multiprocessing.pool.ThreadPool()
|
||||
inputs = [constant_op.constant(0.0)] * 100
|
||||
pool.map(stateful, inputs)
|
||||
self.assertEqual(float(v.read_value()), 0.0)
|
||||
@ -454,7 +464,7 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
||||
del x
|
||||
return v.assign(0.0)
|
||||
|
||||
pool = ThreadPool()
|
||||
pool = multiprocessing.pool.ThreadPool()
|
||||
# `pool.map` below instantiates 100 functions, one for each object.
|
||||
pool.map(stateful, [object() for _ in range(100)])
|
||||
self.assertEqual(float(v.read_value()), 0.0)
|
||||
@ -998,7 +1008,6 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
||||
for (input_component, output_component) in zip(input_flat, output_flat):
|
||||
self.assertAllEqual(input_component, output_component)
|
||||
|
||||
|
||||
@test_util.run_gpu_only
|
||||
def testFunctionOnDevice(self):
|
||||
x = constant_op.constant([1.]).gpu()
|
||||
@ -1250,11 +1259,7 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
||||
# `Function` --(instancemethod on `MiniModel`)--> `MiniModel`
|
||||
del model.call
|
||||
|
||||
# Note: The ConfigProto below unfortunately only configures graph
|
||||
# construction. Eager's configuration is controlled in `__main__`.
|
||||
@test_util.run_in_graph_and_eager_modes(
|
||||
config=config_pb2.ConfigProto(device_count={'CPU': 4}))
|
||||
@test_util.run_v1_only('b/120545219')
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testDeviceAnnotationsRespected(self):
|
||||
|
||||
def multi_device_fn():
|
||||
@ -1291,9 +1296,7 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
||||
self.assertIn(compat.as_bytes('CPU:2'), outputs[2])
|
||||
self.assertIn(compat.as_bytes('CPU:0'), outputs[3])
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes(
|
||||
config=config_pb2.ConfigProto(device_count={'CPU': 2}))
|
||||
@test_util.run_v1_only('b/120545219')
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testCallingGraphFunctionOnDifferentDevice(self):
|
||||
|
||||
def func():
|
||||
@ -1441,7 +1444,7 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
||||
self.assertLen(defined._function_cache.arg_relaxed_shapes, 1)
|
||||
relaxed_shapes = (
|
||||
list(defined._function_cache.arg_relaxed_shapes.values())[0])
|
||||
self.assertEqual(len(relaxed_shapes), 1)
|
||||
self.assertLen(relaxed_shapes, 1)
|
||||
relaxed_shape = relaxed_shapes[0]
|
||||
# pylint: enable=protected-access
|
||||
self.assertEqual(relaxed_shape.rank, 1)
|
||||
@ -1672,7 +1675,7 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
||||
defined(array_ops.ones([2, 1]))
|
||||
|
||||
# Wrong number of arguments.
|
||||
with self.assertRaisesRegexp(TypeError, 'Received 2 argument\(s\)'):
|
||||
with self.assertRaisesRegexp(TypeError, r'Received 2 argument\(s\)'):
|
||||
defined(array_ops.ones([2]), array_ops.ones([2]))
|
||||
with self.assertRaisesRegexp(ValueError,
|
||||
'Structure of Python function inputs.*'):
|
||||
@ -2203,10 +2206,10 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
||||
rewrites.min_graph_nodes = -1
|
||||
graph_options = config_pb2.GraphOptions(
|
||||
rewrite_options=rewrites, build_cost_model=1)
|
||||
config = config_pb2.ConfigProto(graph_options=graph_options)
|
||||
config_proto = config_pb2.ConfigProto(graph_options=graph_options)
|
||||
|
||||
with context.graph_mode(), self.cached_session(
|
||||
config=config, graph=ops.Graph(), use_gpu=True):
|
||||
config=config_proto, graph=ops.Graph(), use_gpu=True):
|
||||
|
||||
@function.defun_with_attributes(
|
||||
attributes={
|
||||
@ -3180,6 +3183,5 @@ class MultiDeviceTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
ops.enable_eager_execution(
|
||||
config=config_pb2.ConfigProto(device_count={'CPU': 4}))
|
||||
ops.enable_eager_execution()
|
||||
test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user