diff --git a/tensorflow/compiler/mlir/tensorflow/tests/promote_var_handles_to_args.mlir b/tensorflow/compiler/mlir/tensorflow/tests/promote_var_handles_to_args.mlir index 8b8a070cfab..925062ea4ff 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/promote_var_handles_to_args.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/promote_var_handles_to_args.mlir @@ -1,4 +1,5 @@ -// RUN: tf-opt %s -split-input-file -verify-diagnostics -tf-promote-var-handles-to-args | FileCheck %s -dump-input-on-failure +// Run a pass for promoting tf.VarHandleOps to function arguments in a format of TensorFlowSavedModelDialect. +// RUN: tf-opt %s -split-input-file -verify-diagnostics -tf-saved-model-promote-var-handles-to-args | FileCheck %s -dump-input-on-failure // Tests main function with multiple blocks. @@ -11,24 +12,27 @@ func @main() { // ----- +"tf_saved_model.global_tensor"() {sym_name = "x", type = tensor, value = dense<1.67482901> : tensor} : () -> () +"tf_saved_model.global_tensor"() {sym_name = "y", type = tensor, value = dense<0> : tensor} : () -> () + // CHECK-LABEL: func @no_args -// CHECK-SAME: (%arg0: tensor {tf.resource_name = "x"}) +// CHECK-SAME: (%arg0: tensor>> {tf_saved_model.bound_input = @x}) // CHECK-NOT: "tf.VarHandleOp" func @no_args() { - %0 = "tf.VarHandleOp"() {container = "", shape = "tfshape$", shared_name = "x"} : () -> tensor + %0 = "tf.VarHandleOp"() {container = "", shape = "tfshape$", shared_name = "x"} : () -> tensor>> return } // CHECK-LABEL: func @some_args -// CHECK-SAME: (%arg0: tensor, %arg1: tensor {tf.resource_name = "x"}) +// CHECK-SAME: (%arg0: tensor, %arg1: tensor>> {tf_saved_model.bound_input = @x}) // CHECK-NOT: "tf.VarHandleOp" func @some_args(%arg0: tensor) { - %0 = "tf.VarHandleOp"() {container = "", shape = "tfshape$", shared_name = "x"} : () -> tensor + %0 = "tf.VarHandleOp"() {container = "", shape = "tfshape$", shared_name = "x"} : () -> tensor>> return } // CHECK-LABEL: func @unique_vars -// CHECK-SAME: (%arg0: tensor>> {tf.resource_name = "x"}, %arg1: tensor>> {tf.resource_name = "y"}) +// CHECK-SAME: (%arg0: tensor>> {tf_saved_model.bound_input = @x}, %arg1: tensor>> {tf_saved_model.bound_input = @y}) // CHECK-NOT: "tf.VarHandleOp" func @unique_vars() { %0 = "tf.VarHandleOp"() {container = "", shape = "tfshape$", shared_name = "x"} : () -> tensor>> @@ -37,7 +41,7 @@ func @unique_vars() { } // CHECK-LABEL: func @duplicate_vars -// CHECK-SAME: (%arg0: tensor>> {tf.resource_name = "x"}) +// CHECK-SAME: (%arg0: tensor>> {tf_saved_model.bound_input = @x}) // CHECK-NOT: "tf.VarHandleOp" func @duplicate_vars() { %0 = "tf.VarHandleOp"() {container = "", shape = "tfshape$", shared_name = "x"} : () -> tensor>> @@ -46,7 +50,7 @@ func @duplicate_vars() { } // CHECK-LABEL: func @duplicate_vars_with_users -// CHECK-SAME: (%arg0: tensor, %arg1: tensor>> {tf.resource_name = "x"}) +// CHECK-SAME: (%arg0: tensor, %arg1: tensor>> {tf_saved_model.bound_input = @x}) // CHECK: "tf.ReadVariableOp"(%arg1) // CHECK: "tf.AssignAddVariableOp"(%arg1, %arg0) // CHECK-NOT: "tf.VarHandleOp" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h index 5c140ddd6aa..93d7af96c1e 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h @@ -91,9 +91,11 @@ std::unique_ptr> CreateResourceDeviceInferencePass(); // of their aliasing output arguments. std::unique_ptr> CreatePromoteResourcesToArgsPass(); -// Creates a pass that promotes tf.VarHandleOp to resource arguments for all -// functions. -std::unique_ptr> CreatePromoteVarHandlesToArgsPass(); +// Creates a pass that promotes tf.VarHandleOp to to resource arguments of where +// resource names are `tf_saved_model.bound_input` symbol argument attributes +// for all functions. +std::unique_ptr> +CreatePromoteVarHandlesToSavedModelArgsPass(); // Creates a pass that converts readonly reference variables to the // corresponding resource variables. diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/promote_resources_to_args.cc b/tensorflow/compiler/mlir/tensorflow/transforms/promote_resources_to_args.cc index cece23b4750..0d331686c46 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/promote_resources_to_args.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/promote_resources_to_args.cc @@ -389,15 +389,18 @@ void PromoteResourcesToArgsPass::runOnOperation() { return signalPassFailure(); } -class PromoteVarHandlesToArgsPass - : public PassWrapper> { +// This pass is for promoting Varhandle ops to tf_saved_model.bound_input +// attributes, which are required for TensorFlowSavedModelDialect. +class PromoteVarHandlesToSavedModelArgsPass + : public PassWrapper> { public: void runOnOperation() override; }; -void PromoteVarHandlesToArgsPass::runOnOperation() { +void PromoteVarHandlesToSavedModelArgsPass::runOnOperation() { ModuleOp module = getOperation(); - MLIRContext* context = module.getContext(); + for (auto function : module.getOps()) { if (failed(CheckSingleBlockFunction(function))) return signalPassFailure(); @@ -406,13 +409,15 @@ void PromoteVarHandlesToArgsPass::runOnOperation() { &var_handle_shared_names); // Add resource names for each `tf.VarHandleOp` that were promoted to - // resource arguments. + // saved model arguments. const int var_handle_args_offset = function.getNumArguments() - var_handle_shared_names.size(); - for (auto var_name_and_index : llvm::enumerate(var_handle_shared_names)) + for (auto var_name_and_index : llvm::enumerate(var_handle_shared_names)) { + auto symbol_ref = + SymbolRefAttr::get(var_name_and_index.value(), &getContext()); function.setArgAttr(var_name_and_index.index() + var_handle_args_offset, - kResourceNameArgAttr, - StringAttr::get(var_name_and_index.value(), context)); + "tf_saved_model.bound_input", symbol_ref); + } } } @@ -422,17 +427,19 @@ std::unique_ptr> CreatePromoteResourcesToArgsPass() { return std::make_unique(); } -std::unique_ptr> CreatePromoteVarHandlesToArgsPass() { - return std::make_unique(); +std::unique_ptr> +CreatePromoteVarHandlesToSavedModelArgsPass() { + return std::make_unique(); } static PassRegistration pass( "tf-promote-resources-to-args", "Promote resources reads/writes to function inputs/outputs."); -static PassRegistration var_handle_pass( - "tf-promote-var-handles-to-args", - "Promote tf.VarHandleOps to function arguments."); +static PassRegistration saved_model_pass( + "tf-saved-model-promote-var-handles-to-args", + "Promote tf.VarHandleOps to function arguments in a format of " + "TensorFlowSavedModelDialect."); } // namespace TF } // namespace mlir