[Perf] Skip EagerOperation::SetDeviceName(...) call if input device name didn't change.
PiperOrigin-RevId: 284700133 Change-Id: I7716abe6968b0686df00ea15dec3d85bf16e8cf5
This commit is contained in:
parent
6dfa569272
commit
495e179730
@ -1009,7 +1009,7 @@ TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) {
|
||||
|
||||
TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name,
|
||||
TF_Status* status) {
|
||||
return NewOrResetOp(ctx, op_or_function_name, status,
|
||||
return NewOrResetOp(ctx, op_or_function_name, nullptr, status,
|
||||
/* op_to_reset= */ nullptr);
|
||||
}
|
||||
|
||||
|
@ -29,9 +29,11 @@ limitations under the License.
|
||||
using tensorflow::string;
|
||||
|
||||
void TFE_OpReset(TFE_Context* ctx, const char* op_or_function_name,
|
||||
TF_Status* status, TFE_Op* op_to_reset) {
|
||||
const char* raw_device_name, TF_Status* status,
|
||||
TFE_Op* op_to_reset) {
|
||||
if (op_to_reset) {
|
||||
NewOrResetOp(ctx, op_or_function_name, status, op_to_reset);
|
||||
NewOrResetOp(ctx, op_or_function_name, raw_device_name, status,
|
||||
op_to_reset);
|
||||
} else {
|
||||
TF_SetStatus(status, TF_INVALID_ARGUMENT,
|
||||
"op_to_reset should not be nullptr");
|
||||
|
@ -22,8 +22,16 @@ limitations under the License.
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
// Resets `op_to_reset` with `op_or_function_name` and `raw_device_name`. This
|
||||
// is for performance optimization by reusing an exiting unused op rather than
|
||||
// creating a new op every time. If `raw_device_name` is `NULL` or empty, it
|
||||
// does not set the device name. If it's not `NULL`, then it attempts to parse
|
||||
// and set the device name. It's effectively `TFE_OpSetDevice`, but it is faster
|
||||
// than seperately calling it because if the existing op has the same
|
||||
// `raw_device_name`, it skips parsing and just leave as it is.
|
||||
TF_CAPI_EXPORT extern void TFE_OpReset(TFE_Context* ctx,
|
||||
const char* op_or_function_name,
|
||||
const char* raw_device_name,
|
||||
TF_Status* status, TFE_Op* op_to_reset);
|
||||
|
||||
TF_CAPI_EXPORT extern void TFE_OpConsumeInput(TFE_Op* op, TFE_TensorHandle* h,
|
||||
|
@ -17,7 +17,8 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/host_info.h"
|
||||
|
||||
TFE_Op* NewOrResetOp(TFE_Context* ctx, const char* op_or_function_name,
|
||||
TF_Status* status, TFE_Op* op_to_reset) {
|
||||
const char* raw_device_name, TF_Status* status,
|
||||
TFE_Op* op_to_reset) {
|
||||
const char* name = op_or_function_name; // Shorthand
|
||||
const tensorflow::AttrTypeMap* types;
|
||||
bool is_function = false;
|
||||
@ -25,14 +26,17 @@ TFE_Op* NewOrResetOp(TFE_Context* ctx, const char* op_or_function_name,
|
||||
if (!status->status.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
auto create_or_reset = [&op_to_reset, &ctx, &name, &types](
|
||||
bool is_function,
|
||||
TFE_OpInferenceContext* inference_ctx) -> TFE_Op* {
|
||||
auto create_or_reset =
|
||||
[&op_to_reset, &ctx, &name, &types, &raw_device_name, &status](
|
||||
bool is_function, TFE_OpInferenceContext* inference_ctx) -> TFE_Op* {
|
||||
if (op_to_reset) {
|
||||
op_to_reset->Reset(ctx, name, is_function, types, inference_ctx);
|
||||
status->status = op_to_reset->Reset(ctx, name, is_function, types,
|
||||
raw_device_name, inference_ctx);
|
||||
return op_to_reset;
|
||||
} else {
|
||||
return new TFE_Op(ctx, name, is_function, types, inference_ctx);
|
||||
TFE_Op* new_op = new TFE_Op(ctx, name, is_function, types, inference_ctx);
|
||||
status->status = new_op->operation.SetDeviceName(raw_device_name);
|
||||
return new_op;
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -134,11 +134,13 @@ struct TFE_Op {
|
||||
inference_ctx.reset();
|
||||
}
|
||||
|
||||
void Reset(TFE_Context* ctx, const char* op, bool is_function,
|
||||
const tensorflow::AttrTypeMap* t,
|
||||
TFE_OpInferenceContext* infer_ctx) {
|
||||
operation.Reset(ctx->context, op, is_function, t, nullptr);
|
||||
tensorflow::Status Reset(TFE_Context* ctx, const char* op, bool is_function,
|
||||
const tensorflow::AttrTypeMap* t,
|
||||
const char* raw_device_name,
|
||||
TFE_OpInferenceContext* infer_ctx) {
|
||||
inference_ctx.reset(infer_ctx);
|
||||
return operation.Reset(ctx->context, op, is_function, t, raw_device_name,
|
||||
nullptr);
|
||||
}
|
||||
|
||||
tensorflow::EagerOperation operation;
|
||||
@ -146,7 +148,8 @@ struct TFE_Op {
|
||||
};
|
||||
|
||||
TFE_Op* NewOrResetOp(TFE_Context* ctx, const char* op_or_function_name,
|
||||
TF_Status* status, TFE_Op* op_to_reset = nullptr);
|
||||
const char* raw_device_name, TF_Status* status,
|
||||
TFE_Op* op_to_reset = nullptr);
|
||||
|
||||
struct TFE_Profiler {
|
||||
explicit TFE_Profiler() { profiler = tensorflow::ProfilerSession::Create(); }
|
||||
|
@ -16,16 +16,25 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
tensorflow::Status EagerOperation::SetDeviceName(const char* device) {
|
||||
tensorflow::Status EagerOperation::SetDeviceName(const char* device,
|
||||
const bool reset) {
|
||||
if (device != nullptr && strlen(device) > 0) {
|
||||
if (!DeviceNameUtils::ParseFullName(device, &device_parsed_name_)) {
|
||||
return errors::InvalidArgument("Malformed device specification '", device,
|
||||
"' in eager op: ", DebugString());
|
||||
if (device != raw_device_name_) {
|
||||
if (!DeviceNameUtils::ParseFullName(device, &device_parsed_name_)) {
|
||||
return errors::InvalidArgument("Malformed device specification '",
|
||||
device,
|
||||
"' in eager op: ", DebugString());
|
||||
}
|
||||
raw_device_name_ = device;
|
||||
device_name_ =
|
||||
DeviceNameUtils::HasSomeDetails(device_parsed_name_)
|
||||
? DeviceNameUtils::ParsedNameToString(device_parsed_name_)
|
||||
: "";
|
||||
}
|
||||
device_name_ =
|
||||
DeviceNameUtils::HasSomeDetails(device_parsed_name_)
|
||||
? DeviceNameUtils::ParsedNameToString(device_parsed_name_)
|
||||
: "";
|
||||
} else if (reset) {
|
||||
raw_device_name_.clear();
|
||||
device_name_.clear();
|
||||
device_parsed_name_.Clear();
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -33,7 +33,9 @@ class EagerOperation {
|
||||
const absl::optional<EagerRemoteFunctionParams>
|
||||
remote_func_params = absl::nullopt)
|
||||
: ctx_(nullptr) {
|
||||
Reset(ctx, op, is_function, t, executor, remote_func_params);
|
||||
tensorflow::Status status =
|
||||
Reset(ctx, op, is_function, t, nullptr, executor, remote_func_params);
|
||||
DCHECK(status.ok());
|
||||
}
|
||||
|
||||
~EagerOperation() {
|
||||
@ -53,10 +55,11 @@ class EagerOperation {
|
||||
inputs_.clear();
|
||||
}
|
||||
|
||||
void Reset(tensorflow::EagerContext* ctx, const char* op, bool is_function,
|
||||
const tensorflow::AttrTypeMap* t, EagerExecutor* executor,
|
||||
const absl::optional<EagerRemoteFunctionParams>
|
||||
remote_func_params = absl::nullopt) {
|
||||
tensorflow::Status Reset(tensorflow::EagerContext* ctx, const char* op,
|
||||
bool is_function, const tensorflow::AttrTypeMap* t,
|
||||
const char* raw_device_name, EagerExecutor* executor,
|
||||
const absl::optional<EagerRemoteFunctionParams>
|
||||
remote_func_params = absl::nullopt) {
|
||||
DCHECK(ctx_ == nullptr) << "Calling Reset without first calling Release";
|
||||
DCHECK(inputs_.empty());
|
||||
ctx_ = ctx;
|
||||
@ -67,8 +70,6 @@ class EagerOperation {
|
||||
}
|
||||
attr_types_ = t;
|
||||
device_ = nullptr;
|
||||
device_name_ = "";
|
||||
device_parsed_name_.Clear();
|
||||
use_xla_ = false;
|
||||
is_function_ = is_function;
|
||||
cancellation_manager_ = nullptr;
|
||||
@ -77,6 +78,7 @@ class EagerOperation {
|
||||
#ifdef TENSORFLOW_MEM_DEBUG
|
||||
op_name_ = op;
|
||||
#endif
|
||||
return SetDeviceName(raw_device_name, true);
|
||||
}
|
||||
|
||||
bool is_function() const { return is_function_; }
|
||||
@ -105,6 +107,7 @@ class EagerOperation {
|
||||
tensorflow::Device* Device() const { return device_; }
|
||||
void SetDevice(tensorflow::Device* device) {
|
||||
device_ = device;
|
||||
raw_device_name_.clear();
|
||||
device_name_ = device->name();
|
||||
device_parsed_name_ = device->parsed_name();
|
||||
}
|
||||
@ -113,7 +116,8 @@ class EagerOperation {
|
||||
const DeviceNameUtils::ParsedName& GetDeviceParsedName() const {
|
||||
return device_parsed_name_;
|
||||
}
|
||||
tensorflow::Status SetDeviceName(const char* device);
|
||||
tensorflow::Status SetDeviceName(const char* device,
|
||||
const bool reset = false);
|
||||
|
||||
// Indicates whether the op is assigned to a device that is local to the
|
||||
// current host.
|
||||
@ -147,6 +151,7 @@ class EagerOperation {
|
||||
const tensorflow::AttrTypeMap* attr_types_;
|
||||
tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 4> inputs_;
|
||||
tensorflow::Device* device_;
|
||||
string raw_device_name_;
|
||||
string device_name_;
|
||||
DeviceNameUtils::ParsedName device_parsed_name_;
|
||||
bool use_xla_ = false;
|
||||
|
@ -68,13 +68,14 @@ TFE_Op* ReleaseThreadLocalOp() {
|
||||
}
|
||||
|
||||
TFE_Op* GetOp(TFE_Context* ctx, const char* op_or_function_name,
|
||||
TF_Status* status) {
|
||||
const char* raw_device_name, TF_Status* status) {
|
||||
TFE_Op* maybe_op = ReleaseThreadLocalOp();
|
||||
if (maybe_op) {
|
||||
TFE_OpReset(ctx, op_or_function_name, status, maybe_op);
|
||||
TFE_OpReset(ctx, op_or_function_name, raw_device_name, status, maybe_op);
|
||||
return maybe_op;
|
||||
} else {
|
||||
return TFE_NewOp(ctx, op_or_function_name, status);
|
||||
return NewOrResetOp(ctx, op_or_function_name, raw_device_name, status,
|
||||
nullptr);
|
||||
}
|
||||
}
|
||||
|
||||
@ -834,14 +835,12 @@ void TFE_Py_ExecuteCancelable(TFE_Context* ctx, const char* device_name,
|
||||
TFE_CancellationManager* cancellation_manager,
|
||||
TFE_OutputTensorHandles* outputs,
|
||||
TF_Status* out_status) {
|
||||
TFE_Op* op = GetOp(ctx, op_name, out_status);
|
||||
TFE_Op* op = GetOp(ctx, op_name, device_name, out_status);
|
||||
auto cleaner = tensorflow::gtl::MakeCleanup([op] { ReturnOp(op); });
|
||||
if (!out_status->status.ok()) return;
|
||||
TFE_OpSetDevice(op, device_name, out_status);
|
||||
if (out_status->status.ok()) {
|
||||
for (int i = 0; i < inputs->size() && out_status->status.ok(); ++i) {
|
||||
TFE_OpAddInput(op, inputs->at(i), out_status);
|
||||
}
|
||||
|
||||
for (int i = 0; i < inputs->size() && out_status->status.ok(); ++i) {
|
||||
TFE_OpAddInput(op, inputs->at(i), out_status);
|
||||
}
|
||||
if (cancellation_manager && out_status->status.ok()) {
|
||||
TFE_OpSetCancellationManager(op, cancellation_manager, out_status);
|
||||
@ -3506,7 +3505,8 @@ PyObject* TFE_Py_FastPathExecute_C(PyObject*, PyObject* args) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
TFE_Op* op = GetOp(op_exec_info.ctx, op_name, status);
|
||||
TFE_Op* op =
|
||||
GetOp(op_exec_info.ctx, op_name, op_exec_info.device_name, status);
|
||||
auto cleaner = tensorflow::gtl::MakeCleanup([status, op] {
|
||||
ReturnStatus(status);
|
||||
ReturnOp(op);
|
||||
@ -3573,11 +3573,6 @@ PyObject* TFE_Py_FastPathExecute_C(PyObject*, PyObject* args) {
|
||||
}
|
||||
}
|
||||
|
||||
TFE_OpSetDevice(op, op_exec_info.device_name, status);
|
||||
if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Flat attrs and inputs as required by the record_gradient call. The attrs
|
||||
// here only contain inferred attrs (non-inferred attrs are added directly
|
||||
// from the input args).
|
||||
|
Loading…
Reference in New Issue
Block a user