STT-tensorflow/tensorflow/python/autograph/pyct/loader_test.py
A. Unique TensorFlower 23d51b67d0 For long (multi-day) training runs, temporary files that got created at the
start of the run might have gotten cleaned up by the OS before the process
finishes. So don't assume they're still around, in the atexit handlers.

PiperOrigin-RevId: 306364079
Change-Id: Ic456f05027d45d6b7699de2e1028da1e31dc0f8a
2020-04-13 20:17:39 -07:00

118 lines
3.5 KiB
Python

# coding=utf-8
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for loader module."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import textwrap
import gast
from tensorflow.python.autograph.pyct import loader
from tensorflow.python.autograph.pyct import parser
from tensorflow.python.platform import test
from tensorflow.python.util import tf_inspect
class LoaderTest(test.TestCase):
def test_parse_load_identity(self):
def test_fn(x):
a = True
b = ''
if a:
b = (x + 1)
return b
node, _ = parser.parse_entity(test_fn, future_features=())
module, _, _ = loader.load_ast(node)
# astunparse uses fixed 4-space indenting.
self.assertEqual(
textwrap.dedent(tf_inspect.getsource(test_fn)),
tf_inspect.getsource(module.test_fn).replace(' ', ' '))
def test_load_ast(self):
node = gast.FunctionDef(
name='f',
args=gast.arguments(
args=[
gast.Name(
'a', ctx=gast.Param(), annotation=None, type_comment=None)
],
posonlyargs=[],
vararg=None,
kwonlyargs=[],
kw_defaults=[],
kwarg=None,
defaults=[]),
body=[
gast.Return(
gast.BinOp(
op=gast.Add(),
left=gast.Name(
'a',
ctx=gast.Load(),
annotation=None,
type_comment=None),
right=gast.Constant(1, kind=None)))
],
decorator_list=[],
returns=None,
type_comment=None)
module, source, _ = loader.load_ast(node)
expected_source = """
# coding=utf-8
def f(a):
return (a + 1)
"""
self.assertEqual(
textwrap.dedent(expected_source).strip(),
source.strip())
self.assertEqual(2, module.f(1))
with open(module.__file__, 'r') as temp_output:
self.assertEqual(
textwrap.dedent(expected_source).strip(),
temp_output.read().strip())
def test_load_source(self):
test_source = textwrap.dedent(u"""
# coding=utf-8
def f(a):
'日本語 Δθₜ ← Δθₜ₋₁ + ∇Q(sₜ, aₜ)(rₜ + γₜ₊₁ max Q(⋅))'
return a + 1
""")
module, _ = loader.load_source(test_source, delete_on_exit=True)
self.assertEqual(module.f(1), 2)
self.assertEqual(
module.f.__doc__, '日本語 Δθₜ ← Δθₜ₋₁ + ∇Q(sₜ, aₜ)(rₜ + γₜ₊₁ max Q(⋅))')
def test_cleanup(self):
test_source = textwrap.dedent('')
_, filename = loader.load_source(test_source, delete_on_exit=True)
# Clean up the file before loader.py tries to remove it, to check that the
# latter can deal with that situation.
os.unlink(filename)
if __name__ == '__main__':
test.main()