Adding ROCm support to cross_op

This commit is contained in:
jerryyin 2019-05-30 14:29:45 +00:00
parent b25e77ca06
commit 2c7a81f9d5
2 changed files with 3 additions and 3 deletions

View File

@ -87,7 +87,7 @@ class CrossOp : public OpKernel {
TF_CALL_REAL_NUMBER_TYPES(REGISTER_CPU_KERNEL);
#undef REGISTER_CPU_KERNEL
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
// Forward declarations of the function specializations for GPU (to prevent
// building the GPU versions here, they will be built compiling _gpu.cu.cc).
namespace functor {

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define EIGEN_USE_GPU
@ -31,4 +31,4 @@ TF_CALL_REAL_NUMBER_TYPES(INSTANTIATE_GPU_KERNEL);
} // namespace tensorflow
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM