Unify device names in EagerOperation.
Change DeviceName() so it returns what is set by SetDeviceName(). Update documentation and code accordingly. The GetDeviceName() function became redundant and was removed. PiperOrigin-RevId: 306722028 Change-Id: I680498e36e477d06521a7b2af267c2d7852bfafd
This commit is contained in:
parent
75aa2fa3f6
commit
ed7d6f617d
@ -42,7 +42,28 @@ class AbstractOperationInterface {
|
||||
virtual Status Reset(const char* op, const char* raw_device_name) = 0;
|
||||
|
||||
virtual const string& Name() const = 0;
|
||||
|
||||
// Returns the operation's device name.
|
||||
//
|
||||
// The value returned may be different from the one set by SetDeviceName, but
|
||||
// it will be compatible with it: the name will be updated by device placement
|
||||
// logic to refer to the specific device chosen.
|
||||
//
|
||||
// Example: If one calls `op->SetDeviceName("/device:GPU")`, the value
|
||||
// returned by DeviceName should be "/device:GPU:*" until a particular GPU is
|
||||
// chosen for the operation by the device placement logic in the
|
||||
// executor. After that, the value returned by DeviceName will be a full
|
||||
// device name such as "/job:localhost/replica:0/task:0/device:GPU:1".
|
||||
virtual const string& DeviceName() const = 0;
|
||||
|
||||
// Sets the operation device name.
|
||||
//
|
||||
// The given `name` must be parseable by DeviceNameUtils::ParseFullName, and
|
||||
// the result will be used as a constraint for device placement. See the
|
||||
// documentation for DeviceName for more details.
|
||||
//
|
||||
// The value will override the previous value - that is, no "merging" of
|
||||
// existing and given constraints will be performed.
|
||||
virtual Status SetDeviceName(const char* name) = 0;
|
||||
|
||||
virtual Status AddInput(AbstractTensorHandleInterface* input) = 0;
|
||||
|
@ -151,6 +151,18 @@ tf_cuda_library(
|
||||
}),
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "eager_operation_test",
|
||||
srcs = ["eager_operation_test.cc"],
|
||||
deps = [
|
||||
":core",
|
||||
":eager_operation",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cuda_library(
|
||||
name = "tensor_handle_data",
|
||||
srcs = [
|
||||
|
@ -36,13 +36,6 @@ void EagerOperation::Clear() {
|
||||
ClearInferenceState();
|
||||
}
|
||||
|
||||
const string& EagerOperation::DeviceName() const {
|
||||
VariantDevice variant_device =
|
||||
(Device() == kVariantDeviceNull) ? EagerContext().HostCPU() : Device();
|
||||
return absl::visit([](auto* d) -> const string& { return d->name(); },
|
||||
variant_device);
|
||||
}
|
||||
|
||||
Status EagerOperation::SetAttrString(const char* attr_name, const char* data,
|
||||
size_t length) {
|
||||
MutableAttrs()->Set(attr_name, StringPiece(data, length));
|
||||
@ -311,15 +304,7 @@ Status EagerOperation::Reset(
|
||||
executor_ = executor ? executor : &ctx_.Executor();
|
||||
remote_func_params_ = remote_func_params;
|
||||
op_name_ = op;
|
||||
if (device_name != nullptr && strlen(device_name) > 0) {
|
||||
return SetDeviceName(device_name);
|
||||
} else {
|
||||
last_set_device_name_.clear();
|
||||
device_name_.clear();
|
||||
device_parsed_name_.Clear();
|
||||
device_ = kVariantDeviceNull;
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
|
||||
Status EagerOperation::MaybeInferSingleInputAttrs(TensorHandle* handle) {
|
||||
@ -389,8 +374,8 @@ Status EagerOperation::InferInputListAttrs(int num_inputs) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status EagerOperation::SetDeviceName(const char* name) {
|
||||
if (name != nullptr && strlen(name) > 0) {
|
||||
Status EagerOperation::SetDeviceName(const char* c_name) {
|
||||
string name(c_name != nullptr ? c_name : "");
|
||||
if (name != last_set_device_name_) {
|
||||
if (!DeviceNameUtils::ParseFullName(name, &device_parsed_name_)) {
|
||||
return errors::InvalidArgument("Malformed device specification '", name,
|
||||
@ -407,7 +392,6 @@ Status EagerOperation::SetDeviceName(const char* name) {
|
||||
device_ = kVariantDeviceNull;
|
||||
}
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -48,28 +48,39 @@ class EagerOperation : public AbstractOperationInterface {
|
||||
}
|
||||
|
||||
const string& Name() const override { return attrs_.op_name(); }
|
||||
const string& DeviceName() const override;
|
||||
|
||||
const string& GetDeviceName() const { return device_name_; }
|
||||
const string& DeviceName() const override { return device_name_; }
|
||||
|
||||
const DeviceNameUtils::ParsedName& GetDeviceParsedName() const {
|
||||
return device_parsed_name_;
|
||||
}
|
||||
|
||||
// Replaces the previous device name with the given one (see
|
||||
// AbstractOperationInterface::SetDeviceName for more details).
|
||||
//
|
||||
// This also resets the internal device pointer, unless the given name refers
|
||||
// to a known custom device, in which case the internal device pointer is
|
||||
// updated to that device.
|
||||
Status SetDeviceName(const char* name) override;
|
||||
|
||||
void SetDevice(tensorflow::Device* device) {
|
||||
device_ = device;
|
||||
last_set_device_name_.clear();
|
||||
device_name_ = device->name();
|
||||
device_parsed_name_ = device->parsed_name();
|
||||
// TODO(b/154133594): Due to intricacies of external logic, we can not
|
||||
// set this do device_name_ as it would be natural, because we need the
|
||||
// next call to SetDeviceName to reset the device pointer.
|
||||
last_set_device_name_ = "\177"; // DEL (an invalid value)
|
||||
}
|
||||
|
||||
void SetDevice(tensorflow::CustomDevice* device) {
|
||||
device_ = device;
|
||||
last_set_device_name_.clear();
|
||||
device_name_ = device->name();
|
||||
DeviceNameUtils::ParseFullName(device_name_, &device_parsed_name_);
|
||||
// TODO(b/154133594): Due to intricacies of external logic, we can not
|
||||
// set this do device_name_ as it would be natural, because we need the
|
||||
// next call to SetDeviceName to reset the device pointer.
|
||||
last_set_device_name_ = "\177"; // DEL (an invalid value)
|
||||
}
|
||||
|
||||
Status AddInput(AbstractTensorHandleInterface* input) override;
|
||||
@ -191,10 +202,20 @@ class EagerOperation : public AbstractOperationInterface {
|
||||
// calls to SetDeviceName.
|
||||
string last_set_device_name_;
|
||||
|
||||
// The operation's device name.
|
||||
// This contains the named passed to SetDeviceName until device_ is set,
|
||||
// at which point it contains the device_ name.
|
||||
string device_name_;
|
||||
|
||||
// The parsed device name.
|
||||
// This will always contain the result of
|
||||
// DeviceNameUtils::ParseFullName(device_name_).
|
||||
DeviceNameUtils::ParsedName device_parsed_name_;
|
||||
|
||||
// The operation's device.
|
||||
// This is set by the execution device placement logic, and should conform
|
||||
// with the contents of device_name_. Once it is set, the device_name_ is
|
||||
// updated accordingly.
|
||||
VariantDevice device_;
|
||||
|
||||
bool use_xla_ = false;
|
||||
|
52
tensorflow/core/common_runtime/eager/eager_operation_test.cc
Normal file
52
tensorflow/core/common_runtime/eager/eager_operation_test.cc
Normal file
@ -0,0 +1,52 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
|
||||
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
TEST(EagerOperationTest, DeviceName) {
|
||||
StaticDeviceMgr device_mgr(DeviceFactory::NewDevice(
|
||||
"CPU", {}, "/job:localhost/replica:0/task:0/device:CPU:0"));
|
||||
auto ctx = new EagerContext(
|
||||
SessionOptions(),
|
||||
tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
|
||||
tensorflow::ContextMirroringPolicy::MIRRORING_NONE, false, false,
|
||||
&device_mgr, false, nullptr, nullptr, nullptr);
|
||||
|
||||
auto op = new EagerOperation(ctx);
|
||||
|
||||
TF_ASSERT_OK(op->SetDeviceName("/device:DONTHAVE"));
|
||||
EXPECT_EQ("/device:DONTHAVE:*", op->DeviceName());
|
||||
|
||||
TF_ASSERT_OK(op->SetDeviceName(""));
|
||||
EXPECT_EQ("", op->DeviceName());
|
||||
|
||||
TF_ASSERT_OK(op->SetDeviceName("/job:localhost"));
|
||||
EXPECT_EQ("/job:localhost", op->DeviceName());
|
||||
|
||||
EXPECT_NE(Status::OK(), op->SetDeviceName("/not/a/valid/name"));
|
||||
|
||||
delete op;
|
||||
ctx->Unref();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
@ -385,7 +385,7 @@ Status EagerLocalExecute(EagerOperation* op, TensorHandle** retvals,
|
||||
TF_RETURN_IF_ERROR(executor.status());
|
||||
Device* device = absl::get<Device*>(op->Device());
|
||||
|
||||
Fprint128 cache_key = op->MutableAttrs()->CacheKey(op->GetDeviceName());
|
||||
Fprint128 cache_key = op->MutableAttrs()->CacheKey(op->DeviceName());
|
||||
|
||||
std::vector<Device*> input_dev_ptrs;
|
||||
std::unordered_map<int, DtypeAndPartialTensorShape>
|
||||
@ -677,7 +677,7 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals,
|
||||
// TODO(fishx): Remove following code when lazy tensor copy is ready.
|
||||
if (op->Device() == kVariantDeviceNull) {
|
||||
tensorflow::Device* device = nullptr;
|
||||
string device_name = op->GetDeviceName();
|
||||
string device_name = op->DeviceName();
|
||||
TF_RETURN_IF_ERROR(ctx.FindDeviceFromName(device_name.c_str(), &device));
|
||||
op->SetDevice(device);
|
||||
}
|
||||
@ -886,9 +886,9 @@ Status MaybeUpdateOpDevice(EagerOperation* op) {
|
||||
// TODO(b/145922293): Support allowed_devices specified in wildcard
|
||||
// patterns.
|
||||
if (std::find(allowed_devices.begin(), allowed_devices.end(),
|
||||
op->GetDeviceName()) != allowed_devices.end()) {
|
||||
TF_RETURN_IF_ERROR(ctx.FindDeviceFromName(
|
||||
op->GetDeviceName().c_str(), &resource_device));
|
||||
op->DeviceName()) != allowed_devices.end()) {
|
||||
TF_RETURN_IF_ERROR(ctx.FindDeviceFromName(op->DeviceName().c_str(),
|
||||
&resource_device));
|
||||
}
|
||||
}
|
||||
DVLOG(1) << (resource_device != op_device ? "Changing " : "Setting ")
|
||||
|
@ -250,7 +250,7 @@ class OpNode {
|
||||
|
||||
// Precalculating a cache key saves about 10% of inference time for very
|
||||
// small models.
|
||||
op_->MutableAttrs()->CacheKey(op_->GetDeviceName());
|
||||
op_->MutableAttrs()->CacheKey(op_->DeviceName());
|
||||
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user