From d79bb04e2a3d4a6086135ae234f1bb61201c7df6 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 18 Mar 2019 14:29:41 -0700 Subject: [PATCH] Implement Matrix Diag PiperOrigin-RevId: 239061891 --- tensorflow/lite/build_def.bzl | 1 + tensorflow/lite/builtin_ops.h | 1 + tensorflow/lite/c/builtin_op_data.h | 4 + .../lite/core/api/flatbuffer_conversions.cc | 4 +- tensorflow/lite/kernels/BUILD | 13 ++ tensorflow/lite/kernels/matrix_diag.cc | 136 ++++++++++++++++++ tensorflow/lite/kernels/matrix_diag_test.cc | 110 ++++++++++++++ tensorflow/lite/kernels/register.cc | 2 + tensorflow/lite/kernels/test_util.h | 50 +++++++ tensorflow/lite/nnapi_delegate.cc | 1 + tensorflow/lite/schema/schema.fbs | 6 + tensorflow/lite/schema/schema_generated.h | 124 +++++++++++++++- tensorflow/lite/testing/generate_examples.py | 27 ++++ .../propagate_array_data_types.cc | 7 + .../propagate_fixed_sizes.cc | 21 +++ tensorflow/lite/toco/import_tensorflow.cc | 1 + tensorflow/lite/toco/model.h | 11 +- tensorflow/lite/toco/tflite/operator.cc | 3 +- tensorflow/lite/toco/tflite/operator_test.cc | 7 + tensorflow/lite/toco/tooling_util.cc | 1 + 20 files changed, 520 insertions(+), 10 deletions(-) create mode 100644 tensorflow/lite/kernels/matrix_diag.cc create mode 100644 tensorflow/lite/kernels/matrix_diag_test.cc diff --git a/tensorflow/lite/build_def.bzl b/tensorflow/lite/build_def.bzl index 2b30309d83a..c5ace905d4f 100644 --- a/tensorflow/lite/build_def.bzl +++ b/tensorflow/lite/build_def.bzl @@ -274,6 +274,7 @@ def generated_test_models(): "logical_or", "logical_xor", "lstm", + "matrix_diag", "max_pool", "maximum", "mean", diff --git a/tensorflow/lite/builtin_ops.h b/tensorflow/lite/builtin_ops.h index 7b4efdf4a36..8142e7353f5 100644 --- a/tensorflow/lite/builtin_ops.h +++ b/tensorflow/lite/builtin_ops.h @@ -138,6 +138,7 @@ typedef enum { kTfLiteBuiltinRank = 110, kTfLiteBuiltinElu = 111, kTfLiteBuiltinReverseSequence = 112, + kTfLiteBuiltinMatrixDiag = 113, } TfLiteBuiltinOperator; #ifdef __cplusplus diff --git a/tensorflow/lite/c/builtin_op_data.h b/tensorflow/lite/c/builtin_op_data.h index 5d1c92d36f5..cdc44be3c7e 100644 --- a/tensorflow/lite/c/builtin_op_data.h +++ b/tensorflow/lite/c/builtin_op_data.h @@ -373,6 +373,10 @@ typedef struct { int batch_dim; } TfLiteReverseSequenceParams; +typedef struct { + EmptyStructPlaceholder placeholder; +} TfLiteMatrixDiagParams; + #ifdef __cplusplus } // extern "C" #endif // __cplusplus diff --git a/tensorflow/lite/core/api/flatbuffer_conversions.cc b/tensorflow/lite/core/api/flatbuffer_conversions.cc index 2ba64f51d9a..db6bd8b9532 100644 --- a/tensorflow/lite/core/api/flatbuffer_conversions.cc +++ b/tensorflow/lite/core/api/flatbuffer_conversions.cc @@ -683,8 +683,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, *builtin_data = reinterpret_cast(params); break; } - - // Below are the ops with no builtin_data strcture. + // Below are the ops with no builtin_data structure. case BuiltinOperator_ABS: case BuiltinOperator_BATCH_TO_SPACE_ND: // TODO(aselle): Implement call in BuiltinOptions, but nullptrs are @@ -708,6 +707,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, case BuiltinOperator_LOG: case BuiltinOperator_LOGISTIC: case BuiltinOperator_LOG_SOFTMAX: + case BuiltinOperator_MATRIX_DIAG: case BuiltinOperator_MAXIMUM: case BuiltinOperator_MINIMUM: case BuiltinOperator_NEG: diff --git a/tensorflow/lite/kernels/BUILD b/tensorflow/lite/kernels/BUILD index dece7113951..f8faa460c60 100644 --- a/tensorflow/lite/kernels/BUILD +++ b/tensorflow/lite/kernels/BUILD @@ -188,6 +188,7 @@ cc_library( "logical.cc", "lsh_projection.cc", "lstm.cc", + "matrix_diag.cc", "maximum_minimum.cc", "mfcc.cc", "mirror_pad.cc", @@ -1388,3 +1389,15 @@ cc_test( "@com_google_googletest//:gtest", ], ) + +cc_test( + name = "matrix_diag_test", + size = "small", + srcs = ["matrix_diag_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/lite:framework", + "//tensorflow/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) diff --git a/tensorflow/lite/kernels/matrix_diag.cc b/tensorflow/lite/kernels/matrix_diag.cc new file mode 100644 index 00000000000..d30187103fd --- /dev/null +++ b/tensorflow/lite/kernels/matrix_diag.cc @@ -0,0 +1,136 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/lite/kernels/internal/tensor.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace matrix_diag { + +constexpr int kInputTensor = 0; +constexpr int kOutputTensor = 0; + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + TfLiteIntArray* input_dims = input->dims; + int input_dims_size = input_dims->size; + TF_LITE_ENSURE(context, input_dims_size >= 1); + + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + // Resize the output tensor. + TfLiteIntArray* output_shape = TfLiteIntArrayCreate(input_dims_size + 1); + for (int i = 0; i < input_dims_size; i++) { + output_shape->data[i] = input_dims->data[i]; + } + // Last dimension in the output is the same as the last dimension in the + // input. + output_shape->data[input_dims_size] = input_dims->data[input_dims_size - 1]; + output->type = input->type; + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, output, output_shape)); + + return kTfLiteOk; +} + +// Fill the tensor to make a diagonal matrix in each batch, i.e., when +// row index and column index are the same, fill with the next input value. +// All other entries get zero. +// TODO(b/128636574) Move to reference_ops. +template +void FillDiagImpl(const T* in, T* out, const int batch_size, const int row_size, + const int col_size) { + int idx = 0; + for (int b = 0; b < batch_size; b++) { + for (int i = 0; i < row_size; i++) { + for (int j = 0; j < col_size; ++j) { + // input values go on the diagonal, 0 elsewhere + if (i == j) { + out[i * col_size + j] = in[idx]; + idx++; + } else { + out[i * col_size + j] = 0; + } + } + } + out += row_size * col_size; + } +} + +template +void FillDiag(const TfLiteTensor* input, TfLiteTensor* output, + const int batch_size, const int row_size, const int col_size) { + FillDiagImpl(GetTensorData(input), GetTensorData(output), batch_size, + row_size, col_size); +} + +// Fill a tensor with given input on the diagonal, zero elsewhere +void FillDiagHelper(const TfLiteTensor* input, TfLiteTensor* output) { + const int num_output_dims = output->dims->size; + int batch_size = 1; + for (int i = 0; i < num_output_dims - 2; ++i) { + batch_size *= output->dims->data[i]; + } + + const int row_size = output->dims->data[num_output_dims - 2]; + const int col_size = output->dims->data[num_output_dims - 1]; + switch (output->type) { + case kTfLiteInt64: { + return FillDiag(input, output, batch_size, row_size, col_size); + } + case kTfLiteInt32: { + return FillDiag(input, output, batch_size, row_size, col_size); + } + case kTfLiteInt16: { + return FillDiag(input, output, batch_size, row_size, col_size); + } + case kTfLiteInt8: { + return FillDiag(input, output, batch_size, row_size, col_size); + } + case kTfLiteUInt8: { + return FillDiag(input, output, batch_size, row_size, col_size); + } + default: + return FillDiag(input, output, batch_size, row_size, col_size); + } +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + FillDiagHelper(input, output); + return kTfLiteOk; +} + +} // namespace matrix_diag + +TfLiteRegistration* Register_MATRIX_DIAG() { + static TfLiteRegistration r = {nullptr, nullptr, matrix_diag::Prepare, + matrix_diag::Eval}; + return &r; +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/lite/kernels/matrix_diag_test.cc b/tensorflow/lite/kernels/matrix_diag_test.cc new file mode 100644 index 00000000000..757209a4b3b --- /dev/null +++ b/tensorflow/lite/kernels/matrix_diag_test.cc @@ -0,0 +1,110 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAre; +using ::testing::ElementsAreArray; + +template +class MatrixDiagOpModel : public SingleOpModel { + public: + explicit MatrixDiagOpModel(const TensorData& input) { + input_ = AddInput(input); + output_ = AddOutput({input.type, {}}); + + SetBuiltinOp(BuiltinOperator_MATRIX_DIAG, BuiltinOptions_MatrixDiagOptions, + CreateMatrixDiagOptions(builder_).Union()); + BuildInterpreter({GetShape(input_)}); + } + + int input() { return input_; } + + std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetOutputShape() { return GetTensorShape(output_); } + TfLiteType GetOutputType() { + TfLiteTensor* t = interpreter_->tensor(output_); + return t->type; + } + + private: + int input_; + int output_; +}; + +// Use the machinery of TYPED_TEST_SUITE to test all supported types. +// See +// https://github.com/google/googletest/blob/master/googletest/docs/advanced.md#typed-tests +// for details. +template +class MatrixDiagOpTest : public ::testing::Test {}; + +using TypesUnderTest = + ::testing::Types, TypeUnion, TypeUnion, + TypeUnion, TypeUnion>; +TYPED_TEST_SUITE(MatrixDiagOpTest, TypesUnderTest); + +TYPED_TEST(MatrixDiagOpTest, ThreeByThreeDiag) { + MatrixDiagOpModel model( + {TypeParam::tensor_type, {3}}); + model.template PopulateTensor(model.input(), + {1, 2, 3}); + model.Invoke(); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(3, 3)); + EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 0, 0, // + 0, 2, 0, // + 0, 0, 3})); + EXPECT_THAT(model.GetOutputType(), TypeParam::tflite_type); +} + +// Additional special cases. +TEST(MatrixDiagTest, Int32TestTwoDimDiag) { + MatrixDiagOpModel model({TensorType_INT32, {2, 4}}); + model.PopulateTensor(model.input(), {1, 2, 3, 4, 5, 6, 7, 8}); + model.Invoke(); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(2, 4, 4)); + EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 0, 0, 0, // + 0, 2, 0, 0, // + 0, 0, 3, 0, // + 0, 0, 0, 4, // + 5, 0, 0, 0, // + 0, 6, 0, 0, // + 0, 0, 7, 0, // + 0, 0, 0, 8})); + EXPECT_THAT(model.GetOutputType(), TfLiteType::kTfLiteInt32); +} + +TEST(MatrixDiagTest, DegenenerateCase) { + MatrixDiagOpModel model({TensorType_UINT8, {1}}); + model.PopulateTensor(model.input(), {1}); + model.Invoke(); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1)); + EXPECT_THAT(model.GetOutput(), ElementsAreArray({1})); + EXPECT_THAT(model.GetOutputType(), TfLiteType::kTfLiteUInt8); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/lite/kernels/register.cc b/tensorflow/lite/kernels/register.cc index eada2111255..dc02d85f5f9 100644 --- a/tensorflow/lite/kernels/register.cc +++ b/tensorflow/lite/kernels/register.cc @@ -139,6 +139,7 @@ TfLiteRegistration* Register_GATHER_ND(); TfLiteRegistration* Register_WHERE(); TfLiteRegistration* Register_ELU(); TfLiteRegistration* Register_REVERSE_SEQUENCE(); +TfLiteRegistration* Register_MATRIX_DIAG(); TfLiteStatus UnsupportedTensorFlowOp(TfLiteContext* context, TfLiteNode* node) { context->ReportError( @@ -374,6 +375,7 @@ BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_WHERE, Register_WHERE()); AddBuiltin(BuiltinOperator_ELU, Register_ELU()); AddBuiltin(BuiltinOperator_REVERSE_SEQUENCE, Register_REVERSE_SEQUENCE()); + AddBuiltin(BuiltinOperator_MATRIX_DIAG, Register_MATRIX_DIAG()); // TODO(andrewharp, ahentz): Move these somewhere more appropriate so that // custom ops aren't always included by default. diff --git a/tensorflow/lite/kernels/test_util.h b/tensorflow/lite/kernels/test_util.h index 4ea54837e8a..cf5a111092c 100644 --- a/tensorflow/lite/kernels/test_util.h +++ b/tensorflow/lite/kernels/test_util.h @@ -522,6 +522,56 @@ TensorType GetTensorType() { // Strings have a special implementation that is in test_util.cc template <> std::vector SingleOpModel::ExtractVector(int index); + +// The TypeUnion struct specializations hold a collection of related types. +// Each struct holds: 1. a primitive type (e.g. float), 2. a TensorType (e.g. +// TensorType_FLOAT32, and 3. a TfLiteType (e.g. kTfLiteFloat32). The latter +// two are actually enum values and not raw types, but these specializations +// make it easy to use gUnit Typed Test Suite: +// https://github.com/google/googletest/blob/master/googletest/docs/advanced.md#typed-tests +template +struct TypeUnion; + +template <> +struct TypeUnion { + public: + static const TensorType tensor_type = TensorType::TensorType_FLOAT32; + static const TfLiteType tflite_type = TfLiteType::kTfLiteFloat32; + typedef float_t ScalarType; +}; + +template <> +struct TypeUnion { + public: + static const TensorType tensor_type = TensorType::TensorType_INT32; + static const TfLiteType tflite_type = TfLiteType::kTfLiteInt32; + typedef int32_t ScalarType; +}; + +template <> +struct TypeUnion { + public: + static const TensorType tensor_type = TensorType::TensorType_INT16; + static const TfLiteType tflite_type = TfLiteType::kTfLiteInt16; + typedef int16_t ScalarType; +}; + +template <> +struct TypeUnion { + public: + static const TensorType tensor_type = TensorType::TensorType_INT8; + static const TfLiteType tflite_type = TfLiteType::kTfLiteInt8; + typedef int8_t ScalarType; +}; + +template <> +struct TypeUnion { + public: + static const TensorType tensor_type = TensorType::TensorType_UINT8; + static const TfLiteType tflite_type = TfLiteType::kTfLiteUInt8; + typedef uint8_t ScalarType; +}; + } // namespace tflite #endif // TENSORFLOW_LITE_KERNELS_TEST_UTIL_H_ diff --git a/tensorflow/lite/nnapi_delegate.cc b/tensorflow/lite/nnapi_delegate.cc index 443651b9910..da11c0f6ed2 100644 --- a/tensorflow/lite/nnapi_delegate.cc +++ b/tensorflow/lite/nnapi_delegate.cc @@ -671,6 +671,7 @@ TfLiteStatus AddOpsAndParams( case tflite::BuiltinOperator_RANK: case tflite::BuiltinOperator_ELU: case tflite::BuiltinOperator_REVERSE_SEQUENCE: + case tflite::BuiltinOperator_MATRIX_DIAG: logError("Op code %d is currently not delegated to NNAPI", builtin); return kTfLiteError; break; diff --git a/tensorflow/lite/schema/schema.fbs b/tensorflow/lite/schema/schema.fbs index c6c61a602a8..353fb494f00 100644 --- a/tensorflow/lite/schema/schema.fbs +++ b/tensorflow/lite/schema/schema.fbs @@ -226,6 +226,7 @@ enum BuiltinOperator : byte { RANK = 110, ELU = 111, REVERSE_SEQUENCE = 112, + MATRIX_DIAG = 113, } // Options for the builtin operators. @@ -317,6 +318,7 @@ union BuiltinOptions { WhereOptions, RankOptions, ReverseSequenceOptions, + MatrixDiagOptions, } enum Padding : byte { SAME, VALID } @@ -756,6 +758,10 @@ table ReverseSequenceOptions { seq_dim:int; batch_dim:int = 0; } + +table MatrixDiagOptions { +} + // An OperatorCode can be an enum value (BuiltinOperator) if the operator is a // builtin, or a string if the operator is custom. table OperatorCode { diff --git a/tensorflow/lite/schema/schema_generated.h b/tensorflow/lite/schema/schema_generated.h index 2a55698a616..161c8b8e650 100755 --- a/tensorflow/lite/schema/schema_generated.h +++ b/tensorflow/lite/schema/schema_generated.h @@ -292,6 +292,9 @@ struct WhereOptionsT; struct ReverseSequenceOptions; struct ReverseSequenceOptionsT; +struct MatrixDiagOptions; +struct MatrixDiagOptionsT; + struct OperatorCode; struct OperatorCodeT; @@ -554,11 +557,12 @@ enum BuiltinOperator { BuiltinOperator_RANK = 110, BuiltinOperator_ELU = 111, BuiltinOperator_REVERSE_SEQUENCE = 112, + BuiltinOperator_MATRIX_DIAG = 113, BuiltinOperator_MIN = BuiltinOperator_ADD, - BuiltinOperator_MAX = BuiltinOperator_REVERSE_SEQUENCE + BuiltinOperator_MAX = BuiltinOperator_MATRIX_DIAG }; -inline const BuiltinOperator (&EnumValuesBuiltinOperator())[112] { +inline const BuiltinOperator (&EnumValuesBuiltinOperator())[113] { static const BuiltinOperator values[] = { BuiltinOperator_ADD, BuiltinOperator_AVERAGE_POOL_2D, @@ -671,7 +675,8 @@ inline const BuiltinOperator (&EnumValuesBuiltinOperator())[112] { BuiltinOperator_WHERE, BuiltinOperator_RANK, BuiltinOperator_ELU, - BuiltinOperator_REVERSE_SEQUENCE + BuiltinOperator_REVERSE_SEQUENCE, + BuiltinOperator_MATRIX_DIAG }; return values; } @@ -791,6 +796,7 @@ inline const char * const *EnumNamesBuiltinOperator() { "RANK", "ELU", "REVERSE_SEQUENCE", + "MATRIX_DIAG", nullptr }; return names; @@ -890,11 +896,12 @@ enum BuiltinOptions { BuiltinOptions_WhereOptions = 85, BuiltinOptions_RankOptions = 86, BuiltinOptions_ReverseSequenceOptions = 87, + BuiltinOptions_MatrixDiagOptions = 88, BuiltinOptions_MIN = BuiltinOptions_NONE, - BuiltinOptions_MAX = BuiltinOptions_ReverseSequenceOptions + BuiltinOptions_MAX = BuiltinOptions_MatrixDiagOptions }; -inline const BuiltinOptions (&EnumValuesBuiltinOptions())[88] { +inline const BuiltinOptions (&EnumValuesBuiltinOptions())[89] { static const BuiltinOptions values[] = { BuiltinOptions_NONE, BuiltinOptions_Conv2DOptions, @@ -983,7 +990,8 @@ inline const BuiltinOptions (&EnumValuesBuiltinOptions())[88] { BuiltinOptions_CosOptions, BuiltinOptions_WhereOptions, BuiltinOptions_RankOptions, - BuiltinOptions_ReverseSequenceOptions + BuiltinOptions_ReverseSequenceOptions, + BuiltinOptions_MatrixDiagOptions }; return values; } @@ -1078,6 +1086,7 @@ inline const char * const *EnumNamesBuiltinOptions() { "WhereOptions", "RankOptions", "ReverseSequenceOptions", + "MatrixDiagOptions", nullptr }; return names; @@ -1440,6 +1449,10 @@ template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_ReverseSequenceOptions; }; +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_MatrixDiagOptions; +}; + struct BuiltinOptionsUnion { BuiltinOptions type; void *value; @@ -2167,6 +2180,14 @@ struct BuiltinOptionsUnion { return type == BuiltinOptions_ReverseSequenceOptions ? reinterpret_cast(value) : nullptr; } + MatrixDiagOptionsT *AsMatrixDiagOptions() { + return type == BuiltinOptions_MatrixDiagOptions ? + reinterpret_cast(value) : nullptr; + } + const MatrixDiagOptionsT *AsMatrixDiagOptions() const { + return type == BuiltinOptions_MatrixDiagOptions ? + reinterpret_cast(value) : nullptr; + } }; bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type); @@ -7569,6 +7590,46 @@ inline flatbuffers::Offset CreateReverseSequenceOptions( flatbuffers::Offset CreateReverseSequenceOptions(flatbuffers::FlatBufferBuilder &_fbb, const ReverseSequenceOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +struct MatrixDiagOptionsT : public flatbuffers::NativeTable { + typedef MatrixDiagOptions TableType; + MatrixDiagOptionsT() { + } +}; + +struct MatrixDiagOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef MatrixDiagOptionsT NativeTableType; + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + MatrixDiagOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(MatrixDiagOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const MatrixDiagOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct MatrixDiagOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + explicit MatrixDiagOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + MatrixDiagOptionsBuilder &operator=(const MatrixDiagOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateMatrixDiagOptions( + flatbuffers::FlatBufferBuilder &_fbb) { + MatrixDiagOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +flatbuffers::Offset CreateMatrixDiagOptions(flatbuffers::FlatBufferBuilder &_fbb, const MatrixDiagOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + struct OperatorCodeT : public flatbuffers::NativeTable { typedef OperatorCode TableType; BuiltinOperator builtin_code; @@ -7963,6 +8024,9 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { const ReverseSequenceOptions *builtin_options_as_ReverseSequenceOptions() const { return builtin_options_type() == BuiltinOptions_ReverseSequenceOptions ? static_cast(builtin_options()) : nullptr; } + const MatrixDiagOptions *builtin_options_as_MatrixDiagOptions() const { + return builtin_options_type() == BuiltinOptions_MatrixDiagOptions ? static_cast(builtin_options()) : nullptr; + } const flatbuffers::Vector *custom_options() const { return GetPointer *>(VT_CUSTOM_OPTIONS); } @@ -8342,6 +8406,10 @@ template<> inline const ReverseSequenceOptions *Operator::builtin_options_as inline const MatrixDiagOptions *Operator::builtin_options_as() const { + return builtin_options_as_MatrixDiagOptions(); +} + struct OperatorBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; @@ -11156,6 +11224,29 @@ inline flatbuffers::Offset CreateReverseSequenceOptions( _batch_dim); } +inline MatrixDiagOptionsT *MatrixDiagOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new MatrixDiagOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void MatrixDiagOptions::UnPackTo(MatrixDiagOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline flatbuffers::Offset MatrixDiagOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const MatrixDiagOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateMatrixDiagOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateMatrixDiagOptions(flatbuffers::FlatBufferBuilder &_fbb, const MatrixDiagOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const MatrixDiagOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateMatrixDiagOptions( + _fbb); +} + inline OperatorCodeT *OperatorCode::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new OperatorCodeT(); UnPackTo(_o, _resolver); @@ -11762,6 +11853,10 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *ob auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } + case BuiltinOptions_MatrixDiagOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } default: return false; } } @@ -12128,6 +12223,10 @@ inline void *BuiltinOptionsUnion::UnPack(const void *obj, BuiltinOptions type, c auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } + case BuiltinOptions_MatrixDiagOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } default: return nullptr; } } @@ -12482,6 +12581,10 @@ inline flatbuffers::Offset BuiltinOptionsUnion::Pack(flatbuffers::FlatBuff auto ptr = reinterpret_cast(value); return CreateReverseSequenceOptions(_fbb, ptr, _rehasher).Union(); } + case BuiltinOptions_MatrixDiagOptions: { + auto ptr = reinterpret_cast(value); + return CreateMatrixDiagOptions(_fbb, ptr, _rehasher).Union(); + } default: return 0; } } @@ -12836,6 +12939,10 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) FL value = new ReverseSequenceOptionsT(*reinterpret_cast(u.value)); break; } + case BuiltinOptions_MatrixDiagOptions: { + value = new MatrixDiagOptionsT(*reinterpret_cast(u.value)); + break; + } default: break; } @@ -13278,6 +13385,11 @@ inline void BuiltinOptionsUnion::Reset() { delete ptr; break; } + case BuiltinOptions_MatrixDiagOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } default: break; } value = nullptr; diff --git a/tensorflow/lite/testing/generate_examples.py b/tensorflow/lite/testing/generate_examples.py index 213d214c132..84da8e06d97 100644 --- a/tensorflow/lite/testing/generate_examples.py +++ b/tensorflow/lite/testing/generate_examples.py @@ -4361,6 +4361,33 @@ def make_reverse_sequence_tests(zip_path): make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) +def make_matrix_diag_tests(zip_path): + """Make a set of tests for tf.matrix_diag op.""" + + test_parameters = [ + { + "input_shape": [[3], [2, 3], [3, 4, 5], [2, 4, 6, 8]], + "input_dtype": [tf.int32, tf.float32], + }, + ] + + def build_graph(parameters): + input_tensor = tf.placeholder( + dtype=parameters["input_dtype"], + name="input", + shape=parameters["input_shape"]) + outs = tf.matrix_diag(input_tensor) + return [input_tensor], [outs] + + def build_inputs(parameters, sess, inputs, outputs): + input_values = create_tensor_data(parameters["input_dtype"], + parameters["input_shape"]) + return [input_values], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_values]))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + @test_util.enable_control_flow_v2 def make_unidirectional_sequence_lstm_tests(zip_path): """Make a set of tests to do unidirectional_sequence_lstm.""" diff --git a/tensorflow/lite/toco/graph_transformations/propagate_array_data_types.cc b/tensorflow/lite/toco/graph_transformations/propagate_array_data_types.cc index cb66a2372fd..1cffee785a8 100644 --- a/tensorflow/lite/toco/graph_transformations/propagate_array_data_types.cc +++ b/tensorflow/lite/toco/graph_transformations/propagate_array_data_types.cc @@ -286,6 +286,13 @@ void SetDataTypeForAllOutputs(Model* model, Operator* op, // have data type fields for all their arrays. break; } + case OperatorType::kMatrixDiag: { + CHECK_EQ(op->inputs.size(), 1); + CHECK_EQ(op->outputs.size(), 1); + const ArrayDataType data_type = model->GetArray(op->inputs[0]).data_type; + SetDataTypeForAllOutputs(model, op, data_type); + break; + } default: { // These operators produce outputs with the same type as their 1st input CHECK_GT(op->inputs.size(), 0); diff --git a/tensorflow/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/lite/toco/graph_transformations/propagate_fixed_sizes.cc index ca72d0037a9..33b66dab8cb 100644 --- a/tensorflow/lite/toco/graph_transformations/propagate_fixed_sizes.cc +++ b/tensorflow/lite/toco/graph_transformations/propagate_fixed_sizes.cc @@ -2060,6 +2060,24 @@ void ProcessUniqueOperator(Model* model, UniqueOperator* op) { idx_output_array.copy_shape(input_array.shape()); } +void ProcessMatrixDiagOperator(Model* model, MatrixDiagOperator* op) { + CHECK_EQ(op->inputs.size(), 1); + CHECK_EQ(op->outputs.size(), 1); + auto& output_array = model->GetArray(op->outputs[0]); + if (output_array.has_shape()) { + // We have already run + return; + } + // Get the input_shape + auto& input_array = model->GetArray(op->inputs[0]); + Shape* mutable_shape = input_array.mutable_shape(); + std::vector* dims = mutable_shape->mutable_dims(); + int dims_size = dims->size(); + int last_dim = (*dims)[dims_size - 1]; + dims->push_back(last_dim); + output_array.copy_shape(*mutable_shape); +} + } // namespace ::tensorflow::Status PropagateFixedSizes::Run(Model* model, @@ -2363,6 +2381,9 @@ void ProcessUniqueOperator(Model* model, UniqueOperator* op) { // tensor. Ignore shape propagation here and defer that to the // interpreter. break; + case OperatorType::kMatrixDiag: + ProcessMatrixDiagOperator(model, static_cast(op)); + break; default: // Unimplemented, another graph transformation should drop it. LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(op->type); diff --git a/tensorflow/lite/toco/import_tensorflow.cc b/tensorflow/lite/toco/import_tensorflow.cc index fdf72bde057..9e3df6a9b59 100644 --- a/tensorflow/lite/toco/import_tensorflow.cc +++ b/tensorflow/lite/toco/import_tensorflow.cc @@ -2471,6 +2471,7 @@ ConverterMapType GetTensorFlowNodeConverterMap() { {"LogicalNot", ConvertSimpleOperator}, {"LogSoftmax", ConvertSimpleOperator}, {"MatMul", ConvertMatMulOperator}, + {"MatrixDiag", ConvertSimpleOperator}, {"Max", ConvertReduceOperator}, {"MaxPool", ConvertMaxPoolOperator}, {"Maximum", ConvertSimpleOperator}, diff --git a/tensorflow/lite/toco/model.h b/tensorflow/lite/toco/model.h index 46f70c9e379..27a5545a1f9 100644 --- a/tensorflow/lite/toco/model.h +++ b/tensorflow/lite/toco/model.h @@ -168,7 +168,8 @@ enum class OperatorType : uint8 { kGatherNd, kWhere, kElu, - kReverseSequence + kReverseSequence, + kMatrixDiag }; // Helper to deal with TensorFlow arrays using a different ordering of @@ -2075,6 +2076,14 @@ struct WhereOperator : Operator { WhereOperator() : Operator(OperatorType::kWhere) {} }; +// Matrix Diag Operator: +// Construct a batched diagonal tensor with given batched diagonal values. +// Inputs: A tensor of values that will be on the diagonal of the returned +// tensor. +struct MatrixDiagOperator : Operator { + MatrixDiagOperator() : Operator(OperatorType::kMatrixDiag) {} +}; + // Alloc's are used for transient arrays only. An Alloc specifies which interval // of the "transient_data" workspace buffer passed to inference functions, is to // be used for the transient array at hand. The 'start' and 'end' values are diff --git a/tensorflow/lite/toco/tflite/operator.cc b/tensorflow/lite/toco/tflite/operator.cc index 2919f81571a..2f2a01fde97 100644 --- a/tensorflow/lite/toco/tflite/operator.cc +++ b/tensorflow/lite/toco/tflite/operator.cc @@ -2476,7 +2476,8 @@ std::vector> BuildOperatorList( ops.push_back( MakeUnique(::tflite::BuiltinOperator_REVERSE_SEQUENCE, OperatorType::kReverseSequence)); - + ops.push_back(MakeUnique>( + "MATRIX_DIAG", OperatorType::kMatrixDiag)); // Custom Operators. ops.push_back( MakeUnique("DEPTH_TO_SPACE", OperatorType::kDepthToSpace)); diff --git a/tensorflow/lite/toco/tflite/operator_test.cc b/tensorflow/lite/toco/tflite/operator_test.cc index 1b13f8076a0..8bf843a3040 100644 --- a/tensorflow/lite/toco/tflite/operator_test.cc +++ b/tensorflow/lite/toco/tflite/operator_test.cc @@ -705,6 +705,13 @@ TEST_F(OperatorTest, BuiltinReverseSequence) { EXPECT_EQ(op.batch_dim, output_toco_op->batch_dim); } +TEST_F(OperatorTest, BuiltinMatrixDiag) { + MatrixDiagOperator op; + std::unique_ptr output_toco_op = + SerializeAndDeserialize( + GetOperator("MATRIX_DIAG", OperatorType::kMatrixDiag), op); +} + // Test version for a simple Op with 2 versions and the input type controls the // version. template diff --git a/tensorflow/lite/toco/tooling_util.cc b/tensorflow/lite/toco/tooling_util.cc index ca2477fed1a..25ea8f53c5f 100644 --- a/tensorflow/lite/toco/tooling_util.cc +++ b/tensorflow/lite/toco/tooling_util.cc @@ -427,6 +427,7 @@ const char* OperatorTypeName(OperatorType type) { HANDLE_OPERATORTYPENAME_CASE(Cos) HANDLE_OPERATORTYPENAME_CASE(Where) HANDLE_OPERATORTYPENAME_CASE(ReverseSequence) + HANDLE_OPERATORTYPENAME_CASE(MatrixDiag) default: LOG(FATAL) << "Unhandled op type"; #undef HANDLE_OPERATORTYPENAME_CASE