[Shape Inference] Try to fold the operation before attempting to infer its shape.
If the operation can be folded to a constant, it is trivial to infer the shape after wards. This also enables other cases that may rely on folding for the inference to be successful, e.g. tf.Shape+tf.Conv2DBackpropInput. PiperOrigin-RevId: 281904440 Change-Id: Id7374b6f20750799b97d2419aa0113d39c30ace3
This commit is contained in:
parent
33a99f926c
commit
a6fee27f74
@ -34,15 +34,29 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr
|
||||
return %1 : tensor<*xf32>
|
||||
}
|
||||
|
||||
// Tests the case where an inference opportunity relies on folding.
|
||||
|
||||
// CHECK-LABEL: func @simple_folding
|
||||
func @simple_folding(%arg0: tensor<1x1x1x1xi32>, %arg1: tensor<1x1x1x1xf32>) -> tensor<?x?x?x?xf32> {
|
||||
// CHECK: %[[CST:.*]] = "tf.Const"{{.*}} {value = dense<1> : tensor<4xi32>} : () -> tensor<4xi32>
|
||||
// CHECK: %[[CONV:.*]] = "tf.Conv2DBackpropInput"(%[[CST]]
|
||||
// CHECK-SAME: (tensor<4xi32>, tensor<1x1x1x1xf32>, tensor<1x1x1x1xf32>) -> tensor<1x1x1x1xf32>
|
||||
// CHECK: %[[CAST:.*]] = "tf.Cast"(%[[CONV]]) {{.*}} : (tensor<1x1x1x1xf32>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK: return %[[CAST]] : tensor<?x?x?x?xf32>
|
||||
%0 = "tf.Shape"(%arg0) : (tensor<1x1x1x1xi32>) -> tensor<4xi32>
|
||||
%1 = "tf.Conv2DBackpropInput"(%0, %arg1, %arg1) {
|
||||
padding = "VALID", strides = [1, 1, 1, 1]
|
||||
} : (tensor<4xi32>, tensor<1x1x1x1xf32>, tensor<1x1x1x1xf32>) -> tensor<?x?x?x?xf32>
|
||||
return %1 : tensor<?x?x?x?xf32>
|
||||
}
|
||||
|
||||
// Tests the case where an op's shape function returns non-fully-defined shapes.
|
||||
|
||||
// CHECK-LABEL: func @op_non_fully_defined_shape_fn
|
||||
func @op_non_fully_defined_shape_fn() -> tensor<?xi32> {
|
||||
%0 = "tf.Const"() {dtype = "tfdtype$DT_INT32", value = dense<[]> : tensor<0xi32>} : () -> tensor<0xi32>
|
||||
%1 = "tf.Const"() {dtype = "tfdtype$DT_INT32", value = dense<[]> : tensor<0xi32>} : () -> tensor<0xi32>
|
||||
func @op_non_fully_defined_shape_fn(%arg0: tensor<0xi32>, %arg1: tensor<0xi32>) -> tensor<?xi32> {
|
||||
// CHECK: tf.BroadcastGradientArgs
|
||||
// CHECK-SAME: (tensor<0xi32>, tensor<0xi32>) -> (tensor<?xi32>, tensor<?xi32>)
|
||||
%2:2 = "tf.BroadcastGradientArgs"(%0, %1) {T = "tfdtype$DT_INT32", name = "BroadcastGradientArgs"} : (tensor<0xi32>, tensor<0xi32>) -> (tensor<?xi32>, tensor<?xi32>)
|
||||
%2:2 = "tf.BroadcastGradientArgs"(%arg0, %arg1) {T = "tfdtype$DT_INT32", name = "BroadcastGradientArgs"} : (tensor<0xi32>, tensor<0xi32>) -> (tensor<?xi32>, tensor<?xi32>)
|
||||
return %2#0 : tensor<?xi32>
|
||||
}
|
||||
|
||||
|
@ -32,6 +32,7 @@ limitations under the License.
|
||||
#include "mlir/Pass/PassRegistry.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir
|
||||
#include "mlir/Transforms/FoldUtils.h" // TF:local_config_mlir
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
|
||||
@ -213,9 +214,13 @@ bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect,
|
||||
|
||||
LogicalResult InferShapeUntilFixPoint(Region* region, int64_t graph_version,
|
||||
int64_t max_iteration) {
|
||||
Dialect* tf_dialect = region->getContext()->getRegisteredDialect(
|
||||
TensorFlowDialect::getDialectNamespace());
|
||||
MLIRContext* ctx = region->getContext();
|
||||
Dialect* tf_dialect = ctx->getRegisteredDialect<TensorFlowDialect>();
|
||||
|
||||
// An operation folder that is used to attempt folding before inference.
|
||||
OperationFolder folder(ctx);
|
||||
bool changed = true;
|
||||
|
||||
// TODO(aminim): we could have a more efficient traversal by guiding the
|
||||
// traversal with a worklist and reconsider only the nodes for which an
|
||||
// operand type was inferred. This would need to be careful if working on a
|
||||
@ -225,15 +230,17 @@ LogicalResult InferShapeUntilFixPoint(Region* region, int64_t graph_version,
|
||||
LLVM_DEBUG(llvm::dbgs()
|
||||
<< "Shape inference, iteration " << iteration << "\n");
|
||||
region->walk([&](Operation* op) {
|
||||
if (op->getDialect() == tf_dialect)
|
||||
if (op->getDialect() != tf_dialect) return;
|
||||
|
||||
// Before attempting inference, just try to fold the operation.
|
||||
if (failed(folder.tryToFold(op)))
|
||||
changed |= InferShapeForSingleOperation(op, tf_dialect, graph_version);
|
||||
});
|
||||
}
|
||||
if (changed) {
|
||||
region->getParentOp()->emitWarning()
|
||||
<< "Shape inference did not reach stable state after " << max_iteration
|
||||
<< " iterations";
|
||||
return failure();
|
||||
return region->getParentOp()->emitWarning()
|
||||
<< "Shape inference did not reach stable state after "
|
||||
<< max_iteration << " iterations";
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user