Allow dynamically configuring device placement
Enable setting soft device placement as well as logging dynamically. This required ensuring the device placement policy was part of the cache key. Further, we fix the logging to ensure in eager mode if a kernel is retrieved from the kernel cache, then the execution is still logged. We also log closer to the actual op execution to avoid logging before all checks have been done. PiperOrigin-RevId: 311271808 Change-Id: I9765228894f84a3447cc03332a2559f6d933165b
This commit is contained in:
parent
088fc3a9b5
commit
d5b3ec27d1
@ -657,3 +657,17 @@ TFE_TensorHandle* TFE_CreatePackedTensorHandle(TFE_Context* ctx,
|
|||||||
std::move(tensor_handles), context, &handle);
|
std::move(tensor_handles), context, &handle);
|
||||||
return tensorflow::wrap(handle);
|
return tensorflow::wrap(handle);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void TFE_ContextSetSoftDevicePlacement(TFE_Context* ctx, unsigned char enable,
|
||||||
|
TF_Status* status) {
|
||||||
|
tensorflow::EagerContext* context =
|
||||||
|
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||||
|
context->SetAllowSoftPlacement(enable);
|
||||||
|
}
|
||||||
|
|
||||||
|
void TFE_ContextSetLogDevicePlacement(TFE_Context* ctx, unsigned char enable,
|
||||||
|
TF_Status* status) {
|
||||||
|
tensorflow::EagerContext* context =
|
||||||
|
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||||
|
context->SetLogDevicePlacement(enable);
|
||||||
|
}
|
||||||
|
@ -549,6 +549,18 @@ TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_CreatePackedTensorHandle(
|
|||||||
TFE_Context* ctx, TFE_TensorHandle** handles, int* num_handles,
|
TFE_Context* ctx, TFE_TensorHandle** handles, int* num_handles,
|
||||||
TF_Status* status);
|
TF_Status* status);
|
||||||
|
|
||||||
|
// Configure soft device placement policy for the eager executor. Note this
|
||||||
|
// policy is applied to any subsequent op executions.
|
||||||
|
TF_CAPI_EXPORT void TFE_ContextSetSoftDevicePlacement(TFE_Context* ctx,
|
||||||
|
unsigned char enable,
|
||||||
|
TF_Status* status);
|
||||||
|
|
||||||
|
// Configure device placement policy logging for the eager executor. Note this
|
||||||
|
// policy is applied to any subsequent op executions.
|
||||||
|
TF_CAPI_EXPORT void TFE_ContextSetLogDevicePlacement(TFE_Context* ctx,
|
||||||
|
unsigned char enable,
|
||||||
|
TF_Status* status);
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
} /* end extern "C" */
|
} /* end extern "C" */
|
||||||
#endif
|
#endif
|
||||||
|
@ -300,7 +300,9 @@ class EagerContext : public AbstractContextInterface, public core::RefCounted {
|
|||||||
void AddKernelToCache(Fprint128 cache_key, KernelAndDevice* kernel);
|
void AddKernelToCache(Fprint128 cache_key, KernelAndDevice* kernel);
|
||||||
|
|
||||||
bool LogDevicePlacement() const { return log_device_placement_; }
|
bool LogDevicePlacement() const { return log_device_placement_; }
|
||||||
|
void SetLogDevicePlacement(bool enable) { log_device_placement_ = enable; }
|
||||||
bool AllowSoftPlacement() const { return allow_soft_placement_; }
|
bool AllowSoftPlacement() const { return allow_soft_placement_; }
|
||||||
|
void SetAllowSoftPlacement(bool enable) { allow_soft_placement_ = enable; }
|
||||||
bool LogMemory() const { return log_memory_; }
|
bool LogMemory() const { return log_memory_; }
|
||||||
|
|
||||||
Rendezvous* GetRendezvous() const { return rendezvous_; }
|
Rendezvous* GetRendezvous() const { return rendezvous_; }
|
||||||
@ -625,9 +627,8 @@ class EagerContext : public AbstractContextInterface, public core::RefCounted {
|
|||||||
mutex metadata_mu_;
|
mutex metadata_mu_;
|
||||||
RunMetadata run_metadata_ TF_GUARDED_BY(metadata_mu_);
|
RunMetadata run_metadata_ TF_GUARDED_BY(metadata_mu_);
|
||||||
GraphCollector graph_collector_;
|
GraphCollector graph_collector_;
|
||||||
// TODO(fishx): Allow update following two bool after context creation.
|
std::atomic<bool> log_device_placement_;
|
||||||
const bool log_device_placement_;
|
std::atomic<bool> allow_soft_placement_;
|
||||||
const bool allow_soft_placement_;
|
|
||||||
|
|
||||||
// Information related to step containers.
|
// Information related to step containers.
|
||||||
std::atomic<int> num_active_steps_;
|
std::atomic<int> num_active_steps_;
|
||||||
|
@ -365,6 +365,9 @@ Status GetOrCreateKernelAndDevice(
|
|||||||
Device* device = absl::get<Device*>(op->Device());
|
Device* device = absl::get<Device*>(op->Device());
|
||||||
|
|
||||||
Fprint128 cache_key = op->MutableAttrs()->CacheKey(op->DeviceName());
|
Fprint128 cache_key = op->MutableAttrs()->CacheKey(op->DeviceName());
|
||||||
|
/// Include soft placement policy in cache key since the placement strategy
|
||||||
|
// can change and thus affect which kernel is picked.
|
||||||
|
cache_key = FingerprintCat128(cache_key, ctx.AllowSoftPlacement());
|
||||||
|
|
||||||
std::vector<Device*> input_dev_ptrs;
|
std::vector<Device*> input_dev_ptrs;
|
||||||
absl::flat_hash_map<string, const std::vector<string>*> composite_devices;
|
absl::flat_hash_map<string, const std::vector<string>*> composite_devices;
|
||||||
@ -488,13 +491,6 @@ Status GetOrCreateKernelAndDevice(
|
|||||||
<< KernelsRegisteredForOp(op->Name());
|
<< KernelsRegisteredForOp(op->Name());
|
||||||
op->SetDevice(device);
|
op->SetDevice(device);
|
||||||
}
|
}
|
||||||
if (ctx.LogDevicePlacement() || VLOG_IS_ON(1)) {
|
|
||||||
string msg = strings::StrCat("Executing op ", ndef.op(), " in device ",
|
|
||||||
DeviceNameOrUnspecified(device));
|
|
||||||
if (!logging::LogToListeners(msg)) {
|
|
||||||
LOG(INFO) << msg;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
FunctionLibraryRuntime* flr =
|
FunctionLibraryRuntime* flr =
|
||||||
device == nullptr ? nullptr : ctx.func_lib(device);
|
device == nullptr ? nullptr : ctx.func_lib(device);
|
||||||
@ -607,6 +603,14 @@ Status EagerLocalExecute(EagerOperation* op, TensorHandle** retvals,
|
|||||||
int num_outputs = kernel->num_outputs();
|
int num_outputs = kernel->num_outputs();
|
||||||
TF_RETURN_IF_ERROR(ValidateInputTypeAndPlacement(&ctx, op, kernel));
|
TF_RETURN_IF_ERROR(ValidateInputTypeAndPlacement(&ctx, op, kernel));
|
||||||
|
|
||||||
|
if (ctx.LogDevicePlacement() || VLOG_IS_ON(1)) {
|
||||||
|
string msg = strings::StrCat("Executing op ", op->Name(), " in device ",
|
||||||
|
kernel->device()->name());
|
||||||
|
if (!logging::LogToListeners(msg)) {
|
||||||
|
LOG(INFO) << msg;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
GraphCollector* graph_collector = nullptr;
|
GraphCollector* graph_collector = nullptr;
|
||||||
if (ctx.ShouldStoreGraphs()) {
|
if (ctx.ShouldStoreGraphs()) {
|
||||||
graph_collector = ctx.GetGraphCollector();
|
graph_collector = ctx.GetGraphCollector();
|
||||||
@ -841,6 +845,16 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals,
|
|||||||
ctx.GetContextViewId(), eager_client.get(),
|
ctx.GetContextViewId(), eager_client.get(),
|
||||||
op->MutableAttrs()->BuildNodeDef(), op->EagerContext().FuncLibDef(),
|
op->MutableAttrs()->BuildNodeDef(), op->EagerContext().FuncLibDef(),
|
||||||
op->Inputs(), {retvals, num_outputs}));
|
op->Inputs(), {retvals, num_outputs}));
|
||||||
|
|
||||||
|
if (op->EagerContext().LogDevicePlacement() || VLOG_IS_ON(1)) {
|
||||||
|
string msg = strings::StrCat(
|
||||||
|
"Executing op ", op->Name(), " on task ",
|
||||||
|
DeviceNameUtils::ParsedNameToString(op->GetDeviceParsedName()));
|
||||||
|
if (!logging::LogToListeners(msg)) {
|
||||||
|
LOG(INFO) << msg;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
Status s = executor.AddOrExecute(std::move(node));
|
Status s = executor.AddOrExecute(std::move(node));
|
||||||
// Since the operation failed, we need to Unref any outputs that were
|
// Since the operation failed, we need to Unref any outputs that were
|
||||||
// allocated.
|
// allocated.
|
||||||
@ -1119,15 +1133,6 @@ Status EagerExecute(EagerOperation* op, TensorHandle** retvals,
|
|||||||
return EagerLocalExecute(op, retvals, num_retvals);
|
return EagerLocalExecute(op, retvals, num_retvals);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (op->EagerContext().LogDevicePlacement() || VLOG_IS_ON(1)) {
|
|
||||||
string msg = strings::StrCat(
|
|
||||||
"Executing op ", op->Name(), " on task ",
|
|
||||||
DeviceNameUtils::ParsedNameToString(op->GetDeviceParsedName()));
|
|
||||||
if (!logging::LogToListeners(msg)) {
|
|
||||||
LOG(INFO) << msg;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#if defined(IS_MOBILE_PLATFORM)
|
#if defined(IS_MOBILE_PLATFORM)
|
||||||
return errors::Unimplemented(
|
return errors::Unimplemented(
|
||||||
"Eager's remote execution is not available on mobile devices.");
|
"Eager's remote execution is not available on mobile devices.");
|
||||||
@ -1428,6 +1433,14 @@ void EagerLocalExecuteAsync(EagerOperation* op, TensorHandle** retvals,
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (ctx.LogDevicePlacement() || VLOG_IS_ON(1)) {
|
||||||
|
string msg = strings::StrCat("Executing op ", op->Name(), " in device ",
|
||||||
|
kernel->device()->name());
|
||||||
|
if (!logging::LogToListeners(msg)) {
|
||||||
|
LOG(INFO) << msg;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
GraphCollector* graph_collector = nullptr;
|
GraphCollector* graph_collector = nullptr;
|
||||||
if (ctx.ShouldStoreGraphs()) {
|
if (ctx.ShouldStoreGraphs()) {
|
||||||
graph_collector = ctx.GetGraphCollector();
|
graph_collector = ctx.GetGraphCollector();
|
||||||
|
@ -1917,6 +1917,9 @@ class SessionTest(test_util.TensorFlowTestCase):
|
|||||||
a = constant_op.constant(1)
|
a = constant_op.constant(1)
|
||||||
b = constant_op.constant(2)
|
b = constant_op.constant(2)
|
||||||
c = a + b
|
c = a + b
|
||||||
|
# Ensure if the same kernel with the same arguments is executed then its
|
||||||
|
# execution is logged.
|
||||||
|
d = a + b
|
||||||
else:
|
else:
|
||||||
# Passing the config to the server, but not the session should still
|
# Passing the config to the server, but not the session should still
|
||||||
# result in logging device placement.
|
# result in logging device placement.
|
||||||
@ -1925,12 +1928,16 @@ class SessionTest(test_util.TensorFlowTestCase):
|
|||||||
a = constant_op.constant(1)
|
a = constant_op.constant(1)
|
||||||
b = constant_op.constant(2)
|
b = constant_op.constant(2)
|
||||||
c = a + b
|
c = a + b
|
||||||
|
d = a + b
|
||||||
with session.Session(server.target) as sess:
|
with session.Session(server.target) as sess:
|
||||||
with CaptureStderr() as log:
|
with CaptureStderr() as log:
|
||||||
sess.run(c)
|
c, d = sess.run([c, d])
|
||||||
|
|
||||||
|
self.assertEqual(c, 3)
|
||||||
|
self.assertEqual(d, 3)
|
||||||
# Ensure that we did log device placement.
|
# Ensure that we did log device placement.
|
||||||
self.assertTrue('/replica:0/task:0/device:CPU:0' in str(log), str(log))
|
add_executions = [l for l in str(log).splitlines() if 'AddV2' in l]
|
||||||
|
self.assertEqual(len(add_executions), 2)
|
||||||
|
|
||||||
@test_util.run_v1_only('b/120545219')
|
@test_util.run_v1_only('b/120545219')
|
||||||
def testLocalMasterSessionTimeout(self):
|
def testLocalMasterSessionTimeout(self):
|
||||||
|
@ -1509,9 +1509,11 @@ class Context(object):
|
|||||||
return self.config.allow_soft_placement
|
return self.config.allow_soft_placement
|
||||||
|
|
||||||
@soft_device_placement.setter
|
@soft_device_placement.setter
|
||||||
def soft_device_placement(self, enabled):
|
def soft_device_placement(self, enable):
|
||||||
self._soft_device_placement = enabled
|
if self._context_handle is not None:
|
||||||
|
pywrap_tfe.TFE_ContextSetSoftDevicePlacement(self._handle, enable)
|
||||||
|
|
||||||
|
self._soft_device_placement = enable
|
||||||
self._thread_local_data.function_call_options = None
|
self._thread_local_data.function_call_options = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -1519,15 +1521,11 @@ class Context(object):
|
|||||||
return self.config.log_device_placement
|
return self.config.log_device_placement
|
||||||
|
|
||||||
@log_device_placement.setter
|
@log_device_placement.setter
|
||||||
def log_device_placement(self, enabled):
|
def log_device_placement(self, enable):
|
||||||
if self._log_device_placement == enabled:
|
|
||||||
return
|
|
||||||
|
|
||||||
if self._context_handle is not None:
|
if self._context_handle is not None:
|
||||||
raise RuntimeError(
|
pywrap_tfe.TFE_ContextSetLogDevicePlacement(self._handle, enable)
|
||||||
"Device placement logging must be set at program startup")
|
|
||||||
|
|
||||||
self._log_device_placement = enabled
|
self._log_device_placement = enable
|
||||||
self._thread_local_data.function_call_options = None
|
self._thread_local_data.function_call_options = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -1112,5 +1112,4 @@ class EagerTensorCacheTest(test_util.TensorFlowTestCase):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
context.set_log_device_placement(True)
|
|
||||||
test.main()
|
test.main()
|
||||||
|
@ -159,7 +159,6 @@ class ConfigTest(test.TestCase, parameterized.TestCase):
|
|||||||
else:
|
else:
|
||||||
self.assertFalse(config.get_soft_device_placement())
|
self.assertFalse(config.get_soft_device_placement())
|
||||||
|
|
||||||
@def_function.function
|
|
||||||
def mod():
|
def mod():
|
||||||
with ops.device('/device:GPU:0'):
|
with ops.device('/device:GPU:0'):
|
||||||
a = constant_op.constant(1.0)
|
a = constant_op.constant(1.0)
|
||||||
@ -172,8 +171,10 @@ class ConfigTest(test.TestCase, parameterized.TestCase):
|
|||||||
config.get_soft_device_placement(),
|
config.get_soft_device_placement(),
|
||||||
context.context().soft_device_placement)
|
context.context().soft_device_placement)
|
||||||
|
|
||||||
# Since soft placement is enabled, the mod operation should work with CPU
|
# Since soft placement is enabled, the mod operation should fallback to CPU
|
||||||
|
# with pure eager execution as well as functions
|
||||||
mod()
|
mod()
|
||||||
|
def_function.function(mod)()
|
||||||
|
|
||||||
config.set_soft_device_placement(False)
|
config.set_soft_device_placement(False)
|
||||||
self.assertEqual(config.get_soft_device_placement(), False)
|
self.assertEqual(config.get_soft_device_placement(), False)
|
||||||
@ -182,8 +183,11 @@ class ConfigTest(test.TestCase, parameterized.TestCase):
|
|||||||
context.context().soft_device_placement)
|
context.context().soft_device_placement)
|
||||||
|
|
||||||
# Since soft placement is disabled, the mod operation should fail on GPU
|
# Since soft placement is disabled, the mod operation should fail on GPU
|
||||||
|
# with pure eager execution as well as functions
|
||||||
with self.assertRaises(errors.InvalidArgumentError):
|
with self.assertRaises(errors.InvalidArgumentError):
|
||||||
mod()
|
mod()
|
||||||
|
with self.assertRaises(errors.InvalidArgumentError):
|
||||||
|
def_function.function(mod)()
|
||||||
|
|
||||||
@reset_eager
|
@reset_eager
|
||||||
def testLogDevicePlacement(self):
|
def testLogDevicePlacement(self):
|
||||||
@ -203,12 +207,8 @@ class ConfigTest(test.TestCase, parameterized.TestCase):
|
|||||||
|
|
||||||
context.ensure_initialized()
|
context.ensure_initialized()
|
||||||
|
|
||||||
with self.assertRaises(RuntimeError):
|
# Changing the device placement should not throw an exception
|
||||||
context.set_log_device_placement(True)
|
context.set_log_device_placement(True)
|
||||||
|
|
||||||
# If the setting the device placement is a no-op, do not throw a runtime
|
|
||||||
# exception.
|
|
||||||
context.set_log_device_placement(False)
|
|
||||||
|
|
||||||
@reset_eager
|
@reset_eager
|
||||||
def testEnableMlirBridge(self):
|
def testEnableMlirBridge(self):
|
||||||
|
@ -488,6 +488,18 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
|
|||||||
// NOTE: different from TFE_ContextSyncExecutors that raises potential
|
// NOTE: different from TFE_ContextSyncExecutors that raises potential
|
||||||
// errors, deliberately ignore executor statuses in cleanup.
|
// errors, deliberately ignore executor statuses in cleanup.
|
||||||
});
|
});
|
||||||
|
m.def("TFE_ContextSetSoftDevicePlacement", [](py::handle& ctx, bool enable) {
|
||||||
|
tensorflow::Safe_TF_StatusPtr status =
|
||||||
|
tensorflow::make_safe(TF_NewStatus());
|
||||||
|
TFE_ContextSetSoftDevicePlacement(tensorflow::InputTFE_Context(ctx), enable,
|
||||||
|
status.get());
|
||||||
|
});
|
||||||
|
m.def("TFE_ContextSetLogDevicePlacement", [](py::handle& ctx, bool enable) {
|
||||||
|
tensorflow::Safe_TF_StatusPtr status =
|
||||||
|
tensorflow::make_safe(TF_NewStatus());
|
||||||
|
TFE_ContextSetSoftDevicePlacement(tensorflow::InputTFE_Context(ctx), enable,
|
||||||
|
status.get());
|
||||||
|
});
|
||||||
|
|
||||||
// TFE_Executor logic
|
// TFE_Executor logic
|
||||||
m.def(
|
m.def(
|
||||||
|
Loading…
Reference in New Issue
Block a user