Avoid refining return types when used with unknown callee constraints
The return type of function cannot be refined without considering the callee of the function. The difference is fine for TF dialect ops, but not for std.call (for example). Given this currently doesn't use a worklist across the Module and the pass refines functions in single loop, don't update return types unless known safe. Also use "can refine" in checking if can fold (rather than static type check) and enable refining shape even where element type not castable. This does incur an additional uses lookup per function refinement, which is not ideal and could be avoided by moving to worklist & operating on module internally instead. PiperOrigin-RevId: 339870845 Change-Id: Ie9fdafa0e961dbf66878fd9767d5a3685fbd00f6
This commit is contained in:
parent
dcd0f24105
commit
dfa308ca5d
@ -440,6 +440,22 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr
|
||||
return %arg0 : tensor<2xi32>
|
||||
}
|
||||
|
||||
// Test not updating call site if a std.call is used.
|
||||
// CHECK-LABEL: func @call_partitioned_call2(
|
||||
// CHECK-SAME: -> tensor<*xi32>
|
||||
func @call_partitioned_call2() -> tensor<*xi32> {
|
||||
// CHECK: () -> tensor<*xi32>
|
||||
%0 = call @partitioned_called_func2() : () -> tensor<*xi32>
|
||||
return %0 : tensor<*xi32>
|
||||
}
|
||||
// CHECK-LABEL: func @partitioned_called_func2(
|
||||
// CHECK-SAME: -> tensor<*xi32>
|
||||
func @partitioned_called_func2() -> (tensor<*xi32>) {
|
||||
%0 = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||
%1 = tensor_cast %0 : tensor<1xi32> to tensor<*xi32>
|
||||
return %1 : tensor<*xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @tensor_list_refine
|
||||
func @tensor_list_refine() {
|
||||
tf_executor.graph {
|
||||
@ -501,16 +517,16 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr
|
||||
// CHECK-LABEL: cast_at_end(%arg0:
|
||||
// CHECK-SAME: tensor<16x194x199x4xui8>, tensor<16x194x199x4xi8>, tensor<*xi8>
|
||||
func @cast_at_end(%arg0: tensor<16x194x199x4xf32>, %arg1: tensor<16x194x199x4xi8>) -> (tensor<*xui8>, tensor<*xi8>, tensor<*xi8>) {
|
||||
// CHECK: %[[CAST_RESULT_0:.*]] = "tf.Cast"(%arg0)
|
||||
// CHECK-SAME: (tensor<16x194x199x4xf32>) -> tensor<16x194x199x4xui8>
|
||||
%27 = "tf.Cast"(%arg0) {Truncate = false, device = ""} : (tensor<16x194x199x4xf32>) -> tensor<*xui8>
|
||||
// CHECK: %[[CAST_RESULT_1:.*]] = "tf.Cast"(%arg0)
|
||||
// CHECK-SAME: (tensor<16x194x199x4xf32>) -> tensor<16x194x199x4xi8>
|
||||
// CHECK: %[[CAST_RESULT_2:.*]] = "tf.Cast"(%[[CAST_RESULT_1]])
|
||||
// CHECK-SAME: (tensor<16x194x199x4xi8>) -> tensor<*xi8>
|
||||
%28 = "tf.Cast"(%arg0) {Truncate = false, device = ""} : (tensor<16x194x199x4xf32>) -> tensor<*xi8>
|
||||
// CHECK: %[[CAST_RESULT_2:.*]] = "tf.Cast"(%arg0)
|
||||
// CHECK-SAME: (tensor<16x194x199x4xf32>) -> tensor<*xi8>
|
||||
// CHECK: %[[ADDI:.*]] = addi %[[CAST_RESULT_2]], %[[CAST_RESULT_2]]
|
||||
%2 = addi %28, %28 : tensor<*xi8>
|
||||
// CHECK: %[[CAST_RESULT_0:.*]] = "tf.Cast"(%arg0)
|
||||
// CHECK-SAME: (tensor<16x194x199x4xf32>) -> tensor<16x194x199x4xui8>
|
||||
// CHECK: %[[CAST_RESULT_1:.*]] = "tf.Cast"(%arg0)
|
||||
// CHECK-SAME: (tensor<16x194x199x4xf32>) -> tensor<16x194x199x4xi8>
|
||||
// CHECK: return %[[CAST_RESULT_0]], %[[CAST_RESULT_1]], %[[ADDI]]
|
||||
return %27, %28, %2 : tensor<*xui8>, tensor<*xi8>, tensor<*xi8>
|
||||
}
|
||||
|
@ -67,100 +67,13 @@ using tensorflow::shape_inference::ShapeHandle;
|
||||
namespace mlir {
|
||||
namespace TF {
|
||||
namespace {
|
||||
Optional<TypeRange> InferShapeForFunctionReturnType(FuncOp func) {
|
||||
// Find any return ops.
|
||||
SmallVector<ReturnOp, 4> return_ops;
|
||||
for (Block& block : func) {
|
||||
if (auto return_op = dyn_cast<ReturnOp>(block.getTerminator())) {
|
||||
return_ops.push_back(return_op);
|
||||
}
|
||||
}
|
||||
|
||||
// Right now we only handle the case of a single return op.
|
||||
// To handle multiple return ops, we would need to look at all their shapes
|
||||
// and come up with a common shape and insert appropriate casts.
|
||||
if (return_ops.size() != 1) {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Find the return type.
|
||||
auto return_op = return_ops.front();
|
||||
|
||||
// Manually fold tf.Cast that precedes the return instruction and only differs
|
||||
// in shape refinement level.
|
||||
for (OpOperand& arg_op : return_op.getOperation()->getOpOperands()) {
|
||||
Operation* arg_defining_op = arg_op.get().getDefiningOp();
|
||||
if (auto cast_op = dyn_cast_or_null<CastOp>(arg_defining_op)) {
|
||||
// Shape inference should not change the element type.
|
||||
if (cast_op.SrcT() != cast_op.DstT()) continue;
|
||||
// We only refine the result shape if the result a dynamic shape, the
|
||||
// input has static shape, and the two shapes are compatible.
|
||||
auto has_static_shape = [](const Value value) {
|
||||
auto shaped_type = value.getType().dyn_cast<ShapedType>();
|
||||
return shaped_type && shaped_type.hasStaticShape();
|
||||
};
|
||||
Value input = cast_op.x();
|
||||
Value result = cast_op.y();
|
||||
if (!has_static_shape(input) || has_static_shape(result) ||
|
||||
failed(verifyCompatibleShape(input.getType(), result.getType())))
|
||||
continue;
|
||||
|
||||
arg_op.set(cast_op.x());
|
||||
if (cast_op.y().use_empty()) cast_op.erase();
|
||||
}
|
||||
}
|
||||
|
||||
return TypeRange(return_op.getOperandTypes());
|
||||
}
|
||||
|
||||
// Returns if the shape inference pass supports an op outside the TF dialect.
|
||||
bool IsSupportedNonTFOp(Operation* op) {
|
||||
return isa<ReturnOp, tf_device::ReturnOp, tf_device::ClusterOp,
|
||||
tf_device::LaunchOp, tf_executor::EnterOp, tf_executor::ExitOp,
|
||||
tf_executor::FetchOp, tf_executor::GraphOp, tf_executor::IslandOp,
|
||||
tf_executor::LoopCondOp, tf_executor::MergeOp,
|
||||
tf_executor::NextIterationSinkOp, tf_executor::SwitchNOp,
|
||||
tf_executor::SwitchOp, tf_executor::YieldOp>(op);
|
||||
}
|
||||
|
||||
// Returns whether a cast back would need to be inserted, e.g., whether the
|
||||
// operation of which use is an operand allows for shape refinement without
|
||||
// a cast.
|
||||
bool NeedsCastBack(OpOperand& use, Dialect* tf_dialect) {
|
||||
return use.getOwner()->getDialect() != tf_dialect &&
|
||||
!IsSupportedNonTFOp(use.getOwner());
|
||||
}
|
||||
|
||||
// Updates the result of an operation to a new inferred type. Also inserts
|
||||
// tf.Cast operation for uses that are incompatible with the new type.
|
||||
void UpdateTypeAndInsertIncompatibleUseCasts(Dialect* tf_dialect, Type new_type,
|
||||
Operation* op, Value result) {
|
||||
// A tf.Cast operation is lazily created on the first use requires a cast.
|
||||
TF::CastOp cast_op;
|
||||
auto get_cast_op = [&]() {
|
||||
if (!cast_op) {
|
||||
OpBuilder b(op);
|
||||
b.setInsertionPointAfter(op);
|
||||
cast_op = b.create<TF::CastOp>(op->getLoc(), result.getType(), result,
|
||||
/*truncate=*/b.getBoolAttr(false));
|
||||
}
|
||||
return Value(cast_op);
|
||||
};
|
||||
// First insert cast back for uses that need a cast and then
|
||||
// update the type.
|
||||
for (OpOperand& use : make_early_inc_range(result.getUses())) {
|
||||
if (NeedsCastBack(use, tf_dialect)) use.set(get_cast_op());
|
||||
}
|
||||
|
||||
result.setType(new_type);
|
||||
}
|
||||
|
||||
// Returns whether type can be further refined.
|
||||
bool CanBeRefined(Type type) {
|
||||
auto shape_type = type.dyn_cast<ShapedType>();
|
||||
return shape_type &&
|
||||
(!shape_type.hasStaticShape() ||
|
||||
shape_type.getElementType().isa<TF::ResourceType, TF::VariantType>());
|
||||
shape_type.getElementType().isa<TF::TensorFlowTypeWithSubtype>());
|
||||
}
|
||||
|
||||
// Returns whether `original_type` type can be refined with
|
||||
@ -179,6 +92,122 @@ bool CanRefineTypeWith(Type original_type, Type potential_refined_type) {
|
||||
!element_type_with_subtype.GetSubtypes().empty();
|
||||
}
|
||||
|
||||
Optional<TypeRange> InferShapeForFunctionReturnType(FuncOp func,
|
||||
Dialect* tf_dialect) {
|
||||
LLVM_DEBUG(llvm::dbgs() << "Inferring return type for: " << func.getName()
|
||||
<< "\n");
|
||||
|
||||
// Find any return ops.
|
||||
SmallVector<ReturnOp, 4> return_ops;
|
||||
for (Block& block : func) {
|
||||
if (auto return_op = dyn_cast<ReturnOp>(block.getTerminator())) {
|
||||
return_ops.push_back(return_op);
|
||||
}
|
||||
}
|
||||
|
||||
// Right now we only handle the case of a single return op.
|
||||
// To handle multiple return ops, we would need to look at all their shapes
|
||||
// and come up with a common shape and insert appropriate casts.
|
||||
if (return_ops.size() != 1) return None;
|
||||
|
||||
// Find the return type.
|
||||
auto return_op = return_ops.front();
|
||||
|
||||
// Avoid refining result type if not used by TF dialect op. This can be
|
||||
// relaxed once we move to a work queue, but at the moment this can result
|
||||
// in invalid modules (in particular when a std.call is used but we've
|
||||
// already processed the function where the call is made from before this).
|
||||
auto uses = mlir::SymbolTable::getSymbolUses(
|
||||
func.getOperation(), func.getParentOfType<ModuleOp>());
|
||||
if (uses) {
|
||||
for (auto use : *uses) {
|
||||
if (use.getUser()->getDialect() != tf_dialect) {
|
||||
LLVM_DEBUG(llvm::dbgs() << "Skipping refing return type of function "
|
||||
"given non-TF dialect use\n");
|
||||
return TypeRange(return_op.getOperandTypes());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Manually fold tf.Cast that precedes the return instruction and only differs
|
||||
// in shape refinement level.
|
||||
for (OpOperand& arg_op : return_op.getOperation()->getOpOperands()) {
|
||||
Operation* arg_defining_op = arg_op.get().getDefiningOp();
|
||||
if (auto cast_op = dyn_cast_or_null<CastOp>(arg_defining_op)) {
|
||||
Value input = cast_op.x();
|
||||
Value result = cast_op.y();
|
||||
if (!CanRefineTypeWith(result.getType(), input.getType())) continue;
|
||||
|
||||
LLVM_DEBUG({
|
||||
llvm::errs() << "\tfolding & updating return type ";
|
||||
cast_op.getResult().getType().print(llvm::errs());
|
||||
cast_op.getOperand().getType().print(llvm::errs() << " to ");
|
||||
llvm::errs() << "\n";
|
||||
});
|
||||
|
||||
// Shape inference should not change the element type.
|
||||
if (HasCompatibleElementTypes(input.getType(), result.getType())) {
|
||||
arg_op.set(cast_op.x());
|
||||
} else {
|
||||
OpBuilder b(return_op.getOperation());
|
||||
auto type = RankedTensorType::get(
|
||||
input.getType().cast<TensorType>().getShape(),
|
||||
result.getType().cast<TensorType>().getElementType());
|
||||
auto new_cast_op =
|
||||
b.create<TF::CastOp>(return_op.getLoc(), type, input,
|
||||
/*truncate=*/b.getBoolAttr(false));
|
||||
arg_op.set(new_cast_op);
|
||||
}
|
||||
if (cast_op.y().use_empty()) cast_op.erase();
|
||||
}
|
||||
}
|
||||
|
||||
return TypeRange(return_op.getOperandTypes());
|
||||
}
|
||||
|
||||
// Returns if the shape inference pass supports an op outside the TF dialect.
|
||||
bool IsSupportedNonTFOp(Operation* op) {
|
||||
return isa<tf_device::ReturnOp, tf_device::ClusterOp, tf_device::LaunchOp,
|
||||
tf_executor::EnterOp, tf_executor::ExitOp, tf_executor::FetchOp,
|
||||
tf_executor::GraphOp, tf_executor::IslandOp,
|
||||
tf_executor::LoopCondOp, tf_executor::MergeOp,
|
||||
tf_executor::NextIterationSinkOp, tf_executor::SwitchNOp,
|
||||
tf_executor::SwitchOp, tf_executor::YieldOp>(op);
|
||||
}
|
||||
|
||||
// Returns whether a cast back would need to be inserted, e.g., whether the
|
||||
// operation of which use is an operand allows for shape refinement without
|
||||
// a cast.
|
||||
bool NeedsCastBack(OpOperand& use, Dialect* tf_dialect) {
|
||||
return use.getOwner()->getDialect() != tf_dialect &&
|
||||
!IsSupportedNonTFOp(use.getOwner());
|
||||
}
|
||||
|
||||
// Updates the result of an operation to a new inferred type. Also inserts
|
||||
// tf.Cast operation for uses that are incompatible with the new type.
|
||||
void UpdateTypeAndInsertIncompatibleUseCasts(Dialect* tf_dialect, Type new_type,
|
||||
Value result) {
|
||||
// A tf.Cast operation is lazily created on the first use requires a cast.
|
||||
TF::CastOp cast_op;
|
||||
auto get_cast_op = [&]() {
|
||||
if (!cast_op) {
|
||||
Operation* op = result.getDefiningOp();
|
||||
OpBuilder b(op);
|
||||
b.setInsertionPointAfter(op);
|
||||
cast_op = b.create<TF::CastOp>(op->getLoc(), result.getType(), result,
|
||||
/*truncate=*/b.getBoolAttr(false));
|
||||
}
|
||||
return Value(cast_op);
|
||||
};
|
||||
// First insert cast back for uses that need a cast and then
|
||||
// update the type.
|
||||
for (OpOperand& use : make_early_inc_range(result.getUses())) {
|
||||
if (NeedsCastBack(use, tf_dialect)) use.set(get_cast_op());
|
||||
}
|
||||
|
||||
result.setType(new_type);
|
||||
}
|
||||
|
||||
// Refines the type of `result` of `op` using the type `potential_refined_type`.
|
||||
// Return true if the type was changed.
|
||||
bool RefineResultType(Operation* op, Value result,
|
||||
@ -187,7 +216,7 @@ bool RefineResultType(Operation* op, Value result,
|
||||
return false;
|
||||
|
||||
UpdateTypeAndInsertIncompatibleUseCasts(op->getDialect(),
|
||||
potential_refined_type, op, result);
|
||||
potential_refined_type, result);
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -197,6 +226,7 @@ bool InferShapeForCall(CallOpInterface call_op) {
|
||||
FuncOp func = dyn_cast<FuncOp>(call_op.resolveCallable());
|
||||
if (!func) return false;
|
||||
|
||||
LLVM_DEBUG(llvm::dbgs() << "Infer shape for call " << func.getName());
|
||||
Operation* op = call_op.getOperation();
|
||||
bool changed = false;
|
||||
// Map each of the results of the call to the returned type of the
|
||||
@ -205,6 +235,7 @@ bool InferShapeForCall(CallOpInterface call_op) {
|
||||
changed = RefineResultType(op, std::get<0>(result), std::get<1>(result)) ||
|
||||
changed;
|
||||
}
|
||||
LLVM_DEBUG(llvm::dbgs() << " changed ? " << changed << "\n");
|
||||
|
||||
return changed;
|
||||
}
|
||||
@ -232,8 +263,7 @@ bool InferShapeForCast(CastOp op, Dialect* tf_dialect) {
|
||||
ranked_op_type.getShape(),
|
||||
result.getType().cast<ShapedType>().getElementType());
|
||||
|
||||
UpdateTypeAndInsertIncompatibleUseCasts(tf_dialect, new_type, op,
|
||||
op.getResult());
|
||||
UpdateTypeAndInsertIncompatibleUseCasts(tf_dialect, new_type, op.getResult());
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -289,7 +319,7 @@ bool RefineWithInferTypeOpInterface(InferTypeOpInterface infer_ti,
|
||||
if (std::get<0>(result).getType() == std::get<1>(result)) continue;
|
||||
|
||||
UpdateTypeAndInsertIncompatibleUseCasts(
|
||||
op->getDialect(), std::get<1>(result), op, std::get<0>(result));
|
||||
op->getDialect(), std::get<1>(result), std::get<0>(result));
|
||||
changed = true;
|
||||
}
|
||||
return changed;
|
||||
@ -528,13 +558,14 @@ class ShapeInference {
|
||||
// whether any result type changed.
|
||||
bool InferShapeForNonTFDialectOperation(Operation* op);
|
||||
|
||||
Dialect* const tf_dialect_;
|
||||
|
||||
private:
|
||||
// Mapping between ValuePort (which corresponds to an OpResult or smaller,
|
||||
// e.g., first element of OpResult produced) to an Attribute if the ValuePort
|
||||
// corresponds to a constant value.
|
||||
ValuePortResultMap results_;
|
||||
int64_t graph_version_;
|
||||
Dialect* tf_dialect_;
|
||||
|
||||
// TODO(b/154065712): Remove propagate_caller_callee_constants once using
|
||||
// SCCP pass instead.
|
||||
@ -543,10 +574,9 @@ class ShapeInference {
|
||||
|
||||
ShapeInference::ShapeInference(int64_t graph_version, MLIRContext* context,
|
||||
bool propagate_caller_callee_constants)
|
||||
: graph_version_(graph_version),
|
||||
propagate_caller_callee_constants_(propagate_caller_callee_constants) {
|
||||
tf_dialect_ = context->getLoadedDialect<TensorFlowDialect>();
|
||||
}
|
||||
: tf_dialect_(context->getLoadedDialect<TensorFlowDialect>()),
|
||||
graph_version_(graph_version),
|
||||
propagate_caller_callee_constants_(propagate_caller_callee_constants) {}
|
||||
|
||||
ShapeHandle ShapeInference::ComputeOutputAsShape(OpResult result,
|
||||
InferenceContext* ic) {
|
||||
@ -633,8 +663,7 @@ bool ShapeInference::RefineTypeForPassThroughOperands(Operation* op,
|
||||
.isa<TF::TensorFlowRefType>())
|
||||
continue;
|
||||
|
||||
UpdateTypeAndInsertIncompatibleUseCasts(tf_dialect_, operand_type, op,
|
||||
result);
|
||||
UpdateTypeAndInsertIncompatibleUseCasts(tf_dialect_, operand_type, result);
|
||||
changed = true;
|
||||
}
|
||||
return changed;
|
||||
@ -666,7 +695,7 @@ bool ShapeInference::RefineShapeForPassThroughOps(Operation* op) {
|
||||
|
||||
auto new_type = RankedTensorType::get(operand_type.getShape(),
|
||||
result_type.getElementType());
|
||||
UpdateTypeAndInsertIncompatibleUseCasts(tf_dialect_, new_type, op, result);
|
||||
UpdateTypeAndInsertIncompatibleUseCasts(tf_dialect_, new_type, result);
|
||||
changed = true;
|
||||
}
|
||||
return changed;
|
||||
@ -790,7 +819,7 @@ bool ShapeInference::InferShapeForSingleOperation(Operation* op) {
|
||||
inferred_type = UnrankedTensorType::get(inferred.getElementType());
|
||||
|
||||
if (op_result.getType() == inferred_type) continue;
|
||||
UpdateTypeAndInsertIncompatibleUseCasts(tf_dialect_, inferred_type, op,
|
||||
UpdateTypeAndInsertIncompatibleUseCasts(tf_dialect_, inferred_type,
|
||||
op_result);
|
||||
changed = true;
|
||||
}
|
||||
@ -831,7 +860,7 @@ LogicalResult ShapeInference::PropagateShapeToFunctions(
|
||||
continue;
|
||||
}
|
||||
|
||||
auto new_return_types = InferShapeForFunctionReturnType(func);
|
||||
auto new_return_types = InferShapeForFunctionReturnType(func, tf_dialect_);
|
||||
if (new_return_types)
|
||||
func.setType(FunctionType::get(input_types, new_return_types.getValue(),
|
||||
func.getContext()));
|
||||
@ -1015,7 +1044,7 @@ LogicalResult ShapeInference::TryToFold(Operation* op) {
|
||||
if (ElementsAttr eattr = attr.dyn_cast_or_null<ElementsAttr>()) {
|
||||
if (std::get<0>(result).getType() == eattr.getType()) continue;
|
||||
|
||||
UpdateTypeAndInsertIncompatibleUseCasts(tf_dialect_, eattr.getType(), op,
|
||||
UpdateTypeAndInsertIncompatibleUseCasts(tf_dialect_, eattr.getType(),
|
||||
std::get<0>(result));
|
||||
}
|
||||
}
|
||||
@ -1085,7 +1114,8 @@ LogicalResult InferShapeForFunction(FuncOp func,
|
||||
return failure();
|
||||
// TODO(b/156276510): Verify that it is always fine to refine a function's
|
||||
// return type, as long as we do not change the argument shapes.
|
||||
if (auto return_types = InferShapeForFunctionReturnType(func)) {
|
||||
if (auto return_types =
|
||||
InferShapeForFunctionReturnType(func, context.tf_dialect_)) {
|
||||
func.setType(FunctionType::get(func.getType().getInputs(),
|
||||
return_types.getValue(),
|
||||
func.getContext()));
|
||||
@ -1133,7 +1163,8 @@ LogicalResult InferShapeForFunction(FuncOp func,
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto return_types = InferShapeForFunctionReturnType(func);
|
||||
auto return_types =
|
||||
InferShapeForFunctionReturnType(func, context.tf_dialect_);
|
||||
func.setType(FunctionType::get(new_arg_types,
|
||||
return_types.hasValue()
|
||||
? return_types.getValue()
|
||||
|
Loading…
Reference in New Issue
Block a user