Fix the deletion of ops and their arrays in toco:

1. Fix DeleteOpAndArraysIfUnused, it was not deleting the op's output arrays.
2. Rename it to just DeleteOpAndArrays (was not clear what 'unused' referred to).
3. Use that throughout toco graph transformations.
4. Some extra fixes, and removed DeleteArrayIfUsedOnce.

PiperOrigin-RevId: 253285001
This commit is contained in:
Benoit Jacob 2019-06-14 13:24:56 -07:00 committed by TensorFlower Gardener
parent 119de52436
commit 0e11e67490
57 changed files with 156 additions and 469 deletions

View File

@ -94,10 +94,8 @@ namespace toco {
} }
// Replace the operator in the graph. // Replace the operator in the graph.
const auto reshape_it = model->operators.emplace(expand_it, reshape_op); model->operators.emplace(expand_it, reshape_op);
expand_it = reshape_it + 1; DeleteOpAndArrays(model, expand_op);
CHECK_EQ(expand_it->get(), expand_op);
model->operators.erase(expand_it);
*modified = true; *modified = true;
return ::tensorflow::Status::OK(); return ::tensorflow::Status::OK();

View File

@ -92,11 +92,8 @@ namespace toco {
depthwiseconv_op->stride_width = conv_op->stride_width; depthwiseconv_op->stride_width = conv_op->stride_width;
depthwiseconv_op->depth_multiplier = weights_array.shape().dims(0); depthwiseconv_op->depth_multiplier = weights_array.shape().dims(0);
// Replace the operator in the graph. // Replace the operator in the graph.
const auto depthwiseconv_it =
model->operators.emplace(conv_it, depthwiseconv_op); model->operators.emplace(conv_it, depthwiseconv_op);
conv_it = depthwiseconv_it + 1; DeleteOpAndArrays(model, conv_op);
CHECK_EQ(conv_it->get(), conv_op);
model->operators.erase(conv_it);
// Shuffle the weights. // Shuffle the weights.
const auto& weights_shape = weights_array.shape(); const auto& weights_shape = weights_array.shape();
auto& weights_buffer = auto& weights_buffer =

View File

@ -132,20 +132,16 @@ TransposeOperator* CreateTransposeFromReorderAxes(
// order of the elements does not change. // order of the elements does not change.
auto* reshape_op = auto* reshape_op =
CreateReshapeFromReorderAxes(model, reorder_op, input_shape); CreateReshapeFromReorderAxes(model, reorder_op, input_shape);
const auto reshape_it = model->operators.emplace(reorder_it, reshape_op); model->operators.emplace(reorder_it, reshape_op);
reorder_it = reshape_it + 1;
} else { } else {
// Add Transpose operator into the graph. // Add Transpose operator into the graph.
auto* transpose_op = CreateTransposeFromReorderAxes( auto* transpose_op = CreateTransposeFromReorderAxes(
model, reorder_op, input_shape, input_axes_order, output_axes_order); model, reorder_op, input_shape, input_axes_order, output_axes_order);
const auto transpose_it =
model->operators.emplace(reorder_it, transpose_op); model->operators.emplace(reorder_it, transpose_op);
reorder_it = transpose_it + 1;
} }
// Remove ReorderAxes operator from the graph. // Remove ReorderAxes operator from the graph.
CHECK_EQ(reorder_it->get(), reorder_op); DeleteOpAndArrays(model, reorder_op);
model->operators.erase(reorder_it);
*modified = true; *modified = true;
return ::tensorflow::Status::OK(); return ::tensorflow::Status::OK();

View File

@ -77,10 +77,8 @@ namespace toco {
LogName(*reshape_op)); LogName(*reshape_op));
// Replace the operator in the graph. // Replace the operator in the graph.
const auto reshape_it = model->operators.emplace(squeeze_it, reshape_op); model->operators.emplace(squeeze_it, reshape_op);
squeeze_it = reshape_it + 1; DeleteOpAndArrays(model, squeeze_op);
CHECK_EQ(squeeze_it->get(), squeeze_op);
model->operators.erase(squeeze_it);
*modified = true; *modified = true;
return ::tensorflow::Status::OK(); return ::tensorflow::Status::OK();

View File

@ -12,9 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
#include "tensorflow/lite/toco/model.h" #include "tensorflow/lite/toco/model.h"
#include "tensorflow/core/platform/logging.h" #include "tensorflow/lite/toco/tooling_util.h"
namespace toco { namespace toco {
@ -44,10 +45,8 @@ namespace toco {
add_op->outputs = addn_op->outputs; add_op->outputs = addn_op->outputs;
// Replace the AddN operator in the graph. // Replace the AddN operator in the graph.
const auto add_it = model->operators.emplace(addn_it, add_op); model->operators.emplace(addn_it, add_op);
addn_it = add_it + 1; DeleteOpAndArrays(model, addn_op);
CHECK_EQ(addn_it->get(), addn_op);
model->operators.erase(addn_it);
*modified = true; *modified = true;
return ::tensorflow::Status::OK(); return ::tensorflow::Status::OK();
} }

View File

@ -77,10 +77,8 @@ namespace toco {
} }
// Replace the operator in the graph. // Replace the operator in the graph.
const auto reshape_it = model->operators.emplace(pack_it, reshape_op); model->operators.emplace(pack_it, reshape_op);
pack_it = reshape_it + 1; DeleteOpAndArrays(model, pack_op);
CHECK_EQ(pack_it->get(), pack_op);
model->operators.erase(pack_it);
*modified = true; *modified = true;
return ::tensorflow::Status::OK(); return ::tensorflow::Status::OK();

View File

@ -86,10 +86,8 @@ namespace toco {
} }
// Replace the operator in the graph. // Replace the operator in the graph.
const auto concat_it = model->operators.emplace(tile_it, concat_op); model->operators.emplace(tile_it, concat_op);
tile_it = concat_it + 1; DeleteOpAndArrays(model, tile_op);
CHECK_EQ(tile_it->get(), tile_op);
model->operators.erase(tile_it);
*modified = true; *modified = true;
return ::tensorflow::Status::OK(); return ::tensorflow::Status::OK();

View File

@ -107,10 +107,8 @@ bool TransposeAffectsMemoryOrder(std::vector<int> perm,
} }
// Replace the operator in the graph. // Replace the operator in the graph.
const auto reshape_it = model->operators.emplace(transpose_it, reshape_op); model->operators.emplace(transpose_it, reshape_op);
transpose_it = reshape_it + 1; DeleteOpAndArrays(model, transpose_op);
CHECK_EQ(transpose_it->get(), transpose_op);
model->operators.erase(transpose_it);
*modified = true; *modified = true;
return ::tensorflow::Status::OK(); return ::tensorflow::Status::OK();

View File

@ -98,9 +98,8 @@ namespace toco {
} else { } else {
LOG(FATAL) << "Unhandled activation function type"; LOG(FATAL) << "Unhandled activation function type";
} }
model->EraseArray(ac_op->inputs[0]);
op->outputs[0] = ac_op->outputs[0]; op->outputs[0] = ac_op->outputs[0];
model->operators.erase(ac_it); DeleteOpAndArrays(model, ac_op);
*modified = true; *modified = true;
return ::tensorflow::Status::OK(); return ::tensorflow::Status::OK();
} }

View File

@ -296,13 +296,7 @@ void FuseMulOrDivParamsIntoFollowingAffine(Model* model, Operator* following_op,
model->EraseArray(binary_op->outputs[0]); model->EraseArray(binary_op->outputs[0]);
following_op->inputs[0] = binary_op->inputs[index_of_variable_input]; following_op->inputs[0] = binary_op->inputs[index_of_variable_input];
const auto& old_constant_param_name = DeleteOpAndArrays(model, binary_op);
binary_op->inputs[index_of_constant_input];
CHECK(IsConstantParameterArray(*model, old_constant_param_name));
if (CountOpsWithInput(*model, old_constant_param_name) == 1) {
model->EraseArray(old_constant_param_name);
}
model->operators.erase(binary_it);
*modified = true; *modified = true;
return ::tensorflow::Status::OK(); return ::tensorflow::Status::OK();
} }

View File

@ -360,13 +360,7 @@ void FuseMulOrDivParamsIntoPrecedingAffine(Model* model, Operator* preceding_op,
preceding_op->outputs[0] = binary_op->outputs[0]; preceding_op->outputs[0] = binary_op->outputs[0];
preceding_op->fused_activation_function = preceding_op->fused_activation_function =
binary_op->fused_activation_function; binary_op->fused_activation_function;
const auto& old_constant_param_name = DeleteOpAndArrays(model, binary_op);
binary_op->inputs[index_of_constant_input];
CHECK(IsConstantParameterArray(*model, old_constant_param_name));
if (CountOpsWithInput(*model, old_constant_param_name) == 1) {
model->EraseArray(old_constant_param_name);
}
model->operators.erase(binary_it);
*modified = true; *modified = true;
return ::tensorflow::Status::OK(); return ::tensorflow::Status::OK();
} }

View File

@ -336,11 +336,13 @@ void RewireBidirectionalSequenceSequenceOpsConnections(
string cur_fw_input = input_array_name; string cur_fw_input = input_array_name;
string cur_bw_input = input_array_name; string cur_bw_input = input_array_name;
for (size_t i = 0; i < bidirectional_sequence_ops.size(); ++i) { for (size_t i = 0; i < bidirectional_sequence_ops.size(); ++i) {
DeleteArrayIfUsedOnce(bidirectional_sequence_ops[i]->inputs[0], model); DeleteArrayIfUnusedOutsideOfOp(bidirectional_sequence_ops[i]->inputs[0],
bidirectional_sequence_ops[i], model);
bidirectional_sequence_ops[i]->inputs[0] = cur_fw_input; bidirectional_sequence_ops[i]->inputs[0] = cur_fw_input;
if (i != 0) { if (i != 0) {
DeleteArrayIfUsedOnce( DeleteArrayIfUnusedOutsideOfOp(
bidirectional_sequence_ops[i]->inputs[aux_input_index], model); bidirectional_sequence_ops[i]->inputs[aux_input_index],
bidirectional_sequence_ops[i], model);
bidirectional_sequence_ops[i]->inputs[aux_input_index] = cur_bw_input; bidirectional_sequence_ops[i]->inputs[aux_input_index] = cur_bw_input;
} }
cur_fw_input = bidirectional_sequence_ops[i]->outputs[0]; cur_fw_input = bidirectional_sequence_ops[i]->outputs[0];
@ -383,24 +385,16 @@ void RewireFinalUnpackOutputs(const UnpackOperator& original_unpack_operator,
(*final_unpack_operator)->outputs[i] = concat_output; (*final_unpack_operator)->outputs[i] = concat_output;
} }
// Remove the concat op. // Remove the concat op.
model->operators.erase(FindOperator(model, *unpack_following_op)); DeleteOpAndArrays(model, unpack_following_op);
} }
} }
} }
void RemoveUnpackOperator(const Operator& unpack_op, Model* model) {
for (const string& output_array_name : unpack_op.outputs) {
DeleteArrayIfUnused(output_array_name, model);
}
model->operators.erase(FindOperator(model, unpack_op));
}
void RemoveUnidirectionalSequenceOps(std::stack<Operator*> uni_sequence_ops, void RemoveUnidirectionalSequenceOps(std::stack<Operator*> uni_sequence_ops,
Model* model) { Model* model) {
while (!uni_sequence_ops.empty()) { while (!uni_sequence_ops.empty()) {
Operator* uni_sequence_op = uni_sequence_ops.top(); Operator* uni_sequence_op = uni_sequence_ops.top();
DeleteArrayIfUnused(uni_sequence_op->outputs[0], model); DeleteOpAndArrays(model, uni_sequence_op);
model->operators.erase(FindOperator(model, *uni_sequence_op));
uni_sequence_ops.pop(); uni_sequence_ops.pop();
} }
} }
@ -471,14 +465,9 @@ template <typename T>
// Delete unused ops. // Delete unused ops.
RemoveUnidirectionalSequenceOps(fw_unidirectional_sequence_ops, model); RemoveUnidirectionalSequenceOps(fw_unidirectional_sequence_ops, model);
RemoveUnidirectionalSequenceOps(bw_unidirectional_sequence_ops, model); RemoveUnidirectionalSequenceOps(bw_unidirectional_sequence_ops, model);
DeleteOpAndArrays(model, final_concat_op);
DeleteArrayIfUnused(final_concat_op->inputs[0], model);
DeleteArrayIfUnused(final_concat_op->inputs[1], model);
model->operators.erase(FindOperator(model, *final_concat_op));
// Only keep the fw lstm's input. // Only keep the fw lstm's input.
DeleteArrayIfUnused(first_bw_sequence_input->outputs[0], model); DeleteOpAndArrays(model, first_bw_sequence_input);
model->operators.erase(FindOperator(model, *first_bw_sequence_input));
*modified = true; *modified = true;
return ::tensorflow::Status::OK(); return ::tensorflow::Status::OK();
} }
@ -552,13 +541,12 @@ template <typename T>
model->operators.emplace(op_it, unpack_operator); model->operators.emplace(op_it, unpack_operator);
// Delete unused ops. // Delete unused ops.
RemoveUnpackOperator(*fw_lstm_output, model); DeleteOpAndArrays(model, fw_lstm_output);
RemoveUnpackOperator(*bw_lstm_output, model); DeleteOpAndArrays(model, bw_lstm_output);
RemoveUnidirectionalSequenceOps(fw_unidirectional_sequence_lstm_ops, model); RemoveUnidirectionalSequenceOps(fw_unidirectional_sequence_lstm_ops, model);
RemoveUnidirectionalSequenceOps(bw_unidirectional_sequence_lstm_ops, model); RemoveUnidirectionalSequenceOps(bw_unidirectional_sequence_lstm_ops, model);
// Only keep the fw lstm's pack input. // Only keep the fw lstm's pack input.
DeleteArrayIfUnused(first_bw_lstm_input->outputs[0], model); DeleteOpAndArrays(model, first_bw_lstm_input);
model->operators.erase(FindOperator(model, *first_bw_lstm_input));
*modified = true; *modified = true;
return ::tensorflow::Status::OK(); return ::tensorflow::Status::OK();
} }
@ -628,13 +616,12 @@ template <typename T>
model->operators.emplace(op_it, unpack_operator); model->operators.emplace(op_it, unpack_operator);
// Delete unused ops. // Delete unused ops.
RemoveUnpackOperator(*fw_rnn_output, model); DeleteOpAndArrays(model, fw_rnn_output);
RemoveUnpackOperator(*bw_rnn_output, model); DeleteOpAndArrays(model, bw_rnn_output);
RemoveUnidirectionalSequenceOps(fw_unidirectional_sequence_rnn_ops, model); RemoveUnidirectionalSequenceOps(fw_unidirectional_sequence_rnn_ops, model);
RemoveUnidirectionalSequenceOps(bw_unidirectional_sequence_rnn_ops, model); RemoveUnidirectionalSequenceOps(bw_unidirectional_sequence_rnn_ops, model);
// Only keep the fw rnn's pack input. // Only keep the fw rnn's pack input.
DeleteArrayIfUnused(first_bw_rnn_input->outputs[0], model); DeleteOpAndArrays(model, first_bw_rnn_input);
model->operators.erase(FindOperator(model, *first_bw_rnn_input));
*modified = true; *modified = true;
return ::tensorflow::Status::OK(); return ::tensorflow::Status::OK();
} }

View File

@ -141,30 +141,12 @@ bool ResolveDilatedConv(Model* model, Operator* conv_base_op, Operator* stb_op,
// 3. DELETE LEFTOVER OPERATORS // 3. DELETE LEFTOVER OPERATORS
// *************************************************************************** // ***************************************************************************
// Order is important. Delete the output array first, then the op, then it's DeleteOpAndArrays(model, bts_op);
// redundant inputs. DeleteOpAndArrays(model, stb_op);
// BatchToSpace Op
DeleteArrayIfUnused(bts_op->outputs[0], model);
std::vector<string> bts_op_inputs = bts_op->inputs;
model->operators.erase(FindOp(*model, bts_op));
DeleteArrayIfUnused(bts_op_inputs[1], model);
DeleteArrayIfUnused(bts_op_inputs[2], model);
// Pad Op if present
if (has_pad_op) { if (has_pad_op) {
DeleteArrayIfUnused(pad_op->outputs[0], model); DeleteOpAndArrays(model, pad_op);
std::vector<string> pad_op_inputs = pad_op->inputs;
model->operators.erase(FindOp(*model, pad_op));
DeleteArrayIfUnused(pad_op_inputs[1], model);
} }
// SpaceToBatch Op
DeleteArrayIfUnused(stb_op->outputs[0], model);
std::vector<string> stb_op_inputs = stb_op->inputs;
model->operators.erase(FindOp(*model, stb_op));
DeleteArrayIfUnused(stb_op_inputs[1], model);
DeleteArrayIfUnused(stb_op_inputs[2], model);
return true; return true;
} }

View File

@ -136,13 +136,13 @@ namespace toco {
AddMessageF("Creating %s replacing equivalent subgraph", LogName(*l2norm_op)); AddMessageF("Creating %s replacing equivalent subgraph", LogName(*l2norm_op));
// Erase the subgraph that is now replaced by L2Normalization // Erase the subgraph that is now replaced by L2Normalization
model->operators.erase(FindOp(*model, square_op)); DeleteOpAndArrays(model, square_op);
DeleteOpAndArraysIfUnused(model, sum_op); DeleteOpAndArrays(model, sum_op);
if (add_op) { if (add_op) {
DeleteOpAndArraysIfUnused(model, add_op); DeleteOpAndArrays(model, add_op);
} }
DeleteOpAndArraysIfUnused(model, sqrt_or_rsqrt_op); DeleteOpAndArrays(model, sqrt_or_rsqrt_op);
DeleteOpAndArraysIfUnused(model, div_or_mul_op); DeleteOpAndArrays(model, div_or_mul_op);
*modified = true; *modified = true;
return ::tensorflow::Status::OK(); return ::tensorflow::Status::OK();
} }

View File

@ -86,14 +86,9 @@ namespace toco {
AddMessageF("Creating %s replacing equivalent subgraph", LogName(*l2pool_op)); AddMessageF("Creating %s replacing equivalent subgraph", LogName(*l2pool_op));
// Erase intermediate arrays, keeping input to square op. DeleteOpAndArrays(model, square_op);
model->EraseArray(avpool_op->inputs[0]); DeleteOpAndArrays(model, avpool_op);
model->EraseArray(sqrt_op->inputs[0]); DeleteOpAndArrays(model, sqrt_op);
// Erase three operators being replaced.
model->operators.erase(FindOp(*model, square_op));
model->operators.erase(FindOp(*model, avpool_op));
model->operators.erase(FindOp(*model, sqrt_op));
*modified = true; *modified = true;
return ::tensorflow::Status::OK(); return ::tensorflow::Status::OK();

View File

@ -290,36 +290,19 @@ bool MatchOperatorInputs(const Operator& op, const Model& model,
concat_temp_array_name, activ_temp_array_name, concat_temp_array_name, activ_temp_array_name,
LogName(*lstm_cell_op)); LogName(*lstm_cell_op));
// Delete arrays and operators replaced by the LSTM cell operator. Order is DeleteOpAndArrays(model, final_output_mul);
// important - DeleteArrayIfUnused() only succeeds if dependent operators DeleteOpAndArrays(model, state_output_tanh);
// have been removed first. Start at the output and work towards the input. DeleteOpAndArrays(model, fc_output_sig);
model->operators.erase(FindOperator(model, *final_output_mul)); DeleteOpAndArrays(model, state_combine_add);
DeleteArrayIfUnused(state_output_tanh->outputs[0], model); DeleteOpAndArrays(model, state_forget_mul);
DeleteArrayIfUnused(fc_output_sig->outputs[0], model); DeleteOpAndArrays(model, state_remember_mul);
model->operators.erase(FindOperator(model, *state_output_tanh)); DeleteOpAndArrays(model, state_forget_sig);
model->operators.erase(FindOperator(model, *fc_output_sig)); DeleteOpAndArrays(model, state_info_tanh);
model->operators.erase(FindOperator(model, *state_combine_add)); DeleteOpAndArrays(model, state_remember_sig);
DeleteArrayIfUnused(state_forget_mul->outputs[0], model); DeleteOpAndArrays(model, fc_output_split);
DeleteArrayIfUnused(state_remember_mul->outputs[0], model); DeleteOpAndArrays(model, fully_connected);
model->operators.erase(FindOperator(model, *state_forget_mul)); DeleteOpAndArrays(model, concat_inputs);
model->operators.erase(FindOperator(model, *state_remember_mul));
DeleteArrayIfUnused(state_forget_sig->outputs[0], model);
DeleteArrayIfUnused(state_info_tanh->outputs[0], model);
DeleteArrayIfUnused(state_remember_sig->outputs[0], model);
model->operators.erase(FindOperator(model, *state_forget_sig));
model->operators.erase(FindOperator(model, *state_info_tanh));
model->operators.erase(FindOperator(model, *state_remember_sig));
DeleteArrayIfUnused(fc_output_split->outputs[0], model);
DeleteArrayIfUnused(fc_output_split->outputs[1], model);
DeleteArrayIfUnused(fc_output_split->outputs[2], model);
DeleteArrayIfUnused(fc_output_split->outputs[3], model);
string dims_array = fc_output_split->inputs[0];
model->operators.erase(FindOperator(model, *fc_output_split));
DeleteArrayIfUnused(dims_array, model);
DeleteArrayIfUnused(fully_connected->outputs[0], model);
model->operators.erase(FindOperator(model, *fully_connected));
DeleteArrayIfUnused(concat_inputs->outputs[0], model);
model->operators.erase(FindOperator(model, *concat_inputs));
*modified = true; *modified = true;
return ::tensorflow::Status::OK(); return ::tensorflow::Status::OK();
} }

View File

@ -169,23 +169,7 @@ namespace toco {
model->operators.emplace(op_it, std::move(lstm_cell_op)); model->operators.emplace(op_it, std::move(lstm_cell_op));
AddMessageF("Creating compact LstmCell replacing previous lstm cell"); AddMessageF("Creating compact LstmCell replacing previous lstm cell");
// Delete arrays and operators replaced by the LSTM cell operator. Order is DeleteOpAndArrays(model, src_op);
// important - DeleteArrayIfUnused() only succeeds if dependent operators
// have been removed first. Start at the output and work towards the input.
// Erase curr lstm op being replaced.
DeleteArrayIfUnused(src_op->inputs[kInputToInputWeightsTensor], model);
DeleteArrayIfUnused(src_op->inputs[kInputToForgetWeightsTensor], model);
DeleteArrayIfUnused(src_op->inputs[kInputToCellWeightsTensor], model);
DeleteArrayIfUnused(src_op->inputs[kInputToOutputWeightsTensor], model);
DeleteArrayIfUnused(src_op->inputs[kRecurrentToInputWeightsTensor], model);
DeleteArrayIfUnused(src_op->inputs[kRecurrentToForgetWeightsTensor], model);
DeleteArrayIfUnused(src_op->inputs[kRecurrentToCellWeightsTensor], model);
DeleteArrayIfUnused(src_op->inputs[kRecurrentToOutputWeightsTensor], model);
DeleteArrayIfUnused(src_op->inputs[kInputGateBiasTensor], model);
DeleteArrayIfUnused(src_op->inputs[kForgetGateBiasTensor], model);
DeleteArrayIfUnused(src_op->inputs[kCellGateBiasTensor], model);
DeleteArrayIfUnused(src_op->inputs[kOutputGateBiasTensor], model);
model->operators.erase(FindOp(*model, src_op));
*modified = true; *modified = true;
return ::tensorflow::Status::OK(); return ::tensorflow::Status::OK();

View File

@ -163,13 +163,7 @@ namespace toco {
model->operators.emplace(op_it, std::move(lstm_cell_op)); model->operators.emplace(op_it, std::move(lstm_cell_op));
AddMessageF("Creating extended LstmCell replacing previous lstm cell"); AddMessageF("Creating extended LstmCell replacing previous lstm cell");
// Delete arrays and operators replaced by the LSTM cell operator. Order is DeleteOpAndArrays(model, curr_op);
// important - DeleteArrayIfUnused() only succeeds if dependent operators
// have been removed first. Start at the output and work towards the input.
// Erase curr lstm op being replaced.
DeleteArrayIfUnused(curr_op->inputs[LstmCellOperator::WEIGHTS_INPUT], model);
DeleteArrayIfUnused(curr_op->inputs[LstmCellOperator::BIASES_INPUT], model);
model->operators.erase(FindOp(*model, curr_op));
*modified = true; *modified = true;
return ::tensorflow::Status::OK(); return ::tensorflow::Status::OK();

View File

@ -122,14 +122,10 @@ namespace toco {
model->operators.emplace(add_op_it, prelu_op); model->operators.emplace(add_op_it, prelu_op);
AddMessageF("Creating %s replacing equivalent subgraph", LogName(*prelu_op)); AddMessageF("Creating %s replacing equivalent subgraph", LogName(*prelu_op));
DeleteArrayIfUsedOnce(neg_alpha_tensor_name, model); DeleteArrayIfUnusedOutsideOfOp(neg_alpha_tensor_name, neg_neg_alpha_op,
DeleteArrayIfUsedOnce(add_op->inputs[0], model); model);
DeleteArrayIfUsedOnce(add_op->inputs[1], model); DeleteArrayIfUnusedOutsideOfOp(mul_op->inputs[1], mul_op, model);
DeleteArrayIfUsedOnce(mul_op->inputs[1], model); DeleteOpAndArrays(model, add_op);
// Remove the existing Add op that outputs the final result. If the other
// intermediate tensors aren't used by other ops, those will be removed by
// other graph transformation rules.
model->operators.erase(FindOp(*model, add_op));
*modified = true; *modified = true;
return ::tensorflow::Status::OK(); return ::tensorflow::Status::OK();
} }

View File

@ -99,19 +99,13 @@ int GetSingleScalarInputIndexOfBinaryOp(Model* model, const Operator* op,
// Create and emplace Relu1 node. // Create and emplace Relu1 node.
auto* relu1_op = new Relu1Operator; auto* relu1_op = new Relu1Operator;
AddMessageF("Creating %s replacing equivalent subgraph", LogName(*relu1_op));
relu1_op->inputs = {op_0->inputs[!op_0_scalar_input_index]}; relu1_op->inputs = {op_0->inputs[!op_0_scalar_input_index]};
relu1_op->outputs = op_1->outputs; relu1_op->outputs = op_1->outputs;
model->operators.emplace(op_it, relu1_op); model->operators.emplace(op_it, relu1_op);
AddMessageF("Creating %s replacing equivalent subgraph", LogName(*relu1_op)); DeleteOpAndArrays(model, op_0);
DeleteOpAndArrays(model, op_1);
// Erase op scalar inputs & operators. Note that we preserve the non-scalar
// input to the first op as that's been redirected to the relu1_op.
DeleteArrayIfUsedOnce(op_0->inputs[op_0_scalar_input_index], model);
DeleteArrayIfUsedOnce(op_1->inputs[0], model);
DeleteArrayIfUsedOnce(op_1->inputs[1], model);
model->operators.erase(FindOperator(model, op_0));
model->operators.erase(FindOperator(model, op_1));
*modified = true; *modified = true;
return ::tensorflow::Status::OK(); return ::tensorflow::Status::OK();

View File

@ -546,10 +546,7 @@ void FixMinMaxPostQuantization(GraphTransformation* transformation,
// Check if the output of that Dequantize op was not used by any // Check if the output of that Dequantize op was not used by any
// other operator. We will then erase that Dequantize op. // other operator. We will then erase that Dequantize op.
if (!CountOpsWithInput(*model, dequantize_op->outputs[0])) { if (!CountOpsWithInput(*model, dequantize_op->outputs[0])) {
if (IsDiscardableArray(*model, dequantize_op->outputs[0])) { if (!IsDiscardableArray(*model, dequantize_op->outputs[0])) {
// Usual case: we can just discard the dequantize output.
model->EraseArray(dequantize_op->outputs[0]);
} else {
// The dequantize output is not discardable. Special care needed. // The dequantize output is not discardable. Special care needed.
// If any of the model's output_arrays was pointing to the // If any of the model's output_arrays was pointing to the
// Dequantize op's output, let it point to the Dequantize op's // Dequantize op's output, let it point to the Dequantize op's
@ -583,13 +580,12 @@ void FixMinMaxPostQuantization(GraphTransformation* transformation,
model->flags.output_arrays(i), model->flags.output_arrays(i),
dequantize_op->inputs[0]); dequantize_op->inputs[0]);
model->flags.set_output_arrays(i, dequantize_op->inputs[0]); model->flags.set_output_arrays(i, dequantize_op->inputs[0]);
model->EraseArray(dequantize_op->outputs[0]);
} }
break; break;
} }
} }
} }
model->operators.erase(dequantize_it); DeleteOpAndArrays(model, dequantize_op);
} }
changed = true; changed = true;
} else { } else {

View File

@ -54,8 +54,7 @@ namespace toco {
// Remove the node and its output array. // Remove the node and its output array.
AddMessageF("Removed final %s", LogName(*dequantize_op)); AddMessageF("Removed final %s", LogName(*dequantize_op));
model->EraseArray(output); DeleteOpAndArrays(model, dequantize_op);
model->operators.erase(dequantize_it);
*modified = true; *modified = true;
return ::tensorflow::Status::OK(); return ::tensorflow::Status::OK();
} }

View File

@ -59,13 +59,10 @@ namespace toco {
} }
// Drop trivial inputs. // Drop trivial inputs.
for (const string& input : trivial_inputs) {
if (IsDiscardableArray(*model, input) &&
CountOpsWithInput(*model, input) == 1) {
model->EraseArray(input);
}
}
concat_op->inputs = nontrivial_inputs; concat_op->inputs = nontrivial_inputs;
for (const string& input : trivial_inputs) {
DeleteArrayIfUnusedOutsideOfOp(input, concat_op, model);
}
*modified = true; *modified = true;
return ::tensorflow::Status::OK(); return ::tensorflow::Status::OK();
} }

View File

@ -89,17 +89,6 @@ bool RemoveTrivialPassthroughOp(GraphTransformation* transformation,
const string main_input_name = passthru_op->inputs[main_input_array_index]; const string main_input_name = passthru_op->inputs[main_input_array_index];
const string output_name = passthru_op->outputs[0]; const string output_name = passthru_op->outputs[0];
// Build the list of all input and output arrays of the passthrough node
// that we are considering removing. Any of these arrays is a candidate
// for being removed as well, if nothing else references it. Doing that
// arrays-removal together with the passthrough-node-removal proved too
// error-prone.
std::vector<string> removal_candidates;
for (const string& input : passthru_op->inputs) {
removal_candidates.push_back(input);
}
removal_candidates.push_back(output_name);
if (IsDiscardableArray(*model, output_name)) { if (IsDiscardableArray(*model, output_name)) {
transformation->AddMessageF( transformation->AddMessageF(
"Removing %s, keeping its non-constant input array %s and removing %s", "Removing %s, keeping its non-constant input array %s and removing %s",
@ -135,38 +124,7 @@ bool RemoveTrivialPassthroughOp(GraphTransformation* transformation,
} }
// Remove the pass-through node. // Remove the pass-through node.
CHECK_EQ(passthru_it->get(), passthru_op); DeleteOpAndArrays(model, passthru_op);
model->operators.erase(passthru_it);
// Remove any array that is no longer used.
for (const string& removal_candidate : removal_candidates) {
bool is_referenced = false;
for (const auto& array : model->flags.input_arrays()) {
if (array.name() == removal_candidate) {
is_referenced = true;
}
}
for (const auto& array_name : model->flags.output_arrays()) {
if (array_name == removal_candidate) {
is_referenced = true;
}
}
for (const auto& op : model->operators) {
for (const string& input : op->inputs) {
if (input == removal_candidate) {
is_referenced = true;
}
}
for (const string& output : op->outputs) {
if (output == removal_candidate) {
is_referenced = true;
}
}
}
if (!is_referenced) {
model->EraseArray(removal_candidate);
}
}
return true; return true;
} }

View File

@ -86,27 +86,7 @@ namespace toco {
AddMessageF("Discarding %s because none of its outputs is used.", AddMessageF("Discarding %s because none of its outputs is used.",
LogName(*op)); LogName(*op));
DeleteOpAndArrays(model, op);
// At that point we know that none of the outputs is used, so we will
// definitely remove the node and all its outputs.
// Remove any input array that not the output of another op, and only used by
// this op.
for (const auto& input : op->inputs) {
if (!GetOpWithOutput(*model, input)) {
DeleteArrayIfUsedOnce(input, model);
}
}
// Remove the node and its now-unused output arrays.
for (const auto& output : op->outputs) {
// If the output array is the model's input array, don't remove that.
// That's the case when cropping a model at a given --input_array.
if (IsDiscardableArray(*model, output)) {
model->EraseArray(output);
}
}
model->operators.erase(it);
*modified = true; *modified = true;
return ::tensorflow::Status::OK(); return ::tensorflow::Status::OK();
} }

View File

@ -122,7 +122,7 @@ bool IsMoveOperator(OperatorType optype) {
element_op->inputs[0] = input_name; element_op->inputs[0] = input_name;
element_op->outputs[0] = new_intermediate_name; element_op->outputs[0] = new_intermediate_name;
model->EraseArray(intermediate_name); DeleteArrayIfUnused(intermediate_name, model);
move_op->inputs[0] = new_intermediate_name; move_op->inputs[0] = new_intermediate_name;
move_op->outputs[0] = output_name; move_op->outputs[0] = output_name;
} else { } else {

View File

@ -190,7 +190,7 @@ std::vector<int> ComputeNewPerm(std::vector<int> input_dims,
transpose_op->outputs[0] = new_intermediate_name; transpose_op->outputs[0] = new_intermediate_name;
reshape_op->inputs[0] = new_intermediate_name; reshape_op->inputs[0] = new_intermediate_name;
reshape_op->outputs[0] = output_name; reshape_op->outputs[0] = output_name;
model->EraseArray(intermediate_name); DeleteArrayIfUnused(intermediate_name, model);
} else { } else {
// The intermediate array is now the output array. // The intermediate array is now the output array.
for (int i = 0; i < model->operators.size(); i++) { for (int i = 0; i < model->operators.size(); i++) {

View File

@ -136,14 +136,7 @@ namespace toco {
offset_float_data[i] - mean_float_data[i] * multiplier_float_data[i]; offset_float_data[i] - mean_float_data[i] * multiplier_float_data[i];
} }
// Remove the old param arrays DeleteOpAndArrays(model, bn_op);
DeleteArrayIfUsedOnce(bn_op->inputs[1], model);
DeleteArrayIfUsedOnce(bn_op->inputs[2], model);
DeleteArrayIfUsedOnce(bn_op->inputs[3], model);
// Remove the old operator
DCHECK_EQ(bn_it->get(), bn_op);
model->operators.erase(bn_it);
*modified = true; *modified = true;
return ::tensorflow::Status::OK(); return ::tensorflow::Status::OK();

View File

@ -247,16 +247,7 @@ void EvaluateBinaryOperatorOnConstantInputs(Model* model,
// Do the actual constants propagation // Do the actual constants propagation
EvaluateBinaryOperatorOnConstantInputs(model, binary_op); EvaluateBinaryOperatorOnConstantInputs(model, binary_op);
// Remove the binary operator and its inputs DeleteOpAndArrays(model, binary_op);
if (CountOpsWithInput(*model, binary_op->inputs[0]) == 1) {
model->EraseArray(binary_op->inputs[0]);
}
if (CountOpsWithInput(*model, binary_op->inputs[1]) == 1) {
model->EraseArray(binary_op->inputs[1]);
}
AddMessageF("Resolved constant %s to the equivalent constant array",
LogName(*binary_op));
model->operators.erase(binary_it);
*modified = true; *modified = true;
return ::tensorflow::Status::OK(); return ::tensorflow::Status::OK();
} }

View File

@ -206,16 +206,7 @@ void SetMinMaxForConcatenedArray(GraphTransformation* transformation,
LOG(FATAL) << "ArrayDataType not supported"; LOG(FATAL) << "ArrayDataType not supported";
} }
// Remove all the resolved arrays. DeleteOpAndArrays(model, concat_op);
for (const string& input_name : concat_op->inputs) {
// Check to prevent removal of shared tensors.
if (CountOpsWithInput(*model, input_name) == 1) {
model->EraseArray(input_name);
}
}
// Remove concatenate operator.
model->operators.erase(concat_it);
*modified = true; *modified = true;
return ::tensorflow::Status::OK(); return ::tensorflow::Status::OK();
} }

View File

@ -132,13 +132,7 @@ void GetBoundsForQuantizedDataType(ArrayDataType quantized_data_type,
tflite::FakeQuantizeArray(scale, nudged_min, nudged_max, tflite::FakeQuantizeArray(scale, nudged_min, nudged_max,
input_buffer.data.data(), output_buffer.data.data(), input_buffer.data.data(), output_buffer.data.data(),
size); size);
DeleteOpAndArrays(model, fakequant_op);
if (IsDiscardableArray(*model, fakequant_op->inputs[0]) &&
CountOpsWithInput(*model, fakequant_op->inputs[0]) == 1) {
model->EraseArray(fakequant_op->inputs[0]);
}
model->operators.erase(fakequant_it);
*modified = true; *modified = true;
return ::tensorflow::Status::OK(); return ::tensorflow::Status::OK();
} }

View File

@ -109,19 +109,7 @@ bool ComputeFillArray(Model* model, FillOperator* op) {
break; break;
} }
// Erase input arrays if no longer used DeleteOpAndArrays(model, op);
if (IsDiscardableArray(*model, op->inputs[0]) &&
CountOpsWithInput(*model, op->inputs[0]) == 1) {
model->EraseArray(op->inputs[0]);
}
if (IsDiscardableArray(*model, op->inputs[1]) &&
CountOpsWithInput(*model, op->inputs[1]) == 1) {
model->EraseArray(op->inputs[1]);
}
// Erase the operator
model->operators.erase(fill_it);
*modified = true; *modified = true;
return ::tensorflow::Status::OK(); return ::tensorflow::Status::OK();
} }

View File

@ -142,12 +142,7 @@ inline void Gather(const Array& input_array, const Array& coords_array,
break; break;
} }
// Erase input arrays if no longer used after we remove the op. DeleteOpAndArrays(model, op);
DeleteArrayIfUsedOnce(op->inputs[0], model);
DeleteArrayIfUsedOnce(op->inputs[1], model);
// Erase the operator.
model->operators.erase(it);
*modified = true; *modified = true;
return ::tensorflow::Status::OK(); return ::tensorflow::Status::OK();
} }

View File

@ -110,13 +110,7 @@ void Pack(Model* model, PackOperator const& op) {
break; break;
} }
// Erase input arrays if no longer used DeleteOpAndArrays(model, op);
for (const auto& input : op->inputs) {
toco::DeleteArrayIfUsedOnce(input, model);
}
// Erase the operator
model->operators.erase(it);
*modified = true; *modified = true;
return ::tensorflow::Status::OK(); return ::tensorflow::Status::OK();
} }

View File

@ -107,12 +107,7 @@ bool ComputeRandomUniformArray(Model* model, RandomUniformOperator* op) {
break; break;
} }
// Erase input arrays if no longer used DeleteOpAndArrays(model, op);
toco::DeleteArrayIfUsedOnce(op->inputs[0], model);
// Erase the operator
model->operators.erase(it);
*modified = true; *modified = true;
return ::tensorflow::Status::OK(); return ::tensorflow::Status::OK();
} }

View File

@ -105,23 +105,7 @@ void FillRangeOutput(const Array& start_array, const Array& limit_array,
delta_array, &output_array); delta_array, &output_array);
} }
// Delete the input array if no longer used DeleteOpAndArrays(model, op);
if (IsDiscardableArray(*model, op->inputs[0]) &&
CountOpsWithInput(*model, op->inputs[0]) == 1) {
model->EraseArray(op->inputs[0]);
}
if (IsDiscardableArray(*model, op->inputs[1]) &&
CountOpsWithInput(*model, op->inputs[1]) == 1) {
model->EraseArray(op->inputs[1]);
}
if (IsDiscardableArray(*model, op->inputs[2]) &&
CountOpsWithInput(*model, op->inputs[2]) == 1) {
model->EraseArray(op->inputs[2]);
}
// Delete the operator
model->operators.erase(it);
*modified = true; *modified = true;
return ::tensorflow::Status::OK(); return ::tensorflow::Status::OK();
} }

View File

@ -108,16 +108,7 @@ namespace toco {
CopyMinMaxAndQuantizationRelatedFields(input_array, &output_array); CopyMinMaxAndQuantizationRelatedFields(input_array, &output_array);
// Erase input arrays if no longer used. DeleteOpAndArrays(model, op);
for (const auto& input : op->inputs) {
if (IsDiscardableArray(*model, input) &&
CountOpsWithInput(*model, input) == 1) {
model->EraseArray(input);
}
}
// Erase the operator.
model->operators.erase(it);
*modified = true; *modified = true;
return ::tensorflow::Status::OK(); return ::tensorflow::Status::OK();
} }

View File

@ -61,13 +61,7 @@ namespace toco {
output_array.mutable_shape()->ReplaceDims( output_array.mutable_shape()->ReplaceDims(
{static_cast<int>(output_buffer.data.size())}); {static_cast<int>(output_buffer.data.size())});
// Delete the input array if no longer used DeleteOpAndArrays(model, op);
if (IsDiscardableArray(*model, op->inputs[0]) &&
CountOpsWithInput(*model, op->inputs[0]) == 1) {
model->EraseArray(op->inputs[0]);
}
model->operators.erase(it);
*modified = true; *modified = true;
return ::tensorflow::Status::OK(); return ::tensorflow::Status::OK();
} }

View File

@ -158,15 +158,7 @@ bool Slice(SliceOperator const& op, Array const& input_array,
break; break;
} }
// Erase input array if no longer used. DeleteOpAndArrays(model, op);
if (IsDiscardableArray(*model, op->inputs[0]) &&
CountOpsWithInput(*model, op->inputs[0]) == 1) {
model->EraseArray(op->inputs[0]);
}
// Erase the operator
model->operators.erase(it);
*modified = true; *modified = true;
return ::tensorflow::Status::OK(); return ::tensorflow::Status::OK();
} }

View File

@ -163,8 +163,7 @@ void StridedSlice(StridedSliceOperator const& op, Array const& input_array,
break; break;
} }
DeleteOpAndArraysIfUnused(model, it->get()); DeleteOpAndArrays(model, op);
*modified = true; *modified = true;
return ::tensorflow::Status::OK(); return ::tensorflow::Status::OK();
} }

View File

@ -160,12 +160,7 @@ inline void Tile(const Array& input_array, const Array& multiples_array,
break; break;
} }
// Erase input arrays if no longer used after we remove the op. DeleteOpAndArrays(model, op);
DeleteArrayIfUsedOnce(op->inputs[0], model);
DeleteArrayIfUsedOnce(op->inputs[1], model);
// Erase the operator.
model->operators.erase(it);
*modified = true; *modified = true;
return ::tensorflow::Status::OK(); return ::tensorflow::Status::OK();
} }

View File

@ -171,16 +171,7 @@ void Transpose(Model* model, const Array& input_array,
AddMessageF("Resolving constant transpose of %s", LogName(*op)); AddMessageF("Resolving constant transpose of %s", LogName(*op));
// Erase input arrays if no longer used. DeleteOpAndArrays(model, op);
for (const auto& input : op->inputs) {
if (IsDiscardableArray(*model, input) &&
CountOpsWithInput(*model, input) == 1) {
model->EraseArray(input);
}
}
// Erase the operator.
model->operators.erase(it);
*modified = true; *modified = true;
return ::tensorflow::Status::OK(); return ::tensorflow::Status::OK();
} }

View File

@ -347,14 +347,8 @@ bool CopyMinMaxFromFirstInput(const Operator& op, Model* model) {
} else { } else {
LOG(FATAL) << "should not get here."; LOG(FATAL) << "should not get here.";
} }
for (const auto& input : unary_op->inputs) {
if (CountOpsWithInput(*model, input) == 1) { DeleteOpAndArrays(model, unary_op);
model->EraseArray(input);
}
}
AddMessageF("Resolved constant %s to the equivalent constant array",
LogName(*unary_op));
model->operators.erase(unary_it);
*modified = true; *modified = true;
return ::tensorflow::Status::OK(); return ::tensorflow::Status::OK();
} }

View File

@ -74,7 +74,8 @@ namespace toco {
// values, anymore. Delete them unless they are used by something // values, anymore. Delete them unless they are used by something
// else. // else.
for (int i = 1; i <= 2; i++) { for (int i = 1; i <= 2; i++) {
DeleteArrayIfUsedOnce(fakequant_op->inputs[i], model); DeleteArrayIfUnusedOutsideOfOp(fakequant_op->inputs[i], fakequant_op,
model);
} }
fakequant_op->inputs.resize(1); fakequant_op->inputs.resize(1);
*modified = true; *modified = true;

View File

@ -49,7 +49,7 @@ namespace toco {
op->axis = {axis_data[0]}; op->axis = {axis_data[0]};
// Drop the axis array as we no longer need it. // Drop the axis array as we no longer need it.
DeleteArrayIfUsedOnce(op->inputs[2], model); DeleteArrayIfUnusedOutsideOfOp(op->inputs[2], op, model);
op->inputs.resize(2); op->inputs.resize(2);
*modified = true; *modified = true;

View File

@ -154,13 +154,7 @@ void FillArrayWithZeros(Array* array) {
return ::tensorflow::Status::OK(); return ::tensorflow::Status::OK();
} }
// Erase input arrays to the multiply if no longer used DeleteOpAndArrays(model, mul_op);
DeleteArrayIfUsedOnce(mul_op->inputs[0], model);
DeleteArrayIfUsedOnce(mul_op->inputs[1], model);
// Erase the multiply operator.
model->operators.erase(mul_it);
*modified = true; *modified = true;
return ::tensorflow::Status::OK(); return ::tensorflow::Status::OK();
} }

View File

@ -119,7 +119,7 @@ void ReorderAxes(AxesOrder input_axes_order, AxesOrder output_axes_order,
AddMessageF("Reordered axes for array %s", input_array_name); AddMessageF("Reordered axes for array %s", input_array_name);
DeleteOpAndArraysIfUnused(model, op); DeleteOpAndArrays(model, op);
RenameArray(model, output_array_name, input_array_name); RenameArray(model, output_array_name, input_array_name);
*modified = true; *modified = true;

View File

@ -72,16 +72,8 @@ namespace toco {
concatenation_op->outputs = {tf_concat_op->outputs[0]}; concatenation_op->outputs = {tf_concat_op->outputs[0]};
auto depth_concat_it = model->operators.emplace(concat_it, concatenation_op); auto depth_concat_it = model->operators.emplace(concat_it, concatenation_op);
CHECK_EQ(depth_concat_it->get(), concatenation_op); CHECK_EQ(depth_concat_it->get(), concatenation_op);
// Update invalidated iterator
concat_it = depth_concat_it + 1;
CHECK_EQ(concat_it->get(), tf_concat_op);
// Remove the axis array if it is not used by anything else. DeleteOpAndArrays(model, tf_concat_op);
if (CountOpsWithInput(*model, axis_name) == 1) {
model->EraseArray(axis_name);
}
// Remove the TensorFlowConcat op
model->operators.erase(concat_it);
*modified = true; *modified = true;
return ::tensorflow::Status::OK(); return ::tensorflow::Status::OK();
} }

View File

@ -129,6 +129,7 @@ TransposeOperator* FindTransposeOpWithInput(const Model& model,
// Construct the new FullyConnectedOperator. // Construct the new FullyConnectedOperator.
auto* fc_op = new FullyConnectedOperator; auto* fc_op = new FullyConnectedOperator;
fc_op->inputs = {input_lhs, input_rhs};
fc_op->outputs = matmul_op->outputs; fc_op->outputs = matmul_op->outputs;
// Insert the newly constructed FullyConnectedOperator. // Insert the newly constructed FullyConnectedOperator.
@ -173,14 +174,10 @@ TransposeOperator* FindTransposeOpWithInput(const Model& model,
} }
CHECK_EQ(previous_op->inputs.size(), 2); CHECK_EQ(previous_op->inputs.size(), 2);
input_lhs = previous_op->inputs[0]; input_lhs = previous_op->inputs[0];
fc_op->inputs = {input_lhs, input_rhs};
// Only remove Reshape node if no other node uses its output. // Only remove Reshape node if no other node uses its output.
if (CountOpsWithInput(*model, previous_op_output) == 1) { if (CountOpsWithInput(*model, previous_op_output) == 1) {
const auto& previous_op_shape = previous_op->inputs[1]; DeleteOpAndArrays(model, previous_op);
if (CountOpsWithInput(*model, previous_op_shape) == 1 &&
!GetOpWithOutput(*model, previous_op_shape)) {
model->EraseArray(previous_op_shape);
}
model->operators.erase(previous_op_it);
} }
// We may have just invalidated matmul_it, so let's refresh it now. // We may have just invalidated matmul_it, so let's refresh it now.
@ -197,7 +194,6 @@ TransposeOperator* FindTransposeOpWithInput(const Model& model,
LogName(*matmul_op)); LogName(*matmul_op));
} }
fc_op->inputs = {input_lhs, input_rhs};
// erase the MatMul operator // erase the MatMul operator
model->operators.erase(matmul_it); model->operators.erase(matmul_it);

View File

@ -56,10 +56,7 @@ namespace toco {
} }
} }
// Remove the node and its output array. DeleteOpAndArrays(model, merge_op);
AddMessageF("Removing already-resolved %s", LogName(*merge_op));
model->EraseArray(merge_op->outputs[0]);
model->operators.erase(merge_it);
*modified = true; *modified = true;
return ::tensorflow::Status::OK(); return ::tensorflow::Status::OK();
} }

View File

@ -121,7 +121,7 @@ namespace toco {
} }
// Remove the switch node itself. // Remove the switch node itself.
AddMessageF("Removing already-resolved %s", LogName(*switch_op)); AddMessageF("Removing already-resolved %s", LogName(*switch_op));
model->operators.erase(switch_it); DeleteOpAndArrays(model, switch_op);
*modified = true; *modified = true;
return ::tensorflow::Status::OK(); return ::tensorflow::Status::OK();
} }

View File

@ -107,6 +107,8 @@ class ResolveConstantConcatenationTest : public ::testing::Test {
// together with 4 arrays as its inputs. // together with 4 arrays as its inputs.
// It receives the dimension of concatenation as input. // It receives the dimension of concatenation as input.
void PrepareModel(Model* model, int axis) { void PrepareModel(Model* model, int axis) {
const string output_name("concat_op_output");
model->flags.add_output_arrays(output_name);
std::vector<string> concat_input_names = {"array0", "array1", "array2", std::vector<string> concat_input_names = {"array0", "array1", "array2",
"array3"}; "array3"};
@ -141,7 +143,7 @@ class ResolveConstantConcatenationTest : public ::testing::Test {
auto* concatenation_op = new ConcatenationOperator; auto* concatenation_op = new ConcatenationOperator;
concatenation_op->axis = axis; concatenation_op->axis = axis;
concatenation_op->inputs = concat_input_names; concatenation_op->inputs = concat_input_names;
concatenation_op->outputs = {"concat_op_outputs"}; concatenation_op->outputs = {output_name};
Array& out_array = model->GetOrCreateArray(concatenation_op->outputs[0]); Array& out_array = model->GetOrCreateArray(concatenation_op->outputs[0]);
out_array.data_type = ArrayDataType::kFloat; out_array.data_type = ArrayDataType::kFloat;
Shape* out_array_shape = out_array.mutable_shape(); Shape* out_array_shape = out_array.mutable_shape();
@ -172,8 +174,8 @@ TEST_F(ResolveConstantConcatenationTest, ConcatAtAxis0) {
.ok()); .ok());
EXPECT_THAT(model.GetArrayMap().size(), 1); EXPECT_THAT(model.GetArrayMap().size(), 1);
auto& concatenated_array = (*model.GetArrayMap().begin()).second; const auto& concatenated_array = model.GetArray(model.flags.output_arrays(0));
EXPECT_THAT(concatenated_array->GetBuffer<toco::ArrayDataType::kFloat>().data, EXPECT_THAT(concatenated_array.GetBuffer<toco::ArrayDataType::kFloat>().data,
ElementsAreArray(ArrayFloatNear( ElementsAreArray(ArrayFloatNear(
{0., 1., 2., 3., 4., 5., 6., 7., 10., 11., 12., {0., 1., 2., 3., 4., 5., 6., 7., 10., 11., 12.,
13., 14., 15., 16., 17., 20., 21., 22., 23., 24., 25., 13., 14., 15., 16., 17., 20., 21., 22., 23., 24., 25.,

View File

@ -31,9 +31,11 @@ void RunResolveSum(const std::vector<float>& input,
const std::vector<int>& output_shape, const std::vector<int>& output_shape,
const std::vector<float>& expected_output) { const std::vector<float>& expected_output) {
Model model; Model model;
const std::string output_name("output");
model.flags.add_output_arrays(output_name);
Array& input0 = model.GetOrCreateArray("input0"); Array& input0 = model.GetOrCreateArray("input0");
Array& input1 = model.GetOrCreateArray("input1"); Array& input1 = model.GetOrCreateArray("input1");
Array& output = model.GetOrCreateArray("output"); Array& output = model.GetOrCreateArray(output_name);
*input0.mutable_shape()->mutable_dims() = input_shape; *input0.mutable_shape()->mutable_dims() = input_shape;
input0.data_type = ArrayDataType::kFloat; input0.data_type = ArrayDataType::kFloat;
@ -48,7 +50,7 @@ void RunResolveSum(const std::vector<float>& input,
auto sum_op = absl::make_unique<TensorFlowSumOperator>(); auto sum_op = absl::make_unique<TensorFlowSumOperator>();
sum_op->keep_dims = true; sum_op->keep_dims = true;
sum_op->inputs = {"input0", "input1"}; sum_op->inputs = {"input0", "input1"};
sum_op->outputs = {"output"}; sum_op->outputs = {output_name};
model.operators.push_back(std::move(sum_op)); model.operators.push_back(std::move(sum_op));
bool modified; bool modified;
ASSERT_TRUE(ResolveConstantUnaryOperator().Run(&model, 0, &modified).ok()); ASSERT_TRUE(ResolveConstantUnaryOperator().Run(&model, 0, &modified).ok());

View File

@ -227,18 +227,18 @@ namespace toco {
// Ensure the stitch output array is dead, as we don't want whatever was in it // Ensure the stitch output array is dead, as we don't want whatever was in it
// previously now that we've redefined it. It'll be recreated when needed. // previously now that we've redefined it. It'll be recreated when needed.
model->EraseArray(stitch_op->outputs[0]); model->EraseArray(merged_gather_op->outputs[0]);
model->GetOrCreateArray(merged_gather_op->outputs[0]); model->GetOrCreateArray(merged_gather_op->outputs[0]);
// Erase all the original ops. // Erase all the original ops.
DeleteOpAndArraysIfUnused(model, div_op); DeleteOpAndArrays(model, div_op);
DeleteOpAndArraysIfUnused(model, mod_op); DeleteOpAndArrays(model, mod_op);
for (auto* gather_op : gather_ops) { for (auto* gather_op : gather_ops) {
DeleteOpAndArraysIfUnused(model, gather_op); DeleteOpAndArrays(model, gather_op);
} }
DeleteOpAndArraysIfUnused(model, indices_partition_op); DeleteOpAndArrays(model, indices_partition_op);
DeleteOpAndArraysIfUnused(model, data_partition_op); DeleteOpAndArrays(model, data_partition_op);
DeleteOpAndArraysIfUnused(model, stitch_op); DeleteOpAndArrays(model, stitch_op);
*modified = true; *modified = true;
return ::tensorflow::Status::OK(); return ::tensorflow::Status::OK();
} }

View File

@ -188,9 +188,8 @@ TransposeOperator* TransposeInput(const string& input, Model* model) {
auto* matmul_op = new TensorFlowMatMulOperator; auto* matmul_op = new TensorFlowMatMulOperator;
matmul_op->inputs = {input_lhs, input_rhs}; matmul_op->inputs = {input_lhs, input_rhs};
matmul_op->outputs = batch_op->outputs; matmul_op->outputs = batch_op->outputs;
tail_it = model->operators.emplace(tail_it, matmul_op) + 1; model->operators.emplace(tail_it, matmul_op);
CHECK_EQ(tail_it->get(), batch_op); DeleteOpAndArrays(model, batch_op);
model->operators.erase(tail_it);
*modified = true; *modified = true;
return ::tensorflow::Status::OK(); return ::tensorflow::Status::OK();
} }
@ -254,16 +253,7 @@ TransposeOperator* TransposeInput(const string& input, Model* model) {
reshape_result_op->outputs = {batch_op->outputs[0]}; reshape_result_op->outputs = {batch_op->outputs[0]};
model->operators.emplace(tail_it, reshape_result_op); model->operators.emplace(tail_it, reshape_result_op);
// Remove the old batch matmul now that we've unrolled. DeleteOpAndArrays(model, batch_op);
batch_op_it = model->operators.begin();
for (; batch_op_it != model->operators.end(); ++batch_op_it) {
if (batch_op_it->get() == batch_op) {
break;
}
}
CHECK(batch_op_it != model->operators.end());
CHECK(batch_op_it->get() == batch_op);
model->operators.erase(batch_op_it);
*modified = true; *modified = true;
return ::tensorflow::Status::OK(); return ::tensorflow::Status::OK();
} }

View File

@ -165,19 +165,33 @@ bool DeleteArrayIfUnused(const string& array_name, Model* model) {
return false; return false;
} }
bool DeleteArrayIfUsedOnce(const string& array_name, Model* model) { bool DeleteArrayIfUnusedOutsideOfOp(const string& array_name,
if (IsDiscardableArray(*model, array_name) && const Operator* op, Model* model) {
CountOpsWithInput(*model, array_name) == 1 && if (!IsDiscardableArray(*model, array_name)) {
GetOpWithOutput(*model, array_name) == nullptr) { return false;
}
if (CountOpsWithInput(*model, array_name) > 1) {
return false;
}
const Operator* op_having_this_as_input = GetOpWithInput(*model, array_name);
if (op_having_this_as_input && op_having_this_as_input != op) {
return false;
}
const Operator* op_having_this_as_output =
GetOpWithOutput(*model, array_name);
if (op_having_this_as_output && op_having_this_as_output != op) {
return false;
}
model->EraseArray(array_name); model->EraseArray(array_name);
return true; return true;
}
return false;
} }
void DeleteOpAndArraysIfUnused(Model* model, const Operator* op) { void DeleteOpAndArrays(Model* model, const Operator* op) {
for (const string& array_name : op->inputs) { for (const string& array_name : op->inputs) {
DeleteArrayIfUsedOnce(array_name, model); DeleteArrayIfUnusedOutsideOfOp(array_name, op, model);
}
for (const string& array_name : op->outputs) {
DeleteArrayIfUnusedOutsideOfOp(array_name, op, model);
} }
auto op_it = FindOp(*model, op); auto op_it = FindOp(*model, op);
CHECK(op_it != model->operators.end()); CHECK(op_it != model->operators.end());

View File

@ -68,11 +68,10 @@ int CountTrueOutputs(const Model& model, const Operator& op);
int CountOpsWithInput(const Model& model, const string& array_name); int CountOpsWithInput(const Model& model, const string& array_name);
bool DeleteArrayIfUnused(const string& array_name, Model* model); bool DeleteArrayIfUnused(const string& array_name, Model* model);
bool DeleteArrayIfUsedOnce(const string& array_name, Model* model);
// Deletes the op and any of its input and output arrays if they are unused // Deletes the op and any of its input and output arrays if they are unused
// after the op has been deleted. // after the op has been deleted.
void DeleteOpAndArraysIfUnused(Model* model, const Operator* op); void DeleteOpAndArrays(Model* model, const Operator* op);
std::vector<std::unique_ptr<Operator>>::const_iterator FindOpWithOutput( std::vector<std::unique_ptr<Operator>>::const_iterator FindOpWithOutput(
const Model& model, const string& array_name); const Model& model, const string& array_name);
@ -361,6 +360,11 @@ void UndoWeightsShuffling(Model* model);
// Copies minmax, quantization_params, and narrow_range. // Copies minmax, quantization_params, and narrow_range.
void CopyMinMaxAndQuantizationRelatedFields(const Array& src, Array* dst); void CopyMinMaxAndQuantizationRelatedFields(const Array& src, Array* dst);
// Delete Array if it's discardable and not referenced as input or output array
// by any other op than the specified op.
bool DeleteArrayIfUnusedOutsideOfOp(const string& array_name,
const Operator* op, Model* model);
} // namespace toco } // namespace toco
#endif // TENSORFLOW_LITE_TOCO_TOOLING_UTIL_H_ #endif // TENSORFLOW_LITE_TOCO_TOOLING_UTIL_H_