Add more tests for partial functions in tf.function.
PiperOrigin-RevId: 239377543
This commit is contained in:
parent
73cb55b411
commit
ec1befe9d8
@ -246,6 +246,18 @@ class DefFunctionTest(test.TestCase):
|
||||
def_function.function(functools.partial(lambda x, y: x + y, 1.))(
|
||||
constant_op.constant(2.)))
|
||||
|
||||
def test_functools_partial_single_keyword(self):
|
||||
def f(x, y):
|
||||
return x + y
|
||||
|
||||
func = def_function.function(
|
||||
functools.partial(f, x=constant_op.constant(1)))
|
||||
|
||||
# This is a limitation of functools.partial.
|
||||
with self.assertRaisesRegexp(
|
||||
TypeError, 'got multiple values for'):
|
||||
func(5)
|
||||
|
||||
def test_functools_partial_keywords(self):
|
||||
def f(x, y):
|
||||
return x + y
|
||||
@ -254,6 +266,14 @@ class DefFunctionTest(test.TestCase):
|
||||
functools.partial(f, x=array_ops.zeros([1]), y=array_ops.zeros([1])))
|
||||
self.assertAllEqual(func(), [0.0])
|
||||
|
||||
def test_functools_partial_single_positional(self):
|
||||
def f(x, y):
|
||||
return x + y
|
||||
|
||||
func = def_function.function(
|
||||
functools.partial(f, constant_op.constant(1)))
|
||||
self.assertAllEqual(func(5), 6)
|
||||
|
||||
def test_unspecified_default_argument(self):
|
||||
wrapped = def_function.function(
|
||||
lambda x, y=2: x + y,
|
||||
|
||||
@ -680,19 +680,22 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
fn2 = lambda: array_ops.ones([10]) * 2
|
||||
|
||||
def fn3(x=2):
|
||||
def fn3(x=3):
|
||||
return array_ops.ones([10]) * x
|
||||
fn3 = functools.partial(fn3, x=3)
|
||||
fn4 = functools.partial(fn3, x=4)
|
||||
fn5 = functools.partial(fn3, 5)
|
||||
|
||||
return gen_functional_ops.case(val, [], [dtypes.float32],
|
||||
[function.defun(f).get_concrete_function()
|
||||
for f in (fn1, fn2, fn3)])
|
||||
for f in (fn1, fn2, fn3, fn4, fn5)])
|
||||
|
||||
ones = array_ops.ones([10])
|
||||
self.assertAllEqual([ones], test_function(0))
|
||||
self.assertAllEqual([ones * 2], test_function(1))
|
||||
self.assertAllEqual([ones * 3], test_function(2))
|
||||
self.assertAllEqual([ones * 3], test_function(22)) # default branch
|
||||
self.assertAllEqual([ones * 4], test_function(3))
|
||||
self.assertAllEqual([ones * 5], test_function(4))
|
||||
self.assertAllEqual([ones * 5], test_function(22)) # default branch
|
||||
|
||||
@test_util.enable_control_flow_v2
|
||||
def testVariableInLoopInFunction(self):
|
||||
|
||||
@ -1140,22 +1140,69 @@ class LoadTest(test.TestCase, parameterized.TestCase):
|
||||
self.assertAllEqual(4, root.f(constant_op.constant(3)).numpy())
|
||||
|
||||
def test_partial(self, cycles):
|
||||
# TODO(vbardiovsky): Figure out the story for FunctionSpec vs partial vs
|
||||
# input_signature.
|
||||
# TODO(b/124441704): Figure out the story for FunctionSpec vs partial.
|
||||
self.skipTest("Partial does not work for serialization.")
|
||||
|
||||
def f(x, y):
|
||||
return x + y
|
||||
|
||||
func = def_function.function(
|
||||
functools.partial(f, x=array_ops.zeros([1]), y=array_ops.zeros([1])))
|
||||
functools.partial(f, x=array_ops.zeros([1]), y=array_ops.ones([1])))
|
||||
|
||||
root = tracking.AutoTrackable()
|
||||
root.f = func
|
||||
self.assertAllEqual(root.f(), [0.0])
|
||||
self.assertAllEqual(root.f(), [1.0])
|
||||
|
||||
root = self.cycle(root, cycles)
|
||||
self.assertAllEqual(root.f(), [0.0])
|
||||
self.assertAllEqual(root.f(), [1.0])
|
||||
|
||||
def test_partial_with_non_tensor_defaults(self, cycles):
|
||||
def f(x, y):
|
||||
return x + y
|
||||
|
||||
func = def_function.function(functools.partial(f, y=5))
|
||||
|
||||
root = tracking.AutoTrackable()
|
||||
root.f = func
|
||||
self.assertAllEqual(root.f(1), 6)
|
||||
|
||||
root = self.cycle(root, cycles)
|
||||
self.assertAllEqual(root.f(1), 6)
|
||||
|
||||
def test_partial_with_positional(self, cycles):
|
||||
# TODO(b/124441704): Figure out the story for FunctionSpec vs partial.
|
||||
self.skipTest("Partial does not work for serialization.")
|
||||
|
||||
def f(x, y):
|
||||
return x + y
|
||||
|
||||
func = def_function.function(functools.partial(f, constant_op.constant(5)))
|
||||
|
||||
root = tracking.AutoTrackable()
|
||||
root.f = func
|
||||
self.assertAllEqual(root.f(1), 6)
|
||||
|
||||
root = self.cycle(root, cycles)
|
||||
self.assertAllEqual(root.f(1), 6)
|
||||
|
||||
def test_partial_with_passed_fn_as_default(self, cycles):
|
||||
# TODO(b/124441704): Figure out the story for FunctionSpec vs partial.
|
||||
self.skipTest("Partial does not work for serialization.")
|
||||
|
||||
def f(x, y):
|
||||
return x(3) + y
|
||||
|
||||
def my_func(a):
|
||||
return 2 * a
|
||||
|
||||
func = def_function.function(functools.partial(f, my_func))
|
||||
|
||||
root = tracking.AutoTrackable()
|
||||
root.f = func
|
||||
self.assertEqual(root.f(constant_op.constant(3)).numpy(), 9)
|
||||
|
||||
root = self.cycle(root, cycles)
|
||||
self.assertEqual(root.f(constant_op.constant(3)).numpy(), 9)
|
||||
|
||||
def test_convert_to_input_signature(self, cycles):
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user