Modify ResourceOpLifting to be more permissive.

1. Don't fail if some resource inputs are not lifted.
2. Don't fail if an op besides read/assign operate on a resource.

Ops like SummaryOps have resource operands and will be handled in subsequent passes.

PiperOrigin-RevId: 326548756
Change-Id: I74138211762836b6eadcac9018db5740782a4380
This commit is contained in:
Ken Franko 2020-08-13 16:29:44 -07:00 committed by TensorFlower Gardener
parent 3c946aab07
commit 43288ecdda
2 changed files with 5 additions and 64 deletions

View File

@ -112,26 +112,6 @@ func @internal_resource() -> tensor<*xi32> {
// -----
// Tests that pass fails when there are remaining resource operationss that can
// not be lifted.
func @lifting_failure() -> tensor<*xi32> {
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
// expected-error @+1 {{has remaining resource inputs that can not be lifted}}
%1 = "tf_device.cluster"() ( {
%2 = "tf.ReadVariableOp"(%0) {dtype = i32} : (tensor<*x!tf.resource>) -> tensor<*xi32>
%3 = "tf.SomeResourceOp"(%0, %2) : (tensor<*x!tf.resource>, tensor<*xi32>) -> tensor<*xi32>
"tf.AssignVariableOp"(%0, %3) {dtype = i32} : (tensor<*x!tf.resource>, tensor<*xi32>) -> ()
tf_device.return %3 : tensor<*xi32>
}) {cluster_attr = "cluster_attr"} : () -> tensor<*xi32>
return %1 : tensor<*xi32>
}
// -----
// Tests that pass lifts resource reads/writes from a loop, and removed unused
// resources.
@ -347,30 +327,6 @@ func @while_cond(%arg0: tensor<*x!tf.resource<tensor<f32>>>) -> tensor<f32> {
// -----
// Tests that pass reports error on unsupported ops in loop body.
func @cluster_with_loop() -> () {
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<f32>>>
"tf_device.cluster"() ( {
%1 = "tf.While"(%0) {
body = @while_body, cond = @while_cond, device = "", is_stateless = false}
: (tensor<*x!tf.resource<tensor<f32>>>) -> (tensor<*x!tf.resource<tensor<f32>>>)
tf_device.return
}) {cluster_attr = "cluster_attr"} : () -> ()
return
}
func @while_body(%arg0: tensor<*x!tf.resource<tensor<f32>>>) -> (tensor<*x!tf.resource<tensor<f32>>>) {
// expected-error @+1 {{found unsupported operations on resource.}}
"tf._UnknownOp"(%arg0) : (tensor<*x!tf.resource<tensor<f32>>>) -> ()
return %arg0 : tensor<*x!tf.resource<tensor<f32>>>
}
func @while_cond(%arg0: tensor<*x!tf.resource<tensor<f32>>>) -> tensor<f32> {
%read = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf.resource<tensor<f32>>>) -> tensor<f32>
return %read : tensor<f32>
}
// -----
// Tests that pass reports error on unsupported ops in loop cond.
func @cluster_with_loop() -> () {

View File

@ -330,15 +330,6 @@ LogicalResult HoistResourceOpsFromCluster(tf_device::ClusterOp cluster,
getUsedValuesDefinedAbove(new_cluster.body(), new_cluster.body(),
captured_values);
for (Value v : captured_values) {
auto tensor_type = v.getType().dyn_cast<TensorType>();
if (!tensor_type) continue;
if (!tensor_type.getElementType().isa<TF::ResourceType>()) continue;
return new_cluster.emitOpError()
<< "has remaining resource inputs that can not be lifted";
}
return success();
}
@ -361,29 +352,23 @@ LogicalResult FindResourceArgUseInfo(
ResourceArgUseInfo info;
info.used = false;
info.updated = false;
bool do_not_touch = false;
bool read_or_assigned = false;
for (auto user : arg.getUsers()) {
if (user == return_op) continue;
info.used = true;
if (auto read = llvm::dyn_cast<TF::ReadVariableOp>(user)) {
info.used = true;
read_or_assigned = true;
info.data_type = read.getType();
continue;
}
if (auto assign = llvm::dyn_cast<TF::AssignVariableOp>(user)) {
info.used = true;
read_or_assigned = true;
info.updated = true;
info.data_type = assign.value().getType();
continue;
}
if (isa<TF::StackPushV2Op, TF::StackPopV2Op>(user)) {
// Stacks will be handled by a separate pass.
do_not_touch = true;
break;
}
user->emitOpError("found unsupported operations on resource.");
return failure();
}
if (!do_not_touch) (*result)[arg.getArgNumber()] = info;
if (!info.used || read_or_assigned) (*result)[arg.getArgNumber()] = info;
}
return success();
}