PR #24918: Tflite round
Imported from GitHub PR #24918 ROUND operator for TFLITE Copybara import of the project: - 8180d7ffa144023bbc52fdcc34f6aeb1a8ab5cc4 Kernel code added by Siju Samuel <siju.samuel@huawei.com> - 4550e73ed595f835e515e0bad195edbc8826b417 Schema changes added by Siju Samuel <siju.samuel@huawei.com> - 4b430a470ee2abf6e4214bd3e4a19acaf4732cf9 Core api updated by Siju Samuel <siju.samuel@huawei.com> - bcef36a07cd4d417e696979a4418c924886801a2 Toco tflite operator updated by Siju Samuel <siju.samuel@huawei.com> - 25c239f04d3214137064b78667c0ad0b5b575fc3 Toco changes updated by Siju Samuel <siju.samuel@huawei.com> - 3d3d63c257be62ed3ec4d31bd04d95095edb599e build_def.bzl updated by Siju Samuel <siju.samuel@huawei.com> - 73ab7a0d79004d52b117027e986752680642b90f builtin ops and option_writer updated by Siju Samuel <siju.samuel@huawei.com> - b61d41ef3ce5a5f90dce5dfb6139423a6c5a4750 Documentation updated by Siju Samuel <siju.samuel@huawei.com> - 853355de6a3c7827fa46e4b0ac3f0ff8ea8b738e bugfix, changed round method to match tf.round using bank... by Siju Samuel <siju.samuel@huawei.com> - 7f8fc937ebae14b3bef6dc0bf8ec1992a8b58754 Set rounding metod to Nearest by Siju Samuel <siju.samuel@huawei.com> - f5dd6316b92ab25ac5dd476b96c00cd4e53f6ead used rint and updated testcase by Siju Samuel <siju.samuel@huawei.com> - 9cf66cbe37cc32dc8fe0eac069f040befd2a8e6d Alphabetize by Siju Samuel <siju.samuel@huawei.com> - 3ae29805c6570a2f03c606f5ecc3ea7d34924a2b make_round_test updated based on latest code by Siju Samuel <siju.samuel@huawei.com> - 838707c0a68aae905212962aa109fd72aec9928c Merge 3ae29805c6570a2f03c606f5ecc3ea7d34924a2b into 13aaa... by Siju <sijusamuel@gmail.com> PiperOrigin-RevId: 246902956
This commit is contained in:
parent
3b66ec3a38
commit
650a877042
@ -307,6 +307,7 @@ def generated_test_models():
|
||||
"resolve_constant_strided_slice",
|
||||
"reverse_sequence",
|
||||
"reverse_v2",
|
||||
"round",
|
||||
"rsqrt",
|
||||
"shape",
|
||||
"sigmoid",
|
||||
|
@ -141,6 +141,7 @@ typedef enum {
|
||||
kTfLiteBuiltinMatrixDiag = 113,
|
||||
kTfLiteBuiltinQuantize = 114,
|
||||
kTfLiteBuiltinMatrixSetDiag = 115,
|
||||
kTfLiteBuiltinRound = 116,
|
||||
} TfLiteBuiltinOperator;
|
||||
|
||||
#ifdef __cplusplus
|
||||
|
@ -728,6 +728,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
|
||||
case BuiltinOperator_RELU:
|
||||
case BuiltinOperator_RELU6:
|
||||
case BuiltinOperator_RELU_N1_TO_1:
|
||||
case BuiltinOperator_ROUND:
|
||||
case BuiltinOperator_RSQRT:
|
||||
case BuiltinOperator_SELECT:
|
||||
case BuiltinOperator_SIN:
|
||||
|
@ -174,6 +174,7 @@ class OpOptionData {
|
||||
op_to_option_["RELU"] = "";
|
||||
op_to_option_["RELU_N1_TO_1"] = "";
|
||||
op_to_option_["RELU6"] = "";
|
||||
op_to_option_["ROUND"] = "";
|
||||
op_to_option_["TANH"] = "";
|
||||
op_to_option_["PRELU"] = "";
|
||||
op_to_option_["SIN"] = "";
|
||||
|
@ -390,10 +390,10 @@ Outputs {
|
||||
**CEIL**
|
||||
|
||||
```
|
||||
inputs {
|
||||
0: tensor
|
||||
Inputs {
|
||||
0: a tensor
|
||||
}
|
||||
outputs: {
|
||||
Outputs {
|
||||
0: result of computing element-wise ceil of the input tensor
|
||||
}
|
||||
```
|
||||
@ -844,6 +844,17 @@ Options {
|
||||
}
|
||||
```
|
||||
|
||||
**ROUND**
|
||||
|
||||
```
|
||||
Inputs {
|
||||
0: a tensor
|
||||
}
|
||||
Outputs {
|
||||
0: result of computing element-wise round of the input tensor
|
||||
}
|
||||
```
|
||||
|
||||
**SLICE**
|
||||
|
||||
```
|
||||
|
@ -327,6 +327,7 @@ cc_library(
|
||||
"resize_nearest_neighbor.cc",
|
||||
"reverse.cc",
|
||||
"reverse_sequence.cc",
|
||||
"round.cc",
|
||||
"select.cc",
|
||||
"shape.cc",
|
||||
"skip_gram.cc",
|
||||
@ -685,6 +686,21 @@ cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "round_test",
|
||||
size = "small",
|
||||
srcs = ["round_test.cc"],
|
||||
tags = [
|
||||
"tflite_not_portable_ios",
|
||||
],
|
||||
deps = [
|
||||
":builtin_ops",
|
||||
"//tensorflow/lite:framework",
|
||||
"//tensorflow/lite/kernels:test_util",
|
||||
"@com_google_googletest//:gtest",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "elementwise_test",
|
||||
size = "small",
|
||||
|
@ -93,6 +93,7 @@ using reference_ops::RankOneSelect;
|
||||
using reference_ops::Relu1;
|
||||
using reference_ops::Relu6;
|
||||
using reference_ops::ReluX;
|
||||
using reference_ops::Round;
|
||||
using reference_ops::Select;
|
||||
using reference_ops::SpaceToBatchND;
|
||||
using reference_ops::Split;
|
||||
|
@ -2619,6 +2619,29 @@ inline void Ceil(const RuntimeShape& input_shape, const float* input_data,
|
||||
}
|
||||
}
|
||||
|
||||
inline float RoundToNearest(float value) {
|
||||
auto floor_val = std::floor(value);
|
||||
auto diff = value - floor_val;
|
||||
if ((diff < 0.5f) ||
|
||||
((diff == 0.5f) && (static_cast<int>(floor_val) % 2 == 0))) {
|
||||
return floor_val;
|
||||
} else {
|
||||
return floor_val = floor_val + 1.0f;
|
||||
}
|
||||
}
|
||||
|
||||
inline void Round(const RuntimeShape& input_shape, const float* input_data,
|
||||
const RuntimeShape& output_shape, float* output_data) {
|
||||
const int flat_size = MatchingFlatSize(input_shape, output_shape);
|
||||
for (int i = 0; i < flat_size; i++) {
|
||||
// Note that this implementation matches that of tensorFlow tf.round
|
||||
// and corresponds to the bankers rounding method.
|
||||
// cfenv (for fesetround) is not yet supported universally on Android, so
|
||||
// using a work around.
|
||||
output_data[i] = RoundToNearest(input_data[i]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename CoordsT = int32>
|
||||
inline void Gather(const tflite::GatherParams& op_params,
|
||||
const RuntimeShape& input_shape, const T* input_data,
|
||||
|
@ -96,6 +96,7 @@ TfLiteRegistration* Register_LESS();
|
||||
TfLiteRegistration* Register_LESS_EQUAL();
|
||||
TfLiteRegistration* Register_FLOOR();
|
||||
TfLiteRegistration* Register_CEIL();
|
||||
TfLiteRegistration* Register_ROUND();
|
||||
TfLiteRegistration* Register_TILE();
|
||||
TfLiteRegistration* Register_NEG();
|
||||
TfLiteRegistration* Register_SUM();
|
||||
@ -324,6 +325,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
|
||||
/* max_version */ 2);
|
||||
AddBuiltin(BuiltinOperator_FLOOR, Register_FLOOR());
|
||||
AddBuiltin(BuiltinOperator_CEIL, Register_CEIL());
|
||||
AddBuiltin(BuiltinOperator_ROUND, Register_ROUND());
|
||||
AddBuiltin(BuiltinOperator_NEG, Register_NEG());
|
||||
AddBuiltin(BuiltinOperator_SELECT, Register_SELECT(),
|
||||
/* min_version */ 1,
|
||||
|
59
tensorflow/lite/kernels/round.cc
Normal file
59
tensorflow/lite/kernels/round.cc
Normal file
@ -0,0 +1,59 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/lite/c/c_api_internal.h"
|
||||
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace ops {
|
||||
namespace builtin {
|
||||
namespace round {
|
||||
|
||||
constexpr int kInputTensor = 0;
|
||||
constexpr int kOutputTensor = 0;
|
||||
|
||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
|
||||
output->type = input->type;
|
||||
TfLiteIntArray* output_size = TfLiteIntArrayCopy(input->dims);
|
||||
return context->ResizeTensor(context, output, output_size);
|
||||
}
|
||||
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
|
||||
optimized_ops::Round(GetTensorShape(input), GetTensorData<float>(input),
|
||||
GetTensorShape(output), GetTensorData<float>(output));
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
} // namespace round
|
||||
|
||||
TfLiteRegistration* Register_ROUND() {
|
||||
static TfLiteRegistration r = {/*init=*/nullptr,
|
||||
/*free=*/nullptr, round::Prepare, round::Eval};
|
||||
return &r;
|
||||
}
|
||||
|
||||
} // namespace builtin
|
||||
} // namespace ops
|
||||
} // namespace tflite
|
74
tensorflow/lite/kernels/round_test.cc
Normal file
74
tensorflow/lite/kernels/round_test.cc
Normal file
@ -0,0 +1,74 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include "tensorflow/lite/interpreter.h"
|
||||
#include "tensorflow/lite/kernels/register.h"
|
||||
#include "tensorflow/lite/kernels/test_util.h"
|
||||
#include "tensorflow/lite/model.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace {
|
||||
|
||||
using ::testing::ElementsAreArray;
|
||||
|
||||
class RoundOpModel : public SingleOpModel {
|
||||
public:
|
||||
RoundOpModel(std::initializer_list<int> input_shape, TensorType input_type) {
|
||||
input_ = AddInput(TensorType_FLOAT32);
|
||||
output_ = AddOutput(TensorType_FLOAT32);
|
||||
SetBuiltinOp(BuiltinOperator_ROUND, BuiltinOptions_NONE, 0);
|
||||
BuildInterpreter({
|
||||
input_shape,
|
||||
});
|
||||
}
|
||||
|
||||
int input() { return input_; }
|
||||
|
||||
std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
|
||||
std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
|
||||
|
||||
private:
|
||||
int input_;
|
||||
int output_;
|
||||
};
|
||||
|
||||
TEST(RoundOpTest, SingleDim) {
|
||||
RoundOpModel model({6}, TensorType_FLOAT32);
|
||||
model.PopulateTensor<float>(model.input(), {8.5, 0.0, 3.5, 4.2, -3.5, -4.5});
|
||||
model.Invoke();
|
||||
EXPECT_THAT(model.GetOutput(), ElementsAreArray({8, 0, 4, 4, -4, -4}));
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({6}));
|
||||
}
|
||||
|
||||
TEST(RoundOpTest, MultiDims) {
|
||||
RoundOpModel model({2, 1, 1, 6}, TensorType_FLOAT32);
|
||||
model.PopulateTensor<float>(
|
||||
model.input(), {0.0001, 8.0001, 0.9999, 9.9999, 0.5, -0.0001, -8.0001,
|
||||
-0.9999, -9.9999, -0.5, -2.5, 1.5});
|
||||
model.Invoke();
|
||||
EXPECT_THAT(model.GetOutput(),
|
||||
ElementsAreArray({0, 8, 1, 10, 0, 0, -8, -1, -10, -0, -2, 2}));
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 1, 1, 6}));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tflite
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
::tflite::LogToStderr();
|
||||
::testing::InitGoogleTest(&argc, argv);
|
||||
return RUN_ALL_TESTS();
|
||||
}
|
@ -228,7 +228,8 @@ enum BuiltinOperator : byte {
|
||||
REVERSE_SEQUENCE = 112,
|
||||
MATRIX_DIAG = 113,
|
||||
QUANTIZE = 114,
|
||||
MATRIX_SET_DIAG = 115
|
||||
MATRIX_SET_DIAG = 115,
|
||||
ROUND = 116,
|
||||
}
|
||||
|
||||
// Options for the builtin operators.
|
||||
|
@ -566,11 +566,12 @@ enum BuiltinOperator {
|
||||
BuiltinOperator_MATRIX_DIAG = 113,
|
||||
BuiltinOperator_QUANTIZE = 114,
|
||||
BuiltinOperator_MATRIX_SET_DIAG = 115,
|
||||
BuiltinOperator_ROUND = 116,
|
||||
BuiltinOperator_MIN = BuiltinOperator_ADD,
|
||||
BuiltinOperator_MAX = BuiltinOperator_MATRIX_SET_DIAG
|
||||
BuiltinOperator_MAX = BuiltinOperator_ROUND
|
||||
};
|
||||
|
||||
inline const BuiltinOperator (&EnumValuesBuiltinOperator())[115] {
|
||||
inline const BuiltinOperator (&EnumValuesBuiltinOperator())[116] {
|
||||
static const BuiltinOperator values[] = {
|
||||
BuiltinOperator_ADD,
|
||||
BuiltinOperator_AVERAGE_POOL_2D,
|
||||
@ -686,7 +687,8 @@ inline const BuiltinOperator (&EnumValuesBuiltinOperator())[115] {
|
||||
BuiltinOperator_REVERSE_SEQUENCE,
|
||||
BuiltinOperator_MATRIX_DIAG,
|
||||
BuiltinOperator_QUANTIZE,
|
||||
BuiltinOperator_MATRIX_SET_DIAG
|
||||
BuiltinOperator_MATRIX_SET_DIAG,
|
||||
BuiltinOperator_ROUND
|
||||
};
|
||||
return values;
|
||||
}
|
||||
@ -809,6 +811,7 @@ inline const char * const *EnumNamesBuiltinOperator() {
|
||||
"MATRIX_DIAG",
|
||||
"QUANTIZE",
|
||||
"MATRIX_SET_DIAG",
|
||||
"ROUND",
|
||||
nullptr
|
||||
};
|
||||
return names;
|
||||
|
@ -3584,6 +3584,32 @@ def make_ceil_tests(options):
|
||||
make_zip_of_tests(options, test_parameters, build_graph, build_inputs)
|
||||
|
||||
|
||||
@register_make_test_function()
|
||||
def make_round_tests(options):
|
||||
"""Build the round op testing graph."""
|
||||
|
||||
test_parameters = [{
|
||||
"input_dtype": [tf.float32],
|
||||
"input_shape": [[], [1], [1, 2], [5, 6, 7, 8], [3, 4, 5, 6]],
|
||||
}]
|
||||
|
||||
def build_graph(parameters):
|
||||
"""Build the round op testing graph."""
|
||||
input_value = tf.placeholder(
|
||||
dtype=parameters["input_dtype"],
|
||||
name="input1",
|
||||
shape=parameters["input_shape"])
|
||||
out = tf.round(input_value)
|
||||
return [input_value], [out]
|
||||
|
||||
def build_inputs(parameters, sess, inputs, outputs):
|
||||
input_value = create_tensor_data(parameters["input_dtype"],
|
||||
parameters["input_shape"])
|
||||
return [input_value], sess.run(outputs, feed_dict={inputs[0]: input_value})
|
||||
|
||||
make_zip_of_tests(options, test_parameters, build_graph, build_inputs)
|
||||
|
||||
|
||||
@register_make_test_function()
|
||||
def make_neg_tests(options):
|
||||
"""Make a set of tests to do neg."""
|
||||
|
@ -1215,6 +1215,16 @@ void ConvertCeilOperator(const Model& model, const CeilOperator& src_op,
|
||||
(*ceil_op->mutable_attr())["T"].set_type(DT_FLOAT);
|
||||
}
|
||||
|
||||
void ConvertRoundOperator(const Model& model, const RoundOperator& src_op,
|
||||
GraphDef* tensorflow_graph) {
|
||||
tensorflow::NodeDef* round_op = tensorflow_graph->add_node();
|
||||
round_op->set_op("Round");
|
||||
round_op->set_name(src_op.outputs[0]);
|
||||
CHECK_EQ(src_op.inputs.size(), 1);
|
||||
*round_op->add_input() = src_op.inputs[0];
|
||||
(*round_op->mutable_attr())["T"].set_type(DT_FLOAT);
|
||||
}
|
||||
|
||||
void ConvertGatherOperator(const Model& model, const GatherOperator& src_op,
|
||||
GraphDef* tensorflow_graph) {
|
||||
tensorflow::NodeDef* gather_op = tensorflow_graph->add_node();
|
||||
@ -2210,6 +2220,9 @@ void ConvertOperator(const Model& model, const Operator& src_op,
|
||||
} else if (src_op.type == OperatorType::kCeil) {
|
||||
ConvertCeilOperator(model, static_cast<const CeilOperator&>(src_op),
|
||||
tensorflow_graph);
|
||||
} else if (src_op.type == OperatorType::kRound) {
|
||||
ConvertRoundOperator(model, static_cast<const RoundOperator&>(src_op),
|
||||
tensorflow_graph);
|
||||
} else if (src_op.type == OperatorType::kGather) {
|
||||
ConvertGatherOperator(model, static_cast<const GatherOperator&>(src_op),
|
||||
tensorflow_graph);
|
||||
|
@ -2153,6 +2153,7 @@ void ProcessMatrixSetDiagOperator(Model* model, MatrixSetDiagOperator* op) {
|
||||
case OperatorType::kCast:
|
||||
case OperatorType::kFloor:
|
||||
case OperatorType::kCeil:
|
||||
case OperatorType::kRound:
|
||||
case OperatorType::kExp:
|
||||
case OperatorType::kSin:
|
||||
case OperatorType::kCos:
|
||||
|
@ -37,6 +37,7 @@ bool IsElementwiseOperator(OperatorType optype) {
|
||||
case OperatorType::kRelu:
|
||||
case OperatorType::kRelu1:
|
||||
case OperatorType::kRelu6:
|
||||
case OperatorType::kRound:
|
||||
case OperatorType::kTanh:
|
||||
case OperatorType::kSqrt:
|
||||
case OperatorType::kSquare:
|
||||
|
@ -1547,6 +1547,20 @@ tensorflow::Status ConvertCeilOperator(
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
tensorflow::Status ConvertRoundOperator(
|
||||
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
|
||||
Model* model) {
|
||||
CHECK_EQ(node.op(), "Round");
|
||||
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
|
||||
const auto data_type = GetDataTypeAttr(node, "T");
|
||||
CHECK(data_type == DT_FLOAT);
|
||||
auto* op = new RoundOperator;
|
||||
op->inputs.push_back(node.input(0));
|
||||
op->outputs.push_back(node.name());
|
||||
model->operators.emplace_back(op);
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
tensorflow::Status ConvertGatherOperator(
|
||||
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
|
||||
Model* model) {
|
||||
@ -2511,6 +2525,7 @@ ConverterMapType GetTensorFlowNodeConverterMap() {
|
||||
{"ResizeNearestNeighbor", ConvertResizeNearestNeighborOperator},
|
||||
{"ReverseSequence", ConvertReverseSequenceOperator},
|
||||
{"ReverseV2", ConvertSimpleOperator<ReverseV2Operator, 2, 1>},
|
||||
{"Round", ConvertRoundOperator},
|
||||
{"Rsqrt", ConvertSimpleOperator<TensorFlowRsqrtOperator, 1, 1>},
|
||||
{"Select", ConvertSimpleOperator<SelectOperator, 3, 1>},
|
||||
{"Shape", ConvertShapeOperator},
|
||||
|
@ -82,6 +82,7 @@ enum class OperatorType : uint8 {
|
||||
kTransposeConv,
|
||||
kCast,
|
||||
kFloor,
|
||||
kRound,
|
||||
kGather,
|
||||
kResizeBilinear,
|
||||
kSin,
|
||||
@ -1715,6 +1716,16 @@ struct CeilOperator : Operator {
|
||||
CeilOperator() : Operator(OperatorType::kCeil) {}
|
||||
};
|
||||
|
||||
// Round operator.
|
||||
//
|
||||
// Inputs:
|
||||
// inputs[0]: required: the input array
|
||||
//
|
||||
// TensorFlow equivalent: Round
|
||||
struct RoundOperator : Operator {
|
||||
RoundOperator() : Operator(OperatorType::kRound) {}
|
||||
};
|
||||
|
||||
// Gather operator. It gathers slices from params according to indices.
|
||||
// Only 1-D indices are supported at the moment.
|
||||
//
|
||||
|
@ -2517,6 +2517,8 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList(
|
||||
MakeUnique<SimpleOperator<CeilOperator>>("CEIL", OperatorType::kCeil));
|
||||
ops.push_back(
|
||||
MakeUnique<SimpleOperator<EluOperator>>("ELU", OperatorType::kElu));
|
||||
ops.push_back(
|
||||
MakeUnique<SimpleOperator<RoundOperator>>("ROUND", OperatorType::kRound));
|
||||
ops.push_back(
|
||||
MakeUnique<SimpleOperator<ReluOperator>>("RELU", OperatorType::kRelu));
|
||||
ops.push_back(MakeUnique<SimpleOperator<Relu1Operator>>(
|
||||
|
@ -113,6 +113,7 @@ TEST_F(OperatorTest, SimpleOperators) {
|
||||
CheckSimpleOperator<FloorOperator>("FLOOR", OperatorType::kFloor);
|
||||
CheckSimpleOperator<CeilOperator>("CEIL", OperatorType::kCeil);
|
||||
CheckSimpleOperator<EluOperator>("ELU", OperatorType::kElu);
|
||||
CheckSimpleOperator<RoundOperator>("ROUND", OperatorType::kRound);
|
||||
CheckSimpleOperator<ReluOperator>("RELU", OperatorType::kRelu);
|
||||
CheckSimpleOperator<Relu1Operator>("RELU_N1_TO_1", OperatorType::kRelu1);
|
||||
CheckSimpleOperator<Relu6Operator>("RELU6", OperatorType::kRelu6);
|
||||
|
@ -389,6 +389,7 @@ const char* OperatorTypeName(OperatorType type) {
|
||||
HANDLE_OPERATORTYPENAME_CASE(Cast)
|
||||
HANDLE_OPERATORTYPENAME_CASE(Floor)
|
||||
HANDLE_OPERATORTYPENAME_CASE(Ceil)
|
||||
HANDLE_OPERATORTYPENAME_CASE(Round)
|
||||
HANDLE_OPERATORTYPENAME_CASE(Gather)
|
||||
HANDLE_OPERATORTYPENAME_CASE(GatherNd)
|
||||
HANDLE_OPERATORTYPENAME_CASE(ResizeBilinear)
|
||||
|
Loading…
x
Reference in New Issue
Block a user