diff --git a/tensorflow/lite/build_def.bzl b/tensorflow/lite/build_def.bzl index c5ace905d4f..76db6049cf2 100644 --- a/tensorflow/lite/build_def.bzl +++ b/tensorflow/lite/build_def.bzl @@ -248,6 +248,7 @@ def generated_test_models(): "equal", "exp", "expand_dims", + "eye", "fill", "floor", "floor_div", @@ -275,6 +276,7 @@ def generated_test_models(): "logical_xor", "lstm", "matrix_diag", + "matrix_set_diag", "max_pool", "maximum", "mean", diff --git a/tensorflow/lite/builtin_ops.h b/tensorflow/lite/builtin_ops.h index 1598498332f..914fd7fc23c 100644 --- a/tensorflow/lite/builtin_ops.h +++ b/tensorflow/lite/builtin_ops.h @@ -140,6 +140,7 @@ typedef enum { kTfLiteBuiltinReverseSequence = 112, kTfLiteBuiltinMatrixDiag = 113, kTfLiteBuiltinQuantize = 114, + kTfLiteBuiltinMatrixSetDiag = 115, } TfLiteBuiltinOperator; #ifdef __cplusplus diff --git a/tensorflow/lite/c/builtin_op_data.h b/tensorflow/lite/c/builtin_op_data.h index cdc44be3c7e..2b64db4419a 100644 --- a/tensorflow/lite/c/builtin_op_data.h +++ b/tensorflow/lite/c/builtin_op_data.h @@ -377,6 +377,10 @@ typedef struct { EmptyStructPlaceholder placeholder; } TfLiteMatrixDiagParams; +typedef struct { + EmptyStructPlaceholder placeholder; +} TfLiteMatrixSetDiagParams; + #ifdef __cplusplus } // extern "C" #endif // __cplusplus diff --git a/tensorflow/lite/core/api/flatbuffer_conversions.cc b/tensorflow/lite/core/api/flatbuffer_conversions.cc index 2b7ce58a160..c906fc82576 100644 --- a/tensorflow/lite/core/api/flatbuffer_conversions.cc +++ b/tensorflow/lite/core/api/flatbuffer_conversions.cc @@ -708,6 +708,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, case BuiltinOperator_LOGISTIC: case BuiltinOperator_LOG_SOFTMAX: case BuiltinOperator_MATRIX_DIAG: + case BuiltinOperator_MATRIX_SET_DIAG: case BuiltinOperator_MAXIMUM: case BuiltinOperator_MINIMUM: case BuiltinOperator_NEG: diff --git a/tensorflow/lite/kernels/BUILD b/tensorflow/lite/kernels/BUILD index 36939160acc..0432862fbb1 100644 --- a/tensorflow/lite/kernels/BUILD +++ b/tensorflow/lite/kernels/BUILD @@ -189,6 +189,7 @@ cc_library( "lsh_projection.cc", "lstm.cc", "matrix_diag.cc", + "matrix_set_diag.cc", "maximum_minimum.cc", "mfcc.cc", "mirror_pad.cc", @@ -1415,3 +1416,15 @@ cc_test( "@com_google_googletest//:gtest", ], ) + +cc_test( + name = "matrix_set_diag_test", + size = "small", + srcs = ["matrix_set_diag_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/lite:framework", + "//tensorflow/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) diff --git a/tensorflow/lite/kernels/matrix_set_diag.cc b/tensorflow/lite/kernels/matrix_set_diag.cc new file mode 100644 index 00000000000..8ce613e82a7 --- /dev/null +++ b/tensorflow/lite/kernels/matrix_set_diag.cc @@ -0,0 +1,147 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/lite/kernels/internal/tensor.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace matrix_set_diag { + +constexpr int kInputTensor = 0; +constexpr int kDiagonalTensor = 1; +constexpr int kOutputTensor = 0; + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + TfLiteIntArray* input_dims = input->dims; + int input_dims_size = input_dims->size; + TF_LITE_ENSURE(context, input_dims_size >= 2); + + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + TfLiteIntArray* output_shape = TfLiteIntArrayCreate(input_dims_size); + for (int i = 0; i < input_dims_size; i++) { + output_shape->data[i] = input_dims->data[i]; + } + + // Resize the output tensor to the same size as the input tensor. + output->type = input->type; + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, output, output_shape)); + + return kTfLiteOk; +} + +// Fill the tensor to make a diagonal matrix in each batch, i.e., when +// row index and column index are the same, fill with the next diagonal value. +// All other entries are the same as the input value. +// TODO(b/128636574) Move to reference_ops. +template +void FillDiagImpl(const T* in, const T* diag, T* out, const int batch_size, + const int row_size, const int col_size) { + int idx = 0; + for (int b = 0; b < batch_size; b++) { + for (int i = 0; i < row_size; i++) { + for (int j = 0; j < col_size; ++j) { + // diag values go on the diagonal, in values elsewhere + if (i == j) { + out[i * col_size + j] = diag[idx]; + idx++; + } else { + out[i * col_size + j] = in[i * col_size + j]; + } + } + } + out += row_size * col_size; + in += row_size * col_size; + } +} + +template +void FillDiag(const TfLiteTensor* input, const TfLiteTensor* diag, + TfLiteTensor* output, const int batch_size, const int row_size, + const int col_size) { + FillDiagImpl(GetTensorData(input), GetTensorData(diag), + GetTensorData(output), batch_size, row_size, col_size); +} + +// Fill a tensor with given "diag" values on the diagonal, input values +// elsewhere. +void FillDiagHelper(const TfLiteTensor* input, const TfLiteTensor* diag, + TfLiteTensor* output) { + const int num_output_dims = output->dims->size; + int batch_size = 1; + for (int i = 0; i < num_output_dims - 2; ++i) { + batch_size *= output->dims->data[i]; + } + + const int row_size = output->dims->data[num_output_dims - 2]; + const int col_size = output->dims->data[num_output_dims - 1]; + switch (output->type) { + case kTfLiteInt64: { + return FillDiag(input, diag, output, batch_size, row_size, + col_size); + } + case kTfLiteInt32: { + return FillDiag(input, diag, output, batch_size, row_size, + col_size); + } + case kTfLiteInt16: { + return FillDiag(input, diag, output, batch_size, row_size, + col_size); + } + case kTfLiteInt8: { + return FillDiag(input, diag, output, batch_size, row_size, + col_size); + } + case kTfLiteUInt8: { + return FillDiag(input, diag, output, batch_size, row_size, + col_size); + } + default: + return FillDiag(input, diag, output, batch_size, row_size, + col_size); + } +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + const TfLiteTensor* diag = GetInput(context, node, kDiagonalTensor); + FillDiagHelper(input, diag, output); + return kTfLiteOk; +} + +} // namespace matrix_set_diag + +TfLiteRegistration* Register_MATRIX_SET_DIAG() { + static TfLiteRegistration r = {nullptr, nullptr, matrix_set_diag::Prepare, + matrix_set_diag::Eval}; + return &r; +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/lite/kernels/matrix_set_diag_test.cc b/tensorflow/lite/kernels/matrix_set_diag_test.cc new file mode 100644 index 00000000000..a41c717f733 --- /dev/null +++ b/tensorflow/lite/kernels/matrix_set_diag_test.cc @@ -0,0 +1,132 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAre; +using ::testing::ElementsAreArray; + +template +class MatrixSetDiagOpModel : public SingleOpModel { + public: + explicit MatrixSetDiagOpModel(const TensorData& input, + const TensorData& diag) { + input_ = AddInput(input); + diag_ = AddInput(diag); + output_ = AddOutput({input.type, {}}); + + SetBuiltinOp(BuiltinOperator_MATRIX_SET_DIAG, + BuiltinOptions_MatrixSetDiagOptions, + CreateMatrixSetDiagOptions(builder_).Union()); + BuildInterpreter({GetShape(input_), GetShape(diag_)}); + } + + int input() { return input_; } + int diag() { return diag_; } + + std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetOutputShape() { return GetTensorShape(output_); } + TfLiteType GetOutputType() { + TfLiteTensor* t = interpreter_->tensor(output_); + return t->type; + } + + private: + int input_; + int diag_; + int output_; +}; + +// Use the machinery of TYPED_TEST_SUITE to test all supported types. +// See +// https://github.com/google/googletest/blob/master/googletest/docs/advanced.md#typed-tests +// for details. +template +class MatrixSetDiagOpTest : public ::testing::Test {}; + +using TypesUnderTest = + ::testing::Types, TypeUnion, TypeUnion, + TypeUnion, TypeUnion>; + +TYPED_TEST_SUITE(MatrixSetDiagOpTest, TypesUnderTest); + +TYPED_TEST(MatrixSetDiagOpTest, ThreeByThreeDiagScatter) { + MatrixSetDiagOpModel model( + {TypeParam::tensor_type, {3, 3}}, {TypeParam::tensor_type, {3}}); + model.template PopulateTensor(model.input(), + {7, 1, 2, // + 3, 8, 4, // + 5, 6, 9}); + model.template PopulateTensor(model.diag(), + {0, 4, 2}); + model.Invoke(); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(3, 3)); + EXPECT_THAT(model.GetOutput(), ElementsAreArray({0, 1, 2, // + 3, 4, 4, // + 5, 6, 2})); + EXPECT_THAT(model.GetOutputType(), TypeParam::tflite_type); +} + +TEST(MatrixSetDiagTest, Int32TestMoreColumnsThanRows) { + MatrixSetDiagOpModel model({TensorType_INT32, {2, 3}}, + {TensorType_INT32, {2}}); + model.PopulateTensor(model.input(), {0, 0, 0, // + 9, 9, 9}); + model.PopulateTensor(model.diag(), {1, 1}); + model.Invoke(); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(2, 3)); + EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 0, 0, // + 9, 1, 9})); + EXPECT_THAT(model.GetOutputType(), TfLiteType::kTfLiteInt32); +} + +TEST(MatrixSetDiagTest, Int32TestTwoDimDiag) { + MatrixSetDiagOpModel model({TensorType_INT32, {2, 4, 4}}, + {TensorType_INT32, {2, 4}}); + model.PopulateTensor(model.input(), {5, 5, 5, 5, // + 5, 5, 5, 5, // + 5, 5, 5, 5, // + 5, 5, 5, 5, // + 1, 1, 1, 1, // + 1, 1, 1, 1, // + 1, 1, 1, 1, // + 1, 1, 1, 1}); + model.PopulateTensor(model.diag(), {1, 2, 3, 4, 5, 6, 7, 8}); + model.Invoke(); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(2, 4, 4)); + EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 5, 5, 5, // + 5, 2, 5, 5, // + 5, 5, 3, 5, // + 5, 5, 5, 4, // + 5, 1, 1, 1, // + 1, 6, 1, 1, // + 1, 1, 7, 1, // + 1, 1, 1, 8})); + EXPECT_THAT(model.GetOutputType(), TfLiteType::kTfLiteInt32); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/lite/kernels/register.cc b/tensorflow/lite/kernels/register.cc index 7f6f5c513bf..3a80230536f 100644 --- a/tensorflow/lite/kernels/register.cc +++ b/tensorflow/lite/kernels/register.cc @@ -141,6 +141,7 @@ TfLiteRegistration* Register_ELU(); TfLiteRegistration* Register_REVERSE_SEQUENCE(); TfLiteRegistration* Register_MATRIX_DIAG(); TfLiteRegistration* Register_QUANTIZE(); +TfLiteRegistration* Register_MATRIX_SET_DIAG(); TfLiteStatus UnsupportedTensorFlowOp(TfLiteContext* context, TfLiteNode* node) { context->ReportError( @@ -378,6 +379,7 @@ BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_REVERSE_SEQUENCE, Register_REVERSE_SEQUENCE()); AddBuiltin(BuiltinOperator_MATRIX_DIAG, Register_MATRIX_DIAG()); AddBuiltin(BuiltinOperator_QUANTIZE, Register_QUANTIZE()); + AddBuiltin(BuiltinOperator_MATRIX_SET_DIAG, Register_MATRIX_SET_DIAG()); // TODO(andrewharp, ahentz): Move these somewhere more appropriate so that // custom ops aren't always included by default. diff --git a/tensorflow/lite/nnapi_delegate.cc b/tensorflow/lite/nnapi_delegate.cc index 5b1a75f5526..c8807e3ae29 100644 --- a/tensorflow/lite/nnapi_delegate.cc +++ b/tensorflow/lite/nnapi_delegate.cc @@ -673,6 +673,7 @@ TfLiteStatus AddOpsAndParams( case tflite::BuiltinOperator_REVERSE_SEQUENCE: case tflite::BuiltinOperator_MATRIX_DIAG: case tflite::BuiltinOperator_QUANTIZE: + case tflite::BuiltinOperator_MATRIX_SET_DIAG: logError("Op code %d is currently not delegated to NNAPI", builtin); return kTfLiteError; break; diff --git a/tensorflow/lite/schema/schema.fbs b/tensorflow/lite/schema/schema.fbs index 4343ebc2edc..3dbdacd3832 100644 --- a/tensorflow/lite/schema/schema.fbs +++ b/tensorflow/lite/schema/schema.fbs @@ -228,6 +228,7 @@ enum BuiltinOperator : byte { REVERSE_SEQUENCE = 112, MATRIX_DIAG = 113, QUANTIZE = 114, + MATRIX_SET_DIAG = 115 } // Options for the builtin operators. @@ -321,6 +322,7 @@ union BuiltinOptions { ReverseSequenceOptions, MatrixDiagOptions, QuantizeOptions, + MatrixSetDiagOptions } enum Padding : byte { SAME, VALID } @@ -767,6 +769,9 @@ table MatrixDiagOptions { table QuantizeOptions { } +table MatrixSetDiagOptions { +} + // An OperatorCode can be an enum value (BuiltinOperator) if the operator is a // builtin, or a string if the operator is custom. table OperatorCode { diff --git a/tensorflow/lite/schema/schema_generated.h b/tensorflow/lite/schema/schema_generated.h index f765efe8201..3520eff51d9 100755 --- a/tensorflow/lite/schema/schema_generated.h +++ b/tensorflow/lite/schema/schema_generated.h @@ -298,6 +298,9 @@ struct MatrixDiagOptionsT; struct QuantizeOptions; struct QuantizeOptionsT; +struct MatrixSetDiagOptions; +struct MatrixSetDiagOptionsT; + struct OperatorCode; struct OperatorCodeT; @@ -562,11 +565,12 @@ enum BuiltinOperator { BuiltinOperator_REVERSE_SEQUENCE = 112, BuiltinOperator_MATRIX_DIAG = 113, BuiltinOperator_QUANTIZE = 114, + BuiltinOperator_MATRIX_SET_DIAG = 115, BuiltinOperator_MIN = BuiltinOperator_ADD, - BuiltinOperator_MAX = BuiltinOperator_QUANTIZE + BuiltinOperator_MAX = BuiltinOperator_MATRIX_SET_DIAG }; -inline const BuiltinOperator (&EnumValuesBuiltinOperator())[114] { +inline const BuiltinOperator (&EnumValuesBuiltinOperator())[115] { static const BuiltinOperator values[] = { BuiltinOperator_ADD, BuiltinOperator_AVERAGE_POOL_2D, @@ -681,7 +685,8 @@ inline const BuiltinOperator (&EnumValuesBuiltinOperator())[114] { BuiltinOperator_ELU, BuiltinOperator_REVERSE_SEQUENCE, BuiltinOperator_MATRIX_DIAG, - BuiltinOperator_QUANTIZE + BuiltinOperator_QUANTIZE, + BuiltinOperator_MATRIX_SET_DIAG }; return values; } @@ -803,6 +808,7 @@ inline const char * const *EnumNamesBuiltinOperator() { "REVERSE_SEQUENCE", "MATRIX_DIAG", "QUANTIZE", + "MATRIX_SET_DIAG", nullptr }; return names; @@ -904,11 +910,12 @@ enum BuiltinOptions { BuiltinOptions_ReverseSequenceOptions = 87, BuiltinOptions_MatrixDiagOptions = 88, BuiltinOptions_QuantizeOptions = 89, + BuiltinOptions_MatrixSetDiagOptions = 90, BuiltinOptions_MIN = BuiltinOptions_NONE, - BuiltinOptions_MAX = BuiltinOptions_QuantizeOptions + BuiltinOptions_MAX = BuiltinOptions_MatrixSetDiagOptions }; -inline const BuiltinOptions (&EnumValuesBuiltinOptions())[90] { +inline const BuiltinOptions (&EnumValuesBuiltinOptions())[91] { static const BuiltinOptions values[] = { BuiltinOptions_NONE, BuiltinOptions_Conv2DOptions, @@ -999,7 +1006,8 @@ inline const BuiltinOptions (&EnumValuesBuiltinOptions())[90] { BuiltinOptions_RankOptions, BuiltinOptions_ReverseSequenceOptions, BuiltinOptions_MatrixDiagOptions, - BuiltinOptions_QuantizeOptions + BuiltinOptions_QuantizeOptions, + BuiltinOptions_MatrixSetDiagOptions }; return values; } @@ -1096,6 +1104,7 @@ inline const char * const *EnumNamesBuiltinOptions() { "ReverseSequenceOptions", "MatrixDiagOptions", "QuantizeOptions", + "MatrixSetDiagOptions", nullptr }; return names; @@ -1466,6 +1475,10 @@ template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_QuantizeOptions; }; +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_MatrixSetDiagOptions; +}; + struct BuiltinOptionsUnion { BuiltinOptions type; void *value; @@ -2209,6 +2222,14 @@ struct BuiltinOptionsUnion { return type == BuiltinOptions_QuantizeOptions ? reinterpret_cast(value) : nullptr; } + MatrixSetDiagOptionsT *AsMatrixSetDiagOptions() { + return type == BuiltinOptions_MatrixSetDiagOptions ? + reinterpret_cast(value) : nullptr; + } + const MatrixSetDiagOptionsT *AsMatrixSetDiagOptions() const { + return type == BuiltinOptions_MatrixSetDiagOptions ? + reinterpret_cast(value) : nullptr; + } }; bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type); @@ -7691,6 +7712,46 @@ inline flatbuffers::Offset CreateQuantizeOptions( flatbuffers::Offset CreateQuantizeOptions(flatbuffers::FlatBufferBuilder &_fbb, const QuantizeOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +struct MatrixSetDiagOptionsT : public flatbuffers::NativeTable { + typedef MatrixSetDiagOptions TableType; + MatrixSetDiagOptionsT() { + } +}; + +struct MatrixSetDiagOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef MatrixSetDiagOptionsT NativeTableType; + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + MatrixSetDiagOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(MatrixSetDiagOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const MatrixSetDiagOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct MatrixSetDiagOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + explicit MatrixSetDiagOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + MatrixSetDiagOptionsBuilder &operator=(const MatrixSetDiagOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateMatrixSetDiagOptions( + flatbuffers::FlatBufferBuilder &_fbb) { + MatrixSetDiagOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +flatbuffers::Offset CreateMatrixSetDiagOptions(flatbuffers::FlatBufferBuilder &_fbb, const MatrixSetDiagOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + struct OperatorCodeT : public flatbuffers::NativeTable { typedef OperatorCode TableType; BuiltinOperator builtin_code; @@ -8091,6 +8152,9 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { const QuantizeOptions *builtin_options_as_QuantizeOptions() const { return builtin_options_type() == BuiltinOptions_QuantizeOptions ? static_cast(builtin_options()) : nullptr; } + const MatrixSetDiagOptions *builtin_options_as_MatrixSetDiagOptions() const { + return builtin_options_type() == BuiltinOptions_MatrixSetDiagOptions ? static_cast(builtin_options()) : nullptr; + } const flatbuffers::Vector *custom_options() const { return GetPointer *>(VT_CUSTOM_OPTIONS); } @@ -8478,6 +8542,10 @@ template<> inline const QuantizeOptions *Operator::builtin_options_as inline const MatrixSetDiagOptions *Operator::builtin_options_as() const { + return builtin_options_as_MatrixSetDiagOptions(); +} + struct OperatorBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; @@ -11338,6 +11406,29 @@ inline flatbuffers::Offset CreateQuantizeOptions(flatbuffers::F _fbb); } +inline MatrixSetDiagOptionsT *MatrixSetDiagOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new MatrixSetDiagOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void MatrixSetDiagOptions::UnPackTo(MatrixSetDiagOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline flatbuffers::Offset MatrixSetDiagOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const MatrixSetDiagOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateMatrixSetDiagOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateMatrixSetDiagOptions(flatbuffers::FlatBufferBuilder &_fbb, const MatrixSetDiagOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const MatrixSetDiagOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateMatrixSetDiagOptions( + _fbb); +} + inline OperatorCodeT *OperatorCode::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new OperatorCodeT(); UnPackTo(_o, _resolver); @@ -11952,6 +12043,10 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *ob auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } + case BuiltinOptions_MatrixSetDiagOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } default: return false; } } @@ -12326,6 +12421,10 @@ inline void *BuiltinOptionsUnion::UnPack(const void *obj, BuiltinOptions type, c auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } + case BuiltinOptions_MatrixSetDiagOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } default: return nullptr; } } @@ -12688,6 +12787,10 @@ inline flatbuffers::Offset BuiltinOptionsUnion::Pack(flatbuffers::FlatBuff auto ptr = reinterpret_cast(value); return CreateQuantizeOptions(_fbb, ptr, _rehasher).Union(); } + case BuiltinOptions_MatrixSetDiagOptions: { + auto ptr = reinterpret_cast(value); + return CreateMatrixSetDiagOptions(_fbb, ptr, _rehasher).Union(); + } default: return 0; } } @@ -13050,6 +13153,10 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) FL value = new QuantizeOptionsT(*reinterpret_cast(u.value)); break; } + case BuiltinOptions_MatrixSetDiagOptions: { + value = new MatrixSetDiagOptionsT(*reinterpret_cast(u.value)); + break; + } default: break; } @@ -13502,6 +13609,11 @@ inline void BuiltinOptionsUnion::Reset() { delete ptr; break; } + case BuiltinOptions_MatrixSetDiagOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } default: break; } value = nullptr; diff --git a/tensorflow/lite/testing/generate_examples_lib.py b/tensorflow/lite/testing/generate_examples_lib.py index 36146692250..8ef792ad22d 100644 --- a/tensorflow/lite/testing/generate_examples_lib.py +++ b/tensorflow/lite/testing/generate_examples_lib.py @@ -4385,6 +4385,79 @@ def make_matrix_diag_tests(zip_path): make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) +def make_matrix_set_diag_tests(zip_path): + """Make a set of tests for tf.matrix_set_diag op.""" + + test_parameters = [ + { + "input_diag_shapes": [([3, 3], [3]), ([2, 3], [2]), ([2, 4, 4], + [2, 4]), + ([3, 4, 5, 6], [3, 4, 5])], + "input_dtype": [tf.int32, tf.float32, tf.uint8], + }, + ] + + def build_graph(parameters): + input_shape = parameters["input_diag_shapes"][0] + diag_shape = parameters["input_diag_shapes"][1] + input_tensor = tf.placeholder( + dtype=parameters["input_dtype"], name="input", shape=input_shape) + diag_tensor = tf.placeholder( + dtype=parameters["input_dtype"], name="diagonal", shape=diag_shape) + outs = tf.matrix_set_diag(input_tensor, diag_tensor) + return [input_tensor, diag_tensor], [outs] + + def build_inputs(parameters, sess, inputs, outputs): + input_shape = parameters["input_diag_shapes"][0] + diag_shape = parameters["input_diag_shapes"][1] + input_values = create_tensor_data(parameters["input_dtype"], input_shape) + diag_values = create_tensor_data(parameters["input_dtype"], diag_shape) + return [input_values, diag_values], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_values, diag_values]))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + +def make_eye_tests(zip_path): + """Make a set of tests for tf.eye op.""" + + test_parameters = [{ + "num_rows_shape": [[]], + "num_cols_shape": [[]], + "batch_shape": [[3], [2, 4], [4, 5, 6], None], + "use_num_cols": [True, False], + "dtype": [tf.float32, tf.int32], + }] + + def build_graph(parameters): + input_tensor0 = tf.placeholder( + dtype=tf.int32, name="num_rows", shape=parameters["num_rows_shape"]) + input_tensor1 = tf.placeholder( + dtype=tf.int32, name="num_columns", shape=parameters["num_cols_shape"]) + if parameters["use_num_cols"]: + outs = tf.eye( + num_rows=input_tensor0, + num_columns=input_tensor1, + batch_shape=parameters["batch_shape"], + dtype=parameters["dtype"]) + return [input_tensor0, input_tensor1], [outs] + else: + outs = tf.eye(num_rows=input_tensor0, dtype=parameters["dtype"]) + return [input_tensor0], [outs] + + def build_inputs(parameters, sess, inputs, outputs): + input_value0 = create_scalar_data(dtype=np.int32, min_value=1) + input_value1 = create_scalar_data(dtype=np.int32, min_value=1) + if parameters["use_num_cols"]: + return [input_value0, input_value1], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_value0, input_value1]))) + else: + return [input_value0], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_value0]))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + @test_util.enable_control_flow_v2 def make_unidirectional_sequence_lstm_tests(zip_path): """Make a set of tests to do unidirectional_sequence_lstm.""" diff --git a/tensorflow/lite/toco/graph_transformations/propagate_array_data_types.cc b/tensorflow/lite/toco/graph_transformations/propagate_array_data_types.cc index 1cffee785a8..0f67edce9b1 100644 --- a/tensorflow/lite/toco/graph_transformations/propagate_array_data_types.cc +++ b/tensorflow/lite/toco/graph_transformations/propagate_array_data_types.cc @@ -293,6 +293,13 @@ void SetDataTypeForAllOutputs(Model* model, Operator* op, SetDataTypeForAllOutputs(model, op, data_type); break; } + case OperatorType::kMatrixSetDiag: { + CHECK_EQ(op->inputs.size(), 2); + CHECK_EQ(op->outputs.size(), 1); + const ArrayDataType data_type = model->GetArray(op->inputs[0]).data_type; + SetDataTypeForAllOutputs(model, op, data_type); + break; + } default: { // These operators produce outputs with the same type as their 1st input CHECK_GT(op->inputs.size(), 0); diff --git a/tensorflow/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/lite/toco/graph_transformations/propagate_fixed_sizes.cc index 33b66dab8cb..870e05094ce 100644 --- a/tensorflow/lite/toco/graph_transformations/propagate_fixed_sizes.cc +++ b/tensorflow/lite/toco/graph_transformations/propagate_fixed_sizes.cc @@ -2063,21 +2063,40 @@ void ProcessUniqueOperator(Model* model, UniqueOperator* op) { void ProcessMatrixDiagOperator(Model* model, MatrixDiagOperator* op) { CHECK_EQ(op->inputs.size(), 1); CHECK_EQ(op->outputs.size(), 1); + auto& input_array = model->GetArray(op->inputs[0]); auto& output_array = model->GetArray(op->outputs[0]); - if (output_array.has_shape()) { + // The input array must have a shape in order to proceed. Also, + // bail out if the output shape has already been calculated. + if (!input_array.has_shape() || output_array.has_shape()) { // We have already run return; } // Get the input_shape - auto& input_array = model->GetArray(op->inputs[0]); Shape* mutable_shape = input_array.mutable_shape(); std::vector* dims = mutable_shape->mutable_dims(); int dims_size = dims->size(); + // Scalars are not allowed. + CHECK_GT(dims_size, 0); int last_dim = (*dims)[dims_size - 1]; dims->push_back(last_dim); output_array.copy_shape(*mutable_shape); } +void ProcessMatrixSetDiagOperator(Model* model, MatrixSetDiagOperator* op) { + CHECK_EQ(op->inputs.size(), 2); + CHECK_EQ(op->outputs.size(), 1); + auto& input_array = model->GetArray(op->inputs[0]); + auto& output_array = model->GetArray(op->outputs[0]); + // The shape of the input array must be known because that will + // be the shape of the output array. + if (!input_array.has_shape() || !output_array.has_shape()) { + // We have already run + return; + } + + output_array.copy_shape(input_array.shape()); +} + } // namespace ::tensorflow::Status PropagateFixedSizes::Run(Model* model, @@ -2384,6 +2403,10 @@ void ProcessMatrixDiagOperator(Model* model, MatrixDiagOperator* op) { case OperatorType::kMatrixDiag: ProcessMatrixDiagOperator(model, static_cast(op)); break; + case OperatorType::kMatrixSetDiag: + ProcessMatrixSetDiagOperator(model, + static_cast(op)); + break; default: // Unimplemented, another graph transformation should drop it. LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(op->type); diff --git a/tensorflow/lite/toco/graph_transformations/quantize.cc b/tensorflow/lite/toco/graph_transformations/quantize.cc index ee00c334336..6381fc408d1 100644 --- a/tensorflow/lite/toco/graph_transformations/quantize.cc +++ b/tensorflow/lite/toco/graph_transformations/quantize.cc @@ -68,7 +68,9 @@ bool SupportsQuantization(const Operator& op) { type == OperatorType::kResizeNearestNeighbor || type == OperatorType::kPRelu || type == OperatorType::kReduceMax || type == OperatorType::kReduceMin || - type == OperatorType::kTransposeConv; + type == OperatorType::kTransposeConv || + type == OperatorType::kMatrixSetDiag || + type == OperatorType::kMatrixDiag; } // The quantized op allows output arrays of type float using diff --git a/tensorflow/lite/toco/import_tensorflow.cc b/tensorflow/lite/toco/import_tensorflow.cc index 9e3df6a9b59..be2a71c33b7 100644 --- a/tensorflow/lite/toco/import_tensorflow.cc +++ b/tensorflow/lite/toco/import_tensorflow.cc @@ -2472,6 +2472,7 @@ ConverterMapType GetTensorFlowNodeConverterMap() { {"LogSoftmax", ConvertSimpleOperator}, {"MatMul", ConvertMatMulOperator}, {"MatrixDiag", ConvertSimpleOperator}, + {"MatrixSetDiag", ConvertSimpleOperator}, {"Max", ConvertReduceOperator}, {"MaxPool", ConvertMaxPoolOperator}, {"Maximum", ConvertSimpleOperator}, diff --git a/tensorflow/lite/toco/model.h b/tensorflow/lite/toco/model.h index 27a5545a1f9..e7318e7082b 100644 --- a/tensorflow/lite/toco/model.h +++ b/tensorflow/lite/toco/model.h @@ -169,7 +169,8 @@ enum class OperatorType : uint8 { kWhere, kElu, kReverseSequence, - kMatrixDiag + kMatrixDiag, + kMatrixSetDiag }; // Helper to deal with TensorFlow arrays using a different ordering of @@ -2084,6 +2085,16 @@ struct MatrixDiagOperator : Operator { MatrixDiagOperator() : Operator(OperatorType::kMatrixDiag) {} }; +// Matrix Set Diag Operator: +// Construct a batched diagonal tensor with given input and diagonal values. +// Input is a rank (k+1) tensor of values. +// diagonal is a rank (k) tensor of values that will be on the diagonal +// of the returned output. Output is rank k+1. +// tensor. +struct MatrixSetDiagOperator : Operator { + MatrixSetDiagOperator() : Operator(OperatorType::kMatrixSetDiag) {} +}; + // Alloc's are used for transient arrays only. An Alloc specifies which interval // of the "transient_data" workspace buffer passed to inference functions, is to // be used for the transient array at hand. The 'start' and 'end' values are diff --git a/tensorflow/lite/toco/tflite/operator.cc b/tensorflow/lite/toco/tflite/operator.cc index 2f2a01fde97..bc8f54cb26f 100644 --- a/tensorflow/lite/toco/tflite/operator.cc +++ b/tensorflow/lite/toco/tflite/operator.cc @@ -2478,6 +2478,8 @@ std::vector> BuildOperatorList( OperatorType::kReverseSequence)); ops.push_back(MakeUnique>( "MATRIX_DIAG", OperatorType::kMatrixDiag)); + ops.push_back(MakeUnique>( + "MATRIX_SET_DIAG", OperatorType::kMatrixSetDiag)); // Custom Operators. ops.push_back( MakeUnique("DEPTH_TO_SPACE", OperatorType::kDepthToSpace)); diff --git a/tensorflow/lite/toco/tflite/operator_test.cc b/tensorflow/lite/toco/tflite/operator_test.cc index 8bf843a3040..a5b1efef5a1 100644 --- a/tensorflow/lite/toco/tflite/operator_test.cc +++ b/tensorflow/lite/toco/tflite/operator_test.cc @@ -712,6 +712,13 @@ TEST_F(OperatorTest, BuiltinMatrixDiag) { GetOperator("MATRIX_DIAG", OperatorType::kMatrixDiag), op); } +TEST_F(OperatorTest, BuiltinMatrixSetDiag) { + MatrixSetDiagOperator op; + std::unique_ptr output_toco_op = + SerializeAndDeserialize( + GetOperator("MATRIX_SET_DIAG", OperatorType::kMatrixSetDiag), op); +} + // Test version for a simple Op with 2 versions and the input type controls the // version. template diff --git a/tensorflow/lite/toco/tooling_util.cc b/tensorflow/lite/toco/tooling_util.cc index 1503c741b30..31d5f03e17a 100644 --- a/tensorflow/lite/toco/tooling_util.cc +++ b/tensorflow/lite/toco/tooling_util.cc @@ -428,6 +428,7 @@ const char* OperatorTypeName(OperatorType type) { HANDLE_OPERATORTYPENAME_CASE(Where) HANDLE_OPERATORTYPENAME_CASE(ReverseSequence) HANDLE_OPERATORTYPENAME_CASE(MatrixDiag) + HANDLE_OPERATORTYPENAME_CASE(MatrixSetDiag) default: LOG(FATAL) << "Unhandled op type"; #undef HANDLE_OPERATORTYPENAME_CASE