Modify ResourceOpLifting to be more permissive.
1. Don't fail if some resource inputs are not lifted. 2. Don't fail if an op besides read/assign operate on a resource. Ops like SummaryOps have resource operands and will be handled in subsequent passes. PiperOrigin-RevId: 326548756 Change-Id: I74138211762836b6eadcac9018db5740782a4380
This commit is contained in:
parent
3c946aab07
commit
43288ecdda
@ -112,26 +112,6 @@ func @internal_resource() -> tensor<*xi32> {
|
||||
|
||||
// -----
|
||||
|
||||
// Tests that pass fails when there are remaining resource operationss that can
|
||||
// not be lifted.
|
||||
|
||||
func @lifting_failure() -> tensor<*xi32> {
|
||||
|
||||
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
|
||||
|
||||
// expected-error @+1 {{has remaining resource inputs that can not be lifted}}
|
||||
%1 = "tf_device.cluster"() ( {
|
||||
%2 = "tf.ReadVariableOp"(%0) {dtype = i32} : (tensor<*x!tf.resource>) -> tensor<*xi32>
|
||||
%3 = "tf.SomeResourceOp"(%0, %2) : (tensor<*x!tf.resource>, tensor<*xi32>) -> tensor<*xi32>
|
||||
"tf.AssignVariableOp"(%0, %3) {dtype = i32} : (tensor<*x!tf.resource>, tensor<*xi32>) -> ()
|
||||
tf_device.return %3 : tensor<*xi32>
|
||||
}) {cluster_attr = "cluster_attr"} : () -> tensor<*xi32>
|
||||
|
||||
return %1 : tensor<*xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Tests that pass lifts resource reads/writes from a loop, and removed unused
|
||||
// resources.
|
||||
|
||||
@ -347,30 +327,6 @@ func @while_cond(%arg0: tensor<*x!tf.resource<tensor<f32>>>) -> tensor<f32> {
|
||||
|
||||
// -----
|
||||
|
||||
// Tests that pass reports error on unsupported ops in loop body.
|
||||
|
||||
func @cluster_with_loop() -> () {
|
||||
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<f32>>>
|
||||
"tf_device.cluster"() ( {
|
||||
%1 = "tf.While"(%0) {
|
||||
body = @while_body, cond = @while_cond, device = "", is_stateless = false}
|
||||
: (tensor<*x!tf.resource<tensor<f32>>>) -> (tensor<*x!tf.resource<tensor<f32>>>)
|
||||
tf_device.return
|
||||
}) {cluster_attr = "cluster_attr"} : () -> ()
|
||||
return
|
||||
}
|
||||
func @while_body(%arg0: tensor<*x!tf.resource<tensor<f32>>>) -> (tensor<*x!tf.resource<tensor<f32>>>) {
|
||||
// expected-error @+1 {{found unsupported operations on resource.}}
|
||||
"tf._UnknownOp"(%arg0) : (tensor<*x!tf.resource<tensor<f32>>>) -> ()
|
||||
return %arg0 : tensor<*x!tf.resource<tensor<f32>>>
|
||||
}
|
||||
func @while_cond(%arg0: tensor<*x!tf.resource<tensor<f32>>>) -> tensor<f32> {
|
||||
%read = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf.resource<tensor<f32>>>) -> tensor<f32>
|
||||
return %read : tensor<f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Tests that pass reports error on unsupported ops in loop cond.
|
||||
|
||||
func @cluster_with_loop() -> () {
|
||||
|
@ -330,15 +330,6 @@ LogicalResult HoistResourceOpsFromCluster(tf_device::ClusterOp cluster,
|
||||
getUsedValuesDefinedAbove(new_cluster.body(), new_cluster.body(),
|
||||
captured_values);
|
||||
|
||||
for (Value v : captured_values) {
|
||||
auto tensor_type = v.getType().dyn_cast<TensorType>();
|
||||
if (!tensor_type) continue;
|
||||
if (!tensor_type.getElementType().isa<TF::ResourceType>()) continue;
|
||||
|
||||
return new_cluster.emitOpError()
|
||||
<< "has remaining resource inputs that can not be lifted";
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
@ -361,29 +352,23 @@ LogicalResult FindResourceArgUseInfo(
|
||||
ResourceArgUseInfo info;
|
||||
info.used = false;
|
||||
info.updated = false;
|
||||
bool do_not_touch = false;
|
||||
bool read_or_assigned = false;
|
||||
for (auto user : arg.getUsers()) {
|
||||
if (user == return_op) continue;
|
||||
info.used = true;
|
||||
if (auto read = llvm::dyn_cast<TF::ReadVariableOp>(user)) {
|
||||
info.used = true;
|
||||
read_or_assigned = true;
|
||||
info.data_type = read.getType();
|
||||
continue;
|
||||
}
|
||||
if (auto assign = llvm::dyn_cast<TF::AssignVariableOp>(user)) {
|
||||
info.used = true;
|
||||
read_or_assigned = true;
|
||||
info.updated = true;
|
||||
info.data_type = assign.value().getType();
|
||||
continue;
|
||||
}
|
||||
if (isa<TF::StackPushV2Op, TF::StackPopV2Op>(user)) {
|
||||
// Stacks will be handled by a separate pass.
|
||||
do_not_touch = true;
|
||||
break;
|
||||
}
|
||||
user->emitOpError("found unsupported operations on resource.");
|
||||
return failure();
|
||||
}
|
||||
if (!do_not_touch) (*result)[arg.getArgNumber()] = info;
|
||||
if (!info.used || read_or_assigned) (*result)[arg.getArgNumber()] = info;
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user