Qualify uses of std::string
PiperOrigin-RevId: 317683481 Change-Id: I3f6bb93d67e18623bf5d049ae0cc03bd2f937e5f
This commit is contained in:
parent
197c34b17a
commit
7e59a2ab28
|
@ -76,8 +76,9 @@ namespace toco {
|
||||||
reshape_op->outputs = expand_op->outputs;
|
reshape_op->outputs = expand_op->outputs;
|
||||||
|
|
||||||
// Create a new input array
|
// Create a new input array
|
||||||
string axis_array_name = expand_op->inputs[1];
|
std::string axis_array_name = expand_op->inputs[1];
|
||||||
string shape_array_name = toco::AvailableArrayName(*model, axis_array_name);
|
std::string shape_array_name =
|
||||||
|
toco::AvailableArrayName(*model, axis_array_name);
|
||||||
Array& shape_array = model->GetOrCreateArray(shape_array_name);
|
Array& shape_array = model->GetOrCreateArray(shape_array_name);
|
||||||
*(shape_array.mutable_shape()->mutable_dims()) = {
|
*(shape_array.mutable_shape()->mutable_dims()) = {
|
||||||
1, static_cast<int>(reshape_dims.size())};
|
1, static_cast<int>(reshape_dims.size())};
|
||||||
|
|
|
@ -41,7 +41,7 @@ TensorFlowReshapeOperator* CreateReshapeFromReorderAxes(
|
||||||
input_shape.dims(3) * input_shape.dims(2)};
|
input_shape.dims(3) * input_shape.dims(2)};
|
||||||
|
|
||||||
// Create a new input array for Reshape.
|
// Create a new input array for Reshape.
|
||||||
string reshape_array_name =
|
std::string reshape_array_name =
|
||||||
AvailableArrayName(*model, reshape_op->outputs[0]);
|
AvailableArrayName(*model, reshape_op->outputs[0]);
|
||||||
reshape_op->inputs.push_back(reshape_array_name);
|
reshape_op->inputs.push_back(reshape_array_name);
|
||||||
|
|
||||||
|
@ -71,7 +71,8 @@ TransposeOperator* CreateTransposeFromReorderAxes(
|
||||||
GetShuffleShape(input_axes_order, output_axes_order, &permutations_data);
|
GetShuffleShape(input_axes_order, output_axes_order, &permutations_data);
|
||||||
|
|
||||||
// Create a new input permutations array for Transpose.
|
// Create a new input permutations array for Transpose.
|
||||||
string perm_array_name = AvailableArrayName(*model, transpose_op->outputs[0]);
|
std::string perm_array_name =
|
||||||
|
AvailableArrayName(*model, transpose_op->outputs[0]);
|
||||||
transpose_op->inputs.push_back(perm_array_name);
|
transpose_op->inputs.push_back(perm_array_name);
|
||||||
|
|
||||||
Array& perm_array = model->GetOrCreateArray(perm_array_name);
|
Array& perm_array = model->GetOrCreateArray(perm_array_name);
|
||||||
|
@ -104,7 +105,7 @@ TransposeOperator* CreateTransposeFromReorderAxes(
|
||||||
|
|
||||||
// Get input array. If kFakeQuant is the input into ReorderAxes, get the input
|
// Get input array. If kFakeQuant is the input into ReorderAxes, get the input
|
||||||
// array passed into kFakeQuant. kFakeQuant op is dropped when possible.
|
// array passed into kFakeQuant. kFakeQuant op is dropped when possible.
|
||||||
string constant_input_array_name = input_array_name;
|
std::string constant_input_array_name = input_array_name;
|
||||||
if (!input_array.buffer) {
|
if (!input_array.buffer) {
|
||||||
const auto* op_producing_input = GetOpWithOutput(*model, input_array_name);
|
const auto* op_producing_input = GetOpWithOutput(*model, input_array_name);
|
||||||
if (op_producing_input &&
|
if (op_producing_input &&
|
||||||
|
|
|
@ -59,7 +59,7 @@ namespace toco {
|
||||||
reshape_op->outputs = pack_op->outputs;
|
reshape_op->outputs = pack_op->outputs;
|
||||||
|
|
||||||
// Create shape param.
|
// Create shape param.
|
||||||
string shape_array_name =
|
std::string shape_array_name =
|
||||||
AvailableArrayName(*model, pack_op->outputs[0] + "_shape");
|
AvailableArrayName(*model, pack_op->outputs[0] + "_shape");
|
||||||
Array& shape_array = model->GetOrCreateArray(shape_array_name);
|
Array& shape_array = model->GetOrCreateArray(shape_array_name);
|
||||||
const int shape_array_dims = 1 + input_array.shape().dimensions_count();
|
const int shape_array_dims = 1 + input_array.shape().dimensions_count();
|
||||||
|
|
|
@ -90,8 +90,9 @@ bool TransposeAffectsMemoryOrder(std::vector<int> perm,
|
||||||
reshape_op->outputs = transpose_op->outputs;
|
reshape_op->outputs = transpose_op->outputs;
|
||||||
|
|
||||||
// Create a new input array for the shape input
|
// Create a new input array for the shape input
|
||||||
string perm_array_name = transpose_op->inputs[1];
|
std::string perm_array_name = transpose_op->inputs[1];
|
||||||
string shape_array_name = toco::AvailableArrayName(*model, perm_array_name);
|
std::string shape_array_name =
|
||||||
|
toco::AvailableArrayName(*model, perm_array_name);
|
||||||
Array& shape_array = model->GetOrCreateArray(shape_array_name);
|
Array& shape_array = model->GetOrCreateArray(shape_array_name);
|
||||||
*(shape_array.mutable_shape()->mutable_dims()) = {
|
*(shape_array.mutable_shape()->mutable_dims()) = {
|
||||||
1, static_cast<int>(output_dims.size())};
|
1, static_cast<int>(output_dims.size())};
|
||||||
|
|
|
@ -49,7 +49,7 @@ bool ProcessConvOperator(Model* model, ConvOperator* op) {
|
||||||
|
|
||||||
// Create the im2col array.
|
// Create the im2col array.
|
||||||
CHECK_EQ(op->outputs.size(), 1);
|
CHECK_EQ(op->outputs.size(), 1);
|
||||||
const string& im2col_array_name =
|
const std::string& im2col_array_name =
|
||||||
AvailableArrayName(*model, op->inputs[0] + "_im2col");
|
AvailableArrayName(*model, op->inputs[0] + "_im2col");
|
||||||
model->GetOrCreateArray(im2col_array_name);
|
model->GetOrCreateArray(im2col_array_name);
|
||||||
op->outputs.push_back(im2col_array_name);
|
op->outputs.push_back(im2col_array_name);
|
||||||
|
@ -65,7 +65,7 @@ bool ProcessTransposeConvOperator(Model* model, TransposeConvOperator* op) {
|
||||||
|
|
||||||
// Always create an im2col array for transpose_conv.
|
// Always create an im2col array for transpose_conv.
|
||||||
CHECK_EQ(op->outputs.size(), 1);
|
CHECK_EQ(op->outputs.size(), 1);
|
||||||
const string& im2col_array_name = AvailableArrayName(
|
const std::string& im2col_array_name = AvailableArrayName(
|
||||||
*model, op->inputs[TransposeConvOperator::DATA_INPUT] + "_im2col");
|
*model, op->inputs[TransposeConvOperator::DATA_INPUT] + "_im2col");
|
||||||
model->GetOrCreateArray(im2col_array_name);
|
model->GetOrCreateArray(im2col_array_name);
|
||||||
op->outputs.push_back(im2col_array_name);
|
op->outputs.push_back(im2col_array_name);
|
||||||
|
|
|
@ -41,7 +41,7 @@ void DequantizeBuffer(Array* array) {
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::unique_ptr<Operator>>::iterator FindFirstOpWithInput(
|
std::vector<std::unique_ptr<Operator>>::iterator FindFirstOpWithInput(
|
||||||
Model* model, const string& array_name) {
|
Model* model, const std::string& array_name) {
|
||||||
for (auto it = model->operators.begin(); it != model->operators.end(); ++it) {
|
for (auto it = model->operators.begin(); it != model->operators.end(); ++it) {
|
||||||
for (const auto& input : it->get()->inputs) {
|
for (const auto& input : it->get()->inputs) {
|
||||||
if (input == array_name) {
|
if (input == array_name) {
|
||||||
|
@ -52,7 +52,7 @@ std::vector<std::unique_ptr<Operator>>::iterator FindFirstOpWithInput(
|
||||||
return model->operators.end();
|
return model->operators.end();
|
||||||
}
|
}
|
||||||
|
|
||||||
void ClearArrayQuantizationParams(const string& array_name, Model* model) {
|
void ClearArrayQuantizationParams(const std::string& array_name, Model* model) {
|
||||||
auto* array = &model->GetArray(array_name);
|
auto* array = &model->GetArray(array_name);
|
||||||
CHECK(array->quantization_params);
|
CHECK(array->quantization_params);
|
||||||
for (auto& input_array : *model->flags.mutable_input_arrays()) {
|
for (auto& input_array : *model->flags.mutable_input_arrays()) {
|
||||||
|
@ -75,7 +75,7 @@ void ClearArrayQuantizationParams(const string& array_name, Model* model) {
|
||||||
array->quantization_params = nullptr;
|
array->quantization_params = nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool DequantizeArray(const string& array_name,
|
bool DequantizeArray(const std::string& array_name,
|
||||||
GraphTransformation* transformation, Model* model) {
|
GraphTransformation* transformation, Model* model) {
|
||||||
auto* array = &model->GetArray(array_name);
|
auto* array = &model->GetArray(array_name);
|
||||||
if (!array->quantization_params) {
|
if (!array->quantization_params) {
|
||||||
|
@ -133,7 +133,7 @@ bool DequantizeArray(const string& array_name,
|
||||||
if (IsInputArray(*model, array_name)) {
|
if (IsInputArray(*model, array_name)) {
|
||||||
must_insert_fakequant_after = true;
|
must_insert_fakequant_after = true;
|
||||||
}
|
}
|
||||||
for (const string& output_array : model->flags.output_arrays()) {
|
for (const std::string& output_array : model->flags.output_arrays()) {
|
||||||
if (array_name == output_array) {
|
if (array_name == output_array) {
|
||||||
must_insert_fakequant_before = true;
|
must_insert_fakequant_before = true;
|
||||||
}
|
}
|
||||||
|
@ -152,7 +152,7 @@ bool DequantizeArray(const string& array_name,
|
||||||
auto* fakequant_op = new FakeQuantOperator;
|
auto* fakequant_op = new FakeQuantOperator;
|
||||||
model->operators.emplace(FindFirstOpWithInput(model, array_name),
|
model->operators.emplace(FindFirstOpWithInput(model, array_name),
|
||||||
fakequant_op);
|
fakequant_op);
|
||||||
const string& new_array_name = AvailableArrayName(*model, array_name);
|
const std::string& new_array_name = AvailableArrayName(*model, array_name);
|
||||||
auto& new_array = model->GetOrCreateArray(new_array_name);
|
auto& new_array = model->GetOrCreateArray(new_array_name);
|
||||||
new_array.data_type = ArrayDataType::kFloat;
|
new_array.data_type = ArrayDataType::kFloat;
|
||||||
new_array.copy_shape(array->shape());
|
new_array.copy_shape(array->shape());
|
||||||
|
@ -162,7 +162,7 @@ bool DequantizeArray(const string& array_name,
|
||||||
fakequant_op->narrow_range = array->narrow_range;
|
fakequant_op->narrow_range = array->narrow_range;
|
||||||
if (must_insert_fakequant_before) {
|
if (must_insert_fakequant_before) {
|
||||||
for (const auto& op : model->operators) {
|
for (const auto& op : model->operators) {
|
||||||
for (string& output : op->outputs) {
|
for (std::string& output : op->outputs) {
|
||||||
if (output == array_name) {
|
if (output == array_name) {
|
||||||
output = new_array_name;
|
output = new_array_name;
|
||||||
}
|
}
|
||||||
|
@ -172,7 +172,7 @@ bool DequantizeArray(const string& array_name,
|
||||||
fakequant_op->outputs = {array_name};
|
fakequant_op->outputs = {array_name};
|
||||||
} else {
|
} else {
|
||||||
for (const auto& op : model->operators) {
|
for (const auto& op : model->operators) {
|
||||||
for (string& input : op->inputs) {
|
for (std::string& input : op->inputs) {
|
||||||
if (input == array_name) {
|
if (input == array_name) {
|
||||||
input = new_array_name;
|
input = new_array_name;
|
||||||
}
|
}
|
||||||
|
@ -209,15 +209,15 @@ bool DequantizeArray(const string& array_name,
|
||||||
return ::tensorflow::Status::OK();
|
return ::tensorflow::Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<string> arrays;
|
std::vector<std::string> arrays;
|
||||||
for (const string& input : op->inputs) {
|
for (const std::string& input : op->inputs) {
|
||||||
arrays.push_back(input);
|
arrays.push_back(input);
|
||||||
}
|
}
|
||||||
for (const string& output : op->outputs) {
|
for (const std::string& output : op->outputs) {
|
||||||
arrays.push_back(output);
|
arrays.push_back(output);
|
||||||
}
|
}
|
||||||
bool changed = false;
|
bool changed = false;
|
||||||
for (const string& array : arrays) {
|
for (const std::string& array : arrays) {
|
||||||
if (!model->IsOptionalArray(array)) {
|
if (!model->IsOptionalArray(array)) {
|
||||||
changed |= DequantizeArray(array, this, model);
|
changed |= DequantizeArray(array, this, model);
|
||||||
}
|
}
|
||||||
|
|
|
@ -27,7 +27,7 @@ namespace toco {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
int GetOutputDepthFromWeights(const Model& model, const Operator& op) {
|
int GetOutputDepthFromWeights(const Model& model, const Operator& op) {
|
||||||
const string& weights_name = op.inputs[1];
|
const std::string& weights_name = op.inputs[1];
|
||||||
const auto& weights_shape = model.GetArray(weights_name).shape();
|
const auto& weights_shape = model.GetArray(weights_name).shape();
|
||||||
if (op.type == OperatorType::kConv ||
|
if (op.type == OperatorType::kConv ||
|
||||||
op.type == OperatorType::kFullyConnected ||
|
op.type == OperatorType::kFullyConnected ||
|
||||||
|
@ -56,13 +56,14 @@ bool ProcessLinearOperator(Model* model, Operator* op) {
|
||||||
if (CheckOpInputSize(*op)) {
|
if (CheckOpInputSize(*op)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
const string& output_name = op->outputs[0];
|
const std::string& output_name = op->outputs[0];
|
||||||
const string& weights_name = op->inputs[1];
|
const std::string& weights_name = op->inputs[1];
|
||||||
if (!model->GetArray(weights_name).has_shape()) {
|
if (!model->GetArray(weights_name).has_shape()) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
const int depth = GetOutputDepthFromWeights(*model, *op);
|
const int depth = GetOutputDepthFromWeights(*model, *op);
|
||||||
const string& bias_name = AvailableArrayName(*model, output_name + "_bias");
|
const std::string& bias_name =
|
||||||
|
AvailableArrayName(*model, output_name + "_bias");
|
||||||
op->inputs.push_back(bias_name);
|
op->inputs.push_back(bias_name);
|
||||||
auto& bias_array = model->GetOrCreateArray(bias_name);
|
auto& bias_array = model->GetOrCreateArray(bias_name);
|
||||||
bias_array.data_type = ArrayDataType::kFloat;
|
bias_array.data_type = ArrayDataType::kFloat;
|
||||||
|
|
|
@ -152,7 +152,7 @@ namespace toco {
|
||||||
return ::tensorflow::Status::OK();
|
return ::tensorflow::Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
const string& name = op.inputs[weights_index];
|
const std::string& name = op.inputs[weights_index];
|
||||||
auto& array = model->GetArray(name);
|
auto& array = model->GetArray(name);
|
||||||
if (!array.buffer) {
|
if (!array.buffer) {
|
||||||
return ::tensorflow::Status::OK();
|
return ::tensorflow::Status::OK();
|
||||||
|
|
|
@ -260,7 +260,7 @@ void FuseMulOrDivParamsIntoPrecedingAffine(Model* model, Operator* preceding_op,
|
||||||
return ::tensorflow::Status::OK();
|
return ::tensorflow::Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
for (const string& output_array : model->flags.output_arrays()) {
|
for (const std::string& output_array : model->flags.output_arrays()) {
|
||||||
if (preceding_op->outputs[0] == output_array) {
|
if (preceding_op->outputs[0] == output_array) {
|
||||||
return ::tensorflow::Status::OK();
|
return ::tensorflow::Status::OK();
|
||||||
}
|
}
|
||||||
|
|
|
@ -29,7 +29,7 @@ namespace toco {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
void PrintModelStats(const string& label, const Model& model) {
|
void PrintModelStats(const std::string& label, const Model& model) {
|
||||||
int quantized_arrays = 0;
|
int quantized_arrays = 0;
|
||||||
for (const auto& array : model.GetArrayMap()) {
|
for (const auto& array : model.GetArrayMap()) {
|
||||||
if (array.second->quantization_params) {
|
if (array.second->quantization_params) {
|
||||||
|
@ -57,8 +57,8 @@ void PrintModelStats(const string& label, const Model& model) {
|
||||||
void DiscardUselessConnectedComponentsAndRNNBackEdges(Model* model) {
|
void DiscardUselessConnectedComponentsAndRNNBackEdges(Model* model) {
|
||||||
// Identify the set of arrays that are in 'useful' connected components
|
// Identify the set of arrays that are in 'useful' connected components
|
||||||
// of the graph, which means connected to output arrays.
|
// of the graph, which means connected to output arrays.
|
||||||
std::unordered_set<string> useful_arrays;
|
std::unordered_set<std::string> useful_arrays;
|
||||||
for (const string& output_array : model->flags.output_arrays()) {
|
for (const std::string& output_array : model->flags.output_arrays()) {
|
||||||
useful_arrays.insert(output_array);
|
useful_arrays.insert(output_array);
|
||||||
}
|
}
|
||||||
bool found_new_useful_arrays;
|
bool found_new_useful_arrays;
|
||||||
|
@ -66,15 +66,15 @@ void DiscardUselessConnectedComponentsAndRNNBackEdges(Model* model) {
|
||||||
found_new_useful_arrays = false;
|
found_new_useful_arrays = false;
|
||||||
for (const auto& op : model->operators) {
|
for (const auto& op : model->operators) {
|
||||||
bool op_touches_useful_arrays = false;
|
bool op_touches_useful_arrays = false;
|
||||||
for (const string& output : op->outputs) {
|
for (const std::string& output : op->outputs) {
|
||||||
op_touches_useful_arrays |= useful_arrays.count(output);
|
op_touches_useful_arrays |= useful_arrays.count(output);
|
||||||
}
|
}
|
||||||
if (op_touches_useful_arrays) {
|
if (op_touches_useful_arrays) {
|
||||||
for (const string& input : op->inputs) {
|
for (const std::string& input : op->inputs) {
|
||||||
found_new_useful_arrays |= !useful_arrays.count(input);
|
found_new_useful_arrays |= !useful_arrays.count(input);
|
||||||
useful_arrays.insert(input);
|
useful_arrays.insert(input);
|
||||||
}
|
}
|
||||||
for (const string& output : op->outputs) {
|
for (const std::string& output : op->outputs) {
|
||||||
found_new_useful_arrays |= !useful_arrays.count(output);
|
found_new_useful_arrays |= !useful_arrays.count(output);
|
||||||
useful_arrays.insert(output);
|
useful_arrays.insert(output);
|
||||||
}
|
}
|
||||||
|
@ -91,7 +91,7 @@ void DiscardUselessConnectedComponentsAndRNNBackEdges(Model* model) {
|
||||||
}
|
}
|
||||||
} while (found_new_useful_arrays);
|
} while (found_new_useful_arrays);
|
||||||
// Erase arrays that aren't useful, and that are discardable.
|
// Erase arrays that aren't useful, and that are discardable.
|
||||||
model->EraseArrays([&](const string& name) {
|
model->EraseArrays([&](const std::string& name) {
|
||||||
return (!useful_arrays.count(name) && IsDiscardableArray(*model, name));
|
return (!useful_arrays.count(name) && IsDiscardableArray(*model, name));
|
||||||
});
|
});
|
||||||
// Erase operators that do not produce a useful output array.
|
// Erase operators that do not produce a useful output array.
|
||||||
|
@ -101,7 +101,7 @@ void DiscardUselessConnectedComponentsAndRNNBackEdges(Model* model) {
|
||||||
if (useful_arrays.count((*it)->outputs[0])) {
|
if (useful_arrays.count((*it)->outputs[0])) {
|
||||||
++it;
|
++it;
|
||||||
} else {
|
} else {
|
||||||
for (const string& output : (*it)->outputs) {
|
for (const std::string& output : (*it)->outputs) {
|
||||||
CHECK(!useful_arrays.count(output));
|
CHECK(!useful_arrays.count(output));
|
||||||
}
|
}
|
||||||
it = model->operators.erase(it);
|
it = model->operators.erase(it);
|
||||||
|
@ -156,7 +156,7 @@ bool GraphTransformationsPass(int increment, Model* model,
|
||||||
<< " at op_index=" << op_index << "/"
|
<< " at op_index=" << op_index << "/"
|
||||||
<< model->operators.size() - 1;
|
<< model->operators.size() - 1;
|
||||||
}
|
}
|
||||||
for (const string& message : transformation->Messages()) {
|
for (const std::string& message : transformation->Messages()) {
|
||||||
VLOG(log_level) << transformation->Name() << " " << made_a_change_msg
|
VLOG(log_level) << transformation->Name() << " " << made_a_change_msg
|
||||||
<< " at op_index=" << op_index << "/"
|
<< " at op_index=" << op_index << "/"
|
||||||
<< model->operators.size() - 1 << ": " << message;
|
<< model->operators.size() - 1 << ": " << message;
|
||||||
|
@ -191,7 +191,7 @@ bool GraphTransformationsPass(int increment, Model* model,
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
tensorflow::Status RunGraphTransformationsWithStatus(
|
tensorflow::Status RunGraphTransformationsWithStatus(
|
||||||
Model* model, const string& msg,
|
Model* model, const std::string& msg,
|
||||||
const GraphTransformationsSet& transformations) {
|
const GraphTransformationsSet& transformations) {
|
||||||
PrintModelStats(toco::port::StringF("Before %s", msg), *model);
|
PrintModelStats(toco::port::StringF("Before %s", msg), *model);
|
||||||
int pass_index = 0;
|
int pass_index = 0;
|
||||||
|
|
|
@ -33,7 +33,7 @@ class GraphTransformation {
|
||||||
virtual ~GraphTransformation() {}
|
virtual ~GraphTransformation() {}
|
||||||
// Returns the list of messages that this graph transformation
|
// Returns the list of messages that this graph transformation
|
||||||
// generated since ClearMessages() was called.
|
// generated since ClearMessages() was called.
|
||||||
const std::vector<string>& Messages() const { return messages_; }
|
const std::vector<std::string>& Messages() const { return messages_; }
|
||||||
// Clears the list of messages; should be called after every
|
// Clears the list of messages; should be called after every
|
||||||
// run of this graph transformation.
|
// run of this graph transformation.
|
||||||
void ClearMessages() { return messages_.clear(); }
|
void ClearMessages() { return messages_.clear(); }
|
||||||
|
@ -48,7 +48,7 @@ class GraphTransformation {
|
||||||
GraphTransformation() {}
|
GraphTransformation() {}
|
||||||
|
|
||||||
// List of messages generated by this graph transformation.
|
// List of messages generated by this graph transformation.
|
||||||
std::vector<string> messages_;
|
std::vector<std::string> messages_;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
GraphTransformation(const GraphTransformation& other) = delete;
|
GraphTransformation(const GraphTransformation& other) = delete;
|
||||||
|
@ -74,7 +74,7 @@ class GraphTransformationsSet {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
void Add(GraphTransformation* transformation) {
|
void Add(GraphTransformation* transformation) {
|
||||||
const string& name = transformation->Name();
|
const std::string& name = transformation->Name();
|
||||||
CHECK(!names_.count(name));
|
CHECK(!names_.count(name));
|
||||||
names_.insert(name);
|
names_.insert(name);
|
||||||
transformations_.emplace_back(transformation);
|
transformations_.emplace_back(transformation);
|
||||||
|
@ -92,7 +92,7 @@ class GraphTransformationsSet {
|
||||||
GraphTransformationsSet(const GraphTransformationsSet&& other) = delete;
|
GraphTransformationsSet(const GraphTransformationsSet&& other) = delete;
|
||||||
std::vector<std::unique_ptr<GraphTransformation>> transformations_;
|
std::vector<std::unique_ptr<GraphTransformation>> transformations_;
|
||||||
// Names of transformations in the set. Only used to guard against dupes.
|
// Names of transformations in the set. Only used to guard against dupes.
|
||||||
std::unordered_set<string> names_;
|
std::unordered_set<std::string> names_;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Run the given list of graph transformations on the model.
|
// Run the given list of graph transformations on the model.
|
||||||
|
@ -103,11 +103,11 @@ class GraphTransformationsSet {
|
||||||
// the resulting raw pointers, and this RunGraphTransformations
|
// the resulting raw pointers, and this RunGraphTransformations
|
||||||
// takes care of delete'ing these pointers.
|
// takes care of delete'ing these pointers.
|
||||||
tensorflow::Status RunGraphTransformationsWithStatus(
|
tensorflow::Status RunGraphTransformationsWithStatus(
|
||||||
Model* model, const string& msg,
|
Model* model, const std::string& msg,
|
||||||
const GraphTransformationsSet& transformations);
|
const GraphTransformationsSet& transformations);
|
||||||
|
|
||||||
inline void RunGraphTransformations(
|
inline void RunGraphTransformations(
|
||||||
Model* model, const string& msg,
|
Model* model, const std::string& msg,
|
||||||
const GraphTransformationsSet& transformations) {
|
const GraphTransformationsSet& transformations) {
|
||||||
auto s = RunGraphTransformationsWithStatus(model, msg, transformations);
|
auto s = RunGraphTransformationsWithStatus(model, msg, transformations);
|
||||||
CHECK(s.ok()) << s.error_message();
|
CHECK(s.ok()) << s.error_message();
|
||||||
|
@ -232,7 +232,7 @@ class PropagateDefaultMinMax : public GraphTransformation {
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
bool SetArrayMinMax(const string& array_name, Array* array);
|
bool SetArrayMinMax(const std::string& array_name, Array* array);
|
||||||
std::vector<std::pair<ArrayDataType, MinMax>> type_ranges_;
|
std::vector<std::pair<ArrayDataType, MinMax>> type_ranges_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -197,7 +197,7 @@ void ConstructBidirectionalSequenceOp(
|
||||||
constexpr int kBwInputActivationStartIndex = 37;
|
constexpr int kBwInputActivationStartIndex = 37;
|
||||||
constexpr int kAuxInputStartIndex = 39;
|
constexpr int kAuxInputStartIndex = 39;
|
||||||
(*bi_op)->inputs.reserve(kBidirectionalSequenceLstmInputsCount);
|
(*bi_op)->inputs.reserve(kBidirectionalSequenceLstmInputsCount);
|
||||||
const string& input_array_name =
|
const std::string& input_array_name =
|
||||||
AvailableArrayName(*model, "bidirectional_sequence_lstm_input_0");
|
AvailableArrayName(*model, "bidirectional_sequence_lstm_input_0");
|
||||||
model->GetOrCreateArray(input_array_name);
|
model->GetOrCreateArray(input_array_name);
|
||||||
// The input will be changed later.
|
// The input will be changed later.
|
||||||
|
@ -232,7 +232,7 @@ void ConstructBidirectionalSequenceOp(
|
||||||
|
|
||||||
// TODO(renjieliu): Deal with Auxiliary input and weights for 39 - 47.
|
// TODO(renjieliu): Deal with Auxiliary input and weights for 39 - 47.
|
||||||
for (; i <= kBidirectionalSequenceLstmInputsCount; ++i) {
|
for (; i <= kBidirectionalSequenceLstmInputsCount; ++i) {
|
||||||
const string& temp_array_name = AvailableArrayName(
|
const std::string& temp_array_name = AvailableArrayName(
|
||||||
*model, "bidirectional_sequence_lstm_temp_" + std::to_string(i));
|
*model, "bidirectional_sequence_lstm_temp_" + std::to_string(i));
|
||||||
model->CreateOptionalArray(temp_array_name);
|
model->CreateOptionalArray(temp_array_name);
|
||||||
(*bi_op)->inputs.push_back(temp_array_name);
|
(*bi_op)->inputs.push_back(temp_array_name);
|
||||||
|
@ -240,9 +240,9 @@ void ConstructBidirectionalSequenceOp(
|
||||||
|
|
||||||
// Deal with outputs.
|
// Deal with outputs.
|
||||||
(*bi_op)->outputs.reserve(2);
|
(*bi_op)->outputs.reserve(2);
|
||||||
const string& fw_output_array_name =
|
const std::string& fw_output_array_name =
|
||||||
AvailableArrayName(*model, "bidirectional_sequence_lstm_fw_output_0");
|
AvailableArrayName(*model, "bidirectional_sequence_lstm_fw_output_0");
|
||||||
const string& bw_output_array_name =
|
const std::string& bw_output_array_name =
|
||||||
AvailableArrayName(*model, "bidirectional_sequence_lstm_bw_output_0");
|
AvailableArrayName(*model, "bidirectional_sequence_lstm_bw_output_0");
|
||||||
model->GetOrCreateArray(fw_output_array_name);
|
model->GetOrCreateArray(fw_output_array_name);
|
||||||
model->GetOrCreateArray(bw_output_array_name);
|
model->GetOrCreateArray(bw_output_array_name);
|
||||||
|
@ -260,7 +260,7 @@ void ConstructBidirectionalSequenceOp(
|
||||||
constexpr int kBwInputsStartIndex = 5;
|
constexpr int kBwInputsStartIndex = 5;
|
||||||
constexpr int kAuxInputsStartIndex = 9;
|
constexpr int kAuxInputsStartIndex = 9;
|
||||||
(*bi_op)->inputs.reserve(kBidirectionalSequenceRnnInputsCount);
|
(*bi_op)->inputs.reserve(kBidirectionalSequenceRnnInputsCount);
|
||||||
const string& input_array_name =
|
const std::string& input_array_name =
|
||||||
AvailableArrayName(*model, "bidirectional_sequence_rnn_input_0");
|
AvailableArrayName(*model, "bidirectional_sequence_rnn_input_0");
|
||||||
model->GetOrCreateArray(input_array_name);
|
model->GetOrCreateArray(input_array_name);
|
||||||
// The input will be changed later.
|
// The input will be changed later.
|
||||||
|
@ -280,7 +280,7 @@ void ConstructBidirectionalSequenceOp(
|
||||||
|
|
||||||
// TODO(renjieliu): Deal with optional weights.
|
// TODO(renjieliu): Deal with optional weights.
|
||||||
for (; i < kBidirectionalSequenceRnnInputsCount; ++i) {
|
for (; i < kBidirectionalSequenceRnnInputsCount; ++i) {
|
||||||
const string& temp_array_name = AvailableArrayName(
|
const std::string& temp_array_name = AvailableArrayName(
|
||||||
*model, "bidirectional_sequence_rnn_temp_" + std::to_string(i));
|
*model, "bidirectional_sequence_rnn_temp_" + std::to_string(i));
|
||||||
model->CreateOptionalArray(temp_array_name);
|
model->CreateOptionalArray(temp_array_name);
|
||||||
(*bi_op)->inputs.push_back(temp_array_name);
|
(*bi_op)->inputs.push_back(temp_array_name);
|
||||||
|
@ -288,9 +288,9 @@ void ConstructBidirectionalSequenceOp(
|
||||||
|
|
||||||
// Deal with outputs.
|
// Deal with outputs.
|
||||||
(*bi_op)->outputs.reserve(2);
|
(*bi_op)->outputs.reserve(2);
|
||||||
const string& fw_output_array_name =
|
const std::string& fw_output_array_name =
|
||||||
AvailableArrayName(*model, "bidirectional_sequence_rnn_fw_output_0");
|
AvailableArrayName(*model, "bidirectional_sequence_rnn_fw_output_0");
|
||||||
const string& bw_output_array_name =
|
const std::string& bw_output_array_name =
|
||||||
AvailableArrayName(*model, "bidirectional_sequence_rnn_bw_output_0");
|
AvailableArrayName(*model, "bidirectional_sequence_rnn_bw_output_0");
|
||||||
model->GetOrCreateArray(fw_output_array_name);
|
model->GetOrCreateArray(fw_output_array_name);
|
||||||
model->GetOrCreateArray(bw_output_array_name);
|
model->GetOrCreateArray(bw_output_array_name);
|
||||||
|
@ -318,7 +318,7 @@ void GroupFwBwSequenceOps(Model* model, std::stack<Operator*> fw_sequence_ops,
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void RewireBidirectionalSequenceSequenceOpsConnections(
|
void RewireBidirectionalSequenceSequenceOpsConnections(
|
||||||
OperatorType operator_type, const string& input_array_name,
|
OperatorType operator_type, const std::string& input_array_name,
|
||||||
const std::vector<T*>& bidirectional_sequence_ops,
|
const std::vector<T*>& bidirectional_sequence_ops,
|
||||||
std::vector<std::unique_ptr<Operator>>::iterator* op_it, Model* model) {
|
std::vector<std::unique_ptr<Operator>>::iterator* op_it, Model* model) {
|
||||||
int aux_input_index = -1;
|
int aux_input_index = -1;
|
||||||
|
@ -333,8 +333,8 @@ void RewireBidirectionalSequenceSequenceOpsConnections(
|
||||||
// Should not reach here.
|
// Should not reach here.
|
||||||
DCHECK(false);
|
DCHECK(false);
|
||||||
}
|
}
|
||||||
string cur_fw_input = input_array_name;
|
std::string cur_fw_input = input_array_name;
|
||||||
string cur_bw_input = input_array_name;
|
std::string cur_bw_input = input_array_name;
|
||||||
for (size_t i = 0; i < bidirectional_sequence_ops.size(); ++i) {
|
for (size_t i = 0; i < bidirectional_sequence_ops.size(); ++i) {
|
||||||
DeleteArrayIfUnusedOutsideOfOp(bidirectional_sequence_ops[i]->inputs[0],
|
DeleteArrayIfUnusedOutsideOfOp(bidirectional_sequence_ops[i]->inputs[0],
|
||||||
bidirectional_sequence_ops[i], model);
|
bidirectional_sequence_ops[i], model);
|
||||||
|
@ -371,8 +371,8 @@ void RewireFinalUnpackOutputs(const UnpackOperator& original_unpack_operator,
|
||||||
(*final_unpack_operator)->num = original_unpack_operator.num;
|
(*final_unpack_operator)->num = original_unpack_operator.num;
|
||||||
|
|
||||||
for (size_t i = 0; i < original_unpack_operator.outputs.size(); ++i) {
|
for (size_t i = 0; i < original_unpack_operator.outputs.size(); ++i) {
|
||||||
const string& output_array_name = original_unpack_operator.outputs[i];
|
const std::string& output_array_name = original_unpack_operator.outputs[i];
|
||||||
const string& final_unpack_output_array_name = AvailableArrayName(
|
const std::string& final_unpack_output_array_name = AvailableArrayName(
|
||||||
*model, "bidirectional_sequence_unpack_" + std::to_string(i));
|
*model, "bidirectional_sequence_unpack_" + std::to_string(i));
|
||||||
model->GetOrCreateArray(final_unpack_output_array_name);
|
model->GetOrCreateArray(final_unpack_output_array_name);
|
||||||
(*final_unpack_operator)->outputs.push_back(final_unpack_output_array_name);
|
(*final_unpack_operator)->outputs.push_back(final_unpack_output_array_name);
|
||||||
|
@ -381,7 +381,7 @@ void RewireFinalUnpackOutputs(const UnpackOperator& original_unpack_operator,
|
||||||
// If there's a following op after the unpack, it must be a concat op.
|
// If there's a following op after the unpack, it must be a concat op.
|
||||||
DCHECK(unpack_following_op->type == OperatorType::kConcatenation);
|
DCHECK(unpack_following_op->type == OperatorType::kConcatenation);
|
||||||
// For every output of the concat, rewire the outputs.
|
// For every output of the concat, rewire the outputs.
|
||||||
for (const string& concat_output : unpack_following_op->outputs) {
|
for (const std::string& concat_output : unpack_following_op->outputs) {
|
||||||
(*final_unpack_operator)->outputs[i] = concat_output;
|
(*final_unpack_operator)->outputs[i] = concat_output;
|
||||||
}
|
}
|
||||||
// Remove the concat op.
|
// Remove the concat op.
|
||||||
|
@ -454,7 +454,7 @@ template <typename T>
|
||||||
&bidirectional_sequence_ops);
|
&bidirectional_sequence_ops);
|
||||||
|
|
||||||
// Rewire the inputs & outputs.
|
// Rewire the inputs & outputs.
|
||||||
string current_input = first_fw_sequence_input->outputs[0];
|
std::string current_input = first_fw_sequence_input->outputs[0];
|
||||||
RewireBidirectionalSequenceSequenceOpsConnections(
|
RewireBidirectionalSequenceSequenceOpsConnections(
|
||||||
operator_type, current_input, bidirectional_sequence_ops, &op_it, model);
|
operator_type, current_input, bidirectional_sequence_ops, &op_it, model);
|
||||||
|
|
||||||
|
@ -525,7 +525,7 @@ template <typename T>
|
||||||
&bidirectional_sequence_lstm_ops);
|
&bidirectional_sequence_lstm_ops);
|
||||||
|
|
||||||
// Rewire the inputs & outputs.
|
// Rewire the inputs & outputs.
|
||||||
string current_input = first_fw_lstm_input->outputs[0];
|
std::string current_input = first_fw_lstm_input->outputs[0];
|
||||||
RewireBidirectionalSequenceSequenceOpsConnections(
|
RewireBidirectionalSequenceSequenceOpsConnections(
|
||||||
OperatorType::kBidirectionalSequenceLstm, current_input,
|
OperatorType::kBidirectionalSequenceLstm, current_input,
|
||||||
bidirectional_sequence_lstm_ops, &op_it, model);
|
bidirectional_sequence_lstm_ops, &op_it, model);
|
||||||
|
@ -601,7 +601,7 @@ template <typename T>
|
||||||
&bidirectional_sequence_rnn_ops);
|
&bidirectional_sequence_rnn_ops);
|
||||||
|
|
||||||
// Rewire the inputs & outputs.
|
// Rewire the inputs & outputs.
|
||||||
string current_input = first_fw_rnn_input->outputs[0];
|
std::string current_input = first_fw_rnn_input->outputs[0];
|
||||||
RewireBidirectionalSequenceSequenceOpsConnections(
|
RewireBidirectionalSequenceSequenceOpsConnections(
|
||||||
OperatorType::kBidirectionalSequenceRnn, current_input,
|
OperatorType::kBidirectionalSequenceRnn, current_input,
|
||||||
bidirectional_sequence_rnn_ops, &op_it, model);
|
bidirectional_sequence_rnn_ops, &op_it, model);
|
||||||
|
|
|
@ -279,10 +279,10 @@ bool MinMaxApproximatelyEqual(const MinMax& minmax1, const MinMax& minmax2) {
|
||||||
// If multiple of these arrays have MinMax, then these are required
|
// If multiple of these arrays have MinMax, then these are required
|
||||||
// to agree with each other.
|
// to agree with each other.
|
||||||
bool PropagateMinMaxAmongArrays(Model* model,
|
bool PropagateMinMaxAmongArrays(Model* model,
|
||||||
const std::vector<string>& array_names) {
|
const std::vector<std::string>& array_names) {
|
||||||
string reference_array_name;
|
std::string reference_array_name;
|
||||||
MinMax* reference_minmax = nullptr;
|
MinMax* reference_minmax = nullptr;
|
||||||
for (const string& array_name : array_names) {
|
for (const std::string& array_name : array_names) {
|
||||||
if (model->GetArray(array_name).minmax) {
|
if (model->GetArray(array_name).minmax) {
|
||||||
reference_array_name = array_name;
|
reference_array_name = array_name;
|
||||||
reference_minmax = model->GetArray(array_name).minmax.get();
|
reference_minmax = model->GetArray(array_name).minmax.get();
|
||||||
|
@ -294,7 +294,7 @@ bool PropagateMinMaxAmongArrays(Model* model,
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
bool changed = false;
|
bool changed = false;
|
||||||
for (const string& array_name : array_names) {
|
for (const std::string& array_name : array_names) {
|
||||||
auto& array = model->GetArray(array_name);
|
auto& array = model->GetArray(array_name);
|
||||||
if (array.minmax) {
|
if (array.minmax) {
|
||||||
CHECK(MinMaxApproximatelyEqual(*array.minmax, *reference_minmax))
|
CHECK(MinMaxApproximatelyEqual(*array.minmax, *reference_minmax))
|
||||||
|
|
|
@ -206,7 +206,7 @@ bool ResolveDilatedConv(Model* model, Operator* conv_base_op, Operator* stb_op,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Conv Op
|
// Conv Op
|
||||||
const string& input_of_conv_op =
|
const std::string& input_of_conv_op =
|
||||||
has_expand_op ? post_stb_op->outputs[0] : stb_op->outputs[0];
|
has_expand_op ? post_stb_op->outputs[0] : stb_op->outputs[0];
|
||||||
auto* conv_base_op = GetOpWithInput(*model, input_of_conv_op);
|
auto* conv_base_op = GetOpWithInput(*model, input_of_conv_op);
|
||||||
bool changed = false;
|
bool changed = false;
|
||||||
|
|
|
@ -78,7 +78,7 @@ using util::IsBinaryOp;
|
||||||
// 1. non-constant input of add_with_relu6_op
|
// 1. non-constant input of add_with_relu6_op
|
||||||
// 2. 1/6
|
// 2. 1/6
|
||||||
// 3. (and add_with_relu6_op[0].outputs[0] - which we already know!)
|
// 3. (and add_with_relu6_op[0].outputs[0] - which we already know!)
|
||||||
std::vector<string> mul_inputs = mul_op->inputs;
|
std::vector<std::string> mul_inputs = mul_op->inputs;
|
||||||
mul_inputs.insert(mul_inputs.end(), output_op->inputs.begin(),
|
mul_inputs.insert(mul_inputs.end(), output_op->inputs.begin(),
|
||||||
output_op->inputs.end());
|
output_op->inputs.end());
|
||||||
|
|
||||||
|
|
|
@ -35,7 +35,7 @@ std::vector<std::unique_ptr<Operator>>::iterator FindOperator(
|
||||||
return it;
|
return it;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool ValidateSourceOp(const Model& model, const string& array_name,
|
bool ValidateSourceOp(const Model& model, const std::string& array_name,
|
||||||
OperatorType op_type, Operator** source_op) {
|
OperatorType op_type, Operator** source_op) {
|
||||||
if (op_type == OperatorType::kNone) {
|
if (op_type == OperatorType::kNone) {
|
||||||
CHECK(!source_op);
|
CHECK(!source_op);
|
||||||
|
@ -184,7 +184,7 @@ bool MatchOperatorInputs(const Operator& op, const Model& model,
|
||||||
&state_remember_mul)) {
|
&state_remember_mul)) {
|
||||||
return ::tensorflow::Status::OK();
|
return ::tensorflow::Status::OK();
|
||||||
}
|
}
|
||||||
const string prev_state = state_forget_mul->inputs[0];
|
const std::string prev_state = state_forget_mul->inputs[0];
|
||||||
|
|
||||||
// State forget gate
|
// State forget gate
|
||||||
Operator* state_forget_sig;
|
Operator* state_forget_sig;
|
||||||
|
@ -271,16 +271,16 @@ bool MatchOperatorInputs(const Operator& op, const Model& model,
|
||||||
LogName(*lstm_cell_op));
|
LogName(*lstm_cell_op));
|
||||||
|
|
||||||
// Create temp arrays used internally during runtime.
|
// Create temp arrays used internally during runtime.
|
||||||
const string base_name(FindLongestCommonPrefix(
|
const std::string base_name(FindLongestCommonPrefix(
|
||||||
lstm_cell_op->outputs[LstmCellOperator::STATE_OUTPUT],
|
lstm_cell_op->outputs[LstmCellOperator::STATE_OUTPUT],
|
||||||
lstm_cell_op->outputs[LstmCellOperator::ACTIV_OUTPUT]));
|
lstm_cell_op->outputs[LstmCellOperator::ACTIV_OUTPUT]));
|
||||||
const string& concat_temp_array_name =
|
const std::string& concat_temp_array_name =
|
||||||
AvailableArrayName(*model, base_name + "concat_temp");
|
AvailableArrayName(*model, base_name + "concat_temp");
|
||||||
auto& concat_temp_array = model->GetOrCreateArray(concat_temp_array_name);
|
auto& concat_temp_array = model->GetOrCreateArray(concat_temp_array_name);
|
||||||
concat_temp_array.data_type =
|
concat_temp_array.data_type =
|
||||||
model->GetArray(concat_inputs->outputs[0]).data_type;
|
model->GetArray(concat_inputs->outputs[0]).data_type;
|
||||||
lstm_cell_op->outputs[LstmCellOperator::CONCAT_TEMP] = concat_temp_array_name;
|
lstm_cell_op->outputs[LstmCellOperator::CONCAT_TEMP] = concat_temp_array_name;
|
||||||
const string& activ_temp_array_name =
|
const std::string& activ_temp_array_name =
|
||||||
AvailableArrayName(*model, base_name + "activ_temp");
|
AvailableArrayName(*model, base_name + "activ_temp");
|
||||||
auto& activ_temp_array = model->GetOrCreateArray(activ_temp_array_name);
|
auto& activ_temp_array = model->GetOrCreateArray(activ_temp_array_name);
|
||||||
activ_temp_array.data_type =
|
activ_temp_array.data_type =
|
||||||
|
|
|
@ -45,12 +45,12 @@ namespace toco {
|
||||||
|
|
||||||
// Identify prev_activ_input, prev_state_input as required Op inputs,
|
// Identify prev_activ_input, prev_state_input as required Op inputs,
|
||||||
// using the rnn_states in the model flag.
|
// using the rnn_states in the model flag.
|
||||||
string prev_activ_input;
|
std::string prev_activ_input;
|
||||||
if (!GetMatchingRnnArray(model, src_op->outputs[kOutputTensor],
|
if (!GetMatchingRnnArray(model, src_op->outputs[kOutputTensor],
|
||||||
&prev_activ_input)) {
|
&prev_activ_input)) {
|
||||||
return ::tensorflow::Status::OK();
|
return ::tensorflow::Status::OK();
|
||||||
}
|
}
|
||||||
string prev_state_input;
|
std::string prev_state_input;
|
||||||
if (!GetMatchingRnnArray(model, src_op->outputs[kCellStateTensor],
|
if (!GetMatchingRnnArray(model, src_op->outputs[kCellStateTensor],
|
||||||
&prev_state_input)) {
|
&prev_state_input)) {
|
||||||
return ::tensorflow::Status::OK();
|
return ::tensorflow::Status::OK();
|
||||||
|
@ -72,9 +72,10 @@ namespace toco {
|
||||||
CHECK_EQ(num_cell, num_output);
|
CHECK_EQ(num_cell, num_output);
|
||||||
|
|
||||||
// Create tensorflow_graphdef style's one big weight tensor.
|
// Create tensorflow_graphdef style's one big weight tensor.
|
||||||
const string base_name(FindLongestCommonPrefix(
|
const std::string base_name(FindLongestCommonPrefix(
|
||||||
src_op->outputs[kOutputTensor], src_op->outputs[kCellStateTensor]));
|
src_op->outputs[kOutputTensor], src_op->outputs[kCellStateTensor]));
|
||||||
string merged_weights = AvailableArrayName(*model, base_name + "weights");
|
std::string merged_weights =
|
||||||
|
AvailableArrayName(*model, base_name + "weights");
|
||||||
auto& array = model->GetOrCreateArray(merged_weights);
|
auto& array = model->GetOrCreateArray(merged_weights);
|
||||||
array.data_type = ArrayDataType::kFloat;
|
array.data_type = ArrayDataType::kFloat;
|
||||||
int weights_dim1 = 4 * num_cell;
|
int weights_dim1 = 4 * num_cell;
|
||||||
|
@ -117,7 +118,7 @@ namespace toco {
|
||||||
num_cell * 3, num_input);
|
num_cell * 3, num_input);
|
||||||
|
|
||||||
// Create tensorflow_graphdef style's one big bias tensor.
|
// Create tensorflow_graphdef style's one big bias tensor.
|
||||||
string merged_biases = AvailableArrayName(*model, base_name + "biases");
|
std::string merged_biases = AvailableArrayName(*model, base_name + "biases");
|
||||||
auto& bias_array = model->GetOrCreateArray(merged_biases);
|
auto& bias_array = model->GetOrCreateArray(merged_biases);
|
||||||
bias_array.data_type = ArrayDataType::kFloat;
|
bias_array.data_type = ArrayDataType::kFloat;
|
||||||
bias_array.copy_shape(Shape({weights_dim1}));
|
bias_array.copy_shape(Shape({weights_dim1}));
|
||||||
|
@ -160,7 +161,7 @@ namespace toco {
|
||||||
lstm_cell_op->outputs[LstmCellOperator::ACTIV_TEMP] =
|
lstm_cell_op->outputs[LstmCellOperator::ACTIV_TEMP] =
|
||||||
src_op->outputs[kOutputStateTensor];
|
src_op->outputs[kOutputStateTensor];
|
||||||
// Create a new temp array for the fourth output.
|
// Create a new temp array for the fourth output.
|
||||||
const string& concat_temp_array_name =
|
const std::string& concat_temp_array_name =
|
||||||
AvailableArrayName(*model, base_name + "concat_temp");
|
AvailableArrayName(*model, base_name + "concat_temp");
|
||||||
model->GetOrCreateArray(concat_temp_array_name);
|
model->GetOrCreateArray(concat_temp_array_name);
|
||||||
lstm_cell_op->outputs[LstmCellOperator::CONCAT_TEMP] = concat_temp_array_name;
|
lstm_cell_op->outputs[LstmCellOperator::CONCAT_TEMP] = concat_temp_array_name;
|
||||||
|
|
|
@ -86,7 +86,7 @@ namespace toco {
|
||||||
// Get original weight tensor and decompose 1 tensor to 8 sub tensors.
|
// Get original weight tensor and decompose 1 tensor to 8 sub tensors.
|
||||||
Array& kernel =
|
Array& kernel =
|
||||||
model->GetArray(curr_op->inputs[LstmCellOperator::WEIGHTS_INPUT]);
|
model->GetArray(curr_op->inputs[LstmCellOperator::WEIGHTS_INPUT]);
|
||||||
const string base_name(FindLongestCommonPrefix(
|
const std::string base_name(FindLongestCommonPrefix(
|
||||||
curr_op->outputs[LstmCellOperator::ACTIV_OUTPUT],
|
curr_op->outputs[LstmCellOperator::ACTIV_OUTPUT],
|
||||||
curr_op->outputs[LstmCellOperator::STATE_OUTPUT]));
|
curr_op->outputs[LstmCellOperator::STATE_OUTPUT]));
|
||||||
|
|
||||||
|
|
|
@ -16,8 +16,8 @@ limitations under the License.
|
||||||
|
|
||||||
namespace toco {
|
namespace toco {
|
||||||
|
|
||||||
void CreateOptionalArray(Model* model, string* input_array_buffer,
|
void CreateOptionalArray(Model* model, std::string* input_array_buffer,
|
||||||
const string& array_name) {
|
const std::string& array_name) {
|
||||||
*input_array_buffer = array_name;
|
*input_array_buffer = array_name;
|
||||||
model->CreateOptionalArray(array_name);
|
model->CreateOptionalArray(array_name);
|
||||||
}
|
}
|
||||||
|
@ -39,7 +39,7 @@ void CopyArrayData(const Buffer<ArrayDataType::kFloat>& src_buffer,
|
||||||
}
|
}
|
||||||
|
|
||||||
Buffer<ArrayDataType::kFloat>* CreateFloatArrayBuffer(Model* model,
|
Buffer<ArrayDataType::kFloat>* CreateFloatArrayBuffer(Model* model,
|
||||||
string* array_name,
|
std::string* array_name,
|
||||||
const Shape& shape) {
|
const Shape& shape) {
|
||||||
*array_name = AvailableArrayName(*model, *array_name);
|
*array_name = AvailableArrayName(*model, *array_name);
|
||||||
auto& array = model->GetOrCreateArray(*array_name);
|
auto& array = model->GetOrCreateArray(*array_name);
|
||||||
|
@ -51,8 +51,8 @@ Buffer<ArrayDataType::kFloat>* CreateFloatArrayBuffer(Model* model,
|
||||||
return buffer;
|
return buffer;
|
||||||
}
|
}
|
||||||
|
|
||||||
void CopySubArrayToArray(Model* model, string* array_name,
|
void CopySubArrayToArray(Model* model, std::string* array_name,
|
||||||
const string& tensor_name, int dim1_size,
|
const std::string& tensor_name, int dim1_size,
|
||||||
int dim2_size, const Array& original_array,
|
int dim2_size, const Array& original_array,
|
||||||
int start_idx1, int start_idx2) {
|
int start_idx1, int start_idx2) {
|
||||||
// Determine whether it's bias or not, create shape, buffer.
|
// Determine whether it's bias or not, create shape, buffer.
|
||||||
|
@ -83,8 +83,9 @@ void CopyArrayToSubArray(Buffer<ArrayDataType::kFloat>& tensor_buffer,
|
||||||
dim1_copy_size, dim2_copy_size);
|
dim1_copy_size, dim2_copy_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool GetMatchingRnnArray(Model* model, const string& back_edge_source_array,
|
bool GetMatchingRnnArray(Model* model,
|
||||||
string* rnn_array) {
|
const std::string& back_edge_source_array,
|
||||||
|
std::string* rnn_array) {
|
||||||
for (const auto& rnn_state : model->flags.rnn_states()) {
|
for (const auto& rnn_state : model->flags.rnn_states()) {
|
||||||
if (rnn_state.back_edge_source_array() == back_edge_source_array) {
|
if (rnn_state.back_edge_source_array() == back_edge_source_array) {
|
||||||
*rnn_array = rnn_state.state_array();
|
*rnn_array = rnn_state.state_array();
|
||||||
|
|
|
@ -62,12 +62,12 @@ enum ExtendedLstmCellOutputs {
|
||||||
};
|
};
|
||||||
|
|
||||||
// Create optional array used for optional tensor in ExtendedLstmCell inputs.
|
// Create optional array used for optional tensor in ExtendedLstmCell inputs.
|
||||||
void CreateOptionalArray(Model* model, string* input_array_buffer,
|
void CreateOptionalArray(Model* model, std::string* input_array_buffer,
|
||||||
const string& array_name);
|
const std::string& array_name);
|
||||||
|
|
||||||
// Create float array and get its buffer.
|
// Create float array and get its buffer.
|
||||||
Buffer<ArrayDataType::kFloat>* CreateFloatArrayBuffer(Model* model,
|
Buffer<ArrayDataType::kFloat>* CreateFloatArrayBuffer(Model* model,
|
||||||
string* array_name,
|
std::string* array_name,
|
||||||
const Shape& shape);
|
const Shape& shape);
|
||||||
|
|
||||||
// Copy data from one array to the other one (supports 1D and 2D array),
|
// Copy data from one array to the other one (supports 1D and 2D array),
|
||||||
|
@ -91,8 +91,8 @@ void CopyArrayData(const Buffer<ArrayDataType::kFloat>& src_buffer,
|
||||||
|
|
||||||
// Copy a subset of array data and create a smaller array,
|
// Copy a subset of array data and create a smaller array,
|
||||||
// mostly used for spliting weights and bias for Lstm cell.
|
// mostly used for spliting weights and bias for Lstm cell.
|
||||||
void CopySubArrayToArray(Model* model, string* array_name,
|
void CopySubArrayToArray(Model* model, std::string* array_name,
|
||||||
const string& tensor_name, int dim1_size,
|
const std::string& tensor_name, int dim1_size,
|
||||||
int dim2_size, const Array& original_array,
|
int dim2_size, const Array& original_array,
|
||||||
int start_idx1, int start_idx2);
|
int start_idx1, int start_idx2);
|
||||||
|
|
||||||
|
@ -103,8 +103,9 @@ void CopyArrayToSubArray(Buffer<ArrayDataType::kFloat>& tensor_buffer,
|
||||||
int start_idx1, int start_idx2);
|
int start_idx1, int start_idx2);
|
||||||
|
|
||||||
// Get mating rnn array inputs using rnn_states flag.
|
// Get mating rnn array inputs using rnn_states flag.
|
||||||
bool GetMatchingRnnArray(Model* model, const string& back_edge_source_array,
|
bool GetMatchingRnnArray(Model* model,
|
||||||
string* rnn_array);
|
const std::string& back_edge_source_array,
|
||||||
|
std::string* rnn_array);
|
||||||
|
|
||||||
} // namespace toco
|
} // namespace toco
|
||||||
|
|
||||||
|
|
|
@ -31,7 +31,8 @@ namespace toco {
|
||||||
// generate this output to be removed by graph transformations. Note that there
|
// generate this output to be removed by graph transformations. Note that there
|
||||||
// may be more than one operator that takes the input_array as their input, and
|
// may be more than one operator that takes the input_array as their input, and
|
||||||
// that some of these may be removed by graph transformations.
|
// that some of these may be removed by graph transformations.
|
||||||
bool AddDequantizeOperatorToInput(const string& input_name, const Operator* op,
|
bool AddDequantizeOperatorToInput(const std::string& input_name,
|
||||||
|
const Operator* op,
|
||||||
GraphTransformation* transformation,
|
GraphTransformation* transformation,
|
||||||
Model* model) {
|
Model* model) {
|
||||||
// An operator with the required output may be a dequantize operator already
|
// An operator with the required output may be a dequantize operator already
|
||||||
|
@ -65,7 +66,7 @@ bool AddDequantizeOperatorToInput(const string& input_name, const Operator* op,
|
||||||
const auto& dequantized_input_name =
|
const auto& dequantized_input_name =
|
||||||
AvailableArrayName(*model, input_name + "_dequantized");
|
AvailableArrayName(*model, input_name + "_dequantized");
|
||||||
for (auto& other_op : model->operators) {
|
for (auto& other_op : model->operators) {
|
||||||
for (string& other_op_input : other_op->inputs) {
|
for (std::string& other_op_input : other_op->inputs) {
|
||||||
if (other_op_input == input_name) {
|
if (other_op_input == input_name) {
|
||||||
other_op_input = dequantized_input_name;
|
other_op_input = dequantized_input_name;
|
||||||
}
|
}
|
||||||
|
|
|
@ -117,8 +117,8 @@ std::vector<int32> ReshapeToTranspose(const Model& model,
|
||||||
return ::tensorflow::Status::OK();
|
return ::tensorflow::Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
const string intermediate_name = reshape_op->inputs[0];
|
const std::string intermediate_name = reshape_op->inputs[0];
|
||||||
const string output_name = reshape_op->outputs[0];
|
const std::string output_name = reshape_op->outputs[0];
|
||||||
|
|
||||||
// Guarantee the input is only consume by the reshape.
|
// Guarantee the input is only consume by the reshape.
|
||||||
if (CountOpsWithInput(*model, intermediate_name) != 1) {
|
if (CountOpsWithInput(*model, intermediate_name) != 1) {
|
||||||
|
|
|
@ -141,7 +141,7 @@ bool IsTailOfShape(const Shape& tail, const Shape& shape) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// EXTRA CHECKS ON CONNECTING ARRAY
|
// EXTRA CHECKS ON CONNECTING ARRAY
|
||||||
for (const string& output_array : model->flags.output_arrays()) {
|
for (const std::string& output_array : model->flags.output_arrays()) {
|
||||||
if (binary_op->inputs[variable_input_idx] == output_array) {
|
if (binary_op->inputs[variable_input_idx] == output_array) {
|
||||||
AddMessageF(
|
AddMessageF(
|
||||||
"Not moving %s because the output of reshape op %s is an output op.",
|
"Not moving %s because the output of reshape op %s is an output op.",
|
||||||
|
|
|
@ -52,7 +52,7 @@ namespace toco {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Filter to the list of supported ops.
|
// Filter to the list of supported ops.
|
||||||
string src_op_input;
|
std::string src_op_input;
|
||||||
switch (src_op->type) {
|
switch (src_op->type) {
|
||||||
case OperatorType::kGather:
|
case OperatorType::kGather:
|
||||||
src_op_input = src_op->inputs[0];
|
src_op_input = src_op->inputs[0];
|
||||||
|
|
|
@ -48,7 +48,7 @@ void SetDataTypeForAllOutputs(Model* model, Operator* op,
|
||||||
}
|
}
|
||||||
// Record data types of output before processing, so we can see at the
|
// Record data types of output before processing, so we can see at the
|
||||||
// end if we changed anything, and return the correct boolean value.
|
// end if we changed anything, and return the correct boolean value.
|
||||||
std::unordered_map<string, ArrayDataType> old_output_data_types;
|
std::unordered_map<std::string, ArrayDataType> old_output_data_types;
|
||||||
for (const auto& output : op->outputs) {
|
for (const auto& output : op->outputs) {
|
||||||
old_output_data_types[output] = model->GetArray(output).data_type;
|
old_output_data_types[output] = model->GetArray(output).data_type;
|
||||||
}
|
}
|
||||||
|
@ -171,7 +171,7 @@ void SetDataTypeForAllOutputs(Model* model, Operator* op,
|
||||||
return ::tensorflow::Status::OK();
|
return ::tensorflow::Status::OK();
|
||||||
}
|
}
|
||||||
for (int i = 0; i < op->outputs.size(); ++i) {
|
for (int i = 0; i < op->outputs.size(); ++i) {
|
||||||
const string& output = op->outputs[i];
|
const std::string& output = op->outputs[i];
|
||||||
const ArrayDataType data_type = unsupported_op->output_data_types[i];
|
const ArrayDataType data_type = unsupported_op->output_data_types[i];
|
||||||
model->GetArray(output).data_type = data_type;
|
model->GetArray(output).data_type = data_type;
|
||||||
}
|
}
|
||||||
|
|
|
@ -70,7 +70,7 @@ bool SupportsMinMax(const Array& array) {
|
||||||
|
|
||||||
// Sets the min/max on the given array, adjusting the reference_minmax for the
|
// Sets the min/max on the given array, adjusting the reference_minmax for the
|
||||||
// final data type of the array if it is already specified.
|
// final data type of the array if it is already specified.
|
||||||
bool PropagateDefaultMinMax::SetArrayMinMax(const string& array_name,
|
bool PropagateDefaultMinMax::SetArrayMinMax(const std::string& array_name,
|
||||||
Array* array) {
|
Array* array) {
|
||||||
CHECK(!array->minmax);
|
CHECK(!array->minmax);
|
||||||
|
|
||||||
|
|
|
@ -268,7 +268,7 @@ void ProcessDepthwiseConvOperator(Model* model, DepthwiseConvOperator* op) {
|
||||||
const auto& weights_shape = weights_array.shape();
|
const auto& weights_shape = weights_array.shape();
|
||||||
CHECK_EQ(weights_shape.dimensions_count(), 4);
|
CHECK_EQ(weights_shape.dimensions_count(), 4);
|
||||||
|
|
||||||
const string& output_name = op->outputs[0];
|
const std::string& output_name = op->outputs[0];
|
||||||
const int input_depth = input_shape.dims(3);
|
const int input_depth = input_shape.dims(3);
|
||||||
const int output_depth = weights_shape.dims(3);
|
const int output_depth = weights_shape.dims(3);
|
||||||
// TensorFlow doesn't define the depth_multiplier value on DepthwiseConv ops,
|
// TensorFlow doesn't define the depth_multiplier value on DepthwiseConv ops,
|
||||||
|
@ -302,7 +302,7 @@ void ProcessDepthToSpaceOperator(Model* model, DepthToSpaceOperator* op) {
|
||||||
const auto& input_shape = input_array.shape();
|
const auto& input_shape = input_array.shape();
|
||||||
CHECK_EQ(input_shape.dimensions_count(), 4);
|
CHECK_EQ(input_shape.dimensions_count(), 4);
|
||||||
|
|
||||||
const string& output_name = op->outputs[0];
|
const std::string& output_name = op->outputs[0];
|
||||||
const int block_size = op->block_size;
|
const int block_size = op->block_size;
|
||||||
CHECK_NE(block_size, 0) << "Invalid block_size in " << output_name;
|
CHECK_NE(block_size, 0) << "Invalid block_size in " << output_name;
|
||||||
const int batch = input_shape.dims(0);
|
const int batch = input_shape.dims(0);
|
||||||
|
@ -325,7 +325,7 @@ void ProcessSpaceToDepthOperator(Model* model, SpaceToDepthOperator* op) {
|
||||||
const auto& input_shape = input_array.shape();
|
const auto& input_shape = input_array.shape();
|
||||||
CHECK_EQ(input_shape.dimensions_count(), 4);
|
CHECK_EQ(input_shape.dimensions_count(), 4);
|
||||||
|
|
||||||
const string& output_name = op->outputs[0];
|
const std::string& output_name = op->outputs[0];
|
||||||
const int block_size = op->block_size;
|
const int block_size = op->block_size;
|
||||||
CHECK_NE(block_size, 0) << "Invalid block_size in " << output_name;
|
CHECK_NE(block_size, 0) << "Invalid block_size in " << output_name;
|
||||||
const int batch = input_shape.dims(0);
|
const int batch = input_shape.dims(0);
|
||||||
|
@ -470,7 +470,7 @@ void ProcessSimpleOperator(Model* model, Operator* op, int input_index) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const string& output_name = op->outputs[0];
|
const std::string& output_name = op->outputs[0];
|
||||||
auto& output_array = model->GetArray(output_name);
|
auto& output_array = model->GetArray(output_name);
|
||||||
if (output_array.has_shape()) {
|
if (output_array.has_shape()) {
|
||||||
return;
|
return;
|
||||||
|
@ -487,7 +487,7 @@ void ProcessSimpleBinaryOperator(Model* model, Operator* op) {
|
||||||
if (!input0_array.has_shape() || !input1_array.has_shape()) {
|
if (!input0_array.has_shape() || !input1_array.has_shape()) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
const string& output_name = op->outputs[0];
|
const std::string& output_name = op->outputs[0];
|
||||||
auto& output_array = model->GetArray(output_name);
|
auto& output_array = model->GetArray(output_name);
|
||||||
ComputeBinaryOperatorOutputSize(input0_array.shape(), input1_array.shape(),
|
ComputeBinaryOperatorOutputSize(input0_array.shape(), input1_array.shape(),
|
||||||
&output_array);
|
&output_array);
|
||||||
|
@ -639,14 +639,14 @@ void ProcessSliceOperator(Model* model, SliceOperator* op) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void ProcessReorderAxesOperator(Model* model, ReorderAxesOperator* op) {
|
void ProcessReorderAxesOperator(Model* model, ReorderAxesOperator* op) {
|
||||||
const string& input_name = op->inputs[0];
|
const std::string& input_name = op->inputs[0];
|
||||||
const auto& input_array = model->GetArray(input_name);
|
const auto& input_array = model->GetArray(input_name);
|
||||||
// Yield until input dims have been resolved.
|
// Yield until input dims have been resolved.
|
||||||
if (!input_array.has_shape()) {
|
if (!input_array.has_shape()) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
const auto& input_shape = input_array.shape();
|
const auto& input_shape = input_array.shape();
|
||||||
const string& output_name = op->outputs[0];
|
const std::string& output_name = op->outputs[0];
|
||||||
Shape* output_shape = model->GetArray(output_name).mutable_shape();
|
Shape* output_shape = model->GetArray(output_name).mutable_shape();
|
||||||
ShuffleDims(input_shape, op->input_axes_order, op->output_axes_order,
|
ShuffleDims(input_shape, op->input_axes_order, op->output_axes_order,
|
||||||
output_shape);
|
output_shape);
|
||||||
|
@ -757,7 +757,7 @@ void ProcessRangeOperator(Model* model, RangeOperator* op) {
|
||||||
|
|
||||||
void ProcessTensorFlowSplitOperator(Model* model, TensorFlowSplitOperator* op) {
|
void ProcessTensorFlowSplitOperator(Model* model, TensorFlowSplitOperator* op) {
|
||||||
CHECK_EQ(op->inputs.size(), 2);
|
CHECK_EQ(op->inputs.size(), 2);
|
||||||
const string& input_name = op->inputs[1];
|
const std::string& input_name = op->inputs[1];
|
||||||
const auto& input_array = model->GetArray(input_name);
|
const auto& input_array = model->GetArray(input_name);
|
||||||
// Yield until input dims have been resolved.
|
// Yield until input dims have been resolved.
|
||||||
if (!input_array.has_shape()) {
|
if (!input_array.has_shape()) {
|
||||||
|
@ -892,7 +892,7 @@ void ProcessTensorFlowSplitVOperator(Model* model,
|
||||||
}
|
}
|
||||||
|
|
||||||
void ProcessAveragePoolOperator(Model* model, AveragePoolOperator* op) {
|
void ProcessAveragePoolOperator(Model* model, AveragePoolOperator* op) {
|
||||||
const string& input_name = op->inputs[0];
|
const std::string& input_name = op->inputs[0];
|
||||||
const auto& input_array = model->GetArray(input_name);
|
const auto& input_array = model->GetArray(input_name);
|
||||||
// Yield until input dims have been resolved.
|
// Yield until input dims have been resolved.
|
||||||
if (!input_array.has_shape()) {
|
if (!input_array.has_shape()) {
|
||||||
|
@ -900,7 +900,7 @@ void ProcessAveragePoolOperator(Model* model, AveragePoolOperator* op) {
|
||||||
}
|
}
|
||||||
const auto& input_shape = input_array.shape();
|
const auto& input_shape = input_array.shape();
|
||||||
CHECK_EQ(input_shape.dimensions_count(), 4);
|
CHECK_EQ(input_shape.dimensions_count(), 4);
|
||||||
const string& output_name = op->outputs[0];
|
const std::string& output_name = op->outputs[0];
|
||||||
const int output_depth = input_shape.dims(3);
|
const int output_depth = input_shape.dims(3);
|
||||||
ComputeConvSizes(input_shape, output_depth, op->kwidth, op->kheight,
|
ComputeConvSizes(input_shape, output_depth, op->kwidth, op->kheight,
|
||||||
op->stride_width, op->stride_height, 1, 1, op->padding.type,
|
op->stride_width, op->stride_height, 1, 1, op->padding.type,
|
||||||
|
@ -909,7 +909,7 @@ void ProcessAveragePoolOperator(Model* model, AveragePoolOperator* op) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void ProcessMaxPoolOperator(Model* model, MaxPoolOperator* op) {
|
void ProcessMaxPoolOperator(Model* model, MaxPoolOperator* op) {
|
||||||
const string& input_name = op->inputs[0];
|
const std::string& input_name = op->inputs[0];
|
||||||
const auto& input_array = model->GetArray(input_name);
|
const auto& input_array = model->GetArray(input_name);
|
||||||
// Yield until input dims have been resolved.
|
// Yield until input dims have been resolved.
|
||||||
if (!input_array.has_shape()) {
|
if (!input_array.has_shape()) {
|
||||||
|
@ -917,7 +917,7 @@ void ProcessMaxPoolOperator(Model* model, MaxPoolOperator* op) {
|
||||||
}
|
}
|
||||||
const auto& input_shape = input_array.shape();
|
const auto& input_shape = input_array.shape();
|
||||||
CHECK_EQ(input_shape.dimensions_count(), 4);
|
CHECK_EQ(input_shape.dimensions_count(), 4);
|
||||||
const string& output_name = op->outputs[0];
|
const std::string& output_name = op->outputs[0];
|
||||||
const int output_depth = input_shape.dims(3);
|
const int output_depth = input_shape.dims(3);
|
||||||
ComputeConvSizes(input_shape, output_depth, op->kwidth, op->kheight,
|
ComputeConvSizes(input_shape, output_depth, op->kwidth, op->kheight,
|
||||||
op->stride_width, op->stride_height, 1, 1, op->padding.type,
|
op->stride_width, op->stride_height, 1, 1, op->padding.type,
|
||||||
|
@ -926,7 +926,7 @@ void ProcessMaxPoolOperator(Model* model, MaxPoolOperator* op) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void ProcessL2PoolOperator(Model* model, L2PoolOperator* op) {
|
void ProcessL2PoolOperator(Model* model, L2PoolOperator* op) {
|
||||||
const string& input_name = op->inputs[0];
|
const std::string& input_name = op->inputs[0];
|
||||||
const auto& input_array = model->GetArray(input_name);
|
const auto& input_array = model->GetArray(input_name);
|
||||||
// Yield until input dims have been resolved.
|
// Yield until input dims have been resolved.
|
||||||
if (!input_array.has_shape()) {
|
if (!input_array.has_shape()) {
|
||||||
|
@ -936,7 +936,7 @@ void ProcessL2PoolOperator(Model* model, L2PoolOperator* op) {
|
||||||
if (input_shape.dimensions_count() < 4) {
|
if (input_shape.dimensions_count() < 4) {
|
||||||
LOG(FATAL) << "missing dimensions for " << input_name;
|
LOG(FATAL) << "missing dimensions for " << input_name;
|
||||||
}
|
}
|
||||||
const string& output_name = op->outputs[0];
|
const std::string& output_name = op->outputs[0];
|
||||||
const int output_depth = input_shape.dims(3);
|
const int output_depth = input_shape.dims(3);
|
||||||
ComputeConvSizes(input_shape, output_depth, op->kwidth, op->kheight,
|
ComputeConvSizes(input_shape, output_depth, op->kwidth, op->kheight,
|
||||||
op->stride_width, op->stride_height, 1, 1, op->padding.type,
|
op->stride_width, op->stride_height, 1, 1, op->padding.type,
|
||||||
|
@ -954,7 +954,7 @@ void ProcessResizeBilinearOperator(Model* model, ResizeBilinearOperator* op) {
|
||||||
}
|
}
|
||||||
const auto& input_data_shape = model->GetArray(op->inputs[0]).shape();
|
const auto& input_data_shape = model->GetArray(op->inputs[0]).shape();
|
||||||
|
|
||||||
const string& output_size_name = op->inputs[1];
|
const std::string& output_size_name = op->inputs[1];
|
||||||
const auto& output_size_array = model->GetArray(output_size_name);
|
const auto& output_size_array = model->GetArray(output_size_name);
|
||||||
CHECK(output_size_array.data_type == ArrayDataType::kInt32);
|
CHECK(output_size_array.data_type == ArrayDataType::kInt32);
|
||||||
CHECK(output_size_array.has_shape());
|
CHECK(output_size_array.has_shape());
|
||||||
|
@ -982,7 +982,7 @@ void ProcessResizeNearestNeighborOperator(Model* model,
|
||||||
}
|
}
|
||||||
const auto& input_data_shape = model->GetArray(op->inputs[0]).shape();
|
const auto& input_data_shape = model->GetArray(op->inputs[0]).shape();
|
||||||
|
|
||||||
const string& output_size_name = op->inputs[1];
|
const std::string& output_size_name = op->inputs[1];
|
||||||
const auto& output_size_array = model->GetArray(output_size_name);
|
const auto& output_size_array = model->GetArray(output_size_name);
|
||||||
CHECK(output_size_array.data_type == ArrayDataType::kInt32);
|
CHECK(output_size_array.data_type == ArrayDataType::kInt32);
|
||||||
CHECK(output_size_array.has_shape());
|
CHECK(output_size_array.has_shape());
|
||||||
|
@ -1862,7 +1862,7 @@ void ProcessArgMinMaxOperator(Model* model, Op* op) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const string& output_name = op->outputs[0];
|
const std::string& output_name = op->outputs[0];
|
||||||
auto& output_array = model->GetArray(output_name);
|
auto& output_array = model->GetArray(output_name);
|
||||||
if (output_array.has_shape()) {
|
if (output_array.has_shape()) {
|
||||||
return;
|
return;
|
||||||
|
@ -1880,7 +1880,7 @@ void ProcessSparseToDenseOperator(Model* model, SparseToDenseOperator* op) {
|
||||||
// Output should not go over four dimensions.
|
// Output should not go over four dimensions.
|
||||||
CHECK_LE(output_shape_array.shape().dims(0), 4);
|
CHECK_LE(output_shape_array.shape().dims(0), 4);
|
||||||
|
|
||||||
const string& output_name = op->outputs[0];
|
const std::string& output_name = op->outputs[0];
|
||||||
Array& output_array = model->GetArray(output_name);
|
Array& output_array = model->GetArray(output_name);
|
||||||
if (output_array.has_shape()) return;
|
if (output_array.has_shape()) return;
|
||||||
|
|
||||||
|
@ -2015,7 +2015,7 @@ void ProcessUnpackOperator(Model* model, UnpackOperator* op) {
|
||||||
output_dims.push_back(input_dims[i]);
|
output_dims.push_back(input_dims[i]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for (const string& output_name : op->outputs) {
|
for (const std::string& output_name : op->outputs) {
|
||||||
auto& output_array = model->GetArray(output_name);
|
auto& output_array = model->GetArray(output_name);
|
||||||
if (output_array.has_shape()) {
|
if (output_array.has_shape()) {
|
||||||
return;
|
return;
|
||||||
|
@ -2149,7 +2149,7 @@ void ProcessScatterNdOperator(Model* model, ScatterNdOperator* op) {
|
||||||
*modified = false;
|
*modified = false;
|
||||||
auto it = model->operators.begin() + op_index;
|
auto it = model->operators.begin() + op_index;
|
||||||
auto* op = it->get();
|
auto* op = it->get();
|
||||||
std::unordered_map<string, std::vector<int>> old_output_dims;
|
std::unordered_map<std::string, std::vector<int>> old_output_dims;
|
||||||
for (const auto& output : op->outputs) {
|
for (const auto& output : op->outputs) {
|
||||||
if (model->GetArray(output).has_shape()) {
|
if (model->GetArray(output).has_shape()) {
|
||||||
old_output_dims[output] = model->GetArray(output).shape().dims();
|
old_output_dims[output] = model->GetArray(output).shape().dims();
|
||||||
|
@ -2400,7 +2400,7 @@ void ProcessScatterNdOperator(Model* model, ScatterNdOperator* op) {
|
||||||
return ::tensorflow::Status::OK();
|
return ::tensorflow::Status::OK();
|
||||||
}
|
}
|
||||||
for (int i = 0; i < op->outputs.size(); ++i) {
|
for (int i = 0; i < op->outputs.size(); ++i) {
|
||||||
const string& output = op->outputs[i];
|
const std::string& output = op->outputs[i];
|
||||||
model->GetArray(output).copy_shape(unsupported_op->output_shapes.at(i));
|
model->GetArray(output).copy_shape(unsupported_op->output_shapes.at(i));
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
|
|
|
@ -164,7 +164,7 @@ std::unique_ptr<GenericBuffer> QuantizeBuffer(
|
||||||
|
|
||||||
template <ArrayDataType A>
|
template <ArrayDataType A>
|
||||||
void QuantizeArray(GraphTransformation* transformation, Model* model,
|
void QuantizeArray(GraphTransformation* transformation, Model* model,
|
||||||
const string& name,
|
const std::string& name,
|
||||||
const QuantizationParams& quantization_params) {
|
const QuantizationParams& quantization_params) {
|
||||||
auto& array = model->GetArray(name);
|
auto& array = model->GetArray(name);
|
||||||
CHECK(array.data_type == ArrayDataType::kFloat);
|
CHECK(array.data_type == ArrayDataType::kFloat);
|
||||||
|
@ -184,7 +184,7 @@ void QuantizeArray(GraphTransformation* transformation, Model* model,
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void QuantizeArray(GraphTransformation* transformation, Model* model,
|
void QuantizeArray(GraphTransformation* transformation, Model* model,
|
||||||
const string& name, ArrayDataType quantized_data_type,
|
const std::string& name, ArrayDataType quantized_data_type,
|
||||||
const QuantizationParams& quantization_params) {
|
const QuantizationParams& quantization_params) {
|
||||||
ArrayDataType adjusted_data_type = quantized_data_type;
|
ArrayDataType adjusted_data_type = quantized_data_type;
|
||||||
auto& array = model->GetArray(name);
|
auto& array = model->GetArray(name);
|
||||||
|
|
|
@ -47,7 +47,7 @@ void ChooseQuantizationParamsForArrayAndQuantizedDataType(
|
||||||
// Quantizes an array by setting its data type and (if constant) quantizing
|
// Quantizes an array by setting its data type and (if constant) quantizing
|
||||||
// all values in the array.
|
// all values in the array.
|
||||||
void QuantizeArray(GraphTransformation* transformation, Model* model,
|
void QuantizeArray(GraphTransformation* transformation, Model* model,
|
||||||
const string& name, ArrayDataType quantized_data_type,
|
const std::string& name, ArrayDataType quantized_data_type,
|
||||||
const QuantizationParams& quantization_params);
|
const QuantizationParams& quantization_params);
|
||||||
|
|
||||||
// Returns true if the given array, when quantized, contains only values between
|
// Returns true if the given array, when quantized, contains only values between
|
||||||
|
|
|
@ -121,7 +121,7 @@ bool SupportOutputTypeFloatInQuantizedOp(const Operator& op) {
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
const MinMax& GetOrComputeMinMax(Model* model, const string& array_name) {
|
const MinMax& GetOrComputeMinMax(Model* model, const std::string& array_name) {
|
||||||
auto& array = model->GetArray(array_name);
|
auto& array = model->GetArray(array_name);
|
||||||
// Normally we should have a MinMax recorded on this Array,
|
// Normally we should have a MinMax recorded on this Array,
|
||||||
// so we just use it.
|
// so we just use it.
|
||||||
|
|
|
@ -29,7 +29,7 @@ namespace {
|
||||||
|
|
||||||
bool ApplyAttrsToArray(GraphTransformation* transformation, Model* model,
|
bool ApplyAttrsToArray(GraphTransformation* transformation, Model* model,
|
||||||
const FakeQuantOperator& fq_op,
|
const FakeQuantOperator& fq_op,
|
||||||
const string& array_name) {
|
const std::string& array_name) {
|
||||||
bool changed = false;
|
bool changed = false;
|
||||||
auto& annotated_array = model->GetArray(array_name);
|
auto& annotated_array = model->GetArray(array_name);
|
||||||
if (!annotated_array.minmax) {
|
if (!annotated_array.minmax) {
|
||||||
|
|
|
@ -43,8 +43,8 @@ bool TransformsToIdentity(std::vector<int> const& perm1,
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
void ReplaceOpInputsWith(Model* model, const string& lookfor,
|
void ReplaceOpInputsWith(Model* model, const std::string& lookfor,
|
||||||
const string& replacewith) {
|
const std::string& replacewith) {
|
||||||
for (const auto& op : model->operators) {
|
for (const auto& op : model->operators) {
|
||||||
for (int i = 0; i < op->inputs.size(); ++i) {
|
for (int i = 0; i < op->inputs.size(); ++i) {
|
||||||
if (op->inputs[i] == lookfor) {
|
if (op->inputs[i] == lookfor) {
|
||||||
|
|
|
@ -41,9 +41,9 @@ namespace toco {
|
||||||
if (concat_op->type != OperatorType::kConcatenation) {
|
if (concat_op->type != OperatorType::kConcatenation) {
|
||||||
return ::tensorflow::Status::OK();
|
return ::tensorflow::Status::OK();
|
||||||
}
|
}
|
||||||
std::vector<string> trivial_inputs;
|
std::vector<std::string> trivial_inputs;
|
||||||
std::vector<string> nontrivial_inputs;
|
std::vector<std::string> nontrivial_inputs;
|
||||||
for (const string& input : concat_op->inputs) {
|
for (const std::string& input : concat_op->inputs) {
|
||||||
const auto& input_array = model->GetArray(input);
|
const auto& input_array = model->GetArray(input);
|
||||||
const bool is_trivial =
|
const bool is_trivial =
|
||||||
input_array.has_shape() && input_array.shape().dimensions_count() == 0;
|
input_array.has_shape() && input_array.shape().dimensions_count() == 0;
|
||||||
|
@ -60,7 +60,7 @@ namespace toco {
|
||||||
|
|
||||||
// Drop trivial inputs.
|
// Drop trivial inputs.
|
||||||
concat_op->inputs = nontrivial_inputs;
|
concat_op->inputs = nontrivial_inputs;
|
||||||
for (const string& input : trivial_inputs) {
|
for (const std::string& input : trivial_inputs) {
|
||||||
DeleteArrayIfUnusedOutsideOfOp(input, concat_op, model);
|
DeleteArrayIfUnusedOutsideOfOp(input, concat_op, model);
|
||||||
}
|
}
|
||||||
*modified = true;
|
*modified = true;
|
||||||
|
|
|
@ -29,7 +29,7 @@ namespace {
|
||||||
// array instead. from_array is assumed to be discardable, and consequently
|
// array instead. from_array is assumed to be discardable, and consequently
|
||||||
// this only updates operator edges (since discardable arrays only
|
// this only updates operator edges (since discardable arrays only
|
||||||
// appear there, and not e.g. in model flags).
|
// appear there, and not e.g. in model flags).
|
||||||
void Reroute(const string& from, const string& to, Model* model) {
|
void Reroute(const std::string& from, const std::string& to, Model* model) {
|
||||||
for (const auto& op : model->operators) {
|
for (const auto& op : model->operators) {
|
||||||
for (auto& output : op->outputs) {
|
for (auto& output : op->outputs) {
|
||||||
if (output == from) {
|
if (output == from) {
|
||||||
|
@ -92,8 +92,9 @@ bool RemoveTrivialPassthroughOp(GraphTransformation* transformation,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const string main_input_name = passthru_op->inputs[main_input_array_index];
|
const std::string main_input_name =
|
||||||
const string output_name = passthru_op->outputs[0];
|
passthru_op->inputs[main_input_array_index];
|
||||||
|
const std::string output_name = passthru_op->outputs[0];
|
||||||
|
|
||||||
if (IsDiscardableArray(*model, output_name)) {
|
if (IsDiscardableArray(*model, output_name)) {
|
||||||
transformation->AddMessageF(
|
transformation->AddMessageF(
|
||||||
|
|
|
@ -32,7 +32,7 @@ namespace {
|
||||||
|
|
||||||
bool IsTrivialUnfusedActivationFunc(GraphTransformation* transformation,
|
bool IsTrivialUnfusedActivationFunc(GraphTransformation* transformation,
|
||||||
const Model& model, OperatorType op_type,
|
const Model& model, OperatorType op_type,
|
||||||
const string& input_array_name) {
|
const std::string& input_array_name) {
|
||||||
double clamp_min;
|
double clamp_min;
|
||||||
double clamp_max;
|
double clamp_max;
|
||||||
switch (op_type) {
|
switch (op_type) {
|
||||||
|
@ -60,7 +60,7 @@ bool IsTrivialUnfusedActivationFunc(GraphTransformation* transformation,
|
||||||
bool IsTrivialFusedActivationFunc(
|
bool IsTrivialFusedActivationFunc(
|
||||||
GraphTransformation* transformation, const Model& model,
|
GraphTransformation* transformation, const Model& model,
|
||||||
FusedActivationFunctionType activation_function,
|
FusedActivationFunctionType activation_function,
|
||||||
const string& output_array_name) {
|
const std::string& output_array_name) {
|
||||||
double clamp_min;
|
double clamp_min;
|
||||||
double clamp_max;
|
double clamp_max;
|
||||||
switch (activation_function) {
|
switch (activation_function) {
|
||||||
|
|
|
@ -31,8 +31,8 @@ namespace toco {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
bool IsTrivialMinMax(GraphTransformation* transformation, const Model& model,
|
bool IsTrivialMinMax(GraphTransformation* transformation, const Model& model,
|
||||||
OperatorType op_type, const string& input_array_name,
|
OperatorType op_type, const std::string& input_array_name,
|
||||||
const string& clamp_value_array_name) {
|
const std::string& clamp_value_array_name) {
|
||||||
const auto& clamp_value_array = model.GetArray(clamp_value_array_name);
|
const auto& clamp_value_array = model.GetArray(clamp_value_array_name);
|
||||||
if (!IsConstantParameterArray(model, clamp_value_array_name)) {
|
if (!IsConstantParameterArray(model, clamp_value_array_name)) {
|
||||||
transformation->AddMessageF("Clip value array %s is non-constant",
|
transformation->AddMessageF("Clip value array %s is non-constant",
|
||||||
|
|
|
@ -58,7 +58,7 @@ namespace toco {
|
||||||
if (found_output_as_rnn_state_array) {
|
if (found_output_as_rnn_state_array) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
for (const string& output_array : model->flags.output_arrays()) {
|
for (const std::string& output_array : model->flags.output_arrays()) {
|
||||||
if (output == output_array) {
|
if (output == output_array) {
|
||||||
return ::tensorflow::Status::OK();
|
return ::tensorflow::Status::OK();
|
||||||
}
|
}
|
||||||
|
|
|
@ -75,7 +75,7 @@ bool IsMoveOperator(OperatorType optype) {
|
||||||
return ::tensorflow::Status::OK();
|
return ::tensorflow::Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
const string intermediate_name = element_op->inputs[0];
|
const std::string intermediate_name = element_op->inputs[0];
|
||||||
auto it = FindOpWithOutput(*model, intermediate_name);
|
auto it = FindOpWithOutput(*model, intermediate_name);
|
||||||
if (it == model->operators.end()) {
|
if (it == model->operators.end()) {
|
||||||
AddMessageF("No preceding operator");
|
AddMessageF("No preceding operator");
|
||||||
|
@ -103,8 +103,8 @@ bool IsMoveOperator(OperatorType optype) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// op->inputs may change so we need to keep a value by copy.
|
// op->inputs may change so we need to keep a value by copy.
|
||||||
const string input_name = move_op->inputs[0];
|
const std::string input_name = move_op->inputs[0];
|
||||||
const string output_name = element_op->outputs[0];
|
const std::string output_name = element_op->outputs[0];
|
||||||
|
|
||||||
AddMessageF("Swapping around operators with %s and %s", LogName(*element_op),
|
AddMessageF("Swapping around operators with %s and %s", LogName(*element_op),
|
||||||
LogName(*move_op));
|
LogName(*move_op));
|
||||||
|
|
|
@ -138,9 +138,9 @@ std::vector<int> ComputeNewPerm(std::vector<int> input_dims,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Need to copy to keep static if permutated.
|
// Need to copy to keep static if permutated.
|
||||||
const string input_name = reshape_op->inputs[0];
|
const std::string input_name = reshape_op->inputs[0];
|
||||||
const string intermediate_name = reshape_op->outputs[0];
|
const std::string intermediate_name = reshape_op->outputs[0];
|
||||||
const string output_name = transpose_op->outputs[0];
|
const std::string output_name = transpose_op->outputs[0];
|
||||||
|
|
||||||
// Intermediate should not be consumed by any other operators.
|
// Intermediate should not be consumed by any other operators.
|
||||||
if (CountOpsWithInput(*model, intermediate_name) != 1) {
|
if (CountOpsWithInput(*model, intermediate_name) != 1) {
|
||||||
|
|
|
@ -62,12 +62,14 @@ namespace toco {
|
||||||
// Create the new Mul, Add operators
|
// Create the new Mul, Add operators
|
||||||
auto* mul_op = new MulOperator;
|
auto* mul_op = new MulOperator;
|
||||||
auto* add_op = new AddOperator;
|
auto* add_op = new AddOperator;
|
||||||
const string mul_name =
|
const std::string mul_name =
|
||||||
AvailableArrayName(*model, bn_op->outputs[0] + "_mul");
|
AvailableArrayName(*model, bn_op->outputs[0] + "_mul");
|
||||||
const string add_name =
|
const std::string add_name =
|
||||||
AvailableArrayName(*model, bn_op->outputs[0] + "_add");
|
AvailableArrayName(*model, bn_op->outputs[0] + "_add");
|
||||||
const string mul_param_name = AvailableArrayName(*model, mul_name + "_param");
|
const std::string mul_param_name =
|
||||||
const string add_param_name = AvailableArrayName(*model, add_name + "_param");
|
AvailableArrayName(*model, mul_name + "_param");
|
||||||
|
const std::string add_param_name =
|
||||||
|
AvailableArrayName(*model, add_name + "_param");
|
||||||
mul_op->inputs = {bn_op->inputs[0], mul_param_name};
|
mul_op->inputs = {bn_op->inputs[0], mul_param_name};
|
||||||
mul_op->outputs = {mul_name};
|
mul_op->outputs = {mul_name};
|
||||||
add_op->inputs = {mul_name, add_param_name};
|
add_op->inputs = {mul_name, add_param_name};
|
||||||
|
|
|
@ -147,7 +147,7 @@ void SetMinMaxForConcatenedArray(GraphTransformation* transformation,
|
||||||
const auto* concat_op =
|
const auto* concat_op =
|
||||||
static_cast<const ConcatenationOperator*>(concat_base_op);
|
static_cast<const ConcatenationOperator*>(concat_base_op);
|
||||||
|
|
||||||
for (const string& input_name : concat_op->inputs) {
|
for (const std::string& input_name : concat_op->inputs) {
|
||||||
// We only expect constant unquantized arrays as input, otherwise we return.
|
// We only expect constant unquantized arrays as input, otherwise we return.
|
||||||
// We also make sure the shapes of the input arrays are known and they are
|
// We also make sure the shapes of the input arrays are known and they are
|
||||||
// all discardable.
|
// all discardable.
|
||||||
|
@ -166,10 +166,10 @@ void SetMinMaxForConcatenedArray(GraphTransformation* transformation,
|
||||||
const int concatenation_axis = concat_op->axis;
|
const int concatenation_axis = concat_op->axis;
|
||||||
|
|
||||||
CHECK_EQ(concat_op->outputs.size(), 1);
|
CHECK_EQ(concat_op->outputs.size(), 1);
|
||||||
string concatenated_array_name = concat_op->outputs[0];
|
std::string concatenated_array_name = concat_op->outputs[0];
|
||||||
Array& concatenated_array = model->GetOrCreateArray(concatenated_array_name);
|
Array& concatenated_array = model->GetOrCreateArray(concatenated_array_name);
|
||||||
std::vector<Array*> input_arrays;
|
std::vector<Array*> input_arrays;
|
||||||
for (const string& input_name : concat_op->inputs) {
|
for (const std::string& input_name : concat_op->inputs) {
|
||||||
input_arrays.push_back(&model->GetArray(input_name));
|
input_arrays.push_back(&model->GetArray(input_name));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -27,19 +27,19 @@ namespace toco {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
void RenameArray(Model* model, const string& oldname,
|
void RenameArray(Model* model, const std::string& oldname,
|
||||||
const string& desired_newname) {
|
const std::string& desired_newname) {
|
||||||
const string& newname = AvailableArrayName(*model, desired_newname);
|
const std::string& newname = AvailableArrayName(*model, desired_newname);
|
||||||
auto& arrays = model->GetMutableArrayMap();
|
auto& arrays = model->GetMutableArrayMap();
|
||||||
arrays[newname] = std::move(arrays[oldname]);
|
arrays[newname] = std::move(arrays[oldname]);
|
||||||
arrays.erase(oldname);
|
arrays.erase(oldname);
|
||||||
for (const auto& op : model->operators) {
|
for (const auto& op : model->operators) {
|
||||||
for (string& input : op->inputs) {
|
for (std::string& input : op->inputs) {
|
||||||
if (input == oldname) {
|
if (input == oldname) {
|
||||||
input = newname;
|
input = newname;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for (string& output : op->outputs) {
|
for (std::string& output : op->outputs) {
|
||||||
if (output == oldname) {
|
if (output == oldname) {
|
||||||
output = newname;
|
output = newname;
|
||||||
}
|
}
|
||||||
|
@ -89,8 +89,8 @@ void ReorderAxes(AxesOrder input_axes_order, AxesOrder output_axes_order,
|
||||||
auto* reorder_op = static_cast<ReorderAxesOperator*>(op);
|
auto* reorder_op = static_cast<ReorderAxesOperator*>(op);
|
||||||
|
|
||||||
// Intentionally copies, not references.
|
// Intentionally copies, not references.
|
||||||
const string input_array_name = reorder_op->inputs[0];
|
const std::string input_array_name = reorder_op->inputs[0];
|
||||||
const string output_array_name = reorder_op->outputs[0];
|
const std::string output_array_name = reorder_op->outputs[0];
|
||||||
|
|
||||||
auto& input_array = model->GetArray(input_array_name);
|
auto& input_array = model->GetArray(input_array_name);
|
||||||
auto& output_array = model->GetArray(output_array_name);
|
auto& output_array = model->GetArray(output_array_name);
|
||||||
|
|
|
@ -44,8 +44,8 @@ namespace toco {
|
||||||
if (tf_concat_op->type == OperatorType::kConcatV2) {
|
if (tf_concat_op->type == OperatorType::kConcatV2) {
|
||||||
axis_pos = tf_concat_op->inputs.size() - 1;
|
axis_pos = tf_concat_op->inputs.size() - 1;
|
||||||
}
|
}
|
||||||
const string axis_name = tf_concat_op->inputs[axis_pos];
|
const std::string axis_name = tf_concat_op->inputs[axis_pos];
|
||||||
std::vector<string> concat_input_names;
|
std::vector<std::string> concat_input_names;
|
||||||
for (std::size_t i = 0; i < tf_concat_op->inputs.size(); i++) {
|
for (std::size_t i = 0; i < tf_concat_op->inputs.size(); i++) {
|
||||||
if (i != axis_pos) {
|
if (i != axis_pos) {
|
||||||
concat_input_names.push_back(tf_concat_op->inputs[i]);
|
concat_input_names.push_back(tf_concat_op->inputs[i]);
|
||||||
|
|
|
@ -27,7 +27,7 @@ namespace toco {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
TransposeOperator* FindTransposeOpWithInput(const Model& model,
|
TransposeOperator* FindTransposeOpWithInput(const Model& model,
|
||||||
const string& array_name) {
|
const std::string& array_name) {
|
||||||
for (auto it = model.operators.begin(); it != model.operators.end(); ++it) {
|
for (auto it = model.operators.begin(); it != model.operators.end(); ++it) {
|
||||||
Operator* op = it->get();
|
Operator* op = it->get();
|
||||||
if (op->type != OperatorType::kTranspose) {
|
if (op->type != OperatorType::kTranspose) {
|
||||||
|
@ -74,8 +74,8 @@ TransposeOperator* FindTransposeOpWithInput(const Model& model,
|
||||||
DCHECK_EQ(matmul_it->get(), matmul_op);
|
DCHECK_EQ(matmul_it->get(), matmul_op);
|
||||||
};
|
};
|
||||||
|
|
||||||
string input_lhs = matmul_op->inputs[0];
|
std::string input_lhs = matmul_op->inputs[0];
|
||||||
string input_rhs = matmul_op->inputs[1];
|
std::string input_rhs = matmul_op->inputs[1];
|
||||||
|
|
||||||
// Handle `transpose_a` with best effort: If the dimension of lhs is known,
|
// Handle `transpose_a` with best effort: If the dimension of lhs is known,
|
||||||
// insert a `Transpose` op.
|
// insert a `Transpose` op.
|
||||||
|
|
|
@ -37,7 +37,7 @@ namespace toco {
|
||||||
|
|
||||||
CHECK_EQ(switch_op->inputs.size(), 2);
|
CHECK_EQ(switch_op->inputs.size(), 2);
|
||||||
CHECK_EQ(switch_op->outputs.size(), 2);
|
CHECK_EQ(switch_op->outputs.size(), 2);
|
||||||
const string& predicate_name = switch_op->inputs[1];
|
const std::string& predicate_name = switch_op->inputs[1];
|
||||||
// If the predicate array hasn't been resolved to a constant yet,
|
// If the predicate array hasn't been resolved to a constant yet,
|
||||||
// we need to yield.
|
// we need to yield.
|
||||||
if (!IsConstantParameterArray(*model, predicate_name)) {
|
if (!IsConstantParameterArray(*model, predicate_name)) {
|
||||||
|
|
|
@ -37,7 +37,7 @@ namespace toco {
|
||||||
return ::tensorflow::Status::OK();
|
return ::tensorflow::Status::OK();
|
||||||
}
|
}
|
||||||
const Array& input_array = model->GetArray(fc_op->inputs[0]);
|
const Array& input_array = model->GetArray(fc_op->inputs[0]);
|
||||||
const string& weights_name = fc_op->inputs[1];
|
const std::string& weights_name = fc_op->inputs[1];
|
||||||
Array& weights_array = model->GetArray(weights_name);
|
Array& weights_array = model->GetArray(weights_name);
|
||||||
const Array& output_array = model->GetArray(fc_op->outputs[0]);
|
const Array& output_array = model->GetArray(fc_op->outputs[0]);
|
||||||
// Exit if this FC op isn't quantized with uint8 inputs and int16 outputs,
|
// Exit if this FC op isn't quantized with uint8 inputs and int16 outputs,
|
||||||
|
@ -143,7 +143,7 @@ namespace toco {
|
||||||
// Add a second output array to this FC op, serving as a workspace to perform
|
// Add a second output array to this FC op, serving as a workspace to perform
|
||||||
// runtime shuffling/xoring of its input activations.
|
// runtime shuffling/xoring of its input activations.
|
||||||
CHECK_EQ(fc_op->outputs.size(), 1);
|
CHECK_EQ(fc_op->outputs.size(), 1);
|
||||||
const string& shuffled_input_workspace_array_name =
|
const std::string& shuffled_input_workspace_array_name =
|
||||||
AvailableArrayName(*model, fc_op->inputs[0] + "_shuffled");
|
AvailableArrayName(*model, fc_op->inputs[0] + "_shuffled");
|
||||||
fc_op->outputs.push_back(shuffled_input_workspace_array_name);
|
fc_op->outputs.push_back(shuffled_input_workspace_array_name);
|
||||||
auto& shuffled_input_workspace_array =
|
auto& shuffled_input_workspace_array =
|
||||||
|
|
|
@ -64,7 +64,7 @@ namespace toco {
|
||||||
// Wire up arrays, constructing a new intermediate array to connect the
|
// Wire up arrays, constructing a new intermediate array to connect the
|
||||||
// op to its new unfused activation function.
|
// op to its new unfused activation function.
|
||||||
ac_op->outputs = op->outputs;
|
ac_op->outputs = op->outputs;
|
||||||
const string& tmp_array_name =
|
const std::string& tmp_array_name =
|
||||||
AvailableArrayName(*model, op->outputs[0] + "_unfused");
|
AvailableArrayName(*model, op->outputs[0] + "_unfused");
|
||||||
CHECK(!model->HasArray(tmp_array_name));
|
CHECK(!model->HasArray(tmp_array_name));
|
||||||
|
|
||||||
|
|
|
@ -55,8 +55,8 @@ namespace toco {
|
||||||
auto* stitch_op = static_cast<DynamicStitchOperator*>(op_it->get());
|
auto* stitch_op = static_cast<DynamicStitchOperator*>(op_it->get());
|
||||||
|
|
||||||
// Split up the DynamicStitch inputs into the indices and data.
|
// Split up the DynamicStitch inputs into the indices and data.
|
||||||
std::vector<string> stitch_indices_inputs;
|
std::vector<std::string> stitch_indices_inputs;
|
||||||
std::vector<string> stitch_data_inputs;
|
std::vector<std::string> stitch_data_inputs;
|
||||||
for (size_t i = 0; i < stitch_op->num_partitions; ++i) {
|
for (size_t i = 0; i < stitch_op->num_partitions; ++i) {
|
||||||
stitch_indices_inputs.push_back(stitch_op->inputs[i]);
|
stitch_indices_inputs.push_back(stitch_op->inputs[i]);
|
||||||
}
|
}
|
||||||
|
@ -67,7 +67,8 @@ namespace toco {
|
||||||
|
|
||||||
// Validate all indices come from the same DynamicPartition.
|
// Validate all indices come from the same DynamicPartition.
|
||||||
DynamicPartitionOperator* indices_partition_op = nullptr;
|
DynamicPartitionOperator* indices_partition_op = nullptr;
|
||||||
for (const string& indices_partition_output_name : stitch_indices_inputs) {
|
for (const std::string& indices_partition_output_name :
|
||||||
|
stitch_indices_inputs) {
|
||||||
auto* op = GetOpWithOutput(*model, indices_partition_output_name);
|
auto* op = GetOpWithOutput(*model, indices_partition_output_name);
|
||||||
CHECK(op) << "Source of " << indices_partition_output_name << " not found";
|
CHECK(op) << "Source of " << indices_partition_output_name << " not found";
|
||||||
if (op->type != OperatorType::kDynamicPartition) {
|
if (op->type != OperatorType::kDynamicPartition) {
|
||||||
|
@ -112,7 +113,7 @@ namespace toco {
|
||||||
|
|
||||||
// Find all of the gathers used for the data inputs.
|
// Find all of the gathers used for the data inputs.
|
||||||
std::vector<GatherOperator*> gather_ops;
|
std::vector<GatherOperator*> gather_ops;
|
||||||
for (const string& gather_output_name : stitch_data_inputs) {
|
for (const std::string& gather_output_name : stitch_data_inputs) {
|
||||||
auto* op = GetOpWithOutput(*model, gather_output_name);
|
auto* op = GetOpWithOutput(*model, gather_output_name);
|
||||||
CHECK(op) << "Source of " << gather_output_name << " not found";
|
CHECK(op) << "Source of " << gather_output_name << " not found";
|
||||||
if (op->type != OperatorType::kGather) {
|
if (op->type != OperatorType::kGather) {
|
||||||
|
|
|
@ -34,9 +34,10 @@ absl::InlinedVector<int64, 4> ToInlinedVector(const std::vector<int>& vec) {
|
||||||
return absl::InlinedVector<int64, 4>(vec.begin(), vec.end());
|
return absl::InlinedVector<int64, 4>(vec.begin(), vec.end());
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<string> SliceInput(
|
std::vector<std::string> SliceInput(
|
||||||
const string& input, const string& base_name, const string& input_name,
|
const std::string& input, const std::string& base_name,
|
||||||
const int batch_size, const Array& input_array, Model* model,
|
const std::string& input_name, const int batch_size,
|
||||||
|
const Array& input_array, Model* model,
|
||||||
std::vector<std::unique_ptr<Operator>>::iterator* tail_it) {
|
std::vector<std::unique_ptr<Operator>>::iterator* tail_it) {
|
||||||
int rank = input_array.shape().dimensions_count();
|
int rank = input_array.shape().dimensions_count();
|
||||||
int num_rows = input_array.shape().dims(rank - 2);
|
int num_rows = input_array.shape().dims(rank - 2);
|
||||||
|
@ -54,7 +55,7 @@ std::vector<string> SliceInput(
|
||||||
*tail_it = model->operators.emplace(*tail_it, reshape_op) + 1;
|
*tail_it = model->operators.emplace(*tail_it, reshape_op) + 1;
|
||||||
|
|
||||||
// Slice along each batch index and remember the slice output for future use.
|
// Slice along each batch index and remember the slice output for future use.
|
||||||
std::vector<string> slice_outputs;
|
std::vector<std::string> slice_outputs;
|
||||||
for (int batch_idx = 0; batch_idx < batch_size; ++batch_idx) {
|
for (int batch_idx = 0; batch_idx < batch_size; ++batch_idx) {
|
||||||
std::string batch_name =
|
std::string batch_name =
|
||||||
absl::StrCat(base_name, "_b", batch_idx, "/slice_", input_name);
|
absl::StrCat(base_name, "_b", batch_idx, "/slice_", input_name);
|
||||||
|
@ -110,10 +111,10 @@ std::vector<int32> GetTransposeShape(const Shape& input_shape,
|
||||||
return output_shape;
|
return output_shape;
|
||||||
}
|
}
|
||||||
|
|
||||||
TransposeOperator* TransposeInput(const string& input, Model* model) {
|
TransposeOperator* TransposeInput(const std::string& input, Model* model) {
|
||||||
const auto& input_array = model->GetArray(input);
|
const auto& input_array = model->GetArray(input);
|
||||||
const auto perm_array = GetTransposePerm(input_array);
|
const auto perm_array = GetTransposePerm(input_array);
|
||||||
const string perm_array_name = CreateInt32Array(
|
const std::string perm_array_name = CreateInt32Array(
|
||||||
model, AvailableArrayName(*model, input + "/transpose/perm"), perm_array);
|
model, AvailableArrayName(*model, input + "/transpose/perm"), perm_array);
|
||||||
auto* transpose_op = new TransposeOperator;
|
auto* transpose_op = new TransposeOperator;
|
||||||
transpose_op->inputs = {input, perm_array_name};
|
transpose_op->inputs = {input, perm_array_name};
|
||||||
|
@ -141,8 +142,8 @@ TransposeOperator* TransposeInput(const string& input, Model* model) {
|
||||||
static_cast<const BatchMatMulOperator*>(batch_op_it->get());
|
static_cast<const BatchMatMulOperator*>(batch_op_it->get());
|
||||||
auto& tail_it = batch_op_it;
|
auto& tail_it = batch_op_it;
|
||||||
|
|
||||||
string input_lhs = batch_op->inputs[0];
|
std::string input_lhs = batch_op->inputs[0];
|
||||||
string input_rhs = batch_op->inputs[1];
|
std::string input_rhs = batch_op->inputs[1];
|
||||||
const auto& input_lhs_array = model->GetArray(input_lhs);
|
const auto& input_lhs_array = model->GetArray(input_lhs);
|
||||||
const auto& input_rhs_array = model->GetArray(input_rhs);
|
const auto& input_rhs_array = model->GetArray(input_rhs);
|
||||||
if (!input_lhs_array.has_shape() || !input_rhs_array.has_shape())
|
if (!input_lhs_array.has_shape() || !input_rhs_array.has_shape())
|
||||||
|
@ -195,19 +196,19 @@ TransposeOperator* TransposeInput(const string& input, Model* model) {
|
||||||
}
|
}
|
||||||
AddMessageF("Unrolling BatchMatMul %s %d times", LogName(*batch_op),
|
AddMessageF("Unrolling BatchMatMul %s %d times", LogName(*batch_op),
|
||||||
bcast.output_batch_size());
|
bcast.output_batch_size());
|
||||||
string base_name = std::string(batch_op->outputs[0]);
|
std::string base_name = std::string(batch_op->outputs[0]);
|
||||||
|
|
||||||
// Compute slices for each batch in the LHS and RHS.
|
// Compute slices for each batch in the LHS and RHS.
|
||||||
std::vector<string> slice_a_outputs =
|
std::vector<std::string> slice_a_outputs =
|
||||||
SliceInput(input_lhs, base_name, "a", bcast.x_batch_size(), input_array_a,
|
SliceInput(input_lhs, base_name, "a", bcast.x_batch_size(), input_array_a,
|
||||||
model, &tail_it);
|
model, &tail_it);
|
||||||
std::vector<string> slice_b_outputs =
|
std::vector<std::string> slice_b_outputs =
|
||||||
SliceInput(input_rhs, base_name, "b", bcast.y_batch_size(), input_array_b,
|
SliceInput(input_rhs, base_name, "b", bcast.y_batch_size(), input_array_b,
|
||||||
model, &tail_it);
|
model, &tail_it);
|
||||||
|
|
||||||
// Compute (single batch) MatMul for each output batch. The MatMul outputs are
|
// Compute (single batch) MatMul for each output batch. The MatMul outputs are
|
||||||
// then packed together into one output Tensor.
|
// then packed together into one output Tensor.
|
||||||
std::vector<string> pack_inputs;
|
std::vector<std::string> pack_inputs;
|
||||||
for (int batch_idx = 0; batch_idx < bcast.output_batch_size(); ++batch_idx) {
|
for (int batch_idx = 0; batch_idx < bcast.output_batch_size(); ++batch_idx) {
|
||||||
std::string batch_name =
|
std::string batch_name =
|
||||||
absl::StrCat(batch_op->outputs[0], "_b", batch_idx);
|
absl::StrCat(batch_op->outputs[0], "_b", batch_idx);
|
||||||
|
|
Loading…
Reference in New Issue