support MatrixSetDiagV2 op
PiperOrigin-RevId: 256211310
This commit is contained in:
parent
fa33109764
commit
fc59dc2cff
@ -357,10 +357,6 @@ def generated_test_models_failing(conversion_mode):
|
|||||||
# TODO(b/135758082): L2Norm is broken in future forward
|
# TODO(b/135758082): L2Norm is broken in future forward
|
||||||
# compatibility horizon
|
# compatibility horizon
|
||||||
"l2norm",
|
"l2norm",
|
||||||
# TODO(b/135756979): Eye/MatrixDiag is broken in future
|
|
||||||
# forward compatibility horizon
|
|
||||||
"matrix_set_diag",
|
|
||||||
"eye",
|
|
||||||
]
|
]
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
@ -179,6 +179,7 @@ cc_library(
|
|||||||
srcs = [
|
srcs = [
|
||||||
"graph_transformations/convert_expanddims_to_reshape.cc",
|
"graph_transformations/convert_expanddims_to_reshape.cc",
|
||||||
"graph_transformations/convert_matrix_diag_v2_to_v1.cc",
|
"graph_transformations/convert_matrix_diag_v2_to_v1.cc",
|
||||||
|
"graph_transformations/convert_matrix_set_diag_v2_to_v1.cc",
|
||||||
"graph_transformations/convert_pure_conv_to_depthwise.cc",
|
"graph_transformations/convert_pure_conv_to_depthwise.cc",
|
||||||
"graph_transformations/convert_reorder_axes.cc",
|
"graph_transformations/convert_reorder_axes.cc",
|
||||||
"graph_transformations/convert_squeeze_to_reshape.cc",
|
"graph_transformations/convert_squeeze_to_reshape.cc",
|
||||||
|
@ -0,0 +1,80 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/strings/str_cat.h"
|
||||||
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
|
#include "tensorflow/core/platform/logging.h"
|
||||||
|
#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
|
||||||
|
#include "tensorflow/lite/toco/model.h"
|
||||||
|
#include "tensorflow/lite/toco/tooling_util.h"
|
||||||
|
|
||||||
|
namespace toco {
|
||||||
|
|
||||||
|
::tensorflow::Status ConvertMatrixSetDiagV2ToV1::Run(Model* model,
|
||||||
|
std::size_t op_index,
|
||||||
|
bool* modified) {
|
||||||
|
*modified = false;
|
||||||
|
auto it = model->operators.begin() + op_index;
|
||||||
|
const auto* op = it->get();
|
||||||
|
if (op->type != OperatorType::kMatrixSetDiagV2) {
|
||||||
|
return ::tensorflow::Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (op->inputs.size() != 3) {
|
||||||
|
return tensorflow::errors::InvalidArgument(
|
||||||
|
"The input size of op %s should be 3", LogName(*op));
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto& input_k = model->GetArray(op->inputs[2]);
|
||||||
|
|
||||||
|
if (!input_k.buffer) {
|
||||||
|
return ::tensorflow::Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (input_k.GetBuffer<ArrayDataType::kInt32>().data.size() != 1) {
|
||||||
|
return tensorflow::errors::InvalidArgument(
|
||||||
|
"Array for argument k of op %s should contains exact one element",
|
||||||
|
LogName(*op));
|
||||||
|
}
|
||||||
|
|
||||||
|
int k = input_k.GetBuffer<ArrayDataType::kInt32>().data[0];
|
||||||
|
|
||||||
|
if (k != 0) {
|
||||||
|
return tensorflow::errors::InvalidArgument(
|
||||||
|
"parameter k of op ", LogName(*op),
|
||||||
|
" is expected to be 0, other values are not supported currently");
|
||||||
|
}
|
||||||
|
|
||||||
|
auto* matrix_set_diag_op = new MatrixSetDiagOperator;
|
||||||
|
matrix_set_diag_op->inputs.push_back(op->inputs[0]);
|
||||||
|
matrix_set_diag_op->inputs.push_back(op->inputs[1]);
|
||||||
|
matrix_set_diag_op->outputs.push_back(op->outputs[0]);
|
||||||
|
|
||||||
|
AddMessageF("Replacing %s with %s", LogName(*op),
|
||||||
|
LogName(*matrix_set_diag_op));
|
||||||
|
|
||||||
|
// Replace the operator in the graph.
|
||||||
|
model->operators.emplace(it, matrix_set_diag_op);
|
||||||
|
DeleteOpAndArrays(model, op);
|
||||||
|
|
||||||
|
*modified = true;
|
||||||
|
return ::tensorflow::Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace toco
|
@ -123,6 +123,7 @@ inline void RunGraphTransformations(
|
|||||||
|
|
||||||
// List of all graph transformations
|
// List of all graph transformations
|
||||||
DECLARE_GRAPH_TRANSFORMATION(ConvertExpandDimsToReshape)
|
DECLARE_GRAPH_TRANSFORMATION(ConvertExpandDimsToReshape)
|
||||||
|
DECLARE_GRAPH_TRANSFORMATION(ConvertMatrixSetDiagV2ToV1)
|
||||||
DECLARE_GRAPH_TRANSFORMATION(ConvertMatrixDiagV2ToV1)
|
DECLARE_GRAPH_TRANSFORMATION(ConvertMatrixDiagV2ToV1)
|
||||||
DECLARE_GRAPH_TRANSFORMATION(ConvertPureConvToDepthwise)
|
DECLARE_GRAPH_TRANSFORMATION(ConvertPureConvToDepthwise)
|
||||||
DECLARE_GRAPH_TRANSFORMATION(ConvertReorderAxes)
|
DECLARE_GRAPH_TRANSFORMATION(ConvertReorderAxes)
|
||||||
|
@ -2426,6 +2426,10 @@ void ProcessMatrixSetDiagOperator(Model* model, MatrixSetDiagOperator* op) {
|
|||||||
// The sizes of the outputs are only known in runtime based on the input.
|
// The sizes of the outputs are only known in runtime based on the input.
|
||||||
// Ignore shape progapation here and defer that to the interpreter.
|
// Ignore shape progapation here and defer that to the interpreter.
|
||||||
break;
|
break;
|
||||||
|
case OperatorType::kMatrixSetDiagV2:
|
||||||
|
// MatrixSetDiagV2 operators are converted to MatrixSetDiag,
|
||||||
|
// after which their shapes are propagated.
|
||||||
|
break;
|
||||||
case OperatorType::kMatrixDiagV2:
|
case OperatorType::kMatrixDiagV2:
|
||||||
// MatrixDiagV2 operators are converted to MatrixDiag, after which their
|
// MatrixDiagV2 operators are converted to MatrixDiag, after which their
|
||||||
// shapes are propagated.
|
// shapes are propagated.
|
||||||
|
@ -2518,6 +2518,7 @@ ConverterMapType GetTensorFlowNodeConverterMap() {
|
|||||||
{"MatrixDiag", ConvertSimpleOperator<MatrixDiagOperator, 1, 1>},
|
{"MatrixDiag", ConvertSimpleOperator<MatrixDiagOperator, 1, 1>},
|
||||||
{"MatrixDiagV2", ConvertSimpleOperator<MatrixDiagV2Operator, 5, 1>},
|
{"MatrixDiagV2", ConvertSimpleOperator<MatrixDiagV2Operator, 5, 1>},
|
||||||
{"MatrixSetDiag", ConvertSimpleOperator<MatrixSetDiagOperator, 2, 1>},
|
{"MatrixSetDiag", ConvertSimpleOperator<MatrixSetDiagOperator, 2, 1>},
|
||||||
|
{"MatrixSetDiagV2", ConvertSimpleOperator<MatrixSetDiagV2Operator, 3, 1>},
|
||||||
{"Max", ConvertReduceOperator<TensorFlowMaxOperator>},
|
{"Max", ConvertReduceOperator<TensorFlowMaxOperator>},
|
||||||
{"MaxPool", ConvertMaxPoolOperator},
|
{"MaxPool", ConvertMaxPoolOperator},
|
||||||
{"Maximum", ConvertSimpleOperator<TensorFlowMaximumOperator, 2, 1>},
|
{"Maximum", ConvertSimpleOperator<TensorFlowMaximumOperator, 2, 1>},
|
||||||
|
@ -174,6 +174,7 @@ enum class OperatorType : uint8 {
|
|||||||
kMatrixDiag,
|
kMatrixDiag,
|
||||||
kMatrixSetDiag,
|
kMatrixSetDiag,
|
||||||
kMatrixDiagV2,
|
kMatrixDiagV2,
|
||||||
|
kMatrixSetDiagV2
|
||||||
};
|
};
|
||||||
|
|
||||||
// Helper to deal with TensorFlow arrays using a different ordering of
|
// Helper to deal with TensorFlow arrays using a different ordering of
|
||||||
@ -2128,6 +2129,14 @@ struct MatrixSetDiagOperator : Operator {
|
|||||||
MatrixSetDiagOperator() : Operator(OperatorType::kMatrixSetDiag) {}
|
MatrixSetDiagOperator() : Operator(OperatorType::kMatrixSetDiag) {}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Matrix Set Diag Operator V2:
|
||||||
|
// Construct a batched diagonal tensor with given input and diagonal values.
|
||||||
|
// Not fully supported, constains 1 extra inputs compared to MatrixSetDiag,
|
||||||
|
// support default parameters settings which performs the same as MatrixSetDiag
|
||||||
|
struct MatrixSetDiagV2Operator : Operator {
|
||||||
|
MatrixSetDiagV2Operator() : Operator(OperatorType::kMatrixSetDiagV2) {}
|
||||||
|
};
|
||||||
|
|
||||||
// Alloc's are used for transient arrays only. An Alloc specifies which interval
|
// Alloc's are used for transient arrays only. An Alloc specifies which interval
|
||||||
// of the "transient_data" workspace buffer passed to inference functions, is to
|
// of the "transient_data" workspace buffer passed to inference functions, is to
|
||||||
// be used for the transient array at hand. The 'start' and 'end' values are
|
// be used for the transient array at hand. The 'start' and 'end' values are
|
||||||
|
@ -55,6 +55,7 @@ void MakeGeneralGraphTransformationsSet(
|
|||||||
CHECK(transformations->empty());
|
CHECK(transformations->empty());
|
||||||
transformations->Add(new ConvertExpandDimsToReshape);
|
transformations->Add(new ConvertExpandDimsToReshape);
|
||||||
transformations->Add(new ConvertMatrixDiagV2ToV1);
|
transformations->Add(new ConvertMatrixDiagV2ToV1);
|
||||||
|
transformations->Add(new ConvertMatrixSetDiagV2ToV1);
|
||||||
transformations->Add(new ConvertSqueezeToReshape);
|
transformations->Add(new ConvertSqueezeToReshape);
|
||||||
transformations->Add(new ConvertTrivialAddNToAdd);
|
transformations->Add(new ConvertTrivialAddNToAdd);
|
||||||
transformations->Add(new ConvertTrivialPackToReshape);
|
transformations->Add(new ConvertTrivialPackToReshape);
|
||||||
|
@ -448,6 +448,7 @@ const char* OperatorTypeName(OperatorType type) {
|
|||||||
HANDLE_OPERATORTYPENAME_CASE(MatrixDiag)
|
HANDLE_OPERATORTYPENAME_CASE(MatrixDiag)
|
||||||
HANDLE_OPERATORTYPENAME_CASE(MatrixSetDiag)
|
HANDLE_OPERATORTYPENAME_CASE(MatrixSetDiag)
|
||||||
HANDLE_OPERATORTYPENAME_CASE(MatrixDiagV2)
|
HANDLE_OPERATORTYPENAME_CASE(MatrixDiagV2)
|
||||||
|
HANDLE_OPERATORTYPENAME_CASE(MatrixSetDiagV2)
|
||||||
default:
|
default:
|
||||||
LOG(FATAL) << "Unhandled op type";
|
LOG(FATAL) << "Unhandled op type";
|
||||||
#undef HANDLE_OPERATORTYPENAME_CASE
|
#undef HANDLE_OPERATORTYPENAME_CASE
|
||||||
|
Loading…
Reference in New Issue
Block a user