Enabling //tensorflow/core/framework:variant_op_registry_test for ROCm

This commit is contained in:
Eugene Kuznetsov 2020-01-15 18:30:01 -08:00
parent 057cf24986
commit c1b02f9a64

View File

@ -18,7 +18,7 @@ limitations under the License.
#define EIGEN_USE_THREADS
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define EIGEN_USE_GPU
#endif
@ -215,7 +215,7 @@ TEST(VariantOpZerosLikeRegistryTest, TestBasicCPU) {
EXPECT_EQ(vv_out->value, 1); // CPU
}
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
TEST(VariantOpUnaryOpRegistryTest, TestBasicGPU) {
class Blah {};
EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetUnaryOpFn(
@ -239,7 +239,7 @@ TEST(VariantOpUnaryOpRegistryTest, TestBasicGPU) {
VariantValue* vv_out = CHECK_NOTNULL(v_out.get<VariantValue>());
EXPECT_EQ(vv_out->value, 2); // GPU
}
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
TEST(VariantOpUnaryOpRegistryTest, TestDuplicate) {
UnaryVariantOpRegistry registry;
@ -286,7 +286,7 @@ TEST(VariantOpAddRegistryTest, TestBasicCPU) {
EXPECT_EQ(vv_out->value, 7); // CPU
}
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
TEST(VariantOpAddRegistryTest, TestBasicGPU) {
class Blah {};
EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetBinaryOpFn(
@ -312,7 +312,7 @@ TEST(VariantOpAddRegistryTest, TestBasicGPU) {
VariantValue* vv_out = CHECK_NOTNULL(v_out.get<VariantValue>());
EXPECT_EQ(vv_out->value, -7); // GPU
}
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
TEST(VariantOpAddRegistryTest, TestDuplicate) {
UnaryVariantOpRegistry registry;