Add more tests for partial functions in tf.function.

PiperOrigin-RevId: 239377543
This commit is contained in:
Vojtech Bardiovsky 2019-03-20 05:17:49 -07:00 committed by TensorFlower Gardener
parent 73cb55b411
commit ec1befe9d8
3 changed files with 79 additions and 9 deletions

View File

@ -246,6 +246,18 @@ class DefFunctionTest(test.TestCase):
def_function.function(functools.partial(lambda x, y: x + y, 1.))(
constant_op.constant(2.)))
def test_functools_partial_single_keyword(self):
def f(x, y):
return x + y
func = def_function.function(
functools.partial(f, x=constant_op.constant(1)))
# This is a limitation of functools.partial.
with self.assertRaisesRegexp(
TypeError, 'got multiple values for'):
func(5)
def test_functools_partial_keywords(self):
def f(x, y):
return x + y
@ -254,6 +266,14 @@ class DefFunctionTest(test.TestCase):
functools.partial(f, x=array_ops.zeros([1]), y=array_ops.zeros([1])))
self.assertAllEqual(func(), [0.0])
def test_functools_partial_single_positional(self):
def f(x, y):
return x + y
func = def_function.function(
functools.partial(f, constant_op.constant(1)))
self.assertAllEqual(func(5), 6)
def test_unspecified_default_argument(self):
wrapped = def_function.function(
lambda x, y=2: x + y,

View File

@ -680,19 +680,22 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
fn2 = lambda: array_ops.ones([10]) * 2
def fn3(x=2):
def fn3(x=3):
return array_ops.ones([10]) * x
fn3 = functools.partial(fn3, x=3)
fn4 = functools.partial(fn3, x=4)
fn5 = functools.partial(fn3, 5)
return gen_functional_ops.case(val, [], [dtypes.float32],
[function.defun(f).get_concrete_function()
for f in (fn1, fn2, fn3)])
for f in (fn1, fn2, fn3, fn4, fn5)])
ones = array_ops.ones([10])
self.assertAllEqual([ones], test_function(0))
self.assertAllEqual([ones * 2], test_function(1))
self.assertAllEqual([ones * 3], test_function(2))
self.assertAllEqual([ones * 3], test_function(22)) # default branch
self.assertAllEqual([ones * 4], test_function(3))
self.assertAllEqual([ones * 5], test_function(4))
self.assertAllEqual([ones * 5], test_function(22)) # default branch
@test_util.enable_control_flow_v2
def testVariableInLoopInFunction(self):

View File

@ -1140,22 +1140,69 @@ class LoadTest(test.TestCase, parameterized.TestCase):
self.assertAllEqual(4, root.f(constant_op.constant(3)).numpy())
def test_partial(self, cycles):
# TODO(vbardiovsky): Figure out the story for FunctionSpec vs partial vs
# input_signature.
# TODO(b/124441704): Figure out the story for FunctionSpec vs partial.
self.skipTest("Partial does not work for serialization.")
def f(x, y):
return x + y
func = def_function.function(
functools.partial(f, x=array_ops.zeros([1]), y=array_ops.zeros([1])))
functools.partial(f, x=array_ops.zeros([1]), y=array_ops.ones([1])))
root = tracking.AutoTrackable()
root.f = func
self.assertAllEqual(root.f(), [0.0])
self.assertAllEqual(root.f(), [1.0])
root = self.cycle(root, cycles)
self.assertAllEqual(root.f(), [0.0])
self.assertAllEqual(root.f(), [1.0])
def test_partial_with_non_tensor_defaults(self, cycles):
def f(x, y):
return x + y
func = def_function.function(functools.partial(f, y=5))
root = tracking.AutoTrackable()
root.f = func
self.assertAllEqual(root.f(1), 6)
root = self.cycle(root, cycles)
self.assertAllEqual(root.f(1), 6)
def test_partial_with_positional(self, cycles):
# TODO(b/124441704): Figure out the story for FunctionSpec vs partial.
self.skipTest("Partial does not work for serialization.")
def f(x, y):
return x + y
func = def_function.function(functools.partial(f, constant_op.constant(5)))
root = tracking.AutoTrackable()
root.f = func
self.assertAllEqual(root.f(1), 6)
root = self.cycle(root, cycles)
self.assertAllEqual(root.f(1), 6)
def test_partial_with_passed_fn_as_default(self, cycles):
# TODO(b/124441704): Figure out the story for FunctionSpec vs partial.
self.skipTest("Partial does not work for serialization.")
def f(x, y):
return x(3) + y
def my_func(a):
return 2 * a
func = def_function.function(functools.partial(f, my_func))
root = tracking.AutoTrackable()
root.f = func
self.assertEqual(root.f(constant_op.constant(3)).numpy(), 9)
root = self.cycle(root, cycles)
self.assertEqual(root.f(constant_op.constant(3)).numpy(), 9)
def test_convert_to_input_signature(self, cycles):