Update tf_executor.fetch, tf_executor.yield, tf_executor.NextIteration.Source, tf_executor.NextIteration.Sink, and tf_executor.ControlTrigger to use declarative assembly format instead of a custom parser and printer (NFC).

- tf_executor.NextIteration.Sink printer has been updated to not print a space between the op name and `[` holding the token operand (e.g. `tf_executor.NextIteration.Sink[%token]` instead of `tf_executor.NextIteration.Sink [%token]`).
- tf_executor.Exit description example has been corrected to show the attribute dict is printed after the input types instead of before.
- TfeControlType and TfeTokenType are updated to be buildable types.

PiperOrigin-RevId: 328774020
Change-Id: Icaaab30a4d65265e3993ed3bd9b6d0579dc19b8a
This commit is contained in:
Andy Ly 2020-08-27 11:07:05 -07:00 committed by TensorFlower Gardener
parent bad835f9af
commit 27a2c42be8
6 changed files with 35 additions and 136 deletions

View File

@ -250,33 +250,6 @@ ParseResult ParseGraphOp(OpAsmParser &parser, OperationState &result) {
// tf_executor.fetch // tf_executor.fetch
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
namespace {
void Print(FetchOp fetch, OpAsmPrinter &p) {
p << fetch.getOperationName();
if (fetch.getNumOperands() > 0) {
p << ' ';
p.printOperands(fetch.operand_begin(), fetch.operand_end());
p << " : ";
interleaveComma(fetch.getOperandTypes(), p);
}
p.printOptionalAttrDict(fetch.getAttrs());
}
ParseResult ParseFetchOp(OpAsmParser &parser, OperationState &result) {
SmallVector<OpAsmParser::OperandType, 2> opInfo;
SmallVector<Type, 2> types;
llvm::SMLoc loc = parser.getCurrentLocation();
return failure(parser.parseOperandList(opInfo) ||
(!opInfo.empty() && parser.parseColonTypeList(types)) ||
parser.resolveOperands(opInfo, types, loc, result.operands) ||
parser.parseOptionalAttrDict(result.attributes)
);
}
} // anonymous namespace
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// tf_executor.island // tf_executor.island
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -411,31 +384,6 @@ ParseResult ParseIslandOp(OpAsmParser &parser, OperationState &result) {
// tf_executor.yield // tf_executor.yield
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
namespace {
void Print(YieldOp yield, OpAsmPrinter &p) {
p << yield.getOperationName();
if (yield.getNumOperands() > 0) {
p << ' ';
p.printOperands(yield.operand_begin(), yield.operand_end());
p << " : ";
interleaveComma(yield.getOperandTypes(), p);
}
p.printOptionalAttrDict(yield.getAttrs());
}
ParseResult ParseYieldOp(OpAsmParser &parser, OperationState &result) {
SmallVector<OpAsmParser::OperandType, 2> op_info;
SmallVector<Type, 2> types;
llvm::SMLoc loc = parser.getCurrentLocation();
return failure(parser.parseOperandList(op_info) ||
(!op_info.empty() && parser.parseColonTypeList(types)) ||
parser.resolveOperands(op_info, types, loc, result.operands) ||
parser.parseOptionalAttrDict(result.attributes));
}
} // anonymous namespace
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// tf_executor.Switch // tf_executor.Switch
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -848,23 +796,6 @@ LogicalResult Verify(NextIterationSourceOp source) {
return success(); return success();
} }
void Print(NextIterationSourceOp next_iteration, OpAsmPrinter &p) {
p << next_iteration.getOperationName() << " : " << next_iteration.getType(0);
p.printOptionalAttrDict(next_iteration.getAttrs());
}
ParseResult ParseNextIterationSourceOp(OpAsmParser &parser,
OperationState &result) {
SmallVector<Type, 1> types;
if (parser.parseColonTypeList(types)) return failure();
MLIRContext *context = parser.getBuilder().getContext();
Type token_type = TokenType::get(context);
Type control_type = ControlType::get(context);
result.addTypes({types.front(), token_type, control_type});
return parser.parseOptionalAttrDict(result.attributes);
}
} // anonymous namespace } // anonymous namespace
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -891,36 +822,6 @@ LogicalResult Verify(NextIterationSinkOp sink) {
return success(); return success();
} }
void Print(NextIterationSinkOp next_iteration, OpAsmPrinter &p) {
p << next_iteration.getOperationName() << " [";
p.printOperand(next_iteration.getOperand(0));
p << "] ";
p.printOperands(llvm::drop_begin(next_iteration.getOperands(), 1));
p << " : " << next_iteration.getOperand(1).getType();
p.printOptionalAttrDict(next_iteration.getAttrs());
}
ParseResult ParseNextIterationSinkOp(OpAsmParser &parser,
OperationState &result) {
SmallVector<OpAsmParser::OperandType, 2> op_infos;
llvm::SMLoc loc = parser.getCurrentLocation();
// First type is always the token consumed from the NextIteration.source
Type token_type = TokenType::get(parser.getBuilder().getContext());
SmallVector<Type, 1> types = {token_type};
if (parser.parseOperandList(op_infos, 1, OpAsmParser::Delimiter::Square) ||
parser.parseOperandList(op_infos) || parser.parseColonTypeList(types))
return failure();
Type control_type = ControlType::get(parser.getBuilder().getContext());
types.append(op_infos.size() - 2, control_type);
if (parser.resolveOperands(op_infos, types, loc, result.operands))
return failure();
return parser.parseOptionalAttrDict(result.attributes);
}
} // anonymous namespace } // anonymous namespace
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -959,32 +860,6 @@ ParseResult ParseExitOp(OpAsmParser &parser, OperationState &result) {
// tf_executor.ControlTrigger // tf_executor.ControlTrigger
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
namespace {
void Print(ControlTriggerOp trigger, OpAsmPrinter &p) {
p << trigger.getOperationName() << ' ';
p.printOperands(trigger.getOperands());
p.printOptionalAttrDict(trigger.getAttrs());
}
ParseResult ParseControlTriggerOp(OpAsmParser &parser, OperationState &result) {
SmallVector<OpAsmParser::OperandType, 2> op_infos;
SmallVector<Type, 1> types;
llvm::SMLoc loc = parser.getCurrentLocation();
if (parser.parseOperandList(op_infos)) return failure();
Type control_type = ControlType::get(parser.getBuilder().getContext());
types.append(op_infos.size(), control_type);
if (parser.resolveOperands(op_infos, types, loc, result.operands))
return failure();
// Single control as the only output
result.types.push_back(control_type);
return parser.parseOptionalAttrDict(result.attributes);
}
} // anonymous namespace
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// tf_executor.LoopCond // tf_executor.LoopCond
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -47,10 +47,12 @@ def TfExecutor_Dialect : Dialect {
} }
// Control type. // Control type.
def TfeControlType : Type<CPred<"$_self.isa<ControlType>()">, "control">; def TfeControlType : Type<CPred<"$_self.isa<ControlType>()">, "control">,
BuildableType<"$_builder.getType<ControlType>()">;
// Token type. // Token type.
def TfeTokenType : Type<CPred<"$_self.isa<TokenType>()">, "token">; def TfeTokenType : Type<CPred<"$_self.isa<TokenType>()">, "token">,
BuildableType<"$_builder.getType<TokenType>()">;
// TODO(hinsu): Define and use TensorType instead of AnyType for data operands // TODO(hinsu): Define and use TensorType instead of AnyType for data operands
// and results. For example, MergeOp output type. // and results. For example, MergeOp output type.
@ -148,7 +150,11 @@ def TfExecutor_FetchOp : TfExecutor_Op<"fetch",
}]> }]>
]; ];
let assemblyFormat = "($fetches^ `:` type($fetches))? attr-dict";
let verifier = ?; let verifier = ?;
let printer = ?;
let parser = ?;
} }
def TfExecutor_IslandOp : TfExecutor_Op<"island", def TfExecutor_IslandOp : TfExecutor_Op<"island",
@ -229,7 +235,11 @@ def TfExecutor_YieldOp : TfExecutor_Op<"yield",
}]> }]>
]; ];
let assemblyFormat = "($fetches^ `:` type($fetches))? attr-dict";
let verifier = ?; let verifier = ?;
let printer = ?;
let parser = ?;
} }
def TfExecutor_SwitchOp : TfExecutor_Op<"Switch", def TfExecutor_SwitchOp : TfExecutor_Op<"Switch",
@ -466,6 +476,10 @@ def TfExecutor_NextIterationSourceOp : TfExecutor_Op<"NextIteration.Source",
} }
}]; }];
let assemblyFormat = "`:` type($output) attr-dict";
let printer = ?;
let parser = ?;
} }
@ -527,6 +541,11 @@ def TfExecutor_NextIterationSinkOp : TfExecutor_Op<"NextIteration.Sink",
result.attributes.append(attributes.begin(), attributes.end()); result.attributes.append(attributes.begin(), attributes.end());
}]> }]>
]; ];
let assemblyFormat = " `[` $token `]` $input (`,` $controlInputs^)? `:` type($input) attr-dict";
let printer = ?;
let parser = ?;
} }
def TfExecutor_ExitOp : TfExecutor_Op<"Exit", def TfExecutor_ExitOp : TfExecutor_Op<"Exit",
@ -552,7 +571,7 @@ def TfExecutor_ExitOp : TfExecutor_Op<"Exit",
.Attr("T: type") .Attr("T: type")
For example: For example:
%1:2 = tf_executor.Exit %0#0 {T: "tfdtype$DT_INT32"} : tensor<*xi32> %1:2 = tf_executor.Exit %0#0 : tensor<*xi32> {T: "tfdtype$DT_INT32"}
Note: Additional result corresponds to the control output. Note: Additional result corresponds to the control output.
}]; }];
@ -607,6 +626,11 @@ def TfExecutor_ControlTriggerOp : TfExecutor_Op<"ControlTrigger",
result.attributes.append(attributes.begin(), attributes.end()); result.attributes.append(attributes.begin(), attributes.end());
}]> }]>
]; ];
let assemblyFormat = "$controlInputs attr-dict";
let printer = ?;
let parser = ?;
} }
def TfExecutor_LoopCondOp : TfExecutor_Op<"LoopCond", def TfExecutor_LoopCondOp : TfExecutor_Op<"LoopCond",

View File

@ -285,7 +285,7 @@ func @empty_island_multiple_data_results(%arg0: tensor<*xf32>, %arg1: tensor<*xi
// and certain tf_executor ops are added correctly. // and certain tf_executor ops are added correctly.
// CHECK: %[[CONTROL:[^ ,]*]] = tf_executor.island wraps "tf.Print" // CHECK: %[[CONTROL:[^ ,]*]] = tf_executor.island wraps "tf.Print"
// CHECK: tf_executor.NextIteration.Sink [{{.*}}] {{.*}}, %[[CONTROL]] // CHECK: tf_executor.NextIteration.Sink[{{.*}}] {{.*}}, %[[CONTROL]]
func @next_iteration_sink_control_input() { func @next_iteration_sink_control_input() {
tf_executor.graph { tf_executor.graph {
%source:3 = tf_executor.NextIteration.Source : tensor<*xi32> %source:3 = tf_executor.NextIteration.Source : tensor<*xi32>

View File

@ -220,7 +220,7 @@ func @merge_islands_only() {
%11:2 = tf_executor.island(%10#1) wraps "tf.opF"() : () -> tensor<i32> %11:2 = tf_executor.island(%10#1) wraps "tf.opF"() : () -> tensor<i32>
%12:2 = tf_executor.island wraps "tf.opG"(%10#0, %11#0) : (tensor<*xi32>, tensor<i32>) -> tensor<*xi32> %12:2 = tf_executor.island wraps "tf.opG"(%10#0, %11#0) : (tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
%13 = tf_executor.ControlTrigger %2, %12#1, %9#1 %13 = tf_executor.ControlTrigger %2, %12#1, %9#1
tf_executor.NextIteration.Sink [%3#1] %12#0, %13 : tensor<*xi32> tf_executor.NextIteration.Sink[%3#1] %12#0, %13 : tensor<*xi32>
tf_executor.fetch tf_executor.fetch
} }
return return
@ -244,7 +244,7 @@ func @merge_islands_only() {
// CHECK-NEXT: %[[OP_G:[0-9]*]] = "tf.opG"(%[[OP_E]], %[[OP_F]]) // CHECK-NEXT: %[[OP_G:[0-9]*]] = "tf.opG"(%[[OP_E]], %[[OP_F]])
// CHECK-NEXT: tf_executor.yield %[[OP_G]] : tensor<*xi32> // CHECK-NEXT: tf_executor.yield %[[OP_G]] : tensor<*xi32>
// CHECK: %[[CT:.*]] = tf_executor.ControlTrigger %[[ISLAND_1]], %[[ISLAND_3_control]], %[[EXIT_control]] // CHECK: %[[CT:.*]] = tf_executor.ControlTrigger %[[ISLAND_1]], %[[ISLAND_3_control]], %[[EXIT_control]]
// CHECK-NEXT: tf_executor.NextIteration.Sink [%[[NEXTIT_SRC_token]]] %[[ISLAND_3]], %[[CT]] // CHECK-NEXT: tf_executor.NextIteration.Sink[%[[NEXTIT_SRC_token]]] %[[ISLAND_3]], %[[CT]]
// Test no merging took place as cycle would be formed otherwise. // Test no merging took place as cycle would be formed otherwise.

View File

@ -7,7 +7,7 @@
# CHECK: %[[NEXTITERATION:[a-z0-9]+]], %[[NEXTITERATION_token:[a-z0-9]+]], {{.*}} = tf_executor.NextIteration.Source # CHECK: %[[NEXTITERATION:[a-z0-9]+]], %[[NEXTITERATION_token:[a-z0-9]+]], {{.*}} = tf_executor.NextIteration.Source
# CHECK: tf_executor.Merge {{.*}} %[[NEXTITERATION]] # CHECK: tf_executor.Merge {{.*}} %[[NEXTITERATION]]
# CHECK: tf_executor.NextIteration.Sink [%[[NEXTITERATION_token]]] # CHECK: tf_executor.NextIteration.Sink[%[[NEXTITERATION_token]]]
node { node {
name: "Const" name: "Const"

View File

@ -433,7 +433,7 @@ func @nextiteration(%arg0: tensor<*xf32>, %arg1: i1) -> tensor<*xf32> {
%1:3 = tf_executor.NextIteration.Source : tensor<*xf32> %1:3 = tf_executor.NextIteration.Source : tensor<*xf32>
tf_executor.NextIteration.Sink[%1#1] %1#0 : tensor<*xf32> tf_executor.NextIteration.Sink[%1#1] %1#0 : tensor<*xf32>
// CHECK: tf_executor.NextIteration.Source : tensor<*xf32> // CHECK: tf_executor.NextIteration.Source : tensor<*xf32>
// CHECK: tf_executor.NextIteration.Sink [%{{.*}}] %{{.*}} : tensor<*xf32> // CHECK: tf_executor.NextIteration.Sink[%{{.*}}] %{{.*}} : tensor<*xf32>
tf_executor.fetch %1#0 : tensor<*xf32> tf_executor.fetch %1#0 : tensor<*xf32>
} }
return %0 : tensor<*xf32> return %0 : tensor<*xf32>
@ -445,7 +445,7 @@ func @nextiteration_with_attributes(%arg0: tensor<*xf32>, %arg1: i1) -> tensor<*
%1:3 = tf_executor.NextIteration.Source : tensor<*xf32> {attr3 = 32 : i64, tf_executor.attr_fetch = "some_value"} %1:3 = tf_executor.NextIteration.Source : tensor<*xf32> {attr3 = 32 : i64, tf_executor.attr_fetch = "some_value"}
tf_executor.NextIteration.Sink[%1#1] %1#0 : tensor<*xf32> {attr4 = 42 : i64, tf_executor.attr_push = "other_value"} tf_executor.NextIteration.Sink[%1#1] %1#0 : tensor<*xf32> {attr4 = 42 : i64, tf_executor.attr_push = "other_value"}
// CHECK: tf_executor.NextIteration.Source : tensor<*xf32> {attr3 = 32 : i64, tf_executor.attr_fetch = "some_value"} // CHECK: tf_executor.NextIteration.Source : tensor<*xf32> {attr3 = 32 : i64, tf_executor.attr_fetch = "some_value"}
// CHECK: tf_executor.NextIteration.Sink [%{{.*}}] %{{.*}} : tensor<*xf32> {attr4 = 42 : i64, tf_executor.attr_push = "other_value"} // CHECK: tf_executor.NextIteration.Sink[%{{.*}}] %{{.*}} : tensor<*xf32> {attr4 = 42 : i64, tf_executor.attr_push = "other_value"}
tf_executor.fetch %1#0 : tensor<*xf32> tf_executor.fetch %1#0 : tensor<*xf32>
} }
return %0 : tensor<*xf32> return %0 : tensor<*xf32>
@ -457,9 +457,9 @@ func @nextiteration_control(%arg0: tensor<*xf32>, %arg1: tensor<i1>) -> tensor<*
%1:3 = tf_executor.Switch %arg0, %arg1 : tensor<*xf32> %1:3 = tf_executor.Switch %arg0, %arg1 : tensor<*xf32>
%2:2 = tf_executor.Enter %arg0, %1#2, %1#2 frame "some/frame" : tensor<*xf32> %2:2 = tf_executor.Enter %arg0, %1#2, %1#2 frame "some/frame" : tensor<*xf32>
%3:3 = tf_executor.NextIteration.Source : tensor<*xf32> %3:3 = tf_executor.NextIteration.Source : tensor<*xf32>
tf_executor.NextIteration.Sink [%3#1] %3#0, %1#2 : tensor<*xf32> tf_executor.NextIteration.Sink[%3#1] %3#0, %1#2 : tensor<*xf32>
// CHECK: tf_executor.NextIteration.Source : tensor<*xf32> // CHECK: tf_executor.NextIteration.Source : tensor<*xf32>
// CHECK: tf_executor.NextIteration.Sink [%{{.*}}] %{{.*}}, %{{.*}} : tensor<*xf32> // CHECK: tf_executor.NextIteration.Sink[%{{.*}}] %{{.*}}, %{{.*}} : tensor<*xf32>
tf_executor.fetch %3#0 : tensor<*xf32> tf_executor.fetch %3#0 : tensor<*xf32>
} }
return %0 : tensor<*xf32> return %0 : tensor<*xf32>