From ed7d6f617dd0e8362354b5b618a5d03514ecb337 Mon Sep 17 00:00:00 2001 From: Cesar Crusius Date: Wed, 15 Apr 2020 14:33:06 -0700 Subject: [PATCH] Unify device names in EagerOperation. Change DeviceName() so it returns what is set by SetDeviceName(). Update documentation and code accordingly. The GetDeviceName() function became redundant and was removed. PiperOrigin-RevId: 306722028 Change-Id: I680498e36e477d06521a7b2af267c2d7852bfafd --- tensorflow/c/eager/operation_interface.h | 21 ++++++++ tensorflow/core/common_runtime/eager/BUILD | 12 +++++ .../common_runtime/eager/eager_operation.cc | 50 ++++++------------ .../common_runtime/eager/eager_operation.h | 29 +++++++++-- .../eager/eager_operation_test.cc | 52 +++++++++++++++++++ .../core/common_runtime/eager/execute.cc | 10 ++-- tensorflow/lite/delegates/flex/kernel.cc | 2 +- 7 files changed, 133 insertions(+), 43 deletions(-) create mode 100644 tensorflow/core/common_runtime/eager/eager_operation_test.cc diff --git a/tensorflow/c/eager/operation_interface.h b/tensorflow/c/eager/operation_interface.h index 4651d45ec04..844ba6c14bd 100644 --- a/tensorflow/c/eager/operation_interface.h +++ b/tensorflow/c/eager/operation_interface.h @@ -42,7 +42,28 @@ class AbstractOperationInterface { virtual Status Reset(const char* op, const char* raw_device_name) = 0; virtual const string& Name() const = 0; + + // Returns the operation's device name. + // + // The value returned may be different from the one set by SetDeviceName, but + // it will be compatible with it: the name will be updated by device placement + // logic to refer to the specific device chosen. + // + // Example: If one calls `op->SetDeviceName("/device:GPU")`, the value + // returned by DeviceName should be "/device:GPU:*" until a particular GPU is + // chosen for the operation by the device placement logic in the + // executor. After that, the value returned by DeviceName will be a full + // device name such as "/job:localhost/replica:0/task:0/device:GPU:1". virtual const string& DeviceName() const = 0; + + // Sets the operation device name. + // + // The given `name` must be parseable by DeviceNameUtils::ParseFullName, and + // the result will be used as a constraint for device placement. See the + // documentation for DeviceName for more details. + // + // The value will override the previous value - that is, no "merging" of + // existing and given constraints will be performed. virtual Status SetDeviceName(const char* name) = 0; virtual Status AddInput(AbstractTensorHandleInterface* input) = 0; diff --git a/tensorflow/core/common_runtime/eager/BUILD b/tensorflow/core/common_runtime/eager/BUILD index 328f29b2cb4..c57e1edb283 100644 --- a/tensorflow/core/common_runtime/eager/BUILD +++ b/tensorflow/core/common_runtime/eager/BUILD @@ -151,6 +151,18 @@ tf_cuda_library( }), ) +tf_cc_test( + name = "eager_operation_test", + srcs = ["eager_operation_test.cc"], + deps = [ + ":core", + ":eager_operation", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + tf_cuda_library( name = "tensor_handle_data", srcs = [ diff --git a/tensorflow/core/common_runtime/eager/eager_operation.cc b/tensorflow/core/common_runtime/eager/eager_operation.cc index e3a455e7e0d..ec8cc412658 100644 --- a/tensorflow/core/common_runtime/eager/eager_operation.cc +++ b/tensorflow/core/common_runtime/eager/eager_operation.cc @@ -36,13 +36,6 @@ void EagerOperation::Clear() { ClearInferenceState(); } -const string& EagerOperation::DeviceName() const { - VariantDevice variant_device = - (Device() == kVariantDeviceNull) ? EagerContext().HostCPU() : Device(); - return absl::visit([](auto* d) -> const string& { return d->name(); }, - variant_device); -} - Status EagerOperation::SetAttrString(const char* attr_name, const char* data, size_t length) { MutableAttrs()->Set(attr_name, StringPiece(data, length)); @@ -311,15 +304,7 @@ Status EagerOperation::Reset( executor_ = executor ? executor : &ctx_.Executor(); remote_func_params_ = remote_func_params; op_name_ = op; - if (device_name != nullptr && strlen(device_name) > 0) { - return SetDeviceName(device_name); - } else { - last_set_device_name_.clear(); - device_name_.clear(); - device_parsed_name_.Clear(); - device_ = kVariantDeviceNull; - return Status::OK(); - } + return SetDeviceName(device_name); } Status EagerOperation::MaybeInferSingleInputAttrs(TensorHandle* handle) { @@ -389,23 +374,22 @@ Status EagerOperation::InferInputListAttrs(int num_inputs) { return Status::OK(); } -Status EagerOperation::SetDeviceName(const char* name) { - if (name != nullptr && strlen(name) > 0) { - if (name != last_set_device_name_) { - if (!DeviceNameUtils::ParseFullName(name, &device_parsed_name_)) { - return errors::InvalidArgument("Malformed device specification '", name, - "' in eager op: ", DebugString()); - } - last_set_device_name_ = name; - device_name_ = DeviceNameUtils::ParsedNameToString(device_parsed_name_); - CustomDevice* custom_device; - if (ctx_.FindCustomDeviceFromName(device_name_, &custom_device).ok()) { - device_ = custom_device; - } else { - // Device placement for physical devices happens lazily in - // EagerExecute/EagerRemoteExecute, and can depend on the inputs. - device_ = kVariantDeviceNull; - } +Status EagerOperation::SetDeviceName(const char* c_name) { + string name(c_name != nullptr ? c_name : ""); + if (name != last_set_device_name_) { + if (!DeviceNameUtils::ParseFullName(name, &device_parsed_name_)) { + return errors::InvalidArgument("Malformed device specification '", name, + "' in eager op: ", DebugString()); + } + last_set_device_name_ = name; + device_name_ = DeviceNameUtils::ParsedNameToString(device_parsed_name_); + CustomDevice* custom_device; + if (ctx_.FindCustomDeviceFromName(device_name_, &custom_device).ok()) { + device_ = custom_device; + } else { + // Device placement for physical devices happens lazily in + // EagerExecute/EagerRemoteExecute, and can depend on the inputs. + device_ = kVariantDeviceNull; } } return Status::OK(); diff --git a/tensorflow/core/common_runtime/eager/eager_operation.h b/tensorflow/core/common_runtime/eager/eager_operation.h index 44dce9dc057..5a2a437dac8 100644 --- a/tensorflow/core/common_runtime/eager/eager_operation.h +++ b/tensorflow/core/common_runtime/eager/eager_operation.h @@ -48,28 +48,39 @@ class EagerOperation : public AbstractOperationInterface { } const string& Name() const override { return attrs_.op_name(); } - const string& DeviceName() const override; - const string& GetDeviceName() const { return device_name_; } + const string& DeviceName() const override { return device_name_; } const DeviceNameUtils::ParsedName& GetDeviceParsedName() const { return device_parsed_name_; } + // Replaces the previous device name with the given one (see + // AbstractOperationInterface::SetDeviceName for more details). + // + // This also resets the internal device pointer, unless the given name refers + // to a known custom device, in which case the internal device pointer is + // updated to that device. Status SetDeviceName(const char* name) override; void SetDevice(tensorflow::Device* device) { device_ = device; - last_set_device_name_.clear(); device_name_ = device->name(); device_parsed_name_ = device->parsed_name(); + // TODO(b/154133594): Due to intricacies of external logic, we can not + // set this do device_name_ as it would be natural, because we need the + // next call to SetDeviceName to reset the device pointer. + last_set_device_name_ = "\177"; // DEL (an invalid value) } void SetDevice(tensorflow::CustomDevice* device) { device_ = device; - last_set_device_name_.clear(); device_name_ = device->name(); DeviceNameUtils::ParseFullName(device_name_, &device_parsed_name_); + // TODO(b/154133594): Due to intricacies of external logic, we can not + // set this do device_name_ as it would be natural, because we need the + // next call to SetDeviceName to reset the device pointer. + last_set_device_name_ = "\177"; // DEL (an invalid value) } Status AddInput(AbstractTensorHandleInterface* input) override; @@ -191,10 +202,20 @@ class EagerOperation : public AbstractOperationInterface { // calls to SetDeviceName. string last_set_device_name_; + // The operation's device name. + // This contains the named passed to SetDeviceName until device_ is set, + // at which point it contains the device_ name. string device_name_; + // The parsed device name. + // This will always contain the result of + // DeviceNameUtils::ParseFullName(device_name_). DeviceNameUtils::ParsedName device_parsed_name_; + // The operation's device. + // This is set by the execution device placement logic, and should conform + // with the contents of device_name_. Once it is set, the device_name_ is + // updated accordingly. VariantDevice device_; bool use_xla_ = false; diff --git a/tensorflow/core/common_runtime/eager/eager_operation_test.cc b/tensorflow/core/common_runtime/eager/eager_operation_test.cc new file mode 100644 index 00000000000..352c7f03365 --- /dev/null +++ b/tensorflow/core/common_runtime/eager/eager_operation_test.cc @@ -0,0 +1,52 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/common_runtime/eager/eager_operation.h" + +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +TEST(EagerOperationTest, DeviceName) { + StaticDeviceMgr device_mgr(DeviceFactory::NewDevice( + "CPU", {}, "/job:localhost/replica:0/task:0/device:CPU:0")); + auto ctx = new EagerContext( + SessionOptions(), + tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT, + tensorflow::ContextMirroringPolicy::MIRRORING_NONE, false, false, + &device_mgr, false, nullptr, nullptr, nullptr); + + auto op = new EagerOperation(ctx); + + TF_ASSERT_OK(op->SetDeviceName("/device:DONTHAVE")); + EXPECT_EQ("/device:DONTHAVE:*", op->DeviceName()); + + TF_ASSERT_OK(op->SetDeviceName("")); + EXPECT_EQ("", op->DeviceName()); + + TF_ASSERT_OK(op->SetDeviceName("/job:localhost")); + EXPECT_EQ("/job:localhost", op->DeviceName()); + + EXPECT_NE(Status::OK(), op->SetDeviceName("/not/a/valid/name")); + + delete op; + ctx->Unref(); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/common_runtime/eager/execute.cc b/tensorflow/core/common_runtime/eager/execute.cc index cb743fe81ba..33f4e52e095 100644 --- a/tensorflow/core/common_runtime/eager/execute.cc +++ b/tensorflow/core/common_runtime/eager/execute.cc @@ -385,7 +385,7 @@ Status EagerLocalExecute(EagerOperation* op, TensorHandle** retvals, TF_RETURN_IF_ERROR(executor.status()); Device* device = absl::get(op->Device()); - Fprint128 cache_key = op->MutableAttrs()->CacheKey(op->GetDeviceName()); + Fprint128 cache_key = op->MutableAttrs()->CacheKey(op->DeviceName()); std::vector input_dev_ptrs; std::unordered_map @@ -677,7 +677,7 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals, // TODO(fishx): Remove following code when lazy tensor copy is ready. if (op->Device() == kVariantDeviceNull) { tensorflow::Device* device = nullptr; - string device_name = op->GetDeviceName(); + string device_name = op->DeviceName(); TF_RETURN_IF_ERROR(ctx.FindDeviceFromName(device_name.c_str(), &device)); op->SetDevice(device); } @@ -886,9 +886,9 @@ Status MaybeUpdateOpDevice(EagerOperation* op) { // TODO(b/145922293): Support allowed_devices specified in wildcard // patterns. if (std::find(allowed_devices.begin(), allowed_devices.end(), - op->GetDeviceName()) != allowed_devices.end()) { - TF_RETURN_IF_ERROR(ctx.FindDeviceFromName( - op->GetDeviceName().c_str(), &resource_device)); + op->DeviceName()) != allowed_devices.end()) { + TF_RETURN_IF_ERROR(ctx.FindDeviceFromName(op->DeviceName().c_str(), + &resource_device)); } } DVLOG(1) << (resource_device != op_device ? "Changing " : "Setting ") diff --git a/tensorflow/lite/delegates/flex/kernel.cc b/tensorflow/lite/delegates/flex/kernel.cc index 04d9eca597e..d1c21086703 100644 --- a/tensorflow/lite/delegates/flex/kernel.cc +++ b/tensorflow/lite/delegates/flex/kernel.cc @@ -250,7 +250,7 @@ class OpNode { // Precalculating a cache key saves about 10% of inference time for very // small models. - op_->MutableAttrs()->CacheKey(op_->GetDeviceName()); + op_->MutableAttrs()->CacheKey(op_->DeviceName()); return tensorflow::Status::OK(); }