String/Bool input support on TFLite Fill op
PiperOrigin-RevId: 305841870 Change-Id: Ibfbff33f039cca9da95fd80a2a3e95d048c456c5
This commit is contained in:
parent
5bde6ce98a
commit
a6fe69e318
tensorflow/lite
kernels
testing
toco/tflite
tools/versioning
@ -94,6 +94,20 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus FillString(const TfLiteTensor* value, TfLiteTensor* output) {
|
||||
DynamicBuffer buffer;
|
||||
const auto string_ref = GetString(value, 0);
|
||||
int n = 1;
|
||||
for (int i = 0; i < output->dims->size; ++i) {
|
||||
n *= output->dims->data[i];
|
||||
}
|
||||
for (int i = 0; i < n; ++i) {
|
||||
buffer.AddString(string_ref.str, string_ref.len);
|
||||
}
|
||||
buffer.WriteToTensor(output, /*new_shape=*/nullptr);
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* value = GetInput(context, node, kValueTensor);
|
||||
|
||||
@ -117,11 +131,17 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
case kTfLiteFloat32:
|
||||
TF_LITE_FILL(float);
|
||||
break;
|
||||
case kTfLiteBool:
|
||||
TF_LITE_FILL(bool);
|
||||
break;
|
||||
case kTfLiteString:
|
||||
FillString(value, output);
|
||||
break;
|
||||
default:
|
||||
context->ReportError(
|
||||
context,
|
||||
"Fill only currently supports int32, int64, float32 for input 1,"
|
||||
"got %d.",
|
||||
"Fill only currently supports int32, int64, float32, bool, string "
|
||||
"for input 1, got %d.",
|
||||
value->type);
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
@ -84,5 +84,27 @@ TEST(FillOpModel, FillOutputScalar) {
|
||||
EXPECT_THAT(m.GetTensorShape(m.output()), IsEmpty());
|
||||
}
|
||||
|
||||
TEST(FillOpModel, FillBool) {
|
||||
FillOpModel m({TensorType_INT64, {3}}, {TensorType_BOOL});
|
||||
m.PopulateTensor<int64_t>(m.input1(), {2, 2, 2});
|
||||
m.PopulateTensor<bool>(m.input2(), {true});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(
|
||||
m.ExtractVector<bool>(m.output()),
|
||||
ElementsAreArray({true, true, true, true, true, true, true, true}));
|
||||
EXPECT_THAT(m.GetTensorShape(m.output()), ElementsAreArray({2, 2, 2}));
|
||||
}
|
||||
|
||||
TEST(FillOpModel, FillString) {
|
||||
FillOpModel m({TensorType_INT64, {3}}, {TensorType_STRING});
|
||||
m.PopulateTensor<int64_t>(m.input1(), {2, 2, 2});
|
||||
m.PopulateTensor<std::string>(m.input2(), {"AB"});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(
|
||||
m.ExtractVector<std::string>(m.output()),
|
||||
ElementsAreArray({"AB", "AB", "AB", "AB", "AB", "AB", "AB", "AB"}));
|
||||
EXPECT_THAT(m.GetTensorShape(m.output()), ElementsAreArray({2, 2, 2}));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tflite
|
||||
|
@ -255,7 +255,9 @@ BuiltinOpResolver::BuiltinOpResolver() {
|
||||
/* min_version = */ 1,
|
||||
/* max_version = */ 2);
|
||||
AddBuiltin(BuiltinOperator_SQUARED_DIFFERENCE, Register_SQUARED_DIFFERENCE());
|
||||
AddBuiltin(BuiltinOperator_FILL, Register_FILL());
|
||||
AddBuiltin(BuiltinOperator_FILL, Register_FILL(),
|
||||
/* min_version = */ 1,
|
||||
/* max_version = */ 2);
|
||||
AddBuiltin(BuiltinOperator_MIRROR_PAD, Register_MIRROR_PAD());
|
||||
AddBuiltin(BuiltinOperator_UNIQUE, Register_UNIQUE());
|
||||
AddBuiltin(BuiltinOperator_REVERSE_V2, Register_REVERSE_V2(),
|
||||
|
@ -31,7 +31,7 @@ def make_fill_tests(options):
|
||||
test_parameters = [{
|
||||
"dims_dtype": [tf.int32, tf.int64],
|
||||
"dims_shape": [[], [1], [3], [3, 3]],
|
||||
"value_dtype": [tf.int32, tf.int64, tf.float32],
|
||||
"value_dtype": [tf.int32, tf.int64, tf.float32, tf.bool, tf.string],
|
||||
}]
|
||||
|
||||
def build_graph(parameters):
|
||||
@ -57,4 +57,4 @@ def make_fill_tests(options):
|
||||
test_parameters,
|
||||
build_graph,
|
||||
build_inputs,
|
||||
expected_tf_failures=12)
|
||||
expected_tf_failures=20)
|
||||
|
@ -133,6 +133,11 @@ def create_scalar_data(dtype, min_value=-100, max_value=100):
|
||||
value = (max_value - min_value) * np.random.random() + min_value
|
||||
elif dtype in (tf.int32, tf.uint8, tf.int64, tf.int16):
|
||||
value = np.random.randint(min_value, max_value + 1)
|
||||
elif dtype == tf.bool:
|
||||
value = np.random.choice([True, False])
|
||||
elif dtype == np.string_:
|
||||
l = np.random.randint(1, 6)
|
||||
value = "".join(np.random.choice(list(string.ascii_uppercase), size=l))
|
||||
return np.array(value, dtype=dtype)
|
||||
|
||||
|
||||
|
@ -224,7 +224,7 @@ string GetMinimumRuntimeVersionForModel(const Model& model) {
|
||||
{{OperatorType::kZerosLike, 1}, "1.12.0"},
|
||||
{{OperatorType::kAbs, 1}, "1.13.0"},
|
||||
{{OperatorType::kHardSwish, 1}, "1.15.0"},
|
||||
{{OperatorType::kFill, 1}, "1.13.0"},
|
||||
{{OperatorType::kFill, 2}, kPendingReleaseOpVersion},
|
||||
{{OperatorType::kReverseV2, 1}, "1.14.0"},
|
||||
{{OperatorType::kReverseV2, 2}, "2.2.0"},
|
||||
{{OperatorType::kRank, 1}, "1.14.0"},
|
||||
|
@ -399,6 +399,14 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) {
|
||||
}
|
||||
return 1;
|
||||
|
||||
case BuiltinOperator_FILL:
|
||||
if (op_sig.input_types.size() >= 2 &&
|
||||
(op_sig.input_types.at(1) == TensorType_BOOL ||
|
||||
op_sig.input_types.at(1) == TensorType_STRING)) {
|
||||
return 2;
|
||||
}
|
||||
return 1;
|
||||
|
||||
case BuiltinOperator_ADD:
|
||||
case BuiltinOperator_CONCATENATION:
|
||||
case BuiltinOperator_PAD:
|
||||
|
@ -554,4 +554,18 @@ TEST(OpVersionTest, VersioningDivTest) {
|
||||
fake_op_sig.options.broadcast.num_dims = 4;
|
||||
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1);
|
||||
}
|
||||
TEST(OpVersionTEst, VersioningFillTest) {
|
||||
OpSignature fake_op_sig = {.op = BuiltinOperator_FILL,
|
||||
.input_types = std::vector<TensorType>{
|
||||
TensorType_INT32, TensorType_BOOL}};
|
||||
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2);
|
||||
fake_op_sig = {.op = BuiltinOperator_FILL,
|
||||
.input_types = std::vector<TensorType>{TensorType_INT32,
|
||||
TensorType_STRING}};
|
||||
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2);
|
||||
fake_op_sig = {.op = BuiltinOperator_FILL,
|
||||
.input_types = std::vector<TensorType>{TensorType_INT32,
|
||||
TensorType_INT32}};
|
||||
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1);
|
||||
}
|
||||
} // namespace tflite
|
||||
|
Loading…
Reference in New Issue
Block a user