diff --git a/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir b/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir index 0b02ac2b5bb..acf236f8e1f 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir @@ -34,15 +34,29 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr return %1 : tensor<*xf32> } +// Tests the case where an inference opportunity relies on folding. + +// CHECK-LABEL: func @simple_folding + func @simple_folding(%arg0: tensor<1x1x1x1xi32>, %arg1: tensor<1x1x1x1xf32>) -> tensor { +// CHECK: %[[CST:.*]] = "tf.Const"{{.*}} {value = dense<1> : tensor<4xi32>} : () -> tensor<4xi32> +// CHECK: %[[CONV:.*]] = "tf.Conv2DBackpropInput"(%[[CST]] +// CHECK-SAME: (tensor<4xi32>, tensor<1x1x1x1xf32>, tensor<1x1x1x1xf32>) -> tensor<1x1x1x1xf32> +// CHECK: %[[CAST:.*]] = "tf.Cast"(%[[CONV]]) {{.*}} : (tensor<1x1x1x1xf32>) -> tensor +// CHECK: return %[[CAST]] : tensor + %0 = "tf.Shape"(%arg0) : (tensor<1x1x1x1xi32>) -> tensor<4xi32> + %1 = "tf.Conv2DBackpropInput"(%0, %arg1, %arg1) { + padding = "VALID", strides = [1, 1, 1, 1] + } : (tensor<4xi32>, tensor<1x1x1x1xf32>, tensor<1x1x1x1xf32>) -> tensor + return %1 : tensor + } + // Tests the case where an op's shape function returns non-fully-defined shapes. // CHECK-LABEL: func @op_non_fully_defined_shape_fn - func @op_non_fully_defined_shape_fn() -> tensor { - %0 = "tf.Const"() {dtype = "tfdtype$DT_INT32", value = dense<[]> : tensor<0xi32>} : () -> tensor<0xi32> - %1 = "tf.Const"() {dtype = "tfdtype$DT_INT32", value = dense<[]> : tensor<0xi32>} : () -> tensor<0xi32> + func @op_non_fully_defined_shape_fn(%arg0: tensor<0xi32>, %arg1: tensor<0xi32>) -> tensor { // CHECK: tf.BroadcastGradientArgs // CHECK-SAME: (tensor<0xi32>, tensor<0xi32>) -> (tensor, tensor) - %2:2 = "tf.BroadcastGradientArgs"(%0, %1) {T = "tfdtype$DT_INT32", name = "BroadcastGradientArgs"} : (tensor<0xi32>, tensor<0xi32>) -> (tensor, tensor) + %2:2 = "tf.BroadcastGradientArgs"(%arg0, %arg1) {T = "tfdtype$DT_INT32", name = "BroadcastGradientArgs"} : (tensor<0xi32>, tensor<0xi32>) -> (tensor, tensor) return %2#0 : tensor } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc index 9a7178178ac..c44c81d1cef 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc @@ -32,6 +32,7 @@ limitations under the License. #include "mlir/Pass/PassRegistry.h" // TF:local_config_mlir #include "mlir/Support/LLVM.h" // TF:local_config_mlir #include "mlir/Support/LogicalResult.h" // TF:local_config_mlir +#include "mlir/Transforms/FoldUtils.h" // TF:local_config_mlir #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" @@ -213,9 +214,13 @@ bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect, LogicalResult InferShapeUntilFixPoint(Region* region, int64_t graph_version, int64_t max_iteration) { - Dialect* tf_dialect = region->getContext()->getRegisteredDialect( - TensorFlowDialect::getDialectNamespace()); + MLIRContext* ctx = region->getContext(); + Dialect* tf_dialect = ctx->getRegisteredDialect(); + + // An operation folder that is used to attempt folding before inference. + OperationFolder folder(ctx); bool changed = true; + // TODO(aminim): we could have a more efficient traversal by guiding the // traversal with a worklist and reconsider only the nodes for which an // operand type was inferred. This would need to be careful if working on a @@ -225,15 +230,17 @@ LogicalResult InferShapeUntilFixPoint(Region* region, int64_t graph_version, LLVM_DEBUG(llvm::dbgs() << "Shape inference, iteration " << iteration << "\n"); region->walk([&](Operation* op) { - if (op->getDialect() == tf_dialect) + if (op->getDialect() != tf_dialect) return; + + // Before attempting inference, just try to fold the operation. + if (failed(folder.tryToFold(op))) changed |= InferShapeForSingleOperation(op, tf_dialect, graph_version); }); } if (changed) { - region->getParentOp()->emitWarning() - << "Shape inference did not reach stable state after " << max_iteration - << " iterations"; - return failure(); + return region->getParentOp()->emitWarning() + << "Shape inference did not reach stable state after " + << max_iteration << " iterations"; } return success(); }