Add support for target nodes of TensorFlow Graphs in TF MLIR importer.
It is possible to have target nodes/control rets for nodes that must execute but may not have data output(s). For example, side effecting ops, like resource writes, may rely on being set as a target node to be executed. This also unifies control rets support added for function graphs. PiperOrigin-RevId: 302683770 Change-Id: I33a66ae0ce08ac2230b575caae2cc752636d0d8b
This commit is contained in:
parent
e1afcc5feb
commit
463bec0d92
@ -36,8 +36,11 @@ versions {
|
|||||||
producer: 27
|
producer: 27
|
||||||
}
|
}
|
||||||
|
|
||||||
# CHECK: func @main(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<*xi32>
|
# CHECK-LABEL: func @main
|
||||||
# CHECK: attributes {tf.entry_function = {inputs = "input0,input1", outputs = "output"}} {
|
# CHECK-SAME: (%[[ARG_0:[a-z0-9]+]]: tensor<4xi32>, %[[ARG_1:[a-z0-9]+]]: tensor<4xi32>) -> tensor<*xi32>
|
||||||
# CHECK-NEXT: %0 = "tf.BannaPotatoSaladWithColeslaw"(%arg0, %arg1) {T = i32, device = ""} : (tensor<4xi32>, tensor<4xi32>) -> tensor<*xi32>
|
# CHECK-SAME: control_outputs = ""
|
||||||
# CHECK-NEXT: return %0 : tensor<*xi32>
|
# CHECK-SAME inputs = "input0,input1"
|
||||||
|
# CHECK-SAME: outputs = "output"
|
||||||
|
# CHECK-NEXT: %[[OP:[a-z0-9]+]] = "tf.BannaPotatoSaladWithColeslaw"(%[[ARG_0]], %[[ARG_1]]) {T = i32, device = ""} : (tensor<4xi32>, tensor<4xi32>) -> tensor<*xi32>
|
||||||
|
# CHECK-NEXT: return %[[OP]] : tensor<*xi32>
|
||||||
# CHECK-NEXT: }
|
# CHECK-NEXT: }
|
||||||
|
@ -443,12 +443,15 @@ node {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
# MLIR-LABEL: func @main(%arg0: tensor<1x1x1x256x!quant.uniform<i8:f32, 0.21632751372549019:27>>) -> tensor<1x6x31x!quant.uniform<i8:f32, 0.09363494573854933:22>>
|
# MLIR-LABEL: func @main
|
||||||
# MLIR: attributes {tf.entry_function = {inputs = "input", outputs = "output"}
|
# MLIR-SAME: (%[[ARG_0:[a-z0-9]+]]: tensor<1x1x1x256x!quant.uniform<i8:f32, 0.21632751372549019:27>>) -> tensor<1x6x31x!quant.uniform<i8:f32, 0.09363494573854933:22>>
|
||||||
|
# MLIR-SAME: control_outputs = ""
|
||||||
|
# MLIR-SAME: inputs = "input"
|
||||||
|
# MLIR-SAME: outputs = "output"
|
||||||
# MLIR: %[[shape:.*]] = constant dense<[1, -1, 31]> : tensor<3xi32>
|
# MLIR: %[[shape:.*]] = constant dense<[1, -1, 31]> : tensor<3xi32>
|
||||||
# MLIR: %[[bias:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<186x!quant.uniform<i32:f32:0
|
# MLIR: %[[bias:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<186x!quant.uniform<i32:f32:0
|
||||||
# MLIR: %[[weight:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<186x1x1x256x!quant.uniform<i8<-127:127>:f32:0, {0.12581039038230116,
|
# MLIR: %[[weight:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<186x1x1x256x!quant.uniform<i8<-127:127>:f32:0, {0.12581039038230116,
|
||||||
# MLIR: %[[conv:.*]] = "tfl.conv_2d"(%arg0, %[[weight]], %[[bias]]) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32}
|
# MLIR: %[[conv:.*]] = "tfl.conv_2d"(%[[ARG_0]], %[[weight]], %[[bias]]) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32}
|
||||||
# MLIR: %[[reshape:.*]] = "tfl.reshape"(%[[conv]], %[[shape]]) : (tensor<1x1x1x186x!quant.uniform<i8:f32, 0.09363494573854933:22>>, tensor<3xi32>)
|
# MLIR: %[[reshape:.*]] = "tfl.reshape"(%[[conv]], %[[shape]]) : (tensor<1x1x1x186x!quant.uniform<i8:f32, 0.09363494573854933:22>>, tensor<3xi32>)
|
||||||
# MLIR: return %[[reshape]] : tensor<1x6x31x!quant.uniform<i8:f32, 0.09363494573854933:22>>
|
# MLIR: return %[[reshape]] : tensor<1x6x31x!quant.uniform<i8:f32, 0.09363494573854933:22>>
|
||||||
# MLIR: }
|
# MLIR: }
|
||||||
|
@ -47,7 +47,10 @@ versions {
|
|||||||
producer: 27
|
producer: 27
|
||||||
}
|
}
|
||||||
|
|
||||||
# CHECK: func @main(%arg0: tensor<4xi32>) -> tensor<4xi32>
|
# CHECK-LABEL: func @main
|
||||||
# CHECK: attributes {tf.entry_function = {inputs = "input", outputs = "output"}} {
|
# CHECK-SAME: (%[[ARG_0:[a-z0-9]+]]: tensor<4xi32>) -> tensor<4xi32>
|
||||||
# CHECK-NEXT: return %arg0 : tensor<4xi32>
|
# CHECK-SAME: control_outputs = ""
|
||||||
|
# CHECK-SAME: inputs = "input"
|
||||||
|
# CHECK-SAME: outputs = "output"
|
||||||
|
# CHECK-NEXT: return %[[ARG_0]] : tensor<4xi32>
|
||||||
# CHECK-NEXT: }
|
# CHECK-NEXT: }
|
||||||
|
@ -137,8 +137,11 @@ versions {
|
|||||||
producer: 198
|
producer: 198
|
||||||
}
|
}
|
||||||
|
|
||||||
# CHECK: func @main([[VAL_0:%.*]]: tensor<1x8x8x2xi32>) -> (tensor<1x8x8x2xi32>, tensor<i32>, tensor<10xi32>)
|
# CHECK-LABEL: func @main
|
||||||
# CHECK: attributes {tf.entry_function = {inputs = "unranked", outputs = "unranked,static,static_10"}} {
|
# CHECK-SAME: ([[VAL_0:%.*]]: tensor<1x8x8x2xi32>) -> (tensor<1x8x8x2xi32>, tensor<i32>, tensor<10xi32>)
|
||||||
|
# CHECK-SAME: control_outputs = ""
|
||||||
|
# CHECK-SAME: inputs = "unranked"
|
||||||
|
# CHECK-SAME: outputs = "unranked,static,static_10"
|
||||||
# CHECK: [[VAL_1:%.*]] = constant dense<0> : tensor<10xi32>
|
# CHECK: [[VAL_1:%.*]] = constant dense<0> : tensor<10xi32>
|
||||||
# CHECK: [[VAL_2:%.*]] = constant dense<0> : tensor<i32>
|
# CHECK: [[VAL_2:%.*]] = constant dense<0> : tensor<i32>
|
||||||
# CHECK: return [[VAL_0]], [[VAL_2]], [[VAL_1]] : tensor<1x8x8x2xi32>, tensor<i32>, tensor<10xi32>
|
# CHECK: return [[VAL_0]], [[VAL_2]], [[VAL_1]] : tensor<1x8x8x2xi32>, tensor<i32>, tensor<10xi32>
|
||||||
|
@ -7783,8 +7783,11 @@ node {
|
|||||||
library {
|
library {
|
||||||
}
|
}
|
||||||
|
|
||||||
# CHECK: func @main(%arg0: tensor<1x3x3xf32>) -> tensor<1x3xf32>
|
# CHECK-LABEL: func @main
|
||||||
# CHECK: attributes {tf.entry_function = {inputs = "INPUT", outputs = "OUTPUT"}} {
|
# CHECK-SAME: (%[[ARG_0:[a-z0-9]+]]: tensor<1x3x3xf32>) -> tensor<1x3xf32>
|
||||||
|
# CHECK-SAME: control_outputs = ""
|
||||||
|
# CHECK-SAME: inputs = "INPUT"
|
||||||
|
# CHECK-SAME: outputs = "OUTPUT"
|
||||||
# CHECK: [[VAL_1:%.*]] = constant dense<{{\[\[}}-0.400154352, 0.739109992, 0.201825857], [0.678572893, 0.32076478, 0.949867963], [-0.807729483, -5.324750e-01, 0.148033619]]> : tensor<3x3xf32>
|
# CHECK: [[VAL_1:%.*]] = constant dense<{{\[\[}}-0.400154352, 0.739109992, 0.201825857], [0.678572893, 0.32076478, 0.949867963], [-0.807729483, -5.324750e-01, 0.148033619]]> : tensor<3x3xf32>
|
||||||
# CHECK: [[VAL_2:%.*]] = constant dense<{{\[\[}}0.886177539, -0.606141329, -0.451275587], [0.325554609, 0.691527605, -0.676239967], [0.219799042, 0.626042128, -0.597596407]]> : tensor<3x3xf32>
|
# CHECK: [[VAL_2:%.*]] = constant dense<{{\[\[}}0.886177539, -0.606141329, -0.451275587], [0.325554609, 0.691527605, -0.676239967], [0.219799042, 0.626042128, -0.597596407]]> : tensor<3x3xf32>
|
||||||
# CHECK: [[VAL_3:%.*]] = constant dense<{{\[\[}}-0.493826151, -0.391061306, -0.349843264], [-0.0213134289, 0.558384657, -0.51513052], [0.427886248, 0.618100405, -0.187585592]]> : tensor<3x3xf32>
|
# CHECK: [[VAL_3:%.*]] = constant dense<{{\[\[}}-0.493826151, -0.391061306, -0.349843264], [-0.0213134289, 0.558384657, -0.51513052], [0.427886248, 0.618100405, -0.187585592]]> : tensor<3x3xf32>
|
||||||
@ -7808,7 +7811,7 @@ library {
|
|||||||
# CHECK: [[VAL_21:%.*]] = constant dense<0.000000e+00> : tensor<3xf32>
|
# CHECK: [[VAL_21:%.*]] = constant dense<0.000000e+00> : tensor<3xf32>
|
||||||
# CHECK: [[VAL_22:%.*]] = constant dense<0.000000e+00> : tensor<1x3xf32>
|
# CHECK: [[VAL_22:%.*]] = constant dense<0.000000e+00> : tensor<1x3xf32>
|
||||||
# CHECK: [[VAL_23:%.*]] = constant unit
|
# CHECK: [[VAL_23:%.*]] = constant unit
|
||||||
# CHECK: [[VAL_24:%.*]]:3 = "tfl.unpack"(%arg0) {axis = 1 : i32, num = 3 : i32} : (tensor<1x3x3xf32>) -> (tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>)
|
# CHECK: [[VAL_24:%.*]]:3 = "tfl.unpack"(%[[ARG_0]]) {axis = 1 : i32, num = 3 : i32} : (tensor<1x3x3xf32>) -> (tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>)
|
||||||
# CHECK: [[VAL_25:%.*]] = "tfl.pack"([[VAL_24]]#0, [[VAL_24]]#1, [[VAL_24]]#2) {axis = 0 : i32, values_count = 3 : i32} : (tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>) -> tensor<3x1x3xf32>
|
# CHECK: [[VAL_25:%.*]] = "tfl.pack"([[VAL_24]]#0, [[VAL_24]]#1, [[VAL_24]]#2) {axis = 0 : i32, values_count = 3 : i32} : (tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>) -> tensor<3x1x3xf32>
|
||||||
# CHECK: [[VAL_26:%.*]] = constant dense<0.000000e+00> : tensor<1x3xf32>
|
# CHECK: [[VAL_26:%.*]] = constant dense<0.000000e+00> : tensor<1x3xf32>
|
||||||
# CHECK: [[VAL_27:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_25]], [[VAL_7]], [[VAL_6]], [[VAL_5]], [[VAL_8]], [[VAL_3]], [[VAL_2]], [[VAL_1]], [[VAL_4]], [[VAL_10]], [[VAL_9]], [[VAL_11]], [[VAL_21]], [[VAL_16]], [[VAL_21]], [[VAL_21]], [[VAL_23]], [[VAL_23]], [[VAL_22]], [[VAL_26]], [[VAL_23]], [[VAL_23]], [[VAL_23]], [[VAL_23]]) {fused_activation_function = "TANH", time_major = true} : (tensor<3x1x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, none, none, tensor<1x3xf32>, tensor<1x3xf32>, none, none, none, none) -> tensor<3x1x3xf32>
|
# CHECK: [[VAL_27:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_25]], [[VAL_7]], [[VAL_6]], [[VAL_5]], [[VAL_8]], [[VAL_3]], [[VAL_2]], [[VAL_1]], [[VAL_4]], [[VAL_10]], [[VAL_9]], [[VAL_11]], [[VAL_21]], [[VAL_16]], [[VAL_21]], [[VAL_21]], [[VAL_23]], [[VAL_23]], [[VAL_22]], [[VAL_26]], [[VAL_23]], [[VAL_23]], [[VAL_23]], [[VAL_23]]) {fused_activation_function = "TANH", time_major = true} : (tensor<3x1x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, none, none, tensor<1x3xf32>, tensor<1x3xf32>, none, none, none, none) -> tensor<3x1x3xf32>
|
||||||
|
@ -38,17 +38,26 @@ versions {
|
|||||||
producer: 27
|
producer: 27
|
||||||
}
|
}
|
||||||
|
|
||||||
# CHECK: func @main(%arg0: tensor<10xi32>, %arg1: tensor<10xi32>) -> tensor<10xi32>
|
# CHECK-LABEL: func @main
|
||||||
# CHECK: attributes {tf.entry_function = {inputs = "input0,input1", outputs = "Add"}} {
|
# CHECK-SAME: (%[[ARG_0:[a-z0-9]+]]: tensor<10xi32>, %[[ARG_1:[a-z0-9]+]]: tensor<10xi32>) -> tensor<10xi32>
|
||||||
# CHECK: %[[add:.*]], %[[add_control:.*]] = tf_executor.island wraps "tf.Add"(%arg0, %arg1)
|
# CHECK-SAME: control_outputs = ""
|
||||||
|
# CHECK-SAME: inputs = "input0,input1"
|
||||||
|
# CHECK-SAME: outputs = "Add"
|
||||||
|
# CHECK: %[[add:.*]], %[[add_control:.*]] = tf_executor.island wraps "tf.Add"(%[[ARG_0]], %[[ARG_1]])
|
||||||
# CHECK: fetch %[[add]]
|
# CHECK: fetch %[[add]]
|
||||||
|
|
||||||
# SOME: func @main(%arg0: tensor<10xi32>, %arg1: tensor<10xi32>) -> tensor<10xi32>
|
# SOME-LABEL: func @main
|
||||||
# SOME: attributes {tf.entry_function = {inputs = "input0,input1", outputs = "Add"}} {
|
# SOME-SAME: (%[[ARG_0:[a-z0-9]+]]: tensor<10xi32>, %[[ARG_1:[a-z0-9]+]]: tensor<10xi32>) -> tensor<10xi32>
|
||||||
# SOME: %[[add:.*]], %[[add_control:.*]] = tf_executor.island wraps "tf.Add"(%arg0, %arg1)
|
# SOME-SAME: control_outputs = ""
|
||||||
|
# SOME-SAME: inputs = "input0,input1"
|
||||||
|
# SOME-SAME: outputs = "Add"
|
||||||
|
# SOME: %[[add:.*]], %[[add_control:.*]] = tf_executor.island wraps "tf.Add"(%[[ARG_0]], %[[ARG_1]])
|
||||||
# SOME: fetch %[[add]]
|
# SOME: fetch %[[add]]
|
||||||
|
|
||||||
# NONE: func @main(%arg0: tensor<10xi32>, %arg1: tensor<10xi32>) -> tensor<10xi32>
|
# NONE-LABEL: func @main
|
||||||
# NONE: attributes {tf.entry_function = {inputs = "input0,input1", outputs = "Add"}} {
|
# NONE-SAME: (%[[ARG_0:[a-z0-9]+]]: tensor<10xi32>, %[[ARG_1:[a-z0-9]+]]: tensor<10xi32>) -> tensor<10xi32>
|
||||||
# NONE: %[[add:.*]], %[[add_control:.*]] = tf_executor.island wraps "tf.Add"(%arg0, %arg1)
|
# NONE-SAME: control_outputs = ""
|
||||||
|
# NONE-SAME: inputs = "input0,input1"
|
||||||
|
# NONE-SAME: outputs = "Add"
|
||||||
|
# NONE: %[[add:.*]], %[[add_control:.*]] = tf_executor.island wraps "tf.Add"(%[[ARG_0]], %[[ARG_1]])
|
||||||
# NONE: fetch %[[add]]
|
# NONE: fetch %[[add]]
|
||||||
|
@ -20,8 +20,11 @@ versions {
|
|||||||
producer: 27
|
producer: 27
|
||||||
}
|
}
|
||||||
|
|
||||||
# CHECK: func @main(%[[ARG_0:[a-z0-9]+]]: tensor<8xi32>) -> tensor<8xi32>
|
# CHECK-LABEL: func @main
|
||||||
# CHECK: attributes {tf.entry_function = {inputs = "arg", outputs = "arg"}} {
|
# CHECK-SAME: (%[[ARG_0:[a-z0-9]+]]: tensor<8xi32>) -> tensor<8xi32>
|
||||||
|
# CHECK-SAME: control_outputs = ""
|
||||||
|
# CHECK-SAME: inputs = "arg"
|
||||||
|
# CHECK-SAME: outputs = "arg"
|
||||||
# CHECK: %[[GRAPH:[0-9]+]] = tf_executor.graph
|
# CHECK: %[[GRAPH:[0-9]+]] = tf_executor.graph
|
||||||
# CHECK: tf_executor.fetch %[[ARG_0]]
|
# CHECK: tf_executor.fetch %[[ARG_0]]
|
||||||
# CHECK: return %[[GRAPH]]
|
# CHECK: return %[[GRAPH]]
|
||||||
|
@ -14,8 +14,11 @@ versions {
|
|||||||
producer: 27
|
producer: 27
|
||||||
}
|
}
|
||||||
|
|
||||||
# CHECK: func @main(%[[ARG_0:[a-z0-9]+]]: tensor<8xi32>) -> tensor<8xi32>
|
# CHECK-LABEL: func @main
|
||||||
# CHECK: attributes {tf.entry_function = {inputs = "input", outputs = "input"}} {
|
# CHECK-SAME: (%[[ARG_0:[a-z0-9]+]]: tensor<8xi32>) -> tensor<8xi32>
|
||||||
|
# CHECK-SAME: control_outputs = ""
|
||||||
|
# CHECK-SAME: inputs = "input"
|
||||||
|
# CHECK-SAME: outputs = "input"
|
||||||
# CHECK: %[[GRAPH:[0-9]+]] = tf_executor.graph
|
# CHECK: %[[GRAPH:[0-9]+]] = tf_executor.graph
|
||||||
# CHECK: tf_executor.fetch %[[ARG_0]]
|
# CHECK: tf_executor.fetch %[[ARG_0]]
|
||||||
# CHECK: return %[[GRAPH]]
|
# CHECK: return %[[GRAPH]]
|
||||||
|
@ -59,8 +59,11 @@ library {
|
|||||||
versions {
|
versions {
|
||||||
}
|
}
|
||||||
|
|
||||||
# CHECK: func @main(%[[ARG_0:[a-z0-9]+]]: tensor<f32>) -> tensor<f32>
|
# CHECK-LABEL: func @main
|
||||||
# CHECK: tf.entry_function = {inputs = "input", outputs = "output_node"}
|
# CHECK-SAME: (%[[ARG_0:[a-z0-9]+]]: tensor<f32>) -> tensor<f32>
|
||||||
|
# CHECK-SAME: control_outputs = ""
|
||||||
|
# CHECK-SAME: inputs = "input"
|
||||||
|
# CHECK-SAME: outputs = "output_node"
|
||||||
# CHECK: %[[GRAPH:[0-9]+]] = tf_executor.graph
|
# CHECK: %[[GRAPH:[0-9]+]] = tf_executor.graph
|
||||||
# CHECK: %[[CONST:.*]], %[[CONST_control:.*]] = tf_executor.island wraps "tf.Const"()
|
# CHECK: %[[CONST:.*]], %[[CONST_control:.*]] = tf_executor.island wraps "tf.Const"()
|
||||||
# CHECK: %[[OUTPUT:.*]], %[[OUTPUT_control:.*]] = tf_executor.island wraps "tf.Identity"(%[[CONST]])
|
# CHECK: %[[OUTPUT:.*]], %[[OUTPUT_control:.*]] = tf_executor.island wraps "tf.Identity"(%[[CONST]])
|
||||||
|
@ -2,12 +2,15 @@
|
|||||||
|
|
||||||
# Verify that we match correctly the input / output when they are scalar.
|
# Verify that we match correctly the input / output when they are scalar.
|
||||||
|
|
||||||
# CHECK: func @main(%arg0: tensor<f32> {tf.device = "/device:CPU:0"}) -> (tensor<f32>, tensor<f32>)
|
# CHECK-LABEL: func @main
|
||||||
# CHECK: attributes {tf.entry_function = {inputs = "input", outputs = "out:1,out"}} {
|
# CHECK-SAME: (%{{[a-z0-9]+}}: tensor<f32> {tf.device = "/device:CPU:0"}) -> (tensor<f32>, tensor<f32>)
|
||||||
|
# CHECK-SAME: control_outputs = ""
|
||||||
|
# CHECK-SAME: inputs = "input"
|
||||||
|
# CHECK-SAME: outputs = "out:1,out"
|
||||||
|
|
||||||
# CHECK: tf.Relu
|
# CHECK: tf.Relu
|
||||||
# CHECK: %[[IDENTITY:[a-z_0-9]+]]:2, {{.*}} = tf_executor.island wraps "tf.IdentityN"
|
# CHECK: %[[IDENTITY:[a-z_0-9]+]]:2, {{.*}} = tf_executor.island wraps "tf.IdentityN"
|
||||||
# CHECK: fetch %[[IDENTITY]]#1, %[[IDENTITY]]#0 : tensor<f32>, tensor<f32>
|
# CHECK: etch %[[IDENTITY]]#1, %[[IDENTITY]]#0 : tensor<f32>, tensor<f32>
|
||||||
|
|
||||||
node {
|
node {
|
||||||
name: "input"
|
name: "input"
|
||||||
|
@ -268,8 +268,11 @@ versions {
|
|||||||
# function. Rets that happen to coincide with a feed should have its value be
|
# function. Rets that happen to coincide with a feed should have its value be
|
||||||
# of the feed.
|
# of the feed.
|
||||||
#
|
#
|
||||||
# CHECK: func @main(%[[ARG_0:.*]]: tensor<f32>, %[[ARG_1:.*]]: tensor<f32>) -> (tensor<f32>, tensor<f32>, tensor<f32>)
|
# CHECK-LABEL: func @main
|
||||||
# CHECK: attributes {tf.entry_function = {inputs = "z:1,z:2", outputs = "z:2,z:1,a:0"}}
|
# CHECK-SAME: (%[[ARG_0:.*]]: tensor<f32>, %[[ARG_1:.*]]: tensor<f32>) -> (tensor<f32>, tensor<f32>, tensor<f32>)
|
||||||
|
# CHECK-SAME: control_outputs = ""
|
||||||
|
# CHECK-SAME: inputs = "z:1,z:2"
|
||||||
|
# CHECK-SAME: outputs = "z:2,z:1,a:0"
|
||||||
# CHECK: %{{.*}}, %[[ASSIGN_ADD_CTRL:.*]] = tf_executor.island wraps "tf.AssignAdd"
|
# CHECK: %{{.*}}, %[[ASSIGN_ADD_CTRL:.*]] = tf_executor.island wraps "tf.AssignAdd"
|
||||||
# CHECK: %{{.*}}, %{{.*}} = tf_executor.island(%[[ASSIGN_ADD_CTRL]]) wraps "tf.IdentityN"
|
# CHECK: %{{.*}}, %{{.*}} = tf_executor.island(%[[ASSIGN_ADD_CTRL]]) wraps "tf.IdentityN"
|
||||||
# CHECK: %[[ADD:.*]], %{{.*}} = tf_executor.island wraps "tf.Add"(%[[ARG_0]], %[[ARG_1]])
|
# CHECK: %[[ADD:.*]], %{{.*}} = tf_executor.island wraps "tf.Add"(%[[ARG_0]], %[[ARG_1]])
|
||||||
@ -278,8 +281,11 @@ versions {
|
|||||||
# Test when non zero index output tensors are feeds, remaining ops that are
|
# Test when non zero index output tensors are feeds, remaining ops that are
|
||||||
# unreachable are pruned if pruning is enabled.
|
# unreachable are pruned if pruning is enabled.
|
||||||
#
|
#
|
||||||
# PRUNE: func @main(%[[ARG_0:.*]]: tensor<f32>, %[[ARG_1:.*]]: tensor<f32>) -> (tensor<f32>, tensor<f32>, tensor<f32>)
|
# PRUNE-LABEL: func @main
|
||||||
# PRUNE: attributes {tf.entry_function = {inputs = "z:1,z:2", outputs = "z:2,z:1,a:0"}}
|
# PRUNE-SAME: (%[[ARG_0:.*]]: tensor<f32>, %[[ARG_1:.*]]: tensor<f32>) -> (tensor<f32>, tensor<f32>, tensor<f32>)
|
||||||
|
# PRUNE-SAME: control_outputs = ""
|
||||||
|
# PRUNE-SAME: inputs = "z:1,z:2"
|
||||||
|
# PRUNE-SAME: outputs = "z:2,z:1,a:0"
|
||||||
# PRUNE-NOT: "tf.Const"
|
# PRUNE-NOT: "tf.Const"
|
||||||
# PRUNE-NOT: "tf.VariableV2"
|
# PRUNE-NOT: "tf.VariableV2"
|
||||||
# PRUNE-NOT: "tf.Assign"
|
# PRUNE-NOT: "tf.Assign"
|
||||||
@ -292,8 +298,11 @@ versions {
|
|||||||
# Test when non zero index output tensors are feeds, remaining ops that are
|
# Test when non zero index output tensors are feeds, remaining ops that are
|
||||||
# unreachable are preserved if pruning is not enabled.
|
# unreachable are preserved if pruning is not enabled.
|
||||||
#
|
#
|
||||||
# PRESERVE: func @main(%[[ARG_0:.*]]: tensor<f32>, %[[ARG_1:.*]]: tensor<f32>) -> (tensor<f32>, tensor<f32>)
|
# PRESERVE-LABEL: func @main
|
||||||
# PRESERVE: attributes {tf.entry_function = {inputs = "z:1,z:2", outputs = "z:0,a:0"}}
|
# PRESERVE-SAME: (%[[ARG_0:.*]]: tensor<f32>, %[[ARG_1:.*]]: tensor<f32>) -> (tensor<f32>, tensor<f32>)
|
||||||
|
# PRESERVE-SAME: control_outputs = ""
|
||||||
|
# PRESERVE-SAME: inputs = "z:1,z:2"
|
||||||
|
# PRESERVE-SAME: outputs = "z:0,a:0"
|
||||||
# PRESERVE: %{{.*}}, %[[ASSIGN_ADD_CTRL:.*]] = tf_executor.island wraps "tf.AssignAdd"
|
# PRESERVE: %{{.*}}, %[[ASSIGN_ADD_CTRL:.*]] = tf_executor.island wraps "tf.AssignAdd"
|
||||||
# PRESERVE: %[[IDENTITY_N:.*]]:3, %{{.*}} = tf_executor.island(%[[ASSIGN_ADD_CTRL]]) wraps "tf.IdentityN"
|
# PRESERVE: %[[IDENTITY_N:.*]]:3, %{{.*}} = tf_executor.island(%[[ASSIGN_ADD_CTRL]]) wraps "tf.IdentityN"
|
||||||
# PRESERVE: %[[ADD:.*]], %{{.*}} = tf_executor.island wraps "tf.Add"(%[[ARG_0]], %[[ARG_1]])
|
# PRESERVE: %[[ADD:.*]], %{{.*}} = tf_executor.island wraps "tf.Add"(%[[ARG_0]], %[[ARG_1]])
|
||||||
|
@ -0,0 +1,203 @@
|
|||||||
|
# RUN: tf-mlir-translate -graphdef-to-mlir %s -tf-control-output-arrays=AssignAdd -o - | FileCheck %s --dump-input=fail
|
||||||
|
# RUN: tf-mlir-translate -graphdef-to-mlir %s -tf-prune-unused-nodes -tf-control-output-arrays=AssignAdd -o - | FileCheck --check-prefix=PRUNE %s --dump-input=fail
|
||||||
|
# RUN: tf-mlir-translate -graphdef-to-mlir %s -tf-prune-unused-nodes -tf-control-output-arrays=Variable/Assign,AssignAdd -o - | FileCheck --check-prefix=PRESERVE %s --dump-input=fail
|
||||||
|
|
||||||
|
# Generated in Python via
|
||||||
|
# ```
|
||||||
|
# import tensorflow as tf
|
||||||
|
#
|
||||||
|
# with tf.compat.v1.Graph().as_default() as g:
|
||||||
|
# var = tf.Variable(2.0)
|
||||||
|
# var_add = var.assign_add(3.0)
|
||||||
|
# ```
|
||||||
|
|
||||||
|
node {
|
||||||
|
name: "Variable/initial_value"
|
||||||
|
op: "Const"
|
||||||
|
attr {
|
||||||
|
key: "dtype"
|
||||||
|
value {
|
||||||
|
type: DT_FLOAT
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
key: "value"
|
||||||
|
value {
|
||||||
|
tensor {
|
||||||
|
dtype: DT_FLOAT
|
||||||
|
tensor_shape {
|
||||||
|
}
|
||||||
|
float_val: 2.0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
name: "Variable"
|
||||||
|
op: "VariableV2"
|
||||||
|
attr {
|
||||||
|
key: "container"
|
||||||
|
value {
|
||||||
|
s: ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
key: "dtype"
|
||||||
|
value {
|
||||||
|
type: DT_FLOAT
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
key: "shape"
|
||||||
|
value {
|
||||||
|
shape {
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
key: "shared_name"
|
||||||
|
value {
|
||||||
|
s: ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
name: "Variable/Assign"
|
||||||
|
op: "Assign"
|
||||||
|
input: "Variable"
|
||||||
|
input: "Variable/initial_value"
|
||||||
|
attr {
|
||||||
|
key: "T"
|
||||||
|
value {
|
||||||
|
type: DT_FLOAT
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
key: "_class"
|
||||||
|
value {
|
||||||
|
list {
|
||||||
|
s: "loc:@Variable"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
key: "use_locking"
|
||||||
|
value {
|
||||||
|
b: true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
key: "validate_shape"
|
||||||
|
value {
|
||||||
|
b: true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
name: "Variable/read"
|
||||||
|
op: "Identity"
|
||||||
|
input: "Variable"
|
||||||
|
attr {
|
||||||
|
key: "T"
|
||||||
|
value {
|
||||||
|
type: DT_FLOAT
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
key: "_class"
|
||||||
|
value {
|
||||||
|
list {
|
||||||
|
s: "loc:@Variable"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
name: "AssignAdd/value"
|
||||||
|
op: "Const"
|
||||||
|
attr {
|
||||||
|
key: "dtype"
|
||||||
|
value {
|
||||||
|
type: DT_FLOAT
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
key: "value"
|
||||||
|
value {
|
||||||
|
tensor {
|
||||||
|
dtype: DT_FLOAT
|
||||||
|
tensor_shape {
|
||||||
|
}
|
||||||
|
float_val: 3.0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
name: "AssignAdd"
|
||||||
|
op: "AssignAdd"
|
||||||
|
input: "Variable"
|
||||||
|
input: "AssignAdd/value"
|
||||||
|
attr {
|
||||||
|
key: "T"
|
||||||
|
value {
|
||||||
|
type: DT_FLOAT
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
key: "_class"
|
||||||
|
value {
|
||||||
|
list {
|
||||||
|
s: "loc:@Variable"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
key: "use_locking"
|
||||||
|
value {
|
||||||
|
b: false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
versions {
|
||||||
|
producer: 309
|
||||||
|
}
|
||||||
|
|
||||||
|
# Tests single target node with no pruning set. All nodes will remain in the
|
||||||
|
# graph and the target node is added to the graph fetch as a control.
|
||||||
|
#
|
||||||
|
# CHECK-LABEL: func @main
|
||||||
|
# CHECK-SAME: control_outputs = "AssignAdd"
|
||||||
|
# CHECK-SAME: inputs = ""
|
||||||
|
# CHECK-SAME: outputs = ""
|
||||||
|
# CHECK: %{{.*}}, %[[ASSIGN_ADD_CTRL:.*]] = tf_executor.island wraps "tf.AssignAdd"
|
||||||
|
# CHECK: tf_executor.fetch %[[ASSIGN_ADD_CTRL]]
|
||||||
|
|
||||||
|
# Tests single target node with pruning set. Unreachable nodes from the fetch,
|
||||||
|
# including side effecting nodes, will be removed. The target node is added to
|
||||||
|
# the graph fetch as a control.
|
||||||
|
#
|
||||||
|
# PRUNE-LABEL: func @main
|
||||||
|
# PRUNE-SAME: control_outputs = "AssignAdd"
|
||||||
|
# PRUNE-SAME: inputs = ""
|
||||||
|
# PRUNE-SAME: outputs = ""
|
||||||
|
# PRUNE-NOT: "tf.Assign"
|
||||||
|
# PRUNE-NOT: "tf.Identity"
|
||||||
|
# PRUNE-DAG: %[[CONST:.*]], %{{.*}} = tf_executor.island wraps "tf.Const"
|
||||||
|
# PRUNE-DAG: %[[VAR:.*]], %{{.*}} = tf_executor.island wraps "tf.VariableV2"
|
||||||
|
# PRUNE: %{{.*}}, %[[ASSIGN_ADD_CTRL:.*]] = tf_executor.island wraps "tf.AssignAdd"(%[[VAR]], %[[CONST]])
|
||||||
|
# PRUNE-NEXT: tf_executor.fetch %[[ASSIGN_ADD_CTRL]]
|
||||||
|
|
||||||
|
# Tests multiple target nodes with pruning set. Unreachable nodes from the
|
||||||
|
# fetch, including side effecting nodes, will be removed. The target nodes are
|
||||||
|
# added to the graph fetch as controls.
|
||||||
|
#
|
||||||
|
# PRESERVE-LABEL: func @main
|
||||||
|
# PRESERVE-SAME: control_outputs = "Variable/Assign,AssignAdd"
|
||||||
|
# PRESERVE-SAME: inputs = ""
|
||||||
|
# PRESERVE-SAME: outputs = ""
|
||||||
|
# PRESERVE-NOT: "tf.Identity"
|
||||||
|
# PRESERVE: %[[VAR:.*]], %{{.*}} = tf_executor.island wraps "tf.VariableV2"
|
||||||
|
# PRESERVE-DAG: %{{.*}}, %[[ASSIGN_CTRL:.*]] = tf_executor.island wraps "tf.Assign"
|
||||||
|
# PRESERVE-DAG: %{{.*}}, %[[ASSIGN_ADD_CTRL:.*]] = tf_executor.island wraps "tf.AssignAdd"
|
||||||
|
# PRESERVE-NEXT: tf_executor.fetch %[[ASSIGN_CTRL]], %[[ASSIGN_ADD_CTRL]]
|
@ -570,6 +570,13 @@ StatusOr<std::pair<Node*, bool>> ImporterBase::CreatePlaceholderNodeForFeed(
|
|||||||
edge->dst_input()));
|
edge->dst_input()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO(lyandy): Preserve control dependencies properly by not forwarding
|
||||||
|
// control dependencies to data outputs and not removing single output nodes.
|
||||||
|
// When a data output is replaced as a feed, unless there is another non feed
|
||||||
|
// data output or an explicit control output used by the same node, transitive
|
||||||
|
// control dependencies are not to be executed. For single output nodes,
|
||||||
|
// Placeholders can be converted to a NoOp if there are no uses, and
|
||||||
|
// PlaceholderWithDefault can be converted to an Identity.
|
||||||
for (const auto* edge : control_edges) {
|
for (const auto* edge : control_edges) {
|
||||||
graph_->AddControlEdge(placeholder_node, edge->dst());
|
graph_->AddControlEdge(placeholder_node, edge->dst());
|
||||||
graph_->RemoveControlEdge(edge);
|
graph_->RemoveControlEdge(edge);
|
||||||
@ -616,6 +623,9 @@ Status ImporterBase::GetInputOutputNodes(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for (const auto& control_output : specs_.control_outputs)
|
||||||
|
TF_RETURN_IF_ERROR(add_node(control_output));
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -658,7 +668,7 @@ Status ImporterBase::AddNodesToShapeRefiner() {
|
|||||||
DataType dtype = array_info.imported_dtype;
|
DataType dtype = array_info.imported_dtype;
|
||||||
// Uses the existing output type if it isn't specified by the user.
|
// Uses the existing output type if it isn't specified by the user.
|
||||||
if (dtype == DT_INVALID) {
|
if (dtype == DT_INVALID) {
|
||||||
dtype = node->output_type(0);
|
dtype = node->output_type(index);
|
||||||
}
|
}
|
||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
@ -1787,10 +1797,10 @@ class GraphDefImporter : public ImporterBase {
|
|||||||
absl::InlinedVector<std::pair<int64_t, int64_t>, 4>*
|
absl::InlinedVector<std::pair<int64_t, int64_t>, 4>*
|
||||||
resource_arg_unique_ids);
|
resource_arg_unique_ids);
|
||||||
|
|
||||||
// Finds the function's control ret nodes based on supplied node names in
|
// Finds the graph's target nodes/function's control ret nodes based on
|
||||||
// `control_outputs`. If `control_outputs` are not unique or a control ret
|
// supplied node names in `control_outputs`. If `control_outputs` are not
|
||||||
// node is missing, an error will be returned.
|
// unique or a control ret node is missing, an error will be returned.
|
||||||
Status GetControlRetsFromFunctionGraph(
|
Status GetControlRetsFromGraph(
|
||||||
llvm::ArrayRef<std::string> control_outputs,
|
llvm::ArrayRef<std::string> control_outputs,
|
||||||
absl::InlinedVector<Node*, 4>* control_ret_nodes);
|
absl::InlinedVector<Node*, 4>* control_ret_nodes);
|
||||||
};
|
};
|
||||||
@ -1827,8 +1837,8 @@ StatusOr<mlir::OwningModuleRef> GraphDefImporter::Convert(
|
|||||||
importer.GetArgsRetsAndTypesFromFunctionGraph(
|
importer.GetArgsRetsAndTypesFromFunctionGraph(
|
||||||
context, &arg_nodes, &ret_nodes, &resource_arg_unique_ids));
|
context, &arg_nodes, &ret_nodes, &resource_arg_unique_ids));
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(importer.GetControlRetsFromFunctionGraph(
|
TF_RETURN_IF_ERROR(importer.GetControlRetsFromGraph(specs.control_outputs,
|
||||||
specs.control_outputs, &control_ret_nodes));
|
&control_ret_nodes));
|
||||||
|
|
||||||
if (!arg_nodes.empty() || !ret_nodes.empty() ||
|
if (!arg_nodes.empty() || !ret_nodes.empty() ||
|
||||||
!control_ret_nodes.empty()) {
|
!control_ret_nodes.empty()) {
|
||||||
@ -1858,10 +1868,14 @@ StatusOr<mlir::OwningModuleRef> GraphDefImporter::Convert(
|
|||||||
TF_ASSIGN_OR_RETURN(func_type, importer.InferMainFunctionType(
|
TF_ASSIGN_OR_RETURN(func_type, importer.InferMainFunctionType(
|
||||||
specs, context, &arg_nodes, &ret_nodes));
|
specs, context, &arg_nodes, &ret_nodes));
|
||||||
|
|
||||||
|
TF_RETURN_IF_ERROR(importer.GetControlRetsFromGraph(specs.control_outputs,
|
||||||
|
&control_ret_nodes));
|
||||||
|
|
||||||
// TODO(prakalps): Refactor to keep tf.entry_function attribute encoding and
|
// TODO(prakalps): Refactor to keep tf.entry_function attribute encoding and
|
||||||
// decoding in a centralized place.
|
// decoding in a centralized place.
|
||||||
// Record the input and output mapping.
|
// Record the input and output mapping.
|
||||||
if (!specs.inputs.empty() || !specs.outputs.empty()) {
|
if (!specs.inputs.empty() || !specs.outputs.empty() ||
|
||||||
|
!specs.control_outputs.empty()) {
|
||||||
mlir::Builder b(context);
|
mlir::Builder b(context);
|
||||||
std::string s;
|
std::string s;
|
||||||
llvm::raw_string_ostream ss(s);
|
llvm::raw_string_ostream ss(s);
|
||||||
@ -1873,9 +1887,14 @@ StatusOr<mlir::OwningModuleRef> GraphDefImporter::Convert(
|
|||||||
s.clear();
|
s.clear();
|
||||||
mlir::interleave(specs.outputs, ss, ",");
|
mlir::interleave(specs.outputs, ss, ",");
|
||||||
auto outputs = b.getNamedAttr("outputs", b.getStringAttr(ss.str()));
|
auto outputs = b.getNamedAttr("outputs", b.getStringAttr(ss.str()));
|
||||||
|
s.clear();
|
||||||
|
mlir::interleave(specs.control_outputs, ss, ",");
|
||||||
|
auto control_outputs =
|
||||||
|
b.getNamedAttr("control_outputs", b.getStringAttr(ss.str()));
|
||||||
|
|
||||||
attrs.push_back(b.getNamedAttr("tf.entry_function",
|
attrs.push_back(b.getNamedAttr(
|
||||||
b.getDictionaryAttr({inputs, outputs})));
|
"tf.entry_function",
|
||||||
|
b.getDictionaryAttr({inputs, outputs, control_outputs})));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2064,7 +2083,7 @@ GraphDefImporter::GetArgsRetsAndTypesFromFunctionGraph(
|
|||||||
return builder.getFunctionType(arg_types, ret_types);
|
return builder.getFunctionType(arg_types, ret_types);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status GraphDefImporter::GetControlRetsFromFunctionGraph(
|
Status GraphDefImporter::GetControlRetsFromGraph(
|
||||||
llvm::ArrayRef<std::string> control_outputs,
|
llvm::ArrayRef<std::string> control_outputs,
|
||||||
absl::InlinedVector<Node*, 4>* control_ret_nodes) {
|
absl::InlinedVector<Node*, 4>* control_ret_nodes) {
|
||||||
if (control_outputs.empty()) return Status::OK();
|
if (control_outputs.empty()) return Status::OK();
|
||||||
|
@ -42,8 +42,7 @@ struct GraphImportConfig {
|
|||||||
InputArrays inputs;
|
InputArrays inputs;
|
||||||
// name:index strings for the data outputs.
|
// name:index strings for the data outputs.
|
||||||
std::vector<string> outputs;
|
std::vector<string> outputs;
|
||||||
// name strings for the control outputs. This is currently only used when
|
// name strings for the control outputs.
|
||||||
// `graph_as_function` is set.
|
|
||||||
std::vector<string> control_outputs;
|
std::vector<string> control_outputs;
|
||||||
// Setting prune_unused_nodes to true, would prune unreachable nodes if
|
// Setting prune_unused_nodes to true, would prune unreachable nodes if
|
||||||
// output_arrays is specified.
|
// output_arrays is specified.
|
||||||
|
@ -50,8 +50,7 @@ opt<std::string> output_arrays(
|
|||||||
// NOLINTNEXTLINE
|
// NOLINTNEXTLINE
|
||||||
opt<std::string> control_output_arrays(
|
opt<std::string> control_output_arrays(
|
||||||
"tf-control-output-arrays",
|
"tf-control-output-arrays",
|
||||||
llvm::cl::desc("Control output node names, separated by ',', for main "
|
llvm::cl::desc("Control output node names, separated by ','"),
|
||||||
"graphs that are functions"),
|
|
||||||
llvm::cl::init(""));
|
llvm::cl::init(""));
|
||||||
|
|
||||||
// NOLINTNEXTLINE
|
// NOLINTNEXTLINE
|
||||||
|
Loading…
x
Reference in New Issue
Block a user