Added gradients for RaggedTensorToVariant and RaggedTensorFromVariant. (This allows gradients to pass through map_fn when it is applied to ragged tensors.)

PiperOrigin-RevId: 337108621
Change-Id: I73d5f3296181877f0cc4c7a6273b693bcf8310ab
This commit is contained in:
Edward Loper 2020-10-14 09:41:17 -07:00 committed by TensorFlower Gardener
parent be99659bf6
commit be6b1fdb06
18 changed files with 820 additions and 593 deletions

View File

@ -138,6 +138,7 @@
stateful ops.
* Added `tf.config.experimental.get_memory_usage` to return total memory
usage of the device.
* Added gradients for `RaggedTensorToVariant` and `RaggedTensorFromVariant`.
* `tf.data`:
* tf.data service:
* Added new `tf.data.experimental.service.register_dataset` and

View File

@ -0,0 +1,38 @@
op {
graph_op_name: "RaggedTensorToVariantGradient"
visibility: HIDDEN
in_arg {
name: "encoded_ragged_grad"
description: <<END
A `variant` Tensor containing encoded `RaggedTensor` gradients.
END
}
in_arg {
name: "row_splits"
description: <<END
Outermost row-splits that were used as input to the RaggedTensorToVariant op.
END
}
in_arg {
name: "dense_values_shape"
description: <<END
Shape of the dense_values that was used as an input to the
RaggedTensorToVariant op.
END
}
out_arg {
name: "dense_values_grad"
description: <<END
Gradient for the dense_values of the RaggedTensorToVariant op.
END
}
summary: <<END
Helper used to compute the gradient for `RaggedTensorToVariant`.
END
description: <<END
Computes the gradient for the dense_values input to the RaggedTensorToVariant
op, given the variant-encoded ragged gradients of the outputs, along with
the outer row-splits and the shape of the dense-values that were provided as
inputs to the RaggedTensorToVariant op.
END
}

View File

@ -1419,10 +1419,22 @@ tf_cc_test(
],
)
cc_library(
name = "ragged_tensor_variant",
srcs = ["ragged_tensor_variant.cc"],
hdrs = ["ragged_tensor_variant.h"],
deps = [
":cwise_op",
"//tensorflow/core:framework",
],
)
tf_kernel_library(
name = "ragged_tensor_to_variant_op",
srcs = ["ragged_tensor_to_variant_op.cc"],
deps = [
":concat_lib",
":ragged_tensor_variant",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
],
@ -1432,6 +1444,7 @@ tf_kernel_library(
name = "ragged_tensor_from_variant_op",
srcs = ["ragged_tensor_from_variant_op.cc"],
deps = [
":ragged_tensor_variant",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
],
@ -1444,6 +1457,7 @@ tf_cc_test(
deps = [
":ops_testutil",
":ragged_tensor_to_variant_op",
":ragged_tensor_variant",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:test",
@ -1460,6 +1474,7 @@ tf_cc_test(
deps = [
":ops_testutil",
":ragged_tensor_from_variant_op",
":ragged_tensor_variant",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:test",

View File

@ -424,6 +424,7 @@ tf_kernel_library(
"//tensorflow/core:framework",
"//tensorflow/core:functional_ops_op_lib",
"//tensorflow/core:lib",
"//tensorflow/core/kernels:ragged_tensor_variant",
"//tensorflow/core/kernels/data:dataset_utils",
"//tensorflow/core/kernels/data:name_utils",
"//tensorflow/core/kernels/data:parallel_map_dataset_op",

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/core/kernels/data/name_utils.h"
#include "tensorflow/core/kernels/data/parallel_map_dataset_op.h"
#include "tensorflow/core/kernels/data/stats_utils.h"
#include "tensorflow/core/kernels/ragged_tensor_variant.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/stringprintf.h"
#include "tensorflow/core/profiler/lib/traceme.h"
@ -678,12 +679,9 @@ class ParseExampleDatasetOp : public UnaryDatasetOpKernel {
for (int d = 0; d < dataset()->ragged_keys_.size(); ++d) {
int output_index =
dataset()->key_to_output_index_.at(dataset()->ragged_keys_[d]);
(*output)[output_index] = Tensor(ctx->allocator({}), DT_VARIANT, {});
Tensor serialized_ragged =
Tensor(ctx->allocator({}), DT_VARIANT, {2});
auto serialized_ragged_t = serialized_ragged.vec<Variant>();
serialized_ragged_t(0) = example_result.ragged_splits[d];
serialized_ragged_t(1) = example_result.ragged_values[d];
RaggedTensorVariant serialized_ragged;
serialized_ragged.append_splits(example_result.ragged_splits[d]);
serialized_ragged.set_values(example_result.ragged_values[d]);
(*output)[output_index] = Tensor(ctx->allocator({}), DT_VARIANT, {});
Tensor& ragged_wrapper = (*output)[output_index];
ragged_wrapper.scalar<Variant>()() = serialized_ragged;

View File

@ -20,110 +20,76 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/variant.h"
#include "tensorflow/core/framework/variant_encode_decode.h"
#include "tensorflow/core/kernels/ragged_tensor_variant.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
namespace tensorflow {
namespace {
struct RaggedTensor {
Tensor values;
std::vector<Tensor> nested_splits;
};
Status RaggedComponentsFromVariant(const Tensor& encoded_variant,
int ragged_rank, DataType value_dtype,
DataType split_dtype,
std::vector<RaggedTensor>* decoded_ragged) {
Status RaggedComponentsFromVariant(
const Tensor& encoded_variant, int ragged_rank, DataType value_dtype,
DataType split_dtype, std::vector<RaggedTensorVariant>* decoded_ragged) {
const auto& flat_variants = encoded_variant.flat<Variant>();
decoded_ragged->resize(flat_variants.size());
// Step 1: Extract the 1-D DT_VARIANT Tensor from each Variant element in the
// input.
decoded_ragged->reserve(flat_variants.size());
for (int i = 0; i < flat_variants.size(); i++) {
const auto& flat_variant = flat_variants(i);
const Tensor* encoded_list = flat_variant.get<Tensor>();
if (encoded_list == nullptr) {
const RaggedTensorVariant* decoded =
flat_variant.get<RaggedTensorVariant>();
if (decoded == nullptr) {
return errors::InvalidArgument(
"Input Variant element at index ", i,
" doesn't hold a Tensor: ", flat_variant.DebugString());
" doesn't hold a RaggedTensorVariant: ", flat_variant.DebugString());
}
if (encoded_list->dims() != 1) {
decoded_ragged->push_back(*decoded);
decoded = &decoded_ragged->back();
// Check ragged rank & types
if (decoded->ragged_rank() != ragged_rank) {
return errors::InvalidArgument(
"Encoded input Variant must have rank 1, but found rank: ",
encoded_list->dims(),
". encoded input Variant: ", encoded_list->DebugString());
"Encoded input RaggedTensorVariant has ragged_rank=",
decoded->ragged_rank(), ". Expected ragged_rank=", ragged_rank, ".");
}
if (encoded_list->NumElements() != (ragged_rank + 1) &&
encoded_list->NumElements() != 1) {
return errors::InvalidArgument(
"Encoded input Variant must hold either input_ragged_rank + 1 "
"Tensors or an empty Tensor (zero splits Tensors, 1 values Tensor), "
"input_ragged_rank: ",
ragged_rank,
", encoded input Variant: ", encoded_list->DebugString());
}
const auto& input_vec = encoded_list->vec<Variant>();
// Step 2: Get the splits and value Tensors from the 1-D DT_VARIANT Tensor
// to create the component RaggedTensors.
(*decoded_ragged)[i].nested_splits.reserve(ragged_rank);
for (int j = 0; j < ragged_rank; j++) {
const Tensor* split_tensor = input_vec(j).get<Tensor>();
if (split_tensor == nullptr) {
return errors::InvalidArgument(
"Encoded scalar element at index ", i,
" doesn't have a splits Tensor at split_index ", j, ": ",
input_vec(j).DebugString());
}
Tensor splits_tensor = *split_tensor;
if (splits_tensor.dtype() != split_dtype) {
return errors::InvalidArgument(
"Expected splits Tensor dtype: ", split_dtype,
", found: ", splits_tensor.dtype());
}
if (splits_tensor.dims() != 1) {
return errors::InvalidArgument(
"Ragged splits must have rank 1; encoded scalar element at index ",
i, " has splits Tensor at split_index ", j, ": ",
splits_tensor.DebugString());
}
(*decoded_ragged)[i].nested_splits.push_back(splits_tensor);
}
const Tensor* values_tensor = input_vec(ragged_rank).get<Tensor>();
if (values_tensor == nullptr) {
return errors::InvalidArgument("Encoded scalar element at index ", i,
" doesn't have a values Tensor: ",
input_vec(ragged_rank).DebugString());
}
if (values_tensor->dtype() != value_dtype) {
if (decoded->values().dtype() != value_dtype) {
return errors::InvalidArgument(
"Expected values Tensor dtype: ", DataTypeString(value_dtype),
", found: ", DataTypeString(values_tensor->dtype()));
", found: ", DataTypeString(decoded->values().dtype()));
}
if (values_tensor->dims() < 1) {
if (decoded->values().dims() < 1) {
return errors::InvalidArgument(
"Ragged values must have rank >= 1; encoded scalar element at index ",
i, " has values Tensor: ", values_tensor->DebugString());
i, " has values Tensor: ", decoded->values().DebugString());
}
for (const auto& splits : decoded->nested_splits()) {
if (splits.dtype() != split_dtype) {
return errors::InvalidArgument(
"Expected row_splits Tensor dtype: ", DataTypeString(split_dtype),
", found: ", DataTypeString(splits.dtype()));
}
if (splits.dims() != 1) {
return errors::InvalidArgument(
"Ragged splits must have rank 1; encoded scalar element at index ",
i, " has splits Tensor ", splits.DebugString());
}
}
(*decoded_ragged)[i].values = *values_tensor;
}
return Status::OK();
}
template <typename VALUE_TYPE, typename SPLIT_TYPE>
Status NestedStackRaggedTensors(
const std::vector<RaggedTensor>& ragged_components,
const std::vector<RaggedTensorVariant>& ragged_components,
const std::vector<int>& nested_dim_sizes, const int input_ragged_rank,
const int output_ragged_rank, RaggedTensor* output_ragged) {
output_ragged->nested_splits.reserve(output_ragged_rank);
const int output_ragged_rank, RaggedTensorVariant* output_ragged) {
output_ragged->mutable_nested_splits()->reserve(output_ragged_rank);
const int dims = nested_dim_sizes.size();
// Populate first `dims - 1` splits.
for (int i = 0; i < dims - 1; i++) {
int dims_splits_size = nested_dim_sizes[i] + 1;
output_ragged->nested_splits.push_back(Tensor(
DataTypeToEnum<SPLIT_TYPE>::value, TensorShape({dims_splits_size})));
auto splits_vec = output_ragged->nested_splits[i].vec<SPLIT_TYPE>();
output_ragged->append_splits(Tensor(DataTypeToEnum<SPLIT_TYPE>::value,
TensorShape({dims_splits_size})));
auto splits_vec = output_ragged->mutable_splits(i)->vec<SPLIT_TYPE>();
int split_diff = nested_dim_sizes[i + 1];
for (int j = 0; j < dims_splits_size; j++) {
splits_vec(j) = j * split_diff;
@ -132,15 +98,15 @@ Status NestedStackRaggedTensors(
// Populate `dims`-th split.
int splits_size = ragged_components.size() + 1;
output_ragged->nested_splits.push_back(
output_ragged->append_splits(
Tensor(DataTypeToEnum<SPLIT_TYPE>::value, TensorShape({splits_size})));
auto dims_splits_vec =
output_ragged->nested_splits[dims - 1].vec<SPLIT_TYPE>();
output_ragged->mutable_splits(dims - 1)->vec<SPLIT_TYPE>();
dims_splits_vec(0) = 0;
for (int i = 0; i < ragged_components.size(); i++) {
int split_val = ragged_components[i].values.shape().dim_size(0);
if (input_ragged_rank != 0 && !ragged_components[i].nested_splits.empty()) {
split_val = ragged_components[i].nested_splits[0].NumElements() - 1;
int split_val = ragged_components[i].values().shape().dim_size(0);
if (input_ragged_rank != 0 && ragged_components[i].ragged_rank() > 0) {
split_val = ragged_components[i].splits(0).NumElements() - 1;
}
dims_splits_vec(i + 1) = dims_splits_vec(i) + split_val;
}
@ -150,24 +116,24 @@ Status NestedStackRaggedTensors(
int split_index = dims + i;
int split_size = 1;
for (int j = 0; j < ragged_components.size(); j++) {
if (!ragged_components[j].nested_splits.empty()) {
split_size += ragged_components[j].nested_splits[i].NumElements() - 1;
if (!ragged_components[j].nested_splits().empty()) {
split_size += ragged_components[j].splits(i).NumElements() - 1;
}
}
output_ragged->nested_splits.push_back(
output_ragged->append_splits(
Tensor(DataTypeToEnum<SPLIT_TYPE>::value, TensorShape({split_size})));
auto splits_vec =
output_ragged->nested_splits[split_index].vec<SPLIT_TYPE>();
output_ragged->mutable_splits(split_index)->vec<SPLIT_TYPE>();
splits_vec(0) = 0;
SPLIT_TYPE last_split_value = 0;
int index = 1;
for (int j = 0; j < ragged_components.size(); j++) {
if (ragged_components[j].nested_splits.empty()) {
if (ragged_components[j].nested_splits().empty()) {
// Corner case: empty row. e.g [ [[x], [x]], [] ]
continue;
}
auto component_splits_vec =
ragged_components[j].nested_splits[i].vec<SPLIT_TYPE>();
ragged_components[j].splits(i).vec<SPLIT_TYPE>();
for (int k = 1; k < component_splits_vec.size(); k++, index++) {
splits_vec(index) = component_splits_vec(k) + last_split_value;
}
@ -187,35 +153,35 @@ Status NestedStackRaggedTensors(
if (ragged_components.empty()) {
component_values_shape = TensorShape({0});
} else {
component_values_shape = ragged_components[0].values.shape();
component_values_shape = ragged_components[0].values().shape();
}
// Populate values.
int values_size = component_values_shape.dim_size(0);
for (int i = 1; i < ragged_components.size(); i++) {
if (ragged_components[i].values.dims() != component_values_shape.dims()) {
if (ragged_components[i].values().dims() != component_values_shape.dims()) {
return errors::InvalidArgument(
"Rank of values must match for all "
"components; values shape at index 0: ",
component_values_shape.DebugString(), ", values shape at index ", i,
": ", ragged_components[i].values.shape().DebugString());
": ", ragged_components[i].values().shape().DebugString());
}
values_size += ragged_components[i].values.shape().dim_size(0);
values_size += ragged_components[i].values().shape().dim_size(0);
}
component_values_shape.set_dim(0, values_size);
output_ragged->values =
Tensor(DataTypeToEnum<VALUE_TYPE>::value, component_values_shape);
output_ragged->set_values(
Tensor(DataTypeToEnum<VALUE_TYPE>::value, component_values_shape));
auto output_values_flat =
output_ragged->values.flat_outer_dims<VALUE_TYPE, 2>();
output_ragged->mutable_values()->flat_outer_dims<VALUE_TYPE, 2>();
int values_index = 0;
for (int i = 0; i < ragged_components.size(); i++) {
auto component_values_flat =
ragged_components[i].values.flat_outer_dims<VALUE_TYPE, 2>();
int num_inner_elements = ragged_components[i].values.NumElements();
if (ragged_components[i].values.dim_size(0) > 0) {
num_inner_elements /= ragged_components[i].values.dim_size(0);
ragged_components[i].values().flat_outer_dims<VALUE_TYPE, 2>();
int num_inner_elements = ragged_components[i].values().NumElements();
if (ragged_components[i].values().dim_size(0) > 0) {
num_inner_elements /= ragged_components[i].values().dim_size(0);
}
for (int j = 0; j < ragged_components[i].values.dim_size(0);
for (int j = 0; j < ragged_components[i].values().dim_size(0);
j++, values_index++) {
for (int k = 0; k < num_inner_elements; k++) {
output_values_flat(values_index, k) = component_values_flat(j, k);
@ -265,7 +231,7 @@ class RaggedTensorFromVariantOp : public OpKernel {
// Decode all variants.
const auto value_dtype = DataTypeToEnum<VALUE_TYPE>::v();
const auto split_dtype = DataTypeToEnum<SPLIT_TYPE>::v();
std::vector<RaggedTensor> decoded_components;
std::vector<RaggedTensorVariant> decoded_components;
OP_REQUIRES_OK(context, RaggedComponentsFromVariant(
encoded_variant, input_ragged_rank_,
value_dtype, split_dtype, &decoded_components));
@ -281,7 +247,7 @@ class RaggedTensorFromVariantOp : public OpKernel {
for (int i = 0; i < encoded_variant.dims(); i++) {
encoded_dim_sizes[i] = encoded_variant.dim_size(i);
}
RaggedTensor output_ragged;
RaggedTensorVariant output_ragged;
OP_REQUIRES_OK(
context, NestedStackRaggedTensors<VALUE_TYPE, SPLIT_TYPE>(
decoded_components, encoded_dim_sizes, input_ragged_rank_,
@ -296,15 +262,15 @@ class RaggedTensorFromVariantOp : public OpKernel {
int output_ragged_rank_;
void ReturnRaggedTensor(OpKernelContext* context,
RaggedTensor ragged_tensor) {
int ragged_rank = ragged_tensor.nested_splits.size();
const RaggedTensorVariant& ragged_tensor) {
int ragged_rank = ragged_tensor.ragged_rank();
OpOutputList splits_out;
OP_REQUIRES_OK(context,
context->output_list("output_nested_splits", &splits_out));
for (int i = 0; i < ragged_rank; i++) {
splits_out.set(i, ragged_tensor.nested_splits[i]);
splits_out.set(i, ragged_tensor.splits(i));
}
context->set_output(ragged_rank, ragged_tensor.values);
context->set_output(ragged_rank, ragged_tensor.values());
}
};

View File

@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/core/framework/variant.h"
#include "tensorflow/core/framework/variant_encode_decode.h"
#include "tensorflow/core/kernels/ops_testutil.h"
#include "tensorflow/core/kernels/ragged_tensor_variant.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
@ -55,28 +56,22 @@ class RaggedTensorFromVariantKernelTest : public ::tensorflow::OpsTestBase {
}
template <typename VALUE_TYPE, typename SPLIT_TYPE>
Tensor CreateVariantFromRagged(
RaggedTensorVariant CreateVariantFromRagged(
const std::vector<std::vector<SPLIT_TYPE>>& ragged_splits,
const TensorShape& ragged_values_shape,
const std::vector<VALUE_TYPE>& ragged_values) {
// Step 1: Create Tensors out of ragged splits and values.
std::vector<Variant> ragged_components;
RaggedTensorVariant encoded;
for (auto ragged_split : ragged_splits) {
int splits_size = ragged_split.size();
Tensor splits(DataTypeToEnum<SPLIT_TYPE>::v(),
TensorShape({splits_size}));
test::FillValues<SPLIT_TYPE>(&splits, ragged_split);
ragged_components.push_back(splits);
encoded.append_splits(splits);
}
Tensor values(DataTypeToEnum<VALUE_TYPE>::v(), ragged_values_shape);
test::FillValues<VALUE_TYPE>(&values, ragged_values);
ragged_components.push_back(values);
// Step 2: Encode into a 1-D Variant Tensor.
int num_splits = ragged_splits.size();
Tensor encoded_list(DT_VARIANT, TensorShape({num_splits + 1}));
test::FillValues<Variant>(&encoded_list, ragged_components);
return encoded_list;
encoded.set_values(values);
return encoded;
}
};
@ -85,7 +80,7 @@ TEST_F(RaggedTensorFromVariantKernelTest, ScalarInput) {
const std::vector<int64> split_2 = {0, 1, 2, 5, 6, 7};
const std::vector<int> values = {0, 1, 1, 2, 2, 3, 4};
Tensor encoded_variant = CreateVariantFromRagged<int, int64>(
auto encoded_variant = CreateVariantFromRagged<int, int64>(
{split_1, split_2}, TensorShape({7}), values);
Tensor expected_splits_1(DT_INT64, TensorShape({6}));
Tensor expected_splits_2(DT_INT64, TensorShape({6}));
@ -113,7 +108,7 @@ TEST_F(RaggedTensorFromVariantKernelTest, OneInputElement) {
const std::vector<int> values = {0, 1, 1, 2, 2, 3, 4};
const std::vector<int64> batched_splits_1 = {0, 5};
Tensor encoded_variant = CreateVariantFromRagged<int, int64>(
auto encoded_variant = CreateVariantFromRagged<int, int64>(
{split_1, split_2}, TensorShape({7}), values);
Tensor expected_splits_1(DT_INT64, TensorShape({2}));
Tensor expected_splits_2(DT_INT64, TensorShape({6}));
@ -157,13 +152,13 @@ TEST_F(RaggedTensorFromVariantKernelTest, TensorIn2DOut) {
const std::vector<int64> batched_splits_2 = {0, 3, 3, 5, 6};
const std::vector<int> batched_values = {1, 2, 3, 4, 5, 6};
Tensor component_variant_1 =
auto component_variant_1 =
CreateVariantFromRagged<int, int64>({}, TensorShape({3}), values_1);
Tensor component_variant_2 =
auto component_variant_2 =
CreateVariantFromRagged<int, int64>({}, TensorShape({0}), values_2);
Tensor component_variant_3 =
auto component_variant_3 =
CreateVariantFromRagged<int, int64>({}, TensorShape({2}), values_3);
Tensor component_variant_4 =
auto component_variant_4 =
CreateVariantFromRagged<int, int64>({}, TensorShape({1}), values_4);
Tensor expected_splits_1(DT_INT64, TensorShape({3}));
@ -223,15 +218,15 @@ TEST_F(RaggedTensorFromVariantKernelTest, NonEmpty1DIn3DOut) {
test::FillValues<int64>(&expected_splits_3, batched_splits_3);
test::FillValues<int>(&expected_values, batched_values);
Tensor variant_component_1 = CreateVariantFromRagged<int, int64>(
auto variant_component_1 = CreateVariantFromRagged<int, int64>(
{component_split_1_1}, TensorShape({1}), component_values_1);
Tensor variant_component_2 = CreateVariantFromRagged<int, int64>(
auto variant_component_2 = CreateVariantFromRagged<int, int64>(
{component_split_2_1}, TensorShape({2}), component_values_2);
Tensor variant_component_3 = CreateVariantFromRagged<int, int64>(
auto variant_component_3 = CreateVariantFromRagged<int, int64>(
{component_split_3_1}, TensorShape({2}), component_values_3);
Tensor variant_component_4 = CreateVariantFromRagged<int, int64>(
auto variant_component_4 = CreateVariantFromRagged<int, int64>(
{component_split_4_1}, TensorShape({3}), component_values_4);
Tensor variant_component_5 = CreateVariantFromRagged<int, int64>(
auto variant_component_5 = CreateVariantFromRagged<int, int64>(
{component_split_5_1}, TensorShape({3}), component_values_5);
int input_ragged_rank = 1;
int output_ragged_rank = 3;
@ -297,10 +292,10 @@ TEST_F(RaggedTensorFromVariantKernelTest,
test::FillValues<int64>(&expected_splits_4, batched_splits_4);
test::FillValues<int>(&expected_values, batched_values);
Tensor variant_component_1 = CreateVariantFromRagged<int, int64>(
auto variant_component_1 = CreateVariantFromRagged<int, int64>(
{component_split_1_1, component_split_1_2}, TensorShape({11}),
component_values_1);
Tensor variant_component_2 = CreateVariantFromRagged<int, int64>(
auto variant_component_2 = CreateVariantFromRagged<int, int64>(
{component_split_2_1, component_split_2_2}, TensorShape({11}),
component_values_2);
int input_ragged_rank = -1;
@ -336,9 +331,9 @@ TEST_F(RaggedTensorFromVariantKernelTest, EmptyRow1DIn2DOut) {
test::FillValues<int64>(&expected_splits_2, batched_splits_2);
test::FillValues<int>(&expected_values, batched_values);
Tensor variant_component_1 = CreateVariantFromRagged<int, int64>(
auto variant_component_1 = CreateVariantFromRagged<int, int64>(
{component_split_1_1}, TensorShape({3}), component_values_1);
Tensor variant_component_2 = CreateVariantFromRagged<int, int64>(
auto variant_component_2 = CreateVariantFromRagged<int, int64>(
{component_split_2_1}, TensorShape({0}), {}); // Empty row.
int input_ragged_rank = 1;
int output_ragged_rank = 2;
@ -371,9 +366,9 @@ TEST_F(RaggedTensorFromVariantKernelTest, NDValues1DIn2DOut) {
test::FillValues<int64>(&expected_splits_2, batched_splits_2);
test::FillValues<int>(&expected_values, batched_values);
Tensor variant_component_1 = CreateVariantFromRagged<int, int64>(
auto variant_component_1 = CreateVariantFromRagged<int, int64>(
{component_split_1_1}, TensorShape({1, 2}), component_values_1);
Tensor variant_component_2 = CreateVariantFromRagged<int, int64>(
auto variant_component_2 = CreateVariantFromRagged<int, int64>(
{component_split_2_1}, TensorShape({2, 2}), component_values_2);
int input_ragged_rank = 1;
int output_ragged_rank = 2;
@ -423,15 +418,15 @@ TEST_F(RaggedTensorFromVariantKernelTest, NonEmpty1DIn3DOutInt32Splits) {
test::FillValues<int>(&expected_splits_3, batched_splits_3);
test::FillValues<int>(&expected_values, batched_values);
Tensor variant_component_1 = CreateVariantFromRagged<int, int>(
auto variant_component_1 = CreateVariantFromRagged<int, int>(
{component_split_1_1}, TensorShape({1}), component_values_1);
Tensor variant_component_2 = CreateVariantFromRagged<int, int>(
auto variant_component_2 = CreateVariantFromRagged<int, int>(
{component_split_2_1}, TensorShape({2}), component_values_2);
Tensor variant_component_3 = CreateVariantFromRagged<int, int>(
auto variant_component_3 = CreateVariantFromRagged<int, int>(
{component_split_3_1}, TensorShape({2}), component_values_3);
Tensor variant_component_4 = CreateVariantFromRagged<int, int>(
auto variant_component_4 = CreateVariantFromRagged<int, int>(
{component_split_4_1}, TensorShape({3}), component_values_4);
Tensor variant_component_5 = CreateVariantFromRagged<int, int>(
auto variant_component_5 = CreateVariantFromRagged<int, int>(
{component_split_5_1}, TensorShape({3}), component_values_5);
int input_ragged_rank = 1;
int output_ragged_rank = 3;
@ -451,13 +446,13 @@ TEST_F(RaggedTensorFromVariantKernelTest, NonEmpty1DIn3DOutInt32Splits) {
// Tests for invalid inputs.
TEST_F(RaggedTensorFromVariantKernelTest, InvalidInferredInputRaggedRank) {
Tensor component_variant_1 =
auto component_variant_1 =
CreateVariantFromRagged<int, int64>({}, TensorShape({3}), {1, 2, 3});
Tensor component_variant_2 =
auto component_variant_2 =
CreateVariantFromRagged<int, int64>({}, TensorShape({0}), {});
Tensor component_variant_3 =
auto component_variant_3 =
CreateVariantFromRagged<int, int64>({}, TensorShape({2}), {1, 2});
Tensor component_variant_4 =
auto component_variant_4 =
CreateVariantFromRagged<int, int64>({}, TensorShape({1}), {1});
int input_ragged_rank = -1;
@ -478,9 +473,9 @@ TEST_F(RaggedTensorFromVariantKernelTest, InputDimsAndRaggedRankAttrsMismatch) {
const std::vector<int> component_values_1 = {0};
const std::vector<int> component_values_2 = {0, 1};
Tensor variant_component_1 = CreateVariantFromRagged<int, int64>(
auto variant_component_1 = CreateVariantFromRagged<int, int64>(
{component_split_1_1}, TensorShape({1}), component_values_1);
Tensor variant_component_2 = CreateVariantFromRagged<int, int64>(
auto variant_component_2 = CreateVariantFromRagged<int, int64>(
{component_split_2_1}, TensorShape({2}), component_values_2);
int input_ragged_rank = 1;
@ -493,33 +488,21 @@ TEST_F(RaggedTensorFromVariantKernelTest, InputDimsAndRaggedRankAttrsMismatch) {
"input_ragged_rank + encoded_ragged.dims()"));
}
TEST_F(RaggedTensorFromVariantKernelTest, InputDoesNotHoldTensors) {
TEST_F(RaggedTensorFromVariantKernelTest, InputDoesNotHoldRaggedTensorVariant) {
int input_ragged_rank = 1;
int output_ragged_rank = 2;
BuildDecodeRaggedTensorGraph<int, int64>(
input_ragged_rank, output_ragged_rank, TensorShape({2}), {1, 2});
EXPECT_TRUE(absl::StartsWith(
RunOpKernel().error_message(),
"Input Variant element at index 0 doesn't hold a Tensor"));
}
TEST_F(RaggedTensorFromVariantKernelTest, InputVariantTensorRankNotOne) {
Tensor variant_list(DT_VARIANT, TensorShape({2, 1}));
test::FillValues<Variant>(&variant_list, {1, 2});
int input_ragged_rank = 1;
int output_ragged_rank = 2;
BuildDecodeRaggedTensorGraph<int, int64>(
input_ragged_rank, output_ragged_rank, TensorShape({1}), {variant_list});
EXPECT_TRUE(absl::StartsWith(
RunOpKernel().error_message(),
"Encoded input Variant must have rank 1, but found rank: 2"));
"Input Variant element at index 0 doesn't hold a RaggedTensorVariant"));
}
TEST_F(RaggedTensorFromVariantKernelTest,
InputScalarElementDoesNotMatchInputRaggedRank) {
const std::vector<int64> component_split_1_1 = {0, 1};
const std::vector<int> component_values_1 = {1, 2};
Tensor variant_component_1 = CreateVariantFromRagged<int, int64>(
auto variant_component_1 = CreateVariantFromRagged<int, int64>(
{component_split_1_1}, TensorShape({1, 2}), component_values_1);
int input_ragged_rank = 2;
@ -527,31 +510,17 @@ TEST_F(RaggedTensorFromVariantKernelTest,
BuildDecodeRaggedTensorGraph<int, int64>(input_ragged_rank,
output_ragged_rank, TensorShape({1}),
{variant_component_1});
EXPECT_TRUE(absl::StartsWith(
RunOpKernel().error_message(),
"Encoded input Variant must hold either input_ragged_rank + 1 "
"Tensors or an empty Tensor"));
}
TEST_F(RaggedTensorFromVariantKernelTest, RaggedSplitNotATensor) {
Tensor variant_list(DT_VARIANT, TensorShape({2}));
test::FillValues<Variant>(&variant_list, {1, 2});
int input_ragged_rank = 1;
int output_ragged_rank = 2;
BuildDecodeRaggedTensorGraph<int, int>(input_ragged_rank, output_ragged_rank,
TensorShape({1}), {variant_list});
EXPECT_TRUE(
absl::StartsWith(RunOpKernel().error_message(),
"Encoded scalar element at index 0 doesn't have a "
"splits Tensor at split_index 0"));
"Encoded input RaggedTensorVariant has ragged_rank=1. "
"Expected ragged_rank=2."));
}
TEST_F(RaggedTensorFromVariantKernelTest, RaggedSplitTypeMismatch) {
const std::vector<int64> component_split_1_1 = {0, 1};
const std::vector<int> component_values_1 = {0};
Tensor variant_component_1 = CreateVariantFromRagged<int, int64>(
auto variant_component_1 = CreateVariantFromRagged<int, int64>(
{component_split_1_1}, TensorShape({1}), component_values_1);
int input_ragged_rank = 1;
@ -559,46 +528,29 @@ TEST_F(RaggedTensorFromVariantKernelTest, RaggedSplitTypeMismatch) {
BuildDecodeRaggedTensorGraph<int, int>(input_ragged_rank, output_ragged_rank,
TensorShape({1}),
{variant_component_1});
EXPECT_TRUE(absl::StartsWith(RunOpKernel().error_message(),
"Expected splits Tensor dtype: 3, found: 9"));
EXPECT_TRUE(absl::StartsWith(
RunOpKernel().error_message(),
"Expected row_splits Tensor dtype: int32, found: int64"));
}
TEST_F(RaggedTensorFromVariantKernelTest, RaggedSplitRankNotOne) {
Tensor splits(DT_INT64, TensorShape({2, 1}));
test::FillValues<int64>(&splits, {1, 2});
Tensor values(DT_INT32, {2});
test::FillValues<int>(&values, {1, 2});
Tensor encoded_list(DT_VARIANT, TensorShape({2}));
test::FillValues<Variant>(&encoded_list, {splits, values});
RaggedTensorVariant encoded(Tensor(DT_INT32, {2}),
{Tensor(DT_INT64, {2, 1})});
test::FillValues<int64>(encoded.mutable_splits(0), {1, 2});
test::FillValues<int>(encoded.mutable_values(), {1, 2});
int input_ragged_rank = 1;
int output_ragged_rank = 2;
BuildDecodeRaggedTensorGraph<int, int64>(
input_ragged_rank, output_ragged_rank, TensorShape({1}), {encoded_list});
input_ragged_rank, output_ragged_rank, TensorShape({1}), {encoded});
EXPECT_TRUE(absl::StartsWith(RunOpKernel().error_message(),
"Ragged splits must have rank 1"));
}
TEST_F(RaggedTensorFromVariantKernelTest, RaggedValuesNotATensor) {
Tensor splits(DT_INT64, TensorShape({3}));
test::FillValues<int64>(&splits, {0, 2, 3});
Tensor variant_list(DT_VARIANT, TensorShape({2}));
test::FillValues<Variant>(&variant_list, {splits, 2});
int input_ragged_rank = 1;
int output_ragged_rank = 2;
BuildDecodeRaggedTensorGraph<int, int64>(
input_ragged_rank, output_ragged_rank, TensorShape({1}), {variant_list});
EXPECT_TRUE(
absl::StartsWith(RunOpKernel().error_message(),
"Encoded scalar element at index 0 doesn't have a "
"values Tensor"));
}
TEST_F(RaggedTensorFromVariantKernelTest, RaggedValuesTypeMismatch) {
const std::vector<int64> component_split_1_1 = {0, 1};
const std::vector<int> component_values_1 = {0};
Tensor variant_component_1 = CreateVariantFromRagged<int, int64>(
auto variant_component_1 = CreateVariantFromRagged<int, int64>(
{component_split_1_1}, TensorShape({1}), component_values_1);
int input_ragged_rank = 1;
int output_ragged_rank = 2;
@ -611,7 +563,7 @@ TEST_F(RaggedTensorFromVariantKernelTest, RaggedValuesTypeMismatch) {
}
TEST_F(RaggedTensorFromVariantKernelTest, RaggedValuesRankNotGreaterThanOne) {
Tensor variant_component_1 =
auto variant_component_1 =
CreateVariantFromRagged<int, int64>({{0, 1}}, TensorShape({}), {1});
int input_ragged_rank = 1;
int output_ragged_rank = 2;
@ -628,9 +580,9 @@ TEST_F(RaggedTensorFromVariantKernelTest, RaggedValuesRankMismatch) {
const std::vector<int> component_values_1 = {0};
const std::vector<int> component_values_2 = {0, 1, 2, 3};
Tensor variant_component_1 = CreateVariantFromRagged<int, int64>(
auto variant_component_1 = CreateVariantFromRagged<int, int64>(
{component_split_1_1}, TensorShape({1}), component_values_1);
Tensor variant_component_2 = CreateVariantFromRagged<int, int64>(
auto variant_component_2 = CreateVariantFromRagged<int, int64>(
{component_split_2_1}, TensorShape({2, 2}), component_values_2);
int input_ragged_rank = 1;
int output_ragged_rank = 2;
@ -711,13 +663,13 @@ TEST_F(RaggedTensorFromVariantKernelTest, 2DValuesTensorIn1DOut) {
const std::vector<int> batched_values = {1, 1, 1, 1, 2, 2, 2, 2, 3, 3,
3, 3, 4, 4, 4, 4, 5, 5, 5, 5};
Tensor variant_component_1 = CreateVariantFromRagged<int, int64>(
auto variant_component_1 = CreateVariantFromRagged<int, int64>(
{}, TensorShape({2, 2, 2}), {1, 1, 1, 1, 2, 2, 2, 2});
Tensor variant_component_2 = CreateVariantFromRagged<int, int64>(
auto variant_component_2 = CreateVariantFromRagged<int, int64>(
{}, TensorShape({1, 2, 2}), {3, 3, 3, 3});
Tensor variant_component_3 =
auto variant_component_3 =
CreateVariantFromRagged<int, int64>({}, TensorShape({0, 2, 2}), {});
Tensor variant_component_4 = CreateVariantFromRagged<int, int64>(
auto variant_component_4 = CreateVariantFromRagged<int, int64>(
{}, TensorShape({2, 2, 2}), {4, 4, 4, 4, 5, 5, 5, 5});
Tensor expected_splits_1(DT_INT64, TensorShape({5}));

View File

@ -18,50 +18,38 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/variant.h"
#include "tensorflow/core/framework/variant_encode_decode.h"
#include "tensorflow/core/framework/variant_op_registry.h"
#include "tensorflow/core/kernels/concat_lib.h"
#include "tensorflow/core/kernels/ragged_tensor_variant.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/util/tensor_ops_util.h"
namespace tensorflow {
namespace {
struct RaggedTensor {
Tensor values;
std::vector<Tensor> nested_splits;
};
Status RaggedToVariant(const RaggedTensor& ragged, Tensor* encoded_list) {
// Encode as a rank-1 Variant Tensor.
int ragged_rank = ragged.nested_splits.size();
*encoded_list = Tensor(DT_VARIANT, TensorShape({ragged_rank + 1}));
auto encoded_vec = encoded_list->vec<Variant>();
for (int i = 0; i < ragged_rank; i++) {
encoded_vec(i) = ragged.nested_splits[i];
}
encoded_vec(ragged_rank) = ragged.values;
return Status::OK();
}
template <typename VALUE_TYPE, typename SPLIT_TYPE>
Status UnbatchRaggedZerothDim(const RaggedTensor& batched_ragged,
std::vector<RaggedTensor>* ragged_components) {
Status UnbatchRaggedZerothDim(
const RaggedTensorVariant& batched_ragged,
std::vector<RaggedTensorVariant>* ragged_components) {
// Set up the component Ragged Tensors.
int ragged_rank = batched_ragged.nested_splits.size();
auto batched_splits_top_vec =
batched_ragged.nested_splits[0].vec<SPLIT_TYPE>();
int ragged_rank = batched_ragged.ragged_rank();
auto batched_splits_top_vec = batched_ragged.splits(0).vec<SPLIT_TYPE>();
int num_components = batched_splits_top_vec.size() - 1;
int num_splits = ragged_rank - 1;
ragged_components->resize(num_components);
for (RaggedTensor ragged_component : *ragged_components) {
ragged_component.nested_splits.reserve(num_splits);
for (RaggedTensorVariant& ragged_component : *ragged_components) {
ragged_component.mutable_nested_splits()->reserve(num_splits);
}
const auto& batched_flat = batched_ragged.values.flat<VALUE_TYPE>();
int num_inner_elems = batched_ragged.values.NumElements();
if (batched_ragged.values.dim_size(0) > 1) {
num_inner_elems /= batched_ragged.values.dim_size(0);
const auto& batched_flat = batched_ragged.values().flat<VALUE_TYPE>();
int num_inner_elems = batched_ragged.values().NumElements();
if (batched_ragged.values().dim_size(0) > 1) {
num_inner_elems /= batched_ragged.values().dim_size(0);
}
TensorShape values_shape = batched_ragged.values.shape();
TensorShape values_shape = batched_ragged.values().shape();
// Corner case: ragged_rank == 1, e.g. [[1, 2, 3], [4, 5]]
if (num_splits == 0) {
@ -70,10 +58,10 @@ Status UnbatchRaggedZerothDim(const RaggedTensor& batched_ragged,
int limit = batched_splits_top_vec(i + 1);
int num_values = limit - start;
values_shape.set_dim(0, num_values);
(*ragged_components)[i].values =
Tensor(DataTypeToEnum<VALUE_TYPE>::value, values_shape);
(*ragged_components)[i].set_values(
Tensor(DataTypeToEnum<VALUE_TYPE>::value, values_shape));
auto ragged_component_values_flat =
(*ragged_components)[i].values.flat<VALUE_TYPE>();
(*ragged_components)[i].mutable_values()->flat<VALUE_TYPE>();
for (int j = 0; j < num_values * num_inner_elems; j++) {
ragged_component_values_flat(j) =
batched_flat(j + start * num_inner_elems);
@ -86,8 +74,7 @@ Status UnbatchRaggedZerothDim(const RaggedTensor& batched_ragged,
std::vector<typename TTypes<SPLIT_TYPE>::ConstVec> batched_splits_vec;
batched_splits_vec.reserve(ragged_rank);
for (int i = 0; i < ragged_rank; i++) {
batched_splits_vec.push_back(
batched_ragged.nested_splits[i].vec<SPLIT_TYPE>());
batched_splits_vec.push_back(batched_ragged.splits(i).vec<SPLIT_TYPE>());
}
std::vector<int> index(num_splits, 1);
std::vector<int> ragged_component_values_size(num_components, 0);
@ -104,10 +91,10 @@ Status UnbatchRaggedZerothDim(const RaggedTensor& batched_ragged,
int last_index = ragged_component_splits_vec[j - 1].size() - 1;
split_size = ragged_component_splits_vec[j - 1](last_index) + 1;
}
(*ragged_components)[i].nested_splits.push_back(
(*ragged_components)[i].append_splits(
Tensor(DataTypeToEnum<SPLIT_TYPE>::value, TensorShape({split_size})));
ragged_component_splits_vec.push_back(
(*ragged_components)[i].nested_splits[j].vec<SPLIT_TYPE>());
(*ragged_components)[i].mutable_splits(j)->vec<SPLIT_TYPE>());
SPLIT_TYPE last_split_value = batched_splits_vec[j + 1](index[j] - 1);
ragged_component_splits_vec[j](0) = 0;
for (int k = 1; k < split_size; k++, index[j]++) {
@ -125,10 +112,10 @@ Status UnbatchRaggedZerothDim(const RaggedTensor& batched_ragged,
for (int i = 0; i < num_components; i++) {
int num_values = ragged_component_values_size[i];
values_shape.set_dim(0, num_values);
(*ragged_components)[i].values =
Tensor(DataTypeToEnum<VALUE_TYPE>::value, values_shape);
(*ragged_components)[i].set_values(
Tensor(DataTypeToEnum<VALUE_TYPE>::value, values_shape));
auto ragged_component_values_flat =
(*ragged_components)[i].values.flat<VALUE_TYPE>();
(*ragged_components)[i].mutable_values()->flat<VALUE_TYPE>();
for (int j = 0; j < num_values * num_inner_elems; j++, value_index++) {
ragged_component_values_flat(j) = batched_flat(value_index);
}
@ -152,46 +139,38 @@ class RaggedTensorToVariantOp : public OpKernel {
OP_REQUIRES_OK(context, context->input_list("rt_nested_splits",
&ragged_nested_splits_in));
const int ragged_nested_splits_len = ragged_nested_splits_in.size();
RaggedTensor batched_ragged_input;
RaggedTensorVariant batched_ragged_input;
// Read ragged_values input.
batched_ragged_input.values = context->input(ragged_nested_splits_len);
batched_ragged_input.nested_splits.reserve(ragged_nested_splits_len);
batched_ragged_input.set_values(context->input(ragged_nested_splits_len));
batched_ragged_input.mutable_nested_splits()->reserve(
ragged_nested_splits_len);
for (int i = 0; i < ragged_nested_splits_len; i++) {
batched_ragged_input.nested_splits.push_back(ragged_nested_splits_in[i]);
batched_ragged_input.append_splits(ragged_nested_splits_in[i]);
}
if (!batched_input_) {
// Encode the input as is.
Tensor encoded_list;
OP_REQUIRES_OK(context,
RaggedToVariant(batched_ragged_input, &encoded_list));
// Encode as a Scalar Variant Tensor.
Tensor* encoded_scalar;
OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape({}),
&encoded_scalar));
encoded_scalar->scalar<Variant>()() = std::move(encoded_list);
encoded_scalar->scalar<Variant>()() = std::move(batched_ragged_input);
return;
}
// Unbatch the Ragged Tensor and encode the components.
std::vector<RaggedTensor> ragged_components;
std::vector<RaggedTensorVariant> unbatched_ragged_input;
OP_REQUIRES_OK(context, UnbatchRaggedZerothDim<VALUE_TYPE, SPLIT_TYPE>(
batched_ragged_input, &ragged_components));
std::vector<Tensor> encoded_components(ragged_components.size());
for (int i = 0; i < ragged_components.size(); i++) {
OP_REQUIRES_OK(context, RaggedToVariant(ragged_components[i],
&encoded_components[i]));
}
batched_ragged_input, &unbatched_ragged_input));
// Bundle the encoded scalar Variant Tensors into a rank-1 Variant Tensor.
Tensor* encoded_ragged;
int output_size = ragged_components.size();
Tensor* encoded_vector;
int output_size = unbatched_ragged_input.size();
OP_REQUIRES_OK(context,
context->allocate_output(0, TensorShape({output_size}),
&encoded_ragged));
auto encoded_ragged_vec = encoded_ragged->vec<Variant>();
&encoded_vector));
auto encoded_vector_t = encoded_vector->vec<Variant>();
for (int i = 0; i < output_size; i++) {
encoded_ragged_vec(i) = encoded_components[i];
encoded_vector_t(i) = unbatched_ragged_input[i];
}
}
@ -199,12 +178,81 @@ class RaggedTensorToVariantOp : public OpKernel {
bool batched_input_;
};
#define REGISTER_KERNELS_WITH_SPLIT_TYPE(value_type, split_type) \
REGISTER_KERNEL_BUILDER(Name("RaggedTensorToVariant") \
.Device(DEVICE_CPU) \
.TypeConstraint<value_type>("Tvalues") \
.TypeConstraint<split_type>("Tsplits"), \
RaggedTensorToVariantOp<value_type, split_type>);
template <typename VALUE_TYPE, typename SPLIT_TYPE>
class RaggedTensorToVariantGradientOp : public OpKernel {
public:
using OpKernel::OpKernel;
void Compute(OpKernelContext* context) override {
// Read inputs.
Tensor encoded_variant = context->input(0);
Tensor row_splits = context->input(1);
auto flat_row_splits = row_splits.flat<SPLIT_TYPE>();
TensorShape dense_values_shape;
OP_REQUIRES_OK(context,
TensorShapeUtils::MakeShape(context->input(2).vec<int32>(),
&dense_values_shape));
const auto& flat_variants = encoded_variant.flat<Variant>();
// Get a Tensor containing the flat_values for each variant.
std::vector<Tensor> values;
for (int i = 0; i < flat_variants.size(); ++i) {
if (const auto* encoded = flat_variants(i).get<RaggedTensorVariant>()) {
values.push_back(encoded->values());
} else {
// Missing value: this happens if only some of the variant values
// generated by ragged_tensor_to_variant impacted the value that we're
// calculating the gradient for. In this case, we will see a
// default-constructed variant; so treat it as a zero tensor with the
// appropriate shape.
const auto value_dtype = DataTypeToEnum<VALUE_TYPE>::v();
int piece_size = flat_row_splits(i + 1) - flat_row_splits(i);
TensorShape zeros_shape = dense_values_shape;
zeros_shape.set_dim(0, piece_size);
Tensor zero(value_dtype, zeros_shape);
zero.flat<VALUE_TYPE>() =
zero.flat<VALUE_TYPE>().constant(VALUE_TYPE());
values.push_back(zero);
}
}
if (values.size() == 1) {
// Just one flat_value tensor: return as-is.
context->set_output(0, values[0]);
} else {
// Multiple flat_values tensors: concatenate them together.
using Piece = typename TTypes<VALUE_TYPE, 2>::Matrix;
using ConstPiece = typename TTypes<VALUE_TYPE, 2>::ConstMatrix;
std::vector<std::unique_ptr<ConstPiece>> pieces;
pieces.reserve(values.size());
for (const Tensor& t : values) {
pieces.emplace_back(
new ConstPiece(t.shaped<VALUE_TYPE, 2>({1, t.NumElements()})));
}
Tensor* out = nullptr;
OP_REQUIRES_OK(context,
context->allocate_output(0, dense_values_shape, &out));
Piece out_flat =
out->shaped<VALUE_TYPE, 2>({1, dense_values_shape.num_elements()});
ConcatCPU<VALUE_TYPE>(context->device(), pieces, &out_flat);
}
}
};
#define REGISTER_KERNELS_WITH_SPLIT_TYPE(value_type, split_type) \
REGISTER_KERNEL_BUILDER(Name("RaggedTensorToVariant") \
.Device(DEVICE_CPU) \
.TypeConstraint<value_type>("Tvalues") \
.TypeConstraint<split_type>("Tsplits"), \
RaggedTensorToVariantOp<value_type, split_type>); \
REGISTER_KERNEL_BUILDER( \
Name("RaggedTensorToVariantGradient") \
.Device(DEVICE_CPU) \
.TypeConstraint<value_type>("Tvalues") \
.TypeConstraint<split_type>("Tsplits"), \
RaggedTensorToVariantGradientOp<value_type, split_type>);
#define REGISTER_KERNELS(value_type) \
REGISTER_KERNELS_WITH_SPLIT_TYPE(value_type, int32) \
REGISTER_KERNELS_WITH_SPLIT_TYPE(value_type, int64)

View File

@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/core/framework/variant.h"
#include "tensorflow/core/framework/variant_encode_decode.h"
#include "tensorflow/core/kernels/ops_testutil.h"
#include "tensorflow/core/kernels/ragged_tensor_variant.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
@ -60,6 +61,43 @@ class RaggedTensorToVariantKernelTest : public ::tensorflow::OpsTestBase {
}
AddInputFromArray<VALUE_TYPE>(ragged_values_shape, ragged_values);
}
template <typename VALUE_TYPE, typename SPLIT_TYPE>
RaggedTensorVariant CreateVariantFromRagged(
const std::vector<std::vector<SPLIT_TYPE>>& ragged_splits,
const TensorShape& ragged_values_shape,
const std::vector<VALUE_TYPE>& ragged_values) {
RaggedTensorVariant encoded;
for (auto ragged_split : ragged_splits) {
int splits_size = ragged_split.size();
Tensor splits(DataTypeToEnum<SPLIT_TYPE>::v(),
TensorShape({splits_size}));
test::FillValues<SPLIT_TYPE>(&splits, ragged_split);
encoded.append_splits(splits);
}
Tensor values(DataTypeToEnum<VALUE_TYPE>::v(), ragged_values_shape);
test::FillValues<VALUE_TYPE>(&values, ragged_values);
encoded.set_values(values);
return encoded;
}
template <typename VALUE_TYPE, typename SPLIT_TYPE>
RaggedTensorVariant CreateVariantFromRagged(
const std::vector<std::vector<SPLIT_TYPE>>& ragged_splits,
const std::vector<VALUE_TYPE>& ragged_values) {
int num_values = ragged_values.size();
return CreateVariantFromRagged(ragged_splits, {num_values}, ragged_values);
}
template <typename VALUE_TYPE, typename SPLIT_TYPE>
void ExpectRaggedTensorVariantEqual(const RaggedTensorVariant& expected,
const RaggedTensorVariant& actual) {
test::ExpectTensorEqual<VALUE_TYPE>(actual.values(), expected.values());
EXPECT_EQ(actual.ragged_rank(), expected.ragged_rank());
for (int i = 0; i < actual.ragged_rank(); ++i) {
test::ExpectTensorEqual<SPLIT_TYPE>(actual.splits(i), expected.splits(i));
}
}
};
TEST_F(RaggedTensorToVariantKernelTest, NoValuesInput) {
@ -67,18 +105,6 @@ TEST_F(RaggedTensorToVariantKernelTest, NoValuesInput) {
const std::vector<int64> batched_splits_1 = {0, 2, 3, 3};
const std::vector<int64> batched_splits_2 = {0, 0, 0, 0};
const std::vector<int64> component_splits_1_1 = {0, 0, 0};
const std::vector<int64> component_splits_2_1 = {0, 0};
const std::vector<int64> component_splits_3_1 = {0};
Tensor expected_splits_1_1(DT_INT64, TensorShape({3}));
Tensor expected_splits_2_1(DT_INT64, TensorShape({2}));
Tensor expected_splits_3_1(DT_INT64, TensorShape({1}));
test::FillValues<int64>(&expected_splits_1_1, component_splits_1_1);
test::FillValues<int64>(&expected_splits_2_1, component_splits_2_1);
test::FillValues<int64>(&expected_splits_3_1, component_splits_3_1);
BuildEncodeRaggedTensorGraph<int, int64>({batched_splits_1, batched_splits_2},
TensorShape({0}), {}, true);
TF_ASSERT_OK(RunOpKernel());
@ -86,55 +112,26 @@ TEST_F(RaggedTensorToVariantKernelTest, NoValuesInput) {
const auto& encoded_list = GetOutput(0)->vec<Variant>();
EXPECT_EQ(encoded_list.size(), 3);
const Variant& encoded_splits_1_1 =
encoded_list(0).get<Tensor>()->vec<Variant>()(0);
const Variant& encoded_values_1 =
encoded_list(0).get<Tensor>()->vec<Variant>()(1);
const Variant& encoded_splits_2_1 =
encoded_list(1).get<Tensor>()->vec<Variant>()(0);
const Variant& encoded_values_2 =
encoded_list(1).get<Tensor>()->vec<Variant>()(1);
const Variant& encoded_splits_3_1 =
encoded_list(2).get<Tensor>()->vec<Variant>()(0);
const Variant& encoded_values_3 =
encoded_list(2).get<Tensor>()->vec<Variant>()(1);
test::ExpectTensorEqual<int64>(*encoded_splits_1_1.get<Tensor>(),
expected_splits_1_1);
test::ExpectTensorEqual<int64>(*encoded_splits_2_1.get<Tensor>(),
expected_splits_2_1);
test::ExpectTensorEqual<int64>(*encoded_splits_3_1.get<Tensor>(),
expected_splits_3_1);
test::ExpectTensorEqual<int>(*encoded_values_1.get<Tensor>(),
Tensor(DT_INT32, TensorShape({0})));
test::ExpectTensorEqual<int>(*encoded_values_2.get<Tensor>(),
Tensor(DT_INT32, TensorShape({0})));
test::ExpectTensorEqual<int>(*encoded_values_3.get<Tensor>(),
Tensor(DT_INT32, TensorShape({0})));
ExpectRaggedTensorVariantEqual<int, int64>(
CreateVariantFromRagged<int, int64>({{0, 0, 0}}, {}),
*encoded_list(0).get<RaggedTensorVariant>());
ExpectRaggedTensorVariantEqual<int, int64>(
CreateVariantFromRagged<int, int64>({{0, 0}}, {}),
*encoded_list(1).get<RaggedTensorVariant>());
ExpectRaggedTensorVariantEqual<int, int64>(
CreateVariantFromRagged<int, int64>({{0}}, {}),
*encoded_list(2).get<RaggedTensorVariant>());
}
TEST_F(RaggedTensorToVariantKernelTest, 1DValuesRaggedRankOneInput) {
// ragged_tensor=
// [ [x, x, x],
// [ [1, 2, 3],
// [ ],
// [x, x ],
// [x ]]
// [4, 5 ],
// [6 ]]
const std::vector<int64> batched_splits = {0, 3, 3, 5, 6};
const std::vector<int> batched_values = {1, 2, 3, 4, 5, 6};
const std::vector<int> component_values_1 = {1, 2, 3};
const std::vector<int> component_values_3 = {4, 5};
const std::vector<int> component_values_4 = {6};
Tensor expected_values_1(DT_INT32, TensorShape({3}));
Tensor expected_values_2(DT_INT32, TensorShape({0}));
Tensor expected_values_3(DT_INT32, TensorShape({2}));
Tensor expected_values_4(DT_INT32, TensorShape({1}));
test::FillValues<int>(&expected_values_1, component_values_1);
test::FillValues<int>(&expected_values_3, component_values_3);
test::FillValues<int>(&expected_values_4, component_values_4);
BuildEncodeRaggedTensorGraph<int, int64>({batched_splits}, TensorShape({6}),
batched_values, true);
TF_ASSERT_OK(RunOpKernel());
@ -142,45 +139,28 @@ TEST_F(RaggedTensorToVariantKernelTest, 1DValuesRaggedRankOneInput) {
const auto& encoded_list = GetOutput(0)->vec<Variant>();
EXPECT_EQ(encoded_list.size(), 4);
const Variant& encoded_values_1 =
encoded_list(0).get<Tensor>()->vec<Variant>()(0);
const Variant& encoded_values_2 =
encoded_list(1).get<Tensor>()->vec<Variant>()(0);
const Variant& encoded_values_3 =
encoded_list(2).get<Tensor>()->vec<Variant>()(0);
const Variant& encoded_values_4 =
encoded_list(3).get<Tensor>()->vec<Variant>()(0);
test::ExpectTensorEqual<int>(*encoded_values_1.get<Tensor>(),
expected_values_1);
test::ExpectTensorEqual<int>(*encoded_values_2.get<Tensor>(),
expected_values_2);
test::ExpectTensorEqual<int>(*encoded_values_3.get<Tensor>(),
expected_values_3);
test::ExpectTensorEqual<int>(*encoded_values_4.get<Tensor>(),
expected_values_4);
ExpectRaggedTensorVariantEqual<int, int64>(
CreateVariantFromRagged<int, int64>({}, {1, 2, 3}),
*encoded_list(0).get<RaggedTensorVariant>());
ExpectRaggedTensorVariantEqual<int, int64>(
CreateVariantFromRagged<int, int64>({}, {}),
*encoded_list(1).get<RaggedTensorVariant>());
ExpectRaggedTensorVariantEqual<int, int64>(
CreateVariantFromRagged<int, int64>({}, {4, 5}),
*encoded_list(2).get<RaggedTensorVariant>());
ExpectRaggedTensorVariantEqual<int, int64>(
CreateVariantFromRagged<int, int64>({}, {6}),
*encoded_list(3).get<RaggedTensorVariant>());
}
TEST_F(RaggedTensorToVariantKernelTest, 2DBatchedValuesRankOneInput) {
// ragged_tensor=
// [[x, x],
// [x, x],
// [x, x]]
// [[1, 2],
// [4, 5],
// [6, 7]]
const std::vector<int64> batched_splits = {0, 1, 2, 3};
const std::vector<int> batched_values = {1, 2, 4, 5, 6, 7};
const std::vector<int> component_values_1 = {1, 2};
const std::vector<int> component_values_2 = {4, 5};
const std::vector<int> component_values_3 = {6, 7};
Tensor expected_values_1(DT_INT32, TensorShape({1, 2}));
Tensor expected_values_2(DT_INT32, TensorShape({1, 2}));
Tensor expected_values_3(DT_INT32, TensorShape({1, 2}));
test::FillValues<int>(&expected_values_1, component_values_1);
test::FillValues<int>(&expected_values_2, component_values_2);
test::FillValues<int>(&expected_values_3, component_values_3);
BuildEncodeRaggedTensorGraph<int, int64>(
{batched_splits}, TensorShape({3, 2}), batched_values, true);
TF_ASSERT_OK(RunOpKernel());
@ -188,44 +168,25 @@ TEST_F(RaggedTensorToVariantKernelTest, 2DBatchedValuesRankOneInput) {
const auto& encoded_list = GetOutput(0)->vec<Variant>();
EXPECT_EQ(encoded_list.size(), 3);
const Variant& encoded_values_1 =
encoded_list(0).get<Tensor>()->vec<Variant>()(0);
const Variant& encoded_values_2 =
encoded_list(1).get<Tensor>()->vec<Variant>()(0);
const Variant& encoded_values_3 =
encoded_list(2).get<Tensor>()->vec<Variant>()(0);
test::ExpectTensorEqual<int>(*encoded_values_1.get<Tensor>(),
expected_values_1);
test::ExpectTensorEqual<int>(*encoded_values_2.get<Tensor>(),
expected_values_2);
test::ExpectTensorEqual<int>(*encoded_values_3.get<Tensor>(),
expected_values_3);
ExpectRaggedTensorVariantEqual<int, int64>(
CreateVariantFromRagged<int, int64>({}, {1, 2}, {1, 2}),
*encoded_list(0).get<RaggedTensorVariant>());
ExpectRaggedTensorVariantEqual<int, int64>(
CreateVariantFromRagged<int, int64>({}, {1, 2}, {4, 5}),
*encoded_list(1).get<RaggedTensorVariant>());
ExpectRaggedTensorVariantEqual<int, int64>(
CreateVariantFromRagged<int, int64>({}, {1, 2}, {6, 7}),
*encoded_list(2).get<RaggedTensorVariant>());
}
TEST_F(RaggedTensorToVariantKernelTest, 2DBatchedValuesRankTwoInput) {
// ragged_tensor=[
// [ [[x, x], [x, x]],
// [[x, x] ] ]
// ragged_tensor=
// [ [[[1, 2], [4, 5]]],
// [[[6 7]]] ]
const std::vector<int64> batched_splits_1 = {0, 1, 2};
const std::vector<int64> batched_splits_2 = {0, 2, 3};
const std::vector<int> batched_values = {1, 2, 4, 5, 6, 7};
const std::vector<int64> component_splits_1_1 = {0, 2};
const std::vector<int64> component_splits_2_1 = {0, 1};
const std::vector<int> component_values_1 = {1, 2, 4, 5};
const std::vector<int> component_values_2 = {6, 7};
Tensor expected_splits_1_1(DT_INT64, TensorShape({2}));
Tensor expected_splits_2_1(DT_INT64, TensorShape({2}));
Tensor expected_values_1(DT_INT32, TensorShape({2, 2}));
Tensor expected_values_2(DT_INT32, TensorShape({1, 2}));
test::FillValues<int64>(&expected_splits_1_1, component_splits_1_1);
test::FillValues<int64>(&expected_splits_2_1, component_splits_2_1);
test::FillValues<int>(&expected_values_1, component_values_1);
test::FillValues<int>(&expected_values_2, component_values_2);
BuildEncodeRaggedTensorGraph<int, int64>({batched_splits_1, batched_splits_2},
TensorShape({3, 2}), batched_values,
true);
@ -234,23 +195,12 @@ TEST_F(RaggedTensorToVariantKernelTest, 2DBatchedValuesRankTwoInput) {
const auto& encoded_list = GetOutput(0)->vec<Variant>();
EXPECT_EQ(encoded_list.size(), 2);
const Variant& encoded_splits_1_1 =
encoded_list(0).get<Tensor>()->vec<Variant>()(0);
const Variant& encoded_values_1 =
encoded_list(0).get<Tensor>()->vec<Variant>()(1);
const Variant& encoded_splits_2_1 =
encoded_list(1).get<Tensor>()->vec<Variant>()(0);
const Variant& encoded_values_2 =
encoded_list(1).get<Tensor>()->vec<Variant>()(1);
test::ExpectTensorEqual<int64>(*encoded_splits_1_1.get<Tensor>(),
expected_splits_1_1);
test::ExpectTensorEqual<int>(*encoded_values_1.get<Tensor>(),
expected_values_1);
test::ExpectTensorEqual<int64>(*encoded_splits_2_1.get<Tensor>(),
expected_splits_2_1);
test::ExpectTensorEqual<int>(*encoded_values_2.get<Tensor>(),
expected_values_2);
ExpectRaggedTensorVariantEqual<int, int64>(
CreateVariantFromRagged<int, int64>({{0, 2}}, {2, 2}, {1, 2, 4, 5}),
*encoded_list(0).get<RaggedTensorVariant>());
ExpectRaggedTensorVariantEqual<int, int64>(
CreateVariantFromRagged<int, int64>({{0, 1}}, {1, 2}, {6, 7}),
*encoded_list(1).get<RaggedTensorVariant>());
}
TEST_F(RaggedTensorToVariantKernelTest, EmptyRowInBatchedInput) {
@ -263,30 +213,6 @@ TEST_F(RaggedTensorToVariantKernelTest, EmptyRowInBatchedInput) {
const std::vector<int64> batched_splits_2 = {0, 1, 3, 3, 8, 11, 11, 15};
const std::vector<int> batched_values = {1, 2, 3, 4, 5, 6, 7, 8,
9, 10, 11, 12, 13, 14, 15};
const std::vector<int64> component_splits_1_1 = {0, 1, 3, 3};
const std::vector<int64> component_splits_2_1 = {0};
const std::vector<int64> component_splits_3_1 = {0, 5, 8};
const std::vector<int64> component_splits_4_1 = {0, 0, 4};
const std::vector<int> component_values_1 = {1, 2, 3};
const std::vector<int> component_values_3 = {4, 5, 6, 7, 8, 9, 10, 11};
const std::vector<int> component_values_4 = {12, 13, 14, 15};
Tensor expected_splits_1_1(DT_INT64, TensorShape({4}));
Tensor expected_splits_2_1(DT_INT64, TensorShape({1}));
Tensor expected_splits_3_1(DT_INT64, TensorShape({3}));
Tensor expected_splits_4_1(DT_INT64, TensorShape({3}));
Tensor expected_values_1(DT_INT32, TensorShape({3}));
Tensor expected_values_2(DT_INT32, TensorShape({0}));
Tensor expected_values_3(DT_INT32, TensorShape({8}));
Tensor expected_values_4(DT_INT32, TensorShape({4}));
test::FillValues<int64>(&expected_splits_1_1, component_splits_1_1);
test::FillValues<int64>(&expected_splits_2_1, component_splits_2_1);
test::FillValues<int64>(&expected_splits_3_1, component_splits_3_1);
test::FillValues<int64>(&expected_splits_4_1, component_splits_4_1);
test::FillValues<int>(&expected_values_1, component_values_1);
test::FillValues<int>(&expected_values_3, component_values_3);
test::FillValues<int>(&expected_values_4, component_values_4);
BuildEncodeRaggedTensorGraph<int, int64>({batched_splits_1, batched_splits_2},
TensorShape({15}), batched_values,
@ -296,39 +222,19 @@ TEST_F(RaggedTensorToVariantKernelTest, EmptyRowInBatchedInput) {
const auto& encoded_list = GetOutput(0)->vec<Variant>();
EXPECT_EQ(encoded_list.size(), 4);
const Variant& encoded_splits_1_1 =
encoded_list(0).get<Tensor>()->vec<Variant>()(0);
const Variant& encoded_values_1 =
encoded_list(0).get<Tensor>()->vec<Variant>()(1);
const Variant& encoded_splits_2_1 =
encoded_list(1).get<Tensor>()->vec<Variant>()(0);
const Variant& encoded_values_2 =
encoded_list(1).get<Tensor>()->vec<Variant>()(1);
const Variant& encoded_splits_3_1 =
encoded_list(2).get<Tensor>()->vec<Variant>()(0);
const Variant& encoded_values_3 =
encoded_list(2).get<Tensor>()->vec<Variant>()(1);
const Variant& encoded_splits_4_1 =
encoded_list(3).get<Tensor>()->vec<Variant>()(0);
const Variant& encoded_values_4 =
encoded_list(3).get<Tensor>()->vec<Variant>()(1);
test::ExpectTensorEqual<int64>(*encoded_splits_1_1.get<Tensor>(),
expected_splits_1_1);
test::ExpectTensorEqual<int>(*encoded_values_1.get<Tensor>(),
expected_values_1);
test::ExpectTensorEqual<int64>(*encoded_splits_2_1.get<Tensor>(),
expected_splits_2_1);
test::ExpectTensorEqual<int>(*encoded_values_2.get<Tensor>(),
expected_values_2);
test::ExpectTensorEqual<int64>(*encoded_splits_3_1.get<Tensor>(),
expected_splits_3_1);
test::ExpectTensorEqual<int>(*encoded_values_3.get<Tensor>(),
expected_values_3);
test::ExpectTensorEqual<int64>(*encoded_splits_4_1.get<Tensor>(),
expected_splits_4_1);
test::ExpectTensorEqual<int>(*encoded_values_4.get<Tensor>(),
expected_values_4);
ExpectRaggedTensorVariantEqual<int, int64>(
CreateVariantFromRagged<int, int64>({{0, 1, 3, 3}}, {1, 2, 3}),
*encoded_list(0).get<RaggedTensorVariant>());
ExpectRaggedTensorVariantEqual<int, int64>(
CreateVariantFromRagged<int, int64>({{0}}, {}),
*encoded_list(1).get<RaggedTensorVariant>());
ExpectRaggedTensorVariantEqual<int, int64>(
CreateVariantFromRagged<int, int64>({{0, 5, 8}},
{4, 5, 6, 7, 8, 9, 10, 11}),
*encoded_list(2).get<RaggedTensorVariant>());
ExpectRaggedTensorVariantEqual<int, int64>(
CreateVariantFromRagged<int, int64>({{0, 0, 4}}, {12, 13, 14, 15}),
*encoded_list(3).get<RaggedTensorVariant>());
}
TEST_F(RaggedTensorToVariantKernelTest, NonEmptyBatchedInput) {
@ -350,26 +256,6 @@ TEST_F(RaggedTensorToVariantKernelTest, NonEmptyBatchedInput) {
7, 8, 9, 12, 13, 14};
const std::vector<int> batched_values = {0, 1, 1, 2, 2, 3, 4,
5, 6, 7, 8, 9, 8, 9};
const std::vector<int64> component_split_1_1 = {0, 1, 3, 4, 5, 6};
const std::vector<int64> component_split_1_2 = {0, 2, 3, 4, 5, 6, 7};
const std::vector<int64> component_split_2_1 = {0, 1, 2, 3, 4, 5};
const std::vector<int64> component_split_2_2 = {0, 1, 2, 5, 6, 7};
const std::vector<int> component_values_1 = {0, 1, 1, 2, 2, 3, 4};
const std::vector<int> component_values_2 = {5, 6, 7, 8, 9, 8, 9};
Tensor expected_splits_1_1(DT_INT64, TensorShape({6}));
Tensor expected_splits_1_2(DT_INT64, TensorShape({7}));
Tensor expected_splits_2_1(DT_INT64, TensorShape({6}));
Tensor expected_splits_2_2(DT_INT64, TensorShape({6}));
Tensor expected_values_1(DT_INT32, TensorShape({7}));
Tensor expected_values_2(DT_INT32, TensorShape({7}));
test::FillValues<int64>(&expected_splits_1_1, component_split_1_1);
test::FillValues<int64>(&expected_splits_1_2, component_split_1_2);
test::FillValues<int64>(&expected_splits_2_1, component_split_2_1);
test::FillValues<int64>(&expected_splits_2_2, component_split_2_2);
test::FillValues<int>(&expected_values_1, component_values_1);
test::FillValues<int>(&expected_values_2, component_values_2);
BuildEncodeRaggedTensorGraph<int, int64>(
{batched_splits_1, batched_splits_2, batched_splits_3}, TensorShape({14}),
@ -379,31 +265,14 @@ TEST_F(RaggedTensorToVariantKernelTest, NonEmptyBatchedInput) {
const auto& encoded_list = GetOutput(0)->vec<Variant>();
EXPECT_EQ(encoded_list.size(), 2);
const Variant& encoded_splits_1_1 =
encoded_list(0).get<Tensor>()->vec<Variant>()(0);
const Variant& encoded_splits_1_2 =
encoded_list(0).get<Tensor>()->vec<Variant>()(1);
const Variant& encoded_values_1 =
encoded_list(0).get<Tensor>()->vec<Variant>()(2);
const Variant& encoded_splits_2_1 =
encoded_list(1).get<Tensor>()->vec<Variant>()(0);
const Variant& encoded_splits_2_2 =
encoded_list(1).get<Tensor>()->vec<Variant>()(1);
const Variant& encoded_values_2 =
encoded_list(1).get<Tensor>()->vec<Variant>()(2);
test::ExpectTensorEqual<int64>(*encoded_splits_1_1.get<Tensor>(),
expected_splits_1_1);
test::ExpectTensorEqual<int64>(*encoded_splits_1_2.get<Tensor>(),
expected_splits_1_2);
test::ExpectTensorEqual<int64>(*encoded_splits_2_1.get<Tensor>(),
expected_splits_2_1);
test::ExpectTensorEqual<int64>(*encoded_splits_2_2.get<Tensor>(),
expected_splits_2_2);
test::ExpectTensorEqual<int>(*encoded_values_1.get<Tensor>(),
expected_values_1);
test::ExpectTensorEqual<int>(*encoded_values_2.get<Tensor>(),
expected_values_2);
ExpectRaggedTensorVariantEqual<int, int64>(
CreateVariantFromRagged<int, int64>(
{{0, 1, 3, 4, 5, 6}, {0, 2, 3, 4, 5, 6, 7}}, {0, 1, 1, 2, 2, 3, 4}),
*encoded_list(0).get<RaggedTensorVariant>());
ExpectRaggedTensorVariantEqual<int, int64>(
CreateVariantFromRagged<int, int64>(
{{0, 1, 2, 3, 4, 5}, {0, 1, 2, 5, 6, 7}}, {5, 6, 7, 8, 9, 8, 9}),
*encoded_list(1).get<RaggedTensorVariant>());
}
TEST_F(RaggedTensorToVariantKernelTest, NonEmptyBatchedInputInt32Splits) {
@ -424,28 +293,8 @@ TEST_F(RaggedTensorToVariantKernelTest, NonEmptyBatchedInputInt32Splits) {
7, 8, 9, 12, 13, 14};
const std::vector<int> batched_values = {0, 1, 1, 2, 2, 3, 4,
5, 6, 7, 8, 9, 8, 9};
const std::vector<int> component_split_1_1 = {0, 1, 3, 4, 5, 6};
const std::vector<int> component_split_1_2 = {0, 2, 3, 4, 5, 6, 7};
const std::vector<int> component_split_2_1 = {0, 1, 2, 3, 4, 5};
const std::vector<int> component_split_2_2 = {0, 1, 2, 5, 6, 7};
const std::vector<int> component_values_1 = {0, 1, 1, 2, 2, 3, 4};
const std::vector<int> component_values_2 = {5, 6, 7, 8, 9, 8, 9};
Tensor expected_splits_1_1(DT_INT32, TensorShape({6}));
Tensor expected_splits_1_2(DT_INT32, TensorShape({7}));
Tensor expected_splits_2_1(DT_INT32, TensorShape({6}));
Tensor expected_splits_2_2(DT_INT32, TensorShape({6}));
Tensor expected_values_1(DT_INT32, TensorShape({7}));
Tensor expected_values_2(DT_INT32, TensorShape({7}));
test::FillValues<int>(&expected_splits_1_1, component_split_1_1);
test::FillValues<int>(&expected_splits_1_2, component_split_1_2);
test::FillValues<int>(&expected_splits_2_1, component_split_2_1);
test::FillValues<int>(&expected_splits_2_2, component_split_2_2);
test::FillValues<int>(&expected_values_1, component_values_1);
test::FillValues<int>(&expected_values_2, component_values_2);
BuildEncodeRaggedTensorGraph<int, int>(
BuildEncodeRaggedTensorGraph<int, int32>(
{batched_splits_1, batched_splits_2, batched_splits_3}, TensorShape({14}),
batched_values, true);
TF_ASSERT_OK(RunOpKernel());
@ -453,31 +302,14 @@ TEST_F(RaggedTensorToVariantKernelTest, NonEmptyBatchedInputInt32Splits) {
const auto& encoded_list = GetOutput(0)->vec<Variant>();
EXPECT_EQ(encoded_list.size(), 2);
const Variant& encoded_splits_1_1 =
encoded_list(0).get<Tensor>()->vec<Variant>()(0);
const Variant& encoded_splits_1_2 =
encoded_list(0).get<Tensor>()->vec<Variant>()(1);
const Variant& encoded_values_1 =
encoded_list(0).get<Tensor>()->vec<Variant>()(2);
const Variant& encoded_splits_2_1 =
encoded_list(1).get<Tensor>()->vec<Variant>()(0);
const Variant& encoded_splits_2_2 =
encoded_list(1).get<Tensor>()->vec<Variant>()(1);
const Variant& encoded_values_2 =
encoded_list(1).get<Tensor>()->vec<Variant>()(2);
test::ExpectTensorEqual<int>(*encoded_splits_1_1.get<Tensor>(),
expected_splits_1_1);
test::ExpectTensorEqual<int>(*encoded_splits_1_2.get<Tensor>(),
expected_splits_1_2);
test::ExpectTensorEqual<int>(*encoded_splits_2_1.get<Tensor>(),
expected_splits_2_1);
test::ExpectTensorEqual<int>(*encoded_splits_2_2.get<Tensor>(),
expected_splits_2_2);
test::ExpectTensorEqual<int>(*encoded_values_1.get<Tensor>(),
expected_values_1);
test::ExpectTensorEqual<int>(*encoded_values_2.get<Tensor>(),
expected_values_2);
ExpectRaggedTensorVariantEqual<int, int32>(
CreateVariantFromRagged<int, int32>(
{{0, 1, 3, 4, 5, 6}, {0, 2, 3, 4, 5, 6, 7}}, {0, 1, 1, 2, 2, 3, 4}),
*encoded_list(0).get<RaggedTensorVariant>());
ExpectRaggedTensorVariantEqual<int, int32>(
CreateVariantFromRagged<int, int32>(
{{0, 1, 2, 3, 4, 5}, {0, 1, 2, 5, 6, 7}}, {5, 6, 7, 8, 9, 8, 9}),
*encoded_list(1).get<RaggedTensorVariant>());
}
TEST_F(RaggedTensorToVariantKernelTest, NonBatchInput) {
@ -491,33 +323,17 @@ TEST_F(RaggedTensorToVariantKernelTest, NonBatchInput) {
const std::vector<int> batched_values = {1, 2, 3, 4, 5, 6, 7, 8,
9, 10, 11, 12, 13, 14, 15};
Tensor batched_ragged_splits_1(DT_INT64, TensorShape({5}));
Tensor batched_ragged_splits_2(DT_INT64, TensorShape({8}));
Tensor batched_ragged_values(DT_INT32, TensorShape({15}));
test::FillValues<int64>(&batched_ragged_splits_1, batched_splits_1);
test::FillValues<int64>(&batched_ragged_splits_2, batched_splits_2);
test::FillValues<int>(&batched_ragged_values, batched_values);
BuildEncodeRaggedTensorGraph<int, int64>({batched_splits_1, batched_splits_2},
TensorShape({15}), batched_values,
false);
TF_ASSERT_OK(RunOpKernel());
const auto& encoded_scalar = GetOutput(0)->scalar<Variant>()();
const Variant& encoded_splits_1 =
encoded_scalar.get<Tensor>()->vec<Variant>()(0);
const Variant& encoded_splits_2 =
encoded_scalar.get<Tensor>()->vec<Variant>()(1);
const Variant& encoded_values =
encoded_scalar.get<Tensor>()->vec<Variant>()(2);
test::ExpectTensorEqual<int64>(*encoded_splits_1.get<Tensor>(),
batched_ragged_splits_1);
test::ExpectTensorEqual<int64>(*encoded_splits_2.get<Tensor>(),
batched_ragged_splits_2);
test::ExpectTensorEqual<int>(*encoded_values.get<Tensor>(),
batched_ragged_values);
ExpectRaggedTensorVariantEqual<int, int64>(
CreateVariantFromRagged<int, int64>({batched_splits_1, batched_splits_2},
batched_values),
*encoded_scalar.get<RaggedTensorVariant>());
}
TEST_F(RaggedTensorToVariantKernelTest, ShapeFnTestBatched) {
@ -598,17 +414,14 @@ TEST_F(RaggedTensorToVariantKernelTest, ShapeFnTestNotBatched) {
TEST_F(RaggedTensorToVariantKernelTest, NonRaggedInput) {
const std::vector<int> values = {1, 2, 3, 4, 5, 6};
Tensor expected_values(DT_INT32, TensorShape({6}));
test::FillValues<int>(&expected_values, values);
BuildEncodeRaggedTensorGraph<int, int64>({}, TensorShape({6}), values, false);
TF_ASSERT_OK(RunOpKernel());
const auto& encoded_scalar = GetOutput(0)->scalar<Variant>()();
const Variant& encoded_values =
encoded_scalar.get<Tensor>()->vec<Variant>()(0);
test::ExpectTensorEqual<int>(*encoded_values.get<Tensor>(), expected_values);
ExpectRaggedTensorVariantEqual<int, int64>(
CreateVariantFromRagged<int, int64>({}, values),
*encoded_scalar.get<RaggedTensorVariant>());
}
} // namespace

View File

@ -0,0 +1,86 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#define EIGEN_USE_THREADS
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define EIGEN_USE_GPU
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#include "tensorflow/core/kernels/ragged_tensor_variant.h"
namespace tensorflow {
string RaggedTensorVariant::TypeName() const { return "RaggedTensorVariant"; }
string RaggedTensorVariant::DebugString() const {
return absl::StrCat(
"RaggedTensorVariant(dtype=", DataTypeString(values_.dtype()),
", ragged_rank=", nested_splits_.size(), ", splits_dtype=",
DataTypeString(nested_splits_.empty() ? DT_INVALID
: nested_splits_.back().dtype()));
}
void RaggedTensorVariant::Encode(VariantTensorData* data) const {
data->set_type_name(TypeName());
for (const auto& splits : nested_splits_) {
*data->add_tensors() = splits;
}
*data->add_tensors() = values_;
}
bool RaggedTensorVariant::Decode(const VariantTensorData& data) {
if (data.tensors_size() < 1) {
return false;
}
nested_splits_.assign(data.tensors().begin(),
std::prev(data.tensors().end()));
values_ = data.tensors().back();
return true;
}
namespace {
Status RaggedTensorVariantDeviceCopy(
const RaggedTensorVariant& from, RaggedTensorVariant* to,
const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copy) {
TF_RETURN_IF_ERROR(copy(from.values(), to->mutable_values()));
// TODO(b/170415165) Should we use `copy` to move splits from device<->host?
*to->mutable_nested_splits() = from.nested_splits();
return Status::OK();
}
} // namespace
REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(
ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_CPU, RaggedTensorVariant,
RaggedTensorVariantZerosLike<CPUDevice>);
REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(
ADD_VARIANT_BINARY_OP, DEVICE_CPU, RaggedTensorVariant,
RaggedTensorVariantBinaryAdd<CPUDevice>);
REGISTER_UNARY_VARIANT_DECODE_FUNCTION(RaggedTensorVariant,
"RaggedTensorVariant");
#define REGISTER_RAGGED_TENSOR_VARIANT_COPY(DIRECTION) \
INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( \
RaggedTensorVariant, DIRECTION, RaggedTensorVariantDeviceCopy)
REGISTER_RAGGED_TENSOR_VARIANT_COPY(VariantDeviceCopyDirection::HOST_TO_DEVICE);
REGISTER_RAGGED_TENSOR_VARIANT_COPY(VariantDeviceCopyDirection::DEVICE_TO_HOST);
REGISTER_RAGGED_TENSOR_VARIANT_COPY(
VariantDeviceCopyDirection::DEVICE_TO_DEVICE);
} // namespace tensorflow

View File

@ -0,0 +1,110 @@
#include "tensorflow/core/framework/tensor_key.h"
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_KERNELS_RAGGED_TENSOR_VARIANT_H_
#define TENSORFLOW_CORE_KERNELS_RAGGED_TENSOR_VARIANT_H_
#define EIGEN_USE_THREADS
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define EIGEN_USE_GPU
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#include <vector>
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/variant_op_registry.h"
#include "tensorflow/core/framework/variant_tensor_data.h"
#include "tensorflow/core/kernels/cwise_ops_common.h"
#include "tensorflow/core/util/tensor_ops_util.h"
namespace tensorflow {
// Class used to store a RaggedTensor as a Variant scalar.
class RaggedTensorVariant {
public:
RaggedTensorVariant() {}
RaggedTensorVariant(Tensor values, const std::vector<Tensor>& nested_splits)
: values_(std::move(values)), nested_splits_(nested_splits) {}
// Variant support methods.
string TypeName() const;
string DebugString() const;
void Encode(VariantTensorData* data) const;
bool Decode(const VariantTensorData& data);
// The flat_values of the RaggedTensor.
const Tensor& values() const { return values_; }
Tensor* mutable_values() { return &values_; }
void set_values(const Tensor& new_values) { values_ = new_values; }
// The nested row_splits of the RaggedTensor.
int ragged_rank() const { return nested_splits_.size(); }
const std::vector<Tensor>& nested_splits() const { return nested_splits_; }
std::vector<Tensor>* mutable_nested_splits() { return &nested_splits_; }
const Tensor& splits(int i) const { return nested_splits_[i]; }
Tensor* mutable_splits(int i) { return &nested_splits_[i]; }
void set_nested_splits(const std::vector<Tensor>& nested_splits) {
nested_splits_ = nested_splits;
}
void append_splits(const Tensor& splits) { nested_splits_.push_back(splits); }
private:
Tensor values_;
std::vector<Tensor> nested_splits_;
};
template <typename Device>
Status RaggedTensorVariantZerosLike(OpKernelContext* c,
const RaggedTensorVariant& x,
RaggedTensorVariant* y) {
y->set_nested_splits(x.nested_splits());
TF_RETURN_IF_ERROR(
ZerosLikeTensor<Device>(c, x.values(), y->mutable_values()));
return Status::OK();
}
template <typename Device>
Status RaggedTensorVariantBinaryAdd(OpKernelContext* c,
const RaggedTensorVariant& x,
const RaggedTensorVariant& y,
RaggedTensorVariant* out) {
if (x.values().dtype() != y.values().dtype()) {
return errors::InvalidArgument(
"Can't add RaggedTensorVariants of different dtypes. One is ",
DataTypeString(x.values().dtype()), " and the other is ",
DataTypeString(y.values().dtype()));
}
if (x.ragged_rank() != y.ragged_rank()) {
return errors::InvalidArgument(
"Can't add RaggedTensorVariants of different ragged rank. ", "One is ",
x.ragged_rank(), " and the other is ", y.ragged_rank());
}
for (int i = 0; i < x.ragged_rank(); ++i) {
if (TensorKey(x.splits(i)) != TensorKey(y.splits(i))) {
return errors::InvalidArgument(
"Can't add RaggedTensorVariants with different row_splits.");
}
}
out->set_nested_splits(x.nested_splits());
TF_RETURN_IF_ERROR(BinaryAddTensors<Device>(c, x.values(), y.values(),
out->mutable_values()));
return Status::OK();
}
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_RAGGED_TENSOR_VARIANT_H_

View File

@ -92,7 +92,8 @@ tensorflow::Status ValidateRowPartitionTypesAndShapes(
Status RaggedTensorToSparseShapeFn(InferenceContext* c);
Status RaggedTensorToVariantShapeFn(InferenceContext* c);
Status RaggedTensorFromVariantShapeFn(InferenceContext* c);
tensorflow::Status RaggedTensorToTensorShapeFn(InferenceContext* c);
Status RaggedTensorToVariantGradientShapeFn(InferenceContext* c);
Status RaggedTensorToTensorShapeFn(InferenceContext* c);
//==============================================================================
// Registered Ops
@ -129,6 +130,15 @@ REGISTER_OP("RaggedTensorFromVariant")
.Attr("Tsplits: {int32, int64} = DT_INT64")
.SetShapeFn(RaggedTensorFromVariantShapeFn);
REGISTER_OP("RaggedTensorToVariantGradient")
.Input("encoded_ragged_grad: variant")
.Input("row_splits: Tsplits")
.Input("dense_values_shape: int32")
.Output("dense_values_grad: Tvalues")
.Attr("Tvalues: type")
.Attr("Tsplits: {int32, int64} = DT_INT64")
.SetShapeFn(RaggedTensorToVariantGradientShapeFn);
REGISTER_OP("RaggedTensorToTensor")
.Attr("T: type")
.Attr("Tindex: {int64, int32}")
@ -201,6 +211,14 @@ Status RaggedTensorToVariantShapeFn(InferenceContext* c) {
return Status::OK();
}
Status RaggedTensorToVariantGradientShapeFn(InferenceContext* c) {
ShapeHandle shape;
TF_RETURN_IF_ERROR(
c->MakeShapeFromShapeTensorTreatScalarAsUnknownShape(2, &shape));
c->set_output(0, shape);
return Status::OK();
}
Status RaggedTensorFromVariantShapeFn(InferenceContext* c) {
int64 input_ragged_rank;
TF_RETURN_IF_ERROR(

View File

@ -510,6 +510,7 @@ py_test(
"//tensorflow/python:framework_ops",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
"//tensorflow/python:tensor_array_grad",
"//tensorflow/python:tensor_shape",
"//tensorflow/python:tensor_spec",
"//tensorflow/python/data/ops:dataset_ops",

View File

@ -143,3 +143,42 @@ def to_sparse(rt_input, name=None):
def from_sparse(st_input, name=None):
return ragged_tensor.RaggedTensor.from_sparse(st_input, name)
@ops.RegisterGradient("RaggedTensorFromVariant")
def _ragged_tensor_from_variant_grad(op, *grads):
"""Gradient for RaggedTensorFromVariant op."""
variant_rank = op.inputs[0].shape.rank
if variant_rank == 0:
batched_input = False
elif variant_rank == 1:
batched_input = True
elif variant_rank is None:
batched_input = (op.get_attr("output_ragged_rank") > 0)
else:
# TODO(edloper): Add a batch_dims argument to RaggedTensorToVariant, so
# we can support this.
raise ValueError("Unable to compute gradient: RaggedTensorToVariant "
"can currently only generate 0D or 1D output.")
return [
gen_ragged_conversion_ops.ragged_tensor_to_variant(
rt_nested_splits=op.outputs[:-1],
rt_dense_values=grads[-1],
batched_input=batched_input)
]
@ops.RegisterGradient("RaggedTensorToVariant")
def _ragged_tensor_to_variant_grad(op, encoded_ragged_grad):
"""Gradient for RaggedTensorToVariant op."""
dense_values = op.inputs[-1]
ragged_rank = len(op.inputs) - 1
row_splits = 0 if ragged_rank == 0 else op.inputs[0]
values_grad = gen_ragged_conversion_ops.ragged_tensor_to_variant_gradient(
encoded_ragged_grad=encoded_ragged_grad,
row_splits=row_splits,
dense_values_shape=array_ops.shape(dense_values),
Tvalues=op.inputs[-1].dtype)
result = [None] * ragged_rank + [values_grad]
return result

View File

@ -2863,9 +2863,6 @@ def _get_optional_partition_dtype(values):
return None
ops.no_gradient("RaggedTensorToVariant")
_SUPPORTED_RAGGED_VALUE_TYPES = (ops.Tensor, RaggedTensor)

View File

@ -18,10 +18,12 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
from absl.testing import parameterized
import numpy as np
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@ -30,8 +32,15 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_grad # pylint: disable=unused-import
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_ragged_conversion_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import map_fn
from tensorflow.python.ops import math_grad # pylint: disable=unused-import
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import tensor_array_grad # pylint: disable=unused-import
from tensorflow.python.ops.ragged import ragged_factory_ops
from tensorflow.python.ops.ragged import ragged_math_ops
from tensorflow.python.ops.ragged import ragged_tensor
@ -1233,19 +1242,21 @@ class RaggedTensorTest(test_util.TensorFlowTestCase, parameterized.TestCase):
with self.assertRaises(errors.InvalidArgumentError):
self.evaluate(factory(**kwargs))
#=============================================================================
# RaggedTensor Variant conversion
#=============================================================================
#=============================================================================
# RaggedTensor Variant conversion
#=============================================================================
@parameterized.parameters(
@parameterized.named_parameters(
{
'testcase_name': 'Shape_5_none',
'ragged_constant': [[1, 2], [3, 4, 5], [6], [], [7]],
'ragged_rank': 1
}, {
'testcase_name': 'Shape_4_none_2',
'ragged_constant': [[[1, 2]], [], [[3, 4]], []],
'ragged_rank': 1
}, {
'testcase_name': 'Shape_1_none_none',
'ragged_constant': [[[1], [2, 3, 4, 5, 6, 7]], [[]]],
'ragged_rank': 2
})
@ -1432,6 +1443,131 @@ class RaggedTensorTest(test_util.TensorFlowTestCase, parameterized.TestCase):
output_ragged_rank=1,
input_ragged_rank=1)
def _testRaggedVarientGradient(self, func, x, expected_grad):
x = constant_op.constant(x)
if context.executing_eagerly():
with backprop.GradientTape() as t:
t.watch(x)
y = func(x)
g = t.gradient(y, x)
else:
y = func(x)
g = gradients_impl.gradients(ys=y, xs=x)[0]
self.assertAllClose(g, expected_grad)
def testRaggedVariantGradients(self):
def func(x):
rt1 = RaggedTensor.from_row_splits(values=x, row_splits=[0, 4, 7, 8])
rt2 = rt1 * [[10], [100], [1000]]
v = rt2._to_variant(batched_input=False)
rt3 = RaggedTensor._from_variant(v, dtype=rt2.dtype, output_ragged_rank=1)
return rt3.flat_values
self._testRaggedVarientGradient(
func,
[3.0, 1.0, 4.0, 1.0, 1.0, 0.0, 2.0, 1.0],
[10., 10., 10., 10., 100., 100., 100., 1000.])
def testRaggedVariantGradientsBatched(self):
def func(x):
rt1 = RaggedTensor.from_row_splits(values=x, row_splits=[0, 4, 7, 8])
rt2 = rt1 * [[10], [100], [1000]]
v = rt2._to_variant(batched_input=True)
rt3 = RaggedTensor._from_variant(v, dtype=rt2.dtype, output_ragged_rank=1)
return rt3.flat_values
self._testRaggedVarientGradient(
func,
[3.0, 1.0, 4.0, 1.0, 1.0, 0.0, 2.0, 1.0],
[10., 10., 10., 10., 100., 100., 100., 1000.])
def testRaggedVariantGradientsBatchedAndSliced(self):
def func(x, i):
rt1 = RaggedTensor.from_row_splits(values=x, row_splits=[0, 4, 7, 8])
rt2 = rt1 * [[10], [100], [1000]]
v_slice = rt2._to_variant(batched_input=True)[i]
return RaggedTensor._from_variant(v_slice, dtype=rt2.dtype,
output_ragged_rank=0)
self._testRaggedVarientGradient(
functools.partial(func, i=0),
[3.0, 1.0, 4.0, 1.0, 1.0, 0.0, 2.0, 1.0],
[10., 10., 10., 10., 0., 0., 0., 0.])
self._testRaggedVarientGradient(
functools.partial(func, i=1),
[3.0, 1.0, 4.0, 1.0, 1.0, 0.0, 2.0, 1.0],
[0., 0., 0., 0., 100., 100., 100., 0.])
self._testRaggedVarientGradient(
functools.partial(func, i=2),
[3.0, 1.0, 4.0, 1.0, 1.0, 0.0, 2.0, 1.0],
[0., 0., 0., 0., 0., 0., 0., 1000.])
def testRaggedVariantGradientsRaggedRank0(self):
def func(x):
x2 = x * 2
v = gen_ragged_conversion_ops.ragged_tensor_to_variant(
[], x2, batched_input=False)
return RaggedTensor._from_variant(v, dtype=x2.dtype, output_ragged_rank=0)
self._testRaggedVarientGradient(
func,
[3.0, 1.0, 4.0, 1.0, 1.0, 0.0, 2.0, 1.0],
[2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0])
def testRaggedVariantGradientsRaggedRank3(self):
def func(x):
x2 = x * 2
rt1 = RaggedTensor.from_nested_row_splits(
x2, ([0, 0, 3], [0, 2, 2, 3], [0, 4, 7, 8]))
v = rt1._to_variant(batched_input=False)
rt3 = RaggedTensor._from_variant(v, dtype=x2.dtype, output_ragged_rank=3)
return rt3.flat_values
self._testRaggedVarientGradient(
func,
[3.0, 1.0, 4.0, 1.0, 1.0, 0.0, 2.0, 1.0],
[2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0])
def testRaggedVariantGradientsViaMapFn(self):
rt = RaggedTensor.from_row_splits(
values=[3, 1.0, 4, 1, 5, 9, 2, 6], row_splits=[0, 4, 7, 8])
def func(x):
def transform_row(row):
return math_ops.sqrt(
math_ops.reduce_mean(math_ops.square(row * x), keepdims=True))
return math_ops.reduce_sum(map_fn.map_fn(transform_row, rt))
self._testRaggedVarientGradient(func, 3.0, 14.653377)
def testRaggedVariantGradientsViaMapFnReduce(self):
def func(x):
rt1 = RaggedTensor.from_row_splits(values=x, row_splits=[0, 4, 7, 8])
return map_fn.map_fn(
math_ops.reduce_max, rt1,
fn_output_signature=tensor_spec.TensorSpec((), x.dtype))
self._testRaggedVarientGradient(
func,
[3.0, 1.0, 4.0, 1.0, 1.0, 0.0, 2.0, 1.0],
[0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0])
def testRaggedVariantGradientsErrors(self):
if context.executing_eagerly():
return
rt = RaggedTensor.from_row_splits([1.0, 2.0], row_splits=[0, 2, 2])
v1 = rt._to_variant()
v2 = array_ops.stack([array_ops.stack([v1])])
y = RaggedTensor._from_variant(v2, rt.dtype, output_ragged_rank=3)
with self.assertRaisesRegex(
ValueError, 'Unable to compute gradient: RaggedTensorToVariant '
'can currently only generate 0D or 1D output.'):
gradients_impl.gradients(ys=y.flat_values, xs=rt.flat_values)
def assertNumpyObjectTensorsRecursivelyEqual(self, a, b, msg):
"""Check that two numpy arrays are equal.

View File

@ -3212,6 +3212,10 @@ tf_module {
name: "RaggedTensorToVariant"
argspec: "args=[\'rt_nested_splits\', \'rt_dense_values\', \'batched_input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "RaggedTensorToVariantGradient"
argspec: "args=[\'encoded_ragged_grad\', \'row_splits\', \'dense_values_shape\', \'Tvalues\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "RandomCrop"
argspec: "args=[\'image\', \'size\', \'seed\', \'seed2\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'0\', \'None\'], "

View File

@ -3212,6 +3212,10 @@ tf_module {
name: "RaggedTensorToVariant"
argspec: "args=[\'rt_nested_splits\', \'rt_dense_values\', \'batched_input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "RaggedTensorToVariantGradient"
argspec: "args=[\'encoded_ragged_grad\', \'row_splits\', \'dense_values_shape\', \'Tvalues\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "RandomCrop"
argspec: "args=[\'image\', \'size\', \'seed\', \'seed2\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'0\', \'None\'], "