Treat Case similar to If/While wrt stateless variant
On import dedup to tf.Case and on export expand to either Case or StatelessCase depending on variant. Kept it mechanical to the other two control flow ops here. PiperOrigin-RevId: 325498204 Change-Id: Icf5f6f580510908d7dd7c043ac287b19862eaa02
This commit is contained in:
parent
7197362d5a
commit
247e9bd050
@ -1350,48 +1350,6 @@ then the output will be
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_CaseOp : TF_Op<"Case", []> {
|
||||
let summary = [{
|
||||
An n-way switch statement which calls a single branch function.
|
||||
}];
|
||||
|
||||
let description = [{
|
||||
An n-way switch statement, implementing the following:
|
||||
```
|
||||
switch (branch_index) {
|
||||
case 0:
|
||||
output = branches[0](input);
|
||||
break;
|
||||
case 1:
|
||||
output = branches[1](input);
|
||||
break;
|
||||
...
|
||||
case [[nbranches-1]]:
|
||||
default:
|
||||
output = branches[nbranches-1](input);
|
||||
break;
|
||||
}
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
I32Tensor:$branch_index,
|
||||
Variadic<TF_Tensor>:$input,
|
||||
|
||||
Confined<SymbolRefArrayAttr, [ArrayMinCount<1>]>:$branches,
|
||||
DefaultValuedAttr<TF_ShapeAttrArray, "{}">:$output_shapes
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
Variadic<TF_Tensor>:$output
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeListAttr Tin = TF_DerivedOperandTypeListAttr<1>;
|
||||
TF_DerivedResultTypeListAttr Tout = TF_DerivedResultTypeListAttr<0>;
|
||||
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def TF_CastOp : TF_Op<"Cast", [NoSideEffect, SameOperandsAndResultShape]> {
|
||||
let summary = "Cast x of type SrcT to y of DstT.";
|
||||
|
||||
|
@ -68,6 +68,51 @@ class TF_TensorListInitOp<string mnemonic> : TF_Op<mnemonic, [NoSideEffect]> {
|
||||
}];
|
||||
}
|
||||
|
||||
def TF_CaseOp : TF_Op<"Case", []> {
|
||||
let summary = [{
|
||||
An n-way switch statement which calls a single branch function.
|
||||
}];
|
||||
|
||||
let description = [{
|
||||
An n-way switch statement, implementing the following:
|
||||
```
|
||||
switch (branch_index) {
|
||||
case 0:
|
||||
output = branches[0](input);
|
||||
break;
|
||||
case 1:
|
||||
output = branches[1](input);
|
||||
break;
|
||||
...
|
||||
case [[nbranches-1]]:
|
||||
default:
|
||||
output = branches[nbranches-1](input);
|
||||
break;
|
||||
}
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
I32Tensor:$branch_index,
|
||||
Variadic<TF_Tensor>:$input,
|
||||
|
||||
Confined<SymbolRefArrayAttr, [ArrayMinCount<1>]>:$branches,
|
||||
DefaultValuedAttr<TF_ShapeAttrArray, "{}">:$output_shapes,
|
||||
|
||||
// Used to map StatelessCase and Case to a common op.
|
||||
DefaultValuedAttr<BoolAttr, "false">:$is_stateless
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
Variadic<TF_Tensor>:$output
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeListAttr Tin = TF_DerivedOperandTypeListAttr<1>;
|
||||
TF_DerivedResultTypeListAttr Tout = TF_DerivedResultTypeListAttr<0>;
|
||||
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
// In MLIR, the TensorFlow tensor value is represented as an ElementsAttr, with
|
||||
// its type encoding the tensor's shape and data type.
|
||||
def TF_ConstOp : TF_Op<"Const", [ConstantLike, NoSideEffect,
|
||||
|
@ -0,0 +1,261 @@
|
||||
# RUN: tf-mlir-translate -graphdef-to-splatted-mlir %s -o - | FileCheck %s
|
||||
|
||||
node {
|
||||
name: "Const"
|
||||
op: "Const"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "value"
|
||||
value {
|
||||
tensor {
|
||||
dtype: DT_INT32
|
||||
tensor_shape {
|
||||
}
|
||||
int_val: 1
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "Const_1"
|
||||
op: "Const"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "value"
|
||||
value {
|
||||
tensor {
|
||||
dtype: DT_INT32
|
||||
tensor_shape {
|
||||
}
|
||||
int_val: 0
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "indexed_case"
|
||||
op: "StatelessCase"
|
||||
input: "Const_1"
|
||||
input: "Const"
|
||||
attr {
|
||||
key: "Tin"
|
||||
value {
|
||||
list {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "Tout"
|
||||
value {
|
||||
list {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "_lower_using_switch_merge"
|
||||
value {
|
||||
b: true
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "_read_only_resource_inputs"
|
||||
value {
|
||||
list {
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "branches"
|
||||
value {
|
||||
list {
|
||||
func {
|
||||
name: "indexed_case_branch0_4"
|
||||
}
|
||||
func {
|
||||
name: "indexed_case_branch1_5"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "output_shapes"
|
||||
value {
|
||||
list {
|
||||
shape {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "indexed_case/Identity"
|
||||
op: "Identity"
|
||||
input: "indexed_case"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
}
|
||||
library {
|
||||
function {
|
||||
signature {
|
||||
name: "indexed_case_branch0_4"
|
||||
input_arg {
|
||||
name: "add_const"
|
||||
type: DT_INT32
|
||||
}
|
||||
output_arg {
|
||||
name: "add"
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
node_def {
|
||||
name: "add/y"
|
||||
op: "Const"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "value"
|
||||
value {
|
||||
tensor {
|
||||
dtype: DT_INT32
|
||||
tensor_shape {
|
||||
}
|
||||
int_val: 1
|
||||
}
|
||||
}
|
||||
}
|
||||
experimental_debug_info {
|
||||
original_node_names: "add/y"
|
||||
}
|
||||
}
|
||||
node_def {
|
||||
name: "add_0"
|
||||
op: "AddV2"
|
||||
input: "add_const"
|
||||
input: "add/y:output:0"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
experimental_debug_info {
|
||||
original_node_names: "add"
|
||||
}
|
||||
}
|
||||
ret {
|
||||
key: "add"
|
||||
value: "add_0:z:0"
|
||||
}
|
||||
arg_attr {
|
||||
key: 0
|
||||
value {
|
||||
attr {
|
||||
key: "_output_shapes"
|
||||
value {
|
||||
list {
|
||||
shape {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
function {
|
||||
signature {
|
||||
name: "indexed_case_branch1_5"
|
||||
input_arg {
|
||||
name: "add_const"
|
||||
type: DT_INT32
|
||||
}
|
||||
output_arg {
|
||||
name: "add"
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
node_def {
|
||||
name: "add/y"
|
||||
op: "Const"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "value"
|
||||
value {
|
||||
tensor {
|
||||
dtype: DT_INT32
|
||||
tensor_shape {
|
||||
}
|
||||
int_val: 2
|
||||
}
|
||||
}
|
||||
}
|
||||
experimental_debug_info {
|
||||
original_node_names: "add/y"
|
||||
}
|
||||
}
|
||||
node_def {
|
||||
name: "add_0"
|
||||
op: "AddV2"
|
||||
input: "add_const"
|
||||
input: "add/y:output:0"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
experimental_debug_info {
|
||||
original_node_names: "add"
|
||||
}
|
||||
}
|
||||
ret {
|
||||
key: "add"
|
||||
value: "add_0:z:0"
|
||||
}
|
||||
arg_attr {
|
||||
key: 0
|
||||
value {
|
||||
attr {
|
||||
key: "_output_shapes"
|
||||
value {
|
||||
list {
|
||||
shape {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
versions {
|
||||
producer: 486
|
||||
min_consumer: 12
|
||||
}
|
||||
|
||||
# CHECK: tf.Case
|
||||
# CHECK-SAME: is_stateless
|
@ -0,0 +1,38 @@
|
||||
// RUN: tf-mlir-translate -mlir-to-graphdef %s -o - | FileCheck %s
|
||||
|
||||
module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 486 : i32}} {
|
||||
func @main() {
|
||||
tf_executor.graph {
|
||||
%outputs, %control = tf_executor.island wraps "tf.Const"() {device = "", value = dense<1> : tensor<i32>} : () -> tensor<i32>
|
||||
%outputs_0, %control_1 = tf_executor.island wraps "tf.Const"() {device = "", value = dense<0> : tensor<i32>} : () -> tensor<i32>
|
||||
%outputs_2, %control_3 = tf_executor.island wraps "tf.Case"(%outputs_0, %outputs) {Tin = [i32], Tout = [i32], _lower_using_switch_merge = true, _read_only_resource_inputs = [], branches = [@indexed_case_branch0_40, @indexed_case_branch1_50], device = "", is_stateless = true, output_shapes = [#tf.shape<>]} : (tensor<i32>, tensor<i32>) -> tensor<*xi32> loc("stateless_case")
|
||||
%outputs_4, %control_5 = tf_executor.island wraps "tf.Identity"(%outputs_2) {device = ""} : (tensor<*xi32>) -> tensor<*xi32>
|
||||
%outputs_6, %control_7 = tf_executor.island wraps "tf.Case"(%outputs_0, %outputs) {Tin = [i32], Tout = [i32], _lower_using_switch_merge = true, _read_only_resource_inputs = [], branches = [@indexed_case_branch0_40, @indexed_case_branch1_50], device = "", is_stateless = false, output_shapes = [#tf.shape<>]} : (tensor<i32>, tensor<i32>) -> tensor<*xi32> loc("regular_case")
|
||||
tf_executor.fetch
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func @indexed_case_branch0_40(%arg0: tensor<i32>) -> tensor<*xi32> attributes {sym_visibility = "private"} {
|
||||
%0 = tf_executor.graph {
|
||||
%outputs, %control = tf_executor.island wraps "tf.Const"() {device = "", value = dense<1> : tensor<i32>} : () -> tensor<i32>
|
||||
%outputs_0, %control_1 = tf_executor.island wraps "tf.AddV2"(%arg0, %outputs) {device = ""} : (tensor<i32>, tensor<i32>) -> tensor<*xi32>
|
||||
tf_executor.fetch %outputs_0 : tensor<*xi32>
|
||||
}
|
||||
return %0 : tensor<*xi32>
|
||||
}
|
||||
|
||||
func @indexed_case_branch1_50(%arg0: tensor<i32>) -> tensor<*xi32> attributes {sym_visibility = "private"} {
|
||||
%0 = tf_executor.graph {
|
||||
%outputs, %control = tf_executor.island wraps "tf.Const"() {device = "", value = dense<2> : tensor<i32>} : () -> tensor<i32>
|
||||
%outputs_0, %control_1 = tf_executor.island wraps "tf.AddV2"(%arg0, %outputs) {device = ""} : (tensor<i32>, tensor<i32>) -> tensor<*xi32>
|
||||
tf_executor.fetch %outputs_0 : tensor<*xi32>
|
||||
}
|
||||
return %0 : tensor<*xi32>
|
||||
}
|
||||
}
|
||||
|
||||
// CHECK: name: "stateless_case"
|
||||
// CHECK-NEXT: "StatelessCase"
|
||||
// CHECK: name: "regular_case"
|
||||
// CHECK-NEXT: "Case"
|
@ -1934,22 +1934,18 @@ Status ImporterBase::ConvertNode(const Node& node) {
|
||||
}
|
||||
}
|
||||
|
||||
// Map If and StatelessIf op in TensorFlow to the common If op in MLIR and add
|
||||
// the differentiating attribute.
|
||||
if (node.IsIfNode()) {
|
||||
result.name = mlir::OperationName(get_full_op_name("If"), context_);
|
||||
mlir::BoolAttr val = builder_.getBoolAttr(node_type_name == "StatelessIf");
|
||||
auto composite_control_flow_op = [&](const std::string& name) {
|
||||
result.name = mlir::OperationName(get_full_op_name(name), context_);
|
||||
bool stateless = absl::StartsWith(node_type_name, "Stateless");
|
||||
mlir::BoolAttr val = builder_.getBoolAttr(stateless);
|
||||
result.attributes.push_back(builder_.getNamedAttr("is_stateless", val));
|
||||
}
|
||||
};
|
||||
|
||||
// Map While and StatelessWhile op in TensorFlow to the common While op in
|
||||
// MLIR and add the differentiating attribute.
|
||||
if (node.IsWhileNode()) {
|
||||
result.name = mlir::OperationName(get_full_op_name("While"), context_);
|
||||
mlir::BoolAttr val =
|
||||
builder_.getBoolAttr(node_type_name == "StatelessWhile");
|
||||
result.attributes.push_back(builder_.getNamedAttr("is_stateless", val));
|
||||
}
|
||||
// Map Case/If/While and StatelessCase/If/While op in TensorFlow to the common
|
||||
// Case/If/While op in MLIR and add the differentiating attribute.
|
||||
if (node.IsCaseNode()) composite_control_flow_op("Case");
|
||||
if (node.IsIfNode()) composite_control_flow_op("If");
|
||||
if (node.IsWhileNode()) composite_control_flow_op("While");
|
||||
|
||||
// Register the mapping between the TF node and the newly created operation.
|
||||
node_values_[node.id()] =
|
||||
|
@ -227,25 +227,13 @@ Status ConvertAttribute(const mlir::ArrayAttr& attr, AttrValue* value) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Updates NodeDef constructed out of an MLIR If op to map it to either
|
||||
// TensorFlow StatelessIf or If op depending on the additional attribute.
|
||||
void UpdateCompositeIfOp(NodeDef* node_def) {
|
||||
// Updates NodeDef constructed out of an MLIR Case/IfW/While op to map it to
|
||||
// either TensorFlow StatelessX or X op depending on the additional attribute.
|
||||
void UpdateCompositeOp(NodeDef* node_def) {
|
||||
auto it = node_def->mutable_attr()->find("is_stateless");
|
||||
if (it != node_def->attr().end()) {
|
||||
if (it->second.b()) {
|
||||
*node_def->mutable_op() = "StatelessIf";
|
||||
}
|
||||
node_def->mutable_attr()->erase(it);
|
||||
}
|
||||
}
|
||||
|
||||
// Updates NodeDef constructed out of an MLIR While op to map it to either
|
||||
// TensorFlow StatelessWhile or While op depending on the additional attribute.
|
||||
void UpdateCompositeWhileOp(NodeDef* node_def) {
|
||||
auto it = node_def->mutable_attr()->find("is_stateless");
|
||||
if (it != node_def->attr().end()) {
|
||||
if (it->second.b()) {
|
||||
*node_def->mutable_op() = "StatelessWhile";
|
||||
*node_def->mutable_op() = "Stateless" + node_def->op();
|
||||
}
|
||||
node_def->mutable_attr()->erase(it);
|
||||
}
|
||||
@ -352,8 +340,9 @@ StatusOr<std::unique_ptr<NodeDef>> GetOperationNodeDef(
|
||||
TF_RETURN_IF_ERROR(ConvertLocation(
|
||||
inst->getLoc(), node_def->mutable_experimental_debug_info()));
|
||||
|
||||
if (node_def->op() == "If") UpdateCompositeIfOp(node_def.get());
|
||||
if (node_def->op() == "While") UpdateCompositeWhileOp(node_def.get());
|
||||
if (node_def->op() == "Case") UpdateCompositeOp(node_def.get());
|
||||
if (node_def->op() == "If") UpdateCompositeOp(node_def.get());
|
||||
if (node_def->op() == "While") UpdateCompositeOp(node_def.get());
|
||||
|
||||
return node_def;
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user