Add output_shapes attribute to tf.WhileRegion.

This will match the representation for tf.While in regards to output shapes. Shape inference for tf.While/tf.WhileRegion should be handled differently when `output_shapes` attribute is not empty, as tf.While/tf.WhileRegion supports dynamic shapes.

Custom builders added for tf.While are removed as both functional and region based ops have explicit `output_shapes` attributes.

PiperOrigin-RevId: 341890776
Change-Id: I92bcbec86b997ad466b771f4de17ad9fde904842
This commit is contained in:
Andy Ly 2020-11-11 12:48:45 -08:00 committed by TensorFlower Gardener
parent c77657c395
commit 2889cec62c
5 changed files with 9 additions and 48 deletions

View File

@ -702,16 +702,6 @@ body: A function that takes a list of tensors and returns another
return Verify(*this);
}];
let builders = [
OpBuilderDAG<(ins "TypeRange":$output, "ValueRange":$input,
"FlatSymbolRefAttr":$cond, "FlatSymbolRefAttr":$body,
"IntegerAttr":$parallel_iterations,
"BoolAttr":$is_stateless)>,
OpBuilderDAG<(ins "TypeRange":$output, "ValueRange":$input,
"StringRef":$cond, "StringRef":$body,
"uint64_t":$parallel_iterations, "bool":$is_stateless)>
];
let extraClassDeclaration = [{
// Get the condition function.
FuncOp cond_function() {
@ -764,8 +754,9 @@ def TF_WhileRegionOp : TF_Op<"WhileRegion",
// Used to map StatelessWhile and While op defined in TensorFlow to a common
// op.
DefaultValuedAttr<BoolAttr, "false">:$is_stateless,
DefaultValuedAttr<I64Attr, "10">:$parallel_iterations
DefaultValuedAttr<TF_ShapeAttrArray, "{}">:$output_shapes,
DefaultValuedAttr<I64Attr, "10">:$parallel_iterations,
DefaultValuedAttr<BoolAttr, "false">:$is_stateless
);
let results = (outs Variadic<AnyTensor>:$output);

View File

@ -2605,39 +2605,6 @@ static LogicalResult Verify(WhileOp op) {
return success();
}
namespace {
ArrayAttr GetShapeArrayAttrFromTypes(mlir::MLIRContext *context,
TypeRange types) {
SmallVector<Attribute, 4> shapes;
shapes.reserve(types.size());
for (Type type : types)
shapes.push_back(ShapeAttr::get(context, type.cast<ShapedType>()));
return ArrayAttr::get(shapes, context);
}
} // namespace
void WhileOp::build(OpBuilder &builder, OperationState &result,
TypeRange output, ValueRange input, FlatSymbolRefAttr cond,
FlatSymbolRefAttr body, IntegerAttr parallel_iterations,
BoolAttr is_stateless) {
ArrayAttr output_shapes =
GetShapeArrayAttrFromTypes(builder.getContext(), output);
build(builder, result, output, input, cond, body, output_shapes,
parallel_iterations, is_stateless);
}
void WhileOp::build(OpBuilder &builder, OperationState &result,
TypeRange output, ValueRange input, StringRef cond,
StringRef body, uint64_t parallel_iterations,
bool is_stateless) {
ArrayAttr output_shapes =
GetShapeArrayAttrFromTypes(builder.getContext(), output);
build(builder, result, output, input, cond, body, output_shapes,
parallel_iterations, is_stateless);
}
//===----------------------------------------------------------------------===//
// WhileRegionOp
//===----------------------------------------------------------------------===//

View File

@ -112,7 +112,8 @@ LogicalResult ConvertIfOp(IfOp if_op) {
LogicalResult ConvertWhileOp(WhileOp while_op) {
auto while_region = OpBuilder(while_op).create<TF::WhileRegionOp>(
while_op.getLoc(), while_op.getResultTypes(), while_op.input(),
while_op.is_stateless(), while_op.parallel_iterations());
while_op.output_shapes(), while_op.parallel_iterations(),
while_op.is_stateless());
CopyDeviceAndUnderscoredAttributes(while_op, while_region);
YieldOp cond_yield =

View File

@ -398,7 +398,8 @@ LogicalResult RegionControlFlowToFunctional::ConvertWhileOp(
OpBuilder builder(while_region);
auto while_op = builder.create<WhileOp>(
while_region.getLoc(), new_result_types, new_inputs, cond_name, body_name,
while_region.parallel_iterations(), while_region.is_stateless());
while_region.output_shapes(), while_region.parallel_iterations(),
while_region.is_stateless());
CopyDeviceAndUnderscoredAttributes(while_region, while_op);
// Redirect old results to new results.

View File

@ -255,7 +255,8 @@ TF::WhileRegionOp CloneEmptyWhile(bool is_stateless,
OpBuilder& builder) {
auto host_side_while = builder.create<TF::WhileRegionOp>(
loc, /*output=*/ArrayRef<Type>{}, /*input=*/ArrayRef<Value>{},
is_stateless, parallel_iterations);
/*output_shapes=*/builder.getArrayAttr({}), parallel_iterations,
is_stateless);
// Create empty else branch region.
auto& body = host_side_while.body();