Migrate function tests to use tf.config APIs
Using the tf.config APIs allow the test to run in v2 mode. Additionally fixed various lint issues. PiperOrigin-RevId: 257002615
This commit is contained in:
parent
48753b6b1e
commit
6e9c8371fd
@ -19,11 +19,11 @@ from __future__ import print_function
|
|||||||
|
|
||||||
from absl.testing import parameterized
|
from absl.testing import parameterized
|
||||||
|
|
||||||
from tensorflow.core.protobuf import config_pb2
|
|
||||||
from tensorflow.python.eager import backprop
|
from tensorflow.python.eager import backprop
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.eager import def_function
|
from tensorflow.python.eager import def_function
|
||||||
from tensorflow.python.eager import function
|
from tensorflow.python.eager import function
|
||||||
|
from tensorflow.python.framework import config
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
@ -49,6 +49,17 @@ _COS_DERIVATIVES = [math_ops.cos,
|
|||||||
|
|
||||||
class FunctionGradientsTest(test.TestCase, parameterized.TestCase):
|
class FunctionGradientsTest(test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super(FunctionGradientsTest, self).setUp()
|
||||||
|
cpus = config.list_physical_devices('CPU')
|
||||||
|
# Set 4 virtual CPUs
|
||||||
|
config.set_virtual_device_configuration(cpus[0], [
|
||||||
|
context.VirtualDeviceConfiguration(),
|
||||||
|
context.VirtualDeviceConfiguration(),
|
||||||
|
context.VirtualDeviceConfiguration(),
|
||||||
|
context.VirtualDeviceConfiguration()
|
||||||
|
])
|
||||||
|
|
||||||
def testGraphModeWithGradients(self):
|
def testGraphModeWithGradients(self):
|
||||||
v = resource_variable_ops.ResourceVariable(1.0, name='v')
|
v = resource_variable_ops.ResourceVariable(1.0, name='v')
|
||||||
|
|
||||||
@ -215,7 +226,7 @@ class FunctionGradientsTest(test.TestCase, parameterized.TestCase):
|
|||||||
self.assertAllClose(-math_ops.sin(x), gg)
|
self.assertAllClose(-math_ops.sin(x), gg)
|
||||||
|
|
||||||
def testSymGradGatherNd(self):
|
def testSymGradGatherNd(self):
|
||||||
with ops.Graph().as_default(), self.cached_session() as sess:
|
with ops.Graph().as_default(), self.cached_session():
|
||||||
|
|
||||||
@def_function.function
|
@def_function.function
|
||||||
def f(x):
|
def f(x):
|
||||||
@ -897,6 +908,5 @@ class FunctionGradientsTest(test.TestCase, parameterized.TestCase):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
ops.enable_eager_execution(
|
ops.enable_eager_execution()
|
||||||
config=config_pb2.ConfigProto(device_count={'CPU': 4}))
|
|
||||||
test.main()
|
test.main()
|
||||||
|
@ -20,7 +20,7 @@ from __future__ import print_function
|
|||||||
import collections
|
import collections
|
||||||
import functools
|
import functools
|
||||||
import itertools
|
import itertools
|
||||||
from multiprocessing.pool import ThreadPool
|
import multiprocessing.pool
|
||||||
import sys
|
import sys
|
||||||
import weakref
|
import weakref
|
||||||
|
|
||||||
@ -31,10 +31,12 @@ from tensorflow.core.protobuf import config_pb2
|
|||||||
from tensorflow.core.protobuf import rewriter_config_pb2
|
from tensorflow.core.protobuf import rewriter_config_pb2
|
||||||
from tensorflow.python import keras
|
from tensorflow.python import keras
|
||||||
from tensorflow.python.autograph.core import ag_ctx
|
from tensorflow.python.autograph.core import ag_ctx
|
||||||
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
from tensorflow.python.eager import backprop
|
from tensorflow.python.eager import backprop
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.eager import def_function
|
from tensorflow.python.eager import def_function
|
||||||
from tensorflow.python.eager import function
|
from tensorflow.python.eager import function
|
||||||
|
from tensorflow.python.framework import config
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import errors
|
from tensorflow.python.framework import errors
|
||||||
@ -51,7 +53,6 @@ from tensorflow.python.keras.engine import training as keras_training
|
|||||||
from tensorflow.python.keras.layers import core
|
from tensorflow.python.keras.layers import core
|
||||||
from tensorflow.python.keras.optimizer_v2 import adam
|
from tensorflow.python.keras.optimizer_v2 import adam
|
||||||
from tensorflow.python.layers import convolutional
|
from tensorflow.python.layers import convolutional
|
||||||
from tensorflow.python.data.ops import dataset_ops
|
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import check_ops
|
from tensorflow.python.ops import check_ops
|
||||||
from tensorflow.python.ops import clip_ops
|
from tensorflow.python.ops import clip_ops
|
||||||
@ -118,6 +119,17 @@ def _example_indexed_slices_without_dense_shape():
|
|||||||
|
|
||||||
class FunctionTest(test.TestCase, parameterized.TestCase):
|
class FunctionTest(test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super(FunctionTest, self).setUp()
|
||||||
|
cpus = config.list_physical_devices('CPU')
|
||||||
|
# Set 4 virtual CPUs
|
||||||
|
config.set_virtual_device_configuration(cpus[0], [
|
||||||
|
context.VirtualDeviceConfiguration(),
|
||||||
|
context.VirtualDeviceConfiguration(),
|
||||||
|
context.VirtualDeviceConfiguration(),
|
||||||
|
context.VirtualDeviceConfiguration()
|
||||||
|
])
|
||||||
|
|
||||||
def testBasic(self):
|
def testBasic(self):
|
||||||
matmul = def_function.function(math_ops.matmul)
|
matmul = def_function.function(math_ops.matmul)
|
||||||
t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
|
t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
|
||||||
@ -410,7 +422,7 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
|||||||
def stateless(x):
|
def stateless(x):
|
||||||
return math_ops.multiply(2.0, x)
|
return math_ops.multiply(2.0, x)
|
||||||
|
|
||||||
pool = ThreadPool()
|
pool = multiprocessing.pool.ThreadPool()
|
||||||
inputs = [constant_op.constant(1.0 * x) for x in range(100)]
|
inputs = [constant_op.constant(1.0 * x) for x in range(100)]
|
||||||
outputs = [float(out) for out in pool.map(stateless, inputs)]
|
outputs = [float(out) for out in pool.map(stateless, inputs)]
|
||||||
expected = [float(2.0 * x) for x in inputs]
|
expected = [float(2.0 * x) for x in inputs]
|
||||||
@ -423,12 +435,10 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
|||||||
del x
|
del x
|
||||||
return math_ops.multiply(2.0, 2.0)
|
return math_ops.multiply(2.0, 2.0)
|
||||||
|
|
||||||
pool = ThreadPool()
|
pool = multiprocessing.pool.ThreadPool()
|
||||||
# `pool.map` below instantiates 100 functions, one for each object.
|
# `pool.map` below instantiates 100 functions, one for each object.
|
||||||
outputs = [
|
objects = [object() for _ in range(100)]
|
||||||
float(out)
|
outputs = [float(out) for out in pool.map(stateless, objects)]
|
||||||
for out in pool.map(stateless, [object() for _ in range(100)])
|
|
||||||
]
|
|
||||||
expected = [4.0] * 100
|
expected = [4.0] * 100
|
||||||
self.assertSequenceEqual(outputs, expected)
|
self.assertSequenceEqual(outputs, expected)
|
||||||
|
|
||||||
@ -440,7 +450,7 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
|||||||
def stateful(x):
|
def stateful(x):
|
||||||
v.assign(x)
|
v.assign(x)
|
||||||
|
|
||||||
pool = ThreadPool()
|
pool = multiprocessing.pool.ThreadPool()
|
||||||
inputs = [constant_op.constant(0.0)] * 100
|
inputs = [constant_op.constant(0.0)] * 100
|
||||||
pool.map(stateful, inputs)
|
pool.map(stateful, inputs)
|
||||||
self.assertEqual(float(v.read_value()), 0.0)
|
self.assertEqual(float(v.read_value()), 0.0)
|
||||||
@ -454,7 +464,7 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
|||||||
del x
|
del x
|
||||||
return v.assign(0.0)
|
return v.assign(0.0)
|
||||||
|
|
||||||
pool = ThreadPool()
|
pool = multiprocessing.pool.ThreadPool()
|
||||||
# `pool.map` below instantiates 100 functions, one for each object.
|
# `pool.map` below instantiates 100 functions, one for each object.
|
||||||
pool.map(stateful, [object() for _ in range(100)])
|
pool.map(stateful, [object() for _ in range(100)])
|
||||||
self.assertEqual(float(v.read_value()), 0.0)
|
self.assertEqual(float(v.read_value()), 0.0)
|
||||||
@ -998,7 +1008,6 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
|||||||
for (input_component, output_component) in zip(input_flat, output_flat):
|
for (input_component, output_component) in zip(input_flat, output_flat):
|
||||||
self.assertAllEqual(input_component, output_component)
|
self.assertAllEqual(input_component, output_component)
|
||||||
|
|
||||||
|
|
||||||
@test_util.run_gpu_only
|
@test_util.run_gpu_only
|
||||||
def testFunctionOnDevice(self):
|
def testFunctionOnDevice(self):
|
||||||
x = constant_op.constant([1.]).gpu()
|
x = constant_op.constant([1.]).gpu()
|
||||||
@ -1250,11 +1259,7 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
|||||||
# `Function` --(instancemethod on `MiniModel`)--> `MiniModel`
|
# `Function` --(instancemethod on `MiniModel`)--> `MiniModel`
|
||||||
del model.call
|
del model.call
|
||||||
|
|
||||||
# Note: The ConfigProto below unfortunately only configures graph
|
@test_util.run_in_graph_and_eager_modes
|
||||||
# construction. Eager's configuration is controlled in `__main__`.
|
|
||||||
@test_util.run_in_graph_and_eager_modes(
|
|
||||||
config=config_pb2.ConfigProto(device_count={'CPU': 4}))
|
|
||||||
@test_util.run_v1_only('b/120545219')
|
|
||||||
def testDeviceAnnotationsRespected(self):
|
def testDeviceAnnotationsRespected(self):
|
||||||
|
|
||||||
def multi_device_fn():
|
def multi_device_fn():
|
||||||
@ -1291,9 +1296,7 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
|||||||
self.assertIn(compat.as_bytes('CPU:2'), outputs[2])
|
self.assertIn(compat.as_bytes('CPU:2'), outputs[2])
|
||||||
self.assertIn(compat.as_bytes('CPU:0'), outputs[3])
|
self.assertIn(compat.as_bytes('CPU:0'), outputs[3])
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes(
|
@test_util.run_in_graph_and_eager_modes
|
||||||
config=config_pb2.ConfigProto(device_count={'CPU': 2}))
|
|
||||||
@test_util.run_v1_only('b/120545219')
|
|
||||||
def testCallingGraphFunctionOnDifferentDevice(self):
|
def testCallingGraphFunctionOnDifferentDevice(self):
|
||||||
|
|
||||||
def func():
|
def func():
|
||||||
@ -1441,7 +1444,7 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
|||||||
self.assertLen(defined._function_cache.arg_relaxed_shapes, 1)
|
self.assertLen(defined._function_cache.arg_relaxed_shapes, 1)
|
||||||
relaxed_shapes = (
|
relaxed_shapes = (
|
||||||
list(defined._function_cache.arg_relaxed_shapes.values())[0])
|
list(defined._function_cache.arg_relaxed_shapes.values())[0])
|
||||||
self.assertEqual(len(relaxed_shapes), 1)
|
self.assertLen(relaxed_shapes, 1)
|
||||||
relaxed_shape = relaxed_shapes[0]
|
relaxed_shape = relaxed_shapes[0]
|
||||||
# pylint: enable=protected-access
|
# pylint: enable=protected-access
|
||||||
self.assertEqual(relaxed_shape.rank, 1)
|
self.assertEqual(relaxed_shape.rank, 1)
|
||||||
@ -1672,7 +1675,7 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
|||||||
defined(array_ops.ones([2, 1]))
|
defined(array_ops.ones([2, 1]))
|
||||||
|
|
||||||
# Wrong number of arguments.
|
# Wrong number of arguments.
|
||||||
with self.assertRaisesRegexp(TypeError, 'Received 2 argument\(s\)'):
|
with self.assertRaisesRegexp(TypeError, r'Received 2 argument\(s\)'):
|
||||||
defined(array_ops.ones([2]), array_ops.ones([2]))
|
defined(array_ops.ones([2]), array_ops.ones([2]))
|
||||||
with self.assertRaisesRegexp(ValueError,
|
with self.assertRaisesRegexp(ValueError,
|
||||||
'Structure of Python function inputs.*'):
|
'Structure of Python function inputs.*'):
|
||||||
@ -2203,10 +2206,10 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
|||||||
rewrites.min_graph_nodes = -1
|
rewrites.min_graph_nodes = -1
|
||||||
graph_options = config_pb2.GraphOptions(
|
graph_options = config_pb2.GraphOptions(
|
||||||
rewrite_options=rewrites, build_cost_model=1)
|
rewrite_options=rewrites, build_cost_model=1)
|
||||||
config = config_pb2.ConfigProto(graph_options=graph_options)
|
config_proto = config_pb2.ConfigProto(graph_options=graph_options)
|
||||||
|
|
||||||
with context.graph_mode(), self.cached_session(
|
with context.graph_mode(), self.cached_session(
|
||||||
config=config, graph=ops.Graph(), use_gpu=True):
|
config=config_proto, graph=ops.Graph(), use_gpu=True):
|
||||||
|
|
||||||
@function.defun_with_attributes(
|
@function.defun_with_attributes(
|
||||||
attributes={
|
attributes={
|
||||||
@ -3180,6 +3183,5 @@ class MultiDeviceTest(test.TestCase, parameterized.TestCase):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
ops.enable_eager_execution(
|
ops.enable_eager_execution()
|
||||||
config=config_pb2.ConfigProto(device_count={'CPU': 4}))
|
|
||||||
test.main()
|
test.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user