Implement some OptionalVariants functionality.
* ZerosLike variant registry function * BinaryAdd variant registry function * Shape variant registry function * Enables copying nested variants between host/devices * Refactors some common code from the corresponding TensorList functions. This should also enable TensorLists containing OptionalVariants (previously some of this code assumed any nested variant was also a TensorList). PiperOrigin-RevId: 223112633
This commit is contained in:
parent
5f111d5a6c
commit
567c0692de
@ -928,6 +928,7 @@ tf_cuda_library(
|
|||||||
"util/stream_executor_util.h",
|
"util/stream_executor_util.h",
|
||||||
"util/strided_slice_op.h",
|
"util/strided_slice_op.h",
|
||||||
"util/tensor_format.h",
|
"util/tensor_format.h",
|
||||||
|
"util/tensor_ops_util.h",
|
||||||
"util/tensor_slice_reader.h",
|
"util/tensor_slice_reader.h",
|
||||||
"util/tensor_slice_reader_cache.h",
|
"util/tensor_slice_reader_cache.h",
|
||||||
"util/tensor_slice_writer.h",
|
"util/tensor_slice_writer.h",
|
||||||
|
@ -621,6 +621,10 @@ tf_kernel_library(
|
|||||||
name = "optional_ops",
|
name = "optional_ops",
|
||||||
srcs = ["optional_ops.cc"],
|
srcs = ["optional_ops.cc"],
|
||||||
hdrs = ["optional_ops.h"],
|
hdrs = ["optional_ops.h"],
|
||||||
|
gpu_srcs = [
|
||||||
|
"optional_ops.cu.cc",
|
||||||
|
"optional_ops.h",
|
||||||
|
],
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/core:core_cpu_internal",
|
"//tensorflow/core:core_cpu_internal",
|
||||||
"//tensorflow/core:dataset_ops_op_lib",
|
"//tensorflow/core:dataset_ops_op_lib",
|
||||||
@ -628,6 +632,7 @@ tf_kernel_library(
|
|||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:lib_internal",
|
"//tensorflow/core:lib_internal",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
|
"//third_party/eigen3",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -22,75 +22,6 @@ limitations under the License.
|
|||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace data {
|
namespace data {
|
||||||
namespace {
|
namespace {
|
||||||
const char kOptionalVariantTypeName[] = "tensorflow::data::Optional";
|
|
||||||
|
|
||||||
// An `OptionalVariant` can represent either an "actual value" (a tuple of
|
|
||||||
// tensors) or "none", and may be stored in a DT_VARIANT tensor.
|
|
||||||
class OptionalVariant {
|
|
||||||
public:
|
|
||||||
// Create an `OptionalVariant` with no actual value.
|
|
||||||
OptionalVariant() : values_(nullptr) {}
|
|
||||||
|
|
||||||
// Create an `OptionalVariant` with the actual value given by the tuple of
|
|
||||||
// tensors in `values`.
|
|
||||||
explicit OptionalVariant(std::vector<Tensor> values)
|
|
||||||
: values_(new std::vector<Tensor>(std::move(values))) {}
|
|
||||||
|
|
||||||
OptionalVariant(const OptionalVariant& other) : values_(other.values_) {}
|
|
||||||
|
|
||||||
// Returns true if `this` represents an actual value.
|
|
||||||
bool has_value() const { return values_ != nullptr; }
|
|
||||||
|
|
||||||
// REQUIRES: `this->has_value()` must be true.
|
|
||||||
const std::vector<Tensor>& get_values() const {
|
|
||||||
CHECK(values_) << "Tried to get values from an empty OptionalVariant";
|
|
||||||
return *values_;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Implementations of the necessary methods for using `OptionalVariant`
|
|
||||||
// objects in DT_VARIANT tensors.
|
|
||||||
string TypeName() const { return kOptionalVariantTypeName; }
|
|
||||||
void Encode(VariantTensorData* data) const {
|
|
||||||
data->set_metadata(values_ != nullptr);
|
|
||||||
if (values_ != nullptr) {
|
|
||||||
for (const auto& t : *values_) {
|
|
||||||
*(data->add_tensors()) = t;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
bool Decode(const VariantTensorData& data) {
|
|
||||||
if (data.type_name() != TypeName()) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
bool has_value = false;
|
|
||||||
if (!data.get_metadata(&has_value)) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
if (has_value) {
|
|
||||||
values_.reset(new std::vector<Tensor>(data.tensors()));
|
|
||||||
} else {
|
|
||||||
values_.reset();
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
string DebugString() const {
|
|
||||||
if (values_) {
|
|
||||||
return strings::StrCat("OptionalVariant<", "values: (",
|
|
||||||
str_util::Join(*values_, ", ",
|
|
||||||
[](string* s, const Tensor& elem) {
|
|
||||||
*s = elem.DebugString();
|
|
||||||
}),
|
|
||||||
")>");
|
|
||||||
} else {
|
|
||||||
return strings::StrCat("OptionalVariant<None>");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
std::shared_ptr<const std::vector<Tensor>> values_;
|
|
||||||
};
|
|
||||||
|
|
||||||
class OptionalNoneOp : public OpKernel {
|
class OptionalNoneOp : public OpKernel {
|
||||||
public:
|
public:
|
||||||
@ -143,6 +74,12 @@ class OptionalGetValueOp : public OpKernel {
|
|||||||
explicit OptionalGetValueOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
explicit OptionalGetValueOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
||||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
|
||||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
|
||||||
|
OP_REQUIRES(
|
||||||
|
ctx, output_shapes_.size() == output_types_.size(),
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"output_types and output_shapes must be same length, got:\n",
|
||||||
|
"output_types: ", output_types_.size(), "\n",
|
||||||
|
"output_shapes: ", output_shapes_.size()));
|
||||||
}
|
}
|
||||||
|
|
||||||
void Compute(OpKernelContext* ctx) override {
|
void Compute(OpKernelContext* ctx) override {
|
||||||
@ -162,6 +99,10 @@ class OptionalGetValueOp : public OpKernel {
|
|||||||
ctx, optional->has_value(),
|
ctx, optional->has_value(),
|
||||||
errors::InvalidArgument("The given optional does not have a value."));
|
errors::InvalidArgument("The given optional does not have a value."));
|
||||||
const auto& components = optional->get_values();
|
const auto& components = optional->get_values();
|
||||||
|
OP_REQUIRES(ctx, components.size() == output_types_.size(),
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"The given optional has ", components.size(),
|
||||||
|
" components, expected ", output_types_.size()));
|
||||||
for (int i = 0; i < components.size(); ++i) {
|
for (int i = 0; i < components.size(); ++i) {
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(
|
||||||
ctx, components[i].dtype() == output_types_[i],
|
ctx, components[i].dtype() == output_types_[i],
|
||||||
@ -213,15 +154,7 @@ static Status OptionalDeviceCopy(
|
|||||||
std::vector<Tensor> to_values;
|
std::vector<Tensor> to_values;
|
||||||
to_values.reserve(from_values.size());
|
to_values.reserve(from_values.size());
|
||||||
for (const Tensor& t : from_values) {
|
for (const Tensor& t : from_values) {
|
||||||
if (t.dtype() == DT_VARIANT) {
|
if (DMAHelper::CanUseDMA(&t) || t.dtype() == DT_VARIANT) {
|
||||||
// TODO(b/116349787): Implement support for nested variants.
|
|
||||||
return errors::Unimplemented(
|
|
||||||
"Support for copying nested variants to device has not yet been "
|
|
||||||
"implemented.");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for (const Tensor& t : from_values) {
|
|
||||||
if (DMAHelper::CanUseDMA(&t)) {
|
|
||||||
Tensor tmp(t.dtype());
|
Tensor tmp(t.dtype());
|
||||||
TF_RETURN_IF_ERROR(copy(t, &tmp));
|
TF_RETURN_IF_ERROR(copy(t, &tmp));
|
||||||
to_values.push_back(std::move(tmp));
|
to_values.push_back(std::move(tmp));
|
||||||
@ -272,5 +205,20 @@ Status WriteOptionalNoneToOutput(OpKernelContext* ctx, int output_index) {
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(ZEROS_LIKE_VARIANT_UNARY_OP,
|
||||||
|
DEVICE_CPU, OptionalVariant,
|
||||||
|
OptionalZerosLike<CPUDevice>);
|
||||||
|
|
||||||
|
REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(ADD_VARIANT_BINARY_OP, DEVICE_CPU,
|
||||||
|
OptionalVariant,
|
||||||
|
OptionalBinaryAdd<CPUDevice>);
|
||||||
|
|
||||||
|
Status OptionalShape(const OptionalVariant& x, TensorShape* s) {
|
||||||
|
*s = TensorShape({});
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(OptionalVariant, OptionalShape);
|
||||||
|
|
||||||
} // namespace data
|
} // namespace data
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
37
tensorflow/core/kernels/data/optional_ops.cu.cc
Normal file
37
tensorflow/core/kernels/data/optional_ops.cu.cc
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#define EIGEN_USE_THREADS
|
||||||
|
#if GOOGLE_CUDA
|
||||||
|
#define EIGEN_USE_GPU
|
||||||
|
|
||||||
|
#include "tensorflow/core/kernels/data/optional_ops.h"
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/variant_op_registry.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace data {
|
||||||
|
|
||||||
|
REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(ZEROS_LIKE_VARIANT_UNARY_OP,
|
||||||
|
DEVICE_GPU, OptionalVariant,
|
||||||
|
OptionalZerosLike<GPUDevice>);
|
||||||
|
|
||||||
|
REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(ADD_VARIANT_BINARY_OP, DEVICE_GPU,
|
||||||
|
OptionalVariant,
|
||||||
|
OptionalBinaryAdd<GPUDevice>);
|
||||||
|
|
||||||
|
} // namespace data
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // GOOGLE_CUDA
|
@ -19,10 +19,13 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/core/framework/tensor.h"
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
#include "tensorflow/core/framework/variant_tensor_data.h"
|
#include "tensorflow/core/framework/variant_tensor_data.h"
|
||||||
|
#include "tensorflow/core/util/tensor_ops_util.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace data {
|
namespace data {
|
||||||
|
|
||||||
|
const char kOptionalVariantTypeName[] = "tensorflow::data::Optional";
|
||||||
|
|
||||||
// Stores a DT_VARIANT value representing an Optional with the given value
|
// Stores a DT_VARIANT value representing an Optional with the given value
|
||||||
// in the `output_index`^th output of the given kernel execution context.
|
// in the `output_index`^th output of the given kernel execution context.
|
||||||
Status WriteOptionalWithValueToOutput(OpKernelContext* ctx, int output_index,
|
Status WriteOptionalWithValueToOutput(OpKernelContext* ctx, int output_index,
|
||||||
@ -32,6 +35,122 @@ Status WriteOptionalWithValueToOutput(OpKernelContext* ctx, int output_index,
|
|||||||
// in the `output_index`^th output of the given kernel execution context.
|
// in the `output_index`^th output of the given kernel execution context.
|
||||||
Status WriteOptionalNoneToOutput(OpKernelContext* ctx, int output_index);
|
Status WriteOptionalNoneToOutput(OpKernelContext* ctx, int output_index);
|
||||||
|
|
||||||
|
// An `OptionalVariant` can represent either an "actual value" (a tuple of
|
||||||
|
// tensors) or "none", and may be stored in a DT_VARIANT tensor.
|
||||||
|
class OptionalVariant {
|
||||||
|
public:
|
||||||
|
// Create an `OptionalVariant` with no actual value.
|
||||||
|
OptionalVariant() : values_(nullptr) {}
|
||||||
|
|
||||||
|
// Create an `OptionalVariant` with the actual value given by the tuple of
|
||||||
|
// tensors in `values`.
|
||||||
|
explicit OptionalVariant(std::vector<Tensor> values)
|
||||||
|
: values_(new std::vector<Tensor>(std::move(values))) {}
|
||||||
|
|
||||||
|
OptionalVariant(const OptionalVariant& other) : values_(other.values_) {}
|
||||||
|
|
||||||
|
// Returns true if `this` represents an actual value.
|
||||||
|
bool has_value() const { return values_ != nullptr; }
|
||||||
|
|
||||||
|
// REQUIRES: `this->has_value()` must be true.
|
||||||
|
const std::vector<Tensor>& get_values() const {
|
||||||
|
DCHECK(values_) << "Tried to get values from an empty OptionalVariant";
|
||||||
|
return *values_;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Implementations of the necessary methods for using `OptionalVariant`
|
||||||
|
// objects in DT_VARIANT tensors.
|
||||||
|
string TypeName() const { return kOptionalVariantTypeName; }
|
||||||
|
void Encode(VariantTensorData* data) const {
|
||||||
|
data->set_metadata(values_ != nullptr);
|
||||||
|
if (values_ != nullptr) {
|
||||||
|
for (const auto& t : *values_) {
|
||||||
|
*(data->add_tensors()) = t;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool Decode(const VariantTensorData& data) {
|
||||||
|
if (data.type_name() != TypeName()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
bool has_value = false;
|
||||||
|
if (!data.get_metadata(&has_value)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (has_value) {
|
||||||
|
values_.reset(new std::vector<Tensor>(data.tensors()));
|
||||||
|
} else {
|
||||||
|
values_.reset();
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
string DebugString() const {
|
||||||
|
if (values_) {
|
||||||
|
return strings::StrCat("OptionalVariant<", "values: (",
|
||||||
|
str_util::Join(*values_, ", ",
|
||||||
|
[](string* s, const Tensor& elem) {
|
||||||
|
*s = elem.DebugString();
|
||||||
|
}),
|
||||||
|
")>");
|
||||||
|
} else {
|
||||||
|
return strings::StrCat("OptionalVariant<None>");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::shared_ptr<const std::vector<Tensor>> values_;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename Device>
|
||||||
|
Status OptionalZerosLike(OpKernelContext* ctx, const OptionalVariant& x,
|
||||||
|
OptionalVariant* y) {
|
||||||
|
if (!x.has_value()) {
|
||||||
|
*y = x;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
std::vector<Tensor> zero_tensors;
|
||||||
|
for (const Tensor& tensor : x.get_values()) {
|
||||||
|
Tensor zero_t;
|
||||||
|
TF_RETURN_IF_ERROR(ZerosLikeTensor<Device>(ctx, tensor, &zero_t));
|
||||||
|
zero_tensors.push_back(std::move(zero_t));
|
||||||
|
}
|
||||||
|
*y = OptionalVariant(zero_tensors);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Device>
|
||||||
|
Status OptionalBinaryAdd(OpKernelContext* ctx, const OptionalVariant& a,
|
||||||
|
const OptionalVariant& b, OptionalVariant* out) {
|
||||||
|
// TODO(skyewm): should adding a value to a non-value be a no-op instead?
|
||||||
|
if (a.has_value() != b.has_value()) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"Cannot add optionals because one has a value and the other doesn't.");
|
||||||
|
}
|
||||||
|
if (!a.has_value()) {
|
||||||
|
*out = a;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
if (a.get_values().size() != b.get_values().size()) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"Cannot add optionals because they have different numbers of "
|
||||||
|
"components (",
|
||||||
|
a.get_values().size(), " vs. ", b.get_values().size(), ").");
|
||||||
|
}
|
||||||
|
std::vector<Tensor> out_tensors;
|
||||||
|
for (int i = 0; i < a.get_values().size(); ++i) {
|
||||||
|
const Tensor& a_tensor = a.get_values()[i];
|
||||||
|
const Tensor& b_tensor = b.get_values()[i];
|
||||||
|
Tensor out_tensor;
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
BinaryAddTensors<Device>(ctx, a_tensor, b_tensor, &out_tensor));
|
||||||
|
out_tensors.push_back(std::move(out_tensor));
|
||||||
|
}
|
||||||
|
*out = OptionalVariant(out_tensors);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace data
|
} // namespace data
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
@ -30,6 +30,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/kernels/concat_lib.h"
|
#include "tensorflow/core/kernels/concat_lib.h"
|
||||||
#include "tensorflow/core/lib/core/coding.h"
|
#include "tensorflow/core/lib/core/coding.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
|
#include "tensorflow/core/util/tensor_ops_util.h"
|
||||||
#include "tensorflow/core/util/util.h"
|
#include "tensorflow/core/util/util.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
@ -357,51 +358,10 @@ Status TensorListBinaryAdd(OpKernelContext* c, const TensorList& a,
|
|||||||
for (int i = 0; i < a.tensors.size(); ++i) {
|
for (int i = 0; i < a.tensors.size(); ++i) {
|
||||||
const Tensor& a_tensor = a.tensors[i];
|
const Tensor& a_tensor = a.tensors[i];
|
||||||
const Tensor& b_tensor = b.tensors[i];
|
const Tensor& b_tensor = b.tensors[i];
|
||||||
if (a_tensor.dtype() == DT_INVALID) {
|
|
||||||
out->tensors.push_back(b_tensor);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (b_tensor.dtype() == DT_INVALID) {
|
|
||||||
out->tensors.push_back(a_tensor);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (a_tensor.shape() != b_tensor.shape()) {
|
|
||||||
// TODO(apassos) support broadcasting additions here?
|
|
||||||
return errors::InvalidArgument(
|
|
||||||
"Trying to add two tensors with incompatible element shapes. "
|
|
||||||
"One is ",
|
|
||||||
a_tensor.shape().DebugString(), " and the other is ",
|
|
||||||
b_tensor.shape().DebugString(), " in position ", i);
|
|
||||||
}
|
|
||||||
Tensor out_tensor;
|
Tensor out_tensor;
|
||||||
AllocatorAttributes attr;
|
TF_RETURN_IF_ERROR(
|
||||||
if (a_tensor.dtype() == DT_VARIANT) {
|
BinaryAddTensors<Device>(c, a_tensor, b_tensor, &out_tensor));
|
||||||
attr.set_on_host(true);
|
|
||||||
}
|
|
||||||
TF_RETURN_IF_ERROR(c->allocate_temp(a_tensor.dtype(), a_tensor.shape(),
|
|
||||||
&out_tensor, attr));
|
|
||||||
out->tensors.push_back(out_tensor);
|
out->tensors.push_back(out_tensor);
|
||||||
switch (out_tensor.dtype()) {
|
|
||||||
#define DTYPE_CASE(dtype) \
|
|
||||||
case DataTypeToEnum<dtype>::value: \
|
|
||||||
out_tensor.flat<dtype>().device(c->eigen_device<Device>()) = \
|
|
||||||
a_tensor.flat<dtype>() + b_tensor.flat<dtype>(); \
|
|
||||||
break;
|
|
||||||
|
|
||||||
TF_CALL_NUMBER_TYPES(DTYPE_CASE)
|
|
||||||
|
|
||||||
#undef DTYPE_CASE
|
|
||||||
case DataTypeToEnum<Variant>::value: {
|
|
||||||
Variant* v_out = &(out_tensor.scalar<Variant>()());
|
|
||||||
TF_RETURN_IF_ERROR(BinaryOpVariants<Device>(
|
|
||||||
c, ADD_VARIANT_BINARY_OP, a_tensor.scalar<Variant>()(),
|
|
||||||
b_tensor.scalar<Variant>()(), v_out));
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
return errors::InvalidArgument("Trying to add unsupported dtype ",
|
|
||||||
out_tensor.dtype());
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
@ -414,46 +374,7 @@ Status TensorListZerosLike(OpKernelContext* c, const TensorList& x,
|
|||||||
y->tensors.reserve(x.tensors.size());
|
y->tensors.reserve(x.tensors.size());
|
||||||
for (const Tensor& t : x.tensors) {
|
for (const Tensor& t : x.tensors) {
|
||||||
Tensor out_tensor;
|
Tensor out_tensor;
|
||||||
AllocatorAttributes attr;
|
TF_RETURN_IF_ERROR(ZerosLikeTensor<Device>(c, t, &out_tensor));
|
||||||
if (t.dtype() == DT_VARIANT) {
|
|
||||||
attr.set_on_host(true);
|
|
||||||
}
|
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
c->allocate_temp(t.dtype(), t.shape(), &out_tensor, attr));
|
|
||||||
switch (out_tensor.dtype()) {
|
|
||||||
#define DTYPE_CASE(dtype) \
|
|
||||||
case DataTypeToEnum<dtype>::value: \
|
|
||||||
out_tensor.flat<dtype>().device(c->eigen_device<Device>()) = \
|
|
||||||
out_tensor.flat<dtype>().constant(dtype(0)); \
|
|
||||||
break;
|
|
||||||
|
|
||||||
TF_CALL_POD_TYPES(DTYPE_CASE)
|
|
||||||
|
|
||||||
#undef DTYPE_CASE
|
|
||||||
|
|
||||||
case DT_INVALID: {
|
|
||||||
// Uninitialized tensor in the TensorList.
|
|
||||||
out_tensor = Tensor(DT_INVALID);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case DataTypeToEnum<Variant>::value: {
|
|
||||||
const TensorList* inner_x = t.scalar<Variant>()().get<TensorList>();
|
|
||||||
if (inner_x == nullptr) {
|
|
||||||
return errors::InvalidArgument("Input handle is not a list. Saw: '",
|
|
||||||
t.scalar<Variant>()().DebugString(),
|
|
||||||
"'");
|
|
||||||
}
|
|
||||||
TensorList inner_y;
|
|
||||||
TF_RETURN_IF_ERROR(TensorListZerosLike<Device>(c, *inner_x, &inner_y));
|
|
||||||
out_tensor.scalar<Variant>()() = std::move(inner_y);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
default:
|
|
||||||
return errors::InvalidArgument(
|
|
||||||
"Trying to compute zeros_like for unsupported dtype ",
|
|
||||||
DataTypeString(out_tensor.dtype()));
|
|
||||||
}
|
|
||||||
y->tensors.emplace_back(out_tensor);
|
y->tensors.emplace_back(out_tensor);
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
128
tensorflow/core/util/tensor_ops_util.h
Normal file
128
tensorflow/core/util/tensor_ops_util.h
Normal file
@ -0,0 +1,128 @@
|
|||||||
|
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#ifndef TENSORFLOW_CORE_UTIL_TENSOR_OPS_UTIL_H_
|
||||||
|
#define TENSORFLOW_CORE_UTIL_TENSOR_OPS_UTIL_H_
|
||||||
|
|
||||||
|
#define EIGEN_USE_THREADS
|
||||||
|
|
||||||
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||||
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
#include "tensorflow/core/framework/register_types.h"
|
||||||
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
|
#include "tensorflow/core/framework/types.h"
|
||||||
|
#include "tensorflow/core/framework/variant_op_registry.h"
|
||||||
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||||
|
typedef Eigen::GpuDevice GPUDevice;
|
||||||
|
|
||||||
|
template <typename Device>
|
||||||
|
Status ZerosLikeTensor(OpKernelContext* ctx, const Tensor& x, Tensor* out) {
|
||||||
|
AllocatorAttributes attr;
|
||||||
|
if (x.dtype() == DT_VARIANT) {
|
||||||
|
attr.set_on_host(true);
|
||||||
|
}
|
||||||
|
TF_RETURN_IF_ERROR(ctx->allocate_temp(x.dtype(), x.shape(), out, attr));
|
||||||
|
|
||||||
|
switch (out->dtype()) {
|
||||||
|
#define DTYPE_CASE(dtype) \
|
||||||
|
case DataTypeToEnum<dtype>::value: \
|
||||||
|
/* TODO(skyewm): use SetZeroFunctor like in ZerosLikeOp? */ \
|
||||||
|
out->flat<dtype>().device(ctx->eigen_device<Device>()) = \
|
||||||
|
out->flat<dtype>().constant(dtype(0)); \
|
||||||
|
break;
|
||||||
|
|
||||||
|
TF_CALL_POD_TYPES(DTYPE_CASE)
|
||||||
|
#undef DTYPE_CASE
|
||||||
|
|
||||||
|
case DT_INVALID: {
|
||||||
|
*out = Tensor(DT_INVALID);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case DataTypeToEnum<Variant>::value: {
|
||||||
|
Variant* out_variant = out->scalar<Variant>().data();
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
UnaryOpVariant<Device>(ctx, ZEROS_LIKE_VARIANT_UNARY_OP,
|
||||||
|
x.scalar<Variant>()(), out_variant));
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"Trying to compute zeros_like for unsupported dtype ",
|
||||||
|
DataTypeString(out->dtype()));
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Device>
|
||||||
|
Status BinaryAddTensors(OpKernelContext* ctx, const Tensor& a, const Tensor& b,
|
||||||
|
Tensor* out) {
|
||||||
|
if (a.dtype() == DT_INVALID) {
|
||||||
|
*out = b;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
if (b.dtype() == DT_INVALID) {
|
||||||
|
*out = a;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
if (a.dtype() != b.dtype()) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"Trying to add two tensors with incompatible element types. ",
|
||||||
|
"One is ", DataTypeString(a.dtype()), " and the other is ",
|
||||||
|
DataTypeString(b.dtype()));
|
||||||
|
}
|
||||||
|
if (a.shape() != b.shape()) {
|
||||||
|
// TODO(apassos) support broadcasting additions here?
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"Trying to add two tensors with incompatible element shapes. ",
|
||||||
|
"One is ", a.shape().DebugString(), " and the other is ",
|
||||||
|
b.shape().DebugString());
|
||||||
|
}
|
||||||
|
|
||||||
|
AllocatorAttributes attr;
|
||||||
|
if (a.dtype() == DT_VARIANT) {
|
||||||
|
attr.set_on_host(true);
|
||||||
|
}
|
||||||
|
TF_RETURN_IF_ERROR(ctx->allocate_temp(a.dtype(), a.shape(), out, attr));
|
||||||
|
|
||||||
|
switch (out->dtype()) {
|
||||||
|
#define DTYPE_CASE(dtype) \
|
||||||
|
case DataTypeToEnum<dtype>::value: \
|
||||||
|
out->flat<dtype>().device(ctx->eigen_device<Device>()) = \
|
||||||
|
a.flat<dtype>() + b.flat<dtype>(); \
|
||||||
|
break;
|
||||||
|
|
||||||
|
TF_CALL_NUMBER_TYPES(DTYPE_CASE)
|
||||||
|
#undef DTYPE_CASE
|
||||||
|
|
||||||
|
case DataTypeToEnum<Variant>::value: {
|
||||||
|
Variant* out_variant = out->scalar<Variant>().data();
|
||||||
|
TF_RETURN_IF_ERROR(BinaryOpVariants<Device>(
|
||||||
|
ctx, ADD_VARIANT_BINARY_OP, a.scalar<Variant>()(),
|
||||||
|
b.scalar<Variant>()(), out_variant));
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return errors::InvalidArgument("Trying to add unsupported dtype ",
|
||||||
|
out->dtype());
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_CORE_UTIL_TENSOR_OPS_UTIL_H_
|
@ -33,6 +33,7 @@ from tensorflow.python.framework import sparse_tensor
|
|||||||
from tensorflow.python.framework import tensor_shape
|
from tensorflow.python.framework import tensor_shape
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
@ -87,6 +88,90 @@ class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
with self.assertRaises(errors.InvalidArgumentError):
|
with self.assertRaises(errors.InvalidArgumentError):
|
||||||
self.evaluate(opt.get_value())
|
self.evaluate(opt.get_value())
|
||||||
|
|
||||||
|
def testAddN(self):
|
||||||
|
devices = ["/cpu:0"]
|
||||||
|
if test_util.is_gpu_available():
|
||||||
|
devices.append("/gpu:0")
|
||||||
|
for device in devices:
|
||||||
|
with ops.device(device):
|
||||||
|
# With value
|
||||||
|
opt1 = optional_ops.Optional.from_value((1.0, 2.0))
|
||||||
|
opt2 = optional_ops.Optional.from_value((3.0, 4.0))
|
||||||
|
|
||||||
|
add_tensor = math_ops.add_n([opt1._variant_tensor,
|
||||||
|
opt2._variant_tensor])
|
||||||
|
add_opt = optional_ops._OptionalImpl(add_tensor, opt1.value_structure)
|
||||||
|
self.assertAllEqual(self.evaluate(add_opt.get_value()), (4.0, 6.0))
|
||||||
|
|
||||||
|
# Without value
|
||||||
|
opt_none1 = optional_ops.Optional.none_from_structure(
|
||||||
|
opt1.value_structure)
|
||||||
|
opt_none2 = optional_ops.Optional.none_from_structure(
|
||||||
|
opt2.value_structure)
|
||||||
|
add_tensor = math_ops.add_n([opt_none1._variant_tensor,
|
||||||
|
opt_none2._variant_tensor])
|
||||||
|
add_opt = optional_ops._OptionalImpl(add_tensor,
|
||||||
|
opt_none1.value_structure)
|
||||||
|
self.assertFalse(self.evaluate(add_opt.has_value()))
|
||||||
|
|
||||||
|
def testNestedAddN(self):
|
||||||
|
devices = ["/cpu:0"]
|
||||||
|
if test_util.is_gpu_available():
|
||||||
|
devices.append("/gpu:0")
|
||||||
|
for device in devices:
|
||||||
|
with ops.device(device):
|
||||||
|
opt1 = optional_ops.Optional.from_value([1, 2.0])
|
||||||
|
opt2 = optional_ops.Optional.from_value([3, 4.0])
|
||||||
|
opt3 = optional_ops.Optional.from_value((5.0, opt1._variant_tensor))
|
||||||
|
opt4 = optional_ops.Optional.from_value((6.0, opt2._variant_tensor))
|
||||||
|
|
||||||
|
add_tensor = math_ops.add_n([opt3._variant_tensor,
|
||||||
|
opt4._variant_tensor])
|
||||||
|
add_opt = optional_ops._OptionalImpl(add_tensor, opt3.value_structure)
|
||||||
|
self.assertEqual(self.evaluate(add_opt.get_value()[0]), 11.0)
|
||||||
|
|
||||||
|
inner_add_opt = optional_ops._OptionalImpl(add_opt.get_value()[1],
|
||||||
|
opt1.value_structure)
|
||||||
|
self.assertAllEqual(inner_add_opt.get_value(), [4, 6.0])
|
||||||
|
|
||||||
|
def testZerosLike(self):
|
||||||
|
devices = ["/cpu:0"]
|
||||||
|
if test_util.is_gpu_available():
|
||||||
|
devices.append("/gpu:0")
|
||||||
|
for device in devices:
|
||||||
|
with ops.device(device):
|
||||||
|
# With value
|
||||||
|
opt = optional_ops.Optional.from_value((1.0, 2.0))
|
||||||
|
zeros_tensor = array_ops.zeros_like(opt._variant_tensor)
|
||||||
|
zeros_opt = optional_ops._OptionalImpl(zeros_tensor,
|
||||||
|
opt.value_structure)
|
||||||
|
self.assertAllEqual(self.evaluate(zeros_opt.get_value()),
|
||||||
|
(0.0, 0.0))
|
||||||
|
|
||||||
|
# Without value
|
||||||
|
opt_none = optional_ops.Optional.none_from_structure(
|
||||||
|
opt.value_structure)
|
||||||
|
zeros_tensor = array_ops.zeros_like(opt_none._variant_tensor)
|
||||||
|
zeros_opt = optional_ops._OptionalImpl(zeros_tensor,
|
||||||
|
opt_none.value_structure)
|
||||||
|
self.assertFalse(self.evaluate(zeros_opt.has_value()))
|
||||||
|
|
||||||
|
def testNestedZerosLike(self):
|
||||||
|
devices = ["/cpu:0"]
|
||||||
|
if test_util.is_gpu_available():
|
||||||
|
devices.append("/gpu:0")
|
||||||
|
for device in devices:
|
||||||
|
with ops.device(device):
|
||||||
|
opt1 = optional_ops.Optional.from_value(1.0)
|
||||||
|
opt2 = optional_ops.Optional.from_value(opt1._variant_tensor)
|
||||||
|
|
||||||
|
zeros_tensor = array_ops.zeros_like(opt2._variant_tensor)
|
||||||
|
zeros_opt = optional_ops._OptionalImpl(zeros_tensor,
|
||||||
|
opt2.value_structure)
|
||||||
|
inner_zeros_opt = optional_ops._OptionalImpl(zeros_opt.get_value(),
|
||||||
|
opt1.value_structure)
|
||||||
|
self.assertEqual(self.evaluate(inner_zeros_opt.get_value()), 0.0)
|
||||||
|
|
||||||
def testCopyToGPU(self):
|
def testCopyToGPU(self):
|
||||||
if not test_util.is_gpu_available():
|
if not test_util.is_gpu_available():
|
||||||
self.skipTest("No GPU available")
|
self.skipTest("No GPU available")
|
||||||
@ -116,6 +201,41 @@ class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
self.evaluate(gpu_optional_with_value_values))
|
self.evaluate(gpu_optional_with_value_values))
|
||||||
self.assertFalse(self.evaluate(gpu_optional_none_has_value))
|
self.assertFalse(self.evaluate(gpu_optional_none_has_value))
|
||||||
|
|
||||||
|
def testNestedCopyToGPU(self):
|
||||||
|
if not test_util.is_gpu_available():
|
||||||
|
self.skipTest("No GPU available")
|
||||||
|
|
||||||
|
with ops.device("/cpu:0"):
|
||||||
|
optional_with_value = optional_ops.Optional.from_value(
|
||||||
|
(constant_op.constant(37.0), constant_op.constant("Foo"),
|
||||||
|
constant_op.constant(42)))
|
||||||
|
optional_none = optional_ops.Optional.none_from_structure(
|
||||||
|
structure.TensorStructure(dtypes.float32, []))
|
||||||
|
nested_optional = optional_ops.Optional.from_value(
|
||||||
|
(optional_with_value._variant_tensor, optional_none._variant_tensor,
|
||||||
|
1.0))
|
||||||
|
|
||||||
|
with ops.device("/gpu:0"):
|
||||||
|
gpu_nested_optional = optional_ops._OptionalImpl(
|
||||||
|
array_ops.identity(nested_optional._variant_tensor),
|
||||||
|
nested_optional.value_structure)
|
||||||
|
|
||||||
|
gpu_nested_optional_has_value = gpu_nested_optional.has_value()
|
||||||
|
gpu_nested_optional_values = gpu_nested_optional.get_value()
|
||||||
|
|
||||||
|
self.assertTrue(self.evaluate(gpu_nested_optional_has_value))
|
||||||
|
|
||||||
|
inner_with_value = optional_ops._OptionalImpl(
|
||||||
|
gpu_nested_optional_values[0], optional_with_value.value_structure)
|
||||||
|
|
||||||
|
inner_none = optional_ops._OptionalImpl(
|
||||||
|
gpu_nested_optional_values[1], optional_none.value_structure)
|
||||||
|
|
||||||
|
self.assertEqual((37.0, b"Foo", 42),
|
||||||
|
self.evaluate(inner_with_value.get_value()))
|
||||||
|
self.assertFalse(self.evaluate(inner_none.has_value()))
|
||||||
|
self.assertEqual(1.0, self.evaluate(gpu_nested_optional_values[2]))
|
||||||
|
|
||||||
def _assertElementValueEqual(self, expected, actual):
|
def _assertElementValueEqual(self, expected, actual):
|
||||||
if isinstance(expected, dict):
|
if isinstance(expected, dict):
|
||||||
self.assertItemsEqual(list(expected.keys()), list(actual.keys()))
|
self.assertItemsEqual(list(expected.keys()), list(actual.keys()))
|
||||||
|
Loading…
Reference in New Issue
Block a user