diff --git a/tensorflow/core/common_runtime/eager/BUILD b/tensorflow/core/common_runtime/eager/BUILD index d9fe36fa8f3..c32c1a81ea9 100644 --- a/tensorflow/core/common_runtime/eager/BUILD +++ b/tensorflow/core/common_runtime/eager/BUILD @@ -81,6 +81,17 @@ tf_cuda_library( }), ) +tf_cc_test( + name = "context_test", + srcs = ["context_test.cc"], + deps = [ + ":context", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + tf_cuda_library( name = "eager_operation", srcs = [ diff --git a/tensorflow/core/common_runtime/eager/context.cc b/tensorflow/core/common_runtime/eager/context.cc index 5e6308aac11..95032babff2 100644 --- a/tensorflow/core/common_runtime/eager/context.cc +++ b/tensorflow/core/common_runtime/eager/context.cc @@ -168,12 +168,17 @@ std::vector DevicesToString(const std::vector& devices) { } } // namespace -Status EagerContext::SelectDevice(const DeviceNameUtils::ParsedName& preferred, +Status EagerContext::SelectDevice(DeviceNameUtils::ParsedName preferred, const PrioritizedDeviceTypeVector& supported, - Device** device) const { + const DataType dtype, Device** device) const { std::vector selected; const DeviceSet& pflr_devices = *pflr()->device_set(); + // We always place string tensors on the CPU device if we're allowed to. + if (dtype == DT_STRING && AllowSoftPlacement()) { + preferred = HostCPU()->parsed_name(); + } + // If there are no preferred devices, select the first registered device from // the supported device list. if (!DeviceNameUtils::HasSomeDetails(preferred)) { diff --git a/tensorflow/core/common_runtime/eager/context.h b/tensorflow/core/common_runtime/eager/context.h index 44f287d0ad8..98c54035234 100644 --- a/tensorflow/core/common_runtime/eager/context.h +++ b/tensorflow/core/common_runtime/eager/context.h @@ -167,11 +167,15 @@ class EagerContext : public core::RefCounted { // devices, and the context currently allows soft device placement, a suitable // device not matching `preferred` will be chosen. // + // The `dtype` parameter specifies the operation's result data type, if + // known. Setting it to DT_INVALID will make this method not use the data type + // for its decisions. + // // The chosen device is stored in the `device` argument. The argument is not // modified unless this method returns `Status::OK()`. - Status SelectDevice(const DeviceNameUtils::ParsedName& preferred, + Status SelectDevice(DeviceNameUtils::ParsedName preferred, const PrioritizedDeviceTypeVector& supported, - Device** device) const; + const DataType dtype, Device** device) const; // Sets the implicit copy policy for the current thread. void SetThreadLocalMirroringPolicy(ContextMirroringPolicy); diff --git a/tensorflow/core/common_runtime/eager/context_test.cc b/tensorflow/core/common_runtime/eager/context_test.cc new file mode 100644 index 00000000000..9581f149e4a --- /dev/null +++ b/tensorflow/core/common_runtime/eager/context_test.cc @@ -0,0 +1,174 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/common_runtime/eager/context.h" + +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +// Return a fake device. +static Device* CreateDevice(const string& type, int n) { + class FakeDevice : public Device { + public: + explicit FakeDevice(const DeviceAttributes& attr) : Device(nullptr, attr) {} + Status Sync() override { return Status::OK(); } + Allocator* GetAllocator(AllocatorAttributes) override { return nullptr; } + }; + DeviceAttributes attr; + attr.set_name("/job:a/replica:0/task:0/device:" + type + ":" + + std::to_string(n)); + attr.set_device_type(type); + return new FakeDevice(attr); +} + +class EagerContextTest : public ::testing::Test { + public: + EagerContextTest() : device_manager_(nullptr), context_(nullptr) {} + + ~EagerContextTest() override { + delete device_manager_; + if (context_) { + context_->Unref(); + } + } + + EagerContext* context() { return context_; } + + void InitContext(const SessionOptions& opts, + ContextDevicePlacementPolicy policy) { + ASSERT_EQ(context_, nullptr); + InitDeviceManager(); + context_ = new EagerContext( + opts, policy, + /* default_mirroring_policy */ MIRRORING_NONE, + /* async */ false, + /* lazy_copy_function_remote_inputs */ false, device_manager_, + /* device_mgr_owned */ false, /* rendezvous */ nullptr, + /* custom_kernel_creator */ nullptr, + /* cluster_flr */ nullptr); + } + + protected: + void InitDeviceManager() { + ASSERT_EQ(device_manager_, nullptr); + device_manager_ = new DynamicDeviceMgr(); + std::vector> added_devices; + added_devices.emplace_back(CreateDevice(DEVICE_CPU, 0)); + added_devices.emplace_back(CreateDevice(DEVICE_CPU, 1)); + added_devices.emplace_back(CreateDevice(DEVICE_GPU, 0)); + added_devices.emplace_back(CreateDevice(DEVICE_GPU, 1)); + + TF_CHECK_OK(device_manager_->AddDevices(std::move(added_devices))); + } + + DynamicDeviceMgr* device_manager_; + EagerContext* context_; +}; + +TEST_F(EagerContextTest, SelectDeviceExplicitHardPlacement) { + SessionOptions options; + options.config.set_log_device_placement(true); + options.config.set_allow_soft_placement(false); + InitContext(options, DEVICE_PLACEMENT_EXPLICIT); + + Device* dev; + DeviceNameUtils::ParsedName requested; + const PrioritizedDeviceTypeVector supported{ + std::make_pair(DeviceType(DEVICE_GPU), 20), + std::make_pair(DeviceType(DEVICE_CPU), 10), + }; + + // No supported devices should result in an error. + requested.Clear(); + Status status = context()->SelectDevice( + requested, PrioritizedDeviceTypeVector{}, DT_INVALID, &dev); + EXPECT_TRUE(errors::IsInvalidArgument(status)); + EXPECT_TRUE( + absl::StrContains(status.error_message(), "No supported device found")) + << "unexpected error message " << status.error_message(); + + // An invalid requested device should also cause an error. + ASSERT_TRUE(DeviceNameUtils::ParseLocalName("GPU:99", &requested)); + status = context()->SelectDevice(requested, supported, DT_INVALID, &dev); + EXPECT_TRUE(errors::IsInvalidArgument(status)); + EXPECT_TRUE(absl::StrContains(status.error_message(), + "Could not satisfy device specification")) + << "unexpected error message " << status.error_message(); + + // Should pick the "best" supported device if given no constraints. + requested.Clear(); + TF_ASSERT_OK(context()->SelectDevice(requested, supported, DT_INVALID, &dev)); + EXPECT_EQ(dev->device_type(), DEVICE_GPU); + + // Should pick a CPU if asked to. + ASSERT_TRUE(DeviceNameUtils::ParseLocalName("CPU:1", &requested)); + TF_ASSERT_OK(context()->SelectDevice(requested, supported, DT_INVALID, &dev)); + EXPECT_EQ(dev->device_type(), DEVICE_CPU); + + // String tensors stay in GPU under hard device placement. + requested.Clear(); + TF_ASSERT_OK(context()->SelectDevice(requested, supported, DT_STRING, &dev)); + EXPECT_EQ(dev->device_type(), DEVICE_GPU); +} + +TEST_F(EagerContextTest, SelectDeviceExplicitSoftPlacement) { + SessionOptions options; + options.config.set_log_device_placement(true); + options.config.set_allow_soft_placement(true); + InitContext(options, DEVICE_PLACEMENT_EXPLICIT); + + Device* dev; + DeviceNameUtils::ParsedName requested; + const PrioritizedDeviceTypeVector supported{ + std::make_pair(DeviceType(DEVICE_GPU), 20), + std::make_pair(DeviceType(DEVICE_CPU), 10), + }; + + // No supported devices should result in an error. + requested.Clear(); + Status status = context()->SelectDevice( + requested, PrioritizedDeviceTypeVector{}, DT_INVALID, &dev); + EXPECT_TRUE(errors::IsInvalidArgument(status)); + EXPECT_TRUE( + absl::StrContains(status.error_message(), "No supported device found")) + << "unexpected error message " << status.error_message(); + + // An invalid requested device should be replaced by the "best" one. + ASSERT_TRUE(DeviceNameUtils::ParseLocalName("GPU:99", &requested)); + TF_ASSERT_OK(context()->SelectDevice(requested, supported, DT_INVALID, &dev)); + EXPECT_EQ(dev->device_type(), DEVICE_GPU); + + // Should pick the "best" supported device if given no constraints. + requested.Clear(); + TF_ASSERT_OK(context()->SelectDevice(requested, supported, DT_INVALID, &dev)); + EXPECT_EQ(dev->device_type(), DEVICE_GPU); + + // Should pick a CPU if asked to. + ASSERT_TRUE(DeviceNameUtils::ParseLocalName("CPU:1", &requested)); + TF_ASSERT_OK(context()->SelectDevice(requested, supported, DT_INVALID, &dev)); + EXPECT_EQ(dev->device_type(), DEVICE_CPU); + + // String tensors move to CPU under soft device placement. + requested.Clear(); + TF_ASSERT_OK(context()->SelectDevice(requested, supported, DT_STRING, &dev)); + EXPECT_EQ(dev->device_type(), DEVICE_CPU); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/common_runtime/eager/execute.cc b/tensorflow/core/common_runtime/eager/execute.cc index c81945f7ef0..e6861a41447 100644 --- a/tensorflow/core/common_runtime/eager/execute.cc +++ b/tensorflow/core/common_runtime/eager/execute.cc @@ -458,8 +458,8 @@ Status EagerLocalExecute(EagerOperation* op, TensorHandle** retvals, "\nAll kernels registered for op ", ndef.op(), " :\n", KernelsRegisteredForOp(ndef.op())); } - TF_RETURN_IF_ERROR( - ctx.SelectDevice(op->GetDeviceParsedName(), supported_devs, &device)); + TF_RETURN_IF_ERROR(ctx.SelectDevice(op->GetDeviceParsedName(), + supported_devs, DT_INVALID, &device)); DVLOG(1) << "Placer place op [" << op->Name() << "] on device: " << device->name();