Introduce a data type parameter to EagerContext::SelectDevice.
This is in preparation for fixing a couple of bugs related to ConvertToEagerTensor, one of them about device placement of string tensors. The final fix will have ConvertToEagerTensorUncached calling SelectDevice, which will then need the data type parameter introduced in this change. Since the method is currently called in only one place, and that place has a DT_INVALID hardcoded parameter value, the change has no effect until the changes in ConvertToEagerTensorUncached come in. I created a context_test anyway to make sure the method is doing what it is supposed to do. This turned out to be valuable enough to justify this otherwise small change. PiperOrigin-RevId: 292934896 Change-Id: I4aa1e91904fb14b3e60f6f1b86ea699dd66c03dd
This commit is contained in:
parent
59145d293a
commit
123db4534b
@ -81,6 +81,17 @@ tf_cuda_library(
|
||||
}),
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "context_test",
|
||||
srcs = ["context_test.cc"],
|
||||
deps = [
|
||||
":context",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cuda_library(
|
||||
name = "eager_operation",
|
||||
srcs = [
|
||||
|
@ -168,12 +168,17 @@ std::vector<string> DevicesToString(const std::vector<Device*>& devices) {
|
||||
}
|
||||
} // namespace
|
||||
|
||||
Status EagerContext::SelectDevice(const DeviceNameUtils::ParsedName& preferred,
|
||||
Status EagerContext::SelectDevice(DeviceNameUtils::ParsedName preferred,
|
||||
const PrioritizedDeviceTypeVector& supported,
|
||||
Device** device) const {
|
||||
const DataType dtype, Device** device) const {
|
||||
std::vector<Device*> selected;
|
||||
const DeviceSet& pflr_devices = *pflr()->device_set();
|
||||
|
||||
// We always place string tensors on the CPU device if we're allowed to.
|
||||
if (dtype == DT_STRING && AllowSoftPlacement()) {
|
||||
preferred = HostCPU()->parsed_name();
|
||||
}
|
||||
|
||||
// If there are no preferred devices, select the first registered device from
|
||||
// the supported device list.
|
||||
if (!DeviceNameUtils::HasSomeDetails(preferred)) {
|
||||
|
@ -167,11 +167,15 @@ class EagerContext : public core::RefCounted {
|
||||
// devices, and the context currently allows soft device placement, a suitable
|
||||
// device not matching `preferred` will be chosen.
|
||||
//
|
||||
// The `dtype` parameter specifies the operation's result data type, if
|
||||
// known. Setting it to DT_INVALID will make this method not use the data type
|
||||
// for its decisions.
|
||||
//
|
||||
// The chosen device is stored in the `device` argument. The argument is not
|
||||
// modified unless this method returns `Status::OK()`.
|
||||
Status SelectDevice(const DeviceNameUtils::ParsedName& preferred,
|
||||
Status SelectDevice(DeviceNameUtils::ParsedName preferred,
|
||||
const PrioritizedDeviceTypeVector& supported,
|
||||
Device** device) const;
|
||||
const DataType dtype, Device** device) const;
|
||||
|
||||
// Sets the implicit copy policy for the current thread.
|
||||
void SetThreadLocalMirroringPolicy(ContextMirroringPolicy);
|
||||
|
174
tensorflow/core/common_runtime/eager/context_test.cc
Normal file
174
tensorflow/core/common_runtime/eager/context_test.cc
Normal file
@ -0,0 +1,174 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
// Return a fake device.
|
||||
static Device* CreateDevice(const string& type, int n) {
|
||||
class FakeDevice : public Device {
|
||||
public:
|
||||
explicit FakeDevice(const DeviceAttributes& attr) : Device(nullptr, attr) {}
|
||||
Status Sync() override { return Status::OK(); }
|
||||
Allocator* GetAllocator(AllocatorAttributes) override { return nullptr; }
|
||||
};
|
||||
DeviceAttributes attr;
|
||||
attr.set_name("/job:a/replica:0/task:0/device:" + type + ":" +
|
||||
std::to_string(n));
|
||||
attr.set_device_type(type);
|
||||
return new FakeDevice(attr);
|
||||
}
|
||||
|
||||
class EagerContextTest : public ::testing::Test {
|
||||
public:
|
||||
EagerContextTest() : device_manager_(nullptr), context_(nullptr) {}
|
||||
|
||||
~EagerContextTest() override {
|
||||
delete device_manager_;
|
||||
if (context_) {
|
||||
context_->Unref();
|
||||
}
|
||||
}
|
||||
|
||||
EagerContext* context() { return context_; }
|
||||
|
||||
void InitContext(const SessionOptions& opts,
|
||||
ContextDevicePlacementPolicy policy) {
|
||||
ASSERT_EQ(context_, nullptr);
|
||||
InitDeviceManager();
|
||||
context_ = new EagerContext(
|
||||
opts, policy,
|
||||
/* default_mirroring_policy */ MIRRORING_NONE,
|
||||
/* async */ false,
|
||||
/* lazy_copy_function_remote_inputs */ false, device_manager_,
|
||||
/* device_mgr_owned */ false, /* rendezvous */ nullptr,
|
||||
/* custom_kernel_creator */ nullptr,
|
||||
/* cluster_flr */ nullptr);
|
||||
}
|
||||
|
||||
protected:
|
||||
void InitDeviceManager() {
|
||||
ASSERT_EQ(device_manager_, nullptr);
|
||||
device_manager_ = new DynamicDeviceMgr();
|
||||
std::vector<std::unique_ptr<Device>> added_devices;
|
||||
added_devices.emplace_back(CreateDevice(DEVICE_CPU, 0));
|
||||
added_devices.emplace_back(CreateDevice(DEVICE_CPU, 1));
|
||||
added_devices.emplace_back(CreateDevice(DEVICE_GPU, 0));
|
||||
added_devices.emplace_back(CreateDevice(DEVICE_GPU, 1));
|
||||
|
||||
TF_CHECK_OK(device_manager_->AddDevices(std::move(added_devices)));
|
||||
}
|
||||
|
||||
DynamicDeviceMgr* device_manager_;
|
||||
EagerContext* context_;
|
||||
};
|
||||
|
||||
TEST_F(EagerContextTest, SelectDeviceExplicitHardPlacement) {
|
||||
SessionOptions options;
|
||||
options.config.set_log_device_placement(true);
|
||||
options.config.set_allow_soft_placement(false);
|
||||
InitContext(options, DEVICE_PLACEMENT_EXPLICIT);
|
||||
|
||||
Device* dev;
|
||||
DeviceNameUtils::ParsedName requested;
|
||||
const PrioritizedDeviceTypeVector supported{
|
||||
std::make_pair(DeviceType(DEVICE_GPU), 20),
|
||||
std::make_pair(DeviceType(DEVICE_CPU), 10),
|
||||
};
|
||||
|
||||
// No supported devices should result in an error.
|
||||
requested.Clear();
|
||||
Status status = context()->SelectDevice(
|
||||
requested, PrioritizedDeviceTypeVector{}, DT_INVALID, &dev);
|
||||
EXPECT_TRUE(errors::IsInvalidArgument(status));
|
||||
EXPECT_TRUE(
|
||||
absl::StrContains(status.error_message(), "No supported device found"))
|
||||
<< "unexpected error message " << status.error_message();
|
||||
|
||||
// An invalid requested device should also cause an error.
|
||||
ASSERT_TRUE(DeviceNameUtils::ParseLocalName("GPU:99", &requested));
|
||||
status = context()->SelectDevice(requested, supported, DT_INVALID, &dev);
|
||||
EXPECT_TRUE(errors::IsInvalidArgument(status));
|
||||
EXPECT_TRUE(absl::StrContains(status.error_message(),
|
||||
"Could not satisfy device specification"))
|
||||
<< "unexpected error message " << status.error_message();
|
||||
|
||||
// Should pick the "best" supported device if given no constraints.
|
||||
requested.Clear();
|
||||
TF_ASSERT_OK(context()->SelectDevice(requested, supported, DT_INVALID, &dev));
|
||||
EXPECT_EQ(dev->device_type(), DEVICE_GPU);
|
||||
|
||||
// Should pick a CPU if asked to.
|
||||
ASSERT_TRUE(DeviceNameUtils::ParseLocalName("CPU:1", &requested));
|
||||
TF_ASSERT_OK(context()->SelectDevice(requested, supported, DT_INVALID, &dev));
|
||||
EXPECT_EQ(dev->device_type(), DEVICE_CPU);
|
||||
|
||||
// String tensors stay in GPU under hard device placement.
|
||||
requested.Clear();
|
||||
TF_ASSERT_OK(context()->SelectDevice(requested, supported, DT_STRING, &dev));
|
||||
EXPECT_EQ(dev->device_type(), DEVICE_GPU);
|
||||
}
|
||||
|
||||
TEST_F(EagerContextTest, SelectDeviceExplicitSoftPlacement) {
|
||||
SessionOptions options;
|
||||
options.config.set_log_device_placement(true);
|
||||
options.config.set_allow_soft_placement(true);
|
||||
InitContext(options, DEVICE_PLACEMENT_EXPLICIT);
|
||||
|
||||
Device* dev;
|
||||
DeviceNameUtils::ParsedName requested;
|
||||
const PrioritizedDeviceTypeVector supported{
|
||||
std::make_pair(DeviceType(DEVICE_GPU), 20),
|
||||
std::make_pair(DeviceType(DEVICE_CPU), 10),
|
||||
};
|
||||
|
||||
// No supported devices should result in an error.
|
||||
requested.Clear();
|
||||
Status status = context()->SelectDevice(
|
||||
requested, PrioritizedDeviceTypeVector{}, DT_INVALID, &dev);
|
||||
EXPECT_TRUE(errors::IsInvalidArgument(status));
|
||||
EXPECT_TRUE(
|
||||
absl::StrContains(status.error_message(), "No supported device found"))
|
||||
<< "unexpected error message " << status.error_message();
|
||||
|
||||
// An invalid requested device should be replaced by the "best" one.
|
||||
ASSERT_TRUE(DeviceNameUtils::ParseLocalName("GPU:99", &requested));
|
||||
TF_ASSERT_OK(context()->SelectDevice(requested, supported, DT_INVALID, &dev));
|
||||
EXPECT_EQ(dev->device_type(), DEVICE_GPU);
|
||||
|
||||
// Should pick the "best" supported device if given no constraints.
|
||||
requested.Clear();
|
||||
TF_ASSERT_OK(context()->SelectDevice(requested, supported, DT_INVALID, &dev));
|
||||
EXPECT_EQ(dev->device_type(), DEVICE_GPU);
|
||||
|
||||
// Should pick a CPU if asked to.
|
||||
ASSERT_TRUE(DeviceNameUtils::ParseLocalName("CPU:1", &requested));
|
||||
TF_ASSERT_OK(context()->SelectDevice(requested, supported, DT_INVALID, &dev));
|
||||
EXPECT_EQ(dev->device_type(), DEVICE_CPU);
|
||||
|
||||
// String tensors move to CPU under soft device placement.
|
||||
requested.Clear();
|
||||
TF_ASSERT_OK(context()->SelectDevice(requested, supported, DT_STRING, &dev));
|
||||
EXPECT_EQ(dev->device_type(), DEVICE_CPU);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
@ -458,8 +458,8 @@ Status EagerLocalExecute(EagerOperation* op, TensorHandle** retvals,
|
||||
"\nAll kernels registered for op ", ndef.op(),
|
||||
" :\n", KernelsRegisteredForOp(ndef.op()));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(
|
||||
ctx.SelectDevice(op->GetDeviceParsedName(), supported_devs, &device));
|
||||
TF_RETURN_IF_ERROR(ctx.SelectDevice(op->GetDeviceParsedName(),
|
||||
supported_devs, DT_INVALID, &device));
|
||||
|
||||
DVLOG(1) << "Placer place op [" << op->Name()
|
||||
<< "] on device: " << device->name();
|
||||
|
Loading…
Reference in New Issue
Block a user