PR #24918: Tflite round

Imported from GitHub PR #24918

ROUND operator for TFLITE

Copybara import of the project:

  - 8180d7ffa144023bbc52fdcc34f6aeb1a8ab5cc4 Kernel code added by Siju Samuel <siju.samuel@huawei.com>
  - 4550e73ed595f835e515e0bad195edbc8826b417 Schema changes added by Siju Samuel <siju.samuel@huawei.com>
  - 4b430a470ee2abf6e4214bd3e4a19acaf4732cf9 Core api updated by Siju Samuel <siju.samuel@huawei.com>
  - bcef36a07cd4d417e696979a4418c924886801a2 Toco tflite operator updated by Siju Samuel <siju.samuel@huawei.com>
  - 25c239f04d3214137064b78667c0ad0b5b575fc3 Toco changes updated by Siju Samuel <siju.samuel@huawei.com>
  - 3d3d63c257be62ed3ec4d31bd04d95095edb599e build_def.bzl updated by Siju Samuel <siju.samuel@huawei.com>
  - 73ab7a0d79004d52b117027e986752680642b90f builtin ops and option_writer updated by Siju Samuel <siju.samuel@huawei.com>
  - b61d41ef3ce5a5f90dce5dfb6139423a6c5a4750 Documentation updated by Siju Samuel <siju.samuel@huawei.com>
  - 853355de6a3c7827fa46e4b0ac3f0ff8ea8b738e bugfix, changed round method to match tf.round using bank... by Siju Samuel <siju.samuel@huawei.com>
  - 7f8fc937ebae14b3bef6dc0bf8ec1992a8b58754 Set rounding metod to Nearest by Siju Samuel <siju.samuel@huawei.com>
  - f5dd6316b92ab25ac5dd476b96c00cd4e53f6ead used rint and updated testcase by Siju Samuel <siju.samuel@huawei.com>
  - 9cf66cbe37cc32dc8fe0eac069f040befd2a8e6d Alphabetize by Siju Samuel <siju.samuel@huawei.com>
  - 3ae29805c6570a2f03c606f5ecc3ea7d34924a2b make_round_test updated based on latest code by Siju Samuel <siju.samuel@huawei.com>
  - 838707c0a68aae905212962aa109fd72aec9928c Merge 3ae29805c6570a2f03c606f5ecc3ea7d34924a2b into 13aaa... by Siju <sijusamuel@gmail.com>

PiperOrigin-RevId: 246902956
This commit is contained in:
Jared Duke 2019-05-06 14:56:32 -07:00 committed by TensorFlower Gardener
parent 3b66ec3a38
commit 650a877042
22 changed files with 272 additions and 7 deletions

View File

@ -307,6 +307,7 @@ def generated_test_models():
"resolve_constant_strided_slice",
"reverse_sequence",
"reverse_v2",
"round",
"rsqrt",
"shape",
"sigmoid",

View File

@ -141,6 +141,7 @@ typedef enum {
kTfLiteBuiltinMatrixDiag = 113,
kTfLiteBuiltinQuantize = 114,
kTfLiteBuiltinMatrixSetDiag = 115,
kTfLiteBuiltinRound = 116,
} TfLiteBuiltinOperator;
#ifdef __cplusplus

View File

@ -728,6 +728,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
case BuiltinOperator_RELU:
case BuiltinOperator_RELU6:
case BuiltinOperator_RELU_N1_TO_1:
case BuiltinOperator_ROUND:
case BuiltinOperator_RSQRT:
case BuiltinOperator_SELECT:
case BuiltinOperator_SIN:

View File

@ -174,6 +174,7 @@ class OpOptionData {
op_to_option_["RELU"] = "";
op_to_option_["RELU_N1_TO_1"] = "";
op_to_option_["RELU6"] = "";
op_to_option_["ROUND"] = "";
op_to_option_["TANH"] = "";
op_to_option_["PRELU"] = "";
op_to_option_["SIN"] = "";

View File

@ -390,10 +390,10 @@ Outputs {
**CEIL**
```
inputs {
0: tensor
Inputs {
0: a tensor
}
outputs: {
Outputs {
0: result of computing element-wise ceil of the input tensor
}
```
@ -844,6 +844,17 @@ Options {
}
```
**ROUND**
```
Inputs {
0: a tensor
}
Outputs {
0: result of computing element-wise round of the input tensor
}
```
**SLICE**
```

View File

@ -327,6 +327,7 @@ cc_library(
"resize_nearest_neighbor.cc",
"reverse.cc",
"reverse_sequence.cc",
"round.cc",
"select.cc",
"shape.cc",
"skip_gram.cc",
@ -685,6 +686,21 @@ cc_test(
],
)
cc_test(
name = "round_test",
size = "small",
srcs = ["round_test.cc"],
tags = [
"tflite_not_portable_ios",
],
deps = [
":builtin_ops",
"//tensorflow/lite:framework",
"//tensorflow/lite/kernels:test_util",
"@com_google_googletest//:gtest",
],
)
cc_test(
name = "elementwise_test",
size = "small",

View File

@ -93,6 +93,7 @@ using reference_ops::RankOneSelect;
using reference_ops::Relu1;
using reference_ops::Relu6;
using reference_ops::ReluX;
using reference_ops::Round;
using reference_ops::Select;
using reference_ops::SpaceToBatchND;
using reference_ops::Split;

View File

@ -2619,6 +2619,29 @@ inline void Ceil(const RuntimeShape& input_shape, const float* input_data,
}
}
inline float RoundToNearest(float value) {
auto floor_val = std::floor(value);
auto diff = value - floor_val;
if ((diff < 0.5f) ||
((diff == 0.5f) && (static_cast<int>(floor_val) % 2 == 0))) {
return floor_val;
} else {
return floor_val = floor_val + 1.0f;
}
}
inline void Round(const RuntimeShape& input_shape, const float* input_data,
const RuntimeShape& output_shape, float* output_data) {
const int flat_size = MatchingFlatSize(input_shape, output_shape);
for (int i = 0; i < flat_size; i++) {
// Note that this implementation matches that of tensorFlow tf.round
// and corresponds to the bankers rounding method.
// cfenv (for fesetround) is not yet supported universally on Android, so
// using a work around.
output_data[i] = RoundToNearest(input_data[i]);
}
}
template <typename T, typename CoordsT = int32>
inline void Gather(const tflite::GatherParams& op_params,
const RuntimeShape& input_shape, const T* input_data,

View File

@ -96,6 +96,7 @@ TfLiteRegistration* Register_LESS();
TfLiteRegistration* Register_LESS_EQUAL();
TfLiteRegistration* Register_FLOOR();
TfLiteRegistration* Register_CEIL();
TfLiteRegistration* Register_ROUND();
TfLiteRegistration* Register_TILE();
TfLiteRegistration* Register_NEG();
TfLiteRegistration* Register_SUM();
@ -324,6 +325,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
/* max_version */ 2);
AddBuiltin(BuiltinOperator_FLOOR, Register_FLOOR());
AddBuiltin(BuiltinOperator_CEIL, Register_CEIL());
AddBuiltin(BuiltinOperator_ROUND, Register_ROUND());
AddBuiltin(BuiltinOperator_NEG, Register_NEG());
AddBuiltin(BuiltinOperator_SELECT, Register_SELECT(),
/* min_version */ 1,

View File

@ -0,0 +1,59 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/c/c_api_internal.h"
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/lite/kernels/internal/tensor.h"
#include "tensorflow/lite/kernels/kernel_util.h"
namespace tflite {
namespace ops {
namespace builtin {
namespace round {
constexpr int kInputTensor = 0;
constexpr int kOutputTensor = 0;
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
output->type = input->type;
TfLiteIntArray* output_size = TfLiteIntArrayCopy(input->dims);
return context->ResizeTensor(context, output, output_size);
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
optimized_ops::Round(GetTensorShape(input), GetTensorData<float>(input),
GetTensorShape(output), GetTensorData<float>(output));
return kTfLiteOk;
}
} // namespace round
TfLiteRegistration* Register_ROUND() {
static TfLiteRegistration r = {/*init=*/nullptr,
/*free=*/nullptr, round::Prepare, round::Eval};
return &r;
}
} // namespace builtin
} // namespace ops
} // namespace tflite

View File

@ -0,0 +1,74 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <gtest/gtest.h>
#include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/kernels/register.h"
#include "tensorflow/lite/kernels/test_util.h"
#include "tensorflow/lite/model.h"
namespace tflite {
namespace {
using ::testing::ElementsAreArray;
class RoundOpModel : public SingleOpModel {
public:
RoundOpModel(std::initializer_list<int> input_shape, TensorType input_type) {
input_ = AddInput(TensorType_FLOAT32);
output_ = AddOutput(TensorType_FLOAT32);
SetBuiltinOp(BuiltinOperator_ROUND, BuiltinOptions_NONE, 0);
BuildInterpreter({
input_shape,
});
}
int input() { return input_; }
std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
private:
int input_;
int output_;
};
TEST(RoundOpTest, SingleDim) {
RoundOpModel model({6}, TensorType_FLOAT32);
model.PopulateTensor<float>(model.input(), {8.5, 0.0, 3.5, 4.2, -3.5, -4.5});
model.Invoke();
EXPECT_THAT(model.GetOutput(), ElementsAreArray({8, 0, 4, 4, -4, -4}));
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({6}));
}
TEST(RoundOpTest, MultiDims) {
RoundOpModel model({2, 1, 1, 6}, TensorType_FLOAT32);
model.PopulateTensor<float>(
model.input(), {0.0001, 8.0001, 0.9999, 9.9999, 0.5, -0.0001, -8.0001,
-0.9999, -9.9999, -0.5, -2.5, 1.5});
model.Invoke();
EXPECT_THAT(model.GetOutput(),
ElementsAreArray({0, 8, 1, 10, 0, 0, -8, -1, -10, -0, -2, 2}));
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 1, 1, 6}));
}
} // namespace
} // namespace tflite
int main(int argc, char** argv) {
::tflite::LogToStderr();
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}

View File

@ -228,7 +228,8 @@ enum BuiltinOperator : byte {
REVERSE_SEQUENCE = 112,
MATRIX_DIAG = 113,
QUANTIZE = 114,
MATRIX_SET_DIAG = 115
MATRIX_SET_DIAG = 115,
ROUND = 116,
}
// Options for the builtin operators.

View File

@ -566,11 +566,12 @@ enum BuiltinOperator {
BuiltinOperator_MATRIX_DIAG = 113,
BuiltinOperator_QUANTIZE = 114,
BuiltinOperator_MATRIX_SET_DIAG = 115,
BuiltinOperator_ROUND = 116,
BuiltinOperator_MIN = BuiltinOperator_ADD,
BuiltinOperator_MAX = BuiltinOperator_MATRIX_SET_DIAG
BuiltinOperator_MAX = BuiltinOperator_ROUND
};
inline const BuiltinOperator (&EnumValuesBuiltinOperator())[115] {
inline const BuiltinOperator (&EnumValuesBuiltinOperator())[116] {
static const BuiltinOperator values[] = {
BuiltinOperator_ADD,
BuiltinOperator_AVERAGE_POOL_2D,
@ -686,7 +687,8 @@ inline const BuiltinOperator (&EnumValuesBuiltinOperator())[115] {
BuiltinOperator_REVERSE_SEQUENCE,
BuiltinOperator_MATRIX_DIAG,
BuiltinOperator_QUANTIZE,
BuiltinOperator_MATRIX_SET_DIAG
BuiltinOperator_MATRIX_SET_DIAG,
BuiltinOperator_ROUND
};
return values;
}
@ -809,6 +811,7 @@ inline const char * const *EnumNamesBuiltinOperator() {
"MATRIX_DIAG",
"QUANTIZE",
"MATRIX_SET_DIAG",
"ROUND",
nullptr
};
return names;

View File

@ -3584,6 +3584,32 @@ def make_ceil_tests(options):
make_zip_of_tests(options, test_parameters, build_graph, build_inputs)
@register_make_test_function()
def make_round_tests(options):
"""Build the round op testing graph."""
test_parameters = [{
"input_dtype": [tf.float32],
"input_shape": [[], [1], [1, 2], [5, 6, 7, 8], [3, 4, 5, 6]],
}]
def build_graph(parameters):
"""Build the round op testing graph."""
input_value = tf.placeholder(
dtype=parameters["input_dtype"],
name="input1",
shape=parameters["input_shape"])
out = tf.round(input_value)
return [input_value], [out]
def build_inputs(parameters, sess, inputs, outputs):
input_value = create_tensor_data(parameters["input_dtype"],
parameters["input_shape"])
return [input_value], sess.run(outputs, feed_dict={inputs[0]: input_value})
make_zip_of_tests(options, test_parameters, build_graph, build_inputs)
@register_make_test_function()
def make_neg_tests(options):
"""Make a set of tests to do neg."""

View File

@ -1215,6 +1215,16 @@ void ConvertCeilOperator(const Model& model, const CeilOperator& src_op,
(*ceil_op->mutable_attr())["T"].set_type(DT_FLOAT);
}
void ConvertRoundOperator(const Model& model, const RoundOperator& src_op,
GraphDef* tensorflow_graph) {
tensorflow::NodeDef* round_op = tensorflow_graph->add_node();
round_op->set_op("Round");
round_op->set_name(src_op.outputs[0]);
CHECK_EQ(src_op.inputs.size(), 1);
*round_op->add_input() = src_op.inputs[0];
(*round_op->mutable_attr())["T"].set_type(DT_FLOAT);
}
void ConvertGatherOperator(const Model& model, const GatherOperator& src_op,
GraphDef* tensorflow_graph) {
tensorflow::NodeDef* gather_op = tensorflow_graph->add_node();
@ -2210,6 +2220,9 @@ void ConvertOperator(const Model& model, const Operator& src_op,
} else if (src_op.type == OperatorType::kCeil) {
ConvertCeilOperator(model, static_cast<const CeilOperator&>(src_op),
tensorflow_graph);
} else if (src_op.type == OperatorType::kRound) {
ConvertRoundOperator(model, static_cast<const RoundOperator&>(src_op),
tensorflow_graph);
} else if (src_op.type == OperatorType::kGather) {
ConvertGatherOperator(model, static_cast<const GatherOperator&>(src_op),
tensorflow_graph);

View File

@ -2153,6 +2153,7 @@ void ProcessMatrixSetDiagOperator(Model* model, MatrixSetDiagOperator* op) {
case OperatorType::kCast:
case OperatorType::kFloor:
case OperatorType::kCeil:
case OperatorType::kRound:
case OperatorType::kExp:
case OperatorType::kSin:
case OperatorType::kCos:

View File

@ -37,6 +37,7 @@ bool IsElementwiseOperator(OperatorType optype) {
case OperatorType::kRelu:
case OperatorType::kRelu1:
case OperatorType::kRelu6:
case OperatorType::kRound:
case OperatorType::kTanh:
case OperatorType::kSqrt:
case OperatorType::kSquare:

View File

@ -1547,6 +1547,20 @@ tensorflow::Status ConvertCeilOperator(
return tensorflow::Status::OK();
}
tensorflow::Status ConvertRoundOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) {
CHECK_EQ(node.op(), "Round");
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
const auto data_type = GetDataTypeAttr(node, "T");
CHECK(data_type == DT_FLOAT);
auto* op = new RoundOperator;
op->inputs.push_back(node.input(0));
op->outputs.push_back(node.name());
model->operators.emplace_back(op);
return tensorflow::Status::OK();
}
tensorflow::Status ConvertGatherOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) {
@ -2511,6 +2525,7 @@ ConverterMapType GetTensorFlowNodeConverterMap() {
{"ResizeNearestNeighbor", ConvertResizeNearestNeighborOperator},
{"ReverseSequence", ConvertReverseSequenceOperator},
{"ReverseV2", ConvertSimpleOperator<ReverseV2Operator, 2, 1>},
{"Round", ConvertRoundOperator},
{"Rsqrt", ConvertSimpleOperator<TensorFlowRsqrtOperator, 1, 1>},
{"Select", ConvertSimpleOperator<SelectOperator, 3, 1>},
{"Shape", ConvertShapeOperator},

View File

@ -82,6 +82,7 @@ enum class OperatorType : uint8 {
kTransposeConv,
kCast,
kFloor,
kRound,
kGather,
kResizeBilinear,
kSin,
@ -1715,6 +1716,16 @@ struct CeilOperator : Operator {
CeilOperator() : Operator(OperatorType::kCeil) {}
};
// Round operator.
//
// Inputs:
// inputs[0]: required: the input array
//
// TensorFlow equivalent: Round
struct RoundOperator : Operator {
RoundOperator() : Operator(OperatorType::kRound) {}
};
// Gather operator. It gathers slices from params according to indices.
// Only 1-D indices are supported at the moment.
//

View File

@ -2517,6 +2517,8 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList(
MakeUnique<SimpleOperator<CeilOperator>>("CEIL", OperatorType::kCeil));
ops.push_back(
MakeUnique<SimpleOperator<EluOperator>>("ELU", OperatorType::kElu));
ops.push_back(
MakeUnique<SimpleOperator<RoundOperator>>("ROUND", OperatorType::kRound));
ops.push_back(
MakeUnique<SimpleOperator<ReluOperator>>("RELU", OperatorType::kRelu));
ops.push_back(MakeUnique<SimpleOperator<Relu1Operator>>(

View File

@ -113,6 +113,7 @@ TEST_F(OperatorTest, SimpleOperators) {
CheckSimpleOperator<FloorOperator>("FLOOR", OperatorType::kFloor);
CheckSimpleOperator<CeilOperator>("CEIL", OperatorType::kCeil);
CheckSimpleOperator<EluOperator>("ELU", OperatorType::kElu);
CheckSimpleOperator<RoundOperator>("ROUND", OperatorType::kRound);
CheckSimpleOperator<ReluOperator>("RELU", OperatorType::kRelu);
CheckSimpleOperator<Relu1Operator>("RELU_N1_TO_1", OperatorType::kRelu1);
CheckSimpleOperator<Relu6Operator>("RELU6", OperatorType::kRelu6);

View File

@ -389,6 +389,7 @@ const char* OperatorTypeName(OperatorType type) {
HANDLE_OPERATORTYPENAME_CASE(Cast)
HANDLE_OPERATORTYPENAME_CASE(Floor)
HANDLE_OPERATORTYPENAME_CASE(Ceil)
HANDLE_OPERATORTYPENAME_CASE(Round)
HANDLE_OPERATORTYPENAME_CASE(Gather)
HANDLE_OPERATORTYPENAME_CASE(GatherNd)
HANDLE_OPERATORTYPENAME_CASE(ResizeBilinear)