Implement a pass for lifting variables.
This pass creates GlobalTensorOp for each variable from function arguments and converts the function arguments to the corresponding saved model arguments. This change fixes the Kokoro builds by avoiding the common_runtime dependency. PiperOrigin-RevId: 316779884 Change-Id: I07c83bf12486748e4350717d94928a75bad92342
This commit is contained in:
parent
fad7b3a33b
commit
89a1d3f4e9
|
@ -397,6 +397,73 @@ cc_library(
|
|||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "lift_variables_lib",
|
||||
srcs = [
|
||||
"transforms/lift_variables.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"transforms/lift_variables.h",
|
||||
],
|
||||
deps = [
|
||||
":convert_tensor",
|
||||
":tensorflow",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:core_cpu_lib",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/platform:errors",
|
||||
"//tensorflow/core/platform:threadpool_options",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:Support",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "lift_variables_pass",
|
||||
hdrs = [
|
||||
"transforms/lift_variables_pass.h",
|
||||
],
|
||||
deps = [
|
||||
":lift_variables_lib",
|
||||
":tensorflow",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:Support",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "lift_variables_test_pass",
|
||||
hdrs = [
|
||||
"transforms/lift_variables_test_pass.h",
|
||||
],
|
||||
deps = [
|
||||
":lift_variables_lib",
|
||||
":tensorflow",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:core_cpu_lib",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/platform:errors",
|
||||
"//tensorflow/core/platform:status",
|
||||
"//tensorflow/core/platform:threadpool_options",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:Support",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tensorflow_passes",
|
||||
srcs = [
|
||||
|
@ -520,9 +587,11 @@ cc_library(
|
|||
cc_library(
|
||||
name = "tensorflow_test_passes",
|
||||
srcs = [
|
||||
"transforms/lift_variables_test_pass_registration.cc",
|
||||
"transforms/lower_tf_pass.cc",
|
||||
],
|
||||
deps = [
|
||||
":lift_variables_test_pass",
|
||||
":lower_tf_lib",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
|
|
|
@ -0,0 +1,61 @@
|
|||
// RUN: tf-opt -verify-diagnostics -tf-saved-model-lift-variables-test -split-input-file %s | FileCheck %s --dump-input=fail
|
||||
|
||||
module attributes {tf_saved_model.semantics} {
|
||||
|
||||
// Test case: Freezing VarHandleOp ops.
|
||||
|
||||
func @serving_default(%arg0: tensor<!tf.resource<tensor<100x50xf32>>> {tf.resource_name = "dense/kernel"}, %arg1: tensor<!tf.resource<tensor<50xf32>>> {tf.resource_name = "dense/bias"}) -> (tensor<100x50xf32> {tf_saved_model.index_path = ["dense_2"]})
|
||||
attributes {tf.entry_function = {control_outputs = "", inputs = "", outputs = "dense_2/Add:0"}, tf_saved_model.exported_names = ["serving_default"]} {
|
||||
%0 = "tf.VarHandleOp"() {_class = ["loc:@dense/kernel"], allowed_devices = [], container = "", device = "", shared_name = "dense/kernel"} : () -> tensor<!tf.resource<tensor<100x50xf32>>>
|
||||
%1 = "tf.ReadVariableOp"(%0) {device = ""} : (tensor<!tf.resource<tensor<100x50xf32>>>) -> tensor<100x50xf32>
|
||||
%2 = "tf.VarHandleOp"() {_class = ["loc:@dense/bias"], allowed_devices = [], container = "", device = "", shared_name = "dense/bias"} : () -> tensor<!tf.resource<tensor<50xf32>>>
|
||||
%3 = "tf.ReadVariableOp"(%2) {device = ""} : (tensor<!tf.resource<tensor<50xf32>>>) -> tensor<50xf32>
|
||||
%4 = "tf.Add"(%1, %3) {device = ""} : (tensor<100x50xf32>, tensor<50xf32>) -> tensor<100x50xf32>
|
||||
return %4 : tensor<100x50xf32>
|
||||
}
|
||||
// CHECK: "tf_saved_model.global_tensor"()
|
||||
// CHECK: sym_name = "dense/kernel"
|
||||
// CHECK: "tf_saved_model.global_tensor"()
|
||||
// CHECK: sym_name = "dense/bias"
|
||||
// CHECK: func @serving_default(
|
||||
// CHECK: %arg0: tensor<!tf.resource<tensor<100x50xf32>>> {tf_saved_model.bound_input = @"dense/kernel"},
|
||||
// CHECK: %arg1: tensor<!tf.resource<tensor<50xf32>>> {tf_saved_model.bound_input = @"dense/bias"})
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
module attributes {tf_saved_model.semantics} {
|
||||
|
||||
// Test case: Freezing shared VarHandleOp ops.
|
||||
|
||||
func @f(%arg0: tensor<!tf.resource<tensor<100x50xf32>>> {tf.resource_name = "dense/kernel"}, %arg1: tensor<!tf.resource<tensor<50xf32>>> {tf.resource_name = "dense/bias"}) -> (tensor<100x50xf32> {tf_saved_model.index_path = ["dense_2"]})
|
||||
attributes {tf.entry_function = {control_outputs = "", inputs = "", outputs = "dense_2/Add:0"}, tf_saved_model.exported_names = ["f"]} {
|
||||
%0 = "tf.VarHandleOp"() {_class = ["loc:@dense/kernel"], allowed_devices = [], container = "", device = "", shared_name = "dense/kernel"} : () -> tensor<!tf.resource<tensor<100x50xf32>>>
|
||||
%1 = "tf.ReadVariableOp"(%0) {device = ""} : (tensor<!tf.resource<tensor<100x50xf32>>>) -> tensor<100x50xf32>
|
||||
%2 = "tf.VarHandleOp"() {_class = ["loc:@dense/bias"], allowed_devices = [], container = "", device = "", shared_name = "dense/bias"} : () -> tensor<!tf.resource<tensor<50xf32>>>
|
||||
%3 = "tf.ReadVariableOp"(%2) {device = ""} : (tensor<!tf.resource<tensor<50xf32>>>) -> tensor<50xf32>
|
||||
%4 = "tf.Add"(%1, %3) {device = ""} : (tensor<100x50xf32>, tensor<50xf32>) -> tensor<100x50xf32>
|
||||
return %4 : tensor<100x50xf32>
|
||||
}
|
||||
|
||||
func @f2(%arg0: tensor<!tf.resource<tensor<100x50xf32>>> {tf.resource_name = "dense/kernel"}, %arg1: tensor<!tf.resource<tensor<50xf32>>> {tf.resource_name = "dense/bias"}) -> (tensor<100x50xf32> {tf_saved_model.index_path = ["dense_2"]})
|
||||
attributes {tf.entry_function = {control_outputs = "", inputs = "", outputs = "dense_2/Add:0"}, tf_saved_model.exported_names = ["f2"]} {
|
||||
%0 = "tf.VarHandleOp"() {_class = ["loc:@dense/kernel"], allowed_devices = [], container = "", device = "", shared_name = "dense/kernel"} : () -> tensor<!tf.resource<tensor<100x50xf32>>>
|
||||
%1 = "tf.ReadVariableOp"(%0) {device = ""} : (tensor<!tf.resource<tensor<100x50xf32>>>) -> tensor<100x50xf32>
|
||||
%2 = "tf.VarHandleOp"() {_class = ["loc:@dense/bias"], allowed_devices = [], container = "", device = "", shared_name = "dense/bias"} : () -> tensor<!tf.resource<tensor<50xf32>>>
|
||||
%3 = "tf.ReadVariableOp"(%2) {device = ""} : (tensor<!tf.resource<tensor<50xf32>>>) -> tensor<50xf32>
|
||||
%4 = "tf.Add"(%1, %3) {device = ""} : (tensor<100x50xf32>, tensor<50xf32>) -> tensor<100x50xf32>
|
||||
return %4 : tensor<100x50xf32>
|
||||
}
|
||||
// CHECK: "tf_saved_model.global_tensor"()
|
||||
// CHECK: sym_name = "dense/kernel"
|
||||
// CHECK: "tf_saved_model.global_tensor"()
|
||||
// CHECK: sym_name = "dense/bias"
|
||||
// CHECK: func @f(
|
||||
// CHECK: %arg0: tensor<!tf.resource<tensor<100x50xf32>>> {tf_saved_model.bound_input = @"dense/kernel"},
|
||||
// CHECK: %arg1: tensor<!tf.resource<tensor<50xf32>>> {tf_saved_model.bound_input = @"dense/bias"})
|
||||
|
||||
// CHECK: func @f2(
|
||||
// CHECK: %arg0: tensor<!tf.resource<tensor<100x50xf32>>> {tf_saved_model.bound_input = @"dense/kernel"},
|
||||
// CHECK: %arg1: tensor<!tf.resource<tensor<50xf32>>> {tf_saved_model.bound_input = @"dense/bias"})
|
||||
}
|
|
@ -0,0 +1,33 @@
|
|||
// RUN: tf-opt -verify-diagnostics -tf-saved-model-lift-variables-invalid-session-test -split-input-file %s | FileCheck %s --dump-input=fail
|
||||
|
||||
// Test case: Invalid session.
|
||||
// expected-error @+1 {{'module' op no session provided}}
|
||||
module attributes {tf_saved_model.semantics} {
|
||||
|
||||
func @serving_default(%arg0: tensor<!tf.resource<tensor<100x50xf32>>> {tf.resource_name = "dense/kernel"}, %arg1: tensor<!tf.resource<tensor<50xf32>>> {tf.resource_name = "dense/bias"}) -> (tensor<100x50xf32> {tf_saved_model.index_path = ["dense_2"]})
|
||||
attributes {tf.entry_function = {control_outputs = "", inputs = "", outputs = "dense_2/Add:0"}, tf_saved_model.exported_names = ["serving_default"]} {
|
||||
%0 = "tf.VarHandleOp"() {_class = ["loc:@dense/kernel"], allowed_devices = [], container = "", device = "", shared_name = "dense/kernel"} : () -> tensor<!tf.resource<tensor<100x50xf32>>>
|
||||
%1 = "tf.ReadVariableOp"(%0) {device = ""} : (tensor<!tf.resource<tensor<100x50xf32>>>) -> tensor<100x50xf32>
|
||||
%2 = "tf.VarHandleOp"() {_class = ["loc:@dense/bias"], allowed_devices = [], container = "", device = "", shared_name = "dense/bias"} : () -> tensor<!tf.resource<tensor<50xf32>>>
|
||||
%3 = "tf.ReadVariableOp"(%2) {device = ""} : (tensor<!tf.resource<tensor<50xf32>>>) -> tensor<50xf32>
|
||||
%4 = "tf.Add"(%1, %3) {device = ""} : (tensor<100x50xf32>, tensor<50xf32>) -> tensor<100x50xf32>
|
||||
return %4 : tensor<100x50xf32>
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test case: No errors on no resource arguments.
|
||||
module attributes {tf_saved_model.semantics} {
|
||||
|
||||
// CHECK-LABEL: @serving_default
|
||||
func @serving_default() -> (tensor<100x50xf32> {tf_saved_model.index_path = ["dense_2"]})
|
||||
attributes {tf.entry_function = {control_outputs = "", inputs = "", outputs = "dense_2/Add:0"}, tf_saved_model.exported_names = ["serving_default"]} {
|
||||
%0 = "tf.VarHandleOp"() {_class = ["loc:@dense/kernel"], allowed_devices = [], container = "", device = "", shared_name = "dense/kernel"} : () -> tensor<!tf.resource<tensor<100x50xf32>>>
|
||||
%1 = "tf.ReadVariableOp"(%0) {device = ""} : (tensor<!tf.resource<tensor<100x50xf32>>>) -> tensor<100x50xf32>
|
||||
%2 = "tf.VarHandleOp"() {_class = ["loc:@dense/bias"], allowed_devices = [], container = "", device = "", shared_name = "dense/bias"} : () -> tensor<!tf.resource<tensor<50xf32>>>
|
||||
%3 = "tf.ReadVariableOp"(%2) {device = ""} : (tensor<!tf.resource<tensor<50xf32>>>) -> tensor<50xf32>
|
||||
%4 = "tf.Add"(%1, %3) {device = ""} : (tensor<100x50xf32>, tensor<50xf32>) -> tensor<100x50xf32>
|
||||
return %4 : tensor<100x50xf32>
|
||||
}
|
||||
}
|
|
@ -0,0 +1,183 @@
|
|||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/lift_variables.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <iterator>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallSet.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "mlir/IR/UseDefLists.h" // from @llvm-project
|
||||
#include "mlir/Support/LLVM.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||
#include "tensorflow/core/framework/resource_mgr.h"
|
||||
#include "tensorflow/core/framework/resource_var.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/threadpool_options.h"
|
||||
#include "tensorflow/core/public/session.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace tf_saved_model {
|
||||
|
||||
using llvm::SmallSet;
|
||||
using ::tensorflow::Device;
|
||||
using ::tensorflow::DeviceMgr;
|
||||
using ::tensorflow::mutex_lock;
|
||||
using ::tensorflow::ResourceHandle;
|
||||
using ::tensorflow::Session;
|
||||
using ::tensorflow::Status;
|
||||
using ::tensorflow::StatusOr;
|
||||
using ::tensorflow::Tensor;
|
||||
using ::tensorflow::Var;
|
||||
|
||||
namespace {
|
||||
|
||||
constexpr char kResourceNameArgAttr[] = "tf.resource_name";
|
||||
constexpr char kSavedModelArgAttr[] = "tf_saved_model.bound_input";
|
||||
|
||||
LogicalResult LiftVariablesFromSession(
|
||||
ModuleOp module, Session* session,
|
||||
const SmallSet<StringRef, 4>& resource_names) {
|
||||
OpBuilder builder(module.getBodyRegion());
|
||||
MLIRContext* context = module.getContext();
|
||||
|
||||
if (!session) return module.emitOpError() << "no session provided";
|
||||
|
||||
// Read all resource variables from the session.
|
||||
std::vector<std::string> variable_names;
|
||||
variable_names.reserve(resource_names.size());
|
||||
for (StringRef name : resource_names) variable_names.push_back(name.str());
|
||||
|
||||
std::vector<Tensor> resource_tensors;
|
||||
Status status = session->Run(
|
||||
/*inputs=*/{}, variable_names,
|
||||
/*target_node_names=*/{}, &resource_tensors);
|
||||
if (!status.ok()) {
|
||||
return module.emitOpError()
|
||||
<< "failed to run the provided session: " << status.error_message();
|
||||
}
|
||||
|
||||
const DeviceMgr* device_manager;
|
||||
if (!(session->LocalDeviceManager(&device_manager).ok())) {
|
||||
return module.emitOpError() << "failed to get local device manager";
|
||||
}
|
||||
|
||||
// Read all underlying tensors of the variables from the session.
|
||||
std::vector<Tensor> tensors;
|
||||
tensors.reserve(resource_tensors.size());
|
||||
for (const Tensor& resource_tensor : resource_tensors) {
|
||||
if (resource_tensor.dtype() != tensorflow::DT_RESOURCE) {
|
||||
tensors.push_back(resource_tensor);
|
||||
continue;
|
||||
}
|
||||
|
||||
const ResourceHandle& resource_handle =
|
||||
resource_tensor.scalar<ResourceHandle>()();
|
||||
|
||||
Device* device;
|
||||
if (!(device_manager->LookupDevice(resource_handle.device(), &device)
|
||||
.ok())) {
|
||||
return module.emitOpError() << "failed to look up device";
|
||||
}
|
||||
|
||||
tensorflow::Var* var_ptr;
|
||||
if (!(device->resource_manager()
|
||||
->Lookup(resource_handle.container(), resource_handle.name(),
|
||||
&var_ptr)
|
||||
.ok())) {
|
||||
return module.emitOpError() << "failed to look up resource value";
|
||||
}
|
||||
tensorflow::core::RefCountPtr<Var> var(var_ptr);
|
||||
|
||||
// The variable tensor is already loaded into corresponding device's
|
||||
// resource manager when we load the saved model using LoadSavedModel().
|
||||
// Here we just read its value.
|
||||
mutex_lock ml(*var->mu());
|
||||
tensors.push_back(*var->tensor());
|
||||
}
|
||||
|
||||
for (const auto iter : llvm::zip(resource_names, tensors)) {
|
||||
const StringRef name = std::get<0>(iter);
|
||||
const Tensor& tensor = std::get<1>(iter);
|
||||
|
||||
// Create tensor attribute for this variable.
|
||||
StatusOr<ElementsAttr> tensor_attr_or = ConvertTensor(tensor, &builder);
|
||||
if (!tensor_attr_or.ok()) {
|
||||
return module.emitOpError()
|
||||
<< "failed to convert tensor (name: " << name.str() << ")";
|
||||
}
|
||||
ElementsAttr tensor_attr = tensor_attr_or.ValueOrDie();
|
||||
|
||||
builder.create<tf_saved_model::GlobalTensorOp>(
|
||||
NameLoc::get(builder.getIdentifier(name.str()), context),
|
||||
builder.getStringAttr(name), tensor_attr,
|
||||
TypeAttr::get(tensor_attr.getType()), builder.getUnitAttr());
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
LogicalResult LiftVariables(ModuleOp module, Session* session) {
|
||||
MLIRContext* context = module.getContext();
|
||||
mlir::Builder builder(context);
|
||||
Identifier resource_name_id = builder.getIdentifier(kResourceNameArgAttr);
|
||||
|
||||
SmallSet<StringRef, 4> resource_names;
|
||||
|
||||
for (FuncOp func : module.getOps<FuncOp>()) {
|
||||
for (int i = 0, e = func.getNumArguments(); i < e; ++i) {
|
||||
auto resource_arg =
|
||||
func.getArgAttrOfType<StringAttr>(i, kResourceNameArgAttr);
|
||||
if (!resource_arg) continue;
|
||||
|
||||
StringRef resource_name = resource_arg.getValue();
|
||||
auto flat_symbol_ref_attr =
|
||||
FlatSymbolRefAttr::get(resource_name, context);
|
||||
|
||||
// Add the corresponding `tf_saved_model.bound_input` attribute.
|
||||
func.setArgAttr(i, kSavedModelArgAttr, flat_symbol_ref_attr);
|
||||
|
||||
resource_names.insert(flat_symbol_ref_attr.getValue());
|
||||
|
||||
// Remove the existing `tf.resource_name` attribute.
|
||||
func.removeArgAttr(i, resource_name_id);
|
||||
}
|
||||
}
|
||||
|
||||
if (resource_names.empty()) return success();
|
||||
|
||||
return LiftVariablesFromSession(module, session, resource_names);
|
||||
}
|
||||
|
||||
} // namespace tf_saved_model
|
||||
} // namespace mlir
|
|
@ -0,0 +1,33 @@
|
|||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_LIFT_VARIABLES_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_LIFT_VARIABLES_H_
|
||||
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // from @llvm-project
|
||||
#include "tensorflow/core/public/session.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace tf_saved_model {
|
||||
|
||||
// Creates GlobalTensorOp for each variable from function arguments and converts
|
||||
// them to the corresponding saved model arguments.
|
||||
LogicalResult LiftVariables(ModuleOp module, ::tensorflow::Session* session);
|
||||
|
||||
} // namespace tf_saved_model
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_LIFT_VARIABLES_H_
|
|
@ -0,0 +1,57 @@
|
|||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_LIFT_VARIABLES_PASS_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_LIFT_VARIABLES_PASS_H_
|
||||
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Pass/PassManager.h" // from @llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/lift_variables.h"
|
||||
#include "tensorflow/core/public/session.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace tf_saved_model {
|
||||
|
||||
// This pass takes care of finding all variables from the function arguments and
|
||||
// converting them to the corresponding global tensors, that will be located out
|
||||
// of function. Also it converts resource arguments from function types to the
|
||||
// corresponding saved model arguments accordingly.
|
||||
class LiftVariablesPass
|
||||
: public PassWrapper<LiftVariablesPass, OperationPass<ModuleOp>> {
|
||||
public:
|
||||
explicit LiftVariablesPass(::tensorflow::Session* session)
|
||||
: session_(session) {}
|
||||
|
||||
void runOnOperation() override {
|
||||
ModuleOp module = getOperation();
|
||||
if (failed(LiftVariables(module, session_))) signalPassFailure();
|
||||
}
|
||||
|
||||
private:
|
||||
::tensorflow::Session* session_;
|
||||
};
|
||||
|
||||
// Creates as pass that creates GlobalTensorOp for each variable from function
|
||||
// arguments and converts the function arguments to the corresponding saved
|
||||
// model arguments.
|
||||
std::unique_ptr<OperationPass<ModuleOp>> CreateLiftVariablesPass(
|
||||
::tensorflow::Session* session);
|
||||
|
||||
} // namespace tf_saved_model
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_LIFT_VARIABLES_PASS_H_
|
|
@ -0,0 +1,146 @@
|
|||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_LIFT_VARIABLES_TEST_PASS_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_LIFT_VARIABLES_TEST_PASS_H_
|
||||
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Pass/PassManager.h" // from @llvm-project
|
||||
#include "mlir/Support/LLVM.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/lift_variables.h"
|
||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||
#include "tensorflow/core/framework/resource_mgr.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/platform/threadpool_options.h"
|
||||
#include "tensorflow/core/public/session.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace tf_saved_model {
|
||||
|
||||
using ::tensorflow::DeviceMgr;
|
||||
using ::tensorflow::Session;
|
||||
using ::tensorflow::Status;
|
||||
using ::tensorflow::Tensor;
|
||||
|
||||
// FakeSession is for testing only.
|
||||
class FakeSession : public tensorflow::Session {
|
||||
public:
|
||||
FakeSession() {}
|
||||
~FakeSession() override = default;
|
||||
|
||||
Status Create(const tensorflow::GraphDef& graph) override {
|
||||
return tensorflow::errors::Unimplemented("not available");
|
||||
}
|
||||
Status Extend(const tensorflow::GraphDef& graph) override {
|
||||
return tensorflow::errors::Unimplemented("not available");
|
||||
}
|
||||
|
||||
Status Close() override {
|
||||
return tensorflow::errors::Unimplemented("not available");
|
||||
}
|
||||
|
||||
Status ListDevices(
|
||||
std::vector<tensorflow::DeviceAttributes>* response) override {
|
||||
return tensorflow::errors::Unimplemented("not available");
|
||||
}
|
||||
|
||||
Status LocalDeviceManager(
|
||||
const tensorflow::DeviceMgr** deviceMgrPtr) override {
|
||||
// This method returns a null device manager without making an error.
|
||||
// Users of this method will be notified since it will have a fake data.
|
||||
*deviceMgrPtr = nullptr;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Run(const std::vector<std::pair<std::string, Tensor>>& inputs,
|
||||
const std::vector<std::string>& output_names,
|
||||
const std::vector<std::string>& target_nodes,
|
||||
std::vector<Tensor>* outputs) override {
|
||||
tensorflow::RunMetadata run_metadata;
|
||||
return Run(tensorflow::RunOptions(), inputs, output_names, target_nodes,
|
||||
outputs, &run_metadata);
|
||||
}
|
||||
|
||||
Status Run(const tensorflow::RunOptions& run_options,
|
||||
const std::vector<std::pair<std::string, Tensor>>& inputs,
|
||||
const std::vector<std::string>& output_names,
|
||||
const std::vector<std::string>& target_nodes,
|
||||
std::vector<Tensor>* outputs,
|
||||
tensorflow::RunMetadata* run_metadata) override {
|
||||
return Run(run_options, inputs, output_names, target_nodes, outputs,
|
||||
run_metadata, tensorflow::thread::ThreadPoolOptions());
|
||||
}
|
||||
|
||||
Status Run(const tensorflow::RunOptions& run_options,
|
||||
const std::vector<std::pair<std::string, Tensor>>& inputs,
|
||||
const std::vector<std::string>& output_names,
|
||||
const std::vector<std::string>& target_nodes,
|
||||
std::vector<Tensor>* outputs,
|
||||
tensorflow::RunMetadata* run_metadata,
|
||||
const tensorflow::thread::ThreadPoolOptions& thread_pool_options)
|
||||
override {
|
||||
for (const std::string& output_name : output_names) {
|
||||
Tensor output;
|
||||
if (output_name == "dense/bias") {
|
||||
outputs->push_back(
|
||||
Tensor(tensorflow::DT_FLOAT, tensorflow::TensorShape({50})));
|
||||
} else if (output_name == "dense/kernel") {
|
||||
outputs->push_back(
|
||||
Tensor(tensorflow::DT_FLOAT, tensorflow::TensorShape({100, 50})));
|
||||
} else {
|
||||
// Create a scalar float tensor.
|
||||
outputs->push_back(
|
||||
Tensor(tensorflow::DT_FLOAT, tensorflow::TensorShape({})));
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
};
|
||||
|
||||
// This pass is only available in the tf-opt binary for testing.
|
||||
class LiftVariablesTestPass
|
||||
: public PassWrapper<LiftVariablesTestPass, OperationPass<ModuleOp>> {
|
||||
public:
|
||||
LiftVariablesTestPass() { session_ = new FakeSession(); }
|
||||
|
||||
~LiftVariablesTestPass() override { delete session_; }
|
||||
|
||||
void runOnOperation() override {
|
||||
ModuleOp module = getOperation();
|
||||
if (failed(LiftVariables(module, session_))) signalPassFailure();
|
||||
}
|
||||
|
||||
private:
|
||||
Session* session_;
|
||||
};
|
||||
|
||||
// This pass is only available in the tf-opt binary for testing.
|
||||
class LiftVariablesInvalidSessionTestPass
|
||||
: public PassWrapper<LiftVariablesInvalidSessionTestPass,
|
||||
OperationPass<ModuleOp>> {
|
||||
public:
|
||||
void runOnOperation() override {
|
||||
ModuleOp module = getOperation();
|
||||
// Pass an invalid session argument, which is a nullptr.
|
||||
if (failed(LiftVariables(module, /*session=*/nullptr))) signalPassFailure();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace tf_saved_model
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_LIFT_VARIABLES_TEST_PASS_H_
|
|
@ -0,0 +1,32 @@
|
|||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/lift_variables_test_pass.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace tf_saved_model {
|
||||
|
||||
static PassRegistration<LiftVariablesTestPass> lift_variables_test_pass(
|
||||
"tf-saved-model-lift-variables-test",
|
||||
"Lift variables and save them as global tensors");
|
||||
|
||||
static PassRegistration<LiftVariablesInvalidSessionTestPass>
|
||||
lift_variables_invalid_session_test_pass(
|
||||
"tf-saved-model-lift-variables-invalid-session-test",
|
||||
"Lift variables and save them as global tensors with an invalid "
|
||||
"session");
|
||||
|
||||
} // namespace tf_saved_model
|
||||
} // namespace mlir
|
Loading…
Reference in New Issue