Basic templating code.
PiperOrigin-RevId: 180964100
This commit is contained in:
parent
620c838312
commit
3a3feb207d
@ -22,6 +22,7 @@ py_library(
|
|||||||
"compiler.py",
|
"compiler.py",
|
||||||
"parser.py",
|
"parser.py",
|
||||||
"pretty_printer.py",
|
"pretty_printer.py",
|
||||||
|
"templates.py",
|
||||||
],
|
],
|
||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
@ -78,3 +79,16 @@ py_test(
|
|||||||
"//tensorflow/python:client_testlib",
|
"//tensorflow/python:client_testlib",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "templates_test",
|
||||||
|
srcs = ["templates_test.py"],
|
||||||
|
tags = [
|
||||||
|
"manual",
|
||||||
|
"notap",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":pyct",
|
||||||
|
"//tensorflow/python:client_testlib",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
112
tensorflow/contrib/py2tf/pyct/templates.py
Normal file
112
tensorflow/contrib/py2tf/pyct/templates.py
Normal file
@ -0,0 +1,112 @@
|
|||||||
|
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""AST conversion templates.
|
||||||
|
|
||||||
|
Adapted from Tangent.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import ast
|
||||||
|
|
||||||
|
import gast
|
||||||
|
|
||||||
|
from tensorflow.contrib.py2tf.pyct import parser
|
||||||
|
|
||||||
|
|
||||||
|
class ReplaceTransformer(gast.NodeTransformer):
|
||||||
|
"""Replace AST nodes."""
|
||||||
|
|
||||||
|
def __init__(self, replacements):
|
||||||
|
"""Create a new ReplaceTransformer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
replacements: A mapping from placeholder names to (lists of) AST nodes
|
||||||
|
that these placeholders will be replaced by.
|
||||||
|
"""
|
||||||
|
self.replacements = replacements
|
||||||
|
|
||||||
|
# TODO(mdan): Make a more detailed pass and clean up if needed.
|
||||||
|
|
||||||
|
def visit_Expr(self, node):
|
||||||
|
if (isinstance(node.value, gast.Name) and
|
||||||
|
node.value.id in self.replacements):
|
||||||
|
return self.visit(node.value)
|
||||||
|
self.generic_visit(node)
|
||||||
|
return node
|
||||||
|
|
||||||
|
def visit_FunctionDef(self, node):
|
||||||
|
node = self.generic_visit(node)
|
||||||
|
if node.name in self.replacements:
|
||||||
|
repl = self.replacements[node.name]
|
||||||
|
if not isinstance(repl, (gast.Name, ast.Name)):
|
||||||
|
raise ValueError(
|
||||||
|
'A function name can only be replaced by a Name node. Found: %s',
|
||||||
|
repl)
|
||||||
|
node.name = repl.id
|
||||||
|
return node
|
||||||
|
|
||||||
|
def visit_Name(self, node):
|
||||||
|
# Note: The caller is reposnsible with making sure the replacement
|
||||||
|
# Name nodes have the proper ctx set up.
|
||||||
|
# TODO(mdan): Is it possible to always infer the proper context here?
|
||||||
|
if node.id in self.replacements:
|
||||||
|
# TODO(mdan): Sanitize the nodes by erasing scope-dependent annotations.
|
||||||
|
new_nodes = self.replacements[node.id]
|
||||||
|
if isinstance(new_nodes, gast.AST):
|
||||||
|
new_nodes = [new_nodes]
|
||||||
|
if len(new_nodes) == 1:
|
||||||
|
new_nodes, = new_nodes
|
||||||
|
return new_nodes
|
||||||
|
else:
|
||||||
|
return node
|
||||||
|
|
||||||
|
|
||||||
|
def replace(template, **replacements):
|
||||||
|
"""Replace placeholders in a Python template.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
template: A function to be used as a template. Any placeholder is expected
|
||||||
|
to also be a function argument.
|
||||||
|
**replacements: A mapping from placeholder names to (lists of) AST nodes
|
||||||
|
that these placeholders will be replaced by.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
body: An AST node or list of AST nodes with the replacements made. If the
|
||||||
|
template was a function, a list will be returned. If the template was a
|
||||||
|
node, the same node will be returned. If the template was a string, an
|
||||||
|
AST node will be returned (a `Module` node in the case of a multi-line
|
||||||
|
string, an `Expr` node otherwise).
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If a function is used as a template and an incorrect set of
|
||||||
|
replacements was passed.
|
||||||
|
"""
|
||||||
|
tree = parser.parse_object(template).body[0]
|
||||||
|
placeholders = set(arg.id for arg in tree.args.args)
|
||||||
|
tree.args.args = []
|
||||||
|
if tree.args.vararg:
|
||||||
|
placeholders.add(tree.args.vararg)
|
||||||
|
tree.args.vararg = None
|
||||||
|
if set(replacements.keys()) != placeholders:
|
||||||
|
raise ValueError(
|
||||||
|
'too many or few replacements. replacements: %s; placeholders: %s' %
|
||||||
|
(replacements.keys(), placeholders))
|
||||||
|
|
||||||
|
# Perform the replacement, stripping the function into which the template was
|
||||||
|
# wrapped.
|
||||||
|
return ReplaceTransformer(replacements).visit(tree).body
|
77
tensorflow/contrib/py2tf/pyct/templates_test.py
Normal file
77
tensorflow/contrib/py2tf/pyct/templates_test.py
Normal file
@ -0,0 +1,77 @@
|
|||||||
|
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Tests for templates module."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import gast
|
||||||
|
|
||||||
|
from tensorflow.contrib.py2tf.pyct import compiler
|
||||||
|
from tensorflow.contrib.py2tf.pyct import templates
|
||||||
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
|
class TemplatesTest(test.TestCase):
|
||||||
|
|
||||||
|
def test_replace_variable(self):
|
||||||
|
def template(a): # pylint:disable=unused-argument
|
||||||
|
def test_fn(a): # pylint:disable=unused-variable
|
||||||
|
a += 1
|
||||||
|
a = 2 * a + 1
|
||||||
|
return b # pylint:disable=undefined-variable
|
||||||
|
|
||||||
|
node = templates.replace(
|
||||||
|
template, a=gast.Name('b', gast.Load(), None))[0]
|
||||||
|
result = compiler.ast_to_object(node)
|
||||||
|
self.assertEquals(7, result.test_fn(2))
|
||||||
|
|
||||||
|
def test_replace_function_name(self):
|
||||||
|
def template(fname): # pylint:disable=unused-argument
|
||||||
|
def fname(a): # pylint:disable=function-redefined
|
||||||
|
a += 1
|
||||||
|
a = 2 * a + 1
|
||||||
|
return a
|
||||||
|
|
||||||
|
node = templates.replace(
|
||||||
|
template, fname=gast.Name('test_fn', gast.Load(), None))[0]
|
||||||
|
result = compiler.ast_to_object(node)
|
||||||
|
self.assertEquals(7, result.test_fn(2))
|
||||||
|
|
||||||
|
def test_code_block(self):
|
||||||
|
def template(block): # pylint:disable=unused-argument
|
||||||
|
def test_fn(a): # pylint:disable=unused-variable
|
||||||
|
block # pylint:disable=pointless-statement
|
||||||
|
return a
|
||||||
|
|
||||||
|
node = templates.replace(
|
||||||
|
template,
|
||||||
|
block=[
|
||||||
|
gast.Assign(
|
||||||
|
[
|
||||||
|
gast.Name('a', gast.Store(), None)
|
||||||
|
],
|
||||||
|
gast.BinOp(
|
||||||
|
gast.Name('a', gast.Load(), None),
|
||||||
|
gast.Add(),
|
||||||
|
gast.Num(1))),
|
||||||
|
] * 2)[0]
|
||||||
|
result = compiler.ast_to_object(node)
|
||||||
|
self.assertEquals(3, result.test_fn(1))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test.main()
|
Loading…
x
Reference in New Issue
Block a user