Fix for Python 4: replace unsafe six.PY3 with PY2

This commit is contained in:
Hugo 2020-01-15 17:26:38 +02:00
parent dc4278b527
commit 9db7340145
11 changed files with 27 additions and 26 deletions

View File

@ -24,7 +24,7 @@ import warnings
from absl import logging
import six
from six import PY3
from six import PY2
from google.protobuf import text_format as _text_format
from google.protobuf.message import DecodeError
@ -727,10 +727,10 @@ class TFLiteConverter(TFLiteConverterBase):
print("Ignore 'tcmalloc: large alloc' warnings.")
if not isinstance(file_content, str):
if PY3:
file_content = six.ensure_text(file_content, "utf-8")
else:
if PY2:
file_content = six.ensure_binary(file_content, "utf-8")
else:
file_content = six.ensure_text(file_content, "utf-8")
graph_def = _graph_pb2.GraphDef()
_text_format.Merge(file_content, graph_def)
except (_text_format.ParseError, DecodeError):

View File

@ -21,7 +21,7 @@ from __future__ import print_function
import os
import numpy as np
from six import PY3
from six import PY2
from google.protobuf import text_format as _text_format
from google.protobuf.message import DecodeError
@ -209,10 +209,10 @@ def evaluate_frozen_graph(filename, input_arrays, output_arrays):
graph_def.ParseFromString(file_content)
except (_text_format.ParseError, DecodeError):
if not isinstance(file_content, str):
if PY3:
file_content = file_content.decode("utf-8")
else:
if PY2:
file_content = file_content.encode("utf-8")
else:
file_content = file_content.decode("utf-8")
_text_format.Merge(file_content, graph_def)
graph = ops.Graph()

View File

@ -539,7 +539,7 @@ def converted_call(f,
if logging.has_verbosity(2):
logging.log(2, 'Defaults of %s : %s', converted_f,
converted_f.__defaults__)
if six.PY3:
if not six.PY2:
logging.log(2, 'KW defaults of %s : %s',
converted_f, converted_f.__kwdefaults__)

View File

@ -303,7 +303,7 @@ def _tf_py_func_print(objects, kwargs):
def print_wrapper(*vals):
vals = tuple(v.numpy() if tensor_util.is_tensor(v) else v for v in vals)
if six.PY3:
if not six.PY2:
# TensorFlow doesn't seem to generate Unicode when passing strings to
# py_func. This causes the print to add a "b'" wrapper to the output,
# which is probably never what you want.

View File

@ -2920,8 +2920,8 @@ class TensorFlowTestCase(googletest.TestCase):
else:
self._assertAllCloseRecursive(a, b, rtol, atol, path, msg)
# Fix Python 3 compatibility issues
if six.PY3:
# Fix Python 3+ compatibility issues
if not six.PY2:
# pylint: disable=invalid-name
# Silence a deprecation warning

View File

@ -283,15 +283,15 @@ def get_file(fname,
def _makedirs_exist_ok(datadir):
if six.PY3:
os.makedirs(datadir, exist_ok=True) # pylint: disable=unexpected-keyword-arg
else:
if six.PY2:
# Python 2 doesn't have the exist_ok arg, so we try-except here.
try:
os.makedirs(datadir)
except OSError as e:
if e.errno != errno.EEXIST:
raise
else:
os.makedirs(datadir, exist_ok=True) # pylint: disable=unexpected-keyword-arg
def _hash_file(fpath, algorithm='sha256', chunk_size=65535):

View File

@ -1512,7 +1512,8 @@ def _range_tensor_conversion_function(value, dtype=None, name=None,
del as_ref
return range(value.start, value.stop, value.step, dtype=dtype, name=name)
if six.PY3:
if not six.PY2:
ops.register_tensor_conversion_function(builtins.range,
_range_tensor_conversion_function)

View File

@ -721,8 +721,8 @@ def _get_opt_einsum_contract_path(equation, shaped_inputs_tuple, optimize):
# Cache the possibly expensive opt_einsum.contract_path call using lru_cache
# from the Python3 standard library.
if six.PY3:
# from the Python3+ standard library.
if not six.PY2:
_get_opt_einsum_contract_path = functools.lru_cache(maxsize=128)(
_get_opt_einsum_contract_path)

View File

@ -436,7 +436,7 @@ class EinsumTest(test.TestCase):
# with the same input args (as input_1 and input_2 above), and if
# those tests run before this test, then the call_count for the method
# mock_contract_path will not increment.
if six.PY3:
if not six.PY2:
special_math_ops._get_opt_einsum_contract_path.cache_clear()
self.assertEqual(mock_contract_path.call_count, 0)
@ -445,15 +445,15 @@ class EinsumTest(test.TestCase):
# The same input results in no extra call if we're caching the
# opt_einsum.contract_path call. We only cache in Python3.
self._check(*input_1)
self.assertEqual(mock_contract_path.call_count, 1 if six.PY3 else 2)
self.assertEqual(mock_contract_path.call_count, 2 if six.PY2 else 1)
# New input results in another call to opt_einsum.
self._check(*input_2)
self.assertEqual(mock_contract_path.call_count, 2 if six.PY3 else 3)
self.assertEqual(mock_contract_path.call_count, 3 if six.PY2 else 2)
# No more extra calls as the inputs should be cached.
self._check(*input_1)
self._check(*input_2)
self._check(*input_1)
self.assertEqual(mock_contract_path.call_count, 2 if six.PY3 else 6)
self.assertEqual(mock_contract_path.call_count, 6 if six.PY2 else 2)
@test_util.disable_xla('b/131919749')
def test_long_cases_with_repeated_labels(self):

View File

@ -33,11 +33,11 @@ from tensorflow.python import _tf_stack
# when a thread is joined, so reusing the key does not introduce a correctness
# issue. Moreover, get_ident is faster than storing and retrieving a unique
# key in a thread local store.
if six.PY3:
_get_thread_key = threading.get_ident
else:
if six.PY2:
import thread # pylint: disable=g-import-not-at-top
_get_thread_key = thread.get_ident
else:
_get_thread_key = threading.get_ident
_source_mapper_stacks = collections.defaultdict(list)

View File

@ -57,7 +57,7 @@ OLD_DIVISION = [
def check_file(path, old_division):
futures = set()
count = 0
for line in open(path, encoding='utf-8') if six.PY3 else open(path):
for line in open(path) if six.PY2 else open(path, encoding='utf-8'):
count += 1
m = FUTURES_PATTERN.match(line)
if not m: