Support conversion of tf.text ops to Flex ops
PiperOrigin-RevId: 322733983 Change-Id: Ie32a3912e7575ff84318de8e6aa2d087eaac8fbe
This commit is contained in:
parent
e543b6842a
commit
537f3ec52f
@ -239,6 +239,17 @@ cc_library(
|
||||
"allowlisted_flex_ops.h",
|
||||
"allowlisted_flex_ops_internal.h",
|
||||
],
|
||||
deps = select({
|
||||
"//tensorflow:android": [
|
||||
"//tensorflow/core:portable_tensorflow_lib_lite",
|
||||
],
|
||||
"//tensorflow:ios": [
|
||||
"//tensorflow/core:portable_tensorflow_lib_lite",
|
||||
],
|
||||
"//conditions:default": [
|
||||
"//tensorflow/core:framework",
|
||||
],
|
||||
}),
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
|
||||
#include <set>
|
||||
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/lite/delegates/flex/allowlisted_flex_ops_internal.h"
|
||||
|
||||
namespace tflite {
|
||||
@ -547,8 +548,36 @@ const std::set<std::string>& GetFlexAllowlist() {
|
||||
// NOLINTNEXTLINE
|
||||
}
|
||||
|
||||
// Allow the tf.text ops if they are registered in the global op registry.
|
||||
bool IsAllowedTFTextOpForFlex(const std::string& op_name) {
|
||||
static const std::set<std::string>* tftext_flex_ops =
|
||||
new std::set<std::string>({
|
||||
"CaseFoldUTF8",
|
||||
"ConstrainedSequence",
|
||||
"MaxSpanningTree",
|
||||
"NormalizeUTF8",
|
||||
"NormalizeUTF8WithOffsetsMap",
|
||||
"RegexSplitWithOffsets",
|
||||
"RougeL",
|
||||
"SentenceFragments",
|
||||
"SentencepieceOp",
|
||||
"SentencepieceTokenizeOp",
|
||||
"SentencepieceTokenizeWithOffsetsOp",
|
||||
"SentencepieceDetokenizeOp",
|
||||
"SentencepieceVocabSizeOp",
|
||||
"SplitMergeTokenizeWithOffsets",
|
||||
"UnicodeScriptTokenizeWithOffsets",
|
||||
"WhitespaceTokenizeWithOffsets",
|
||||
"WordpieceTokenizeWithOffsets",
|
||||
});
|
||||
if (tftext_flex_ops->count(op_name) == 0) return false;
|
||||
return tensorflow::OpRegistry::Global()->LookUp(op_name) != nullptr;
|
||||
}
|
||||
|
||||
bool IsAllowlistedFlexOp(const std::string& tensorflow_op_name) {
|
||||
return GetFlexAllowlist().count(tensorflow_op_name) != 0;
|
||||
if (GetFlexAllowlist().count(tensorflow_op_name) != 0) return true;
|
||||
// Check if the op is an allowlisted tf.text op.
|
||||
return IsAllowedTFTextOpForFlex(tensorflow_op_name);
|
||||
}
|
||||
|
||||
} // namespace flex
|
||||
|
@ -24,6 +24,9 @@ namespace flex {
|
||||
// Return the list of allowlisted flex ops.
|
||||
const std::set<std::string>& GetFlexAllowlist();
|
||||
|
||||
// Return true if op_name is a tf.text op need to be supported by flex delegate.
|
||||
bool IsAllowedTFTextOpForFlex(const std::string& op_name);
|
||||
|
||||
} // namespace flex
|
||||
} // namespace tflite
|
||||
|
||||
|
@ -52,6 +52,16 @@ TEST(AllowlistedFlexOpsTest, EveryOpHasKernel) {
|
||||
<< "but its kernel is not found.";
|
||||
}
|
||||
}
|
||||
|
||||
TEST(TfTextUtilsTest, TestFlexOpAllowed) {
|
||||
// Expect false since ConstrainedSequence kernel is not registered.
|
||||
EXPECT_FALSE(IsAllowedTFTextOpForFlex("ConstrainedSequence"));
|
||||
}
|
||||
|
||||
TEST(TfTextUtilsTest, TestFlexOpNotAllowed) {
|
||||
EXPECT_FALSE(IsAllowedTFTextOpForFlex("ngrams"));
|
||||
}
|
||||
|
||||
} // namespace flex
|
||||
} // namespace tflite
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user