Add a new pass for promoting VarHandle ops to TF saved model arguments
PiperOrigin-RevId: 315275908 Change-Id: Icbc5c032bd9474d279fecf48267665025a53c1bf
This commit is contained in:
parent
9429a94225
commit
74fb47ccd2
@ -1,5 +1,4 @@
|
|||||||
// Run a pass for promoting tf.VarHandleOps to function arguments in a format of TensorFlowSavedModelDialect.
|
// RUN: tf-opt %s -split-input-file -verify-diagnostics -tf-promote-var-handles-to-args | FileCheck %s -dump-input-on-failure
|
||||||
// RUN: tf-opt %s -split-input-file -verify-diagnostics -tf-saved-model-promote-var-handles-to-args | FileCheck %s -dump-input-on-failure
|
|
||||||
|
|
||||||
// Tests main function with multiple blocks.
|
// Tests main function with multiple blocks.
|
||||||
|
|
||||||
@ -12,27 +11,24 @@ func @main() {
|
|||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
"tf_saved_model.global_tensor"() {sym_name = "x", type = tensor<f32>, value = dense<1.67482901> : tensor<f32>} : () -> ()
|
|
||||||
"tf_saved_model.global_tensor"() {sym_name = "y", type = tensor<i32>, value = dense<0> : tensor<i32>} : () -> ()
|
|
||||||
|
|
||||||
// CHECK-LABEL: func @no_args
|
// CHECK-LABEL: func @no_args
|
||||||
// CHECK-SAME: (%arg0: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @x})
|
// CHECK-SAME: (%arg0: tensor<!tf.resource> {tf.resource_name = "x"})
|
||||||
// CHECK-NOT: "tf.VarHandleOp"
|
// CHECK-NOT: "tf.VarHandleOp"
|
||||||
func @no_args() {
|
func @no_args() {
|
||||||
%0 = "tf.VarHandleOp"() {container = "", shape = "tfshape$", shared_name = "x"} : () -> tensor<!tf.resource<tensor<f32>>>
|
%0 = "tf.VarHandleOp"() {container = "", shape = "tfshape$", shared_name = "x"} : () -> tensor<!tf.resource>
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: func @some_args
|
// CHECK-LABEL: func @some_args
|
||||||
// CHECK-SAME: (%arg0: tensor<i1>, %arg1: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @x})
|
// CHECK-SAME: (%arg0: tensor<i1>, %arg1: tensor<!tf.resource> {tf.resource_name = "x"})
|
||||||
// CHECK-NOT: "tf.VarHandleOp"
|
// CHECK-NOT: "tf.VarHandleOp"
|
||||||
func @some_args(%arg0: tensor<i1>) {
|
func @some_args(%arg0: tensor<i1>) {
|
||||||
%0 = "tf.VarHandleOp"() {container = "", shape = "tfshape$", shared_name = "x"} : () -> tensor<!tf.resource<tensor<f32>>>
|
%0 = "tf.VarHandleOp"() {container = "", shape = "tfshape$", shared_name = "x"} : () -> tensor<!tf.resource>
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: func @unique_vars
|
// CHECK-LABEL: func @unique_vars
|
||||||
// CHECK-SAME: (%arg0: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @x}, %arg1: tensor<!tf.resource<tensor<i32>>> {tf_saved_model.bound_input = @y})
|
// CHECK-SAME: (%arg0: tensor<!tf.resource<tensor<f32>>> {tf.resource_name = "x"}, %arg1: tensor<!tf.resource<tensor<i32>>> {tf.resource_name = "y"})
|
||||||
// CHECK-NOT: "tf.VarHandleOp"
|
// CHECK-NOT: "tf.VarHandleOp"
|
||||||
func @unique_vars() {
|
func @unique_vars() {
|
||||||
%0 = "tf.VarHandleOp"() {container = "", shape = "tfshape$", shared_name = "x"} : () -> tensor<!tf.resource<tensor<f32>>>
|
%0 = "tf.VarHandleOp"() {container = "", shape = "tfshape$", shared_name = "x"} : () -> tensor<!tf.resource<tensor<f32>>>
|
||||||
@ -41,7 +37,7 @@ func @unique_vars() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: func @duplicate_vars
|
// CHECK-LABEL: func @duplicate_vars
|
||||||
// CHECK-SAME: (%arg0: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @x})
|
// CHECK-SAME: (%arg0: tensor<!tf.resource<tensor<f32>>> {tf.resource_name = "x"})
|
||||||
// CHECK-NOT: "tf.VarHandleOp"
|
// CHECK-NOT: "tf.VarHandleOp"
|
||||||
func @duplicate_vars() {
|
func @duplicate_vars() {
|
||||||
%0 = "tf.VarHandleOp"() {container = "", shape = "tfshape$", shared_name = "x"} : () -> tensor<!tf.resource<tensor<f32>>>
|
%0 = "tf.VarHandleOp"() {container = "", shape = "tfshape$", shared_name = "x"} : () -> tensor<!tf.resource<tensor<f32>>>
|
||||||
@ -50,7 +46,7 @@ func @duplicate_vars() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: func @duplicate_vars_with_users
|
// CHECK-LABEL: func @duplicate_vars_with_users
|
||||||
// CHECK-SAME: (%arg0: tensor<f32>, %arg1: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @x})
|
// CHECK-SAME: (%arg0: tensor<f32>, %arg1: tensor<!tf.resource<tensor<f32>>> {tf.resource_name = "x"})
|
||||||
// CHECK: "tf.ReadVariableOp"(%arg1)
|
// CHECK: "tf.ReadVariableOp"(%arg1)
|
||||||
// CHECK: "tf.AssignAddVariableOp"(%arg1, %arg0)
|
// CHECK: "tf.AssignAddVariableOp"(%arg1, %arg0)
|
||||||
// CHECK-NOT: "tf.VarHandleOp"
|
// CHECK-NOT: "tf.VarHandleOp"
|
||||||
|
@ -95,11 +95,9 @@ std::unique_ptr<OperationPass<ModuleOp>> CreateResourceDeviceInferencePass();
|
|||||||
// of their aliasing output arguments.
|
// of their aliasing output arguments.
|
||||||
std::unique_ptr<OperationPass<ModuleOp>> CreatePromoteResourcesToArgsPass();
|
std::unique_ptr<OperationPass<ModuleOp>> CreatePromoteResourcesToArgsPass();
|
||||||
|
|
||||||
// Creates a pass that promotes tf.VarHandleOp to to resource arguments of where
|
// Creates a pass that promotes tf.VarHandleOp to resource arguments for all
|
||||||
// resource names are `tf_saved_model.bound_input` symbol argument attributes
|
// functions.
|
||||||
// for all functions.
|
std::unique_ptr<OperationPass<ModuleOp>> CreatePromoteVarHandlesToArgsPass();
|
||||||
std::unique_ptr<OperationPass<ModuleOp>>
|
|
||||||
CreatePromoteVarHandlesToSavedModelArgsPass();
|
|
||||||
|
|
||||||
// Creates a pass that converts readonly reference variables to the
|
// Creates a pass that converts readonly reference variables to the
|
||||||
// corresponding resource variables.
|
// corresponding resource variables.
|
||||||
|
@ -389,18 +389,15 @@ void PromoteResourcesToArgsPass::runOnOperation() {
|
|||||||
return signalPassFailure();
|
return signalPassFailure();
|
||||||
}
|
}
|
||||||
|
|
||||||
// This pass is for promoting Varhandle ops to tf_saved_model.bound_input
|
class PromoteVarHandlesToArgsPass
|
||||||
// attributes, which are required for TensorFlowSavedModelDialect.
|
: public PassWrapper<PromoteVarHandlesToArgsPass, OperationPass<ModuleOp>> {
|
||||||
class PromoteVarHandlesToSavedModelArgsPass
|
|
||||||
: public PassWrapper<PromoteVarHandlesToSavedModelArgsPass,
|
|
||||||
OperationPass<ModuleOp>> {
|
|
||||||
public:
|
public:
|
||||||
void runOnOperation() override;
|
void runOnOperation() override;
|
||||||
};
|
};
|
||||||
|
|
||||||
void PromoteVarHandlesToSavedModelArgsPass::runOnOperation() {
|
void PromoteVarHandlesToArgsPass::runOnOperation() {
|
||||||
ModuleOp module = getOperation();
|
ModuleOp module = getOperation();
|
||||||
|
MLIRContext* context = module.getContext();
|
||||||
for (auto function : module.getOps<FuncOp>()) {
|
for (auto function : module.getOps<FuncOp>()) {
|
||||||
if (failed(CheckSingleBlockFunction(function))) return signalPassFailure();
|
if (failed(CheckSingleBlockFunction(function))) return signalPassFailure();
|
||||||
|
|
||||||
@ -409,15 +406,13 @@ void PromoteVarHandlesToSavedModelArgsPass::runOnOperation() {
|
|||||||
&var_handle_shared_names);
|
&var_handle_shared_names);
|
||||||
|
|
||||||
// Add resource names for each `tf.VarHandleOp` that were promoted to
|
// Add resource names for each `tf.VarHandleOp` that were promoted to
|
||||||
// saved model arguments.
|
// resource arguments.
|
||||||
const int var_handle_args_offset =
|
const int var_handle_args_offset =
|
||||||
function.getNumArguments() - var_handle_shared_names.size();
|
function.getNumArguments() - var_handle_shared_names.size();
|
||||||
for (auto var_name_and_index : llvm::enumerate(var_handle_shared_names)) {
|
for (auto var_name_and_index : llvm::enumerate(var_handle_shared_names))
|
||||||
auto symbol_ref =
|
|
||||||
SymbolRefAttr::get(var_name_and_index.value(), &getContext());
|
|
||||||
function.setArgAttr(var_name_and_index.index() + var_handle_args_offset,
|
function.setArgAttr(var_name_and_index.index() + var_handle_args_offset,
|
||||||
"tf_saved_model.bound_input", symbol_ref);
|
kResourceNameArgAttr,
|
||||||
}
|
StringAttr::get(var_name_and_index.value(), context));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -427,19 +422,17 @@ std::unique_ptr<OperationPass<ModuleOp>> CreatePromoteResourcesToArgsPass() {
|
|||||||
return std::make_unique<PromoteResourcesToArgsPass>();
|
return std::make_unique<PromoteResourcesToArgsPass>();
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<ModuleOp>>
|
std::unique_ptr<OperationPass<ModuleOp>> CreatePromoteVarHandlesToArgsPass() {
|
||||||
CreatePromoteVarHandlesToSavedModelArgsPass() {
|
return std::make_unique<PromoteVarHandlesToArgsPass>();
|
||||||
return std::make_unique<PromoteVarHandlesToSavedModelArgsPass>();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static PassRegistration<PromoteResourcesToArgsPass> pass(
|
static PassRegistration<PromoteResourcesToArgsPass> pass(
|
||||||
"tf-promote-resources-to-args",
|
"tf-promote-resources-to-args",
|
||||||
"Promote resources reads/writes to function inputs/outputs.");
|
"Promote resources reads/writes to function inputs/outputs.");
|
||||||
|
|
||||||
static PassRegistration<PromoteVarHandlesToSavedModelArgsPass> saved_model_pass(
|
static PassRegistration<PromoteVarHandlesToArgsPass> var_handle_pass(
|
||||||
"tf-saved-model-promote-var-handles-to-args",
|
"tf-promote-var-handles-to-args",
|
||||||
"Promote tf.VarHandleOps to function arguments in a format of "
|
"Promote tf.VarHandleOps to function arguments.");
|
||||||
"TensorFlowSavedModelDialect.");
|
|
||||||
|
|
||||||
} // namespace TF
|
} // namespace TF
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
Loading…
Reference in New Issue
Block a user