Allow dynamically configuring device placement

Enable setting soft device placement as well as logging dynamically.
This required ensuring the device placement policy was part of the cache
key.

Further, we fix the logging to ensure in eager mode if a kernel is
retrieved from the kernel cache, then the execution is still logged. We
also log closer to the actual op execution to avoid logging before all
checks have been done.

PiperOrigin-RevId: 311271808
Change-Id: I9765228894f84a3447cc03332a2559f6d933165b
This commit is contained in:
Gaurav Jain 2020-05-12 23:14:53 -07:00 committed by TensorFlower Gardener
parent 088fc3a9b5
commit d5b3ec27d1
9 changed files with 95 additions and 39 deletions

View File

@ -657,3 +657,17 @@ TFE_TensorHandle* TFE_CreatePackedTensorHandle(TFE_Context* ctx,
std::move(tensor_handles), context, &handle);
return tensorflow::wrap(handle);
}
void TFE_ContextSetSoftDevicePlacement(TFE_Context* ctx, unsigned char enable,
TF_Status* status) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
context->SetAllowSoftPlacement(enable);
}
void TFE_ContextSetLogDevicePlacement(TFE_Context* ctx, unsigned char enable,
TF_Status* status) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
context->SetLogDevicePlacement(enable);
}

View File

@ -549,6 +549,18 @@ TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_CreatePackedTensorHandle(
TFE_Context* ctx, TFE_TensorHandle** handles, int* num_handles,
TF_Status* status);
// Configure soft device placement policy for the eager executor. Note this
// policy is applied to any subsequent op executions.
TF_CAPI_EXPORT void TFE_ContextSetSoftDevicePlacement(TFE_Context* ctx,
unsigned char enable,
TF_Status* status);
// Configure device placement policy logging for the eager executor. Note this
// policy is applied to any subsequent op executions.
TF_CAPI_EXPORT void TFE_ContextSetLogDevicePlacement(TFE_Context* ctx,
unsigned char enable,
TF_Status* status);
#ifdef __cplusplus
} /* end extern "C" */
#endif

View File

@ -300,7 +300,9 @@ class EagerContext : public AbstractContextInterface, public core::RefCounted {
void AddKernelToCache(Fprint128 cache_key, KernelAndDevice* kernel);
bool LogDevicePlacement() const { return log_device_placement_; }
void SetLogDevicePlacement(bool enable) { log_device_placement_ = enable; }
bool AllowSoftPlacement() const { return allow_soft_placement_; }
void SetAllowSoftPlacement(bool enable) { allow_soft_placement_ = enable; }
bool LogMemory() const { return log_memory_; }
Rendezvous* GetRendezvous() const { return rendezvous_; }
@ -625,9 +627,8 @@ class EagerContext : public AbstractContextInterface, public core::RefCounted {
mutex metadata_mu_;
RunMetadata run_metadata_ TF_GUARDED_BY(metadata_mu_);
GraphCollector graph_collector_;
// TODO(fishx): Allow update following two bool after context creation.
const bool log_device_placement_;
const bool allow_soft_placement_;
std::atomic<bool> log_device_placement_;
std::atomic<bool> allow_soft_placement_;
// Information related to step containers.
std::atomic<int> num_active_steps_;

View File

@ -365,6 +365,9 @@ Status GetOrCreateKernelAndDevice(
Device* device = absl::get<Device*>(op->Device());
Fprint128 cache_key = op->MutableAttrs()->CacheKey(op->DeviceName());
/// Include soft placement policy in cache key since the placement strategy
// can change and thus affect which kernel is picked.
cache_key = FingerprintCat128(cache_key, ctx.AllowSoftPlacement());
std::vector<Device*> input_dev_ptrs;
absl::flat_hash_map<string, const std::vector<string>*> composite_devices;
@ -488,13 +491,6 @@ Status GetOrCreateKernelAndDevice(
<< KernelsRegisteredForOp(op->Name());
op->SetDevice(device);
}
if (ctx.LogDevicePlacement() || VLOG_IS_ON(1)) {
string msg = strings::StrCat("Executing op ", ndef.op(), " in device ",
DeviceNameOrUnspecified(device));
if (!logging::LogToListeners(msg)) {
LOG(INFO) << msg;
}
}
FunctionLibraryRuntime* flr =
device == nullptr ? nullptr : ctx.func_lib(device);
@ -607,6 +603,14 @@ Status EagerLocalExecute(EagerOperation* op, TensorHandle** retvals,
int num_outputs = kernel->num_outputs();
TF_RETURN_IF_ERROR(ValidateInputTypeAndPlacement(&ctx, op, kernel));
if (ctx.LogDevicePlacement() || VLOG_IS_ON(1)) {
string msg = strings::StrCat("Executing op ", op->Name(), " in device ",
kernel->device()->name());
if (!logging::LogToListeners(msg)) {
LOG(INFO) << msg;
}
}
GraphCollector* graph_collector = nullptr;
if (ctx.ShouldStoreGraphs()) {
graph_collector = ctx.GetGraphCollector();
@ -841,6 +845,16 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals,
ctx.GetContextViewId(), eager_client.get(),
op->MutableAttrs()->BuildNodeDef(), op->EagerContext().FuncLibDef(),
op->Inputs(), {retvals, num_outputs}));
if (op->EagerContext().LogDevicePlacement() || VLOG_IS_ON(1)) {
string msg = strings::StrCat(
"Executing op ", op->Name(), " on task ",
DeviceNameUtils::ParsedNameToString(op->GetDeviceParsedName()));
if (!logging::LogToListeners(msg)) {
LOG(INFO) << msg;
}
}
Status s = executor.AddOrExecute(std::move(node));
// Since the operation failed, we need to Unref any outputs that were
// allocated.
@ -1119,15 +1133,6 @@ Status EagerExecute(EagerOperation* op, TensorHandle** retvals,
return EagerLocalExecute(op, retvals, num_retvals);
}
if (op->EagerContext().LogDevicePlacement() || VLOG_IS_ON(1)) {
string msg = strings::StrCat(
"Executing op ", op->Name(), " on task ",
DeviceNameUtils::ParsedNameToString(op->GetDeviceParsedName()));
if (!logging::LogToListeners(msg)) {
LOG(INFO) << msg;
}
}
#if defined(IS_MOBILE_PLATFORM)
return errors::Unimplemented(
"Eager's remote execution is not available on mobile devices.");
@ -1428,6 +1433,14 @@ void EagerLocalExecuteAsync(EagerOperation* op, TensorHandle** retvals,
return;
}
if (ctx.LogDevicePlacement() || VLOG_IS_ON(1)) {
string msg = strings::StrCat("Executing op ", op->Name(), " in device ",
kernel->device()->name());
if (!logging::LogToListeners(msg)) {
LOG(INFO) << msg;
}
}
GraphCollector* graph_collector = nullptr;
if (ctx.ShouldStoreGraphs()) {
graph_collector = ctx.GetGraphCollector();

View File

@ -1917,6 +1917,9 @@ class SessionTest(test_util.TensorFlowTestCase):
a = constant_op.constant(1)
b = constant_op.constant(2)
c = a + b
# Ensure if the same kernel with the same arguments is executed then its
# execution is logged.
d = a + b
else:
# Passing the config to the server, but not the session should still
# result in logging device placement.
@ -1925,12 +1928,16 @@ class SessionTest(test_util.TensorFlowTestCase):
a = constant_op.constant(1)
b = constant_op.constant(2)
c = a + b
d = a + b
with session.Session(server.target) as sess:
with CaptureStderr() as log:
sess.run(c)
c, d = sess.run([c, d])
self.assertEqual(c, 3)
self.assertEqual(d, 3)
# Ensure that we did log device placement.
self.assertTrue('/replica:0/task:0/device:CPU:0' in str(log), str(log))
add_executions = [l for l in str(log).splitlines() if 'AddV2' in l]
self.assertEqual(len(add_executions), 2)
@test_util.run_v1_only('b/120545219')
def testLocalMasterSessionTimeout(self):

View File

@ -1509,9 +1509,11 @@ class Context(object):
return self.config.allow_soft_placement
@soft_device_placement.setter
def soft_device_placement(self, enabled):
self._soft_device_placement = enabled
def soft_device_placement(self, enable):
if self._context_handle is not None:
pywrap_tfe.TFE_ContextSetSoftDevicePlacement(self._handle, enable)
self._soft_device_placement = enable
self._thread_local_data.function_call_options = None
@property
@ -1519,15 +1521,11 @@ class Context(object):
return self.config.log_device_placement
@log_device_placement.setter
def log_device_placement(self, enabled):
if self._log_device_placement == enabled:
return
def log_device_placement(self, enable):
if self._context_handle is not None:
raise RuntimeError(
"Device placement logging must be set at program startup")
pywrap_tfe.TFE_ContextSetLogDevicePlacement(self._handle, enable)
self._log_device_placement = enabled
self._log_device_placement = enable
self._thread_local_data.function_call_options = None
@property

View File

@ -1112,5 +1112,4 @@ class EagerTensorCacheTest(test_util.TensorFlowTestCase):
if __name__ == '__main__':
context.set_log_device_placement(True)
test.main()

View File

@ -159,7 +159,6 @@ class ConfigTest(test.TestCase, parameterized.TestCase):
else:
self.assertFalse(config.get_soft_device_placement())
@def_function.function
def mod():
with ops.device('/device:GPU:0'):
a = constant_op.constant(1.0)
@ -172,8 +171,10 @@ class ConfigTest(test.TestCase, parameterized.TestCase):
config.get_soft_device_placement(),
context.context().soft_device_placement)
# Since soft placement is enabled, the mod operation should work with CPU
# Since soft placement is enabled, the mod operation should fallback to CPU
# with pure eager execution as well as functions
mod()
def_function.function(mod)()
config.set_soft_device_placement(False)
self.assertEqual(config.get_soft_device_placement(), False)
@ -182,8 +183,11 @@ class ConfigTest(test.TestCase, parameterized.TestCase):
context.context().soft_device_placement)
# Since soft placement is disabled, the mod operation should fail on GPU
# with pure eager execution as well as functions
with self.assertRaises(errors.InvalidArgumentError):
mod()
with self.assertRaises(errors.InvalidArgumentError):
def_function.function(mod)()
@reset_eager
def testLogDevicePlacement(self):
@ -203,12 +207,8 @@ class ConfigTest(test.TestCase, parameterized.TestCase):
context.ensure_initialized()
with self.assertRaises(RuntimeError):
context.set_log_device_placement(True)
# If the setting the device placement is a no-op, do not throw a runtime
# exception.
context.set_log_device_placement(False)
# Changing the device placement should not throw an exception
context.set_log_device_placement(True)
@reset_eager
def testEnableMlirBridge(self):

View File

@ -488,6 +488,18 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
// NOTE: different from TFE_ContextSyncExecutors that raises potential
// errors, deliberately ignore executor statuses in cleanup.
});
m.def("TFE_ContextSetSoftDevicePlacement", [](py::handle& ctx, bool enable) {
tensorflow::Safe_TF_StatusPtr status =
tensorflow::make_safe(TF_NewStatus());
TFE_ContextSetSoftDevicePlacement(tensorflow::InputTFE_Context(ctx), enable,
status.get());
});
m.def("TFE_ContextSetLogDevicePlacement", [](py::handle& ctx, bool enable) {
tensorflow::Safe_TF_StatusPtr status =
tensorflow::make_safe(TF_NewStatus());
TFE_ContextSetSoftDevicePlacement(tensorflow::InputTFE_Context(ctx), enable,
status.get());
});
// TFE_Executor logic
m.def(