From b0a397fceac50f7d6c9f18b66789f275c4a7ba6c Mon Sep 17 00:00:00 2001 From: Asim Shankar Date: Tue, 12 Sep 2017 12:16:59 -0700 Subject: [PATCH] eager: Remove unnecessary TFE_Context argument to TFE_OpSetDevice. PiperOrigin-RevId: 168417999 --- tensorflow/c/eager/c_api.cc | 6 +++--- tensorflow/c/eager/c_api.h | 7 ++----- tensorflow/python/eager/pywrap_tfe_src.cc | 2 +- 3 files changed, 6 insertions(+), 9 deletions(-) diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 56f7303f70f..5fa3da940ce 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -289,11 +289,11 @@ static void TFE_OpSetDeviceHelper(TFE_Op* op, tensorflow::Device* device, } } -void TFE_OpSetDevice(TFE_Op* op, TFE_Context* ctx, const char* device_name, - TF_Status* status) { +void TFE_OpSetDevice(TFE_Op* op, const char* device_name, TF_Status* status) { tensorflow::Device* d = nullptr; if (device_name != nullptr && strlen(device_name) > 0) { - status->status = ctx->session->device_mgr->LookupDevice(device_name, &d); + status->status = + op->ctx->session->device_mgr->LookupDevice(device_name, &d); if (!status->status.ok()) return; } TFE_OpSetDeviceHelper(op, d, status); diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h index a54d206a307..814bf03fa25 100644 --- a/tensorflow/c/eager/c_api.h +++ b/tensorflow/c/eager/c_api.h @@ -100,11 +100,8 @@ TF_CAPI_EXPORT extern TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_func TF_Status* status); TF_CAPI_EXPORT extern void TFE_DeleteOp(TFE_Op* op); -// TODO(ashankar): TFE_OpSetDevice and TFE_Execute should not have a TFE_Context -// parameter. Instead, the TFE_Context should be captured when creating the -// TFE_Op. -TF_CAPI_EXPORT extern void TFE_OpSetDevice(TFE_Op* op, TFE_Context* ctx, - const char* device_name, TF_Status* status); +TF_CAPI_EXPORT extern void TFE_OpSetDevice(TFE_Op* op, const char* device_name, + TF_Status* status); TF_CAPI_EXPORT extern void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* h, TF_Status* status); diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc index 655e3ec8491..323ae8328a4 100644 --- a/tensorflow/python/eager/pywrap_tfe_src.cc +++ b/tensorflow/python/eager/pywrap_tfe_src.cc @@ -297,7 +297,7 @@ void TFE_Py_Execute(TFE_Context* ctx, const char* device_name, TF_Status* out_status) { TFE_Op* op = TFE_NewOp(ctx, op_name, out_status); if (TF_GetCode(out_status) != TF_OK) return; - TFE_OpSetDevice(op, ctx, device_name, out_status); + TFE_OpSetDevice(op, device_name, out_status); if (TF_GetCode(out_status) == TF_OK) { for (int i = 0; i < inputs->size() && TF_GetCode(out_status) == TF_OK; ++i) {