Add support for target nodes of TensorFlow Graphs in TF MLIR importer.

It is possible to have target nodes/control rets for nodes that must execute but may not have data output(s). For example, side effecting ops, like resource writes, may rely on being set as a target node to be executed. This also unifies control rets support added for function graphs.

PiperOrigin-RevId: 302683770
Change-Id: I33a66ae0ce08ac2230b575caae2cc752636d0d8b
This commit is contained in:
Andy Ly 2020-03-24 09:39:51 -07:00 committed by TensorFlower Gardener
parent e1afcc5feb
commit 463bec0d92
15 changed files with 386 additions and 121 deletions

View File

@ -36,8 +36,11 @@ versions {
producer: 27
}
# CHECK: func @main(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<*xi32>
# CHECK: attributes {tf.entry_function = {inputs = "input0,input1", outputs = "output"}} {
# CHECK-NEXT: %0 = "tf.BannaPotatoSaladWithColeslaw"(%arg0, %arg1) {T = i32, device = ""} : (tensor<4xi32>, tensor<4xi32>) -> tensor<*xi32>
# CHECK-NEXT: return %0 : tensor<*xi32>
# CHECK-NEXT: }
# CHECK-LABEL: func @main
# CHECK-SAME: (%[[ARG_0:[a-z0-9]+]]: tensor<4xi32>, %[[ARG_1:[a-z0-9]+]]: tensor<4xi32>) -> tensor<*xi32>
# CHECK-SAME: control_outputs = ""
# CHECK-SAME inputs = "input0,input1"
# CHECK-SAME: outputs = "output"
# CHECK-NEXT: %[[OP:[a-z0-9]+]] = "tf.BannaPotatoSaladWithColeslaw"(%[[ARG_0]], %[[ARG_1]]) {T = i32, device = ""} : (tensor<4xi32>, tensor<4xi32>) -> tensor<*xi32>
# CHECK-NEXT: return %[[OP]] : tensor<*xi32>
# CHECK-NEXT: }

View File

@ -443,15 +443,18 @@ node {
}
}
# MLIR-LABEL: func @main(%arg0: tensor<1x1x1x256x!quant.uniform<i8:f32, 0.21632751372549019:27>>) -> tensor<1x6x31x!quant.uniform<i8:f32, 0.09363494573854933:22>>
# MLIR: attributes {tf.entry_function = {inputs = "input", outputs = "output"}
# MLIR: %[[shape:.*]] = constant dense<[1, -1, 31]> : tensor<3xi32>
# MLIR: %[[bias:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<186x!quant.uniform<i32:f32:0
# MLIR: %[[weight:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<186x1x1x256x!quant.uniform<i8<-127:127>:f32:0, {0.12581039038230116,
# MLIR: %[[conv:.*]] = "tfl.conv_2d"(%arg0, %[[weight]], %[[bias]]) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32}
# MLIR: %[[reshape:.*]] = "tfl.reshape"(%[[conv]], %[[shape]]) : (tensor<1x1x1x186x!quant.uniform<i8:f32, 0.09363494573854933:22>>, tensor<3xi32>)
# MLIR: return %[[reshape]] : tensor<1x6x31x!quant.uniform<i8:f32, 0.09363494573854933:22>>
# MLIR: }
# MLIR-LABEL: func @main
# MLIR-SAME: (%[[ARG_0:[a-z0-9]+]]: tensor<1x1x1x256x!quant.uniform<i8:f32, 0.21632751372549019:27>>) -> tensor<1x6x31x!quant.uniform<i8:f32, 0.09363494573854933:22>>
# MLIR-SAME: control_outputs = ""
# MLIR-SAME: inputs = "input"
# MLIR-SAME: outputs = "output"
# MLIR: %[[shape:.*]] = constant dense<[1, -1, 31]> : tensor<3xi32>
# MLIR: %[[bias:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<186x!quant.uniform<i32:f32:0
# MLIR: %[[weight:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<186x1x1x256x!quant.uniform<i8<-127:127>:f32:0, {0.12581039038230116,
# MLIR: %[[conv:.*]] = "tfl.conv_2d"(%[[ARG_0]], %[[weight]], %[[bias]]) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32}
# MLIR: %[[reshape:.*]] = "tfl.reshape"(%[[conv]], %[[shape]]) : (tensor<1x1x1x186x!quant.uniform<i8:f32, 0.09363494573854933:22>>, tensor<3xi32>)
# MLIR: return %[[reshape]] : tensor<1x6x31x!quant.uniform<i8:f32, 0.09363494573854933:22>>
# MLIR: }
# CHECK-LABEL: {
# CHECK: version: 3,

View File

@ -47,7 +47,10 @@ versions {
producer: 27
}
# CHECK: func @main(%arg0: tensor<4xi32>) -> tensor<4xi32>
# CHECK: attributes {tf.entry_function = {inputs = "input", outputs = "output"}} {
# CHECK-NEXT: return %arg0 : tensor<4xi32>
# CHECK-NEXT: }
# CHECK-LABEL: func @main
# CHECK-SAME: (%[[ARG_0:[a-z0-9]+]]: tensor<4xi32>) -> tensor<4xi32>
# CHECK-SAME: control_outputs = ""
# CHECK-SAME: inputs = "input"
# CHECK-SAME: outputs = "output"
# CHECK-NEXT: return %[[ARG_0]] : tensor<4xi32>
# CHECK-NEXT: }

View File

@ -137,9 +137,12 @@ versions {
producer: 198
}
# CHECK: func @main([[VAL_0:%.*]]: tensor<1x8x8x2xi32>) -> (tensor<1x8x8x2xi32>, tensor<i32>, tensor<10xi32>)
# CHECK: attributes {tf.entry_function = {inputs = "unranked", outputs = "unranked,static,static_10"}} {
# CHECK: [[VAL_1:%.*]] = constant dense<0> : tensor<10xi32>
# CHECK: [[VAL_2:%.*]] = constant dense<0> : tensor<i32>
# CHECK: return [[VAL_0]], [[VAL_2]], [[VAL_1]] : tensor<1x8x8x2xi32>, tensor<i32>, tensor<10xi32>
# CHECK: }
# CHECK-LABEL: func @main
# CHECK-SAME: ([[VAL_0:%.*]]: tensor<1x8x8x2xi32>) -> (tensor<1x8x8x2xi32>, tensor<i32>, tensor<10xi32>)
# CHECK-SAME: control_outputs = ""
# CHECK-SAME: inputs = "unranked"
# CHECK-SAME: outputs = "unranked,static,static_10"
# CHECK: [[VAL_1:%.*]] = constant dense<0> : tensor<10xi32>
# CHECK: [[VAL_2:%.*]] = constant dense<0> : tensor<i32>
# CHECK: return [[VAL_0]], [[VAL_2]], [[VAL_1]] : tensor<1x8x8x2xi32>, tensor<i32>, tensor<10xi32>
# CHECK: }

View File

@ -7783,37 +7783,40 @@ node {
library {
}
# CHECK: func @main(%arg0: tensor<1x3x3xf32>) -> tensor<1x3xf32>
# CHECK: attributes {tf.entry_function = {inputs = "INPUT", outputs = "OUTPUT"}} {
# CHECK: [[VAL_1:%.*]] = constant dense<{{\[\[}}-0.400154352, 0.739109992, 0.201825857], [0.678572893, 0.32076478, 0.949867963], [-0.807729483, -5.324750e-01, 0.148033619]]> : tensor<3x3xf32>
# CHECK: [[VAL_2:%.*]] = constant dense<{{\[\[}}0.886177539, -0.606141329, -0.451275587], [0.325554609, 0.691527605, -0.676239967], [0.219799042, 0.626042128, -0.597596407]]> : tensor<3x3xf32>
# CHECK: [[VAL_3:%.*]] = constant dense<{{\[\[}}-0.493826151, -0.391061306, -0.349843264], [-0.0213134289, 0.558384657, -0.51513052], [0.427886248, 0.618100405, -0.187585592]]> : tensor<3x3xf32>
# CHECK: [[VAL_4:%.*]] = constant dense<{{\[\[}}0.444335222, -0.133341789, 0.839591503], [0.445418358, -0.571707964, 0.569707394], [0.465010405, -0.990037918, -0.632481337]]> : tensor<3x3xf32>
# CHECK: [[VAL_5:%.*]] = constant dense<{{\[\[}}-0.138204336, -0.10879755, -0.135128736], [0.94797182, -8.713360e-01, -0.792336463], [0.0339827538, -0.539326906, 8.906350e-01]]> : tensor<3x3xf32>
# CHECK: [[VAL_6:%.*]] = constant dense<{{\[\[}}0.513064623, -0.692989588, 0.547988653], [0.0653710365, 0.576977491, 0.966733217], [0.0130724907, 0.247342348, 0.317092657]]> : tensor<3x3xf32>
# CHECK: [[VAL_7:%.*]] = constant dense<{{\[\[}}0.230039358, -0.182297707, -0.352231741], [-0.805100203, -0.220300436, -0.669503212], [0.278807402, -0.201502323, -0.627609729]]> : tensor<3x3xf32>
# CHECK: [[VAL_8:%.*]] = constant dense<{{\[\[}}-0.207589626, -0.756766081, -0.853258133], [-0.269270182, 0.0468223095, -0.353052378], [-0.0702953338, 0.0725159645, -0.817753077]]> : tensor<3x3xf32>
# CHECK: [[VAL_9:%.*]] = constant dense<[0.171322107, -0.153412342, 0.591750383]> : tensor<3xf32>
# CHECK: [[VAL_10:%.*]] = constant dense<[-0.671292543, 0.411814928, 0.560465336]> : tensor<3xf32>
# CHECK: [[VAL_11:%.*]] = constant dense<[0.403919935, -0.882057666, -0.894463062]> : tensor<3xf32>
# CHECK: [[VAL_12:%.*]] = constant dense<{{\[\[}}-0.936182261, -0.935433864, 0.288229942], [-0.243383884, -0.628288031, -0.477061749], [-0.514976501, -0.903514862, 6.728170e-01]]> : tensor<3x3xf32>
# CHECK: [[VAL_13:%.*]] = constant dense<{{\[\[}}0.18183589, 0.616135359, -0.167827845], [0.734281301, 0.958347797, -0.878054618], [0.369523764, -0.969005823, -0.881014585]]> : tensor<3x3xf32>
# CHECK: [[VAL_14:%.*]] = constant dense<{{\[\[}}-5.087240e-01, -0.588907719, 0.471896172], [-0.508019447, -0.0157074928, -0.804120779], [-0.978842973, 0.00160336494, -0.978532075]]> : tensor<3x3xf32>
# CHECK: [[VAL_15:%.*]] = constant dense<{{\[\[}}-0.616786718, 0.892614365, 0.671324968], [-0.842380046, -0.358094931, 0.821366549], [0.790347338, 0.71222949, 0.0690443515]]> : tensor<3x3xf32>
# CHECK: [[VAL_16:%.*]] = constant dense<1.000000e+00> : tensor<3xf32>
# CHECK: [[VAL_17:%.*]] = constant dense<{{\[\[}}0.782244444, -0.0446639061, 0.848498106], [-0.579102755, -0.407756329, 0.442389727], [0.00566458702, 0.5984025, 0.629857302]]> : tensor<3x3xf32>
# CHECK: [[VAL_18:%.*]] = constant dense<{{\[\[}}0.891112089, -2.786560e-01, 0.966933965], [-0.789963722, 0.057955265, 0.217499971], [-0.698129416, -0.983400583, -0.834380626]]> : tensor<3x3xf32>
# CHECK: [[VAL_19:%.*]] = constant dense<{{\[\[}}-0.125753641, 0.32271719, 0.488939524], [0.36119318, 0.982266664, -0.448646784], [0.966353893, -0.767024993, 0.446366787]]> : tensor<3x3xf32>
# CHECK: [[VAL_20:%.*]] = constant dense<{{\[\[}}-0.856678485, -0.800494194, 0.716800689], [0.536404848, 0.541643381, -0.35657692], [-0.794646739, 0.137629032, 0.690013885]]> : tensor<3x3xf32>
# CHECK: [[VAL_21:%.*]] = constant dense<0.000000e+00> : tensor<3xf32>
# CHECK: [[VAL_22:%.*]] = constant dense<0.000000e+00> : tensor<1x3xf32>
# CHECK: [[VAL_23:%.*]] = constant unit
# CHECK: [[VAL_24:%.*]]:3 = "tfl.unpack"(%arg0) {axis = 1 : i32, num = 3 : i32} : (tensor<1x3x3xf32>) -> (tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>)
# CHECK: [[VAL_25:%.*]] = "tfl.pack"([[VAL_24]]#0, [[VAL_24]]#1, [[VAL_24]]#2) {axis = 0 : i32, values_count = 3 : i32} : (tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>) -> tensor<3x1x3xf32>
# CHECK: [[VAL_26:%.*]] = constant dense<0.000000e+00> : tensor<1x3xf32>
# CHECK: [[VAL_27:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_25]], [[VAL_7]], [[VAL_6]], [[VAL_5]], [[VAL_8]], [[VAL_3]], [[VAL_2]], [[VAL_1]], [[VAL_4]], [[VAL_10]], [[VAL_9]], [[VAL_11]], [[VAL_21]], [[VAL_16]], [[VAL_21]], [[VAL_21]], [[VAL_23]], [[VAL_23]], [[VAL_22]], [[VAL_26]], [[VAL_23]], [[VAL_23]], [[VAL_23]], [[VAL_23]]) {fused_activation_function = "TANH", time_major = true} : (tensor<3x1x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, none, none, tensor<1x3xf32>, tensor<1x3xf32>, none, none, none, none) -> tensor<3x1x3xf32>
# CHECK: [[VAL_28:%.*]] = constant dense<0.000000e+00> : tensor<1x3xf32>
# CHECK: [[VAL_29:%.*]] = constant dense<0.000000e+00> : tensor<1x3xf32>
# CHECK: [[VAL_30:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_27]], [[VAL_19]], [[VAL_18]], [[VAL_17]], [[VAL_20]], [[VAL_14]], [[VAL_13]], [[VAL_12]], [[VAL_15]], [[VAL_23]], [[VAL_23]], [[VAL_23]], [[VAL_21]], [[VAL_16]], [[VAL_21]], [[VAL_21]], [[VAL_23]], [[VAL_23]], [[VAL_28]], [[VAL_29]], [[VAL_23]], [[VAL_23]], [[VAL_23]], [[VAL_23]]) {fused_activation_function = "TANH", time_major = true} : (tensor<3x1x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, none, none, none, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, none, none, tensor<1x3xf32>, tensor<1x3xf32>, none, none, none, none) -> tensor<3x1x3xf32>
# CHECK: [[VAL_31:%.*]]:3 = "tfl.unpack"([[VAL_30]]) {axis = 0 : i32, num = 3 : i32} : (tensor<3x1x3xf32>) -> (tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>)
# CHECK: return [[VAL_31]]#2 : tensor<1x3xf32>
# CHECK-LABEL: func @main
# CHECK-SAME: (%[[ARG_0:[a-z0-9]+]]: tensor<1x3x3xf32>) -> tensor<1x3xf32>
# CHECK-SAME: control_outputs = ""
# CHECK-SAME: inputs = "INPUT"
# CHECK-SAME: outputs = "OUTPUT"
# CHECK: [[VAL_1:%.*]] = constant dense<{{\[\[}}-0.400154352, 0.739109992, 0.201825857], [0.678572893, 0.32076478, 0.949867963], [-0.807729483, -5.324750e-01, 0.148033619]]> : tensor<3x3xf32>
# CHECK: [[VAL_2:%.*]] = constant dense<{{\[\[}}0.886177539, -0.606141329, -0.451275587], [0.325554609, 0.691527605, -0.676239967], [0.219799042, 0.626042128, -0.597596407]]> : tensor<3x3xf32>
# CHECK: [[VAL_3:%.*]] = constant dense<{{\[\[}}-0.493826151, -0.391061306, -0.349843264], [-0.0213134289, 0.558384657, -0.51513052], [0.427886248, 0.618100405, -0.187585592]]> : tensor<3x3xf32>
# CHECK: [[VAL_4:%.*]] = constant dense<{{\[\[}}0.444335222, -0.133341789, 0.839591503], [0.445418358, -0.571707964, 0.569707394], [0.465010405, -0.990037918, -0.632481337]]> : tensor<3x3xf32>
# CHECK: [[VAL_5:%.*]] = constant dense<{{\[\[}}-0.138204336, -0.10879755, -0.135128736], [0.94797182, -8.713360e-01, -0.792336463], [0.0339827538, -0.539326906, 8.906350e-01]]> : tensor<3x3xf32>
# CHECK: [[VAL_6:%.*]] = constant dense<{{\[\[}}0.513064623, -0.692989588, 0.547988653], [0.0653710365, 0.576977491, 0.966733217], [0.0130724907, 0.247342348, 0.317092657]]> : tensor<3x3xf32>
# CHECK: [[VAL_7:%.*]] = constant dense<{{\[\[}}0.230039358, -0.182297707, -0.352231741], [-0.805100203, -0.220300436, -0.669503212], [0.278807402, -0.201502323, -0.627609729]]> : tensor<3x3xf32>
# CHECK: [[VAL_8:%.*]] = constant dense<{{\[\[}}-0.207589626, -0.756766081, -0.853258133], [-0.269270182, 0.0468223095, -0.353052378], [-0.0702953338, 0.0725159645, -0.817753077]]> : tensor<3x3xf32>
# CHECK: [[VAL_9:%.*]] = constant dense<[0.171322107, -0.153412342, 0.591750383]> : tensor<3xf32>
# CHECK: [[VAL_10:%.*]] = constant dense<[-0.671292543, 0.411814928, 0.560465336]> : tensor<3xf32>
# CHECK: [[VAL_11:%.*]] = constant dense<[0.403919935, -0.882057666, -0.894463062]> : tensor<3xf32>
# CHECK: [[VAL_12:%.*]] = constant dense<{{\[\[}}-0.936182261, -0.935433864, 0.288229942], [-0.243383884, -0.628288031, -0.477061749], [-0.514976501, -0.903514862, 6.728170e-01]]> : tensor<3x3xf32>
# CHECK: [[VAL_13:%.*]] = constant dense<{{\[\[}}0.18183589, 0.616135359, -0.167827845], [0.734281301, 0.958347797, -0.878054618], [0.369523764, -0.969005823, -0.881014585]]> : tensor<3x3xf32>
# CHECK: [[VAL_14:%.*]] = constant dense<{{\[\[}}-5.087240e-01, -0.588907719, 0.471896172], [-0.508019447, -0.0157074928, -0.804120779], [-0.978842973, 0.00160336494, -0.978532075]]> : tensor<3x3xf32>
# CHECK: [[VAL_15:%.*]] = constant dense<{{\[\[}}-0.616786718, 0.892614365, 0.671324968], [-0.842380046, -0.358094931, 0.821366549], [0.790347338, 0.71222949, 0.0690443515]]> : tensor<3x3xf32>
# CHECK: [[VAL_16:%.*]] = constant dense<1.000000e+00> : tensor<3xf32>
# CHECK: [[VAL_17:%.*]] = constant dense<{{\[\[}}0.782244444, -0.0446639061, 0.848498106], [-0.579102755, -0.407756329, 0.442389727], [0.00566458702, 0.5984025, 0.629857302]]> : tensor<3x3xf32>
# CHECK: [[VAL_18:%.*]] = constant dense<{{\[\[}}0.891112089, -2.786560e-01, 0.966933965], [-0.789963722, 0.057955265, 0.217499971], [-0.698129416, -0.983400583, -0.834380626]]> : tensor<3x3xf32>
# CHECK: [[VAL_19:%.*]] = constant dense<{{\[\[}}-0.125753641, 0.32271719, 0.488939524], [0.36119318, 0.982266664, -0.448646784], [0.966353893, -0.767024993, 0.446366787]]> : tensor<3x3xf32>
# CHECK: [[VAL_20:%.*]] = constant dense<{{\[\[}}-0.856678485, -0.800494194, 0.716800689], [0.536404848, 0.541643381, -0.35657692], [-0.794646739, 0.137629032, 0.690013885]]> : tensor<3x3xf32>
# CHECK: [[VAL_21:%.*]] = constant dense<0.000000e+00> : tensor<3xf32>
# CHECK: [[VAL_22:%.*]] = constant dense<0.000000e+00> : tensor<1x3xf32>
# CHECK: [[VAL_23:%.*]] = constant unit
# CHECK: [[VAL_24:%.*]]:3 = "tfl.unpack"(%[[ARG_0]]) {axis = 1 : i32, num = 3 : i32} : (tensor<1x3x3xf32>) -> (tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>)
# CHECK: [[VAL_25:%.*]] = "tfl.pack"([[VAL_24]]#0, [[VAL_24]]#1, [[VAL_24]]#2) {axis = 0 : i32, values_count = 3 : i32} : (tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>) -> tensor<3x1x3xf32>
# CHECK: [[VAL_26:%.*]] = constant dense<0.000000e+00> : tensor<1x3xf32>
# CHECK: [[VAL_27:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_25]], [[VAL_7]], [[VAL_6]], [[VAL_5]], [[VAL_8]], [[VAL_3]], [[VAL_2]], [[VAL_1]], [[VAL_4]], [[VAL_10]], [[VAL_9]], [[VAL_11]], [[VAL_21]], [[VAL_16]], [[VAL_21]], [[VAL_21]], [[VAL_23]], [[VAL_23]], [[VAL_22]], [[VAL_26]], [[VAL_23]], [[VAL_23]], [[VAL_23]], [[VAL_23]]) {fused_activation_function = "TANH", time_major = true} : (tensor<3x1x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, none, none, tensor<1x3xf32>, tensor<1x3xf32>, none, none, none, none) -> tensor<3x1x3xf32>
# CHECK: [[VAL_28:%.*]] = constant dense<0.000000e+00> : tensor<1x3xf32>
# CHECK: [[VAL_29:%.*]] = constant dense<0.000000e+00> : tensor<1x3xf32>
# CHECK: [[VAL_30:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_27]], [[VAL_19]], [[VAL_18]], [[VAL_17]], [[VAL_20]], [[VAL_14]], [[VAL_13]], [[VAL_12]], [[VAL_15]], [[VAL_23]], [[VAL_23]], [[VAL_23]], [[VAL_21]], [[VAL_16]], [[VAL_21]], [[VAL_21]], [[VAL_23]], [[VAL_23]], [[VAL_28]], [[VAL_29]], [[VAL_23]], [[VAL_23]], [[VAL_23]], [[VAL_23]]) {fused_activation_function = "TANH", time_major = true} : (tensor<3x1x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, none, none, none, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, none, none, tensor<1x3xf32>, tensor<1x3xf32>, none, none, none, none) -> tensor<3x1x3xf32>
# CHECK: [[VAL_31:%.*]]:3 = "tfl.unpack"([[VAL_30]]) {axis = 0 : i32, num = 3 : i32} : (tensor<3x1x3xf32>) -> (tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>)
# CHECK: return [[VAL_31]]#2 : tensor<1x3xf32>

View File

@ -38,17 +38,26 @@ versions {
producer: 27
}
# CHECK: func @main(%arg0: tensor<10xi32>, %arg1: tensor<10xi32>) -> tensor<10xi32>
# CHECK: attributes {tf.entry_function = {inputs = "input0,input1", outputs = "Add"}} {
# CHECK: %[[add:.*]], %[[add_control:.*]] = tf_executor.island wraps "tf.Add"(%arg0, %arg1)
# CHECK: fetch %[[add]]
# CHECK-LABEL: func @main
# CHECK-SAME: (%[[ARG_0:[a-z0-9]+]]: tensor<10xi32>, %[[ARG_1:[a-z0-9]+]]: tensor<10xi32>) -> tensor<10xi32>
# CHECK-SAME: control_outputs = ""
# CHECK-SAME: inputs = "input0,input1"
# CHECK-SAME: outputs = "Add"
# CHECK: %[[add:.*]], %[[add_control:.*]] = tf_executor.island wraps "tf.Add"(%[[ARG_0]], %[[ARG_1]])
# CHECK: fetch %[[add]]
# SOME: func @main(%arg0: tensor<10xi32>, %arg1: tensor<10xi32>) -> tensor<10xi32>
# SOME: attributes {tf.entry_function = {inputs = "input0,input1", outputs = "Add"}} {
# SOME: %[[add:.*]], %[[add_control:.*]] = tf_executor.island wraps "tf.Add"(%arg0, %arg1)
# SOME: fetch %[[add]]
# SOME-LABEL: func @main
# SOME-SAME: (%[[ARG_0:[a-z0-9]+]]: tensor<10xi32>, %[[ARG_1:[a-z0-9]+]]: tensor<10xi32>) -> tensor<10xi32>
# SOME-SAME: control_outputs = ""
# SOME-SAME: inputs = "input0,input1"
# SOME-SAME: outputs = "Add"
# SOME: %[[add:.*]], %[[add_control:.*]] = tf_executor.island wraps "tf.Add"(%[[ARG_0]], %[[ARG_1]])
# SOME: fetch %[[add]]
# NONE: func @main(%arg0: tensor<10xi32>, %arg1: tensor<10xi32>) -> tensor<10xi32>
# NONE: attributes {tf.entry_function = {inputs = "input0,input1", outputs = "Add"}} {
# NONE: %[[add:.*]], %[[add_control:.*]] = tf_executor.island wraps "tf.Add"(%arg0, %arg1)
# NONE: fetch %[[add]]
# NONE-LABEL: func @main
# NONE-SAME: (%[[ARG_0:[a-z0-9]+]]: tensor<10xi32>, %[[ARG_1:[a-z0-9]+]]: tensor<10xi32>) -> tensor<10xi32>
# NONE-SAME: control_outputs = ""
# NONE-SAME: inputs = "input0,input1"
# NONE-SAME: outputs = "Add"
# NONE: %[[add:.*]], %[[add_control:.*]] = tf_executor.island wraps "tf.Add"(%[[ARG_0]], %[[ARG_1]])
# NONE: fetch %[[add]]

View File

@ -20,8 +20,11 @@ versions {
producer: 27
}
# CHECK: func @main(%[[ARG_0:[a-z0-9]+]]: tensor<8xi32>) -> tensor<8xi32>
# CHECK: attributes {tf.entry_function = {inputs = "arg", outputs = "arg"}} {
# CHECK: %[[GRAPH:[0-9]+]] = tf_executor.graph
# CHECK: tf_executor.fetch %[[ARG_0]]
# CHECK: return %[[GRAPH]]
# CHECK-LABEL: func @main
# CHECK-SAME: (%[[ARG_0:[a-z0-9]+]]: tensor<8xi32>) -> tensor<8xi32>
# CHECK-SAME: control_outputs = ""
# CHECK-SAME: inputs = "arg"
# CHECK-SAME: outputs = "arg"
# CHECK: %[[GRAPH:[0-9]+]] = tf_executor.graph
# CHECK: tf_executor.fetch %[[ARG_0]]
# CHECK: return %[[GRAPH]]

View File

@ -14,8 +14,11 @@ versions {
producer: 27
}
# CHECK: func @main(%[[ARG_0:[a-z0-9]+]]: tensor<8xi32>) -> tensor<8xi32>
# CHECK: attributes {tf.entry_function = {inputs = "input", outputs = "input"}} {
# CHECK: %[[GRAPH:[0-9]+]] = tf_executor.graph
# CHECK: tf_executor.fetch %[[ARG_0]]
# CHECK: return %[[GRAPH]]
# CHECK-LABEL: func @main
# CHECK-SAME: (%[[ARG_0:[a-z0-9]+]]: tensor<8xi32>) -> tensor<8xi32>
# CHECK-SAME: control_outputs = ""
# CHECK-SAME: inputs = "input"
# CHECK-SAME: outputs = "input"
# CHECK: %[[GRAPH:[0-9]+]] = tf_executor.graph
# CHECK: tf_executor.fetch %[[ARG_0]]
# CHECK: return %[[GRAPH]]

View File

@ -59,10 +59,13 @@ library {
versions {
}
# CHECK: func @main(%[[ARG_0:[a-z0-9]+]]: tensor<f32>) -> tensor<f32>
# CHECK: tf.entry_function = {inputs = "input", outputs = "output_node"}
# CHECK: %[[GRAPH:[0-9]+]] = tf_executor.graph
# CHECK: %[[CONST:.*]], %[[CONST_control:.*]] = tf_executor.island wraps "tf.Const"()
# CHECK: %[[OUTPUT:.*]], %[[OUTPUT_control:.*]] = tf_executor.island wraps "tf.Identity"(%[[CONST]])
# CHECK: tf_executor.fetch %[[OUTPUT]]
# CHECK: return %[[GRAPH]]
# CHECK-LABEL: func @main
# CHECK-SAME: (%[[ARG_0:[a-z0-9]+]]: tensor<f32>) -> tensor<f32>
# CHECK-SAME: control_outputs = ""
# CHECK-SAME: inputs = "input"
# CHECK-SAME: outputs = "output_node"
# CHECK: %[[GRAPH:[0-9]+]] = tf_executor.graph
# CHECK: %[[CONST:.*]], %[[CONST_control:.*]] = tf_executor.island wraps "tf.Const"()
# CHECK: %[[OUTPUT:.*]], %[[OUTPUT_control:.*]] = tf_executor.island wraps "tf.Identity"(%[[CONST]])
# CHECK: tf_executor.fetch %[[OUTPUT]]
# CHECK: return %[[GRAPH]]

View File

@ -2,12 +2,15 @@
# Verify that we match correctly the input / output when they are scalar.
# CHECK: func @main(%arg0: tensor<f32> {tf.device = "/device:CPU:0"}) -> (tensor<f32>, tensor<f32>)
# CHECK: attributes {tf.entry_function = {inputs = "input", outputs = "out:1,out"}} {
# CHECK-LABEL: func @main
# CHECK-SAME: (%{{[a-z0-9]+}}: tensor<f32> {tf.device = "/device:CPU:0"}) -> (tensor<f32>, tensor<f32>)
# CHECK-SAME: control_outputs = ""
# CHECK-SAME: inputs = "input"
# CHECK-SAME: outputs = "out:1,out"
# CHECK: tf.Relu
# CHECK: %[[IDENTITY:[a-z_0-9]+]]:2, {{.*}} = tf_executor.island wraps "tf.IdentityN"
# CHECK: fetch %[[IDENTITY]]#1, %[[IDENTITY]]#0 : tensor<f32>, tensor<f32>
# CHECK: tf.Relu
# CHECK: %[[IDENTITY:[a-z_0-9]+]]:2, {{.*}} = tf_executor.island wraps "tf.IdentityN"
# CHECK: etch %[[IDENTITY]]#1, %[[IDENTITY]]#0 : tensor<f32>, tensor<f32>
node {
name: "input"

View File

@ -268,18 +268,24 @@ versions {
# function. Rets that happen to coincide with a feed should have its value be
# of the feed.
#
# CHECK: func @main(%[[ARG_0:.*]]: tensor<f32>, %[[ARG_1:.*]]: tensor<f32>) -> (tensor<f32>, tensor<f32>, tensor<f32>)
# CHECK: attributes {tf.entry_function = {inputs = "z:1,z:2", outputs = "z:2,z:1,a:0"}}
# CHECK: %{{.*}}, %[[ASSIGN_ADD_CTRL:.*]] = tf_executor.island wraps "tf.AssignAdd"
# CHECK: %{{.*}}, %{{.*}} = tf_executor.island(%[[ASSIGN_ADD_CTRL]]) wraps "tf.IdentityN"
# CHECK: %[[ADD:.*]], %{{.*}} = tf_executor.island wraps "tf.Add"(%[[ARG_0]], %[[ARG_1]])
# CHECK: tf_executor.fetch %[[ARG_1]], %[[ARG_0]], %[[ADD]]
# CHECK-LABEL: func @main
# CHECK-SAME: (%[[ARG_0:.*]]: tensor<f32>, %[[ARG_1:.*]]: tensor<f32>) -> (tensor<f32>, tensor<f32>, tensor<f32>)
# CHECK-SAME: control_outputs = ""
# CHECK-SAME: inputs = "z:1,z:2"
# CHECK-SAME: outputs = "z:2,z:1,a:0"
# CHECK: %{{.*}}, %[[ASSIGN_ADD_CTRL:.*]] = tf_executor.island wraps "tf.AssignAdd"
# CHECK: %{{.*}}, %{{.*}} = tf_executor.island(%[[ASSIGN_ADD_CTRL]]) wraps "tf.IdentityN"
# CHECK: %[[ADD:.*]], %{{.*}} = tf_executor.island wraps "tf.Add"(%[[ARG_0]], %[[ARG_1]])
# CHECK: tf_executor.fetch %[[ARG_1]], %[[ARG_0]], %[[ADD]]
# Test when non zero index output tensors are feeds, remaining ops that are
# unreachable are pruned if pruning is enabled.
#
# PRUNE: func @main(%[[ARG_0:.*]]: tensor<f32>, %[[ARG_1:.*]]: tensor<f32>) -> (tensor<f32>, tensor<f32>, tensor<f32>)
# PRUNE: attributes {tf.entry_function = {inputs = "z:1,z:2", outputs = "z:2,z:1,a:0"}}
# PRUNE-LABEL: func @main
# PRUNE-SAME: (%[[ARG_0:.*]]: tensor<f32>, %[[ARG_1:.*]]: tensor<f32>) -> (tensor<f32>, tensor<f32>, tensor<f32>)
# PRUNE-SAME: control_outputs = ""
# PRUNE-SAME: inputs = "z:1,z:2"
# PRUNE-SAME: outputs = "z:2,z:1,a:0"
# PRUNE-NOT: "tf.Const"
# PRUNE-NOT: "tf.VariableV2"
# PRUNE-NOT: "tf.Assign"
@ -292,9 +298,12 @@ versions {
# Test when non zero index output tensors are feeds, remaining ops that are
# unreachable are preserved if pruning is not enabled.
#
# PRESERVE: func @main(%[[ARG_0:.*]]: tensor<f32>, %[[ARG_1:.*]]: tensor<f32>) -> (tensor<f32>, tensor<f32>)
# PRESERVE: attributes {tf.entry_function = {inputs = "z:1,z:2", outputs = "z:0,a:0"}}
# PRESERVE: %{{.*}}, %[[ASSIGN_ADD_CTRL:.*]] = tf_executor.island wraps "tf.AssignAdd"
# PRESERVE: %[[IDENTITY_N:.*]]:3, %{{.*}} = tf_executor.island(%[[ASSIGN_ADD_CTRL]]) wraps "tf.IdentityN"
# PRESERVE: %[[ADD:.*]], %{{.*}} = tf_executor.island wraps "tf.Add"(%[[ARG_0]], %[[ARG_1]])
# PRESERVE: tf_executor.fetch %[[IDENTITY_N]]#0, %[[ADD]]
# PRESERVE-LABEL: func @main
# PRESERVE-SAME: (%[[ARG_0:.*]]: tensor<f32>, %[[ARG_1:.*]]: tensor<f32>) -> (tensor<f32>, tensor<f32>)
# PRESERVE-SAME: control_outputs = ""
# PRESERVE-SAME: inputs = "z:1,z:2"
# PRESERVE-SAME: outputs = "z:0,a:0"
# PRESERVE: %{{.*}}, %[[ASSIGN_ADD_CTRL:.*]] = tf_executor.island wraps "tf.AssignAdd"
# PRESERVE: %[[IDENTITY_N:.*]]:3, %{{.*}} = tf_executor.island(%[[ASSIGN_ADD_CTRL]]) wraps "tf.IdentityN"
# PRESERVE: %[[ADD:.*]], %{{.*}} = tf_executor.island wraps "tf.Add"(%[[ARG_0]], %[[ARG_1]])
# PRESERVE: tf_executor.fetch %[[IDENTITY_N]]#0, %[[ADD]]

View File

@ -0,0 +1,203 @@
# RUN: tf-mlir-translate -graphdef-to-mlir %s -tf-control-output-arrays=AssignAdd -o - | FileCheck %s --dump-input=fail
# RUN: tf-mlir-translate -graphdef-to-mlir %s -tf-prune-unused-nodes -tf-control-output-arrays=AssignAdd -o - | FileCheck --check-prefix=PRUNE %s --dump-input=fail
# RUN: tf-mlir-translate -graphdef-to-mlir %s -tf-prune-unused-nodes -tf-control-output-arrays=Variable/Assign,AssignAdd -o - | FileCheck --check-prefix=PRESERVE %s --dump-input=fail
# Generated in Python via
# ```
# import tensorflow as tf
#
# with tf.compat.v1.Graph().as_default() as g:
# var = tf.Variable(2.0)
# var_add = var.assign_add(3.0)
# ```
node {
name: "Variable/initial_value"
op: "Const"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_FLOAT
tensor_shape {
}
float_val: 2.0
}
}
}
}
node {
name: "Variable"
op: "VariableV2"
attr {
key: "container"
value {
s: ""
}
}
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "shape"
value {
shape {
}
}
}
attr {
key: "shared_name"
value {
s: ""
}
}
}
node {
name: "Variable/Assign"
op: "Assign"
input: "Variable"
input: "Variable/initial_value"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
attr {
key: "_class"
value {
list {
s: "loc:@Variable"
}
}
}
attr {
key: "use_locking"
value {
b: true
}
}
attr {
key: "validate_shape"
value {
b: true
}
}
}
node {
name: "Variable/read"
op: "Identity"
input: "Variable"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
attr {
key: "_class"
value {
list {
s: "loc:@Variable"
}
}
}
}
node {
name: "AssignAdd/value"
op: "Const"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_FLOAT
tensor_shape {
}
float_val: 3.0
}
}
}
}
node {
name: "AssignAdd"
op: "AssignAdd"
input: "Variable"
input: "AssignAdd/value"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
attr {
key: "_class"
value {
list {
s: "loc:@Variable"
}
}
}
attr {
key: "use_locking"
value {
b: false
}
}
}
versions {
producer: 309
}
# Tests single target node with no pruning set. All nodes will remain in the
# graph and the target node is added to the graph fetch as a control.
#
# CHECK-LABEL: func @main
# CHECK-SAME: control_outputs = "AssignAdd"
# CHECK-SAME: inputs = ""
# CHECK-SAME: outputs = ""
# CHECK: %{{.*}}, %[[ASSIGN_ADD_CTRL:.*]] = tf_executor.island wraps "tf.AssignAdd"
# CHECK: tf_executor.fetch %[[ASSIGN_ADD_CTRL]]
# Tests single target node with pruning set. Unreachable nodes from the fetch,
# including side effecting nodes, will be removed. The target node is added to
# the graph fetch as a control.
#
# PRUNE-LABEL: func @main
# PRUNE-SAME: control_outputs = "AssignAdd"
# PRUNE-SAME: inputs = ""
# PRUNE-SAME: outputs = ""
# PRUNE-NOT: "tf.Assign"
# PRUNE-NOT: "tf.Identity"
# PRUNE-DAG: %[[CONST:.*]], %{{.*}} = tf_executor.island wraps "tf.Const"
# PRUNE-DAG: %[[VAR:.*]], %{{.*}} = tf_executor.island wraps "tf.VariableV2"
# PRUNE: %{{.*}}, %[[ASSIGN_ADD_CTRL:.*]] = tf_executor.island wraps "tf.AssignAdd"(%[[VAR]], %[[CONST]])
# PRUNE-NEXT: tf_executor.fetch %[[ASSIGN_ADD_CTRL]]
# Tests multiple target nodes with pruning set. Unreachable nodes from the
# fetch, including side effecting nodes, will be removed. The target nodes are
# added to the graph fetch as controls.
#
# PRESERVE-LABEL: func @main
# PRESERVE-SAME: control_outputs = "Variable/Assign,AssignAdd"
# PRESERVE-SAME: inputs = ""
# PRESERVE-SAME: outputs = ""
# PRESERVE-NOT: "tf.Identity"
# PRESERVE: %[[VAR:.*]], %{{.*}} = tf_executor.island wraps "tf.VariableV2"
# PRESERVE-DAG: %{{.*}}, %[[ASSIGN_CTRL:.*]] = tf_executor.island wraps "tf.Assign"
# PRESERVE-DAG: %{{.*}}, %[[ASSIGN_ADD_CTRL:.*]] = tf_executor.island wraps "tf.AssignAdd"
# PRESERVE-NEXT: tf_executor.fetch %[[ASSIGN_CTRL]], %[[ASSIGN_ADD_CTRL]]

View File

@ -570,6 +570,13 @@ StatusOr<std::pair<Node*, bool>> ImporterBase::CreatePlaceholderNodeForFeed(
edge->dst_input()));
}
// TODO(lyandy): Preserve control dependencies properly by not forwarding
// control dependencies to data outputs and not removing single output nodes.
// When a data output is replaced as a feed, unless there is another non feed
// data output or an explicit control output used by the same node, transitive
// control dependencies are not to be executed. For single output nodes,
// Placeholders can be converted to a NoOp if there are no uses, and
// PlaceholderWithDefault can be converted to an Identity.
for (const auto* edge : control_edges) {
graph_->AddControlEdge(placeholder_node, edge->dst());
graph_->RemoveControlEdge(edge);
@ -616,6 +623,9 @@ Status ImporterBase::GetInputOutputNodes(
}
}
for (const auto& control_output : specs_.control_outputs)
TF_RETURN_IF_ERROR(add_node(control_output));
return Status::OK();
}
@ -658,7 +668,7 @@ Status ImporterBase::AddNodesToShapeRefiner() {
DataType dtype = array_info.imported_dtype;
// Uses the existing output type if it isn't specified by the user.
if (dtype == DT_INVALID) {
dtype = node->output_type(0);
dtype = node->output_type(index);
}
TF_ASSIGN_OR_RETURN(
@ -1787,10 +1797,10 @@ class GraphDefImporter : public ImporterBase {
absl::InlinedVector<std::pair<int64_t, int64_t>, 4>*
resource_arg_unique_ids);
// Finds the function's control ret nodes based on supplied node names in
// `control_outputs`. If `control_outputs` are not unique or a control ret
// node is missing, an error will be returned.
Status GetControlRetsFromFunctionGraph(
// Finds the graph's target nodes/function's control ret nodes based on
// supplied node names in `control_outputs`. If `control_outputs` are not
// unique or a control ret node is missing, an error will be returned.
Status GetControlRetsFromGraph(
llvm::ArrayRef<std::string> control_outputs,
absl::InlinedVector<Node*, 4>* control_ret_nodes);
};
@ -1827,8 +1837,8 @@ StatusOr<mlir::OwningModuleRef> GraphDefImporter::Convert(
importer.GetArgsRetsAndTypesFromFunctionGraph(
context, &arg_nodes, &ret_nodes, &resource_arg_unique_ids));
TF_RETURN_IF_ERROR(importer.GetControlRetsFromFunctionGraph(
specs.control_outputs, &control_ret_nodes));
TF_RETURN_IF_ERROR(importer.GetControlRetsFromGraph(specs.control_outputs,
&control_ret_nodes));
if (!arg_nodes.empty() || !ret_nodes.empty() ||
!control_ret_nodes.empty()) {
@ -1858,10 +1868,14 @@ StatusOr<mlir::OwningModuleRef> GraphDefImporter::Convert(
TF_ASSIGN_OR_RETURN(func_type, importer.InferMainFunctionType(
specs, context, &arg_nodes, &ret_nodes));
TF_RETURN_IF_ERROR(importer.GetControlRetsFromGraph(specs.control_outputs,
&control_ret_nodes));
// TODO(prakalps): Refactor to keep tf.entry_function attribute encoding and
// decoding in a centralized place.
// Record the input and output mapping.
if (!specs.inputs.empty() || !specs.outputs.empty()) {
if (!specs.inputs.empty() || !specs.outputs.empty() ||
!specs.control_outputs.empty()) {
mlir::Builder b(context);
std::string s;
llvm::raw_string_ostream ss(s);
@ -1873,9 +1887,14 @@ StatusOr<mlir::OwningModuleRef> GraphDefImporter::Convert(
s.clear();
mlir::interleave(specs.outputs, ss, ",");
auto outputs = b.getNamedAttr("outputs", b.getStringAttr(ss.str()));
s.clear();
mlir::interleave(specs.control_outputs, ss, ",");
auto control_outputs =
b.getNamedAttr("control_outputs", b.getStringAttr(ss.str()));
attrs.push_back(b.getNamedAttr("tf.entry_function",
b.getDictionaryAttr({inputs, outputs})));
attrs.push_back(b.getNamedAttr(
"tf.entry_function",
b.getDictionaryAttr({inputs, outputs, control_outputs})));
}
}
@ -2064,7 +2083,7 @@ GraphDefImporter::GetArgsRetsAndTypesFromFunctionGraph(
return builder.getFunctionType(arg_types, ret_types);
}
Status GraphDefImporter::GetControlRetsFromFunctionGraph(
Status GraphDefImporter::GetControlRetsFromGraph(
llvm::ArrayRef<std::string> control_outputs,
absl::InlinedVector<Node*, 4>* control_ret_nodes) {
if (control_outputs.empty()) return Status::OK();

View File

@ -42,8 +42,7 @@ struct GraphImportConfig {
InputArrays inputs;
// name:index strings for the data outputs.
std::vector<string> outputs;
// name strings for the control outputs. This is currently only used when
// `graph_as_function` is set.
// name strings for the control outputs.
std::vector<string> control_outputs;
// Setting prune_unused_nodes to true, would prune unreachable nodes if
// output_arrays is specified.

View File

@ -50,8 +50,7 @@ opt<std::string> output_arrays(
// NOLINTNEXTLINE
opt<std::string> control_output_arrays(
"tf-control-output-arrays",
llvm::cl::desc("Control output node names, separated by ',', for main "
"graphs that are functions"),
llvm::cl::desc("Control output node names, separated by ','"),
llvm::cl::init(""));
// NOLINTNEXTLINE