Unify device names in EagerOperation.

Change DeviceName() so it returns what is set by
SetDeviceName(). Update documentation and code accordingly. The
GetDeviceName() function became redundant and was removed.

PiperOrigin-RevId: 306722028
Change-Id: I680498e36e477d06521a7b2af267c2d7852bfafd
This commit is contained in:
Cesar Crusius 2020-04-15 14:33:06 -07:00 committed by TensorFlower Gardener
parent 75aa2fa3f6
commit ed7d6f617d
7 changed files with 133 additions and 43 deletions

View File

@ -42,7 +42,28 @@ class AbstractOperationInterface {
virtual Status Reset(const char* op, const char* raw_device_name) = 0;
virtual const string& Name() const = 0;
// Returns the operation's device name.
//
// The value returned may be different from the one set by SetDeviceName, but
// it will be compatible with it: the name will be updated by device placement
// logic to refer to the specific device chosen.
//
// Example: If one calls `op->SetDeviceName("/device:GPU")`, the value
// returned by DeviceName should be "/device:GPU:*" until a particular GPU is
// chosen for the operation by the device placement logic in the
// executor. After that, the value returned by DeviceName will be a full
// device name such as "/job:localhost/replica:0/task:0/device:GPU:1".
virtual const string& DeviceName() const = 0;
// Sets the operation device name.
//
// The given `name` must be parseable by DeviceNameUtils::ParseFullName, and
// the result will be used as a constraint for device placement. See the
// documentation for DeviceName for more details.
//
// The value will override the previous value - that is, no "merging" of
// existing and given constraints will be performed.
virtual Status SetDeviceName(const char* name) = 0;
virtual Status AddInput(AbstractTensorHandleInterface* input) = 0;

View File

@ -151,6 +151,18 @@ tf_cuda_library(
}),
)
tf_cc_test(
name = "eager_operation_test",
srcs = ["eager_operation_test.cc"],
deps = [
":core",
":eager_operation",
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
)
tf_cuda_library(
name = "tensor_handle_data",
srcs = [

View File

@ -36,13 +36,6 @@ void EagerOperation::Clear() {
ClearInferenceState();
}
const string& EagerOperation::DeviceName() const {
VariantDevice variant_device =
(Device() == kVariantDeviceNull) ? EagerContext().HostCPU() : Device();
return absl::visit([](auto* d) -> const string& { return d->name(); },
variant_device);
}
Status EagerOperation::SetAttrString(const char* attr_name, const char* data,
size_t length) {
MutableAttrs()->Set(attr_name, StringPiece(data, length));
@ -311,15 +304,7 @@ Status EagerOperation::Reset(
executor_ = executor ? executor : &ctx_.Executor();
remote_func_params_ = remote_func_params;
op_name_ = op;
if (device_name != nullptr && strlen(device_name) > 0) {
return SetDeviceName(device_name);
} else {
last_set_device_name_.clear();
device_name_.clear();
device_parsed_name_.Clear();
device_ = kVariantDeviceNull;
return Status::OK();
}
}
Status EagerOperation::MaybeInferSingleInputAttrs(TensorHandle* handle) {
@ -389,8 +374,8 @@ Status EagerOperation::InferInputListAttrs(int num_inputs) {
return Status::OK();
}
Status EagerOperation::SetDeviceName(const char* name) {
if (name != nullptr && strlen(name) > 0) {
Status EagerOperation::SetDeviceName(const char* c_name) {
string name(c_name != nullptr ? c_name : "");
if (name != last_set_device_name_) {
if (!DeviceNameUtils::ParseFullName(name, &device_parsed_name_)) {
return errors::InvalidArgument("Malformed device specification '", name,
@ -407,7 +392,6 @@ Status EagerOperation::SetDeviceName(const char* name) {
device_ = kVariantDeviceNull;
}
}
}
return Status::OK();
}

View File

@ -48,28 +48,39 @@ class EagerOperation : public AbstractOperationInterface {
}
const string& Name() const override { return attrs_.op_name(); }
const string& DeviceName() const override;
const string& GetDeviceName() const { return device_name_; }
const string& DeviceName() const override { return device_name_; }
const DeviceNameUtils::ParsedName& GetDeviceParsedName() const {
return device_parsed_name_;
}
// Replaces the previous device name with the given one (see
// AbstractOperationInterface::SetDeviceName for more details).
//
// This also resets the internal device pointer, unless the given name refers
// to a known custom device, in which case the internal device pointer is
// updated to that device.
Status SetDeviceName(const char* name) override;
void SetDevice(tensorflow::Device* device) {
device_ = device;
last_set_device_name_.clear();
device_name_ = device->name();
device_parsed_name_ = device->parsed_name();
// TODO(b/154133594): Due to intricacies of external logic, we can not
// set this do device_name_ as it would be natural, because we need the
// next call to SetDeviceName to reset the device pointer.
last_set_device_name_ = "\177"; // DEL (an invalid value)
}
void SetDevice(tensorflow::CustomDevice* device) {
device_ = device;
last_set_device_name_.clear();
device_name_ = device->name();
DeviceNameUtils::ParseFullName(device_name_, &device_parsed_name_);
// TODO(b/154133594): Due to intricacies of external logic, we can not
// set this do device_name_ as it would be natural, because we need the
// next call to SetDeviceName to reset the device pointer.
last_set_device_name_ = "\177"; // DEL (an invalid value)
}
Status AddInput(AbstractTensorHandleInterface* input) override;
@ -191,10 +202,20 @@ class EagerOperation : public AbstractOperationInterface {
// calls to SetDeviceName.
string last_set_device_name_;
// The operation's device name.
// This contains the named passed to SetDeviceName until device_ is set,
// at which point it contains the device_ name.
string device_name_;
// The parsed device name.
// This will always contain the result of
// DeviceNameUtils::ParseFullName(device_name_).
DeviceNameUtils::ParsedName device_parsed_name_;
// The operation's device.
// This is set by the execution device placement logic, and should conform
// with the contents of device_name_. Once it is set, the device_name_ is
// updated accordingly.
VariantDevice device_;
bool use_xla_ = false;

View File

@ -0,0 +1,52 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace {
TEST(EagerOperationTest, DeviceName) {
StaticDeviceMgr device_mgr(DeviceFactory::NewDevice(
"CPU", {}, "/job:localhost/replica:0/task:0/device:CPU:0"));
auto ctx = new EagerContext(
SessionOptions(),
tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
tensorflow::ContextMirroringPolicy::MIRRORING_NONE, false, false,
&device_mgr, false, nullptr, nullptr, nullptr);
auto op = new EagerOperation(ctx);
TF_ASSERT_OK(op->SetDeviceName("/device:DONTHAVE"));
EXPECT_EQ("/device:DONTHAVE:*", op->DeviceName());
TF_ASSERT_OK(op->SetDeviceName(""));
EXPECT_EQ("", op->DeviceName());
TF_ASSERT_OK(op->SetDeviceName("/job:localhost"));
EXPECT_EQ("/job:localhost", op->DeviceName());
EXPECT_NE(Status::OK(), op->SetDeviceName("/not/a/valid/name"));
delete op;
ctx->Unref();
}
} // namespace
} // namespace tensorflow

View File

@ -385,7 +385,7 @@ Status EagerLocalExecute(EagerOperation* op, TensorHandle** retvals,
TF_RETURN_IF_ERROR(executor.status());
Device* device = absl::get<Device*>(op->Device());
Fprint128 cache_key = op->MutableAttrs()->CacheKey(op->GetDeviceName());
Fprint128 cache_key = op->MutableAttrs()->CacheKey(op->DeviceName());
std::vector<Device*> input_dev_ptrs;
std::unordered_map<int, DtypeAndPartialTensorShape>
@ -677,7 +677,7 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals,
// TODO(fishx): Remove following code when lazy tensor copy is ready.
if (op->Device() == kVariantDeviceNull) {
tensorflow::Device* device = nullptr;
string device_name = op->GetDeviceName();
string device_name = op->DeviceName();
TF_RETURN_IF_ERROR(ctx.FindDeviceFromName(device_name.c_str(), &device));
op->SetDevice(device);
}
@ -886,9 +886,9 @@ Status MaybeUpdateOpDevice(EagerOperation* op) {
// TODO(b/145922293): Support allowed_devices specified in wildcard
// patterns.
if (std::find(allowed_devices.begin(), allowed_devices.end(),
op->GetDeviceName()) != allowed_devices.end()) {
TF_RETURN_IF_ERROR(ctx.FindDeviceFromName(
op->GetDeviceName().c_str(), &resource_device));
op->DeviceName()) != allowed_devices.end()) {
TF_RETURN_IF_ERROR(ctx.FindDeviceFromName(op->DeviceName().c_str(),
&resource_device));
}
}
DVLOG(1) << (resource_device != op_device ? "Changing " : "Setting ")

View File

@ -250,7 +250,7 @@ class OpNode {
// Precalculating a cache key saves about 10% of inference time for very
// small models.
op_->MutableAttrs()->CacheKey(op_->GetDeviceName());
op_->MutableAttrs()->CacheKey(op_->DeviceName());
return tensorflow::Status::OK();
}