diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td index 61a55c3534d..89273354d45 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td @@ -702,16 +702,6 @@ body: A function that takes a list of tensors and returns another return Verify(*this); }]; - let builders = [ - OpBuilderDAG<(ins "TypeRange":$output, "ValueRange":$input, - "FlatSymbolRefAttr":$cond, "FlatSymbolRefAttr":$body, - "IntegerAttr":$parallel_iterations, - "BoolAttr":$is_stateless)>, - OpBuilderDAG<(ins "TypeRange":$output, "ValueRange":$input, - "StringRef":$cond, "StringRef":$body, - "uint64_t":$parallel_iterations, "bool":$is_stateless)> - ]; - let extraClassDeclaration = [{ // Get the condition function. FuncOp cond_function() { @@ -764,8 +754,9 @@ def TF_WhileRegionOp : TF_Op<"WhileRegion", // Used to map StatelessWhile and While op defined in TensorFlow to a common // op. - DefaultValuedAttr:$is_stateless, - DefaultValuedAttr:$parallel_iterations + DefaultValuedAttr:$output_shapes, + DefaultValuedAttr:$parallel_iterations, + DefaultValuedAttr:$is_stateless ); let results = (outs Variadic:$output); diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc index 5ead88b2903..fdff883f7a5 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc @@ -2605,39 +2605,6 @@ static LogicalResult Verify(WhileOp op) { return success(); } -namespace { - -ArrayAttr GetShapeArrayAttrFromTypes(mlir::MLIRContext *context, - TypeRange types) { - SmallVector shapes; - shapes.reserve(types.size()); - for (Type type : types) - shapes.push_back(ShapeAttr::get(context, type.cast())); - return ArrayAttr::get(shapes, context); -} - -} // namespace - -void WhileOp::build(OpBuilder &builder, OperationState &result, - TypeRange output, ValueRange input, FlatSymbolRefAttr cond, - FlatSymbolRefAttr body, IntegerAttr parallel_iterations, - BoolAttr is_stateless) { - ArrayAttr output_shapes = - GetShapeArrayAttrFromTypes(builder.getContext(), output); - build(builder, result, output, input, cond, body, output_shapes, - parallel_iterations, is_stateless); -} - -void WhileOp::build(OpBuilder &builder, OperationState &result, - TypeRange output, ValueRange input, StringRef cond, - StringRef body, uint64_t parallel_iterations, - bool is_stateless) { - ArrayAttr output_shapes = - GetShapeArrayAttrFromTypes(builder.getContext(), output); - build(builder, result, output, input, cond, body, output_shapes, - parallel_iterations, is_stateless); -} - //===----------------------------------------------------------------------===// // WhileRegionOp //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_regions.cc b/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_regions.cc index 87733bbbf3f..a92d3f367cf 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_regions.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_regions.cc @@ -112,7 +112,8 @@ LogicalResult ConvertIfOp(IfOp if_op) { LogicalResult ConvertWhileOp(WhileOp while_op) { auto while_region = OpBuilder(while_op).create( while_op.getLoc(), while_op.getResultTypes(), while_op.input(), - while_op.is_stateless(), while_op.parallel_iterations()); + while_op.output_shapes(), while_op.parallel_iterations(), + while_op.is_stateless()); CopyDeviceAndUnderscoredAttributes(while_op, while_region); YieldOp cond_yield = diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/region_control_flow_to_functional.cc b/tensorflow/compiler/mlir/tensorflow/transforms/region_control_flow_to_functional.cc index 9a6f8696285..66e736db869 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/region_control_flow_to_functional.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/region_control_flow_to_functional.cc @@ -398,7 +398,8 @@ LogicalResult RegionControlFlowToFunctional::ConvertWhileOp( OpBuilder builder(while_region); auto while_op = builder.create( while_region.getLoc(), new_result_types, new_inputs, cond_name, body_name, - while_region.parallel_iterations(), while_region.is_stateless()); + while_region.output_shapes(), while_region.parallel_iterations(), + while_region.is_stateless()); CopyDeviceAndUnderscoredAttributes(while_region, while_op); // Redirect old results to new results. diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_outside_compilation.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_outside_compilation.cc index f1dc3f21087..7953dfe1832 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_outside_compilation.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_outside_compilation.cc @@ -255,7 +255,8 @@ TF::WhileRegionOp CloneEmptyWhile(bool is_stateless, OpBuilder& builder) { auto host_side_while = builder.create( loc, /*output=*/ArrayRef{}, /*input=*/ArrayRef{}, - is_stateless, parallel_iterations); + /*output_shapes=*/builder.getArrayAttr({}), parallel_iterations, + is_stateless); // Create empty else branch region. auto& body = host_side_while.body();