Add tf.ragged.cross and tf.ragged.cross_hashed operations. These ops generate feature crosses from a list of input tensors (similar to tf.sparse.cross), and return the result as a RaggedTensor. Inputs may be tf.RaggedTensors, tf.Tensors, or tf.SparseTensors.

PiperOrigin-RevId: 298495411
Change-Id: I2edde721cfc4e236d78ece470ff30cc68cf6b70c
This commit is contained in:
Edward Loper 2020-03-02 18:10:08 -08:00 committed by TensorFlower Gardener
parent 81e2ecdaae
commit 72e7964b6a
11 changed files with 1327 additions and 18 deletions

View File

@ -0,0 +1,49 @@
op {
graph_op_name: "RaggedCross"
visibility: HIDDEN
in_arg {
name: "ragged_values"
description: "The values tensor for each RaggedTensor input."
}
in_arg {
name: "ragged_row_splits"
description: "The row_splits tensor for each RaggedTensor input."
}
in_arg {
name: "sparse_indices"
description: "The indices tensor for each SparseTensor input."
}
in_arg {
name: "sparse_values"
description: "The values tensor for each SparseTensor input."
}
in_arg {
name: "sparse_shape"
description: "The dense_shape tensor for each SparseTensor input."
}
in_arg {
name: "dense_inputs"
description: "The tf.Tensor inputs."
}
out_arg {
name: "output_values"
description: "The `values` for the returned `RaggedTensor`."
}
out_arg {
name: "output_row_splits"
description: "The `row_splits` for the returned `RaggedTensor`."
}
attr {
name: "input_order"
description: <<END
String specifying the tensor type for each input. The `i`th character in
this string specifies the type of the `i`th input, and is one of: 'R' (ragged),
'D' (dense), or 'S' (sparse). This attr is used to ensure that the crossed
values are combined in the order of the inputs from the call to tf.ragged.cross.
END
}
summary: <<END
Generates a feature cross from a list of tensors, and returns it as a
RaggedTensor. See `tf.ragged.cross` for more details.
END
}

View File

@ -1394,6 +1394,7 @@ tf_kernel_library(
cc_library(
name = "ragged_ops",
deps = [
":ragged_cross_op",
":ragged_gather_op",
":ragged_range_op",
":ragged_tensor_from_variant_op",
@ -1548,6 +1549,15 @@ tf_cc_test(
],
)
tf_kernel_library(
name = "ragged_cross_op",
srcs = ["ragged_cross_op.cc"],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
],
)
tf_kernel_library(
name = "rnn_ops",
deps = [

View File

@ -0,0 +1,609 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <limits>
#include <memory>
#include <string>
#include <vector>
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/platform/fingerprint.h"
#include "tensorflow/core/util/util.h"
#include "tensorflow/core/util/work_sharder.h"
namespace tensorflow {
namespace {
//==============================================================================
// Feature Readers
//==============================================================================
// A `FeatureReader` is used to read the feature values from a single input
// tensor. Subclasses are used for reading different tensor types:
// * RaggedFeatureReader<value_type, splits_type>
// * SparseFeatureReader<value_type>
// * DenseFeatureReader<value_type>
//
// Where value_type is one of: {tstring, int64}; and SplitsType is one of:
// {int32, int64}.
class FeatureReader {
public:
// Returns the number of feature values in the specified batch.
virtual int64 FeatureCount(int64 batch) const = 0;
// Copies the value for the specified feature to `out`.
virtual void ReadValue(int64 batch, int64 n, uint64* out) const = 0;
virtual void ReadValue(int64 batch, int64 n, tstring* out) const = 0;
virtual ~FeatureReader() {}
};
using FeatureReaders = std::vector<std::unique_ptr<FeatureReader>>;
// Copies a feature value `src` to a tstring `dst`, using a view if appropriate.
void CopyToString(const tstring& src, tstring* dst) {
if (src.type() == tstring::SMALL) {
*dst = src; // string buffer fits in the tstring object (under ~24 bytes)
} else {
dst->assign_as_view(src);
}
}
void CopyToString(int64 src, tstring* dst) { *dst = std::to_string(src); }
// Copies a feature value `src` to an int64 fingerprint `dst`.
void CopyToFingerprint(const tstring& feature, uint64* dst) {
*dst = Fingerprint64(feature);
}
void CopyToFingerprint(int64 feature, uint64* dst) { *dst = feature; }
// A FeatureReader that is backed by a ragged tensor.
template <typename ValuesType, typename SplitsType>
class RaggedFeatureReader : public FeatureReader {
public:
RaggedFeatureReader(const Tensor& values, const Tensor& row_splits)
: values_(values.flat<ValuesType>()),
row_splits_(row_splits.flat<SplitsType>()) {}
int64 FeatureCount(int64 batch) const override {
return row_splits_(batch + 1) - row_splits_(batch);
}
void ReadValue(int64 batch, int64 n, uint64* out) const override {
CopyToFingerprint(values_(row_splits_(batch) + n), out);
}
void ReadValue(int64 batch, int64 n, tstring* out) const override {
CopyToString(values_(row_splits_(batch) + n), out);
}
private:
const typename TTypes<ValuesType>::ConstFlat values_;
const typename TTypes<SplitsType>::ConstFlat row_splits_;
};
// A FeatureReader that is backed by a dense tensor.
template <typename ValuesType>
class DenseFeatureReader : public FeatureReader {
public:
explicit DenseFeatureReader(const Tensor& tensor)
: values_(tensor.matrix<ValuesType>()),
feature_count_(tensor.dim_size(1)) {}
int64 FeatureCount(int64 batch) const override { return feature_count_; }
void ReadValue(int64 batch, int64 n, uint64* out) const override {
CopyToFingerprint(values_(batch, n), out);
}
void ReadValue(int64 batch, int64 n, tstring* out) const override {
CopyToString(values_(batch, n), out);
}
private:
const typename TTypes<ValuesType>::ConstMatrix values_;
const int64 feature_count_;
};
// A FeatureReader that is backed by a sparse tensor.
template <typename ValuesType>
class SparseFeatureReader : public FeatureReader {
public:
SparseFeatureReader(const Tensor& indices_t, const Tensor& values_t,
int64 batch_size)
: values_(values_t.flat<ValuesType>()) {
row_splits_.reserve(batch_size + 1);
row_splits_.push_back(0);
auto indices = indices_t.matrix<int64>();
int64 num_values = values_.size();
int64 i = 0; // value index
for (int row = 0; row < batch_size; row++) {
while (i < num_values && indices(i, 0) <= row) ++i;
row_splits_.push_back(i);
}
}
int64 FeatureCount(int64 batch) const override {
return row_splits_[batch + 1] - row_splits_[batch];
}
void ReadValue(int64 batch, int64 n, uint64* out) const override {
CopyToFingerprint(values_(row_splits_[batch] + n), out);
}
void ReadValue(int64 batch, int64 n, tstring* out) const override {
CopyToString(values_(row_splits_[batch] + n), out);
}
private:
const typename TTypes<ValuesType>::ConstFlat values_;
std::vector<int64> row_splits_;
};
//==============================================================================
// Output Writers
//==============================================================================
// An `OutputWriter` is used to write the feature crosses to the output values
// tensor. Different subclasses are used for writing different output dtypes:
// * OutputWriterImpl<tstring, SplitsType> (for tf.ragged.cross)
// * OutputWriterImpl<int64, SplitsType> (for tf.ragged.cross_hashed)
class OutputWriter {
public:
virtual void WriteOutputSlice(int64 begin, int64 end) = 0;
virtual ~OutputWriter() {}
};
template <typename ValuesType, typename SplitsType>
class OutputWriterImpl : public OutputWriter {
public:
using FlatValues = typename TTypes<ValuesType>::Flat;
using FlatSplits = typename TTypes<SplitsType>::ConstFlat;
OutputWriterImpl(const FeatureReaders& features, int64 num_buckets,
uint64 hash_key, const Tensor* splits_out,
Tensor* values_out)
: features_(features),
num_buckets_(num_buckets),
hash_key_(hash_key),
splits_out_(splits_out->flat<SplitsType>()),
values_out_(values_out->flat<ValuesType>()) {}
// Reads features from the specified slice of batch indices, computes
// feature crosses for each one, and writes them to values_out_.
void WriteOutputSlice(int64 begin, int64 end) override {
std::vector<int> combination(features_.size(), 0);
for (int64 b = begin; b < end; ++b) {
auto row_start = splits_out_(b);
auto row_limit = splits_out_(b + 1);
for (auto i = row_start; i < row_limit; ++i) {
WriteCombination(b, combination, &values_out_(i));
NextCombination(b, &combination);
}
combination.assign(features_.size(), 0); // reset for next batch.
}
}
private:
// Joins the specified combination of input features into a single string,
// and writes it to *out.
void WriteCombination(int64 batch_index, const std::vector<int>& combination,
tstring* out) {
static const auto k_feature_separator = "_X_";
gtl::InlinedVector<tstring, 6> cross_vec(features_.size());
for (int i = 0; i < combination.size(); ++i) {
features_[i]->ReadValue(batch_index, combination[i], &cross_vec[i]);
}
*out = absl::StrJoin(cross_vec, k_feature_separator);
}
// Joins the specified combination of input features into a single
// fingerprint, and writes it to *out.
void WriteCombination(int64 batch_index, const std::vector<int>& combination,
int64* out) {
// Do the fingerprint concatenation on uint64.
uint64 hashed_output = hash_key_;
for (size_t i = 0; i < combination.size(); ++i) {
uint64 hash_i;
features_[i]->ReadValue(batch_index, combination[i], &hash_i);
hashed_output = FingerprintCat64(hashed_output, hash_i);
}
// The return value is int64 based on the number of buckets.
if (num_buckets_ > 0) {
*out = hashed_output % num_buckets_;
} else {
// To prevent negative output we take modulo to max int64.
*out = hashed_output % std::numeric_limits<int64>::max();
}
}
// Updates `combination` to the next combination of input features.
void NextCombination(int64 batch_index, std::vector<int>* combination) const {
bool carry = true;
for (int i = combination->size() - 1; i >= 0; i--) {
if (carry) {
(*combination)[i] = (*combination)[i] + 1;
}
if ((*combination)[i] == features_[i]->FeatureCount(batch_index)) {
(*combination)[i] = 0;
} else {
carry = false;
break;
}
}
}
const FeatureReaders& features_;
const int64 num_buckets_;
const uint64 hash_key_;
FlatSplits splits_out_;
FlatValues values_out_;
};
// Returns an appropriate OutputWriter, based on the dtypes of the
// given tensors.
std::unique_ptr<OutputWriter> MakeOutputWriter(const FeatureReaders& features,
int64 num_buckets,
uint64 hash_key,
const Tensor* splits_out,
Tensor* values_out) {
if (values_out->dtype() == DT_INT64) {
if (splits_out->dtype() == DT_INT64) {
return std::make_unique<OutputWriterImpl<int64, int64>>(
features, num_buckets, hash_key, splits_out, values_out);
} else {
return std::make_unique<OutputWriterImpl<int64, int32>>(
features, num_buckets, hash_key, splits_out, values_out);
}
} else {
if (splits_out->dtype() == DT_INT64) {
return std::make_unique<OutputWriterImpl<tstring, int64>>(
features, num_buckets, hash_key, splits_out, values_out);
} else {
return std::make_unique<OutputWriterImpl<tstring, int32>>(
features, num_buckets, hash_key, splits_out, values_out);
}
}
}
//==============================================================================
// RaggedCross Kernel
//==============================================================================
template <typename SplitsType>
class RaggedCrossOp : public OpKernel {
public:
explicit RaggedCrossOp(OpKernelConstruction* context) : OpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("num_buckets", &num_buckets_));
// Read signed_hash_key_ as int64 since uint64 attributes are not
// supported by REGISTER_OP.
int64 signed_hash_key_;
OP_REQUIRES_OK(context, context->GetAttr("hash_key", &signed_hash_key_));
hash_key_ = static_cast<uint64>(signed_hash_key_);
int num_sparse;
OP_REQUIRES_OK(context, context->GetAttr("Nsparse", &num_sparse));
OP_REQUIRES_OK(context, context->GetAttr("ragged_values_types",
&ragged_values_types_));
OP_REQUIRES_OK(context, context->GetAttr("ragged_splits_types",
&ragged_splits_types_));
OP_REQUIRES_OK(context, context->GetAttr("sparse_values_types",
&sparse_values_types_));
OP_REQUIRES_OK(context, context->GetAttr("dense_types", &dense_types_));
OP_REQUIRES_OK(context, context->GetAttr("input_order", &input_order_));
OP_REQUIRES(context,
ragged_values_types_.size() == ragged_splits_types_.size(),
errors::InvalidArgument(
"ragged values and splits must have the same length"));
OP_REQUIRES(context, num_sparse == sparse_values_types_.size(),
errors::InvalidArgument(
"sparse indices and values must have the same length"));
OP_REQUIRES(context,
ragged_values_types_.size() + sparse_values_types_.size() +
dense_types_.size() ==
input_order_.size(),
errors::InvalidArgument("Invalid length for input_order"));
}
void Compute(OpKernelContext* context) override {
OpInputList ragged_values_list;
OpInputList ragged_splits_list;
OpInputList sparse_indices_list;
OpInputList sparse_values_list;
OpInputList sparse_shape_list;
OpInputList dense_list;
OP_REQUIRES_OK(context,
context->input_list("ragged_values", &ragged_values_list));
OP_REQUIRES_OK(
context, context->input_list("ragged_row_splits", &ragged_splits_list));
OP_REQUIRES_OK(context,
context->input_list("sparse_indices", &sparse_indices_list));
OP_REQUIRES_OK(context,
context->input_list("sparse_values", &sparse_values_list));
OP_REQUIRES_OK(context,
context->input_list("sparse_shape", &sparse_shape_list));
OP_REQUIRES_OK(context, context->input_list("dense_inputs", &dense_list));
OP_REQUIRES_OK(context,
ValidateInput(ragged_values_list, ragged_splits_list,
sparse_indices_list, sparse_values_list,
sparse_shape_list, dense_list));
int64 batch_size =
CalculateBatchSize(ragged_splits_list, sparse_shape_list, dense_list);
FeatureReaders features;
OP_REQUIRES_OK(context,
BuildFeatureReaders(ragged_values_list, ragged_splits_list,
sparse_indices_list, sparse_values_list,
dense_list, batch_size, &features));
Tensor* values_out;
Tensor* row_splits_out;
OP_REQUIRES_OK(context, BuildOutputTensors(features, batch_size, context,
&values_out, &row_splits_out));
std::unique_ptr<OutputWriter> output_writer = MakeOutputWriter(
features, num_buckets_, hash_key_, row_splits_out, values_out);
auto do_work = [&output_writer](int64 begin, int64 end) {
output_writer->WriteOutputSlice(begin, end);
};
// TODO(edloper): optimize cost_per_batch
const int cost_per_batch = 5000 * ragged_values_list.size();
auto thread_pool =
context->device()->tensorflow_cpu_worker_threads()->workers;
thread_pool->ParallelFor(batch_size, cost_per_batch, do_work);
}
private:
// Validates input tensors.
Status ValidateInput(const OpInputList& ragged_values_list,
const OpInputList& ragged_splits_list,
const OpInputList& sparse_indices_list,
const OpInputList& sparse_values_list,
const OpInputList& sparse_shape_list,
const OpInputList& dense_list) {
const auto num_ragged = ragged_values_list.size();
const auto num_sparse = sparse_indices_list.size();
// Validate tensor shapes.
for (int i = 0; i < num_ragged; ++i) {
if (!TensorShapeUtils::IsVector(ragged_values_list[i].shape())) {
return errors::InvalidArgument(
"tf.ragged.cross only supports inputs with rank=2.");
}
if (!TensorShapeUtils::IsVector(ragged_splits_list[i].shape()) ||
(ragged_splits_list[i].NumElements() == 0)) {
return errors::InvalidArgument("Invalid RaggedTensor");
}
}
for (int i = 0; i < num_sparse; ++i) {
if (!TensorShapeUtils::IsMatrix(sparse_indices_list[i].shape()) ||
!TensorShapeUtils::IsVector(sparse_values_list[i].shape()) ||
!TensorShapeUtils::IsVector(sparse_shape_list[i].shape())) {
return errors::InvalidArgument("Invalid SparseTensor ", i);
}
if (sparse_shape_list[i].NumElements() != 2) {
return errors::InvalidArgument(
"tf.ragged.cross only supports inputs with rank=2.");
}
}
for (int i = 0; i < dense_list.size(); ++i) {
if (!TensorShapeUtils::IsMatrix(dense_list[i].shape())) {
return errors::InvalidArgument(
"tf.ragged.cross only supports inputs with rank=2.");
}
}
// Check that batch sizes are consistent.
int64 batch_size =
CalculateBatchSize(ragged_splits_list, sparse_shape_list, dense_list);
for (int i = 0; i < num_ragged; ++i) {
if (ragged_splits_list[i].NumElements() - 1 != batch_size) {
return errors::InvalidArgument(
"inputs must all have the same batch dimension size.");
}
}
for (int i = 0; i < num_sparse; ++i) {
if (sparse_shape_list[i].flat<int64>()(0) != batch_size) {
return errors::InvalidArgument(
"inputs must all have the same batch dimension size.");
}
}
for (int i = 0; i < dense_list.size(); ++i) {
if (dense_list[i].dim_size(0) != batch_size) {
return errors::InvalidArgument(
"inputs must all have the same batch dimension size.");
}
}
return Status::OK();
}
// Calculate the batch size from any input tensor. (We check that all input
// tensors have the same batch size in `ValidateInput`).
int64 CalculateBatchSize(const OpInputList& ragged_splits_list,
const OpInputList& sparse_shape_list,
const OpInputList& dense_list) {
if (ragged_splits_list.size() > 0) {
return ragged_splits_list[0].NumElements() - 1;
} else if (dense_list.size() > 0) {
return dense_list[0].dim_size(0);
} else if (sparse_shape_list.size() > 0) {
return sparse_shape_list[0].flat<int64>()(0);
} else {
return 0;
}
}
// Build a feature reader for each input tensor, and store them in `features`.
Status BuildFeatureReaders(const OpInputList& ragged_values_list,
const OpInputList& ragged_splits_list,
const OpInputList& sparse_indices_list,
const OpInputList& sparse_values_list,
const OpInputList& dense_list, int64 batch_size,
FeatureReaders* features) {
features->reserve(input_order_.size());
int next_ragged = 0;
int next_sparse = 0;
int next_dense = 0;
for (char c : input_order_) {
if (c == 'R') {
TF_RETURN_IF_ERROR(BuildRaggedFeatureReader(
ragged_values_list[next_ragged], ragged_splits_list[next_ragged],
features));
next_ragged++;
} else if (c == 'S') {
TF_RETURN_IF_ERROR(BuildSparseFeatureReader(
sparse_indices_list[next_sparse], sparse_values_list[next_sparse],
batch_size, features));
next_sparse++;
} else if (c == 'D') {
TF_RETURN_IF_ERROR(
BuildDenseFeatureReader(dense_list[next_dense++], features));
} else {
return errors::InvalidArgument("Unexpected input_order value.");
}
}
return Status::OK();
}
// Builds a RaggedReatureReader
static Status BuildRaggedFeatureReader(const Tensor& values,
const Tensor& splits,
FeatureReaders* features) {
if (values.dtype() != DT_INT64 && values.dtype() != DT_STRING) {
return errors::InvalidArgument("Unexpected dtype for input ",
(features->size() + 1), ": ",
values.dtype());
}
if (splits.dtype() != DT_INT64 && splits.dtype() != DT_INT32) {
return errors::InvalidArgument("Unexpected row_splits.dtype for input ",
(features->size() + 1), ": ",
values.dtype());
}
if (values.dtype() == DT_INT64) {
if (splits.dtype() == DT_INT64) {
features->emplace_back(
new RaggedFeatureReader<int64, int64>(values, splits));
} else {
features->emplace_back(
new RaggedFeatureReader<int64, int32>(values, splits));
}
} else {
if (splits.dtype() == DT_INT64) {
features->emplace_back(
new RaggedFeatureReader<tstring, int64>(values, splits));
} else {
features->emplace_back(
new RaggedFeatureReader<tstring, int32>(values, splits));
}
}
return Status::OK();
}
// Builds a DenseFaggedReatureReader.
static Status BuildDenseFeatureReader(const Tensor& values,
FeatureReaders* features) {
if (values.dtype() == DT_INT64) {
features->emplace_back(new DenseFeatureReader<int64>(values));
} else if (values.dtype() == DT_STRING) {
features->emplace_back(new DenseFeatureReader<tstring>(values));
} else {
return errors::InvalidArgument("Unexpected dtype for input ",
(features->size() + 1), ": ",
values.dtype());
}
return Status::OK();
}
// Builds a SparseFaggedReatureReader.
static Status BuildSparseFeatureReader(const Tensor& indices,
const Tensor& values, int64 batch_size,
FeatureReaders* features) {
if (values.dtype() == DT_INT64) {
features->emplace_back(
new SparseFeatureReader<int64>(indices, values, batch_size));
} else if (values.dtype() == DT_STRING) {
features->emplace_back(
new SparseFeatureReader<tstring>(indices, values, batch_size));
} else {
return errors::InvalidArgument("Unexpected dtype for input ",
(features->size() + 1), ": ",
values.dtype());
}
return Status::OK();
}
// Allocates output tensors with proper size, and populates row_splits_out.
Status BuildOutputTensors(const FeatureReaders& features, int64 batch_size,
OpKernelContext* context, Tensor** values_out,
Tensor** row_splits_out) {
// Allocate and populate the row_splits output tensor.
TF_RETURN_IF_ERROR(context->allocate_output(
1, TensorShape({batch_size + 1}), row_splits_out));
auto flat_row_splits = (*row_splits_out)->flat<SplitsType>();
int64 cross_count_total = 0;
flat_row_splits(0) = 0;
for (int64 b = 0; b < batch_size; b++) {
cross_count_total += CrossCountByBatchIndex(features, b);
flat_row_splits(b + 1) = cross_count_total;
}
// Allocate the values output tensor.
TF_RETURN_IF_ERROR(context->allocate_output(
0, TensorShape({cross_count_total}), values_out));
return Status::OK();
}
// Returns number of crosses for a given batch_index
int64 CrossCountByBatchIndex(const FeatureReaders& features,
int batch_index) {
int64 cross_count = 1;
for (int i = 0; i < features.size(); ++i) {
const auto feature_count = features[i]->FeatureCount(batch_index);
if (feature_count == 0) return 0;
cross_count *= feature_count;
}
return cross_count;
}
int64 num_buckets_;
uint64 hash_key_;
std::vector<DataType> ragged_values_types_;
std::vector<DataType> ragged_splits_types_;
std::vector<DataType> sparse_values_types_;
std::vector<DataType> dense_types_;
tstring input_order_;
};
REGISTER_KERNEL_BUILDER(Name("RaggedCross")
.Device(DEVICE_CPU)
.TypeConstraint<int32>("out_row_splits_type"),
RaggedCrossOp<int32>);
REGISTER_KERNEL_BUILDER(Name("RaggedCross")
.Device(DEVICE_CPU)
.TypeConstraint<int64>("out_row_splits_type"),
RaggedCrossOp<int64>);
} // namespace
} // namespace tensorflow

View File

@ -41,6 +41,79 @@ REGISTER_OP("RaggedGather")
.Attr("OUTPUT_RAGGED_RANK: int >= 0")
.SetShapeFn(RaggedGatherShapeFn);
REGISTER_OP("RaggedCross")
.Input("ragged_values: ragged_values_types")
.Input("ragged_row_splits: ragged_splits_types")
.Input("sparse_indices: Nsparse * int64")
.Input("sparse_values: sparse_values_types")
.Input("sparse_shape: Nsparse * int64")
.Input("dense_inputs: dense_types")
.Output("output_values: out_values_type")
.Output("output_row_splits: out_row_splits_type")
.Attr("Nsparse: int >= 0")
.Attr("input_order: string")
.Attr("hashed_output: bool")
.Attr("num_buckets: int >= 0")
.Attr("hash_key: int")
.Attr("ragged_values_types: list({int64, string}) >= 0")
.Attr("ragged_splits_types: list({int32, int64}) >= 0")
.Attr("sparse_values_types: list({int64, string}) >= 0")
.Attr("dense_types: list({int64, string}) >= 0")
.Attr("out_values_type: {int64, string}")
.Attr("out_row_splits_type: {int32, int64}")
.SetShapeFn([](shape_inference::InferenceContext* c) {
std::vector<DataType> ragged_values_types;
std::vector<DataType> ragged_splits_types;
std::vector<DataType> dense_types;
TF_RETURN_IF_ERROR(
c->GetAttr("ragged_values_types", &ragged_values_types));
TF_RETURN_IF_ERROR(
c->GetAttr("ragged_splits_types", &ragged_splits_types));
TF_RETURN_IF_ERROR(c->GetAttr("dense_types", &dense_types));
int num_ragged = ragged_values_types.size();
if (num_ragged != ragged_splits_types.size()) {
return errors::InvalidArgument(
"Parameters `values` and `row_splits` must be the same length");
}
int num_sparse;
TF_RETURN_IF_ERROR(c->GetAttr("Nsparse", &num_sparse));
ShapeHandle out_values = c->UnknownShapeOfRank(1);
ShapeHandle out_splits = c->UnknownShapeOfRank(1);
// Merge the shapes of row_splits from ragged inputs. (This is one plus
// the batch size.)
int ragged_splits_start = num_ragged;
for (int i = 0; i < ragged_splits_types.size(); ++i) {
ShapeHandle row_splits = c->input(i + ragged_splits_start);
if (!c->Merge(out_splits, row_splits, &out_splits).ok()) {
return errors::InvalidArgument(
"inputs must all have the same batch dimension size.");
}
}
// Merge the batch size of each dense input into out_splits.
int dense_start = num_ragged * 2 + num_sparse * 3;
for (int i = 0; i < dense_types.size(); ++i) {
ShapeHandle dense_input = c->input(i + dense_start);
int64 batch_size = c->Value(c->Dim(dense_input, 0));
if (batch_size != InferenceContext::kUnknownDim) {
ShapeHandle row_splits = c->Vector(batch_size + 1);
if (!c->Merge(out_splits, row_splits, &out_splits).ok()) {
return errors::InvalidArgument(
"inputs must all have the same batch dimension size.");
}
}
}
c->set_output(0, out_values);
c->set_output(1, out_splits);
return Status::OK();
});
//==============================================================================
// Shape Functions
//==============================================================================

View File

@ -1123,3 +1123,17 @@ py_test(
"@absl_py//absl/testing:parameterized",
],
)
py_test(
name = "ragged_cross_op_test",
srcs = ["ragged_cross_op_test.py"],
python_version = "PY3",
srcs_version = "PY2AND3",
deps = [
":ragged_array_ops",
"//tensorflow/python:constant_op",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
"@absl_py//absl/testing:parameterized",
],
)

View File

@ -20,9 +20,11 @@ from __future__ import print_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import gen_ragged_array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import sort_ops
from tensorflow.python.ops.ragged import ragged_functional_ops
@ -32,7 +34,6 @@ from tensorflow.python.ops.ragged import ragged_util
from tensorflow.python.ops.ragged import segment_id_ops
from tensorflow.python.util.tf_export import tf_export
#===============================================================================
# Masking
#===============================================================================
@ -107,7 +108,8 @@ def boolean_mask(data, mask, name=None):
if ragged_tensor.is_ragged(mask):
if not ragged_tensor.is_ragged(data):
data = ragged_tensor.RaggedTensor.from_tensor(
data, ragged_rank=mask.ragged_rank,
data,
ragged_rank=mask.ragged_rank,
row_splits_dtype=mask.row_splits.dtype)
# Check that mask.nested_row_splits is a prefix of
# data.nested_row_splits.
@ -160,15 +162,15 @@ def boolean_mask(data, mask, name=None):
segment_mask = array_ops.gather(mask, segment_ids)
masked_values = boolean_mask(data.values, segment_mask)
return ragged_tensor.RaggedTensor.from_row_splits(masked_values,
masked_splits,
validate=False)
return ragged_tensor.RaggedTensor.from_row_splits(
masked_values, masked_splits, validate=False)
# If mask is non-ragged and has rank>1, then convert it to be ragged,
# with a ragged rank matching data.
if ragged_tensor.is_ragged(data):
mask = ragged_tensor.RaggedTensor.from_tensor(
mask, ragged_rank=min(data.ragged_rank, mask.shape.ndims - 1),
mask,
ragged_rank=min(data.ragged_rank, mask.shape.ndims - 1),
row_splits_dtype=data.row_splits.dtype)
return boolean_mask(data, mask)
@ -182,8 +184,8 @@ def boolean_mask(data, mask, name=None):
# number of values it contains. Then flatten that to get a list of
# cell lengths, and convert it to splits. Finally, combine the splits
# and values to get the innermost ragged tensor.
masked_lengths = math_ops.count_nonzero(mask, axis=-1,
dtype=row_splits_dtype)
masked_lengths = math_ops.count_nonzero(
mask, axis=-1, dtype=row_splits_dtype)
flattened_masked_lengths = array_ops.reshape(masked_lengths, [-1])
masked_values = ragged_tensor.RaggedTensor.from_row_lengths(
masked_values, flattened_masked_lengths, validate=False)
@ -340,8 +342,7 @@ def _tile_ragged_splits(rt_input, multiples, const_multiples=None):
for src_axis in range(ragged_rank):
for dst_axis in range(src_axis + 1, ragged_rank - 1):
projected_splits[src_axis][dst_axis] = array_ops.gather(
nested_splits[dst_axis],
projected_splits[src_axis][dst_axis - 1])
nested_splits[dst_axis], projected_splits[src_axis][dst_axis - 1])
# For each ragged dimension: nested_splits[axis] -> result_splits[axis].
result_splits = []
@ -400,8 +401,7 @@ def expand_dims(input, axis, name=None): # pylint: disable=redefined-builtin
`[D1, (D2), (D3), D4]` | `4` | `[D1, (D2), (D3), D4, 1]`
Args:
input: The potentially tensor that should be expanded with a new
dimension.
input: The potentially tensor that should be expanded with a new dimension.
axis: An integer constant indicating where the new dimension should be
inserted.
name: A name for the operation (optional).
@ -556,10 +556,10 @@ def stack_dynamic_partitions(data, partitions, num_partitions, name=None):
Args:
data: A `Tensor` or `RaggedTensor` containing the values to stack.
partitions: An `int32` or `int64` `Tensor` or `RaggedTensor` specifying the
partition that each slice of `data` should be added to.
`partitions.shape` must be a prefix of `data.shape`. Values must be
greater than or equal to zero, and less than `num_partitions`.
`partitions` is not required to be sorted.
partition that each slice of `data` should be added to. `partitions.shape`
must be a prefix of `data.shape`. Values must be greater than or equal to
zero, and less than `num_partitions`. `partitions` is not required to be
sorted.
num_partitions: An `int32` or `int64` scalar specifying the number of
partitions to output. This determines the number of rows in `output`.
name: A name prefix for the returned tensor (optional).
@ -650,8 +650,8 @@ def reverse(tensor, axis, name=None):
Args:
tensor: A 'RaggedTensor' to reverse.
axis: A list or tuple of 'int' or a constant 1D 'tf.Tensor'. The indices
of the axes to reverse.
axis: A list or tuple of 'int' or a constant 1D 'tf.Tensor'. The indices of
the axes to reverse.
name: A name prefix for the returned tensor (optional).
Returns:
@ -687,3 +687,137 @@ def reverse(tensor, axis, name=None):
slices[dim] = slice(None, None, -1)
return tensor[tuple(slices)]
#===============================================================================
# Cross
#===============================================================================
@tf_export('ragged.cross')
def cross(inputs, name=None):
"""Generates feature cross from a list of tensors.
The input tensors must have `rank=2`, and must all have the same number of
rows. The result is a `RaggedTensor` with the same number of rows as the
inputs, where `result[row]` contains a list of all combinations of values
formed by taking a single value from each input's corresponding row
(`inputs[i][row]`). Values are combined by joining their strings with '_X_'.
E.g.:
>>> tf.ragged.cross([tf.ragged.constant([['a'], ['b', 'c']]),
... tf.ragged.constant([['d'], ['e']]),
... tf.ragged.constant([['f'], ['g']])])
<tf.RaggedTensor [[b'a_X_d_X_f'], [b'b_X_e_X_g', b'c_X_e_X_g']]>
Args:
inputs: A list of `RaggedTensor` or `Tensor` or `SparseTensor`.
name: Optional name for the op.
Returns:
A 2D `RaggedTensor` of type `string`.
"""
return _cross_internal(inputs=inputs, hashed_output=False, name=name)
@tf_export('ragged.cross_hashed')
def cross_hashed(inputs, num_buckets=0, hash_key=None, name=None):
"""Generates hashed feature cross from a list of tensors.
The input tensors must have `rank=2`, and must all have the same number of
rows. The result is a `RaggedTensor` with the same number of rows as the
inputs, where `result[row]` contains a list of all combinations of values
formed by taking a single value from each input's corresponding row
(`inputs[i][row]`). Values are combined by hashing together their
fingerprints. E.g.:
>>> tf.ragged.cross_hashed([tf.ragged.constant([['a'], ['b', 'c']]),
... tf.ragged.constant([['d'], ['e']]),
... tf.ragged.constant([['f'], ['g']])],
... num_buckets=100)
<tf.RaggedTensor [[78], [66, 74]]>
Args:
inputs: A list of `RaggedTensor` or `Tensor` or `SparseTensor`.
num_buckets: A non-negative `int` that used to bucket the hashed values. If
`num_buckets != 0`, then `output = hashed_value % num_buckets`.
hash_key: Integer hash_key that will be used by the `FingerprintCat64`
function. If not given, a default key is used.
name: Optional name for the op.
Returns:
A 2D `RaggedTensor` of type `int64`.
"""
return _cross_internal(
inputs=inputs,
hashed_output=True,
num_buckets=num_buckets,
hash_key=hash_key,
name=name)
_DEFAULT_CROSS_HASH_KEY = 0xDECAFCAFFE
def _cross_internal(inputs,
hashed_output=False,
num_buckets=0,
hash_key=None,
name=None):
"""Generates feature cross from a list of ragged and dense tensors."""
if not isinstance(inputs, (tuple, list)):
raise TypeError('Inputs must be a list')
if hash_key is None:
hash_key = _DEFAULT_CROSS_HASH_KEY
ragged_inputs = []
sparse_inputs = []
dense_inputs = []
input_order = []
with ops.name_scope(name, 'RaggedCross', inputs):
for i, t in enumerate(inputs):
if sparse_tensor.is_sparse(t):
t = sparse_tensor.SparseTensor.from_value(t)
else:
t = ragged_tensor.convert_to_tensor_or_ragged_tensor(t)
if t.dtype.is_integer:
t = math_ops.cast(t, dtypes.int64)
elif t.dtype != dtypes.string:
raise ValueError('Unexpected dtype for inputs[%d]: %s' % (i, t.dtype))
if isinstance(t, ragged_tensor.RaggedTensor):
if t.ragged_rank != 1:
raise ValueError('tf.ragged.cross only supports inputs with rank=2')
ragged_inputs.append(t)
input_order.append('R')
elif isinstance(t, sparse_tensor.SparseTensor):
sparse_inputs.append(t)
input_order.append('S')
else:
dense_inputs.append(t)
input_order.append('D')
out_values_type = dtypes.int64 if hashed_output else dtypes.string
if ragged_inputs and all(
t.row_splits.dtype == dtypes.int32 for t in ragged_inputs):
out_row_splits_type = dtypes.int32
else:
out_row_splits_type = dtypes.int64
values_out, splits_out = gen_ragged_array_ops.ragged_cross(
ragged_values=[rt.values for rt in ragged_inputs],
ragged_row_splits=[rt.row_splits for rt in ragged_inputs],
sparse_indices=[st.indices for st in sparse_inputs],
sparse_values=[st.values for st in sparse_inputs],
sparse_shape=[st.dense_shape for st in sparse_inputs],
dense_inputs=dense_inputs,
input_order=''.join(input_order),
hashed_output=hashed_output,
num_buckets=num_buckets,
hash_key=hash_key,
out_values_type=out_values_type,
out_row_splits_type=out_row_splits_type,
name=name)
return ragged_tensor.RaggedTensor.from_row_splits(
values_out, splits_out, validate=False)

View File

@ -0,0 +1,396 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tf.ragged.cross and tf.ragged.cross_hashed."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
import numpy as np
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_util
from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops.ragged import ragged_array_ops
from tensorflow.python.ops.ragged import ragged_factory_ops
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.platform import googletest
ragged_const = ragged_factory_ops.constant_value
dense_const = np.array
def sparse_const(matrix):
indices = []
values = []
for i, row in enumerate(matrix):
for j, val in enumerate(row):
indices.append([i, j])
values.append(val)
shape = [len(matrix), max(len(row) for row in matrix)] if matrix else [0, 0]
if not values:
indices = np.zeros([0, 2], dtype=np.int64)
values = np.zeros([0], dtype=np.int64)
return sparse_tensor.SparseTensorValue(indices, values, shape)
class RaggedCrossOpTest(test_util.TensorFlowTestCase, parameterized.TestCase):
@parameterized.named_parameters([
dict(
testcase_name='NoInputs',
inputs=[],
expected=ragged_const([], ragged_rank=1, dtype=dtypes.int32)),
dict(
testcase_name='OneInput_RaggedStr',
inputs=[ragged_const([['a', 'b'], [], ['c']])],
expected=ragged_const([[b'a', b'b'], [], [b'c']])),
dict(
testcase_name='OneInput_RaggedInt',
inputs=[ragged_const([[1, 2, 3], [4, 5]])],
expected=ragged_const([[b'1', b'2', b'3'], [b'4', b'5']])),
dict(
testcase_name='OneInput_DenseInt',
inputs=[dense_const([[1, 2, 3], [4, 5, 6]])],
expected=ragged_const([[b'1', b'2', b'3'], [b'4', b'5', b'6']])),
dict(
testcase_name='OneInput_SparseStr',
inputs=[sparse_const([['a', 'b'], [], ['c']])],
expected=ragged_const([[b'a', b'b'], [], [b'c']])),
dict(
testcase_name='TwoInputs_RaggedStr_RaggedStr',
inputs=[
ragged_const([['a', 'b'], [], ['c']]),
ragged_const([['d', 'e'], ['f'], ['g']])
],
expected=ragged_const([[b'a_X_d', b'a_X_e', b'b_X_d', b'b_X_e'], [],
[b'c_X_g']])),
dict(
testcase_name='TwoInputs_RaggedInt_RaggedInt',
inputs=[
ragged_const([[1, 2], [], [3]]),
ragged_const([[4, 5, 6], [], [7]])
],
expected=ragged_const(
[[b'1_X_4', b'1_X_5', b'1_X_6', b'2_X_4', b'2_X_5', b'2_X_6'], [],
[b'3_X_7']])),
dict(
testcase_name='TwoInputs_RaggedStr_RaggedInt',
inputs=[
ragged_const([['a', 'b'], [], ['c']]),
ragged_const([['1', '2'], ['3'], ['4']])
],
expected=ragged_const([[b'a_X_1', b'a_X_2', b'b_X_1', b'b_X_2'], [],
[b'c_X_4']])),
dict(
testcase_name='TwoInputs_SparseStr_SparseStr',
inputs=[
sparse_const([['a', 'b'], [], ['c']]),
sparse_const([['d', 'e'], ['f'], ['g']])
],
expected=ragged_const([[b'a_X_d', b'a_X_e', b'b_X_d', b'b_X_e'], [],
[b'c_X_g']])),
dict(
testcase_name='TwoInputs_DenseInt_DenseInt',
inputs=[dense_const([[1, 2], [3, 4]]),
dense_const([[5, 6], [7, 8]])],
expected=ragged_const([[b'1_X_5', b'1_X_6', b'2_X_5', b'2_X_6'],
[b'3_X_7', b'3_X_8', b'4_X_7', b'4_X_8']])),
dict(
testcase_name='TwoInputs_DenseInt_DenseStr',
inputs=[
dense_const([[1, 2], [3, 4]]),
dense_const([[b'5', b'6'], [b'7', b'8']])
],
expected=ragged_const([[b'1_X_5', b'1_X_6', b'2_X_5', b'2_X_6'],
[b'3_X_7', b'3_X_8', b'4_X_7', b'4_X_8']])),
dict(
testcase_name='TwoInputs_RaggedInt_DenseInt',
inputs=[
ragged_const([[], [], [1, 2], [3]]),
dense_const([[1, 2], [3, 4], [5, 6], [7, 8]])
],
expected=ragged_const([[], [],
[b'1_X_5', b'1_X_6', b'2_X_5', b'2_X_6'],
[b'3_X_7', b'3_X_8']])),
dict(
# This test exercises `input_order`.
testcase_name='TwoInputs_DenseInt_RaggedStr',
inputs=[
dense_const([[1, 2], [3, 4], [5, 6]]),
ragged_const([['d', 'e'], ['f'], ['g']])
],
expected=ragged_const([[b'1_X_d', b'1_X_e', b'2_X_d', b'2_X_e'],
[b'3_X_f', b'4_X_f'], [b'5_X_g', b'6_X_g']]),
matches_sparse_cross=False # sparse doesn't preserve input order.
),
dict(
# This test exercises `input_order`.
testcase_name='TwoInputs_SparseInt_RaggedStr',
inputs=[
sparse_const([[1, 2], [3, 4], [5, 6]]),
ragged_const([['d', 'e'], ['f'], ['g']])
],
expected=ragged_const([[b'1_X_d', b'1_X_e', b'2_X_d', b'2_X_e'],
[b'3_X_f', b'4_X_f'], [b'5_X_g', b'6_X_g']]),
matches_sparse_cross=False # sparse doesn't preserve input order.
),
dict(
testcase_name='ThreeInputs_RaggedInt_RaggedInt_RaggedInt',
inputs=[
ragged_const([[11], [12, 13], [], [14, 15]]),
ragged_const([[21, 22], [23], [24, 25], [26, 27]]),
ragged_const([[31], [32, 33], [34, 35], [36, 37]])
],
expected=ragged_const([[b'11_X_21_X_31', b'11_X_22_X_31'],
[
b'12_X_23_X_32', b'12_X_23_X_33',
b'13_X_23_X_32', b'13_X_23_X_33'
], [],
[
b'14_X_26_X_36', b'14_X_26_X_37',
b'14_X_27_X_36', b'14_X_27_X_37',
b'15_X_26_X_36', b'15_X_26_X_37',
b'15_X_27_X_36', b'15_X_27_X_37'
]])),
dict(
testcase_name='ThreeInputs_RaggedInt_SparseInt_DenseInt',
inputs=[
ragged_const([[11], [12, 13], [], [14, 15]]),
sparse_const([[21, 22], [23], [24, 25], [26, 27]]),
dense_const([[31], [32], [33], [34]])
],
expected=ragged_const([[b'11_X_21_X_31', b'11_X_22_X_31'],
[
b'12_X_23_X_32',
b'13_X_23_X_32',
], [],
[
b'14_X_26_X_34',
b'14_X_27_X_34',
b'15_X_26_X_34',
b'15_X_27_X_34',
]])),
dict(
testcase_name='FiveInputs',
inputs=[
ragged_const([[1]]),
dense_const([[2]]),
ragged_const([[3]]),
sparse_const([[4]]),
ragged_const([[5]])
],
expected=ragged_const([[b'1_X_2_X_3_X_4_X_5']]),
matches_sparse_cross=False # sparse doesn't preserve input order.
),
dict(
testcase_name='Permutation_3x3x3',
inputs=[[['11', '12', '13']], [['21', '22', '23']],
[['31', '32', '33']]],
expected=[[
b'11_X_21_X_31', b'11_X_21_X_32', b'11_X_21_X_33',
b'11_X_22_X_31', b'11_X_22_X_32', b'11_X_22_X_33',
b'11_X_23_X_31', b'11_X_23_X_32', b'11_X_23_X_33',
b'12_X_21_X_31', b'12_X_21_X_32', b'12_X_21_X_33',
b'12_X_22_X_31', b'12_X_22_X_32', b'12_X_22_X_33',
b'12_X_23_X_31', b'12_X_23_X_32', b'12_X_23_X_33',
b'13_X_21_X_31', b'13_X_21_X_32', b'13_X_21_X_33',
b'13_X_22_X_31', b'13_X_22_X_32', b'13_X_22_X_33',
b'13_X_23_X_31', b'13_X_23_X_32', b'13_X_23_X_33'
]]),
dict(
testcase_name='BatchSizeZero',
inputs=[
ragged_const([], ragged_rank=1, dtype=dtypes.int32),
sparse_const([]),
np.zeros([0, 3], dtype=np.int32),
],
expected=ragged_const([], ragged_rank=1, dtype=dtypes.int32)),
dict(
testcase_name='ThreeInputs_OneEmpty',
inputs=[
ragged_const([[1, 2]]),
ragged_const([[]], dtype=dtypes.int32),
ragged_const([[3, 4]])
],
expected=ragged_const([[]], dtype=dtypes.string)),
dict(
testcase_name='ThreeInputs_AllEmpty',
inputs=[
ragged_const([[]], dtype=dtypes.int64),
ragged_const([[]], dtype=dtypes.string),
ragged_const([[]], dtype=dtypes.int32)
],
expected=ragged_const([[]], ragged_rank=1, dtype=dtypes.string)),
dict(
testcase_name='HashedZeroBucketsDefaultKey',
inputs=[
ragged_const([['batch1-FC1-F1']]),
ragged_const([['batch1-FC2-F1']]),
ragged_const([['batch1-FC3-F1']])
],
expected_hashed=ragged_const([[1971693436396284976]])),
dict(
testcase_name='Hashed100BucketsDefaultKey',
inputs=[
ragged_const([['batch1-FC1-F1']]),
ragged_const([['batch1-FC2-F1']]),
ragged_const([['batch1-FC3-F1']])
],
num_buckets=100,
expected_hashed=ragged_const([[83]])),
dict(
testcase_name='HashedZeroBucketsCustomKey',
inputs=[
ragged_const([['batch1-FC1-F1']]),
ragged_const([['batch1-FC2-F1']]),
ragged_const([['batch1-FC3-F1']])
],
hash_key=ragged_array_ops._DEFAULT_CROSS_HASH_KEY + 1,
expected_hashed=ragged_const([[4847552627144134031]])),
dict(
testcase_name='Hashed100BucketsCustomKey',
inputs=[
ragged_const([['batch1-FC1-F1']]),
ragged_const([['batch1-FC2-F1']]),
ragged_const([['batch1-FC3-F1']])
],
num_buckets=100,
hash_key=ragged_array_ops._DEFAULT_CROSS_HASH_KEY + 1,
expected_hashed=ragged_const([[31]])),
dict(
testcase_name='HashedZeroKey',
inputs=[
ragged_const([['batch1-FC1-F1']]),
ragged_const([['batch1-FC2-F1']]),
ragged_const([['batch1-FC3-F1']])
],
hash_key=0,
expected_hashed=ragged_const([[9077905385164735582]]),
matches_sparse_cross=False # sparse treats hash_key=0 as None.
),
dict(
testcase_name='UInt64',
inputs=[ragged_const([[2**64 - 1]], dtype=dtypes.uint64)],
expected=ragged_const([[b'-1']])),
])
def testRaggedCross(self,
inputs,
num_buckets=0,
hash_key=None,
expected=None,
expected_hashed=None,
matches_sparse_cross=True):
ragged_cross = ragged_array_ops.cross(inputs)
ragged_cross_hashed = ragged_array_ops.cross_hashed(inputs, num_buckets,
hash_key)
if expected is not None:
self.assertAllEqual(ragged_cross, expected)
if expected_hashed is not None:
self.assertAllEqual(ragged_cross_hashed, expected_hashed)
if matches_sparse_cross:
# Check that ragged.cross & sparse.cross match.
sparse_inputs = [self._ragged_to_sparse(t) for t in inputs]
sparse_cross = sparse_ops.sparse_cross(sparse_inputs)
self.assertAllEqual(ragged_cross,
ragged_tensor.RaggedTensor.from_sparse(sparse_cross))
# Check that ragged.cross_hashed & sparse.cross_hashed match.
sparse_inputs = [self._ragged_to_sparse(t) for t in inputs]
sparse_cross_hashed = sparse_ops.sparse_cross_hashed(
sparse_inputs, num_buckets, hash_key)
self.assertAllEqual(
ragged_cross_hashed,
ragged_tensor.RaggedTensor.from_sparse(sparse_cross_hashed))
def testRaggedCrossLargeBatch(self):
batch_size = 5000
inputs = [
ragged_const([[1, 2, 3]] * batch_size),
ragged_const([[b'4']] * batch_size),
dense_const([[5]] * batch_size),
sparse_const([[6, 7]] * batch_size)
]
expected = [[
b'1_X_4_X_5_X_6', b'1_X_4_X_5_X_7', b'2_X_4_X_5_X_6', b'2_X_4_X_5_X_7',
b'3_X_4_X_5_X_6', b'3_X_4_X_5_X_7'
]] * batch_size
ragged_cross = ragged_array_ops.cross(inputs)
# Note: we don't use assertAllEqual here because if they don't match,
# then the code in assertAllEqual that tries to build the error message
# is very slow, causing the test to timeout.
# pylint: disable=g-generic-assert
self.assertTrue(self.evaluate(ragged_cross).to_list() == expected)
@parameterized.named_parameters([
dict(
testcase_name='BadDType',
inputs=[ragged_const([[1.1], [2.2, 3.3]])],
message=r'Unexpected dtype for inputs\[0\]'),
dict(
testcase_name='StaticBatchSizeMismatch1',
inputs=[ragged_const([[1]]),
ragged_const([[2], [3]])],
exception=(ValueError, errors.InvalidArgumentError),
message='inputs must all have the same batch dimension size'),
dict(
testcase_name='StaticBatchSizeMismatch2',
inputs=[ragged_const([[1]]),
dense_const([[2], [3]])],
exception=(ValueError, errors.InvalidArgumentError),
message='inputs must all have the same batch dimension size'),
])
def testStaticError(self, inputs, exception=ValueError, message=None):
with self.assertRaisesRegexp(exception, message):
ragged_array_ops.cross(inputs)
@parameterized.named_parameters([
dict(
testcase_name='3DRaggedTensor',
inputs=[ragged_const([[[1]]], ragged_rank=1)],
message='tf.ragged.cross only supports inputs with rank=2'),
dict(
testcase_name='3DDenseTensor',
inputs=[dense_const([[[1]]])],
message='tf.ragged.cross only supports inputs with rank=2'),
])
def testRuntimeError(self,
inputs,
exception=errors.InvalidArgumentError,
message=None):
with self.assertRaisesRegexp(exception, message):
self.evaluate(ragged_array_ops.cross(inputs))
def _ragged_to_sparse(self, t):
if ragged_tensor.is_ragged(t):
return ragged_tensor.convert_to_tensor_or_ragged_tensor(t).to_sparse()
elif sparse_tensor.is_sparse(t):
return sparse_tensor.SparseTensor.from_value(t)
else:
return ops.convert_to_tensor(t)
if __name__ == '__main__':
googletest.main()

View File

@ -16,6 +16,14 @@ tf_module {
name: "constant_value"
argspec: "args=[\'pylist\', \'dtype\', \'ragged_rank\', \'inner_shape\', \'row_splits_dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'int64\'], "
}
member_method {
name: "cross"
argspec: "args=[\'inputs\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "cross_hashed"
argspec: "args=[\'inputs\', \'num_buckets\', \'hash_key\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'None\', \'None\'], "
}
member_method {
name: "map_flat_values"
argspec: "args=[\'op\'], varargs=args, keywords=kwargs, defaults=None"

View File

@ -3020,6 +3020,10 @@ tf_module {
name: "RGBToHSV"
argspec: "args=[\'images\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "RaggedCross"
argspec: "args=[\'ragged_values\', \'ragged_row_splits\', \'sparse_indices\', \'sparse_values\', \'sparse_shape\', \'dense_inputs\', \'input_order\', \'hashed_output\', \'num_buckets\', \'hash_key\', \'out_values_type\', \'out_row_splits_type\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "RaggedGather"
argspec: "args=[\'params_nested_splits\', \'params_dense_values\', \'indices\', \'OUTPUT_RAGGED_RANK\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "

View File

@ -8,6 +8,14 @@ tf_module {
name: "constant"
argspec: "args=[\'pylist\', \'dtype\', \'ragged_rank\', \'inner_shape\', \'name\', \'row_splits_dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \"<dtype: \'int64\'>\"], "
}
member_method {
name: "cross"
argspec: "args=[\'inputs\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "cross_hashed"
argspec: "args=[\'inputs\', \'num_buckets\', \'hash_key\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'None\', \'None\'], "
}
member_method {
name: "map_flat_values"
argspec: "args=[\'op\'], varargs=args, keywords=kwargs, defaults=None"

View File

@ -3020,6 +3020,10 @@ tf_module {
name: "RGBToHSV"
argspec: "args=[\'images\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "RaggedCross"
argspec: "args=[\'ragged_values\', \'ragged_row_splits\', \'sparse_indices\', \'sparse_values\', \'sparse_shape\', \'dense_inputs\', \'input_order\', \'hashed_output\', \'num_buckets\', \'hash_key\', \'out_values_type\', \'out_row_splits_type\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "RaggedGather"
argspec: "args=[\'params_nested_splits\', \'params_dense_values\', \'indices\', \'OUTPUT_RAGGED_RANK\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "