Upgrade to gast 0.4.
With thanks to @bhack for contributions in #46039. PiperOrigin-RevId: 350377724 Change-Id: If1dbc77b9c2c7d640720f847292586aca39da80d
This commit is contained in:
parent
9b3f838c23
commit
e2395993d0
@ -1298,15 +1298,15 @@ class TFRGen(transformer.CodeGenerator):
|
|||||||
# TODO(fengliuai): Here we hardcode the node.slice here to get the index
|
# TODO(fengliuai): Here we hardcode the node.slice here to get the index
|
||||||
# type. Use the visit method once the type inference is done.
|
# type. Use the visit method once the type inference is done.
|
||||||
# slice_val, slice_ty = self.visit(node.slice)
|
# slice_val, slice_ty = self.visit(node.slice)
|
||||||
if isinstance(node.slice, ast.Index):
|
s = node.slice
|
||||||
if isinstance(node.slice.value, ast.Constant):
|
if not isinstance(s, (ast.Tuple, ast.Slice)):
|
||||||
|
if isinstance(s, ast.Constant):
|
||||||
# TODO(fengliuai): promote to an assignment
|
# TODO(fengliuai): promote to an assignment
|
||||||
idx_val = self._ssa_name('cst')
|
idx_val = self._ssa_name('cst')
|
||||||
self._emit_with_loc(
|
self._emit_with_loc(
|
||||||
'\n{} = constant {} : index'.format(idx_val,
|
'\n{} = constant {} : index'.format(idx_val, s.value), node)
|
||||||
node.slice.value.value), node)
|
|
||||||
else:
|
else:
|
||||||
idx_val, _ = self.visit(node.slice.value)
|
idx_val, _ = self.visit(s)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError('non-index slice not supported.')
|
raise NotImplementedError('non-index slice not supported.')
|
||||||
|
|
||||||
|
@ -36,14 +36,15 @@ class SliceTransformer(converter.Base):
|
|||||||
def _process_single_assignment(self, target, value):
|
def _process_single_assignment(self, target, value):
|
||||||
if not isinstance(target, gast.Subscript):
|
if not isinstance(target, gast.Subscript):
|
||||||
return None
|
return None
|
||||||
if not isinstance(target.slice, gast.Index):
|
s = target.slice
|
||||||
|
if isinstance(s, (gast.Tuple, gast.Slice)):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
template = """
|
template = """
|
||||||
target = ag__.set_item(target, key, item)
|
target = ag__.set_item(target, key, item)
|
||||||
"""
|
"""
|
||||||
return templates.replace(
|
return templates.replace(
|
||||||
template, target=target.value, key=target.slice.value, item=value)
|
template, target=target.value, key=target.slice, item=value)
|
||||||
|
|
||||||
def visit_Assign(self, node):
|
def visit_Assign(self, node):
|
||||||
node = self.generic_visit(node)
|
node = self.generic_visit(node)
|
||||||
@ -57,7 +58,8 @@ class SliceTransformer(converter.Base):
|
|||||||
|
|
||||||
def visit_Subscript(self, node):
|
def visit_Subscript(self, node):
|
||||||
node = self.generic_visit(node)
|
node = self.generic_visit(node)
|
||||||
if not isinstance(node.slice, gast.Index):
|
s = node.slice
|
||||||
|
if isinstance(s, (gast.Tuple, gast.Slice)):
|
||||||
return node
|
return node
|
||||||
|
|
||||||
if not isinstance(node.ctx, gast.Load):
|
if not isinstance(node.ctx, gast.Load):
|
||||||
@ -78,7 +80,7 @@ class SliceTransformer(converter.Base):
|
|||||||
opts=ag__.GetItemOpts(element_dtype=dtype))
|
opts=ag__.GetItemOpts(element_dtype=dtype))
|
||||||
"""
|
"""
|
||||||
return templates.replace_as_expression(
|
return templates.replace_as_expression(
|
||||||
template, target=node.value, key=node.slice.value, dtype=dtype)
|
template, target=node.value, key=s, dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
def transform(node, ctx):
|
def transform(node, ctx):
|
||||||
|
@ -274,7 +274,7 @@ def apply_to_single_assignments(targets, values, apply_fn):
|
|||||||
value_el = values.elts[i]
|
value_el = values.elts[i]
|
||||||
else:
|
else:
|
||||||
idx = parser.parse_expression(str(i))
|
idx = parser.parse_expression(str(i))
|
||||||
value_el = gast.Subscript(values, gast.Index(idx), ctx=gast.Load())
|
value_el = gast.Subscript(values, idx, ctx=gast.Load())
|
||||||
apply_to_single_assignments(target_el, value_el, apply_fn)
|
apply_to_single_assignments(target_el, value_el, apply_fn)
|
||||||
else:
|
else:
|
||||||
apply_fn(target, values)
|
apply_fn(target, values)
|
||||||
|
@ -208,7 +208,7 @@ class QN(object):
|
|||||||
if self.has_subscript():
|
if self.has_subscript():
|
||||||
return gast.Subscript(
|
return gast.Subscript(
|
||||||
value=self.parent.ast(),
|
value=self.parent.ast(),
|
||||||
slice=gast.Index(self.qn[-1].ast()),
|
slice=self.qn[-1].ast(),
|
||||||
ctx=CallerMustSetThis)
|
ctx=CallerMustSetThis)
|
||||||
if self.has_attr():
|
if self.has_attr():
|
||||||
return gast.Attribute(
|
return gast.Attribute(
|
||||||
@ -247,16 +247,16 @@ class QnResolver(gast.NodeTransformer):
|
|||||||
# TODO(mdan): This may no longer apply if we overload getitem.
|
# TODO(mdan): This may no longer apply if we overload getitem.
|
||||||
node = self.generic_visit(node)
|
node = self.generic_visit(node)
|
||||||
s = node.slice
|
s = node.slice
|
||||||
if not isinstance(s, gast.Index):
|
if isinstance(s, (gast.Tuple, gast.Slice)):
|
||||||
# TODO(mdan): Support range and multi-dimensional indices.
|
# TODO(mdan): Support range and multi-dimensional indices.
|
||||||
# Continuing silently because some demos use these.
|
# Continuing silently because some demos use these.
|
||||||
return node
|
return node
|
||||||
if isinstance(s.value, gast.Constant):
|
if isinstance(s, gast.Constant) and s.value != Ellipsis:
|
||||||
subscript = QN(Literal(s.value.value))
|
subscript = QN(Literal(s.value))
|
||||||
else:
|
else:
|
||||||
# The index may be an expression, case in which a name doesn't make sense.
|
# The index may be an expression, case in which a name doesn't make sense.
|
||||||
if anno.hasanno(node.slice.value, anno.Basic.QN):
|
if anno.hasanno(s, anno.Basic.QN):
|
||||||
subscript = anno.getanno(node.slice.value, anno.Basic.QN)
|
subscript = anno.getanno(s, anno.Basic.QN)
|
||||||
else:
|
else:
|
||||||
return node
|
return node
|
||||||
if anno.hasanno(node.value, anno.Basic.QN):
|
if anno.hasanno(node.value, anno.Basic.QN):
|
||||||
|
@ -66,7 +66,7 @@ class QNTest(test.TestCase):
|
|||||||
self.assertEqual(str(a_sub_b), 'a[b]')
|
self.assertEqual(str(a_sub_b), 'a[b]')
|
||||||
self.assertEqual(a_sub_b.ssf(), 'a_sub_b')
|
self.assertEqual(a_sub_b.ssf(), 'a_sub_b')
|
||||||
self.assertEqual(a_sub_b.ast().value.id, 'a')
|
self.assertEqual(a_sub_b.ast().value.id, 'a')
|
||||||
self.assertEqual(a_sub_b.ast().slice.value.id, 'b')
|
self.assertEqual(a_sub_b.ast().slice.id, 'b')
|
||||||
self.assertTrue(a_sub_b.is_composite())
|
self.assertTrue(a_sub_b.is_composite())
|
||||||
self.assertTrue(a_sub_b.has_subscript())
|
self.assertTrue(a_sub_b.has_subscript())
|
||||||
self.assertEqual(a_sub_b.parent.qn, ('a',))
|
self.assertEqual(a_sub_b.parent.qn, ('a',))
|
||||||
@ -81,9 +81,9 @@ class QNTest(test.TestCase):
|
|||||||
self.assertEqual(str(a_sub_b_sub_c), 'a[b[c]]')
|
self.assertEqual(str(a_sub_b_sub_c), 'a[b[c]]')
|
||||||
self.assertEqual(a_sub_b_sub_c.ssf(), 'a_sub_b_sub_c')
|
self.assertEqual(a_sub_b_sub_c.ssf(), 'a_sub_b_sub_c')
|
||||||
self.assertEqual(a_sub_b_sub_c.ast().value.id, 'a')
|
self.assertEqual(a_sub_b_sub_c.ast().value.id, 'a')
|
||||||
self.assertEqual(a_sub_b_sub_c.ast().slice.value.value.id, 'b')
|
self.assertEqual(a_sub_b_sub_c.ast().slice.value.id, 'b')
|
||||||
self.assertEqual(a_sub_b_sub_c.ast().slice.value.slice.value.id, 'c')
|
self.assertEqual(a_sub_b_sub_c.ast().slice.slice.id, 'c')
|
||||||
self.assertEqual(b_sub_c.ast().slice.value.id, 'c')
|
self.assertEqual(b_sub_c.ast().slice.id, 'c')
|
||||||
self.assertEqual(a_sub_b_sub_c.parent.qn, ('a',))
|
self.assertEqual(a_sub_b_sub_c.parent.qn, ('a',))
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
QN('a', 'b')
|
QN('a', 'b')
|
||||||
@ -157,12 +157,12 @@ class QNTest(test.TestCase):
|
|||||||
|
|
||||||
self.assertNotEqual(a_sub_str_b, a_sub_b)
|
self.assertNotEqual(a_sub_str_b, a_sub_b)
|
||||||
self.assertNotEqual(hash(a_sub_str_b), hash(a_sub_b))
|
self.assertNotEqual(hash(a_sub_str_b), hash(a_sub_b))
|
||||||
self.assertEqual(a_sub_str_b.ast().slice.value.value, 'b')
|
self.assertEqual(a_sub_str_b.ast().slice.value, 'b')
|
||||||
self.assertEqual(str(a_sub_str_b), "a['b']")
|
self.assertEqual(str(a_sub_str_b), "a['b']")
|
||||||
|
|
||||||
a_sub_three = QN(a, subscript=QN(qual_names.Literal(3)))
|
a_sub_three = QN(a, subscript=QN(qual_names.Literal(3)))
|
||||||
self.assertEqual(a_sub_three.ast().slice.value.value, 3)
|
self.assertEqual(a_sub_three.ast().slice.value, 3)
|
||||||
self.assertEqual(str(a_sub_three), "a[3]")
|
self.assertEqual(str(a_sub_three), 'a[3]')
|
||||||
|
|
||||||
def test_support_set(self):
|
def test_support_set(self):
|
||||||
a = QN('a')
|
a = QN('a')
|
||||||
|
@ -444,9 +444,6 @@ class StmtInferrer(gast.NodeVisitor):
|
|||||||
def visit_Expr(self, node):
|
def visit_Expr(self, node):
|
||||||
return self.visit(node.value)
|
return self.visit(node.value)
|
||||||
|
|
||||||
def visit_Index(self, node):
|
|
||||||
return self.visit(node.value)
|
|
||||||
|
|
||||||
def visit_Assign(self, node):
|
def visit_Assign(self, node):
|
||||||
self.rtype = self.visit(node.value)
|
self.rtype = self.visit(node.value)
|
||||||
|
|
||||||
|
@ -669,7 +669,7 @@ class TypeInferenceAnalyzerTest(test.TestCase):
|
|||||||
|
|
||||||
self.assertTypes(fn_body[0].value, str)
|
self.assertTypes(fn_body[0].value, str)
|
||||||
self.assertTypes(fn_body[0].value.value, list)
|
self.assertTypes(fn_body[0].value.value, list)
|
||||||
self.assertTypes(fn_body[0].value.slice.value, int)
|
self.assertTypes(fn_body[0].value.slice, int)
|
||||||
|
|
||||||
def test_tuple_unpacking(self):
|
def test_tuple_unpacking(self):
|
||||||
|
|
||||||
|
@ -221,7 +221,7 @@ class TemplatesTest(test.TestCase, parameterized.TestCase):
|
|||||||
template, foo=parser.parse_expression('foo(a[b]).bar'))[0]
|
template, foo=parser.parse_expression('foo(a[b]).bar'))[0]
|
||||||
function_call_arg = node.body[0].targets[0].value.args[0]
|
function_call_arg = node.body[0].targets[0].value.args[0]
|
||||||
self.assertIsInstance(function_call_arg.ctx, gast.Load)
|
self.assertIsInstance(function_call_arg.ctx, gast.Load)
|
||||||
self.assertIsInstance(function_call_arg.slice.value.ctx, gast.Load)
|
self.assertIsInstance(function_call_arg.slice.ctx, gast.Load)
|
||||||
|
|
||||||
def test_replace_call_keyword(self):
|
def test_replace_call_keyword(self):
|
||||||
template = """
|
template = """
|
||||||
|
@ -412,7 +412,7 @@ class Base(NodeStateTracker, gast.NodeTransformer):
|
|||||||
if isinstance(values, (gast.Tuple, gast.List)):
|
if isinstance(values, (gast.Tuple, gast.List)):
|
||||||
value_el = values.elts[i]
|
value_el = values.elts[i]
|
||||||
else:
|
else:
|
||||||
value_el = gast.Subscript(values, gast.Index(i), ctx=gast.Store())
|
value_el = gast.Subscript(values, i, ctx=gast.Store())
|
||||||
self.apply_to_single_assignments(target_el, value_el, apply_fn)
|
self.apply_to_single_assignments(target_el, value_el, apply_fn)
|
||||||
else:
|
else:
|
||||||
# TODO(mdan): Look into allowing to rewrite the AST here.
|
# TODO(mdan): Look into allowing to rewrite the AST here.
|
||||||
|
@ -127,21 +127,20 @@ class _SubscriptUseTracker(transformer.Base):
|
|||||||
|
|
||||||
def visit_Subscript(self, node):
|
def visit_Subscript(self, node):
|
||||||
"""Visits nodes with subscript in the AST."""
|
"""Visits nodes with subscript in the AST."""
|
||||||
|
s = node.slice
|
||||||
if anno.hasanno(node, anno.Basic.QN):
|
if anno.hasanno(node, anno.Basic.QN):
|
||||||
qn = anno.getanno(node, anno.Basic.QN)
|
qn = anno.getanno(node, anno.Basic.QN)
|
||||||
if isinstance(node.ctx, gast.Load):
|
if isinstance(node.ctx, gast.Load):
|
||||||
self.reads.add(qn)
|
self.reads.add(qn)
|
||||||
elif not isinstance(node.slice, gast.Index):
|
elif isinstance(s, (gast.Tuple, gast.Slice)):
|
||||||
if anno.hasanno(node, anno.Basic.QN):
|
if anno.hasanno(node.value, anno.Basic.QN):
|
||||||
self.complex_reads.add(anno.getanno(node, anno.Basic.QN))
|
|
||||||
elif anno.hasanno(node.value, anno.Basic.QN):
|
|
||||||
self.complex_reads.add(anno.getanno(node.value, anno.Basic.QN))
|
self.complex_reads.add(anno.getanno(node.value, anno.Basic.QN))
|
||||||
value_qn = anno.getanno(node.value, anno.Basic.QN, None)
|
value_qn = anno.getanno(node.value, anno.Basic.QN, None)
|
||||||
if value_qn in self.exclude:
|
if value_qn in self.exclude:
|
||||||
node.value = self.generic_visit(node.value)
|
node.value = self.generic_visit(node.value)
|
||||||
else:
|
else:
|
||||||
node.value = self.visit(node.value)
|
node.value = self.visit(node.value)
|
||||||
node.slice = self.visit(node.slice)
|
node.slice = self.visit(s)
|
||||||
return node
|
return node
|
||||||
|
|
||||||
|
|
||||||
|
@ -469,7 +469,7 @@ install_tensorflow_pip() {
|
|||||||
|
|
||||||
# Install the gast package in the virtualenv. Installing it in user system
|
# Install the gast package in the virtualenv. Installing it in user system
|
||||||
# packages does not appear to port it over when creating a virtualenv.
|
# packages does not appear to port it over when creating a virtualenv.
|
||||||
${PIP_BIN_PATH} install --upgrade "gast==0.3.3" || \
|
${PIP_BIN_PATH} install --upgrade "gast==0.4.0" || \
|
||||||
die "Error: gast install, upgrade FAILED"
|
die "Error: gast install, upgrade FAILED"
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -137,7 +137,7 @@ function install_ubuntu_16_pip_deps {
|
|||||||
"${PIP_CMD}" install --user 'wheel ~= 0.35'
|
"${PIP_CMD}" install --user 'wheel ~= 0.35'
|
||||||
"${PIP_CMD}" install --user 'wrapt ~= 1.12.1'
|
"${PIP_CMD}" install --user 'wrapt ~= 1.12.1'
|
||||||
# We need to pin gast dependency exactly
|
# We need to pin gast dependency exactly
|
||||||
"${PIP_CMD}" install --user 'gast == 0.3.3'
|
"${PIP_CMD}" install --user 'gast == 0.4.0'
|
||||||
# Finally, install tensorboard and estimator
|
# Finally, install tensorboard and estimator
|
||||||
# Note that here we want the latest version that matches (b/156523241)
|
# Note that here we want the latest version that matches (b/156523241)
|
||||||
"${PIP_CMD}" install --user --upgrade --force-reinstall 'tb-nightly ~= 2.4.0.a'
|
"${PIP_CMD}" install --user --upgrade --force-reinstall 'tb-nightly ~= 2.4.0.a'
|
||||||
@ -199,7 +199,7 @@ function install_macos_pip_deps {
|
|||||||
${PIP_CMD} install $USER_FLAG 'wheel ~= 0.35'
|
${PIP_CMD} install $USER_FLAG 'wheel ~= 0.35'
|
||||||
${PIP_CMD} install $USER_FLAG 'wrapt ~= 1.12.1'
|
${PIP_CMD} install $USER_FLAG 'wrapt ~= 1.12.1'
|
||||||
# We need to pin gast dependency exactly
|
# We need to pin gast dependency exactly
|
||||||
${PIP_CMD} install $USER_FLAG 'gast == 0.3.3'
|
${PIP_CMD} install $USER_FLAG 'gast == 0.4.0'
|
||||||
# Finally, install tensorboard and estimator
|
# Finally, install tensorboard and estimator
|
||||||
# Note that here we want the latest version that matches (b/156523241)
|
# Note that here we want the latest version that matches (b/156523241)
|
||||||
${PIP_CMD} install $USER_FLAG --upgrade --force-reinstall 'tb-nightly ~= 2.4.0.a'
|
${PIP_CMD} install $USER_FLAG --upgrade --force-reinstall 'tb-nightly ~= 2.4.0.a'
|
||||||
|
@ -43,7 +43,7 @@ SET PATH=%PATH%;C:\%PYTHON_DIRECTORY%
|
|||||||
%PY_EXE% -m pip install "wheel ~= 0.35"
|
%PY_EXE% -m pip install "wheel ~= 0.35"
|
||||||
%PY_EXE% -m pip install "wrapt ~= 1.12.1"
|
%PY_EXE% -m pip install "wrapt ~= 1.12.1"
|
||||||
@REM We need to pin gast dependency exactly
|
@REM We need to pin gast dependency exactly
|
||||||
%PY_EXE% -m pip install "gast == 0.3.3"
|
%PY_EXE% -m pip install "gast == 0.4.0"
|
||||||
@REM Finally, install tensorboard and estimator
|
@REM Finally, install tensorboard and estimator
|
||||||
@REM Note that here we want the latest version that matches (b/156523241)
|
@REM Note that here we want the latest version that matches (b/156523241)
|
||||||
%PY_EXE% -m pip install --upgrade --force-reinstall "tb-nightly ~= 2.4.0.a"
|
%PY_EXE% -m pip install --upgrade --force-reinstall "tb-nightly ~= 2.4.0.a"
|
||||||
|
@ -91,7 +91,7 @@ REQUIRED_PACKAGES = [
|
|||||||
'wrapt ~= 1.12.1',
|
'wrapt ~= 1.12.1',
|
||||||
# These packages need to be pinned exactly as newer versions are
|
# These packages need to be pinned exactly as newer versions are
|
||||||
# incompatible with the rest of the ecosystem
|
# incompatible with the rest of the ecosystem
|
||||||
'gast == 0.3.3',
|
'gast == 0.4.0',
|
||||||
# TensorFlow ecosystem packages that TF exposes API for
|
# TensorFlow ecosystem packages that TF exposes API for
|
||||||
# These need to be in sync with the existing TF version
|
# These need to be in sync with the existing TF version
|
||||||
# They are updated during the release process
|
# They are updated during the release process
|
||||||
|
@ -450,12 +450,12 @@ def tf_repositories(path_prefix = "", tf_repo_name = ""):
|
|||||||
tf_http_archive(
|
tf_http_archive(
|
||||||
name = "gast_archive",
|
name = "gast_archive",
|
||||||
build_file = clean_dep("//third_party:gast.BUILD"),
|
build_file = clean_dep("//third_party:gast.BUILD"),
|
||||||
sha256 = "b881ef288a49aa81440d2c5eb8aeefd4c2bb8993d5f50edae7413a85bfdb3b57",
|
sha256 = "40feb7b8b8434785585ab224d1568b857edb18297e5a3047f1ba012bc83b42c1",
|
||||||
strip_prefix = "gast-0.3.3",
|
strip_prefix = "gast-0.4.0",
|
||||||
system_build_file = clean_dep("//third_party/systemlibs:gast.BUILD"),
|
system_build_file = clean_dep("//third_party/systemlibs:gast.BUILD"),
|
||||||
urls = [
|
urls = [
|
||||||
"http://mirror.tensorflow.org/files.pythonhosted.org/packages/12/59/eaa15ab9710a20e22225efd042cd2d6a0b559a0656d5baba9641a2a4a921/gast-0.3.3.tar.gz",
|
"http://mirror.tensorflow.org/files.pythonhosted.org/packages/12/59/eaa15ab9710a20e22225efd042cd2d6a0b559a0656d5baba9641a2a4a921/gast-0.4.0.tar.gz",
|
||||||
"https://files.pythonhosted.org/packages/12/59/eaa15ab9710a20e22225efd042cd2d6a0b559a0656d5baba9641a2a4a921/gast-0.3.3.tar.gz",
|
"https://files.pythonhosted.org/packages/83/4a/07c7e59cef23fb147454663c3271c21da68ba2ab141427c20548ae5a8a4d/gast-0.4.0.tar.gz",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user