diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index 9e5688cd230..904ccb7e820 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -397,6 +397,73 @@ cc_library( ], ) +cc_library( + name = "lift_variables_lib", + srcs = [ + "transforms/lift_variables.cc", + ], + hdrs = [ + "transforms/lift_variables.h", + ], + deps = [ + ":convert_tensor", + ":tensorflow", + "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_lib", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core/platform:errors", + "//tensorflow/core/platform:threadpool_options", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + ], + alwayslink = 1, +) + +cc_library( + name = "lift_variables_pass", + hdrs = [ + "transforms/lift_variables_pass.h", + ], + deps = [ + ":lift_variables_lib", + ":tensorflow", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + ], + alwayslink = 1, +) + +cc_library( + name = "lift_variables_test_pass", + hdrs = [ + "transforms/lift_variables_test_pass.h", + ], + deps = [ + ":lift_variables_lib", + ":tensorflow", + "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_lib", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core/platform:errors", + "//tensorflow/core/platform:status", + "//tensorflow/core/platform:threadpool_options", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + ], + alwayslink = 1, +) + cc_library( name = "tensorflow_passes", srcs = [ @@ -520,9 +587,11 @@ cc_library( cc_library( name = "tensorflow_test_passes", srcs = [ + "transforms/lift_variables_test_pass_registration.cc", "transforms/lower_tf_pass.cc", ], deps = [ + ":lift_variables_test_pass", ":lower_tf_lib", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_lift_variables.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_lift_variables.mlir new file mode 100644 index 00000000000..0c04a0d738c --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_lift_variables.mlir @@ -0,0 +1,61 @@ +// RUN: tf-opt -verify-diagnostics -tf-saved-model-lift-variables-test -split-input-file %s | FileCheck %s --dump-input=fail + +module attributes {tf_saved_model.semantics} { + + // Test case: Freezing VarHandleOp ops. + + func @serving_default(%arg0: tensor>> {tf.resource_name = "dense/kernel"}, %arg1: tensor>> {tf.resource_name = "dense/bias"}) -> (tensor<100x50xf32> {tf_saved_model.index_path = ["dense_2"]}) + attributes {tf.entry_function = {control_outputs = "", inputs = "", outputs = "dense_2/Add:0"}, tf_saved_model.exported_names = ["serving_default"]} { + %0 = "tf.VarHandleOp"() {_class = ["loc:@dense/kernel"], allowed_devices = [], container = "", device = "", shared_name = "dense/kernel"} : () -> tensor>> + %1 = "tf.ReadVariableOp"(%0) {device = ""} : (tensor>>) -> tensor<100x50xf32> + %2 = "tf.VarHandleOp"() {_class = ["loc:@dense/bias"], allowed_devices = [], container = "", device = "", shared_name = "dense/bias"} : () -> tensor>> + %3 = "tf.ReadVariableOp"(%2) {device = ""} : (tensor>>) -> tensor<50xf32> + %4 = "tf.Add"(%1, %3) {device = ""} : (tensor<100x50xf32>, tensor<50xf32>) -> tensor<100x50xf32> + return %4 : tensor<100x50xf32> + } + // CHECK: "tf_saved_model.global_tensor"() + // CHECK: sym_name = "dense/kernel" + // CHECK: "tf_saved_model.global_tensor"() + // CHECK: sym_name = "dense/bias" + // CHECK: func @serving_default( + // CHECK: %arg0: tensor>> {tf_saved_model.bound_input = @"dense/kernel"}, + // CHECK: %arg1: tensor>> {tf_saved_model.bound_input = @"dense/bias"}) +} + +// ----- + +module attributes {tf_saved_model.semantics} { + + // Test case: Freezing shared VarHandleOp ops. + + func @f(%arg0: tensor>> {tf.resource_name = "dense/kernel"}, %arg1: tensor>> {tf.resource_name = "dense/bias"}) -> (tensor<100x50xf32> {tf_saved_model.index_path = ["dense_2"]}) + attributes {tf.entry_function = {control_outputs = "", inputs = "", outputs = "dense_2/Add:0"}, tf_saved_model.exported_names = ["f"]} { + %0 = "tf.VarHandleOp"() {_class = ["loc:@dense/kernel"], allowed_devices = [], container = "", device = "", shared_name = "dense/kernel"} : () -> tensor>> + %1 = "tf.ReadVariableOp"(%0) {device = ""} : (tensor>>) -> tensor<100x50xf32> + %2 = "tf.VarHandleOp"() {_class = ["loc:@dense/bias"], allowed_devices = [], container = "", device = "", shared_name = "dense/bias"} : () -> tensor>> + %3 = "tf.ReadVariableOp"(%2) {device = ""} : (tensor>>) -> tensor<50xf32> + %4 = "tf.Add"(%1, %3) {device = ""} : (tensor<100x50xf32>, tensor<50xf32>) -> tensor<100x50xf32> + return %4 : tensor<100x50xf32> + } + + func @f2(%arg0: tensor>> {tf.resource_name = "dense/kernel"}, %arg1: tensor>> {tf.resource_name = "dense/bias"}) -> (tensor<100x50xf32> {tf_saved_model.index_path = ["dense_2"]}) + attributes {tf.entry_function = {control_outputs = "", inputs = "", outputs = "dense_2/Add:0"}, tf_saved_model.exported_names = ["f2"]} { + %0 = "tf.VarHandleOp"() {_class = ["loc:@dense/kernel"], allowed_devices = [], container = "", device = "", shared_name = "dense/kernel"} : () -> tensor>> + %1 = "tf.ReadVariableOp"(%0) {device = ""} : (tensor>>) -> tensor<100x50xf32> + %2 = "tf.VarHandleOp"() {_class = ["loc:@dense/bias"], allowed_devices = [], container = "", device = "", shared_name = "dense/bias"} : () -> tensor>> + %3 = "tf.ReadVariableOp"(%2) {device = ""} : (tensor>>) -> tensor<50xf32> + %4 = "tf.Add"(%1, %3) {device = ""} : (tensor<100x50xf32>, tensor<50xf32>) -> tensor<100x50xf32> + return %4 : tensor<100x50xf32> + } + // CHECK: "tf_saved_model.global_tensor"() + // CHECK: sym_name = "dense/kernel" + // CHECK: "tf_saved_model.global_tensor"() + // CHECK: sym_name = "dense/bias" + // CHECK: func @f( + // CHECK: %arg0: tensor>> {tf_saved_model.bound_input = @"dense/kernel"}, + // CHECK: %arg1: tensor>> {tf_saved_model.bound_input = @"dense/bias"}) + + // CHECK: func @f2( + // CHECK: %arg0: tensor>> {tf_saved_model.bound_input = @"dense/kernel"}, + // CHECK: %arg1: tensor>> {tf_saved_model.bound_input = @"dense/bias"}) +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_lift_variables_invalid_session.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_lift_variables_invalid_session.mlir new file mode 100644 index 00000000000..17244d8481a --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_lift_variables_invalid_session.mlir @@ -0,0 +1,33 @@ +// RUN: tf-opt -verify-diagnostics -tf-saved-model-lift-variables-invalid-session-test -split-input-file %s | FileCheck %s --dump-input=fail + +// Test case: Invalid session. +// expected-error @+1 {{'module' op no session provided}} +module attributes {tf_saved_model.semantics} { + + func @serving_default(%arg0: tensor>> {tf.resource_name = "dense/kernel"}, %arg1: tensor>> {tf.resource_name = "dense/bias"}) -> (tensor<100x50xf32> {tf_saved_model.index_path = ["dense_2"]}) + attributes {tf.entry_function = {control_outputs = "", inputs = "", outputs = "dense_2/Add:0"}, tf_saved_model.exported_names = ["serving_default"]} { + %0 = "tf.VarHandleOp"() {_class = ["loc:@dense/kernel"], allowed_devices = [], container = "", device = "", shared_name = "dense/kernel"} : () -> tensor>> + %1 = "tf.ReadVariableOp"(%0) {device = ""} : (tensor>>) -> tensor<100x50xf32> + %2 = "tf.VarHandleOp"() {_class = ["loc:@dense/bias"], allowed_devices = [], container = "", device = "", shared_name = "dense/bias"} : () -> tensor>> + %3 = "tf.ReadVariableOp"(%2) {device = ""} : (tensor>>) -> tensor<50xf32> + %4 = "tf.Add"(%1, %3) {device = ""} : (tensor<100x50xf32>, tensor<50xf32>) -> tensor<100x50xf32> + return %4 : tensor<100x50xf32> + } +} + +// ----- + +// Test case: No errors on no resource arguments. +module attributes {tf_saved_model.semantics} { + + // CHECK-LABEL: @serving_default + func @serving_default() -> (tensor<100x50xf32> {tf_saved_model.index_path = ["dense_2"]}) + attributes {tf.entry_function = {control_outputs = "", inputs = "", outputs = "dense_2/Add:0"}, tf_saved_model.exported_names = ["serving_default"]} { + %0 = "tf.VarHandleOp"() {_class = ["loc:@dense/kernel"], allowed_devices = [], container = "", device = "", shared_name = "dense/kernel"} : () -> tensor>> + %1 = "tf.ReadVariableOp"(%0) {device = ""} : (tensor>>) -> tensor<100x50xf32> + %2 = "tf.VarHandleOp"() {_class = ["loc:@dense/bias"], allowed_devices = [], container = "", device = "", shared_name = "dense/bias"} : () -> tensor>> + %3 = "tf.ReadVariableOp"(%2) {device = ""} : (tensor>>) -> tensor<50xf32> + %4 = "tf.Add"(%1, %3) {device = ""} : (tensor<100x50xf32>, tensor<50xf32>) -> tensor<100x50xf32> + return %4 : tensor<100x50xf32> + } +} diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/lift_variables.cc b/tensorflow/compiler/mlir/tensorflow/transforms/lift_variables.cc new file mode 100644 index 00000000000..6686b340be9 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/lift_variables.cc @@ -0,0 +1,183 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/tensorflow/transforms/lift_variables.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/IR/UseDefLists.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h" +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/common_runtime/device_mgr.h" +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/framework/resource_var.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/threadpool_options.h" +#include "tensorflow/core/public/session.h" + +namespace mlir { +namespace tf_saved_model { + +using llvm::SmallSet; +using ::tensorflow::Device; +using ::tensorflow::DeviceMgr; +using ::tensorflow::mutex_lock; +using ::tensorflow::ResourceHandle; +using ::tensorflow::Session; +using ::tensorflow::Status; +using ::tensorflow::StatusOr; +using ::tensorflow::Tensor; +using ::tensorflow::Var; + +namespace { + +constexpr char kResourceNameArgAttr[] = "tf.resource_name"; +constexpr char kSavedModelArgAttr[] = "tf_saved_model.bound_input"; + +LogicalResult LiftVariablesFromSession( + ModuleOp module, Session* session, + const SmallSet& resource_names) { + OpBuilder builder(module.getBodyRegion()); + MLIRContext* context = module.getContext(); + + if (!session) return module.emitOpError() << "no session provided"; + + // Read all resource variables from the session. + std::vector variable_names; + variable_names.reserve(resource_names.size()); + for (StringRef name : resource_names) variable_names.push_back(name.str()); + + std::vector resource_tensors; + Status status = session->Run( + /*inputs=*/{}, variable_names, + /*target_node_names=*/{}, &resource_tensors); + if (!status.ok()) { + return module.emitOpError() + << "failed to run the provided session: " << status.error_message(); + } + + const DeviceMgr* device_manager; + if (!(session->LocalDeviceManager(&device_manager).ok())) { + return module.emitOpError() << "failed to get local device manager"; + } + + // Read all underlying tensors of the variables from the session. + std::vector tensors; + tensors.reserve(resource_tensors.size()); + for (const Tensor& resource_tensor : resource_tensors) { + if (resource_tensor.dtype() != tensorflow::DT_RESOURCE) { + tensors.push_back(resource_tensor); + continue; + } + + const ResourceHandle& resource_handle = + resource_tensor.scalar()(); + + Device* device; + if (!(device_manager->LookupDevice(resource_handle.device(), &device) + .ok())) { + return module.emitOpError() << "failed to look up device"; + } + + tensorflow::Var* var_ptr; + if (!(device->resource_manager() + ->Lookup(resource_handle.container(), resource_handle.name(), + &var_ptr) + .ok())) { + return module.emitOpError() << "failed to look up resource value"; + } + tensorflow::core::RefCountPtr var(var_ptr); + + // The variable tensor is already loaded into corresponding device's + // resource manager when we load the saved model using LoadSavedModel(). + // Here we just read its value. + mutex_lock ml(*var->mu()); + tensors.push_back(*var->tensor()); + } + + for (const auto iter : llvm::zip(resource_names, tensors)) { + const StringRef name = std::get<0>(iter); + const Tensor& tensor = std::get<1>(iter); + + // Create tensor attribute for this variable. + StatusOr tensor_attr_or = ConvertTensor(tensor, &builder); + if (!tensor_attr_or.ok()) { + return module.emitOpError() + << "failed to convert tensor (name: " << name.str() << ")"; + } + ElementsAttr tensor_attr = tensor_attr_or.ValueOrDie(); + + builder.create( + NameLoc::get(builder.getIdentifier(name.str()), context), + builder.getStringAttr(name), tensor_attr, + TypeAttr::get(tensor_attr.getType()), builder.getUnitAttr()); + } + + return success(); +} + +} // namespace + +LogicalResult LiftVariables(ModuleOp module, Session* session) { + MLIRContext* context = module.getContext(); + mlir::Builder builder(context); + Identifier resource_name_id = builder.getIdentifier(kResourceNameArgAttr); + + SmallSet resource_names; + + for (FuncOp func : module.getOps()) { + for (int i = 0, e = func.getNumArguments(); i < e; ++i) { + auto resource_arg = + func.getArgAttrOfType(i, kResourceNameArgAttr); + if (!resource_arg) continue; + + StringRef resource_name = resource_arg.getValue(); + auto flat_symbol_ref_attr = + FlatSymbolRefAttr::get(resource_name, context); + + // Add the corresponding `tf_saved_model.bound_input` attribute. + func.setArgAttr(i, kSavedModelArgAttr, flat_symbol_ref_attr); + + resource_names.insert(flat_symbol_ref_attr.getValue()); + + // Remove the existing `tf.resource_name` attribute. + func.removeArgAttr(i, resource_name_id); + } + } + + if (resource_names.empty()) return success(); + + return LiftVariablesFromSession(module, session, resource_names); +} + +} // namespace tf_saved_model +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/lift_variables.h b/tensorflow/compiler/mlir/tensorflow/transforms/lift_variables.h new file mode 100644 index 00000000000..12dc787fbcf --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/lift_variables.h @@ -0,0 +1,33 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_LIFT_VARIABLES_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_LIFT_VARIABLES_H_ + +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/core/public/session.h" + +namespace mlir { +namespace tf_saved_model { + +// Creates GlobalTensorOp for each variable from function arguments and converts +// them to the corresponding saved model arguments. +LogicalResult LiftVariables(ModuleOp module, ::tensorflow::Session* session); + +} // namespace tf_saved_model +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_LIFT_VARIABLES_H_ diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/lift_variables_pass.h b/tensorflow/compiler/mlir/tensorflow/transforms/lift_variables_pass.h new file mode 100644 index 00000000000..0eaee959c77 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/lift_variables_pass.h @@ -0,0 +1,57 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_LIFT_VARIABLES_PASS_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_LIFT_VARIABLES_PASS_H_ + +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/transforms/lift_variables.h" +#include "tensorflow/core/public/session.h" + +namespace mlir { +namespace tf_saved_model { + +// This pass takes care of finding all variables from the function arguments and +// converting them to the corresponding global tensors, that will be located out +// of function. Also it converts resource arguments from function types to the +// corresponding saved model arguments accordingly. +class LiftVariablesPass + : public PassWrapper> { + public: + explicit LiftVariablesPass(::tensorflow::Session* session) + : session_(session) {} + + void runOnOperation() override { + ModuleOp module = getOperation(); + if (failed(LiftVariables(module, session_))) signalPassFailure(); + } + + private: + ::tensorflow::Session* session_; +}; + +// Creates as pass that creates GlobalTensorOp for each variable from function +// arguments and converts the function arguments to the corresponding saved +// model arguments. +std::unique_ptr> CreateLiftVariablesPass( + ::tensorflow::Session* session); + +} // namespace tf_saved_model +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_LIFT_VARIABLES_PASS_H_ diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/lift_variables_test_pass.h b/tensorflow/compiler/mlir/tensorflow/transforms/lift_variables_test_pass.h new file mode 100644 index 00000000000..faecdf04368 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/lift_variables_test_pass.h @@ -0,0 +1,146 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_LIFT_VARIABLES_TEST_PASS_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_LIFT_VARIABLES_TEST_PASS_H_ + +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/transforms/lift_variables.h" +#include "tensorflow/core/common_runtime/device_mgr.h" +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/threadpool_options.h" +#include "tensorflow/core/public/session.h" + +namespace mlir { +namespace tf_saved_model { + +using ::tensorflow::DeviceMgr; +using ::tensorflow::Session; +using ::tensorflow::Status; +using ::tensorflow::Tensor; + +// FakeSession is for testing only. +class FakeSession : public tensorflow::Session { + public: + FakeSession() {} + ~FakeSession() override = default; + + Status Create(const tensorflow::GraphDef& graph) override { + return tensorflow::errors::Unimplemented("not available"); + } + Status Extend(const tensorflow::GraphDef& graph) override { + return tensorflow::errors::Unimplemented("not available"); + } + + Status Close() override { + return tensorflow::errors::Unimplemented("not available"); + } + + Status ListDevices( + std::vector* response) override { + return tensorflow::errors::Unimplemented("not available"); + } + + Status LocalDeviceManager( + const tensorflow::DeviceMgr** deviceMgrPtr) override { + // This method returns a null device manager without making an error. + // Users of this method will be notified since it will have a fake data. + *deviceMgrPtr = nullptr; + return Status::OK(); + } + + Status Run(const std::vector>& inputs, + const std::vector& output_names, + const std::vector& target_nodes, + std::vector* outputs) override { + tensorflow::RunMetadata run_metadata; + return Run(tensorflow::RunOptions(), inputs, output_names, target_nodes, + outputs, &run_metadata); + } + + Status Run(const tensorflow::RunOptions& run_options, + const std::vector>& inputs, + const std::vector& output_names, + const std::vector& target_nodes, + std::vector* outputs, + tensorflow::RunMetadata* run_metadata) override { + return Run(run_options, inputs, output_names, target_nodes, outputs, + run_metadata, tensorflow::thread::ThreadPoolOptions()); + } + + Status Run(const tensorflow::RunOptions& run_options, + const std::vector>& inputs, + const std::vector& output_names, + const std::vector& target_nodes, + std::vector* outputs, + tensorflow::RunMetadata* run_metadata, + const tensorflow::thread::ThreadPoolOptions& thread_pool_options) + override { + for (const std::string& output_name : output_names) { + Tensor output; + if (output_name == "dense/bias") { + outputs->push_back( + Tensor(tensorflow::DT_FLOAT, tensorflow::TensorShape({50}))); + } else if (output_name == "dense/kernel") { + outputs->push_back( + Tensor(tensorflow::DT_FLOAT, tensorflow::TensorShape({100, 50}))); + } else { + // Create a scalar float tensor. + outputs->push_back( + Tensor(tensorflow::DT_FLOAT, tensorflow::TensorShape({}))); + } + } + return Status::OK(); + } +}; + +// This pass is only available in the tf-opt binary for testing. +class LiftVariablesTestPass + : public PassWrapper> { + public: + LiftVariablesTestPass() { session_ = new FakeSession(); } + + ~LiftVariablesTestPass() override { delete session_; } + + void runOnOperation() override { + ModuleOp module = getOperation(); + if (failed(LiftVariables(module, session_))) signalPassFailure(); + } + + private: + Session* session_; +}; + +// This pass is only available in the tf-opt binary for testing. +class LiftVariablesInvalidSessionTestPass + : public PassWrapper> { + public: + void runOnOperation() override { + ModuleOp module = getOperation(); + // Pass an invalid session argument, which is a nullptr. + if (failed(LiftVariables(module, /*session=*/nullptr))) signalPassFailure(); + } +}; + +} // namespace tf_saved_model +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_LIFT_VARIABLES_TEST_PASS_H_ diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/lift_variables_test_pass_registration.cc b/tensorflow/compiler/mlir/tensorflow/transforms/lift_variables_test_pass_registration.cc new file mode 100644 index 00000000000..19c367c6d46 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/lift_variables_test_pass_registration.cc @@ -0,0 +1,32 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/tensorflow/transforms/lift_variables_test_pass.h" + +namespace mlir { +namespace tf_saved_model { + +static PassRegistration lift_variables_test_pass( + "tf-saved-model-lift-variables-test", + "Lift variables and save them as global tensors"); + +static PassRegistration + lift_variables_invalid_session_test_pass( + "tf-saved-model-lift-variables-invalid-session-test", + "Lift variables and save them as global tensors with an invalid " + "session"); + +} // namespace tf_saved_model +} // namespace mlir