From 74fb47ccd26da99e57a14fccf7561e7ba7bcb000 Mon Sep 17 00:00:00 2001 From: Jaesung Chung Date: Mon, 8 Jun 2020 08:04:26 -0700 Subject: [PATCH] Add a new pass for promoting VarHandle ops to TF saved model arguments PiperOrigin-RevId: 315275908 Change-Id: Icbc5c032bd9474d279fecf48267665025a53c1bf --- .../tests/promote_var_handles_to_args.mlir | 20 +++++------ .../mlir/tensorflow/transforms/passes.h | 8 ++--- .../transforms/promote_resources_to_args.cc | 33 ++++++++----------- 3 files changed, 24 insertions(+), 37 deletions(-) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/promote_var_handles_to_args.mlir b/tensorflow/compiler/mlir/tensorflow/tests/promote_var_handles_to_args.mlir index 925062ea4ff..8b8a070cfab 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/promote_var_handles_to_args.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/promote_var_handles_to_args.mlir @@ -1,5 +1,4 @@ -// Run a pass for promoting tf.VarHandleOps to function arguments in a format of TensorFlowSavedModelDialect. -// RUN: tf-opt %s -split-input-file -verify-diagnostics -tf-saved-model-promote-var-handles-to-args | FileCheck %s -dump-input-on-failure +// RUN: tf-opt %s -split-input-file -verify-diagnostics -tf-promote-var-handles-to-args | FileCheck %s -dump-input-on-failure // Tests main function with multiple blocks. @@ -12,27 +11,24 @@ func @main() { // ----- -"tf_saved_model.global_tensor"() {sym_name = "x", type = tensor, value = dense<1.67482901> : tensor} : () -> () -"tf_saved_model.global_tensor"() {sym_name = "y", type = tensor, value = dense<0> : tensor} : () -> () - // CHECK-LABEL: func @no_args -// CHECK-SAME: (%arg0: tensor>> {tf_saved_model.bound_input = @x}) +// CHECK-SAME: (%arg0: tensor {tf.resource_name = "x"}) // CHECK-NOT: "tf.VarHandleOp" func @no_args() { - %0 = "tf.VarHandleOp"() {container = "", shape = "tfshape$", shared_name = "x"} : () -> tensor>> + %0 = "tf.VarHandleOp"() {container = "", shape = "tfshape$", shared_name = "x"} : () -> tensor return } // CHECK-LABEL: func @some_args -// CHECK-SAME: (%arg0: tensor, %arg1: tensor>> {tf_saved_model.bound_input = @x}) +// CHECK-SAME: (%arg0: tensor, %arg1: tensor {tf.resource_name = "x"}) // CHECK-NOT: "tf.VarHandleOp" func @some_args(%arg0: tensor) { - %0 = "tf.VarHandleOp"() {container = "", shape = "tfshape$", shared_name = "x"} : () -> tensor>> + %0 = "tf.VarHandleOp"() {container = "", shape = "tfshape$", shared_name = "x"} : () -> tensor return } // CHECK-LABEL: func @unique_vars -// CHECK-SAME: (%arg0: tensor>> {tf_saved_model.bound_input = @x}, %arg1: tensor>> {tf_saved_model.bound_input = @y}) +// CHECK-SAME: (%arg0: tensor>> {tf.resource_name = "x"}, %arg1: tensor>> {tf.resource_name = "y"}) // CHECK-NOT: "tf.VarHandleOp" func @unique_vars() { %0 = "tf.VarHandleOp"() {container = "", shape = "tfshape$", shared_name = "x"} : () -> tensor>> @@ -41,7 +37,7 @@ func @unique_vars() { } // CHECK-LABEL: func @duplicate_vars -// CHECK-SAME: (%arg0: tensor>> {tf_saved_model.bound_input = @x}) +// CHECK-SAME: (%arg0: tensor>> {tf.resource_name = "x"}) // CHECK-NOT: "tf.VarHandleOp" func @duplicate_vars() { %0 = "tf.VarHandleOp"() {container = "", shape = "tfshape$", shared_name = "x"} : () -> tensor>> @@ -50,7 +46,7 @@ func @duplicate_vars() { } // CHECK-LABEL: func @duplicate_vars_with_users -// CHECK-SAME: (%arg0: tensor, %arg1: tensor>> {tf_saved_model.bound_input = @x}) +// CHECK-SAME: (%arg0: tensor, %arg1: tensor>> {tf.resource_name = "x"}) // CHECK: "tf.ReadVariableOp"(%arg1) // CHECK: "tf.AssignAddVariableOp"(%arg1, %arg0) // CHECK-NOT: "tf.VarHandleOp" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h index 3973eb60707..08c95bd8b0e 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h @@ -95,11 +95,9 @@ std::unique_ptr> CreateResourceDeviceInferencePass(); // of their aliasing output arguments. std::unique_ptr> CreatePromoteResourcesToArgsPass(); -// Creates a pass that promotes tf.VarHandleOp to to resource arguments of where -// resource names are `tf_saved_model.bound_input` symbol argument attributes -// for all functions. -std::unique_ptr> -CreatePromoteVarHandlesToSavedModelArgsPass(); +// Creates a pass that promotes tf.VarHandleOp to resource arguments for all +// functions. +std::unique_ptr> CreatePromoteVarHandlesToArgsPass(); // Creates a pass that converts readonly reference variables to the // corresponding resource variables. diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/promote_resources_to_args.cc b/tensorflow/compiler/mlir/tensorflow/transforms/promote_resources_to_args.cc index 0d331686c46..cece23b4750 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/promote_resources_to_args.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/promote_resources_to_args.cc @@ -389,18 +389,15 @@ void PromoteResourcesToArgsPass::runOnOperation() { return signalPassFailure(); } -// This pass is for promoting Varhandle ops to tf_saved_model.bound_input -// attributes, which are required for TensorFlowSavedModelDialect. -class PromoteVarHandlesToSavedModelArgsPass - : public PassWrapper> { +class PromoteVarHandlesToArgsPass + : public PassWrapper> { public: void runOnOperation() override; }; -void PromoteVarHandlesToSavedModelArgsPass::runOnOperation() { +void PromoteVarHandlesToArgsPass::runOnOperation() { ModuleOp module = getOperation(); - + MLIRContext* context = module.getContext(); for (auto function : module.getOps()) { if (failed(CheckSingleBlockFunction(function))) return signalPassFailure(); @@ -409,15 +406,13 @@ void PromoteVarHandlesToSavedModelArgsPass::runOnOperation() { &var_handle_shared_names); // Add resource names for each `tf.VarHandleOp` that were promoted to - // saved model arguments. + // resource arguments. const int var_handle_args_offset = function.getNumArguments() - var_handle_shared_names.size(); - for (auto var_name_and_index : llvm::enumerate(var_handle_shared_names)) { - auto symbol_ref = - SymbolRefAttr::get(var_name_and_index.value(), &getContext()); + for (auto var_name_and_index : llvm::enumerate(var_handle_shared_names)) function.setArgAttr(var_name_and_index.index() + var_handle_args_offset, - "tf_saved_model.bound_input", symbol_ref); - } + kResourceNameArgAttr, + StringAttr::get(var_name_and_index.value(), context)); } } @@ -427,19 +422,17 @@ std::unique_ptr> CreatePromoteResourcesToArgsPass() { return std::make_unique(); } -std::unique_ptr> -CreatePromoteVarHandlesToSavedModelArgsPass() { - return std::make_unique(); +std::unique_ptr> CreatePromoteVarHandlesToArgsPass() { + return std::make_unique(); } static PassRegistration pass( "tf-promote-resources-to-args", "Promote resources reads/writes to function inputs/outputs."); -static PassRegistration saved_model_pass( - "tf-saved-model-promote-var-handles-to-args", - "Promote tf.VarHandleOps to function arguments in a format of " - "TensorFlowSavedModelDialect."); +static PassRegistration var_handle_pass( + "tf-promote-var-handles-to-args", + "Promote tf.VarHandleOps to function arguments."); } // namespace TF } // namespace mlir