ICM PY3 Migration - //tensorflow/lite [1]

PiperOrigin-RevId: 274090960
This commit is contained in:
Hye Soo Yang 2019-10-10 19:49:22 -07:00 committed by TensorFlower Gardener
parent e8fc11fb63
commit c81c00b6ca
24 changed files with 129 additions and 46 deletions

View File

@ -7,7 +7,7 @@ py_binary(
name = "label_image",
srcs = ["label_image.py"],
main = "label_image.py",
python_version = "PY2",
python_version = "PY3",
srcs_version = "PY2AND3",
deps = [
"//tensorflow/lite/python:lite",

View File

@ -37,7 +37,7 @@ py_test(
name = "unidirectional_sequence_lstm_test",
size = "large",
srcs = ["unidirectional_sequence_lstm_test.py"],
python_version = "PY2",
python_version = "PY3",
srcs_version = "PY2AND3",
tags = [
"no_oss",
@ -60,7 +60,7 @@ py_test(
name = "unidirectional_sequence_rnn_test",
size = "large",
srcs = ["unidirectional_sequence_rnn_test.py"],
python_version = "PY2",
python_version = "PY3",
srcs_version = "PY2AND3",
tags = [
"no_oss",
@ -83,7 +83,7 @@ py_test(
name = "bidirectional_sequence_lstm_test",
size = "large",
srcs = ["bidirectional_sequence_lstm_test.py"],
python_version = "PY2",
python_version = "PY3",
srcs_version = "PY2AND3",
tags = [
"no_oss",
@ -106,7 +106,7 @@ py_test(
name = "bidirectional_sequence_rnn_test",
size = "large",
srcs = ["bidirectional_sequence_rnn_test.py"],
python_version = "PY2",
python_version = "PY3",
srcs_version = "PY2AND3",
tags = [
"no_oss",

View File

@ -1,3 +1,4 @@
# Lint as: python2, python3
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -17,6 +18,7 @@ from __future__ import division
from __future__ import print_function
import tempfile
import numpy as np
from six.moves import range
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

View File

@ -1,3 +1,4 @@
# Lint as: python2, python3
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -17,6 +18,7 @@ from __future__ import division
from __future__ import print_function
import tempfile
import numpy as np
from six.moves import range
import tensorflow as tf
from tensorflow import flags

View File

@ -1,3 +1,4 @@
# Lint as: python2, python3
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -153,7 +154,8 @@ class TfLiteRNNCell(rnn_cell_impl.LayerRNNCell):
"reuse": self._reuse,
}
base_config = super(TfLiteRNNCell, self).get_config()
return dict(itertools.chain(base_config.items(), config.items()))
return dict(
itertools.chain(list(base_config.items()), list(config.items())))
@tf_export(v1=["lite.experimental.nn.TFLiteLSTMCell"])

View File

@ -1,3 +1,4 @@
# Lint as: python2, python3
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -17,6 +18,7 @@ from __future__ import division
from __future__ import print_function
import tempfile
import numpy as np
from six.moves import range
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

View File

@ -1,3 +1,4 @@
# Lint as: python2, python3
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -17,6 +18,7 @@ from __future__ import division
from __future__ import print_function
import tempfile
import numpy as np
from six.moves import range
import tensorflow as tf
from tensorflow import flags

View File

@ -1,3 +1,4 @@
# Lint as: python2, python3
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -21,6 +22,7 @@ from __future__ import print_function
import argparse
import glob
import os
import six
def rename_example_subfolder_files(library_dir):
@ -50,8 +52,9 @@ def move_person_data(library_dir):
with open(new_person_data_path, 'r') as source_file:
file_contents = source_file.read()
file_contents = file_contents.replace(
six.ensure_str(
'#include "tensorflow/lite/experimental/micro/examples/' +
'person_detection/person_detect_model_data.h"',
'person_detection/person_detect_model_data.h"'),
'#include "person_detect_model_data.h"')
with open(new_person_data_path, 'w') as source_file:
source_file.write(file_contents)

View File

@ -1,3 +1,4 @@
# Lint as: python2, python3
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -21,11 +22,12 @@ from __future__ import print_function
import argparse
import os.path
import re
import six
def sanitize_xml(unsanitized):
"""Uses a whitelist to avoid generating bad XML."""
return re.sub(r'[^a-zA-Z0-9+_\-/\\.]', '', unsanitized)
return re.sub(r'[^a-zA-Z0-9+_\-/\\.]', '', six.ensure_str(unsanitized))
def main(unused_args, flags):
@ -33,11 +35,12 @@ def main(unused_args, flags):
with open(flags.input_template, 'r') as input_template_file:
template_file_text = input_template_file.read()
template_file_text = re.sub(r'%{EXECUTABLE}%', flags.executable,
template_file_text = re.sub(r'%{EXECUTABLE}%',
six.ensure_str(flags.executable),
template_file_text)
srcs_list = flags.srcs.split(' ')
hdrs_list = flags.hdrs.split(' ')
srcs_list = six.ensure_str(flags.srcs).split(' ')
hdrs_list = six.ensure_str(flags.hdrs).split(' ')
all_srcs_list = srcs_list + hdrs_list
all_srcs_list.sort()
@ -64,9 +67,10 @@ def main(unused_args, flags):
replace_srcs += ' <FileType>' + ext_index + '</FileType>\n'
replace_srcs += ' <FilePath>' + clean_src + '</FilePath>\n'
replace_srcs += ' </File>\n'
template_file_text = re.sub(r'%{SRCS}%', replace_srcs, template_file_text)
template_file_text = re.sub(r'%{SRCS}%', replace_srcs,
six.ensure_str(template_file_text))
include_paths = re.sub(' ', ';', flags.include_paths)
include_paths = re.sub(' ', ';', six.ensure_str(flags.include_paths))
template_file_text = re.sub(r'%{INCLUDE_PATHS}%', include_paths,
template_file_text)

View File

@ -1,3 +1,4 @@
# Lint as: python2, python3
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -21,6 +22,7 @@ from __future__ import print_function
import argparse
import re
import sys
import six
def replace_includes(line, supplied_headers_list):
@ -29,10 +31,11 @@ def replace_includes(line, supplied_headers_list):
if include_match:
path = include_match.group(2)
for supplied_header in supplied_headers_list:
if supplied_header.endswith(path):
if six.ensure_str(supplied_header).endswith(path):
path = supplied_header
break
line = include_match.group(1) + path + include_match.group(3)
line = include_match.group(1) + six.ensure_str(path) + include_match.group(
3)
return line
@ -70,8 +73,8 @@ def replace_example_includes(line, _):
# their default locations into the top-level 'examples' folder in the Arduino
# library, we have to update any include references to match.
dir_path = 'tensorflow/lite/experimental/micro/examples/'
include_match = re.match(r'(.*#include.*")' + dir_path + r'([^/]+)/(.*")',
line)
include_match = re.match(
r'(.*#include.*")' + six.ensure_str(dir_path) + r'([^/]+)/(.*")', line)
if include_match:
flattened_name = re.sub(r'/', '_', include_match.group(3))
line = include_match.group(1) + flattened_name
@ -82,7 +85,7 @@ def main(unused_args, flags):
"""Transforms the input source file to work when exported to Arduino."""
input_file_lines = sys.stdin.read().split('\n')
supplied_headers_list = flags.third_party_headers.split(' ')
supplied_headers_list = six.ensure_str(flags.third_party_headers).split(' ')
output_lines = []
for line in input_file_lines:

View File

@ -19,7 +19,7 @@ py_library(
py_test(
name = "ops_util_test",
srcs = ["ops_util_test.py"],
python_version = "PY2",
python_version = "PY3",
srcs_version = "PY2AND3",
deps = [
":ops_util",

View File

@ -15,6 +15,7 @@ py_library(
deps = [
"//tensorflow/lite/python/interpreter_wrapper:tensorflow_wrap_interpreter_wrapper",
"//third_party/py/numpy",
"@six_archive//:six",
],
)
@ -25,6 +26,7 @@ py_test(
"//tensorflow/lite/python/testdata:interpreter_test_data",
"//tensorflow/lite/python/testdata:test_delegate.so",
],
python_version = "PY3",
srcs_version = "PY2AND3",
tags = [
"no_windows",
@ -38,6 +40,7 @@ py_test(
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform",
"//third_party/py/numpy",
"@six_archive//:six",
],
)
@ -63,16 +66,20 @@ py_test(
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform",
"//third_party/py/numpy",
"@six_archive//:six",
],
)
py_binary(
name = "tflite_convert",
srcs = ["tflite_convert.py"],
python_version = "PY2",
python_version = "PY3",
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
deps = [":tflite_convert_main_lib"],
deps = [
":tflite_convert_main_lib",
"@six_archive//:six",
],
)
py_library(
@ -80,7 +87,10 @@ py_library(
srcs = ["tflite_convert.py"],
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
deps = [":tflite_convert_lib"],
deps = [
":tflite_convert_lib",
"@six_archive//:six",
],
)
py_library(
@ -90,6 +100,7 @@ py_library(
visibility = ["//visibility:public"],
deps = [
":lite",
"@six_archive//:six",
],
)
@ -97,6 +108,7 @@ py_test(
name = "tflite_convert_test",
srcs = ["tflite_convert_test.py"],
data = [":tflite_convert"],
python_version = "PY3",
srcs_version = "PY2AND3",
tags = [
"no_oss",
@ -129,6 +141,7 @@ py_library(
"//tensorflow/python/keras",
"//tensorflow/python/saved_model:constants",
"//tensorflow/python/saved_model:loader",
"@six_archive//:six",
],
)
@ -136,6 +149,7 @@ py_test(
name = "lite_test",
srcs = ["lite_test.py"],
data = ["@tflite_mobilenet_ssd_quant_protobuf//:tflite_graph.pb"],
python_version = "PY3",
shard_count = 4,
srcs_version = "PY2AND3",
tags = [
@ -145,12 +159,14 @@ py_test(
":lite",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_test_lib",
"@six_archive//:six",
],
)
py_test(
name = "lite_v2_test",
srcs = ["lite_v2_test.py"],
python_version = "PY3",
srcs_version = "PY2AND3",
tags = [
"no_windows",
@ -159,12 +175,14 @@ py_test(
":lite",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_test_lib",
"@six_archive//:six",
],
)
py_test(
name = "lite_flex_test",
srcs = ["lite_flex_test.py"],
python_version = "PY3",
srcs_version = "PY2AND3",
tags = [
# TODO(b/111881877): Enable in oss after resolving op registry issues.
@ -181,6 +199,7 @@ py_test(
py_test(
name = "lite_mlir_test",
srcs = ["lite_mlir_test.py"],
python_version = "PY3",
srcs_version = "PY2AND3",
tags = [
"no_windows",
@ -189,6 +208,7 @@ py_test(
":lite",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_test_lib",
"@six_archive//:six",
],
)
@ -204,12 +224,14 @@ py_library(
":op_hint",
"//tensorflow/python:tf_optimizer",
"//tensorflow/python/eager:wrap_function",
"@six_archive//:six",
],
)
py_test(
name = "util_test",
srcs = ["util_test.py"],
python_version = "PY3",
srcs_version = "PY2AND3",
tags = [
"no_windows",
@ -218,6 +240,7 @@ py_test(
":util",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_test_lib",
"@six_archive//:six",
],
)
@ -226,6 +249,7 @@ py_library(
srcs = [
"wrap_toco.py",
],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:pywrap_tensorflow",
"//tensorflow/python:util",
@ -256,6 +280,7 @@ py_library(
"//tensorflow/lite/toco/python:toco_from_protos",
"//tensorflow/python:dtypes",
"//tensorflow/python:platform",
"@six_archive//:six",
],
)
@ -275,6 +300,7 @@ py_library(
py_test(
name = "convert_test",
srcs = ["convert_test.py"],
python_version = "PY3",
srcs_version = "PY2AND3",
deps = [
":convert",
@ -306,6 +332,7 @@ py_library(
py_test(
name = "convert_saved_model_test",
srcs = ["convert_saved_model_test.py"],
python_version = "PY3",
srcs_version = "PY2AND3",
tags = [
"no_windows",

View File

@ -1,3 +1,4 @@
# Lint as: python2, python3
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -19,12 +20,14 @@ from __future__ import division
from __future__ import print_function
import enum # pylint: disable=g-bad-import-order
import os as _os
import platform as _platform
import subprocess as _subprocess
import tempfile as _tempfile
import six
from six.moves import map
from tensorflow.lite.python import lite_constants
from tensorflow.lite.python import util
from tensorflow.lite.python import wrap_toco
@ -55,7 +58,7 @@ def _try_convert_to_unicode(output):
if isinstance(output, bytes):
try:
return output.decode()
return six.ensure_text(output)
except UnicodeDecodeError:
pass
return output
@ -151,7 +154,7 @@ def toco_convert_protos(model_flags_str,
fp_model.write(model_flags_str)
fp_toco.write(toco_flags_str)
fp_input.write(input_data_str)
fp_input.write(six.ensure_binary(input_data_str))
debug_info_str = debug_info_str if debug_info_str else ""
# if debug_info_str contains a "string value", then the call to
# fp_debug.write(debug_info_str) will fail with the following error
@ -347,7 +350,7 @@ def build_toco_convert_protos(input_tensors,
shape = input_tensor.shape
else:
shape = input_shapes[idx]
input_array.shape.dims.extend(map(int, shape))
input_array.shape.dims.extend(list(map(int, shape)))
for output_tensor in output_tensors:
model.output_arrays.append(util.get_tensor_name(output_tensor))
@ -400,7 +403,7 @@ def toco_convert_graph_def(input_data, input_arrays_with_shape, output_arrays,
input_array.mean_value, input_array.std_value = kwargs[
"quantized_input_stats"][idx]
input_array.name = name
input_array.shape.dims.extend(map(int, shape))
input_array.shape.dims.extend(list(map(int, shape)))
for name in output_arrays:
model_flags.output_arrays.append(name)

View File

@ -1,3 +1,4 @@
# Lint as: python2, python3
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -20,7 +21,10 @@ from __future__ import print_function
import ctypes
import platform
import sys
import numpy as np
import six
from six.moves import range
# pylint: disable=g-import-not-at-top
if not __file__.endswith('tflite_runtime/interpreter.py'):
@ -107,7 +111,7 @@ class Delegate(object):
self.message = ''
def report(self, x):
self.message += x if isinstance(x, str) else x.decode('utf-8')
self.message += x if isinstance(x, str) else six.ensure_text(x, 'utf-8')
capture = ErrorMessageCapture()
error_capturer_cb = ctypes.CFUNCTYPE(None, ctypes.c_char_p)(capture.report)

View File

@ -1,3 +1,4 @@
# Lint as: python2, python3
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -335,7 +336,8 @@ class InterpreterDelegateTest(test_util.TensorFlowTestCase):
if sys.platform == 'darwin': return
destructions = []
def register_destruction(x):
destructions.append(x if isinstance(x, str) else x.decode('utf-8'))
destructions.append(
x if isinstance(x, str) else six.ensure_text(x, 'utf-8'))
return 0
# Make a wrapper for the callback so we can send this to ctypes
delegate = interpreter_wrapper.load_delegate(self._delegate_file)

View File

@ -1,3 +1,4 @@
# Lint as: python2, python3
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -18,8 +19,10 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import warnings
import enum
import warnings
import six
from six import PY3
from google.protobuf import text_format as _text_format
@ -679,9 +682,9 @@ class TFLiteConverter(TFLiteConverterBase):
if not isinstance(file_content, str):
if PY3:
file_content = file_content.decode("utf-8")
file_content = six.ensure_text(file_content, "utf-8")
else:
file_content = file_content.encode("utf-8")
file_content = six.ensure_binary(file_content, "utf-8")
graph_def = _graph_pb2.GraphDef()
_text_format.Merge(file_content, graph_def)
except (_text_format.ParseError, DecodeError):

View File

@ -1,3 +1,4 @@
# Lint as: python2, python3
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -20,6 +21,7 @@ from __future__ import print_function
from absl.testing import parameterized
import numpy as np
from six.moves import zip
from tensorflow.lite.python import lite
from tensorflow.lite.python import lite_constants

View File

@ -1,3 +1,4 @@
# Lint as: python2, python3
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -20,8 +21,11 @@ from __future__ import print_function
import os
import tempfile
from absl.testing import parameterized
import numpy as np
import six
from six.moves import range
from tensorflow.lite.python import lite
from tensorflow.lite.python import lite_constants
@ -1132,7 +1136,7 @@ class FromSessionTest(TestModels, parameterized.TestCase):
# Check the add node in the inlined function is included.
func = sess.graph.as_graph_def().library.function[0].signature.name
self.assertIn(('add@' + func), converter._debug_info.traces)
self.assertIn(('add@' + six.ensure_str(func)), converter._debug_info.traces)
class FromFrozenGraphFile(test_util.TensorFlowTestCase):

View File

@ -1,3 +1,4 @@
# Lint as: python2, python3
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -20,6 +21,8 @@ from __future__ import print_function
import os
import numpy as np
from six.moves import range
from six.moves import zip
from tensorflow.lite.python import lite
from tensorflow.lite.python.interpreter import Interpreter

View File

@ -57,7 +57,7 @@ py_test(
":test_data",
"//tensorflow/lite:testdata/multi_add.bin",
],
python_version = "PY2",
python_version = "PY3",
srcs_version = "PY2AND3",
tags = ["no_oss"],
deps = [
@ -68,5 +68,6 @@ py_test(
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform",
"//third_party/py/numpy",
"@six_archive//:six",
],
)

View File

@ -1,3 +1,4 @@
# Lint as: python2, python3
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -18,6 +19,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from six.moves import range
from tensorflow.lite.python import lite_constants as constants
from tensorflow.lite.python.optimize import calibrator as _calibrator

View File

@ -1,3 +1,4 @@
# Lint as: python2, python3
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -22,6 +23,9 @@ import argparse
import os
import sys
import six
from six.moves import zip
from tensorflow.lite.python import lite
from tensorflow.lite.python import lite_constants
from tensorflow.lite.toco import toco_flags_pb2 as _toco_flags_pb2
@ -32,13 +36,13 @@ from tensorflow.python.platform import app
def _parse_array(values, type_fn=str):
if values is not None:
return [type_fn(val) for val in values.split(",") if val]
return [type_fn(val) for val in six.ensure_str(values).split(",") if val]
return None
def _parse_set(values):
if values is not None:
return set([item for item in values.split(",") if item])
return set([item for item in six.ensure_str(values).split(",") if item])
return None
@ -81,9 +85,9 @@ def _get_toco_converter(flags):
if flags.input_shapes:
input_shapes_list = [
_parse_array(shape, type_fn=int)
for shape in flags.input_shapes.split(":")
for shape in six.ensure_str(flags.input_shapes).split(":")
]
input_shapes = dict(zip(input_arrays, input_shapes_list))
input_shapes = dict(list(zip(input_arrays, input_shapes_list)))
output_arrays = _parse_array(flags.output_arrays)
converter_kwargs = {
@ -152,7 +156,7 @@ def _convert_tf1_model(flags):
"--std_dev_values and --mean_values with multiple input "
"tensors in order to map between names and "
"values.".format(",".join(input_arrays)))
converter.quantized_input_stats = dict(zip(input_arrays, quant_stats))
converter.quantized_input_stats = dict(list(zip(input_arrays, quant_stats)))
if (flags.default_ranges_min is not None) and (flags.default_ranges_max is
not None):
converter.default_ranges_stats = (flags.default_ranges_min,
@ -171,7 +175,7 @@ def _convert_tf1_model(flags):
if flags.target_ops:
ops_set_options = lite.OpsSet.get_options()
converter.target_spec.supported_ops = set()
for option in flags.target_ops.split(","):
for option in six.ensure_str(flags.target_ops).split(","):
if option not in ops_set_options:
raise ValueError("Invalid value for --target_ops. Options: "
"{0}".format(",".join(ops_set_options)))
@ -201,7 +205,7 @@ def _convert_tf1_model(flags):
# Convert model.
output_data = converter.convert()
with open(flags.output_file, "wb") as f:
f.write(output_data)
f.write(six.ensure_binary(output_data))
def _convert_tf2_model(flags):
@ -226,7 +230,7 @@ def _convert_tf2_model(flags):
# Convert the model.
tflite_model = converter.convert()
with open(flags.output_file, "wb") as f:
f.write(tflite_model)
f.write(six.ensure_binary(tflite_model))
def _check_tf1_flags(flags, unparsed):
@ -245,7 +249,7 @@ def _check_tf1_flags(flags, unparsed):
# Check unparsed flags for common mistakes based on previous TOCO.
def _get_message_unparsed(flag, orig_flag, new_flag):
if flag.startswith(orig_flag):
if six.ensure_str(flag).startswith(orig_flag):
return "\n Use {0} instead of {1}".format(new_flag, orig_flag)
return ""

View File

@ -1,3 +1,4 @@
# Lint as: python2, python3
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -20,6 +21,9 @@ from __future__ import print_function
import sys
import six
from six.moves import range
from tensorflow.core.protobuf import config_pb2 as _config_pb2
from tensorflow.core.protobuf import meta_graph_pb2 as _meta_graph_pb2
from tensorflow.lite.python.op_hint import convert_op_hints_to_stubs
@ -74,7 +78,7 @@ def get_tensor_name(tensor):
Returns:
str
"""
parts = tensor.name.split(":")
parts = six.ensure_str(tensor.name).split(":")
if len(parts) > 2:
raise ValueError("Tensor name invalid. Expect 0 or 1 colon, got {0}".format(
len(parts) - 1))
@ -277,7 +281,8 @@ def is_frozen_graph(sess):
Bool.
"""
for op in sess.graph.get_operations():
if op.type.startswith("Variable") or op.type.endswith("VariableOp"):
if six.ensure_str(op.type).startswith("Variable") or six.ensure_str(
op.type).endswith("VariableOp"):
return False
return True

View File

@ -1,3 +1,4 @@
# Lint as: python2, python3
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -18,6 +19,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from six.moves import range
from tensorflow.lite.python import lite_constants
from tensorflow.lite.python import util
from tensorflow.lite.toco import types_pb2 as _types_pb2