Fewer TOCO crashes.
PiperOrigin-RevId: 232730162
This commit is contained in:
parent
f11085ebf2
commit
ec7958da20
@ -128,7 +128,8 @@ void DiscardUselessConnectedComponentsAndRNNBackEdges(Model* model) {
|
||||
}
|
||||
|
||||
bool GraphTransformationsPass(int increment, Model* model,
|
||||
const GraphTransformationsSet& transformations) {
|
||||
const GraphTransformationsSet& transformations,
|
||||
tensorflow::Status* status) {
|
||||
CHECK(increment == 1 || increment == -1);
|
||||
bool changed = false;
|
||||
if (model->operators.empty()) {
|
||||
@ -142,7 +143,10 @@ bool GraphTransformationsPass(int increment, Model* model,
|
||||
for (const auto& transformation : transformations) {
|
||||
CHECK(!changed_now);
|
||||
CHECK(transformation->Messages().empty());
|
||||
CHECK(transformation->Run(model, op_index, &changed_now).ok());
|
||||
*status = transformation->Run(model, op_index, &changed_now);
|
||||
if (!status->ok()) {
|
||||
return false;
|
||||
}
|
||||
const char* made_a_change_msg =
|
||||
changed_now ? "made a change" : "did NOT make a change";
|
||||
const int log_level =
|
||||
@ -186,18 +190,21 @@ bool GraphTransformationsPass(int increment, Model* model,
|
||||
|
||||
} // namespace
|
||||
|
||||
void RunGraphTransformations(Model* model, const string& msg,
|
||||
const GraphTransformationsSet& transformations) {
|
||||
tensorflow::Status RunGraphTransformationsWithStatus(
|
||||
Model* model, const string& msg,
|
||||
const GraphTransformationsSet& transformations) {
|
||||
PrintModelStats(toco::port::StringF("Before %s", msg), *model);
|
||||
int pass_index = 0;
|
||||
tensorflow::Status status;
|
||||
while (GraphTransformationsPass((pass_index % 2) ? -1 : 1, model,
|
||||
transformations)) {
|
||||
transformations, &status)) {
|
||||
pass_index++;
|
||||
const auto& label =
|
||||
toco::port::StringF("After %s pass %d", msg, pass_index);
|
||||
PrintModelStats(label, *model);
|
||||
CheckInvariants(*model);
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
} // namespace toco
|
||||
|
@ -102,8 +102,16 @@ class GraphTransformationsSet {
|
||||
// construct GraphTransformation objects by using 'new', pass us
|
||||
// the resulting raw pointers, and this RunGraphTransformations
|
||||
// takes care of delete'ing these pointers.
|
||||
void RunGraphTransformations(Model* model, const string& message,
|
||||
const GraphTransformationsSet& transformations);
|
||||
tensorflow::Status RunGraphTransformationsWithStatus(
|
||||
Model* model, const string& msg,
|
||||
const GraphTransformationsSet& transformations);
|
||||
|
||||
inline void RunGraphTransformations(
|
||||
Model* model, const string& msg,
|
||||
const GraphTransformationsSet& transformations) {
|
||||
auto s = RunGraphTransformationsWithStatus(model, msg, transformations);
|
||||
CHECK(s.ok()) << s.error_message();
|
||||
}
|
||||
|
||||
#define DECLARE_GRAPH_TRANSFORMATION(GTName) \
|
||||
class GTName : public GraphTransformation { \
|
||||
|
@ -489,12 +489,12 @@ void FixMinMaxPostQuantization(GraphTransformation* transformation,
|
||||
}
|
||||
}
|
||||
if (!SupportsQuantization(op)) {
|
||||
LOG(FATAL) << "Unimplemented: this graph contains an operator of type "
|
||||
<< HelpfulOperatorTypeName(op)
|
||||
<< " for which the quantized form is not yet implemented. "
|
||||
"Sorry, and patches welcome (that's a relatively fun patch "
|
||||
"to write, mostly providing the actual quantized arithmetic "
|
||||
"code for this op).";
|
||||
return tensorflow::errors::InvalidArgument(
|
||||
"Unimplemented: this graph contains an operator of type ",
|
||||
HelpfulOperatorTypeName(op),
|
||||
" for which the quantized form is not yet implemented. Sorry, and "
|
||||
"patches welcome (that's a relatively fun patch to write, mostly "
|
||||
"providing the actual quantized arithmetic code for this op).");
|
||||
}
|
||||
|
||||
for (const auto& input : op.inputs) {
|
||||
|
@ -77,7 +77,7 @@ tensorflow::Status Convert(const string& graph_def_contents,
|
||||
string* output_file_contents) {
|
||||
std::unique_ptr<Model> model =
|
||||
Import(toco_flags, model_flags, graph_def_contents);
|
||||
Transform(toco_flags, model.get());
|
||||
TF_RETURN_IF_ERROR(TransformWithStatus(toco_flags, model.get()));
|
||||
return Export(toco_flags, *model, toco_flags.allow_custom_ops(),
|
||||
output_file_contents);
|
||||
}
|
||||
|
@ -236,7 +236,8 @@ std::unique_ptr<Model> Import(const TocoFlags& toco_flags,
|
||||
return model;
|
||||
}
|
||||
|
||||
void Transform(const TocoFlags& toco_flags, Model* model) {
|
||||
tensorflow::Status TransformWithStatus(const TocoFlags& toco_flags,
|
||||
Model* model) {
|
||||
const FileFormat output_format = toco_flags.output_format();
|
||||
const IODataType inference_type = toco_flags.inference_type();
|
||||
|
||||
@ -258,8 +259,8 @@ void Transform(const TocoFlags& toco_flags, Model* model) {
|
||||
// stop optimizations from crossing the input/output boundaries. For example
|
||||
// this will stop BatchNorm fusing if the output node is in between a conv
|
||||
// and BatchNorm layers.
|
||||
RunGraphTransformations(model, "Removing unused ops",
|
||||
{new toco::RemoveUnusedOp});
|
||||
TF_RETURN_IF_ERROR(RunGraphTransformationsWithStatus(
|
||||
model, "Removing unused ops", {new toco::RemoveUnusedOp}));
|
||||
|
||||
GraphTransformationsSet transformations;
|
||||
MakeGeneralGraphTransformationsSet(&transformations);
|
||||
@ -307,20 +308,21 @@ void Transform(const TocoFlags& toco_flags, Model* model) {
|
||||
identify_dilated_conv->set_identify_depthwise_conv(false);
|
||||
}
|
||||
transformations.Add(identify_dilated_conv);
|
||||
RunGraphTransformations(model, "general graph transformations",
|
||||
transformations);
|
||||
TF_RETURN_IF_ERROR(RunGraphTransformationsWithStatus(
|
||||
model, "general graph transformations", transformations));
|
||||
|
||||
if (quantize_output) {
|
||||
if (toco_flags.propagate_fake_quant_num_bits()) {
|
||||
RunGraphTransformations(model,
|
||||
"fake quant propagation graph transformations",
|
||||
{new PropagateFakeQuantNumBits});
|
||||
TF_RETURN_IF_ERROR(RunGraphTransformationsWithStatus(
|
||||
model, "fake quant propagation graph transformations",
|
||||
{new PropagateFakeQuantNumBits}));
|
||||
}
|
||||
RunGraphTransformations(model, "pre-quantization graph transformations",
|
||||
{
|
||||
new HardcodeMinMax,
|
||||
new DropFakeQuant,
|
||||
});
|
||||
TF_RETURN_IF_ERROR(RunGraphTransformationsWithStatus(
|
||||
model, "pre-quantization graph transformations",
|
||||
{
|
||||
new HardcodeMinMax,
|
||||
new DropFakeQuant,
|
||||
}));
|
||||
}
|
||||
|
||||
// Try to merge bidirectional sequence lstm or rnn if present.
|
||||
@ -328,8 +330,9 @@ void Transform(const TocoFlags& toco_flags, Model* model) {
|
||||
bidirectional_transformations.Add(new RemoveUnusedOp);
|
||||
bidirectional_transformations.Add(new toco::GroupBidirectionalSequenceLstm);
|
||||
bidirectional_transformations.Add(new toco::GroupBidirectionalSequenceRnn);
|
||||
RunGraphTransformations(model, "Group bidirectional sequence lstm/rnn",
|
||||
bidirectional_transformations);
|
||||
TF_RETURN_IF_ERROR(RunGraphTransformationsWithStatus(
|
||||
model, "Group bidirectional sequence lstm/rnn",
|
||||
bidirectional_transformations));
|
||||
|
||||
// Fix any issues with IO edges. This must happen after any transform that
|
||||
// may modify the structure of the edges.
|
||||
@ -357,12 +360,12 @@ void Transform(const TocoFlags& toco_flags, Model* model) {
|
||||
toco_flags.default_int16_ranges_max());
|
||||
}
|
||||
if (propagate_default_min_max->has_any_ranges_defined()) {
|
||||
RunGraphTransformations(
|
||||
TF_RETURN_IF_ERROR(RunGraphTransformationsWithStatus(
|
||||
model, "default min-max range propagation graph transformations",
|
||||
{
|
||||
propagate_default_min_max.release(),
|
||||
new HardcodeMinMax,
|
||||
});
|
||||
}));
|
||||
}
|
||||
|
||||
CheckIsReadyForQuantization(*model);
|
||||
@ -372,17 +375,18 @@ void Transform(const TocoFlags& toco_flags, Model* model) {
|
||||
toco_flags.allow_nudging_weights_to_use_fast_gemm_kernel());
|
||||
ensure_safe_for_int8_kernels->set_has_default_ranges_flag(
|
||||
has_default_ranges_flag);
|
||||
RunGraphTransformations(model, "quantization graph transformations",
|
||||
{
|
||||
new RemoveTrivialQuantizedActivationFunc,
|
||||
new RemoveTrivialQuantizedMinMax,
|
||||
new Quantize,
|
||||
new RemoveFinalDequantizeOp,
|
||||
ensure_safe_for_int8_kernels,
|
||||
});
|
||||
TF_RETURN_IF_ERROR(RunGraphTransformationsWithStatus(
|
||||
model, "quantization graph transformations",
|
||||
{
|
||||
new RemoveTrivialQuantizedActivationFunc,
|
||||
new RemoveTrivialQuantizedMinMax,
|
||||
new Quantize,
|
||||
new RemoveFinalDequantizeOp,
|
||||
ensure_safe_for_int8_kernels,
|
||||
}));
|
||||
if (SupportsShuffledFCWeights(output_format)) {
|
||||
RunGraphTransformations(model, "shuffling of FC weights",
|
||||
{new ShuffleFCWeights});
|
||||
TF_RETURN_IF_ERROR(RunGraphTransformationsWithStatus(
|
||||
model, "shuffling of FC weights", {new ShuffleFCWeights}));
|
||||
}
|
||||
} else {
|
||||
GraphTransformationsSet dequantization_transformations{new Dequantize};
|
||||
@ -392,8 +396,9 @@ void Transform(const TocoFlags& toco_flags, Model* model) {
|
||||
dequantization_transformations.Add(new DropFakeQuant);
|
||||
}
|
||||
|
||||
RunGraphTransformations(model, "dequantization graph transformations",
|
||||
dequantization_transformations);
|
||||
TF_RETURN_IF_ERROR(RunGraphTransformationsWithStatus(
|
||||
model, "dequantization graph transformations",
|
||||
dequantization_transformations));
|
||||
}
|
||||
|
||||
if (output_format == TENSORFLOW_GRAPHDEF) {
|
||||
@ -425,6 +430,7 @@ void Transform(const TocoFlags& toco_flags, Model* model) {
|
||||
<< " billion (note that a multiply-add is counted as 2 ops).";
|
||||
}
|
||||
model->ops_count = ops_count;
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
tensorflow::Status Export(const TocoFlags& toco_flags, const Model& model,
|
||||
|
@ -31,7 +31,12 @@ std::unique_ptr<Model> Import(const TocoFlags& toco_flags,
|
||||
|
||||
// Transforms a Model. The resulting Model is ready to be passed
|
||||
// to Export with the exact same toco_flags.
|
||||
void Transform(const TocoFlags& toco_flags, Model* model);
|
||||
tensorflow::Status TransformWithStatus(const TocoFlags& toco_flags,
|
||||
Model* model);
|
||||
inline void Transform(const TocoFlags& toco_flags, Model* model) {
|
||||
auto s = TransformWithStatus(toco_flags, model);
|
||||
CHECK(s.ok()) << s.error_message();
|
||||
}
|
||||
|
||||
// Exports the Model, which must be of the 'lowered' form returned by
|
||||
// Transform, to a file of the format given by
|
||||
|
Loading…
Reference in New Issue
Block a user