diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc index ea9ae5d9477..eced738b0a5 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc @@ -250,33 +250,6 @@ ParseResult ParseGraphOp(OpAsmParser &parser, OperationState &result) { // tf_executor.fetch //===----------------------------------------------------------------------===// -namespace { - -void Print(FetchOp fetch, OpAsmPrinter &p) { - p << fetch.getOperationName(); - if (fetch.getNumOperands() > 0) { - p << ' '; - p.printOperands(fetch.operand_begin(), fetch.operand_end()); - p << " : "; - interleaveComma(fetch.getOperandTypes(), p); - } - p.printOptionalAttrDict(fetch.getAttrs()); -} - -ParseResult ParseFetchOp(OpAsmParser &parser, OperationState &result) { - SmallVector opInfo; - SmallVector types; - llvm::SMLoc loc = parser.getCurrentLocation(); - return failure(parser.parseOperandList(opInfo) || - (!opInfo.empty() && parser.parseColonTypeList(types)) || - parser.resolveOperands(opInfo, types, loc, result.operands) || - parser.parseOptionalAttrDict(result.attributes) - - ); -} - -} // anonymous namespace - //===----------------------------------------------------------------------===// // tf_executor.island //===----------------------------------------------------------------------===// @@ -411,31 +384,6 @@ ParseResult ParseIslandOp(OpAsmParser &parser, OperationState &result) { // tf_executor.yield //===----------------------------------------------------------------------===// -namespace { - -void Print(YieldOp yield, OpAsmPrinter &p) { - p << yield.getOperationName(); - if (yield.getNumOperands() > 0) { - p << ' '; - p.printOperands(yield.operand_begin(), yield.operand_end()); - p << " : "; - interleaveComma(yield.getOperandTypes(), p); - } - p.printOptionalAttrDict(yield.getAttrs()); -} - -ParseResult ParseYieldOp(OpAsmParser &parser, OperationState &result) { - SmallVector op_info; - SmallVector types; - llvm::SMLoc loc = parser.getCurrentLocation(); - return failure(parser.parseOperandList(op_info) || - (!op_info.empty() && parser.parseColonTypeList(types)) || - parser.resolveOperands(op_info, types, loc, result.operands) || - parser.parseOptionalAttrDict(result.attributes)); -} - -} // anonymous namespace - //===----------------------------------------------------------------------===// // tf_executor.Switch //===----------------------------------------------------------------------===// @@ -848,23 +796,6 @@ LogicalResult Verify(NextIterationSourceOp source) { return success(); } -void Print(NextIterationSourceOp next_iteration, OpAsmPrinter &p) { - p << next_iteration.getOperationName() << " : " << next_iteration.getType(0); - p.printOptionalAttrDict(next_iteration.getAttrs()); -} - -ParseResult ParseNextIterationSourceOp(OpAsmParser &parser, - OperationState &result) { - SmallVector types; - if (parser.parseColonTypeList(types)) return failure(); - - MLIRContext *context = parser.getBuilder().getContext(); - Type token_type = TokenType::get(context); - Type control_type = ControlType::get(context); - result.addTypes({types.front(), token_type, control_type}); - return parser.parseOptionalAttrDict(result.attributes); -} - } // anonymous namespace //===----------------------------------------------------------------------===// @@ -891,36 +822,6 @@ LogicalResult Verify(NextIterationSinkOp sink) { return success(); } -void Print(NextIterationSinkOp next_iteration, OpAsmPrinter &p) { - p << next_iteration.getOperationName() << " ["; - p.printOperand(next_iteration.getOperand(0)); - p << "] "; - p.printOperands(llvm::drop_begin(next_iteration.getOperands(), 1)); - p << " : " << next_iteration.getOperand(1).getType(); - p.printOptionalAttrDict(next_iteration.getAttrs()); -} - -ParseResult ParseNextIterationSinkOp(OpAsmParser &parser, - OperationState &result) { - SmallVector op_infos; - llvm::SMLoc loc = parser.getCurrentLocation(); - - // First type is always the token consumed from the NextIteration.source - Type token_type = TokenType::get(parser.getBuilder().getContext()); - SmallVector types = {token_type}; - - if (parser.parseOperandList(op_infos, 1, OpAsmParser::Delimiter::Square) || - parser.parseOperandList(op_infos) || parser.parseColonTypeList(types)) - return failure(); - - Type control_type = ControlType::get(parser.getBuilder().getContext()); - types.append(op_infos.size() - 2, control_type); - if (parser.resolveOperands(op_infos, types, loc, result.operands)) - return failure(); - - return parser.parseOptionalAttrDict(result.attributes); -} - } // anonymous namespace //===----------------------------------------------------------------------===// @@ -959,32 +860,6 @@ ParseResult ParseExitOp(OpAsmParser &parser, OperationState &result) { // tf_executor.ControlTrigger //===----------------------------------------------------------------------===// -namespace { - -void Print(ControlTriggerOp trigger, OpAsmPrinter &p) { - p << trigger.getOperationName() << ' '; - p.printOperands(trigger.getOperands()); - p.printOptionalAttrDict(trigger.getAttrs()); -} - -ParseResult ParseControlTriggerOp(OpAsmParser &parser, OperationState &result) { - SmallVector op_infos; - SmallVector types; - llvm::SMLoc loc = parser.getCurrentLocation(); - - if (parser.parseOperandList(op_infos)) return failure(); - Type control_type = ControlType::get(parser.getBuilder().getContext()); - types.append(op_infos.size(), control_type); - if (parser.resolveOperands(op_infos, types, loc, result.operands)) - return failure(); - - // Single control as the only output - result.types.push_back(control_type); - return parser.parseOptionalAttrDict(result.attributes); -} - -} // anonymous namespace - //===----------------------------------------------------------------------===// // tf_executor.LoopCond //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor_ops.td index 3081018b8da..de2d2485628 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor_ops.td @@ -47,10 +47,12 @@ def TfExecutor_Dialect : Dialect { } // Control type. -def TfeControlType : Type()">, "control">; +def TfeControlType : Type()">, "control">, + BuildableType<"$_builder.getType()">; // Token type. -def TfeTokenType : Type()">, "token">; +def TfeTokenType : Type()">, "token">, + BuildableType<"$_builder.getType()">; // TODO(hinsu): Define and use TensorType instead of AnyType for data operands // and results. For example, MergeOp output type. @@ -148,7 +150,11 @@ def TfExecutor_FetchOp : TfExecutor_Op<"fetch", }]> ]; + let assemblyFormat = "($fetches^ `:` type($fetches))? attr-dict"; + let verifier = ?; + let printer = ?; + let parser = ?; } def TfExecutor_IslandOp : TfExecutor_Op<"island", @@ -229,7 +235,11 @@ def TfExecutor_YieldOp : TfExecutor_Op<"yield", }]> ]; + let assemblyFormat = "($fetches^ `:` type($fetches))? attr-dict"; + let verifier = ?; + let printer = ?; + let parser = ?; } def TfExecutor_SwitchOp : TfExecutor_Op<"Switch", @@ -466,6 +476,10 @@ def TfExecutor_NextIterationSourceOp : TfExecutor_Op<"NextIteration.Source", } }]; + let assemblyFormat = "`:` type($output) attr-dict"; + + let printer = ?; + let parser = ?; } @@ -527,6 +541,11 @@ def TfExecutor_NextIterationSinkOp : TfExecutor_Op<"NextIteration.Sink", result.attributes.append(attributes.begin(), attributes.end()); }]> ]; + + let assemblyFormat = " `[` $token `]` $input (`,` $controlInputs^)? `:` type($input) attr-dict"; + + let printer = ?; + let parser = ?; } def TfExecutor_ExitOp : TfExecutor_Op<"Exit", @@ -552,7 +571,7 @@ def TfExecutor_ExitOp : TfExecutor_Op<"Exit", .Attr("T: type") For example: - %1:2 = tf_executor.Exit %0#0 {T: "tfdtype$DT_INT32"} : tensor<*xi32> + %1:2 = tf_executor.Exit %0#0 : tensor<*xi32> {T: "tfdtype$DT_INT32"} Note: Additional result corresponds to the control output. }]; @@ -607,6 +626,11 @@ def TfExecutor_ControlTriggerOp : TfExecutor_Op<"ControlTrigger", result.attributes.append(attributes.begin(), attributes.end()); }]> ]; + + let assemblyFormat = "$controlInputs attr-dict"; + + let printer = ?; + let parser = ?; } def TfExecutor_LoopCondOp : TfExecutor_Op<"LoopCond", diff --git a/tensorflow/compiler/mlir/tensorflow/tests/breakup-islands.mlir b/tensorflow/compiler/mlir/tensorflow/tests/breakup-islands.mlir index 05d34eb0755..6654341ab42 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/breakup-islands.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/breakup-islands.mlir @@ -285,7 +285,7 @@ func @empty_island_multiple_data_results(%arg0: tensor<*xf32>, %arg1: tensor<*xi // and certain tf_executor ops are added correctly. // CHECK: %[[CONTROL:[^ ,]*]] = tf_executor.island wraps "tf.Print" -// CHECK: tf_executor.NextIteration.Sink [{{.*}}] {{.*}}, %[[CONTROL]] +// CHECK: tf_executor.NextIteration.Sink[{{.*}}] {{.*}}, %[[CONTROL]] func @next_iteration_sink_control_input() { tf_executor.graph { %source:3 = tf_executor.NextIteration.Source : tensor<*xi32> diff --git a/tensorflow/compiler/mlir/tensorflow/tests/executor_island_coarsening.mlir b/tensorflow/compiler/mlir/tensorflow/tests/executor_island_coarsening.mlir index bec48181b3b..726495f1fbc 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/executor_island_coarsening.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/executor_island_coarsening.mlir @@ -220,7 +220,7 @@ func @merge_islands_only() { %11:2 = tf_executor.island(%10#1) wraps "tf.opF"() : () -> tensor %12:2 = tf_executor.island wraps "tf.opG"(%10#0, %11#0) : (tensor<*xi32>, tensor) -> tensor<*xi32> %13 = tf_executor.ControlTrigger %2, %12#1, %9#1 - tf_executor.NextIteration.Sink [%3#1] %12#0, %13 : tensor<*xi32> + tf_executor.NextIteration.Sink[%3#1] %12#0, %13 : tensor<*xi32> tf_executor.fetch } return @@ -244,7 +244,7 @@ func @merge_islands_only() { // CHECK-NEXT: %[[OP_G:[0-9]*]] = "tf.opG"(%[[OP_E]], %[[OP_F]]) // CHECK-NEXT: tf_executor.yield %[[OP_G]] : tensor<*xi32> // CHECK: %[[CT:.*]] = tf_executor.ControlTrigger %[[ISLAND_1]], %[[ISLAND_3_control]], %[[EXIT_control]] -// CHECK-NEXT: tf_executor.NextIteration.Sink [%[[NEXTIT_SRC_token]]] %[[ISLAND_3]], %[[CT]] +// CHECK-NEXT: tf_executor.NextIteration.Sink[%[[NEXTIT_SRC_token]]] %[[ISLAND_3]], %[[CT]] // Test no merging took place as cycle would be formed otherwise. diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-while-loop.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-while-loop.pbtxt index e21fd901a9e..a6b1979ee26 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-while-loop.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-while-loop.pbtxt @@ -7,7 +7,7 @@ # CHECK: %[[NEXTITERATION:[a-z0-9]+]], %[[NEXTITERATION_token:[a-z0-9]+]], {{.*}} = tf_executor.NextIteration.Source # CHECK: tf_executor.Merge {{.*}} %[[NEXTITERATION]] -# CHECK: tf_executor.NextIteration.Sink [%[[NEXTITERATION_token]]] +# CHECK: tf_executor.NextIteration.Sink[%[[NEXTITERATION_token]]] node { name: "Const" diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops.mlir index 1e537880620..23a8e904ad9 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops.mlir @@ -433,7 +433,7 @@ func @nextiteration(%arg0: tensor<*xf32>, %arg1: i1) -> tensor<*xf32> { %1:3 = tf_executor.NextIteration.Source : tensor<*xf32> tf_executor.NextIteration.Sink[%1#1] %1#0 : tensor<*xf32> // CHECK: tf_executor.NextIteration.Source : tensor<*xf32> -// CHECK: tf_executor.NextIteration.Sink [%{{.*}}] %{{.*}} : tensor<*xf32> +// CHECK: tf_executor.NextIteration.Sink[%{{.*}}] %{{.*}} : tensor<*xf32> tf_executor.fetch %1#0 : tensor<*xf32> } return %0 : tensor<*xf32> @@ -445,7 +445,7 @@ func @nextiteration_with_attributes(%arg0: tensor<*xf32>, %arg1: i1) -> tensor<* %1:3 = tf_executor.NextIteration.Source : tensor<*xf32> {attr3 = 32 : i64, tf_executor.attr_fetch = "some_value"} tf_executor.NextIteration.Sink[%1#1] %1#0 : tensor<*xf32> {attr4 = 42 : i64, tf_executor.attr_push = "other_value"} // CHECK: tf_executor.NextIteration.Source : tensor<*xf32> {attr3 = 32 : i64, tf_executor.attr_fetch = "some_value"} -// CHECK: tf_executor.NextIteration.Sink [%{{.*}}] %{{.*}} : tensor<*xf32> {attr4 = 42 : i64, tf_executor.attr_push = "other_value"} +// CHECK: tf_executor.NextIteration.Sink[%{{.*}}] %{{.*}} : tensor<*xf32> {attr4 = 42 : i64, tf_executor.attr_push = "other_value"} tf_executor.fetch %1#0 : tensor<*xf32> } return %0 : tensor<*xf32> @@ -457,9 +457,9 @@ func @nextiteration_control(%arg0: tensor<*xf32>, %arg1: tensor) -> tensor<* %1:3 = tf_executor.Switch %arg0, %arg1 : tensor<*xf32> %2:2 = tf_executor.Enter %arg0, %1#2, %1#2 frame "some/frame" : tensor<*xf32> %3:3 = tf_executor.NextIteration.Source : tensor<*xf32> - tf_executor.NextIteration.Sink [%3#1] %3#0, %1#2 : tensor<*xf32> + tf_executor.NextIteration.Sink[%3#1] %3#0, %1#2 : tensor<*xf32> // CHECK: tf_executor.NextIteration.Source : tensor<*xf32> -// CHECK: tf_executor.NextIteration.Sink [%{{.*}}] %{{.*}}, %{{.*}} : tensor<*xf32> +// CHECK: tf_executor.NextIteration.Sink[%{{.*}}] %{{.*}}, %{{.*}} : tensor<*xf32> tf_executor.fetch %3#0 : tensor<*xf32> } return %0 : tensor<*xf32>