Run placer before importing function to mlir.

PiperOrigin-RevId: 344157130
Change-Id: I263351164321a31c8d59c6681507f37b511ce275
This commit is contained in:
Chuanhao Zhuge 2020-11-24 17:01:22 -08:00 committed by TensorFlower Gardener
parent 8637177e25
commit b4aa28ebbc
5 changed files with 61 additions and 12 deletions

View File

@ -41,6 +41,7 @@ cc_library(
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:AllPassesAndDialectsNoRegistration",
"//tensorflow/compiler/mlir/tensorflow:translate_lib",
"//tensorflow/core/common_runtime:core_cpu_base_no_ops",
"//tensorflow/core:framework",
"//tensorflow/core:protos_all_cc",
],

View File

@ -30,6 +30,8 @@ limitations under the License.
#include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/import_utils.h"
#include "tensorflow/core/common_runtime/function_body.h"
#include "tensorflow/core/common_runtime/function_def_utils.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/framework/op.h"
@ -111,8 +113,24 @@ std::string ImportFunction(const std::string &functiondef_proto,
}
const std::string &function_name = functiondef.signature().name();
const tensorflow::FunctionDef *fdef = flib_def.Find(function_name);
if (fdef == nullptr) {
s = tensorflow::errors::NotFound("Cannot find function ", function_name);
Set_TF_Status_from_Status(status, s);
return "// error";
}
std::unique_ptr<tensorflow::FunctionBody> fbody;
s = FunctionDefToBodyHelper(*fdef, tensorflow::AttrSlice(), &flib_def,
&fbody);
if (!s.ok()) {
Set_TF_Status_from_Status(status, s);
return "// error";
}
mlir::MLIRContext context;
auto module = ConvertFunctionToMlir(function_name, flib_def, &context);
auto module = ConvertFunctionToMlir(fbody.get(), flib_def, &context);
if (!module.ok()) {
Set_TF_Status_from_Status(status, module.status());
return "// error";

View File

@ -3693,23 +3693,16 @@ StatusOr<mlir::OwningModuleRef> ConvertGraphToMlir(
}
stream_executor::port::StatusOr<mlir::OwningModuleRef> ConvertFunctionToMlir(
mlir::StringRef name, const FunctionLibraryDefinition& flib_def,
const FunctionBody* fbody, const FunctionLibraryDefinition& flib_def,
mlir::MLIRContext* context) {
const tensorflow::FunctionDef* fdef = flib_def.Find(name.str());
if (fdef == nullptr)
return tensorflow::errors::NotFound("Cannot find function ", name.str());
std::unique_ptr<tensorflow::FunctionBody> fbody;
TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*fdef, tensorflow::AttrSlice(),
&flib_def, &fbody));
tensorflow::GraphDebugInfo dummy_debug_info;
tensorflow::GraphImportConfig specs;
specs.graph_as_function = true;
for (const auto* control_ret_node : fbody->control_ret_nodes)
specs.control_outputs.push_back(control_ret_node->name());
return GraphDefImporter::Convert(context, *fbody->graph, dummy_debug_info,
flib_def, specs, name);
flib_def, specs,
fbody->fdef.signature().name());
}
StatusOr<mlir::OwningModuleRef> ConvertSavedModelToMlir(

View File

@ -51,7 +51,7 @@ stream_executor::port::StatusOr<mlir::OwningModuleRef> ConvertGraphToMlir(
// Given a Function, returns a MLIR module containing the graph, expressed with
// tf_executor dialect.
stream_executor::port::StatusOr<mlir::OwningModuleRef> ConvertFunctionToMlir(
mlir::StringRef name, const FunctionLibraryDefinition& flib_def,
const FunctionBody* fbody, const FunctionLibraryDefinition& flib_def,
mlir::MLIRContext* context);
// Given a SavedModel, returns a MLIR module containing the functions, expressed

View File

@ -31,6 +31,8 @@ from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
class SoftDevicePlacementTest(test.TestCase, parameterized.TestCase):
@ -38,6 +40,7 @@ class SoftDevicePlacementTest(test.TestCase, parameterized.TestCase):
def setUp(self):
super(SoftDevicePlacementTest, self).setUp()
context._reset_context()
context.ensure_initialized()
config.set_soft_device_placement(enabled=True)
context.context().log_device_placement = True
@ -109,6 +112,36 @@ class SoftDevicePlacementTest(test.TestCase, parameterized.TestCase):
self.assertIn('CPU:0', a.device)
self.assertIn('CPU:0', a.backing_device)
def testPlacedToDeviceInFunction(self):
@def_function.function
def f():
a = random_ops.random_uniform([32, 32])
return math_ops.matmul(a, a)
gpus = config.list_physical_devices('GPU')
if not gpus:
self.assertIn('CPU:0', f().device)
else:
self.assertIn('GPU:0', f().device)
@test_util.disable_tfrt('b/173726713: Support properly inserting device at '
'tf_to_corert lowering.')
def testUnknownDeviceInFunction(self):
@def_function.function
def f():
with ops.device('GPU:42'):
# With placer, the unknown GPU:42 will be replaced with GPU:0.
a = constant_op.constant(1) + constant_op.constant(2)
return a + constant_op.constant(2)
gpus = config.list_physical_devices('GPU')
if not gpus:
self.assertIn('CPU:0', f().device)
else:
self.assertIn('GPU:0', f().device)
class HardDevicePlacementTest(test.TestCase, parameterized.TestCase):
@ -158,6 +191,7 @@ class ClusterPlacementTest(test.TestCase):
workers, _ = test_util.create_local_cluster(2, 0)
remote.connect_to_remote_host([workers[0].target, workers[1].target])
@test_util.disable_tfrt('remote host not supported yet.')
def testNotFullySpecifiedTask(self):
a = constant_op.constant(1)
b = constant_op.constant(2)
@ -165,6 +199,7 @@ class ClusterPlacementTest(test.TestCase):
c = a + b
self.assertIn('/job:worker/replica:0/task:0', c.device)
@test_util.disable_tfrt('remote host not supported yet.')
def testRemoteUnknownDevice(self):
a = constant_op.constant(1)
b = constant_op.constant(2)
@ -175,6 +210,7 @@ class ClusterPlacementTest(test.TestCase):
del c
self.assertIn('unknown device', cm.exception.message)
@test_util.disable_tfrt('remote host not supported yet.')
def testUnknownDeviceInFunctionReturnUnknowDevice(self):
@def_function.function
@ -188,6 +224,7 @@ class ClusterPlacementTest(test.TestCase):
else:
self.assertIn('GPU:0', f().device)
@test_util.disable_tfrt('remote host not supported yet.')
def testUnknownDeviceInFunction(self):
@def_function.function