Implement MatrixSetDiag and Eye
PiperOrigin-RevId: 239662752
This commit is contained in:
parent
08cbe99299
commit
eae92e9d58
@ -248,6 +248,7 @@ def generated_test_models():
|
||||
"equal",
|
||||
"exp",
|
||||
"expand_dims",
|
||||
"eye",
|
||||
"fill",
|
||||
"floor",
|
||||
"floor_div",
|
||||
@ -275,6 +276,7 @@ def generated_test_models():
|
||||
"logical_xor",
|
||||
"lstm",
|
||||
"matrix_diag",
|
||||
"matrix_set_diag",
|
||||
"max_pool",
|
||||
"maximum",
|
||||
"mean",
|
||||
|
@ -140,6 +140,7 @@ typedef enum {
|
||||
kTfLiteBuiltinReverseSequence = 112,
|
||||
kTfLiteBuiltinMatrixDiag = 113,
|
||||
kTfLiteBuiltinQuantize = 114,
|
||||
kTfLiteBuiltinMatrixSetDiag = 115,
|
||||
} TfLiteBuiltinOperator;
|
||||
|
||||
#ifdef __cplusplus
|
||||
|
@ -377,6 +377,10 @@ typedef struct {
|
||||
EmptyStructPlaceholder placeholder;
|
||||
} TfLiteMatrixDiagParams;
|
||||
|
||||
typedef struct {
|
||||
EmptyStructPlaceholder placeholder;
|
||||
} TfLiteMatrixSetDiagParams;
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // extern "C"
|
||||
#endif // __cplusplus
|
||||
|
@ -708,6 +708,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
|
||||
case BuiltinOperator_LOGISTIC:
|
||||
case BuiltinOperator_LOG_SOFTMAX:
|
||||
case BuiltinOperator_MATRIX_DIAG:
|
||||
case BuiltinOperator_MATRIX_SET_DIAG:
|
||||
case BuiltinOperator_MAXIMUM:
|
||||
case BuiltinOperator_MINIMUM:
|
||||
case BuiltinOperator_NEG:
|
||||
|
@ -189,6 +189,7 @@ cc_library(
|
||||
"lsh_projection.cc",
|
||||
"lstm.cc",
|
||||
"matrix_diag.cc",
|
||||
"matrix_set_diag.cc",
|
||||
"maximum_minimum.cc",
|
||||
"mfcc.cc",
|
||||
"mirror_pad.cc",
|
||||
@ -1415,3 +1416,15 @@ cc_test(
|
||||
"@com_google_googletest//:gtest",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "matrix_set_diag_test",
|
||||
size = "small",
|
||||
srcs = ["matrix_set_diag_test.cc"],
|
||||
deps = [
|
||||
":builtin_ops",
|
||||
"//tensorflow/lite:framework",
|
||||
"//tensorflow/lite/kernels:test_util",
|
||||
"@com_google_googletest//:gtest",
|
||||
],
|
||||
)
|
||||
|
147
tensorflow/lite/kernels/matrix_set_diag.cc
Normal file
147
tensorflow/lite/kernels/matrix_set_diag.cc
Normal file
@ -0,0 +1,147 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include <string.h>
|
||||
#include <vector>
|
||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||
#include "tensorflow/lite/c/c_api_internal.h"
|
||||
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
|
||||
#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/kernels/op_macros.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace ops {
|
||||
namespace builtin {
|
||||
namespace matrix_set_diag {
|
||||
|
||||
constexpr int kInputTensor = 0;
|
||||
constexpr int kDiagonalTensor = 1;
|
||||
constexpr int kOutputTensor = 0;
|
||||
|
||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
TfLiteIntArray* input_dims = input->dims;
|
||||
int input_dims_size = input_dims->size;
|
||||
TF_LITE_ENSURE(context, input_dims_size >= 2);
|
||||
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
|
||||
TfLiteIntArray* output_shape = TfLiteIntArrayCreate(input_dims_size);
|
||||
for (int i = 0; i < input_dims_size; i++) {
|
||||
output_shape->data[i] = input_dims->data[i];
|
||||
}
|
||||
|
||||
// Resize the output tensor to the same size as the input tensor.
|
||||
output->type = input->type;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
context->ResizeTensor(context, output, output_shape));
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
// Fill the tensor to make a diagonal matrix in each batch, i.e., when
|
||||
// row index and column index are the same, fill with the next diagonal value.
|
||||
// All other entries are the same as the input value.
|
||||
// TODO(b/128636574) Move to reference_ops.
|
||||
template <typename T>
|
||||
void FillDiagImpl(const T* in, const T* diag, T* out, const int batch_size,
|
||||
const int row_size, const int col_size) {
|
||||
int idx = 0;
|
||||
for (int b = 0; b < batch_size; b++) {
|
||||
for (int i = 0; i < row_size; i++) {
|
||||
for (int j = 0; j < col_size; ++j) {
|
||||
// diag values go on the diagonal, in values elsewhere
|
||||
if (i == j) {
|
||||
out[i * col_size + j] = diag[idx];
|
||||
idx++;
|
||||
} else {
|
||||
out[i * col_size + j] = in[i * col_size + j];
|
||||
}
|
||||
}
|
||||
}
|
||||
out += row_size * col_size;
|
||||
in += row_size * col_size;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void FillDiag(const TfLiteTensor* input, const TfLiteTensor* diag,
|
||||
TfLiteTensor* output, const int batch_size, const int row_size,
|
||||
const int col_size) {
|
||||
FillDiagImpl<T>(GetTensorData<T>(input), GetTensorData<T>(diag),
|
||||
GetTensorData<T>(output), batch_size, row_size, col_size);
|
||||
}
|
||||
|
||||
// Fill a tensor with given "diag" values on the diagonal, input values
|
||||
// elsewhere.
|
||||
void FillDiagHelper(const TfLiteTensor* input, const TfLiteTensor* diag,
|
||||
TfLiteTensor* output) {
|
||||
const int num_output_dims = output->dims->size;
|
||||
int batch_size = 1;
|
||||
for (int i = 0; i < num_output_dims - 2; ++i) {
|
||||
batch_size *= output->dims->data[i];
|
||||
}
|
||||
|
||||
const int row_size = output->dims->data[num_output_dims - 2];
|
||||
const int col_size = output->dims->data[num_output_dims - 1];
|
||||
switch (output->type) {
|
||||
case kTfLiteInt64: {
|
||||
return FillDiag<int64_t>(input, diag, output, batch_size, row_size,
|
||||
col_size);
|
||||
}
|
||||
case kTfLiteInt32: {
|
||||
return FillDiag<int32_t>(input, diag, output, batch_size, row_size,
|
||||
col_size);
|
||||
}
|
||||
case kTfLiteInt16: {
|
||||
return FillDiag<int16_t>(input, diag, output, batch_size, row_size,
|
||||
col_size);
|
||||
}
|
||||
case kTfLiteInt8: {
|
||||
return FillDiag<int8_t>(input, diag, output, batch_size, row_size,
|
||||
col_size);
|
||||
}
|
||||
case kTfLiteUInt8: {
|
||||
return FillDiag<uint8_t>(input, diag, output, batch_size, row_size,
|
||||
col_size);
|
||||
}
|
||||
default:
|
||||
return FillDiag<float>(input, diag, output, batch_size, row_size,
|
||||
col_size);
|
||||
}
|
||||
}
|
||||
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
const TfLiteTensor* diag = GetInput(context, node, kDiagonalTensor);
|
||||
FillDiagHelper(input, diag, output);
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
} // namespace matrix_set_diag
|
||||
|
||||
TfLiteRegistration* Register_MATRIX_SET_DIAG() {
|
||||
static TfLiteRegistration r = {nullptr, nullptr, matrix_set_diag::Prepare,
|
||||
matrix_set_diag::Eval};
|
||||
return &r;
|
||||
}
|
||||
|
||||
} // namespace builtin
|
||||
} // namespace ops
|
||||
} // namespace tflite
|
132
tensorflow/lite/kernels/matrix_set_diag_test.cc
Normal file
132
tensorflow/lite/kernels/matrix_set_diag_test.cc
Normal file
@ -0,0 +1,132 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/lite/interpreter.h"
|
||||
#include "tensorflow/lite/kernels/register.h"
|
||||
#include "tensorflow/lite/kernels/test_util.h"
|
||||
#include "tensorflow/lite/model.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace {
|
||||
|
||||
using ::testing::ElementsAre;
|
||||
using ::testing::ElementsAreArray;
|
||||
|
||||
template <typename T>
|
||||
class MatrixSetDiagOpModel : public SingleOpModel {
|
||||
public:
|
||||
explicit MatrixSetDiagOpModel(const TensorData& input,
|
||||
const TensorData& diag) {
|
||||
input_ = AddInput(input);
|
||||
diag_ = AddInput(diag);
|
||||
output_ = AddOutput({input.type, {}});
|
||||
|
||||
SetBuiltinOp(BuiltinOperator_MATRIX_SET_DIAG,
|
||||
BuiltinOptions_MatrixSetDiagOptions,
|
||||
CreateMatrixSetDiagOptions(builder_).Union());
|
||||
BuildInterpreter({GetShape(input_), GetShape(diag_)});
|
||||
}
|
||||
|
||||
int input() { return input_; }
|
||||
int diag() { return diag_; }
|
||||
|
||||
std::vector<T> GetOutput() { return ExtractVector<T>(output_); }
|
||||
std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
|
||||
TfLiteType GetOutputType() {
|
||||
TfLiteTensor* t = interpreter_->tensor(output_);
|
||||
return t->type;
|
||||
}
|
||||
|
||||
private:
|
||||
int input_;
|
||||
int diag_;
|
||||
int output_;
|
||||
};
|
||||
|
||||
// Use the machinery of TYPED_TEST_SUITE to test all supported types.
|
||||
// See
|
||||
// https://github.com/google/googletest/blob/master/googletest/docs/advanced.md#typed-tests
|
||||
// for details.
|
||||
template <typename T>
|
||||
class MatrixSetDiagOpTest : public ::testing::Test {};
|
||||
|
||||
using TypesUnderTest =
|
||||
::testing::Types<TypeUnion<int32_t>, TypeUnion<float>, TypeUnion<int16_t>,
|
||||
TypeUnion<int8_t>, TypeUnion<uint8_t>>;
|
||||
|
||||
TYPED_TEST_SUITE(MatrixSetDiagOpTest, TypesUnderTest);
|
||||
|
||||
TYPED_TEST(MatrixSetDiagOpTest, ThreeByThreeDiagScatter) {
|
||||
MatrixSetDiagOpModel<typename TypeParam::ScalarType> model(
|
||||
{TypeParam::tensor_type, {3, 3}}, {TypeParam::tensor_type, {3}});
|
||||
model.template PopulateTensor<typename TypeParam::ScalarType>(model.input(),
|
||||
{7, 1, 2, //
|
||||
3, 8, 4, //
|
||||
5, 6, 9});
|
||||
model.template PopulateTensor<typename TypeParam::ScalarType>(model.diag(),
|
||||
{0, 4, 2});
|
||||
model.Invoke();
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAre(3, 3));
|
||||
EXPECT_THAT(model.GetOutput(), ElementsAreArray({0, 1, 2, //
|
||||
3, 4, 4, //
|
||||
5, 6, 2}));
|
||||
EXPECT_THAT(model.GetOutputType(), TypeParam::tflite_type);
|
||||
}
|
||||
|
||||
TEST(MatrixSetDiagTest, Int32TestMoreColumnsThanRows) {
|
||||
MatrixSetDiagOpModel<int32_t> model({TensorType_INT32, {2, 3}},
|
||||
{TensorType_INT32, {2}});
|
||||
model.PopulateTensor<int32_t>(model.input(), {0, 0, 0, //
|
||||
9, 9, 9});
|
||||
model.PopulateTensor<int32_t>(model.diag(), {1, 1});
|
||||
model.Invoke();
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAre(2, 3));
|
||||
EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 0, 0, //
|
||||
9, 1, 9}));
|
||||
EXPECT_THAT(model.GetOutputType(), TfLiteType::kTfLiteInt32);
|
||||
}
|
||||
|
||||
TEST(MatrixSetDiagTest, Int32TestTwoDimDiag) {
|
||||
MatrixSetDiagOpModel<int32_t> model({TensorType_INT32, {2, 4, 4}},
|
||||
{TensorType_INT32, {2, 4}});
|
||||
model.PopulateTensor<int32_t>(model.input(), {5, 5, 5, 5, //
|
||||
5, 5, 5, 5, //
|
||||
5, 5, 5, 5, //
|
||||
5, 5, 5, 5, //
|
||||
1, 1, 1, 1, //
|
||||
1, 1, 1, 1, //
|
||||
1, 1, 1, 1, //
|
||||
1, 1, 1, 1});
|
||||
model.PopulateTensor<int32_t>(model.diag(), {1, 2, 3, 4, 5, 6, 7, 8});
|
||||
model.Invoke();
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAre(2, 4, 4));
|
||||
EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 5, 5, 5, //
|
||||
5, 2, 5, 5, //
|
||||
5, 5, 3, 5, //
|
||||
5, 5, 5, 4, //
|
||||
5, 1, 1, 1, //
|
||||
1, 6, 1, 1, //
|
||||
1, 1, 7, 1, //
|
||||
1, 1, 1, 8}));
|
||||
EXPECT_THAT(model.GetOutputType(), TfLiteType::kTfLiteInt32);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tflite
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
::tflite::LogToStderr();
|
||||
::testing::InitGoogleTest(&argc, argv);
|
||||
return RUN_ALL_TESTS();
|
||||
}
|
@ -141,6 +141,7 @@ TfLiteRegistration* Register_ELU();
|
||||
TfLiteRegistration* Register_REVERSE_SEQUENCE();
|
||||
TfLiteRegistration* Register_MATRIX_DIAG();
|
||||
TfLiteRegistration* Register_QUANTIZE();
|
||||
TfLiteRegistration* Register_MATRIX_SET_DIAG();
|
||||
|
||||
TfLiteStatus UnsupportedTensorFlowOp(TfLiteContext* context, TfLiteNode* node) {
|
||||
context->ReportError(
|
||||
@ -378,6 +379,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
|
||||
AddBuiltin(BuiltinOperator_REVERSE_SEQUENCE, Register_REVERSE_SEQUENCE());
|
||||
AddBuiltin(BuiltinOperator_MATRIX_DIAG, Register_MATRIX_DIAG());
|
||||
AddBuiltin(BuiltinOperator_QUANTIZE, Register_QUANTIZE());
|
||||
AddBuiltin(BuiltinOperator_MATRIX_SET_DIAG, Register_MATRIX_SET_DIAG());
|
||||
|
||||
// TODO(andrewharp, ahentz): Move these somewhere more appropriate so that
|
||||
// custom ops aren't always included by default.
|
||||
|
@ -673,6 +673,7 @@ TfLiteStatus AddOpsAndParams(
|
||||
case tflite::BuiltinOperator_REVERSE_SEQUENCE:
|
||||
case tflite::BuiltinOperator_MATRIX_DIAG:
|
||||
case tflite::BuiltinOperator_QUANTIZE:
|
||||
case tflite::BuiltinOperator_MATRIX_SET_DIAG:
|
||||
logError("Op code %d is currently not delegated to NNAPI", builtin);
|
||||
return kTfLiteError;
|
||||
break;
|
||||
|
@ -228,6 +228,7 @@ enum BuiltinOperator : byte {
|
||||
REVERSE_SEQUENCE = 112,
|
||||
MATRIX_DIAG = 113,
|
||||
QUANTIZE = 114,
|
||||
MATRIX_SET_DIAG = 115
|
||||
}
|
||||
|
||||
// Options for the builtin operators.
|
||||
@ -321,6 +322,7 @@ union BuiltinOptions {
|
||||
ReverseSequenceOptions,
|
||||
MatrixDiagOptions,
|
||||
QuantizeOptions,
|
||||
MatrixSetDiagOptions
|
||||
}
|
||||
|
||||
enum Padding : byte { SAME, VALID }
|
||||
@ -767,6 +769,9 @@ table MatrixDiagOptions {
|
||||
table QuantizeOptions {
|
||||
}
|
||||
|
||||
table MatrixSetDiagOptions {
|
||||
}
|
||||
|
||||
// An OperatorCode can be an enum value (BuiltinOperator) if the operator is a
|
||||
// builtin, or a string if the operator is custom.
|
||||
table OperatorCode {
|
||||
|
@ -298,6 +298,9 @@ struct MatrixDiagOptionsT;
|
||||
struct QuantizeOptions;
|
||||
struct QuantizeOptionsT;
|
||||
|
||||
struct MatrixSetDiagOptions;
|
||||
struct MatrixSetDiagOptionsT;
|
||||
|
||||
struct OperatorCode;
|
||||
struct OperatorCodeT;
|
||||
|
||||
@ -562,11 +565,12 @@ enum BuiltinOperator {
|
||||
BuiltinOperator_REVERSE_SEQUENCE = 112,
|
||||
BuiltinOperator_MATRIX_DIAG = 113,
|
||||
BuiltinOperator_QUANTIZE = 114,
|
||||
BuiltinOperator_MATRIX_SET_DIAG = 115,
|
||||
BuiltinOperator_MIN = BuiltinOperator_ADD,
|
||||
BuiltinOperator_MAX = BuiltinOperator_QUANTIZE
|
||||
BuiltinOperator_MAX = BuiltinOperator_MATRIX_SET_DIAG
|
||||
};
|
||||
|
||||
inline const BuiltinOperator (&EnumValuesBuiltinOperator())[114] {
|
||||
inline const BuiltinOperator (&EnumValuesBuiltinOperator())[115] {
|
||||
static const BuiltinOperator values[] = {
|
||||
BuiltinOperator_ADD,
|
||||
BuiltinOperator_AVERAGE_POOL_2D,
|
||||
@ -681,7 +685,8 @@ inline const BuiltinOperator (&EnumValuesBuiltinOperator())[114] {
|
||||
BuiltinOperator_ELU,
|
||||
BuiltinOperator_REVERSE_SEQUENCE,
|
||||
BuiltinOperator_MATRIX_DIAG,
|
||||
BuiltinOperator_QUANTIZE
|
||||
BuiltinOperator_QUANTIZE,
|
||||
BuiltinOperator_MATRIX_SET_DIAG
|
||||
};
|
||||
return values;
|
||||
}
|
||||
@ -803,6 +808,7 @@ inline const char * const *EnumNamesBuiltinOperator() {
|
||||
"REVERSE_SEQUENCE",
|
||||
"MATRIX_DIAG",
|
||||
"QUANTIZE",
|
||||
"MATRIX_SET_DIAG",
|
||||
nullptr
|
||||
};
|
||||
return names;
|
||||
@ -904,11 +910,12 @@ enum BuiltinOptions {
|
||||
BuiltinOptions_ReverseSequenceOptions = 87,
|
||||
BuiltinOptions_MatrixDiagOptions = 88,
|
||||
BuiltinOptions_QuantizeOptions = 89,
|
||||
BuiltinOptions_MatrixSetDiagOptions = 90,
|
||||
BuiltinOptions_MIN = BuiltinOptions_NONE,
|
||||
BuiltinOptions_MAX = BuiltinOptions_QuantizeOptions
|
||||
BuiltinOptions_MAX = BuiltinOptions_MatrixSetDiagOptions
|
||||
};
|
||||
|
||||
inline const BuiltinOptions (&EnumValuesBuiltinOptions())[90] {
|
||||
inline const BuiltinOptions (&EnumValuesBuiltinOptions())[91] {
|
||||
static const BuiltinOptions values[] = {
|
||||
BuiltinOptions_NONE,
|
||||
BuiltinOptions_Conv2DOptions,
|
||||
@ -999,7 +1006,8 @@ inline const BuiltinOptions (&EnumValuesBuiltinOptions())[90] {
|
||||
BuiltinOptions_RankOptions,
|
||||
BuiltinOptions_ReverseSequenceOptions,
|
||||
BuiltinOptions_MatrixDiagOptions,
|
||||
BuiltinOptions_QuantizeOptions
|
||||
BuiltinOptions_QuantizeOptions,
|
||||
BuiltinOptions_MatrixSetDiagOptions
|
||||
};
|
||||
return values;
|
||||
}
|
||||
@ -1096,6 +1104,7 @@ inline const char * const *EnumNamesBuiltinOptions() {
|
||||
"ReverseSequenceOptions",
|
||||
"MatrixDiagOptions",
|
||||
"QuantizeOptions",
|
||||
"MatrixSetDiagOptions",
|
||||
nullptr
|
||||
};
|
||||
return names;
|
||||
@ -1466,6 +1475,10 @@ template<> struct BuiltinOptionsTraits<QuantizeOptions> {
|
||||
static const BuiltinOptions enum_value = BuiltinOptions_QuantizeOptions;
|
||||
};
|
||||
|
||||
template<> struct BuiltinOptionsTraits<MatrixSetDiagOptions> {
|
||||
static const BuiltinOptions enum_value = BuiltinOptions_MatrixSetDiagOptions;
|
||||
};
|
||||
|
||||
struct BuiltinOptionsUnion {
|
||||
BuiltinOptions type;
|
||||
void *value;
|
||||
@ -2209,6 +2222,14 @@ struct BuiltinOptionsUnion {
|
||||
return type == BuiltinOptions_QuantizeOptions ?
|
||||
reinterpret_cast<const QuantizeOptionsT *>(value) : nullptr;
|
||||
}
|
||||
MatrixSetDiagOptionsT *AsMatrixSetDiagOptions() {
|
||||
return type == BuiltinOptions_MatrixSetDiagOptions ?
|
||||
reinterpret_cast<MatrixSetDiagOptionsT *>(value) : nullptr;
|
||||
}
|
||||
const MatrixSetDiagOptionsT *AsMatrixSetDiagOptions() const {
|
||||
return type == BuiltinOptions_MatrixSetDiagOptions ?
|
||||
reinterpret_cast<const MatrixSetDiagOptionsT *>(value) : nullptr;
|
||||
}
|
||||
};
|
||||
|
||||
bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type);
|
||||
@ -7691,6 +7712,46 @@ inline flatbuffers::Offset<QuantizeOptions> CreateQuantizeOptions(
|
||||
|
||||
flatbuffers::Offset<QuantizeOptions> CreateQuantizeOptions(flatbuffers::FlatBufferBuilder &_fbb, const QuantizeOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
|
||||
|
||||
struct MatrixSetDiagOptionsT : public flatbuffers::NativeTable {
|
||||
typedef MatrixSetDiagOptions TableType;
|
||||
MatrixSetDiagOptionsT() {
|
||||
}
|
||||
};
|
||||
|
||||
struct MatrixSetDiagOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
|
||||
typedef MatrixSetDiagOptionsT NativeTableType;
|
||||
bool Verify(flatbuffers::Verifier &verifier) const {
|
||||
return VerifyTableStart(verifier) &&
|
||||
verifier.EndTable();
|
||||
}
|
||||
MatrixSetDiagOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
|
||||
void UnPackTo(MatrixSetDiagOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
|
||||
static flatbuffers::Offset<MatrixSetDiagOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const MatrixSetDiagOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
|
||||
};
|
||||
|
||||
struct MatrixSetDiagOptionsBuilder {
|
||||
flatbuffers::FlatBufferBuilder &fbb_;
|
||||
flatbuffers::uoffset_t start_;
|
||||
explicit MatrixSetDiagOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
|
||||
: fbb_(_fbb) {
|
||||
start_ = fbb_.StartTable();
|
||||
}
|
||||
MatrixSetDiagOptionsBuilder &operator=(const MatrixSetDiagOptionsBuilder &);
|
||||
flatbuffers::Offset<MatrixSetDiagOptions> Finish() {
|
||||
const auto end = fbb_.EndTable(start_);
|
||||
auto o = flatbuffers::Offset<MatrixSetDiagOptions>(end);
|
||||
return o;
|
||||
}
|
||||
};
|
||||
|
||||
inline flatbuffers::Offset<MatrixSetDiagOptions> CreateMatrixSetDiagOptions(
|
||||
flatbuffers::FlatBufferBuilder &_fbb) {
|
||||
MatrixSetDiagOptionsBuilder builder_(_fbb);
|
||||
return builder_.Finish();
|
||||
}
|
||||
|
||||
flatbuffers::Offset<MatrixSetDiagOptions> CreateMatrixSetDiagOptions(flatbuffers::FlatBufferBuilder &_fbb, const MatrixSetDiagOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
|
||||
|
||||
struct OperatorCodeT : public flatbuffers::NativeTable {
|
||||
typedef OperatorCode TableType;
|
||||
BuiltinOperator builtin_code;
|
||||
@ -8091,6 +8152,9 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
|
||||
const QuantizeOptions *builtin_options_as_QuantizeOptions() const {
|
||||
return builtin_options_type() == BuiltinOptions_QuantizeOptions ? static_cast<const QuantizeOptions *>(builtin_options()) : nullptr;
|
||||
}
|
||||
const MatrixSetDiagOptions *builtin_options_as_MatrixSetDiagOptions() const {
|
||||
return builtin_options_type() == BuiltinOptions_MatrixSetDiagOptions ? static_cast<const MatrixSetDiagOptions *>(builtin_options()) : nullptr;
|
||||
}
|
||||
const flatbuffers::Vector<uint8_t> *custom_options() const {
|
||||
return GetPointer<const flatbuffers::Vector<uint8_t> *>(VT_CUSTOM_OPTIONS);
|
||||
}
|
||||
@ -8478,6 +8542,10 @@ template<> inline const QuantizeOptions *Operator::builtin_options_as<QuantizeOp
|
||||
return builtin_options_as_QuantizeOptions();
|
||||
}
|
||||
|
||||
template<> inline const MatrixSetDiagOptions *Operator::builtin_options_as<MatrixSetDiagOptions>() const {
|
||||
return builtin_options_as_MatrixSetDiagOptions();
|
||||
}
|
||||
|
||||
struct OperatorBuilder {
|
||||
flatbuffers::FlatBufferBuilder &fbb_;
|
||||
flatbuffers::uoffset_t start_;
|
||||
@ -11338,6 +11406,29 @@ inline flatbuffers::Offset<QuantizeOptions> CreateQuantizeOptions(flatbuffers::F
|
||||
_fbb);
|
||||
}
|
||||
|
||||
inline MatrixSetDiagOptionsT *MatrixSetDiagOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
|
||||
auto _o = new MatrixSetDiagOptionsT();
|
||||
UnPackTo(_o, _resolver);
|
||||
return _o;
|
||||
}
|
||||
|
||||
inline void MatrixSetDiagOptions::UnPackTo(MatrixSetDiagOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const {
|
||||
(void)_o;
|
||||
(void)_resolver;
|
||||
}
|
||||
|
||||
inline flatbuffers::Offset<MatrixSetDiagOptions> MatrixSetDiagOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const MatrixSetDiagOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
|
||||
return CreateMatrixSetDiagOptions(_fbb, _o, _rehasher);
|
||||
}
|
||||
|
||||
inline flatbuffers::Offset<MatrixSetDiagOptions> CreateMatrixSetDiagOptions(flatbuffers::FlatBufferBuilder &_fbb, const MatrixSetDiagOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
|
||||
(void)_rehasher;
|
||||
(void)_o;
|
||||
struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const MatrixSetDiagOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
|
||||
return tflite::CreateMatrixSetDiagOptions(
|
||||
_fbb);
|
||||
}
|
||||
|
||||
inline OperatorCodeT *OperatorCode::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
|
||||
auto _o = new OperatorCodeT();
|
||||
UnPackTo(_o, _resolver);
|
||||
@ -11952,6 +12043,10 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *ob
|
||||
auto ptr = reinterpret_cast<const QuantizeOptions *>(obj);
|
||||
return verifier.VerifyTable(ptr);
|
||||
}
|
||||
case BuiltinOptions_MatrixSetDiagOptions: {
|
||||
auto ptr = reinterpret_cast<const MatrixSetDiagOptions *>(obj);
|
||||
return verifier.VerifyTable(ptr);
|
||||
}
|
||||
default: return false;
|
||||
}
|
||||
}
|
||||
@ -12326,6 +12421,10 @@ inline void *BuiltinOptionsUnion::UnPack(const void *obj, BuiltinOptions type, c
|
||||
auto ptr = reinterpret_cast<const QuantizeOptions *>(obj);
|
||||
return ptr->UnPack(resolver);
|
||||
}
|
||||
case BuiltinOptions_MatrixSetDiagOptions: {
|
||||
auto ptr = reinterpret_cast<const MatrixSetDiagOptions *>(obj);
|
||||
return ptr->UnPack(resolver);
|
||||
}
|
||||
default: return nullptr;
|
||||
}
|
||||
}
|
||||
@ -12688,6 +12787,10 @@ inline flatbuffers::Offset<void> BuiltinOptionsUnion::Pack(flatbuffers::FlatBuff
|
||||
auto ptr = reinterpret_cast<const QuantizeOptionsT *>(value);
|
||||
return CreateQuantizeOptions(_fbb, ptr, _rehasher).Union();
|
||||
}
|
||||
case BuiltinOptions_MatrixSetDiagOptions: {
|
||||
auto ptr = reinterpret_cast<const MatrixSetDiagOptionsT *>(value);
|
||||
return CreateMatrixSetDiagOptions(_fbb, ptr, _rehasher).Union();
|
||||
}
|
||||
default: return 0;
|
||||
}
|
||||
}
|
||||
@ -13050,6 +13153,10 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) FL
|
||||
value = new QuantizeOptionsT(*reinterpret_cast<QuantizeOptionsT *>(u.value));
|
||||
break;
|
||||
}
|
||||
case BuiltinOptions_MatrixSetDiagOptions: {
|
||||
value = new MatrixSetDiagOptionsT(*reinterpret_cast<MatrixSetDiagOptionsT *>(u.value));
|
||||
break;
|
||||
}
|
||||
default:
|
||||
break;
|
||||
}
|
||||
@ -13502,6 +13609,11 @@ inline void BuiltinOptionsUnion::Reset() {
|
||||
delete ptr;
|
||||
break;
|
||||
}
|
||||
case BuiltinOptions_MatrixSetDiagOptions: {
|
||||
auto ptr = reinterpret_cast<MatrixSetDiagOptionsT *>(value);
|
||||
delete ptr;
|
||||
break;
|
||||
}
|
||||
default: break;
|
||||
}
|
||||
value = nullptr;
|
||||
|
@ -4385,6 +4385,79 @@ def make_matrix_diag_tests(zip_path):
|
||||
make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
|
||||
|
||||
|
||||
def make_matrix_set_diag_tests(zip_path):
|
||||
"""Make a set of tests for tf.matrix_set_diag op."""
|
||||
|
||||
test_parameters = [
|
||||
{
|
||||
"input_diag_shapes": [([3, 3], [3]), ([2, 3], [2]), ([2, 4, 4],
|
||||
[2, 4]),
|
||||
([3, 4, 5, 6], [3, 4, 5])],
|
||||
"input_dtype": [tf.int32, tf.float32, tf.uint8],
|
||||
},
|
||||
]
|
||||
|
||||
def build_graph(parameters):
|
||||
input_shape = parameters["input_diag_shapes"][0]
|
||||
diag_shape = parameters["input_diag_shapes"][1]
|
||||
input_tensor = tf.placeholder(
|
||||
dtype=parameters["input_dtype"], name="input", shape=input_shape)
|
||||
diag_tensor = tf.placeholder(
|
||||
dtype=parameters["input_dtype"], name="diagonal", shape=diag_shape)
|
||||
outs = tf.matrix_set_diag(input_tensor, diag_tensor)
|
||||
return [input_tensor, diag_tensor], [outs]
|
||||
|
||||
def build_inputs(parameters, sess, inputs, outputs):
|
||||
input_shape = parameters["input_diag_shapes"][0]
|
||||
diag_shape = parameters["input_diag_shapes"][1]
|
||||
input_values = create_tensor_data(parameters["input_dtype"], input_shape)
|
||||
diag_values = create_tensor_data(parameters["input_dtype"], diag_shape)
|
||||
return [input_values, diag_values], sess.run(
|
||||
outputs, feed_dict=dict(zip(inputs, [input_values, diag_values])))
|
||||
|
||||
make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
|
||||
|
||||
|
||||
def make_eye_tests(zip_path):
|
||||
"""Make a set of tests for tf.eye op."""
|
||||
|
||||
test_parameters = [{
|
||||
"num_rows_shape": [[]],
|
||||
"num_cols_shape": [[]],
|
||||
"batch_shape": [[3], [2, 4], [4, 5, 6], None],
|
||||
"use_num_cols": [True, False],
|
||||
"dtype": [tf.float32, tf.int32],
|
||||
}]
|
||||
|
||||
def build_graph(parameters):
|
||||
input_tensor0 = tf.placeholder(
|
||||
dtype=tf.int32, name="num_rows", shape=parameters["num_rows_shape"])
|
||||
input_tensor1 = tf.placeholder(
|
||||
dtype=tf.int32, name="num_columns", shape=parameters["num_cols_shape"])
|
||||
if parameters["use_num_cols"]:
|
||||
outs = tf.eye(
|
||||
num_rows=input_tensor0,
|
||||
num_columns=input_tensor1,
|
||||
batch_shape=parameters["batch_shape"],
|
||||
dtype=parameters["dtype"])
|
||||
return [input_tensor0, input_tensor1], [outs]
|
||||
else:
|
||||
outs = tf.eye(num_rows=input_tensor0, dtype=parameters["dtype"])
|
||||
return [input_tensor0], [outs]
|
||||
|
||||
def build_inputs(parameters, sess, inputs, outputs):
|
||||
input_value0 = create_scalar_data(dtype=np.int32, min_value=1)
|
||||
input_value1 = create_scalar_data(dtype=np.int32, min_value=1)
|
||||
if parameters["use_num_cols"]:
|
||||
return [input_value0, input_value1], sess.run(
|
||||
outputs, feed_dict=dict(zip(inputs, [input_value0, input_value1])))
|
||||
else:
|
||||
return [input_value0], sess.run(
|
||||
outputs, feed_dict=dict(zip(inputs, [input_value0])))
|
||||
|
||||
make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
|
||||
|
||||
|
||||
@test_util.enable_control_flow_v2
|
||||
def make_unidirectional_sequence_lstm_tests(zip_path):
|
||||
"""Make a set of tests to do unidirectional_sequence_lstm."""
|
||||
|
@ -293,6 +293,13 @@ void SetDataTypeForAllOutputs(Model* model, Operator* op,
|
||||
SetDataTypeForAllOutputs(model, op, data_type);
|
||||
break;
|
||||
}
|
||||
case OperatorType::kMatrixSetDiag: {
|
||||
CHECK_EQ(op->inputs.size(), 2);
|
||||
CHECK_EQ(op->outputs.size(), 1);
|
||||
const ArrayDataType data_type = model->GetArray(op->inputs[0]).data_type;
|
||||
SetDataTypeForAllOutputs(model, op, data_type);
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
// These operators produce outputs with the same type as their 1st input
|
||||
CHECK_GT(op->inputs.size(), 0);
|
||||
|
@ -2063,21 +2063,40 @@ void ProcessUniqueOperator(Model* model, UniqueOperator* op) {
|
||||
void ProcessMatrixDiagOperator(Model* model, MatrixDiagOperator* op) {
|
||||
CHECK_EQ(op->inputs.size(), 1);
|
||||
CHECK_EQ(op->outputs.size(), 1);
|
||||
auto& input_array = model->GetArray(op->inputs[0]);
|
||||
auto& output_array = model->GetArray(op->outputs[0]);
|
||||
if (output_array.has_shape()) {
|
||||
// The input array must have a shape in order to proceed. Also,
|
||||
// bail out if the output shape has already been calculated.
|
||||
if (!input_array.has_shape() || output_array.has_shape()) {
|
||||
// We have already run
|
||||
return;
|
||||
}
|
||||
// Get the input_shape
|
||||
auto& input_array = model->GetArray(op->inputs[0]);
|
||||
Shape* mutable_shape = input_array.mutable_shape();
|
||||
std::vector<int>* dims = mutable_shape->mutable_dims();
|
||||
int dims_size = dims->size();
|
||||
// Scalars are not allowed.
|
||||
CHECK_GT(dims_size, 0);
|
||||
int last_dim = (*dims)[dims_size - 1];
|
||||
dims->push_back(last_dim);
|
||||
output_array.copy_shape(*mutable_shape);
|
||||
}
|
||||
|
||||
void ProcessMatrixSetDiagOperator(Model* model, MatrixSetDiagOperator* op) {
|
||||
CHECK_EQ(op->inputs.size(), 2);
|
||||
CHECK_EQ(op->outputs.size(), 1);
|
||||
auto& input_array = model->GetArray(op->inputs[0]);
|
||||
auto& output_array = model->GetArray(op->outputs[0]);
|
||||
// The shape of the input array must be known because that will
|
||||
// be the shape of the output array.
|
||||
if (!input_array.has_shape() || !output_array.has_shape()) {
|
||||
// We have already run
|
||||
return;
|
||||
}
|
||||
|
||||
output_array.copy_shape(input_array.shape());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
::tensorflow::Status PropagateFixedSizes::Run(Model* model,
|
||||
@ -2384,6 +2403,10 @@ void ProcessMatrixDiagOperator(Model* model, MatrixDiagOperator* op) {
|
||||
case OperatorType::kMatrixDiag:
|
||||
ProcessMatrixDiagOperator(model, static_cast<MatrixDiagOperator*>(op));
|
||||
break;
|
||||
case OperatorType::kMatrixSetDiag:
|
||||
ProcessMatrixSetDiagOperator(model,
|
||||
static_cast<MatrixSetDiagOperator*>(op));
|
||||
break;
|
||||
default:
|
||||
// Unimplemented, another graph transformation should drop it.
|
||||
LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(op->type);
|
||||
|
@ -68,7 +68,9 @@ bool SupportsQuantization(const Operator& op) {
|
||||
type == OperatorType::kResizeNearestNeighbor ||
|
||||
type == OperatorType::kPRelu || type == OperatorType::kReduceMax ||
|
||||
type == OperatorType::kReduceMin ||
|
||||
type == OperatorType::kTransposeConv;
|
||||
type == OperatorType::kTransposeConv ||
|
||||
type == OperatorType::kMatrixSetDiag ||
|
||||
type == OperatorType::kMatrixDiag;
|
||||
}
|
||||
|
||||
// The quantized op allows output arrays of type float using
|
||||
|
@ -2472,6 +2472,7 @@ ConverterMapType GetTensorFlowNodeConverterMap() {
|
||||
{"LogSoftmax", ConvertSimpleOperator<LogSoftmaxOperator, 1, 1>},
|
||||
{"MatMul", ConvertMatMulOperator},
|
||||
{"MatrixDiag", ConvertSimpleOperator<MatrixDiagOperator, 1, 1>},
|
||||
{"MatrixSetDiag", ConvertSimpleOperator<MatrixSetDiagOperator, 2, 1>},
|
||||
{"Max", ConvertReduceOperator<TensorFlowMaxOperator>},
|
||||
{"MaxPool", ConvertMaxPoolOperator},
|
||||
{"Maximum", ConvertSimpleOperator<TensorFlowMaximumOperator, 2, 1>},
|
||||
|
@ -169,7 +169,8 @@ enum class OperatorType : uint8 {
|
||||
kWhere,
|
||||
kElu,
|
||||
kReverseSequence,
|
||||
kMatrixDiag
|
||||
kMatrixDiag,
|
||||
kMatrixSetDiag
|
||||
};
|
||||
|
||||
// Helper to deal with TensorFlow arrays using a different ordering of
|
||||
@ -2084,6 +2085,16 @@ struct MatrixDiagOperator : Operator {
|
||||
MatrixDiagOperator() : Operator(OperatorType::kMatrixDiag) {}
|
||||
};
|
||||
|
||||
// Matrix Set Diag Operator:
|
||||
// Construct a batched diagonal tensor with given input and diagonal values.
|
||||
// Input is a rank (k+1) tensor of values.
|
||||
// diagonal is a rank (k) tensor of values that will be on the diagonal
|
||||
// of the returned output. Output is rank k+1.
|
||||
// tensor.
|
||||
struct MatrixSetDiagOperator : Operator {
|
||||
MatrixSetDiagOperator() : Operator(OperatorType::kMatrixSetDiag) {}
|
||||
};
|
||||
|
||||
// Alloc's are used for transient arrays only. An Alloc specifies which interval
|
||||
// of the "transient_data" workspace buffer passed to inference functions, is to
|
||||
// be used for the transient array at hand. The 'start' and 'end' values are
|
||||
|
@ -2478,6 +2478,8 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList(
|
||||
OperatorType::kReverseSequence));
|
||||
ops.push_back(MakeUnique<SimpleOperator<MatrixDiagOperator>>(
|
||||
"MATRIX_DIAG", OperatorType::kMatrixDiag));
|
||||
ops.push_back(MakeUnique<SimpleOperator<MatrixSetDiagOperator>>(
|
||||
"MATRIX_SET_DIAG", OperatorType::kMatrixSetDiag));
|
||||
// Custom Operators.
|
||||
ops.push_back(
|
||||
MakeUnique<DepthToSpace>("DEPTH_TO_SPACE", OperatorType::kDepthToSpace));
|
||||
|
@ -712,6 +712,13 @@ TEST_F(OperatorTest, BuiltinMatrixDiag) {
|
||||
GetOperator("MATRIX_DIAG", OperatorType::kMatrixDiag), op);
|
||||
}
|
||||
|
||||
TEST_F(OperatorTest, BuiltinMatrixSetDiag) {
|
||||
MatrixSetDiagOperator op;
|
||||
std::unique_ptr<toco::MatrixSetDiagOperator> output_toco_op =
|
||||
SerializeAndDeserialize(
|
||||
GetOperator("MATRIX_SET_DIAG", OperatorType::kMatrixSetDiag), op);
|
||||
}
|
||||
|
||||
// Test version for a simple Op with 2 versions and the input type controls the
|
||||
// version.
|
||||
template <typename Op>
|
||||
|
@ -428,6 +428,7 @@ const char* OperatorTypeName(OperatorType type) {
|
||||
HANDLE_OPERATORTYPENAME_CASE(Where)
|
||||
HANDLE_OPERATORTYPENAME_CASE(ReverseSequence)
|
||||
HANDLE_OPERATORTYPENAME_CASE(MatrixDiag)
|
||||
HANDLE_OPERATORTYPENAME_CASE(MatrixSetDiag)
|
||||
default:
|
||||
LOG(FATAL) << "Unhandled op type";
|
||||
#undef HANDLE_OPERATORTYPENAME_CASE
|
||||
|
Loading…
Reference in New Issue
Block a user