Only reset TFE_Op from the same TFE_Context

PiperOrigin-RevId: 287201238
Change-Id: Id1732edd95aa470f2952de2620e3ce33fe4bfcb0
This commit is contained in:
Gaurav Jain 2019-12-26 10:28:58 -08:00 committed by TensorFlower Gardener
parent 107bd3e40b
commit 400e246b7a
3 changed files with 37 additions and 27 deletions

View File

@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/host_info.h"
TFE_Op* NewOrResetOp(TFE_Context* ctx, const char* op_or_function_name,
@ -26,29 +27,22 @@ TFE_Op* NewOrResetOp(TFE_Context* ctx, const char* op_or_function_name,
if (!status->status.ok()) {
return nullptr;
}
auto create_or_reset =
[&op_to_reset, &ctx, &name, &types, &raw_device_name, &status](
bool is_function, TFE_OpInferenceContext* inference_ctx) -> TFE_Op* {
if (op_to_reset) {
status->status = op_to_reset->Reset(ctx, name, is_function, types,
raw_device_name, inference_ctx);
return op_to_reset;
} else {
TFE_Op* new_op = new TFE_Op(ctx, name, is_function, types, inference_ctx);
status->status = new_op->operation.SetDeviceName(raw_device_name);
return new_op;
}
};
if (op_to_reset && op_to_reset->ctx != ctx) {
status->status = tensorflow::errors::Internal(
"Cannot reset a TFE_Op from another TFE_Context");
return nullptr;
}
std::unique_ptr<TFE_OpInferenceContext> inference_ctx;
if (!is_function) {
const tensorflow::OpDef* op_def;
status->status = tensorflow::OpDefForOp(op_or_function_name, &op_def);
if (!status->status.ok()) {
return nullptr;
}
return create_or_reset(false, new TFE_OpInferenceContext(op_def));
}
if (!ctx->context->FindFunctionByName(name)) {
inference_ctx.reset(new TFE_OpInferenceContext(op_def));
} else if (!ctx->context->FindFunctionByName(name)) {
status->status = tensorflow::errors::NotFound(
"'", name,
"' is neither a type of a primitive operation nor a name "
@ -58,5 +52,15 @@ TFE_Op* NewOrResetOp(TFE_Context* ctx, const char* op_or_function_name,
"registered in the binary running in this process.");
return nullptr;
}
return create_or_reset(true, nullptr);
if (op_to_reset) {
status->status = op_to_reset->Reset(
name, is_function, types, raw_device_name, std::move(inference_ctx));
return op_to_reset;
}
TFE_Op* new_op =
new TFE_Op(ctx, name, is_function, types, std::move(inference_ctx));
status->status = new_op->operation.SetDeviceName(raw_device_name);
return new_op;
}

View File

@ -125,24 +125,26 @@ struct TFE_OpInferenceContext {
struct TFE_Op {
TFE_Op(TFE_Context* ctx, const char* op, bool is_function,
const tensorflow::AttrTypeMap* t,
TFE_OpInferenceContext* inference_ctx)
: operation(ctx->context, op, is_function, t),
inference_ctx(inference_ctx) {}
std::unique_ptr<TFE_OpInferenceContext> inference_ctx)
: ctx(ctx),
operation(ctx->context, op, is_function, t),
inference_ctx(std::move(inference_ctx)) {}
void Clear() {
operation.Clear();
inference_ctx.reset();
}
tensorflow::Status Reset(TFE_Context* ctx, const char* op, bool is_function,
tensorflow::Status Reset(const char* op, bool is_function,
const tensorflow::AttrTypeMap* t,
const char* raw_device_name,
TFE_OpInferenceContext* infer_ctx) {
inference_ctx.reset(infer_ctx);
std::unique_ptr<TFE_OpInferenceContext> infer_ctx) {
inference_ctx = std::move(infer_ctx);
return operation.Reset(ctx->context, op, is_function, t, raw_device_name,
nullptr);
}
TFE_Context* ctx;
tensorflow::EagerOperation operation;
std::unique_ptr<TFE_OpInferenceContext> inference_ctx;
};

View File

@ -72,11 +72,15 @@ TFE_Op* GetOp(TFE_Context* ctx, const char* op_or_function_name,
TFE_Op* maybe_op = ReleaseThreadLocalOp();
if (maybe_op) {
TFE_OpReset(ctx, op_or_function_name, raw_device_name, status, maybe_op);
return maybe_op;
} else {
return NewOrResetOp(ctx, op_or_function_name, raw_device_name, status,
nullptr);
if (status->status.ok()) {
return maybe_op;
}
// Delete op and create a fresh one
delete maybe_op;
}
return NewOrResetOp(ctx, op_or_function_name, raw_device_name, status,
nullptr);
}
void ReturnOp(TFE_Op* object) {