Fuse tensorflow_text.ngrams into a TFLite custom op
PiperOrigin-RevId: 323456482 Change-Id: Idfd446c371e8a4a4f82b6da730d02b0897d35a8a
This commit is contained in:
parent
c7e7f49228
commit
07e4db17ff
@ -270,6 +270,7 @@ cc_library(
|
|||||||
"//tensorflow/compiler/mlir/tensorflow",
|
"//tensorflow/compiler/mlir/tensorflow",
|
||||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_attributes",
|
"//tensorflow/compiler/mlir/tensorflow:tensorflow_attributes",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
|
"@flatbuffers",
|
||||||
"@llvm-project//llvm:Support",
|
"@llvm-project//llvm:Support",
|
||||||
"@llvm-project//mlir:IR",
|
"@llvm-project//mlir:IR",
|
||||||
"@llvm-project//mlir:StandardOps",
|
"@llvm-project//mlir:StandardOps",
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
// RUN: tf-opt -tfl-prepare-composite-funcs-tf -tfl-fuse-tftext=true %s -split-input-file | FileCheck %s
|
// RUN: tf-opt -tfl-prepare-composite-funcs-tf -tfl-fuse-tftext=true %s | FileCheck %s
|
||||||
module {
|
|
||||||
|
|
||||||
func @whitespace_tokenizer_rank1(%arg0: tensor<1x!tf.string> {tf._user_specified_name = "input"}) -> (tensor<?x!tf.string>, tensor<?xi64>) attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<1>], tf._implements = #tf.func<@"tftext:WhitespaceTokenizer", {}>, tf.signature.is_stateful} {
|
func @whitespace_tokenizer_rank1(%arg0: tensor<1x!tf.string> {tf._user_specified_name = "input"}) -> (tensor<?x!tf.string>, tensor<?xi64>) attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<1>], tf._implements = #tf.func<@"tftext:WhitespaceTokenizer", {}>, tf.signature.is_stateful} {
|
||||||
%0 = "tf.Const"() {value = dense<[0, 1]> : tensor<2xi64>} : () -> tensor<2xi64>
|
%0 = "tf.Const"() {value = dense<[0, 1]> : tensor<2xi64>} : () -> tensor<2xi64>
|
||||||
@ -3194,4 +3193,246 @@ module {
|
|||||||
// CHECK: func @whitespace_tokenizer_rank0(%arg0: tensor<!tf.string> {tf._user_specified_name = "input"}) -> tensor<?x!tf.string> attributes {sym_visibility = "private", tf._implements = #tf.func<@"tftext:WhitespaceTokenizer", {}>, tf._input_shapes = [#tf.shape<>], tf.signature.is_stateful} {
|
// CHECK: func @whitespace_tokenizer_rank0(%arg0: tensor<!tf.string> {tf._user_specified_name = "input"}) -> tensor<?x!tf.string> attributes {sym_visibility = "private", tf._implements = #tf.func<@"tftext:WhitespaceTokenizer", {}>, tf._input_shapes = [#tf.shape<>], tf.signature.is_stateful} {
|
||||||
// CHECK: %0 = "tfl.custom"(%arg0) {custom_code = "tftext:WhitespaceTokenizer", custom_option = opaque<"tfl", "0x"> : tensor<0xi8>} : (tensor<!tf.string>) -> tensor<?x!tf.string>
|
// CHECK: %0 = "tfl.custom"(%arg0) {custom_code = "tftext:WhitespaceTokenizer", custom_option = opaque<"tfl", "0x"> : tensor<0xi8>} : (tensor<!tf.string>) -> tensor<?x!tf.string>
|
||||||
// CHECK: return %0 : tensor<?x!tf.string>
|
// CHECK: return %0 : tensor<?x!tf.string>
|
||||||
|
|
||||||
|
func @ngrams(%arg0: tensor<?x!tf.string> {tf._user_specified_name = "input"}) -> tensor<?x!tf.string> attributes {tf._input_shapes = [#tf.shape<?>], tf._implements = #tf.func<@"tftext:Ngrams", {axis = -1 : i64, reduction_type = "STRING_JOIN", string_separator = " ", width = 2 : i64}>} {
|
||||||
|
%0 = "tf.Const"() {value = dense<-1> : tensor<i32>} : () -> tensor<i32>
|
||||||
|
%1 = "tf.Const"() {value = dense<[0, -1]> : tensor<2xi32>} : () -> tensor<2xi32>
|
||||||
|
%2 = "tf.Const"() {value = dense<[0, 1]> : tensor<2xi32>} : () -> tensor<2xi32>
|
||||||
|
%3 = "tf.Const"() {value = dense<0> : tensor<2xi32>} : () -> tensor<2xi32>
|
||||||
|
%4 = "tf.Const"() {value = dense<1> : tensor<2xi32>} : () -> tensor<2xi32>
|
||||||
|
%5 = "tf.StridedSlice"(%arg0, %3, %1, %4) {begin_mask = 0 : i64, device = "", ellipsis_mask = 1 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<?x!tf.string>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<?x!tf.string>
|
||||||
|
%6 = "tf.StridedSlice"(%arg0, %2, %3, %4) {begin_mask = 0 : i64, device = "", ellipsis_mask = 1 : i64, end_mask = 2 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<?x!tf.string>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<?x!tf.string>
|
||||||
|
%7 = "tf.Pack"(%5, %6) {axis = -1 : i64, device = ""} : (tensor<?x!tf.string>, tensor<?x!tf.string>) -> tensor<?x2x!tf.string>
|
||||||
|
%8 = "tf.ReduceJoin"(%7, %0) {device = "", keep_dims = false, separator = " "} : (tensor<?x2x!tf.string>, tensor<i32>) -> tensor<?x!tf.string>
|
||||||
|
%9 = "tf.Identity"(%8) {device = ""} : (tensor<?x!tf.string>) -> tensor<?x!tf.string>
|
||||||
|
return %9 : tensor<?x!tf.string>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK: func @ngrams(%arg0: tensor<?x!tf.string> {tf._user_specified_name = "input"}) -> tensor<?x!tf.string> attributes {tf._implements = #tf.func<@"tftext:Ngrams", {axis = -1 : i64, reduction_type = "STRING_JOIN", string_separator = " ", width = 2 : i64}>, tf._input_shapes = [#tf.shape<?>]} {
|
||||||
|
// CHECK: %0 = "tfl.custom"(%arg0) {custom_code = "tftext:Ngrams", custom_option = opaque<"tfl", "0x776964746800737472696E675F736570617261746F72000120006178697300726564756374696F6E5F74797065000B535452494E475F4A4F494E0004221E383F040104FF152D0204141404082401"> : tensor<78xi8>} : (tensor<?x!tf.string>) -> tensor<?x!tf.string>
|
||||||
|
// CHECK: return %0 : tensor<?x!tf.string>
|
||||||
|
// CHECK: }
|
||||||
|
|
||||||
|
func @ngrams_ragged_rank_2(%arg0: tensor<?x!tf.string> {tf._user_specified_name = "values"}, %arg1: tensor<3xi64> {tf._user_specified_name = "args_0"}, %arg2: tensor<?xi64> {tf._user_specified_name = "args_1"}) -> (tensor<?x!tf.string>, tensor<3xi64>, tensor<?xi64>) attributes {sym_visibility = "private", tf._implements = #tf.func<@"tftext:Ngrams", {axis = -1 : i64, reduction_type = "STRING_JOIN", string_separator = "", width = 2 : i64}>, tf._input_shapes = [#tf.shape<?>, #tf.shape<3>, #tf.shape<?>], tf.signature.is_stateful} {
|
||||||
|
%0 = "tf.Const"() {value = dense<-1> : tensor<i32>} : () -> tensor<i32>
|
||||||
|
%1 = "tf.Const"() {value = dense<-1> : tensor<i64>} : () -> tensor<i64>
|
||||||
|
%2 = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
|
||||||
|
%3 = "tf.Const"() {value = dense<1> : tensor<i64>} : () -> tensor<i64>
|
||||||
|
%4 = "tf.Const"() {value = dense<0> : tensor<i64>} : () -> tensor<i64>
|
||||||
|
%5 = "tf.Const"() {value = dense<> : tensor<0xi32>} : () -> tensor<0xi32>
|
||||||
|
%6 = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||||
|
%7 = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||||
|
%8 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||||
|
%9 = "tf.StridedSlice"(%arg1, %7, %8, %8) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<3xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<i64>
|
||||||
|
%10 = "tf.Equal"(%9, %4) {device = "", incompatible_shape_error = true} : (tensor<i64>, tensor<i64>) -> tensor<i1>
|
||||||
|
%11 = "tf.All"(%10, %5) {device = "", keep_dims = false} : (tensor<i1>, tensor<0xi32>) -> tensor<i1>
|
||||||
|
%12 = "tf.StridedSlice"(%arg1, %8, %7, %8) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<3xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi64>
|
||||||
|
%13 = "tf.StridedSlice"(%arg1, %7, %6, %8) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<3xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi64>
|
||||||
|
%14 = "tf.Sub"(%12, %13) {device = ""} : (tensor<2xi64>, tensor<2xi64>) -> tensor<2xi64>
|
||||||
|
%15 = "tf.LessEqual"(%4, %14) {device = ""} : (tensor<i64>, tensor<2xi64>) -> tensor<2xi1>
|
||||||
|
%16 = "tf.All"(%15, %7) {device = "", keep_dims = false} : (tensor<2xi1>, tensor<1xi32>) -> tensor<i1>
|
||||||
|
%17 = "tf.StridedSlice"(%arg2, %7, %8, %8) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<?xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<i64>
|
||||||
|
%18 = "tf.Equal"(%17, %4) {device = "", incompatible_shape_error = true} : (tensor<i64>, tensor<i64>) -> tensor<i1>
|
||||||
|
%19 = "tf.All"(%18, %5) {device = "", keep_dims = false} : (tensor<i1>, tensor<0xi32>) -> tensor<i1>
|
||||||
|
%20 = "tf.IfRegion"(%19) ( {
|
||||||
|
%72 = "std.call"(%19, %17, %4) {callee = @RaggedFromNestedRowSplits_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_true_27770} : (tensor<i1>, tensor<i64>, tensor<i64>) -> tensor<i1>
|
||||||
|
"tf.Yield"(%72) : (tensor<i1>) -> ()
|
||||||
|
}, {
|
||||||
|
%72 = "std.call"(%19, %17, %4) {callee = @RaggedFromNestedRowSplits_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_false_27780} : (tensor<i1>, tensor<i64>, tensor<i64>) -> tensor<i1>
|
||||||
|
"tf.Yield"(%72) : (tensor<i1>) -> ()
|
||||||
|
}) {is_stateless = false} : (tensor<i1>) -> tensor<i1>
|
||||||
|
%21 = "tf.Identity"(%20) {device = ""} : (tensor<i1>) -> tensor<i1>
|
||||||
|
%22 = "tf.StridedSlice"(%arg2, %8, %7, %8) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<?xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<?xi64>
|
||||||
|
%23 = "tf.StridedSlice"(%arg2, %7, %6, %8) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<?xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<?xi64>
|
||||||
|
%24 = "tf.Sub"(%22, %23) {device = ""} : (tensor<?xi64>, tensor<?xi64>) -> tensor<?xi64>
|
||||||
|
%25 = "tf.LessEqual"(%4, %24) {device = ""} : (tensor<i64>, tensor<?xi64>) -> tensor<?xi1>
|
||||||
|
%26 = "tf.All"(%25, %7) {device = "", keep_dims = false} : (tensor<?xi1>, tensor<1xi32>) -> tensor<i1>
|
||||||
|
%27 = "tf.IfRegion"(%26) ( {
|
||||||
|
%72 = "std.call"(%26, %24) {callee = @RaggedFromNestedRowSplits_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_true_28130} : (tensor<i1>, tensor<?xi64>) -> tensor<i1>
|
||||||
|
"tf.Yield"(%72) : (tensor<i1>) -> ()
|
||||||
|
}, {
|
||||||
|
%72 = "std.call"(%26, %24) {callee = @RaggedFromNestedRowSplits_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_false_28140} : (tensor<i1>, tensor<?xi64>) -> tensor<i1>
|
||||||
|
"tf.Yield"(%72) : (tensor<i1>) -> ()
|
||||||
|
}) {is_stateless = false} : (tensor<i1>) -> tensor<i1>
|
||||||
|
%28 = "tf.Identity"(%27) {device = ""} : (tensor<i1>) -> tensor<i1>
|
||||||
|
%29 = "tf.Identity"(%arg2) {_class = ["loc:@args_1"], device = ""} : (tensor<?xi64>) -> tensor<?xi64>
|
||||||
|
%30 = "tf.StridedSlice"(%29, %6, %7, %8) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<?xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<i64>
|
||||||
|
%31 = "tf.Shape"(%arg0) {device = ""} : (tensor<?x!tf.string>) -> tensor<1xi64>
|
||||||
|
%32 = "tf.StridedSlice"(%31, %7, %8, %8) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<i64>
|
||||||
|
%33 = "tf.Equal"(%30, %32) {device = "", incompatible_shape_error = true} : (tensor<i64>, tensor<i64>) -> tensor<i1>
|
||||||
|
%34 = "tf.All"(%33, %5) {device = "", keep_dims = false} : (tensor<i1>, tensor<0xi32>) -> tensor<i1>
|
||||||
|
%35 = "tf.IfRegion"(%34) ( {
|
||||||
|
%72 = "std.call"(%34, %30, %32) {callee = @RaggedFromNestedRowSplits_RaggedFromRowSplits_assert_equal_1_Assert_AssertGuard_true_28500} : (tensor<i1>, tensor<i64>, tensor<i64>) -> tensor<i1>
|
||||||
|
"tf.Yield"(%72) : (tensor<i1>) -> ()
|
||||||
|
}, {
|
||||||
|
%72 = "std.call"(%34, %30, %32) {callee = @RaggedFromNestedRowSplits_RaggedFromRowSplits_assert_equal_1_Assert_AssertGuard_false_28510} : (tensor<i1>, tensor<i64>, tensor<i64>) -> tensor<i1>
|
||||||
|
"tf.Yield"(%72) : (tensor<i1>) -> ()
|
||||||
|
}) {is_stateless = false} : (tensor<i1>) -> tensor<i1>
|
||||||
|
%36 = "tf.Identity"(%35) {device = ""} : (tensor<i1>) -> tensor<i1>
|
||||||
|
%37 = "tf.Identity"(%29) {_class = ["loc:@args_1"], device = ""} : (tensor<?xi64>) -> tensor<?xi64>
|
||||||
|
%38 = "tf.StridedSlice"(%37, %7, %6, %8) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<?xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<?xi64>
|
||||||
|
%39 = "tf.StridedSlice"(%37, %8, %7, %8) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<?xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<?xi64>
|
||||||
|
%40 = "tf.Minimum"(%38, %39) {device = ""} : (tensor<?xi64>, tensor<?xi64>) -> tensor<?xi64>
|
||||||
|
%41 = "tf.AddV2"(%39, %1) {device = ""} : (tensor<?xi64>, tensor<i64>) -> tensor<?xi64>
|
||||||
|
%42 = "tf.Maximum"(%41, %38) {device = ""} : (tensor<?xi64>, tensor<?xi64>) -> tensor<?xi64>
|
||||||
|
%43:2 = "tf.RaggedRange"(%40, %42, %3) {T = i64, Tsplits = i64, device = ""} : (tensor<?xi64>, tensor<?xi64>, tensor<i64>) -> (tensor<?xi64>, tensor<?xi64>)
|
||||||
|
%44 = "tf.GatherV2"(%arg0, %43#1, %2) {batch_dims = 0 : i64, device = ""} : (tensor<?x!tf.string>, tensor<?xi64>, tensor<i32>) -> tensor<?x!tf.string>
|
||||||
|
%45 = "tf.AddV2"(%38, %3) {device = ""} : (tensor<?xi64>, tensor<i64>) -> tensor<?xi64>
|
||||||
|
%46 = "tf.Minimum"(%45, %39) {device = ""} : (tensor<?xi64>, tensor<?xi64>) -> tensor<?xi64>
|
||||||
|
%47:2 = "tf.RaggedRange"(%46, %39, %3) {T = i64, Tsplits = i64, device = ""} : (tensor<?xi64>, tensor<?xi64>, tensor<i64>) -> (tensor<?xi64>, tensor<?xi64>)
|
||||||
|
%48 = "tf.Equal"(%43#0, %47#0) {device = "", incompatible_shape_error = true} : (tensor<?xi64>, tensor<?xi64>) -> tensor<?xi1>
|
||||||
|
%49 = "tf.All"(%48, %7) {device = "", keep_dims = false} : (tensor<?xi1>, tensor<1xi32>) -> tensor<i1>
|
||||||
|
%50 = "tf.GatherV2"(%arg0, %47#1, %2) {batch_dims = 0 : i64, device = ""} : (tensor<?x!tf.string>, tensor<?xi64>, tensor<i32>) -> tensor<?x!tf.string>
|
||||||
|
%51 = "tf.Shape"(%37) {device = ""} : (tensor<?xi64>) -> tensor<1xi64>
|
||||||
|
%52 = "tf.StridedSlice"(%51, %7, %8, %8) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<i64>
|
||||||
|
%53 = "tf.Sub"(%52, %3) {device = ""} : (tensor<i64>, tensor<i64>) -> tensor<i64>
|
||||||
|
%54 = "tf.IfRegion"(%11) ( {
|
||||||
|
%72 = "std.call"(%11, %9, %4) {callee = @RaggedFromNestedRowSplits_RaggedFromRowSplits_1_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_true_28900} : (tensor<i1>, tensor<i64>, tensor<i64>) -> tensor<i1>
|
||||||
|
"tf.Yield"(%72) : (tensor<i1>) -> ()
|
||||||
|
}, {
|
||||||
|
%72 = "std.call"(%11, %9, %4) {callee = @RaggedFromNestedRowSplits_RaggedFromRowSplits_1_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_false_28910} : (tensor<i1>, tensor<i64>, tensor<i64>) -> tensor<i1>
|
||||||
|
"tf.Yield"(%72) : (tensor<i1>) -> ()
|
||||||
|
}) {is_stateless = false} : (tensor<i1>) -> tensor<i1>
|
||||||
|
%55 = "tf.Identity"(%54) {device = ""} : (tensor<i1>) -> tensor<i1>
|
||||||
|
%56 = "tf.IfRegion"(%16) ( {
|
||||||
|
%72 = "std.call"(%16, %14) {callee = @RaggedFromNestedRowSplits_RaggedFromRowSplits_1_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_true_29260} : (tensor<i1>, tensor<2xi64>) -> tensor<i1>
|
||||||
|
"tf.Yield"(%72) : (tensor<i1>) -> ()
|
||||||
|
}, {
|
||||||
|
%72 = "std.call"(%16, %14) {callee = @RaggedFromNestedRowSplits_RaggedFromRowSplits_1_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_false_29270} : (tensor<i1>, tensor<2xi64>) -> tensor<i1>
|
||||||
|
"tf.Yield"(%72) : (tensor<i1>) -> ()
|
||||||
|
}) {is_stateless = false} : (tensor<i1>) -> tensor<i1>
|
||||||
|
%57 = "tf.Identity"(%56) {device = ""} : (tensor<i1>) -> tensor<i1>
|
||||||
|
%58 = "tf.Identity"(%arg1) {_class = ["loc:@args_0"], device = ""} : (tensor<3xi64>) -> tensor<3xi64>
|
||||||
|
%59 = "tf.StridedSlice"(%58, %6, %7, %8) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<3xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<i64>
|
||||||
|
%60 = "tf.Equal"(%59, %53) {device = "", incompatible_shape_error = true} : (tensor<i64>, tensor<i64>) -> tensor<i1>
|
||||||
|
%61 = "tf.All"(%60, %5) {device = "", keep_dims = false} : (tensor<i1>, tensor<0xi32>) -> tensor<i1>
|
||||||
|
%62 = "tf.IfRegion"(%61) ( {
|
||||||
|
%72 = "std.call"(%61, %59, %53) {callee = @RaggedFromNestedRowSplits_RaggedFromRowSplits_1_assert_equal_1_Assert_AssertGuard_true_29650} : (tensor<i1>, tensor<i64>, tensor<i64>) -> tensor<i1>
|
||||||
|
"tf.Yield"(%72) : (tensor<i1>) -> ()
|
||||||
|
}, {
|
||||||
|
%72 = "std.call"(%61, %59, %53) {callee = @RaggedFromNestedRowSplits_RaggedFromRowSplits_1_assert_equal_1_Assert_AssertGuard_false_29660} : (tensor<i1>, tensor<i64>, tensor<i64>) -> tensor<i1>
|
||||||
|
"tf.Yield"(%72) : (tensor<i1>) -> ()
|
||||||
|
}) {is_stateless = false} : (tensor<i1>) -> tensor<i1>
|
||||||
|
%63 = "tf.IfRegion"(%49) ( {
|
||||||
|
%72 = "std.call"(%49, %43#0, %47#0) {callee = @NGrams_SlidingWindow_RaggedConcat_assert_equal_2_Assert_AssertGuard_true_30330} : (tensor<i1>, tensor<?xi64>, tensor<?xi64>) -> tensor<i1>
|
||||||
|
"tf.Yield"(%72) : (tensor<i1>) -> ()
|
||||||
|
}, {
|
||||||
|
%72 = "std.call"(%49, %43#0, %47#0) {callee = @NGrams_SlidingWindow_RaggedConcat_assert_equal_2_Assert_AssertGuard_false_30340} : (tensor<i1>, tensor<?xi64>, tensor<?xi64>) -> tensor<i1>
|
||||||
|
"tf.Yield"(%72) : (tensor<i1>) -> ()
|
||||||
|
}) {is_stateless = false} : (tensor<i1>) -> tensor<i1>
|
||||||
|
%64 = "tf.Identity"(%43#0) {device = ""} : (tensor<?xi64>) -> tensor<?xi64>
|
||||||
|
%65 = "tf.Identity"(%63) {device = ""} : (tensor<i1>) -> tensor<i1>
|
||||||
|
%66 = "tf.Pack"(%44, %50) {axis = 1 : i64, device = ""} : (tensor<?x!tf.string>, tensor<?x!tf.string>) -> tensor<?x2x!tf.string>
|
||||||
|
%67 = "tf.ReduceJoin"(%66, %0) {device = "", keep_dims = false, separator = ""} : (tensor<?x2x!tf.string>, tensor<i32>) -> tensor<?x!tf.string>
|
||||||
|
%68 = "tf.Identity"(%67) {device = ""} : (tensor<?x!tf.string>) -> tensor<?x!tf.string>
|
||||||
|
%69 = "tf.Identity"(%62) {device = ""} : (tensor<i1>) -> tensor<i1>
|
||||||
|
%70 = "tf.Identity"(%58) {_class = ["loc:@args_0"], device = ""} : (tensor<3xi64>) -> tensor<3xi64>
|
||||||
|
%71 = "tf.Identity"(%70) {device = ""} : (tensor<3xi64>) -> tensor<3xi64>
|
||||||
|
return %68, %71, %64 : tensor<?x!tf.string>, tensor<3xi64>, tensor<?xi64>
|
||||||
|
}
|
||||||
|
func @RaggedFromNestedRowSplits_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_true_27770(%arg0: tensor<i1>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<i1> attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} {
|
||||||
|
%0 = "tf.Identity"(%arg0) {device = ""} : (tensor<i1>) -> tensor<i1>
|
||||||
|
%1 = "tf.Identity"(%0) {device = ""} : (tensor<i1>) -> tensor<i1>
|
||||||
|
return %1 : tensor<i1>
|
||||||
|
}
|
||||||
|
func @RaggedFromNestedRowSplits_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_false_27780(%arg0: tensor<i1>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<i1> attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} {
|
||||||
|
%0 = "tf.Const"() {value = dense<"Arguments to from_row_splits do not form a valid RaggedTensor:zero"> : tensor<!tf.string>} : () -> tensor<!tf.string>
|
||||||
|
%1 = "tf.Const"() {value = dense<"Condition x == y did not hold element-wise:"> : tensor<!tf.string>} : () -> tensor<!tf.string>
|
||||||
|
%2 = "tf.Const"() {value = dense<"x (RaggedFromNestedRowSplits/RaggedFromRowSplits/RowPartitionFromRowSplits/strided_slice:0) = "> : tensor<!tf.string>} : () -> tensor<!tf.string>
|
||||||
|
%3 = "tf.Const"() {value = dense<"y (RaggedFromNestedRowSplits/RaggedFromRowSplits/RowPartitionFromRowSplits/Const:0) = "> : tensor<!tf.string>} : () -> tensor<!tf.string>
|
||||||
|
"tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 3 : i64} : (tensor<i1>, tensor<!tf.string>, tensor<!tf.string>, tensor<!tf.string>, tensor<i64>, tensor<!tf.string>, tensor<i64>) -> ()
|
||||||
|
%4 = "tf.Identity"(%arg0) {device = ""} : (tensor<i1>) -> tensor<i1>
|
||||||
|
%5 = "tf.Identity"(%4) {device = ""} : (tensor<i1>) -> tensor<i1>
|
||||||
|
return %5 : tensor<i1>
|
||||||
|
}
|
||||||
|
func @RaggedFromNestedRowSplits_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_true_28130(%arg0: tensor<i1>, %arg1: tensor<?xi64>) -> tensor<i1> attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<>, #tf.shape<?>]} {
|
||||||
|
%0 = "tf.Identity"(%arg0) {device = ""} : (tensor<i1>) -> tensor<i1>
|
||||||
|
%1 = "tf.Identity"(%0) {device = ""} : (tensor<i1>) -> tensor<i1>
|
||||||
|
return %1 : tensor<i1>
|
||||||
|
}
|
||||||
|
func @RaggedFromNestedRowSplits_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_false_28140(%arg0: tensor<i1>, %arg1: tensor<?xi64>) -> tensor<i1> attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<>, #tf.shape<?>], tf.signature.is_stateful} {
|
||||||
|
%0 = "tf.Const"() {value = dense<"Arguments to from_row_splits do not form a valid RaggedTensor:monotonic"> : tensor<!tf.string>} : () -> tensor<!tf.string>
|
||||||
|
%1 = "tf.Const"() {value = dense<"Condition x >= 0 did not hold element-wise:"> : tensor<!tf.string>} : () -> tensor<!tf.string>
|
||||||
|
%2 = "tf.Const"() {value = dense<"x (RaggedFromNestedRowSplits/RaggedFromRowSplits/RowPartitionFromRowSplits/sub:0) = "> : tensor<!tf.string>} : () -> tensor<!tf.string>
|
||||||
|
"tf.Assert"(%arg0, %0, %1, %2, %arg1) {device = "", summarize = 3 : i64} : (tensor<i1>, tensor<!tf.string>, tensor<!tf.string>, tensor<!tf.string>, tensor<?xi64>) -> ()
|
||||||
|
%3 = "tf.Identity"(%arg0) {device = ""} : (tensor<i1>) -> tensor<i1>
|
||||||
|
%4 = "tf.Identity"(%3) {device = ""} : (tensor<i1>) -> tensor<i1>
|
||||||
|
return %4 : tensor<i1>
|
||||||
|
}
|
||||||
|
func @RaggedFromNestedRowSplits_RaggedFromRowSplits_assert_equal_1_Assert_AssertGuard_true_28500(%arg0: tensor<i1>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<i1> attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} {
|
||||||
|
%0 = "tf.Identity"(%arg0) {device = ""} : (tensor<i1>) -> tensor<i1>
|
||||||
|
%1 = "tf.Identity"(%0) {device = ""} : (tensor<i1>) -> tensor<i1>
|
||||||
|
return %1 : tensor<i1>
|
||||||
|
}
|
||||||
|
func @RaggedFromNestedRowSplits_RaggedFromRowSplits_assert_equal_1_Assert_AssertGuard_false_28510(%arg0: tensor<i1>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<i1> attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} {
|
||||||
|
%0 = "tf.Const"() {value = dense<"Arguments to _from_row_partition do not form a valid RaggedTensor"> : tensor<!tf.string>} : () -> tensor<!tf.string>
|
||||||
|
%1 = "tf.Const"() {value = dense<"Condition x == y did not hold element-wise:"> : tensor<!tf.string>} : () -> tensor<!tf.string>
|
||||||
|
%2 = "tf.Const"() {value = dense<"x (RaggedFromNestedRowSplits/RaggedFromRowSplits/strided_slice_1:0) = "> : tensor<!tf.string>} : () -> tensor<!tf.string>
|
||||||
|
%3 = "tf.Const"() {value = dense<"y (RaggedFromNestedRowSplits/RaggedFromRowSplits/strided_slice:0) = "> : tensor<!tf.string>} : () -> tensor<!tf.string>
|
||||||
|
"tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 3 : i64} : (tensor<i1>, tensor<!tf.string>, tensor<!tf.string>, tensor<!tf.string>, tensor<i64>, tensor<!tf.string>, tensor<i64>) -> ()
|
||||||
|
%4 = "tf.Identity"(%arg0) {device = ""} : (tensor<i1>) -> tensor<i1>
|
||||||
|
%5 = "tf.Identity"(%4) {device = ""} : (tensor<i1>) -> tensor<i1>
|
||||||
|
return %5 : tensor<i1>
|
||||||
|
}
|
||||||
|
func @RaggedFromNestedRowSplits_RaggedFromRowSplits_1_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_true_28900(%arg0: tensor<i1>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<i1> attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} {
|
||||||
|
%0 = "tf.Identity"(%arg0) {device = ""} : (tensor<i1>) -> tensor<i1>
|
||||||
|
%1 = "tf.Identity"(%0) {device = ""} : (tensor<i1>) -> tensor<i1>
|
||||||
|
return %1 : tensor<i1>
|
||||||
|
}
|
||||||
|
func @RaggedFromNestedRowSplits_RaggedFromRowSplits_1_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_false_28910(%arg0: tensor<i1>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<i1> attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} {
|
||||||
|
%0 = "tf.Const"() {value = dense<"Arguments to from_row_splits do not form a valid RaggedTensor:zero"> : tensor<!tf.string>} : () -> tensor<!tf.string>
|
||||||
|
%1 = "tf.Const"() {value = dense<"Condition x == y did not hold element-wise:"> : tensor<!tf.string>} : () -> tensor<!tf.string>
|
||||||
|
%2 = "tf.Const"() {value = dense<"x (RaggedFromNestedRowSplits/RaggedFromRowSplits_1/RowPartitionFromRowSplits/strided_slice:0) = "> : tensor<!tf.string>} : () -> tensor<!tf.string>
|
||||||
|
%3 = "tf.Const"() {value = dense<"y (RaggedFromNestedRowSplits/RaggedFromRowSplits_1/RowPartitionFromRowSplits/Const:0) = "> : tensor<!tf.string>} : () -> tensor<!tf.string>
|
||||||
|
"tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 3 : i64} : (tensor<i1>, tensor<!tf.string>, tensor<!tf.string>, tensor<!tf.string>, tensor<i64>, tensor<!tf.string>, tensor<i64>) -> ()
|
||||||
|
%4 = "tf.Identity"(%arg0) {device = ""} : (tensor<i1>) -> tensor<i1>
|
||||||
|
%5 = "tf.Identity"(%4) {device = ""} : (tensor<i1>) -> tensor<i1>
|
||||||
|
return %5 : tensor<i1>
|
||||||
|
}
|
||||||
|
func @RaggedFromNestedRowSplits_RaggedFromRowSplits_1_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_true_29260(%arg0: tensor<i1>, %arg1: tensor<2xi64>) -> tensor<i1> attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<>, #tf.shape<2>]} {
|
||||||
|
%0 = "tf.Identity"(%arg0) {device = ""} : (tensor<i1>) -> tensor<i1>
|
||||||
|
%1 = "tf.Identity"(%0) {device = ""} : (tensor<i1>) -> tensor<i1>
|
||||||
|
return %1 : tensor<i1>
|
||||||
|
}
|
||||||
|
func @RaggedFromNestedRowSplits_RaggedFromRowSplits_1_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_false_29270(%arg0: tensor<i1>, %arg1: tensor<2xi64>) -> tensor<i1> attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<>, #tf.shape<2>], tf.signature.is_stateful} {
|
||||||
|
%0 = "tf.Const"() {value = dense<"Arguments to from_row_splits do not form a valid RaggedTensor:monotonic"> : tensor<!tf.string>} : () -> tensor<!tf.string>
|
||||||
|
%1 = "tf.Const"() {value = dense<"Condition x >= 0 did not hold element-wise:"> : tensor<!tf.string>} : () -> tensor<!tf.string>
|
||||||
|
%2 = "tf.Const"() {value = dense<"x (RaggedFromNestedRowSplits/RaggedFromRowSplits_1/RowPartitionFromRowSplits/sub:0) = "> : tensor<!tf.string>} : () -> tensor<!tf.string>
|
||||||
|
"tf.Assert"(%arg0, %0, %1, %2, %arg1) {device = "", summarize = 3 : i64} : (tensor<i1>, tensor<!tf.string>, tensor<!tf.string>, tensor<!tf.string>, tensor<2xi64>) -> ()
|
||||||
|
%3 = "tf.Identity"(%arg0) {device = ""} : (tensor<i1>) -> tensor<i1>
|
||||||
|
%4 = "tf.Identity"(%3) {device = ""} : (tensor<i1>) -> tensor<i1>
|
||||||
|
return %4 : tensor<i1>
|
||||||
|
}
|
||||||
|
func @RaggedFromNestedRowSplits_RaggedFromRowSplits_1_assert_equal_1_Assert_AssertGuard_true_29650(%arg0: tensor<i1>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<i1> attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} {
|
||||||
|
%0 = "tf.Identity"(%arg0) {device = ""} : (tensor<i1>) -> tensor<i1>
|
||||||
|
%1 = "tf.Identity"(%0) {device = ""} : (tensor<i1>) -> tensor<i1>
|
||||||
|
return %1 : tensor<i1>
|
||||||
|
}
|
||||||
|
func @RaggedFromNestedRowSplits_RaggedFromRowSplits_1_assert_equal_1_Assert_AssertGuard_false_29660(%arg0: tensor<i1>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<i1> attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} {
|
||||||
|
%0 = "tf.Const"() {value = dense<"Arguments to _from_row_partition do not form a valid RaggedTensor"> : tensor<!tf.string>} : () -> tensor<!tf.string>
|
||||||
|
%1 = "tf.Const"() {value = dense<"Condition x == y did not hold element-wise:"> : tensor<!tf.string>} : () -> tensor<!tf.string>
|
||||||
|
%2 = "tf.Const"() {value = dense<"x (RaggedFromNestedRowSplits/RaggedFromRowSplits_1/strided_slice:0) = "> : tensor<!tf.string>} : () -> tensor<!tf.string>
|
||||||
|
%3 = "tf.Const"() {value = dense<"y (RaggedFromNestedRowSplits/RaggedFromRowSplits_1/RaggedNRows/sub:0) = "> : tensor<!tf.string>} : () -> tensor<!tf.string>
|
||||||
|
"tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 3 : i64} : (tensor<i1>, tensor<!tf.string>, tensor<!tf.string>, tensor<!tf.string>, tensor<i64>, tensor<!tf.string>, tensor<i64>) -> ()
|
||||||
|
%4 = "tf.Identity"(%arg0) {device = ""} : (tensor<i1>) -> tensor<i1>
|
||||||
|
%5 = "tf.Identity"(%4) {device = ""} : (tensor<i1>) -> tensor<i1>
|
||||||
|
return %5 : tensor<i1>
|
||||||
|
}
|
||||||
|
func @NGrams_SlidingWindow_RaggedConcat_assert_equal_2_Assert_AssertGuard_true_30330(%arg0: tensor<i1>, %arg1: tensor<?xi64>, %arg2: tensor<?xi64>) -> tensor<i1> attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<>, #tf.shape<?>, #tf.shape<?>]} {
|
||||||
|
%0 = "tf.Identity"(%arg0) {device = ""} : (tensor<i1>) -> tensor<i1>
|
||||||
|
%1 = "tf.Identity"(%0) {device = ""} : (tensor<i1>) -> tensor<i1>
|
||||||
|
return %1 : tensor<i1>
|
||||||
|
}
|
||||||
|
func @NGrams_SlidingWindow_RaggedConcat_assert_equal_2_Assert_AssertGuard_false_30340(%arg0: tensor<i1>, %arg1: tensor<?xi64>, %arg2: tensor<?xi64>) -> tensor<i1> attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<>, #tf.shape<?>, #tf.shape<?>], tf.signature.is_stateful} {
|
||||||
|
%0 = "tf.Const"() {value = dense<"Inputs must have identical ragged splits"> : tensor<!tf.string>} : () -> tensor<!tf.string>
|
||||||
|
%1 = "tf.Const"() {value = dense<"Condition x == y did not hold element-wise:"> : tensor<!tf.string>} : () -> tensor<!tf.string>
|
||||||
|
%2 = "tf.Const"() {value = dense<"x (NGrams/SlidingWindow/RaggedGetItem/RaggedRange:0) = "> : tensor<!tf.string>} : () -> tensor<!tf.string>
|
||||||
|
%3 = "tf.Const"() {value = dense<"y (NGrams/SlidingWindow/RaggedGetItem_1/RaggedRange:0) = "> : tensor<!tf.string>} : () -> tensor<!tf.string>
|
||||||
|
"tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 3 : i64} : (tensor<i1>, tensor<!tf.string>, tensor<!tf.string>, tensor<!tf.string>, tensor<?xi64>, tensor<!tf.string>, tensor<?xi64>) -> ()
|
||||||
|
%4 = "tf.Identity"(%arg0) {device = ""} : (tensor<i1>) -> tensor<i1>
|
||||||
|
%5 = "tf.Identity"(%4) {device = ""} : (tensor<i1>) -> tensor<i1>
|
||||||
|
return %5 : tensor<i1>
|
||||||
|
}
|
||||||
|
// CHECK: func @ngrams_ragged_rank_2(%arg0: tensor<?x!tf.string> {tf._user_specified_name = "values"}, %arg1: tensor<3xi64> {tf._user_specified_name = "args_0"}, %arg2: tensor<?xi64> {tf._user_specified_name = "args_1"}) -> (tensor<?x!tf.string>, tensor<3xi64>, tensor<?xi64>) attributes {sym_visibility = "private", tf._implements = #tf.func<@"tftext:Ngrams", {axis = -1 : i64, reduction_type = "STRING_JOIN", string_separator = "", width = 2 : i64}>, tf._input_shapes = [#tf.shape<?>, #tf.shape<3>, #tf.shape<?>], tf.signature.is_stateful} {
|
||||||
|
// CHECK: %0:3 = "tfl.custom"(%arg0, %arg1, %arg2) {custom_code = "tftext:Ngrams", custom_option = opaque<"tfl", "0x776964746800737472696E675F736570617261746F720000006178697300726564756374696F6E5F74797065000B535452494E475F4A4F494E0004221E373E040104FF152C0204141404082401"> : tensor<77xi8>} : (tensor<?x!tf.string>, tensor<3xi64>, tensor<?xi64>) -> (tensor<?x!tf.string>, tensor<3xi64>, tensor<?xi64>)
|
||||||
|
// CHECK: return %0#0, %0#1, %0#2 : tensor<?x!tf.string>, tensor<3xi64>, tensor<?xi64>
|
@ -15,6 +15,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/compiler/mlir/lite/utils/tftext_utils.h"
|
#include "tensorflow/compiler/mlir/lite/utils/tftext_utils.h"
|
||||||
|
|
||||||
|
#include "flatbuffers/flexbuffers.h" // from @flatbuffers
|
||||||
#include "llvm/ADT/ArrayRef.h"
|
#include "llvm/ADT/ArrayRef.h"
|
||||||
#include "llvm/ADT/None.h"
|
#include "llvm/ADT/None.h"
|
||||||
#include "llvm/ADT/SmallVector.h"
|
#include "llvm/ADT/SmallVector.h"
|
||||||
@ -28,6 +29,7 @@ limitations under the License.
|
|||||||
#include "mlir/IR/Identifier.h" // from @llvm-project
|
#include "mlir/IR/Identifier.h" // from @llvm-project
|
||||||
#include "mlir/IR/Location.h" // from @llvm-project
|
#include "mlir/IR/Location.h" // from @llvm-project
|
||||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||||
|
#include "mlir/IR/Matchers.h" // from @llvm-project
|
||||||
#include "mlir/IR/OpDefinition.h" // from @llvm-project
|
#include "mlir/IR/OpDefinition.h" // from @llvm-project
|
||||||
#include "mlir/IR/Operation.h" // from @llvm-project
|
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||||
@ -43,32 +45,35 @@ namespace TFL {
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
constexpr char kNgrams[] = "tftext:Ngrams";
|
||||||
constexpr char kWhitespaceTokenizer[] = "tftext:WhitespaceTokenizer";
|
constexpr char kWhitespaceTokenizer[] = "tftext:WhitespaceTokenizer";
|
||||||
constexpr char kTFImplements[] = "tf._implements";
|
constexpr char kTFImplements[] = "tf._implements";
|
||||||
|
|
||||||
using mlir::TF::FuncAttr;
|
using mlir::TF::FuncAttr;
|
||||||
|
using mlir::TF::StringType;
|
||||||
|
|
||||||
inline OpaqueElementsAttr emptyCustomOption(OpBuilder* builder) {
|
inline OpaqueElementsAttr CustomOption(OpBuilder* builder,
|
||||||
std::string content = "";
|
const std::string& content) {
|
||||||
ShapedType type = RankedTensorType::get(
|
ShapedType type = RankedTensorType::get(
|
||||||
{static_cast<int64_t>(content.size())}, builder->getIntegerType(8));
|
{static_cast<int64_t>(content.size())}, builder->getIntegerType(8));
|
||||||
return OpaqueElementsAttr::get(
|
return OpaqueElementsAttr::get(
|
||||||
builder->getContext()->getRegisteredDialect("tfl"), type, content);
|
builder->getContext()->getRegisteredDialect("tfl"), type,
|
||||||
|
StringRef(content.data(), content.size()));
|
||||||
}
|
}
|
||||||
|
|
||||||
inline RankedTensorType getInputType(mlir::FuncOp func, int idx) {
|
inline TensorType GetInputType(FuncOp func, int idx) {
|
||||||
return func.getType()
|
return func.getType().getInput(idx).dyn_cast_or_null<TensorType>();
|
||||||
.getInput(idx)
|
|
||||||
.dyn_cast_or_null<mlir::RankedTensorType>();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
inline RankedTensorType getResultType(mlir::FuncOp func, int idx) {
|
inline TensorType GetResultType(FuncOp func, int idx) {
|
||||||
return func.getType()
|
return func.getType().getResult(idx).dyn_cast_or_null<TensorType>();
|
||||||
.getResult(idx)
|
|
||||||
.dyn_cast_or_null<mlir::RankedTensorType>();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult VerifyWhitespaceTokenizer(mlir::FuncOp func) {
|
inline bool RankEquals(const TensorType& type, int rank) {
|
||||||
|
return type && type.hasRank() && type.getRank() == rank;
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult VerifyWhitespaceTokenizer(FuncOp func) {
|
||||||
// In the case of input tensor with 0 rank.
|
// In the case of input tensor with 0 rank.
|
||||||
// Whitespace tokenizer generates 1 output:
|
// Whitespace tokenizer generates 1 output:
|
||||||
// * String tensor for tokens.
|
// * String tensor for tokens.
|
||||||
@ -83,8 +88,8 @@ LogicalResult VerifyWhitespaceTokenizer(mlir::FuncOp func) {
|
|||||||
// * 1st output is the value of ragged tensor;
|
// * 1st output is the value of ragged tensor;
|
||||||
// * 2nd output is the inner offset;
|
// * 2nd output is the inner offset;
|
||||||
// * 3rd output is the outer offset.
|
// * 3rd output is the outer offset.
|
||||||
auto input_type = getInputType(func, 0);
|
auto input_type = GetInputType(func, 0);
|
||||||
if (!input_type || !input_type.getElementType().isa<mlir::TF::StringType>() ||
|
if (!input_type || !input_type.getElementType().isa<StringType>() ||
|
||||||
!input_type.hasRank()) {
|
!input_type.hasRank()) {
|
||||||
return func.emitError() << "Input should be a string tensor";
|
return func.emitError() << "Input should be a string tensor";
|
||||||
}
|
}
|
||||||
@ -100,21 +105,21 @@ LogicalResult VerifyWhitespaceTokenizer(mlir::FuncOp func) {
|
|||||||
<< "output(s) when input has rank " << input_type.getRank();
|
<< "output(s) when input has rank " << input_type.getRank();
|
||||||
}
|
}
|
||||||
|
|
||||||
auto value_type = getResultType(func, 0);
|
auto value_type = GetResultType(func, 0);
|
||||||
if (!value_type || !value_type.hasRank() || value_type.getRank() != 1 ||
|
if (!RankEquals(value_type, 1) ||
|
||||||
!value_type.getElementType().isa<mlir::TF::StringType>()) {
|
!value_type.getElementType().isa<StringType>()) {
|
||||||
return func.emitError() << "1st output should be string tensor";
|
return func.emitError() << "1st output should be string tensor";
|
||||||
}
|
}
|
||||||
if (func.getNumResults() > 1) {
|
if (func.getNumResults() > 1) {
|
||||||
auto offset_type = getResultType(func, 1);
|
auto offset_type = GetResultType(func, 1);
|
||||||
if (!offset_type || !offset_type.hasRank() || offset_type.getRank() != 1 ||
|
if (!RankEquals(offset_type, 1) ||
|
||||||
!offset_type.getElementType().isInteger(64)) {
|
!offset_type.getElementType().isInteger(64)) {
|
||||||
return func.emitError() << "2nd output should be int64 tensor";
|
return func.emitError() << "2nd output should be int64 tensor";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (func.getNumResults() > 2) {
|
if (func.getNumResults() > 2) {
|
||||||
auto offset_type = getResultType(func, 2);
|
auto offset_type = GetResultType(func, 2);
|
||||||
if (!offset_type || !offset_type.hasRank() || offset_type.getRank() != 1 ||
|
if (!RankEquals(offset_type, 1) ||
|
||||||
!offset_type.getElementType().isInteger(64)) {
|
!offset_type.getElementType().isInteger(64)) {
|
||||||
return func.emitError() << "3rd output should be int64 tensor";
|
return func.emitError() << "3rd output should be int64 tensor";
|
||||||
}
|
}
|
||||||
@ -123,28 +128,159 @@ LogicalResult VerifyWhitespaceTokenizer(mlir::FuncOp func) {
|
|||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult ConvertWhitespaceTokenizer(mlir::FuncOp func, llvm::StringRef api,
|
LogicalResult ConvertWhitespaceTokenizer(FuncOp func, llvm::StringRef api,
|
||||||
FuncAttr attr) {
|
FuncAttr attr) {
|
||||||
func.eraseBody();
|
func.eraseBody();
|
||||||
func.addEntryBlock();
|
func.addEntryBlock();
|
||||||
func.setAttr(kTFImplements, attr);
|
func.setAttr(kTFImplements, attr);
|
||||||
Value text = func.getArgument(0);
|
|
||||||
OpBuilder builder(func.getBody());
|
OpBuilder builder(func.getBody());
|
||||||
|
std::string empty_option_buffer;
|
||||||
auto op = builder.create<mlir::TFL::CustomOp>(
|
auto op = builder.create<CustomOp>(
|
||||||
func.getLoc(), func.getType().getResults(), ValueRange(text), api,
|
func.getLoc(), func.getType().getResults(), func.getArguments(), api,
|
||||||
emptyCustomOption(&builder));
|
CustomOption(&builder, empty_option_buffer));
|
||||||
builder.create<mlir::ReturnOp>(func.getLoc(), op.getResults());
|
builder.create<ReturnOp>(func.getLoc(), op.getResults());
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
LogicalResult VerifyNgrams(FuncOp func) {
|
||||||
|
// The inputs and outputs should be the same:
|
||||||
|
// * A string tensor for tokens/ragged tensor values.
|
||||||
|
// * Zero or more row_split tensors.
|
||||||
|
constexpr int kValues = 0;
|
||||||
|
constexpr int kRowSplits = 1;
|
||||||
|
|
||||||
|
if (func.getType().getInputs().size() != func.getType().getResults().size()) {
|
||||||
|
return func.emitError() << "Mismatched number of inputs and outputs.";
|
||||||
|
}
|
||||||
|
|
||||||
|
int row_splits = func.getType().getInputs().size() - kRowSplits;
|
||||||
|
if (row_splits == 0) {
|
||||||
|
auto input_values = GetInputType(func, kValues);
|
||||||
|
if (!input_values || !input_values.getElementType().isa<StringType>()) {
|
||||||
|
return func.emitError()
|
||||||
|
<< "Input " << kValues << " should be a string tensor";
|
||||||
|
}
|
||||||
|
auto output_values = GetResultType(func, kValues);
|
||||||
|
if (!output_values || !output_values.getElementType().isa<StringType>()) {
|
||||||
|
return func.emitError()
|
||||||
|
<< "Output " << kValues << " should be a string tensor";
|
||||||
|
}
|
||||||
|
|
||||||
|
if (input_values.hasRank() && output_values.hasRank() &&
|
||||||
|
input_values.getRank() != output_values.getRank()) {
|
||||||
|
return func.emitError() << "Input " << kValues << " and output "
|
||||||
|
<< kValues << " should have the same rank";
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
auto input_values = GetInputType(func, kValues);
|
||||||
|
if (!RankEquals(input_values, 1) ||
|
||||||
|
!input_values.getElementType().isa<StringType>()) {
|
||||||
|
return func.emitError()
|
||||||
|
<< "Input " << kValues << " should be a 1D string tensor";
|
||||||
|
}
|
||||||
|
auto output_values = GetResultType(func, kValues);
|
||||||
|
if (!RankEquals(output_values, 1) ||
|
||||||
|
!output_values.getElementType().isa<StringType>()) {
|
||||||
|
return func.emitError()
|
||||||
|
<< "Output " << kValues << " should be a 1D string tensor";
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < row_splits; ++i) {
|
||||||
|
const int row_index = i + kRowSplits;
|
||||||
|
auto input_row_splits = GetInputType(func, row_index);
|
||||||
|
if (!RankEquals(input_row_splits, 1) ||
|
||||||
|
!input_row_splits.getElementType().isInteger(64)) {
|
||||||
|
return func.emitError()
|
||||||
|
<< "Input " << row_index << " should be a 1D int64 tensor";
|
||||||
|
}
|
||||||
|
auto output_row_splits = GetResultType(func, row_index);
|
||||||
|
if (!RankEquals(output_row_splits, 1) ||
|
||||||
|
!output_row_splits.getElementType().isInteger(64)) {
|
||||||
|
return func.emitError()
|
||||||
|
<< "Output " << row_index << " should be a 1D int64 tensor";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult CreateNgramsCustomOption(FuncOp func, DictionaryAttr attrs,
|
||||||
|
std::string& custom_option_buffer) {
|
||||||
|
flexbuffers::Builder fbb;
|
||||||
|
size_t start_map = fbb.StartMap();
|
||||||
|
|
||||||
|
auto width = attrs.get("width").dyn_cast_or_null<IntegerAttr>();
|
||||||
|
if (!width) {
|
||||||
|
return func.emitError() << "'width' attribute is not set or not an integer";
|
||||||
|
}
|
||||||
|
fbb.Int("width", width.getInt());
|
||||||
|
|
||||||
|
auto string_separator =
|
||||||
|
attrs.get("string_separator").dyn_cast_or_null<StringAttr>();
|
||||||
|
if (!string_separator) {
|
||||||
|
return func.emitError()
|
||||||
|
<< "'string_separator' attribute is not set or not a string";
|
||||||
|
}
|
||||||
|
// StringAttrs are not guaranteed to be NUL terminated, but flexbuffers
|
||||||
|
// strings expect NUL terminated strings.
|
||||||
|
std::string string_separator_str(string_separator.getValue().data(),
|
||||||
|
string_separator.getValue().size());
|
||||||
|
fbb.String("string_separator", string_separator_str);
|
||||||
|
|
||||||
|
auto axis = attrs.get("axis").dyn_cast_or_null<IntegerAttr>();
|
||||||
|
if (!axis) {
|
||||||
|
return func.emitError() << "'axis' attribute is not set or not an integer";
|
||||||
|
}
|
||||||
|
fbb.Int("axis", axis.getInt());
|
||||||
|
|
||||||
|
auto reduction_type =
|
||||||
|
attrs.get("reduction_type").dyn_cast_or_null<StringAttr>();
|
||||||
|
if (!reduction_type) {
|
||||||
|
return func.emitError()
|
||||||
|
<< "'reduction_type' attribute is not set or not a string";
|
||||||
|
}
|
||||||
|
// StringAttrs are not guaranteed to be NUL terminated, but flexbuffers
|
||||||
|
// strings expect NUL terminated strings.
|
||||||
|
std::string reduction_type_str(reduction_type.getValue().data(),
|
||||||
|
reduction_type.getValue().size());
|
||||||
|
fbb.String("reduction_type", reduction_type_str);
|
||||||
|
|
||||||
|
fbb.EndMap(start_map);
|
||||||
|
fbb.Finish();
|
||||||
|
custom_option_buffer.assign(fbb.GetBuffer().begin(), fbb.GetBuffer().end());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult ConvertNgrams(FuncOp func, llvm::StringRef api, FuncAttr attr) {
|
||||||
|
func.eraseBody();
|
||||||
|
func.addEntryBlock();
|
||||||
|
func.setAttr(kTFImplements, attr);
|
||||||
|
OpBuilder builder(func.getBody());
|
||||||
|
std::string custom_option_buffer;
|
||||||
|
if (failed(CreateNgramsCustomOption(func, attr.GetAttrs(),
|
||||||
|
custom_option_buffer))) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
auto op = builder.create<CustomOp>(
|
||||||
|
func.getLoc(), func.getType().getResults(), func.getArguments(), api,
|
||||||
|
CustomOption(&builder, custom_option_buffer));
|
||||||
|
builder.create<ReturnOp>(func.getLoc(), op.getResults());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
LogicalResult ConvertTFTextAPI(mlir::FuncOp func, llvm::StringRef api,
|
LogicalResult ConvertTFTextAPI(FuncOp func, llvm::StringRef api,
|
||||||
FuncAttr attr) {
|
FuncAttr attr) {
|
||||||
if (api.str() == kWhitespaceTokenizer) {
|
if (api.str() == kWhitespaceTokenizer) {
|
||||||
if (succeeded(VerifyWhitespaceTokenizer(func))) {
|
if (succeeded(VerifyWhitespaceTokenizer(func))) {
|
||||||
return ConvertWhitespaceTokenizer(func, api, attr);
|
return ConvertWhitespaceTokenizer(func, api, attr);
|
||||||
}
|
}
|
||||||
|
} else if (api.str() == kNgrams) {
|
||||||
|
if (succeeded(VerifyNgrams(func))) {
|
||||||
|
return ConvertNgrams(func, api, attr);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user