[MLIR] Fix resource op lifting to be more permissive with unsupported operations
- If a resource use is seen in an unsupported operation, skip lifting that resource instead of failing the transformation PiperOrigin-RevId: 331146896 Change-Id: Ie018a9aaed29a66b4e875b43ee9c9dd646694bc4
This commit is contained in:
parent
ef9971b6d7
commit
457066c8c6
@ -961,5 +961,31 @@ func @if_region_with_store_in_both(%arg0: tensor<i1>) {
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
// Make sure unsupported resources are handled correctly. If a resource is used
|
||||
// in an unsupported op, resource op lifting should skip lifting that resource.
|
||||
// So for the below test, the IR should stay unchanged.
|
||||
// CHECK-LABEL: func @test_unsupported_resource_op
|
||||
func @test_unsupported_resource_op() -> tensor<*xi32> {
|
||||
// CHECK: "tf.VarHandleOp"
|
||||
// CHECK: "tf_device.cluster"() ( {
|
||||
// CHECK: "tf.ReadVariableOp"
|
||||
// CHECK: "tf.SomeResourceOperation"
|
||||
// CHECK: "tf.SomeComputation"
|
||||
// CHECK: tf_device.return
|
||||
// CHECK: {cluster_attr = "cluster_attr"}
|
||||
// CHECK: return
|
||||
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
|
||||
%1 = "tf_device.cluster"() ( {
|
||||
%2 = "tf.ReadVariableOp"(%0) {dtype = i32} : (tensor<*x!tf.resource>) -> tensor<*xi32>
|
||||
"tf.SomeResourceOperation"(%0) : (tensor<*x!tf.resource>) -> ()
|
||||
%3 = "tf.SomeComputation"(%2) : (tensor<*xi32>) -> (tensor<*xi32>)
|
||||
tf_device.return %3 : tensor<*xi32>
|
||||
}) {cluster_attr = "cluster_attr"} : () -> tensor<*xi32>
|
||||
|
||||
return %1 : tensor<*xi32>
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
@ -309,8 +309,9 @@ LogicalResult RegionResourceHoister::Analyze() {
|
||||
num_new_results_ = op_->getNumResults();
|
||||
|
||||
for (auto resource : all_resources) {
|
||||
ResourceInfo& info = resources_[resource];
|
||||
ResourceInfo info;
|
||||
llvm::BitVector written_regions(op_->getNumRegions());
|
||||
bool unsupported_use = false;
|
||||
for (OpOperand& use : resource.getUses()) {
|
||||
Operation* user = use.getOwner();
|
||||
// If the user is not in one of the regions, we are not interested in it.
|
||||
@ -334,9 +335,8 @@ LogicalResult RegionResourceHoister::Analyze() {
|
||||
auto read = dyn_cast<TF::ReadVariableOp>(user);
|
||||
auto write = dyn_cast<TF::AssignVariableOp>(user);
|
||||
if (!read && !write) {
|
||||
return op_->emitError(
|
||||
"Unsupported use of resource variable in operation ")
|
||||
<< user->getName().getStringRef();
|
||||
unsupported_use = true;
|
||||
break;
|
||||
}
|
||||
|
||||
if (read && !info.is_read) {
|
||||
@ -353,6 +353,10 @@ LogicalResult RegionResourceHoister::Analyze() {
|
||||
}
|
||||
}
|
||||
|
||||
// If the resource is used in an op that we do not understand, skip
|
||||
// lifting for that resource.
|
||||
if (unsupported_use) continue;
|
||||
|
||||
info.is_written_all = written_regions.count() == op_->getNumRegions();
|
||||
|
||||
// If the resource is written in some but not all regions, we would need
|
||||
@ -371,6 +375,8 @@ LogicalResult RegionResourceHoister::Analyze() {
|
||||
written_resources_.insert(resource);
|
||||
if (!info.IsResultIndexAssigned()) info.result_index = num_new_results_++;
|
||||
}
|
||||
|
||||
resources_.insert({resource, info});
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user