ICM PY3 Migration - //tensorflow/lite [1]
PiperOrigin-RevId: 274090960
This commit is contained in:
parent
e8fc11fb63
commit
c81c00b6ca
@ -7,7 +7,7 @@ py_binary(
|
||||
name = "label_image",
|
||||
srcs = ["label_image.py"],
|
||||
main = "label_image.py",
|
||||
python_version = "PY2",
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/lite/python:lite",
|
||||
|
@ -37,7 +37,7 @@ py_test(
|
||||
name = "unidirectional_sequence_lstm_test",
|
||||
size = "large",
|
||||
srcs = ["unidirectional_sequence_lstm_test.py"],
|
||||
python_version = "PY2",
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY2AND3",
|
||||
tags = [
|
||||
"no_oss",
|
||||
@ -60,7 +60,7 @@ py_test(
|
||||
name = "unidirectional_sequence_rnn_test",
|
||||
size = "large",
|
||||
srcs = ["unidirectional_sequence_rnn_test.py"],
|
||||
python_version = "PY2",
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY2AND3",
|
||||
tags = [
|
||||
"no_oss",
|
||||
@ -83,7 +83,7 @@ py_test(
|
||||
name = "bidirectional_sequence_lstm_test",
|
||||
size = "large",
|
||||
srcs = ["bidirectional_sequence_lstm_test.py"],
|
||||
python_version = "PY2",
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY2AND3",
|
||||
tags = [
|
||||
"no_oss",
|
||||
@ -106,7 +106,7 @@ py_test(
|
||||
name = "bidirectional_sequence_rnn_test",
|
||||
size = "large",
|
||||
srcs = ["bidirectional_sequence_rnn_test.py"],
|
||||
python_version = "PY2",
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY2AND3",
|
||||
tags = [
|
||||
"no_oss",
|
||||
|
@ -1,3 +1,4 @@
|
||||
# Lint as: python2, python3
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -17,6 +18,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
import tempfile
|
||||
import numpy as np
|
||||
from six.moves import range
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.examples.tutorials.mnist import input_data
|
||||
|
@ -1,3 +1,4 @@
|
||||
# Lint as: python2, python3
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -17,6 +18,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
import tempfile
|
||||
import numpy as np
|
||||
from six.moves import range
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow import flags
|
||||
|
@ -1,3 +1,4 @@
|
||||
# Lint as: python2, python3
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -153,7 +154,8 @@ class TfLiteRNNCell(rnn_cell_impl.LayerRNNCell):
|
||||
"reuse": self._reuse,
|
||||
}
|
||||
base_config = super(TfLiteRNNCell, self).get_config()
|
||||
return dict(itertools.chain(base_config.items(), config.items()))
|
||||
return dict(
|
||||
itertools.chain(list(base_config.items()), list(config.items())))
|
||||
|
||||
|
||||
@tf_export(v1=["lite.experimental.nn.TFLiteLSTMCell"])
|
||||
|
@ -1,3 +1,4 @@
|
||||
# Lint as: python2, python3
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -17,6 +18,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
import tempfile
|
||||
import numpy as np
|
||||
from six.moves import range
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.examples.tutorials.mnist import input_data
|
||||
|
@ -1,3 +1,4 @@
|
||||
# Lint as: python2, python3
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -17,6 +18,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
import tempfile
|
||||
import numpy as np
|
||||
from six.moves import range
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow import flags
|
||||
|
@ -1,3 +1,4 @@
|
||||
# Lint as: python2, python3
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -21,6 +22,7 @@ from __future__ import print_function
|
||||
import argparse
|
||||
import glob
|
||||
import os
|
||||
import six
|
||||
|
||||
|
||||
def rename_example_subfolder_files(library_dir):
|
||||
@ -50,8 +52,9 @@ def move_person_data(library_dir):
|
||||
with open(new_person_data_path, 'r') as source_file:
|
||||
file_contents = source_file.read()
|
||||
file_contents = file_contents.replace(
|
||||
'#include "tensorflow/lite/experimental/micro/examples/' +
|
||||
'person_detection/person_detect_model_data.h"',
|
||||
six.ensure_str(
|
||||
'#include "tensorflow/lite/experimental/micro/examples/' +
|
||||
'person_detection/person_detect_model_data.h"'),
|
||||
'#include "person_detect_model_data.h"')
|
||||
with open(new_person_data_path, 'w') as source_file:
|
||||
source_file.write(file_contents)
|
||||
|
@ -1,3 +1,4 @@
|
||||
# Lint as: python2, python3
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -21,11 +22,12 @@ from __future__ import print_function
|
||||
import argparse
|
||||
import os.path
|
||||
import re
|
||||
import six
|
||||
|
||||
|
||||
def sanitize_xml(unsanitized):
|
||||
"""Uses a whitelist to avoid generating bad XML."""
|
||||
return re.sub(r'[^a-zA-Z0-9+_\-/\\.]', '', unsanitized)
|
||||
return re.sub(r'[^a-zA-Z0-9+_\-/\\.]', '', six.ensure_str(unsanitized))
|
||||
|
||||
|
||||
def main(unused_args, flags):
|
||||
@ -33,11 +35,12 @@ def main(unused_args, flags):
|
||||
with open(flags.input_template, 'r') as input_template_file:
|
||||
template_file_text = input_template_file.read()
|
||||
|
||||
template_file_text = re.sub(r'%{EXECUTABLE}%', flags.executable,
|
||||
template_file_text = re.sub(r'%{EXECUTABLE}%',
|
||||
six.ensure_str(flags.executable),
|
||||
template_file_text)
|
||||
|
||||
srcs_list = flags.srcs.split(' ')
|
||||
hdrs_list = flags.hdrs.split(' ')
|
||||
srcs_list = six.ensure_str(flags.srcs).split(' ')
|
||||
hdrs_list = six.ensure_str(flags.hdrs).split(' ')
|
||||
all_srcs_list = srcs_list + hdrs_list
|
||||
all_srcs_list.sort()
|
||||
|
||||
@ -64,9 +67,10 @@ def main(unused_args, flags):
|
||||
replace_srcs += ' <FileType>' + ext_index + '</FileType>\n'
|
||||
replace_srcs += ' <FilePath>' + clean_src + '</FilePath>\n'
|
||||
replace_srcs += ' </File>\n'
|
||||
template_file_text = re.sub(r'%{SRCS}%', replace_srcs, template_file_text)
|
||||
template_file_text = re.sub(r'%{SRCS}%', replace_srcs,
|
||||
six.ensure_str(template_file_text))
|
||||
|
||||
include_paths = re.sub(' ', ';', flags.include_paths)
|
||||
include_paths = re.sub(' ', ';', six.ensure_str(flags.include_paths))
|
||||
template_file_text = re.sub(r'%{INCLUDE_PATHS}%', include_paths,
|
||||
template_file_text)
|
||||
|
||||
|
@ -1,3 +1,4 @@
|
||||
# Lint as: python2, python3
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -21,6 +22,7 @@ from __future__ import print_function
|
||||
import argparse
|
||||
import re
|
||||
import sys
|
||||
import six
|
||||
|
||||
|
||||
def replace_includes(line, supplied_headers_list):
|
||||
@ -29,10 +31,11 @@ def replace_includes(line, supplied_headers_list):
|
||||
if include_match:
|
||||
path = include_match.group(2)
|
||||
for supplied_header in supplied_headers_list:
|
||||
if supplied_header.endswith(path):
|
||||
if six.ensure_str(supplied_header).endswith(path):
|
||||
path = supplied_header
|
||||
break
|
||||
line = include_match.group(1) + path + include_match.group(3)
|
||||
line = include_match.group(1) + six.ensure_str(path) + include_match.group(
|
||||
3)
|
||||
return line
|
||||
|
||||
|
||||
@ -70,8 +73,8 @@ def replace_example_includes(line, _):
|
||||
# their default locations into the top-level 'examples' folder in the Arduino
|
||||
# library, we have to update any include references to match.
|
||||
dir_path = 'tensorflow/lite/experimental/micro/examples/'
|
||||
include_match = re.match(r'(.*#include.*")' + dir_path + r'([^/]+)/(.*")',
|
||||
line)
|
||||
include_match = re.match(
|
||||
r'(.*#include.*")' + six.ensure_str(dir_path) + r'([^/]+)/(.*")', line)
|
||||
if include_match:
|
||||
flattened_name = re.sub(r'/', '_', include_match.group(3))
|
||||
line = include_match.group(1) + flattened_name
|
||||
@ -82,7 +85,7 @@ def main(unused_args, flags):
|
||||
"""Transforms the input source file to work when exported to Arduino."""
|
||||
input_file_lines = sys.stdin.read().split('\n')
|
||||
|
||||
supplied_headers_list = flags.third_party_headers.split(' ')
|
||||
supplied_headers_list = six.ensure_str(flags.third_party_headers).split(' ')
|
||||
|
||||
output_lines = []
|
||||
for line in input_file_lines:
|
||||
|
@ -19,7 +19,7 @@ py_library(
|
||||
py_test(
|
||||
name = "ops_util_test",
|
||||
srcs = ["ops_util_test.py"],
|
||||
python_version = "PY2",
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":ops_util",
|
||||
|
@ -15,6 +15,7 @@ py_library(
|
||||
deps = [
|
||||
"//tensorflow/lite/python/interpreter_wrapper:tensorflow_wrap_interpreter_wrapper",
|
||||
"//third_party/py/numpy",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
||||
@ -25,6 +26,7 @@ py_test(
|
||||
"//tensorflow/lite/python/testdata:interpreter_test_data",
|
||||
"//tensorflow/lite/python/testdata:test_delegate.so",
|
||||
],
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY2AND3",
|
||||
tags = [
|
||||
"no_windows",
|
||||
@ -38,6 +40,7 @@ py_test(
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform",
|
||||
"//third_party/py/numpy",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
||||
@ -63,16 +66,20 @@ py_test(
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform",
|
||||
"//third_party/py/numpy",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
||||
py_binary(
|
||||
name = "tflite_convert",
|
||||
srcs = ["tflite_convert.py"],
|
||||
python_version = "PY2",
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY2AND3",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [":tflite_convert_main_lib"],
|
||||
deps = [
|
||||
":tflite_convert_main_lib",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
@ -80,7 +87,10 @@ py_library(
|
||||
srcs = ["tflite_convert.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [":tflite_convert_lib"],
|
||||
deps = [
|
||||
":tflite_convert_lib",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
@ -90,6 +100,7 @@ py_library(
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":lite",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
||||
@ -97,6 +108,7 @@ py_test(
|
||||
name = "tflite_convert_test",
|
||||
srcs = ["tflite_convert_test.py"],
|
||||
data = [":tflite_convert"],
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY2AND3",
|
||||
tags = [
|
||||
"no_oss",
|
||||
@ -129,6 +141,7 @@ py_library(
|
||||
"//tensorflow/python/keras",
|
||||
"//tensorflow/python/saved_model:constants",
|
||||
"//tensorflow/python/saved_model:loader",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
||||
@ -136,6 +149,7 @@ py_test(
|
||||
name = "lite_test",
|
||||
srcs = ["lite_test.py"],
|
||||
data = ["@tflite_mobilenet_ssd_quant_protobuf//:tflite_graph.pb"],
|
||||
python_version = "PY3",
|
||||
shard_count = 4,
|
||||
srcs_version = "PY2AND3",
|
||||
tags = [
|
||||
@ -145,12 +159,14 @@ py_test(
|
||||
":lite",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "lite_v2_test",
|
||||
srcs = ["lite_v2_test.py"],
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY2AND3",
|
||||
tags = [
|
||||
"no_windows",
|
||||
@ -159,12 +175,14 @@ py_test(
|
||||
":lite",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "lite_flex_test",
|
||||
srcs = ["lite_flex_test.py"],
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY2AND3",
|
||||
tags = [
|
||||
# TODO(b/111881877): Enable in oss after resolving op registry issues.
|
||||
@ -181,6 +199,7 @@ py_test(
|
||||
py_test(
|
||||
name = "lite_mlir_test",
|
||||
srcs = ["lite_mlir_test.py"],
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY2AND3",
|
||||
tags = [
|
||||
"no_windows",
|
||||
@ -189,6 +208,7 @@ py_test(
|
||||
":lite",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
||||
@ -204,12 +224,14 @@ py_library(
|
||||
":op_hint",
|
||||
"//tensorflow/python:tf_optimizer",
|
||||
"//tensorflow/python/eager:wrap_function",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "util_test",
|
||||
srcs = ["util_test.py"],
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY2AND3",
|
||||
tags = [
|
||||
"no_windows",
|
||||
@ -218,6 +240,7 @@ py_test(
|
||||
":util",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
||||
@ -226,6 +249,7 @@ py_library(
|
||||
srcs = [
|
||||
"wrap_toco.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/python:pywrap_tensorflow",
|
||||
"//tensorflow/python:util",
|
||||
@ -256,6 +280,7 @@ py_library(
|
||||
"//tensorflow/lite/toco/python:toco_from_protos",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:platform",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
||||
@ -275,6 +300,7 @@ py_library(
|
||||
py_test(
|
||||
name = "convert_test",
|
||||
srcs = ["convert_test.py"],
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":convert",
|
||||
@ -306,6 +332,7 @@ py_library(
|
||||
py_test(
|
||||
name = "convert_saved_model_test",
|
||||
srcs = ["convert_saved_model_test.py"],
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY2AND3",
|
||||
tags = [
|
||||
"no_windows",
|
||||
|
@ -1,3 +1,4 @@
|
||||
# Lint as: python2, python3
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -19,12 +20,14 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import enum # pylint: disable=g-bad-import-order
|
||||
|
||||
import os as _os
|
||||
import platform as _platform
|
||||
import subprocess as _subprocess
|
||||
import tempfile as _tempfile
|
||||
|
||||
import six
|
||||
from six.moves import map
|
||||
|
||||
from tensorflow.lite.python import lite_constants
|
||||
from tensorflow.lite.python import util
|
||||
from tensorflow.lite.python import wrap_toco
|
||||
@ -55,7 +58,7 @@ def _try_convert_to_unicode(output):
|
||||
|
||||
if isinstance(output, bytes):
|
||||
try:
|
||||
return output.decode()
|
||||
return six.ensure_text(output)
|
||||
except UnicodeDecodeError:
|
||||
pass
|
||||
return output
|
||||
@ -151,7 +154,7 @@ def toco_convert_protos(model_flags_str,
|
||||
|
||||
fp_model.write(model_flags_str)
|
||||
fp_toco.write(toco_flags_str)
|
||||
fp_input.write(input_data_str)
|
||||
fp_input.write(six.ensure_binary(input_data_str))
|
||||
debug_info_str = debug_info_str if debug_info_str else ""
|
||||
# if debug_info_str contains a "string value", then the call to
|
||||
# fp_debug.write(debug_info_str) will fail with the following error
|
||||
@ -347,7 +350,7 @@ def build_toco_convert_protos(input_tensors,
|
||||
shape = input_tensor.shape
|
||||
else:
|
||||
shape = input_shapes[idx]
|
||||
input_array.shape.dims.extend(map(int, shape))
|
||||
input_array.shape.dims.extend(list(map(int, shape)))
|
||||
|
||||
for output_tensor in output_tensors:
|
||||
model.output_arrays.append(util.get_tensor_name(output_tensor))
|
||||
@ -400,7 +403,7 @@ def toco_convert_graph_def(input_data, input_arrays_with_shape, output_arrays,
|
||||
input_array.mean_value, input_array.std_value = kwargs[
|
||||
"quantized_input_stats"][idx]
|
||||
input_array.name = name
|
||||
input_array.shape.dims.extend(map(int, shape))
|
||||
input_array.shape.dims.extend(list(map(int, shape)))
|
||||
|
||||
for name in output_arrays:
|
||||
model_flags.output_arrays.append(name)
|
||||
|
@ -1,3 +1,4 @@
|
||||
# Lint as: python2, python3
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -20,7 +21,10 @@ from __future__ import print_function
|
||||
import ctypes
|
||||
import platform
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
import six
|
||||
from six.moves import range
|
||||
|
||||
# pylint: disable=g-import-not-at-top
|
||||
if not __file__.endswith('tflite_runtime/interpreter.py'):
|
||||
@ -107,7 +111,7 @@ class Delegate(object):
|
||||
self.message = ''
|
||||
|
||||
def report(self, x):
|
||||
self.message += x if isinstance(x, str) else x.decode('utf-8')
|
||||
self.message += x if isinstance(x, str) else six.ensure_text(x, 'utf-8')
|
||||
|
||||
capture = ErrorMessageCapture()
|
||||
error_capturer_cb = ctypes.CFUNCTYPE(None, ctypes.c_char_p)(capture.report)
|
||||
|
@ -1,3 +1,4 @@
|
||||
# Lint as: python2, python3
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -335,7 +336,8 @@ class InterpreterDelegateTest(test_util.TensorFlowTestCase):
|
||||
if sys.platform == 'darwin': return
|
||||
destructions = []
|
||||
def register_destruction(x):
|
||||
destructions.append(x if isinstance(x, str) else x.decode('utf-8'))
|
||||
destructions.append(
|
||||
x if isinstance(x, str) else six.ensure_text(x, 'utf-8'))
|
||||
return 0
|
||||
# Make a wrapper for the callback so we can send this to ctypes
|
||||
delegate = interpreter_wrapper.load_delegate(self._delegate_file)
|
||||
|
@ -1,3 +1,4 @@
|
||||
# Lint as: python2, python3
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -18,8 +19,10 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import warnings
|
||||
import enum
|
||||
import warnings
|
||||
|
||||
import six
|
||||
from six import PY3
|
||||
|
||||
from google.protobuf import text_format as _text_format
|
||||
@ -679,9 +682,9 @@ class TFLiteConverter(TFLiteConverterBase):
|
||||
|
||||
if not isinstance(file_content, str):
|
||||
if PY3:
|
||||
file_content = file_content.decode("utf-8")
|
||||
file_content = six.ensure_text(file_content, "utf-8")
|
||||
else:
|
||||
file_content = file_content.encode("utf-8")
|
||||
file_content = six.ensure_binary(file_content, "utf-8")
|
||||
graph_def = _graph_pb2.GraphDef()
|
||||
_text_format.Merge(file_content, graph_def)
|
||||
except (_text_format.ParseError, DecodeError):
|
||||
|
@ -1,3 +1,4 @@
|
||||
# Lint as: python2, python3
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -20,6 +21,7 @@ from __future__ import print_function
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
from six.moves import zip
|
||||
|
||||
from tensorflow.lite.python import lite
|
||||
from tensorflow.lite.python import lite_constants
|
||||
|
@ -1,3 +1,4 @@
|
||||
# Lint as: python2, python3
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -20,8 +21,11 @@ from __future__ import print_function
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
import six
|
||||
from six.moves import range
|
||||
|
||||
from tensorflow.lite.python import lite
|
||||
from tensorflow.lite.python import lite_constants
|
||||
@ -1132,7 +1136,7 @@ class FromSessionTest(TestModels, parameterized.TestCase):
|
||||
|
||||
# Check the add node in the inlined function is included.
|
||||
func = sess.graph.as_graph_def().library.function[0].signature.name
|
||||
self.assertIn(('add@' + func), converter._debug_info.traces)
|
||||
self.assertIn(('add@' + six.ensure_str(func)), converter._debug_info.traces)
|
||||
|
||||
|
||||
class FromFrozenGraphFile(test_util.TensorFlowTestCase):
|
||||
|
@ -1,3 +1,4 @@
|
||||
# Lint as: python2, python3
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -20,6 +21,8 @@ from __future__ import print_function
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
from six.moves import range
|
||||
from six.moves import zip
|
||||
|
||||
from tensorflow.lite.python import lite
|
||||
from tensorflow.lite.python.interpreter import Interpreter
|
||||
|
@ -57,7 +57,7 @@ py_test(
|
||||
":test_data",
|
||||
"//tensorflow/lite:testdata/multi_add.bin",
|
||||
],
|
||||
python_version = "PY2",
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY2AND3",
|
||||
tags = ["no_oss"],
|
||||
deps = [
|
||||
@ -68,5 +68,6 @@ py_test(
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform",
|
||||
"//third_party/py/numpy",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
@ -1,3 +1,4 @@
|
||||
# Lint as: python2, python3
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -18,6 +19,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
import numpy as np
|
||||
from six.moves import range
|
||||
|
||||
from tensorflow.lite.python import lite_constants as constants
|
||||
from tensorflow.lite.python.optimize import calibrator as _calibrator
|
||||
|
@ -1,3 +1,4 @@
|
||||
# Lint as: python2, python3
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -22,6 +23,9 @@ import argparse
|
||||
import os
|
||||
import sys
|
||||
|
||||
import six
|
||||
from six.moves import zip
|
||||
|
||||
from tensorflow.lite.python import lite
|
||||
from tensorflow.lite.python import lite_constants
|
||||
from tensorflow.lite.toco import toco_flags_pb2 as _toco_flags_pb2
|
||||
@ -32,13 +36,13 @@ from tensorflow.python.platform import app
|
||||
|
||||
def _parse_array(values, type_fn=str):
|
||||
if values is not None:
|
||||
return [type_fn(val) for val in values.split(",") if val]
|
||||
return [type_fn(val) for val in six.ensure_str(values).split(",") if val]
|
||||
return None
|
||||
|
||||
|
||||
def _parse_set(values):
|
||||
if values is not None:
|
||||
return set([item for item in values.split(",") if item])
|
||||
return set([item for item in six.ensure_str(values).split(",") if item])
|
||||
return None
|
||||
|
||||
|
||||
@ -81,9 +85,9 @@ def _get_toco_converter(flags):
|
||||
if flags.input_shapes:
|
||||
input_shapes_list = [
|
||||
_parse_array(shape, type_fn=int)
|
||||
for shape in flags.input_shapes.split(":")
|
||||
for shape in six.ensure_str(flags.input_shapes).split(":")
|
||||
]
|
||||
input_shapes = dict(zip(input_arrays, input_shapes_list))
|
||||
input_shapes = dict(list(zip(input_arrays, input_shapes_list)))
|
||||
output_arrays = _parse_array(flags.output_arrays)
|
||||
|
||||
converter_kwargs = {
|
||||
@ -152,7 +156,7 @@ def _convert_tf1_model(flags):
|
||||
"--std_dev_values and --mean_values with multiple input "
|
||||
"tensors in order to map between names and "
|
||||
"values.".format(",".join(input_arrays)))
|
||||
converter.quantized_input_stats = dict(zip(input_arrays, quant_stats))
|
||||
converter.quantized_input_stats = dict(list(zip(input_arrays, quant_stats)))
|
||||
if (flags.default_ranges_min is not None) and (flags.default_ranges_max is
|
||||
not None):
|
||||
converter.default_ranges_stats = (flags.default_ranges_min,
|
||||
@ -171,7 +175,7 @@ def _convert_tf1_model(flags):
|
||||
if flags.target_ops:
|
||||
ops_set_options = lite.OpsSet.get_options()
|
||||
converter.target_spec.supported_ops = set()
|
||||
for option in flags.target_ops.split(","):
|
||||
for option in six.ensure_str(flags.target_ops).split(","):
|
||||
if option not in ops_set_options:
|
||||
raise ValueError("Invalid value for --target_ops. Options: "
|
||||
"{0}".format(",".join(ops_set_options)))
|
||||
@ -201,7 +205,7 @@ def _convert_tf1_model(flags):
|
||||
# Convert model.
|
||||
output_data = converter.convert()
|
||||
with open(flags.output_file, "wb") as f:
|
||||
f.write(output_data)
|
||||
f.write(six.ensure_binary(output_data))
|
||||
|
||||
|
||||
def _convert_tf2_model(flags):
|
||||
@ -226,7 +230,7 @@ def _convert_tf2_model(flags):
|
||||
# Convert the model.
|
||||
tflite_model = converter.convert()
|
||||
with open(flags.output_file, "wb") as f:
|
||||
f.write(tflite_model)
|
||||
f.write(six.ensure_binary(tflite_model))
|
||||
|
||||
|
||||
def _check_tf1_flags(flags, unparsed):
|
||||
@ -245,7 +249,7 @@ def _check_tf1_flags(flags, unparsed):
|
||||
|
||||
# Check unparsed flags for common mistakes based on previous TOCO.
|
||||
def _get_message_unparsed(flag, orig_flag, new_flag):
|
||||
if flag.startswith(orig_flag):
|
||||
if six.ensure_str(flag).startswith(orig_flag):
|
||||
return "\n Use {0} instead of {1}".format(new_flag, orig_flag)
|
||||
return ""
|
||||
|
||||
|
@ -1,3 +1,4 @@
|
||||
# Lint as: python2, python3
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -20,6 +21,9 @@ from __future__ import print_function
|
||||
|
||||
import sys
|
||||
|
||||
import six
|
||||
from six.moves import range
|
||||
|
||||
from tensorflow.core.protobuf import config_pb2 as _config_pb2
|
||||
from tensorflow.core.protobuf import meta_graph_pb2 as _meta_graph_pb2
|
||||
from tensorflow.lite.python.op_hint import convert_op_hints_to_stubs
|
||||
@ -74,7 +78,7 @@ def get_tensor_name(tensor):
|
||||
Returns:
|
||||
str
|
||||
"""
|
||||
parts = tensor.name.split(":")
|
||||
parts = six.ensure_str(tensor.name).split(":")
|
||||
if len(parts) > 2:
|
||||
raise ValueError("Tensor name invalid. Expect 0 or 1 colon, got {0}".format(
|
||||
len(parts) - 1))
|
||||
@ -277,7 +281,8 @@ def is_frozen_graph(sess):
|
||||
Bool.
|
||||
"""
|
||||
for op in sess.graph.get_operations():
|
||||
if op.type.startswith("Variable") or op.type.endswith("VariableOp"):
|
||||
if six.ensure_str(op.type).startswith("Variable") or six.ensure_str(
|
||||
op.type).endswith("VariableOp"):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
@ -1,3 +1,4 @@
|
||||
# Lint as: python2, python3
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -18,6 +19,8 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from six.moves import range
|
||||
|
||||
from tensorflow.lite.python import lite_constants
|
||||
from tensorflow.lite.python import util
|
||||
from tensorflow.lite.toco import types_pb2 as _types_pb2
|
||||
|
Loading…
Reference in New Issue
Block a user