ICM PY3 Migration - //tensorflow/lite [1]

PiperOrigin-RevId: 274090960
This commit is contained in:
Hye Soo Yang 2019-10-10 19:49:22 -07:00 committed by TensorFlower Gardener
parent e8fc11fb63
commit c81c00b6ca
24 changed files with 129 additions and 46 deletions

View File

@ -7,7 +7,7 @@ py_binary(
name = "label_image", name = "label_image",
srcs = ["label_image.py"], srcs = ["label_image.py"],
main = "label_image.py", main = "label_image.py",
python_version = "PY2", python_version = "PY3",
srcs_version = "PY2AND3", srcs_version = "PY2AND3",
deps = [ deps = [
"//tensorflow/lite/python:lite", "//tensorflow/lite/python:lite",

View File

@ -37,7 +37,7 @@ py_test(
name = "unidirectional_sequence_lstm_test", name = "unidirectional_sequence_lstm_test",
size = "large", size = "large",
srcs = ["unidirectional_sequence_lstm_test.py"], srcs = ["unidirectional_sequence_lstm_test.py"],
python_version = "PY2", python_version = "PY3",
srcs_version = "PY2AND3", srcs_version = "PY2AND3",
tags = [ tags = [
"no_oss", "no_oss",
@ -60,7 +60,7 @@ py_test(
name = "unidirectional_sequence_rnn_test", name = "unidirectional_sequence_rnn_test",
size = "large", size = "large",
srcs = ["unidirectional_sequence_rnn_test.py"], srcs = ["unidirectional_sequence_rnn_test.py"],
python_version = "PY2", python_version = "PY3",
srcs_version = "PY2AND3", srcs_version = "PY2AND3",
tags = [ tags = [
"no_oss", "no_oss",
@ -83,7 +83,7 @@ py_test(
name = "bidirectional_sequence_lstm_test", name = "bidirectional_sequence_lstm_test",
size = "large", size = "large",
srcs = ["bidirectional_sequence_lstm_test.py"], srcs = ["bidirectional_sequence_lstm_test.py"],
python_version = "PY2", python_version = "PY3",
srcs_version = "PY2AND3", srcs_version = "PY2AND3",
tags = [ tags = [
"no_oss", "no_oss",
@ -106,7 +106,7 @@ py_test(
name = "bidirectional_sequence_rnn_test", name = "bidirectional_sequence_rnn_test",
size = "large", size = "large",
srcs = ["bidirectional_sequence_rnn_test.py"], srcs = ["bidirectional_sequence_rnn_test.py"],
python_version = "PY2", python_version = "PY3",
srcs_version = "PY2AND3", srcs_version = "PY2AND3",
tags = [ tags = [
"no_oss", "no_oss",

View File

@ -1,3 +1,4 @@
# Lint as: python2, python3
# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # Copyright 2018 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
@ -17,6 +18,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import tempfile import tempfile
import numpy as np import numpy as np
from six.moves import range
import tensorflow as tf import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data from tensorflow.examples.tutorials.mnist import input_data

View File

@ -1,3 +1,4 @@
# Lint as: python2, python3
# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # Copyright 2019 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
@ -17,6 +18,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import tempfile import tempfile
import numpy as np import numpy as np
from six.moves import range
import tensorflow as tf import tensorflow as tf
from tensorflow import flags from tensorflow import flags

View File

@ -1,3 +1,4 @@
# Lint as: python2, python3
# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # Copyright 2018 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
@ -153,7 +154,8 @@ class TfLiteRNNCell(rnn_cell_impl.LayerRNNCell):
"reuse": self._reuse, "reuse": self._reuse,
} }
base_config = super(TfLiteRNNCell, self).get_config() base_config = super(TfLiteRNNCell, self).get_config()
return dict(itertools.chain(base_config.items(), config.items())) return dict(
itertools.chain(list(base_config.items()), list(config.items())))
@tf_export(v1=["lite.experimental.nn.TFLiteLSTMCell"]) @tf_export(v1=["lite.experimental.nn.TFLiteLSTMCell"])

View File

@ -1,3 +1,4 @@
# Lint as: python2, python3
# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # Copyright 2018 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
@ -17,6 +18,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import tempfile import tempfile
import numpy as np import numpy as np
from six.moves import range
import tensorflow as tf import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data from tensorflow.examples.tutorials.mnist import input_data

View File

@ -1,3 +1,4 @@
# Lint as: python2, python3
# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # Copyright 2018 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
@ -17,6 +18,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import tempfile import tempfile
import numpy as np import numpy as np
from six.moves import range
import tensorflow as tf import tensorflow as tf
from tensorflow import flags from tensorflow import flags

View File

@ -1,3 +1,4 @@
# Lint as: python2, python3
# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # Copyright 2019 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
@ -21,6 +22,7 @@ from __future__ import print_function
import argparse import argparse
import glob import glob
import os import os
import six
def rename_example_subfolder_files(library_dir): def rename_example_subfolder_files(library_dir):
@ -50,8 +52,9 @@ def move_person_data(library_dir):
with open(new_person_data_path, 'r') as source_file: with open(new_person_data_path, 'r') as source_file:
file_contents = source_file.read() file_contents = source_file.read()
file_contents = file_contents.replace( file_contents = file_contents.replace(
'#include "tensorflow/lite/experimental/micro/examples/' + six.ensure_str(
'person_detection/person_detect_model_data.h"', '#include "tensorflow/lite/experimental/micro/examples/' +
'person_detection/person_detect_model_data.h"'),
'#include "person_detect_model_data.h"') '#include "person_detect_model_data.h"')
with open(new_person_data_path, 'w') as source_file: with open(new_person_data_path, 'w') as source_file:
source_file.write(file_contents) source_file.write(file_contents)

View File

@ -1,3 +1,4 @@
# Lint as: python2, python3
# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # Copyright 2019 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
@ -21,11 +22,12 @@ from __future__ import print_function
import argparse import argparse
import os.path import os.path
import re import re
import six
def sanitize_xml(unsanitized): def sanitize_xml(unsanitized):
"""Uses a whitelist to avoid generating bad XML.""" """Uses a whitelist to avoid generating bad XML."""
return re.sub(r'[^a-zA-Z0-9+_\-/\\.]', '', unsanitized) return re.sub(r'[^a-zA-Z0-9+_\-/\\.]', '', six.ensure_str(unsanitized))
def main(unused_args, flags): def main(unused_args, flags):
@ -33,11 +35,12 @@ def main(unused_args, flags):
with open(flags.input_template, 'r') as input_template_file: with open(flags.input_template, 'r') as input_template_file:
template_file_text = input_template_file.read() template_file_text = input_template_file.read()
template_file_text = re.sub(r'%{EXECUTABLE}%', flags.executable, template_file_text = re.sub(r'%{EXECUTABLE}%',
six.ensure_str(flags.executable),
template_file_text) template_file_text)
srcs_list = flags.srcs.split(' ') srcs_list = six.ensure_str(flags.srcs).split(' ')
hdrs_list = flags.hdrs.split(' ') hdrs_list = six.ensure_str(flags.hdrs).split(' ')
all_srcs_list = srcs_list + hdrs_list all_srcs_list = srcs_list + hdrs_list
all_srcs_list.sort() all_srcs_list.sort()
@ -64,9 +67,10 @@ def main(unused_args, flags):
replace_srcs += ' <FileType>' + ext_index + '</FileType>\n' replace_srcs += ' <FileType>' + ext_index + '</FileType>\n'
replace_srcs += ' <FilePath>' + clean_src + '</FilePath>\n' replace_srcs += ' <FilePath>' + clean_src + '</FilePath>\n'
replace_srcs += ' </File>\n' replace_srcs += ' </File>\n'
template_file_text = re.sub(r'%{SRCS}%', replace_srcs, template_file_text) template_file_text = re.sub(r'%{SRCS}%', replace_srcs,
six.ensure_str(template_file_text))
include_paths = re.sub(' ', ';', flags.include_paths) include_paths = re.sub(' ', ';', six.ensure_str(flags.include_paths))
template_file_text = re.sub(r'%{INCLUDE_PATHS}%', include_paths, template_file_text = re.sub(r'%{INCLUDE_PATHS}%', include_paths,
template_file_text) template_file_text)

View File

@ -1,3 +1,4 @@
# Lint as: python2, python3
# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # Copyright 2019 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
@ -21,6 +22,7 @@ from __future__ import print_function
import argparse import argparse
import re import re
import sys import sys
import six
def replace_includes(line, supplied_headers_list): def replace_includes(line, supplied_headers_list):
@ -29,10 +31,11 @@ def replace_includes(line, supplied_headers_list):
if include_match: if include_match:
path = include_match.group(2) path = include_match.group(2)
for supplied_header in supplied_headers_list: for supplied_header in supplied_headers_list:
if supplied_header.endswith(path): if six.ensure_str(supplied_header).endswith(path):
path = supplied_header path = supplied_header
break break
line = include_match.group(1) + path + include_match.group(3) line = include_match.group(1) + six.ensure_str(path) + include_match.group(
3)
return line return line
@ -70,8 +73,8 @@ def replace_example_includes(line, _):
# their default locations into the top-level 'examples' folder in the Arduino # their default locations into the top-level 'examples' folder in the Arduino
# library, we have to update any include references to match. # library, we have to update any include references to match.
dir_path = 'tensorflow/lite/experimental/micro/examples/' dir_path = 'tensorflow/lite/experimental/micro/examples/'
include_match = re.match(r'(.*#include.*")' + dir_path + r'([^/]+)/(.*")', include_match = re.match(
line) r'(.*#include.*")' + six.ensure_str(dir_path) + r'([^/]+)/(.*")', line)
if include_match: if include_match:
flattened_name = re.sub(r'/', '_', include_match.group(3)) flattened_name = re.sub(r'/', '_', include_match.group(3))
line = include_match.group(1) + flattened_name line = include_match.group(1) + flattened_name
@ -82,7 +85,7 @@ def main(unused_args, flags):
"""Transforms the input source file to work when exported to Arduino.""" """Transforms the input source file to work when exported to Arduino."""
input_file_lines = sys.stdin.read().split('\n') input_file_lines = sys.stdin.read().split('\n')
supplied_headers_list = flags.third_party_headers.split(' ') supplied_headers_list = six.ensure_str(flags.third_party_headers).split(' ')
output_lines = [] output_lines = []
for line in input_file_lines: for line in input_file_lines:

View File

@ -19,7 +19,7 @@ py_library(
py_test( py_test(
name = "ops_util_test", name = "ops_util_test",
srcs = ["ops_util_test.py"], srcs = ["ops_util_test.py"],
python_version = "PY2", python_version = "PY3",
srcs_version = "PY2AND3", srcs_version = "PY2AND3",
deps = [ deps = [
":ops_util", ":ops_util",

View File

@ -15,6 +15,7 @@ py_library(
deps = [ deps = [
"//tensorflow/lite/python/interpreter_wrapper:tensorflow_wrap_interpreter_wrapper", "//tensorflow/lite/python/interpreter_wrapper:tensorflow_wrap_interpreter_wrapper",
"//third_party/py/numpy", "//third_party/py/numpy",
"@six_archive//:six",
], ],
) )
@ -25,6 +26,7 @@ py_test(
"//tensorflow/lite/python/testdata:interpreter_test_data", "//tensorflow/lite/python/testdata:interpreter_test_data",
"//tensorflow/lite/python/testdata:test_delegate.so", "//tensorflow/lite/python/testdata:test_delegate.so",
], ],
python_version = "PY3",
srcs_version = "PY2AND3", srcs_version = "PY2AND3",
tags = [ tags = [
"no_windows", "no_windows",
@ -38,6 +40,7 @@ py_test(
"//tensorflow/python:framework_test_lib", "//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform", "//tensorflow/python:platform",
"//third_party/py/numpy", "//third_party/py/numpy",
"@six_archive//:six",
], ],
) )
@ -63,16 +66,20 @@ py_test(
"//tensorflow/python:framework_test_lib", "//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform", "//tensorflow/python:platform",
"//third_party/py/numpy", "//third_party/py/numpy",
"@six_archive//:six",
], ],
) )
py_binary( py_binary(
name = "tflite_convert", name = "tflite_convert",
srcs = ["tflite_convert.py"], srcs = ["tflite_convert.py"],
python_version = "PY2", python_version = "PY3",
srcs_version = "PY2AND3", srcs_version = "PY2AND3",
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [":tflite_convert_main_lib"], deps = [
":tflite_convert_main_lib",
"@six_archive//:six",
],
) )
py_library( py_library(
@ -80,7 +87,10 @@ py_library(
srcs = ["tflite_convert.py"], srcs = ["tflite_convert.py"],
srcs_version = "PY2AND3", srcs_version = "PY2AND3",
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [":tflite_convert_lib"], deps = [
":tflite_convert_lib",
"@six_archive//:six",
],
) )
py_library( py_library(
@ -90,6 +100,7 @@ py_library(
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [ deps = [
":lite", ":lite",
"@six_archive//:six",
], ],
) )
@ -97,6 +108,7 @@ py_test(
name = "tflite_convert_test", name = "tflite_convert_test",
srcs = ["tflite_convert_test.py"], srcs = ["tflite_convert_test.py"],
data = [":tflite_convert"], data = [":tflite_convert"],
python_version = "PY3",
srcs_version = "PY2AND3", srcs_version = "PY2AND3",
tags = [ tags = [
"no_oss", "no_oss",
@ -129,6 +141,7 @@ py_library(
"//tensorflow/python/keras", "//tensorflow/python/keras",
"//tensorflow/python/saved_model:constants", "//tensorflow/python/saved_model:constants",
"//tensorflow/python/saved_model:loader", "//tensorflow/python/saved_model:loader",
"@six_archive//:six",
], ],
) )
@ -136,6 +149,7 @@ py_test(
name = "lite_test", name = "lite_test",
srcs = ["lite_test.py"], srcs = ["lite_test.py"],
data = ["@tflite_mobilenet_ssd_quant_protobuf//:tflite_graph.pb"], data = ["@tflite_mobilenet_ssd_quant_protobuf//:tflite_graph.pb"],
python_version = "PY3",
shard_count = 4, shard_count = 4,
srcs_version = "PY2AND3", srcs_version = "PY2AND3",
tags = [ tags = [
@ -145,12 +159,14 @@ py_test(
":lite", ":lite",
"//tensorflow/python:client_testlib", "//tensorflow/python:client_testlib",
"//tensorflow/python:framework_test_lib", "//tensorflow/python:framework_test_lib",
"@six_archive//:six",
], ],
) )
py_test( py_test(
name = "lite_v2_test", name = "lite_v2_test",
srcs = ["lite_v2_test.py"], srcs = ["lite_v2_test.py"],
python_version = "PY3",
srcs_version = "PY2AND3", srcs_version = "PY2AND3",
tags = [ tags = [
"no_windows", "no_windows",
@ -159,12 +175,14 @@ py_test(
":lite", ":lite",
"//tensorflow/python:client_testlib", "//tensorflow/python:client_testlib",
"//tensorflow/python:framework_test_lib", "//tensorflow/python:framework_test_lib",
"@six_archive//:six",
], ],
) )
py_test( py_test(
name = "lite_flex_test", name = "lite_flex_test",
srcs = ["lite_flex_test.py"], srcs = ["lite_flex_test.py"],
python_version = "PY3",
srcs_version = "PY2AND3", srcs_version = "PY2AND3",
tags = [ tags = [
# TODO(b/111881877): Enable in oss after resolving op registry issues. # TODO(b/111881877): Enable in oss after resolving op registry issues.
@ -181,6 +199,7 @@ py_test(
py_test( py_test(
name = "lite_mlir_test", name = "lite_mlir_test",
srcs = ["lite_mlir_test.py"], srcs = ["lite_mlir_test.py"],
python_version = "PY3",
srcs_version = "PY2AND3", srcs_version = "PY2AND3",
tags = [ tags = [
"no_windows", "no_windows",
@ -189,6 +208,7 @@ py_test(
":lite", ":lite",
"//tensorflow/python:client_testlib", "//tensorflow/python:client_testlib",
"//tensorflow/python:framework_test_lib", "//tensorflow/python:framework_test_lib",
"@six_archive//:six",
], ],
) )
@ -204,12 +224,14 @@ py_library(
":op_hint", ":op_hint",
"//tensorflow/python:tf_optimizer", "//tensorflow/python:tf_optimizer",
"//tensorflow/python/eager:wrap_function", "//tensorflow/python/eager:wrap_function",
"@six_archive//:six",
], ],
) )
py_test( py_test(
name = "util_test", name = "util_test",
srcs = ["util_test.py"], srcs = ["util_test.py"],
python_version = "PY3",
srcs_version = "PY2AND3", srcs_version = "PY2AND3",
tags = [ tags = [
"no_windows", "no_windows",
@ -218,6 +240,7 @@ py_test(
":util", ":util",
"//tensorflow/python:client_testlib", "//tensorflow/python:client_testlib",
"//tensorflow/python:framework_test_lib", "//tensorflow/python:framework_test_lib",
"@six_archive//:six",
], ],
) )
@ -226,6 +249,7 @@ py_library(
srcs = [ srcs = [
"wrap_toco.py", "wrap_toco.py",
], ],
srcs_version = "PY2AND3",
deps = [ deps = [
"//tensorflow/python:pywrap_tensorflow", "//tensorflow/python:pywrap_tensorflow",
"//tensorflow/python:util", "//tensorflow/python:util",
@ -256,6 +280,7 @@ py_library(
"//tensorflow/lite/toco/python:toco_from_protos", "//tensorflow/lite/toco/python:toco_from_protos",
"//tensorflow/python:dtypes", "//tensorflow/python:dtypes",
"//tensorflow/python:platform", "//tensorflow/python:platform",
"@six_archive//:six",
], ],
) )
@ -275,6 +300,7 @@ py_library(
py_test( py_test(
name = "convert_test", name = "convert_test",
srcs = ["convert_test.py"], srcs = ["convert_test.py"],
python_version = "PY3",
srcs_version = "PY2AND3", srcs_version = "PY2AND3",
deps = [ deps = [
":convert", ":convert",
@ -306,6 +332,7 @@ py_library(
py_test( py_test(
name = "convert_saved_model_test", name = "convert_saved_model_test",
srcs = ["convert_saved_model_test.py"], srcs = ["convert_saved_model_test.py"],
python_version = "PY3",
srcs_version = "PY2AND3", srcs_version = "PY2AND3",
tags = [ tags = [
"no_windows", "no_windows",

View File

@ -1,3 +1,4 @@
# Lint as: python2, python3
# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # Copyright 2018 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
@ -19,12 +20,14 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import enum # pylint: disable=g-bad-import-order import enum # pylint: disable=g-bad-import-order
import os as _os import os as _os
import platform as _platform import platform as _platform
import subprocess as _subprocess import subprocess as _subprocess
import tempfile as _tempfile import tempfile as _tempfile
import six
from six.moves import map
from tensorflow.lite.python import lite_constants from tensorflow.lite.python import lite_constants
from tensorflow.lite.python import util from tensorflow.lite.python import util
from tensorflow.lite.python import wrap_toco from tensorflow.lite.python import wrap_toco
@ -55,7 +58,7 @@ def _try_convert_to_unicode(output):
if isinstance(output, bytes): if isinstance(output, bytes):
try: try:
return output.decode() return six.ensure_text(output)
except UnicodeDecodeError: except UnicodeDecodeError:
pass pass
return output return output
@ -151,7 +154,7 @@ def toco_convert_protos(model_flags_str,
fp_model.write(model_flags_str) fp_model.write(model_flags_str)
fp_toco.write(toco_flags_str) fp_toco.write(toco_flags_str)
fp_input.write(input_data_str) fp_input.write(six.ensure_binary(input_data_str))
debug_info_str = debug_info_str if debug_info_str else "" debug_info_str = debug_info_str if debug_info_str else ""
# if debug_info_str contains a "string value", then the call to # if debug_info_str contains a "string value", then the call to
# fp_debug.write(debug_info_str) will fail with the following error # fp_debug.write(debug_info_str) will fail with the following error
@ -347,7 +350,7 @@ def build_toco_convert_protos(input_tensors,
shape = input_tensor.shape shape = input_tensor.shape
else: else:
shape = input_shapes[idx] shape = input_shapes[idx]
input_array.shape.dims.extend(map(int, shape)) input_array.shape.dims.extend(list(map(int, shape)))
for output_tensor in output_tensors: for output_tensor in output_tensors:
model.output_arrays.append(util.get_tensor_name(output_tensor)) model.output_arrays.append(util.get_tensor_name(output_tensor))
@ -400,7 +403,7 @@ def toco_convert_graph_def(input_data, input_arrays_with_shape, output_arrays,
input_array.mean_value, input_array.std_value = kwargs[ input_array.mean_value, input_array.std_value = kwargs[
"quantized_input_stats"][idx] "quantized_input_stats"][idx]
input_array.name = name input_array.name = name
input_array.shape.dims.extend(map(int, shape)) input_array.shape.dims.extend(list(map(int, shape)))
for name in output_arrays: for name in output_arrays:
model_flags.output_arrays.append(name) model_flags.output_arrays.append(name)

View File

@ -1,3 +1,4 @@
# Lint as: python2, python3
# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # Copyright 2018 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
@ -20,7 +21,10 @@ from __future__ import print_function
import ctypes import ctypes
import platform import platform
import sys import sys
import numpy as np import numpy as np
import six
from six.moves import range
# pylint: disable=g-import-not-at-top # pylint: disable=g-import-not-at-top
if not __file__.endswith('tflite_runtime/interpreter.py'): if not __file__.endswith('tflite_runtime/interpreter.py'):
@ -107,7 +111,7 @@ class Delegate(object):
self.message = '' self.message = ''
def report(self, x): def report(self, x):
self.message += x if isinstance(x, str) else x.decode('utf-8') self.message += x if isinstance(x, str) else six.ensure_text(x, 'utf-8')
capture = ErrorMessageCapture() capture = ErrorMessageCapture()
error_capturer_cb = ctypes.CFUNCTYPE(None, ctypes.c_char_p)(capture.report) error_capturer_cb = ctypes.CFUNCTYPE(None, ctypes.c_char_p)(capture.report)

View File

@ -1,3 +1,4 @@
# Lint as: python2, python3
# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # Copyright 2018 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
@ -335,7 +336,8 @@ class InterpreterDelegateTest(test_util.TensorFlowTestCase):
if sys.platform == 'darwin': return if sys.platform == 'darwin': return
destructions = [] destructions = []
def register_destruction(x): def register_destruction(x):
destructions.append(x if isinstance(x, str) else x.decode('utf-8')) destructions.append(
x if isinstance(x, str) else six.ensure_text(x, 'utf-8'))
return 0 return 0
# Make a wrapper for the callback so we can send this to ctypes # Make a wrapper for the callback so we can send this to ctypes
delegate = interpreter_wrapper.load_delegate(self._delegate_file) delegate = interpreter_wrapper.load_delegate(self._delegate_file)

View File

@ -1,3 +1,4 @@
# Lint as: python2, python3
# Copyright 2017 The TensorFlow Authors. All Rights Reserved. # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
@ -18,8 +19,10 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import warnings
import enum import enum
import warnings
import six
from six import PY3 from six import PY3
from google.protobuf import text_format as _text_format from google.protobuf import text_format as _text_format
@ -679,9 +682,9 @@ class TFLiteConverter(TFLiteConverterBase):
if not isinstance(file_content, str): if not isinstance(file_content, str):
if PY3: if PY3:
file_content = file_content.decode("utf-8") file_content = six.ensure_text(file_content, "utf-8")
else: else:
file_content = file_content.encode("utf-8") file_content = six.ensure_binary(file_content, "utf-8")
graph_def = _graph_pb2.GraphDef() graph_def = _graph_pb2.GraphDef()
_text_format.Merge(file_content, graph_def) _text_format.Merge(file_content, graph_def)
except (_text_format.ParseError, DecodeError): except (_text_format.ParseError, DecodeError):

View File

@ -1,3 +1,4 @@
# Lint as: python2, python3
# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # Copyright 2019 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
@ -20,6 +21,7 @@ from __future__ import print_function
from absl.testing import parameterized from absl.testing import parameterized
import numpy as np import numpy as np
from six.moves import zip
from tensorflow.lite.python import lite from tensorflow.lite.python import lite
from tensorflow.lite.python import lite_constants from tensorflow.lite.python import lite_constants

View File

@ -1,3 +1,4 @@
# Lint as: python2, python3
# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # Copyright 2018 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
@ -20,8 +21,11 @@ from __future__ import print_function
import os import os
import tempfile import tempfile
from absl.testing import parameterized from absl.testing import parameterized
import numpy as np import numpy as np
import six
from six.moves import range
from tensorflow.lite.python import lite from tensorflow.lite.python import lite
from tensorflow.lite.python import lite_constants from tensorflow.lite.python import lite_constants
@ -1132,7 +1136,7 @@ class FromSessionTest(TestModels, parameterized.TestCase):
# Check the add node in the inlined function is included. # Check the add node in the inlined function is included.
func = sess.graph.as_graph_def().library.function[0].signature.name func = sess.graph.as_graph_def().library.function[0].signature.name
self.assertIn(('add@' + func), converter._debug_info.traces) self.assertIn(('add@' + six.ensure_str(func)), converter._debug_info.traces)
class FromFrozenGraphFile(test_util.TensorFlowTestCase): class FromFrozenGraphFile(test_util.TensorFlowTestCase):

View File

@ -1,3 +1,4 @@
# Lint as: python2, python3
# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # Copyright 2019 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
@ -20,6 +21,8 @@ from __future__ import print_function
import os import os
import numpy as np import numpy as np
from six.moves import range
from six.moves import zip
from tensorflow.lite.python import lite from tensorflow.lite.python import lite
from tensorflow.lite.python.interpreter import Interpreter from tensorflow.lite.python.interpreter import Interpreter

View File

@ -57,7 +57,7 @@ py_test(
":test_data", ":test_data",
"//tensorflow/lite:testdata/multi_add.bin", "//tensorflow/lite:testdata/multi_add.bin",
], ],
python_version = "PY2", python_version = "PY3",
srcs_version = "PY2AND3", srcs_version = "PY2AND3",
tags = ["no_oss"], tags = ["no_oss"],
deps = [ deps = [
@ -68,5 +68,6 @@ py_test(
"//tensorflow/python:framework_test_lib", "//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform", "//tensorflow/python:platform",
"//third_party/py/numpy", "//third_party/py/numpy",
"@six_archive//:six",
], ],
) )

View File

@ -1,3 +1,4 @@
# Lint as: python2, python3
# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # Copyright 2019 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
@ -18,6 +19,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import numpy as np import numpy as np
from six.moves import range
from tensorflow.lite.python import lite_constants as constants from tensorflow.lite.python import lite_constants as constants
from tensorflow.lite.python.optimize import calibrator as _calibrator from tensorflow.lite.python.optimize import calibrator as _calibrator

View File

@ -1,3 +1,4 @@
# Lint as: python2, python3
# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # Copyright 2018 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
@ -22,6 +23,9 @@ import argparse
import os import os
import sys import sys
import six
from six.moves import zip
from tensorflow.lite.python import lite from tensorflow.lite.python import lite
from tensorflow.lite.python import lite_constants from tensorflow.lite.python import lite_constants
from tensorflow.lite.toco import toco_flags_pb2 as _toco_flags_pb2 from tensorflow.lite.toco import toco_flags_pb2 as _toco_flags_pb2
@ -32,13 +36,13 @@ from tensorflow.python.platform import app
def _parse_array(values, type_fn=str): def _parse_array(values, type_fn=str):
if values is not None: if values is not None:
return [type_fn(val) for val in values.split(",") if val] return [type_fn(val) for val in six.ensure_str(values).split(",") if val]
return None return None
def _parse_set(values): def _parse_set(values):
if values is not None: if values is not None:
return set([item for item in values.split(",") if item]) return set([item for item in six.ensure_str(values).split(",") if item])
return None return None
@ -81,9 +85,9 @@ def _get_toco_converter(flags):
if flags.input_shapes: if flags.input_shapes:
input_shapes_list = [ input_shapes_list = [
_parse_array(shape, type_fn=int) _parse_array(shape, type_fn=int)
for shape in flags.input_shapes.split(":") for shape in six.ensure_str(flags.input_shapes).split(":")
] ]
input_shapes = dict(zip(input_arrays, input_shapes_list)) input_shapes = dict(list(zip(input_arrays, input_shapes_list)))
output_arrays = _parse_array(flags.output_arrays) output_arrays = _parse_array(flags.output_arrays)
converter_kwargs = { converter_kwargs = {
@ -152,7 +156,7 @@ def _convert_tf1_model(flags):
"--std_dev_values and --mean_values with multiple input " "--std_dev_values and --mean_values with multiple input "
"tensors in order to map between names and " "tensors in order to map between names and "
"values.".format(",".join(input_arrays))) "values.".format(",".join(input_arrays)))
converter.quantized_input_stats = dict(zip(input_arrays, quant_stats)) converter.quantized_input_stats = dict(list(zip(input_arrays, quant_stats)))
if (flags.default_ranges_min is not None) and (flags.default_ranges_max is if (flags.default_ranges_min is not None) and (flags.default_ranges_max is
not None): not None):
converter.default_ranges_stats = (flags.default_ranges_min, converter.default_ranges_stats = (flags.default_ranges_min,
@ -171,7 +175,7 @@ def _convert_tf1_model(flags):
if flags.target_ops: if flags.target_ops:
ops_set_options = lite.OpsSet.get_options() ops_set_options = lite.OpsSet.get_options()
converter.target_spec.supported_ops = set() converter.target_spec.supported_ops = set()
for option in flags.target_ops.split(","): for option in six.ensure_str(flags.target_ops).split(","):
if option not in ops_set_options: if option not in ops_set_options:
raise ValueError("Invalid value for --target_ops. Options: " raise ValueError("Invalid value for --target_ops. Options: "
"{0}".format(",".join(ops_set_options))) "{0}".format(",".join(ops_set_options)))
@ -201,7 +205,7 @@ def _convert_tf1_model(flags):
# Convert model. # Convert model.
output_data = converter.convert() output_data = converter.convert()
with open(flags.output_file, "wb") as f: with open(flags.output_file, "wb") as f:
f.write(output_data) f.write(six.ensure_binary(output_data))
def _convert_tf2_model(flags): def _convert_tf2_model(flags):
@ -226,7 +230,7 @@ def _convert_tf2_model(flags):
# Convert the model. # Convert the model.
tflite_model = converter.convert() tflite_model = converter.convert()
with open(flags.output_file, "wb") as f: with open(flags.output_file, "wb") as f:
f.write(tflite_model) f.write(six.ensure_binary(tflite_model))
def _check_tf1_flags(flags, unparsed): def _check_tf1_flags(flags, unparsed):
@ -245,7 +249,7 @@ def _check_tf1_flags(flags, unparsed):
# Check unparsed flags for common mistakes based on previous TOCO. # Check unparsed flags for common mistakes based on previous TOCO.
def _get_message_unparsed(flag, orig_flag, new_flag): def _get_message_unparsed(flag, orig_flag, new_flag):
if flag.startswith(orig_flag): if six.ensure_str(flag).startswith(orig_flag):
return "\n Use {0} instead of {1}".format(new_flag, orig_flag) return "\n Use {0} instead of {1}".format(new_flag, orig_flag)
return "" return ""

View File

@ -1,3 +1,4 @@
# Lint as: python2, python3
# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # Copyright 2018 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
@ -20,6 +21,9 @@ from __future__ import print_function
import sys import sys
import six
from six.moves import range
from tensorflow.core.protobuf import config_pb2 as _config_pb2 from tensorflow.core.protobuf import config_pb2 as _config_pb2
from tensorflow.core.protobuf import meta_graph_pb2 as _meta_graph_pb2 from tensorflow.core.protobuf import meta_graph_pb2 as _meta_graph_pb2
from tensorflow.lite.python.op_hint import convert_op_hints_to_stubs from tensorflow.lite.python.op_hint import convert_op_hints_to_stubs
@ -74,7 +78,7 @@ def get_tensor_name(tensor):
Returns: Returns:
str str
""" """
parts = tensor.name.split(":") parts = six.ensure_str(tensor.name).split(":")
if len(parts) > 2: if len(parts) > 2:
raise ValueError("Tensor name invalid. Expect 0 or 1 colon, got {0}".format( raise ValueError("Tensor name invalid. Expect 0 or 1 colon, got {0}".format(
len(parts) - 1)) len(parts) - 1))
@ -277,7 +281,8 @@ def is_frozen_graph(sess):
Bool. Bool.
""" """
for op in sess.graph.get_operations(): for op in sess.graph.get_operations():
if op.type.startswith("Variable") or op.type.endswith("VariableOp"): if six.ensure_str(op.type).startswith("Variable") or six.ensure_str(
op.type).endswith("VariableOp"):
return False return False
return True return True

View File

@ -1,3 +1,4 @@
# Lint as: python2, python3
# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # Copyright 2019 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
@ -18,6 +19,8 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from six.moves import range
from tensorflow.lite.python import lite_constants from tensorflow.lite.python import lite_constants
from tensorflow.lite.python import util from tensorflow.lite.python import util
from tensorflow.lite.toco import types_pb2 as _types_pb2 from tensorflow.lite.toco import types_pb2 as _types_pb2