Add tf.ragged.cross and tf.ragged.cross_hashed operations. These ops generate feature crosses from a list of input tensors (similar to tf.sparse.cross), and return the result as a RaggedTensor. Inputs may be tf.RaggedTensors, tf.Tensors, or tf.SparseTensors.
PiperOrigin-RevId: 298495411 Change-Id: I2edde721cfc4e236d78ece470ff30cc68cf6b70c
This commit is contained in:
parent
81e2ecdaae
commit
72e7964b6a
49
tensorflow/core/api_def/base_api/api_def_RaggedCross.pbtxt
Normal file
49
tensorflow/core/api_def/base_api/api_def_RaggedCross.pbtxt
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
op {
|
||||||
|
graph_op_name: "RaggedCross"
|
||||||
|
visibility: HIDDEN
|
||||||
|
in_arg {
|
||||||
|
name: "ragged_values"
|
||||||
|
description: "The values tensor for each RaggedTensor input."
|
||||||
|
}
|
||||||
|
in_arg {
|
||||||
|
name: "ragged_row_splits"
|
||||||
|
description: "The row_splits tensor for each RaggedTensor input."
|
||||||
|
}
|
||||||
|
in_arg {
|
||||||
|
name: "sparse_indices"
|
||||||
|
description: "The indices tensor for each SparseTensor input."
|
||||||
|
}
|
||||||
|
in_arg {
|
||||||
|
name: "sparse_values"
|
||||||
|
description: "The values tensor for each SparseTensor input."
|
||||||
|
}
|
||||||
|
in_arg {
|
||||||
|
name: "sparse_shape"
|
||||||
|
description: "The dense_shape tensor for each SparseTensor input."
|
||||||
|
}
|
||||||
|
in_arg {
|
||||||
|
name: "dense_inputs"
|
||||||
|
description: "The tf.Tensor inputs."
|
||||||
|
}
|
||||||
|
out_arg {
|
||||||
|
name: "output_values"
|
||||||
|
description: "The `values` for the returned `RaggedTensor`."
|
||||||
|
}
|
||||||
|
out_arg {
|
||||||
|
name: "output_row_splits"
|
||||||
|
description: "The `row_splits` for the returned `RaggedTensor`."
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "input_order"
|
||||||
|
description: <<END
|
||||||
|
String specifying the tensor type for each input. The `i`th character in
|
||||||
|
this string specifies the type of the `i`th input, and is one of: 'R' (ragged),
|
||||||
|
'D' (dense), or 'S' (sparse). This attr is used to ensure that the crossed
|
||||||
|
values are combined in the order of the inputs from the call to tf.ragged.cross.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
summary: <<END
|
||||||
|
Generates a feature cross from a list of tensors, and returns it as a
|
||||||
|
RaggedTensor. See `tf.ragged.cross` for more details.
|
||||||
|
END
|
||||||
|
}
|
@ -1394,6 +1394,7 @@ tf_kernel_library(
|
|||||||
cc_library(
|
cc_library(
|
||||||
name = "ragged_ops",
|
name = "ragged_ops",
|
||||||
deps = [
|
deps = [
|
||||||
|
":ragged_cross_op",
|
||||||
":ragged_gather_op",
|
":ragged_gather_op",
|
||||||
":ragged_range_op",
|
":ragged_range_op",
|
||||||
":ragged_tensor_from_variant_op",
|
":ragged_tensor_from_variant_op",
|
||||||
@ -1548,6 +1549,15 @@ tf_cc_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_kernel_library(
|
||||||
|
name = "ragged_cross_op",
|
||||||
|
srcs = ["ragged_cross_op.cc"],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
tf_kernel_library(
|
tf_kernel_library(
|
||||||
name = "rnn_ops",
|
name = "rnn_ops",
|
||||||
deps = [
|
deps = [
|
||||||
|
609
tensorflow/core/kernels/ragged_cross_op.cc
Normal file
609
tensorflow/core/kernels/ragged_cross_op.cc
Normal file
@ -0,0 +1,609 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#include <limits>
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
#include "tensorflow/core/framework/register_types.h"
|
||||||
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
|
#include "tensorflow/core/platform/fingerprint.h"
|
||||||
|
#include "tensorflow/core/util/util.h"
|
||||||
|
#include "tensorflow/core/util/work_sharder.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
//==============================================================================
|
||||||
|
// Feature Readers
|
||||||
|
//==============================================================================
|
||||||
|
|
||||||
|
// A `FeatureReader` is used to read the feature values from a single input
|
||||||
|
// tensor. Subclasses are used for reading different tensor types:
|
||||||
|
// * RaggedFeatureReader<value_type, splits_type>
|
||||||
|
// * SparseFeatureReader<value_type>
|
||||||
|
// * DenseFeatureReader<value_type>
|
||||||
|
//
|
||||||
|
// Where value_type is one of: {tstring, int64}; and SplitsType is one of:
|
||||||
|
// {int32, int64}.
|
||||||
|
class FeatureReader {
|
||||||
|
public:
|
||||||
|
// Returns the number of feature values in the specified batch.
|
||||||
|
virtual int64 FeatureCount(int64 batch) const = 0;
|
||||||
|
|
||||||
|
// Copies the value for the specified feature to `out`.
|
||||||
|
virtual void ReadValue(int64 batch, int64 n, uint64* out) const = 0;
|
||||||
|
virtual void ReadValue(int64 batch, int64 n, tstring* out) const = 0;
|
||||||
|
|
||||||
|
virtual ~FeatureReader() {}
|
||||||
|
};
|
||||||
|
|
||||||
|
using FeatureReaders = std::vector<std::unique_ptr<FeatureReader>>;
|
||||||
|
|
||||||
|
// Copies a feature value `src` to a tstring `dst`, using a view if appropriate.
|
||||||
|
void CopyToString(const tstring& src, tstring* dst) {
|
||||||
|
if (src.type() == tstring::SMALL) {
|
||||||
|
*dst = src; // string buffer fits in the tstring object (under ~24 bytes)
|
||||||
|
} else {
|
||||||
|
dst->assign_as_view(src);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
void CopyToString(int64 src, tstring* dst) { *dst = std::to_string(src); }
|
||||||
|
|
||||||
|
// Copies a feature value `src` to an int64 fingerprint `dst`.
|
||||||
|
void CopyToFingerprint(const tstring& feature, uint64* dst) {
|
||||||
|
*dst = Fingerprint64(feature);
|
||||||
|
}
|
||||||
|
void CopyToFingerprint(int64 feature, uint64* dst) { *dst = feature; }
|
||||||
|
|
||||||
|
// A FeatureReader that is backed by a ragged tensor.
|
||||||
|
template <typename ValuesType, typename SplitsType>
|
||||||
|
class RaggedFeatureReader : public FeatureReader {
|
||||||
|
public:
|
||||||
|
RaggedFeatureReader(const Tensor& values, const Tensor& row_splits)
|
||||||
|
: values_(values.flat<ValuesType>()),
|
||||||
|
row_splits_(row_splits.flat<SplitsType>()) {}
|
||||||
|
|
||||||
|
int64 FeatureCount(int64 batch) const override {
|
||||||
|
return row_splits_(batch + 1) - row_splits_(batch);
|
||||||
|
}
|
||||||
|
|
||||||
|
void ReadValue(int64 batch, int64 n, uint64* out) const override {
|
||||||
|
CopyToFingerprint(values_(row_splits_(batch) + n), out);
|
||||||
|
}
|
||||||
|
|
||||||
|
void ReadValue(int64 batch, int64 n, tstring* out) const override {
|
||||||
|
CopyToString(values_(row_splits_(batch) + n), out);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
const typename TTypes<ValuesType>::ConstFlat values_;
|
||||||
|
const typename TTypes<SplitsType>::ConstFlat row_splits_;
|
||||||
|
};
|
||||||
|
|
||||||
|
// A FeatureReader that is backed by a dense tensor.
|
||||||
|
template <typename ValuesType>
|
||||||
|
class DenseFeatureReader : public FeatureReader {
|
||||||
|
public:
|
||||||
|
explicit DenseFeatureReader(const Tensor& tensor)
|
||||||
|
: values_(tensor.matrix<ValuesType>()),
|
||||||
|
feature_count_(tensor.dim_size(1)) {}
|
||||||
|
|
||||||
|
int64 FeatureCount(int64 batch) const override { return feature_count_; }
|
||||||
|
|
||||||
|
void ReadValue(int64 batch, int64 n, uint64* out) const override {
|
||||||
|
CopyToFingerprint(values_(batch, n), out);
|
||||||
|
}
|
||||||
|
|
||||||
|
void ReadValue(int64 batch, int64 n, tstring* out) const override {
|
||||||
|
CopyToString(values_(batch, n), out);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
const typename TTypes<ValuesType>::ConstMatrix values_;
|
||||||
|
const int64 feature_count_;
|
||||||
|
};
|
||||||
|
|
||||||
|
// A FeatureReader that is backed by a sparse tensor.
|
||||||
|
template <typename ValuesType>
|
||||||
|
class SparseFeatureReader : public FeatureReader {
|
||||||
|
public:
|
||||||
|
SparseFeatureReader(const Tensor& indices_t, const Tensor& values_t,
|
||||||
|
int64 batch_size)
|
||||||
|
: values_(values_t.flat<ValuesType>()) {
|
||||||
|
row_splits_.reserve(batch_size + 1);
|
||||||
|
row_splits_.push_back(0);
|
||||||
|
auto indices = indices_t.matrix<int64>();
|
||||||
|
int64 num_values = values_.size();
|
||||||
|
int64 i = 0; // value index
|
||||||
|
for (int row = 0; row < batch_size; row++) {
|
||||||
|
while (i < num_values && indices(i, 0) <= row) ++i;
|
||||||
|
row_splits_.push_back(i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
int64 FeatureCount(int64 batch) const override {
|
||||||
|
return row_splits_[batch + 1] - row_splits_[batch];
|
||||||
|
}
|
||||||
|
|
||||||
|
void ReadValue(int64 batch, int64 n, uint64* out) const override {
|
||||||
|
CopyToFingerprint(values_(row_splits_[batch] + n), out);
|
||||||
|
}
|
||||||
|
|
||||||
|
void ReadValue(int64 batch, int64 n, tstring* out) const override {
|
||||||
|
CopyToString(values_(row_splits_[batch] + n), out);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
const typename TTypes<ValuesType>::ConstFlat values_;
|
||||||
|
std::vector<int64> row_splits_;
|
||||||
|
};
|
||||||
|
|
||||||
|
//==============================================================================
|
||||||
|
// Output Writers
|
||||||
|
//==============================================================================
|
||||||
|
|
||||||
|
// An `OutputWriter` is used to write the feature crosses to the output values
|
||||||
|
// tensor. Different subclasses are used for writing different output dtypes:
|
||||||
|
// * OutputWriterImpl<tstring, SplitsType> (for tf.ragged.cross)
|
||||||
|
// * OutputWriterImpl<int64, SplitsType> (for tf.ragged.cross_hashed)
|
||||||
|
class OutputWriter {
|
||||||
|
public:
|
||||||
|
virtual void WriteOutputSlice(int64 begin, int64 end) = 0;
|
||||||
|
virtual ~OutputWriter() {}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename ValuesType, typename SplitsType>
|
||||||
|
class OutputWriterImpl : public OutputWriter {
|
||||||
|
public:
|
||||||
|
using FlatValues = typename TTypes<ValuesType>::Flat;
|
||||||
|
using FlatSplits = typename TTypes<SplitsType>::ConstFlat;
|
||||||
|
|
||||||
|
OutputWriterImpl(const FeatureReaders& features, int64 num_buckets,
|
||||||
|
uint64 hash_key, const Tensor* splits_out,
|
||||||
|
Tensor* values_out)
|
||||||
|
: features_(features),
|
||||||
|
num_buckets_(num_buckets),
|
||||||
|
hash_key_(hash_key),
|
||||||
|
splits_out_(splits_out->flat<SplitsType>()),
|
||||||
|
values_out_(values_out->flat<ValuesType>()) {}
|
||||||
|
|
||||||
|
// Reads features from the specified slice of batch indices, computes
|
||||||
|
// feature crosses for each one, and writes them to values_out_.
|
||||||
|
void WriteOutputSlice(int64 begin, int64 end) override {
|
||||||
|
std::vector<int> combination(features_.size(), 0);
|
||||||
|
for (int64 b = begin; b < end; ++b) {
|
||||||
|
auto row_start = splits_out_(b);
|
||||||
|
auto row_limit = splits_out_(b + 1);
|
||||||
|
for (auto i = row_start; i < row_limit; ++i) {
|
||||||
|
WriteCombination(b, combination, &values_out_(i));
|
||||||
|
NextCombination(b, &combination);
|
||||||
|
}
|
||||||
|
combination.assign(features_.size(), 0); // reset for next batch.
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
// Joins the specified combination of input features into a single string,
|
||||||
|
// and writes it to *out.
|
||||||
|
void WriteCombination(int64 batch_index, const std::vector<int>& combination,
|
||||||
|
tstring* out) {
|
||||||
|
static const auto k_feature_separator = "_X_";
|
||||||
|
gtl::InlinedVector<tstring, 6> cross_vec(features_.size());
|
||||||
|
for (int i = 0; i < combination.size(); ++i) {
|
||||||
|
features_[i]->ReadValue(batch_index, combination[i], &cross_vec[i]);
|
||||||
|
}
|
||||||
|
*out = absl::StrJoin(cross_vec, k_feature_separator);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Joins the specified combination of input features into a single
|
||||||
|
// fingerprint, and writes it to *out.
|
||||||
|
void WriteCombination(int64 batch_index, const std::vector<int>& combination,
|
||||||
|
int64* out) {
|
||||||
|
// Do the fingerprint concatenation on uint64.
|
||||||
|
uint64 hashed_output = hash_key_;
|
||||||
|
for (size_t i = 0; i < combination.size(); ++i) {
|
||||||
|
uint64 hash_i;
|
||||||
|
features_[i]->ReadValue(batch_index, combination[i], &hash_i);
|
||||||
|
hashed_output = FingerprintCat64(hashed_output, hash_i);
|
||||||
|
}
|
||||||
|
// The return value is int64 based on the number of buckets.
|
||||||
|
if (num_buckets_ > 0) {
|
||||||
|
*out = hashed_output % num_buckets_;
|
||||||
|
} else {
|
||||||
|
// To prevent negative output we take modulo to max int64.
|
||||||
|
*out = hashed_output % std::numeric_limits<int64>::max();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Updates `combination` to the next combination of input features.
|
||||||
|
void NextCombination(int64 batch_index, std::vector<int>* combination) const {
|
||||||
|
bool carry = true;
|
||||||
|
for (int i = combination->size() - 1; i >= 0; i--) {
|
||||||
|
if (carry) {
|
||||||
|
(*combination)[i] = (*combination)[i] + 1;
|
||||||
|
}
|
||||||
|
if ((*combination)[i] == features_[i]->FeatureCount(batch_index)) {
|
||||||
|
(*combination)[i] = 0;
|
||||||
|
} else {
|
||||||
|
carry = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const FeatureReaders& features_;
|
||||||
|
const int64 num_buckets_;
|
||||||
|
const uint64 hash_key_;
|
||||||
|
FlatSplits splits_out_;
|
||||||
|
FlatValues values_out_;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Returns an appropriate OutputWriter, based on the dtypes of the
|
||||||
|
// given tensors.
|
||||||
|
std::unique_ptr<OutputWriter> MakeOutputWriter(const FeatureReaders& features,
|
||||||
|
int64 num_buckets,
|
||||||
|
uint64 hash_key,
|
||||||
|
const Tensor* splits_out,
|
||||||
|
Tensor* values_out) {
|
||||||
|
if (values_out->dtype() == DT_INT64) {
|
||||||
|
if (splits_out->dtype() == DT_INT64) {
|
||||||
|
return std::make_unique<OutputWriterImpl<int64, int64>>(
|
||||||
|
features, num_buckets, hash_key, splits_out, values_out);
|
||||||
|
} else {
|
||||||
|
return std::make_unique<OutputWriterImpl<int64, int32>>(
|
||||||
|
features, num_buckets, hash_key, splits_out, values_out);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (splits_out->dtype() == DT_INT64) {
|
||||||
|
return std::make_unique<OutputWriterImpl<tstring, int64>>(
|
||||||
|
features, num_buckets, hash_key, splits_out, values_out);
|
||||||
|
} else {
|
||||||
|
return std::make_unique<OutputWriterImpl<tstring, int32>>(
|
||||||
|
features, num_buckets, hash_key, splits_out, values_out);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//==============================================================================
|
||||||
|
// RaggedCross Kernel
|
||||||
|
//==============================================================================
|
||||||
|
|
||||||
|
template <typename SplitsType>
|
||||||
|
class RaggedCrossOp : public OpKernel {
|
||||||
|
public:
|
||||||
|
explicit RaggedCrossOp(OpKernelConstruction* context) : OpKernel(context) {
|
||||||
|
OP_REQUIRES_OK(context, context->GetAttr("num_buckets", &num_buckets_));
|
||||||
|
// Read signed_hash_key_ as int64 since uint64 attributes are not
|
||||||
|
// supported by REGISTER_OP.
|
||||||
|
int64 signed_hash_key_;
|
||||||
|
OP_REQUIRES_OK(context, context->GetAttr("hash_key", &signed_hash_key_));
|
||||||
|
hash_key_ = static_cast<uint64>(signed_hash_key_);
|
||||||
|
|
||||||
|
int num_sparse;
|
||||||
|
OP_REQUIRES_OK(context, context->GetAttr("Nsparse", &num_sparse));
|
||||||
|
|
||||||
|
OP_REQUIRES_OK(context, context->GetAttr("ragged_values_types",
|
||||||
|
&ragged_values_types_));
|
||||||
|
OP_REQUIRES_OK(context, context->GetAttr("ragged_splits_types",
|
||||||
|
&ragged_splits_types_));
|
||||||
|
OP_REQUIRES_OK(context, context->GetAttr("sparse_values_types",
|
||||||
|
&sparse_values_types_));
|
||||||
|
OP_REQUIRES_OK(context, context->GetAttr("dense_types", &dense_types_));
|
||||||
|
OP_REQUIRES_OK(context, context->GetAttr("input_order", &input_order_));
|
||||||
|
OP_REQUIRES(context,
|
||||||
|
ragged_values_types_.size() == ragged_splits_types_.size(),
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"ragged values and splits must have the same length"));
|
||||||
|
OP_REQUIRES(context, num_sparse == sparse_values_types_.size(),
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"sparse indices and values must have the same length"));
|
||||||
|
OP_REQUIRES(context,
|
||||||
|
ragged_values_types_.size() + sparse_values_types_.size() +
|
||||||
|
dense_types_.size() ==
|
||||||
|
input_order_.size(),
|
||||||
|
errors::InvalidArgument("Invalid length for input_order"));
|
||||||
|
}
|
||||||
|
|
||||||
|
void Compute(OpKernelContext* context) override {
|
||||||
|
OpInputList ragged_values_list;
|
||||||
|
OpInputList ragged_splits_list;
|
||||||
|
OpInputList sparse_indices_list;
|
||||||
|
OpInputList sparse_values_list;
|
||||||
|
OpInputList sparse_shape_list;
|
||||||
|
OpInputList dense_list;
|
||||||
|
OP_REQUIRES_OK(context,
|
||||||
|
context->input_list("ragged_values", &ragged_values_list));
|
||||||
|
OP_REQUIRES_OK(
|
||||||
|
context, context->input_list("ragged_row_splits", &ragged_splits_list));
|
||||||
|
OP_REQUIRES_OK(context,
|
||||||
|
context->input_list("sparse_indices", &sparse_indices_list));
|
||||||
|
OP_REQUIRES_OK(context,
|
||||||
|
context->input_list("sparse_values", &sparse_values_list));
|
||||||
|
OP_REQUIRES_OK(context,
|
||||||
|
context->input_list("sparse_shape", &sparse_shape_list));
|
||||||
|
OP_REQUIRES_OK(context, context->input_list("dense_inputs", &dense_list));
|
||||||
|
OP_REQUIRES_OK(context,
|
||||||
|
ValidateInput(ragged_values_list, ragged_splits_list,
|
||||||
|
sparse_indices_list, sparse_values_list,
|
||||||
|
sparse_shape_list, dense_list));
|
||||||
|
|
||||||
|
int64 batch_size =
|
||||||
|
CalculateBatchSize(ragged_splits_list, sparse_shape_list, dense_list);
|
||||||
|
|
||||||
|
FeatureReaders features;
|
||||||
|
OP_REQUIRES_OK(context,
|
||||||
|
BuildFeatureReaders(ragged_values_list, ragged_splits_list,
|
||||||
|
sparse_indices_list, sparse_values_list,
|
||||||
|
dense_list, batch_size, &features));
|
||||||
|
|
||||||
|
Tensor* values_out;
|
||||||
|
Tensor* row_splits_out;
|
||||||
|
OP_REQUIRES_OK(context, BuildOutputTensors(features, batch_size, context,
|
||||||
|
&values_out, &row_splits_out));
|
||||||
|
|
||||||
|
std::unique_ptr<OutputWriter> output_writer = MakeOutputWriter(
|
||||||
|
features, num_buckets_, hash_key_, row_splits_out, values_out);
|
||||||
|
|
||||||
|
auto do_work = [&output_writer](int64 begin, int64 end) {
|
||||||
|
output_writer->WriteOutputSlice(begin, end);
|
||||||
|
};
|
||||||
|
|
||||||
|
// TODO(edloper): optimize cost_per_batch
|
||||||
|
const int cost_per_batch = 5000 * ragged_values_list.size();
|
||||||
|
auto thread_pool =
|
||||||
|
context->device()->tensorflow_cpu_worker_threads()->workers;
|
||||||
|
thread_pool->ParallelFor(batch_size, cost_per_batch, do_work);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
// Validates input tensors.
|
||||||
|
Status ValidateInput(const OpInputList& ragged_values_list,
|
||||||
|
const OpInputList& ragged_splits_list,
|
||||||
|
const OpInputList& sparse_indices_list,
|
||||||
|
const OpInputList& sparse_values_list,
|
||||||
|
const OpInputList& sparse_shape_list,
|
||||||
|
const OpInputList& dense_list) {
|
||||||
|
const auto num_ragged = ragged_values_list.size();
|
||||||
|
const auto num_sparse = sparse_indices_list.size();
|
||||||
|
|
||||||
|
// Validate tensor shapes.
|
||||||
|
for (int i = 0; i < num_ragged; ++i) {
|
||||||
|
if (!TensorShapeUtils::IsVector(ragged_values_list[i].shape())) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"tf.ragged.cross only supports inputs with rank=2.");
|
||||||
|
}
|
||||||
|
if (!TensorShapeUtils::IsVector(ragged_splits_list[i].shape()) ||
|
||||||
|
(ragged_splits_list[i].NumElements() == 0)) {
|
||||||
|
return errors::InvalidArgument("Invalid RaggedTensor");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (int i = 0; i < num_sparse; ++i) {
|
||||||
|
if (!TensorShapeUtils::IsMatrix(sparse_indices_list[i].shape()) ||
|
||||||
|
!TensorShapeUtils::IsVector(sparse_values_list[i].shape()) ||
|
||||||
|
!TensorShapeUtils::IsVector(sparse_shape_list[i].shape())) {
|
||||||
|
return errors::InvalidArgument("Invalid SparseTensor ", i);
|
||||||
|
}
|
||||||
|
if (sparse_shape_list[i].NumElements() != 2) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"tf.ragged.cross only supports inputs with rank=2.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (int i = 0; i < dense_list.size(); ++i) {
|
||||||
|
if (!TensorShapeUtils::IsMatrix(dense_list[i].shape())) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"tf.ragged.cross only supports inputs with rank=2.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that batch sizes are consistent.
|
||||||
|
int64 batch_size =
|
||||||
|
CalculateBatchSize(ragged_splits_list, sparse_shape_list, dense_list);
|
||||||
|
for (int i = 0; i < num_ragged; ++i) {
|
||||||
|
if (ragged_splits_list[i].NumElements() - 1 != batch_size) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"inputs must all have the same batch dimension size.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (int i = 0; i < num_sparse; ++i) {
|
||||||
|
if (sparse_shape_list[i].flat<int64>()(0) != batch_size) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"inputs must all have the same batch dimension size.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (int i = 0; i < dense_list.size(); ++i) {
|
||||||
|
if (dense_list[i].dim_size(0) != batch_size) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"inputs must all have the same batch dimension size.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate the batch size from any input tensor. (We check that all input
|
||||||
|
// tensors have the same batch size in `ValidateInput`).
|
||||||
|
int64 CalculateBatchSize(const OpInputList& ragged_splits_list,
|
||||||
|
const OpInputList& sparse_shape_list,
|
||||||
|
const OpInputList& dense_list) {
|
||||||
|
if (ragged_splits_list.size() > 0) {
|
||||||
|
return ragged_splits_list[0].NumElements() - 1;
|
||||||
|
} else if (dense_list.size() > 0) {
|
||||||
|
return dense_list[0].dim_size(0);
|
||||||
|
} else if (sparse_shape_list.size() > 0) {
|
||||||
|
return sparse_shape_list[0].flat<int64>()(0);
|
||||||
|
} else {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build a feature reader for each input tensor, and store them in `features`.
|
||||||
|
Status BuildFeatureReaders(const OpInputList& ragged_values_list,
|
||||||
|
const OpInputList& ragged_splits_list,
|
||||||
|
const OpInputList& sparse_indices_list,
|
||||||
|
const OpInputList& sparse_values_list,
|
||||||
|
const OpInputList& dense_list, int64 batch_size,
|
||||||
|
FeatureReaders* features) {
|
||||||
|
features->reserve(input_order_.size());
|
||||||
|
|
||||||
|
int next_ragged = 0;
|
||||||
|
int next_sparse = 0;
|
||||||
|
int next_dense = 0;
|
||||||
|
for (char c : input_order_) {
|
||||||
|
if (c == 'R') {
|
||||||
|
TF_RETURN_IF_ERROR(BuildRaggedFeatureReader(
|
||||||
|
ragged_values_list[next_ragged], ragged_splits_list[next_ragged],
|
||||||
|
features));
|
||||||
|
next_ragged++;
|
||||||
|
} else if (c == 'S') {
|
||||||
|
TF_RETURN_IF_ERROR(BuildSparseFeatureReader(
|
||||||
|
sparse_indices_list[next_sparse], sparse_values_list[next_sparse],
|
||||||
|
batch_size, features));
|
||||||
|
next_sparse++;
|
||||||
|
} else if (c == 'D') {
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
BuildDenseFeatureReader(dense_list[next_dense++], features));
|
||||||
|
} else {
|
||||||
|
return errors::InvalidArgument("Unexpected input_order value.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Builds a RaggedReatureReader
|
||||||
|
static Status BuildRaggedFeatureReader(const Tensor& values,
|
||||||
|
const Tensor& splits,
|
||||||
|
FeatureReaders* features) {
|
||||||
|
if (values.dtype() != DT_INT64 && values.dtype() != DT_STRING) {
|
||||||
|
return errors::InvalidArgument("Unexpected dtype for input ",
|
||||||
|
(features->size() + 1), ": ",
|
||||||
|
values.dtype());
|
||||||
|
}
|
||||||
|
if (splits.dtype() != DT_INT64 && splits.dtype() != DT_INT32) {
|
||||||
|
return errors::InvalidArgument("Unexpected row_splits.dtype for input ",
|
||||||
|
(features->size() + 1), ": ",
|
||||||
|
values.dtype());
|
||||||
|
}
|
||||||
|
if (values.dtype() == DT_INT64) {
|
||||||
|
if (splits.dtype() == DT_INT64) {
|
||||||
|
features->emplace_back(
|
||||||
|
new RaggedFeatureReader<int64, int64>(values, splits));
|
||||||
|
} else {
|
||||||
|
features->emplace_back(
|
||||||
|
new RaggedFeatureReader<int64, int32>(values, splits));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (splits.dtype() == DT_INT64) {
|
||||||
|
features->emplace_back(
|
||||||
|
new RaggedFeatureReader<tstring, int64>(values, splits));
|
||||||
|
} else {
|
||||||
|
features->emplace_back(
|
||||||
|
new RaggedFeatureReader<tstring, int32>(values, splits));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Builds a DenseFaggedReatureReader.
|
||||||
|
static Status BuildDenseFeatureReader(const Tensor& values,
|
||||||
|
FeatureReaders* features) {
|
||||||
|
if (values.dtype() == DT_INT64) {
|
||||||
|
features->emplace_back(new DenseFeatureReader<int64>(values));
|
||||||
|
} else if (values.dtype() == DT_STRING) {
|
||||||
|
features->emplace_back(new DenseFeatureReader<tstring>(values));
|
||||||
|
} else {
|
||||||
|
return errors::InvalidArgument("Unexpected dtype for input ",
|
||||||
|
(features->size() + 1), ": ",
|
||||||
|
values.dtype());
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Builds a SparseFaggedReatureReader.
|
||||||
|
static Status BuildSparseFeatureReader(const Tensor& indices,
|
||||||
|
const Tensor& values, int64 batch_size,
|
||||||
|
FeatureReaders* features) {
|
||||||
|
if (values.dtype() == DT_INT64) {
|
||||||
|
features->emplace_back(
|
||||||
|
new SparseFeatureReader<int64>(indices, values, batch_size));
|
||||||
|
} else if (values.dtype() == DT_STRING) {
|
||||||
|
features->emplace_back(
|
||||||
|
new SparseFeatureReader<tstring>(indices, values, batch_size));
|
||||||
|
} else {
|
||||||
|
return errors::InvalidArgument("Unexpected dtype for input ",
|
||||||
|
(features->size() + 1), ": ",
|
||||||
|
values.dtype());
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Allocates output tensors with proper size, and populates row_splits_out.
|
||||||
|
Status BuildOutputTensors(const FeatureReaders& features, int64 batch_size,
|
||||||
|
OpKernelContext* context, Tensor** values_out,
|
||||||
|
Tensor** row_splits_out) {
|
||||||
|
// Allocate and populate the row_splits output tensor.
|
||||||
|
TF_RETURN_IF_ERROR(context->allocate_output(
|
||||||
|
1, TensorShape({batch_size + 1}), row_splits_out));
|
||||||
|
auto flat_row_splits = (*row_splits_out)->flat<SplitsType>();
|
||||||
|
int64 cross_count_total = 0;
|
||||||
|
flat_row_splits(0) = 0;
|
||||||
|
for (int64 b = 0; b < batch_size; b++) {
|
||||||
|
cross_count_total += CrossCountByBatchIndex(features, b);
|
||||||
|
flat_row_splits(b + 1) = cross_count_total;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Allocate the values output tensor.
|
||||||
|
TF_RETURN_IF_ERROR(context->allocate_output(
|
||||||
|
0, TensorShape({cross_count_total}), values_out));
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns number of crosses for a given batch_index
|
||||||
|
int64 CrossCountByBatchIndex(const FeatureReaders& features,
|
||||||
|
int batch_index) {
|
||||||
|
int64 cross_count = 1;
|
||||||
|
for (int i = 0; i < features.size(); ++i) {
|
||||||
|
const auto feature_count = features[i]->FeatureCount(batch_index);
|
||||||
|
if (feature_count == 0) return 0;
|
||||||
|
cross_count *= feature_count;
|
||||||
|
}
|
||||||
|
return cross_count;
|
||||||
|
}
|
||||||
|
|
||||||
|
int64 num_buckets_;
|
||||||
|
uint64 hash_key_;
|
||||||
|
std::vector<DataType> ragged_values_types_;
|
||||||
|
std::vector<DataType> ragged_splits_types_;
|
||||||
|
std::vector<DataType> sparse_values_types_;
|
||||||
|
std::vector<DataType> dense_types_;
|
||||||
|
tstring input_order_;
|
||||||
|
};
|
||||||
|
|
||||||
|
REGISTER_KERNEL_BUILDER(Name("RaggedCross")
|
||||||
|
.Device(DEVICE_CPU)
|
||||||
|
.TypeConstraint<int32>("out_row_splits_type"),
|
||||||
|
RaggedCrossOp<int32>);
|
||||||
|
REGISTER_KERNEL_BUILDER(Name("RaggedCross")
|
||||||
|
.Device(DEVICE_CPU)
|
||||||
|
.TypeConstraint<int64>("out_row_splits_type"),
|
||||||
|
RaggedCrossOp<int64>);
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace tensorflow
|
@ -41,6 +41,79 @@ REGISTER_OP("RaggedGather")
|
|||||||
.Attr("OUTPUT_RAGGED_RANK: int >= 0")
|
.Attr("OUTPUT_RAGGED_RANK: int >= 0")
|
||||||
.SetShapeFn(RaggedGatherShapeFn);
|
.SetShapeFn(RaggedGatherShapeFn);
|
||||||
|
|
||||||
|
REGISTER_OP("RaggedCross")
|
||||||
|
.Input("ragged_values: ragged_values_types")
|
||||||
|
.Input("ragged_row_splits: ragged_splits_types")
|
||||||
|
.Input("sparse_indices: Nsparse * int64")
|
||||||
|
.Input("sparse_values: sparse_values_types")
|
||||||
|
.Input("sparse_shape: Nsparse * int64")
|
||||||
|
.Input("dense_inputs: dense_types")
|
||||||
|
.Output("output_values: out_values_type")
|
||||||
|
.Output("output_row_splits: out_row_splits_type")
|
||||||
|
.Attr("Nsparse: int >= 0")
|
||||||
|
.Attr("input_order: string")
|
||||||
|
.Attr("hashed_output: bool")
|
||||||
|
.Attr("num_buckets: int >= 0")
|
||||||
|
.Attr("hash_key: int")
|
||||||
|
.Attr("ragged_values_types: list({int64, string}) >= 0")
|
||||||
|
.Attr("ragged_splits_types: list({int32, int64}) >= 0")
|
||||||
|
.Attr("sparse_values_types: list({int64, string}) >= 0")
|
||||||
|
.Attr("dense_types: list({int64, string}) >= 0")
|
||||||
|
.Attr("out_values_type: {int64, string}")
|
||||||
|
.Attr("out_row_splits_type: {int32, int64}")
|
||||||
|
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||||
|
std::vector<DataType> ragged_values_types;
|
||||||
|
std::vector<DataType> ragged_splits_types;
|
||||||
|
std::vector<DataType> dense_types;
|
||||||
|
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
c->GetAttr("ragged_values_types", &ragged_values_types));
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
c->GetAttr("ragged_splits_types", &ragged_splits_types));
|
||||||
|
TF_RETURN_IF_ERROR(c->GetAttr("dense_types", &dense_types));
|
||||||
|
|
||||||
|
int num_ragged = ragged_values_types.size();
|
||||||
|
if (num_ragged != ragged_splits_types.size()) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"Parameters `values` and `row_splits` must be the same length");
|
||||||
|
}
|
||||||
|
|
||||||
|
int num_sparse;
|
||||||
|
TF_RETURN_IF_ERROR(c->GetAttr("Nsparse", &num_sparse));
|
||||||
|
|
||||||
|
ShapeHandle out_values = c->UnknownShapeOfRank(1);
|
||||||
|
ShapeHandle out_splits = c->UnknownShapeOfRank(1);
|
||||||
|
|
||||||
|
// Merge the shapes of row_splits from ragged inputs. (This is one plus
|
||||||
|
// the batch size.)
|
||||||
|
int ragged_splits_start = num_ragged;
|
||||||
|
for (int i = 0; i < ragged_splits_types.size(); ++i) {
|
||||||
|
ShapeHandle row_splits = c->input(i + ragged_splits_start);
|
||||||
|
if (!c->Merge(out_splits, row_splits, &out_splits).ok()) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"inputs must all have the same batch dimension size.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Merge the batch size of each dense input into out_splits.
|
||||||
|
int dense_start = num_ragged * 2 + num_sparse * 3;
|
||||||
|
for (int i = 0; i < dense_types.size(); ++i) {
|
||||||
|
ShapeHandle dense_input = c->input(i + dense_start);
|
||||||
|
int64 batch_size = c->Value(c->Dim(dense_input, 0));
|
||||||
|
if (batch_size != InferenceContext::kUnknownDim) {
|
||||||
|
ShapeHandle row_splits = c->Vector(batch_size + 1);
|
||||||
|
if (!c->Merge(out_splits, row_splits, &out_splits).ok()) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"inputs must all have the same batch dimension size.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
c->set_output(0, out_values);
|
||||||
|
c->set_output(1, out_splits);
|
||||||
|
return Status::OK();
|
||||||
|
});
|
||||||
|
|
||||||
//==============================================================================
|
//==============================================================================
|
||||||
// Shape Functions
|
// Shape Functions
|
||||||
//==============================================================================
|
//==============================================================================
|
||||||
|
@ -1123,3 +1123,17 @@ py_test(
|
|||||||
"@absl_py//absl/testing:parameterized",
|
"@absl_py//absl/testing:parameterized",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "ragged_cross_op_test",
|
||||||
|
srcs = ["ragged_cross_op_test.py"],
|
||||||
|
python_version = "PY3",
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
deps = [
|
||||||
|
":ragged_array_ops",
|
||||||
|
"//tensorflow/python:constant_op",
|
||||||
|
"//tensorflow/python:framework_test_lib",
|
||||||
|
"//tensorflow/python:platform_test",
|
||||||
|
"@absl_py//absl/testing:parameterized",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@ -20,9 +20,11 @@ from __future__ import print_function
|
|||||||
|
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.framework import sparse_tensor
|
||||||
from tensorflow.python.framework import tensor_util
|
from tensorflow.python.framework import tensor_util
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import check_ops
|
from tensorflow.python.ops import check_ops
|
||||||
|
from tensorflow.python.ops import gen_ragged_array_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import sort_ops
|
from tensorflow.python.ops import sort_ops
|
||||||
from tensorflow.python.ops.ragged import ragged_functional_ops
|
from tensorflow.python.ops.ragged import ragged_functional_ops
|
||||||
@ -32,7 +34,6 @@ from tensorflow.python.ops.ragged import ragged_util
|
|||||||
from tensorflow.python.ops.ragged import segment_id_ops
|
from tensorflow.python.ops.ragged import segment_id_ops
|
||||||
from tensorflow.python.util.tf_export import tf_export
|
from tensorflow.python.util.tf_export import tf_export
|
||||||
|
|
||||||
|
|
||||||
#===============================================================================
|
#===============================================================================
|
||||||
# Masking
|
# Masking
|
||||||
#===============================================================================
|
#===============================================================================
|
||||||
@ -107,7 +108,8 @@ def boolean_mask(data, mask, name=None):
|
|||||||
if ragged_tensor.is_ragged(mask):
|
if ragged_tensor.is_ragged(mask):
|
||||||
if not ragged_tensor.is_ragged(data):
|
if not ragged_tensor.is_ragged(data):
|
||||||
data = ragged_tensor.RaggedTensor.from_tensor(
|
data = ragged_tensor.RaggedTensor.from_tensor(
|
||||||
data, ragged_rank=mask.ragged_rank,
|
data,
|
||||||
|
ragged_rank=mask.ragged_rank,
|
||||||
row_splits_dtype=mask.row_splits.dtype)
|
row_splits_dtype=mask.row_splits.dtype)
|
||||||
# Check that mask.nested_row_splits is a prefix of
|
# Check that mask.nested_row_splits is a prefix of
|
||||||
# data.nested_row_splits.
|
# data.nested_row_splits.
|
||||||
@ -160,15 +162,15 @@ def boolean_mask(data, mask, name=None):
|
|||||||
segment_mask = array_ops.gather(mask, segment_ids)
|
segment_mask = array_ops.gather(mask, segment_ids)
|
||||||
masked_values = boolean_mask(data.values, segment_mask)
|
masked_values = boolean_mask(data.values, segment_mask)
|
||||||
|
|
||||||
return ragged_tensor.RaggedTensor.from_row_splits(masked_values,
|
return ragged_tensor.RaggedTensor.from_row_splits(
|
||||||
masked_splits,
|
masked_values, masked_splits, validate=False)
|
||||||
validate=False)
|
|
||||||
|
|
||||||
# If mask is non-ragged and has rank>1, then convert it to be ragged,
|
# If mask is non-ragged and has rank>1, then convert it to be ragged,
|
||||||
# with a ragged rank matching data.
|
# with a ragged rank matching data.
|
||||||
if ragged_tensor.is_ragged(data):
|
if ragged_tensor.is_ragged(data):
|
||||||
mask = ragged_tensor.RaggedTensor.from_tensor(
|
mask = ragged_tensor.RaggedTensor.from_tensor(
|
||||||
mask, ragged_rank=min(data.ragged_rank, mask.shape.ndims - 1),
|
mask,
|
||||||
|
ragged_rank=min(data.ragged_rank, mask.shape.ndims - 1),
|
||||||
row_splits_dtype=data.row_splits.dtype)
|
row_splits_dtype=data.row_splits.dtype)
|
||||||
return boolean_mask(data, mask)
|
return boolean_mask(data, mask)
|
||||||
|
|
||||||
@ -182,8 +184,8 @@ def boolean_mask(data, mask, name=None):
|
|||||||
# number of values it contains. Then flatten that to get a list of
|
# number of values it contains. Then flatten that to get a list of
|
||||||
# cell lengths, and convert it to splits. Finally, combine the splits
|
# cell lengths, and convert it to splits. Finally, combine the splits
|
||||||
# and values to get the innermost ragged tensor.
|
# and values to get the innermost ragged tensor.
|
||||||
masked_lengths = math_ops.count_nonzero(mask, axis=-1,
|
masked_lengths = math_ops.count_nonzero(
|
||||||
dtype=row_splits_dtype)
|
mask, axis=-1, dtype=row_splits_dtype)
|
||||||
flattened_masked_lengths = array_ops.reshape(masked_lengths, [-1])
|
flattened_masked_lengths = array_ops.reshape(masked_lengths, [-1])
|
||||||
masked_values = ragged_tensor.RaggedTensor.from_row_lengths(
|
masked_values = ragged_tensor.RaggedTensor.from_row_lengths(
|
||||||
masked_values, flattened_masked_lengths, validate=False)
|
masked_values, flattened_masked_lengths, validate=False)
|
||||||
@ -340,8 +342,7 @@ def _tile_ragged_splits(rt_input, multiples, const_multiples=None):
|
|||||||
for src_axis in range(ragged_rank):
|
for src_axis in range(ragged_rank):
|
||||||
for dst_axis in range(src_axis + 1, ragged_rank - 1):
|
for dst_axis in range(src_axis + 1, ragged_rank - 1):
|
||||||
projected_splits[src_axis][dst_axis] = array_ops.gather(
|
projected_splits[src_axis][dst_axis] = array_ops.gather(
|
||||||
nested_splits[dst_axis],
|
nested_splits[dst_axis], projected_splits[src_axis][dst_axis - 1])
|
||||||
projected_splits[src_axis][dst_axis - 1])
|
|
||||||
|
|
||||||
# For each ragged dimension: nested_splits[axis] -> result_splits[axis].
|
# For each ragged dimension: nested_splits[axis] -> result_splits[axis].
|
||||||
result_splits = []
|
result_splits = []
|
||||||
@ -400,8 +401,7 @@ def expand_dims(input, axis, name=None): # pylint: disable=redefined-builtin
|
|||||||
`[D1, (D2), (D3), D4]` | `4` | `[D1, (D2), (D3), D4, 1]`
|
`[D1, (D2), (D3), D4]` | `4` | `[D1, (D2), (D3), D4, 1]`
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input: The potentially tensor that should be expanded with a new
|
input: The potentially tensor that should be expanded with a new dimension.
|
||||||
dimension.
|
|
||||||
axis: An integer constant indicating where the new dimension should be
|
axis: An integer constant indicating where the new dimension should be
|
||||||
inserted.
|
inserted.
|
||||||
name: A name for the operation (optional).
|
name: A name for the operation (optional).
|
||||||
@ -556,10 +556,10 @@ def stack_dynamic_partitions(data, partitions, num_partitions, name=None):
|
|||||||
Args:
|
Args:
|
||||||
data: A `Tensor` or `RaggedTensor` containing the values to stack.
|
data: A `Tensor` or `RaggedTensor` containing the values to stack.
|
||||||
partitions: An `int32` or `int64` `Tensor` or `RaggedTensor` specifying the
|
partitions: An `int32` or `int64` `Tensor` or `RaggedTensor` specifying the
|
||||||
partition that each slice of `data` should be added to.
|
partition that each slice of `data` should be added to. `partitions.shape`
|
||||||
`partitions.shape` must be a prefix of `data.shape`. Values must be
|
must be a prefix of `data.shape`. Values must be greater than or equal to
|
||||||
greater than or equal to zero, and less than `num_partitions`.
|
zero, and less than `num_partitions`. `partitions` is not required to be
|
||||||
`partitions` is not required to be sorted.
|
sorted.
|
||||||
num_partitions: An `int32` or `int64` scalar specifying the number of
|
num_partitions: An `int32` or `int64` scalar specifying the number of
|
||||||
partitions to output. This determines the number of rows in `output`.
|
partitions to output. This determines the number of rows in `output`.
|
||||||
name: A name prefix for the returned tensor (optional).
|
name: A name prefix for the returned tensor (optional).
|
||||||
@ -650,8 +650,8 @@ def reverse(tensor, axis, name=None):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
tensor: A 'RaggedTensor' to reverse.
|
tensor: A 'RaggedTensor' to reverse.
|
||||||
axis: A list or tuple of 'int' or a constant 1D 'tf.Tensor'. The indices
|
axis: A list or tuple of 'int' or a constant 1D 'tf.Tensor'. The indices of
|
||||||
of the axes to reverse.
|
the axes to reverse.
|
||||||
name: A name prefix for the returned tensor (optional).
|
name: A name prefix for the returned tensor (optional).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -687,3 +687,137 @@ def reverse(tensor, axis, name=None):
|
|||||||
slices[dim] = slice(None, None, -1)
|
slices[dim] = slice(None, None, -1)
|
||||||
|
|
||||||
return tensor[tuple(slices)]
|
return tensor[tuple(slices)]
|
||||||
|
|
||||||
|
|
||||||
|
#===============================================================================
|
||||||
|
# Cross
|
||||||
|
#===============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
@tf_export('ragged.cross')
|
||||||
|
def cross(inputs, name=None):
|
||||||
|
"""Generates feature cross from a list of tensors.
|
||||||
|
|
||||||
|
The input tensors must have `rank=2`, and must all have the same number of
|
||||||
|
rows. The result is a `RaggedTensor` with the same number of rows as the
|
||||||
|
inputs, where `result[row]` contains a list of all combinations of values
|
||||||
|
formed by taking a single value from each input's corresponding row
|
||||||
|
(`inputs[i][row]`). Values are combined by joining their strings with '_X_'.
|
||||||
|
E.g.:
|
||||||
|
|
||||||
|
>>> tf.ragged.cross([tf.ragged.constant([['a'], ['b', 'c']]),
|
||||||
|
... tf.ragged.constant([['d'], ['e']]),
|
||||||
|
... tf.ragged.constant([['f'], ['g']])])
|
||||||
|
<tf.RaggedTensor [[b'a_X_d_X_f'], [b'b_X_e_X_g', b'c_X_e_X_g']]>
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs: A list of `RaggedTensor` or `Tensor` or `SparseTensor`.
|
||||||
|
name: Optional name for the op.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A 2D `RaggedTensor` of type `string`.
|
||||||
|
"""
|
||||||
|
return _cross_internal(inputs=inputs, hashed_output=False, name=name)
|
||||||
|
|
||||||
|
|
||||||
|
@tf_export('ragged.cross_hashed')
|
||||||
|
def cross_hashed(inputs, num_buckets=0, hash_key=None, name=None):
|
||||||
|
"""Generates hashed feature cross from a list of tensors.
|
||||||
|
|
||||||
|
The input tensors must have `rank=2`, and must all have the same number of
|
||||||
|
rows. The result is a `RaggedTensor` with the same number of rows as the
|
||||||
|
inputs, where `result[row]` contains a list of all combinations of values
|
||||||
|
formed by taking a single value from each input's corresponding row
|
||||||
|
(`inputs[i][row]`). Values are combined by hashing together their
|
||||||
|
fingerprints. E.g.:
|
||||||
|
|
||||||
|
>>> tf.ragged.cross_hashed([tf.ragged.constant([['a'], ['b', 'c']]),
|
||||||
|
... tf.ragged.constant([['d'], ['e']]),
|
||||||
|
... tf.ragged.constant([['f'], ['g']])],
|
||||||
|
... num_buckets=100)
|
||||||
|
<tf.RaggedTensor [[78], [66, 74]]>
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs: A list of `RaggedTensor` or `Tensor` or `SparseTensor`.
|
||||||
|
num_buckets: A non-negative `int` that used to bucket the hashed values. If
|
||||||
|
`num_buckets != 0`, then `output = hashed_value % num_buckets`.
|
||||||
|
hash_key: Integer hash_key that will be used by the `FingerprintCat64`
|
||||||
|
function. If not given, a default key is used.
|
||||||
|
name: Optional name for the op.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A 2D `RaggedTensor` of type `int64`.
|
||||||
|
"""
|
||||||
|
return _cross_internal(
|
||||||
|
inputs=inputs,
|
||||||
|
hashed_output=True,
|
||||||
|
num_buckets=num_buckets,
|
||||||
|
hash_key=hash_key,
|
||||||
|
name=name)
|
||||||
|
|
||||||
|
|
||||||
|
_DEFAULT_CROSS_HASH_KEY = 0xDECAFCAFFE
|
||||||
|
|
||||||
|
|
||||||
|
def _cross_internal(inputs,
|
||||||
|
hashed_output=False,
|
||||||
|
num_buckets=0,
|
||||||
|
hash_key=None,
|
||||||
|
name=None):
|
||||||
|
"""Generates feature cross from a list of ragged and dense tensors."""
|
||||||
|
if not isinstance(inputs, (tuple, list)):
|
||||||
|
raise TypeError('Inputs must be a list')
|
||||||
|
|
||||||
|
if hash_key is None:
|
||||||
|
hash_key = _DEFAULT_CROSS_HASH_KEY
|
||||||
|
|
||||||
|
ragged_inputs = []
|
||||||
|
sparse_inputs = []
|
||||||
|
dense_inputs = []
|
||||||
|
input_order = []
|
||||||
|
with ops.name_scope(name, 'RaggedCross', inputs):
|
||||||
|
for i, t in enumerate(inputs):
|
||||||
|
if sparse_tensor.is_sparse(t):
|
||||||
|
t = sparse_tensor.SparseTensor.from_value(t)
|
||||||
|
else:
|
||||||
|
t = ragged_tensor.convert_to_tensor_or_ragged_tensor(t)
|
||||||
|
if t.dtype.is_integer:
|
||||||
|
t = math_ops.cast(t, dtypes.int64)
|
||||||
|
elif t.dtype != dtypes.string:
|
||||||
|
raise ValueError('Unexpected dtype for inputs[%d]: %s' % (i, t.dtype))
|
||||||
|
if isinstance(t, ragged_tensor.RaggedTensor):
|
||||||
|
if t.ragged_rank != 1:
|
||||||
|
raise ValueError('tf.ragged.cross only supports inputs with rank=2')
|
||||||
|
ragged_inputs.append(t)
|
||||||
|
input_order.append('R')
|
||||||
|
elif isinstance(t, sparse_tensor.SparseTensor):
|
||||||
|
sparse_inputs.append(t)
|
||||||
|
input_order.append('S')
|
||||||
|
else:
|
||||||
|
dense_inputs.append(t)
|
||||||
|
input_order.append('D')
|
||||||
|
|
||||||
|
out_values_type = dtypes.int64 if hashed_output else dtypes.string
|
||||||
|
if ragged_inputs and all(
|
||||||
|
t.row_splits.dtype == dtypes.int32 for t in ragged_inputs):
|
||||||
|
out_row_splits_type = dtypes.int32
|
||||||
|
else:
|
||||||
|
out_row_splits_type = dtypes.int64
|
||||||
|
|
||||||
|
values_out, splits_out = gen_ragged_array_ops.ragged_cross(
|
||||||
|
ragged_values=[rt.values for rt in ragged_inputs],
|
||||||
|
ragged_row_splits=[rt.row_splits for rt in ragged_inputs],
|
||||||
|
sparse_indices=[st.indices for st in sparse_inputs],
|
||||||
|
sparse_values=[st.values for st in sparse_inputs],
|
||||||
|
sparse_shape=[st.dense_shape for st in sparse_inputs],
|
||||||
|
dense_inputs=dense_inputs,
|
||||||
|
input_order=''.join(input_order),
|
||||||
|
hashed_output=hashed_output,
|
||||||
|
num_buckets=num_buckets,
|
||||||
|
hash_key=hash_key,
|
||||||
|
out_values_type=out_values_type,
|
||||||
|
out_row_splits_type=out_row_splits_type,
|
||||||
|
name=name)
|
||||||
|
|
||||||
|
return ragged_tensor.RaggedTensor.from_row_splits(
|
||||||
|
values_out, splits_out, validate=False)
|
||||||
|
396
tensorflow/python/ops/ragged/ragged_cross_op_test.py
Normal file
396
tensorflow/python/ops/ragged/ragged_cross_op_test.py
Normal file
@ -0,0 +1,396 @@
|
|||||||
|
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Tests for tf.ragged.cross and tf.ragged.cross_hashed."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from absl.testing import parameterized
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.framework import errors
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.framework import sparse_tensor
|
||||||
|
from tensorflow.python.framework import test_util
|
||||||
|
from tensorflow.python.ops import sparse_ops
|
||||||
|
from tensorflow.python.ops.ragged import ragged_array_ops
|
||||||
|
from tensorflow.python.ops.ragged import ragged_factory_ops
|
||||||
|
from tensorflow.python.ops.ragged import ragged_tensor
|
||||||
|
from tensorflow.python.platform import googletest
|
||||||
|
|
||||||
|
ragged_const = ragged_factory_ops.constant_value
|
||||||
|
dense_const = np.array
|
||||||
|
|
||||||
|
|
||||||
|
def sparse_const(matrix):
|
||||||
|
indices = []
|
||||||
|
values = []
|
||||||
|
for i, row in enumerate(matrix):
|
||||||
|
for j, val in enumerate(row):
|
||||||
|
indices.append([i, j])
|
||||||
|
values.append(val)
|
||||||
|
shape = [len(matrix), max(len(row) for row in matrix)] if matrix else [0, 0]
|
||||||
|
if not values:
|
||||||
|
indices = np.zeros([0, 2], dtype=np.int64)
|
||||||
|
values = np.zeros([0], dtype=np.int64)
|
||||||
|
return sparse_tensor.SparseTensorValue(indices, values, shape)
|
||||||
|
|
||||||
|
|
||||||
|
class RaggedCrossOpTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||||
|
|
||||||
|
@parameterized.named_parameters([
|
||||||
|
dict(
|
||||||
|
testcase_name='NoInputs',
|
||||||
|
inputs=[],
|
||||||
|
expected=ragged_const([], ragged_rank=1, dtype=dtypes.int32)),
|
||||||
|
dict(
|
||||||
|
testcase_name='OneInput_RaggedStr',
|
||||||
|
inputs=[ragged_const([['a', 'b'], [], ['c']])],
|
||||||
|
expected=ragged_const([[b'a', b'b'], [], [b'c']])),
|
||||||
|
dict(
|
||||||
|
testcase_name='OneInput_RaggedInt',
|
||||||
|
inputs=[ragged_const([[1, 2, 3], [4, 5]])],
|
||||||
|
expected=ragged_const([[b'1', b'2', b'3'], [b'4', b'5']])),
|
||||||
|
dict(
|
||||||
|
testcase_name='OneInput_DenseInt',
|
||||||
|
inputs=[dense_const([[1, 2, 3], [4, 5, 6]])],
|
||||||
|
expected=ragged_const([[b'1', b'2', b'3'], [b'4', b'5', b'6']])),
|
||||||
|
dict(
|
||||||
|
testcase_name='OneInput_SparseStr',
|
||||||
|
inputs=[sparse_const([['a', 'b'], [], ['c']])],
|
||||||
|
expected=ragged_const([[b'a', b'b'], [], [b'c']])),
|
||||||
|
dict(
|
||||||
|
testcase_name='TwoInputs_RaggedStr_RaggedStr',
|
||||||
|
inputs=[
|
||||||
|
ragged_const([['a', 'b'], [], ['c']]),
|
||||||
|
ragged_const([['d', 'e'], ['f'], ['g']])
|
||||||
|
],
|
||||||
|
expected=ragged_const([[b'a_X_d', b'a_X_e', b'b_X_d', b'b_X_e'], [],
|
||||||
|
[b'c_X_g']])),
|
||||||
|
dict(
|
||||||
|
testcase_name='TwoInputs_RaggedInt_RaggedInt',
|
||||||
|
inputs=[
|
||||||
|
ragged_const([[1, 2], [], [3]]),
|
||||||
|
ragged_const([[4, 5, 6], [], [7]])
|
||||||
|
],
|
||||||
|
expected=ragged_const(
|
||||||
|
[[b'1_X_4', b'1_X_5', b'1_X_6', b'2_X_4', b'2_X_5', b'2_X_6'], [],
|
||||||
|
[b'3_X_7']])),
|
||||||
|
dict(
|
||||||
|
testcase_name='TwoInputs_RaggedStr_RaggedInt',
|
||||||
|
inputs=[
|
||||||
|
ragged_const([['a', 'b'], [], ['c']]),
|
||||||
|
ragged_const([['1', '2'], ['3'], ['4']])
|
||||||
|
],
|
||||||
|
expected=ragged_const([[b'a_X_1', b'a_X_2', b'b_X_1', b'b_X_2'], [],
|
||||||
|
[b'c_X_4']])),
|
||||||
|
dict(
|
||||||
|
testcase_name='TwoInputs_SparseStr_SparseStr',
|
||||||
|
inputs=[
|
||||||
|
sparse_const([['a', 'b'], [], ['c']]),
|
||||||
|
sparse_const([['d', 'e'], ['f'], ['g']])
|
||||||
|
],
|
||||||
|
expected=ragged_const([[b'a_X_d', b'a_X_e', b'b_X_d', b'b_X_e'], [],
|
||||||
|
[b'c_X_g']])),
|
||||||
|
dict(
|
||||||
|
testcase_name='TwoInputs_DenseInt_DenseInt',
|
||||||
|
inputs=[dense_const([[1, 2], [3, 4]]),
|
||||||
|
dense_const([[5, 6], [7, 8]])],
|
||||||
|
expected=ragged_const([[b'1_X_5', b'1_X_6', b'2_X_5', b'2_X_6'],
|
||||||
|
[b'3_X_7', b'3_X_8', b'4_X_7', b'4_X_8']])),
|
||||||
|
dict(
|
||||||
|
testcase_name='TwoInputs_DenseInt_DenseStr',
|
||||||
|
inputs=[
|
||||||
|
dense_const([[1, 2], [3, 4]]),
|
||||||
|
dense_const([[b'5', b'6'], [b'7', b'8']])
|
||||||
|
],
|
||||||
|
expected=ragged_const([[b'1_X_5', b'1_X_6', b'2_X_5', b'2_X_6'],
|
||||||
|
[b'3_X_7', b'3_X_8', b'4_X_7', b'4_X_8']])),
|
||||||
|
dict(
|
||||||
|
testcase_name='TwoInputs_RaggedInt_DenseInt',
|
||||||
|
inputs=[
|
||||||
|
ragged_const([[], [], [1, 2], [3]]),
|
||||||
|
dense_const([[1, 2], [3, 4], [5, 6], [7, 8]])
|
||||||
|
],
|
||||||
|
expected=ragged_const([[], [],
|
||||||
|
[b'1_X_5', b'1_X_6', b'2_X_5', b'2_X_6'],
|
||||||
|
[b'3_X_7', b'3_X_8']])),
|
||||||
|
dict(
|
||||||
|
# This test exercises `input_order`.
|
||||||
|
testcase_name='TwoInputs_DenseInt_RaggedStr',
|
||||||
|
inputs=[
|
||||||
|
dense_const([[1, 2], [3, 4], [5, 6]]),
|
||||||
|
ragged_const([['d', 'e'], ['f'], ['g']])
|
||||||
|
],
|
||||||
|
expected=ragged_const([[b'1_X_d', b'1_X_e', b'2_X_d', b'2_X_e'],
|
||||||
|
[b'3_X_f', b'4_X_f'], [b'5_X_g', b'6_X_g']]),
|
||||||
|
matches_sparse_cross=False # sparse doesn't preserve input order.
|
||||||
|
),
|
||||||
|
dict(
|
||||||
|
# This test exercises `input_order`.
|
||||||
|
testcase_name='TwoInputs_SparseInt_RaggedStr',
|
||||||
|
inputs=[
|
||||||
|
sparse_const([[1, 2], [3, 4], [5, 6]]),
|
||||||
|
ragged_const([['d', 'e'], ['f'], ['g']])
|
||||||
|
],
|
||||||
|
expected=ragged_const([[b'1_X_d', b'1_X_e', b'2_X_d', b'2_X_e'],
|
||||||
|
[b'3_X_f', b'4_X_f'], [b'5_X_g', b'6_X_g']]),
|
||||||
|
matches_sparse_cross=False # sparse doesn't preserve input order.
|
||||||
|
),
|
||||||
|
dict(
|
||||||
|
testcase_name='ThreeInputs_RaggedInt_RaggedInt_RaggedInt',
|
||||||
|
inputs=[
|
||||||
|
ragged_const([[11], [12, 13], [], [14, 15]]),
|
||||||
|
ragged_const([[21, 22], [23], [24, 25], [26, 27]]),
|
||||||
|
ragged_const([[31], [32, 33], [34, 35], [36, 37]])
|
||||||
|
],
|
||||||
|
expected=ragged_const([[b'11_X_21_X_31', b'11_X_22_X_31'],
|
||||||
|
[
|
||||||
|
b'12_X_23_X_32', b'12_X_23_X_33',
|
||||||
|
b'13_X_23_X_32', b'13_X_23_X_33'
|
||||||
|
], [],
|
||||||
|
[
|
||||||
|
b'14_X_26_X_36', b'14_X_26_X_37',
|
||||||
|
b'14_X_27_X_36', b'14_X_27_X_37',
|
||||||
|
b'15_X_26_X_36', b'15_X_26_X_37',
|
||||||
|
b'15_X_27_X_36', b'15_X_27_X_37'
|
||||||
|
]])),
|
||||||
|
dict(
|
||||||
|
testcase_name='ThreeInputs_RaggedInt_SparseInt_DenseInt',
|
||||||
|
inputs=[
|
||||||
|
ragged_const([[11], [12, 13], [], [14, 15]]),
|
||||||
|
sparse_const([[21, 22], [23], [24, 25], [26, 27]]),
|
||||||
|
dense_const([[31], [32], [33], [34]])
|
||||||
|
],
|
||||||
|
expected=ragged_const([[b'11_X_21_X_31', b'11_X_22_X_31'],
|
||||||
|
[
|
||||||
|
b'12_X_23_X_32',
|
||||||
|
b'13_X_23_X_32',
|
||||||
|
], [],
|
||||||
|
[
|
||||||
|
b'14_X_26_X_34',
|
||||||
|
b'14_X_27_X_34',
|
||||||
|
b'15_X_26_X_34',
|
||||||
|
b'15_X_27_X_34',
|
||||||
|
]])),
|
||||||
|
dict(
|
||||||
|
testcase_name='FiveInputs',
|
||||||
|
inputs=[
|
||||||
|
ragged_const([[1]]),
|
||||||
|
dense_const([[2]]),
|
||||||
|
ragged_const([[3]]),
|
||||||
|
sparse_const([[4]]),
|
||||||
|
ragged_const([[5]])
|
||||||
|
],
|
||||||
|
expected=ragged_const([[b'1_X_2_X_3_X_4_X_5']]),
|
||||||
|
matches_sparse_cross=False # sparse doesn't preserve input order.
|
||||||
|
),
|
||||||
|
dict(
|
||||||
|
testcase_name='Permutation_3x3x3',
|
||||||
|
inputs=[[['11', '12', '13']], [['21', '22', '23']],
|
||||||
|
[['31', '32', '33']]],
|
||||||
|
expected=[[
|
||||||
|
b'11_X_21_X_31', b'11_X_21_X_32', b'11_X_21_X_33',
|
||||||
|
b'11_X_22_X_31', b'11_X_22_X_32', b'11_X_22_X_33',
|
||||||
|
b'11_X_23_X_31', b'11_X_23_X_32', b'11_X_23_X_33',
|
||||||
|
b'12_X_21_X_31', b'12_X_21_X_32', b'12_X_21_X_33',
|
||||||
|
b'12_X_22_X_31', b'12_X_22_X_32', b'12_X_22_X_33',
|
||||||
|
b'12_X_23_X_31', b'12_X_23_X_32', b'12_X_23_X_33',
|
||||||
|
b'13_X_21_X_31', b'13_X_21_X_32', b'13_X_21_X_33',
|
||||||
|
b'13_X_22_X_31', b'13_X_22_X_32', b'13_X_22_X_33',
|
||||||
|
b'13_X_23_X_31', b'13_X_23_X_32', b'13_X_23_X_33'
|
||||||
|
]]),
|
||||||
|
dict(
|
||||||
|
testcase_name='BatchSizeZero',
|
||||||
|
inputs=[
|
||||||
|
ragged_const([], ragged_rank=1, dtype=dtypes.int32),
|
||||||
|
sparse_const([]),
|
||||||
|
np.zeros([0, 3], dtype=np.int32),
|
||||||
|
],
|
||||||
|
expected=ragged_const([], ragged_rank=1, dtype=dtypes.int32)),
|
||||||
|
dict(
|
||||||
|
testcase_name='ThreeInputs_OneEmpty',
|
||||||
|
inputs=[
|
||||||
|
ragged_const([[1, 2]]),
|
||||||
|
ragged_const([[]], dtype=dtypes.int32),
|
||||||
|
ragged_const([[3, 4]])
|
||||||
|
],
|
||||||
|
expected=ragged_const([[]], dtype=dtypes.string)),
|
||||||
|
dict(
|
||||||
|
testcase_name='ThreeInputs_AllEmpty',
|
||||||
|
inputs=[
|
||||||
|
ragged_const([[]], dtype=dtypes.int64),
|
||||||
|
ragged_const([[]], dtype=dtypes.string),
|
||||||
|
ragged_const([[]], dtype=dtypes.int32)
|
||||||
|
],
|
||||||
|
expected=ragged_const([[]], ragged_rank=1, dtype=dtypes.string)),
|
||||||
|
dict(
|
||||||
|
testcase_name='HashedZeroBucketsDefaultKey',
|
||||||
|
inputs=[
|
||||||
|
ragged_const([['batch1-FC1-F1']]),
|
||||||
|
ragged_const([['batch1-FC2-F1']]),
|
||||||
|
ragged_const([['batch1-FC3-F1']])
|
||||||
|
],
|
||||||
|
expected_hashed=ragged_const([[1971693436396284976]])),
|
||||||
|
dict(
|
||||||
|
testcase_name='Hashed100BucketsDefaultKey',
|
||||||
|
inputs=[
|
||||||
|
ragged_const([['batch1-FC1-F1']]),
|
||||||
|
ragged_const([['batch1-FC2-F1']]),
|
||||||
|
ragged_const([['batch1-FC3-F1']])
|
||||||
|
],
|
||||||
|
num_buckets=100,
|
||||||
|
expected_hashed=ragged_const([[83]])),
|
||||||
|
dict(
|
||||||
|
testcase_name='HashedZeroBucketsCustomKey',
|
||||||
|
inputs=[
|
||||||
|
ragged_const([['batch1-FC1-F1']]),
|
||||||
|
ragged_const([['batch1-FC2-F1']]),
|
||||||
|
ragged_const([['batch1-FC3-F1']])
|
||||||
|
],
|
||||||
|
hash_key=ragged_array_ops._DEFAULT_CROSS_HASH_KEY + 1,
|
||||||
|
expected_hashed=ragged_const([[4847552627144134031]])),
|
||||||
|
dict(
|
||||||
|
testcase_name='Hashed100BucketsCustomKey',
|
||||||
|
inputs=[
|
||||||
|
ragged_const([['batch1-FC1-F1']]),
|
||||||
|
ragged_const([['batch1-FC2-F1']]),
|
||||||
|
ragged_const([['batch1-FC3-F1']])
|
||||||
|
],
|
||||||
|
num_buckets=100,
|
||||||
|
hash_key=ragged_array_ops._DEFAULT_CROSS_HASH_KEY + 1,
|
||||||
|
expected_hashed=ragged_const([[31]])),
|
||||||
|
dict(
|
||||||
|
testcase_name='HashedZeroKey',
|
||||||
|
inputs=[
|
||||||
|
ragged_const([['batch1-FC1-F1']]),
|
||||||
|
ragged_const([['batch1-FC2-F1']]),
|
||||||
|
ragged_const([['batch1-FC3-F1']])
|
||||||
|
],
|
||||||
|
hash_key=0,
|
||||||
|
expected_hashed=ragged_const([[9077905385164735582]]),
|
||||||
|
matches_sparse_cross=False # sparse treats hash_key=0 as None.
|
||||||
|
),
|
||||||
|
dict(
|
||||||
|
testcase_name='UInt64',
|
||||||
|
inputs=[ragged_const([[2**64 - 1]], dtype=dtypes.uint64)],
|
||||||
|
expected=ragged_const([[b'-1']])),
|
||||||
|
])
|
||||||
|
def testRaggedCross(self,
|
||||||
|
inputs,
|
||||||
|
num_buckets=0,
|
||||||
|
hash_key=None,
|
||||||
|
expected=None,
|
||||||
|
expected_hashed=None,
|
||||||
|
matches_sparse_cross=True):
|
||||||
|
ragged_cross = ragged_array_ops.cross(inputs)
|
||||||
|
ragged_cross_hashed = ragged_array_ops.cross_hashed(inputs, num_buckets,
|
||||||
|
hash_key)
|
||||||
|
|
||||||
|
if expected is not None:
|
||||||
|
self.assertAllEqual(ragged_cross, expected)
|
||||||
|
if expected_hashed is not None:
|
||||||
|
self.assertAllEqual(ragged_cross_hashed, expected_hashed)
|
||||||
|
|
||||||
|
if matches_sparse_cross:
|
||||||
|
# Check that ragged.cross & sparse.cross match.
|
||||||
|
sparse_inputs = [self._ragged_to_sparse(t) for t in inputs]
|
||||||
|
sparse_cross = sparse_ops.sparse_cross(sparse_inputs)
|
||||||
|
self.assertAllEqual(ragged_cross,
|
||||||
|
ragged_tensor.RaggedTensor.from_sparse(sparse_cross))
|
||||||
|
|
||||||
|
# Check that ragged.cross_hashed & sparse.cross_hashed match.
|
||||||
|
sparse_inputs = [self._ragged_to_sparse(t) for t in inputs]
|
||||||
|
sparse_cross_hashed = sparse_ops.sparse_cross_hashed(
|
||||||
|
sparse_inputs, num_buckets, hash_key)
|
||||||
|
self.assertAllEqual(
|
||||||
|
ragged_cross_hashed,
|
||||||
|
ragged_tensor.RaggedTensor.from_sparse(sparse_cross_hashed))
|
||||||
|
|
||||||
|
def testRaggedCrossLargeBatch(self):
|
||||||
|
batch_size = 5000
|
||||||
|
inputs = [
|
||||||
|
ragged_const([[1, 2, 3]] * batch_size),
|
||||||
|
ragged_const([[b'4']] * batch_size),
|
||||||
|
dense_const([[5]] * batch_size),
|
||||||
|
sparse_const([[6, 7]] * batch_size)
|
||||||
|
]
|
||||||
|
|
||||||
|
expected = [[
|
||||||
|
b'1_X_4_X_5_X_6', b'1_X_4_X_5_X_7', b'2_X_4_X_5_X_6', b'2_X_4_X_5_X_7',
|
||||||
|
b'3_X_4_X_5_X_6', b'3_X_4_X_5_X_7'
|
||||||
|
]] * batch_size
|
||||||
|
|
||||||
|
ragged_cross = ragged_array_ops.cross(inputs)
|
||||||
|
|
||||||
|
# Note: we don't use assertAllEqual here because if they don't match,
|
||||||
|
# then the code in assertAllEqual that tries to build the error message
|
||||||
|
# is very slow, causing the test to timeout.
|
||||||
|
# pylint: disable=g-generic-assert
|
||||||
|
self.assertTrue(self.evaluate(ragged_cross).to_list() == expected)
|
||||||
|
|
||||||
|
@parameterized.named_parameters([
|
||||||
|
dict(
|
||||||
|
testcase_name='BadDType',
|
||||||
|
inputs=[ragged_const([[1.1], [2.2, 3.3]])],
|
||||||
|
message=r'Unexpected dtype for inputs\[0\]'),
|
||||||
|
dict(
|
||||||
|
testcase_name='StaticBatchSizeMismatch1',
|
||||||
|
inputs=[ragged_const([[1]]),
|
||||||
|
ragged_const([[2], [3]])],
|
||||||
|
exception=(ValueError, errors.InvalidArgumentError),
|
||||||
|
message='inputs must all have the same batch dimension size'),
|
||||||
|
dict(
|
||||||
|
testcase_name='StaticBatchSizeMismatch2',
|
||||||
|
inputs=[ragged_const([[1]]),
|
||||||
|
dense_const([[2], [3]])],
|
||||||
|
exception=(ValueError, errors.InvalidArgumentError),
|
||||||
|
message='inputs must all have the same batch dimension size'),
|
||||||
|
])
|
||||||
|
def testStaticError(self, inputs, exception=ValueError, message=None):
|
||||||
|
with self.assertRaisesRegexp(exception, message):
|
||||||
|
ragged_array_ops.cross(inputs)
|
||||||
|
|
||||||
|
@parameterized.named_parameters([
|
||||||
|
dict(
|
||||||
|
testcase_name='3DRaggedTensor',
|
||||||
|
inputs=[ragged_const([[[1]]], ragged_rank=1)],
|
||||||
|
message='tf.ragged.cross only supports inputs with rank=2'),
|
||||||
|
dict(
|
||||||
|
testcase_name='3DDenseTensor',
|
||||||
|
inputs=[dense_const([[[1]]])],
|
||||||
|
message='tf.ragged.cross only supports inputs with rank=2'),
|
||||||
|
])
|
||||||
|
def testRuntimeError(self,
|
||||||
|
inputs,
|
||||||
|
exception=errors.InvalidArgumentError,
|
||||||
|
message=None):
|
||||||
|
with self.assertRaisesRegexp(exception, message):
|
||||||
|
self.evaluate(ragged_array_ops.cross(inputs))
|
||||||
|
|
||||||
|
def _ragged_to_sparse(self, t):
|
||||||
|
if ragged_tensor.is_ragged(t):
|
||||||
|
return ragged_tensor.convert_to_tensor_or_ragged_tensor(t).to_sparse()
|
||||||
|
elif sparse_tensor.is_sparse(t):
|
||||||
|
return sparse_tensor.SparseTensor.from_value(t)
|
||||||
|
else:
|
||||||
|
return ops.convert_to_tensor(t)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
googletest.main()
|
@ -16,6 +16,14 @@ tf_module {
|
|||||||
name: "constant_value"
|
name: "constant_value"
|
||||||
argspec: "args=[\'pylist\', \'dtype\', \'ragged_rank\', \'inner_shape\', \'row_splits_dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'int64\'], "
|
argspec: "args=[\'pylist\', \'dtype\', \'ragged_rank\', \'inner_shape\', \'row_splits_dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'int64\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "cross"
|
||||||
|
argspec: "args=[\'inputs\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "cross_hashed"
|
||||||
|
argspec: "args=[\'inputs\', \'num_buckets\', \'hash_key\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'None\', \'None\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "map_flat_values"
|
name: "map_flat_values"
|
||||||
argspec: "args=[\'op\'], varargs=args, keywords=kwargs, defaults=None"
|
argspec: "args=[\'op\'], varargs=args, keywords=kwargs, defaults=None"
|
||||||
|
@ -3020,6 +3020,10 @@ tf_module {
|
|||||||
name: "RGBToHSV"
|
name: "RGBToHSV"
|
||||||
argspec: "args=[\'images\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'images\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "RaggedCross"
|
||||||
|
argspec: "args=[\'ragged_values\', \'ragged_row_splits\', \'sparse_indices\', \'sparse_values\', \'sparse_shape\', \'dense_inputs\', \'input_order\', \'hashed_output\', \'num_buckets\', \'hash_key\', \'out_values_type\', \'out_row_splits_type\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "RaggedGather"
|
name: "RaggedGather"
|
||||||
argspec: "args=[\'params_nested_splits\', \'params_dense_values\', \'indices\', \'OUTPUT_RAGGED_RANK\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'params_nested_splits\', \'params_dense_values\', \'indices\', \'OUTPUT_RAGGED_RANK\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
@ -8,6 +8,14 @@ tf_module {
|
|||||||
name: "constant"
|
name: "constant"
|
||||||
argspec: "args=[\'pylist\', \'dtype\', \'ragged_rank\', \'inner_shape\', \'name\', \'row_splits_dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \"<dtype: \'int64\'>\"], "
|
argspec: "args=[\'pylist\', \'dtype\', \'ragged_rank\', \'inner_shape\', \'name\', \'row_splits_dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \"<dtype: \'int64\'>\"], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "cross"
|
||||||
|
argspec: "args=[\'inputs\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "cross_hashed"
|
||||||
|
argspec: "args=[\'inputs\', \'num_buckets\', \'hash_key\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'None\', \'None\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "map_flat_values"
|
name: "map_flat_values"
|
||||||
argspec: "args=[\'op\'], varargs=args, keywords=kwargs, defaults=None"
|
argspec: "args=[\'op\'], varargs=args, keywords=kwargs, defaults=None"
|
||||||
|
@ -3020,6 +3020,10 @@ tf_module {
|
|||||||
name: "RGBToHSV"
|
name: "RGBToHSV"
|
||||||
argspec: "args=[\'images\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'images\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "RaggedCross"
|
||||||
|
argspec: "args=[\'ragged_values\', \'ragged_row_splits\', \'sparse_indices\', \'sparse_values\', \'sparse_shape\', \'dense_inputs\', \'input_order\', \'hashed_output\', \'num_buckets\', \'hash_key\', \'out_values_type\', \'out_row_splits_type\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "RaggedGather"
|
name: "RaggedGather"
|
||||||
argspec: "args=[\'params_nested_splits\', \'params_dense_values\', \'indices\', \'OUTPUT_RAGGED_RANK\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'params_nested_splits\', \'params_dense_values\', \'indices\', \'OUTPUT_RAGGED_RANK\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
Loading…
x
Reference in New Issue
Block a user