Implement Matrix Diag

PiperOrigin-RevId: 239061891
This commit is contained in:
A. Unique TensorFlower 2019-03-18 14:29:41 -07:00 committed by TensorFlower Gardener
parent b7a36dec6b
commit d79bb04e2a
20 changed files with 520 additions and 10 deletions

View File

@ -274,6 +274,7 @@ def generated_test_models():
"logical_or", "logical_or",
"logical_xor", "logical_xor",
"lstm", "lstm",
"matrix_diag",
"max_pool", "max_pool",
"maximum", "maximum",
"mean", "mean",

View File

@ -138,6 +138,7 @@ typedef enum {
kTfLiteBuiltinRank = 110, kTfLiteBuiltinRank = 110,
kTfLiteBuiltinElu = 111, kTfLiteBuiltinElu = 111,
kTfLiteBuiltinReverseSequence = 112, kTfLiteBuiltinReverseSequence = 112,
kTfLiteBuiltinMatrixDiag = 113,
} TfLiteBuiltinOperator; } TfLiteBuiltinOperator;
#ifdef __cplusplus #ifdef __cplusplus

View File

@ -373,6 +373,10 @@ typedef struct {
int batch_dim; int batch_dim;
} TfLiteReverseSequenceParams; } TfLiteReverseSequenceParams;
typedef struct {
EmptyStructPlaceholder placeholder;
} TfLiteMatrixDiagParams;
#ifdef __cplusplus #ifdef __cplusplus
} // extern "C" } // extern "C"
#endif // __cplusplus #endif // __cplusplus

View File

@ -683,8 +683,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
*builtin_data = reinterpret_cast<void*>(params); *builtin_data = reinterpret_cast<void*>(params);
break; break;
} }
// Below are the ops with no builtin_data structure.
// Below are the ops with no builtin_data strcture.
case BuiltinOperator_ABS: case BuiltinOperator_ABS:
case BuiltinOperator_BATCH_TO_SPACE_ND: case BuiltinOperator_BATCH_TO_SPACE_ND:
// TODO(aselle): Implement call in BuiltinOptions, but nullptrs are // TODO(aselle): Implement call in BuiltinOptions, but nullptrs are
@ -708,6 +707,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
case BuiltinOperator_LOG: case BuiltinOperator_LOG:
case BuiltinOperator_LOGISTIC: case BuiltinOperator_LOGISTIC:
case BuiltinOperator_LOG_SOFTMAX: case BuiltinOperator_LOG_SOFTMAX:
case BuiltinOperator_MATRIX_DIAG:
case BuiltinOperator_MAXIMUM: case BuiltinOperator_MAXIMUM:
case BuiltinOperator_MINIMUM: case BuiltinOperator_MINIMUM:
case BuiltinOperator_NEG: case BuiltinOperator_NEG:

View File

@ -188,6 +188,7 @@ cc_library(
"logical.cc", "logical.cc",
"lsh_projection.cc", "lsh_projection.cc",
"lstm.cc", "lstm.cc",
"matrix_diag.cc",
"maximum_minimum.cc", "maximum_minimum.cc",
"mfcc.cc", "mfcc.cc",
"mirror_pad.cc", "mirror_pad.cc",
@ -1388,3 +1389,15 @@ cc_test(
"@com_google_googletest//:gtest", "@com_google_googletest//:gtest",
], ],
) )
cc_test(
name = "matrix_diag_test",
size = "small",
srcs = ["matrix_diag_test.cc"],
deps = [
":builtin_ops",
"//tensorflow/lite:framework",
"//tensorflow/lite/kernels:test_util",
"@com_google_googletest//:gtest",
],
)

View File

@ -0,0 +1,136 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <string.h>
#include <vector>
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/c_api_internal.h"
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/lite/kernels/internal/tensor.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/kernels/op_macros.h"
namespace tflite {
namespace ops {
namespace builtin {
namespace matrix_diag {
constexpr int kInputTensor = 0;
constexpr int kOutputTensor = 0;
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TfLiteIntArray* input_dims = input->dims;
int input_dims_size = input_dims->size;
TF_LITE_ENSURE(context, input_dims_size >= 1);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
// Resize the output tensor.
TfLiteIntArray* output_shape = TfLiteIntArrayCreate(input_dims_size + 1);
for (int i = 0; i < input_dims_size; i++) {
output_shape->data[i] = input_dims->data[i];
}
// Last dimension in the output is the same as the last dimension in the
// input.
output_shape->data[input_dims_size] = input_dims->data[input_dims_size - 1];
output->type = input->type;
TF_LITE_ENSURE_OK(context,
context->ResizeTensor(context, output, output_shape));
return kTfLiteOk;
}
// Fill the tensor to make a diagonal matrix in each batch, i.e., when
// row index and column index are the same, fill with the next input value.
// All other entries get zero.
// TODO(b/128636574) Move to reference_ops.
template <typename T>
void FillDiagImpl(const T* in, T* out, const int batch_size, const int row_size,
const int col_size) {
int idx = 0;
for (int b = 0; b < batch_size; b++) {
for (int i = 0; i < row_size; i++) {
for (int j = 0; j < col_size; ++j) {
// input values go on the diagonal, 0 elsewhere
if (i == j) {
out[i * col_size + j] = in[idx];
idx++;
} else {
out[i * col_size + j] = 0;
}
}
}
out += row_size * col_size;
}
}
template <typename T>
void FillDiag(const TfLiteTensor* input, TfLiteTensor* output,
const int batch_size, const int row_size, const int col_size) {
FillDiagImpl<T>(GetTensorData<T>(input), GetTensorData<T>(output), batch_size,
row_size, col_size);
}
// Fill a tensor with given input on the diagonal, zero elsewhere
void FillDiagHelper(const TfLiteTensor* input, TfLiteTensor* output) {
const int num_output_dims = output->dims->size;
int batch_size = 1;
for (int i = 0; i < num_output_dims - 2; ++i) {
batch_size *= output->dims->data[i];
}
const int row_size = output->dims->data[num_output_dims - 2];
const int col_size = output->dims->data[num_output_dims - 1];
switch (output->type) {
case kTfLiteInt64: {
return FillDiag<int64_t>(input, output, batch_size, row_size, col_size);
}
case kTfLiteInt32: {
return FillDiag<int32_t>(input, output, batch_size, row_size, col_size);
}
case kTfLiteInt16: {
return FillDiag<int16_t>(input, output, batch_size, row_size, col_size);
}
case kTfLiteInt8: {
return FillDiag<int8_t>(input, output, batch_size, row_size, col_size);
}
case kTfLiteUInt8: {
return FillDiag<uint8_t>(input, output, batch_size, row_size, col_size);
}
default:
return FillDiag<float_t>(input, output, batch_size, row_size, col_size);
}
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
FillDiagHelper(input, output);
return kTfLiteOk;
}
} // namespace matrix_diag
TfLiteRegistration* Register_MATRIX_DIAG() {
static TfLiteRegistration r = {nullptr, nullptr, matrix_diag::Prepare,
matrix_diag::Eval};
return &r;
}
} // namespace builtin
} // namespace ops
} // namespace tflite

View File

@ -0,0 +1,110 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/kernels/register.h"
#include "tensorflow/lite/kernels/test_util.h"
#include "tensorflow/lite/model.h"
namespace tflite {
namespace {
using ::testing::ElementsAre;
using ::testing::ElementsAreArray;
template <typename T>
class MatrixDiagOpModel : public SingleOpModel {
public:
explicit MatrixDiagOpModel(const TensorData& input) {
input_ = AddInput(input);
output_ = AddOutput({input.type, {}});
SetBuiltinOp(BuiltinOperator_MATRIX_DIAG, BuiltinOptions_MatrixDiagOptions,
CreateMatrixDiagOptions(builder_).Union());
BuildInterpreter({GetShape(input_)});
}
int input() { return input_; }
std::vector<T> GetOutput() { return ExtractVector<T>(output_); }
std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
TfLiteType GetOutputType() {
TfLiteTensor* t = interpreter_->tensor(output_);
return t->type;
}
private:
int input_;
int output_;
};
// Use the machinery of TYPED_TEST_SUITE to test all supported types.
// See
// https://github.com/google/googletest/blob/master/googletest/docs/advanced.md#typed-tests
// for details.
template <typename T>
class MatrixDiagOpTest : public ::testing::Test {};
using TypesUnderTest =
::testing::Types<TypeUnion<int32_t>, TypeUnion<float_t>, TypeUnion<int16_t>,
TypeUnion<int8_t>, TypeUnion<uint8_t>>;
TYPED_TEST_SUITE(MatrixDiagOpTest, TypesUnderTest);
TYPED_TEST(MatrixDiagOpTest, ThreeByThreeDiag) {
MatrixDiagOpModel<typename TypeParam::ScalarType> model(
{TypeParam::tensor_type, {3}});
model.template PopulateTensor<typename TypeParam::ScalarType>(model.input(),
{1, 2, 3});
model.Invoke();
EXPECT_THAT(model.GetOutputShape(), ElementsAre(3, 3));
EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 0, 0, //
0, 2, 0, //
0, 0, 3}));
EXPECT_THAT(model.GetOutputType(), TypeParam::tflite_type);
}
// Additional special cases.
TEST(MatrixDiagTest, Int32TestTwoDimDiag) {
MatrixDiagOpModel<int32_t> model({TensorType_INT32, {2, 4}});
model.PopulateTensor<int32_t>(model.input(), {1, 2, 3, 4, 5, 6, 7, 8});
model.Invoke();
EXPECT_THAT(model.GetOutputShape(), ElementsAre(2, 4, 4));
EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 0, 0, 0, //
0, 2, 0, 0, //
0, 0, 3, 0, //
0, 0, 0, 4, //
5, 0, 0, 0, //
0, 6, 0, 0, //
0, 0, 7, 0, //
0, 0, 0, 8}));
EXPECT_THAT(model.GetOutputType(), TfLiteType::kTfLiteInt32);
}
TEST(MatrixDiagTest, DegenenerateCase) {
MatrixDiagOpModel<uint8_t> model({TensorType_UINT8, {1}});
model.PopulateTensor<uint8_t>(model.input(), {1});
model.Invoke();
EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1));
EXPECT_THAT(model.GetOutput(), ElementsAreArray({1}));
EXPECT_THAT(model.GetOutputType(), TfLiteType::kTfLiteUInt8);
}
} // namespace
} // namespace tflite
int main(int argc, char** argv) {
::tflite::LogToStderr();
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}

View File

@ -139,6 +139,7 @@ TfLiteRegistration* Register_GATHER_ND();
TfLiteRegistration* Register_WHERE(); TfLiteRegistration* Register_WHERE();
TfLiteRegistration* Register_ELU(); TfLiteRegistration* Register_ELU();
TfLiteRegistration* Register_REVERSE_SEQUENCE(); TfLiteRegistration* Register_REVERSE_SEQUENCE();
TfLiteRegistration* Register_MATRIX_DIAG();
TfLiteStatus UnsupportedTensorFlowOp(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus UnsupportedTensorFlowOp(TfLiteContext* context, TfLiteNode* node) {
context->ReportError( context->ReportError(
@ -374,6 +375,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_WHERE, Register_WHERE()); AddBuiltin(BuiltinOperator_WHERE, Register_WHERE());
AddBuiltin(BuiltinOperator_ELU, Register_ELU()); AddBuiltin(BuiltinOperator_ELU, Register_ELU());
AddBuiltin(BuiltinOperator_REVERSE_SEQUENCE, Register_REVERSE_SEQUENCE()); AddBuiltin(BuiltinOperator_REVERSE_SEQUENCE, Register_REVERSE_SEQUENCE());
AddBuiltin(BuiltinOperator_MATRIX_DIAG, Register_MATRIX_DIAG());
// TODO(andrewharp, ahentz): Move these somewhere more appropriate so that // TODO(andrewharp, ahentz): Move these somewhere more appropriate so that
// custom ops aren't always included by default. // custom ops aren't always included by default.

View File

@ -522,6 +522,56 @@ TensorType GetTensorType() {
// Strings have a special implementation that is in test_util.cc // Strings have a special implementation that is in test_util.cc
template <> template <>
std::vector<string> SingleOpModel::ExtractVector(int index); std::vector<string> SingleOpModel::ExtractVector(int index);
// The TypeUnion struct specializations hold a collection of related types.
// Each struct holds: 1. a primitive type (e.g. float), 2. a TensorType (e.g.
// TensorType_FLOAT32, and 3. a TfLiteType (e.g. kTfLiteFloat32). The latter
// two are actually enum values and not raw types, but these specializations
// make it easy to use gUnit Typed Test Suite:
// https://github.com/google/googletest/blob/master/googletest/docs/advanced.md#typed-tests
template <typename T>
struct TypeUnion;
template <>
struct TypeUnion<float_t> {
public:
static const TensorType tensor_type = TensorType::TensorType_FLOAT32;
static const TfLiteType tflite_type = TfLiteType::kTfLiteFloat32;
typedef float_t ScalarType;
};
template <>
struct TypeUnion<int32_t> {
public:
static const TensorType tensor_type = TensorType::TensorType_INT32;
static const TfLiteType tflite_type = TfLiteType::kTfLiteInt32;
typedef int32_t ScalarType;
};
template <>
struct TypeUnion<int16_t> {
public:
static const TensorType tensor_type = TensorType::TensorType_INT16;
static const TfLiteType tflite_type = TfLiteType::kTfLiteInt16;
typedef int16_t ScalarType;
};
template <>
struct TypeUnion<int8_t> {
public:
static const TensorType tensor_type = TensorType::TensorType_INT8;
static const TfLiteType tflite_type = TfLiteType::kTfLiteInt8;
typedef int8_t ScalarType;
};
template <>
struct TypeUnion<uint8_t> {
public:
static const TensorType tensor_type = TensorType::TensorType_UINT8;
static const TfLiteType tflite_type = TfLiteType::kTfLiteUInt8;
typedef uint8_t ScalarType;
};
} // namespace tflite } // namespace tflite
#endif // TENSORFLOW_LITE_KERNELS_TEST_UTIL_H_ #endif // TENSORFLOW_LITE_KERNELS_TEST_UTIL_H_

View File

@ -671,6 +671,7 @@ TfLiteStatus AddOpsAndParams(
case tflite::BuiltinOperator_RANK: case tflite::BuiltinOperator_RANK:
case tflite::BuiltinOperator_ELU: case tflite::BuiltinOperator_ELU:
case tflite::BuiltinOperator_REVERSE_SEQUENCE: case tflite::BuiltinOperator_REVERSE_SEQUENCE:
case tflite::BuiltinOperator_MATRIX_DIAG:
logError("Op code %d is currently not delegated to NNAPI", builtin); logError("Op code %d is currently not delegated to NNAPI", builtin);
return kTfLiteError; return kTfLiteError;
break; break;

View File

@ -226,6 +226,7 @@ enum BuiltinOperator : byte {
RANK = 110, RANK = 110,
ELU = 111, ELU = 111,
REVERSE_SEQUENCE = 112, REVERSE_SEQUENCE = 112,
MATRIX_DIAG = 113,
} }
// Options for the builtin operators. // Options for the builtin operators.
@ -317,6 +318,7 @@ union BuiltinOptions {
WhereOptions, WhereOptions,
RankOptions, RankOptions,
ReverseSequenceOptions, ReverseSequenceOptions,
MatrixDiagOptions,
} }
enum Padding : byte { SAME, VALID } enum Padding : byte { SAME, VALID }
@ -756,6 +758,10 @@ table ReverseSequenceOptions {
seq_dim:int; seq_dim:int;
batch_dim:int = 0; batch_dim:int = 0;
} }
table MatrixDiagOptions {
}
// An OperatorCode can be an enum value (BuiltinOperator) if the operator is a // An OperatorCode can be an enum value (BuiltinOperator) if the operator is a
// builtin, or a string if the operator is custom. // builtin, or a string if the operator is custom.
table OperatorCode { table OperatorCode {

View File

@ -292,6 +292,9 @@ struct WhereOptionsT;
struct ReverseSequenceOptions; struct ReverseSequenceOptions;
struct ReverseSequenceOptionsT; struct ReverseSequenceOptionsT;
struct MatrixDiagOptions;
struct MatrixDiagOptionsT;
struct OperatorCode; struct OperatorCode;
struct OperatorCodeT; struct OperatorCodeT;
@ -554,11 +557,12 @@ enum BuiltinOperator {
BuiltinOperator_RANK = 110, BuiltinOperator_RANK = 110,
BuiltinOperator_ELU = 111, BuiltinOperator_ELU = 111,
BuiltinOperator_REVERSE_SEQUENCE = 112, BuiltinOperator_REVERSE_SEQUENCE = 112,
BuiltinOperator_MATRIX_DIAG = 113,
BuiltinOperator_MIN = BuiltinOperator_ADD, BuiltinOperator_MIN = BuiltinOperator_ADD,
BuiltinOperator_MAX = BuiltinOperator_REVERSE_SEQUENCE BuiltinOperator_MAX = BuiltinOperator_MATRIX_DIAG
}; };
inline const BuiltinOperator (&EnumValuesBuiltinOperator())[112] { inline const BuiltinOperator (&EnumValuesBuiltinOperator())[113] {
static const BuiltinOperator values[] = { static const BuiltinOperator values[] = {
BuiltinOperator_ADD, BuiltinOperator_ADD,
BuiltinOperator_AVERAGE_POOL_2D, BuiltinOperator_AVERAGE_POOL_2D,
@ -671,7 +675,8 @@ inline const BuiltinOperator (&EnumValuesBuiltinOperator())[112] {
BuiltinOperator_WHERE, BuiltinOperator_WHERE,
BuiltinOperator_RANK, BuiltinOperator_RANK,
BuiltinOperator_ELU, BuiltinOperator_ELU,
BuiltinOperator_REVERSE_SEQUENCE BuiltinOperator_REVERSE_SEQUENCE,
BuiltinOperator_MATRIX_DIAG
}; };
return values; return values;
} }
@ -791,6 +796,7 @@ inline const char * const *EnumNamesBuiltinOperator() {
"RANK", "RANK",
"ELU", "ELU",
"REVERSE_SEQUENCE", "REVERSE_SEQUENCE",
"MATRIX_DIAG",
nullptr nullptr
}; };
return names; return names;
@ -890,11 +896,12 @@ enum BuiltinOptions {
BuiltinOptions_WhereOptions = 85, BuiltinOptions_WhereOptions = 85,
BuiltinOptions_RankOptions = 86, BuiltinOptions_RankOptions = 86,
BuiltinOptions_ReverseSequenceOptions = 87, BuiltinOptions_ReverseSequenceOptions = 87,
BuiltinOptions_MatrixDiagOptions = 88,
BuiltinOptions_MIN = BuiltinOptions_NONE, BuiltinOptions_MIN = BuiltinOptions_NONE,
BuiltinOptions_MAX = BuiltinOptions_ReverseSequenceOptions BuiltinOptions_MAX = BuiltinOptions_MatrixDiagOptions
}; };
inline const BuiltinOptions (&EnumValuesBuiltinOptions())[88] { inline const BuiltinOptions (&EnumValuesBuiltinOptions())[89] {
static const BuiltinOptions values[] = { static const BuiltinOptions values[] = {
BuiltinOptions_NONE, BuiltinOptions_NONE,
BuiltinOptions_Conv2DOptions, BuiltinOptions_Conv2DOptions,
@ -983,7 +990,8 @@ inline const BuiltinOptions (&EnumValuesBuiltinOptions())[88] {
BuiltinOptions_CosOptions, BuiltinOptions_CosOptions,
BuiltinOptions_WhereOptions, BuiltinOptions_WhereOptions,
BuiltinOptions_RankOptions, BuiltinOptions_RankOptions,
BuiltinOptions_ReverseSequenceOptions BuiltinOptions_ReverseSequenceOptions,
BuiltinOptions_MatrixDiagOptions
}; };
return values; return values;
} }
@ -1078,6 +1086,7 @@ inline const char * const *EnumNamesBuiltinOptions() {
"WhereOptions", "WhereOptions",
"RankOptions", "RankOptions",
"ReverseSequenceOptions", "ReverseSequenceOptions",
"MatrixDiagOptions",
nullptr nullptr
}; };
return names; return names;
@ -1440,6 +1449,10 @@ template<> struct BuiltinOptionsTraits<ReverseSequenceOptions> {
static const BuiltinOptions enum_value = BuiltinOptions_ReverseSequenceOptions; static const BuiltinOptions enum_value = BuiltinOptions_ReverseSequenceOptions;
}; };
template<> struct BuiltinOptionsTraits<MatrixDiagOptions> {
static const BuiltinOptions enum_value = BuiltinOptions_MatrixDiagOptions;
};
struct BuiltinOptionsUnion { struct BuiltinOptionsUnion {
BuiltinOptions type; BuiltinOptions type;
void *value; void *value;
@ -2167,6 +2180,14 @@ struct BuiltinOptionsUnion {
return type == BuiltinOptions_ReverseSequenceOptions ? return type == BuiltinOptions_ReverseSequenceOptions ?
reinterpret_cast<const ReverseSequenceOptionsT *>(value) : nullptr; reinterpret_cast<const ReverseSequenceOptionsT *>(value) : nullptr;
} }
MatrixDiagOptionsT *AsMatrixDiagOptions() {
return type == BuiltinOptions_MatrixDiagOptions ?
reinterpret_cast<MatrixDiagOptionsT *>(value) : nullptr;
}
const MatrixDiagOptionsT *AsMatrixDiagOptions() const {
return type == BuiltinOptions_MatrixDiagOptions ?
reinterpret_cast<const MatrixDiagOptionsT *>(value) : nullptr;
}
}; };
bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type); bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type);
@ -7569,6 +7590,46 @@ inline flatbuffers::Offset<ReverseSequenceOptions> CreateReverseSequenceOptions(
flatbuffers::Offset<ReverseSequenceOptions> CreateReverseSequenceOptions(flatbuffers::FlatBufferBuilder &_fbb, const ReverseSequenceOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); flatbuffers::Offset<ReverseSequenceOptions> CreateReverseSequenceOptions(flatbuffers::FlatBufferBuilder &_fbb, const ReverseSequenceOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
struct MatrixDiagOptionsT : public flatbuffers::NativeTable {
typedef MatrixDiagOptions TableType;
MatrixDiagOptionsT() {
}
};
struct MatrixDiagOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
typedef MatrixDiagOptionsT NativeTableType;
bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) &&
verifier.EndTable();
}
MatrixDiagOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
void UnPackTo(MatrixDiagOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
static flatbuffers::Offset<MatrixDiagOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const MatrixDiagOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
};
struct MatrixDiagOptionsBuilder {
flatbuffers::FlatBufferBuilder &fbb_;
flatbuffers::uoffset_t start_;
explicit MatrixDiagOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
: fbb_(_fbb) {
start_ = fbb_.StartTable();
}
MatrixDiagOptionsBuilder &operator=(const MatrixDiagOptionsBuilder &);
flatbuffers::Offset<MatrixDiagOptions> Finish() {
const auto end = fbb_.EndTable(start_);
auto o = flatbuffers::Offset<MatrixDiagOptions>(end);
return o;
}
};
inline flatbuffers::Offset<MatrixDiagOptions> CreateMatrixDiagOptions(
flatbuffers::FlatBufferBuilder &_fbb) {
MatrixDiagOptionsBuilder builder_(_fbb);
return builder_.Finish();
}
flatbuffers::Offset<MatrixDiagOptions> CreateMatrixDiagOptions(flatbuffers::FlatBufferBuilder &_fbb, const MatrixDiagOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
struct OperatorCodeT : public flatbuffers::NativeTable { struct OperatorCodeT : public flatbuffers::NativeTable {
typedef OperatorCode TableType; typedef OperatorCode TableType;
BuiltinOperator builtin_code; BuiltinOperator builtin_code;
@ -7963,6 +8024,9 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
const ReverseSequenceOptions *builtin_options_as_ReverseSequenceOptions() const { const ReverseSequenceOptions *builtin_options_as_ReverseSequenceOptions() const {
return builtin_options_type() == BuiltinOptions_ReverseSequenceOptions ? static_cast<const ReverseSequenceOptions *>(builtin_options()) : nullptr; return builtin_options_type() == BuiltinOptions_ReverseSequenceOptions ? static_cast<const ReverseSequenceOptions *>(builtin_options()) : nullptr;
} }
const MatrixDiagOptions *builtin_options_as_MatrixDiagOptions() const {
return builtin_options_type() == BuiltinOptions_MatrixDiagOptions ? static_cast<const MatrixDiagOptions *>(builtin_options()) : nullptr;
}
const flatbuffers::Vector<uint8_t> *custom_options() const { const flatbuffers::Vector<uint8_t> *custom_options() const {
return GetPointer<const flatbuffers::Vector<uint8_t> *>(VT_CUSTOM_OPTIONS); return GetPointer<const flatbuffers::Vector<uint8_t> *>(VT_CUSTOM_OPTIONS);
} }
@ -8342,6 +8406,10 @@ template<> inline const ReverseSequenceOptions *Operator::builtin_options_as<Rev
return builtin_options_as_ReverseSequenceOptions(); return builtin_options_as_ReverseSequenceOptions();
} }
template<> inline const MatrixDiagOptions *Operator::builtin_options_as<MatrixDiagOptions>() const {
return builtin_options_as_MatrixDiagOptions();
}
struct OperatorBuilder { struct OperatorBuilder {
flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::FlatBufferBuilder &fbb_;
flatbuffers::uoffset_t start_; flatbuffers::uoffset_t start_;
@ -11156,6 +11224,29 @@ inline flatbuffers::Offset<ReverseSequenceOptions> CreateReverseSequenceOptions(
_batch_dim); _batch_dim);
} }
inline MatrixDiagOptionsT *MatrixDiagOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
auto _o = new MatrixDiagOptionsT();
UnPackTo(_o, _resolver);
return _o;
}
inline void MatrixDiagOptions::UnPackTo(MatrixDiagOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const {
(void)_o;
(void)_resolver;
}
inline flatbuffers::Offset<MatrixDiagOptions> MatrixDiagOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const MatrixDiagOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
return CreateMatrixDiagOptions(_fbb, _o, _rehasher);
}
inline flatbuffers::Offset<MatrixDiagOptions> CreateMatrixDiagOptions(flatbuffers::FlatBufferBuilder &_fbb, const MatrixDiagOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
(void)_rehasher;
(void)_o;
struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const MatrixDiagOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
return tflite::CreateMatrixDiagOptions(
_fbb);
}
inline OperatorCodeT *OperatorCode::UnPack(const flatbuffers::resolver_function_t *_resolver) const { inline OperatorCodeT *OperatorCode::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
auto _o = new OperatorCodeT(); auto _o = new OperatorCodeT();
UnPackTo(_o, _resolver); UnPackTo(_o, _resolver);
@ -11762,6 +11853,10 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *ob
auto ptr = reinterpret_cast<const ReverseSequenceOptions *>(obj); auto ptr = reinterpret_cast<const ReverseSequenceOptions *>(obj);
return verifier.VerifyTable(ptr); return verifier.VerifyTable(ptr);
} }
case BuiltinOptions_MatrixDiagOptions: {
auto ptr = reinterpret_cast<const MatrixDiagOptions *>(obj);
return verifier.VerifyTable(ptr);
}
default: return false; default: return false;
} }
} }
@ -12128,6 +12223,10 @@ inline void *BuiltinOptionsUnion::UnPack(const void *obj, BuiltinOptions type, c
auto ptr = reinterpret_cast<const ReverseSequenceOptions *>(obj); auto ptr = reinterpret_cast<const ReverseSequenceOptions *>(obj);
return ptr->UnPack(resolver); return ptr->UnPack(resolver);
} }
case BuiltinOptions_MatrixDiagOptions: {
auto ptr = reinterpret_cast<const MatrixDiagOptions *>(obj);
return ptr->UnPack(resolver);
}
default: return nullptr; default: return nullptr;
} }
} }
@ -12482,6 +12581,10 @@ inline flatbuffers::Offset<void> BuiltinOptionsUnion::Pack(flatbuffers::FlatBuff
auto ptr = reinterpret_cast<const ReverseSequenceOptionsT *>(value); auto ptr = reinterpret_cast<const ReverseSequenceOptionsT *>(value);
return CreateReverseSequenceOptions(_fbb, ptr, _rehasher).Union(); return CreateReverseSequenceOptions(_fbb, ptr, _rehasher).Union();
} }
case BuiltinOptions_MatrixDiagOptions: {
auto ptr = reinterpret_cast<const MatrixDiagOptionsT *>(value);
return CreateMatrixDiagOptions(_fbb, ptr, _rehasher).Union();
}
default: return 0; default: return 0;
} }
} }
@ -12836,6 +12939,10 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) FL
value = new ReverseSequenceOptionsT(*reinterpret_cast<ReverseSequenceOptionsT *>(u.value)); value = new ReverseSequenceOptionsT(*reinterpret_cast<ReverseSequenceOptionsT *>(u.value));
break; break;
} }
case BuiltinOptions_MatrixDiagOptions: {
value = new MatrixDiagOptionsT(*reinterpret_cast<MatrixDiagOptionsT *>(u.value));
break;
}
default: default:
break; break;
} }
@ -13278,6 +13385,11 @@ inline void BuiltinOptionsUnion::Reset() {
delete ptr; delete ptr;
break; break;
} }
case BuiltinOptions_MatrixDiagOptions: {
auto ptr = reinterpret_cast<MatrixDiagOptionsT *>(value);
delete ptr;
break;
}
default: break; default: break;
} }
value = nullptr; value = nullptr;

View File

@ -4361,6 +4361,33 @@ def make_reverse_sequence_tests(zip_path):
make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
def make_matrix_diag_tests(zip_path):
"""Make a set of tests for tf.matrix_diag op."""
test_parameters = [
{
"input_shape": [[3], [2, 3], [3, 4, 5], [2, 4, 6, 8]],
"input_dtype": [tf.int32, tf.float32],
},
]
def build_graph(parameters):
input_tensor = tf.placeholder(
dtype=parameters["input_dtype"],
name="input",
shape=parameters["input_shape"])
outs = tf.matrix_diag(input_tensor)
return [input_tensor], [outs]
def build_inputs(parameters, sess, inputs, outputs):
input_values = create_tensor_data(parameters["input_dtype"],
parameters["input_shape"])
return [input_values], sess.run(
outputs, feed_dict=dict(zip(inputs, [input_values])))
make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
@test_util.enable_control_flow_v2 @test_util.enable_control_flow_v2
def make_unidirectional_sequence_lstm_tests(zip_path): def make_unidirectional_sequence_lstm_tests(zip_path):
"""Make a set of tests to do unidirectional_sequence_lstm.""" """Make a set of tests to do unidirectional_sequence_lstm."""

View File

@ -286,6 +286,13 @@ void SetDataTypeForAllOutputs(Model* model, Operator* op,
// have data type fields for all their arrays. // have data type fields for all their arrays.
break; break;
} }
case OperatorType::kMatrixDiag: {
CHECK_EQ(op->inputs.size(), 1);
CHECK_EQ(op->outputs.size(), 1);
const ArrayDataType data_type = model->GetArray(op->inputs[0]).data_type;
SetDataTypeForAllOutputs(model, op, data_type);
break;
}
default: { default: {
// These operators produce outputs with the same type as their 1st input // These operators produce outputs with the same type as their 1st input
CHECK_GT(op->inputs.size(), 0); CHECK_GT(op->inputs.size(), 0);

View File

@ -2060,6 +2060,24 @@ void ProcessUniqueOperator(Model* model, UniqueOperator* op) {
idx_output_array.copy_shape(input_array.shape()); idx_output_array.copy_shape(input_array.shape());
} }
void ProcessMatrixDiagOperator(Model* model, MatrixDiagOperator* op) {
CHECK_EQ(op->inputs.size(), 1);
CHECK_EQ(op->outputs.size(), 1);
auto& output_array = model->GetArray(op->outputs[0]);
if (output_array.has_shape()) {
// We have already run
return;
}
// Get the input_shape
auto& input_array = model->GetArray(op->inputs[0]);
Shape* mutable_shape = input_array.mutable_shape();
std::vector<int>* dims = mutable_shape->mutable_dims();
int dims_size = dims->size();
int last_dim = (*dims)[dims_size - 1];
dims->push_back(last_dim);
output_array.copy_shape(*mutable_shape);
}
} // namespace } // namespace
::tensorflow::Status PropagateFixedSizes::Run(Model* model, ::tensorflow::Status PropagateFixedSizes::Run(Model* model,
@ -2363,6 +2381,9 @@ void ProcessUniqueOperator(Model* model, UniqueOperator* op) {
// tensor. Ignore shape propagation here and defer that to the // tensor. Ignore shape propagation here and defer that to the
// interpreter. // interpreter.
break; break;
case OperatorType::kMatrixDiag:
ProcessMatrixDiagOperator(model, static_cast<MatrixDiagOperator*>(op));
break;
default: default:
// Unimplemented, another graph transformation should drop it. // Unimplemented, another graph transformation should drop it.
LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(op->type); LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(op->type);

View File

@ -2471,6 +2471,7 @@ ConverterMapType GetTensorFlowNodeConverterMap() {
{"LogicalNot", ConvertSimpleOperator<LogicalNotOperator, 1, 1>}, {"LogicalNot", ConvertSimpleOperator<LogicalNotOperator, 1, 1>},
{"LogSoftmax", ConvertSimpleOperator<LogSoftmaxOperator, 1, 1>}, {"LogSoftmax", ConvertSimpleOperator<LogSoftmaxOperator, 1, 1>},
{"MatMul", ConvertMatMulOperator}, {"MatMul", ConvertMatMulOperator},
{"MatrixDiag", ConvertSimpleOperator<MatrixDiagOperator, 1, 1>},
{"Max", ConvertReduceOperator<TensorFlowMaxOperator>}, {"Max", ConvertReduceOperator<TensorFlowMaxOperator>},
{"MaxPool", ConvertMaxPoolOperator}, {"MaxPool", ConvertMaxPoolOperator},
{"Maximum", ConvertSimpleOperator<TensorFlowMaximumOperator, 2, 1>}, {"Maximum", ConvertSimpleOperator<TensorFlowMaximumOperator, 2, 1>},

View File

@ -168,7 +168,8 @@ enum class OperatorType : uint8 {
kGatherNd, kGatherNd,
kWhere, kWhere,
kElu, kElu,
kReverseSequence kReverseSequence,
kMatrixDiag
}; };
// Helper to deal with TensorFlow arrays using a different ordering of // Helper to deal with TensorFlow arrays using a different ordering of
@ -2075,6 +2076,14 @@ struct WhereOperator : Operator {
WhereOperator() : Operator(OperatorType::kWhere) {} WhereOperator() : Operator(OperatorType::kWhere) {}
}; };
// Matrix Diag Operator:
// Construct a batched diagonal tensor with given batched diagonal values.
// Inputs: A tensor of values that will be on the diagonal of the returned
// tensor.
struct MatrixDiagOperator : Operator {
MatrixDiagOperator() : Operator(OperatorType::kMatrixDiag) {}
};
// Alloc's are used for transient arrays only. An Alloc specifies which interval // Alloc's are used for transient arrays only. An Alloc specifies which interval
// of the "transient_data" workspace buffer passed to inference functions, is to // of the "transient_data" workspace buffer passed to inference functions, is to
// be used for the transient array at hand. The 'start' and 'end' values are // be used for the transient array at hand. The 'start' and 'end' values are

View File

@ -2476,7 +2476,8 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList(
ops.push_back( ops.push_back(
MakeUnique<ReverseSequence>(::tflite::BuiltinOperator_REVERSE_SEQUENCE, MakeUnique<ReverseSequence>(::tflite::BuiltinOperator_REVERSE_SEQUENCE,
OperatorType::kReverseSequence)); OperatorType::kReverseSequence));
ops.push_back(MakeUnique<SimpleOperator<MatrixDiagOperator>>(
"MATRIX_DIAG", OperatorType::kMatrixDiag));
// Custom Operators. // Custom Operators.
ops.push_back( ops.push_back(
MakeUnique<DepthToSpace>("DEPTH_TO_SPACE", OperatorType::kDepthToSpace)); MakeUnique<DepthToSpace>("DEPTH_TO_SPACE", OperatorType::kDepthToSpace));

View File

@ -705,6 +705,13 @@ TEST_F(OperatorTest, BuiltinReverseSequence) {
EXPECT_EQ(op.batch_dim, output_toco_op->batch_dim); EXPECT_EQ(op.batch_dim, output_toco_op->batch_dim);
} }
TEST_F(OperatorTest, BuiltinMatrixDiag) {
MatrixDiagOperator op;
std::unique_ptr<toco::MatrixDiagOperator> output_toco_op =
SerializeAndDeserialize(
GetOperator("MATRIX_DIAG", OperatorType::kMatrixDiag), op);
}
// Test version for a simple Op with 2 versions and the input type controls the // Test version for a simple Op with 2 versions and the input type controls the
// version. // version.
template <typename Op> template <typename Op>

View File

@ -427,6 +427,7 @@ const char* OperatorTypeName(OperatorType type) {
HANDLE_OPERATORTYPENAME_CASE(Cos) HANDLE_OPERATORTYPENAME_CASE(Cos)
HANDLE_OPERATORTYPENAME_CASE(Where) HANDLE_OPERATORTYPENAME_CASE(Where)
HANDLE_OPERATORTYPENAME_CASE(ReverseSequence) HANDLE_OPERATORTYPENAME_CASE(ReverseSequence)
HANDLE_OPERATORTYPENAME_CASE(MatrixDiag)
default: default:
LOG(FATAL) << "Unhandled op type"; LOG(FATAL) << "Unhandled op type";
#undef HANDLE_OPERATORTYPENAME_CASE #undef HANDLE_OPERATORTYPENAME_CASE