Fix the deletion of ops and their arrays in toco:
1. Fix DeleteOpAndArraysIfUnused, it was not deleting the op's output arrays. 2. Rename it to just DeleteOpAndArrays (was not clear what 'unused' referred to). 3. Use that throughout toco graph transformations. 4. Some extra fixes, and removed DeleteArrayIfUsedOnce. PiperOrigin-RevId: 253285001
This commit is contained in:
parent
119de52436
commit
0e11e67490
@ -94,10 +94,8 @@ namespace toco {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Replace the operator in the graph.
|
// Replace the operator in the graph.
|
||||||
const auto reshape_it = model->operators.emplace(expand_it, reshape_op);
|
model->operators.emplace(expand_it, reshape_op);
|
||||||
expand_it = reshape_it + 1;
|
DeleteOpAndArrays(model, expand_op);
|
||||||
CHECK_EQ(expand_it->get(), expand_op);
|
|
||||||
model->operators.erase(expand_it);
|
|
||||||
|
|
||||||
*modified = true;
|
*modified = true;
|
||||||
return ::tensorflow::Status::OK();
|
return ::tensorflow::Status::OK();
|
||||||
|
@ -92,11 +92,8 @@ namespace toco {
|
|||||||
depthwiseconv_op->stride_width = conv_op->stride_width;
|
depthwiseconv_op->stride_width = conv_op->stride_width;
|
||||||
depthwiseconv_op->depth_multiplier = weights_array.shape().dims(0);
|
depthwiseconv_op->depth_multiplier = weights_array.shape().dims(0);
|
||||||
// Replace the operator in the graph.
|
// Replace the operator in the graph.
|
||||||
const auto depthwiseconv_it =
|
|
||||||
model->operators.emplace(conv_it, depthwiseconv_op);
|
model->operators.emplace(conv_it, depthwiseconv_op);
|
||||||
conv_it = depthwiseconv_it + 1;
|
DeleteOpAndArrays(model, conv_op);
|
||||||
CHECK_EQ(conv_it->get(), conv_op);
|
|
||||||
model->operators.erase(conv_it);
|
|
||||||
// Shuffle the weights.
|
// Shuffle the weights.
|
||||||
const auto& weights_shape = weights_array.shape();
|
const auto& weights_shape = weights_array.shape();
|
||||||
auto& weights_buffer =
|
auto& weights_buffer =
|
||||||
|
@ -132,20 +132,16 @@ TransposeOperator* CreateTransposeFromReorderAxes(
|
|||||||
// order of the elements does not change.
|
// order of the elements does not change.
|
||||||
auto* reshape_op =
|
auto* reshape_op =
|
||||||
CreateReshapeFromReorderAxes(model, reorder_op, input_shape);
|
CreateReshapeFromReorderAxes(model, reorder_op, input_shape);
|
||||||
const auto reshape_it = model->operators.emplace(reorder_it, reshape_op);
|
model->operators.emplace(reorder_it, reshape_op);
|
||||||
reorder_it = reshape_it + 1;
|
|
||||||
} else {
|
} else {
|
||||||
// Add Transpose operator into the graph.
|
// Add Transpose operator into the graph.
|
||||||
auto* transpose_op = CreateTransposeFromReorderAxes(
|
auto* transpose_op = CreateTransposeFromReorderAxes(
|
||||||
model, reorder_op, input_shape, input_axes_order, output_axes_order);
|
model, reorder_op, input_shape, input_axes_order, output_axes_order);
|
||||||
const auto transpose_it =
|
|
||||||
model->operators.emplace(reorder_it, transpose_op);
|
model->operators.emplace(reorder_it, transpose_op);
|
||||||
reorder_it = transpose_it + 1;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Remove ReorderAxes operator from the graph.
|
// Remove ReorderAxes operator from the graph.
|
||||||
CHECK_EQ(reorder_it->get(), reorder_op);
|
DeleteOpAndArrays(model, reorder_op);
|
||||||
model->operators.erase(reorder_it);
|
|
||||||
|
|
||||||
*modified = true;
|
*modified = true;
|
||||||
return ::tensorflow::Status::OK();
|
return ::tensorflow::Status::OK();
|
||||||
|
@ -77,10 +77,8 @@ namespace toco {
|
|||||||
LogName(*reshape_op));
|
LogName(*reshape_op));
|
||||||
|
|
||||||
// Replace the operator in the graph.
|
// Replace the operator in the graph.
|
||||||
const auto reshape_it = model->operators.emplace(squeeze_it, reshape_op);
|
model->operators.emplace(squeeze_it, reshape_op);
|
||||||
squeeze_it = reshape_it + 1;
|
DeleteOpAndArrays(model, squeeze_op);
|
||||||
CHECK_EQ(squeeze_it->get(), squeeze_op);
|
|
||||||
model->operators.erase(squeeze_it);
|
|
||||||
|
|
||||||
*modified = true;
|
*modified = true;
|
||||||
return ::tensorflow::Status::OK();
|
return ::tensorflow::Status::OK();
|
||||||
|
@ -12,9 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|||||||
See the License for the specific language governing permissions and
|
See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
#include "tensorflow/core/platform/logging.h"
|
||||||
#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
|
#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
|
||||||
#include "tensorflow/lite/toco/model.h"
|
#include "tensorflow/lite/toco/model.h"
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/lite/toco/tooling_util.h"
|
||||||
|
|
||||||
namespace toco {
|
namespace toco {
|
||||||
|
|
||||||
@ -44,10 +45,8 @@ namespace toco {
|
|||||||
add_op->outputs = addn_op->outputs;
|
add_op->outputs = addn_op->outputs;
|
||||||
|
|
||||||
// Replace the AddN operator in the graph.
|
// Replace the AddN operator in the graph.
|
||||||
const auto add_it = model->operators.emplace(addn_it, add_op);
|
model->operators.emplace(addn_it, add_op);
|
||||||
addn_it = add_it + 1;
|
DeleteOpAndArrays(model, addn_op);
|
||||||
CHECK_EQ(addn_it->get(), addn_op);
|
|
||||||
model->operators.erase(addn_it);
|
|
||||||
*modified = true;
|
*modified = true;
|
||||||
return ::tensorflow::Status::OK();
|
return ::tensorflow::Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -77,10 +77,8 @@ namespace toco {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Replace the operator in the graph.
|
// Replace the operator in the graph.
|
||||||
const auto reshape_it = model->operators.emplace(pack_it, reshape_op);
|
model->operators.emplace(pack_it, reshape_op);
|
||||||
pack_it = reshape_it + 1;
|
DeleteOpAndArrays(model, pack_op);
|
||||||
CHECK_EQ(pack_it->get(), pack_op);
|
|
||||||
model->operators.erase(pack_it);
|
|
||||||
|
|
||||||
*modified = true;
|
*modified = true;
|
||||||
return ::tensorflow::Status::OK();
|
return ::tensorflow::Status::OK();
|
||||||
|
@ -86,10 +86,8 @@ namespace toco {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Replace the operator in the graph.
|
// Replace the operator in the graph.
|
||||||
const auto concat_it = model->operators.emplace(tile_it, concat_op);
|
model->operators.emplace(tile_it, concat_op);
|
||||||
tile_it = concat_it + 1;
|
DeleteOpAndArrays(model, tile_op);
|
||||||
CHECK_EQ(tile_it->get(), tile_op);
|
|
||||||
model->operators.erase(tile_it);
|
|
||||||
|
|
||||||
*modified = true;
|
*modified = true;
|
||||||
return ::tensorflow::Status::OK();
|
return ::tensorflow::Status::OK();
|
||||||
|
@ -107,10 +107,8 @@ bool TransposeAffectsMemoryOrder(std::vector<int> perm,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Replace the operator in the graph.
|
// Replace the operator in the graph.
|
||||||
const auto reshape_it = model->operators.emplace(transpose_it, reshape_op);
|
model->operators.emplace(transpose_it, reshape_op);
|
||||||
transpose_it = reshape_it + 1;
|
DeleteOpAndArrays(model, transpose_op);
|
||||||
CHECK_EQ(transpose_it->get(), transpose_op);
|
|
||||||
model->operators.erase(transpose_it);
|
|
||||||
|
|
||||||
*modified = true;
|
*modified = true;
|
||||||
return ::tensorflow::Status::OK();
|
return ::tensorflow::Status::OK();
|
||||||
|
@ -98,9 +98,8 @@ namespace toco {
|
|||||||
} else {
|
} else {
|
||||||
LOG(FATAL) << "Unhandled activation function type";
|
LOG(FATAL) << "Unhandled activation function type";
|
||||||
}
|
}
|
||||||
model->EraseArray(ac_op->inputs[0]);
|
|
||||||
op->outputs[0] = ac_op->outputs[0];
|
op->outputs[0] = ac_op->outputs[0];
|
||||||
model->operators.erase(ac_it);
|
DeleteOpAndArrays(model, ac_op);
|
||||||
*modified = true;
|
*modified = true;
|
||||||
return ::tensorflow::Status::OK();
|
return ::tensorflow::Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -296,13 +296,7 @@ void FuseMulOrDivParamsIntoFollowingAffine(Model* model, Operator* following_op,
|
|||||||
model->EraseArray(binary_op->outputs[0]);
|
model->EraseArray(binary_op->outputs[0]);
|
||||||
|
|
||||||
following_op->inputs[0] = binary_op->inputs[index_of_variable_input];
|
following_op->inputs[0] = binary_op->inputs[index_of_variable_input];
|
||||||
const auto& old_constant_param_name =
|
DeleteOpAndArrays(model, binary_op);
|
||||||
binary_op->inputs[index_of_constant_input];
|
|
||||||
CHECK(IsConstantParameterArray(*model, old_constant_param_name));
|
|
||||||
if (CountOpsWithInput(*model, old_constant_param_name) == 1) {
|
|
||||||
model->EraseArray(old_constant_param_name);
|
|
||||||
}
|
|
||||||
model->operators.erase(binary_it);
|
|
||||||
*modified = true;
|
*modified = true;
|
||||||
return ::tensorflow::Status::OK();
|
return ::tensorflow::Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -360,13 +360,7 @@ void FuseMulOrDivParamsIntoPrecedingAffine(Model* model, Operator* preceding_op,
|
|||||||
preceding_op->outputs[0] = binary_op->outputs[0];
|
preceding_op->outputs[0] = binary_op->outputs[0];
|
||||||
preceding_op->fused_activation_function =
|
preceding_op->fused_activation_function =
|
||||||
binary_op->fused_activation_function;
|
binary_op->fused_activation_function;
|
||||||
const auto& old_constant_param_name =
|
DeleteOpAndArrays(model, binary_op);
|
||||||
binary_op->inputs[index_of_constant_input];
|
|
||||||
CHECK(IsConstantParameterArray(*model, old_constant_param_name));
|
|
||||||
if (CountOpsWithInput(*model, old_constant_param_name) == 1) {
|
|
||||||
model->EraseArray(old_constant_param_name);
|
|
||||||
}
|
|
||||||
model->operators.erase(binary_it);
|
|
||||||
*modified = true;
|
*modified = true;
|
||||||
return ::tensorflow::Status::OK();
|
return ::tensorflow::Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -336,11 +336,13 @@ void RewireBidirectionalSequenceSequenceOpsConnections(
|
|||||||
string cur_fw_input = input_array_name;
|
string cur_fw_input = input_array_name;
|
||||||
string cur_bw_input = input_array_name;
|
string cur_bw_input = input_array_name;
|
||||||
for (size_t i = 0; i < bidirectional_sequence_ops.size(); ++i) {
|
for (size_t i = 0; i < bidirectional_sequence_ops.size(); ++i) {
|
||||||
DeleteArrayIfUsedOnce(bidirectional_sequence_ops[i]->inputs[0], model);
|
DeleteArrayIfUnusedOutsideOfOp(bidirectional_sequence_ops[i]->inputs[0],
|
||||||
|
bidirectional_sequence_ops[i], model);
|
||||||
bidirectional_sequence_ops[i]->inputs[0] = cur_fw_input;
|
bidirectional_sequence_ops[i]->inputs[0] = cur_fw_input;
|
||||||
if (i != 0) {
|
if (i != 0) {
|
||||||
DeleteArrayIfUsedOnce(
|
DeleteArrayIfUnusedOutsideOfOp(
|
||||||
bidirectional_sequence_ops[i]->inputs[aux_input_index], model);
|
bidirectional_sequence_ops[i]->inputs[aux_input_index],
|
||||||
|
bidirectional_sequence_ops[i], model);
|
||||||
bidirectional_sequence_ops[i]->inputs[aux_input_index] = cur_bw_input;
|
bidirectional_sequence_ops[i]->inputs[aux_input_index] = cur_bw_input;
|
||||||
}
|
}
|
||||||
cur_fw_input = bidirectional_sequence_ops[i]->outputs[0];
|
cur_fw_input = bidirectional_sequence_ops[i]->outputs[0];
|
||||||
@ -383,24 +385,16 @@ void RewireFinalUnpackOutputs(const UnpackOperator& original_unpack_operator,
|
|||||||
(*final_unpack_operator)->outputs[i] = concat_output;
|
(*final_unpack_operator)->outputs[i] = concat_output;
|
||||||
}
|
}
|
||||||
// Remove the concat op.
|
// Remove the concat op.
|
||||||
model->operators.erase(FindOperator(model, *unpack_following_op));
|
DeleteOpAndArrays(model, unpack_following_op);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void RemoveUnpackOperator(const Operator& unpack_op, Model* model) {
|
|
||||||
for (const string& output_array_name : unpack_op.outputs) {
|
|
||||||
DeleteArrayIfUnused(output_array_name, model);
|
|
||||||
}
|
|
||||||
model->operators.erase(FindOperator(model, unpack_op));
|
|
||||||
}
|
|
||||||
|
|
||||||
void RemoveUnidirectionalSequenceOps(std::stack<Operator*> uni_sequence_ops,
|
void RemoveUnidirectionalSequenceOps(std::stack<Operator*> uni_sequence_ops,
|
||||||
Model* model) {
|
Model* model) {
|
||||||
while (!uni_sequence_ops.empty()) {
|
while (!uni_sequence_ops.empty()) {
|
||||||
Operator* uni_sequence_op = uni_sequence_ops.top();
|
Operator* uni_sequence_op = uni_sequence_ops.top();
|
||||||
DeleteArrayIfUnused(uni_sequence_op->outputs[0], model);
|
DeleteOpAndArrays(model, uni_sequence_op);
|
||||||
model->operators.erase(FindOperator(model, *uni_sequence_op));
|
|
||||||
uni_sequence_ops.pop();
|
uni_sequence_ops.pop();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -471,14 +465,9 @@ template <typename T>
|
|||||||
// Delete unused ops.
|
// Delete unused ops.
|
||||||
RemoveUnidirectionalSequenceOps(fw_unidirectional_sequence_ops, model);
|
RemoveUnidirectionalSequenceOps(fw_unidirectional_sequence_ops, model);
|
||||||
RemoveUnidirectionalSequenceOps(bw_unidirectional_sequence_ops, model);
|
RemoveUnidirectionalSequenceOps(bw_unidirectional_sequence_ops, model);
|
||||||
|
DeleteOpAndArrays(model, final_concat_op);
|
||||||
DeleteArrayIfUnused(final_concat_op->inputs[0], model);
|
|
||||||
DeleteArrayIfUnused(final_concat_op->inputs[1], model);
|
|
||||||
model->operators.erase(FindOperator(model, *final_concat_op));
|
|
||||||
|
|
||||||
// Only keep the fw lstm's input.
|
// Only keep the fw lstm's input.
|
||||||
DeleteArrayIfUnused(first_bw_sequence_input->outputs[0], model);
|
DeleteOpAndArrays(model, first_bw_sequence_input);
|
||||||
model->operators.erase(FindOperator(model, *first_bw_sequence_input));
|
|
||||||
*modified = true;
|
*modified = true;
|
||||||
return ::tensorflow::Status::OK();
|
return ::tensorflow::Status::OK();
|
||||||
}
|
}
|
||||||
@ -552,13 +541,12 @@ template <typename T>
|
|||||||
model->operators.emplace(op_it, unpack_operator);
|
model->operators.emplace(op_it, unpack_operator);
|
||||||
|
|
||||||
// Delete unused ops.
|
// Delete unused ops.
|
||||||
RemoveUnpackOperator(*fw_lstm_output, model);
|
DeleteOpAndArrays(model, fw_lstm_output);
|
||||||
RemoveUnpackOperator(*bw_lstm_output, model);
|
DeleteOpAndArrays(model, bw_lstm_output);
|
||||||
RemoveUnidirectionalSequenceOps(fw_unidirectional_sequence_lstm_ops, model);
|
RemoveUnidirectionalSequenceOps(fw_unidirectional_sequence_lstm_ops, model);
|
||||||
RemoveUnidirectionalSequenceOps(bw_unidirectional_sequence_lstm_ops, model);
|
RemoveUnidirectionalSequenceOps(bw_unidirectional_sequence_lstm_ops, model);
|
||||||
// Only keep the fw lstm's pack input.
|
// Only keep the fw lstm's pack input.
|
||||||
DeleteArrayIfUnused(first_bw_lstm_input->outputs[0], model);
|
DeleteOpAndArrays(model, first_bw_lstm_input);
|
||||||
model->operators.erase(FindOperator(model, *first_bw_lstm_input));
|
|
||||||
*modified = true;
|
*modified = true;
|
||||||
return ::tensorflow::Status::OK();
|
return ::tensorflow::Status::OK();
|
||||||
}
|
}
|
||||||
@ -628,13 +616,12 @@ template <typename T>
|
|||||||
model->operators.emplace(op_it, unpack_operator);
|
model->operators.emplace(op_it, unpack_operator);
|
||||||
|
|
||||||
// Delete unused ops.
|
// Delete unused ops.
|
||||||
RemoveUnpackOperator(*fw_rnn_output, model);
|
DeleteOpAndArrays(model, fw_rnn_output);
|
||||||
RemoveUnpackOperator(*bw_rnn_output, model);
|
DeleteOpAndArrays(model, bw_rnn_output);
|
||||||
RemoveUnidirectionalSequenceOps(fw_unidirectional_sequence_rnn_ops, model);
|
RemoveUnidirectionalSequenceOps(fw_unidirectional_sequence_rnn_ops, model);
|
||||||
RemoveUnidirectionalSequenceOps(bw_unidirectional_sequence_rnn_ops, model);
|
RemoveUnidirectionalSequenceOps(bw_unidirectional_sequence_rnn_ops, model);
|
||||||
// Only keep the fw rnn's pack input.
|
// Only keep the fw rnn's pack input.
|
||||||
DeleteArrayIfUnused(first_bw_rnn_input->outputs[0], model);
|
DeleteOpAndArrays(model, first_bw_rnn_input);
|
||||||
model->operators.erase(FindOperator(model, *first_bw_rnn_input));
|
|
||||||
*modified = true;
|
*modified = true;
|
||||||
return ::tensorflow::Status::OK();
|
return ::tensorflow::Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -141,30 +141,12 @@ bool ResolveDilatedConv(Model* model, Operator* conv_base_op, Operator* stb_op,
|
|||||||
|
|
||||||
// 3. DELETE LEFTOVER OPERATORS
|
// 3. DELETE LEFTOVER OPERATORS
|
||||||
// ***************************************************************************
|
// ***************************************************************************
|
||||||
// Order is important. Delete the output array first, then the op, then it's
|
DeleteOpAndArrays(model, bts_op);
|
||||||
// redundant inputs.
|
DeleteOpAndArrays(model, stb_op);
|
||||||
// BatchToSpace Op
|
|
||||||
DeleteArrayIfUnused(bts_op->outputs[0], model);
|
|
||||||
std::vector<string> bts_op_inputs = bts_op->inputs;
|
|
||||||
model->operators.erase(FindOp(*model, bts_op));
|
|
||||||
DeleteArrayIfUnused(bts_op_inputs[1], model);
|
|
||||||
DeleteArrayIfUnused(bts_op_inputs[2], model);
|
|
||||||
|
|
||||||
// Pad Op if present
|
|
||||||
if (has_pad_op) {
|
if (has_pad_op) {
|
||||||
DeleteArrayIfUnused(pad_op->outputs[0], model);
|
DeleteOpAndArrays(model, pad_op);
|
||||||
std::vector<string> pad_op_inputs = pad_op->inputs;
|
|
||||||
model->operators.erase(FindOp(*model, pad_op));
|
|
||||||
DeleteArrayIfUnused(pad_op_inputs[1], model);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// SpaceToBatch Op
|
|
||||||
DeleteArrayIfUnused(stb_op->outputs[0], model);
|
|
||||||
std::vector<string> stb_op_inputs = stb_op->inputs;
|
|
||||||
model->operators.erase(FindOp(*model, stb_op));
|
|
||||||
DeleteArrayIfUnused(stb_op_inputs[1], model);
|
|
||||||
DeleteArrayIfUnused(stb_op_inputs[2], model);
|
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -136,13 +136,13 @@ namespace toco {
|
|||||||
AddMessageF("Creating %s replacing equivalent subgraph", LogName(*l2norm_op));
|
AddMessageF("Creating %s replacing equivalent subgraph", LogName(*l2norm_op));
|
||||||
|
|
||||||
// Erase the subgraph that is now replaced by L2Normalization
|
// Erase the subgraph that is now replaced by L2Normalization
|
||||||
model->operators.erase(FindOp(*model, square_op));
|
DeleteOpAndArrays(model, square_op);
|
||||||
DeleteOpAndArraysIfUnused(model, sum_op);
|
DeleteOpAndArrays(model, sum_op);
|
||||||
if (add_op) {
|
if (add_op) {
|
||||||
DeleteOpAndArraysIfUnused(model, add_op);
|
DeleteOpAndArrays(model, add_op);
|
||||||
}
|
}
|
||||||
DeleteOpAndArraysIfUnused(model, sqrt_or_rsqrt_op);
|
DeleteOpAndArrays(model, sqrt_or_rsqrt_op);
|
||||||
DeleteOpAndArraysIfUnused(model, div_or_mul_op);
|
DeleteOpAndArrays(model, div_or_mul_op);
|
||||||
*modified = true;
|
*modified = true;
|
||||||
return ::tensorflow::Status::OK();
|
return ::tensorflow::Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -86,14 +86,9 @@ namespace toco {
|
|||||||
|
|
||||||
AddMessageF("Creating %s replacing equivalent subgraph", LogName(*l2pool_op));
|
AddMessageF("Creating %s replacing equivalent subgraph", LogName(*l2pool_op));
|
||||||
|
|
||||||
// Erase intermediate arrays, keeping input to square op.
|
DeleteOpAndArrays(model, square_op);
|
||||||
model->EraseArray(avpool_op->inputs[0]);
|
DeleteOpAndArrays(model, avpool_op);
|
||||||
model->EraseArray(sqrt_op->inputs[0]);
|
DeleteOpAndArrays(model, sqrt_op);
|
||||||
|
|
||||||
// Erase three operators being replaced.
|
|
||||||
model->operators.erase(FindOp(*model, square_op));
|
|
||||||
model->operators.erase(FindOp(*model, avpool_op));
|
|
||||||
model->operators.erase(FindOp(*model, sqrt_op));
|
|
||||||
|
|
||||||
*modified = true;
|
*modified = true;
|
||||||
return ::tensorflow::Status::OK();
|
return ::tensorflow::Status::OK();
|
||||||
|
@ -290,36 +290,19 @@ bool MatchOperatorInputs(const Operator& op, const Model& model,
|
|||||||
concat_temp_array_name, activ_temp_array_name,
|
concat_temp_array_name, activ_temp_array_name,
|
||||||
LogName(*lstm_cell_op));
|
LogName(*lstm_cell_op));
|
||||||
|
|
||||||
// Delete arrays and operators replaced by the LSTM cell operator. Order is
|
DeleteOpAndArrays(model, final_output_mul);
|
||||||
// important - DeleteArrayIfUnused() only succeeds if dependent operators
|
DeleteOpAndArrays(model, state_output_tanh);
|
||||||
// have been removed first. Start at the output and work towards the input.
|
DeleteOpAndArrays(model, fc_output_sig);
|
||||||
model->operators.erase(FindOperator(model, *final_output_mul));
|
DeleteOpAndArrays(model, state_combine_add);
|
||||||
DeleteArrayIfUnused(state_output_tanh->outputs[0], model);
|
DeleteOpAndArrays(model, state_forget_mul);
|
||||||
DeleteArrayIfUnused(fc_output_sig->outputs[0], model);
|
DeleteOpAndArrays(model, state_remember_mul);
|
||||||
model->operators.erase(FindOperator(model, *state_output_tanh));
|
DeleteOpAndArrays(model, state_forget_sig);
|
||||||
model->operators.erase(FindOperator(model, *fc_output_sig));
|
DeleteOpAndArrays(model, state_info_tanh);
|
||||||
model->operators.erase(FindOperator(model, *state_combine_add));
|
DeleteOpAndArrays(model, state_remember_sig);
|
||||||
DeleteArrayIfUnused(state_forget_mul->outputs[0], model);
|
DeleteOpAndArrays(model, fc_output_split);
|
||||||
DeleteArrayIfUnused(state_remember_mul->outputs[0], model);
|
DeleteOpAndArrays(model, fully_connected);
|
||||||
model->operators.erase(FindOperator(model, *state_forget_mul));
|
DeleteOpAndArrays(model, concat_inputs);
|
||||||
model->operators.erase(FindOperator(model, *state_remember_mul));
|
|
||||||
DeleteArrayIfUnused(state_forget_sig->outputs[0], model);
|
|
||||||
DeleteArrayIfUnused(state_info_tanh->outputs[0], model);
|
|
||||||
DeleteArrayIfUnused(state_remember_sig->outputs[0], model);
|
|
||||||
model->operators.erase(FindOperator(model, *state_forget_sig));
|
|
||||||
model->operators.erase(FindOperator(model, *state_info_tanh));
|
|
||||||
model->operators.erase(FindOperator(model, *state_remember_sig));
|
|
||||||
DeleteArrayIfUnused(fc_output_split->outputs[0], model);
|
|
||||||
DeleteArrayIfUnused(fc_output_split->outputs[1], model);
|
|
||||||
DeleteArrayIfUnused(fc_output_split->outputs[2], model);
|
|
||||||
DeleteArrayIfUnused(fc_output_split->outputs[3], model);
|
|
||||||
string dims_array = fc_output_split->inputs[0];
|
|
||||||
model->operators.erase(FindOperator(model, *fc_output_split));
|
|
||||||
DeleteArrayIfUnused(dims_array, model);
|
|
||||||
DeleteArrayIfUnused(fully_connected->outputs[0], model);
|
|
||||||
model->operators.erase(FindOperator(model, *fully_connected));
|
|
||||||
DeleteArrayIfUnused(concat_inputs->outputs[0], model);
|
|
||||||
model->operators.erase(FindOperator(model, *concat_inputs));
|
|
||||||
*modified = true;
|
*modified = true;
|
||||||
return ::tensorflow::Status::OK();
|
return ::tensorflow::Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -169,23 +169,7 @@ namespace toco {
|
|||||||
model->operators.emplace(op_it, std::move(lstm_cell_op));
|
model->operators.emplace(op_it, std::move(lstm_cell_op));
|
||||||
AddMessageF("Creating compact LstmCell replacing previous lstm cell");
|
AddMessageF("Creating compact LstmCell replacing previous lstm cell");
|
||||||
|
|
||||||
// Delete arrays and operators replaced by the LSTM cell operator. Order is
|
DeleteOpAndArrays(model, src_op);
|
||||||
// important - DeleteArrayIfUnused() only succeeds if dependent operators
|
|
||||||
// have been removed first. Start at the output and work towards the input.
|
|
||||||
// Erase curr lstm op being replaced.
|
|
||||||
DeleteArrayIfUnused(src_op->inputs[kInputToInputWeightsTensor], model);
|
|
||||||
DeleteArrayIfUnused(src_op->inputs[kInputToForgetWeightsTensor], model);
|
|
||||||
DeleteArrayIfUnused(src_op->inputs[kInputToCellWeightsTensor], model);
|
|
||||||
DeleteArrayIfUnused(src_op->inputs[kInputToOutputWeightsTensor], model);
|
|
||||||
DeleteArrayIfUnused(src_op->inputs[kRecurrentToInputWeightsTensor], model);
|
|
||||||
DeleteArrayIfUnused(src_op->inputs[kRecurrentToForgetWeightsTensor], model);
|
|
||||||
DeleteArrayIfUnused(src_op->inputs[kRecurrentToCellWeightsTensor], model);
|
|
||||||
DeleteArrayIfUnused(src_op->inputs[kRecurrentToOutputWeightsTensor], model);
|
|
||||||
DeleteArrayIfUnused(src_op->inputs[kInputGateBiasTensor], model);
|
|
||||||
DeleteArrayIfUnused(src_op->inputs[kForgetGateBiasTensor], model);
|
|
||||||
DeleteArrayIfUnused(src_op->inputs[kCellGateBiasTensor], model);
|
|
||||||
DeleteArrayIfUnused(src_op->inputs[kOutputGateBiasTensor], model);
|
|
||||||
model->operators.erase(FindOp(*model, src_op));
|
|
||||||
|
|
||||||
*modified = true;
|
*modified = true;
|
||||||
return ::tensorflow::Status::OK();
|
return ::tensorflow::Status::OK();
|
||||||
|
@ -163,13 +163,7 @@ namespace toco {
|
|||||||
model->operators.emplace(op_it, std::move(lstm_cell_op));
|
model->operators.emplace(op_it, std::move(lstm_cell_op));
|
||||||
AddMessageF("Creating extended LstmCell replacing previous lstm cell");
|
AddMessageF("Creating extended LstmCell replacing previous lstm cell");
|
||||||
|
|
||||||
// Delete arrays and operators replaced by the LSTM cell operator. Order is
|
DeleteOpAndArrays(model, curr_op);
|
||||||
// important - DeleteArrayIfUnused() only succeeds if dependent operators
|
|
||||||
// have been removed first. Start at the output and work towards the input.
|
|
||||||
// Erase curr lstm op being replaced.
|
|
||||||
DeleteArrayIfUnused(curr_op->inputs[LstmCellOperator::WEIGHTS_INPUT], model);
|
|
||||||
DeleteArrayIfUnused(curr_op->inputs[LstmCellOperator::BIASES_INPUT], model);
|
|
||||||
model->operators.erase(FindOp(*model, curr_op));
|
|
||||||
|
|
||||||
*modified = true;
|
*modified = true;
|
||||||
return ::tensorflow::Status::OK();
|
return ::tensorflow::Status::OK();
|
||||||
|
@ -122,14 +122,10 @@ namespace toco {
|
|||||||
model->operators.emplace(add_op_it, prelu_op);
|
model->operators.emplace(add_op_it, prelu_op);
|
||||||
AddMessageF("Creating %s replacing equivalent subgraph", LogName(*prelu_op));
|
AddMessageF("Creating %s replacing equivalent subgraph", LogName(*prelu_op));
|
||||||
|
|
||||||
DeleteArrayIfUsedOnce(neg_alpha_tensor_name, model);
|
DeleteArrayIfUnusedOutsideOfOp(neg_alpha_tensor_name, neg_neg_alpha_op,
|
||||||
DeleteArrayIfUsedOnce(add_op->inputs[0], model);
|
model);
|
||||||
DeleteArrayIfUsedOnce(add_op->inputs[1], model);
|
DeleteArrayIfUnusedOutsideOfOp(mul_op->inputs[1], mul_op, model);
|
||||||
DeleteArrayIfUsedOnce(mul_op->inputs[1], model);
|
DeleteOpAndArrays(model, add_op);
|
||||||
// Remove the existing Add op that outputs the final result. If the other
|
|
||||||
// intermediate tensors aren't used by other ops, those will be removed by
|
|
||||||
// other graph transformation rules.
|
|
||||||
model->operators.erase(FindOp(*model, add_op));
|
|
||||||
*modified = true;
|
*modified = true;
|
||||||
return ::tensorflow::Status::OK();
|
return ::tensorflow::Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -99,19 +99,13 @@ int GetSingleScalarInputIndexOfBinaryOp(Model* model, const Operator* op,
|
|||||||
|
|
||||||
// Create and emplace Relu1 node.
|
// Create and emplace Relu1 node.
|
||||||
auto* relu1_op = new Relu1Operator;
|
auto* relu1_op = new Relu1Operator;
|
||||||
|
AddMessageF("Creating %s replacing equivalent subgraph", LogName(*relu1_op));
|
||||||
relu1_op->inputs = {op_0->inputs[!op_0_scalar_input_index]};
|
relu1_op->inputs = {op_0->inputs[!op_0_scalar_input_index]};
|
||||||
relu1_op->outputs = op_1->outputs;
|
relu1_op->outputs = op_1->outputs;
|
||||||
model->operators.emplace(op_it, relu1_op);
|
model->operators.emplace(op_it, relu1_op);
|
||||||
|
|
||||||
AddMessageF("Creating %s replacing equivalent subgraph", LogName(*relu1_op));
|
DeleteOpAndArrays(model, op_0);
|
||||||
|
DeleteOpAndArrays(model, op_1);
|
||||||
// Erase op scalar inputs & operators. Note that we preserve the non-scalar
|
|
||||||
// input to the first op as that's been redirected to the relu1_op.
|
|
||||||
DeleteArrayIfUsedOnce(op_0->inputs[op_0_scalar_input_index], model);
|
|
||||||
DeleteArrayIfUsedOnce(op_1->inputs[0], model);
|
|
||||||
DeleteArrayIfUsedOnce(op_1->inputs[1], model);
|
|
||||||
model->operators.erase(FindOperator(model, op_0));
|
|
||||||
model->operators.erase(FindOperator(model, op_1));
|
|
||||||
|
|
||||||
*modified = true;
|
*modified = true;
|
||||||
return ::tensorflow::Status::OK();
|
return ::tensorflow::Status::OK();
|
||||||
|
@ -546,10 +546,7 @@ void FixMinMaxPostQuantization(GraphTransformation* transformation,
|
|||||||
// Check if the output of that Dequantize op was not used by any
|
// Check if the output of that Dequantize op was not used by any
|
||||||
// other operator. We will then erase that Dequantize op.
|
// other operator. We will then erase that Dequantize op.
|
||||||
if (!CountOpsWithInput(*model, dequantize_op->outputs[0])) {
|
if (!CountOpsWithInput(*model, dequantize_op->outputs[0])) {
|
||||||
if (IsDiscardableArray(*model, dequantize_op->outputs[0])) {
|
if (!IsDiscardableArray(*model, dequantize_op->outputs[0])) {
|
||||||
// Usual case: we can just discard the dequantize output.
|
|
||||||
model->EraseArray(dequantize_op->outputs[0]);
|
|
||||||
} else {
|
|
||||||
// The dequantize output is not discardable. Special care needed.
|
// The dequantize output is not discardable. Special care needed.
|
||||||
// If any of the model's output_arrays was pointing to the
|
// If any of the model's output_arrays was pointing to the
|
||||||
// Dequantize op's output, let it point to the Dequantize op's
|
// Dequantize op's output, let it point to the Dequantize op's
|
||||||
@ -583,13 +580,12 @@ void FixMinMaxPostQuantization(GraphTransformation* transformation,
|
|||||||
model->flags.output_arrays(i),
|
model->flags.output_arrays(i),
|
||||||
dequantize_op->inputs[0]);
|
dequantize_op->inputs[0]);
|
||||||
model->flags.set_output_arrays(i, dequantize_op->inputs[0]);
|
model->flags.set_output_arrays(i, dequantize_op->inputs[0]);
|
||||||
model->EraseArray(dequantize_op->outputs[0]);
|
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
model->operators.erase(dequantize_it);
|
DeleteOpAndArrays(model, dequantize_op);
|
||||||
}
|
}
|
||||||
changed = true;
|
changed = true;
|
||||||
} else {
|
} else {
|
||||||
|
@ -54,8 +54,7 @@ namespace toco {
|
|||||||
|
|
||||||
// Remove the node and its output array.
|
// Remove the node and its output array.
|
||||||
AddMessageF("Removed final %s", LogName(*dequantize_op));
|
AddMessageF("Removed final %s", LogName(*dequantize_op));
|
||||||
model->EraseArray(output);
|
DeleteOpAndArrays(model, dequantize_op);
|
||||||
model->operators.erase(dequantize_it);
|
|
||||||
*modified = true;
|
*modified = true;
|
||||||
return ::tensorflow::Status::OK();
|
return ::tensorflow::Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -59,13 +59,10 @@ namespace toco {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Drop trivial inputs.
|
// Drop trivial inputs.
|
||||||
for (const string& input : trivial_inputs) {
|
|
||||||
if (IsDiscardableArray(*model, input) &&
|
|
||||||
CountOpsWithInput(*model, input) == 1) {
|
|
||||||
model->EraseArray(input);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
concat_op->inputs = nontrivial_inputs;
|
concat_op->inputs = nontrivial_inputs;
|
||||||
|
for (const string& input : trivial_inputs) {
|
||||||
|
DeleteArrayIfUnusedOutsideOfOp(input, concat_op, model);
|
||||||
|
}
|
||||||
*modified = true;
|
*modified = true;
|
||||||
return ::tensorflow::Status::OK();
|
return ::tensorflow::Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -89,17 +89,6 @@ bool RemoveTrivialPassthroughOp(GraphTransformation* transformation,
|
|||||||
const string main_input_name = passthru_op->inputs[main_input_array_index];
|
const string main_input_name = passthru_op->inputs[main_input_array_index];
|
||||||
const string output_name = passthru_op->outputs[0];
|
const string output_name = passthru_op->outputs[0];
|
||||||
|
|
||||||
// Build the list of all input and output arrays of the passthrough node
|
|
||||||
// that we are considering removing. Any of these arrays is a candidate
|
|
||||||
// for being removed as well, if nothing else references it. Doing that
|
|
||||||
// arrays-removal together with the passthrough-node-removal proved too
|
|
||||||
// error-prone.
|
|
||||||
std::vector<string> removal_candidates;
|
|
||||||
for (const string& input : passthru_op->inputs) {
|
|
||||||
removal_candidates.push_back(input);
|
|
||||||
}
|
|
||||||
removal_candidates.push_back(output_name);
|
|
||||||
|
|
||||||
if (IsDiscardableArray(*model, output_name)) {
|
if (IsDiscardableArray(*model, output_name)) {
|
||||||
transformation->AddMessageF(
|
transformation->AddMessageF(
|
||||||
"Removing %s, keeping its non-constant input array %s and removing %s",
|
"Removing %s, keeping its non-constant input array %s and removing %s",
|
||||||
@ -135,38 +124,7 @@ bool RemoveTrivialPassthroughOp(GraphTransformation* transformation,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Remove the pass-through node.
|
// Remove the pass-through node.
|
||||||
CHECK_EQ(passthru_it->get(), passthru_op);
|
DeleteOpAndArrays(model, passthru_op);
|
||||||
model->operators.erase(passthru_it);
|
|
||||||
|
|
||||||
// Remove any array that is no longer used.
|
|
||||||
for (const string& removal_candidate : removal_candidates) {
|
|
||||||
bool is_referenced = false;
|
|
||||||
for (const auto& array : model->flags.input_arrays()) {
|
|
||||||
if (array.name() == removal_candidate) {
|
|
||||||
is_referenced = true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for (const auto& array_name : model->flags.output_arrays()) {
|
|
||||||
if (array_name == removal_candidate) {
|
|
||||||
is_referenced = true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for (const auto& op : model->operators) {
|
|
||||||
for (const string& input : op->inputs) {
|
|
||||||
if (input == removal_candidate) {
|
|
||||||
is_referenced = true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for (const string& output : op->outputs) {
|
|
||||||
if (output == removal_candidate) {
|
|
||||||
is_referenced = true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (!is_referenced) {
|
|
||||||
model->EraseArray(removal_candidate);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
@ -86,27 +86,7 @@ namespace toco {
|
|||||||
|
|
||||||
AddMessageF("Discarding %s because none of its outputs is used.",
|
AddMessageF("Discarding %s because none of its outputs is used.",
|
||||||
LogName(*op));
|
LogName(*op));
|
||||||
|
DeleteOpAndArrays(model, op);
|
||||||
// At that point we know that none of the outputs is used, so we will
|
|
||||||
// definitely remove the node and all its outputs.
|
|
||||||
|
|
||||||
// Remove any input array that not the output of another op, and only used by
|
|
||||||
// this op.
|
|
||||||
for (const auto& input : op->inputs) {
|
|
||||||
if (!GetOpWithOutput(*model, input)) {
|
|
||||||
DeleteArrayIfUsedOnce(input, model);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Remove the node and its now-unused output arrays.
|
|
||||||
for (const auto& output : op->outputs) {
|
|
||||||
// If the output array is the model's input array, don't remove that.
|
|
||||||
// That's the case when cropping a model at a given --input_array.
|
|
||||||
if (IsDiscardableArray(*model, output)) {
|
|
||||||
model->EraseArray(output);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
model->operators.erase(it);
|
|
||||||
*modified = true;
|
*modified = true;
|
||||||
return ::tensorflow::Status::OK();
|
return ::tensorflow::Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -122,7 +122,7 @@ bool IsMoveOperator(OperatorType optype) {
|
|||||||
|
|
||||||
element_op->inputs[0] = input_name;
|
element_op->inputs[0] = input_name;
|
||||||
element_op->outputs[0] = new_intermediate_name;
|
element_op->outputs[0] = new_intermediate_name;
|
||||||
model->EraseArray(intermediate_name);
|
DeleteArrayIfUnused(intermediate_name, model);
|
||||||
move_op->inputs[0] = new_intermediate_name;
|
move_op->inputs[0] = new_intermediate_name;
|
||||||
move_op->outputs[0] = output_name;
|
move_op->outputs[0] = output_name;
|
||||||
} else {
|
} else {
|
||||||
|
@ -190,7 +190,7 @@ std::vector<int> ComputeNewPerm(std::vector<int> input_dims,
|
|||||||
transpose_op->outputs[0] = new_intermediate_name;
|
transpose_op->outputs[0] = new_intermediate_name;
|
||||||
reshape_op->inputs[0] = new_intermediate_name;
|
reshape_op->inputs[0] = new_intermediate_name;
|
||||||
reshape_op->outputs[0] = output_name;
|
reshape_op->outputs[0] = output_name;
|
||||||
model->EraseArray(intermediate_name);
|
DeleteArrayIfUnused(intermediate_name, model);
|
||||||
} else {
|
} else {
|
||||||
// The intermediate array is now the output array.
|
// The intermediate array is now the output array.
|
||||||
for (int i = 0; i < model->operators.size(); i++) {
|
for (int i = 0; i < model->operators.size(); i++) {
|
||||||
|
@ -136,14 +136,7 @@ namespace toco {
|
|||||||
offset_float_data[i] - mean_float_data[i] * multiplier_float_data[i];
|
offset_float_data[i] - mean_float_data[i] * multiplier_float_data[i];
|
||||||
}
|
}
|
||||||
|
|
||||||
// Remove the old param arrays
|
DeleteOpAndArrays(model, bn_op);
|
||||||
DeleteArrayIfUsedOnce(bn_op->inputs[1], model);
|
|
||||||
DeleteArrayIfUsedOnce(bn_op->inputs[2], model);
|
|
||||||
DeleteArrayIfUsedOnce(bn_op->inputs[3], model);
|
|
||||||
|
|
||||||
// Remove the old operator
|
|
||||||
DCHECK_EQ(bn_it->get(), bn_op);
|
|
||||||
model->operators.erase(bn_it);
|
|
||||||
|
|
||||||
*modified = true;
|
*modified = true;
|
||||||
return ::tensorflow::Status::OK();
|
return ::tensorflow::Status::OK();
|
||||||
|
@ -247,16 +247,7 @@ void EvaluateBinaryOperatorOnConstantInputs(Model* model,
|
|||||||
// Do the actual constants propagation
|
// Do the actual constants propagation
|
||||||
EvaluateBinaryOperatorOnConstantInputs(model, binary_op);
|
EvaluateBinaryOperatorOnConstantInputs(model, binary_op);
|
||||||
|
|
||||||
// Remove the binary operator and its inputs
|
DeleteOpAndArrays(model, binary_op);
|
||||||
if (CountOpsWithInput(*model, binary_op->inputs[0]) == 1) {
|
|
||||||
model->EraseArray(binary_op->inputs[0]);
|
|
||||||
}
|
|
||||||
if (CountOpsWithInput(*model, binary_op->inputs[1]) == 1) {
|
|
||||||
model->EraseArray(binary_op->inputs[1]);
|
|
||||||
}
|
|
||||||
AddMessageF("Resolved constant %s to the equivalent constant array",
|
|
||||||
LogName(*binary_op));
|
|
||||||
model->operators.erase(binary_it);
|
|
||||||
*modified = true;
|
*modified = true;
|
||||||
return ::tensorflow::Status::OK();
|
return ::tensorflow::Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -206,16 +206,7 @@ void SetMinMaxForConcatenedArray(GraphTransformation* transformation,
|
|||||||
LOG(FATAL) << "ArrayDataType not supported";
|
LOG(FATAL) << "ArrayDataType not supported";
|
||||||
}
|
}
|
||||||
|
|
||||||
// Remove all the resolved arrays.
|
DeleteOpAndArrays(model, concat_op);
|
||||||
for (const string& input_name : concat_op->inputs) {
|
|
||||||
// Check to prevent removal of shared tensors.
|
|
||||||
if (CountOpsWithInput(*model, input_name) == 1) {
|
|
||||||
model->EraseArray(input_name);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Remove concatenate operator.
|
|
||||||
model->operators.erase(concat_it);
|
|
||||||
*modified = true;
|
*modified = true;
|
||||||
return ::tensorflow::Status::OK();
|
return ::tensorflow::Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -132,13 +132,7 @@ void GetBoundsForQuantizedDataType(ArrayDataType quantized_data_type,
|
|||||||
tflite::FakeQuantizeArray(scale, nudged_min, nudged_max,
|
tflite::FakeQuantizeArray(scale, nudged_min, nudged_max,
|
||||||
input_buffer.data.data(), output_buffer.data.data(),
|
input_buffer.data.data(), output_buffer.data.data(),
|
||||||
size);
|
size);
|
||||||
|
DeleteOpAndArrays(model, fakequant_op);
|
||||||
if (IsDiscardableArray(*model, fakequant_op->inputs[0]) &&
|
|
||||||
CountOpsWithInput(*model, fakequant_op->inputs[0]) == 1) {
|
|
||||||
model->EraseArray(fakequant_op->inputs[0]);
|
|
||||||
}
|
|
||||||
model->operators.erase(fakequant_it);
|
|
||||||
|
|
||||||
*modified = true;
|
*modified = true;
|
||||||
return ::tensorflow::Status::OK();
|
return ::tensorflow::Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -109,19 +109,7 @@ bool ComputeFillArray(Model* model, FillOperator* op) {
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Erase input arrays if no longer used
|
DeleteOpAndArrays(model, op);
|
||||||
if (IsDiscardableArray(*model, op->inputs[0]) &&
|
|
||||||
CountOpsWithInput(*model, op->inputs[0]) == 1) {
|
|
||||||
model->EraseArray(op->inputs[0]);
|
|
||||||
}
|
|
||||||
if (IsDiscardableArray(*model, op->inputs[1]) &&
|
|
||||||
CountOpsWithInput(*model, op->inputs[1]) == 1) {
|
|
||||||
model->EraseArray(op->inputs[1]);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Erase the operator
|
|
||||||
model->operators.erase(fill_it);
|
|
||||||
|
|
||||||
*modified = true;
|
*modified = true;
|
||||||
return ::tensorflow::Status::OK();
|
return ::tensorflow::Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -142,12 +142,7 @@ inline void Gather(const Array& input_array, const Array& coords_array,
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Erase input arrays if no longer used after we remove the op.
|
DeleteOpAndArrays(model, op);
|
||||||
DeleteArrayIfUsedOnce(op->inputs[0], model);
|
|
||||||
DeleteArrayIfUsedOnce(op->inputs[1], model);
|
|
||||||
|
|
||||||
// Erase the operator.
|
|
||||||
model->operators.erase(it);
|
|
||||||
*modified = true;
|
*modified = true;
|
||||||
return ::tensorflow::Status::OK();
|
return ::tensorflow::Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -110,13 +110,7 @@ void Pack(Model* model, PackOperator const& op) {
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Erase input arrays if no longer used
|
DeleteOpAndArrays(model, op);
|
||||||
for (const auto& input : op->inputs) {
|
|
||||||
toco::DeleteArrayIfUsedOnce(input, model);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Erase the operator
|
|
||||||
model->operators.erase(it);
|
|
||||||
*modified = true;
|
*modified = true;
|
||||||
return ::tensorflow::Status::OK();
|
return ::tensorflow::Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -107,12 +107,7 @@ bool ComputeRandomUniformArray(Model* model, RandomUniformOperator* op) {
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Erase input arrays if no longer used
|
DeleteOpAndArrays(model, op);
|
||||||
toco::DeleteArrayIfUsedOnce(op->inputs[0], model);
|
|
||||||
|
|
||||||
// Erase the operator
|
|
||||||
model->operators.erase(it);
|
|
||||||
|
|
||||||
*modified = true;
|
*modified = true;
|
||||||
return ::tensorflow::Status::OK();
|
return ::tensorflow::Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -105,23 +105,7 @@ void FillRangeOutput(const Array& start_array, const Array& limit_array,
|
|||||||
delta_array, &output_array);
|
delta_array, &output_array);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Delete the input array if no longer used
|
DeleteOpAndArrays(model, op);
|
||||||
if (IsDiscardableArray(*model, op->inputs[0]) &&
|
|
||||||
CountOpsWithInput(*model, op->inputs[0]) == 1) {
|
|
||||||
model->EraseArray(op->inputs[0]);
|
|
||||||
}
|
|
||||||
if (IsDiscardableArray(*model, op->inputs[1]) &&
|
|
||||||
CountOpsWithInput(*model, op->inputs[1]) == 1) {
|
|
||||||
model->EraseArray(op->inputs[1]);
|
|
||||||
}
|
|
||||||
if (IsDiscardableArray(*model, op->inputs[2]) &&
|
|
||||||
CountOpsWithInput(*model, op->inputs[2]) == 1) {
|
|
||||||
model->EraseArray(op->inputs[2]);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Delete the operator
|
|
||||||
model->operators.erase(it);
|
|
||||||
|
|
||||||
*modified = true;
|
*modified = true;
|
||||||
return ::tensorflow::Status::OK();
|
return ::tensorflow::Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -108,16 +108,7 @@ namespace toco {
|
|||||||
|
|
||||||
CopyMinMaxAndQuantizationRelatedFields(input_array, &output_array);
|
CopyMinMaxAndQuantizationRelatedFields(input_array, &output_array);
|
||||||
|
|
||||||
// Erase input arrays if no longer used.
|
DeleteOpAndArrays(model, op);
|
||||||
for (const auto& input : op->inputs) {
|
|
||||||
if (IsDiscardableArray(*model, input) &&
|
|
||||||
CountOpsWithInput(*model, input) == 1) {
|
|
||||||
model->EraseArray(input);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Erase the operator.
|
|
||||||
model->operators.erase(it);
|
|
||||||
*modified = true;
|
*modified = true;
|
||||||
return ::tensorflow::Status::OK();
|
return ::tensorflow::Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -61,13 +61,7 @@ namespace toco {
|
|||||||
output_array.mutable_shape()->ReplaceDims(
|
output_array.mutable_shape()->ReplaceDims(
|
||||||
{static_cast<int>(output_buffer.data.size())});
|
{static_cast<int>(output_buffer.data.size())});
|
||||||
|
|
||||||
// Delete the input array if no longer used
|
DeleteOpAndArrays(model, op);
|
||||||
if (IsDiscardableArray(*model, op->inputs[0]) &&
|
|
||||||
CountOpsWithInput(*model, op->inputs[0]) == 1) {
|
|
||||||
model->EraseArray(op->inputs[0]);
|
|
||||||
}
|
|
||||||
|
|
||||||
model->operators.erase(it);
|
|
||||||
*modified = true;
|
*modified = true;
|
||||||
return ::tensorflow::Status::OK();
|
return ::tensorflow::Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -158,15 +158,7 @@ bool Slice(SliceOperator const& op, Array const& input_array,
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Erase input array if no longer used.
|
DeleteOpAndArrays(model, op);
|
||||||
if (IsDiscardableArray(*model, op->inputs[0]) &&
|
|
||||||
CountOpsWithInput(*model, op->inputs[0]) == 1) {
|
|
||||||
model->EraseArray(op->inputs[0]);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Erase the operator
|
|
||||||
model->operators.erase(it);
|
|
||||||
|
|
||||||
*modified = true;
|
*modified = true;
|
||||||
return ::tensorflow::Status::OK();
|
return ::tensorflow::Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -163,8 +163,7 @@ void StridedSlice(StridedSliceOperator const& op, Array const& input_array,
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
DeleteOpAndArraysIfUnused(model, it->get());
|
DeleteOpAndArrays(model, op);
|
||||||
|
|
||||||
*modified = true;
|
*modified = true;
|
||||||
return ::tensorflow::Status::OK();
|
return ::tensorflow::Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -160,12 +160,7 @@ inline void Tile(const Array& input_array, const Array& multiples_array,
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Erase input arrays if no longer used after we remove the op.
|
DeleteOpAndArrays(model, op);
|
||||||
DeleteArrayIfUsedOnce(op->inputs[0], model);
|
|
||||||
DeleteArrayIfUsedOnce(op->inputs[1], model);
|
|
||||||
|
|
||||||
// Erase the operator.
|
|
||||||
model->operators.erase(it);
|
|
||||||
*modified = true;
|
*modified = true;
|
||||||
return ::tensorflow::Status::OK();
|
return ::tensorflow::Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -171,16 +171,7 @@ void Transpose(Model* model, const Array& input_array,
|
|||||||
|
|
||||||
AddMessageF("Resolving constant transpose of %s", LogName(*op));
|
AddMessageF("Resolving constant transpose of %s", LogName(*op));
|
||||||
|
|
||||||
// Erase input arrays if no longer used.
|
DeleteOpAndArrays(model, op);
|
||||||
for (const auto& input : op->inputs) {
|
|
||||||
if (IsDiscardableArray(*model, input) &&
|
|
||||||
CountOpsWithInput(*model, input) == 1) {
|
|
||||||
model->EraseArray(input);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Erase the operator.
|
|
||||||
model->operators.erase(it);
|
|
||||||
*modified = true;
|
*modified = true;
|
||||||
return ::tensorflow::Status::OK();
|
return ::tensorflow::Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -347,14 +347,8 @@ bool CopyMinMaxFromFirstInput(const Operator& op, Model* model) {
|
|||||||
} else {
|
} else {
|
||||||
LOG(FATAL) << "should not get here.";
|
LOG(FATAL) << "should not get here.";
|
||||||
}
|
}
|
||||||
for (const auto& input : unary_op->inputs) {
|
|
||||||
if (CountOpsWithInput(*model, input) == 1) {
|
DeleteOpAndArrays(model, unary_op);
|
||||||
model->EraseArray(input);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
AddMessageF("Resolved constant %s to the equivalent constant array",
|
|
||||||
LogName(*unary_op));
|
|
||||||
model->operators.erase(unary_it);
|
|
||||||
*modified = true;
|
*modified = true;
|
||||||
return ::tensorflow::Status::OK();
|
return ::tensorflow::Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -74,7 +74,8 @@ namespace toco {
|
|||||||
// values, anymore. Delete them unless they are used by something
|
// values, anymore. Delete them unless they are used by something
|
||||||
// else.
|
// else.
|
||||||
for (int i = 1; i <= 2; i++) {
|
for (int i = 1; i <= 2; i++) {
|
||||||
DeleteArrayIfUsedOnce(fakequant_op->inputs[i], model);
|
DeleteArrayIfUnusedOutsideOfOp(fakequant_op->inputs[i], fakequant_op,
|
||||||
|
model);
|
||||||
}
|
}
|
||||||
fakequant_op->inputs.resize(1);
|
fakequant_op->inputs.resize(1);
|
||||||
*modified = true;
|
*modified = true;
|
||||||
|
@ -49,7 +49,7 @@ namespace toco {
|
|||||||
op->axis = {axis_data[0]};
|
op->axis = {axis_data[0]};
|
||||||
|
|
||||||
// Drop the axis array as we no longer need it.
|
// Drop the axis array as we no longer need it.
|
||||||
DeleteArrayIfUsedOnce(op->inputs[2], model);
|
DeleteArrayIfUnusedOutsideOfOp(op->inputs[2], op, model);
|
||||||
op->inputs.resize(2);
|
op->inputs.resize(2);
|
||||||
|
|
||||||
*modified = true;
|
*modified = true;
|
||||||
|
@ -154,13 +154,7 @@ void FillArrayWithZeros(Array* array) {
|
|||||||
return ::tensorflow::Status::OK();
|
return ::tensorflow::Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Erase input arrays to the multiply if no longer used
|
DeleteOpAndArrays(model, mul_op);
|
||||||
DeleteArrayIfUsedOnce(mul_op->inputs[0], model);
|
|
||||||
DeleteArrayIfUsedOnce(mul_op->inputs[1], model);
|
|
||||||
|
|
||||||
// Erase the multiply operator.
|
|
||||||
model->operators.erase(mul_it);
|
|
||||||
|
|
||||||
*modified = true;
|
*modified = true;
|
||||||
return ::tensorflow::Status::OK();
|
return ::tensorflow::Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -119,7 +119,7 @@ void ReorderAxes(AxesOrder input_axes_order, AxesOrder output_axes_order,
|
|||||||
|
|
||||||
AddMessageF("Reordered axes for array %s", input_array_name);
|
AddMessageF("Reordered axes for array %s", input_array_name);
|
||||||
|
|
||||||
DeleteOpAndArraysIfUnused(model, op);
|
DeleteOpAndArrays(model, op);
|
||||||
RenameArray(model, output_array_name, input_array_name);
|
RenameArray(model, output_array_name, input_array_name);
|
||||||
|
|
||||||
*modified = true;
|
*modified = true;
|
||||||
|
@ -72,16 +72,8 @@ namespace toco {
|
|||||||
concatenation_op->outputs = {tf_concat_op->outputs[0]};
|
concatenation_op->outputs = {tf_concat_op->outputs[0]};
|
||||||
auto depth_concat_it = model->operators.emplace(concat_it, concatenation_op);
|
auto depth_concat_it = model->operators.emplace(concat_it, concatenation_op);
|
||||||
CHECK_EQ(depth_concat_it->get(), concatenation_op);
|
CHECK_EQ(depth_concat_it->get(), concatenation_op);
|
||||||
// Update invalidated iterator
|
|
||||||
concat_it = depth_concat_it + 1;
|
|
||||||
CHECK_EQ(concat_it->get(), tf_concat_op);
|
|
||||||
|
|
||||||
// Remove the axis array if it is not used by anything else.
|
DeleteOpAndArrays(model, tf_concat_op);
|
||||||
if (CountOpsWithInput(*model, axis_name) == 1) {
|
|
||||||
model->EraseArray(axis_name);
|
|
||||||
}
|
|
||||||
// Remove the TensorFlowConcat op
|
|
||||||
model->operators.erase(concat_it);
|
|
||||||
*modified = true;
|
*modified = true;
|
||||||
return ::tensorflow::Status::OK();
|
return ::tensorflow::Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -129,6 +129,7 @@ TransposeOperator* FindTransposeOpWithInput(const Model& model,
|
|||||||
|
|
||||||
// Construct the new FullyConnectedOperator.
|
// Construct the new FullyConnectedOperator.
|
||||||
auto* fc_op = new FullyConnectedOperator;
|
auto* fc_op = new FullyConnectedOperator;
|
||||||
|
fc_op->inputs = {input_lhs, input_rhs};
|
||||||
fc_op->outputs = matmul_op->outputs;
|
fc_op->outputs = matmul_op->outputs;
|
||||||
|
|
||||||
// Insert the newly constructed FullyConnectedOperator.
|
// Insert the newly constructed FullyConnectedOperator.
|
||||||
@ -173,14 +174,10 @@ TransposeOperator* FindTransposeOpWithInput(const Model& model,
|
|||||||
}
|
}
|
||||||
CHECK_EQ(previous_op->inputs.size(), 2);
|
CHECK_EQ(previous_op->inputs.size(), 2);
|
||||||
input_lhs = previous_op->inputs[0];
|
input_lhs = previous_op->inputs[0];
|
||||||
|
fc_op->inputs = {input_lhs, input_rhs};
|
||||||
// Only remove Reshape node if no other node uses its output.
|
// Only remove Reshape node if no other node uses its output.
|
||||||
if (CountOpsWithInput(*model, previous_op_output) == 1) {
|
if (CountOpsWithInput(*model, previous_op_output) == 1) {
|
||||||
const auto& previous_op_shape = previous_op->inputs[1];
|
DeleteOpAndArrays(model, previous_op);
|
||||||
if (CountOpsWithInput(*model, previous_op_shape) == 1 &&
|
|
||||||
!GetOpWithOutput(*model, previous_op_shape)) {
|
|
||||||
model->EraseArray(previous_op_shape);
|
|
||||||
}
|
|
||||||
model->operators.erase(previous_op_it);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// We may have just invalidated matmul_it, so let's refresh it now.
|
// We may have just invalidated matmul_it, so let's refresh it now.
|
||||||
@ -197,7 +194,6 @@ TransposeOperator* FindTransposeOpWithInput(const Model& model,
|
|||||||
LogName(*matmul_op));
|
LogName(*matmul_op));
|
||||||
}
|
}
|
||||||
|
|
||||||
fc_op->inputs = {input_lhs, input_rhs};
|
|
||||||
|
|
||||||
// erase the MatMul operator
|
// erase the MatMul operator
|
||||||
model->operators.erase(matmul_it);
|
model->operators.erase(matmul_it);
|
||||||
|
@ -56,10 +56,7 @@ namespace toco {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Remove the node and its output array.
|
DeleteOpAndArrays(model, merge_op);
|
||||||
AddMessageF("Removing already-resolved %s", LogName(*merge_op));
|
|
||||||
model->EraseArray(merge_op->outputs[0]);
|
|
||||||
model->operators.erase(merge_it);
|
|
||||||
*modified = true;
|
*modified = true;
|
||||||
return ::tensorflow::Status::OK();
|
return ::tensorflow::Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -121,7 +121,7 @@ namespace toco {
|
|||||||
}
|
}
|
||||||
// Remove the switch node itself.
|
// Remove the switch node itself.
|
||||||
AddMessageF("Removing already-resolved %s", LogName(*switch_op));
|
AddMessageF("Removing already-resolved %s", LogName(*switch_op));
|
||||||
model->operators.erase(switch_it);
|
DeleteOpAndArrays(model, switch_op);
|
||||||
*modified = true;
|
*modified = true;
|
||||||
return ::tensorflow::Status::OK();
|
return ::tensorflow::Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -107,6 +107,8 @@ class ResolveConstantConcatenationTest : public ::testing::Test {
|
|||||||
// together with 4 arrays as its inputs.
|
// together with 4 arrays as its inputs.
|
||||||
// It receives the dimension of concatenation as input.
|
// It receives the dimension of concatenation as input.
|
||||||
void PrepareModel(Model* model, int axis) {
|
void PrepareModel(Model* model, int axis) {
|
||||||
|
const string output_name("concat_op_output");
|
||||||
|
model->flags.add_output_arrays(output_name);
|
||||||
std::vector<string> concat_input_names = {"array0", "array1", "array2",
|
std::vector<string> concat_input_names = {"array0", "array1", "array2",
|
||||||
"array3"};
|
"array3"};
|
||||||
|
|
||||||
@ -141,7 +143,7 @@ class ResolveConstantConcatenationTest : public ::testing::Test {
|
|||||||
auto* concatenation_op = new ConcatenationOperator;
|
auto* concatenation_op = new ConcatenationOperator;
|
||||||
concatenation_op->axis = axis;
|
concatenation_op->axis = axis;
|
||||||
concatenation_op->inputs = concat_input_names;
|
concatenation_op->inputs = concat_input_names;
|
||||||
concatenation_op->outputs = {"concat_op_outputs"};
|
concatenation_op->outputs = {output_name};
|
||||||
Array& out_array = model->GetOrCreateArray(concatenation_op->outputs[0]);
|
Array& out_array = model->GetOrCreateArray(concatenation_op->outputs[0]);
|
||||||
out_array.data_type = ArrayDataType::kFloat;
|
out_array.data_type = ArrayDataType::kFloat;
|
||||||
Shape* out_array_shape = out_array.mutable_shape();
|
Shape* out_array_shape = out_array.mutable_shape();
|
||||||
@ -172,8 +174,8 @@ TEST_F(ResolveConstantConcatenationTest, ConcatAtAxis0) {
|
|||||||
.ok());
|
.ok());
|
||||||
EXPECT_THAT(model.GetArrayMap().size(), 1);
|
EXPECT_THAT(model.GetArrayMap().size(), 1);
|
||||||
|
|
||||||
auto& concatenated_array = (*model.GetArrayMap().begin()).second;
|
const auto& concatenated_array = model.GetArray(model.flags.output_arrays(0));
|
||||||
EXPECT_THAT(concatenated_array->GetBuffer<toco::ArrayDataType::kFloat>().data,
|
EXPECT_THAT(concatenated_array.GetBuffer<toco::ArrayDataType::kFloat>().data,
|
||||||
ElementsAreArray(ArrayFloatNear(
|
ElementsAreArray(ArrayFloatNear(
|
||||||
{0., 1., 2., 3., 4., 5., 6., 7., 10., 11., 12.,
|
{0., 1., 2., 3., 4., 5., 6., 7., 10., 11., 12.,
|
||||||
13., 14., 15., 16., 17., 20., 21., 22., 23., 24., 25.,
|
13., 14., 15., 16., 17., 20., 21., 22., 23., 24., 25.,
|
||||||
|
@ -31,9 +31,11 @@ void RunResolveSum(const std::vector<float>& input,
|
|||||||
const std::vector<int>& output_shape,
|
const std::vector<int>& output_shape,
|
||||||
const std::vector<float>& expected_output) {
|
const std::vector<float>& expected_output) {
|
||||||
Model model;
|
Model model;
|
||||||
|
const std::string output_name("output");
|
||||||
|
model.flags.add_output_arrays(output_name);
|
||||||
Array& input0 = model.GetOrCreateArray("input0");
|
Array& input0 = model.GetOrCreateArray("input0");
|
||||||
Array& input1 = model.GetOrCreateArray("input1");
|
Array& input1 = model.GetOrCreateArray("input1");
|
||||||
Array& output = model.GetOrCreateArray("output");
|
Array& output = model.GetOrCreateArray(output_name);
|
||||||
|
|
||||||
*input0.mutable_shape()->mutable_dims() = input_shape;
|
*input0.mutable_shape()->mutable_dims() = input_shape;
|
||||||
input0.data_type = ArrayDataType::kFloat;
|
input0.data_type = ArrayDataType::kFloat;
|
||||||
@ -48,7 +50,7 @@ void RunResolveSum(const std::vector<float>& input,
|
|||||||
auto sum_op = absl::make_unique<TensorFlowSumOperator>();
|
auto sum_op = absl::make_unique<TensorFlowSumOperator>();
|
||||||
sum_op->keep_dims = true;
|
sum_op->keep_dims = true;
|
||||||
sum_op->inputs = {"input0", "input1"};
|
sum_op->inputs = {"input0", "input1"};
|
||||||
sum_op->outputs = {"output"};
|
sum_op->outputs = {output_name};
|
||||||
model.operators.push_back(std::move(sum_op));
|
model.operators.push_back(std::move(sum_op));
|
||||||
bool modified;
|
bool modified;
|
||||||
ASSERT_TRUE(ResolveConstantUnaryOperator().Run(&model, 0, &modified).ok());
|
ASSERT_TRUE(ResolveConstantUnaryOperator().Run(&model, 0, &modified).ok());
|
||||||
|
@ -227,18 +227,18 @@ namespace toco {
|
|||||||
|
|
||||||
// Ensure the stitch output array is dead, as we don't want whatever was in it
|
// Ensure the stitch output array is dead, as we don't want whatever was in it
|
||||||
// previously now that we've redefined it. It'll be recreated when needed.
|
// previously now that we've redefined it. It'll be recreated when needed.
|
||||||
model->EraseArray(stitch_op->outputs[0]);
|
model->EraseArray(merged_gather_op->outputs[0]);
|
||||||
model->GetOrCreateArray(merged_gather_op->outputs[0]);
|
model->GetOrCreateArray(merged_gather_op->outputs[0]);
|
||||||
|
|
||||||
// Erase all the original ops.
|
// Erase all the original ops.
|
||||||
DeleteOpAndArraysIfUnused(model, div_op);
|
DeleteOpAndArrays(model, div_op);
|
||||||
DeleteOpAndArraysIfUnused(model, mod_op);
|
DeleteOpAndArrays(model, mod_op);
|
||||||
for (auto* gather_op : gather_ops) {
|
for (auto* gather_op : gather_ops) {
|
||||||
DeleteOpAndArraysIfUnused(model, gather_op);
|
DeleteOpAndArrays(model, gather_op);
|
||||||
}
|
}
|
||||||
DeleteOpAndArraysIfUnused(model, indices_partition_op);
|
DeleteOpAndArrays(model, indices_partition_op);
|
||||||
DeleteOpAndArraysIfUnused(model, data_partition_op);
|
DeleteOpAndArrays(model, data_partition_op);
|
||||||
DeleteOpAndArraysIfUnused(model, stitch_op);
|
DeleteOpAndArrays(model, stitch_op);
|
||||||
*modified = true;
|
*modified = true;
|
||||||
return ::tensorflow::Status::OK();
|
return ::tensorflow::Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -188,9 +188,8 @@ TransposeOperator* TransposeInput(const string& input, Model* model) {
|
|||||||
auto* matmul_op = new TensorFlowMatMulOperator;
|
auto* matmul_op = new TensorFlowMatMulOperator;
|
||||||
matmul_op->inputs = {input_lhs, input_rhs};
|
matmul_op->inputs = {input_lhs, input_rhs};
|
||||||
matmul_op->outputs = batch_op->outputs;
|
matmul_op->outputs = batch_op->outputs;
|
||||||
tail_it = model->operators.emplace(tail_it, matmul_op) + 1;
|
model->operators.emplace(tail_it, matmul_op);
|
||||||
CHECK_EQ(tail_it->get(), batch_op);
|
DeleteOpAndArrays(model, batch_op);
|
||||||
model->operators.erase(tail_it);
|
|
||||||
*modified = true;
|
*modified = true;
|
||||||
return ::tensorflow::Status::OK();
|
return ::tensorflow::Status::OK();
|
||||||
}
|
}
|
||||||
@ -254,16 +253,7 @@ TransposeOperator* TransposeInput(const string& input, Model* model) {
|
|||||||
reshape_result_op->outputs = {batch_op->outputs[0]};
|
reshape_result_op->outputs = {batch_op->outputs[0]};
|
||||||
model->operators.emplace(tail_it, reshape_result_op);
|
model->operators.emplace(tail_it, reshape_result_op);
|
||||||
|
|
||||||
// Remove the old batch matmul now that we've unrolled.
|
DeleteOpAndArrays(model, batch_op);
|
||||||
batch_op_it = model->operators.begin();
|
|
||||||
for (; batch_op_it != model->operators.end(); ++batch_op_it) {
|
|
||||||
if (batch_op_it->get() == batch_op) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
CHECK(batch_op_it != model->operators.end());
|
|
||||||
CHECK(batch_op_it->get() == batch_op);
|
|
||||||
model->operators.erase(batch_op_it);
|
|
||||||
*modified = true;
|
*modified = true;
|
||||||
return ::tensorflow::Status::OK();
|
return ::tensorflow::Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -165,19 +165,33 @@ bool DeleteArrayIfUnused(const string& array_name, Model* model) {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool DeleteArrayIfUsedOnce(const string& array_name, Model* model) {
|
bool DeleteArrayIfUnusedOutsideOfOp(const string& array_name,
|
||||||
if (IsDiscardableArray(*model, array_name) &&
|
const Operator* op, Model* model) {
|
||||||
CountOpsWithInput(*model, array_name) == 1 &&
|
if (!IsDiscardableArray(*model, array_name)) {
|
||||||
GetOpWithOutput(*model, array_name) == nullptr) {
|
return false;
|
||||||
|
}
|
||||||
|
if (CountOpsWithInput(*model, array_name) > 1) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
const Operator* op_having_this_as_input = GetOpWithInput(*model, array_name);
|
||||||
|
if (op_having_this_as_input && op_having_this_as_input != op) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
const Operator* op_having_this_as_output =
|
||||||
|
GetOpWithOutput(*model, array_name);
|
||||||
|
if (op_having_this_as_output && op_having_this_as_output != op) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
model->EraseArray(array_name);
|
model->EraseArray(array_name);
|
||||||
return true;
|
return true;
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void DeleteOpAndArraysIfUnused(Model* model, const Operator* op) {
|
void DeleteOpAndArrays(Model* model, const Operator* op) {
|
||||||
for (const string& array_name : op->inputs) {
|
for (const string& array_name : op->inputs) {
|
||||||
DeleteArrayIfUsedOnce(array_name, model);
|
DeleteArrayIfUnusedOutsideOfOp(array_name, op, model);
|
||||||
|
}
|
||||||
|
for (const string& array_name : op->outputs) {
|
||||||
|
DeleteArrayIfUnusedOutsideOfOp(array_name, op, model);
|
||||||
}
|
}
|
||||||
auto op_it = FindOp(*model, op);
|
auto op_it = FindOp(*model, op);
|
||||||
CHECK(op_it != model->operators.end());
|
CHECK(op_it != model->operators.end());
|
||||||
|
@ -68,11 +68,10 @@ int CountTrueOutputs(const Model& model, const Operator& op);
|
|||||||
|
|
||||||
int CountOpsWithInput(const Model& model, const string& array_name);
|
int CountOpsWithInput(const Model& model, const string& array_name);
|
||||||
bool DeleteArrayIfUnused(const string& array_name, Model* model);
|
bool DeleteArrayIfUnused(const string& array_name, Model* model);
|
||||||
bool DeleteArrayIfUsedOnce(const string& array_name, Model* model);
|
|
||||||
|
|
||||||
// Deletes the op and any of its input and output arrays if they are unused
|
// Deletes the op and any of its input and output arrays if they are unused
|
||||||
// after the op has been deleted.
|
// after the op has been deleted.
|
||||||
void DeleteOpAndArraysIfUnused(Model* model, const Operator* op);
|
void DeleteOpAndArrays(Model* model, const Operator* op);
|
||||||
|
|
||||||
std::vector<std::unique_ptr<Operator>>::const_iterator FindOpWithOutput(
|
std::vector<std::unique_ptr<Operator>>::const_iterator FindOpWithOutput(
|
||||||
const Model& model, const string& array_name);
|
const Model& model, const string& array_name);
|
||||||
@ -361,6 +360,11 @@ void UndoWeightsShuffling(Model* model);
|
|||||||
// Copies minmax, quantization_params, and narrow_range.
|
// Copies minmax, quantization_params, and narrow_range.
|
||||||
void CopyMinMaxAndQuantizationRelatedFields(const Array& src, Array* dst);
|
void CopyMinMaxAndQuantizationRelatedFields(const Array& src, Array* dst);
|
||||||
|
|
||||||
|
// Delete Array if it's discardable and not referenced as input or output array
|
||||||
|
// by any other op than the specified op.
|
||||||
|
bool DeleteArrayIfUnusedOutsideOfOp(const string& array_name,
|
||||||
|
const Operator* op, Model* model);
|
||||||
|
|
||||||
} // namespace toco
|
} // namespace toco
|
||||||
|
|
||||||
#endif // TENSORFLOW_LITE_TOCO_TOOLING_UTIL_H_
|
#endif // TENSORFLOW_LITE_TOCO_TOOLING_UTIL_H_
|
||||||
|
Loading…
Reference in New Issue
Block a user