Fix shape inference to use the result of the SSA value inferred from folding

The shape inference folding can infer that `pow(%a, 1.0) -> %a`, even when
the constant value of %a is unknown. However in such cases the inference
would not proceed forward and let the result as-is even when %a is typed.

PiperOrigin-RevId: 355990740
Change-Id: If21d298f50b2410b9b78d5be973f261e8873e5e1
This commit is contained in:
A. Unique TensorFlower 2021-02-05 22:42:35 -08:00 committed by TensorFlower Gardener
parent 8527927a7b
commit 1fcaed87f6
2 changed files with 2 additions and 26 deletions

View File

@ -991,17 +991,4 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr
%0 = "tf.Identity"(%arg0) : (tensor<*x!tf.resource<tensor<f32>>>) -> tensor<*x!tf.resource<tensor<f32>>>
return %0 : tensor<*x!tf.resource<tensor<f32>>>
}
// CHECK-LABEL: func @InferFromValueFolding
func @InferFromValueFolding(%arg0 : tensor<f32>, %arg1 : tensor<f32>) -> tensor<*xf32> {
%cst1 = "tf.Const"() {value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
%mul = "tf.Mul"(%arg0, %arg0) : (tensor<f32>, tensor<f32>) -> tensor<f32>
// Folding will infer that: Pow(%mul, 1.0) -> %mul
// However we don't have the actual value for the mul, but we can use the
// mul type!
// CHECK: tf.Pow
// CHECK-SAME: -> tensor<f32>
%pow = "tf.Pow"(%mul, %cst1) : (tensor<f32>, tensor<f32>) -> tensor<*xf32>
return %pow : tensor<*xf32>
}
}

View File

@ -1415,18 +1415,11 @@ LogicalResult ShapeInference::TryToFold(Operation* op) {
auto fold_result = std::get<1>(result);
Attribute attr = nullptr;
if ((attr = fold_result.dyn_cast<Attribute>())) {
DCOMMENT("\t\t- Attr Result: " << attr);
RecordValue(ValuePort(std::get<0>(result)), attr);
} else {
auto value = fold_result.get<Value>();
if ((attr = ComputeOutputComponent(ValuePort(value)))) {
DCOMMENT("\t\tValue Result mapped to " << attr);
if ((attr = ComputeOutputComponent(ValuePort(value))))
RecordValue(ValuePort(std::get<0>(result)), attr);
} else {
DCOMMENT("\t\tValue result unmapped, use value type:" << value);
UpdateTypeAndInsertIncompatibleUseCasts(value.getType(),
std::get<0>(result));
}
}
if (ElementsAttr eattr = attr.dyn_cast_or_null<ElementsAttr>()) {
@ -1536,11 +1529,7 @@ LogicalResult ShapeInference::InferShapeUntilFixPoint(Region* region,
// Before attempting inference, just try to compute the folded
// value/shape.
if (succeeded(TryToFold(op)) &&
// Folding can "succeed" and yet not all types be refined. In such
// cases we still want to give a try at `InferShapeForSingleOperation`
none_of(op->getResultTypes(), CanBeRefined))
return;
if (succeeded(TryToFold(op))) return;
// Best-effort shape inference in attached functions. Do not return
// failure even if it doesn't get to fixed point.