[2] Review comments handled

This commit is contained in:
ANSHUMAN TRIPATHY 2019-04-25 10:56:39 +05:30
parent 1e6e5a3867
commit 0b930951cd
5 changed files with 41 additions and 30 deletions

View File

@ -67,7 +67,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
switch (type) {
case kTfLiteFloat32:
case kTfLiteInt32:
case kTfLiteUInt8:
break;
default:
context->ReportError(context, "Type '%s' is not supported by floor_div.",
@ -133,10 +132,6 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
return EvalImpl<float>(context, data->requires_broadcast, input1, input2,
output);
}
case kTfLiteUInt8: {
return EvalImpl<uint8_t>(context, data->requires_broadcast, input1,
input2, output);
}
default: {
context->ReportError(context, "Type '%s' is not supported by floor_div.",
TfLiteTypeGetName(input1->type));

View File

@ -112,28 +112,6 @@ TEST(FloorDivModel, BroadcastFloorDivFloat) {
EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 2, 2, 1));
EXPECT_THAT(model.GetOutput(), ElementsAre(-4.0, 2.0, 3.0, -3.0));
}
TEST(FloorDivModel, SimpleUInt8) {
FloorDivModel<uint8_t> model({TensorType_UINT8, {1, 2, 2, 1}},
{TensorType_UINT8, {1, 2, 2, 1}},
{TensorType_UINT8, {}});
model.PopulateTensor<uint8_t>(model.input1(), {10, 9, 11, 7});
model.PopulateTensor<uint8_t>(model.input2(), {2, 2, 3, 4});
model.Invoke();
EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 2, 2, 1));
EXPECT_THAT(model.GetOutput(), ElementsAre(5, 4, 3, 1));
}
TEST(FloorDivModel, BroadcastFloorDivUInt8) {
FloorDivModel<uint8_t> model({TensorType_UINT8, {1, 2, 2, 1}},
{TensorType_UINT8, {1}}, {TensorType_UINT8, {}});
model.PopulateTensor<uint8_t>(model.input1(), {10, 9, 11, 7});
model.PopulateTensor<uint8_t>(model.input2(), {3});
model.Invoke();
EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 2, 2, 1));
EXPECT_THAT(model.GetOutput(), ElementsAre(3, 3, 3, 2));
}
} // namespace
} // namespace tflite

View File

@ -367,7 +367,9 @@ BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_LOGICAL_AND, Register_LOGICAL_AND());
AddBuiltin(BuiltinOperator_LOGICAL_NOT, Register_LOGICAL_NOT());
AddBuiltin(BuiltinOperator_UNPACK, Register_UNPACK());
AddBuiltin(BuiltinOperator_FLOOR_DIV, Register_FLOOR_DIV());
AddBuiltin(BuiltinOperator_FLOOR_DIV, Register_FLOOR_DIV(),
/* min_version */ 1,
/* max_version */ 2);
AddBuiltin(BuiltinOperator_SQUARE, Register_SQUARE());
AddBuiltin(BuiltinOperator_ZEROS_LIKE, Register_ZEROS_LIKE());
AddBuiltin(BuiltinOperator_FLOOR_MOD, Register_FLOOR_MOD());

View File

@ -2347,6 +2347,20 @@ class Select : public SimpleOperator<SelectOperator> {
}
};
class FloorDiv : public SimpleOperator<FloorDivOperator> {
public:
explicit FloorDiv() : SimpleOperator("FLOOR_DIV", OperatorType::kFloorDiv) {}
int GetVersion(const OperatorSignature& op_signature) const override {
const string& input_name = op_signature.op->inputs[0];
const Array& input_array = op_signature.model->GetArray(input_name);
// Version 2 supports float input types.
if (input_array.data_type == ArrayDataType::kFloat) {
return 2;
}
return 1;
}
};
namespace {
// Build a vector containing all the known operators.
std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList(
@ -2551,8 +2565,7 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList(
"LOGICAL_AND", OperatorType::kLogicalAnd));
ops.emplace_back(new SimpleOperator<LogicalNotOperator>(
"LOGICAL_NOT", OperatorType::kLogicalNot));
ops.emplace_back(new SimpleOperator<FloorDivOperator>(
"FLOOR_DIV", OperatorType::kFloorDiv));
ops.push_back(MakeUnique<FloorDiv>());
ops.emplace_back(new SimpleOperator<FloorModOperator>(
"FLOOR_MOD", OperatorType::kFloorMod));
ops.emplace_back(

View File

@ -968,6 +968,29 @@ TEST_F(OperatorTest, VersioningConv2DTest) {
EXPECT_EQ(op->GetVersion(float_signature), 2);
}
TEST_F(OperatorTest, VersioningFloorDivOperatorTest) {
FloorDivOperator floordiv_op;
floordiv_op.inputs = {"input1"};
auto operator_by_type_map = BuildOperatorByTypeMap(false /*enable_flex_ops*/);
const BaseOperator* op = operator_by_type_map.at(floordiv_op.type).get();
Model int32_model;
Array& input_int32_array =
int32_model.GetOrCreateArray(floordiv_op.inputs[0]);
input_int32_array.data_type = ArrayDataType::kInt32;
OperatorSignature int32_signature = {.op = &floordiv_op,
.model = &int32_model};
EXPECT_EQ(op->GetVersion(int32_signature), 1);
Model float_model;
Array& input_float_array =
float_model.GetOrCreateArray(floordiv_op.inputs[0]);
input_float_array.data_type = ArrayDataType::kFloat;
OperatorSignature float_signature = {.op = &floordiv_op,
.model = &float_model};
EXPECT_EQ(op->GetVersion(float_signature), 2);
}
} // namespace
} // namespace tflite