From 6a21e1386e3e68cf752af861b9b1b950bda8a130 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 11 Sep 2018 21:18:05 -0700 Subject: [PATCH] Implementation of square. PiperOrigin-RevId: 212577288 --- tensorflow/contrib/lite/build_def.bzl | 1 + tensorflow/contrib/lite/kernels/elementwise.cc | 12 ++++++++++++ tensorflow/contrib/lite/kernels/elementwise_test.cc | 9 +++++++++ tensorflow/contrib/lite/kernels/register.cc | 2 ++ tensorflow/contrib/lite/testing/generate_examples.py | 5 +++++ tensorflow/contrib/lite/toco/tflite/operator.cc | 2 ++ tensorflow/contrib/lite/toco/tflite/operator_test.cc | 2 ++ 7 files changed, 33 insertions(+) diff --git a/tensorflow/contrib/lite/build_def.bzl b/tensorflow/contrib/lite/build_def.bzl index 0210428026e..e9c02cdbee7 100644 --- a/tensorflow/contrib/lite/build_def.bzl +++ b/tensorflow/contrib/lite/build_def.bzl @@ -283,6 +283,7 @@ def generated_test_models(): "sparse_to_dense", "split", "sqrt", + "square", "squeeze", "strided_slice", "strided_slice_1d_exhaustive", diff --git a/tensorflow/contrib/lite/kernels/elementwise.cc b/tensorflow/contrib/lite/kernels/elementwise.cc index 04995d70dd7..8c624b32080 100644 --- a/tensorflow/contrib/lite/kernels/elementwise.cc +++ b/tensorflow/contrib/lite/kernels/elementwise.cc @@ -90,6 +90,10 @@ TfLiteStatus RsqrtEval(TfLiteContext* context, TfLiteNode* node) { return EvalNumeric(context, node, [](float f) { return 1.f / std::sqrt(f); }); } +TfLiteStatus SquareEval(TfLiteContext* context, TfLiteNode* node) { + return EvalNumeric(context, node, [](float f) { return f * f; }); +} + TfLiteStatus LogicalNotEval(TfLiteContext* context, TfLiteNode* node) { return EvalLogical(context, node, [](bool v) { return !v; }); } @@ -129,6 +133,14 @@ TfLiteRegistration* Register_RSQRT() { return &r; } +TfLiteRegistration* Register_SQUARE() { + static TfLiteRegistration r = { + /*init=*/nullptr, /*free=*/nullptr, + elementwise::GenericPrepare, + elementwise::SquareEval}; + return &r; +} + TfLiteRegistration* Register_LOGICAL_NOT() { static TfLiteRegistration r = { /*init=*/nullptr, /*free=*/nullptr, diff --git a/tensorflow/contrib/lite/kernels/elementwise_test.cc b/tensorflow/contrib/lite/kernels/elementwise_test.cc index b9d7d73c528..5dd89a0eaec 100644 --- a/tensorflow/contrib/lite/kernels/elementwise_test.cc +++ b/tensorflow/contrib/lite/kernels/elementwise_test.cc @@ -92,6 +92,15 @@ TEST(ElementWise, Rsqrt) { EXPECT_THAT(m.GetTensorShape(m.output()), ElementsAreArray({1, 1, 4, 1})); } +TEST(ElementWise, Square) { + ElementWiseOpFloatModel m(BuiltinOperator_SQUARE, {1, 1, 4, 1}); + m.PopulateTensor(m.input(), {1, 2, 0.5, -3.0}); + m.Invoke(); + EXPECT_THAT(m.ExtractVector(m.output()), + ElementsAreArray(ArrayFloatNear({1, 4.0, 0.25, 9.0}))); + EXPECT_THAT(m.GetTensorShape(m.output()), ElementsAreArray({1, 1, 4, 1})); +} + TEST(ElementWise, LogicalNot) { ElementWiseOpBoolModel m(BuiltinOperator_LOGICAL_NOT, {1, 1, 4, 1}); m.PopulateTensor(m.input(), {true, false, true, false}); diff --git a/tensorflow/contrib/lite/kernels/register.cc b/tensorflow/contrib/lite/kernels/register.cc index c66959fdf4b..14296d3a9f7 100644 --- a/tensorflow/contrib/lite/kernels/register.cc +++ b/tensorflow/contrib/lite/kernels/register.cc @@ -118,6 +118,7 @@ TfLiteRegistration* Register_LOGICAL_AND(); TfLiteRegistration* Register_LOGICAL_NOT(); TfLiteRegistration* Register_UNPACK(); TfLiteRegistration* Register_FLOOR_DIV(); +TfLiteRegistration* Register_SQUARE(); TfLiteStatus UnsupportedTensorFlowOp(TfLiteContext* context, TfLiteNode* node) { context->ReportError( @@ -243,6 +244,7 @@ BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_LOGICAL_NOT, Register_LOGICAL_NOT()); AddBuiltin(BuiltinOperator_UNPACK, Register_UNPACK()); AddBuiltin(BuiltinOperator_FLOOR_DIV, Register_FLOOR_DIV()); + AddBuiltin(BuiltinOperator_SQUARE, Register_SQUARE()); // TODO(andrewharp, ahentz): Move these somewhere more appropriate so that // custom ops aren't always included by default. diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/contrib/lite/testing/generate_examples.py index 812385e7067..5d0895c72fc 100644 --- a/tensorflow/contrib/lite/testing/generate_examples.py +++ b/tensorflow/contrib/lite/testing/generate_examples.py @@ -2882,6 +2882,11 @@ def make_rsqrt_tests(zip_path): return _make_elementwise_tests(tf.rsqrt)(zip_path) +def make_square_tests(zip_path): + """Make a set of tests to do square.""" + return _make_elementwise_tests(tf.square)(zip_path) + + def make_where_tests(zip_path): """Make a set of tests to do where.""" diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc index eb0f7c443a8..54860121763 100644 --- a/tensorflow/contrib/lite/toco/tflite/operator.cc +++ b/tensorflow/contrib/lite/toco/tflite/operator.cc @@ -1488,6 +1488,8 @@ std::vector> BuildOperatorList( "SQRT", OperatorType::kSqrt)); ops.push_back(MakeUnique>( "RSQRT", OperatorType::kRsqrt)); + ops.push_back(MakeUnique>( + "SQUARE", OperatorType::kSquare)); return ops; } diff --git a/tensorflow/contrib/lite/toco/tflite/operator_test.cc b/tensorflow/contrib/lite/toco/tflite/operator_test.cc index 519a3a4e015..72e50a9aed1 100644 --- a/tensorflow/contrib/lite/toco/tflite/operator_test.cc +++ b/tensorflow/contrib/lite/toco/tflite/operator_test.cc @@ -144,6 +144,8 @@ TEST_F(OperatorTest, SimpleOperators) { CheckSimpleOperator("LOGICAL_NOT", OperatorType::kLogicalNot); CheckSimpleOperator("FLOOR_DIV", OperatorType::kFloorDiv); + CheckSimpleOperator("SQUARE", + OperatorType::kSquare); } TEST_F(OperatorTest, BuiltinAdd) {