diff --git a/tensorflow/core/kernels/parameterized_truncated_normal_op.cc b/tensorflow/core/kernels/parameterized_truncated_normal_op.cc
index 4d31edbb1a9..77c4b7a7299 100644
--- a/tensorflow/core/kernels/parameterized_truncated_normal_op.cc
+++ b/tensorflow/core/kernels/parameterized_truncated_normal_op.cc
@@ -46,25 +46,6 @@ namespace functor {
 using random::PhiloxRandom;
 using random::SingleSampleAdapter;
 
-// Sample a truncated normal random variable, with mean, stddev, minval, and
-// maxval parameters for each batch. Uses two rejection sampling algorithms
-// described in http://rd.springer.com/article/10.1007/BF00143942.
-//
-// Either minval may be -infinity, or maxval may be +infinity. If the interval
-// (minval, maxval) is empty, the result is NaN. Large intervals which include
-// both tails may have reduced accuracy.
-template <typename Device, typename T>
-struct TruncatedNormalFunctor {
-  void operator()(OpKernelContext* ctx, const Device& d, int64 num_batches,
-                  int64 samples_per_batch, int64 num_elements,
-                  typename TTypes<T>::ConstFlat means,
-                  typename TTypes<T>::ConstFlat stddevs,
-                  typename TTypes<T>::ConstFlat minvals,
-                  typename TTypes<T>::ConstFlat maxvals,
-                  const random::PhiloxRandom& gen,
-                  typename TTypes<T>::Flat output);
-};
-
 template <typename T>
 struct TruncatedNormalFunctor<CPUDevice, T> {
   static const int kMaxIterations = 100;
@@ -96,8 +77,8 @@ struct TruncatedNormalFunctor<CPUDevice, T> {
 
       // Vectorized intermediate calculations for uniform rejection sampling.
       // We always generate at most 4 samples.
-      tensorflow::random::Array<T, 4> z;
-      tensorflow::random::Array<T, 4> g;
+      Eigen::array<T, 4> z;
+      Eigen::array<T, 4> g;
 
       for (int64 b = start_batch; b < limit_batch; ++b) {
         // We are passed a flat array for each of the parameter tensors.
@@ -145,13 +126,7 @@ struct TruncatedNormalFunctor<CPUDevice, T> {
         if (diff < cutoff) {
           // Sample from a uniform distribution on [normMin, normMax].
 
-          T plusFactor;
-          if (normMin < T(0)) {
-            // normMax > 0 because it is flipped otherwise.
-            plusFactor = T(0);
-          } else {
-            plusFactor = normMin * normMin;
-          }
+          const T plusFactor = (normMin < T(0)) ? T(0) : normMin * normMin;
 
           while (sample < limit_sample) {
             const auto rand = dist(&gen_copy);
@@ -395,4 +370,21 @@ TF_CALL_double(REGISTER);
 
 #undef REGISTER
 
+#if GOOGLE_CUDA
+
+#define REGISTER(TYPE)                                         \
+  REGISTER_KERNEL_BUILDER(Name("ParameterizedTruncatedNormal") \
+                              .Device(DEVICE_GPU)              \
+                              .HostMemory("shape")             \
+                              .TypeConstraint<TYPE>("dtype"),  \
+                          ParameterizedTruncatedNormalOp<GPUDevice, TYPE>)
+
+TF_CALL_half(REGISTER);
+TF_CALL_float(REGISTER);
+TF_CALL_double(REGISTER);
+
+#undef REGISTER
+
+#endif  // GOOGLE_CUDA
+
 }  // end namespace tensorflow
diff --git a/tensorflow/core/kernels/parameterized_truncated_normal_op.h b/tensorflow/core/kernels/parameterized_truncated_normal_op.h
index a46bb1c9fa6..cc801eb8109 100644
--- a/tensorflow/core/kernels/parameterized_truncated_normal_op.h
+++ b/tensorflow/core/kernels/parameterized_truncated_normal_op.h
@@ -16,14 +16,35 @@ limitations under the License.
 #ifndef TENSORFLOW_KERNELS_PARAMETERIZED_TRUNCATED_NORMAL_OP_H_
 #define TENSORFLOW_KERNELS_PARAMETERIZED_TRUNCATED_NORMAL_OP_H_
 
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/lib/random/random_distributions.h"
+
 namespace tensorflow {
 
 class OpKernelContext;
 
 namespace functor {
 
+// Sample a truncated normal random variable, with mean, stddev, minval, and
+// maxval parameters for each batch. Uses two rejection sampling algorithms
+// described in http://rd.springer.com/article/10.1007/BF00143942.
+//
+// Either minval may be -infinity, or maxval may be +infinity. If the interval
+// (minval, maxval) is empty, the result is NaN. Large intervals which include
+// both tails may have reduced accuracy.
 template <typename Device, typename T>
-struct TruncatedNormalFunctor;
+struct TruncatedNormalFunctor {
+  void operator()(OpKernelContext* ctx, const Device& d, int64 num_batches,
+                  int64 samples_per_batch, int64 num_elements,
+                  typename TTypes<T>::ConstFlat means,
+                  typename TTypes<T>::ConstFlat stddevs,
+                  typename TTypes<T>::ConstFlat minvals,
+                  typename TTypes<T>::ConstFlat maxvals,
+                  const random::PhiloxRandom& gen,
+                  typename TTypes<T>::Flat output);
+
+  static const int kMaxIterations = 100;
+};
 
 }  // namespace functor
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/parameterized_truncated_normal_op_gpu.cu.cc b/tensorflow/core/kernels/parameterized_truncated_normal_op_gpu.cu.cc
new file mode 100644
index 00000000000..42d47440690
--- /dev/null
+++ b/tensorflow/core/kernels/parameterized_truncated_normal_op_gpu.cu.cc
@@ -0,0 +1,214 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#if GOOGLE_CUDA
+
+#define EIGEN_USE_GPU
+
+#include "tensorflow/core/kernels/parameterized_truncated_normal_op.h"
+
+#include <assert.h>
+#include <stdio.h>
+#include <cmath>
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/lib/random/philox_random.h"
+#include "tensorflow/core/lib/random/random_distributions.h"
+#include "tensorflow/core/util/cuda_kernel_helper.h"
+
+#define UNROLL _Pragma("unroll")
+
+namespace tensorflow {
+
+class OpKernelContext;
+
+namespace functor {
+
+typedef Eigen::GpuDevice GPUDevice;
+
+template <typename T>
+__global__ void __launch_bounds__(1024)
+    TruncatedNormalKernel(random::PhiloxRandom gen, T* data, int64 num_batches,
+                          int64 samples_per_batch, int64 num_elements,
+                          const T* means, bool single_mean, const T* stddevs,
+                          bool single_stddev, const T* minvals,
+                          bool single_minval, const T* maxvals,
+                          bool single_maxval, int64 kMaxIterations) {
+  const int32 max_samples_per_item = 2 * kMaxIterations;
+  // Initial offset as given by CUDA_1D_KERNEL_LOOP.
+  const int32 initial_offset = blockIdx.x * blockDim.x + threadIdx.x;
+  gen.Skip(max_samples_per_item * initial_offset);
+  typedef random::UniformDistribution<random::PhiloxRandom, T> Uniform;
+  Uniform dist;
+  const int kDistSize = Uniform::kResultElementCount;
+  const T quietNaN = Eigen::NumTraits<T>::quiet_NaN();
+
+  // We skip the total number of threads to get to the next element. To produce
+  // deterministic results between devices, each element in the output array
+  // skips max_samples_per_item in the generator. Then after generating this
+  // item, we need to skip the samples for one element for every thread to get
+  // to the next element that we actually process.
+  const int32 samples_between_processed_elements =
+      max_samples_per_item * (gridDim.x * blockDim.x);
+
+  CUDA_1D_KERNEL_LOOP(offset, num_elements) {
+    // Track how many more samples we need to skip before we process the next
+    // element.
+    int32 remaining_samples = samples_between_processed_elements;
+
+    const int64 batch_id = offset / samples_per_batch;
+    T mean = means[single_mean ? 0 : batch_id];
+    const T input_stddev = stddevs[single_stddev ? 0 : batch_id];
+    T minval = minvals[single_minval ? 0 : batch_id];
+    T maxval = maxvals[single_maxval ? 0 : batch_id];
+
+    // Flip the distribution if we can make the lower bound positive.
+    T stddev;
+    if (Eigen::numext::isinf(minval) || maxval < mean) {
+      // Reverse all calculations. normMin and normMax will be flipped.
+      // std::swap is a host function (not available in CUDA).
+      T temp = minval;
+      minval = maxval;
+      maxval = temp;
+      stddev = -input_stddev;
+    } else {
+      stddev = input_stddev;
+    }
+
+    // Calculate normalized samples, then scale them.
+    const T normMin = (minval - mean) / stddev;
+    const T normMax = (maxval - mean) / stddev;
+
+    // Determine the method to use.
+    const T sqrtFactor = Eigen::numext::sqrt((normMin * normMin) + T(4));
+    const T cutoff =
+        T(2) *
+        Eigen::numext::exp(T(0.5) + (normMin * (normMin - sqrtFactor)) / T(4)) /
+        (normMin + sqrtFactor);
+    const T diff = normMax - normMin;
+
+    // Validate the normalized min and max, because the originals may have been
+    // flipped already.
+    if (!(input_stddev > T(0) && normMin < normMax &&
+          (Eigen::numext::isfinite(normMin) ||
+           Eigen::numext::isfinite(normMax)))) {
+      data[offset] = quietNaN;
+    } else if (diff < cutoff) {
+      // Sample from a uniform distribution on [normMin, normMax].
+
+      // Vectorized intermediate calculations for uniform rejection sampling.
+      // We always generate at most 4 samples.
+      Eigen::array<T, 4> z;
+      Eigen::array<T, 4> g;
+
+      const T plusFactor = (normMin < T(0)) ? T(0) : normMin * normMin;
+
+      int numIterations = 0;
+      while (numIterations < kMaxIterations) {
+        const auto rand = dist(&gen);
+        remaining_samples -= gen.kResultElementCount;
+        UNROLL for (int i = 0; i < kDistSize; i++) {
+          z[i] = rand[i] * diff + normMin;
+        }
+        UNROLL for (int i = 0; i < kDistSize; i++) {
+          g[i] = (plusFactor - z[i] * z[i]) / 2.0;
+        }
+
+        const auto u = dist(&gen);
+        remaining_samples -= gen.kResultElementCount;
+        UNROLL for (int i = 0; i < kDistSize; i++) {
+          if (u[i] <= Eigen::numext::exp(g[i]) ||
+              numIterations + 1 >= kMaxIterations) {
+            // Accept the sample z.
+            // If we run out of iterations, just use the current uniform
+            // sample. Emperically, the probability of accepting each sample
+            // is at least 50% for typical inputs, so we will always accept
+            // by 100 iterations.
+            // This introduces a slight inaccuracy when at least one bound
+            // is large, minval is negative and maxval is positive.
+            data[offset] = z[i] * stddev + mean;
+            // Break out of the nested loop by updating numIterations.
+            numIterations = kMaxIterations;
+            break;
+          } else {
+            numIterations++;
+          }
+        }
+      }
+    } else {
+      // Sample from an exponential distribution with alpha maximizing
+      // acceptance probability, offset by normMin from the origin.
+      // Accept only if less than normMax.
+      const T alpha =
+          (normMin + Eigen::numext::sqrt((normMin * normMin) + T(4))) / T(2);
+      int numIterations = 0;
+      while (numIterations < kMaxIterations) {
+        auto rand = dist(&gen);
+        remaining_samples -= gen.kResultElementCount;
+        UNROLL for (int i = 0; i < kDistSize; i += 2) {
+          const T z = -Eigen::numext::log(rand[i]) / alpha + normMin;
+          const T x = normMin < alpha ? alpha - z : normMin - alpha;
+          const T g = Eigen::numext::exp(-x * x / 2.0);
+          const T u = rand[i + 1];
+          if ((u <= g && z < normMax) || numIterations + 1 >= kMaxIterations) {
+            data[offset] = z * stddev + mean;
+            // Break out of the nested loop by updating numIterations.
+            numIterations = kMaxIterations;
+            break;
+          } else {
+            numIterations++;
+          }
+        }
+      }
+    }
+
+    gen.Skip(remaining_samples);
+  }
+}
+
+// Partial specialization for GPU
+template <typename T>
+struct TruncatedNormalFunctor<GPUDevice, T> {
+  static const int kMaxIterations = 100;
+
+  void operator()(OpKernelContext* ctx, const GPUDevice& d, int64 num_batches,
+                  int64 samples_per_batch, int64 num_elements,
+                  typename TTypes<T>::ConstFlat means,
+                  typename TTypes<T>::ConstFlat stddevs,
+                  typename TTypes<T>::ConstFlat minvals,
+                  typename TTypes<T>::ConstFlat maxvals,
+                  const random::PhiloxRandom& gen,
+                  typename TTypes<T>::Flat output) {
+    const auto config = GetCudaLaunchConfig(num_elements, d);
+
+    TruncatedNormalKernel<
+        T><<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
+        gen, output.data(), num_batches, samples_per_batch, num_elements,
+        means.data(), means.dimension(0) == 1, stddevs.data(),
+        stddevs.dimension(0) == 1, minvals.data(), minvals.dimension(0) == 1,
+        maxvals.data(), maxvals.dimension(0) == 1, kMaxIterations);
+  };
+};
+
+// Explicit instantiation of the GPU distributions functors
+template struct TruncatedNormalFunctor<GPUDevice, Eigen::half>;
+template struct TruncatedNormalFunctor<GPUDevice, float>;
+template struct TruncatedNormalFunctor<GPUDevice, double>;
+
+}  // namespace functor
+}  // namespace tensorflow
+
+#endif  // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/parameterized_truncated_normal_op_test.cc b/tensorflow/core/kernels/parameterized_truncated_normal_op_test.cc
index 13d1187f926..07f2f75ca5a 100644
--- a/tensorflow/core/kernels/parameterized_truncated_normal_op_test.cc
+++ b/tensorflow/core/kernels/parameterized_truncated_normal_op_test.cc
@@ -131,5 +131,8 @@ static Graph* PTruncatedNormalOneTail(int num_batches, int samples_per_batch) {
 BM_PTruncatedNormalDev(cpu, 1000, 1000);
 BM_PTruncatedNormalDev_2SD(cpu, 10000, 100);
 BM_PTruncatedNormalDev_OneTail(cpu, 10000, 100);
+BM_PTruncatedNormalDev(gpu, 1000, 1000);
+BM_PTruncatedNormalDev_2SD(gpu, 10000, 100);
+BM_PTruncatedNormalDev_OneTail(gpu, 10000, 100);
 
 }  // namespace tensorflow
diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD
index 24a1d8e7c97..0100e6b3268 100644
--- a/tensorflow/python/kernel_tests/BUILD
+++ b/tensorflow/python/kernel_tests/BUILD
@@ -220,7 +220,7 @@ cuda_py_test(
     additional_deps = ["//tensorflow:tensorflow_py"],
 )
 
-tf_py_test(
+cuda_py_test(
     name = "parameterized_truncated_normal_op_test",
     size = "small",
     srcs = ["parameterized_truncated_normal_op_test.py"],
diff --git a/tensorflow/python/kernel_tests/parameterized_truncated_normal_op_test.py b/tensorflow/python/kernel_tests/parameterized_truncated_normal_op_test.py
index 8d41029c0b5..1c09949598a 100644
--- a/tensorflow/python/kernel_tests/parameterized_truncated_normal_op_test.py
+++ b/tensorflow/python/kernel_tests/parameterized_truncated_normal_op_test.py
@@ -97,10 +97,10 @@ def z_test(real, expected, i, num_samples):
 
 
 class ParameterizedTruncatedNormalTest(tf.test.TestCase):
-  use_gpu = False
+  _use_gpu = False
   z_limit = 6.0
 
-  # Stop at moment 20 to avoid numerical errors in the theoretical moments.
+  # Stop at moment 10 to avoid numerical errors in the theoretical moments.
   max_moment = 10
 
   def validateMoments(self, shape, mean, stddev, minval, maxval, seed=1618):
@@ -109,9 +109,11 @@ class ParameterizedTruncatedNormalTest(tf.test.TestCase):
       # Give up early if we are unable to import it.
       import scipy.stats  # pylint: disable=g-import-not-at-top,unused-variable
       tf.set_random_seed(seed)
-      with self.test_session(use_gpu=self.use_gpu):
-        samples = random_ops.parameterized_truncated_normal(
-            shape, mean, stddev, minval, maxval).eval()
+      with self.test_session(use_gpu=self._use_gpu):
+        samples = random_ops.parameterized_truncated_normal(shape, mean, stddev,
+                                                            minval,
+                                                            maxval).eval()
+        assert (~np.isnan(samples)).all()
       moments = calculate_moments(samples, self.max_moment)
       expected_moments = TruncatedNormalMoments(mean, stddev, minval, maxval)
       num_samples = functools.reduce(lambda x, y: x * y, shape, 1)
@@ -131,9 +133,11 @@ class ParameterizedTruncatedNormalTest(tf.test.TestCase):
     try:
       import scipy.stats  # pylint: disable=g-import-not-at-top
       tf.set_random_seed(seed)
-      with self.test_session(use_gpu=self.use_gpu):
-        samples = random_ops.parameterized_truncated_normal(
-            shape, mean, stddev, minval, maxval).eval()
+      with self.test_session(use_gpu=self._use_gpu):
+        samples = random_ops.parameterized_truncated_normal(shape, mean, stddev,
+                                                            minval,
+                                                            maxval).eval()
+      assert (~np.isnan(samples)).all()
       minval = max(mean - stddev * 10, minval)
       maxval = min(mean + stddev * 10, maxval)
       dist = scipy.stats.norm(loc=mean, scale=stddev)
@@ -173,8 +177,12 @@ class ParameterizedTruncatedNormalTest(tf.test.TestCase):
     self.validateKolmogorovSmirnov([10**5], 0.0, 0.1, 0.05, 0.10)
 
 
+class ParameterizedTruncatedNormalGpuTest(ParameterizedTruncatedNormalTest):
+  _use_gpu = True
+
+
 # Benchmarking code
-def parameterized_vs_naive(shape, num_iters):
+def parameterized_vs_naive(shape, num_iters, use_gpu=False):
   np.random.seed(1618)  # Make it reproducible.
 
   # No CSE/CF.
@@ -183,17 +191,29 @@ def parameterized_vs_naive(shape, num_iters):
       graph_options=tf.GraphOptions(optimizer_options=optimizer_options))
 
   with tf.Session(config=config) as sess:
-    param_op = tf.group(random_ops.parameterized_truncated_normal(shape))
-    naive_op = tf.group(random_ops.truncated_normal(shape))
+    with tf.device("/cpu:0" if not use_gpu else None):
+      param_op = tf.group(random_ops.parameterized_truncated_normal(shape))
+      naive_op = tf.group(random_ops.truncated_normal(shape))
 
+    # Burn-in to avoid session setup costs in the timing.
+    sess.run(param_op)
+    sess.run(param_op)
     param_dt = timeit.timeit(lambda: sess.run(param_op), number=num_iters)
+    sess.run(naive_op)
+    sess.run(naive_op)
     naive_dt = timeit.timeit(lambda: sess.run(naive_op), number=num_iters)
     return param_dt, naive_dt
 
 
 class TruncatedNormalBenchmark(tf.test.Benchmark):
 
-  def benchmarkParameterizedOpVsNaiveOp(self):
+  def benchmarkParameterizedOpVsNaiveOpCpu(self):
+    self._benchmarkParameterizedOpVsNaiveOp(False)
+
+  def benchmarkParameterizedOpVsNaiveOpGpu(self):
+    self._benchmarkParameterizedOpVsNaiveOp(True)
+
+  def _benchmarkParameterizedOpVsNaiveOp(self, use_gpu):
     num_iters = 50
     print(("Composition of new ParameterizedTruncatedNormalOp vs. "
            "naive TruncatedNormalOp [%d iters]") % num_iters)
@@ -201,16 +221,16 @@ class TruncatedNormalBenchmark(tf.test.Benchmark):
 
     for shape in [[10000, 100], [1000, 1000], [1000000], [100, 100, 100],
                   [20, 20, 20, 20]]:
-      p_dt, n_dt = parameterized_vs_naive(shape, num_iters)
+      p_dt, n_dt = parameterized_vs_naive(shape, num_iters, use_gpu)
       print("%s\t%.3f\t%.3f\t%.2f" % (shape, p_dt, n_dt, p_dt / n_dt))
 
       shape_str = "-".join(map(str, shape))
-      self.report_benchmark(name="parameterized_shape" + shape_str,
-                            iters=num_iters,
-                            wall_time=p_dt)
-      self.report_benchmark(name="naive_shape" + shape_str,
-                            iters=num_iters,
-                            wall_time=n_dt)
+      self.report_benchmark(
+          name="parameterized_shape" + shape_str,
+          iters=num_iters,
+          wall_time=p_dt)
+      self.report_benchmark(
+          name="naive_shape" + shape_str, iters=num_iters, wall_time=n_dt)
 
 
 if __name__ == "__main__":