Add output_shapes
attribute to tf.WhileRegion.
This will match the representation for tf.While in regards to output shapes. Shape inference for tf.While/tf.WhileRegion should be handled differently when `output_shapes` attribute is not empty, as tf.While/tf.WhileRegion supports dynamic shapes. Custom builders added for tf.While are removed as both functional and region based ops have explicit `output_shapes` attributes. PiperOrigin-RevId: 341890776 Change-Id: I92bcbec86b997ad466b771f4de17ad9fde904842
This commit is contained in:
parent
c77657c395
commit
2889cec62c
@ -702,16 +702,6 @@ body: A function that takes a list of tensors and returns another
|
||||
return Verify(*this);
|
||||
}];
|
||||
|
||||
let builders = [
|
||||
OpBuilderDAG<(ins "TypeRange":$output, "ValueRange":$input,
|
||||
"FlatSymbolRefAttr":$cond, "FlatSymbolRefAttr":$body,
|
||||
"IntegerAttr":$parallel_iterations,
|
||||
"BoolAttr":$is_stateless)>,
|
||||
OpBuilderDAG<(ins "TypeRange":$output, "ValueRange":$input,
|
||||
"StringRef":$cond, "StringRef":$body,
|
||||
"uint64_t":$parallel_iterations, "bool":$is_stateless)>
|
||||
];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
// Get the condition function.
|
||||
FuncOp cond_function() {
|
||||
@ -764,8 +754,9 @@ def TF_WhileRegionOp : TF_Op<"WhileRegion",
|
||||
|
||||
// Used to map StatelessWhile and While op defined in TensorFlow to a common
|
||||
// op.
|
||||
DefaultValuedAttr<BoolAttr, "false">:$is_stateless,
|
||||
DefaultValuedAttr<I64Attr, "10">:$parallel_iterations
|
||||
DefaultValuedAttr<TF_ShapeAttrArray, "{}">:$output_shapes,
|
||||
DefaultValuedAttr<I64Attr, "10">:$parallel_iterations,
|
||||
DefaultValuedAttr<BoolAttr, "false">:$is_stateless
|
||||
);
|
||||
let results = (outs Variadic<AnyTensor>:$output);
|
||||
|
||||
|
@ -2605,39 +2605,6 @@ static LogicalResult Verify(WhileOp op) {
|
||||
return success();
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
ArrayAttr GetShapeArrayAttrFromTypes(mlir::MLIRContext *context,
|
||||
TypeRange types) {
|
||||
SmallVector<Attribute, 4> shapes;
|
||||
shapes.reserve(types.size());
|
||||
for (Type type : types)
|
||||
shapes.push_back(ShapeAttr::get(context, type.cast<ShapedType>()));
|
||||
return ArrayAttr::get(shapes, context);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void WhileOp::build(OpBuilder &builder, OperationState &result,
|
||||
TypeRange output, ValueRange input, FlatSymbolRefAttr cond,
|
||||
FlatSymbolRefAttr body, IntegerAttr parallel_iterations,
|
||||
BoolAttr is_stateless) {
|
||||
ArrayAttr output_shapes =
|
||||
GetShapeArrayAttrFromTypes(builder.getContext(), output);
|
||||
build(builder, result, output, input, cond, body, output_shapes,
|
||||
parallel_iterations, is_stateless);
|
||||
}
|
||||
|
||||
void WhileOp::build(OpBuilder &builder, OperationState &result,
|
||||
TypeRange output, ValueRange input, StringRef cond,
|
||||
StringRef body, uint64_t parallel_iterations,
|
||||
bool is_stateless) {
|
||||
ArrayAttr output_shapes =
|
||||
GetShapeArrayAttrFromTypes(builder.getContext(), output);
|
||||
build(builder, result, output, input, cond, body, output_shapes,
|
||||
parallel_iterations, is_stateless);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// WhileRegionOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -112,7 +112,8 @@ LogicalResult ConvertIfOp(IfOp if_op) {
|
||||
LogicalResult ConvertWhileOp(WhileOp while_op) {
|
||||
auto while_region = OpBuilder(while_op).create<TF::WhileRegionOp>(
|
||||
while_op.getLoc(), while_op.getResultTypes(), while_op.input(),
|
||||
while_op.is_stateless(), while_op.parallel_iterations());
|
||||
while_op.output_shapes(), while_op.parallel_iterations(),
|
||||
while_op.is_stateless());
|
||||
CopyDeviceAndUnderscoredAttributes(while_op, while_region);
|
||||
|
||||
YieldOp cond_yield =
|
||||
|
@ -398,7 +398,8 @@ LogicalResult RegionControlFlowToFunctional::ConvertWhileOp(
|
||||
OpBuilder builder(while_region);
|
||||
auto while_op = builder.create<WhileOp>(
|
||||
while_region.getLoc(), new_result_types, new_inputs, cond_name, body_name,
|
||||
while_region.parallel_iterations(), while_region.is_stateless());
|
||||
while_region.output_shapes(), while_region.parallel_iterations(),
|
||||
while_region.is_stateless());
|
||||
CopyDeviceAndUnderscoredAttributes(while_region, while_op);
|
||||
|
||||
// Redirect old results to new results.
|
||||
|
@ -255,7 +255,8 @@ TF::WhileRegionOp CloneEmptyWhile(bool is_stateless,
|
||||
OpBuilder& builder) {
|
||||
auto host_side_while = builder.create<TF::WhileRegionOp>(
|
||||
loc, /*output=*/ArrayRef<Type>{}, /*input=*/ArrayRef<Value>{},
|
||||
is_stateless, parallel_iterations);
|
||||
/*output_shapes=*/builder.getArrayAttr({}), parallel_iterations,
|
||||
is_stateless);
|
||||
|
||||
// Create empty else branch region.
|
||||
auto& body = host_side_while.body();
|
||||
|
Loading…
Reference in New Issue
Block a user