ICM PY3 Migration - //tensorflow/lite [1]
PiperOrigin-RevId: 274090960
This commit is contained in:
parent
e8fc11fb63
commit
c81c00b6ca
@ -7,7 +7,7 @@ py_binary(
|
|||||||
name = "label_image",
|
name = "label_image",
|
||||||
srcs = ["label_image.py"],
|
srcs = ["label_image.py"],
|
||||||
main = "label_image.py",
|
main = "label_image.py",
|
||||||
python_version = "PY2",
|
python_version = "PY3",
|
||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/lite/python:lite",
|
"//tensorflow/lite/python:lite",
|
||||||
|
@ -37,7 +37,7 @@ py_test(
|
|||||||
name = "unidirectional_sequence_lstm_test",
|
name = "unidirectional_sequence_lstm_test",
|
||||||
size = "large",
|
size = "large",
|
||||||
srcs = ["unidirectional_sequence_lstm_test.py"],
|
srcs = ["unidirectional_sequence_lstm_test.py"],
|
||||||
python_version = "PY2",
|
python_version = "PY3",
|
||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
tags = [
|
tags = [
|
||||||
"no_oss",
|
"no_oss",
|
||||||
@ -60,7 +60,7 @@ py_test(
|
|||||||
name = "unidirectional_sequence_rnn_test",
|
name = "unidirectional_sequence_rnn_test",
|
||||||
size = "large",
|
size = "large",
|
||||||
srcs = ["unidirectional_sequence_rnn_test.py"],
|
srcs = ["unidirectional_sequence_rnn_test.py"],
|
||||||
python_version = "PY2",
|
python_version = "PY3",
|
||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
tags = [
|
tags = [
|
||||||
"no_oss",
|
"no_oss",
|
||||||
@ -83,7 +83,7 @@ py_test(
|
|||||||
name = "bidirectional_sequence_lstm_test",
|
name = "bidirectional_sequence_lstm_test",
|
||||||
size = "large",
|
size = "large",
|
||||||
srcs = ["bidirectional_sequence_lstm_test.py"],
|
srcs = ["bidirectional_sequence_lstm_test.py"],
|
||||||
python_version = "PY2",
|
python_version = "PY3",
|
||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
tags = [
|
tags = [
|
||||||
"no_oss",
|
"no_oss",
|
||||||
@ -106,7 +106,7 @@ py_test(
|
|||||||
name = "bidirectional_sequence_rnn_test",
|
name = "bidirectional_sequence_rnn_test",
|
||||||
size = "large",
|
size = "large",
|
||||||
srcs = ["bidirectional_sequence_rnn_test.py"],
|
srcs = ["bidirectional_sequence_rnn_test.py"],
|
||||||
python_version = "PY2",
|
python_version = "PY3",
|
||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
tags = [
|
tags = [
|
||||||
"no_oss",
|
"no_oss",
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
# Lint as: python2, python3
|
||||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
@ -17,6 +18,7 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
import tempfile
|
import tempfile
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from six.moves import range
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from tensorflow.examples.tutorials.mnist import input_data
|
from tensorflow.examples.tutorials.mnist import input_data
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
# Lint as: python2, python3
|
||||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
@ -17,6 +18,7 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
import tempfile
|
import tempfile
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from six.moves import range
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from tensorflow import flags
|
from tensorflow import flags
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
# Lint as: python2, python3
|
||||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
@ -153,7 +154,8 @@ class TfLiteRNNCell(rnn_cell_impl.LayerRNNCell):
|
|||||||
"reuse": self._reuse,
|
"reuse": self._reuse,
|
||||||
}
|
}
|
||||||
base_config = super(TfLiteRNNCell, self).get_config()
|
base_config = super(TfLiteRNNCell, self).get_config()
|
||||||
return dict(itertools.chain(base_config.items(), config.items()))
|
return dict(
|
||||||
|
itertools.chain(list(base_config.items()), list(config.items())))
|
||||||
|
|
||||||
|
|
||||||
@tf_export(v1=["lite.experimental.nn.TFLiteLSTMCell"])
|
@tf_export(v1=["lite.experimental.nn.TFLiteLSTMCell"])
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
# Lint as: python2, python3
|
||||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
@ -17,6 +18,7 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
import tempfile
|
import tempfile
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from six.moves import range
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from tensorflow.examples.tutorials.mnist import input_data
|
from tensorflow.examples.tutorials.mnist import input_data
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
# Lint as: python2, python3
|
||||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
@ -17,6 +18,7 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
import tempfile
|
import tempfile
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from six.moves import range
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from tensorflow import flags
|
from tensorflow import flags
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
# Lint as: python2, python3
|
||||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
@ -21,6 +22,7 @@ from __future__ import print_function
|
|||||||
import argparse
|
import argparse
|
||||||
import glob
|
import glob
|
||||||
import os
|
import os
|
||||||
|
import six
|
||||||
|
|
||||||
|
|
||||||
def rename_example_subfolder_files(library_dir):
|
def rename_example_subfolder_files(library_dir):
|
||||||
@ -50,8 +52,9 @@ def move_person_data(library_dir):
|
|||||||
with open(new_person_data_path, 'r') as source_file:
|
with open(new_person_data_path, 'r') as source_file:
|
||||||
file_contents = source_file.read()
|
file_contents = source_file.read()
|
||||||
file_contents = file_contents.replace(
|
file_contents = file_contents.replace(
|
||||||
'#include "tensorflow/lite/experimental/micro/examples/' +
|
six.ensure_str(
|
||||||
'person_detection/person_detect_model_data.h"',
|
'#include "tensorflow/lite/experimental/micro/examples/' +
|
||||||
|
'person_detection/person_detect_model_data.h"'),
|
||||||
'#include "person_detect_model_data.h"')
|
'#include "person_detect_model_data.h"')
|
||||||
with open(new_person_data_path, 'w') as source_file:
|
with open(new_person_data_path, 'w') as source_file:
|
||||||
source_file.write(file_contents)
|
source_file.write(file_contents)
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
# Lint as: python2, python3
|
||||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
@ -21,11 +22,12 @@ from __future__ import print_function
|
|||||||
import argparse
|
import argparse
|
||||||
import os.path
|
import os.path
|
||||||
import re
|
import re
|
||||||
|
import six
|
||||||
|
|
||||||
|
|
||||||
def sanitize_xml(unsanitized):
|
def sanitize_xml(unsanitized):
|
||||||
"""Uses a whitelist to avoid generating bad XML."""
|
"""Uses a whitelist to avoid generating bad XML."""
|
||||||
return re.sub(r'[^a-zA-Z0-9+_\-/\\.]', '', unsanitized)
|
return re.sub(r'[^a-zA-Z0-9+_\-/\\.]', '', six.ensure_str(unsanitized))
|
||||||
|
|
||||||
|
|
||||||
def main(unused_args, flags):
|
def main(unused_args, flags):
|
||||||
@ -33,11 +35,12 @@ def main(unused_args, flags):
|
|||||||
with open(flags.input_template, 'r') as input_template_file:
|
with open(flags.input_template, 'r') as input_template_file:
|
||||||
template_file_text = input_template_file.read()
|
template_file_text = input_template_file.read()
|
||||||
|
|
||||||
template_file_text = re.sub(r'%{EXECUTABLE}%', flags.executable,
|
template_file_text = re.sub(r'%{EXECUTABLE}%',
|
||||||
|
six.ensure_str(flags.executable),
|
||||||
template_file_text)
|
template_file_text)
|
||||||
|
|
||||||
srcs_list = flags.srcs.split(' ')
|
srcs_list = six.ensure_str(flags.srcs).split(' ')
|
||||||
hdrs_list = flags.hdrs.split(' ')
|
hdrs_list = six.ensure_str(flags.hdrs).split(' ')
|
||||||
all_srcs_list = srcs_list + hdrs_list
|
all_srcs_list = srcs_list + hdrs_list
|
||||||
all_srcs_list.sort()
|
all_srcs_list.sort()
|
||||||
|
|
||||||
@ -64,9 +67,10 @@ def main(unused_args, flags):
|
|||||||
replace_srcs += ' <FileType>' + ext_index + '</FileType>\n'
|
replace_srcs += ' <FileType>' + ext_index + '</FileType>\n'
|
||||||
replace_srcs += ' <FilePath>' + clean_src + '</FilePath>\n'
|
replace_srcs += ' <FilePath>' + clean_src + '</FilePath>\n'
|
||||||
replace_srcs += ' </File>\n'
|
replace_srcs += ' </File>\n'
|
||||||
template_file_text = re.sub(r'%{SRCS}%', replace_srcs, template_file_text)
|
template_file_text = re.sub(r'%{SRCS}%', replace_srcs,
|
||||||
|
six.ensure_str(template_file_text))
|
||||||
|
|
||||||
include_paths = re.sub(' ', ';', flags.include_paths)
|
include_paths = re.sub(' ', ';', six.ensure_str(flags.include_paths))
|
||||||
template_file_text = re.sub(r'%{INCLUDE_PATHS}%', include_paths,
|
template_file_text = re.sub(r'%{INCLUDE_PATHS}%', include_paths,
|
||||||
template_file_text)
|
template_file_text)
|
||||||
|
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
# Lint as: python2, python3
|
||||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
@ -21,6 +22,7 @@ from __future__ import print_function
|
|||||||
import argparse
|
import argparse
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
|
import six
|
||||||
|
|
||||||
|
|
||||||
def replace_includes(line, supplied_headers_list):
|
def replace_includes(line, supplied_headers_list):
|
||||||
@ -29,10 +31,11 @@ def replace_includes(line, supplied_headers_list):
|
|||||||
if include_match:
|
if include_match:
|
||||||
path = include_match.group(2)
|
path = include_match.group(2)
|
||||||
for supplied_header in supplied_headers_list:
|
for supplied_header in supplied_headers_list:
|
||||||
if supplied_header.endswith(path):
|
if six.ensure_str(supplied_header).endswith(path):
|
||||||
path = supplied_header
|
path = supplied_header
|
||||||
break
|
break
|
||||||
line = include_match.group(1) + path + include_match.group(3)
|
line = include_match.group(1) + six.ensure_str(path) + include_match.group(
|
||||||
|
3)
|
||||||
return line
|
return line
|
||||||
|
|
||||||
|
|
||||||
@ -70,8 +73,8 @@ def replace_example_includes(line, _):
|
|||||||
# their default locations into the top-level 'examples' folder in the Arduino
|
# their default locations into the top-level 'examples' folder in the Arduino
|
||||||
# library, we have to update any include references to match.
|
# library, we have to update any include references to match.
|
||||||
dir_path = 'tensorflow/lite/experimental/micro/examples/'
|
dir_path = 'tensorflow/lite/experimental/micro/examples/'
|
||||||
include_match = re.match(r'(.*#include.*")' + dir_path + r'([^/]+)/(.*")',
|
include_match = re.match(
|
||||||
line)
|
r'(.*#include.*")' + six.ensure_str(dir_path) + r'([^/]+)/(.*")', line)
|
||||||
if include_match:
|
if include_match:
|
||||||
flattened_name = re.sub(r'/', '_', include_match.group(3))
|
flattened_name = re.sub(r'/', '_', include_match.group(3))
|
||||||
line = include_match.group(1) + flattened_name
|
line = include_match.group(1) + flattened_name
|
||||||
@ -82,7 +85,7 @@ def main(unused_args, flags):
|
|||||||
"""Transforms the input source file to work when exported to Arduino."""
|
"""Transforms the input source file to work when exported to Arduino."""
|
||||||
input_file_lines = sys.stdin.read().split('\n')
|
input_file_lines = sys.stdin.read().split('\n')
|
||||||
|
|
||||||
supplied_headers_list = flags.third_party_headers.split(' ')
|
supplied_headers_list = six.ensure_str(flags.third_party_headers).split(' ')
|
||||||
|
|
||||||
output_lines = []
|
output_lines = []
|
||||||
for line in input_file_lines:
|
for line in input_file_lines:
|
||||||
|
@ -19,7 +19,7 @@ py_library(
|
|||||||
py_test(
|
py_test(
|
||||||
name = "ops_util_test",
|
name = "ops_util_test",
|
||||||
srcs = ["ops_util_test.py"],
|
srcs = ["ops_util_test.py"],
|
||||||
python_version = "PY2",
|
python_version = "PY3",
|
||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
deps = [
|
deps = [
|
||||||
":ops_util",
|
":ops_util",
|
||||||
|
@ -15,6 +15,7 @@ py_library(
|
|||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/lite/python/interpreter_wrapper:tensorflow_wrap_interpreter_wrapper",
|
"//tensorflow/lite/python/interpreter_wrapper:tensorflow_wrap_interpreter_wrapper",
|
||||||
"//third_party/py/numpy",
|
"//third_party/py/numpy",
|
||||||
|
"@six_archive//:six",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -25,6 +26,7 @@ py_test(
|
|||||||
"//tensorflow/lite/python/testdata:interpreter_test_data",
|
"//tensorflow/lite/python/testdata:interpreter_test_data",
|
||||||
"//tensorflow/lite/python/testdata:test_delegate.so",
|
"//tensorflow/lite/python/testdata:test_delegate.so",
|
||||||
],
|
],
|
||||||
|
python_version = "PY3",
|
||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
tags = [
|
tags = [
|
||||||
"no_windows",
|
"no_windows",
|
||||||
@ -38,6 +40,7 @@ py_test(
|
|||||||
"//tensorflow/python:framework_test_lib",
|
"//tensorflow/python:framework_test_lib",
|
||||||
"//tensorflow/python:platform",
|
"//tensorflow/python:platform",
|
||||||
"//third_party/py/numpy",
|
"//third_party/py/numpy",
|
||||||
|
"@six_archive//:six",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -63,16 +66,20 @@ py_test(
|
|||||||
"//tensorflow/python:framework_test_lib",
|
"//tensorflow/python:framework_test_lib",
|
||||||
"//tensorflow/python:platform",
|
"//tensorflow/python:platform",
|
||||||
"//third_party/py/numpy",
|
"//third_party/py/numpy",
|
||||||
|
"@six_archive//:six",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
py_binary(
|
py_binary(
|
||||||
name = "tflite_convert",
|
name = "tflite_convert",
|
||||||
srcs = ["tflite_convert.py"],
|
srcs = ["tflite_convert.py"],
|
||||||
python_version = "PY2",
|
python_version = "PY3",
|
||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
deps = [":tflite_convert_main_lib"],
|
deps = [
|
||||||
|
":tflite_convert_main_lib",
|
||||||
|
"@six_archive//:six",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
@ -80,7 +87,10 @@ py_library(
|
|||||||
srcs = ["tflite_convert.py"],
|
srcs = ["tflite_convert.py"],
|
||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
deps = [":tflite_convert_lib"],
|
deps = [
|
||||||
|
":tflite_convert_lib",
|
||||||
|
"@six_archive//:six",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
@ -90,6 +100,7 @@ py_library(
|
|||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
deps = [
|
deps = [
|
||||||
":lite",
|
":lite",
|
||||||
|
"@six_archive//:six",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -97,6 +108,7 @@ py_test(
|
|||||||
name = "tflite_convert_test",
|
name = "tflite_convert_test",
|
||||||
srcs = ["tflite_convert_test.py"],
|
srcs = ["tflite_convert_test.py"],
|
||||||
data = [":tflite_convert"],
|
data = [":tflite_convert"],
|
||||||
|
python_version = "PY3",
|
||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
tags = [
|
tags = [
|
||||||
"no_oss",
|
"no_oss",
|
||||||
@ -129,6 +141,7 @@ py_library(
|
|||||||
"//tensorflow/python/keras",
|
"//tensorflow/python/keras",
|
||||||
"//tensorflow/python/saved_model:constants",
|
"//tensorflow/python/saved_model:constants",
|
||||||
"//tensorflow/python/saved_model:loader",
|
"//tensorflow/python/saved_model:loader",
|
||||||
|
"@six_archive//:six",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -136,6 +149,7 @@ py_test(
|
|||||||
name = "lite_test",
|
name = "lite_test",
|
||||||
srcs = ["lite_test.py"],
|
srcs = ["lite_test.py"],
|
||||||
data = ["@tflite_mobilenet_ssd_quant_protobuf//:tflite_graph.pb"],
|
data = ["@tflite_mobilenet_ssd_quant_protobuf//:tflite_graph.pb"],
|
||||||
|
python_version = "PY3",
|
||||||
shard_count = 4,
|
shard_count = 4,
|
||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
tags = [
|
tags = [
|
||||||
@ -145,12 +159,14 @@ py_test(
|
|||||||
":lite",
|
":lite",
|
||||||
"//tensorflow/python:client_testlib",
|
"//tensorflow/python:client_testlib",
|
||||||
"//tensorflow/python:framework_test_lib",
|
"//tensorflow/python:framework_test_lib",
|
||||||
|
"@six_archive//:six",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
py_test(
|
py_test(
|
||||||
name = "lite_v2_test",
|
name = "lite_v2_test",
|
||||||
srcs = ["lite_v2_test.py"],
|
srcs = ["lite_v2_test.py"],
|
||||||
|
python_version = "PY3",
|
||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
tags = [
|
tags = [
|
||||||
"no_windows",
|
"no_windows",
|
||||||
@ -159,12 +175,14 @@ py_test(
|
|||||||
":lite",
|
":lite",
|
||||||
"//tensorflow/python:client_testlib",
|
"//tensorflow/python:client_testlib",
|
||||||
"//tensorflow/python:framework_test_lib",
|
"//tensorflow/python:framework_test_lib",
|
||||||
|
"@six_archive//:six",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
py_test(
|
py_test(
|
||||||
name = "lite_flex_test",
|
name = "lite_flex_test",
|
||||||
srcs = ["lite_flex_test.py"],
|
srcs = ["lite_flex_test.py"],
|
||||||
|
python_version = "PY3",
|
||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
tags = [
|
tags = [
|
||||||
# TODO(b/111881877): Enable in oss after resolving op registry issues.
|
# TODO(b/111881877): Enable in oss after resolving op registry issues.
|
||||||
@ -181,6 +199,7 @@ py_test(
|
|||||||
py_test(
|
py_test(
|
||||||
name = "lite_mlir_test",
|
name = "lite_mlir_test",
|
||||||
srcs = ["lite_mlir_test.py"],
|
srcs = ["lite_mlir_test.py"],
|
||||||
|
python_version = "PY3",
|
||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
tags = [
|
tags = [
|
||||||
"no_windows",
|
"no_windows",
|
||||||
@ -189,6 +208,7 @@ py_test(
|
|||||||
":lite",
|
":lite",
|
||||||
"//tensorflow/python:client_testlib",
|
"//tensorflow/python:client_testlib",
|
||||||
"//tensorflow/python:framework_test_lib",
|
"//tensorflow/python:framework_test_lib",
|
||||||
|
"@six_archive//:six",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -204,12 +224,14 @@ py_library(
|
|||||||
":op_hint",
|
":op_hint",
|
||||||
"//tensorflow/python:tf_optimizer",
|
"//tensorflow/python:tf_optimizer",
|
||||||
"//tensorflow/python/eager:wrap_function",
|
"//tensorflow/python/eager:wrap_function",
|
||||||
|
"@six_archive//:six",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
py_test(
|
py_test(
|
||||||
name = "util_test",
|
name = "util_test",
|
||||||
srcs = ["util_test.py"],
|
srcs = ["util_test.py"],
|
||||||
|
python_version = "PY3",
|
||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
tags = [
|
tags = [
|
||||||
"no_windows",
|
"no_windows",
|
||||||
@ -218,6 +240,7 @@ py_test(
|
|||||||
":util",
|
":util",
|
||||||
"//tensorflow/python:client_testlib",
|
"//tensorflow/python:client_testlib",
|
||||||
"//tensorflow/python:framework_test_lib",
|
"//tensorflow/python:framework_test_lib",
|
||||||
|
"@six_archive//:six",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -226,6 +249,7 @@ py_library(
|
|||||||
srcs = [
|
srcs = [
|
||||||
"wrap_toco.py",
|
"wrap_toco.py",
|
||||||
],
|
],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/python:pywrap_tensorflow",
|
"//tensorflow/python:pywrap_tensorflow",
|
||||||
"//tensorflow/python:util",
|
"//tensorflow/python:util",
|
||||||
@ -256,6 +280,7 @@ py_library(
|
|||||||
"//tensorflow/lite/toco/python:toco_from_protos",
|
"//tensorflow/lite/toco/python:toco_from_protos",
|
||||||
"//tensorflow/python:dtypes",
|
"//tensorflow/python:dtypes",
|
||||||
"//tensorflow/python:platform",
|
"//tensorflow/python:platform",
|
||||||
|
"@six_archive//:six",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -275,6 +300,7 @@ py_library(
|
|||||||
py_test(
|
py_test(
|
||||||
name = "convert_test",
|
name = "convert_test",
|
||||||
srcs = ["convert_test.py"],
|
srcs = ["convert_test.py"],
|
||||||
|
python_version = "PY3",
|
||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
deps = [
|
deps = [
|
||||||
":convert",
|
":convert",
|
||||||
@ -306,6 +332,7 @@ py_library(
|
|||||||
py_test(
|
py_test(
|
||||||
name = "convert_saved_model_test",
|
name = "convert_saved_model_test",
|
||||||
srcs = ["convert_saved_model_test.py"],
|
srcs = ["convert_saved_model_test.py"],
|
||||||
|
python_version = "PY3",
|
||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
tags = [
|
tags = [
|
||||||
"no_windows",
|
"no_windows",
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
# Lint as: python2, python3
|
||||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
@ -19,12 +20,14 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import enum # pylint: disable=g-bad-import-order
|
import enum # pylint: disable=g-bad-import-order
|
||||||
|
|
||||||
import os as _os
|
import os as _os
|
||||||
import platform as _platform
|
import platform as _platform
|
||||||
import subprocess as _subprocess
|
import subprocess as _subprocess
|
||||||
import tempfile as _tempfile
|
import tempfile as _tempfile
|
||||||
|
|
||||||
|
import six
|
||||||
|
from six.moves import map
|
||||||
|
|
||||||
from tensorflow.lite.python import lite_constants
|
from tensorflow.lite.python import lite_constants
|
||||||
from tensorflow.lite.python import util
|
from tensorflow.lite.python import util
|
||||||
from tensorflow.lite.python import wrap_toco
|
from tensorflow.lite.python import wrap_toco
|
||||||
@ -55,7 +58,7 @@ def _try_convert_to_unicode(output):
|
|||||||
|
|
||||||
if isinstance(output, bytes):
|
if isinstance(output, bytes):
|
||||||
try:
|
try:
|
||||||
return output.decode()
|
return six.ensure_text(output)
|
||||||
except UnicodeDecodeError:
|
except UnicodeDecodeError:
|
||||||
pass
|
pass
|
||||||
return output
|
return output
|
||||||
@ -151,7 +154,7 @@ def toco_convert_protos(model_flags_str,
|
|||||||
|
|
||||||
fp_model.write(model_flags_str)
|
fp_model.write(model_flags_str)
|
||||||
fp_toco.write(toco_flags_str)
|
fp_toco.write(toco_flags_str)
|
||||||
fp_input.write(input_data_str)
|
fp_input.write(six.ensure_binary(input_data_str))
|
||||||
debug_info_str = debug_info_str if debug_info_str else ""
|
debug_info_str = debug_info_str if debug_info_str else ""
|
||||||
# if debug_info_str contains a "string value", then the call to
|
# if debug_info_str contains a "string value", then the call to
|
||||||
# fp_debug.write(debug_info_str) will fail with the following error
|
# fp_debug.write(debug_info_str) will fail with the following error
|
||||||
@ -347,7 +350,7 @@ def build_toco_convert_protos(input_tensors,
|
|||||||
shape = input_tensor.shape
|
shape = input_tensor.shape
|
||||||
else:
|
else:
|
||||||
shape = input_shapes[idx]
|
shape = input_shapes[idx]
|
||||||
input_array.shape.dims.extend(map(int, shape))
|
input_array.shape.dims.extend(list(map(int, shape)))
|
||||||
|
|
||||||
for output_tensor in output_tensors:
|
for output_tensor in output_tensors:
|
||||||
model.output_arrays.append(util.get_tensor_name(output_tensor))
|
model.output_arrays.append(util.get_tensor_name(output_tensor))
|
||||||
@ -400,7 +403,7 @@ def toco_convert_graph_def(input_data, input_arrays_with_shape, output_arrays,
|
|||||||
input_array.mean_value, input_array.std_value = kwargs[
|
input_array.mean_value, input_array.std_value = kwargs[
|
||||||
"quantized_input_stats"][idx]
|
"quantized_input_stats"][idx]
|
||||||
input_array.name = name
|
input_array.name = name
|
||||||
input_array.shape.dims.extend(map(int, shape))
|
input_array.shape.dims.extend(list(map(int, shape)))
|
||||||
|
|
||||||
for name in output_arrays:
|
for name in output_arrays:
|
||||||
model_flags.output_arrays.append(name)
|
model_flags.output_arrays.append(name)
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
# Lint as: python2, python3
|
||||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
@ -20,7 +21,10 @@ from __future__ import print_function
|
|||||||
import ctypes
|
import ctypes
|
||||||
import platform
|
import platform
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import six
|
||||||
|
from six.moves import range
|
||||||
|
|
||||||
# pylint: disable=g-import-not-at-top
|
# pylint: disable=g-import-not-at-top
|
||||||
if not __file__.endswith('tflite_runtime/interpreter.py'):
|
if not __file__.endswith('tflite_runtime/interpreter.py'):
|
||||||
@ -107,7 +111,7 @@ class Delegate(object):
|
|||||||
self.message = ''
|
self.message = ''
|
||||||
|
|
||||||
def report(self, x):
|
def report(self, x):
|
||||||
self.message += x if isinstance(x, str) else x.decode('utf-8')
|
self.message += x if isinstance(x, str) else six.ensure_text(x, 'utf-8')
|
||||||
|
|
||||||
capture = ErrorMessageCapture()
|
capture = ErrorMessageCapture()
|
||||||
error_capturer_cb = ctypes.CFUNCTYPE(None, ctypes.c_char_p)(capture.report)
|
error_capturer_cb = ctypes.CFUNCTYPE(None, ctypes.c_char_p)(capture.report)
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
# Lint as: python2, python3
|
||||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
@ -335,7 +336,8 @@ class InterpreterDelegateTest(test_util.TensorFlowTestCase):
|
|||||||
if sys.platform == 'darwin': return
|
if sys.platform == 'darwin': return
|
||||||
destructions = []
|
destructions = []
|
||||||
def register_destruction(x):
|
def register_destruction(x):
|
||||||
destructions.append(x if isinstance(x, str) else x.decode('utf-8'))
|
destructions.append(
|
||||||
|
x if isinstance(x, str) else six.ensure_text(x, 'utf-8'))
|
||||||
return 0
|
return 0
|
||||||
# Make a wrapper for the callback so we can send this to ctypes
|
# Make a wrapper for the callback so we can send this to ctypes
|
||||||
delegate = interpreter_wrapper.load_delegate(self._delegate_file)
|
delegate = interpreter_wrapper.load_delegate(self._delegate_file)
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
# Lint as: python2, python3
|
||||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
@ -18,8 +19,10 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import warnings
|
|
||||||
import enum
|
import enum
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
import six
|
||||||
from six import PY3
|
from six import PY3
|
||||||
|
|
||||||
from google.protobuf import text_format as _text_format
|
from google.protobuf import text_format as _text_format
|
||||||
@ -679,9 +682,9 @@ class TFLiteConverter(TFLiteConverterBase):
|
|||||||
|
|
||||||
if not isinstance(file_content, str):
|
if not isinstance(file_content, str):
|
||||||
if PY3:
|
if PY3:
|
||||||
file_content = file_content.decode("utf-8")
|
file_content = six.ensure_text(file_content, "utf-8")
|
||||||
else:
|
else:
|
||||||
file_content = file_content.encode("utf-8")
|
file_content = six.ensure_binary(file_content, "utf-8")
|
||||||
graph_def = _graph_pb2.GraphDef()
|
graph_def = _graph_pb2.GraphDef()
|
||||||
_text_format.Merge(file_content, graph_def)
|
_text_format.Merge(file_content, graph_def)
|
||||||
except (_text_format.ParseError, DecodeError):
|
except (_text_format.ParseError, DecodeError):
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
# Lint as: python2, python3
|
||||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
@ -20,6 +21,7 @@ from __future__ import print_function
|
|||||||
|
|
||||||
from absl.testing import parameterized
|
from absl.testing import parameterized
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from six.moves import zip
|
||||||
|
|
||||||
from tensorflow.lite.python import lite
|
from tensorflow.lite.python import lite
|
||||||
from tensorflow.lite.python import lite_constants
|
from tensorflow.lite.python import lite_constants
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
# Lint as: python2, python3
|
||||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
@ -20,8 +21,11 @@ from __future__ import print_function
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
|
|
||||||
from absl.testing import parameterized
|
from absl.testing import parameterized
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import six
|
||||||
|
from six.moves import range
|
||||||
|
|
||||||
from tensorflow.lite.python import lite
|
from tensorflow.lite.python import lite
|
||||||
from tensorflow.lite.python import lite_constants
|
from tensorflow.lite.python import lite_constants
|
||||||
@ -1132,7 +1136,7 @@ class FromSessionTest(TestModels, parameterized.TestCase):
|
|||||||
|
|
||||||
# Check the add node in the inlined function is included.
|
# Check the add node in the inlined function is included.
|
||||||
func = sess.graph.as_graph_def().library.function[0].signature.name
|
func = sess.graph.as_graph_def().library.function[0].signature.name
|
||||||
self.assertIn(('add@' + func), converter._debug_info.traces)
|
self.assertIn(('add@' + six.ensure_str(func)), converter._debug_info.traces)
|
||||||
|
|
||||||
|
|
||||||
class FromFrozenGraphFile(test_util.TensorFlowTestCase):
|
class FromFrozenGraphFile(test_util.TensorFlowTestCase):
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
# Lint as: python2, python3
|
||||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
@ -20,6 +21,8 @@ from __future__ import print_function
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from six.moves import range
|
||||||
|
from six.moves import zip
|
||||||
|
|
||||||
from tensorflow.lite.python import lite
|
from tensorflow.lite.python import lite
|
||||||
from tensorflow.lite.python.interpreter import Interpreter
|
from tensorflow.lite.python.interpreter import Interpreter
|
||||||
|
@ -57,7 +57,7 @@ py_test(
|
|||||||
":test_data",
|
":test_data",
|
||||||
"//tensorflow/lite:testdata/multi_add.bin",
|
"//tensorflow/lite:testdata/multi_add.bin",
|
||||||
],
|
],
|
||||||
python_version = "PY2",
|
python_version = "PY3",
|
||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
tags = ["no_oss"],
|
tags = ["no_oss"],
|
||||||
deps = [
|
deps = [
|
||||||
@ -68,5 +68,6 @@ py_test(
|
|||||||
"//tensorflow/python:framework_test_lib",
|
"//tensorflow/python:framework_test_lib",
|
||||||
"//tensorflow/python:platform",
|
"//tensorflow/python:platform",
|
||||||
"//third_party/py/numpy",
|
"//third_party/py/numpy",
|
||||||
|
"@six_archive//:six",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
# Lint as: python2, python3
|
||||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
@ -18,6 +19,7 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from six.moves import range
|
||||||
|
|
||||||
from tensorflow.lite.python import lite_constants as constants
|
from tensorflow.lite.python import lite_constants as constants
|
||||||
from tensorflow.lite.python.optimize import calibrator as _calibrator
|
from tensorflow.lite.python.optimize import calibrator as _calibrator
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
# Lint as: python2, python3
|
||||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
@ -22,6 +23,9 @@ import argparse
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
import six
|
||||||
|
from six.moves import zip
|
||||||
|
|
||||||
from tensorflow.lite.python import lite
|
from tensorflow.lite.python import lite
|
||||||
from tensorflow.lite.python import lite_constants
|
from tensorflow.lite.python import lite_constants
|
||||||
from tensorflow.lite.toco import toco_flags_pb2 as _toco_flags_pb2
|
from tensorflow.lite.toco import toco_flags_pb2 as _toco_flags_pb2
|
||||||
@ -32,13 +36,13 @@ from tensorflow.python.platform import app
|
|||||||
|
|
||||||
def _parse_array(values, type_fn=str):
|
def _parse_array(values, type_fn=str):
|
||||||
if values is not None:
|
if values is not None:
|
||||||
return [type_fn(val) for val in values.split(",") if val]
|
return [type_fn(val) for val in six.ensure_str(values).split(",") if val]
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def _parse_set(values):
|
def _parse_set(values):
|
||||||
if values is not None:
|
if values is not None:
|
||||||
return set([item for item in values.split(",") if item])
|
return set([item for item in six.ensure_str(values).split(",") if item])
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
@ -81,9 +85,9 @@ def _get_toco_converter(flags):
|
|||||||
if flags.input_shapes:
|
if flags.input_shapes:
|
||||||
input_shapes_list = [
|
input_shapes_list = [
|
||||||
_parse_array(shape, type_fn=int)
|
_parse_array(shape, type_fn=int)
|
||||||
for shape in flags.input_shapes.split(":")
|
for shape in six.ensure_str(flags.input_shapes).split(":")
|
||||||
]
|
]
|
||||||
input_shapes = dict(zip(input_arrays, input_shapes_list))
|
input_shapes = dict(list(zip(input_arrays, input_shapes_list)))
|
||||||
output_arrays = _parse_array(flags.output_arrays)
|
output_arrays = _parse_array(flags.output_arrays)
|
||||||
|
|
||||||
converter_kwargs = {
|
converter_kwargs = {
|
||||||
@ -152,7 +156,7 @@ def _convert_tf1_model(flags):
|
|||||||
"--std_dev_values and --mean_values with multiple input "
|
"--std_dev_values and --mean_values with multiple input "
|
||||||
"tensors in order to map between names and "
|
"tensors in order to map between names and "
|
||||||
"values.".format(",".join(input_arrays)))
|
"values.".format(",".join(input_arrays)))
|
||||||
converter.quantized_input_stats = dict(zip(input_arrays, quant_stats))
|
converter.quantized_input_stats = dict(list(zip(input_arrays, quant_stats)))
|
||||||
if (flags.default_ranges_min is not None) and (flags.default_ranges_max is
|
if (flags.default_ranges_min is not None) and (flags.default_ranges_max is
|
||||||
not None):
|
not None):
|
||||||
converter.default_ranges_stats = (flags.default_ranges_min,
|
converter.default_ranges_stats = (flags.default_ranges_min,
|
||||||
@ -171,7 +175,7 @@ def _convert_tf1_model(flags):
|
|||||||
if flags.target_ops:
|
if flags.target_ops:
|
||||||
ops_set_options = lite.OpsSet.get_options()
|
ops_set_options = lite.OpsSet.get_options()
|
||||||
converter.target_spec.supported_ops = set()
|
converter.target_spec.supported_ops = set()
|
||||||
for option in flags.target_ops.split(","):
|
for option in six.ensure_str(flags.target_ops).split(","):
|
||||||
if option not in ops_set_options:
|
if option not in ops_set_options:
|
||||||
raise ValueError("Invalid value for --target_ops. Options: "
|
raise ValueError("Invalid value for --target_ops. Options: "
|
||||||
"{0}".format(",".join(ops_set_options)))
|
"{0}".format(",".join(ops_set_options)))
|
||||||
@ -201,7 +205,7 @@ def _convert_tf1_model(flags):
|
|||||||
# Convert model.
|
# Convert model.
|
||||||
output_data = converter.convert()
|
output_data = converter.convert()
|
||||||
with open(flags.output_file, "wb") as f:
|
with open(flags.output_file, "wb") as f:
|
||||||
f.write(output_data)
|
f.write(six.ensure_binary(output_data))
|
||||||
|
|
||||||
|
|
||||||
def _convert_tf2_model(flags):
|
def _convert_tf2_model(flags):
|
||||||
@ -226,7 +230,7 @@ def _convert_tf2_model(flags):
|
|||||||
# Convert the model.
|
# Convert the model.
|
||||||
tflite_model = converter.convert()
|
tflite_model = converter.convert()
|
||||||
with open(flags.output_file, "wb") as f:
|
with open(flags.output_file, "wb") as f:
|
||||||
f.write(tflite_model)
|
f.write(six.ensure_binary(tflite_model))
|
||||||
|
|
||||||
|
|
||||||
def _check_tf1_flags(flags, unparsed):
|
def _check_tf1_flags(flags, unparsed):
|
||||||
@ -245,7 +249,7 @@ def _check_tf1_flags(flags, unparsed):
|
|||||||
|
|
||||||
# Check unparsed flags for common mistakes based on previous TOCO.
|
# Check unparsed flags for common mistakes based on previous TOCO.
|
||||||
def _get_message_unparsed(flag, orig_flag, new_flag):
|
def _get_message_unparsed(flag, orig_flag, new_flag):
|
||||||
if flag.startswith(orig_flag):
|
if six.ensure_str(flag).startswith(orig_flag):
|
||||||
return "\n Use {0} instead of {1}".format(new_flag, orig_flag)
|
return "\n Use {0} instead of {1}".format(new_flag, orig_flag)
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
# Lint as: python2, python3
|
||||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
@ -20,6 +21,9 @@ from __future__ import print_function
|
|||||||
|
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
import six
|
||||||
|
from six.moves import range
|
||||||
|
|
||||||
from tensorflow.core.protobuf import config_pb2 as _config_pb2
|
from tensorflow.core.protobuf import config_pb2 as _config_pb2
|
||||||
from tensorflow.core.protobuf import meta_graph_pb2 as _meta_graph_pb2
|
from tensorflow.core.protobuf import meta_graph_pb2 as _meta_graph_pb2
|
||||||
from tensorflow.lite.python.op_hint import convert_op_hints_to_stubs
|
from tensorflow.lite.python.op_hint import convert_op_hints_to_stubs
|
||||||
@ -74,7 +78,7 @@ def get_tensor_name(tensor):
|
|||||||
Returns:
|
Returns:
|
||||||
str
|
str
|
||||||
"""
|
"""
|
||||||
parts = tensor.name.split(":")
|
parts = six.ensure_str(tensor.name).split(":")
|
||||||
if len(parts) > 2:
|
if len(parts) > 2:
|
||||||
raise ValueError("Tensor name invalid. Expect 0 or 1 colon, got {0}".format(
|
raise ValueError("Tensor name invalid. Expect 0 or 1 colon, got {0}".format(
|
||||||
len(parts) - 1))
|
len(parts) - 1))
|
||||||
@ -277,7 +281,8 @@ def is_frozen_graph(sess):
|
|||||||
Bool.
|
Bool.
|
||||||
"""
|
"""
|
||||||
for op in sess.graph.get_operations():
|
for op in sess.graph.get_operations():
|
||||||
if op.type.startswith("Variable") or op.type.endswith("VariableOp"):
|
if six.ensure_str(op.type).startswith("Variable") or six.ensure_str(
|
||||||
|
op.type).endswith("VariableOp"):
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
# Lint as: python2, python3
|
||||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
@ -18,6 +19,8 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from six.moves import range
|
||||||
|
|
||||||
from tensorflow.lite.python import lite_constants
|
from tensorflow.lite.python import lite_constants
|
||||||
from tensorflow.lite.python import util
|
from tensorflow.lite.python import util
|
||||||
from tensorflow.lite.toco import types_pb2 as _types_pb2
|
from tensorflow.lite.toco import types_pb2 as _types_pb2
|
||||||
|
Loading…
Reference in New Issue
Block a user