Implementation of square.
PiperOrigin-RevId: 212577288
This commit is contained in:
parent
cadd6b42bf
commit
6a21e1386e
@ -283,6 +283,7 @@ def generated_test_models():
|
|||||||
"sparse_to_dense",
|
"sparse_to_dense",
|
||||||
"split",
|
"split",
|
||||||
"sqrt",
|
"sqrt",
|
||||||
|
"square",
|
||||||
"squeeze",
|
"squeeze",
|
||||||
"strided_slice",
|
"strided_slice",
|
||||||
"strided_slice_1d_exhaustive",
|
"strided_slice_1d_exhaustive",
|
||||||
|
@ -90,6 +90,10 @@ TfLiteStatus RsqrtEval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
return EvalNumeric(context, node, [](float f) { return 1.f / std::sqrt(f); });
|
return EvalNumeric(context, node, [](float f) { return 1.f / std::sqrt(f); });
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TfLiteStatus SquareEval(TfLiteContext* context, TfLiteNode* node) {
|
||||||
|
return EvalNumeric(context, node, [](float f) { return f * f; });
|
||||||
|
}
|
||||||
|
|
||||||
TfLiteStatus LogicalNotEval(TfLiteContext* context, TfLiteNode* node) {
|
TfLiteStatus LogicalNotEval(TfLiteContext* context, TfLiteNode* node) {
|
||||||
return EvalLogical(context, node, [](bool v) { return !v; });
|
return EvalLogical(context, node, [](bool v) { return !v; });
|
||||||
}
|
}
|
||||||
@ -129,6 +133,14 @@ TfLiteRegistration* Register_RSQRT() {
|
|||||||
return &r;
|
return &r;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TfLiteRegistration* Register_SQUARE() {
|
||||||
|
static TfLiteRegistration r = {
|
||||||
|
/*init=*/nullptr, /*free=*/nullptr,
|
||||||
|
elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
|
||||||
|
elementwise::SquareEval};
|
||||||
|
return &r;
|
||||||
|
}
|
||||||
|
|
||||||
TfLiteRegistration* Register_LOGICAL_NOT() {
|
TfLiteRegistration* Register_LOGICAL_NOT() {
|
||||||
static TfLiteRegistration r = {
|
static TfLiteRegistration r = {
|
||||||
/*init=*/nullptr, /*free=*/nullptr,
|
/*init=*/nullptr, /*free=*/nullptr,
|
||||||
|
@ -92,6 +92,15 @@ TEST(ElementWise, Rsqrt) {
|
|||||||
EXPECT_THAT(m.GetTensorShape(m.output()), ElementsAreArray({1, 1, 4, 1}));
|
EXPECT_THAT(m.GetTensorShape(m.output()), ElementsAreArray({1, 1, 4, 1}));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(ElementWise, Square) {
|
||||||
|
ElementWiseOpFloatModel m(BuiltinOperator_SQUARE, {1, 1, 4, 1});
|
||||||
|
m.PopulateTensor<float>(m.input(), {1, 2, 0.5, -3.0});
|
||||||
|
m.Invoke();
|
||||||
|
EXPECT_THAT(m.ExtractVector<float>(m.output()),
|
||||||
|
ElementsAreArray(ArrayFloatNear({1, 4.0, 0.25, 9.0})));
|
||||||
|
EXPECT_THAT(m.GetTensorShape(m.output()), ElementsAreArray({1, 1, 4, 1}));
|
||||||
|
}
|
||||||
|
|
||||||
TEST(ElementWise, LogicalNot) {
|
TEST(ElementWise, LogicalNot) {
|
||||||
ElementWiseOpBoolModel m(BuiltinOperator_LOGICAL_NOT, {1, 1, 4, 1});
|
ElementWiseOpBoolModel m(BuiltinOperator_LOGICAL_NOT, {1, 1, 4, 1});
|
||||||
m.PopulateTensor<bool>(m.input(), {true, false, true, false});
|
m.PopulateTensor<bool>(m.input(), {true, false, true, false});
|
||||||
|
@ -118,6 +118,7 @@ TfLiteRegistration* Register_LOGICAL_AND();
|
|||||||
TfLiteRegistration* Register_LOGICAL_NOT();
|
TfLiteRegistration* Register_LOGICAL_NOT();
|
||||||
TfLiteRegistration* Register_UNPACK();
|
TfLiteRegistration* Register_UNPACK();
|
||||||
TfLiteRegistration* Register_FLOOR_DIV();
|
TfLiteRegistration* Register_FLOOR_DIV();
|
||||||
|
TfLiteRegistration* Register_SQUARE();
|
||||||
|
|
||||||
TfLiteStatus UnsupportedTensorFlowOp(TfLiteContext* context, TfLiteNode* node) {
|
TfLiteStatus UnsupportedTensorFlowOp(TfLiteContext* context, TfLiteNode* node) {
|
||||||
context->ReportError(
|
context->ReportError(
|
||||||
@ -243,6 +244,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
|
|||||||
AddBuiltin(BuiltinOperator_LOGICAL_NOT, Register_LOGICAL_NOT());
|
AddBuiltin(BuiltinOperator_LOGICAL_NOT, Register_LOGICAL_NOT());
|
||||||
AddBuiltin(BuiltinOperator_UNPACK, Register_UNPACK());
|
AddBuiltin(BuiltinOperator_UNPACK, Register_UNPACK());
|
||||||
AddBuiltin(BuiltinOperator_FLOOR_DIV, Register_FLOOR_DIV());
|
AddBuiltin(BuiltinOperator_FLOOR_DIV, Register_FLOOR_DIV());
|
||||||
|
AddBuiltin(BuiltinOperator_SQUARE, Register_SQUARE());
|
||||||
|
|
||||||
// TODO(andrewharp, ahentz): Move these somewhere more appropriate so that
|
// TODO(andrewharp, ahentz): Move these somewhere more appropriate so that
|
||||||
// custom ops aren't always included by default.
|
// custom ops aren't always included by default.
|
||||||
|
@ -2882,6 +2882,11 @@ def make_rsqrt_tests(zip_path):
|
|||||||
return _make_elementwise_tests(tf.rsqrt)(zip_path)
|
return _make_elementwise_tests(tf.rsqrt)(zip_path)
|
||||||
|
|
||||||
|
|
||||||
|
def make_square_tests(zip_path):
|
||||||
|
"""Make a set of tests to do square."""
|
||||||
|
return _make_elementwise_tests(tf.square)(zip_path)
|
||||||
|
|
||||||
|
|
||||||
def make_where_tests(zip_path):
|
def make_where_tests(zip_path):
|
||||||
"""Make a set of tests to do where."""
|
"""Make a set of tests to do where."""
|
||||||
|
|
||||||
|
@ -1488,6 +1488,8 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList(
|
|||||||
"SQRT", OperatorType::kSqrt));
|
"SQRT", OperatorType::kSqrt));
|
||||||
ops.push_back(MakeUnique<SimpleOperator<TensorFlowRsqrtOperator>>(
|
ops.push_back(MakeUnique<SimpleOperator<TensorFlowRsqrtOperator>>(
|
||||||
"RSQRT", OperatorType::kRsqrt));
|
"RSQRT", OperatorType::kRsqrt));
|
||||||
|
ops.push_back(MakeUnique<SimpleOperator<TensorFlowSquareOperator>>(
|
||||||
|
"SQUARE", OperatorType::kSquare));
|
||||||
|
|
||||||
return ops;
|
return ops;
|
||||||
}
|
}
|
||||||
|
@ -144,6 +144,8 @@ TEST_F(OperatorTest, SimpleOperators) {
|
|||||||
CheckSimpleOperator<LogicalNotOperator>("LOGICAL_NOT",
|
CheckSimpleOperator<LogicalNotOperator>("LOGICAL_NOT",
|
||||||
OperatorType::kLogicalNot);
|
OperatorType::kLogicalNot);
|
||||||
CheckSimpleOperator<FloorDivOperator>("FLOOR_DIV", OperatorType::kFloorDiv);
|
CheckSimpleOperator<FloorDivOperator>("FLOOR_DIV", OperatorType::kFloorDiv);
|
||||||
|
CheckSimpleOperator<TensorFlowSquareOperator>("SQUARE",
|
||||||
|
OperatorType::kSquare);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(OperatorTest, BuiltinAdd) {
|
TEST_F(OperatorTest, BuiltinAdd) {
|
||||||
|
Loading…
Reference in New Issue
Block a user