Implementation of square.

PiperOrigin-RevId: 212577288
This commit is contained in:
A. Unique TensorFlower 2018-09-11 21:18:05 -07:00 committed by TensorFlower Gardener
parent cadd6b42bf
commit 6a21e1386e
7 changed files with 33 additions and 0 deletions

View File

@ -283,6 +283,7 @@ def generated_test_models():
"sparse_to_dense",
"split",
"sqrt",
"square",
"squeeze",
"strided_slice",
"strided_slice_1d_exhaustive",

View File

@ -90,6 +90,10 @@ TfLiteStatus RsqrtEval(TfLiteContext* context, TfLiteNode* node) {
return EvalNumeric(context, node, [](float f) { return 1.f / std::sqrt(f); });
}
TfLiteStatus SquareEval(TfLiteContext* context, TfLiteNode* node) {
return EvalNumeric(context, node, [](float f) { return f * f; });
}
TfLiteStatus LogicalNotEval(TfLiteContext* context, TfLiteNode* node) {
return EvalLogical(context, node, [](bool v) { return !v; });
}
@ -129,6 +133,14 @@ TfLiteRegistration* Register_RSQRT() {
return &r;
}
TfLiteRegistration* Register_SQUARE() {
static TfLiteRegistration r = {
/*init=*/nullptr, /*free=*/nullptr,
elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
elementwise::SquareEval};
return &r;
}
TfLiteRegistration* Register_LOGICAL_NOT() {
static TfLiteRegistration r = {
/*init=*/nullptr, /*free=*/nullptr,

View File

@ -92,6 +92,15 @@ TEST(ElementWise, Rsqrt) {
EXPECT_THAT(m.GetTensorShape(m.output()), ElementsAreArray({1, 1, 4, 1}));
}
TEST(ElementWise, Square) {
ElementWiseOpFloatModel m(BuiltinOperator_SQUARE, {1, 1, 4, 1});
m.PopulateTensor<float>(m.input(), {1, 2, 0.5, -3.0});
m.Invoke();
EXPECT_THAT(m.ExtractVector<float>(m.output()),
ElementsAreArray(ArrayFloatNear({1, 4.0, 0.25, 9.0})));
EXPECT_THAT(m.GetTensorShape(m.output()), ElementsAreArray({1, 1, 4, 1}));
}
TEST(ElementWise, LogicalNot) {
ElementWiseOpBoolModel m(BuiltinOperator_LOGICAL_NOT, {1, 1, 4, 1});
m.PopulateTensor<bool>(m.input(), {true, false, true, false});

View File

@ -118,6 +118,7 @@ TfLiteRegistration* Register_LOGICAL_AND();
TfLiteRegistration* Register_LOGICAL_NOT();
TfLiteRegistration* Register_UNPACK();
TfLiteRegistration* Register_FLOOR_DIV();
TfLiteRegistration* Register_SQUARE();
TfLiteStatus UnsupportedTensorFlowOp(TfLiteContext* context, TfLiteNode* node) {
context->ReportError(
@ -243,6 +244,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_LOGICAL_NOT, Register_LOGICAL_NOT());
AddBuiltin(BuiltinOperator_UNPACK, Register_UNPACK());
AddBuiltin(BuiltinOperator_FLOOR_DIV, Register_FLOOR_DIV());
AddBuiltin(BuiltinOperator_SQUARE, Register_SQUARE());
// TODO(andrewharp, ahentz): Move these somewhere more appropriate so that
// custom ops aren't always included by default.

View File

@ -2882,6 +2882,11 @@ def make_rsqrt_tests(zip_path):
return _make_elementwise_tests(tf.rsqrt)(zip_path)
def make_square_tests(zip_path):
"""Make a set of tests to do square."""
return _make_elementwise_tests(tf.square)(zip_path)
def make_where_tests(zip_path):
"""Make a set of tests to do where."""

View File

@ -1488,6 +1488,8 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList(
"SQRT", OperatorType::kSqrt));
ops.push_back(MakeUnique<SimpleOperator<TensorFlowRsqrtOperator>>(
"RSQRT", OperatorType::kRsqrt));
ops.push_back(MakeUnique<SimpleOperator<TensorFlowSquareOperator>>(
"SQUARE", OperatorType::kSquare));
return ops;
}

View File

@ -144,6 +144,8 @@ TEST_F(OperatorTest, SimpleOperators) {
CheckSimpleOperator<LogicalNotOperator>("LOGICAL_NOT",
OperatorType::kLogicalNot);
CheckSimpleOperator<FloorDivOperator>("FLOOR_DIV", OperatorType::kFloorDiv);
CheckSimpleOperator<TensorFlowSquareOperator>("SQUARE",
OperatorType::kSquare);
}
TEST_F(OperatorTest, BuiltinAdd) {