support MatrixSetDiagV2 op
PiperOrigin-RevId: 256211310
This commit is contained in:
parent
fa33109764
commit
fc59dc2cff
@ -357,10 +357,6 @@ def generated_test_models_failing(conversion_mode):
|
||||
# TODO(b/135758082): L2Norm is broken in future forward
|
||||
# compatibility horizon
|
||||
"l2norm",
|
||||
# TODO(b/135756979): Eye/MatrixDiag is broken in future
|
||||
# forward compatibility horizon
|
||||
"matrix_set_diag",
|
||||
"eye",
|
||||
]
|
||||
return []
|
||||
|
||||
|
@ -179,6 +179,7 @@ cc_library(
|
||||
srcs = [
|
||||
"graph_transformations/convert_expanddims_to_reshape.cc",
|
||||
"graph_transformations/convert_matrix_diag_v2_to_v1.cc",
|
||||
"graph_transformations/convert_matrix_set_diag_v2_to_v1.cc",
|
||||
"graph_transformations/convert_pure_conv_to_depthwise.cc",
|
||||
"graph_transformations/convert_reorder_axes.cc",
|
||||
"graph_transformations/convert_squeeze_to_reshape.cc",
|
||||
|
@ -0,0 +1,80 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
|
||||
#include "tensorflow/lite/toco/model.h"
|
||||
#include "tensorflow/lite/toco/tooling_util.h"
|
||||
|
||||
namespace toco {
|
||||
|
||||
::tensorflow::Status ConvertMatrixSetDiagV2ToV1::Run(Model* model,
|
||||
std::size_t op_index,
|
||||
bool* modified) {
|
||||
*modified = false;
|
||||
auto it = model->operators.begin() + op_index;
|
||||
const auto* op = it->get();
|
||||
if (op->type != OperatorType::kMatrixSetDiagV2) {
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
if (op->inputs.size() != 3) {
|
||||
return tensorflow::errors::InvalidArgument(
|
||||
"The input size of op %s should be 3", LogName(*op));
|
||||
}
|
||||
|
||||
const auto& input_k = model->GetArray(op->inputs[2]);
|
||||
|
||||
if (!input_k.buffer) {
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
if (input_k.GetBuffer<ArrayDataType::kInt32>().data.size() != 1) {
|
||||
return tensorflow::errors::InvalidArgument(
|
||||
"Array for argument k of op %s should contains exact one element",
|
||||
LogName(*op));
|
||||
}
|
||||
|
||||
int k = input_k.GetBuffer<ArrayDataType::kInt32>().data[0];
|
||||
|
||||
if (k != 0) {
|
||||
return tensorflow::errors::InvalidArgument(
|
||||
"parameter k of op ", LogName(*op),
|
||||
" is expected to be 0, other values are not supported currently");
|
||||
}
|
||||
|
||||
auto* matrix_set_diag_op = new MatrixSetDiagOperator;
|
||||
matrix_set_diag_op->inputs.push_back(op->inputs[0]);
|
||||
matrix_set_diag_op->inputs.push_back(op->inputs[1]);
|
||||
matrix_set_diag_op->outputs.push_back(op->outputs[0]);
|
||||
|
||||
AddMessageF("Replacing %s with %s", LogName(*op),
|
||||
LogName(*matrix_set_diag_op));
|
||||
|
||||
// Replace the operator in the graph.
|
||||
model->operators.emplace(it, matrix_set_diag_op);
|
||||
DeleteOpAndArrays(model, op);
|
||||
|
||||
*modified = true;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
} // namespace toco
|
@ -123,6 +123,7 @@ inline void RunGraphTransformations(
|
||||
|
||||
// List of all graph transformations
|
||||
DECLARE_GRAPH_TRANSFORMATION(ConvertExpandDimsToReshape)
|
||||
DECLARE_GRAPH_TRANSFORMATION(ConvertMatrixSetDiagV2ToV1)
|
||||
DECLARE_GRAPH_TRANSFORMATION(ConvertMatrixDiagV2ToV1)
|
||||
DECLARE_GRAPH_TRANSFORMATION(ConvertPureConvToDepthwise)
|
||||
DECLARE_GRAPH_TRANSFORMATION(ConvertReorderAxes)
|
||||
|
@ -2426,6 +2426,10 @@ void ProcessMatrixSetDiagOperator(Model* model, MatrixSetDiagOperator* op) {
|
||||
// The sizes of the outputs are only known in runtime based on the input.
|
||||
// Ignore shape progapation here and defer that to the interpreter.
|
||||
break;
|
||||
case OperatorType::kMatrixSetDiagV2:
|
||||
// MatrixSetDiagV2 operators are converted to MatrixSetDiag,
|
||||
// after which their shapes are propagated.
|
||||
break;
|
||||
case OperatorType::kMatrixDiagV2:
|
||||
// MatrixDiagV2 operators are converted to MatrixDiag, after which their
|
||||
// shapes are propagated.
|
||||
|
@ -2518,6 +2518,7 @@ ConverterMapType GetTensorFlowNodeConverterMap() {
|
||||
{"MatrixDiag", ConvertSimpleOperator<MatrixDiagOperator, 1, 1>},
|
||||
{"MatrixDiagV2", ConvertSimpleOperator<MatrixDiagV2Operator, 5, 1>},
|
||||
{"MatrixSetDiag", ConvertSimpleOperator<MatrixSetDiagOperator, 2, 1>},
|
||||
{"MatrixSetDiagV2", ConvertSimpleOperator<MatrixSetDiagV2Operator, 3, 1>},
|
||||
{"Max", ConvertReduceOperator<TensorFlowMaxOperator>},
|
||||
{"MaxPool", ConvertMaxPoolOperator},
|
||||
{"Maximum", ConvertSimpleOperator<TensorFlowMaximumOperator, 2, 1>},
|
||||
|
@ -174,6 +174,7 @@ enum class OperatorType : uint8 {
|
||||
kMatrixDiag,
|
||||
kMatrixSetDiag,
|
||||
kMatrixDiagV2,
|
||||
kMatrixSetDiagV2
|
||||
};
|
||||
|
||||
// Helper to deal with TensorFlow arrays using a different ordering of
|
||||
@ -2128,6 +2129,14 @@ struct MatrixSetDiagOperator : Operator {
|
||||
MatrixSetDiagOperator() : Operator(OperatorType::kMatrixSetDiag) {}
|
||||
};
|
||||
|
||||
// Matrix Set Diag Operator V2:
|
||||
// Construct a batched diagonal tensor with given input and diagonal values.
|
||||
// Not fully supported, constains 1 extra inputs compared to MatrixSetDiag,
|
||||
// support default parameters settings which performs the same as MatrixSetDiag
|
||||
struct MatrixSetDiagV2Operator : Operator {
|
||||
MatrixSetDiagV2Operator() : Operator(OperatorType::kMatrixSetDiagV2) {}
|
||||
};
|
||||
|
||||
// Alloc's are used for transient arrays only. An Alloc specifies which interval
|
||||
// of the "transient_data" workspace buffer passed to inference functions, is to
|
||||
// be used for the transient array at hand. The 'start' and 'end' values are
|
||||
|
@ -55,6 +55,7 @@ void MakeGeneralGraphTransformationsSet(
|
||||
CHECK(transformations->empty());
|
||||
transformations->Add(new ConvertExpandDimsToReshape);
|
||||
transformations->Add(new ConvertMatrixDiagV2ToV1);
|
||||
transformations->Add(new ConvertMatrixSetDiagV2ToV1);
|
||||
transformations->Add(new ConvertSqueezeToReshape);
|
||||
transformations->Add(new ConvertTrivialAddNToAdd);
|
||||
transformations->Add(new ConvertTrivialPackToReshape);
|
||||
|
@ -448,6 +448,7 @@ const char* OperatorTypeName(OperatorType type) {
|
||||
HANDLE_OPERATORTYPENAME_CASE(MatrixDiag)
|
||||
HANDLE_OPERATORTYPENAME_CASE(MatrixSetDiag)
|
||||
HANDLE_OPERATORTYPENAME_CASE(MatrixDiagV2)
|
||||
HANDLE_OPERATORTYPENAME_CASE(MatrixSetDiagV2)
|
||||
default:
|
||||
LOG(FATAL) << "Unhandled op type";
|
||||
#undef HANDLE_OPERATORTYPENAME_CASE
|
||||
|
Loading…
Reference in New Issue
Block a user