Explicitly reject ops with symbol ref attributes in the fallback legalization pass

We don't attempt legalization for ops with symbol reference attribute even if they are in allow-list. Xla op kernels for these ops compile the function to HLO on-demand which won't work in our case as it may contain unsupported ops in the fallback nor we provide XlaCompiler to the kernel. Also, these patterns are used from a function pass.

This is just an extra check which might be useful if we switch to deny list instead of allow list or avoid crashing if an op with symbol ref is added to the allow list.

PiperOrigin-RevId: 352074881
Change-Id: Ided5dc736b61aa25bd5b3960059ee4949086f54b
This commit is contained in:
Smit Hinsu 2021-01-15 13:39:01 -08:00 committed by TensorFlower Gardener
parent 102e1f9855
commit 4db28856db
2 changed files with 27 additions and 0 deletions

View File

@ -339,6 +339,26 @@ func @xla_svd(%arg0: tensor<1x1xf32>) -> (tensor<1xf32>, tensor<1x1xf32>, tensor
return %s, %u, %v : tensor<1xf32>, tensor<1x1xf32>, tensor<1x1xf32>
}
func @abs_impl(%arg0: f32) -> f32 {
return %arg0 : f32
}
// This test verifies that legalization for ops with symbol reference attribute
// is not attempted even if they are in allow-list. XLA op kernels for these
// ops compile the function to HLO on-demand which won't work in our case as it
// may contain unsupported ops in the fallback nor we provide XlaCompiler to
// the kernel. Using a allowed op Abs to protect against future addition of a
// new op with a symbol ref.
// CHECK-LABEL: @abs_with_symbol_ref
func @abs_with_symbol_ref(%arg0: tensor<2xf32>) -> tensor<2xf32> {
// CHECK: tf.Abs
// expected-remark@+1 {{ops with symbol references are not supported}}
%0 = "tf.Abs"(%arg0) {_body = @abs_impl} : (tensor<2xf32>) -> tensor<2xf32>
return %0 : tensor<2xf32>
}
// TODO(hinsu): Add a test with a valid TF op for which tf2xla kernel is
// available but doesn't support this instance.
}

View File

@ -410,6 +410,13 @@ LogicalResult Tf2XlaRewriter::LegalizeOp() {
}
}
for (const auto& attr : op_->getAttrs()) {
if (attr.second.isa<SymbolRefAttr>()) {
return op_->emitRemark()
<< "ops with symbol references are not supported";
}
}
auto nodedef_or = tensorflow::ConvertTFDialectOpToNodeDef(
op_, name_mapper_.GetUniqueName(op_), /*ignore_unregistered_attrs=*/true);
if (!nodedef_or.ok()) {