Support conversion of tf.text ops to Flex ops

PiperOrigin-RevId: 322733983
Change-Id: Ie32a3912e7575ff84318de8e6aa2d087eaac8fbe
This commit is contained in:
Thai Nguyen 2020-07-23 00:07:13 -07:00 committed by TensorFlower Gardener
parent e543b6842a
commit 537f3ec52f
4 changed files with 54 additions and 1 deletions

View File

@ -239,6 +239,17 @@ cc_library(
"allowlisted_flex_ops.h",
"allowlisted_flex_ops_internal.h",
],
deps = select({
"//tensorflow:android": [
"//tensorflow/core:portable_tensorflow_lib_lite",
],
"//tensorflow:ios": [
"//tensorflow/core:portable_tensorflow_lib_lite",
],
"//conditions:default": [
"//tensorflow/core:framework",
],
}),
)
tf_cc_test(

View File

@ -16,6 +16,7 @@ limitations under the License.
#include <set>
#include "tensorflow/core/framework/op.h"
#include "tensorflow/lite/delegates/flex/allowlisted_flex_ops_internal.h"
namespace tflite {
@ -547,8 +548,36 @@ const std::set<std::string>& GetFlexAllowlist() {
// NOLINTNEXTLINE
}
// Allow the tf.text ops if they are registered in the global op registry.
bool IsAllowedTFTextOpForFlex(const std::string& op_name) {
static const std::set<std::string>* tftext_flex_ops =
new std::set<std::string>({
"CaseFoldUTF8",
"ConstrainedSequence",
"MaxSpanningTree",
"NormalizeUTF8",
"NormalizeUTF8WithOffsetsMap",
"RegexSplitWithOffsets",
"RougeL",
"SentenceFragments",
"SentencepieceOp",
"SentencepieceTokenizeOp",
"SentencepieceTokenizeWithOffsetsOp",
"SentencepieceDetokenizeOp",
"SentencepieceVocabSizeOp",
"SplitMergeTokenizeWithOffsets",
"UnicodeScriptTokenizeWithOffsets",
"WhitespaceTokenizeWithOffsets",
"WordpieceTokenizeWithOffsets",
});
if (tftext_flex_ops->count(op_name) == 0) return false;
return tensorflow::OpRegistry::Global()->LookUp(op_name) != nullptr;
}
bool IsAllowlistedFlexOp(const std::string& tensorflow_op_name) {
return GetFlexAllowlist().count(tensorflow_op_name) != 0;
if (GetFlexAllowlist().count(tensorflow_op_name) != 0) return true;
// Check if the op is an allowlisted tf.text op.
return IsAllowedTFTextOpForFlex(tensorflow_op_name);
}
} // namespace flex

View File

@ -24,6 +24,9 @@ namespace flex {
// Return the list of allowlisted flex ops.
const std::set<std::string>& GetFlexAllowlist();
// Return true if op_name is a tf.text op need to be supported by flex delegate.
bool IsAllowedTFTextOpForFlex(const std::string& op_name);
} // namespace flex
} // namespace tflite

View File

@ -52,6 +52,16 @@ TEST(AllowlistedFlexOpsTest, EveryOpHasKernel) {
<< "but its kernel is not found.";
}
}
TEST(TfTextUtilsTest, TestFlexOpAllowed) {
// Expect false since ConstrainedSequence kernel is not registered.
EXPECT_FALSE(IsAllowedTFTextOpForFlex("ConstrainedSequence"));
}
TEST(TfTextUtilsTest, TestFlexOpNotAllowed) {
EXPECT_FALSE(IsAllowedTFTextOpForFlex("ngrams"));
}
} // namespace flex
} // namespace tflite