Adding ROCm support to cross_op
This commit is contained in:
parent
b25e77ca06
commit
2c7a81f9d5
@ -87,7 +87,7 @@ class CrossOp : public OpKernel {
|
||||
TF_CALL_REAL_NUMBER_TYPES(REGISTER_CPU_KERNEL);
|
||||
#undef REGISTER_CPU_KERNEL
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
// Forward declarations of the function specializations for GPU (to prevent
|
||||
// building the GPU versions here, they will be built compiling _gpu.cu.cc).
|
||||
namespace functor {
|
||||
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
#define EIGEN_USE_GPU
|
||||
|
||||
@ -31,4 +31,4 @@ TF_CALL_REAL_NUMBER_TYPES(INSTANTIATE_GPU_KERNEL);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
Loading…
x
Reference in New Issue
Block a user