Added gradients for RaggedTensorToVariant and RaggedTensorFromVariant. (This allows gradients to pass through map_fn when it is applied to ragged tensors.)
PiperOrigin-RevId: 337108621 Change-Id: I73d5f3296181877f0cc4c7a6273b693bcf8310ab
This commit is contained in:
parent
be99659bf6
commit
be6b1fdb06
@ -138,6 +138,7 @@
|
||||
stateful ops.
|
||||
* Added `tf.config.experimental.get_memory_usage` to return total memory
|
||||
usage of the device.
|
||||
* Added gradients for `RaggedTensorToVariant` and `RaggedTensorFromVariant`.
|
||||
* `tf.data`:
|
||||
* tf.data service:
|
||||
* Added new `tf.data.experimental.service.register_dataset` and
|
||||
|
@ -0,0 +1,38 @@
|
||||
op {
|
||||
graph_op_name: "RaggedTensorToVariantGradient"
|
||||
visibility: HIDDEN
|
||||
in_arg {
|
||||
name: "encoded_ragged_grad"
|
||||
description: <<END
|
||||
A `variant` Tensor containing encoded `RaggedTensor` gradients.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "row_splits"
|
||||
description: <<END
|
||||
Outermost row-splits that were used as input to the RaggedTensorToVariant op.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "dense_values_shape"
|
||||
description: <<END
|
||||
Shape of the dense_values that was used as an input to the
|
||||
RaggedTensorToVariant op.
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "dense_values_grad"
|
||||
description: <<END
|
||||
Gradient for the dense_values of the RaggedTensorToVariant op.
|
||||
END
|
||||
}
|
||||
summary: <<END
|
||||
Helper used to compute the gradient for `RaggedTensorToVariant`.
|
||||
END
|
||||
description: <<END
|
||||
Computes the gradient for the dense_values input to the RaggedTensorToVariant
|
||||
op, given the variant-encoded ragged gradients of the outputs, along with
|
||||
the outer row-splits and the shape of the dense-values that were provided as
|
||||
inputs to the RaggedTensorToVariant op.
|
||||
END
|
||||
}
|
@ -1419,10 +1419,22 @@ tf_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "ragged_tensor_variant",
|
||||
srcs = ["ragged_tensor_variant.cc"],
|
||||
hdrs = ["ragged_tensor_variant.h"],
|
||||
deps = [
|
||||
":cwise_op",
|
||||
"//tensorflow/core:framework",
|
||||
],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "ragged_tensor_to_variant_op",
|
||||
srcs = ["ragged_tensor_to_variant_op.cc"],
|
||||
deps = [
|
||||
":concat_lib",
|
||||
":ragged_tensor_variant",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
@ -1432,6 +1444,7 @@ tf_kernel_library(
|
||||
name = "ragged_tensor_from_variant_op",
|
||||
srcs = ["ragged_tensor_from_variant_op.cc"],
|
||||
deps = [
|
||||
":ragged_tensor_variant",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
@ -1444,6 +1457,7 @@ tf_cc_test(
|
||||
deps = [
|
||||
":ops_testutil",
|
||||
":ragged_tensor_to_variant_op",
|
||||
":ragged_tensor_variant",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
@ -1460,6 +1474,7 @@ tf_cc_test(
|
||||
deps = [
|
||||
":ops_testutil",
|
||||
":ragged_tensor_from_variant_op",
|
||||
":ragged_tensor_variant",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
|
@ -424,6 +424,7 @@ tf_kernel_library(
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:functional_ops_op_lib",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/kernels:ragged_tensor_variant",
|
||||
"//tensorflow/core/kernels/data:dataset_utils",
|
||||
"//tensorflow/core/kernels/data:name_utils",
|
||||
"//tensorflow/core/kernels/data:parallel_map_dataset_op",
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/kernels/data/name_utils.h"
|
||||
#include "tensorflow/core/kernels/data/parallel_map_dataset_op.h"
|
||||
#include "tensorflow/core/kernels/data/stats_utils.h"
|
||||
#include "tensorflow/core/kernels/ragged_tensor_variant.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/stringprintf.h"
|
||||
#include "tensorflow/core/profiler/lib/traceme.h"
|
||||
@ -678,12 +679,9 @@ class ParseExampleDatasetOp : public UnaryDatasetOpKernel {
|
||||
for (int d = 0; d < dataset()->ragged_keys_.size(); ++d) {
|
||||
int output_index =
|
||||
dataset()->key_to_output_index_.at(dataset()->ragged_keys_[d]);
|
||||
(*output)[output_index] = Tensor(ctx->allocator({}), DT_VARIANT, {});
|
||||
Tensor serialized_ragged =
|
||||
Tensor(ctx->allocator({}), DT_VARIANT, {2});
|
||||
auto serialized_ragged_t = serialized_ragged.vec<Variant>();
|
||||
serialized_ragged_t(0) = example_result.ragged_splits[d];
|
||||
serialized_ragged_t(1) = example_result.ragged_values[d];
|
||||
RaggedTensorVariant serialized_ragged;
|
||||
serialized_ragged.append_splits(example_result.ragged_splits[d]);
|
||||
serialized_ragged.set_values(example_result.ragged_values[d]);
|
||||
(*output)[output_index] = Tensor(ctx->allocator({}), DT_VARIANT, {});
|
||||
Tensor& ragged_wrapper = (*output)[output_index];
|
||||
ragged_wrapper.scalar<Variant>()() = serialized_ragged;
|
||||
|
@ -20,110 +20,76 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/variant.h"
|
||||
#include "tensorflow/core/framework/variant_encode_decode.h"
|
||||
#include "tensorflow/core/kernels/ragged_tensor_variant.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
struct RaggedTensor {
|
||||
Tensor values;
|
||||
std::vector<Tensor> nested_splits;
|
||||
};
|
||||
|
||||
Status RaggedComponentsFromVariant(const Tensor& encoded_variant,
|
||||
int ragged_rank, DataType value_dtype,
|
||||
DataType split_dtype,
|
||||
std::vector<RaggedTensor>* decoded_ragged) {
|
||||
Status RaggedComponentsFromVariant(
|
||||
const Tensor& encoded_variant, int ragged_rank, DataType value_dtype,
|
||||
DataType split_dtype, std::vector<RaggedTensorVariant>* decoded_ragged) {
|
||||
const auto& flat_variants = encoded_variant.flat<Variant>();
|
||||
decoded_ragged->resize(flat_variants.size());
|
||||
// Step 1: Extract the 1-D DT_VARIANT Tensor from each Variant element in the
|
||||
// input.
|
||||
decoded_ragged->reserve(flat_variants.size());
|
||||
|
||||
for (int i = 0; i < flat_variants.size(); i++) {
|
||||
const auto& flat_variant = flat_variants(i);
|
||||
const Tensor* encoded_list = flat_variant.get<Tensor>();
|
||||
if (encoded_list == nullptr) {
|
||||
const RaggedTensorVariant* decoded =
|
||||
flat_variant.get<RaggedTensorVariant>();
|
||||
if (decoded == nullptr) {
|
||||
return errors::InvalidArgument(
|
||||
"Input Variant element at index ", i,
|
||||
" doesn't hold a Tensor: ", flat_variant.DebugString());
|
||||
" doesn't hold a RaggedTensorVariant: ", flat_variant.DebugString());
|
||||
}
|
||||
if (encoded_list->dims() != 1) {
|
||||
decoded_ragged->push_back(*decoded);
|
||||
decoded = &decoded_ragged->back();
|
||||
// Check ragged rank & types
|
||||
if (decoded->ragged_rank() != ragged_rank) {
|
||||
return errors::InvalidArgument(
|
||||
"Encoded input Variant must have rank 1, but found rank: ",
|
||||
encoded_list->dims(),
|
||||
". encoded input Variant: ", encoded_list->DebugString());
|
||||
"Encoded input RaggedTensorVariant has ragged_rank=",
|
||||
decoded->ragged_rank(), ". Expected ragged_rank=", ragged_rank, ".");
|
||||
}
|
||||
if (encoded_list->NumElements() != (ragged_rank + 1) &&
|
||||
encoded_list->NumElements() != 1) {
|
||||
return errors::InvalidArgument(
|
||||
"Encoded input Variant must hold either input_ragged_rank + 1 "
|
||||
"Tensors or an empty Tensor (zero splits Tensors, 1 values Tensor), "
|
||||
"input_ragged_rank: ",
|
||||
ragged_rank,
|
||||
", encoded input Variant: ", encoded_list->DebugString());
|
||||
}
|
||||
const auto& input_vec = encoded_list->vec<Variant>();
|
||||
|
||||
// Step 2: Get the splits and value Tensors from the 1-D DT_VARIANT Tensor
|
||||
// to create the component RaggedTensors.
|
||||
(*decoded_ragged)[i].nested_splits.reserve(ragged_rank);
|
||||
for (int j = 0; j < ragged_rank; j++) {
|
||||
const Tensor* split_tensor = input_vec(j).get<Tensor>();
|
||||
if (split_tensor == nullptr) {
|
||||
return errors::InvalidArgument(
|
||||
"Encoded scalar element at index ", i,
|
||||
" doesn't have a splits Tensor at split_index ", j, ": ",
|
||||
input_vec(j).DebugString());
|
||||
}
|
||||
Tensor splits_tensor = *split_tensor;
|
||||
if (splits_tensor.dtype() != split_dtype) {
|
||||
return errors::InvalidArgument(
|
||||
"Expected splits Tensor dtype: ", split_dtype,
|
||||
", found: ", splits_tensor.dtype());
|
||||
}
|
||||
if (splits_tensor.dims() != 1) {
|
||||
return errors::InvalidArgument(
|
||||
"Ragged splits must have rank 1; encoded scalar element at index ",
|
||||
i, " has splits Tensor at split_index ", j, ": ",
|
||||
splits_tensor.DebugString());
|
||||
}
|
||||
(*decoded_ragged)[i].nested_splits.push_back(splits_tensor);
|
||||
}
|
||||
const Tensor* values_tensor = input_vec(ragged_rank).get<Tensor>();
|
||||
if (values_tensor == nullptr) {
|
||||
return errors::InvalidArgument("Encoded scalar element at index ", i,
|
||||
" doesn't have a values Tensor: ",
|
||||
input_vec(ragged_rank).DebugString());
|
||||
}
|
||||
if (values_tensor->dtype() != value_dtype) {
|
||||
if (decoded->values().dtype() != value_dtype) {
|
||||
return errors::InvalidArgument(
|
||||
"Expected values Tensor dtype: ", DataTypeString(value_dtype),
|
||||
", found: ", DataTypeString(values_tensor->dtype()));
|
||||
", found: ", DataTypeString(decoded->values().dtype()));
|
||||
}
|
||||
if (values_tensor->dims() < 1) {
|
||||
if (decoded->values().dims() < 1) {
|
||||
return errors::InvalidArgument(
|
||||
"Ragged values must have rank >= 1; encoded scalar element at index ",
|
||||
i, " has values Tensor: ", values_tensor->DebugString());
|
||||
i, " has values Tensor: ", decoded->values().DebugString());
|
||||
}
|
||||
for (const auto& splits : decoded->nested_splits()) {
|
||||
if (splits.dtype() != split_dtype) {
|
||||
return errors::InvalidArgument(
|
||||
"Expected row_splits Tensor dtype: ", DataTypeString(split_dtype),
|
||||
", found: ", DataTypeString(splits.dtype()));
|
||||
}
|
||||
if (splits.dims() != 1) {
|
||||
return errors::InvalidArgument(
|
||||
"Ragged splits must have rank 1; encoded scalar element at index ",
|
||||
i, " has splits Tensor ", splits.DebugString());
|
||||
}
|
||||
}
|
||||
(*decoded_ragged)[i].values = *values_tensor;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename VALUE_TYPE, typename SPLIT_TYPE>
|
||||
Status NestedStackRaggedTensors(
|
||||
const std::vector<RaggedTensor>& ragged_components,
|
||||
const std::vector<RaggedTensorVariant>& ragged_components,
|
||||
const std::vector<int>& nested_dim_sizes, const int input_ragged_rank,
|
||||
const int output_ragged_rank, RaggedTensor* output_ragged) {
|
||||
output_ragged->nested_splits.reserve(output_ragged_rank);
|
||||
const int output_ragged_rank, RaggedTensorVariant* output_ragged) {
|
||||
output_ragged->mutable_nested_splits()->reserve(output_ragged_rank);
|
||||
const int dims = nested_dim_sizes.size();
|
||||
|
||||
// Populate first `dims - 1` splits.
|
||||
for (int i = 0; i < dims - 1; i++) {
|
||||
int dims_splits_size = nested_dim_sizes[i] + 1;
|
||||
output_ragged->nested_splits.push_back(Tensor(
|
||||
DataTypeToEnum<SPLIT_TYPE>::value, TensorShape({dims_splits_size})));
|
||||
auto splits_vec = output_ragged->nested_splits[i].vec<SPLIT_TYPE>();
|
||||
output_ragged->append_splits(Tensor(DataTypeToEnum<SPLIT_TYPE>::value,
|
||||
TensorShape({dims_splits_size})));
|
||||
auto splits_vec = output_ragged->mutable_splits(i)->vec<SPLIT_TYPE>();
|
||||
int split_diff = nested_dim_sizes[i + 1];
|
||||
for (int j = 0; j < dims_splits_size; j++) {
|
||||
splits_vec(j) = j * split_diff;
|
||||
@ -132,15 +98,15 @@ Status NestedStackRaggedTensors(
|
||||
|
||||
// Populate `dims`-th split.
|
||||
int splits_size = ragged_components.size() + 1;
|
||||
output_ragged->nested_splits.push_back(
|
||||
output_ragged->append_splits(
|
||||
Tensor(DataTypeToEnum<SPLIT_TYPE>::value, TensorShape({splits_size})));
|
||||
auto dims_splits_vec =
|
||||
output_ragged->nested_splits[dims - 1].vec<SPLIT_TYPE>();
|
||||
output_ragged->mutable_splits(dims - 1)->vec<SPLIT_TYPE>();
|
||||
dims_splits_vec(0) = 0;
|
||||
for (int i = 0; i < ragged_components.size(); i++) {
|
||||
int split_val = ragged_components[i].values.shape().dim_size(0);
|
||||
if (input_ragged_rank != 0 && !ragged_components[i].nested_splits.empty()) {
|
||||
split_val = ragged_components[i].nested_splits[0].NumElements() - 1;
|
||||
int split_val = ragged_components[i].values().shape().dim_size(0);
|
||||
if (input_ragged_rank != 0 && ragged_components[i].ragged_rank() > 0) {
|
||||
split_val = ragged_components[i].splits(0).NumElements() - 1;
|
||||
}
|
||||
dims_splits_vec(i + 1) = dims_splits_vec(i) + split_val;
|
||||
}
|
||||
@ -150,24 +116,24 @@ Status NestedStackRaggedTensors(
|
||||
int split_index = dims + i;
|
||||
int split_size = 1;
|
||||
for (int j = 0; j < ragged_components.size(); j++) {
|
||||
if (!ragged_components[j].nested_splits.empty()) {
|
||||
split_size += ragged_components[j].nested_splits[i].NumElements() - 1;
|
||||
if (!ragged_components[j].nested_splits().empty()) {
|
||||
split_size += ragged_components[j].splits(i).NumElements() - 1;
|
||||
}
|
||||
}
|
||||
output_ragged->nested_splits.push_back(
|
||||
output_ragged->append_splits(
|
||||
Tensor(DataTypeToEnum<SPLIT_TYPE>::value, TensorShape({split_size})));
|
||||
auto splits_vec =
|
||||
output_ragged->nested_splits[split_index].vec<SPLIT_TYPE>();
|
||||
output_ragged->mutable_splits(split_index)->vec<SPLIT_TYPE>();
|
||||
splits_vec(0) = 0;
|
||||
SPLIT_TYPE last_split_value = 0;
|
||||
int index = 1;
|
||||
for (int j = 0; j < ragged_components.size(); j++) {
|
||||
if (ragged_components[j].nested_splits.empty()) {
|
||||
if (ragged_components[j].nested_splits().empty()) {
|
||||
// Corner case: empty row. e.g [ [[x], [x]], [] ]
|
||||
continue;
|
||||
}
|
||||
auto component_splits_vec =
|
||||
ragged_components[j].nested_splits[i].vec<SPLIT_TYPE>();
|
||||
ragged_components[j].splits(i).vec<SPLIT_TYPE>();
|
||||
for (int k = 1; k < component_splits_vec.size(); k++, index++) {
|
||||
splits_vec(index) = component_splits_vec(k) + last_split_value;
|
||||
}
|
||||
@ -187,35 +153,35 @@ Status NestedStackRaggedTensors(
|
||||
if (ragged_components.empty()) {
|
||||
component_values_shape = TensorShape({0});
|
||||
} else {
|
||||
component_values_shape = ragged_components[0].values.shape();
|
||||
component_values_shape = ragged_components[0].values().shape();
|
||||
}
|
||||
|
||||
// Populate values.
|
||||
int values_size = component_values_shape.dim_size(0);
|
||||
for (int i = 1; i < ragged_components.size(); i++) {
|
||||
if (ragged_components[i].values.dims() != component_values_shape.dims()) {
|
||||
if (ragged_components[i].values().dims() != component_values_shape.dims()) {
|
||||
return errors::InvalidArgument(
|
||||
"Rank of values must match for all "
|
||||
"components; values shape at index 0: ",
|
||||
component_values_shape.DebugString(), ", values shape at index ", i,
|
||||
": ", ragged_components[i].values.shape().DebugString());
|
||||
": ", ragged_components[i].values().shape().DebugString());
|
||||
}
|
||||
values_size += ragged_components[i].values.shape().dim_size(0);
|
||||
values_size += ragged_components[i].values().shape().dim_size(0);
|
||||
}
|
||||
component_values_shape.set_dim(0, values_size);
|
||||
output_ragged->values =
|
||||
Tensor(DataTypeToEnum<VALUE_TYPE>::value, component_values_shape);
|
||||
output_ragged->set_values(
|
||||
Tensor(DataTypeToEnum<VALUE_TYPE>::value, component_values_shape));
|
||||
auto output_values_flat =
|
||||
output_ragged->values.flat_outer_dims<VALUE_TYPE, 2>();
|
||||
output_ragged->mutable_values()->flat_outer_dims<VALUE_TYPE, 2>();
|
||||
int values_index = 0;
|
||||
for (int i = 0; i < ragged_components.size(); i++) {
|
||||
auto component_values_flat =
|
||||
ragged_components[i].values.flat_outer_dims<VALUE_TYPE, 2>();
|
||||
int num_inner_elements = ragged_components[i].values.NumElements();
|
||||
if (ragged_components[i].values.dim_size(0) > 0) {
|
||||
num_inner_elements /= ragged_components[i].values.dim_size(0);
|
||||
ragged_components[i].values().flat_outer_dims<VALUE_TYPE, 2>();
|
||||
int num_inner_elements = ragged_components[i].values().NumElements();
|
||||
if (ragged_components[i].values().dim_size(0) > 0) {
|
||||
num_inner_elements /= ragged_components[i].values().dim_size(0);
|
||||
}
|
||||
for (int j = 0; j < ragged_components[i].values.dim_size(0);
|
||||
for (int j = 0; j < ragged_components[i].values().dim_size(0);
|
||||
j++, values_index++) {
|
||||
for (int k = 0; k < num_inner_elements; k++) {
|
||||
output_values_flat(values_index, k) = component_values_flat(j, k);
|
||||
@ -265,7 +231,7 @@ class RaggedTensorFromVariantOp : public OpKernel {
|
||||
// Decode all variants.
|
||||
const auto value_dtype = DataTypeToEnum<VALUE_TYPE>::v();
|
||||
const auto split_dtype = DataTypeToEnum<SPLIT_TYPE>::v();
|
||||
std::vector<RaggedTensor> decoded_components;
|
||||
std::vector<RaggedTensorVariant> decoded_components;
|
||||
OP_REQUIRES_OK(context, RaggedComponentsFromVariant(
|
||||
encoded_variant, input_ragged_rank_,
|
||||
value_dtype, split_dtype, &decoded_components));
|
||||
@ -281,7 +247,7 @@ class RaggedTensorFromVariantOp : public OpKernel {
|
||||
for (int i = 0; i < encoded_variant.dims(); i++) {
|
||||
encoded_dim_sizes[i] = encoded_variant.dim_size(i);
|
||||
}
|
||||
RaggedTensor output_ragged;
|
||||
RaggedTensorVariant output_ragged;
|
||||
OP_REQUIRES_OK(
|
||||
context, NestedStackRaggedTensors<VALUE_TYPE, SPLIT_TYPE>(
|
||||
decoded_components, encoded_dim_sizes, input_ragged_rank_,
|
||||
@ -296,15 +262,15 @@ class RaggedTensorFromVariantOp : public OpKernel {
|
||||
int output_ragged_rank_;
|
||||
|
||||
void ReturnRaggedTensor(OpKernelContext* context,
|
||||
RaggedTensor ragged_tensor) {
|
||||
int ragged_rank = ragged_tensor.nested_splits.size();
|
||||
const RaggedTensorVariant& ragged_tensor) {
|
||||
int ragged_rank = ragged_tensor.ragged_rank();
|
||||
OpOutputList splits_out;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->output_list("output_nested_splits", &splits_out));
|
||||
for (int i = 0; i < ragged_rank; i++) {
|
||||
splits_out.set(i, ragged_tensor.nested_splits[i]);
|
||||
splits_out.set(i, ragged_tensor.splits(i));
|
||||
}
|
||||
context->set_output(ragged_rank, ragged_tensor.values);
|
||||
context->set_output(ragged_rank, ragged_tensor.values());
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -26,6 +26,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/variant.h"
|
||||
#include "tensorflow/core/framework/variant_encode_decode.h"
|
||||
#include "tensorflow/core/kernels/ops_testutil.h"
|
||||
#include "tensorflow/core/kernels/ragged_tensor_variant.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
@ -55,28 +56,22 @@ class RaggedTensorFromVariantKernelTest : public ::tensorflow::OpsTestBase {
|
||||
}
|
||||
|
||||
template <typename VALUE_TYPE, typename SPLIT_TYPE>
|
||||
Tensor CreateVariantFromRagged(
|
||||
RaggedTensorVariant CreateVariantFromRagged(
|
||||
const std::vector<std::vector<SPLIT_TYPE>>& ragged_splits,
|
||||
const TensorShape& ragged_values_shape,
|
||||
const std::vector<VALUE_TYPE>& ragged_values) {
|
||||
// Step 1: Create Tensors out of ragged splits and values.
|
||||
std::vector<Variant> ragged_components;
|
||||
RaggedTensorVariant encoded;
|
||||
for (auto ragged_split : ragged_splits) {
|
||||
int splits_size = ragged_split.size();
|
||||
Tensor splits(DataTypeToEnum<SPLIT_TYPE>::v(),
|
||||
TensorShape({splits_size}));
|
||||
test::FillValues<SPLIT_TYPE>(&splits, ragged_split);
|
||||
ragged_components.push_back(splits);
|
||||
encoded.append_splits(splits);
|
||||
}
|
||||
Tensor values(DataTypeToEnum<VALUE_TYPE>::v(), ragged_values_shape);
|
||||
test::FillValues<VALUE_TYPE>(&values, ragged_values);
|
||||
ragged_components.push_back(values);
|
||||
|
||||
// Step 2: Encode into a 1-D Variant Tensor.
|
||||
int num_splits = ragged_splits.size();
|
||||
Tensor encoded_list(DT_VARIANT, TensorShape({num_splits + 1}));
|
||||
test::FillValues<Variant>(&encoded_list, ragged_components);
|
||||
return encoded_list;
|
||||
encoded.set_values(values);
|
||||
return encoded;
|
||||
}
|
||||
};
|
||||
|
||||
@ -85,7 +80,7 @@ TEST_F(RaggedTensorFromVariantKernelTest, ScalarInput) {
|
||||
const std::vector<int64> split_2 = {0, 1, 2, 5, 6, 7};
|
||||
const std::vector<int> values = {0, 1, 1, 2, 2, 3, 4};
|
||||
|
||||
Tensor encoded_variant = CreateVariantFromRagged<int, int64>(
|
||||
auto encoded_variant = CreateVariantFromRagged<int, int64>(
|
||||
{split_1, split_2}, TensorShape({7}), values);
|
||||
Tensor expected_splits_1(DT_INT64, TensorShape({6}));
|
||||
Tensor expected_splits_2(DT_INT64, TensorShape({6}));
|
||||
@ -113,7 +108,7 @@ TEST_F(RaggedTensorFromVariantKernelTest, OneInputElement) {
|
||||
const std::vector<int> values = {0, 1, 1, 2, 2, 3, 4};
|
||||
const std::vector<int64> batched_splits_1 = {0, 5};
|
||||
|
||||
Tensor encoded_variant = CreateVariantFromRagged<int, int64>(
|
||||
auto encoded_variant = CreateVariantFromRagged<int, int64>(
|
||||
{split_1, split_2}, TensorShape({7}), values);
|
||||
Tensor expected_splits_1(DT_INT64, TensorShape({2}));
|
||||
Tensor expected_splits_2(DT_INT64, TensorShape({6}));
|
||||
@ -157,13 +152,13 @@ TEST_F(RaggedTensorFromVariantKernelTest, TensorIn2DOut) {
|
||||
const std::vector<int64> batched_splits_2 = {0, 3, 3, 5, 6};
|
||||
const std::vector<int> batched_values = {1, 2, 3, 4, 5, 6};
|
||||
|
||||
Tensor component_variant_1 =
|
||||
auto component_variant_1 =
|
||||
CreateVariantFromRagged<int, int64>({}, TensorShape({3}), values_1);
|
||||
Tensor component_variant_2 =
|
||||
auto component_variant_2 =
|
||||
CreateVariantFromRagged<int, int64>({}, TensorShape({0}), values_2);
|
||||
Tensor component_variant_3 =
|
||||
auto component_variant_3 =
|
||||
CreateVariantFromRagged<int, int64>({}, TensorShape({2}), values_3);
|
||||
Tensor component_variant_4 =
|
||||
auto component_variant_4 =
|
||||
CreateVariantFromRagged<int, int64>({}, TensorShape({1}), values_4);
|
||||
|
||||
Tensor expected_splits_1(DT_INT64, TensorShape({3}));
|
||||
@ -223,15 +218,15 @@ TEST_F(RaggedTensorFromVariantKernelTest, NonEmpty1DIn3DOut) {
|
||||
test::FillValues<int64>(&expected_splits_3, batched_splits_3);
|
||||
test::FillValues<int>(&expected_values, batched_values);
|
||||
|
||||
Tensor variant_component_1 = CreateVariantFromRagged<int, int64>(
|
||||
auto variant_component_1 = CreateVariantFromRagged<int, int64>(
|
||||
{component_split_1_1}, TensorShape({1}), component_values_1);
|
||||
Tensor variant_component_2 = CreateVariantFromRagged<int, int64>(
|
||||
auto variant_component_2 = CreateVariantFromRagged<int, int64>(
|
||||
{component_split_2_1}, TensorShape({2}), component_values_2);
|
||||
Tensor variant_component_3 = CreateVariantFromRagged<int, int64>(
|
||||
auto variant_component_3 = CreateVariantFromRagged<int, int64>(
|
||||
{component_split_3_1}, TensorShape({2}), component_values_3);
|
||||
Tensor variant_component_4 = CreateVariantFromRagged<int, int64>(
|
||||
auto variant_component_4 = CreateVariantFromRagged<int, int64>(
|
||||
{component_split_4_1}, TensorShape({3}), component_values_4);
|
||||
Tensor variant_component_5 = CreateVariantFromRagged<int, int64>(
|
||||
auto variant_component_5 = CreateVariantFromRagged<int, int64>(
|
||||
{component_split_5_1}, TensorShape({3}), component_values_5);
|
||||
int input_ragged_rank = 1;
|
||||
int output_ragged_rank = 3;
|
||||
@ -297,10 +292,10 @@ TEST_F(RaggedTensorFromVariantKernelTest,
|
||||
test::FillValues<int64>(&expected_splits_4, batched_splits_4);
|
||||
test::FillValues<int>(&expected_values, batched_values);
|
||||
|
||||
Tensor variant_component_1 = CreateVariantFromRagged<int, int64>(
|
||||
auto variant_component_1 = CreateVariantFromRagged<int, int64>(
|
||||
{component_split_1_1, component_split_1_2}, TensorShape({11}),
|
||||
component_values_1);
|
||||
Tensor variant_component_2 = CreateVariantFromRagged<int, int64>(
|
||||
auto variant_component_2 = CreateVariantFromRagged<int, int64>(
|
||||
{component_split_2_1, component_split_2_2}, TensorShape({11}),
|
||||
component_values_2);
|
||||
int input_ragged_rank = -1;
|
||||
@ -336,9 +331,9 @@ TEST_F(RaggedTensorFromVariantKernelTest, EmptyRow1DIn2DOut) {
|
||||
test::FillValues<int64>(&expected_splits_2, batched_splits_2);
|
||||
test::FillValues<int>(&expected_values, batched_values);
|
||||
|
||||
Tensor variant_component_1 = CreateVariantFromRagged<int, int64>(
|
||||
auto variant_component_1 = CreateVariantFromRagged<int, int64>(
|
||||
{component_split_1_1}, TensorShape({3}), component_values_1);
|
||||
Tensor variant_component_2 = CreateVariantFromRagged<int, int64>(
|
||||
auto variant_component_2 = CreateVariantFromRagged<int, int64>(
|
||||
{component_split_2_1}, TensorShape({0}), {}); // Empty row.
|
||||
int input_ragged_rank = 1;
|
||||
int output_ragged_rank = 2;
|
||||
@ -371,9 +366,9 @@ TEST_F(RaggedTensorFromVariantKernelTest, NDValues1DIn2DOut) {
|
||||
test::FillValues<int64>(&expected_splits_2, batched_splits_2);
|
||||
test::FillValues<int>(&expected_values, batched_values);
|
||||
|
||||
Tensor variant_component_1 = CreateVariantFromRagged<int, int64>(
|
||||
auto variant_component_1 = CreateVariantFromRagged<int, int64>(
|
||||
{component_split_1_1}, TensorShape({1, 2}), component_values_1);
|
||||
Tensor variant_component_2 = CreateVariantFromRagged<int, int64>(
|
||||
auto variant_component_2 = CreateVariantFromRagged<int, int64>(
|
||||
{component_split_2_1}, TensorShape({2, 2}), component_values_2);
|
||||
int input_ragged_rank = 1;
|
||||
int output_ragged_rank = 2;
|
||||
@ -423,15 +418,15 @@ TEST_F(RaggedTensorFromVariantKernelTest, NonEmpty1DIn3DOutInt32Splits) {
|
||||
test::FillValues<int>(&expected_splits_3, batched_splits_3);
|
||||
test::FillValues<int>(&expected_values, batched_values);
|
||||
|
||||
Tensor variant_component_1 = CreateVariantFromRagged<int, int>(
|
||||
auto variant_component_1 = CreateVariantFromRagged<int, int>(
|
||||
{component_split_1_1}, TensorShape({1}), component_values_1);
|
||||
Tensor variant_component_2 = CreateVariantFromRagged<int, int>(
|
||||
auto variant_component_2 = CreateVariantFromRagged<int, int>(
|
||||
{component_split_2_1}, TensorShape({2}), component_values_2);
|
||||
Tensor variant_component_3 = CreateVariantFromRagged<int, int>(
|
||||
auto variant_component_3 = CreateVariantFromRagged<int, int>(
|
||||
{component_split_3_1}, TensorShape({2}), component_values_3);
|
||||
Tensor variant_component_4 = CreateVariantFromRagged<int, int>(
|
||||
auto variant_component_4 = CreateVariantFromRagged<int, int>(
|
||||
{component_split_4_1}, TensorShape({3}), component_values_4);
|
||||
Tensor variant_component_5 = CreateVariantFromRagged<int, int>(
|
||||
auto variant_component_5 = CreateVariantFromRagged<int, int>(
|
||||
{component_split_5_1}, TensorShape({3}), component_values_5);
|
||||
int input_ragged_rank = 1;
|
||||
int output_ragged_rank = 3;
|
||||
@ -451,13 +446,13 @@ TEST_F(RaggedTensorFromVariantKernelTest, NonEmpty1DIn3DOutInt32Splits) {
|
||||
|
||||
// Tests for invalid inputs.
|
||||
TEST_F(RaggedTensorFromVariantKernelTest, InvalidInferredInputRaggedRank) {
|
||||
Tensor component_variant_1 =
|
||||
auto component_variant_1 =
|
||||
CreateVariantFromRagged<int, int64>({}, TensorShape({3}), {1, 2, 3});
|
||||
Tensor component_variant_2 =
|
||||
auto component_variant_2 =
|
||||
CreateVariantFromRagged<int, int64>({}, TensorShape({0}), {});
|
||||
Tensor component_variant_3 =
|
||||
auto component_variant_3 =
|
||||
CreateVariantFromRagged<int, int64>({}, TensorShape({2}), {1, 2});
|
||||
Tensor component_variant_4 =
|
||||
auto component_variant_4 =
|
||||
CreateVariantFromRagged<int, int64>({}, TensorShape({1}), {1});
|
||||
|
||||
int input_ragged_rank = -1;
|
||||
@ -478,9 +473,9 @@ TEST_F(RaggedTensorFromVariantKernelTest, InputDimsAndRaggedRankAttrsMismatch) {
|
||||
const std::vector<int> component_values_1 = {0};
|
||||
const std::vector<int> component_values_2 = {0, 1};
|
||||
|
||||
Tensor variant_component_1 = CreateVariantFromRagged<int, int64>(
|
||||
auto variant_component_1 = CreateVariantFromRagged<int, int64>(
|
||||
{component_split_1_1}, TensorShape({1}), component_values_1);
|
||||
Tensor variant_component_2 = CreateVariantFromRagged<int, int64>(
|
||||
auto variant_component_2 = CreateVariantFromRagged<int, int64>(
|
||||
{component_split_2_1}, TensorShape({2}), component_values_2);
|
||||
|
||||
int input_ragged_rank = 1;
|
||||
@ -493,33 +488,21 @@ TEST_F(RaggedTensorFromVariantKernelTest, InputDimsAndRaggedRankAttrsMismatch) {
|
||||
"input_ragged_rank + encoded_ragged.dims()"));
|
||||
}
|
||||
|
||||
TEST_F(RaggedTensorFromVariantKernelTest, InputDoesNotHoldTensors) {
|
||||
TEST_F(RaggedTensorFromVariantKernelTest, InputDoesNotHoldRaggedTensorVariant) {
|
||||
int input_ragged_rank = 1;
|
||||
int output_ragged_rank = 2;
|
||||
BuildDecodeRaggedTensorGraph<int, int64>(
|
||||
input_ragged_rank, output_ragged_rank, TensorShape({2}), {1, 2});
|
||||
EXPECT_TRUE(absl::StartsWith(
|
||||
RunOpKernel().error_message(),
|
||||
"Input Variant element at index 0 doesn't hold a Tensor"));
|
||||
}
|
||||
|
||||
TEST_F(RaggedTensorFromVariantKernelTest, InputVariantTensorRankNotOne) {
|
||||
Tensor variant_list(DT_VARIANT, TensorShape({2, 1}));
|
||||
test::FillValues<Variant>(&variant_list, {1, 2});
|
||||
int input_ragged_rank = 1;
|
||||
int output_ragged_rank = 2;
|
||||
BuildDecodeRaggedTensorGraph<int, int64>(
|
||||
input_ragged_rank, output_ragged_rank, TensorShape({1}), {variant_list});
|
||||
EXPECT_TRUE(absl::StartsWith(
|
||||
RunOpKernel().error_message(),
|
||||
"Encoded input Variant must have rank 1, but found rank: 2"));
|
||||
"Input Variant element at index 0 doesn't hold a RaggedTensorVariant"));
|
||||
}
|
||||
|
||||
TEST_F(RaggedTensorFromVariantKernelTest,
|
||||
InputScalarElementDoesNotMatchInputRaggedRank) {
|
||||
const std::vector<int64> component_split_1_1 = {0, 1};
|
||||
const std::vector<int> component_values_1 = {1, 2};
|
||||
Tensor variant_component_1 = CreateVariantFromRagged<int, int64>(
|
||||
auto variant_component_1 = CreateVariantFromRagged<int, int64>(
|
||||
{component_split_1_1}, TensorShape({1, 2}), component_values_1);
|
||||
|
||||
int input_ragged_rank = 2;
|
||||
@ -527,31 +510,17 @@ TEST_F(RaggedTensorFromVariantKernelTest,
|
||||
BuildDecodeRaggedTensorGraph<int, int64>(input_ragged_rank,
|
||||
output_ragged_rank, TensorShape({1}),
|
||||
{variant_component_1});
|
||||
EXPECT_TRUE(absl::StartsWith(
|
||||
RunOpKernel().error_message(),
|
||||
"Encoded input Variant must hold either input_ragged_rank + 1 "
|
||||
"Tensors or an empty Tensor"));
|
||||
}
|
||||
|
||||
TEST_F(RaggedTensorFromVariantKernelTest, RaggedSplitNotATensor) {
|
||||
Tensor variant_list(DT_VARIANT, TensorShape({2}));
|
||||
test::FillValues<Variant>(&variant_list, {1, 2});
|
||||
|
||||
int input_ragged_rank = 1;
|
||||
int output_ragged_rank = 2;
|
||||
BuildDecodeRaggedTensorGraph<int, int>(input_ragged_rank, output_ragged_rank,
|
||||
TensorShape({1}), {variant_list});
|
||||
EXPECT_TRUE(
|
||||
absl::StartsWith(RunOpKernel().error_message(),
|
||||
"Encoded scalar element at index 0 doesn't have a "
|
||||
"splits Tensor at split_index 0"));
|
||||
"Encoded input RaggedTensorVariant has ragged_rank=1. "
|
||||
"Expected ragged_rank=2."));
|
||||
}
|
||||
|
||||
TEST_F(RaggedTensorFromVariantKernelTest, RaggedSplitTypeMismatch) {
|
||||
const std::vector<int64> component_split_1_1 = {0, 1};
|
||||
const std::vector<int> component_values_1 = {0};
|
||||
|
||||
Tensor variant_component_1 = CreateVariantFromRagged<int, int64>(
|
||||
auto variant_component_1 = CreateVariantFromRagged<int, int64>(
|
||||
{component_split_1_1}, TensorShape({1}), component_values_1);
|
||||
|
||||
int input_ragged_rank = 1;
|
||||
@ -559,46 +528,29 @@ TEST_F(RaggedTensorFromVariantKernelTest, RaggedSplitTypeMismatch) {
|
||||
BuildDecodeRaggedTensorGraph<int, int>(input_ragged_rank, output_ragged_rank,
|
||||
TensorShape({1}),
|
||||
{variant_component_1});
|
||||
EXPECT_TRUE(absl::StartsWith(RunOpKernel().error_message(),
|
||||
"Expected splits Tensor dtype: 3, found: 9"));
|
||||
EXPECT_TRUE(absl::StartsWith(
|
||||
RunOpKernel().error_message(),
|
||||
"Expected row_splits Tensor dtype: int32, found: int64"));
|
||||
}
|
||||
|
||||
TEST_F(RaggedTensorFromVariantKernelTest, RaggedSplitRankNotOne) {
|
||||
Tensor splits(DT_INT64, TensorShape({2, 1}));
|
||||
test::FillValues<int64>(&splits, {1, 2});
|
||||
Tensor values(DT_INT32, {2});
|
||||
test::FillValues<int>(&values, {1, 2});
|
||||
Tensor encoded_list(DT_VARIANT, TensorShape({2}));
|
||||
test::FillValues<Variant>(&encoded_list, {splits, values});
|
||||
RaggedTensorVariant encoded(Tensor(DT_INT32, {2}),
|
||||
{Tensor(DT_INT64, {2, 1})});
|
||||
test::FillValues<int64>(encoded.mutable_splits(0), {1, 2});
|
||||
test::FillValues<int>(encoded.mutable_values(), {1, 2});
|
||||
|
||||
int input_ragged_rank = 1;
|
||||
int output_ragged_rank = 2;
|
||||
BuildDecodeRaggedTensorGraph<int, int64>(
|
||||
input_ragged_rank, output_ragged_rank, TensorShape({1}), {encoded_list});
|
||||
input_ragged_rank, output_ragged_rank, TensorShape({1}), {encoded});
|
||||
EXPECT_TRUE(absl::StartsWith(RunOpKernel().error_message(),
|
||||
"Ragged splits must have rank 1"));
|
||||
}
|
||||
|
||||
TEST_F(RaggedTensorFromVariantKernelTest, RaggedValuesNotATensor) {
|
||||
Tensor splits(DT_INT64, TensorShape({3}));
|
||||
test::FillValues<int64>(&splits, {0, 2, 3});
|
||||
Tensor variant_list(DT_VARIANT, TensorShape({2}));
|
||||
test::FillValues<Variant>(&variant_list, {splits, 2});
|
||||
|
||||
int input_ragged_rank = 1;
|
||||
int output_ragged_rank = 2;
|
||||
BuildDecodeRaggedTensorGraph<int, int64>(
|
||||
input_ragged_rank, output_ragged_rank, TensorShape({1}), {variant_list});
|
||||
EXPECT_TRUE(
|
||||
absl::StartsWith(RunOpKernel().error_message(),
|
||||
"Encoded scalar element at index 0 doesn't have a "
|
||||
"values Tensor"));
|
||||
}
|
||||
|
||||
TEST_F(RaggedTensorFromVariantKernelTest, RaggedValuesTypeMismatch) {
|
||||
const std::vector<int64> component_split_1_1 = {0, 1};
|
||||
const std::vector<int> component_values_1 = {0};
|
||||
Tensor variant_component_1 = CreateVariantFromRagged<int, int64>(
|
||||
auto variant_component_1 = CreateVariantFromRagged<int, int64>(
|
||||
{component_split_1_1}, TensorShape({1}), component_values_1);
|
||||
int input_ragged_rank = 1;
|
||||
int output_ragged_rank = 2;
|
||||
@ -611,7 +563,7 @@ TEST_F(RaggedTensorFromVariantKernelTest, RaggedValuesTypeMismatch) {
|
||||
}
|
||||
|
||||
TEST_F(RaggedTensorFromVariantKernelTest, RaggedValuesRankNotGreaterThanOne) {
|
||||
Tensor variant_component_1 =
|
||||
auto variant_component_1 =
|
||||
CreateVariantFromRagged<int, int64>({{0, 1}}, TensorShape({}), {1});
|
||||
int input_ragged_rank = 1;
|
||||
int output_ragged_rank = 2;
|
||||
@ -628,9 +580,9 @@ TEST_F(RaggedTensorFromVariantKernelTest, RaggedValuesRankMismatch) {
|
||||
const std::vector<int> component_values_1 = {0};
|
||||
const std::vector<int> component_values_2 = {0, 1, 2, 3};
|
||||
|
||||
Tensor variant_component_1 = CreateVariantFromRagged<int, int64>(
|
||||
auto variant_component_1 = CreateVariantFromRagged<int, int64>(
|
||||
{component_split_1_1}, TensorShape({1}), component_values_1);
|
||||
Tensor variant_component_2 = CreateVariantFromRagged<int, int64>(
|
||||
auto variant_component_2 = CreateVariantFromRagged<int, int64>(
|
||||
{component_split_2_1}, TensorShape({2, 2}), component_values_2);
|
||||
int input_ragged_rank = 1;
|
||||
int output_ragged_rank = 2;
|
||||
@ -711,13 +663,13 @@ TEST_F(RaggedTensorFromVariantKernelTest, 2DValuesTensorIn1DOut) {
|
||||
const std::vector<int> batched_values = {1, 1, 1, 1, 2, 2, 2, 2, 3, 3,
|
||||
3, 3, 4, 4, 4, 4, 5, 5, 5, 5};
|
||||
|
||||
Tensor variant_component_1 = CreateVariantFromRagged<int, int64>(
|
||||
auto variant_component_1 = CreateVariantFromRagged<int, int64>(
|
||||
{}, TensorShape({2, 2, 2}), {1, 1, 1, 1, 2, 2, 2, 2});
|
||||
Tensor variant_component_2 = CreateVariantFromRagged<int, int64>(
|
||||
auto variant_component_2 = CreateVariantFromRagged<int, int64>(
|
||||
{}, TensorShape({1, 2, 2}), {3, 3, 3, 3});
|
||||
Tensor variant_component_3 =
|
||||
auto variant_component_3 =
|
||||
CreateVariantFromRagged<int, int64>({}, TensorShape({0, 2, 2}), {});
|
||||
Tensor variant_component_4 = CreateVariantFromRagged<int, int64>(
|
||||
auto variant_component_4 = CreateVariantFromRagged<int, int64>(
|
||||
{}, TensorShape({2, 2, 2}), {4, 4, 4, 4, 5, 5, 5, 5});
|
||||
|
||||
Tensor expected_splits_1(DT_INT64, TensorShape({5}));
|
||||
|
@ -18,50 +18,38 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/variant.h"
|
||||
#include "tensorflow/core/framework/variant_encode_decode.h"
|
||||
#include "tensorflow/core/framework/variant_op_registry.h"
|
||||
#include "tensorflow/core/kernels/concat_lib.h"
|
||||
#include "tensorflow/core/kernels/ragged_tensor_variant.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/util/tensor_ops_util.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
struct RaggedTensor {
|
||||
Tensor values;
|
||||
std::vector<Tensor> nested_splits;
|
||||
};
|
||||
|
||||
Status RaggedToVariant(const RaggedTensor& ragged, Tensor* encoded_list) {
|
||||
// Encode as a rank-1 Variant Tensor.
|
||||
int ragged_rank = ragged.nested_splits.size();
|
||||
*encoded_list = Tensor(DT_VARIANT, TensorShape({ragged_rank + 1}));
|
||||
auto encoded_vec = encoded_list->vec<Variant>();
|
||||
for (int i = 0; i < ragged_rank; i++) {
|
||||
encoded_vec(i) = ragged.nested_splits[i];
|
||||
}
|
||||
encoded_vec(ragged_rank) = ragged.values;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename VALUE_TYPE, typename SPLIT_TYPE>
|
||||
Status UnbatchRaggedZerothDim(const RaggedTensor& batched_ragged,
|
||||
std::vector<RaggedTensor>* ragged_components) {
|
||||
Status UnbatchRaggedZerothDim(
|
||||
const RaggedTensorVariant& batched_ragged,
|
||||
std::vector<RaggedTensorVariant>* ragged_components) {
|
||||
// Set up the component Ragged Tensors.
|
||||
int ragged_rank = batched_ragged.nested_splits.size();
|
||||
auto batched_splits_top_vec =
|
||||
batched_ragged.nested_splits[0].vec<SPLIT_TYPE>();
|
||||
int ragged_rank = batched_ragged.ragged_rank();
|
||||
auto batched_splits_top_vec = batched_ragged.splits(0).vec<SPLIT_TYPE>();
|
||||
int num_components = batched_splits_top_vec.size() - 1;
|
||||
int num_splits = ragged_rank - 1;
|
||||
ragged_components->resize(num_components);
|
||||
for (RaggedTensor ragged_component : *ragged_components) {
|
||||
ragged_component.nested_splits.reserve(num_splits);
|
||||
for (RaggedTensorVariant& ragged_component : *ragged_components) {
|
||||
ragged_component.mutable_nested_splits()->reserve(num_splits);
|
||||
}
|
||||
const auto& batched_flat = batched_ragged.values.flat<VALUE_TYPE>();
|
||||
int num_inner_elems = batched_ragged.values.NumElements();
|
||||
if (batched_ragged.values.dim_size(0) > 1) {
|
||||
num_inner_elems /= batched_ragged.values.dim_size(0);
|
||||
const auto& batched_flat = batched_ragged.values().flat<VALUE_TYPE>();
|
||||
int num_inner_elems = batched_ragged.values().NumElements();
|
||||
if (batched_ragged.values().dim_size(0) > 1) {
|
||||
num_inner_elems /= batched_ragged.values().dim_size(0);
|
||||
}
|
||||
TensorShape values_shape = batched_ragged.values.shape();
|
||||
TensorShape values_shape = batched_ragged.values().shape();
|
||||
|
||||
// Corner case: ragged_rank == 1, e.g. [[1, 2, 3], [4, 5]]
|
||||
if (num_splits == 0) {
|
||||
@ -70,10 +58,10 @@ Status UnbatchRaggedZerothDim(const RaggedTensor& batched_ragged,
|
||||
int limit = batched_splits_top_vec(i + 1);
|
||||
int num_values = limit - start;
|
||||
values_shape.set_dim(0, num_values);
|
||||
(*ragged_components)[i].values =
|
||||
Tensor(DataTypeToEnum<VALUE_TYPE>::value, values_shape);
|
||||
(*ragged_components)[i].set_values(
|
||||
Tensor(DataTypeToEnum<VALUE_TYPE>::value, values_shape));
|
||||
auto ragged_component_values_flat =
|
||||
(*ragged_components)[i].values.flat<VALUE_TYPE>();
|
||||
(*ragged_components)[i].mutable_values()->flat<VALUE_TYPE>();
|
||||
for (int j = 0; j < num_values * num_inner_elems; j++) {
|
||||
ragged_component_values_flat(j) =
|
||||
batched_flat(j + start * num_inner_elems);
|
||||
@ -86,8 +74,7 @@ Status UnbatchRaggedZerothDim(const RaggedTensor& batched_ragged,
|
||||
std::vector<typename TTypes<SPLIT_TYPE>::ConstVec> batched_splits_vec;
|
||||
batched_splits_vec.reserve(ragged_rank);
|
||||
for (int i = 0; i < ragged_rank; i++) {
|
||||
batched_splits_vec.push_back(
|
||||
batched_ragged.nested_splits[i].vec<SPLIT_TYPE>());
|
||||
batched_splits_vec.push_back(batched_ragged.splits(i).vec<SPLIT_TYPE>());
|
||||
}
|
||||
std::vector<int> index(num_splits, 1);
|
||||
std::vector<int> ragged_component_values_size(num_components, 0);
|
||||
@ -104,10 +91,10 @@ Status UnbatchRaggedZerothDim(const RaggedTensor& batched_ragged,
|
||||
int last_index = ragged_component_splits_vec[j - 1].size() - 1;
|
||||
split_size = ragged_component_splits_vec[j - 1](last_index) + 1;
|
||||
}
|
||||
(*ragged_components)[i].nested_splits.push_back(
|
||||
(*ragged_components)[i].append_splits(
|
||||
Tensor(DataTypeToEnum<SPLIT_TYPE>::value, TensorShape({split_size})));
|
||||
ragged_component_splits_vec.push_back(
|
||||
(*ragged_components)[i].nested_splits[j].vec<SPLIT_TYPE>());
|
||||
(*ragged_components)[i].mutable_splits(j)->vec<SPLIT_TYPE>());
|
||||
SPLIT_TYPE last_split_value = batched_splits_vec[j + 1](index[j] - 1);
|
||||
ragged_component_splits_vec[j](0) = 0;
|
||||
for (int k = 1; k < split_size; k++, index[j]++) {
|
||||
@ -125,10 +112,10 @@ Status UnbatchRaggedZerothDim(const RaggedTensor& batched_ragged,
|
||||
for (int i = 0; i < num_components; i++) {
|
||||
int num_values = ragged_component_values_size[i];
|
||||
values_shape.set_dim(0, num_values);
|
||||
(*ragged_components)[i].values =
|
||||
Tensor(DataTypeToEnum<VALUE_TYPE>::value, values_shape);
|
||||
(*ragged_components)[i].set_values(
|
||||
Tensor(DataTypeToEnum<VALUE_TYPE>::value, values_shape));
|
||||
auto ragged_component_values_flat =
|
||||
(*ragged_components)[i].values.flat<VALUE_TYPE>();
|
||||
(*ragged_components)[i].mutable_values()->flat<VALUE_TYPE>();
|
||||
for (int j = 0; j < num_values * num_inner_elems; j++, value_index++) {
|
||||
ragged_component_values_flat(j) = batched_flat(value_index);
|
||||
}
|
||||
@ -152,46 +139,38 @@ class RaggedTensorToVariantOp : public OpKernel {
|
||||
OP_REQUIRES_OK(context, context->input_list("rt_nested_splits",
|
||||
&ragged_nested_splits_in));
|
||||
const int ragged_nested_splits_len = ragged_nested_splits_in.size();
|
||||
RaggedTensor batched_ragged_input;
|
||||
RaggedTensorVariant batched_ragged_input;
|
||||
// Read ragged_values input.
|
||||
batched_ragged_input.values = context->input(ragged_nested_splits_len);
|
||||
batched_ragged_input.nested_splits.reserve(ragged_nested_splits_len);
|
||||
batched_ragged_input.set_values(context->input(ragged_nested_splits_len));
|
||||
batched_ragged_input.mutable_nested_splits()->reserve(
|
||||
ragged_nested_splits_len);
|
||||
for (int i = 0; i < ragged_nested_splits_len; i++) {
|
||||
batched_ragged_input.nested_splits.push_back(ragged_nested_splits_in[i]);
|
||||
batched_ragged_input.append_splits(ragged_nested_splits_in[i]);
|
||||
}
|
||||
|
||||
if (!batched_input_) {
|
||||
// Encode the input as is.
|
||||
Tensor encoded_list;
|
||||
OP_REQUIRES_OK(context,
|
||||
RaggedToVariant(batched_ragged_input, &encoded_list));
|
||||
// Encode as a Scalar Variant Tensor.
|
||||
Tensor* encoded_scalar;
|
||||
OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape({}),
|
||||
&encoded_scalar));
|
||||
encoded_scalar->scalar<Variant>()() = std::move(encoded_list);
|
||||
encoded_scalar->scalar<Variant>()() = std::move(batched_ragged_input);
|
||||
return;
|
||||
}
|
||||
|
||||
// Unbatch the Ragged Tensor and encode the components.
|
||||
std::vector<RaggedTensor> ragged_components;
|
||||
std::vector<RaggedTensorVariant> unbatched_ragged_input;
|
||||
OP_REQUIRES_OK(context, UnbatchRaggedZerothDim<VALUE_TYPE, SPLIT_TYPE>(
|
||||
batched_ragged_input, &ragged_components));
|
||||
std::vector<Tensor> encoded_components(ragged_components.size());
|
||||
for (int i = 0; i < ragged_components.size(); i++) {
|
||||
OP_REQUIRES_OK(context, RaggedToVariant(ragged_components[i],
|
||||
&encoded_components[i]));
|
||||
}
|
||||
batched_ragged_input, &unbatched_ragged_input));
|
||||
|
||||
// Bundle the encoded scalar Variant Tensors into a rank-1 Variant Tensor.
|
||||
Tensor* encoded_ragged;
|
||||
int output_size = ragged_components.size();
|
||||
Tensor* encoded_vector;
|
||||
int output_size = unbatched_ragged_input.size();
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_output(0, TensorShape({output_size}),
|
||||
&encoded_ragged));
|
||||
auto encoded_ragged_vec = encoded_ragged->vec<Variant>();
|
||||
&encoded_vector));
|
||||
auto encoded_vector_t = encoded_vector->vec<Variant>();
|
||||
for (int i = 0; i < output_size; i++) {
|
||||
encoded_ragged_vec(i) = encoded_components[i];
|
||||
encoded_vector_t(i) = unbatched_ragged_input[i];
|
||||
}
|
||||
}
|
||||
|
||||
@ -199,12 +178,81 @@ class RaggedTensorToVariantOp : public OpKernel {
|
||||
bool batched_input_;
|
||||
};
|
||||
|
||||
#define REGISTER_KERNELS_WITH_SPLIT_TYPE(value_type, split_type) \
|
||||
REGISTER_KERNEL_BUILDER(Name("RaggedTensorToVariant") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<value_type>("Tvalues") \
|
||||
.TypeConstraint<split_type>("Tsplits"), \
|
||||
RaggedTensorToVariantOp<value_type, split_type>);
|
||||
template <typename VALUE_TYPE, typename SPLIT_TYPE>
|
||||
class RaggedTensorToVariantGradientOp : public OpKernel {
|
||||
public:
|
||||
using OpKernel::OpKernel;
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
// Read inputs.
|
||||
Tensor encoded_variant = context->input(0);
|
||||
Tensor row_splits = context->input(1);
|
||||
auto flat_row_splits = row_splits.flat<SPLIT_TYPE>();
|
||||
TensorShape dense_values_shape;
|
||||
OP_REQUIRES_OK(context,
|
||||
TensorShapeUtils::MakeShape(context->input(2).vec<int32>(),
|
||||
&dense_values_shape));
|
||||
|
||||
const auto& flat_variants = encoded_variant.flat<Variant>();
|
||||
|
||||
// Get a Tensor containing the flat_values for each variant.
|
||||
std::vector<Tensor> values;
|
||||
for (int i = 0; i < flat_variants.size(); ++i) {
|
||||
if (const auto* encoded = flat_variants(i).get<RaggedTensorVariant>()) {
|
||||
values.push_back(encoded->values());
|
||||
} else {
|
||||
// Missing value: this happens if only some of the variant values
|
||||
// generated by ragged_tensor_to_variant impacted the value that we're
|
||||
// calculating the gradient for. In this case, we will see a
|
||||
// default-constructed variant; so treat it as a zero tensor with the
|
||||
// appropriate shape.
|
||||
const auto value_dtype = DataTypeToEnum<VALUE_TYPE>::v();
|
||||
int piece_size = flat_row_splits(i + 1) - flat_row_splits(i);
|
||||
TensorShape zeros_shape = dense_values_shape;
|
||||
zeros_shape.set_dim(0, piece_size);
|
||||
Tensor zero(value_dtype, zeros_shape);
|
||||
zero.flat<VALUE_TYPE>() =
|
||||
zero.flat<VALUE_TYPE>().constant(VALUE_TYPE());
|
||||
values.push_back(zero);
|
||||
}
|
||||
}
|
||||
|
||||
if (values.size() == 1) {
|
||||
// Just one flat_value tensor: return as-is.
|
||||
context->set_output(0, values[0]);
|
||||
} else {
|
||||
// Multiple flat_values tensors: concatenate them together.
|
||||
using Piece = typename TTypes<VALUE_TYPE, 2>::Matrix;
|
||||
using ConstPiece = typename TTypes<VALUE_TYPE, 2>::ConstMatrix;
|
||||
std::vector<std::unique_ptr<ConstPiece>> pieces;
|
||||
pieces.reserve(values.size());
|
||||
for (const Tensor& t : values) {
|
||||
pieces.emplace_back(
|
||||
new ConstPiece(t.shaped<VALUE_TYPE, 2>({1, t.NumElements()})));
|
||||
}
|
||||
Tensor* out = nullptr;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_output(0, dense_values_shape, &out));
|
||||
Piece out_flat =
|
||||
out->shaped<VALUE_TYPE, 2>({1, dense_values_shape.num_elements()});
|
||||
ConcatCPU<VALUE_TYPE>(context->device(), pieces, &out_flat);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
#define REGISTER_KERNELS_WITH_SPLIT_TYPE(value_type, split_type) \
|
||||
REGISTER_KERNEL_BUILDER(Name("RaggedTensorToVariant") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<value_type>("Tvalues") \
|
||||
.TypeConstraint<split_type>("Tsplits"), \
|
||||
RaggedTensorToVariantOp<value_type, split_type>); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("RaggedTensorToVariantGradient") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<value_type>("Tvalues") \
|
||||
.TypeConstraint<split_type>("Tsplits"), \
|
||||
RaggedTensorToVariantGradientOp<value_type, split_type>);
|
||||
|
||||
#define REGISTER_KERNELS(value_type) \
|
||||
REGISTER_KERNELS_WITH_SPLIT_TYPE(value_type, int32) \
|
||||
REGISTER_KERNELS_WITH_SPLIT_TYPE(value_type, int64)
|
||||
|
@ -26,6 +26,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/variant.h"
|
||||
#include "tensorflow/core/framework/variant_encode_decode.h"
|
||||
#include "tensorflow/core/kernels/ops_testutil.h"
|
||||
#include "tensorflow/core/kernels/ragged_tensor_variant.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
@ -60,6 +61,43 @@ class RaggedTensorToVariantKernelTest : public ::tensorflow::OpsTestBase {
|
||||
}
|
||||
AddInputFromArray<VALUE_TYPE>(ragged_values_shape, ragged_values);
|
||||
}
|
||||
|
||||
template <typename VALUE_TYPE, typename SPLIT_TYPE>
|
||||
RaggedTensorVariant CreateVariantFromRagged(
|
||||
const std::vector<std::vector<SPLIT_TYPE>>& ragged_splits,
|
||||
const TensorShape& ragged_values_shape,
|
||||
const std::vector<VALUE_TYPE>& ragged_values) {
|
||||
RaggedTensorVariant encoded;
|
||||
for (auto ragged_split : ragged_splits) {
|
||||
int splits_size = ragged_split.size();
|
||||
Tensor splits(DataTypeToEnum<SPLIT_TYPE>::v(),
|
||||
TensorShape({splits_size}));
|
||||
test::FillValues<SPLIT_TYPE>(&splits, ragged_split);
|
||||
encoded.append_splits(splits);
|
||||
}
|
||||
Tensor values(DataTypeToEnum<VALUE_TYPE>::v(), ragged_values_shape);
|
||||
test::FillValues<VALUE_TYPE>(&values, ragged_values);
|
||||
encoded.set_values(values);
|
||||
return encoded;
|
||||
}
|
||||
|
||||
template <typename VALUE_TYPE, typename SPLIT_TYPE>
|
||||
RaggedTensorVariant CreateVariantFromRagged(
|
||||
const std::vector<std::vector<SPLIT_TYPE>>& ragged_splits,
|
||||
const std::vector<VALUE_TYPE>& ragged_values) {
|
||||
int num_values = ragged_values.size();
|
||||
return CreateVariantFromRagged(ragged_splits, {num_values}, ragged_values);
|
||||
}
|
||||
|
||||
template <typename VALUE_TYPE, typename SPLIT_TYPE>
|
||||
void ExpectRaggedTensorVariantEqual(const RaggedTensorVariant& expected,
|
||||
const RaggedTensorVariant& actual) {
|
||||
test::ExpectTensorEqual<VALUE_TYPE>(actual.values(), expected.values());
|
||||
EXPECT_EQ(actual.ragged_rank(), expected.ragged_rank());
|
||||
for (int i = 0; i < actual.ragged_rank(); ++i) {
|
||||
test::ExpectTensorEqual<SPLIT_TYPE>(actual.splits(i), expected.splits(i));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(RaggedTensorToVariantKernelTest, NoValuesInput) {
|
||||
@ -67,18 +105,6 @@ TEST_F(RaggedTensorToVariantKernelTest, NoValuesInput) {
|
||||
const std::vector<int64> batched_splits_1 = {0, 2, 3, 3};
|
||||
const std::vector<int64> batched_splits_2 = {0, 0, 0, 0};
|
||||
|
||||
const std::vector<int64> component_splits_1_1 = {0, 0, 0};
|
||||
const std::vector<int64> component_splits_2_1 = {0, 0};
|
||||
const std::vector<int64> component_splits_3_1 = {0};
|
||||
|
||||
Tensor expected_splits_1_1(DT_INT64, TensorShape({3}));
|
||||
Tensor expected_splits_2_1(DT_INT64, TensorShape({2}));
|
||||
Tensor expected_splits_3_1(DT_INT64, TensorShape({1}));
|
||||
|
||||
test::FillValues<int64>(&expected_splits_1_1, component_splits_1_1);
|
||||
test::FillValues<int64>(&expected_splits_2_1, component_splits_2_1);
|
||||
test::FillValues<int64>(&expected_splits_3_1, component_splits_3_1);
|
||||
|
||||
BuildEncodeRaggedTensorGraph<int, int64>({batched_splits_1, batched_splits_2},
|
||||
TensorShape({0}), {}, true);
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
@ -86,55 +112,26 @@ TEST_F(RaggedTensorToVariantKernelTest, NoValuesInput) {
|
||||
const auto& encoded_list = GetOutput(0)->vec<Variant>();
|
||||
EXPECT_EQ(encoded_list.size(), 3);
|
||||
|
||||
const Variant& encoded_splits_1_1 =
|
||||
encoded_list(0).get<Tensor>()->vec<Variant>()(0);
|
||||
const Variant& encoded_values_1 =
|
||||
encoded_list(0).get<Tensor>()->vec<Variant>()(1);
|
||||
const Variant& encoded_splits_2_1 =
|
||||
encoded_list(1).get<Tensor>()->vec<Variant>()(0);
|
||||
const Variant& encoded_values_2 =
|
||||
encoded_list(1).get<Tensor>()->vec<Variant>()(1);
|
||||
const Variant& encoded_splits_3_1 =
|
||||
encoded_list(2).get<Tensor>()->vec<Variant>()(0);
|
||||
const Variant& encoded_values_3 =
|
||||
encoded_list(2).get<Tensor>()->vec<Variant>()(1);
|
||||
|
||||
test::ExpectTensorEqual<int64>(*encoded_splits_1_1.get<Tensor>(),
|
||||
expected_splits_1_1);
|
||||
test::ExpectTensorEqual<int64>(*encoded_splits_2_1.get<Tensor>(),
|
||||
expected_splits_2_1);
|
||||
test::ExpectTensorEqual<int64>(*encoded_splits_3_1.get<Tensor>(),
|
||||
expected_splits_3_1);
|
||||
test::ExpectTensorEqual<int>(*encoded_values_1.get<Tensor>(),
|
||||
Tensor(DT_INT32, TensorShape({0})));
|
||||
test::ExpectTensorEqual<int>(*encoded_values_2.get<Tensor>(),
|
||||
Tensor(DT_INT32, TensorShape({0})));
|
||||
test::ExpectTensorEqual<int>(*encoded_values_3.get<Tensor>(),
|
||||
Tensor(DT_INT32, TensorShape({0})));
|
||||
ExpectRaggedTensorVariantEqual<int, int64>(
|
||||
CreateVariantFromRagged<int, int64>({{0, 0, 0}}, {}),
|
||||
*encoded_list(0).get<RaggedTensorVariant>());
|
||||
ExpectRaggedTensorVariantEqual<int, int64>(
|
||||
CreateVariantFromRagged<int, int64>({{0, 0}}, {}),
|
||||
*encoded_list(1).get<RaggedTensorVariant>());
|
||||
ExpectRaggedTensorVariantEqual<int, int64>(
|
||||
CreateVariantFromRagged<int, int64>({{0}}, {}),
|
||||
*encoded_list(2).get<RaggedTensorVariant>());
|
||||
}
|
||||
|
||||
TEST_F(RaggedTensorToVariantKernelTest, 1DValuesRaggedRankOneInput) {
|
||||
// ragged_tensor=
|
||||
// [ [x, x, x],
|
||||
// [ [1, 2, 3],
|
||||
// [ ],
|
||||
// [x, x ],
|
||||
// [x ]]
|
||||
// [4, 5 ],
|
||||
// [6 ]]
|
||||
const std::vector<int64> batched_splits = {0, 3, 3, 5, 6};
|
||||
const std::vector<int> batched_values = {1, 2, 3, 4, 5, 6};
|
||||
|
||||
const std::vector<int> component_values_1 = {1, 2, 3};
|
||||
const std::vector<int> component_values_3 = {4, 5};
|
||||
const std::vector<int> component_values_4 = {6};
|
||||
|
||||
Tensor expected_values_1(DT_INT32, TensorShape({3}));
|
||||
Tensor expected_values_2(DT_INT32, TensorShape({0}));
|
||||
Tensor expected_values_3(DT_INT32, TensorShape({2}));
|
||||
Tensor expected_values_4(DT_INT32, TensorShape({1}));
|
||||
|
||||
test::FillValues<int>(&expected_values_1, component_values_1);
|
||||
test::FillValues<int>(&expected_values_3, component_values_3);
|
||||
test::FillValues<int>(&expected_values_4, component_values_4);
|
||||
|
||||
BuildEncodeRaggedTensorGraph<int, int64>({batched_splits}, TensorShape({6}),
|
||||
batched_values, true);
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
@ -142,45 +139,28 @@ TEST_F(RaggedTensorToVariantKernelTest, 1DValuesRaggedRankOneInput) {
|
||||
const auto& encoded_list = GetOutput(0)->vec<Variant>();
|
||||
EXPECT_EQ(encoded_list.size(), 4);
|
||||
|
||||
const Variant& encoded_values_1 =
|
||||
encoded_list(0).get<Tensor>()->vec<Variant>()(0);
|
||||
const Variant& encoded_values_2 =
|
||||
encoded_list(1).get<Tensor>()->vec<Variant>()(0);
|
||||
const Variant& encoded_values_3 =
|
||||
encoded_list(2).get<Tensor>()->vec<Variant>()(0);
|
||||
const Variant& encoded_values_4 =
|
||||
encoded_list(3).get<Tensor>()->vec<Variant>()(0);
|
||||
|
||||
test::ExpectTensorEqual<int>(*encoded_values_1.get<Tensor>(),
|
||||
expected_values_1);
|
||||
test::ExpectTensorEqual<int>(*encoded_values_2.get<Tensor>(),
|
||||
expected_values_2);
|
||||
test::ExpectTensorEqual<int>(*encoded_values_3.get<Tensor>(),
|
||||
expected_values_3);
|
||||
test::ExpectTensorEqual<int>(*encoded_values_4.get<Tensor>(),
|
||||
expected_values_4);
|
||||
ExpectRaggedTensorVariantEqual<int, int64>(
|
||||
CreateVariantFromRagged<int, int64>({}, {1, 2, 3}),
|
||||
*encoded_list(0).get<RaggedTensorVariant>());
|
||||
ExpectRaggedTensorVariantEqual<int, int64>(
|
||||
CreateVariantFromRagged<int, int64>({}, {}),
|
||||
*encoded_list(1).get<RaggedTensorVariant>());
|
||||
ExpectRaggedTensorVariantEqual<int, int64>(
|
||||
CreateVariantFromRagged<int, int64>({}, {4, 5}),
|
||||
*encoded_list(2).get<RaggedTensorVariant>());
|
||||
ExpectRaggedTensorVariantEqual<int, int64>(
|
||||
CreateVariantFromRagged<int, int64>({}, {6}),
|
||||
*encoded_list(3).get<RaggedTensorVariant>());
|
||||
}
|
||||
|
||||
TEST_F(RaggedTensorToVariantKernelTest, 2DBatchedValuesRankOneInput) {
|
||||
// ragged_tensor=
|
||||
// [[x, x],
|
||||
// [x, x],
|
||||
// [x, x]]
|
||||
// [[1, 2],
|
||||
// [4, 5],
|
||||
// [6, 7]]
|
||||
const std::vector<int64> batched_splits = {0, 1, 2, 3};
|
||||
const std::vector<int> batched_values = {1, 2, 4, 5, 6, 7};
|
||||
|
||||
const std::vector<int> component_values_1 = {1, 2};
|
||||
const std::vector<int> component_values_2 = {4, 5};
|
||||
const std::vector<int> component_values_3 = {6, 7};
|
||||
|
||||
Tensor expected_values_1(DT_INT32, TensorShape({1, 2}));
|
||||
Tensor expected_values_2(DT_INT32, TensorShape({1, 2}));
|
||||
Tensor expected_values_3(DT_INT32, TensorShape({1, 2}));
|
||||
|
||||
test::FillValues<int>(&expected_values_1, component_values_1);
|
||||
test::FillValues<int>(&expected_values_2, component_values_2);
|
||||
test::FillValues<int>(&expected_values_3, component_values_3);
|
||||
|
||||
BuildEncodeRaggedTensorGraph<int, int64>(
|
||||
{batched_splits}, TensorShape({3, 2}), batched_values, true);
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
@ -188,44 +168,25 @@ TEST_F(RaggedTensorToVariantKernelTest, 2DBatchedValuesRankOneInput) {
|
||||
const auto& encoded_list = GetOutput(0)->vec<Variant>();
|
||||
EXPECT_EQ(encoded_list.size(), 3);
|
||||
|
||||
const Variant& encoded_values_1 =
|
||||
encoded_list(0).get<Tensor>()->vec<Variant>()(0);
|
||||
const Variant& encoded_values_2 =
|
||||
encoded_list(1).get<Tensor>()->vec<Variant>()(0);
|
||||
const Variant& encoded_values_3 =
|
||||
encoded_list(2).get<Tensor>()->vec<Variant>()(0);
|
||||
|
||||
test::ExpectTensorEqual<int>(*encoded_values_1.get<Tensor>(),
|
||||
expected_values_1);
|
||||
test::ExpectTensorEqual<int>(*encoded_values_2.get<Tensor>(),
|
||||
expected_values_2);
|
||||
test::ExpectTensorEqual<int>(*encoded_values_3.get<Tensor>(),
|
||||
expected_values_3);
|
||||
ExpectRaggedTensorVariantEqual<int, int64>(
|
||||
CreateVariantFromRagged<int, int64>({}, {1, 2}, {1, 2}),
|
||||
*encoded_list(0).get<RaggedTensorVariant>());
|
||||
ExpectRaggedTensorVariantEqual<int, int64>(
|
||||
CreateVariantFromRagged<int, int64>({}, {1, 2}, {4, 5}),
|
||||
*encoded_list(1).get<RaggedTensorVariant>());
|
||||
ExpectRaggedTensorVariantEqual<int, int64>(
|
||||
CreateVariantFromRagged<int, int64>({}, {1, 2}, {6, 7}),
|
||||
*encoded_list(2).get<RaggedTensorVariant>());
|
||||
}
|
||||
|
||||
TEST_F(RaggedTensorToVariantKernelTest, 2DBatchedValuesRankTwoInput) {
|
||||
// ragged_tensor=[
|
||||
// [ [[x, x], [x, x]],
|
||||
// [[x, x] ] ]
|
||||
// ragged_tensor=
|
||||
// [ [[[1, 2], [4, 5]]],
|
||||
// [[[6 7]]] ]
|
||||
const std::vector<int64> batched_splits_1 = {0, 1, 2};
|
||||
const std::vector<int64> batched_splits_2 = {0, 2, 3};
|
||||
const std::vector<int> batched_values = {1, 2, 4, 5, 6, 7};
|
||||
|
||||
const std::vector<int64> component_splits_1_1 = {0, 2};
|
||||
const std::vector<int64> component_splits_2_1 = {0, 1};
|
||||
const std::vector<int> component_values_1 = {1, 2, 4, 5};
|
||||
const std::vector<int> component_values_2 = {6, 7};
|
||||
|
||||
Tensor expected_splits_1_1(DT_INT64, TensorShape({2}));
|
||||
Tensor expected_splits_2_1(DT_INT64, TensorShape({2}));
|
||||
Tensor expected_values_1(DT_INT32, TensorShape({2, 2}));
|
||||
Tensor expected_values_2(DT_INT32, TensorShape({1, 2}));
|
||||
|
||||
test::FillValues<int64>(&expected_splits_1_1, component_splits_1_1);
|
||||
test::FillValues<int64>(&expected_splits_2_1, component_splits_2_1);
|
||||
test::FillValues<int>(&expected_values_1, component_values_1);
|
||||
test::FillValues<int>(&expected_values_2, component_values_2);
|
||||
|
||||
BuildEncodeRaggedTensorGraph<int, int64>({batched_splits_1, batched_splits_2},
|
||||
TensorShape({3, 2}), batched_values,
|
||||
true);
|
||||
@ -234,23 +195,12 @@ TEST_F(RaggedTensorToVariantKernelTest, 2DBatchedValuesRankTwoInput) {
|
||||
const auto& encoded_list = GetOutput(0)->vec<Variant>();
|
||||
EXPECT_EQ(encoded_list.size(), 2);
|
||||
|
||||
const Variant& encoded_splits_1_1 =
|
||||
encoded_list(0).get<Tensor>()->vec<Variant>()(0);
|
||||
const Variant& encoded_values_1 =
|
||||
encoded_list(0).get<Tensor>()->vec<Variant>()(1);
|
||||
const Variant& encoded_splits_2_1 =
|
||||
encoded_list(1).get<Tensor>()->vec<Variant>()(0);
|
||||
const Variant& encoded_values_2 =
|
||||
encoded_list(1).get<Tensor>()->vec<Variant>()(1);
|
||||
|
||||
test::ExpectTensorEqual<int64>(*encoded_splits_1_1.get<Tensor>(),
|
||||
expected_splits_1_1);
|
||||
test::ExpectTensorEqual<int>(*encoded_values_1.get<Tensor>(),
|
||||
expected_values_1);
|
||||
test::ExpectTensorEqual<int64>(*encoded_splits_2_1.get<Tensor>(),
|
||||
expected_splits_2_1);
|
||||
test::ExpectTensorEqual<int>(*encoded_values_2.get<Tensor>(),
|
||||
expected_values_2);
|
||||
ExpectRaggedTensorVariantEqual<int, int64>(
|
||||
CreateVariantFromRagged<int, int64>({{0, 2}}, {2, 2}, {1, 2, 4, 5}),
|
||||
*encoded_list(0).get<RaggedTensorVariant>());
|
||||
ExpectRaggedTensorVariantEqual<int, int64>(
|
||||
CreateVariantFromRagged<int, int64>({{0, 1}}, {1, 2}, {6, 7}),
|
||||
*encoded_list(1).get<RaggedTensorVariant>());
|
||||
}
|
||||
|
||||
TEST_F(RaggedTensorToVariantKernelTest, EmptyRowInBatchedInput) {
|
||||
@ -263,30 +213,6 @@ TEST_F(RaggedTensorToVariantKernelTest, EmptyRowInBatchedInput) {
|
||||
const std::vector<int64> batched_splits_2 = {0, 1, 3, 3, 8, 11, 11, 15};
|
||||
const std::vector<int> batched_values = {1, 2, 3, 4, 5, 6, 7, 8,
|
||||
9, 10, 11, 12, 13, 14, 15};
|
||||
const std::vector<int64> component_splits_1_1 = {0, 1, 3, 3};
|
||||
const std::vector<int64> component_splits_2_1 = {0};
|
||||
const std::vector<int64> component_splits_3_1 = {0, 5, 8};
|
||||
const std::vector<int64> component_splits_4_1 = {0, 0, 4};
|
||||
const std::vector<int> component_values_1 = {1, 2, 3};
|
||||
const std::vector<int> component_values_3 = {4, 5, 6, 7, 8, 9, 10, 11};
|
||||
const std::vector<int> component_values_4 = {12, 13, 14, 15};
|
||||
|
||||
Tensor expected_splits_1_1(DT_INT64, TensorShape({4}));
|
||||
Tensor expected_splits_2_1(DT_INT64, TensorShape({1}));
|
||||
Tensor expected_splits_3_1(DT_INT64, TensorShape({3}));
|
||||
Tensor expected_splits_4_1(DT_INT64, TensorShape({3}));
|
||||
Tensor expected_values_1(DT_INT32, TensorShape({3}));
|
||||
Tensor expected_values_2(DT_INT32, TensorShape({0}));
|
||||
Tensor expected_values_3(DT_INT32, TensorShape({8}));
|
||||
Tensor expected_values_4(DT_INT32, TensorShape({4}));
|
||||
|
||||
test::FillValues<int64>(&expected_splits_1_1, component_splits_1_1);
|
||||
test::FillValues<int64>(&expected_splits_2_1, component_splits_2_1);
|
||||
test::FillValues<int64>(&expected_splits_3_1, component_splits_3_1);
|
||||
test::FillValues<int64>(&expected_splits_4_1, component_splits_4_1);
|
||||
test::FillValues<int>(&expected_values_1, component_values_1);
|
||||
test::FillValues<int>(&expected_values_3, component_values_3);
|
||||
test::FillValues<int>(&expected_values_4, component_values_4);
|
||||
|
||||
BuildEncodeRaggedTensorGraph<int, int64>({batched_splits_1, batched_splits_2},
|
||||
TensorShape({15}), batched_values,
|
||||
@ -296,39 +222,19 @@ TEST_F(RaggedTensorToVariantKernelTest, EmptyRowInBatchedInput) {
|
||||
const auto& encoded_list = GetOutput(0)->vec<Variant>();
|
||||
EXPECT_EQ(encoded_list.size(), 4);
|
||||
|
||||
const Variant& encoded_splits_1_1 =
|
||||
encoded_list(0).get<Tensor>()->vec<Variant>()(0);
|
||||
const Variant& encoded_values_1 =
|
||||
encoded_list(0).get<Tensor>()->vec<Variant>()(1);
|
||||
const Variant& encoded_splits_2_1 =
|
||||
encoded_list(1).get<Tensor>()->vec<Variant>()(0);
|
||||
const Variant& encoded_values_2 =
|
||||
encoded_list(1).get<Tensor>()->vec<Variant>()(1);
|
||||
const Variant& encoded_splits_3_1 =
|
||||
encoded_list(2).get<Tensor>()->vec<Variant>()(0);
|
||||
const Variant& encoded_values_3 =
|
||||
encoded_list(2).get<Tensor>()->vec<Variant>()(1);
|
||||
const Variant& encoded_splits_4_1 =
|
||||
encoded_list(3).get<Tensor>()->vec<Variant>()(0);
|
||||
const Variant& encoded_values_4 =
|
||||
encoded_list(3).get<Tensor>()->vec<Variant>()(1);
|
||||
|
||||
test::ExpectTensorEqual<int64>(*encoded_splits_1_1.get<Tensor>(),
|
||||
expected_splits_1_1);
|
||||
test::ExpectTensorEqual<int>(*encoded_values_1.get<Tensor>(),
|
||||
expected_values_1);
|
||||
test::ExpectTensorEqual<int64>(*encoded_splits_2_1.get<Tensor>(),
|
||||
expected_splits_2_1);
|
||||
test::ExpectTensorEqual<int>(*encoded_values_2.get<Tensor>(),
|
||||
expected_values_2);
|
||||
test::ExpectTensorEqual<int64>(*encoded_splits_3_1.get<Tensor>(),
|
||||
expected_splits_3_1);
|
||||
test::ExpectTensorEqual<int>(*encoded_values_3.get<Tensor>(),
|
||||
expected_values_3);
|
||||
test::ExpectTensorEqual<int64>(*encoded_splits_4_1.get<Tensor>(),
|
||||
expected_splits_4_1);
|
||||
test::ExpectTensorEqual<int>(*encoded_values_4.get<Tensor>(),
|
||||
expected_values_4);
|
||||
ExpectRaggedTensorVariantEqual<int, int64>(
|
||||
CreateVariantFromRagged<int, int64>({{0, 1, 3, 3}}, {1, 2, 3}),
|
||||
*encoded_list(0).get<RaggedTensorVariant>());
|
||||
ExpectRaggedTensorVariantEqual<int, int64>(
|
||||
CreateVariantFromRagged<int, int64>({{0}}, {}),
|
||||
*encoded_list(1).get<RaggedTensorVariant>());
|
||||
ExpectRaggedTensorVariantEqual<int, int64>(
|
||||
CreateVariantFromRagged<int, int64>({{0, 5, 8}},
|
||||
{4, 5, 6, 7, 8, 9, 10, 11}),
|
||||
*encoded_list(2).get<RaggedTensorVariant>());
|
||||
ExpectRaggedTensorVariantEqual<int, int64>(
|
||||
CreateVariantFromRagged<int, int64>({{0, 0, 4}}, {12, 13, 14, 15}),
|
||||
*encoded_list(3).get<RaggedTensorVariant>());
|
||||
}
|
||||
|
||||
TEST_F(RaggedTensorToVariantKernelTest, NonEmptyBatchedInput) {
|
||||
@ -350,26 +256,6 @@ TEST_F(RaggedTensorToVariantKernelTest, NonEmptyBatchedInput) {
|
||||
7, 8, 9, 12, 13, 14};
|
||||
const std::vector<int> batched_values = {0, 1, 1, 2, 2, 3, 4,
|
||||
5, 6, 7, 8, 9, 8, 9};
|
||||
const std::vector<int64> component_split_1_1 = {0, 1, 3, 4, 5, 6};
|
||||
const std::vector<int64> component_split_1_2 = {0, 2, 3, 4, 5, 6, 7};
|
||||
const std::vector<int64> component_split_2_1 = {0, 1, 2, 3, 4, 5};
|
||||
const std::vector<int64> component_split_2_2 = {0, 1, 2, 5, 6, 7};
|
||||
const std::vector<int> component_values_1 = {0, 1, 1, 2, 2, 3, 4};
|
||||
const std::vector<int> component_values_2 = {5, 6, 7, 8, 9, 8, 9};
|
||||
|
||||
Tensor expected_splits_1_1(DT_INT64, TensorShape({6}));
|
||||
Tensor expected_splits_1_2(DT_INT64, TensorShape({7}));
|
||||
Tensor expected_splits_2_1(DT_INT64, TensorShape({6}));
|
||||
Tensor expected_splits_2_2(DT_INT64, TensorShape({6}));
|
||||
Tensor expected_values_1(DT_INT32, TensorShape({7}));
|
||||
Tensor expected_values_2(DT_INT32, TensorShape({7}));
|
||||
|
||||
test::FillValues<int64>(&expected_splits_1_1, component_split_1_1);
|
||||
test::FillValues<int64>(&expected_splits_1_2, component_split_1_2);
|
||||
test::FillValues<int64>(&expected_splits_2_1, component_split_2_1);
|
||||
test::FillValues<int64>(&expected_splits_2_2, component_split_2_2);
|
||||
test::FillValues<int>(&expected_values_1, component_values_1);
|
||||
test::FillValues<int>(&expected_values_2, component_values_2);
|
||||
|
||||
BuildEncodeRaggedTensorGraph<int, int64>(
|
||||
{batched_splits_1, batched_splits_2, batched_splits_3}, TensorShape({14}),
|
||||
@ -379,31 +265,14 @@ TEST_F(RaggedTensorToVariantKernelTest, NonEmptyBatchedInput) {
|
||||
const auto& encoded_list = GetOutput(0)->vec<Variant>();
|
||||
EXPECT_EQ(encoded_list.size(), 2);
|
||||
|
||||
const Variant& encoded_splits_1_1 =
|
||||
encoded_list(0).get<Tensor>()->vec<Variant>()(0);
|
||||
const Variant& encoded_splits_1_2 =
|
||||
encoded_list(0).get<Tensor>()->vec<Variant>()(1);
|
||||
const Variant& encoded_values_1 =
|
||||
encoded_list(0).get<Tensor>()->vec<Variant>()(2);
|
||||
const Variant& encoded_splits_2_1 =
|
||||
encoded_list(1).get<Tensor>()->vec<Variant>()(0);
|
||||
const Variant& encoded_splits_2_2 =
|
||||
encoded_list(1).get<Tensor>()->vec<Variant>()(1);
|
||||
const Variant& encoded_values_2 =
|
||||
encoded_list(1).get<Tensor>()->vec<Variant>()(2);
|
||||
|
||||
test::ExpectTensorEqual<int64>(*encoded_splits_1_1.get<Tensor>(),
|
||||
expected_splits_1_1);
|
||||
test::ExpectTensorEqual<int64>(*encoded_splits_1_2.get<Tensor>(),
|
||||
expected_splits_1_2);
|
||||
test::ExpectTensorEqual<int64>(*encoded_splits_2_1.get<Tensor>(),
|
||||
expected_splits_2_1);
|
||||
test::ExpectTensorEqual<int64>(*encoded_splits_2_2.get<Tensor>(),
|
||||
expected_splits_2_2);
|
||||
test::ExpectTensorEqual<int>(*encoded_values_1.get<Tensor>(),
|
||||
expected_values_1);
|
||||
test::ExpectTensorEqual<int>(*encoded_values_2.get<Tensor>(),
|
||||
expected_values_2);
|
||||
ExpectRaggedTensorVariantEqual<int, int64>(
|
||||
CreateVariantFromRagged<int, int64>(
|
||||
{{0, 1, 3, 4, 5, 6}, {0, 2, 3, 4, 5, 6, 7}}, {0, 1, 1, 2, 2, 3, 4}),
|
||||
*encoded_list(0).get<RaggedTensorVariant>());
|
||||
ExpectRaggedTensorVariantEqual<int, int64>(
|
||||
CreateVariantFromRagged<int, int64>(
|
||||
{{0, 1, 2, 3, 4, 5}, {0, 1, 2, 5, 6, 7}}, {5, 6, 7, 8, 9, 8, 9}),
|
||||
*encoded_list(1).get<RaggedTensorVariant>());
|
||||
}
|
||||
|
||||
TEST_F(RaggedTensorToVariantKernelTest, NonEmptyBatchedInputInt32Splits) {
|
||||
@ -424,28 +293,8 @@ TEST_F(RaggedTensorToVariantKernelTest, NonEmptyBatchedInputInt32Splits) {
|
||||
7, 8, 9, 12, 13, 14};
|
||||
const std::vector<int> batched_values = {0, 1, 1, 2, 2, 3, 4,
|
||||
5, 6, 7, 8, 9, 8, 9};
|
||||
const std::vector<int> component_split_1_1 = {0, 1, 3, 4, 5, 6};
|
||||
const std::vector<int> component_split_1_2 = {0, 2, 3, 4, 5, 6, 7};
|
||||
const std::vector<int> component_split_2_1 = {0, 1, 2, 3, 4, 5};
|
||||
const std::vector<int> component_split_2_2 = {0, 1, 2, 5, 6, 7};
|
||||
const std::vector<int> component_values_1 = {0, 1, 1, 2, 2, 3, 4};
|
||||
const std::vector<int> component_values_2 = {5, 6, 7, 8, 9, 8, 9};
|
||||
|
||||
Tensor expected_splits_1_1(DT_INT32, TensorShape({6}));
|
||||
Tensor expected_splits_1_2(DT_INT32, TensorShape({7}));
|
||||
Tensor expected_splits_2_1(DT_INT32, TensorShape({6}));
|
||||
Tensor expected_splits_2_2(DT_INT32, TensorShape({6}));
|
||||
Tensor expected_values_1(DT_INT32, TensorShape({7}));
|
||||
Tensor expected_values_2(DT_INT32, TensorShape({7}));
|
||||
|
||||
test::FillValues<int>(&expected_splits_1_1, component_split_1_1);
|
||||
test::FillValues<int>(&expected_splits_1_2, component_split_1_2);
|
||||
test::FillValues<int>(&expected_splits_2_1, component_split_2_1);
|
||||
test::FillValues<int>(&expected_splits_2_2, component_split_2_2);
|
||||
test::FillValues<int>(&expected_values_1, component_values_1);
|
||||
test::FillValues<int>(&expected_values_2, component_values_2);
|
||||
|
||||
BuildEncodeRaggedTensorGraph<int, int>(
|
||||
BuildEncodeRaggedTensorGraph<int, int32>(
|
||||
{batched_splits_1, batched_splits_2, batched_splits_3}, TensorShape({14}),
|
||||
batched_values, true);
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
@ -453,31 +302,14 @@ TEST_F(RaggedTensorToVariantKernelTest, NonEmptyBatchedInputInt32Splits) {
|
||||
const auto& encoded_list = GetOutput(0)->vec<Variant>();
|
||||
EXPECT_EQ(encoded_list.size(), 2);
|
||||
|
||||
const Variant& encoded_splits_1_1 =
|
||||
encoded_list(0).get<Tensor>()->vec<Variant>()(0);
|
||||
const Variant& encoded_splits_1_2 =
|
||||
encoded_list(0).get<Tensor>()->vec<Variant>()(1);
|
||||
const Variant& encoded_values_1 =
|
||||
encoded_list(0).get<Tensor>()->vec<Variant>()(2);
|
||||
const Variant& encoded_splits_2_1 =
|
||||
encoded_list(1).get<Tensor>()->vec<Variant>()(0);
|
||||
const Variant& encoded_splits_2_2 =
|
||||
encoded_list(1).get<Tensor>()->vec<Variant>()(1);
|
||||
const Variant& encoded_values_2 =
|
||||
encoded_list(1).get<Tensor>()->vec<Variant>()(2);
|
||||
|
||||
test::ExpectTensorEqual<int>(*encoded_splits_1_1.get<Tensor>(),
|
||||
expected_splits_1_1);
|
||||
test::ExpectTensorEqual<int>(*encoded_splits_1_2.get<Tensor>(),
|
||||
expected_splits_1_2);
|
||||
test::ExpectTensorEqual<int>(*encoded_splits_2_1.get<Tensor>(),
|
||||
expected_splits_2_1);
|
||||
test::ExpectTensorEqual<int>(*encoded_splits_2_2.get<Tensor>(),
|
||||
expected_splits_2_2);
|
||||
test::ExpectTensorEqual<int>(*encoded_values_1.get<Tensor>(),
|
||||
expected_values_1);
|
||||
test::ExpectTensorEqual<int>(*encoded_values_2.get<Tensor>(),
|
||||
expected_values_2);
|
||||
ExpectRaggedTensorVariantEqual<int, int32>(
|
||||
CreateVariantFromRagged<int, int32>(
|
||||
{{0, 1, 3, 4, 5, 6}, {0, 2, 3, 4, 5, 6, 7}}, {0, 1, 1, 2, 2, 3, 4}),
|
||||
*encoded_list(0).get<RaggedTensorVariant>());
|
||||
ExpectRaggedTensorVariantEqual<int, int32>(
|
||||
CreateVariantFromRagged<int, int32>(
|
||||
{{0, 1, 2, 3, 4, 5}, {0, 1, 2, 5, 6, 7}}, {5, 6, 7, 8, 9, 8, 9}),
|
||||
*encoded_list(1).get<RaggedTensorVariant>());
|
||||
}
|
||||
|
||||
TEST_F(RaggedTensorToVariantKernelTest, NonBatchInput) {
|
||||
@ -491,33 +323,17 @@ TEST_F(RaggedTensorToVariantKernelTest, NonBatchInput) {
|
||||
const std::vector<int> batched_values = {1, 2, 3, 4, 5, 6, 7, 8,
|
||||
9, 10, 11, 12, 13, 14, 15};
|
||||
|
||||
Tensor batched_ragged_splits_1(DT_INT64, TensorShape({5}));
|
||||
Tensor batched_ragged_splits_2(DT_INT64, TensorShape({8}));
|
||||
Tensor batched_ragged_values(DT_INT32, TensorShape({15}));
|
||||
|
||||
test::FillValues<int64>(&batched_ragged_splits_1, batched_splits_1);
|
||||
test::FillValues<int64>(&batched_ragged_splits_2, batched_splits_2);
|
||||
test::FillValues<int>(&batched_ragged_values, batched_values);
|
||||
|
||||
BuildEncodeRaggedTensorGraph<int, int64>({batched_splits_1, batched_splits_2},
|
||||
TensorShape({15}), batched_values,
|
||||
false);
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
|
||||
const auto& encoded_scalar = GetOutput(0)->scalar<Variant>()();
|
||||
const Variant& encoded_splits_1 =
|
||||
encoded_scalar.get<Tensor>()->vec<Variant>()(0);
|
||||
const Variant& encoded_splits_2 =
|
||||
encoded_scalar.get<Tensor>()->vec<Variant>()(1);
|
||||
const Variant& encoded_values =
|
||||
encoded_scalar.get<Tensor>()->vec<Variant>()(2);
|
||||
|
||||
test::ExpectTensorEqual<int64>(*encoded_splits_1.get<Tensor>(),
|
||||
batched_ragged_splits_1);
|
||||
test::ExpectTensorEqual<int64>(*encoded_splits_2.get<Tensor>(),
|
||||
batched_ragged_splits_2);
|
||||
test::ExpectTensorEqual<int>(*encoded_values.get<Tensor>(),
|
||||
batched_ragged_values);
|
||||
ExpectRaggedTensorVariantEqual<int, int64>(
|
||||
CreateVariantFromRagged<int, int64>({batched_splits_1, batched_splits_2},
|
||||
batched_values),
|
||||
*encoded_scalar.get<RaggedTensorVariant>());
|
||||
}
|
||||
|
||||
TEST_F(RaggedTensorToVariantKernelTest, ShapeFnTestBatched) {
|
||||
@ -598,17 +414,14 @@ TEST_F(RaggedTensorToVariantKernelTest, ShapeFnTestNotBatched) {
|
||||
|
||||
TEST_F(RaggedTensorToVariantKernelTest, NonRaggedInput) {
|
||||
const std::vector<int> values = {1, 2, 3, 4, 5, 6};
|
||||
Tensor expected_values(DT_INT32, TensorShape({6}));
|
||||
test::FillValues<int>(&expected_values, values);
|
||||
|
||||
BuildEncodeRaggedTensorGraph<int, int64>({}, TensorShape({6}), values, false);
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
|
||||
const auto& encoded_scalar = GetOutput(0)->scalar<Variant>()();
|
||||
const Variant& encoded_values =
|
||||
encoded_scalar.get<Tensor>()->vec<Variant>()(0);
|
||||
|
||||
test::ExpectTensorEqual<int>(*encoded_values.get<Tensor>(), expected_values);
|
||||
ExpectRaggedTensorVariantEqual<int, int64>(
|
||||
CreateVariantFromRagged<int, int64>({}, values),
|
||||
*encoded_scalar.get<RaggedTensorVariant>());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
86
tensorflow/core/kernels/ragged_tensor_variant.cc
Normal file
86
tensorflow/core/kernels/ragged_tensor_variant.cc
Normal file
@ -0,0 +1,86 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#define EIGEN_USE_THREADS
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#define EIGEN_USE_GPU
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
#include "tensorflow/core/kernels/ragged_tensor_variant.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
string RaggedTensorVariant::TypeName() const { return "RaggedTensorVariant"; }
|
||||
|
||||
string RaggedTensorVariant::DebugString() const {
|
||||
return absl::StrCat(
|
||||
"RaggedTensorVariant(dtype=", DataTypeString(values_.dtype()),
|
||||
", ragged_rank=", nested_splits_.size(), ", splits_dtype=",
|
||||
DataTypeString(nested_splits_.empty() ? DT_INVALID
|
||||
: nested_splits_.back().dtype()));
|
||||
}
|
||||
|
||||
void RaggedTensorVariant::Encode(VariantTensorData* data) const {
|
||||
data->set_type_name(TypeName());
|
||||
for (const auto& splits : nested_splits_) {
|
||||
*data->add_tensors() = splits;
|
||||
}
|
||||
*data->add_tensors() = values_;
|
||||
}
|
||||
|
||||
bool RaggedTensorVariant::Decode(const VariantTensorData& data) {
|
||||
if (data.tensors_size() < 1) {
|
||||
return false;
|
||||
}
|
||||
nested_splits_.assign(data.tensors().begin(),
|
||||
std::prev(data.tensors().end()));
|
||||
values_ = data.tensors().back();
|
||||
return true;
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
Status RaggedTensorVariantDeviceCopy(
|
||||
const RaggedTensorVariant& from, RaggedTensorVariant* to,
|
||||
const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copy) {
|
||||
TF_RETURN_IF_ERROR(copy(from.values(), to->mutable_values()));
|
||||
// TODO(b/170415165) Should we use `copy` to move splits from device<->host?
|
||||
*to->mutable_nested_splits() = from.nested_splits();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(
|
||||
ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_CPU, RaggedTensorVariant,
|
||||
RaggedTensorVariantZerosLike<CPUDevice>);
|
||||
|
||||
REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(
|
||||
ADD_VARIANT_BINARY_OP, DEVICE_CPU, RaggedTensorVariant,
|
||||
RaggedTensorVariantBinaryAdd<CPUDevice>);
|
||||
|
||||
REGISTER_UNARY_VARIANT_DECODE_FUNCTION(RaggedTensorVariant,
|
||||
"RaggedTensorVariant");
|
||||
|
||||
#define REGISTER_RAGGED_TENSOR_VARIANT_COPY(DIRECTION) \
|
||||
INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( \
|
||||
RaggedTensorVariant, DIRECTION, RaggedTensorVariantDeviceCopy)
|
||||
|
||||
REGISTER_RAGGED_TENSOR_VARIANT_COPY(VariantDeviceCopyDirection::HOST_TO_DEVICE);
|
||||
REGISTER_RAGGED_TENSOR_VARIANT_COPY(VariantDeviceCopyDirection::DEVICE_TO_HOST);
|
||||
REGISTER_RAGGED_TENSOR_VARIANT_COPY(
|
||||
VariantDeviceCopyDirection::DEVICE_TO_DEVICE);
|
||||
|
||||
} // namespace tensorflow
|
110
tensorflow/core/kernels/ragged_tensor_variant.h
Normal file
110
tensorflow/core/kernels/ragged_tensor_variant.h
Normal file
@ -0,0 +1,110 @@
|
||||
#include "tensorflow/core/framework/tensor_key.h"
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_KERNELS_RAGGED_TENSOR_VARIANT_H_
|
||||
#define TENSORFLOW_CORE_KERNELS_RAGGED_TENSOR_VARIANT_H_
|
||||
|
||||
#define EIGEN_USE_THREADS
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#define EIGEN_USE_GPU
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/framework/variant_op_registry.h"
|
||||
#include "tensorflow/core/framework/variant_tensor_data.h"
|
||||
#include "tensorflow/core/kernels/cwise_ops_common.h"
|
||||
#include "tensorflow/core/util/tensor_ops_util.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Class used to store a RaggedTensor as a Variant scalar.
|
||||
class RaggedTensorVariant {
|
||||
public:
|
||||
RaggedTensorVariant() {}
|
||||
RaggedTensorVariant(Tensor values, const std::vector<Tensor>& nested_splits)
|
||||
: values_(std::move(values)), nested_splits_(nested_splits) {}
|
||||
|
||||
// Variant support methods.
|
||||
string TypeName() const;
|
||||
string DebugString() const;
|
||||
void Encode(VariantTensorData* data) const;
|
||||
bool Decode(const VariantTensorData& data);
|
||||
|
||||
// The flat_values of the RaggedTensor.
|
||||
const Tensor& values() const { return values_; }
|
||||
Tensor* mutable_values() { return &values_; }
|
||||
void set_values(const Tensor& new_values) { values_ = new_values; }
|
||||
|
||||
// The nested row_splits of the RaggedTensor.
|
||||
int ragged_rank() const { return nested_splits_.size(); }
|
||||
const std::vector<Tensor>& nested_splits() const { return nested_splits_; }
|
||||
std::vector<Tensor>* mutable_nested_splits() { return &nested_splits_; }
|
||||
const Tensor& splits(int i) const { return nested_splits_[i]; }
|
||||
Tensor* mutable_splits(int i) { return &nested_splits_[i]; }
|
||||
void set_nested_splits(const std::vector<Tensor>& nested_splits) {
|
||||
nested_splits_ = nested_splits;
|
||||
}
|
||||
void append_splits(const Tensor& splits) { nested_splits_.push_back(splits); }
|
||||
|
||||
private:
|
||||
Tensor values_;
|
||||
std::vector<Tensor> nested_splits_;
|
||||
};
|
||||
|
||||
template <typename Device>
|
||||
Status RaggedTensorVariantZerosLike(OpKernelContext* c,
|
||||
const RaggedTensorVariant& x,
|
||||
RaggedTensorVariant* y) {
|
||||
y->set_nested_splits(x.nested_splits());
|
||||
TF_RETURN_IF_ERROR(
|
||||
ZerosLikeTensor<Device>(c, x.values(), y->mutable_values()));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename Device>
|
||||
Status RaggedTensorVariantBinaryAdd(OpKernelContext* c,
|
||||
const RaggedTensorVariant& x,
|
||||
const RaggedTensorVariant& y,
|
||||
RaggedTensorVariant* out) {
|
||||
if (x.values().dtype() != y.values().dtype()) {
|
||||
return errors::InvalidArgument(
|
||||
"Can't add RaggedTensorVariants of different dtypes. One is ",
|
||||
DataTypeString(x.values().dtype()), " and the other is ",
|
||||
DataTypeString(y.values().dtype()));
|
||||
}
|
||||
if (x.ragged_rank() != y.ragged_rank()) {
|
||||
return errors::InvalidArgument(
|
||||
"Can't add RaggedTensorVariants of different ragged rank. ", "One is ",
|
||||
x.ragged_rank(), " and the other is ", y.ragged_rank());
|
||||
}
|
||||
for (int i = 0; i < x.ragged_rank(); ++i) {
|
||||
if (TensorKey(x.splits(i)) != TensorKey(y.splits(i))) {
|
||||
return errors::InvalidArgument(
|
||||
"Can't add RaggedTensorVariants with different row_splits.");
|
||||
}
|
||||
}
|
||||
out->set_nested_splits(x.nested_splits());
|
||||
TF_RETURN_IF_ERROR(BinaryAddTensors<Device>(c, x.values(), y.values(),
|
||||
out->mutable_values()));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_KERNELS_RAGGED_TENSOR_VARIANT_H_
|
@ -92,7 +92,8 @@ tensorflow::Status ValidateRowPartitionTypesAndShapes(
|
||||
Status RaggedTensorToSparseShapeFn(InferenceContext* c);
|
||||
Status RaggedTensorToVariantShapeFn(InferenceContext* c);
|
||||
Status RaggedTensorFromVariantShapeFn(InferenceContext* c);
|
||||
tensorflow::Status RaggedTensorToTensorShapeFn(InferenceContext* c);
|
||||
Status RaggedTensorToVariantGradientShapeFn(InferenceContext* c);
|
||||
Status RaggedTensorToTensorShapeFn(InferenceContext* c);
|
||||
|
||||
//==============================================================================
|
||||
// Registered Ops
|
||||
@ -129,6 +130,15 @@ REGISTER_OP("RaggedTensorFromVariant")
|
||||
.Attr("Tsplits: {int32, int64} = DT_INT64")
|
||||
.SetShapeFn(RaggedTensorFromVariantShapeFn);
|
||||
|
||||
REGISTER_OP("RaggedTensorToVariantGradient")
|
||||
.Input("encoded_ragged_grad: variant")
|
||||
.Input("row_splits: Tsplits")
|
||||
.Input("dense_values_shape: int32")
|
||||
.Output("dense_values_grad: Tvalues")
|
||||
.Attr("Tvalues: type")
|
||||
.Attr("Tsplits: {int32, int64} = DT_INT64")
|
||||
.SetShapeFn(RaggedTensorToVariantGradientShapeFn);
|
||||
|
||||
REGISTER_OP("RaggedTensorToTensor")
|
||||
.Attr("T: type")
|
||||
.Attr("Tindex: {int64, int32}")
|
||||
@ -201,6 +211,14 @@ Status RaggedTensorToVariantShapeFn(InferenceContext* c) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status RaggedTensorToVariantGradientShapeFn(InferenceContext* c) {
|
||||
ShapeHandle shape;
|
||||
TF_RETURN_IF_ERROR(
|
||||
c->MakeShapeFromShapeTensorTreatScalarAsUnknownShape(2, &shape));
|
||||
c->set_output(0, shape);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status RaggedTensorFromVariantShapeFn(InferenceContext* c) {
|
||||
int64 input_ragged_rank;
|
||||
TF_RETURN_IF_ERROR(
|
||||
|
@ -510,6 +510,7 @@ py_test(
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
"//tensorflow/python:tensor_array_grad",
|
||||
"//tensorflow/python:tensor_shape",
|
||||
"//tensorflow/python:tensor_spec",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
|
@ -143,3 +143,42 @@ def to_sparse(rt_input, name=None):
|
||||
|
||||
def from_sparse(st_input, name=None):
|
||||
return ragged_tensor.RaggedTensor.from_sparse(st_input, name)
|
||||
|
||||
|
||||
@ops.RegisterGradient("RaggedTensorFromVariant")
|
||||
def _ragged_tensor_from_variant_grad(op, *grads):
|
||||
"""Gradient for RaggedTensorFromVariant op."""
|
||||
|
||||
variant_rank = op.inputs[0].shape.rank
|
||||
if variant_rank == 0:
|
||||
batched_input = False
|
||||
elif variant_rank == 1:
|
||||
batched_input = True
|
||||
elif variant_rank is None:
|
||||
batched_input = (op.get_attr("output_ragged_rank") > 0)
|
||||
else:
|
||||
# TODO(edloper): Add a batch_dims argument to RaggedTensorToVariant, so
|
||||
# we can support this.
|
||||
raise ValueError("Unable to compute gradient: RaggedTensorToVariant "
|
||||
"can currently only generate 0D or 1D output.")
|
||||
return [
|
||||
gen_ragged_conversion_ops.ragged_tensor_to_variant(
|
||||
rt_nested_splits=op.outputs[:-1],
|
||||
rt_dense_values=grads[-1],
|
||||
batched_input=batched_input)
|
||||
]
|
||||
|
||||
|
||||
@ops.RegisterGradient("RaggedTensorToVariant")
|
||||
def _ragged_tensor_to_variant_grad(op, encoded_ragged_grad):
|
||||
"""Gradient for RaggedTensorToVariant op."""
|
||||
dense_values = op.inputs[-1]
|
||||
ragged_rank = len(op.inputs) - 1
|
||||
row_splits = 0 if ragged_rank == 0 else op.inputs[0]
|
||||
values_grad = gen_ragged_conversion_ops.ragged_tensor_to_variant_gradient(
|
||||
encoded_ragged_grad=encoded_ragged_grad,
|
||||
row_splits=row_splits,
|
||||
dense_values_shape=array_ops.shape(dense_values),
|
||||
Tvalues=op.inputs[-1].dtype)
|
||||
result = [None] * ragged_rank + [values_grad]
|
||||
return result
|
||||
|
@ -2863,9 +2863,6 @@ def _get_optional_partition_dtype(values):
|
||||
return None
|
||||
|
||||
|
||||
ops.no_gradient("RaggedTensorToVariant")
|
||||
|
||||
|
||||
_SUPPORTED_RAGGED_VALUE_TYPES = (ops.Tensor, RaggedTensor)
|
||||
|
||||
|
||||
|
@ -18,10 +18,12 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import functools
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
@ -30,8 +32,15 @@ from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_grad # pylint: disable=unused-import
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import gen_ragged_conversion_ops
|
||||
from tensorflow.python.ops import gradients_impl
|
||||
from tensorflow.python.ops import map_fn
|
||||
from tensorflow.python.ops import math_grad # pylint: disable=unused-import
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import tensor_array_grad # pylint: disable=unused-import
|
||||
from tensorflow.python.ops.ragged import ragged_factory_ops
|
||||
from tensorflow.python.ops.ragged import ragged_math_ops
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
@ -1233,19 +1242,21 @@ class RaggedTensorTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
self.evaluate(factory(**kwargs))
|
||||
|
||||
#=============================================================================
|
||||
# RaggedTensor Variant conversion
|
||||
#=============================================================================
|
||||
|
||||
#=============================================================================
|
||||
# RaggedTensor Variant conversion
|
||||
#=============================================================================
|
||||
|
||||
@parameterized.parameters(
|
||||
@parameterized.named_parameters(
|
||||
{
|
||||
'testcase_name': 'Shape_5_none',
|
||||
'ragged_constant': [[1, 2], [3, 4, 5], [6], [], [7]],
|
||||
'ragged_rank': 1
|
||||
}, {
|
||||
'testcase_name': 'Shape_4_none_2',
|
||||
'ragged_constant': [[[1, 2]], [], [[3, 4]], []],
|
||||
'ragged_rank': 1
|
||||
}, {
|
||||
'testcase_name': 'Shape_1_none_none',
|
||||
'ragged_constant': [[[1], [2, 3, 4, 5, 6, 7]], [[]]],
|
||||
'ragged_rank': 2
|
||||
})
|
||||
@ -1432,6 +1443,131 @@ class RaggedTensorTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
output_ragged_rank=1,
|
||||
input_ragged_rank=1)
|
||||
|
||||
def _testRaggedVarientGradient(self, func, x, expected_grad):
|
||||
x = constant_op.constant(x)
|
||||
if context.executing_eagerly():
|
||||
with backprop.GradientTape() as t:
|
||||
t.watch(x)
|
||||
y = func(x)
|
||||
g = t.gradient(y, x)
|
||||
else:
|
||||
y = func(x)
|
||||
g = gradients_impl.gradients(ys=y, xs=x)[0]
|
||||
self.assertAllClose(g, expected_grad)
|
||||
|
||||
def testRaggedVariantGradients(self):
|
||||
def func(x):
|
||||
rt1 = RaggedTensor.from_row_splits(values=x, row_splits=[0, 4, 7, 8])
|
||||
rt2 = rt1 * [[10], [100], [1000]]
|
||||
v = rt2._to_variant(batched_input=False)
|
||||
rt3 = RaggedTensor._from_variant(v, dtype=rt2.dtype, output_ragged_rank=1)
|
||||
return rt3.flat_values
|
||||
|
||||
self._testRaggedVarientGradient(
|
||||
func,
|
||||
[3.0, 1.0, 4.0, 1.0, 1.0, 0.0, 2.0, 1.0],
|
||||
[10., 10., 10., 10., 100., 100., 100., 1000.])
|
||||
|
||||
def testRaggedVariantGradientsBatched(self):
|
||||
def func(x):
|
||||
rt1 = RaggedTensor.from_row_splits(values=x, row_splits=[0, 4, 7, 8])
|
||||
rt2 = rt1 * [[10], [100], [1000]]
|
||||
v = rt2._to_variant(batched_input=True)
|
||||
rt3 = RaggedTensor._from_variant(v, dtype=rt2.dtype, output_ragged_rank=1)
|
||||
return rt3.flat_values
|
||||
|
||||
self._testRaggedVarientGradient(
|
||||
func,
|
||||
[3.0, 1.0, 4.0, 1.0, 1.0, 0.0, 2.0, 1.0],
|
||||
[10., 10., 10., 10., 100., 100., 100., 1000.])
|
||||
|
||||
def testRaggedVariantGradientsBatchedAndSliced(self):
|
||||
def func(x, i):
|
||||
rt1 = RaggedTensor.from_row_splits(values=x, row_splits=[0, 4, 7, 8])
|
||||
rt2 = rt1 * [[10], [100], [1000]]
|
||||
v_slice = rt2._to_variant(batched_input=True)[i]
|
||||
return RaggedTensor._from_variant(v_slice, dtype=rt2.dtype,
|
||||
output_ragged_rank=0)
|
||||
|
||||
self._testRaggedVarientGradient(
|
||||
functools.partial(func, i=0),
|
||||
[3.0, 1.0, 4.0, 1.0, 1.0, 0.0, 2.0, 1.0],
|
||||
[10., 10., 10., 10., 0., 0., 0., 0.])
|
||||
self._testRaggedVarientGradient(
|
||||
functools.partial(func, i=1),
|
||||
[3.0, 1.0, 4.0, 1.0, 1.0, 0.0, 2.0, 1.0],
|
||||
[0., 0., 0., 0., 100., 100., 100., 0.])
|
||||
self._testRaggedVarientGradient(
|
||||
functools.partial(func, i=2),
|
||||
[3.0, 1.0, 4.0, 1.0, 1.0, 0.0, 2.0, 1.0],
|
||||
[0., 0., 0., 0., 0., 0., 0., 1000.])
|
||||
|
||||
def testRaggedVariantGradientsRaggedRank0(self):
|
||||
def func(x):
|
||||
x2 = x * 2
|
||||
v = gen_ragged_conversion_ops.ragged_tensor_to_variant(
|
||||
[], x2, batched_input=False)
|
||||
return RaggedTensor._from_variant(v, dtype=x2.dtype, output_ragged_rank=0)
|
||||
|
||||
self._testRaggedVarientGradient(
|
||||
func,
|
||||
[3.0, 1.0, 4.0, 1.0, 1.0, 0.0, 2.0, 1.0],
|
||||
[2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0])
|
||||
|
||||
def testRaggedVariantGradientsRaggedRank3(self):
|
||||
def func(x):
|
||||
x2 = x * 2
|
||||
rt1 = RaggedTensor.from_nested_row_splits(
|
||||
x2, ([0, 0, 3], [0, 2, 2, 3], [0, 4, 7, 8]))
|
||||
v = rt1._to_variant(batched_input=False)
|
||||
rt3 = RaggedTensor._from_variant(v, dtype=x2.dtype, output_ragged_rank=3)
|
||||
return rt3.flat_values
|
||||
|
||||
self._testRaggedVarientGradient(
|
||||
func,
|
||||
[3.0, 1.0, 4.0, 1.0, 1.0, 0.0, 2.0, 1.0],
|
||||
[2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0])
|
||||
|
||||
def testRaggedVariantGradientsViaMapFn(self):
|
||||
rt = RaggedTensor.from_row_splits(
|
||||
values=[3, 1.0, 4, 1, 5, 9, 2, 6], row_splits=[0, 4, 7, 8])
|
||||
|
||||
def func(x):
|
||||
|
||||
def transform_row(row):
|
||||
return math_ops.sqrt(
|
||||
math_ops.reduce_mean(math_ops.square(row * x), keepdims=True))
|
||||
|
||||
return math_ops.reduce_sum(map_fn.map_fn(transform_row, rt))
|
||||
|
||||
self._testRaggedVarientGradient(func, 3.0, 14.653377)
|
||||
|
||||
def testRaggedVariantGradientsViaMapFnReduce(self):
|
||||
def func(x):
|
||||
rt1 = RaggedTensor.from_row_splits(values=x, row_splits=[0, 4, 7, 8])
|
||||
return map_fn.map_fn(
|
||||
math_ops.reduce_max, rt1,
|
||||
fn_output_signature=tensor_spec.TensorSpec((), x.dtype))
|
||||
|
||||
self._testRaggedVarientGradient(
|
||||
func,
|
||||
[3.0, 1.0, 4.0, 1.0, 1.0, 0.0, 2.0, 1.0],
|
||||
[0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0])
|
||||
|
||||
def testRaggedVariantGradientsErrors(self):
|
||||
if context.executing_eagerly():
|
||||
return
|
||||
|
||||
rt = RaggedTensor.from_row_splits([1.0, 2.0], row_splits=[0, 2, 2])
|
||||
v1 = rt._to_variant()
|
||||
v2 = array_ops.stack([array_ops.stack([v1])])
|
||||
y = RaggedTensor._from_variant(v2, rt.dtype, output_ragged_rank=3)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, 'Unable to compute gradient: RaggedTensorToVariant '
|
||||
'can currently only generate 0D or 1D output.'):
|
||||
gradients_impl.gradients(ys=y.flat_values, xs=rt.flat_values)
|
||||
|
||||
def assertNumpyObjectTensorsRecursivelyEqual(self, a, b, msg):
|
||||
"""Check that two numpy arrays are equal.
|
||||
|
||||
|
@ -3212,6 +3212,10 @@ tf_module {
|
||||
name: "RaggedTensorToVariant"
|
||||
argspec: "args=[\'rt_nested_splits\', \'rt_dense_values\', \'batched_input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "RaggedTensorToVariantGradient"
|
||||
argspec: "args=[\'encoded_ragged_grad\', \'row_splits\', \'dense_values_shape\', \'Tvalues\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "RandomCrop"
|
||||
argspec: "args=[\'image\', \'size\', \'seed\', \'seed2\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'0\', \'None\'], "
|
||||
|
@ -3212,6 +3212,10 @@ tf_module {
|
||||
name: "RaggedTensorToVariant"
|
||||
argspec: "args=[\'rt_nested_splits\', \'rt_dense_values\', \'batched_input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "RaggedTensorToVariantGradient"
|
||||
argspec: "args=[\'encoded_ragged_grad\', \'row_splits\', \'dense_values_shape\', \'Tvalues\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "RandomCrop"
|
||||
argspec: "args=[\'image\', \'size\', \'seed\', \'seed2\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'0\', \'None\'], "
|
||||
|
Loading…
Reference in New Issue
Block a user