diff --git a/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir b/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir index 757df9db790..94f626c7b29 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir @@ -1,4 +1,4 @@ -// RUN: tf-opt %s -tf-shape-inference -verify-diagnostics | FileCheck %s -dump-input=fail -color +// RUN: tf-opt %s -tf-shape-inference -verify-diagnostics | FileCheck %s -dump-input=fail module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 130 : i32}} { // CHECK-LABEL: func @main(%arg0: tensor<1xi32>, %arg1: tensor<1xi32>) -> tensor<1xi32> @@ -71,6 +71,15 @@ func @multiple_blocks_one_return(%arg0: tensor) -> tensor<*xf32> { return %1 : tensor } +// Tests where tf.Const's value needs to be refined. + + func @const_refine() -> tensor<*xi32> { + %0 = "tf.Const"() {value = dense<[3, 2]> : tensor<2xi32>} : () -> tensor<*xi32> + // CHECK: "tf.Const" + // CHECK-SAME: -> tensor<2xi32> + return %0 : tensor<*xi32> + } + // Tests the case where an op's shape function returns non-fully-defined shapes. // CHECK-LABEL: func @op_non_fully_defined_shape_fn diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc index 53e15b0b609..c2e21d2c59f 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc @@ -274,6 +274,15 @@ bool InferShapeForCall(Operation* op) { return changed; } +bool RefineTfConst(TF::ConstOp const_op) { + Type old_type = const_op.getType(); + if (const_op.valueAttr().getType() == old_type) return false; + const_op.getResult().setType(const_op.valueAttr().getType()); + AddCastBackForUnsupportedNonTFUses(const_op, const_op.getResult(), + const_op.getDialect(), old_type); + return true; +} + } // namespace bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect, @@ -622,6 +631,13 @@ LogicalResult InferShapeUntilFixPoint(Region* region, int64_t graph_version, return; } + if (auto tf_const = dyn_cast(op)) { + changed |= RefineTfConst(tf_const); + // TODO(jpienaar): Debug why we can't just return here. We end up with + // additional constant due to the propagation of constant into attached + // function if we return already. + } + // Before attempting inference, just try to fold the operation. if (succeeded(folder.tryToFold(op))) return;