[tf.data] Add support for copying Optional variants to/from GPU.
PiperOrigin-RevId: 207490563
This commit is contained in:
parent
c42013f103
commit
02ae1e2e78
@ -907,6 +907,42 @@ class CopyToDeviceTest(test.TestCase):
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
|
||||
def testIteratorGetNextAsOptionalOnGPU(self):
|
||||
if not test_util.is_gpu_available():
|
||||
self.skipTest("No GPU available")
|
||||
|
||||
host_dataset = dataset_ops.Dataset.range(3)
|
||||
device_dataset = host_dataset.apply(
|
||||
prefetching_ops.copy_to_device("/gpu:0"))
|
||||
with ops.device("/gpu:0"):
|
||||
iterator = device_dataset.make_initializable_iterator()
|
||||
next_elem = iterator_ops.get_next_as_optional(iterator)
|
||||
elem_has_value_t = next_elem.has_value()
|
||||
elem_value_t = next_elem.get_value()
|
||||
|
||||
with self.test_session() as sess:
|
||||
# Before initializing the iterator, evaluating the optional fails with
|
||||
# a FailedPreconditionError.
|
||||
with self.assertRaises(errors.FailedPreconditionError):
|
||||
sess.run(elem_has_value_t)
|
||||
with self.assertRaises(errors.FailedPreconditionError):
|
||||
sess.run(elem_value_t)
|
||||
|
||||
# For each element of the dataset, assert that the optional evaluates to
|
||||
# the expected value.
|
||||
sess.run(iterator.initializer)
|
||||
for i in range(3):
|
||||
elem_has_value, elem_value = sess.run([elem_has_value_t, elem_value_t])
|
||||
self.assertTrue(elem_has_value)
|
||||
self.assertEqual(i, elem_value)
|
||||
|
||||
# After exhausting the iterator, `next_elem.has_value()` will evaluate to
|
||||
# false, and attempting to get the value will fail.
|
||||
for _ in range(2):
|
||||
self.assertFalse(sess.run(elem_has_value_t))
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
sess.run(elem_value_t)
|
||||
|
||||
|
||||
class MultiDeviceIteratorTest(test.TestCase):
|
||||
|
||||
|
||||
@ -340,4 +340,30 @@ Status CopyTensor::Register(DeviceType sender_device_type,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
// The following registrations enable a DT_VARIANT tensor element that contains
|
||||
// a wrapped `tensorflow::Tensor` to be copied between devices.
|
||||
static Status WrappedTensorDeviceCopy(
|
||||
const Tensor& from, Tensor* to,
|
||||
const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copy) {
|
||||
if (DMAHelper::CanUseDMA(&from)) {
|
||||
TF_RETURN_IF_ERROR(copy(from, to));
|
||||
} else {
|
||||
*to = from;
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
#define REGISTER_WRAPPED_TENSOR_COPY(DIRECTION) \
|
||||
INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( \
|
||||
Tensor, DIRECTION, "tensorflow::Tensor", WrappedTensorDeviceCopy)
|
||||
|
||||
REGISTER_WRAPPED_TENSOR_COPY(VariantDeviceCopyDirection::HOST_TO_DEVICE);
|
||||
REGISTER_WRAPPED_TENSOR_COPY(VariantDeviceCopyDirection::DEVICE_TO_HOST);
|
||||
REGISTER_WRAPPED_TENSOR_COPY(VariantDeviceCopyDirection::DEVICE_TO_DEVICE);
|
||||
|
||||
} // namespace
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
@ -57,6 +57,10 @@ namespace tensorflow {
|
||||
// Allow Tensors to be stored inside Variants with automatic
|
||||
// encoding/decoding when those Variants are themselves being decoded
|
||||
// in a Tensor's FromProto.
|
||||
//
|
||||
// NOTE(mrry): The corresponding "copy function" registrations can be found in
|
||||
// ../common_runtime/copy_tensor.cc (due to dependencies on other common_runtime
|
||||
// code).
|
||||
REGISTER_UNARY_VARIANT_DECODE_FUNCTION(Tensor, "tensorflow::Tensor");
|
||||
|
||||
namespace {
|
||||
|
||||
@ -1321,9 +1321,10 @@ REGISTER_KERNEL_BUILDER(Name("IteratorGetNextSync").Device(DEVICE_CPU),
|
||||
IteratorGetNextSyncOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("IteratorGetNextSync").Device(DEVICE_GPU),
|
||||
IteratorGetNextSyncOp);
|
||||
// TODO(b/111349762): Add registration for other devices.
|
||||
REGISTER_KERNEL_BUILDER(Name("IteratorGetNextAsOptional").Device(DEVICE_CPU),
|
||||
IteratorGetNextAsOptionalOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("IteratorGetNextAsOptional").Device(DEVICE_GPU),
|
||||
IteratorGetNextAsOptionalOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("IteratorToStringHandle").Device(DEVICE_CPU),
|
||||
IteratorToStringHandleOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("IteratorToStringHandle")
|
||||
|
||||
@ -14,8 +14,10 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/core/kernels/data/optional_ops.h"
|
||||
|
||||
#include "tensorflow/core/common_runtime/dma_helper.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/variant_encode_decode.h"
|
||||
#include "tensorflow/core/framework/variant_op_registry.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
@ -23,9 +25,6 @@ const char kOptionalVariantTypeName[] = "tensorflow::data::Optional";
|
||||
|
||||
// An `OptionalVariant` can represent either an "actual value" (a tuple of
|
||||
// tensors) or "none", and may be stored in a DT_VARIANT tensor.
|
||||
//
|
||||
// TODO(b/111349762): Add registrations for copying `OptionalVariant` between
|
||||
// devices.
|
||||
class OptionalVariant {
|
||||
public:
|
||||
// Create an `OptionalVariant` with no actual value.
|
||||
@ -189,16 +188,59 @@ class OptionalGetValueOp : public OpKernel {
|
||||
std::vector<PartialTensorShape> output_shapes_;
|
||||
};
|
||||
|
||||
// TODO(b/111349762): Add registrations for other devices.
|
||||
REGISTER_KERNEL_BUILDER(Name("OptionalNone").Device(DEVICE_CPU),
|
||||
OptionalNoneOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("OptionalNone").Device(DEVICE_GPU),
|
||||
OptionalNoneOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("OptionalFromValue").Device(DEVICE_CPU),
|
||||
OptionalFromValueOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("OptionalFromValue").Device(DEVICE_GPU),
|
||||
OptionalFromValueOp);
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("OptionalHasValue").Device(DEVICE_CPU),
|
||||
OptionalHasValueOp);
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("OptionalHasValue").Device(DEVICE_GPU).HostMemory("has_value"),
|
||||
OptionalHasValueOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("OptionalGetValue").Device(DEVICE_CPU),
|
||||
OptionalGetValueOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("OptionalGetValue").Device(DEVICE_GPU),
|
||||
OptionalGetValueOp);
|
||||
|
||||
static Status OptionalDeviceCopy(
|
||||
const OptionalVariant& from, OptionalVariant* to,
|
||||
const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copy) {
|
||||
if (from.has_value()) {
|
||||
const std::vector<Tensor>& from_values = from.get_values();
|
||||
std::vector<Tensor> to_values;
|
||||
to_values.reserve(from_values.size());
|
||||
for (const Tensor& t : from_values) {
|
||||
if (DMAHelper::CanUseDMA(&t)) {
|
||||
Tensor tmp(t.dtype());
|
||||
TF_RETURN_IF_ERROR(copy(t, &tmp));
|
||||
to_values.push_back(std::move(tmp));
|
||||
} else {
|
||||
to_values.push_back(t);
|
||||
}
|
||||
}
|
||||
*to = OptionalVariant(std::move(to_values));
|
||||
} else {
|
||||
*to = from;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
#define REGISTER_OPTIONAL_COPY(DIRECTION) \
|
||||
INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( \
|
||||
OptionalVariant, DIRECTION, kOptionalVariantTypeName, \
|
||||
OptionalDeviceCopy)
|
||||
|
||||
REGISTER_OPTIONAL_COPY(VariantDeviceCopyDirection::HOST_TO_DEVICE);
|
||||
REGISTER_OPTIONAL_COPY(VariantDeviceCopyDirection::DEVICE_TO_HOST);
|
||||
REGISTER_OPTIONAL_COPY(VariantDeviceCopyDirection::DEVICE_TO_DEVICE);
|
||||
|
||||
REGISTER_UNARY_VARIANT_DECODE_FUNCTION(OptionalVariant,
|
||||
kOptionalVariantTypeName);
|
||||
|
||||
} // namespace
|
||||
|
||||
@ -206,8 +248,10 @@ Status WriteOptionalWithValueToOutput(OpKernelContext* ctx, int output_index,
|
||||
std::vector<Tensor> value) {
|
||||
OptionalVariant v(std::move(value));
|
||||
Tensor* variant_t;
|
||||
TF_RETURN_IF_ERROR(
|
||||
ctx->allocate_output(output_index, TensorShape({}), &variant_t));
|
||||
AllocatorAttributes cpu_alloc;
|
||||
cpu_alloc.set_on_host(true);
|
||||
TF_RETURN_IF_ERROR(ctx->allocate_output(output_index, TensorShape({}),
|
||||
&variant_t, cpu_alloc));
|
||||
variant_t->scalar<Variant>()() = v;
|
||||
return Status::OK();
|
||||
}
|
||||
@ -215,8 +259,10 @@ Status WriteOptionalWithValueToOutput(OpKernelContext* ctx, int output_index,
|
||||
Status WriteOptionalNoneToOutput(OpKernelContext* ctx, int output_index) {
|
||||
OptionalVariant v;
|
||||
Tensor* variant_t;
|
||||
TF_RETURN_IF_ERROR(
|
||||
ctx->allocate_output(output_index, TensorShape({}), &variant_t));
|
||||
AllocatorAttributes cpu_alloc;
|
||||
cpu_alloc.set_on_host(true);
|
||||
TF_RETURN_IF_ERROR(ctx->allocate_output(output_index, TensorShape({}),
|
||||
&variant_t, cpu_alloc));
|
||||
variant_t->scalar<Variant>()() = v;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -386,7 +386,7 @@ tf_py_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
cuda_py_test(
|
||||
name = "optional_ops_test",
|
||||
size = "small",
|
||||
srcs = ["optional_ops_test.py"],
|
||||
@ -395,6 +395,7 @@ tf_py_test(
|
||||
"//tensorflow/python/data/ops:iterator_ops",
|
||||
"//tensorflow/python/data/ops:optional_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:errors",
|
||||
|
||||
@ -29,6 +29,7 @@ from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@ -115,6 +116,38 @@ class OptionalTest(test.TestCase):
|
||||
optional_ops.Optional.none_from_structure(
|
||||
dict_output_shapes, tuple_output_types, tuple_output_classes)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testCopyToGPU(self):
|
||||
if not test_util.is_gpu_available():
|
||||
self.skipTest("No GPU available")
|
||||
|
||||
with ops.device("/cpu:0"):
|
||||
optional_with_value = optional_ops.Optional.from_value(
|
||||
(constant_op.constant(37.0), constant_op.constant("Foo"),
|
||||
constant_op.constant(42)))
|
||||
optional_none = optional_ops.Optional.none_from_structure(
|
||||
tensor_shape.scalar(), dtypes.float32, ops.Tensor)
|
||||
|
||||
with ops.device("/gpu:0"):
|
||||
gpu_optional_with_value = optional_ops._OptionalImpl(
|
||||
array_ops.identity(optional_with_value._variant_tensor),
|
||||
optional_with_value.output_shapes, optional_with_value.output_types,
|
||||
optional_with_value.output_classes)
|
||||
gpu_optional_none = optional_ops._OptionalImpl(
|
||||
array_ops.identity(optional_none._variant_tensor),
|
||||
optional_none.output_shapes, optional_none.output_types,
|
||||
optional_none.output_classes)
|
||||
|
||||
gpu_optional_with_value_has_value = gpu_optional_with_value.has_value()
|
||||
gpu_optional_with_value_values = gpu_optional_with_value.get_value()
|
||||
|
||||
gpu_optional_none_has_value = gpu_optional_none.has_value()
|
||||
|
||||
self.assertTrue(self.evaluate(gpu_optional_with_value_has_value))
|
||||
self.assertEqual((37.0, b"Foo", 42),
|
||||
self.evaluate(gpu_optional_with_value_values))
|
||||
self.assertFalse(self.evaluate(gpu_optional_none_has_value))
|
||||
|
||||
def testIteratorGetNextAsOptional(self):
|
||||
ds = dataset_ops.Dataset.range(3)
|
||||
iterator = ds.make_initializable_iterator()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user