Merge pull request #42615 from drebain:drebain_patch
PiperOrigin-RevId: 328985012 Change-Id: I0b4c461f07e8a838df05bfa329efdb1a8f1293f1
This commit is contained in:
commit
16895e59b8
@ -48,6 +48,14 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/util/gpu_kernel_helper.h"
|
#include "tensorflow/core/util/gpu_kernel_helper.h"
|
||||||
#include "tensorflow/core/util/transform_output_iterator.h"
|
#include "tensorflow/core/util/transform_output_iterator.h"
|
||||||
|
|
||||||
|
#if GOOGLE_CUDA
|
||||||
|
#include "tensorflow/stream_executor/cuda/cuda_activation.h"
|
||||||
|
using stream_executor::cuda::ScopedActivateExecutorContext;
|
||||||
|
#elif TENSORFLOW_USE_ROCM
|
||||||
|
#include "tensorflow/core/platform/rocm.h"
|
||||||
|
using stream_executor::rocm::ScopedActivateExecutorContext;
|
||||||
|
#endif // GOOGLE_CUDA
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
typedef Eigen::GpuDevice GPUDevice;
|
typedef Eigen::GpuDevice GPUDevice;
|
||||||
@ -302,6 +310,9 @@ class DynamicPartitionOpGPU : public AsyncOpKernel {
|
|||||||
TensorReference partition_ref(partition_count);
|
TensorReference partition_ref(partition_count);
|
||||||
auto wrapped_callback = [this, c, &data, &partitions, indices_out,
|
auto wrapped_callback = [this, c, &data, &partitions, indices_out,
|
||||||
partition_ref, cpu_tensor, done]() {
|
partition_ref, cpu_tensor, done]() {
|
||||||
|
auto stream = c->op_device_context()->stream();
|
||||||
|
ScopedActivateExecutorContext scoped_activation{stream->parent()};
|
||||||
|
|
||||||
OpOutputList outputs;
|
OpOutputList outputs;
|
||||||
this->AllocateOutputs(c, &data, &partitions, &cpu_tensor, &outputs, done);
|
this->AllocateOutputs(c, &data, &partitions, &cpu_tensor, &outputs, done);
|
||||||
if (!c->status().ok()) {
|
if (!c->status().ok()) {
|
||||||
|
@ -2048,6 +2048,9 @@ cuda_py_test(
|
|||||||
name = "dynamic_partition_op_test",
|
name = "dynamic_partition_op_test",
|
||||||
size = "medium",
|
size = "medium",
|
||||||
srcs = ["dynamic_partition_op_test.py"],
|
srcs = ["dynamic_partition_op_test.py"],
|
||||||
|
tags = [
|
||||||
|
"multi_and_single_gpu",
|
||||||
|
],
|
||||||
tfrt_enabled = True,
|
tfrt_enabled = True,
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/python:array_ops",
|
"//tensorflow/python:array_ops",
|
||||||
|
@ -23,8 +23,10 @@ import unittest
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||||
|
|
||||||
|
from tensorflow.python.framework import config
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import data_flow_ops
|
from tensorflow.python.ops import data_flow_ops
|
||||||
@ -346,6 +348,19 @@ class DynamicPartitionTest(test.TestCase):
|
|||||||
res = self.evaluate(partitioned)
|
res = self.evaluate(partitioned)
|
||||||
self.assertEqual(res[-1].shape[0], 192)
|
self.assertEqual(res[-1].shape[0], 192)
|
||||||
|
|
||||||
|
# see https://github.com/tensorflow/tensorflow/issues/42500
|
||||||
|
def testMultiGPU(self):
|
||||||
|
device_list = config.list_logical_devices("GPU")
|
||||||
|
results = []
|
||||||
|
for device in device_list:
|
||||||
|
with ops.device(device.name):
|
||||||
|
data = constant_op.constant(np.zeros((1000,)))
|
||||||
|
partitions = constant_op.constant(np.arange(1000, dtype=np.int32) % 10)
|
||||||
|
result = data_flow_ops.dynamic_partition(data, partitions, 10)
|
||||||
|
results.append(self.evaluate(result))
|
||||||
|
if device_list:
|
||||||
|
self.assertAllEqual(results, np.zeros((len(device_list), 10, 100)))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test.main()
|
test.main()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user