Fix for Python 4: replace unsafe six.PY3 with PY2
This commit is contained in:
parent
dc4278b527
commit
9db7340145
@ -24,7 +24,7 @@ import warnings
|
||||
|
||||
from absl import logging
|
||||
import six
|
||||
from six import PY3
|
||||
from six import PY2
|
||||
|
||||
from google.protobuf import text_format as _text_format
|
||||
from google.protobuf.message import DecodeError
|
||||
@ -727,10 +727,10 @@ class TFLiteConverter(TFLiteConverterBase):
|
||||
print("Ignore 'tcmalloc: large alloc' warnings.")
|
||||
|
||||
if not isinstance(file_content, str):
|
||||
if PY3:
|
||||
file_content = six.ensure_text(file_content, "utf-8")
|
||||
else:
|
||||
if PY2:
|
||||
file_content = six.ensure_binary(file_content, "utf-8")
|
||||
else:
|
||||
file_content = six.ensure_text(file_content, "utf-8")
|
||||
graph_def = _graph_pb2.GraphDef()
|
||||
_text_format.Merge(file_content, graph_def)
|
||||
except (_text_format.ParseError, DecodeError):
|
||||
|
@ -21,7 +21,7 @@ from __future__ import print_function
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
from six import PY3
|
||||
from six import PY2
|
||||
|
||||
from google.protobuf import text_format as _text_format
|
||||
from google.protobuf.message import DecodeError
|
||||
@ -209,10 +209,10 @@ def evaluate_frozen_graph(filename, input_arrays, output_arrays):
|
||||
graph_def.ParseFromString(file_content)
|
||||
except (_text_format.ParseError, DecodeError):
|
||||
if not isinstance(file_content, str):
|
||||
if PY3:
|
||||
file_content = file_content.decode("utf-8")
|
||||
else:
|
||||
if PY2:
|
||||
file_content = file_content.encode("utf-8")
|
||||
else:
|
||||
file_content = file_content.decode("utf-8")
|
||||
_text_format.Merge(file_content, graph_def)
|
||||
|
||||
graph = ops.Graph()
|
||||
|
@ -539,7 +539,7 @@ def converted_call(f,
|
||||
if logging.has_verbosity(2):
|
||||
logging.log(2, 'Defaults of %s : %s', converted_f,
|
||||
converted_f.__defaults__)
|
||||
if six.PY3:
|
||||
if not six.PY2:
|
||||
logging.log(2, 'KW defaults of %s : %s',
|
||||
converted_f, converted_f.__kwdefaults__)
|
||||
|
||||
|
@ -303,7 +303,7 @@ def _tf_py_func_print(objects, kwargs):
|
||||
|
||||
def print_wrapper(*vals):
|
||||
vals = tuple(v.numpy() if tensor_util.is_tensor(v) else v for v in vals)
|
||||
if six.PY3:
|
||||
if not six.PY2:
|
||||
# TensorFlow doesn't seem to generate Unicode when passing strings to
|
||||
# py_func. This causes the print to add a "b'" wrapper to the output,
|
||||
# which is probably never what you want.
|
||||
|
@ -2920,8 +2920,8 @@ class TensorFlowTestCase(googletest.TestCase):
|
||||
else:
|
||||
self._assertAllCloseRecursive(a, b, rtol, atol, path, msg)
|
||||
|
||||
# Fix Python 3 compatibility issues
|
||||
if six.PY3:
|
||||
# Fix Python 3+ compatibility issues
|
||||
if not six.PY2:
|
||||
# pylint: disable=invalid-name
|
||||
|
||||
# Silence a deprecation warning
|
||||
|
@ -283,15 +283,15 @@ def get_file(fname,
|
||||
|
||||
|
||||
def _makedirs_exist_ok(datadir):
|
||||
if six.PY3:
|
||||
os.makedirs(datadir, exist_ok=True) # pylint: disable=unexpected-keyword-arg
|
||||
else:
|
||||
if six.PY2:
|
||||
# Python 2 doesn't have the exist_ok arg, so we try-except here.
|
||||
try:
|
||||
os.makedirs(datadir)
|
||||
except OSError as e:
|
||||
if e.errno != errno.EEXIST:
|
||||
raise
|
||||
else:
|
||||
os.makedirs(datadir, exist_ok=True) # pylint: disable=unexpected-keyword-arg
|
||||
|
||||
|
||||
def _hash_file(fpath, algorithm='sha256', chunk_size=65535):
|
||||
|
@ -1512,7 +1512,8 @@ def _range_tensor_conversion_function(value, dtype=None, name=None,
|
||||
del as_ref
|
||||
return range(value.start, value.stop, value.step, dtype=dtype, name=name)
|
||||
|
||||
if six.PY3:
|
||||
|
||||
if not six.PY2:
|
||||
ops.register_tensor_conversion_function(builtins.range,
|
||||
_range_tensor_conversion_function)
|
||||
|
||||
|
@ -721,8 +721,8 @@ def _get_opt_einsum_contract_path(equation, shaped_inputs_tuple, optimize):
|
||||
|
||||
|
||||
# Cache the possibly expensive opt_einsum.contract_path call using lru_cache
|
||||
# from the Python3 standard library.
|
||||
if six.PY3:
|
||||
# from the Python3+ standard library.
|
||||
if not six.PY2:
|
||||
_get_opt_einsum_contract_path = functools.lru_cache(maxsize=128)(
|
||||
_get_opt_einsum_contract_path)
|
||||
|
||||
|
@ -436,7 +436,7 @@ class EinsumTest(test.TestCase):
|
||||
# with the same input args (as input_1 and input_2 above), and if
|
||||
# those tests run before this test, then the call_count for the method
|
||||
# mock_contract_path will not increment.
|
||||
if six.PY3:
|
||||
if not six.PY2:
|
||||
special_math_ops._get_opt_einsum_contract_path.cache_clear()
|
||||
|
||||
self.assertEqual(mock_contract_path.call_count, 0)
|
||||
@ -445,15 +445,15 @@ class EinsumTest(test.TestCase):
|
||||
# The same input results in no extra call if we're caching the
|
||||
# opt_einsum.contract_path call. We only cache in Python3.
|
||||
self._check(*input_1)
|
||||
self.assertEqual(mock_contract_path.call_count, 1 if six.PY3 else 2)
|
||||
self.assertEqual(mock_contract_path.call_count, 2 if six.PY2 else 1)
|
||||
# New input results in another call to opt_einsum.
|
||||
self._check(*input_2)
|
||||
self.assertEqual(mock_contract_path.call_count, 2 if six.PY3 else 3)
|
||||
self.assertEqual(mock_contract_path.call_count, 3 if six.PY2 else 2)
|
||||
# No more extra calls as the inputs should be cached.
|
||||
self._check(*input_1)
|
||||
self._check(*input_2)
|
||||
self._check(*input_1)
|
||||
self.assertEqual(mock_contract_path.call_count, 2 if six.PY3 else 6)
|
||||
self.assertEqual(mock_contract_path.call_count, 6 if six.PY2 else 2)
|
||||
|
||||
@test_util.disable_xla('b/131919749')
|
||||
def test_long_cases_with_repeated_labels(self):
|
||||
|
@ -33,11 +33,11 @@ from tensorflow.python import _tf_stack
|
||||
# when a thread is joined, so reusing the key does not introduce a correctness
|
||||
# issue. Moreover, get_ident is faster than storing and retrieving a unique
|
||||
# key in a thread local store.
|
||||
if six.PY3:
|
||||
_get_thread_key = threading.get_ident
|
||||
else:
|
||||
if six.PY2:
|
||||
import thread # pylint: disable=g-import-not-at-top
|
||||
_get_thread_key = thread.get_ident
|
||||
else:
|
||||
_get_thread_key = threading.get_ident
|
||||
|
||||
|
||||
_source_mapper_stacks = collections.defaultdict(list)
|
||||
|
@ -57,7 +57,7 @@ OLD_DIVISION = [
|
||||
def check_file(path, old_division):
|
||||
futures = set()
|
||||
count = 0
|
||||
for line in open(path, encoding='utf-8') if six.PY3 else open(path):
|
||||
for line in open(path) if six.PY2 else open(path, encoding='utf-8'):
|
||||
count += 1
|
||||
m = FUTURES_PATTERN.match(line)
|
||||
if not m:
|
||||
|
Loading…
Reference in New Issue
Block a user