Minor toco changes to support new features in tfmini.

PiperOrigin-RevId: 176694498
This commit is contained in:
A. Unique TensorFlower 2017-11-22 13:29:47 -08:00 committed by TensorFlower Gardener
parent 8067aa0862
commit c5b8a5ed86
3 changed files with 100 additions and 37 deletions

View File

@ -35,8 +35,11 @@ limitations under the License.
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/platform/logging.h"
using tensorflow::DT_BOOL;
using tensorflow::DT_FLOAT;
using tensorflow::DT_INT32;
using tensorflow::DT_INT64;
using tensorflow::DT_UINT8;
using tensorflow::GraphDef;
using tensorflow::TensorProto;
@ -1500,10 +1503,29 @@ void ConvertOperator(const Model& model, const Operator& src_op,
}
}
void AddPlaceholder(const string& name, GraphDef* tensorflow_graph) {
void AddPlaceholder(const string& name, ArrayDataType type,
GraphDef* tensorflow_graph) {
auto* placeholder = tensorflow_graph->add_node();
placeholder->set_op("Placeholder");
(*placeholder->mutable_attr())["dtype"].set_type(DT_FLOAT);
switch (type) {
case ArrayDataType::kBool:
(*placeholder->mutable_attr())["dtype"].set_type(DT_BOOL);
break;
case ArrayDataType::kFloat:
(*placeholder->mutable_attr())["dtype"].set_type(DT_FLOAT);
break;
case ArrayDataType::kUint8:
(*placeholder->mutable_attr())["dtype"].set_type(DT_UINT8);
break;
case ArrayDataType::kInt32:
(*placeholder->mutable_attr())["dtype"].set_type(DT_INT32);
break;
case ArrayDataType::kInt64:
(*placeholder->mutable_attr())["dtype"].set_type(DT_INT64);
break;
default:
LOG(FATAL) << "Unexpected data type in array \"" << name << "\"";
}
placeholder->set_name(name);
}
@ -1531,7 +1553,9 @@ void AddPlaceholderForRNNState(const Model& model, const string& name, int size,
void ExportTensorFlowGraphDefImplementation(const Model& model,
GraphDef* tensorflow_graph) {
for (const auto& input_array : model.flags.input_arrays()) {
AddPlaceholder(input_array.name(), tensorflow_graph);
AddPlaceholder(input_array.name(),
model.arrays.at(input_array.name())->data_type,
tensorflow_graph);
}
for (const auto& rnn_state : model.flags.rnn_states()) {
AddPlaceholderForRNNState(model, rnn_state.state_array(), rnn_state.size(),

View File

@ -85,38 +85,57 @@ void MakeGeneralGraphTransformationsSet(
transformations->Add(new MakeInitialDequantizeOperator);
}
void SetArrayFinalDataTypes(const TocoFlags& toco_flags, Model* model) {
const bool output_supports_only_float =
toco_flags.output_format() == TENSORFLOW_GRAPHDEF;
bool SupportsQuantization(FileFormat format) {
return (format == GRAPHVIZ_DOT || format == TFLITE);
;
}
ArrayDataType specified_final_data_type = ArrayDataType::kNone;
bool SupportsFusedActivationFunction(FileFormat format) {
return (format == GRAPHVIZ_DOT || format == TFLITE);
}
bool SupportsLstmCell(FileFormat format) {
return (format == TENSORFLOW_GRAPHDEF || format == GRAPHVIZ_DOT);
}
bool SupportsPreallocatedWorkspace(FileFormat format) {
return (format == GRAPHVIZ_DOT || format == TFLITE);
}
bool IsRealValued(toco::ArrayDataType type) {
return static_cast<bool>(type == toco::ArrayDataType::kFloat ||
type == toco::ArrayDataType::kUint8);
}
void SetFinalDataTypeOnInputs(const TocoFlags& toco_flags, Model* model) {
const FileFormat output_format = toco_flags.output_format();
ArrayDataType type;
if (toco_flags.has_inference_input_type()) {
specified_final_data_type =
ConvertIODataTypeToArrayDataType(toco_flags.inference_input_type());
type = ConvertIODataTypeToArrayDataType(toco_flags.inference_input_type());
} else if (toco_flags.has_inference_type()) {
specified_final_data_type =
ConvertIODataTypeToArrayDataType(toco_flags.inference_type());
}
ArrayDataType final_data_type = ArrayDataType::kNone;
if (output_supports_only_float) {
QCHECK(specified_final_data_type == ArrayDataType::kNone ||
specified_final_data_type == ArrayDataType::kFloat);
final_data_type = ArrayDataType::kFloat;
type = ConvertIODataTypeToArrayDataType(toco_flags.inference_type());
} else if (!SupportsQuantization(output_format)) {
// Data type is implicitly float for non-quantized formats
type = ArrayDataType::kFloat;
} else {
final_data_type = specified_final_data_type;
// Nothing to do. Data types stay as-is.
return;
}
for (int i = 0; i < model->flags.input_arrays_size(); i++) {
auto* array = model->arrays[model->flags.input_arrays(i).name()].get();
string const& array_name = model->flags.input_arrays(i).name();
auto* array = model->arrays[array_name].get();
// Note that the notion of changing data types only applies to real-numbers
// arrays (see the documentation for inference_input_type).
// TODO(benoitjacob) this is assuming that uint8 arrays are quantized,
// i.e. represent real numbers by means of quantization parameters,
// and not plain integer uint8 input arrays.
const bool is_real_numbers = array->data_type == ArrayDataType::kFloat ||
array->data_type == ArrayDataType::kUint8;
if (is_real_numbers) {
array->final_data_type = final_data_type;
if (!IsRealValued(array->data_type)) {
// Ignore non-real data types.
continue;
}
array->final_data_type = type;
}
}
@ -155,23 +174,21 @@ void Transform(const TocoFlags& toco_flags, Model* model) {
const FileFormat output_format = toco_flags.output_format();
const IODataType inference_type = toco_flags.inference_type();
const bool output_is_tflite = output_format == TFLITE;
const bool quantize_output =
SupportsQuantization(output_format) && inference_type == QUANTIZED_UINT8;
const bool output_is_tflite_quantized =
output_is_tflite && inference_type == QUANTIZED_UINT8;
if (output_is_tflite_quantized) {
if (quantize_output) {
QCHECK_NE(toco_flags.inference_input_type(), FLOAT)
<< "Quantized inference is not allowed with float inputs.";
}
SetArrayFinalDataTypes(toco_flags, model);
SetFinalDataTypeOnInputs(toco_flags, model);
GraphTransformationsSet transformations;
MakeGeneralGraphTransformationsSet(&transformations);
auto* remove_trivial_reshape = new RemoveTrivialReshape;
transformations.Add(remove_trivial_reshape);
if (output_format == TFLITE) {
if (SupportsFusedActivationFunction(output_format)) {
transformations.Add(new FuseActivationFunctions);
} else {
transformations.Add(new UnfuseActivationFunctions);
@ -190,25 +207,24 @@ void Transform(const TocoFlags& toco_flags, Model* model) {
// easy to pass a new toco flag. Once that is resolved on the DarwiNN
// tests side, the special-casing of DarwiNN here can go away.
// TODO(benoitjacob): so drop it when we can.
if ((output_is_tflite_quantized &&
toco_flags.reorder_across_fake_quant())) {
if ((quantize_output && toco_flags.reorder_across_fake_quant())) {
transformations.Add(new DropFakeQuant);
}
}
transformations.Add(new ConvertPureConvToDepthwise);
// TFLite export does not yet support fused LSTM cell.
if (output_format == TENSORFLOW_GRAPHDEF) {
if (SupportsLstmCell(output_format)) {
transformations.Add(new IdentifyLstmCell);
}
transformations.Add(new ResolveConstantConcatenation);
RunGraphTransformations(model, "general graph transformations",
transformations);
if (output_is_tflite_quantized) {
if (quantize_output) {
RunGraphTransformations(model, "pre-quantization graph transformations",
{new HardcodeMinMax, new DropFakeQuant});
}
if (output_is_tflite_quantized) {
if (quantize_output) {
if (toco_flags.has_default_ranges_min() &&
toco_flags.has_default_ranges_max()) {
UseDefaultMinMaxRangeValues(model, toco_flags.default_ranges_min(),
@ -239,7 +255,7 @@ void Transform(const TocoFlags& toco_flags, Model* model) {
CheckUnsupportedOperations(*model);
}
if (output_is_tflite) {
if (SupportsPreallocatedWorkspace(output_format)) {
AllocateTransientArrays(model, kDefaultTransientDataAlignment);
LogDump(kLogLevelModelChanged, "AFTER ALLOCATION", *model);
}

View File

@ -294,6 +294,7 @@ void LogArray(int log_level, const Model& model, const string& name) {
VLOG(log_level) << "Array: " << name;
switch (array.data_type) {
case ArrayDataType::kNone:
VLOG(log_level) << " Data type:";
break;
case ArrayDataType::kFloat:
VLOG(log_level) << " Data type: kFloat";
@ -309,6 +310,24 @@ void LogArray(int log_level, const Model& model, const string& name) {
<< static_cast<int>(array.data_type) << ")";
break;
}
switch (array.final_data_type) {
case ArrayDataType::kNone:
VLOG(log_level) << " Final type:";
break;
case ArrayDataType::kFloat:
VLOG(log_level) << " Final type: kFloat";
break;
case ArrayDataType::kInt32:
VLOG(log_level) << " Final type: kInt32";
break;
case ArrayDataType::kUint8:
VLOG(log_level) << " Final type: kUint8";
break;
default:
VLOG(log_level) << " Final type: other (numerical value: "
<< static_cast<int>(array.data_type) << ")";
break;
}
if (array.buffer) {
VLOG(log_level) << " Constant Buffer";
}
@ -1562,7 +1581,11 @@ void CheckFinalDataTypesSatisfied(const Model& model) {
for (const auto& array_entry : model.arrays) {
const auto& array = *array_entry.second;
if (array.final_data_type != ArrayDataType::kNone) {
CHECK(array.final_data_type == array.data_type);
CHECK(array.final_data_type == array.data_type)
<< "Array \"" << array_entry.first
<< "\" has mis-matching actual and final data types ("
<< static_cast<int>(array.data_type) << ","
<< static_cast<int>(array.final_data_type) << ").";
}
}
}