Map TensorFlow StatelessIf and If op to a common If op in MLIR

TensorFlow StatelessIf and If op only differs in the is_stateful property and are identical otherwise. Introduced an additional attribute in the MLIR op definition to differentiate them and mapped to and from the common op while importing and export to MLIR, respectively.

Thanks Mehdi for the suggestion!

PiperOrigin-RevId: 259468359
This commit is contained in:
Smit Hinsu 2019-07-22 21:51:39 -07:00 committed by TensorFlower Gardener
parent ade316deef
commit 1de23834be
8 changed files with 50 additions and 17 deletions

View File

@ -160,7 +160,7 @@ func @main(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> {
%0 = "tfl.pseudo_input"(%arg0) : (tensor<1xf32>) -> tensor<1xf32> %0 = "tfl.pseudo_input"(%arg0) : (tensor<1xf32>) -> tensor<1xf32>
%1 = "tfl.pseudo_input"(%arg1) : (tensor<1xf32>) -> tensor<1xf32> %1 = "tfl.pseudo_input"(%arg1) : (tensor<1xf32>) -> tensor<1xf32>
%2 = "tfl.less"(%0, %1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xi1> %2 = "tfl.less"(%0, %1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xi1>
%3 = "tf.If"(%2, %0, %1) {else_branch = @cond_false, then_branch = @cond_true} : (tensor<1xi1>, tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> %3 = "tf.If"(%2, %0, %1) {else_branch = @cond_false, then_branch = @cond_true, is_stateless = false} : (tensor<1xi1>, tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
return %3 : tensor<1xf32> return %3 : tensor<1xf32>
} }

View File

@ -103,7 +103,10 @@ else_branch: A function that takes 'inputs' and returns a list of
SymbolRefAttr:$then_branch, SymbolRefAttr:$then_branch,
SymbolRefAttr:$else_branch, SymbolRefAttr:$else_branch,
DefaultValuedAttr<StrArrayAttr, "{}">:$output_shapes DefaultValuedAttr<StrArrayAttr, "{}">:$output_shapes,
// Used to map StatelessIf and If op defined in TensorFlow to a common op.
BoolAttr:$is_stateless
); );
let results = (outs let results = (outs

View File

@ -7,7 +7,7 @@ func @testIf1Else(tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
func @testIf1Result(tensor<i1>, tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> { func @testIf1Result(tensor<i1>, tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> {
^bb0(%arg0: tensor<i1>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>): ^bb0(%arg0: tensor<i1>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>):
%1 = "tf.If"(%arg0, %arg1, %arg2) { %1 = "tf.If"(%arg0, %arg1, %arg2) {
then_branch = @testIf1Then, else_branch = @testIf1Else then_branch = @testIf1Then, else_branch = @testIf1Else, is_stateless = false
} : (tensor<i1>, tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> } : (tensor<i1>, tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
// CHECK: %0 = extract_element %arg0[] : tensor<i1> // CHECK: %0 = extract_element %arg0[] : tensor<i1>
@ -31,7 +31,7 @@ func @testIf3Else(tensor<*xf32>) -> (tensor<*xf32>, tensor<*xi8>, tensor<*xbf16>
func @testIf3Result(tensor<i1>, tensor<*xf32>) -> (tensor<*xf32>, tensor<*xi8>, tensor<*xbf16>) { func @testIf3Result(tensor<i1>, tensor<*xf32>) -> (tensor<*xf32>, tensor<*xi8>, tensor<*xbf16>) {
^bb0(%arg0: tensor<i1>, %arg1: tensor<*xf32>): ^bb0(%arg0: tensor<i1>, %arg1: tensor<*xf32>):
%1:3 = "tf.If"(%arg0, %arg1) { %1:3 = "tf.If"(%arg0, %arg1) {
then_branch = @testIf3Then, else_branch = @testIf3Else then_branch = @testIf3Then, else_branch = @testIf3Else, is_stateless = false
} : (tensor<i1>, tensor<*xf32>) -> (tensor<*xf32>, tensor<*xi8>, tensor<*xbf16>) } : (tensor<i1>, tensor<*xf32>) -> (tensor<*xf32>, tensor<*xi8>, tensor<*xbf16>)
// CHECK: %0 = extract_element %arg0[] : tensor<i1> // CHECK: %0 = extract_element %arg0[] : tensor<i1>
@ -57,7 +57,7 @@ func @testIf1Casts(tensor<i1>, tensor<2x2xf32>, tensor<*xf32>) -> tensor<2x?xf32
^bb0(%arg0: tensor<i1>, %arg1: tensor<2x2xf32>, %arg2: tensor<*xf32>): ^bb0(%arg0: tensor<i1>, %arg1: tensor<2x2xf32>, %arg2: tensor<*xf32>):
%1 = "tf.If"(%arg0, %arg1, %arg2) { %1 = "tf.If"(%arg0, %arg1, %arg2) {
then_branch = @testIf1Then, else_branch = @testIf1Else then_branch = @testIf1Then, else_branch = @testIf1Else, is_stateless = false
} : (tensor<i1>, tensor<2x2xf32>, tensor<*xf32>) -> tensor<2x?xf32> } : (tensor<i1>, tensor<2x2xf32>, tensor<*xf32>) -> tensor<2x?xf32>
// CHECK: %0 = extract_element %arg0[] : tensor<i1> // CHECK: %0 = extract_element %arg0[] : tensor<i1>
@ -97,7 +97,7 @@ func @testIf1x4(tensor<4xi1>, tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> {
// expected-error @+1 {{only supports zero-D bool tensors now}} // expected-error @+1 {{only supports zero-D bool tensors now}}
%1 = "tf.If"(%arg0, %arg1, %arg2) { %1 = "tf.If"(%arg0, %arg1, %arg2) {
then_branch = @testIf1Then, else_branch = @testIf1Else then_branch = @testIf1Then, else_branch = @testIf1Else, is_stateless = false
} : (tensor<4xi1>, tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> } : (tensor<4xi1>, tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
return %1 : tensor<*xf32> return %1 : tensor<*xf32>

View File

@ -517,7 +517,7 @@ versions {
# CHECK-NEXT: %9:2 = "_tf.Identity"(%2#0, %7) {T = "tfdtype$DT_INT32", _tpu_replicate = "cluster", device = "", name = "replicated_input_0"} : (tensor<*xi32>, !_tf.control) -> (tensor<*xi32>, !_tf.control) # CHECK-NEXT: %9:2 = "_tf.Identity"(%2#0, %7) {T = "tfdtype$DT_INT32", _tpu_replicate = "cluster", device = "", name = "replicated_input_0"} : (tensor<*xi32>, !_tf.control) -> (tensor<*xi32>, !_tf.control)
# CHECK-NEXT: %10:2 = "_tf.Identity"(%4#0, %7) {T = "tfdtype$DT_INT32", _tpu_replicate = "cluster", device = "", name = "replicated_input_1"} : (tensor<*xi32>, !_tf.control) -> (tensor<*xi32>, !_tf.control) # CHECK-NEXT: %10:2 = "_tf.Identity"(%4#0, %7) {T = "tfdtype$DT_INT32", _tpu_replicate = "cluster", device = "", name = "replicated_input_1"} : (tensor<*xi32>, !_tf.control) -> (tensor<*xi32>, !_tf.control)
# CHECK-NEXT: %11:2 = "_tf.Less"(%9#0, %10#0) {T = "tfdtype$DT_INT32", _tpu_replicate = "cluster", device = "", name = "Less"} : (tensor<*xi32>, tensor<*xi32>) -> (tensor<*xi1>, !_tf.control) # CHECK-NEXT: %11:2 = "_tf.Less"(%9#0, %10#0) {T = "tfdtype$DT_INT32", _tpu_replicate = "cluster", device = "", name = "Less"} : (tensor<*xi32>, tensor<*xi32>) -> (tensor<*xi1>, !_tf.control)
# CHECK-NEXT: %12:3 = "_tf.If"(%11#0, %10#0, %9#0) {Tcond = "tfdtype$DT_BOOL", Tin = ["tfdtype$DT_INT32", "tfdtype$DT_INT32"], Tout = ["tfdtype$DT_INT32", "tfdtype$DT_INT32"], _tpu_replicate = "cluster", device = "", else_branch = @cond_false0, name = "cond", output_shapes = ["tfshape$unknown_rank: true\0A", "tfshape$unknown_rank: true\0A"], then_branch = @cond_true0, then_branch.how_many = 32 : i64, then_branch.ping = "ack"} : (tensor<*xi1>, tensor<*xi32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi32>, !_tf.control) # CHECK-NEXT: %12:3 = "_tf.If"(%11#0, %10#0, %9#0) {Tcond = "tfdtype$DT_BOOL", Tin = ["tfdtype$DT_INT32", "tfdtype$DT_INT32"], Tout = ["tfdtype$DT_INT32", "tfdtype$DT_INT32"], _tpu_replicate = "cluster", device = "", else_branch = @cond_false0, is_stateless = false, name = "cond", output_shapes = ["tfshape$unknown_rank: true\0A", "tfshape$unknown_rank: true\0A"], then_branch = @cond_true0, then_branch.how_many = 32 : i64, then_branch.ping = "ack"} : (tensor<*xi1>, tensor<*xi32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi32>, !_tf.control)
# CHECK-NEXT: %13:2 = "_tf.Identity"(%12#0) {T = "tfdtype$DT_INT32", _tpu_replicate = "cluster", device = "/device:TPU_REPLICATED_CORE:0", name = "Identity"} : (tensor<*xi32>) -> (tensor<*xi32>, !_tf.control) # CHECK-NEXT: %13:2 = "_tf.Identity"(%12#0) {T = "tfdtype$DT_INT32", _tpu_replicate = "cluster", device = "/device:TPU_REPLICATED_CORE:0", name = "Identity"} : (tensor<*xi32>) -> (tensor<*xi32>, !_tf.control)
# CHECK-NEXT: %14:2 = "_tf.TPUReplicatedOutput"(%13#0) {T = "tfdtype$DT_INT32", device = "", name = "output0", num_replicas = 1 : i64} : (tensor<*xi32>) -> (tensor<*xi32>, !_tf.control) # CHECK-NEXT: %14:2 = "_tf.TPUReplicatedOutput"(%13#0) {T = "tfdtype$DT_INT32", device = "", name = "output0", num_replicas = 1 : i64} : (tensor<*xi32>) -> (tensor<*xi32>, !_tf.control)
# CHECK-NEXT: %15:2 = "_tf.Identity"(%14#0, %6) {T = "tfdtype$DT_INT32", device = "", name = "output_0_shard_0"} : (tensor<*xi32>, !_tf.control) -> (tensor<*xi32>, !_tf.control) # CHECK-NEXT: %15:2 = "_tf.Identity"(%14#0, %6) {T = "tfdtype$DT_INT32", device = "", name = "output_0_shard_0"} : (tensor<*xi32>, !_tf.control) -> (tensor<*xi32>, !_tf.control)

View File

@ -142,7 +142,7 @@ versions {
#CHECK: func @main() { #CHECK: func @main() {
#CHECK-NEXT: %0:2 = "_tf.Placeholder"() {device = "", dtype = "tfdtype$DT_BOOL", name = "Placeholder", shape = "tfshape$unknown_rank: true\0A"} : () -> (tensor<*xi1>, !_tf.control) #CHECK-NEXT: %0:2 = "_tf.Placeholder"() {device = "", dtype = "tfdtype$DT_BOOL", name = "Placeholder", shape = "tfshape$unknown_rank: true\0A"} : () -> (tensor<*xi1>, !_tf.control)
#CHECK-NEXT: %1:2 = "_tf.Placeholder"() {device = "", dtype = "tfdtype$DT_INT32", name = "Placeholder_1", shape = "tfshape$unknown_rank: true\0A"} : () -> (tensor<*xi32>, !_tf.control) #CHECK-NEXT: %1:2 = "_tf.Placeholder"() {device = "", dtype = "tfdtype$DT_INT32", name = "Placeholder_1", shape = "tfshape$unknown_rank: true\0A"} : () -> (tensor<*xi32>, !_tf.control)
#CHECK-NEXT: %2:2 = "_tf.If"(%0#0, %1#0) {Tcond = "tfdtype$DT_BOOL", Tin = ["tfdtype$DT_INT32"], Tout = ["tfdtype$DT_INT32"], device = "", else_branch = @get_zeros0, name = "If", output_shapes = [], then_branch = @identity0} : (tensor<*xi1>, tensor<*xi32>) -> (tensor<*xi32>, !_tf.control) #CHECK-NEXT: %2:2 = "_tf.If"(%0#0, %1#0) {Tcond = "tfdtype$DT_BOOL", Tin = ["tfdtype$DT_INT32"], Tout = ["tfdtype$DT_INT32"], device = "", else_branch = @get_zeros0, is_stateless = false, name = "If", output_shapes = [], then_branch = @identity0} : (tensor<*xi1>, tensor<*xi32>) -> (tensor<*xi32>, !_tf.control)
#CHECK-NEXT: return #CHECK-NEXT: return
#CHECK-NEXT: } #CHECK-NEXT: }
#CHECK: func @get_zeros0(%arg0: tensor<*xi32>) -> tensor<2xi32> { #CHECK: func @get_zeros0(%arg0: tensor<*xi32>) -> tensor<2xi32> {

View File

@ -486,7 +486,7 @@ func @testIfElse(tensor<*xf32>) -> tensor<*xf32>
func @testValidIfOp(tensor<i1>, tensor<2xf32>) -> tensor<2xf32> { func @testValidIfOp(tensor<i1>, tensor<2xf32>) -> tensor<2xf32> {
^bb0(%arg0: tensor<i1>, %arg1: tensor<2xf32>): ^bb0(%arg0: tensor<i1>, %arg1: tensor<2xf32>):
%1 = "tf.If"(%arg0, %arg1) { %1 = "tf.If"(%arg0, %arg1) {
then_branch = @testIfThen, else_branch = @testIfElse then_branch = @testIfThen, else_branch = @testIfElse, is_stateless = false
} : (tensor<i1>, tensor<2xf32>) -> tensor<2xf32> } : (tensor<i1>, tensor<2xf32>) -> tensor<2xf32>
return %1 : tensor<2xf32> return %1 : tensor<2xf32>
@ -503,7 +503,8 @@ func @testInvalidIfOp(tensor<i1>, f32) -> f32 {
// expected-error @+1 {{operand #1 must be tensor of tf.dtype values}} // expected-error @+1 {{operand #1 must be tensor of tf.dtype values}}
%1 = "tf.If"(%arg0, %arg1) { %1 = "tf.If"(%arg0, %arg1) {
then_branch = @testIfThen, then_branch = @testIfThen,
else_branch = @testIfElse else_branch = @testIfElse,
is_stateless = false
} : (tensor<i1>, f32) -> f32 } : (tensor<i1>, f32) -> f32
return %1 : f32 return %1 : f32
@ -518,7 +519,7 @@ func @testInvalidIfOp(tensor<i1>, tensor<2xf32>) -> tensor<2xf32> {
^bb0(%arg0: tensor<i1>, %arg1: tensor<2xf32>): ^bb0(%arg0: tensor<i1>, %arg1: tensor<2xf32>):
// expected-error @+1 {{requires attribute 'then_branch'}} // expected-error @+1 {{requires attribute 'then_branch'}}
%1 = "tf.If"(%arg0, %arg1) { %1 = "tf.If"(%arg0, %arg1) {
else_branch = @testIfElse else_branch = @testIfElse, is_stateless = false
} : (tensor<i1>, tensor<2xf32>) -> tensor<2xf32> } : (tensor<i1>, tensor<2xf32>) -> tensor<2xf32>
return %1 : tensor<2xf32> return %1 : tensor<2xf32>
@ -535,7 +536,8 @@ func @testInvalidIfOp(tensor<i1>, tensor<2xf32>) -> tensor<2xf32> {
// expected-error @+1 {{branches should have 1 inputs}} // expected-error @+1 {{branches should have 1 inputs}}
%1 = "tf.If"(%arg0, %arg1) { %1 = "tf.If"(%arg0, %arg1) {
then_branch = @testIfThen, then_branch = @testIfThen,
else_branch = @testIfElse else_branch = @testIfElse,
is_stateless = false
} : (tensor<i1>, tensor<2xf32>) -> tensor<2xf32> } : (tensor<i1>, tensor<2xf32>) -> tensor<2xf32>
return %1 : tensor<2xf32> return %1 : tensor<2xf32>
@ -552,7 +554,8 @@ func @testInvalidIfOp(tensor<i1>, tensor<2xf32>) -> tensor<2xf32> {
// expected-error @+1 {{then branch input type tensor<*xf16> is incompatible with operand type tensor<2xf32>}} // expected-error @+1 {{then branch input type tensor<*xf16> is incompatible with operand type tensor<2xf32>}}
%1 = "tf.If"(%arg0, %arg1) { %1 = "tf.If"(%arg0, %arg1) {
then_branch = @testIfThen, then_branch = @testIfThen,
else_branch = @testIfElse else_branch = @testIfElse,
is_stateless = false
} : (tensor<i1>, tensor<2xf32>) -> tensor<2xf32> } : (tensor<i1>, tensor<2xf32>) -> tensor<2xf32>
return %1 : tensor<2xf32> return %1 : tensor<2xf32>
@ -569,7 +572,8 @@ func @testInvalidIfOp(tensor<i1>, tensor<*xf32>) -> tensor<2xf32> {
// expected-error @+1 {{branches inputs have incompatible types tensor<2xf32> and tensor<3xf32>}} // expected-error @+1 {{branches inputs have incompatible types tensor<2xf32> and tensor<3xf32>}}
%1 = "tf.If"(%arg0, %arg1) { %1 = "tf.If"(%arg0, %arg1) {
then_branch = @testIfThen, then_branch = @testIfThen,
else_branch = @testIfElse else_branch = @testIfElse,
is_stateless = false
} : (tensor<i1>, tensor<*xf32>) -> tensor<2xf32> } : (tensor<i1>, tensor<*xf32>) -> tensor<2xf32>
return %1 : tensor<2xf32> return %1 : tensor<2xf32>
@ -586,7 +590,8 @@ func @testInvalidIfOp(tensor<i1>, tensor<*xf32>) -> tensor<2xf32> {
// expected-error @+1 {{else branch result type tensor<3xf32> is incompatible with op result type tensor<2xf32>}} // expected-error @+1 {{else branch result type tensor<3xf32> is incompatible with op result type tensor<2xf32>}}
%1 = "tf.If"(%arg0, %arg1) { %1 = "tf.If"(%arg0, %arg1) {
then_branch = @testIfThen, then_branch = @testIfThen,
else_branch = @testIfElse else_branch = @testIfElse,
is_stateless = false
} : (tensor<i1>, tensor<*xf32>) -> tensor<2xf32> } : (tensor<i1>, tensor<*xf32>) -> tensor<2xf32>
return %1 : tensor<2xf32> return %1 : tensor<2xf32>

View File

@ -979,9 +979,12 @@ Status Importer::ConvertNode(const Node& node) {
node_type_name = (*tf_name_to_mlir_name_)[node_type_name]; node_type_name = (*tf_name_to_mlir_name_)[node_type_name];
} }
const char* kTfControlFlowFormPrefix = "_tf."; auto get_full_op_name = [&](const std::string& op_name) {
std::string op_name = kTfControlFlowFormPrefix + node_type_name; const char* kTfControlFlowFormPrefix = "_tf.";
return kTfControlFlowFormPrefix + op_name;
};
std::string op_name = get_full_op_name(node_type_name);
if (back_edge_node_output_.contains(&node)) { if (back_edge_node_output_.contains(&node)) {
op_name = op_name + ".sink"; op_name = op_name + ".sink";
} }
@ -1082,6 +1085,14 @@ Status Importer::ConvertNode(const Node& node) {
result.attributes.push_back(builder_->getNamedAttr( result.attributes.push_back(builder_->getNamedAttr(
"device", builder_->getStringAttr(std::string(node_def.device())))); "device", builder_->getStringAttr(std::string(node_def.device()))));
// Map If and StatelessIf op in TensorFlow to the common If op in MLIR and add
// the differentiating attribute.
if (node.IsIfNode()) {
result.name = mlir::OperationName(get_full_op_name("If"), context_);
mlir::BoolAttr val = builder_->getBoolAttr(node_type_name == "StatelessIf");
result.attributes.push_back(builder_->getNamedAttr("is_stateless", val));
}
node_values_[node.id()] = builder_->createOperation(result); node_values_[node.id()] = builder_->createOperation(result);
return Status::OK(); return Status::OK();
} }

View File

@ -160,6 +160,18 @@ Status ConvertAttribute(const mlir::ArrayAttr& attr, AttrValue* value) {
return Status::OK(); return Status::OK();
} }
// Updates NodeDef constructed out of an MLIR If op to map it to either
// TensorFlow StatelessIf or If op depending on the additional attribute.
void UpdateCompositeIfOp(NodeDef* node_def) {
auto it = node_def->mutable_attr()->find("is_stateless");
if (it != node_def->attr().end()) {
if (it->second.b()) {
*node_def->mutable_op() = "StatelessIf";
}
node_def->mutable_attr()->erase(it);
}
}
} // anonymous namespace } // anonymous namespace
StatusOr<std::unique_ptr<NodeDef>> GetOperationNodeDef( StatusOr<std::unique_ptr<NodeDef>> GetOperationNodeDef(
@ -194,6 +206,8 @@ StatusOr<std::unique_ptr<NodeDef>> GetOperationNodeDef(
TF_RETURN_IF_ERROR(ConvertLocation( TF_RETURN_IF_ERROR(ConvertLocation(
inst->getLoc(), node_def->mutable_experimental_debug_info())); inst->getLoc(), node_def->mutable_experimental_debug_info()));
if (node_def->op() == "If") UpdateCompositeIfOp(node_def.get());
return node_def; return node_def;
} }