From 9db73401456bde08e033cef0f97c818bdcc2ace0 Mon Sep 17 00:00:00 2001 From: Hugo Date: Wed, 15 Jan 2020 17:26:38 +0200 Subject: [PATCH] Fix for Python 4: replace unsafe six.PY3 with PY2 --- tensorflow/lite/python/lite.py | 8 ++++---- .../lite/testing/model_coverage/model_coverage_lib.py | 8 ++++---- tensorflow/python/autograph/impl/api.py | 2 +- tensorflow/python/autograph/operators/py_builtins.py | 2 +- tensorflow/python/framework/test_util.py | 4 ++-- tensorflow/python/keras/utils/data_utils.py | 6 +++--- tensorflow/python/ops/math_ops.py | 3 ++- tensorflow/python/ops/special_math_ops.py | 4 ++-- tensorflow/python/ops/special_math_ops_test.py | 8 ++++---- tensorflow/python/util/tf_stack.py | 6 +++--- tensorflow/tools/test/check_futures_test.py | 2 +- 11 files changed, 27 insertions(+), 26 deletions(-) diff --git a/tensorflow/lite/python/lite.py b/tensorflow/lite/python/lite.py index 83e97f156eb..61baea19935 100644 --- a/tensorflow/lite/python/lite.py +++ b/tensorflow/lite/python/lite.py @@ -24,7 +24,7 @@ import warnings from absl import logging import six -from six import PY3 +from six import PY2 from google.protobuf import text_format as _text_format from google.protobuf.message import DecodeError @@ -727,10 +727,10 @@ class TFLiteConverter(TFLiteConverterBase): print("Ignore 'tcmalloc: large alloc' warnings.") if not isinstance(file_content, str): - if PY3: - file_content = six.ensure_text(file_content, "utf-8") - else: + if PY2: file_content = six.ensure_binary(file_content, "utf-8") + else: + file_content = six.ensure_text(file_content, "utf-8") graph_def = _graph_pb2.GraphDef() _text_format.Merge(file_content, graph_def) except (_text_format.ParseError, DecodeError): diff --git a/tensorflow/lite/testing/model_coverage/model_coverage_lib.py b/tensorflow/lite/testing/model_coverage/model_coverage_lib.py index 30d102c4fd9..aa448af77a0 100644 --- a/tensorflow/lite/testing/model_coverage/model_coverage_lib.py +++ b/tensorflow/lite/testing/model_coverage/model_coverage_lib.py @@ -21,7 +21,7 @@ from __future__ import print_function import os import numpy as np -from six import PY3 +from six import PY2 from google.protobuf import text_format as _text_format from google.protobuf.message import DecodeError @@ -209,10 +209,10 @@ def evaluate_frozen_graph(filename, input_arrays, output_arrays): graph_def.ParseFromString(file_content) except (_text_format.ParseError, DecodeError): if not isinstance(file_content, str): - if PY3: - file_content = file_content.decode("utf-8") - else: + if PY2: file_content = file_content.encode("utf-8") + else: + file_content = file_content.decode("utf-8") _text_format.Merge(file_content, graph_def) graph = ops.Graph() diff --git a/tensorflow/python/autograph/impl/api.py b/tensorflow/python/autograph/impl/api.py index 9e976b3a9ca..c65a3931da2 100644 --- a/tensorflow/python/autograph/impl/api.py +++ b/tensorflow/python/autograph/impl/api.py @@ -539,7 +539,7 @@ def converted_call(f, if logging.has_verbosity(2): logging.log(2, 'Defaults of %s : %s', converted_f, converted_f.__defaults__) - if six.PY3: + if not six.PY2: logging.log(2, 'KW defaults of %s : %s', converted_f, converted_f.__kwdefaults__) diff --git a/tensorflow/python/autograph/operators/py_builtins.py b/tensorflow/python/autograph/operators/py_builtins.py index 7df4781524f..20565f28277 100644 --- a/tensorflow/python/autograph/operators/py_builtins.py +++ b/tensorflow/python/autograph/operators/py_builtins.py @@ -303,7 +303,7 @@ def _tf_py_func_print(objects, kwargs): def print_wrapper(*vals): vals = tuple(v.numpy() if tensor_util.is_tensor(v) else v for v in vals) - if six.PY3: + if not six.PY2: # TensorFlow doesn't seem to generate Unicode when passing strings to # py_func. This causes the print to add a "b'" wrapper to the output, # which is probably never what you want. diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py index 8c560e4aa8c..b45e206f9bf 100644 --- a/tensorflow/python/framework/test_util.py +++ b/tensorflow/python/framework/test_util.py @@ -2920,8 +2920,8 @@ class TensorFlowTestCase(googletest.TestCase): else: self._assertAllCloseRecursive(a, b, rtol, atol, path, msg) - # Fix Python 3 compatibility issues - if six.PY3: + # Fix Python 3+ compatibility issues + if not six.PY2: # pylint: disable=invalid-name # Silence a deprecation warning diff --git a/tensorflow/python/keras/utils/data_utils.py b/tensorflow/python/keras/utils/data_utils.py index b3494af9439..5224356e877 100644 --- a/tensorflow/python/keras/utils/data_utils.py +++ b/tensorflow/python/keras/utils/data_utils.py @@ -283,15 +283,15 @@ def get_file(fname, def _makedirs_exist_ok(datadir): - if six.PY3: - os.makedirs(datadir, exist_ok=True) # pylint: disable=unexpected-keyword-arg - else: + if six.PY2: # Python 2 doesn't have the exist_ok arg, so we try-except here. try: os.makedirs(datadir) except OSError as e: if e.errno != errno.EEXIST: raise + else: + os.makedirs(datadir, exist_ok=True) # pylint: disable=unexpected-keyword-arg def _hash_file(fpath, algorithm='sha256', chunk_size=65535): diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py index 360bf2b91dd..e2d824e3446 100644 --- a/tensorflow/python/ops/math_ops.py +++ b/tensorflow/python/ops/math_ops.py @@ -1512,7 +1512,8 @@ def _range_tensor_conversion_function(value, dtype=None, name=None, del as_ref return range(value.start, value.stop, value.step, dtype=dtype, name=name) -if six.PY3: + +if not six.PY2: ops.register_tensor_conversion_function(builtins.range, _range_tensor_conversion_function) diff --git a/tensorflow/python/ops/special_math_ops.py b/tensorflow/python/ops/special_math_ops.py index 686a6300bf6..6741699ed12 100644 --- a/tensorflow/python/ops/special_math_ops.py +++ b/tensorflow/python/ops/special_math_ops.py @@ -721,8 +721,8 @@ def _get_opt_einsum_contract_path(equation, shaped_inputs_tuple, optimize): # Cache the possibly expensive opt_einsum.contract_path call using lru_cache -# from the Python3 standard library. -if six.PY3: +# from the Python3+ standard library. +if not six.PY2: _get_opt_einsum_contract_path = functools.lru_cache(maxsize=128)( _get_opt_einsum_contract_path) diff --git a/tensorflow/python/ops/special_math_ops_test.py b/tensorflow/python/ops/special_math_ops_test.py index 77136adc5b4..320c5a1f6f1 100644 --- a/tensorflow/python/ops/special_math_ops_test.py +++ b/tensorflow/python/ops/special_math_ops_test.py @@ -436,7 +436,7 @@ class EinsumTest(test.TestCase): # with the same input args (as input_1 and input_2 above), and if # those tests run before this test, then the call_count for the method # mock_contract_path will not increment. - if six.PY3: + if not six.PY2: special_math_ops._get_opt_einsum_contract_path.cache_clear() self.assertEqual(mock_contract_path.call_count, 0) @@ -445,15 +445,15 @@ class EinsumTest(test.TestCase): # The same input results in no extra call if we're caching the # opt_einsum.contract_path call. We only cache in Python3. self._check(*input_1) - self.assertEqual(mock_contract_path.call_count, 1 if six.PY3 else 2) + self.assertEqual(mock_contract_path.call_count, 2 if six.PY2 else 1) # New input results in another call to opt_einsum. self._check(*input_2) - self.assertEqual(mock_contract_path.call_count, 2 if six.PY3 else 3) + self.assertEqual(mock_contract_path.call_count, 3 if six.PY2 else 2) # No more extra calls as the inputs should be cached. self._check(*input_1) self._check(*input_2) self._check(*input_1) - self.assertEqual(mock_contract_path.call_count, 2 if six.PY3 else 6) + self.assertEqual(mock_contract_path.call_count, 6 if six.PY2 else 2) @test_util.disable_xla('b/131919749') def test_long_cases_with_repeated_labels(self): diff --git a/tensorflow/python/util/tf_stack.py b/tensorflow/python/util/tf_stack.py index 0dfc03e37ce..628cd4e1854 100644 --- a/tensorflow/python/util/tf_stack.py +++ b/tensorflow/python/util/tf_stack.py @@ -33,11 +33,11 @@ from tensorflow.python import _tf_stack # when a thread is joined, so reusing the key does not introduce a correctness # issue. Moreover, get_ident is faster than storing and retrieving a unique # key in a thread local store. -if six.PY3: - _get_thread_key = threading.get_ident -else: +if six.PY2: import thread # pylint: disable=g-import-not-at-top _get_thread_key = thread.get_ident +else: + _get_thread_key = threading.get_ident _source_mapper_stacks = collections.defaultdict(list) diff --git a/tensorflow/tools/test/check_futures_test.py b/tensorflow/tools/test/check_futures_test.py index a883ce221fc..353fb694bc8 100644 --- a/tensorflow/tools/test/check_futures_test.py +++ b/tensorflow/tools/test/check_futures_test.py @@ -57,7 +57,7 @@ OLD_DIVISION = [ def check_file(path, old_division): futures = set() count = 0 - for line in open(path, encoding='utf-8') if six.PY3 else open(path): + for line in open(path) if six.PY2 else open(path, encoding='utf-8'): count += 1 m = FUTURES_PATTERN.match(line) if not m: