Implement MatrixSetDiag and Eye
PiperOrigin-RevId: 239662752
This commit is contained in:
parent
08cbe99299
commit
eae92e9d58
@ -248,6 +248,7 @@ def generated_test_models():
|
|||||||
"equal",
|
"equal",
|
||||||
"exp",
|
"exp",
|
||||||
"expand_dims",
|
"expand_dims",
|
||||||
|
"eye",
|
||||||
"fill",
|
"fill",
|
||||||
"floor",
|
"floor",
|
||||||
"floor_div",
|
"floor_div",
|
||||||
@ -275,6 +276,7 @@ def generated_test_models():
|
|||||||
"logical_xor",
|
"logical_xor",
|
||||||
"lstm",
|
"lstm",
|
||||||
"matrix_diag",
|
"matrix_diag",
|
||||||
|
"matrix_set_diag",
|
||||||
"max_pool",
|
"max_pool",
|
||||||
"maximum",
|
"maximum",
|
||||||
"mean",
|
"mean",
|
||||||
|
@ -140,6 +140,7 @@ typedef enum {
|
|||||||
kTfLiteBuiltinReverseSequence = 112,
|
kTfLiteBuiltinReverseSequence = 112,
|
||||||
kTfLiteBuiltinMatrixDiag = 113,
|
kTfLiteBuiltinMatrixDiag = 113,
|
||||||
kTfLiteBuiltinQuantize = 114,
|
kTfLiteBuiltinQuantize = 114,
|
||||||
|
kTfLiteBuiltinMatrixSetDiag = 115,
|
||||||
} TfLiteBuiltinOperator;
|
} TfLiteBuiltinOperator;
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
|
@ -377,6 +377,10 @@ typedef struct {
|
|||||||
EmptyStructPlaceholder placeholder;
|
EmptyStructPlaceholder placeholder;
|
||||||
} TfLiteMatrixDiagParams;
|
} TfLiteMatrixDiagParams;
|
||||||
|
|
||||||
|
typedef struct {
|
||||||
|
EmptyStructPlaceholder placeholder;
|
||||||
|
} TfLiteMatrixSetDiagParams;
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
} // extern "C"
|
} // extern "C"
|
||||||
#endif // __cplusplus
|
#endif // __cplusplus
|
||||||
|
@ -708,6 +708,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
|
|||||||
case BuiltinOperator_LOGISTIC:
|
case BuiltinOperator_LOGISTIC:
|
||||||
case BuiltinOperator_LOG_SOFTMAX:
|
case BuiltinOperator_LOG_SOFTMAX:
|
||||||
case BuiltinOperator_MATRIX_DIAG:
|
case BuiltinOperator_MATRIX_DIAG:
|
||||||
|
case BuiltinOperator_MATRIX_SET_DIAG:
|
||||||
case BuiltinOperator_MAXIMUM:
|
case BuiltinOperator_MAXIMUM:
|
||||||
case BuiltinOperator_MINIMUM:
|
case BuiltinOperator_MINIMUM:
|
||||||
case BuiltinOperator_NEG:
|
case BuiltinOperator_NEG:
|
||||||
|
@ -189,6 +189,7 @@ cc_library(
|
|||||||
"lsh_projection.cc",
|
"lsh_projection.cc",
|
||||||
"lstm.cc",
|
"lstm.cc",
|
||||||
"matrix_diag.cc",
|
"matrix_diag.cc",
|
||||||
|
"matrix_set_diag.cc",
|
||||||
"maximum_minimum.cc",
|
"maximum_minimum.cc",
|
||||||
"mfcc.cc",
|
"mfcc.cc",
|
||||||
"mirror_pad.cc",
|
"mirror_pad.cc",
|
||||||
@ -1415,3 +1416,15 @@ cc_test(
|
|||||||
"@com_google_googletest//:gtest",
|
"@com_google_googletest//:gtest",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_test(
|
||||||
|
name = "matrix_set_diag_test",
|
||||||
|
size = "small",
|
||||||
|
srcs = ["matrix_set_diag_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":builtin_ops",
|
||||||
|
"//tensorflow/lite:framework",
|
||||||
|
"//tensorflow/lite/kernels:test_util",
|
||||||
|
"@com_google_googletest//:gtest",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
147
tensorflow/lite/kernels/matrix_set_diag.cc
Normal file
147
tensorflow/lite/kernels/matrix_set_diag.cc
Normal file
@ -0,0 +1,147 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#include <string.h>
|
||||||
|
#include <vector>
|
||||||
|
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||||
|
#include "tensorflow/lite/c/c_api_internal.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/tensor.h"
|
||||||
|
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||||
|
#include "tensorflow/lite/kernels/op_macros.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace ops {
|
||||||
|
namespace builtin {
|
||||||
|
namespace matrix_set_diag {
|
||||||
|
|
||||||
|
constexpr int kInputTensor = 0;
|
||||||
|
constexpr int kDiagonalTensor = 1;
|
||||||
|
constexpr int kOutputTensor = 0;
|
||||||
|
|
||||||
|
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||||
|
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
|
||||||
|
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||||
|
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||||
|
TfLiteIntArray* input_dims = input->dims;
|
||||||
|
int input_dims_size = input_dims->size;
|
||||||
|
TF_LITE_ENSURE(context, input_dims_size >= 2);
|
||||||
|
|
||||||
|
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||||
|
|
||||||
|
TfLiteIntArray* output_shape = TfLiteIntArrayCreate(input_dims_size);
|
||||||
|
for (int i = 0; i < input_dims_size; i++) {
|
||||||
|
output_shape->data[i] = input_dims->data[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Resize the output tensor to the same size as the input tensor.
|
||||||
|
output->type = input->type;
|
||||||
|
TF_LITE_ENSURE_OK(context,
|
||||||
|
context->ResizeTensor(context, output, output_shape));
|
||||||
|
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fill the tensor to make a diagonal matrix in each batch, i.e., when
|
||||||
|
// row index and column index are the same, fill with the next diagonal value.
|
||||||
|
// All other entries are the same as the input value.
|
||||||
|
// TODO(b/128636574) Move to reference_ops.
|
||||||
|
template <typename T>
|
||||||
|
void FillDiagImpl(const T* in, const T* diag, T* out, const int batch_size,
|
||||||
|
const int row_size, const int col_size) {
|
||||||
|
int idx = 0;
|
||||||
|
for (int b = 0; b < batch_size; b++) {
|
||||||
|
for (int i = 0; i < row_size; i++) {
|
||||||
|
for (int j = 0; j < col_size; ++j) {
|
||||||
|
// diag values go on the diagonal, in values elsewhere
|
||||||
|
if (i == j) {
|
||||||
|
out[i * col_size + j] = diag[idx];
|
||||||
|
idx++;
|
||||||
|
} else {
|
||||||
|
out[i * col_size + j] = in[i * col_size + j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
out += row_size * col_size;
|
||||||
|
in += row_size * col_size;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void FillDiag(const TfLiteTensor* input, const TfLiteTensor* diag,
|
||||||
|
TfLiteTensor* output, const int batch_size, const int row_size,
|
||||||
|
const int col_size) {
|
||||||
|
FillDiagImpl<T>(GetTensorData<T>(input), GetTensorData<T>(diag),
|
||||||
|
GetTensorData<T>(output), batch_size, row_size, col_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fill a tensor with given "diag" values on the diagonal, input values
|
||||||
|
// elsewhere.
|
||||||
|
void FillDiagHelper(const TfLiteTensor* input, const TfLiteTensor* diag,
|
||||||
|
TfLiteTensor* output) {
|
||||||
|
const int num_output_dims = output->dims->size;
|
||||||
|
int batch_size = 1;
|
||||||
|
for (int i = 0; i < num_output_dims - 2; ++i) {
|
||||||
|
batch_size *= output->dims->data[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
const int row_size = output->dims->data[num_output_dims - 2];
|
||||||
|
const int col_size = output->dims->data[num_output_dims - 1];
|
||||||
|
switch (output->type) {
|
||||||
|
case kTfLiteInt64: {
|
||||||
|
return FillDiag<int64_t>(input, diag, output, batch_size, row_size,
|
||||||
|
col_size);
|
||||||
|
}
|
||||||
|
case kTfLiteInt32: {
|
||||||
|
return FillDiag<int32_t>(input, diag, output, batch_size, row_size,
|
||||||
|
col_size);
|
||||||
|
}
|
||||||
|
case kTfLiteInt16: {
|
||||||
|
return FillDiag<int16_t>(input, diag, output, batch_size, row_size,
|
||||||
|
col_size);
|
||||||
|
}
|
||||||
|
case kTfLiteInt8: {
|
||||||
|
return FillDiag<int8_t>(input, diag, output, batch_size, row_size,
|
||||||
|
col_size);
|
||||||
|
}
|
||||||
|
case kTfLiteUInt8: {
|
||||||
|
return FillDiag<uint8_t>(input, diag, output, batch_size, row_size,
|
||||||
|
col_size);
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return FillDiag<float>(input, diag, output, batch_size, row_size,
|
||||||
|
col_size);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||||
|
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||||
|
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||||
|
const TfLiteTensor* diag = GetInput(context, node, kDiagonalTensor);
|
||||||
|
FillDiagHelper(input, diag, output);
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace matrix_set_diag
|
||||||
|
|
||||||
|
TfLiteRegistration* Register_MATRIX_SET_DIAG() {
|
||||||
|
static TfLiteRegistration r = {nullptr, nullptr, matrix_set_diag::Prepare,
|
||||||
|
matrix_set_diag::Eval};
|
||||||
|
return &r;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace builtin
|
||||||
|
} // namespace ops
|
||||||
|
} // namespace tflite
|
132
tensorflow/lite/kernels/matrix_set_diag_test.cc
Normal file
132
tensorflow/lite/kernels/matrix_set_diag_test.cc
Normal file
@ -0,0 +1,132 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#include "tensorflow/lite/interpreter.h"
|
||||||
|
#include "tensorflow/lite/kernels/register.h"
|
||||||
|
#include "tensorflow/lite/kernels/test_util.h"
|
||||||
|
#include "tensorflow/lite/model.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
using ::testing::ElementsAre;
|
||||||
|
using ::testing::ElementsAreArray;
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
class MatrixSetDiagOpModel : public SingleOpModel {
|
||||||
|
public:
|
||||||
|
explicit MatrixSetDiagOpModel(const TensorData& input,
|
||||||
|
const TensorData& diag) {
|
||||||
|
input_ = AddInput(input);
|
||||||
|
diag_ = AddInput(diag);
|
||||||
|
output_ = AddOutput({input.type, {}});
|
||||||
|
|
||||||
|
SetBuiltinOp(BuiltinOperator_MATRIX_SET_DIAG,
|
||||||
|
BuiltinOptions_MatrixSetDiagOptions,
|
||||||
|
CreateMatrixSetDiagOptions(builder_).Union());
|
||||||
|
BuildInterpreter({GetShape(input_), GetShape(diag_)});
|
||||||
|
}
|
||||||
|
|
||||||
|
int input() { return input_; }
|
||||||
|
int diag() { return diag_; }
|
||||||
|
|
||||||
|
std::vector<T> GetOutput() { return ExtractVector<T>(output_); }
|
||||||
|
std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
|
||||||
|
TfLiteType GetOutputType() {
|
||||||
|
TfLiteTensor* t = interpreter_->tensor(output_);
|
||||||
|
return t->type;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
int input_;
|
||||||
|
int diag_;
|
||||||
|
int output_;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Use the machinery of TYPED_TEST_SUITE to test all supported types.
|
||||||
|
// See
|
||||||
|
// https://github.com/google/googletest/blob/master/googletest/docs/advanced.md#typed-tests
|
||||||
|
// for details.
|
||||||
|
template <typename T>
|
||||||
|
class MatrixSetDiagOpTest : public ::testing::Test {};
|
||||||
|
|
||||||
|
using TypesUnderTest =
|
||||||
|
::testing::Types<TypeUnion<int32_t>, TypeUnion<float>, TypeUnion<int16_t>,
|
||||||
|
TypeUnion<int8_t>, TypeUnion<uint8_t>>;
|
||||||
|
|
||||||
|
TYPED_TEST_SUITE(MatrixSetDiagOpTest, TypesUnderTest);
|
||||||
|
|
||||||
|
TYPED_TEST(MatrixSetDiagOpTest, ThreeByThreeDiagScatter) {
|
||||||
|
MatrixSetDiagOpModel<typename TypeParam::ScalarType> model(
|
||||||
|
{TypeParam::tensor_type, {3, 3}}, {TypeParam::tensor_type, {3}});
|
||||||
|
model.template PopulateTensor<typename TypeParam::ScalarType>(model.input(),
|
||||||
|
{7, 1, 2, //
|
||||||
|
3, 8, 4, //
|
||||||
|
5, 6, 9});
|
||||||
|
model.template PopulateTensor<typename TypeParam::ScalarType>(model.diag(),
|
||||||
|
{0, 4, 2});
|
||||||
|
model.Invoke();
|
||||||
|
EXPECT_THAT(model.GetOutputShape(), ElementsAre(3, 3));
|
||||||
|
EXPECT_THAT(model.GetOutput(), ElementsAreArray({0, 1, 2, //
|
||||||
|
3, 4, 4, //
|
||||||
|
5, 6, 2}));
|
||||||
|
EXPECT_THAT(model.GetOutputType(), TypeParam::tflite_type);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(MatrixSetDiagTest, Int32TestMoreColumnsThanRows) {
|
||||||
|
MatrixSetDiagOpModel<int32_t> model({TensorType_INT32, {2, 3}},
|
||||||
|
{TensorType_INT32, {2}});
|
||||||
|
model.PopulateTensor<int32_t>(model.input(), {0, 0, 0, //
|
||||||
|
9, 9, 9});
|
||||||
|
model.PopulateTensor<int32_t>(model.diag(), {1, 1});
|
||||||
|
model.Invoke();
|
||||||
|
EXPECT_THAT(model.GetOutputShape(), ElementsAre(2, 3));
|
||||||
|
EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 0, 0, //
|
||||||
|
9, 1, 9}));
|
||||||
|
EXPECT_THAT(model.GetOutputType(), TfLiteType::kTfLiteInt32);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(MatrixSetDiagTest, Int32TestTwoDimDiag) {
|
||||||
|
MatrixSetDiagOpModel<int32_t> model({TensorType_INT32, {2, 4, 4}},
|
||||||
|
{TensorType_INT32, {2, 4}});
|
||||||
|
model.PopulateTensor<int32_t>(model.input(), {5, 5, 5, 5, //
|
||||||
|
5, 5, 5, 5, //
|
||||||
|
5, 5, 5, 5, //
|
||||||
|
5, 5, 5, 5, //
|
||||||
|
1, 1, 1, 1, //
|
||||||
|
1, 1, 1, 1, //
|
||||||
|
1, 1, 1, 1, //
|
||||||
|
1, 1, 1, 1});
|
||||||
|
model.PopulateTensor<int32_t>(model.diag(), {1, 2, 3, 4, 5, 6, 7, 8});
|
||||||
|
model.Invoke();
|
||||||
|
EXPECT_THAT(model.GetOutputShape(), ElementsAre(2, 4, 4));
|
||||||
|
EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 5, 5, 5, //
|
||||||
|
5, 2, 5, 5, //
|
||||||
|
5, 5, 3, 5, //
|
||||||
|
5, 5, 5, 4, //
|
||||||
|
5, 1, 1, 1, //
|
||||||
|
1, 6, 1, 1, //
|
||||||
|
1, 1, 7, 1, //
|
||||||
|
1, 1, 1, 8}));
|
||||||
|
EXPECT_THAT(model.GetOutputType(), TfLiteType::kTfLiteInt32);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace tflite
|
||||||
|
|
||||||
|
int main(int argc, char** argv) {
|
||||||
|
::tflite::LogToStderr();
|
||||||
|
::testing::InitGoogleTest(&argc, argv);
|
||||||
|
return RUN_ALL_TESTS();
|
||||||
|
}
|
@ -141,6 +141,7 @@ TfLiteRegistration* Register_ELU();
|
|||||||
TfLiteRegistration* Register_REVERSE_SEQUENCE();
|
TfLiteRegistration* Register_REVERSE_SEQUENCE();
|
||||||
TfLiteRegistration* Register_MATRIX_DIAG();
|
TfLiteRegistration* Register_MATRIX_DIAG();
|
||||||
TfLiteRegistration* Register_QUANTIZE();
|
TfLiteRegistration* Register_QUANTIZE();
|
||||||
|
TfLiteRegistration* Register_MATRIX_SET_DIAG();
|
||||||
|
|
||||||
TfLiteStatus UnsupportedTensorFlowOp(TfLiteContext* context, TfLiteNode* node) {
|
TfLiteStatus UnsupportedTensorFlowOp(TfLiteContext* context, TfLiteNode* node) {
|
||||||
context->ReportError(
|
context->ReportError(
|
||||||
@ -378,6 +379,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
|
|||||||
AddBuiltin(BuiltinOperator_REVERSE_SEQUENCE, Register_REVERSE_SEQUENCE());
|
AddBuiltin(BuiltinOperator_REVERSE_SEQUENCE, Register_REVERSE_SEQUENCE());
|
||||||
AddBuiltin(BuiltinOperator_MATRIX_DIAG, Register_MATRIX_DIAG());
|
AddBuiltin(BuiltinOperator_MATRIX_DIAG, Register_MATRIX_DIAG());
|
||||||
AddBuiltin(BuiltinOperator_QUANTIZE, Register_QUANTIZE());
|
AddBuiltin(BuiltinOperator_QUANTIZE, Register_QUANTIZE());
|
||||||
|
AddBuiltin(BuiltinOperator_MATRIX_SET_DIAG, Register_MATRIX_SET_DIAG());
|
||||||
|
|
||||||
// TODO(andrewharp, ahentz): Move these somewhere more appropriate so that
|
// TODO(andrewharp, ahentz): Move these somewhere more appropriate so that
|
||||||
// custom ops aren't always included by default.
|
// custom ops aren't always included by default.
|
||||||
|
@ -673,6 +673,7 @@ TfLiteStatus AddOpsAndParams(
|
|||||||
case tflite::BuiltinOperator_REVERSE_SEQUENCE:
|
case tflite::BuiltinOperator_REVERSE_SEQUENCE:
|
||||||
case tflite::BuiltinOperator_MATRIX_DIAG:
|
case tflite::BuiltinOperator_MATRIX_DIAG:
|
||||||
case tflite::BuiltinOperator_QUANTIZE:
|
case tflite::BuiltinOperator_QUANTIZE:
|
||||||
|
case tflite::BuiltinOperator_MATRIX_SET_DIAG:
|
||||||
logError("Op code %d is currently not delegated to NNAPI", builtin);
|
logError("Op code %d is currently not delegated to NNAPI", builtin);
|
||||||
return kTfLiteError;
|
return kTfLiteError;
|
||||||
break;
|
break;
|
||||||
|
@ -228,6 +228,7 @@ enum BuiltinOperator : byte {
|
|||||||
REVERSE_SEQUENCE = 112,
|
REVERSE_SEQUENCE = 112,
|
||||||
MATRIX_DIAG = 113,
|
MATRIX_DIAG = 113,
|
||||||
QUANTIZE = 114,
|
QUANTIZE = 114,
|
||||||
|
MATRIX_SET_DIAG = 115
|
||||||
}
|
}
|
||||||
|
|
||||||
// Options for the builtin operators.
|
// Options for the builtin operators.
|
||||||
@ -321,6 +322,7 @@ union BuiltinOptions {
|
|||||||
ReverseSequenceOptions,
|
ReverseSequenceOptions,
|
||||||
MatrixDiagOptions,
|
MatrixDiagOptions,
|
||||||
QuantizeOptions,
|
QuantizeOptions,
|
||||||
|
MatrixSetDiagOptions
|
||||||
}
|
}
|
||||||
|
|
||||||
enum Padding : byte { SAME, VALID }
|
enum Padding : byte { SAME, VALID }
|
||||||
@ -767,6 +769,9 @@ table MatrixDiagOptions {
|
|||||||
table QuantizeOptions {
|
table QuantizeOptions {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
table MatrixSetDiagOptions {
|
||||||
|
}
|
||||||
|
|
||||||
// An OperatorCode can be an enum value (BuiltinOperator) if the operator is a
|
// An OperatorCode can be an enum value (BuiltinOperator) if the operator is a
|
||||||
// builtin, or a string if the operator is custom.
|
// builtin, or a string if the operator is custom.
|
||||||
table OperatorCode {
|
table OperatorCode {
|
||||||
|
@ -298,6 +298,9 @@ struct MatrixDiagOptionsT;
|
|||||||
struct QuantizeOptions;
|
struct QuantizeOptions;
|
||||||
struct QuantizeOptionsT;
|
struct QuantizeOptionsT;
|
||||||
|
|
||||||
|
struct MatrixSetDiagOptions;
|
||||||
|
struct MatrixSetDiagOptionsT;
|
||||||
|
|
||||||
struct OperatorCode;
|
struct OperatorCode;
|
||||||
struct OperatorCodeT;
|
struct OperatorCodeT;
|
||||||
|
|
||||||
@ -562,11 +565,12 @@ enum BuiltinOperator {
|
|||||||
BuiltinOperator_REVERSE_SEQUENCE = 112,
|
BuiltinOperator_REVERSE_SEQUENCE = 112,
|
||||||
BuiltinOperator_MATRIX_DIAG = 113,
|
BuiltinOperator_MATRIX_DIAG = 113,
|
||||||
BuiltinOperator_QUANTIZE = 114,
|
BuiltinOperator_QUANTIZE = 114,
|
||||||
|
BuiltinOperator_MATRIX_SET_DIAG = 115,
|
||||||
BuiltinOperator_MIN = BuiltinOperator_ADD,
|
BuiltinOperator_MIN = BuiltinOperator_ADD,
|
||||||
BuiltinOperator_MAX = BuiltinOperator_QUANTIZE
|
BuiltinOperator_MAX = BuiltinOperator_MATRIX_SET_DIAG
|
||||||
};
|
};
|
||||||
|
|
||||||
inline const BuiltinOperator (&EnumValuesBuiltinOperator())[114] {
|
inline const BuiltinOperator (&EnumValuesBuiltinOperator())[115] {
|
||||||
static const BuiltinOperator values[] = {
|
static const BuiltinOperator values[] = {
|
||||||
BuiltinOperator_ADD,
|
BuiltinOperator_ADD,
|
||||||
BuiltinOperator_AVERAGE_POOL_2D,
|
BuiltinOperator_AVERAGE_POOL_2D,
|
||||||
@ -681,7 +685,8 @@ inline const BuiltinOperator (&EnumValuesBuiltinOperator())[114] {
|
|||||||
BuiltinOperator_ELU,
|
BuiltinOperator_ELU,
|
||||||
BuiltinOperator_REVERSE_SEQUENCE,
|
BuiltinOperator_REVERSE_SEQUENCE,
|
||||||
BuiltinOperator_MATRIX_DIAG,
|
BuiltinOperator_MATRIX_DIAG,
|
||||||
BuiltinOperator_QUANTIZE
|
BuiltinOperator_QUANTIZE,
|
||||||
|
BuiltinOperator_MATRIX_SET_DIAG
|
||||||
};
|
};
|
||||||
return values;
|
return values;
|
||||||
}
|
}
|
||||||
@ -803,6 +808,7 @@ inline const char * const *EnumNamesBuiltinOperator() {
|
|||||||
"REVERSE_SEQUENCE",
|
"REVERSE_SEQUENCE",
|
||||||
"MATRIX_DIAG",
|
"MATRIX_DIAG",
|
||||||
"QUANTIZE",
|
"QUANTIZE",
|
||||||
|
"MATRIX_SET_DIAG",
|
||||||
nullptr
|
nullptr
|
||||||
};
|
};
|
||||||
return names;
|
return names;
|
||||||
@ -904,11 +910,12 @@ enum BuiltinOptions {
|
|||||||
BuiltinOptions_ReverseSequenceOptions = 87,
|
BuiltinOptions_ReverseSequenceOptions = 87,
|
||||||
BuiltinOptions_MatrixDiagOptions = 88,
|
BuiltinOptions_MatrixDiagOptions = 88,
|
||||||
BuiltinOptions_QuantizeOptions = 89,
|
BuiltinOptions_QuantizeOptions = 89,
|
||||||
|
BuiltinOptions_MatrixSetDiagOptions = 90,
|
||||||
BuiltinOptions_MIN = BuiltinOptions_NONE,
|
BuiltinOptions_MIN = BuiltinOptions_NONE,
|
||||||
BuiltinOptions_MAX = BuiltinOptions_QuantizeOptions
|
BuiltinOptions_MAX = BuiltinOptions_MatrixSetDiagOptions
|
||||||
};
|
};
|
||||||
|
|
||||||
inline const BuiltinOptions (&EnumValuesBuiltinOptions())[90] {
|
inline const BuiltinOptions (&EnumValuesBuiltinOptions())[91] {
|
||||||
static const BuiltinOptions values[] = {
|
static const BuiltinOptions values[] = {
|
||||||
BuiltinOptions_NONE,
|
BuiltinOptions_NONE,
|
||||||
BuiltinOptions_Conv2DOptions,
|
BuiltinOptions_Conv2DOptions,
|
||||||
@ -999,7 +1006,8 @@ inline const BuiltinOptions (&EnumValuesBuiltinOptions())[90] {
|
|||||||
BuiltinOptions_RankOptions,
|
BuiltinOptions_RankOptions,
|
||||||
BuiltinOptions_ReverseSequenceOptions,
|
BuiltinOptions_ReverseSequenceOptions,
|
||||||
BuiltinOptions_MatrixDiagOptions,
|
BuiltinOptions_MatrixDiagOptions,
|
||||||
BuiltinOptions_QuantizeOptions
|
BuiltinOptions_QuantizeOptions,
|
||||||
|
BuiltinOptions_MatrixSetDiagOptions
|
||||||
};
|
};
|
||||||
return values;
|
return values;
|
||||||
}
|
}
|
||||||
@ -1096,6 +1104,7 @@ inline const char * const *EnumNamesBuiltinOptions() {
|
|||||||
"ReverseSequenceOptions",
|
"ReverseSequenceOptions",
|
||||||
"MatrixDiagOptions",
|
"MatrixDiagOptions",
|
||||||
"QuantizeOptions",
|
"QuantizeOptions",
|
||||||
|
"MatrixSetDiagOptions",
|
||||||
nullptr
|
nullptr
|
||||||
};
|
};
|
||||||
return names;
|
return names;
|
||||||
@ -1466,6 +1475,10 @@ template<> struct BuiltinOptionsTraits<QuantizeOptions> {
|
|||||||
static const BuiltinOptions enum_value = BuiltinOptions_QuantizeOptions;
|
static const BuiltinOptions enum_value = BuiltinOptions_QuantizeOptions;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template<> struct BuiltinOptionsTraits<MatrixSetDiagOptions> {
|
||||||
|
static const BuiltinOptions enum_value = BuiltinOptions_MatrixSetDiagOptions;
|
||||||
|
};
|
||||||
|
|
||||||
struct BuiltinOptionsUnion {
|
struct BuiltinOptionsUnion {
|
||||||
BuiltinOptions type;
|
BuiltinOptions type;
|
||||||
void *value;
|
void *value;
|
||||||
@ -2209,6 +2222,14 @@ struct BuiltinOptionsUnion {
|
|||||||
return type == BuiltinOptions_QuantizeOptions ?
|
return type == BuiltinOptions_QuantizeOptions ?
|
||||||
reinterpret_cast<const QuantizeOptionsT *>(value) : nullptr;
|
reinterpret_cast<const QuantizeOptionsT *>(value) : nullptr;
|
||||||
}
|
}
|
||||||
|
MatrixSetDiagOptionsT *AsMatrixSetDiagOptions() {
|
||||||
|
return type == BuiltinOptions_MatrixSetDiagOptions ?
|
||||||
|
reinterpret_cast<MatrixSetDiagOptionsT *>(value) : nullptr;
|
||||||
|
}
|
||||||
|
const MatrixSetDiagOptionsT *AsMatrixSetDiagOptions() const {
|
||||||
|
return type == BuiltinOptions_MatrixSetDiagOptions ?
|
||||||
|
reinterpret_cast<const MatrixSetDiagOptionsT *>(value) : nullptr;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type);
|
bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type);
|
||||||
@ -7691,6 +7712,46 @@ inline flatbuffers::Offset<QuantizeOptions> CreateQuantizeOptions(
|
|||||||
|
|
||||||
flatbuffers::Offset<QuantizeOptions> CreateQuantizeOptions(flatbuffers::FlatBufferBuilder &_fbb, const QuantizeOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
|
flatbuffers::Offset<QuantizeOptions> CreateQuantizeOptions(flatbuffers::FlatBufferBuilder &_fbb, const QuantizeOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
|
||||||
|
|
||||||
|
struct MatrixSetDiagOptionsT : public flatbuffers::NativeTable {
|
||||||
|
typedef MatrixSetDiagOptions TableType;
|
||||||
|
MatrixSetDiagOptionsT() {
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct MatrixSetDiagOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
|
||||||
|
typedef MatrixSetDiagOptionsT NativeTableType;
|
||||||
|
bool Verify(flatbuffers::Verifier &verifier) const {
|
||||||
|
return VerifyTableStart(verifier) &&
|
||||||
|
verifier.EndTable();
|
||||||
|
}
|
||||||
|
MatrixSetDiagOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
|
||||||
|
void UnPackTo(MatrixSetDiagOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
|
||||||
|
static flatbuffers::Offset<MatrixSetDiagOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const MatrixSetDiagOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
|
||||||
|
};
|
||||||
|
|
||||||
|
struct MatrixSetDiagOptionsBuilder {
|
||||||
|
flatbuffers::FlatBufferBuilder &fbb_;
|
||||||
|
flatbuffers::uoffset_t start_;
|
||||||
|
explicit MatrixSetDiagOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
|
||||||
|
: fbb_(_fbb) {
|
||||||
|
start_ = fbb_.StartTable();
|
||||||
|
}
|
||||||
|
MatrixSetDiagOptionsBuilder &operator=(const MatrixSetDiagOptionsBuilder &);
|
||||||
|
flatbuffers::Offset<MatrixSetDiagOptions> Finish() {
|
||||||
|
const auto end = fbb_.EndTable(start_);
|
||||||
|
auto o = flatbuffers::Offset<MatrixSetDiagOptions>(end);
|
||||||
|
return o;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
inline flatbuffers::Offset<MatrixSetDiagOptions> CreateMatrixSetDiagOptions(
|
||||||
|
flatbuffers::FlatBufferBuilder &_fbb) {
|
||||||
|
MatrixSetDiagOptionsBuilder builder_(_fbb);
|
||||||
|
return builder_.Finish();
|
||||||
|
}
|
||||||
|
|
||||||
|
flatbuffers::Offset<MatrixSetDiagOptions> CreateMatrixSetDiagOptions(flatbuffers::FlatBufferBuilder &_fbb, const MatrixSetDiagOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
|
||||||
|
|
||||||
struct OperatorCodeT : public flatbuffers::NativeTable {
|
struct OperatorCodeT : public flatbuffers::NativeTable {
|
||||||
typedef OperatorCode TableType;
|
typedef OperatorCode TableType;
|
||||||
BuiltinOperator builtin_code;
|
BuiltinOperator builtin_code;
|
||||||
@ -8091,6 +8152,9 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
|
|||||||
const QuantizeOptions *builtin_options_as_QuantizeOptions() const {
|
const QuantizeOptions *builtin_options_as_QuantizeOptions() const {
|
||||||
return builtin_options_type() == BuiltinOptions_QuantizeOptions ? static_cast<const QuantizeOptions *>(builtin_options()) : nullptr;
|
return builtin_options_type() == BuiltinOptions_QuantizeOptions ? static_cast<const QuantizeOptions *>(builtin_options()) : nullptr;
|
||||||
}
|
}
|
||||||
|
const MatrixSetDiagOptions *builtin_options_as_MatrixSetDiagOptions() const {
|
||||||
|
return builtin_options_type() == BuiltinOptions_MatrixSetDiagOptions ? static_cast<const MatrixSetDiagOptions *>(builtin_options()) : nullptr;
|
||||||
|
}
|
||||||
const flatbuffers::Vector<uint8_t> *custom_options() const {
|
const flatbuffers::Vector<uint8_t> *custom_options() const {
|
||||||
return GetPointer<const flatbuffers::Vector<uint8_t> *>(VT_CUSTOM_OPTIONS);
|
return GetPointer<const flatbuffers::Vector<uint8_t> *>(VT_CUSTOM_OPTIONS);
|
||||||
}
|
}
|
||||||
@ -8478,6 +8542,10 @@ template<> inline const QuantizeOptions *Operator::builtin_options_as<QuantizeOp
|
|||||||
return builtin_options_as_QuantizeOptions();
|
return builtin_options_as_QuantizeOptions();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<> inline const MatrixSetDiagOptions *Operator::builtin_options_as<MatrixSetDiagOptions>() const {
|
||||||
|
return builtin_options_as_MatrixSetDiagOptions();
|
||||||
|
}
|
||||||
|
|
||||||
struct OperatorBuilder {
|
struct OperatorBuilder {
|
||||||
flatbuffers::FlatBufferBuilder &fbb_;
|
flatbuffers::FlatBufferBuilder &fbb_;
|
||||||
flatbuffers::uoffset_t start_;
|
flatbuffers::uoffset_t start_;
|
||||||
@ -11338,6 +11406,29 @@ inline flatbuffers::Offset<QuantizeOptions> CreateQuantizeOptions(flatbuffers::F
|
|||||||
_fbb);
|
_fbb);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline MatrixSetDiagOptionsT *MatrixSetDiagOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
|
||||||
|
auto _o = new MatrixSetDiagOptionsT();
|
||||||
|
UnPackTo(_o, _resolver);
|
||||||
|
return _o;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void MatrixSetDiagOptions::UnPackTo(MatrixSetDiagOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const {
|
||||||
|
(void)_o;
|
||||||
|
(void)_resolver;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline flatbuffers::Offset<MatrixSetDiagOptions> MatrixSetDiagOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const MatrixSetDiagOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
|
||||||
|
return CreateMatrixSetDiagOptions(_fbb, _o, _rehasher);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline flatbuffers::Offset<MatrixSetDiagOptions> CreateMatrixSetDiagOptions(flatbuffers::FlatBufferBuilder &_fbb, const MatrixSetDiagOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
|
||||||
|
(void)_rehasher;
|
||||||
|
(void)_o;
|
||||||
|
struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const MatrixSetDiagOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
|
||||||
|
return tflite::CreateMatrixSetDiagOptions(
|
||||||
|
_fbb);
|
||||||
|
}
|
||||||
|
|
||||||
inline OperatorCodeT *OperatorCode::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
|
inline OperatorCodeT *OperatorCode::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
|
||||||
auto _o = new OperatorCodeT();
|
auto _o = new OperatorCodeT();
|
||||||
UnPackTo(_o, _resolver);
|
UnPackTo(_o, _resolver);
|
||||||
@ -11952,6 +12043,10 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *ob
|
|||||||
auto ptr = reinterpret_cast<const QuantizeOptions *>(obj);
|
auto ptr = reinterpret_cast<const QuantizeOptions *>(obj);
|
||||||
return verifier.VerifyTable(ptr);
|
return verifier.VerifyTable(ptr);
|
||||||
}
|
}
|
||||||
|
case BuiltinOptions_MatrixSetDiagOptions: {
|
||||||
|
auto ptr = reinterpret_cast<const MatrixSetDiagOptions *>(obj);
|
||||||
|
return verifier.VerifyTable(ptr);
|
||||||
|
}
|
||||||
default: return false;
|
default: return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -12326,6 +12421,10 @@ inline void *BuiltinOptionsUnion::UnPack(const void *obj, BuiltinOptions type, c
|
|||||||
auto ptr = reinterpret_cast<const QuantizeOptions *>(obj);
|
auto ptr = reinterpret_cast<const QuantizeOptions *>(obj);
|
||||||
return ptr->UnPack(resolver);
|
return ptr->UnPack(resolver);
|
||||||
}
|
}
|
||||||
|
case BuiltinOptions_MatrixSetDiagOptions: {
|
||||||
|
auto ptr = reinterpret_cast<const MatrixSetDiagOptions *>(obj);
|
||||||
|
return ptr->UnPack(resolver);
|
||||||
|
}
|
||||||
default: return nullptr;
|
default: return nullptr;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -12688,6 +12787,10 @@ inline flatbuffers::Offset<void> BuiltinOptionsUnion::Pack(flatbuffers::FlatBuff
|
|||||||
auto ptr = reinterpret_cast<const QuantizeOptionsT *>(value);
|
auto ptr = reinterpret_cast<const QuantizeOptionsT *>(value);
|
||||||
return CreateQuantizeOptions(_fbb, ptr, _rehasher).Union();
|
return CreateQuantizeOptions(_fbb, ptr, _rehasher).Union();
|
||||||
}
|
}
|
||||||
|
case BuiltinOptions_MatrixSetDiagOptions: {
|
||||||
|
auto ptr = reinterpret_cast<const MatrixSetDiagOptionsT *>(value);
|
||||||
|
return CreateMatrixSetDiagOptions(_fbb, ptr, _rehasher).Union();
|
||||||
|
}
|
||||||
default: return 0;
|
default: return 0;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -13050,6 +13153,10 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) FL
|
|||||||
value = new QuantizeOptionsT(*reinterpret_cast<QuantizeOptionsT *>(u.value));
|
value = new QuantizeOptionsT(*reinterpret_cast<QuantizeOptionsT *>(u.value));
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
case BuiltinOptions_MatrixSetDiagOptions: {
|
||||||
|
value = new MatrixSetDiagOptionsT(*reinterpret_cast<MatrixSetDiagOptionsT *>(u.value));
|
||||||
|
break;
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@ -13502,6 +13609,11 @@ inline void BuiltinOptionsUnion::Reset() {
|
|||||||
delete ptr;
|
delete ptr;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
case BuiltinOptions_MatrixSetDiagOptions: {
|
||||||
|
auto ptr = reinterpret_cast<MatrixSetDiagOptionsT *>(value);
|
||||||
|
delete ptr;
|
||||||
|
break;
|
||||||
|
}
|
||||||
default: break;
|
default: break;
|
||||||
}
|
}
|
||||||
value = nullptr;
|
value = nullptr;
|
||||||
|
@ -4385,6 +4385,79 @@ def make_matrix_diag_tests(zip_path):
|
|||||||
make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
|
make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
|
||||||
|
|
||||||
|
|
||||||
|
def make_matrix_set_diag_tests(zip_path):
|
||||||
|
"""Make a set of tests for tf.matrix_set_diag op."""
|
||||||
|
|
||||||
|
test_parameters = [
|
||||||
|
{
|
||||||
|
"input_diag_shapes": [([3, 3], [3]), ([2, 3], [2]), ([2, 4, 4],
|
||||||
|
[2, 4]),
|
||||||
|
([3, 4, 5, 6], [3, 4, 5])],
|
||||||
|
"input_dtype": [tf.int32, tf.float32, tf.uint8],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
def build_graph(parameters):
|
||||||
|
input_shape = parameters["input_diag_shapes"][0]
|
||||||
|
diag_shape = parameters["input_diag_shapes"][1]
|
||||||
|
input_tensor = tf.placeholder(
|
||||||
|
dtype=parameters["input_dtype"], name="input", shape=input_shape)
|
||||||
|
diag_tensor = tf.placeholder(
|
||||||
|
dtype=parameters["input_dtype"], name="diagonal", shape=diag_shape)
|
||||||
|
outs = tf.matrix_set_diag(input_tensor, diag_tensor)
|
||||||
|
return [input_tensor, diag_tensor], [outs]
|
||||||
|
|
||||||
|
def build_inputs(parameters, sess, inputs, outputs):
|
||||||
|
input_shape = parameters["input_diag_shapes"][0]
|
||||||
|
diag_shape = parameters["input_diag_shapes"][1]
|
||||||
|
input_values = create_tensor_data(parameters["input_dtype"], input_shape)
|
||||||
|
diag_values = create_tensor_data(parameters["input_dtype"], diag_shape)
|
||||||
|
return [input_values, diag_values], sess.run(
|
||||||
|
outputs, feed_dict=dict(zip(inputs, [input_values, diag_values])))
|
||||||
|
|
||||||
|
make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
|
||||||
|
|
||||||
|
|
||||||
|
def make_eye_tests(zip_path):
|
||||||
|
"""Make a set of tests for tf.eye op."""
|
||||||
|
|
||||||
|
test_parameters = [{
|
||||||
|
"num_rows_shape": [[]],
|
||||||
|
"num_cols_shape": [[]],
|
||||||
|
"batch_shape": [[3], [2, 4], [4, 5, 6], None],
|
||||||
|
"use_num_cols": [True, False],
|
||||||
|
"dtype": [tf.float32, tf.int32],
|
||||||
|
}]
|
||||||
|
|
||||||
|
def build_graph(parameters):
|
||||||
|
input_tensor0 = tf.placeholder(
|
||||||
|
dtype=tf.int32, name="num_rows", shape=parameters["num_rows_shape"])
|
||||||
|
input_tensor1 = tf.placeholder(
|
||||||
|
dtype=tf.int32, name="num_columns", shape=parameters["num_cols_shape"])
|
||||||
|
if parameters["use_num_cols"]:
|
||||||
|
outs = tf.eye(
|
||||||
|
num_rows=input_tensor0,
|
||||||
|
num_columns=input_tensor1,
|
||||||
|
batch_shape=parameters["batch_shape"],
|
||||||
|
dtype=parameters["dtype"])
|
||||||
|
return [input_tensor0, input_tensor1], [outs]
|
||||||
|
else:
|
||||||
|
outs = tf.eye(num_rows=input_tensor0, dtype=parameters["dtype"])
|
||||||
|
return [input_tensor0], [outs]
|
||||||
|
|
||||||
|
def build_inputs(parameters, sess, inputs, outputs):
|
||||||
|
input_value0 = create_scalar_data(dtype=np.int32, min_value=1)
|
||||||
|
input_value1 = create_scalar_data(dtype=np.int32, min_value=1)
|
||||||
|
if parameters["use_num_cols"]:
|
||||||
|
return [input_value0, input_value1], sess.run(
|
||||||
|
outputs, feed_dict=dict(zip(inputs, [input_value0, input_value1])))
|
||||||
|
else:
|
||||||
|
return [input_value0], sess.run(
|
||||||
|
outputs, feed_dict=dict(zip(inputs, [input_value0])))
|
||||||
|
|
||||||
|
make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
|
||||||
|
|
||||||
|
|
||||||
@test_util.enable_control_flow_v2
|
@test_util.enable_control_flow_v2
|
||||||
def make_unidirectional_sequence_lstm_tests(zip_path):
|
def make_unidirectional_sequence_lstm_tests(zip_path):
|
||||||
"""Make a set of tests to do unidirectional_sequence_lstm."""
|
"""Make a set of tests to do unidirectional_sequence_lstm."""
|
||||||
|
@ -293,6 +293,13 @@ void SetDataTypeForAllOutputs(Model* model, Operator* op,
|
|||||||
SetDataTypeForAllOutputs(model, op, data_type);
|
SetDataTypeForAllOutputs(model, op, data_type);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
case OperatorType::kMatrixSetDiag: {
|
||||||
|
CHECK_EQ(op->inputs.size(), 2);
|
||||||
|
CHECK_EQ(op->outputs.size(), 1);
|
||||||
|
const ArrayDataType data_type = model->GetArray(op->inputs[0]).data_type;
|
||||||
|
SetDataTypeForAllOutputs(model, op, data_type);
|
||||||
|
break;
|
||||||
|
}
|
||||||
default: {
|
default: {
|
||||||
// These operators produce outputs with the same type as their 1st input
|
// These operators produce outputs with the same type as their 1st input
|
||||||
CHECK_GT(op->inputs.size(), 0);
|
CHECK_GT(op->inputs.size(), 0);
|
||||||
|
@ -2063,21 +2063,40 @@ void ProcessUniqueOperator(Model* model, UniqueOperator* op) {
|
|||||||
void ProcessMatrixDiagOperator(Model* model, MatrixDiagOperator* op) {
|
void ProcessMatrixDiagOperator(Model* model, MatrixDiagOperator* op) {
|
||||||
CHECK_EQ(op->inputs.size(), 1);
|
CHECK_EQ(op->inputs.size(), 1);
|
||||||
CHECK_EQ(op->outputs.size(), 1);
|
CHECK_EQ(op->outputs.size(), 1);
|
||||||
|
auto& input_array = model->GetArray(op->inputs[0]);
|
||||||
auto& output_array = model->GetArray(op->outputs[0]);
|
auto& output_array = model->GetArray(op->outputs[0]);
|
||||||
if (output_array.has_shape()) {
|
// The input array must have a shape in order to proceed. Also,
|
||||||
|
// bail out if the output shape has already been calculated.
|
||||||
|
if (!input_array.has_shape() || output_array.has_shape()) {
|
||||||
// We have already run
|
// We have already run
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
// Get the input_shape
|
// Get the input_shape
|
||||||
auto& input_array = model->GetArray(op->inputs[0]);
|
|
||||||
Shape* mutable_shape = input_array.mutable_shape();
|
Shape* mutable_shape = input_array.mutable_shape();
|
||||||
std::vector<int>* dims = mutable_shape->mutable_dims();
|
std::vector<int>* dims = mutable_shape->mutable_dims();
|
||||||
int dims_size = dims->size();
|
int dims_size = dims->size();
|
||||||
|
// Scalars are not allowed.
|
||||||
|
CHECK_GT(dims_size, 0);
|
||||||
int last_dim = (*dims)[dims_size - 1];
|
int last_dim = (*dims)[dims_size - 1];
|
||||||
dims->push_back(last_dim);
|
dims->push_back(last_dim);
|
||||||
output_array.copy_shape(*mutable_shape);
|
output_array.copy_shape(*mutable_shape);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ProcessMatrixSetDiagOperator(Model* model, MatrixSetDiagOperator* op) {
|
||||||
|
CHECK_EQ(op->inputs.size(), 2);
|
||||||
|
CHECK_EQ(op->outputs.size(), 1);
|
||||||
|
auto& input_array = model->GetArray(op->inputs[0]);
|
||||||
|
auto& output_array = model->GetArray(op->outputs[0]);
|
||||||
|
// The shape of the input array must be known because that will
|
||||||
|
// be the shape of the output array.
|
||||||
|
if (!input_array.has_shape() || !output_array.has_shape()) {
|
||||||
|
// We have already run
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
output_array.copy_shape(input_array.shape());
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
::tensorflow::Status PropagateFixedSizes::Run(Model* model,
|
::tensorflow::Status PropagateFixedSizes::Run(Model* model,
|
||||||
@ -2384,6 +2403,10 @@ void ProcessMatrixDiagOperator(Model* model, MatrixDiagOperator* op) {
|
|||||||
case OperatorType::kMatrixDiag:
|
case OperatorType::kMatrixDiag:
|
||||||
ProcessMatrixDiagOperator(model, static_cast<MatrixDiagOperator*>(op));
|
ProcessMatrixDiagOperator(model, static_cast<MatrixDiagOperator*>(op));
|
||||||
break;
|
break;
|
||||||
|
case OperatorType::kMatrixSetDiag:
|
||||||
|
ProcessMatrixSetDiagOperator(model,
|
||||||
|
static_cast<MatrixSetDiagOperator*>(op));
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
// Unimplemented, another graph transformation should drop it.
|
// Unimplemented, another graph transformation should drop it.
|
||||||
LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(op->type);
|
LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(op->type);
|
||||||
|
@ -68,7 +68,9 @@ bool SupportsQuantization(const Operator& op) {
|
|||||||
type == OperatorType::kResizeNearestNeighbor ||
|
type == OperatorType::kResizeNearestNeighbor ||
|
||||||
type == OperatorType::kPRelu || type == OperatorType::kReduceMax ||
|
type == OperatorType::kPRelu || type == OperatorType::kReduceMax ||
|
||||||
type == OperatorType::kReduceMin ||
|
type == OperatorType::kReduceMin ||
|
||||||
type == OperatorType::kTransposeConv;
|
type == OperatorType::kTransposeConv ||
|
||||||
|
type == OperatorType::kMatrixSetDiag ||
|
||||||
|
type == OperatorType::kMatrixDiag;
|
||||||
}
|
}
|
||||||
|
|
||||||
// The quantized op allows output arrays of type float using
|
// The quantized op allows output arrays of type float using
|
||||||
|
@ -2472,6 +2472,7 @@ ConverterMapType GetTensorFlowNodeConverterMap() {
|
|||||||
{"LogSoftmax", ConvertSimpleOperator<LogSoftmaxOperator, 1, 1>},
|
{"LogSoftmax", ConvertSimpleOperator<LogSoftmaxOperator, 1, 1>},
|
||||||
{"MatMul", ConvertMatMulOperator},
|
{"MatMul", ConvertMatMulOperator},
|
||||||
{"MatrixDiag", ConvertSimpleOperator<MatrixDiagOperator, 1, 1>},
|
{"MatrixDiag", ConvertSimpleOperator<MatrixDiagOperator, 1, 1>},
|
||||||
|
{"MatrixSetDiag", ConvertSimpleOperator<MatrixSetDiagOperator, 2, 1>},
|
||||||
{"Max", ConvertReduceOperator<TensorFlowMaxOperator>},
|
{"Max", ConvertReduceOperator<TensorFlowMaxOperator>},
|
||||||
{"MaxPool", ConvertMaxPoolOperator},
|
{"MaxPool", ConvertMaxPoolOperator},
|
||||||
{"Maximum", ConvertSimpleOperator<TensorFlowMaximumOperator, 2, 1>},
|
{"Maximum", ConvertSimpleOperator<TensorFlowMaximumOperator, 2, 1>},
|
||||||
|
@ -169,7 +169,8 @@ enum class OperatorType : uint8 {
|
|||||||
kWhere,
|
kWhere,
|
||||||
kElu,
|
kElu,
|
||||||
kReverseSequence,
|
kReverseSequence,
|
||||||
kMatrixDiag
|
kMatrixDiag,
|
||||||
|
kMatrixSetDiag
|
||||||
};
|
};
|
||||||
|
|
||||||
// Helper to deal with TensorFlow arrays using a different ordering of
|
// Helper to deal with TensorFlow arrays using a different ordering of
|
||||||
@ -2084,6 +2085,16 @@ struct MatrixDiagOperator : Operator {
|
|||||||
MatrixDiagOperator() : Operator(OperatorType::kMatrixDiag) {}
|
MatrixDiagOperator() : Operator(OperatorType::kMatrixDiag) {}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Matrix Set Diag Operator:
|
||||||
|
// Construct a batched diagonal tensor with given input and diagonal values.
|
||||||
|
// Input is a rank (k+1) tensor of values.
|
||||||
|
// diagonal is a rank (k) tensor of values that will be on the diagonal
|
||||||
|
// of the returned output. Output is rank k+1.
|
||||||
|
// tensor.
|
||||||
|
struct MatrixSetDiagOperator : Operator {
|
||||||
|
MatrixSetDiagOperator() : Operator(OperatorType::kMatrixSetDiag) {}
|
||||||
|
};
|
||||||
|
|
||||||
// Alloc's are used for transient arrays only. An Alloc specifies which interval
|
// Alloc's are used for transient arrays only. An Alloc specifies which interval
|
||||||
// of the "transient_data" workspace buffer passed to inference functions, is to
|
// of the "transient_data" workspace buffer passed to inference functions, is to
|
||||||
// be used for the transient array at hand. The 'start' and 'end' values are
|
// be used for the transient array at hand. The 'start' and 'end' values are
|
||||||
|
@ -2478,6 +2478,8 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList(
|
|||||||
OperatorType::kReverseSequence));
|
OperatorType::kReverseSequence));
|
||||||
ops.push_back(MakeUnique<SimpleOperator<MatrixDiagOperator>>(
|
ops.push_back(MakeUnique<SimpleOperator<MatrixDiagOperator>>(
|
||||||
"MATRIX_DIAG", OperatorType::kMatrixDiag));
|
"MATRIX_DIAG", OperatorType::kMatrixDiag));
|
||||||
|
ops.push_back(MakeUnique<SimpleOperator<MatrixSetDiagOperator>>(
|
||||||
|
"MATRIX_SET_DIAG", OperatorType::kMatrixSetDiag));
|
||||||
// Custom Operators.
|
// Custom Operators.
|
||||||
ops.push_back(
|
ops.push_back(
|
||||||
MakeUnique<DepthToSpace>("DEPTH_TO_SPACE", OperatorType::kDepthToSpace));
|
MakeUnique<DepthToSpace>("DEPTH_TO_SPACE", OperatorType::kDepthToSpace));
|
||||||
|
@ -712,6 +712,13 @@ TEST_F(OperatorTest, BuiltinMatrixDiag) {
|
|||||||
GetOperator("MATRIX_DIAG", OperatorType::kMatrixDiag), op);
|
GetOperator("MATRIX_DIAG", OperatorType::kMatrixDiag), op);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(OperatorTest, BuiltinMatrixSetDiag) {
|
||||||
|
MatrixSetDiagOperator op;
|
||||||
|
std::unique_ptr<toco::MatrixSetDiagOperator> output_toco_op =
|
||||||
|
SerializeAndDeserialize(
|
||||||
|
GetOperator("MATRIX_SET_DIAG", OperatorType::kMatrixSetDiag), op);
|
||||||
|
}
|
||||||
|
|
||||||
// Test version for a simple Op with 2 versions and the input type controls the
|
// Test version for a simple Op with 2 versions and the input type controls the
|
||||||
// version.
|
// version.
|
||||||
template <typename Op>
|
template <typename Op>
|
||||||
|
@ -428,6 +428,7 @@ const char* OperatorTypeName(OperatorType type) {
|
|||||||
HANDLE_OPERATORTYPENAME_CASE(Where)
|
HANDLE_OPERATORTYPENAME_CASE(Where)
|
||||||
HANDLE_OPERATORTYPENAME_CASE(ReverseSequence)
|
HANDLE_OPERATORTYPENAME_CASE(ReverseSequence)
|
||||||
HANDLE_OPERATORTYPENAME_CASE(MatrixDiag)
|
HANDLE_OPERATORTYPENAME_CASE(MatrixDiag)
|
||||||
|
HANDLE_OPERATORTYPENAME_CASE(MatrixSetDiag)
|
||||||
default:
|
default:
|
||||||
LOG(FATAL) << "Unhandled op type";
|
LOG(FATAL) << "Unhandled op type";
|
||||||
#undef HANDLE_OPERATORTYPENAME_CASE
|
#undef HANDLE_OPERATORTYPENAME_CASE
|
||||||
|
Loading…
x
Reference in New Issue
Block a user