[Perf] Skip EagerOperation::SetDeviceName(...) call if input device name didn't change.

PiperOrigin-RevId: 284700133
Change-Id: I7716abe6968b0686df00ea15dec3d85bf16e8cf5
This commit is contained in:
A. Unique TensorFlower 2019-12-09 22:01:06 -08:00 committed by TensorFlower Gardener
parent 6dfa569272
commit 495e179730
8 changed files with 71 additions and 45 deletions

View File

@ -1009,7 +1009,7 @@ TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) {
TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name,
TF_Status* status) {
return NewOrResetOp(ctx, op_or_function_name, status,
return NewOrResetOp(ctx, op_or_function_name, nullptr, status,
/* op_to_reset= */ nullptr);
}

View File

@ -29,9 +29,11 @@ limitations under the License.
using tensorflow::string;
void TFE_OpReset(TFE_Context* ctx, const char* op_or_function_name,
TF_Status* status, TFE_Op* op_to_reset) {
const char* raw_device_name, TF_Status* status,
TFE_Op* op_to_reset) {
if (op_to_reset) {
NewOrResetOp(ctx, op_or_function_name, status, op_to_reset);
NewOrResetOp(ctx, op_or_function_name, raw_device_name, status,
op_to_reset);
} else {
TF_SetStatus(status, TF_INVALID_ARGUMENT,
"op_to_reset should not be nullptr");

View File

@ -22,8 +22,16 @@ limitations under the License.
extern "C" {
#endif
// Resets `op_to_reset` with `op_or_function_name` and `raw_device_name`. This
// is for performance optimization by reusing an exiting unused op rather than
// creating a new op every time. If `raw_device_name` is `NULL` or empty, it
// does not set the device name. If it's not `NULL`, then it attempts to parse
// and set the device name. It's effectively `TFE_OpSetDevice`, but it is faster
// than seperately calling it because if the existing op has the same
// `raw_device_name`, it skips parsing and just leave as it is.
TF_CAPI_EXPORT extern void TFE_OpReset(TFE_Context* ctx,
const char* op_or_function_name,
const char* raw_device_name,
TF_Status* status, TFE_Op* op_to_reset);
TF_CAPI_EXPORT extern void TFE_OpConsumeInput(TFE_Op* op, TFE_TensorHandle* h,

View File

@ -17,7 +17,8 @@ limitations under the License.
#include "tensorflow/core/platform/host_info.h"
TFE_Op* NewOrResetOp(TFE_Context* ctx, const char* op_or_function_name,
TF_Status* status, TFE_Op* op_to_reset) {
const char* raw_device_name, TF_Status* status,
TFE_Op* op_to_reset) {
const char* name = op_or_function_name; // Shorthand
const tensorflow::AttrTypeMap* types;
bool is_function = false;
@ -25,14 +26,17 @@ TFE_Op* NewOrResetOp(TFE_Context* ctx, const char* op_or_function_name,
if (!status->status.ok()) {
return nullptr;
}
auto create_or_reset = [&op_to_reset, &ctx, &name, &types](
bool is_function,
TFE_OpInferenceContext* inference_ctx) -> TFE_Op* {
auto create_or_reset =
[&op_to_reset, &ctx, &name, &types, &raw_device_name, &status](
bool is_function, TFE_OpInferenceContext* inference_ctx) -> TFE_Op* {
if (op_to_reset) {
op_to_reset->Reset(ctx, name, is_function, types, inference_ctx);
status->status = op_to_reset->Reset(ctx, name, is_function, types,
raw_device_name, inference_ctx);
return op_to_reset;
} else {
return new TFE_Op(ctx, name, is_function, types, inference_ctx);
TFE_Op* new_op = new TFE_Op(ctx, name, is_function, types, inference_ctx);
status->status = new_op->operation.SetDeviceName(raw_device_name);
return new_op;
}
};

View File

@ -134,11 +134,13 @@ struct TFE_Op {
inference_ctx.reset();
}
void Reset(TFE_Context* ctx, const char* op, bool is_function,
const tensorflow::AttrTypeMap* t,
TFE_OpInferenceContext* infer_ctx) {
operation.Reset(ctx->context, op, is_function, t, nullptr);
tensorflow::Status Reset(TFE_Context* ctx, const char* op, bool is_function,
const tensorflow::AttrTypeMap* t,
const char* raw_device_name,
TFE_OpInferenceContext* infer_ctx) {
inference_ctx.reset(infer_ctx);
return operation.Reset(ctx->context, op, is_function, t, raw_device_name,
nullptr);
}
tensorflow::EagerOperation operation;
@ -146,7 +148,8 @@ struct TFE_Op {
};
TFE_Op* NewOrResetOp(TFE_Context* ctx, const char* op_or_function_name,
TF_Status* status, TFE_Op* op_to_reset = nullptr);
const char* raw_device_name, TF_Status* status,
TFE_Op* op_to_reset = nullptr);
struct TFE_Profiler {
explicit TFE_Profiler() { profiler = tensorflow::ProfilerSession::Create(); }

View File

@ -16,16 +16,25 @@ limitations under the License.
namespace tensorflow {
tensorflow::Status EagerOperation::SetDeviceName(const char* device) {
tensorflow::Status EagerOperation::SetDeviceName(const char* device,
const bool reset) {
if (device != nullptr && strlen(device) > 0) {
if (!DeviceNameUtils::ParseFullName(device, &device_parsed_name_)) {
return errors::InvalidArgument("Malformed device specification '", device,
"' in eager op: ", DebugString());
if (device != raw_device_name_) {
if (!DeviceNameUtils::ParseFullName(device, &device_parsed_name_)) {
return errors::InvalidArgument("Malformed device specification '",
device,
"' in eager op: ", DebugString());
}
raw_device_name_ = device;
device_name_ =
DeviceNameUtils::HasSomeDetails(device_parsed_name_)
? DeviceNameUtils::ParsedNameToString(device_parsed_name_)
: "";
}
device_name_ =
DeviceNameUtils::HasSomeDetails(device_parsed_name_)
? DeviceNameUtils::ParsedNameToString(device_parsed_name_)
: "";
} else if (reset) {
raw_device_name_.clear();
device_name_.clear();
device_parsed_name_.Clear();
}
return Status::OK();
}

View File

@ -33,7 +33,9 @@ class EagerOperation {
const absl::optional<EagerRemoteFunctionParams>
remote_func_params = absl::nullopt)
: ctx_(nullptr) {
Reset(ctx, op, is_function, t, executor, remote_func_params);
tensorflow::Status status =
Reset(ctx, op, is_function, t, nullptr, executor, remote_func_params);
DCHECK(status.ok());
}
~EagerOperation() {
@ -53,10 +55,11 @@ class EagerOperation {
inputs_.clear();
}
void Reset(tensorflow::EagerContext* ctx, const char* op, bool is_function,
const tensorflow::AttrTypeMap* t, EagerExecutor* executor,
const absl::optional<EagerRemoteFunctionParams>
remote_func_params = absl::nullopt) {
tensorflow::Status Reset(tensorflow::EagerContext* ctx, const char* op,
bool is_function, const tensorflow::AttrTypeMap* t,
const char* raw_device_name, EagerExecutor* executor,
const absl::optional<EagerRemoteFunctionParams>
remote_func_params = absl::nullopt) {
DCHECK(ctx_ == nullptr) << "Calling Reset without first calling Release";
DCHECK(inputs_.empty());
ctx_ = ctx;
@ -67,8 +70,6 @@ class EagerOperation {
}
attr_types_ = t;
device_ = nullptr;
device_name_ = "";
device_parsed_name_.Clear();
use_xla_ = false;
is_function_ = is_function;
cancellation_manager_ = nullptr;
@ -77,6 +78,7 @@ class EagerOperation {
#ifdef TENSORFLOW_MEM_DEBUG
op_name_ = op;
#endif
return SetDeviceName(raw_device_name, true);
}
bool is_function() const { return is_function_; }
@ -105,6 +107,7 @@ class EagerOperation {
tensorflow::Device* Device() const { return device_; }
void SetDevice(tensorflow::Device* device) {
device_ = device;
raw_device_name_.clear();
device_name_ = device->name();
device_parsed_name_ = device->parsed_name();
}
@ -113,7 +116,8 @@ class EagerOperation {
const DeviceNameUtils::ParsedName& GetDeviceParsedName() const {
return device_parsed_name_;
}
tensorflow::Status SetDeviceName(const char* device);
tensorflow::Status SetDeviceName(const char* device,
const bool reset = false);
// Indicates whether the op is assigned to a device that is local to the
// current host.
@ -147,6 +151,7 @@ class EagerOperation {
const tensorflow::AttrTypeMap* attr_types_;
tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 4> inputs_;
tensorflow::Device* device_;
string raw_device_name_;
string device_name_;
DeviceNameUtils::ParsedName device_parsed_name_;
bool use_xla_ = false;

View File

@ -68,13 +68,14 @@ TFE_Op* ReleaseThreadLocalOp() {
}
TFE_Op* GetOp(TFE_Context* ctx, const char* op_or_function_name,
TF_Status* status) {
const char* raw_device_name, TF_Status* status) {
TFE_Op* maybe_op = ReleaseThreadLocalOp();
if (maybe_op) {
TFE_OpReset(ctx, op_or_function_name, status, maybe_op);
TFE_OpReset(ctx, op_or_function_name, raw_device_name, status, maybe_op);
return maybe_op;
} else {
return TFE_NewOp(ctx, op_or_function_name, status);
return NewOrResetOp(ctx, op_or_function_name, raw_device_name, status,
nullptr);
}
}
@ -834,14 +835,12 @@ void TFE_Py_ExecuteCancelable(TFE_Context* ctx, const char* device_name,
TFE_CancellationManager* cancellation_manager,
TFE_OutputTensorHandles* outputs,
TF_Status* out_status) {
TFE_Op* op = GetOp(ctx, op_name, out_status);
TFE_Op* op = GetOp(ctx, op_name, device_name, out_status);
auto cleaner = tensorflow::gtl::MakeCleanup([op] { ReturnOp(op); });
if (!out_status->status.ok()) return;
TFE_OpSetDevice(op, device_name, out_status);
if (out_status->status.ok()) {
for (int i = 0; i < inputs->size() && out_status->status.ok(); ++i) {
TFE_OpAddInput(op, inputs->at(i), out_status);
}
for (int i = 0; i < inputs->size() && out_status->status.ok(); ++i) {
TFE_OpAddInput(op, inputs->at(i), out_status);
}
if (cancellation_manager && out_status->status.ok()) {
TFE_OpSetCancellationManager(op, cancellation_manager, out_status);
@ -3506,7 +3505,8 @@ PyObject* TFE_Py_FastPathExecute_C(PyObject*, PyObject* args) {
return nullptr;
}
TFE_Op* op = GetOp(op_exec_info.ctx, op_name, status);
TFE_Op* op =
GetOp(op_exec_info.ctx, op_name, op_exec_info.device_name, status);
auto cleaner = tensorflow::gtl::MakeCleanup([status, op] {
ReturnStatus(status);
ReturnOp(op);
@ -3573,11 +3573,6 @@ PyObject* TFE_Py_FastPathExecute_C(PyObject*, PyObject* args) {
}
}
TFE_OpSetDevice(op, op_exec_info.device_name, status);
if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) {
return nullptr;
}
// Flat attrs and inputs as required by the record_gradient call. The attrs
// here only contain inferred attrs (non-inferred attrs are added directly
// from the input args).