Fix shape inference to use the result of the SSA value inferred from folding
The shape inference folding can infer that `pow(%a, 1.0) -> %a`, even when the constant value of %a is unknown. However in such cases the inference would not proceed forward and let the result as-is even when %a is typed. PiperOrigin-RevId: 355990740 Change-Id: If21d298f50b2410b9b78d5be973f261e8873e5e1
This commit is contained in:
parent
8527927a7b
commit
1fcaed87f6
@ -991,17 +991,4 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr
|
||||
%0 = "tf.Identity"(%arg0) : (tensor<*x!tf.resource<tensor<f32>>>) -> tensor<*x!tf.resource<tensor<f32>>>
|
||||
return %0 : tensor<*x!tf.resource<tensor<f32>>>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @InferFromValueFolding
|
||||
func @InferFromValueFolding(%arg0 : tensor<f32>, %arg1 : tensor<f32>) -> tensor<*xf32> {
|
||||
%cst1 = "tf.Const"() {value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
|
||||
%mul = "tf.Mul"(%arg0, %arg0) : (tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
// Folding will infer that: Pow(%mul, 1.0) -> %mul
|
||||
// However we don't have the actual value for the mul, but we can use the
|
||||
// mul type!
|
||||
// CHECK: tf.Pow
|
||||
// CHECK-SAME: -> tensor<f32>
|
||||
%pow = "tf.Pow"(%mul, %cst1) : (tensor<f32>, tensor<f32>) -> tensor<*xf32>
|
||||
return %pow : tensor<*xf32>
|
||||
}
|
||||
}
|
||||
|
@ -1415,18 +1415,11 @@ LogicalResult ShapeInference::TryToFold(Operation* op) {
|
||||
auto fold_result = std::get<1>(result);
|
||||
Attribute attr = nullptr;
|
||||
if ((attr = fold_result.dyn_cast<Attribute>())) {
|
||||
DCOMMENT("\t\t- Attr Result: " << attr);
|
||||
RecordValue(ValuePort(std::get<0>(result)), attr);
|
||||
} else {
|
||||
auto value = fold_result.get<Value>();
|
||||
if ((attr = ComputeOutputComponent(ValuePort(value)))) {
|
||||
DCOMMENT("\t\tValue Result mapped to " << attr);
|
||||
if ((attr = ComputeOutputComponent(ValuePort(value))))
|
||||
RecordValue(ValuePort(std::get<0>(result)), attr);
|
||||
} else {
|
||||
DCOMMENT("\t\tValue result unmapped, use value type:" << value);
|
||||
UpdateTypeAndInsertIncompatibleUseCasts(value.getType(),
|
||||
std::get<0>(result));
|
||||
}
|
||||
}
|
||||
|
||||
if (ElementsAttr eattr = attr.dyn_cast_or_null<ElementsAttr>()) {
|
||||
@ -1536,11 +1529,7 @@ LogicalResult ShapeInference::InferShapeUntilFixPoint(Region* region,
|
||||
|
||||
// Before attempting inference, just try to compute the folded
|
||||
// value/shape.
|
||||
if (succeeded(TryToFold(op)) &&
|
||||
// Folding can "succeed" and yet not all types be refined. In such
|
||||
// cases we still want to give a try at `InferShapeForSingleOperation`
|
||||
none_of(op->getResultTypes(), CanBeRefined))
|
||||
return;
|
||||
if (succeeded(TryToFold(op))) return;
|
||||
|
||||
// Best-effort shape inference in attached functions. Do not return
|
||||
// failure even if it doesn't get to fixed point.
|
||||
|
Loading…
x
Reference in New Issue
Block a user