Fewer TOCO crashes.

PiperOrigin-RevId: 232730162
This commit is contained in:
A. Unique TensorFlower 2019-02-06 13:05:13 -08:00 committed by TensorFlower Gardener
parent f11085ebf2
commit ec7958da20
6 changed files with 70 additions and 44 deletions

View File

@ -128,7 +128,8 @@ void DiscardUselessConnectedComponentsAndRNNBackEdges(Model* model) {
} }
bool GraphTransformationsPass(int increment, Model* model, bool GraphTransformationsPass(int increment, Model* model,
const GraphTransformationsSet& transformations) { const GraphTransformationsSet& transformations,
tensorflow::Status* status) {
CHECK(increment == 1 || increment == -1); CHECK(increment == 1 || increment == -1);
bool changed = false; bool changed = false;
if (model->operators.empty()) { if (model->operators.empty()) {
@ -142,7 +143,10 @@ bool GraphTransformationsPass(int increment, Model* model,
for (const auto& transformation : transformations) { for (const auto& transformation : transformations) {
CHECK(!changed_now); CHECK(!changed_now);
CHECK(transformation->Messages().empty()); CHECK(transformation->Messages().empty());
CHECK(transformation->Run(model, op_index, &changed_now).ok()); *status = transformation->Run(model, op_index, &changed_now);
if (!status->ok()) {
return false;
}
const char* made_a_change_msg = const char* made_a_change_msg =
changed_now ? "made a change" : "did NOT make a change"; changed_now ? "made a change" : "did NOT make a change";
const int log_level = const int log_level =
@ -186,18 +190,21 @@ bool GraphTransformationsPass(int increment, Model* model,
} // namespace } // namespace
void RunGraphTransformations(Model* model, const string& msg, tensorflow::Status RunGraphTransformationsWithStatus(
Model* model, const string& msg,
const GraphTransformationsSet& transformations) { const GraphTransformationsSet& transformations) {
PrintModelStats(toco::port::StringF("Before %s", msg), *model); PrintModelStats(toco::port::StringF("Before %s", msg), *model);
int pass_index = 0; int pass_index = 0;
tensorflow::Status status;
while (GraphTransformationsPass((pass_index % 2) ? -1 : 1, model, while (GraphTransformationsPass((pass_index % 2) ? -1 : 1, model,
transformations)) { transformations, &status)) {
pass_index++; pass_index++;
const auto& label = const auto& label =
toco::port::StringF("After %s pass %d", msg, pass_index); toco::port::StringF("After %s pass %d", msg, pass_index);
PrintModelStats(label, *model); PrintModelStats(label, *model);
CheckInvariants(*model); CheckInvariants(*model);
} }
return status;
} }
} // namespace toco } // namespace toco

View File

@ -102,9 +102,17 @@ class GraphTransformationsSet {
// construct GraphTransformation objects by using 'new', pass us // construct GraphTransformation objects by using 'new', pass us
// the resulting raw pointers, and this RunGraphTransformations // the resulting raw pointers, and this RunGraphTransformations
// takes care of delete'ing these pointers. // takes care of delete'ing these pointers.
void RunGraphTransformations(Model* model, const string& message, tensorflow::Status RunGraphTransformationsWithStatus(
Model* model, const string& msg,
const GraphTransformationsSet& transformations); const GraphTransformationsSet& transformations);
inline void RunGraphTransformations(
Model* model, const string& msg,
const GraphTransformationsSet& transformations) {
auto s = RunGraphTransformationsWithStatus(model, msg, transformations);
CHECK(s.ok()) << s.error_message();
}
#define DECLARE_GRAPH_TRANSFORMATION(GTName) \ #define DECLARE_GRAPH_TRANSFORMATION(GTName) \
class GTName : public GraphTransformation { \ class GTName : public GraphTransformation { \
public: \ public: \

View File

@ -489,12 +489,12 @@ void FixMinMaxPostQuantization(GraphTransformation* transformation,
} }
} }
if (!SupportsQuantization(op)) { if (!SupportsQuantization(op)) {
LOG(FATAL) << "Unimplemented: this graph contains an operator of type " return tensorflow::errors::InvalidArgument(
<< HelpfulOperatorTypeName(op) "Unimplemented: this graph contains an operator of type ",
<< " for which the quantized form is not yet implemented. " HelpfulOperatorTypeName(op),
"Sorry, and patches welcome (that's a relatively fun patch " " for which the quantized form is not yet implemented. Sorry, and "
"to write, mostly providing the actual quantized arithmetic " "patches welcome (that's a relatively fun patch to write, mostly "
"code for this op)."; "providing the actual quantized arithmetic code for this op).");
} }
for (const auto& input : op.inputs) { for (const auto& input : op.inputs) {

View File

@ -77,7 +77,7 @@ tensorflow::Status Convert(const string& graph_def_contents,
string* output_file_contents) { string* output_file_contents) {
std::unique_ptr<Model> model = std::unique_ptr<Model> model =
Import(toco_flags, model_flags, graph_def_contents); Import(toco_flags, model_flags, graph_def_contents);
Transform(toco_flags, model.get()); TF_RETURN_IF_ERROR(TransformWithStatus(toco_flags, model.get()));
return Export(toco_flags, *model, toco_flags.allow_custom_ops(), return Export(toco_flags, *model, toco_flags.allow_custom_ops(),
output_file_contents); output_file_contents);
} }

View File

@ -236,7 +236,8 @@ std::unique_ptr<Model> Import(const TocoFlags& toco_flags,
return model; return model;
} }
void Transform(const TocoFlags& toco_flags, Model* model) { tensorflow::Status TransformWithStatus(const TocoFlags& toco_flags,
Model* model) {
const FileFormat output_format = toco_flags.output_format(); const FileFormat output_format = toco_flags.output_format();
const IODataType inference_type = toco_flags.inference_type(); const IODataType inference_type = toco_flags.inference_type();
@ -258,8 +259,8 @@ void Transform(const TocoFlags& toco_flags, Model* model) {
// stop optimizations from crossing the input/output boundaries. For example // stop optimizations from crossing the input/output boundaries. For example
// this will stop BatchNorm fusing if the output node is in between a conv // this will stop BatchNorm fusing if the output node is in between a conv
// and BatchNorm layers. // and BatchNorm layers.
RunGraphTransformations(model, "Removing unused ops", TF_RETURN_IF_ERROR(RunGraphTransformationsWithStatus(
{new toco::RemoveUnusedOp}); model, "Removing unused ops", {new toco::RemoveUnusedOp}));
GraphTransformationsSet transformations; GraphTransformationsSet transformations;
MakeGeneralGraphTransformationsSet(&transformations); MakeGeneralGraphTransformationsSet(&transformations);
@ -307,20 +308,21 @@ void Transform(const TocoFlags& toco_flags, Model* model) {
identify_dilated_conv->set_identify_depthwise_conv(false); identify_dilated_conv->set_identify_depthwise_conv(false);
} }
transformations.Add(identify_dilated_conv); transformations.Add(identify_dilated_conv);
RunGraphTransformations(model, "general graph transformations", TF_RETURN_IF_ERROR(RunGraphTransformationsWithStatus(
transformations); model, "general graph transformations", transformations));
if (quantize_output) { if (quantize_output) {
if (toco_flags.propagate_fake_quant_num_bits()) { if (toco_flags.propagate_fake_quant_num_bits()) {
RunGraphTransformations(model, TF_RETURN_IF_ERROR(RunGraphTransformationsWithStatus(
"fake quant propagation graph transformations", model, "fake quant propagation graph transformations",
{new PropagateFakeQuantNumBits}); {new PropagateFakeQuantNumBits}));
} }
RunGraphTransformations(model, "pre-quantization graph transformations", TF_RETURN_IF_ERROR(RunGraphTransformationsWithStatus(
model, "pre-quantization graph transformations",
{ {
new HardcodeMinMax, new HardcodeMinMax,
new DropFakeQuant, new DropFakeQuant,
}); }));
} }
// Try to merge bidirectional sequence lstm or rnn if present. // Try to merge bidirectional sequence lstm or rnn if present.
@ -328,8 +330,9 @@ void Transform(const TocoFlags& toco_flags, Model* model) {
bidirectional_transformations.Add(new RemoveUnusedOp); bidirectional_transformations.Add(new RemoveUnusedOp);
bidirectional_transformations.Add(new toco::GroupBidirectionalSequenceLstm); bidirectional_transformations.Add(new toco::GroupBidirectionalSequenceLstm);
bidirectional_transformations.Add(new toco::GroupBidirectionalSequenceRnn); bidirectional_transformations.Add(new toco::GroupBidirectionalSequenceRnn);
RunGraphTransformations(model, "Group bidirectional sequence lstm/rnn", TF_RETURN_IF_ERROR(RunGraphTransformationsWithStatus(
bidirectional_transformations); model, "Group bidirectional sequence lstm/rnn",
bidirectional_transformations));
// Fix any issues with IO edges. This must happen after any transform that // Fix any issues with IO edges. This must happen after any transform that
// may modify the structure of the edges. // may modify the structure of the edges.
@ -357,12 +360,12 @@ void Transform(const TocoFlags& toco_flags, Model* model) {
toco_flags.default_int16_ranges_max()); toco_flags.default_int16_ranges_max());
} }
if (propagate_default_min_max->has_any_ranges_defined()) { if (propagate_default_min_max->has_any_ranges_defined()) {
RunGraphTransformations( TF_RETURN_IF_ERROR(RunGraphTransformationsWithStatus(
model, "default min-max range propagation graph transformations", model, "default min-max range propagation graph transformations",
{ {
propagate_default_min_max.release(), propagate_default_min_max.release(),
new HardcodeMinMax, new HardcodeMinMax,
}); }));
} }
CheckIsReadyForQuantization(*model); CheckIsReadyForQuantization(*model);
@ -372,17 +375,18 @@ void Transform(const TocoFlags& toco_flags, Model* model) {
toco_flags.allow_nudging_weights_to_use_fast_gemm_kernel()); toco_flags.allow_nudging_weights_to_use_fast_gemm_kernel());
ensure_safe_for_int8_kernels->set_has_default_ranges_flag( ensure_safe_for_int8_kernels->set_has_default_ranges_flag(
has_default_ranges_flag); has_default_ranges_flag);
RunGraphTransformations(model, "quantization graph transformations", TF_RETURN_IF_ERROR(RunGraphTransformationsWithStatus(
model, "quantization graph transformations",
{ {
new RemoveTrivialQuantizedActivationFunc, new RemoveTrivialQuantizedActivationFunc,
new RemoveTrivialQuantizedMinMax, new RemoveTrivialQuantizedMinMax,
new Quantize, new Quantize,
new RemoveFinalDequantizeOp, new RemoveFinalDequantizeOp,
ensure_safe_for_int8_kernels, ensure_safe_for_int8_kernels,
}); }));
if (SupportsShuffledFCWeights(output_format)) { if (SupportsShuffledFCWeights(output_format)) {
RunGraphTransformations(model, "shuffling of FC weights", TF_RETURN_IF_ERROR(RunGraphTransformationsWithStatus(
{new ShuffleFCWeights}); model, "shuffling of FC weights", {new ShuffleFCWeights}));
} }
} else { } else {
GraphTransformationsSet dequantization_transformations{new Dequantize}; GraphTransformationsSet dequantization_transformations{new Dequantize};
@ -392,8 +396,9 @@ void Transform(const TocoFlags& toco_flags, Model* model) {
dequantization_transformations.Add(new DropFakeQuant); dequantization_transformations.Add(new DropFakeQuant);
} }
RunGraphTransformations(model, "dequantization graph transformations", TF_RETURN_IF_ERROR(RunGraphTransformationsWithStatus(
dequantization_transformations); model, "dequantization graph transformations",
dequantization_transformations));
} }
if (output_format == TENSORFLOW_GRAPHDEF) { if (output_format == TENSORFLOW_GRAPHDEF) {
@ -425,6 +430,7 @@ void Transform(const TocoFlags& toco_flags, Model* model) {
<< " billion (note that a multiply-add is counted as 2 ops)."; << " billion (note that a multiply-add is counted as 2 ops).";
} }
model->ops_count = ops_count; model->ops_count = ops_count;
return tensorflow::Status::OK();
} }
tensorflow::Status Export(const TocoFlags& toco_flags, const Model& model, tensorflow::Status Export(const TocoFlags& toco_flags, const Model& model,

View File

@ -31,7 +31,12 @@ std::unique_ptr<Model> Import(const TocoFlags& toco_flags,
// Transforms a Model. The resulting Model is ready to be passed // Transforms a Model. The resulting Model is ready to be passed
// to Export with the exact same toco_flags. // to Export with the exact same toco_flags.
void Transform(const TocoFlags& toco_flags, Model* model); tensorflow::Status TransformWithStatus(const TocoFlags& toco_flags,
Model* model);
inline void Transform(const TocoFlags& toco_flags, Model* model) {
auto s = TransformWithStatus(toco_flags, model);
CHECK(s.ok()) << s.error_message();
}
// Exports the Model, which must be of the 'lowered' form returned by // Exports the Model, which must be of the 'lowered' form returned by
// Transform, to a file of the format given by // Transform, to a file of the format given by