Enable support for default argument values.
PiperOrigin-RevId: 218907219
This commit is contained in:
parent
f9486f4879
commit
62363a4319
tensorflow/python/autograph/impl
@ -318,6 +318,9 @@ def to_graph(e,
|
||||
compiled_module.__dict__[key] = val
|
||||
compiled = getattr(compiled_module, name)
|
||||
|
||||
if tf_inspect.isfunction(e):
|
||||
compiled.__defaults__ = e.__defaults__
|
||||
|
||||
# Need this so the source_mapping attribute is available for the context
|
||||
# manager to access for runtime errors.
|
||||
#
|
||||
|
@ -309,6 +309,21 @@ class ApiTest(test.TestCase):
|
||||
x = compiled_fn(constant_op.constant([4, 8]), 4)
|
||||
self.assertListEqual([1, 2], sess.run(x).tolist())
|
||||
|
||||
def test_to_graph_with_defaults(self):
|
||||
|
||||
foo = 4
|
||||
|
||||
def test_fn(x, s=foo):
|
||||
while tf.reduce_sum(x) > s:
|
||||
x //= 2
|
||||
return x
|
||||
|
||||
compiled_fn = api.to_graph(test_fn)
|
||||
|
||||
with self.cached_session() as sess:
|
||||
x = compiled_fn(constant_op.constant([4, 8]))
|
||||
self.assertListEqual([1, 2], sess.run(x).tolist())
|
||||
|
||||
def test_to_code_basic(self):
|
||||
|
||||
def test_fn(x, s):
|
||||
|
@ -24,6 +24,7 @@ import gast
|
||||
|
||||
from tensorflow.python.autograph import operators
|
||||
from tensorflow.python.autograph import utils
|
||||
from tensorflow.python.autograph.converters import arg_defaults
|
||||
from tensorflow.python.autograph.converters import asserts
|
||||
from tensorflow.python.autograph.converters import break_statements
|
||||
from tensorflow.python.autograph.converters import builtin_functions
|
||||
@ -342,6 +343,7 @@ def node_to_graph(node, context, rewrite_errors=True):
|
||||
|
||||
if context.program.options.uses(converter.Feature.DECORATORS):
|
||||
node = converter.apply_(node, context, decorators)
|
||||
node = converter.apply_(node, context, arg_defaults)
|
||||
node = converter.apply_(node, context, directives)
|
||||
node = converter.apply_(node, context, break_statements)
|
||||
node = converter.apply_(node, context, asserts)
|
||||
|
@ -24,6 +24,7 @@ from tensorflow.python.autograph import utils
|
||||
from tensorflow.python.autograph.core import config
|
||||
from tensorflow.python.autograph.core import converter
|
||||
from tensorflow.python.autograph.impl import api
|
||||
from tensorflow.python.autograph.pyct import compiler
|
||||
from tensorflow.python.autograph.impl import conversion
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.keras.engine import training
|
||||
@ -65,6 +66,20 @@ class ConversionTest(test.TestCase):
|
||||
self.assertEqual('tf__f', name)
|
||||
self.assertIs(ns['b'], b)
|
||||
|
||||
def test_entity_to_graph_function_with_defaults(self):
|
||||
b = 2
|
||||
c = 1
|
||||
def f(a, d=c + 1):
|
||||
return a + b + d
|
||||
|
||||
program_ctx = self._simple_program_ctx()
|
||||
nodes, name, _ = conversion.entity_to_graph(f, program_ctx, None, None)
|
||||
fn_node, _ = nodes
|
||||
self.assertIsInstance(fn_node, gast.FunctionDef)
|
||||
self.assertEqual('tf__f', name)
|
||||
self.assertEqual(
|
||||
compiler.ast_to_source(fn_node.args.defaults[0]).strip(), 'None')
|
||||
|
||||
def test_entity_to_graph_call_tree(self):
|
||||
|
||||
def g(a):
|
||||
|
Loading…
Reference in New Issue
Block a user