162 lines
5.3 KiB
C++
162 lines
5.3 KiB
C++
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
==============================================================================*/
|
|
|
|
#include <complex>
|
|
#include <vector>
|
|
|
|
#include <gmock/gmock.h>
|
|
#include <gtest/gtest.h>
|
|
#include "tensorflow/lite/interpreter.h"
|
|
#include "tensorflow/lite/kernels/custom_ops_register.h"
|
|
#include "tensorflow/lite/kernels/test_util.h"
|
|
#include "tensorflow/lite/schema/schema_generated.h"
|
|
#include "tensorflow/lite/testing/util.h"
|
|
|
|
namespace tflite {
|
|
namespace ops {
|
|
namespace custom {
|
|
|
|
TfLiteRegistration* Register_REAL();
|
|
TfLiteRegistration* Register_IMAG();
|
|
|
|
namespace {
|
|
|
|
template <typename T>
|
|
class RealOpModel : public SingleOpModel {
|
|
public:
|
|
RealOpModel(const TensorData& input, const TensorData& output) {
|
|
input_ = AddInput(input);
|
|
|
|
output_ = AddOutput(output);
|
|
|
|
const std::vector<uint8_t> custom_option;
|
|
SetCustomOp("Real", custom_option, Register_REAL);
|
|
|
|
BuildInterpreter({GetShape(input_)});
|
|
}
|
|
|
|
int input() { return input_; }
|
|
|
|
std::vector<T> GetOutput() { return ExtractVector<T>(output_); }
|
|
|
|
private:
|
|
int input_;
|
|
int output_;
|
|
};
|
|
|
|
TEST(RealOpTest, SimpleFloatTest) {
|
|
RealOpModel<float> m({TensorType_COMPLEX64, {2, 4}},
|
|
{TensorType_FLOAT32, {}});
|
|
|
|
m.PopulateTensor<std::complex<float>>(m.input(), {{75, 0},
|
|
{-6, -1},
|
|
{9, 0},
|
|
{-10, 5},
|
|
{-3, 2},
|
|
{-6, 11},
|
|
{0, 0},
|
|
{22.1, 33.3}});
|
|
|
|
m.Invoke();
|
|
|
|
EXPECT_THAT(m.GetOutput(), testing::ElementsAreArray(ArrayFloatNear(
|
|
{75, -6, 9, -10, -3, -6, 0, 22.1f})));
|
|
}
|
|
|
|
TEST(RealOpTest, SimpleDoubleTest) {
|
|
RealOpModel<double> m({TensorType_COMPLEX128, {2, 4}},
|
|
{TensorType_FLOAT64, {}});
|
|
|
|
m.PopulateTensor<std::complex<double>>(m.input(), {{75, 0},
|
|
{-6, -1},
|
|
{9, 0},
|
|
{-10, 5},
|
|
{-3, 2},
|
|
{-6, 11},
|
|
{0, 0},
|
|
{22.1, 33.3}});
|
|
|
|
m.Invoke();
|
|
|
|
EXPECT_THAT(m.GetOutput(), testing::ElementsAreArray(ArrayFloatNear(
|
|
{75, -6, 9, -10, -3, -6, 0, 22.1f})));
|
|
}
|
|
|
|
template <typename T>
|
|
class ImagOpModel : public SingleOpModel {
|
|
public:
|
|
ImagOpModel(const TensorData& input, const TensorData& output) {
|
|
input_ = AddInput(input);
|
|
|
|
output_ = AddOutput(output);
|
|
|
|
const std::vector<uint8_t> custom_option;
|
|
SetCustomOp("Imag", custom_option, Register_IMAG);
|
|
|
|
BuildInterpreter({GetShape(input_)});
|
|
}
|
|
|
|
int input() { return input_; }
|
|
|
|
std::vector<T> GetOutput() { return ExtractVector<T>(output_); }
|
|
|
|
private:
|
|
int input_;
|
|
int output_;
|
|
};
|
|
|
|
TEST(ImagOpTest, SimpleFloatTest) {
|
|
ImagOpModel<float> m({TensorType_COMPLEX64, {2, 4}},
|
|
{TensorType_FLOAT32, {}});
|
|
|
|
m.PopulateTensor<std::complex<float>>(m.input(), {{75, 7},
|
|
{-6, -1},
|
|
{9, 3.5},
|
|
{-10, 5},
|
|
{-3, 2},
|
|
{-6, 11},
|
|
{0, 0},
|
|
{22.1, 33.3}});
|
|
|
|
m.Invoke();
|
|
|
|
EXPECT_THAT(m.GetOutput(), testing::ElementsAreArray(ArrayFloatNear(
|
|
{7, -1, 3.5f, 5, 2, 11, 0, 33.3f})));
|
|
}
|
|
|
|
TEST(ImagOpTest, SimpleDoubleTest) {
|
|
ImagOpModel<double> m({TensorType_COMPLEX128, {2, 4}},
|
|
{TensorType_FLOAT64, {}});
|
|
|
|
m.PopulateTensor<std::complex<double>>(m.input(), {{75, 7},
|
|
{-6, -1},
|
|
{9, 3.5},
|
|
{-10, 5},
|
|
{-3, 2},
|
|
{-6, 11},
|
|
{0, 0},
|
|
{22.1, 33.3}});
|
|
|
|
m.Invoke();
|
|
|
|
EXPECT_THAT(m.GetOutput(), testing::ElementsAreArray(ArrayFloatNear(
|
|
{7, -1, 3.5f, 5, 2, 11, 0, 33.3f})));
|
|
}
|
|
|
|
} // namespace
|
|
} // namespace custom
|
|
} // namespace ops
|
|
} // namespace tflite
|