Adding ROCm support for the "reverse_sequence" op
This commit is contained in:
parent
3ef646727a
commit
46958a12ca
@ -17,9 +17,9 @@ limitations under the License.
|
|||||||
|
|
||||||
#define EIGEN_USE_THREADS
|
#define EIGEN_USE_THREADS
|
||||||
|
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
#define EIGEN_USE_GPU
|
#define EIGEN_USE_GPU
|
||||||
#endif // GOOGLE_CUDA
|
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
|
||||||
#include "tensorflow/core/kernels/reverse_sequence_op.h"
|
#include "tensorflow/core/kernels/reverse_sequence_op.h"
|
||||||
|
|
||||||
@ -177,7 +177,7 @@ class ReverseSequenceOp : public OpKernel {
|
|||||||
TF_CALL_NUMBER_TYPES(REGISTER_REVERSE_SEQUENCE_LEN);
|
TF_CALL_NUMBER_TYPES(REGISTER_REVERSE_SEQUENCE_LEN);
|
||||||
TF_CALL_bool(REGISTER_REVERSE_SEQUENCE_LEN);
|
TF_CALL_bool(REGISTER_REVERSE_SEQUENCE_LEN);
|
||||||
|
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
|
||||||
// Forward declarations of the functor specializations for GPU.
|
// Forward declarations of the functor specializations for GPU.
|
||||||
namespace functor {
|
namespace functor {
|
||||||
@ -222,6 +222,6 @@ TF_CALL_bool(REGISTER_REVERSE_SEQUENCE_GPU_LEN);
|
|||||||
|
|
||||||
#undef REGISTER_REVERSE_SEQUENCE_GPU
|
#undef REGISTER_REVERSE_SEQUENCE_GPU
|
||||||
|
|
||||||
#endif // GOOGLE_CUDA
|
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
|
||||||
#define EIGEN_USE_GPU
|
#define EIGEN_USE_GPU
|
||||||
|
|
||||||
@ -43,4 +43,4 @@ TF_CALL_bool(DEFINE_GPU_SPECS);
|
|||||||
|
|
||||||
} // end namespace tensorflow
|
} // end namespace tensorflow
|
||||||
|
|
||||||
#endif // GOOGLE_CUDA
|
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
Loading…
x
Reference in New Issue
Block a user