Clean up Device pointer usage in execute.cc

- Avoid unnecessary Device lookup by name
- Merge device == nullptr logic for TensorHandle

PiperOrigin-RevId: 256296919
This commit is contained in:
Gaurav Jain 2019-07-02 21:30:37 -07:00 committed by TensorFlower Gardener
parent 720e198425
commit e0ab930a7e
9 changed files with 88 additions and 90 deletions

View File

@ -568,8 +568,7 @@ TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) {
const tensorflow::Tensor* t = nullptr; const tensorflow::Tensor* t = nullptr;
tensorflow::TensorHandle* h_cpu = nullptr; tensorflow::TensorHandle* h_cpu = nullptr;
status->status = EagerCopyToDevice( status->status = EagerCopyToDevice(
handle, handle->Context(), handle->Context()->HostCPU()->name().c_str(), handle, handle->Context(), handle->Context()->HostCPU(), false, &h_cpu);
false, &h_cpu);
if (!status->status.ok()) { if (!status->status.ok()) {
return nullptr; return nullptr;
} }
@ -913,8 +912,13 @@ TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
const char* device_name, const char* device_name,
TF_Status* status) { TF_Status* status) {
tensorflow::TensorHandle* handle = nullptr; tensorflow::TensorHandle* handle = nullptr;
tensorflow::Device* device;
status->status = ctx->context->FindDeviceFromName(device_name, &device);
if (!status->status.ok()) {
return nullptr;
}
status->status = tensorflow::EagerCopyToDevice(h->handle, ctx->context, status->status = tensorflow::EagerCopyToDevice(h->handle, ctx->context,
device_name, false, &handle); device, false, &handle);
if (status->status.ok()) { if (status->status.ok()) {
return new TFE_TensorHandle(handle); return new TFE_TensorHandle(handle);
} }
@ -984,10 +988,11 @@ TFE_TensorHandle* TFE_TensorHandleMaybeCopyToHostCPU(TFE_TensorHandle* h,
TF_Status* status) { TF_Status* status) {
// TensorHandles created by PyFuncOp lack context and therefore could // TensorHandles created by PyFuncOp lack context and therefore could
// not be copied. // not be copied.
if (!h->handle->OnHostCPU() && h->handle->Context() != nullptr) { tensorflow::EagerContext* ctx = h->handle->Context();
if (!h->handle->OnHostCPU() && ctx != nullptr) {
tensorflow::TensorHandle* handle = nullptr; tensorflow::TensorHandle* handle = nullptr;
status->status = tensorflow::EagerCopyToDevice( status->status = tensorflow::EagerCopyToDevice(
h->handle, h->handle->Context(), "CPU:0", false, &handle); h->handle, ctx, ctx->HostCPU(), false, &handle);
if (status->status.ok()) { if (status->status.ok()) {
return new TFE_TensorHandle(handle); return new TFE_TensorHandle(handle);
} else { } else {

View File

@ -253,17 +253,17 @@ inline tensorflow::Fprint128 CacheKeyHelper(StringPiece s, uint64 b) {
} // namespace } // namespace
tensorflow::Fprint128 AttrBuilder::CacheKey(const string& device) { tensorflow::Fprint128 AttrBuilder::CacheKey(const StringPiece device) {
if (!cached_cache_key_ || device != device_for_cached_cache_key_) { if (!cached_cache_key_ || device != device_for_cached_cache_key_) {
cached_cache_key_ = BuildCacheKeyForDevice(device); cached_cache_key_ = BuildCacheKeyForDevice(device);
device_for_cached_cache_key_ = device; device_for_cached_cache_key_ = string(device);
} }
return *cached_cache_key_; return *cached_cache_key_;
} }
tensorflow::Fprint128 AttrBuilder::BuildCacheKeyForDevice( tensorflow::Fprint128 AttrBuilder::BuildCacheKeyForDevice(
const string& device) const { const StringPiece device) const {
tensorflow::Fprint128 f = tensorflow::Fingerprint128(op_name_); tensorflow::Fprint128 f = tensorflow::Fingerprint128(op_name_);
f = tensorflow::FingerprintCat128(f, tensorflow::Fingerprint128(device)); f = tensorflow::FingerprintCat128(f, tensorflow::Fingerprint128(device));
if (node_def_ != nullptr) { if (node_def_ != nullptr) {

View File

@ -117,7 +117,7 @@ class AttrBuilder {
return GetNodeAttr(node_def_, attr_name, value); return GetNodeAttr(node_def_, attr_name, value);
} }
tensorflow::Fprint128 CacheKey(const string& device); tensorflow::Fprint128 CacheKey(const StringPiece device);
void FillAttrValueMap(AttrValueMap* m) const { FillAttrValueMap(m, true); } void FillAttrValueMap(AttrValueMap* m) const { FillAttrValueMap(m, true); }
const NodeDef& BuildNodeDef(); const NodeDef& BuildNodeDef();
@ -126,7 +126,7 @@ class AttrBuilder {
template <class T> template <class T>
using AttrVec = tensorflow::gtl::InlinedVector<std::pair<string, T>, 2>; using AttrVec = tensorflow::gtl::InlinedVector<std::pair<string, T>, 2>;
tensorflow::Fprint128 BuildCacheKeyForDevice(const string& device) const; tensorflow::Fprint128 BuildCacheKeyForDevice(const StringPiece device) const;
void MayBeInitializeNodeDef(); void MayBeInitializeNodeDef();
// Fill `m` with the attr-value pairs set via AttrBuilder::Set() so far, as // Fill `m` with the attr-value pairs set via AttrBuilder::Set() so far, as

View File

@ -78,9 +78,6 @@ void MaybeInitializeStepStats(StepStats* step_stats, EagerContext* ctx) {
int StepStatsDeviceIndex(StepStats* step_stats, EagerContext* ctx, int StepStatsDeviceIndex(StepStats* step_stats, EagerContext* ctx,
Device* device) { Device* device) {
// Find the current device's index. // Find the current device's index.
if (device == nullptr) {
device = ctx->HostCPU();
}
for (int i = 0; i < ctx->devices()->size(); ++i) { for (int i = 0; i < ctx->devices()->size(); ++i) {
if (ctx->devices()->at(i) == device || if (ctx->devices()->at(i) == device ||
ctx->devices()->at(i)->name() == device->name()) { ctx->devices()->at(i)->name() == device->name()) {
@ -91,25 +88,29 @@ int StepStatsDeviceIndex(StepStats* step_stats, EagerContext* ctx,
return 0; return 0;
} }
const char* kUnspecifiedDeviceName = "<unspecified>";
const char* DeviceNameOrUnspecified(Device* device) {
return (device == nullptr) ? kUnspecifiedDeviceName : device->name().c_str();
}
// This function expects *handle to point to an existing tensor handle. The // This function expects *handle to point to an existing tensor handle. The
// function will update the *handle to be pointed to the existing input tensor // function will update the *handle to be pointed to the existing input tensor
// handle or else the newly copied tensor handle. The existing handle will have // handle or else the newly copied tensor handle. The existing handle will have
// a Ref added, vs the new handle has a Ref due to being newly constructed. // a Ref added, vs the new handle has a Ref due to being newly constructed.
// //
// `op_device_name` is passed in explicitly because `op->device()` might be // `op_device` is passed in explicitly because `op->device()` might be
// unset and we might have selected some specific device to run this op on. // unset and we might have selected some specific device to run this op on.
Status MaybeCopyInputToExpectedDevice(EagerOperation* op, Status MaybeCopyInputToExpectedDevice(EagerOperation* op, Device* op_device,
const string& op_device_name, int i, int i, Device* expected_input_device,
const Device* expected_input_device,
RunMetadata* run_metadata, RunMetadata* run_metadata,
TensorHandle** result) { TensorHandle** result) {
tensorflow::TensorHandle* handle = op->Inputs()[i]; tensorflow::TensorHandle* handle = op->Inputs()[i];
EagerContext* ctx = op->EagerContext(); EagerContext* ctx = op->EagerContext();
Device* handle_device = handle->device(); Device* handle_device = handle->DeviceOrHostCPU(ctx);
const Device* actual_device = const string& op_device_name = DeviceNameOrUnspecified(op_device);
handle_device == nullptr ? ctx->HostCPU() : handle_device;
if (expected_input_device == actual_device) { if (expected_input_device == handle_device) {
// No copy was done, so the result is just the original handle with a Ref // No copy was done, so the result is just the original handle with a Ref
handle->Ref(); handle->Ref();
*result = handle; *result = handle;
@ -132,7 +133,7 @@ Status MaybeCopyInputToExpectedDevice(EagerOperation* op,
" cannot compute ", " cannot compute ",
op->Name(), " as input #", i, " was expected to be on ", op->Name(), " as input #", i, " was expected to be on ",
expected_input_device->name(), " but is actually on ", expected_input_device->name(), " but is actually on ",
actual_device->name(), " (operation running on ", op_device_name, ")", handle_device->name(), " (operation running on ", op_device_name, ")",
" Tensors can be copied explicitly using:" " Tensors can be copied explicitly using:"
" `with tf.device(device_name): x = tf.identity(x)`" " `with tf.device(device_name): x = tf.identity(x)`"
" or transparently copied by using" " or transparently copied by using"
@ -141,7 +142,7 @@ Status MaybeCopyInputToExpectedDevice(EagerOperation* op,
case DEVICE_PLACEMENT_WARN: case DEVICE_PLACEMENT_WARN:
LOG(WARNING) << "before computing " << op->Name() << " input #" << i LOG(WARNING) << "before computing " << op->Name() << " input #" << i
<< " was expected to be on " << expected_input_device->name() << " was expected to be on " << expected_input_device->name()
<< " but is actually on " << actual_device->name() << " but is actually on " << handle_device->name()
<< " (operation running on " << op_device_name << " (operation running on " << op_device_name
<< "). This triggers a copy which can be a performance " << "). This triggers a copy which can be a performance "
"bottleneck."; "bottleneck.";
@ -153,9 +154,8 @@ Status MaybeCopyInputToExpectedDevice(EagerOperation* op,
// trigger a copy. // trigger a copy.
auto pre_time_nanos = Env::Default()->NowNanos(); auto pre_time_nanos = Env::Default()->NowNanos();
TensorHandle* result_handle = nullptr; TensorHandle* result_handle = nullptr;
Status status = Status status = EagerCopyToDevice(handle, ctx, expected_input_device,
EagerCopyToDevice(handle, ctx, expected_input_device->name().c_str(), ctx->MirrorTensors(), &result_handle);
ctx->MirrorTensors(), &result_handle);
if (run_metadata != nullptr) { if (run_metadata != nullptr) {
auto* step_stats = run_metadata->mutable_step_stats(); auto* step_stats = run_metadata->mutable_step_stats();
MaybeInitializeStepStats(step_stats, ctx); MaybeInitializeStepStats(step_stats, ctx);
@ -177,7 +177,7 @@ Status MaybeCopyInputToExpectedDevice(EagerOperation* op,
if (!status.ok()) { if (!status.ok()) {
if (result_handle != nullptr) result_handle->Unref(); if (result_handle != nullptr) result_handle->Unref();
return errors::Internal("Failed copying input tensor from ", return errors::Internal("Failed copying input tensor from ",
actual_device->name(), " to ", handle_device->name(), " to ",
expected_input_device->name(), " in order to run ", expected_input_device->name(), " in order to run ",
op->Name(), ": ", status.error_message()); op->Name(), ": ", status.error_message());
} }
@ -190,20 +190,19 @@ Status MaybeCopyInputToExpectedDevice(EagerOperation* op,
// `op_device_name` the name of the device on which the op will run, if any. // `op_device_name` the name of the device on which the op will run, if any.
// For functions running using function library runtime, the device can be // For functions running using function library runtime, the device can be
// unspecified. // unspecified.
Status ValidateInputTypeAndPlacement(EagerContext* ctx, Status ValidateInputTypeAndPlacement(
const string& op_device_name, EagerContext* ctx, EagerOperation* op,
EagerOperation* op, const core::RefCountPtr<KernelAndDevice>& kernel,
const KernelAndDevice* kernel, RunMetadata* run_metadata) {
RunMetadata* run_metadata) {
if (kernel->num_inputs() != op->Inputs().size()) { if (kernel->num_inputs() != op->Inputs().size()) {
return errors::InvalidArgument("expected ", kernel->num_inputs(), return errors::InvalidArgument("expected ", kernel->num_inputs(),
" inputs, got ", op->Inputs().size()); " inputs, got ", op->Inputs().size());
} }
for (int i = 0; i < op->Inputs().size(); ++i) { for (int i = 0; i < op->Inputs().size(); ++i) {
const Device* expected_device = kernel->InputDevice(i); Device* expected_device = kernel->InputDevice(i);
TensorHandle* handle = nullptr; TensorHandle* handle = nullptr;
TF_RETURN_IF_ERROR(MaybeCopyInputToExpectedDevice( TF_RETURN_IF_ERROR(MaybeCopyInputToExpectedDevice(
op, op_device_name, i, expected_device, run_metadata, &handle)); op, kernel->device(), i, expected_device, run_metadata, &handle));
op->UpdateInput(i, handle); op->UpdateInput(i, handle);
// Unref handle since it has a ref as an input now // Unref handle since it has a ref as an input now
handle->Unref(); handle->Unref();
@ -266,7 +265,7 @@ inline tensorflow::Fprint128 FingerprintCat128(const tensorflow::Fprint128& a,
return {x, tensorflow::FingerprintCat64(a.high64, x)}; return {x, tensorflow::FingerprintCat64(a.high64, x)};
} }
bool IsMultiDevice(const FunctionDef* fdef, const string& op_device) { bool IsMultiDevice(const FunctionDef* fdef) {
if (fdef == nullptr) { if (fdef == nullptr) {
// Primitive op. // Primitive op.
return false; return false;
@ -405,18 +404,15 @@ Status EagerLocalExecute(EagerOperation* op,
profiler::TraceMe activity( profiler::TraceMe activity(
[&] { return absl::StrCat("EagerLocalExecute: ", op->Name()); }, [&] { return absl::StrCat("EagerLocalExecute: ", op->Name()); },
profiler::TraceMeLevel::kInfo); profiler::TraceMeLevel::kInfo);
const string unspecified_device_name("<unspecified>");
EagerContext* ctx = op->EagerContext(); EagerContext* ctx = op->EagerContext();
TF_RETURN_IF_ERROR(ctx->GetStatus()); TF_RETURN_IF_ERROR(ctx->GetStatus());
Device* device = op->Device(); Device* device = op->Device();
const string& maybe_unspecified_device_name =
device == nullptr ? unspecified_device_name : device->name();
Fprint128 cache_key = Fprint128 cache_key =
op->MutableAttrs()->CacheKey(maybe_unspecified_device_name); op->MutableAttrs()->CacheKey(DeviceNameOrUnspecified(device));
bool is_multi_device_function = IsMultiDevice( bool is_multi_device_function =
ctx->FindFunctionDef(op->Name()), maybe_unspecified_device_name); IsMultiDevice(ctx->FindFunctionDef(op->Name()));
std::vector<Device*> input_dev_ptrs; std::vector<Device*> input_dev_ptrs;
// `input_tensor_shapes` contains (potentially a subset of) non DT_RESOURCE // `input_tensor_shapes` contains (potentially a subset of) non DT_RESOURCE
@ -440,7 +436,7 @@ Status EagerLocalExecute(EagerOperation* op,
if (input->IsRemote()) { if (input->IsRemote()) {
TensorHandle* handle = nullptr; TensorHandle* handle = nullptr;
TF_RETURN_IF_ERROR(EagerCopyToDevice( TF_RETURN_IF_ERROR(EagerCopyToDevice(
input, ctx, device == nullptr ? "" : device->name().c_str(), input, ctx, device == nullptr ? ctx->HostCPU() : device,
ctx->MirrorTensors(), &handle)); ctx->MirrorTensors(), &handle));
op->UpdateInput(i, handle); op->UpdateInput(i, handle);
// Unref handle since it has a ref as an input now // Unref handle since it has a ref as an input now
@ -449,10 +445,11 @@ Status EagerLocalExecute(EagerOperation* op,
} }
// Get device for this input, and add it to 'cache_key'. // Get device for this input, and add it to 'cache_key'.
Device* device; Device* input_device;
TF_RETURN_IF_ERROR(GetDeviceForInput(ctx, input, &device)); TF_RETURN_IF_ERROR(GetDeviceForInput(ctx, input, &input_device));
input_dev_ptrs.push_back(device); input_dev_ptrs.push_back(input_device);
cache_key = FingerprintCat128(cache_key, Fingerprint128(device->name())); cache_key =
FingerprintCat128(cache_key, Fingerprint128(input_device->name()));
// If input is normal tensor, get its shape and add it to 'cache_key'; // If input is normal tensor, get its shape and add it to 'cache_key';
// If input is a ResourceHandle, get its resource handle dtypes and shapes // If input is a ResourceHandle, get its resource handle dtypes and shapes
@ -493,7 +490,7 @@ Status EagerLocalExecute(EagerOperation* op,
core::RefCountPtr<KernelAndDevice> kernel = ctx->GetCachedKernel(cache_key); core::RefCountPtr<KernelAndDevice> kernel = ctx->GetCachedKernel(cache_key);
if (kernel == nullptr) { if (kernel == nullptr) {
VLOG(2) << "Creating new kernel for " << op->Name() << " on device " VLOG(2) << "Creating new kernel for " << op->Name() << " on device "
<< maybe_unspecified_device_name; << DeviceNameOrUnspecified(op->Device());
bool compile_with_xla; bool compile_with_xla;
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
ShouldCompileWithXLA(op, device, ctx, &compile_with_xla)); ShouldCompileWithXLA(op, device, ctx, &compile_with_xla));
@ -512,11 +509,9 @@ Status EagerLocalExecute(EagerOperation* op,
if (!run_function_with_flr && device == nullptr) { if (!run_function_with_flr && device == nullptr) {
TF_RETURN_IF_ERROR(SelectDevice(ndef, ctx, &device)); TF_RETURN_IF_ERROR(SelectDevice(ndef, ctx, &device));
} }
const string& device_name =
device == nullptr ? unspecified_device_name : device->name();
if (ctx->LogDevicePlacement() || VLOG_IS_ON(1)) { if (ctx->LogDevicePlacement() || VLOG_IS_ON(1)) {
string msg = strings::StrCat("Executing op ", ndef.op(), " in device ", string msg = strings::StrCat("Executing op ", ndef.op(), " in device ",
device_name); DeviceNameOrUnspecified(device));
if (!logging::LogToListeners(msg)) { if (!logging::LogToListeners(msg)) {
LOG(INFO) << msg; LOG(INFO) << msg;
} }
@ -576,11 +571,8 @@ Status EagerLocalExecute(EagerOperation* op,
*num_retvals); *num_retvals);
} }
*num_retvals = output_dtypes_size; *num_retvals = output_dtypes_size;
const string& device_name = kernel->device() == nullptr
? unspecified_device_name
: kernel->device()->name();
TF_RETURN_IF_ERROR(ValidateInputTypeAndPlacement( TF_RETURN_IF_ERROR(ValidateInputTypeAndPlacement(
ctx, device_name, op, kernel.get(), ctx, op, kernel,
ctx->ShouldStoreStepStats() ? ctx->RunMetadataProto() : nullptr)); ctx->ShouldStoreStepStats() ? ctx->RunMetadataProto() : nullptr));
std::unique_ptr<NodeExecStats> maybe_stats; std::unique_ptr<NodeExecStats> maybe_stats;
@ -776,9 +768,9 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals,
// correctly determined after the kernel is selected/instantiated, since // correctly determined after the kernel is selected/instantiated, since
// the op might have its inputs on host memory. // the op might have its inputs on host memory.
TensorHandle* handle = nullptr; TensorHandle* handle = nullptr;
TF_RETURN_IF_ERROR(MaybeCopyInputToExpectedDevice( TF_RETURN_IF_ERROR(
op, op->Device()->name(), i, remote_cpu_device, MaybeCopyInputToExpectedDevice(op, op->Device(), i, remote_cpu_device,
/* run_metadata= */ nullptr, &handle)); /* run_metadata= */ nullptr, &handle));
op->UpdateInput(i, handle); op->UpdateInput(i, handle);
input = handle; input = handle;
input_device = remote_cpu_device; input_device = remote_cpu_device;
@ -918,8 +910,7 @@ Status MaybeUpdateOpDevice(EagerOperation* op) {
// ineligible for CPU pinning. // ineligible for CPU pinning.
break; break;
} else if (all_inputs_eligible_for_cpu_pinning) { } else if (all_inputs_eligible_for_cpu_pinning) {
Device* input_device = tensor_handle->device(); Device* input_device = tensor_handle->DeviceOrHostCPU(ctx);
input_device = input_device == nullptr ? ctx->HostCPU() : input_device;
VLOG(2) << "for op " << op->Name() << " input " << i << " " VLOG(2) << "for op " << op->Name() << " input " << i << " "
<< DataTypeString(tensor_handle->dtype) << DataTypeString(tensor_handle->dtype)
<< " input device = " << input_device->name() << " input device = " << input_device->name()
@ -1336,33 +1327,25 @@ string GetUniqueWireID() {
} // namespace } // namespace
Status EagerCopyToDevice(TensorHandle* h, EagerContext* ctx, Status EagerCopyToDevice(TensorHandle* h, EagerContext* ctx, Device* device,
const char* device_name, bool mirror, bool mirror, TensorHandle** result) {
TensorHandle** result) {
profiler::TraceMe activity("EagerCopyToDevice", profiler::TraceMe activity("EagerCopyToDevice",
profiler::TraceMeLevel::kInfo); profiler::TraceMeLevel::kInfo);
Device* send_device = h->device(); Device* send_device = h->DeviceOrHostCPU(ctx);
if (send_device == nullptr) {
send_device = ctx->HostCPU();
}
bool sender_is_local = ctx->IsLocal(send_device); bool sender_is_local = ctx->IsLocal(send_device);
Device* recv_device; bool recver_is_local = ctx->IsLocal(device);
TF_RETURN_IF_ERROR(ctx->FindDeviceFromName(device_name, &recv_device));
bool recver_is_local = ctx->IsLocal(recv_device);
if (sender_is_local && recver_is_local) { if (sender_is_local && recver_is_local) {
return LocalEagerCopyToDevice(h, ctx, recv_device, result); return LocalEagerCopyToDevice(h, ctx, device, result);
} else { } else {
#if defined(IS_MOBILE_PLATFORM) #if defined(IS_MOBILE_PLATFORM)
return errors::Unimplemented( return errors::Unimplemented(
"Eager's remote execution is not available on mobile devices."); "Eager's remote execution is not available on mobile devices.");
#else // !IS_MOBILE_PLATFORM #else // !IS_MOBILE_PLATFORM
if (mirror) { if (mirror) {
if (h->HasRemoteMirror(recv_device)) { if (h->HasRemoteMirror(device)) {
h->Ref(); h->Ref();
*result = h; *result = h;
return Status::OK(); return Status::OK();
@ -1370,13 +1353,12 @@ Status EagerCopyToDevice(TensorHandle* h, EagerContext* ctx,
} }
if (ctx->UseSendTensorRPC() && sender_is_local && !recver_is_local) { if (ctx->UseSendTensorRPC() && sender_is_local && !recver_is_local) {
return EagerRemoteSendTensor(ctx, h, recv_device, mirror, result); return EagerRemoteSendTensor(ctx, h, device, mirror, result);
} else { } else {
string wire_id = GetUniqueWireID(); string wire_id = GetUniqueWireID();
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(ExecuteSend(ctx, send_device, h, wire_id, device));
ExecuteSend(ctx, send_device, h, wire_id, recv_device));
return ExecuteRecv(ctx, recv_device, h->dtype, wire_id, send_device, return ExecuteRecv(ctx, device, h->dtype, wire_id, send_device,
mirror ? h : nullptr, result); mirror ? h : nullptr, result);
} }
#endif // !IS_MOBILE_PLATFORM #endif // !IS_MOBILE_PLATFORM

View File

@ -55,9 +55,8 @@ Status EagerKernelExecute(EagerContext* ctx,
// the mirror flag, EagerCopyToDevice will attempt to add a mirror to the // the mirror flag, EagerCopyToDevice will attempt to add a mirror to the
// original handle and update *result to point to h. Since this is not // original handle and update *result to point to h. Since this is not
// guaranteed, callers should always use the value in *result. // guaranteed, callers should always use the value in *result.
Status EagerCopyToDevice(TensorHandle* h, EagerContext* ctx, Status EagerCopyToDevice(TensorHandle* h, EagerContext* ctx, Device* device,
const char* device_name, bool mirror, bool mirror, TensorHandle** result);
TensorHandle** result);
} // namespace tensorflow } // namespace tensorflow

View File

@ -77,6 +77,7 @@ Status TensorHandle::GetResourceHandleDtypesAndShapes(
Status TensorHandle::CreateLocalHandle(const class Tensor& t, Status TensorHandle::CreateLocalHandle(const class Tensor& t,
TensorHandle** h) { TensorHandle** h) {
// TODO(b/136608821): Move away from nullptr
return CreateLocalHandle(t, nullptr, nullptr, nullptr, h); return CreateLocalHandle(t, nullptr, nullptr, nullptr, h);
} }
@ -271,6 +272,10 @@ Status TensorHandle::TensorValue(tensorflow::TensorValue* t) {
return tensor_handle_data_->TensorValue(t); return tensor_handle_data_->TensorValue(t);
} }
Device* TensorHandle::DeviceOrHostCPU(EagerContext* ctx) const {
return (device_ == nullptr) ? ctx->HostCPU() : device_;
}
Status TensorHandle::Shape(tensorflow::TensorShape* shape) { Status TensorHandle::Shape(tensorflow::TensorShape* shape) {
TF_RETURN_IF_ERROR(WaitReady()); TF_RETURN_IF_ERROR(WaitReady());
return tensor_handle_data_->Shape(shape); return tensor_handle_data_->Shape(shape);
@ -375,7 +380,7 @@ void TensorHandle::Poison(Status status) {
Status TensorHandle::CopyToDevice(EagerContext* ctx, tensorflow::Device* dstd, Status TensorHandle::CopyToDevice(EagerContext* ctx, tensorflow::Device* dstd,
tensorflow::Tensor* output) { tensorflow::Tensor* output) {
tensorflow::Device* srcd = (device_ == nullptr) ? ctx->HostCPU() : device_; tensorflow::Device* srcd = DeviceOrHostCPU(ctx);
bool is_same_device = (srcd == dstd) || (srcd->name() == dstd->name()); bool is_same_device = (srcd == dstd) || (srcd->name() == dstd->name());
const bool dst_cpu = dstd->tensorflow_gpu_device_info() == nullptr; const bool dst_cpu = dstd->tensorflow_gpu_device_info() == nullptr;
const bool src_cpu = srcd->tensorflow_gpu_device_info() == nullptr; const bool src_cpu = srcd->tensorflow_gpu_device_info() == nullptr;

View File

@ -125,9 +125,11 @@ class TensorHandle : public core::RefCounted {
Status TensorValue(tensorflow::TensorValue* t); Status TensorValue(tensorflow::TensorValue* t);
tensorflow::Device* device() const { return device_; } Device* device() const { return device_; }
tensorflow::Device* op_device() const { return op_device_; } Device* op_device() const { return op_device_; }
tensorflow::Device* resource_device() const { return resource_device_; } Device* resource_device() const { return resource_device_; }
Device* DeviceOrHostCPU(EagerContext* ctx) const;
Status Shape(tensorflow::TensorShape* shape); Status Shape(tensorflow::TensorShape* shape);
@ -169,12 +171,14 @@ class TensorHandle : public core::RefCounted {
tensorflow::Tensor* output); tensorflow::Tensor* output);
// Warning: can return nullptr for CPU tensors. // Warning: can return nullptr for CPU tensors.
// TODO(b/136608821): Move away from nullptr
EagerContext* Context() { return ctx_; } EagerContext* Context() { return ctx_; }
// dtype for the handle. It must be the same as t.dtype() once the handle is // dtype for the handle. It must be the same as t.dtype() once the handle is
// ready. // ready.
const DataType dtype; const DataType dtype;
// TODO(b/136608821): Move away from nullptr
bool OnHostCPU() const { bool OnHostCPU() const {
return device_ == nullptr || return device_ == nullptr ||
(ctx_ != nullptr && ctx_->HostCPU() == device_); (ctx_ != nullptr && ctx_->HostCPU() == device_);
@ -197,7 +201,7 @@ class TensorHandle : public core::RefCounted {
// done and the handle is "ready". // done and the handle is "ready".
Status WaitReady(); Status WaitReady();
// TODO(ashankar): device_ == nullptr iff local CPU // TODO(b/136608821): device_ == nullptr iff local CPU
// This was expedient, but perhaps worth revisiting ('device_' should always // This was expedient, but perhaps worth revisiting ('device_' should always
// be a valid pointer?) // be a valid pointer?)
// This can be done if TFE_NewOp() and the TFE_TensorHandle constructors are // This can be done if TFE_NewOp() and the TFE_TensorHandle constructors are

View File

@ -331,9 +331,12 @@ Status EagerServiceImpl::SendTensor(const SendTensorRequest* request,
TensorHandle* tensor_handle = nullptr; TensorHandle* tensor_handle = nullptr;
TF_RETURN_IF_ERROR(TensorHandle::CreateLocalHandle(tensor, &tensor_handle)); TF_RETURN_IF_ERROR(TensorHandle::CreateLocalHandle(tensor, &tensor_handle));
TensorHandle* copied_handle = nullptr; TensorHandle* copied_handle = nullptr;
TF_RETURN_IF_ERROR(EagerCopyToDevice(tensor_handle, context->Context(), EagerContext* ctx = context->Context();
request->device_name().c_str(), false, Device* device;
&copied_handle)); TF_RETURN_IF_ERROR(
ctx->FindDeviceFromName(request->device_name().c_str(), &device));
TF_RETURN_IF_ERROR(
EagerCopyToDevice(tensor_handle, ctx, device, false, &copied_handle));
tensors.push_back(copied_handle); tensors.push_back(copied_handle);
tensor_handle->Unref(); tensor_handle->Unref();
} }

View File

@ -77,7 +77,7 @@ inline uint64 FingerprintCat64(const uint64 fp1, const uint64 fp2) {
// This is a portable fingerprint interface for strings that will never change. // This is a portable fingerprint interface for strings that will never change.
// However, it is not suitable for cryptography. // However, it is not suitable for cryptography.
inline uint64 Fingerprint64(StringPiece s) { inline uint64 Fingerprint64(const StringPiece s) {
#ifdef USE_OSS_FARMHASH #ifdef USE_OSS_FARMHASH
return ::util::Fingerprint64(s.data(), s.size()); return ::util::Fingerprint64(s.data(), s.size());
#else #else
@ -91,7 +91,7 @@ inline uint64 Fingerprint64(StringPiece s) {
} }
// 128-bit variant of Fingerprint64 above (same properties and caveats apply). // 128-bit variant of Fingerprint64 above (same properties and caveats apply).
inline Fprint128 Fingerprint128(StringPiece s) { inline Fprint128 Fingerprint128(const StringPiece s) {
#ifdef USE_OSS_FARMHASH #ifdef USE_OSS_FARMHASH
const auto fingerprint = ::util::Fingerprint128(s.data(), s.size()); const auto fingerprint = ::util::Fingerprint128(s.data(), s.size());
return {::util::Uint128Low64(fingerprint), return {::util::Uint128Low64(fingerprint),