Explicitly reject ops with symbol ref attributes in the fallback legalization pass
We don't attempt legalization for ops with symbol reference attribute even if they are in allow-list. Xla op kernels for these ops compile the function to HLO on-demand which won't work in our case as it may contain unsupported ops in the fallback nor we provide XlaCompiler to the kernel. Also, these patterns are used from a function pass. This is just an extra check which might be useful if we switch to deny list instead of allow list or avoid crashing if an op with symbol ref is added to the allow list. PiperOrigin-RevId: 352074881 Change-Id: Ided5dc736b61aa25bd5b3960059ee4949086f54b
This commit is contained in:
parent
102e1f9855
commit
4db28856db
@ -339,6 +339,26 @@ func @xla_svd(%arg0: tensor<1x1xf32>) -> (tensor<1xf32>, tensor<1x1xf32>, tensor
|
||||
return %s, %u, %v : tensor<1xf32>, tensor<1x1xf32>, tensor<1x1xf32>
|
||||
}
|
||||
|
||||
func @abs_impl(%arg0: f32) -> f32 {
|
||||
return %arg0 : f32
|
||||
}
|
||||
|
||||
// This test verifies that legalization for ops with symbol reference attribute
|
||||
// is not attempted even if they are in allow-list. XLA op kernels for these
|
||||
// ops compile the function to HLO on-demand which won't work in our case as it
|
||||
// may contain unsupported ops in the fallback nor we provide XlaCompiler to
|
||||
// the kernel. Using a allowed op Abs to protect against future addition of a
|
||||
// new op with a symbol ref.
|
||||
|
||||
// CHECK-LABEL: @abs_with_symbol_ref
|
||||
func @abs_with_symbol_ref(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
||||
// CHECK: tf.Abs
|
||||
// expected-remark@+1 {{ops with symbol references are not supported}}
|
||||
%0 = "tf.Abs"(%arg0) {_body = @abs_impl} : (tensor<2xf32>) -> tensor<2xf32>
|
||||
|
||||
return %0 : tensor<2xf32>
|
||||
}
|
||||
|
||||
// TODO(hinsu): Add a test with a valid TF op for which tf2xla kernel is
|
||||
// available but doesn't support this instance.
|
||||
}
|
||||
|
@ -410,6 +410,13 @@ LogicalResult Tf2XlaRewriter::LegalizeOp() {
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto& attr : op_->getAttrs()) {
|
||||
if (attr.second.isa<SymbolRefAttr>()) {
|
||||
return op_->emitRemark()
|
||||
<< "ops with symbol references are not supported";
|
||||
}
|
||||
}
|
||||
|
||||
auto nodedef_or = tensorflow::ConvertTFDialectOpToNodeDef(
|
||||
op_, name_mapper_.GetUniqueName(op_), /*ignore_unregistered_attrs=*/true);
|
||||
if (!nodedef_or.ok()) {
|
||||
|
Loading…
Reference in New Issue
Block a user