Implement MatrixSetDiag and Eye

PiperOrigin-RevId: 239662752
This commit is contained in:
A. Unique TensorFlower 2019-03-21 13:40:57 -07:00 committed by TensorFlower Gardener
parent 08cbe99299
commit eae92e9d58
20 changed files with 557 additions and 10 deletions

View File

@ -248,6 +248,7 @@ def generated_test_models():
"equal",
"exp",
"expand_dims",
"eye",
"fill",
"floor",
"floor_div",
@ -275,6 +276,7 @@ def generated_test_models():
"logical_xor",
"lstm",
"matrix_diag",
"matrix_set_diag",
"max_pool",
"maximum",
"mean",

View File

@ -140,6 +140,7 @@ typedef enum {
kTfLiteBuiltinReverseSequence = 112,
kTfLiteBuiltinMatrixDiag = 113,
kTfLiteBuiltinQuantize = 114,
kTfLiteBuiltinMatrixSetDiag = 115,
} TfLiteBuiltinOperator;
#ifdef __cplusplus

View File

@ -377,6 +377,10 @@ typedef struct {
EmptyStructPlaceholder placeholder;
} TfLiteMatrixDiagParams;
typedef struct {
EmptyStructPlaceholder placeholder;
} TfLiteMatrixSetDiagParams;
#ifdef __cplusplus
} // extern "C"
#endif // __cplusplus

View File

@ -708,6 +708,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
case BuiltinOperator_LOGISTIC:
case BuiltinOperator_LOG_SOFTMAX:
case BuiltinOperator_MATRIX_DIAG:
case BuiltinOperator_MATRIX_SET_DIAG:
case BuiltinOperator_MAXIMUM:
case BuiltinOperator_MINIMUM:
case BuiltinOperator_NEG:

View File

@ -189,6 +189,7 @@ cc_library(
"lsh_projection.cc",
"lstm.cc",
"matrix_diag.cc",
"matrix_set_diag.cc",
"maximum_minimum.cc",
"mfcc.cc",
"mirror_pad.cc",
@ -1415,3 +1416,15 @@ cc_test(
"@com_google_googletest//:gtest",
],
)
cc_test(
name = "matrix_set_diag_test",
size = "small",
srcs = ["matrix_set_diag_test.cc"],
deps = [
":builtin_ops",
"//tensorflow/lite:framework",
"//tensorflow/lite/kernels:test_util",
"@com_google_googletest//:gtest",
],
)

View File

@ -0,0 +1,147 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <string.h>
#include <vector>
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/c_api_internal.h"
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/lite/kernels/internal/tensor.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/kernels/op_macros.h"
namespace tflite {
namespace ops {
namespace builtin {
namespace matrix_set_diag {
constexpr int kInputTensor = 0;
constexpr int kDiagonalTensor = 1;
constexpr int kOutputTensor = 0;
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TfLiteIntArray* input_dims = input->dims;
int input_dims_size = input_dims->size;
TF_LITE_ENSURE(context, input_dims_size >= 2);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TfLiteIntArray* output_shape = TfLiteIntArrayCreate(input_dims_size);
for (int i = 0; i < input_dims_size; i++) {
output_shape->data[i] = input_dims->data[i];
}
// Resize the output tensor to the same size as the input tensor.
output->type = input->type;
TF_LITE_ENSURE_OK(context,
context->ResizeTensor(context, output, output_shape));
return kTfLiteOk;
}
// Fill the tensor to make a diagonal matrix in each batch, i.e., when
// row index and column index are the same, fill with the next diagonal value.
// All other entries are the same as the input value.
// TODO(b/128636574) Move to reference_ops.
template <typename T>
void FillDiagImpl(const T* in, const T* diag, T* out, const int batch_size,
const int row_size, const int col_size) {
int idx = 0;
for (int b = 0; b < batch_size; b++) {
for (int i = 0; i < row_size; i++) {
for (int j = 0; j < col_size; ++j) {
// diag values go on the diagonal, in values elsewhere
if (i == j) {
out[i * col_size + j] = diag[idx];
idx++;
} else {
out[i * col_size + j] = in[i * col_size + j];
}
}
}
out += row_size * col_size;
in += row_size * col_size;
}
}
template <typename T>
void FillDiag(const TfLiteTensor* input, const TfLiteTensor* diag,
TfLiteTensor* output, const int batch_size, const int row_size,
const int col_size) {
FillDiagImpl<T>(GetTensorData<T>(input), GetTensorData<T>(diag),
GetTensorData<T>(output), batch_size, row_size, col_size);
}
// Fill a tensor with given "diag" values on the diagonal, input values
// elsewhere.
void FillDiagHelper(const TfLiteTensor* input, const TfLiteTensor* diag,
TfLiteTensor* output) {
const int num_output_dims = output->dims->size;
int batch_size = 1;
for (int i = 0; i < num_output_dims - 2; ++i) {
batch_size *= output->dims->data[i];
}
const int row_size = output->dims->data[num_output_dims - 2];
const int col_size = output->dims->data[num_output_dims - 1];
switch (output->type) {
case kTfLiteInt64: {
return FillDiag<int64_t>(input, diag, output, batch_size, row_size,
col_size);
}
case kTfLiteInt32: {
return FillDiag<int32_t>(input, diag, output, batch_size, row_size,
col_size);
}
case kTfLiteInt16: {
return FillDiag<int16_t>(input, diag, output, batch_size, row_size,
col_size);
}
case kTfLiteInt8: {
return FillDiag<int8_t>(input, diag, output, batch_size, row_size,
col_size);
}
case kTfLiteUInt8: {
return FillDiag<uint8_t>(input, diag, output, batch_size, row_size,
col_size);
}
default:
return FillDiag<float>(input, diag, output, batch_size, row_size,
col_size);
}
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* diag = GetInput(context, node, kDiagonalTensor);
FillDiagHelper(input, diag, output);
return kTfLiteOk;
}
} // namespace matrix_set_diag
TfLiteRegistration* Register_MATRIX_SET_DIAG() {
static TfLiteRegistration r = {nullptr, nullptr, matrix_set_diag::Prepare,
matrix_set_diag::Eval};
return &r;
}
} // namespace builtin
} // namespace ops
} // namespace tflite

View File

@ -0,0 +1,132 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/kernels/register.h"
#include "tensorflow/lite/kernels/test_util.h"
#include "tensorflow/lite/model.h"
namespace tflite {
namespace {
using ::testing::ElementsAre;
using ::testing::ElementsAreArray;
template <typename T>
class MatrixSetDiagOpModel : public SingleOpModel {
public:
explicit MatrixSetDiagOpModel(const TensorData& input,
const TensorData& diag) {
input_ = AddInput(input);
diag_ = AddInput(diag);
output_ = AddOutput({input.type, {}});
SetBuiltinOp(BuiltinOperator_MATRIX_SET_DIAG,
BuiltinOptions_MatrixSetDiagOptions,
CreateMatrixSetDiagOptions(builder_).Union());
BuildInterpreter({GetShape(input_), GetShape(diag_)});
}
int input() { return input_; }
int diag() { return diag_; }
std::vector<T> GetOutput() { return ExtractVector<T>(output_); }
std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
TfLiteType GetOutputType() {
TfLiteTensor* t = interpreter_->tensor(output_);
return t->type;
}
private:
int input_;
int diag_;
int output_;
};
// Use the machinery of TYPED_TEST_SUITE to test all supported types.
// See
// https://github.com/google/googletest/blob/master/googletest/docs/advanced.md#typed-tests
// for details.
template <typename T>
class MatrixSetDiagOpTest : public ::testing::Test {};
using TypesUnderTest =
::testing::Types<TypeUnion<int32_t>, TypeUnion<float>, TypeUnion<int16_t>,
TypeUnion<int8_t>, TypeUnion<uint8_t>>;
TYPED_TEST_SUITE(MatrixSetDiagOpTest, TypesUnderTest);
TYPED_TEST(MatrixSetDiagOpTest, ThreeByThreeDiagScatter) {
MatrixSetDiagOpModel<typename TypeParam::ScalarType> model(
{TypeParam::tensor_type, {3, 3}}, {TypeParam::tensor_type, {3}});
model.template PopulateTensor<typename TypeParam::ScalarType>(model.input(),
{7, 1, 2, //
3, 8, 4, //
5, 6, 9});
model.template PopulateTensor<typename TypeParam::ScalarType>(model.diag(),
{0, 4, 2});
model.Invoke();
EXPECT_THAT(model.GetOutputShape(), ElementsAre(3, 3));
EXPECT_THAT(model.GetOutput(), ElementsAreArray({0, 1, 2, //
3, 4, 4, //
5, 6, 2}));
EXPECT_THAT(model.GetOutputType(), TypeParam::tflite_type);
}
TEST(MatrixSetDiagTest, Int32TestMoreColumnsThanRows) {
MatrixSetDiagOpModel<int32_t> model({TensorType_INT32, {2, 3}},
{TensorType_INT32, {2}});
model.PopulateTensor<int32_t>(model.input(), {0, 0, 0, //
9, 9, 9});
model.PopulateTensor<int32_t>(model.diag(), {1, 1});
model.Invoke();
EXPECT_THAT(model.GetOutputShape(), ElementsAre(2, 3));
EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 0, 0, //
9, 1, 9}));
EXPECT_THAT(model.GetOutputType(), TfLiteType::kTfLiteInt32);
}
TEST(MatrixSetDiagTest, Int32TestTwoDimDiag) {
MatrixSetDiagOpModel<int32_t> model({TensorType_INT32, {2, 4, 4}},
{TensorType_INT32, {2, 4}});
model.PopulateTensor<int32_t>(model.input(), {5, 5, 5, 5, //
5, 5, 5, 5, //
5, 5, 5, 5, //
5, 5, 5, 5, //
1, 1, 1, 1, //
1, 1, 1, 1, //
1, 1, 1, 1, //
1, 1, 1, 1});
model.PopulateTensor<int32_t>(model.diag(), {1, 2, 3, 4, 5, 6, 7, 8});
model.Invoke();
EXPECT_THAT(model.GetOutputShape(), ElementsAre(2, 4, 4));
EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 5, 5, 5, //
5, 2, 5, 5, //
5, 5, 3, 5, //
5, 5, 5, 4, //
5, 1, 1, 1, //
1, 6, 1, 1, //
1, 1, 7, 1, //
1, 1, 1, 8}));
EXPECT_THAT(model.GetOutputType(), TfLiteType::kTfLiteInt32);
}
} // namespace
} // namespace tflite
int main(int argc, char** argv) {
::tflite::LogToStderr();
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}

View File

@ -141,6 +141,7 @@ TfLiteRegistration* Register_ELU();
TfLiteRegistration* Register_REVERSE_SEQUENCE();
TfLiteRegistration* Register_MATRIX_DIAG();
TfLiteRegistration* Register_QUANTIZE();
TfLiteRegistration* Register_MATRIX_SET_DIAG();
TfLiteStatus UnsupportedTensorFlowOp(TfLiteContext* context, TfLiteNode* node) {
context->ReportError(
@ -378,6 +379,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_REVERSE_SEQUENCE, Register_REVERSE_SEQUENCE());
AddBuiltin(BuiltinOperator_MATRIX_DIAG, Register_MATRIX_DIAG());
AddBuiltin(BuiltinOperator_QUANTIZE, Register_QUANTIZE());
AddBuiltin(BuiltinOperator_MATRIX_SET_DIAG, Register_MATRIX_SET_DIAG());
// TODO(andrewharp, ahentz): Move these somewhere more appropriate so that
// custom ops aren't always included by default.

View File

@ -673,6 +673,7 @@ TfLiteStatus AddOpsAndParams(
case tflite::BuiltinOperator_REVERSE_SEQUENCE:
case tflite::BuiltinOperator_MATRIX_DIAG:
case tflite::BuiltinOperator_QUANTIZE:
case tflite::BuiltinOperator_MATRIX_SET_DIAG:
logError("Op code %d is currently not delegated to NNAPI", builtin);
return kTfLiteError;
break;

View File

@ -228,6 +228,7 @@ enum BuiltinOperator : byte {
REVERSE_SEQUENCE = 112,
MATRIX_DIAG = 113,
QUANTIZE = 114,
MATRIX_SET_DIAG = 115
}
// Options for the builtin operators.
@ -321,6 +322,7 @@ union BuiltinOptions {
ReverseSequenceOptions,
MatrixDiagOptions,
QuantizeOptions,
MatrixSetDiagOptions
}
enum Padding : byte { SAME, VALID }
@ -767,6 +769,9 @@ table MatrixDiagOptions {
table QuantizeOptions {
}
table MatrixSetDiagOptions {
}
// An OperatorCode can be an enum value (BuiltinOperator) if the operator is a
// builtin, or a string if the operator is custom.
table OperatorCode {

View File

@ -298,6 +298,9 @@ struct MatrixDiagOptionsT;
struct QuantizeOptions;
struct QuantizeOptionsT;
struct MatrixSetDiagOptions;
struct MatrixSetDiagOptionsT;
struct OperatorCode;
struct OperatorCodeT;
@ -562,11 +565,12 @@ enum BuiltinOperator {
BuiltinOperator_REVERSE_SEQUENCE = 112,
BuiltinOperator_MATRIX_DIAG = 113,
BuiltinOperator_QUANTIZE = 114,
BuiltinOperator_MATRIX_SET_DIAG = 115,
BuiltinOperator_MIN = BuiltinOperator_ADD,
BuiltinOperator_MAX = BuiltinOperator_QUANTIZE
BuiltinOperator_MAX = BuiltinOperator_MATRIX_SET_DIAG
};
inline const BuiltinOperator (&EnumValuesBuiltinOperator())[114] {
inline const BuiltinOperator (&EnumValuesBuiltinOperator())[115] {
static const BuiltinOperator values[] = {
BuiltinOperator_ADD,
BuiltinOperator_AVERAGE_POOL_2D,
@ -681,7 +685,8 @@ inline const BuiltinOperator (&EnumValuesBuiltinOperator())[114] {
BuiltinOperator_ELU,
BuiltinOperator_REVERSE_SEQUENCE,
BuiltinOperator_MATRIX_DIAG,
BuiltinOperator_QUANTIZE
BuiltinOperator_QUANTIZE,
BuiltinOperator_MATRIX_SET_DIAG
};
return values;
}
@ -803,6 +808,7 @@ inline const char * const *EnumNamesBuiltinOperator() {
"REVERSE_SEQUENCE",
"MATRIX_DIAG",
"QUANTIZE",
"MATRIX_SET_DIAG",
nullptr
};
return names;
@ -904,11 +910,12 @@ enum BuiltinOptions {
BuiltinOptions_ReverseSequenceOptions = 87,
BuiltinOptions_MatrixDiagOptions = 88,
BuiltinOptions_QuantizeOptions = 89,
BuiltinOptions_MatrixSetDiagOptions = 90,
BuiltinOptions_MIN = BuiltinOptions_NONE,
BuiltinOptions_MAX = BuiltinOptions_QuantizeOptions
BuiltinOptions_MAX = BuiltinOptions_MatrixSetDiagOptions
};
inline const BuiltinOptions (&EnumValuesBuiltinOptions())[90] {
inline const BuiltinOptions (&EnumValuesBuiltinOptions())[91] {
static const BuiltinOptions values[] = {
BuiltinOptions_NONE,
BuiltinOptions_Conv2DOptions,
@ -999,7 +1006,8 @@ inline const BuiltinOptions (&EnumValuesBuiltinOptions())[90] {
BuiltinOptions_RankOptions,
BuiltinOptions_ReverseSequenceOptions,
BuiltinOptions_MatrixDiagOptions,
BuiltinOptions_QuantizeOptions
BuiltinOptions_QuantizeOptions,
BuiltinOptions_MatrixSetDiagOptions
};
return values;
}
@ -1096,6 +1104,7 @@ inline const char * const *EnumNamesBuiltinOptions() {
"ReverseSequenceOptions",
"MatrixDiagOptions",
"QuantizeOptions",
"MatrixSetDiagOptions",
nullptr
};
return names;
@ -1466,6 +1475,10 @@ template<> struct BuiltinOptionsTraits<QuantizeOptions> {
static const BuiltinOptions enum_value = BuiltinOptions_QuantizeOptions;
};
template<> struct BuiltinOptionsTraits<MatrixSetDiagOptions> {
static const BuiltinOptions enum_value = BuiltinOptions_MatrixSetDiagOptions;
};
struct BuiltinOptionsUnion {
BuiltinOptions type;
void *value;
@ -2209,6 +2222,14 @@ struct BuiltinOptionsUnion {
return type == BuiltinOptions_QuantizeOptions ?
reinterpret_cast<const QuantizeOptionsT *>(value) : nullptr;
}
MatrixSetDiagOptionsT *AsMatrixSetDiagOptions() {
return type == BuiltinOptions_MatrixSetDiagOptions ?
reinterpret_cast<MatrixSetDiagOptionsT *>(value) : nullptr;
}
const MatrixSetDiagOptionsT *AsMatrixSetDiagOptions() const {
return type == BuiltinOptions_MatrixSetDiagOptions ?
reinterpret_cast<const MatrixSetDiagOptionsT *>(value) : nullptr;
}
};
bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type);
@ -7691,6 +7712,46 @@ inline flatbuffers::Offset<QuantizeOptions> CreateQuantizeOptions(
flatbuffers::Offset<QuantizeOptions> CreateQuantizeOptions(flatbuffers::FlatBufferBuilder &_fbb, const QuantizeOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
struct MatrixSetDiagOptionsT : public flatbuffers::NativeTable {
typedef MatrixSetDiagOptions TableType;
MatrixSetDiagOptionsT() {
}
};
struct MatrixSetDiagOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
typedef MatrixSetDiagOptionsT NativeTableType;
bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) &&
verifier.EndTable();
}
MatrixSetDiagOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
void UnPackTo(MatrixSetDiagOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
static flatbuffers::Offset<MatrixSetDiagOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const MatrixSetDiagOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
};
struct MatrixSetDiagOptionsBuilder {
flatbuffers::FlatBufferBuilder &fbb_;
flatbuffers::uoffset_t start_;
explicit MatrixSetDiagOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
: fbb_(_fbb) {
start_ = fbb_.StartTable();
}
MatrixSetDiagOptionsBuilder &operator=(const MatrixSetDiagOptionsBuilder &);
flatbuffers::Offset<MatrixSetDiagOptions> Finish() {
const auto end = fbb_.EndTable(start_);
auto o = flatbuffers::Offset<MatrixSetDiagOptions>(end);
return o;
}
};
inline flatbuffers::Offset<MatrixSetDiagOptions> CreateMatrixSetDiagOptions(
flatbuffers::FlatBufferBuilder &_fbb) {
MatrixSetDiagOptionsBuilder builder_(_fbb);
return builder_.Finish();
}
flatbuffers::Offset<MatrixSetDiagOptions> CreateMatrixSetDiagOptions(flatbuffers::FlatBufferBuilder &_fbb, const MatrixSetDiagOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
struct OperatorCodeT : public flatbuffers::NativeTable {
typedef OperatorCode TableType;
BuiltinOperator builtin_code;
@ -8091,6 +8152,9 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
const QuantizeOptions *builtin_options_as_QuantizeOptions() const {
return builtin_options_type() == BuiltinOptions_QuantizeOptions ? static_cast<const QuantizeOptions *>(builtin_options()) : nullptr;
}
const MatrixSetDiagOptions *builtin_options_as_MatrixSetDiagOptions() const {
return builtin_options_type() == BuiltinOptions_MatrixSetDiagOptions ? static_cast<const MatrixSetDiagOptions *>(builtin_options()) : nullptr;
}
const flatbuffers::Vector<uint8_t> *custom_options() const {
return GetPointer<const flatbuffers::Vector<uint8_t> *>(VT_CUSTOM_OPTIONS);
}
@ -8478,6 +8542,10 @@ template<> inline const QuantizeOptions *Operator::builtin_options_as<QuantizeOp
return builtin_options_as_QuantizeOptions();
}
template<> inline const MatrixSetDiagOptions *Operator::builtin_options_as<MatrixSetDiagOptions>() const {
return builtin_options_as_MatrixSetDiagOptions();
}
struct OperatorBuilder {
flatbuffers::FlatBufferBuilder &fbb_;
flatbuffers::uoffset_t start_;
@ -11338,6 +11406,29 @@ inline flatbuffers::Offset<QuantizeOptions> CreateQuantizeOptions(flatbuffers::F
_fbb);
}
inline MatrixSetDiagOptionsT *MatrixSetDiagOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
auto _o = new MatrixSetDiagOptionsT();
UnPackTo(_o, _resolver);
return _o;
}
inline void MatrixSetDiagOptions::UnPackTo(MatrixSetDiagOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const {
(void)_o;
(void)_resolver;
}
inline flatbuffers::Offset<MatrixSetDiagOptions> MatrixSetDiagOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const MatrixSetDiagOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
return CreateMatrixSetDiagOptions(_fbb, _o, _rehasher);
}
inline flatbuffers::Offset<MatrixSetDiagOptions> CreateMatrixSetDiagOptions(flatbuffers::FlatBufferBuilder &_fbb, const MatrixSetDiagOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
(void)_rehasher;
(void)_o;
struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const MatrixSetDiagOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
return tflite::CreateMatrixSetDiagOptions(
_fbb);
}
inline OperatorCodeT *OperatorCode::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
auto _o = new OperatorCodeT();
UnPackTo(_o, _resolver);
@ -11952,6 +12043,10 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *ob
auto ptr = reinterpret_cast<const QuantizeOptions *>(obj);
return verifier.VerifyTable(ptr);
}
case BuiltinOptions_MatrixSetDiagOptions: {
auto ptr = reinterpret_cast<const MatrixSetDiagOptions *>(obj);
return verifier.VerifyTable(ptr);
}
default: return false;
}
}
@ -12326,6 +12421,10 @@ inline void *BuiltinOptionsUnion::UnPack(const void *obj, BuiltinOptions type, c
auto ptr = reinterpret_cast<const QuantizeOptions *>(obj);
return ptr->UnPack(resolver);
}
case BuiltinOptions_MatrixSetDiagOptions: {
auto ptr = reinterpret_cast<const MatrixSetDiagOptions *>(obj);
return ptr->UnPack(resolver);
}
default: return nullptr;
}
}
@ -12688,6 +12787,10 @@ inline flatbuffers::Offset<void> BuiltinOptionsUnion::Pack(flatbuffers::FlatBuff
auto ptr = reinterpret_cast<const QuantizeOptionsT *>(value);
return CreateQuantizeOptions(_fbb, ptr, _rehasher).Union();
}
case BuiltinOptions_MatrixSetDiagOptions: {
auto ptr = reinterpret_cast<const MatrixSetDiagOptionsT *>(value);
return CreateMatrixSetDiagOptions(_fbb, ptr, _rehasher).Union();
}
default: return 0;
}
}
@ -13050,6 +13153,10 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) FL
value = new QuantizeOptionsT(*reinterpret_cast<QuantizeOptionsT *>(u.value));
break;
}
case BuiltinOptions_MatrixSetDiagOptions: {
value = new MatrixSetDiagOptionsT(*reinterpret_cast<MatrixSetDiagOptionsT *>(u.value));
break;
}
default:
break;
}
@ -13502,6 +13609,11 @@ inline void BuiltinOptionsUnion::Reset() {
delete ptr;
break;
}
case BuiltinOptions_MatrixSetDiagOptions: {
auto ptr = reinterpret_cast<MatrixSetDiagOptionsT *>(value);
delete ptr;
break;
}
default: break;
}
value = nullptr;

View File

@ -4385,6 +4385,79 @@ def make_matrix_diag_tests(zip_path):
make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
def make_matrix_set_diag_tests(zip_path):
"""Make a set of tests for tf.matrix_set_diag op."""
test_parameters = [
{
"input_diag_shapes": [([3, 3], [3]), ([2, 3], [2]), ([2, 4, 4],
[2, 4]),
([3, 4, 5, 6], [3, 4, 5])],
"input_dtype": [tf.int32, tf.float32, tf.uint8],
},
]
def build_graph(parameters):
input_shape = parameters["input_diag_shapes"][0]
diag_shape = parameters["input_diag_shapes"][1]
input_tensor = tf.placeholder(
dtype=parameters["input_dtype"], name="input", shape=input_shape)
diag_tensor = tf.placeholder(
dtype=parameters["input_dtype"], name="diagonal", shape=diag_shape)
outs = tf.matrix_set_diag(input_tensor, diag_tensor)
return [input_tensor, diag_tensor], [outs]
def build_inputs(parameters, sess, inputs, outputs):
input_shape = parameters["input_diag_shapes"][0]
diag_shape = parameters["input_diag_shapes"][1]
input_values = create_tensor_data(parameters["input_dtype"], input_shape)
diag_values = create_tensor_data(parameters["input_dtype"], diag_shape)
return [input_values, diag_values], sess.run(
outputs, feed_dict=dict(zip(inputs, [input_values, diag_values])))
make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
def make_eye_tests(zip_path):
"""Make a set of tests for tf.eye op."""
test_parameters = [{
"num_rows_shape": [[]],
"num_cols_shape": [[]],
"batch_shape": [[3], [2, 4], [4, 5, 6], None],
"use_num_cols": [True, False],
"dtype": [tf.float32, tf.int32],
}]
def build_graph(parameters):
input_tensor0 = tf.placeholder(
dtype=tf.int32, name="num_rows", shape=parameters["num_rows_shape"])
input_tensor1 = tf.placeholder(
dtype=tf.int32, name="num_columns", shape=parameters["num_cols_shape"])
if parameters["use_num_cols"]:
outs = tf.eye(
num_rows=input_tensor0,
num_columns=input_tensor1,
batch_shape=parameters["batch_shape"],
dtype=parameters["dtype"])
return [input_tensor0, input_tensor1], [outs]
else:
outs = tf.eye(num_rows=input_tensor0, dtype=parameters["dtype"])
return [input_tensor0], [outs]
def build_inputs(parameters, sess, inputs, outputs):
input_value0 = create_scalar_data(dtype=np.int32, min_value=1)
input_value1 = create_scalar_data(dtype=np.int32, min_value=1)
if parameters["use_num_cols"]:
return [input_value0, input_value1], sess.run(
outputs, feed_dict=dict(zip(inputs, [input_value0, input_value1])))
else:
return [input_value0], sess.run(
outputs, feed_dict=dict(zip(inputs, [input_value0])))
make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
@test_util.enable_control_flow_v2
def make_unidirectional_sequence_lstm_tests(zip_path):
"""Make a set of tests to do unidirectional_sequence_lstm."""

View File

@ -293,6 +293,13 @@ void SetDataTypeForAllOutputs(Model* model, Operator* op,
SetDataTypeForAllOutputs(model, op, data_type);
break;
}
case OperatorType::kMatrixSetDiag: {
CHECK_EQ(op->inputs.size(), 2);
CHECK_EQ(op->outputs.size(), 1);
const ArrayDataType data_type = model->GetArray(op->inputs[0]).data_type;
SetDataTypeForAllOutputs(model, op, data_type);
break;
}
default: {
// These operators produce outputs with the same type as their 1st input
CHECK_GT(op->inputs.size(), 0);

View File

@ -2063,21 +2063,40 @@ void ProcessUniqueOperator(Model* model, UniqueOperator* op) {
void ProcessMatrixDiagOperator(Model* model, MatrixDiagOperator* op) {
CHECK_EQ(op->inputs.size(), 1);
CHECK_EQ(op->outputs.size(), 1);
auto& input_array = model->GetArray(op->inputs[0]);
auto& output_array = model->GetArray(op->outputs[0]);
if (output_array.has_shape()) {
// The input array must have a shape in order to proceed. Also,
// bail out if the output shape has already been calculated.
if (!input_array.has_shape() || output_array.has_shape()) {
// We have already run
return;
}
// Get the input_shape
auto& input_array = model->GetArray(op->inputs[0]);
Shape* mutable_shape = input_array.mutable_shape();
std::vector<int>* dims = mutable_shape->mutable_dims();
int dims_size = dims->size();
// Scalars are not allowed.
CHECK_GT(dims_size, 0);
int last_dim = (*dims)[dims_size - 1];
dims->push_back(last_dim);
output_array.copy_shape(*mutable_shape);
}
void ProcessMatrixSetDiagOperator(Model* model, MatrixSetDiagOperator* op) {
CHECK_EQ(op->inputs.size(), 2);
CHECK_EQ(op->outputs.size(), 1);
auto& input_array = model->GetArray(op->inputs[0]);
auto& output_array = model->GetArray(op->outputs[0]);
// The shape of the input array must be known because that will
// be the shape of the output array.
if (!input_array.has_shape() || !output_array.has_shape()) {
// We have already run
return;
}
output_array.copy_shape(input_array.shape());
}
} // namespace
::tensorflow::Status PropagateFixedSizes::Run(Model* model,
@ -2384,6 +2403,10 @@ void ProcessMatrixDiagOperator(Model* model, MatrixDiagOperator* op) {
case OperatorType::kMatrixDiag:
ProcessMatrixDiagOperator(model, static_cast<MatrixDiagOperator*>(op));
break;
case OperatorType::kMatrixSetDiag:
ProcessMatrixSetDiagOperator(model,
static_cast<MatrixSetDiagOperator*>(op));
break;
default:
// Unimplemented, another graph transformation should drop it.
LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(op->type);

View File

@ -68,7 +68,9 @@ bool SupportsQuantization(const Operator& op) {
type == OperatorType::kResizeNearestNeighbor ||
type == OperatorType::kPRelu || type == OperatorType::kReduceMax ||
type == OperatorType::kReduceMin ||
type == OperatorType::kTransposeConv;
type == OperatorType::kTransposeConv ||
type == OperatorType::kMatrixSetDiag ||
type == OperatorType::kMatrixDiag;
}
// The quantized op allows output arrays of type float using

View File

@ -2472,6 +2472,7 @@ ConverterMapType GetTensorFlowNodeConverterMap() {
{"LogSoftmax", ConvertSimpleOperator<LogSoftmaxOperator, 1, 1>},
{"MatMul", ConvertMatMulOperator},
{"MatrixDiag", ConvertSimpleOperator<MatrixDiagOperator, 1, 1>},
{"MatrixSetDiag", ConvertSimpleOperator<MatrixSetDiagOperator, 2, 1>},
{"Max", ConvertReduceOperator<TensorFlowMaxOperator>},
{"MaxPool", ConvertMaxPoolOperator},
{"Maximum", ConvertSimpleOperator<TensorFlowMaximumOperator, 2, 1>},

View File

@ -169,7 +169,8 @@ enum class OperatorType : uint8 {
kWhere,
kElu,
kReverseSequence,
kMatrixDiag
kMatrixDiag,
kMatrixSetDiag
};
// Helper to deal with TensorFlow arrays using a different ordering of
@ -2084,6 +2085,16 @@ struct MatrixDiagOperator : Operator {
MatrixDiagOperator() : Operator(OperatorType::kMatrixDiag) {}
};
// Matrix Set Diag Operator:
// Construct a batched diagonal tensor with given input and diagonal values.
// Input is a rank (k+1) tensor of values.
// diagonal is a rank (k) tensor of values that will be on the diagonal
// of the returned output. Output is rank k+1.
// tensor.
struct MatrixSetDiagOperator : Operator {
MatrixSetDiagOperator() : Operator(OperatorType::kMatrixSetDiag) {}
};
// Alloc's are used for transient arrays only. An Alloc specifies which interval
// of the "transient_data" workspace buffer passed to inference functions, is to
// be used for the transient array at hand. The 'start' and 'end' values are

View File

@ -2478,6 +2478,8 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList(
OperatorType::kReverseSequence));
ops.push_back(MakeUnique<SimpleOperator<MatrixDiagOperator>>(
"MATRIX_DIAG", OperatorType::kMatrixDiag));
ops.push_back(MakeUnique<SimpleOperator<MatrixSetDiagOperator>>(
"MATRIX_SET_DIAG", OperatorType::kMatrixSetDiag));
// Custom Operators.
ops.push_back(
MakeUnique<DepthToSpace>("DEPTH_TO_SPACE", OperatorType::kDepthToSpace));

View File

@ -712,6 +712,13 @@ TEST_F(OperatorTest, BuiltinMatrixDiag) {
GetOperator("MATRIX_DIAG", OperatorType::kMatrixDiag), op);
}
TEST_F(OperatorTest, BuiltinMatrixSetDiag) {
MatrixSetDiagOperator op;
std::unique_ptr<toco::MatrixSetDiagOperator> output_toco_op =
SerializeAndDeserialize(
GetOperator("MATRIX_SET_DIAG", OperatorType::kMatrixSetDiag), op);
}
// Test version for a simple Op with 2 versions and the input type controls the
// version.
template <typename Op>

View File

@ -428,6 +428,7 @@ const char* OperatorTypeName(OperatorType type) {
HANDLE_OPERATORTYPENAME_CASE(Where)
HANDLE_OPERATORTYPENAME_CASE(ReverseSequence)
HANDLE_OPERATORTYPENAME_CASE(MatrixDiag)
HANDLE_OPERATORTYPENAME_CASE(MatrixSetDiag)
default:
LOG(FATAL) << "Unhandled op type";
#undef HANDLE_OPERATORTYPENAME_CASE