Implementation of square.

PiperOrigin-RevId: 212577288
This commit is contained in:
A. Unique TensorFlower 2018-09-11 21:18:05 -07:00 committed by TensorFlower Gardener
parent cadd6b42bf
commit 6a21e1386e
7 changed files with 33 additions and 0 deletions

View File

@ -283,6 +283,7 @@ def generated_test_models():
"sparse_to_dense", "sparse_to_dense",
"split", "split",
"sqrt", "sqrt",
"square",
"squeeze", "squeeze",
"strided_slice", "strided_slice",
"strided_slice_1d_exhaustive", "strided_slice_1d_exhaustive",

View File

@ -90,6 +90,10 @@ TfLiteStatus RsqrtEval(TfLiteContext* context, TfLiteNode* node) {
return EvalNumeric(context, node, [](float f) { return 1.f / std::sqrt(f); }); return EvalNumeric(context, node, [](float f) { return 1.f / std::sqrt(f); });
} }
TfLiteStatus SquareEval(TfLiteContext* context, TfLiteNode* node) {
return EvalNumeric(context, node, [](float f) { return f * f; });
}
TfLiteStatus LogicalNotEval(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus LogicalNotEval(TfLiteContext* context, TfLiteNode* node) {
return EvalLogical(context, node, [](bool v) { return !v; }); return EvalLogical(context, node, [](bool v) { return !v; });
} }
@ -129,6 +133,14 @@ TfLiteRegistration* Register_RSQRT() {
return &r; return &r;
} }
TfLiteRegistration* Register_SQUARE() {
static TfLiteRegistration r = {
/*init=*/nullptr, /*free=*/nullptr,
elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
elementwise::SquareEval};
return &r;
}
TfLiteRegistration* Register_LOGICAL_NOT() { TfLiteRegistration* Register_LOGICAL_NOT() {
static TfLiteRegistration r = { static TfLiteRegistration r = {
/*init=*/nullptr, /*free=*/nullptr, /*init=*/nullptr, /*free=*/nullptr,

View File

@ -92,6 +92,15 @@ TEST(ElementWise, Rsqrt) {
EXPECT_THAT(m.GetTensorShape(m.output()), ElementsAreArray({1, 1, 4, 1})); EXPECT_THAT(m.GetTensorShape(m.output()), ElementsAreArray({1, 1, 4, 1}));
} }
TEST(ElementWise, Square) {
ElementWiseOpFloatModel m(BuiltinOperator_SQUARE, {1, 1, 4, 1});
m.PopulateTensor<float>(m.input(), {1, 2, 0.5, -3.0});
m.Invoke();
EXPECT_THAT(m.ExtractVector<float>(m.output()),
ElementsAreArray(ArrayFloatNear({1, 4.0, 0.25, 9.0})));
EXPECT_THAT(m.GetTensorShape(m.output()), ElementsAreArray({1, 1, 4, 1}));
}
TEST(ElementWise, LogicalNot) { TEST(ElementWise, LogicalNot) {
ElementWiseOpBoolModel m(BuiltinOperator_LOGICAL_NOT, {1, 1, 4, 1}); ElementWiseOpBoolModel m(BuiltinOperator_LOGICAL_NOT, {1, 1, 4, 1});
m.PopulateTensor<bool>(m.input(), {true, false, true, false}); m.PopulateTensor<bool>(m.input(), {true, false, true, false});

View File

@ -118,6 +118,7 @@ TfLiteRegistration* Register_LOGICAL_AND();
TfLiteRegistration* Register_LOGICAL_NOT(); TfLiteRegistration* Register_LOGICAL_NOT();
TfLiteRegistration* Register_UNPACK(); TfLiteRegistration* Register_UNPACK();
TfLiteRegistration* Register_FLOOR_DIV(); TfLiteRegistration* Register_FLOOR_DIV();
TfLiteRegistration* Register_SQUARE();
TfLiteStatus UnsupportedTensorFlowOp(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus UnsupportedTensorFlowOp(TfLiteContext* context, TfLiteNode* node) {
context->ReportError( context->ReportError(
@ -243,6 +244,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_LOGICAL_NOT, Register_LOGICAL_NOT()); AddBuiltin(BuiltinOperator_LOGICAL_NOT, Register_LOGICAL_NOT());
AddBuiltin(BuiltinOperator_UNPACK, Register_UNPACK()); AddBuiltin(BuiltinOperator_UNPACK, Register_UNPACK());
AddBuiltin(BuiltinOperator_FLOOR_DIV, Register_FLOOR_DIV()); AddBuiltin(BuiltinOperator_FLOOR_DIV, Register_FLOOR_DIV());
AddBuiltin(BuiltinOperator_SQUARE, Register_SQUARE());
// TODO(andrewharp, ahentz): Move these somewhere more appropriate so that // TODO(andrewharp, ahentz): Move these somewhere more appropriate so that
// custom ops aren't always included by default. // custom ops aren't always included by default.

View File

@ -2882,6 +2882,11 @@ def make_rsqrt_tests(zip_path):
return _make_elementwise_tests(tf.rsqrt)(zip_path) return _make_elementwise_tests(tf.rsqrt)(zip_path)
def make_square_tests(zip_path):
"""Make a set of tests to do square."""
return _make_elementwise_tests(tf.square)(zip_path)
def make_where_tests(zip_path): def make_where_tests(zip_path):
"""Make a set of tests to do where.""" """Make a set of tests to do where."""

View File

@ -1488,6 +1488,8 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList(
"SQRT", OperatorType::kSqrt)); "SQRT", OperatorType::kSqrt));
ops.push_back(MakeUnique<SimpleOperator<TensorFlowRsqrtOperator>>( ops.push_back(MakeUnique<SimpleOperator<TensorFlowRsqrtOperator>>(
"RSQRT", OperatorType::kRsqrt)); "RSQRT", OperatorType::kRsqrt));
ops.push_back(MakeUnique<SimpleOperator<TensorFlowSquareOperator>>(
"SQUARE", OperatorType::kSquare));
return ops; return ops;
} }

View File

@ -144,6 +144,8 @@ TEST_F(OperatorTest, SimpleOperators) {
CheckSimpleOperator<LogicalNotOperator>("LOGICAL_NOT", CheckSimpleOperator<LogicalNotOperator>("LOGICAL_NOT",
OperatorType::kLogicalNot); OperatorType::kLogicalNot);
CheckSimpleOperator<FloorDivOperator>("FLOOR_DIV", OperatorType::kFloorDiv); CheckSimpleOperator<FloorDivOperator>("FLOOR_DIV", OperatorType::kFloorDiv);
CheckSimpleOperator<TensorFlowSquareOperator>("SQUARE",
OperatorType::kSquare);
} }
TEST_F(OperatorTest, BuiltinAdd) { TEST_F(OperatorTest, BuiltinAdd) {