Fewer TOCO crashes.
PiperOrigin-RevId: 232730162
This commit is contained in:
parent
f11085ebf2
commit
ec7958da20
@ -128,7 +128,8 @@ void DiscardUselessConnectedComponentsAndRNNBackEdges(Model* model) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
bool GraphTransformationsPass(int increment, Model* model,
|
bool GraphTransformationsPass(int increment, Model* model,
|
||||||
const GraphTransformationsSet& transformations) {
|
const GraphTransformationsSet& transformations,
|
||||||
|
tensorflow::Status* status) {
|
||||||
CHECK(increment == 1 || increment == -1);
|
CHECK(increment == 1 || increment == -1);
|
||||||
bool changed = false;
|
bool changed = false;
|
||||||
if (model->operators.empty()) {
|
if (model->operators.empty()) {
|
||||||
@ -142,7 +143,10 @@ bool GraphTransformationsPass(int increment, Model* model,
|
|||||||
for (const auto& transformation : transformations) {
|
for (const auto& transformation : transformations) {
|
||||||
CHECK(!changed_now);
|
CHECK(!changed_now);
|
||||||
CHECK(transformation->Messages().empty());
|
CHECK(transformation->Messages().empty());
|
||||||
CHECK(transformation->Run(model, op_index, &changed_now).ok());
|
*status = transformation->Run(model, op_index, &changed_now);
|
||||||
|
if (!status->ok()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
const char* made_a_change_msg =
|
const char* made_a_change_msg =
|
||||||
changed_now ? "made a change" : "did NOT make a change";
|
changed_now ? "made a change" : "did NOT make a change";
|
||||||
const int log_level =
|
const int log_level =
|
||||||
@ -186,18 +190,21 @@ bool GraphTransformationsPass(int increment, Model* model,
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void RunGraphTransformations(Model* model, const string& msg,
|
tensorflow::Status RunGraphTransformationsWithStatus(
|
||||||
|
Model* model, const string& msg,
|
||||||
const GraphTransformationsSet& transformations) {
|
const GraphTransformationsSet& transformations) {
|
||||||
PrintModelStats(toco::port::StringF("Before %s", msg), *model);
|
PrintModelStats(toco::port::StringF("Before %s", msg), *model);
|
||||||
int pass_index = 0;
|
int pass_index = 0;
|
||||||
|
tensorflow::Status status;
|
||||||
while (GraphTransformationsPass((pass_index % 2) ? -1 : 1, model,
|
while (GraphTransformationsPass((pass_index % 2) ? -1 : 1, model,
|
||||||
transformations)) {
|
transformations, &status)) {
|
||||||
pass_index++;
|
pass_index++;
|
||||||
const auto& label =
|
const auto& label =
|
||||||
toco::port::StringF("After %s pass %d", msg, pass_index);
|
toco::port::StringF("After %s pass %d", msg, pass_index);
|
||||||
PrintModelStats(label, *model);
|
PrintModelStats(label, *model);
|
||||||
CheckInvariants(*model);
|
CheckInvariants(*model);
|
||||||
}
|
}
|
||||||
|
return status;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace toco
|
} // namespace toco
|
||||||
|
@ -102,9 +102,17 @@ class GraphTransformationsSet {
|
|||||||
// construct GraphTransformation objects by using 'new', pass us
|
// construct GraphTransformation objects by using 'new', pass us
|
||||||
// the resulting raw pointers, and this RunGraphTransformations
|
// the resulting raw pointers, and this RunGraphTransformations
|
||||||
// takes care of delete'ing these pointers.
|
// takes care of delete'ing these pointers.
|
||||||
void RunGraphTransformations(Model* model, const string& message,
|
tensorflow::Status RunGraphTransformationsWithStatus(
|
||||||
|
Model* model, const string& msg,
|
||||||
const GraphTransformationsSet& transformations);
|
const GraphTransformationsSet& transformations);
|
||||||
|
|
||||||
|
inline void RunGraphTransformations(
|
||||||
|
Model* model, const string& msg,
|
||||||
|
const GraphTransformationsSet& transformations) {
|
||||||
|
auto s = RunGraphTransformationsWithStatus(model, msg, transformations);
|
||||||
|
CHECK(s.ok()) << s.error_message();
|
||||||
|
}
|
||||||
|
|
||||||
#define DECLARE_GRAPH_TRANSFORMATION(GTName) \
|
#define DECLARE_GRAPH_TRANSFORMATION(GTName) \
|
||||||
class GTName : public GraphTransformation { \
|
class GTName : public GraphTransformation { \
|
||||||
public: \
|
public: \
|
||||||
|
@ -489,12 +489,12 @@ void FixMinMaxPostQuantization(GraphTransformation* transformation,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (!SupportsQuantization(op)) {
|
if (!SupportsQuantization(op)) {
|
||||||
LOG(FATAL) << "Unimplemented: this graph contains an operator of type "
|
return tensorflow::errors::InvalidArgument(
|
||||||
<< HelpfulOperatorTypeName(op)
|
"Unimplemented: this graph contains an operator of type ",
|
||||||
<< " for which the quantized form is not yet implemented. "
|
HelpfulOperatorTypeName(op),
|
||||||
"Sorry, and patches welcome (that's a relatively fun patch "
|
" for which the quantized form is not yet implemented. Sorry, and "
|
||||||
"to write, mostly providing the actual quantized arithmetic "
|
"patches welcome (that's a relatively fun patch to write, mostly "
|
||||||
"code for this op).";
|
"providing the actual quantized arithmetic code for this op).");
|
||||||
}
|
}
|
||||||
|
|
||||||
for (const auto& input : op.inputs) {
|
for (const auto& input : op.inputs) {
|
||||||
|
@ -77,7 +77,7 @@ tensorflow::Status Convert(const string& graph_def_contents,
|
|||||||
string* output_file_contents) {
|
string* output_file_contents) {
|
||||||
std::unique_ptr<Model> model =
|
std::unique_ptr<Model> model =
|
||||||
Import(toco_flags, model_flags, graph_def_contents);
|
Import(toco_flags, model_flags, graph_def_contents);
|
||||||
Transform(toco_flags, model.get());
|
TF_RETURN_IF_ERROR(TransformWithStatus(toco_flags, model.get()));
|
||||||
return Export(toco_flags, *model, toco_flags.allow_custom_ops(),
|
return Export(toco_flags, *model, toco_flags.allow_custom_ops(),
|
||||||
output_file_contents);
|
output_file_contents);
|
||||||
}
|
}
|
||||||
|
@ -236,7 +236,8 @@ std::unique_ptr<Model> Import(const TocoFlags& toco_flags,
|
|||||||
return model;
|
return model;
|
||||||
}
|
}
|
||||||
|
|
||||||
void Transform(const TocoFlags& toco_flags, Model* model) {
|
tensorflow::Status TransformWithStatus(const TocoFlags& toco_flags,
|
||||||
|
Model* model) {
|
||||||
const FileFormat output_format = toco_flags.output_format();
|
const FileFormat output_format = toco_flags.output_format();
|
||||||
const IODataType inference_type = toco_flags.inference_type();
|
const IODataType inference_type = toco_flags.inference_type();
|
||||||
|
|
||||||
@ -258,8 +259,8 @@ void Transform(const TocoFlags& toco_flags, Model* model) {
|
|||||||
// stop optimizations from crossing the input/output boundaries. For example
|
// stop optimizations from crossing the input/output boundaries. For example
|
||||||
// this will stop BatchNorm fusing if the output node is in between a conv
|
// this will stop BatchNorm fusing if the output node is in between a conv
|
||||||
// and BatchNorm layers.
|
// and BatchNorm layers.
|
||||||
RunGraphTransformations(model, "Removing unused ops",
|
TF_RETURN_IF_ERROR(RunGraphTransformationsWithStatus(
|
||||||
{new toco::RemoveUnusedOp});
|
model, "Removing unused ops", {new toco::RemoveUnusedOp}));
|
||||||
|
|
||||||
GraphTransformationsSet transformations;
|
GraphTransformationsSet transformations;
|
||||||
MakeGeneralGraphTransformationsSet(&transformations);
|
MakeGeneralGraphTransformationsSet(&transformations);
|
||||||
@ -307,20 +308,21 @@ void Transform(const TocoFlags& toco_flags, Model* model) {
|
|||||||
identify_dilated_conv->set_identify_depthwise_conv(false);
|
identify_dilated_conv->set_identify_depthwise_conv(false);
|
||||||
}
|
}
|
||||||
transformations.Add(identify_dilated_conv);
|
transformations.Add(identify_dilated_conv);
|
||||||
RunGraphTransformations(model, "general graph transformations",
|
TF_RETURN_IF_ERROR(RunGraphTransformationsWithStatus(
|
||||||
transformations);
|
model, "general graph transformations", transformations));
|
||||||
|
|
||||||
if (quantize_output) {
|
if (quantize_output) {
|
||||||
if (toco_flags.propagate_fake_quant_num_bits()) {
|
if (toco_flags.propagate_fake_quant_num_bits()) {
|
||||||
RunGraphTransformations(model,
|
TF_RETURN_IF_ERROR(RunGraphTransformationsWithStatus(
|
||||||
"fake quant propagation graph transformations",
|
model, "fake quant propagation graph transformations",
|
||||||
{new PropagateFakeQuantNumBits});
|
{new PropagateFakeQuantNumBits}));
|
||||||
}
|
}
|
||||||
RunGraphTransformations(model, "pre-quantization graph transformations",
|
TF_RETURN_IF_ERROR(RunGraphTransformationsWithStatus(
|
||||||
|
model, "pre-quantization graph transformations",
|
||||||
{
|
{
|
||||||
new HardcodeMinMax,
|
new HardcodeMinMax,
|
||||||
new DropFakeQuant,
|
new DropFakeQuant,
|
||||||
});
|
}));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Try to merge bidirectional sequence lstm or rnn if present.
|
// Try to merge bidirectional sequence lstm or rnn if present.
|
||||||
@ -328,8 +330,9 @@ void Transform(const TocoFlags& toco_flags, Model* model) {
|
|||||||
bidirectional_transformations.Add(new RemoveUnusedOp);
|
bidirectional_transformations.Add(new RemoveUnusedOp);
|
||||||
bidirectional_transformations.Add(new toco::GroupBidirectionalSequenceLstm);
|
bidirectional_transformations.Add(new toco::GroupBidirectionalSequenceLstm);
|
||||||
bidirectional_transformations.Add(new toco::GroupBidirectionalSequenceRnn);
|
bidirectional_transformations.Add(new toco::GroupBidirectionalSequenceRnn);
|
||||||
RunGraphTransformations(model, "Group bidirectional sequence lstm/rnn",
|
TF_RETURN_IF_ERROR(RunGraphTransformationsWithStatus(
|
||||||
bidirectional_transformations);
|
model, "Group bidirectional sequence lstm/rnn",
|
||||||
|
bidirectional_transformations));
|
||||||
|
|
||||||
// Fix any issues with IO edges. This must happen after any transform that
|
// Fix any issues with IO edges. This must happen after any transform that
|
||||||
// may modify the structure of the edges.
|
// may modify the structure of the edges.
|
||||||
@ -357,12 +360,12 @@ void Transform(const TocoFlags& toco_flags, Model* model) {
|
|||||||
toco_flags.default_int16_ranges_max());
|
toco_flags.default_int16_ranges_max());
|
||||||
}
|
}
|
||||||
if (propagate_default_min_max->has_any_ranges_defined()) {
|
if (propagate_default_min_max->has_any_ranges_defined()) {
|
||||||
RunGraphTransformations(
|
TF_RETURN_IF_ERROR(RunGraphTransformationsWithStatus(
|
||||||
model, "default min-max range propagation graph transformations",
|
model, "default min-max range propagation graph transformations",
|
||||||
{
|
{
|
||||||
propagate_default_min_max.release(),
|
propagate_default_min_max.release(),
|
||||||
new HardcodeMinMax,
|
new HardcodeMinMax,
|
||||||
});
|
}));
|
||||||
}
|
}
|
||||||
|
|
||||||
CheckIsReadyForQuantization(*model);
|
CheckIsReadyForQuantization(*model);
|
||||||
@ -372,17 +375,18 @@ void Transform(const TocoFlags& toco_flags, Model* model) {
|
|||||||
toco_flags.allow_nudging_weights_to_use_fast_gemm_kernel());
|
toco_flags.allow_nudging_weights_to_use_fast_gemm_kernel());
|
||||||
ensure_safe_for_int8_kernels->set_has_default_ranges_flag(
|
ensure_safe_for_int8_kernels->set_has_default_ranges_flag(
|
||||||
has_default_ranges_flag);
|
has_default_ranges_flag);
|
||||||
RunGraphTransformations(model, "quantization graph transformations",
|
TF_RETURN_IF_ERROR(RunGraphTransformationsWithStatus(
|
||||||
|
model, "quantization graph transformations",
|
||||||
{
|
{
|
||||||
new RemoveTrivialQuantizedActivationFunc,
|
new RemoveTrivialQuantizedActivationFunc,
|
||||||
new RemoveTrivialQuantizedMinMax,
|
new RemoveTrivialQuantizedMinMax,
|
||||||
new Quantize,
|
new Quantize,
|
||||||
new RemoveFinalDequantizeOp,
|
new RemoveFinalDequantizeOp,
|
||||||
ensure_safe_for_int8_kernels,
|
ensure_safe_for_int8_kernels,
|
||||||
});
|
}));
|
||||||
if (SupportsShuffledFCWeights(output_format)) {
|
if (SupportsShuffledFCWeights(output_format)) {
|
||||||
RunGraphTransformations(model, "shuffling of FC weights",
|
TF_RETURN_IF_ERROR(RunGraphTransformationsWithStatus(
|
||||||
{new ShuffleFCWeights});
|
model, "shuffling of FC weights", {new ShuffleFCWeights}));
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
GraphTransformationsSet dequantization_transformations{new Dequantize};
|
GraphTransformationsSet dequantization_transformations{new Dequantize};
|
||||||
@ -392,8 +396,9 @@ void Transform(const TocoFlags& toco_flags, Model* model) {
|
|||||||
dequantization_transformations.Add(new DropFakeQuant);
|
dequantization_transformations.Add(new DropFakeQuant);
|
||||||
}
|
}
|
||||||
|
|
||||||
RunGraphTransformations(model, "dequantization graph transformations",
|
TF_RETURN_IF_ERROR(RunGraphTransformationsWithStatus(
|
||||||
dequantization_transformations);
|
model, "dequantization graph transformations",
|
||||||
|
dequantization_transformations));
|
||||||
}
|
}
|
||||||
|
|
||||||
if (output_format == TENSORFLOW_GRAPHDEF) {
|
if (output_format == TENSORFLOW_GRAPHDEF) {
|
||||||
@ -425,6 +430,7 @@ void Transform(const TocoFlags& toco_flags, Model* model) {
|
|||||||
<< " billion (note that a multiply-add is counted as 2 ops).";
|
<< " billion (note that a multiply-add is counted as 2 ops).";
|
||||||
}
|
}
|
||||||
model->ops_count = ops_count;
|
model->ops_count = ops_count;
|
||||||
|
return tensorflow::Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
tensorflow::Status Export(const TocoFlags& toco_flags, const Model& model,
|
tensorflow::Status Export(const TocoFlags& toco_flags, const Model& model,
|
||||||
|
@ -31,7 +31,12 @@ std::unique_ptr<Model> Import(const TocoFlags& toco_flags,
|
|||||||
|
|
||||||
// Transforms a Model. The resulting Model is ready to be passed
|
// Transforms a Model. The resulting Model is ready to be passed
|
||||||
// to Export with the exact same toco_flags.
|
// to Export with the exact same toco_flags.
|
||||||
void Transform(const TocoFlags& toco_flags, Model* model);
|
tensorflow::Status TransformWithStatus(const TocoFlags& toco_flags,
|
||||||
|
Model* model);
|
||||||
|
inline void Transform(const TocoFlags& toco_flags, Model* model) {
|
||||||
|
auto s = TransformWithStatus(toco_flags, model);
|
||||||
|
CHECK(s.ok()) << s.error_message();
|
||||||
|
}
|
||||||
|
|
||||||
// Exports the Model, which must be of the 'lowered' form returned by
|
// Exports the Model, which must be of the 'lowered' form returned by
|
||||||
// Transform, to a file of the format given by
|
// Transform, to a file of the format given by
|
||||||
|
Loading…
Reference in New Issue
Block a user