Unify device names in EagerOperation.
Change DeviceName() so it returns what is set by SetDeviceName(). Update documentation and code accordingly. The GetDeviceName() function became redundant and was removed. PiperOrigin-RevId: 306722028 Change-Id: I680498e36e477d06521a7b2af267c2d7852bfafd
This commit is contained in:
parent
75aa2fa3f6
commit
ed7d6f617d
@ -42,7 +42,28 @@ class AbstractOperationInterface {
|
|||||||
virtual Status Reset(const char* op, const char* raw_device_name) = 0;
|
virtual Status Reset(const char* op, const char* raw_device_name) = 0;
|
||||||
|
|
||||||
virtual const string& Name() const = 0;
|
virtual const string& Name() const = 0;
|
||||||
|
|
||||||
|
// Returns the operation's device name.
|
||||||
|
//
|
||||||
|
// The value returned may be different from the one set by SetDeviceName, but
|
||||||
|
// it will be compatible with it: the name will be updated by device placement
|
||||||
|
// logic to refer to the specific device chosen.
|
||||||
|
//
|
||||||
|
// Example: If one calls `op->SetDeviceName("/device:GPU")`, the value
|
||||||
|
// returned by DeviceName should be "/device:GPU:*" until a particular GPU is
|
||||||
|
// chosen for the operation by the device placement logic in the
|
||||||
|
// executor. After that, the value returned by DeviceName will be a full
|
||||||
|
// device name such as "/job:localhost/replica:0/task:0/device:GPU:1".
|
||||||
virtual const string& DeviceName() const = 0;
|
virtual const string& DeviceName() const = 0;
|
||||||
|
|
||||||
|
// Sets the operation device name.
|
||||||
|
//
|
||||||
|
// The given `name` must be parseable by DeviceNameUtils::ParseFullName, and
|
||||||
|
// the result will be used as a constraint for device placement. See the
|
||||||
|
// documentation for DeviceName for more details.
|
||||||
|
//
|
||||||
|
// The value will override the previous value - that is, no "merging" of
|
||||||
|
// existing and given constraints will be performed.
|
||||||
virtual Status SetDeviceName(const char* name) = 0;
|
virtual Status SetDeviceName(const char* name) = 0;
|
||||||
|
|
||||||
virtual Status AddInput(AbstractTensorHandleInterface* input) = 0;
|
virtual Status AddInput(AbstractTensorHandleInterface* input) = 0;
|
||||||
|
@ -151,6 +151,18 @@ tf_cuda_library(
|
|||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_cc_test(
|
||||||
|
name = "eager_operation_test",
|
||||||
|
srcs = ["eager_operation_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":core",
|
||||||
|
":eager_operation",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core:test",
|
||||||
|
"//tensorflow/core:test_main",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
tf_cuda_library(
|
tf_cuda_library(
|
||||||
name = "tensor_handle_data",
|
name = "tensor_handle_data",
|
||||||
srcs = [
|
srcs = [
|
||||||
|
@ -36,13 +36,6 @@ void EagerOperation::Clear() {
|
|||||||
ClearInferenceState();
|
ClearInferenceState();
|
||||||
}
|
}
|
||||||
|
|
||||||
const string& EagerOperation::DeviceName() const {
|
|
||||||
VariantDevice variant_device =
|
|
||||||
(Device() == kVariantDeviceNull) ? EagerContext().HostCPU() : Device();
|
|
||||||
return absl::visit([](auto* d) -> const string& { return d->name(); },
|
|
||||||
variant_device);
|
|
||||||
}
|
|
||||||
|
|
||||||
Status EagerOperation::SetAttrString(const char* attr_name, const char* data,
|
Status EagerOperation::SetAttrString(const char* attr_name, const char* data,
|
||||||
size_t length) {
|
size_t length) {
|
||||||
MutableAttrs()->Set(attr_name, StringPiece(data, length));
|
MutableAttrs()->Set(attr_name, StringPiece(data, length));
|
||||||
@ -311,15 +304,7 @@ Status EagerOperation::Reset(
|
|||||||
executor_ = executor ? executor : &ctx_.Executor();
|
executor_ = executor ? executor : &ctx_.Executor();
|
||||||
remote_func_params_ = remote_func_params;
|
remote_func_params_ = remote_func_params;
|
||||||
op_name_ = op;
|
op_name_ = op;
|
||||||
if (device_name != nullptr && strlen(device_name) > 0) {
|
|
||||||
return SetDeviceName(device_name);
|
return SetDeviceName(device_name);
|
||||||
} else {
|
|
||||||
last_set_device_name_.clear();
|
|
||||||
device_name_.clear();
|
|
||||||
device_parsed_name_.Clear();
|
|
||||||
device_ = kVariantDeviceNull;
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Status EagerOperation::MaybeInferSingleInputAttrs(TensorHandle* handle) {
|
Status EagerOperation::MaybeInferSingleInputAttrs(TensorHandle* handle) {
|
||||||
@ -389,8 +374,8 @@ Status EagerOperation::InferInputListAttrs(int num_inputs) {
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status EagerOperation::SetDeviceName(const char* name) {
|
Status EagerOperation::SetDeviceName(const char* c_name) {
|
||||||
if (name != nullptr && strlen(name) > 0) {
|
string name(c_name != nullptr ? c_name : "");
|
||||||
if (name != last_set_device_name_) {
|
if (name != last_set_device_name_) {
|
||||||
if (!DeviceNameUtils::ParseFullName(name, &device_parsed_name_)) {
|
if (!DeviceNameUtils::ParseFullName(name, &device_parsed_name_)) {
|
||||||
return errors::InvalidArgument("Malformed device specification '", name,
|
return errors::InvalidArgument("Malformed device specification '", name,
|
||||||
@ -407,7 +392,6 @@ Status EagerOperation::SetDeviceName(const char* name) {
|
|||||||
device_ = kVariantDeviceNull;
|
device_ = kVariantDeviceNull;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -48,28 +48,39 @@ class EagerOperation : public AbstractOperationInterface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const string& Name() const override { return attrs_.op_name(); }
|
const string& Name() const override { return attrs_.op_name(); }
|
||||||
const string& DeviceName() const override;
|
|
||||||
|
|
||||||
const string& GetDeviceName() const { return device_name_; }
|
const string& DeviceName() const override { return device_name_; }
|
||||||
|
|
||||||
const DeviceNameUtils::ParsedName& GetDeviceParsedName() const {
|
const DeviceNameUtils::ParsedName& GetDeviceParsedName() const {
|
||||||
return device_parsed_name_;
|
return device_parsed_name_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Replaces the previous device name with the given one (see
|
||||||
|
// AbstractOperationInterface::SetDeviceName for more details).
|
||||||
|
//
|
||||||
|
// This also resets the internal device pointer, unless the given name refers
|
||||||
|
// to a known custom device, in which case the internal device pointer is
|
||||||
|
// updated to that device.
|
||||||
Status SetDeviceName(const char* name) override;
|
Status SetDeviceName(const char* name) override;
|
||||||
|
|
||||||
void SetDevice(tensorflow::Device* device) {
|
void SetDevice(tensorflow::Device* device) {
|
||||||
device_ = device;
|
device_ = device;
|
||||||
last_set_device_name_.clear();
|
|
||||||
device_name_ = device->name();
|
device_name_ = device->name();
|
||||||
device_parsed_name_ = device->parsed_name();
|
device_parsed_name_ = device->parsed_name();
|
||||||
|
// TODO(b/154133594): Due to intricacies of external logic, we can not
|
||||||
|
// set this do device_name_ as it would be natural, because we need the
|
||||||
|
// next call to SetDeviceName to reset the device pointer.
|
||||||
|
last_set_device_name_ = "\177"; // DEL (an invalid value)
|
||||||
}
|
}
|
||||||
|
|
||||||
void SetDevice(tensorflow::CustomDevice* device) {
|
void SetDevice(tensorflow::CustomDevice* device) {
|
||||||
device_ = device;
|
device_ = device;
|
||||||
last_set_device_name_.clear();
|
|
||||||
device_name_ = device->name();
|
device_name_ = device->name();
|
||||||
DeviceNameUtils::ParseFullName(device_name_, &device_parsed_name_);
|
DeviceNameUtils::ParseFullName(device_name_, &device_parsed_name_);
|
||||||
|
// TODO(b/154133594): Due to intricacies of external logic, we can not
|
||||||
|
// set this do device_name_ as it would be natural, because we need the
|
||||||
|
// next call to SetDeviceName to reset the device pointer.
|
||||||
|
last_set_device_name_ = "\177"; // DEL (an invalid value)
|
||||||
}
|
}
|
||||||
|
|
||||||
Status AddInput(AbstractTensorHandleInterface* input) override;
|
Status AddInput(AbstractTensorHandleInterface* input) override;
|
||||||
@ -191,10 +202,20 @@ class EagerOperation : public AbstractOperationInterface {
|
|||||||
// calls to SetDeviceName.
|
// calls to SetDeviceName.
|
||||||
string last_set_device_name_;
|
string last_set_device_name_;
|
||||||
|
|
||||||
|
// The operation's device name.
|
||||||
|
// This contains the named passed to SetDeviceName until device_ is set,
|
||||||
|
// at which point it contains the device_ name.
|
||||||
string device_name_;
|
string device_name_;
|
||||||
|
|
||||||
|
// The parsed device name.
|
||||||
|
// This will always contain the result of
|
||||||
|
// DeviceNameUtils::ParseFullName(device_name_).
|
||||||
DeviceNameUtils::ParsedName device_parsed_name_;
|
DeviceNameUtils::ParsedName device_parsed_name_;
|
||||||
|
|
||||||
|
// The operation's device.
|
||||||
|
// This is set by the execution device placement logic, and should conform
|
||||||
|
// with the contents of device_name_. Once it is set, the device_name_ is
|
||||||
|
// updated accordingly.
|
||||||
VariantDevice device_;
|
VariantDevice device_;
|
||||||
|
|
||||||
bool use_xla_ = false;
|
bool use_xla_ = false;
|
||||||
|
52
tensorflow/core/common_runtime/eager/eager_operation_test.cc
Normal file
52
tensorflow/core/common_runtime/eager/eager_operation_test.cc
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
|
||||||
|
|
||||||
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
|
#include "tensorflow/core/platform/errors.h"
|
||||||
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
TEST(EagerOperationTest, DeviceName) {
|
||||||
|
StaticDeviceMgr device_mgr(DeviceFactory::NewDevice(
|
||||||
|
"CPU", {}, "/job:localhost/replica:0/task:0/device:CPU:0"));
|
||||||
|
auto ctx = new EagerContext(
|
||||||
|
SessionOptions(),
|
||||||
|
tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
|
||||||
|
tensorflow::ContextMirroringPolicy::MIRRORING_NONE, false, false,
|
||||||
|
&device_mgr, false, nullptr, nullptr, nullptr);
|
||||||
|
|
||||||
|
auto op = new EagerOperation(ctx);
|
||||||
|
|
||||||
|
TF_ASSERT_OK(op->SetDeviceName("/device:DONTHAVE"));
|
||||||
|
EXPECT_EQ("/device:DONTHAVE:*", op->DeviceName());
|
||||||
|
|
||||||
|
TF_ASSERT_OK(op->SetDeviceName(""));
|
||||||
|
EXPECT_EQ("", op->DeviceName());
|
||||||
|
|
||||||
|
TF_ASSERT_OK(op->SetDeviceName("/job:localhost"));
|
||||||
|
EXPECT_EQ("/job:localhost", op->DeviceName());
|
||||||
|
|
||||||
|
EXPECT_NE(Status::OK(), op->SetDeviceName("/not/a/valid/name"));
|
||||||
|
|
||||||
|
delete op;
|
||||||
|
ctx->Unref();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace tensorflow
|
@ -385,7 +385,7 @@ Status EagerLocalExecute(EagerOperation* op, TensorHandle** retvals,
|
|||||||
TF_RETURN_IF_ERROR(executor.status());
|
TF_RETURN_IF_ERROR(executor.status());
|
||||||
Device* device = absl::get<Device*>(op->Device());
|
Device* device = absl::get<Device*>(op->Device());
|
||||||
|
|
||||||
Fprint128 cache_key = op->MutableAttrs()->CacheKey(op->GetDeviceName());
|
Fprint128 cache_key = op->MutableAttrs()->CacheKey(op->DeviceName());
|
||||||
|
|
||||||
std::vector<Device*> input_dev_ptrs;
|
std::vector<Device*> input_dev_ptrs;
|
||||||
std::unordered_map<int, DtypeAndPartialTensorShape>
|
std::unordered_map<int, DtypeAndPartialTensorShape>
|
||||||
@ -677,7 +677,7 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals,
|
|||||||
// TODO(fishx): Remove following code when lazy tensor copy is ready.
|
// TODO(fishx): Remove following code when lazy tensor copy is ready.
|
||||||
if (op->Device() == kVariantDeviceNull) {
|
if (op->Device() == kVariantDeviceNull) {
|
||||||
tensorflow::Device* device = nullptr;
|
tensorflow::Device* device = nullptr;
|
||||||
string device_name = op->GetDeviceName();
|
string device_name = op->DeviceName();
|
||||||
TF_RETURN_IF_ERROR(ctx.FindDeviceFromName(device_name.c_str(), &device));
|
TF_RETURN_IF_ERROR(ctx.FindDeviceFromName(device_name.c_str(), &device));
|
||||||
op->SetDevice(device);
|
op->SetDevice(device);
|
||||||
}
|
}
|
||||||
@ -886,9 +886,9 @@ Status MaybeUpdateOpDevice(EagerOperation* op) {
|
|||||||
// TODO(b/145922293): Support allowed_devices specified in wildcard
|
// TODO(b/145922293): Support allowed_devices specified in wildcard
|
||||||
// patterns.
|
// patterns.
|
||||||
if (std::find(allowed_devices.begin(), allowed_devices.end(),
|
if (std::find(allowed_devices.begin(), allowed_devices.end(),
|
||||||
op->GetDeviceName()) != allowed_devices.end()) {
|
op->DeviceName()) != allowed_devices.end()) {
|
||||||
TF_RETURN_IF_ERROR(ctx.FindDeviceFromName(
|
TF_RETURN_IF_ERROR(ctx.FindDeviceFromName(op->DeviceName().c_str(),
|
||||||
op->GetDeviceName().c_str(), &resource_device));
|
&resource_device));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
DVLOG(1) << (resource_device != op_device ? "Changing " : "Setting ")
|
DVLOG(1) << (resource_device != op_device ? "Changing " : "Setting ")
|
||||||
|
@ -250,7 +250,7 @@ class OpNode {
|
|||||||
|
|
||||||
// Precalculating a cache key saves about 10% of inference time for very
|
// Precalculating a cache key saves about 10% of inference time for very
|
||||||
// small models.
|
// small models.
|
||||||
op_->MutableAttrs()->CacheKey(op_->GetDeviceName());
|
op_->MutableAttrs()->CacheKey(op_->DeviceName());
|
||||||
|
|
||||||
return tensorflow::Status::OK();
|
return tensorflow::Status::OK();
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user