Minor toco changes to support new features in tfmini.
PiperOrigin-RevId: 176694498
This commit is contained in:
parent
8067aa0862
commit
c5b8a5ed86
@ -35,8 +35,11 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/types.pb.h"
|
#include "tensorflow/core/framework/types.pb.h"
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
|
|
||||||
|
using tensorflow::DT_BOOL;
|
||||||
using tensorflow::DT_FLOAT;
|
using tensorflow::DT_FLOAT;
|
||||||
using tensorflow::DT_INT32;
|
using tensorflow::DT_INT32;
|
||||||
|
using tensorflow::DT_INT64;
|
||||||
|
using tensorflow::DT_UINT8;
|
||||||
using tensorflow::GraphDef;
|
using tensorflow::GraphDef;
|
||||||
using tensorflow::TensorProto;
|
using tensorflow::TensorProto;
|
||||||
|
|
||||||
@ -1500,10 +1503,29 @@ void ConvertOperator(const Model& model, const Operator& src_op,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void AddPlaceholder(const string& name, GraphDef* tensorflow_graph) {
|
void AddPlaceholder(const string& name, ArrayDataType type,
|
||||||
|
GraphDef* tensorflow_graph) {
|
||||||
auto* placeholder = tensorflow_graph->add_node();
|
auto* placeholder = tensorflow_graph->add_node();
|
||||||
placeholder->set_op("Placeholder");
|
placeholder->set_op("Placeholder");
|
||||||
|
switch (type) {
|
||||||
|
case ArrayDataType::kBool:
|
||||||
|
(*placeholder->mutable_attr())["dtype"].set_type(DT_BOOL);
|
||||||
|
break;
|
||||||
|
case ArrayDataType::kFloat:
|
||||||
(*placeholder->mutable_attr())["dtype"].set_type(DT_FLOAT);
|
(*placeholder->mutable_attr())["dtype"].set_type(DT_FLOAT);
|
||||||
|
break;
|
||||||
|
case ArrayDataType::kUint8:
|
||||||
|
(*placeholder->mutable_attr())["dtype"].set_type(DT_UINT8);
|
||||||
|
break;
|
||||||
|
case ArrayDataType::kInt32:
|
||||||
|
(*placeholder->mutable_attr())["dtype"].set_type(DT_INT32);
|
||||||
|
break;
|
||||||
|
case ArrayDataType::kInt64:
|
||||||
|
(*placeholder->mutable_attr())["dtype"].set_type(DT_INT64);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
LOG(FATAL) << "Unexpected data type in array \"" << name << "\"";
|
||||||
|
}
|
||||||
placeholder->set_name(name);
|
placeholder->set_name(name);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1531,7 +1553,9 @@ void AddPlaceholderForRNNState(const Model& model, const string& name, int size,
|
|||||||
void ExportTensorFlowGraphDefImplementation(const Model& model,
|
void ExportTensorFlowGraphDefImplementation(const Model& model,
|
||||||
GraphDef* tensorflow_graph) {
|
GraphDef* tensorflow_graph) {
|
||||||
for (const auto& input_array : model.flags.input_arrays()) {
|
for (const auto& input_array : model.flags.input_arrays()) {
|
||||||
AddPlaceholder(input_array.name(), tensorflow_graph);
|
AddPlaceholder(input_array.name(),
|
||||||
|
model.arrays.at(input_array.name())->data_type,
|
||||||
|
tensorflow_graph);
|
||||||
}
|
}
|
||||||
for (const auto& rnn_state : model.flags.rnn_states()) {
|
for (const auto& rnn_state : model.flags.rnn_states()) {
|
||||||
AddPlaceholderForRNNState(model, rnn_state.state_array(), rnn_state.size(),
|
AddPlaceholderForRNNState(model, rnn_state.state_array(), rnn_state.size(),
|
||||||
|
@ -85,38 +85,57 @@ void MakeGeneralGraphTransformationsSet(
|
|||||||
transformations->Add(new MakeInitialDequantizeOperator);
|
transformations->Add(new MakeInitialDequantizeOperator);
|
||||||
}
|
}
|
||||||
|
|
||||||
void SetArrayFinalDataTypes(const TocoFlags& toco_flags, Model* model) {
|
bool SupportsQuantization(FileFormat format) {
|
||||||
const bool output_supports_only_float =
|
return (format == GRAPHVIZ_DOT || format == TFLITE);
|
||||||
toco_flags.output_format() == TENSORFLOW_GRAPHDEF;
|
;
|
||||||
|
}
|
||||||
|
|
||||||
ArrayDataType specified_final_data_type = ArrayDataType::kNone;
|
bool SupportsFusedActivationFunction(FileFormat format) {
|
||||||
|
return (format == GRAPHVIZ_DOT || format == TFLITE);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool SupportsLstmCell(FileFormat format) {
|
||||||
|
return (format == TENSORFLOW_GRAPHDEF || format == GRAPHVIZ_DOT);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool SupportsPreallocatedWorkspace(FileFormat format) {
|
||||||
|
return (format == GRAPHVIZ_DOT || format == TFLITE);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool IsRealValued(toco::ArrayDataType type) {
|
||||||
|
return static_cast<bool>(type == toco::ArrayDataType::kFloat ||
|
||||||
|
type == toco::ArrayDataType::kUint8);
|
||||||
|
}
|
||||||
|
|
||||||
|
void SetFinalDataTypeOnInputs(const TocoFlags& toco_flags, Model* model) {
|
||||||
|
const FileFormat output_format = toco_flags.output_format();
|
||||||
|
ArrayDataType type;
|
||||||
if (toco_flags.has_inference_input_type()) {
|
if (toco_flags.has_inference_input_type()) {
|
||||||
specified_final_data_type =
|
type = ConvertIODataTypeToArrayDataType(toco_flags.inference_input_type());
|
||||||
ConvertIODataTypeToArrayDataType(toco_flags.inference_input_type());
|
|
||||||
} else if (toco_flags.has_inference_type()) {
|
} else if (toco_flags.has_inference_type()) {
|
||||||
specified_final_data_type =
|
type = ConvertIODataTypeToArrayDataType(toco_flags.inference_type());
|
||||||
ConvertIODataTypeToArrayDataType(toco_flags.inference_type());
|
} else if (!SupportsQuantization(output_format)) {
|
||||||
}
|
// Data type is implicitly float for non-quantized formats
|
||||||
ArrayDataType final_data_type = ArrayDataType::kNone;
|
type = ArrayDataType::kFloat;
|
||||||
if (output_supports_only_float) {
|
|
||||||
QCHECK(specified_final_data_type == ArrayDataType::kNone ||
|
|
||||||
specified_final_data_type == ArrayDataType::kFloat);
|
|
||||||
final_data_type = ArrayDataType::kFloat;
|
|
||||||
} else {
|
} else {
|
||||||
final_data_type = specified_final_data_type;
|
// Nothing to do. Data types stay as-is.
|
||||||
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int i = 0; i < model->flags.input_arrays_size(); i++) {
|
for (int i = 0; i < model->flags.input_arrays_size(); i++) {
|
||||||
auto* array = model->arrays[model->flags.input_arrays(i).name()].get();
|
string const& array_name = model->flags.input_arrays(i).name();
|
||||||
|
auto* array = model->arrays[array_name].get();
|
||||||
// Note that the notion of changing data types only applies to real-numbers
|
// Note that the notion of changing data types only applies to real-numbers
|
||||||
// arrays (see the documentation for inference_input_type).
|
// arrays (see the documentation for inference_input_type).
|
||||||
// TODO(benoitjacob) this is assuming that uint8 arrays are quantized,
|
// TODO(benoitjacob) this is assuming that uint8 arrays are quantized,
|
||||||
// i.e. represent real numbers by means of quantization parameters,
|
// i.e. represent real numbers by means of quantization parameters,
|
||||||
// and not plain integer uint8 input arrays.
|
// and not plain integer uint8 input arrays.
|
||||||
const bool is_real_numbers = array->data_type == ArrayDataType::kFloat ||
|
if (!IsRealValued(array->data_type)) {
|
||||||
array->data_type == ArrayDataType::kUint8;
|
// Ignore non-real data types.
|
||||||
if (is_real_numbers) {
|
continue;
|
||||||
array->final_data_type = final_data_type;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
array->final_data_type = type;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -155,23 +174,21 @@ void Transform(const TocoFlags& toco_flags, Model* model) {
|
|||||||
const FileFormat output_format = toco_flags.output_format();
|
const FileFormat output_format = toco_flags.output_format();
|
||||||
const IODataType inference_type = toco_flags.inference_type();
|
const IODataType inference_type = toco_flags.inference_type();
|
||||||
|
|
||||||
const bool output_is_tflite = output_format == TFLITE;
|
const bool quantize_output =
|
||||||
|
SupportsQuantization(output_format) && inference_type == QUANTIZED_UINT8;
|
||||||
|
|
||||||
const bool output_is_tflite_quantized =
|
if (quantize_output) {
|
||||||
output_is_tflite && inference_type == QUANTIZED_UINT8;
|
|
||||||
|
|
||||||
if (output_is_tflite_quantized) {
|
|
||||||
QCHECK_NE(toco_flags.inference_input_type(), FLOAT)
|
QCHECK_NE(toco_flags.inference_input_type(), FLOAT)
|
||||||
<< "Quantized inference is not allowed with float inputs.";
|
<< "Quantized inference is not allowed with float inputs.";
|
||||||
}
|
}
|
||||||
|
|
||||||
SetArrayFinalDataTypes(toco_flags, model);
|
SetFinalDataTypeOnInputs(toco_flags, model);
|
||||||
|
|
||||||
GraphTransformationsSet transformations;
|
GraphTransformationsSet transformations;
|
||||||
MakeGeneralGraphTransformationsSet(&transformations);
|
MakeGeneralGraphTransformationsSet(&transformations);
|
||||||
auto* remove_trivial_reshape = new RemoveTrivialReshape;
|
auto* remove_trivial_reshape = new RemoveTrivialReshape;
|
||||||
transformations.Add(remove_trivial_reshape);
|
transformations.Add(remove_trivial_reshape);
|
||||||
if (output_format == TFLITE) {
|
if (SupportsFusedActivationFunction(output_format)) {
|
||||||
transformations.Add(new FuseActivationFunctions);
|
transformations.Add(new FuseActivationFunctions);
|
||||||
} else {
|
} else {
|
||||||
transformations.Add(new UnfuseActivationFunctions);
|
transformations.Add(new UnfuseActivationFunctions);
|
||||||
@ -190,25 +207,24 @@ void Transform(const TocoFlags& toco_flags, Model* model) {
|
|||||||
// easy to pass a new toco flag. Once that is resolved on the DarwiNN
|
// easy to pass a new toco flag. Once that is resolved on the DarwiNN
|
||||||
// tests side, the special-casing of DarwiNN here can go away.
|
// tests side, the special-casing of DarwiNN here can go away.
|
||||||
// TODO(benoitjacob): so drop it when we can.
|
// TODO(benoitjacob): so drop it when we can.
|
||||||
if ((output_is_tflite_quantized &&
|
if ((quantize_output && toco_flags.reorder_across_fake_quant())) {
|
||||||
toco_flags.reorder_across_fake_quant())) {
|
|
||||||
transformations.Add(new DropFakeQuant);
|
transformations.Add(new DropFakeQuant);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
transformations.Add(new ConvertPureConvToDepthwise);
|
transformations.Add(new ConvertPureConvToDepthwise);
|
||||||
// TFLite export does not yet support fused LSTM cell.
|
// TFLite export does not yet support fused LSTM cell.
|
||||||
if (output_format == TENSORFLOW_GRAPHDEF) {
|
if (SupportsLstmCell(output_format)) {
|
||||||
transformations.Add(new IdentifyLstmCell);
|
transformations.Add(new IdentifyLstmCell);
|
||||||
}
|
}
|
||||||
transformations.Add(new ResolveConstantConcatenation);
|
transformations.Add(new ResolveConstantConcatenation);
|
||||||
RunGraphTransformations(model, "general graph transformations",
|
RunGraphTransformations(model, "general graph transformations",
|
||||||
transformations);
|
transformations);
|
||||||
if (output_is_tflite_quantized) {
|
if (quantize_output) {
|
||||||
RunGraphTransformations(model, "pre-quantization graph transformations",
|
RunGraphTransformations(model, "pre-quantization graph transformations",
|
||||||
{new HardcodeMinMax, new DropFakeQuant});
|
{new HardcodeMinMax, new DropFakeQuant});
|
||||||
}
|
}
|
||||||
|
|
||||||
if (output_is_tflite_quantized) {
|
if (quantize_output) {
|
||||||
if (toco_flags.has_default_ranges_min() &&
|
if (toco_flags.has_default_ranges_min() &&
|
||||||
toco_flags.has_default_ranges_max()) {
|
toco_flags.has_default_ranges_max()) {
|
||||||
UseDefaultMinMaxRangeValues(model, toco_flags.default_ranges_min(),
|
UseDefaultMinMaxRangeValues(model, toco_flags.default_ranges_min(),
|
||||||
@ -239,7 +255,7 @@ void Transform(const TocoFlags& toco_flags, Model* model) {
|
|||||||
CheckUnsupportedOperations(*model);
|
CheckUnsupportedOperations(*model);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (output_is_tflite) {
|
if (SupportsPreallocatedWorkspace(output_format)) {
|
||||||
AllocateTransientArrays(model, kDefaultTransientDataAlignment);
|
AllocateTransientArrays(model, kDefaultTransientDataAlignment);
|
||||||
LogDump(kLogLevelModelChanged, "AFTER ALLOCATION", *model);
|
LogDump(kLogLevelModelChanged, "AFTER ALLOCATION", *model);
|
||||||
}
|
}
|
||||||
|
@ -294,6 +294,7 @@ void LogArray(int log_level, const Model& model, const string& name) {
|
|||||||
VLOG(log_level) << "Array: " << name;
|
VLOG(log_level) << "Array: " << name;
|
||||||
switch (array.data_type) {
|
switch (array.data_type) {
|
||||||
case ArrayDataType::kNone:
|
case ArrayDataType::kNone:
|
||||||
|
VLOG(log_level) << " Data type:";
|
||||||
break;
|
break;
|
||||||
case ArrayDataType::kFloat:
|
case ArrayDataType::kFloat:
|
||||||
VLOG(log_level) << " Data type: kFloat";
|
VLOG(log_level) << " Data type: kFloat";
|
||||||
@ -309,6 +310,24 @@ void LogArray(int log_level, const Model& model, const string& name) {
|
|||||||
<< static_cast<int>(array.data_type) << ")";
|
<< static_cast<int>(array.data_type) << ")";
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
switch (array.final_data_type) {
|
||||||
|
case ArrayDataType::kNone:
|
||||||
|
VLOG(log_level) << " Final type:";
|
||||||
|
break;
|
||||||
|
case ArrayDataType::kFloat:
|
||||||
|
VLOG(log_level) << " Final type: kFloat";
|
||||||
|
break;
|
||||||
|
case ArrayDataType::kInt32:
|
||||||
|
VLOG(log_level) << " Final type: kInt32";
|
||||||
|
break;
|
||||||
|
case ArrayDataType::kUint8:
|
||||||
|
VLOG(log_level) << " Final type: kUint8";
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
VLOG(log_level) << " Final type: other (numerical value: "
|
||||||
|
<< static_cast<int>(array.data_type) << ")";
|
||||||
|
break;
|
||||||
|
}
|
||||||
if (array.buffer) {
|
if (array.buffer) {
|
||||||
VLOG(log_level) << " Constant Buffer";
|
VLOG(log_level) << " Constant Buffer";
|
||||||
}
|
}
|
||||||
@ -1562,7 +1581,11 @@ void CheckFinalDataTypesSatisfied(const Model& model) {
|
|||||||
for (const auto& array_entry : model.arrays) {
|
for (const auto& array_entry : model.arrays) {
|
||||||
const auto& array = *array_entry.second;
|
const auto& array = *array_entry.second;
|
||||||
if (array.final_data_type != ArrayDataType::kNone) {
|
if (array.final_data_type != ArrayDataType::kNone) {
|
||||||
CHECK(array.final_data_type == array.data_type);
|
CHECK(array.final_data_type == array.data_type)
|
||||||
|
<< "Array \"" << array_entry.first
|
||||||
|
<< "\" has mis-matching actual and final data types ("
|
||||||
|
<< static_cast<int>(array.data_type) << ","
|
||||||
|
<< static_cast<int>(array.final_data_type) << ").";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user