Implement Matrix Diag

PiperOrigin-RevId: 239061891
This commit is contained in:
A. Unique TensorFlower 2019-03-18 14:29:41 -07:00 committed by TensorFlower Gardener
parent b7a36dec6b
commit d79bb04e2a
20 changed files with 520 additions and 10 deletions

View File

@ -274,6 +274,7 @@ def generated_test_models():
"logical_or",
"logical_xor",
"lstm",
"matrix_diag",
"max_pool",
"maximum",
"mean",

View File

@ -138,6 +138,7 @@ typedef enum {
kTfLiteBuiltinRank = 110,
kTfLiteBuiltinElu = 111,
kTfLiteBuiltinReverseSequence = 112,
kTfLiteBuiltinMatrixDiag = 113,
} TfLiteBuiltinOperator;
#ifdef __cplusplus

View File

@ -373,6 +373,10 @@ typedef struct {
int batch_dim;
} TfLiteReverseSequenceParams;
typedef struct {
EmptyStructPlaceholder placeholder;
} TfLiteMatrixDiagParams;
#ifdef __cplusplus
} // extern "C"
#endif // __cplusplus

View File

@ -683,8 +683,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
*builtin_data = reinterpret_cast<void*>(params);
break;
}
// Below are the ops with no builtin_data strcture.
// Below are the ops with no builtin_data structure.
case BuiltinOperator_ABS:
case BuiltinOperator_BATCH_TO_SPACE_ND:
// TODO(aselle): Implement call in BuiltinOptions, but nullptrs are
@ -708,6 +707,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
case BuiltinOperator_LOG:
case BuiltinOperator_LOGISTIC:
case BuiltinOperator_LOG_SOFTMAX:
case BuiltinOperator_MATRIX_DIAG:
case BuiltinOperator_MAXIMUM:
case BuiltinOperator_MINIMUM:
case BuiltinOperator_NEG:

View File

@ -188,6 +188,7 @@ cc_library(
"logical.cc",
"lsh_projection.cc",
"lstm.cc",
"matrix_diag.cc",
"maximum_minimum.cc",
"mfcc.cc",
"mirror_pad.cc",
@ -1388,3 +1389,15 @@ cc_test(
"@com_google_googletest//:gtest",
],
)
cc_test(
name = "matrix_diag_test",
size = "small",
srcs = ["matrix_diag_test.cc"],
deps = [
":builtin_ops",
"//tensorflow/lite:framework",
"//tensorflow/lite/kernels:test_util",
"@com_google_googletest//:gtest",
],
)

View File

@ -0,0 +1,136 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <string.h>
#include <vector>
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/c_api_internal.h"
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/lite/kernels/internal/tensor.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/kernels/op_macros.h"
namespace tflite {
namespace ops {
namespace builtin {
namespace matrix_diag {
constexpr int kInputTensor = 0;
constexpr int kOutputTensor = 0;
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TfLiteIntArray* input_dims = input->dims;
int input_dims_size = input_dims->size;
TF_LITE_ENSURE(context, input_dims_size >= 1);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
// Resize the output tensor.
TfLiteIntArray* output_shape = TfLiteIntArrayCreate(input_dims_size + 1);
for (int i = 0; i < input_dims_size; i++) {
output_shape->data[i] = input_dims->data[i];
}
// Last dimension in the output is the same as the last dimension in the
// input.
output_shape->data[input_dims_size] = input_dims->data[input_dims_size - 1];
output->type = input->type;
TF_LITE_ENSURE_OK(context,
context->ResizeTensor(context, output, output_shape));
return kTfLiteOk;
}
// Fill the tensor to make a diagonal matrix in each batch, i.e., when
// row index and column index are the same, fill with the next input value.
// All other entries get zero.
// TODO(b/128636574) Move to reference_ops.
template <typename T>
void FillDiagImpl(const T* in, T* out, const int batch_size, const int row_size,
const int col_size) {
int idx = 0;
for (int b = 0; b < batch_size; b++) {
for (int i = 0; i < row_size; i++) {
for (int j = 0; j < col_size; ++j) {
// input values go on the diagonal, 0 elsewhere
if (i == j) {
out[i * col_size + j] = in[idx];
idx++;
} else {
out[i * col_size + j] = 0;
}
}
}
out += row_size * col_size;
}
}
template <typename T>
void FillDiag(const TfLiteTensor* input, TfLiteTensor* output,
const int batch_size, const int row_size, const int col_size) {
FillDiagImpl<T>(GetTensorData<T>(input), GetTensorData<T>(output), batch_size,
row_size, col_size);
}
// Fill a tensor with given input on the diagonal, zero elsewhere
void FillDiagHelper(const TfLiteTensor* input, TfLiteTensor* output) {
const int num_output_dims = output->dims->size;
int batch_size = 1;
for (int i = 0; i < num_output_dims - 2; ++i) {
batch_size *= output->dims->data[i];
}
const int row_size = output->dims->data[num_output_dims - 2];
const int col_size = output->dims->data[num_output_dims - 1];
switch (output->type) {
case kTfLiteInt64: {
return FillDiag<int64_t>(input, output, batch_size, row_size, col_size);
}
case kTfLiteInt32: {
return FillDiag<int32_t>(input, output, batch_size, row_size, col_size);
}
case kTfLiteInt16: {
return FillDiag<int16_t>(input, output, batch_size, row_size, col_size);
}
case kTfLiteInt8: {
return FillDiag<int8_t>(input, output, batch_size, row_size, col_size);
}
case kTfLiteUInt8: {
return FillDiag<uint8_t>(input, output, batch_size, row_size, col_size);
}
default:
return FillDiag<float_t>(input, output, batch_size, row_size, col_size);
}
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
FillDiagHelper(input, output);
return kTfLiteOk;
}
} // namespace matrix_diag
TfLiteRegistration* Register_MATRIX_DIAG() {
static TfLiteRegistration r = {nullptr, nullptr, matrix_diag::Prepare,
matrix_diag::Eval};
return &r;
}
} // namespace builtin
} // namespace ops
} // namespace tflite

View File

@ -0,0 +1,110 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/kernels/register.h"
#include "tensorflow/lite/kernels/test_util.h"
#include "tensorflow/lite/model.h"
namespace tflite {
namespace {
using ::testing::ElementsAre;
using ::testing::ElementsAreArray;
template <typename T>
class MatrixDiagOpModel : public SingleOpModel {
public:
explicit MatrixDiagOpModel(const TensorData& input) {
input_ = AddInput(input);
output_ = AddOutput({input.type, {}});
SetBuiltinOp(BuiltinOperator_MATRIX_DIAG, BuiltinOptions_MatrixDiagOptions,
CreateMatrixDiagOptions(builder_).Union());
BuildInterpreter({GetShape(input_)});
}
int input() { return input_; }
std::vector<T> GetOutput() { return ExtractVector<T>(output_); }
std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
TfLiteType GetOutputType() {
TfLiteTensor* t = interpreter_->tensor(output_);
return t->type;
}
private:
int input_;
int output_;
};
// Use the machinery of TYPED_TEST_SUITE to test all supported types.
// See
// https://github.com/google/googletest/blob/master/googletest/docs/advanced.md#typed-tests
// for details.
template <typename T>
class MatrixDiagOpTest : public ::testing::Test {};
using TypesUnderTest =
::testing::Types<TypeUnion<int32_t>, TypeUnion<float_t>, TypeUnion<int16_t>,
TypeUnion<int8_t>, TypeUnion<uint8_t>>;
TYPED_TEST_SUITE(MatrixDiagOpTest, TypesUnderTest);
TYPED_TEST(MatrixDiagOpTest, ThreeByThreeDiag) {
MatrixDiagOpModel<typename TypeParam::ScalarType> model(
{TypeParam::tensor_type, {3}});
model.template PopulateTensor<typename TypeParam::ScalarType>(model.input(),
{1, 2, 3});
model.Invoke();
EXPECT_THAT(model.GetOutputShape(), ElementsAre(3, 3));
EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 0, 0, //
0, 2, 0, //
0, 0, 3}));
EXPECT_THAT(model.GetOutputType(), TypeParam::tflite_type);
}
// Additional special cases.
TEST(MatrixDiagTest, Int32TestTwoDimDiag) {
MatrixDiagOpModel<int32_t> model({TensorType_INT32, {2, 4}});
model.PopulateTensor<int32_t>(model.input(), {1, 2, 3, 4, 5, 6, 7, 8});
model.Invoke();
EXPECT_THAT(model.GetOutputShape(), ElementsAre(2, 4, 4));
EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 0, 0, 0, //
0, 2, 0, 0, //
0, 0, 3, 0, //
0, 0, 0, 4, //
5, 0, 0, 0, //
0, 6, 0, 0, //
0, 0, 7, 0, //
0, 0, 0, 8}));
EXPECT_THAT(model.GetOutputType(), TfLiteType::kTfLiteInt32);
}
TEST(MatrixDiagTest, DegenenerateCase) {
MatrixDiagOpModel<uint8_t> model({TensorType_UINT8, {1}});
model.PopulateTensor<uint8_t>(model.input(), {1});
model.Invoke();
EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1));
EXPECT_THAT(model.GetOutput(), ElementsAreArray({1}));
EXPECT_THAT(model.GetOutputType(), TfLiteType::kTfLiteUInt8);
}
} // namespace
} // namespace tflite
int main(int argc, char** argv) {
::tflite::LogToStderr();
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}

View File

@ -139,6 +139,7 @@ TfLiteRegistration* Register_GATHER_ND();
TfLiteRegistration* Register_WHERE();
TfLiteRegistration* Register_ELU();
TfLiteRegistration* Register_REVERSE_SEQUENCE();
TfLiteRegistration* Register_MATRIX_DIAG();
TfLiteStatus UnsupportedTensorFlowOp(TfLiteContext* context, TfLiteNode* node) {
context->ReportError(
@ -374,6 +375,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_WHERE, Register_WHERE());
AddBuiltin(BuiltinOperator_ELU, Register_ELU());
AddBuiltin(BuiltinOperator_REVERSE_SEQUENCE, Register_REVERSE_SEQUENCE());
AddBuiltin(BuiltinOperator_MATRIX_DIAG, Register_MATRIX_DIAG());
// TODO(andrewharp, ahentz): Move these somewhere more appropriate so that
// custom ops aren't always included by default.

View File

@ -522,6 +522,56 @@ TensorType GetTensorType() {
// Strings have a special implementation that is in test_util.cc
template <>
std::vector<string> SingleOpModel::ExtractVector(int index);
// The TypeUnion struct specializations hold a collection of related types.
// Each struct holds: 1. a primitive type (e.g. float), 2. a TensorType (e.g.
// TensorType_FLOAT32, and 3. a TfLiteType (e.g. kTfLiteFloat32). The latter
// two are actually enum values and not raw types, but these specializations
// make it easy to use gUnit Typed Test Suite:
// https://github.com/google/googletest/blob/master/googletest/docs/advanced.md#typed-tests
template <typename T>
struct TypeUnion;
template <>
struct TypeUnion<float_t> {
public:
static const TensorType tensor_type = TensorType::TensorType_FLOAT32;
static const TfLiteType tflite_type = TfLiteType::kTfLiteFloat32;
typedef float_t ScalarType;
};
template <>
struct TypeUnion<int32_t> {
public:
static const TensorType tensor_type = TensorType::TensorType_INT32;
static const TfLiteType tflite_type = TfLiteType::kTfLiteInt32;
typedef int32_t ScalarType;
};
template <>
struct TypeUnion<int16_t> {
public:
static const TensorType tensor_type = TensorType::TensorType_INT16;
static const TfLiteType tflite_type = TfLiteType::kTfLiteInt16;
typedef int16_t ScalarType;
};
template <>
struct TypeUnion<int8_t> {
public:
static const TensorType tensor_type = TensorType::TensorType_INT8;
static const TfLiteType tflite_type = TfLiteType::kTfLiteInt8;
typedef int8_t ScalarType;
};
template <>
struct TypeUnion<uint8_t> {
public:
static const TensorType tensor_type = TensorType::TensorType_UINT8;
static const TfLiteType tflite_type = TfLiteType::kTfLiteUInt8;
typedef uint8_t ScalarType;
};
} // namespace tflite
#endif // TENSORFLOW_LITE_KERNELS_TEST_UTIL_H_

View File

@ -671,6 +671,7 @@ TfLiteStatus AddOpsAndParams(
case tflite::BuiltinOperator_RANK:
case tflite::BuiltinOperator_ELU:
case tflite::BuiltinOperator_REVERSE_SEQUENCE:
case tflite::BuiltinOperator_MATRIX_DIAG:
logError("Op code %d is currently not delegated to NNAPI", builtin);
return kTfLiteError;
break;

View File

@ -226,6 +226,7 @@ enum BuiltinOperator : byte {
RANK = 110,
ELU = 111,
REVERSE_SEQUENCE = 112,
MATRIX_DIAG = 113,
}
// Options for the builtin operators.
@ -317,6 +318,7 @@ union BuiltinOptions {
WhereOptions,
RankOptions,
ReverseSequenceOptions,
MatrixDiagOptions,
}
enum Padding : byte { SAME, VALID }
@ -756,6 +758,10 @@ table ReverseSequenceOptions {
seq_dim:int;
batch_dim:int = 0;
}
table MatrixDiagOptions {
}
// An OperatorCode can be an enum value (BuiltinOperator) if the operator is a
// builtin, or a string if the operator is custom.
table OperatorCode {

View File

@ -292,6 +292,9 @@ struct WhereOptionsT;
struct ReverseSequenceOptions;
struct ReverseSequenceOptionsT;
struct MatrixDiagOptions;
struct MatrixDiagOptionsT;
struct OperatorCode;
struct OperatorCodeT;
@ -554,11 +557,12 @@ enum BuiltinOperator {
BuiltinOperator_RANK = 110,
BuiltinOperator_ELU = 111,
BuiltinOperator_REVERSE_SEQUENCE = 112,
BuiltinOperator_MATRIX_DIAG = 113,
BuiltinOperator_MIN = BuiltinOperator_ADD,
BuiltinOperator_MAX = BuiltinOperator_REVERSE_SEQUENCE
BuiltinOperator_MAX = BuiltinOperator_MATRIX_DIAG
};
inline const BuiltinOperator (&EnumValuesBuiltinOperator())[112] {
inline const BuiltinOperator (&EnumValuesBuiltinOperator())[113] {
static const BuiltinOperator values[] = {
BuiltinOperator_ADD,
BuiltinOperator_AVERAGE_POOL_2D,
@ -671,7 +675,8 @@ inline const BuiltinOperator (&EnumValuesBuiltinOperator())[112] {
BuiltinOperator_WHERE,
BuiltinOperator_RANK,
BuiltinOperator_ELU,
BuiltinOperator_REVERSE_SEQUENCE
BuiltinOperator_REVERSE_SEQUENCE,
BuiltinOperator_MATRIX_DIAG
};
return values;
}
@ -791,6 +796,7 @@ inline const char * const *EnumNamesBuiltinOperator() {
"RANK",
"ELU",
"REVERSE_SEQUENCE",
"MATRIX_DIAG",
nullptr
};
return names;
@ -890,11 +896,12 @@ enum BuiltinOptions {
BuiltinOptions_WhereOptions = 85,
BuiltinOptions_RankOptions = 86,
BuiltinOptions_ReverseSequenceOptions = 87,
BuiltinOptions_MatrixDiagOptions = 88,
BuiltinOptions_MIN = BuiltinOptions_NONE,
BuiltinOptions_MAX = BuiltinOptions_ReverseSequenceOptions
BuiltinOptions_MAX = BuiltinOptions_MatrixDiagOptions
};
inline const BuiltinOptions (&EnumValuesBuiltinOptions())[88] {
inline const BuiltinOptions (&EnumValuesBuiltinOptions())[89] {
static const BuiltinOptions values[] = {
BuiltinOptions_NONE,
BuiltinOptions_Conv2DOptions,
@ -983,7 +990,8 @@ inline const BuiltinOptions (&EnumValuesBuiltinOptions())[88] {
BuiltinOptions_CosOptions,
BuiltinOptions_WhereOptions,
BuiltinOptions_RankOptions,
BuiltinOptions_ReverseSequenceOptions
BuiltinOptions_ReverseSequenceOptions,
BuiltinOptions_MatrixDiagOptions
};
return values;
}
@ -1078,6 +1086,7 @@ inline const char * const *EnumNamesBuiltinOptions() {
"WhereOptions",
"RankOptions",
"ReverseSequenceOptions",
"MatrixDiagOptions",
nullptr
};
return names;
@ -1440,6 +1449,10 @@ template<> struct BuiltinOptionsTraits<ReverseSequenceOptions> {
static const BuiltinOptions enum_value = BuiltinOptions_ReverseSequenceOptions;
};
template<> struct BuiltinOptionsTraits<MatrixDiagOptions> {
static const BuiltinOptions enum_value = BuiltinOptions_MatrixDiagOptions;
};
struct BuiltinOptionsUnion {
BuiltinOptions type;
void *value;
@ -2167,6 +2180,14 @@ struct BuiltinOptionsUnion {
return type == BuiltinOptions_ReverseSequenceOptions ?
reinterpret_cast<const ReverseSequenceOptionsT *>(value) : nullptr;
}
MatrixDiagOptionsT *AsMatrixDiagOptions() {
return type == BuiltinOptions_MatrixDiagOptions ?
reinterpret_cast<MatrixDiagOptionsT *>(value) : nullptr;
}
const MatrixDiagOptionsT *AsMatrixDiagOptions() const {
return type == BuiltinOptions_MatrixDiagOptions ?
reinterpret_cast<const MatrixDiagOptionsT *>(value) : nullptr;
}
};
bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type);
@ -7569,6 +7590,46 @@ inline flatbuffers::Offset<ReverseSequenceOptions> CreateReverseSequenceOptions(
flatbuffers::Offset<ReverseSequenceOptions> CreateReverseSequenceOptions(flatbuffers::FlatBufferBuilder &_fbb, const ReverseSequenceOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
struct MatrixDiagOptionsT : public flatbuffers::NativeTable {
typedef MatrixDiagOptions TableType;
MatrixDiagOptionsT() {
}
};
struct MatrixDiagOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
typedef MatrixDiagOptionsT NativeTableType;
bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) &&
verifier.EndTable();
}
MatrixDiagOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
void UnPackTo(MatrixDiagOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
static flatbuffers::Offset<MatrixDiagOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const MatrixDiagOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
};
struct MatrixDiagOptionsBuilder {
flatbuffers::FlatBufferBuilder &fbb_;
flatbuffers::uoffset_t start_;
explicit MatrixDiagOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
: fbb_(_fbb) {
start_ = fbb_.StartTable();
}
MatrixDiagOptionsBuilder &operator=(const MatrixDiagOptionsBuilder &);
flatbuffers::Offset<MatrixDiagOptions> Finish() {
const auto end = fbb_.EndTable(start_);
auto o = flatbuffers::Offset<MatrixDiagOptions>(end);
return o;
}
};
inline flatbuffers::Offset<MatrixDiagOptions> CreateMatrixDiagOptions(
flatbuffers::FlatBufferBuilder &_fbb) {
MatrixDiagOptionsBuilder builder_(_fbb);
return builder_.Finish();
}
flatbuffers::Offset<MatrixDiagOptions> CreateMatrixDiagOptions(flatbuffers::FlatBufferBuilder &_fbb, const MatrixDiagOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
struct OperatorCodeT : public flatbuffers::NativeTable {
typedef OperatorCode TableType;
BuiltinOperator builtin_code;
@ -7963,6 +8024,9 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
const ReverseSequenceOptions *builtin_options_as_ReverseSequenceOptions() const {
return builtin_options_type() == BuiltinOptions_ReverseSequenceOptions ? static_cast<const ReverseSequenceOptions *>(builtin_options()) : nullptr;
}
const MatrixDiagOptions *builtin_options_as_MatrixDiagOptions() const {
return builtin_options_type() == BuiltinOptions_MatrixDiagOptions ? static_cast<const MatrixDiagOptions *>(builtin_options()) : nullptr;
}
const flatbuffers::Vector<uint8_t> *custom_options() const {
return GetPointer<const flatbuffers::Vector<uint8_t> *>(VT_CUSTOM_OPTIONS);
}
@ -8342,6 +8406,10 @@ template<> inline const ReverseSequenceOptions *Operator::builtin_options_as<Rev
return builtin_options_as_ReverseSequenceOptions();
}
template<> inline const MatrixDiagOptions *Operator::builtin_options_as<MatrixDiagOptions>() const {
return builtin_options_as_MatrixDiagOptions();
}
struct OperatorBuilder {
flatbuffers::FlatBufferBuilder &fbb_;
flatbuffers::uoffset_t start_;
@ -11156,6 +11224,29 @@ inline flatbuffers::Offset<ReverseSequenceOptions> CreateReverseSequenceOptions(
_batch_dim);
}
inline MatrixDiagOptionsT *MatrixDiagOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
auto _o = new MatrixDiagOptionsT();
UnPackTo(_o, _resolver);
return _o;
}
inline void MatrixDiagOptions::UnPackTo(MatrixDiagOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const {
(void)_o;
(void)_resolver;
}
inline flatbuffers::Offset<MatrixDiagOptions> MatrixDiagOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const MatrixDiagOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
return CreateMatrixDiagOptions(_fbb, _o, _rehasher);
}
inline flatbuffers::Offset<MatrixDiagOptions> CreateMatrixDiagOptions(flatbuffers::FlatBufferBuilder &_fbb, const MatrixDiagOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
(void)_rehasher;
(void)_o;
struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const MatrixDiagOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
return tflite::CreateMatrixDiagOptions(
_fbb);
}
inline OperatorCodeT *OperatorCode::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
auto _o = new OperatorCodeT();
UnPackTo(_o, _resolver);
@ -11762,6 +11853,10 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *ob
auto ptr = reinterpret_cast<const ReverseSequenceOptions *>(obj);
return verifier.VerifyTable(ptr);
}
case BuiltinOptions_MatrixDiagOptions: {
auto ptr = reinterpret_cast<const MatrixDiagOptions *>(obj);
return verifier.VerifyTable(ptr);
}
default: return false;
}
}
@ -12128,6 +12223,10 @@ inline void *BuiltinOptionsUnion::UnPack(const void *obj, BuiltinOptions type, c
auto ptr = reinterpret_cast<const ReverseSequenceOptions *>(obj);
return ptr->UnPack(resolver);
}
case BuiltinOptions_MatrixDiagOptions: {
auto ptr = reinterpret_cast<const MatrixDiagOptions *>(obj);
return ptr->UnPack(resolver);
}
default: return nullptr;
}
}
@ -12482,6 +12581,10 @@ inline flatbuffers::Offset<void> BuiltinOptionsUnion::Pack(flatbuffers::FlatBuff
auto ptr = reinterpret_cast<const ReverseSequenceOptionsT *>(value);
return CreateReverseSequenceOptions(_fbb, ptr, _rehasher).Union();
}
case BuiltinOptions_MatrixDiagOptions: {
auto ptr = reinterpret_cast<const MatrixDiagOptionsT *>(value);
return CreateMatrixDiagOptions(_fbb, ptr, _rehasher).Union();
}
default: return 0;
}
}
@ -12836,6 +12939,10 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) FL
value = new ReverseSequenceOptionsT(*reinterpret_cast<ReverseSequenceOptionsT *>(u.value));
break;
}
case BuiltinOptions_MatrixDiagOptions: {
value = new MatrixDiagOptionsT(*reinterpret_cast<MatrixDiagOptionsT *>(u.value));
break;
}
default:
break;
}
@ -13278,6 +13385,11 @@ inline void BuiltinOptionsUnion::Reset() {
delete ptr;
break;
}
case BuiltinOptions_MatrixDiagOptions: {
auto ptr = reinterpret_cast<MatrixDiagOptionsT *>(value);
delete ptr;
break;
}
default: break;
}
value = nullptr;

View File

@ -4361,6 +4361,33 @@ def make_reverse_sequence_tests(zip_path):
make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
def make_matrix_diag_tests(zip_path):
"""Make a set of tests for tf.matrix_diag op."""
test_parameters = [
{
"input_shape": [[3], [2, 3], [3, 4, 5], [2, 4, 6, 8]],
"input_dtype": [tf.int32, tf.float32],
},
]
def build_graph(parameters):
input_tensor = tf.placeholder(
dtype=parameters["input_dtype"],
name="input",
shape=parameters["input_shape"])
outs = tf.matrix_diag(input_tensor)
return [input_tensor], [outs]
def build_inputs(parameters, sess, inputs, outputs):
input_values = create_tensor_data(parameters["input_dtype"],
parameters["input_shape"])
return [input_values], sess.run(
outputs, feed_dict=dict(zip(inputs, [input_values])))
make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
@test_util.enable_control_flow_v2
def make_unidirectional_sequence_lstm_tests(zip_path):
"""Make a set of tests to do unidirectional_sequence_lstm."""

View File

@ -286,6 +286,13 @@ void SetDataTypeForAllOutputs(Model* model, Operator* op,
// have data type fields for all their arrays.
break;
}
case OperatorType::kMatrixDiag: {
CHECK_EQ(op->inputs.size(), 1);
CHECK_EQ(op->outputs.size(), 1);
const ArrayDataType data_type = model->GetArray(op->inputs[0]).data_type;
SetDataTypeForAllOutputs(model, op, data_type);
break;
}
default: {
// These operators produce outputs with the same type as their 1st input
CHECK_GT(op->inputs.size(), 0);

View File

@ -2060,6 +2060,24 @@ void ProcessUniqueOperator(Model* model, UniqueOperator* op) {
idx_output_array.copy_shape(input_array.shape());
}
void ProcessMatrixDiagOperator(Model* model, MatrixDiagOperator* op) {
CHECK_EQ(op->inputs.size(), 1);
CHECK_EQ(op->outputs.size(), 1);
auto& output_array = model->GetArray(op->outputs[0]);
if (output_array.has_shape()) {
// We have already run
return;
}
// Get the input_shape
auto& input_array = model->GetArray(op->inputs[0]);
Shape* mutable_shape = input_array.mutable_shape();
std::vector<int>* dims = mutable_shape->mutable_dims();
int dims_size = dims->size();
int last_dim = (*dims)[dims_size - 1];
dims->push_back(last_dim);
output_array.copy_shape(*mutable_shape);
}
} // namespace
::tensorflow::Status PropagateFixedSizes::Run(Model* model,
@ -2363,6 +2381,9 @@ void ProcessUniqueOperator(Model* model, UniqueOperator* op) {
// tensor. Ignore shape propagation here and defer that to the
// interpreter.
break;
case OperatorType::kMatrixDiag:
ProcessMatrixDiagOperator(model, static_cast<MatrixDiagOperator*>(op));
break;
default:
// Unimplemented, another graph transformation should drop it.
LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(op->type);

View File

@ -2471,6 +2471,7 @@ ConverterMapType GetTensorFlowNodeConverterMap() {
{"LogicalNot", ConvertSimpleOperator<LogicalNotOperator, 1, 1>},
{"LogSoftmax", ConvertSimpleOperator<LogSoftmaxOperator, 1, 1>},
{"MatMul", ConvertMatMulOperator},
{"MatrixDiag", ConvertSimpleOperator<MatrixDiagOperator, 1, 1>},
{"Max", ConvertReduceOperator<TensorFlowMaxOperator>},
{"MaxPool", ConvertMaxPoolOperator},
{"Maximum", ConvertSimpleOperator<TensorFlowMaximumOperator, 2, 1>},

View File

@ -168,7 +168,8 @@ enum class OperatorType : uint8 {
kGatherNd,
kWhere,
kElu,
kReverseSequence
kReverseSequence,
kMatrixDiag
};
// Helper to deal with TensorFlow arrays using a different ordering of
@ -2075,6 +2076,14 @@ struct WhereOperator : Operator {
WhereOperator() : Operator(OperatorType::kWhere) {}
};
// Matrix Diag Operator:
// Construct a batched diagonal tensor with given batched diagonal values.
// Inputs: A tensor of values that will be on the diagonal of the returned
// tensor.
struct MatrixDiagOperator : Operator {
MatrixDiagOperator() : Operator(OperatorType::kMatrixDiag) {}
};
// Alloc's are used for transient arrays only. An Alloc specifies which interval
// of the "transient_data" workspace buffer passed to inference functions, is to
// be used for the transient array at hand. The 'start' and 'end' values are

View File

@ -2476,7 +2476,8 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList(
ops.push_back(
MakeUnique<ReverseSequence>(::tflite::BuiltinOperator_REVERSE_SEQUENCE,
OperatorType::kReverseSequence));
ops.push_back(MakeUnique<SimpleOperator<MatrixDiagOperator>>(
"MATRIX_DIAG", OperatorType::kMatrixDiag));
// Custom Operators.
ops.push_back(
MakeUnique<DepthToSpace>("DEPTH_TO_SPACE", OperatorType::kDepthToSpace));

View File

@ -705,6 +705,13 @@ TEST_F(OperatorTest, BuiltinReverseSequence) {
EXPECT_EQ(op.batch_dim, output_toco_op->batch_dim);
}
TEST_F(OperatorTest, BuiltinMatrixDiag) {
MatrixDiagOperator op;
std::unique_ptr<toco::MatrixDiagOperator> output_toco_op =
SerializeAndDeserialize(
GetOperator("MATRIX_DIAG", OperatorType::kMatrixDiag), op);
}
// Test version for a simple Op with 2 versions and the input type controls the
// version.
template <typename Op>

View File

@ -427,6 +427,7 @@ const char* OperatorTypeName(OperatorType type) {
HANDLE_OPERATORTYPENAME_CASE(Cos)
HANDLE_OPERATORTYPENAME_CASE(Where)
HANDLE_OPERATORTYPENAME_CASE(ReverseSequence)
HANDLE_OPERATORTYPENAME_CASE(MatrixDiag)
default:
LOG(FATAL) << "Unhandled op type";
#undef HANDLE_OPERATORTYPENAME_CASE