Map TensorFlow StatelessIf and If op to a common If op in MLIR
TensorFlow StatelessIf and If op only differs in the is_stateful property and are identical otherwise. Introduced an additional attribute in the MLIR op definition to differentiate them and mapped to and from the common op while importing and export to MLIR, respectively. Thanks Mehdi for the suggestion! PiperOrigin-RevId: 259468359
This commit is contained in:
parent
ade316deef
commit
1de23834be
@ -160,7 +160,7 @@ func @main(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> {
|
|||||||
%0 = "tfl.pseudo_input"(%arg0) : (tensor<1xf32>) -> tensor<1xf32>
|
%0 = "tfl.pseudo_input"(%arg0) : (tensor<1xf32>) -> tensor<1xf32>
|
||||||
%1 = "tfl.pseudo_input"(%arg1) : (tensor<1xf32>) -> tensor<1xf32>
|
%1 = "tfl.pseudo_input"(%arg1) : (tensor<1xf32>) -> tensor<1xf32>
|
||||||
%2 = "tfl.less"(%0, %1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xi1>
|
%2 = "tfl.less"(%0, %1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xi1>
|
||||||
%3 = "tf.If"(%2, %0, %1) {else_branch = @cond_false, then_branch = @cond_true} : (tensor<1xi1>, tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
|
%3 = "tf.If"(%2, %0, %1) {else_branch = @cond_false, then_branch = @cond_true, is_stateless = false} : (tensor<1xi1>, tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
|
||||||
return %3 : tensor<1xf32>
|
return %3 : tensor<1xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -103,7 +103,10 @@ else_branch: A function that takes 'inputs' and returns a list of
|
|||||||
|
|
||||||
SymbolRefAttr:$then_branch,
|
SymbolRefAttr:$then_branch,
|
||||||
SymbolRefAttr:$else_branch,
|
SymbolRefAttr:$else_branch,
|
||||||
DefaultValuedAttr<StrArrayAttr, "{}">:$output_shapes
|
DefaultValuedAttr<StrArrayAttr, "{}">:$output_shapes,
|
||||||
|
|
||||||
|
// Used to map StatelessIf and If op defined in TensorFlow to a common op.
|
||||||
|
BoolAttr:$is_stateless
|
||||||
);
|
);
|
||||||
|
|
||||||
let results = (outs
|
let results = (outs
|
||||||
|
|||||||
@ -7,7 +7,7 @@ func @testIf1Else(tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
|
|||||||
func @testIf1Result(tensor<i1>, tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> {
|
func @testIf1Result(tensor<i1>, tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> {
|
||||||
^bb0(%arg0: tensor<i1>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>):
|
^bb0(%arg0: tensor<i1>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>):
|
||||||
%1 = "tf.If"(%arg0, %arg1, %arg2) {
|
%1 = "tf.If"(%arg0, %arg1, %arg2) {
|
||||||
then_branch = @testIf1Then, else_branch = @testIf1Else
|
then_branch = @testIf1Then, else_branch = @testIf1Else, is_stateless = false
|
||||||
} : (tensor<i1>, tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
|
} : (tensor<i1>, tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||||
|
|
||||||
// CHECK: %0 = extract_element %arg0[] : tensor<i1>
|
// CHECK: %0 = extract_element %arg0[] : tensor<i1>
|
||||||
@ -31,7 +31,7 @@ func @testIf3Else(tensor<*xf32>) -> (tensor<*xf32>, tensor<*xi8>, tensor<*xbf16>
|
|||||||
func @testIf3Result(tensor<i1>, tensor<*xf32>) -> (tensor<*xf32>, tensor<*xi8>, tensor<*xbf16>) {
|
func @testIf3Result(tensor<i1>, tensor<*xf32>) -> (tensor<*xf32>, tensor<*xi8>, tensor<*xbf16>) {
|
||||||
^bb0(%arg0: tensor<i1>, %arg1: tensor<*xf32>):
|
^bb0(%arg0: tensor<i1>, %arg1: tensor<*xf32>):
|
||||||
%1:3 = "tf.If"(%arg0, %arg1) {
|
%1:3 = "tf.If"(%arg0, %arg1) {
|
||||||
then_branch = @testIf3Then, else_branch = @testIf3Else
|
then_branch = @testIf3Then, else_branch = @testIf3Else, is_stateless = false
|
||||||
} : (tensor<i1>, tensor<*xf32>) -> (tensor<*xf32>, tensor<*xi8>, tensor<*xbf16>)
|
} : (tensor<i1>, tensor<*xf32>) -> (tensor<*xf32>, tensor<*xi8>, tensor<*xbf16>)
|
||||||
|
|
||||||
// CHECK: %0 = extract_element %arg0[] : tensor<i1>
|
// CHECK: %0 = extract_element %arg0[] : tensor<i1>
|
||||||
@ -57,7 +57,7 @@ func @testIf1Casts(tensor<i1>, tensor<2x2xf32>, tensor<*xf32>) -> tensor<2x?xf32
|
|||||||
^bb0(%arg0: tensor<i1>, %arg1: tensor<2x2xf32>, %arg2: tensor<*xf32>):
|
^bb0(%arg0: tensor<i1>, %arg1: tensor<2x2xf32>, %arg2: tensor<*xf32>):
|
||||||
|
|
||||||
%1 = "tf.If"(%arg0, %arg1, %arg2) {
|
%1 = "tf.If"(%arg0, %arg1, %arg2) {
|
||||||
then_branch = @testIf1Then, else_branch = @testIf1Else
|
then_branch = @testIf1Then, else_branch = @testIf1Else, is_stateless = false
|
||||||
} : (tensor<i1>, tensor<2x2xf32>, tensor<*xf32>) -> tensor<2x?xf32>
|
} : (tensor<i1>, tensor<2x2xf32>, tensor<*xf32>) -> tensor<2x?xf32>
|
||||||
|
|
||||||
// CHECK: %0 = extract_element %arg0[] : tensor<i1>
|
// CHECK: %0 = extract_element %arg0[] : tensor<i1>
|
||||||
@ -97,7 +97,7 @@ func @testIf1x4(tensor<4xi1>, tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> {
|
|||||||
|
|
||||||
// expected-error @+1 {{only supports zero-D bool tensors now}}
|
// expected-error @+1 {{only supports zero-D bool tensors now}}
|
||||||
%1 = "tf.If"(%arg0, %arg1, %arg2) {
|
%1 = "tf.If"(%arg0, %arg1, %arg2) {
|
||||||
then_branch = @testIf1Then, else_branch = @testIf1Else
|
then_branch = @testIf1Then, else_branch = @testIf1Else, is_stateless = false
|
||||||
} : (tensor<4xi1>, tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
|
} : (tensor<4xi1>, tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||||
|
|
||||||
return %1 : tensor<*xf32>
|
return %1 : tensor<*xf32>
|
||||||
|
|||||||
@ -517,7 +517,7 @@ versions {
|
|||||||
# CHECK-NEXT: %9:2 = "_tf.Identity"(%2#0, %7) {T = "tfdtype$DT_INT32", _tpu_replicate = "cluster", device = "", name = "replicated_input_0"} : (tensor<*xi32>, !_tf.control) -> (tensor<*xi32>, !_tf.control)
|
# CHECK-NEXT: %9:2 = "_tf.Identity"(%2#0, %7) {T = "tfdtype$DT_INT32", _tpu_replicate = "cluster", device = "", name = "replicated_input_0"} : (tensor<*xi32>, !_tf.control) -> (tensor<*xi32>, !_tf.control)
|
||||||
# CHECK-NEXT: %10:2 = "_tf.Identity"(%4#0, %7) {T = "tfdtype$DT_INT32", _tpu_replicate = "cluster", device = "", name = "replicated_input_1"} : (tensor<*xi32>, !_tf.control) -> (tensor<*xi32>, !_tf.control)
|
# CHECK-NEXT: %10:2 = "_tf.Identity"(%4#0, %7) {T = "tfdtype$DT_INT32", _tpu_replicate = "cluster", device = "", name = "replicated_input_1"} : (tensor<*xi32>, !_tf.control) -> (tensor<*xi32>, !_tf.control)
|
||||||
# CHECK-NEXT: %11:2 = "_tf.Less"(%9#0, %10#0) {T = "tfdtype$DT_INT32", _tpu_replicate = "cluster", device = "", name = "Less"} : (tensor<*xi32>, tensor<*xi32>) -> (tensor<*xi1>, !_tf.control)
|
# CHECK-NEXT: %11:2 = "_tf.Less"(%9#0, %10#0) {T = "tfdtype$DT_INT32", _tpu_replicate = "cluster", device = "", name = "Less"} : (tensor<*xi32>, tensor<*xi32>) -> (tensor<*xi1>, !_tf.control)
|
||||||
# CHECK-NEXT: %12:3 = "_tf.If"(%11#0, %10#0, %9#0) {Tcond = "tfdtype$DT_BOOL", Tin = ["tfdtype$DT_INT32", "tfdtype$DT_INT32"], Tout = ["tfdtype$DT_INT32", "tfdtype$DT_INT32"], _tpu_replicate = "cluster", device = "", else_branch = @cond_false0, name = "cond", output_shapes = ["tfshape$unknown_rank: true\0A", "tfshape$unknown_rank: true\0A"], then_branch = @cond_true0, then_branch.how_many = 32 : i64, then_branch.ping = "ack"} : (tensor<*xi1>, tensor<*xi32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi32>, !_tf.control)
|
# CHECK-NEXT: %12:3 = "_tf.If"(%11#0, %10#0, %9#0) {Tcond = "tfdtype$DT_BOOL", Tin = ["tfdtype$DT_INT32", "tfdtype$DT_INT32"], Tout = ["tfdtype$DT_INT32", "tfdtype$DT_INT32"], _tpu_replicate = "cluster", device = "", else_branch = @cond_false0, is_stateless = false, name = "cond", output_shapes = ["tfshape$unknown_rank: true\0A", "tfshape$unknown_rank: true\0A"], then_branch = @cond_true0, then_branch.how_many = 32 : i64, then_branch.ping = "ack"} : (tensor<*xi1>, tensor<*xi32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi32>, !_tf.control)
|
||||||
# CHECK-NEXT: %13:2 = "_tf.Identity"(%12#0) {T = "tfdtype$DT_INT32", _tpu_replicate = "cluster", device = "/device:TPU_REPLICATED_CORE:0", name = "Identity"} : (tensor<*xi32>) -> (tensor<*xi32>, !_tf.control)
|
# CHECK-NEXT: %13:2 = "_tf.Identity"(%12#0) {T = "tfdtype$DT_INT32", _tpu_replicate = "cluster", device = "/device:TPU_REPLICATED_CORE:0", name = "Identity"} : (tensor<*xi32>) -> (tensor<*xi32>, !_tf.control)
|
||||||
# CHECK-NEXT: %14:2 = "_tf.TPUReplicatedOutput"(%13#0) {T = "tfdtype$DT_INT32", device = "", name = "output0", num_replicas = 1 : i64} : (tensor<*xi32>) -> (tensor<*xi32>, !_tf.control)
|
# CHECK-NEXT: %14:2 = "_tf.TPUReplicatedOutput"(%13#0) {T = "tfdtype$DT_INT32", device = "", name = "output0", num_replicas = 1 : i64} : (tensor<*xi32>) -> (tensor<*xi32>, !_tf.control)
|
||||||
# CHECK-NEXT: %15:2 = "_tf.Identity"(%14#0, %6) {T = "tfdtype$DT_INT32", device = "", name = "output_0_shard_0"} : (tensor<*xi32>, !_tf.control) -> (tensor<*xi32>, !_tf.control)
|
# CHECK-NEXT: %15:2 = "_tf.Identity"(%14#0, %6) {T = "tfdtype$DT_INT32", device = "", name = "output_0_shard_0"} : (tensor<*xi32>, !_tf.control) -> (tensor<*xi32>, !_tf.control)
|
||||||
|
|||||||
@ -142,7 +142,7 @@ versions {
|
|||||||
#CHECK: func @main() {
|
#CHECK: func @main() {
|
||||||
#CHECK-NEXT: %0:2 = "_tf.Placeholder"() {device = "", dtype = "tfdtype$DT_BOOL", name = "Placeholder", shape = "tfshape$unknown_rank: true\0A"} : () -> (tensor<*xi1>, !_tf.control)
|
#CHECK-NEXT: %0:2 = "_tf.Placeholder"() {device = "", dtype = "tfdtype$DT_BOOL", name = "Placeholder", shape = "tfshape$unknown_rank: true\0A"} : () -> (tensor<*xi1>, !_tf.control)
|
||||||
#CHECK-NEXT: %1:2 = "_tf.Placeholder"() {device = "", dtype = "tfdtype$DT_INT32", name = "Placeholder_1", shape = "tfshape$unknown_rank: true\0A"} : () -> (tensor<*xi32>, !_tf.control)
|
#CHECK-NEXT: %1:2 = "_tf.Placeholder"() {device = "", dtype = "tfdtype$DT_INT32", name = "Placeholder_1", shape = "tfshape$unknown_rank: true\0A"} : () -> (tensor<*xi32>, !_tf.control)
|
||||||
#CHECK-NEXT: %2:2 = "_tf.If"(%0#0, %1#0) {Tcond = "tfdtype$DT_BOOL", Tin = ["tfdtype$DT_INT32"], Tout = ["tfdtype$DT_INT32"], device = "", else_branch = @get_zeros0, name = "If", output_shapes = [], then_branch = @identity0} : (tensor<*xi1>, tensor<*xi32>) -> (tensor<*xi32>, !_tf.control)
|
#CHECK-NEXT: %2:2 = "_tf.If"(%0#0, %1#0) {Tcond = "tfdtype$DT_BOOL", Tin = ["tfdtype$DT_INT32"], Tout = ["tfdtype$DT_INT32"], device = "", else_branch = @get_zeros0, is_stateless = false, name = "If", output_shapes = [], then_branch = @identity0} : (tensor<*xi1>, tensor<*xi32>) -> (tensor<*xi32>, !_tf.control)
|
||||||
#CHECK-NEXT: return
|
#CHECK-NEXT: return
|
||||||
#CHECK-NEXT: }
|
#CHECK-NEXT: }
|
||||||
#CHECK: func @get_zeros0(%arg0: tensor<*xi32>) -> tensor<2xi32> {
|
#CHECK: func @get_zeros0(%arg0: tensor<*xi32>) -> tensor<2xi32> {
|
||||||
|
|||||||
@ -486,7 +486,7 @@ func @testIfElse(tensor<*xf32>) -> tensor<*xf32>
|
|||||||
func @testValidIfOp(tensor<i1>, tensor<2xf32>) -> tensor<2xf32> {
|
func @testValidIfOp(tensor<i1>, tensor<2xf32>) -> tensor<2xf32> {
|
||||||
^bb0(%arg0: tensor<i1>, %arg1: tensor<2xf32>):
|
^bb0(%arg0: tensor<i1>, %arg1: tensor<2xf32>):
|
||||||
%1 = "tf.If"(%arg0, %arg1) {
|
%1 = "tf.If"(%arg0, %arg1) {
|
||||||
then_branch = @testIfThen, else_branch = @testIfElse
|
then_branch = @testIfThen, else_branch = @testIfElse, is_stateless = false
|
||||||
} : (tensor<i1>, tensor<2xf32>) -> tensor<2xf32>
|
} : (tensor<i1>, tensor<2xf32>) -> tensor<2xf32>
|
||||||
|
|
||||||
return %1 : tensor<2xf32>
|
return %1 : tensor<2xf32>
|
||||||
@ -503,7 +503,8 @@ func @testInvalidIfOp(tensor<i1>, f32) -> f32 {
|
|||||||
// expected-error @+1 {{operand #1 must be tensor of tf.dtype values}}
|
// expected-error @+1 {{operand #1 must be tensor of tf.dtype values}}
|
||||||
%1 = "tf.If"(%arg0, %arg1) {
|
%1 = "tf.If"(%arg0, %arg1) {
|
||||||
then_branch = @testIfThen,
|
then_branch = @testIfThen,
|
||||||
else_branch = @testIfElse
|
else_branch = @testIfElse,
|
||||||
|
is_stateless = false
|
||||||
} : (tensor<i1>, f32) -> f32
|
} : (tensor<i1>, f32) -> f32
|
||||||
|
|
||||||
return %1 : f32
|
return %1 : f32
|
||||||
@ -518,7 +519,7 @@ func @testInvalidIfOp(tensor<i1>, tensor<2xf32>) -> tensor<2xf32> {
|
|||||||
^bb0(%arg0: tensor<i1>, %arg1: tensor<2xf32>):
|
^bb0(%arg0: tensor<i1>, %arg1: tensor<2xf32>):
|
||||||
// expected-error @+1 {{requires attribute 'then_branch'}}
|
// expected-error @+1 {{requires attribute 'then_branch'}}
|
||||||
%1 = "tf.If"(%arg0, %arg1) {
|
%1 = "tf.If"(%arg0, %arg1) {
|
||||||
else_branch = @testIfElse
|
else_branch = @testIfElse, is_stateless = false
|
||||||
} : (tensor<i1>, tensor<2xf32>) -> tensor<2xf32>
|
} : (tensor<i1>, tensor<2xf32>) -> tensor<2xf32>
|
||||||
|
|
||||||
return %1 : tensor<2xf32>
|
return %1 : tensor<2xf32>
|
||||||
@ -535,7 +536,8 @@ func @testInvalidIfOp(tensor<i1>, tensor<2xf32>) -> tensor<2xf32> {
|
|||||||
// expected-error @+1 {{branches should have 1 inputs}}
|
// expected-error @+1 {{branches should have 1 inputs}}
|
||||||
%1 = "tf.If"(%arg0, %arg1) {
|
%1 = "tf.If"(%arg0, %arg1) {
|
||||||
then_branch = @testIfThen,
|
then_branch = @testIfThen,
|
||||||
else_branch = @testIfElse
|
else_branch = @testIfElse,
|
||||||
|
is_stateless = false
|
||||||
} : (tensor<i1>, tensor<2xf32>) -> tensor<2xf32>
|
} : (tensor<i1>, tensor<2xf32>) -> tensor<2xf32>
|
||||||
|
|
||||||
return %1 : tensor<2xf32>
|
return %1 : tensor<2xf32>
|
||||||
@ -552,7 +554,8 @@ func @testInvalidIfOp(tensor<i1>, tensor<2xf32>) -> tensor<2xf32> {
|
|||||||
// expected-error @+1 {{then branch input type tensor<*xf16> is incompatible with operand type tensor<2xf32>}}
|
// expected-error @+1 {{then branch input type tensor<*xf16> is incompatible with operand type tensor<2xf32>}}
|
||||||
%1 = "tf.If"(%arg0, %arg1) {
|
%1 = "tf.If"(%arg0, %arg1) {
|
||||||
then_branch = @testIfThen,
|
then_branch = @testIfThen,
|
||||||
else_branch = @testIfElse
|
else_branch = @testIfElse,
|
||||||
|
is_stateless = false
|
||||||
} : (tensor<i1>, tensor<2xf32>) -> tensor<2xf32>
|
} : (tensor<i1>, tensor<2xf32>) -> tensor<2xf32>
|
||||||
|
|
||||||
return %1 : tensor<2xf32>
|
return %1 : tensor<2xf32>
|
||||||
@ -569,7 +572,8 @@ func @testInvalidIfOp(tensor<i1>, tensor<*xf32>) -> tensor<2xf32> {
|
|||||||
// expected-error @+1 {{branches inputs have incompatible types tensor<2xf32> and tensor<3xf32>}}
|
// expected-error @+1 {{branches inputs have incompatible types tensor<2xf32> and tensor<3xf32>}}
|
||||||
%1 = "tf.If"(%arg0, %arg1) {
|
%1 = "tf.If"(%arg0, %arg1) {
|
||||||
then_branch = @testIfThen,
|
then_branch = @testIfThen,
|
||||||
else_branch = @testIfElse
|
else_branch = @testIfElse,
|
||||||
|
is_stateless = false
|
||||||
} : (tensor<i1>, tensor<*xf32>) -> tensor<2xf32>
|
} : (tensor<i1>, tensor<*xf32>) -> tensor<2xf32>
|
||||||
|
|
||||||
return %1 : tensor<2xf32>
|
return %1 : tensor<2xf32>
|
||||||
@ -586,7 +590,8 @@ func @testInvalidIfOp(tensor<i1>, tensor<*xf32>) -> tensor<2xf32> {
|
|||||||
// expected-error @+1 {{else branch result type tensor<3xf32> is incompatible with op result type tensor<2xf32>}}
|
// expected-error @+1 {{else branch result type tensor<3xf32> is incompatible with op result type tensor<2xf32>}}
|
||||||
%1 = "tf.If"(%arg0, %arg1) {
|
%1 = "tf.If"(%arg0, %arg1) {
|
||||||
then_branch = @testIfThen,
|
then_branch = @testIfThen,
|
||||||
else_branch = @testIfElse
|
else_branch = @testIfElse,
|
||||||
|
is_stateless = false
|
||||||
} : (tensor<i1>, tensor<*xf32>) -> tensor<2xf32>
|
} : (tensor<i1>, tensor<*xf32>) -> tensor<2xf32>
|
||||||
|
|
||||||
return %1 : tensor<2xf32>
|
return %1 : tensor<2xf32>
|
||||||
|
|||||||
@ -979,9 +979,12 @@ Status Importer::ConvertNode(const Node& node) {
|
|||||||
node_type_name = (*tf_name_to_mlir_name_)[node_type_name];
|
node_type_name = (*tf_name_to_mlir_name_)[node_type_name];
|
||||||
}
|
}
|
||||||
|
|
||||||
const char* kTfControlFlowFormPrefix = "_tf.";
|
auto get_full_op_name = [&](const std::string& op_name) {
|
||||||
std::string op_name = kTfControlFlowFormPrefix + node_type_name;
|
const char* kTfControlFlowFormPrefix = "_tf.";
|
||||||
|
return kTfControlFlowFormPrefix + op_name;
|
||||||
|
};
|
||||||
|
|
||||||
|
std::string op_name = get_full_op_name(node_type_name);
|
||||||
if (back_edge_node_output_.contains(&node)) {
|
if (back_edge_node_output_.contains(&node)) {
|
||||||
op_name = op_name + ".sink";
|
op_name = op_name + ".sink";
|
||||||
}
|
}
|
||||||
@ -1082,6 +1085,14 @@ Status Importer::ConvertNode(const Node& node) {
|
|||||||
result.attributes.push_back(builder_->getNamedAttr(
|
result.attributes.push_back(builder_->getNamedAttr(
|
||||||
"device", builder_->getStringAttr(std::string(node_def.device()))));
|
"device", builder_->getStringAttr(std::string(node_def.device()))));
|
||||||
|
|
||||||
|
// Map If and StatelessIf op in TensorFlow to the common If op in MLIR and add
|
||||||
|
// the differentiating attribute.
|
||||||
|
if (node.IsIfNode()) {
|
||||||
|
result.name = mlir::OperationName(get_full_op_name("If"), context_);
|
||||||
|
mlir::BoolAttr val = builder_->getBoolAttr(node_type_name == "StatelessIf");
|
||||||
|
result.attributes.push_back(builder_->getNamedAttr("is_stateless", val));
|
||||||
|
}
|
||||||
|
|
||||||
node_values_[node.id()] = builder_->createOperation(result);
|
node_values_[node.id()] = builder_->createOperation(result);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|||||||
@ -160,6 +160,18 @@ Status ConvertAttribute(const mlir::ArrayAttr& attr, AttrValue* value) {
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Updates NodeDef constructed out of an MLIR If op to map it to either
|
||||||
|
// TensorFlow StatelessIf or If op depending on the additional attribute.
|
||||||
|
void UpdateCompositeIfOp(NodeDef* node_def) {
|
||||||
|
auto it = node_def->mutable_attr()->find("is_stateless");
|
||||||
|
if (it != node_def->attr().end()) {
|
||||||
|
if (it->second.b()) {
|
||||||
|
*node_def->mutable_op() = "StatelessIf";
|
||||||
|
}
|
||||||
|
node_def->mutable_attr()->erase(it);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // anonymous namespace
|
} // anonymous namespace
|
||||||
|
|
||||||
StatusOr<std::unique_ptr<NodeDef>> GetOperationNodeDef(
|
StatusOr<std::unique_ptr<NodeDef>> GetOperationNodeDef(
|
||||||
@ -194,6 +206,8 @@ StatusOr<std::unique_ptr<NodeDef>> GetOperationNodeDef(
|
|||||||
TF_RETURN_IF_ERROR(ConvertLocation(
|
TF_RETURN_IF_ERROR(ConvertLocation(
|
||||||
inst->getLoc(), node_def->mutable_experimental_debug_info()));
|
inst->getLoc(), node_def->mutable_experimental_debug_info()));
|
||||||
|
|
||||||
|
if (node_def->op() == "If") UpdateCompositeIfOp(node_def.get());
|
||||||
|
|
||||||
return node_def;
|
return node_def;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user