From a6fee27f745b89d3c1e2173215835f7e4ca95ac1 Mon Sep 17 00:00:00 2001 From: River Riddle Date: Thu, 21 Nov 2019 22:30:35 -0800 Subject: [PATCH] [Shape Inference] Try to fold the operation before attempting to infer its shape. If the operation can be folded to a constant, it is trivial to infer the shape after wards. This also enables other cases that may rely on folding for the inference to be successful, e.g. tf.Shape+tf.Conv2DBackpropInput. PiperOrigin-RevId: 281904440 Change-Id: Id7374b6f20750799b97d2419aa0113d39c30ace3 --- .../tensorflow/tests/shape_inference.mlir | 22 +++++++++++++++---- .../tensorflow/transforms/shape_inference.cc | 21 ++++++++++++------ 2 files changed, 32 insertions(+), 11 deletions(-) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir b/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir index 0b02ac2b5bb..acf236f8e1f 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir @@ -34,15 +34,29 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr return %1 : tensor<*xf32> } +// Tests the case where an inference opportunity relies on folding. + +// CHECK-LABEL: func @simple_folding + func @simple_folding(%arg0: tensor<1x1x1x1xi32>, %arg1: tensor<1x1x1x1xf32>) -> tensor { +// CHECK: %[[CST:.*]] = "tf.Const"{{.*}} {value = dense<1> : tensor<4xi32>} : () -> tensor<4xi32> +// CHECK: %[[CONV:.*]] = "tf.Conv2DBackpropInput"(%[[CST]] +// CHECK-SAME: (tensor<4xi32>, tensor<1x1x1x1xf32>, tensor<1x1x1x1xf32>) -> tensor<1x1x1x1xf32> +// CHECK: %[[CAST:.*]] = "tf.Cast"(%[[CONV]]) {{.*}} : (tensor<1x1x1x1xf32>) -> tensor +// CHECK: return %[[CAST]] : tensor + %0 = "tf.Shape"(%arg0) : (tensor<1x1x1x1xi32>) -> tensor<4xi32> + %1 = "tf.Conv2DBackpropInput"(%0, %arg1, %arg1) { + padding = "VALID", strides = [1, 1, 1, 1] + } : (tensor<4xi32>, tensor<1x1x1x1xf32>, tensor<1x1x1x1xf32>) -> tensor + return %1 : tensor + } + // Tests the case where an op's shape function returns non-fully-defined shapes. // CHECK-LABEL: func @op_non_fully_defined_shape_fn - func @op_non_fully_defined_shape_fn() -> tensor { - %0 = "tf.Const"() {dtype = "tfdtype$DT_INT32", value = dense<[]> : tensor<0xi32>} : () -> tensor<0xi32> - %1 = "tf.Const"() {dtype = "tfdtype$DT_INT32", value = dense<[]> : tensor<0xi32>} : () -> tensor<0xi32> + func @op_non_fully_defined_shape_fn(%arg0: tensor<0xi32>, %arg1: tensor<0xi32>) -> tensor { // CHECK: tf.BroadcastGradientArgs // CHECK-SAME: (tensor<0xi32>, tensor<0xi32>) -> (tensor, tensor) - %2:2 = "tf.BroadcastGradientArgs"(%0, %1) {T = "tfdtype$DT_INT32", name = "BroadcastGradientArgs"} : (tensor<0xi32>, tensor<0xi32>) -> (tensor, tensor) + %2:2 = "tf.BroadcastGradientArgs"(%arg0, %arg1) {T = "tfdtype$DT_INT32", name = "BroadcastGradientArgs"} : (tensor<0xi32>, tensor<0xi32>) -> (tensor, tensor) return %2#0 : tensor } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc index 9a7178178ac..c44c81d1cef 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc @@ -32,6 +32,7 @@ limitations under the License. #include "mlir/Pass/PassRegistry.h" // TF:local_config_mlir #include "mlir/Support/LLVM.h" // TF:local_config_mlir #include "mlir/Support/LogicalResult.h" // TF:local_config_mlir +#include "mlir/Transforms/FoldUtils.h" // TF:local_config_mlir #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" @@ -213,9 +214,13 @@ bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect, LogicalResult InferShapeUntilFixPoint(Region* region, int64_t graph_version, int64_t max_iteration) { - Dialect* tf_dialect = region->getContext()->getRegisteredDialect( - TensorFlowDialect::getDialectNamespace()); + MLIRContext* ctx = region->getContext(); + Dialect* tf_dialect = ctx->getRegisteredDialect(); + + // An operation folder that is used to attempt folding before inference. + OperationFolder folder(ctx); bool changed = true; + // TODO(aminim): we could have a more efficient traversal by guiding the // traversal with a worklist and reconsider only the nodes for which an // operand type was inferred. This would need to be careful if working on a @@ -225,15 +230,17 @@ LogicalResult InferShapeUntilFixPoint(Region* region, int64_t graph_version, LLVM_DEBUG(llvm::dbgs() << "Shape inference, iteration " << iteration << "\n"); region->walk([&](Operation* op) { - if (op->getDialect() == tf_dialect) + if (op->getDialect() != tf_dialect) return; + + // Before attempting inference, just try to fold the operation. + if (failed(folder.tryToFold(op))) changed |= InferShapeForSingleOperation(op, tf_dialect, graph_version); }); } if (changed) { - region->getParentOp()->emitWarning() - << "Shape inference did not reach stable state after " << max_iteration - << " iterations"; - return failure(); + return region->getParentOp()->emitWarning() + << "Shape inference did not reach stable state after " + << max_iteration << " iterations"; } return success(); }