diff --git a/tensorflow/core/common_runtime/gpu/gpu_device_test.cc b/tensorflow/core/common_runtime/gpu/gpu_device_test.cc index 6448fc56af7..21c75244b5f 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_device_test.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_device_test.cc @@ -230,9 +230,9 @@ TEST_F(GPUDeviceTest, SingleVirtualDeviceWithMemoryLimitAndNoPriority) { TEST_F(GPUDeviceTest, SingleVirtualDeviceWithInvalidPriority) { { #if TENSORFLOW_USE_ROCM - // Priority outside the range (0, 2) for AMD GPUs + // Priority outside the range (-1, 1) for AMD GPUs SessionOptions opts = - MakeSessionOptions("0", 0, 1, {{123, 456}}, {{-1, 2}}); + MakeSessionOptions("0", 0, 1, {{123, 456}}, {{-2, 1}}); #else // Priority outside the range (-2, 0) for NVidia GPUs SessionOptions opts = @@ -245,7 +245,7 @@ TEST_F(GPUDeviceTest, SingleVirtualDeviceWithInvalidPriority) { #if TENSORFLOW_USE_ROCM ExpectErrorMessageSubstr( status, - "Priority -1 is outside the range of supported priorities [0,2] for" + "Priority -2 is outside the range of supported priorities [-1,1] for" " virtual device 0 on GPU# 0"); #else ExpectErrorMessageSubstr( @@ -254,8 +254,8 @@ TEST_F(GPUDeviceTest, SingleVirtualDeviceWithInvalidPriority) { } { #if TENSORFLOW_USE_ROCM - // Priority outside the range (0, 2) for AMD GPUs - SessionOptions opts = MakeSessionOptions("0", 0, 1, {{123, 456}}, {{0, 3}}); + // Priority outside the range (-1, 1) for AMD GPUs + SessionOptions opts = MakeSessionOptions("0", 0, 1, {{123, 456}}, {{-1, 2}}); #else // Priority outside the range (-2, 0) for NVidia GPUs SessionOptions opts = MakeSessionOptions("0", 0, 1, {{123, 456}}, {{0, 1}}); @@ -267,7 +267,7 @@ TEST_F(GPUDeviceTest, SingleVirtualDeviceWithInvalidPriority) { #if TENSORFLOW_USE_ROCM ExpectErrorMessageSubstr( status, - "Priority 3 is outside the range of supported priorities [0,2] for" + "Priority 2 is outside the range of supported priorities [-1,1] for" " virtual device 0 on GPU# 0"); #else ExpectErrorMessageSubstr( @@ -288,26 +288,17 @@ TEST_F(GPUDeviceTest, SingleVirtualDeviceWithMemoryLimitAndPriority) { } TEST_F(GPUDeviceTest, MultipleVirtualDevices) { -#if TENSORFLOW_USE_ROCM - // Valid range for priority values on AMD GPUs in (0,2) - SessionOptions opts = MakeSessionOptions("0", 0, 1, {{123, 456}}, {{0, 1}}); -#else + // Valid range for priority values on AMD GPUs in (-1,1) // Valid range for priority values on NVidia GPUs in (-2, 0) SessionOptions opts = MakeSessionOptions("0", 0, 1, {{123, 456}}, {{0, -1}}); -#endif std::vector> devices; TF_CHECK_OK(DeviceFactory::GetFactory("GPU")->CreateDevices( opts, kDeviceNamePrefix, &devices)); EXPECT_EQ(2, devices.size()); EXPECT_EQ(123 << 20, devices[0]->attributes().memory_limit()); EXPECT_EQ(456 << 20, devices[1]->attributes().memory_limit()); -#if TENSORFLOW_USE_ROCM - EXPECT_EQ(0, static_cast(devices[0].get())->priority()); - EXPECT_EQ(1, static_cast(devices[1].get())->priority()); -#else EXPECT_EQ(0, static_cast(devices[0].get())->priority()); EXPECT_EQ(-1, static_cast(devices[1].get())->priority()); -#endif ASSERT_EQ(1, devices[0]->attributes().locality().links().link_size()); ASSERT_EQ(1, devices[1]->attributes().locality().links().link_size()); EXPECT_EQ(1, devices[0]->attributes().locality().links().link(0).device_id()); @@ -339,27 +330,18 @@ TEST_F(GPUDeviceTest, MultipleVirtualDevicesWithPriority) { } { // Multile virtual devices with matching priority. -#if TENSORFLOW_USE_ROCM - // Valid range for priority values on AMD GPUs in (0,2) - SessionOptions opts = MakeSessionOptions("0", 0, 1, {{123, 456}}, {{2, 1}}); -#else + // Valid range for priority values on AMD GPUs in (-1,1) // Valid range for priority values on NVidia GPUs in (-2, 0) SessionOptions opts = MakeSessionOptions("0", 0, 1, {{123, 456}}, {{-1, 0}}); -#endif std::vector> devices; TF_CHECK_OK(DeviceFactory::GetFactory("GPU")->CreateDevices( opts, kDeviceNamePrefix, &devices)); EXPECT_EQ(2, devices.size()); EXPECT_EQ(123 << 20, devices[0]->attributes().memory_limit()); EXPECT_EQ(456 << 20, devices[1]->attributes().memory_limit()); -#if TENSORFLOW_USE_ROCM - EXPECT_EQ(2, static_cast(devices[0].get())->priority()); - EXPECT_EQ(1, static_cast(devices[1].get())->priority()); -#else EXPECT_EQ(-1, static_cast(devices[0].get())->priority()); EXPECT_EQ(0, static_cast(devices[1].get())->priority()); -#endif } }