Map TensorFlow StatelessIf and If op to a common If op in MLIR

TensorFlow StatelessIf and If op only differs in the is_stateful property and are identical otherwise. Introduced an additional attribute in the MLIR op definition to differentiate them and mapped to and from the common op while importing and export to MLIR, respectively.

Thanks Mehdi for the suggestion!

PiperOrigin-RevId: 259468359
This commit is contained in:
Smit Hinsu 2019-07-22 21:51:39 -07:00 committed by TensorFlower Gardener
parent ade316deef
commit 1de23834be
8 changed files with 50 additions and 17 deletions

View File

@ -160,7 +160,7 @@ func @main(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> {
%0 = "tfl.pseudo_input"(%arg0) : (tensor<1xf32>) -> tensor<1xf32>
%1 = "tfl.pseudo_input"(%arg1) : (tensor<1xf32>) -> tensor<1xf32>
%2 = "tfl.less"(%0, %1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xi1>
%3 = "tf.If"(%2, %0, %1) {else_branch = @cond_false, then_branch = @cond_true} : (tensor<1xi1>, tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
%3 = "tf.If"(%2, %0, %1) {else_branch = @cond_false, then_branch = @cond_true, is_stateless = false} : (tensor<1xi1>, tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
return %3 : tensor<1xf32>
}

View File

@ -103,7 +103,10 @@ else_branch: A function that takes 'inputs' and returns a list of
SymbolRefAttr:$then_branch,
SymbolRefAttr:$else_branch,
DefaultValuedAttr<StrArrayAttr, "{}">:$output_shapes
DefaultValuedAttr<StrArrayAttr, "{}">:$output_shapes,
// Used to map StatelessIf and If op defined in TensorFlow to a common op.
BoolAttr:$is_stateless
);
let results = (outs

View File

@ -7,7 +7,7 @@ func @testIf1Else(tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
func @testIf1Result(tensor<i1>, tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> {
^bb0(%arg0: tensor<i1>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>):
%1 = "tf.If"(%arg0, %arg1, %arg2) {
then_branch = @testIf1Then, else_branch = @testIf1Else
then_branch = @testIf1Then, else_branch = @testIf1Else, is_stateless = false
} : (tensor<i1>, tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
// CHECK: %0 = extract_element %arg0[] : tensor<i1>
@ -31,7 +31,7 @@ func @testIf3Else(tensor<*xf32>) -> (tensor<*xf32>, tensor<*xi8>, tensor<*xbf16>
func @testIf3Result(tensor<i1>, tensor<*xf32>) -> (tensor<*xf32>, tensor<*xi8>, tensor<*xbf16>) {
^bb0(%arg0: tensor<i1>, %arg1: tensor<*xf32>):
%1:3 = "tf.If"(%arg0, %arg1) {
then_branch = @testIf3Then, else_branch = @testIf3Else
then_branch = @testIf3Then, else_branch = @testIf3Else, is_stateless = false
} : (tensor<i1>, tensor<*xf32>) -> (tensor<*xf32>, tensor<*xi8>, tensor<*xbf16>)
// CHECK: %0 = extract_element %arg0[] : tensor<i1>
@ -57,7 +57,7 @@ func @testIf1Casts(tensor<i1>, tensor<2x2xf32>, tensor<*xf32>) -> tensor<2x?xf32
^bb0(%arg0: tensor<i1>, %arg1: tensor<2x2xf32>, %arg2: tensor<*xf32>):
%1 = "tf.If"(%arg0, %arg1, %arg2) {
then_branch = @testIf1Then, else_branch = @testIf1Else
then_branch = @testIf1Then, else_branch = @testIf1Else, is_stateless = false
} : (tensor<i1>, tensor<2x2xf32>, tensor<*xf32>) -> tensor<2x?xf32>
// CHECK: %0 = extract_element %arg0[] : tensor<i1>
@ -97,7 +97,7 @@ func @testIf1x4(tensor<4xi1>, tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> {
// expected-error @+1 {{only supports zero-D bool tensors now}}
%1 = "tf.If"(%arg0, %arg1, %arg2) {
then_branch = @testIf1Then, else_branch = @testIf1Else
then_branch = @testIf1Then, else_branch = @testIf1Else, is_stateless = false
} : (tensor<4xi1>, tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
return %1 : tensor<*xf32>

View File

@ -517,7 +517,7 @@ versions {
# CHECK-NEXT: %9:2 = "_tf.Identity"(%2#0, %7) {T = "tfdtype$DT_INT32", _tpu_replicate = "cluster", device = "", name = "replicated_input_0"} : (tensor<*xi32>, !_tf.control) -> (tensor<*xi32>, !_tf.control)
# CHECK-NEXT: %10:2 = "_tf.Identity"(%4#0, %7) {T = "tfdtype$DT_INT32", _tpu_replicate = "cluster", device = "", name = "replicated_input_1"} : (tensor<*xi32>, !_tf.control) -> (tensor<*xi32>, !_tf.control)
# CHECK-NEXT: %11:2 = "_tf.Less"(%9#0, %10#0) {T = "tfdtype$DT_INT32", _tpu_replicate = "cluster", device = "", name = "Less"} : (tensor<*xi32>, tensor<*xi32>) -> (tensor<*xi1>, !_tf.control)
# CHECK-NEXT: %12:3 = "_tf.If"(%11#0, %10#0, %9#0) {Tcond = "tfdtype$DT_BOOL", Tin = ["tfdtype$DT_INT32", "tfdtype$DT_INT32"], Tout = ["tfdtype$DT_INT32", "tfdtype$DT_INT32"], _tpu_replicate = "cluster", device = "", else_branch = @cond_false0, name = "cond", output_shapes = ["tfshape$unknown_rank: true\0A", "tfshape$unknown_rank: true\0A"], then_branch = @cond_true0, then_branch.how_many = 32 : i64, then_branch.ping = "ack"} : (tensor<*xi1>, tensor<*xi32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi32>, !_tf.control)
# CHECK-NEXT: %12:3 = "_tf.If"(%11#0, %10#0, %9#0) {Tcond = "tfdtype$DT_BOOL", Tin = ["tfdtype$DT_INT32", "tfdtype$DT_INT32"], Tout = ["tfdtype$DT_INT32", "tfdtype$DT_INT32"], _tpu_replicate = "cluster", device = "", else_branch = @cond_false0, is_stateless = false, name = "cond", output_shapes = ["tfshape$unknown_rank: true\0A", "tfshape$unknown_rank: true\0A"], then_branch = @cond_true0, then_branch.how_many = 32 : i64, then_branch.ping = "ack"} : (tensor<*xi1>, tensor<*xi32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi32>, !_tf.control)
# CHECK-NEXT: %13:2 = "_tf.Identity"(%12#0) {T = "tfdtype$DT_INT32", _tpu_replicate = "cluster", device = "/device:TPU_REPLICATED_CORE:0", name = "Identity"} : (tensor<*xi32>) -> (tensor<*xi32>, !_tf.control)
# CHECK-NEXT: %14:2 = "_tf.TPUReplicatedOutput"(%13#0) {T = "tfdtype$DT_INT32", device = "", name = "output0", num_replicas = 1 : i64} : (tensor<*xi32>) -> (tensor<*xi32>, !_tf.control)
# CHECK-NEXT: %15:2 = "_tf.Identity"(%14#0, %6) {T = "tfdtype$DT_INT32", device = "", name = "output_0_shard_0"} : (tensor<*xi32>, !_tf.control) -> (tensor<*xi32>, !_tf.control)

View File

@ -142,7 +142,7 @@ versions {
#CHECK: func @main() {
#CHECK-NEXT: %0:2 = "_tf.Placeholder"() {device = "", dtype = "tfdtype$DT_BOOL", name = "Placeholder", shape = "tfshape$unknown_rank: true\0A"} : () -> (tensor<*xi1>, !_tf.control)
#CHECK-NEXT: %1:2 = "_tf.Placeholder"() {device = "", dtype = "tfdtype$DT_INT32", name = "Placeholder_1", shape = "tfshape$unknown_rank: true\0A"} : () -> (tensor<*xi32>, !_tf.control)
#CHECK-NEXT: %2:2 = "_tf.If"(%0#0, %1#0) {Tcond = "tfdtype$DT_BOOL", Tin = ["tfdtype$DT_INT32"], Tout = ["tfdtype$DT_INT32"], device = "", else_branch = @get_zeros0, name = "If", output_shapes = [], then_branch = @identity0} : (tensor<*xi1>, tensor<*xi32>) -> (tensor<*xi32>, !_tf.control)
#CHECK-NEXT: %2:2 = "_tf.If"(%0#0, %1#0) {Tcond = "tfdtype$DT_BOOL", Tin = ["tfdtype$DT_INT32"], Tout = ["tfdtype$DT_INT32"], device = "", else_branch = @get_zeros0, is_stateless = false, name = "If", output_shapes = [], then_branch = @identity0} : (tensor<*xi1>, tensor<*xi32>) -> (tensor<*xi32>, !_tf.control)
#CHECK-NEXT: return
#CHECK-NEXT: }
#CHECK: func @get_zeros0(%arg0: tensor<*xi32>) -> tensor<2xi32> {

View File

@ -486,7 +486,7 @@ func @testIfElse(tensor<*xf32>) -> tensor<*xf32>
func @testValidIfOp(tensor<i1>, tensor<2xf32>) -> tensor<2xf32> {
^bb0(%arg0: tensor<i1>, %arg1: tensor<2xf32>):
%1 = "tf.If"(%arg0, %arg1) {
then_branch = @testIfThen, else_branch = @testIfElse
then_branch = @testIfThen, else_branch = @testIfElse, is_stateless = false
} : (tensor<i1>, tensor<2xf32>) -> tensor<2xf32>
return %1 : tensor<2xf32>
@ -503,7 +503,8 @@ func @testInvalidIfOp(tensor<i1>, f32) -> f32 {
// expected-error @+1 {{operand #1 must be tensor of tf.dtype values}}
%1 = "tf.If"(%arg0, %arg1) {
then_branch = @testIfThen,
else_branch = @testIfElse
else_branch = @testIfElse,
is_stateless = false
} : (tensor<i1>, f32) -> f32
return %1 : f32
@ -518,7 +519,7 @@ func @testInvalidIfOp(tensor<i1>, tensor<2xf32>) -> tensor<2xf32> {
^bb0(%arg0: tensor<i1>, %arg1: tensor<2xf32>):
// expected-error @+1 {{requires attribute 'then_branch'}}
%1 = "tf.If"(%arg0, %arg1) {
else_branch = @testIfElse
else_branch = @testIfElse, is_stateless = false
} : (tensor<i1>, tensor<2xf32>) -> tensor<2xf32>
return %1 : tensor<2xf32>
@ -535,7 +536,8 @@ func @testInvalidIfOp(tensor<i1>, tensor<2xf32>) -> tensor<2xf32> {
// expected-error @+1 {{branches should have 1 inputs}}
%1 = "tf.If"(%arg0, %arg1) {
then_branch = @testIfThen,
else_branch = @testIfElse
else_branch = @testIfElse,
is_stateless = false
} : (tensor<i1>, tensor<2xf32>) -> tensor<2xf32>
return %1 : tensor<2xf32>
@ -552,7 +554,8 @@ func @testInvalidIfOp(tensor<i1>, tensor<2xf32>) -> tensor<2xf32> {
// expected-error @+1 {{then branch input type tensor<*xf16> is incompatible with operand type tensor<2xf32>}}
%1 = "tf.If"(%arg0, %arg1) {
then_branch = @testIfThen,
else_branch = @testIfElse
else_branch = @testIfElse,
is_stateless = false
} : (tensor<i1>, tensor<2xf32>) -> tensor<2xf32>
return %1 : tensor<2xf32>
@ -569,7 +572,8 @@ func @testInvalidIfOp(tensor<i1>, tensor<*xf32>) -> tensor<2xf32> {
// expected-error @+1 {{branches inputs have incompatible types tensor<2xf32> and tensor<3xf32>}}
%1 = "tf.If"(%arg0, %arg1) {
then_branch = @testIfThen,
else_branch = @testIfElse
else_branch = @testIfElse,
is_stateless = false
} : (tensor<i1>, tensor<*xf32>) -> tensor<2xf32>
return %1 : tensor<2xf32>
@ -586,7 +590,8 @@ func @testInvalidIfOp(tensor<i1>, tensor<*xf32>) -> tensor<2xf32> {
// expected-error @+1 {{else branch result type tensor<3xf32> is incompatible with op result type tensor<2xf32>}}
%1 = "tf.If"(%arg0, %arg1) {
then_branch = @testIfThen,
else_branch = @testIfElse
else_branch = @testIfElse,
is_stateless = false
} : (tensor<i1>, tensor<*xf32>) -> tensor<2xf32>
return %1 : tensor<2xf32>

View File

@ -979,9 +979,12 @@ Status Importer::ConvertNode(const Node& node) {
node_type_name = (*tf_name_to_mlir_name_)[node_type_name];
}
const char* kTfControlFlowFormPrefix = "_tf.";
std::string op_name = kTfControlFlowFormPrefix + node_type_name;
auto get_full_op_name = [&](const std::string& op_name) {
const char* kTfControlFlowFormPrefix = "_tf.";
return kTfControlFlowFormPrefix + op_name;
};
std::string op_name = get_full_op_name(node_type_name);
if (back_edge_node_output_.contains(&node)) {
op_name = op_name + ".sink";
}
@ -1082,6 +1085,14 @@ Status Importer::ConvertNode(const Node& node) {
result.attributes.push_back(builder_->getNamedAttr(
"device", builder_->getStringAttr(std::string(node_def.device()))));
// Map If and StatelessIf op in TensorFlow to the common If op in MLIR and add
// the differentiating attribute.
if (node.IsIfNode()) {
result.name = mlir::OperationName(get_full_op_name("If"), context_);
mlir::BoolAttr val = builder_->getBoolAttr(node_type_name == "StatelessIf");
result.attributes.push_back(builder_->getNamedAttr("is_stateless", val));
}
node_values_[node.id()] = builder_->createOperation(result);
return Status::OK();
}

View File

@ -160,6 +160,18 @@ Status ConvertAttribute(const mlir::ArrayAttr& attr, AttrValue* value) {
return Status::OK();
}
// Updates NodeDef constructed out of an MLIR If op to map it to either
// TensorFlow StatelessIf or If op depending on the additional attribute.
void UpdateCompositeIfOp(NodeDef* node_def) {
auto it = node_def->mutable_attr()->find("is_stateless");
if (it != node_def->attr().end()) {
if (it->second.b()) {
*node_def->mutable_op() = "StatelessIf";
}
node_def->mutable_attr()->erase(it);
}
}
} // anonymous namespace
StatusOr<std::unique_ptr<NodeDef>> GetOperationNodeDef(
@ -194,6 +206,8 @@ StatusOr<std::unique_ptr<NodeDef>> GetOperationNodeDef(
TF_RETURN_IF_ERROR(ConvertLocation(
inst->getLoc(), node_def->mutable_experimental_debug_info()));
if (node_def->op() == "If") UpdateCompositeIfOp(node_def.get());
return node_def;
}