Fuse tensorflow_text.ngrams into a TFLite custom op

PiperOrigin-RevId: 323456482
Change-Id: Idfd446c371e8a4a4f82b6da730d02b0897d35a8a
This commit is contained in:
A. Unique TensorFlower 2020-07-27 15:42:20 -07:00 committed by TensorFlower Gardener
parent c7e7f49228
commit 07e4db17ff
3 changed files with 3602 additions and 3224 deletions

View File

@ -270,6 +270,7 @@ cc_library(
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_attributes",
"//tensorflow/core:framework",
"@flatbuffers",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:StandardOps",

View File

@ -1,5 +1,4 @@
// RUN: tf-opt -tfl-prepare-composite-funcs-tf -tfl-fuse-tftext=true %s -split-input-file | FileCheck %s
module {
// RUN: tf-opt -tfl-prepare-composite-funcs-tf -tfl-fuse-tftext=true %s | FileCheck %s
func @whitespace_tokenizer_rank1(%arg0: tensor<1x!tf.string> {tf._user_specified_name = "input"}) -> (tensor<?x!tf.string>, tensor<?xi64>) attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<1>], tf._implements = #tf.func<@"tftext:WhitespaceTokenizer", {}>, tf.signature.is_stateful} {
%0 = "tf.Const"() {value = dense<[0, 1]> : tensor<2xi64>} : () -> tensor<2xi64>
@ -3194,4 +3193,246 @@ module {
// CHECK: func @whitespace_tokenizer_rank0(%arg0: tensor<!tf.string> {tf._user_specified_name = "input"}) -> tensor<?x!tf.string> attributes {sym_visibility = "private", tf._implements = #tf.func<@"tftext:WhitespaceTokenizer", {}>, tf._input_shapes = [#tf.shape<>], tf.signature.is_stateful} {
// CHECK: %0 = "tfl.custom"(%arg0) {custom_code = "tftext:WhitespaceTokenizer", custom_option = opaque<"tfl", "0x"> : tensor<0xi8>} : (tensor<!tf.string>) -> tensor<?x!tf.string>
// CHECK: return %0 : tensor<?x!tf.string>
func @ngrams(%arg0: tensor<?x!tf.string> {tf._user_specified_name = "input"}) -> tensor<?x!tf.string> attributes {tf._input_shapes = [#tf.shape<?>], tf._implements = #tf.func<@"tftext:Ngrams", {axis = -1 : i64, reduction_type = "STRING_JOIN", string_separator = " ", width = 2 : i64}>} {
%0 = "tf.Const"() {value = dense<-1> : tensor<i32>} : () -> tensor<i32>
%1 = "tf.Const"() {value = dense<[0, -1]> : tensor<2xi32>} : () -> tensor<2xi32>
%2 = "tf.Const"() {value = dense<[0, 1]> : tensor<2xi32>} : () -> tensor<2xi32>
%3 = "tf.Const"() {value = dense<0> : tensor<2xi32>} : () -> tensor<2xi32>
%4 = "tf.Const"() {value = dense<1> : tensor<2xi32>} : () -> tensor<2xi32>
%5 = "tf.StridedSlice"(%arg0, %3, %1, %4) {begin_mask = 0 : i64, device = "", ellipsis_mask = 1 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<?x!tf.string>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<?x!tf.string>
%6 = "tf.StridedSlice"(%arg0, %2, %3, %4) {begin_mask = 0 : i64, device = "", ellipsis_mask = 1 : i64, end_mask = 2 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<?x!tf.string>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<?x!tf.string>
%7 = "tf.Pack"(%5, %6) {axis = -1 : i64, device = ""} : (tensor<?x!tf.string>, tensor<?x!tf.string>) -> tensor<?x2x!tf.string>
%8 = "tf.ReduceJoin"(%7, %0) {device = "", keep_dims = false, separator = " "} : (tensor<?x2x!tf.string>, tensor<i32>) -> tensor<?x!tf.string>
%9 = "tf.Identity"(%8) {device = ""} : (tensor<?x!tf.string>) -> tensor<?x!tf.string>
return %9 : tensor<?x!tf.string>
}
// CHECK: func @ngrams(%arg0: tensor<?x!tf.string> {tf._user_specified_name = "input"}) -> tensor<?x!tf.string> attributes {tf._implements = #tf.func<@"tftext:Ngrams", {axis = -1 : i64, reduction_type = "STRING_JOIN", string_separator = " ", width = 2 : i64}>, tf._input_shapes = [#tf.shape<?>]} {
// CHECK: %0 = "tfl.custom"(%arg0) {custom_code = "tftext:Ngrams", custom_option = opaque<"tfl", "0x776964746800737472696E675F736570617261746F72000120006178697300726564756374696F6E5F74797065000B535452494E475F4A4F494E0004221E383F040104FF152D0204141404082401"> : tensor<78xi8>} : (tensor<?x!tf.string>) -> tensor<?x!tf.string>
// CHECK: return %0 : tensor<?x!tf.string>
// CHECK: }
func @ngrams_ragged_rank_2(%arg0: tensor<?x!tf.string> {tf._user_specified_name = "values"}, %arg1: tensor<3xi64> {tf._user_specified_name = "args_0"}, %arg2: tensor<?xi64> {tf._user_specified_name = "args_1"}) -> (tensor<?x!tf.string>, tensor<3xi64>, tensor<?xi64>) attributes {sym_visibility = "private", tf._implements = #tf.func<@"tftext:Ngrams", {axis = -1 : i64, reduction_type = "STRING_JOIN", string_separator = "", width = 2 : i64}>, tf._input_shapes = [#tf.shape<?>, #tf.shape<3>, #tf.shape<?>], tf.signature.is_stateful} {
%0 = "tf.Const"() {value = dense<-1> : tensor<i32>} : () -> tensor<i32>
%1 = "tf.Const"() {value = dense<-1> : tensor<i64>} : () -> tensor<i64>
%2 = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
%3 = "tf.Const"() {value = dense<1> : tensor<i64>} : () -> tensor<i64>
%4 = "tf.Const"() {value = dense<0> : tensor<i64>} : () -> tensor<i64>
%5 = "tf.Const"() {value = dense<> : tensor<0xi32>} : () -> tensor<0xi32>
%6 = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32>
%7 = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
%8 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
%9 = "tf.StridedSlice"(%arg1, %7, %8, %8) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<3xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<i64>
%10 = "tf.Equal"(%9, %4) {device = "", incompatible_shape_error = true} : (tensor<i64>, tensor<i64>) -> tensor<i1>
%11 = "tf.All"(%10, %5) {device = "", keep_dims = false} : (tensor<i1>, tensor<0xi32>) -> tensor<i1>
%12 = "tf.StridedSlice"(%arg1, %8, %7, %8) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<3xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi64>
%13 = "tf.StridedSlice"(%arg1, %7, %6, %8) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<3xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi64>
%14 = "tf.Sub"(%12, %13) {device = ""} : (tensor<2xi64>, tensor<2xi64>) -> tensor<2xi64>
%15 = "tf.LessEqual"(%4, %14) {device = ""} : (tensor<i64>, tensor<2xi64>) -> tensor<2xi1>
%16 = "tf.All"(%15, %7) {device = "", keep_dims = false} : (tensor<2xi1>, tensor<1xi32>) -> tensor<i1>
%17 = "tf.StridedSlice"(%arg2, %7, %8, %8) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<?xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<i64>
%18 = "tf.Equal"(%17, %4) {device = "", incompatible_shape_error = true} : (tensor<i64>, tensor<i64>) -> tensor<i1>
%19 = "tf.All"(%18, %5) {device = "", keep_dims = false} : (tensor<i1>, tensor<0xi32>) -> tensor<i1>
%20 = "tf.IfRegion"(%19) ( {
%72 = "std.call"(%19, %17, %4) {callee = @RaggedFromNestedRowSplits_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_true_27770} : (tensor<i1>, tensor<i64>, tensor<i64>) -> tensor<i1>
"tf.Yield"(%72) : (tensor<i1>) -> ()
}, {
%72 = "std.call"(%19, %17, %4) {callee = @RaggedFromNestedRowSplits_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_false_27780} : (tensor<i1>, tensor<i64>, tensor<i64>) -> tensor<i1>
"tf.Yield"(%72) : (tensor<i1>) -> ()
}) {is_stateless = false} : (tensor<i1>) -> tensor<i1>
%21 = "tf.Identity"(%20) {device = ""} : (tensor<i1>) -> tensor<i1>
%22 = "tf.StridedSlice"(%arg2, %8, %7, %8) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<?xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<?xi64>
%23 = "tf.StridedSlice"(%arg2, %7, %6, %8) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<?xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<?xi64>
%24 = "tf.Sub"(%22, %23) {device = ""} : (tensor<?xi64>, tensor<?xi64>) -> tensor<?xi64>
%25 = "tf.LessEqual"(%4, %24) {device = ""} : (tensor<i64>, tensor<?xi64>) -> tensor<?xi1>
%26 = "tf.All"(%25, %7) {device = "", keep_dims = false} : (tensor<?xi1>, tensor<1xi32>) -> tensor<i1>
%27 = "tf.IfRegion"(%26) ( {
%72 = "std.call"(%26, %24) {callee = @RaggedFromNestedRowSplits_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_true_28130} : (tensor<i1>, tensor<?xi64>) -> tensor<i1>
"tf.Yield"(%72) : (tensor<i1>) -> ()
}, {
%72 = "std.call"(%26, %24) {callee = @RaggedFromNestedRowSplits_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_false_28140} : (tensor<i1>, tensor<?xi64>) -> tensor<i1>
"tf.Yield"(%72) : (tensor<i1>) -> ()
}) {is_stateless = false} : (tensor<i1>) -> tensor<i1>
%28 = "tf.Identity"(%27) {device = ""} : (tensor<i1>) -> tensor<i1>
%29 = "tf.Identity"(%arg2) {_class = ["loc:@args_1"], device = ""} : (tensor<?xi64>) -> tensor<?xi64>
%30 = "tf.StridedSlice"(%29, %6, %7, %8) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<?xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<i64>
%31 = "tf.Shape"(%arg0) {device = ""} : (tensor<?x!tf.string>) -> tensor<1xi64>
%32 = "tf.StridedSlice"(%31, %7, %8, %8) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<i64>
%33 = "tf.Equal"(%30, %32) {device = "", incompatible_shape_error = true} : (tensor<i64>, tensor<i64>) -> tensor<i1>
%34 = "tf.All"(%33, %5) {device = "", keep_dims = false} : (tensor<i1>, tensor<0xi32>) -> tensor<i1>
%35 = "tf.IfRegion"(%34) ( {
%72 = "std.call"(%34, %30, %32) {callee = @RaggedFromNestedRowSplits_RaggedFromRowSplits_assert_equal_1_Assert_AssertGuard_true_28500} : (tensor<i1>, tensor<i64>, tensor<i64>) -> tensor<i1>
"tf.Yield"(%72) : (tensor<i1>) -> ()
}, {
%72 = "std.call"(%34, %30, %32) {callee = @RaggedFromNestedRowSplits_RaggedFromRowSplits_assert_equal_1_Assert_AssertGuard_false_28510} : (tensor<i1>, tensor<i64>, tensor<i64>) -> tensor<i1>
"tf.Yield"(%72) : (tensor<i1>) -> ()
}) {is_stateless = false} : (tensor<i1>) -> tensor<i1>
%36 = "tf.Identity"(%35) {device = ""} : (tensor<i1>) -> tensor<i1>
%37 = "tf.Identity"(%29) {_class = ["loc:@args_1"], device = ""} : (tensor<?xi64>) -> tensor<?xi64>
%38 = "tf.StridedSlice"(%37, %7, %6, %8) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<?xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<?xi64>
%39 = "tf.StridedSlice"(%37, %8, %7, %8) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<?xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<?xi64>
%40 = "tf.Minimum"(%38, %39) {device = ""} : (tensor<?xi64>, tensor<?xi64>) -> tensor<?xi64>
%41 = "tf.AddV2"(%39, %1) {device = ""} : (tensor<?xi64>, tensor<i64>) -> tensor<?xi64>
%42 = "tf.Maximum"(%41, %38) {device = ""} : (tensor<?xi64>, tensor<?xi64>) -> tensor<?xi64>
%43:2 = "tf.RaggedRange"(%40, %42, %3) {T = i64, Tsplits = i64, device = ""} : (tensor<?xi64>, tensor<?xi64>, tensor<i64>) -> (tensor<?xi64>, tensor<?xi64>)
%44 = "tf.GatherV2"(%arg0, %43#1, %2) {batch_dims = 0 : i64, device = ""} : (tensor<?x!tf.string>, tensor<?xi64>, tensor<i32>) -> tensor<?x!tf.string>
%45 = "tf.AddV2"(%38, %3) {device = ""} : (tensor<?xi64>, tensor<i64>) -> tensor<?xi64>
%46 = "tf.Minimum"(%45, %39) {device = ""} : (tensor<?xi64>, tensor<?xi64>) -> tensor<?xi64>
%47:2 = "tf.RaggedRange"(%46, %39, %3) {T = i64, Tsplits = i64, device = ""} : (tensor<?xi64>, tensor<?xi64>, tensor<i64>) -> (tensor<?xi64>, tensor<?xi64>)
%48 = "tf.Equal"(%43#0, %47#0) {device = "", incompatible_shape_error = true} : (tensor<?xi64>, tensor<?xi64>) -> tensor<?xi1>
%49 = "tf.All"(%48, %7) {device = "", keep_dims = false} : (tensor<?xi1>, tensor<1xi32>) -> tensor<i1>
%50 = "tf.GatherV2"(%arg0, %47#1, %2) {batch_dims = 0 : i64, device = ""} : (tensor<?x!tf.string>, tensor<?xi64>, tensor<i32>) -> tensor<?x!tf.string>
%51 = "tf.Shape"(%37) {device = ""} : (tensor<?xi64>) -> tensor<1xi64>
%52 = "tf.StridedSlice"(%51, %7, %8, %8) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<i64>
%53 = "tf.Sub"(%52, %3) {device = ""} : (tensor<i64>, tensor<i64>) -> tensor<i64>
%54 = "tf.IfRegion"(%11) ( {
%72 = "std.call"(%11, %9, %4) {callee = @RaggedFromNestedRowSplits_RaggedFromRowSplits_1_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_true_28900} : (tensor<i1>, tensor<i64>, tensor<i64>) -> tensor<i1>
"tf.Yield"(%72) : (tensor<i1>) -> ()
}, {
%72 = "std.call"(%11, %9, %4) {callee = @RaggedFromNestedRowSplits_RaggedFromRowSplits_1_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_false_28910} : (tensor<i1>, tensor<i64>, tensor<i64>) -> tensor<i1>
"tf.Yield"(%72) : (tensor<i1>) -> ()
}) {is_stateless = false} : (tensor<i1>) -> tensor<i1>
%55 = "tf.Identity"(%54) {device = ""} : (tensor<i1>) -> tensor<i1>
%56 = "tf.IfRegion"(%16) ( {
%72 = "std.call"(%16, %14) {callee = @RaggedFromNestedRowSplits_RaggedFromRowSplits_1_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_true_29260} : (tensor<i1>, tensor<2xi64>) -> tensor<i1>
"tf.Yield"(%72) : (tensor<i1>) -> ()
}, {
%72 = "std.call"(%16, %14) {callee = @RaggedFromNestedRowSplits_RaggedFromRowSplits_1_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_false_29270} : (tensor<i1>, tensor<2xi64>) -> tensor<i1>
"tf.Yield"(%72) : (tensor<i1>) -> ()
}) {is_stateless = false} : (tensor<i1>) -> tensor<i1>
%57 = "tf.Identity"(%56) {device = ""} : (tensor<i1>) -> tensor<i1>
%58 = "tf.Identity"(%arg1) {_class = ["loc:@args_0"], device = ""} : (tensor<3xi64>) -> tensor<3xi64>
%59 = "tf.StridedSlice"(%58, %6, %7, %8) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<3xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<i64>
%60 = "tf.Equal"(%59, %53) {device = "", incompatible_shape_error = true} : (tensor<i64>, tensor<i64>) -> tensor<i1>
%61 = "tf.All"(%60, %5) {device = "", keep_dims = false} : (tensor<i1>, tensor<0xi32>) -> tensor<i1>
%62 = "tf.IfRegion"(%61) ( {
%72 = "std.call"(%61, %59, %53) {callee = @RaggedFromNestedRowSplits_RaggedFromRowSplits_1_assert_equal_1_Assert_AssertGuard_true_29650} : (tensor<i1>, tensor<i64>, tensor<i64>) -> tensor<i1>
"tf.Yield"(%72) : (tensor<i1>) -> ()
}, {
%72 = "std.call"(%61, %59, %53) {callee = @RaggedFromNestedRowSplits_RaggedFromRowSplits_1_assert_equal_1_Assert_AssertGuard_false_29660} : (tensor<i1>, tensor<i64>, tensor<i64>) -> tensor<i1>
"tf.Yield"(%72) : (tensor<i1>) -> ()
}) {is_stateless = false} : (tensor<i1>) -> tensor<i1>
%63 = "tf.IfRegion"(%49) ( {
%72 = "std.call"(%49, %43#0, %47#0) {callee = @NGrams_SlidingWindow_RaggedConcat_assert_equal_2_Assert_AssertGuard_true_30330} : (tensor<i1>, tensor<?xi64>, tensor<?xi64>) -> tensor<i1>
"tf.Yield"(%72) : (tensor<i1>) -> ()
}, {
%72 = "std.call"(%49, %43#0, %47#0) {callee = @NGrams_SlidingWindow_RaggedConcat_assert_equal_2_Assert_AssertGuard_false_30340} : (tensor<i1>, tensor<?xi64>, tensor<?xi64>) -> tensor<i1>
"tf.Yield"(%72) : (tensor<i1>) -> ()
}) {is_stateless = false} : (tensor<i1>) -> tensor<i1>
%64 = "tf.Identity"(%43#0) {device = ""} : (tensor<?xi64>) -> tensor<?xi64>
%65 = "tf.Identity"(%63) {device = ""} : (tensor<i1>) -> tensor<i1>
%66 = "tf.Pack"(%44, %50) {axis = 1 : i64, device = ""} : (tensor<?x!tf.string>, tensor<?x!tf.string>) -> tensor<?x2x!tf.string>
%67 = "tf.ReduceJoin"(%66, %0) {device = "", keep_dims = false, separator = ""} : (tensor<?x2x!tf.string>, tensor<i32>) -> tensor<?x!tf.string>
%68 = "tf.Identity"(%67) {device = ""} : (tensor<?x!tf.string>) -> tensor<?x!tf.string>
%69 = "tf.Identity"(%62) {device = ""} : (tensor<i1>) -> tensor<i1>
%70 = "tf.Identity"(%58) {_class = ["loc:@args_0"], device = ""} : (tensor<3xi64>) -> tensor<3xi64>
%71 = "tf.Identity"(%70) {device = ""} : (tensor<3xi64>) -> tensor<3xi64>
return %68, %71, %64 : tensor<?x!tf.string>, tensor<3xi64>, tensor<?xi64>
}
func @RaggedFromNestedRowSplits_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_true_27770(%arg0: tensor<i1>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<i1> attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} {
%0 = "tf.Identity"(%arg0) {device = ""} : (tensor<i1>) -> tensor<i1>
%1 = "tf.Identity"(%0) {device = ""} : (tensor<i1>) -> tensor<i1>
return %1 : tensor<i1>
}
func @RaggedFromNestedRowSplits_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_false_27780(%arg0: tensor<i1>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<i1> attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} {
%0 = "tf.Const"() {value = dense<"Arguments to from_row_splits do not form a valid RaggedTensor:zero"> : tensor<!tf.string>} : () -> tensor<!tf.string>
%1 = "tf.Const"() {value = dense<"Condition x == y did not hold element-wise:"> : tensor<!tf.string>} : () -> tensor<!tf.string>
%2 = "tf.Const"() {value = dense<"x (RaggedFromNestedRowSplits/RaggedFromRowSplits/RowPartitionFromRowSplits/strided_slice:0) = "> : tensor<!tf.string>} : () -> tensor<!tf.string>
%3 = "tf.Const"() {value = dense<"y (RaggedFromNestedRowSplits/RaggedFromRowSplits/RowPartitionFromRowSplits/Const:0) = "> : tensor<!tf.string>} : () -> tensor<!tf.string>
"tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 3 : i64} : (tensor<i1>, tensor<!tf.string>, tensor<!tf.string>, tensor<!tf.string>, tensor<i64>, tensor<!tf.string>, tensor<i64>) -> ()
%4 = "tf.Identity"(%arg0) {device = ""} : (tensor<i1>) -> tensor<i1>
%5 = "tf.Identity"(%4) {device = ""} : (tensor<i1>) -> tensor<i1>
return %5 : tensor<i1>
}
func @RaggedFromNestedRowSplits_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_true_28130(%arg0: tensor<i1>, %arg1: tensor<?xi64>) -> tensor<i1> attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<>, #tf.shape<?>]} {
%0 = "tf.Identity"(%arg0) {device = ""} : (tensor<i1>) -> tensor<i1>
%1 = "tf.Identity"(%0) {device = ""} : (tensor<i1>) -> tensor<i1>
return %1 : tensor<i1>
}
func @RaggedFromNestedRowSplits_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_false_28140(%arg0: tensor<i1>, %arg1: tensor<?xi64>) -> tensor<i1> attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<>, #tf.shape<?>], tf.signature.is_stateful} {
%0 = "tf.Const"() {value = dense<"Arguments to from_row_splits do not form a valid RaggedTensor:monotonic"> : tensor<!tf.string>} : () -> tensor<!tf.string>
%1 = "tf.Const"() {value = dense<"Condition x >= 0 did not hold element-wise:"> : tensor<!tf.string>} : () -> tensor<!tf.string>
%2 = "tf.Const"() {value = dense<"x (RaggedFromNestedRowSplits/RaggedFromRowSplits/RowPartitionFromRowSplits/sub:0) = "> : tensor<!tf.string>} : () -> tensor<!tf.string>
"tf.Assert"(%arg0, %0, %1, %2, %arg1) {device = "", summarize = 3 : i64} : (tensor<i1>, tensor<!tf.string>, tensor<!tf.string>, tensor<!tf.string>, tensor<?xi64>) -> ()
%3 = "tf.Identity"(%arg0) {device = ""} : (tensor<i1>) -> tensor<i1>
%4 = "tf.Identity"(%3) {device = ""} : (tensor<i1>) -> tensor<i1>
return %4 : tensor<i1>
}
func @RaggedFromNestedRowSplits_RaggedFromRowSplits_assert_equal_1_Assert_AssertGuard_true_28500(%arg0: tensor<i1>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<i1> attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} {
%0 = "tf.Identity"(%arg0) {device = ""} : (tensor<i1>) -> tensor<i1>
%1 = "tf.Identity"(%0) {device = ""} : (tensor<i1>) -> tensor<i1>
return %1 : tensor<i1>
}
func @RaggedFromNestedRowSplits_RaggedFromRowSplits_assert_equal_1_Assert_AssertGuard_false_28510(%arg0: tensor<i1>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<i1> attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} {
%0 = "tf.Const"() {value = dense<"Arguments to _from_row_partition do not form a valid RaggedTensor"> : tensor<!tf.string>} : () -> tensor<!tf.string>
%1 = "tf.Const"() {value = dense<"Condition x == y did not hold element-wise:"> : tensor<!tf.string>} : () -> tensor<!tf.string>
%2 = "tf.Const"() {value = dense<"x (RaggedFromNestedRowSplits/RaggedFromRowSplits/strided_slice_1:0) = "> : tensor<!tf.string>} : () -> tensor<!tf.string>
%3 = "tf.Const"() {value = dense<"y (RaggedFromNestedRowSplits/RaggedFromRowSplits/strided_slice:0) = "> : tensor<!tf.string>} : () -> tensor<!tf.string>
"tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 3 : i64} : (tensor<i1>, tensor<!tf.string>, tensor<!tf.string>, tensor<!tf.string>, tensor<i64>, tensor<!tf.string>, tensor<i64>) -> ()
%4 = "tf.Identity"(%arg0) {device = ""} : (tensor<i1>) -> tensor<i1>
%5 = "tf.Identity"(%4) {device = ""} : (tensor<i1>) -> tensor<i1>
return %5 : tensor<i1>
}
func @RaggedFromNestedRowSplits_RaggedFromRowSplits_1_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_true_28900(%arg0: tensor<i1>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<i1> attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} {
%0 = "tf.Identity"(%arg0) {device = ""} : (tensor<i1>) -> tensor<i1>
%1 = "tf.Identity"(%0) {device = ""} : (tensor<i1>) -> tensor<i1>
return %1 : tensor<i1>
}
func @RaggedFromNestedRowSplits_RaggedFromRowSplits_1_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_false_28910(%arg0: tensor<i1>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<i1> attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} {
%0 = "tf.Const"() {value = dense<"Arguments to from_row_splits do not form a valid RaggedTensor:zero"> : tensor<!tf.string>} : () -> tensor<!tf.string>
%1 = "tf.Const"() {value = dense<"Condition x == y did not hold element-wise:"> : tensor<!tf.string>} : () -> tensor<!tf.string>
%2 = "tf.Const"() {value = dense<"x (RaggedFromNestedRowSplits/RaggedFromRowSplits_1/RowPartitionFromRowSplits/strided_slice:0) = "> : tensor<!tf.string>} : () -> tensor<!tf.string>
%3 = "tf.Const"() {value = dense<"y (RaggedFromNestedRowSplits/RaggedFromRowSplits_1/RowPartitionFromRowSplits/Const:0) = "> : tensor<!tf.string>} : () -> tensor<!tf.string>
"tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 3 : i64} : (tensor<i1>, tensor<!tf.string>, tensor<!tf.string>, tensor<!tf.string>, tensor<i64>, tensor<!tf.string>, tensor<i64>) -> ()
%4 = "tf.Identity"(%arg0) {device = ""} : (tensor<i1>) -> tensor<i1>
%5 = "tf.Identity"(%4) {device = ""} : (tensor<i1>) -> tensor<i1>
return %5 : tensor<i1>
}
func @RaggedFromNestedRowSplits_RaggedFromRowSplits_1_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_true_29260(%arg0: tensor<i1>, %arg1: tensor<2xi64>) -> tensor<i1> attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<>, #tf.shape<2>]} {
%0 = "tf.Identity"(%arg0) {device = ""} : (tensor<i1>) -> tensor<i1>
%1 = "tf.Identity"(%0) {device = ""} : (tensor<i1>) -> tensor<i1>
return %1 : tensor<i1>
}
func @RaggedFromNestedRowSplits_RaggedFromRowSplits_1_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_false_29270(%arg0: tensor<i1>, %arg1: tensor<2xi64>) -> tensor<i1> attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<>, #tf.shape<2>], tf.signature.is_stateful} {
%0 = "tf.Const"() {value = dense<"Arguments to from_row_splits do not form a valid RaggedTensor:monotonic"> : tensor<!tf.string>} : () -> tensor<!tf.string>
%1 = "tf.Const"() {value = dense<"Condition x >= 0 did not hold element-wise:"> : tensor<!tf.string>} : () -> tensor<!tf.string>
%2 = "tf.Const"() {value = dense<"x (RaggedFromNestedRowSplits/RaggedFromRowSplits_1/RowPartitionFromRowSplits/sub:0) = "> : tensor<!tf.string>} : () -> tensor<!tf.string>
"tf.Assert"(%arg0, %0, %1, %2, %arg1) {device = "", summarize = 3 : i64} : (tensor<i1>, tensor<!tf.string>, tensor<!tf.string>, tensor<!tf.string>, tensor<2xi64>) -> ()
%3 = "tf.Identity"(%arg0) {device = ""} : (tensor<i1>) -> tensor<i1>
%4 = "tf.Identity"(%3) {device = ""} : (tensor<i1>) -> tensor<i1>
return %4 : tensor<i1>
}
func @RaggedFromNestedRowSplits_RaggedFromRowSplits_1_assert_equal_1_Assert_AssertGuard_true_29650(%arg0: tensor<i1>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<i1> attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} {
%0 = "tf.Identity"(%arg0) {device = ""} : (tensor<i1>) -> tensor<i1>
%1 = "tf.Identity"(%0) {device = ""} : (tensor<i1>) -> tensor<i1>
return %1 : tensor<i1>
}
func @RaggedFromNestedRowSplits_RaggedFromRowSplits_1_assert_equal_1_Assert_AssertGuard_false_29660(%arg0: tensor<i1>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<i1> attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} {
%0 = "tf.Const"() {value = dense<"Arguments to _from_row_partition do not form a valid RaggedTensor"> : tensor<!tf.string>} : () -> tensor<!tf.string>
%1 = "tf.Const"() {value = dense<"Condition x == y did not hold element-wise:"> : tensor<!tf.string>} : () -> tensor<!tf.string>
%2 = "tf.Const"() {value = dense<"x (RaggedFromNestedRowSplits/RaggedFromRowSplits_1/strided_slice:0) = "> : tensor<!tf.string>} : () -> tensor<!tf.string>
%3 = "tf.Const"() {value = dense<"y (RaggedFromNestedRowSplits/RaggedFromRowSplits_1/RaggedNRows/sub:0) = "> : tensor<!tf.string>} : () -> tensor<!tf.string>
"tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 3 : i64} : (tensor<i1>, tensor<!tf.string>, tensor<!tf.string>, tensor<!tf.string>, tensor<i64>, tensor<!tf.string>, tensor<i64>) -> ()
%4 = "tf.Identity"(%arg0) {device = ""} : (tensor<i1>) -> tensor<i1>
%5 = "tf.Identity"(%4) {device = ""} : (tensor<i1>) -> tensor<i1>
return %5 : tensor<i1>
}
func @NGrams_SlidingWindow_RaggedConcat_assert_equal_2_Assert_AssertGuard_true_30330(%arg0: tensor<i1>, %arg1: tensor<?xi64>, %arg2: tensor<?xi64>) -> tensor<i1> attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<>, #tf.shape<?>, #tf.shape<?>]} {
%0 = "tf.Identity"(%arg0) {device = ""} : (tensor<i1>) -> tensor<i1>
%1 = "tf.Identity"(%0) {device = ""} : (tensor<i1>) -> tensor<i1>
return %1 : tensor<i1>
}
func @NGrams_SlidingWindow_RaggedConcat_assert_equal_2_Assert_AssertGuard_false_30340(%arg0: tensor<i1>, %arg1: tensor<?xi64>, %arg2: tensor<?xi64>) -> tensor<i1> attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<>, #tf.shape<?>, #tf.shape<?>], tf.signature.is_stateful} {
%0 = "tf.Const"() {value = dense<"Inputs must have identical ragged splits"> : tensor<!tf.string>} : () -> tensor<!tf.string>
%1 = "tf.Const"() {value = dense<"Condition x == y did not hold element-wise:"> : tensor<!tf.string>} : () -> tensor<!tf.string>
%2 = "tf.Const"() {value = dense<"x (NGrams/SlidingWindow/RaggedGetItem/RaggedRange:0) = "> : tensor<!tf.string>} : () -> tensor<!tf.string>
%3 = "tf.Const"() {value = dense<"y (NGrams/SlidingWindow/RaggedGetItem_1/RaggedRange:0) = "> : tensor<!tf.string>} : () -> tensor<!tf.string>
"tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 3 : i64} : (tensor<i1>, tensor<!tf.string>, tensor<!tf.string>, tensor<!tf.string>, tensor<?xi64>, tensor<!tf.string>, tensor<?xi64>) -> ()
%4 = "tf.Identity"(%arg0) {device = ""} : (tensor<i1>) -> tensor<i1>
%5 = "tf.Identity"(%4) {device = ""} : (tensor<i1>) -> tensor<i1>
return %5 : tensor<i1>
}
// CHECK: func @ngrams_ragged_rank_2(%arg0: tensor<?x!tf.string> {tf._user_specified_name = "values"}, %arg1: tensor<3xi64> {tf._user_specified_name = "args_0"}, %arg2: tensor<?xi64> {tf._user_specified_name = "args_1"}) -> (tensor<?x!tf.string>, tensor<3xi64>, tensor<?xi64>) attributes {sym_visibility = "private", tf._implements = #tf.func<@"tftext:Ngrams", {axis = -1 : i64, reduction_type = "STRING_JOIN", string_separator = "", width = 2 : i64}>, tf._input_shapes = [#tf.shape<?>, #tf.shape<3>, #tf.shape<?>], tf.signature.is_stateful} {
// CHECK: %0:3 = "tfl.custom"(%arg0, %arg1, %arg2) {custom_code = "tftext:Ngrams", custom_option = opaque<"tfl", "0x776964746800737472696E675F736570617261746F720000006178697300726564756374696F6E5F74797065000B535452494E475F4A4F494E0004221E373E040104FF152C0204141404082401"> : tensor<77xi8>} : (tensor<?x!tf.string>, tensor<3xi64>, tensor<?xi64>) -> (tensor<?x!tf.string>, tensor<3xi64>, tensor<?xi64>)
// CHECK: return %0#0, %0#1, %0#2 : tensor<?x!tf.string>, tensor<3xi64>, tensor<?xi64>

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/lite/utils/tftext_utils.h"
#include "flatbuffers/flexbuffers.h" // from @flatbuffers
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/None.h"
#include "llvm/ADT/SmallVector.h"
@ -28,6 +29,7 @@ limitations under the License.
#include "mlir/IR/Identifier.h" // from @llvm-project
#include "mlir/IR/Location.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Matchers.h" // from @llvm-project
#include "mlir/IR/OpDefinition.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
@ -43,32 +45,35 @@ namespace TFL {
namespace {
constexpr char kNgrams[] = "tftext:Ngrams";
constexpr char kWhitespaceTokenizer[] = "tftext:WhitespaceTokenizer";
constexpr char kTFImplements[] = "tf._implements";
using mlir::TF::FuncAttr;
using mlir::TF::StringType;
inline OpaqueElementsAttr emptyCustomOption(OpBuilder* builder) {
std::string content = "";
inline OpaqueElementsAttr CustomOption(OpBuilder* builder,
const std::string& content) {
ShapedType type = RankedTensorType::get(
{static_cast<int64_t>(content.size())}, builder->getIntegerType(8));
return OpaqueElementsAttr::get(
builder->getContext()->getRegisteredDialect("tfl"), type, content);
builder->getContext()->getRegisteredDialect("tfl"), type,
StringRef(content.data(), content.size()));
}
inline RankedTensorType getInputType(mlir::FuncOp func, int idx) {
return func.getType()
.getInput(idx)
.dyn_cast_or_null<mlir::RankedTensorType>();
inline TensorType GetInputType(FuncOp func, int idx) {
return func.getType().getInput(idx).dyn_cast_or_null<TensorType>();
}
inline RankedTensorType getResultType(mlir::FuncOp func, int idx) {
return func.getType()
.getResult(idx)
.dyn_cast_or_null<mlir::RankedTensorType>();
inline TensorType GetResultType(FuncOp func, int idx) {
return func.getType().getResult(idx).dyn_cast_or_null<TensorType>();
}
LogicalResult VerifyWhitespaceTokenizer(mlir::FuncOp func) {
inline bool RankEquals(const TensorType& type, int rank) {
return type && type.hasRank() && type.getRank() == rank;
}
LogicalResult VerifyWhitespaceTokenizer(FuncOp func) {
// In the case of input tensor with 0 rank.
// Whitespace tokenizer generates 1 output:
// * String tensor for tokens.
@ -83,8 +88,8 @@ LogicalResult VerifyWhitespaceTokenizer(mlir::FuncOp func) {
// * 1st output is the value of ragged tensor;
// * 2nd output is the inner offset;
// * 3rd output is the outer offset.
auto input_type = getInputType(func, 0);
if (!input_type || !input_type.getElementType().isa<mlir::TF::StringType>() ||
auto input_type = GetInputType(func, 0);
if (!input_type || !input_type.getElementType().isa<StringType>() ||
!input_type.hasRank()) {
return func.emitError() << "Input should be a string tensor";
}
@ -100,21 +105,21 @@ LogicalResult VerifyWhitespaceTokenizer(mlir::FuncOp func) {
<< "output(s) when input has rank " << input_type.getRank();
}
auto value_type = getResultType(func, 0);
if (!value_type || !value_type.hasRank() || value_type.getRank() != 1 ||
!value_type.getElementType().isa<mlir::TF::StringType>()) {
auto value_type = GetResultType(func, 0);
if (!RankEquals(value_type, 1) ||
!value_type.getElementType().isa<StringType>()) {
return func.emitError() << "1st output should be string tensor";
}
if (func.getNumResults() > 1) {
auto offset_type = getResultType(func, 1);
if (!offset_type || !offset_type.hasRank() || offset_type.getRank() != 1 ||
auto offset_type = GetResultType(func, 1);
if (!RankEquals(offset_type, 1) ||
!offset_type.getElementType().isInteger(64)) {
return func.emitError() << "2nd output should be int64 tensor";
}
}
if (func.getNumResults() > 2) {
auto offset_type = getResultType(func, 2);
if (!offset_type || !offset_type.hasRank() || offset_type.getRank() != 1 ||
auto offset_type = GetResultType(func, 2);
if (!RankEquals(offset_type, 1) ||
!offset_type.getElementType().isInteger(64)) {
return func.emitError() << "3rd output should be int64 tensor";
}
@ -123,28 +128,159 @@ LogicalResult VerifyWhitespaceTokenizer(mlir::FuncOp func) {
return success();
}
LogicalResult ConvertWhitespaceTokenizer(mlir::FuncOp func, llvm::StringRef api,
LogicalResult ConvertWhitespaceTokenizer(FuncOp func, llvm::StringRef api,
FuncAttr attr) {
func.eraseBody();
func.addEntryBlock();
func.setAttr(kTFImplements, attr);
Value text = func.getArgument(0);
OpBuilder builder(func.getBody());
auto op = builder.create<mlir::TFL::CustomOp>(
func.getLoc(), func.getType().getResults(), ValueRange(text), api,
emptyCustomOption(&builder));
builder.create<mlir::ReturnOp>(func.getLoc(), op.getResults());
std::string empty_option_buffer;
auto op = builder.create<CustomOp>(
func.getLoc(), func.getType().getResults(), func.getArguments(), api,
CustomOption(&builder, empty_option_buffer));
builder.create<ReturnOp>(func.getLoc(), op.getResults());
return success();
}
LogicalResult VerifyNgrams(FuncOp func) {
// The inputs and outputs should be the same:
// * A string tensor for tokens/ragged tensor values.
// * Zero or more row_split tensors.
constexpr int kValues = 0;
constexpr int kRowSplits = 1;
if (func.getType().getInputs().size() != func.getType().getResults().size()) {
return func.emitError() << "Mismatched number of inputs and outputs.";
}
int row_splits = func.getType().getInputs().size() - kRowSplits;
if (row_splits == 0) {
auto input_values = GetInputType(func, kValues);
if (!input_values || !input_values.getElementType().isa<StringType>()) {
return func.emitError()
<< "Input " << kValues << " should be a string tensor";
}
auto output_values = GetResultType(func, kValues);
if (!output_values || !output_values.getElementType().isa<StringType>()) {
return func.emitError()
<< "Output " << kValues << " should be a string tensor";
}
if (input_values.hasRank() && output_values.hasRank() &&
input_values.getRank() != output_values.getRank()) {
return func.emitError() << "Input " << kValues << " and output "
<< kValues << " should have the same rank";
}
} else {
auto input_values = GetInputType(func, kValues);
if (!RankEquals(input_values, 1) ||
!input_values.getElementType().isa<StringType>()) {
return func.emitError()
<< "Input " << kValues << " should be a 1D string tensor";
}
auto output_values = GetResultType(func, kValues);
if (!RankEquals(output_values, 1) ||
!output_values.getElementType().isa<StringType>()) {
return func.emitError()
<< "Output " << kValues << " should be a 1D string tensor";
}
for (int i = 0; i < row_splits; ++i) {
const int row_index = i + kRowSplits;
auto input_row_splits = GetInputType(func, row_index);
if (!RankEquals(input_row_splits, 1) ||
!input_row_splits.getElementType().isInteger(64)) {
return func.emitError()
<< "Input " << row_index << " should be a 1D int64 tensor";
}
auto output_row_splits = GetResultType(func, row_index);
if (!RankEquals(output_row_splits, 1) ||
!output_row_splits.getElementType().isInteger(64)) {
return func.emitError()
<< "Output " << row_index << " should be a 1D int64 tensor";
}
}
}
return success();
}
LogicalResult CreateNgramsCustomOption(FuncOp func, DictionaryAttr attrs,
std::string& custom_option_buffer) {
flexbuffers::Builder fbb;
size_t start_map = fbb.StartMap();
auto width = attrs.get("width").dyn_cast_or_null<IntegerAttr>();
if (!width) {
return func.emitError() << "'width' attribute is not set or not an integer";
}
fbb.Int("width", width.getInt());
auto string_separator =
attrs.get("string_separator").dyn_cast_or_null<StringAttr>();
if (!string_separator) {
return func.emitError()
<< "'string_separator' attribute is not set or not a string";
}
// StringAttrs are not guaranteed to be NUL terminated, but flexbuffers
// strings expect NUL terminated strings.
std::string string_separator_str(string_separator.getValue().data(),
string_separator.getValue().size());
fbb.String("string_separator", string_separator_str);
auto axis = attrs.get("axis").dyn_cast_or_null<IntegerAttr>();
if (!axis) {
return func.emitError() << "'axis' attribute is not set or not an integer";
}
fbb.Int("axis", axis.getInt());
auto reduction_type =
attrs.get("reduction_type").dyn_cast_or_null<StringAttr>();
if (!reduction_type) {
return func.emitError()
<< "'reduction_type' attribute is not set or not a string";
}
// StringAttrs are not guaranteed to be NUL terminated, but flexbuffers
// strings expect NUL terminated strings.
std::string reduction_type_str(reduction_type.getValue().data(),
reduction_type.getValue().size());
fbb.String("reduction_type", reduction_type_str);
fbb.EndMap(start_map);
fbb.Finish();
custom_option_buffer.assign(fbb.GetBuffer().begin(), fbb.GetBuffer().end());
return success();
}
LogicalResult ConvertNgrams(FuncOp func, llvm::StringRef api, FuncAttr attr) {
func.eraseBody();
func.addEntryBlock();
func.setAttr(kTFImplements, attr);
OpBuilder builder(func.getBody());
std::string custom_option_buffer;
if (failed(CreateNgramsCustomOption(func, attr.GetAttrs(),
custom_option_buffer))) {
return failure();
}
auto op = builder.create<CustomOp>(
func.getLoc(), func.getType().getResults(), func.getArguments(), api,
CustomOption(&builder, custom_option_buffer));
builder.create<ReturnOp>(func.getLoc(), op.getResults());
return success();
}
} // namespace
LogicalResult ConvertTFTextAPI(mlir::FuncOp func, llvm::StringRef api,
LogicalResult ConvertTFTextAPI(FuncOp func, llvm::StringRef api,
FuncAttr attr) {
if (api.str() == kWhitespaceTokenizer) {
if (succeeded(VerifyWhitespaceTokenizer(func))) {
return ConvertWhitespaceTokenizer(func, api, attr);
}
} else if (api.str() == kNgrams) {
if (succeeded(VerifyNgrams(func))) {
return ConvertNgrams(func, api, attr);
}
}
return failure();
}