Allow dynamically configuring device placement

Enable setting soft device placement as well as logging dynamically.
This required ensuring the device placement policy was part of the cache
key.

Further, we fix the logging to ensure in eager mode if a kernel is
retrieved from the kernel cache, then the execution is still logged. We
also log closer to the actual op execution to avoid logging before all
checks have been done.

PiperOrigin-RevId: 311271808
Change-Id: I9765228894f84a3447cc03332a2559f6d933165b
This commit is contained in:
Gaurav Jain 2020-05-12 23:14:53 -07:00 committed by TensorFlower Gardener
parent 088fc3a9b5
commit d5b3ec27d1
9 changed files with 95 additions and 39 deletions

View File

@ -657,3 +657,17 @@ TFE_TensorHandle* TFE_CreatePackedTensorHandle(TFE_Context* ctx,
std::move(tensor_handles), context, &handle); std::move(tensor_handles), context, &handle);
return tensorflow::wrap(handle); return tensorflow::wrap(handle);
} }
void TFE_ContextSetSoftDevicePlacement(TFE_Context* ctx, unsigned char enable,
TF_Status* status) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
context->SetAllowSoftPlacement(enable);
}
void TFE_ContextSetLogDevicePlacement(TFE_Context* ctx, unsigned char enable,
TF_Status* status) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
context->SetLogDevicePlacement(enable);
}

View File

@ -549,6 +549,18 @@ TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_CreatePackedTensorHandle(
TFE_Context* ctx, TFE_TensorHandle** handles, int* num_handles, TFE_Context* ctx, TFE_TensorHandle** handles, int* num_handles,
TF_Status* status); TF_Status* status);
// Configure soft device placement policy for the eager executor. Note this
// policy is applied to any subsequent op executions.
TF_CAPI_EXPORT void TFE_ContextSetSoftDevicePlacement(TFE_Context* ctx,
unsigned char enable,
TF_Status* status);
// Configure device placement policy logging for the eager executor. Note this
// policy is applied to any subsequent op executions.
TF_CAPI_EXPORT void TFE_ContextSetLogDevicePlacement(TFE_Context* ctx,
unsigned char enable,
TF_Status* status);
#ifdef __cplusplus #ifdef __cplusplus
} /* end extern "C" */ } /* end extern "C" */
#endif #endif

View File

@ -300,7 +300,9 @@ class EagerContext : public AbstractContextInterface, public core::RefCounted {
void AddKernelToCache(Fprint128 cache_key, KernelAndDevice* kernel); void AddKernelToCache(Fprint128 cache_key, KernelAndDevice* kernel);
bool LogDevicePlacement() const { return log_device_placement_; } bool LogDevicePlacement() const { return log_device_placement_; }
void SetLogDevicePlacement(bool enable) { log_device_placement_ = enable; }
bool AllowSoftPlacement() const { return allow_soft_placement_; } bool AllowSoftPlacement() const { return allow_soft_placement_; }
void SetAllowSoftPlacement(bool enable) { allow_soft_placement_ = enable; }
bool LogMemory() const { return log_memory_; } bool LogMemory() const { return log_memory_; }
Rendezvous* GetRendezvous() const { return rendezvous_; } Rendezvous* GetRendezvous() const { return rendezvous_; }
@ -625,9 +627,8 @@ class EagerContext : public AbstractContextInterface, public core::RefCounted {
mutex metadata_mu_; mutex metadata_mu_;
RunMetadata run_metadata_ TF_GUARDED_BY(metadata_mu_); RunMetadata run_metadata_ TF_GUARDED_BY(metadata_mu_);
GraphCollector graph_collector_; GraphCollector graph_collector_;
// TODO(fishx): Allow update following two bool after context creation. std::atomic<bool> log_device_placement_;
const bool log_device_placement_; std::atomic<bool> allow_soft_placement_;
const bool allow_soft_placement_;
// Information related to step containers. // Information related to step containers.
std::atomic<int> num_active_steps_; std::atomic<int> num_active_steps_;

View File

@ -365,6 +365,9 @@ Status GetOrCreateKernelAndDevice(
Device* device = absl::get<Device*>(op->Device()); Device* device = absl::get<Device*>(op->Device());
Fprint128 cache_key = op->MutableAttrs()->CacheKey(op->DeviceName()); Fprint128 cache_key = op->MutableAttrs()->CacheKey(op->DeviceName());
/// Include soft placement policy in cache key since the placement strategy
// can change and thus affect which kernel is picked.
cache_key = FingerprintCat128(cache_key, ctx.AllowSoftPlacement());
std::vector<Device*> input_dev_ptrs; std::vector<Device*> input_dev_ptrs;
absl::flat_hash_map<string, const std::vector<string>*> composite_devices; absl::flat_hash_map<string, const std::vector<string>*> composite_devices;
@ -488,13 +491,6 @@ Status GetOrCreateKernelAndDevice(
<< KernelsRegisteredForOp(op->Name()); << KernelsRegisteredForOp(op->Name());
op->SetDevice(device); op->SetDevice(device);
} }
if (ctx.LogDevicePlacement() || VLOG_IS_ON(1)) {
string msg = strings::StrCat("Executing op ", ndef.op(), " in device ",
DeviceNameOrUnspecified(device));
if (!logging::LogToListeners(msg)) {
LOG(INFO) << msg;
}
}
FunctionLibraryRuntime* flr = FunctionLibraryRuntime* flr =
device == nullptr ? nullptr : ctx.func_lib(device); device == nullptr ? nullptr : ctx.func_lib(device);
@ -607,6 +603,14 @@ Status EagerLocalExecute(EagerOperation* op, TensorHandle** retvals,
int num_outputs = kernel->num_outputs(); int num_outputs = kernel->num_outputs();
TF_RETURN_IF_ERROR(ValidateInputTypeAndPlacement(&ctx, op, kernel)); TF_RETURN_IF_ERROR(ValidateInputTypeAndPlacement(&ctx, op, kernel));
if (ctx.LogDevicePlacement() || VLOG_IS_ON(1)) {
string msg = strings::StrCat("Executing op ", op->Name(), " in device ",
kernel->device()->name());
if (!logging::LogToListeners(msg)) {
LOG(INFO) << msg;
}
}
GraphCollector* graph_collector = nullptr; GraphCollector* graph_collector = nullptr;
if (ctx.ShouldStoreGraphs()) { if (ctx.ShouldStoreGraphs()) {
graph_collector = ctx.GetGraphCollector(); graph_collector = ctx.GetGraphCollector();
@ -841,6 +845,16 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals,
ctx.GetContextViewId(), eager_client.get(), ctx.GetContextViewId(), eager_client.get(),
op->MutableAttrs()->BuildNodeDef(), op->EagerContext().FuncLibDef(), op->MutableAttrs()->BuildNodeDef(), op->EagerContext().FuncLibDef(),
op->Inputs(), {retvals, num_outputs})); op->Inputs(), {retvals, num_outputs}));
if (op->EagerContext().LogDevicePlacement() || VLOG_IS_ON(1)) {
string msg = strings::StrCat(
"Executing op ", op->Name(), " on task ",
DeviceNameUtils::ParsedNameToString(op->GetDeviceParsedName()));
if (!logging::LogToListeners(msg)) {
LOG(INFO) << msg;
}
}
Status s = executor.AddOrExecute(std::move(node)); Status s = executor.AddOrExecute(std::move(node));
// Since the operation failed, we need to Unref any outputs that were // Since the operation failed, we need to Unref any outputs that were
// allocated. // allocated.
@ -1119,15 +1133,6 @@ Status EagerExecute(EagerOperation* op, TensorHandle** retvals,
return EagerLocalExecute(op, retvals, num_retvals); return EagerLocalExecute(op, retvals, num_retvals);
} }
if (op->EagerContext().LogDevicePlacement() || VLOG_IS_ON(1)) {
string msg = strings::StrCat(
"Executing op ", op->Name(), " on task ",
DeviceNameUtils::ParsedNameToString(op->GetDeviceParsedName()));
if (!logging::LogToListeners(msg)) {
LOG(INFO) << msg;
}
}
#if defined(IS_MOBILE_PLATFORM) #if defined(IS_MOBILE_PLATFORM)
return errors::Unimplemented( return errors::Unimplemented(
"Eager's remote execution is not available on mobile devices."); "Eager's remote execution is not available on mobile devices.");
@ -1428,6 +1433,14 @@ void EagerLocalExecuteAsync(EagerOperation* op, TensorHandle** retvals,
return; return;
} }
if (ctx.LogDevicePlacement() || VLOG_IS_ON(1)) {
string msg = strings::StrCat("Executing op ", op->Name(), " in device ",
kernel->device()->name());
if (!logging::LogToListeners(msg)) {
LOG(INFO) << msg;
}
}
GraphCollector* graph_collector = nullptr; GraphCollector* graph_collector = nullptr;
if (ctx.ShouldStoreGraphs()) { if (ctx.ShouldStoreGraphs()) {
graph_collector = ctx.GetGraphCollector(); graph_collector = ctx.GetGraphCollector();

View File

@ -1917,6 +1917,9 @@ class SessionTest(test_util.TensorFlowTestCase):
a = constant_op.constant(1) a = constant_op.constant(1)
b = constant_op.constant(2) b = constant_op.constant(2)
c = a + b c = a + b
# Ensure if the same kernel with the same arguments is executed then its
# execution is logged.
d = a + b
else: else:
# Passing the config to the server, but not the session should still # Passing the config to the server, but not the session should still
# result in logging device placement. # result in logging device placement.
@ -1925,12 +1928,16 @@ class SessionTest(test_util.TensorFlowTestCase):
a = constant_op.constant(1) a = constant_op.constant(1)
b = constant_op.constant(2) b = constant_op.constant(2)
c = a + b c = a + b
d = a + b
with session.Session(server.target) as sess: with session.Session(server.target) as sess:
with CaptureStderr() as log: with CaptureStderr() as log:
sess.run(c) c, d = sess.run([c, d])
self.assertEqual(c, 3)
self.assertEqual(d, 3)
# Ensure that we did log device placement. # Ensure that we did log device placement.
self.assertTrue('/replica:0/task:0/device:CPU:0' in str(log), str(log)) add_executions = [l for l in str(log).splitlines() if 'AddV2' in l]
self.assertEqual(len(add_executions), 2)
@test_util.run_v1_only('b/120545219') @test_util.run_v1_only('b/120545219')
def testLocalMasterSessionTimeout(self): def testLocalMasterSessionTimeout(self):

View File

@ -1509,9 +1509,11 @@ class Context(object):
return self.config.allow_soft_placement return self.config.allow_soft_placement
@soft_device_placement.setter @soft_device_placement.setter
def soft_device_placement(self, enabled): def soft_device_placement(self, enable):
self._soft_device_placement = enabled if self._context_handle is not None:
pywrap_tfe.TFE_ContextSetSoftDevicePlacement(self._handle, enable)
self._soft_device_placement = enable
self._thread_local_data.function_call_options = None self._thread_local_data.function_call_options = None
@property @property
@ -1519,15 +1521,11 @@ class Context(object):
return self.config.log_device_placement return self.config.log_device_placement
@log_device_placement.setter @log_device_placement.setter
def log_device_placement(self, enabled): def log_device_placement(self, enable):
if self._log_device_placement == enabled:
return
if self._context_handle is not None: if self._context_handle is not None:
raise RuntimeError( pywrap_tfe.TFE_ContextSetLogDevicePlacement(self._handle, enable)
"Device placement logging must be set at program startup")
self._log_device_placement = enabled self._log_device_placement = enable
self._thread_local_data.function_call_options = None self._thread_local_data.function_call_options = None
@property @property

View File

@ -1112,5 +1112,4 @@ class EagerTensorCacheTest(test_util.TensorFlowTestCase):
if __name__ == '__main__': if __name__ == '__main__':
context.set_log_device_placement(True)
test.main() test.main()

View File

@ -159,7 +159,6 @@ class ConfigTest(test.TestCase, parameterized.TestCase):
else: else:
self.assertFalse(config.get_soft_device_placement()) self.assertFalse(config.get_soft_device_placement())
@def_function.function
def mod(): def mod():
with ops.device('/device:GPU:0'): with ops.device('/device:GPU:0'):
a = constant_op.constant(1.0) a = constant_op.constant(1.0)
@ -172,8 +171,10 @@ class ConfigTest(test.TestCase, parameterized.TestCase):
config.get_soft_device_placement(), config.get_soft_device_placement(),
context.context().soft_device_placement) context.context().soft_device_placement)
# Since soft placement is enabled, the mod operation should work with CPU # Since soft placement is enabled, the mod operation should fallback to CPU
# with pure eager execution as well as functions
mod() mod()
def_function.function(mod)()
config.set_soft_device_placement(False) config.set_soft_device_placement(False)
self.assertEqual(config.get_soft_device_placement(), False) self.assertEqual(config.get_soft_device_placement(), False)
@ -182,8 +183,11 @@ class ConfigTest(test.TestCase, parameterized.TestCase):
context.context().soft_device_placement) context.context().soft_device_placement)
# Since soft placement is disabled, the mod operation should fail on GPU # Since soft placement is disabled, the mod operation should fail on GPU
# with pure eager execution as well as functions
with self.assertRaises(errors.InvalidArgumentError): with self.assertRaises(errors.InvalidArgumentError):
mod() mod()
with self.assertRaises(errors.InvalidArgumentError):
def_function.function(mod)()
@reset_eager @reset_eager
def testLogDevicePlacement(self): def testLogDevicePlacement(self):
@ -203,12 +207,8 @@ class ConfigTest(test.TestCase, parameterized.TestCase):
context.ensure_initialized() context.ensure_initialized()
with self.assertRaises(RuntimeError): # Changing the device placement should not throw an exception
context.set_log_device_placement(True) context.set_log_device_placement(True)
# If the setting the device placement is a no-op, do not throw a runtime
# exception.
context.set_log_device_placement(False)
@reset_eager @reset_eager
def testEnableMlirBridge(self): def testEnableMlirBridge(self):

View File

@ -488,6 +488,18 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
// NOTE: different from TFE_ContextSyncExecutors that raises potential // NOTE: different from TFE_ContextSyncExecutors that raises potential
// errors, deliberately ignore executor statuses in cleanup. // errors, deliberately ignore executor statuses in cleanup.
}); });
m.def("TFE_ContextSetSoftDevicePlacement", [](py::handle& ctx, bool enable) {
tensorflow::Safe_TF_StatusPtr status =
tensorflow::make_safe(TF_NewStatus());
TFE_ContextSetSoftDevicePlacement(tensorflow::InputTFE_Context(ctx), enable,
status.get());
});
m.def("TFE_ContextSetLogDevicePlacement", [](py::handle& ctx, bool enable) {
tensorflow::Safe_TF_StatusPtr status =
tensorflow::make_safe(TF_NewStatus());
TFE_ContextSetSoftDevicePlacement(tensorflow::InputTFE_Context(ctx), enable,
status.get());
});
// TFE_Executor logic // TFE_Executor logic
m.def( m.def(