Implement Matrix Diag
PiperOrigin-RevId: 239061891
This commit is contained in:
parent
b7a36dec6b
commit
d79bb04e2a
@ -274,6 +274,7 @@ def generated_test_models():
|
||||
"logical_or",
|
||||
"logical_xor",
|
||||
"lstm",
|
||||
"matrix_diag",
|
||||
"max_pool",
|
||||
"maximum",
|
||||
"mean",
|
||||
|
@ -138,6 +138,7 @@ typedef enum {
|
||||
kTfLiteBuiltinRank = 110,
|
||||
kTfLiteBuiltinElu = 111,
|
||||
kTfLiteBuiltinReverseSequence = 112,
|
||||
kTfLiteBuiltinMatrixDiag = 113,
|
||||
} TfLiteBuiltinOperator;
|
||||
|
||||
#ifdef __cplusplus
|
||||
|
@ -373,6 +373,10 @@ typedef struct {
|
||||
int batch_dim;
|
||||
} TfLiteReverseSequenceParams;
|
||||
|
||||
typedef struct {
|
||||
EmptyStructPlaceholder placeholder;
|
||||
} TfLiteMatrixDiagParams;
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // extern "C"
|
||||
#endif // __cplusplus
|
||||
|
@ -683,8 +683,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
|
||||
*builtin_data = reinterpret_cast<void*>(params);
|
||||
break;
|
||||
}
|
||||
|
||||
// Below are the ops with no builtin_data strcture.
|
||||
// Below are the ops with no builtin_data structure.
|
||||
case BuiltinOperator_ABS:
|
||||
case BuiltinOperator_BATCH_TO_SPACE_ND:
|
||||
// TODO(aselle): Implement call in BuiltinOptions, but nullptrs are
|
||||
@ -708,6 +707,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
|
||||
case BuiltinOperator_LOG:
|
||||
case BuiltinOperator_LOGISTIC:
|
||||
case BuiltinOperator_LOG_SOFTMAX:
|
||||
case BuiltinOperator_MATRIX_DIAG:
|
||||
case BuiltinOperator_MAXIMUM:
|
||||
case BuiltinOperator_MINIMUM:
|
||||
case BuiltinOperator_NEG:
|
||||
|
@ -188,6 +188,7 @@ cc_library(
|
||||
"logical.cc",
|
||||
"lsh_projection.cc",
|
||||
"lstm.cc",
|
||||
"matrix_diag.cc",
|
||||
"maximum_minimum.cc",
|
||||
"mfcc.cc",
|
||||
"mirror_pad.cc",
|
||||
@ -1388,3 +1389,15 @@ cc_test(
|
||||
"@com_google_googletest//:gtest",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "matrix_diag_test",
|
||||
size = "small",
|
||||
srcs = ["matrix_diag_test.cc"],
|
||||
deps = [
|
||||
":builtin_ops",
|
||||
"//tensorflow/lite:framework",
|
||||
"//tensorflow/lite/kernels:test_util",
|
||||
"@com_google_googletest//:gtest",
|
||||
],
|
||||
)
|
||||
|
136
tensorflow/lite/kernels/matrix_diag.cc
Normal file
136
tensorflow/lite/kernels/matrix_diag.cc
Normal file
@ -0,0 +1,136 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include <string.h>
|
||||
#include <vector>
|
||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||
#include "tensorflow/lite/c/c_api_internal.h"
|
||||
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
|
||||
#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/kernels/op_macros.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace ops {
|
||||
namespace builtin {
|
||||
namespace matrix_diag {
|
||||
|
||||
constexpr int kInputTensor = 0;
|
||||
constexpr int kOutputTensor = 0;
|
||||
|
||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
TfLiteIntArray* input_dims = input->dims;
|
||||
int input_dims_size = input_dims->size;
|
||||
TF_LITE_ENSURE(context, input_dims_size >= 1);
|
||||
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
// Resize the output tensor.
|
||||
TfLiteIntArray* output_shape = TfLiteIntArrayCreate(input_dims_size + 1);
|
||||
for (int i = 0; i < input_dims_size; i++) {
|
||||
output_shape->data[i] = input_dims->data[i];
|
||||
}
|
||||
// Last dimension in the output is the same as the last dimension in the
|
||||
// input.
|
||||
output_shape->data[input_dims_size] = input_dims->data[input_dims_size - 1];
|
||||
output->type = input->type;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
context->ResizeTensor(context, output, output_shape));
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
// Fill the tensor to make a diagonal matrix in each batch, i.e., when
|
||||
// row index and column index are the same, fill with the next input value.
|
||||
// All other entries get zero.
|
||||
// TODO(b/128636574) Move to reference_ops.
|
||||
template <typename T>
|
||||
void FillDiagImpl(const T* in, T* out, const int batch_size, const int row_size,
|
||||
const int col_size) {
|
||||
int idx = 0;
|
||||
for (int b = 0; b < batch_size; b++) {
|
||||
for (int i = 0; i < row_size; i++) {
|
||||
for (int j = 0; j < col_size; ++j) {
|
||||
// input values go on the diagonal, 0 elsewhere
|
||||
if (i == j) {
|
||||
out[i * col_size + j] = in[idx];
|
||||
idx++;
|
||||
} else {
|
||||
out[i * col_size + j] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
out += row_size * col_size;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void FillDiag(const TfLiteTensor* input, TfLiteTensor* output,
|
||||
const int batch_size, const int row_size, const int col_size) {
|
||||
FillDiagImpl<T>(GetTensorData<T>(input), GetTensorData<T>(output), batch_size,
|
||||
row_size, col_size);
|
||||
}
|
||||
|
||||
// Fill a tensor with given input on the diagonal, zero elsewhere
|
||||
void FillDiagHelper(const TfLiteTensor* input, TfLiteTensor* output) {
|
||||
const int num_output_dims = output->dims->size;
|
||||
int batch_size = 1;
|
||||
for (int i = 0; i < num_output_dims - 2; ++i) {
|
||||
batch_size *= output->dims->data[i];
|
||||
}
|
||||
|
||||
const int row_size = output->dims->data[num_output_dims - 2];
|
||||
const int col_size = output->dims->data[num_output_dims - 1];
|
||||
switch (output->type) {
|
||||
case kTfLiteInt64: {
|
||||
return FillDiag<int64_t>(input, output, batch_size, row_size, col_size);
|
||||
}
|
||||
case kTfLiteInt32: {
|
||||
return FillDiag<int32_t>(input, output, batch_size, row_size, col_size);
|
||||
}
|
||||
case kTfLiteInt16: {
|
||||
return FillDiag<int16_t>(input, output, batch_size, row_size, col_size);
|
||||
}
|
||||
case kTfLiteInt8: {
|
||||
return FillDiag<int8_t>(input, output, batch_size, row_size, col_size);
|
||||
}
|
||||
case kTfLiteUInt8: {
|
||||
return FillDiag<uint8_t>(input, output, batch_size, row_size, col_size);
|
||||
}
|
||||
default:
|
||||
return FillDiag<float_t>(input, output, batch_size, row_size, col_size);
|
||||
}
|
||||
}
|
||||
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
FillDiagHelper(input, output);
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
} // namespace matrix_diag
|
||||
|
||||
TfLiteRegistration* Register_MATRIX_DIAG() {
|
||||
static TfLiteRegistration r = {nullptr, nullptr, matrix_diag::Prepare,
|
||||
matrix_diag::Eval};
|
||||
return &r;
|
||||
}
|
||||
|
||||
} // namespace builtin
|
||||
} // namespace ops
|
||||
} // namespace tflite
|
110
tensorflow/lite/kernels/matrix_diag_test.cc
Normal file
110
tensorflow/lite/kernels/matrix_diag_test.cc
Normal file
@ -0,0 +1,110 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/lite/interpreter.h"
|
||||
#include "tensorflow/lite/kernels/register.h"
|
||||
#include "tensorflow/lite/kernels/test_util.h"
|
||||
#include "tensorflow/lite/model.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace {
|
||||
|
||||
using ::testing::ElementsAre;
|
||||
using ::testing::ElementsAreArray;
|
||||
|
||||
template <typename T>
|
||||
class MatrixDiagOpModel : public SingleOpModel {
|
||||
public:
|
||||
explicit MatrixDiagOpModel(const TensorData& input) {
|
||||
input_ = AddInput(input);
|
||||
output_ = AddOutput({input.type, {}});
|
||||
|
||||
SetBuiltinOp(BuiltinOperator_MATRIX_DIAG, BuiltinOptions_MatrixDiagOptions,
|
||||
CreateMatrixDiagOptions(builder_).Union());
|
||||
BuildInterpreter({GetShape(input_)});
|
||||
}
|
||||
|
||||
int input() { return input_; }
|
||||
|
||||
std::vector<T> GetOutput() { return ExtractVector<T>(output_); }
|
||||
std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
|
||||
TfLiteType GetOutputType() {
|
||||
TfLiteTensor* t = interpreter_->tensor(output_);
|
||||
return t->type;
|
||||
}
|
||||
|
||||
private:
|
||||
int input_;
|
||||
int output_;
|
||||
};
|
||||
|
||||
// Use the machinery of TYPED_TEST_SUITE to test all supported types.
|
||||
// See
|
||||
// https://github.com/google/googletest/blob/master/googletest/docs/advanced.md#typed-tests
|
||||
// for details.
|
||||
template <typename T>
|
||||
class MatrixDiagOpTest : public ::testing::Test {};
|
||||
|
||||
using TypesUnderTest =
|
||||
::testing::Types<TypeUnion<int32_t>, TypeUnion<float_t>, TypeUnion<int16_t>,
|
||||
TypeUnion<int8_t>, TypeUnion<uint8_t>>;
|
||||
TYPED_TEST_SUITE(MatrixDiagOpTest, TypesUnderTest);
|
||||
|
||||
TYPED_TEST(MatrixDiagOpTest, ThreeByThreeDiag) {
|
||||
MatrixDiagOpModel<typename TypeParam::ScalarType> model(
|
||||
{TypeParam::tensor_type, {3}});
|
||||
model.template PopulateTensor<typename TypeParam::ScalarType>(model.input(),
|
||||
{1, 2, 3});
|
||||
model.Invoke();
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAre(3, 3));
|
||||
EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 0, 0, //
|
||||
0, 2, 0, //
|
||||
0, 0, 3}));
|
||||
EXPECT_THAT(model.GetOutputType(), TypeParam::tflite_type);
|
||||
}
|
||||
|
||||
// Additional special cases.
|
||||
TEST(MatrixDiagTest, Int32TestTwoDimDiag) {
|
||||
MatrixDiagOpModel<int32_t> model({TensorType_INT32, {2, 4}});
|
||||
model.PopulateTensor<int32_t>(model.input(), {1, 2, 3, 4, 5, 6, 7, 8});
|
||||
model.Invoke();
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAre(2, 4, 4));
|
||||
EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 0, 0, 0, //
|
||||
0, 2, 0, 0, //
|
||||
0, 0, 3, 0, //
|
||||
0, 0, 0, 4, //
|
||||
5, 0, 0, 0, //
|
||||
0, 6, 0, 0, //
|
||||
0, 0, 7, 0, //
|
||||
0, 0, 0, 8}));
|
||||
EXPECT_THAT(model.GetOutputType(), TfLiteType::kTfLiteInt32);
|
||||
}
|
||||
|
||||
TEST(MatrixDiagTest, DegenenerateCase) {
|
||||
MatrixDiagOpModel<uint8_t> model({TensorType_UINT8, {1}});
|
||||
model.PopulateTensor<uint8_t>(model.input(), {1});
|
||||
model.Invoke();
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1));
|
||||
EXPECT_THAT(model.GetOutput(), ElementsAreArray({1}));
|
||||
EXPECT_THAT(model.GetOutputType(), TfLiteType::kTfLiteUInt8);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tflite
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
::tflite::LogToStderr();
|
||||
::testing::InitGoogleTest(&argc, argv);
|
||||
return RUN_ALL_TESTS();
|
||||
}
|
@ -139,6 +139,7 @@ TfLiteRegistration* Register_GATHER_ND();
|
||||
TfLiteRegistration* Register_WHERE();
|
||||
TfLiteRegistration* Register_ELU();
|
||||
TfLiteRegistration* Register_REVERSE_SEQUENCE();
|
||||
TfLiteRegistration* Register_MATRIX_DIAG();
|
||||
|
||||
TfLiteStatus UnsupportedTensorFlowOp(TfLiteContext* context, TfLiteNode* node) {
|
||||
context->ReportError(
|
||||
@ -374,6 +375,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
|
||||
AddBuiltin(BuiltinOperator_WHERE, Register_WHERE());
|
||||
AddBuiltin(BuiltinOperator_ELU, Register_ELU());
|
||||
AddBuiltin(BuiltinOperator_REVERSE_SEQUENCE, Register_REVERSE_SEQUENCE());
|
||||
AddBuiltin(BuiltinOperator_MATRIX_DIAG, Register_MATRIX_DIAG());
|
||||
|
||||
// TODO(andrewharp, ahentz): Move these somewhere more appropriate so that
|
||||
// custom ops aren't always included by default.
|
||||
|
@ -522,6 +522,56 @@ TensorType GetTensorType() {
|
||||
// Strings have a special implementation that is in test_util.cc
|
||||
template <>
|
||||
std::vector<string> SingleOpModel::ExtractVector(int index);
|
||||
|
||||
// The TypeUnion struct specializations hold a collection of related types.
|
||||
// Each struct holds: 1. a primitive type (e.g. float), 2. a TensorType (e.g.
|
||||
// TensorType_FLOAT32, and 3. a TfLiteType (e.g. kTfLiteFloat32). The latter
|
||||
// two are actually enum values and not raw types, but these specializations
|
||||
// make it easy to use gUnit Typed Test Suite:
|
||||
// https://github.com/google/googletest/blob/master/googletest/docs/advanced.md#typed-tests
|
||||
template <typename T>
|
||||
struct TypeUnion;
|
||||
|
||||
template <>
|
||||
struct TypeUnion<float_t> {
|
||||
public:
|
||||
static const TensorType tensor_type = TensorType::TensorType_FLOAT32;
|
||||
static const TfLiteType tflite_type = TfLiteType::kTfLiteFloat32;
|
||||
typedef float_t ScalarType;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TypeUnion<int32_t> {
|
||||
public:
|
||||
static const TensorType tensor_type = TensorType::TensorType_INT32;
|
||||
static const TfLiteType tflite_type = TfLiteType::kTfLiteInt32;
|
||||
typedef int32_t ScalarType;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TypeUnion<int16_t> {
|
||||
public:
|
||||
static const TensorType tensor_type = TensorType::TensorType_INT16;
|
||||
static const TfLiteType tflite_type = TfLiteType::kTfLiteInt16;
|
||||
typedef int16_t ScalarType;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TypeUnion<int8_t> {
|
||||
public:
|
||||
static const TensorType tensor_type = TensorType::TensorType_INT8;
|
||||
static const TfLiteType tflite_type = TfLiteType::kTfLiteInt8;
|
||||
typedef int8_t ScalarType;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TypeUnion<uint8_t> {
|
||||
public:
|
||||
static const TensorType tensor_type = TensorType::TensorType_UINT8;
|
||||
static const TfLiteType tflite_type = TfLiteType::kTfLiteUInt8;
|
||||
typedef uint8_t ScalarType;
|
||||
};
|
||||
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_KERNELS_TEST_UTIL_H_
|
||||
|
@ -671,6 +671,7 @@ TfLiteStatus AddOpsAndParams(
|
||||
case tflite::BuiltinOperator_RANK:
|
||||
case tflite::BuiltinOperator_ELU:
|
||||
case tflite::BuiltinOperator_REVERSE_SEQUENCE:
|
||||
case tflite::BuiltinOperator_MATRIX_DIAG:
|
||||
logError("Op code %d is currently not delegated to NNAPI", builtin);
|
||||
return kTfLiteError;
|
||||
break;
|
||||
|
@ -226,6 +226,7 @@ enum BuiltinOperator : byte {
|
||||
RANK = 110,
|
||||
ELU = 111,
|
||||
REVERSE_SEQUENCE = 112,
|
||||
MATRIX_DIAG = 113,
|
||||
}
|
||||
|
||||
// Options for the builtin operators.
|
||||
@ -317,6 +318,7 @@ union BuiltinOptions {
|
||||
WhereOptions,
|
||||
RankOptions,
|
||||
ReverseSequenceOptions,
|
||||
MatrixDiagOptions,
|
||||
}
|
||||
|
||||
enum Padding : byte { SAME, VALID }
|
||||
@ -756,6 +758,10 @@ table ReverseSequenceOptions {
|
||||
seq_dim:int;
|
||||
batch_dim:int = 0;
|
||||
}
|
||||
|
||||
table MatrixDiagOptions {
|
||||
}
|
||||
|
||||
// An OperatorCode can be an enum value (BuiltinOperator) if the operator is a
|
||||
// builtin, or a string if the operator is custom.
|
||||
table OperatorCode {
|
||||
|
@ -292,6 +292,9 @@ struct WhereOptionsT;
|
||||
struct ReverseSequenceOptions;
|
||||
struct ReverseSequenceOptionsT;
|
||||
|
||||
struct MatrixDiagOptions;
|
||||
struct MatrixDiagOptionsT;
|
||||
|
||||
struct OperatorCode;
|
||||
struct OperatorCodeT;
|
||||
|
||||
@ -554,11 +557,12 @@ enum BuiltinOperator {
|
||||
BuiltinOperator_RANK = 110,
|
||||
BuiltinOperator_ELU = 111,
|
||||
BuiltinOperator_REVERSE_SEQUENCE = 112,
|
||||
BuiltinOperator_MATRIX_DIAG = 113,
|
||||
BuiltinOperator_MIN = BuiltinOperator_ADD,
|
||||
BuiltinOperator_MAX = BuiltinOperator_REVERSE_SEQUENCE
|
||||
BuiltinOperator_MAX = BuiltinOperator_MATRIX_DIAG
|
||||
};
|
||||
|
||||
inline const BuiltinOperator (&EnumValuesBuiltinOperator())[112] {
|
||||
inline const BuiltinOperator (&EnumValuesBuiltinOperator())[113] {
|
||||
static const BuiltinOperator values[] = {
|
||||
BuiltinOperator_ADD,
|
||||
BuiltinOperator_AVERAGE_POOL_2D,
|
||||
@ -671,7 +675,8 @@ inline const BuiltinOperator (&EnumValuesBuiltinOperator())[112] {
|
||||
BuiltinOperator_WHERE,
|
||||
BuiltinOperator_RANK,
|
||||
BuiltinOperator_ELU,
|
||||
BuiltinOperator_REVERSE_SEQUENCE
|
||||
BuiltinOperator_REVERSE_SEQUENCE,
|
||||
BuiltinOperator_MATRIX_DIAG
|
||||
};
|
||||
return values;
|
||||
}
|
||||
@ -791,6 +796,7 @@ inline const char * const *EnumNamesBuiltinOperator() {
|
||||
"RANK",
|
||||
"ELU",
|
||||
"REVERSE_SEQUENCE",
|
||||
"MATRIX_DIAG",
|
||||
nullptr
|
||||
};
|
||||
return names;
|
||||
@ -890,11 +896,12 @@ enum BuiltinOptions {
|
||||
BuiltinOptions_WhereOptions = 85,
|
||||
BuiltinOptions_RankOptions = 86,
|
||||
BuiltinOptions_ReverseSequenceOptions = 87,
|
||||
BuiltinOptions_MatrixDiagOptions = 88,
|
||||
BuiltinOptions_MIN = BuiltinOptions_NONE,
|
||||
BuiltinOptions_MAX = BuiltinOptions_ReverseSequenceOptions
|
||||
BuiltinOptions_MAX = BuiltinOptions_MatrixDiagOptions
|
||||
};
|
||||
|
||||
inline const BuiltinOptions (&EnumValuesBuiltinOptions())[88] {
|
||||
inline const BuiltinOptions (&EnumValuesBuiltinOptions())[89] {
|
||||
static const BuiltinOptions values[] = {
|
||||
BuiltinOptions_NONE,
|
||||
BuiltinOptions_Conv2DOptions,
|
||||
@ -983,7 +990,8 @@ inline const BuiltinOptions (&EnumValuesBuiltinOptions())[88] {
|
||||
BuiltinOptions_CosOptions,
|
||||
BuiltinOptions_WhereOptions,
|
||||
BuiltinOptions_RankOptions,
|
||||
BuiltinOptions_ReverseSequenceOptions
|
||||
BuiltinOptions_ReverseSequenceOptions,
|
||||
BuiltinOptions_MatrixDiagOptions
|
||||
};
|
||||
return values;
|
||||
}
|
||||
@ -1078,6 +1086,7 @@ inline const char * const *EnumNamesBuiltinOptions() {
|
||||
"WhereOptions",
|
||||
"RankOptions",
|
||||
"ReverseSequenceOptions",
|
||||
"MatrixDiagOptions",
|
||||
nullptr
|
||||
};
|
||||
return names;
|
||||
@ -1440,6 +1449,10 @@ template<> struct BuiltinOptionsTraits<ReverseSequenceOptions> {
|
||||
static const BuiltinOptions enum_value = BuiltinOptions_ReverseSequenceOptions;
|
||||
};
|
||||
|
||||
template<> struct BuiltinOptionsTraits<MatrixDiagOptions> {
|
||||
static const BuiltinOptions enum_value = BuiltinOptions_MatrixDiagOptions;
|
||||
};
|
||||
|
||||
struct BuiltinOptionsUnion {
|
||||
BuiltinOptions type;
|
||||
void *value;
|
||||
@ -2167,6 +2180,14 @@ struct BuiltinOptionsUnion {
|
||||
return type == BuiltinOptions_ReverseSequenceOptions ?
|
||||
reinterpret_cast<const ReverseSequenceOptionsT *>(value) : nullptr;
|
||||
}
|
||||
MatrixDiagOptionsT *AsMatrixDiagOptions() {
|
||||
return type == BuiltinOptions_MatrixDiagOptions ?
|
||||
reinterpret_cast<MatrixDiagOptionsT *>(value) : nullptr;
|
||||
}
|
||||
const MatrixDiagOptionsT *AsMatrixDiagOptions() const {
|
||||
return type == BuiltinOptions_MatrixDiagOptions ?
|
||||
reinterpret_cast<const MatrixDiagOptionsT *>(value) : nullptr;
|
||||
}
|
||||
};
|
||||
|
||||
bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type);
|
||||
@ -7569,6 +7590,46 @@ inline flatbuffers::Offset<ReverseSequenceOptions> CreateReverseSequenceOptions(
|
||||
|
||||
flatbuffers::Offset<ReverseSequenceOptions> CreateReverseSequenceOptions(flatbuffers::FlatBufferBuilder &_fbb, const ReverseSequenceOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
|
||||
|
||||
struct MatrixDiagOptionsT : public flatbuffers::NativeTable {
|
||||
typedef MatrixDiagOptions TableType;
|
||||
MatrixDiagOptionsT() {
|
||||
}
|
||||
};
|
||||
|
||||
struct MatrixDiagOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
|
||||
typedef MatrixDiagOptionsT NativeTableType;
|
||||
bool Verify(flatbuffers::Verifier &verifier) const {
|
||||
return VerifyTableStart(verifier) &&
|
||||
verifier.EndTable();
|
||||
}
|
||||
MatrixDiagOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
|
||||
void UnPackTo(MatrixDiagOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
|
||||
static flatbuffers::Offset<MatrixDiagOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const MatrixDiagOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
|
||||
};
|
||||
|
||||
struct MatrixDiagOptionsBuilder {
|
||||
flatbuffers::FlatBufferBuilder &fbb_;
|
||||
flatbuffers::uoffset_t start_;
|
||||
explicit MatrixDiagOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
|
||||
: fbb_(_fbb) {
|
||||
start_ = fbb_.StartTable();
|
||||
}
|
||||
MatrixDiagOptionsBuilder &operator=(const MatrixDiagOptionsBuilder &);
|
||||
flatbuffers::Offset<MatrixDiagOptions> Finish() {
|
||||
const auto end = fbb_.EndTable(start_);
|
||||
auto o = flatbuffers::Offset<MatrixDiagOptions>(end);
|
||||
return o;
|
||||
}
|
||||
};
|
||||
|
||||
inline flatbuffers::Offset<MatrixDiagOptions> CreateMatrixDiagOptions(
|
||||
flatbuffers::FlatBufferBuilder &_fbb) {
|
||||
MatrixDiagOptionsBuilder builder_(_fbb);
|
||||
return builder_.Finish();
|
||||
}
|
||||
|
||||
flatbuffers::Offset<MatrixDiagOptions> CreateMatrixDiagOptions(flatbuffers::FlatBufferBuilder &_fbb, const MatrixDiagOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
|
||||
|
||||
struct OperatorCodeT : public flatbuffers::NativeTable {
|
||||
typedef OperatorCode TableType;
|
||||
BuiltinOperator builtin_code;
|
||||
@ -7963,6 +8024,9 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
|
||||
const ReverseSequenceOptions *builtin_options_as_ReverseSequenceOptions() const {
|
||||
return builtin_options_type() == BuiltinOptions_ReverseSequenceOptions ? static_cast<const ReverseSequenceOptions *>(builtin_options()) : nullptr;
|
||||
}
|
||||
const MatrixDiagOptions *builtin_options_as_MatrixDiagOptions() const {
|
||||
return builtin_options_type() == BuiltinOptions_MatrixDiagOptions ? static_cast<const MatrixDiagOptions *>(builtin_options()) : nullptr;
|
||||
}
|
||||
const flatbuffers::Vector<uint8_t> *custom_options() const {
|
||||
return GetPointer<const flatbuffers::Vector<uint8_t> *>(VT_CUSTOM_OPTIONS);
|
||||
}
|
||||
@ -8342,6 +8406,10 @@ template<> inline const ReverseSequenceOptions *Operator::builtin_options_as<Rev
|
||||
return builtin_options_as_ReverseSequenceOptions();
|
||||
}
|
||||
|
||||
template<> inline const MatrixDiagOptions *Operator::builtin_options_as<MatrixDiagOptions>() const {
|
||||
return builtin_options_as_MatrixDiagOptions();
|
||||
}
|
||||
|
||||
struct OperatorBuilder {
|
||||
flatbuffers::FlatBufferBuilder &fbb_;
|
||||
flatbuffers::uoffset_t start_;
|
||||
@ -11156,6 +11224,29 @@ inline flatbuffers::Offset<ReverseSequenceOptions> CreateReverseSequenceOptions(
|
||||
_batch_dim);
|
||||
}
|
||||
|
||||
inline MatrixDiagOptionsT *MatrixDiagOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
|
||||
auto _o = new MatrixDiagOptionsT();
|
||||
UnPackTo(_o, _resolver);
|
||||
return _o;
|
||||
}
|
||||
|
||||
inline void MatrixDiagOptions::UnPackTo(MatrixDiagOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const {
|
||||
(void)_o;
|
||||
(void)_resolver;
|
||||
}
|
||||
|
||||
inline flatbuffers::Offset<MatrixDiagOptions> MatrixDiagOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const MatrixDiagOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
|
||||
return CreateMatrixDiagOptions(_fbb, _o, _rehasher);
|
||||
}
|
||||
|
||||
inline flatbuffers::Offset<MatrixDiagOptions> CreateMatrixDiagOptions(flatbuffers::FlatBufferBuilder &_fbb, const MatrixDiagOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
|
||||
(void)_rehasher;
|
||||
(void)_o;
|
||||
struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const MatrixDiagOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
|
||||
return tflite::CreateMatrixDiagOptions(
|
||||
_fbb);
|
||||
}
|
||||
|
||||
inline OperatorCodeT *OperatorCode::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
|
||||
auto _o = new OperatorCodeT();
|
||||
UnPackTo(_o, _resolver);
|
||||
@ -11762,6 +11853,10 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *ob
|
||||
auto ptr = reinterpret_cast<const ReverseSequenceOptions *>(obj);
|
||||
return verifier.VerifyTable(ptr);
|
||||
}
|
||||
case BuiltinOptions_MatrixDiagOptions: {
|
||||
auto ptr = reinterpret_cast<const MatrixDiagOptions *>(obj);
|
||||
return verifier.VerifyTable(ptr);
|
||||
}
|
||||
default: return false;
|
||||
}
|
||||
}
|
||||
@ -12128,6 +12223,10 @@ inline void *BuiltinOptionsUnion::UnPack(const void *obj, BuiltinOptions type, c
|
||||
auto ptr = reinterpret_cast<const ReverseSequenceOptions *>(obj);
|
||||
return ptr->UnPack(resolver);
|
||||
}
|
||||
case BuiltinOptions_MatrixDiagOptions: {
|
||||
auto ptr = reinterpret_cast<const MatrixDiagOptions *>(obj);
|
||||
return ptr->UnPack(resolver);
|
||||
}
|
||||
default: return nullptr;
|
||||
}
|
||||
}
|
||||
@ -12482,6 +12581,10 @@ inline flatbuffers::Offset<void> BuiltinOptionsUnion::Pack(flatbuffers::FlatBuff
|
||||
auto ptr = reinterpret_cast<const ReverseSequenceOptionsT *>(value);
|
||||
return CreateReverseSequenceOptions(_fbb, ptr, _rehasher).Union();
|
||||
}
|
||||
case BuiltinOptions_MatrixDiagOptions: {
|
||||
auto ptr = reinterpret_cast<const MatrixDiagOptionsT *>(value);
|
||||
return CreateMatrixDiagOptions(_fbb, ptr, _rehasher).Union();
|
||||
}
|
||||
default: return 0;
|
||||
}
|
||||
}
|
||||
@ -12836,6 +12939,10 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) FL
|
||||
value = new ReverseSequenceOptionsT(*reinterpret_cast<ReverseSequenceOptionsT *>(u.value));
|
||||
break;
|
||||
}
|
||||
case BuiltinOptions_MatrixDiagOptions: {
|
||||
value = new MatrixDiagOptionsT(*reinterpret_cast<MatrixDiagOptionsT *>(u.value));
|
||||
break;
|
||||
}
|
||||
default:
|
||||
break;
|
||||
}
|
||||
@ -13278,6 +13385,11 @@ inline void BuiltinOptionsUnion::Reset() {
|
||||
delete ptr;
|
||||
break;
|
||||
}
|
||||
case BuiltinOptions_MatrixDiagOptions: {
|
||||
auto ptr = reinterpret_cast<MatrixDiagOptionsT *>(value);
|
||||
delete ptr;
|
||||
break;
|
||||
}
|
||||
default: break;
|
||||
}
|
||||
value = nullptr;
|
||||
|
@ -4361,6 +4361,33 @@ def make_reverse_sequence_tests(zip_path):
|
||||
make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
|
||||
|
||||
|
||||
def make_matrix_diag_tests(zip_path):
|
||||
"""Make a set of tests for tf.matrix_diag op."""
|
||||
|
||||
test_parameters = [
|
||||
{
|
||||
"input_shape": [[3], [2, 3], [3, 4, 5], [2, 4, 6, 8]],
|
||||
"input_dtype": [tf.int32, tf.float32],
|
||||
},
|
||||
]
|
||||
|
||||
def build_graph(parameters):
|
||||
input_tensor = tf.placeholder(
|
||||
dtype=parameters["input_dtype"],
|
||||
name="input",
|
||||
shape=parameters["input_shape"])
|
||||
outs = tf.matrix_diag(input_tensor)
|
||||
return [input_tensor], [outs]
|
||||
|
||||
def build_inputs(parameters, sess, inputs, outputs):
|
||||
input_values = create_tensor_data(parameters["input_dtype"],
|
||||
parameters["input_shape"])
|
||||
return [input_values], sess.run(
|
||||
outputs, feed_dict=dict(zip(inputs, [input_values])))
|
||||
|
||||
make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
|
||||
|
||||
|
||||
@test_util.enable_control_flow_v2
|
||||
def make_unidirectional_sequence_lstm_tests(zip_path):
|
||||
"""Make a set of tests to do unidirectional_sequence_lstm."""
|
||||
|
@ -286,6 +286,13 @@ void SetDataTypeForAllOutputs(Model* model, Operator* op,
|
||||
// have data type fields for all their arrays.
|
||||
break;
|
||||
}
|
||||
case OperatorType::kMatrixDiag: {
|
||||
CHECK_EQ(op->inputs.size(), 1);
|
||||
CHECK_EQ(op->outputs.size(), 1);
|
||||
const ArrayDataType data_type = model->GetArray(op->inputs[0]).data_type;
|
||||
SetDataTypeForAllOutputs(model, op, data_type);
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
// These operators produce outputs with the same type as their 1st input
|
||||
CHECK_GT(op->inputs.size(), 0);
|
||||
|
@ -2060,6 +2060,24 @@ void ProcessUniqueOperator(Model* model, UniqueOperator* op) {
|
||||
idx_output_array.copy_shape(input_array.shape());
|
||||
}
|
||||
|
||||
void ProcessMatrixDiagOperator(Model* model, MatrixDiagOperator* op) {
|
||||
CHECK_EQ(op->inputs.size(), 1);
|
||||
CHECK_EQ(op->outputs.size(), 1);
|
||||
auto& output_array = model->GetArray(op->outputs[0]);
|
||||
if (output_array.has_shape()) {
|
||||
// We have already run
|
||||
return;
|
||||
}
|
||||
// Get the input_shape
|
||||
auto& input_array = model->GetArray(op->inputs[0]);
|
||||
Shape* mutable_shape = input_array.mutable_shape();
|
||||
std::vector<int>* dims = mutable_shape->mutable_dims();
|
||||
int dims_size = dims->size();
|
||||
int last_dim = (*dims)[dims_size - 1];
|
||||
dims->push_back(last_dim);
|
||||
output_array.copy_shape(*mutable_shape);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
::tensorflow::Status PropagateFixedSizes::Run(Model* model,
|
||||
@ -2363,6 +2381,9 @@ void ProcessUniqueOperator(Model* model, UniqueOperator* op) {
|
||||
// tensor. Ignore shape propagation here and defer that to the
|
||||
// interpreter.
|
||||
break;
|
||||
case OperatorType::kMatrixDiag:
|
||||
ProcessMatrixDiagOperator(model, static_cast<MatrixDiagOperator*>(op));
|
||||
break;
|
||||
default:
|
||||
// Unimplemented, another graph transformation should drop it.
|
||||
LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(op->type);
|
||||
|
@ -2471,6 +2471,7 @@ ConverterMapType GetTensorFlowNodeConverterMap() {
|
||||
{"LogicalNot", ConvertSimpleOperator<LogicalNotOperator, 1, 1>},
|
||||
{"LogSoftmax", ConvertSimpleOperator<LogSoftmaxOperator, 1, 1>},
|
||||
{"MatMul", ConvertMatMulOperator},
|
||||
{"MatrixDiag", ConvertSimpleOperator<MatrixDiagOperator, 1, 1>},
|
||||
{"Max", ConvertReduceOperator<TensorFlowMaxOperator>},
|
||||
{"MaxPool", ConvertMaxPoolOperator},
|
||||
{"Maximum", ConvertSimpleOperator<TensorFlowMaximumOperator, 2, 1>},
|
||||
|
@ -168,7 +168,8 @@ enum class OperatorType : uint8 {
|
||||
kGatherNd,
|
||||
kWhere,
|
||||
kElu,
|
||||
kReverseSequence
|
||||
kReverseSequence,
|
||||
kMatrixDiag
|
||||
};
|
||||
|
||||
// Helper to deal with TensorFlow arrays using a different ordering of
|
||||
@ -2075,6 +2076,14 @@ struct WhereOperator : Operator {
|
||||
WhereOperator() : Operator(OperatorType::kWhere) {}
|
||||
};
|
||||
|
||||
// Matrix Diag Operator:
|
||||
// Construct a batched diagonal tensor with given batched diagonal values.
|
||||
// Inputs: A tensor of values that will be on the diagonal of the returned
|
||||
// tensor.
|
||||
struct MatrixDiagOperator : Operator {
|
||||
MatrixDiagOperator() : Operator(OperatorType::kMatrixDiag) {}
|
||||
};
|
||||
|
||||
// Alloc's are used for transient arrays only. An Alloc specifies which interval
|
||||
// of the "transient_data" workspace buffer passed to inference functions, is to
|
||||
// be used for the transient array at hand. The 'start' and 'end' values are
|
||||
|
@ -2476,7 +2476,8 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList(
|
||||
ops.push_back(
|
||||
MakeUnique<ReverseSequence>(::tflite::BuiltinOperator_REVERSE_SEQUENCE,
|
||||
OperatorType::kReverseSequence));
|
||||
|
||||
ops.push_back(MakeUnique<SimpleOperator<MatrixDiagOperator>>(
|
||||
"MATRIX_DIAG", OperatorType::kMatrixDiag));
|
||||
// Custom Operators.
|
||||
ops.push_back(
|
||||
MakeUnique<DepthToSpace>("DEPTH_TO_SPACE", OperatorType::kDepthToSpace));
|
||||
|
@ -705,6 +705,13 @@ TEST_F(OperatorTest, BuiltinReverseSequence) {
|
||||
EXPECT_EQ(op.batch_dim, output_toco_op->batch_dim);
|
||||
}
|
||||
|
||||
TEST_F(OperatorTest, BuiltinMatrixDiag) {
|
||||
MatrixDiagOperator op;
|
||||
std::unique_ptr<toco::MatrixDiagOperator> output_toco_op =
|
||||
SerializeAndDeserialize(
|
||||
GetOperator("MATRIX_DIAG", OperatorType::kMatrixDiag), op);
|
||||
}
|
||||
|
||||
// Test version for a simple Op with 2 versions and the input type controls the
|
||||
// version.
|
||||
template <typename Op>
|
||||
|
@ -427,6 +427,7 @@ const char* OperatorTypeName(OperatorType type) {
|
||||
HANDLE_OPERATORTYPENAME_CASE(Cos)
|
||||
HANDLE_OPERATORTYPENAME_CASE(Where)
|
||||
HANDLE_OPERATORTYPENAME_CASE(ReverseSequence)
|
||||
HANDLE_OPERATORTYPENAME_CASE(MatrixDiag)
|
||||
default:
|
||||
LOG(FATAL) << "Unhandled op type";
|
||||
#undef HANDLE_OPERATORTYPENAME_CASE
|
||||
|
Loading…
Reference in New Issue
Block a user