Run placer before importing function to mlir.
PiperOrigin-RevId: 344157130 Change-Id: I263351164321a31c8d59c6681507f37b511ce275
This commit is contained in:
parent
8637177e25
commit
b4aa28ebbc
@ -41,6 +41,7 @@ cc_library(
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:AllPassesAndDialectsNoRegistration",
|
||||
"//tensorflow/compiler/mlir/tensorflow:translate_lib",
|
||||
"//tensorflow/core/common_runtime:core_cpu_base_no_ops",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
],
|
||||
|
@ -30,6 +30,8 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/import_utils.h"
|
||||
#include "tensorflow/core/common_runtime/function_body.h"
|
||||
#include "tensorflow/core/common_runtime/function_def_utils.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/framework/function.pb.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
@ -111,8 +113,24 @@ std::string ImportFunction(const std::string &functiondef_proto,
|
||||
}
|
||||
|
||||
const std::string &function_name = functiondef.signature().name();
|
||||
|
||||
const tensorflow::FunctionDef *fdef = flib_def.Find(function_name);
|
||||
if (fdef == nullptr) {
|
||||
s = tensorflow::errors::NotFound("Cannot find function ", function_name);
|
||||
Set_TF_Status_from_Status(status, s);
|
||||
return "// error";
|
||||
}
|
||||
|
||||
std::unique_ptr<tensorflow::FunctionBody> fbody;
|
||||
s = FunctionDefToBodyHelper(*fdef, tensorflow::AttrSlice(), &flib_def,
|
||||
&fbody);
|
||||
if (!s.ok()) {
|
||||
Set_TF_Status_from_Status(status, s);
|
||||
return "// error";
|
||||
}
|
||||
|
||||
mlir::MLIRContext context;
|
||||
auto module = ConvertFunctionToMlir(function_name, flib_def, &context);
|
||||
auto module = ConvertFunctionToMlir(fbody.get(), flib_def, &context);
|
||||
if (!module.ok()) {
|
||||
Set_TF_Status_from_Status(status, module.status());
|
||||
return "// error";
|
||||
|
@ -3693,23 +3693,16 @@ StatusOr<mlir::OwningModuleRef> ConvertGraphToMlir(
|
||||
}
|
||||
|
||||
stream_executor::port::StatusOr<mlir::OwningModuleRef> ConvertFunctionToMlir(
|
||||
mlir::StringRef name, const FunctionLibraryDefinition& flib_def,
|
||||
const FunctionBody* fbody, const FunctionLibraryDefinition& flib_def,
|
||||
mlir::MLIRContext* context) {
|
||||
const tensorflow::FunctionDef* fdef = flib_def.Find(name.str());
|
||||
if (fdef == nullptr)
|
||||
return tensorflow::errors::NotFound("Cannot find function ", name.str());
|
||||
|
||||
std::unique_ptr<tensorflow::FunctionBody> fbody;
|
||||
TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*fdef, tensorflow::AttrSlice(),
|
||||
&flib_def, &fbody));
|
||||
|
||||
tensorflow::GraphDebugInfo dummy_debug_info;
|
||||
tensorflow::GraphImportConfig specs;
|
||||
specs.graph_as_function = true;
|
||||
for (const auto* control_ret_node : fbody->control_ret_nodes)
|
||||
specs.control_outputs.push_back(control_ret_node->name());
|
||||
return GraphDefImporter::Convert(context, *fbody->graph, dummy_debug_info,
|
||||
flib_def, specs, name);
|
||||
flib_def, specs,
|
||||
fbody->fdef.signature().name());
|
||||
}
|
||||
|
||||
StatusOr<mlir::OwningModuleRef> ConvertSavedModelToMlir(
|
||||
|
@ -51,7 +51,7 @@ stream_executor::port::StatusOr<mlir::OwningModuleRef> ConvertGraphToMlir(
|
||||
// Given a Function, returns a MLIR module containing the graph, expressed with
|
||||
// tf_executor dialect.
|
||||
stream_executor::port::StatusOr<mlir::OwningModuleRef> ConvertFunctionToMlir(
|
||||
mlir::StringRef name, const FunctionLibraryDefinition& flib_def,
|
||||
const FunctionBody* fbody, const FunctionLibraryDefinition& flib_def,
|
||||
mlir::MLIRContext* context);
|
||||
|
||||
// Given a SavedModel, returns a MLIR module containing the functions, expressed
|
||||
|
@ -31,6 +31,8 @@ from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import random_ops
|
||||
|
||||
|
||||
class SoftDevicePlacementTest(test.TestCase, parameterized.TestCase):
|
||||
@ -38,6 +40,7 @@ class SoftDevicePlacementTest(test.TestCase, parameterized.TestCase):
|
||||
def setUp(self):
|
||||
super(SoftDevicePlacementTest, self).setUp()
|
||||
context._reset_context()
|
||||
context.ensure_initialized()
|
||||
config.set_soft_device_placement(enabled=True)
|
||||
context.context().log_device_placement = True
|
||||
|
||||
@ -109,6 +112,36 @@ class SoftDevicePlacementTest(test.TestCase, parameterized.TestCase):
|
||||
self.assertIn('CPU:0', a.device)
|
||||
self.assertIn('CPU:0', a.backing_device)
|
||||
|
||||
def testPlacedToDeviceInFunction(self):
|
||||
|
||||
@def_function.function
|
||||
def f():
|
||||
a = random_ops.random_uniform([32, 32])
|
||||
return math_ops.matmul(a, a)
|
||||
|
||||
gpus = config.list_physical_devices('GPU')
|
||||
if not gpus:
|
||||
self.assertIn('CPU:0', f().device)
|
||||
else:
|
||||
self.assertIn('GPU:0', f().device)
|
||||
|
||||
@test_util.disable_tfrt('b/173726713: Support properly inserting device at '
|
||||
'tf_to_corert lowering.')
|
||||
def testUnknownDeviceInFunction(self):
|
||||
|
||||
@def_function.function
|
||||
def f():
|
||||
with ops.device('GPU:42'):
|
||||
# With placer, the unknown GPU:42 will be replaced with GPU:0.
|
||||
a = constant_op.constant(1) + constant_op.constant(2)
|
||||
return a + constant_op.constant(2)
|
||||
|
||||
gpus = config.list_physical_devices('GPU')
|
||||
if not gpus:
|
||||
self.assertIn('CPU:0', f().device)
|
||||
else:
|
||||
self.assertIn('GPU:0', f().device)
|
||||
|
||||
|
||||
class HardDevicePlacementTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
@ -158,6 +191,7 @@ class ClusterPlacementTest(test.TestCase):
|
||||
workers, _ = test_util.create_local_cluster(2, 0)
|
||||
remote.connect_to_remote_host([workers[0].target, workers[1].target])
|
||||
|
||||
@test_util.disable_tfrt('remote host not supported yet.')
|
||||
def testNotFullySpecifiedTask(self):
|
||||
a = constant_op.constant(1)
|
||||
b = constant_op.constant(2)
|
||||
@ -165,6 +199,7 @@ class ClusterPlacementTest(test.TestCase):
|
||||
c = a + b
|
||||
self.assertIn('/job:worker/replica:0/task:0', c.device)
|
||||
|
||||
@test_util.disable_tfrt('remote host not supported yet.')
|
||||
def testRemoteUnknownDevice(self):
|
||||
a = constant_op.constant(1)
|
||||
b = constant_op.constant(2)
|
||||
@ -175,6 +210,7 @@ class ClusterPlacementTest(test.TestCase):
|
||||
del c
|
||||
self.assertIn('unknown device', cm.exception.message)
|
||||
|
||||
@test_util.disable_tfrt('remote host not supported yet.')
|
||||
def testUnknownDeviceInFunctionReturnUnknowDevice(self):
|
||||
|
||||
@def_function.function
|
||||
@ -188,6 +224,7 @@ class ClusterPlacementTest(test.TestCase):
|
||||
else:
|
||||
self.assertIn('GPU:0', f().device)
|
||||
|
||||
@test_util.disable_tfrt('remote host not supported yet.')
|
||||
def testUnknownDeviceInFunction(self):
|
||||
|
||||
@def_function.function
|
||||
|
Loading…
Reference in New Issue
Block a user