Only reset TFE_Op from the same TFE_Context
PiperOrigin-RevId: 287201238 Change-Id: Id1732edd95aa470f2952de2620e3ce33fe4bfcb0
This commit is contained in:
parent
107bd3e40b
commit
400e246b7a
@ -14,6 +14,7 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/c/eager/c_api_internal.h"
|
||||
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/host_info.h"
|
||||
|
||||
TFE_Op* NewOrResetOp(TFE_Context* ctx, const char* op_or_function_name,
|
||||
@ -26,29 +27,22 @@ TFE_Op* NewOrResetOp(TFE_Context* ctx, const char* op_or_function_name,
|
||||
if (!status->status.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
auto create_or_reset =
|
||||
[&op_to_reset, &ctx, &name, &types, &raw_device_name, &status](
|
||||
bool is_function, TFE_OpInferenceContext* inference_ctx) -> TFE_Op* {
|
||||
if (op_to_reset) {
|
||||
status->status = op_to_reset->Reset(ctx, name, is_function, types,
|
||||
raw_device_name, inference_ctx);
|
||||
return op_to_reset;
|
||||
} else {
|
||||
TFE_Op* new_op = new TFE_Op(ctx, name, is_function, types, inference_ctx);
|
||||
status->status = new_op->operation.SetDeviceName(raw_device_name);
|
||||
return new_op;
|
||||
}
|
||||
};
|
||||
|
||||
if (op_to_reset && op_to_reset->ctx != ctx) {
|
||||
status->status = tensorflow::errors::Internal(
|
||||
"Cannot reset a TFE_Op from another TFE_Context");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::unique_ptr<TFE_OpInferenceContext> inference_ctx;
|
||||
if (!is_function) {
|
||||
const tensorflow::OpDef* op_def;
|
||||
status->status = tensorflow::OpDefForOp(op_or_function_name, &op_def);
|
||||
if (!status->status.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
return create_or_reset(false, new TFE_OpInferenceContext(op_def));
|
||||
}
|
||||
if (!ctx->context->FindFunctionByName(name)) {
|
||||
inference_ctx.reset(new TFE_OpInferenceContext(op_def));
|
||||
} else if (!ctx->context->FindFunctionByName(name)) {
|
||||
status->status = tensorflow::errors::NotFound(
|
||||
"'", name,
|
||||
"' is neither a type of a primitive operation nor a name "
|
||||
@ -58,5 +52,15 @@ TFE_Op* NewOrResetOp(TFE_Context* ctx, const char* op_or_function_name,
|
||||
"registered in the binary running in this process.");
|
||||
return nullptr;
|
||||
}
|
||||
return create_or_reset(true, nullptr);
|
||||
|
||||
if (op_to_reset) {
|
||||
status->status = op_to_reset->Reset(
|
||||
name, is_function, types, raw_device_name, std::move(inference_ctx));
|
||||
return op_to_reset;
|
||||
}
|
||||
|
||||
TFE_Op* new_op =
|
||||
new TFE_Op(ctx, name, is_function, types, std::move(inference_ctx));
|
||||
status->status = new_op->operation.SetDeviceName(raw_device_name);
|
||||
return new_op;
|
||||
}
|
||||
|
@ -125,24 +125,26 @@ struct TFE_OpInferenceContext {
|
||||
struct TFE_Op {
|
||||
TFE_Op(TFE_Context* ctx, const char* op, bool is_function,
|
||||
const tensorflow::AttrTypeMap* t,
|
||||
TFE_OpInferenceContext* inference_ctx)
|
||||
: operation(ctx->context, op, is_function, t),
|
||||
inference_ctx(inference_ctx) {}
|
||||
std::unique_ptr<TFE_OpInferenceContext> inference_ctx)
|
||||
: ctx(ctx),
|
||||
operation(ctx->context, op, is_function, t),
|
||||
inference_ctx(std::move(inference_ctx)) {}
|
||||
|
||||
void Clear() {
|
||||
operation.Clear();
|
||||
inference_ctx.reset();
|
||||
}
|
||||
|
||||
tensorflow::Status Reset(TFE_Context* ctx, const char* op, bool is_function,
|
||||
tensorflow::Status Reset(const char* op, bool is_function,
|
||||
const tensorflow::AttrTypeMap* t,
|
||||
const char* raw_device_name,
|
||||
TFE_OpInferenceContext* infer_ctx) {
|
||||
inference_ctx.reset(infer_ctx);
|
||||
std::unique_ptr<TFE_OpInferenceContext> infer_ctx) {
|
||||
inference_ctx = std::move(infer_ctx);
|
||||
return operation.Reset(ctx->context, op, is_function, t, raw_device_name,
|
||||
nullptr);
|
||||
}
|
||||
|
||||
TFE_Context* ctx;
|
||||
tensorflow::EagerOperation operation;
|
||||
std::unique_ptr<TFE_OpInferenceContext> inference_ctx;
|
||||
};
|
||||
|
@ -72,11 +72,15 @@ TFE_Op* GetOp(TFE_Context* ctx, const char* op_or_function_name,
|
||||
TFE_Op* maybe_op = ReleaseThreadLocalOp();
|
||||
if (maybe_op) {
|
||||
TFE_OpReset(ctx, op_or_function_name, raw_device_name, status, maybe_op);
|
||||
return maybe_op;
|
||||
} else {
|
||||
return NewOrResetOp(ctx, op_or_function_name, raw_device_name, status,
|
||||
nullptr);
|
||||
if (status->status.ok()) {
|
||||
return maybe_op;
|
||||
}
|
||||
// Delete op and create a fresh one
|
||||
delete maybe_op;
|
||||
}
|
||||
|
||||
return NewOrResetOp(ctx, op_or_function_name, raw_device_name, status,
|
||||
nullptr);
|
||||
}
|
||||
|
||||
void ReturnOp(TFE_Op* object) {
|
||||
|
Loading…
Reference in New Issue
Block a user