Update Graph -> TF MLIR importer to always import tf.While output shapes.

output_shapes attribute is reverted back to be a derived attribute and output shapes are populated as result types on import. A new attribute, shape_invariant, is added to tf.While/tf.WhileRegion indicating whether shapes were originally present on the While op. This attribute will be used to update the shape inference for tf.While/tf.WhileRegion, where it is possible to have different result shapes and operand shapes cannot simply be propagated as result shapes.

PiperOrigin-RevId: 343546219
Change-Id: Ib7e832f11edabd59d0f3b45cb0d31c6b6187e706
This commit is contained in:
Andy Ly 2020-11-20 12:44:49 -08:00 committed by TensorFlower Gardener
parent 70802c2d5b
commit 7fa0d80d06
8 changed files with 142 additions and 18 deletions

View File

@ -684,12 +684,23 @@ body: A function that takes a list of tensors and returns another
FlatSymbolRefAttr:$cond,
FlatSymbolRefAttr:$body,
DefaultValuedAttr<TF_ShapeAttrArray, "{}">:$output_shapes,
DefaultValuedAttr<I64Attr, "10">:$parallel_iterations,
// Used to map StatelessWhile and While op defined in TensorFlow to a common
// op.
BoolAttr:$is_stateless
BoolAttr:$is_stateless,
// In TensorFlow, While has a special behavior where if `output_shapes`
// attribute is not empty, those shapes are used in its shape function
// as result shapes instead of propagating operand shapes as result shapes.
// This allows for different result shapes from operand shapes. While these
// shapes are imported and set as a part of the result type, there is no
// indicator differentiating between having no output shapes compared to
// having all unranked shapes. Thus this attribute is set to determine
// which shape function behavior to use for this op, specifically
// propagating operand shapes as result shapes when this attribute is not
// set, or preserving result shapes as is when this attribute is set.
UnitAttr:$shape_invariant
);
let results = (outs
@ -697,6 +708,7 @@ body: A function that takes a list of tensors and returns another
);
TF_DerivedOperandTypeListAttr T = TF_DerivedOperandTypeListAttr<0>;
TF_DerivedResultShapeListAttr output_shapes = TF_DerivedResultShapeListAttr<0>;
let verifier = [{
return Verify(*this);
@ -752,12 +764,23 @@ def TF_WhileRegionOp : TF_Op<"WhileRegion",
let arguments = (ins
Variadic<AnyTensor>:$input,
DefaultValuedAttr<TF_ShapeAttrArray, "{}">:$output_shapes,
DefaultValuedAttr<I64Attr, "10">:$parallel_iterations,
// Used to map StatelessWhile and While op defined in TensorFlow to a common
// op.
BoolAttr:$is_stateless
BoolAttr:$is_stateless,
// In TensorFlow, While has a special behavior where if `output_shapes`
// attribute is not empty, those shapes are used in its shape function
// as result shapes instead of propagating operand shapes as result shapes.
// This allows for different result shapes from operand shapes. While these
// shapes are imported and set as a part of the result type, there is no
// indicator differentiating between having no output shapes compared to
// having all unranked shapes. Thus this attribute is set to determine
// which shape function behavior to use for this op, specifically
// propagating operand shapes as result shapes when this attribute is not
// set, or preserving result shapes as is when this attribute is set.
UnitAttr:$shape_invariant
);
let results = (outs Variadic<AnyTensor>:$output);

View File

@ -1,4 +1,4 @@
# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -tf-input-arrays=iter,val -tf-input-data-types=DT_INT32,DT_FLOAT -tf-input-shapes=':' -tf-output-arrays=StatefulWhile:1,StatelessWhile:1 -o - -mlir-print-debuginfo -mlir-print-local-scope | FileCheck %s
# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -tf-input-arrays=iter,val -tf-input-data-types=DT_INT32,DT_FLOAT -tf-input-shapes=':' -tf-output-arrays=StatefulWhile:1,StatelessWhile:1,WhileWithOutputShapes:1 -o - -mlir-print-debuginfo -mlir-print-local-scope | FileCheck %s
# Verify that TensorFlow While and StatelessWhile ops are mapped to the
# composite While op in MLIR with is_stateless attribute set accordingly to
@ -6,6 +6,7 @@
# CHECK-DAG: "tf.While"{{.*}} is_stateless = false{{.*}} loc("StatefulWhile")
# CHECK-DAG: "tf.While"{{.*}} is_stateless = true{{.*}} loc("StatelessWhile")
# CHECK-DAG: "tf.While"{{.*}} is_stateless = false{{.*}} shape_invariant{{.*}} -> (tensor<i32>, tensor<*xf32>) loc("WhileWithOutputShapes")
node {
name: "StatefulWhile"
@ -73,6 +74,51 @@ node {
experimental_debug_info {
}
}
node {
name: "WhileWithOutputShapes"
op: "While"
input: "iter"
input: "val"
attr {
key: "T"
value {
list {
type: DT_INT32
type: DT_FLOAT
}
}
}
attr {
key: "body"
value {
func {
name: "body"
}
}
}
attr {
key: "cond"
value {
func {
name: "cond"
}
}
}
attr {
key: "output_shapes"
value {
list {
shape {
}
shape {
unknown_rank: true
}
}
}
}
experimental_debug_info {
}
}
node {
name: "main"
op: "_Retval"
@ -107,6 +153,23 @@ node {
}
}
}
node {
name: "main2"
op: "_Retval"
input: "WhileWithOutputShapes:1"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
attr {
key: "index"
value {
i: 2
}
}
}
node {
name: "iter"
op: "Placeholder"

View File

@ -1,12 +1,13 @@
// RUN: tf-mlir-translate -mlir-to-graphdef %s -o - | FileCheck %s
func @main(%arg0: tensor<i32>, %arg1: tensor<5xf32>) -> (tensor<5xf32>, tensor<5xf32>) {
%0:2 = tf_executor.graph {
%outputs_2:2, %control_3 = tf_executor.island wraps "tf.While"(%arg0, %arg1) {body = @body, cond = @cond, is_stateless = false, output_shapes = [#tf.shape<>, #tf.shape<5>]} : (tensor<i32>, tensor<5xf32>) -> (tensor<i32>, tensor<5xf32>) loc("StatefulWhile")
%outputs_4:2, %control_5 = tf_executor.island wraps "tf.While"(%arg0, %arg1) {body = @body, cond = @cond, is_stateless = true, output_shapes = [#tf.shape<>, #tf.shape<5>]} : (tensor<i32>, tensor<5xf32>) -> (tensor<i32>, tensor<5xf32>) loc("StatelessWhile")
tf_executor.fetch %outputs_2#1, %outputs_4#1 : tensor<5xf32>, tensor<5xf32>
func @main(%arg0: tensor<i32>, %arg1: tensor<5xf32>) -> (tensor<5xf32>, tensor<5xf32>, tensor<5xf32>) {
%0:3 = tf_executor.graph {
%outputs_2:2, %control_3 = tf_executor.island wraps "tf.While"(%arg0, %arg1) {body = @body, cond = @cond, is_stateless = false} : (tensor<i32>, tensor<5xf32>) -> (tensor<i32>, tensor<5xf32>) loc("StatefulWhile")
%outputs_4:2, %control_5 = tf_executor.island wraps "tf.While"(%arg0, %arg1) {body = @body, cond = @cond, is_stateless = true} : (tensor<i32>, tensor<5xf32>) -> (tensor<i32>, tensor<5xf32>) loc("StatelessWhile")
%outputs_6:2, %control_7 = tf_executor.island wraps "tf.While"(%arg0, %arg1) {body = @body, cond = @cond, is_stateless = false, shape_invariant} : (tensor<i32>, tensor<5xf32>) -> (tensor<i32>, tensor<5xf32>) loc("WhileWithOutputShapes")
tf_executor.fetch %outputs_2#1, %outputs_4#1, %outputs_6#1 : tensor<5xf32>, tensor<5xf32>, tensor<5xf32>
}
return %0#0, %0#1 : tensor<5xf32>, tensor<5xf32>
return %0#0, %0#1, %0#2 : tensor<5xf32>, tensor<5xf32>, tensor<5xf32>
}
func @cond(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>) -> tensor<i1> {
@ -36,6 +37,7 @@ func @body(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>) -> (tensor<*xi32>, tensor
// CHECK-NOT: name:
// CHECK: op: "While"
// CHECK-NOT: is_stateless
// CHECK-NOT: shape_invariant
// CHECK: attr {
// CHECK: key: "output_shapes"
// CHECK: value {
@ -54,6 +56,7 @@ func @body(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>) -> (tensor<*xi32>, tensor
// CHECK-NOT: name:
// CHECK: op: "StatelessWhile"
// CHECK-NOT: is_stateless
// CHECK-NOT: shape_invariant
// CHECK: attr {
// CHECK: key: "output_shapes"
// CHECK: value {
@ -67,3 +70,20 @@ func @body(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>) -> (tensor<*xi32>, tensor
// CHECK: }
// CHECK: }
// CHECK: name: "WhileWithOutputShapes"
// CHECK-NOT: name:
// CHECK: op: "While"
// CHECK-NOT: is_stateless
// CHECK-NOT: shape_invariant
// CHECK: attr {
// CHECK: key: "output_shapes"
// CHECK: value {
// CHECK: list {
// CHECK: shape {
// CHECK: dim {
// CHECK: size: 5
// CHECK: }
// CHECK: }
// CHECK: }
// CHECK: }
// CHECK: }

View File

@ -112,8 +112,8 @@ LogicalResult ConvertIfOp(IfOp if_op) {
LogicalResult ConvertWhileOp(WhileOp while_op) {
auto while_region = OpBuilder(while_op).create<TF::WhileRegionOp>(
while_op.getLoc(), while_op.getResultTypes(), while_op.input(),
while_op.output_shapes(), while_op.parallel_iterations(),
while_op.is_stateless());
while_op.parallel_iterations(), while_op.is_stateless(),
while_op.shape_invariant());
CopyDeviceAndUnderscoredAttributes(while_op, while_region);
YieldOp cond_yield =

View File

@ -398,8 +398,8 @@ LogicalResult RegionControlFlowToFunctional::ConvertWhileOp(
OpBuilder builder(while_region);
auto while_op = builder.create<WhileOp>(
while_region.getLoc(), new_result_types, new_inputs, cond_name, body_name,
while_region.output_shapes(), while_region.parallel_iterations(),
while_region.is_stateless());
while_region.parallel_iterations(), while_region.is_stateless(),
while_region.shape_invariant());
CopyDeviceAndUnderscoredAttributes(while_region, while_op);
// Redirect old results to new results.

View File

@ -255,8 +255,7 @@ TF::WhileRegionOp CloneEmptyWhile(bool is_stateless,
OpBuilder& builder) {
auto host_side_while = builder.create<TF::WhileRegionOp>(
loc, /*output=*/ArrayRef<Type>{}, /*input=*/ArrayRef<Value>{},
/*output_shapes=*/builder.getArrayAttr({}), parallel_iterations,
is_stateless);
parallel_iterations, is_stateless, /*shape_invariant=*/false);
// Create empty else branch region.
auto& body = host_side_while.body();

View File

@ -155,6 +155,9 @@ StatusOr<absl::flat_hash_set<absl::string_view>> GetAttributesToIgnore(
if (llvm::isa<mlir::TF::CaseOp, mlir::TF::IfOp, mlir::TF::WhileOp>(inst))
attrs_to_ignore.insert("is_stateless");
if (llvm::isa<mlir::TF::WhileOp>(inst))
attrs_to_ignore.insert("shape_invariant");
return attrs_to_ignore;
}

View File

@ -971,6 +971,16 @@ StatusOr<mlir::Type> ImporterBase::InferOutputType(const Node& node, int idx,
etype.getContext()));
}
if (node.IsWhileNode()) {
auto* output_shapes = node.attrs().Find("output_shapes");
auto* element_types = node.attrs().Find("T");
if (output_shapes && !output_shapes->list().shape().empty()) {
const auto& output_shape = output_shapes->list().shape(idx);
const auto& element_type = element_types->list().type(idx);
return ConvertToMlirTensorType(output_shape, element_type, &builder);
}
}
// Returns a simple, more conservative unranked tensor type.
auto default_type = [&]() -> StatusOr<mlir::Type> {
mlir::Type element_type;
@ -1907,7 +1917,13 @@ Status ImporterBase::ConvertNode(const Node& node) {
// Case/If/While op in MLIR and add the differentiating attribute.
if (node.IsCaseNode()) composite_control_flow_op("Case");
if (node.IsIfNode()) composite_control_flow_op("If");
if (node.IsWhileNode()) composite_control_flow_op("While");
if (node.IsWhileNode()) {
composite_control_flow_op("While");
auto* output_shapes = node.attrs().Find("output_shapes");
if (output_shapes && !output_shapes->list().shape().empty())
result.attributes.push_back(
builder_.getNamedAttr("shape_invariant", builder_.getUnitAttr()));
}
// Register the mapping between the TF node and the newly created operation.
node_values_[node.id()] =