Fix an issue that start_profiler_server complains 'AssertionError: Context must be initialized first.'
PiperOrigin-RevId: 253093414
This commit is contained in:
parent
bbd5f591af
commit
6c5d79930c
@ -42,10 +42,8 @@ bool TFE_ProfilerIsOk(TFE_Profiler* profiler) {
|
|||||||
|
|
||||||
void TFE_DeleteProfiler(TFE_Profiler* profiler) { delete profiler; }
|
void TFE_DeleteProfiler(TFE_Profiler* profiler) { delete profiler; }
|
||||||
|
|
||||||
void TFE_ProfilerSerializeToString(TFE_Context* ctx, TFE_Profiler* profiler,
|
void TFE_ProfilerSerializeToString(TFE_Profiler* profiler, TF_Buffer* buf,
|
||||||
TF_Buffer* buf, TF_Status* status) {
|
TF_Status* status) {
|
||||||
TFE_ContextAsyncWait(ctx, status);
|
|
||||||
if (TF_GetCode(status) != TF_OK) return;
|
|
||||||
string content;
|
string content;
|
||||||
status->status = profiler->profiler->SerializeToString(&content);
|
status->status = profiler->profiler->SerializeToString(&content);
|
||||||
void* data = tensorflow::port::Malloc(content.length());
|
void* data = tensorflow::port::Malloc(content.length());
|
||||||
|
@ -40,8 +40,7 @@ TF_CAPI_EXPORT extern void TFE_DeleteProfiler(TFE_Profiler* profiler);
|
|||||||
|
|
||||||
// The output string is a binary string of tensorflow.tpu.Trace. User can write
|
// The output string is a binary string of tensorflow.tpu.Trace. User can write
|
||||||
// the string to file for offline analysis by tensorboard.
|
// the string to file for offline analysis by tensorboard.
|
||||||
TF_CAPI_EXPORT extern void TFE_ProfilerSerializeToString(TFE_Context* ctx,
|
TF_CAPI_EXPORT extern void TFE_ProfilerSerializeToString(TFE_Profiler* profiler,
|
||||||
TFE_Profiler* profiler,
|
|
||||||
TF_Buffer* buf,
|
TF_Buffer* buf,
|
||||||
TF_Status* status);
|
TF_Status* status);
|
||||||
|
|
||||||
|
@ -72,7 +72,11 @@ void ExecuteWithProfiling(bool async) {
|
|||||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
ASSERT_EQ(1, num_retvals);
|
ASSERT_EQ(1, num_retvals);
|
||||||
TF_Buffer* profiler_result = TF_NewBuffer();
|
TF_Buffer* profiler_result = TF_NewBuffer();
|
||||||
TFE_ProfilerSerializeToString(ctx, profiler, profiler_result, status);
|
if (async) {
|
||||||
|
TFE_ContextAsyncWait(ctx, status);
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
}
|
||||||
|
TFE_ProfilerSerializeToString(profiler, profiler_result, status);
|
||||||
TFE_DeleteProfiler(profiler);
|
TFE_DeleteProfiler(profiler);
|
||||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
profiler::Trace profile_proto;
|
profiler::Trace profile_proto;
|
||||||
|
@ -71,9 +71,9 @@ def start():
|
|||||||
with _profiler_lock:
|
with _profiler_lock:
|
||||||
if _profiler is not None:
|
if _profiler is not None:
|
||||||
raise ProfilerAlreadyRunningError('Another profiler is running.')
|
raise ProfilerAlreadyRunningError('Another profiler is running.')
|
||||||
context.ensure_initialized()
|
|
||||||
profiler_context = pywrap_tensorflow.TFE_NewProfilerContext()
|
profiler_context = pywrap_tensorflow.TFE_NewProfilerContext()
|
||||||
if context.default_execution_mode == context.EAGER_MODE:
|
if context.default_execution_mode == context.EAGER_MODE:
|
||||||
|
context.ensure_initialized()
|
||||||
pywrap_tensorflow.TFE_ProfilerContextSetEagerContext(
|
pywrap_tensorflow.TFE_ProfilerContextSetEagerContext(
|
||||||
profiler_context,
|
profiler_context,
|
||||||
context.context()._handle) # pylint: disable=protected-access
|
context.context()._handle) # pylint: disable=protected-access
|
||||||
@ -101,9 +101,10 @@ def stop():
|
|||||||
if _profiler is None:
|
if _profiler is None:
|
||||||
raise ProfilerNotRunningError(
|
raise ProfilerNotRunningError(
|
||||||
'Cannot stop profiling. No profiler is running.')
|
'Cannot stop profiling. No profiler is running.')
|
||||||
|
if context.default_execution_mode == context.EAGER_MODE:
|
||||||
|
context.async_wait()
|
||||||
with c_api_util.tf_buffer() as buffer_:
|
with c_api_util.tf_buffer() as buffer_:
|
||||||
pywrap_tensorflow.TFE_ProfilerSerializeToString(
|
pywrap_tensorflow.TFE_ProfilerSerializeToString(
|
||||||
context.context()._handle, # pylint: disable=protected-access
|
|
||||||
_profiler,
|
_profiler,
|
||||||
buffer_)
|
buffer_)
|
||||||
result = pywrap_tensorflow.TF_GetBuffer(buffer_)
|
result = pywrap_tensorflow.TF_GetBuffer(buffer_)
|
||||||
@ -162,6 +163,7 @@ def start_profiler_server(port):
|
|||||||
"""
|
"""
|
||||||
profiler_context = pywrap_tensorflow.TFE_NewProfilerContext()
|
profiler_context = pywrap_tensorflow.TFE_NewProfilerContext()
|
||||||
if context.default_execution_mode == context.EAGER_MODE:
|
if context.default_execution_mode == context.EAGER_MODE:
|
||||||
|
context.ensure_initialized()
|
||||||
pywrap_tensorflow.TFE_ProfilerContextSetEagerContext(
|
pywrap_tensorflow.TFE_ProfilerContextSetEagerContext(
|
||||||
profiler_context,
|
profiler_context,
|
||||||
context.context()._handle) # pylint: disable=protected-access
|
context.context()._handle) # pylint: disable=protected-access
|
||||||
|
Loading…
Reference in New Issue
Block a user