Update Graph -> TF MLIR importer to always import tf.While output shapes.
output_shapes attribute is reverted back to be a derived attribute and output shapes are populated as result types on import. A new attribute, shape_invariant, is added to tf.While/tf.WhileRegion indicating whether shapes were originally present on the While op. This attribute will be used to update the shape inference for tf.While/tf.WhileRegion, where it is possible to have different result shapes and operand shapes cannot simply be propagated as result shapes. PiperOrigin-RevId: 343546219 Change-Id: Ib7e832f11edabd59d0f3b45cb0d31c6b6187e706
This commit is contained in:
parent
70802c2d5b
commit
7fa0d80d06
@ -684,12 +684,23 @@ body: A function that takes a list of tensors and returns another
|
||||
|
||||
FlatSymbolRefAttr:$cond,
|
||||
FlatSymbolRefAttr:$body,
|
||||
DefaultValuedAttr<TF_ShapeAttrArray, "{}">:$output_shapes,
|
||||
DefaultValuedAttr<I64Attr, "10">:$parallel_iterations,
|
||||
|
||||
// Used to map StatelessWhile and While op defined in TensorFlow to a common
|
||||
// op.
|
||||
BoolAttr:$is_stateless
|
||||
BoolAttr:$is_stateless,
|
||||
|
||||
// In TensorFlow, While has a special behavior where if `output_shapes`
|
||||
// attribute is not empty, those shapes are used in its shape function
|
||||
// as result shapes instead of propagating operand shapes as result shapes.
|
||||
// This allows for different result shapes from operand shapes. While these
|
||||
// shapes are imported and set as a part of the result type, there is no
|
||||
// indicator differentiating between having no output shapes compared to
|
||||
// having all unranked shapes. Thus this attribute is set to determine
|
||||
// which shape function behavior to use for this op, specifically
|
||||
// propagating operand shapes as result shapes when this attribute is not
|
||||
// set, or preserving result shapes as is when this attribute is set.
|
||||
UnitAttr:$shape_invariant
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
@ -697,6 +708,7 @@ body: A function that takes a list of tensors and returns another
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeListAttr T = TF_DerivedOperandTypeListAttr<0>;
|
||||
TF_DerivedResultShapeListAttr output_shapes = TF_DerivedResultShapeListAttr<0>;
|
||||
|
||||
let verifier = [{
|
||||
return Verify(*this);
|
||||
@ -752,12 +764,23 @@ def TF_WhileRegionOp : TF_Op<"WhileRegion",
|
||||
let arguments = (ins
|
||||
Variadic<AnyTensor>:$input,
|
||||
|
||||
DefaultValuedAttr<TF_ShapeAttrArray, "{}">:$output_shapes,
|
||||
DefaultValuedAttr<I64Attr, "10">:$parallel_iterations,
|
||||
|
||||
// Used to map StatelessWhile and While op defined in TensorFlow to a common
|
||||
// op.
|
||||
BoolAttr:$is_stateless
|
||||
BoolAttr:$is_stateless,
|
||||
|
||||
// In TensorFlow, While has a special behavior where if `output_shapes`
|
||||
// attribute is not empty, those shapes are used in its shape function
|
||||
// as result shapes instead of propagating operand shapes as result shapes.
|
||||
// This allows for different result shapes from operand shapes. While these
|
||||
// shapes are imported and set as a part of the result type, there is no
|
||||
// indicator differentiating between having no output shapes compared to
|
||||
// having all unranked shapes. Thus this attribute is set to determine
|
||||
// which shape function behavior to use for this op, specifically
|
||||
// propagating operand shapes as result shapes when this attribute is not
|
||||
// set, or preserving result shapes as is when this attribute is set.
|
||||
UnitAttr:$shape_invariant
|
||||
);
|
||||
let results = (outs Variadic<AnyTensor>:$output);
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -tf-input-arrays=iter,val -tf-input-data-types=DT_INT32,DT_FLOAT -tf-input-shapes=':' -tf-output-arrays=StatefulWhile:1,StatelessWhile:1 -o - -mlir-print-debuginfo -mlir-print-local-scope | FileCheck %s
|
||||
# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -tf-input-arrays=iter,val -tf-input-data-types=DT_INT32,DT_FLOAT -tf-input-shapes=':' -tf-output-arrays=StatefulWhile:1,StatelessWhile:1,WhileWithOutputShapes:1 -o - -mlir-print-debuginfo -mlir-print-local-scope | FileCheck %s
|
||||
|
||||
# Verify that TensorFlow While and StatelessWhile ops are mapped to the
|
||||
# composite While op in MLIR with is_stateless attribute set accordingly to
|
||||
@ -6,6 +6,7 @@
|
||||
|
||||
# CHECK-DAG: "tf.While"{{.*}} is_stateless = false{{.*}} loc("StatefulWhile")
|
||||
# CHECK-DAG: "tf.While"{{.*}} is_stateless = true{{.*}} loc("StatelessWhile")
|
||||
# CHECK-DAG: "tf.While"{{.*}} is_stateless = false{{.*}} shape_invariant{{.*}} -> (tensor<i32>, tensor<*xf32>) loc("WhileWithOutputShapes")
|
||||
|
||||
node {
|
||||
name: "StatefulWhile"
|
||||
@ -73,6 +74,51 @@ node {
|
||||
experimental_debug_info {
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "WhileWithOutputShapes"
|
||||
op: "While"
|
||||
input: "iter"
|
||||
input: "val"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
list {
|
||||
type: DT_INT32
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "body"
|
||||
value {
|
||||
func {
|
||||
name: "body"
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "cond"
|
||||
value {
|
||||
func {
|
||||
name: "cond"
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "output_shapes"
|
||||
value {
|
||||
list {
|
||||
shape {
|
||||
}
|
||||
shape {
|
||||
unknown_rank: true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
experimental_debug_info {
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "main"
|
||||
op: "_Retval"
|
||||
@ -107,6 +153,23 @@ node {
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "main2"
|
||||
op: "_Retval"
|
||||
input: "WhileWithOutputShapes:1"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "index"
|
||||
value {
|
||||
i: 2
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "iter"
|
||||
op: "Placeholder"
|
||||
|
@ -1,12 +1,13 @@
|
||||
// RUN: tf-mlir-translate -mlir-to-graphdef %s -o - | FileCheck %s
|
||||
|
||||
func @main(%arg0: tensor<i32>, %arg1: tensor<5xf32>) -> (tensor<5xf32>, tensor<5xf32>) {
|
||||
%0:2 = tf_executor.graph {
|
||||
%outputs_2:2, %control_3 = tf_executor.island wraps "tf.While"(%arg0, %arg1) {body = @body, cond = @cond, is_stateless = false, output_shapes = [#tf.shape<>, #tf.shape<5>]} : (tensor<i32>, tensor<5xf32>) -> (tensor<i32>, tensor<5xf32>) loc("StatefulWhile")
|
||||
%outputs_4:2, %control_5 = tf_executor.island wraps "tf.While"(%arg0, %arg1) {body = @body, cond = @cond, is_stateless = true, output_shapes = [#tf.shape<>, #tf.shape<5>]} : (tensor<i32>, tensor<5xf32>) -> (tensor<i32>, tensor<5xf32>) loc("StatelessWhile")
|
||||
tf_executor.fetch %outputs_2#1, %outputs_4#1 : tensor<5xf32>, tensor<5xf32>
|
||||
func @main(%arg0: tensor<i32>, %arg1: tensor<5xf32>) -> (tensor<5xf32>, tensor<5xf32>, tensor<5xf32>) {
|
||||
%0:3 = tf_executor.graph {
|
||||
%outputs_2:2, %control_3 = tf_executor.island wraps "tf.While"(%arg0, %arg1) {body = @body, cond = @cond, is_stateless = false} : (tensor<i32>, tensor<5xf32>) -> (tensor<i32>, tensor<5xf32>) loc("StatefulWhile")
|
||||
%outputs_4:2, %control_5 = tf_executor.island wraps "tf.While"(%arg0, %arg1) {body = @body, cond = @cond, is_stateless = true} : (tensor<i32>, tensor<5xf32>) -> (tensor<i32>, tensor<5xf32>) loc("StatelessWhile")
|
||||
%outputs_6:2, %control_7 = tf_executor.island wraps "tf.While"(%arg0, %arg1) {body = @body, cond = @cond, is_stateless = false, shape_invariant} : (tensor<i32>, tensor<5xf32>) -> (tensor<i32>, tensor<5xf32>) loc("WhileWithOutputShapes")
|
||||
tf_executor.fetch %outputs_2#1, %outputs_4#1, %outputs_6#1 : tensor<5xf32>, tensor<5xf32>, tensor<5xf32>
|
||||
}
|
||||
return %0#0, %0#1 : tensor<5xf32>, tensor<5xf32>
|
||||
return %0#0, %0#1, %0#2 : tensor<5xf32>, tensor<5xf32>, tensor<5xf32>
|
||||
}
|
||||
|
||||
func @cond(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>) -> tensor<i1> {
|
||||
@ -36,6 +37,7 @@ func @body(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>) -> (tensor<*xi32>, tensor
|
||||
// CHECK-NOT: name:
|
||||
// CHECK: op: "While"
|
||||
// CHECK-NOT: is_stateless
|
||||
// CHECK-NOT: shape_invariant
|
||||
// CHECK: attr {
|
||||
// CHECK: key: "output_shapes"
|
||||
// CHECK: value {
|
||||
@ -54,6 +56,7 @@ func @body(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>) -> (tensor<*xi32>, tensor
|
||||
// CHECK-NOT: name:
|
||||
// CHECK: op: "StatelessWhile"
|
||||
// CHECK-NOT: is_stateless
|
||||
// CHECK-NOT: shape_invariant
|
||||
// CHECK: attr {
|
||||
// CHECK: key: "output_shapes"
|
||||
// CHECK: value {
|
||||
@ -67,3 +70,20 @@ func @body(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>) -> (tensor<*xi32>, tensor
|
||||
// CHECK: }
|
||||
// CHECK: }
|
||||
|
||||
// CHECK: name: "WhileWithOutputShapes"
|
||||
// CHECK-NOT: name:
|
||||
// CHECK: op: "While"
|
||||
// CHECK-NOT: is_stateless
|
||||
// CHECK-NOT: shape_invariant
|
||||
// CHECK: attr {
|
||||
// CHECK: key: "output_shapes"
|
||||
// CHECK: value {
|
||||
// CHECK: list {
|
||||
// CHECK: shape {
|
||||
// CHECK: dim {
|
||||
// CHECK: size: 5
|
||||
// CHECK: }
|
||||
// CHECK: }
|
||||
// CHECK: }
|
||||
// CHECK: }
|
||||
// CHECK: }
|
||||
|
@ -112,8 +112,8 @@ LogicalResult ConvertIfOp(IfOp if_op) {
|
||||
LogicalResult ConvertWhileOp(WhileOp while_op) {
|
||||
auto while_region = OpBuilder(while_op).create<TF::WhileRegionOp>(
|
||||
while_op.getLoc(), while_op.getResultTypes(), while_op.input(),
|
||||
while_op.output_shapes(), while_op.parallel_iterations(),
|
||||
while_op.is_stateless());
|
||||
while_op.parallel_iterations(), while_op.is_stateless(),
|
||||
while_op.shape_invariant());
|
||||
CopyDeviceAndUnderscoredAttributes(while_op, while_region);
|
||||
|
||||
YieldOp cond_yield =
|
||||
|
@ -398,8 +398,8 @@ LogicalResult RegionControlFlowToFunctional::ConvertWhileOp(
|
||||
OpBuilder builder(while_region);
|
||||
auto while_op = builder.create<WhileOp>(
|
||||
while_region.getLoc(), new_result_types, new_inputs, cond_name, body_name,
|
||||
while_region.output_shapes(), while_region.parallel_iterations(),
|
||||
while_region.is_stateless());
|
||||
while_region.parallel_iterations(), while_region.is_stateless(),
|
||||
while_region.shape_invariant());
|
||||
CopyDeviceAndUnderscoredAttributes(while_region, while_op);
|
||||
|
||||
// Redirect old results to new results.
|
||||
|
@ -255,8 +255,7 @@ TF::WhileRegionOp CloneEmptyWhile(bool is_stateless,
|
||||
OpBuilder& builder) {
|
||||
auto host_side_while = builder.create<TF::WhileRegionOp>(
|
||||
loc, /*output=*/ArrayRef<Type>{}, /*input=*/ArrayRef<Value>{},
|
||||
/*output_shapes=*/builder.getArrayAttr({}), parallel_iterations,
|
||||
is_stateless);
|
||||
parallel_iterations, is_stateless, /*shape_invariant=*/false);
|
||||
|
||||
// Create empty else branch region.
|
||||
auto& body = host_side_while.body();
|
||||
|
@ -155,6 +155,9 @@ StatusOr<absl::flat_hash_set<absl::string_view>> GetAttributesToIgnore(
|
||||
if (llvm::isa<mlir::TF::CaseOp, mlir::TF::IfOp, mlir::TF::WhileOp>(inst))
|
||||
attrs_to_ignore.insert("is_stateless");
|
||||
|
||||
if (llvm::isa<mlir::TF::WhileOp>(inst))
|
||||
attrs_to_ignore.insert("shape_invariant");
|
||||
|
||||
return attrs_to_ignore;
|
||||
}
|
||||
|
||||
|
@ -971,6 +971,16 @@ StatusOr<mlir::Type> ImporterBase::InferOutputType(const Node& node, int idx,
|
||||
etype.getContext()));
|
||||
}
|
||||
|
||||
if (node.IsWhileNode()) {
|
||||
auto* output_shapes = node.attrs().Find("output_shapes");
|
||||
auto* element_types = node.attrs().Find("T");
|
||||
if (output_shapes && !output_shapes->list().shape().empty()) {
|
||||
const auto& output_shape = output_shapes->list().shape(idx);
|
||||
const auto& element_type = element_types->list().type(idx);
|
||||
return ConvertToMlirTensorType(output_shape, element_type, &builder);
|
||||
}
|
||||
}
|
||||
|
||||
// Returns a simple, more conservative unranked tensor type.
|
||||
auto default_type = [&]() -> StatusOr<mlir::Type> {
|
||||
mlir::Type element_type;
|
||||
@ -1907,7 +1917,13 @@ Status ImporterBase::ConvertNode(const Node& node) {
|
||||
// Case/If/While op in MLIR and add the differentiating attribute.
|
||||
if (node.IsCaseNode()) composite_control_flow_op("Case");
|
||||
if (node.IsIfNode()) composite_control_flow_op("If");
|
||||
if (node.IsWhileNode()) composite_control_flow_op("While");
|
||||
if (node.IsWhileNode()) {
|
||||
composite_control_flow_op("While");
|
||||
auto* output_shapes = node.attrs().Find("output_shapes");
|
||||
if (output_shapes && !output_shapes->list().shape().empty())
|
||||
result.attributes.push_back(
|
||||
builder_.getNamedAttr("shape_invariant", builder_.getUnitAttr()));
|
||||
}
|
||||
|
||||
// Register the mapping between the TF node and the newly created operation.
|
||||
node_values_[node.id()] =
|
||||
|
Loading…
Reference in New Issue
Block a user