Clean up Device pointer usage in execute.cc
- Avoid unnecessary Device lookup by name - Merge device == nullptr logic for TensorHandle PiperOrigin-RevId: 256296919
This commit is contained in:
parent
720e198425
commit
e0ab930a7e
@ -568,8 +568,7 @@ TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) {
|
||||
const tensorflow::Tensor* t = nullptr;
|
||||
tensorflow::TensorHandle* h_cpu = nullptr;
|
||||
status->status = EagerCopyToDevice(
|
||||
handle, handle->Context(), handle->Context()->HostCPU()->name().c_str(),
|
||||
false, &h_cpu);
|
||||
handle, handle->Context(), handle->Context()->HostCPU(), false, &h_cpu);
|
||||
if (!status->status.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
@ -913,8 +912,13 @@ TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
|
||||
const char* device_name,
|
||||
TF_Status* status) {
|
||||
tensorflow::TensorHandle* handle = nullptr;
|
||||
tensorflow::Device* device;
|
||||
status->status = ctx->context->FindDeviceFromName(device_name, &device);
|
||||
if (!status->status.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
status->status = tensorflow::EagerCopyToDevice(h->handle, ctx->context,
|
||||
device_name, false, &handle);
|
||||
device, false, &handle);
|
||||
if (status->status.ok()) {
|
||||
return new TFE_TensorHandle(handle);
|
||||
}
|
||||
@ -984,10 +988,11 @@ TFE_TensorHandle* TFE_TensorHandleMaybeCopyToHostCPU(TFE_TensorHandle* h,
|
||||
TF_Status* status) {
|
||||
// TensorHandles created by PyFuncOp lack context and therefore could
|
||||
// not be copied.
|
||||
if (!h->handle->OnHostCPU() && h->handle->Context() != nullptr) {
|
||||
tensorflow::EagerContext* ctx = h->handle->Context();
|
||||
if (!h->handle->OnHostCPU() && ctx != nullptr) {
|
||||
tensorflow::TensorHandle* handle = nullptr;
|
||||
status->status = tensorflow::EagerCopyToDevice(
|
||||
h->handle, h->handle->Context(), "CPU:0", false, &handle);
|
||||
h->handle, ctx, ctx->HostCPU(), false, &handle);
|
||||
if (status->status.ok()) {
|
||||
return new TFE_TensorHandle(handle);
|
||||
} else {
|
||||
|
@ -253,17 +253,17 @@ inline tensorflow::Fprint128 CacheKeyHelper(StringPiece s, uint64 b) {
|
||||
|
||||
} // namespace
|
||||
|
||||
tensorflow::Fprint128 AttrBuilder::CacheKey(const string& device) {
|
||||
tensorflow::Fprint128 AttrBuilder::CacheKey(const StringPiece device) {
|
||||
if (!cached_cache_key_ || device != device_for_cached_cache_key_) {
|
||||
cached_cache_key_ = BuildCacheKeyForDevice(device);
|
||||
device_for_cached_cache_key_ = device;
|
||||
device_for_cached_cache_key_ = string(device);
|
||||
}
|
||||
|
||||
return *cached_cache_key_;
|
||||
}
|
||||
|
||||
tensorflow::Fprint128 AttrBuilder::BuildCacheKeyForDevice(
|
||||
const string& device) const {
|
||||
const StringPiece device) const {
|
||||
tensorflow::Fprint128 f = tensorflow::Fingerprint128(op_name_);
|
||||
f = tensorflow::FingerprintCat128(f, tensorflow::Fingerprint128(device));
|
||||
if (node_def_ != nullptr) {
|
||||
|
@ -117,7 +117,7 @@ class AttrBuilder {
|
||||
return GetNodeAttr(node_def_, attr_name, value);
|
||||
}
|
||||
|
||||
tensorflow::Fprint128 CacheKey(const string& device);
|
||||
tensorflow::Fprint128 CacheKey(const StringPiece device);
|
||||
|
||||
void FillAttrValueMap(AttrValueMap* m) const { FillAttrValueMap(m, true); }
|
||||
const NodeDef& BuildNodeDef();
|
||||
@ -126,7 +126,7 @@ class AttrBuilder {
|
||||
template <class T>
|
||||
using AttrVec = tensorflow::gtl::InlinedVector<std::pair<string, T>, 2>;
|
||||
|
||||
tensorflow::Fprint128 BuildCacheKeyForDevice(const string& device) const;
|
||||
tensorflow::Fprint128 BuildCacheKeyForDevice(const StringPiece device) const;
|
||||
|
||||
void MayBeInitializeNodeDef();
|
||||
// Fill `m` with the attr-value pairs set via AttrBuilder::Set() so far, as
|
||||
|
@ -78,9 +78,6 @@ void MaybeInitializeStepStats(StepStats* step_stats, EagerContext* ctx) {
|
||||
int StepStatsDeviceIndex(StepStats* step_stats, EagerContext* ctx,
|
||||
Device* device) {
|
||||
// Find the current device's index.
|
||||
if (device == nullptr) {
|
||||
device = ctx->HostCPU();
|
||||
}
|
||||
for (int i = 0; i < ctx->devices()->size(); ++i) {
|
||||
if (ctx->devices()->at(i) == device ||
|
||||
ctx->devices()->at(i)->name() == device->name()) {
|
||||
@ -91,25 +88,29 @@ int StepStatsDeviceIndex(StepStats* step_stats, EagerContext* ctx,
|
||||
return 0;
|
||||
}
|
||||
|
||||
const char* kUnspecifiedDeviceName = "<unspecified>";
|
||||
|
||||
const char* DeviceNameOrUnspecified(Device* device) {
|
||||
return (device == nullptr) ? kUnspecifiedDeviceName : device->name().c_str();
|
||||
}
|
||||
|
||||
// This function expects *handle to point to an existing tensor handle. The
|
||||
// function will update the *handle to be pointed to the existing input tensor
|
||||
// handle or else the newly copied tensor handle. The existing handle will have
|
||||
// a Ref added, vs the new handle has a Ref due to being newly constructed.
|
||||
//
|
||||
// `op_device_name` is passed in explicitly because `op->device()` might be
|
||||
// `op_device` is passed in explicitly because `op->device()` might be
|
||||
// unset and we might have selected some specific device to run this op on.
|
||||
Status MaybeCopyInputToExpectedDevice(EagerOperation* op,
|
||||
const string& op_device_name, int i,
|
||||
const Device* expected_input_device,
|
||||
Status MaybeCopyInputToExpectedDevice(EagerOperation* op, Device* op_device,
|
||||
int i, Device* expected_input_device,
|
||||
RunMetadata* run_metadata,
|
||||
TensorHandle** result) {
|
||||
tensorflow::TensorHandle* handle = op->Inputs()[i];
|
||||
EagerContext* ctx = op->EagerContext();
|
||||
Device* handle_device = handle->device();
|
||||
const Device* actual_device =
|
||||
handle_device == nullptr ? ctx->HostCPU() : handle_device;
|
||||
Device* handle_device = handle->DeviceOrHostCPU(ctx);
|
||||
const string& op_device_name = DeviceNameOrUnspecified(op_device);
|
||||
|
||||
if (expected_input_device == actual_device) {
|
||||
if (expected_input_device == handle_device) {
|
||||
// No copy was done, so the result is just the original handle with a Ref
|
||||
handle->Ref();
|
||||
*result = handle;
|
||||
@ -132,7 +133,7 @@ Status MaybeCopyInputToExpectedDevice(EagerOperation* op,
|
||||
" cannot compute ",
|
||||
op->Name(), " as input #", i, " was expected to be on ",
|
||||
expected_input_device->name(), " but is actually on ",
|
||||
actual_device->name(), " (operation running on ", op_device_name, ")",
|
||||
handle_device->name(), " (operation running on ", op_device_name, ")",
|
||||
" Tensors can be copied explicitly using:"
|
||||
" `with tf.device(device_name): x = tf.identity(x)`"
|
||||
" or transparently copied by using"
|
||||
@ -141,7 +142,7 @@ Status MaybeCopyInputToExpectedDevice(EagerOperation* op,
|
||||
case DEVICE_PLACEMENT_WARN:
|
||||
LOG(WARNING) << "before computing " << op->Name() << " input #" << i
|
||||
<< " was expected to be on " << expected_input_device->name()
|
||||
<< " but is actually on " << actual_device->name()
|
||||
<< " but is actually on " << handle_device->name()
|
||||
<< " (operation running on " << op_device_name
|
||||
<< "). This triggers a copy which can be a performance "
|
||||
"bottleneck.";
|
||||
@ -153,9 +154,8 @@ Status MaybeCopyInputToExpectedDevice(EagerOperation* op,
|
||||
// trigger a copy.
|
||||
auto pre_time_nanos = Env::Default()->NowNanos();
|
||||
TensorHandle* result_handle = nullptr;
|
||||
Status status =
|
||||
EagerCopyToDevice(handle, ctx, expected_input_device->name().c_str(),
|
||||
ctx->MirrorTensors(), &result_handle);
|
||||
Status status = EagerCopyToDevice(handle, ctx, expected_input_device,
|
||||
ctx->MirrorTensors(), &result_handle);
|
||||
if (run_metadata != nullptr) {
|
||||
auto* step_stats = run_metadata->mutable_step_stats();
|
||||
MaybeInitializeStepStats(step_stats, ctx);
|
||||
@ -177,7 +177,7 @@ Status MaybeCopyInputToExpectedDevice(EagerOperation* op,
|
||||
if (!status.ok()) {
|
||||
if (result_handle != nullptr) result_handle->Unref();
|
||||
return errors::Internal("Failed copying input tensor from ",
|
||||
actual_device->name(), " to ",
|
||||
handle_device->name(), " to ",
|
||||
expected_input_device->name(), " in order to run ",
|
||||
op->Name(), ": ", status.error_message());
|
||||
}
|
||||
@ -190,20 +190,19 @@ Status MaybeCopyInputToExpectedDevice(EagerOperation* op,
|
||||
// `op_device_name` the name of the device on which the op will run, if any.
|
||||
// For functions running using function library runtime, the device can be
|
||||
// unspecified.
|
||||
Status ValidateInputTypeAndPlacement(EagerContext* ctx,
|
||||
const string& op_device_name,
|
||||
EagerOperation* op,
|
||||
const KernelAndDevice* kernel,
|
||||
RunMetadata* run_metadata) {
|
||||
Status ValidateInputTypeAndPlacement(
|
||||
EagerContext* ctx, EagerOperation* op,
|
||||
const core::RefCountPtr<KernelAndDevice>& kernel,
|
||||
RunMetadata* run_metadata) {
|
||||
if (kernel->num_inputs() != op->Inputs().size()) {
|
||||
return errors::InvalidArgument("expected ", kernel->num_inputs(),
|
||||
" inputs, got ", op->Inputs().size());
|
||||
}
|
||||
for (int i = 0; i < op->Inputs().size(); ++i) {
|
||||
const Device* expected_device = kernel->InputDevice(i);
|
||||
Device* expected_device = kernel->InputDevice(i);
|
||||
TensorHandle* handle = nullptr;
|
||||
TF_RETURN_IF_ERROR(MaybeCopyInputToExpectedDevice(
|
||||
op, op_device_name, i, expected_device, run_metadata, &handle));
|
||||
op, kernel->device(), i, expected_device, run_metadata, &handle));
|
||||
op->UpdateInput(i, handle);
|
||||
// Unref handle since it has a ref as an input now
|
||||
handle->Unref();
|
||||
@ -266,7 +265,7 @@ inline tensorflow::Fprint128 FingerprintCat128(const tensorflow::Fprint128& a,
|
||||
return {x, tensorflow::FingerprintCat64(a.high64, x)};
|
||||
}
|
||||
|
||||
bool IsMultiDevice(const FunctionDef* fdef, const string& op_device) {
|
||||
bool IsMultiDevice(const FunctionDef* fdef) {
|
||||
if (fdef == nullptr) {
|
||||
// Primitive op.
|
||||
return false;
|
||||
@ -405,18 +404,15 @@ Status EagerLocalExecute(EagerOperation* op,
|
||||
profiler::TraceMe activity(
|
||||
[&] { return absl::StrCat("EagerLocalExecute: ", op->Name()); },
|
||||
profiler::TraceMeLevel::kInfo);
|
||||
const string unspecified_device_name("<unspecified>");
|
||||
EagerContext* ctx = op->EagerContext();
|
||||
TF_RETURN_IF_ERROR(ctx->GetStatus());
|
||||
Device* device = op->Device();
|
||||
|
||||
const string& maybe_unspecified_device_name =
|
||||
device == nullptr ? unspecified_device_name : device->name();
|
||||
Fprint128 cache_key =
|
||||
op->MutableAttrs()->CacheKey(maybe_unspecified_device_name);
|
||||
op->MutableAttrs()->CacheKey(DeviceNameOrUnspecified(device));
|
||||
|
||||
bool is_multi_device_function = IsMultiDevice(
|
||||
ctx->FindFunctionDef(op->Name()), maybe_unspecified_device_name);
|
||||
bool is_multi_device_function =
|
||||
IsMultiDevice(ctx->FindFunctionDef(op->Name()));
|
||||
|
||||
std::vector<Device*> input_dev_ptrs;
|
||||
// `input_tensor_shapes` contains (potentially a subset of) non DT_RESOURCE
|
||||
@ -440,7 +436,7 @@ Status EagerLocalExecute(EagerOperation* op,
|
||||
if (input->IsRemote()) {
|
||||
TensorHandle* handle = nullptr;
|
||||
TF_RETURN_IF_ERROR(EagerCopyToDevice(
|
||||
input, ctx, device == nullptr ? "" : device->name().c_str(),
|
||||
input, ctx, device == nullptr ? ctx->HostCPU() : device,
|
||||
ctx->MirrorTensors(), &handle));
|
||||
op->UpdateInput(i, handle);
|
||||
// Unref handle since it has a ref as an input now
|
||||
@ -449,10 +445,11 @@ Status EagerLocalExecute(EagerOperation* op,
|
||||
}
|
||||
|
||||
// Get device for this input, and add it to 'cache_key'.
|
||||
Device* device;
|
||||
TF_RETURN_IF_ERROR(GetDeviceForInput(ctx, input, &device));
|
||||
input_dev_ptrs.push_back(device);
|
||||
cache_key = FingerprintCat128(cache_key, Fingerprint128(device->name()));
|
||||
Device* input_device;
|
||||
TF_RETURN_IF_ERROR(GetDeviceForInput(ctx, input, &input_device));
|
||||
input_dev_ptrs.push_back(input_device);
|
||||
cache_key =
|
||||
FingerprintCat128(cache_key, Fingerprint128(input_device->name()));
|
||||
|
||||
// If input is normal tensor, get its shape and add it to 'cache_key';
|
||||
// If input is a ResourceHandle, get its resource handle dtypes and shapes
|
||||
@ -493,7 +490,7 @@ Status EagerLocalExecute(EagerOperation* op,
|
||||
core::RefCountPtr<KernelAndDevice> kernel = ctx->GetCachedKernel(cache_key);
|
||||
if (kernel == nullptr) {
|
||||
VLOG(2) << "Creating new kernel for " << op->Name() << " on device "
|
||||
<< maybe_unspecified_device_name;
|
||||
<< DeviceNameOrUnspecified(op->Device());
|
||||
bool compile_with_xla;
|
||||
TF_RETURN_IF_ERROR(
|
||||
ShouldCompileWithXLA(op, device, ctx, &compile_with_xla));
|
||||
@ -512,11 +509,9 @@ Status EagerLocalExecute(EagerOperation* op,
|
||||
if (!run_function_with_flr && device == nullptr) {
|
||||
TF_RETURN_IF_ERROR(SelectDevice(ndef, ctx, &device));
|
||||
}
|
||||
const string& device_name =
|
||||
device == nullptr ? unspecified_device_name : device->name();
|
||||
if (ctx->LogDevicePlacement() || VLOG_IS_ON(1)) {
|
||||
string msg = strings::StrCat("Executing op ", ndef.op(), " in device ",
|
||||
device_name);
|
||||
DeviceNameOrUnspecified(device));
|
||||
if (!logging::LogToListeners(msg)) {
|
||||
LOG(INFO) << msg;
|
||||
}
|
||||
@ -576,11 +571,8 @@ Status EagerLocalExecute(EagerOperation* op,
|
||||
*num_retvals);
|
||||
}
|
||||
*num_retvals = output_dtypes_size;
|
||||
const string& device_name = kernel->device() == nullptr
|
||||
? unspecified_device_name
|
||||
: kernel->device()->name();
|
||||
TF_RETURN_IF_ERROR(ValidateInputTypeAndPlacement(
|
||||
ctx, device_name, op, kernel.get(),
|
||||
ctx, op, kernel,
|
||||
ctx->ShouldStoreStepStats() ? ctx->RunMetadataProto() : nullptr));
|
||||
|
||||
std::unique_ptr<NodeExecStats> maybe_stats;
|
||||
@ -776,9 +768,9 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals,
|
||||
// correctly determined after the kernel is selected/instantiated, since
|
||||
// the op might have its inputs on host memory.
|
||||
TensorHandle* handle = nullptr;
|
||||
TF_RETURN_IF_ERROR(MaybeCopyInputToExpectedDevice(
|
||||
op, op->Device()->name(), i, remote_cpu_device,
|
||||
/* run_metadata= */ nullptr, &handle));
|
||||
TF_RETURN_IF_ERROR(
|
||||
MaybeCopyInputToExpectedDevice(op, op->Device(), i, remote_cpu_device,
|
||||
/* run_metadata= */ nullptr, &handle));
|
||||
op->UpdateInput(i, handle);
|
||||
input = handle;
|
||||
input_device = remote_cpu_device;
|
||||
@ -918,8 +910,7 @@ Status MaybeUpdateOpDevice(EagerOperation* op) {
|
||||
// ineligible for CPU pinning.
|
||||
break;
|
||||
} else if (all_inputs_eligible_for_cpu_pinning) {
|
||||
Device* input_device = tensor_handle->device();
|
||||
input_device = input_device == nullptr ? ctx->HostCPU() : input_device;
|
||||
Device* input_device = tensor_handle->DeviceOrHostCPU(ctx);
|
||||
VLOG(2) << "for op " << op->Name() << " input " << i << " "
|
||||
<< DataTypeString(tensor_handle->dtype)
|
||||
<< " input device = " << input_device->name()
|
||||
@ -1336,33 +1327,25 @@ string GetUniqueWireID() {
|
||||
|
||||
} // namespace
|
||||
|
||||
Status EagerCopyToDevice(TensorHandle* h, EagerContext* ctx,
|
||||
const char* device_name, bool mirror,
|
||||
TensorHandle** result) {
|
||||
Status EagerCopyToDevice(TensorHandle* h, EagerContext* ctx, Device* device,
|
||||
bool mirror, TensorHandle** result) {
|
||||
profiler::TraceMe activity("EagerCopyToDevice",
|
||||
profiler::TraceMeLevel::kInfo);
|
||||
Device* send_device = h->device();
|
||||
|
||||
if (send_device == nullptr) {
|
||||
send_device = ctx->HostCPU();
|
||||
}
|
||||
Device* send_device = h->DeviceOrHostCPU(ctx);
|
||||
|
||||
bool sender_is_local = ctx->IsLocal(send_device);
|
||||
|
||||
Device* recv_device;
|
||||
TF_RETURN_IF_ERROR(ctx->FindDeviceFromName(device_name, &recv_device));
|
||||
|
||||
bool recver_is_local = ctx->IsLocal(recv_device);
|
||||
bool recver_is_local = ctx->IsLocal(device);
|
||||
|
||||
if (sender_is_local && recver_is_local) {
|
||||
return LocalEagerCopyToDevice(h, ctx, recv_device, result);
|
||||
return LocalEagerCopyToDevice(h, ctx, device, result);
|
||||
} else {
|
||||
#if defined(IS_MOBILE_PLATFORM)
|
||||
return errors::Unimplemented(
|
||||
"Eager's remote execution is not available on mobile devices.");
|
||||
#else // !IS_MOBILE_PLATFORM
|
||||
if (mirror) {
|
||||
if (h->HasRemoteMirror(recv_device)) {
|
||||
if (h->HasRemoteMirror(device)) {
|
||||
h->Ref();
|
||||
*result = h;
|
||||
return Status::OK();
|
||||
@ -1370,13 +1353,12 @@ Status EagerCopyToDevice(TensorHandle* h, EagerContext* ctx,
|
||||
}
|
||||
|
||||
if (ctx->UseSendTensorRPC() && sender_is_local && !recver_is_local) {
|
||||
return EagerRemoteSendTensor(ctx, h, recv_device, mirror, result);
|
||||
return EagerRemoteSendTensor(ctx, h, device, mirror, result);
|
||||
} else {
|
||||
string wire_id = GetUniqueWireID();
|
||||
TF_RETURN_IF_ERROR(
|
||||
ExecuteSend(ctx, send_device, h, wire_id, recv_device));
|
||||
TF_RETURN_IF_ERROR(ExecuteSend(ctx, send_device, h, wire_id, device));
|
||||
|
||||
return ExecuteRecv(ctx, recv_device, h->dtype, wire_id, send_device,
|
||||
return ExecuteRecv(ctx, device, h->dtype, wire_id, send_device,
|
||||
mirror ? h : nullptr, result);
|
||||
}
|
||||
#endif // !IS_MOBILE_PLATFORM
|
||||
|
@ -55,9 +55,8 @@ Status EagerKernelExecute(EagerContext* ctx,
|
||||
// the mirror flag, EagerCopyToDevice will attempt to add a mirror to the
|
||||
// original handle and update *result to point to h. Since this is not
|
||||
// guaranteed, callers should always use the value in *result.
|
||||
Status EagerCopyToDevice(TensorHandle* h, EagerContext* ctx,
|
||||
const char* device_name, bool mirror,
|
||||
TensorHandle** result);
|
||||
Status EagerCopyToDevice(TensorHandle* h, EagerContext* ctx, Device* device,
|
||||
bool mirror, TensorHandle** result);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -77,6 +77,7 @@ Status TensorHandle::GetResourceHandleDtypesAndShapes(
|
||||
|
||||
Status TensorHandle::CreateLocalHandle(const class Tensor& t,
|
||||
TensorHandle** h) {
|
||||
// TODO(b/136608821): Move away from nullptr
|
||||
return CreateLocalHandle(t, nullptr, nullptr, nullptr, h);
|
||||
}
|
||||
|
||||
@ -271,6 +272,10 @@ Status TensorHandle::TensorValue(tensorflow::TensorValue* t) {
|
||||
return tensor_handle_data_->TensorValue(t);
|
||||
}
|
||||
|
||||
Device* TensorHandle::DeviceOrHostCPU(EagerContext* ctx) const {
|
||||
return (device_ == nullptr) ? ctx->HostCPU() : device_;
|
||||
}
|
||||
|
||||
Status TensorHandle::Shape(tensorflow::TensorShape* shape) {
|
||||
TF_RETURN_IF_ERROR(WaitReady());
|
||||
return tensor_handle_data_->Shape(shape);
|
||||
@ -375,7 +380,7 @@ void TensorHandle::Poison(Status status) {
|
||||
|
||||
Status TensorHandle::CopyToDevice(EagerContext* ctx, tensorflow::Device* dstd,
|
||||
tensorflow::Tensor* output) {
|
||||
tensorflow::Device* srcd = (device_ == nullptr) ? ctx->HostCPU() : device_;
|
||||
tensorflow::Device* srcd = DeviceOrHostCPU(ctx);
|
||||
bool is_same_device = (srcd == dstd) || (srcd->name() == dstd->name());
|
||||
const bool dst_cpu = dstd->tensorflow_gpu_device_info() == nullptr;
|
||||
const bool src_cpu = srcd->tensorflow_gpu_device_info() == nullptr;
|
||||
|
@ -125,9 +125,11 @@ class TensorHandle : public core::RefCounted {
|
||||
|
||||
Status TensorValue(tensorflow::TensorValue* t);
|
||||
|
||||
tensorflow::Device* device() const { return device_; }
|
||||
tensorflow::Device* op_device() const { return op_device_; }
|
||||
tensorflow::Device* resource_device() const { return resource_device_; }
|
||||
Device* device() const { return device_; }
|
||||
Device* op_device() const { return op_device_; }
|
||||
Device* resource_device() const { return resource_device_; }
|
||||
|
||||
Device* DeviceOrHostCPU(EagerContext* ctx) const;
|
||||
|
||||
Status Shape(tensorflow::TensorShape* shape);
|
||||
|
||||
@ -169,12 +171,14 @@ class TensorHandle : public core::RefCounted {
|
||||
tensorflow::Tensor* output);
|
||||
|
||||
// Warning: can return nullptr for CPU tensors.
|
||||
// TODO(b/136608821): Move away from nullptr
|
||||
EagerContext* Context() { return ctx_; }
|
||||
|
||||
// dtype for the handle. It must be the same as t.dtype() once the handle is
|
||||
// ready.
|
||||
const DataType dtype;
|
||||
|
||||
// TODO(b/136608821): Move away from nullptr
|
||||
bool OnHostCPU() const {
|
||||
return device_ == nullptr ||
|
||||
(ctx_ != nullptr && ctx_->HostCPU() == device_);
|
||||
@ -197,7 +201,7 @@ class TensorHandle : public core::RefCounted {
|
||||
// done and the handle is "ready".
|
||||
Status WaitReady();
|
||||
|
||||
// TODO(ashankar): device_ == nullptr iff local CPU
|
||||
// TODO(b/136608821): device_ == nullptr iff local CPU
|
||||
// This was expedient, but perhaps worth revisiting ('device_' should always
|
||||
// be a valid pointer?)
|
||||
// This can be done if TFE_NewOp() and the TFE_TensorHandle constructors are
|
||||
|
@ -331,9 +331,12 @@ Status EagerServiceImpl::SendTensor(const SendTensorRequest* request,
|
||||
TensorHandle* tensor_handle = nullptr;
|
||||
TF_RETURN_IF_ERROR(TensorHandle::CreateLocalHandle(tensor, &tensor_handle));
|
||||
TensorHandle* copied_handle = nullptr;
|
||||
TF_RETURN_IF_ERROR(EagerCopyToDevice(tensor_handle, context->Context(),
|
||||
request->device_name().c_str(), false,
|
||||
&copied_handle));
|
||||
EagerContext* ctx = context->Context();
|
||||
Device* device;
|
||||
TF_RETURN_IF_ERROR(
|
||||
ctx->FindDeviceFromName(request->device_name().c_str(), &device));
|
||||
TF_RETURN_IF_ERROR(
|
||||
EagerCopyToDevice(tensor_handle, ctx, device, false, &copied_handle));
|
||||
tensors.push_back(copied_handle);
|
||||
tensor_handle->Unref();
|
||||
}
|
||||
|
@ -77,7 +77,7 @@ inline uint64 FingerprintCat64(const uint64 fp1, const uint64 fp2) {
|
||||
|
||||
// This is a portable fingerprint interface for strings that will never change.
|
||||
// However, it is not suitable for cryptography.
|
||||
inline uint64 Fingerprint64(StringPiece s) {
|
||||
inline uint64 Fingerprint64(const StringPiece s) {
|
||||
#ifdef USE_OSS_FARMHASH
|
||||
return ::util::Fingerprint64(s.data(), s.size());
|
||||
#else
|
||||
@ -91,7 +91,7 @@ inline uint64 Fingerprint64(StringPiece s) {
|
||||
}
|
||||
|
||||
// 128-bit variant of Fingerprint64 above (same properties and caveats apply).
|
||||
inline Fprint128 Fingerprint128(StringPiece s) {
|
||||
inline Fprint128 Fingerprint128(const StringPiece s) {
|
||||
#ifdef USE_OSS_FARMHASH
|
||||
const auto fingerprint = ::util::Fingerprint128(s.data(), s.size());
|
||||
return {::util::Uint128Low64(fingerprint),
|
||||
|
Loading…
x
Reference in New Issue
Block a user