Enabling //tensorflow/core/framework:variant_op_registry_test for ROCm
This commit is contained in:
parent
057cf24986
commit
c1b02f9a64
@ -18,7 +18,7 @@ limitations under the License.
|
||||
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#define EIGEN_USE_GPU
|
||||
#endif
|
||||
|
||||
@ -215,7 +215,7 @@ TEST(VariantOpZerosLikeRegistryTest, TestBasicCPU) {
|
||||
EXPECT_EQ(vv_out->value, 1); // CPU
|
||||
}
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
TEST(VariantOpUnaryOpRegistryTest, TestBasicGPU) {
|
||||
class Blah {};
|
||||
EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetUnaryOpFn(
|
||||
@ -239,7 +239,7 @@ TEST(VariantOpUnaryOpRegistryTest, TestBasicGPU) {
|
||||
VariantValue* vv_out = CHECK_NOTNULL(v_out.get<VariantValue>());
|
||||
EXPECT_EQ(vv_out->value, 2); // GPU
|
||||
}
|
||||
#endif // GOOGLE_CUDA
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
TEST(VariantOpUnaryOpRegistryTest, TestDuplicate) {
|
||||
UnaryVariantOpRegistry registry;
|
||||
@ -286,7 +286,7 @@ TEST(VariantOpAddRegistryTest, TestBasicCPU) {
|
||||
EXPECT_EQ(vv_out->value, 7); // CPU
|
||||
}
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
TEST(VariantOpAddRegistryTest, TestBasicGPU) {
|
||||
class Blah {};
|
||||
EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetBinaryOpFn(
|
||||
@ -312,7 +312,7 @@ TEST(VariantOpAddRegistryTest, TestBasicGPU) {
|
||||
VariantValue* vv_out = CHECK_NOTNULL(v_out.get<VariantValue>());
|
||||
EXPECT_EQ(vv_out->value, -7); // GPU
|
||||
}
|
||||
#endif // GOOGLE_CUDA
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
TEST(VariantOpAddRegistryTest, TestDuplicate) {
|
||||
UnaryVariantOpRegistry registry;
|
||||
|
Loading…
Reference in New Issue
Block a user