Implement a pass for lifting variables.
This pass creates GlobalTensorOp for each variable from function arguments and converts the function arguments to the corresponding saved model arguments. PiperOrigin-RevId: 316275316 Change-Id: Ib854dcca14e7bd527bb006683f9a97484cefea63
This commit is contained in:
parent
cfe8ae1f51
commit
8e6cb257fe
@ -397,64 +397,6 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "lift_variables_lib",
|
||||
srcs = [
|
||||
"transforms/lift_variables.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"transforms/lift_variables.h",
|
||||
],
|
||||
deps = [
|
||||
":convert_tensor",
|
||||
":tensorflow",
|
||||
"//tensorflow/core:core_cpu_base",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_internal_impl",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/common_runtime:core_cpu_impl",
|
||||
"//tensorflow/core/platform:errors",
|
||||
"//tensorflow/core/platform:threadpool_options",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:Support",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "lift_variables_pass",
|
||||
hdrs = [
|
||||
"transforms/lift_variables_pass.h",
|
||||
],
|
||||
deps = [
|
||||
":lift_variables_lib",
|
||||
":tensorflow",
|
||||
"//tensorflow/core:core_cpu_base",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:Support",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "lift_variables_test_pass",
|
||||
hdrs = [
|
||||
"transforms/lift_variables_test_pass.h",
|
||||
],
|
||||
deps = [
|
||||
":lift_variables_lib",
|
||||
":tensorflow",
|
||||
"//tensorflow/core:core_cpu_base",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core/platform:errors",
|
||||
"//tensorflow/core/platform:status",
|
||||
"//tensorflow/core/platform:threadpool_options",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:Support",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tensorflow_passes",
|
||||
srcs = [
|
||||
@ -575,11 +517,9 @@ cc_library(
|
||||
cc_library(
|
||||
name = "tensorflow_test_passes",
|
||||
srcs = [
|
||||
"transforms/lift_variables_test_pass_registration.cc",
|
||||
"transforms/lower_tf_pass.cc",
|
||||
],
|
||||
deps = [
|
||||
":lift_variables_test_pass",
|
||||
":lower_tf_lib",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
|
@ -1,61 +0,0 @@
|
||||
// RUN: tf-opt -verify-diagnostics -tf-saved-model-lift-variables-test -split-input-file %s | FileCheck %s --dump-input=fail
|
||||
|
||||
module attributes {tf_saved_model.semantics} {
|
||||
|
||||
// Test case: Freezing VarHandleOp ops.
|
||||
|
||||
func @serving_default(%arg0: tensor<!tf.resource<tensor<100x50xf32>>> {tf.resource_name = "dense/kernel"}, %arg1: tensor<!tf.resource<tensor<50xf32>>> {tf.resource_name = "dense/bias"}) -> (tensor<100x50xf32> {tf_saved_model.index_path = ["dense_2"]})
|
||||
attributes {tf.entry_function = {control_outputs = "", inputs = "", outputs = "dense_2/Add:0"}, tf_saved_model.exported_names = ["serving_default"]} {
|
||||
%0 = "tf.VarHandleOp"() {_class = ["loc:@dense/kernel"], allowed_devices = [], container = "", device = "", shared_name = "dense/kernel"} : () -> tensor<!tf.resource<tensor<100x50xf32>>>
|
||||
%1 = "tf.ReadVariableOp"(%0) {device = ""} : (tensor<!tf.resource<tensor<100x50xf32>>>) -> tensor<100x50xf32>
|
||||
%2 = "tf.VarHandleOp"() {_class = ["loc:@dense/bias"], allowed_devices = [], container = "", device = "", shared_name = "dense/bias"} : () -> tensor<!tf.resource<tensor<50xf32>>>
|
||||
%3 = "tf.ReadVariableOp"(%2) {device = ""} : (tensor<!tf.resource<tensor<50xf32>>>) -> tensor<50xf32>
|
||||
%4 = "tf.Add"(%1, %3) {device = ""} : (tensor<100x50xf32>, tensor<50xf32>) -> tensor<100x50xf32>
|
||||
return %4 : tensor<100x50xf32>
|
||||
}
|
||||
// CHECK: "tf_saved_model.global_tensor"()
|
||||
// CHECK: sym_name = "dense/kernel"
|
||||
// CHECK: "tf_saved_model.global_tensor"()
|
||||
// CHECK: sym_name = "dense/bias"
|
||||
// CHECK: func @serving_default(
|
||||
// CHECK: %arg0: tensor<!tf.resource<tensor<100x50xf32>>> {tf_saved_model.bound_input = @"dense/kernel"},
|
||||
// CHECK: %arg1: tensor<!tf.resource<tensor<50xf32>>> {tf_saved_model.bound_input = @"dense/bias"})
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
module attributes {tf_saved_model.semantics} {
|
||||
|
||||
// Test case: Freezing shared VarHandleOp ops.
|
||||
|
||||
func @f(%arg0: tensor<!tf.resource<tensor<100x50xf32>>> {tf.resource_name = "dense/kernel"}, %arg1: tensor<!tf.resource<tensor<50xf32>>> {tf.resource_name = "dense/bias"}) -> (tensor<100x50xf32> {tf_saved_model.index_path = ["dense_2"]})
|
||||
attributes {tf.entry_function = {control_outputs = "", inputs = "", outputs = "dense_2/Add:0"}, tf_saved_model.exported_names = ["f"]} {
|
||||
%0 = "tf.VarHandleOp"() {_class = ["loc:@dense/kernel"], allowed_devices = [], container = "", device = "", shared_name = "dense/kernel"} : () -> tensor<!tf.resource<tensor<100x50xf32>>>
|
||||
%1 = "tf.ReadVariableOp"(%0) {device = ""} : (tensor<!tf.resource<tensor<100x50xf32>>>) -> tensor<100x50xf32>
|
||||
%2 = "tf.VarHandleOp"() {_class = ["loc:@dense/bias"], allowed_devices = [], container = "", device = "", shared_name = "dense/bias"} : () -> tensor<!tf.resource<tensor<50xf32>>>
|
||||
%3 = "tf.ReadVariableOp"(%2) {device = ""} : (tensor<!tf.resource<tensor<50xf32>>>) -> tensor<50xf32>
|
||||
%4 = "tf.Add"(%1, %3) {device = ""} : (tensor<100x50xf32>, tensor<50xf32>) -> tensor<100x50xf32>
|
||||
return %4 : tensor<100x50xf32>
|
||||
}
|
||||
|
||||
func @f2(%arg0: tensor<!tf.resource<tensor<100x50xf32>>> {tf.resource_name = "dense/kernel"}, %arg1: tensor<!tf.resource<tensor<50xf32>>> {tf.resource_name = "dense/bias"}) -> (tensor<100x50xf32> {tf_saved_model.index_path = ["dense_2"]})
|
||||
attributes {tf.entry_function = {control_outputs = "", inputs = "", outputs = "dense_2/Add:0"}, tf_saved_model.exported_names = ["f2"]} {
|
||||
%0 = "tf.VarHandleOp"() {_class = ["loc:@dense/kernel"], allowed_devices = [], container = "", device = "", shared_name = "dense/kernel"} : () -> tensor<!tf.resource<tensor<100x50xf32>>>
|
||||
%1 = "tf.ReadVariableOp"(%0) {device = ""} : (tensor<!tf.resource<tensor<100x50xf32>>>) -> tensor<100x50xf32>
|
||||
%2 = "tf.VarHandleOp"() {_class = ["loc:@dense/bias"], allowed_devices = [], container = "", device = "", shared_name = "dense/bias"} : () -> tensor<!tf.resource<tensor<50xf32>>>
|
||||
%3 = "tf.ReadVariableOp"(%2) {device = ""} : (tensor<!tf.resource<tensor<50xf32>>>) -> tensor<50xf32>
|
||||
%4 = "tf.Add"(%1, %3) {device = ""} : (tensor<100x50xf32>, tensor<50xf32>) -> tensor<100x50xf32>
|
||||
return %4 : tensor<100x50xf32>
|
||||
}
|
||||
// CHECK: "tf_saved_model.global_tensor"()
|
||||
// CHECK: sym_name = "dense/kernel"
|
||||
// CHECK: "tf_saved_model.global_tensor"()
|
||||
// CHECK: sym_name = "dense/bias"
|
||||
// CHECK: func @f(
|
||||
// CHECK: %arg0: tensor<!tf.resource<tensor<100x50xf32>>> {tf_saved_model.bound_input = @"dense/kernel"},
|
||||
// CHECK: %arg1: tensor<!tf.resource<tensor<50xf32>>> {tf_saved_model.bound_input = @"dense/bias"})
|
||||
|
||||
// CHECK: func @f2(
|
||||
// CHECK: %arg0: tensor<!tf.resource<tensor<100x50xf32>>> {tf_saved_model.bound_input = @"dense/kernel"},
|
||||
// CHECK: %arg1: tensor<!tf.resource<tensor<50xf32>>> {tf_saved_model.bound_input = @"dense/bias"})
|
||||
}
|
@ -1,33 +0,0 @@
|
||||
// RUN: tf-opt -verify-diagnostics -tf-saved-model-lift-variables-invalid-session-test -split-input-file %s | FileCheck %s --dump-input=fail
|
||||
|
||||
// Test case: Invalid session.
|
||||
// expected-error @+1 {{'module' op no session provided}}
|
||||
module attributes {tf_saved_model.semantics} {
|
||||
|
||||
func @serving_default(%arg0: tensor<!tf.resource<tensor<100x50xf32>>> {tf.resource_name = "dense/kernel"}, %arg1: tensor<!tf.resource<tensor<50xf32>>> {tf.resource_name = "dense/bias"}) -> (tensor<100x50xf32> {tf_saved_model.index_path = ["dense_2"]})
|
||||
attributes {tf.entry_function = {control_outputs = "", inputs = "", outputs = "dense_2/Add:0"}, tf_saved_model.exported_names = ["serving_default"]} {
|
||||
%0 = "tf.VarHandleOp"() {_class = ["loc:@dense/kernel"], allowed_devices = [], container = "", device = "", shared_name = "dense/kernel"} : () -> tensor<!tf.resource<tensor<100x50xf32>>>
|
||||
%1 = "tf.ReadVariableOp"(%0) {device = ""} : (tensor<!tf.resource<tensor<100x50xf32>>>) -> tensor<100x50xf32>
|
||||
%2 = "tf.VarHandleOp"() {_class = ["loc:@dense/bias"], allowed_devices = [], container = "", device = "", shared_name = "dense/bias"} : () -> tensor<!tf.resource<tensor<50xf32>>>
|
||||
%3 = "tf.ReadVariableOp"(%2) {device = ""} : (tensor<!tf.resource<tensor<50xf32>>>) -> tensor<50xf32>
|
||||
%4 = "tf.Add"(%1, %3) {device = ""} : (tensor<100x50xf32>, tensor<50xf32>) -> tensor<100x50xf32>
|
||||
return %4 : tensor<100x50xf32>
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test case: No errors on no resource arguments.
|
||||
module attributes {tf_saved_model.semantics} {
|
||||
|
||||
// CHECK-LABEL: @serving_default
|
||||
func @serving_default() -> (tensor<100x50xf32> {tf_saved_model.index_path = ["dense_2"]})
|
||||
attributes {tf.entry_function = {control_outputs = "", inputs = "", outputs = "dense_2/Add:0"}, tf_saved_model.exported_names = ["serving_default"]} {
|
||||
%0 = "tf.VarHandleOp"() {_class = ["loc:@dense/kernel"], allowed_devices = [], container = "", device = "", shared_name = "dense/kernel"} : () -> tensor<!tf.resource<tensor<100x50xf32>>>
|
||||
%1 = "tf.ReadVariableOp"(%0) {device = ""} : (tensor<!tf.resource<tensor<100x50xf32>>>) -> tensor<100x50xf32>
|
||||
%2 = "tf.VarHandleOp"() {_class = ["loc:@dense/bias"], allowed_devices = [], container = "", device = "", shared_name = "dense/bias"} : () -> tensor<!tf.resource<tensor<50xf32>>>
|
||||
%3 = "tf.ReadVariableOp"(%2) {device = ""} : (tensor<!tf.resource<tensor<50xf32>>>) -> tensor<50xf32>
|
||||
%4 = "tf.Add"(%1, %3) {device = ""} : (tensor<100x50xf32>, tensor<50xf32>) -> tensor<100x50xf32>
|
||||
return %4 : tensor<100x50xf32>
|
||||
}
|
||||
}
|
@ -1,183 +0,0 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/lift_variables.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <iterator>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallSet.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "mlir/IR/UseDefLists.h" // from @llvm-project
|
||||
#include "mlir/Support/LLVM.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||
#include "tensorflow/core/framework/resource_mgr.h"
|
||||
#include "tensorflow/core/framework/resource_var.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/threadpool_options.h"
|
||||
#include "tensorflow/core/public/session.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace tf_saved_model {
|
||||
|
||||
using llvm::SmallSet;
|
||||
using ::tensorflow::Device;
|
||||
using ::tensorflow::DeviceMgr;
|
||||
using ::tensorflow::mutex_lock;
|
||||
using ::tensorflow::ResourceHandle;
|
||||
using ::tensorflow::Session;
|
||||
using ::tensorflow::Status;
|
||||
using ::tensorflow::StatusOr;
|
||||
using ::tensorflow::Tensor;
|
||||
using ::tensorflow::Var;
|
||||
|
||||
namespace {
|
||||
|
||||
constexpr char kResourceNameArgAttr[] = "tf.resource_name";
|
||||
constexpr char kSavedModelArgAttr[] = "tf_saved_model.bound_input";
|
||||
|
||||
LogicalResult LiftVariablesFromSession(
|
||||
ModuleOp module, Session* session,
|
||||
const SmallSet<StringRef, 4>& resource_names) {
|
||||
OpBuilder builder(module.getBodyRegion());
|
||||
MLIRContext* context = module.getContext();
|
||||
|
||||
if (!session) return module.emitOpError() << "no session provided";
|
||||
|
||||
// Read all resource variables from the session.
|
||||
std::vector<std::string> variable_names;
|
||||
variable_names.reserve(resource_names.size());
|
||||
for (StringRef name : resource_names) variable_names.push_back(name.str());
|
||||
|
||||
std::vector<Tensor> resource_tensors;
|
||||
Status status = session->Run(
|
||||
/*inputs=*/{}, variable_names,
|
||||
/*target_node_names=*/{}, &resource_tensors);
|
||||
if (!status.ok()) {
|
||||
return module.emitOpError()
|
||||
<< "failed to run the provided session: " << status.error_message();
|
||||
}
|
||||
|
||||
const DeviceMgr* device_manager;
|
||||
if (!(session->LocalDeviceManager(&device_manager).ok())) {
|
||||
return module.emitOpError() << "failed to get local device manager";
|
||||
}
|
||||
|
||||
// Read all underlying tensors of the variables from the session.
|
||||
std::vector<Tensor> tensors;
|
||||
tensors.reserve(resource_tensors.size());
|
||||
for (const Tensor& resource_tensor : resource_tensors) {
|
||||
if (resource_tensor.dtype() != tensorflow::DT_RESOURCE) {
|
||||
tensors.push_back(resource_tensor);
|
||||
continue;
|
||||
}
|
||||
|
||||
const ResourceHandle& resource_handle =
|
||||
resource_tensor.scalar<ResourceHandle>()();
|
||||
|
||||
Device* device;
|
||||
if (!(device_manager->LookupDevice(resource_handle.device(), &device)
|
||||
.ok())) {
|
||||
return module.emitOpError() << "failed to look up device";
|
||||
}
|
||||
|
||||
tensorflow::Var* var_ptr;
|
||||
if (!(device->resource_manager()
|
||||
->Lookup(resource_handle.container(), resource_handle.name(),
|
||||
&var_ptr)
|
||||
.ok())) {
|
||||
return module.emitOpError() << "failed to look up resource value";
|
||||
}
|
||||
tensorflow::core::RefCountPtr<Var> var(var_ptr);
|
||||
|
||||
// The variable tensor is already loaded into corresponding device's
|
||||
// resource manager when we load the saved model using LoadSavedModel().
|
||||
// Here we just read its value.
|
||||
mutex_lock ml(*var->mu());
|
||||
tensors.push_back(*var->tensor());
|
||||
}
|
||||
|
||||
for (const auto iter : llvm::zip(resource_names, tensors)) {
|
||||
const StringRef name = std::get<0>(iter);
|
||||
const Tensor& tensor = std::get<1>(iter);
|
||||
|
||||
// Create tensor attribute for this variable.
|
||||
StatusOr<ElementsAttr> tensor_attr_or = ConvertTensor(tensor, &builder);
|
||||
if (!tensor_attr_or.ok()) {
|
||||
return module.emitOpError()
|
||||
<< "failed to convert tensor (name: " << name.str() << ")";
|
||||
}
|
||||
ElementsAttr tensor_attr = tensor_attr_or.ValueOrDie();
|
||||
|
||||
builder.create<tf_saved_model::GlobalTensorOp>(
|
||||
NameLoc::get(builder.getIdentifier(name.str()), context),
|
||||
builder.getStringAttr(name), tensor_attr,
|
||||
TypeAttr::get(tensor_attr.getType()), builder.getUnitAttr());
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
LogicalResult LiftVariables(ModuleOp module, Session* session) {
|
||||
MLIRContext* context = module.getContext();
|
||||
mlir::Builder builder(context);
|
||||
Identifier resource_name_id = builder.getIdentifier(kResourceNameArgAttr);
|
||||
|
||||
SmallSet<StringRef, 4> resource_names;
|
||||
|
||||
for (FuncOp func : module.getOps<FuncOp>()) {
|
||||
for (int i = 0, e = func.getNumArguments(); i < e; ++i) {
|
||||
auto resource_arg =
|
||||
func.getArgAttrOfType<StringAttr>(i, kResourceNameArgAttr);
|
||||
if (!resource_arg) continue;
|
||||
|
||||
StringRef resource_name = resource_arg.getValue();
|
||||
auto flat_symbol_ref_attr =
|
||||
FlatSymbolRefAttr::get(resource_name, context);
|
||||
|
||||
// Add the corresponding `tf_saved_model.bound_input` attribute.
|
||||
func.setArgAttr(i, kSavedModelArgAttr, flat_symbol_ref_attr);
|
||||
|
||||
resource_names.insert(flat_symbol_ref_attr.getValue());
|
||||
|
||||
// Remove the existing `tf.resource_name` attribute.
|
||||
func.removeArgAttr(i, resource_name_id);
|
||||
}
|
||||
}
|
||||
|
||||
if (resource_names.empty()) return success();
|
||||
|
||||
return LiftVariablesFromSession(module, session, resource_names);
|
||||
}
|
||||
|
||||
} // namespace tf_saved_model
|
||||
} // namespace mlir
|
@ -1,33 +0,0 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_LIFT_VARIABLES_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_LIFT_VARIABLES_H_
|
||||
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // from @llvm-project
|
||||
#include "tensorflow/core/public/session.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace tf_saved_model {
|
||||
|
||||
// Creates GlobalTensorOp for each variable from function arguments and converts
|
||||
// them to the corresponding saved model arguments.
|
||||
LogicalResult LiftVariables(ModuleOp module, ::tensorflow::Session* session);
|
||||
|
||||
} // namespace tf_saved_model
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_LIFT_VARIABLES_H_
|
@ -1,57 +0,0 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_LIFT_VARIABLES_PASS_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_LIFT_VARIABLES_PASS_H_
|
||||
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Pass/PassManager.h" // from @llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/lift_variables.h"
|
||||
#include "tensorflow/core/public/session.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace tf_saved_model {
|
||||
|
||||
// This pass takes care of finding all variables from the function arguments and
|
||||
// converting them to the corresponding global tensors, that will be located out
|
||||
// of function. Also it converts resource arguments from function types to the
|
||||
// corresponding saved model arguments accordingly.
|
||||
class LiftVariablesPass
|
||||
: public PassWrapper<LiftVariablesPass, OperationPass<ModuleOp>> {
|
||||
public:
|
||||
explicit LiftVariablesPass(::tensorflow::Session* session)
|
||||
: session_(session) {}
|
||||
|
||||
void runOnOperation() override {
|
||||
ModuleOp module = getOperation();
|
||||
if (failed(LiftVariables(module, session_))) signalPassFailure();
|
||||
}
|
||||
|
||||
private:
|
||||
::tensorflow::Session* session_;
|
||||
};
|
||||
|
||||
// Creates as pass that creates GlobalTensorOp for each variable from function
|
||||
// arguments and converts the function arguments to the corresponding saved
|
||||
// model arguments.
|
||||
std::unique_ptr<OperationPass<ModuleOp>> CreateLiftVariablesPass(
|
||||
::tensorflow::Session* session);
|
||||
|
||||
} // namespace tf_saved_model
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_LIFT_VARIABLES_PASS_H_
|
@ -1,146 +0,0 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_LIFT_VARIABLES_TEST_PASS_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_LIFT_VARIABLES_TEST_PASS_H_
|
||||
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Pass/PassManager.h" // from @llvm-project
|
||||
#include "mlir/Support/LLVM.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/lift_variables.h"
|
||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||
#include "tensorflow/core/framework/resource_mgr.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/platform/threadpool_options.h"
|
||||
#include "tensorflow/core/public/session.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace tf_saved_model {
|
||||
|
||||
using ::tensorflow::DeviceMgr;
|
||||
using ::tensorflow::Session;
|
||||
using ::tensorflow::Status;
|
||||
using ::tensorflow::Tensor;
|
||||
|
||||
// FakeSession is for testing only.
|
||||
class FakeSession : public tensorflow::Session {
|
||||
public:
|
||||
FakeSession() {}
|
||||
~FakeSession() override = default;
|
||||
|
||||
Status Create(const tensorflow::GraphDef& graph) override {
|
||||
return tensorflow::errors::Unimplemented("not available");
|
||||
}
|
||||
Status Extend(const tensorflow::GraphDef& graph) override {
|
||||
return tensorflow::errors::Unimplemented("not available");
|
||||
}
|
||||
|
||||
Status Close() override {
|
||||
return tensorflow::errors::Unimplemented("not available");
|
||||
}
|
||||
|
||||
Status ListDevices(
|
||||
std::vector<tensorflow::DeviceAttributes>* response) override {
|
||||
return tensorflow::errors::Unimplemented("not available");
|
||||
}
|
||||
|
||||
Status LocalDeviceManager(
|
||||
const tensorflow::DeviceMgr** deviceMgrPtr) override {
|
||||
// This method returns a null device manager without making an error.
|
||||
// Users of this method will be notified since it will have a fake data.
|
||||
*deviceMgrPtr = nullptr;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Run(const std::vector<std::pair<std::string, Tensor>>& inputs,
|
||||
const std::vector<std::string>& output_names,
|
||||
const std::vector<std::string>& target_nodes,
|
||||
std::vector<Tensor>* outputs) override {
|
||||
tensorflow::RunMetadata run_metadata;
|
||||
return Run(tensorflow::RunOptions(), inputs, output_names, target_nodes,
|
||||
outputs, &run_metadata);
|
||||
}
|
||||
|
||||
Status Run(const tensorflow::RunOptions& run_options,
|
||||
const std::vector<std::pair<std::string, Tensor>>& inputs,
|
||||
const std::vector<std::string>& output_names,
|
||||
const std::vector<std::string>& target_nodes,
|
||||
std::vector<Tensor>* outputs,
|
||||
tensorflow::RunMetadata* run_metadata) override {
|
||||
return Run(run_options, inputs, output_names, target_nodes, outputs,
|
||||
run_metadata, tensorflow::thread::ThreadPoolOptions());
|
||||
}
|
||||
|
||||
Status Run(const tensorflow::RunOptions& run_options,
|
||||
const std::vector<std::pair<std::string, Tensor>>& inputs,
|
||||
const std::vector<std::string>& output_names,
|
||||
const std::vector<std::string>& target_nodes,
|
||||
std::vector<Tensor>* outputs,
|
||||
tensorflow::RunMetadata* run_metadata,
|
||||
const tensorflow::thread::ThreadPoolOptions& thread_pool_options)
|
||||
override {
|
||||
for (const std::string& output_name : output_names) {
|
||||
Tensor output;
|
||||
if (output_name == "dense/bias") {
|
||||
outputs->push_back(
|
||||
Tensor(tensorflow::DT_FLOAT, tensorflow::TensorShape({50})));
|
||||
} else if (output_name == "dense/kernel") {
|
||||
outputs->push_back(
|
||||
Tensor(tensorflow::DT_FLOAT, tensorflow::TensorShape({100, 50})));
|
||||
} else {
|
||||
// Create a scalar float tensor.
|
||||
outputs->push_back(
|
||||
Tensor(tensorflow::DT_FLOAT, tensorflow::TensorShape({})));
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
};
|
||||
|
||||
// This pass is only available in the tf-opt binary for testing.
|
||||
class LiftVariablesTestPass
|
||||
: public PassWrapper<LiftVariablesTestPass, OperationPass<ModuleOp>> {
|
||||
public:
|
||||
LiftVariablesTestPass() { session_ = new FakeSession(); }
|
||||
|
||||
~LiftVariablesTestPass() override { delete session_; }
|
||||
|
||||
void runOnOperation() override {
|
||||
ModuleOp module = getOperation();
|
||||
if (failed(LiftVariables(module, session_))) signalPassFailure();
|
||||
}
|
||||
|
||||
private:
|
||||
Session* session_;
|
||||
};
|
||||
|
||||
// This pass is only available in the tf-opt binary for testing.
|
||||
class LiftVariablesInvalidSessionTestPass
|
||||
: public PassWrapper<LiftVariablesInvalidSessionTestPass,
|
||||
OperationPass<ModuleOp>> {
|
||||
public:
|
||||
void runOnOperation() override {
|
||||
ModuleOp module = getOperation();
|
||||
// Pass an invalid session argument, which is a nullptr.
|
||||
if (failed(LiftVariables(module, /*session=*/nullptr))) signalPassFailure();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace tf_saved_model
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_LIFT_VARIABLES_TEST_PASS_H_
|
@ -1,32 +0,0 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/lift_variables_test_pass.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace tf_saved_model {
|
||||
|
||||
static PassRegistration<LiftVariablesTestPass> lift_variables_test_pass(
|
||||
"tf-saved-model-lift-variables-test",
|
||||
"Lift variables and save them as global tensors");
|
||||
|
||||
static PassRegistration<LiftVariablesInvalidSessionTestPass>
|
||||
lift_variables_invalid_session_test_pass(
|
||||
"tf-saved-model-lift-variables-invalid-session-test",
|
||||
"Lift variables and save them as global tensors with an invalid "
|
||||
"session");
|
||||
|
||||
} // namespace tf_saved_model
|
||||
} // namespace mlir
|
Loading…
x
Reference in New Issue
Block a user