Add a new pass for promoting VarHandle ops to TF saved model arguments

PiperOrigin-RevId: 314415654
Change-Id: I0dccc4dd9c3c20625950816d341ce910384d3906
This commit is contained in:
Jaesung Chung 2020-06-02 15:25:48 -07:00 committed by TensorFlower Gardener
parent e410915945
commit 96dfb6f4a8
3 changed files with 37 additions and 24 deletions

View File

@ -1,4 +1,5 @@
// RUN: tf-opt %s -split-input-file -verify-diagnostics -tf-promote-var-handles-to-args | FileCheck %s -dump-input-on-failure // Run a pass for promoting tf.VarHandleOps to function arguments in a format of TensorFlowSavedModelDialect.
// RUN: tf-opt %s -split-input-file -verify-diagnostics -tf-saved-model-promote-var-handles-to-args | FileCheck %s -dump-input-on-failure
// Tests main function with multiple blocks. // Tests main function with multiple blocks.
@ -11,24 +12,27 @@ func @main() {
// ----- // -----
"tf_saved_model.global_tensor"() {sym_name = "x", type = tensor<f32>, value = dense<1.67482901> : tensor<f32>} : () -> ()
"tf_saved_model.global_tensor"() {sym_name = "y", type = tensor<i32>, value = dense<0> : tensor<i32>} : () -> ()
// CHECK-LABEL: func @no_args // CHECK-LABEL: func @no_args
// CHECK-SAME: (%arg0: tensor<!tf.resource> {tf.resource_name = "x"}) // CHECK-SAME: (%arg0: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @x})
// CHECK-NOT: "tf.VarHandleOp" // CHECK-NOT: "tf.VarHandleOp"
func @no_args() { func @no_args() {
%0 = "tf.VarHandleOp"() {container = "", shape = "tfshape$", shared_name = "x"} : () -> tensor<!tf.resource> %0 = "tf.VarHandleOp"() {container = "", shape = "tfshape$", shared_name = "x"} : () -> tensor<!tf.resource<tensor<f32>>>
return return
} }
// CHECK-LABEL: func @some_args // CHECK-LABEL: func @some_args
// CHECK-SAME: (%arg0: tensor<i1>, %arg1: tensor<!tf.resource> {tf.resource_name = "x"}) // CHECK-SAME: (%arg0: tensor<i1>, %arg1: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @x})
// CHECK-NOT: "tf.VarHandleOp" // CHECK-NOT: "tf.VarHandleOp"
func @some_args(%arg0: tensor<i1>) { func @some_args(%arg0: tensor<i1>) {
%0 = "tf.VarHandleOp"() {container = "", shape = "tfshape$", shared_name = "x"} : () -> tensor<!tf.resource> %0 = "tf.VarHandleOp"() {container = "", shape = "tfshape$", shared_name = "x"} : () -> tensor<!tf.resource<tensor<f32>>>
return return
} }
// CHECK-LABEL: func @unique_vars // CHECK-LABEL: func @unique_vars
// CHECK-SAME: (%arg0: tensor<!tf.resource<tensor<f32>>> {tf.resource_name = "x"}, %arg1: tensor<!tf.resource<tensor<i32>>> {tf.resource_name = "y"}) // CHECK-SAME: (%arg0: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @x}, %arg1: tensor<!tf.resource<tensor<i32>>> {tf_saved_model.bound_input = @y})
// CHECK-NOT: "tf.VarHandleOp" // CHECK-NOT: "tf.VarHandleOp"
func @unique_vars() { func @unique_vars() {
%0 = "tf.VarHandleOp"() {container = "", shape = "tfshape$", shared_name = "x"} : () -> tensor<!tf.resource<tensor<f32>>> %0 = "tf.VarHandleOp"() {container = "", shape = "tfshape$", shared_name = "x"} : () -> tensor<!tf.resource<tensor<f32>>>
@ -37,7 +41,7 @@ func @unique_vars() {
} }
// CHECK-LABEL: func @duplicate_vars // CHECK-LABEL: func @duplicate_vars
// CHECK-SAME: (%arg0: tensor<!tf.resource<tensor<f32>>> {tf.resource_name = "x"}) // CHECK-SAME: (%arg0: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @x})
// CHECK-NOT: "tf.VarHandleOp" // CHECK-NOT: "tf.VarHandleOp"
func @duplicate_vars() { func @duplicate_vars() {
%0 = "tf.VarHandleOp"() {container = "", shape = "tfshape$", shared_name = "x"} : () -> tensor<!tf.resource<tensor<f32>>> %0 = "tf.VarHandleOp"() {container = "", shape = "tfshape$", shared_name = "x"} : () -> tensor<!tf.resource<tensor<f32>>>
@ -46,7 +50,7 @@ func @duplicate_vars() {
} }
// CHECK-LABEL: func @duplicate_vars_with_users // CHECK-LABEL: func @duplicate_vars_with_users
// CHECK-SAME: (%arg0: tensor<f32>, %arg1: tensor<!tf.resource<tensor<f32>>> {tf.resource_name = "x"}) // CHECK-SAME: (%arg0: tensor<f32>, %arg1: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @x})
// CHECK: "tf.ReadVariableOp"(%arg1) // CHECK: "tf.ReadVariableOp"(%arg1)
// CHECK: "tf.AssignAddVariableOp"(%arg1, %arg0) // CHECK: "tf.AssignAddVariableOp"(%arg1, %arg0)
// CHECK-NOT: "tf.VarHandleOp" // CHECK-NOT: "tf.VarHandleOp"

View File

@ -91,9 +91,11 @@ std::unique_ptr<OperationPass<ModuleOp>> CreateResourceDeviceInferencePass();
// of their aliasing output arguments. // of their aliasing output arguments.
std::unique_ptr<OperationPass<ModuleOp>> CreatePromoteResourcesToArgsPass(); std::unique_ptr<OperationPass<ModuleOp>> CreatePromoteResourcesToArgsPass();
// Creates a pass that promotes tf.VarHandleOp to resource arguments for all // Creates a pass that promotes tf.VarHandleOp to to resource arguments of where
// functions. // resource names are `tf_saved_model.bound_input` symbol argument attributes
std::unique_ptr<OperationPass<ModuleOp>> CreatePromoteVarHandlesToArgsPass(); // for all functions.
std::unique_ptr<OperationPass<ModuleOp>>
CreatePromoteVarHandlesToSavedModelArgsPass();
// Creates a pass that converts readonly reference variables to the // Creates a pass that converts readonly reference variables to the
// corresponding resource variables. // corresponding resource variables.

View File

@ -389,15 +389,18 @@ void PromoteResourcesToArgsPass::runOnOperation() {
return signalPassFailure(); return signalPassFailure();
} }
class PromoteVarHandlesToArgsPass // This pass is for promoting Varhandle ops to tf_saved_model.bound_input
: public PassWrapper<PromoteVarHandlesToArgsPass, OperationPass<ModuleOp>> { // attributes, which are required for TensorFlowSavedModelDialect.
class PromoteVarHandlesToSavedModelArgsPass
: public PassWrapper<PromoteVarHandlesToSavedModelArgsPass,
OperationPass<ModuleOp>> {
public: public:
void runOnOperation() override; void runOnOperation() override;
}; };
void PromoteVarHandlesToArgsPass::runOnOperation() { void PromoteVarHandlesToSavedModelArgsPass::runOnOperation() {
ModuleOp module = getOperation(); ModuleOp module = getOperation();
MLIRContext* context = module.getContext();
for (auto function : module.getOps<FuncOp>()) { for (auto function : module.getOps<FuncOp>()) {
if (failed(CheckSingleBlockFunction(function))) return signalPassFailure(); if (failed(CheckSingleBlockFunction(function))) return signalPassFailure();
@ -406,13 +409,15 @@ void PromoteVarHandlesToArgsPass::runOnOperation() {
&var_handle_shared_names); &var_handle_shared_names);
// Add resource names for each `tf.VarHandleOp` that were promoted to // Add resource names for each `tf.VarHandleOp` that were promoted to
// resource arguments. // saved model arguments.
const int var_handle_args_offset = const int var_handle_args_offset =
function.getNumArguments() - var_handle_shared_names.size(); function.getNumArguments() - var_handle_shared_names.size();
for (auto var_name_and_index : llvm::enumerate(var_handle_shared_names)) for (auto var_name_and_index : llvm::enumerate(var_handle_shared_names)) {
auto symbol_ref =
SymbolRefAttr::get(var_name_and_index.value(), &getContext());
function.setArgAttr(var_name_and_index.index() + var_handle_args_offset, function.setArgAttr(var_name_and_index.index() + var_handle_args_offset,
kResourceNameArgAttr, "tf_saved_model.bound_input", symbol_ref);
StringAttr::get(var_name_and_index.value(), context)); }
} }
} }
@ -422,17 +427,19 @@ std::unique_ptr<OperationPass<ModuleOp>> CreatePromoteResourcesToArgsPass() {
return std::make_unique<PromoteResourcesToArgsPass>(); return std::make_unique<PromoteResourcesToArgsPass>();
} }
std::unique_ptr<OperationPass<ModuleOp>> CreatePromoteVarHandlesToArgsPass() { std::unique_ptr<OperationPass<ModuleOp>>
return std::make_unique<PromoteVarHandlesToArgsPass>(); CreatePromoteVarHandlesToSavedModelArgsPass() {
return std::make_unique<PromoteVarHandlesToSavedModelArgsPass>();
} }
static PassRegistration<PromoteResourcesToArgsPass> pass( static PassRegistration<PromoteResourcesToArgsPass> pass(
"tf-promote-resources-to-args", "tf-promote-resources-to-args",
"Promote resources reads/writes to function inputs/outputs."); "Promote resources reads/writes to function inputs/outputs.");
static PassRegistration<PromoteVarHandlesToArgsPass> var_handle_pass( static PassRegistration<PromoteVarHandlesToSavedModelArgsPass> saved_model_pass(
"tf-promote-var-handles-to-args", "tf-saved-model-promote-var-handles-to-args",
"Promote tf.VarHandleOps to function arguments."); "Promote tf.VarHandleOps to function arguments in a format of "
"TensorFlowSavedModelDialect.");
} // namespace TF } // namespace TF
} // namespace mlir } // namespace mlir