From 400e246b7aed17ef2ea0590019ce2336405c56ed Mon Sep 17 00:00:00 2001 From: Gaurav Jain Date: Thu, 26 Dec 2019 10:28:58 -0800 Subject: [PATCH] Only reset TFE_Op from the same TFE_Context PiperOrigin-RevId: 287201238 Change-Id: Id1732edd95aa470f2952de2620e3ce33fe4bfcb0 --- tensorflow/c/eager/c_api_internal.cc | 38 +++++++++++++---------- tensorflow/c/eager/c_api_internal.h | 14 +++++---- tensorflow/python/eager/pywrap_tfe_src.cc | 12 ++++--- 3 files changed, 37 insertions(+), 27 deletions(-) diff --git a/tensorflow/c/eager/c_api_internal.cc b/tensorflow/c/eager/c_api_internal.cc index f6092715e17..4f3de479ba7 100644 --- a/tensorflow/c/eager/c_api_internal.cc +++ b/tensorflow/c/eager/c_api_internal.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/c/eager/c_api_internal.h" +#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/host_info.h" TFE_Op* NewOrResetOp(TFE_Context* ctx, const char* op_or_function_name, @@ -26,29 +27,22 @@ TFE_Op* NewOrResetOp(TFE_Context* ctx, const char* op_or_function_name, if (!status->status.ok()) { return nullptr; } - auto create_or_reset = - [&op_to_reset, &ctx, &name, &types, &raw_device_name, &status]( - bool is_function, TFE_OpInferenceContext* inference_ctx) -> TFE_Op* { - if (op_to_reset) { - status->status = op_to_reset->Reset(ctx, name, is_function, types, - raw_device_name, inference_ctx); - return op_to_reset; - } else { - TFE_Op* new_op = new TFE_Op(ctx, name, is_function, types, inference_ctx); - status->status = new_op->operation.SetDeviceName(raw_device_name); - return new_op; - } - }; + if (op_to_reset && op_to_reset->ctx != ctx) { + status->status = tensorflow::errors::Internal( + "Cannot reset a TFE_Op from another TFE_Context"); + return nullptr; + } + + std::unique_ptr inference_ctx; if (!is_function) { const tensorflow::OpDef* op_def; status->status = tensorflow::OpDefForOp(op_or_function_name, &op_def); if (!status->status.ok()) { return nullptr; } - return create_or_reset(false, new TFE_OpInferenceContext(op_def)); - } - if (!ctx->context->FindFunctionByName(name)) { + inference_ctx.reset(new TFE_OpInferenceContext(op_def)); + } else if (!ctx->context->FindFunctionByName(name)) { status->status = tensorflow::errors::NotFound( "'", name, "' is neither a type of a primitive operation nor a name " @@ -58,5 +52,15 @@ TFE_Op* NewOrResetOp(TFE_Context* ctx, const char* op_or_function_name, "registered in the binary running in this process."); return nullptr; } - return create_or_reset(true, nullptr); + + if (op_to_reset) { + status->status = op_to_reset->Reset( + name, is_function, types, raw_device_name, std::move(inference_ctx)); + return op_to_reset; + } + + TFE_Op* new_op = + new TFE_Op(ctx, name, is_function, types, std::move(inference_ctx)); + status->status = new_op->operation.SetDeviceName(raw_device_name); + return new_op; } diff --git a/tensorflow/c/eager/c_api_internal.h b/tensorflow/c/eager/c_api_internal.h index 29106e2998d..df192913b72 100644 --- a/tensorflow/c/eager/c_api_internal.h +++ b/tensorflow/c/eager/c_api_internal.h @@ -125,24 +125,26 @@ struct TFE_OpInferenceContext { struct TFE_Op { TFE_Op(TFE_Context* ctx, const char* op, bool is_function, const tensorflow::AttrTypeMap* t, - TFE_OpInferenceContext* inference_ctx) - : operation(ctx->context, op, is_function, t), - inference_ctx(inference_ctx) {} + std::unique_ptr inference_ctx) + : ctx(ctx), + operation(ctx->context, op, is_function, t), + inference_ctx(std::move(inference_ctx)) {} void Clear() { operation.Clear(); inference_ctx.reset(); } - tensorflow::Status Reset(TFE_Context* ctx, const char* op, bool is_function, + tensorflow::Status Reset(const char* op, bool is_function, const tensorflow::AttrTypeMap* t, const char* raw_device_name, - TFE_OpInferenceContext* infer_ctx) { - inference_ctx.reset(infer_ctx); + std::unique_ptr infer_ctx) { + inference_ctx = std::move(infer_ctx); return operation.Reset(ctx->context, op, is_function, t, raw_device_name, nullptr); } + TFE_Context* ctx; tensorflow::EagerOperation operation; std::unique_ptr inference_ctx; }; diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc index f5508d76583..8fe4b6ac5eb 100644 --- a/tensorflow/python/eager/pywrap_tfe_src.cc +++ b/tensorflow/python/eager/pywrap_tfe_src.cc @@ -72,11 +72,15 @@ TFE_Op* GetOp(TFE_Context* ctx, const char* op_or_function_name, TFE_Op* maybe_op = ReleaseThreadLocalOp(); if (maybe_op) { TFE_OpReset(ctx, op_or_function_name, raw_device_name, status, maybe_op); - return maybe_op; - } else { - return NewOrResetOp(ctx, op_or_function_name, raw_device_name, status, - nullptr); + if (status->status.ok()) { + return maybe_op; + } + // Delete op and create a fresh one + delete maybe_op; } + + return NewOrResetOp(ctx, op_or_function_name, raw_device_name, status, + nullptr); } void ReturnOp(TFE_Op* object) {