Update Python codes that read builtin code after schema change.
Also added a Python utility for reading builtin code from operator code. PiperOrigin-RevId: 338422128 Change-Id: I40d2d77cca6ddc2fa161a1107dd73e0adfffba5e
This commit is contained in:
parent
48f594cba4
commit
786bd17c6d
tensorflow/lite/python
@ -234,6 +234,7 @@ py_library(
|
||||
deps = [
|
||||
":op_hint",
|
||||
":schema_py",
|
||||
":schema_util",
|
||||
"//tensorflow/lite/toco:toco_flags_proto_py",
|
||||
"//tensorflow/python:convert_to_constants",
|
||||
"//tensorflow/python:dtypes",
|
||||
@ -402,3 +403,13 @@ sh_test(
|
||||
srcs = ["convert_file_to_c_source_test.sh"],
|
||||
data = [":convert_file_to_c_source"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "schema_util",
|
||||
srcs = ["schema_util.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
visibility = ["//tensorflow/lite/schema:utils_friends"],
|
||||
deps = [
|
||||
"//tensorflow/python:util",
|
||||
],
|
||||
)
|
||||
|
50
tensorflow/lite/python/schema_util.py
Normal file
50
tensorflow/lite/python/schema_util.py
Normal file
@ -0,0 +1,50 @@
|
||||
# Lint as: python2, python3
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Schema utilities to get builtin code from operator code."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.util import all_util
|
||||
|
||||
|
||||
def get_builtin_code_from_operator_code(opcode):
|
||||
"""Return the builtin code of the given operator code.
|
||||
|
||||
The following method is introduced to resolve op builtin code shortage
|
||||
problem. The new builtin opreator will be assigned to the extended builtin
|
||||
code field in the flatbuffer schema. Those methods helps to hide builtin code
|
||||
details.
|
||||
|
||||
Args:
|
||||
opcode: Operator code.
|
||||
|
||||
Returns:
|
||||
The builtin code of the given operator code.
|
||||
"""
|
||||
# Access BuiltinCode() method first if available.
|
||||
if hasattr(opcode, 'BuiltinCode') and callable(opcode.BuiltinCode):
|
||||
return max(opcode.BuiltinCode(), opcode.DeprecatedBuiltinCode())
|
||||
|
||||
return max(opcode.builtinCode, opcode.deprecatedBuiltinCode)
|
||||
|
||||
|
||||
_allowed_symbols = [
|
||||
'get_builtin_code_from_operator_code',
|
||||
]
|
||||
|
||||
all_util.remove_undocumented(__name__, _allowed_symbols)
|
@ -32,6 +32,7 @@ from tensorflow.core.protobuf import config_pb2 as _config_pb2
|
||||
from tensorflow.core.protobuf import graph_debug_info_pb2
|
||||
from tensorflow.core.protobuf import meta_graph_pb2 as _meta_graph_pb2
|
||||
from tensorflow.lite.python import schema_py_generated as schema_fb
|
||||
from tensorflow.lite.python import schema_util
|
||||
from tensorflow.lite.python.op_hint import convert_op_hints_to_stubs
|
||||
from tensorflow.lite.python.op_hint import find_all_hinted_output_nodes
|
||||
from tensorflow.lite.toco import types_pb2 as _types_pb2
|
||||
@ -641,7 +642,8 @@ def _modify_model_input_type(model, inference_input_type=dtypes.float32):
|
||||
# Find all quantize operators
|
||||
quant_opcode_idxs = []
|
||||
for idx, opcode in enumerate(model.operatorCodes):
|
||||
if opcode.builtinCode == schema_fb.BuiltinOperator.QUANTIZE:
|
||||
builtin_code = schema_util.get_builtin_code_from_operator_code(opcode)
|
||||
if builtin_code == schema_fb.BuiltinOperator.QUANTIZE:
|
||||
quant_opcode_idxs.append(idx)
|
||||
if not quant_opcode_idxs:
|
||||
raise ValueError("Model input is not quantized.")
|
||||
@ -721,7 +723,8 @@ def _modify_model_output_type(model, inference_output_type=dtypes.float32):
|
||||
# Find all dequantize operators
|
||||
dequant_opcode_idxs = []
|
||||
for idx, opcode in enumerate(model.operatorCodes):
|
||||
if opcode.builtinCode == schema_fb.BuiltinOperator.DEQUANTIZE:
|
||||
builtin_code = schema_util.get_builtin_code_from_operator_code(opcode)
|
||||
if builtin_code == schema_fb.BuiltinOperator.DEQUANTIZE:
|
||||
dequant_opcode_idxs.append(idx)
|
||||
if not dequant_opcode_idxs:
|
||||
raise ValueError("Model output is not dequantized.")
|
||||
@ -769,7 +772,8 @@ def _modify_model_output_type(model, inference_output_type=dtypes.float32):
|
||||
# Find a quantize operator
|
||||
quant_opcode_idx = -1
|
||||
for idx, opcode in enumerate(model.operatorCodes):
|
||||
if opcode.builtinCode == schema_fb.BuiltinOperator.QUANTIZE:
|
||||
builtin_code = schema_util.get_builtin_code_from_operator_code(opcode)
|
||||
if builtin_code == schema_fb.BuiltinOperator.QUANTIZE:
|
||||
quant_opcode_idx = idx
|
||||
break
|
||||
# Create a quantize operator, if none exist
|
||||
@ -843,4 +847,3 @@ def modify_model_io_type(
|
||||
_modify_model_output_type(model_object, inference_output_type)
|
||||
|
||||
return _convert_model_from_object_to_bytearray(model_object)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user