Add support for target nodes of TensorFlow Graphs in TF MLIR importer.
It is possible to have target nodes/control rets for nodes that must execute but may not have data output(s). For example, side effecting ops, like resource writes, may rely on being set as a target node to be executed. This also unifies control rets support added for function graphs. PiperOrigin-RevId: 302683770 Change-Id: I33a66ae0ce08ac2230b575caae2cc752636d0d8b
This commit is contained in:
parent
e1afcc5feb
commit
463bec0d92
@ -36,8 +36,11 @@ versions {
|
||||
producer: 27
|
||||
}
|
||||
|
||||
# CHECK: func @main(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<*xi32>
|
||||
# CHECK: attributes {tf.entry_function = {inputs = "input0,input1", outputs = "output"}} {
|
||||
# CHECK-NEXT: %0 = "tf.BannaPotatoSaladWithColeslaw"(%arg0, %arg1) {T = i32, device = ""} : (tensor<4xi32>, tensor<4xi32>) -> tensor<*xi32>
|
||||
# CHECK-NEXT: return %0 : tensor<*xi32>
|
||||
# CHECK-NEXT: }
|
||||
# CHECK-LABEL: func @main
|
||||
# CHECK-SAME: (%[[ARG_0:[a-z0-9]+]]: tensor<4xi32>, %[[ARG_1:[a-z0-9]+]]: tensor<4xi32>) -> tensor<*xi32>
|
||||
# CHECK-SAME: control_outputs = ""
|
||||
# CHECK-SAME inputs = "input0,input1"
|
||||
# CHECK-SAME: outputs = "output"
|
||||
# CHECK-NEXT: %[[OP:[a-z0-9]+]] = "tf.BannaPotatoSaladWithColeslaw"(%[[ARG_0]], %[[ARG_1]]) {T = i32, device = ""} : (tensor<4xi32>, tensor<4xi32>) -> tensor<*xi32>
|
||||
# CHECK-NEXT: return %[[OP]] : tensor<*xi32>
|
||||
# CHECK-NEXT: }
|
||||
|
@ -443,15 +443,18 @@ node {
|
||||
}
|
||||
}
|
||||
|
||||
# MLIR-LABEL: func @main(%arg0: tensor<1x1x1x256x!quant.uniform<i8:f32, 0.21632751372549019:27>>) -> tensor<1x6x31x!quant.uniform<i8:f32, 0.09363494573854933:22>>
|
||||
# MLIR: attributes {tf.entry_function = {inputs = "input", outputs = "output"}
|
||||
# MLIR: %[[shape:.*]] = constant dense<[1, -1, 31]> : tensor<3xi32>
|
||||
# MLIR: %[[bias:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<186x!quant.uniform<i32:f32:0
|
||||
# MLIR: %[[weight:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<186x1x1x256x!quant.uniform<i8<-127:127>:f32:0, {0.12581039038230116,
|
||||
# MLIR: %[[conv:.*]] = "tfl.conv_2d"(%arg0, %[[weight]], %[[bias]]) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32}
|
||||
# MLIR: %[[reshape:.*]] = "tfl.reshape"(%[[conv]], %[[shape]]) : (tensor<1x1x1x186x!quant.uniform<i8:f32, 0.09363494573854933:22>>, tensor<3xi32>)
|
||||
# MLIR: return %[[reshape]] : tensor<1x6x31x!quant.uniform<i8:f32, 0.09363494573854933:22>>
|
||||
# MLIR: }
|
||||
# MLIR-LABEL: func @main
|
||||
# MLIR-SAME: (%[[ARG_0:[a-z0-9]+]]: tensor<1x1x1x256x!quant.uniform<i8:f32, 0.21632751372549019:27>>) -> tensor<1x6x31x!quant.uniform<i8:f32, 0.09363494573854933:22>>
|
||||
# MLIR-SAME: control_outputs = ""
|
||||
# MLIR-SAME: inputs = "input"
|
||||
# MLIR-SAME: outputs = "output"
|
||||
# MLIR: %[[shape:.*]] = constant dense<[1, -1, 31]> : tensor<3xi32>
|
||||
# MLIR: %[[bias:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<186x!quant.uniform<i32:f32:0
|
||||
# MLIR: %[[weight:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<186x1x1x256x!quant.uniform<i8<-127:127>:f32:0, {0.12581039038230116,
|
||||
# MLIR: %[[conv:.*]] = "tfl.conv_2d"(%[[ARG_0]], %[[weight]], %[[bias]]) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32}
|
||||
# MLIR: %[[reshape:.*]] = "tfl.reshape"(%[[conv]], %[[shape]]) : (tensor<1x1x1x186x!quant.uniform<i8:f32, 0.09363494573854933:22>>, tensor<3xi32>)
|
||||
# MLIR: return %[[reshape]] : tensor<1x6x31x!quant.uniform<i8:f32, 0.09363494573854933:22>>
|
||||
# MLIR: }
|
||||
|
||||
# CHECK-LABEL: {
|
||||
# CHECK: version: 3,
|
||||
|
@ -47,7 +47,10 @@ versions {
|
||||
producer: 27
|
||||
}
|
||||
|
||||
# CHECK: func @main(%arg0: tensor<4xi32>) -> tensor<4xi32>
|
||||
# CHECK: attributes {tf.entry_function = {inputs = "input", outputs = "output"}} {
|
||||
# CHECK-NEXT: return %arg0 : tensor<4xi32>
|
||||
# CHECK-NEXT: }
|
||||
# CHECK-LABEL: func @main
|
||||
# CHECK-SAME: (%[[ARG_0:[a-z0-9]+]]: tensor<4xi32>) -> tensor<4xi32>
|
||||
# CHECK-SAME: control_outputs = ""
|
||||
# CHECK-SAME: inputs = "input"
|
||||
# CHECK-SAME: outputs = "output"
|
||||
# CHECK-NEXT: return %[[ARG_0]] : tensor<4xi32>
|
||||
# CHECK-NEXT: }
|
||||
|
@ -137,9 +137,12 @@ versions {
|
||||
producer: 198
|
||||
}
|
||||
|
||||
# CHECK: func @main([[VAL_0:%.*]]: tensor<1x8x8x2xi32>) -> (tensor<1x8x8x2xi32>, tensor<i32>, tensor<10xi32>)
|
||||
# CHECK: attributes {tf.entry_function = {inputs = "unranked", outputs = "unranked,static,static_10"}} {
|
||||
# CHECK: [[VAL_1:%.*]] = constant dense<0> : tensor<10xi32>
|
||||
# CHECK: [[VAL_2:%.*]] = constant dense<0> : tensor<i32>
|
||||
# CHECK: return [[VAL_0]], [[VAL_2]], [[VAL_1]] : tensor<1x8x8x2xi32>, tensor<i32>, tensor<10xi32>
|
||||
# CHECK: }
|
||||
# CHECK-LABEL: func @main
|
||||
# CHECK-SAME: ([[VAL_0:%.*]]: tensor<1x8x8x2xi32>) -> (tensor<1x8x8x2xi32>, tensor<i32>, tensor<10xi32>)
|
||||
# CHECK-SAME: control_outputs = ""
|
||||
# CHECK-SAME: inputs = "unranked"
|
||||
# CHECK-SAME: outputs = "unranked,static,static_10"
|
||||
# CHECK: [[VAL_1:%.*]] = constant dense<0> : tensor<10xi32>
|
||||
# CHECK: [[VAL_2:%.*]] = constant dense<0> : tensor<i32>
|
||||
# CHECK: return [[VAL_0]], [[VAL_2]], [[VAL_1]] : tensor<1x8x8x2xi32>, tensor<i32>, tensor<10xi32>
|
||||
# CHECK: }
|
||||
|
@ -7783,37 +7783,40 @@ node {
|
||||
library {
|
||||
}
|
||||
|
||||
# CHECK: func @main(%arg0: tensor<1x3x3xf32>) -> tensor<1x3xf32>
|
||||
# CHECK: attributes {tf.entry_function = {inputs = "INPUT", outputs = "OUTPUT"}} {
|
||||
# CHECK: [[VAL_1:%.*]] = constant dense<{{\[\[}}-0.400154352, 0.739109992, 0.201825857], [0.678572893, 0.32076478, 0.949867963], [-0.807729483, -5.324750e-01, 0.148033619]]> : tensor<3x3xf32>
|
||||
# CHECK: [[VAL_2:%.*]] = constant dense<{{\[\[}}0.886177539, -0.606141329, -0.451275587], [0.325554609, 0.691527605, -0.676239967], [0.219799042, 0.626042128, -0.597596407]]> : tensor<3x3xf32>
|
||||
# CHECK: [[VAL_3:%.*]] = constant dense<{{\[\[}}-0.493826151, -0.391061306, -0.349843264], [-0.0213134289, 0.558384657, -0.51513052], [0.427886248, 0.618100405, -0.187585592]]> : tensor<3x3xf32>
|
||||
# CHECK: [[VAL_4:%.*]] = constant dense<{{\[\[}}0.444335222, -0.133341789, 0.839591503], [0.445418358, -0.571707964, 0.569707394], [0.465010405, -0.990037918, -0.632481337]]> : tensor<3x3xf32>
|
||||
# CHECK: [[VAL_5:%.*]] = constant dense<{{\[\[}}-0.138204336, -0.10879755, -0.135128736], [0.94797182, -8.713360e-01, -0.792336463], [0.0339827538, -0.539326906, 8.906350e-01]]> : tensor<3x3xf32>
|
||||
# CHECK: [[VAL_6:%.*]] = constant dense<{{\[\[}}0.513064623, -0.692989588, 0.547988653], [0.0653710365, 0.576977491, 0.966733217], [0.0130724907, 0.247342348, 0.317092657]]> : tensor<3x3xf32>
|
||||
# CHECK: [[VAL_7:%.*]] = constant dense<{{\[\[}}0.230039358, -0.182297707, -0.352231741], [-0.805100203, -0.220300436, -0.669503212], [0.278807402, -0.201502323, -0.627609729]]> : tensor<3x3xf32>
|
||||
# CHECK: [[VAL_8:%.*]] = constant dense<{{\[\[}}-0.207589626, -0.756766081, -0.853258133], [-0.269270182, 0.0468223095, -0.353052378], [-0.0702953338, 0.0725159645, -0.817753077]]> : tensor<3x3xf32>
|
||||
# CHECK: [[VAL_9:%.*]] = constant dense<[0.171322107, -0.153412342, 0.591750383]> : tensor<3xf32>
|
||||
# CHECK: [[VAL_10:%.*]] = constant dense<[-0.671292543, 0.411814928, 0.560465336]> : tensor<3xf32>
|
||||
# CHECK: [[VAL_11:%.*]] = constant dense<[0.403919935, -0.882057666, -0.894463062]> : tensor<3xf32>
|
||||
# CHECK: [[VAL_12:%.*]] = constant dense<{{\[\[}}-0.936182261, -0.935433864, 0.288229942], [-0.243383884, -0.628288031, -0.477061749], [-0.514976501, -0.903514862, 6.728170e-01]]> : tensor<3x3xf32>
|
||||
# CHECK: [[VAL_13:%.*]] = constant dense<{{\[\[}}0.18183589, 0.616135359, -0.167827845], [0.734281301, 0.958347797, -0.878054618], [0.369523764, -0.969005823, -0.881014585]]> : tensor<3x3xf32>
|
||||
# CHECK: [[VAL_14:%.*]] = constant dense<{{\[\[}}-5.087240e-01, -0.588907719, 0.471896172], [-0.508019447, -0.0157074928, -0.804120779], [-0.978842973, 0.00160336494, -0.978532075]]> : tensor<3x3xf32>
|
||||
# CHECK: [[VAL_15:%.*]] = constant dense<{{\[\[}}-0.616786718, 0.892614365, 0.671324968], [-0.842380046, -0.358094931, 0.821366549], [0.790347338, 0.71222949, 0.0690443515]]> : tensor<3x3xf32>
|
||||
# CHECK: [[VAL_16:%.*]] = constant dense<1.000000e+00> : tensor<3xf32>
|
||||
# CHECK: [[VAL_17:%.*]] = constant dense<{{\[\[}}0.782244444, -0.0446639061, 0.848498106], [-0.579102755, -0.407756329, 0.442389727], [0.00566458702, 0.5984025, 0.629857302]]> : tensor<3x3xf32>
|
||||
# CHECK: [[VAL_18:%.*]] = constant dense<{{\[\[}}0.891112089, -2.786560e-01, 0.966933965], [-0.789963722, 0.057955265, 0.217499971], [-0.698129416, -0.983400583, -0.834380626]]> : tensor<3x3xf32>
|
||||
# CHECK: [[VAL_19:%.*]] = constant dense<{{\[\[}}-0.125753641, 0.32271719, 0.488939524], [0.36119318, 0.982266664, -0.448646784], [0.966353893, -0.767024993, 0.446366787]]> : tensor<3x3xf32>
|
||||
# CHECK: [[VAL_20:%.*]] = constant dense<{{\[\[}}-0.856678485, -0.800494194, 0.716800689], [0.536404848, 0.541643381, -0.35657692], [-0.794646739, 0.137629032, 0.690013885]]> : tensor<3x3xf32>
|
||||
# CHECK: [[VAL_21:%.*]] = constant dense<0.000000e+00> : tensor<3xf32>
|
||||
# CHECK: [[VAL_22:%.*]] = constant dense<0.000000e+00> : tensor<1x3xf32>
|
||||
# CHECK: [[VAL_23:%.*]] = constant unit
|
||||
# CHECK: [[VAL_24:%.*]]:3 = "tfl.unpack"(%arg0) {axis = 1 : i32, num = 3 : i32} : (tensor<1x3x3xf32>) -> (tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>)
|
||||
# CHECK: [[VAL_25:%.*]] = "tfl.pack"([[VAL_24]]#0, [[VAL_24]]#1, [[VAL_24]]#2) {axis = 0 : i32, values_count = 3 : i32} : (tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>) -> tensor<3x1x3xf32>
|
||||
# CHECK: [[VAL_26:%.*]] = constant dense<0.000000e+00> : tensor<1x3xf32>
|
||||
# CHECK: [[VAL_27:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_25]], [[VAL_7]], [[VAL_6]], [[VAL_5]], [[VAL_8]], [[VAL_3]], [[VAL_2]], [[VAL_1]], [[VAL_4]], [[VAL_10]], [[VAL_9]], [[VAL_11]], [[VAL_21]], [[VAL_16]], [[VAL_21]], [[VAL_21]], [[VAL_23]], [[VAL_23]], [[VAL_22]], [[VAL_26]], [[VAL_23]], [[VAL_23]], [[VAL_23]], [[VAL_23]]) {fused_activation_function = "TANH", time_major = true} : (tensor<3x1x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, none, none, tensor<1x3xf32>, tensor<1x3xf32>, none, none, none, none) -> tensor<3x1x3xf32>
|
||||
# CHECK: [[VAL_28:%.*]] = constant dense<0.000000e+00> : tensor<1x3xf32>
|
||||
# CHECK: [[VAL_29:%.*]] = constant dense<0.000000e+00> : tensor<1x3xf32>
|
||||
# CHECK: [[VAL_30:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_27]], [[VAL_19]], [[VAL_18]], [[VAL_17]], [[VAL_20]], [[VAL_14]], [[VAL_13]], [[VAL_12]], [[VAL_15]], [[VAL_23]], [[VAL_23]], [[VAL_23]], [[VAL_21]], [[VAL_16]], [[VAL_21]], [[VAL_21]], [[VAL_23]], [[VAL_23]], [[VAL_28]], [[VAL_29]], [[VAL_23]], [[VAL_23]], [[VAL_23]], [[VAL_23]]) {fused_activation_function = "TANH", time_major = true} : (tensor<3x1x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, none, none, none, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, none, none, tensor<1x3xf32>, tensor<1x3xf32>, none, none, none, none) -> tensor<3x1x3xf32>
|
||||
# CHECK: [[VAL_31:%.*]]:3 = "tfl.unpack"([[VAL_30]]) {axis = 0 : i32, num = 3 : i32} : (tensor<3x1x3xf32>) -> (tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>)
|
||||
# CHECK: return [[VAL_31]]#2 : tensor<1x3xf32>
|
||||
# CHECK-LABEL: func @main
|
||||
# CHECK-SAME: (%[[ARG_0:[a-z0-9]+]]: tensor<1x3x3xf32>) -> tensor<1x3xf32>
|
||||
# CHECK-SAME: control_outputs = ""
|
||||
# CHECK-SAME: inputs = "INPUT"
|
||||
# CHECK-SAME: outputs = "OUTPUT"
|
||||
# CHECK: [[VAL_1:%.*]] = constant dense<{{\[\[}}-0.400154352, 0.739109992, 0.201825857], [0.678572893, 0.32076478, 0.949867963], [-0.807729483, -5.324750e-01, 0.148033619]]> : tensor<3x3xf32>
|
||||
# CHECK: [[VAL_2:%.*]] = constant dense<{{\[\[}}0.886177539, -0.606141329, -0.451275587], [0.325554609, 0.691527605, -0.676239967], [0.219799042, 0.626042128, -0.597596407]]> : tensor<3x3xf32>
|
||||
# CHECK: [[VAL_3:%.*]] = constant dense<{{\[\[}}-0.493826151, -0.391061306, -0.349843264], [-0.0213134289, 0.558384657, -0.51513052], [0.427886248, 0.618100405, -0.187585592]]> : tensor<3x3xf32>
|
||||
# CHECK: [[VAL_4:%.*]] = constant dense<{{\[\[}}0.444335222, -0.133341789, 0.839591503], [0.445418358, -0.571707964, 0.569707394], [0.465010405, -0.990037918, -0.632481337]]> : tensor<3x3xf32>
|
||||
# CHECK: [[VAL_5:%.*]] = constant dense<{{\[\[}}-0.138204336, -0.10879755, -0.135128736], [0.94797182, -8.713360e-01, -0.792336463], [0.0339827538, -0.539326906, 8.906350e-01]]> : tensor<3x3xf32>
|
||||
# CHECK: [[VAL_6:%.*]] = constant dense<{{\[\[}}0.513064623, -0.692989588, 0.547988653], [0.0653710365, 0.576977491, 0.966733217], [0.0130724907, 0.247342348, 0.317092657]]> : tensor<3x3xf32>
|
||||
# CHECK: [[VAL_7:%.*]] = constant dense<{{\[\[}}0.230039358, -0.182297707, -0.352231741], [-0.805100203, -0.220300436, -0.669503212], [0.278807402, -0.201502323, -0.627609729]]> : tensor<3x3xf32>
|
||||
# CHECK: [[VAL_8:%.*]] = constant dense<{{\[\[}}-0.207589626, -0.756766081, -0.853258133], [-0.269270182, 0.0468223095, -0.353052378], [-0.0702953338, 0.0725159645, -0.817753077]]> : tensor<3x3xf32>
|
||||
# CHECK: [[VAL_9:%.*]] = constant dense<[0.171322107, -0.153412342, 0.591750383]> : tensor<3xf32>
|
||||
# CHECK: [[VAL_10:%.*]] = constant dense<[-0.671292543, 0.411814928, 0.560465336]> : tensor<3xf32>
|
||||
# CHECK: [[VAL_11:%.*]] = constant dense<[0.403919935, -0.882057666, -0.894463062]> : tensor<3xf32>
|
||||
# CHECK: [[VAL_12:%.*]] = constant dense<{{\[\[}}-0.936182261, -0.935433864, 0.288229942], [-0.243383884, -0.628288031, -0.477061749], [-0.514976501, -0.903514862, 6.728170e-01]]> : tensor<3x3xf32>
|
||||
# CHECK: [[VAL_13:%.*]] = constant dense<{{\[\[}}0.18183589, 0.616135359, -0.167827845], [0.734281301, 0.958347797, -0.878054618], [0.369523764, -0.969005823, -0.881014585]]> : tensor<3x3xf32>
|
||||
# CHECK: [[VAL_14:%.*]] = constant dense<{{\[\[}}-5.087240e-01, -0.588907719, 0.471896172], [-0.508019447, -0.0157074928, -0.804120779], [-0.978842973, 0.00160336494, -0.978532075]]> : tensor<3x3xf32>
|
||||
# CHECK: [[VAL_15:%.*]] = constant dense<{{\[\[}}-0.616786718, 0.892614365, 0.671324968], [-0.842380046, -0.358094931, 0.821366549], [0.790347338, 0.71222949, 0.0690443515]]> : tensor<3x3xf32>
|
||||
# CHECK: [[VAL_16:%.*]] = constant dense<1.000000e+00> : tensor<3xf32>
|
||||
# CHECK: [[VAL_17:%.*]] = constant dense<{{\[\[}}0.782244444, -0.0446639061, 0.848498106], [-0.579102755, -0.407756329, 0.442389727], [0.00566458702, 0.5984025, 0.629857302]]> : tensor<3x3xf32>
|
||||
# CHECK: [[VAL_18:%.*]] = constant dense<{{\[\[}}0.891112089, -2.786560e-01, 0.966933965], [-0.789963722, 0.057955265, 0.217499971], [-0.698129416, -0.983400583, -0.834380626]]> : tensor<3x3xf32>
|
||||
# CHECK: [[VAL_19:%.*]] = constant dense<{{\[\[}}-0.125753641, 0.32271719, 0.488939524], [0.36119318, 0.982266664, -0.448646784], [0.966353893, -0.767024993, 0.446366787]]> : tensor<3x3xf32>
|
||||
# CHECK: [[VAL_20:%.*]] = constant dense<{{\[\[}}-0.856678485, -0.800494194, 0.716800689], [0.536404848, 0.541643381, -0.35657692], [-0.794646739, 0.137629032, 0.690013885]]> : tensor<3x3xf32>
|
||||
# CHECK: [[VAL_21:%.*]] = constant dense<0.000000e+00> : tensor<3xf32>
|
||||
# CHECK: [[VAL_22:%.*]] = constant dense<0.000000e+00> : tensor<1x3xf32>
|
||||
# CHECK: [[VAL_23:%.*]] = constant unit
|
||||
# CHECK: [[VAL_24:%.*]]:3 = "tfl.unpack"(%[[ARG_0]]) {axis = 1 : i32, num = 3 : i32} : (tensor<1x3x3xf32>) -> (tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>)
|
||||
# CHECK: [[VAL_25:%.*]] = "tfl.pack"([[VAL_24]]#0, [[VAL_24]]#1, [[VAL_24]]#2) {axis = 0 : i32, values_count = 3 : i32} : (tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>) -> tensor<3x1x3xf32>
|
||||
# CHECK: [[VAL_26:%.*]] = constant dense<0.000000e+00> : tensor<1x3xf32>
|
||||
# CHECK: [[VAL_27:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_25]], [[VAL_7]], [[VAL_6]], [[VAL_5]], [[VAL_8]], [[VAL_3]], [[VAL_2]], [[VAL_1]], [[VAL_4]], [[VAL_10]], [[VAL_9]], [[VAL_11]], [[VAL_21]], [[VAL_16]], [[VAL_21]], [[VAL_21]], [[VAL_23]], [[VAL_23]], [[VAL_22]], [[VAL_26]], [[VAL_23]], [[VAL_23]], [[VAL_23]], [[VAL_23]]) {fused_activation_function = "TANH", time_major = true} : (tensor<3x1x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, none, none, tensor<1x3xf32>, tensor<1x3xf32>, none, none, none, none) -> tensor<3x1x3xf32>
|
||||
# CHECK: [[VAL_28:%.*]] = constant dense<0.000000e+00> : tensor<1x3xf32>
|
||||
# CHECK: [[VAL_29:%.*]] = constant dense<0.000000e+00> : tensor<1x3xf32>
|
||||
# CHECK: [[VAL_30:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_27]], [[VAL_19]], [[VAL_18]], [[VAL_17]], [[VAL_20]], [[VAL_14]], [[VAL_13]], [[VAL_12]], [[VAL_15]], [[VAL_23]], [[VAL_23]], [[VAL_23]], [[VAL_21]], [[VAL_16]], [[VAL_21]], [[VAL_21]], [[VAL_23]], [[VAL_23]], [[VAL_28]], [[VAL_29]], [[VAL_23]], [[VAL_23]], [[VAL_23]], [[VAL_23]]) {fused_activation_function = "TANH", time_major = true} : (tensor<3x1x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, none, none, none, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, none, none, tensor<1x3xf32>, tensor<1x3xf32>, none, none, none, none) -> tensor<3x1x3xf32>
|
||||
# CHECK: [[VAL_31:%.*]]:3 = "tfl.unpack"([[VAL_30]]) {axis = 0 : i32, num = 3 : i32} : (tensor<3x1x3xf32>) -> (tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>)
|
||||
# CHECK: return [[VAL_31]]#2 : tensor<1x3xf32>
|
||||
|
@ -38,17 +38,26 @@ versions {
|
||||
producer: 27
|
||||
}
|
||||
|
||||
# CHECK: func @main(%arg0: tensor<10xi32>, %arg1: tensor<10xi32>) -> tensor<10xi32>
|
||||
# CHECK: attributes {tf.entry_function = {inputs = "input0,input1", outputs = "Add"}} {
|
||||
# CHECK: %[[add:.*]], %[[add_control:.*]] = tf_executor.island wraps "tf.Add"(%arg0, %arg1)
|
||||
# CHECK: fetch %[[add]]
|
||||
# CHECK-LABEL: func @main
|
||||
# CHECK-SAME: (%[[ARG_0:[a-z0-9]+]]: tensor<10xi32>, %[[ARG_1:[a-z0-9]+]]: tensor<10xi32>) -> tensor<10xi32>
|
||||
# CHECK-SAME: control_outputs = ""
|
||||
# CHECK-SAME: inputs = "input0,input1"
|
||||
# CHECK-SAME: outputs = "Add"
|
||||
# CHECK: %[[add:.*]], %[[add_control:.*]] = tf_executor.island wraps "tf.Add"(%[[ARG_0]], %[[ARG_1]])
|
||||
# CHECK: fetch %[[add]]
|
||||
|
||||
# SOME: func @main(%arg0: tensor<10xi32>, %arg1: tensor<10xi32>) -> tensor<10xi32>
|
||||
# SOME: attributes {tf.entry_function = {inputs = "input0,input1", outputs = "Add"}} {
|
||||
# SOME: %[[add:.*]], %[[add_control:.*]] = tf_executor.island wraps "tf.Add"(%arg0, %arg1)
|
||||
# SOME: fetch %[[add]]
|
||||
# SOME-LABEL: func @main
|
||||
# SOME-SAME: (%[[ARG_0:[a-z0-9]+]]: tensor<10xi32>, %[[ARG_1:[a-z0-9]+]]: tensor<10xi32>) -> tensor<10xi32>
|
||||
# SOME-SAME: control_outputs = ""
|
||||
# SOME-SAME: inputs = "input0,input1"
|
||||
# SOME-SAME: outputs = "Add"
|
||||
# SOME: %[[add:.*]], %[[add_control:.*]] = tf_executor.island wraps "tf.Add"(%[[ARG_0]], %[[ARG_1]])
|
||||
# SOME: fetch %[[add]]
|
||||
|
||||
# NONE: func @main(%arg0: tensor<10xi32>, %arg1: tensor<10xi32>) -> tensor<10xi32>
|
||||
# NONE: attributes {tf.entry_function = {inputs = "input0,input1", outputs = "Add"}} {
|
||||
# NONE: %[[add:.*]], %[[add_control:.*]] = tf_executor.island wraps "tf.Add"(%arg0, %arg1)
|
||||
# NONE: fetch %[[add]]
|
||||
# NONE-LABEL: func @main
|
||||
# NONE-SAME: (%[[ARG_0:[a-z0-9]+]]: tensor<10xi32>, %[[ARG_1:[a-z0-9]+]]: tensor<10xi32>) -> tensor<10xi32>
|
||||
# NONE-SAME: control_outputs = ""
|
||||
# NONE-SAME: inputs = "input0,input1"
|
||||
# NONE-SAME: outputs = "Add"
|
||||
# NONE: %[[add:.*]], %[[add_control:.*]] = tf_executor.island wraps "tf.Add"(%[[ARG_0]], %[[ARG_1]])
|
||||
# NONE: fetch %[[add]]
|
||||
|
@ -20,8 +20,11 @@ versions {
|
||||
producer: 27
|
||||
}
|
||||
|
||||
# CHECK: func @main(%[[ARG_0:[a-z0-9]+]]: tensor<8xi32>) -> tensor<8xi32>
|
||||
# CHECK: attributes {tf.entry_function = {inputs = "arg", outputs = "arg"}} {
|
||||
# CHECK: %[[GRAPH:[0-9]+]] = tf_executor.graph
|
||||
# CHECK: tf_executor.fetch %[[ARG_0]]
|
||||
# CHECK: return %[[GRAPH]]
|
||||
# CHECK-LABEL: func @main
|
||||
# CHECK-SAME: (%[[ARG_0:[a-z0-9]+]]: tensor<8xi32>) -> tensor<8xi32>
|
||||
# CHECK-SAME: control_outputs = ""
|
||||
# CHECK-SAME: inputs = "arg"
|
||||
# CHECK-SAME: outputs = "arg"
|
||||
# CHECK: %[[GRAPH:[0-9]+]] = tf_executor.graph
|
||||
# CHECK: tf_executor.fetch %[[ARG_0]]
|
||||
# CHECK: return %[[GRAPH]]
|
||||
|
@ -14,8 +14,11 @@ versions {
|
||||
producer: 27
|
||||
}
|
||||
|
||||
# CHECK: func @main(%[[ARG_0:[a-z0-9]+]]: tensor<8xi32>) -> tensor<8xi32>
|
||||
# CHECK: attributes {tf.entry_function = {inputs = "input", outputs = "input"}} {
|
||||
# CHECK: %[[GRAPH:[0-9]+]] = tf_executor.graph
|
||||
# CHECK: tf_executor.fetch %[[ARG_0]]
|
||||
# CHECK: return %[[GRAPH]]
|
||||
# CHECK-LABEL: func @main
|
||||
# CHECK-SAME: (%[[ARG_0:[a-z0-9]+]]: tensor<8xi32>) -> tensor<8xi32>
|
||||
# CHECK-SAME: control_outputs = ""
|
||||
# CHECK-SAME: inputs = "input"
|
||||
# CHECK-SAME: outputs = "input"
|
||||
# CHECK: %[[GRAPH:[0-9]+]] = tf_executor.graph
|
||||
# CHECK: tf_executor.fetch %[[ARG_0]]
|
||||
# CHECK: return %[[GRAPH]]
|
||||
|
@ -59,10 +59,13 @@ library {
|
||||
versions {
|
||||
}
|
||||
|
||||
# CHECK: func @main(%[[ARG_0:[a-z0-9]+]]: tensor<f32>) -> tensor<f32>
|
||||
# CHECK: tf.entry_function = {inputs = "input", outputs = "output_node"}
|
||||
# CHECK: %[[GRAPH:[0-9]+]] = tf_executor.graph
|
||||
# CHECK: %[[CONST:.*]], %[[CONST_control:.*]] = tf_executor.island wraps "tf.Const"()
|
||||
# CHECK: %[[OUTPUT:.*]], %[[OUTPUT_control:.*]] = tf_executor.island wraps "tf.Identity"(%[[CONST]])
|
||||
# CHECK: tf_executor.fetch %[[OUTPUT]]
|
||||
# CHECK: return %[[GRAPH]]
|
||||
# CHECK-LABEL: func @main
|
||||
# CHECK-SAME: (%[[ARG_0:[a-z0-9]+]]: tensor<f32>) -> tensor<f32>
|
||||
# CHECK-SAME: control_outputs = ""
|
||||
# CHECK-SAME: inputs = "input"
|
||||
# CHECK-SAME: outputs = "output_node"
|
||||
# CHECK: %[[GRAPH:[0-9]+]] = tf_executor.graph
|
||||
# CHECK: %[[CONST:.*]], %[[CONST_control:.*]] = tf_executor.island wraps "tf.Const"()
|
||||
# CHECK: %[[OUTPUT:.*]], %[[OUTPUT_control:.*]] = tf_executor.island wraps "tf.Identity"(%[[CONST]])
|
||||
# CHECK: tf_executor.fetch %[[OUTPUT]]
|
||||
# CHECK: return %[[GRAPH]]
|
||||
|
@ -2,12 +2,15 @@
|
||||
|
||||
# Verify that we match correctly the input / output when they are scalar.
|
||||
|
||||
# CHECK: func @main(%arg0: tensor<f32> {tf.device = "/device:CPU:0"}) -> (tensor<f32>, tensor<f32>)
|
||||
# CHECK: attributes {tf.entry_function = {inputs = "input", outputs = "out:1,out"}} {
|
||||
# CHECK-LABEL: func @main
|
||||
# CHECK-SAME: (%{{[a-z0-9]+}}: tensor<f32> {tf.device = "/device:CPU:0"}) -> (tensor<f32>, tensor<f32>)
|
||||
# CHECK-SAME: control_outputs = ""
|
||||
# CHECK-SAME: inputs = "input"
|
||||
# CHECK-SAME: outputs = "out:1,out"
|
||||
|
||||
# CHECK: tf.Relu
|
||||
# CHECK: %[[IDENTITY:[a-z_0-9]+]]:2, {{.*}} = tf_executor.island wraps "tf.IdentityN"
|
||||
# CHECK: fetch %[[IDENTITY]]#1, %[[IDENTITY]]#0 : tensor<f32>, tensor<f32>
|
||||
# CHECK: tf.Relu
|
||||
# CHECK: %[[IDENTITY:[a-z_0-9]+]]:2, {{.*}} = tf_executor.island wraps "tf.IdentityN"
|
||||
# CHECK: etch %[[IDENTITY]]#1, %[[IDENTITY]]#0 : tensor<f32>, tensor<f32>
|
||||
|
||||
node {
|
||||
name: "input"
|
||||
|
@ -268,18 +268,24 @@ versions {
|
||||
# function. Rets that happen to coincide with a feed should have its value be
|
||||
# of the feed.
|
||||
#
|
||||
# CHECK: func @main(%[[ARG_0:.*]]: tensor<f32>, %[[ARG_1:.*]]: tensor<f32>) -> (tensor<f32>, tensor<f32>, tensor<f32>)
|
||||
# CHECK: attributes {tf.entry_function = {inputs = "z:1,z:2", outputs = "z:2,z:1,a:0"}}
|
||||
# CHECK: %{{.*}}, %[[ASSIGN_ADD_CTRL:.*]] = tf_executor.island wraps "tf.AssignAdd"
|
||||
# CHECK: %{{.*}}, %{{.*}} = tf_executor.island(%[[ASSIGN_ADD_CTRL]]) wraps "tf.IdentityN"
|
||||
# CHECK: %[[ADD:.*]], %{{.*}} = tf_executor.island wraps "tf.Add"(%[[ARG_0]], %[[ARG_1]])
|
||||
# CHECK: tf_executor.fetch %[[ARG_1]], %[[ARG_0]], %[[ADD]]
|
||||
# CHECK-LABEL: func @main
|
||||
# CHECK-SAME: (%[[ARG_0:.*]]: tensor<f32>, %[[ARG_1:.*]]: tensor<f32>) -> (tensor<f32>, tensor<f32>, tensor<f32>)
|
||||
# CHECK-SAME: control_outputs = ""
|
||||
# CHECK-SAME: inputs = "z:1,z:2"
|
||||
# CHECK-SAME: outputs = "z:2,z:1,a:0"
|
||||
# CHECK: %{{.*}}, %[[ASSIGN_ADD_CTRL:.*]] = tf_executor.island wraps "tf.AssignAdd"
|
||||
# CHECK: %{{.*}}, %{{.*}} = tf_executor.island(%[[ASSIGN_ADD_CTRL]]) wraps "tf.IdentityN"
|
||||
# CHECK: %[[ADD:.*]], %{{.*}} = tf_executor.island wraps "tf.Add"(%[[ARG_0]], %[[ARG_1]])
|
||||
# CHECK: tf_executor.fetch %[[ARG_1]], %[[ARG_0]], %[[ADD]]
|
||||
|
||||
# Test when non zero index output tensors are feeds, remaining ops that are
|
||||
# unreachable are pruned if pruning is enabled.
|
||||
#
|
||||
# PRUNE: func @main(%[[ARG_0:.*]]: tensor<f32>, %[[ARG_1:.*]]: tensor<f32>) -> (tensor<f32>, tensor<f32>, tensor<f32>)
|
||||
# PRUNE: attributes {tf.entry_function = {inputs = "z:1,z:2", outputs = "z:2,z:1,a:0"}}
|
||||
# PRUNE-LABEL: func @main
|
||||
# PRUNE-SAME: (%[[ARG_0:.*]]: tensor<f32>, %[[ARG_1:.*]]: tensor<f32>) -> (tensor<f32>, tensor<f32>, tensor<f32>)
|
||||
# PRUNE-SAME: control_outputs = ""
|
||||
# PRUNE-SAME: inputs = "z:1,z:2"
|
||||
# PRUNE-SAME: outputs = "z:2,z:1,a:0"
|
||||
# PRUNE-NOT: "tf.Const"
|
||||
# PRUNE-NOT: "tf.VariableV2"
|
||||
# PRUNE-NOT: "tf.Assign"
|
||||
@ -292,9 +298,12 @@ versions {
|
||||
# Test when non zero index output tensors are feeds, remaining ops that are
|
||||
# unreachable are preserved if pruning is not enabled.
|
||||
#
|
||||
# PRESERVE: func @main(%[[ARG_0:.*]]: tensor<f32>, %[[ARG_1:.*]]: tensor<f32>) -> (tensor<f32>, tensor<f32>)
|
||||
# PRESERVE: attributes {tf.entry_function = {inputs = "z:1,z:2", outputs = "z:0,a:0"}}
|
||||
# PRESERVE: %{{.*}}, %[[ASSIGN_ADD_CTRL:.*]] = tf_executor.island wraps "tf.AssignAdd"
|
||||
# PRESERVE: %[[IDENTITY_N:.*]]:3, %{{.*}} = tf_executor.island(%[[ASSIGN_ADD_CTRL]]) wraps "tf.IdentityN"
|
||||
# PRESERVE: %[[ADD:.*]], %{{.*}} = tf_executor.island wraps "tf.Add"(%[[ARG_0]], %[[ARG_1]])
|
||||
# PRESERVE: tf_executor.fetch %[[IDENTITY_N]]#0, %[[ADD]]
|
||||
# PRESERVE-LABEL: func @main
|
||||
# PRESERVE-SAME: (%[[ARG_0:.*]]: tensor<f32>, %[[ARG_1:.*]]: tensor<f32>) -> (tensor<f32>, tensor<f32>)
|
||||
# PRESERVE-SAME: control_outputs = ""
|
||||
# PRESERVE-SAME: inputs = "z:1,z:2"
|
||||
# PRESERVE-SAME: outputs = "z:0,a:0"
|
||||
# PRESERVE: %{{.*}}, %[[ASSIGN_ADD_CTRL:.*]] = tf_executor.island wraps "tf.AssignAdd"
|
||||
# PRESERVE: %[[IDENTITY_N:.*]]:3, %{{.*}} = tf_executor.island(%[[ASSIGN_ADD_CTRL]]) wraps "tf.IdentityN"
|
||||
# PRESERVE: %[[ADD:.*]], %{{.*}} = tf_executor.island wraps "tf.Add"(%[[ARG_0]], %[[ARG_1]])
|
||||
# PRESERVE: tf_executor.fetch %[[IDENTITY_N]]#0, %[[ADD]]
|
||||
|
@ -0,0 +1,203 @@
|
||||
# RUN: tf-mlir-translate -graphdef-to-mlir %s -tf-control-output-arrays=AssignAdd -o - | FileCheck %s --dump-input=fail
|
||||
# RUN: tf-mlir-translate -graphdef-to-mlir %s -tf-prune-unused-nodes -tf-control-output-arrays=AssignAdd -o - | FileCheck --check-prefix=PRUNE %s --dump-input=fail
|
||||
# RUN: tf-mlir-translate -graphdef-to-mlir %s -tf-prune-unused-nodes -tf-control-output-arrays=Variable/Assign,AssignAdd -o - | FileCheck --check-prefix=PRESERVE %s --dump-input=fail
|
||||
|
||||
# Generated in Python via
|
||||
# ```
|
||||
# import tensorflow as tf
|
||||
#
|
||||
# with tf.compat.v1.Graph().as_default() as g:
|
||||
# var = tf.Variable(2.0)
|
||||
# var_add = var.assign_add(3.0)
|
||||
# ```
|
||||
|
||||
node {
|
||||
name: "Variable/initial_value"
|
||||
op: "Const"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "value"
|
||||
value {
|
||||
tensor {
|
||||
dtype: DT_FLOAT
|
||||
tensor_shape {
|
||||
}
|
||||
float_val: 2.0
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "Variable"
|
||||
op: "VariableV2"
|
||||
attr {
|
||||
key: "container"
|
||||
value {
|
||||
s: ""
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "shape"
|
||||
value {
|
||||
shape {
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "shared_name"
|
||||
value {
|
||||
s: ""
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "Variable/Assign"
|
||||
op: "Assign"
|
||||
input: "Variable"
|
||||
input: "Variable/initial_value"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "_class"
|
||||
value {
|
||||
list {
|
||||
s: "loc:@Variable"
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "use_locking"
|
||||
value {
|
||||
b: true
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "validate_shape"
|
||||
value {
|
||||
b: true
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "Variable/read"
|
||||
op: "Identity"
|
||||
input: "Variable"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "_class"
|
||||
value {
|
||||
list {
|
||||
s: "loc:@Variable"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "AssignAdd/value"
|
||||
op: "Const"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "value"
|
||||
value {
|
||||
tensor {
|
||||
dtype: DT_FLOAT
|
||||
tensor_shape {
|
||||
}
|
||||
float_val: 3.0
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "AssignAdd"
|
||||
op: "AssignAdd"
|
||||
input: "Variable"
|
||||
input: "AssignAdd/value"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "_class"
|
||||
value {
|
||||
list {
|
||||
s: "loc:@Variable"
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "use_locking"
|
||||
value {
|
||||
b: false
|
||||
}
|
||||
}
|
||||
}
|
||||
versions {
|
||||
producer: 309
|
||||
}
|
||||
|
||||
# Tests single target node with no pruning set. All nodes will remain in the
|
||||
# graph and the target node is added to the graph fetch as a control.
|
||||
#
|
||||
# CHECK-LABEL: func @main
|
||||
# CHECK-SAME: control_outputs = "AssignAdd"
|
||||
# CHECK-SAME: inputs = ""
|
||||
# CHECK-SAME: outputs = ""
|
||||
# CHECK: %{{.*}}, %[[ASSIGN_ADD_CTRL:.*]] = tf_executor.island wraps "tf.AssignAdd"
|
||||
# CHECK: tf_executor.fetch %[[ASSIGN_ADD_CTRL]]
|
||||
|
||||
# Tests single target node with pruning set. Unreachable nodes from the fetch,
|
||||
# including side effecting nodes, will be removed. The target node is added to
|
||||
# the graph fetch as a control.
|
||||
#
|
||||
# PRUNE-LABEL: func @main
|
||||
# PRUNE-SAME: control_outputs = "AssignAdd"
|
||||
# PRUNE-SAME: inputs = ""
|
||||
# PRUNE-SAME: outputs = ""
|
||||
# PRUNE-NOT: "tf.Assign"
|
||||
# PRUNE-NOT: "tf.Identity"
|
||||
# PRUNE-DAG: %[[CONST:.*]], %{{.*}} = tf_executor.island wraps "tf.Const"
|
||||
# PRUNE-DAG: %[[VAR:.*]], %{{.*}} = tf_executor.island wraps "tf.VariableV2"
|
||||
# PRUNE: %{{.*}}, %[[ASSIGN_ADD_CTRL:.*]] = tf_executor.island wraps "tf.AssignAdd"(%[[VAR]], %[[CONST]])
|
||||
# PRUNE-NEXT: tf_executor.fetch %[[ASSIGN_ADD_CTRL]]
|
||||
|
||||
# Tests multiple target nodes with pruning set. Unreachable nodes from the
|
||||
# fetch, including side effecting nodes, will be removed. The target nodes are
|
||||
# added to the graph fetch as controls.
|
||||
#
|
||||
# PRESERVE-LABEL: func @main
|
||||
# PRESERVE-SAME: control_outputs = "Variable/Assign,AssignAdd"
|
||||
# PRESERVE-SAME: inputs = ""
|
||||
# PRESERVE-SAME: outputs = ""
|
||||
# PRESERVE-NOT: "tf.Identity"
|
||||
# PRESERVE: %[[VAR:.*]], %{{.*}} = tf_executor.island wraps "tf.VariableV2"
|
||||
# PRESERVE-DAG: %{{.*}}, %[[ASSIGN_CTRL:.*]] = tf_executor.island wraps "tf.Assign"
|
||||
# PRESERVE-DAG: %{{.*}}, %[[ASSIGN_ADD_CTRL:.*]] = tf_executor.island wraps "tf.AssignAdd"
|
||||
# PRESERVE-NEXT: tf_executor.fetch %[[ASSIGN_CTRL]], %[[ASSIGN_ADD_CTRL]]
|
@ -570,6 +570,13 @@ StatusOr<std::pair<Node*, bool>> ImporterBase::CreatePlaceholderNodeForFeed(
|
||||
edge->dst_input()));
|
||||
}
|
||||
|
||||
// TODO(lyandy): Preserve control dependencies properly by not forwarding
|
||||
// control dependencies to data outputs and not removing single output nodes.
|
||||
// When a data output is replaced as a feed, unless there is another non feed
|
||||
// data output or an explicit control output used by the same node, transitive
|
||||
// control dependencies are not to be executed. For single output nodes,
|
||||
// Placeholders can be converted to a NoOp if there are no uses, and
|
||||
// PlaceholderWithDefault can be converted to an Identity.
|
||||
for (const auto* edge : control_edges) {
|
||||
graph_->AddControlEdge(placeholder_node, edge->dst());
|
||||
graph_->RemoveControlEdge(edge);
|
||||
@ -616,6 +623,9 @@ Status ImporterBase::GetInputOutputNodes(
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto& control_output : specs_.control_outputs)
|
||||
TF_RETURN_IF_ERROR(add_node(control_output));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -658,7 +668,7 @@ Status ImporterBase::AddNodesToShapeRefiner() {
|
||||
DataType dtype = array_info.imported_dtype;
|
||||
// Uses the existing output type if it isn't specified by the user.
|
||||
if (dtype == DT_INVALID) {
|
||||
dtype = node->output_type(0);
|
||||
dtype = node->output_type(index);
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
@ -1787,10 +1797,10 @@ class GraphDefImporter : public ImporterBase {
|
||||
absl::InlinedVector<std::pair<int64_t, int64_t>, 4>*
|
||||
resource_arg_unique_ids);
|
||||
|
||||
// Finds the function's control ret nodes based on supplied node names in
|
||||
// `control_outputs`. If `control_outputs` are not unique or a control ret
|
||||
// node is missing, an error will be returned.
|
||||
Status GetControlRetsFromFunctionGraph(
|
||||
// Finds the graph's target nodes/function's control ret nodes based on
|
||||
// supplied node names in `control_outputs`. If `control_outputs` are not
|
||||
// unique or a control ret node is missing, an error will be returned.
|
||||
Status GetControlRetsFromGraph(
|
||||
llvm::ArrayRef<std::string> control_outputs,
|
||||
absl::InlinedVector<Node*, 4>* control_ret_nodes);
|
||||
};
|
||||
@ -1827,8 +1837,8 @@ StatusOr<mlir::OwningModuleRef> GraphDefImporter::Convert(
|
||||
importer.GetArgsRetsAndTypesFromFunctionGraph(
|
||||
context, &arg_nodes, &ret_nodes, &resource_arg_unique_ids));
|
||||
|
||||
TF_RETURN_IF_ERROR(importer.GetControlRetsFromFunctionGraph(
|
||||
specs.control_outputs, &control_ret_nodes));
|
||||
TF_RETURN_IF_ERROR(importer.GetControlRetsFromGraph(specs.control_outputs,
|
||||
&control_ret_nodes));
|
||||
|
||||
if (!arg_nodes.empty() || !ret_nodes.empty() ||
|
||||
!control_ret_nodes.empty()) {
|
||||
@ -1858,10 +1868,14 @@ StatusOr<mlir::OwningModuleRef> GraphDefImporter::Convert(
|
||||
TF_ASSIGN_OR_RETURN(func_type, importer.InferMainFunctionType(
|
||||
specs, context, &arg_nodes, &ret_nodes));
|
||||
|
||||
TF_RETURN_IF_ERROR(importer.GetControlRetsFromGraph(specs.control_outputs,
|
||||
&control_ret_nodes));
|
||||
|
||||
// TODO(prakalps): Refactor to keep tf.entry_function attribute encoding and
|
||||
// decoding in a centralized place.
|
||||
// Record the input and output mapping.
|
||||
if (!specs.inputs.empty() || !specs.outputs.empty()) {
|
||||
if (!specs.inputs.empty() || !specs.outputs.empty() ||
|
||||
!specs.control_outputs.empty()) {
|
||||
mlir::Builder b(context);
|
||||
std::string s;
|
||||
llvm::raw_string_ostream ss(s);
|
||||
@ -1873,9 +1887,14 @@ StatusOr<mlir::OwningModuleRef> GraphDefImporter::Convert(
|
||||
s.clear();
|
||||
mlir::interleave(specs.outputs, ss, ",");
|
||||
auto outputs = b.getNamedAttr("outputs", b.getStringAttr(ss.str()));
|
||||
s.clear();
|
||||
mlir::interleave(specs.control_outputs, ss, ",");
|
||||
auto control_outputs =
|
||||
b.getNamedAttr("control_outputs", b.getStringAttr(ss.str()));
|
||||
|
||||
attrs.push_back(b.getNamedAttr("tf.entry_function",
|
||||
b.getDictionaryAttr({inputs, outputs})));
|
||||
attrs.push_back(b.getNamedAttr(
|
||||
"tf.entry_function",
|
||||
b.getDictionaryAttr({inputs, outputs, control_outputs})));
|
||||
}
|
||||
}
|
||||
|
||||
@ -2064,7 +2083,7 @@ GraphDefImporter::GetArgsRetsAndTypesFromFunctionGraph(
|
||||
return builder.getFunctionType(arg_types, ret_types);
|
||||
}
|
||||
|
||||
Status GraphDefImporter::GetControlRetsFromFunctionGraph(
|
||||
Status GraphDefImporter::GetControlRetsFromGraph(
|
||||
llvm::ArrayRef<std::string> control_outputs,
|
||||
absl::InlinedVector<Node*, 4>* control_ret_nodes) {
|
||||
if (control_outputs.empty()) return Status::OK();
|
||||
|
@ -42,8 +42,7 @@ struct GraphImportConfig {
|
||||
InputArrays inputs;
|
||||
// name:index strings for the data outputs.
|
||||
std::vector<string> outputs;
|
||||
// name strings for the control outputs. This is currently only used when
|
||||
// `graph_as_function` is set.
|
||||
// name strings for the control outputs.
|
||||
std::vector<string> control_outputs;
|
||||
// Setting prune_unused_nodes to true, would prune unreachable nodes if
|
||||
// output_arrays is specified.
|
||||
|
@ -50,8 +50,7 @@ opt<std::string> output_arrays(
|
||||
// NOLINTNEXTLINE
|
||||
opt<std::string> control_output_arrays(
|
||||
"tf-control-output-arrays",
|
||||
llvm::cl::desc("Control output node names, separated by ',', for main "
|
||||
"graphs that are functions"),
|
||||
llvm::cl::desc("Control output node names, separated by ','"),
|
||||
llvm::cl::init(""));
|
||||
|
||||
// NOLINTNEXTLINE
|
||||
|
Loading…
x
Reference in New Issue
Block a user