[MLIR] Fix resource op lifting to be more permissive with unsupported operations

- If a resource use is seen in an unsupported operation, skip lifting that
  resource instead of failing the transformation

PiperOrigin-RevId: 331146896
Change-Id: Ie018a9aaed29a66b4e875b43ee9c9dd646694bc4
This commit is contained in:
Rahul Joshi 2020-09-11 07:50:29 -07:00 committed by TensorFlower Gardener
parent ef9971b6d7
commit 457066c8c6
2 changed files with 36 additions and 4 deletions

View File

@ -961,5 +961,31 @@ func @if_region_with_store_in_both(%arg0: tensor<i1>) {
}
// -----
// Make sure unsupported resources are handled correctly. If a resource is used
// in an unsupported op, resource op lifting should skip lifting that resource.
// So for the below test, the IR should stay unchanged.
// CHECK-LABEL: func @test_unsupported_resource_op
func @test_unsupported_resource_op() -> tensor<*xi32> {
// CHECK: "tf.VarHandleOp"
// CHECK: "tf_device.cluster"() ( {
// CHECK: "tf.ReadVariableOp"
// CHECK: "tf.SomeResourceOperation"
// CHECK: "tf.SomeComputation"
// CHECK: tf_device.return
// CHECK: {cluster_attr = "cluster_attr"}
// CHECK: return
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
%1 = "tf_device.cluster"() ( {
%2 = "tf.ReadVariableOp"(%0) {dtype = i32} : (tensor<*x!tf.resource>) -> tensor<*xi32>
"tf.SomeResourceOperation"(%0) : (tensor<*x!tf.resource>) -> ()
%3 = "tf.SomeComputation"(%2) : (tensor<*xi32>) -> (tensor<*xi32>)
tf_device.return %3 : tensor<*xi32>
}) {cluster_attr = "cluster_attr"} : () -> tensor<*xi32>
return %1 : tensor<*xi32>
}

View File

@ -309,8 +309,9 @@ LogicalResult RegionResourceHoister::Analyze() {
num_new_results_ = op_->getNumResults();
for (auto resource : all_resources) {
ResourceInfo& info = resources_[resource];
ResourceInfo info;
llvm::BitVector written_regions(op_->getNumRegions());
bool unsupported_use = false;
for (OpOperand& use : resource.getUses()) {
Operation* user = use.getOwner();
// If the user is not in one of the regions, we are not interested in it.
@ -334,9 +335,8 @@ LogicalResult RegionResourceHoister::Analyze() {
auto read = dyn_cast<TF::ReadVariableOp>(user);
auto write = dyn_cast<TF::AssignVariableOp>(user);
if (!read && !write) {
return op_->emitError(
"Unsupported use of resource variable in operation ")
<< user->getName().getStringRef();
unsupported_use = true;
break;
}
if (read && !info.is_read) {
@ -353,6 +353,10 @@ LogicalResult RegionResourceHoister::Analyze() {
}
}
// If the resource is used in an op that we do not understand, skip
// lifting for that resource.
if (unsupported_use) continue;
info.is_written_all = written_regions.count() == op_->getNumRegions();
// If the resource is written in some but not all regions, we would need
@ -371,6 +375,8 @@ LogicalResult RegionResourceHoister::Analyze() {
written_resources_.insert(resource);
if (!info.IsResultIndexAssigned()) info.result_index = num_new_results_++;
}
resources_.insert({resource, info});
}
return success();
}