From 89a1d3f4e9dcea93bff6cdcd469e5e10c26376bc Mon Sep 17 00:00:00 2001 From: Jaesung Chung Date: Tue, 16 Jun 2020 16:24:49 -0700 Subject: [PATCH] Implement a pass for lifting variables. This pass creates GlobalTensorOp for each variable from function arguments and converts the function arguments to the corresponding saved model arguments. This change fixes the Kokoro builds by avoiding the common_runtime dependency. PiperOrigin-RevId: 316779884 Change-Id: I07c83bf12486748e4350717d94928a75bad92342 --- tensorflow/compiler/mlir/tensorflow/BUILD | 69 +++++++ .../tests/tf_saved_model_lift_variables.mlir | 61 ++++++ ..._model_lift_variables_invalid_session.mlir | 33 ++++ .../tensorflow/transforms/lift_variables.cc | 183 ++++++++++++++++++ .../tensorflow/transforms/lift_variables.h | 33 ++++ .../transforms/lift_variables_pass.h | 57 ++++++ .../transforms/lift_variables_test_pass.h | 146 ++++++++++++++ .../lift_variables_test_pass_registration.cc | 32 +++ 8 files changed, 614 insertions(+) create mode 100644 tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_lift_variables.mlir create mode 100644 tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_lift_variables_invalid_session.mlir create mode 100644 tensorflow/compiler/mlir/tensorflow/transforms/lift_variables.cc create mode 100644 tensorflow/compiler/mlir/tensorflow/transforms/lift_variables.h create mode 100644 tensorflow/compiler/mlir/tensorflow/transforms/lift_variables_pass.h create mode 100644 tensorflow/compiler/mlir/tensorflow/transforms/lift_variables_test_pass.h create mode 100644 tensorflow/compiler/mlir/tensorflow/transforms/lift_variables_test_pass_registration.cc diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index 9e5688cd230..904ccb7e820 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -397,6 +397,73 @@ cc_library( ], ) +cc_library( + name = "lift_variables_lib", + srcs = [ + "transforms/lift_variables.cc", + ], + hdrs = [ + "transforms/lift_variables.h", + ], + deps = [ + ":convert_tensor", + ":tensorflow", + "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_lib", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core/platform:errors", + "//tensorflow/core/platform:threadpool_options", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + ], + alwayslink = 1, +) + +cc_library( + name = "lift_variables_pass", + hdrs = [ + "transforms/lift_variables_pass.h", + ], + deps = [ + ":lift_variables_lib", + ":tensorflow", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + ], + alwayslink = 1, +) + +cc_library( + name = "lift_variables_test_pass", + hdrs = [ + "transforms/lift_variables_test_pass.h", + ], + deps = [ + ":lift_variables_lib", + ":tensorflow", + "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_lib", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core/platform:errors", + "//tensorflow/core/platform:status", + "//tensorflow/core/platform:threadpool_options", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + ], + alwayslink = 1, +) + cc_library( name = "tensorflow_passes", srcs = [ @@ -520,9 +587,11 @@ cc_library( cc_library( name = "tensorflow_test_passes", srcs = [ + "transforms/lift_variables_test_pass_registration.cc", "transforms/lower_tf_pass.cc", ], deps = [ + ":lift_variables_test_pass", ":lower_tf_lib", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_lift_variables.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_lift_variables.mlir new file mode 100644 index 00000000000..0c04a0d738c --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_lift_variables.mlir @@ -0,0 +1,61 @@ +// RUN: tf-opt -verify-diagnostics -tf-saved-model-lift-variables-test -split-input-file %s | FileCheck %s --dump-input=fail + +module attributes {tf_saved_model.semantics} { + + // Test case: Freezing VarHandleOp ops. + + func @serving_default(%arg0: tensor>> {tf.resource_name = "dense/kernel"}, %arg1: tensor>> {tf.resource_name = "dense/bias"}) -> (tensor<100x50xf32> {tf_saved_model.index_path = ["dense_2"]}) + attributes {tf.entry_function = {control_outputs = "", inputs = "", outputs = "dense_2/Add:0"}, tf_saved_model.exported_names = ["serving_default"]} { + %0 = "tf.VarHandleOp"() {_class = ["loc:@dense/kernel"], allowed_devices = [], container = "", device = "", shared_name = "dense/kernel"} : () -> tensor>> + %1 = "tf.ReadVariableOp"(%0) {device = ""} : (tensor>>) -> tensor<100x50xf32> + %2 = "tf.VarHandleOp"() {_class = ["loc:@dense/bias"], allowed_devices = [], container = "", device = "", shared_name = "dense/bias"} : () -> tensor>> + %3 = "tf.ReadVariableOp"(%2) {device = ""} : (tensor>>) -> tensor<50xf32> + %4 = "tf.Add"(%1, %3) {device = ""} : (tensor<100x50xf32>, tensor<50xf32>) -> tensor<100x50xf32> + return %4 : tensor<100x50xf32> + } + // CHECK: "tf_saved_model.global_tensor"() + // CHECK: sym_name = "dense/kernel" + // CHECK: "tf_saved_model.global_tensor"() + // CHECK: sym_name = "dense/bias" + // CHECK: func @serving_default( + // CHECK: %arg0: tensor>> {tf_saved_model.bound_input = @"dense/kernel"}, + // CHECK: %arg1: tensor>> {tf_saved_model.bound_input = @"dense/bias"}) +} + +// ----- + +module attributes {tf_saved_model.semantics} { + + // Test case: Freezing shared VarHandleOp ops. + + func @f(%arg0: tensor>> {tf.resource_name = "dense/kernel"}, %arg1: tensor>> {tf.resource_name = "dense/bias"}) -> (tensor<100x50xf32> {tf_saved_model.index_path = ["dense_2"]}) + attributes {tf.entry_function = {control_outputs = "", inputs = "", outputs = "dense_2/Add:0"}, tf_saved_model.exported_names = ["f"]} { + %0 = "tf.VarHandleOp"() {_class = ["loc:@dense/kernel"], allowed_devices = [], container = "", device = "", shared_name = "dense/kernel"} : () -> tensor>> + %1 = "tf.ReadVariableOp"(%0) {device = ""} : (tensor>>) -> tensor<100x50xf32> + %2 = "tf.VarHandleOp"() {_class = ["loc:@dense/bias"], allowed_devices = [], container = "", device = "", shared_name = "dense/bias"} : () -> tensor>> + %3 = "tf.ReadVariableOp"(%2) {device = ""} : (tensor>>) -> tensor<50xf32> + %4 = "tf.Add"(%1, %3) {device = ""} : (tensor<100x50xf32>, tensor<50xf32>) -> tensor<100x50xf32> + return %4 : tensor<100x50xf32> + } + + func @f2(%arg0: tensor>> {tf.resource_name = "dense/kernel"}, %arg1: tensor>> {tf.resource_name = "dense/bias"}) -> (tensor<100x50xf32> {tf_saved_model.index_path = ["dense_2"]}) + attributes {tf.entry_function = {control_outputs = "", inputs = "", outputs = "dense_2/Add:0"}, tf_saved_model.exported_names = ["f2"]} { + %0 = "tf.VarHandleOp"() {_class = ["loc:@dense/kernel"], allowed_devices = [], container = "", device = "", shared_name = "dense/kernel"} : () -> tensor>> + %1 = "tf.ReadVariableOp"(%0) {device = ""} : (tensor>>) -> tensor<100x50xf32> + %2 = "tf.VarHandleOp"() {_class = ["loc:@dense/bias"], allowed_devices = [], container = "", device = "", shared_name = "dense/bias"} : () -> tensor>> + %3 = "tf.ReadVariableOp"(%2) {device = ""} : (tensor>>) -> tensor<50xf32> + %4 = "tf.Add"(%1, %3) {device = ""} : (tensor<100x50xf32>, tensor<50xf32>) -> tensor<100x50xf32> + return %4 : tensor<100x50xf32> + } + // CHECK: "tf_saved_model.global_tensor"() + // CHECK: sym_name = "dense/kernel" + // CHECK: "tf_saved_model.global_tensor"() + // CHECK: sym_name = "dense/bias" + // CHECK: func @f( + // CHECK: %arg0: tensor>> {tf_saved_model.bound_input = @"dense/kernel"}, + // CHECK: %arg1: tensor>> {tf_saved_model.bound_input = @"dense/bias"}) + + // CHECK: func @f2( + // CHECK: %arg0: tensor>> {tf_saved_model.bound_input = @"dense/kernel"}, + // CHECK: %arg1: tensor>> {tf_saved_model.bound_input = @"dense/bias"}) +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_lift_variables_invalid_session.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_lift_variables_invalid_session.mlir new file mode 100644 index 00000000000..17244d8481a --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_lift_variables_invalid_session.mlir @@ -0,0 +1,33 @@ +// RUN: tf-opt -verify-diagnostics -tf-saved-model-lift-variables-invalid-session-test -split-input-file %s | FileCheck %s --dump-input=fail + +// Test case: Invalid session. +// expected-error @+1 {{'module' op no session provided}} +module attributes {tf_saved_model.semantics} { + + func @serving_default(%arg0: tensor>> {tf.resource_name = "dense/kernel"}, %arg1: tensor>> {tf.resource_name = "dense/bias"}) -> (tensor<100x50xf32> {tf_saved_model.index_path = ["dense_2"]}) + attributes {tf.entry_function = {control_outputs = "", inputs = "", outputs = "dense_2/Add:0"}, tf_saved_model.exported_names = ["serving_default"]} { + %0 = "tf.VarHandleOp"() {_class = ["loc:@dense/kernel"], allowed_devices = [], container = "", device = "", shared_name = "dense/kernel"} : () -> tensor>> + %1 = "tf.ReadVariableOp"(%0) {device = ""} : (tensor>>) -> tensor<100x50xf32> + %2 = "tf.VarHandleOp"() {_class = ["loc:@dense/bias"], allowed_devices = [], container = "", device = "", shared_name = "dense/bias"} : () -> tensor>> + %3 = "tf.ReadVariableOp"(%2) {device = ""} : (tensor>>) -> tensor<50xf32> + %4 = "tf.Add"(%1, %3) {device = ""} : (tensor<100x50xf32>, tensor<50xf32>) -> tensor<100x50xf32> + return %4 : tensor<100x50xf32> + } +} + +// ----- + +// Test case: No errors on no resource arguments. +module attributes {tf_saved_model.semantics} { + + // CHECK-LABEL: @serving_default + func @serving_default() -> (tensor<100x50xf32> {tf_saved_model.index_path = ["dense_2"]}) + attributes {tf.entry_function = {control_outputs = "", inputs = "", outputs = "dense_2/Add:0"}, tf_saved_model.exported_names = ["serving_default"]} { + %0 = "tf.VarHandleOp"() {_class = ["loc:@dense/kernel"], allowed_devices = [], container = "", device = "", shared_name = "dense/kernel"} : () -> tensor>> + %1 = "tf.ReadVariableOp"(%0) {device = ""} : (tensor>>) -> tensor<100x50xf32> + %2 = "tf.VarHandleOp"() {_class = ["loc:@dense/bias"], allowed_devices = [], container = "", device = "", shared_name = "dense/bias"} : () -> tensor>> + %3 = "tf.ReadVariableOp"(%2) {device = ""} : (tensor>>) -> tensor<50xf32> + %4 = "tf.Add"(%1, %3) {device = ""} : (tensor<100x50xf32>, tensor<50xf32>) -> tensor<100x50xf32> + return %4 : tensor<100x50xf32> + } +} diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/lift_variables.cc b/tensorflow/compiler/mlir/tensorflow/transforms/lift_variables.cc new file mode 100644 index 00000000000..6686b340be9 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/lift_variables.cc @@ -0,0 +1,183 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/tensorflow/transforms/lift_variables.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/IR/UseDefLists.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h" +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/common_runtime/device_mgr.h" +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/framework/resource_var.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/threadpool_options.h" +#include "tensorflow/core/public/session.h" + +namespace mlir { +namespace tf_saved_model { + +using llvm::SmallSet; +using ::tensorflow::Device; +using ::tensorflow::DeviceMgr; +using ::tensorflow::mutex_lock; +using ::tensorflow::ResourceHandle; +using ::tensorflow::Session; +using ::tensorflow::Status; +using ::tensorflow::StatusOr; +using ::tensorflow::Tensor; +using ::tensorflow::Var; + +namespace { + +constexpr char kResourceNameArgAttr[] = "tf.resource_name"; +constexpr char kSavedModelArgAttr[] = "tf_saved_model.bound_input"; + +LogicalResult LiftVariablesFromSession( + ModuleOp module, Session* session, + const SmallSet& resource_names) { + OpBuilder builder(module.getBodyRegion()); + MLIRContext* context = module.getContext(); + + if (!session) return module.emitOpError() << "no session provided"; + + // Read all resource variables from the session. + std::vector variable_names; + variable_names.reserve(resource_names.size()); + for (StringRef name : resource_names) variable_names.push_back(name.str()); + + std::vector resource_tensors; + Status status = session->Run( + /*inputs=*/{}, variable_names, + /*target_node_names=*/{}, &resource_tensors); + if (!status.ok()) { + return module.emitOpError() + << "failed to run the provided session: " << status.error_message(); + } + + const DeviceMgr* device_manager; + if (!(session->LocalDeviceManager(&device_manager).ok())) { + return module.emitOpError() << "failed to get local device manager"; + } + + // Read all underlying tensors of the variables from the session. + std::vector tensors; + tensors.reserve(resource_tensors.size()); + for (const Tensor& resource_tensor : resource_tensors) { + if (resource_tensor.dtype() != tensorflow::DT_RESOURCE) { + tensors.push_back(resource_tensor); + continue; + } + + const ResourceHandle& resource_handle = + resource_tensor.scalar()(); + + Device* device; + if (!(device_manager->LookupDevice(resource_handle.device(), &device) + .ok())) { + return module.emitOpError() << "failed to look up device"; + } + + tensorflow::Var* var_ptr; + if (!(device->resource_manager() + ->Lookup(resource_handle.container(), resource_handle.name(), + &var_ptr) + .ok())) { + return module.emitOpError() << "failed to look up resource value"; + } + tensorflow::core::RefCountPtr var(var_ptr); + + // The variable tensor is already loaded into corresponding device's + // resource manager when we load the saved model using LoadSavedModel(). + // Here we just read its value. + mutex_lock ml(*var->mu()); + tensors.push_back(*var->tensor()); + } + + for (const auto iter : llvm::zip(resource_names, tensors)) { + const StringRef name = std::get<0>(iter); + const Tensor& tensor = std::get<1>(iter); + + // Create tensor attribute for this variable. + StatusOr tensor_attr_or = ConvertTensor(tensor, &builder); + if (!tensor_attr_or.ok()) { + return module.emitOpError() + << "failed to convert tensor (name: " << name.str() << ")"; + } + ElementsAttr tensor_attr = tensor_attr_or.ValueOrDie(); + + builder.create( + NameLoc::get(builder.getIdentifier(name.str()), context), + builder.getStringAttr(name), tensor_attr, + TypeAttr::get(tensor_attr.getType()), builder.getUnitAttr()); + } + + return success(); +} + +} // namespace + +LogicalResult LiftVariables(ModuleOp module, Session* session) { + MLIRContext* context = module.getContext(); + mlir::Builder builder(context); + Identifier resource_name_id = builder.getIdentifier(kResourceNameArgAttr); + + SmallSet resource_names; + + for (FuncOp func : module.getOps()) { + for (int i = 0, e = func.getNumArguments(); i < e; ++i) { + auto resource_arg = + func.getArgAttrOfType(i, kResourceNameArgAttr); + if (!resource_arg) continue; + + StringRef resource_name = resource_arg.getValue(); + auto flat_symbol_ref_attr = + FlatSymbolRefAttr::get(resource_name, context); + + // Add the corresponding `tf_saved_model.bound_input` attribute. + func.setArgAttr(i, kSavedModelArgAttr, flat_symbol_ref_attr); + + resource_names.insert(flat_symbol_ref_attr.getValue()); + + // Remove the existing `tf.resource_name` attribute. + func.removeArgAttr(i, resource_name_id); + } + } + + if (resource_names.empty()) return success(); + + return LiftVariablesFromSession(module, session, resource_names); +} + +} // namespace tf_saved_model +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/lift_variables.h b/tensorflow/compiler/mlir/tensorflow/transforms/lift_variables.h new file mode 100644 index 00000000000..12dc787fbcf --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/lift_variables.h @@ -0,0 +1,33 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_LIFT_VARIABLES_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_LIFT_VARIABLES_H_ + +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/core/public/session.h" + +namespace mlir { +namespace tf_saved_model { + +// Creates GlobalTensorOp for each variable from function arguments and converts +// them to the corresponding saved model arguments. +LogicalResult LiftVariables(ModuleOp module, ::tensorflow::Session* session); + +} // namespace tf_saved_model +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_LIFT_VARIABLES_H_ diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/lift_variables_pass.h b/tensorflow/compiler/mlir/tensorflow/transforms/lift_variables_pass.h new file mode 100644 index 00000000000..0eaee959c77 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/lift_variables_pass.h @@ -0,0 +1,57 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_LIFT_VARIABLES_PASS_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_LIFT_VARIABLES_PASS_H_ + +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/transforms/lift_variables.h" +#include "tensorflow/core/public/session.h" + +namespace mlir { +namespace tf_saved_model { + +// This pass takes care of finding all variables from the function arguments and +// converting them to the corresponding global tensors, that will be located out +// of function. Also it converts resource arguments from function types to the +// corresponding saved model arguments accordingly. +class LiftVariablesPass + : public PassWrapper> { + public: + explicit LiftVariablesPass(::tensorflow::Session* session) + : session_(session) {} + + void runOnOperation() override { + ModuleOp module = getOperation(); + if (failed(LiftVariables(module, session_))) signalPassFailure(); + } + + private: + ::tensorflow::Session* session_; +}; + +// Creates as pass that creates GlobalTensorOp for each variable from function +// arguments and converts the function arguments to the corresponding saved +// model arguments. +std::unique_ptr> CreateLiftVariablesPass( + ::tensorflow::Session* session); + +} // namespace tf_saved_model +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_LIFT_VARIABLES_PASS_H_ diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/lift_variables_test_pass.h b/tensorflow/compiler/mlir/tensorflow/transforms/lift_variables_test_pass.h new file mode 100644 index 00000000000..faecdf04368 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/lift_variables_test_pass.h @@ -0,0 +1,146 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_LIFT_VARIABLES_TEST_PASS_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_LIFT_VARIABLES_TEST_PASS_H_ + +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/transforms/lift_variables.h" +#include "tensorflow/core/common_runtime/device_mgr.h" +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/threadpool_options.h" +#include "tensorflow/core/public/session.h" + +namespace mlir { +namespace tf_saved_model { + +using ::tensorflow::DeviceMgr; +using ::tensorflow::Session; +using ::tensorflow::Status; +using ::tensorflow::Tensor; + +// FakeSession is for testing only. +class FakeSession : public tensorflow::Session { + public: + FakeSession() {} + ~FakeSession() override = default; + + Status Create(const tensorflow::GraphDef& graph) override { + return tensorflow::errors::Unimplemented("not available"); + } + Status Extend(const tensorflow::GraphDef& graph) override { + return tensorflow::errors::Unimplemented("not available"); + } + + Status Close() override { + return tensorflow::errors::Unimplemented("not available"); + } + + Status ListDevices( + std::vector* response) override { + return tensorflow::errors::Unimplemented("not available"); + } + + Status LocalDeviceManager( + const tensorflow::DeviceMgr** deviceMgrPtr) override { + // This method returns a null device manager without making an error. + // Users of this method will be notified since it will have a fake data. + *deviceMgrPtr = nullptr; + return Status::OK(); + } + + Status Run(const std::vector>& inputs, + const std::vector& output_names, + const std::vector& target_nodes, + std::vector* outputs) override { + tensorflow::RunMetadata run_metadata; + return Run(tensorflow::RunOptions(), inputs, output_names, target_nodes, + outputs, &run_metadata); + } + + Status Run(const tensorflow::RunOptions& run_options, + const std::vector>& inputs, + const std::vector& output_names, + const std::vector& target_nodes, + std::vector* outputs, + tensorflow::RunMetadata* run_metadata) override { + return Run(run_options, inputs, output_names, target_nodes, outputs, + run_metadata, tensorflow::thread::ThreadPoolOptions()); + } + + Status Run(const tensorflow::RunOptions& run_options, + const std::vector>& inputs, + const std::vector& output_names, + const std::vector& target_nodes, + std::vector* outputs, + tensorflow::RunMetadata* run_metadata, + const tensorflow::thread::ThreadPoolOptions& thread_pool_options) + override { + for (const std::string& output_name : output_names) { + Tensor output; + if (output_name == "dense/bias") { + outputs->push_back( + Tensor(tensorflow::DT_FLOAT, tensorflow::TensorShape({50}))); + } else if (output_name == "dense/kernel") { + outputs->push_back( + Tensor(tensorflow::DT_FLOAT, tensorflow::TensorShape({100, 50}))); + } else { + // Create a scalar float tensor. + outputs->push_back( + Tensor(tensorflow::DT_FLOAT, tensorflow::TensorShape({}))); + } + } + return Status::OK(); + } +}; + +// This pass is only available in the tf-opt binary for testing. +class LiftVariablesTestPass + : public PassWrapper> { + public: + LiftVariablesTestPass() { session_ = new FakeSession(); } + + ~LiftVariablesTestPass() override { delete session_; } + + void runOnOperation() override { + ModuleOp module = getOperation(); + if (failed(LiftVariables(module, session_))) signalPassFailure(); + } + + private: + Session* session_; +}; + +// This pass is only available in the tf-opt binary for testing. +class LiftVariablesInvalidSessionTestPass + : public PassWrapper> { + public: + void runOnOperation() override { + ModuleOp module = getOperation(); + // Pass an invalid session argument, which is a nullptr. + if (failed(LiftVariables(module, /*session=*/nullptr))) signalPassFailure(); + } +}; + +} // namespace tf_saved_model +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_LIFT_VARIABLES_TEST_PASS_H_ diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/lift_variables_test_pass_registration.cc b/tensorflow/compiler/mlir/tensorflow/transforms/lift_variables_test_pass_registration.cc new file mode 100644 index 00000000000..19c367c6d46 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/lift_variables_test_pass_registration.cc @@ -0,0 +1,32 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/tensorflow/transforms/lift_variables_test_pass.h" + +namespace mlir { +namespace tf_saved_model { + +static PassRegistration lift_variables_test_pass( + "tf-saved-model-lift-variables-test", + "Lift variables and save them as global tensors"); + +static PassRegistration + lift_variables_invalid_session_test_pass( + "tf-saved-model-lift-variables-invalid-session-test", + "Lift variables and save them as global tensors with an invalid " + "session"); + +} // namespace tf_saved_model +} // namespace mlir