Map TensorFlow StatelessIf and If op to a common If op in MLIR
TensorFlow StatelessIf and If op only differs in the is_stateful property and are identical otherwise. Introduced an additional attribute in the MLIR op definition to differentiate them and mapped to and from the common op while importing and export to MLIR, respectively. Thanks Mehdi for the suggestion! PiperOrigin-RevId: 259468359
This commit is contained in:
parent
ade316deef
commit
1de23834be
@ -160,7 +160,7 @@ func @main(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> {
|
||||
%0 = "tfl.pseudo_input"(%arg0) : (tensor<1xf32>) -> tensor<1xf32>
|
||||
%1 = "tfl.pseudo_input"(%arg1) : (tensor<1xf32>) -> tensor<1xf32>
|
||||
%2 = "tfl.less"(%0, %1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xi1>
|
||||
%3 = "tf.If"(%2, %0, %1) {else_branch = @cond_false, then_branch = @cond_true} : (tensor<1xi1>, tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
|
||||
%3 = "tf.If"(%2, %0, %1) {else_branch = @cond_false, then_branch = @cond_true, is_stateless = false} : (tensor<1xi1>, tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
|
||||
return %3 : tensor<1xf32>
|
||||
}
|
||||
|
||||
|
@ -103,7 +103,10 @@ else_branch: A function that takes 'inputs' and returns a list of
|
||||
|
||||
SymbolRefAttr:$then_branch,
|
||||
SymbolRefAttr:$else_branch,
|
||||
DefaultValuedAttr<StrArrayAttr, "{}">:$output_shapes
|
||||
DefaultValuedAttr<StrArrayAttr, "{}">:$output_shapes,
|
||||
|
||||
// Used to map StatelessIf and If op defined in TensorFlow to a common op.
|
||||
BoolAttr:$is_stateless
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
|
@ -7,7 +7,7 @@ func @testIf1Else(tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||
func @testIf1Result(tensor<i1>, tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> {
|
||||
^bb0(%arg0: tensor<i1>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>):
|
||||
%1 = "tf.If"(%arg0, %arg1, %arg2) {
|
||||
then_branch = @testIf1Then, else_branch = @testIf1Else
|
||||
then_branch = @testIf1Then, else_branch = @testIf1Else, is_stateless = false
|
||||
} : (tensor<i1>, tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||
|
||||
// CHECK: %0 = extract_element %arg0[] : tensor<i1>
|
||||
@ -31,7 +31,7 @@ func @testIf3Else(tensor<*xf32>) -> (tensor<*xf32>, tensor<*xi8>, tensor<*xbf16>
|
||||
func @testIf3Result(tensor<i1>, tensor<*xf32>) -> (tensor<*xf32>, tensor<*xi8>, tensor<*xbf16>) {
|
||||
^bb0(%arg0: tensor<i1>, %arg1: tensor<*xf32>):
|
||||
%1:3 = "tf.If"(%arg0, %arg1) {
|
||||
then_branch = @testIf3Then, else_branch = @testIf3Else
|
||||
then_branch = @testIf3Then, else_branch = @testIf3Else, is_stateless = false
|
||||
} : (tensor<i1>, tensor<*xf32>) -> (tensor<*xf32>, tensor<*xi8>, tensor<*xbf16>)
|
||||
|
||||
// CHECK: %0 = extract_element %arg0[] : tensor<i1>
|
||||
@ -57,7 +57,7 @@ func @testIf1Casts(tensor<i1>, tensor<2x2xf32>, tensor<*xf32>) -> tensor<2x?xf32
|
||||
^bb0(%arg0: tensor<i1>, %arg1: tensor<2x2xf32>, %arg2: tensor<*xf32>):
|
||||
|
||||
%1 = "tf.If"(%arg0, %arg1, %arg2) {
|
||||
then_branch = @testIf1Then, else_branch = @testIf1Else
|
||||
then_branch = @testIf1Then, else_branch = @testIf1Else, is_stateless = false
|
||||
} : (tensor<i1>, tensor<2x2xf32>, tensor<*xf32>) -> tensor<2x?xf32>
|
||||
|
||||
// CHECK: %0 = extract_element %arg0[] : tensor<i1>
|
||||
@ -97,7 +97,7 @@ func @testIf1x4(tensor<4xi1>, tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> {
|
||||
|
||||
// expected-error @+1 {{only supports zero-D bool tensors now}}
|
||||
%1 = "tf.If"(%arg0, %arg1, %arg2) {
|
||||
then_branch = @testIf1Then, else_branch = @testIf1Else
|
||||
then_branch = @testIf1Then, else_branch = @testIf1Else, is_stateless = false
|
||||
} : (tensor<4xi1>, tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||
|
||||
return %1 : tensor<*xf32>
|
||||
|
@ -517,7 +517,7 @@ versions {
|
||||
# CHECK-NEXT: %9:2 = "_tf.Identity"(%2#0, %7) {T = "tfdtype$DT_INT32", _tpu_replicate = "cluster", device = "", name = "replicated_input_0"} : (tensor<*xi32>, !_tf.control) -> (tensor<*xi32>, !_tf.control)
|
||||
# CHECK-NEXT: %10:2 = "_tf.Identity"(%4#0, %7) {T = "tfdtype$DT_INT32", _tpu_replicate = "cluster", device = "", name = "replicated_input_1"} : (tensor<*xi32>, !_tf.control) -> (tensor<*xi32>, !_tf.control)
|
||||
# CHECK-NEXT: %11:2 = "_tf.Less"(%9#0, %10#0) {T = "tfdtype$DT_INT32", _tpu_replicate = "cluster", device = "", name = "Less"} : (tensor<*xi32>, tensor<*xi32>) -> (tensor<*xi1>, !_tf.control)
|
||||
# CHECK-NEXT: %12:3 = "_tf.If"(%11#0, %10#0, %9#0) {Tcond = "tfdtype$DT_BOOL", Tin = ["tfdtype$DT_INT32", "tfdtype$DT_INT32"], Tout = ["tfdtype$DT_INT32", "tfdtype$DT_INT32"], _tpu_replicate = "cluster", device = "", else_branch = @cond_false0, name = "cond", output_shapes = ["tfshape$unknown_rank: true\0A", "tfshape$unknown_rank: true\0A"], then_branch = @cond_true0, then_branch.how_many = 32 : i64, then_branch.ping = "ack"} : (tensor<*xi1>, tensor<*xi32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi32>, !_tf.control)
|
||||
# CHECK-NEXT: %12:3 = "_tf.If"(%11#0, %10#0, %9#0) {Tcond = "tfdtype$DT_BOOL", Tin = ["tfdtype$DT_INT32", "tfdtype$DT_INT32"], Tout = ["tfdtype$DT_INT32", "tfdtype$DT_INT32"], _tpu_replicate = "cluster", device = "", else_branch = @cond_false0, is_stateless = false, name = "cond", output_shapes = ["tfshape$unknown_rank: true\0A", "tfshape$unknown_rank: true\0A"], then_branch = @cond_true0, then_branch.how_many = 32 : i64, then_branch.ping = "ack"} : (tensor<*xi1>, tensor<*xi32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi32>, !_tf.control)
|
||||
# CHECK-NEXT: %13:2 = "_tf.Identity"(%12#0) {T = "tfdtype$DT_INT32", _tpu_replicate = "cluster", device = "/device:TPU_REPLICATED_CORE:0", name = "Identity"} : (tensor<*xi32>) -> (tensor<*xi32>, !_tf.control)
|
||||
# CHECK-NEXT: %14:2 = "_tf.TPUReplicatedOutput"(%13#0) {T = "tfdtype$DT_INT32", device = "", name = "output0", num_replicas = 1 : i64} : (tensor<*xi32>) -> (tensor<*xi32>, !_tf.control)
|
||||
# CHECK-NEXT: %15:2 = "_tf.Identity"(%14#0, %6) {T = "tfdtype$DT_INT32", device = "", name = "output_0_shard_0"} : (tensor<*xi32>, !_tf.control) -> (tensor<*xi32>, !_tf.control)
|
||||
|
@ -142,7 +142,7 @@ versions {
|
||||
#CHECK: func @main() {
|
||||
#CHECK-NEXT: %0:2 = "_tf.Placeholder"() {device = "", dtype = "tfdtype$DT_BOOL", name = "Placeholder", shape = "tfshape$unknown_rank: true\0A"} : () -> (tensor<*xi1>, !_tf.control)
|
||||
#CHECK-NEXT: %1:2 = "_tf.Placeholder"() {device = "", dtype = "tfdtype$DT_INT32", name = "Placeholder_1", shape = "tfshape$unknown_rank: true\0A"} : () -> (tensor<*xi32>, !_tf.control)
|
||||
#CHECK-NEXT: %2:2 = "_tf.If"(%0#0, %1#0) {Tcond = "tfdtype$DT_BOOL", Tin = ["tfdtype$DT_INT32"], Tout = ["tfdtype$DT_INT32"], device = "", else_branch = @get_zeros0, name = "If", output_shapes = [], then_branch = @identity0} : (tensor<*xi1>, tensor<*xi32>) -> (tensor<*xi32>, !_tf.control)
|
||||
#CHECK-NEXT: %2:2 = "_tf.If"(%0#0, %1#0) {Tcond = "tfdtype$DT_BOOL", Tin = ["tfdtype$DT_INT32"], Tout = ["tfdtype$DT_INT32"], device = "", else_branch = @get_zeros0, is_stateless = false, name = "If", output_shapes = [], then_branch = @identity0} : (tensor<*xi1>, tensor<*xi32>) -> (tensor<*xi32>, !_tf.control)
|
||||
#CHECK-NEXT: return
|
||||
#CHECK-NEXT: }
|
||||
#CHECK: func @get_zeros0(%arg0: tensor<*xi32>) -> tensor<2xi32> {
|
||||
|
@ -486,7 +486,7 @@ func @testIfElse(tensor<*xf32>) -> tensor<*xf32>
|
||||
func @testValidIfOp(tensor<i1>, tensor<2xf32>) -> tensor<2xf32> {
|
||||
^bb0(%arg0: tensor<i1>, %arg1: tensor<2xf32>):
|
||||
%1 = "tf.If"(%arg0, %arg1) {
|
||||
then_branch = @testIfThen, else_branch = @testIfElse
|
||||
then_branch = @testIfThen, else_branch = @testIfElse, is_stateless = false
|
||||
} : (tensor<i1>, tensor<2xf32>) -> tensor<2xf32>
|
||||
|
||||
return %1 : tensor<2xf32>
|
||||
@ -503,7 +503,8 @@ func @testInvalidIfOp(tensor<i1>, f32) -> f32 {
|
||||
// expected-error @+1 {{operand #1 must be tensor of tf.dtype values}}
|
||||
%1 = "tf.If"(%arg0, %arg1) {
|
||||
then_branch = @testIfThen,
|
||||
else_branch = @testIfElse
|
||||
else_branch = @testIfElse,
|
||||
is_stateless = false
|
||||
} : (tensor<i1>, f32) -> f32
|
||||
|
||||
return %1 : f32
|
||||
@ -518,7 +519,7 @@ func @testInvalidIfOp(tensor<i1>, tensor<2xf32>) -> tensor<2xf32> {
|
||||
^bb0(%arg0: tensor<i1>, %arg1: tensor<2xf32>):
|
||||
// expected-error @+1 {{requires attribute 'then_branch'}}
|
||||
%1 = "tf.If"(%arg0, %arg1) {
|
||||
else_branch = @testIfElse
|
||||
else_branch = @testIfElse, is_stateless = false
|
||||
} : (tensor<i1>, tensor<2xf32>) -> tensor<2xf32>
|
||||
|
||||
return %1 : tensor<2xf32>
|
||||
@ -535,7 +536,8 @@ func @testInvalidIfOp(tensor<i1>, tensor<2xf32>) -> tensor<2xf32> {
|
||||
// expected-error @+1 {{branches should have 1 inputs}}
|
||||
%1 = "tf.If"(%arg0, %arg1) {
|
||||
then_branch = @testIfThen,
|
||||
else_branch = @testIfElse
|
||||
else_branch = @testIfElse,
|
||||
is_stateless = false
|
||||
} : (tensor<i1>, tensor<2xf32>) -> tensor<2xf32>
|
||||
|
||||
return %1 : tensor<2xf32>
|
||||
@ -552,7 +554,8 @@ func @testInvalidIfOp(tensor<i1>, tensor<2xf32>) -> tensor<2xf32> {
|
||||
// expected-error @+1 {{then branch input type tensor<*xf16> is incompatible with operand type tensor<2xf32>}}
|
||||
%1 = "tf.If"(%arg0, %arg1) {
|
||||
then_branch = @testIfThen,
|
||||
else_branch = @testIfElse
|
||||
else_branch = @testIfElse,
|
||||
is_stateless = false
|
||||
} : (tensor<i1>, tensor<2xf32>) -> tensor<2xf32>
|
||||
|
||||
return %1 : tensor<2xf32>
|
||||
@ -569,7 +572,8 @@ func @testInvalidIfOp(tensor<i1>, tensor<*xf32>) -> tensor<2xf32> {
|
||||
// expected-error @+1 {{branches inputs have incompatible types tensor<2xf32> and tensor<3xf32>}}
|
||||
%1 = "tf.If"(%arg0, %arg1) {
|
||||
then_branch = @testIfThen,
|
||||
else_branch = @testIfElse
|
||||
else_branch = @testIfElse,
|
||||
is_stateless = false
|
||||
} : (tensor<i1>, tensor<*xf32>) -> tensor<2xf32>
|
||||
|
||||
return %1 : tensor<2xf32>
|
||||
@ -586,7 +590,8 @@ func @testInvalidIfOp(tensor<i1>, tensor<*xf32>) -> tensor<2xf32> {
|
||||
// expected-error @+1 {{else branch result type tensor<3xf32> is incompatible with op result type tensor<2xf32>}}
|
||||
%1 = "tf.If"(%arg0, %arg1) {
|
||||
then_branch = @testIfThen,
|
||||
else_branch = @testIfElse
|
||||
else_branch = @testIfElse,
|
||||
is_stateless = false
|
||||
} : (tensor<i1>, tensor<*xf32>) -> tensor<2xf32>
|
||||
|
||||
return %1 : tensor<2xf32>
|
||||
|
@ -979,9 +979,12 @@ Status Importer::ConvertNode(const Node& node) {
|
||||
node_type_name = (*tf_name_to_mlir_name_)[node_type_name];
|
||||
}
|
||||
|
||||
const char* kTfControlFlowFormPrefix = "_tf.";
|
||||
std::string op_name = kTfControlFlowFormPrefix + node_type_name;
|
||||
auto get_full_op_name = [&](const std::string& op_name) {
|
||||
const char* kTfControlFlowFormPrefix = "_tf.";
|
||||
return kTfControlFlowFormPrefix + op_name;
|
||||
};
|
||||
|
||||
std::string op_name = get_full_op_name(node_type_name);
|
||||
if (back_edge_node_output_.contains(&node)) {
|
||||
op_name = op_name + ".sink";
|
||||
}
|
||||
@ -1082,6 +1085,14 @@ Status Importer::ConvertNode(const Node& node) {
|
||||
result.attributes.push_back(builder_->getNamedAttr(
|
||||
"device", builder_->getStringAttr(std::string(node_def.device()))));
|
||||
|
||||
// Map If and StatelessIf op in TensorFlow to the common If op in MLIR and add
|
||||
// the differentiating attribute.
|
||||
if (node.IsIfNode()) {
|
||||
result.name = mlir::OperationName(get_full_op_name("If"), context_);
|
||||
mlir::BoolAttr val = builder_->getBoolAttr(node_type_name == "StatelessIf");
|
||||
result.attributes.push_back(builder_->getNamedAttr("is_stateless", val));
|
||||
}
|
||||
|
||||
node_values_[node.id()] = builder_->createOperation(result);
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -160,6 +160,18 @@ Status ConvertAttribute(const mlir::ArrayAttr& attr, AttrValue* value) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Updates NodeDef constructed out of an MLIR If op to map it to either
|
||||
// TensorFlow StatelessIf or If op depending on the additional attribute.
|
||||
void UpdateCompositeIfOp(NodeDef* node_def) {
|
||||
auto it = node_def->mutable_attr()->find("is_stateless");
|
||||
if (it != node_def->attr().end()) {
|
||||
if (it->second.b()) {
|
||||
*node_def->mutable_op() = "StatelessIf";
|
||||
}
|
||||
node_def->mutable_attr()->erase(it);
|
||||
}
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
StatusOr<std::unique_ptr<NodeDef>> GetOperationNodeDef(
|
||||
@ -194,6 +206,8 @@ StatusOr<std::unique_ptr<NodeDef>> GetOperationNodeDef(
|
||||
TF_RETURN_IF_ERROR(ConvertLocation(
|
||||
inst->getLoc(), node_def->mutable_experimental_debug_info()));
|
||||
|
||||
if (node_def->op() == "If") UpdateCompositeIfOp(node_def.get());
|
||||
|
||||
return node_def;
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user