Merge pull request #2506 from siddharth-agrawal/matmul_float64_gpu
Enable GPU for Matmul float64
This commit is contained in:
commit
175e9f73b3
@ -210,7 +210,7 @@ REGISTER_CPU(complex64);
|
|||||||
REGISTER_CPU(complex128);
|
REGISTER_CPU(complex128);
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA
|
||||||
REGISTER_GPU(float);
|
REGISTER_GPU(float);
|
||||||
// REGISTER_GPU(double);
|
REGISTER_GPU(double);
|
||||||
#if CUDA_VERSION >= 7050
|
#if CUDA_VERSION >= 7050
|
||||||
REGISTER_GPU(Eigen::half);
|
REGISTER_GPU(Eigen::half);
|
||||||
#endif
|
#endif
|
||||||
|
@ -95,6 +95,7 @@ class MatMulTest(tf.test.TestCase):
|
|||||||
x = np.arange(1., 5.).reshape([4, 1]).astype(np.float64)
|
x = np.arange(1., 5.).reshape([4, 1]).astype(np.float64)
|
||||||
y = np.arange(1., 3.).reshape([1, 2]).astype(np.float64)
|
y = np.arange(1., 3.).reshape([1, 2]).astype(np.float64)
|
||||||
self._testCpuMatmul(x, y)
|
self._testCpuMatmul(x, y)
|
||||||
|
self._testGpuMatmul(x, y)
|
||||||
|
|
||||||
def testHalfBasic(self):
|
def testHalfBasic(self):
|
||||||
x = np.arange(1., 5.).reshape([4, 1]).astype(np.float16)
|
x = np.arange(1., 5.).reshape([4, 1]).astype(np.float16)
|
||||||
@ -135,6 +136,7 @@ class MatMulTest(tf.test.TestCase):
|
|||||||
x = self._randMatrix(n, k, np.float64)
|
x = self._randMatrix(n, k, np.float64)
|
||||||
y = self._randMatrix(k, m, np.float64)
|
y = self._randMatrix(k, m, np.float64)
|
||||||
self._testCpuMatmul(x, y)
|
self._testCpuMatmul(x, y)
|
||||||
|
self._testGpuMatmul(x, y)
|
||||||
|
|
||||||
def testHalfRandom(self):
|
def testHalfRandom(self):
|
||||||
for _ in range(10):
|
for _ in range(10):
|
||||||
@ -185,6 +187,7 @@ class MatMulTest(tf.test.TestCase):
|
|||||||
x = self._randMatrix(k, n, np.float64)
|
x = self._randMatrix(k, n, np.float64)
|
||||||
y = self._randMatrix(m, k, np.float64)
|
y = self._randMatrix(m, k, np.float64)
|
||||||
self._testCpuMatmul(x, y, True, True)
|
self._testCpuMatmul(x, y, True, True)
|
||||||
|
self._testGpuMatmul(x, y, True, True)
|
||||||
|
|
||||||
def testHalfRandomTransposeBoth(self):
|
def testHalfRandomTransposeBoth(self):
|
||||||
for _ in range(10):
|
for _ in range(10):
|
||||||
|
Loading…
Reference in New Issue
Block a user