From be6b1fdb0699d4000b70ad32cc23d1503e5c7511 Mon Sep 17 00:00:00 2001 From: Edward Loper <edloper@google.com> Date: Wed, 14 Oct 2020 09:41:17 -0700 Subject: [PATCH] Added gradients for RaggedTensorToVariant and RaggedTensorFromVariant. (This allows gradients to pass through map_fn when it is applied to ragged tensors.) PiperOrigin-RevId: 337108621 Change-Id: I73d5f3296181877f0cc4c7a6273b693bcf8310ab --- RELEASE.md | 1 + ...pi_def_RaggedTensorToVariantGradient.pbtxt | 38 ++ tensorflow/core/kernels/BUILD | 15 + .../core/kernels/data/experimental/BUILD | 1 + .../experimental/parse_example_dataset_op.cc | 10 +- .../kernels/ragged_tensor_from_variant_op.cc | 168 +++---- .../ragged_tensor_from_variant_op_test.cc | 160 +++---- .../kernels/ragged_tensor_to_variant_op.cc | 180 +++++--- .../ragged_tensor_to_variant_op_test.cc | 427 +++++------------- .../core/kernels/ragged_tensor_variant.cc | 86 ++++ .../core/kernels/ragged_tensor_variant.h | 110 +++++ tensorflow/core/ops/ragged_conversion_ops.cc | 20 +- tensorflow/python/ops/ragged/BUILD | 1 + .../ops/ragged/ragged_conversion_ops.py | 39 ++ tensorflow/python/ops/ragged/ragged_tensor.py | 3 - .../python/ops/ragged/ragged_tensor_test.py | 146 +++++- .../api/golden/v1/tensorflow.raw_ops.pbtxt | 4 + .../api/golden/v2/tensorflow.raw_ops.pbtxt | 4 + 18 files changed, 820 insertions(+), 593 deletions(-) create mode 100644 tensorflow/core/api_def/base_api/api_def_RaggedTensorToVariantGradient.pbtxt create mode 100644 tensorflow/core/kernels/ragged_tensor_variant.cc create mode 100644 tensorflow/core/kernels/ragged_tensor_variant.h diff --git a/RELEASE.md b/RELEASE.md index 23324d56ca7..0886b6e116c 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -138,6 +138,7 @@ stateful ops. * Added `tf.config.experimental.get_memory_usage` to return total memory usage of the device. + * Added gradients for `RaggedTensorToVariant` and `RaggedTensorFromVariant`. * `tf.data`: * tf.data service: * Added new `tf.data.experimental.service.register_dataset` and diff --git a/tensorflow/core/api_def/base_api/api_def_RaggedTensorToVariantGradient.pbtxt b/tensorflow/core/api_def/base_api/api_def_RaggedTensorToVariantGradient.pbtxt new file mode 100644 index 00000000000..066d6b5eae4 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_RaggedTensorToVariantGradient.pbtxt @@ -0,0 +1,38 @@ +op { + graph_op_name: "RaggedTensorToVariantGradient" + visibility: HIDDEN + in_arg { + name: "encoded_ragged_grad" + description: <<END +A `variant` Tensor containing encoded `RaggedTensor` gradients. +END + } + in_arg { + name: "row_splits" + description: <<END +Outermost row-splits that were used as input to the RaggedTensorToVariant op. +END + } + in_arg { + name: "dense_values_shape" + description: <<END +Shape of the dense_values that was used as an input to the +RaggedTensorToVariant op. +END + } + out_arg { + name: "dense_values_grad" + description: <<END +Gradient for the dense_values of the RaggedTensorToVariant op. +END + } + summary: <<END +Helper used to compute the gradient for `RaggedTensorToVariant`. +END + description: <<END +Computes the gradient for the dense_values input to the RaggedTensorToVariant +op, given the variant-encoded ragged gradients of the outputs, along with +the outer row-splits and the shape of the dense-values that were provided as +inputs to the RaggedTensorToVariant op. +END +} diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 6c045152434..f5874474ef8 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -1419,10 +1419,22 @@ tf_cc_test( ], ) +cc_library( + name = "ragged_tensor_variant", + srcs = ["ragged_tensor_variant.cc"], + hdrs = ["ragged_tensor_variant.h"], + deps = [ + ":cwise_op", + "//tensorflow/core:framework", + ], +) + tf_kernel_library( name = "ragged_tensor_to_variant_op", srcs = ["ragged_tensor_to_variant_op.cc"], deps = [ + ":concat_lib", + ":ragged_tensor_variant", "//tensorflow/core:framework", "//tensorflow/core:lib", ], @@ -1432,6 +1444,7 @@ tf_kernel_library( name = "ragged_tensor_from_variant_op", srcs = ["ragged_tensor_from_variant_op.cc"], deps = [ + ":ragged_tensor_variant", "//tensorflow/core:framework", "//tensorflow/core:lib", ], @@ -1444,6 +1457,7 @@ tf_cc_test( deps = [ ":ops_testutil", ":ragged_tensor_to_variant_op", + ":ragged_tensor_variant", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:test", @@ -1460,6 +1474,7 @@ tf_cc_test( deps = [ ":ops_testutil", ":ragged_tensor_from_variant_op", + ":ragged_tensor_variant", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:test", diff --git a/tensorflow/core/kernels/data/experimental/BUILD b/tensorflow/core/kernels/data/experimental/BUILD index c3ef8122ebb..92bc48159ad 100644 --- a/tensorflow/core/kernels/data/experimental/BUILD +++ b/tensorflow/core/kernels/data/experimental/BUILD @@ -424,6 +424,7 @@ tf_kernel_library( "//tensorflow/core:framework", "//tensorflow/core:functional_ops_op_lib", "//tensorflow/core:lib", + "//tensorflow/core/kernels:ragged_tensor_variant", "//tensorflow/core/kernels/data:dataset_utils", "//tensorflow/core/kernels/data:name_utils", "//tensorflow/core/kernels/data:parallel_map_dataset_op", diff --git a/tensorflow/core/kernels/data/experimental/parse_example_dataset_op.cc b/tensorflow/core/kernels/data/experimental/parse_example_dataset_op.cc index 16cf7fe6416..80f23bb5a0c 100644 --- a/tensorflow/core/kernels/data/experimental/parse_example_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/parse_example_dataset_op.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/core/kernels/data/name_utils.h" #include "tensorflow/core/kernels/data/parallel_map_dataset_op.h" #include "tensorflow/core/kernels/data/stats_utils.h" +#include "tensorflow/core/kernels/ragged_tensor_variant.h" #include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/stringprintf.h" #include "tensorflow/core/profiler/lib/traceme.h" @@ -678,12 +679,9 @@ class ParseExampleDatasetOp : public UnaryDatasetOpKernel { for (int d = 0; d < dataset()->ragged_keys_.size(); ++d) { int output_index = dataset()->key_to_output_index_.at(dataset()->ragged_keys_[d]); - (*output)[output_index] = Tensor(ctx->allocator({}), DT_VARIANT, {}); - Tensor serialized_ragged = - Tensor(ctx->allocator({}), DT_VARIANT, {2}); - auto serialized_ragged_t = serialized_ragged.vec<Variant>(); - serialized_ragged_t(0) = example_result.ragged_splits[d]; - serialized_ragged_t(1) = example_result.ragged_values[d]; + RaggedTensorVariant serialized_ragged; + serialized_ragged.append_splits(example_result.ragged_splits[d]); + serialized_ragged.set_values(example_result.ragged_values[d]); (*output)[output_index] = Tensor(ctx->allocator({}), DT_VARIANT, {}); Tensor& ragged_wrapper = (*output)[output_index]; ragged_wrapper.scalar<Variant>()() = serialized_ragged; diff --git a/tensorflow/core/kernels/ragged_tensor_from_variant_op.cc b/tensorflow/core/kernels/ragged_tensor_from_variant_op.cc index aa736ad7f60..d9993bb6d39 100644 --- a/tensorflow/core/kernels/ragged_tensor_from_variant_op.cc +++ b/tensorflow/core/kernels/ragged_tensor_from_variant_op.cc @@ -20,110 +20,76 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/variant.h" #include "tensorflow/core/framework/variant_encode_decode.h" +#include "tensorflow/core/kernels/ragged_tensor_variant.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" namespace tensorflow { namespace { -struct RaggedTensor { - Tensor values; - std::vector<Tensor> nested_splits; -}; - -Status RaggedComponentsFromVariant(const Tensor& encoded_variant, - int ragged_rank, DataType value_dtype, - DataType split_dtype, - std::vector<RaggedTensor>* decoded_ragged) { +Status RaggedComponentsFromVariant( + const Tensor& encoded_variant, int ragged_rank, DataType value_dtype, + DataType split_dtype, std::vector<RaggedTensorVariant>* decoded_ragged) { const auto& flat_variants = encoded_variant.flat<Variant>(); - decoded_ragged->resize(flat_variants.size()); - // Step 1: Extract the 1-D DT_VARIANT Tensor from each Variant element in the - // input. + decoded_ragged->reserve(flat_variants.size()); + for (int i = 0; i < flat_variants.size(); i++) { const auto& flat_variant = flat_variants(i); - const Tensor* encoded_list = flat_variant.get<Tensor>(); - if (encoded_list == nullptr) { + const RaggedTensorVariant* decoded = + flat_variant.get<RaggedTensorVariant>(); + if (decoded == nullptr) { return errors::InvalidArgument( "Input Variant element at index ", i, - " doesn't hold a Tensor: ", flat_variant.DebugString()); + " doesn't hold a RaggedTensorVariant: ", flat_variant.DebugString()); } - if (encoded_list->dims() != 1) { + decoded_ragged->push_back(*decoded); + decoded = &decoded_ragged->back(); + // Check ragged rank & types + if (decoded->ragged_rank() != ragged_rank) { return errors::InvalidArgument( - "Encoded input Variant must have rank 1, but found rank: ", - encoded_list->dims(), - ". encoded input Variant: ", encoded_list->DebugString()); + "Encoded input RaggedTensorVariant has ragged_rank=", + decoded->ragged_rank(), ". Expected ragged_rank=", ragged_rank, "."); } - if (encoded_list->NumElements() != (ragged_rank + 1) && - encoded_list->NumElements() != 1) { - return errors::InvalidArgument( - "Encoded input Variant must hold either input_ragged_rank + 1 " - "Tensors or an empty Tensor (zero splits Tensors, 1 values Tensor), " - "input_ragged_rank: ", - ragged_rank, - ", encoded input Variant: ", encoded_list->DebugString()); - } - const auto& input_vec = encoded_list->vec<Variant>(); - - // Step 2: Get the splits and value Tensors from the 1-D DT_VARIANT Tensor - // to create the component RaggedTensors. - (*decoded_ragged)[i].nested_splits.reserve(ragged_rank); - for (int j = 0; j < ragged_rank; j++) { - const Tensor* split_tensor = input_vec(j).get<Tensor>(); - if (split_tensor == nullptr) { - return errors::InvalidArgument( - "Encoded scalar element at index ", i, - " doesn't have a splits Tensor at split_index ", j, ": ", - input_vec(j).DebugString()); - } - Tensor splits_tensor = *split_tensor; - if (splits_tensor.dtype() != split_dtype) { - return errors::InvalidArgument( - "Expected splits Tensor dtype: ", split_dtype, - ", found: ", splits_tensor.dtype()); - } - if (splits_tensor.dims() != 1) { - return errors::InvalidArgument( - "Ragged splits must have rank 1; encoded scalar element at index ", - i, " has splits Tensor at split_index ", j, ": ", - splits_tensor.DebugString()); - } - (*decoded_ragged)[i].nested_splits.push_back(splits_tensor); - } - const Tensor* values_tensor = input_vec(ragged_rank).get<Tensor>(); - if (values_tensor == nullptr) { - return errors::InvalidArgument("Encoded scalar element at index ", i, - " doesn't have a values Tensor: ", - input_vec(ragged_rank).DebugString()); - } - if (values_tensor->dtype() != value_dtype) { + if (decoded->values().dtype() != value_dtype) { return errors::InvalidArgument( "Expected values Tensor dtype: ", DataTypeString(value_dtype), - ", found: ", DataTypeString(values_tensor->dtype())); + ", found: ", DataTypeString(decoded->values().dtype())); } - if (values_tensor->dims() < 1) { + if (decoded->values().dims() < 1) { return errors::InvalidArgument( "Ragged values must have rank >= 1; encoded scalar element at index ", - i, " has values Tensor: ", values_tensor->DebugString()); + i, " has values Tensor: ", decoded->values().DebugString()); + } + for (const auto& splits : decoded->nested_splits()) { + if (splits.dtype() != split_dtype) { + return errors::InvalidArgument( + "Expected row_splits Tensor dtype: ", DataTypeString(split_dtype), + ", found: ", DataTypeString(splits.dtype())); + } + if (splits.dims() != 1) { + return errors::InvalidArgument( + "Ragged splits must have rank 1; encoded scalar element at index ", + i, " has splits Tensor ", splits.DebugString()); + } } - (*decoded_ragged)[i].values = *values_tensor; } return Status::OK(); } template <typename VALUE_TYPE, typename SPLIT_TYPE> Status NestedStackRaggedTensors( - const std::vector<RaggedTensor>& ragged_components, + const std::vector<RaggedTensorVariant>& ragged_components, const std::vector<int>& nested_dim_sizes, const int input_ragged_rank, - const int output_ragged_rank, RaggedTensor* output_ragged) { - output_ragged->nested_splits.reserve(output_ragged_rank); + const int output_ragged_rank, RaggedTensorVariant* output_ragged) { + output_ragged->mutable_nested_splits()->reserve(output_ragged_rank); const int dims = nested_dim_sizes.size(); // Populate first `dims - 1` splits. for (int i = 0; i < dims - 1; i++) { int dims_splits_size = nested_dim_sizes[i] + 1; - output_ragged->nested_splits.push_back(Tensor( - DataTypeToEnum<SPLIT_TYPE>::value, TensorShape({dims_splits_size}))); - auto splits_vec = output_ragged->nested_splits[i].vec<SPLIT_TYPE>(); + output_ragged->append_splits(Tensor(DataTypeToEnum<SPLIT_TYPE>::value, + TensorShape({dims_splits_size}))); + auto splits_vec = output_ragged->mutable_splits(i)->vec<SPLIT_TYPE>(); int split_diff = nested_dim_sizes[i + 1]; for (int j = 0; j < dims_splits_size; j++) { splits_vec(j) = j * split_diff; @@ -132,15 +98,15 @@ Status NestedStackRaggedTensors( // Populate `dims`-th split. int splits_size = ragged_components.size() + 1; - output_ragged->nested_splits.push_back( + output_ragged->append_splits( Tensor(DataTypeToEnum<SPLIT_TYPE>::value, TensorShape({splits_size}))); auto dims_splits_vec = - output_ragged->nested_splits[dims - 1].vec<SPLIT_TYPE>(); + output_ragged->mutable_splits(dims - 1)->vec<SPLIT_TYPE>(); dims_splits_vec(0) = 0; for (int i = 0; i < ragged_components.size(); i++) { - int split_val = ragged_components[i].values.shape().dim_size(0); - if (input_ragged_rank != 0 && !ragged_components[i].nested_splits.empty()) { - split_val = ragged_components[i].nested_splits[0].NumElements() - 1; + int split_val = ragged_components[i].values().shape().dim_size(0); + if (input_ragged_rank != 0 && ragged_components[i].ragged_rank() > 0) { + split_val = ragged_components[i].splits(0).NumElements() - 1; } dims_splits_vec(i + 1) = dims_splits_vec(i) + split_val; } @@ -150,24 +116,24 @@ Status NestedStackRaggedTensors( int split_index = dims + i; int split_size = 1; for (int j = 0; j < ragged_components.size(); j++) { - if (!ragged_components[j].nested_splits.empty()) { - split_size += ragged_components[j].nested_splits[i].NumElements() - 1; + if (!ragged_components[j].nested_splits().empty()) { + split_size += ragged_components[j].splits(i).NumElements() - 1; } } - output_ragged->nested_splits.push_back( + output_ragged->append_splits( Tensor(DataTypeToEnum<SPLIT_TYPE>::value, TensorShape({split_size}))); auto splits_vec = - output_ragged->nested_splits[split_index].vec<SPLIT_TYPE>(); + output_ragged->mutable_splits(split_index)->vec<SPLIT_TYPE>(); splits_vec(0) = 0; SPLIT_TYPE last_split_value = 0; int index = 1; for (int j = 0; j < ragged_components.size(); j++) { - if (ragged_components[j].nested_splits.empty()) { + if (ragged_components[j].nested_splits().empty()) { // Corner case: empty row. e.g [ [[x], [x]], [] ] continue; } auto component_splits_vec = - ragged_components[j].nested_splits[i].vec<SPLIT_TYPE>(); + ragged_components[j].splits(i).vec<SPLIT_TYPE>(); for (int k = 1; k < component_splits_vec.size(); k++, index++) { splits_vec(index) = component_splits_vec(k) + last_split_value; } @@ -187,35 +153,35 @@ Status NestedStackRaggedTensors( if (ragged_components.empty()) { component_values_shape = TensorShape({0}); } else { - component_values_shape = ragged_components[0].values.shape(); + component_values_shape = ragged_components[0].values().shape(); } // Populate values. int values_size = component_values_shape.dim_size(0); for (int i = 1; i < ragged_components.size(); i++) { - if (ragged_components[i].values.dims() != component_values_shape.dims()) { + if (ragged_components[i].values().dims() != component_values_shape.dims()) { return errors::InvalidArgument( "Rank of values must match for all " "components; values shape at index 0: ", component_values_shape.DebugString(), ", values shape at index ", i, - ": ", ragged_components[i].values.shape().DebugString()); + ": ", ragged_components[i].values().shape().DebugString()); } - values_size += ragged_components[i].values.shape().dim_size(0); + values_size += ragged_components[i].values().shape().dim_size(0); } component_values_shape.set_dim(0, values_size); - output_ragged->values = - Tensor(DataTypeToEnum<VALUE_TYPE>::value, component_values_shape); + output_ragged->set_values( + Tensor(DataTypeToEnum<VALUE_TYPE>::value, component_values_shape)); auto output_values_flat = - output_ragged->values.flat_outer_dims<VALUE_TYPE, 2>(); + output_ragged->mutable_values()->flat_outer_dims<VALUE_TYPE, 2>(); int values_index = 0; for (int i = 0; i < ragged_components.size(); i++) { auto component_values_flat = - ragged_components[i].values.flat_outer_dims<VALUE_TYPE, 2>(); - int num_inner_elements = ragged_components[i].values.NumElements(); - if (ragged_components[i].values.dim_size(0) > 0) { - num_inner_elements /= ragged_components[i].values.dim_size(0); + ragged_components[i].values().flat_outer_dims<VALUE_TYPE, 2>(); + int num_inner_elements = ragged_components[i].values().NumElements(); + if (ragged_components[i].values().dim_size(0) > 0) { + num_inner_elements /= ragged_components[i].values().dim_size(0); } - for (int j = 0; j < ragged_components[i].values.dim_size(0); + for (int j = 0; j < ragged_components[i].values().dim_size(0); j++, values_index++) { for (int k = 0; k < num_inner_elements; k++) { output_values_flat(values_index, k) = component_values_flat(j, k); @@ -265,7 +231,7 @@ class RaggedTensorFromVariantOp : public OpKernel { // Decode all variants. const auto value_dtype = DataTypeToEnum<VALUE_TYPE>::v(); const auto split_dtype = DataTypeToEnum<SPLIT_TYPE>::v(); - std::vector<RaggedTensor> decoded_components; + std::vector<RaggedTensorVariant> decoded_components; OP_REQUIRES_OK(context, RaggedComponentsFromVariant( encoded_variant, input_ragged_rank_, value_dtype, split_dtype, &decoded_components)); @@ -281,7 +247,7 @@ class RaggedTensorFromVariantOp : public OpKernel { for (int i = 0; i < encoded_variant.dims(); i++) { encoded_dim_sizes[i] = encoded_variant.dim_size(i); } - RaggedTensor output_ragged; + RaggedTensorVariant output_ragged; OP_REQUIRES_OK( context, NestedStackRaggedTensors<VALUE_TYPE, SPLIT_TYPE>( decoded_components, encoded_dim_sizes, input_ragged_rank_, @@ -296,15 +262,15 @@ class RaggedTensorFromVariantOp : public OpKernel { int output_ragged_rank_; void ReturnRaggedTensor(OpKernelContext* context, - RaggedTensor ragged_tensor) { - int ragged_rank = ragged_tensor.nested_splits.size(); + const RaggedTensorVariant& ragged_tensor) { + int ragged_rank = ragged_tensor.ragged_rank(); OpOutputList splits_out; OP_REQUIRES_OK(context, context->output_list("output_nested_splits", &splits_out)); for (int i = 0; i < ragged_rank; i++) { - splits_out.set(i, ragged_tensor.nested_splits[i]); + splits_out.set(i, ragged_tensor.splits(i)); } - context->set_output(ragged_rank, ragged_tensor.values); + context->set_output(ragged_rank, ragged_tensor.values()); } }; diff --git a/tensorflow/core/kernels/ragged_tensor_from_variant_op_test.cc b/tensorflow/core/kernels/ragged_tensor_from_variant_op_test.cc index bdf321d0515..fc46283c90e 100644 --- a/tensorflow/core/kernels/ragged_tensor_from_variant_op_test.cc +++ b/tensorflow/core/kernels/ragged_tensor_from_variant_op_test.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/core/framework/variant.h" #include "tensorflow/core/framework/variant_encode_decode.h" #include "tensorflow/core/kernels/ops_testutil.h" +#include "tensorflow/core/kernels/ragged_tensor_variant.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" @@ -55,28 +56,22 @@ class RaggedTensorFromVariantKernelTest : public ::tensorflow::OpsTestBase { } template <typename VALUE_TYPE, typename SPLIT_TYPE> - Tensor CreateVariantFromRagged( + RaggedTensorVariant CreateVariantFromRagged( const std::vector<std::vector<SPLIT_TYPE>>& ragged_splits, const TensorShape& ragged_values_shape, const std::vector<VALUE_TYPE>& ragged_values) { - // Step 1: Create Tensors out of ragged splits and values. - std::vector<Variant> ragged_components; + RaggedTensorVariant encoded; for (auto ragged_split : ragged_splits) { int splits_size = ragged_split.size(); Tensor splits(DataTypeToEnum<SPLIT_TYPE>::v(), TensorShape({splits_size})); test::FillValues<SPLIT_TYPE>(&splits, ragged_split); - ragged_components.push_back(splits); + encoded.append_splits(splits); } Tensor values(DataTypeToEnum<VALUE_TYPE>::v(), ragged_values_shape); test::FillValues<VALUE_TYPE>(&values, ragged_values); - ragged_components.push_back(values); - - // Step 2: Encode into a 1-D Variant Tensor. - int num_splits = ragged_splits.size(); - Tensor encoded_list(DT_VARIANT, TensorShape({num_splits + 1})); - test::FillValues<Variant>(&encoded_list, ragged_components); - return encoded_list; + encoded.set_values(values); + return encoded; } }; @@ -85,7 +80,7 @@ TEST_F(RaggedTensorFromVariantKernelTest, ScalarInput) { const std::vector<int64> split_2 = {0, 1, 2, 5, 6, 7}; const std::vector<int> values = {0, 1, 1, 2, 2, 3, 4}; - Tensor encoded_variant = CreateVariantFromRagged<int, int64>( + auto encoded_variant = CreateVariantFromRagged<int, int64>( {split_1, split_2}, TensorShape({7}), values); Tensor expected_splits_1(DT_INT64, TensorShape({6})); Tensor expected_splits_2(DT_INT64, TensorShape({6})); @@ -113,7 +108,7 @@ TEST_F(RaggedTensorFromVariantKernelTest, OneInputElement) { const std::vector<int> values = {0, 1, 1, 2, 2, 3, 4}; const std::vector<int64> batched_splits_1 = {0, 5}; - Tensor encoded_variant = CreateVariantFromRagged<int, int64>( + auto encoded_variant = CreateVariantFromRagged<int, int64>( {split_1, split_2}, TensorShape({7}), values); Tensor expected_splits_1(DT_INT64, TensorShape({2})); Tensor expected_splits_2(DT_INT64, TensorShape({6})); @@ -157,13 +152,13 @@ TEST_F(RaggedTensorFromVariantKernelTest, TensorIn2DOut) { const std::vector<int64> batched_splits_2 = {0, 3, 3, 5, 6}; const std::vector<int> batched_values = {1, 2, 3, 4, 5, 6}; - Tensor component_variant_1 = + auto component_variant_1 = CreateVariantFromRagged<int, int64>({}, TensorShape({3}), values_1); - Tensor component_variant_2 = + auto component_variant_2 = CreateVariantFromRagged<int, int64>({}, TensorShape({0}), values_2); - Tensor component_variant_3 = + auto component_variant_3 = CreateVariantFromRagged<int, int64>({}, TensorShape({2}), values_3); - Tensor component_variant_4 = + auto component_variant_4 = CreateVariantFromRagged<int, int64>({}, TensorShape({1}), values_4); Tensor expected_splits_1(DT_INT64, TensorShape({3})); @@ -223,15 +218,15 @@ TEST_F(RaggedTensorFromVariantKernelTest, NonEmpty1DIn3DOut) { test::FillValues<int64>(&expected_splits_3, batched_splits_3); test::FillValues<int>(&expected_values, batched_values); - Tensor variant_component_1 = CreateVariantFromRagged<int, int64>( + auto variant_component_1 = CreateVariantFromRagged<int, int64>( {component_split_1_1}, TensorShape({1}), component_values_1); - Tensor variant_component_2 = CreateVariantFromRagged<int, int64>( + auto variant_component_2 = CreateVariantFromRagged<int, int64>( {component_split_2_1}, TensorShape({2}), component_values_2); - Tensor variant_component_3 = CreateVariantFromRagged<int, int64>( + auto variant_component_3 = CreateVariantFromRagged<int, int64>( {component_split_3_1}, TensorShape({2}), component_values_3); - Tensor variant_component_4 = CreateVariantFromRagged<int, int64>( + auto variant_component_4 = CreateVariantFromRagged<int, int64>( {component_split_4_1}, TensorShape({3}), component_values_4); - Tensor variant_component_5 = CreateVariantFromRagged<int, int64>( + auto variant_component_5 = CreateVariantFromRagged<int, int64>( {component_split_5_1}, TensorShape({3}), component_values_5); int input_ragged_rank = 1; int output_ragged_rank = 3; @@ -297,10 +292,10 @@ TEST_F(RaggedTensorFromVariantKernelTest, test::FillValues<int64>(&expected_splits_4, batched_splits_4); test::FillValues<int>(&expected_values, batched_values); - Tensor variant_component_1 = CreateVariantFromRagged<int, int64>( + auto variant_component_1 = CreateVariantFromRagged<int, int64>( {component_split_1_1, component_split_1_2}, TensorShape({11}), component_values_1); - Tensor variant_component_2 = CreateVariantFromRagged<int, int64>( + auto variant_component_2 = CreateVariantFromRagged<int, int64>( {component_split_2_1, component_split_2_2}, TensorShape({11}), component_values_2); int input_ragged_rank = -1; @@ -336,9 +331,9 @@ TEST_F(RaggedTensorFromVariantKernelTest, EmptyRow1DIn2DOut) { test::FillValues<int64>(&expected_splits_2, batched_splits_2); test::FillValues<int>(&expected_values, batched_values); - Tensor variant_component_1 = CreateVariantFromRagged<int, int64>( + auto variant_component_1 = CreateVariantFromRagged<int, int64>( {component_split_1_1}, TensorShape({3}), component_values_1); - Tensor variant_component_2 = CreateVariantFromRagged<int, int64>( + auto variant_component_2 = CreateVariantFromRagged<int, int64>( {component_split_2_1}, TensorShape({0}), {}); // Empty row. int input_ragged_rank = 1; int output_ragged_rank = 2; @@ -371,9 +366,9 @@ TEST_F(RaggedTensorFromVariantKernelTest, NDValues1DIn2DOut) { test::FillValues<int64>(&expected_splits_2, batched_splits_2); test::FillValues<int>(&expected_values, batched_values); - Tensor variant_component_1 = CreateVariantFromRagged<int, int64>( + auto variant_component_1 = CreateVariantFromRagged<int, int64>( {component_split_1_1}, TensorShape({1, 2}), component_values_1); - Tensor variant_component_2 = CreateVariantFromRagged<int, int64>( + auto variant_component_2 = CreateVariantFromRagged<int, int64>( {component_split_2_1}, TensorShape({2, 2}), component_values_2); int input_ragged_rank = 1; int output_ragged_rank = 2; @@ -423,15 +418,15 @@ TEST_F(RaggedTensorFromVariantKernelTest, NonEmpty1DIn3DOutInt32Splits) { test::FillValues<int>(&expected_splits_3, batched_splits_3); test::FillValues<int>(&expected_values, batched_values); - Tensor variant_component_1 = CreateVariantFromRagged<int, int>( + auto variant_component_1 = CreateVariantFromRagged<int, int>( {component_split_1_1}, TensorShape({1}), component_values_1); - Tensor variant_component_2 = CreateVariantFromRagged<int, int>( + auto variant_component_2 = CreateVariantFromRagged<int, int>( {component_split_2_1}, TensorShape({2}), component_values_2); - Tensor variant_component_3 = CreateVariantFromRagged<int, int>( + auto variant_component_3 = CreateVariantFromRagged<int, int>( {component_split_3_1}, TensorShape({2}), component_values_3); - Tensor variant_component_4 = CreateVariantFromRagged<int, int>( + auto variant_component_4 = CreateVariantFromRagged<int, int>( {component_split_4_1}, TensorShape({3}), component_values_4); - Tensor variant_component_5 = CreateVariantFromRagged<int, int>( + auto variant_component_5 = CreateVariantFromRagged<int, int>( {component_split_5_1}, TensorShape({3}), component_values_5); int input_ragged_rank = 1; int output_ragged_rank = 3; @@ -451,13 +446,13 @@ TEST_F(RaggedTensorFromVariantKernelTest, NonEmpty1DIn3DOutInt32Splits) { // Tests for invalid inputs. TEST_F(RaggedTensorFromVariantKernelTest, InvalidInferredInputRaggedRank) { - Tensor component_variant_1 = + auto component_variant_1 = CreateVariantFromRagged<int, int64>({}, TensorShape({3}), {1, 2, 3}); - Tensor component_variant_2 = + auto component_variant_2 = CreateVariantFromRagged<int, int64>({}, TensorShape({0}), {}); - Tensor component_variant_3 = + auto component_variant_3 = CreateVariantFromRagged<int, int64>({}, TensorShape({2}), {1, 2}); - Tensor component_variant_4 = + auto component_variant_4 = CreateVariantFromRagged<int, int64>({}, TensorShape({1}), {1}); int input_ragged_rank = -1; @@ -478,9 +473,9 @@ TEST_F(RaggedTensorFromVariantKernelTest, InputDimsAndRaggedRankAttrsMismatch) { const std::vector<int> component_values_1 = {0}; const std::vector<int> component_values_2 = {0, 1}; - Tensor variant_component_1 = CreateVariantFromRagged<int, int64>( + auto variant_component_1 = CreateVariantFromRagged<int, int64>( {component_split_1_1}, TensorShape({1}), component_values_1); - Tensor variant_component_2 = CreateVariantFromRagged<int, int64>( + auto variant_component_2 = CreateVariantFromRagged<int, int64>( {component_split_2_1}, TensorShape({2}), component_values_2); int input_ragged_rank = 1; @@ -493,33 +488,21 @@ TEST_F(RaggedTensorFromVariantKernelTest, InputDimsAndRaggedRankAttrsMismatch) { "input_ragged_rank + encoded_ragged.dims()")); } -TEST_F(RaggedTensorFromVariantKernelTest, InputDoesNotHoldTensors) { +TEST_F(RaggedTensorFromVariantKernelTest, InputDoesNotHoldRaggedTensorVariant) { int input_ragged_rank = 1; int output_ragged_rank = 2; BuildDecodeRaggedTensorGraph<int, int64>( input_ragged_rank, output_ragged_rank, TensorShape({2}), {1, 2}); EXPECT_TRUE(absl::StartsWith( RunOpKernel().error_message(), - "Input Variant element at index 0 doesn't hold a Tensor")); -} - -TEST_F(RaggedTensorFromVariantKernelTest, InputVariantTensorRankNotOne) { - Tensor variant_list(DT_VARIANT, TensorShape({2, 1})); - test::FillValues<Variant>(&variant_list, {1, 2}); - int input_ragged_rank = 1; - int output_ragged_rank = 2; - BuildDecodeRaggedTensorGraph<int, int64>( - input_ragged_rank, output_ragged_rank, TensorShape({1}), {variant_list}); - EXPECT_TRUE(absl::StartsWith( - RunOpKernel().error_message(), - "Encoded input Variant must have rank 1, but found rank: 2")); + "Input Variant element at index 0 doesn't hold a RaggedTensorVariant")); } TEST_F(RaggedTensorFromVariantKernelTest, InputScalarElementDoesNotMatchInputRaggedRank) { const std::vector<int64> component_split_1_1 = {0, 1}; const std::vector<int> component_values_1 = {1, 2}; - Tensor variant_component_1 = CreateVariantFromRagged<int, int64>( + auto variant_component_1 = CreateVariantFromRagged<int, int64>( {component_split_1_1}, TensorShape({1, 2}), component_values_1); int input_ragged_rank = 2; @@ -527,31 +510,17 @@ TEST_F(RaggedTensorFromVariantKernelTest, BuildDecodeRaggedTensorGraph<int, int64>(input_ragged_rank, output_ragged_rank, TensorShape({1}), {variant_component_1}); - EXPECT_TRUE(absl::StartsWith( - RunOpKernel().error_message(), - "Encoded input Variant must hold either input_ragged_rank + 1 " - "Tensors or an empty Tensor")); -} - -TEST_F(RaggedTensorFromVariantKernelTest, RaggedSplitNotATensor) { - Tensor variant_list(DT_VARIANT, TensorShape({2})); - test::FillValues<Variant>(&variant_list, {1, 2}); - - int input_ragged_rank = 1; - int output_ragged_rank = 2; - BuildDecodeRaggedTensorGraph<int, int>(input_ragged_rank, output_ragged_rank, - TensorShape({1}), {variant_list}); EXPECT_TRUE( absl::StartsWith(RunOpKernel().error_message(), - "Encoded scalar element at index 0 doesn't have a " - "splits Tensor at split_index 0")); + "Encoded input RaggedTensorVariant has ragged_rank=1. " + "Expected ragged_rank=2.")); } TEST_F(RaggedTensorFromVariantKernelTest, RaggedSplitTypeMismatch) { const std::vector<int64> component_split_1_1 = {0, 1}; const std::vector<int> component_values_1 = {0}; - Tensor variant_component_1 = CreateVariantFromRagged<int, int64>( + auto variant_component_1 = CreateVariantFromRagged<int, int64>( {component_split_1_1}, TensorShape({1}), component_values_1); int input_ragged_rank = 1; @@ -559,46 +528,29 @@ TEST_F(RaggedTensorFromVariantKernelTest, RaggedSplitTypeMismatch) { BuildDecodeRaggedTensorGraph<int, int>(input_ragged_rank, output_ragged_rank, TensorShape({1}), {variant_component_1}); - EXPECT_TRUE(absl::StartsWith(RunOpKernel().error_message(), - "Expected splits Tensor dtype: 3, found: 9")); + EXPECT_TRUE(absl::StartsWith( + RunOpKernel().error_message(), + "Expected row_splits Tensor dtype: int32, found: int64")); } TEST_F(RaggedTensorFromVariantKernelTest, RaggedSplitRankNotOne) { - Tensor splits(DT_INT64, TensorShape({2, 1})); - test::FillValues<int64>(&splits, {1, 2}); - Tensor values(DT_INT32, {2}); - test::FillValues<int>(&values, {1, 2}); - Tensor encoded_list(DT_VARIANT, TensorShape({2})); - test::FillValues<Variant>(&encoded_list, {splits, values}); + RaggedTensorVariant encoded(Tensor(DT_INT32, {2}), + {Tensor(DT_INT64, {2, 1})}); + test::FillValues<int64>(encoded.mutable_splits(0), {1, 2}); + test::FillValues<int>(encoded.mutable_values(), {1, 2}); int input_ragged_rank = 1; int output_ragged_rank = 2; BuildDecodeRaggedTensorGraph<int, int64>( - input_ragged_rank, output_ragged_rank, TensorShape({1}), {encoded_list}); + input_ragged_rank, output_ragged_rank, TensorShape({1}), {encoded}); EXPECT_TRUE(absl::StartsWith(RunOpKernel().error_message(), "Ragged splits must have rank 1")); } -TEST_F(RaggedTensorFromVariantKernelTest, RaggedValuesNotATensor) { - Tensor splits(DT_INT64, TensorShape({3})); - test::FillValues<int64>(&splits, {0, 2, 3}); - Tensor variant_list(DT_VARIANT, TensorShape({2})); - test::FillValues<Variant>(&variant_list, {splits, 2}); - - int input_ragged_rank = 1; - int output_ragged_rank = 2; - BuildDecodeRaggedTensorGraph<int, int64>( - input_ragged_rank, output_ragged_rank, TensorShape({1}), {variant_list}); - EXPECT_TRUE( - absl::StartsWith(RunOpKernel().error_message(), - "Encoded scalar element at index 0 doesn't have a " - "values Tensor")); -} - TEST_F(RaggedTensorFromVariantKernelTest, RaggedValuesTypeMismatch) { const std::vector<int64> component_split_1_1 = {0, 1}; const std::vector<int> component_values_1 = {0}; - Tensor variant_component_1 = CreateVariantFromRagged<int, int64>( + auto variant_component_1 = CreateVariantFromRagged<int, int64>( {component_split_1_1}, TensorShape({1}), component_values_1); int input_ragged_rank = 1; int output_ragged_rank = 2; @@ -611,7 +563,7 @@ TEST_F(RaggedTensorFromVariantKernelTest, RaggedValuesTypeMismatch) { } TEST_F(RaggedTensorFromVariantKernelTest, RaggedValuesRankNotGreaterThanOne) { - Tensor variant_component_1 = + auto variant_component_1 = CreateVariantFromRagged<int, int64>({{0, 1}}, TensorShape({}), {1}); int input_ragged_rank = 1; int output_ragged_rank = 2; @@ -628,9 +580,9 @@ TEST_F(RaggedTensorFromVariantKernelTest, RaggedValuesRankMismatch) { const std::vector<int> component_values_1 = {0}; const std::vector<int> component_values_2 = {0, 1, 2, 3}; - Tensor variant_component_1 = CreateVariantFromRagged<int, int64>( + auto variant_component_1 = CreateVariantFromRagged<int, int64>( {component_split_1_1}, TensorShape({1}), component_values_1); - Tensor variant_component_2 = CreateVariantFromRagged<int, int64>( + auto variant_component_2 = CreateVariantFromRagged<int, int64>( {component_split_2_1}, TensorShape({2, 2}), component_values_2); int input_ragged_rank = 1; int output_ragged_rank = 2; @@ -711,13 +663,13 @@ TEST_F(RaggedTensorFromVariantKernelTest, 2DValuesTensorIn1DOut) { const std::vector<int> batched_values = {1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5}; - Tensor variant_component_1 = CreateVariantFromRagged<int, int64>( + auto variant_component_1 = CreateVariantFromRagged<int, int64>( {}, TensorShape({2, 2, 2}), {1, 1, 1, 1, 2, 2, 2, 2}); - Tensor variant_component_2 = CreateVariantFromRagged<int, int64>( + auto variant_component_2 = CreateVariantFromRagged<int, int64>( {}, TensorShape({1, 2, 2}), {3, 3, 3, 3}); - Tensor variant_component_3 = + auto variant_component_3 = CreateVariantFromRagged<int, int64>({}, TensorShape({0, 2, 2}), {}); - Tensor variant_component_4 = CreateVariantFromRagged<int, int64>( + auto variant_component_4 = CreateVariantFromRagged<int, int64>( {}, TensorShape({2, 2, 2}), {4, 4, 4, 4, 5, 5, 5, 5}); Tensor expected_splits_1(DT_INT64, TensorShape({5})); diff --git a/tensorflow/core/kernels/ragged_tensor_to_variant_op.cc b/tensorflow/core/kernels/ragged_tensor_to_variant_op.cc index 64c372b005e..549dc68dfbf 100644 --- a/tensorflow/core/kernels/ragged_tensor_to_variant_op.cc +++ b/tensorflow/core/kernels/ragged_tensor_to_variant_op.cc @@ -18,50 +18,38 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/variant.h" #include "tensorflow/core/framework/variant_encode_decode.h" +#include "tensorflow/core/framework/variant_op_registry.h" +#include "tensorflow/core/kernels/concat_lib.h" +#include "tensorflow/core/kernels/ragged_tensor_variant.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/util/tensor_ops_util.h" namespace tensorflow { namespace { -struct RaggedTensor { - Tensor values; - std::vector<Tensor> nested_splits; -}; - -Status RaggedToVariant(const RaggedTensor& ragged, Tensor* encoded_list) { - // Encode as a rank-1 Variant Tensor. - int ragged_rank = ragged.nested_splits.size(); - *encoded_list = Tensor(DT_VARIANT, TensorShape({ragged_rank + 1})); - auto encoded_vec = encoded_list->vec<Variant>(); - for (int i = 0; i < ragged_rank; i++) { - encoded_vec(i) = ragged.nested_splits[i]; - } - encoded_vec(ragged_rank) = ragged.values; - return Status::OK(); -} - template <typename VALUE_TYPE, typename SPLIT_TYPE> -Status UnbatchRaggedZerothDim(const RaggedTensor& batched_ragged, - std::vector<RaggedTensor>* ragged_components) { +Status UnbatchRaggedZerothDim( + const RaggedTensorVariant& batched_ragged, + std::vector<RaggedTensorVariant>* ragged_components) { // Set up the component Ragged Tensors. - int ragged_rank = batched_ragged.nested_splits.size(); - auto batched_splits_top_vec = - batched_ragged.nested_splits[0].vec<SPLIT_TYPE>(); + int ragged_rank = batched_ragged.ragged_rank(); + auto batched_splits_top_vec = batched_ragged.splits(0).vec<SPLIT_TYPE>(); int num_components = batched_splits_top_vec.size() - 1; int num_splits = ragged_rank - 1; ragged_components->resize(num_components); - for (RaggedTensor ragged_component : *ragged_components) { - ragged_component.nested_splits.reserve(num_splits); + for (RaggedTensorVariant& ragged_component : *ragged_components) { + ragged_component.mutable_nested_splits()->reserve(num_splits); } - const auto& batched_flat = batched_ragged.values.flat<VALUE_TYPE>(); - int num_inner_elems = batched_ragged.values.NumElements(); - if (batched_ragged.values.dim_size(0) > 1) { - num_inner_elems /= batched_ragged.values.dim_size(0); + const auto& batched_flat = batched_ragged.values().flat<VALUE_TYPE>(); + int num_inner_elems = batched_ragged.values().NumElements(); + if (batched_ragged.values().dim_size(0) > 1) { + num_inner_elems /= batched_ragged.values().dim_size(0); } - TensorShape values_shape = batched_ragged.values.shape(); + TensorShape values_shape = batched_ragged.values().shape(); // Corner case: ragged_rank == 1, e.g. [[1, 2, 3], [4, 5]] if (num_splits == 0) { @@ -70,10 +58,10 @@ Status UnbatchRaggedZerothDim(const RaggedTensor& batched_ragged, int limit = batched_splits_top_vec(i + 1); int num_values = limit - start; values_shape.set_dim(0, num_values); - (*ragged_components)[i].values = - Tensor(DataTypeToEnum<VALUE_TYPE>::value, values_shape); + (*ragged_components)[i].set_values( + Tensor(DataTypeToEnum<VALUE_TYPE>::value, values_shape)); auto ragged_component_values_flat = - (*ragged_components)[i].values.flat<VALUE_TYPE>(); + (*ragged_components)[i].mutable_values()->flat<VALUE_TYPE>(); for (int j = 0; j < num_values * num_inner_elems; j++) { ragged_component_values_flat(j) = batched_flat(j + start * num_inner_elems); @@ -86,8 +74,7 @@ Status UnbatchRaggedZerothDim(const RaggedTensor& batched_ragged, std::vector<typename TTypes<SPLIT_TYPE>::ConstVec> batched_splits_vec; batched_splits_vec.reserve(ragged_rank); for (int i = 0; i < ragged_rank; i++) { - batched_splits_vec.push_back( - batched_ragged.nested_splits[i].vec<SPLIT_TYPE>()); + batched_splits_vec.push_back(batched_ragged.splits(i).vec<SPLIT_TYPE>()); } std::vector<int> index(num_splits, 1); std::vector<int> ragged_component_values_size(num_components, 0); @@ -104,10 +91,10 @@ Status UnbatchRaggedZerothDim(const RaggedTensor& batched_ragged, int last_index = ragged_component_splits_vec[j - 1].size() - 1; split_size = ragged_component_splits_vec[j - 1](last_index) + 1; } - (*ragged_components)[i].nested_splits.push_back( + (*ragged_components)[i].append_splits( Tensor(DataTypeToEnum<SPLIT_TYPE>::value, TensorShape({split_size}))); ragged_component_splits_vec.push_back( - (*ragged_components)[i].nested_splits[j].vec<SPLIT_TYPE>()); + (*ragged_components)[i].mutable_splits(j)->vec<SPLIT_TYPE>()); SPLIT_TYPE last_split_value = batched_splits_vec[j + 1](index[j] - 1); ragged_component_splits_vec[j](0) = 0; for (int k = 1; k < split_size; k++, index[j]++) { @@ -125,10 +112,10 @@ Status UnbatchRaggedZerothDim(const RaggedTensor& batched_ragged, for (int i = 0; i < num_components; i++) { int num_values = ragged_component_values_size[i]; values_shape.set_dim(0, num_values); - (*ragged_components)[i].values = - Tensor(DataTypeToEnum<VALUE_TYPE>::value, values_shape); + (*ragged_components)[i].set_values( + Tensor(DataTypeToEnum<VALUE_TYPE>::value, values_shape)); auto ragged_component_values_flat = - (*ragged_components)[i].values.flat<VALUE_TYPE>(); + (*ragged_components)[i].mutable_values()->flat<VALUE_TYPE>(); for (int j = 0; j < num_values * num_inner_elems; j++, value_index++) { ragged_component_values_flat(j) = batched_flat(value_index); } @@ -152,46 +139,38 @@ class RaggedTensorToVariantOp : public OpKernel { OP_REQUIRES_OK(context, context->input_list("rt_nested_splits", &ragged_nested_splits_in)); const int ragged_nested_splits_len = ragged_nested_splits_in.size(); - RaggedTensor batched_ragged_input; + RaggedTensorVariant batched_ragged_input; // Read ragged_values input. - batched_ragged_input.values = context->input(ragged_nested_splits_len); - batched_ragged_input.nested_splits.reserve(ragged_nested_splits_len); + batched_ragged_input.set_values(context->input(ragged_nested_splits_len)); + batched_ragged_input.mutable_nested_splits()->reserve( + ragged_nested_splits_len); for (int i = 0; i < ragged_nested_splits_len; i++) { - batched_ragged_input.nested_splits.push_back(ragged_nested_splits_in[i]); + batched_ragged_input.append_splits(ragged_nested_splits_in[i]); } if (!batched_input_) { - // Encode the input as is. - Tensor encoded_list; - OP_REQUIRES_OK(context, - RaggedToVariant(batched_ragged_input, &encoded_list)); // Encode as a Scalar Variant Tensor. Tensor* encoded_scalar; OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape({}), &encoded_scalar)); - encoded_scalar->scalar<Variant>()() = std::move(encoded_list); + encoded_scalar->scalar<Variant>()() = std::move(batched_ragged_input); return; } // Unbatch the Ragged Tensor and encode the components. - std::vector<RaggedTensor> ragged_components; + std::vector<RaggedTensorVariant> unbatched_ragged_input; OP_REQUIRES_OK(context, UnbatchRaggedZerothDim<VALUE_TYPE, SPLIT_TYPE>( - batched_ragged_input, &ragged_components)); - std::vector<Tensor> encoded_components(ragged_components.size()); - for (int i = 0; i < ragged_components.size(); i++) { - OP_REQUIRES_OK(context, RaggedToVariant(ragged_components[i], - &encoded_components[i])); - } + batched_ragged_input, &unbatched_ragged_input)); // Bundle the encoded scalar Variant Tensors into a rank-1 Variant Tensor. - Tensor* encoded_ragged; - int output_size = ragged_components.size(); + Tensor* encoded_vector; + int output_size = unbatched_ragged_input.size(); OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape({output_size}), - &encoded_ragged)); - auto encoded_ragged_vec = encoded_ragged->vec<Variant>(); + &encoded_vector)); + auto encoded_vector_t = encoded_vector->vec<Variant>(); for (int i = 0; i < output_size; i++) { - encoded_ragged_vec(i) = encoded_components[i]; + encoded_vector_t(i) = unbatched_ragged_input[i]; } } @@ -199,12 +178,81 @@ class RaggedTensorToVariantOp : public OpKernel { bool batched_input_; }; -#define REGISTER_KERNELS_WITH_SPLIT_TYPE(value_type, split_type) \ - REGISTER_KERNEL_BUILDER(Name("RaggedTensorToVariant") \ - .Device(DEVICE_CPU) \ - .TypeConstraint<value_type>("Tvalues") \ - .TypeConstraint<split_type>("Tsplits"), \ - RaggedTensorToVariantOp<value_type, split_type>); +template <typename VALUE_TYPE, typename SPLIT_TYPE> +class RaggedTensorToVariantGradientOp : public OpKernel { + public: + using OpKernel::OpKernel; + + void Compute(OpKernelContext* context) override { + // Read inputs. + Tensor encoded_variant = context->input(0); + Tensor row_splits = context->input(1); + auto flat_row_splits = row_splits.flat<SPLIT_TYPE>(); + TensorShape dense_values_shape; + OP_REQUIRES_OK(context, + TensorShapeUtils::MakeShape(context->input(2).vec<int32>(), + &dense_values_shape)); + + const auto& flat_variants = encoded_variant.flat<Variant>(); + + // Get a Tensor containing the flat_values for each variant. + std::vector<Tensor> values; + for (int i = 0; i < flat_variants.size(); ++i) { + if (const auto* encoded = flat_variants(i).get<RaggedTensorVariant>()) { + values.push_back(encoded->values()); + } else { + // Missing value: this happens if only some of the variant values + // generated by ragged_tensor_to_variant impacted the value that we're + // calculating the gradient for. In this case, we will see a + // default-constructed variant; so treat it as a zero tensor with the + // appropriate shape. + const auto value_dtype = DataTypeToEnum<VALUE_TYPE>::v(); + int piece_size = flat_row_splits(i + 1) - flat_row_splits(i); + TensorShape zeros_shape = dense_values_shape; + zeros_shape.set_dim(0, piece_size); + Tensor zero(value_dtype, zeros_shape); + zero.flat<VALUE_TYPE>() = + zero.flat<VALUE_TYPE>().constant(VALUE_TYPE()); + values.push_back(zero); + } + } + + if (values.size() == 1) { + // Just one flat_value tensor: return as-is. + context->set_output(0, values[0]); + } else { + // Multiple flat_values tensors: concatenate them together. + using Piece = typename TTypes<VALUE_TYPE, 2>::Matrix; + using ConstPiece = typename TTypes<VALUE_TYPE, 2>::ConstMatrix; + std::vector<std::unique_ptr<ConstPiece>> pieces; + pieces.reserve(values.size()); + for (const Tensor& t : values) { + pieces.emplace_back( + new ConstPiece(t.shaped<VALUE_TYPE, 2>({1, t.NumElements()}))); + } + Tensor* out = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output(0, dense_values_shape, &out)); + Piece out_flat = + out->shaped<VALUE_TYPE, 2>({1, dense_values_shape.num_elements()}); + ConcatCPU<VALUE_TYPE>(context->device(), pieces, &out_flat); + } + } +}; + +#define REGISTER_KERNELS_WITH_SPLIT_TYPE(value_type, split_type) \ + REGISTER_KERNEL_BUILDER(Name("RaggedTensorToVariant") \ + .Device(DEVICE_CPU) \ + .TypeConstraint<value_type>("Tvalues") \ + .TypeConstraint<split_type>("Tsplits"), \ + RaggedTensorToVariantOp<value_type, split_type>); \ + REGISTER_KERNEL_BUILDER( \ + Name("RaggedTensorToVariantGradient") \ + .Device(DEVICE_CPU) \ + .TypeConstraint<value_type>("Tvalues") \ + .TypeConstraint<split_type>("Tsplits"), \ + RaggedTensorToVariantGradientOp<value_type, split_type>); + #define REGISTER_KERNELS(value_type) \ REGISTER_KERNELS_WITH_SPLIT_TYPE(value_type, int32) \ REGISTER_KERNELS_WITH_SPLIT_TYPE(value_type, int64) diff --git a/tensorflow/core/kernels/ragged_tensor_to_variant_op_test.cc b/tensorflow/core/kernels/ragged_tensor_to_variant_op_test.cc index c1438dd7af9..94f35673c8b 100644 --- a/tensorflow/core/kernels/ragged_tensor_to_variant_op_test.cc +++ b/tensorflow/core/kernels/ragged_tensor_to_variant_op_test.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/core/framework/variant.h" #include "tensorflow/core/framework/variant_encode_decode.h" #include "tensorflow/core/kernels/ops_testutil.h" +#include "tensorflow/core/kernels/ragged_tensor_variant.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" @@ -60,6 +61,43 @@ class RaggedTensorToVariantKernelTest : public ::tensorflow::OpsTestBase { } AddInputFromArray<VALUE_TYPE>(ragged_values_shape, ragged_values); } + + template <typename VALUE_TYPE, typename SPLIT_TYPE> + RaggedTensorVariant CreateVariantFromRagged( + const std::vector<std::vector<SPLIT_TYPE>>& ragged_splits, + const TensorShape& ragged_values_shape, + const std::vector<VALUE_TYPE>& ragged_values) { + RaggedTensorVariant encoded; + for (auto ragged_split : ragged_splits) { + int splits_size = ragged_split.size(); + Tensor splits(DataTypeToEnum<SPLIT_TYPE>::v(), + TensorShape({splits_size})); + test::FillValues<SPLIT_TYPE>(&splits, ragged_split); + encoded.append_splits(splits); + } + Tensor values(DataTypeToEnum<VALUE_TYPE>::v(), ragged_values_shape); + test::FillValues<VALUE_TYPE>(&values, ragged_values); + encoded.set_values(values); + return encoded; + } + + template <typename VALUE_TYPE, typename SPLIT_TYPE> + RaggedTensorVariant CreateVariantFromRagged( + const std::vector<std::vector<SPLIT_TYPE>>& ragged_splits, + const std::vector<VALUE_TYPE>& ragged_values) { + int num_values = ragged_values.size(); + return CreateVariantFromRagged(ragged_splits, {num_values}, ragged_values); + } + + template <typename VALUE_TYPE, typename SPLIT_TYPE> + void ExpectRaggedTensorVariantEqual(const RaggedTensorVariant& expected, + const RaggedTensorVariant& actual) { + test::ExpectTensorEqual<VALUE_TYPE>(actual.values(), expected.values()); + EXPECT_EQ(actual.ragged_rank(), expected.ragged_rank()); + for (int i = 0; i < actual.ragged_rank(); ++i) { + test::ExpectTensorEqual<SPLIT_TYPE>(actual.splits(i), expected.splits(i)); + } + } }; TEST_F(RaggedTensorToVariantKernelTest, NoValuesInput) { @@ -67,18 +105,6 @@ TEST_F(RaggedTensorToVariantKernelTest, NoValuesInput) { const std::vector<int64> batched_splits_1 = {0, 2, 3, 3}; const std::vector<int64> batched_splits_2 = {0, 0, 0, 0}; - const std::vector<int64> component_splits_1_1 = {0, 0, 0}; - const std::vector<int64> component_splits_2_1 = {0, 0}; - const std::vector<int64> component_splits_3_1 = {0}; - - Tensor expected_splits_1_1(DT_INT64, TensorShape({3})); - Tensor expected_splits_2_1(DT_INT64, TensorShape({2})); - Tensor expected_splits_3_1(DT_INT64, TensorShape({1})); - - test::FillValues<int64>(&expected_splits_1_1, component_splits_1_1); - test::FillValues<int64>(&expected_splits_2_1, component_splits_2_1); - test::FillValues<int64>(&expected_splits_3_1, component_splits_3_1); - BuildEncodeRaggedTensorGraph<int, int64>({batched_splits_1, batched_splits_2}, TensorShape({0}), {}, true); TF_ASSERT_OK(RunOpKernel()); @@ -86,55 +112,26 @@ TEST_F(RaggedTensorToVariantKernelTest, NoValuesInput) { const auto& encoded_list = GetOutput(0)->vec<Variant>(); EXPECT_EQ(encoded_list.size(), 3); - const Variant& encoded_splits_1_1 = - encoded_list(0).get<Tensor>()->vec<Variant>()(0); - const Variant& encoded_values_1 = - encoded_list(0).get<Tensor>()->vec<Variant>()(1); - const Variant& encoded_splits_2_1 = - encoded_list(1).get<Tensor>()->vec<Variant>()(0); - const Variant& encoded_values_2 = - encoded_list(1).get<Tensor>()->vec<Variant>()(1); - const Variant& encoded_splits_3_1 = - encoded_list(2).get<Tensor>()->vec<Variant>()(0); - const Variant& encoded_values_3 = - encoded_list(2).get<Tensor>()->vec<Variant>()(1); - - test::ExpectTensorEqual<int64>(*encoded_splits_1_1.get<Tensor>(), - expected_splits_1_1); - test::ExpectTensorEqual<int64>(*encoded_splits_2_1.get<Tensor>(), - expected_splits_2_1); - test::ExpectTensorEqual<int64>(*encoded_splits_3_1.get<Tensor>(), - expected_splits_3_1); - test::ExpectTensorEqual<int>(*encoded_values_1.get<Tensor>(), - Tensor(DT_INT32, TensorShape({0}))); - test::ExpectTensorEqual<int>(*encoded_values_2.get<Tensor>(), - Tensor(DT_INT32, TensorShape({0}))); - test::ExpectTensorEqual<int>(*encoded_values_3.get<Tensor>(), - Tensor(DT_INT32, TensorShape({0}))); + ExpectRaggedTensorVariantEqual<int, int64>( + CreateVariantFromRagged<int, int64>({{0, 0, 0}}, {}), + *encoded_list(0).get<RaggedTensorVariant>()); + ExpectRaggedTensorVariantEqual<int, int64>( + CreateVariantFromRagged<int, int64>({{0, 0}}, {}), + *encoded_list(1).get<RaggedTensorVariant>()); + ExpectRaggedTensorVariantEqual<int, int64>( + CreateVariantFromRagged<int, int64>({{0}}, {}), + *encoded_list(2).get<RaggedTensorVariant>()); } TEST_F(RaggedTensorToVariantKernelTest, 1DValuesRaggedRankOneInput) { // ragged_tensor= - // [ [x, x, x], + // [ [1, 2, 3], // [ ], - // [x, x ], - // [x ]] + // [4, 5 ], + // [6 ]] const std::vector<int64> batched_splits = {0, 3, 3, 5, 6}; const std::vector<int> batched_values = {1, 2, 3, 4, 5, 6}; - const std::vector<int> component_values_1 = {1, 2, 3}; - const std::vector<int> component_values_3 = {4, 5}; - const std::vector<int> component_values_4 = {6}; - - Tensor expected_values_1(DT_INT32, TensorShape({3})); - Tensor expected_values_2(DT_INT32, TensorShape({0})); - Tensor expected_values_3(DT_INT32, TensorShape({2})); - Tensor expected_values_4(DT_INT32, TensorShape({1})); - - test::FillValues<int>(&expected_values_1, component_values_1); - test::FillValues<int>(&expected_values_3, component_values_3); - test::FillValues<int>(&expected_values_4, component_values_4); - BuildEncodeRaggedTensorGraph<int, int64>({batched_splits}, TensorShape({6}), batched_values, true); TF_ASSERT_OK(RunOpKernel()); @@ -142,45 +139,28 @@ TEST_F(RaggedTensorToVariantKernelTest, 1DValuesRaggedRankOneInput) { const auto& encoded_list = GetOutput(0)->vec<Variant>(); EXPECT_EQ(encoded_list.size(), 4); - const Variant& encoded_values_1 = - encoded_list(0).get<Tensor>()->vec<Variant>()(0); - const Variant& encoded_values_2 = - encoded_list(1).get<Tensor>()->vec<Variant>()(0); - const Variant& encoded_values_3 = - encoded_list(2).get<Tensor>()->vec<Variant>()(0); - const Variant& encoded_values_4 = - encoded_list(3).get<Tensor>()->vec<Variant>()(0); - - test::ExpectTensorEqual<int>(*encoded_values_1.get<Tensor>(), - expected_values_1); - test::ExpectTensorEqual<int>(*encoded_values_2.get<Tensor>(), - expected_values_2); - test::ExpectTensorEqual<int>(*encoded_values_3.get<Tensor>(), - expected_values_3); - test::ExpectTensorEqual<int>(*encoded_values_4.get<Tensor>(), - expected_values_4); + ExpectRaggedTensorVariantEqual<int, int64>( + CreateVariantFromRagged<int, int64>({}, {1, 2, 3}), + *encoded_list(0).get<RaggedTensorVariant>()); + ExpectRaggedTensorVariantEqual<int, int64>( + CreateVariantFromRagged<int, int64>({}, {}), + *encoded_list(1).get<RaggedTensorVariant>()); + ExpectRaggedTensorVariantEqual<int, int64>( + CreateVariantFromRagged<int, int64>({}, {4, 5}), + *encoded_list(2).get<RaggedTensorVariant>()); + ExpectRaggedTensorVariantEqual<int, int64>( + CreateVariantFromRagged<int, int64>({}, {6}), + *encoded_list(3).get<RaggedTensorVariant>()); } TEST_F(RaggedTensorToVariantKernelTest, 2DBatchedValuesRankOneInput) { // ragged_tensor= - // [[x, x], - // [x, x], - // [x, x]] + // [[1, 2], + // [4, 5], + // [6, 7]] const std::vector<int64> batched_splits = {0, 1, 2, 3}; const std::vector<int> batched_values = {1, 2, 4, 5, 6, 7}; - const std::vector<int> component_values_1 = {1, 2}; - const std::vector<int> component_values_2 = {4, 5}; - const std::vector<int> component_values_3 = {6, 7}; - - Tensor expected_values_1(DT_INT32, TensorShape({1, 2})); - Tensor expected_values_2(DT_INT32, TensorShape({1, 2})); - Tensor expected_values_3(DT_INT32, TensorShape({1, 2})); - - test::FillValues<int>(&expected_values_1, component_values_1); - test::FillValues<int>(&expected_values_2, component_values_2); - test::FillValues<int>(&expected_values_3, component_values_3); - BuildEncodeRaggedTensorGraph<int, int64>( {batched_splits}, TensorShape({3, 2}), batched_values, true); TF_ASSERT_OK(RunOpKernel()); @@ -188,44 +168,25 @@ TEST_F(RaggedTensorToVariantKernelTest, 2DBatchedValuesRankOneInput) { const auto& encoded_list = GetOutput(0)->vec<Variant>(); EXPECT_EQ(encoded_list.size(), 3); - const Variant& encoded_values_1 = - encoded_list(0).get<Tensor>()->vec<Variant>()(0); - const Variant& encoded_values_2 = - encoded_list(1).get<Tensor>()->vec<Variant>()(0); - const Variant& encoded_values_3 = - encoded_list(2).get<Tensor>()->vec<Variant>()(0); - - test::ExpectTensorEqual<int>(*encoded_values_1.get<Tensor>(), - expected_values_1); - test::ExpectTensorEqual<int>(*encoded_values_2.get<Tensor>(), - expected_values_2); - test::ExpectTensorEqual<int>(*encoded_values_3.get<Tensor>(), - expected_values_3); + ExpectRaggedTensorVariantEqual<int, int64>( + CreateVariantFromRagged<int, int64>({}, {1, 2}, {1, 2}), + *encoded_list(0).get<RaggedTensorVariant>()); + ExpectRaggedTensorVariantEqual<int, int64>( + CreateVariantFromRagged<int, int64>({}, {1, 2}, {4, 5}), + *encoded_list(1).get<RaggedTensorVariant>()); + ExpectRaggedTensorVariantEqual<int, int64>( + CreateVariantFromRagged<int, int64>({}, {1, 2}, {6, 7}), + *encoded_list(2).get<RaggedTensorVariant>()); } TEST_F(RaggedTensorToVariantKernelTest, 2DBatchedValuesRankTwoInput) { - // ragged_tensor=[ - // [ [[x, x], [x, x]], - // [[x, x] ] ] + // ragged_tensor= + // [ [[[1, 2], [4, 5]]], + // [[[6 7]]] ] const std::vector<int64> batched_splits_1 = {0, 1, 2}; const std::vector<int64> batched_splits_2 = {0, 2, 3}; const std::vector<int> batched_values = {1, 2, 4, 5, 6, 7}; - const std::vector<int64> component_splits_1_1 = {0, 2}; - const std::vector<int64> component_splits_2_1 = {0, 1}; - const std::vector<int> component_values_1 = {1, 2, 4, 5}; - const std::vector<int> component_values_2 = {6, 7}; - - Tensor expected_splits_1_1(DT_INT64, TensorShape({2})); - Tensor expected_splits_2_1(DT_INT64, TensorShape({2})); - Tensor expected_values_1(DT_INT32, TensorShape({2, 2})); - Tensor expected_values_2(DT_INT32, TensorShape({1, 2})); - - test::FillValues<int64>(&expected_splits_1_1, component_splits_1_1); - test::FillValues<int64>(&expected_splits_2_1, component_splits_2_1); - test::FillValues<int>(&expected_values_1, component_values_1); - test::FillValues<int>(&expected_values_2, component_values_2); - BuildEncodeRaggedTensorGraph<int, int64>({batched_splits_1, batched_splits_2}, TensorShape({3, 2}), batched_values, true); @@ -234,23 +195,12 @@ TEST_F(RaggedTensorToVariantKernelTest, 2DBatchedValuesRankTwoInput) { const auto& encoded_list = GetOutput(0)->vec<Variant>(); EXPECT_EQ(encoded_list.size(), 2); - const Variant& encoded_splits_1_1 = - encoded_list(0).get<Tensor>()->vec<Variant>()(0); - const Variant& encoded_values_1 = - encoded_list(0).get<Tensor>()->vec<Variant>()(1); - const Variant& encoded_splits_2_1 = - encoded_list(1).get<Tensor>()->vec<Variant>()(0); - const Variant& encoded_values_2 = - encoded_list(1).get<Tensor>()->vec<Variant>()(1); - - test::ExpectTensorEqual<int64>(*encoded_splits_1_1.get<Tensor>(), - expected_splits_1_1); - test::ExpectTensorEqual<int>(*encoded_values_1.get<Tensor>(), - expected_values_1); - test::ExpectTensorEqual<int64>(*encoded_splits_2_1.get<Tensor>(), - expected_splits_2_1); - test::ExpectTensorEqual<int>(*encoded_values_2.get<Tensor>(), - expected_values_2); + ExpectRaggedTensorVariantEqual<int, int64>( + CreateVariantFromRagged<int, int64>({{0, 2}}, {2, 2}, {1, 2, 4, 5}), + *encoded_list(0).get<RaggedTensorVariant>()); + ExpectRaggedTensorVariantEqual<int, int64>( + CreateVariantFromRagged<int, int64>({{0, 1}}, {1, 2}, {6, 7}), + *encoded_list(1).get<RaggedTensorVariant>()); } TEST_F(RaggedTensorToVariantKernelTest, EmptyRowInBatchedInput) { @@ -263,30 +213,6 @@ TEST_F(RaggedTensorToVariantKernelTest, EmptyRowInBatchedInput) { const std::vector<int64> batched_splits_2 = {0, 1, 3, 3, 8, 11, 11, 15}; const std::vector<int> batched_values = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; - const std::vector<int64> component_splits_1_1 = {0, 1, 3, 3}; - const std::vector<int64> component_splits_2_1 = {0}; - const std::vector<int64> component_splits_3_1 = {0, 5, 8}; - const std::vector<int64> component_splits_4_1 = {0, 0, 4}; - const std::vector<int> component_values_1 = {1, 2, 3}; - const std::vector<int> component_values_3 = {4, 5, 6, 7, 8, 9, 10, 11}; - const std::vector<int> component_values_4 = {12, 13, 14, 15}; - - Tensor expected_splits_1_1(DT_INT64, TensorShape({4})); - Tensor expected_splits_2_1(DT_INT64, TensorShape({1})); - Tensor expected_splits_3_1(DT_INT64, TensorShape({3})); - Tensor expected_splits_4_1(DT_INT64, TensorShape({3})); - Tensor expected_values_1(DT_INT32, TensorShape({3})); - Tensor expected_values_2(DT_INT32, TensorShape({0})); - Tensor expected_values_3(DT_INT32, TensorShape({8})); - Tensor expected_values_4(DT_INT32, TensorShape({4})); - - test::FillValues<int64>(&expected_splits_1_1, component_splits_1_1); - test::FillValues<int64>(&expected_splits_2_1, component_splits_2_1); - test::FillValues<int64>(&expected_splits_3_1, component_splits_3_1); - test::FillValues<int64>(&expected_splits_4_1, component_splits_4_1); - test::FillValues<int>(&expected_values_1, component_values_1); - test::FillValues<int>(&expected_values_3, component_values_3); - test::FillValues<int>(&expected_values_4, component_values_4); BuildEncodeRaggedTensorGraph<int, int64>({batched_splits_1, batched_splits_2}, TensorShape({15}), batched_values, @@ -296,39 +222,19 @@ TEST_F(RaggedTensorToVariantKernelTest, EmptyRowInBatchedInput) { const auto& encoded_list = GetOutput(0)->vec<Variant>(); EXPECT_EQ(encoded_list.size(), 4); - const Variant& encoded_splits_1_1 = - encoded_list(0).get<Tensor>()->vec<Variant>()(0); - const Variant& encoded_values_1 = - encoded_list(0).get<Tensor>()->vec<Variant>()(1); - const Variant& encoded_splits_2_1 = - encoded_list(1).get<Tensor>()->vec<Variant>()(0); - const Variant& encoded_values_2 = - encoded_list(1).get<Tensor>()->vec<Variant>()(1); - const Variant& encoded_splits_3_1 = - encoded_list(2).get<Tensor>()->vec<Variant>()(0); - const Variant& encoded_values_3 = - encoded_list(2).get<Tensor>()->vec<Variant>()(1); - const Variant& encoded_splits_4_1 = - encoded_list(3).get<Tensor>()->vec<Variant>()(0); - const Variant& encoded_values_4 = - encoded_list(3).get<Tensor>()->vec<Variant>()(1); - - test::ExpectTensorEqual<int64>(*encoded_splits_1_1.get<Tensor>(), - expected_splits_1_1); - test::ExpectTensorEqual<int>(*encoded_values_1.get<Tensor>(), - expected_values_1); - test::ExpectTensorEqual<int64>(*encoded_splits_2_1.get<Tensor>(), - expected_splits_2_1); - test::ExpectTensorEqual<int>(*encoded_values_2.get<Tensor>(), - expected_values_2); - test::ExpectTensorEqual<int64>(*encoded_splits_3_1.get<Tensor>(), - expected_splits_3_1); - test::ExpectTensorEqual<int>(*encoded_values_3.get<Tensor>(), - expected_values_3); - test::ExpectTensorEqual<int64>(*encoded_splits_4_1.get<Tensor>(), - expected_splits_4_1); - test::ExpectTensorEqual<int>(*encoded_values_4.get<Tensor>(), - expected_values_4); + ExpectRaggedTensorVariantEqual<int, int64>( + CreateVariantFromRagged<int, int64>({{0, 1, 3, 3}}, {1, 2, 3}), + *encoded_list(0).get<RaggedTensorVariant>()); + ExpectRaggedTensorVariantEqual<int, int64>( + CreateVariantFromRagged<int, int64>({{0}}, {}), + *encoded_list(1).get<RaggedTensorVariant>()); + ExpectRaggedTensorVariantEqual<int, int64>( + CreateVariantFromRagged<int, int64>({{0, 5, 8}}, + {4, 5, 6, 7, 8, 9, 10, 11}), + *encoded_list(2).get<RaggedTensorVariant>()); + ExpectRaggedTensorVariantEqual<int, int64>( + CreateVariantFromRagged<int, int64>({{0, 0, 4}}, {12, 13, 14, 15}), + *encoded_list(3).get<RaggedTensorVariant>()); } TEST_F(RaggedTensorToVariantKernelTest, NonEmptyBatchedInput) { @@ -350,26 +256,6 @@ TEST_F(RaggedTensorToVariantKernelTest, NonEmptyBatchedInput) { 7, 8, 9, 12, 13, 14}; const std::vector<int> batched_values = {0, 1, 1, 2, 2, 3, 4, 5, 6, 7, 8, 9, 8, 9}; - const std::vector<int64> component_split_1_1 = {0, 1, 3, 4, 5, 6}; - const std::vector<int64> component_split_1_2 = {0, 2, 3, 4, 5, 6, 7}; - const std::vector<int64> component_split_2_1 = {0, 1, 2, 3, 4, 5}; - const std::vector<int64> component_split_2_2 = {0, 1, 2, 5, 6, 7}; - const std::vector<int> component_values_1 = {0, 1, 1, 2, 2, 3, 4}; - const std::vector<int> component_values_2 = {5, 6, 7, 8, 9, 8, 9}; - - Tensor expected_splits_1_1(DT_INT64, TensorShape({6})); - Tensor expected_splits_1_2(DT_INT64, TensorShape({7})); - Tensor expected_splits_2_1(DT_INT64, TensorShape({6})); - Tensor expected_splits_2_2(DT_INT64, TensorShape({6})); - Tensor expected_values_1(DT_INT32, TensorShape({7})); - Tensor expected_values_2(DT_INT32, TensorShape({7})); - - test::FillValues<int64>(&expected_splits_1_1, component_split_1_1); - test::FillValues<int64>(&expected_splits_1_2, component_split_1_2); - test::FillValues<int64>(&expected_splits_2_1, component_split_2_1); - test::FillValues<int64>(&expected_splits_2_2, component_split_2_2); - test::FillValues<int>(&expected_values_1, component_values_1); - test::FillValues<int>(&expected_values_2, component_values_2); BuildEncodeRaggedTensorGraph<int, int64>( {batched_splits_1, batched_splits_2, batched_splits_3}, TensorShape({14}), @@ -379,31 +265,14 @@ TEST_F(RaggedTensorToVariantKernelTest, NonEmptyBatchedInput) { const auto& encoded_list = GetOutput(0)->vec<Variant>(); EXPECT_EQ(encoded_list.size(), 2); - const Variant& encoded_splits_1_1 = - encoded_list(0).get<Tensor>()->vec<Variant>()(0); - const Variant& encoded_splits_1_2 = - encoded_list(0).get<Tensor>()->vec<Variant>()(1); - const Variant& encoded_values_1 = - encoded_list(0).get<Tensor>()->vec<Variant>()(2); - const Variant& encoded_splits_2_1 = - encoded_list(1).get<Tensor>()->vec<Variant>()(0); - const Variant& encoded_splits_2_2 = - encoded_list(1).get<Tensor>()->vec<Variant>()(1); - const Variant& encoded_values_2 = - encoded_list(1).get<Tensor>()->vec<Variant>()(2); - - test::ExpectTensorEqual<int64>(*encoded_splits_1_1.get<Tensor>(), - expected_splits_1_1); - test::ExpectTensorEqual<int64>(*encoded_splits_1_2.get<Tensor>(), - expected_splits_1_2); - test::ExpectTensorEqual<int64>(*encoded_splits_2_1.get<Tensor>(), - expected_splits_2_1); - test::ExpectTensorEqual<int64>(*encoded_splits_2_2.get<Tensor>(), - expected_splits_2_2); - test::ExpectTensorEqual<int>(*encoded_values_1.get<Tensor>(), - expected_values_1); - test::ExpectTensorEqual<int>(*encoded_values_2.get<Tensor>(), - expected_values_2); + ExpectRaggedTensorVariantEqual<int, int64>( + CreateVariantFromRagged<int, int64>( + {{0, 1, 3, 4, 5, 6}, {0, 2, 3, 4, 5, 6, 7}}, {0, 1, 1, 2, 2, 3, 4}), + *encoded_list(0).get<RaggedTensorVariant>()); + ExpectRaggedTensorVariantEqual<int, int64>( + CreateVariantFromRagged<int, int64>( + {{0, 1, 2, 3, 4, 5}, {0, 1, 2, 5, 6, 7}}, {5, 6, 7, 8, 9, 8, 9}), + *encoded_list(1).get<RaggedTensorVariant>()); } TEST_F(RaggedTensorToVariantKernelTest, NonEmptyBatchedInputInt32Splits) { @@ -424,28 +293,8 @@ TEST_F(RaggedTensorToVariantKernelTest, NonEmptyBatchedInputInt32Splits) { 7, 8, 9, 12, 13, 14}; const std::vector<int> batched_values = {0, 1, 1, 2, 2, 3, 4, 5, 6, 7, 8, 9, 8, 9}; - const std::vector<int> component_split_1_1 = {0, 1, 3, 4, 5, 6}; - const std::vector<int> component_split_1_2 = {0, 2, 3, 4, 5, 6, 7}; - const std::vector<int> component_split_2_1 = {0, 1, 2, 3, 4, 5}; - const std::vector<int> component_split_2_2 = {0, 1, 2, 5, 6, 7}; - const std::vector<int> component_values_1 = {0, 1, 1, 2, 2, 3, 4}; - const std::vector<int> component_values_2 = {5, 6, 7, 8, 9, 8, 9}; - Tensor expected_splits_1_1(DT_INT32, TensorShape({6})); - Tensor expected_splits_1_2(DT_INT32, TensorShape({7})); - Tensor expected_splits_2_1(DT_INT32, TensorShape({6})); - Tensor expected_splits_2_2(DT_INT32, TensorShape({6})); - Tensor expected_values_1(DT_INT32, TensorShape({7})); - Tensor expected_values_2(DT_INT32, TensorShape({7})); - - test::FillValues<int>(&expected_splits_1_1, component_split_1_1); - test::FillValues<int>(&expected_splits_1_2, component_split_1_2); - test::FillValues<int>(&expected_splits_2_1, component_split_2_1); - test::FillValues<int>(&expected_splits_2_2, component_split_2_2); - test::FillValues<int>(&expected_values_1, component_values_1); - test::FillValues<int>(&expected_values_2, component_values_2); - - BuildEncodeRaggedTensorGraph<int, int>( + BuildEncodeRaggedTensorGraph<int, int32>( {batched_splits_1, batched_splits_2, batched_splits_3}, TensorShape({14}), batched_values, true); TF_ASSERT_OK(RunOpKernel()); @@ -453,31 +302,14 @@ TEST_F(RaggedTensorToVariantKernelTest, NonEmptyBatchedInputInt32Splits) { const auto& encoded_list = GetOutput(0)->vec<Variant>(); EXPECT_EQ(encoded_list.size(), 2); - const Variant& encoded_splits_1_1 = - encoded_list(0).get<Tensor>()->vec<Variant>()(0); - const Variant& encoded_splits_1_2 = - encoded_list(0).get<Tensor>()->vec<Variant>()(1); - const Variant& encoded_values_1 = - encoded_list(0).get<Tensor>()->vec<Variant>()(2); - const Variant& encoded_splits_2_1 = - encoded_list(1).get<Tensor>()->vec<Variant>()(0); - const Variant& encoded_splits_2_2 = - encoded_list(1).get<Tensor>()->vec<Variant>()(1); - const Variant& encoded_values_2 = - encoded_list(1).get<Tensor>()->vec<Variant>()(2); - - test::ExpectTensorEqual<int>(*encoded_splits_1_1.get<Tensor>(), - expected_splits_1_1); - test::ExpectTensorEqual<int>(*encoded_splits_1_2.get<Tensor>(), - expected_splits_1_2); - test::ExpectTensorEqual<int>(*encoded_splits_2_1.get<Tensor>(), - expected_splits_2_1); - test::ExpectTensorEqual<int>(*encoded_splits_2_2.get<Tensor>(), - expected_splits_2_2); - test::ExpectTensorEqual<int>(*encoded_values_1.get<Tensor>(), - expected_values_1); - test::ExpectTensorEqual<int>(*encoded_values_2.get<Tensor>(), - expected_values_2); + ExpectRaggedTensorVariantEqual<int, int32>( + CreateVariantFromRagged<int, int32>( + {{0, 1, 3, 4, 5, 6}, {0, 2, 3, 4, 5, 6, 7}}, {0, 1, 1, 2, 2, 3, 4}), + *encoded_list(0).get<RaggedTensorVariant>()); + ExpectRaggedTensorVariantEqual<int, int32>( + CreateVariantFromRagged<int, int32>( + {{0, 1, 2, 3, 4, 5}, {0, 1, 2, 5, 6, 7}}, {5, 6, 7, 8, 9, 8, 9}), + *encoded_list(1).get<RaggedTensorVariant>()); } TEST_F(RaggedTensorToVariantKernelTest, NonBatchInput) { @@ -491,33 +323,17 @@ TEST_F(RaggedTensorToVariantKernelTest, NonBatchInput) { const std::vector<int> batched_values = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; - Tensor batched_ragged_splits_1(DT_INT64, TensorShape({5})); - Tensor batched_ragged_splits_2(DT_INT64, TensorShape({8})); - Tensor batched_ragged_values(DT_INT32, TensorShape({15})); - - test::FillValues<int64>(&batched_ragged_splits_1, batched_splits_1); - test::FillValues<int64>(&batched_ragged_splits_2, batched_splits_2); - test::FillValues<int>(&batched_ragged_values, batched_values); - BuildEncodeRaggedTensorGraph<int, int64>({batched_splits_1, batched_splits_2}, TensorShape({15}), batched_values, false); TF_ASSERT_OK(RunOpKernel()); const auto& encoded_scalar = GetOutput(0)->scalar<Variant>()(); - const Variant& encoded_splits_1 = - encoded_scalar.get<Tensor>()->vec<Variant>()(0); - const Variant& encoded_splits_2 = - encoded_scalar.get<Tensor>()->vec<Variant>()(1); - const Variant& encoded_values = - encoded_scalar.get<Tensor>()->vec<Variant>()(2); - test::ExpectTensorEqual<int64>(*encoded_splits_1.get<Tensor>(), - batched_ragged_splits_1); - test::ExpectTensorEqual<int64>(*encoded_splits_2.get<Tensor>(), - batched_ragged_splits_2); - test::ExpectTensorEqual<int>(*encoded_values.get<Tensor>(), - batched_ragged_values); + ExpectRaggedTensorVariantEqual<int, int64>( + CreateVariantFromRagged<int, int64>({batched_splits_1, batched_splits_2}, + batched_values), + *encoded_scalar.get<RaggedTensorVariant>()); } TEST_F(RaggedTensorToVariantKernelTest, ShapeFnTestBatched) { @@ -598,17 +414,14 @@ TEST_F(RaggedTensorToVariantKernelTest, ShapeFnTestNotBatched) { TEST_F(RaggedTensorToVariantKernelTest, NonRaggedInput) { const std::vector<int> values = {1, 2, 3, 4, 5, 6}; - Tensor expected_values(DT_INT32, TensorShape({6})); - test::FillValues<int>(&expected_values, values); BuildEncodeRaggedTensorGraph<int, int64>({}, TensorShape({6}), values, false); TF_ASSERT_OK(RunOpKernel()); const auto& encoded_scalar = GetOutput(0)->scalar<Variant>()(); - const Variant& encoded_values = - encoded_scalar.get<Tensor>()->vec<Variant>()(0); - - test::ExpectTensorEqual<int>(*encoded_values.get<Tensor>(), expected_values); + ExpectRaggedTensorVariantEqual<int, int64>( + CreateVariantFromRagged<int, int64>({}, values), + *encoded_scalar.get<RaggedTensorVariant>()); } } // namespace diff --git a/tensorflow/core/kernels/ragged_tensor_variant.cc b/tensorflow/core/kernels/ragged_tensor_variant.cc new file mode 100644 index 00000000000..9466313819b --- /dev/null +++ b/tensorflow/core/kernels/ragged_tensor_variant.cc @@ -0,0 +1,86 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#define EIGEN_USE_THREADS +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM +#define EIGEN_USE_GPU +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +#include "tensorflow/core/kernels/ragged_tensor_variant.h" + +namespace tensorflow { + +string RaggedTensorVariant::TypeName() const { return "RaggedTensorVariant"; } + +string RaggedTensorVariant::DebugString() const { + return absl::StrCat( + "RaggedTensorVariant(dtype=", DataTypeString(values_.dtype()), + ", ragged_rank=", nested_splits_.size(), ", splits_dtype=", + DataTypeString(nested_splits_.empty() ? DT_INVALID + : nested_splits_.back().dtype())); +} + +void RaggedTensorVariant::Encode(VariantTensorData* data) const { + data->set_type_name(TypeName()); + for (const auto& splits : nested_splits_) { + *data->add_tensors() = splits; + } + *data->add_tensors() = values_; +} + +bool RaggedTensorVariant::Decode(const VariantTensorData& data) { + if (data.tensors_size() < 1) { + return false; + } + nested_splits_.assign(data.tensors().begin(), + std::prev(data.tensors().end())); + values_ = data.tensors().back(); + return true; +} + +namespace { + +Status RaggedTensorVariantDeviceCopy( + const RaggedTensorVariant& from, RaggedTensorVariant* to, + const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copy) { + TF_RETURN_IF_ERROR(copy(from.values(), to->mutable_values())); + // TODO(b/170415165) Should we use `copy` to move splits from device<->host? + *to->mutable_nested_splits() = from.nested_splits(); + return Status::OK(); +} + +} // namespace + +REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION( + ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_CPU, RaggedTensorVariant, + RaggedTensorVariantZerosLike<CPUDevice>); + +REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION( + ADD_VARIANT_BINARY_OP, DEVICE_CPU, RaggedTensorVariant, + RaggedTensorVariantBinaryAdd<CPUDevice>); + +REGISTER_UNARY_VARIANT_DECODE_FUNCTION(RaggedTensorVariant, + "RaggedTensorVariant"); + +#define REGISTER_RAGGED_TENSOR_VARIANT_COPY(DIRECTION) \ + INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( \ + RaggedTensorVariant, DIRECTION, RaggedTensorVariantDeviceCopy) + +REGISTER_RAGGED_TENSOR_VARIANT_COPY(VariantDeviceCopyDirection::HOST_TO_DEVICE); +REGISTER_RAGGED_TENSOR_VARIANT_COPY(VariantDeviceCopyDirection::DEVICE_TO_HOST); +REGISTER_RAGGED_TENSOR_VARIANT_COPY( + VariantDeviceCopyDirection::DEVICE_TO_DEVICE); + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/ragged_tensor_variant.h b/tensorflow/core/kernels/ragged_tensor_variant.h new file mode 100644 index 00000000000..730758a3e82 --- /dev/null +++ b/tensorflow/core/kernels/ragged_tensor_variant.h @@ -0,0 +1,110 @@ +#include "tensorflow/core/framework/tensor_key.h" +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_RAGGED_TENSOR_VARIANT_H_ +#define TENSORFLOW_CORE_KERNELS_RAGGED_TENSOR_VARIANT_H_ + +#define EIGEN_USE_THREADS +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM +#define EIGEN_USE_GPU +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +#include <vector> + +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/variant_op_registry.h" +#include "tensorflow/core/framework/variant_tensor_data.h" +#include "tensorflow/core/kernels/cwise_ops_common.h" +#include "tensorflow/core/util/tensor_ops_util.h" + +namespace tensorflow { + +// Class used to store a RaggedTensor as a Variant scalar. +class RaggedTensorVariant { + public: + RaggedTensorVariant() {} + RaggedTensorVariant(Tensor values, const std::vector<Tensor>& nested_splits) + : values_(std::move(values)), nested_splits_(nested_splits) {} + + // Variant support methods. + string TypeName() const; + string DebugString() const; + void Encode(VariantTensorData* data) const; + bool Decode(const VariantTensorData& data); + + // The flat_values of the RaggedTensor. + const Tensor& values() const { return values_; } + Tensor* mutable_values() { return &values_; } + void set_values(const Tensor& new_values) { values_ = new_values; } + + // The nested row_splits of the RaggedTensor. + int ragged_rank() const { return nested_splits_.size(); } + const std::vector<Tensor>& nested_splits() const { return nested_splits_; } + std::vector<Tensor>* mutable_nested_splits() { return &nested_splits_; } + const Tensor& splits(int i) const { return nested_splits_[i]; } + Tensor* mutable_splits(int i) { return &nested_splits_[i]; } + void set_nested_splits(const std::vector<Tensor>& nested_splits) { + nested_splits_ = nested_splits; + } + void append_splits(const Tensor& splits) { nested_splits_.push_back(splits); } + + private: + Tensor values_; + std::vector<Tensor> nested_splits_; +}; + +template <typename Device> +Status RaggedTensorVariantZerosLike(OpKernelContext* c, + const RaggedTensorVariant& x, + RaggedTensorVariant* y) { + y->set_nested_splits(x.nested_splits()); + TF_RETURN_IF_ERROR( + ZerosLikeTensor<Device>(c, x.values(), y->mutable_values())); + return Status::OK(); +} + +template <typename Device> +Status RaggedTensorVariantBinaryAdd(OpKernelContext* c, + const RaggedTensorVariant& x, + const RaggedTensorVariant& y, + RaggedTensorVariant* out) { + if (x.values().dtype() != y.values().dtype()) { + return errors::InvalidArgument( + "Can't add RaggedTensorVariants of different dtypes. One is ", + DataTypeString(x.values().dtype()), " and the other is ", + DataTypeString(y.values().dtype())); + } + if (x.ragged_rank() != y.ragged_rank()) { + return errors::InvalidArgument( + "Can't add RaggedTensorVariants of different ragged rank. ", "One is ", + x.ragged_rank(), " and the other is ", y.ragged_rank()); + } + for (int i = 0; i < x.ragged_rank(); ++i) { + if (TensorKey(x.splits(i)) != TensorKey(y.splits(i))) { + return errors::InvalidArgument( + "Can't add RaggedTensorVariants with different row_splits."); + } + } + out->set_nested_splits(x.nested_splits()); + TF_RETURN_IF_ERROR(BinaryAddTensors<Device>(c, x.values(), y.values(), + out->mutable_values())); + return Status::OK(); +} + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_RAGGED_TENSOR_VARIANT_H_ diff --git a/tensorflow/core/ops/ragged_conversion_ops.cc b/tensorflow/core/ops/ragged_conversion_ops.cc index 44712bf7739..043ff469487 100644 --- a/tensorflow/core/ops/ragged_conversion_ops.cc +++ b/tensorflow/core/ops/ragged_conversion_ops.cc @@ -92,7 +92,8 @@ tensorflow::Status ValidateRowPartitionTypesAndShapes( Status RaggedTensorToSparseShapeFn(InferenceContext* c); Status RaggedTensorToVariantShapeFn(InferenceContext* c); Status RaggedTensorFromVariantShapeFn(InferenceContext* c); -tensorflow::Status RaggedTensorToTensorShapeFn(InferenceContext* c); +Status RaggedTensorToVariantGradientShapeFn(InferenceContext* c); +Status RaggedTensorToTensorShapeFn(InferenceContext* c); //============================================================================== // Registered Ops @@ -129,6 +130,15 @@ REGISTER_OP("RaggedTensorFromVariant") .Attr("Tsplits: {int32, int64} = DT_INT64") .SetShapeFn(RaggedTensorFromVariantShapeFn); +REGISTER_OP("RaggedTensorToVariantGradient") + .Input("encoded_ragged_grad: variant") + .Input("row_splits: Tsplits") + .Input("dense_values_shape: int32") + .Output("dense_values_grad: Tvalues") + .Attr("Tvalues: type") + .Attr("Tsplits: {int32, int64} = DT_INT64") + .SetShapeFn(RaggedTensorToVariantGradientShapeFn); + REGISTER_OP("RaggedTensorToTensor") .Attr("T: type") .Attr("Tindex: {int64, int32}") @@ -201,6 +211,14 @@ Status RaggedTensorToVariantShapeFn(InferenceContext* c) { return Status::OK(); } +Status RaggedTensorToVariantGradientShapeFn(InferenceContext* c) { + ShapeHandle shape; + TF_RETURN_IF_ERROR( + c->MakeShapeFromShapeTensorTreatScalarAsUnknownShape(2, &shape)); + c->set_output(0, shape); + return Status::OK(); +} + Status RaggedTensorFromVariantShapeFn(InferenceContext* c) { int64 input_ragged_rank; TF_RETURN_IF_ERROR( diff --git a/tensorflow/python/ops/ragged/BUILD b/tensorflow/python/ops/ragged/BUILD index 309957a76a1..6bd517d7abc 100644 --- a/tensorflow/python/ops/ragged/BUILD +++ b/tensorflow/python/ops/ragged/BUILD @@ -510,6 +510,7 @@ py_test( "//tensorflow/python:framework_ops", "//tensorflow/python:framework_test_lib", "//tensorflow/python:platform_test", + "//tensorflow/python:tensor_array_grad", "//tensorflow/python:tensor_shape", "//tensorflow/python:tensor_spec", "//tensorflow/python/data/ops:dataset_ops", diff --git a/tensorflow/python/ops/ragged/ragged_conversion_ops.py b/tensorflow/python/ops/ragged/ragged_conversion_ops.py index e8c625ccc73..e915d1ecd61 100644 --- a/tensorflow/python/ops/ragged/ragged_conversion_ops.py +++ b/tensorflow/python/ops/ragged/ragged_conversion_ops.py @@ -143,3 +143,42 @@ def to_sparse(rt_input, name=None): def from_sparse(st_input, name=None): return ragged_tensor.RaggedTensor.from_sparse(st_input, name) + + +@ops.RegisterGradient("RaggedTensorFromVariant") +def _ragged_tensor_from_variant_grad(op, *grads): + """Gradient for RaggedTensorFromVariant op.""" + + variant_rank = op.inputs[0].shape.rank + if variant_rank == 0: + batched_input = False + elif variant_rank == 1: + batched_input = True + elif variant_rank is None: + batched_input = (op.get_attr("output_ragged_rank") > 0) + else: + # TODO(edloper): Add a batch_dims argument to RaggedTensorToVariant, so + # we can support this. + raise ValueError("Unable to compute gradient: RaggedTensorToVariant " + "can currently only generate 0D or 1D output.") + return [ + gen_ragged_conversion_ops.ragged_tensor_to_variant( + rt_nested_splits=op.outputs[:-1], + rt_dense_values=grads[-1], + batched_input=batched_input) + ] + + +@ops.RegisterGradient("RaggedTensorToVariant") +def _ragged_tensor_to_variant_grad(op, encoded_ragged_grad): + """Gradient for RaggedTensorToVariant op.""" + dense_values = op.inputs[-1] + ragged_rank = len(op.inputs) - 1 + row_splits = 0 if ragged_rank == 0 else op.inputs[0] + values_grad = gen_ragged_conversion_ops.ragged_tensor_to_variant_gradient( + encoded_ragged_grad=encoded_ragged_grad, + row_splits=row_splits, + dense_values_shape=array_ops.shape(dense_values), + Tvalues=op.inputs[-1].dtype) + result = [None] * ragged_rank + [values_grad] + return result diff --git a/tensorflow/python/ops/ragged/ragged_tensor.py b/tensorflow/python/ops/ragged/ragged_tensor.py index 5f713fa0793..800272d0dd9 100644 --- a/tensorflow/python/ops/ragged/ragged_tensor.py +++ b/tensorflow/python/ops/ragged/ragged_tensor.py @@ -2863,9 +2863,6 @@ def _get_optional_partition_dtype(values): return None -ops.no_gradient("RaggedTensorToVariant") - - _SUPPORTED_RAGGED_VALUE_TYPES = (ops.Tensor, RaggedTensor) diff --git a/tensorflow/python/ops/ragged/ragged_tensor_test.py b/tensorflow/python/ops/ragged/ragged_tensor_test.py index d92cb9cec6c..a38c5527305 100644 --- a/tensorflow/python/ops/ragged/ragged_tensor_test.py +++ b/tensorflow/python/ops/ragged/ragged_tensor_test.py @@ -18,10 +18,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import functools from absl.testing import parameterized import numpy as np from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.eager import backprop from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -30,8 +32,15 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_grad # pylint: disable=unused-import from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import gen_ragged_conversion_ops +from tensorflow.python.ops import gradients_impl +from tensorflow.python.ops import map_fn +from tensorflow.python.ops import math_grad # pylint: disable=unused-import +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import tensor_array_grad # pylint: disable=unused-import from tensorflow.python.ops.ragged import ragged_factory_ops from tensorflow.python.ops.ragged import ragged_math_ops from tensorflow.python.ops.ragged import ragged_tensor @@ -1233,19 +1242,21 @@ class RaggedTensorTest(test_util.TensorFlowTestCase, parameterized.TestCase): with self.assertRaises(errors.InvalidArgumentError): self.evaluate(factory(**kwargs)) + #============================================================================= + # RaggedTensor Variant conversion + #============================================================================= -#============================================================================= -# RaggedTensor Variant conversion -#============================================================================= - - @parameterized.parameters( + @parameterized.named_parameters( { + 'testcase_name': 'Shape_5_none', 'ragged_constant': [[1, 2], [3, 4, 5], [6], [], [7]], 'ragged_rank': 1 }, { + 'testcase_name': 'Shape_4_none_2', 'ragged_constant': [[[1, 2]], [], [[3, 4]], []], 'ragged_rank': 1 }, { + 'testcase_name': 'Shape_1_none_none', 'ragged_constant': [[[1], [2, 3, 4, 5, 6, 7]], [[]]], 'ragged_rank': 2 }) @@ -1432,6 +1443,131 @@ class RaggedTensorTest(test_util.TensorFlowTestCase, parameterized.TestCase): output_ragged_rank=1, input_ragged_rank=1) + def _testRaggedVarientGradient(self, func, x, expected_grad): + x = constant_op.constant(x) + if context.executing_eagerly(): + with backprop.GradientTape() as t: + t.watch(x) + y = func(x) + g = t.gradient(y, x) + else: + y = func(x) + g = gradients_impl.gradients(ys=y, xs=x)[0] + self.assertAllClose(g, expected_grad) + + def testRaggedVariantGradients(self): + def func(x): + rt1 = RaggedTensor.from_row_splits(values=x, row_splits=[0, 4, 7, 8]) + rt2 = rt1 * [[10], [100], [1000]] + v = rt2._to_variant(batched_input=False) + rt3 = RaggedTensor._from_variant(v, dtype=rt2.dtype, output_ragged_rank=1) + return rt3.flat_values + + self._testRaggedVarientGradient( + func, + [3.0, 1.0, 4.0, 1.0, 1.0, 0.0, 2.0, 1.0], + [10., 10., 10., 10., 100., 100., 100., 1000.]) + + def testRaggedVariantGradientsBatched(self): + def func(x): + rt1 = RaggedTensor.from_row_splits(values=x, row_splits=[0, 4, 7, 8]) + rt2 = rt1 * [[10], [100], [1000]] + v = rt2._to_variant(batched_input=True) + rt3 = RaggedTensor._from_variant(v, dtype=rt2.dtype, output_ragged_rank=1) + return rt3.flat_values + + self._testRaggedVarientGradient( + func, + [3.0, 1.0, 4.0, 1.0, 1.0, 0.0, 2.0, 1.0], + [10., 10., 10., 10., 100., 100., 100., 1000.]) + + def testRaggedVariantGradientsBatchedAndSliced(self): + def func(x, i): + rt1 = RaggedTensor.from_row_splits(values=x, row_splits=[0, 4, 7, 8]) + rt2 = rt1 * [[10], [100], [1000]] + v_slice = rt2._to_variant(batched_input=True)[i] + return RaggedTensor._from_variant(v_slice, dtype=rt2.dtype, + output_ragged_rank=0) + + self._testRaggedVarientGradient( + functools.partial(func, i=0), + [3.0, 1.0, 4.0, 1.0, 1.0, 0.0, 2.0, 1.0], + [10., 10., 10., 10., 0., 0., 0., 0.]) + self._testRaggedVarientGradient( + functools.partial(func, i=1), + [3.0, 1.0, 4.0, 1.0, 1.0, 0.0, 2.0, 1.0], + [0., 0., 0., 0., 100., 100., 100., 0.]) + self._testRaggedVarientGradient( + functools.partial(func, i=2), + [3.0, 1.0, 4.0, 1.0, 1.0, 0.0, 2.0, 1.0], + [0., 0., 0., 0., 0., 0., 0., 1000.]) + + def testRaggedVariantGradientsRaggedRank0(self): + def func(x): + x2 = x * 2 + v = gen_ragged_conversion_ops.ragged_tensor_to_variant( + [], x2, batched_input=False) + return RaggedTensor._from_variant(v, dtype=x2.dtype, output_ragged_rank=0) + + self._testRaggedVarientGradient( + func, + [3.0, 1.0, 4.0, 1.0, 1.0, 0.0, 2.0, 1.0], + [2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0]) + + def testRaggedVariantGradientsRaggedRank3(self): + def func(x): + x2 = x * 2 + rt1 = RaggedTensor.from_nested_row_splits( + x2, ([0, 0, 3], [0, 2, 2, 3], [0, 4, 7, 8])) + v = rt1._to_variant(batched_input=False) + rt3 = RaggedTensor._from_variant(v, dtype=x2.dtype, output_ragged_rank=3) + return rt3.flat_values + + self._testRaggedVarientGradient( + func, + [3.0, 1.0, 4.0, 1.0, 1.0, 0.0, 2.0, 1.0], + [2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0]) + + def testRaggedVariantGradientsViaMapFn(self): + rt = RaggedTensor.from_row_splits( + values=[3, 1.0, 4, 1, 5, 9, 2, 6], row_splits=[0, 4, 7, 8]) + + def func(x): + + def transform_row(row): + return math_ops.sqrt( + math_ops.reduce_mean(math_ops.square(row * x), keepdims=True)) + + return math_ops.reduce_sum(map_fn.map_fn(transform_row, rt)) + + self._testRaggedVarientGradient(func, 3.0, 14.653377) + + def testRaggedVariantGradientsViaMapFnReduce(self): + def func(x): + rt1 = RaggedTensor.from_row_splits(values=x, row_splits=[0, 4, 7, 8]) + return map_fn.map_fn( + math_ops.reduce_max, rt1, + fn_output_signature=tensor_spec.TensorSpec((), x.dtype)) + + self._testRaggedVarientGradient( + func, + [3.0, 1.0, 4.0, 1.0, 1.0, 0.0, 2.0, 1.0], + [0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0]) + + def testRaggedVariantGradientsErrors(self): + if context.executing_eagerly(): + return + + rt = RaggedTensor.from_row_splits([1.0, 2.0], row_splits=[0, 2, 2]) + v1 = rt._to_variant() + v2 = array_ops.stack([array_ops.stack([v1])]) + y = RaggedTensor._from_variant(v2, rt.dtype, output_ragged_rank=3) + + with self.assertRaisesRegex( + ValueError, 'Unable to compute gradient: RaggedTensorToVariant ' + 'can currently only generate 0D or 1D output.'): + gradients_impl.gradients(ys=y.flat_values, xs=rt.flat_values) + def assertNumpyObjectTensorsRecursivelyEqual(self, a, b, msg): """Check that two numpy arrays are equal. diff --git a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt index 2efd289c259..96be23b9e50 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt @@ -3212,6 +3212,10 @@ tf_module { name: "RaggedTensorToVariant" argspec: "args=[\'rt_nested_splits\', \'rt_dense_values\', \'batched_input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "RaggedTensorToVariantGradient" + argspec: "args=[\'encoded_ragged_grad\', \'row_splits\', \'dense_values_shape\', \'Tvalues\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "RandomCrop" argspec: "args=[\'image\', \'size\', \'seed\', \'seed2\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'0\', \'None\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt index 2efd289c259..96be23b9e50 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt @@ -3212,6 +3212,10 @@ tf_module { name: "RaggedTensorToVariant" argspec: "args=[\'rt_nested_splits\', \'rt_dense_values\', \'batched_input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "RaggedTensorToVariantGradient" + argspec: "args=[\'encoded_ragged_grad\', \'row_splits\', \'dense_values_shape\', \'Tvalues\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "RandomCrop" argspec: "args=[\'image\', \'size\', \'seed\', \'seed2\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'0\', \'None\'], "