From e647a3b425ea63ff5e2e2338815ca4aea188c619 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" <gardener@tensorflow.org> Date: Sat, 20 Jun 2020 11:28:17 -0700 Subject: [PATCH] Add experimental C API to access EagerContext context ID. PiperOrigin-RevId: 317476439 Change-Id: I9e97bce61cf526695f0c903b5f4f837116fef455 --- tensorflow/c/eager/c_api_experimental.cc | 6 ++++++ tensorflow/c/eager/c_api_experimental.h | 8 ++++++++ tensorflow/python/tfe_wrapper.cc | 3 +++ 3 files changed, 17 insertions(+) diff --git a/tensorflow/c/eager/c_api_experimental.cc b/tensorflow/c/eager/c_api_experimental.cc index 9937fd7551f..7390cf243be 100644 --- a/tensorflow/c/eager/c_api_experimental.cc +++ b/tensorflow/c/eager/c_api_experimental.cc @@ -60,6 +60,12 @@ void TFE_ContextDisableGraphCollection(TFE_Context* ctx) { context->SetShouldStoreGraphs(false); } +uint64_t TFE_GetContextId(TFE_Context* ctx) { + tensorflow::EagerContext* context = + tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); + return context->GetContextId(); +} + void TFE_MonitoringCounterCellIncrementBy(TFE_MonitoringCounterCell* cell, int64_t value) { cell->cell.IncrementBy(value); diff --git a/tensorflow/c/eager/c_api_experimental.h b/tensorflow/c/eager/c_api_experimental.h index 1b8efe61ee0..1af76c01154 100644 --- a/tensorflow/c/eager/c_api_experimental.h +++ b/tensorflow/c/eager/c_api_experimental.h @@ -300,6 +300,14 @@ TF_CAPI_EXPORT extern void TFE_ContextOptionsSetLazyRemoteInputsCopy( TF_CAPI_EXPORT extern void TFE_ContextOptionsSetTfrt(TFE_ContextOptions*, bool use_tfrt); +// Returns the context_id from the EagerContext which is used by the +// EagerService to maintain consistency between client and worker. The +// context_id is initialized with a dummy value and is later set when the worker +// is initialized (either locally or remotely). The context_id can change during +// the process lifetime although this should cause the worker to be +// reinitialized (e.g. cleared caches) as well. +TF_CAPI_EXPORT extern uint64_t TFE_GetContextId(TFE_Context* ctx); + // ----------------------------------------------------------------------------- // Cancellation APIs. diff --git a/tensorflow/python/tfe_wrapper.cc b/tensorflow/python/tfe_wrapper.cc index 00137f6f492..80cce331353 100644 --- a/tensorflow/python/tfe_wrapper.cc +++ b/tensorflow/python/tfe_wrapper.cc @@ -461,6 +461,9 @@ PYBIND11_MODULE(_pywrap_tfe, m) { m.def("TFE_ContextClearCaches", [](py::handle& o) { TFE_ContextClearCaches(tensorflow::InputTFE_Context(o)); }); + m.def("TFE_GetContextId", [](py::handle& ctx) { + return TFE_GetContextId(tensorflow::InputTFE_Context(ctx)); + }); m.def("TFE_ContextGetDevicePlacementPolicy", [](py::handle& ctx) { return TFE_ContextGetDevicePlacementPolicy( tensorflow::InputTFE_Context(ctx));