diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td index 081903d13cf..bf8d7015b46 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td @@ -1350,48 +1350,6 @@ then the output will be TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } -def TF_CaseOp : TF_Op<"Case", []> { - let summary = [{ -An n-way switch statement which calls a single branch function. - }]; - - let description = [{ -An n-way switch statement, implementing the following: - ``` - switch (branch_index) { - case 0: - output = branches[0](input); - break; - case 1: - output = branches[1](input); - break; - ... - case [[nbranches-1]]: - default: - output = branches[nbranches-1](input); - break; - } - ``` - }]; - - let arguments = (ins - I32Tensor:$branch_index, - Variadic:$input, - - Confined]>:$branches, - DefaultValuedAttr:$output_shapes - ); - - let results = (outs - Variadic:$output - ); - - TF_DerivedOperandTypeListAttr Tin = TF_DerivedOperandTypeListAttr<1>; - TF_DerivedResultTypeListAttr Tout = TF_DerivedResultTypeListAttr<0>; - - let hasCanonicalizer = 1; -} - def TF_CastOp : TF_Op<"Cast", [NoSideEffect, SameOperandsAndResultShape]> { let summary = "Cast x of type SrcT to y of DstT."; diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td index 376b7933b47..5269bb82239 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td @@ -68,6 +68,51 @@ class TF_TensorListInitOp : TF_Op { }]; } +def TF_CaseOp : TF_Op<"Case", []> { + let summary = [{ +An n-way switch statement which calls a single branch function. + }]; + + let description = [{ +An n-way switch statement, implementing the following: + ``` + switch (branch_index) { + case 0: + output = branches[0](input); + break; + case 1: + output = branches[1](input); + break; + ... + case [[nbranches-1]]: + default: + output = branches[nbranches-1](input); + break; + } + ``` + }]; + + let arguments = (ins + I32Tensor:$branch_index, + Variadic:$input, + + Confined]>:$branches, + DefaultValuedAttr:$output_shapes, + + // Used to map StatelessCase and Case to a common op. + DefaultValuedAttr:$is_stateless + ); + + let results = (outs + Variadic:$output + ); + + TF_DerivedOperandTypeListAttr Tin = TF_DerivedOperandTypeListAttr<1>; + TF_DerivedResultTypeListAttr Tout = TF_DerivedResultTypeListAttr<0>; + + let hasCanonicalizer = 1; +} + // In MLIR, the TensorFlow tensor value is represented as an ElementsAttr, with // its type encoding the tensor's shape and data type. def TF_ConstOp : TF_Op<"Const", [ConstantLike, NoSideEffect, diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/case_op.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/case_op.pbtxt new file mode 100644 index 00000000000..1372ad71283 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/case_op.pbtxt @@ -0,0 +1,261 @@ +# RUN: tf-mlir-translate -graphdef-to-splatted-mlir %s -o - | FileCheck %s + +node { + name: "Const" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } +} +node { + name: "Const_1" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 0 + } + } + } +} +node { + name: "indexed_case" + op: "StatelessCase" + input: "Const_1" + input: "Const" + attr { + key: "Tin" + value { + list { + type: DT_INT32 + } + } + } + attr { + key: "Tout" + value { + list { + type: DT_INT32 + } + } + } + attr { + key: "_lower_using_switch_merge" + value { + b: true + } + } + attr { + key: "_read_only_resource_inputs" + value { + list { + } + } + } + attr { + key: "branches" + value { + list { + func { + name: "indexed_case_branch0_4" + } + func { + name: "indexed_case_branch1_5" + } + } + } + } + attr { + key: "output_shapes" + value { + list { + shape { + } + } + } + } +} +node { + name: "indexed_case/Identity" + op: "Identity" + input: "indexed_case" + attr { + key: "T" + value { + type: DT_INT32 + } + } +} +library { + function { + signature { + name: "indexed_case_branch0_4" + input_arg { + name: "add_const" + type: DT_INT32 + } + output_arg { + name: "add" + type: DT_INT32 + } + } + node_def { + name: "add/y" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } + experimental_debug_info { + original_node_names: "add/y" + } + } + node_def { + name: "add_0" + op: "AddV2" + input: "add_const" + input: "add/y:output:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + experimental_debug_info { + original_node_names: "add" + } + } + ret { + key: "add" + value: "add_0:z:0" + } + arg_attr { + key: 0 + value { + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + } + } + function { + signature { + name: "indexed_case_branch1_5" + input_arg { + name: "add_const" + type: DT_INT32 + } + output_arg { + name: "add" + type: DT_INT32 + } + } + node_def { + name: "add/y" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 2 + } + } + } + experimental_debug_info { + original_node_names: "add/y" + } + } + node_def { + name: "add_0" + op: "AddV2" + input: "add_const" + input: "add/y:output:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + experimental_debug_info { + original_node_names: "add" + } + } + ret { + key: "add" + value: "add_0:z:0" + } + arg_attr { + key: 0 + value { + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + } + } +} +versions { + producer: 486 + min_consumer: 12 +} + +# CHECK: tf.Case +# CHECK-SAME: is_stateless diff --git a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/case.mlir b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/case.mlir new file mode 100644 index 00000000000..2f2ee6f1286 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/case.mlir @@ -0,0 +1,38 @@ +// RUN: tf-mlir-translate -mlir-to-graphdef %s -o - | FileCheck %s + +module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 486 : i32}} { + func @main() { + tf_executor.graph { + %outputs, %control = tf_executor.island wraps "tf.Const"() {device = "", value = dense<1> : tensor} : () -> tensor + %outputs_0, %control_1 = tf_executor.island wraps "tf.Const"() {device = "", value = dense<0> : tensor} : () -> tensor + %outputs_2, %control_3 = tf_executor.island wraps "tf.Case"(%outputs_0, %outputs) {Tin = [i32], Tout = [i32], _lower_using_switch_merge = true, _read_only_resource_inputs = [], branches = [@indexed_case_branch0_40, @indexed_case_branch1_50], device = "", is_stateless = true, output_shapes = [#tf.shape<>]} : (tensor, tensor) -> tensor<*xi32> loc("stateless_case") + %outputs_4, %control_5 = tf_executor.island wraps "tf.Identity"(%outputs_2) {device = ""} : (tensor<*xi32>) -> tensor<*xi32> + %outputs_6, %control_7 = tf_executor.island wraps "tf.Case"(%outputs_0, %outputs) {Tin = [i32], Tout = [i32], _lower_using_switch_merge = true, _read_only_resource_inputs = [], branches = [@indexed_case_branch0_40, @indexed_case_branch1_50], device = "", is_stateless = false, output_shapes = [#tf.shape<>]} : (tensor, tensor) -> tensor<*xi32> loc("regular_case") + tf_executor.fetch + } + return + } + + func @indexed_case_branch0_40(%arg0: tensor) -> tensor<*xi32> attributes {sym_visibility = "private"} { + %0 = tf_executor.graph { + %outputs, %control = tf_executor.island wraps "tf.Const"() {device = "", value = dense<1> : tensor} : () -> tensor + %outputs_0, %control_1 = tf_executor.island wraps "tf.AddV2"(%arg0, %outputs) {device = ""} : (tensor, tensor) -> tensor<*xi32> + tf_executor.fetch %outputs_0 : tensor<*xi32> + } + return %0 : tensor<*xi32> + } + + func @indexed_case_branch1_50(%arg0: tensor) -> tensor<*xi32> attributes {sym_visibility = "private"} { + %0 = tf_executor.graph { + %outputs, %control = tf_executor.island wraps "tf.Const"() {device = "", value = dense<2> : tensor} : () -> tensor + %outputs_0, %control_1 = tf_executor.island wraps "tf.AddV2"(%arg0, %outputs) {device = ""} : (tensor, tensor) -> tensor<*xi32> + tf_executor.fetch %outputs_0 : tensor<*xi32> + } + return %0 : tensor<*xi32> + } +} + +// CHECK: name: "stateless_case" +// CHECK-NEXT: "StatelessCase" +// CHECK: name: "regular_case" +// CHECK-NEXT: "Case" diff --git a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc index ef0087c4310..94ddf76736e 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc @@ -1934,22 +1934,18 @@ Status ImporterBase::ConvertNode(const Node& node) { } } - // Map If and StatelessIf op in TensorFlow to the common If op in MLIR and add - // the differentiating attribute. - if (node.IsIfNode()) { - result.name = mlir::OperationName(get_full_op_name("If"), context_); - mlir::BoolAttr val = builder_.getBoolAttr(node_type_name == "StatelessIf"); + auto composite_control_flow_op = [&](const std::string& name) { + result.name = mlir::OperationName(get_full_op_name(name), context_); + bool stateless = absl::StartsWith(node_type_name, "Stateless"); + mlir::BoolAttr val = builder_.getBoolAttr(stateless); result.attributes.push_back(builder_.getNamedAttr("is_stateless", val)); - } + }; - // Map While and StatelessWhile op in TensorFlow to the common While op in - // MLIR and add the differentiating attribute. - if (node.IsWhileNode()) { - result.name = mlir::OperationName(get_full_op_name("While"), context_); - mlir::BoolAttr val = - builder_.getBoolAttr(node_type_name == "StatelessWhile"); - result.attributes.push_back(builder_.getNamedAttr("is_stateless", val)); - } + // Map Case/If/While and StatelessCase/If/While op in TensorFlow to the common + // Case/If/While op in MLIR and add the differentiating attribute. + if (node.IsCaseNode()) composite_control_flow_op("Case"); + if (node.IsIfNode()) composite_control_flow_op("If"); + if (node.IsWhileNode()) composite_control_flow_op("While"); // Register the mapping between the TF node and the newly created operation. node_values_[node.id()] = diff --git a/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc b/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc index 0364b935b92..ad9ddb277d7 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc @@ -227,25 +227,13 @@ Status ConvertAttribute(const mlir::ArrayAttr& attr, AttrValue* value) { return Status::OK(); } -// Updates NodeDef constructed out of an MLIR If op to map it to either -// TensorFlow StatelessIf or If op depending on the additional attribute. -void UpdateCompositeIfOp(NodeDef* node_def) { +// Updates NodeDef constructed out of an MLIR Case/IfW/While op to map it to +// either TensorFlow StatelessX or X op depending on the additional attribute. +void UpdateCompositeOp(NodeDef* node_def) { auto it = node_def->mutable_attr()->find("is_stateless"); if (it != node_def->attr().end()) { if (it->second.b()) { - *node_def->mutable_op() = "StatelessIf"; - } - node_def->mutable_attr()->erase(it); - } -} - -// Updates NodeDef constructed out of an MLIR While op to map it to either -// TensorFlow StatelessWhile or While op depending on the additional attribute. -void UpdateCompositeWhileOp(NodeDef* node_def) { - auto it = node_def->mutable_attr()->find("is_stateless"); - if (it != node_def->attr().end()) { - if (it->second.b()) { - *node_def->mutable_op() = "StatelessWhile"; + *node_def->mutable_op() = "Stateless" + node_def->op(); } node_def->mutable_attr()->erase(it); } @@ -352,8 +340,9 @@ StatusOr> GetOperationNodeDef( TF_RETURN_IF_ERROR(ConvertLocation( inst->getLoc(), node_def->mutable_experimental_debug_info())); - if (node_def->op() == "If") UpdateCompositeIfOp(node_def.get()); - if (node_def->op() == "While") UpdateCompositeWhileOp(node_def.get()); + if (node_def->op() == "Case") UpdateCompositeOp(node_def.get()); + if (node_def->op() == "If") UpdateCompositeOp(node_def.get()); + if (node_def->op() == "While") UpdateCompositeOp(node_def.get()); return node_def; }