Implement some OptionalVariants functionality.
* ZerosLike variant registry function * BinaryAdd variant registry function * Shape variant registry function * Enables copying nested variants between host/devices * Refactors some common code from the corresponding TensorList functions. This should also enable TensorLists containing OptionalVariants (previously some of this code assumed any nested variant was also a TensorList). PiperOrigin-RevId: 223112633
This commit is contained in:
parent
5f111d5a6c
commit
567c0692de
@ -928,6 +928,7 @@ tf_cuda_library(
|
||||
"util/stream_executor_util.h",
|
||||
"util/strided_slice_op.h",
|
||||
"util/tensor_format.h",
|
||||
"util/tensor_ops_util.h",
|
||||
"util/tensor_slice_reader.h",
|
||||
"util/tensor_slice_reader_cache.h",
|
||||
"util/tensor_slice_writer.h",
|
||||
|
@ -621,6 +621,10 @@ tf_kernel_library(
|
||||
name = "optional_ops",
|
||||
srcs = ["optional_ops.cc"],
|
||||
hdrs = ["optional_ops.h"],
|
||||
gpu_srcs = [
|
||||
"optional_ops.cu.cc",
|
||||
"optional_ops.h",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:dataset_ops_op_lib",
|
||||
@ -628,6 +632,7 @@ tf_kernel_library(
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//third_party/eigen3",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -22,75 +22,6 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace data {
|
||||
namespace {
|
||||
const char kOptionalVariantTypeName[] = "tensorflow::data::Optional";
|
||||
|
||||
// An `OptionalVariant` can represent either an "actual value" (a tuple of
|
||||
// tensors) or "none", and may be stored in a DT_VARIANT tensor.
|
||||
class OptionalVariant {
|
||||
public:
|
||||
// Create an `OptionalVariant` with no actual value.
|
||||
OptionalVariant() : values_(nullptr) {}
|
||||
|
||||
// Create an `OptionalVariant` with the actual value given by the tuple of
|
||||
// tensors in `values`.
|
||||
explicit OptionalVariant(std::vector<Tensor> values)
|
||||
: values_(new std::vector<Tensor>(std::move(values))) {}
|
||||
|
||||
OptionalVariant(const OptionalVariant& other) : values_(other.values_) {}
|
||||
|
||||
// Returns true if `this` represents an actual value.
|
||||
bool has_value() const { return values_ != nullptr; }
|
||||
|
||||
// REQUIRES: `this->has_value()` must be true.
|
||||
const std::vector<Tensor>& get_values() const {
|
||||
CHECK(values_) << "Tried to get values from an empty OptionalVariant";
|
||||
return *values_;
|
||||
}
|
||||
|
||||
// Implementations of the necessary methods for using `OptionalVariant`
|
||||
// objects in DT_VARIANT tensors.
|
||||
string TypeName() const { return kOptionalVariantTypeName; }
|
||||
void Encode(VariantTensorData* data) const {
|
||||
data->set_metadata(values_ != nullptr);
|
||||
if (values_ != nullptr) {
|
||||
for (const auto& t : *values_) {
|
||||
*(data->add_tensors()) = t;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool Decode(const VariantTensorData& data) {
|
||||
if (data.type_name() != TypeName()) {
|
||||
return false;
|
||||
}
|
||||
bool has_value = false;
|
||||
if (!data.get_metadata(&has_value)) {
|
||||
return false;
|
||||
}
|
||||
if (has_value) {
|
||||
values_.reset(new std::vector<Tensor>(data.tensors()));
|
||||
} else {
|
||||
values_.reset();
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
string DebugString() const {
|
||||
if (values_) {
|
||||
return strings::StrCat("OptionalVariant<", "values: (",
|
||||
str_util::Join(*values_, ", ",
|
||||
[](string* s, const Tensor& elem) {
|
||||
*s = elem.DebugString();
|
||||
}),
|
||||
")>");
|
||||
} else {
|
||||
return strings::StrCat("OptionalVariant<None>");
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
std::shared_ptr<const std::vector<Tensor>> values_;
|
||||
};
|
||||
|
||||
class OptionalNoneOp : public OpKernel {
|
||||
public:
|
||||
@ -143,6 +74,12 @@ class OptionalGetValueOp : public OpKernel {
|
||||
explicit OptionalGetValueOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
|
||||
OP_REQUIRES(
|
||||
ctx, output_shapes_.size() == output_types_.size(),
|
||||
errors::InvalidArgument(
|
||||
"output_types and output_shapes must be same length, got:\n",
|
||||
"output_types: ", output_types_.size(), "\n",
|
||||
"output_shapes: ", output_shapes_.size()));
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
@ -162,6 +99,10 @@ class OptionalGetValueOp : public OpKernel {
|
||||
ctx, optional->has_value(),
|
||||
errors::InvalidArgument("The given optional does not have a value."));
|
||||
const auto& components = optional->get_values();
|
||||
OP_REQUIRES(ctx, components.size() == output_types_.size(),
|
||||
errors::InvalidArgument(
|
||||
"The given optional has ", components.size(),
|
||||
" components, expected ", output_types_.size()));
|
||||
for (int i = 0; i < components.size(); ++i) {
|
||||
OP_REQUIRES(
|
||||
ctx, components[i].dtype() == output_types_[i],
|
||||
@ -213,15 +154,7 @@ static Status OptionalDeviceCopy(
|
||||
std::vector<Tensor> to_values;
|
||||
to_values.reserve(from_values.size());
|
||||
for (const Tensor& t : from_values) {
|
||||
if (t.dtype() == DT_VARIANT) {
|
||||
// TODO(b/116349787): Implement support for nested variants.
|
||||
return errors::Unimplemented(
|
||||
"Support for copying nested variants to device has not yet been "
|
||||
"implemented.");
|
||||
}
|
||||
}
|
||||
for (const Tensor& t : from_values) {
|
||||
if (DMAHelper::CanUseDMA(&t)) {
|
||||
if (DMAHelper::CanUseDMA(&t) || t.dtype() == DT_VARIANT) {
|
||||
Tensor tmp(t.dtype());
|
||||
TF_RETURN_IF_ERROR(copy(t, &tmp));
|
||||
to_values.push_back(std::move(tmp));
|
||||
@ -272,5 +205,20 @@ Status WriteOptionalNoneToOutput(OpKernelContext* ctx, int output_index) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(ZEROS_LIKE_VARIANT_UNARY_OP,
|
||||
DEVICE_CPU, OptionalVariant,
|
||||
OptionalZerosLike<CPUDevice>);
|
||||
|
||||
REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(ADD_VARIANT_BINARY_OP, DEVICE_CPU,
|
||||
OptionalVariant,
|
||||
OptionalBinaryAdd<CPUDevice>);
|
||||
|
||||
Status OptionalShape(const OptionalVariant& x, TensorShape* s) {
|
||||
*s = TensorShape({});
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(OptionalVariant, OptionalShape);
|
||||
|
||||
} // namespace data
|
||||
} // namespace tensorflow
|
||||
|
37
tensorflow/core/kernels/data/optional_ops.cu.cc
Normal file
37
tensorflow/core/kernels/data/optional_ops.cu.cc
Normal file
@ -0,0 +1,37 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#define EIGEN_USE_THREADS
|
||||
#if GOOGLE_CUDA
|
||||
#define EIGEN_USE_GPU
|
||||
|
||||
#include "tensorflow/core/kernels/data/optional_ops.h"
|
||||
|
||||
#include "tensorflow/core/framework/variant_op_registry.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace data {
|
||||
|
||||
REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(ZEROS_LIKE_VARIANT_UNARY_OP,
|
||||
DEVICE_GPU, OptionalVariant,
|
||||
OptionalZerosLike<GPUDevice>);
|
||||
|
||||
REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(ADD_VARIANT_BINARY_OP, DEVICE_GPU,
|
||||
OptionalVariant,
|
||||
OptionalBinaryAdd<GPUDevice>);
|
||||
|
||||
} // namespace data
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // GOOGLE_CUDA
|
@ -19,10 +19,13 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/variant_tensor_data.h"
|
||||
#include "tensorflow/core/util/tensor_ops_util.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace data {
|
||||
|
||||
const char kOptionalVariantTypeName[] = "tensorflow::data::Optional";
|
||||
|
||||
// Stores a DT_VARIANT value representing an Optional with the given value
|
||||
// in the `output_index`^th output of the given kernel execution context.
|
||||
Status WriteOptionalWithValueToOutput(OpKernelContext* ctx, int output_index,
|
||||
@ -32,6 +35,122 @@ Status WriteOptionalWithValueToOutput(OpKernelContext* ctx, int output_index,
|
||||
// in the `output_index`^th output of the given kernel execution context.
|
||||
Status WriteOptionalNoneToOutput(OpKernelContext* ctx, int output_index);
|
||||
|
||||
// An `OptionalVariant` can represent either an "actual value" (a tuple of
|
||||
// tensors) or "none", and may be stored in a DT_VARIANT tensor.
|
||||
class OptionalVariant {
|
||||
public:
|
||||
// Create an `OptionalVariant` with no actual value.
|
||||
OptionalVariant() : values_(nullptr) {}
|
||||
|
||||
// Create an `OptionalVariant` with the actual value given by the tuple of
|
||||
// tensors in `values`.
|
||||
explicit OptionalVariant(std::vector<Tensor> values)
|
||||
: values_(new std::vector<Tensor>(std::move(values))) {}
|
||||
|
||||
OptionalVariant(const OptionalVariant& other) : values_(other.values_) {}
|
||||
|
||||
// Returns true if `this` represents an actual value.
|
||||
bool has_value() const { return values_ != nullptr; }
|
||||
|
||||
// REQUIRES: `this->has_value()` must be true.
|
||||
const std::vector<Tensor>& get_values() const {
|
||||
DCHECK(values_) << "Tried to get values from an empty OptionalVariant";
|
||||
return *values_;
|
||||
}
|
||||
|
||||
// Implementations of the necessary methods for using `OptionalVariant`
|
||||
// objects in DT_VARIANT tensors.
|
||||
string TypeName() const { return kOptionalVariantTypeName; }
|
||||
void Encode(VariantTensorData* data) const {
|
||||
data->set_metadata(values_ != nullptr);
|
||||
if (values_ != nullptr) {
|
||||
for (const auto& t : *values_) {
|
||||
*(data->add_tensors()) = t;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool Decode(const VariantTensorData& data) {
|
||||
if (data.type_name() != TypeName()) {
|
||||
return false;
|
||||
}
|
||||
bool has_value = false;
|
||||
if (!data.get_metadata(&has_value)) {
|
||||
return false;
|
||||
}
|
||||
if (has_value) {
|
||||
values_.reset(new std::vector<Tensor>(data.tensors()));
|
||||
} else {
|
||||
values_.reset();
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
string DebugString() const {
|
||||
if (values_) {
|
||||
return strings::StrCat("OptionalVariant<", "values: (",
|
||||
str_util::Join(*values_, ", ",
|
||||
[](string* s, const Tensor& elem) {
|
||||
*s = elem.DebugString();
|
||||
}),
|
||||
")>");
|
||||
} else {
|
||||
return strings::StrCat("OptionalVariant<None>");
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
std::shared_ptr<const std::vector<Tensor>> values_;
|
||||
};
|
||||
|
||||
template <typename Device>
|
||||
Status OptionalZerosLike(OpKernelContext* ctx, const OptionalVariant& x,
|
||||
OptionalVariant* y) {
|
||||
if (!x.has_value()) {
|
||||
*y = x;
|
||||
return Status::OK();
|
||||
}
|
||||
std::vector<Tensor> zero_tensors;
|
||||
for (const Tensor& tensor : x.get_values()) {
|
||||
Tensor zero_t;
|
||||
TF_RETURN_IF_ERROR(ZerosLikeTensor<Device>(ctx, tensor, &zero_t));
|
||||
zero_tensors.push_back(std::move(zero_t));
|
||||
}
|
||||
*y = OptionalVariant(zero_tensors);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename Device>
|
||||
Status OptionalBinaryAdd(OpKernelContext* ctx, const OptionalVariant& a,
|
||||
const OptionalVariant& b, OptionalVariant* out) {
|
||||
// TODO(skyewm): should adding a value to a non-value be a no-op instead?
|
||||
if (a.has_value() != b.has_value()) {
|
||||
return errors::InvalidArgument(
|
||||
"Cannot add optionals because one has a value and the other doesn't.");
|
||||
}
|
||||
if (!a.has_value()) {
|
||||
*out = a;
|
||||
return Status::OK();
|
||||
}
|
||||
if (a.get_values().size() != b.get_values().size()) {
|
||||
return errors::InvalidArgument(
|
||||
"Cannot add optionals because they have different numbers of "
|
||||
"components (",
|
||||
a.get_values().size(), " vs. ", b.get_values().size(), ").");
|
||||
}
|
||||
std::vector<Tensor> out_tensors;
|
||||
for (int i = 0; i < a.get_values().size(); ++i) {
|
||||
const Tensor& a_tensor = a.get_values()[i];
|
||||
const Tensor& b_tensor = b.get_values()[i];
|
||||
Tensor out_tensor;
|
||||
TF_RETURN_IF_ERROR(
|
||||
BinaryAddTensors<Device>(ctx, a_tensor, b_tensor, &out_tensor));
|
||||
out_tensors.push_back(std::move(out_tensor));
|
||||
}
|
||||
*out = OptionalVariant(out_tensors);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace data
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -30,6 +30,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/kernels/concat_lib.h"
|
||||
#include "tensorflow/core/lib/core/coding.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/util/tensor_ops_util.h"
|
||||
#include "tensorflow/core/util/util.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -357,51 +358,10 @@ Status TensorListBinaryAdd(OpKernelContext* c, const TensorList& a,
|
||||
for (int i = 0; i < a.tensors.size(); ++i) {
|
||||
const Tensor& a_tensor = a.tensors[i];
|
||||
const Tensor& b_tensor = b.tensors[i];
|
||||
if (a_tensor.dtype() == DT_INVALID) {
|
||||
out->tensors.push_back(b_tensor);
|
||||
continue;
|
||||
}
|
||||
if (b_tensor.dtype() == DT_INVALID) {
|
||||
out->tensors.push_back(a_tensor);
|
||||
continue;
|
||||
}
|
||||
if (a_tensor.shape() != b_tensor.shape()) {
|
||||
// TODO(apassos) support broadcasting additions here?
|
||||
return errors::InvalidArgument(
|
||||
"Trying to add two tensors with incompatible element shapes. "
|
||||
"One is ",
|
||||
a_tensor.shape().DebugString(), " and the other is ",
|
||||
b_tensor.shape().DebugString(), " in position ", i);
|
||||
}
|
||||
Tensor out_tensor;
|
||||
AllocatorAttributes attr;
|
||||
if (a_tensor.dtype() == DT_VARIANT) {
|
||||
attr.set_on_host(true);
|
||||
}
|
||||
TF_RETURN_IF_ERROR(c->allocate_temp(a_tensor.dtype(), a_tensor.shape(),
|
||||
&out_tensor, attr));
|
||||
TF_RETURN_IF_ERROR(
|
||||
BinaryAddTensors<Device>(c, a_tensor, b_tensor, &out_tensor));
|
||||
out->tensors.push_back(out_tensor);
|
||||
switch (out_tensor.dtype()) {
|
||||
#define DTYPE_CASE(dtype) \
|
||||
case DataTypeToEnum<dtype>::value: \
|
||||
out_tensor.flat<dtype>().device(c->eigen_device<Device>()) = \
|
||||
a_tensor.flat<dtype>() + b_tensor.flat<dtype>(); \
|
||||
break;
|
||||
|
||||
TF_CALL_NUMBER_TYPES(DTYPE_CASE)
|
||||
|
||||
#undef DTYPE_CASE
|
||||
case DataTypeToEnum<Variant>::value: {
|
||||
Variant* v_out = &(out_tensor.scalar<Variant>()());
|
||||
TF_RETURN_IF_ERROR(BinaryOpVariants<Device>(
|
||||
c, ADD_VARIANT_BINARY_OP, a_tensor.scalar<Variant>()(),
|
||||
b_tensor.scalar<Variant>()(), v_out));
|
||||
break;
|
||||
}
|
||||
default:
|
||||
return errors::InvalidArgument("Trying to add unsupported dtype ",
|
||||
out_tensor.dtype());
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
@ -414,46 +374,7 @@ Status TensorListZerosLike(OpKernelContext* c, const TensorList& x,
|
||||
y->tensors.reserve(x.tensors.size());
|
||||
for (const Tensor& t : x.tensors) {
|
||||
Tensor out_tensor;
|
||||
AllocatorAttributes attr;
|
||||
if (t.dtype() == DT_VARIANT) {
|
||||
attr.set_on_host(true);
|
||||
}
|
||||
TF_RETURN_IF_ERROR(
|
||||
c->allocate_temp(t.dtype(), t.shape(), &out_tensor, attr));
|
||||
switch (out_tensor.dtype()) {
|
||||
#define DTYPE_CASE(dtype) \
|
||||
case DataTypeToEnum<dtype>::value: \
|
||||
out_tensor.flat<dtype>().device(c->eigen_device<Device>()) = \
|
||||
out_tensor.flat<dtype>().constant(dtype(0)); \
|
||||
break;
|
||||
|
||||
TF_CALL_POD_TYPES(DTYPE_CASE)
|
||||
|
||||
#undef DTYPE_CASE
|
||||
|
||||
case DT_INVALID: {
|
||||
// Uninitialized tensor in the TensorList.
|
||||
out_tensor = Tensor(DT_INVALID);
|
||||
break;
|
||||
}
|
||||
case DataTypeToEnum<Variant>::value: {
|
||||
const TensorList* inner_x = t.scalar<Variant>()().get<TensorList>();
|
||||
if (inner_x == nullptr) {
|
||||
return errors::InvalidArgument("Input handle is not a list. Saw: '",
|
||||
t.scalar<Variant>()().DebugString(),
|
||||
"'");
|
||||
}
|
||||
TensorList inner_y;
|
||||
TF_RETURN_IF_ERROR(TensorListZerosLike<Device>(c, *inner_x, &inner_y));
|
||||
out_tensor.scalar<Variant>()() = std::move(inner_y);
|
||||
break;
|
||||
}
|
||||
|
||||
default:
|
||||
return errors::InvalidArgument(
|
||||
"Trying to compute zeros_like for unsupported dtype ",
|
||||
DataTypeString(out_tensor.dtype()));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(ZerosLikeTensor<Device>(c, t, &out_tensor));
|
||||
y->tensors.emplace_back(out_tensor);
|
||||
}
|
||||
return Status::OK();
|
||||
|
128
tensorflow/core/util/tensor_ops_util.h
Normal file
128
tensorflow/core/util/tensor_ops_util.h
Normal file
@ -0,0 +1,128 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_CORE_UTIL_TENSOR_OPS_UTIL_H_
|
||||
#define TENSORFLOW_CORE_UTIL_TENSOR_OPS_UTIL_H_
|
||||
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/framework/variant_op_registry.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||
typedef Eigen::GpuDevice GPUDevice;
|
||||
|
||||
template <typename Device>
|
||||
Status ZerosLikeTensor(OpKernelContext* ctx, const Tensor& x, Tensor* out) {
|
||||
AllocatorAttributes attr;
|
||||
if (x.dtype() == DT_VARIANT) {
|
||||
attr.set_on_host(true);
|
||||
}
|
||||
TF_RETURN_IF_ERROR(ctx->allocate_temp(x.dtype(), x.shape(), out, attr));
|
||||
|
||||
switch (out->dtype()) {
|
||||
#define DTYPE_CASE(dtype) \
|
||||
case DataTypeToEnum<dtype>::value: \
|
||||
/* TODO(skyewm): use SetZeroFunctor like in ZerosLikeOp? */ \
|
||||
out->flat<dtype>().device(ctx->eigen_device<Device>()) = \
|
||||
out->flat<dtype>().constant(dtype(0)); \
|
||||
break;
|
||||
|
||||
TF_CALL_POD_TYPES(DTYPE_CASE)
|
||||
#undef DTYPE_CASE
|
||||
|
||||
case DT_INVALID: {
|
||||
*out = Tensor(DT_INVALID);
|
||||
break;
|
||||
}
|
||||
case DataTypeToEnum<Variant>::value: {
|
||||
Variant* out_variant = out->scalar<Variant>().data();
|
||||
TF_RETURN_IF_ERROR(
|
||||
UnaryOpVariant<Device>(ctx, ZEROS_LIKE_VARIANT_UNARY_OP,
|
||||
x.scalar<Variant>()(), out_variant));
|
||||
break;
|
||||
}
|
||||
default:
|
||||
return errors::InvalidArgument(
|
||||
"Trying to compute zeros_like for unsupported dtype ",
|
||||
DataTypeString(out->dtype()));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename Device>
|
||||
Status BinaryAddTensors(OpKernelContext* ctx, const Tensor& a, const Tensor& b,
|
||||
Tensor* out) {
|
||||
if (a.dtype() == DT_INVALID) {
|
||||
*out = b;
|
||||
return Status::OK();
|
||||
}
|
||||
if (b.dtype() == DT_INVALID) {
|
||||
*out = a;
|
||||
return Status::OK();
|
||||
}
|
||||
if (a.dtype() != b.dtype()) {
|
||||
return errors::InvalidArgument(
|
||||
"Trying to add two tensors with incompatible element types. ",
|
||||
"One is ", DataTypeString(a.dtype()), " and the other is ",
|
||||
DataTypeString(b.dtype()));
|
||||
}
|
||||
if (a.shape() != b.shape()) {
|
||||
// TODO(apassos) support broadcasting additions here?
|
||||
return errors::InvalidArgument(
|
||||
"Trying to add two tensors with incompatible element shapes. ",
|
||||
"One is ", a.shape().DebugString(), " and the other is ",
|
||||
b.shape().DebugString());
|
||||
}
|
||||
|
||||
AllocatorAttributes attr;
|
||||
if (a.dtype() == DT_VARIANT) {
|
||||
attr.set_on_host(true);
|
||||
}
|
||||
TF_RETURN_IF_ERROR(ctx->allocate_temp(a.dtype(), a.shape(), out, attr));
|
||||
|
||||
switch (out->dtype()) {
|
||||
#define DTYPE_CASE(dtype) \
|
||||
case DataTypeToEnum<dtype>::value: \
|
||||
out->flat<dtype>().device(ctx->eigen_device<Device>()) = \
|
||||
a.flat<dtype>() + b.flat<dtype>(); \
|
||||
break;
|
||||
|
||||
TF_CALL_NUMBER_TYPES(DTYPE_CASE)
|
||||
#undef DTYPE_CASE
|
||||
|
||||
case DataTypeToEnum<Variant>::value: {
|
||||
Variant* out_variant = out->scalar<Variant>().data();
|
||||
TF_RETURN_IF_ERROR(BinaryOpVariants<Device>(
|
||||
ctx, ADD_VARIANT_BINARY_OP, a.scalar<Variant>()(),
|
||||
b.scalar<Variant>()(), out_variant));
|
||||
break;
|
||||
}
|
||||
default:
|
||||
return errors::InvalidArgument("Trying to add unsupported dtype ",
|
||||
out->dtype());
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_UTIL_TENSOR_OPS_UTIL_H_
|
@ -33,6 +33,7 @@ from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@ -87,6 +88,90 @@ class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
self.evaluate(opt.get_value())
|
||||
|
||||
def testAddN(self):
|
||||
devices = ["/cpu:0"]
|
||||
if test_util.is_gpu_available():
|
||||
devices.append("/gpu:0")
|
||||
for device in devices:
|
||||
with ops.device(device):
|
||||
# With value
|
||||
opt1 = optional_ops.Optional.from_value((1.0, 2.0))
|
||||
opt2 = optional_ops.Optional.from_value((3.0, 4.0))
|
||||
|
||||
add_tensor = math_ops.add_n([opt1._variant_tensor,
|
||||
opt2._variant_tensor])
|
||||
add_opt = optional_ops._OptionalImpl(add_tensor, opt1.value_structure)
|
||||
self.assertAllEqual(self.evaluate(add_opt.get_value()), (4.0, 6.0))
|
||||
|
||||
# Without value
|
||||
opt_none1 = optional_ops.Optional.none_from_structure(
|
||||
opt1.value_structure)
|
||||
opt_none2 = optional_ops.Optional.none_from_structure(
|
||||
opt2.value_structure)
|
||||
add_tensor = math_ops.add_n([opt_none1._variant_tensor,
|
||||
opt_none2._variant_tensor])
|
||||
add_opt = optional_ops._OptionalImpl(add_tensor,
|
||||
opt_none1.value_structure)
|
||||
self.assertFalse(self.evaluate(add_opt.has_value()))
|
||||
|
||||
def testNestedAddN(self):
|
||||
devices = ["/cpu:0"]
|
||||
if test_util.is_gpu_available():
|
||||
devices.append("/gpu:0")
|
||||
for device in devices:
|
||||
with ops.device(device):
|
||||
opt1 = optional_ops.Optional.from_value([1, 2.0])
|
||||
opt2 = optional_ops.Optional.from_value([3, 4.0])
|
||||
opt3 = optional_ops.Optional.from_value((5.0, opt1._variant_tensor))
|
||||
opt4 = optional_ops.Optional.from_value((6.0, opt2._variant_tensor))
|
||||
|
||||
add_tensor = math_ops.add_n([opt3._variant_tensor,
|
||||
opt4._variant_tensor])
|
||||
add_opt = optional_ops._OptionalImpl(add_tensor, opt3.value_structure)
|
||||
self.assertEqual(self.evaluate(add_opt.get_value()[0]), 11.0)
|
||||
|
||||
inner_add_opt = optional_ops._OptionalImpl(add_opt.get_value()[1],
|
||||
opt1.value_structure)
|
||||
self.assertAllEqual(inner_add_opt.get_value(), [4, 6.0])
|
||||
|
||||
def testZerosLike(self):
|
||||
devices = ["/cpu:0"]
|
||||
if test_util.is_gpu_available():
|
||||
devices.append("/gpu:0")
|
||||
for device in devices:
|
||||
with ops.device(device):
|
||||
# With value
|
||||
opt = optional_ops.Optional.from_value((1.0, 2.0))
|
||||
zeros_tensor = array_ops.zeros_like(opt._variant_tensor)
|
||||
zeros_opt = optional_ops._OptionalImpl(zeros_tensor,
|
||||
opt.value_structure)
|
||||
self.assertAllEqual(self.evaluate(zeros_opt.get_value()),
|
||||
(0.0, 0.0))
|
||||
|
||||
# Without value
|
||||
opt_none = optional_ops.Optional.none_from_structure(
|
||||
opt.value_structure)
|
||||
zeros_tensor = array_ops.zeros_like(opt_none._variant_tensor)
|
||||
zeros_opt = optional_ops._OptionalImpl(zeros_tensor,
|
||||
opt_none.value_structure)
|
||||
self.assertFalse(self.evaluate(zeros_opt.has_value()))
|
||||
|
||||
def testNestedZerosLike(self):
|
||||
devices = ["/cpu:0"]
|
||||
if test_util.is_gpu_available():
|
||||
devices.append("/gpu:0")
|
||||
for device in devices:
|
||||
with ops.device(device):
|
||||
opt1 = optional_ops.Optional.from_value(1.0)
|
||||
opt2 = optional_ops.Optional.from_value(opt1._variant_tensor)
|
||||
|
||||
zeros_tensor = array_ops.zeros_like(opt2._variant_tensor)
|
||||
zeros_opt = optional_ops._OptionalImpl(zeros_tensor,
|
||||
opt2.value_structure)
|
||||
inner_zeros_opt = optional_ops._OptionalImpl(zeros_opt.get_value(),
|
||||
opt1.value_structure)
|
||||
self.assertEqual(self.evaluate(inner_zeros_opt.get_value()), 0.0)
|
||||
|
||||
def testCopyToGPU(self):
|
||||
if not test_util.is_gpu_available():
|
||||
self.skipTest("No GPU available")
|
||||
@ -116,6 +201,41 @@ class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
self.evaluate(gpu_optional_with_value_values))
|
||||
self.assertFalse(self.evaluate(gpu_optional_none_has_value))
|
||||
|
||||
def testNestedCopyToGPU(self):
|
||||
if not test_util.is_gpu_available():
|
||||
self.skipTest("No GPU available")
|
||||
|
||||
with ops.device("/cpu:0"):
|
||||
optional_with_value = optional_ops.Optional.from_value(
|
||||
(constant_op.constant(37.0), constant_op.constant("Foo"),
|
||||
constant_op.constant(42)))
|
||||
optional_none = optional_ops.Optional.none_from_structure(
|
||||
structure.TensorStructure(dtypes.float32, []))
|
||||
nested_optional = optional_ops.Optional.from_value(
|
||||
(optional_with_value._variant_tensor, optional_none._variant_tensor,
|
||||
1.0))
|
||||
|
||||
with ops.device("/gpu:0"):
|
||||
gpu_nested_optional = optional_ops._OptionalImpl(
|
||||
array_ops.identity(nested_optional._variant_tensor),
|
||||
nested_optional.value_structure)
|
||||
|
||||
gpu_nested_optional_has_value = gpu_nested_optional.has_value()
|
||||
gpu_nested_optional_values = gpu_nested_optional.get_value()
|
||||
|
||||
self.assertTrue(self.evaluate(gpu_nested_optional_has_value))
|
||||
|
||||
inner_with_value = optional_ops._OptionalImpl(
|
||||
gpu_nested_optional_values[0], optional_with_value.value_structure)
|
||||
|
||||
inner_none = optional_ops._OptionalImpl(
|
||||
gpu_nested_optional_values[1], optional_none.value_structure)
|
||||
|
||||
self.assertEqual((37.0, b"Foo", 42),
|
||||
self.evaluate(inner_with_value.get_value()))
|
||||
self.assertFalse(self.evaluate(inner_none.has_value()))
|
||||
self.assertEqual(1.0, self.evaluate(gpu_nested_optional_values[2]))
|
||||
|
||||
def _assertElementValueEqual(self, expected, actual):
|
||||
if isinstance(expected, dict):
|
||||
self.assertItemsEqual(list(expected.keys()), list(actual.keys()))
|
||||
|
Loading…
Reference in New Issue
Block a user