[2] Review comments handled
This commit is contained in:
parent
1e6e5a3867
commit
0b930951cd
@ -67,7 +67,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
switch (type) {
|
||||
case kTfLiteFloat32:
|
||||
case kTfLiteInt32:
|
||||
case kTfLiteUInt8:
|
||||
break;
|
||||
default:
|
||||
context->ReportError(context, "Type '%s' is not supported by floor_div.",
|
||||
@ -133,10 +132,6 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
return EvalImpl<float>(context, data->requires_broadcast, input1, input2,
|
||||
output);
|
||||
}
|
||||
case kTfLiteUInt8: {
|
||||
return EvalImpl<uint8_t>(context, data->requires_broadcast, input1,
|
||||
input2, output);
|
||||
}
|
||||
default: {
|
||||
context->ReportError(context, "Type '%s' is not supported by floor_div.",
|
||||
TfLiteTypeGetName(input1->type));
|
||||
|
@ -112,28 +112,6 @@ TEST(FloorDivModel, BroadcastFloorDivFloat) {
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 2, 2, 1));
|
||||
EXPECT_THAT(model.GetOutput(), ElementsAre(-4.0, 2.0, 3.0, -3.0));
|
||||
}
|
||||
|
||||
TEST(FloorDivModel, SimpleUInt8) {
|
||||
FloorDivModel<uint8_t> model({TensorType_UINT8, {1, 2, 2, 1}},
|
||||
{TensorType_UINT8, {1, 2, 2, 1}},
|
||||
{TensorType_UINT8, {}});
|
||||
model.PopulateTensor<uint8_t>(model.input1(), {10, 9, 11, 7});
|
||||
model.PopulateTensor<uint8_t>(model.input2(), {2, 2, 3, 4});
|
||||
model.Invoke();
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 2, 2, 1));
|
||||
EXPECT_THAT(model.GetOutput(), ElementsAre(5, 4, 3, 1));
|
||||
}
|
||||
|
||||
TEST(FloorDivModel, BroadcastFloorDivUInt8) {
|
||||
FloorDivModel<uint8_t> model({TensorType_UINT8, {1, 2, 2, 1}},
|
||||
{TensorType_UINT8, {1}}, {TensorType_UINT8, {}});
|
||||
model.PopulateTensor<uint8_t>(model.input1(), {10, 9, 11, 7});
|
||||
model.PopulateTensor<uint8_t>(model.input2(), {3});
|
||||
model.Invoke();
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 2, 2, 1));
|
||||
EXPECT_THAT(model.GetOutput(), ElementsAre(3, 3, 3, 2));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tflite
|
||||
|
||||
|
@ -367,7 +367,9 @@ BuiltinOpResolver::BuiltinOpResolver() {
|
||||
AddBuiltin(BuiltinOperator_LOGICAL_AND, Register_LOGICAL_AND());
|
||||
AddBuiltin(BuiltinOperator_LOGICAL_NOT, Register_LOGICAL_NOT());
|
||||
AddBuiltin(BuiltinOperator_UNPACK, Register_UNPACK());
|
||||
AddBuiltin(BuiltinOperator_FLOOR_DIV, Register_FLOOR_DIV());
|
||||
AddBuiltin(BuiltinOperator_FLOOR_DIV, Register_FLOOR_DIV(),
|
||||
/* min_version */ 1,
|
||||
/* max_version */ 2);
|
||||
AddBuiltin(BuiltinOperator_SQUARE, Register_SQUARE());
|
||||
AddBuiltin(BuiltinOperator_ZEROS_LIKE, Register_ZEROS_LIKE());
|
||||
AddBuiltin(BuiltinOperator_FLOOR_MOD, Register_FLOOR_MOD());
|
||||
|
@ -2347,6 +2347,20 @@ class Select : public SimpleOperator<SelectOperator> {
|
||||
}
|
||||
};
|
||||
|
||||
class FloorDiv : public SimpleOperator<FloorDivOperator> {
|
||||
public:
|
||||
explicit FloorDiv() : SimpleOperator("FLOOR_DIV", OperatorType::kFloorDiv) {}
|
||||
int GetVersion(const OperatorSignature& op_signature) const override {
|
||||
const string& input_name = op_signature.op->inputs[0];
|
||||
const Array& input_array = op_signature.model->GetArray(input_name);
|
||||
// Version 2 supports float input types.
|
||||
if (input_array.data_type == ArrayDataType::kFloat) {
|
||||
return 2;
|
||||
}
|
||||
return 1;
|
||||
}
|
||||
};
|
||||
|
||||
namespace {
|
||||
// Build a vector containing all the known operators.
|
||||
std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList(
|
||||
@ -2551,8 +2565,7 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList(
|
||||
"LOGICAL_AND", OperatorType::kLogicalAnd));
|
||||
ops.emplace_back(new SimpleOperator<LogicalNotOperator>(
|
||||
"LOGICAL_NOT", OperatorType::kLogicalNot));
|
||||
ops.emplace_back(new SimpleOperator<FloorDivOperator>(
|
||||
"FLOOR_DIV", OperatorType::kFloorDiv));
|
||||
ops.push_back(MakeUnique<FloorDiv>());
|
||||
ops.emplace_back(new SimpleOperator<FloorModOperator>(
|
||||
"FLOOR_MOD", OperatorType::kFloorMod));
|
||||
ops.emplace_back(
|
||||
|
@ -968,6 +968,29 @@ TEST_F(OperatorTest, VersioningConv2DTest) {
|
||||
EXPECT_EQ(op->GetVersion(float_signature), 2);
|
||||
}
|
||||
|
||||
TEST_F(OperatorTest, VersioningFloorDivOperatorTest) {
|
||||
FloorDivOperator floordiv_op;
|
||||
floordiv_op.inputs = {"input1"};
|
||||
auto operator_by_type_map = BuildOperatorByTypeMap(false /*enable_flex_ops*/);
|
||||
const BaseOperator* op = operator_by_type_map.at(floordiv_op.type).get();
|
||||
|
||||
Model int32_model;
|
||||
Array& input_int32_array =
|
||||
int32_model.GetOrCreateArray(floordiv_op.inputs[0]);
|
||||
input_int32_array.data_type = ArrayDataType::kInt32;
|
||||
OperatorSignature int32_signature = {.op = &floordiv_op,
|
||||
.model = &int32_model};
|
||||
EXPECT_EQ(op->GetVersion(int32_signature), 1);
|
||||
|
||||
Model float_model;
|
||||
Array& input_float_array =
|
||||
float_model.GetOrCreateArray(floordiv_op.inputs[0]);
|
||||
input_float_array.data_type = ArrayDataType::kFloat;
|
||||
OperatorSignature float_signature = {.op = &floordiv_op,
|
||||
.model = &float_model};
|
||||
EXPECT_EQ(op->GetVersion(float_signature), 2);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tflite
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user