[Shape Inference] Try to fold the operation before attempting to infer its shape.

If the operation can be folded to a constant, it is trivial to infer the shape after wards. This also enables other cases that may rely on folding for the inference to be successful, e.g. tf.Shape+tf.Conv2DBackpropInput.

PiperOrigin-RevId: 281904440
Change-Id: Id7374b6f20750799b97d2419aa0113d39c30ace3
This commit is contained in:
River Riddle 2019-11-21 22:30:35 -08:00 committed by TensorFlower Gardener
parent 33a99f926c
commit a6fee27f74
2 changed files with 32 additions and 11 deletions

View File

@ -34,15 +34,29 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr
return %1 : tensor<*xf32> return %1 : tensor<*xf32>
} }
// Tests the case where an inference opportunity relies on folding.
// CHECK-LABEL: func @simple_folding
func @simple_folding(%arg0: tensor<1x1x1x1xi32>, %arg1: tensor<1x1x1x1xf32>) -> tensor<?x?x?x?xf32> {
// CHECK: %[[CST:.*]] = "tf.Const"{{.*}} {value = dense<1> : tensor<4xi32>} : () -> tensor<4xi32>
// CHECK: %[[CONV:.*]] = "tf.Conv2DBackpropInput"(%[[CST]]
// CHECK-SAME: (tensor<4xi32>, tensor<1x1x1x1xf32>, tensor<1x1x1x1xf32>) -> tensor<1x1x1x1xf32>
// CHECK: %[[CAST:.*]] = "tf.Cast"(%[[CONV]]) {{.*}} : (tensor<1x1x1x1xf32>) -> tensor<?x?x?x?xf32>
// CHECK: return %[[CAST]] : tensor<?x?x?x?xf32>
%0 = "tf.Shape"(%arg0) : (tensor<1x1x1x1xi32>) -> tensor<4xi32>
%1 = "tf.Conv2DBackpropInput"(%0, %arg1, %arg1) {
padding = "VALID", strides = [1, 1, 1, 1]
} : (tensor<4xi32>, tensor<1x1x1x1xf32>, tensor<1x1x1x1xf32>) -> tensor<?x?x?x?xf32>
return %1 : tensor<?x?x?x?xf32>
}
// Tests the case where an op's shape function returns non-fully-defined shapes. // Tests the case where an op's shape function returns non-fully-defined shapes.
// CHECK-LABEL: func @op_non_fully_defined_shape_fn // CHECK-LABEL: func @op_non_fully_defined_shape_fn
func @op_non_fully_defined_shape_fn() -> tensor<?xi32> { func @op_non_fully_defined_shape_fn(%arg0: tensor<0xi32>, %arg1: tensor<0xi32>) -> tensor<?xi32> {
%0 = "tf.Const"() {dtype = "tfdtype$DT_INT32", value = dense<[]> : tensor<0xi32>} : () -> tensor<0xi32>
%1 = "tf.Const"() {dtype = "tfdtype$DT_INT32", value = dense<[]> : tensor<0xi32>} : () -> tensor<0xi32>
// CHECK: tf.BroadcastGradientArgs // CHECK: tf.BroadcastGradientArgs
// CHECK-SAME: (tensor<0xi32>, tensor<0xi32>) -> (tensor<?xi32>, tensor<?xi32>) // CHECK-SAME: (tensor<0xi32>, tensor<0xi32>) -> (tensor<?xi32>, tensor<?xi32>)
%2:2 = "tf.BroadcastGradientArgs"(%0, %1) {T = "tfdtype$DT_INT32", name = "BroadcastGradientArgs"} : (tensor<0xi32>, tensor<0xi32>) -> (tensor<?xi32>, tensor<?xi32>) %2:2 = "tf.BroadcastGradientArgs"(%arg0, %arg1) {T = "tfdtype$DT_INT32", name = "BroadcastGradientArgs"} : (tensor<0xi32>, tensor<0xi32>) -> (tensor<?xi32>, tensor<?xi32>)
return %2#0 : tensor<?xi32> return %2#0 : tensor<?xi32>
} }

View File

@ -32,6 +32,7 @@ limitations under the License.
#include "mlir/Pass/PassRegistry.h" // TF:local_config_mlir #include "mlir/Pass/PassRegistry.h" // TF:local_config_mlir
#include "mlir/Support/LLVM.h" // TF:local_config_mlir #include "mlir/Support/LLVM.h" // TF:local_config_mlir
#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir #include "mlir/Support/LogicalResult.h" // TF:local_config_mlir
#include "mlir/Transforms/FoldUtils.h" // TF:local_config_mlir
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
@ -213,9 +214,13 @@ bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect,
LogicalResult InferShapeUntilFixPoint(Region* region, int64_t graph_version, LogicalResult InferShapeUntilFixPoint(Region* region, int64_t graph_version,
int64_t max_iteration) { int64_t max_iteration) {
Dialect* tf_dialect = region->getContext()->getRegisteredDialect( MLIRContext* ctx = region->getContext();
TensorFlowDialect::getDialectNamespace()); Dialect* tf_dialect = ctx->getRegisteredDialect<TensorFlowDialect>();
// An operation folder that is used to attempt folding before inference.
OperationFolder folder(ctx);
bool changed = true; bool changed = true;
// TODO(aminim): we could have a more efficient traversal by guiding the // TODO(aminim): we could have a more efficient traversal by guiding the
// traversal with a worklist and reconsider only the nodes for which an // traversal with a worklist and reconsider only the nodes for which an
// operand type was inferred. This would need to be careful if working on a // operand type was inferred. This would need to be careful if working on a
@ -225,15 +230,17 @@ LogicalResult InferShapeUntilFixPoint(Region* region, int64_t graph_version,
LLVM_DEBUG(llvm::dbgs() LLVM_DEBUG(llvm::dbgs()
<< "Shape inference, iteration " << iteration << "\n"); << "Shape inference, iteration " << iteration << "\n");
region->walk([&](Operation* op) { region->walk([&](Operation* op) {
if (op->getDialect() == tf_dialect) if (op->getDialect() != tf_dialect) return;
// Before attempting inference, just try to fold the operation.
if (failed(folder.tryToFold(op)))
changed |= InferShapeForSingleOperation(op, tf_dialect, graph_version); changed |= InferShapeForSingleOperation(op, tf_dialect, graph_version);
}); });
} }
if (changed) { if (changed) {
region->getParentOp()->emitWarning() return region->getParentOp()->emitWarning()
<< "Shape inference did not reach stable state after " << max_iteration << "Shape inference did not reach stable state after "
<< " iterations"; << max_iteration << " iterations";
return failure();
} }
return success(); return success();
} }