Update the remaining unchanged spots, which read builtin code
PiperOrigin-RevId: 338183423 Change-Id: Ib208da7e61475165b58dbac3771c9c330ca6f101
This commit is contained in:
parent
cc0e32e209
commit
62feaa576a
tensorflow
compiler/mlir/lite/tests/flatbuffer2mlir
lite
micro
toco/tflite
tools/versioning
@ -54,6 +54,7 @@ tf_native_cc_binary(
|
||||
deps = [
|
||||
"//tensorflow/lite:framework",
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
"//tensorflow/lite/schema:schema_utils",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//llvm:Support",
|
||||
],
|
||||
@ -70,6 +71,7 @@ tf_native_cc_binary(
|
||||
deps = [
|
||||
"//tensorflow/lite:framework",
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
"//tensorflow/lite/schema:schema_utils",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//llvm:Support",
|
||||
],
|
||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "tensorflow/lite/model.h"
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
#include "tensorflow/lite/schema/schema_utils.h"
|
||||
|
||||
using llvm::Optional;
|
||||
using llvm::cl::opt;
|
||||
@ -95,7 +96,8 @@ Optional<std::unique_ptr<tflite::ModelT>> RemoveConstantOpInReshape(
|
||||
// Find the reshape ops and make it single operand.
|
||||
for (auto& sub_graph : model->subgraphs) {
|
||||
for (auto& op : sub_graph->operators) {
|
||||
if (model->operator_codes[op->opcode_index]->builtin_code ==
|
||||
if (tflite::GetBuiltinCode(
|
||||
model->operator_codes[op->opcode_index].get()) ==
|
||||
tflite::BuiltinOperator_RESHAPE) {
|
||||
auto& output_tensor = sub_graph->tensors[op->outputs[0]];
|
||||
auto shape = output_tensor->shape;
|
||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "tensorflow/lite/model.h"
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
#include "tensorflow/lite/schema/schema_utils.h"
|
||||
|
||||
using llvm::Optional;
|
||||
using llvm::cl::opt;
|
||||
@ -114,7 +115,8 @@ Optional<std::unique_ptr<tflite::ModelT>> InjectStatsToFullyConnected(
|
||||
// Find the tensors and inject the min and max to the input and output
|
||||
for (auto& sub_graph : model->subgraphs) {
|
||||
for (auto& op : sub_graph->operators) {
|
||||
if (model->operator_codes[op->opcode_index]->builtin_code ==
|
||||
if (tflite::GetBuiltinCode(
|
||||
model->operator_codes[op->opcode_index].get()) ==
|
||||
tflite::BuiltinOperator_FULLY_CONNECTED) {
|
||||
// inject min/max to the input and output tensors
|
||||
auto& input_tensor = sub_graph->tensors[op->inputs[0]];
|
||||
|
@ -818,6 +818,8 @@ TfLiteStatus MicroAllocator::PrepareNodeAndRegistrationDataFromFlatbuffer(
|
||||
GetRegistrationFromOpCode(opcode, op_resolver, error_reporter_,
|
||||
&(node_and_registrations[i].registration));
|
||||
if (status != kTfLiteOk) {
|
||||
// TODO(b/171278094): Use the GetBuiltinCode method in the schema utilitly
|
||||
// to get builtin code from op code.
|
||||
TF_LITE_REPORT_ERROR(error_reporter_,
|
||||
"Failed to get registration from op code %s\n ",
|
||||
EnumNameBuiltinOperator(opcode->builtin_code()));
|
||||
|
@ -662,10 +662,10 @@ TEST_F(VersionedOpExportTest, Export) {
|
||||
// different versions.
|
||||
EXPECT_EQ(2, operator_codes->size());
|
||||
EXPECT_EQ(::tflite::BuiltinOperator_CONV_2D,
|
||||
(*operator_codes)[0]->builtin_code());
|
||||
GetBuiltinCode((*operator_codes)[0]));
|
||||
EXPECT_EQ(1, (*operator_codes)[0]->version());
|
||||
EXPECT_EQ(::tflite::BuiltinOperator_CONV_2D,
|
||||
(*operator_codes)[1]->builtin_code());
|
||||
GetBuiltinCode((*operator_codes)[1]));
|
||||
EXPECT_EQ(2, (*operator_codes)[1]->version());
|
||||
|
||||
// Verify that the 2 operators points to the correct indices of the operation
|
||||
|
@ -26,6 +26,7 @@ cc_library(
|
||||
"//tensorflow/lite/kernels/internal:compatibility",
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
"//tensorflow/lite/schema:schema_fbs_with_mutable",
|
||||
"//tensorflow/lite/schema:schema_utils",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@flatbuffers",
|
||||
|
@ -25,6 +25,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/lite/kernels/internal/compatibility.h"
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
#include "tensorflow/lite/schema/schema_utils.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace {
|
||||
@ -621,9 +622,10 @@ TensorType GetTensorType(int32_t idx, const SubGraph* subgraph) {
|
||||
// options to decide op version.
|
||||
OpSignature GetOpSignature(const OperatorCode* op_code, const Operator* op,
|
||||
const SubGraph* subgraph) {
|
||||
OpSignature op_sig = {op_code->builtin_code()};
|
||||
auto builtin_code = GetBuiltinCode(op_code);
|
||||
OpSignature op_sig = {builtin_code};
|
||||
|
||||
switch (op_code->builtin_code()) {
|
||||
switch (builtin_code) {
|
||||
case BuiltinOperator_DEPTHWISE_CONV_2D: {
|
||||
auto conv_option = op->builtin_options_as_DepthwiseConv2DOptions();
|
||||
if (conv_option) {
|
||||
@ -797,14 +799,15 @@ void UpdateOpVersion(uint8_t* model_buffer_pointer) {
|
||||
OperatorCode* op_code =
|
||||
model->mutable_operator_codes()->GetMutableObject(op->opcode_index());
|
||||
|
||||
if (op_code->builtin_code() != BuiltinOperator_CUSTOM) {
|
||||
auto builtin_code = GetBuiltinCode(op_code);
|
||||
if (builtin_code != BuiltinOperator_CUSTOM) {
|
||||
OpSignature op_sig = GetOpSignature(op_code, op, subgraph);
|
||||
// Update builtin operator version.
|
||||
int32_t op_ver = GetBuiltinOperatorVersion(op_sig);
|
||||
if (!op_code->mutate_version(op_ver)) {
|
||||
LOG(ERROR) << "Can't set operator "
|
||||
<< EnumNameBuiltinOperator(op_code->builtin_code())
|
||||
<< " to version " << op_ver;
|
||||
<< EnumNameBuiltinOperator(builtin_code) << " to version "
|
||||
<< op_ver;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#include "absl/strings/str_split.h"
|
||||
#include "tensorflow/lite/minimal_logging.h"
|
||||
#include "tensorflow/lite/schema/mutable/schema_generated.h"
|
||||
#include "tensorflow/lite/schema/schema_utils.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace {
|
||||
@ -340,7 +341,7 @@ void UpdateMinimumRuntimeVersionForModel(uint8_t* model_buffer_pointer) {
|
||||
const OperatorCode* op_code =
|
||||
model->operator_codes()->Get(op->opcode_index());
|
||||
std::string runtime_version = FindMinimumRuntimeVersionForOp(
|
||||
op_code->builtin_code(), op_code->version());
|
||||
GetBuiltinCode(op_code), op_code->version());
|
||||
if (runtime_version.empty() ||
|
||||
runtime_version == kPendingReleaseVersion) {
|
||||
// In case we didn't find the current op in the map, or the operator
|
||||
|
Loading…
Reference in New Issue
Block a user