Update the remaining unchanged spots, which read builtin code
PiperOrigin-RevId: 338183423 Change-Id: Ib208da7e61475165b58dbac3771c9c330ca6f101
This commit is contained in:
parent
cc0e32e209
commit
62feaa576a
@ -54,6 +54,7 @@ tf_native_cc_binary(
|
|||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/lite:framework",
|
"//tensorflow/lite:framework",
|
||||||
"//tensorflow/lite/schema:schema_fbs",
|
"//tensorflow/lite/schema:schema_fbs",
|
||||||
|
"//tensorflow/lite/schema:schema_utils",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
"@llvm-project//llvm:Support",
|
"@llvm-project//llvm:Support",
|
||||||
],
|
],
|
||||||
@ -70,6 +71,7 @@ tf_native_cc_binary(
|
|||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/lite:framework",
|
"//tensorflow/lite:framework",
|
||||||
"//tensorflow/lite/schema:schema_fbs",
|
"//tensorflow/lite/schema:schema_fbs",
|
||||||
|
"//tensorflow/lite/schema:schema_utils",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
"@llvm-project//llvm:Support",
|
"@llvm-project//llvm:Support",
|
||||||
],
|
],
|
||||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
|||||||
#include "llvm/Support/raw_ostream.h"
|
#include "llvm/Support/raw_ostream.h"
|
||||||
#include "tensorflow/lite/model.h"
|
#include "tensorflow/lite/model.h"
|
||||||
#include "tensorflow/lite/schema/schema_generated.h"
|
#include "tensorflow/lite/schema/schema_generated.h"
|
||||||
|
#include "tensorflow/lite/schema/schema_utils.h"
|
||||||
|
|
||||||
using llvm::Optional;
|
using llvm::Optional;
|
||||||
using llvm::cl::opt;
|
using llvm::cl::opt;
|
||||||
@ -95,7 +96,8 @@ Optional<std::unique_ptr<tflite::ModelT>> RemoveConstantOpInReshape(
|
|||||||
// Find the reshape ops and make it single operand.
|
// Find the reshape ops and make it single operand.
|
||||||
for (auto& sub_graph : model->subgraphs) {
|
for (auto& sub_graph : model->subgraphs) {
|
||||||
for (auto& op : sub_graph->operators) {
|
for (auto& op : sub_graph->operators) {
|
||||||
if (model->operator_codes[op->opcode_index]->builtin_code ==
|
if (tflite::GetBuiltinCode(
|
||||||
|
model->operator_codes[op->opcode_index].get()) ==
|
||||||
tflite::BuiltinOperator_RESHAPE) {
|
tflite::BuiltinOperator_RESHAPE) {
|
||||||
auto& output_tensor = sub_graph->tensors[op->outputs[0]];
|
auto& output_tensor = sub_graph->tensors[op->outputs[0]];
|
||||||
auto shape = output_tensor->shape;
|
auto shape = output_tensor->shape;
|
||||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
|||||||
#include "llvm/Support/raw_ostream.h"
|
#include "llvm/Support/raw_ostream.h"
|
||||||
#include "tensorflow/lite/model.h"
|
#include "tensorflow/lite/model.h"
|
||||||
#include "tensorflow/lite/schema/schema_generated.h"
|
#include "tensorflow/lite/schema/schema_generated.h"
|
||||||
|
#include "tensorflow/lite/schema/schema_utils.h"
|
||||||
|
|
||||||
using llvm::Optional;
|
using llvm::Optional;
|
||||||
using llvm::cl::opt;
|
using llvm::cl::opt;
|
||||||
@ -114,7 +115,8 @@ Optional<std::unique_ptr<tflite::ModelT>> InjectStatsToFullyConnected(
|
|||||||
// Find the tensors and inject the min and max to the input and output
|
// Find the tensors and inject the min and max to the input and output
|
||||||
for (auto& sub_graph : model->subgraphs) {
|
for (auto& sub_graph : model->subgraphs) {
|
||||||
for (auto& op : sub_graph->operators) {
|
for (auto& op : sub_graph->operators) {
|
||||||
if (model->operator_codes[op->opcode_index]->builtin_code ==
|
if (tflite::GetBuiltinCode(
|
||||||
|
model->operator_codes[op->opcode_index].get()) ==
|
||||||
tflite::BuiltinOperator_FULLY_CONNECTED) {
|
tflite::BuiltinOperator_FULLY_CONNECTED) {
|
||||||
// inject min/max to the input and output tensors
|
// inject min/max to the input and output tensors
|
||||||
auto& input_tensor = sub_graph->tensors[op->inputs[0]];
|
auto& input_tensor = sub_graph->tensors[op->inputs[0]];
|
||||||
|
@ -818,6 +818,8 @@ TfLiteStatus MicroAllocator::PrepareNodeAndRegistrationDataFromFlatbuffer(
|
|||||||
GetRegistrationFromOpCode(opcode, op_resolver, error_reporter_,
|
GetRegistrationFromOpCode(opcode, op_resolver, error_reporter_,
|
||||||
&(node_and_registrations[i].registration));
|
&(node_and_registrations[i].registration));
|
||||||
if (status != kTfLiteOk) {
|
if (status != kTfLiteOk) {
|
||||||
|
// TODO(b/171278094): Use the GetBuiltinCode method in the schema utilitly
|
||||||
|
// to get builtin code from op code.
|
||||||
TF_LITE_REPORT_ERROR(error_reporter_,
|
TF_LITE_REPORT_ERROR(error_reporter_,
|
||||||
"Failed to get registration from op code %s\n ",
|
"Failed to get registration from op code %s\n ",
|
||||||
EnumNameBuiltinOperator(opcode->builtin_code()));
|
EnumNameBuiltinOperator(opcode->builtin_code()));
|
||||||
|
@ -662,10 +662,10 @@ TEST_F(VersionedOpExportTest, Export) {
|
|||||||
// different versions.
|
// different versions.
|
||||||
EXPECT_EQ(2, operator_codes->size());
|
EXPECT_EQ(2, operator_codes->size());
|
||||||
EXPECT_EQ(::tflite::BuiltinOperator_CONV_2D,
|
EXPECT_EQ(::tflite::BuiltinOperator_CONV_2D,
|
||||||
(*operator_codes)[0]->builtin_code());
|
GetBuiltinCode((*operator_codes)[0]));
|
||||||
EXPECT_EQ(1, (*operator_codes)[0]->version());
|
EXPECT_EQ(1, (*operator_codes)[0]->version());
|
||||||
EXPECT_EQ(::tflite::BuiltinOperator_CONV_2D,
|
EXPECT_EQ(::tflite::BuiltinOperator_CONV_2D,
|
||||||
(*operator_codes)[1]->builtin_code());
|
GetBuiltinCode((*operator_codes)[1]));
|
||||||
EXPECT_EQ(2, (*operator_codes)[1]->version());
|
EXPECT_EQ(2, (*operator_codes)[1]->version());
|
||||||
|
|
||||||
// Verify that the 2 operators points to the correct indices of the operation
|
// Verify that the 2 operators points to the correct indices of the operation
|
||||||
|
@ -26,6 +26,7 @@ cc_library(
|
|||||||
"//tensorflow/lite/kernels/internal:compatibility",
|
"//tensorflow/lite/kernels/internal:compatibility",
|
||||||
"//tensorflow/lite/schema:schema_fbs",
|
"//tensorflow/lite/schema:schema_fbs",
|
||||||
"//tensorflow/lite/schema:schema_fbs_with_mutable",
|
"//tensorflow/lite/schema:schema_fbs_with_mutable",
|
||||||
|
"//tensorflow/lite/schema:schema_utils",
|
||||||
"@com_google_absl//absl/memory",
|
"@com_google_absl//absl/memory",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
"@flatbuffers",
|
"@flatbuffers",
|
||||||
|
@ -25,6 +25,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
#include "tensorflow/lite/kernels/internal/compatibility.h"
|
#include "tensorflow/lite/kernels/internal/compatibility.h"
|
||||||
#include "tensorflow/lite/schema/schema_generated.h"
|
#include "tensorflow/lite/schema/schema_generated.h"
|
||||||
|
#include "tensorflow/lite/schema/schema_utils.h"
|
||||||
|
|
||||||
namespace tflite {
|
namespace tflite {
|
||||||
namespace {
|
namespace {
|
||||||
@ -621,9 +622,10 @@ TensorType GetTensorType(int32_t idx, const SubGraph* subgraph) {
|
|||||||
// options to decide op version.
|
// options to decide op version.
|
||||||
OpSignature GetOpSignature(const OperatorCode* op_code, const Operator* op,
|
OpSignature GetOpSignature(const OperatorCode* op_code, const Operator* op,
|
||||||
const SubGraph* subgraph) {
|
const SubGraph* subgraph) {
|
||||||
OpSignature op_sig = {op_code->builtin_code()};
|
auto builtin_code = GetBuiltinCode(op_code);
|
||||||
|
OpSignature op_sig = {builtin_code};
|
||||||
|
|
||||||
switch (op_code->builtin_code()) {
|
switch (builtin_code) {
|
||||||
case BuiltinOperator_DEPTHWISE_CONV_2D: {
|
case BuiltinOperator_DEPTHWISE_CONV_2D: {
|
||||||
auto conv_option = op->builtin_options_as_DepthwiseConv2DOptions();
|
auto conv_option = op->builtin_options_as_DepthwiseConv2DOptions();
|
||||||
if (conv_option) {
|
if (conv_option) {
|
||||||
@ -797,14 +799,15 @@ void UpdateOpVersion(uint8_t* model_buffer_pointer) {
|
|||||||
OperatorCode* op_code =
|
OperatorCode* op_code =
|
||||||
model->mutable_operator_codes()->GetMutableObject(op->opcode_index());
|
model->mutable_operator_codes()->GetMutableObject(op->opcode_index());
|
||||||
|
|
||||||
if (op_code->builtin_code() != BuiltinOperator_CUSTOM) {
|
auto builtin_code = GetBuiltinCode(op_code);
|
||||||
|
if (builtin_code != BuiltinOperator_CUSTOM) {
|
||||||
OpSignature op_sig = GetOpSignature(op_code, op, subgraph);
|
OpSignature op_sig = GetOpSignature(op_code, op, subgraph);
|
||||||
// Update builtin operator version.
|
// Update builtin operator version.
|
||||||
int32_t op_ver = GetBuiltinOperatorVersion(op_sig);
|
int32_t op_ver = GetBuiltinOperatorVersion(op_sig);
|
||||||
if (!op_code->mutate_version(op_ver)) {
|
if (!op_code->mutate_version(op_ver)) {
|
||||||
LOG(ERROR) << "Can't set operator "
|
LOG(ERROR) << "Can't set operator "
|
||||||
<< EnumNameBuiltinOperator(op_code->builtin_code())
|
<< EnumNameBuiltinOperator(builtin_code) << " to version "
|
||||||
<< " to version " << op_ver;
|
<< op_ver;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
|||||||
#include "absl/strings/str_split.h"
|
#include "absl/strings/str_split.h"
|
||||||
#include "tensorflow/lite/minimal_logging.h"
|
#include "tensorflow/lite/minimal_logging.h"
|
||||||
#include "tensorflow/lite/schema/mutable/schema_generated.h"
|
#include "tensorflow/lite/schema/mutable/schema_generated.h"
|
||||||
|
#include "tensorflow/lite/schema/schema_utils.h"
|
||||||
|
|
||||||
namespace tflite {
|
namespace tflite {
|
||||||
namespace {
|
namespace {
|
||||||
@ -340,7 +341,7 @@ void UpdateMinimumRuntimeVersionForModel(uint8_t* model_buffer_pointer) {
|
|||||||
const OperatorCode* op_code =
|
const OperatorCode* op_code =
|
||||||
model->operator_codes()->Get(op->opcode_index());
|
model->operator_codes()->Get(op->opcode_index());
|
||||||
std::string runtime_version = FindMinimumRuntimeVersionForOp(
|
std::string runtime_version = FindMinimumRuntimeVersionForOp(
|
||||||
op_code->builtin_code(), op_code->version());
|
GetBuiltinCode(op_code), op_code->version());
|
||||||
if (runtime_version.empty() ||
|
if (runtime_version.empty() ||
|
||||||
runtime_version == kPendingReleaseVersion) {
|
runtime_version == kPendingReleaseVersion) {
|
||||||
// In case we didn't find the current op in the map, or the operator
|
// In case we didn't find the current op in the map, or the operator
|
||||||
|
Loading…
Reference in New Issue
Block a user