Since variable ids can be used to python data structures it can be dangerous to reuse the previous function trace if the shape and dtype is the same. We thus remove the optimization and ensure we always retrace. PiperOrigin-RevId: 317680370 Change-Id: I1dfa3a626074e623b735869aa724138072e2c274
924 lines
28 KiB
Python
924 lines
28 KiB
Python
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import functools
|
|
import itertools
|
|
import pickle
|
|
import re
|
|
import sys
|
|
import weakref
|
|
|
|
from absl.testing import parameterized
|
|
from six.moves import range
|
|
|
|
from tensorflow.python.autograph.core import converter
|
|
from tensorflow.python.eager import def_function
|
|
from tensorflow.python.eager import lift_to_graph
|
|
from tensorflow.python.framework import constant_op
|
|
from tensorflow.python.framework import dtypes
|
|
from tensorflow.python.framework import errors
|
|
from tensorflow.python.framework import ops
|
|
from tensorflow.python.framework import tensor_spec
|
|
from tensorflow.python.framework import test_util
|
|
from tensorflow.python.module import module
|
|
from tensorflow.python.ops import array_ops
|
|
from tensorflow.python.ops import control_flow_ops
|
|
from tensorflow.python.ops import math_ops
|
|
from tensorflow.python.ops import random_ops
|
|
from tensorflow.python.ops import resource_variable_ops
|
|
from tensorflow.python.ops import variable_scope
|
|
from tensorflow.python.ops import variables
|
|
from tensorflow.python.platform import test
|
|
|
|
|
|
def undecorated_function(x):
|
|
return x * 3.
|
|
|
|
|
|
class _HasDecoratedMethod(object):
|
|
|
|
@def_function.function
|
|
def f(self, x):
|
|
return x * 3.
|
|
|
|
|
|
class DefFunctionTest(test.TestCase, parameterized.TestCase):
|
|
|
|
def testNoVariables(self):
|
|
|
|
@def_function.function
|
|
def fn(x):
|
|
return 2 * x
|
|
|
|
self.assertAllEqual(fn(constant_op.constant(4.0)), 8.0)
|
|
|
|
@test_util.disable_tfrt('Variable argument is not supported')
|
|
def testFailIfVariablesAreCreatedMoreThanOnce(self):
|
|
|
|
@def_function.function
|
|
def fn(x):
|
|
return variables.Variable(1.0) + x
|
|
|
|
with self.assertRaises(ValueError):
|
|
fn(1.0)
|
|
|
|
@test_util.disable_tfrt('Variable argument is not supported')
|
|
def testFailIfVariablesAreCreatedMoreThanOnceNoWeakRef(self):
|
|
state = []
|
|
|
|
@def_function.function
|
|
def fn(x):
|
|
state.append(variables.Variable(1.0))
|
|
return state[-1] + x
|
|
|
|
with self.assertRaises(ValueError):
|
|
fn(1.0)
|
|
|
|
def testRange(self):
|
|
|
|
@def_function.function
|
|
def f(unused_x):
|
|
return 1.0
|
|
|
|
self.assertAllEqual(f(range(5)), 1.0)
|
|
|
|
@test_util.disable_tfrt('Variable argument is not supported')
|
|
def testCorrectVariableCreation(self):
|
|
|
|
state = []
|
|
|
|
@def_function.function
|
|
def fn(x):
|
|
if not state:
|
|
state.append(variables.Variable(2.0))
|
|
return state[0] * x
|
|
|
|
self.assertAllEqual(fn(constant_op.constant(1.0)), 2.0)
|
|
self.assertAllEqual(fn(constant_op.constant(3.0)), 6.0)
|
|
|
|
@test_util.disable_tfrt('Variable argument is not supported')
|
|
def testFunctionInitializer(self):
|
|
|
|
state = []
|
|
|
|
@def_function.function
|
|
def fn(x):
|
|
if not state:
|
|
state.append(variables.Variable(lambda: 2.0))
|
|
return state[0] * x
|
|
|
|
self.assertAllEqual(fn(constant_op.constant(1.0)), 2.0)
|
|
|
|
@test_util.disable_tfrt('Variable argument is not supported')
|
|
def testFunctionMultipleVariableInitializer(self):
|
|
|
|
state = []
|
|
|
|
@def_function.function
|
|
def fn(x):
|
|
if not state:
|
|
state.append(variables.Variable(lambda: 2.0))
|
|
state.append(variables.Variable(lambda: 5.0))
|
|
return state[0] * x, state[1] * x
|
|
|
|
self.assertAllEqual(fn(constant_op.constant(1.0)), [2.0, 5.0])
|
|
|
|
@test_util.disable_tfrt('Variable argument is not supported')
|
|
def testFunctionInitializationFunction(self):
|
|
|
|
state = []
|
|
|
|
@def_function.function
|
|
def fn(x):
|
|
if not state:
|
|
state.append(variables.Variable(2.0))
|
|
return state[0] * x
|
|
|
|
init_fn = fn.get_initialization_function(constant_op.constant(1.0))
|
|
self.assertLen(state, 1)
|
|
self.assertFalse(
|
|
resource_variable_ops.var_is_initialized_op(state[0].handle))
|
|
init_fn()
|
|
self.assertEqual(state[0].numpy(), 2.0)
|
|
|
|
@test_util.disable_tfrt('Variable argument is not supported')
|
|
def testVariableInitializerNotConstant(self):
|
|
|
|
state = []
|
|
|
|
@def_function.function
|
|
def fn(x):
|
|
if not state:
|
|
state.append(variables.Variable(2.0 * x))
|
|
return state[0] * x
|
|
|
|
self.assertAllEqual(fn(constant_op.constant(1.0)), 2.0)
|
|
self.assertAllEqual(fn(constant_op.constant(3.0)), 6.0)
|
|
|
|
def testLegacyGraphModeVariables(self):
|
|
with ops.Graph().as_default(), self.test_session() as sess:
|
|
state = []
|
|
|
|
@def_function.function
|
|
def fn(x):
|
|
if not state:
|
|
state.append(variables.Variable(2.0))
|
|
return state[0] * x
|
|
|
|
result = fn(3.0)
|
|
|
|
self.evaluate(variables.global_variables_initializer())
|
|
self.assertAllEqual(sess.run(state[0]), 2.0)
|
|
self.assertAllEqual(self.evaluate(result), 6.0)
|
|
|
|
@test_util.disable_tfrt('Variable argument is not supported')
|
|
def testLegacyGraphModeVariablesNonTrivialInitializer(self):
|
|
with ops.Graph().as_default(), self.test_session() as sess:
|
|
state = []
|
|
|
|
@def_function.function
|
|
def fn(x):
|
|
if not state:
|
|
two = constant_op.constant(2.0)
|
|
four = two * two
|
|
two_again = math_ops.sqrt(four)
|
|
state.append(variables.Variable(two_again + four))
|
|
return state[0] * x
|
|
|
|
result = fn(3.0)
|
|
|
|
self.evaluate(variables.global_variables_initializer())
|
|
self.assertAllEqual(sess.run(state[0]), 6.0)
|
|
self.assertAllEqual(self.evaluate(result), 18.0)
|
|
|
|
@test_util.disable_tfrt('Variable argument is not supported')
|
|
def testLegacyGraphModeInputDependentInitializerFails(self):
|
|
with ops.Graph().as_default():
|
|
state = []
|
|
|
|
@def_function.function
|
|
def fn(x):
|
|
if not state:
|
|
state.append(variables.Variable(2.0 * x))
|
|
return state[0] * x
|
|
|
|
with self.assertRaisesRegexp(
|
|
lift_to_graph.UnliftableError, r'transitively.* mul .* x'):
|
|
fn(constant_op.constant(3.0))
|
|
|
|
@test_util.disable_tfrt('Variable argument is not supported')
|
|
def testMethod(self):
|
|
|
|
class MyModel(object):
|
|
|
|
def __init__(self):
|
|
self.var = None
|
|
|
|
@def_function.function
|
|
def apply(self, x):
|
|
if self.var is None:
|
|
self.var = variables.Variable(2.0)
|
|
return self.var * x
|
|
|
|
m0 = MyModel()
|
|
self.assertAllEqual(m0.apply(3.0), 6.0)
|
|
# Calling twice to exercise that we do not recreate variables.
|
|
m0.var.assign(3.0)
|
|
self.assertAllEqual(m0.apply(3.0), 9.0)
|
|
|
|
m1 = MyModel()
|
|
self.assertAllEqual(m1.apply(3.0), 6.0)
|
|
|
|
def test_functools_partial(self):
|
|
self.assertAllClose(
|
|
3.,
|
|
def_function.function(functools.partial(lambda x, y: x + y, 1.))(
|
|
constant_op.constant(2.)))
|
|
|
|
@test_util.disable_tfrt('Partial is not supported')
|
|
def test_functools_partial_new_default(self):
|
|
def f(x=3, y=7):
|
|
return x + y
|
|
|
|
func = def_function.function(functools.partial(f, y=6))
|
|
self.assertEqual(func().numpy(), 9)
|
|
self.assertEqual(func(y=8).numpy(), 11)
|
|
|
|
@test_util.disable_tfrt('Partial is not supported')
|
|
def test_functools_partial_keywords(self):
|
|
def f(x, y):
|
|
return x + y
|
|
|
|
func = def_function.function(
|
|
functools.partial(f, x=array_ops.zeros([1]), y=array_ops.zeros([1])))
|
|
self.assertAllEqual(func(), [0.0])
|
|
|
|
@test_util.disable_tfrt('Partial is not supported')
|
|
def test_functools_partial_single_positional(self):
|
|
def f(x, y):
|
|
return x + y
|
|
|
|
func = def_function.function(
|
|
functools.partial(f, constant_op.constant(1)))
|
|
self.assertAllEqual(func(5), 6)
|
|
|
|
@test_util.disable_tfrt('Partial is not supported')
|
|
def test_complicated_partial_with_defaults(self):
|
|
|
|
def identity(*args):
|
|
return args
|
|
|
|
def dynamic_unroll(core_fn,
|
|
input_sequence,
|
|
initial_state,
|
|
sequence_length=None,
|
|
parallel_iterations=1,
|
|
swap_memory=False):
|
|
del core_fn
|
|
self.assertIs(None, sequence_length)
|
|
self.assertEqual(1, parallel_iterations)
|
|
self.assertTrue(swap_memory)
|
|
return input_sequence, initial_state
|
|
|
|
input_sequence = random_ops.random_uniform([1, 1, 1])
|
|
initial_state = random_ops.random_uniform([1, 1])
|
|
|
|
func = def_function.function(
|
|
functools.partial(dynamic_unroll, identity, swap_memory=True))
|
|
func(input_sequence, initial_state)
|
|
|
|
def test_unspecified_default_argument(self):
|
|
wrapped = def_function.function(
|
|
lambda x, y=2: x + y,
|
|
input_signature=[tensor_spec.TensorSpec((), dtypes.int32)])
|
|
self.assertEqual(3, wrapped(constant_op.constant(1)).numpy())
|
|
|
|
def test_concrete_function_from_signature(self):
|
|
|
|
@def_function.function(
|
|
input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])
|
|
def compute(x):
|
|
return 2. * x
|
|
|
|
concrete = compute.get_concrete_function()
|
|
self.assertAllClose(1., concrete(constant_op.constant(0.5)))
|
|
concrete = compute.get_concrete_function(
|
|
tensor_spec.TensorSpec(None, dtypes.float32))
|
|
self.assertAllClose(4., concrete(constant_op.constant(2.)))
|
|
signature_args, _ = concrete.structured_input_signature
|
|
self.assertEqual(signature_args,
|
|
(tensor_spec.TensorSpec(
|
|
None, dtypes.float32, name='x'),))
|
|
|
|
@test_util.disable_tfrt('Variable argument is not supported')
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def test_variable_naming(self):
|
|
class HasVars(module.Module):
|
|
|
|
def __init__(self):
|
|
self.x = None
|
|
self.y = None
|
|
self.z = None
|
|
|
|
@def_function.function
|
|
def make_x(self):
|
|
if self.x is None:
|
|
self.x = variables.Variable(1., name='v')
|
|
|
|
def make_y(self):
|
|
if self.y is None:
|
|
self.y = variables.Variable(1., name='v')
|
|
|
|
def make_z(self):
|
|
if self.z is None:
|
|
with ops.name_scope('z_scope', skip_on_eager=False):
|
|
self.z = variables.Variable(1., name='z')
|
|
|
|
root = HasVars()
|
|
root.make_x()
|
|
root.make_y()
|
|
root.make_z()
|
|
self.assertEqual('v:0', root.x.name)
|
|
self.assertEqual('z_scope/z:0', root.z.name)
|
|
|
|
def test_concrete_function_keyword_arguments(self):
|
|
@def_function.function
|
|
def f(x):
|
|
return x
|
|
|
|
conc = f.get_concrete_function(
|
|
tensor_spec.TensorSpec(None, dtypes.float32, 'y'))
|
|
conc(y=constant_op.constant(3.0))
|
|
signature_args, _ = conc.structured_input_signature
|
|
self.assertEqual('y', signature_args[0].name)
|
|
|
|
conc = f.get_concrete_function(tensor_spec.TensorSpec(None, dtypes.float32))
|
|
conc(x=constant_op.constant(3.0))
|
|
signature_args, _ = conc.structured_input_signature
|
|
self.assertEqual('x', signature_args[0].name)
|
|
|
|
@def_function.function
|
|
def g(x):
|
|
return x[0]
|
|
|
|
conc = g.get_concrete_function(
|
|
[tensor_spec.TensorSpec(None, dtypes.float32, 'z'), 2])
|
|
conc(z=constant_op.constant(3.0))
|
|
signature_args, _ = conc.structured_input_signature
|
|
self.assertEqual('z', signature_args[0][0].name)
|
|
|
|
def test_error_inner_capture(self):
|
|
|
|
@def_function.function
|
|
def f(inputs):
|
|
num_steps, _ = inputs.shape[:2]
|
|
outputs = []
|
|
for t in math_ops.range(num_steps):
|
|
outputs.append(inputs[t])
|
|
return outputs
|
|
|
|
with self.assertRaisesRegexp(errors.InaccessibleTensorError,
|
|
'defined in another function or code block'):
|
|
f(array_ops.zeros(shape=(8, 42, 3)))
|
|
|
|
@test_util.disable_tfrt('Control flow is not supported')
|
|
def testRuntimeErrorNotSticky(self):
|
|
|
|
@def_function.function
|
|
def fail(i):
|
|
control_flow_ops.Assert(math_ops.equal(i, 0), ['ick'])
|
|
|
|
fail(constant_op.constant(0)) # OK
|
|
with self.assertRaises(errors.InvalidArgumentError):
|
|
fail(constant_op.constant(1)) # InvalidArgument: "ick"
|
|
fail(constant_op.constant(0)) # OK
|
|
|
|
def testUnderscoreName(self):
|
|
|
|
@def_function.function
|
|
def f(_):
|
|
return _ + _
|
|
|
|
self.assertAllEqual(2.0, f(constant_op.constant(1.0)))
|
|
|
|
def test_serialization_signature_cache(self):
|
|
|
|
@def_function.function
|
|
def f(x, y):
|
|
return x, y
|
|
|
|
f(constant_op.constant([[3., 4.]]), constant_op.constant([2.]))
|
|
f(constant_op.constant([[3, 4, 5]]), constant_op.constant([2]))
|
|
|
|
signatures_args = set()
|
|
concrete_functions = f._list_all_concrete_functions_for_serialization()
|
|
for concrete_function in concrete_functions:
|
|
args, kwargs = concrete_function.structured_input_signature
|
|
signatures_args.add(args)
|
|
self.assertEqual(dict(), kwargs)
|
|
|
|
self.assertEqual(
|
|
signatures_args,
|
|
set(((tensor_spec.TensorSpec([1, 2], dtypes.float32, name='x'),
|
|
tensor_spec.TensorSpec([1], dtypes.float32, name='y')),
|
|
(tensor_spec.TensorSpec([1, 3], dtypes.int32, name='x'),
|
|
tensor_spec.TensorSpec([1], dtypes.int32, name='y')))))
|
|
|
|
@test_util.assert_no_garbage_created
|
|
def testFunctionReferenceCycles(self):
|
|
fn = def_function.function(lambda x: 2. * x)
|
|
fn(constant_op.constant(4.0))
|
|
weak_fn = weakref.ref(fn)
|
|
del fn
|
|
# Tests that the weak reference we made to the function is now dead, which
|
|
# means the object has been deleted. This should be true as long as the
|
|
# function itself is not involved in a reference cycle.
|
|
self.assertIs(None, weak_fn())
|
|
|
|
@test_util.assert_no_garbage_created
|
|
def testMethodReferenceCycles(self):
|
|
has_decorated_method = _HasDecoratedMethod()
|
|
has_decorated_method.f(constant_op.constant(5.))
|
|
weak_fn = weakref.ref(has_decorated_method.f)
|
|
del has_decorated_method
|
|
# Tests that the weak reference we made to the function is now dead, which
|
|
# means the object has been deleted. This should be true as long as the
|
|
# function itself is not involved in a reference cycle.
|
|
self.assertIs(None, weak_fn())
|
|
|
|
@test_util.assert_no_new_pyobjects_executing_eagerly
|
|
def testErrorMessageWhenGraphTensorIsPassedToEager(self):
|
|
|
|
@def_function.function
|
|
def failing_function():
|
|
a = constant_op.constant(1.)
|
|
|
|
with ops.init_scope():
|
|
_ = a + a
|
|
|
|
with self.assertRaisesRegexp(
|
|
TypeError,
|
|
re.compile('An op outside of the function.*passed.*Const', re.DOTALL)):
|
|
failing_function()
|
|
|
|
def testNonUniqueNamesGetConcreteFunction(self):
|
|
@def_function.function
|
|
def non_unique_arg_names(x, **kwargs):
|
|
a, b, c = x
|
|
d = kwargs['d']
|
|
return a + b + c + d
|
|
|
|
concrete = non_unique_arg_names.get_concrete_function(
|
|
(tensor_spec.TensorSpec(None, dtypes.float32),
|
|
tensor_spec.TensorSpec(None, dtypes.float32),
|
|
tensor_spec.TensorSpec(None, dtypes.float32)),
|
|
d=tensor_spec.TensorSpec(None, dtypes.float32))
|
|
self.assertAllClose(
|
|
10.,
|
|
concrete(x=constant_op.constant(1.),
|
|
x_1=constant_op.constant(2.),
|
|
x_2=constant_op.constant(3.),
|
|
d=constant_op.constant(4.)))
|
|
self.assertAllClose(
|
|
10.,
|
|
concrete(constant_op.constant(1.),
|
|
constant_op.constant(2.),
|
|
constant_op.constant(3.),
|
|
constant_op.constant(4.)))
|
|
|
|
@test_util.disable_tfrt('Variable argument is not supported')
|
|
def testVariableCreatorScope(self):
|
|
created_variables = []
|
|
captured_variables = []
|
|
|
|
@def_function.function
|
|
def f():
|
|
if not created_variables:
|
|
created_variables.append(variables.Variable(1.))
|
|
return created_variables[0] + 1.
|
|
|
|
def capture_creator(next_creator, **kwargs):
|
|
created = next_creator(**kwargs)
|
|
captured_variables.append(created)
|
|
return created
|
|
|
|
with variable_scope.variable_creator_scope(capture_creator):
|
|
f()
|
|
self.assertEqual(created_variables, captured_variables)
|
|
|
|
@test_util.disable_tfrt('Variable argument is not supported')
|
|
def testVarAlreadyInitializedNoClobbering(self):
|
|
v_holder = []
|
|
|
|
@def_function.function
|
|
def add_var(x):
|
|
if not v_holder:
|
|
v = variables.Variable([1., 2.])
|
|
v_holder.append(v)
|
|
already_initialized = variables.Variable(3.)
|
|
with ops.init_scope():
|
|
already_initialized.assign(10.)
|
|
v_holder.append(already_initialized)
|
|
return v_holder[0] + v_holder[1] + x
|
|
|
|
add_var.get_concrete_function(constant_op.constant(2.))
|
|
self.assertAllClose([13., 14.], add_var(constant_op.constant(2.)))
|
|
|
|
@test_util.disable_tfrt('Variable argument is not supported')
|
|
def testSameVariableTwice(self):
|
|
v = variables.Variable(1.0)
|
|
|
|
@def_function.function
|
|
def add(a, b):
|
|
return a + b
|
|
|
|
self.assertAllEqual(add(v, v), 2.0)
|
|
|
|
@test_util.disable_tfrt('Variable argument is not supported')
|
|
def testVariableUpdate(self):
|
|
v1 = variables.Variable(1.0)
|
|
v2 = variables.Variable(2.0)
|
|
v3 = variables.Variable(4, dtype=dtypes.int32)
|
|
|
|
trace_count = [0]
|
|
|
|
@def_function.function
|
|
def double_variable(x):
|
|
trace_count[0] += 1
|
|
x.assign_add(x.read_value())
|
|
|
|
self.assertEqual(trace_count[0], 0)
|
|
double_variable(v1)
|
|
self.assertEqual(trace_count[0], 1)
|
|
self.assertEqual(self.evaluate(v1), 2.0)
|
|
double_variable(v2)
|
|
self.assertEqual(trace_count[0], 2)
|
|
self.assertEqual(self.evaluate(v2), 4.0)
|
|
double_variable(v3)
|
|
self.assertEqual(trace_count[0], 3)
|
|
self.assertEqual(self.evaluate(v3), 8)
|
|
|
|
def testShapeCache(self):
|
|
@def_function.function
|
|
def func(x):
|
|
return 2 * x
|
|
|
|
func_a = func.get_concrete_function(
|
|
tensor_spec.TensorSpec([None], dtypes.int32))
|
|
func_b = func.get_concrete_function(
|
|
tensor_spec.TensorSpec([None], dtypes.int32))
|
|
|
|
self.assertIs(func_a, func_b)
|
|
|
|
@test_util.disable_tfrt('Nested function is not supported')
|
|
def testInitializationInNestedCall(self):
|
|
v_holder = []
|
|
|
|
@def_function.function
|
|
def add_var(x):
|
|
if not v_holder:
|
|
v = variables.Variable([1., 2.])
|
|
v_holder.append(v)
|
|
already_initialized = variables.Variable(3.)
|
|
with ops.init_scope():
|
|
already_initialized.assign(10.)
|
|
v_holder.append(already_initialized)
|
|
return v_holder[0] + v_holder[1] + x
|
|
|
|
@def_function.function
|
|
def wrapper(x):
|
|
return add_var(x)
|
|
|
|
self.assertAllClose([13., 14.], wrapper(constant_op.constant(2.)))
|
|
v_holder[1].assign(11.)
|
|
self.assertAllClose([14., 15.], wrapper(constant_op.constant(2.)))
|
|
|
|
@test_util.disable_tfrt('Variable argument is not supported')
|
|
@test_util.run_gpu_only
|
|
def testDeviceAnnotationRespected(self):
|
|
a = []
|
|
|
|
@def_function.function()
|
|
def create_variable():
|
|
with ops.init_scope():
|
|
initial_value = random_ops.random_uniform(
|
|
(2, 2), maxval=1000000, dtype=dtypes.int64)
|
|
|
|
if not a:
|
|
with ops.device('CPU:0'):
|
|
a.append(resource_variable_ops.ResourceVariable(initial_value))
|
|
|
|
return a[0].read_value()
|
|
|
|
create_variable()
|
|
self.assertRegexpMatches(a[0].device, 'CPU')
|
|
|
|
@test_util.disable_tfrt('Variable argument is not supported')
|
|
@test_util.run_gpu_only
|
|
def testDeviceAnnotationForInitializerRespected(self):
|
|
a = []
|
|
initial_value = []
|
|
|
|
def initial_value_fn():
|
|
initial_value.append(random_ops.random_uniform((2, 3)))
|
|
return initial_value[0]
|
|
|
|
@def_function.function()
|
|
def create_variable():
|
|
with ops.init_scope():
|
|
if not a:
|
|
a.append(variables.Variable(initial_value_fn))
|
|
|
|
with ops.device('CPU:0'):
|
|
create_variable()
|
|
self.assertRegexpMatches(a[0].device, 'CPU')
|
|
self.assertRegexpMatches(initial_value[0].device, 'CPU')
|
|
|
|
def testDecorate(self):
|
|
func = def_function.function(lambda: 1)
|
|
def decorator(f):
|
|
return lambda: 1 + f()
|
|
|
|
func._decorate(decorator)
|
|
self.assertEqual(func().numpy(), 2)
|
|
|
|
@parameterized.parameters(*itertools.product(
|
|
(None, (tensor_spec.TensorSpec([]),)), # input_signature
|
|
(True, False), # autograph
|
|
(None, converter.Feature.ALL), # autograph_options
|
|
(None, 'foo.bar'), # implements
|
|
(None, True, False), # relax_shapes
|
|
(True, False), # compile
|
|
(True, False), # override_function
|
|
))
|
|
def testClone(self, input_signature, autograph, autograph_options, implements,
|
|
relax_shapes, compile_, override_function):
|
|
original_py_function = lambda x: x
|
|
|
|
compile_ = False
|
|
func = def_function.function(
|
|
func=original_py_function,
|
|
input_signature=input_signature,
|
|
autograph=autograph,
|
|
experimental_implements=implements,
|
|
experimental_autograph_options=autograph_options,
|
|
experimental_relax_shapes=relax_shapes,
|
|
experimental_compile=compile_)
|
|
|
|
if override_function:
|
|
cloned_py_function = lambda x: x + 1
|
|
else:
|
|
cloned_py_function = original_py_function
|
|
|
|
cloned = func._clone(python_function=cloned_py_function)
|
|
|
|
self.assertEqual(cloned_py_function, cloned._python_function)
|
|
self.assertEqual(func._name, cloned._name)
|
|
self.assertEqual(input_signature, cloned._input_signature)
|
|
self.assertEqual(autograph, cloned._autograph)
|
|
self.assertEqual(implements, cloned._implements)
|
|
self.assertEqual(autograph_options, cloned._experimental_autograph_options)
|
|
self.assertEqual(relax_shapes, cloned._experimental_relax_shapes)
|
|
self.assertEqual(compile_, cloned._experimental_compile)
|
|
|
|
# This test does not run with XLA JIT support linked in so we can only check
|
|
# the output of the function if compile is disabled.
|
|
if not compile_:
|
|
x = array_ops.zeros([])
|
|
self.assertEqual(self.evaluate(cloned(x)),
|
|
self.evaluate(cloned_py_function(x)))
|
|
|
|
@test_util.disable_tfrt('Variable argument is not supported')
|
|
def testLiftPlaceholderInitializedVariable(self):
|
|
with ops.Graph().as_default():
|
|
var_list = []
|
|
|
|
@def_function.function
|
|
def use_variable():
|
|
if not var_list:
|
|
initial_value = array_ops.placeholder(shape=[], dtype=dtypes.float32)
|
|
v = variables.Variable(initial_value)
|
|
var_list.append(v)
|
|
return var_list[0] + 1.
|
|
|
|
var_plus_one = use_variable()
|
|
with self.session() as session:
|
|
init_op = var_list[0].initializer
|
|
session.run(init_op, feed_dict={init_op.inputs[1]: 2.})
|
|
self.assertEqual(3., session.run(var_plus_one))
|
|
|
|
def testDecorate_rejectedAfterTrace(self):
|
|
func = def_function.function(lambda: 1)
|
|
self.assertEqual(func().numpy(), 1)
|
|
msg = 'Functions cannot be decorated after they have been traced.'
|
|
with self.assertRaisesRegexp(ValueError, msg):
|
|
func._decorate(lambda f: f)
|
|
|
|
def testGetConcreteFunctionGraphLifetime(self):
|
|
|
|
@def_function.function
|
|
def func():
|
|
pass
|
|
|
|
graph = func.get_concrete_function().graph
|
|
del func
|
|
|
|
# If the graph is deleted, then an exception is raised on reading `captures`
|
|
self.assertEmpty(graph.captures)
|
|
|
|
@parameterized.parameters(*itertools.product(
|
|
(None, (tensor_spec.TensorSpec([]),)), # input_signature
|
|
(True, False), # autograph
|
|
(None, converter.Feature.ALL), # autograph_options
|
|
(None, 'foo.bar'), # implements
|
|
(None, True, False), # relax_shapes
|
|
))
|
|
def test_pickle(self, input_signature, autograph, autograph_options,
|
|
implements, relax_shapes):
|
|
"""@function objects can be pickled and unpickled."""
|
|
original_py_function = undecorated_function
|
|
|
|
func = def_function.function(
|
|
func=original_py_function,
|
|
input_signature=input_signature,
|
|
autograph=autograph,
|
|
experimental_implements=implements,
|
|
experimental_autograph_options=autograph_options,
|
|
experimental_relax_shapes=relax_shapes,
|
|
)
|
|
|
|
cloned = pickle.loads(pickle.dumps(func))
|
|
|
|
self.assertEqual(func._name, cloned._name)
|
|
self.assertEqual(input_signature, cloned._input_signature)
|
|
self.assertEqual(autograph, cloned._autograph)
|
|
self.assertEqual(implements, cloned._implements)
|
|
self.assertEqual(autograph_options, cloned._experimental_autograph_options)
|
|
self.assertEqual(relax_shapes, cloned._experimental_relax_shapes)
|
|
|
|
x = array_ops.ones([])
|
|
self.assertEqual(self.evaluate(cloned(x)), self.evaluate(func(x)))
|
|
|
|
def test_frequent_retracing_warning(self):
|
|
if sys.version_info[0] < 3:
|
|
self.skipTest('self.assertLogs() call is not available in Python 2.')
|
|
|
|
@def_function.function
|
|
def f(x):
|
|
return x
|
|
|
|
with self.assertLogs(level='WARN') as logs:
|
|
f(1)
|
|
f(2)
|
|
f(3)
|
|
f(4)
|
|
self.assertEmpty(logs.output)
|
|
f(5)
|
|
|
|
self.assertLen(logs.output, 1)
|
|
self.assertIn('Tracing is expensive', logs.output[0])
|
|
|
|
def test_frequent_retracing_warning_lambda(self):
|
|
if sys.version_info[0] < 3:
|
|
self.skipTest('self.assertLogs() call is not available in Python 2.')
|
|
|
|
f = def_function.function(lambda x: x)
|
|
|
|
with self.assertLogs(level='WARN') as logs:
|
|
f(1)
|
|
f(2)
|
|
f(3)
|
|
f(4)
|
|
f(5)
|
|
|
|
self.assertLen(logs.output, 1)
|
|
self.assertIn('Tracing is expensive', logs.output[0])
|
|
|
|
def test_frequent_retracing_warning_method(self):
|
|
if sys.version_info[0] < 3:
|
|
self.skipTest('self.assertLogs() call is not available in Python 2.')
|
|
|
|
class Foo(object):
|
|
|
|
@def_function.function
|
|
def f(self, x):
|
|
return x
|
|
|
|
f = Foo().f
|
|
|
|
with self.assertLogs(level='WARN') as logs:
|
|
f(1)
|
|
f(2)
|
|
f(3)
|
|
f(4)
|
|
f(5)
|
|
|
|
self.assertLen(logs.output, 1)
|
|
self.assertIn('Tracing is expensive', logs.output[0])
|
|
|
|
def test_frequent_retracing_warning_two_independent_tf_functions(self):
|
|
if sys.version_info[0] < 3:
|
|
self.skipTest('self.assertLogs() call is not available in Python 2.')
|
|
|
|
@def_function.function
|
|
def f(x):
|
|
return x
|
|
|
|
@def_function.function
|
|
def g(x):
|
|
return x
|
|
|
|
with self.assertLogs(level='WARN') as logs:
|
|
f(1)
|
|
f(2)
|
|
f(3)
|
|
f(4)
|
|
g(1)
|
|
g(2)
|
|
g(3)
|
|
g(4)
|
|
g(5)
|
|
|
|
self.assertLen(logs.output, 1)
|
|
self.assertIn('Tracing is expensive', logs.output[0])
|
|
|
|
@test_util.disable_tfrt('Nested function is not supported')
|
|
def test_frequent_retracing_warning_nested(self):
|
|
if sys.version_info[0] < 3:
|
|
self.skipTest('self.assertLogs() call is not available in Python 2.')
|
|
|
|
@def_function.function
|
|
def inner(x):
|
|
return x + 1
|
|
|
|
@def_function.function
|
|
def outer1(x):
|
|
return inner(x) * 2
|
|
|
|
@def_function.function
|
|
def outer2(x):
|
|
return inner(x) * 3
|
|
|
|
with self.assertLogs(level='WARN') as logs:
|
|
inner(1)
|
|
inner(2)
|
|
inner(3)
|
|
inner(4)
|
|
|
|
outer1(5)
|
|
outer1(6)
|
|
outer1(7)
|
|
outer1(8)
|
|
|
|
outer2(9)
|
|
outer2(10)
|
|
outer2(11)
|
|
outer2(12)
|
|
|
|
self.assertEmpty(logs.output)
|
|
|
|
outer2(13)
|
|
|
|
self.assertLen(logs.output, 1)
|
|
self.assertIn('Tracing is expensive', logs.output[0])
|
|
|
|
def test_frequent_retracing_warning_on_reinstantiation(self):
|
|
if sys.version_info[0] < 3:
|
|
self.skipTest('self.assertLogs() call is not available in Python 2.')
|
|
|
|
with self.assertLogs(level='WARN') as logs:
|
|
for i in range(5):
|
|
|
|
@def_function.function
|
|
def f(x):
|
|
return x
|
|
|
|
f(i)
|
|
|
|
if i < 4:
|
|
self.assertEmpty(logs.output)
|
|
|
|
self.assertLen(logs.output, 1)
|
|
self.assertIn('Tracing is expensive', logs.output[0])
|
|
|
|
|
|
if __name__ == '__main__':
|
|
ops.enable_eager_execution()
|
|
test.main()
|