Fewer TOCO crashes.

PiperOrigin-RevId: 232730162
This commit is contained in:
A. Unique TensorFlower 2019-02-06 13:05:13 -08:00 committed by TensorFlower Gardener
parent f11085ebf2
commit ec7958da20
6 changed files with 70 additions and 44 deletions

View File

@ -128,7 +128,8 @@ void DiscardUselessConnectedComponentsAndRNNBackEdges(Model* model) {
}
bool GraphTransformationsPass(int increment, Model* model,
const GraphTransformationsSet& transformations) {
const GraphTransformationsSet& transformations,
tensorflow::Status* status) {
CHECK(increment == 1 || increment == -1);
bool changed = false;
if (model->operators.empty()) {
@ -142,7 +143,10 @@ bool GraphTransformationsPass(int increment, Model* model,
for (const auto& transformation : transformations) {
CHECK(!changed_now);
CHECK(transformation->Messages().empty());
CHECK(transformation->Run(model, op_index, &changed_now).ok());
*status = transformation->Run(model, op_index, &changed_now);
if (!status->ok()) {
return false;
}
const char* made_a_change_msg =
changed_now ? "made a change" : "did NOT make a change";
const int log_level =
@ -186,18 +190,21 @@ bool GraphTransformationsPass(int increment, Model* model,
} // namespace
void RunGraphTransformations(Model* model, const string& msg,
const GraphTransformationsSet& transformations) {
tensorflow::Status RunGraphTransformationsWithStatus(
Model* model, const string& msg,
const GraphTransformationsSet& transformations) {
PrintModelStats(toco::port::StringF("Before %s", msg), *model);
int pass_index = 0;
tensorflow::Status status;
while (GraphTransformationsPass((pass_index % 2) ? -1 : 1, model,
transformations)) {
transformations, &status)) {
pass_index++;
const auto& label =
toco::port::StringF("After %s pass %d", msg, pass_index);
PrintModelStats(label, *model);
CheckInvariants(*model);
}
return status;
}
} // namespace toco

View File

@ -102,8 +102,16 @@ class GraphTransformationsSet {
// construct GraphTransformation objects by using 'new', pass us
// the resulting raw pointers, and this RunGraphTransformations
// takes care of delete'ing these pointers.
void RunGraphTransformations(Model* model, const string& message,
const GraphTransformationsSet& transformations);
tensorflow::Status RunGraphTransformationsWithStatus(
Model* model, const string& msg,
const GraphTransformationsSet& transformations);
inline void RunGraphTransformations(
Model* model, const string& msg,
const GraphTransformationsSet& transformations) {
auto s = RunGraphTransformationsWithStatus(model, msg, transformations);
CHECK(s.ok()) << s.error_message();
}
#define DECLARE_GRAPH_TRANSFORMATION(GTName) \
class GTName : public GraphTransformation { \

View File

@ -489,12 +489,12 @@ void FixMinMaxPostQuantization(GraphTransformation* transformation,
}
}
if (!SupportsQuantization(op)) {
LOG(FATAL) << "Unimplemented: this graph contains an operator of type "
<< HelpfulOperatorTypeName(op)
<< " for which the quantized form is not yet implemented. "
"Sorry, and patches welcome (that's a relatively fun patch "
"to write, mostly providing the actual quantized arithmetic "
"code for this op).";
return tensorflow::errors::InvalidArgument(
"Unimplemented: this graph contains an operator of type ",
HelpfulOperatorTypeName(op),
" for which the quantized form is not yet implemented. Sorry, and "
"patches welcome (that's a relatively fun patch to write, mostly "
"providing the actual quantized arithmetic code for this op).");
}
for (const auto& input : op.inputs) {

View File

@ -77,7 +77,7 @@ tensorflow::Status Convert(const string& graph_def_contents,
string* output_file_contents) {
std::unique_ptr<Model> model =
Import(toco_flags, model_flags, graph_def_contents);
Transform(toco_flags, model.get());
TF_RETURN_IF_ERROR(TransformWithStatus(toco_flags, model.get()));
return Export(toco_flags, *model, toco_flags.allow_custom_ops(),
output_file_contents);
}

View File

@ -236,7 +236,8 @@ std::unique_ptr<Model> Import(const TocoFlags& toco_flags,
return model;
}
void Transform(const TocoFlags& toco_flags, Model* model) {
tensorflow::Status TransformWithStatus(const TocoFlags& toco_flags,
Model* model) {
const FileFormat output_format = toco_flags.output_format();
const IODataType inference_type = toco_flags.inference_type();
@ -258,8 +259,8 @@ void Transform(const TocoFlags& toco_flags, Model* model) {
// stop optimizations from crossing the input/output boundaries. For example
// this will stop BatchNorm fusing if the output node is in between a conv
// and BatchNorm layers.
RunGraphTransformations(model, "Removing unused ops",
{new toco::RemoveUnusedOp});
TF_RETURN_IF_ERROR(RunGraphTransformationsWithStatus(
model, "Removing unused ops", {new toco::RemoveUnusedOp}));
GraphTransformationsSet transformations;
MakeGeneralGraphTransformationsSet(&transformations);
@ -307,20 +308,21 @@ void Transform(const TocoFlags& toco_flags, Model* model) {
identify_dilated_conv->set_identify_depthwise_conv(false);
}
transformations.Add(identify_dilated_conv);
RunGraphTransformations(model, "general graph transformations",
transformations);
TF_RETURN_IF_ERROR(RunGraphTransformationsWithStatus(
model, "general graph transformations", transformations));
if (quantize_output) {
if (toco_flags.propagate_fake_quant_num_bits()) {
RunGraphTransformations(model,
"fake quant propagation graph transformations",
{new PropagateFakeQuantNumBits});
TF_RETURN_IF_ERROR(RunGraphTransformationsWithStatus(
model, "fake quant propagation graph transformations",
{new PropagateFakeQuantNumBits}));
}
RunGraphTransformations(model, "pre-quantization graph transformations",
{
new HardcodeMinMax,
new DropFakeQuant,
});
TF_RETURN_IF_ERROR(RunGraphTransformationsWithStatus(
model, "pre-quantization graph transformations",
{
new HardcodeMinMax,
new DropFakeQuant,
}));
}
// Try to merge bidirectional sequence lstm or rnn if present.
@ -328,8 +330,9 @@ void Transform(const TocoFlags& toco_flags, Model* model) {
bidirectional_transformations.Add(new RemoveUnusedOp);
bidirectional_transformations.Add(new toco::GroupBidirectionalSequenceLstm);
bidirectional_transformations.Add(new toco::GroupBidirectionalSequenceRnn);
RunGraphTransformations(model, "Group bidirectional sequence lstm/rnn",
bidirectional_transformations);
TF_RETURN_IF_ERROR(RunGraphTransformationsWithStatus(
model, "Group bidirectional sequence lstm/rnn",
bidirectional_transformations));
// Fix any issues with IO edges. This must happen after any transform that
// may modify the structure of the edges.
@ -357,12 +360,12 @@ void Transform(const TocoFlags& toco_flags, Model* model) {
toco_flags.default_int16_ranges_max());
}
if (propagate_default_min_max->has_any_ranges_defined()) {
RunGraphTransformations(
TF_RETURN_IF_ERROR(RunGraphTransformationsWithStatus(
model, "default min-max range propagation graph transformations",
{
propagate_default_min_max.release(),
new HardcodeMinMax,
});
}));
}
CheckIsReadyForQuantization(*model);
@ -372,17 +375,18 @@ void Transform(const TocoFlags& toco_flags, Model* model) {
toco_flags.allow_nudging_weights_to_use_fast_gemm_kernel());
ensure_safe_for_int8_kernels->set_has_default_ranges_flag(
has_default_ranges_flag);
RunGraphTransformations(model, "quantization graph transformations",
{
new RemoveTrivialQuantizedActivationFunc,
new RemoveTrivialQuantizedMinMax,
new Quantize,
new RemoveFinalDequantizeOp,
ensure_safe_for_int8_kernels,
});
TF_RETURN_IF_ERROR(RunGraphTransformationsWithStatus(
model, "quantization graph transformations",
{
new RemoveTrivialQuantizedActivationFunc,
new RemoveTrivialQuantizedMinMax,
new Quantize,
new RemoveFinalDequantizeOp,
ensure_safe_for_int8_kernels,
}));
if (SupportsShuffledFCWeights(output_format)) {
RunGraphTransformations(model, "shuffling of FC weights",
{new ShuffleFCWeights});
TF_RETURN_IF_ERROR(RunGraphTransformationsWithStatus(
model, "shuffling of FC weights", {new ShuffleFCWeights}));
}
} else {
GraphTransformationsSet dequantization_transformations{new Dequantize};
@ -392,8 +396,9 @@ void Transform(const TocoFlags& toco_flags, Model* model) {
dequantization_transformations.Add(new DropFakeQuant);
}
RunGraphTransformations(model, "dequantization graph transformations",
dequantization_transformations);
TF_RETURN_IF_ERROR(RunGraphTransformationsWithStatus(
model, "dequantization graph transformations",
dequantization_transformations));
}
if (output_format == TENSORFLOW_GRAPHDEF) {
@ -425,6 +430,7 @@ void Transform(const TocoFlags& toco_flags, Model* model) {
<< " billion (note that a multiply-add is counted as 2 ops).";
}
model->ops_count = ops_count;
return tensorflow::Status::OK();
}
tensorflow::Status Export(const TocoFlags& toco_flags, const Model& model,

View File

@ -31,7 +31,12 @@ std::unique_ptr<Model> Import(const TocoFlags& toco_flags,
// Transforms a Model. The resulting Model is ready to be passed
// to Export with the exact same toco_flags.
void Transform(const TocoFlags& toco_flags, Model* model);
tensorflow::Status TransformWithStatus(const TocoFlags& toco_flags,
Model* model);
inline void Transform(const TocoFlags& toco_flags, Model* model) {
auto s = TransformWithStatus(toco_flags, model);
CHECK(s.ok()) << s.error_message();
}
// Exports the Model, which must be of the 'lowered' form returned by
// Transform, to a file of the format given by