Implement some OptionalVariants functionality.

* ZerosLike variant registry function
* BinaryAdd variant registry function
* Shape variant registry function
* Enables copying nested variants between host/devices
* Refactors some common code from the corresponding TensorList
  functions. This should also enable TensorLists containing
  OptionalVariants (previously some of this code assumed any nested
  variant was also a TensorList).

PiperOrigin-RevId: 223112633
This commit is contained in:
Skye Wanderman-Milne 2018-11-27 22:25:15 -08:00 committed by TensorFlower Gardener
parent 5f111d5a6c
commit 567c0692de
8 changed files with 440 additions and 161 deletions

View File

@ -928,6 +928,7 @@ tf_cuda_library(
"util/stream_executor_util.h",
"util/strided_slice_op.h",
"util/tensor_format.h",
"util/tensor_ops_util.h",
"util/tensor_slice_reader.h",
"util/tensor_slice_reader_cache.h",
"util/tensor_slice_writer.h",

View File

@ -621,6 +621,10 @@ tf_kernel_library(
name = "optional_ops",
srcs = ["optional_ops.cc"],
hdrs = ["optional_ops.h"],
gpu_srcs = [
"optional_ops.cu.cc",
"optional_ops.h",
],
deps = [
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:dataset_ops_op_lib",
@ -628,6 +632,7 @@ tf_kernel_library(
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//third_party/eigen3",
],
)

View File

@ -22,75 +22,6 @@ limitations under the License.
namespace tensorflow {
namespace data {
namespace {
const char kOptionalVariantTypeName[] = "tensorflow::data::Optional";
// An `OptionalVariant` can represent either an "actual value" (a tuple of
// tensors) or "none", and may be stored in a DT_VARIANT tensor.
class OptionalVariant {
public:
// Create an `OptionalVariant` with no actual value.
OptionalVariant() : values_(nullptr) {}
// Create an `OptionalVariant` with the actual value given by the tuple of
// tensors in `values`.
explicit OptionalVariant(std::vector<Tensor> values)
: values_(new std::vector<Tensor>(std::move(values))) {}
OptionalVariant(const OptionalVariant& other) : values_(other.values_) {}
// Returns true if `this` represents an actual value.
bool has_value() const { return values_ != nullptr; }
// REQUIRES: `this->has_value()` must be true.
const std::vector<Tensor>& get_values() const {
CHECK(values_) << "Tried to get values from an empty OptionalVariant";
return *values_;
}
// Implementations of the necessary methods for using `OptionalVariant`
// objects in DT_VARIANT tensors.
string TypeName() const { return kOptionalVariantTypeName; }
void Encode(VariantTensorData* data) const {
data->set_metadata(values_ != nullptr);
if (values_ != nullptr) {
for (const auto& t : *values_) {
*(data->add_tensors()) = t;
}
}
}
bool Decode(const VariantTensorData& data) {
if (data.type_name() != TypeName()) {
return false;
}
bool has_value = false;
if (!data.get_metadata(&has_value)) {
return false;
}
if (has_value) {
values_.reset(new std::vector<Tensor>(data.tensors()));
} else {
values_.reset();
}
return true;
}
string DebugString() const {
if (values_) {
return strings::StrCat("OptionalVariant<", "values: (",
str_util::Join(*values_, ", ",
[](string* s, const Tensor& elem) {
*s = elem.DebugString();
}),
")>");
} else {
return strings::StrCat("OptionalVariant<None>");
}
}
private:
std::shared_ptr<const std::vector<Tensor>> values_;
};
class OptionalNoneOp : public OpKernel {
public:
@ -143,6 +74,12 @@ class OptionalGetValueOp : public OpKernel {
explicit OptionalGetValueOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
OP_REQUIRES(
ctx, output_shapes_.size() == output_types_.size(),
errors::InvalidArgument(
"output_types and output_shapes must be same length, got:\n",
"output_types: ", output_types_.size(), "\n",
"output_shapes: ", output_shapes_.size()));
}
void Compute(OpKernelContext* ctx) override {
@ -162,6 +99,10 @@ class OptionalGetValueOp : public OpKernel {
ctx, optional->has_value(),
errors::InvalidArgument("The given optional does not have a value."));
const auto& components = optional->get_values();
OP_REQUIRES(ctx, components.size() == output_types_.size(),
errors::InvalidArgument(
"The given optional has ", components.size(),
" components, expected ", output_types_.size()));
for (int i = 0; i < components.size(); ++i) {
OP_REQUIRES(
ctx, components[i].dtype() == output_types_[i],
@ -213,15 +154,7 @@ static Status OptionalDeviceCopy(
std::vector<Tensor> to_values;
to_values.reserve(from_values.size());
for (const Tensor& t : from_values) {
if (t.dtype() == DT_VARIANT) {
// TODO(b/116349787): Implement support for nested variants.
return errors::Unimplemented(
"Support for copying nested variants to device has not yet been "
"implemented.");
}
}
for (const Tensor& t : from_values) {
if (DMAHelper::CanUseDMA(&t)) {
if (DMAHelper::CanUseDMA(&t) || t.dtype() == DT_VARIANT) {
Tensor tmp(t.dtype());
TF_RETURN_IF_ERROR(copy(t, &tmp));
to_values.push_back(std::move(tmp));
@ -272,5 +205,20 @@ Status WriteOptionalNoneToOutput(OpKernelContext* ctx, int output_index) {
return Status::OK();
}
REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(ZEROS_LIKE_VARIANT_UNARY_OP,
DEVICE_CPU, OptionalVariant,
OptionalZerosLike<CPUDevice>);
REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(ADD_VARIANT_BINARY_OP, DEVICE_CPU,
OptionalVariant,
OptionalBinaryAdd<CPUDevice>);
Status OptionalShape(const OptionalVariant& x, TensorShape* s) {
*s = TensorShape({});
return Status::OK();
}
REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(OptionalVariant, OptionalShape);
} // namespace data
} // namespace tensorflow

View File

@ -0,0 +1,37 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#define EIGEN_USE_THREADS
#if GOOGLE_CUDA
#define EIGEN_USE_GPU
#include "tensorflow/core/kernels/data/optional_ops.h"
#include "tensorflow/core/framework/variant_op_registry.h"
namespace tensorflow {
namespace data {
REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(ZEROS_LIKE_VARIANT_UNARY_OP,
DEVICE_GPU, OptionalVariant,
OptionalZerosLike<GPUDevice>);
REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(ADD_VARIANT_BINARY_OP, DEVICE_GPU,
OptionalVariant,
OptionalBinaryAdd<GPUDevice>);
} // namespace data
} // namespace tensorflow
#endif // GOOGLE_CUDA

View File

@ -19,10 +19,13 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/variant_tensor_data.h"
#include "tensorflow/core/util/tensor_ops_util.h"
namespace tensorflow {
namespace data {
const char kOptionalVariantTypeName[] = "tensorflow::data::Optional";
// Stores a DT_VARIANT value representing an Optional with the given value
// in the `output_index`^th output of the given kernel execution context.
Status WriteOptionalWithValueToOutput(OpKernelContext* ctx, int output_index,
@ -32,6 +35,122 @@ Status WriteOptionalWithValueToOutput(OpKernelContext* ctx, int output_index,
// in the `output_index`^th output of the given kernel execution context.
Status WriteOptionalNoneToOutput(OpKernelContext* ctx, int output_index);
// An `OptionalVariant` can represent either an "actual value" (a tuple of
// tensors) or "none", and may be stored in a DT_VARIANT tensor.
class OptionalVariant {
public:
// Create an `OptionalVariant` with no actual value.
OptionalVariant() : values_(nullptr) {}
// Create an `OptionalVariant` with the actual value given by the tuple of
// tensors in `values`.
explicit OptionalVariant(std::vector<Tensor> values)
: values_(new std::vector<Tensor>(std::move(values))) {}
OptionalVariant(const OptionalVariant& other) : values_(other.values_) {}
// Returns true if `this` represents an actual value.
bool has_value() const { return values_ != nullptr; }
// REQUIRES: `this->has_value()` must be true.
const std::vector<Tensor>& get_values() const {
DCHECK(values_) << "Tried to get values from an empty OptionalVariant";
return *values_;
}
// Implementations of the necessary methods for using `OptionalVariant`
// objects in DT_VARIANT tensors.
string TypeName() const { return kOptionalVariantTypeName; }
void Encode(VariantTensorData* data) const {
data->set_metadata(values_ != nullptr);
if (values_ != nullptr) {
for (const auto& t : *values_) {
*(data->add_tensors()) = t;
}
}
}
bool Decode(const VariantTensorData& data) {
if (data.type_name() != TypeName()) {
return false;
}
bool has_value = false;
if (!data.get_metadata(&has_value)) {
return false;
}
if (has_value) {
values_.reset(new std::vector<Tensor>(data.tensors()));
} else {
values_.reset();
}
return true;
}
string DebugString() const {
if (values_) {
return strings::StrCat("OptionalVariant<", "values: (",
str_util::Join(*values_, ", ",
[](string* s, const Tensor& elem) {
*s = elem.DebugString();
}),
")>");
} else {
return strings::StrCat("OptionalVariant<None>");
}
}
private:
std::shared_ptr<const std::vector<Tensor>> values_;
};
template <typename Device>
Status OptionalZerosLike(OpKernelContext* ctx, const OptionalVariant& x,
OptionalVariant* y) {
if (!x.has_value()) {
*y = x;
return Status::OK();
}
std::vector<Tensor> zero_tensors;
for (const Tensor& tensor : x.get_values()) {
Tensor zero_t;
TF_RETURN_IF_ERROR(ZerosLikeTensor<Device>(ctx, tensor, &zero_t));
zero_tensors.push_back(std::move(zero_t));
}
*y = OptionalVariant(zero_tensors);
return Status::OK();
}
template <typename Device>
Status OptionalBinaryAdd(OpKernelContext* ctx, const OptionalVariant& a,
const OptionalVariant& b, OptionalVariant* out) {
// TODO(skyewm): should adding a value to a non-value be a no-op instead?
if (a.has_value() != b.has_value()) {
return errors::InvalidArgument(
"Cannot add optionals because one has a value and the other doesn't.");
}
if (!a.has_value()) {
*out = a;
return Status::OK();
}
if (a.get_values().size() != b.get_values().size()) {
return errors::InvalidArgument(
"Cannot add optionals because they have different numbers of "
"components (",
a.get_values().size(), " vs. ", b.get_values().size(), ").");
}
std::vector<Tensor> out_tensors;
for (int i = 0; i < a.get_values().size(); ++i) {
const Tensor& a_tensor = a.get_values()[i];
const Tensor& b_tensor = b.get_values()[i];
Tensor out_tensor;
TF_RETURN_IF_ERROR(
BinaryAddTensors<Device>(ctx, a_tensor, b_tensor, &out_tensor));
out_tensors.push_back(std::move(out_tensor));
}
*out = OptionalVariant(out_tensors);
return Status::OK();
}
} // namespace data
} // namespace tensorflow

View File

@ -30,6 +30,7 @@ limitations under the License.
#include "tensorflow/core/kernels/concat_lib.h"
#include "tensorflow/core/lib/core/coding.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/util/tensor_ops_util.h"
#include "tensorflow/core/util/util.h"
namespace tensorflow {
@ -357,51 +358,10 @@ Status TensorListBinaryAdd(OpKernelContext* c, const TensorList& a,
for (int i = 0; i < a.tensors.size(); ++i) {
const Tensor& a_tensor = a.tensors[i];
const Tensor& b_tensor = b.tensors[i];
if (a_tensor.dtype() == DT_INVALID) {
out->tensors.push_back(b_tensor);
continue;
}
if (b_tensor.dtype() == DT_INVALID) {
out->tensors.push_back(a_tensor);
continue;
}
if (a_tensor.shape() != b_tensor.shape()) {
// TODO(apassos) support broadcasting additions here?
return errors::InvalidArgument(
"Trying to add two tensors with incompatible element shapes. "
"One is ",
a_tensor.shape().DebugString(), " and the other is ",
b_tensor.shape().DebugString(), " in position ", i);
}
Tensor out_tensor;
AllocatorAttributes attr;
if (a_tensor.dtype() == DT_VARIANT) {
attr.set_on_host(true);
}
TF_RETURN_IF_ERROR(c->allocate_temp(a_tensor.dtype(), a_tensor.shape(),
&out_tensor, attr));
TF_RETURN_IF_ERROR(
BinaryAddTensors<Device>(c, a_tensor, b_tensor, &out_tensor));
out->tensors.push_back(out_tensor);
switch (out_tensor.dtype()) {
#define DTYPE_CASE(dtype) \
case DataTypeToEnum<dtype>::value: \
out_tensor.flat<dtype>().device(c->eigen_device<Device>()) = \
a_tensor.flat<dtype>() + b_tensor.flat<dtype>(); \
break;
TF_CALL_NUMBER_TYPES(DTYPE_CASE)
#undef DTYPE_CASE
case DataTypeToEnum<Variant>::value: {
Variant* v_out = &(out_tensor.scalar<Variant>()());
TF_RETURN_IF_ERROR(BinaryOpVariants<Device>(
c, ADD_VARIANT_BINARY_OP, a_tensor.scalar<Variant>()(),
b_tensor.scalar<Variant>()(), v_out));
break;
}
default:
return errors::InvalidArgument("Trying to add unsupported dtype ",
out_tensor.dtype());
}
}
return Status::OK();
}
@ -414,46 +374,7 @@ Status TensorListZerosLike(OpKernelContext* c, const TensorList& x,
y->tensors.reserve(x.tensors.size());
for (const Tensor& t : x.tensors) {
Tensor out_tensor;
AllocatorAttributes attr;
if (t.dtype() == DT_VARIANT) {
attr.set_on_host(true);
}
TF_RETURN_IF_ERROR(
c->allocate_temp(t.dtype(), t.shape(), &out_tensor, attr));
switch (out_tensor.dtype()) {
#define DTYPE_CASE(dtype) \
case DataTypeToEnum<dtype>::value: \
out_tensor.flat<dtype>().device(c->eigen_device<Device>()) = \
out_tensor.flat<dtype>().constant(dtype(0)); \
break;
TF_CALL_POD_TYPES(DTYPE_CASE)
#undef DTYPE_CASE
case DT_INVALID: {
// Uninitialized tensor in the TensorList.
out_tensor = Tensor(DT_INVALID);
break;
}
case DataTypeToEnum<Variant>::value: {
const TensorList* inner_x = t.scalar<Variant>()().get<TensorList>();
if (inner_x == nullptr) {
return errors::InvalidArgument("Input handle is not a list. Saw: '",
t.scalar<Variant>()().DebugString(),
"'");
}
TensorList inner_y;
TF_RETURN_IF_ERROR(TensorListZerosLike<Device>(c, *inner_x, &inner_y));
out_tensor.scalar<Variant>()() = std::move(inner_y);
break;
}
default:
return errors::InvalidArgument(
"Trying to compute zeros_like for unsupported dtype ",
DataTypeString(out_tensor.dtype()));
}
TF_RETURN_IF_ERROR(ZerosLikeTensor<Device>(c, t, &out_tensor));
y->tensors.emplace_back(out_tensor);
}
return Status::OK();

View File

@ -0,0 +1,128 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_UTIL_TENSOR_OPS_UTIL_H_
#define TENSORFLOW_CORE_UTIL_TENSOR_OPS_UTIL_H_
#define EIGEN_USE_THREADS
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/variant_op_registry.h"
#include "tensorflow/core/lib/core/status.h"
namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
template <typename Device>
Status ZerosLikeTensor(OpKernelContext* ctx, const Tensor& x, Tensor* out) {
AllocatorAttributes attr;
if (x.dtype() == DT_VARIANT) {
attr.set_on_host(true);
}
TF_RETURN_IF_ERROR(ctx->allocate_temp(x.dtype(), x.shape(), out, attr));
switch (out->dtype()) {
#define DTYPE_CASE(dtype) \
case DataTypeToEnum<dtype>::value: \
/* TODO(skyewm): use SetZeroFunctor like in ZerosLikeOp? */ \
out->flat<dtype>().device(ctx->eigen_device<Device>()) = \
out->flat<dtype>().constant(dtype(0)); \
break;
TF_CALL_POD_TYPES(DTYPE_CASE)
#undef DTYPE_CASE
case DT_INVALID: {
*out = Tensor(DT_INVALID);
break;
}
case DataTypeToEnum<Variant>::value: {
Variant* out_variant = out->scalar<Variant>().data();
TF_RETURN_IF_ERROR(
UnaryOpVariant<Device>(ctx, ZEROS_LIKE_VARIANT_UNARY_OP,
x.scalar<Variant>()(), out_variant));
break;
}
default:
return errors::InvalidArgument(
"Trying to compute zeros_like for unsupported dtype ",
DataTypeString(out->dtype()));
}
return Status::OK();
}
template <typename Device>
Status BinaryAddTensors(OpKernelContext* ctx, const Tensor& a, const Tensor& b,
Tensor* out) {
if (a.dtype() == DT_INVALID) {
*out = b;
return Status::OK();
}
if (b.dtype() == DT_INVALID) {
*out = a;
return Status::OK();
}
if (a.dtype() != b.dtype()) {
return errors::InvalidArgument(
"Trying to add two tensors with incompatible element types. ",
"One is ", DataTypeString(a.dtype()), " and the other is ",
DataTypeString(b.dtype()));
}
if (a.shape() != b.shape()) {
// TODO(apassos) support broadcasting additions here?
return errors::InvalidArgument(
"Trying to add two tensors with incompatible element shapes. ",
"One is ", a.shape().DebugString(), " and the other is ",
b.shape().DebugString());
}
AllocatorAttributes attr;
if (a.dtype() == DT_VARIANT) {
attr.set_on_host(true);
}
TF_RETURN_IF_ERROR(ctx->allocate_temp(a.dtype(), a.shape(), out, attr));
switch (out->dtype()) {
#define DTYPE_CASE(dtype) \
case DataTypeToEnum<dtype>::value: \
out->flat<dtype>().device(ctx->eigen_device<Device>()) = \
a.flat<dtype>() + b.flat<dtype>(); \
break;
TF_CALL_NUMBER_TYPES(DTYPE_CASE)
#undef DTYPE_CASE
case DataTypeToEnum<Variant>::value: {
Variant* out_variant = out->scalar<Variant>().data();
TF_RETURN_IF_ERROR(BinaryOpVariants<Device>(
ctx, ADD_VARIANT_BINARY_OP, a.scalar<Variant>()(),
b.scalar<Variant>()(), out_variant));
break;
}
default:
return errors::InvalidArgument("Trying to add unsupported dtype ",
out->dtype());
}
return Status::OK();
}
} // namespace tensorflow
#endif // TENSORFLOW_CORE_UTIL_TENSOR_OPS_UTIL_H_

View File

@ -33,6 +33,7 @@ from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
@ -87,6 +88,90 @@ class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase):
with self.assertRaises(errors.InvalidArgumentError):
self.evaluate(opt.get_value())
def testAddN(self):
devices = ["/cpu:0"]
if test_util.is_gpu_available():
devices.append("/gpu:0")
for device in devices:
with ops.device(device):
# With value
opt1 = optional_ops.Optional.from_value((1.0, 2.0))
opt2 = optional_ops.Optional.from_value((3.0, 4.0))
add_tensor = math_ops.add_n([opt1._variant_tensor,
opt2._variant_tensor])
add_opt = optional_ops._OptionalImpl(add_tensor, opt1.value_structure)
self.assertAllEqual(self.evaluate(add_opt.get_value()), (4.0, 6.0))
# Without value
opt_none1 = optional_ops.Optional.none_from_structure(
opt1.value_structure)
opt_none2 = optional_ops.Optional.none_from_structure(
opt2.value_structure)
add_tensor = math_ops.add_n([opt_none1._variant_tensor,
opt_none2._variant_tensor])
add_opt = optional_ops._OptionalImpl(add_tensor,
opt_none1.value_structure)
self.assertFalse(self.evaluate(add_opt.has_value()))
def testNestedAddN(self):
devices = ["/cpu:0"]
if test_util.is_gpu_available():
devices.append("/gpu:0")
for device in devices:
with ops.device(device):
opt1 = optional_ops.Optional.from_value([1, 2.0])
opt2 = optional_ops.Optional.from_value([3, 4.0])
opt3 = optional_ops.Optional.from_value((5.0, opt1._variant_tensor))
opt4 = optional_ops.Optional.from_value((6.0, opt2._variant_tensor))
add_tensor = math_ops.add_n([opt3._variant_tensor,
opt4._variant_tensor])
add_opt = optional_ops._OptionalImpl(add_tensor, opt3.value_structure)
self.assertEqual(self.evaluate(add_opt.get_value()[0]), 11.0)
inner_add_opt = optional_ops._OptionalImpl(add_opt.get_value()[1],
opt1.value_structure)
self.assertAllEqual(inner_add_opt.get_value(), [4, 6.0])
def testZerosLike(self):
devices = ["/cpu:0"]
if test_util.is_gpu_available():
devices.append("/gpu:0")
for device in devices:
with ops.device(device):
# With value
opt = optional_ops.Optional.from_value((1.0, 2.0))
zeros_tensor = array_ops.zeros_like(opt._variant_tensor)
zeros_opt = optional_ops._OptionalImpl(zeros_tensor,
opt.value_structure)
self.assertAllEqual(self.evaluate(zeros_opt.get_value()),
(0.0, 0.0))
# Without value
opt_none = optional_ops.Optional.none_from_structure(
opt.value_structure)
zeros_tensor = array_ops.zeros_like(opt_none._variant_tensor)
zeros_opt = optional_ops._OptionalImpl(zeros_tensor,
opt_none.value_structure)
self.assertFalse(self.evaluate(zeros_opt.has_value()))
def testNestedZerosLike(self):
devices = ["/cpu:0"]
if test_util.is_gpu_available():
devices.append("/gpu:0")
for device in devices:
with ops.device(device):
opt1 = optional_ops.Optional.from_value(1.0)
opt2 = optional_ops.Optional.from_value(opt1._variant_tensor)
zeros_tensor = array_ops.zeros_like(opt2._variant_tensor)
zeros_opt = optional_ops._OptionalImpl(zeros_tensor,
opt2.value_structure)
inner_zeros_opt = optional_ops._OptionalImpl(zeros_opt.get_value(),
opt1.value_structure)
self.assertEqual(self.evaluate(inner_zeros_opt.get_value()), 0.0)
def testCopyToGPU(self):
if not test_util.is_gpu_available():
self.skipTest("No GPU available")
@ -116,6 +201,41 @@ class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase):
self.evaluate(gpu_optional_with_value_values))
self.assertFalse(self.evaluate(gpu_optional_none_has_value))
def testNestedCopyToGPU(self):
if not test_util.is_gpu_available():
self.skipTest("No GPU available")
with ops.device("/cpu:0"):
optional_with_value = optional_ops.Optional.from_value(
(constant_op.constant(37.0), constant_op.constant("Foo"),
constant_op.constant(42)))
optional_none = optional_ops.Optional.none_from_structure(
structure.TensorStructure(dtypes.float32, []))
nested_optional = optional_ops.Optional.from_value(
(optional_with_value._variant_tensor, optional_none._variant_tensor,
1.0))
with ops.device("/gpu:0"):
gpu_nested_optional = optional_ops._OptionalImpl(
array_ops.identity(nested_optional._variant_tensor),
nested_optional.value_structure)
gpu_nested_optional_has_value = gpu_nested_optional.has_value()
gpu_nested_optional_values = gpu_nested_optional.get_value()
self.assertTrue(self.evaluate(gpu_nested_optional_has_value))
inner_with_value = optional_ops._OptionalImpl(
gpu_nested_optional_values[0], optional_with_value.value_structure)
inner_none = optional_ops._OptionalImpl(
gpu_nested_optional_values[1], optional_none.value_structure)
self.assertEqual((37.0, b"Foo", 42),
self.evaluate(inner_with_value.get_value()))
self.assertFalse(self.evaluate(inner_none.has_value()))
self.assertEqual(1.0, self.evaluate(gpu_nested_optional_values[2]))
def _assertElementValueEqual(self, expected, actual):
if isinstance(expected, dict):
self.assertItemsEqual(list(expected.keys()), list(actual.keys()))