diff --git a/tensorflow/tools/compatibility/ast_edits.py b/tensorflow/tools/compatibility/ast_edits.py index e80bdc47b82..70ed82dd009 100644 --- a/tensorflow/tools/compatibility/ast_edits.py +++ b/tensorflow/tools/compatibility/ast_edits.py @@ -1032,10 +1032,25 @@ class ASTCodeUpgrader(object): output_directory = os.path.dirname(output_path) if not os.path.isdir(output_directory): os.makedirs(output_directory) + + if os.path.islink(input_path): + link_target = os.readlink(input_path) + link_target_output = os.path.join( + output_root_directory, os.path.relpath(link_target, root_directory)) + if (link_target, link_target_output) in files_to_process: + # Create a link to the new location of the target file + os.symlink(link_target_output, output_path) + else: + report += "Copying symlink %s without modifying its target %s" % ( + input_path, link_target) + os.symlink(link_target, output_path) + continue + file_count += 1 _, l_report, l_errors = self.process_file(input_path, output_path) tree_errors[input_path] = l_errors report += l_report + for input_path, output_path in files_to_copy: output_directory = os.path.dirname(output_path) if not os.path.isdir(output_directory): @@ -1059,6 +1074,9 @@ class ASTCodeUpgrader(object): report += ("=" * 80) + "\n" for path in files_to_process: + if os.path.islink(path): + report += "Skipping symlink %s.\n" % path + continue file_count += 1 _, l_report, l_errors = self.process_file(path, path) tree_errors[path] = l_errors diff --git a/tensorflow/tools/compatibility/ast_edits_test.py b/tensorflow/tools/compatibility/ast_edits_test.py index 0bc87d17d53..d6a366d7220 100644 --- a/tensorflow/tools/compatibility/ast_edits_test.py +++ b/tensorflow/tools/compatibility/ast_edits_test.py @@ -45,6 +45,7 @@ from __future__ import division from __future__ import print_function import ast +import os import six from tensorflow.python.framework import test_util @@ -605,6 +606,89 @@ def t(): _, new_text = self._upgrade(RenameImports(), text) self.assertEqual(expected_text, new_text) + def testUpgradeInplaceWithSymlink(self): + upgrade_dir = os.path.join(self.get_temp_dir(), "foo") + os.mkdir(upgrade_dir) + file_a = os.path.join(upgrade_dir, "a.py") + file_b = os.path.join(upgrade_dir, "b.py") + + with open(file_a, "a") as f: + f.write("import foo as f") + os.symlink(file_a, file_b) + + upgrader = ast_edits.ASTCodeUpgrader(RenameImports()) + upgrader.process_tree_inplace(upgrade_dir) + + self.assertTrue(os.path.islink(file_b)) + self.assertEqual(file_a, os.readlink(file_b)) + with open(file_a, "r") as f: + self.assertEqual("import bar as f", f.read()) + + def testUpgradeInPlaceWithSymlinkInDifferentDir(self): + upgrade_dir = os.path.join(self.get_temp_dir(), "foo") + other_dir = os.path.join(self.get_temp_dir(), "bar") + os.mkdir(upgrade_dir) + os.mkdir(other_dir) + file_c = os.path.join(other_dir, "c.py") + file_d = os.path.join(upgrade_dir, "d.py") + + with open(file_c, "a") as f: + f.write("import foo as f") + os.symlink(file_c, file_d) + + upgrader = ast_edits.ASTCodeUpgrader(RenameImports()) + upgrader.process_tree_inplace(upgrade_dir) + + self.assertTrue(os.path.islink(file_d)) + self.assertEqual(file_c, os.readlink(file_d)) + # File pointed to by symlink is in a different directory. + # Therefore, it should not be upgraded. + with open(file_c, "r") as f: + self.assertEqual("import foo as f", f.read()) + + def testUpgradeCopyWithSymlink(self): + upgrade_dir = os.path.join(self.get_temp_dir(), "foo") + output_dir = os.path.join(self.get_temp_dir(), "bar") + os.mkdir(upgrade_dir) + file_a = os.path.join(upgrade_dir, "a.py") + file_b = os.path.join(upgrade_dir, "b.py") + + with open(file_a, "a") as f: + f.write("import foo as f") + os.symlink(file_a, file_b) + + upgrader = ast_edits.ASTCodeUpgrader(RenameImports()) + upgrader.process_tree(upgrade_dir, output_dir, copy_other_files=True) + + new_file_a = os.path.join(output_dir, "a.py") + new_file_b = os.path.join(output_dir, "b.py") + self.assertTrue(os.path.islink(new_file_b)) + self.assertEqual(new_file_a, os.readlink(new_file_b)) + with open(new_file_a, "r") as f: + self.assertEqual("import bar as f", f.read()) + + def testUpgradeCopyWithSymlinkInDifferentDir(self): + upgrade_dir = os.path.join(self.get_temp_dir(), "foo") + other_dir = os.path.join(self.get_temp_dir(), "bar") + output_dir = os.path.join(self.get_temp_dir(), "baz") + os.mkdir(upgrade_dir) + os.mkdir(other_dir) + file_a = os.path.join(other_dir, "a.py") + file_b = os.path.join(upgrade_dir, "b.py") + + with open(file_a, "a") as f: + f.write("import foo as f") + os.symlink(file_a, file_b) + + upgrader = ast_edits.ASTCodeUpgrader(RenameImports()) + upgrader.process_tree(upgrade_dir, output_dir, copy_other_files=True) + + new_file_b = os.path.join(output_dir, "b.py") + self.assertTrue(os.path.islink(new_file_b)) + self.assertEqual(file_a, os.readlink(new_file_b)) + with open(file_a, "r") as f: + self.assertEqual("import foo as f", f.read()) + if __name__ == "__main__": test_lib.main()