diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td index 53ed7709630..c74e27b2982 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td @@ -1596,17 +1596,16 @@ shape(squeeze(t, [2, 4])) ==> [1, 2, 3, 1] def TFL_FillOp: TFL_Op<"fill", [ NoSideEffect, PredOpTrait<"input and result must have same element type", - TFL_TCresVTEtIsSameAsOp<0, 1>>, - NoQuantizableResult]> { + TFL_TCresVTEtIsSameAsOp<0, 1>>]> { let summary = "Fill the tensor with given value."; let description = [{ Fill the tensor with given value. }]; let arguments = (ins TFL_I32OrI64Tensor:$dims, - TFL_TensorOf<[F32, I32, I64, I1, QI8, TFL_Str]>:$input); + TFL_TensorOf<[F32, I32, I64, I1, QI8, QI16, TFL_Str]>:$input); - let results = (outs TFL_TensorOf<[F32, I32, I64, I1, QI8, TFL_Str]>:$result); + let results = (outs TFL_TensorOf<[F32, I32, I64, I1, QI8, QI16, TFL_Str]>:$result); let hasOptions = 0; } diff --git a/tensorflow/lite/kernels/fill.cc b/tensorflow/lite/kernels/fill.cc index 6a04ded6dcd..6af77546038 100644 --- a/tensorflow/lite/kernels/fill.cc +++ b/tensorflow/lite/kernels/fill.cc @@ -92,6 +92,14 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { GetOutputSafe(context, node, kOutputTensor, &output)); output->type = value->type; + TF_LITE_ENSURE_EQ(context, output->params.scale, value->params.scale); + TF_LITE_ENSURE_EQ(context, output->params.zero_point, + value->params.zero_point); + + if (value->type == kTfLiteInt16) { + TF_LITE_ENSURE_EQ(context, value->params.zero_point, 0); + } + if (IsConstantTensor(dims)) { TF_LITE_ENSURE_OK(context, ResizeOutput(context, dims, output)); } else { @@ -132,6 +140,12 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { GetTensorShape(output), \ GetTensorData(output)) switch (output->type) { + case kTfLiteInt8: + TF_LITE_FILL(int8_t); + break; + case kTfLiteInt16: + TF_LITE_FILL(int16_t); + break; case kTfLiteInt32: TF_LITE_FILL(int32_t); break; @@ -147,14 +161,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { case kTfLiteString: FillString(value, output); break; - case kTfLiteInt8: - TF_LITE_FILL(int8_t); - break; default: context->ReportError( context, - "Fill only currently supports int32, int64, float32, bool, string " - "for input 1, got %d.", + "Fill only currently supports int8, int16, int32, int64, float32, " + "bool, string for input 1, got %d.", value->type); return kTfLiteError; } diff --git a/tensorflow/lite/kernels/fill_test.cc b/tensorflow/lite/kernels/fill_test.cc index ccb8638a28c..1cb38016ad7 100644 --- a/tensorflow/lite/kernels/fill_test.cc +++ b/tensorflow/lite/kernels/fill_test.cc @@ -73,6 +73,42 @@ class FillOpModel : public SingleOpModel { int output_; }; +template +class QuantizedFillOpModel : public SingleOpModel { + public: + explicit QuantizedFillOpModel(TensorType dims_tensor_type, + std::initializer_list dims_shape, + std::initializer_list dims_data, + const TensorData& tensor_data, + float value) { + dims_ = AddInput(dims_tensor_type); + value_ = AddInput(tensor_data); + output_ = AddOutput(tensor_data); + SetBuiltinOp(BuiltinOperator_FILL, BuiltinOptions_FillOptions, + CreateFillOptions(builder_).Union()); + BuildInterpreter({dims_shape, {}}); + + if (dims_data.size() > 0) { + PopulateTensor(dims_, dims_data); + } + QuantizeAndPopulate(value_, {value}); + } + + std::vector GetOutput() { + return ExtractVector(output_); + } + std::vector GetDequantizedOutput() { + TfLiteTensor* t = interpreter_->tensor(output_); + return Dequantize(GetOutput(), t->params.scale, t->params.zero_point); + } + std::vector GetOutputShape() { return GetTensorShape(output_); } + + protected: + int dims_; + int value_; + int output_; +}; + class FillOpTest : public ::testing::TestWithParam {}; TEST_P(FillOpTest, FillInt32) { @@ -144,6 +180,42 @@ TEST_P(FillOpTest, FillInt8) { EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 2})); } +template +void QuantizedFill(float value) { + // Prepare TensorData for quantization of value + const float kMin = -1; + // Workaround to get a zero-point of 0 + const float kMax = + std::numeric_limits::max() / + static_cast(std::numeric_limits::max() + 1); + const TensorData tensor_data(GetTensorType(), {}, + std::abs(value) * kMin, std::abs(value) * kMax); + + QuantizedFillOpModel m(TensorType_INT32, {2}, {2, 3}, + tensor_data, value); + m.Invoke(); + + constexpr float epsilon = 0.01f; + const float min_value = tensor_data.min - epsilon; + const float max_value = tensor_data.max + epsilon; + const float kQuantizedTolerance = + (max_value - min_value) / (std::numeric_limits::max() - + std::numeric_limits::min()); + EXPECT_THAT( + m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear( + {value, value, value, value, value, value}, kQuantizedTolerance))); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 3})); +} + +TEST(FillOpTest, QuantizedFillInt8) { + QuantizedFill(3.14f); +} + +TEST(FillOpTest, QuantizedFillInt16) { + QuantizedFill(3.14f); +} + INSTANTIATE_TEST_SUITE_P(FillOpTest, FillOpTest, ::testing::Values(TestType::kConst, TestType::kDynamic)); diff --git a/tensorflow/lite/kernels/register.cc b/tensorflow/lite/kernels/register.cc index e52e067ad4b..752f83f1791 100644 --- a/tensorflow/lite/kernels/register.cc +++ b/tensorflow/lite/kernels/register.cc @@ -270,7 +270,7 @@ BuiltinOpResolver::BuiltinOpResolver() { /* max_version = */ 2); AddBuiltin(BuiltinOperator_FILL, Register_FILL(), /* min_version = */ 1, - /* max_version = */ 2); + /* max_version = */ 3); AddBuiltin(BuiltinOperator_MIRROR_PAD, Register_MIRROR_PAD(), /* min_version = */ 1, /* max_version = */ 2); diff --git a/tensorflow/lite/kernels/register_ref.cc b/tensorflow/lite/kernels/register_ref.cc index 5a0c068e724..d055daaf994 100644 --- a/tensorflow/lite/kernels/register_ref.cc +++ b/tensorflow/lite/kernels/register_ref.cc @@ -432,7 +432,7 @@ BuiltinRefOpResolver::BuiltinRefOpResolver() { /* max_version = */ 2); AddBuiltin(BuiltinOperator_FILL, Register_FILL(), /* min_version = */ 1, - /* max_version = */ 2); + /* max_version = */ 3); AddBuiltin(BuiltinOperator_MIRROR_PAD, Register_MIRROR_PAD(), /* min_version = */ 1, /* max_version = */ 2); diff --git a/tensorflow/lite/tools/optimize/operator_property.cc b/tensorflow/lite/tools/optimize/operator_property.cc index cf390ed9bfa..ee04f281cd5 100644 --- a/tensorflow/lite/tools/optimize/operator_property.cc +++ b/tensorflow/lite/tools/optimize/operator_property.cc @@ -192,6 +192,13 @@ OperatorProperty GetOperatorProperty(OpVariant op_variant) { property.outputs = {{0, {}}}; property.version = 1; break; + case BuiltinOperator_FILL: { + property.inputs = {{1, {}}}; + property.outputs = {{0, {}}}; + property.restrict_same_input_output_scale = true; + property.version = 3; + break; + } case BuiltinOperator_FULLY_CONNECTED: { TensorProperty tensor_property; tensor_property.symmetric = true; diff --git a/tensorflow/lite/tools/versioning/op_version.cc b/tensorflow/lite/tools/versioning/op_version.cc index 3e479143490..b20dea5e4ec 100644 --- a/tensorflow/lite/tools/versioning/op_version.cc +++ b/tensorflow/lite/tools/versioning/op_version.cc @@ -538,10 +538,14 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) { return 1; case BuiltinOperator_FILL: - if (op_sig.input_types.size() >= 2 && - (op_sig.input_types.at(1) == TensorType_BOOL || - op_sig.input_types.at(1) == TensorType_STRING)) { - return 2; + if (op_sig.input_types.size() >= 2) { + if (op_sig.input_types.at(1) == TensorType_INT8 || + op_sig.input_types.at(1) == TensorType_INT16) { + return 3; + } else if ((op_sig.input_types.at(1) == TensorType_BOOL || + op_sig.input_types.at(1) == TensorType_STRING)) { + return 2; + } } return 1; diff --git a/tensorflow/lite/tools/versioning/op_version_test.cc b/tensorflow/lite/tools/versioning/op_version_test.cc index 905ccb09140..f6c2581139f 100644 --- a/tensorflow/lite/tools/versioning/op_version_test.cc +++ b/tensorflow/lite/tools/versioning/op_version_test.cc @@ -700,7 +700,15 @@ TEST(OpVersionTest, VersioningDivTest) { TEST(OpVersionTEst, VersioningFillTest) { OpSignature fake_op_sig = {.op = BuiltinOperator_FILL, .input_types = std::vector{ - TensorType_INT32, TensorType_BOOL}}; + TensorType_INT32, TensorType_INT8}}; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3); + fake_op_sig = {.op = BuiltinOperator_FILL, + .input_types = std::vector{TensorType_INT64, + TensorType_INT16}}; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3); + fake_op_sig = {.op = BuiltinOperator_FILL, + .input_types = std::vector{TensorType_INT32, + TensorType_BOOL}}; EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); fake_op_sig = {.op = BuiltinOperator_FILL, .input_types = std::vector{TensorType_INT32, diff --git a/tensorflow/lite/tools/versioning/runtime_version.cc b/tensorflow/lite/tools/versioning/runtime_version.cc index d6970762b5a..a075d69b55a 100644 --- a/tensorflow/lite/tools/versioning/runtime_version.cc +++ b/tensorflow/lite/tools/versioning/runtime_version.cc @@ -332,6 +332,7 @@ std::string FindMinimumRuntimeVersionForOp(tflite::BuiltinOperator op_code, {{BuiltinOperator_HARD_SWISH, 1}, "1.15.0"}, {{BuiltinOperator_FILL, 1}, "1.13.0"}, {{BuiltinOperator_FILL, 2}, "2.3.0"}, + {{BuiltinOperator_FILL, 3}, kPendingReleaseVersion}, {{BuiltinOperator_REVERSE_V2, 1}, "1.14.0"}, {{BuiltinOperator_REVERSE_V2, 2}, "2.2.0"}, {{BuiltinOperator_REVERSE_V2, 3}, kPendingReleaseVersion},