Treat Case similar to If/While wrt stateless variant

On import dedup to tf.Case and on export expand to either Case or StatelessCase depending on variant. Kept it mechanical to the other two control flow ops here.

PiperOrigin-RevId: 325498204
Change-Id: Icf5f6f580510908d7dd7c043ac287b19862eaa02
This commit is contained in:
Jacques Pienaar 2020-08-07 13:32:59 -07:00 committed by TensorFlower Gardener
parent 7197362d5a
commit 247e9bd050
6 changed files with 361 additions and 74 deletions

View File

@ -1350,48 +1350,6 @@ then the output will be
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_CaseOp : TF_Op<"Case", []> {
let summary = [{
An n-way switch statement which calls a single branch function.
}];
let description = [{
An n-way switch statement, implementing the following:
```
switch (branch_index) {
case 0:
output = branches[0](input);
break;
case 1:
output = branches[1](input);
break;
...
case [[nbranches-1]]:
default:
output = branches[nbranches-1](input);
break;
}
```
}];
let arguments = (ins
I32Tensor:$branch_index,
Variadic<TF_Tensor>:$input,
Confined<SymbolRefArrayAttr, [ArrayMinCount<1>]>:$branches,
DefaultValuedAttr<TF_ShapeAttrArray, "{}">:$output_shapes
);
let results = (outs
Variadic<TF_Tensor>:$output
);
TF_DerivedOperandTypeListAttr Tin = TF_DerivedOperandTypeListAttr<1>;
TF_DerivedResultTypeListAttr Tout = TF_DerivedResultTypeListAttr<0>;
let hasCanonicalizer = 1;
}
def TF_CastOp : TF_Op<"Cast", [NoSideEffect, SameOperandsAndResultShape]> {
let summary = "Cast x of type SrcT to y of DstT.";

View File

@ -68,6 +68,51 @@ class TF_TensorListInitOp<string mnemonic> : TF_Op<mnemonic, [NoSideEffect]> {
}];
}
def TF_CaseOp : TF_Op<"Case", []> {
let summary = [{
An n-way switch statement which calls a single branch function.
}];
let description = [{
An n-way switch statement, implementing the following:
```
switch (branch_index) {
case 0:
output = branches[0](input);
break;
case 1:
output = branches[1](input);
break;
...
case [[nbranches-1]]:
default:
output = branches[nbranches-1](input);
break;
}
```
}];
let arguments = (ins
I32Tensor:$branch_index,
Variadic<TF_Tensor>:$input,
Confined<SymbolRefArrayAttr, [ArrayMinCount<1>]>:$branches,
DefaultValuedAttr<TF_ShapeAttrArray, "{}">:$output_shapes,
// Used to map StatelessCase and Case to a common op.
DefaultValuedAttr<BoolAttr, "false">:$is_stateless
);
let results = (outs
Variadic<TF_Tensor>:$output
);
TF_DerivedOperandTypeListAttr Tin = TF_DerivedOperandTypeListAttr<1>;
TF_DerivedResultTypeListAttr Tout = TF_DerivedResultTypeListAttr<0>;
let hasCanonicalizer = 1;
}
// In MLIR, the TensorFlow tensor value is represented as an ElementsAttr, with
// its type encoding the tensor's shape and data type.
def TF_ConstOp : TF_Op<"Const", [ConstantLike, NoSideEffect,

View File

@ -0,0 +1,261 @@
# RUN: tf-mlir-translate -graphdef-to-splatted-mlir %s -o - | FileCheck %s
node {
name: "Const"
op: "Const"
attr {
key: "dtype"
value {
type: DT_INT32
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_INT32
tensor_shape {
}
int_val: 1
}
}
}
}
node {
name: "Const_1"
op: "Const"
attr {
key: "dtype"
value {
type: DT_INT32
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_INT32
tensor_shape {
}
int_val: 0
}
}
}
}
node {
name: "indexed_case"
op: "StatelessCase"
input: "Const_1"
input: "Const"
attr {
key: "Tin"
value {
list {
type: DT_INT32
}
}
}
attr {
key: "Tout"
value {
list {
type: DT_INT32
}
}
}
attr {
key: "_lower_using_switch_merge"
value {
b: true
}
}
attr {
key: "_read_only_resource_inputs"
value {
list {
}
}
}
attr {
key: "branches"
value {
list {
func {
name: "indexed_case_branch0_4"
}
func {
name: "indexed_case_branch1_5"
}
}
}
}
attr {
key: "output_shapes"
value {
list {
shape {
}
}
}
}
}
node {
name: "indexed_case/Identity"
op: "Identity"
input: "indexed_case"
attr {
key: "T"
value {
type: DT_INT32
}
}
}
library {
function {
signature {
name: "indexed_case_branch0_4"
input_arg {
name: "add_const"
type: DT_INT32
}
output_arg {
name: "add"
type: DT_INT32
}
}
node_def {
name: "add/y"
op: "Const"
attr {
key: "dtype"
value {
type: DT_INT32
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_INT32
tensor_shape {
}
int_val: 1
}
}
}
experimental_debug_info {
original_node_names: "add/y"
}
}
node_def {
name: "add_0"
op: "AddV2"
input: "add_const"
input: "add/y:output:0"
attr {
key: "T"
value {
type: DT_INT32
}
}
experimental_debug_info {
original_node_names: "add"
}
}
ret {
key: "add"
value: "add_0:z:0"
}
arg_attr {
key: 0
value {
attr {
key: "_output_shapes"
value {
list {
shape {
}
}
}
}
}
}
}
function {
signature {
name: "indexed_case_branch1_5"
input_arg {
name: "add_const"
type: DT_INT32
}
output_arg {
name: "add"
type: DT_INT32
}
}
node_def {
name: "add/y"
op: "Const"
attr {
key: "dtype"
value {
type: DT_INT32
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_INT32
tensor_shape {
}
int_val: 2
}
}
}
experimental_debug_info {
original_node_names: "add/y"
}
}
node_def {
name: "add_0"
op: "AddV2"
input: "add_const"
input: "add/y:output:0"
attr {
key: "T"
value {
type: DT_INT32
}
}
experimental_debug_info {
original_node_names: "add"
}
}
ret {
key: "add"
value: "add_0:z:0"
}
arg_attr {
key: 0
value {
attr {
key: "_output_shapes"
value {
list {
shape {
}
}
}
}
}
}
}
}
versions {
producer: 486
min_consumer: 12
}
# CHECK: tf.Case
# CHECK-SAME: is_stateless

View File

@ -0,0 +1,38 @@
// RUN: tf-mlir-translate -mlir-to-graphdef %s -o - | FileCheck %s
module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 486 : i32}} {
func @main() {
tf_executor.graph {
%outputs, %control = tf_executor.island wraps "tf.Const"() {device = "", value = dense<1> : tensor<i32>} : () -> tensor<i32>
%outputs_0, %control_1 = tf_executor.island wraps "tf.Const"() {device = "", value = dense<0> : tensor<i32>} : () -> tensor<i32>
%outputs_2, %control_3 = tf_executor.island wraps "tf.Case"(%outputs_0, %outputs) {Tin = [i32], Tout = [i32], _lower_using_switch_merge = true, _read_only_resource_inputs = [], branches = [@indexed_case_branch0_40, @indexed_case_branch1_50], device = "", is_stateless = true, output_shapes = [#tf.shape<>]} : (tensor<i32>, tensor<i32>) -> tensor<*xi32> loc("stateless_case")
%outputs_4, %control_5 = tf_executor.island wraps "tf.Identity"(%outputs_2) {device = ""} : (tensor<*xi32>) -> tensor<*xi32>
%outputs_6, %control_7 = tf_executor.island wraps "tf.Case"(%outputs_0, %outputs) {Tin = [i32], Tout = [i32], _lower_using_switch_merge = true, _read_only_resource_inputs = [], branches = [@indexed_case_branch0_40, @indexed_case_branch1_50], device = "", is_stateless = false, output_shapes = [#tf.shape<>]} : (tensor<i32>, tensor<i32>) -> tensor<*xi32> loc("regular_case")
tf_executor.fetch
}
return
}
func @indexed_case_branch0_40(%arg0: tensor<i32>) -> tensor<*xi32> attributes {sym_visibility = "private"} {
%0 = tf_executor.graph {
%outputs, %control = tf_executor.island wraps "tf.Const"() {device = "", value = dense<1> : tensor<i32>} : () -> tensor<i32>
%outputs_0, %control_1 = tf_executor.island wraps "tf.AddV2"(%arg0, %outputs) {device = ""} : (tensor<i32>, tensor<i32>) -> tensor<*xi32>
tf_executor.fetch %outputs_0 : tensor<*xi32>
}
return %0 : tensor<*xi32>
}
func @indexed_case_branch1_50(%arg0: tensor<i32>) -> tensor<*xi32> attributes {sym_visibility = "private"} {
%0 = tf_executor.graph {
%outputs, %control = tf_executor.island wraps "tf.Const"() {device = "", value = dense<2> : tensor<i32>} : () -> tensor<i32>
%outputs_0, %control_1 = tf_executor.island wraps "tf.AddV2"(%arg0, %outputs) {device = ""} : (tensor<i32>, tensor<i32>) -> tensor<*xi32>
tf_executor.fetch %outputs_0 : tensor<*xi32>
}
return %0 : tensor<*xi32>
}
}
// CHECK: name: "stateless_case"
// CHECK-NEXT: "StatelessCase"
// CHECK: name: "regular_case"
// CHECK-NEXT: "Case"

View File

@ -1934,22 +1934,18 @@ Status ImporterBase::ConvertNode(const Node& node) {
}
}
// Map If and StatelessIf op in TensorFlow to the common If op in MLIR and add
// the differentiating attribute.
if (node.IsIfNode()) {
result.name = mlir::OperationName(get_full_op_name("If"), context_);
mlir::BoolAttr val = builder_.getBoolAttr(node_type_name == "StatelessIf");
auto composite_control_flow_op = [&](const std::string& name) {
result.name = mlir::OperationName(get_full_op_name(name), context_);
bool stateless = absl::StartsWith(node_type_name, "Stateless");
mlir::BoolAttr val = builder_.getBoolAttr(stateless);
result.attributes.push_back(builder_.getNamedAttr("is_stateless", val));
}
};
// Map While and StatelessWhile op in TensorFlow to the common While op in
// MLIR and add the differentiating attribute.
if (node.IsWhileNode()) {
result.name = mlir::OperationName(get_full_op_name("While"), context_);
mlir::BoolAttr val =
builder_.getBoolAttr(node_type_name == "StatelessWhile");
result.attributes.push_back(builder_.getNamedAttr("is_stateless", val));
}
// Map Case/If/While and StatelessCase/If/While op in TensorFlow to the common
// Case/If/While op in MLIR and add the differentiating attribute.
if (node.IsCaseNode()) composite_control_flow_op("Case");
if (node.IsIfNode()) composite_control_flow_op("If");
if (node.IsWhileNode()) composite_control_flow_op("While");
// Register the mapping between the TF node and the newly created operation.
node_values_[node.id()] =

View File

@ -227,25 +227,13 @@ Status ConvertAttribute(const mlir::ArrayAttr& attr, AttrValue* value) {
return Status::OK();
}
// Updates NodeDef constructed out of an MLIR If op to map it to either
// TensorFlow StatelessIf or If op depending on the additional attribute.
void UpdateCompositeIfOp(NodeDef* node_def) {
// Updates NodeDef constructed out of an MLIR Case/IfW/While op to map it to
// either TensorFlow StatelessX or X op depending on the additional attribute.
void UpdateCompositeOp(NodeDef* node_def) {
auto it = node_def->mutable_attr()->find("is_stateless");
if (it != node_def->attr().end()) {
if (it->second.b()) {
*node_def->mutable_op() = "StatelessIf";
}
node_def->mutable_attr()->erase(it);
}
}
// Updates NodeDef constructed out of an MLIR While op to map it to either
// TensorFlow StatelessWhile or While op depending on the additional attribute.
void UpdateCompositeWhileOp(NodeDef* node_def) {
auto it = node_def->mutable_attr()->find("is_stateless");
if (it != node_def->attr().end()) {
if (it->second.b()) {
*node_def->mutable_op() = "StatelessWhile";
*node_def->mutable_op() = "Stateless" + node_def->op();
}
node_def->mutable_attr()->erase(it);
}
@ -352,8 +340,9 @@ StatusOr<std::unique_ptr<NodeDef>> GetOperationNodeDef(
TF_RETURN_IF_ERROR(ConvertLocation(
inst->getLoc(), node_def->mutable_experimental_debug_info()));
if (node_def->op() == "If") UpdateCompositeIfOp(node_def.get());
if (node_def->op() == "While") UpdateCompositeWhileOp(node_def.get());
if (node_def->op() == "Case") UpdateCompositeOp(node_def.get());
if (node_def->op() == "If") UpdateCompositeOp(node_def.get());
if (node_def->op() == "While") UpdateCompositeOp(node_def.get());
return node_def;
}