Add an owner set property to QN objects. This assists extracting symbols that are affected when objects or containers are mutated in place. Also included is a utility method that simplifies constructing a QN from string.
PiperOrigin-RevId: 204303017
This commit is contained in:
parent
4665633c5f
commit
34a1b6780b
@ -30,6 +30,7 @@ import collections
|
||||
import gast
|
||||
|
||||
from tensorflow.contrib.autograph.pyct import anno
|
||||
from tensorflow.contrib.autograph.pyct import parser
|
||||
|
||||
|
||||
class Symbol(collections.namedtuple('Symbol', ['name'])):
|
||||
@ -89,7 +90,8 @@ class QN(object):
|
||||
if not isinstance(base, (str, StringLiteral, NumberLiteral)):
|
||||
# TODO(mdan): Require Symbol instead of string.
|
||||
raise ValueError(
|
||||
'For simple QNs, base must be a string or a Literal object.')
|
||||
'for simple QNs, base must be a string or a Literal object;'
|
||||
' got instead "%s"' % type(base))
|
||||
assert '.' not in base and '[' not in base and ']' not in base
|
||||
self._parent = None
|
||||
self.qn = (base,)
|
||||
@ -112,6 +114,22 @@ class QN(object):
|
||||
raise ValueError('Cannot get parent of simple name "%s".' % self.qn[0])
|
||||
return self._parent
|
||||
|
||||
@property
|
||||
def owner_set(self):
|
||||
"""Returns all the symbols (simple or composite) that own this QN.
|
||||
|
||||
In other words, if this symbol was modified, the symbols in the owner set
|
||||
may also be affected.
|
||||
|
||||
Examples:
|
||||
'a.b[c.d]' has two owners, 'a' and 'a.b'
|
||||
"""
|
||||
owners = set()
|
||||
if self.has_attr() or self.has_subscript():
|
||||
owners.add(self.parent)
|
||||
owners.update(self.parent.owner_set)
|
||||
return owners
|
||||
|
||||
@property
|
||||
def support_set(self):
|
||||
"""Returns the set of simple symbols that this QN relies on.
|
||||
@ -122,7 +140,7 @@ class QN(object):
|
||||
|
||||
Examples:
|
||||
'a.b' has only one support symbol, 'a'
|
||||
'a[i]' has two roots, 'a' and 'i'
|
||||
'a[i]' has two support symbols, 'a' and 'i'
|
||||
"""
|
||||
# TODO(mdan): This might be the set of Name nodes in the AST. Track those?
|
||||
roots = set()
|
||||
@ -231,3 +249,9 @@ class QnResolver(gast.NodeTransformer):
|
||||
|
||||
def resolve(node):
|
||||
return QnResolver().visit(node)
|
||||
|
||||
|
||||
def from_str(qn_str):
|
||||
node = parser.parse_expression(qn_str)
|
||||
node = resolve(node)
|
||||
return anno.getanno(node, anno.Basic.QN)
|
||||
|
@ -30,6 +30,15 @@ from tensorflow.python.platform import test
|
||||
|
||||
class QNTest(test.TestCase):
|
||||
|
||||
def test_from_str(self):
|
||||
a = QN('a')
|
||||
b = QN('b')
|
||||
a_dot_b = QN(a, attr='b')
|
||||
a_sub_b = QN(a, subscript=b)
|
||||
self.assertEqual(qual_names.from_str('a.b'), a_dot_b)
|
||||
self.assertEqual(qual_names.from_str('a'), a)
|
||||
self.assertEqual(qual_names.from_str('a[b]'), a_sub_b)
|
||||
|
||||
def test_basic(self):
|
||||
a = QN('a')
|
||||
self.assertEqual(a.qn, ('a',))
|
||||
|
Loading…
x
Reference in New Issue
Block a user