[TF:MLIR] Enhance promote-resources-to-arguments pass to handle resource
accesses inside control flow. Use resource lifting to functionalize control flow statements. Change the pass ordering in ConvertMLIRToXlaComputation to perform control flow legalization after promoting resources to arguments. PiperOrigin-RevId: 294772980 Change-Id: I8e4b89d7c4c090fd473e579baf0424188ec26e59
This commit is contained in:
parent
dcc5a469be
commit
057e8630bc
@ -1,4 +1,4 @@
|
||||
// RUN: tf-opt %s -split-input-file -verify-diagnostics -tf-promote-resources-to-args | FileCheck %s -dump-input-on-failure
|
||||
// RUN: tf-opt %s -split-input-file -tf-promote-resources-to-args | FileCheck %s -dump-input-on-failure
|
||||
|
||||
// One resource, one read.
|
||||
// CHECK-LABEL: func @main(%arg0: tensor<f32>) -> tensor<2xf32>
|
||||
@ -88,12 +88,10 @@ func @main() -> tensor<2xf32> {
|
||||
// -----
|
||||
|
||||
// A resource is passed into tf.If
|
||||
// expected-error @+1 {{potential nested resource accesses in function}}
|
||||
func @cond_false(%arg0: tensor<!tf.resource<tensor<f32>>>, %arg1: tensor<f32>) -> tensor<f32> {
|
||||
return %arg1 : tensor<f32>
|
||||
}
|
||||
|
||||
// expected-error @+1 {{potential nested resource accesses in function}}
|
||||
func @cond_true(%arg0: tensor<!tf.resource<tensor<f32>>>, %arg1: tensor<f32>) -> tensor<f32> {
|
||||
%0 = "tf.Const"() {value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
|
||||
%1 = "tf.ReadVariableOp"(%arg0) : (tensor<!tf.resource<tensor<f32>>>) -> tensor<f32>
|
||||
@ -101,6 +99,7 @@ func @cond_true(%arg0: tensor<!tf.resource<tensor<f32>>>, %arg1: tensor<f32>) ->
|
||||
return %2 : tensor<f32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @main(%arg0: tensor<f32>) -> tensor<2xf32>
|
||||
func @main() -> tensor<2xf32> attributes {tf.entry_function = {inputs = "", outputs = "result"}} {
|
||||
%0 = "tf.Const"() {value = dense<1.050000e+03> : tensor<f32>} : () -> tensor<f32>
|
||||
%1 = "tf.VarHandleOp"() {container = "", shape = "tfshape$", shared_name = "x"} : () -> tensor<!tf.resource<tensor<f32>>>
|
||||
|
@ -24,13 +24,6 @@ limitations under the License.
|
||||
// . Compound resource operations have already been decomposed.
|
||||
// . Dead functions have already been removed, as resource arguments in dead
|
||||
// functions can cause the pass to fail.
|
||||
//
|
||||
// TODO(bixia): This pass currently reports any error when it sees ResourceType
|
||||
// as function arguments. That is, this pass assumes resource reads/writes in
|
||||
// functions called by the main function, such as through TF IfOp and WhileOp,
|
||||
// have already been functionalized. This functionalization can be achieved by
|
||||
// either finishing cl/281636304 or enhancing PromoteResourcesToArguments
|
||||
// here.
|
||||
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
@ -42,6 +35,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace TF {
|
||||
@ -200,7 +194,8 @@ void PromoteResourcesToArgsPass::runOnModule() {
|
||||
return;
|
||||
}
|
||||
|
||||
if (failed(VerifyNoPotentialNestedResourceAccesses(module)) ||
|
||||
if (failed(ResourceLiftingForFunctionalControlFlow(main_func)) ||
|
||||
failed(VerifyNoPotentialNestedResourceAccesses(module)) ||
|
||||
failed(PromoteResourcesToArguments(main_func))) {
|
||||
return signalPassFailure();
|
||||
}
|
||||
|
@ -211,9 +211,12 @@ Status ConvertMLIRToXlaComputation(mlir::ModuleOp module_op,
|
||||
bool use_tuple_args, bool return_tuple) {
|
||||
mlir::PassManager tf2xla(module_op.getContext());
|
||||
tf2xla.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
|
||||
tf2xla.addPass(mlir::xla_hlo::createLegalizeTFControlFlowPass());
|
||||
tf2xla.addPass(mlir::TFDevice::CreateDecomposeResourceOpsPass());
|
||||
tf2xla.addPass(mlir::TF::CreatePromoteResourcesToArgsPass());
|
||||
// LegalizeTFControlFlow encapsulates arguments for control flow operations
|
||||
// with a tuple argument which break the assumption of resource lifting
|
||||
// inside PromoteResourcesToArgs.
|
||||
tf2xla.addPass(mlir::xla_hlo::createLegalizeTFControlFlowPass());
|
||||
// We need to run LegalizeTFPass 2 times because first
|
||||
// LegalizeTFPass(allow_partial_conversion=true) can expose more graph pruning
|
||||
// and canonicalization opportunities that are necessary for the second
|
||||
|
Loading…
Reference in New Issue
Block a user