Refine tf.const in TF shape inference.

PiperOrigin-RevId: 307726788
Change-Id: I7bb1ede57d9c27b191078f7533fad5975f1e713d
This commit is contained in:
Jacques Pienaar 2020-04-21 19:07:51 -07:00 committed by TensorFlower Gardener
parent a3f1c2a668
commit ab0cbb3cc0
2 changed files with 26 additions and 1 deletions

View File

@ -1,4 +1,4 @@
// RUN: tf-opt %s -tf-shape-inference -verify-diagnostics | FileCheck %s -dump-input=fail -color
// RUN: tf-opt %s -tf-shape-inference -verify-diagnostics | FileCheck %s -dump-input=fail
module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 130 : i32}} {
// CHECK-LABEL: func @main(%arg0: tensor<1xi32>, %arg1: tensor<1xi32>) -> tensor<1xi32>
@ -71,6 +71,15 @@ func @multiple_blocks_one_return(%arg0: tensor<?xf32>) -> tensor<*xf32> {
return %1 : tensor<?x?x?x?xf32>
}
// Tests where tf.Const's value needs to be refined.
func @const_refine() -> tensor<*xi32> {
%0 = "tf.Const"() {value = dense<[3, 2]> : tensor<2xi32>} : () -> tensor<*xi32>
// CHECK: "tf.Const"
// CHECK-SAME: -> tensor<2xi32>
return %0 : tensor<*xi32>
}
// Tests the case where an op's shape function returns non-fully-defined shapes.
// CHECK-LABEL: func @op_non_fully_defined_shape_fn

View File

@ -274,6 +274,15 @@ bool InferShapeForCall(Operation* op) {
return changed;
}
bool RefineTfConst(TF::ConstOp const_op) {
Type old_type = const_op.getType();
if (const_op.valueAttr().getType() == old_type) return false;
const_op.getResult().setType(const_op.valueAttr().getType());
AddCastBackForUnsupportedNonTFUses(const_op, const_op.getResult(),
const_op.getDialect(), old_type);
return true;
}
} // namespace
bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect,
@ -622,6 +631,13 @@ LogicalResult InferShapeUntilFixPoint(Region* region, int64_t graph_version,
return;
}
if (auto tf_const = dyn_cast<TF::ConstOp>(op)) {
changed |= RefineTfConst(tf_const);
// TODO(jpienaar): Debug why we can't just return here. We end up with
// additional constant due to the propagation of constant into attached
// function if we return already.
}
// Before attempting inference, just try to fold the operation.
if (succeeded(folder.tryToFold(op))) return;