String/Bool input support on TFLite Fill op

PiperOrigin-RevId: 305841870
Change-Id: Ibfbff33f039cca9da95fd80a2a3e95d048c456c5
This commit is contained in:
Hyeonjong Ryu 2020-04-10 01:07:59 -07:00 committed by TensorFlower Gardener
parent 5bde6ce98a
commit a6fe69e318
8 changed files with 77 additions and 6 deletions

View File

@ -94,6 +94,20 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}
TfLiteStatus FillString(const TfLiteTensor* value, TfLiteTensor* output) {
DynamicBuffer buffer;
const auto string_ref = GetString(value, 0);
int n = 1;
for (int i = 0; i < output->dims->size; ++i) {
n *= output->dims->data[i];
}
for (int i = 0; i < n; ++i) {
buffer.AddString(string_ref.str, string_ref.len);
}
buffer.WriteToTensor(output, /*new_shape=*/nullptr);
return kTfLiteOk;
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* value = GetInput(context, node, kValueTensor);
@ -117,11 +131,17 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
case kTfLiteFloat32:
TF_LITE_FILL(float);
break;
case kTfLiteBool:
TF_LITE_FILL(bool);
break;
case kTfLiteString:
FillString(value, output);
break;
default:
context->ReportError(
context,
"Fill only currently supports int32, int64, float32 for input 1,"
"got %d.",
"Fill only currently supports int32, int64, float32, bool, string "
"for input 1, got %d.",
value->type);
return kTfLiteError;
}

View File

@ -84,5 +84,27 @@ TEST(FillOpModel, FillOutputScalar) {
EXPECT_THAT(m.GetTensorShape(m.output()), IsEmpty());
}
TEST(FillOpModel, FillBool) {
FillOpModel m({TensorType_INT64, {3}}, {TensorType_BOOL});
m.PopulateTensor<int64_t>(m.input1(), {2, 2, 2});
m.PopulateTensor<bool>(m.input2(), {true});
m.Invoke();
EXPECT_THAT(
m.ExtractVector<bool>(m.output()),
ElementsAreArray({true, true, true, true, true, true, true, true}));
EXPECT_THAT(m.GetTensorShape(m.output()), ElementsAreArray({2, 2, 2}));
}
TEST(FillOpModel, FillString) {
FillOpModel m({TensorType_INT64, {3}}, {TensorType_STRING});
m.PopulateTensor<int64_t>(m.input1(), {2, 2, 2});
m.PopulateTensor<std::string>(m.input2(), {"AB"});
m.Invoke();
EXPECT_THAT(
m.ExtractVector<std::string>(m.output()),
ElementsAreArray({"AB", "AB", "AB", "AB", "AB", "AB", "AB", "AB"}));
EXPECT_THAT(m.GetTensorShape(m.output()), ElementsAreArray({2, 2, 2}));
}
} // namespace
} // namespace tflite

View File

@ -255,7 +255,9 @@ BuiltinOpResolver::BuiltinOpResolver() {
/* min_version = */ 1,
/* max_version = */ 2);
AddBuiltin(BuiltinOperator_SQUARED_DIFFERENCE, Register_SQUARED_DIFFERENCE());
AddBuiltin(BuiltinOperator_FILL, Register_FILL());
AddBuiltin(BuiltinOperator_FILL, Register_FILL(),
/* min_version = */ 1,
/* max_version = */ 2);
AddBuiltin(BuiltinOperator_MIRROR_PAD, Register_MIRROR_PAD());
AddBuiltin(BuiltinOperator_UNIQUE, Register_UNIQUE());
AddBuiltin(BuiltinOperator_REVERSE_V2, Register_REVERSE_V2(),

View File

@ -31,7 +31,7 @@ def make_fill_tests(options):
test_parameters = [{
"dims_dtype": [tf.int32, tf.int64],
"dims_shape": [[], [1], [3], [3, 3]],
"value_dtype": [tf.int32, tf.int64, tf.float32],
"value_dtype": [tf.int32, tf.int64, tf.float32, tf.bool, tf.string],
}]
def build_graph(parameters):
@ -57,4 +57,4 @@ def make_fill_tests(options):
test_parameters,
build_graph,
build_inputs,
expected_tf_failures=12)
expected_tf_failures=20)

View File

@ -133,6 +133,11 @@ def create_scalar_data(dtype, min_value=-100, max_value=100):
value = (max_value - min_value) * np.random.random() + min_value
elif dtype in (tf.int32, tf.uint8, tf.int64, tf.int16):
value = np.random.randint(min_value, max_value + 1)
elif dtype == tf.bool:
value = np.random.choice([True, False])
elif dtype == np.string_:
l = np.random.randint(1, 6)
value = "".join(np.random.choice(list(string.ascii_uppercase), size=l))
return np.array(value, dtype=dtype)

View File

@ -224,7 +224,7 @@ string GetMinimumRuntimeVersionForModel(const Model& model) {
{{OperatorType::kZerosLike, 1}, "1.12.0"},
{{OperatorType::kAbs, 1}, "1.13.0"},
{{OperatorType::kHardSwish, 1}, "1.15.0"},
{{OperatorType::kFill, 1}, "1.13.0"},
{{OperatorType::kFill, 2}, kPendingReleaseOpVersion},
{{OperatorType::kReverseV2, 1}, "1.14.0"},
{{OperatorType::kReverseV2, 2}, "2.2.0"},
{{OperatorType::kRank, 1}, "1.14.0"},

View File

@ -399,6 +399,14 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) {
}
return 1;
case BuiltinOperator_FILL:
if (op_sig.input_types.size() >= 2 &&
(op_sig.input_types.at(1) == TensorType_BOOL ||
op_sig.input_types.at(1) == TensorType_STRING)) {
return 2;
}
return 1;
case BuiltinOperator_ADD:
case BuiltinOperator_CONCATENATION:
case BuiltinOperator_PAD:

View File

@ -554,4 +554,18 @@ TEST(OpVersionTest, VersioningDivTest) {
fake_op_sig.options.broadcast.num_dims = 4;
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1);
}
TEST(OpVersionTEst, VersioningFillTest) {
OpSignature fake_op_sig = {.op = BuiltinOperator_FILL,
.input_types = std::vector<TensorType>{
TensorType_INT32, TensorType_BOOL}};
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2);
fake_op_sig = {.op = BuiltinOperator_FILL,
.input_types = std::vector<TensorType>{TensorType_INT32,
TensorType_STRING}};
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2);
fake_op_sig = {.op = BuiltinOperator_FILL,
.input_types = std::vector<TensorType>{TensorType_INT32,
TensorType_INT32}};
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1);
}
} // namespace tflite