Upgrade to gast 0.4.

With thanks to @bhack for contributions in #46039.

PiperOrigin-RevId: 350377724
Change-Id: If1dbc77b9c2c7d640720f847292586aca39da80d
This commit is contained in:
Dan Moldovan 2021-01-06 10:13:38 -08:00 committed by TensorFlower Gardener
parent 9b3f838c23
commit e2395993d0
15 changed files with 41 additions and 43 deletions

View File

@ -1298,15 +1298,15 @@ class TFRGen(transformer.CodeGenerator):
# TODO(fengliuai): Here we hardcode the node.slice here to get the index
# type. Use the visit method once the type inference is done.
# slice_val, slice_ty = self.visit(node.slice)
if isinstance(node.slice, ast.Index):
if isinstance(node.slice.value, ast.Constant):
s = node.slice
if not isinstance(s, (ast.Tuple, ast.Slice)):
if isinstance(s, ast.Constant):
# TODO(fengliuai): promote to an assignment
idx_val = self._ssa_name('cst')
self._emit_with_loc(
'\n{} = constant {} : index'.format(idx_val,
node.slice.value.value), node)
'\n{} = constant {} : index'.format(idx_val, s.value), node)
else:
idx_val, _ = self.visit(node.slice.value)
idx_val, _ = self.visit(s)
else:
raise NotImplementedError('non-index slice not supported.')

View File

@ -36,14 +36,15 @@ class SliceTransformer(converter.Base):
def _process_single_assignment(self, target, value):
if not isinstance(target, gast.Subscript):
return None
if not isinstance(target.slice, gast.Index):
s = target.slice
if isinstance(s, (gast.Tuple, gast.Slice)):
return None
template = """
target = ag__.set_item(target, key, item)
"""
return templates.replace(
template, target=target.value, key=target.slice.value, item=value)
template, target=target.value, key=target.slice, item=value)
def visit_Assign(self, node):
node = self.generic_visit(node)
@ -57,7 +58,8 @@ class SliceTransformer(converter.Base):
def visit_Subscript(self, node):
node = self.generic_visit(node)
if not isinstance(node.slice, gast.Index):
s = node.slice
if isinstance(s, (gast.Tuple, gast.Slice)):
return node
if not isinstance(node.ctx, gast.Load):
@ -78,7 +80,7 @@ class SliceTransformer(converter.Base):
opts=ag__.GetItemOpts(element_dtype=dtype))
"""
return templates.replace_as_expression(
template, target=node.value, key=node.slice.value, dtype=dtype)
template, target=node.value, key=s, dtype=dtype)
def transform(node, ctx):

View File

@ -274,7 +274,7 @@ def apply_to_single_assignments(targets, values, apply_fn):
value_el = values.elts[i]
else:
idx = parser.parse_expression(str(i))
value_el = gast.Subscript(values, gast.Index(idx), ctx=gast.Load())
value_el = gast.Subscript(values, idx, ctx=gast.Load())
apply_to_single_assignments(target_el, value_el, apply_fn)
else:
apply_fn(target, values)

View File

@ -208,7 +208,7 @@ class QN(object):
if self.has_subscript():
return gast.Subscript(
value=self.parent.ast(),
slice=gast.Index(self.qn[-1].ast()),
slice=self.qn[-1].ast(),
ctx=CallerMustSetThis)
if self.has_attr():
return gast.Attribute(
@ -247,16 +247,16 @@ class QnResolver(gast.NodeTransformer):
# TODO(mdan): This may no longer apply if we overload getitem.
node = self.generic_visit(node)
s = node.slice
if not isinstance(s, gast.Index):
if isinstance(s, (gast.Tuple, gast.Slice)):
# TODO(mdan): Support range and multi-dimensional indices.
# Continuing silently because some demos use these.
return node
if isinstance(s.value, gast.Constant):
subscript = QN(Literal(s.value.value))
if isinstance(s, gast.Constant) and s.value != Ellipsis:
subscript = QN(Literal(s.value))
else:
# The index may be an expression, case in which a name doesn't make sense.
if anno.hasanno(node.slice.value, anno.Basic.QN):
subscript = anno.getanno(node.slice.value, anno.Basic.QN)
if anno.hasanno(s, anno.Basic.QN):
subscript = anno.getanno(s, anno.Basic.QN)
else:
return node
if anno.hasanno(node.value, anno.Basic.QN):

View File

@ -66,7 +66,7 @@ class QNTest(test.TestCase):
self.assertEqual(str(a_sub_b), 'a[b]')
self.assertEqual(a_sub_b.ssf(), 'a_sub_b')
self.assertEqual(a_sub_b.ast().value.id, 'a')
self.assertEqual(a_sub_b.ast().slice.value.id, 'b')
self.assertEqual(a_sub_b.ast().slice.id, 'b')
self.assertTrue(a_sub_b.is_composite())
self.assertTrue(a_sub_b.has_subscript())
self.assertEqual(a_sub_b.parent.qn, ('a',))
@ -81,9 +81,9 @@ class QNTest(test.TestCase):
self.assertEqual(str(a_sub_b_sub_c), 'a[b[c]]')
self.assertEqual(a_sub_b_sub_c.ssf(), 'a_sub_b_sub_c')
self.assertEqual(a_sub_b_sub_c.ast().value.id, 'a')
self.assertEqual(a_sub_b_sub_c.ast().slice.value.value.id, 'b')
self.assertEqual(a_sub_b_sub_c.ast().slice.value.slice.value.id, 'c')
self.assertEqual(b_sub_c.ast().slice.value.id, 'c')
self.assertEqual(a_sub_b_sub_c.ast().slice.value.id, 'b')
self.assertEqual(a_sub_b_sub_c.ast().slice.slice.id, 'c')
self.assertEqual(b_sub_c.ast().slice.id, 'c')
self.assertEqual(a_sub_b_sub_c.parent.qn, ('a',))
with self.assertRaises(ValueError):
QN('a', 'b')
@ -157,12 +157,12 @@ class QNTest(test.TestCase):
self.assertNotEqual(a_sub_str_b, a_sub_b)
self.assertNotEqual(hash(a_sub_str_b), hash(a_sub_b))
self.assertEqual(a_sub_str_b.ast().slice.value.value, 'b')
self.assertEqual(a_sub_str_b.ast().slice.value, 'b')
self.assertEqual(str(a_sub_str_b), "a['b']")
a_sub_three = QN(a, subscript=QN(qual_names.Literal(3)))
self.assertEqual(a_sub_three.ast().slice.value.value, 3)
self.assertEqual(str(a_sub_three), "a[3]")
self.assertEqual(a_sub_three.ast().slice.value, 3)
self.assertEqual(str(a_sub_three), 'a[3]')
def test_support_set(self):
a = QN('a')

View File

@ -444,9 +444,6 @@ class StmtInferrer(gast.NodeVisitor):
def visit_Expr(self, node):
return self.visit(node.value)
def visit_Index(self, node):
return self.visit(node.value)
def visit_Assign(self, node):
self.rtype = self.visit(node.value)

View File

@ -669,7 +669,7 @@ class TypeInferenceAnalyzerTest(test.TestCase):
self.assertTypes(fn_body[0].value, str)
self.assertTypes(fn_body[0].value.value, list)
self.assertTypes(fn_body[0].value.slice.value, int)
self.assertTypes(fn_body[0].value.slice, int)
def test_tuple_unpacking(self):

View File

@ -221,7 +221,7 @@ class TemplatesTest(test.TestCase, parameterized.TestCase):
template, foo=parser.parse_expression('foo(a[b]).bar'))[0]
function_call_arg = node.body[0].targets[0].value.args[0]
self.assertIsInstance(function_call_arg.ctx, gast.Load)
self.assertIsInstance(function_call_arg.slice.value.ctx, gast.Load)
self.assertIsInstance(function_call_arg.slice.ctx, gast.Load)
def test_replace_call_keyword(self):
template = """

View File

@ -412,7 +412,7 @@ class Base(NodeStateTracker, gast.NodeTransformer):
if isinstance(values, (gast.Tuple, gast.List)):
value_el = values.elts[i]
else:
value_el = gast.Subscript(values, gast.Index(i), ctx=gast.Store())
value_el = gast.Subscript(values, i, ctx=gast.Store())
self.apply_to_single_assignments(target_el, value_el, apply_fn)
else:
# TODO(mdan): Look into allowing to rewrite the AST here.

View File

@ -127,21 +127,20 @@ class _SubscriptUseTracker(transformer.Base):
def visit_Subscript(self, node):
"""Visits nodes with subscript in the AST."""
s = node.slice
if anno.hasanno(node, anno.Basic.QN):
qn = anno.getanno(node, anno.Basic.QN)
if isinstance(node.ctx, gast.Load):
self.reads.add(qn)
elif not isinstance(node.slice, gast.Index):
if anno.hasanno(node, anno.Basic.QN):
self.complex_reads.add(anno.getanno(node, anno.Basic.QN))
elif anno.hasanno(node.value, anno.Basic.QN):
elif isinstance(s, (gast.Tuple, gast.Slice)):
if anno.hasanno(node.value, anno.Basic.QN):
self.complex_reads.add(anno.getanno(node.value, anno.Basic.QN))
value_qn = anno.getanno(node.value, anno.Basic.QN, None)
if value_qn in self.exclude:
node.value = self.generic_visit(node.value)
else:
node.value = self.visit(node.value)
node.slice = self.visit(node.slice)
node.slice = self.visit(s)
return node

View File

@ -469,7 +469,7 @@ install_tensorflow_pip() {
# Install the gast package in the virtualenv. Installing it in user system
# packages does not appear to port it over when creating a virtualenv.
${PIP_BIN_PATH} install --upgrade "gast==0.3.3" || \
${PIP_BIN_PATH} install --upgrade "gast==0.4.0" || \
die "Error: gast install, upgrade FAILED"
}

View File

@ -137,7 +137,7 @@ function install_ubuntu_16_pip_deps {
"${PIP_CMD}" install --user 'wheel ~= 0.35'
"${PIP_CMD}" install --user 'wrapt ~= 1.12.1'
# We need to pin gast dependency exactly
"${PIP_CMD}" install --user 'gast == 0.3.3'
"${PIP_CMD}" install --user 'gast == 0.4.0'
# Finally, install tensorboard and estimator
# Note that here we want the latest version that matches (b/156523241)
"${PIP_CMD}" install --user --upgrade --force-reinstall 'tb-nightly ~= 2.4.0.a'
@ -199,7 +199,7 @@ function install_macos_pip_deps {
${PIP_CMD} install $USER_FLAG 'wheel ~= 0.35'
${PIP_CMD} install $USER_FLAG 'wrapt ~= 1.12.1'
# We need to pin gast dependency exactly
${PIP_CMD} install $USER_FLAG 'gast == 0.3.3'
${PIP_CMD} install $USER_FLAG 'gast == 0.4.0'
# Finally, install tensorboard and estimator
# Note that here we want the latest version that matches (b/156523241)
${PIP_CMD} install $USER_FLAG --upgrade --force-reinstall 'tb-nightly ~= 2.4.0.a'

View File

@ -43,7 +43,7 @@ SET PATH=%PATH%;C:\%PYTHON_DIRECTORY%
%PY_EXE% -m pip install "wheel ~= 0.35"
%PY_EXE% -m pip install "wrapt ~= 1.12.1"
@REM We need to pin gast dependency exactly
%PY_EXE% -m pip install "gast == 0.3.3"
%PY_EXE% -m pip install "gast == 0.4.0"
@REM Finally, install tensorboard and estimator
@REM Note that here we want the latest version that matches (b/156523241)
%PY_EXE% -m pip install --upgrade --force-reinstall "tb-nightly ~= 2.4.0.a"

View File

@ -91,7 +91,7 @@ REQUIRED_PACKAGES = [
'wrapt ~= 1.12.1',
# These packages need to be pinned exactly as newer versions are
# incompatible with the rest of the ecosystem
'gast == 0.3.3',
'gast == 0.4.0',
# TensorFlow ecosystem packages that TF exposes API for
# These need to be in sync with the existing TF version
# They are updated during the release process

View File

@ -450,12 +450,12 @@ def tf_repositories(path_prefix = "", tf_repo_name = ""):
tf_http_archive(
name = "gast_archive",
build_file = clean_dep("//third_party:gast.BUILD"),
sha256 = "b881ef288a49aa81440d2c5eb8aeefd4c2bb8993d5f50edae7413a85bfdb3b57",
strip_prefix = "gast-0.3.3",
sha256 = "40feb7b8b8434785585ab224d1568b857edb18297e5a3047f1ba012bc83b42c1",
strip_prefix = "gast-0.4.0",
system_build_file = clean_dep("//third_party/systemlibs:gast.BUILD"),
urls = [
"http://mirror.tensorflow.org/files.pythonhosted.org/packages/12/59/eaa15ab9710a20e22225efd042cd2d6a0b559a0656d5baba9641a2a4a921/gast-0.3.3.tar.gz",
"https://files.pythonhosted.org/packages/12/59/eaa15ab9710a20e22225efd042cd2d6a0b559a0656d5baba9641a2a4a921/gast-0.3.3.tar.gz",
"http://mirror.tensorflow.org/files.pythonhosted.org/packages/12/59/eaa15ab9710a20e22225efd042cd2d6a0b559a0656d5baba9641a2a4a921/gast-0.4.0.tar.gz",
"https://files.pythonhosted.org/packages/83/4a/07c7e59cef23fb147454663c3271c21da68ba2ab141427c20548ae5a8a4d/gast-0.4.0.tar.gz",
],
)