Fix for Python 4: replace unsafe six.PY3 with PY2
This commit is contained in:
parent
dc4278b527
commit
9db7340145
@ -24,7 +24,7 @@ import warnings
|
|||||||
|
|
||||||
from absl import logging
|
from absl import logging
|
||||||
import six
|
import six
|
||||||
from six import PY3
|
from six import PY2
|
||||||
|
|
||||||
from google.protobuf import text_format as _text_format
|
from google.protobuf import text_format as _text_format
|
||||||
from google.protobuf.message import DecodeError
|
from google.protobuf.message import DecodeError
|
||||||
@ -727,10 +727,10 @@ class TFLiteConverter(TFLiteConverterBase):
|
|||||||
print("Ignore 'tcmalloc: large alloc' warnings.")
|
print("Ignore 'tcmalloc: large alloc' warnings.")
|
||||||
|
|
||||||
if not isinstance(file_content, str):
|
if not isinstance(file_content, str):
|
||||||
if PY3:
|
if PY2:
|
||||||
file_content = six.ensure_text(file_content, "utf-8")
|
|
||||||
else:
|
|
||||||
file_content = six.ensure_binary(file_content, "utf-8")
|
file_content = six.ensure_binary(file_content, "utf-8")
|
||||||
|
else:
|
||||||
|
file_content = six.ensure_text(file_content, "utf-8")
|
||||||
graph_def = _graph_pb2.GraphDef()
|
graph_def = _graph_pb2.GraphDef()
|
||||||
_text_format.Merge(file_content, graph_def)
|
_text_format.Merge(file_content, graph_def)
|
||||||
except (_text_format.ParseError, DecodeError):
|
except (_text_format.ParseError, DecodeError):
|
||||||
|
@ -21,7 +21,7 @@ from __future__ import print_function
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from six import PY3
|
from six import PY2
|
||||||
|
|
||||||
from google.protobuf import text_format as _text_format
|
from google.protobuf import text_format as _text_format
|
||||||
from google.protobuf.message import DecodeError
|
from google.protobuf.message import DecodeError
|
||||||
@ -209,10 +209,10 @@ def evaluate_frozen_graph(filename, input_arrays, output_arrays):
|
|||||||
graph_def.ParseFromString(file_content)
|
graph_def.ParseFromString(file_content)
|
||||||
except (_text_format.ParseError, DecodeError):
|
except (_text_format.ParseError, DecodeError):
|
||||||
if not isinstance(file_content, str):
|
if not isinstance(file_content, str):
|
||||||
if PY3:
|
if PY2:
|
||||||
file_content = file_content.decode("utf-8")
|
|
||||||
else:
|
|
||||||
file_content = file_content.encode("utf-8")
|
file_content = file_content.encode("utf-8")
|
||||||
|
else:
|
||||||
|
file_content = file_content.decode("utf-8")
|
||||||
_text_format.Merge(file_content, graph_def)
|
_text_format.Merge(file_content, graph_def)
|
||||||
|
|
||||||
graph = ops.Graph()
|
graph = ops.Graph()
|
||||||
|
@ -539,7 +539,7 @@ def converted_call(f,
|
|||||||
if logging.has_verbosity(2):
|
if logging.has_verbosity(2):
|
||||||
logging.log(2, 'Defaults of %s : %s', converted_f,
|
logging.log(2, 'Defaults of %s : %s', converted_f,
|
||||||
converted_f.__defaults__)
|
converted_f.__defaults__)
|
||||||
if six.PY3:
|
if not six.PY2:
|
||||||
logging.log(2, 'KW defaults of %s : %s',
|
logging.log(2, 'KW defaults of %s : %s',
|
||||||
converted_f, converted_f.__kwdefaults__)
|
converted_f, converted_f.__kwdefaults__)
|
||||||
|
|
||||||
|
@ -303,7 +303,7 @@ def _tf_py_func_print(objects, kwargs):
|
|||||||
|
|
||||||
def print_wrapper(*vals):
|
def print_wrapper(*vals):
|
||||||
vals = tuple(v.numpy() if tensor_util.is_tensor(v) else v for v in vals)
|
vals = tuple(v.numpy() if tensor_util.is_tensor(v) else v for v in vals)
|
||||||
if six.PY3:
|
if not six.PY2:
|
||||||
# TensorFlow doesn't seem to generate Unicode when passing strings to
|
# TensorFlow doesn't seem to generate Unicode when passing strings to
|
||||||
# py_func. This causes the print to add a "b'" wrapper to the output,
|
# py_func. This causes the print to add a "b'" wrapper to the output,
|
||||||
# which is probably never what you want.
|
# which is probably never what you want.
|
||||||
|
@ -2920,8 +2920,8 @@ class TensorFlowTestCase(googletest.TestCase):
|
|||||||
else:
|
else:
|
||||||
self._assertAllCloseRecursive(a, b, rtol, atol, path, msg)
|
self._assertAllCloseRecursive(a, b, rtol, atol, path, msg)
|
||||||
|
|
||||||
# Fix Python 3 compatibility issues
|
# Fix Python 3+ compatibility issues
|
||||||
if six.PY3:
|
if not six.PY2:
|
||||||
# pylint: disable=invalid-name
|
# pylint: disable=invalid-name
|
||||||
|
|
||||||
# Silence a deprecation warning
|
# Silence a deprecation warning
|
||||||
|
@ -283,15 +283,15 @@ def get_file(fname,
|
|||||||
|
|
||||||
|
|
||||||
def _makedirs_exist_ok(datadir):
|
def _makedirs_exist_ok(datadir):
|
||||||
if six.PY3:
|
if six.PY2:
|
||||||
os.makedirs(datadir, exist_ok=True) # pylint: disable=unexpected-keyword-arg
|
|
||||||
else:
|
|
||||||
# Python 2 doesn't have the exist_ok arg, so we try-except here.
|
# Python 2 doesn't have the exist_ok arg, so we try-except here.
|
||||||
try:
|
try:
|
||||||
os.makedirs(datadir)
|
os.makedirs(datadir)
|
||||||
except OSError as e:
|
except OSError as e:
|
||||||
if e.errno != errno.EEXIST:
|
if e.errno != errno.EEXIST:
|
||||||
raise
|
raise
|
||||||
|
else:
|
||||||
|
os.makedirs(datadir, exist_ok=True) # pylint: disable=unexpected-keyword-arg
|
||||||
|
|
||||||
|
|
||||||
def _hash_file(fpath, algorithm='sha256', chunk_size=65535):
|
def _hash_file(fpath, algorithm='sha256', chunk_size=65535):
|
||||||
|
@ -1512,7 +1512,8 @@ def _range_tensor_conversion_function(value, dtype=None, name=None,
|
|||||||
del as_ref
|
del as_ref
|
||||||
return range(value.start, value.stop, value.step, dtype=dtype, name=name)
|
return range(value.start, value.stop, value.step, dtype=dtype, name=name)
|
||||||
|
|
||||||
if six.PY3:
|
|
||||||
|
if not six.PY2:
|
||||||
ops.register_tensor_conversion_function(builtins.range,
|
ops.register_tensor_conversion_function(builtins.range,
|
||||||
_range_tensor_conversion_function)
|
_range_tensor_conversion_function)
|
||||||
|
|
||||||
|
@ -721,8 +721,8 @@ def _get_opt_einsum_contract_path(equation, shaped_inputs_tuple, optimize):
|
|||||||
|
|
||||||
|
|
||||||
# Cache the possibly expensive opt_einsum.contract_path call using lru_cache
|
# Cache the possibly expensive opt_einsum.contract_path call using lru_cache
|
||||||
# from the Python3 standard library.
|
# from the Python3+ standard library.
|
||||||
if six.PY3:
|
if not six.PY2:
|
||||||
_get_opt_einsum_contract_path = functools.lru_cache(maxsize=128)(
|
_get_opt_einsum_contract_path = functools.lru_cache(maxsize=128)(
|
||||||
_get_opt_einsum_contract_path)
|
_get_opt_einsum_contract_path)
|
||||||
|
|
||||||
|
@ -436,7 +436,7 @@ class EinsumTest(test.TestCase):
|
|||||||
# with the same input args (as input_1 and input_2 above), and if
|
# with the same input args (as input_1 and input_2 above), and if
|
||||||
# those tests run before this test, then the call_count for the method
|
# those tests run before this test, then the call_count for the method
|
||||||
# mock_contract_path will not increment.
|
# mock_contract_path will not increment.
|
||||||
if six.PY3:
|
if not six.PY2:
|
||||||
special_math_ops._get_opt_einsum_contract_path.cache_clear()
|
special_math_ops._get_opt_einsum_contract_path.cache_clear()
|
||||||
|
|
||||||
self.assertEqual(mock_contract_path.call_count, 0)
|
self.assertEqual(mock_contract_path.call_count, 0)
|
||||||
@ -445,15 +445,15 @@ class EinsumTest(test.TestCase):
|
|||||||
# The same input results in no extra call if we're caching the
|
# The same input results in no extra call if we're caching the
|
||||||
# opt_einsum.contract_path call. We only cache in Python3.
|
# opt_einsum.contract_path call. We only cache in Python3.
|
||||||
self._check(*input_1)
|
self._check(*input_1)
|
||||||
self.assertEqual(mock_contract_path.call_count, 1 if six.PY3 else 2)
|
self.assertEqual(mock_contract_path.call_count, 2 if six.PY2 else 1)
|
||||||
# New input results in another call to opt_einsum.
|
# New input results in another call to opt_einsum.
|
||||||
self._check(*input_2)
|
self._check(*input_2)
|
||||||
self.assertEqual(mock_contract_path.call_count, 2 if six.PY3 else 3)
|
self.assertEqual(mock_contract_path.call_count, 3 if six.PY2 else 2)
|
||||||
# No more extra calls as the inputs should be cached.
|
# No more extra calls as the inputs should be cached.
|
||||||
self._check(*input_1)
|
self._check(*input_1)
|
||||||
self._check(*input_2)
|
self._check(*input_2)
|
||||||
self._check(*input_1)
|
self._check(*input_1)
|
||||||
self.assertEqual(mock_contract_path.call_count, 2 if six.PY3 else 6)
|
self.assertEqual(mock_contract_path.call_count, 6 if six.PY2 else 2)
|
||||||
|
|
||||||
@test_util.disable_xla('b/131919749')
|
@test_util.disable_xla('b/131919749')
|
||||||
def test_long_cases_with_repeated_labels(self):
|
def test_long_cases_with_repeated_labels(self):
|
||||||
|
@ -33,11 +33,11 @@ from tensorflow.python import _tf_stack
|
|||||||
# when a thread is joined, so reusing the key does not introduce a correctness
|
# when a thread is joined, so reusing the key does not introduce a correctness
|
||||||
# issue. Moreover, get_ident is faster than storing and retrieving a unique
|
# issue. Moreover, get_ident is faster than storing and retrieving a unique
|
||||||
# key in a thread local store.
|
# key in a thread local store.
|
||||||
if six.PY3:
|
if six.PY2:
|
||||||
_get_thread_key = threading.get_ident
|
|
||||||
else:
|
|
||||||
import thread # pylint: disable=g-import-not-at-top
|
import thread # pylint: disable=g-import-not-at-top
|
||||||
_get_thread_key = thread.get_ident
|
_get_thread_key = thread.get_ident
|
||||||
|
else:
|
||||||
|
_get_thread_key = threading.get_ident
|
||||||
|
|
||||||
|
|
||||||
_source_mapper_stacks = collections.defaultdict(list)
|
_source_mapper_stacks = collections.defaultdict(list)
|
||||||
|
@ -57,7 +57,7 @@ OLD_DIVISION = [
|
|||||||
def check_file(path, old_division):
|
def check_file(path, old_division):
|
||||||
futures = set()
|
futures = set()
|
||||||
count = 0
|
count = 0
|
||||||
for line in open(path, encoding='utf-8') if six.PY3 else open(path):
|
for line in open(path) if six.PY2 else open(path, encoding='utf-8'):
|
||||||
count += 1
|
count += 1
|
||||||
m = FUTURES_PATTERN.match(line)
|
m = FUTURES_PATTERN.match(line)
|
||||||
if not m:
|
if not m:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user