Move TF function test that is keras related to integration test.

PiperOrigin-RevId: 304803957
Change-Id: Ia190726a64701dca4d8dfe9aabb5f7f80faf78e6
This commit is contained in:
Scott Zhu 2020-04-04 10:54:03 -07:00 committed by TensorFlower Gardener
parent f4b2e3499b
commit 305a6e7251
3 changed files with 172 additions and 203 deletions

View File

@ -26,7 +26,6 @@ from absl.testing import parameterized
from six.moves import range
from tensorflow.python.autograph.core import converter
from tensorflow.python.eager import backprop
from tensorflow.python.eager import def_function
from tensorflow.python.eager import lift_to_graph
from tensorflow.python.framework import constant_op
@ -35,8 +34,6 @@ from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import test_util
from tensorflow.python.keras.engine import training
from tensorflow.python.keras.layers import core
from tensorflow.python.module import module
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
@ -46,26 +43,6 @@ from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.training import adam
class _ModelWithOptimizer(training.Model):
def __init__(self):
super(_ModelWithOptimizer, self).__init__()
self.dense = core.Dense(1)
self.optimizer = adam.AdamOptimizer(0.01)
@def_function.function(
input_signature=(tensor_spec.TensorSpec([None, 2], dtypes.float32),
tensor_spec.TensorSpec([None], dtypes.float32)))
def call(self, x, y):
with backprop.GradientTape() as tape:
loss = math_ops.reduce_mean((self.dense(x) - y) ** 2.)
trainable_variables = self.trainable_variables
gradients = tape.gradient(loss, trainable_variables)
self.optimizer.apply_gradients(zip(gradients, trainable_variables))
return {'loss': loss}
class _HasDecoratedMethod(object):
@ -74,6 +51,7 @@ class _HasDecoratedMethod(object):
def f(self, x):
return x * 3.
class DefFunctionTest(test.TestCase, parameterized.TestCase):
def testNoVariables(self):
@ -311,12 +289,6 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase):
input_signature=[tensor_spec.TensorSpec((), dtypes.int32)])
self.assertEqual(3, wrapped(constant_op.constant(1)).numpy())
def test_optimizer(self):
x = constant_op.constant([[3., 4.]])
y = constant_op.constant([2.])
model = _ModelWithOptimizer()
model(x, y)
def test_concrete_function_from_signature(self):
@def_function.function(

View File

@ -30,9 +30,7 @@ import numpy
from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python import keras
from tensorflow.python.autograph.core import ag_ctx
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.eager import backprop
from tensorflow.python.eager import cancellation
from tensorflow.python.eager import context
@ -52,9 +50,6 @@ from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import test_ops
from tensorflow.python.framework import test_util
from tensorflow.python.keras.engine import training as keras_training
from tensorflow.python.keras.layers import core
from tensorflow.python.keras.optimizer_v2 import adam
from tensorflow.python.layers import convolutional
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
@ -68,7 +63,6 @@ from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import list_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variable_scope
@ -94,28 +88,6 @@ def total_function_cache(defined):
# pylint: enable=protected-access
class MiniModel(keras_training.Model):
"""Minimal model for mnist.
Useful for testing and debugging on slow TPU simulators.
"""
def __init__(self):
super(MiniModel, self).__init__(name='')
self.fc = keras.layers.Dense(1, name='fc', kernel_initializer='ones',
bias_initializer='ones')
def call(self, inputs, training=True):
return self.fc(inputs)
class DefunnedMiniModel(MiniModel):
@function.defun
def call(self, inputs, training=True):
return super(DefunnedMiniModel, self).call(inputs, training=training)
def _example_indexed_slices_with_dense_shape():
return indexed_slices.IndexedSlices(
constant_op.constant([1, 2]), constant_op.constant([0, 1]),
@ -439,26 +411,6 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
self.assertTrue(unknown_dim[0])
self.assertLen(total_function_cache(func), 3)
def testFunctionRelaxationLosesInnerDimWithKerasLayer(self):
layer = keras.layers.Dense(1)
fn = def_function.function(experimental_relax_shapes=True)(layer)
with self.captureWritesToStream(sys.stderr) as printed:
fn(array_ops.ones((3, 2)))
self.assertNotIn('ValueError', printed.contents())
with self.captureWritesToStream(sys.stderr) as printed:
# Use batch size 2 to trigger a second cache miss on the shape.
fn(array_ops.ones((2, 2)))
self.assertNotIn('ValueError', printed.contents())
# Shape relaxation passes TensorShape([None, None]), which causes layer
# matmul to fail, due to incompatible dims. What would have been a graph
# build time error (layer would complain about the inner dim being 4).
with self.captureWritesToStream(sys.stderr) as printed:
with self.assertRaisesRegexp(errors.InvalidArgumentError,
r'Matrix size-incompatible'):
fn(array_ops.ones((3, 4)))
def testNestedShapeFunctionRelaxation(self):
got_shape = [None]
@ -1513,24 +1465,6 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
has_device.f()
self.assertIn('CPU', has_device.v.device)
@test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
def testDefunKerasModelCall(self):
model = MiniModel()
model.call = function.defun(model.call)
x = array_ops.ones([1, 2])
y = model(x)
if not context.executing_eagerly():
self.evaluate(variables.global_variables_initializer())
self.assertAllEqual([[3.0]], self.evaluate(y))
# Break the reference cycle between the MiniModel and the defun:
# `MiniModel` --(through its `call` method)--> `Function`
# `Function` --(instancemethod on `MiniModel`)--> `MiniModel`
del model.call
@test_util.run_in_graph_and_eager_modes
def testDeviceAnnotationsRespected(self):
@ -2712,54 +2646,18 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
self.assertLen(total_function_cache(defined),
3 if ops.Tensor._USE_EQUALITY else 5)
def testDecoratedMethod(self):
m = DefunnedMiniModel()
instance_call_one = m.call(array_ops.ones([1, 2]), training=True)
instance_call_two = m.call(
inputs=array_ops.ones([1, 2]), training=True)
class_call = DefunnedMiniModel.call(m, array_ops.ones([1, 2]),
training=True)
self.assertAllEqual(instance_call_one, instance_call_two)
self.assertAllEqual(instance_call_one, class_call)
def testDecoratedMethodUniqueFunctionPerInstance(self):
m = DefunnedMiniModel()
n = DefunnedMiniModel()
class_method_one = DefunnedMiniModel.call
class_method_two = DefunnedMiniModel.call
m_method_one = m.call
m_method_two = m.call
n_method_one = n.call
n_method_two = n.call
self.assertEqual(class_method_one, class_method_two)
self.assertEqual(m_method_one, m_method_two)
self.assertEqual(n_method_one, n_method_two)
self.assertNotEqual(m.call, n.call)
def testDecoratedMethodInspect(self):
class DefunnedMiniModel(object):
@function.defun
def call(self, inputs, training=True):
pass
m = DefunnedMiniModel()
fullargspec = tf_inspect.getfullargspec(m.call)
self.assertIn('training', fullargspec.args)
def testDecoratedMethodGetConcreteFunction(self):
m = DefunnedMiniModel()
instance_call_one = m.call.get_concrete_function(
array_ops.ones([1, 2]), training=False)
instance_call_two = m.call.get_concrete_function(
inputs=array_ops.ones([1, 2]), training=False)
self.assertAllEqual(instance_call_one(array_ops.ones([1, 2])),
instance_call_two(array_ops.ones([1, 2])))
# Also make sure get_concrete_function works on the class method
DefunnedMiniModel.call.get_concrete_function(
m, array_ops.ones([1, 2]), training=False)
DefunnedMiniModel.call.get_concrete_function(
m, inputs=array_ops.ones([1, 2]), training=True)
def testFunctionModifiesInputList(self):
# Tests on `list` methods that do in place modification, except `list.sort`
# since it cannot even be "defunned" in the first place
@ -2915,21 +2813,6 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
modify_same_flat(nested_input)
def testDecoratedMethodVariableCleanup(self):
m = DefunnedMiniModel()
m(array_ops.ones([1, 2]))
variable_refs = list({v.ref() for v in m.variables})
self.assertLen(variable_refs, 2)
del m
# Verifying if the variables are only referenced from variable_refs.
# We expect the reference counter to be 1, but `sys.getrefcount` reports
# one higher reference counter because a temporary is created when we call
# sys.getrefcount(). Hence check if the number returned is 2.
# https://docs.python.org/3/library/sys.html#sys.getrefcount
self.assertEqual(sys.getrefcount(variable_refs[0].deref()), 2)
self.assertEqual(sys.getrefcount(variable_refs[1].deref()), 2)
def testExecutorType(self):
@function.defun
def add_five(x):
@ -3592,56 +3475,6 @@ class MultiDeviceTest(test.TestCase, parameterized.TestCase):
self.assertEqual((v,), tape.watched_variables())
def testStandardTrainingLoopInFunction(self):
layer = core.Dense(2)
dataset = (
dataset_ops.DatasetV2.from_tensors(
(array_ops.ones([784]), array_ops.ones([], dtypes.int32)))
.map(lambda x, y: (x, y))
.repeat(10)
.batch(32))
optimizer = adam.Adam()
@def_function.function
def train():
for x, y in dataset:
with backprop.GradientTape() as tape:
out = layer(x)
loss = math_ops.reduce_mean(
nn_ops.sparse_softmax_cross_entropy_with_logits(
logits=out, labels=y))
layer_variables = layer.trainable_variables
gradients = tape.gradient(loss, layer_variables)
optimizer.apply_gradients(zip(gradients, layer_variables))
train()
def testEarlyStoppingTrainingLoopInFunction(self):
layer = core.Dense(2)
dataset = (
dataset_ops.DatasetV2.from_tensors(
(array_ops.ones([784]), array_ops.ones([], dtypes.int32)))
.map(lambda x, y: (x, y))
.repeat(10)
.batch(32))
optimizer = adam.Adam()
@def_function.function
def train():
for x, y in dataset:
with backprop.GradientTape() as tape:
out = layer(x)
loss = math_ops.reduce_mean(
nn_ops.sparse_softmax_cross_entropy_with_logits(
logits=out, labels=y))
layer_variables = layer.trainable_variables
gradients = tape.gradient(loss, layer_variables)
optimizer.apply_gradients(zip(gradients, layer_variables))
if optimizer.iterations > 3:
break
train()
def testDeferredCapture(self):
value = 1.0

View File

@ -16,11 +16,145 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sys
import tensorflow as tf
class MiniModel(tf.keras.Model):
"""Minimal model for mnist.
Useful for testing and debugging on slow TPU simulators.
"""
def __init__(self):
super(MiniModel, self).__init__(name='')
self.fc = tf.keras.layers.Dense(1, name='fc', kernel_initializer='ones',
bias_initializer='ones')
def call(self, inputs, training=True):
return self.fc(inputs)
class DefunnedMiniModel(MiniModel):
@tf.function
def call(self, inputs, training=True):
return super(DefunnedMiniModel, self).call(inputs, training=training)
class ModelWithOptimizer(tf.keras.Model):
def __init__(self):
super(ModelWithOptimizer, self).__init__()
self.dense = tf.keras.layers.Dense(1)
self.optimizer = tf.keras.optimizers.Adam(0.01)
@tf.function(
input_signature=(tf.TensorSpec([None, 2], tf.float32),
tf.TensorSpec([None], tf.float32)))
def call(self, x, y):
with tf.GradientTape() as tape:
loss = tf.math.reduce_mean((self.dense(x) - y) ** 2.)
trainable_variables = self.trainable_variables
gradients = tape.gradient(loss, trainable_variables)
self.optimizer.apply_gradients(zip(gradients, trainable_variables))
return {'loss': loss}
class FunctionTest(tf.test.TestCase):
def testFunctionRelaxationLosesInnerDimWithKerasLayer(self):
layer = tf.keras.layers.Dense(1)
fn = tf.function(experimental_relax_shapes=True)(layer)
with self.captureWritesToStream(sys.stderr) as printed:
fn(tf.ones((3, 2)))
self.assertNotIn('ValueError', printed.contents())
with self.captureWritesToStream(sys.stderr) as printed:
# Use batch size 2 to trigger a second cache miss on the shape.
fn(tf.ones((2, 2)))
self.assertNotIn('ValueError', printed.contents())
# Shape relaxation passes TensorShape([None, None]), which causes layer
# matmul to fail, due to incompatible dims. What would have been a graph
# build time error (layer would complain about the inner dim being 4).
with self.captureWritesToStream(sys.stderr) as printed:
with self.assertRaisesRegexp(tf.errors.InvalidArgumentError,
r'Matrix size-incompatible'):
fn(tf.ones((3, 4)))
def testDefunKerasModelCall(self):
model = MiniModel()
model.call = tf.function(model.call)
x = tf.ones([1, 2])
y = model(x) # pylint:disable=not-callable
self.assertAllEqual([[3.0]], self.evaluate(y))
# Break the reference cycle between the MiniModel and the defun:
# `MiniModel` --(through its `call` method)--> `Function`
# `Function` --(instancemethod on `MiniModel`)--> `MiniModel`
del model.call
def testDecoratedMethod(self):
m = DefunnedMiniModel()
instance_call_one = m.call(tf.ones([1, 2]), training=True)
instance_call_two = m.call(
inputs=tf.ones([1, 2]), training=True)
class_call = DefunnedMiniModel.call(m, tf.ones([1, 2]), training=True)
self.assertAllEqual(instance_call_one, instance_call_two)
self.assertAllEqual(instance_call_one, class_call)
def testDecoratedMethodUniqueFunctionPerInstance(self):
m = DefunnedMiniModel()
n = DefunnedMiniModel()
class_method_one = DefunnedMiniModel.call
class_method_two = DefunnedMiniModel.call
m_method_one = m.call
m_method_two = m.call
n_method_one = n.call
n_method_two = n.call
self.assertEqual(class_method_one, class_method_two)
self.assertEqual(m_method_one, m_method_two)
self.assertEqual(n_method_one, n_method_two)
self.assertNotEqual(m.call, n.call)
def testDecoratedMethodGetConcreteFunction(self):
m = DefunnedMiniModel()
instance_call_one = m.call.get_concrete_function(
tf.ones([1, 2]), training=False)
instance_call_two = m.call.get_concrete_function(
inputs=tf.ones([1, 2]), training=False)
self.assertAllEqual(instance_call_one(tf.ones([1, 2])),
instance_call_two(tf.ones([1, 2])))
# Also make sure get_concrete_function works on the class method
DefunnedMiniModel.call.get_concrete_function(
m, tf.ones([1, 2]), training=False)
DefunnedMiniModel.call.get_concrete_function(
m, inputs=tf.ones([1, 2]), training=True)
def testDecoratedMethodVariableCleanup(self):
m = DefunnedMiniModel()
m(tf.ones([1, 2])) # pylint:disable=not-callable
variable_refs = list({v.ref() for v in m.variables})
self.assertLen(variable_refs, 2)
del m
# Verifying if the variables are only referenced from variable_refs.
# We expect the reference counter to be 1, but `sys.getrefcount` reports
# one higher reference counter because a temporary is created when we call
# sys.getrefcount(). Hence check if the number returned is 2.
# https://docs.python.org/3/library/sys.html#sys.getrefcount
self.assertEqual(sys.getrefcount(variable_refs[0].deref()), 2)
self.assertEqual(sys.getrefcount(variable_refs[1].deref()), 2)
def testStandardTrainingLoopInFunction(self):
layer = tf.keras.layers.Dense(2)
dataset = (
@ -44,6 +178,36 @@ class FunctionTest(tf.test.TestCase):
train()
def testEarlyStoppingTrainingLoopInFunction(self):
layer = tf.keras.layers.Dense(2)
dataset = (
tf.data.Dataset.from_tensors((tf.ones([784]), tf.ones([], tf.int32)))
.map(lambda x, y: (x, y))
.repeat(10)
.batch(32))
optimizer = tf.keras.optimizers.Adam()
@tf.function
def train():
for x, y in dataset:
with tf.GradientTape() as tape:
out = layer(x)
loss = tf.math.reduce_mean(
tf.nn.sparse_softmax_cross_entropy_with_logits(
logits=out, labels=y))
layer_variables = layer.trainable_variables
gradients = tape.gradient(loss, layer_variables)
optimizer.apply_gradients(zip(gradients, layer_variables))
if optimizer.iterations > 3:
break
train()
def test_optimizer(self):
x = tf.constant([[3., 4.]])
y = tf.constant([2.])
model = ModelWithOptimizer()
model(x, y) # pylint:disable=not-callable
if __name__ == '__main__':
tf.test.main()