From ec7958da20ab7259b133d4e3c76d170fdb9f699b Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 6 Feb 2019 13:05:13 -0800 Subject: [PATCH] Fewer TOCO crashes. PiperOrigin-RevId: 232730162 --- .../graph_transformations.cc | 17 +++-- .../graph_transformations.h | 12 +++- .../toco/graph_transformations/quantize.cc | 12 ++-- tensorflow/lite/toco/toco_convert.cc | 2 +- tensorflow/lite/toco/toco_tooling.cc | 64 ++++++++++--------- tensorflow/lite/toco/toco_tooling.h | 7 +- 6 files changed, 70 insertions(+), 44 deletions(-) diff --git a/tensorflow/lite/toco/graph_transformations/graph_transformations.cc b/tensorflow/lite/toco/graph_transformations/graph_transformations.cc index a0260e24013..e4eb7698597 100644 --- a/tensorflow/lite/toco/graph_transformations/graph_transformations.cc +++ b/tensorflow/lite/toco/graph_transformations/graph_transformations.cc @@ -128,7 +128,8 @@ void DiscardUselessConnectedComponentsAndRNNBackEdges(Model* model) { } bool GraphTransformationsPass(int increment, Model* model, - const GraphTransformationsSet& transformations) { + const GraphTransformationsSet& transformations, + tensorflow::Status* status) { CHECK(increment == 1 || increment == -1); bool changed = false; if (model->operators.empty()) { @@ -142,7 +143,10 @@ bool GraphTransformationsPass(int increment, Model* model, for (const auto& transformation : transformations) { CHECK(!changed_now); CHECK(transformation->Messages().empty()); - CHECK(transformation->Run(model, op_index, &changed_now).ok()); + *status = transformation->Run(model, op_index, &changed_now); + if (!status->ok()) { + return false; + } const char* made_a_change_msg = changed_now ? "made a change" : "did NOT make a change"; const int log_level = @@ -186,18 +190,21 @@ bool GraphTransformationsPass(int increment, Model* model, } // namespace -void RunGraphTransformations(Model* model, const string& msg, - const GraphTransformationsSet& transformations) { +tensorflow::Status RunGraphTransformationsWithStatus( + Model* model, const string& msg, + const GraphTransformationsSet& transformations) { PrintModelStats(toco::port::StringF("Before %s", msg), *model); int pass_index = 0; + tensorflow::Status status; while (GraphTransformationsPass((pass_index % 2) ? -1 : 1, model, - transformations)) { + transformations, &status)) { pass_index++; const auto& label = toco::port::StringF("After %s pass %d", msg, pass_index); PrintModelStats(label, *model); CheckInvariants(*model); } + return status; } } // namespace toco diff --git a/tensorflow/lite/toco/graph_transformations/graph_transformations.h b/tensorflow/lite/toco/graph_transformations/graph_transformations.h index 4008bbdb4d3..491a3e7cfb6 100644 --- a/tensorflow/lite/toco/graph_transformations/graph_transformations.h +++ b/tensorflow/lite/toco/graph_transformations/graph_transformations.h @@ -102,8 +102,16 @@ class GraphTransformationsSet { // construct GraphTransformation objects by using 'new', pass us // the resulting raw pointers, and this RunGraphTransformations // takes care of delete'ing these pointers. -void RunGraphTransformations(Model* model, const string& message, - const GraphTransformationsSet& transformations); +tensorflow::Status RunGraphTransformationsWithStatus( + Model* model, const string& msg, + const GraphTransformationsSet& transformations); + +inline void RunGraphTransformations( + Model* model, const string& msg, + const GraphTransformationsSet& transformations) { + auto s = RunGraphTransformationsWithStatus(model, msg, transformations); + CHECK(s.ok()) << s.error_message(); +} #define DECLARE_GRAPH_TRANSFORMATION(GTName) \ class GTName : public GraphTransformation { \ diff --git a/tensorflow/lite/toco/graph_transformations/quantize.cc b/tensorflow/lite/toco/graph_transformations/quantize.cc index ee65f92e00c..5a5c9bbd61a 100644 --- a/tensorflow/lite/toco/graph_transformations/quantize.cc +++ b/tensorflow/lite/toco/graph_transformations/quantize.cc @@ -489,12 +489,12 @@ void FixMinMaxPostQuantization(GraphTransformation* transformation, } } if (!SupportsQuantization(op)) { - LOG(FATAL) << "Unimplemented: this graph contains an operator of type " - << HelpfulOperatorTypeName(op) - << " for which the quantized form is not yet implemented. " - "Sorry, and patches welcome (that's a relatively fun patch " - "to write, mostly providing the actual quantized arithmetic " - "code for this op)."; + return tensorflow::errors::InvalidArgument( + "Unimplemented: this graph contains an operator of type ", + HelpfulOperatorTypeName(op), + " for which the quantized form is not yet implemented. Sorry, and " + "patches welcome (that's a relatively fun patch to write, mostly " + "providing the actual quantized arithmetic code for this op)."); } for (const auto& input : op.inputs) { diff --git a/tensorflow/lite/toco/toco_convert.cc b/tensorflow/lite/toco/toco_convert.cc index 28e7b10ecd0..2adfc1dd236 100644 --- a/tensorflow/lite/toco/toco_convert.cc +++ b/tensorflow/lite/toco/toco_convert.cc @@ -77,7 +77,7 @@ tensorflow::Status Convert(const string& graph_def_contents, string* output_file_contents) { std::unique_ptr model = Import(toco_flags, model_flags, graph_def_contents); - Transform(toco_flags, model.get()); + TF_RETURN_IF_ERROR(TransformWithStatus(toco_flags, model.get())); return Export(toco_flags, *model, toco_flags.allow_custom_ops(), output_file_contents); } diff --git a/tensorflow/lite/toco/toco_tooling.cc b/tensorflow/lite/toco/toco_tooling.cc index 69d7a7a61a5..06f51825fe0 100644 --- a/tensorflow/lite/toco/toco_tooling.cc +++ b/tensorflow/lite/toco/toco_tooling.cc @@ -236,7 +236,8 @@ std::unique_ptr Import(const TocoFlags& toco_flags, return model; } -void Transform(const TocoFlags& toco_flags, Model* model) { +tensorflow::Status TransformWithStatus(const TocoFlags& toco_flags, + Model* model) { const FileFormat output_format = toco_flags.output_format(); const IODataType inference_type = toco_flags.inference_type(); @@ -258,8 +259,8 @@ void Transform(const TocoFlags& toco_flags, Model* model) { // stop optimizations from crossing the input/output boundaries. For example // this will stop BatchNorm fusing if the output node is in between a conv // and BatchNorm layers. - RunGraphTransformations(model, "Removing unused ops", - {new toco::RemoveUnusedOp}); + TF_RETURN_IF_ERROR(RunGraphTransformationsWithStatus( + model, "Removing unused ops", {new toco::RemoveUnusedOp})); GraphTransformationsSet transformations; MakeGeneralGraphTransformationsSet(&transformations); @@ -307,20 +308,21 @@ void Transform(const TocoFlags& toco_flags, Model* model) { identify_dilated_conv->set_identify_depthwise_conv(false); } transformations.Add(identify_dilated_conv); - RunGraphTransformations(model, "general graph transformations", - transformations); + TF_RETURN_IF_ERROR(RunGraphTransformationsWithStatus( + model, "general graph transformations", transformations)); if (quantize_output) { if (toco_flags.propagate_fake_quant_num_bits()) { - RunGraphTransformations(model, - "fake quant propagation graph transformations", - {new PropagateFakeQuantNumBits}); + TF_RETURN_IF_ERROR(RunGraphTransformationsWithStatus( + model, "fake quant propagation graph transformations", + {new PropagateFakeQuantNumBits})); } - RunGraphTransformations(model, "pre-quantization graph transformations", - { - new HardcodeMinMax, - new DropFakeQuant, - }); + TF_RETURN_IF_ERROR(RunGraphTransformationsWithStatus( + model, "pre-quantization graph transformations", + { + new HardcodeMinMax, + new DropFakeQuant, + })); } // Try to merge bidirectional sequence lstm or rnn if present. @@ -328,8 +330,9 @@ void Transform(const TocoFlags& toco_flags, Model* model) { bidirectional_transformations.Add(new RemoveUnusedOp); bidirectional_transformations.Add(new toco::GroupBidirectionalSequenceLstm); bidirectional_transformations.Add(new toco::GroupBidirectionalSequenceRnn); - RunGraphTransformations(model, "Group bidirectional sequence lstm/rnn", - bidirectional_transformations); + TF_RETURN_IF_ERROR(RunGraphTransformationsWithStatus( + model, "Group bidirectional sequence lstm/rnn", + bidirectional_transformations)); // Fix any issues with IO edges. This must happen after any transform that // may modify the structure of the edges. @@ -357,12 +360,12 @@ void Transform(const TocoFlags& toco_flags, Model* model) { toco_flags.default_int16_ranges_max()); } if (propagate_default_min_max->has_any_ranges_defined()) { - RunGraphTransformations( + TF_RETURN_IF_ERROR(RunGraphTransformationsWithStatus( model, "default min-max range propagation graph transformations", { propagate_default_min_max.release(), new HardcodeMinMax, - }); + })); } CheckIsReadyForQuantization(*model); @@ -372,17 +375,18 @@ void Transform(const TocoFlags& toco_flags, Model* model) { toco_flags.allow_nudging_weights_to_use_fast_gemm_kernel()); ensure_safe_for_int8_kernels->set_has_default_ranges_flag( has_default_ranges_flag); - RunGraphTransformations(model, "quantization graph transformations", - { - new RemoveTrivialQuantizedActivationFunc, - new RemoveTrivialQuantizedMinMax, - new Quantize, - new RemoveFinalDequantizeOp, - ensure_safe_for_int8_kernels, - }); + TF_RETURN_IF_ERROR(RunGraphTransformationsWithStatus( + model, "quantization graph transformations", + { + new RemoveTrivialQuantizedActivationFunc, + new RemoveTrivialQuantizedMinMax, + new Quantize, + new RemoveFinalDequantizeOp, + ensure_safe_for_int8_kernels, + })); if (SupportsShuffledFCWeights(output_format)) { - RunGraphTransformations(model, "shuffling of FC weights", - {new ShuffleFCWeights}); + TF_RETURN_IF_ERROR(RunGraphTransformationsWithStatus( + model, "shuffling of FC weights", {new ShuffleFCWeights})); } } else { GraphTransformationsSet dequantization_transformations{new Dequantize}; @@ -392,8 +396,9 @@ void Transform(const TocoFlags& toco_flags, Model* model) { dequantization_transformations.Add(new DropFakeQuant); } - RunGraphTransformations(model, "dequantization graph transformations", - dequantization_transformations); + TF_RETURN_IF_ERROR(RunGraphTransformationsWithStatus( + model, "dequantization graph transformations", + dequantization_transformations)); } if (output_format == TENSORFLOW_GRAPHDEF) { @@ -425,6 +430,7 @@ void Transform(const TocoFlags& toco_flags, Model* model) { << " billion (note that a multiply-add is counted as 2 ops)."; } model->ops_count = ops_count; + return tensorflow::Status::OK(); } tensorflow::Status Export(const TocoFlags& toco_flags, const Model& model, diff --git a/tensorflow/lite/toco/toco_tooling.h b/tensorflow/lite/toco/toco_tooling.h index 742e3769269..36996151949 100644 --- a/tensorflow/lite/toco/toco_tooling.h +++ b/tensorflow/lite/toco/toco_tooling.h @@ -31,7 +31,12 @@ std::unique_ptr Import(const TocoFlags& toco_flags, // Transforms a Model. The resulting Model is ready to be passed // to Export with the exact same toco_flags. -void Transform(const TocoFlags& toco_flags, Model* model); +tensorflow::Status TransformWithStatus(const TocoFlags& toco_flags, + Model* model); +inline void Transform(const TocoFlags& toco_flags, Model* model) { + auto s = TransformWithStatus(toco_flags, model); + CHECK(s.ok()) << s.error_message(); +} // Exports the Model, which must be of the 'lowered' form returned by // Transform, to a file of the format given by