diff --git a/tensorflow/lite/build_def.bzl b/tensorflow/lite/build_def.bzl index c00a6d4a2a4..90ba5dfb4a3 100644 --- a/tensorflow/lite/build_def.bzl +++ b/tensorflow/lite/build_def.bzl @@ -357,10 +357,6 @@ def generated_test_models_failing(conversion_mode): # TODO(b/135758082): L2Norm is broken in future forward # compatibility horizon "l2norm", - # TODO(b/135756979): Eye/MatrixDiag is broken in future - # forward compatibility horizon - "matrix_set_diag", - "eye", ] return [] diff --git a/tensorflow/lite/toco/BUILD b/tensorflow/lite/toco/BUILD index 329430643f9..80382864c71 100644 --- a/tensorflow/lite/toco/BUILD +++ b/tensorflow/lite/toco/BUILD @@ -179,6 +179,7 @@ cc_library( srcs = [ "graph_transformations/convert_expanddims_to_reshape.cc", "graph_transformations/convert_matrix_diag_v2_to_v1.cc", + "graph_transformations/convert_matrix_set_diag_v2_to_v1.cc", "graph_transformations/convert_pure_conv_to_depthwise.cc", "graph_transformations/convert_reorder_axes.cc", "graph_transformations/convert_squeeze_to_reshape.cc", diff --git a/tensorflow/lite/toco/graph_transformations/convert_matrix_set_diag_v2_to_v1.cc b/tensorflow/lite/toco/graph_transformations/convert_matrix_set_diag_v2_to_v1.cc new file mode 100644 index 00000000000..61288f626b6 --- /dev/null +++ b/tensorflow/lite/toco/graph_transformations/convert_matrix_set_diag_v2_to_v1.cc @@ -0,0 +1,80 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include + +#include "absl/strings/str_cat.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/tooling_util.h" + +namespace toco { + +::tensorflow::Status ConvertMatrixSetDiagV2ToV1::Run(Model* model, + std::size_t op_index, + bool* modified) { + *modified = false; + auto it = model->operators.begin() + op_index; + const auto* op = it->get(); + if (op->type != OperatorType::kMatrixSetDiagV2) { + return ::tensorflow::Status::OK(); + } + + if (op->inputs.size() != 3) { + return tensorflow::errors::InvalidArgument( + "The input size of op %s should be 3", LogName(*op)); + } + + const auto& input_k = model->GetArray(op->inputs[2]); + + if (!input_k.buffer) { + return ::tensorflow::Status::OK(); + } + + if (input_k.GetBuffer().data.size() != 1) { + return tensorflow::errors::InvalidArgument( + "Array for argument k of op %s should contains exact one element", + LogName(*op)); + } + + int k = input_k.GetBuffer().data[0]; + + if (k != 0) { + return tensorflow::errors::InvalidArgument( + "parameter k of op ", LogName(*op), + " is expected to be 0, other values are not supported currently"); + } + + auto* matrix_set_diag_op = new MatrixSetDiagOperator; + matrix_set_diag_op->inputs.push_back(op->inputs[0]); + matrix_set_diag_op->inputs.push_back(op->inputs[1]); + matrix_set_diag_op->outputs.push_back(op->outputs[0]); + + AddMessageF("Replacing %s with %s", LogName(*op), + LogName(*matrix_set_diag_op)); + + // Replace the operator in the graph. + model->operators.emplace(it, matrix_set_diag_op); + DeleteOpAndArrays(model, op); + + *modified = true; + return ::tensorflow::Status::OK(); +} + +} // namespace toco diff --git a/tensorflow/lite/toco/graph_transformations/graph_transformations.h b/tensorflow/lite/toco/graph_transformations/graph_transformations.h index 83d0f2fb907..c53e07031f2 100644 --- a/tensorflow/lite/toco/graph_transformations/graph_transformations.h +++ b/tensorflow/lite/toco/graph_transformations/graph_transformations.h @@ -123,6 +123,7 @@ inline void RunGraphTransformations( // List of all graph transformations DECLARE_GRAPH_TRANSFORMATION(ConvertExpandDimsToReshape) +DECLARE_GRAPH_TRANSFORMATION(ConvertMatrixSetDiagV2ToV1) DECLARE_GRAPH_TRANSFORMATION(ConvertMatrixDiagV2ToV1) DECLARE_GRAPH_TRANSFORMATION(ConvertPureConvToDepthwise) DECLARE_GRAPH_TRANSFORMATION(ConvertReorderAxes) diff --git a/tensorflow/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/lite/toco/graph_transformations/propagate_fixed_sizes.cc index e7e70ad8c83..7f953a34e9c 100644 --- a/tensorflow/lite/toco/graph_transformations/propagate_fixed_sizes.cc +++ b/tensorflow/lite/toco/graph_transformations/propagate_fixed_sizes.cc @@ -2426,6 +2426,10 @@ void ProcessMatrixSetDiagOperator(Model* model, MatrixSetDiagOperator* op) { // The sizes of the outputs are only known in runtime based on the input. // Ignore shape progapation here and defer that to the interpreter. break; + case OperatorType::kMatrixSetDiagV2: + // MatrixSetDiagV2 operators are converted to MatrixSetDiag, + // after which their shapes are propagated. + break; case OperatorType::kMatrixDiagV2: // MatrixDiagV2 operators are converted to MatrixDiag, after which their // shapes are propagated. diff --git a/tensorflow/lite/toco/import_tensorflow.cc b/tensorflow/lite/toco/import_tensorflow.cc index 14943254fee..859fa0f6147 100644 --- a/tensorflow/lite/toco/import_tensorflow.cc +++ b/tensorflow/lite/toco/import_tensorflow.cc @@ -2518,6 +2518,7 @@ ConverterMapType GetTensorFlowNodeConverterMap() { {"MatrixDiag", ConvertSimpleOperator}, {"MatrixDiagV2", ConvertSimpleOperator}, {"MatrixSetDiag", ConvertSimpleOperator}, + {"MatrixSetDiagV2", ConvertSimpleOperator}, {"Max", ConvertReduceOperator}, {"MaxPool", ConvertMaxPoolOperator}, {"Maximum", ConvertSimpleOperator}, diff --git a/tensorflow/lite/toco/model.h b/tensorflow/lite/toco/model.h index 496740ef56f..7a95e5db582 100644 --- a/tensorflow/lite/toco/model.h +++ b/tensorflow/lite/toco/model.h @@ -174,6 +174,7 @@ enum class OperatorType : uint8 { kMatrixDiag, kMatrixSetDiag, kMatrixDiagV2, + kMatrixSetDiagV2 }; // Helper to deal with TensorFlow arrays using a different ordering of @@ -2128,6 +2129,14 @@ struct MatrixSetDiagOperator : Operator { MatrixSetDiagOperator() : Operator(OperatorType::kMatrixSetDiag) {} }; +// Matrix Set Diag Operator V2: +// Construct a batched diagonal tensor with given input and diagonal values. +// Not fully supported, constains 1 extra inputs compared to MatrixSetDiag, +// support default parameters settings which performs the same as MatrixSetDiag +struct MatrixSetDiagV2Operator : Operator { + MatrixSetDiagV2Operator() : Operator(OperatorType::kMatrixSetDiagV2) {} +}; + // Alloc's are used for transient arrays only. An Alloc specifies which interval // of the "transient_data" workspace buffer passed to inference functions, is to // be used for the transient array at hand. The 'start' and 'end' values are diff --git a/tensorflow/lite/toco/toco_tooling.cc b/tensorflow/lite/toco/toco_tooling.cc index b7c03452f01..96f9d7602f1 100644 --- a/tensorflow/lite/toco/toco_tooling.cc +++ b/tensorflow/lite/toco/toco_tooling.cc @@ -55,6 +55,7 @@ void MakeGeneralGraphTransformationsSet( CHECK(transformations->empty()); transformations->Add(new ConvertExpandDimsToReshape); transformations->Add(new ConvertMatrixDiagV2ToV1); + transformations->Add(new ConvertMatrixSetDiagV2ToV1); transformations->Add(new ConvertSqueezeToReshape); transformations->Add(new ConvertTrivialAddNToAdd); transformations->Add(new ConvertTrivialPackToReshape); diff --git a/tensorflow/lite/toco/tooling_util.cc b/tensorflow/lite/toco/tooling_util.cc index f06621b73d3..3978cf5ee1a 100644 --- a/tensorflow/lite/toco/tooling_util.cc +++ b/tensorflow/lite/toco/tooling_util.cc @@ -448,6 +448,7 @@ const char* OperatorTypeName(OperatorType type) { HANDLE_OPERATORTYPENAME_CASE(MatrixDiag) HANDLE_OPERATORTYPENAME_CASE(MatrixSetDiag) HANDLE_OPERATORTYPENAME_CASE(MatrixDiagV2) + HANDLE_OPERATORTYPENAME_CASE(MatrixSetDiagV2) default: LOG(FATAL) << "Unhandled op type"; #undef HANDLE_OPERATORTYPENAME_CASE