Basic templating code.

PiperOrigin-RevId: 180964100
This commit is contained in:
A. Unique TensorFlower 2018-01-05 12:58:00 -08:00 committed by TensorFlower Gardener
parent 620c838312
commit 3a3feb207d
3 changed files with 203 additions and 0 deletions

View File

@ -22,6 +22,7 @@ py_library(
"compiler.py",
"parser.py",
"pretty_printer.py",
"templates.py",
],
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
@ -78,3 +79,16 @@ py_test(
"//tensorflow/python:client_testlib",
],
)
py_test(
name = "templates_test",
srcs = ["templates_test.py"],
tags = [
"manual",
"notap",
],
deps = [
":pyct",
"//tensorflow/python:client_testlib",
],
)

View File

@ -0,0 +1,112 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""AST conversion templates.
Adapted from Tangent.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import ast
import gast
from tensorflow.contrib.py2tf.pyct import parser
class ReplaceTransformer(gast.NodeTransformer):
"""Replace AST nodes."""
def __init__(self, replacements):
"""Create a new ReplaceTransformer.
Args:
replacements: A mapping from placeholder names to (lists of) AST nodes
that these placeholders will be replaced by.
"""
self.replacements = replacements
# TODO(mdan): Make a more detailed pass and clean up if needed.
def visit_Expr(self, node):
if (isinstance(node.value, gast.Name) and
node.value.id in self.replacements):
return self.visit(node.value)
self.generic_visit(node)
return node
def visit_FunctionDef(self, node):
node = self.generic_visit(node)
if node.name in self.replacements:
repl = self.replacements[node.name]
if not isinstance(repl, (gast.Name, ast.Name)):
raise ValueError(
'A function name can only be replaced by a Name node. Found: %s',
repl)
node.name = repl.id
return node
def visit_Name(self, node):
# Note: The caller is reposnsible with making sure the replacement
# Name nodes have the proper ctx set up.
# TODO(mdan): Is it possible to always infer the proper context here?
if node.id in self.replacements:
# TODO(mdan): Sanitize the nodes by erasing scope-dependent annotations.
new_nodes = self.replacements[node.id]
if isinstance(new_nodes, gast.AST):
new_nodes = [new_nodes]
if len(new_nodes) == 1:
new_nodes, = new_nodes
return new_nodes
else:
return node
def replace(template, **replacements):
"""Replace placeholders in a Python template.
Args:
template: A function to be used as a template. Any placeholder is expected
to also be a function argument.
**replacements: A mapping from placeholder names to (lists of) AST nodes
that these placeholders will be replaced by.
Returns:
body: An AST node or list of AST nodes with the replacements made. If the
template was a function, a list will be returned. If the template was a
node, the same node will be returned. If the template was a string, an
AST node will be returned (a `Module` node in the case of a multi-line
string, an `Expr` node otherwise).
Raises:
ValueError: If a function is used as a template and an incorrect set of
replacements was passed.
"""
tree = parser.parse_object(template).body[0]
placeholders = set(arg.id for arg in tree.args.args)
tree.args.args = []
if tree.args.vararg:
placeholders.add(tree.args.vararg)
tree.args.vararg = None
if set(replacements.keys()) != placeholders:
raise ValueError(
'too many or few replacements. replacements: %s; placeholders: %s' %
(replacements.keys(), placeholders))
# Perform the replacement, stripping the function into which the template was
# wrapped.
return ReplaceTransformer(replacements).visit(tree).body

View File

@ -0,0 +1,77 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for templates module."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import gast
from tensorflow.contrib.py2tf.pyct import compiler
from tensorflow.contrib.py2tf.pyct import templates
from tensorflow.python.platform import test
class TemplatesTest(test.TestCase):
def test_replace_variable(self):
def template(a): # pylint:disable=unused-argument
def test_fn(a): # pylint:disable=unused-variable
a += 1
a = 2 * a + 1
return b # pylint:disable=undefined-variable
node = templates.replace(
template, a=gast.Name('b', gast.Load(), None))[0]
result = compiler.ast_to_object(node)
self.assertEquals(7, result.test_fn(2))
def test_replace_function_name(self):
def template(fname): # pylint:disable=unused-argument
def fname(a): # pylint:disable=function-redefined
a += 1
a = 2 * a + 1
return a
node = templates.replace(
template, fname=gast.Name('test_fn', gast.Load(), None))[0]
result = compiler.ast_to_object(node)
self.assertEquals(7, result.test_fn(2))
def test_code_block(self):
def template(block): # pylint:disable=unused-argument
def test_fn(a): # pylint:disable=unused-variable
block # pylint:disable=pointless-statement
return a
node = templates.replace(
template,
block=[
gast.Assign(
[
gast.Name('a', gast.Store(), None)
],
gast.BinOp(
gast.Name('a', gast.Load(), None),
gast.Add(),
gast.Num(1))),
] * 2)[0]
result = compiler.ast_to_object(node)
self.assertEquals(3, result.test_fn(1))
if __name__ == '__main__':
test.main()