diff --git a/tensorflow/lite/build_def.bzl b/tensorflow/lite/build_def.bzl index 379c8134e5a..582ec7144b5 100644 --- a/tensorflow/lite/build_def.bzl +++ b/tensorflow/lite/build_def.bzl @@ -307,6 +307,7 @@ def generated_test_models(): "resolve_constant_strided_slice", "reverse_sequence", "reverse_v2", + "round", "rsqrt", "shape", "sigmoid", diff --git a/tensorflow/lite/builtin_ops.h b/tensorflow/lite/builtin_ops.h index 914fd7fc23c..4e86e4bdf27 100644 --- a/tensorflow/lite/builtin_ops.h +++ b/tensorflow/lite/builtin_ops.h @@ -141,6 +141,7 @@ typedef enum { kTfLiteBuiltinMatrixDiag = 113, kTfLiteBuiltinQuantize = 114, kTfLiteBuiltinMatrixSetDiag = 115, + kTfLiteBuiltinRound = 116, } TfLiteBuiltinOperator; #ifdef __cplusplus diff --git a/tensorflow/lite/core/api/flatbuffer_conversions.cc b/tensorflow/lite/core/api/flatbuffer_conversions.cc index 0ff207b6b34..2354f000a71 100644 --- a/tensorflow/lite/core/api/flatbuffer_conversions.cc +++ b/tensorflow/lite/core/api/flatbuffer_conversions.cc @@ -728,6 +728,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, case BuiltinOperator_RELU: case BuiltinOperator_RELU6: case BuiltinOperator_RELU_N1_TO_1: + case BuiltinOperator_ROUND: case BuiltinOperator_RSQRT: case BuiltinOperator_SELECT: case BuiltinOperator_SIN: diff --git a/tensorflow/lite/experimental/writer/option_writer_generator.cc b/tensorflow/lite/experimental/writer/option_writer_generator.cc index caf777305fa..2ea105f4127 100644 --- a/tensorflow/lite/experimental/writer/option_writer_generator.cc +++ b/tensorflow/lite/experimental/writer/option_writer_generator.cc @@ -174,6 +174,7 @@ class OpOptionData { op_to_option_["RELU"] = ""; op_to_option_["RELU_N1_TO_1"] = ""; op_to_option_["RELU6"] = ""; + op_to_option_["ROUND"] = ""; op_to_option_["TANH"] = ""; op_to_option_["PRELU"] = ""; op_to_option_["SIN"] = ""; diff --git a/tensorflow/lite/g3doc/guide/ops_compatibility.md b/tensorflow/lite/g3doc/guide/ops_compatibility.md index 56caf7dcc71..e11f34fb665 100644 --- a/tensorflow/lite/g3doc/guide/ops_compatibility.md +++ b/tensorflow/lite/g3doc/guide/ops_compatibility.md @@ -390,10 +390,10 @@ Outputs { **CEIL** ``` -inputs { - 0: tensor +Inputs { + 0: a tensor } -outputs: { +Outputs { 0: result of computing element-wise ceil of the input tensor } ``` @@ -844,6 +844,17 @@ Options { } ``` +**ROUND** + +``` +Inputs { + 0: a tensor +} +Outputs { + 0: result of computing element-wise round of the input tensor +} +``` + **SLICE** ``` diff --git a/tensorflow/lite/kernels/BUILD b/tensorflow/lite/kernels/BUILD index a0afd2ca564..f3cf5b79308 100644 --- a/tensorflow/lite/kernels/BUILD +++ b/tensorflow/lite/kernels/BUILD @@ -327,6 +327,7 @@ cc_library( "resize_nearest_neighbor.cc", "reverse.cc", "reverse_sequence.cc", + "round.cc", "select.cc", "shape.cc", "skip_gram.cc", @@ -685,6 +686,21 @@ cc_test( ], ) +cc_test( + name = "round_test", + size = "small", + srcs = ["round_test.cc"], + tags = [ + "tflite_not_portable_ios", + ], + deps = [ + ":builtin_ops", + "//tensorflow/lite:framework", + "//tensorflow/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + cc_test( name = "elementwise_test", size = "small", diff --git a/tensorflow/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/lite/kernels/internal/optimized/optimized_ops.h index 0bf85acdfbb..df114b39b02 100644 --- a/tensorflow/lite/kernels/internal/optimized/optimized_ops.h +++ b/tensorflow/lite/kernels/internal/optimized/optimized_ops.h @@ -93,6 +93,7 @@ using reference_ops::RankOneSelect; using reference_ops::Relu1; using reference_ops::Relu6; using reference_ops::ReluX; +using reference_ops::Round; using reference_ops::Select; using reference_ops::SpaceToBatchND; using reference_ops::Split; diff --git a/tensorflow/lite/kernels/internal/reference/reference_ops.h b/tensorflow/lite/kernels/internal/reference/reference_ops.h index 66e6bc9cc80..1594a0a1199 100644 --- a/tensorflow/lite/kernels/internal/reference/reference_ops.h +++ b/tensorflow/lite/kernels/internal/reference/reference_ops.h @@ -2619,6 +2619,29 @@ inline void Ceil(const RuntimeShape& input_shape, const float* input_data, } } +inline float RoundToNearest(float value) { + auto floor_val = std::floor(value); + auto diff = value - floor_val; + if ((diff < 0.5f) || + ((diff == 0.5f) && (static_cast(floor_val) % 2 == 0))) { + return floor_val; + } else { + return floor_val = floor_val + 1.0f; + } +} + +inline void Round(const RuntimeShape& input_shape, const float* input_data, + const RuntimeShape& output_shape, float* output_data) { + const int flat_size = MatchingFlatSize(input_shape, output_shape); + for (int i = 0; i < flat_size; i++) { + // Note that this implementation matches that of tensorFlow tf.round + // and corresponds to the bankers rounding method. + // cfenv (for fesetround) is not yet supported universally on Android, so + // using a work around. + output_data[i] = RoundToNearest(input_data[i]); + } +} + template inline void Gather(const tflite::GatherParams& op_params, const RuntimeShape& input_shape, const T* input_data, diff --git a/tensorflow/lite/kernels/register.cc b/tensorflow/lite/kernels/register.cc index b527d927812..fb9807b7fa9 100644 --- a/tensorflow/lite/kernels/register.cc +++ b/tensorflow/lite/kernels/register.cc @@ -96,6 +96,7 @@ TfLiteRegistration* Register_LESS(); TfLiteRegistration* Register_LESS_EQUAL(); TfLiteRegistration* Register_FLOOR(); TfLiteRegistration* Register_CEIL(); +TfLiteRegistration* Register_ROUND(); TfLiteRegistration* Register_TILE(); TfLiteRegistration* Register_NEG(); TfLiteRegistration* Register_SUM(); @@ -324,6 +325,7 @@ BuiltinOpResolver::BuiltinOpResolver() { /* max_version */ 2); AddBuiltin(BuiltinOperator_FLOOR, Register_FLOOR()); AddBuiltin(BuiltinOperator_CEIL, Register_CEIL()); + AddBuiltin(BuiltinOperator_ROUND, Register_ROUND()); AddBuiltin(BuiltinOperator_NEG, Register_NEG()); AddBuiltin(BuiltinOperator_SELECT, Register_SELECT(), /* min_version */ 1, diff --git a/tensorflow/lite/kernels/round.cc b/tensorflow/lite/kernels/round.cc new file mode 100644 index 00000000000..908e355be0a --- /dev/null +++ b/tensorflow/lite/kernels/round.cc @@ -0,0 +1,59 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/lite/kernels/internal/tensor.h" +#include "tensorflow/lite/kernels/kernel_util.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace round { + +constexpr int kInputTensor = 0; +constexpr int kOutputTensor = 0; + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32); + output->type = input->type; + TfLiteIntArray* output_size = TfLiteIntArrayCopy(input->dims); + return context->ResizeTensor(context, output, output_size); +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + optimized_ops::Round(GetTensorShape(input), GetTensorData(input), + GetTensorShape(output), GetTensorData(output)); + + return kTfLiteOk; +} +} // namespace round + +TfLiteRegistration* Register_ROUND() { + static TfLiteRegistration r = {/*init=*/nullptr, + /*free=*/nullptr, round::Prepare, round::Eval}; + return &r; +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/lite/kernels/round_test.cc b/tensorflow/lite/kernels/round_test.cc new file mode 100644 index 00000000000..37304fb2309 --- /dev/null +++ b/tensorflow/lite/kernels/round_test.cc @@ -0,0 +1,74 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class RoundOpModel : public SingleOpModel { + public: + RoundOpModel(std::initializer_list input_shape, TensorType input_type) { + input_ = AddInput(TensorType_FLOAT32); + output_ = AddOutput(TensorType_FLOAT32); + SetBuiltinOp(BuiltinOperator_ROUND, BuiltinOptions_NONE, 0); + BuildInterpreter({ + input_shape, + }); + } + + int input() { return input_; } + + std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetOutputShape() { return GetTensorShape(output_); } + + private: + int input_; + int output_; +}; + +TEST(RoundOpTest, SingleDim) { + RoundOpModel model({6}, TensorType_FLOAT32); + model.PopulateTensor(model.input(), {8.5, 0.0, 3.5, 4.2, -3.5, -4.5}); + model.Invoke(); + EXPECT_THAT(model.GetOutput(), ElementsAreArray({8, 0, 4, 4, -4, -4})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({6})); +} + +TEST(RoundOpTest, MultiDims) { + RoundOpModel model({2, 1, 1, 6}, TensorType_FLOAT32); + model.PopulateTensor( + model.input(), {0.0001, 8.0001, 0.9999, 9.9999, 0.5, -0.0001, -8.0001, + -0.9999, -9.9999, -0.5, -2.5, 1.5}); + model.Invoke(); + EXPECT_THAT(model.GetOutput(), + ElementsAreArray({0, 8, 1, 10, 0, 0, -8, -1, -10, -0, -2, 2})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 1, 1, 6})); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/lite/schema/schema.fbs b/tensorflow/lite/schema/schema.fbs index 3dbdacd3832..b5fc0f31be0 100644 --- a/tensorflow/lite/schema/schema.fbs +++ b/tensorflow/lite/schema/schema.fbs @@ -228,7 +228,8 @@ enum BuiltinOperator : byte { REVERSE_SEQUENCE = 112, MATRIX_DIAG = 113, QUANTIZE = 114, - MATRIX_SET_DIAG = 115 + MATRIX_SET_DIAG = 115, + ROUND = 116, } // Options for the builtin operators. diff --git a/tensorflow/lite/schema/schema_generated.h b/tensorflow/lite/schema/schema_generated.h index 3520eff51d9..6d14eb4dc79 100755 --- a/tensorflow/lite/schema/schema_generated.h +++ b/tensorflow/lite/schema/schema_generated.h @@ -566,11 +566,12 @@ enum BuiltinOperator { BuiltinOperator_MATRIX_DIAG = 113, BuiltinOperator_QUANTIZE = 114, BuiltinOperator_MATRIX_SET_DIAG = 115, + BuiltinOperator_ROUND = 116, BuiltinOperator_MIN = BuiltinOperator_ADD, - BuiltinOperator_MAX = BuiltinOperator_MATRIX_SET_DIAG + BuiltinOperator_MAX = BuiltinOperator_ROUND }; -inline const BuiltinOperator (&EnumValuesBuiltinOperator())[115] { +inline const BuiltinOperator (&EnumValuesBuiltinOperator())[116] { static const BuiltinOperator values[] = { BuiltinOperator_ADD, BuiltinOperator_AVERAGE_POOL_2D, @@ -686,7 +687,8 @@ inline const BuiltinOperator (&EnumValuesBuiltinOperator())[115] { BuiltinOperator_REVERSE_SEQUENCE, BuiltinOperator_MATRIX_DIAG, BuiltinOperator_QUANTIZE, - BuiltinOperator_MATRIX_SET_DIAG + BuiltinOperator_MATRIX_SET_DIAG, + BuiltinOperator_ROUND }; return values; } @@ -809,6 +811,7 @@ inline const char * const *EnumNamesBuiltinOperator() { "MATRIX_DIAG", "QUANTIZE", "MATRIX_SET_DIAG", + "ROUND", nullptr }; return names; diff --git a/tensorflow/lite/testing/generate_examples_lib.py b/tensorflow/lite/testing/generate_examples_lib.py index 7934d6e4d11..31daaf2f44f 100644 --- a/tensorflow/lite/testing/generate_examples_lib.py +++ b/tensorflow/lite/testing/generate_examples_lib.py @@ -3584,6 +3584,32 @@ def make_ceil_tests(options): make_zip_of_tests(options, test_parameters, build_graph, build_inputs) +@register_make_test_function() +def make_round_tests(options): + """Build the round op testing graph.""" + + test_parameters = [{ + "input_dtype": [tf.float32], + "input_shape": [[], [1], [1, 2], [5, 6, 7, 8], [3, 4, 5, 6]], + }] + + def build_graph(parameters): + """Build the round op testing graph.""" + input_value = tf.placeholder( + dtype=parameters["input_dtype"], + name="input1", + shape=parameters["input_shape"]) + out = tf.round(input_value) + return [input_value], [out] + + def build_inputs(parameters, sess, inputs, outputs): + input_value = create_tensor_data(parameters["input_dtype"], + parameters["input_shape"]) + return [input_value], sess.run(outputs, feed_dict={inputs[0]: input_value}) + + make_zip_of_tests(options, test_parameters, build_graph, build_inputs) + + @register_make_test_function() def make_neg_tests(options): """Make a set of tests to do neg.""" diff --git a/tensorflow/lite/toco/export_tensorflow.cc b/tensorflow/lite/toco/export_tensorflow.cc index d426a690678..f9a307a3769 100644 --- a/tensorflow/lite/toco/export_tensorflow.cc +++ b/tensorflow/lite/toco/export_tensorflow.cc @@ -1215,6 +1215,16 @@ void ConvertCeilOperator(const Model& model, const CeilOperator& src_op, (*ceil_op->mutable_attr())["T"].set_type(DT_FLOAT); } +void ConvertRoundOperator(const Model& model, const RoundOperator& src_op, + GraphDef* tensorflow_graph) { + tensorflow::NodeDef* round_op = tensorflow_graph->add_node(); + round_op->set_op("Round"); + round_op->set_name(src_op.outputs[0]); + CHECK_EQ(src_op.inputs.size(), 1); + *round_op->add_input() = src_op.inputs[0]; + (*round_op->mutable_attr())["T"].set_type(DT_FLOAT); +} + void ConvertGatherOperator(const Model& model, const GatherOperator& src_op, GraphDef* tensorflow_graph) { tensorflow::NodeDef* gather_op = tensorflow_graph->add_node(); @@ -2210,6 +2220,9 @@ void ConvertOperator(const Model& model, const Operator& src_op, } else if (src_op.type == OperatorType::kCeil) { ConvertCeilOperator(model, static_cast(src_op), tensorflow_graph); + } else if (src_op.type == OperatorType::kRound) { + ConvertRoundOperator(model, static_cast(src_op), + tensorflow_graph); } else if (src_op.type == OperatorType::kGather) { ConvertGatherOperator(model, static_cast(src_op), tensorflow_graph); diff --git a/tensorflow/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/lite/toco/graph_transformations/propagate_fixed_sizes.cc index b748d32f63d..d3d64411376 100644 --- a/tensorflow/lite/toco/graph_transformations/propagate_fixed_sizes.cc +++ b/tensorflow/lite/toco/graph_transformations/propagate_fixed_sizes.cc @@ -2153,6 +2153,7 @@ void ProcessMatrixSetDiagOperator(Model* model, MatrixSetDiagOperator* op) { case OperatorType::kCast: case OperatorType::kFloor: case OperatorType::kCeil: + case OperatorType::kRound: case OperatorType::kExp: case OperatorType::kSin: case OperatorType::kCos: diff --git a/tensorflow/lite/toco/graph_transformations/reorder_elementwise_unary.cc b/tensorflow/lite/toco/graph_transformations/reorder_elementwise_unary.cc index 98105d384e1..2f935a674f1 100644 --- a/tensorflow/lite/toco/graph_transformations/reorder_elementwise_unary.cc +++ b/tensorflow/lite/toco/graph_transformations/reorder_elementwise_unary.cc @@ -37,6 +37,7 @@ bool IsElementwiseOperator(OperatorType optype) { case OperatorType::kRelu: case OperatorType::kRelu1: case OperatorType::kRelu6: + case OperatorType::kRound: case OperatorType::kTanh: case OperatorType::kSqrt: case OperatorType::kSquare: diff --git a/tensorflow/lite/toco/import_tensorflow.cc b/tensorflow/lite/toco/import_tensorflow.cc index 197a970212e..1c78d35457c 100644 --- a/tensorflow/lite/toco/import_tensorflow.cc +++ b/tensorflow/lite/toco/import_tensorflow.cc @@ -1547,6 +1547,20 @@ tensorflow::Status ConvertCeilOperator( return tensorflow::Status::OK(); } +tensorflow::Status ConvertRoundOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { + CHECK_EQ(node.op(), "Round"); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1)); + const auto data_type = GetDataTypeAttr(node, "T"); + CHECK(data_type == DT_FLOAT); + auto* op = new RoundOperator; + op->inputs.push_back(node.input(0)); + op->outputs.push_back(node.name()); + model->operators.emplace_back(op); + return tensorflow::Status::OK(); +} + tensorflow::Status ConvertGatherOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, Model* model) { @@ -2511,6 +2525,7 @@ ConverterMapType GetTensorFlowNodeConverterMap() { {"ResizeNearestNeighbor", ConvertResizeNearestNeighborOperator}, {"ReverseSequence", ConvertReverseSequenceOperator}, {"ReverseV2", ConvertSimpleOperator}, + {"Round", ConvertRoundOperator}, {"Rsqrt", ConvertSimpleOperator}, {"Select", ConvertSimpleOperator}, {"Shape", ConvertShapeOperator}, diff --git a/tensorflow/lite/toco/model.h b/tensorflow/lite/toco/model.h index 2b0c3e982cc..fcee42c2294 100644 --- a/tensorflow/lite/toco/model.h +++ b/tensorflow/lite/toco/model.h @@ -82,6 +82,7 @@ enum class OperatorType : uint8 { kTransposeConv, kCast, kFloor, + kRound, kGather, kResizeBilinear, kSin, @@ -1715,6 +1716,16 @@ struct CeilOperator : Operator { CeilOperator() : Operator(OperatorType::kCeil) {} }; +// Round operator. +// +// Inputs: +// inputs[0]: required: the input array +// +// TensorFlow equivalent: Round +struct RoundOperator : Operator { + RoundOperator() : Operator(OperatorType::kRound) {} +}; + // Gather operator. It gathers slices from params according to indices. // Only 1-D indices are supported at the moment. // diff --git a/tensorflow/lite/toco/tflite/operator.cc b/tensorflow/lite/toco/tflite/operator.cc index 32ee882ae40..15c4d7457b1 100644 --- a/tensorflow/lite/toco/tflite/operator.cc +++ b/tensorflow/lite/toco/tflite/operator.cc @@ -2517,6 +2517,8 @@ std::vector> BuildOperatorList( MakeUnique>("CEIL", OperatorType::kCeil)); ops.push_back( MakeUnique>("ELU", OperatorType::kElu)); + ops.push_back( + MakeUnique>("ROUND", OperatorType::kRound)); ops.push_back( MakeUnique>("RELU", OperatorType::kRelu)); ops.push_back(MakeUnique>( diff --git a/tensorflow/lite/toco/tflite/operator_test.cc b/tensorflow/lite/toco/tflite/operator_test.cc index 0d851973323..937b69e331e 100644 --- a/tensorflow/lite/toco/tflite/operator_test.cc +++ b/tensorflow/lite/toco/tflite/operator_test.cc @@ -113,6 +113,7 @@ TEST_F(OperatorTest, SimpleOperators) { CheckSimpleOperator("FLOOR", OperatorType::kFloor); CheckSimpleOperator("CEIL", OperatorType::kCeil); CheckSimpleOperator("ELU", OperatorType::kElu); + CheckSimpleOperator("ROUND", OperatorType::kRound); CheckSimpleOperator("RELU", OperatorType::kRelu); CheckSimpleOperator("RELU_N1_TO_1", OperatorType::kRelu1); CheckSimpleOperator("RELU6", OperatorType::kRelu6); diff --git a/tensorflow/lite/toco/tooling_util.cc b/tensorflow/lite/toco/tooling_util.cc index 69ecd5cdc18..626a7befecf 100644 --- a/tensorflow/lite/toco/tooling_util.cc +++ b/tensorflow/lite/toco/tooling_util.cc @@ -389,6 +389,7 @@ const char* OperatorTypeName(OperatorType type) { HANDLE_OPERATORTYPENAME_CASE(Cast) HANDLE_OPERATORTYPENAME_CASE(Floor) HANDLE_OPERATORTYPENAME_CASE(Ceil) + HANDLE_OPERATORTYPENAME_CASE(Round) HANDLE_OPERATORTYPENAME_CASE(Gather) HANDLE_OPERATORTYPENAME_CASE(GatherNd) HANDLE_OPERATORTYPENAME_CASE(ResizeBilinear)