diff --git a/tensorflow/tools/compatibility/ipynb.py b/tensorflow/tools/compatibility/ipynb.py index d37a1abde25..a85b21ab4fd 100644 --- a/tensorflow/tools/compatibility/ipynb.py +++ b/tensorflow/tools/compatibility/ipynb.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================= -"""A module to support operation on ipynb files""" +"""A module to support operations on ipynb files""" from __future__ import absolute_import from __future__ import division @@ -21,6 +21,7 @@ from __future__ import print_function import collections import copy import json +import re import shutil import tempfile @@ -62,8 +63,45 @@ def process_file(in_filename, out_filename, upgrader): return files_processed, report_text, errors +def skip_magic(code_line, magic_list): + """Checks if the cell has magic, that is not Python-based. + + Args: + code_line: A line of Python code + magic_list: A list of jupyter "magic" exceptions + + Returns: + If the line jupyter "magic" line, not Python line + + >>> skip_magic('!ls -laF', ['%', '!', '?']) + True + """ + + for magic in magic_list: + if code_line.startswith(magic): + return True + + return False + + +def check_line_split(code_line): + r"""Checks if a line was split with `\`. + + Args: + code_line: A line of Python code + + Returns: + If the line was split with `\` + + >>> skip_magic("!gcloud ml-engine models create ${MODEL} \\\n") + True + """ + + return re.search(r"\\\s*\n$", code_line) + + def _get_code(input_file): - """Load the ipynb file and return a list of CodeLines.""" + """Loads the ipynb file and returns a list of CodeLines.""" raw_code = [] @@ -75,15 +113,21 @@ def _get_code(input_file): if is_python(cell): cell_lines = cell["source"] + is_line_split = False for line_idx, code_line in enumerate(cell_lines): # Sometimes, jupyter has more than python code # Idea is to comment these lines, for upgrade time - if code_line.startswith("%") or code_line.startswith("!") \ - or code_line.startswith("?"): + if skip_magic(code_line, ["%", "!", "?"]) or is_line_split: # Found a special character, need to "encode" code_line = "###!!!" + code_line + # if this cell ends with `\` -> skip the next line + is_line_split = check_line_split(code_line) + + if is_line_split: + is_line_split = check_line_split(code_line) + # Sometimes, people leave \n at the end of cell # in order to migrate only related things, and make the diff # the smallest -> here is another hack @@ -102,7 +146,7 @@ def _get_code(input_file): def _update_notebook(original_notebook, original_raw_lines, updated_code_lines): - """Update notebook, once migration is done.""" + """Updates notebook, once migration is done.""" new_notebook = copy.deepcopy(original_notebook)