Refine tf.const in TF shape inference.
PiperOrigin-RevId: 307726788 Change-Id: I7bb1ede57d9c27b191078f7533fad5975f1e713d
This commit is contained in:
parent
a3f1c2a668
commit
ab0cbb3cc0
@ -1,4 +1,4 @@
|
||||
// RUN: tf-opt %s -tf-shape-inference -verify-diagnostics | FileCheck %s -dump-input=fail -color
|
||||
// RUN: tf-opt %s -tf-shape-inference -verify-diagnostics | FileCheck %s -dump-input=fail
|
||||
|
||||
module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 130 : i32}} {
|
||||
// CHECK-LABEL: func @main(%arg0: tensor<1xi32>, %arg1: tensor<1xi32>) -> tensor<1xi32>
|
||||
@ -71,6 +71,15 @@ func @multiple_blocks_one_return(%arg0: tensor<?xf32>) -> tensor<*xf32> {
|
||||
return %1 : tensor<?x?x?x?xf32>
|
||||
}
|
||||
|
||||
// Tests where tf.Const's value needs to be refined.
|
||||
|
||||
func @const_refine() -> tensor<*xi32> {
|
||||
%0 = "tf.Const"() {value = dense<[3, 2]> : tensor<2xi32>} : () -> tensor<*xi32>
|
||||
// CHECK: "tf.Const"
|
||||
// CHECK-SAME: -> tensor<2xi32>
|
||||
return %0 : tensor<*xi32>
|
||||
}
|
||||
|
||||
// Tests the case where an op's shape function returns non-fully-defined shapes.
|
||||
|
||||
// CHECK-LABEL: func @op_non_fully_defined_shape_fn
|
||||
|
@ -274,6 +274,15 @@ bool InferShapeForCall(Operation* op) {
|
||||
return changed;
|
||||
}
|
||||
|
||||
bool RefineTfConst(TF::ConstOp const_op) {
|
||||
Type old_type = const_op.getType();
|
||||
if (const_op.valueAttr().getType() == old_type) return false;
|
||||
const_op.getResult().setType(const_op.valueAttr().getType());
|
||||
AddCastBackForUnsupportedNonTFUses(const_op, const_op.getResult(),
|
||||
const_op.getDialect(), old_type);
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect,
|
||||
@ -622,6 +631,13 @@ LogicalResult InferShapeUntilFixPoint(Region* region, int64_t graph_version,
|
||||
return;
|
||||
}
|
||||
|
||||
if (auto tf_const = dyn_cast<TF::ConstOp>(op)) {
|
||||
changed |= RefineTfConst(tf_const);
|
||||
// TODO(jpienaar): Debug why we can't just return here. We end up with
|
||||
// additional constant due to the propagation of constant into attached
|
||||
// function if we return already.
|
||||
}
|
||||
|
||||
// Before attempting inference, just try to fold the operation.
|
||||
if (succeeded(folder.tryToFold(op))) return;
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user