Implementation of square.
PiperOrigin-RevId: 212577288
This commit is contained in:
parent
cadd6b42bf
commit
6a21e1386e
@ -283,6 +283,7 @@ def generated_test_models():
|
||||
"sparse_to_dense",
|
||||
"split",
|
||||
"sqrt",
|
||||
"square",
|
||||
"squeeze",
|
||||
"strided_slice",
|
||||
"strided_slice_1d_exhaustive",
|
||||
|
@ -90,6 +90,10 @@ TfLiteStatus RsqrtEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
return EvalNumeric(context, node, [](float f) { return 1.f / std::sqrt(f); });
|
||||
}
|
||||
|
||||
TfLiteStatus SquareEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
return EvalNumeric(context, node, [](float f) { return f * f; });
|
||||
}
|
||||
|
||||
TfLiteStatus LogicalNotEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
return EvalLogical(context, node, [](bool v) { return !v; });
|
||||
}
|
||||
@ -129,6 +133,14 @@ TfLiteRegistration* Register_RSQRT() {
|
||||
return &r;
|
||||
}
|
||||
|
||||
TfLiteRegistration* Register_SQUARE() {
|
||||
static TfLiteRegistration r = {
|
||||
/*init=*/nullptr, /*free=*/nullptr,
|
||||
elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
|
||||
elementwise::SquareEval};
|
||||
return &r;
|
||||
}
|
||||
|
||||
TfLiteRegistration* Register_LOGICAL_NOT() {
|
||||
static TfLiteRegistration r = {
|
||||
/*init=*/nullptr, /*free=*/nullptr,
|
||||
|
@ -92,6 +92,15 @@ TEST(ElementWise, Rsqrt) {
|
||||
EXPECT_THAT(m.GetTensorShape(m.output()), ElementsAreArray({1, 1, 4, 1}));
|
||||
}
|
||||
|
||||
TEST(ElementWise, Square) {
|
||||
ElementWiseOpFloatModel m(BuiltinOperator_SQUARE, {1, 1, 4, 1});
|
||||
m.PopulateTensor<float>(m.input(), {1, 2, 0.5, -3.0});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.ExtractVector<float>(m.output()),
|
||||
ElementsAreArray(ArrayFloatNear({1, 4.0, 0.25, 9.0})));
|
||||
EXPECT_THAT(m.GetTensorShape(m.output()), ElementsAreArray({1, 1, 4, 1}));
|
||||
}
|
||||
|
||||
TEST(ElementWise, LogicalNot) {
|
||||
ElementWiseOpBoolModel m(BuiltinOperator_LOGICAL_NOT, {1, 1, 4, 1});
|
||||
m.PopulateTensor<bool>(m.input(), {true, false, true, false});
|
||||
|
@ -118,6 +118,7 @@ TfLiteRegistration* Register_LOGICAL_AND();
|
||||
TfLiteRegistration* Register_LOGICAL_NOT();
|
||||
TfLiteRegistration* Register_UNPACK();
|
||||
TfLiteRegistration* Register_FLOOR_DIV();
|
||||
TfLiteRegistration* Register_SQUARE();
|
||||
|
||||
TfLiteStatus UnsupportedTensorFlowOp(TfLiteContext* context, TfLiteNode* node) {
|
||||
context->ReportError(
|
||||
@ -243,6 +244,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
|
||||
AddBuiltin(BuiltinOperator_LOGICAL_NOT, Register_LOGICAL_NOT());
|
||||
AddBuiltin(BuiltinOperator_UNPACK, Register_UNPACK());
|
||||
AddBuiltin(BuiltinOperator_FLOOR_DIV, Register_FLOOR_DIV());
|
||||
AddBuiltin(BuiltinOperator_SQUARE, Register_SQUARE());
|
||||
|
||||
// TODO(andrewharp, ahentz): Move these somewhere more appropriate so that
|
||||
// custom ops aren't always included by default.
|
||||
|
@ -2882,6 +2882,11 @@ def make_rsqrt_tests(zip_path):
|
||||
return _make_elementwise_tests(tf.rsqrt)(zip_path)
|
||||
|
||||
|
||||
def make_square_tests(zip_path):
|
||||
"""Make a set of tests to do square."""
|
||||
return _make_elementwise_tests(tf.square)(zip_path)
|
||||
|
||||
|
||||
def make_where_tests(zip_path):
|
||||
"""Make a set of tests to do where."""
|
||||
|
||||
|
@ -1488,6 +1488,8 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList(
|
||||
"SQRT", OperatorType::kSqrt));
|
||||
ops.push_back(MakeUnique<SimpleOperator<TensorFlowRsqrtOperator>>(
|
||||
"RSQRT", OperatorType::kRsqrt));
|
||||
ops.push_back(MakeUnique<SimpleOperator<TensorFlowSquareOperator>>(
|
||||
"SQUARE", OperatorType::kSquare));
|
||||
|
||||
return ops;
|
||||
}
|
||||
|
@ -144,6 +144,8 @@ TEST_F(OperatorTest, SimpleOperators) {
|
||||
CheckSimpleOperator<LogicalNotOperator>("LOGICAL_NOT",
|
||||
OperatorType::kLogicalNot);
|
||||
CheckSimpleOperator<FloorDivOperator>("FLOOR_DIV", OperatorType::kFloorDiv);
|
||||
CheckSimpleOperator<TensorFlowSquareOperator>("SQUARE",
|
||||
OperatorType::kSquare);
|
||||
}
|
||||
|
||||
TEST_F(OperatorTest, BuiltinAdd) {
|
||||
|
Loading…
Reference in New Issue
Block a user