Add an owner set property to QN objects. This assists extracting symbols that are affected when objects or containers are mutated in place. Also included is a utility method that simplifies constructing a QN from string.

PiperOrigin-RevId: 204303017
This commit is contained in:
Dan Moldovan 2018-07-12 08:22:15 -07:00 committed by TensorFlower Gardener
parent 4665633c5f
commit 34a1b6780b
2 changed files with 35 additions and 2 deletions

View File

@ -30,6 +30,7 @@ import collections
import gast
from tensorflow.contrib.autograph.pyct import anno
from tensorflow.contrib.autograph.pyct import parser
class Symbol(collections.namedtuple('Symbol', ['name'])):
@ -89,7 +90,8 @@ class QN(object):
if not isinstance(base, (str, StringLiteral, NumberLiteral)):
# TODO(mdan): Require Symbol instead of string.
raise ValueError(
'For simple QNs, base must be a string or a Literal object.')
'for simple QNs, base must be a string or a Literal object;'
' got instead "%s"' % type(base))
assert '.' not in base and '[' not in base and ']' not in base
self._parent = None
self.qn = (base,)
@ -112,6 +114,22 @@ class QN(object):
raise ValueError('Cannot get parent of simple name "%s".' % self.qn[0])
return self._parent
@property
def owner_set(self):
"""Returns all the symbols (simple or composite) that own this QN.
In other words, if this symbol was modified, the symbols in the owner set
may also be affected.
Examples:
'a.b[c.d]' has two owners, 'a' and 'a.b'
"""
owners = set()
if self.has_attr() or self.has_subscript():
owners.add(self.parent)
owners.update(self.parent.owner_set)
return owners
@property
def support_set(self):
"""Returns the set of simple symbols that this QN relies on.
@ -122,7 +140,7 @@ class QN(object):
Examples:
'a.b' has only one support symbol, 'a'
'a[i]' has two roots, 'a' and 'i'
'a[i]' has two support symbols, 'a' and 'i'
"""
# TODO(mdan): This might be the set of Name nodes in the AST. Track those?
roots = set()
@ -231,3 +249,9 @@ class QnResolver(gast.NodeTransformer):
def resolve(node):
return QnResolver().visit(node)
def from_str(qn_str):
node = parser.parse_expression(qn_str)
node = resolve(node)
return anno.getanno(node, anno.Basic.QN)

View File

@ -30,6 +30,15 @@ from tensorflow.python.platform import test
class QNTest(test.TestCase):
def test_from_str(self):
a = QN('a')
b = QN('b')
a_dot_b = QN(a, attr='b')
a_sub_b = QN(a, subscript=b)
self.assertEqual(qual_names.from_str('a.b'), a_dot_b)
self.assertEqual(qual_names.from_str('a'), a)
self.assertEqual(qual_names.from_str('a[b]'), a_sub_b)
def test_basic(self):
a = QN('a')
self.assertEqual(a.qn, ('a',))