Add tf.ragged.cross and tf.ragged.cross_hashed operations. These ops generate feature crosses from a list of input tensors (similar to tf.sparse.cross), and return the result as a RaggedTensor. Inputs may be tf.RaggedTensors, tf.Tensors, or tf.SparseTensors.
PiperOrigin-RevId: 298495411 Change-Id: I2edde721cfc4e236d78ece470ff30cc68cf6b70c
This commit is contained in:
parent
81e2ecdaae
commit
72e7964b6a
49
tensorflow/core/api_def/base_api/api_def_RaggedCross.pbtxt
Normal file
49
tensorflow/core/api_def/base_api/api_def_RaggedCross.pbtxt
Normal file
@ -0,0 +1,49 @@
|
||||
op {
|
||||
graph_op_name: "RaggedCross"
|
||||
visibility: HIDDEN
|
||||
in_arg {
|
||||
name: "ragged_values"
|
||||
description: "The values tensor for each RaggedTensor input."
|
||||
}
|
||||
in_arg {
|
||||
name: "ragged_row_splits"
|
||||
description: "The row_splits tensor for each RaggedTensor input."
|
||||
}
|
||||
in_arg {
|
||||
name: "sparse_indices"
|
||||
description: "The indices tensor for each SparseTensor input."
|
||||
}
|
||||
in_arg {
|
||||
name: "sparse_values"
|
||||
description: "The values tensor for each SparseTensor input."
|
||||
}
|
||||
in_arg {
|
||||
name: "sparse_shape"
|
||||
description: "The dense_shape tensor for each SparseTensor input."
|
||||
}
|
||||
in_arg {
|
||||
name: "dense_inputs"
|
||||
description: "The tf.Tensor inputs."
|
||||
}
|
||||
out_arg {
|
||||
name: "output_values"
|
||||
description: "The `values` for the returned `RaggedTensor`."
|
||||
}
|
||||
out_arg {
|
||||
name: "output_row_splits"
|
||||
description: "The `row_splits` for the returned `RaggedTensor`."
|
||||
}
|
||||
attr {
|
||||
name: "input_order"
|
||||
description: <<END
|
||||
String specifying the tensor type for each input. The `i`th character in
|
||||
this string specifies the type of the `i`th input, and is one of: 'R' (ragged),
|
||||
'D' (dense), or 'S' (sparse). This attr is used to ensure that the crossed
|
||||
values are combined in the order of the inputs from the call to tf.ragged.cross.
|
||||
END
|
||||
}
|
||||
summary: <<END
|
||||
Generates a feature cross from a list of tensors, and returns it as a
|
||||
RaggedTensor. See `tf.ragged.cross` for more details.
|
||||
END
|
||||
}
|
@ -1394,6 +1394,7 @@ tf_kernel_library(
|
||||
cc_library(
|
||||
name = "ragged_ops",
|
||||
deps = [
|
||||
":ragged_cross_op",
|
||||
":ragged_gather_op",
|
||||
":ragged_range_op",
|
||||
":ragged_tensor_from_variant_op",
|
||||
@ -1548,6 +1549,15 @@ tf_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "ragged_cross_op",
|
||||
srcs = ["ragged_cross_op.cc"],
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "rnn_ops",
|
||||
deps = [
|
||||
|
609
tensorflow/core/kernels/ragged_cross_op.cc
Normal file
609
tensorflow/core/kernels/ragged_cross_op.cc
Normal file
@ -0,0 +1,609 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include <limits>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/platform/fingerprint.h"
|
||||
#include "tensorflow/core/util/util.h"
|
||||
#include "tensorflow/core/util/work_sharder.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
|
||||
//==============================================================================
|
||||
// Feature Readers
|
||||
//==============================================================================
|
||||
|
||||
// A `FeatureReader` is used to read the feature values from a single input
|
||||
// tensor. Subclasses are used for reading different tensor types:
|
||||
// * RaggedFeatureReader<value_type, splits_type>
|
||||
// * SparseFeatureReader<value_type>
|
||||
// * DenseFeatureReader<value_type>
|
||||
//
|
||||
// Where value_type is one of: {tstring, int64}; and SplitsType is one of:
|
||||
// {int32, int64}.
|
||||
class FeatureReader {
|
||||
public:
|
||||
// Returns the number of feature values in the specified batch.
|
||||
virtual int64 FeatureCount(int64 batch) const = 0;
|
||||
|
||||
// Copies the value for the specified feature to `out`.
|
||||
virtual void ReadValue(int64 batch, int64 n, uint64* out) const = 0;
|
||||
virtual void ReadValue(int64 batch, int64 n, tstring* out) const = 0;
|
||||
|
||||
virtual ~FeatureReader() {}
|
||||
};
|
||||
|
||||
using FeatureReaders = std::vector<std::unique_ptr<FeatureReader>>;
|
||||
|
||||
// Copies a feature value `src` to a tstring `dst`, using a view if appropriate.
|
||||
void CopyToString(const tstring& src, tstring* dst) {
|
||||
if (src.type() == tstring::SMALL) {
|
||||
*dst = src; // string buffer fits in the tstring object (under ~24 bytes)
|
||||
} else {
|
||||
dst->assign_as_view(src);
|
||||
}
|
||||
}
|
||||
void CopyToString(int64 src, tstring* dst) { *dst = std::to_string(src); }
|
||||
|
||||
// Copies a feature value `src` to an int64 fingerprint `dst`.
|
||||
void CopyToFingerprint(const tstring& feature, uint64* dst) {
|
||||
*dst = Fingerprint64(feature);
|
||||
}
|
||||
void CopyToFingerprint(int64 feature, uint64* dst) { *dst = feature; }
|
||||
|
||||
// A FeatureReader that is backed by a ragged tensor.
|
||||
template <typename ValuesType, typename SplitsType>
|
||||
class RaggedFeatureReader : public FeatureReader {
|
||||
public:
|
||||
RaggedFeatureReader(const Tensor& values, const Tensor& row_splits)
|
||||
: values_(values.flat<ValuesType>()),
|
||||
row_splits_(row_splits.flat<SplitsType>()) {}
|
||||
|
||||
int64 FeatureCount(int64 batch) const override {
|
||||
return row_splits_(batch + 1) - row_splits_(batch);
|
||||
}
|
||||
|
||||
void ReadValue(int64 batch, int64 n, uint64* out) const override {
|
||||
CopyToFingerprint(values_(row_splits_(batch) + n), out);
|
||||
}
|
||||
|
||||
void ReadValue(int64 batch, int64 n, tstring* out) const override {
|
||||
CopyToString(values_(row_splits_(batch) + n), out);
|
||||
}
|
||||
|
||||
private:
|
||||
const typename TTypes<ValuesType>::ConstFlat values_;
|
||||
const typename TTypes<SplitsType>::ConstFlat row_splits_;
|
||||
};
|
||||
|
||||
// A FeatureReader that is backed by a dense tensor.
|
||||
template <typename ValuesType>
|
||||
class DenseFeatureReader : public FeatureReader {
|
||||
public:
|
||||
explicit DenseFeatureReader(const Tensor& tensor)
|
||||
: values_(tensor.matrix<ValuesType>()),
|
||||
feature_count_(tensor.dim_size(1)) {}
|
||||
|
||||
int64 FeatureCount(int64 batch) const override { return feature_count_; }
|
||||
|
||||
void ReadValue(int64 batch, int64 n, uint64* out) const override {
|
||||
CopyToFingerprint(values_(batch, n), out);
|
||||
}
|
||||
|
||||
void ReadValue(int64 batch, int64 n, tstring* out) const override {
|
||||
CopyToString(values_(batch, n), out);
|
||||
}
|
||||
|
||||
private:
|
||||
const typename TTypes<ValuesType>::ConstMatrix values_;
|
||||
const int64 feature_count_;
|
||||
};
|
||||
|
||||
// A FeatureReader that is backed by a sparse tensor.
|
||||
template <typename ValuesType>
|
||||
class SparseFeatureReader : public FeatureReader {
|
||||
public:
|
||||
SparseFeatureReader(const Tensor& indices_t, const Tensor& values_t,
|
||||
int64 batch_size)
|
||||
: values_(values_t.flat<ValuesType>()) {
|
||||
row_splits_.reserve(batch_size + 1);
|
||||
row_splits_.push_back(0);
|
||||
auto indices = indices_t.matrix<int64>();
|
||||
int64 num_values = values_.size();
|
||||
int64 i = 0; // value index
|
||||
for (int row = 0; row < batch_size; row++) {
|
||||
while (i < num_values && indices(i, 0) <= row) ++i;
|
||||
row_splits_.push_back(i);
|
||||
}
|
||||
}
|
||||
|
||||
int64 FeatureCount(int64 batch) const override {
|
||||
return row_splits_[batch + 1] - row_splits_[batch];
|
||||
}
|
||||
|
||||
void ReadValue(int64 batch, int64 n, uint64* out) const override {
|
||||
CopyToFingerprint(values_(row_splits_[batch] + n), out);
|
||||
}
|
||||
|
||||
void ReadValue(int64 batch, int64 n, tstring* out) const override {
|
||||
CopyToString(values_(row_splits_[batch] + n), out);
|
||||
}
|
||||
|
||||
private:
|
||||
const typename TTypes<ValuesType>::ConstFlat values_;
|
||||
std::vector<int64> row_splits_;
|
||||
};
|
||||
|
||||
//==============================================================================
|
||||
// Output Writers
|
||||
//==============================================================================
|
||||
|
||||
// An `OutputWriter` is used to write the feature crosses to the output values
|
||||
// tensor. Different subclasses are used for writing different output dtypes:
|
||||
// * OutputWriterImpl<tstring, SplitsType> (for tf.ragged.cross)
|
||||
// * OutputWriterImpl<int64, SplitsType> (for tf.ragged.cross_hashed)
|
||||
class OutputWriter {
|
||||
public:
|
||||
virtual void WriteOutputSlice(int64 begin, int64 end) = 0;
|
||||
virtual ~OutputWriter() {}
|
||||
};
|
||||
|
||||
template <typename ValuesType, typename SplitsType>
|
||||
class OutputWriterImpl : public OutputWriter {
|
||||
public:
|
||||
using FlatValues = typename TTypes<ValuesType>::Flat;
|
||||
using FlatSplits = typename TTypes<SplitsType>::ConstFlat;
|
||||
|
||||
OutputWriterImpl(const FeatureReaders& features, int64 num_buckets,
|
||||
uint64 hash_key, const Tensor* splits_out,
|
||||
Tensor* values_out)
|
||||
: features_(features),
|
||||
num_buckets_(num_buckets),
|
||||
hash_key_(hash_key),
|
||||
splits_out_(splits_out->flat<SplitsType>()),
|
||||
values_out_(values_out->flat<ValuesType>()) {}
|
||||
|
||||
// Reads features from the specified slice of batch indices, computes
|
||||
// feature crosses for each one, and writes them to values_out_.
|
||||
void WriteOutputSlice(int64 begin, int64 end) override {
|
||||
std::vector<int> combination(features_.size(), 0);
|
||||
for (int64 b = begin; b < end; ++b) {
|
||||
auto row_start = splits_out_(b);
|
||||
auto row_limit = splits_out_(b + 1);
|
||||
for (auto i = row_start; i < row_limit; ++i) {
|
||||
WriteCombination(b, combination, &values_out_(i));
|
||||
NextCombination(b, &combination);
|
||||
}
|
||||
combination.assign(features_.size(), 0); // reset for next batch.
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
// Joins the specified combination of input features into a single string,
|
||||
// and writes it to *out.
|
||||
void WriteCombination(int64 batch_index, const std::vector<int>& combination,
|
||||
tstring* out) {
|
||||
static const auto k_feature_separator = "_X_";
|
||||
gtl::InlinedVector<tstring, 6> cross_vec(features_.size());
|
||||
for (int i = 0; i < combination.size(); ++i) {
|
||||
features_[i]->ReadValue(batch_index, combination[i], &cross_vec[i]);
|
||||
}
|
||||
*out = absl::StrJoin(cross_vec, k_feature_separator);
|
||||
}
|
||||
|
||||
// Joins the specified combination of input features into a single
|
||||
// fingerprint, and writes it to *out.
|
||||
void WriteCombination(int64 batch_index, const std::vector<int>& combination,
|
||||
int64* out) {
|
||||
// Do the fingerprint concatenation on uint64.
|
||||
uint64 hashed_output = hash_key_;
|
||||
for (size_t i = 0; i < combination.size(); ++i) {
|
||||
uint64 hash_i;
|
||||
features_[i]->ReadValue(batch_index, combination[i], &hash_i);
|
||||
hashed_output = FingerprintCat64(hashed_output, hash_i);
|
||||
}
|
||||
// The return value is int64 based on the number of buckets.
|
||||
if (num_buckets_ > 0) {
|
||||
*out = hashed_output % num_buckets_;
|
||||
} else {
|
||||
// To prevent negative output we take modulo to max int64.
|
||||
*out = hashed_output % std::numeric_limits<int64>::max();
|
||||
}
|
||||
}
|
||||
|
||||
// Updates `combination` to the next combination of input features.
|
||||
void NextCombination(int64 batch_index, std::vector<int>* combination) const {
|
||||
bool carry = true;
|
||||
for (int i = combination->size() - 1; i >= 0; i--) {
|
||||
if (carry) {
|
||||
(*combination)[i] = (*combination)[i] + 1;
|
||||
}
|
||||
if ((*combination)[i] == features_[i]->FeatureCount(batch_index)) {
|
||||
(*combination)[i] = 0;
|
||||
} else {
|
||||
carry = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const FeatureReaders& features_;
|
||||
const int64 num_buckets_;
|
||||
const uint64 hash_key_;
|
||||
FlatSplits splits_out_;
|
||||
FlatValues values_out_;
|
||||
};
|
||||
|
||||
// Returns an appropriate OutputWriter, based on the dtypes of the
|
||||
// given tensors.
|
||||
std::unique_ptr<OutputWriter> MakeOutputWriter(const FeatureReaders& features,
|
||||
int64 num_buckets,
|
||||
uint64 hash_key,
|
||||
const Tensor* splits_out,
|
||||
Tensor* values_out) {
|
||||
if (values_out->dtype() == DT_INT64) {
|
||||
if (splits_out->dtype() == DT_INT64) {
|
||||
return std::make_unique<OutputWriterImpl<int64, int64>>(
|
||||
features, num_buckets, hash_key, splits_out, values_out);
|
||||
} else {
|
||||
return std::make_unique<OutputWriterImpl<int64, int32>>(
|
||||
features, num_buckets, hash_key, splits_out, values_out);
|
||||
}
|
||||
} else {
|
||||
if (splits_out->dtype() == DT_INT64) {
|
||||
return std::make_unique<OutputWriterImpl<tstring, int64>>(
|
||||
features, num_buckets, hash_key, splits_out, values_out);
|
||||
} else {
|
||||
return std::make_unique<OutputWriterImpl<tstring, int32>>(
|
||||
features, num_buckets, hash_key, splits_out, values_out);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//==============================================================================
|
||||
// RaggedCross Kernel
|
||||
//==============================================================================
|
||||
|
||||
template <typename SplitsType>
|
||||
class RaggedCrossOp : public OpKernel {
|
||||
public:
|
||||
explicit RaggedCrossOp(OpKernelConstruction* context) : OpKernel(context) {
|
||||
OP_REQUIRES_OK(context, context->GetAttr("num_buckets", &num_buckets_));
|
||||
// Read signed_hash_key_ as int64 since uint64 attributes are not
|
||||
// supported by REGISTER_OP.
|
||||
int64 signed_hash_key_;
|
||||
OP_REQUIRES_OK(context, context->GetAttr("hash_key", &signed_hash_key_));
|
||||
hash_key_ = static_cast<uint64>(signed_hash_key_);
|
||||
|
||||
int num_sparse;
|
||||
OP_REQUIRES_OK(context, context->GetAttr("Nsparse", &num_sparse));
|
||||
|
||||
OP_REQUIRES_OK(context, context->GetAttr("ragged_values_types",
|
||||
&ragged_values_types_));
|
||||
OP_REQUIRES_OK(context, context->GetAttr("ragged_splits_types",
|
||||
&ragged_splits_types_));
|
||||
OP_REQUIRES_OK(context, context->GetAttr("sparse_values_types",
|
||||
&sparse_values_types_));
|
||||
OP_REQUIRES_OK(context, context->GetAttr("dense_types", &dense_types_));
|
||||
OP_REQUIRES_OK(context, context->GetAttr("input_order", &input_order_));
|
||||
OP_REQUIRES(context,
|
||||
ragged_values_types_.size() == ragged_splits_types_.size(),
|
||||
errors::InvalidArgument(
|
||||
"ragged values and splits must have the same length"));
|
||||
OP_REQUIRES(context, num_sparse == sparse_values_types_.size(),
|
||||
errors::InvalidArgument(
|
||||
"sparse indices and values must have the same length"));
|
||||
OP_REQUIRES(context,
|
||||
ragged_values_types_.size() + sparse_values_types_.size() +
|
||||
dense_types_.size() ==
|
||||
input_order_.size(),
|
||||
errors::InvalidArgument("Invalid length for input_order"));
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
OpInputList ragged_values_list;
|
||||
OpInputList ragged_splits_list;
|
||||
OpInputList sparse_indices_list;
|
||||
OpInputList sparse_values_list;
|
||||
OpInputList sparse_shape_list;
|
||||
OpInputList dense_list;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->input_list("ragged_values", &ragged_values_list));
|
||||
OP_REQUIRES_OK(
|
||||
context, context->input_list("ragged_row_splits", &ragged_splits_list));
|
||||
OP_REQUIRES_OK(context,
|
||||
context->input_list("sparse_indices", &sparse_indices_list));
|
||||
OP_REQUIRES_OK(context,
|
||||
context->input_list("sparse_values", &sparse_values_list));
|
||||
OP_REQUIRES_OK(context,
|
||||
context->input_list("sparse_shape", &sparse_shape_list));
|
||||
OP_REQUIRES_OK(context, context->input_list("dense_inputs", &dense_list));
|
||||
OP_REQUIRES_OK(context,
|
||||
ValidateInput(ragged_values_list, ragged_splits_list,
|
||||
sparse_indices_list, sparse_values_list,
|
||||
sparse_shape_list, dense_list));
|
||||
|
||||
int64 batch_size =
|
||||
CalculateBatchSize(ragged_splits_list, sparse_shape_list, dense_list);
|
||||
|
||||
FeatureReaders features;
|
||||
OP_REQUIRES_OK(context,
|
||||
BuildFeatureReaders(ragged_values_list, ragged_splits_list,
|
||||
sparse_indices_list, sparse_values_list,
|
||||
dense_list, batch_size, &features));
|
||||
|
||||
Tensor* values_out;
|
||||
Tensor* row_splits_out;
|
||||
OP_REQUIRES_OK(context, BuildOutputTensors(features, batch_size, context,
|
||||
&values_out, &row_splits_out));
|
||||
|
||||
std::unique_ptr<OutputWriter> output_writer = MakeOutputWriter(
|
||||
features, num_buckets_, hash_key_, row_splits_out, values_out);
|
||||
|
||||
auto do_work = [&output_writer](int64 begin, int64 end) {
|
||||
output_writer->WriteOutputSlice(begin, end);
|
||||
};
|
||||
|
||||
// TODO(edloper): optimize cost_per_batch
|
||||
const int cost_per_batch = 5000 * ragged_values_list.size();
|
||||
auto thread_pool =
|
||||
context->device()->tensorflow_cpu_worker_threads()->workers;
|
||||
thread_pool->ParallelFor(batch_size, cost_per_batch, do_work);
|
||||
}
|
||||
|
||||
private:
|
||||
// Validates input tensors.
|
||||
Status ValidateInput(const OpInputList& ragged_values_list,
|
||||
const OpInputList& ragged_splits_list,
|
||||
const OpInputList& sparse_indices_list,
|
||||
const OpInputList& sparse_values_list,
|
||||
const OpInputList& sparse_shape_list,
|
||||
const OpInputList& dense_list) {
|
||||
const auto num_ragged = ragged_values_list.size();
|
||||
const auto num_sparse = sparse_indices_list.size();
|
||||
|
||||
// Validate tensor shapes.
|
||||
for (int i = 0; i < num_ragged; ++i) {
|
||||
if (!TensorShapeUtils::IsVector(ragged_values_list[i].shape())) {
|
||||
return errors::InvalidArgument(
|
||||
"tf.ragged.cross only supports inputs with rank=2.");
|
||||
}
|
||||
if (!TensorShapeUtils::IsVector(ragged_splits_list[i].shape()) ||
|
||||
(ragged_splits_list[i].NumElements() == 0)) {
|
||||
return errors::InvalidArgument("Invalid RaggedTensor");
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < num_sparse; ++i) {
|
||||
if (!TensorShapeUtils::IsMatrix(sparse_indices_list[i].shape()) ||
|
||||
!TensorShapeUtils::IsVector(sparse_values_list[i].shape()) ||
|
||||
!TensorShapeUtils::IsVector(sparse_shape_list[i].shape())) {
|
||||
return errors::InvalidArgument("Invalid SparseTensor ", i);
|
||||
}
|
||||
if (sparse_shape_list[i].NumElements() != 2) {
|
||||
return errors::InvalidArgument(
|
||||
"tf.ragged.cross only supports inputs with rank=2.");
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < dense_list.size(); ++i) {
|
||||
if (!TensorShapeUtils::IsMatrix(dense_list[i].shape())) {
|
||||
return errors::InvalidArgument(
|
||||
"tf.ragged.cross only supports inputs with rank=2.");
|
||||
}
|
||||
}
|
||||
|
||||
// Check that batch sizes are consistent.
|
||||
int64 batch_size =
|
||||
CalculateBatchSize(ragged_splits_list, sparse_shape_list, dense_list);
|
||||
for (int i = 0; i < num_ragged; ++i) {
|
||||
if (ragged_splits_list[i].NumElements() - 1 != batch_size) {
|
||||
return errors::InvalidArgument(
|
||||
"inputs must all have the same batch dimension size.");
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < num_sparse; ++i) {
|
||||
if (sparse_shape_list[i].flat<int64>()(0) != batch_size) {
|
||||
return errors::InvalidArgument(
|
||||
"inputs must all have the same batch dimension size.");
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < dense_list.size(); ++i) {
|
||||
if (dense_list[i].dim_size(0) != batch_size) {
|
||||
return errors::InvalidArgument(
|
||||
"inputs must all have the same batch dimension size.");
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Calculate the batch size from any input tensor. (We check that all input
|
||||
// tensors have the same batch size in `ValidateInput`).
|
||||
int64 CalculateBatchSize(const OpInputList& ragged_splits_list,
|
||||
const OpInputList& sparse_shape_list,
|
||||
const OpInputList& dense_list) {
|
||||
if (ragged_splits_list.size() > 0) {
|
||||
return ragged_splits_list[0].NumElements() - 1;
|
||||
} else if (dense_list.size() > 0) {
|
||||
return dense_list[0].dim_size(0);
|
||||
} else if (sparse_shape_list.size() > 0) {
|
||||
return sparse_shape_list[0].flat<int64>()(0);
|
||||
} else {
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
// Build a feature reader for each input tensor, and store them in `features`.
|
||||
Status BuildFeatureReaders(const OpInputList& ragged_values_list,
|
||||
const OpInputList& ragged_splits_list,
|
||||
const OpInputList& sparse_indices_list,
|
||||
const OpInputList& sparse_values_list,
|
||||
const OpInputList& dense_list, int64 batch_size,
|
||||
FeatureReaders* features) {
|
||||
features->reserve(input_order_.size());
|
||||
|
||||
int next_ragged = 0;
|
||||
int next_sparse = 0;
|
||||
int next_dense = 0;
|
||||
for (char c : input_order_) {
|
||||
if (c == 'R') {
|
||||
TF_RETURN_IF_ERROR(BuildRaggedFeatureReader(
|
||||
ragged_values_list[next_ragged], ragged_splits_list[next_ragged],
|
||||
features));
|
||||
next_ragged++;
|
||||
} else if (c == 'S') {
|
||||
TF_RETURN_IF_ERROR(BuildSparseFeatureReader(
|
||||
sparse_indices_list[next_sparse], sparse_values_list[next_sparse],
|
||||
batch_size, features));
|
||||
next_sparse++;
|
||||
} else if (c == 'D') {
|
||||
TF_RETURN_IF_ERROR(
|
||||
BuildDenseFeatureReader(dense_list[next_dense++], features));
|
||||
} else {
|
||||
return errors::InvalidArgument("Unexpected input_order value.");
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Builds a RaggedReatureReader
|
||||
static Status BuildRaggedFeatureReader(const Tensor& values,
|
||||
const Tensor& splits,
|
||||
FeatureReaders* features) {
|
||||
if (values.dtype() != DT_INT64 && values.dtype() != DT_STRING) {
|
||||
return errors::InvalidArgument("Unexpected dtype for input ",
|
||||
(features->size() + 1), ": ",
|
||||
values.dtype());
|
||||
}
|
||||
if (splits.dtype() != DT_INT64 && splits.dtype() != DT_INT32) {
|
||||
return errors::InvalidArgument("Unexpected row_splits.dtype for input ",
|
||||
(features->size() + 1), ": ",
|
||||
values.dtype());
|
||||
}
|
||||
if (values.dtype() == DT_INT64) {
|
||||
if (splits.dtype() == DT_INT64) {
|
||||
features->emplace_back(
|
||||
new RaggedFeatureReader<int64, int64>(values, splits));
|
||||
} else {
|
||||
features->emplace_back(
|
||||
new RaggedFeatureReader<int64, int32>(values, splits));
|
||||
}
|
||||
} else {
|
||||
if (splits.dtype() == DT_INT64) {
|
||||
features->emplace_back(
|
||||
new RaggedFeatureReader<tstring, int64>(values, splits));
|
||||
} else {
|
||||
features->emplace_back(
|
||||
new RaggedFeatureReader<tstring, int32>(values, splits));
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Builds a DenseFaggedReatureReader.
|
||||
static Status BuildDenseFeatureReader(const Tensor& values,
|
||||
FeatureReaders* features) {
|
||||
if (values.dtype() == DT_INT64) {
|
||||
features->emplace_back(new DenseFeatureReader<int64>(values));
|
||||
} else if (values.dtype() == DT_STRING) {
|
||||
features->emplace_back(new DenseFeatureReader<tstring>(values));
|
||||
} else {
|
||||
return errors::InvalidArgument("Unexpected dtype for input ",
|
||||
(features->size() + 1), ": ",
|
||||
values.dtype());
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Builds a SparseFaggedReatureReader.
|
||||
static Status BuildSparseFeatureReader(const Tensor& indices,
|
||||
const Tensor& values, int64 batch_size,
|
||||
FeatureReaders* features) {
|
||||
if (values.dtype() == DT_INT64) {
|
||||
features->emplace_back(
|
||||
new SparseFeatureReader<int64>(indices, values, batch_size));
|
||||
} else if (values.dtype() == DT_STRING) {
|
||||
features->emplace_back(
|
||||
new SparseFeatureReader<tstring>(indices, values, batch_size));
|
||||
} else {
|
||||
return errors::InvalidArgument("Unexpected dtype for input ",
|
||||
(features->size() + 1), ": ",
|
||||
values.dtype());
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Allocates output tensors with proper size, and populates row_splits_out.
|
||||
Status BuildOutputTensors(const FeatureReaders& features, int64 batch_size,
|
||||
OpKernelContext* context, Tensor** values_out,
|
||||
Tensor** row_splits_out) {
|
||||
// Allocate and populate the row_splits output tensor.
|
||||
TF_RETURN_IF_ERROR(context->allocate_output(
|
||||
1, TensorShape({batch_size + 1}), row_splits_out));
|
||||
auto flat_row_splits = (*row_splits_out)->flat<SplitsType>();
|
||||
int64 cross_count_total = 0;
|
||||
flat_row_splits(0) = 0;
|
||||
for (int64 b = 0; b < batch_size; b++) {
|
||||
cross_count_total += CrossCountByBatchIndex(features, b);
|
||||
flat_row_splits(b + 1) = cross_count_total;
|
||||
}
|
||||
|
||||
// Allocate the values output tensor.
|
||||
TF_RETURN_IF_ERROR(context->allocate_output(
|
||||
0, TensorShape({cross_count_total}), values_out));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Returns number of crosses for a given batch_index
|
||||
int64 CrossCountByBatchIndex(const FeatureReaders& features,
|
||||
int batch_index) {
|
||||
int64 cross_count = 1;
|
||||
for (int i = 0; i < features.size(); ++i) {
|
||||
const auto feature_count = features[i]->FeatureCount(batch_index);
|
||||
if (feature_count == 0) return 0;
|
||||
cross_count *= feature_count;
|
||||
}
|
||||
return cross_count;
|
||||
}
|
||||
|
||||
int64 num_buckets_;
|
||||
uint64 hash_key_;
|
||||
std::vector<DataType> ragged_values_types_;
|
||||
std::vector<DataType> ragged_splits_types_;
|
||||
std::vector<DataType> sparse_values_types_;
|
||||
std::vector<DataType> dense_types_;
|
||||
tstring input_order_;
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("RaggedCross")
|
||||
.Device(DEVICE_CPU)
|
||||
.TypeConstraint<int32>("out_row_splits_type"),
|
||||
RaggedCrossOp<int32>);
|
||||
REGISTER_KERNEL_BUILDER(Name("RaggedCross")
|
||||
.Device(DEVICE_CPU)
|
||||
.TypeConstraint<int64>("out_row_splits_type"),
|
||||
RaggedCrossOp<int64>);
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
@ -41,6 +41,79 @@ REGISTER_OP("RaggedGather")
|
||||
.Attr("OUTPUT_RAGGED_RANK: int >= 0")
|
||||
.SetShapeFn(RaggedGatherShapeFn);
|
||||
|
||||
REGISTER_OP("RaggedCross")
|
||||
.Input("ragged_values: ragged_values_types")
|
||||
.Input("ragged_row_splits: ragged_splits_types")
|
||||
.Input("sparse_indices: Nsparse * int64")
|
||||
.Input("sparse_values: sparse_values_types")
|
||||
.Input("sparse_shape: Nsparse * int64")
|
||||
.Input("dense_inputs: dense_types")
|
||||
.Output("output_values: out_values_type")
|
||||
.Output("output_row_splits: out_row_splits_type")
|
||||
.Attr("Nsparse: int >= 0")
|
||||
.Attr("input_order: string")
|
||||
.Attr("hashed_output: bool")
|
||||
.Attr("num_buckets: int >= 0")
|
||||
.Attr("hash_key: int")
|
||||
.Attr("ragged_values_types: list({int64, string}) >= 0")
|
||||
.Attr("ragged_splits_types: list({int32, int64}) >= 0")
|
||||
.Attr("sparse_values_types: list({int64, string}) >= 0")
|
||||
.Attr("dense_types: list({int64, string}) >= 0")
|
||||
.Attr("out_values_type: {int64, string}")
|
||||
.Attr("out_row_splits_type: {int32, int64}")
|
||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||
std::vector<DataType> ragged_values_types;
|
||||
std::vector<DataType> ragged_splits_types;
|
||||
std::vector<DataType> dense_types;
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
c->GetAttr("ragged_values_types", &ragged_values_types));
|
||||
TF_RETURN_IF_ERROR(
|
||||
c->GetAttr("ragged_splits_types", &ragged_splits_types));
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("dense_types", &dense_types));
|
||||
|
||||
int num_ragged = ragged_values_types.size();
|
||||
if (num_ragged != ragged_splits_types.size()) {
|
||||
return errors::InvalidArgument(
|
||||
"Parameters `values` and `row_splits` must be the same length");
|
||||
}
|
||||
|
||||
int num_sparse;
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("Nsparse", &num_sparse));
|
||||
|
||||
ShapeHandle out_values = c->UnknownShapeOfRank(1);
|
||||
ShapeHandle out_splits = c->UnknownShapeOfRank(1);
|
||||
|
||||
// Merge the shapes of row_splits from ragged inputs. (This is one plus
|
||||
// the batch size.)
|
||||
int ragged_splits_start = num_ragged;
|
||||
for (int i = 0; i < ragged_splits_types.size(); ++i) {
|
||||
ShapeHandle row_splits = c->input(i + ragged_splits_start);
|
||||
if (!c->Merge(out_splits, row_splits, &out_splits).ok()) {
|
||||
return errors::InvalidArgument(
|
||||
"inputs must all have the same batch dimension size.");
|
||||
}
|
||||
}
|
||||
|
||||
// Merge the batch size of each dense input into out_splits.
|
||||
int dense_start = num_ragged * 2 + num_sparse * 3;
|
||||
for (int i = 0; i < dense_types.size(); ++i) {
|
||||
ShapeHandle dense_input = c->input(i + dense_start);
|
||||
int64 batch_size = c->Value(c->Dim(dense_input, 0));
|
||||
if (batch_size != InferenceContext::kUnknownDim) {
|
||||
ShapeHandle row_splits = c->Vector(batch_size + 1);
|
||||
if (!c->Merge(out_splits, row_splits, &out_splits).ok()) {
|
||||
return errors::InvalidArgument(
|
||||
"inputs must all have the same batch dimension size.");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
c->set_output(0, out_values);
|
||||
c->set_output(1, out_splits);
|
||||
return Status::OK();
|
||||
});
|
||||
|
||||
//==============================================================================
|
||||
// Shape Functions
|
||||
//==============================================================================
|
||||
|
@ -1123,3 +1123,17 @@ py_test(
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "ragged_cross_op_test",
|
||||
srcs = ["ragged_cross_op_test.py"],
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":ragged_array_ops",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
)
|
||||
|
@ -20,9 +20,11 @@ from __future__ import print_function
|
||||
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import check_ops
|
||||
from tensorflow.python.ops import gen_ragged_array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import sort_ops
|
||||
from tensorflow.python.ops.ragged import ragged_functional_ops
|
||||
@ -32,7 +34,6 @@ from tensorflow.python.ops.ragged import ragged_util
|
||||
from tensorflow.python.ops.ragged import segment_id_ops
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
||||
#===============================================================================
|
||||
# Masking
|
||||
#===============================================================================
|
||||
@ -107,7 +108,8 @@ def boolean_mask(data, mask, name=None):
|
||||
if ragged_tensor.is_ragged(mask):
|
||||
if not ragged_tensor.is_ragged(data):
|
||||
data = ragged_tensor.RaggedTensor.from_tensor(
|
||||
data, ragged_rank=mask.ragged_rank,
|
||||
data,
|
||||
ragged_rank=mask.ragged_rank,
|
||||
row_splits_dtype=mask.row_splits.dtype)
|
||||
# Check that mask.nested_row_splits is a prefix of
|
||||
# data.nested_row_splits.
|
||||
@ -160,15 +162,15 @@ def boolean_mask(data, mask, name=None):
|
||||
segment_mask = array_ops.gather(mask, segment_ids)
|
||||
masked_values = boolean_mask(data.values, segment_mask)
|
||||
|
||||
return ragged_tensor.RaggedTensor.from_row_splits(masked_values,
|
||||
masked_splits,
|
||||
validate=False)
|
||||
return ragged_tensor.RaggedTensor.from_row_splits(
|
||||
masked_values, masked_splits, validate=False)
|
||||
|
||||
# If mask is non-ragged and has rank>1, then convert it to be ragged,
|
||||
# with a ragged rank matching data.
|
||||
if ragged_tensor.is_ragged(data):
|
||||
mask = ragged_tensor.RaggedTensor.from_tensor(
|
||||
mask, ragged_rank=min(data.ragged_rank, mask.shape.ndims - 1),
|
||||
mask,
|
||||
ragged_rank=min(data.ragged_rank, mask.shape.ndims - 1),
|
||||
row_splits_dtype=data.row_splits.dtype)
|
||||
return boolean_mask(data, mask)
|
||||
|
||||
@ -182,8 +184,8 @@ def boolean_mask(data, mask, name=None):
|
||||
# number of values it contains. Then flatten that to get a list of
|
||||
# cell lengths, and convert it to splits. Finally, combine the splits
|
||||
# and values to get the innermost ragged tensor.
|
||||
masked_lengths = math_ops.count_nonzero(mask, axis=-1,
|
||||
dtype=row_splits_dtype)
|
||||
masked_lengths = math_ops.count_nonzero(
|
||||
mask, axis=-1, dtype=row_splits_dtype)
|
||||
flattened_masked_lengths = array_ops.reshape(masked_lengths, [-1])
|
||||
masked_values = ragged_tensor.RaggedTensor.from_row_lengths(
|
||||
masked_values, flattened_masked_lengths, validate=False)
|
||||
@ -340,8 +342,7 @@ def _tile_ragged_splits(rt_input, multiples, const_multiples=None):
|
||||
for src_axis in range(ragged_rank):
|
||||
for dst_axis in range(src_axis + 1, ragged_rank - 1):
|
||||
projected_splits[src_axis][dst_axis] = array_ops.gather(
|
||||
nested_splits[dst_axis],
|
||||
projected_splits[src_axis][dst_axis - 1])
|
||||
nested_splits[dst_axis], projected_splits[src_axis][dst_axis - 1])
|
||||
|
||||
# For each ragged dimension: nested_splits[axis] -> result_splits[axis].
|
||||
result_splits = []
|
||||
@ -400,8 +401,7 @@ def expand_dims(input, axis, name=None): # pylint: disable=redefined-builtin
|
||||
`[D1, (D2), (D3), D4]` | `4` | `[D1, (D2), (D3), D4, 1]`
|
||||
|
||||
Args:
|
||||
input: The potentially tensor that should be expanded with a new
|
||||
dimension.
|
||||
input: The potentially tensor that should be expanded with a new dimension.
|
||||
axis: An integer constant indicating where the new dimension should be
|
||||
inserted.
|
||||
name: A name for the operation (optional).
|
||||
@ -556,10 +556,10 @@ def stack_dynamic_partitions(data, partitions, num_partitions, name=None):
|
||||
Args:
|
||||
data: A `Tensor` or `RaggedTensor` containing the values to stack.
|
||||
partitions: An `int32` or `int64` `Tensor` or `RaggedTensor` specifying the
|
||||
partition that each slice of `data` should be added to.
|
||||
`partitions.shape` must be a prefix of `data.shape`. Values must be
|
||||
greater than or equal to zero, and less than `num_partitions`.
|
||||
`partitions` is not required to be sorted.
|
||||
partition that each slice of `data` should be added to. `partitions.shape`
|
||||
must be a prefix of `data.shape`. Values must be greater than or equal to
|
||||
zero, and less than `num_partitions`. `partitions` is not required to be
|
||||
sorted.
|
||||
num_partitions: An `int32` or `int64` scalar specifying the number of
|
||||
partitions to output. This determines the number of rows in `output`.
|
||||
name: A name prefix for the returned tensor (optional).
|
||||
@ -650,8 +650,8 @@ def reverse(tensor, axis, name=None):
|
||||
|
||||
Args:
|
||||
tensor: A 'RaggedTensor' to reverse.
|
||||
axis: A list or tuple of 'int' or a constant 1D 'tf.Tensor'. The indices
|
||||
of the axes to reverse.
|
||||
axis: A list or tuple of 'int' or a constant 1D 'tf.Tensor'. The indices of
|
||||
the axes to reverse.
|
||||
name: A name prefix for the returned tensor (optional).
|
||||
|
||||
Returns:
|
||||
@ -687,3 +687,137 @@ def reverse(tensor, axis, name=None):
|
||||
slices[dim] = slice(None, None, -1)
|
||||
|
||||
return tensor[tuple(slices)]
|
||||
|
||||
|
||||
#===============================================================================
|
||||
# Cross
|
||||
#===============================================================================
|
||||
|
||||
|
||||
@tf_export('ragged.cross')
|
||||
def cross(inputs, name=None):
|
||||
"""Generates feature cross from a list of tensors.
|
||||
|
||||
The input tensors must have `rank=2`, and must all have the same number of
|
||||
rows. The result is a `RaggedTensor` with the same number of rows as the
|
||||
inputs, where `result[row]` contains a list of all combinations of values
|
||||
formed by taking a single value from each input's corresponding row
|
||||
(`inputs[i][row]`). Values are combined by joining their strings with '_X_'.
|
||||
E.g.:
|
||||
|
||||
>>> tf.ragged.cross([tf.ragged.constant([['a'], ['b', 'c']]),
|
||||
... tf.ragged.constant([['d'], ['e']]),
|
||||
... tf.ragged.constant([['f'], ['g']])])
|
||||
<tf.RaggedTensor [[b'a_X_d_X_f'], [b'b_X_e_X_g', b'c_X_e_X_g']]>
|
||||
|
||||
Args:
|
||||
inputs: A list of `RaggedTensor` or `Tensor` or `SparseTensor`.
|
||||
name: Optional name for the op.
|
||||
|
||||
Returns:
|
||||
A 2D `RaggedTensor` of type `string`.
|
||||
"""
|
||||
return _cross_internal(inputs=inputs, hashed_output=False, name=name)
|
||||
|
||||
|
||||
@tf_export('ragged.cross_hashed')
|
||||
def cross_hashed(inputs, num_buckets=0, hash_key=None, name=None):
|
||||
"""Generates hashed feature cross from a list of tensors.
|
||||
|
||||
The input tensors must have `rank=2`, and must all have the same number of
|
||||
rows. The result is a `RaggedTensor` with the same number of rows as the
|
||||
inputs, where `result[row]` contains a list of all combinations of values
|
||||
formed by taking a single value from each input's corresponding row
|
||||
(`inputs[i][row]`). Values are combined by hashing together their
|
||||
fingerprints. E.g.:
|
||||
|
||||
>>> tf.ragged.cross_hashed([tf.ragged.constant([['a'], ['b', 'c']]),
|
||||
... tf.ragged.constant([['d'], ['e']]),
|
||||
... tf.ragged.constant([['f'], ['g']])],
|
||||
... num_buckets=100)
|
||||
<tf.RaggedTensor [[78], [66, 74]]>
|
||||
|
||||
Args:
|
||||
inputs: A list of `RaggedTensor` or `Tensor` or `SparseTensor`.
|
||||
num_buckets: A non-negative `int` that used to bucket the hashed values. If
|
||||
`num_buckets != 0`, then `output = hashed_value % num_buckets`.
|
||||
hash_key: Integer hash_key that will be used by the `FingerprintCat64`
|
||||
function. If not given, a default key is used.
|
||||
name: Optional name for the op.
|
||||
|
||||
Returns:
|
||||
A 2D `RaggedTensor` of type `int64`.
|
||||
"""
|
||||
return _cross_internal(
|
||||
inputs=inputs,
|
||||
hashed_output=True,
|
||||
num_buckets=num_buckets,
|
||||
hash_key=hash_key,
|
||||
name=name)
|
||||
|
||||
|
||||
_DEFAULT_CROSS_HASH_KEY = 0xDECAFCAFFE
|
||||
|
||||
|
||||
def _cross_internal(inputs,
|
||||
hashed_output=False,
|
||||
num_buckets=0,
|
||||
hash_key=None,
|
||||
name=None):
|
||||
"""Generates feature cross from a list of ragged and dense tensors."""
|
||||
if not isinstance(inputs, (tuple, list)):
|
||||
raise TypeError('Inputs must be a list')
|
||||
|
||||
if hash_key is None:
|
||||
hash_key = _DEFAULT_CROSS_HASH_KEY
|
||||
|
||||
ragged_inputs = []
|
||||
sparse_inputs = []
|
||||
dense_inputs = []
|
||||
input_order = []
|
||||
with ops.name_scope(name, 'RaggedCross', inputs):
|
||||
for i, t in enumerate(inputs):
|
||||
if sparse_tensor.is_sparse(t):
|
||||
t = sparse_tensor.SparseTensor.from_value(t)
|
||||
else:
|
||||
t = ragged_tensor.convert_to_tensor_or_ragged_tensor(t)
|
||||
if t.dtype.is_integer:
|
||||
t = math_ops.cast(t, dtypes.int64)
|
||||
elif t.dtype != dtypes.string:
|
||||
raise ValueError('Unexpected dtype for inputs[%d]: %s' % (i, t.dtype))
|
||||
if isinstance(t, ragged_tensor.RaggedTensor):
|
||||
if t.ragged_rank != 1:
|
||||
raise ValueError('tf.ragged.cross only supports inputs with rank=2')
|
||||
ragged_inputs.append(t)
|
||||
input_order.append('R')
|
||||
elif isinstance(t, sparse_tensor.SparseTensor):
|
||||
sparse_inputs.append(t)
|
||||
input_order.append('S')
|
||||
else:
|
||||
dense_inputs.append(t)
|
||||
input_order.append('D')
|
||||
|
||||
out_values_type = dtypes.int64 if hashed_output else dtypes.string
|
||||
if ragged_inputs and all(
|
||||
t.row_splits.dtype == dtypes.int32 for t in ragged_inputs):
|
||||
out_row_splits_type = dtypes.int32
|
||||
else:
|
||||
out_row_splits_type = dtypes.int64
|
||||
|
||||
values_out, splits_out = gen_ragged_array_ops.ragged_cross(
|
||||
ragged_values=[rt.values for rt in ragged_inputs],
|
||||
ragged_row_splits=[rt.row_splits for rt in ragged_inputs],
|
||||
sparse_indices=[st.indices for st in sparse_inputs],
|
||||
sparse_values=[st.values for st in sparse_inputs],
|
||||
sparse_shape=[st.dense_shape for st in sparse_inputs],
|
||||
dense_inputs=dense_inputs,
|
||||
input_order=''.join(input_order),
|
||||
hashed_output=hashed_output,
|
||||
num_buckets=num_buckets,
|
||||
hash_key=hash_key,
|
||||
out_values_type=out_values_type,
|
||||
out_row_splits_type=out_row_splits_type,
|
||||
name=name)
|
||||
|
||||
return ragged_tensor.RaggedTensor.from_row_splits(
|
||||
values_out, splits_out, validate=False)
|
||||
|
396
tensorflow/python/ops/ragged/ragged_cross_op_test.py
Normal file
396
tensorflow/python/ops/ragged/ragged_cross_op_test.py
Normal file
@ -0,0 +1,396 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for tf.ragged.cross and tf.ragged.cross_hashed."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl.testing import parameterized
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import sparse_ops
|
||||
from tensorflow.python.ops.ragged import ragged_array_ops
|
||||
from tensorflow.python.ops.ragged import ragged_factory_ops
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
ragged_const = ragged_factory_ops.constant_value
|
||||
dense_const = np.array
|
||||
|
||||
|
||||
def sparse_const(matrix):
|
||||
indices = []
|
||||
values = []
|
||||
for i, row in enumerate(matrix):
|
||||
for j, val in enumerate(row):
|
||||
indices.append([i, j])
|
||||
values.append(val)
|
||||
shape = [len(matrix), max(len(row) for row in matrix)] if matrix else [0, 0]
|
||||
if not values:
|
||||
indices = np.zeros([0, 2], dtype=np.int64)
|
||||
values = np.zeros([0], dtype=np.int64)
|
||||
return sparse_tensor.SparseTensorValue(indices, values, shape)
|
||||
|
||||
|
||||
class RaggedCrossOpTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
|
||||
@parameterized.named_parameters([
|
||||
dict(
|
||||
testcase_name='NoInputs',
|
||||
inputs=[],
|
||||
expected=ragged_const([], ragged_rank=1, dtype=dtypes.int32)),
|
||||
dict(
|
||||
testcase_name='OneInput_RaggedStr',
|
||||
inputs=[ragged_const([['a', 'b'], [], ['c']])],
|
||||
expected=ragged_const([[b'a', b'b'], [], [b'c']])),
|
||||
dict(
|
||||
testcase_name='OneInput_RaggedInt',
|
||||
inputs=[ragged_const([[1, 2, 3], [4, 5]])],
|
||||
expected=ragged_const([[b'1', b'2', b'3'], [b'4', b'5']])),
|
||||
dict(
|
||||
testcase_name='OneInput_DenseInt',
|
||||
inputs=[dense_const([[1, 2, 3], [4, 5, 6]])],
|
||||
expected=ragged_const([[b'1', b'2', b'3'], [b'4', b'5', b'6']])),
|
||||
dict(
|
||||
testcase_name='OneInput_SparseStr',
|
||||
inputs=[sparse_const([['a', 'b'], [], ['c']])],
|
||||
expected=ragged_const([[b'a', b'b'], [], [b'c']])),
|
||||
dict(
|
||||
testcase_name='TwoInputs_RaggedStr_RaggedStr',
|
||||
inputs=[
|
||||
ragged_const([['a', 'b'], [], ['c']]),
|
||||
ragged_const([['d', 'e'], ['f'], ['g']])
|
||||
],
|
||||
expected=ragged_const([[b'a_X_d', b'a_X_e', b'b_X_d', b'b_X_e'], [],
|
||||
[b'c_X_g']])),
|
||||
dict(
|
||||
testcase_name='TwoInputs_RaggedInt_RaggedInt',
|
||||
inputs=[
|
||||
ragged_const([[1, 2], [], [3]]),
|
||||
ragged_const([[4, 5, 6], [], [7]])
|
||||
],
|
||||
expected=ragged_const(
|
||||
[[b'1_X_4', b'1_X_5', b'1_X_6', b'2_X_4', b'2_X_5', b'2_X_6'], [],
|
||||
[b'3_X_7']])),
|
||||
dict(
|
||||
testcase_name='TwoInputs_RaggedStr_RaggedInt',
|
||||
inputs=[
|
||||
ragged_const([['a', 'b'], [], ['c']]),
|
||||
ragged_const([['1', '2'], ['3'], ['4']])
|
||||
],
|
||||
expected=ragged_const([[b'a_X_1', b'a_X_2', b'b_X_1', b'b_X_2'], [],
|
||||
[b'c_X_4']])),
|
||||
dict(
|
||||
testcase_name='TwoInputs_SparseStr_SparseStr',
|
||||
inputs=[
|
||||
sparse_const([['a', 'b'], [], ['c']]),
|
||||
sparse_const([['d', 'e'], ['f'], ['g']])
|
||||
],
|
||||
expected=ragged_const([[b'a_X_d', b'a_X_e', b'b_X_d', b'b_X_e'], [],
|
||||
[b'c_X_g']])),
|
||||
dict(
|
||||
testcase_name='TwoInputs_DenseInt_DenseInt',
|
||||
inputs=[dense_const([[1, 2], [3, 4]]),
|
||||
dense_const([[5, 6], [7, 8]])],
|
||||
expected=ragged_const([[b'1_X_5', b'1_X_6', b'2_X_5', b'2_X_6'],
|
||||
[b'3_X_7', b'3_X_8', b'4_X_7', b'4_X_8']])),
|
||||
dict(
|
||||
testcase_name='TwoInputs_DenseInt_DenseStr',
|
||||
inputs=[
|
||||
dense_const([[1, 2], [3, 4]]),
|
||||
dense_const([[b'5', b'6'], [b'7', b'8']])
|
||||
],
|
||||
expected=ragged_const([[b'1_X_5', b'1_X_6', b'2_X_5', b'2_X_6'],
|
||||
[b'3_X_7', b'3_X_8', b'4_X_7', b'4_X_8']])),
|
||||
dict(
|
||||
testcase_name='TwoInputs_RaggedInt_DenseInt',
|
||||
inputs=[
|
||||
ragged_const([[], [], [1, 2], [3]]),
|
||||
dense_const([[1, 2], [3, 4], [5, 6], [7, 8]])
|
||||
],
|
||||
expected=ragged_const([[], [],
|
||||
[b'1_X_5', b'1_X_6', b'2_X_5', b'2_X_6'],
|
||||
[b'3_X_7', b'3_X_8']])),
|
||||
dict(
|
||||
# This test exercises `input_order`.
|
||||
testcase_name='TwoInputs_DenseInt_RaggedStr',
|
||||
inputs=[
|
||||
dense_const([[1, 2], [3, 4], [5, 6]]),
|
||||
ragged_const([['d', 'e'], ['f'], ['g']])
|
||||
],
|
||||
expected=ragged_const([[b'1_X_d', b'1_X_e', b'2_X_d', b'2_X_e'],
|
||||
[b'3_X_f', b'4_X_f'], [b'5_X_g', b'6_X_g']]),
|
||||
matches_sparse_cross=False # sparse doesn't preserve input order.
|
||||
),
|
||||
dict(
|
||||
# This test exercises `input_order`.
|
||||
testcase_name='TwoInputs_SparseInt_RaggedStr',
|
||||
inputs=[
|
||||
sparse_const([[1, 2], [3, 4], [5, 6]]),
|
||||
ragged_const([['d', 'e'], ['f'], ['g']])
|
||||
],
|
||||
expected=ragged_const([[b'1_X_d', b'1_X_e', b'2_X_d', b'2_X_e'],
|
||||
[b'3_X_f', b'4_X_f'], [b'5_X_g', b'6_X_g']]),
|
||||
matches_sparse_cross=False # sparse doesn't preserve input order.
|
||||
),
|
||||
dict(
|
||||
testcase_name='ThreeInputs_RaggedInt_RaggedInt_RaggedInt',
|
||||
inputs=[
|
||||
ragged_const([[11], [12, 13], [], [14, 15]]),
|
||||
ragged_const([[21, 22], [23], [24, 25], [26, 27]]),
|
||||
ragged_const([[31], [32, 33], [34, 35], [36, 37]])
|
||||
],
|
||||
expected=ragged_const([[b'11_X_21_X_31', b'11_X_22_X_31'],
|
||||
[
|
||||
b'12_X_23_X_32', b'12_X_23_X_33',
|
||||
b'13_X_23_X_32', b'13_X_23_X_33'
|
||||
], [],
|
||||
[
|
||||
b'14_X_26_X_36', b'14_X_26_X_37',
|
||||
b'14_X_27_X_36', b'14_X_27_X_37',
|
||||
b'15_X_26_X_36', b'15_X_26_X_37',
|
||||
b'15_X_27_X_36', b'15_X_27_X_37'
|
||||
]])),
|
||||
dict(
|
||||
testcase_name='ThreeInputs_RaggedInt_SparseInt_DenseInt',
|
||||
inputs=[
|
||||
ragged_const([[11], [12, 13], [], [14, 15]]),
|
||||
sparse_const([[21, 22], [23], [24, 25], [26, 27]]),
|
||||
dense_const([[31], [32], [33], [34]])
|
||||
],
|
||||
expected=ragged_const([[b'11_X_21_X_31', b'11_X_22_X_31'],
|
||||
[
|
||||
b'12_X_23_X_32',
|
||||
b'13_X_23_X_32',
|
||||
], [],
|
||||
[
|
||||
b'14_X_26_X_34',
|
||||
b'14_X_27_X_34',
|
||||
b'15_X_26_X_34',
|
||||
b'15_X_27_X_34',
|
||||
]])),
|
||||
dict(
|
||||
testcase_name='FiveInputs',
|
||||
inputs=[
|
||||
ragged_const([[1]]),
|
||||
dense_const([[2]]),
|
||||
ragged_const([[3]]),
|
||||
sparse_const([[4]]),
|
||||
ragged_const([[5]])
|
||||
],
|
||||
expected=ragged_const([[b'1_X_2_X_3_X_4_X_5']]),
|
||||
matches_sparse_cross=False # sparse doesn't preserve input order.
|
||||
),
|
||||
dict(
|
||||
testcase_name='Permutation_3x3x3',
|
||||
inputs=[[['11', '12', '13']], [['21', '22', '23']],
|
||||
[['31', '32', '33']]],
|
||||
expected=[[
|
||||
b'11_X_21_X_31', b'11_X_21_X_32', b'11_X_21_X_33',
|
||||
b'11_X_22_X_31', b'11_X_22_X_32', b'11_X_22_X_33',
|
||||
b'11_X_23_X_31', b'11_X_23_X_32', b'11_X_23_X_33',
|
||||
b'12_X_21_X_31', b'12_X_21_X_32', b'12_X_21_X_33',
|
||||
b'12_X_22_X_31', b'12_X_22_X_32', b'12_X_22_X_33',
|
||||
b'12_X_23_X_31', b'12_X_23_X_32', b'12_X_23_X_33',
|
||||
b'13_X_21_X_31', b'13_X_21_X_32', b'13_X_21_X_33',
|
||||
b'13_X_22_X_31', b'13_X_22_X_32', b'13_X_22_X_33',
|
||||
b'13_X_23_X_31', b'13_X_23_X_32', b'13_X_23_X_33'
|
||||
]]),
|
||||
dict(
|
||||
testcase_name='BatchSizeZero',
|
||||
inputs=[
|
||||
ragged_const([], ragged_rank=1, dtype=dtypes.int32),
|
||||
sparse_const([]),
|
||||
np.zeros([0, 3], dtype=np.int32),
|
||||
],
|
||||
expected=ragged_const([], ragged_rank=1, dtype=dtypes.int32)),
|
||||
dict(
|
||||
testcase_name='ThreeInputs_OneEmpty',
|
||||
inputs=[
|
||||
ragged_const([[1, 2]]),
|
||||
ragged_const([[]], dtype=dtypes.int32),
|
||||
ragged_const([[3, 4]])
|
||||
],
|
||||
expected=ragged_const([[]], dtype=dtypes.string)),
|
||||
dict(
|
||||
testcase_name='ThreeInputs_AllEmpty',
|
||||
inputs=[
|
||||
ragged_const([[]], dtype=dtypes.int64),
|
||||
ragged_const([[]], dtype=dtypes.string),
|
||||
ragged_const([[]], dtype=dtypes.int32)
|
||||
],
|
||||
expected=ragged_const([[]], ragged_rank=1, dtype=dtypes.string)),
|
||||
dict(
|
||||
testcase_name='HashedZeroBucketsDefaultKey',
|
||||
inputs=[
|
||||
ragged_const([['batch1-FC1-F1']]),
|
||||
ragged_const([['batch1-FC2-F1']]),
|
||||
ragged_const([['batch1-FC3-F1']])
|
||||
],
|
||||
expected_hashed=ragged_const([[1971693436396284976]])),
|
||||
dict(
|
||||
testcase_name='Hashed100BucketsDefaultKey',
|
||||
inputs=[
|
||||
ragged_const([['batch1-FC1-F1']]),
|
||||
ragged_const([['batch1-FC2-F1']]),
|
||||
ragged_const([['batch1-FC3-F1']])
|
||||
],
|
||||
num_buckets=100,
|
||||
expected_hashed=ragged_const([[83]])),
|
||||
dict(
|
||||
testcase_name='HashedZeroBucketsCustomKey',
|
||||
inputs=[
|
||||
ragged_const([['batch1-FC1-F1']]),
|
||||
ragged_const([['batch1-FC2-F1']]),
|
||||
ragged_const([['batch1-FC3-F1']])
|
||||
],
|
||||
hash_key=ragged_array_ops._DEFAULT_CROSS_HASH_KEY + 1,
|
||||
expected_hashed=ragged_const([[4847552627144134031]])),
|
||||
dict(
|
||||
testcase_name='Hashed100BucketsCustomKey',
|
||||
inputs=[
|
||||
ragged_const([['batch1-FC1-F1']]),
|
||||
ragged_const([['batch1-FC2-F1']]),
|
||||
ragged_const([['batch1-FC3-F1']])
|
||||
],
|
||||
num_buckets=100,
|
||||
hash_key=ragged_array_ops._DEFAULT_CROSS_HASH_KEY + 1,
|
||||
expected_hashed=ragged_const([[31]])),
|
||||
dict(
|
||||
testcase_name='HashedZeroKey',
|
||||
inputs=[
|
||||
ragged_const([['batch1-FC1-F1']]),
|
||||
ragged_const([['batch1-FC2-F1']]),
|
||||
ragged_const([['batch1-FC3-F1']])
|
||||
],
|
||||
hash_key=0,
|
||||
expected_hashed=ragged_const([[9077905385164735582]]),
|
||||
matches_sparse_cross=False # sparse treats hash_key=0 as None.
|
||||
),
|
||||
dict(
|
||||
testcase_name='UInt64',
|
||||
inputs=[ragged_const([[2**64 - 1]], dtype=dtypes.uint64)],
|
||||
expected=ragged_const([[b'-1']])),
|
||||
])
|
||||
def testRaggedCross(self,
|
||||
inputs,
|
||||
num_buckets=0,
|
||||
hash_key=None,
|
||||
expected=None,
|
||||
expected_hashed=None,
|
||||
matches_sparse_cross=True):
|
||||
ragged_cross = ragged_array_ops.cross(inputs)
|
||||
ragged_cross_hashed = ragged_array_ops.cross_hashed(inputs, num_buckets,
|
||||
hash_key)
|
||||
|
||||
if expected is not None:
|
||||
self.assertAllEqual(ragged_cross, expected)
|
||||
if expected_hashed is not None:
|
||||
self.assertAllEqual(ragged_cross_hashed, expected_hashed)
|
||||
|
||||
if matches_sparse_cross:
|
||||
# Check that ragged.cross & sparse.cross match.
|
||||
sparse_inputs = [self._ragged_to_sparse(t) for t in inputs]
|
||||
sparse_cross = sparse_ops.sparse_cross(sparse_inputs)
|
||||
self.assertAllEqual(ragged_cross,
|
||||
ragged_tensor.RaggedTensor.from_sparse(sparse_cross))
|
||||
|
||||
# Check that ragged.cross_hashed & sparse.cross_hashed match.
|
||||
sparse_inputs = [self._ragged_to_sparse(t) for t in inputs]
|
||||
sparse_cross_hashed = sparse_ops.sparse_cross_hashed(
|
||||
sparse_inputs, num_buckets, hash_key)
|
||||
self.assertAllEqual(
|
||||
ragged_cross_hashed,
|
||||
ragged_tensor.RaggedTensor.from_sparse(sparse_cross_hashed))
|
||||
|
||||
def testRaggedCrossLargeBatch(self):
|
||||
batch_size = 5000
|
||||
inputs = [
|
||||
ragged_const([[1, 2, 3]] * batch_size),
|
||||
ragged_const([[b'4']] * batch_size),
|
||||
dense_const([[5]] * batch_size),
|
||||
sparse_const([[6, 7]] * batch_size)
|
||||
]
|
||||
|
||||
expected = [[
|
||||
b'1_X_4_X_5_X_6', b'1_X_4_X_5_X_7', b'2_X_4_X_5_X_6', b'2_X_4_X_5_X_7',
|
||||
b'3_X_4_X_5_X_6', b'3_X_4_X_5_X_7'
|
||||
]] * batch_size
|
||||
|
||||
ragged_cross = ragged_array_ops.cross(inputs)
|
||||
|
||||
# Note: we don't use assertAllEqual here because if they don't match,
|
||||
# then the code in assertAllEqual that tries to build the error message
|
||||
# is very slow, causing the test to timeout.
|
||||
# pylint: disable=g-generic-assert
|
||||
self.assertTrue(self.evaluate(ragged_cross).to_list() == expected)
|
||||
|
||||
@parameterized.named_parameters([
|
||||
dict(
|
||||
testcase_name='BadDType',
|
||||
inputs=[ragged_const([[1.1], [2.2, 3.3]])],
|
||||
message=r'Unexpected dtype for inputs\[0\]'),
|
||||
dict(
|
||||
testcase_name='StaticBatchSizeMismatch1',
|
||||
inputs=[ragged_const([[1]]),
|
||||
ragged_const([[2], [3]])],
|
||||
exception=(ValueError, errors.InvalidArgumentError),
|
||||
message='inputs must all have the same batch dimension size'),
|
||||
dict(
|
||||
testcase_name='StaticBatchSizeMismatch2',
|
||||
inputs=[ragged_const([[1]]),
|
||||
dense_const([[2], [3]])],
|
||||
exception=(ValueError, errors.InvalidArgumentError),
|
||||
message='inputs must all have the same batch dimension size'),
|
||||
])
|
||||
def testStaticError(self, inputs, exception=ValueError, message=None):
|
||||
with self.assertRaisesRegexp(exception, message):
|
||||
ragged_array_ops.cross(inputs)
|
||||
|
||||
@parameterized.named_parameters([
|
||||
dict(
|
||||
testcase_name='3DRaggedTensor',
|
||||
inputs=[ragged_const([[[1]]], ragged_rank=1)],
|
||||
message='tf.ragged.cross only supports inputs with rank=2'),
|
||||
dict(
|
||||
testcase_name='3DDenseTensor',
|
||||
inputs=[dense_const([[[1]]])],
|
||||
message='tf.ragged.cross only supports inputs with rank=2'),
|
||||
])
|
||||
def testRuntimeError(self,
|
||||
inputs,
|
||||
exception=errors.InvalidArgumentError,
|
||||
message=None):
|
||||
with self.assertRaisesRegexp(exception, message):
|
||||
self.evaluate(ragged_array_ops.cross(inputs))
|
||||
|
||||
def _ragged_to_sparse(self, t):
|
||||
if ragged_tensor.is_ragged(t):
|
||||
return ragged_tensor.convert_to_tensor_or_ragged_tensor(t).to_sparse()
|
||||
elif sparse_tensor.is_sparse(t):
|
||||
return sparse_tensor.SparseTensor.from_value(t)
|
||||
else:
|
||||
return ops.convert_to_tensor(t)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
googletest.main()
|
@ -16,6 +16,14 @@ tf_module {
|
||||
name: "constant_value"
|
||||
argspec: "args=[\'pylist\', \'dtype\', \'ragged_rank\', \'inner_shape\', \'row_splits_dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'int64\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "cross"
|
||||
argspec: "args=[\'inputs\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "cross_hashed"
|
||||
argspec: "args=[\'inputs\', \'num_buckets\', \'hash_key\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "map_flat_values"
|
||||
argspec: "args=[\'op\'], varargs=args, keywords=kwargs, defaults=None"
|
||||
|
@ -3020,6 +3020,10 @@ tf_module {
|
||||
name: "RGBToHSV"
|
||||
argspec: "args=[\'images\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "RaggedCross"
|
||||
argspec: "args=[\'ragged_values\', \'ragged_row_splits\', \'sparse_indices\', \'sparse_values\', \'sparse_shape\', \'dense_inputs\', \'input_order\', \'hashed_output\', \'num_buckets\', \'hash_key\', \'out_values_type\', \'out_row_splits_type\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "RaggedGather"
|
||||
argspec: "args=[\'params_nested_splits\', \'params_dense_values\', \'indices\', \'OUTPUT_RAGGED_RANK\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
@ -8,6 +8,14 @@ tf_module {
|
||||
name: "constant"
|
||||
argspec: "args=[\'pylist\', \'dtype\', \'ragged_rank\', \'inner_shape\', \'name\', \'row_splits_dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \"<dtype: \'int64\'>\"], "
|
||||
}
|
||||
member_method {
|
||||
name: "cross"
|
||||
argspec: "args=[\'inputs\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "cross_hashed"
|
||||
argspec: "args=[\'inputs\', \'num_buckets\', \'hash_key\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "map_flat_values"
|
||||
argspec: "args=[\'op\'], varargs=args, keywords=kwargs, defaults=None"
|
||||
|
@ -3020,6 +3020,10 @@ tf_module {
|
||||
name: "RGBToHSV"
|
||||
argspec: "args=[\'images\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "RaggedCross"
|
||||
argspec: "args=[\'ragged_values\', \'ragged_row_splits\', \'sparse_indices\', \'sparse_values\', \'sparse_shape\', \'dense_inputs\', \'input_order\', \'hashed_output\', \'num_buckets\', \'hash_key\', \'out_values_type\', \'out_row_splits_type\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "RaggedGather"
|
||||
argspec: "args=[\'params_nested_splits\', \'params_dense_values\', \'indices\', \'OUTPUT_RAGGED_RANK\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
Loading…
Reference in New Issue
Block a user