Add experimental C API to access EagerContext context ID.

PiperOrigin-RevId: 317476439
Change-Id: I9e97bce61cf526695f0c903b5f4f837116fef455
This commit is contained in:
A. Unique TensorFlower 2020-06-20 11:28:17 -07:00 committed by TensorFlower Gardener
parent 6d2ce43b03
commit e647a3b425
3 changed files with 17 additions and 0 deletions

View File

@ -60,6 +60,12 @@ void TFE_ContextDisableGraphCollection(TFE_Context* ctx) {
context->SetShouldStoreGraphs(false);
}
uint64_t TFE_GetContextId(TFE_Context* ctx) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
return context->GetContextId();
}
void TFE_MonitoringCounterCellIncrementBy(TFE_MonitoringCounterCell* cell,
int64_t value) {
cell->cell.IncrementBy(value);

View File

@ -300,6 +300,14 @@ TF_CAPI_EXPORT extern void TFE_ContextOptionsSetLazyRemoteInputsCopy(
TF_CAPI_EXPORT extern void TFE_ContextOptionsSetTfrt(TFE_ContextOptions*,
bool use_tfrt);
// Returns the context_id from the EagerContext which is used by the
// EagerService to maintain consistency between client and worker. The
// context_id is initialized with a dummy value and is later set when the worker
// is initialized (either locally or remotely). The context_id can change during
// the process lifetime although this should cause the worker to be
// reinitialized (e.g. cleared caches) as well.
TF_CAPI_EXPORT extern uint64_t TFE_GetContextId(TFE_Context* ctx);
// -----------------------------------------------------------------------------
// Cancellation APIs.

View File

@ -461,6 +461,9 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
m.def("TFE_ContextClearCaches", [](py::handle& o) {
TFE_ContextClearCaches(tensorflow::InputTFE_Context(o));
});
m.def("TFE_GetContextId", [](py::handle& ctx) {
return TFE_GetContextId(tensorflow::InputTFE_Context(ctx));
});
m.def("TFE_ContextGetDevicePlacementPolicy", [](py::handle& ctx) {
return TFE_ContextGetDevicePlacementPolicy(
tensorflow::InputTFE_Context(ctx));