2443 lines
87 KiB
C++
2443 lines
87 KiB
C++
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
==============================================================================*/
|
|
#include "tensorflow/lite/toco/tooling_util.h"
|
|
|
|
#include <functional>
|
|
#include <iterator>
|
|
#include <set>
|
|
#include <unordered_map>
|
|
#include <unordered_set>
|
|
#include <utility>
|
|
|
|
#include "absl/strings/ascii.h"
|
|
#include "absl/strings/str_cat.h"
|
|
#include "absl/strings/str_join.h"
|
|
#include "absl/strings/str_replace.h"
|
|
#include "absl/strings/str_split.h"
|
|
#include "re2/re2.h"
|
|
#include "tensorflow/core/lib/core/status.h"
|
|
#include "tensorflow/core/platform/logging.h"
|
|
#include "tensorflow/lite/toco/dump_graphviz.h"
|
|
#include "tensorflow/lite/toco/model_flags.pb.h"
|
|
#include "tensorflow/lite/toco/toco_graphviz_dump_options.h"
|
|
|
|
namespace toco {
|
|
|
|
// Find the longest common prefix of two strings.
|
|
absl::string_view FindLongestCommonPrefix(absl::string_view a,
|
|
absl::string_view b) {
|
|
if (a.empty() || b.empty()) return absl::string_view();
|
|
|
|
const char* pa = a.data();
|
|
const char* pb = b.data();
|
|
size_t count = 0;
|
|
const size_t limit = std::min(a.size(), b.size());
|
|
while (count < limit && *pa == *pb) {
|
|
++pa;
|
|
++pb;
|
|
++count;
|
|
}
|
|
|
|
return absl::string_view(a.data(), count);
|
|
}
|
|
|
|
string LogName(const Operator& op) {
|
|
const string& opname = HelpfulOperatorTypeName(op);
|
|
if (op.outputs.empty()) {
|
|
return toco::port::StringF("{%s operator}", opname);
|
|
} else {
|
|
return toco::port::StringF("{%s operator with output %s}", opname,
|
|
op.outputs[0]);
|
|
}
|
|
}
|
|
|
|
string ArrayDataTypeName(ArrayDataType data_type) {
|
|
switch (data_type) {
|
|
case ArrayDataType::kFloat:
|
|
return "float";
|
|
case ArrayDataType::kInt8:
|
|
return "int8";
|
|
case ArrayDataType::kUint8:
|
|
return "uint8";
|
|
case ArrayDataType::kInt16:
|
|
return "int16";
|
|
case ArrayDataType::kUint16:
|
|
return "uint16";
|
|
case ArrayDataType::kInt32:
|
|
return "int32";
|
|
case ArrayDataType::kUint32:
|
|
return "uint32";
|
|
case ArrayDataType::kInt64:
|
|
return "int64";
|
|
case ArrayDataType::kUint64:
|
|
return "uint64";
|
|
case ArrayDataType::kString:
|
|
return "string";
|
|
case ArrayDataType::kBool:
|
|
return "bool";
|
|
case ArrayDataType::kComplex64:
|
|
return "complex64";
|
|
case ArrayDataType::kNone:
|
|
return "None";
|
|
default:
|
|
LOG(FATAL) << "Unhandled array data type " << static_cast<int>(data_type);
|
|
}
|
|
}
|
|
|
|
bool IsInputArray(const Model& model, const string& array_name) {
|
|
for (const auto& input_array : model.flags.input_arrays()) {
|
|
if (array_name == input_array.name()) {
|
|
return true;
|
|
}
|
|
}
|
|
return false;
|
|
}
|
|
|
|
bool IsOutputArray(const Model& model, const string& array_name) {
|
|
for (const auto& output_array : model.flags.output_arrays()) {
|
|
if (array_name == output_array) {
|
|
return true;
|
|
}
|
|
}
|
|
return false;
|
|
}
|
|
|
|
bool IsArrayConsumed(const Model& model, const string& name) {
|
|
if (GetOpWithInput(model, name)) {
|
|
return true;
|
|
}
|
|
if (IsOutputArray(model, name)) {
|
|
return true;
|
|
}
|
|
for (const auto& rnn_state : model.flags.rnn_states()) {
|
|
if (rnn_state.back_edge_source_array() == name) {
|
|
return true;
|
|
}
|
|
}
|
|
return false;
|
|
}
|
|
|
|
int CountTrueOutputs(const Model& model, const Operator& op) {
|
|
int count = 0;
|
|
for (const string& output : op.outputs) {
|
|
if (IsArrayConsumed(model, output)) {
|
|
++count;
|
|
}
|
|
}
|
|
return count;
|
|
}
|
|
|
|
int CountOpsWithInput(const Model& model, const string& array_name) {
|
|
int count = 0;
|
|
for (const auto& op : model.operators) {
|
|
for (auto& input : op->inputs) {
|
|
if (input == array_name) {
|
|
count++;
|
|
// Breaking here is important: some graphs have ops that use the
|
|
// same array as more than one of their inputs, and in that case
|
|
// we want it counted only once.
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
return count;
|
|
}
|
|
|
|
bool DeleteArrayIfUnused(const string& array_name, Model* model) {
|
|
if (IsDiscardableArray(*model, array_name) &&
|
|
CountOpsWithInput(*model, array_name) == 0 &&
|
|
GetOpWithOutput(*model, array_name) == nullptr) {
|
|
model->EraseArray(array_name);
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
bool DeleteArrayIfUnusedOutsideOfOp(const string& array_name,
|
|
const Operator* op, Model* model) {
|
|
if (!IsDiscardableArray(*model, array_name)) {
|
|
return false;
|
|
}
|
|
if (CountOpsWithInput(*model, array_name) > 1) {
|
|
return false;
|
|
}
|
|
const Operator* op_having_this_as_input = GetOpWithInput(*model, array_name);
|
|
if (op_having_this_as_input && op_having_this_as_input != op) {
|
|
return false;
|
|
}
|
|
const Operator* op_having_this_as_output =
|
|
GetOpWithOutput(*model, array_name);
|
|
if (op_having_this_as_output && op_having_this_as_output != op) {
|
|
return false;
|
|
}
|
|
model->EraseArray(array_name);
|
|
return true;
|
|
}
|
|
|
|
void DeleteOpAndArrays(Model* model, const Operator* op) {
|
|
for (const string& array_name : op->inputs) {
|
|
DeleteArrayIfUnusedOutsideOfOp(array_name, op, model);
|
|
}
|
|
for (const string& array_name : op->outputs) {
|
|
DeleteArrayIfUnusedOutsideOfOp(array_name, op, model);
|
|
}
|
|
auto op_it = FindOp(*model, op);
|
|
CHECK(op_it != model->operators.end());
|
|
model->operators.erase(op_it);
|
|
}
|
|
|
|
std::vector<std::unique_ptr<Operator>>::const_iterator FindOpWithOutput(
|
|
const Model& model, const string& array_name) {
|
|
for (auto it = model.operators.begin(); it != model.operators.end(); ++it) {
|
|
for (auto& output : it->get()->outputs) {
|
|
if (output == array_name) {
|
|
return it;
|
|
}
|
|
}
|
|
}
|
|
return model.operators.end();
|
|
}
|
|
|
|
std::vector<std::unique_ptr<Operator>>::iterator FindOpWithOutput(
|
|
Model& model, const string& array_name) {
|
|
for (auto it = model.operators.begin(); it != model.operators.end(); ++it) {
|
|
for (auto& output : it->get()->outputs) {
|
|
if (output == array_name) {
|
|
return it;
|
|
}
|
|
}
|
|
}
|
|
return model.operators.end();
|
|
}
|
|
|
|
Operator* GetOpWithOutput(const Model& model, const string& array_name) {
|
|
auto it = FindOpWithOutput(model, array_name);
|
|
return it == model.operators.end() ? nullptr : it->get();
|
|
}
|
|
|
|
// GetFirstOpWithInput assumes that this finds the first op.
|
|
std::vector<std::unique_ptr<Operator>>::const_iterator FindOpWithInput(
|
|
const Model& model, const string& array_name) {
|
|
for (auto it = model.operators.begin(); it != model.operators.end(); ++it) {
|
|
for (auto& input : it->get()->inputs) {
|
|
if (input == array_name) {
|
|
return it;
|
|
}
|
|
}
|
|
}
|
|
return model.operators.end();
|
|
}
|
|
|
|
std::vector<std::unique_ptr<Operator>>::iterator FindOpWithInput(
|
|
Model& model, const string& array_name) {
|
|
for (auto it = model.operators.begin(); it != model.operators.end(); ++it) {
|
|
for (auto& input : it->get()->inputs) {
|
|
if (input == array_name) {
|
|
return it;
|
|
}
|
|
}
|
|
}
|
|
return model.operators.end();
|
|
}
|
|
|
|
std::vector<std::unique_ptr<Operator>>::const_iterator FindOp(
|
|
const Model& model, const Operator* op) {
|
|
for (auto it = model.operators.begin(); it != model.operators.end(); ++it) {
|
|
if (it->get() == op) {
|
|
return it;
|
|
}
|
|
}
|
|
return model.operators.end();
|
|
}
|
|
|
|
std::vector<std::unique_ptr<Operator>>::iterator FindOp(Model& model,
|
|
const Operator* op) {
|
|
for (auto it = model.operators.begin(); it != model.operators.end(); ++it) {
|
|
if (it->get() == op) {
|
|
return it;
|
|
}
|
|
}
|
|
return model.operators.end();
|
|
}
|
|
|
|
Operator* GetOpWithInput(const Model& model, const string& array_name) {
|
|
auto it = FindOpWithInput(model, array_name);
|
|
return it == model.operators.end() ? nullptr : it->get();
|
|
}
|
|
|
|
Operator* GetFirstOpWithInput(const Model& model, const string& array_name) {
|
|
auto it = FindOpWithInput(model, array_name);
|
|
return it == model.operators.end() ? nullptr : it->get();
|
|
}
|
|
|
|
void ReplaceArrayUsage(Model* model, const string& old_array_name,
|
|
const string& new_array_name) {
|
|
for (auto& op_it : model->operators) {
|
|
Operator* op = op_it.get();
|
|
for (size_t i = 0; i < op->inputs.size(); ++i) {
|
|
if (op->inputs[i] == old_array_name) {
|
|
op->inputs[i] = new_array_name;
|
|
}
|
|
}
|
|
for (size_t i = 0; i < op->outputs.size(); ++i) {
|
|
if (op->outputs[i] == old_array_name) {
|
|
op->outputs[i] = new_array_name;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
string FormatArraysList(const Model& model, const std::vector<string>& list) {
|
|
if (list.empty()) {
|
|
return "[]";
|
|
}
|
|
string result = "";
|
|
if (list.size() > 1) {
|
|
result += "[ ";
|
|
}
|
|
for (std::size_t i = 0; i < list.size(); i++) {
|
|
if (i > 0) {
|
|
result += ", ";
|
|
}
|
|
result += list[i];
|
|
}
|
|
if (list.size() > 1) {
|
|
result += " ]";
|
|
}
|
|
return result;
|
|
}
|
|
|
|
const char* OperatorTypeName(OperatorType type) {
|
|
switch (type) {
|
|
#define HANDLE_OPERATORTYPENAME_CASE(c) \
|
|
case OperatorType::k##c: \
|
|
return #c;
|
|
HANDLE_OPERATORTYPENAME_CASE(Abs)
|
|
HANDLE_OPERATORTYPENAME_CASE(Add)
|
|
HANDLE_OPERATORTYPENAME_CASE(AddN)
|
|
HANDLE_OPERATORTYPENAME_CASE(AveragePool)
|
|
HANDLE_OPERATORTYPENAME_CASE(BatchMatMul)
|
|
HANDLE_OPERATORTYPENAME_CASE(BatchNormalization)
|
|
HANDLE_OPERATORTYPENAME_CASE(Conv)
|
|
HANDLE_OPERATORTYPENAME_CASE(Concatenation)
|
|
HANDLE_OPERATORTYPENAME_CASE(DepthwiseConv)
|
|
HANDLE_OPERATORTYPENAME_CASE(DepthToSpace)
|
|
HANDLE_OPERATORTYPENAME_CASE(SpaceToDepth)
|
|
HANDLE_OPERATORTYPENAME_CASE(FullyConnected)
|
|
HANDLE_OPERATORTYPENAME_CASE(HardSwish)
|
|
HANDLE_OPERATORTYPENAME_CASE(Dequantize)
|
|
HANDLE_OPERATORTYPENAME_CASE(L2Normalization)
|
|
HANDLE_OPERATORTYPENAME_CASE(LocalResponseNormalization)
|
|
HANDLE_OPERATORTYPENAME_CASE(Log)
|
|
HANDLE_OPERATORTYPENAME_CASE(Logistic)
|
|
HANDLE_OPERATORTYPENAME_CASE(LstmCell)
|
|
HANDLE_OPERATORTYPENAME_CASE(MaxPool)
|
|
HANDLE_OPERATORTYPENAME_CASE(L2Pool)
|
|
HANDLE_OPERATORTYPENAME_CASE(FakeQuant)
|
|
HANDLE_OPERATORTYPENAME_CASE(Mul)
|
|
HANDLE_OPERATORTYPENAME_CASE(RandomUniform)
|
|
HANDLE_OPERATORTYPENAME_CASE(Elu)
|
|
HANDLE_OPERATORTYPENAME_CASE(Relu)
|
|
HANDLE_OPERATORTYPENAME_CASE(Relu1)
|
|
HANDLE_OPERATORTYPENAME_CASE(Relu6)
|
|
HANDLE_OPERATORTYPENAME_CASE(PRelu)
|
|
HANDLE_OPERATORTYPENAME_CASE(ReorderAxes)
|
|
HANDLE_OPERATORTYPENAME_CASE(Softmax)
|
|
HANDLE_OPERATORTYPENAME_CASE(LogSoftmax)
|
|
HANDLE_OPERATORTYPENAME_CASE(Div)
|
|
HANDLE_OPERATORTYPENAME_CASE(Tanh)
|
|
HANDLE_OPERATORTYPENAME_CASE(Sin)
|
|
HANDLE_OPERATORTYPENAME_CASE(All)
|
|
HANDLE_OPERATORTYPENAME_CASE(Assert)
|
|
HANDLE_OPERATORTYPENAME_CASE(ExpandDims)
|
|
HANDLE_OPERATORTYPENAME_CASE(Fill)
|
|
HANDLE_OPERATORTYPENAME_CASE(FloorMod)
|
|
HANDLE_OPERATORTYPENAME_CASE(FloorDiv)
|
|
HANDLE_OPERATORTYPENAME_CASE(Greater)
|
|
HANDLE_OPERATORTYPENAME_CASE(GreaterEqual)
|
|
HANDLE_OPERATORTYPENAME_CASE(Identity)
|
|
HANDLE_OPERATORTYPENAME_CASE(Less)
|
|
HANDLE_OPERATORTYPENAME_CASE(LessEqual)
|
|
HANDLE_OPERATORTYPENAME_CASE(MatMul)
|
|
HANDLE_OPERATORTYPENAME_CASE(ReduceMax) // Reduction Max
|
|
HANDLE_OPERATORTYPENAME_CASE(Maximum) // Element-wise Maximum
|
|
HANDLE_OPERATORTYPENAME_CASE(Merge)
|
|
HANDLE_OPERATORTYPENAME_CASE(ReduceMin) // Reduction Min
|
|
HANDLE_OPERATORTYPENAME_CASE(Minimum) // Element-wise Minimum
|
|
HANDLE_OPERATORTYPENAME_CASE(Neg)
|
|
HANDLE_OPERATORTYPENAME_CASE(OneHot)
|
|
HANDLE_OPERATORTYPENAME_CASE(Pack)
|
|
HANDLE_OPERATORTYPENAME_CASE(Pad)
|
|
HANDLE_OPERATORTYPENAME_CASE(PadV2)
|
|
HANDLE_OPERATORTYPENAME_CASE(StridedSlice)
|
|
HANDLE_OPERATORTYPENAME_CASE(Range)
|
|
HANDLE_OPERATORTYPENAME_CASE(Rank)
|
|
HANDLE_OPERATORTYPENAME_CASE(Reshape)
|
|
HANDLE_OPERATORTYPENAME_CASE(Squeeze)
|
|
HANDLE_OPERATORTYPENAME_CASE(Rsqrt)
|
|
HANDLE_OPERATORTYPENAME_CASE(Shape)
|
|
HANDLE_OPERATORTYPENAME_CASE(Slice)
|
|
HANDLE_OPERATORTYPENAME_CASE(Split)
|
|
HANDLE_OPERATORTYPENAME_CASE(SplitV)
|
|
HANDLE_OPERATORTYPENAME_CASE(Sqrt)
|
|
HANDLE_OPERATORTYPENAME_CASE(Square)
|
|
HANDLE_OPERATORTYPENAME_CASE(Switch)
|
|
HANDLE_OPERATORTYPENAME_CASE(Sub)
|
|
HANDLE_OPERATORTYPENAME_CASE(Sum)
|
|
HANDLE_OPERATORTYPENAME_CASE(Tile)
|
|
HANDLE_OPERATORTYPENAME_CASE(Transpose)
|
|
HANDLE_OPERATORTYPENAME_CASE(TransposeConv)
|
|
HANDLE_OPERATORTYPENAME_CASE(Concat)
|
|
HANDLE_OPERATORTYPENAME_CASE(ConcatV2)
|
|
HANDLE_OPERATORTYPENAME_CASE(Cast)
|
|
HANDLE_OPERATORTYPENAME_CASE(Floor)
|
|
HANDLE_OPERATORTYPENAME_CASE(Ceil)
|
|
HANDLE_OPERATORTYPENAME_CASE(Round)
|
|
HANDLE_OPERATORTYPENAME_CASE(Gather)
|
|
HANDLE_OPERATORTYPENAME_CASE(GatherNd)
|
|
HANDLE_OPERATORTYPENAME_CASE(ResizeBilinear)
|
|
HANDLE_OPERATORTYPENAME_CASE(SpaceToBatchND)
|
|
HANDLE_OPERATORTYPENAME_CASE(BatchToSpaceND)
|
|
HANDLE_OPERATORTYPENAME_CASE(Mean)
|
|
HANDLE_OPERATORTYPENAME_CASE(ReduceProd)
|
|
HANDLE_OPERATORTYPENAME_CASE(Svdf)
|
|
HANDLE_OPERATORTYPENAME_CASE(ArgMax)
|
|
HANDLE_OPERATORTYPENAME_CASE(ArgMin)
|
|
HANDLE_OPERATORTYPENAME_CASE(TopK_V2)
|
|
HANDLE_OPERATORTYPENAME_CASE(Unsupported)
|
|
HANDLE_OPERATORTYPENAME_CASE(Exp)
|
|
HANDLE_OPERATORTYPENAME_CASE(DynamicPartition)
|
|
HANDLE_OPERATORTYPENAME_CASE(DynamicStitch)
|
|
HANDLE_OPERATORTYPENAME_CASE(Select)
|
|
HANDLE_OPERATORTYPENAME_CASE(SparseToDense)
|
|
HANDLE_OPERATORTYPENAME_CASE(Equal)
|
|
HANDLE_OPERATORTYPENAME_CASE(NotEqual)
|
|
HANDLE_OPERATORTYPENAME_CASE(Pow)
|
|
HANDLE_OPERATORTYPENAME_CASE(Any)
|
|
HANDLE_OPERATORTYPENAME_CASE(LogicalAnd)
|
|
HANDLE_OPERATORTYPENAME_CASE(LogicalNot)
|
|
HANDLE_OPERATORTYPENAME_CASE(LogicalOr)
|
|
HANDLE_OPERATORTYPENAME_CASE(CTCBeamSearchDecoder)
|
|
HANDLE_OPERATORTYPENAME_CASE(Unpack)
|
|
HANDLE_OPERATORTYPENAME_CASE(ZerosLike)
|
|
HANDLE_OPERATORTYPENAME_CASE(UnidirectionalSequenceLstm)
|
|
HANDLE_OPERATORTYPENAME_CASE(BidirectionalSequenceLstm)
|
|
HANDLE_OPERATORTYPENAME_CASE(BidirectionalSequenceRnn)
|
|
HANDLE_OPERATORTYPENAME_CASE(ResizeNearestNeighbor)
|
|
HANDLE_OPERATORTYPENAME_CASE(LeakyRelu)
|
|
HANDLE_OPERATORTYPENAME_CASE(SquaredDifference)
|
|
HANDLE_OPERATORTYPENAME_CASE(MirrorPad)
|
|
HANDLE_OPERATORTYPENAME_CASE(Unique)
|
|
HANDLE_OPERATORTYPENAME_CASE(UnidirectionalSequenceRnn)
|
|
HANDLE_OPERATORTYPENAME_CASE(ReverseV2)
|
|
HANDLE_OPERATORTYPENAME_CASE(Cos)
|
|
HANDLE_OPERATORTYPENAME_CASE(Where)
|
|
HANDLE_OPERATORTYPENAME_CASE(ReverseSequence)
|
|
HANDLE_OPERATORTYPENAME_CASE(MatrixDiag)
|
|
HANDLE_OPERATORTYPENAME_CASE(MatrixSetDiag)
|
|
HANDLE_OPERATORTYPENAME_CASE(MatrixDiagV2)
|
|
HANDLE_OPERATORTYPENAME_CASE(MatrixSetDiagV2)
|
|
default:
|
|
LOG(FATAL) << "Unhandled op type";
|
|
#undef HANDLE_OPERATORTYPENAME_CASE
|
|
}
|
|
}
|
|
|
|
string HelpfulOperatorTypeName(const Operator& op) {
|
|
if (op.type == OperatorType::kUnsupported) {
|
|
return toco::port::StringF(
|
|
"(Unsupported TensorFlow op: %s)",
|
|
static_cast<const TensorFlowUnsupportedOperator&>(op).tensorflow_op);
|
|
}
|
|
return OperatorTypeName(op.type);
|
|
}
|
|
|
|
bool OperatorSupportsFusedActivation(OperatorType type) {
|
|
switch (type) {
|
|
case OperatorType::kAdd:
|
|
case OperatorType::kAveragePool:
|
|
case OperatorType::kBatchNormalization:
|
|
case OperatorType::kConv:
|
|
case OperatorType::kDepthwiseConv:
|
|
case OperatorType::kDiv:
|
|
case OperatorType::kFullyConnected:
|
|
case OperatorType::kL2Pool:
|
|
case OperatorType::kMaxPool:
|
|
case OperatorType::kMul:
|
|
case OperatorType::kSub:
|
|
case OperatorType::kSquaredDifference:
|
|
return true;
|
|
default:
|
|
return false;
|
|
}
|
|
}
|
|
|
|
void LogSummary(int log_level, const Model& model) {
|
|
VLOG(log_level) << "Operators summary (" << model.operators.size()
|
|
<< " operators):";
|
|
std::unordered_multiset<OperatorType> ops_by_type;
|
|
for (const auto& op : model.operators) {
|
|
ops_by_type.insert(op->type);
|
|
}
|
|
auto it = ops_by_type.begin();
|
|
while (it != ops_by_type.end()) {
|
|
int count = ops_by_type.count(*it);
|
|
VLOG(log_level) << " " << OperatorTypeName(*it) << ": " << count;
|
|
std::advance(it, count);
|
|
}
|
|
}
|
|
|
|
void LogArray(int log_level, const Model& model, const string& name) {
|
|
VLOG(log_level) << "Array: " << name;
|
|
if (!model.HasArray(name)) {
|
|
VLOG(log_level) << " DOES NOT EXIST";
|
|
return;
|
|
}
|
|
const auto& array = model.GetArray(name);
|
|
VLOG(log_level) << " Data type: " << ArrayDataTypeName(array.data_type);
|
|
VLOG(log_level) << " Final type: "
|
|
<< ArrayDataTypeName(array.final_data_type);
|
|
if (array.buffer) {
|
|
VLOG(log_level) << " Constant Buffer";
|
|
}
|
|
if (array.alloc) {
|
|
VLOG(log_level) << " Transient Alloc";
|
|
}
|
|
if (array.has_shape()) {
|
|
const Shape& array_shape = array.shape();
|
|
if (array_shape.dimensions_count() == 0) {
|
|
VLOG(log_level) << " (Zero dimensions)";
|
|
} else {
|
|
string message = " Dims: ";
|
|
bool first = true;
|
|
for (const int dim : array_shape.dims()) {
|
|
if (!first) {
|
|
message += ", ";
|
|
}
|
|
first = false;
|
|
toco::port::AppendF(&message, "%d", dim);
|
|
}
|
|
VLOG(log_level) << message;
|
|
}
|
|
}
|
|
if (array.minmax) {
|
|
VLOG(log_level) << " MinMax: " << array.minmax->min << " .. "
|
|
<< array.minmax->max;
|
|
}
|
|
if (array.quantization_params) {
|
|
VLOG(log_level) << " QuantizationParams: zero_point="
|
|
<< static_cast<int>(array.quantization_params->zero_point)
|
|
<< ", scale=" << array.quantization_params->scale;
|
|
}
|
|
}
|
|
|
|
void DumpGraphvizVideoFrame(const Model& model) {
|
|
namespace port = toco::port;
|
|
|
|
const auto& dump_options = *GraphVizDumpOptions::singleton();
|
|
if (!dump_options.dump_graphviz_video) {
|
|
return;
|
|
}
|
|
CHECK(!dump_options.dump_graphviz.empty());
|
|
// TODO(benoitjacob): the static data here means that this function
|
|
// is stateful, not reentrant, and effectively leaks memory till exit
|
|
// (since dump_hashes can only grow in size). It also means that it
|
|
// really only is intended to be called for a single model during the
|
|
// process' lifetime. So it's not great design at all. The overriding
|
|
// design aspect here is to make the video-dumping code as unintrusive
|
|
// and self-contained as possible. Eventually, we'll want to have that
|
|
// cleaned-up, but that will require some form of general statefulness
|
|
// in toco (some kind of 'tooling state' data structure) that does
|
|
// not exist at present, and would be premature to design here just for
|
|
// this new video-dumping feature.
|
|
static int dump_id = 0;
|
|
static std::unordered_set<std::size_t> dump_hashes;
|
|
string graphviz_dump;
|
|
DumpGraphviz(model, &graphviz_dump,
|
|
toco::port::StringF("VIDEO frame:%05d", dump_id));
|
|
std::size_t hash = std::hash<string>{}(graphviz_dump);
|
|
if (!dump_hashes.count(hash)) {
|
|
LOG(INFO) << "DUMPING GRAPHVIZ VIDEO FRAME: " << dump_id;
|
|
dump_hashes.insert(hash);
|
|
const auto result = port::file::SetContents(
|
|
port::file::JoinPath(
|
|
dump_options.dump_graphviz,
|
|
toco::port::StringF("toco_video_%05d.dot", dump_id)),
|
|
graphviz_dump, port::file::Defaults());
|
|
QCHECK(result.ok()) << result.error_message();
|
|
dump_id++;
|
|
}
|
|
}
|
|
|
|
void LogDump(int log_level, const string& message, const Model& model) {
|
|
namespace port = toco::port;
|
|
const auto& dump_options = *GraphVizDumpOptions::singleton();
|
|
|
|
DumpGraphvizVideoFrame(model);
|
|
if (!dump_options.dump_graphviz.empty()) {
|
|
string graphviz_dump;
|
|
|
|
DumpGraphviz(model, &graphviz_dump, message);
|
|
const auto result = port::file::SetContents(
|
|
port::file::JoinPath(
|
|
dump_options.dump_graphviz,
|
|
absl::StrCat("toco_", absl::StrReplaceAll(message, {{" ", "_"}}),
|
|
".dot")),
|
|
graphviz_dump, port::file::Defaults());
|
|
QCHECK(result.ok()) << result.error_message();
|
|
}
|
|
|
|
if (!VLOG_IS_ON(log_level)) {
|
|
return;
|
|
}
|
|
VLOG(log_level) << "BEGIN DUMP OF TOCO MODEL (" << message << ")";
|
|
LogSummary(log_level, model);
|
|
std::unordered_set<string> already_printed_arrays;
|
|
for (const auto& op : model.operators) {
|
|
for (const auto& input : op->inputs) {
|
|
if (!already_printed_arrays.count(input)) {
|
|
already_printed_arrays.insert(input);
|
|
LogArray(log_level, model, input);
|
|
}
|
|
}
|
|
VLOG(log_level) << HelpfulOperatorTypeName(*op) << " :";
|
|
VLOG(log_level) << " " << FormatArraysList(model, op->inputs) << " -> "
|
|
<< FormatArraysList(model, op->outputs);
|
|
if (op->fused_activation_function != FusedActivationFunctionType::kNone) {
|
|
VLOG(log_level) << " (with fused activation function)";
|
|
}
|
|
for (const auto& output : op->outputs) {
|
|
if (!already_printed_arrays.count(output)) {
|
|
already_printed_arrays.insert(output);
|
|
LogArray(log_level, model, output);
|
|
}
|
|
}
|
|
}
|
|
VLOG(log_level) << "END DUMP OF TOCO MODEL (" << message << ")";
|
|
}
|
|
|
|
// Note remaining raw-array extension in ProcessTensorFlowReshapeOperator().
|
|
void ExtendShape(Shape* shape, int new_shape_size) {
|
|
CHECK_GE(new_shape_size, shape->dimensions_count());
|
|
const int size_increase = new_shape_size - shape->dimensions_count();
|
|
auto* shape_dims = shape->mutable_dims();
|
|
shape_dims->insert(shape_dims->begin(), size_increase, 1);
|
|
}
|
|
|
|
// TODO(b/62904716) Remove along with remaining uses.
|
|
void UnextendShape(Shape* shape, int new_shape_size) {
|
|
CHECK_LE(new_shape_size, shape->dimensions_count());
|
|
const int size_reduction = shape->dimensions_count() - new_shape_size;
|
|
for (int i = 0; i < size_reduction; i++) {
|
|
CHECK_EQ(shape->dims(i), 1);
|
|
}
|
|
std::vector<int>& shape_dims = *shape->mutable_dims();
|
|
shape_dims.erase(shape_dims.begin(), shape_dims.begin() + size_reduction);
|
|
}
|
|
|
|
// In general, zero-sized dimensions are disallowed, but there are exceptions,
|
|
// e.g., if the tensor data itself represents a scalar (rank 0) shape, its
|
|
// shape will have dimensions [0]. CheckNonEmptyShapeDimensions is more
|
|
// strict, and is appropriate for ops and comparisons where an empty shape
|
|
// doesn't make sense.
|
|
template <typename Dims>
|
|
void CheckValidShapeDimensions(const Dims& dims) {
|
|
if (dims.size() == 1 && dims[0] == 0) {
|
|
return;
|
|
}
|
|
for (const auto& dim : dims) {
|
|
CHECK_GE(dim, 1);
|
|
}
|
|
}
|
|
|
|
void CheckValidShape(const Shape& shape) {
|
|
CheckValidShapeDimensions(shape.dims());
|
|
}
|
|
|
|
bool IsNonEmpty(const Shape& shape) {
|
|
for (int i = 0; i < shape.dimensions_count(); ++i) {
|
|
if (shape.dims(i) < 1) return false;
|
|
}
|
|
return true;
|
|
}
|
|
|
|
void CheckNonEmptyShapeDimensions(const Shape& shape) {
|
|
for (int i = 0; i < shape.dimensions_count(); ++i) {
|
|
CHECK_GE(shape.dims()[i], 1) << "shape has dimension 0 at index << " << i
|
|
<< ". shape = " << ShapeToString(shape);
|
|
}
|
|
}
|
|
|
|
bool ShapesAgreeUpToBroadcasting(const Shape& shape0, const Shape& shape1) {
|
|
CheckNonEmptyShapeDimensions(shape0);
|
|
CheckNonEmptyShapeDimensions(shape1);
|
|
|
|
const Shape* longer = &shape0;
|
|
const Shape* shorter = &shape1;
|
|
if (shape1.dimensions_count() > shape0.dimensions_count()) {
|
|
longer = &shape1;
|
|
shorter = &shape0;
|
|
}
|
|
|
|
// Walk dimensions back to front until we run out of dimensions in the shorter
|
|
// shape.
|
|
int longer_index = longer->dimensions_count() - 1;
|
|
int shorter_index = shorter->dimensions_count() - 1;
|
|
while (shorter_index >= 0) {
|
|
const int d_long = longer->dims(longer_index);
|
|
const int d_short = shorter->dims(shorter_index);
|
|
// Broadcasting fails if the dimensions are different *and* neither is 1.
|
|
if ((d_long != d_short) && (d_long != 1) && (d_short != 1)) {
|
|
return false;
|
|
}
|
|
longer_index--;
|
|
shorter_index--;
|
|
}
|
|
return true;
|
|
}
|
|
|
|
bool ShapesAgreeUpToExtending(const Shape& shape0, const Shape& shape1) {
|
|
CheckNonEmptyShapeDimensions(shape0);
|
|
CheckNonEmptyShapeDimensions(shape1);
|
|
|
|
const Shape* longer = &shape0;
|
|
const Shape* shorter = &shape1;
|
|
if (shape1.dimensions_count() > shape0.dimensions_count()) {
|
|
longer = &shape1;
|
|
shorter = &shape0;
|
|
}
|
|
|
|
// Walk dimensions back to front until we run out of dimensions in the shorter
|
|
// shape.
|
|
int longer_index = longer->dimensions_count() - 1;
|
|
int shorter_index = shorter->dimensions_count() - 1;
|
|
while (shorter_index >= 0) {
|
|
const int d_long = longer->dims(longer_index);
|
|
const int d_short = shorter->dims(shorter_index);
|
|
// Extending fails if the dimensions are different.
|
|
if (d_long != d_short) {
|
|
return false;
|
|
}
|
|
longer_index--;
|
|
shorter_index--;
|
|
}
|
|
|
|
// The remaining dimensions in the longer shape must be 1.
|
|
while (longer_index >= 0) {
|
|
const int d_long = longer->dims(longer_index);
|
|
if (d_long != 1) {
|
|
return false;
|
|
}
|
|
longer_index--;
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
int RequiredBufferSizeForShape(const Shape& shape) {
|
|
CheckValidShape(shape);
|
|
int max_offset = 1;
|
|
for (const auto& dim : shape.dims()) {
|
|
max_offset *= dim;
|
|
}
|
|
return max_offset;
|
|
}
|
|
|
|
bool IsConstantParameterArray(const Model& model, const string& name) {
|
|
if (!model.HasArray(name)) {
|
|
return false;
|
|
}
|
|
|
|
return !!model.GetArray(name).buffer;
|
|
}
|
|
|
|
namespace {
|
|
template <ArrayDataType A>
|
|
bool CompareArrayBuffers(const Array& lhs_array, const Array& rhs_array) {
|
|
CHECK(lhs_array.data_type == rhs_array.data_type) << "Data types must match";
|
|
CHECK(lhs_array.buffer) << "LHS must be constant";
|
|
CHECK(rhs_array.buffer) << "RHS must be constant";
|
|
const auto& lhs_data = lhs_array.GetBuffer<A>().data;
|
|
const auto& rhs_data = rhs_array.GetBuffer<A>().data;
|
|
CHECK_EQ(lhs_data.size(), rhs_data.size())
|
|
<< "Buffer sizes must match in element count";
|
|
for (int i = 0; i < lhs_data.size(); ++i) {
|
|
if (lhs_data[i] != rhs_data[i]) {
|
|
return false;
|
|
}
|
|
}
|
|
return true;
|
|
}
|
|
|
|
bool HaveSameMinMax(const Array& lhs_array, const Array& rhs_array) {
|
|
if (lhs_array.minmax || rhs_array.minmax) {
|
|
if (!lhs_array.minmax || !rhs_array.minmax) {
|
|
return false;
|
|
}
|
|
if (!(*lhs_array.minmax == *rhs_array.minmax)) {
|
|
return false;
|
|
}
|
|
}
|
|
return true;
|
|
}
|
|
|
|
bool HaveSameQuantizationParams(const Array& lhs_array,
|
|
const Array& rhs_array) {
|
|
if (lhs_array.quantization_params || rhs_array.quantization_params) {
|
|
if (!lhs_array.quantization_params || !rhs_array.quantization_params) {
|
|
return false;
|
|
}
|
|
if (!(*lhs_array.quantization_params == *rhs_array.quantization_params)) {
|
|
return false;
|
|
}
|
|
}
|
|
return true;
|
|
}
|
|
|
|
} // namespace
|
|
|
|
bool CompareConstantArrays(const Array& lhs_array, const Array& rhs_array) {
|
|
bool attrs_equal = lhs_array.shape() == rhs_array.shape() &&
|
|
lhs_array.data_type == rhs_array.data_type &&
|
|
lhs_array.final_data_type == rhs_array.final_data_type &&
|
|
HaveSameMinMax(lhs_array, rhs_array) &&
|
|
HaveSameQuantizationParams(lhs_array, rhs_array) &&
|
|
lhs_array.narrow_range == rhs_array.narrow_range;
|
|
if (!attrs_equal) {
|
|
return false;
|
|
}
|
|
switch (lhs_array.data_type) {
|
|
case ArrayDataType::kBool:
|
|
return CompareArrayBuffers<ArrayDataType::kBool>(lhs_array, rhs_array);
|
|
case ArrayDataType::kFloat:
|
|
return CompareArrayBuffers<ArrayDataType::kFloat>(lhs_array, rhs_array);
|
|
case ArrayDataType::kInt8:
|
|
return CompareArrayBuffers<ArrayDataType::kInt8>(lhs_array, rhs_array);
|
|
case ArrayDataType::kUint8:
|
|
return CompareArrayBuffers<ArrayDataType::kUint8>(lhs_array, rhs_array);
|
|
case ArrayDataType::kInt16:
|
|
return CompareArrayBuffers<ArrayDataType::kInt16>(lhs_array, rhs_array);
|
|
case ArrayDataType::kUint16:
|
|
return CompareArrayBuffers<ArrayDataType::kUint16>(lhs_array, rhs_array);
|
|
case ArrayDataType::kInt32:
|
|
return CompareArrayBuffers<ArrayDataType::kInt32>(lhs_array, rhs_array);
|
|
case ArrayDataType::kUint32:
|
|
return CompareArrayBuffers<ArrayDataType::kUint32>(lhs_array, rhs_array);
|
|
case ArrayDataType::kInt64:
|
|
return CompareArrayBuffers<ArrayDataType::kInt64>(lhs_array, rhs_array);
|
|
case ArrayDataType::kUint64:
|
|
return CompareArrayBuffers<ArrayDataType::kUint64>(lhs_array, rhs_array);
|
|
case ArrayDataType::kString:
|
|
return CompareArrayBuffers<ArrayDataType::kString>(lhs_array, rhs_array);
|
|
case ArrayDataType::kComplex64:
|
|
return CompareArrayBuffers<ArrayDataType::kComplex64>(lhs_array,
|
|
rhs_array);
|
|
default:
|
|
LOG(FATAL) << "Unsupported data type: "
|
|
<< ArrayDataTypeName(lhs_array.data_type);
|
|
return false;
|
|
}
|
|
}
|
|
|
|
namespace {
|
|
// Take an array name, which may be something like "name:3_5" and make it
|
|
// acceptable as a TF node name, say "name_3_5";
|
|
string SanitizeNameForTFNode(const string& array_name) {
|
|
auto node_name = array_name;
|
|
std::replace(node_name.begin(), node_name.end(), ':', '_');
|
|
return node_name;
|
|
}
|
|
|
|
void CheckInputArraysAreNotOutputArrays(const ModelFlags& model_flags) {
|
|
for (const auto& input_array : model_flags.input_arrays()) {
|
|
for (const string& output_array : model_flags.output_arrays()) {
|
|
QCHECK_NE(input_array.name(), output_array)
|
|
<< "The array " << output_array
|
|
<< " is listed in both --input_arrays and --output_arrays.";
|
|
}
|
|
}
|
|
}
|
|
|
|
bool IsAsciiPrintable(const string& name) {
|
|
for (char c : name) {
|
|
if (!absl::ascii_isprint(c)) {
|
|
return false;
|
|
}
|
|
}
|
|
return true;
|
|
}
|
|
|
|
string DumpAscii(const string& name) {
|
|
string result;
|
|
port::AppendF(&result, "ASCII | Hex\n");
|
|
port::AppendF(&result, "------+----\n");
|
|
for (char c : name) {
|
|
if (absl::ascii_isprint(c)) {
|
|
port::AppendF(&result, "%c | %x\n", c, c);
|
|
} else {
|
|
port::AppendF(&result, " | %x Not ASCII printable!\n", c);
|
|
}
|
|
}
|
|
return result;
|
|
}
|
|
|
|
void CheckNonAsciiIOArrays(const ModelFlags& model_flags) {
|
|
if (model_flags.allow_nonascii_arrays()) {
|
|
return;
|
|
}
|
|
for (const auto& input_array : model_flags.input_arrays()) {
|
|
QCHECK(IsAsciiPrintable(input_array.name()))
|
|
<< "Non-ASCII-printable character found in --input_arrays: "
|
|
<< input_array.name()
|
|
<< ". Pass --allow_nonascii_arrays to allow that. "
|
|
<< "Here is a dump of the string:\n\n"
|
|
<< DumpAscii(input_array.name());
|
|
}
|
|
for (const string& output_array : model_flags.output_arrays()) {
|
|
QCHECK(IsAsciiPrintable(output_array))
|
|
<< "Non-ASCII-printable character found in --output_arrays: "
|
|
<< output_array << ". Pass --allow_nonascii_arrays to allow that. "
|
|
<< "Here is a dump of the string:\n\n"
|
|
<< DumpAscii(output_array);
|
|
}
|
|
}
|
|
|
|
void CheckNonExistentIOArrays(const Model& model) {
|
|
// "non-existent" is interpreted in the stronger sense of
|
|
// "not actually produced/consumed by an op".
|
|
// Rationale: we have to artificially fix up TensorFlow graphs by creating
|
|
// any array that it refers to, so just checking that arrays exist isn't
|
|
// sufficient. The real invariant here is whether arrays are produced/consumed
|
|
// by something.
|
|
if (model.flags.allow_nonexistent_arrays()) {
|
|
return;
|
|
}
|
|
static constexpr char general_comment[] =
|
|
"Is it a typo? This should not happen. If you trigger this error "
|
|
"please send a bug report (with code to reporduce this error), to the "
|
|
"TensorFlow Lite team.";
|
|
for (const string& output_array : model.flags.output_arrays()) {
|
|
if (IsConstantParameterArray(model, output_array)) {
|
|
continue; // It is OK to request that a constant be an output.
|
|
}
|
|
QCHECK(GetOpWithOutput(model, output_array))
|
|
<< "Specified output array \"" << output_array
|
|
<< "\" is not produced by any op in this graph. " << general_comment;
|
|
}
|
|
for (const auto& rnn_state : model.flags.rnn_states()) {
|
|
if (!rnn_state.discardable()) {
|
|
// Check that all RNN states are consumed
|
|
QCHECK(GetOpWithInput(model, rnn_state.state_array()))
|
|
<< "Specified RNN state \"" << rnn_state.state_array()
|
|
<< "\" is not consumed by any op in this graph. " << general_comment;
|
|
// Check that all RNN back-edge source arrays are produced
|
|
QCHECK(GetOpWithOutput(model, rnn_state.back_edge_source_array()))
|
|
<< "Specified RNN back-edge source array \""
|
|
<< rnn_state.back_edge_source_array()
|
|
<< "\" is not produced by any op in this graph. " << general_comment;
|
|
}
|
|
}
|
|
}
|
|
|
|
} // namespace
|
|
|
|
void CheckNoMissingArray(const Model& model) {
|
|
for (const auto& op : model.operators) {
|
|
for (const auto& input : op->inputs) {
|
|
CHECK(model.HasArray(input) || model.optional_arrays.count(input))
|
|
<< "Input: " << input << " missing for op: " << op->outputs[0] << ".";
|
|
}
|
|
for (const auto& output : op->outputs) {
|
|
CHECK(model.HasArray(output)) << "Output: " << output << " missing.";
|
|
}
|
|
}
|
|
CheckNonExistentIOArrays(model);
|
|
}
|
|
|
|
void FixNoMissingArray(Model* model) {
|
|
for (const auto& op : model->operators) {
|
|
for (const auto& input : op->inputs) {
|
|
if (!model->HasArray(input) && !model->IsOptionalArray(input)) {
|
|
model->GetOrCreateArray(input);
|
|
}
|
|
}
|
|
for (const auto& output : op->outputs) {
|
|
if (!model->HasArray(output) && !model->IsOptionalArray(output)) {
|
|
model->GetOrCreateArray(output);
|
|
}
|
|
}
|
|
}
|
|
if (model->flags.allow_nonexistent_arrays()) {
|
|
for (const string& output_array : model->flags.output_arrays()) {
|
|
model->GetOrCreateArray(output_array);
|
|
}
|
|
for (const auto& rnn_state : model->flags.rnn_states()) {
|
|
model->GetOrCreateArray(rnn_state.state_array());
|
|
model->GetOrCreateArray(rnn_state.back_edge_source_array());
|
|
}
|
|
}
|
|
}
|
|
|
|
void CheckNoOrphanedArray(const Model& model) {
|
|
std::unordered_set<string> arrays_without_known_use;
|
|
for (const auto& array : model.GetArrayMap()) {
|
|
if (IsDiscardableArray(model, array.first)) {
|
|
arrays_without_known_use.insert(array.first);
|
|
}
|
|
}
|
|
for (const auto& op : model.operators) {
|
|
for (const auto& input : op->inputs) {
|
|
arrays_without_known_use.erase(input);
|
|
}
|
|
for (const auto& output : op->outputs) {
|
|
arrays_without_known_use.erase(output);
|
|
}
|
|
}
|
|
for (const auto& rnn_state : model.flags.rnn_states()) {
|
|
arrays_without_known_use.erase(rnn_state.state_array());
|
|
arrays_without_known_use.erase(rnn_state.back_edge_source_array());
|
|
}
|
|
if (!arrays_without_known_use.empty()) {
|
|
for (const auto& array : arrays_without_known_use) {
|
|
LOG(INFO) << "Error: Orphaned array: " << array;
|
|
}
|
|
}
|
|
CHECK(arrays_without_known_use.empty());
|
|
}
|
|
|
|
void FixNoOrphanedArray(Model* model) {
|
|
std::unordered_set<string> arrays_without_known_use;
|
|
for (const auto& array : model->GetArrayMap()) {
|
|
arrays_without_known_use.insert(array.first);
|
|
}
|
|
for (const auto& op : model->operators) {
|
|
for (const auto& input : op->inputs) {
|
|
arrays_without_known_use.erase(input);
|
|
}
|
|
for (const auto& output : op->outputs) {
|
|
arrays_without_known_use.erase(output);
|
|
}
|
|
}
|
|
for (const auto& rnn_state : model->flags.rnn_states()) {
|
|
arrays_without_known_use.erase(rnn_state.state_array());
|
|
arrays_without_known_use.erase(rnn_state.back_edge_source_array());
|
|
}
|
|
for (const auto& array : arrays_without_known_use) {
|
|
if (IsDiscardableArray(*model, array)) {
|
|
model->EraseArray(array);
|
|
}
|
|
}
|
|
}
|
|
|
|
// Apply checks to arrays individually (for-each fashion).
|
|
//
|
|
// Check consistency of array fields, check name.
|
|
void CheckEachArray(const Model& model) {
|
|
for (const auto& array_entry : model.GetArrayMap()) {
|
|
const auto& array = array_entry.second;
|
|
// It's OK to have a buffer or an alloc, but not both.
|
|
// (Since allocs are for transient arrays without a buffer).
|
|
CHECK(!array->buffer || !array->alloc);
|
|
if (array->buffer) {
|
|
// If there is a buffer, its type should be consistent with data_type.
|
|
CHECK(array->buffer->type == array->data_type);
|
|
// The presence of a fixed buffer should imply the presence of a fixed
|
|
// shape.
|
|
CHECK(array->has_shape());
|
|
// Constant buffer should has a valid shape.
|
|
CheckValidShape(array->shape());
|
|
// The shape flat-size should agree with the buffer length.
|
|
CHECK_EQ(array->buffer->Length(),
|
|
RequiredBufferSizeForShape(array->shape()));
|
|
}
|
|
|
|
// Check name. Either "name_with_suffix_8", "name_with_port:3", but not
|
|
// "name_with_both:3_8".
|
|
const string& name = array_entry.first;
|
|
auto colon_pos = name.find_first_of(":");
|
|
if (colon_pos != string::npos) {
|
|
CHECK_EQ(name.substr(colon_pos + 1).find_first_not_of("0123456789"),
|
|
string::npos)
|
|
<< "Array '" << name << "' has non-digit characters after colon.";
|
|
}
|
|
CHECK_GT(colon_pos, 0) << "Array '" << name
|
|
<< "' must not start with a colon.";
|
|
}
|
|
}
|
|
|
|
void CheckOperatorOrdering(const Model& model) {
|
|
std::unordered_set<string> arrays_behind_us;
|
|
for (const auto& array_entry : model.GetArrayMap()) {
|
|
if (!GetOpWithOutput(model, array_entry.first)) {
|
|
arrays_behind_us.insert(array_entry.first);
|
|
}
|
|
}
|
|
arrays_behind_us.insert(model.optional_arrays.begin(),
|
|
model.optional_arrays.end());
|
|
for (const auto& op : model.operators) {
|
|
for (const auto& input : op->inputs) {
|
|
if (!IsConstantParameterArray(model, input)) {
|
|
CHECK(arrays_behind_us.count(input));
|
|
}
|
|
}
|
|
for (const auto& output : op->outputs) {
|
|
CHECK(!arrays_behind_us.count(output));
|
|
arrays_behind_us.insert(output);
|
|
}
|
|
}
|
|
for (const string& output_array : model.flags.output_arrays()) {
|
|
CHECK(arrays_behind_us.count(output_array));
|
|
}
|
|
}
|
|
|
|
void FixOperatorOrdering(Model* model) {
|
|
std::unordered_set<string> arrays_behind_us;
|
|
for (const auto& array_entry : model->GetArrayMap()) {
|
|
if (!GetOpWithOutput(*model, array_entry.first)) {
|
|
arrays_behind_us.insert(array_entry.first);
|
|
}
|
|
}
|
|
arrays_behind_us.insert(model->optional_arrays.begin(),
|
|
model->optional_arrays.end());
|
|
std::vector<std::unique_ptr<Operator>> old_operators;
|
|
std::swap(old_operators, model->operators);
|
|
std::set<std::size_t> remaining;
|
|
for (std::size_t i = 0; i < old_operators.size(); i++) {
|
|
remaining.insert(i);
|
|
}
|
|
std::unordered_map<string, string> reason_why_leftover;
|
|
while (true) {
|
|
bool inserted_something = false;
|
|
for (const auto& i : remaining) {
|
|
bool can_insert = true;
|
|
auto& op = old_operators[i];
|
|
CHECK(op);
|
|
for (const auto& input : op->inputs) {
|
|
if (!IsConstantParameterArray(*model, input) &&
|
|
!arrays_behind_us.count(input)) {
|
|
for (const string& output : op->outputs) {
|
|
reason_why_leftover[output] = input;
|
|
}
|
|
can_insert = false;
|
|
break;
|
|
}
|
|
}
|
|
if (can_insert) {
|
|
model->operators.emplace_back(nullptr);
|
|
for (const auto& output : op->outputs) {
|
|
arrays_behind_us.insert(output);
|
|
}
|
|
std::swap(op, model->operators.back());
|
|
remaining.erase(i);
|
|
inserted_something = true;
|
|
break;
|
|
}
|
|
}
|
|
if (!inserted_something) {
|
|
break;
|
|
}
|
|
}
|
|
if (!remaining.empty()) {
|
|
LOG(ERROR)
|
|
<< "No viable ordering of operators was found. "
|
|
<< "Here is a 'backtrace' of at least one part of the graph that is "
|
|
<< "problematic. It starts with the first operator that has as "
|
|
<< "problematic input array, and then walks back the graph to "
|
|
<< "the operator that produced that input array, etc., until we find "
|
|
<< "the root cause:";
|
|
LOG(ERROR) << "BEGIN TRACE OF OPERATOR WITH BAD INPUT";
|
|
LOG(ERROR) << "Here is the first-encountered operator with a bad input: ";
|
|
const Operator* bad_op = old_operators[*remaining.begin()].get();
|
|
std::unordered_set<string> bad_inputs_already_traced;
|
|
// The following while(true) loop should always end with a LOG(FATAL).
|
|
while (true) {
|
|
LOG(ERROR) << HelpfulOperatorTypeName(*bad_op) << " : "
|
|
<< FormatArraysList(*model, bad_op->inputs) << " -> "
|
|
<< FormatArraysList(*model, bad_op->outputs);
|
|
bool found_bad_output = false;
|
|
string bad_output;
|
|
for (const string& output : bad_op->outputs) {
|
|
if (reason_why_leftover.count(output)) {
|
|
found_bad_output = true;
|
|
bad_output = output;
|
|
break;
|
|
}
|
|
}
|
|
CHECK(found_bad_output);
|
|
const string& bad_input = reason_why_leftover[bad_output];
|
|
LOG(ERROR) << "The bad input here is: " << bad_input;
|
|
if (bad_inputs_already_traced.count(bad_input)) {
|
|
LOG(FATAL)
|
|
<< "Cycle found! We already encountered that "
|
|
<< "input array, " << bad_input << ", earlier in the "
|
|
<< "above trace! We expect graphs to be acyclic, even "
|
|
<< "RNNs. Let us know if some graph actually needs to have "
|
|
<< "cycles, but first, please check if it really is "
|
|
<< "an *inference* graph. *Training* graphs are out-of-scope "
|
|
<< "for toco.";
|
|
}
|
|
bad_inputs_already_traced.insert(bad_input);
|
|
bad_op = nullptr;
|
|
for (const auto& i : remaining) {
|
|
const Operator* op = old_operators[i].get();
|
|
for (const string& output : op->outputs) {
|
|
if (bad_input == output) {
|
|
bad_op = op;
|
|
break;
|
|
}
|
|
}
|
|
if (bad_op) {
|
|
break;
|
|
}
|
|
}
|
|
if (!bad_op) {
|
|
LOG(ERROR) << "And that's the root cause: "
|
|
<< "that array, " << bad_input << ", isn't produced by any "
|
|
<< "operator, or provided in any other way.";
|
|
LOG(ERROR) << "END TRACE OF OPERATOR WITH BAD INPUT";
|
|
LOG(FATAL) << "(The above was a multi-line fatal error)";
|
|
}
|
|
LOG(ERROR) << "And that array is the output of the following operator:";
|
|
}
|
|
}
|
|
CHECK(remaining.empty())
|
|
<< "Should never get here! In case of bad graph, "
|
|
<< "the above code should have generated a FATAL error already!";
|
|
}
|
|
|
|
void CheckInvariants(const Model& model) {
|
|
CheckInputArraysAreNotOutputArrays(model.flags);
|
|
CheckNonAsciiIOArrays(model.flags);
|
|
CheckNoMissingArray(model);
|
|
CheckNoOrphanedArray(model);
|
|
CheckEachArray(model);
|
|
CheckOperatorOrdering(model);
|
|
}
|
|
|
|
void CheckCountInRange(const ::toco::ModelFlags::ModelCheck& model_check,
|
|
const int count, const string& count_description) {
|
|
if (model_check.count_min() >= 0) {
|
|
CHECK_GE(count, model_check.count_min())
|
|
<< "Mismatch in " << count_description << ": count was " << count
|
|
<< ", but the specified "
|
|
<< (model_check.count_max() > model_check.count_min() ? "minimum"
|
|
: "value")
|
|
<< " was " << model_check.count_min() << ".";
|
|
}
|
|
if (model_check.count_max() > model_check.count_min()) {
|
|
CHECK_LE(count, model_check.count_max())
|
|
<< "Mismatch in " << count_description << ": count was " << count
|
|
<< ", but the specified maximum was " << model_check.count_max() << ".";
|
|
}
|
|
}
|
|
|
|
void CheckModelCounts(const Model& model) {
|
|
std::unordered_multiset<OperatorType> ops_by_type;
|
|
std::unordered_map<string, OperatorType> op_type_by_name;
|
|
if (model.flags.model_checks_size() == 0) {
|
|
return;
|
|
}
|
|
|
|
for (const auto& op : model.operators) {
|
|
ops_by_type.insert(op->type);
|
|
op_type_by_name[OperatorTypeName(op->type)] = op->type;
|
|
}
|
|
for (const auto& model_check : model.flags.model_checks()) {
|
|
string count_type = model_check.count_type();
|
|
if (count_type == "None") {
|
|
continue;
|
|
} else if (count_type == "Arrays") {
|
|
CheckCountInRange(model_check, model.GetArrayMap().size(),
|
|
"count of arrays");
|
|
} else if (count_type == "Total") {
|
|
CheckCountInRange(model_check, model.operators.size(),
|
|
"count of all operator instances");
|
|
} else {
|
|
// The check type is not itself checked against the set of valid
|
|
// operators, mainly because the enum set cannot be iterated in C++.
|
|
const int found_count =
|
|
op_type_by_name.count(count_type) > 0
|
|
? ops_by_type.count(op_type_by_name[count_type])
|
|
: 0;
|
|
CheckCountInRange(model_check, found_count,
|
|
"count of instances of " + count_type + " operator");
|
|
}
|
|
}
|
|
}
|
|
|
|
void FixEdgeArrays(Model* model) {
|
|
for (const string& output_array_name : model->flags.output_arrays()) {
|
|
if (!GetOpWithOutput(*model, output_array_name)) {
|
|
// Output has no operator producing it. Change that by inserting a copy.
|
|
LOG(WARNING) << "Fixing constant output array " << output_array_name
|
|
<< " by inserting a copy. This is not optimal.";
|
|
string intermediate_array_name =
|
|
AvailableArrayName(*model, output_array_name + "_copy");
|
|
CloneArray(model, output_array_name, intermediate_array_name);
|
|
InsertCopyOperator(model, intermediate_array_name, output_array_name);
|
|
}
|
|
}
|
|
}
|
|
|
|
void DedupeConstantArrays(Model* model, size_t min_size) {
|
|
// Walk all 0..N and compare with the remaining n+1..N.
|
|
// This lets us avoid N^2 comparisons and erase duplicate arrays while
|
|
// iterating.
|
|
const auto& array_map = model->GetArrayMap();
|
|
for (auto lhs_array_it = array_map.begin(); lhs_array_it != array_map.end();
|
|
++lhs_array_it) {
|
|
const auto& lhs_array_name = lhs_array_it->first;
|
|
const auto& lhs_array = *lhs_array_it->second;
|
|
if (!IsConstantParameterArray(*model, lhs_array_name)) {
|
|
// Not a constant array; skip.
|
|
continue;
|
|
}
|
|
ArrayDataType final_data_type =
|
|
lhs_array.final_data_type != ArrayDataType::kNone
|
|
? lhs_array.final_data_type
|
|
: lhs_array.data_type;
|
|
// Ignore small arrays, don't check string arrays because it is not possible
|
|
// to estimate its size.
|
|
if (final_data_type != ArrayDataType::kString) {
|
|
size_t array_byte_size =
|
|
lhs_array.buffer->Length() * ElementSize(final_data_type);
|
|
if (array_byte_size < min_size) {
|
|
// Too small; skip.
|
|
continue;
|
|
}
|
|
}
|
|
|
|
auto next_lhs_array_it = lhs_array_it;
|
|
++next_lhs_array_it;
|
|
for (auto rhs_array_it = next_lhs_array_it;
|
|
rhs_array_it != array_map.end();) {
|
|
const auto& rhs_array_name = rhs_array_it->first;
|
|
const auto& rhs_array = *rhs_array_it->second;
|
|
++rhs_array_it;
|
|
if (!IsConstantParameterArray(*model, rhs_array_name)) {
|
|
// Not a constant array; skip.
|
|
continue;
|
|
}
|
|
if (!IsDiscardableArray(*model, rhs_array_name)) {
|
|
// Can't remove the array as it's not discardable (such as an IO edge).
|
|
continue;
|
|
}
|
|
if (!CompareConstantArrays(lhs_array, rhs_array)) {
|
|
// Arrays aren't equal; skip.
|
|
continue;
|
|
}
|
|
|
|
// Arrays can be deduped!
|
|
VLOG(1) << "Deduplicating arrays; using " << lhs_array_name
|
|
<< " in place of " << rhs_array_name;
|
|
ReplaceArrayUsage(model, rhs_array_name, lhs_array_name);
|
|
// Note: rhs_array_it above is already incremented so this is safe.
|
|
model->EraseArray(rhs_array_name);
|
|
}
|
|
}
|
|
}
|
|
|
|
namespace {
|
|
void CopyArrayAttribs(const Array& source_array, Array* target_array) {
|
|
target_array->data_type = source_array.data_type;
|
|
target_array->final_data_type = source_array.final_data_type;
|
|
if (source_array.has_shape()) {
|
|
target_array->copy_shape(source_array.shape());
|
|
}
|
|
|
|
if (source_array.minmax) {
|
|
target_array->GetOrCreateMinMax() = source_array.GetMinMax();
|
|
} else {
|
|
target_array->minmax.reset();
|
|
}
|
|
|
|
if (source_array.quantization_params) {
|
|
target_array->GetOrCreateQuantizationParams() =
|
|
source_array.GetQuantizationParams();
|
|
} else {
|
|
target_array->quantization_params.reset();
|
|
}
|
|
}
|
|
} // namespace
|
|
|
|
void InsertCopyOperator(Model* model, const string& source_array_name,
|
|
const string& target_array_name) {
|
|
// Reshape to the same size. This should be a no-op.
|
|
const Array& source_array = model->GetArray(source_array_name);
|
|
std::vector<int> shape = source_array.shape().dims();
|
|
|
|
// Drop constant data from the target array as the copy will be done at
|
|
// runtime.
|
|
Array& target_array = model->GetOrCreateArray(target_array_name);
|
|
target_array.buffer.reset();
|
|
CopyArrayAttribs(source_array, &target_array);
|
|
|
|
// Insert copy operator.
|
|
auto* copy_op = new TensorFlowReshapeOperator;
|
|
copy_op->inputs = {
|
|
source_array_name,
|
|
CreateInt32Array(
|
|
model, AvailableArrayName(*model, target_array_name + "_copy_shape"),
|
|
shape)};
|
|
copy_op->outputs = {target_array_name};
|
|
if (target_array.has_shape()) {
|
|
copy_op->shape = target_array.shape().dims();
|
|
}
|
|
model->operators.emplace_back(copy_op);
|
|
}
|
|
|
|
void CloneArray(Model* model, const string& source_array_name,
|
|
const string& target_array_name) {
|
|
CHECK(!model->HasArray(target_array_name));
|
|
const Array& source_array = model->GetArray(source_array_name);
|
|
Array& target_array = model->GetOrCreateArray(target_array_name);
|
|
CopyArrayAttribs(source_array, &target_array);
|
|
|
|
if (!source_array.buffer) {
|
|
return;
|
|
}
|
|
|
|
switch (source_array.data_type) {
|
|
case ArrayDataType::kBool:
|
|
CopyArrayBuffer<ArrayDataType::kBool>(source_array, &target_array);
|
|
break;
|
|
case ArrayDataType::kFloat:
|
|
CopyArrayBuffer<ArrayDataType::kFloat>(source_array, &target_array);
|
|
break;
|
|
case ArrayDataType::kInt8:
|
|
CopyArrayBuffer<ArrayDataType::kInt8>(source_array, &target_array);
|
|
break;
|
|
case ArrayDataType::kUint8:
|
|
CopyArrayBuffer<ArrayDataType::kUint8>(source_array, &target_array);
|
|
break;
|
|
case ArrayDataType::kInt16:
|
|
CopyArrayBuffer<ArrayDataType::kInt16>(source_array, &target_array);
|
|
break;
|
|
case ArrayDataType::kUint16:
|
|
CopyArrayBuffer<ArrayDataType::kUint16>(source_array, &target_array);
|
|
break;
|
|
case ArrayDataType::kInt32:
|
|
CopyArrayBuffer<ArrayDataType::kInt32>(source_array, &target_array);
|
|
break;
|
|
case ArrayDataType::kUint32:
|
|
CopyArrayBuffer<ArrayDataType::kUint32>(source_array, &target_array);
|
|
break;
|
|
case ArrayDataType::kInt64:
|
|
CopyArrayBuffer<ArrayDataType::kInt64>(source_array, &target_array);
|
|
break;
|
|
case ArrayDataType::kUint64:
|
|
CopyArrayBuffer<ArrayDataType::kUint64>(source_array, &target_array);
|
|
break;
|
|
case ArrayDataType::kString:
|
|
CopyArrayBuffer<ArrayDataType::kString>(source_array, &target_array);
|
|
break;
|
|
case ArrayDataType::kComplex64:
|
|
CopyArrayBuffer<ArrayDataType::kComplex64>(source_array, &target_array);
|
|
break;
|
|
default:
|
|
LOG(FATAL) << "Unsupported data type: "
|
|
<< ArrayDataTypeName(source_array.data_type);
|
|
return;
|
|
}
|
|
}
|
|
|
|
void MakeArrayDims(int num_dims, int batch, int height, int width, int depth,
|
|
std::vector<int>* out_dims) {
|
|
CHECK(out_dims->empty());
|
|
if (num_dims == 0) {
|
|
return;
|
|
} else if (num_dims == 1) {
|
|
CHECK_EQ(batch, 1);
|
|
*out_dims = {depth};
|
|
} else if (num_dims == 2) {
|
|
*out_dims = {batch, depth};
|
|
} else if (num_dims == 3) {
|
|
CHECK_EQ(batch, 1);
|
|
*out_dims = {height, width, depth};
|
|
} else if (num_dims == 4) {
|
|
*out_dims = {batch, height, width, depth};
|
|
} else {
|
|
LOG(FATAL) << "Should not get here: " << num_dims;
|
|
}
|
|
}
|
|
|
|
void CreateOrCheckRnnStateArray(const string& name, int size,
|
|
int state_num_dims, Model* model) {
|
|
int batch = 1;
|
|
int num_dims = -1;
|
|
if (state_num_dims > 0) {
|
|
num_dims = state_num_dims;
|
|
} else {
|
|
// state_num_dims is not given. We will infer it from an input tensor.
|
|
for (const auto& input_array : model->flags.input_arrays()) {
|
|
// Pick 'num_dims' and 'batch' from the first input_arrays, unless we find
|
|
// a better match by name.
|
|
if (input_array.name() == name || num_dims == -1) {
|
|
num_dims = input_array.shape().dims_size();
|
|
if (num_dims > 0) {
|
|
batch = input_array.shape().dims(0);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
Array& array = model->GetOrCreateArray(name);
|
|
if (array.has_shape()) {
|
|
num_dims = array.shape().dimensions_count();
|
|
}
|
|
if (!array.has_shape() && num_dims >= 0) {
|
|
Shape* shape = array.mutable_shape();
|
|
std::vector<int> dims;
|
|
MakeArrayDims(num_dims, batch, 1, 1, size, &dims);
|
|
*shape->mutable_dims() = dims;
|
|
}
|
|
}
|
|
|
|
void ResolveModelFlags(const ModelFlags& model_flags, Model* model) {
|
|
// Merge info about input_arrays from model_flags into model->flags
|
|
for (const auto& specified_input_array : model_flags.input_arrays()) {
|
|
toco::InputArray* dst_input_array = nullptr;
|
|
for (int i = 0; i < model->flags.input_arrays_size(); i++) {
|
|
toco::InputArray* candidate_dst_input_array =
|
|
model->flags.mutable_input_arrays(i);
|
|
if (candidate_dst_input_array->name() == specified_input_array.name()) {
|
|
// specified_input_array from model_flags maps to dst_input_array
|
|
// in model->flags
|
|
dst_input_array = candidate_dst_input_array;
|
|
break;
|
|
}
|
|
}
|
|
if (!dst_input_array) {
|
|
// Specified_input_array from model_flags is not found in model->flags.
|
|
// Match a name-less specified input array when there can be no ambiguity
|
|
// as there is only 1 input array.
|
|
if (model->flags.input_arrays_size() == 1 &&
|
|
model_flags.input_arrays_size() == 1 &&
|
|
!specified_input_array.has_name()) {
|
|
dst_input_array = model->flags.mutable_input_arrays(0);
|
|
}
|
|
}
|
|
if (!dst_input_array) {
|
|
// Still no match, so create a new input array to copy
|
|
// specified_input_array into.
|
|
dst_input_array = model->flags.add_input_arrays();
|
|
dst_input_array->set_name(specified_input_array.name());
|
|
}
|
|
|
|
#define RESOLVE_MODEL_FLAG(field_name) \
|
|
if (specified_input_array.has_##field_name()) { \
|
|
if (dst_input_array->has_##field_name()) { \
|
|
QCHECK_EQ(dst_input_array->field_name(), \
|
|
specified_input_array.field_name()) \
|
|
<< "For input array '" << dst_input_array->name() << "', " \
|
|
<< "specified " #field_name " flag with value: " \
|
|
<< specified_input_array.field_name() \
|
|
<< " does not agree with already defined " #field_name \
|
|
" of this model, with value: " \
|
|
<< specified_input_array.field_name(); \
|
|
} else { \
|
|
dst_input_array->set_##field_name(specified_input_array.field_name()); \
|
|
} \
|
|
}
|
|
RESOLVE_MODEL_FLAG(std_value);
|
|
RESOLVE_MODEL_FLAG(mean_value);
|
|
#undef RESOLVE_MODEL_FLAG
|
|
|
|
if (specified_input_array.has_shape()) {
|
|
if (dst_input_array->has_shape()) {
|
|
QCHECK_EQ(specified_input_array.shape().dims_size(),
|
|
dst_input_array->shape().dims_size())
|
|
<< "For input array '" << specified_input_array.name() << "', "
|
|
<< "size of specified input shape flag with size: "
|
|
<< specified_input_array.shape().dims_size()
|
|
<< " does not agree with already defined input shape"
|
|
" of this model, with size: "
|
|
<< dst_input_array->shape().dims_size();
|
|
// We treat the first dimension as a special case, since it is often
|
|
// a batch size and the input_shape flag is effectively overriding
|
|
// the model.
|
|
for (int i = 1; i < specified_input_array.shape().dims_size(); i++) {
|
|
QCHECK_EQ(specified_input_array.shape().dims(i),
|
|
dst_input_array->shape().dims(i))
|
|
<< "At dimension number " << i << " of input array "
|
|
<< specified_input_array.name() << ", the specified shape's "
|
|
<< "dimension flag with dimension: "
|
|
<< specified_input_array.shape().dims(i)
|
|
<< " does not agree with already defined shape"
|
|
<< " of this model, with dimension: "
|
|
<< dst_input_array->shape().dims(i);
|
|
}
|
|
} else {
|
|
*dst_input_array->mutable_shape() = specified_input_array.shape();
|
|
}
|
|
}
|
|
|
|
if (specified_input_array.has_data_type()) {
|
|
QCHECK(!dst_input_array->has_data_type());
|
|
dst_input_array->set_data_type(specified_input_array.data_type());
|
|
}
|
|
}
|
|
|
|
if (model_flags.output_arrays_size() > 0) {
|
|
model->flags.mutable_output_arrays()->CopyFrom(model_flags.output_arrays());
|
|
}
|
|
|
|
#define RESOLVE_MODEL_FLAG(name) \
|
|
if (model_flags.has_##name()) { \
|
|
if (model->flags.has_##name()) { \
|
|
QCHECK_EQ(model_flags.name(), model->flags.name()) \
|
|
<< "Specified " #name " flag with value: " << model_flags.name() \
|
|
<< " does not agree with already defined " #name \
|
|
" of this model, with value: " \
|
|
<< model->flags.name(); \
|
|
} else { \
|
|
model->flags.set_##name(model_flags.name()); \
|
|
} \
|
|
}
|
|
|
|
RESOLVE_MODEL_FLAG(variable_batch)
|
|
|
|
#undef RESOLVE_MODEL_FLAG
|
|
|
|
if (!model_flags.rnn_states().empty()) {
|
|
model->flags.mutable_rnn_states()->CopyFrom(model_flags.rnn_states());
|
|
}
|
|
|
|
if (model->flags.model_checks_size() == 0) {
|
|
model->flags.mutable_model_checks()->CopyFrom(model_flags.model_checks());
|
|
}
|
|
|
|
QCHECK_GT(model->flags.output_arrays_size(), 0)
|
|
<< "This model does not define output arrays, so a "
|
|
"--output_arrays flag must be given on the command-line.";
|
|
|
|
for (auto& input_array_proto : *model->flags.mutable_input_arrays()) {
|
|
auto& input_array = model->GetOrCreateArray(input_array_proto.name());
|
|
if (input_array_proto.has_data_type()) {
|
|
const ArrayDataType specified_type =
|
|
ConvertIODataTypeToArrayDataType(input_array_proto.data_type());
|
|
QCHECK(specified_type != ArrayDataType::kNone);
|
|
if (input_array.data_type != ArrayDataType::kNone) {
|
|
QCHECK(specified_type == input_array.data_type)
|
|
<< "For input array " << input_array_proto.name()
|
|
<< " the specified input data type "
|
|
<< IODataType_Name(input_array_proto.data_type())
|
|
<< " conflicts with the existing type.";
|
|
}
|
|
input_array.data_type = specified_type;
|
|
}
|
|
|
|
if (input_array.data_type == ArrayDataType::kNone) {
|
|
// We start out with a float input array;
|
|
// that may get replaced by a uint8 array later, by
|
|
// MakeInitialDequantizeOp.
|
|
input_array.data_type = ArrayDataType::kFloat;
|
|
}
|
|
|
|
// Compare/merge the model->flags describing the input_shape with
|
|
// the actual input array's shape.
|
|
if (!input_array.has_shape()) {
|
|
if (input_array_proto.has_shape()) {
|
|
auto& input_array_dims = *input_array.mutable_shape()->mutable_dims();
|
|
CheckValidShapeDimensions(input_array_proto.shape().dims());
|
|
for (const auto& dim : input_array_proto.shape().dims()) {
|
|
input_array_dims.push_back(dim);
|
|
}
|
|
}
|
|
} else {
|
|
if (input_array_proto.has_shape()) {
|
|
// If an input shape was specified on the flags ensure that it matches
|
|
// the actual shape in the model.
|
|
const auto& input_array_dims =
|
|
*input_array.mutable_shape()->mutable_dims();
|
|
CHECK_EQ(input_array_dims.size(),
|
|
input_array_proto.shape().dims_size());
|
|
for (int i = 0; i < input_array_dims.size(); i++) {
|
|
CHECK_EQ(input_array_dims[i], input_array_proto.shape().dims(i));
|
|
}
|
|
} else {
|
|
for (int i = 0; i < input_array.shape().dimensions_count(); i++) {
|
|
input_array_proto.mutable_shape()->add_dims(
|
|
input_array.shape().dims(i));
|
|
}
|
|
}
|
|
}
|
|
|
|
const float mean_value = input_array_proto.mean_value();
|
|
const float std_value = input_array_proto.std_value();
|
|
MinMax input_minmax;
|
|
float qmin = 0, qmax = 255;
|
|
if (input_array.data_type == ArrayDataType::kInt16) {
|
|
qmin = -32768;
|
|
qmax = 32767;
|
|
}
|
|
input_minmax.min = (qmin - mean_value) / std_value;
|
|
input_minmax.max = (qmax - mean_value) / std_value;
|
|
if (!input_array.minmax) {
|
|
input_array.GetOrCreateMinMax() = input_minmax;
|
|
}
|
|
}
|
|
|
|
// Creation of the RNN state arrays
|
|
for (const auto& rnn_state : model->flags.rnn_states()) {
|
|
CreateOrCheckRnnStateArray(rnn_state.state_array(), rnn_state.size(),
|
|
rnn_state.num_dims(), model);
|
|
}
|
|
|
|
model->flags.set_change_concat_input_ranges(
|
|
model_flags.change_concat_input_ranges());
|
|
model->flags.set_allow_nonascii_arrays(model_flags.allow_nonascii_arrays());
|
|
model->flags.set_allow_nonexistent_arrays(
|
|
model_flags.allow_nonexistent_arrays());
|
|
|
|
CHECK(!model->flags.has_arrays_extra_info());
|
|
*model->flags.mutable_arrays_extra_info() = model_flags.arrays_extra_info();
|
|
}
|
|
|
|
void CheckIsReadyForQuantization(const Model& model) {
|
|
for (const auto& op : model.operators) {
|
|
for (const auto& input : op->inputs) {
|
|
const auto& input_array = model.GetArray(input);
|
|
if (input_array.data_type != ArrayDataType::kFloat) {
|
|
// The array is not floats, no quantization needed.
|
|
continue;
|
|
}
|
|
if (input_array.minmax) {
|
|
// The array has minmax, we're good.
|
|
continue;
|
|
}
|
|
if (input_array.buffer) {
|
|
// The array has a constant buffer, so we can
|
|
// fall back to computing the minmax from actual array entries
|
|
// (with a WARNING about possible accuracy implications).
|
|
continue;
|
|
}
|
|
LOG(FATAL)
|
|
<< "Array " << input << ", which is an input to the "
|
|
<< HelpfulOperatorTypeName(*op) << " operator producing the output "
|
|
<< "array " << op->outputs[0] << ", is lacking min/max data, "
|
|
<< "which is necessary for quantization. If accuracy matters, either "
|
|
<< "target a non-quantized output format, or run quantized training "
|
|
<< "with your model from a floating point checkpoint to change the "
|
|
<< "input graph to contain min/max information. If you don't care "
|
|
<< "about accuracy, you can pass --default_ranges_min= and "
|
|
<< "--default_ranges_max= for easy experimentation.";
|
|
}
|
|
}
|
|
}
|
|
|
|
int ElementSize(ArrayDataType data_type) {
|
|
switch (data_type) {
|
|
case ArrayDataType::kBool:
|
|
return sizeof(bool);
|
|
case ArrayDataType::kFloat:
|
|
return 4;
|
|
case ArrayDataType::kInt8:
|
|
return 1;
|
|
case ArrayDataType::kUint8:
|
|
return 1;
|
|
case ArrayDataType::kInt16:
|
|
return 2;
|
|
case ArrayDataType::kUint16:
|
|
return 2;
|
|
case ArrayDataType::kInt32:
|
|
return 4;
|
|
case ArrayDataType::kUint32:
|
|
return 4;
|
|
case ArrayDataType::kInt64:
|
|
return 8;
|
|
case ArrayDataType::kUint64:
|
|
return 8;
|
|
case ArrayDataType::kComplex64:
|
|
return 8;
|
|
|
|
// Usually not critical limitation because strings are only input and/or
|
|
// output.
|
|
case ArrayDataType::kString:
|
|
LOG(FATAL) << "Transient arrays with strings are not supported yet";
|
|
return 0;
|
|
default:
|
|
LOG(FATAL) << "Unknown data_type = " << static_cast<int>(data_type);
|
|
return 0;
|
|
}
|
|
}
|
|
|
|
void DropMinMax(Model* model, const string& array_name) {
|
|
auto& array = model->GetArray(array_name);
|
|
if (!!array.minmax) {
|
|
LOG(WARNING) << "Dropping MinMax information in array " << array_name
|
|
<< ". Expect inaccuracy in quantized inference.";
|
|
array.minmax = nullptr;
|
|
}
|
|
}
|
|
|
|
bool IsAllocatableTransientArray(const Model& model, const string& array_name) {
|
|
// Optional array is not transient
|
|
if (model.IsOptionalArray(array_name)) return false;
|
|
// The model's input and output arrays are externally allocated.
|
|
// They are not transient arrays.
|
|
if (IsInputArray(model, array_name) || IsOutputArray(model, array_name)) {
|
|
return false;
|
|
}
|
|
const auto& array = &model.GetArray(array_name);
|
|
// An array with a constant buffer isn't a transient array.
|
|
if (!!array->buffer) {
|
|
return false;
|
|
}
|
|
// An array without shape isn't allocatable.
|
|
if (!array->has_shape()) {
|
|
return false;
|
|
}
|
|
|
|
// The size of string tensors is rarely known ahead of time, so all transient
|
|
// tensors of this type will need to be dynamically allocated.
|
|
if (array->final_data_type == ArrayDataType::kString ||
|
|
array->data_type == ArrayDataType::kString) {
|
|
return false;
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
string AvailableArrayName(const Model& model, const string& name) {
|
|
string sanitized_name = SanitizeNameForTFNode(name);
|
|
if (!model.HasArray(sanitized_name) &&
|
|
!model.IsOptionalArray(sanitized_name)) {
|
|
return sanitized_name;
|
|
}
|
|
const int kNumSuffixesToTry = 1000;
|
|
for (int i = 0; i < kNumSuffixesToTry; i++) {
|
|
const string& name_with_suffix =
|
|
toco::port::StringF("%s_%d", sanitized_name, i);
|
|
if (!model.HasArray(name_with_suffix) &&
|
|
!model.IsOptionalArray(name_with_suffix)) {
|
|
return name_with_suffix;
|
|
}
|
|
}
|
|
LOG(FATAL) << "Could not find an available array name starting with "
|
|
<< sanitized_name << ". Tried " << kNumSuffixesToTry
|
|
<< " suffixes, all were taken!";
|
|
return "";
|
|
}
|
|
|
|
string ShapeToString(const Shape& shape) {
|
|
if (shape.dimensions_count() == 0) {
|
|
return "[]";
|
|
}
|
|
|
|
return absl::StrCat("[ ", absl::StrJoin(shape.dims(), ", "), " ]");
|
|
}
|
|
|
|
void PrintArrayShape(Model* model, const string& name) {
|
|
if (!model->GetArray(name).has_shape()) {
|
|
LOG(INFO) << name << " has no shape";
|
|
return;
|
|
}
|
|
LOG(INFO) << name
|
|
<< " has shape: " << ShapeToString(model->GetArray(name).shape());
|
|
}
|
|
|
|
bool IsArrayFullyConnectedWeights(const Model& model, const string& name) {
|
|
bool is_fc_weights = false;
|
|
bool is_something_else = false;
|
|
for (const auto& op : model.operators) {
|
|
for (int input_index = 0; input_index < op->inputs.size(); input_index++) {
|
|
if (op->inputs[input_index] == name) {
|
|
if (op->type == OperatorType::kFullyConnected && input_index == 1) {
|
|
is_fc_weights = true;
|
|
} else {
|
|
is_something_else = true;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
CHECK(!(is_fc_weights && is_something_else));
|
|
return is_fc_weights;
|
|
}
|
|
|
|
string CreateInt32Array(Model* model, const string& param_name,
|
|
const std::vector<int>& value) {
|
|
auto param_array_name = AvailableArrayName(*model, param_name);
|
|
auto& param_array = model->GetOrCreateArray(param_array_name);
|
|
param_array.mutable_shape()->ReplaceDims({static_cast<int>(value.size())});
|
|
param_array.data_type = ArrayDataType::kInt32;
|
|
auto& param_array_data =
|
|
param_array.GetMutableBuffer<ArrayDataType::kInt32>().data;
|
|
param_array_data.resize(RequiredBufferSizeForShape(param_array.shape()));
|
|
for (int i = 0; i < value.size(); ++i) {
|
|
param_array_data[i] = value[i];
|
|
}
|
|
return param_array_name;
|
|
}
|
|
|
|
bool EstimateArithmeticOpsCount(const Model& model, const Operator& op,
|
|
int64* result) {
|
|
switch (op.type) {
|
|
case OperatorType::kFullyConnected:
|
|
case OperatorType::kConv:
|
|
case OperatorType::kDepthwiseConv: {
|
|
const auto& output_array = model.GetArray(op.outputs[0]);
|
|
const auto& weights_array = model.GetArray(op.inputs[1]);
|
|
if (!output_array.has_shape() || !weights_array.has_shape()) {
|
|
return false;
|
|
}
|
|
int64 cols = 1;
|
|
for (int i = 0; i < output_array.shape().dimensions_count() - 1; i++) {
|
|
cols *= output_array.shape().dims(i);
|
|
}
|
|
const int64 cost_per_col =
|
|
2 * RequiredBufferSizeForShape(weights_array.shape());
|
|
*result = cost_per_col * cols;
|
|
if (op.inputs.size() > 2) {
|
|
// There is a bias vector. One more op per output value.
|
|
*result += RequiredBufferSizeForShape(output_array.shape());
|
|
}
|
|
break;
|
|
}
|
|
case OperatorType::kTransposeConv: {
|
|
const auto& input_array = model.GetArray(op.inputs[2]);
|
|
const auto& weights_array = model.GetArray(op.inputs[1]);
|
|
if (!input_array.has_shape() || !weights_array.has_shape()) {
|
|
return false;
|
|
}
|
|
const Shape& input = input_array.shape();
|
|
const Shape& weights = weights_array.shape();
|
|
// Compute op count from the seven nested loops of
|
|
// tflite::reference_ops::TransposeConv():
|
|
*result = 2 * input.dims(0) * input.dims(1) * input.dims(2) *
|
|
input.dims(3) * weights.dims(1) * weights.dims(2) *
|
|
weights.dims(0);
|
|
// Note that tflite::optimized_ops::TransposeConv() uses an im2col matrix
|
|
// and has a higher op count, by a factor of (output_height*output_width)
|
|
// vs. (input_height*input_width). Yet it generally performs better
|
|
// because of coherent memory access. (At least for 2x2 striding. But not
|
|
// likely for all cases.)
|
|
break;
|
|
}
|
|
case OperatorType::kAdd:
|
|
case OperatorType::kSub:
|
|
case OperatorType::kMul: {
|
|
const auto& output_array = model.GetArray(op.outputs[0]);
|
|
if (!output_array.has_shape()) {
|
|
return false;
|
|
}
|
|
*result = RequiredBufferSizeForShape(output_array.shape());
|
|
break;
|
|
}
|
|
case OperatorType::kAddN: {
|
|
const auto& output_array = model.GetArray(op.outputs[0]);
|
|
if (!output_array.has_shape()) {
|
|
return false;
|
|
}
|
|
// AddN cost is roughly the same cost as N-1 Adds.
|
|
const int64 num_adds = op.inputs.size() - 1;
|
|
*result = num_adds * RequiredBufferSizeForShape(output_array.shape());
|
|
break;
|
|
}
|
|
case OperatorType::kLogistic:
|
|
case OperatorType::kSoftmax:
|
|
case OperatorType::kLogSoftmax:
|
|
case OperatorType::kTanh: {
|
|
const auto& output_array = model.GetArray(op.outputs[0]);
|
|
if (!output_array.has_shape()) {
|
|
return false;
|
|
}
|
|
// As a very rough ballpark, the cost of evaluating a math function
|
|
// such as tanh or logistic is about 32 multiplications, and about as
|
|
// many additions/subtractions. (Just a power-of-two order-of-magnitude
|
|
// from looking at actual implementations that we use in runtime/ code).
|
|
*result = 64 * RequiredBufferSizeForShape(output_array.shape());
|
|
break;
|
|
}
|
|
case OperatorType::kMaxPool: {
|
|
const auto& maxpool = *static_cast<const MaxPoolOperator*>(&op);
|
|
const auto& output_array = model.GetArray(op.outputs[0]);
|
|
if (!output_array.has_shape()) {
|
|
return false;
|
|
}
|
|
*result = RequiredBufferSizeForShape(output_array.shape()) *
|
|
maxpool.kheight * maxpool.kwidth;
|
|
break;
|
|
}
|
|
case OperatorType::kAveragePool: {
|
|
const auto& avgpool = *static_cast<const AveragePoolOperator*>(&op);
|
|
const auto& output_array = model.GetArray(op.outputs[0]);
|
|
if (!output_array.has_shape()) {
|
|
return false;
|
|
}
|
|
*result = RequiredBufferSizeForShape(output_array.shape()) *
|
|
avgpool.kheight * avgpool.kwidth;
|
|
break;
|
|
}
|
|
case OperatorType::kL2Pool: {
|
|
const auto* maxpool = static_cast<const MaxPoolOperator*>(&op);
|
|
const auto& output_array = model.GetArray(op.outputs[0]);
|
|
if (!output_array.has_shape()) {
|
|
return false;
|
|
}
|
|
// The sum of squares requires (kheight*kwidth) multiply-adds,
|
|
// and then there is the sqrt which we ballpark at 32 ops.
|
|
const int64 cost_per_val = 2 * maxpool->kheight * maxpool->kwidth + 32;
|
|
*result = RequiredBufferSizeForShape(output_array.shape()) * cost_per_val;
|
|
break;
|
|
}
|
|
case OperatorType::kL2Normalization: {
|
|
const auto& output_array = model.GetArray(op.outputs[0]);
|
|
if (!output_array.has_shape()) {
|
|
return false;
|
|
}
|
|
// Computing the squared L2 norm is N multiply-adds so 2N ops,
|
|
// then the single inverse-sqrt is negligible, then we multiply each
|
|
// value by the resulting multiplier, so an extra N ops. count 3N ops.
|
|
*result = 3 * RequiredBufferSizeForShape(output_array.shape());
|
|
break;
|
|
}
|
|
default:
|
|
*result = 0;
|
|
break;
|
|
}
|
|
return true;
|
|
}
|
|
|
|
bool EstimateArithmeticOpsCount(const Model& model, int64* result) {
|
|
int64 total = 0;
|
|
for (const auto& op : model.operators) {
|
|
int64 num_ops;
|
|
if (!EstimateArithmeticOpsCount(model, *op, &num_ops)) {
|
|
return false;
|
|
}
|
|
total += num_ops;
|
|
}
|
|
*result = total;
|
|
return true;
|
|
}
|
|
|
|
string FormattedNumber(int64 x) {
|
|
const int64 million = 1000000;
|
|
const int64 billion = 1000000000;
|
|
if (x < 10000) {
|
|
return toco::port::StringF("%d ", x);
|
|
} else if (x < billion) {
|
|
return toco::port::StringF("%.3f M", static_cast<double>(x) / million);
|
|
} else {
|
|
return toco::port::StringF("%.3f G", static_cast<double>(x) / billion);
|
|
}
|
|
}
|
|
|
|
void GetShuffleShape(AxesOrder input_axes_order, AxesOrder output_axes_order,
|
|
std::vector<int>* shuffle) {
|
|
CHECK_EQ(AxesCount(input_axes_order), AxesCount(output_axes_order));
|
|
shuffle->resize(4);
|
|
for (int i = 0; i < 4; i++) {
|
|
(*shuffle)[i] = i;
|
|
}
|
|
if (input_axes_order == output_axes_order) {
|
|
// nothing to do
|
|
} else if (AxesCount(input_axes_order) == 2) {
|
|
shuffle->resize(2);
|
|
(*shuffle)[0] = 1;
|
|
(*shuffle)[1] = 0;
|
|
} else if (input_axes_order == AxesOrder::kOHWI &&
|
|
output_axes_order == AxesOrder::kHWIO) {
|
|
// 3210 <- 3210
|
|
// HWIO <- OHWI
|
|
*shuffle = {1, 2, 3, 0};
|
|
} else if (input_axes_order == AxesOrder::kHWIO &&
|
|
output_axes_order == AxesOrder::kOHWI) {
|
|
// 3210 <- 3210
|
|
// OHWI <- HWIO
|
|
*shuffle = {3, 0, 1, 2};
|
|
} else if (input_axes_order == AxesOrder::kOHWI &&
|
|
output_axes_order == AxesOrder::kHWOI) {
|
|
*shuffle = {1, 2, 0, 3};
|
|
} else {
|
|
LOG(FATAL) << "Bad shuffle";
|
|
}
|
|
}
|
|
|
|
void ExtendShuffle(const std::vector<int>& input_shuffle, int newdim,
|
|
std::vector<int>* extended_shuffle) {
|
|
*extended_shuffle = input_shuffle;
|
|
CHECK(newdim >= input_shuffle.size());
|
|
const int pad_size = newdim - input_shuffle.size();
|
|
extended_shuffle->resize(newdim);
|
|
for (int i = 0; i < pad_size; i++) {
|
|
(*extended_shuffle)[i] = i;
|
|
}
|
|
for (int i = pad_size; i < newdim; i++) {
|
|
(*extended_shuffle)[i] = input_shuffle[i - pad_size] + pad_size;
|
|
}
|
|
}
|
|
|
|
void ShuffleDims(const Shape& input_shape, AxesOrder input_axes_order,
|
|
AxesOrder output_axes_order, Shape* output_shape) {
|
|
if (input_axes_order == AxesOrder::kHWIM &&
|
|
output_axes_order == AxesOrder::k1HWO) {
|
|
// This special case isn't just a permutation, the IM pair of dims get
|
|
// merged into the 3 dim, so we have to special-case it.
|
|
*output_shape = Shape({1, input_shape.dims(0), input_shape.dims(1),
|
|
input_shape.dims(3) * input_shape.dims(2)});
|
|
} else {
|
|
std::vector<int> shuffle;
|
|
GetShuffleShape(input_axes_order, output_axes_order, &shuffle);
|
|
std::vector<int>* output_dims = output_shape->mutable_dims();
|
|
output_dims->resize(input_shape.dimensions_count());
|
|
for (int i = 0; i < input_shape.dimensions_count(); i++) {
|
|
(*output_dims)[i] = input_shape.dims(shuffle[i]);
|
|
}
|
|
}
|
|
}
|
|
|
|
template <typename T>
|
|
void ShuffleArrayTemplate(const Shape& input_shape, AxesOrder input_axes_order,
|
|
AxesOrder output_axes_order,
|
|
const Shape& output_shape, const T* input_data,
|
|
T* output_data) {
|
|
if (input_axes_order == AxesOrder::kHWIM &&
|
|
output_axes_order == AxesOrder::k1HWO) {
|
|
// This special case isn't just a permutation, the IM pair of dims get
|
|
// merged into the O dim, so we have to special-case it. Fortunately,
|
|
// as far as array shuffling is concerned, it's just the identity
|
|
// transformation.
|
|
memcpy(output_data, input_data,
|
|
RequiredBufferSizeForShape(input_shape) * sizeof(output_data[0]));
|
|
return;
|
|
}
|
|
CHECK(input_shape.dimensions_count() == output_shape.dimensions_count());
|
|
const int dim = input_shape.dimensions_count();
|
|
CHECK_LE(dim, 4);
|
|
std::vector<int> shuffle;
|
|
GetShuffleShape(input_axes_order, output_axes_order, &shuffle);
|
|
CHECK(shuffle.size() >= dim);
|
|
for (int i = 0; i < dim; i++) {
|
|
CHECK(shuffle[i] >= 0 && shuffle[i] < dim);
|
|
CHECK(input_shape.dims(shuffle[i]) == output_shape.dims(i));
|
|
}
|
|
Shape extended_input_shape = input_shape;
|
|
ExtendShape(&extended_input_shape, 4);
|
|
Shape extended_output_shape = output_shape;
|
|
ExtendShape(&extended_output_shape, 4);
|
|
std::vector<int> extended_shuffle;
|
|
ExtendShuffle(shuffle, 4, &extended_shuffle);
|
|
|
|
const std::vector<int>& extended_input_dims = extended_input_shape.dims();
|
|
const std::vector<int>& extended_output_dims = extended_output_shape.dims();
|
|
|
|
// TODO(starka): Rework to handle different numbers of dimensions.
|
|
int input_strides[4];
|
|
input_strides[3] = 1;
|
|
input_strides[2] = extended_input_dims[3];
|
|
input_strides[1] = input_strides[2] * extended_input_dims[2];
|
|
input_strides[0] = input_strides[1] * extended_input_dims[1];
|
|
const int input_stride_0 = input_strides[extended_shuffle[3]];
|
|
const int input_stride_1 = input_strides[extended_shuffle[2]];
|
|
const int input_stride_2 = input_strides[extended_shuffle[1]];
|
|
const int input_stride_3 = input_strides[extended_shuffle[0]];
|
|
|
|
const int output_size_0 = extended_output_dims[3];
|
|
const int output_size_1 = extended_output_dims[2];
|
|
const int output_size_2 = extended_output_dims[1];
|
|
const int output_size_3 = extended_output_dims[0];
|
|
const int output_stride_0 = 1;
|
|
const int output_stride_1 = output_size_0;
|
|
const int output_stride_2 = output_stride_1 * output_size_1;
|
|
const int output_stride_3 = output_stride_2 * output_size_2;
|
|
|
|
for (int i3 = 0; i3 < output_size_3; i3++) {
|
|
const T* const input_ptr_3 = input_data + i3 * input_stride_3;
|
|
T* const output_ptr_3 = output_data + i3 * output_stride_3;
|
|
for (int i2 = 0; i2 < output_size_2; i2++) {
|
|
const T* const input_ptr_2 = input_ptr_3 + i2 * input_stride_2;
|
|
T* const output_ptr_2 = output_ptr_3 + i2 * output_stride_2;
|
|
for (int i1 = 0; i1 < output_size_1; i1++) {
|
|
const T* input_ptr = input_ptr_2 + i1 * input_stride_1;
|
|
T* output_ptr = output_ptr_2 + i1 * output_stride_1;
|
|
T* const output_ptr_end = output_ptr + output_size_0 * output_stride_0;
|
|
while (output_ptr != output_ptr_end) {
|
|
*output_ptr = *input_ptr;
|
|
input_ptr += input_stride_0;
|
|
output_ptr += output_stride_0;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
void ShuffleArray(const Shape& input_shape, AxesOrder input_axes_order,
|
|
AxesOrder output_axes_order, const Shape& output_shape,
|
|
const uint8* input_data, uint8* output_data) {
|
|
ShuffleArrayTemplate<uint8>(input_shape, input_axes_order, output_axes_order,
|
|
output_shape, input_data, output_data);
|
|
}
|
|
|
|
void ShuffleArray(const Shape& input_shape, AxesOrder input_axes_order,
|
|
AxesOrder output_axes_order, const Shape& output_shape,
|
|
const float* input_data, float* output_data) {
|
|
ShuffleArrayTemplate<float>(input_shape, input_axes_order, output_axes_order,
|
|
output_shape, input_data, output_data);
|
|
}
|
|
|
|
int AxesCount(AxesOrder axes_order) {
|
|
switch (axes_order) {
|
|
case AxesOrder::kOneAxis:
|
|
return 1;
|
|
case AxesOrder::kRC:
|
|
return 2;
|
|
case AxesOrder::kCR:
|
|
return 2;
|
|
case AxesOrder::kHWIO:
|
|
return 4;
|
|
case AxesOrder::kOHWI:
|
|
return 4;
|
|
case AxesOrder::kHWIM:
|
|
return 4;
|
|
case AxesOrder::k1HWO:
|
|
return 4;
|
|
case AxesOrder::kNHWC:
|
|
return 4;
|
|
case AxesOrder::kHWOI:
|
|
return 4;
|
|
default:
|
|
LOG(FATAL) << "Bad AxesOrder";
|
|
return 0;
|
|
}
|
|
}
|
|
|
|
bool IsDiscardableArray(const Model& model, const string& array_name) {
|
|
if (IsInputArray(model, array_name) || IsOutputArray(model, array_name)) {
|
|
return false;
|
|
}
|
|
for (const auto& rnn_state : model.flags.rnn_states()) {
|
|
if (!rnn_state.discardable()) {
|
|
if (array_name == rnn_state.state_array()) {
|
|
return false;
|
|
}
|
|
if (array_name == rnn_state.back_edge_source_array()) {
|
|
return false;
|
|
}
|
|
}
|
|
}
|
|
return true;
|
|
}
|
|
|
|
bool ReshapeIsEquivalentToTranspose(const Model& model,
|
|
const TensorFlowReshapeOperator* op,
|
|
bool allow_extra_unary_dims) {
|
|
CHECK(!op->shape.empty());
|
|
CHECK(model.HasArray(op->inputs[0]));
|
|
CHECK(model.HasArray(op->outputs[0]));
|
|
|
|
const auto& input_array = model.GetArray(op->inputs[0]);
|
|
const auto& output_array = model.GetArray(op->outputs[0]);
|
|
|
|
CHECK(input_array.has_shape());
|
|
CHECK(output_array.has_shape());
|
|
|
|
std::vector<int> in_shape = input_array.shape().dims();
|
|
std::vector<int> out_shape = output_array.shape().dims();
|
|
|
|
// If the reshape changes the number of dimensions so it cannot be interpreted
|
|
// as a transpose.
|
|
if (!allow_extra_unary_dims && in_shape.size() != out_shape.size()) {
|
|
return false;
|
|
}
|
|
|
|
in_shape.erase(std::remove(in_shape.begin(), in_shape.end(), 1),
|
|
in_shape.end());
|
|
out_shape.erase(std::remove(out_shape.begin(), out_shape.end(), 1),
|
|
out_shape.end());
|
|
return in_shape == out_shape;
|
|
}
|
|
|
|
void CheckFinalDataTypesSatisfied(const Model& model) {
|
|
for (const auto& array_entry : model.GetArrayMap()) {
|
|
const auto& array = *array_entry.second;
|
|
if (array.data_type == ArrayDataType::kBool) {
|
|
// Boolean values are never quantized.
|
|
continue;
|
|
}
|
|
|
|
// If the final data type is int16, the data type may be float, for example
|
|
// after dequantization.
|
|
if (array.final_data_type != ArrayDataType::kNone &&
|
|
array.final_data_type != ArrayDataType::kInt16) {
|
|
CHECK(array.data_type == array.final_data_type)
|
|
<< "Array \"" << array_entry.first
|
|
<< "\" has mis-matching actual and final data types (data_type="
|
|
<< ArrayDataTypeName(array.data_type)
|
|
<< ", final_data_type=" << ArrayDataTypeName(array.final_data_type)
|
|
<< ").";
|
|
}
|
|
}
|
|
}
|
|
|
|
ArrayDataType ConvertIODataTypeToArrayDataType(IODataType type) {
|
|
switch (type) {
|
|
case FLOAT:
|
|
return ArrayDataType::kFloat;
|
|
case QUANTIZED_UINT8:
|
|
return ArrayDataType::kUint8;
|
|
case INT8:
|
|
return ArrayDataType::kInt8;
|
|
case QUANTIZED_INT16:
|
|
return ArrayDataType::kInt16;
|
|
case INT32:
|
|
return ArrayDataType::kInt32;
|
|
case INT64:
|
|
return ArrayDataType::kInt64;
|
|
case BOOL:
|
|
return ArrayDataType::kBool;
|
|
case STRING:
|
|
return ArrayDataType::kString;
|
|
case COMPLEX64:
|
|
return ArrayDataType::kComplex64;
|
|
default:
|
|
return ArrayDataType::kNone;
|
|
}
|
|
}
|
|
|
|
void FinishBuildingRNNStates(Model* model) {
|
|
for (const auto& rnn_state : model->flags.rnn_states()) {
|
|
if (!model->HasArray(rnn_state.back_edge_source_array()) ||
|
|
!model->HasArray(rnn_state.state_array())) {
|
|
CHECK(model->HasArray(rnn_state.back_edge_source_array()));
|
|
CHECK(model->HasArray(rnn_state.state_array()));
|
|
continue;
|
|
}
|
|
const auto& src_array = model->GetArray(rnn_state.back_edge_source_array());
|
|
auto& dst_array = model->GetArray(rnn_state.state_array());
|
|
if (src_array.data_type == ArrayDataType::kNone &&
|
|
dst_array.data_type == ArrayDataType::kNone) {
|
|
dst_array.data_type = ArrayDataType::kFloat;
|
|
}
|
|
}
|
|
}
|
|
|
|
// Returns the array names that match the ArraysExtraInfo's name and
|
|
// name_regexp. The regexp match is for a full match.
|
|
std::unordered_set<string> ScanArrayNames(
|
|
const Model& model, const toco::ArraysExtraInfo_Entry& entry) {
|
|
std::unordered_set<string> matches;
|
|
if (model.HasArray(entry.name())) {
|
|
matches.insert(entry.name());
|
|
}
|
|
if (!entry.name_regexp().empty()) {
|
|
const auto& arrays = model.GetArrayMap();
|
|
const RE2 name_regexp = {entry.name_regexp()};
|
|
for (auto it = arrays.begin(); it != arrays.end(); ++it) {
|
|
if (RE2::FullMatch(it->first, name_regexp)) {
|
|
matches.insert(it->first);
|
|
}
|
|
}
|
|
}
|
|
return matches;
|
|
}
|
|
|
|
void UseArraysExtraInfo(Model* model, bool quantize_output) {
|
|
for (const auto& entry : model->flags.arrays_extra_info().entries()) {
|
|
const auto matches = ScanArrayNames(*model, entry);
|
|
for (const auto& matched_name : matches) {
|
|
auto& array = model->GetArray(matched_name);
|
|
if (entry.has_min() || entry.has_max()) {
|
|
CHECK_EQ(entry.has_min(), entry.has_max());
|
|
auto& minmax = array.GetOrCreateMinMax();
|
|
minmax.min = entry.min();
|
|
minmax.max = entry.max();
|
|
}
|
|
if (entry.has_data_type() && quantize_output) {
|
|
array.final_data_type =
|
|
ConvertIODataTypeToArrayDataType(entry.data_type());
|
|
}
|
|
if (entry.has_shape()) {
|
|
array.clear_shape();
|
|
// Make sure to create the shape even if there are no dims, to
|
|
// correctly record 0-D shapes.
|
|
array.mutable_shape();
|
|
for (const auto& dim : entry.shape().dims()) {
|
|
array.mutable_shape()->mutable_dims()->push_back(dim);
|
|
}
|
|
}
|
|
if (entry.has_constant_float_value()) {
|
|
CHECK(array.has_shape());
|
|
if (array.data_type == ArrayDataType::kFloat) {
|
|
auto& data = array.GetMutableBuffer<ArrayDataType::kFloat>().data;
|
|
data.resize(RequiredBufferSizeForShape(array.shape()));
|
|
for (float& f : data) {
|
|
f = entry.constant_float_value();
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
void UndoWeightsShuffling(Model* model) {
|
|
for (const auto& op : model->operators) {
|
|
if (op->type != toco::OperatorType::kFullyConnected) {
|
|
continue;
|
|
}
|
|
const auto& fc_op = static_cast<toco::FullyConnectedOperator&>(*op);
|
|
if (fc_op.weights_format == FullyConnectedWeightsFormat::kDefault) {
|
|
continue;
|
|
}
|
|
const string& weights_name = fc_op.inputs[1];
|
|
QCHECK_EQ(CountOpsWithInput(*model, weights_name), 1);
|
|
auto& weights_array = model->GetArray(weights_name);
|
|
QCHECK(weights_array.data_type == ArrayDataType::kUint8);
|
|
auto& weights_data =
|
|
weights_array.GetMutableBuffer<toco::ArrayDataType::kUint8>().data;
|
|
const auto& weights_shape = weights_array.shape();
|
|
QCHECK_EQ(weights_shape.dimensions_count(), 2);
|
|
const int rows = weights_shape.dims(0);
|
|
const int cols = weights_shape.dims(1);
|
|
QCHECK_EQ(rows % 4, 0);
|
|
QCHECK_EQ(cols % 16, 0);
|
|
CHECK_EQ(rows * cols, weights_data.size());
|
|
// Compute the de-shuffled weights
|
|
std::vector<uint8> deshuffled_data(weights_data.size());
|
|
uint8* shuffled_data_ptr = weights_data.data();
|
|
for (int r = 0; r < rows; r += 4) {
|
|
for (int c = 0; c < cols; c += 16) {
|
|
for (int i = 0; i < 4; i++) {
|
|
uint8* deshuffled_data_ptr =
|
|
deshuffled_data.data() + (r + i) * cols + c;
|
|
for (int j = 0; j < 16; j++) {
|
|
uint8 shuffled_val = *shuffled_data_ptr++;
|
|
// Deshuffling isn't only about deshuffling the storage layout,
|
|
// it's also about undoing the flipping of the sign bit, which is
|
|
// performed on the shuffled weights.
|
|
uint8 deshuffled_val = shuffled_val ^ 0x80;
|
|
*deshuffled_data_ptr++ = deshuffled_val;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
CHECK_EQ(shuffled_data_ptr, weights_data.data() + rows * cols);
|
|
// Switch this FC op to using the deshuffled weights.
|
|
weights_data = std::move(deshuffled_data);
|
|
}
|
|
}
|
|
|
|
void CopyMinMaxAndQuantizationRelatedFields(const Array& src, Array* dst) {
|
|
if (src.minmax) {
|
|
dst->GetOrCreateMinMax() = src.GetMinMax();
|
|
}
|
|
if (src.quantization_params) {
|
|
dst->GetOrCreateQuantizationParams() = src.GetQuantizationParams();
|
|
}
|
|
dst->narrow_range = src.narrow_range;
|
|
}
|
|
|
|
} // namespace toco
|