Implement a pass for lifting variables.

This pass creates GlobalTensorOp for each variable from function
arguments and converts the function arguments to the corresponding saved
model arguments.

This change fixes the Kokoro builds by avoiding the common_runtime
dependency.

PiperOrigin-RevId: 316779884
Change-Id: I07c83bf12486748e4350717d94928a75bad92342
This commit is contained in:
Jaesung Chung 2020-06-16 16:24:49 -07:00 committed by TensorFlower Gardener
parent fad7b3a33b
commit 89a1d3f4e9
8 changed files with 614 additions and 0 deletions

View File

@ -397,6 +397,73 @@ cc_library(
],
)
cc_library(
name = "lift_variables_lib",
srcs = [
"transforms/lift_variables.cc",
],
hdrs = [
"transforms/lift_variables.h",
],
deps = [
":convert_tensor",
":tensorflow",
"//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_lib",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core/platform:errors",
"//tensorflow/core/platform:threadpool_options",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
],
alwayslink = 1,
)
cc_library(
name = "lift_variables_pass",
hdrs = [
"transforms/lift_variables_pass.h",
],
deps = [
":lift_variables_lib",
":tensorflow",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
],
alwayslink = 1,
)
cc_library(
name = "lift_variables_test_pass",
hdrs = [
"transforms/lift_variables_test_pass.h",
],
deps = [
":lift_variables_lib",
":tensorflow",
"//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_lib",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core/platform:errors",
"//tensorflow/core/platform:status",
"//tensorflow/core/platform:threadpool_options",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
],
alwayslink = 1,
)
cc_library(
name = "tensorflow_passes",
srcs = [
@ -520,9 +587,11 @@ cc_library(
cc_library(
name = "tensorflow_test_passes",
srcs = [
"transforms/lift_variables_test_pass_registration.cc",
"transforms/lower_tf_pass.cc",
],
deps = [
":lift_variables_test_pass",
":lower_tf_lib",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",

View File

@ -0,0 +1,61 @@
// RUN: tf-opt -verify-diagnostics -tf-saved-model-lift-variables-test -split-input-file %s | FileCheck %s --dump-input=fail
module attributes {tf_saved_model.semantics} {
// Test case: Freezing VarHandleOp ops.
func @serving_default(%arg0: tensor<!tf.resource<tensor<100x50xf32>>> {tf.resource_name = "dense/kernel"}, %arg1: tensor<!tf.resource<tensor<50xf32>>> {tf.resource_name = "dense/bias"}) -> (tensor<100x50xf32> {tf_saved_model.index_path = ["dense_2"]})
attributes {tf.entry_function = {control_outputs = "", inputs = "", outputs = "dense_2/Add:0"}, tf_saved_model.exported_names = ["serving_default"]} {
%0 = "tf.VarHandleOp"() {_class = ["loc:@dense/kernel"], allowed_devices = [], container = "", device = "", shared_name = "dense/kernel"} : () -> tensor<!tf.resource<tensor<100x50xf32>>>
%1 = "tf.ReadVariableOp"(%0) {device = ""} : (tensor<!tf.resource<tensor<100x50xf32>>>) -> tensor<100x50xf32>
%2 = "tf.VarHandleOp"() {_class = ["loc:@dense/bias"], allowed_devices = [], container = "", device = "", shared_name = "dense/bias"} : () -> tensor<!tf.resource<tensor<50xf32>>>
%3 = "tf.ReadVariableOp"(%2) {device = ""} : (tensor<!tf.resource<tensor<50xf32>>>) -> tensor<50xf32>
%4 = "tf.Add"(%1, %3) {device = ""} : (tensor<100x50xf32>, tensor<50xf32>) -> tensor<100x50xf32>
return %4 : tensor<100x50xf32>
}
// CHECK: "tf_saved_model.global_tensor"()
// CHECK: sym_name = "dense/kernel"
// CHECK: "tf_saved_model.global_tensor"()
// CHECK: sym_name = "dense/bias"
// CHECK: func @serving_default(
// CHECK: %arg0: tensor<!tf.resource<tensor<100x50xf32>>> {tf_saved_model.bound_input = @"dense/kernel"},
// CHECK: %arg1: tensor<!tf.resource<tensor<50xf32>>> {tf_saved_model.bound_input = @"dense/bias"})
}
// -----
module attributes {tf_saved_model.semantics} {
// Test case: Freezing shared VarHandleOp ops.
func @f(%arg0: tensor<!tf.resource<tensor<100x50xf32>>> {tf.resource_name = "dense/kernel"}, %arg1: tensor<!tf.resource<tensor<50xf32>>> {tf.resource_name = "dense/bias"}) -> (tensor<100x50xf32> {tf_saved_model.index_path = ["dense_2"]})
attributes {tf.entry_function = {control_outputs = "", inputs = "", outputs = "dense_2/Add:0"}, tf_saved_model.exported_names = ["f"]} {
%0 = "tf.VarHandleOp"() {_class = ["loc:@dense/kernel"], allowed_devices = [], container = "", device = "", shared_name = "dense/kernel"} : () -> tensor<!tf.resource<tensor<100x50xf32>>>
%1 = "tf.ReadVariableOp"(%0) {device = ""} : (tensor<!tf.resource<tensor<100x50xf32>>>) -> tensor<100x50xf32>
%2 = "tf.VarHandleOp"() {_class = ["loc:@dense/bias"], allowed_devices = [], container = "", device = "", shared_name = "dense/bias"} : () -> tensor<!tf.resource<tensor<50xf32>>>
%3 = "tf.ReadVariableOp"(%2) {device = ""} : (tensor<!tf.resource<tensor<50xf32>>>) -> tensor<50xf32>
%4 = "tf.Add"(%1, %3) {device = ""} : (tensor<100x50xf32>, tensor<50xf32>) -> tensor<100x50xf32>
return %4 : tensor<100x50xf32>
}
func @f2(%arg0: tensor<!tf.resource<tensor<100x50xf32>>> {tf.resource_name = "dense/kernel"}, %arg1: tensor<!tf.resource<tensor<50xf32>>> {tf.resource_name = "dense/bias"}) -> (tensor<100x50xf32> {tf_saved_model.index_path = ["dense_2"]})
attributes {tf.entry_function = {control_outputs = "", inputs = "", outputs = "dense_2/Add:0"}, tf_saved_model.exported_names = ["f2"]} {
%0 = "tf.VarHandleOp"() {_class = ["loc:@dense/kernel"], allowed_devices = [], container = "", device = "", shared_name = "dense/kernel"} : () -> tensor<!tf.resource<tensor<100x50xf32>>>
%1 = "tf.ReadVariableOp"(%0) {device = ""} : (tensor<!tf.resource<tensor<100x50xf32>>>) -> tensor<100x50xf32>
%2 = "tf.VarHandleOp"() {_class = ["loc:@dense/bias"], allowed_devices = [], container = "", device = "", shared_name = "dense/bias"} : () -> tensor<!tf.resource<tensor<50xf32>>>
%3 = "tf.ReadVariableOp"(%2) {device = ""} : (tensor<!tf.resource<tensor<50xf32>>>) -> tensor<50xf32>
%4 = "tf.Add"(%1, %3) {device = ""} : (tensor<100x50xf32>, tensor<50xf32>) -> tensor<100x50xf32>
return %4 : tensor<100x50xf32>
}
// CHECK: "tf_saved_model.global_tensor"()
// CHECK: sym_name = "dense/kernel"
// CHECK: "tf_saved_model.global_tensor"()
// CHECK: sym_name = "dense/bias"
// CHECK: func @f(
// CHECK: %arg0: tensor<!tf.resource<tensor<100x50xf32>>> {tf_saved_model.bound_input = @"dense/kernel"},
// CHECK: %arg1: tensor<!tf.resource<tensor<50xf32>>> {tf_saved_model.bound_input = @"dense/bias"})
// CHECK: func @f2(
// CHECK: %arg0: tensor<!tf.resource<tensor<100x50xf32>>> {tf_saved_model.bound_input = @"dense/kernel"},
// CHECK: %arg1: tensor<!tf.resource<tensor<50xf32>>> {tf_saved_model.bound_input = @"dense/bias"})
}

View File

@ -0,0 +1,33 @@
// RUN: tf-opt -verify-diagnostics -tf-saved-model-lift-variables-invalid-session-test -split-input-file %s | FileCheck %s --dump-input=fail
// Test case: Invalid session.
// expected-error @+1 {{'module' op no session provided}}
module attributes {tf_saved_model.semantics} {
func @serving_default(%arg0: tensor<!tf.resource<tensor<100x50xf32>>> {tf.resource_name = "dense/kernel"}, %arg1: tensor<!tf.resource<tensor<50xf32>>> {tf.resource_name = "dense/bias"}) -> (tensor<100x50xf32> {tf_saved_model.index_path = ["dense_2"]})
attributes {tf.entry_function = {control_outputs = "", inputs = "", outputs = "dense_2/Add:0"}, tf_saved_model.exported_names = ["serving_default"]} {
%0 = "tf.VarHandleOp"() {_class = ["loc:@dense/kernel"], allowed_devices = [], container = "", device = "", shared_name = "dense/kernel"} : () -> tensor<!tf.resource<tensor<100x50xf32>>>
%1 = "tf.ReadVariableOp"(%0) {device = ""} : (tensor<!tf.resource<tensor<100x50xf32>>>) -> tensor<100x50xf32>
%2 = "tf.VarHandleOp"() {_class = ["loc:@dense/bias"], allowed_devices = [], container = "", device = "", shared_name = "dense/bias"} : () -> tensor<!tf.resource<tensor<50xf32>>>
%3 = "tf.ReadVariableOp"(%2) {device = ""} : (tensor<!tf.resource<tensor<50xf32>>>) -> tensor<50xf32>
%4 = "tf.Add"(%1, %3) {device = ""} : (tensor<100x50xf32>, tensor<50xf32>) -> tensor<100x50xf32>
return %4 : tensor<100x50xf32>
}
}
// -----
// Test case: No errors on no resource arguments.
module attributes {tf_saved_model.semantics} {
// CHECK-LABEL: @serving_default
func @serving_default() -> (tensor<100x50xf32> {tf_saved_model.index_path = ["dense_2"]})
attributes {tf.entry_function = {control_outputs = "", inputs = "", outputs = "dense_2/Add:0"}, tf_saved_model.exported_names = ["serving_default"]} {
%0 = "tf.VarHandleOp"() {_class = ["loc:@dense/kernel"], allowed_devices = [], container = "", device = "", shared_name = "dense/kernel"} : () -> tensor<!tf.resource<tensor<100x50xf32>>>
%1 = "tf.ReadVariableOp"(%0) {device = ""} : (tensor<!tf.resource<tensor<100x50xf32>>>) -> tensor<100x50xf32>
%2 = "tf.VarHandleOp"() {_class = ["loc:@dense/bias"], allowed_devices = [], container = "", device = "", shared_name = "dense/bias"} : () -> tensor<!tf.resource<tensor<50xf32>>>
%3 = "tf.ReadVariableOp"(%2) {device = ""} : (tensor<!tf.resource<tensor<50xf32>>>) -> tensor<50xf32>
%4 = "tf.Add"(%1, %3) {device = ""} : (tensor<100x50xf32>, tensor<50xf32>) -> tensor<100x50xf32>
return %4 : tensor<100x50xf32>
}
}

View File

@ -0,0 +1,183 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/tensorflow/transforms/lift_variables.h"
#include <algorithm>
#include <iterator>
#include <string>
#include <tuple>
#include <type_traits>
#include <utility>
#include <vector>
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/IR/UseDefLists.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/resource_var.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/threadpool_options.h"
#include "tensorflow/core/public/session.h"
namespace mlir {
namespace tf_saved_model {
using llvm::SmallSet;
using ::tensorflow::Device;
using ::tensorflow::DeviceMgr;
using ::tensorflow::mutex_lock;
using ::tensorflow::ResourceHandle;
using ::tensorflow::Session;
using ::tensorflow::Status;
using ::tensorflow::StatusOr;
using ::tensorflow::Tensor;
using ::tensorflow::Var;
namespace {
constexpr char kResourceNameArgAttr[] = "tf.resource_name";
constexpr char kSavedModelArgAttr[] = "tf_saved_model.bound_input";
LogicalResult LiftVariablesFromSession(
ModuleOp module, Session* session,
const SmallSet<StringRef, 4>& resource_names) {
OpBuilder builder(module.getBodyRegion());
MLIRContext* context = module.getContext();
if (!session) return module.emitOpError() << "no session provided";
// Read all resource variables from the session.
std::vector<std::string> variable_names;
variable_names.reserve(resource_names.size());
for (StringRef name : resource_names) variable_names.push_back(name.str());
std::vector<Tensor> resource_tensors;
Status status = session->Run(
/*inputs=*/{}, variable_names,
/*target_node_names=*/{}, &resource_tensors);
if (!status.ok()) {
return module.emitOpError()
<< "failed to run the provided session: " << status.error_message();
}
const DeviceMgr* device_manager;
if (!(session->LocalDeviceManager(&device_manager).ok())) {
return module.emitOpError() << "failed to get local device manager";
}
// Read all underlying tensors of the variables from the session.
std::vector<Tensor> tensors;
tensors.reserve(resource_tensors.size());
for (const Tensor& resource_tensor : resource_tensors) {
if (resource_tensor.dtype() != tensorflow::DT_RESOURCE) {
tensors.push_back(resource_tensor);
continue;
}
const ResourceHandle& resource_handle =
resource_tensor.scalar<ResourceHandle>()();
Device* device;
if (!(device_manager->LookupDevice(resource_handle.device(), &device)
.ok())) {
return module.emitOpError() << "failed to look up device";
}
tensorflow::Var* var_ptr;
if (!(device->resource_manager()
->Lookup(resource_handle.container(), resource_handle.name(),
&var_ptr)
.ok())) {
return module.emitOpError() << "failed to look up resource value";
}
tensorflow::core::RefCountPtr<Var> var(var_ptr);
// The variable tensor is already loaded into corresponding device's
// resource manager when we load the saved model using LoadSavedModel().
// Here we just read its value.
mutex_lock ml(*var->mu());
tensors.push_back(*var->tensor());
}
for (const auto iter : llvm::zip(resource_names, tensors)) {
const StringRef name = std::get<0>(iter);
const Tensor& tensor = std::get<1>(iter);
// Create tensor attribute for this variable.
StatusOr<ElementsAttr> tensor_attr_or = ConvertTensor(tensor, &builder);
if (!tensor_attr_or.ok()) {
return module.emitOpError()
<< "failed to convert tensor (name: " << name.str() << ")";
}
ElementsAttr tensor_attr = tensor_attr_or.ValueOrDie();
builder.create<tf_saved_model::GlobalTensorOp>(
NameLoc::get(builder.getIdentifier(name.str()), context),
builder.getStringAttr(name), tensor_attr,
TypeAttr::get(tensor_attr.getType()), builder.getUnitAttr());
}
return success();
}
} // namespace
LogicalResult LiftVariables(ModuleOp module, Session* session) {
MLIRContext* context = module.getContext();
mlir::Builder builder(context);
Identifier resource_name_id = builder.getIdentifier(kResourceNameArgAttr);
SmallSet<StringRef, 4> resource_names;
for (FuncOp func : module.getOps<FuncOp>()) {
for (int i = 0, e = func.getNumArguments(); i < e; ++i) {
auto resource_arg =
func.getArgAttrOfType<StringAttr>(i, kResourceNameArgAttr);
if (!resource_arg) continue;
StringRef resource_name = resource_arg.getValue();
auto flat_symbol_ref_attr =
FlatSymbolRefAttr::get(resource_name, context);
// Add the corresponding `tf_saved_model.bound_input` attribute.
func.setArgAttr(i, kSavedModelArgAttr, flat_symbol_ref_attr);
resource_names.insert(flat_symbol_ref_attr.getValue());
// Remove the existing `tf.resource_name` attribute.
func.removeArgAttr(i, resource_name_id);
}
}
if (resource_names.empty()) return success();
return LiftVariablesFromSession(module, session, resource_names);
}
} // namespace tf_saved_model
} // namespace mlir

View File

@ -0,0 +1,33 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_LIFT_VARIABLES_H_
#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_LIFT_VARIABLES_H_
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "tensorflow/core/public/session.h"
namespace mlir {
namespace tf_saved_model {
// Creates GlobalTensorOp for each variable from function arguments and converts
// them to the corresponding saved model arguments.
LogicalResult LiftVariables(ModuleOp module, ::tensorflow::Session* session);
} // namespace tf_saved_model
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_LIFT_VARIABLES_H_

View File

@ -0,0 +1,57 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_LIFT_VARIABLES_PASS_H_
#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_LIFT_VARIABLES_PASS_H_
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Pass/PassManager.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/transforms/lift_variables.h"
#include "tensorflow/core/public/session.h"
namespace mlir {
namespace tf_saved_model {
// This pass takes care of finding all variables from the function arguments and
// converting them to the corresponding global tensors, that will be located out
// of function. Also it converts resource arguments from function types to the
// corresponding saved model arguments accordingly.
class LiftVariablesPass
: public PassWrapper<LiftVariablesPass, OperationPass<ModuleOp>> {
public:
explicit LiftVariablesPass(::tensorflow::Session* session)
: session_(session) {}
void runOnOperation() override {
ModuleOp module = getOperation();
if (failed(LiftVariables(module, session_))) signalPassFailure();
}
private:
::tensorflow::Session* session_;
};
// Creates as pass that creates GlobalTensorOp for each variable from function
// arguments and converts the function arguments to the corresponding saved
// model arguments.
std::unique_ptr<OperationPass<ModuleOp>> CreateLiftVariablesPass(
::tensorflow::Session* session);
} // namespace tf_saved_model
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_LIFT_VARIABLES_PASS_H_

View File

@ -0,0 +1,146 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_LIFT_VARIABLES_TEST_PASS_H_
#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_LIFT_VARIABLES_TEST_PASS_H_
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Pass/PassManager.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/transforms/lift_variables.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/threadpool_options.h"
#include "tensorflow/core/public/session.h"
namespace mlir {
namespace tf_saved_model {
using ::tensorflow::DeviceMgr;
using ::tensorflow::Session;
using ::tensorflow::Status;
using ::tensorflow::Tensor;
// FakeSession is for testing only.
class FakeSession : public tensorflow::Session {
public:
FakeSession() {}
~FakeSession() override = default;
Status Create(const tensorflow::GraphDef& graph) override {
return tensorflow::errors::Unimplemented("not available");
}
Status Extend(const tensorflow::GraphDef& graph) override {
return tensorflow::errors::Unimplemented("not available");
}
Status Close() override {
return tensorflow::errors::Unimplemented("not available");
}
Status ListDevices(
std::vector<tensorflow::DeviceAttributes>* response) override {
return tensorflow::errors::Unimplemented("not available");
}
Status LocalDeviceManager(
const tensorflow::DeviceMgr** deviceMgrPtr) override {
// This method returns a null device manager without making an error.
// Users of this method will be notified since it will have a fake data.
*deviceMgrPtr = nullptr;
return Status::OK();
}
Status Run(const std::vector<std::pair<std::string, Tensor>>& inputs,
const std::vector<std::string>& output_names,
const std::vector<std::string>& target_nodes,
std::vector<Tensor>* outputs) override {
tensorflow::RunMetadata run_metadata;
return Run(tensorflow::RunOptions(), inputs, output_names, target_nodes,
outputs, &run_metadata);
}
Status Run(const tensorflow::RunOptions& run_options,
const std::vector<std::pair<std::string, Tensor>>& inputs,
const std::vector<std::string>& output_names,
const std::vector<std::string>& target_nodes,
std::vector<Tensor>* outputs,
tensorflow::RunMetadata* run_metadata) override {
return Run(run_options, inputs, output_names, target_nodes, outputs,
run_metadata, tensorflow::thread::ThreadPoolOptions());
}
Status Run(const tensorflow::RunOptions& run_options,
const std::vector<std::pair<std::string, Tensor>>& inputs,
const std::vector<std::string>& output_names,
const std::vector<std::string>& target_nodes,
std::vector<Tensor>* outputs,
tensorflow::RunMetadata* run_metadata,
const tensorflow::thread::ThreadPoolOptions& thread_pool_options)
override {
for (const std::string& output_name : output_names) {
Tensor output;
if (output_name == "dense/bias") {
outputs->push_back(
Tensor(tensorflow::DT_FLOAT, tensorflow::TensorShape({50})));
} else if (output_name == "dense/kernel") {
outputs->push_back(
Tensor(tensorflow::DT_FLOAT, tensorflow::TensorShape({100, 50})));
} else {
// Create a scalar float tensor.
outputs->push_back(
Tensor(tensorflow::DT_FLOAT, tensorflow::TensorShape({})));
}
}
return Status::OK();
}
};
// This pass is only available in the tf-opt binary for testing.
class LiftVariablesTestPass
: public PassWrapper<LiftVariablesTestPass, OperationPass<ModuleOp>> {
public:
LiftVariablesTestPass() { session_ = new FakeSession(); }
~LiftVariablesTestPass() override { delete session_; }
void runOnOperation() override {
ModuleOp module = getOperation();
if (failed(LiftVariables(module, session_))) signalPassFailure();
}
private:
Session* session_;
};
// This pass is only available in the tf-opt binary for testing.
class LiftVariablesInvalidSessionTestPass
: public PassWrapper<LiftVariablesInvalidSessionTestPass,
OperationPass<ModuleOp>> {
public:
void runOnOperation() override {
ModuleOp module = getOperation();
// Pass an invalid session argument, which is a nullptr.
if (failed(LiftVariables(module, /*session=*/nullptr))) signalPassFailure();
}
};
} // namespace tf_saved_model
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_LIFT_VARIABLES_TEST_PASS_H_

View File

@ -0,0 +1,32 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/tensorflow/transforms/lift_variables_test_pass.h"
namespace mlir {
namespace tf_saved_model {
static PassRegistration<LiftVariablesTestPass> lift_variables_test_pass(
"tf-saved-model-lift-variables-test",
"Lift variables and save them as global tensors");
static PassRegistration<LiftVariablesInvalidSessionTestPass>
lift_variables_invalid_session_test_pass(
"tf-saved-model-lift-variables-invalid-session-test",
"Lift variables and save them as global tensors with an invalid "
"session");
} // namespace tf_saved_model
} // namespace mlir