Only reset TFE_Op from the same TFE_Context

PiperOrigin-RevId: 287201238
Change-Id: Id1732edd95aa470f2952de2620e3ce33fe4bfcb0
This commit is contained in:
Gaurav Jain 2019-12-26 10:28:58 -08:00 committed by TensorFlower Gardener
parent 107bd3e40b
commit 400e246b7a
3 changed files with 37 additions and 27 deletions

View File

@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/c/eager/c_api_internal.h" #include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/host_info.h" #include "tensorflow/core/platform/host_info.h"
TFE_Op* NewOrResetOp(TFE_Context* ctx, const char* op_or_function_name, TFE_Op* NewOrResetOp(TFE_Context* ctx, const char* op_or_function_name,
@ -26,29 +27,22 @@ TFE_Op* NewOrResetOp(TFE_Context* ctx, const char* op_or_function_name,
if (!status->status.ok()) { if (!status->status.ok()) {
return nullptr; return nullptr;
} }
auto create_or_reset =
[&op_to_reset, &ctx, &name, &types, &raw_device_name, &status](
bool is_function, TFE_OpInferenceContext* inference_ctx) -> TFE_Op* {
if (op_to_reset) {
status->status = op_to_reset->Reset(ctx, name, is_function, types,
raw_device_name, inference_ctx);
return op_to_reset;
} else {
TFE_Op* new_op = new TFE_Op(ctx, name, is_function, types, inference_ctx);
status->status = new_op->operation.SetDeviceName(raw_device_name);
return new_op;
}
};
if (op_to_reset && op_to_reset->ctx != ctx) {
status->status = tensorflow::errors::Internal(
"Cannot reset a TFE_Op from another TFE_Context");
return nullptr;
}
std::unique_ptr<TFE_OpInferenceContext> inference_ctx;
if (!is_function) { if (!is_function) {
const tensorflow::OpDef* op_def; const tensorflow::OpDef* op_def;
status->status = tensorflow::OpDefForOp(op_or_function_name, &op_def); status->status = tensorflow::OpDefForOp(op_or_function_name, &op_def);
if (!status->status.ok()) { if (!status->status.ok()) {
return nullptr; return nullptr;
} }
return create_or_reset(false, new TFE_OpInferenceContext(op_def)); inference_ctx.reset(new TFE_OpInferenceContext(op_def));
} } else if (!ctx->context->FindFunctionByName(name)) {
if (!ctx->context->FindFunctionByName(name)) {
status->status = tensorflow::errors::NotFound( status->status = tensorflow::errors::NotFound(
"'", name, "'", name,
"' is neither a type of a primitive operation nor a name " "' is neither a type of a primitive operation nor a name "
@ -58,5 +52,15 @@ TFE_Op* NewOrResetOp(TFE_Context* ctx, const char* op_or_function_name,
"registered in the binary running in this process."); "registered in the binary running in this process.");
return nullptr; return nullptr;
} }
return create_or_reset(true, nullptr);
if (op_to_reset) {
status->status = op_to_reset->Reset(
name, is_function, types, raw_device_name, std::move(inference_ctx));
return op_to_reset;
}
TFE_Op* new_op =
new TFE_Op(ctx, name, is_function, types, std::move(inference_ctx));
status->status = new_op->operation.SetDeviceName(raw_device_name);
return new_op;
} }

View File

@ -125,24 +125,26 @@ struct TFE_OpInferenceContext {
struct TFE_Op { struct TFE_Op {
TFE_Op(TFE_Context* ctx, const char* op, bool is_function, TFE_Op(TFE_Context* ctx, const char* op, bool is_function,
const tensorflow::AttrTypeMap* t, const tensorflow::AttrTypeMap* t,
TFE_OpInferenceContext* inference_ctx) std::unique_ptr<TFE_OpInferenceContext> inference_ctx)
: operation(ctx->context, op, is_function, t), : ctx(ctx),
inference_ctx(inference_ctx) {} operation(ctx->context, op, is_function, t),
inference_ctx(std::move(inference_ctx)) {}
void Clear() { void Clear() {
operation.Clear(); operation.Clear();
inference_ctx.reset(); inference_ctx.reset();
} }
tensorflow::Status Reset(TFE_Context* ctx, const char* op, bool is_function, tensorflow::Status Reset(const char* op, bool is_function,
const tensorflow::AttrTypeMap* t, const tensorflow::AttrTypeMap* t,
const char* raw_device_name, const char* raw_device_name,
TFE_OpInferenceContext* infer_ctx) { std::unique_ptr<TFE_OpInferenceContext> infer_ctx) {
inference_ctx.reset(infer_ctx); inference_ctx = std::move(infer_ctx);
return operation.Reset(ctx->context, op, is_function, t, raw_device_name, return operation.Reset(ctx->context, op, is_function, t, raw_device_name,
nullptr); nullptr);
} }
TFE_Context* ctx;
tensorflow::EagerOperation operation; tensorflow::EagerOperation operation;
std::unique_ptr<TFE_OpInferenceContext> inference_ctx; std::unique_ptr<TFE_OpInferenceContext> inference_ctx;
}; };

View File

@ -72,11 +72,15 @@ TFE_Op* GetOp(TFE_Context* ctx, const char* op_or_function_name,
TFE_Op* maybe_op = ReleaseThreadLocalOp(); TFE_Op* maybe_op = ReleaseThreadLocalOp();
if (maybe_op) { if (maybe_op) {
TFE_OpReset(ctx, op_or_function_name, raw_device_name, status, maybe_op); TFE_OpReset(ctx, op_or_function_name, raw_device_name, status, maybe_op);
return maybe_op; if (status->status.ok()) {
} else { return maybe_op;
return NewOrResetOp(ctx, op_or_function_name, raw_device_name, status, }
nullptr); // Delete op and create a fresh one
delete maybe_op;
} }
return NewOrResetOp(ctx, op_or_function_name, raw_device_name, status,
nullptr);
} }
void ReturnOp(TFE_Op* object) { void ReturnOp(TFE_Op* object) {