Enable support for default argument values.

PiperOrigin-RevId: 218907219
This commit is contained in:
Dan Moldovan 2018-10-26 13:44:36 -07:00 committed by TensorFlower Gardener
parent f9486f4879
commit 62363a4319
4 changed files with 35 additions and 0 deletions
tensorflow/python/autograph/impl

View File

@ -318,6 +318,9 @@ def to_graph(e,
compiled_module.__dict__[key] = val
compiled = getattr(compiled_module, name)
if tf_inspect.isfunction(e):
compiled.__defaults__ = e.__defaults__
# Need this so the source_mapping attribute is available for the context
# manager to access for runtime errors.
#

View File

@ -309,6 +309,21 @@ class ApiTest(test.TestCase):
x = compiled_fn(constant_op.constant([4, 8]), 4)
self.assertListEqual([1, 2], sess.run(x).tolist())
def test_to_graph_with_defaults(self):
foo = 4
def test_fn(x, s=foo):
while tf.reduce_sum(x) > s:
x //= 2
return x
compiled_fn = api.to_graph(test_fn)
with self.cached_session() as sess:
x = compiled_fn(constant_op.constant([4, 8]))
self.assertListEqual([1, 2], sess.run(x).tolist())
def test_to_code_basic(self):
def test_fn(x, s):

View File

@ -24,6 +24,7 @@ import gast
from tensorflow.python.autograph import operators
from tensorflow.python.autograph import utils
from tensorflow.python.autograph.converters import arg_defaults
from tensorflow.python.autograph.converters import asserts
from tensorflow.python.autograph.converters import break_statements
from tensorflow.python.autograph.converters import builtin_functions
@ -342,6 +343,7 @@ def node_to_graph(node, context, rewrite_errors=True):
if context.program.options.uses(converter.Feature.DECORATORS):
node = converter.apply_(node, context, decorators)
node = converter.apply_(node, context, arg_defaults)
node = converter.apply_(node, context, directives)
node = converter.apply_(node, context, break_statements)
node = converter.apply_(node, context, asserts)

View File

@ -24,6 +24,7 @@ from tensorflow.python.autograph import utils
from tensorflow.python.autograph.core import config
from tensorflow.python.autograph.core import converter
from tensorflow.python.autograph.impl import api
from tensorflow.python.autograph.pyct import compiler
from tensorflow.python.autograph.impl import conversion
from tensorflow.python.framework import constant_op
from tensorflow.python.keras.engine import training
@ -65,6 +66,20 @@ class ConversionTest(test.TestCase):
self.assertEqual('tf__f', name)
self.assertIs(ns['b'], b)
def test_entity_to_graph_function_with_defaults(self):
b = 2
c = 1
def f(a, d=c + 1):
return a + b + d
program_ctx = self._simple_program_ctx()
nodes, name, _ = conversion.entity_to_graph(f, program_ctx, None, None)
fn_node, _ = nodes
self.assertIsInstance(fn_node, gast.FunctionDef)
self.assertEqual('tf__f', name)
self.assertEqual(
compiler.ast_to_source(fn_node.args.defaults[0]).strip(), 'None')
def test_entity_to_graph_call_tree(self):
def g(a):