Minor toco changes to support new features in tfmini.
PiperOrigin-RevId: 176694498
This commit is contained in:
parent
8067aa0862
commit
c5b8a5ed86
@ -35,8 +35,11 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
using tensorflow::DT_BOOL;
|
||||
using tensorflow::DT_FLOAT;
|
||||
using tensorflow::DT_INT32;
|
||||
using tensorflow::DT_INT64;
|
||||
using tensorflow::DT_UINT8;
|
||||
using tensorflow::GraphDef;
|
||||
using tensorflow::TensorProto;
|
||||
|
||||
@ -1500,10 +1503,29 @@ void ConvertOperator(const Model& model, const Operator& src_op,
|
||||
}
|
||||
}
|
||||
|
||||
void AddPlaceholder(const string& name, GraphDef* tensorflow_graph) {
|
||||
void AddPlaceholder(const string& name, ArrayDataType type,
|
||||
GraphDef* tensorflow_graph) {
|
||||
auto* placeholder = tensorflow_graph->add_node();
|
||||
placeholder->set_op("Placeholder");
|
||||
(*placeholder->mutable_attr())["dtype"].set_type(DT_FLOAT);
|
||||
switch (type) {
|
||||
case ArrayDataType::kBool:
|
||||
(*placeholder->mutable_attr())["dtype"].set_type(DT_BOOL);
|
||||
break;
|
||||
case ArrayDataType::kFloat:
|
||||
(*placeholder->mutable_attr())["dtype"].set_type(DT_FLOAT);
|
||||
break;
|
||||
case ArrayDataType::kUint8:
|
||||
(*placeholder->mutable_attr())["dtype"].set_type(DT_UINT8);
|
||||
break;
|
||||
case ArrayDataType::kInt32:
|
||||
(*placeholder->mutable_attr())["dtype"].set_type(DT_INT32);
|
||||
break;
|
||||
case ArrayDataType::kInt64:
|
||||
(*placeholder->mutable_attr())["dtype"].set_type(DT_INT64);
|
||||
break;
|
||||
default:
|
||||
LOG(FATAL) << "Unexpected data type in array \"" << name << "\"";
|
||||
}
|
||||
placeholder->set_name(name);
|
||||
}
|
||||
|
||||
@ -1531,7 +1553,9 @@ void AddPlaceholderForRNNState(const Model& model, const string& name, int size,
|
||||
void ExportTensorFlowGraphDefImplementation(const Model& model,
|
||||
GraphDef* tensorflow_graph) {
|
||||
for (const auto& input_array : model.flags.input_arrays()) {
|
||||
AddPlaceholder(input_array.name(), tensorflow_graph);
|
||||
AddPlaceholder(input_array.name(),
|
||||
model.arrays.at(input_array.name())->data_type,
|
||||
tensorflow_graph);
|
||||
}
|
||||
for (const auto& rnn_state : model.flags.rnn_states()) {
|
||||
AddPlaceholderForRNNState(model, rnn_state.state_array(), rnn_state.size(),
|
||||
|
@ -85,38 +85,57 @@ void MakeGeneralGraphTransformationsSet(
|
||||
transformations->Add(new MakeInitialDequantizeOperator);
|
||||
}
|
||||
|
||||
void SetArrayFinalDataTypes(const TocoFlags& toco_flags, Model* model) {
|
||||
const bool output_supports_only_float =
|
||||
toco_flags.output_format() == TENSORFLOW_GRAPHDEF;
|
||||
bool SupportsQuantization(FileFormat format) {
|
||||
return (format == GRAPHVIZ_DOT || format == TFLITE);
|
||||
;
|
||||
}
|
||||
|
||||
ArrayDataType specified_final_data_type = ArrayDataType::kNone;
|
||||
bool SupportsFusedActivationFunction(FileFormat format) {
|
||||
return (format == GRAPHVIZ_DOT || format == TFLITE);
|
||||
}
|
||||
|
||||
bool SupportsLstmCell(FileFormat format) {
|
||||
return (format == TENSORFLOW_GRAPHDEF || format == GRAPHVIZ_DOT);
|
||||
}
|
||||
|
||||
bool SupportsPreallocatedWorkspace(FileFormat format) {
|
||||
return (format == GRAPHVIZ_DOT || format == TFLITE);
|
||||
}
|
||||
|
||||
bool IsRealValued(toco::ArrayDataType type) {
|
||||
return static_cast<bool>(type == toco::ArrayDataType::kFloat ||
|
||||
type == toco::ArrayDataType::kUint8);
|
||||
}
|
||||
|
||||
void SetFinalDataTypeOnInputs(const TocoFlags& toco_flags, Model* model) {
|
||||
const FileFormat output_format = toco_flags.output_format();
|
||||
ArrayDataType type;
|
||||
if (toco_flags.has_inference_input_type()) {
|
||||
specified_final_data_type =
|
||||
ConvertIODataTypeToArrayDataType(toco_flags.inference_input_type());
|
||||
type = ConvertIODataTypeToArrayDataType(toco_flags.inference_input_type());
|
||||
} else if (toco_flags.has_inference_type()) {
|
||||
specified_final_data_type =
|
||||
ConvertIODataTypeToArrayDataType(toco_flags.inference_type());
|
||||
}
|
||||
ArrayDataType final_data_type = ArrayDataType::kNone;
|
||||
if (output_supports_only_float) {
|
||||
QCHECK(specified_final_data_type == ArrayDataType::kNone ||
|
||||
specified_final_data_type == ArrayDataType::kFloat);
|
||||
final_data_type = ArrayDataType::kFloat;
|
||||
type = ConvertIODataTypeToArrayDataType(toco_flags.inference_type());
|
||||
} else if (!SupportsQuantization(output_format)) {
|
||||
// Data type is implicitly float for non-quantized formats
|
||||
type = ArrayDataType::kFloat;
|
||||
} else {
|
||||
final_data_type = specified_final_data_type;
|
||||
// Nothing to do. Data types stay as-is.
|
||||
return;
|
||||
}
|
||||
|
||||
for (int i = 0; i < model->flags.input_arrays_size(); i++) {
|
||||
auto* array = model->arrays[model->flags.input_arrays(i).name()].get();
|
||||
string const& array_name = model->flags.input_arrays(i).name();
|
||||
auto* array = model->arrays[array_name].get();
|
||||
// Note that the notion of changing data types only applies to real-numbers
|
||||
// arrays (see the documentation for inference_input_type).
|
||||
// TODO(benoitjacob) this is assuming that uint8 arrays are quantized,
|
||||
// i.e. represent real numbers by means of quantization parameters,
|
||||
// and not plain integer uint8 input arrays.
|
||||
const bool is_real_numbers = array->data_type == ArrayDataType::kFloat ||
|
||||
array->data_type == ArrayDataType::kUint8;
|
||||
if (is_real_numbers) {
|
||||
array->final_data_type = final_data_type;
|
||||
if (!IsRealValued(array->data_type)) {
|
||||
// Ignore non-real data types.
|
||||
continue;
|
||||
}
|
||||
|
||||
array->final_data_type = type;
|
||||
}
|
||||
}
|
||||
|
||||
@ -155,23 +174,21 @@ void Transform(const TocoFlags& toco_flags, Model* model) {
|
||||
const FileFormat output_format = toco_flags.output_format();
|
||||
const IODataType inference_type = toco_flags.inference_type();
|
||||
|
||||
const bool output_is_tflite = output_format == TFLITE;
|
||||
const bool quantize_output =
|
||||
SupportsQuantization(output_format) && inference_type == QUANTIZED_UINT8;
|
||||
|
||||
const bool output_is_tflite_quantized =
|
||||
output_is_tflite && inference_type == QUANTIZED_UINT8;
|
||||
|
||||
if (output_is_tflite_quantized) {
|
||||
if (quantize_output) {
|
||||
QCHECK_NE(toco_flags.inference_input_type(), FLOAT)
|
||||
<< "Quantized inference is not allowed with float inputs.";
|
||||
}
|
||||
|
||||
SetArrayFinalDataTypes(toco_flags, model);
|
||||
SetFinalDataTypeOnInputs(toco_flags, model);
|
||||
|
||||
GraphTransformationsSet transformations;
|
||||
MakeGeneralGraphTransformationsSet(&transformations);
|
||||
auto* remove_trivial_reshape = new RemoveTrivialReshape;
|
||||
transformations.Add(remove_trivial_reshape);
|
||||
if (output_format == TFLITE) {
|
||||
if (SupportsFusedActivationFunction(output_format)) {
|
||||
transformations.Add(new FuseActivationFunctions);
|
||||
} else {
|
||||
transformations.Add(new UnfuseActivationFunctions);
|
||||
@ -190,25 +207,24 @@ void Transform(const TocoFlags& toco_flags, Model* model) {
|
||||
// easy to pass a new toco flag. Once that is resolved on the DarwiNN
|
||||
// tests side, the special-casing of DarwiNN here can go away.
|
||||
// TODO(benoitjacob): so drop it when we can.
|
||||
if ((output_is_tflite_quantized &&
|
||||
toco_flags.reorder_across_fake_quant())) {
|
||||
if ((quantize_output && toco_flags.reorder_across_fake_quant())) {
|
||||
transformations.Add(new DropFakeQuant);
|
||||
}
|
||||
}
|
||||
transformations.Add(new ConvertPureConvToDepthwise);
|
||||
// TFLite export does not yet support fused LSTM cell.
|
||||
if (output_format == TENSORFLOW_GRAPHDEF) {
|
||||
if (SupportsLstmCell(output_format)) {
|
||||
transformations.Add(new IdentifyLstmCell);
|
||||
}
|
||||
transformations.Add(new ResolveConstantConcatenation);
|
||||
RunGraphTransformations(model, "general graph transformations",
|
||||
transformations);
|
||||
if (output_is_tflite_quantized) {
|
||||
if (quantize_output) {
|
||||
RunGraphTransformations(model, "pre-quantization graph transformations",
|
||||
{new HardcodeMinMax, new DropFakeQuant});
|
||||
}
|
||||
|
||||
if (output_is_tflite_quantized) {
|
||||
if (quantize_output) {
|
||||
if (toco_flags.has_default_ranges_min() &&
|
||||
toco_flags.has_default_ranges_max()) {
|
||||
UseDefaultMinMaxRangeValues(model, toco_flags.default_ranges_min(),
|
||||
@ -239,7 +255,7 @@ void Transform(const TocoFlags& toco_flags, Model* model) {
|
||||
CheckUnsupportedOperations(*model);
|
||||
}
|
||||
|
||||
if (output_is_tflite) {
|
||||
if (SupportsPreallocatedWorkspace(output_format)) {
|
||||
AllocateTransientArrays(model, kDefaultTransientDataAlignment);
|
||||
LogDump(kLogLevelModelChanged, "AFTER ALLOCATION", *model);
|
||||
}
|
||||
|
@ -294,6 +294,7 @@ void LogArray(int log_level, const Model& model, const string& name) {
|
||||
VLOG(log_level) << "Array: " << name;
|
||||
switch (array.data_type) {
|
||||
case ArrayDataType::kNone:
|
||||
VLOG(log_level) << " Data type:";
|
||||
break;
|
||||
case ArrayDataType::kFloat:
|
||||
VLOG(log_level) << " Data type: kFloat";
|
||||
@ -309,6 +310,24 @@ void LogArray(int log_level, const Model& model, const string& name) {
|
||||
<< static_cast<int>(array.data_type) << ")";
|
||||
break;
|
||||
}
|
||||
switch (array.final_data_type) {
|
||||
case ArrayDataType::kNone:
|
||||
VLOG(log_level) << " Final type:";
|
||||
break;
|
||||
case ArrayDataType::kFloat:
|
||||
VLOG(log_level) << " Final type: kFloat";
|
||||
break;
|
||||
case ArrayDataType::kInt32:
|
||||
VLOG(log_level) << " Final type: kInt32";
|
||||
break;
|
||||
case ArrayDataType::kUint8:
|
||||
VLOG(log_level) << " Final type: kUint8";
|
||||
break;
|
||||
default:
|
||||
VLOG(log_level) << " Final type: other (numerical value: "
|
||||
<< static_cast<int>(array.data_type) << ")";
|
||||
break;
|
||||
}
|
||||
if (array.buffer) {
|
||||
VLOG(log_level) << " Constant Buffer";
|
||||
}
|
||||
@ -1562,7 +1581,11 @@ void CheckFinalDataTypesSatisfied(const Model& model) {
|
||||
for (const auto& array_entry : model.arrays) {
|
||||
const auto& array = *array_entry.second;
|
||||
if (array.final_data_type != ArrayDataType::kNone) {
|
||||
CHECK(array.final_data_type == array.data_type);
|
||||
CHECK(array.final_data_type == array.data_type)
|
||||
<< "Array \"" << array_entry.first
|
||||
<< "\" has mis-matching actual and final data types ("
|
||||
<< static_cast<int>(array.data_type) << ","
|
||||
<< static_cast<int>(array.final_data_type) << ").";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user