Update RaggedTensors to support int32 row_splits.

PiperOrigin-RevId: 245157497
This commit is contained in:
Edward Loper 2019-04-24 18:40:49 -07:00 committed by TensorFlower Gardener
parent 02bd711c79
commit c45be92834
41 changed files with 778 additions and 404 deletions

View File

@ -30,10 +30,11 @@ namespace {
// For each slice in `(start, limit)` in `value_slices`, append // For each slice in `(start, limit)` in `value_slices`, append
// `params_dense_values_in[start:limit] to `values_out`. `value_size` indicates // `params_dense_values_in[start:limit] to `values_out`. `value_size` indicates
// the number of scalars contained in each value params_dense_values_in[i]. // the number of scalars contained in each value params_dense_values_in[i].
template <typename VALUE_TYPE> template <typename VALUE_TYPE, typename SPLITS_TYPE>
void WriteValueSlices(const Tensor& params_dense_values_in, void WriteValueSlices(
const std::vector<std::pair<int64, int64>>& value_slices, const Tensor& params_dense_values_in,
int64 value_size, Tensor* values_out) { const std::vector<std::pair<SPLITS_TYPE, SPLITS_TYPE>>& value_slices,
SPLITS_TYPE value_size, Tensor* values_out) {
const auto& params_dense_values = const auto& params_dense_values =
params_dense_values_in.flat_outer_dims<VALUE_TYPE, 2>(); params_dense_values_in.flat_outer_dims<VALUE_TYPE, 2>();
auto values = values_out->flat_outer_dims<VALUE_TYPE, 2>(); auto values = values_out->flat_outer_dims<VALUE_TYPE, 2>();
@ -50,7 +51,7 @@ void WriteValueSlices(const Tensor& params_dense_values_in,
} // namespace } // namespace
template <typename INDEX_TYPE> template <typename INDEX_TYPE, typename SPLITS_TYPE>
class RaggedGatherOpBase : public OpKernel { class RaggedGatherOpBase : public OpKernel {
public: public:
using OpKernel::OpKernel; using OpKernel::OpKernel;
@ -66,18 +67,18 @@ class RaggedGatherOpBase : public OpKernel {
context->input(params_nested_splits_in.size() + 1); context->input(params_nested_splits_in.size() + 1);
DCHECK_GT(params_nested_splits_in.size(), 0); // Enforced by REGISTER_OP. DCHECK_GT(params_nested_splits_in.size(), 0); // Enforced by REGISTER_OP.
int64 num_params = params_nested_splits_in[0].dim_size(0) - 1; SPLITS_TYPE num_params = params_nested_splits_in[0].dim_size(0) - 1;
OP_REQUIRES_OK(context, ValidateIndices(indices_in, num_params)); OP_REQUIRES_OK(context, ValidateIndices(indices_in, num_params));
OP_REQUIRES(context, params_dense_values_in.dims() > 0, OP_REQUIRES(context, params_dense_values_in.dims() > 0,
errors::InvalidArgument("params.rank must be nonzero")); errors::InvalidArgument("params.rank must be nonzero"));
int64 num_params_dense_values = params_dense_values_in.dim_size(0); SPLITS_TYPE num_params_dense_values = params_dense_values_in.dim_size(0);
// Calculate the `splits`, and store the value slices that we need to // Calculate the `splits`, and store the value slices that we need to
// copy in `value_slices`. // copy in `value_slices`.
std::vector<std::pair<int64, int64>> value_slices; std::vector<std::pair<SPLITS_TYPE, SPLITS_TYPE>> value_slices;
int64 num_values = 0; SPLITS_TYPE num_values = 0;
std::vector<std::vector<int64>> out_splits; std::vector<std::vector<SPLITS_TYPE>> out_splits;
OP_REQUIRES_OK(context, MakeSplits(indices_in, params_nested_splits_in, OP_REQUIRES_OK(context, MakeSplits(indices_in, params_nested_splits_in,
num_params_dense_values, &out_splits, num_params_dense_values, &out_splits,
&value_slices, &num_values)); &value_slices, &num_values));
@ -90,12 +91,14 @@ class RaggedGatherOpBase : public OpKernel {
} }
private: private:
using ConstFlatType = typename TTypes<SPLITS_TYPE>::ConstFlat;
// Check if any indices are out-of-bounds. // Check if any indices are out-of-bounds.
::tensorflow::Status ValidateIndices(const Tensor& indices_in, ::tensorflow::Status ValidateIndices(const Tensor& indices_in,
int64 num_params) { SPLITS_TYPE num_params) {
const auto& indices = indices_in.flat<INDEX_TYPE>(); const auto& indices = indices_in.flat<INDEX_TYPE>();
for (int64 i = 0; i < indices.size(); ++i) { for (SPLITS_TYPE i = 0; i < indices.size(); ++i) {
int64 index = indices(i); SPLITS_TYPE index = indices(i);
if (index < 0 || index >= num_params) { if (index < 0 || index >= num_params) {
return errors::InvalidArgument( return errors::InvalidArgument(
"indices", SliceDebugString(indices_in.shape(), i), " = ", index, "indices", SliceDebugString(indices_in.shape(), i), " = ", index,
@ -111,9 +114,10 @@ class RaggedGatherOpBase : public OpKernel {
// we need for allocating the output values tensor) is stored in `num_values`. // we need for allocating the output values tensor) is stored in `num_values`.
::tensorflow::Status MakeSplits( ::tensorflow::Status MakeSplits(
const Tensor& indices_in, const OpInputList& params_nested_splits_in, const Tensor& indices_in, const OpInputList& params_nested_splits_in,
int64 num_params_dense_values, SPLITS_TYPE num_params_dense_values,
std::vector<std::vector<int64>>* out_splits, std::vector<std::vector<SPLITS_TYPE>>* out_splits,
std::vector<std::pair<int64, int64>>* value_slices, int64* num_values) { std::vector<std::pair<SPLITS_TYPE, SPLITS_TYPE>>* value_slices,
SPLITS_TYPE* num_values) {
*num_values = 0; *num_values = 0;
value_slices->clear(); value_slices->clear();
@ -122,10 +126,10 @@ class RaggedGatherOpBase : public OpKernel {
// Get Eigen tensors. // Get Eigen tensors.
const auto& indices = indices_in.flat<INDEX_TYPE>(); const auto& indices = indices_in.flat<INDEX_TYPE>();
std::vector<TTypes<int64>::ConstFlat> params_nested_splits; std::vector<ConstFlatType> params_nested_splits;
params_nested_splits.reserve(params_nested_splits_in.size()); params_nested_splits.reserve(params_nested_splits_in.size());
for (const auto& splits_in : params_nested_splits_in) { for (const auto& splits_in : params_nested_splits_in) {
params_nested_splits.push_back(splits_in.flat<int64>()); params_nested_splits.push_back(splits_in.flat<SPLITS_TYPE>());
} }
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
@ -165,7 +169,7 @@ class RaggedGatherOpBase : public OpKernel {
const auto& splits = params_nested_splits[dim]; const auto& splits = params_nested_splits[dim];
int out_dim = dim + indices_in.dims() - 1; int out_dim = dim + indices_in.dims() - 1;
if (out_dim >= 0) { if (out_dim >= 0) {
int64 delta = out_splits->at(out_dim).back() - splits(start); SPLITS_TYPE delta = out_splits->at(out_dim).back() - splits(start);
for (int j = start; j < limit; ++j) { for (int j = start; j < limit; ++j) {
out_splits->at(out_dim).push_back(splits(j + 1) + delta); out_splits->at(out_dim).push_back(splits(j + 1) + delta);
} }
@ -182,14 +186,14 @@ class RaggedGatherOpBase : public OpKernel {
} }
::tensorflow::Status ValidateSplits( ::tensorflow::Status ValidateSplits(
const std::vector<TTypes<int64>::ConstFlat>& params_nested_splits, const std::vector<ConstFlatType>& params_nested_splits,
int64 num_params_dense_values) { SPLITS_TYPE num_params_dense_values) {
// Validate // Validate
for (int dim = 0; dim < params_nested_splits.size(); ++dim) { for (int dim = 0; dim < params_nested_splits.size(); ++dim) {
const auto& splits = params_nested_splits[dim]; const auto& splits = params_nested_splits[dim];
int64 last_split = (dim == params_nested_splits.size() - 1) SPLITS_TYPE last_split = (dim == params_nested_splits.size() - 1)
? num_params_dense_values ? num_params_dense_values
: params_nested_splits[dim + 1].size(); : params_nested_splits[dim + 1].size();
if (splits.size() == 0) { if (splits.size() == 0) {
return errors::InvalidArgument("Ragged splits may not be empty"); return errors::InvalidArgument("Ragged splits may not be empty");
} }
@ -210,17 +214,17 @@ class RaggedGatherOpBase : public OpKernel {
} }
::tensorflow::Status WriteSplits( ::tensorflow::Status WriteSplits(
const std::vector<std::vector<int64>>& out_splits, const std::vector<std::vector<SPLITS_TYPE>>& out_splits,
OpKernelContext* context) { OpKernelContext* context) {
OpOutputList splits_out; OpOutputList splits_out;
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
context->output_list("output_nested_splits", &splits_out)); context->output_list("output_nested_splits", &splits_out));
for (int i = 0; i < out_splits.size(); ++i) { for (int i = 0; i < out_splits.size(); ++i) {
Tensor* splits; Tensor* splits;
int64 num_splits = out_splits[i].size(); SPLITS_TYPE num_splits = out_splits[i].size();
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
splits_out.allocate(i, TensorShape({num_splits}), &splits)); splits_out.allocate(i, TensorShape({num_splits}), &splits));
auto splits_flat = splits->flat<int64>(); auto splits_flat = splits->flat<SPLITS_TYPE>();
std::copy_n(out_splits[i].data(), out_splits[i].size(), std::copy_n(out_splits[i].data(), out_splits[i].size(),
splits_flat.data()); splits_flat.data());
} }
@ -229,15 +233,16 @@ class RaggedGatherOpBase : public OpKernel {
::tensorflow::Status WriteValues( ::tensorflow::Status WriteValues(
const Tensor& params_dense_values_in, const Tensor& params_dense_values_in,
const std::vector<std::pair<int64, int64>>& value_slices, const std::vector<std::pair<SPLITS_TYPE, SPLITS_TYPE>>& value_slices,
int values_index, int64 num_values, OpKernelContext* context) const { int values_index, SPLITS_TYPE num_values,
OpKernelContext* context) const {
Tensor* values_out = nullptr; Tensor* values_out = nullptr;
TensorShape values_shape = params_dense_values_in.shape(); TensorShape values_shape = params_dense_values_in.shape();
values_shape.set_dim(0, num_values); values_shape.set_dim(0, num_values);
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
context->allocate_output(values_index, values_shape, &values_out)); context->allocate_output(values_index, values_shape, &values_out));
const int64 num_elements = params_dense_values_in.NumElements(); const SPLITS_TYPE num_elements = params_dense_values_in.NumElements();
const int64 value_size = const SPLITS_TYPE value_size =
num_elements == 0 ? 0 num_elements == 0 ? 0
: (num_elements / params_dense_values_in.dim_size(0)); : (num_elements / params_dense_values_in.dim_size(0));
CallWriteValueSlices(params_dense_values_in, value_slices, value_size, CallWriteValueSlices(params_dense_values_in, value_slices, value_size,
@ -253,34 +258,39 @@ class RaggedGatherOpBase : public OpKernel {
// which cuts the binary size of this op from ~300k to <90k. // which cuts the binary size of this op from ~300k to <90k.
virtual void CallWriteValueSlices( virtual void CallWriteValueSlices(
const Tensor& params_dense_values_in, const Tensor& params_dense_values_in,
const std::vector<std::pair<int64, int64>>& value_slices, const std::vector<std::pair<SPLITS_TYPE, SPLITS_TYPE>>& value_slices,
int64 value_size, Tensor* values_out) const = 0; SPLITS_TYPE value_size, Tensor* values_out) const = 0;
}; };
template <typename INDEX_TYPE, typename VALUE_TYPE> template <typename INDEX_TYPE, typename VALUE_TYPE, typename SPLITS_TYPE>
class RaggedGatherOp : public RaggedGatherOpBase<INDEX_TYPE> { class RaggedGatherOp : public RaggedGatherOpBase<INDEX_TYPE, SPLITS_TYPE> {
public: public:
using RaggedGatherOpBase<INDEX_TYPE>::RaggedGatherOpBase; using RaggedGatherOpBase<INDEX_TYPE, SPLITS_TYPE>::RaggedGatherOpBase;
private: private:
void CallWriteValueSlices( void CallWriteValueSlices(
const Tensor& params_dense_values_in, const Tensor& params_dense_values_in,
const std::vector<std::pair<int64, int64>>& value_slices, const std::vector<std::pair<SPLITS_TYPE, SPLITS_TYPE>>& value_slices,
int64 value_size, Tensor* values_out) const override { SPLITS_TYPE value_size, Tensor* values_out) const override {
WriteValueSlices<VALUE_TYPE>(params_dense_values_in, value_slices, WriteValueSlices<VALUE_TYPE>(params_dense_values_in, value_slices,
value_size, values_out); value_size, values_out);
} }
}; };
#define REGISTER_CPU_KERNEL_WITH_INDEX_TYPE(index_type, value_type) \ #define REGISTER_CPU_KERNEL_WITH_INDEX_TYPE(index_type, value_type, \
REGISTER_KERNEL_BUILDER(Name("RaggedGather") \ splits_type) \
.Device(DEVICE_CPU) \ REGISTER_KERNEL_BUILDER( \
.TypeConstraint<index_type>("Tindices") \ Name("RaggedGather") \
.TypeConstraint<value_type>("Tvalues"), \ .Device(DEVICE_CPU) \
RaggedGatherOp<index_type, value_type>); .TypeConstraint<index_type>("Tindices") \
#define REGISTER_CPU_KERNEL(value_type) \ .TypeConstraint<value_type>("Tvalues") \
REGISTER_CPU_KERNEL_WITH_INDEX_TYPE(int32, value_type) \ .TypeConstraint<splits_type>("Tsplits"), \
REGISTER_CPU_KERNEL_WITH_INDEX_TYPE(int64, value_type) RaggedGatherOp<index_type, value_type, splits_type>);
#define REGISTER_CPU_KERNEL(value_type) \
REGISTER_CPU_KERNEL_WITH_INDEX_TYPE(int32, value_type, int32) \
REGISTER_CPU_KERNEL_WITH_INDEX_TYPE(int64, value_type, int32) \
REGISTER_CPU_KERNEL_WITH_INDEX_TYPE(int32, value_type, int64) \
REGISTER_CPU_KERNEL_WITH_INDEX_TYPE(int64, value_type, int64)
TF_CALL_POD_TYPES(REGISTER_CPU_KERNEL); TF_CALL_POD_TYPES(REGISTER_CPU_KERNEL);
TF_CALL_string(REGISTER_CPU_KERNEL); TF_CALL_string(REGISTER_CPU_KERNEL);
TF_CALL_QUANTIZED_TYPES(REGISTER_CPU_KERNEL); TF_CALL_QUANTIZED_TYPES(REGISTER_CPU_KERNEL);

View File

@ -26,7 +26,7 @@ namespace tensorflow {
using errors::InvalidArgument; using errors::InvalidArgument;
template <typename T> template <typename T, typename SPLITS_TYPE>
class RaggedRangeOp : public OpKernel { class RaggedRangeOp : public OpKernel {
public: public:
using OpKernel::OpKernel; using OpKernel::OpKernel;
@ -60,7 +60,7 @@ class RaggedRangeOp : public OpKernel {
InvalidArgument("starts, limits, and deltas must have the " InvalidArgument("starts, limits, and deltas must have the "
"same shape")); "same shape"));
} }
int64 nrows = in_sizes.empty() ? 1 : in_sizes[0]; SPLITS_TYPE nrows = in_sizes.empty() ? 1 : in_sizes[0];
const auto& starts = starts_in.flat<T>(); const auto& starts = starts_in.flat<T>();
const auto& limits = limits_in.flat<T>(); const auto& limits = limits_in.flat<T>();
@ -71,7 +71,7 @@ class RaggedRangeOp : public OpKernel {
OP_REQUIRES_OK(context, OP_REQUIRES_OK(context,
context->allocate_output(0, TensorShape({nrows + 1}), context->allocate_output(0, TensorShape({nrows + 1}),
&rt_nested_splits_out)); &rt_nested_splits_out));
auto rt_nested_splits = rt_nested_splits_out->flat<int64>(); auto rt_nested_splits = rt_nested_splits_out->flat<SPLITS_TYPE>();
rt_nested_splits(0) = 0; rt_nested_splits(0) = 0;
for (int row = 0; row < nrows; ++row) { for (int row = 0; row < nrows; ++row) {
T start = broadcast_starts ? starts(0) : starts(row); T start = broadcast_starts ? starts(0) : starts(row);
@ -81,7 +81,7 @@ class RaggedRangeOp : public OpKernel {
rt_nested_splits(row + 1) = rt_nested_splits(row + 1) =
rt_nested_splits(row) + RangeSize(start, limit, delta); rt_nested_splits(row) + RangeSize(start, limit, delta);
} }
int64 nvals = rt_nested_splits(nrows); SPLITS_TYPE nvals = rt_nested_splits(nrows);
// Construct the rt_dense_values tensor. // Construct the rt_dense_values tensor.
Tensor* rt_dense_values_out = nullptr; Tensor* rt_dense_values_out = nullptr;
@ -90,10 +90,10 @@ class RaggedRangeOp : public OpKernel {
auto rt_dense_values = rt_dense_values_out->flat<T>(); auto rt_dense_values = rt_dense_values_out->flat<T>();
int value_index = 0; int value_index = 0;
for (int row = 0; row < nrows; ++row) { for (int row = 0; row < nrows; ++row) {
int64 row_size = rt_nested_splits(row + 1) - rt_nested_splits(row); SPLITS_TYPE row_size = rt_nested_splits(row + 1) - rt_nested_splits(row);
T value = broadcast_starts ? starts(0) : starts(row); T value = broadcast_starts ? starts(0) : starts(row);
T delta = broadcast_deltas ? deltas(0) : deltas(row); T delta = broadcast_deltas ? deltas(0) : deltas(row);
for (int64 i = 0; i < row_size; ++i) { for (SPLITS_TYPE i = 0; i < row_size; ++i) {
rt_dense_values(value_index++) = T(value); rt_dense_values(value_index++) = T(value);
value += delta; value += delta;
} }
@ -102,7 +102,7 @@ class RaggedRangeOp : public OpKernel {
private: private:
// Returns the number of elements in the specified range. // Returns the number of elements in the specified range.
int64 RangeSize(T start, T limit, T delta) { SPLITS_TYPE RangeSize(T start, T limit, T delta) {
if (((delta > 0) && (limit < start)) || ((delta < 0) && (limit > start))) { if (((delta > 0) && (limit < start)) || ((delta < 0) && (limit > start))) {
return 0; return 0;
} }
@ -114,10 +114,17 @@ class RaggedRangeOp : public OpKernel {
} }
}; };
#define REGISTER_CPU_KERNEL(TYPE) \ #define REGISTER_CPU_KERNEL(TYPE) \
REGISTER_KERNEL_BUILDER( \ REGISTER_KERNEL_BUILDER(Name("RaggedRange") \
Name("RaggedRange").Device(DEVICE_CPU).TypeConstraint<TYPE>("T"), \ .Device(DEVICE_CPU) \
RaggedRangeOp<TYPE>); .TypeConstraint<TYPE>("T") \
.TypeConstraint<int32>("Tsplits"), \
RaggedRangeOp<TYPE, int32>); \
REGISTER_KERNEL_BUILDER(Name("RaggedRange") \
.Device(DEVICE_CPU) \
.TypeConstraint<TYPE>("T") \
.TypeConstraint<int64>("Tsplits"), \
RaggedRangeOp<TYPE, int64>);
TF_CALL_float(REGISTER_CPU_KERNEL); TF_CALL_float(REGISTER_CPU_KERNEL);
TF_CALL_double(REGISTER_CPU_KERNEL); TF_CALL_double(REGISTER_CPU_KERNEL);
TF_CALL_int32(REGISTER_CPU_KERNEL); TF_CALL_int32(REGISTER_CPU_KERNEL);

View File

@ -26,21 +26,23 @@ namespace tensorflow {
using errors::InvalidArgument; using errors::InvalidArgument;
template <typename SPLITS_TYPE>
class RaggedTensorToSparseOp : public OpKernel { class RaggedTensorToSparseOp : public OpKernel {
public: public:
using OpKernel::OpKernel; using OpKernel::OpKernel;
using ConstFlatSplits = typename TTypes<SPLITS_TYPE>::ConstFlat;
void Compute(OpKernelContext* context) override { void Compute(OpKernelContext* context) override {
// Read the `rt_nested_splits` input & convert to Eigen tensors. // Read the `rt_nested_splits` input & convert to Eigen tensors.
OpInputList rt_nested_splits_in; OpInputList rt_nested_splits_in;
OP_REQUIRES_OK( OP_REQUIRES_OK(
context, context->input_list("rt_nested_splits", &rt_nested_splits_in)); context, context->input_list("rt_nested_splits", &rt_nested_splits_in));
const int64 rt_nested_splits_len = rt_nested_splits_in.size(); const int rt_nested_splits_len = rt_nested_splits_in.size();
DCHECK_GT(rt_nested_splits_len, 0); // Enforced by REGISTER_OP. DCHECK_GT(rt_nested_splits_len, 0); // Enforced by REGISTER_OP.
std::vector<TTypes<int64>::ConstFlat> rt_nested_splits; std::vector<ConstFlatSplits> rt_nested_splits;
rt_nested_splits.reserve(rt_nested_splits_len); rt_nested_splits.reserve(rt_nested_splits_len);
for (int i = 0; i < rt_nested_splits_len; ++i) { for (int i = 0; i < rt_nested_splits_len; ++i) {
rt_nested_splits.push_back(rt_nested_splits_in[i].flat<int64>()); rt_nested_splits.push_back(rt_nested_splits_in[i].flat<SPLITS_TYPE>());
} }
// Read the `rt_dense_values` input. // Read the `rt_dense_values` input.
@ -135,7 +137,7 @@ class RaggedTensorToSparseOp : public OpKernel {
sparse_dense_shape(0) = rt_nested_splits_in[0].dim_size(0) - 1; sparse_dense_shape(0) = rt_nested_splits_in[0].dim_size(0) - 1;
for (int dim = 0; dim < rt_nested_splits_len; ++dim) { for (int dim = 0; dim < rt_nested_splits_len; ++dim) {
const auto& splits = rt_nested_splits[dim]; const auto& splits = rt_nested_splits[dim];
int64 max_width = 0; SPLITS_TYPE max_width = 0;
for (int i = 1; i < splits.size(); ++i) { for (int i = 1; i < splits.size(); ++i) {
max_width = std::max(max_width, splits(i) - splits(i - 1)); max_width = std::max(max_width, splits(i) - splits(i - 1));
} }
@ -150,7 +152,7 @@ class RaggedTensorToSparseOp : public OpKernel {
private: private:
// Validate `rt_nested_splits` to ensure we don't get any segfaults. // Validate `rt_nested_splits` to ensure we don't get any segfaults.
static ::tensorflow::Status ValidateInputs( static ::tensorflow::Status ValidateInputs(
std::vector<TTypes<int64>::ConstFlat> rt_nested_splits, std::vector<ConstFlatSplits> rt_nested_splits,
const Tensor& rt_dense_values_in) { const Tensor& rt_dense_values_in) {
for (int i = 0; i < rt_nested_splits.size(); ++i) { for (int i = 0; i < rt_nested_splits.size(); ++i) {
if (rt_nested_splits[i].size() == 0) { if (rt_nested_splits[i].size() == 0) {
@ -160,7 +162,7 @@ class RaggedTensorToSparseOp : public OpKernel {
return InvalidArgument("First value of ragged splits must be 0."); return InvalidArgument("First value of ragged splits must be 0.");
} }
if (i > 0) { if (i > 0) {
int64 last_split = SPLITS_TYPE last_split =
rt_nested_splits[i - 1](rt_nested_splits[i - 1].size() - 1); rt_nested_splits[i - 1](rt_nested_splits[i - 1].size() - 1);
if (rt_nested_splits[i].size() != last_split + 1) { if (rt_nested_splits[i].size() != last_split + 1) {
return InvalidArgument( return InvalidArgument(
@ -206,14 +208,21 @@ class RaggedTensorToSparseOp : public OpKernel {
// values. // values.
static bool IsCompleted( static bool IsCompleted(
const std::vector<int64>& pos, int dim, const std::vector<int64>& pos, int dim,
const std::vector<TTypes<int64>::ConstFlat>& rt_nested_splits) { const std::vector<ConstFlatSplits>& rt_nested_splits) {
int64 current_child = pos[dim + 1]; int64 current_child = pos[dim + 1];
int64 limit_child = rt_nested_splits[dim](pos[dim] + 1); int64 limit_child = rt_nested_splits[dim](pos[dim] + 1);
return current_child >= limit_child; return current_child >= limit_child;
} }
}; };
REGISTER_KERNEL_BUILDER(Name("RaggedTensorToSparse").Device(DEVICE_CPU), REGISTER_KERNEL_BUILDER(Name("RaggedTensorToSparse")
RaggedTensorToSparseOp); .Device(DEVICE_CPU)
.TypeConstraint<int32>("Tsplits"),
RaggedTensorToSparseOp<int32>);
REGISTER_KERNEL_BUILDER(Name("RaggedTensorToSparse")
.Device(DEVICE_CPU)
.TypeConstraint<int64>("Tsplits"),
RaggedTensorToSparseOp<int64>);
} // namespace tensorflow } // namespace tensorflow

View File

@ -350,6 +350,7 @@ class UnicodeTranscodeOp : public OpKernel {
REGISTER_KERNEL_BUILDER(Name("UnicodeTranscode").Device(DEVICE_CPU), REGISTER_KERNEL_BUILDER(Name("UnicodeTranscode").Device(DEVICE_CPU),
UnicodeTranscodeOp); UnicodeTranscodeOp);
template <typename SPLITS_TYPE>
class UnicodeDecodeBaseOp : public OpKernel { class UnicodeDecodeBaseOp : public OpKernel {
public: public:
explicit UnicodeDecodeBaseOp(OpKernelConstruction* ctx, bool generate_offsets) explicit UnicodeDecodeBaseOp(OpKernelConstruction* ctx, bool generate_offsets)
@ -369,8 +370,8 @@ class UnicodeDecodeBaseOp : public OpKernel {
} }
void Decode(OpKernelContext* ctx, std::vector<UChar32>* char_values, void Decode(OpKernelContext* ctx, std::vector<UChar32>* char_values,
std::vector<int64>* offset_values, int* current_offset, std::vector<SPLITS_TYPE>* offset_values, int* current_offset,
int64* next_row_split, UChar32 char_value, int char_length, SPLITS_TYPE* next_row_split, UChar32 char_value, int char_length,
bool found_any_format_error) { bool found_any_format_error) {
if (error_options_.error_on_malformatting && found_any_format_error) { if (error_options_.error_on_malformatting && found_any_format_error) {
ctx->CtxFailure( ctx->CtxFailure(
@ -414,16 +415,16 @@ class UnicodeDecodeBaseOp : public OpKernel {
input_encoding_)); input_encoding_));
std::vector<UChar32> char_values; std::vector<UChar32> char_values;
std::vector<int64> offset_values; std::vector<SPLITS_TYPE> offset_values;
Tensor* output_row_splits; Tensor* output_row_splits;
OP_REQUIRES_OK(ctx, ctx->allocate_output("row_splits", OP_REQUIRES_OK(ctx, ctx->allocate_output("row_splits",
{input_tensor->NumElements() + 1}, {input_tensor->NumElements() + 1},
&output_row_splits)); &output_row_splits));
auto out_row_splits = output_row_splits->vec<int64>(); auto out_row_splits = output_row_splits->vec<SPLITS_TYPE>();
int row_split_index = 0; int row_split_index = 0;
int64 next_row_split = 0; SPLITS_TYPE next_row_split = 0;
for (int i = 0; i < input_vec.size(); ++i) { for (int i = 0; i < input_vec.size(); ++i) {
const string& input = input_vec(i); const string& input = input_vec(i);
// Convert input strings into unicode values. Output to a list of // Convert input strings into unicode values. Output to a list of
@ -443,18 +444,18 @@ class UnicodeDecodeBaseOp : public OpKernel {
Tensor* output_char_values; Tensor* output_char_values;
OP_REQUIRES_OK( OP_REQUIRES_OK(
ctx, ctx->allocate_output("char_values", ctx, ctx->allocate_output(
{static_cast<int64>(char_values.size())}, "char_values", {static_cast<SPLITS_TYPE>(char_values.size())},
&output_char_values)); &output_char_values));
auto out_char_values = output_char_values->vec<int32>(); auto out_char_values = output_char_values->vec<int32>();
if (generate_offsets_) { if (generate_offsets_) {
DCHECK(offset_values.size() == char_values.size()); DCHECK(offset_values.size() == char_values.size());
Tensor* output_offset_values; Tensor* output_offset_values;
OP_REQUIRES_OK( OP_REQUIRES_OK(ctx, ctx->allocate_output(
ctx, ctx->allocate_output("char_to_byte_starts", "char_to_byte_starts",
{static_cast<int64>(offset_values.size())}, {static_cast<SPLITS_TYPE>(offset_values.size())},
&output_offset_values)); &output_offset_values));
auto out_offset_values = output_offset_values->vec<int64>(); auto out_offset_values = output_offset_values->vec<SPLITS_TYPE>();
// Load output tensors from intermediate value arrays. // Load output tensors from intermediate value arrays.
for (int i = 0; i < char_values.size(); ++i) { for (int i = 0; i < char_values.size(); ++i) {
@ -474,23 +475,36 @@ class UnicodeDecodeBaseOp : public OpKernel {
bool generate_offsets_ = false; bool generate_offsets_ = false;
}; };
class UnicodeDecodeOp : public UnicodeDecodeBaseOp { template <typename SPLITS_TYPE>
class UnicodeDecodeOp : public UnicodeDecodeBaseOp<SPLITS_TYPE> {
public: public:
explicit UnicodeDecodeOp(OpKernelConstruction* ctx) explicit UnicodeDecodeOp(OpKernelConstruction* ctx)
: UnicodeDecodeBaseOp(ctx, false) {} : UnicodeDecodeBaseOp<SPLITS_TYPE>(ctx, false) {}
}; };
class UnicodeDecodeWithOffsetsOp : public UnicodeDecodeBaseOp { template <typename SPLITS_TYPE>
class UnicodeDecodeWithOffsetsOp : public UnicodeDecodeBaseOp<SPLITS_TYPE> {
public: public:
explicit UnicodeDecodeWithOffsetsOp(OpKernelConstruction* ctx) explicit UnicodeDecodeWithOffsetsOp(OpKernelConstruction* ctx)
: UnicodeDecodeBaseOp(ctx, true) {} : UnicodeDecodeBaseOp<SPLITS_TYPE>(ctx, true) {}
}; };
REGISTER_KERNEL_BUILDER(Name("UnicodeDecode").Device(DEVICE_CPU), REGISTER_KERNEL_BUILDER(
UnicodeDecodeOp); Name("UnicodeDecode").Device(DEVICE_CPU).TypeConstraint<int64>("Tsplits"),
REGISTER_KERNEL_BUILDER(Name("UnicodeDecodeWithOffsets").Device(DEVICE_CPU), UnicodeDecodeOp<int64>);
UnicodeDecodeWithOffsetsOp); REGISTER_KERNEL_BUILDER(Name("UnicodeDecodeWithOffsets")
.Device(DEVICE_CPU)
.TypeConstraint<int64>("Tsplits"),
UnicodeDecodeWithOffsetsOp<int64>);
REGISTER_KERNEL_BUILDER(
Name("UnicodeDecode").Device(DEVICE_CPU).TypeConstraint<int32>("Tsplits"),
UnicodeDecodeOp<int32>);
REGISTER_KERNEL_BUILDER(Name("UnicodeDecodeWithOffsets")
.Device(DEVICE_CPU)
.TypeConstraint<int32>("Tsplits"),
UnicodeDecodeWithOffsetsOp<int32>);
template <typename SPLITS_TYPE>
class UnicodeEncodeOp : public OpKernel { class UnicodeEncodeOp : public OpKernel {
public: public:
explicit UnicodeEncodeOp(OpKernelConstruction* ctx) : OpKernel(ctx) { explicit UnicodeEncodeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
@ -515,7 +529,7 @@ class UnicodeEncodeOp : public OpKernel {
const Tensor& input_tensor = context->input(0); const Tensor& input_tensor = context->input(0);
const auto input_tensor_flat = input_tensor.flat<int32>(); const auto input_tensor_flat = input_tensor.flat<int32>();
const Tensor& input_splits = context->input(1); const Tensor& input_splits = context->input(1);
const auto input_splits_flat = input_splits.flat<int64>(); const auto input_splits_flat = input_splits.flat<SPLITS_TYPE>();
// Since we limit to a 2-D input (flat_values of rank 1 and a single splits // Since we limit to a 2-D input (flat_values of rank 1 and a single splits
// tensor), our output dimension will be 1 with it's size equal to the // tensor), our output dimension will be 1 with it's size equal to the
@ -558,7 +572,11 @@ class UnicodeEncodeOp : public OpKernel {
ErrorOptions error_options_; ErrorOptions error_options_;
}; };
REGISTER_KERNEL_BUILDER(Name("UnicodeEncode").Device(DEVICE_CPU), REGISTER_KERNEL_BUILDER(
UnicodeEncodeOp); Name("UnicodeEncode").Device(DEVICE_CPU).TypeConstraint<int64>("Tsplits"),
UnicodeEncodeOp<int64>);
REGISTER_KERNEL_BUILDER(
Name("UnicodeEncode").Device(DEVICE_CPU).TypeConstraint<int32>("Tsplits"),
UnicodeEncodeOp<int32>);
} // namespace tensorflow } // namespace tensorflow

View File

@ -29,13 +29,14 @@ Status RaggedGatherShapeFn(InferenceContext* c);
//============================================================================== //==============================================================================
REGISTER_OP("RaggedGather") REGISTER_OP("RaggedGather")
.Input("params_nested_splits: PARAMS_RAGGED_RANK * int64") .Input("params_nested_splits: PARAMS_RAGGED_RANK * Tsplits")
.Input("params_dense_values: Tvalues") .Input("params_dense_values: Tvalues")
.Input("indices: Tindices") .Input("indices: Tindices")
.Output("output_nested_splits: OUTPUT_RAGGED_RANK * int64") .Output("output_nested_splits: OUTPUT_RAGGED_RANK * Tsplits")
.Output("output_dense_values: Tvalues") .Output("output_dense_values: Tvalues")
.Attr("Tvalues: type") .Attr("Tvalues: type")
.Attr("Tindices: {int32, int64}") .Attr("Tindices: {int32, int64}")
.Attr("Tsplits: {int32, int64} = DT_INT64")
.Attr("PARAMS_RAGGED_RANK: int >= 1") .Attr("PARAMS_RAGGED_RANK: int >= 1")
.Attr("OUTPUT_RAGGED_RANK: int >= 0") .Attr("OUTPUT_RAGGED_RANK: int >= 0")
.SetShapeFn(RaggedGatherShapeFn); .SetShapeFn(RaggedGatherShapeFn);

View File

@ -31,13 +31,14 @@ Status RaggedTensorFromVariantShapeFn(InferenceContext* c);
//============================================================================== //==============================================================================
REGISTER_OP("RaggedTensorToSparse") REGISTER_OP("RaggedTensorToSparse")
.Input("rt_nested_splits: RAGGED_RANK * int64") .Input("rt_nested_splits: RAGGED_RANK * Tsplits")
.Input("rt_dense_values: T") .Input("rt_dense_values: T")
.Output("sparse_indices: int64") .Output("sparse_indices: int64")
.Output("sparse_values: T") .Output("sparse_values: T")
.Output("sparse_dense_shape: int64") .Output("sparse_dense_shape: int64")
.Attr("RAGGED_RANK: int >= 1") .Attr("RAGGED_RANK: int >= 1")
.Attr("T: type") .Attr("T: type")
.Attr("Tsplits: {int32, int64} = DT_INT64")
.SetShapeFn(RaggedTensorToSparseShapeFn); .SetShapeFn(RaggedTensorToSparseShapeFn);
REGISTER_OP("RaggedTensorToVariant") REGISTER_OP("RaggedTensorToVariant")

View File

@ -32,9 +32,10 @@ REGISTER_OP("RaggedRange")
.Input("starts: T") .Input("starts: T")
.Input("limits: T") .Input("limits: T")
.Input("deltas: T") .Input("deltas: T")
.Output("rt_nested_splits: int64") .Output("rt_nested_splits: Tsplits")
.Output("rt_dense_values: T") .Output("rt_dense_values: T")
.Attr("T: {bfloat16, float, double, int32, int64} = DT_INT32") .Attr("T: {bfloat16, float, double, int32, int64} = DT_INT32")
.Attr("Tsplits: {int32, int64} = DT_INT64")
.SetShapeFn(RaggedRangeShapeFn); .SetShapeFn(RaggedRangeShapeFn);
//============================================================================== //==============================================================================

View File

@ -263,10 +263,11 @@ REGISTER_OP("UnicodeScript")
REGISTER_OP("UnicodeEncode") REGISTER_OP("UnicodeEncode")
.Input("input_values: int32") .Input("input_values: int32")
.Input("input_splits: int64") .Input("input_splits: Tsplits")
.Attr("errors: {'ignore', 'replace', 'strict'} = 'replace'") .Attr("errors: {'ignore', 'replace', 'strict'} = 'replace'")
.Attr("output_encoding: {'UTF-8', 'UTF-16-BE', 'UTF-32-BE'}") .Attr("output_encoding: {'UTF-8', 'UTF-16-BE', 'UTF-32-BE'}")
.Attr("replacement_char: int = 65533") // 0xFFFD unicode replacement char .Attr("replacement_char: int = 65533") // 0xFFFD unicode replacement char
.Attr("Tsplits: {int32, int64} = DT_INT64")
.Output("output: string") .Output("output: string")
.SetShapeFn([](InferenceContext* c) { .SetShapeFn([](InferenceContext* c) {
// Check rank of inner values // Check rank of inner values
@ -298,12 +299,13 @@ REGISTER_OP("UnicodeTranscode")
REGISTER_OP("UnicodeDecode") REGISTER_OP("UnicodeDecode")
.Input("input: string") .Input("input: string")
.Output("row_splits: int64") .Output("row_splits: Tsplits")
.Output("char_values: int32") .Output("char_values: int32")
.Attr("input_encoding: string") .Attr("input_encoding: string")
.Attr("errors: {'strict', 'replace', 'ignore'} = 'replace'") .Attr("errors: {'strict', 'replace', 'ignore'} = 'replace'")
.Attr("replacement_char: int = 65533") // 0xFFFD unicode replacement char .Attr("replacement_char: int = 65533") // 0xFFFD unicode replacement char
.Attr("replace_control_characters: bool = false") .Attr("replace_control_characters: bool = false")
.Attr("Tsplits: {int32, int64} = DT_INT64")
.SetShapeFn([](InferenceContext* c) { .SetShapeFn([](InferenceContext* c) {
// row_splits.shape == [input.size() + 1] // row_splits.shape == [input.size() + 1]
DimensionHandle num_row_splits; DimensionHandle num_row_splits;
@ -319,13 +321,14 @@ REGISTER_OP("UnicodeDecode")
REGISTER_OP("UnicodeDecodeWithOffsets") REGISTER_OP("UnicodeDecodeWithOffsets")
.Input("input: string") .Input("input: string")
.Output("row_splits: int64") .Output("row_splits: Tsplits")
.Output("char_values: int32") .Output("char_values: int32")
.Output("char_to_byte_starts: int64") .Output("char_to_byte_starts: int64")
.Attr("input_encoding: string") .Attr("input_encoding: string")
.Attr("errors: {'strict', 'replace', 'ignore'} = 'replace'") .Attr("errors: {'strict', 'replace', 'ignore'} = 'replace'")
.Attr("replacement_char: int = 65533") // 0xFFFD unicode replacement char .Attr("replacement_char: int = 65533") // 0xFFFD unicode replacement char
.Attr("replace_control_characters: bool = false") .Attr("replace_control_characters: bool = false")
.Attr("Tsplits: {int32, int64} = DT_INT64")
.SetShapeFn([](InferenceContext* c) { .SetShapeFn([](InferenceContext* c) {
// row_splits.shape == [input.size() + 1] // row_splits.shape == [input.size() + 1]
DimensionHandle num_row_splits; DimensionHandle num_row_splits;

View File

@ -27,6 +27,7 @@ py_library(
":ragged_batch_gather_ops", ":ragged_batch_gather_ops",
":ragged_batch_gather_with_default_op", ":ragged_batch_gather_with_default_op",
":ragged_concat_ops", ":ragged_concat_ops",
":ragged_config",
":ragged_conversion_ops", ":ragged_conversion_ops",
":ragged_dispatch", ":ragged_dispatch",
":ragged_factory_ops", ":ragged_factory_ops",
@ -282,11 +283,21 @@ py_library(
], ],
) )
py_library(
name = "ragged_config",
srcs = ["ragged_config.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:dtypes",
],
)
py_library( py_library(
name = "ragged_tensor", name = "ragged_tensor",
srcs = ["ragged_tensor.py"], srcs = ["ragged_tensor.py"],
srcs_version = "PY2AND3", srcs_version = "PY2AND3",
deps = [ deps = [
":ragged_config",
":ragged_tensor_value", ":ragged_tensor_value",
":ragged_util", ":ragged_util",
":segment_id_ops", ":segment_id_ops",
@ -363,6 +374,7 @@ py_library(
srcs = ["segment_id_ops.py"], srcs = ["segment_id_ops.py"],
srcs_version = "PY2AND3", srcs_version = "PY2AND3",
deps = [ deps = [
":ragged_config",
":ragged_util", ":ragged_util",
"//tensorflow/python:array_ops", "//tensorflow/python:array_ops",
"//tensorflow/python:dtypes", "//tensorflow/python:dtypes",

View File

@ -23,7 +23,6 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops.ragged import ragged_conversion_ops
from tensorflow.python.ops.ragged import ragged_functional_ops from tensorflow.python.ops.ragged import ragged_functional_ops
from tensorflow.python.ops.ragged import ragged_math_ops from tensorflow.python.ops.ragged import ragged_math_ops
from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.ops.ragged import ragged_tensor
@ -122,6 +121,8 @@ def boolean_mask(data, mask, keepdims=False, name=None):
data = ragged_tensor.convert_to_tensor_or_ragged_tensor(data, name='data') data = ragged_tensor.convert_to_tensor_or_ragged_tensor(data, name='data')
mask = ragged_tensor.convert_to_tensor_or_ragged_tensor( mask = ragged_tensor.convert_to_tensor_or_ragged_tensor(
mask, dtypes.bool, name='mask') mask, dtypes.bool, name='mask')
row_splits_dtype, (data, mask) = ragged_tensor.match_row_splits_dtypes(
data, mask, return_dtype=True)
# Get static rank of mask. # Get static rank of mask.
if mask.shape.ndims is None: if mask.shape.ndims is None:
@ -132,8 +133,9 @@ def boolean_mask(data, mask, keepdims=False, name=None):
# If mask is ragged, then recurse with a non-ragged mask. # If mask is ragged, then recurse with a non-ragged mask.
if ragged_tensor.is_ragged(mask): if ragged_tensor.is_ragged(mask):
if not ragged_tensor.is_ragged(data): if not ragged_tensor.is_ragged(data):
data = ragged_conversion_ops.from_tensor( data = ragged_tensor.RaggedTensor.from_tensor(
data, ragged_rank=mask.ragged_rank) data, ragged_rank=mask.ragged_rank,
row_splits_dtype=mask.row_splits.dtype)
# Check that mask.nested_row_splits is a prefix of # Check that mask.nested_row_splits is a prefix of
# data.nested_row_splits. # data.nested_row_splits.
splits_list = [ splits_list = [
@ -152,7 +154,7 @@ def boolean_mask(data, mask, keepdims=False, name=None):
# Count the number of True mask values in each row to find the # Count the number of True mask values in each row to find the
# lengths of the filtered rows; then convert to splits. # lengths of the filtered rows; then convert to splits.
int_mask = ragged_functional_ops.map_flat_values( int_mask = ragged_functional_ops.map_flat_values(
math_ops.cast, mask, dtype=dtypes.int64) math_ops.cast, mask, dtype=row_splits_dtype)
masked_row_lengths = ragged_math_ops.reduce_sum(int_mask, axis=1) masked_row_lengths = ragged_math_ops.reduce_sum(int_mask, axis=1)
splits.append(ragged_util.lengths_to_splits(masked_row_lengths)) splits.append(ragged_util.lengths_to_splits(masked_row_lengths))
mask = mask.values mask = mask.values
@ -192,8 +194,9 @@ def boolean_mask(data, mask, keepdims=False, name=None):
# If mask is non-ragged and has rank>1, then convert it to be ragged, # If mask is non-ragged and has rank>1, then convert it to be ragged,
# with a ragged rank matching data. # with a ragged rank matching data.
if ragged_tensor.is_ragged(data): if ragged_tensor.is_ragged(data):
mask = ragged_conversion_ops.from_tensor( mask = ragged_tensor.RaggedTensor.from_tensor(
mask, ragged_rank=min(data.ragged_rank, mask.shape.ndims - 1)) mask, ragged_rank=min(data.ragged_rank, mask.shape.ndims - 1),
row_splits_dtype=data.row_splits.dtype)
return boolean_mask(data, mask, keepdims) return boolean_mask(data, mask, keepdims)
# Otherwise, data and mask are both `Tensor`s. # Otherwise, data and mask are both `Tensor`s.
@ -206,14 +209,15 @@ def boolean_mask(data, mask, keepdims=False, name=None):
# number of values it contains. Then flatten that to get a list of # number of values it contains. Then flatten that to get a list of
# cell lengths, and convert it to splits. Finally, combine the splits # cell lengths, and convert it to splits. Finally, combine the splits
# and values to get the innermost ragged tensor. # and values to get the innermost ragged tensor.
masked_lengths = math_ops.count_nonzero(mask, axis=-1) masked_lengths = math_ops.count_nonzero(mask, axis=-1,
dtype=row_splits_dtype)
flattened_masked_lengths = array_ops.reshape(masked_lengths, [-1]) flattened_masked_lengths = array_ops.reshape(masked_lengths, [-1])
masked_values = ragged_tensor.RaggedTensor.from_row_lengths( masked_values = ragged_tensor.RaggedTensor.from_row_lengths(
masked_values, flattened_masked_lengths) masked_values, flattened_masked_lengths)
# Wrap remaining ragged dimensions. # Wrap remaining ragged dimensions.
if mask.shape.ndims > 2 and keepdims: if mask.shape.ndims > 2 and keepdims:
mask_shape = array_ops.shape(mask, out_type=dtypes.int64) mask_shape = array_ops.shape(mask, out_type=row_splits_dtype)
split_size = math_ops.cumprod(mask_shape) + 1 split_size = math_ops.cumprod(mask_shape) + 1
for dim in range(mask.shape.ndims - 3, -1, -1): for dim in range(mask.shape.ndims - 3, -1, -1):
elt_size = mask_shape[dim + 1] elt_size = mask_shape[dim + 1]
@ -254,11 +258,11 @@ def tile(input, multiples, name=None): # pylint: disable=redefined-builtin
with ops.name_scope(name, 'RaggedTile', [input, multiples]): with ops.name_scope(name, 'RaggedTile', [input, multiples]):
input = ragged_tensor.convert_to_tensor_or_ragged_tensor( input = ragged_tensor.convert_to_tensor_or_ragged_tensor(
input, name='input') input, name='input')
multiples = ragged_util.convert_to_int_tensor(
multiples, name='multiples', dtype=dtypes.int64)
multiples.shape.assert_has_rank(1)
if not ragged_tensor.is_ragged(input): if not ragged_tensor.is_ragged(input):
return array_ops.tile(input, multiples, name) return array_ops.tile(input, multiples, name)
multiples = ragged_util.convert_to_int_tensor(
multiples, name='multiples', dtype=input.row_splits.dtype)
multiples.shape.assert_has_rank(1)
# If the constant value of `multiples` is available, then we can use it # If the constant value of `multiples` is available, then we can use it
# to skip tiling dimensions where `multiples=1`. # to skip tiling dimensions where `multiples=1`.
@ -343,7 +347,7 @@ def _tile_ragged_splits(rt_input, multiples, const_multiples=None):
dimensions where `multiples=1`. dimensions where `multiples=1`.
Returns: Returns:
A list of 1-D `int64` `Tensor`s (one for each ragged dimension in A list of 1-D integer `Tensor`s (one for each ragged dimension in
`rt_input`). `rt_input`).
#### Example: #### Example:
@ -514,40 +518,6 @@ def size(input, out_type=dtypes.int32, name=None): # pylint: disable=redefined-
return array_ops.size(input, out_type=out_type, name=name) return array_ops.size(input, out_type=out_type, name=name)
#===============================================================================
# Internal Helper Functions
#===============================================================================
def _increase_ragged_rank_to(rt_input, ragged_rank):
"""Adds ragged dimensions to `rt_input` so it has the desired ragged rank."""
if ragged_rank > 0:
if not ragged_tensor.is_ragged(rt_input):
rt_input = ragged_conversion_ops.from_tensor(rt_input)
if rt_input.ragged_rank < ragged_rank:
rt_input = rt_input.with_values(
_increase_ragged_rank_to(rt_input.values, ragged_rank - 1))
return rt_input
def _concat_ragged_splits(splits_list):
"""Concatenates a list of RaggedTensor splits to form a single splits."""
pieces = [splits_list[0]]
splits_offset = splits_list[0][-1]
for splits in splits_list[1:]:
pieces.append(splits[1:] + splits_offset)
splits_offset += splits[-1]
return array_ops.concat(pieces, axis=0)
def _nrows(rt_input, out_type=dtypes.int64, name=None):
if isinstance(rt_input, ragged_tensor.RaggedTensor):
return rt_input.nrows(out_type=out_type, name=name)
else:
with ops.name_scope(name, 'RaggedNRows', [rt_input]):
return array_ops.shape(rt_input, out_type=out_type)[0]
#=============================================================================== #===============================================================================
# ragged.rank # ragged.rank
#=============================================================================== #===============================================================================

View File

@ -18,7 +18,6 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops from tensorflow.python.ops import check_ops
@ -72,6 +71,7 @@ def batch_gather(params, indices, name=None):
params, name='params') params, name='params')
indices = ragged_tensor.convert_to_tensor_or_ragged_tensor( indices = ragged_tensor.convert_to_tensor_or_ragged_tensor(
indices, name='indices') indices, name='indices')
params, indices = ragged_tensor.match_row_splits_dtypes(params, indices)
indices_ndims = indices.shape.ndims indices_ndims = indices.shape.ndims
if indices_ndims is None: if indices_ndims is None:
raise ValueError( raise ValueError(
@ -97,15 +97,17 @@ def batch_gather(params, indices, name=None):
if params.shape.ndims is not None and params.shape.ndims < 2: if params.shape.ndims is not None and params.shape.ndims < 2:
raise ValueError('batch shape from indices does ' raise ValueError('batch shape from indices does '
'not match params shape') 'not match params shape')
params = ragged_conversion_ops.from_tensor(params, ragged_rank=1) params = ragged_conversion_ops.from_tensor(
params, ragged_rank=1,
row_splits_dtype=indices.row_splits.dtype)
# Adjust indices from within-batch to global (in params.values), and # Adjust indices from within-batch to global (in params.values), and
# then use ragged.gather to gather them. # then use ragged.gather to gather them.
num_indices = indices.row_lengths() num_indices = indices.row_lengths()
params_starts = params.row_starts() params_starts = params.row_starts()
adjustments = ragged_util.repeat(params_starts, num_indices, axis=0) adjustments = ragged_util.repeat(params_starts, num_indices, axis=0)
adjusted_index_values = math_ops.cast( adjusted_index_values = (
indices.values, dtypes.int64) + adjustments math_ops.cast(indices.values, adjustments.dtype) + adjustments)
return ragged_tensor.RaggedTensor.from_row_splits( return ragged_tensor.RaggedTensor.from_row_splits(
ragged_gather_ops.gather(params.values, adjusted_index_values), ragged_gather_ops.gather(params.values, adjusted_index_values),
indices.row_splits) indices.row_splits)
@ -116,7 +118,8 @@ def batch_gather(params, indices, name=None):
elif indices_ndims == 2: elif indices_ndims == 2:
# Adjust indices from batch-local to global (in params.values) # Adjust indices from batch-local to global (in params.values)
adjustments = array_ops.expand_dims(params.row_starts(), 1) adjustments = array_ops.expand_dims(params.row_starts(), 1)
adjusted_indices = math_ops.cast(indices, dtypes.int64) + adjustments adjusted_indices = (
math_ops.cast(indices, adjustments.dtype) + adjustments)
return ragged_gather_ops.gather(params.values, adjusted_indices) return ragged_gather_ops.gather(params.values, adjusted_indices)
else: else:
raise ValueError('batch shape from indices does not match params shape') raise ValueError('batch shape from indices does not match params shape')

View File

@ -20,7 +20,6 @@ from __future__ import print_function
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops from tensorflow.python.ops import check_ops
@ -81,6 +80,9 @@ def batch_gather_with_default(params,
default_value = ragged_tensor.convert_to_tensor_or_ragged_tensor( default_value = ragged_tensor.convert_to_tensor_or_ragged_tensor(
default_value, name='default_value', default_value, name='default_value',
) )
row_splits_dtype, (params, indices, default_value) = (
ragged_tensor.match_row_splits_dtypes(params, indices, default_value,
return_dtype=True))
# TODO(hterry): lift this restriction and support default_values of # TODO(hterry): lift this restriction and support default_values of
# of rank > 1 # of rank > 1
if (default_value.shape.ndims is not 0 if (default_value.shape.ndims is not 0
@ -113,7 +115,7 @@ def batch_gather_with_default(params,
axis=-1) axis=-1)
upper_bounds = math_ops.cast(row_lengths, indices.dtype) upper_bounds = math_ops.cast(row_lengths, indices.dtype)
pad_shape = _get_pad_shape(params, indices) pad_shape = _get_pad_shape(params, indices, row_splits_dtype)
pad = ragged_tensor_shape.broadcast_to( pad = ragged_tensor_shape.broadcast_to(
default_value, pad_shape) default_value, pad_shape)
@ -144,11 +146,11 @@ def batch_gather_with_default(params,
params=padded_params, indices=adjusted_indices, name=name) params=padded_params, indices=adjusted_indices, name=name)
def _get_pad_shape(params, indices): def _get_pad_shape(params, indices, row_splits_dtype):
"""Gets the RaggedTensorDynamicShape for the pad tensor.""" """Gets the RaggedTensorDynamicShape for the pad tensor."""
num_batch_dimensions = indices.shape.ndims - 1 num_batch_dimensions = indices.shape.ndims - 1
params_shape = ragged_tensor_shape.RaggedTensorDynamicShape.from_tensor( params_shape = ragged_tensor_shape.RaggedTensorDynamicShape.from_tensor(
params) params, dim_size_dtype=row_splits_dtype)
# We want to create a pad tensor that can be concatenated with the params. # We want to create a pad tensor that can be concatenated with the params.
if params.shape.ndims == indices.shape.ndims: if params.shape.ndims == indices.shape.ndims:
@ -169,8 +171,8 @@ def _get_pad_shape(params, indices):
# has size 1. # has size 1.
pad_dims = None pad_dims = None
if num_batch_dimensions == 0: if num_batch_dimensions == 0:
pad_dims = (constant_op.constant(1, dtype=dtypes.int64),) + ( pad_dims = (constant_op.constant(1, dtype=row_splits_dtype),) + (
constant_op.constant([1], dtype=dtypes.int64),) * ( constant_op.constant([1], dtype=row_splits_dtype),) * (
params_shape.num_partitioned_dimensions - params_shape.num_partitioned_dimensions -
num_batch_dimensions - 1) num_batch_dimensions - 1)
else: else:

View File

@ -24,7 +24,6 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops from tensorflow.python.ops import check_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops.ragged import ragged_array_ops from tensorflow.python.ops.ragged import ragged_array_ops
from tensorflow.python.ops.ragged import ragged_conversion_ops
from tensorflow.python.ops.ragged import ragged_gather_ops from tensorflow.python.ops.ragged import ragged_gather_ops
from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.ops.ragged import ragged_util from tensorflow.python.ops.ragged import ragged_util
@ -135,6 +134,9 @@ def _ragged_stack_concat_helper(rt_inputs, axis, stack_values):
ragged_tensor.convert_to_tensor_or_ragged_tensor( ragged_tensor.convert_to_tensor_or_ragged_tensor(
rt_input, name='rt_input') for rt_input in rt_inputs rt_input, name='rt_input') for rt_input in rt_inputs
] ]
row_splits_dtype, rt_inputs = ragged_tensor.match_row_splits_dtypes(
*rt_inputs, return_dtype=True)
rt_inputs = list(rt_inputs)
# Special case: if there's only one input, then return it as-is. # Special case: if there's only one input, then return it as-is.
if len(rt_inputs) == 1: if len(rt_inputs) == 1:
@ -168,12 +170,13 @@ def _ragged_stack_concat_helper(rt_inputs, axis, stack_values):
# possible to concatenate Tensors and RaggedTensors together. # possible to concatenate Tensors and RaggedTensors together.
for i in range(len(rt_inputs)): for i in range(len(rt_inputs)):
if not ragged_tensor.is_ragged(rt_inputs[i]): if not ragged_tensor.is_ragged(rt_inputs[i]):
rt_inputs[i] = ragged_conversion_ops.from_tensor( rt_inputs[i] = ragged_tensor.RaggedTensor.from_tensor(
rt_inputs[i], ragged_rank=1) rt_inputs[i], ragged_rank=1, row_splits_dtype=row_splits_dtype)
# Convert the input tensors to all have the same ragged_rank. # Convert the input tensors to all have the same ragged_rank.
ragged_rank = max(max(rt.ragged_rank for rt in rt_inputs), 1) ragged_rank = max(max(rt.ragged_rank for rt in rt_inputs), 1)
rt_inputs = [_increase_ragged_rank_to(rt, ragged_rank) for rt in rt_inputs] rt_inputs = [_increase_ragged_rank_to(rt, ragged_rank, row_splits_dtype)
for rt in rt_inputs]
if axis == 0: if axis == 0:
return _ragged_stack_concat_axis_0(rt_inputs, stack_values) return _ragged_stack_concat_axis_0(rt_inputs, stack_values)
@ -281,14 +284,16 @@ def _copy_row_shape(rt_inputs, splits):
splits.set_shape(tensor_shape.TensorShape(rt.shape[0] + 1)) splits.set_shape(tensor_shape.TensorShape(rt.shape[0] + 1))
def _increase_ragged_rank_to(rt_input, ragged_rank): def _increase_ragged_rank_to(rt_input, ragged_rank, row_splits_dtype):
"""Adds ragged dimensions to `rt_input` so it has the desired ragged rank.""" """Adds ragged dimensions to `rt_input` so it has the desired ragged rank."""
if ragged_rank > 0: if ragged_rank > 0:
if not ragged_tensor.is_ragged(rt_input): if not ragged_tensor.is_ragged(rt_input):
rt_input = ragged_conversion_ops.from_tensor(rt_input) rt_input = ragged_tensor.RaggedTensor.from_tensor(
rt_input, row_splits_dtype=row_splits_dtype)
if rt_input.ragged_rank < ragged_rank: if rt_input.ragged_rank < ragged_rank:
rt_input = rt_input.with_values( rt_input = rt_input.with_values(
_increase_ragged_rank_to(rt_input.values, ragged_rank - 1)) _increase_ragged_rank_to(rt_input.values, ragged_rank - 1,
row_splits_dtype))
return rt_input return rt_input

View File

@ -0,0 +1,33 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Configuration parameters for RaggedTensors."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
def auto_cast_partition_dtype():
"""Whether incopmatible row-partitioning dtypes should be auto-converted.
If true, then operations that combine RaggedTensors but have different
row-partitioning tensor dtypes will be automatically cast to a
compatible dtype (`tf.int64`). If false, then such operations will result
in an error.
Returns:
`bool`
"""
return False

View File

@ -18,15 +18,22 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.python.framework import dtypes
from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.ops.ragged import ragged_tensor
def from_tensor(tensor, lengths=None, padding=None, ragged_rank=1, name=None): def from_tensor(tensor, lengths=None, padding=None, ragged_rank=1,
row_splits_dtype=dtypes.int64, name=None):
if ragged_tensor.is_ragged(tensor): if ragged_tensor.is_ragged(tensor):
return tensor return tensor
else: else:
return ragged_tensor.RaggedTensor.from_tensor(tensor, lengths, padding, return ragged_tensor.RaggedTensor.from_tensor(
ragged_rank, name) tensor,
lengths=lengths,
padding=padding,
ragged_rank=ragged_rank,
row_splits_dtype=row_splits_dtype,
name=name)
def to_tensor(rt_input, default_value=None, name=None): def to_tensor(rt_input, default_value=None, name=None):

View File

@ -128,6 +128,7 @@ class UnaryRaggedElementwiseDispatcher(dispatch.OpDispatcher):
elif not _is_convertible_to_tensor(elt): elif not _is_convertible_to_tensor(elt):
return self.NOT_SUPPORTED return self.NOT_SUPPORTED
if found_ragged: if found_ragged:
x = ragged_tensor.match_row_splits_dtypes(*x)
nested_splits_lists = [ nested_splits_lists = [
elt.nested_row_splits for elt in x if ragged_tensor.is_ragged(elt) elt.nested_row_splits for elt in x if ragged_tensor.is_ragged(elt)
] ]
@ -199,6 +200,9 @@ class BinaryRaggedElementwiseDispatcher(dispatch.OpDispatcher):
except (TypeError, ValueError): except (TypeError, ValueError):
return self.NOT_SUPPORTED return self.NOT_SUPPORTED
if x_is_ragged and y_is_ragged:
x, y = ragged_tensor.match_row_splits_dtypes(x, y)
if ((x_is_ragged and y_is_ragged) or if ((x_is_ragged and y_is_ragged) or
(x_is_ragged and x.flat_values.shape.ndims <= y.shape.ndims) or (x_is_ragged and x.flat_values.shape.ndims <= y.shape.ndims) or
(y_is_ragged and y.flat_values.shape.ndims <= x.shape.ndims)): (y_is_ragged and y.flat_values.shape.ndims <= x.shape.ndims)):
@ -272,16 +276,6 @@ class RaggedDispatcher(dispatch.OpDispatcher):
return found_ragged return found_ragged
def ragged_dispatch(original_op, tensor_args):
def decorator(ragged_op):
dispatch.RaggedDispatcher(original_op, ragged_op,
tensor_args).register(original_op)
return ragged_op
return decorator
_UNARY_ELEMENTWISE_OPS = [ _UNARY_ELEMENTWISE_OPS = [
array_ops.check_numerics, array_ops.check_numerics,
array_ops.identity, array_ops.identity,

View File

@ -34,7 +34,8 @@ from tensorflow.python.util.tf_export import tf_export
# Op to construct a constant RaggedTensor from a nested Python list. # Op to construct a constant RaggedTensor from a nested Python list.
#=============================================================================== #===============================================================================
@tf_export("ragged.constant") @tf_export("ragged.constant")
def constant(pylist, dtype=None, ragged_rank=None, inner_shape=None, name=None): def constant(pylist, dtype=None, ragged_rank=None, inner_shape=None,
name=None, row_splits_dtype=dtypes.int64):
"""Constructs a constant RaggedTensor from a nested Python list. """Constructs a constant RaggedTensor from a nested Python list.
Example: Example:
@ -65,6 +66,8 @@ def constant(pylist, dtype=None, ragged_rank=None, inner_shape=None, name=None):
is not specified. If `ragged_rank` is specified, then a default is chosen is not specified. If `ragged_rank` is specified, then a default is chosen
based on the contents of `pylist`. based on the contents of `pylist`.
name: A name prefix for the returned tensor (optional). name: A name prefix for the returned tensor (optional).
row_splits_dtype: data type for the constructed `RaggedTensor`'s row_splits.
One of `tf.int32` or `tf.int64`.
Returns: Returns:
A potentially ragged tensor with rank `K` and the specified `ragged_rank`, A potentially ragged tensor with rank `K` and the specified `ragged_rank`,
@ -74,14 +77,19 @@ def constant(pylist, dtype=None, ragged_rank=None, inner_shape=None, name=None):
ValueError: If the scalar values in `pylist` have inconsistent nesting ValueError: If the scalar values in `pylist` have inconsistent nesting
depth; or if ragged_rank or inner_shape are incompatible with `pylist`. depth; or if ragged_rank or inner_shape are incompatible with `pylist`.
""" """
def _ragged_factory(values, row_splits):
row_splits = constant_op.constant(row_splits, dtype=row_splits_dtype)
return ragged_tensor.RaggedTensor.from_row_splits(values, row_splits)
with ops.name_scope(name, "RaggedConstant"): with ops.name_scope(name, "RaggedConstant"):
return _constant_value(ragged_tensor.RaggedTensor.from_row_splits, return _constant_value(_ragged_factory,
constant_op.constant, pylist, dtype, ragged_rank, constant_op.constant, pylist, dtype, ragged_rank,
inner_shape) inner_shape)
@tf_export(v1=["ragged.constant_value"]) @tf_export(v1=["ragged.constant_value"])
def constant_value(pylist, dtype=None, ragged_rank=None, inner_shape=None): def constant_value(pylist, dtype=None, ragged_rank=None, inner_shape=None,
row_splits_dtype="int64"):
"""Constructs a RaggedTensorValue from a nested Python list. """Constructs a RaggedTensorValue from a nested Python list.
Warning: This function returns a `RaggedTensorValue`, not a `RaggedTensor`. Warning: This function returns a `RaggedTensorValue`, not a `RaggedTensor`.
@ -114,18 +122,20 @@ def constant_value(pylist, dtype=None, ragged_rank=None, inner_shape=None):
values in the returned `RaggedTensorValue`. Defaults to `()` if values in the returned `RaggedTensorValue`. Defaults to `()` if
`ragged_rank` is not specified. If `ragged_rank` is specified, then a `ragged_rank` is not specified. If `ragged_rank` is specified, then a
default is chosen based on the contents of `pylist`. default is chosen based on the contents of `pylist`.
row_splits_dtype: data type for the constructed `RaggedTensorValue`'s
row_splits. One of `numpy.int32` or `numpy.int64`.
Returns: Returns:
A `RaggedTensorValue` or `numpy.array` with rank `K` and the specified A `tf.RaggedTensorValue` or `numpy.array` with rank `K` and the specified
`ragged_rank`, containing the values from `pylist`. `ragged_rank`, containing the values from `pylist`.
Raises: Raises:
ValueError: If the scalar values in `pylist` have inconsistent nesting ValueError: If the scalar values in `pylist` have inconsistent nesting
depth; or if ragged_rank or inner_shape are incompatible with `pylist`. depth; or if ragged_rank or inner_shape are incompatible with `pylist`.
""" """
row_splits_dtype = dtypes.as_dtype(row_splits_dtype).as_numpy_dtype
def _ragged_factory(values, row_splits): def _ragged_factory(values, row_splits):
row_splits = np.array(row_splits, dtype=np.int64) row_splits = np.array(row_splits, dtype=row_splits_dtype)
return ragged_tensor_value.RaggedTensorValue(values, row_splits) return ragged_tensor_value.RaggedTensorValue(values, row_splits)
def _inner_factory(pylist, dtype, shape, name=None): # pylint: disable=unused-argument def _inner_factory(pylist, dtype, shape, name=None): # pylint: disable=unused-argument

View File

@ -18,7 +18,10 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.ragged import ragged_config
from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.ops.ragged import ragged_util from tensorflow.python.ops.ragged import ragged_util
from tensorflow.python.util.tf_export import tf_export from tensorflow.python.util.tf_export import tf_export
@ -72,6 +75,17 @@ def map_flat_values(op, *args, **kwargs):
if not nested_splits_lists: if not nested_splits_lists:
return op(*args, **kwargs) return op(*args, **kwargs)
split_dtypes = set(splits[0].dtype for splits in nested_splits_lists)
if len(split_dtypes) > 1:
if not ragged_config.auto_cast_partition_dtype():
raise ValueError("Input RaggedTensors have mismatched row_splits dtypes; "
"use RaggedTensor.with_row_splits_dtype() to convert "
"them to compatible dtypes.")
nested_splits_lists = [
[math_ops.cast(s, dtypes.int64) for s in nested_splits] # pylint: disable=g-complex-comprehension
for nested_splits in nested_splits_lists]
with ops.control_dependencies( with ops.control_dependencies(
ragged_util.assert_splits_match(nested_splits_lists)): ragged_util.assert_splits_match(nested_splits_lists)):
# Delegate to op, and then compose the result from the transformed values # Delegate to op, and then compose the result from the transformed values

View File

@ -96,6 +96,7 @@ def gather(params, indices, validate_indices=None, axis=0, batch_dims=0,
params, name='params') params, name='params')
indices = ragged_tensor.convert_to_tensor_or_ragged_tensor( indices = ragged_tensor.convert_to_tensor_or_ragged_tensor(
indices, name='indices') indices, name='indices')
params, indices = ragged_tensor.match_row_splits_dtypes(params, indices)
if ragged_tensor.is_ragged(indices): if ragged_tensor.is_ragged(indices):
return indices.with_values(gather(params, indices.values)) return indices.with_values(gather(params, indices.values))
@ -177,6 +178,7 @@ def gather_nd(params, indices, batch_dims=0, name=None):
params, name='params') params, name='params')
indices = ragged_tensor.convert_to_tensor_or_ragged_tensor( indices = ragged_tensor.convert_to_tensor_or_ragged_tensor(
indices, name='indices') indices, name='indices')
params, indices = ragged_tensor.match_row_splits_dtypes(params, indices)
indices_shape = indices.shape indices_shape = indices.shape
indices_ndims = indices_shape.ndims indices_ndims = indices_shape.ndims
if indices_ndims is None: if indices_ndims is None:
@ -200,7 +202,8 @@ def gather_nd(params, indices, batch_dims=0, name=None):
indices_is_dense = not ragged_tensor.is_ragged(indices) indices_is_dense = not ragged_tensor.is_ragged(indices)
if indices_is_dense: if indices_is_dense:
indices = ragged_conversion_ops.from_tensor( indices = ragged_conversion_ops.from_tensor(
indices, ragged_rank=indices_ndims - 2) indices, ragged_rank=indices_ndims - 2,
row_splits_dtype=params.row_splits.dtype)
result = indices.with_flat_values(gather_nd(params, indices.flat_values)) result = indices.with_flat_values(gather_nd(params, indices.flat_values))
if (indices_is_dense and ragged_tensor.is_ragged(result) and if (indices_is_dense and ragged_tensor.is_ragged(result) and
result.ragged_rank == indices_ndims - 2): result.ragged_rank == indices_ndims - 2):
@ -235,7 +238,7 @@ def gather_nd(params, indices, batch_dims=0, name=None):
# index tuples point to the correct values in the flattened params; and # index tuples point to the correct values in the flattened params; and
# then use ragged.gather on the flattened index tuples & params. # then use ragged.gather on the flattened index tuples & params.
else: else:
indices = math_ops.cast(indices, dtypes.int64) indices = math_ops.cast(indices, params.row_splits.dtype)
# Flatten the outermost 2 dimensions of the index tuples & params. # Flatten the outermost 2 dimensions of the index tuples & params.
flattened_index_tuples = array_ops.gather(params.row_splits, flattened_index_tuples = array_ops.gather(params.row_splits,

View File

@ -19,7 +19,6 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_ops
@ -136,7 +135,8 @@ def _ragged_getitem(rt_input, key_list):
# that puts all values in a single row. # that puts all values in a single row.
if row_key is array_ops.newaxis: if row_key is array_ops.newaxis:
inner_rt = _ragged_getitem(rt_input, inner_keys) inner_rt = _ragged_getitem(rt_input, inner_keys)
nsplits = array_ops.shape(inner_rt.row_splits, out_type=dtypes.int64)[0] nsplits = array_ops.shape(inner_rt.row_splits,
out_type=inner_rt.row_splits.dtype)[0]
return ragged_tensor.RaggedTensor.from_row_splits( return ragged_tensor.RaggedTensor.from_row_splits(
inner_rt, array_ops.stack([0, nsplits - 1])) inner_rt, array_ops.stack([0, nsplits - 1]))
@ -192,7 +192,7 @@ def _slice_ragged_row_dimension(rt_input, row_key):
# Use row_key to slice the starts & limits. # Use row_key to slice the starts & limits.
new_starts = rt_input.row_splits[:-1][row_key] new_starts = rt_input.row_splits[:-1][row_key]
new_limits = rt_input.row_splits[1:][row_key] new_limits = rt_input.row_splits[1:][row_key]
zero_pad = array_ops.zeros([1], dtypes.int64) zero_pad = array_ops.zeros([1], rt_input.row_splits.dtype)
# If there's no slice step, then we can just select a single continuous # If there's no slice step, then we can just select a single continuous
# span of `ragged.values(rt_input)`. # span of `ragged.values(rt_input)`.
@ -245,7 +245,8 @@ def _ragged_getitem_inner_dimensions(rt_input, key_list):
# RaggedTensor that puts each value in its own row. # RaggedTensor that puts each value in its own row.
if column_key is array_ops.newaxis: if column_key is array_ops.newaxis:
inner_rt = _ragged_getitem_inner_dimensions(rt_input, key_list[1:]) inner_rt = _ragged_getitem_inner_dimensions(rt_input, key_list[1:])
nsplits = array_ops.shape(inner_rt.row_splits, out_type=dtypes.int64)[0] nsplits = array_ops.shape(inner_rt.row_splits,
out_type=inner_rt.row_splits.dtype)[0]
return ragged_tensor.RaggedTensor.from_row_splits(inner_rt, return ragged_tensor.RaggedTensor.from_row_splits(inner_rt,
math_ops.range(nsplits)) math_ops.range(nsplits))
@ -359,10 +360,11 @@ def _build_ragged_tensor_from_value_ranges(starts, limits, step, values):
step = 1 step = 1
step = ops.convert_to_tensor(step, name="step") step = ops.convert_to_tensor(step, name="step")
if step.dtype.is_integer: if step.dtype.is_integer:
step = math_ops.cast(step, dtypes.int64) step = math_ops.cast(step, starts.dtype)
else: else:
raise TypeError("slice strides must be integers or None") raise TypeError("slice strides must be integers or None")
value_indices = ragged_math_ops.range(starts, limits, step) value_indices = ragged_math_ops.range(starts, limits, step,
row_splits_dtype=starts.dtype)
# Use `ragged_gather` or `array_ops.gather` to collect the values. # Use `ragged_gather` or `array_ops.gather` to collect the values.
if isinstance(values, ragged_tensor.RaggedTensor): if isinstance(values, ragged_tensor.RaggedTensor):
@ -384,11 +386,11 @@ def _add_offset_to_ranges(offset, starts, limits):
Args: Args:
offset: The offset to add. None, or an int, or a scalar Tensor. offset: The offset to add. None, or an int, or a scalar Tensor.
starts: 1-D int64 tensor containing start indices. starts: 1-D integer tensor containing start indices.
limits: 1-D int64 tensor containing limit indices. limits: 1-D integer tensor containing limit indices.
Returns: Returns:
A 1-D int64 tensor. A 1-D integer tensor.
""" """
def map_positive_offset(offset): def map_positive_offset(offset):
@ -398,7 +400,7 @@ def _add_offset_to_ranges(offset, starts, limits):
return math_ops.maximum(limits + offset, starts) return math_ops.maximum(limits + offset, starts)
if isinstance(offset, ops.Tensor): if isinstance(offset, ops.Tensor):
offset = math_ops.cast(offset, dtypes.int64) offset = math_ops.cast(offset, starts.dtype)
return control_flow_ops.cond(offset >= 0, return control_flow_ops.cond(offset >= 0,
lambda: map_positive_offset(offset), lambda: map_positive_offset(offset),
lambda: map_negative_offset(offset)) lambda: map_negative_offset(offset))

View File

@ -222,7 +222,7 @@ class RaggedMapOpTest(ragged_test_util.RaggedTensorTestCase,
def testZip(self): def testZip(self):
x = ragged_factory_ops.constant( x = ragged_factory_ops.constant(
[[10, 20], [30, 40], [50, 60], [70], [80, 90, 100]], dtypes.int64) [[10, 20], [30, 40], [50, 60], [70], [80, 90, 100]], dtypes.int64)
y = array_ops.expand_dims(mo.range(x.nrows(), dtype=dtypes.int64), axis=1) y = array_ops.expand_dims(mo.range(x.nrows(out_type=dtypes.int64)), axis=1)
def _zip(foo): def _zip(foo):
y_val, x_val = foo y_val, x_val = foo
@ -273,7 +273,7 @@ class RaggedMapOpTest(ragged_test_util.RaggedTensorTestCase,
elems = ragged_factory_ops.constant([[1, 2, 3], [4, 5], [6, 7]]) elems = ragged_factory_ops.constant([[1, 2, 3], [4, 5], [6, 7]])
fn = lambda x: ragged_tensor.RaggedTensor.from_row_starts(x, [0]) fn = lambda x: ragged_tensor.RaggedTensor.from_row_starts(x, [0])
with self.assertRaisesWithLiteralMatch( with self.assertRaisesWithLiteralMatch(
ValueError, r'The declared ragged rank (10) mismatches the result (1)'): ValueError, r'The declared ragged rank (10) mismatches the result (2)'):
_ = ragged_map_ops.map_fn( _ = ragged_map_ops.map_fn(
fn, fn,
elems, elems,

View File

@ -29,6 +29,7 @@ from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops import tensor_array_ops from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.ops import variable_scope as vs from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.ops.ragged import ragged_config
from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.platform import tf_logging as logging from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import nest from tensorflow.python.util import nest
@ -196,6 +197,7 @@ def map_fn(fn,
return nest.pack_sequence_as(elems, x) if input_is_sequence else x[0] return nest.pack_sequence_as(elems, x) if input_is_sequence else x[0]
elems_flat = input_flatten(elems) elems_flat = input_flatten(elems)
elems_flat = ragged_tensor.match_row_splits_dtypes(*elems_flat)
with ops.name_scope(name, "map", elems_flat): with ops.name_scope(name, "map", elems_flat):
# TODO(akshayka): Remove the in_graph_mode check once caching devices are # TODO(akshayka): Remove the in_graph_mode check once caching devices are
@ -408,8 +410,9 @@ def _maybe_decompose_dtype(d):
result = _RaggedTensorComponents( result = _RaggedTensorComponents(
flat_values=d.dtype, flat_values=d.dtype,
nested_row_lengths=tuple(dtypes.int64 for i in range(d.ragged_rank - 1)), nested_row_lengths=tuple(
outer_row_length=dtypes.int64, d.row_splits_dtype for i in range(d.ragged_rank - 1)),
outer_row_length=d.row_splits_dtype,
) )
return result return result
@ -418,31 +421,42 @@ def _convert_declared(fn_output_flat, output_declared):
"""Convert outputs which are `Tensor`s into `_RaggedTensorComponents`.""" """Convert outputs which are `Tensor`s into `_RaggedTensorComponents`."""
for current, declared in zip(fn_output_flat, output_declared): for current, declared in zip(fn_output_flat, output_declared):
if isinstance(declared, ragged_tensor.RaggedTensorType): if isinstance(declared, ragged_tensor.RaggedTensorType):
if isinstance(current, ragged_tensor.RaggedTensor): yield _convert_declared_ragged(current, declared)
# Check that the ragged ranks match up.
# + 1 to account for the rank of the outermost dimension.
if declared.ragged_rank != current.ragged_rank + 1:
raise ValueError(
"The declared ragged rank (%d) mismatches the result (%d)" %
(declared.ragged_rank, current.ragged_rank))
yield current
else:
# We the output is a Tensor, but the caller has declared that we are
# expecting an RaggedTensor output.
if declared.ragged_rank != 1:
raise ValueError(
"The declared ragged rank (%d) mismatches the result (1)" %
declared.ragged_rank)
if isinstance(current, ragged_tensor.RaggedTensor):
nrows = current.nrows()
else:
nrows = array_ops.shape(current, out_type=dtypes.int64)[0]
row_length = array_ops.expand_dims(nrows, axis=0)
rt = _RaggedTensorComponents(
flat_values=current,
nested_row_lengths=(),
outer_row_length=row_length)
yield rt
else: else:
yield current yield current
def _convert_declared_ragged(current, declared):
"""Converts an output with RaggedTensorType into a _RaggedTensorComponents."""
# Check that the ragged ranks match up.
# + 1 to account for the rank of the outermost dimension.
current_ragged_rank = getattr(current, "ragged_rank", 0)
if declared.ragged_rank != current_ragged_rank + 1:
raise ValueError(
"The declared ragged rank (%d) mismatches the result (%d)" %
(declared.ragged_rank, current_ragged_rank + 1))
# Check that dtypes match up.
if declared.dtype != current.dtype:
raise ValueError(
"The declared dtype (%s) mismatches the result (%s)" %
(declared.dtype, current.dtype))
if (isinstance(current, ragged_tensor.RaggedTensor) and
declared.row_splits_dtype != current.row_splits.dtype):
if not ragged_config.auto_cast_partition_dtype():
raise ValueError(
"The declared row_splits dtype (%s) mismatches the result (%s)."
" Use RaggedTensor.with_row_splits_dtype to convert it."
% (declared.row_splits_dtype, current.row_splits.dtype))
current = current.with_row_splits_dtype(declared.row_splits_dtype)
if isinstance(current, ragged_tensor.RaggedTensor):
return current
else:
nrows = array_ops.shape(current, out_type=declared.row_splits_dtype)[0]
row_length = array_ops.expand_dims(nrows, axis=0)
return _RaggedTensorComponents(
flat_values=current,
nested_row_lengths=(),
outer_row_length=row_length)

View File

@ -39,7 +39,8 @@ from tensorflow.python.util.tf_export import tf_export
#=============================================================================== #===============================================================================
# pylint: disable=redefined-builtin # pylint: disable=redefined-builtin
@tf_export('ragged.range') @tf_export('ragged.range')
def range(starts, limits=None, deltas=1, dtype=None, name=None): def range(starts, limits=None, deltas=1, dtype=None,
name=None, row_splits_dtype=dtypes.int64):
"""Returns a `RaggedTensor` containing the specified sequences of numbers. """Returns a `RaggedTensor` containing the specified sequences of numbers.
Each row of the returned `RaggedTensor` contains a single sequence: Each row of the returned `RaggedTensor` contains a single sequence:
@ -81,10 +82,13 @@ def range(starts, limits=None, deltas=1, dtype=None, name=None):
dtype: The type of the elements of the resulting tensor. If not specified, dtype: The type of the elements of the resulting tensor. If not specified,
then a value is chosen based on the other args. then a value is chosen based on the other args.
name: A name for the operation. name: A name for the operation.
row_splits_dtype: `dtype` for the returned `RaggedTensor`'s `row_splits`
tensor. One of `tf.int32` or `tf.int64`.
Returns: Returns:
A `RaggedTensor` of type `dtype` with `ragged_rank=1`. A `RaggedTensor` of type `dtype` with `ragged_rank=1`.
""" """
row_splits_dtype = dtypes.as_dtype(row_splits_dtype)
if limits is None: if limits is None:
starts, limits = 0, starts starts, limits = 0, starts
@ -99,7 +103,8 @@ def range(starts, limits=None, deltas=1, dtype=None, name=None):
[starts, limits, deltas], [starts, limits, deltas],
[dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64]) [dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64])
result = gen_ragged_math_ops.ragged_range(starts, limits, deltas, name=name) result = gen_ragged_math_ops.ragged_range(
starts, limits, deltas, Tsplits=row_splits_dtype, name=name)
return ragged_tensor.RaggedTensor.from_row_splits(result.rt_dense_values, return ragged_tensor.RaggedTensor.from_row_splits(result.rt_dense_values,
result.rt_nested_splits) result.rt_nested_splits)
@ -190,6 +195,9 @@ def _ragged_segment_aggregate(unsorted_segment_op,
data = ragged_tensor.convert_to_tensor_or_ragged_tensor(data, name='data') data = ragged_tensor.convert_to_tensor_or_ragged_tensor(data, name='data')
segment_ids = ragged_tensor.convert_to_tensor_or_ragged_tensor( segment_ids = ragged_tensor.convert_to_tensor_or_ragged_tensor(
segment_ids, name='segment_ids') segment_ids, name='segment_ids')
data, segment_ids = ragged_tensor.match_row_splits_dtypes(data, segment_ids)
if segment_ids.dtype not in (dtypes.int32, dtypes.int64):
raise ValueError('segment_ids must have dtype int32 or int64.')
if ragged_tensor.is_ragged(segment_ids): if ragged_tensor.is_ragged(segment_ids):
if not ragged_tensor.is_ragged(data): if not ragged_tensor.is_ragged(data):
@ -203,22 +211,19 @@ def _ragged_segment_aggregate(unsorted_segment_op,
return _ragged_segment_aggregate(unsorted_segment_op, data.values, return _ragged_segment_aggregate(unsorted_segment_op, data.values,
segment_ids.values, num_segments, name) segment_ids.values, num_segments, name)
segment_ids = math_ops.cast(segment_ids, dtypes.int64) # Find the length of each row in data. (shape=[data_nrows])
# Find the length of each row in data. (dtype=int64, shape=[data_nrows])
data_row_lengths = data.row_splits[1:] - data.row_splits[:-1] data_row_lengths = data.row_splits[1:] - data.row_splits[:-1]
# Find the length that each output row will have. The length of the row # Find the length that each output row will have. The length of the row
# corresponding to segment `id` is `max(data_row_lengths[i])` where # corresponding to segment `id` is `max(data_row_lengths[i])` where
# `segment_ids[i]=id`. (dtype=int64, shape=[output_nrows]) # `segment_ids[i]=id`. (shape=[output_nrows])
output_row_lengths = math_ops.maximum( output_row_lengths = math_ops.maximum(
math_ops.unsorted_segment_max(data_row_lengths, segment_ids, math_ops.unsorted_segment_max(data_row_lengths, segment_ids,
num_segments), 0) num_segments), 0)
assert output_row_lengths.dtype == dtypes.int64
# Build the splits tensor for the output RaggedTensor. # Build the splits tensor for the output RaggedTensor.
output_splits = array_ops.concat([ output_splits = array_ops.concat([
array_ops.zeros([1], dtypes.int64), array_ops.zeros([1], output_row_lengths.dtype),
math_ops.cumsum(output_row_lengths) math_ops.cumsum(output_row_lengths)
], ],
axis=0) axis=0)

View File

@ -43,8 +43,8 @@ class RaggedSplitsToSegmentIdsOpTest(ragged_test_util.RaggedTensorTestCase):
self.assertRaisesRegexp(ValueError, r'Invalid row_splits: \[\]', self.assertRaisesRegexp(ValueError, r'Invalid row_splits: \[\]',
segment_id_ops.row_splits_to_segment_ids, []) segment_id_ops.row_splits_to_segment_ids, [])
self.assertRaisesRegexp( self.assertRaisesRegexp(
ValueError, r'Tensor conversion requested dtype int64 for ' ValueError, r'splits must have dtype int32 or int64',
'Tensor with dtype float32', segment_id_ops.row_splits_to_segment_ids, segment_id_ops.row_splits_to_segment_ids,
constant_op.constant([0.5])) constant_op.constant([0.5]))
self.assertRaisesRegexp(ValueError, r'Shape \(\) must have rank 1', self.assertRaisesRegexp(ValueError, r'Shape \(\) must have rank 1',
segment_id_ops.row_splits_to_segment_ids, 0) segment_id_ops.row_splits_to_segment_ids, 0)

View File

@ -75,7 +75,7 @@ def squeeze(input, axis=None, name=None): # pylint: disable=redefined-builtin
# Make sure the specified ragged dimensions are squeezable. # Make sure the specified ragged dimensions are squeezable.
assertion_list = [] assertion_list = []
scalar_tensor_one = constant_op.constant(1, dtype=dtypes.int64) scalar_tensor_one = constant_op.constant(1, dtype=input.row_splits.dtype)
for i, r in enumerate(input.nested_row_lengths()): for i, r in enumerate(input.nested_row_lengths()):
if i + 1 in ragged_dims: if i + 1 in ragged_dims:
assertion_list.append( assertion_list.append(

View File

@ -24,7 +24,6 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_string_ops from tensorflow.python.ops import gen_string_ops
from tensorflow.python.ops import string_ops from tensorflow.python.ops import string_ops
from tensorflow.python.ops.ragged import ragged_array_ops from tensorflow.python.ops.ragged import ragged_array_ops
from tensorflow.python.ops.ragged import ragged_conversion_ops
from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.util import deprecation from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export from tensorflow.python.util.tf_export import tf_export
@ -146,8 +145,8 @@ def unicode_encode(input,
if input_tensor.shape.ndims == 2: if input_tensor.shape.ndims == 2:
# The input tensor is of the correct 2-D shape, it's just not ragged. # The input tensor is of the correct 2-D shape, it's just not ragged.
return unicode_encode( return unicode_encode(
ragged_conversion_ops.from_tensor(input_tensor), output_encoding, ragged_tensor.RaggedTensor.from_tensor(input_tensor),
errors, replacement_char) output_encoding, errors, replacement_char)
elif input_tensor.shape.ndims > 2: elif input_tensor.shape.ndims > 2:
# We need to initially flatten the input tensor to 2-D, and then can # We need to initially flatten the input tensor to 2-D, and then can
# reshape the output of our processed flattened tensor. # reshape the output of our processed flattened tensor.
@ -166,7 +165,7 @@ def unicode_encode(input,
ragged_input_tensor = ragged_tensor.RaggedTensor.from_row_splits( ragged_input_tensor = ragged_tensor.RaggedTensor.from_row_splits(
input_tensor, input_tensor,
array_ops.stack( array_ops.stack(
[0, array_ops.shape(input_tensor, out_type=dtypes.int64)[0]])) [0, array_ops.shape(input_tensor, out_type=dtypes.int32)[0]]))
output_tensor = unicode_encode(ragged_input_tensor, output_encoding, output_tensor = unicode_encode(ragged_input_tensor, output_encoding,
errors, replacement_char) errors, replacement_char)
return array_ops.reshape(output_tensor, []) return array_ops.reshape(output_tensor, [])
@ -404,11 +403,11 @@ def _unicode_decode(input, input_encoding, errors, replacement_char,
if input_ndims > 1: if input_ndims > 1:
# Convert to a ragged tensor with ragged_rank = input_ndims - 1. # Convert to a ragged tensor with ragged_rank = input_ndims - 1.
if not ragged_tensor.is_ragged(input): if not ragged_tensor.is_ragged(input):
input = ragged_conversion_ops.from_tensor( input = ragged_tensor.RaggedTensor.from_tensor(
input, ragged_rank=input_ndims - 1) input, ragged_rank=input_ndims - 1)
elif input.ragged_rank < input_ndims - 1: elif input.ragged_rank < input_ndims - 1:
input = input.with_flat_values( input = input.with_flat_values(
ragged_conversion_ops.from_tensor( ragged_tensor.RaggedTensor.from_tensor(
input.flat_values, input.flat_values,
ragged_rank=input_ndims - input.ragged_rank + 1)) ragged_rank=input_ndims - input.ragged_rank + 1))

View File

@ -30,6 +30,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_ragged_conversion_ops from tensorflow.python.ops import gen_ragged_conversion_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops.ragged import ragged_config
from tensorflow.python.ops.ragged import ragged_tensor_value from tensorflow.python.ops.ragged import ragged_tensor_value
from tensorflow.python.ops.ragged import ragged_util from tensorflow.python.ops.ragged import ragged_util
from tensorflow.python.ops.ragged import segment_id_ops from tensorflow.python.ops.ragged import segment_id_ops
@ -115,8 +116,8 @@ class RaggedTensor(composite_tensor.CompositeTensor):
`[nvals]`, corresponding one-to-one with `values`, which specifies `[nvals]`, corresponding one-to-one with `values`, which specifies
each value's row index. In particular, the row `rt[row]` consists of the each value's row index. In particular, the row `rt[row]` consists of the
values `rt.values[j]` where `value_rowids[j]==row`. `nrows` is an values `rt.values[j]` where `value_rowids[j]==row`. `nrows` is an
int64 scalar that specifies the number of rows in the `RaggedTensor`. integer scalar that specifies the number of rows in the
(`nrows` is used to indicate trailing empty rows.) `RaggedTensor`. (`nrows` is used to indicate trailing empty rows.)
* `row_starts`: a vector with shape `[nrows]`, which specifies the start * `row_starts`: a vector with shape `[nrows]`, which specifies the start
offset of each row. Equivalent to `row_splits[:-1]`. offset of each row. Equivalent to `row_splits[:-1]`.
@ -220,10 +221,10 @@ class RaggedTensor(composite_tensor.CompositeTensor):
Args: Args:
values: A potentially ragged tensor of any dtype and shape `[nvals, ...]`. values: A potentially ragged tensor of any dtype and shape `[nvals, ...]`.
row_splits: A 1-D int64 tensor with shape `[nrows+1]`. row_splits: A 1-D integer tensor with shape `[nrows+1]`.
cached_row_lengths: A 1-D int64 tensor with shape `[nrows]` cached_row_lengths: A 1-D integer tensor with shape `[nrows]`
cached_value_rowids: A 1-D int64 tensor with shape `[nvals]`. cached_value_rowids: A 1-D integer tensor with shape `[nvals]`.
cached_nrows: A 1-D int64 scalar tensor. cached_nrows: A 1-D integer scalar tensor.
internal: True if the constructor is being called by one of the factory internal: True if the constructor is being called by one of the factory
methods. If false, an exception will be raised. methods. If false, an exception will be raised.
@ -244,9 +245,13 @@ class RaggedTensor(composite_tensor.CompositeTensor):
raise TypeError("values must be a Tensor or RaggedTensor.") raise TypeError("values must be a Tensor or RaggedTensor.")
if not isinstance(row_splits, ops.Tensor): if not isinstance(row_splits, ops.Tensor):
raise TypeError("Row-partitioning argument must be a Tensor.") raise TypeError("Row-partitioning argument must be a Tensor.")
if row_splits.dtype not in (dtypes.int32, dtypes.int64):
raise ValueError("Row-partitioning argument must be int32 or int64")
values.shape.with_rank_at_least(1) values.shape.with_rank_at_least(1)
row_splits.shape.assert_has_rank(1) row_splits.shape.assert_has_rank(1)
row_splits.set_shape([None]) row_splits.set_shape([None])
if isinstance(values, RaggedTensor):
assert row_splits.dtype == values.row_splits.dtype
self._values = values self._values = values
self._row_splits = row_splits self._row_splits = row_splits
@ -255,8 +260,11 @@ class RaggedTensor(composite_tensor.CompositeTensor):
# round-trip conversions when a RaggedTensor is constructed from # round-trip conversions when a RaggedTensor is constructed from
# lengths or rowids, and we later want those lengths/rowids back. # lengths or rowids, and we later want those lengths/rowids back.
for tensor in [cached_row_lengths, cached_value_rowids, cached_nrows]: for tensor in [cached_row_lengths, cached_value_rowids, cached_nrows]:
if tensor is not None and not isinstance(tensor, ops.Tensor): if tensor is not None:
raise TypeError("Cached value must be a Tensor or None.") if not isinstance(tensor, ops.Tensor):
raise TypeError("Cached value must be a Tensor or None.")
elif tensor.dtype not in (dtypes.int32, dtypes.int64):
raise TypeError("Cached value must be int32 or int64.")
self._cached_row_lengths = cached_row_lengths self._cached_row_lengths = cached_row_lengths
self._cached_value_rowids = cached_value_rowids self._cached_value_rowids = cached_value_rowids
self._cached_nrows = cached_nrows self._cached_nrows = cached_nrows
@ -276,15 +284,12 @@ class RaggedTensor(composite_tensor.CompositeTensor):
for row in range(nrows)] for row in range(nrows)]
``` ```
Warning: currently, this needs to cast value_rowids to int64 before
converting, since `tf.math.bincount` only supports `int32`.
Args: Args:
values: A potentially ragged tensor with shape `[nvals, ...]`. values: A potentially ragged tensor with shape `[nvals, ...]`.
value_rowids: A 1-D int64 tensor with shape `[nvals]`, which corresponds value_rowids: A 1-D integer tensor with shape `[nvals]`, which corresponds
one-to-one with `values`, and specifies each value's row index. Must be one-to-one with `values`, and specifies each value's row index. Must be
nonnegative, and must be sorted in ascending order. nonnegative, and must be sorted in ascending order.
nrows: An int64 scalar specifying the number of rows. This should be nrows: An integer scalar specifying the number of rows. This should be
specified if the `RaggedTensor` may containing empty training rows. Must specified if the `RaggedTensor` may containing empty training rows. Must
be greater than `value_rowids[-1]` (or zero if `value_rowids` is empty). be greater than `value_rowids[-1]` (or zero if `value_rowids` is empty).
Defaults to `value_rowids[-1]` (or zero if `value_rowids` is empty). Defaults to `value_rowids[-1]` (or zero if `value_rowids` is empty).
@ -308,9 +313,8 @@ class RaggedTensor(composite_tensor.CompositeTensor):
""" """
with ops.name_scope(name, "RaggedFromValueRowIds", with ops.name_scope(name, "RaggedFromValueRowIds",
[values, value_rowids, nrows]): [values, value_rowids, nrows]):
values = convert_to_tensor_or_ragged_tensor(values, name="values") values, value_rowids = cls._convert_values_and_row_partition(
value_rowids = ops.convert_to_tensor( values, value_rowids, "value_rowids")
value_rowids, dtypes.int64, name="value_rowids")
if nrows is None: if nrows is None:
const_rowids = tensor_util.constant_value(value_rowids) const_rowids = tensor_util.constant_value(value_rowids)
if const_rowids is None: if const_rowids is None:
@ -318,9 +322,10 @@ class RaggedTensor(composite_tensor.CompositeTensor):
const_nrows = None const_nrows = None
else: else:
const_nrows = const_rowids[-1] + 1 if const_rowids.size > 0 else 0 const_nrows = const_rowids[-1] + 1 if const_rowids.size > 0 else 0
nrows = ops.convert_to_tensor(const_nrows, dtypes.int64, name="nrows") nrows = ops.convert_to_tensor(const_nrows, value_rowids.dtype,
name="nrows")
else: else:
nrows = ops.convert_to_tensor(nrows, dtypes.int64, "nrows") nrows = ops.convert_to_tensor(nrows, value_rowids.dtype, "nrows")
const_nrows = tensor_util.constant_value(nrows) const_nrows = tensor_util.constant_value(nrows)
if const_nrows is not None: if const_nrows is not None:
if const_nrows < 0: if const_nrows < 0:
@ -340,14 +345,14 @@ class RaggedTensor(composite_tensor.CompositeTensor):
# Note: we don't use segment_ids_to_row_splits() here because we want # Note: we don't use segment_ids_to_row_splits() here because we want
# to save the intermediate value `row_lengths`, so we can cache it. # to save the intermediate value `row_lengths`, so we can cache it.
# TODO(b/116708836) Upgrade bincount to accept int64 so we can skip the # TODO(b/116708836) Upgrade bincount to accept int64 so we can skip the
# cast (Remove the warning in the docstring when we do.) # cast.
value_rowids_int32 = math_ops.cast(value_rowids, dtypes.int32) value_rowids_int32 = math_ops.cast(value_rowids, dtypes.int32)
nrows_int32 = math_ops.cast(nrows, dtypes.int32) nrows_int32 = math_ops.cast(nrows, dtypes.int32)
row_lengths = math_ops.bincount( row_lengths = math_ops.bincount(
value_rowids_int32, value_rowids_int32,
minlength=nrows_int32, minlength=nrows_int32,
maxlength=nrows_int32, maxlength=nrows_int32,
dtype=dtypes.int64) dtype=value_rowids.dtype)
row_splits = array_ops.concat([[0], math_ops.cumsum(row_lengths)], axis=0) row_splits = array_ops.concat([[0], math_ops.cumsum(row_lengths)], axis=0)
if const_nrows is not None: if const_nrows is not None:
row_lengths.set_shape([const_nrows]) row_lengths.set_shape([const_nrows])
@ -374,9 +379,9 @@ class RaggedTensor(composite_tensor.CompositeTensor):
Args: Args:
values: A potentially ragged tensor with shape `[nvals, ...]`. values: A potentially ragged tensor with shape `[nvals, ...]`.
row_splits: A 1-D int64 tensor with shape `[nrows+1]`. Must not be empty, row_splits: A 1-D integer tensor with shape `[nrows+1]`. Must not be
and must be sorted in ascending order. `row_splits[0]` must be zero and empty, and must be sorted in ascending order. `row_splits[0]` must be
`row_splits[-1]` must be `nvals`. zero and `row_splits[-1]` must be `nvals`.
name: A name prefix for the RaggedTensor (optional). name: A name prefix for the RaggedTensor (optional).
Returns: Returns:
@ -397,8 +402,8 @@ class RaggedTensor(composite_tensor.CompositeTensor):
if isinstance(row_splits, (list, tuple)) and not row_splits: if isinstance(row_splits, (list, tuple)) and not row_splits:
raise ValueError("row_splits tensor may not be empty.") raise ValueError("row_splits tensor may not be empty.")
with ops.name_scope(name, "RaggedFromRowSplits", [values, row_splits]): with ops.name_scope(name, "RaggedFromRowSplits", [values, row_splits]):
values = convert_to_tensor_or_ragged_tensor(values, name="values") values, row_splits = cls._convert_values_and_row_partition(
row_splits = ops.convert_to_tensor(row_splits, dtypes.int64, "row_splits") values, row_splits, "row_splits")
row_splits.shape.assert_has_rank(1) row_splits.shape.assert_has_rank(1)
return cls(values=values, row_splits=row_splits, internal=True) return cls(values=values, row_splits=row_splits, internal=True)
@ -415,7 +420,7 @@ class RaggedTensor(composite_tensor.CompositeTensor):
Args: Args:
values: A potentially ragged tensor with shape `[nvals, ...]`. values: A potentially ragged tensor with shape `[nvals, ...]`.
row_lengths: A 1-D int64 tensor with shape `[nrows]`. Must be row_lengths: A 1-D integer tensor with shape `[nrows]`. Must be
nonnegative. `sum(row_lengths)` must be `nvals`. nonnegative. `sum(row_lengths)` must be `nvals`.
name: A name prefix for the RaggedTensor (optional). name: A name prefix for the RaggedTensor (optional).
@ -432,9 +437,8 @@ class RaggedTensor(composite_tensor.CompositeTensor):
``` ```
""" """
with ops.name_scope(name, "RaggedFromRowLengths", [values, row_lengths]): with ops.name_scope(name, "RaggedFromRowLengths", [values, row_lengths]):
values = convert_to_tensor_or_ragged_tensor(values, name="values") values, row_lengths = cls._convert_values_and_row_partition(
row_lengths = ops.convert_to_tensor(row_lengths, dtypes.int64, values, row_lengths, "row_lengths")
"row_lengths")
row_lengths.shape.assert_has_rank(1) row_lengths.shape.assert_has_rank(1)
row_limits = math_ops.cumsum(row_lengths) row_limits = math_ops.cumsum(row_lengths)
row_splits = array_ops.concat([[0], row_limits], axis=0) row_splits = array_ops.concat([[0], row_limits], axis=0)
@ -452,9 +456,9 @@ class RaggedTensor(composite_tensor.CompositeTensor):
Args: Args:
values: A potentially ragged tensor with shape `[nvals, ...]`. values: A potentially ragged tensor with shape `[nvals, ...]`.
row_starts: A 1-D int64 tensor with shape `[nrows]`. Must be nonnegative row_starts: A 1-D integer tensor with shape `[nrows]`. Must be
and sorted in ascending order. If `nrows>0`, then `row_starts[0]` must nonnegative and sorted in ascending order. If `nrows>0`, then
be zero. `row_starts[0]` must be zero.
name: A name prefix for the RaggedTensor (optional). name: A name prefix for the RaggedTensor (optional).
Returns: Returns:
@ -470,10 +474,10 @@ class RaggedTensor(composite_tensor.CompositeTensor):
``` ```
""" """
with ops.name_scope(name, "RaggedFromRowStarts", [values, row_starts]): with ops.name_scope(name, "RaggedFromRowStarts", [values, row_starts]):
values = convert_to_tensor_or_ragged_tensor(values, name="values") values, row_starts = cls._convert_values_and_row_partition(
row_starts = ops.convert_to_tensor(row_starts, dtypes.int64, "row_starts") values, row_starts, "row_starts")
row_starts.shape.assert_has_rank(1) row_starts.shape.assert_has_rank(1)
nvals = array_ops.shape(values, out_type=dtypes.int64)[:1] nvals = array_ops.shape(values, out_type=row_starts.dtype)[:1]
row_splits = array_ops.concat([row_starts, nvals], axis=0) row_splits = array_ops.concat([row_starts, nvals], axis=0)
return cls(values=values, row_splits=row_splits, internal=True) return cls(values=values, row_splits=row_splits, internal=True)
@ -485,7 +489,7 @@ class RaggedTensor(composite_tensor.CompositeTensor):
Args: Args:
values: A potentially ragged tensor with shape `[nvals, ...]`. values: A potentially ragged tensor with shape `[nvals, ...]`.
row_limits: A 1-D int64 tensor with shape `[nrows]`. Must be sorted in row_limits: A 1-D integer tensor with shape `[nrows]`. Must be sorted in
ascending order. If `nrows>0`, then `row_limits[-1]` must be `nvals`. ascending order. If `nrows>0`, then `row_limits[-1]` must be `nvals`.
name: A name prefix for the RaggedTensor (optional). name: A name prefix for the RaggedTensor (optional).
@ -502,10 +506,10 @@ class RaggedTensor(composite_tensor.CompositeTensor):
``` ```
""" """
with ops.name_scope(name, "RaggedFromRowLimits", [values, row_limits]): with ops.name_scope(name, "RaggedFromRowLimits", [values, row_limits]):
values = convert_to_tensor_or_ragged_tensor(values, name="values") values, row_limits = cls._convert_values_and_row_partition(
row_limits = ops.convert_to_tensor(row_limits, dtypes.int64, "row_limits") values, row_limits, "row_limits")
row_limits.shape.assert_has_rank(1) row_limits.shape.assert_has_rank(1)
zero = array_ops.zeros([1], dtypes.int64) zero = array_ops.zeros([1], row_limits.dtype)
row_splits = array_ops.concat([zero, row_limits], axis=0) row_splits = array_ops.concat([zero, row_limits], axis=0)
return cls(values=values, row_splits=row_splits, internal=True) return cls(values=values, row_splits=row_splits, internal=True)
@ -527,9 +531,9 @@ class RaggedTensor(composite_tensor.CompositeTensor):
Args: Args:
flat_values: A potentially ragged tensor. flat_values: A potentially ragged tensor.
nested_value_rowids: A list of 1-D int64 tensors. The `i`th tensor is nested_value_rowids: A list of 1-D integer tensors. The `i`th tensor is
used as the `value_rowids` for the `i`th ragged dimension. used as the `value_rowids` for the `i`th ragged dimension.
nested_nrows: A list of int64 scalars. The `i`th scalar is used as the nested_nrows: A list of integer scalars. The `i`th scalar is used as the
`nrows` for the `i`th ragged dimension. `nrows` for the `i`th ragged dimension.
name: A name prefix for the RaggedTensor (optional). name: A name prefix for the RaggedTensor (optional).
@ -573,8 +577,8 @@ class RaggedTensor(composite_tensor.CompositeTensor):
Args: Args:
flat_values: A potentially ragged tensor. flat_values: A potentially ragged tensor.
nested_row_splits: A list of 1-D int64 tensors. The `i`th tensor is used nested_row_splits: A list of 1-D integer tensors. The `i`th tensor is
as the `row_splits` for the `i`th ragged dimension. used as the `row_splits` for the `i`th ragged dimension.
name: A name prefix for the RaggedTensor (optional). name: A name prefix for the RaggedTensor (optional).
Returns: Returns:
@ -603,8 +607,8 @@ class RaggedTensor(composite_tensor.CompositeTensor):
Args: Args:
flat_values: A potentially ragged tensor. flat_values: A potentially ragged tensor.
nested_row_lengths: A list of 1-D int64 tensors. The `i`th tensor is used nested_row_lengths: A list of 1-D integer tensors. The `i`th tensor is
as the `row_lengths` for the `i`th ragged dimension. used as the `row_lengths` for the `i`th ragged dimension.
name: A name prefix for the RaggedTensor (optional). name: A name prefix for the RaggedTensor (optional).
Returns: Returns:
@ -619,6 +623,50 @@ class RaggedTensor(composite_tensor.CompositeTensor):
result = cls.from_row_lengths(result, lengths) result = cls.from_row_lengths(result, lengths)
return result return result
@classmethod
def _convert_values_and_row_partition(cls, values, partition, name):
"""Converts `values` and `partition` to Tensors.
If `values` is a `RaggedTensor`, then converts `values` and `partition`
to have compatible row-partitioning dtypes. In particular, if any of the
row partitioning tensors are `int64`, then all of the other row
partitioning tensors wil be cast to `int64` (if auto_cast_partition_dtype()
is true) or an error will be raised (if auto_cast_partition_dtype() is
false).
Args:
values: The `values` for the `RaggedTensor` being constructed.
partition: A row-partitioning tensor for the `RaggedTensor` being
constructed. I.e., one of: row_splits, row_lengths, row_starts,
row_limits, value_rowids.
name: The name of the row-partitioning tensor.
Returns:
A tuple (values, partition).
"""
if isinstance(values, RaggedTensor):
if isinstance(partition, ops.Tensor):
if partition.dtype not in (dtypes.int32, dtypes.int64):
raise ValueError("%s must have dtype int32 or int64" % name)
if values.row_splits.dtype != partition.dtype:
if not ragged_config.auto_cast_partition_dtype():
raise ValueError("dtype mismatch: %s (%s) vs values.row_splits (%s)"
% (name, partition.dtype, values.row_splits.dtype))
partition = math_ops.cast(partition, dtypes.int64)
values = values.with_row_splits_dtype(dtypes.int64)
else:
partition = ops.convert_to_tensor(partition, values.row_splits.dtype,
name=name)
else:
values = ops.convert_to_tensor(values, name="values")
partition = ops.convert_to_tensor(
partition, preferred_dtype=dtypes.int64,
name=name)
if partition.dtype not in (dtypes.int32, dtypes.int64):
raise ValueError("%s must have dtype int32 or int64" % name)
return (values, partition)
#============================================================================= #=============================================================================
# Accessors # Accessors
#============================================================================= #=============================================================================
@ -696,7 +744,7 @@ class RaggedTensor(composite_tensor.CompositeTensor):
the slice `rt.values[rt.row_splits[i]:rt.row_splits[i+1]]`. the slice `rt.values[rt.row_splits[i]:rt.row_splits[i+1]]`.
Returns: Returns:
A 1-D `int64` `Tensor` with shape `[self.nrows+1]`. A 1-D integer `Tensor` with shape `[self.nrows+1]`.
The returned tensor is non-empty, and is sorted in ascending order. The returned tensor is non-empty, and is sorted in ascending order.
`self.row_splits[0]` is zero, and `self.row_splits[-1]` is equal to `self.row_splits[0]` is zero, and `self.row_splits[-1]` is equal to
`self.values.shape[0]`. `self.values.shape[0]`.
@ -752,7 +800,7 @@ class RaggedTensor(composite_tensor.CompositeTensor):
* `value_splits = rt.values.nested_row_splits` otherwise. * `value_splits = rt.values.nested_row_splits` otherwise.
Returns: Returns:
A `tuple` of 1-D `int64` `Tensor`s. A `tuple` of 1-D integer `Tensor`s.
#### Example: #### Example:
@ -785,7 +833,7 @@ class RaggedTensor(composite_tensor.CompositeTensor):
name: A name prefix for the returned tensor (optional). name: A name prefix for the returned tensor (optional).
Returns: Returns:
A 1-D `int64` `Tensor` with shape `self.values.shape[:1]`. A 1-D integer `Tensor` with shape `self.values.shape[:1]`.
The returned tensor is nonnegative, and is sorted in ascending order. The returned tensor is nonnegative, and is sorted in ascending order.
#### Example: #### Example:
@ -803,13 +851,14 @@ class RaggedTensor(composite_tensor.CompositeTensor):
with ops.name_scope(name, "RaggedValueRowIds", [self]): with ops.name_scope(name, "RaggedValueRowIds", [self]):
return segment_id_ops.row_splits_to_segment_ids(self.row_splits) return segment_id_ops.row_splits_to_segment_ids(self.row_splits)
def nrows(self, out_type=dtypes.int64, name=None): def nrows(self, out_type=None, name=None):
"""Returns the number of rows in this ragged tensor. """Returns the number of rows in this ragged tensor.
I.e., the size of the outermost dimension of the tensor. I.e., the size of the outermost dimension of the tensor.
Args: Args:
out_type: `dtype` for the returned tensor. out_type: `dtype` for the returned tensor. Defaults to
`self.row_splits.dtype`.
name: A name prefix for the returned tensor (optional). name: A name prefix for the returned tensor (optional).
Returns: Returns:
@ -824,7 +873,10 @@ class RaggedTensor(composite_tensor.CompositeTensor):
""" """
if self._cached_nrows is not None: if self._cached_nrows is not None:
return self._cached_nrows return self._cached_nrows
if out_type is None:
out_type = self._row_splits.dtype
else:
out_type = dtypes.as_dtype(out_type)
with ops.name_scope(name, "RaggedNRows", [self]): with ops.name_scope(name, "RaggedNRows", [self]):
return array_ops.shape(self.row_splits, out_type=out_type)[0] - 1 return array_ops.shape(self.row_splits, out_type=out_type)[0] - 1
@ -838,7 +890,7 @@ class RaggedTensor(composite_tensor.CompositeTensor):
name: A name prefix for the returned tensor (optional). name: A name prefix for the returned tensor (optional).
Returns: Returns:
A 1-D Tensor of int64 with shape `[nrows]`. A 1-D integer Tensor with shape `[nrows]`.
The returned tensor is nonnegative, and is sorted in ascending order. The returned tensor is nonnegative, and is sorted in ascending order.
#### Example: #### Example:
@ -863,7 +915,7 @@ class RaggedTensor(composite_tensor.CompositeTensor):
name: A name prefix for the returned tensor (optional). name: A name prefix for the returned tensor (optional).
Returns: Returns:
A 1-D Tensor of int64 with shape `[nrows]`. A 1-D integer Tensor with shape `[nrows]`.
The returned tensor is nonnegative, and is sorted in ascending order. The returned tensor is nonnegative, and is sorted in ascending order.
#### Example: #### Example:
@ -890,7 +942,7 @@ class RaggedTensor(composite_tensor.CompositeTensor):
name: A name prefix for the returned tensor (optional). name: A name prefix for the returned tensor (optional).
Returns: Returns:
A potentially ragged Tensor of int64 with shape `self.shape[:axis]`. A potentially ragged integer Tensor with shape `self.shape[:axis]`.
Raises: Raises:
ValueError: If `axis` is out of bounds. ValueError: If `axis` is out of bounds.
@ -917,9 +969,10 @@ class RaggedTensor(composite_tensor.CompositeTensor):
elif isinstance(self.values, RaggedTensor): elif isinstance(self.values, RaggedTensor):
return self.with_values(self.values.row_lengths(axis - 1)) return self.with_values(self.values.row_lengths(axis - 1))
else: else:
shape = array_ops.shape(self.values, out_type=dtypes.int64) shape = array_ops.shape(self.values, out_type=self._row_splits.dtype)
return self.with_values( return self.with_values(
array_ops.ones(shape[:axis - 1], dtypes.int64) * shape[axis - 1]) array_ops.ones(shape[:axis - 1], self._row_splits.dtype) *
shape[axis - 1])
def nested_row_lengths(self, name=None): def nested_row_lengths(self, name=None):
"""Returns a tuple containing the row_lengths for all ragged dimensions. """Returns a tuple containing the row_lengths for all ragged dimensions.
@ -931,7 +984,7 @@ class RaggedTensor(composite_tensor.CompositeTensor):
name: A name prefix for the returned tensors (optional). name: A name prefix for the returned tensors (optional).
Returns: Returns:
A `tuple` of 1-D `int64` `Tensors`. The length of the tuple is equal to A `tuple` of 1-D integer `Tensors`. The length of the tuple is equal to
`self.ragged_rank`. `self.ragged_rank`.
""" """
with ops.name_scope(name, "RaggedNestedRowLengths", [self]): with ops.name_scope(name, "RaggedNestedRowLengths", [self]):
@ -942,7 +995,7 @@ class RaggedTensor(composite_tensor.CompositeTensor):
rt = rt.values rt = rt.values
return tuple(rt_nested_row_lengths) return tuple(rt_nested_row_lengths)
def bounding_shape(self, axis=None, name=None): def bounding_shape(self, axis=None, name=None, out_type=None):
"""Returns the tight bounding box shape for this `RaggedTensor`. """Returns the tight bounding box shape for this `RaggedTensor`.
Args: Args:
@ -950,13 +1003,15 @@ class RaggedTensor(composite_tensor.CompositeTensor):
bounding box for. If not specified, then the full bounding box is bounding box for. If not specified, then the full bounding box is
returned. returned.
name: A name prefix for the returned tensor (optional). name: A name prefix for the returned tensor (optional).
out_type: `dtype` for the returned tensor. Defaults to
`self.row_splits.dtype`.
Returns: Returns:
An int64 `Tensor`. If `axis` is not specified, then `output` An integer `Tensor` (`dtype=self.row_splits.dtype`). If `axis` is not
is a vector with `output.shape=[self.shape.ndims]`. If `axis` is a specified, then `output` is a vector with
scalar, then the `output` is a scalar. If `axis` is a vector, then `output.shape=[self.shape.ndims]`. If `axis` is a scalar, then the
`output` is a vector, where `output[i]` is the bounding size for `output` is a scalar. If `axis` is a vector, then `output` is a vector,
dimension `axis[i]`. where `output[i]` is the bounding size for dimension `axis[i]`.
#### Example: #### Example:
```python ```python
@ -965,6 +1020,10 @@ class RaggedTensor(composite_tensor.CompositeTensor):
[5, 4] [5, 4]
``` ```
""" """
if out_type is None:
out_type = self._row_splits.dtype
else:
out_type = dtypes.as_dtype(out_type)
with ops.name_scope(name, "RaggedBoundingBox", [self, axis]): with ops.name_scope(name, "RaggedBoundingBox", [self, axis]):
nested_splits = self.nested_row_splits nested_splits = self.nested_row_splits
rt_flat_values = self.flat_values rt_flat_values = self.flat_values
@ -972,12 +1031,12 @@ class RaggedTensor(composite_tensor.CompositeTensor):
# Optimized special cases for when axis=0 or axis=1: # Optimized special cases for when axis=0 or axis=1:
if isinstance(axis, int): if isinstance(axis, int):
if axis == 0: if axis == 0:
return array_ops.shape(nested_splits[0], out_type=dtypes.int64)[0] - 1 return array_ops.shape(nested_splits[0], out_type=out_type)[0] - 1
elif axis == 1: elif axis == 1:
return math_ops.maximum(math_ops.reduce_max(self.row_lengths()), 0) return math_ops.maximum(math_ops.reduce_max(self.row_lengths()), 0)
splits_shape = array_ops.shape(self.row_splits, out_type=dtypes.int64) splits_shape = array_ops.shape(self.row_splits, out_type=out_type)
flat_values_shape = array_ops.shape(rt_flat_values, out_type=dtypes.int64) flat_values_shape = array_ops.shape(rt_flat_values, out_type=out_type)
ragged_dimensions = array_ops.stack([splits_shape[0] - 1] + [ ragged_dimensions = array_ops.stack([splits_shape[0] - 1] + [
math_ops.maximum(math_ops.reduce_max(splits[1:] - splits[:-1]), 0) math_ops.maximum(math_ops.reduce_max(splits[1:] - splits[:-1]), 0)
@ -1009,6 +1068,14 @@ class RaggedTensor(composite_tensor.CompositeTensor):
""" """
new_values.shape.with_rank_at_least(1) new_values.shape.with_rank_at_least(1)
self.values.shape[:1].assert_is_compatible_with(new_values.shape[:1]) self.values.shape[:1].assert_is_compatible_with(new_values.shape[:1])
if (isinstance(new_values, RaggedTensor) and
self._row_splits.dtype != new_values.row_splits.dtype):
if not ragged_config.auto_cast_partition_dtype():
raise ValueError("self and new_values have mismatched row_splits "
"dtypes; use RaggedTensor.with_row_splits_dtype() to "
"convert them to compatible dtypes.")
new_values = new_values.with_row_splits_dtype(dtypes.int64)
return self.with_row_splits_dtype(dtypes.int64).with_values(new_values)
return RaggedTensor( return RaggedTensor(
new_values, new_values,
self._row_splits, self._row_splits,
@ -1038,6 +1105,43 @@ class RaggedTensor(composite_tensor.CompositeTensor):
else: else:
return self.with_values(self.values.with_flat_values(new_values)) return self.with_values(self.values.with_flat_values(new_values))
def with_row_splits_dtype(self, dtype):
"""Returns a copy of this RaggedTensor with the given `row_splits` dtype.
For RaggedTensors with multiple ragged dimensions, the `row_splits` for all
nested `RaggedTensor` objects are cast to the given dtype.
Args:
dtype: The dtype for `row_splits`. One of `tf.int32` or `tf.int64`.
Returns:
A copy of this RaggedTensor, with the `row_splits` cast to the given
type.
"""
dtype = dtypes.as_dtype(dtype)
if dtype not in (dtypes.int32, dtypes.int64):
raise ValueError("dtype must be int32 or int64")
if self._row_splits.dtype == dtype:
return self
row_splits = math_ops.cast(self._row_splits, dtype)
values = self._values
if isinstance(values, RaggedTensor):
values = values.with_row_splits_dtype(dtype)
cached_row_lengths = self._cached_row_lengths
if cached_row_lengths is not None:
cached_row_lengths = math_ops.cast(cached_row_lengths, dtype)
cached_value_rowids = self._cached_value_rowids
if cached_value_rowids is not None:
cached_value_rowids = math_ops.cast(cached_value_rowids, dtype)
cached_nrows = self._cached_nrows
if cached_value_rowids is not None:
cached_value_rowids = math_ops.cast(cached_value_rowids, dtype)
return RaggedTensor(values, row_splits, cached_row_lengths,
cached_value_rowids, cached_nrows, internal=True)
#============================================================================= #=============================================================================
# Tensor Type Conversions # Tensor Type Conversions
#============================================================================= #=============================================================================
@ -1048,7 +1152,8 @@ class RaggedTensor(composite_tensor.CompositeTensor):
lengths=None, lengths=None,
padding=None, padding=None,
ragged_rank=1, ragged_rank=1,
name=None): name=None,
row_splits_dtype=dtypes.int64):
"""Converts a `tf.Tensor` into a `RaggedTensor`. """Converts a `tf.Tensor` into a `RaggedTensor`.
The set of absent/default values may be specified using a vector of lengths The set of absent/default values may be specified using a vector of lengths
@ -1096,6 +1201,8 @@ class RaggedTensor(composite_tensor.CompositeTensor):
ragged_rank: Integer specifying the ragged rank for the returned ragged_rank: Integer specifying the ragged rank for the returned
`RaggedTensor`. Must be greater than zero. `RaggedTensor`. Must be greater than zero.
name: A name prefix for the returned tensors (optional). name: A name prefix for the returned tensors (optional).
row_splits_dtype: `dtype` for the returned `RaggedTensor`'s `row_splits`
tensor. One of `tf.int32` or `tf.int64`.
Returns: Returns:
A `RaggedTensor` with the specified `ragged_rank`. The shape of the A `RaggedTensor` with the specified `ragged_rank`. The shape of the
@ -1103,6 +1210,7 @@ class RaggedTensor(composite_tensor.CompositeTensor):
Raises: Raises:
ValueError: If both `lengths` and `padding` are specified. ValueError: If both `lengths` and `padding` are specified.
""" """
row_splits_dtype = dtypes.as_dtype(row_splits_dtype)
if lengths is not None and padding is not None: if lengths is not None and padding is not None:
raise ValueError("Specify lengths or padding, but not both") raise ValueError("Specify lengths or padding, but not both")
if not isinstance(ragged_rank, int): if not isinstance(ragged_rank, int):
@ -1114,7 +1222,7 @@ class RaggedTensor(composite_tensor.CompositeTensor):
with ops.name_scope(name, "RaggedFromTensor", [tensor, lengths, padding]): with ops.name_scope(name, "RaggedFromTensor", [tensor, lengths, padding]):
tensor = ops.convert_to_tensor(tensor, name="tensor") tensor = ops.convert_to_tensor(tensor, name="tensor")
tensor.shape.with_rank_at_least(ragged_rank + 1) tensor.shape.with_rank_at_least(ragged_rank + 1)
input_shape = array_ops.shape(tensor, out_type=dtypes.int64) input_shape = array_ops.shape(tensor, out_type=row_splits_dtype)
ncols = input_shape[1] ncols = input_shape[1]
# Handle ragged_rank>1 via recursion: # Handle ragged_rank>1 via recursion:
@ -1125,12 +1233,14 @@ class RaggedTensor(composite_tensor.CompositeTensor):
if ragged_rank > 1: if ragged_rank > 1:
# Flatten `tensor` to eliminate all but the last ragged dimension. # Flatten `tensor` to eliminate all but the last ragged dimension.
new_shape = array_ops.concat([ new_shape = array_ops.concat([
constant_op.constant([-1], dtypes.int64), input_shape[ragged_rank:] constant_op.constant([-1], row_splits_dtype),
input_shape[ragged_rank:]
], ],
axis=0) axis=0)
flattened = array_ops.reshape(tensor, new_shape) flattened = array_ops.reshape(tensor, new_shape)
# Recursively convert the flattened tensor. # Recursively convert the flattened tensor.
values = cls.from_tensor(flattened, lengths, padding) values = cls.from_tensor(flattened, lengths, padding,
row_splits_dtype=row_splits_dtype)
# The total number of elements in each dimension. E.g., if # The total number of elements in each dimension. E.g., if
# input_shape=[3, 4, 5, 6], then dim[2] has 3*4*5 elements in total. # input_shape=[3, 4, 5, 6], then dim[2] has 3*4*5 elements in total.
dim_size = math_ops.cumprod(input_shape) dim_size = math_ops.cumprod(input_shape)
@ -1167,12 +1277,12 @@ class RaggedTensor(composite_tensor.CompositeTensor):
has_default.set_shape(tensor_shape.TensorShape([None, None])) has_default.set_shape(tensor_shape.TensorShape([None, None]))
has_default.set_shape(tensor.shape[:2]) has_default.set_shape(tensor.shape[:2])
# Use has_default it to find the length of each row: for each # Use has_default to find the length of each row: for each
# non-default item in a row, calculate the length that the row needs to # non-default item in a row, calculate the length that the row needs to
# have to include that item; and then take the max of those values # have to include that item; and then take the max of those values
# (across each row). # (across each row).
has_nondefault = math_ops.logical_not(has_default) has_nondefault = math_ops.logical_not(has_default)
has_nondefault = math_ops.cast(has_nondefault, dtypes.int64) has_nondefault = math_ops.cast(has_nondefault, row_splits_dtype)
length_for_nondefault_value = ( length_for_nondefault_value = (
has_nondefault * array_ops.expand_dims( has_nondefault * array_ops.expand_dims(
math_ops.range(1, ncols + 1), 0)) math_ops.range(1, ncols + 1), 0))
@ -1198,13 +1308,13 @@ class RaggedTensor(composite_tensor.CompositeTensor):
# paddings), then use those to construct splits; and then use masking # paddings), then use those to construct splits; and then use masking
# to get the corresponding values. # to get the corresponding values.
lengths = ragged_util.convert_to_int_tensor(lengths, "lengths", lengths = ragged_util.convert_to_int_tensor(lengths, "lengths",
dtypes.int64) row_splits_dtype)
lengths.shape.assert_has_rank(1) lengths.shape.assert_has_rank(1)
lengths = math_ops.minimum(lengths, ncols) lengths = math_ops.minimum(lengths, ncols)
lengths = math_ops.maximum(lengths, 0) lengths = math_ops.maximum(lengths, 0)
limits = math_ops.cumsum(lengths) limits = math_ops.cumsum(lengths)
splits = array_ops.concat( splits = array_ops.concat(
[array_ops.zeros([1], dtypes.int64), limits], axis=0) [array_ops.zeros([1], row_splits_dtype), limits], axis=0)
mask = array_ops.sequence_mask(lengths, maxlen=ncols) mask = array_ops.sequence_mask(lengths, maxlen=ncols)
values = array_ops.boolean_mask(tensor, mask) values = array_ops.boolean_mask(tensor, mask)
return cls.from_row_splits(values, splits) return cls.from_row_splits(values, splits)
@ -1267,9 +1377,10 @@ class RaggedTensor(composite_tensor.CompositeTensor):
# Get the expected dense shape ([nrows, ncols] + value_shape). # Get the expected dense shape ([nrows, ncols] + value_shape).
rt_row_lengths = [self.row_splits[1:] - self.row_splits[:-1]] rt_row_lengths = [self.row_splits[1:] - self.row_splits[:-1]]
nrows = array_ops.shape(self.row_splits, out_type=dtypes.int64)[0] - 1 nrows = array_ops.shape(self.row_splits,
out_type=self._row_splits.dtype)[0] - 1
ncols = math_ops.maximum(math_ops.reduce_max(rt_row_lengths), 0) ncols = math_ops.maximum(math_ops.reduce_max(rt_row_lengths), 0)
values_shape = array_ops.shape(values, out_type=dtypes.int64) values_shape = array_ops.shape(values, out_type=self._row_splits.dtype)
value_shape = values_shape[1:] value_shape = values_shape[1:]
nvals = values_shape[0] nvals = values_shape[0]
@ -1305,7 +1416,7 @@ class RaggedTensor(composite_tensor.CompositeTensor):
return array_ops.gather(values_and_default, indices) return array_ops.gather(values_and_default, indices)
@classmethod @classmethod
def from_sparse(cls, st_input, name=None): def from_sparse(cls, st_input, name=None, row_splits_dtype=dtypes.int64):
"""Converts a 2D `tf.SparseTensor` to a `RaggedTensor`. """Converts a 2D `tf.SparseTensor` to a `RaggedTensor`.
Each row of the `output` `RaggedTensor` will contain the explicit values Each row of the `output` `RaggedTensor` will contain the explicit values
@ -1327,6 +1438,8 @@ class RaggedTensor(composite_tensor.CompositeTensor):
Args: Args:
st_input: The sparse tensor to convert. Must have rank 2. st_input: The sparse tensor to convert. Must have rank 2.
name: A name prefix for the returned tensors (optional). name: A name prefix for the returned tensors (optional).
row_splits_dtype: `dtype` for the returned `RaggedTensor`'s `row_splits`
tensor. One of `tf.int32` or `tf.int64`.
Returns: Returns:
A `RaggedTensor` with the same values as `st_input`. A `RaggedTensor` with the same values as `st_input`.
@ -1336,6 +1449,7 @@ class RaggedTensor(composite_tensor.CompositeTensor):
ValueError: If the number of dimensions in `st_input` is not known ValueError: If the number of dimensions in `st_input` is not known
statically, or is not two. statically, or is not two.
""" """
row_splits_dtype = dtypes.as_dtype(row_splits_dtype)
if not sparse_tensor.is_sparse(st_input): if not sparse_tensor.is_sparse(st_input):
raise TypeError("Expected SparseTensor, got %s" % type(st_input).__name__) raise TypeError("Expected SparseTensor, got %s" % type(st_input).__name__)
with ops.name_scope(name, "RaggedFromSparse", [st_input]): with ops.name_scope(name, "RaggedFromSparse", [st_input]):
@ -1360,8 +1474,8 @@ class RaggedTensor(composite_tensor.CompositeTensor):
# Treat sparse row indices as segment ids to generate a splits tensor # Treat sparse row indices as segment ids to generate a splits tensor
# thta we can pair with the sparse tensor values. (Ignore sparse column # thta we can pair with the sparse tensor values. (Ignore sparse column
# indices.) # indices.)
segment_ids = st_input.indices[:, 0] segment_ids = math_ops.cast(st_input.indices[:, 0], row_splits_dtype)
num_segments = st_input.dense_shape[0] num_segments = math_ops.cast(st_input.dense_shape[0], row_splits_dtype)
return cls.from_value_rowids(st_input.values, segment_ids, num_segments) return cls.from_value_rowids(st_input.values, segment_ids, num_segments)
def to_sparse(self, name=None): def to_sparse(self, name=None):
@ -1518,6 +1632,50 @@ def is_ragged(value):
(RaggedTensor, ragged_tensor_value.RaggedTensorValue)) (RaggedTensor, ragged_tensor_value.RaggedTensorValue))
def match_row_splits_dtypes(*tensors, **kwargs):
"""Return a copy of `tensors` with row_splits all having the same dtype.
Args:
*tensors: A list of Tensors or RaggedTensors.
**kwargs: If 'return_dtype=True', then return a tuple (dtype, tensors),
where `dtype` is the data type used by row-splits, and `tensors` is the
converted list of `Tensors` and `RaggedTensors`.
Returns:
The converted list of `Tensors` and `RaggedTensors`.
"""
return_dtype = kwargs.pop("return_dtype", False)
if kwargs:
raise ValueError("Unexpected keyword args %r" % kwargs)
has_int32 = False
has_int64 = False
for tensor in tensors:
if isinstance(tensor, RaggedTensor):
if tensor.row_splits.dtype == dtypes.int32:
has_int32 = True
else:
has_int64 = True
if has_int32 and has_int64:
if not ragged_config.auto_cast_partition_dtype():
raise ValueError("Input RaggedTensors have mismatched row_splits dtypes; "
"use RaggedTensor.with_row_splits_dtype() to convert "
"them to compatible dtypes.")
dtype = dtypes.int64
tensors = tuple(t.with_row_splits_dtype(dtypes.int64)
if isinstance(t, RaggedTensor) else t for t in tensors)
elif has_int32:
dtype = dtypes.int32
else:
dtype = dtypes.int64
if return_dtype:
return (dtype, tensors)
else:
return tensors
#=============================================================================== #===============================================================================
# Convert value -> tensor # Convert value -> tensor
#=============================================================================== #===============================================================================
@ -1606,18 +1764,23 @@ class RaggedTensorType(object):
`RaggedTensor`. `RaggedTensor`.
""" """
def __init__(self, dtype, ragged_rank): def __init__(self, dtype, ragged_rank, row_splits_dtype=dtypes.int64):
"""Initializes a RaggedTensorType object. """Initializes a RaggedTensorType object.
Args: Args:
dtype: data type of the `RaggedTensor`'s inner values. dtype: data type of the `RaggedTensor`'s inner values.
ragged_rank: ragged_rank of the declared `RaggedTensor`. ragged_rank: ragged_rank of the declared `RaggedTensor`.
row_splits_dtype: data type for the `RaggedTensor`'s row splits.
One of: `tf.int32` or `tf.int64`.
""" """
row_splits_dtype = dtypes.as_dtype(row_splits_dtype)
self._dtype = dtype self._dtype = dtype
self._ragged_rank = ragged_rank self._ragged_rank = ragged_rank
self._row_splits_dtype = row_splits_dtype
dtype = property(lambda self: self._dtype) dtype = property(lambda self: self._dtype)
ragged_rank = property(lambda self: self._ragged_rank) ragged_rank = property(lambda self: self._ragged_rank)
row_splits_dtype = property(lambda self: self._row_splits_dtype)
#=============================================================================== #===============================================================================

View File

@ -27,6 +27,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops.ragged import ragged_array_ops from tensorflow.python.ops.ragged import ragged_array_ops
from tensorflow.python.ops.ragged import ragged_config
from tensorflow.python.ops.ragged import ragged_conversion_ops from tensorflow.python.ops.ragged import ragged_conversion_ops
from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.ops.ragged import ragged_util from tensorflow.python.ops.ragged import ragged_util
@ -82,7 +83,8 @@ class RaggedTensorDynamicShape(object):
`[[[1, 2], [3]], [[4, 5]]]` | 2 | `2, (2, 1), (2, 1, 2)` | `[[[1, 2], [3]], [[4, 5]]]` | 2 | `2, (2, 1), (2, 1, 2)` |
""" """
def __init__(self, partitioned_dim_sizes, inner_dim_sizes): def __init__(self, partitioned_dim_sizes, inner_dim_sizes,
dim_size_dtype=None):
"""Creates a RaggedTensorDynamicShape. """Creates a RaggedTensorDynamicShape.
Args: Args:
@ -96,16 +98,19 @@ class RaggedTensorDynamicShape(object):
number of inner dimensions. `inner_dim_sizes[n]` is the size of all number of inner dimensions. `inner_dim_sizes[n]` is the size of all
slices across the `n`th inner dimension (which is the slices across the `n`th inner dimension (which is the
`(len(partitioned_dim_sizes)+n)`th dimension in the overall tensor. `(len(partitioned_dim_sizes)+n)`th dimension in the overall tensor.
dim_size_dtype: dtype for dimension sizes. If not specified, then it
is chosen based on the dtypes of `partitioned_dim_sizes` and
`inner_dim_sizes`.
""" """
assert isinstance(partitioned_dim_sizes, (list, tuple)) assert isinstance(partitioned_dim_sizes, (list, tuple))
with ops.name_scope(None, 'RaggedTensorDynamicShape', with ops.name_scope(None, 'RaggedTensorDynamicShape',
(partitioned_dim_sizes, inner_dim_sizes)): (partitioned_dim_sizes, inner_dim_sizes)):
partitioned_dim_sizes = tuple( partitioned_dim_sizes = tuple(
ragged_util.convert_to_int_tensor( ops.convert_to_tensor(size, name='partitioned_dimension_size_%d' % i)
size, dtype=dtypes.int64, name='partitioned_dimension_size') for (i, size) in enumerate(partitioned_dim_sizes))
for size in partitioned_dim_sizes) inner_dim_sizes = ops.convert_to_tensor(
inner_dim_sizes = ragged_util.convert_to_int_tensor( inner_dim_sizes, name='inner_dim_sizes')
inner_dim_sizes, dtype=dtypes.int64, name='inner_dim_sizes')
# Validate shapes. # Validate shapes.
if partitioned_dim_sizes: if partitioned_dim_sizes:
@ -120,6 +125,22 @@ class RaggedTensorDynamicShape(object):
raise ValueError('innermost partitioned dimension must be ragged') raise ValueError('innermost partitioned dimension must be ragged')
inner_dim_sizes.shape.assert_has_rank(1) inner_dim_sizes.shape.assert_has_rank(1)
# Convert dimension size tensors to a single dtype.
if dim_size_dtype is None:
dim_size_dtypes = set([p.dtype for p in partitioned_dim_sizes
if p.shape.ndims == 1])
if not dim_size_dtypes:
dim_size_dtype = dtypes.int64
elif len(dim_size_dtypes) == 1:
dim_size_dtype = dim_size_dtypes.pop()
else:
if not ragged_config.auto_cast_partition_dtype():
raise ValueError('partitioned_dim_sizes must have matching dtypes')
dim_size_dtype = dtypes.int64
partitioned_dim_sizes = tuple(math_ops.cast(p, dim_size_dtype)
for p in partitioned_dim_sizes)
inner_dim_sizes = math_ops.cast(inner_dim_sizes, dim_size_dtype)
self._partitioned_dim_sizes = partitioned_dim_sizes self._partitioned_dim_sizes = partitioned_dim_sizes
self._inner_dim_sizes = inner_dim_sizes self._inner_dim_sizes = inner_dim_sizes
@ -137,7 +158,7 @@ class RaggedTensorDynamicShape(object):
ragged. ragged.
Args: Args:
dim_sizes: List of int64 scalars or vectors. dim_sizes: List of int32 or int64 scalars or vectors.
Returns: Returns:
A RaggedTensorDynamicShape. A RaggedTensorDynamicShape.
@ -145,8 +166,8 @@ class RaggedTensorDynamicShape(object):
with ops.name_scope(None, 'RaggedTensorDynamicShapeFromDimensionSizes', with ops.name_scope(None, 'RaggedTensorDynamicShapeFromDimensionSizes',
[dim_sizes]): [dim_sizes]):
dim_sizes = tuple( dim_sizes = tuple(
ragged_util.convert_to_int_tensor( ops.convert_to_tensor(size, preferred_dtype=dtypes.int64,
size, dtype=dtypes.int64, name='dim_sizes') for size in dim_sizes) name='dim_sizes') for size in dim_sizes)
# Split the dimensions into partitioned & inner dimensions. # Split the dimensions into partitioned & inner dimensions.
inner_split = 0 inner_split = 0
for dim, dim_size in enumerate(dim_sizes): for dim, dim_size in enumerate(dim_sizes):
@ -158,7 +179,7 @@ class RaggedTensorDynamicShape(object):
dim_sizes[inner_split:]) dim_sizes[inner_split:])
@classmethod @classmethod
def from_tensor(cls, rt_input): def from_tensor(cls, rt_input, dim_size_dtype=None):
"""Constructs a ragged shape for a potentially ragged tensor.""" """Constructs a ragged shape for a potentially ragged tensor."""
with ops.name_scope(None, 'RaggedTensorDynamicShapeFromTensor', [rt_input]): with ops.name_scope(None, 'RaggedTensorDynamicShapeFromTensor', [rt_input]):
rt_input = ragged_tensor.convert_to_tensor_or_ragged_tensor(rt_input) rt_input = ragged_tensor.convert_to_tensor_or_ragged_tensor(rt_input)
@ -169,7 +190,8 @@ class RaggedTensorDynamicShape(object):
(rt_input.nrows(),) + rt_input.nested_row_lengths()) (rt_input.nrows(),) + rt_input.nested_row_lengths())
return RaggedTensorDynamicShape( return RaggedTensorDynamicShape(
partitioned_dim_sizes, partitioned_dim_sizes,
array_ops.shape(rt_input.flat_values)[1:]) array_ops.shape(rt_input.flat_values)[1:],
dim_size_dtype=dim_size_dtype)
def dimension_size(self, axis): def dimension_size(self, axis):
"""Returns the size of slices across the specified dimension.""" """Returns the size of slices across the specified dimension."""
@ -231,6 +253,11 @@ class RaggedTensorDynamicShape(object):
"""The number of inner dimensions, or `None` if not statically known.""" """The number of inner dimensions, or `None` if not statically known."""
return tensor_shape.dimension_value(self._inner_dim_sizes.shape[0]) return tensor_shape.dimension_value(self._inner_dim_sizes.shape[0])
@property
def dim_size_dtype(self):
"""DType used by this shape for dimension sizes."""
return self._inner_dim_sizes.dtype
def broadcast_to_rank(self, rank): def broadcast_to_rank(self, rank):
"""Adds leading size-1 dimensions to broadcast `self` to the given rank. """Adds leading size-1 dimensions to broadcast `self` to the given rank.
@ -260,7 +287,8 @@ class RaggedTensorDynamicShape(object):
return RaggedTensorDynamicShape(partitioned_dims, self._inner_dim_sizes) return RaggedTensorDynamicShape(partitioned_dims, self._inner_dim_sizes)
else: else:
inner_dims = array_ops.concat( inner_dims = array_ops.concat(
[array_ops.ones([dims_to_add], dtypes.int64), self.inner_dim_sizes], [array_ops.ones([dims_to_add], self.dim_size_dtype),
self.inner_dim_sizes],
axis=0) axis=0)
return RaggedTensorDynamicShape([], inner_dims) return RaggedTensorDynamicShape([], inner_dims)
@ -290,7 +318,7 @@ class RaggedTensorDynamicShape(object):
A `RaggedTensorDynamicShape`. A `RaggedTensorDynamicShape`.
""" """
lengths = ragged_util.convert_to_int_tensor( lengths = ragged_util.convert_to_int_tensor(
lengths, name='lengths', dtype=dtypes.int64) lengths, name='lengths', dtype=self.dim_size_dtype)
# Check whether lengths is a scalar (for uniform dimensions) or # Check whether lengths is a scalar (for uniform dimensions) or
# vector (for ragged dimensions). # vector (for ragged dimensions).
if lengths.shape.ndims is None: if lengths.shape.ndims is None:
@ -347,7 +375,7 @@ class RaggedTensorDynamicShape(object):
def num_slices_in_dimension(self, axis): def num_slices_in_dimension(self, axis):
"""Returns the total number of slices across the indicated dimension.""" """Returns the total number of slices across the indicated dimension."""
if axis < 0: if axis < 0:
return constant_op.constant(1, dtype=dtypes.int64) return constant_op.constant(1, dtype=self.dim_size_dtype)
elif self.is_ragged(axis): elif self.is_ragged(axis):
return math_ops.reduce_sum(self._partitioned_dim_sizes[axis]) return math_ops.reduce_sum(self._partitioned_dim_sizes[axis])
else: else:
@ -365,7 +393,7 @@ class RaggedTensorDynamicShape(object):
splits = array_ops.stack([0, self.num_slices_in_dimension(axis)]) splits = array_ops.stack([0, self.num_slices_in_dimension(axis)])
else: else:
splits = math_ops.range( splits = math_ops.range(
array_ops.size(lengths, out_type=dtypes.int64) + 1) array_ops.size(lengths, out_type=self.dim_size_dtype) + 1)
repeats = lengths repeats = lengths
partitioned_sizes.append(lengths) partitioned_sizes.append(lengths)
@ -404,6 +432,15 @@ class RaggedTensorDynamicShape(object):
inner_sizes = self._inner_dim_sizes[axis_in_inner_dims + 1:] inner_sizes = self._inner_dim_sizes[axis_in_inner_dims + 1:]
return RaggedTensorDynamicShape(partitioned_sizes, inner_sizes) return RaggedTensorDynamicShape(partitioned_sizes, inner_sizes)
def with_dim_size_dtype(self, dtype):
if dtype not in (dtypes.int32, dtypes.int64):
raise ValueError('dtype must be int32 or int64')
if self.dim_size_dtype == dtype:
return self
return RaggedTensorDynamicShape(
[math_ops.cast(p, dtype) for p in self._partitioned_dim_sizes],
math_ops.cast(self._inner_dim_sizes, dtype))
def broadcast_dynamic_shape(shape_x, shape_y): def broadcast_dynamic_shape(shape_x, shape_y):
"""Returns the shape formed by broadcasting two shapes to be compatible. """Returns the shape formed by broadcasting two shapes to be compatible.
@ -479,6 +516,17 @@ def _broadcast_to_uniform_shape(rt_input, shape, broadcast_inner_dimensions):
def _broadcast_to_ragged_shape(rt_input, dst_shape, broadcast_inner_dimensions): def _broadcast_to_ragged_shape(rt_input, dst_shape, broadcast_inner_dimensions):
"""Broadcasts rt_input to the ragged shape `dst_shape`.""" """Broadcasts rt_input to the ragged shape `dst_shape`."""
# Check that rt_input and dst_shape have the same row_splits dtype.
if (isinstance(rt_input, ragged_tensor.RaggedTensor) and
rt_input.row_splits.dtype != dst_shape.dim_size_dtype):
if not ragged_config.auto_cast_partition_dtype():
raise ValueError('rt_input and dst_shape have different row_split '
'dtypes; use RaggedTensor.with_row_splits_dtype() or '
'RaggedTensorDynamicShape.with_dim_size_dtype() to '
'convert to a compatible dtype.')
rt_input = rt_input.with_row_splits_dtype(dtypes.int64)
dst_shape = dst_shape.with_dim_size_dtype(dtypes.int64)
# dst_shape's rank and ragged_rank must be greater than or equal to rt_input's # dst_shape's rank and ragged_rank must be greater than or equal to rt_input's
if rt_input.shape.ndims is None or dst_shape.rank is None: if rt_input.shape.ndims is None or dst_shape.rank is None:
raise ValueError('Unable to broadcast: unknown rank') raise ValueError('Unable to broadcast: unknown rank')
@ -500,7 +548,8 @@ def _broadcast_to_ragged_shape(rt_input, dst_shape, broadcast_inner_dimensions):
if ragged_tensor.is_ragged(rt_input): if ragged_tensor.is_ragged(rt_input):
nrows = rt_input.nrows() nrows = rt_input.nrows()
else: else:
nrows = array_ops.shape(rt_input, out_type=dtypes.int64)[0] nrows = array_ops.shape(rt_input,
out_type=dst_shape.dim_size_dtype)[0]
rt_input = ragged_tensor.RaggedTensor.from_row_lengths(rt_input, [nrows]) rt_input = ragged_tensor.RaggedTensor.from_row_lengths(rt_input, [nrows])
# Add ragged dimensions to match dst_shape. # Add ragged dimensions to match dst_shape.
@ -509,11 +558,13 @@ def _broadcast_to_ragged_shape(rt_input, dst_shape, broadcast_inner_dimensions):
rt_input.flat_values.shape.ndims - 1 - dst_shape.num_inner_dimensions) rt_input.flat_values.shape.ndims - 1 - dst_shape.num_inner_dimensions)
if inner_rank_diff > 0: if inner_rank_diff > 0:
rt_input = rt_input.with_flat_values( rt_input = rt_input.with_flat_values(
ragged_conversion_ops.from_tensor( ragged_tensor.RaggedTensor.from_tensor(
rt_input.flat_values, ragged_rank=inner_rank_diff)) rt_input.flat_values, ragged_rank=inner_rank_diff,
row_splits_dtype=dst_shape.dim_size_dtype))
else: else:
rt_input = ragged_conversion_ops.from_tensor( rt_input = ragged_tensor.RaggedTensor.from_tensor(
rt_input, ragged_rank=dst_shape.num_partitioned_dimensions - 1) rt_input, ragged_rank=dst_shape.num_partitioned_dimensions - 1,
row_splits_dtype=dst_shape.dim_size_dtype)
# Do broadcasting for any dimensions that will remain uniform. We can do # Do broadcasting for any dimensions that will remain uniform. We can do
# these all at once, since they're independent of one another. # these all at once, since they're independent of one another.
@ -541,21 +592,24 @@ def _broadcast_to_ragged_shape(rt_input, dst_shape, broadcast_inner_dimensions):
for axis in range(dst_shape.num_partitioned_dimensions): for axis in range(dst_shape.num_partitioned_dimensions):
if not src_shape.is_ragged(axis) and dst_shape.is_ragged(axis): if not src_shape.is_ragged(axis) and dst_shape.is_ragged(axis):
dst_size = dst_shape.dimension_size(axis) dst_size = dst_shape.dimension_size(axis)
rt_input = _ragged_tile_axis(rt_input, axis, dst_size) rt_input = _ragged_tile_axis(rt_input, axis, dst_size,
dst_shape.dim_size_dtype)
return rt_input return rt_input
def _ragged_tile_axis(rt_input, axis, repeats): def _ragged_tile_axis(rt_input, axis, repeats, row_splits_dtype):
"""Tile a dimension of a RaggedTensor to match a ragged shape.""" """Tile a dimension of a RaggedTensor to match a ragged shape."""
assert axis > 0 # Outermost dimension may not be ragged. assert axis > 0 # Outermost dimension may not be ragged.
if not ragged_tensor.is_ragged(rt_input): if not ragged_tensor.is_ragged(rt_input):
rt_input = ragged_conversion_ops.from_tensor(rt_input, ragged_rank=1) rt_input = ragged_conversion_ops.from_tensor(
rt_input, ragged_rank=1, row_splits_dtype=row_splits_dtype)
if axis > 1: if axis > 1:
return rt_input.with_values( return rt_input.with_values(
_ragged_tile_axis(rt_input.values, axis - 1, repeats)) _ragged_tile_axis(rt_input.values, axis - 1, repeats,
row_splits_dtype))
else: else:
src_row_splits = rt_input.nested_row_splits src_row_splits = rt_input.nested_row_splits
src_row_lengths = rt_input.nested_row_lengths() src_row_lengths = rt_input.nested_row_lengths()

View File

@ -32,8 +32,8 @@ from tensorflow.python.platform import googletest
@test_util.run_all_in_graph_and_eager_modes @test_util.run_all_in_graph_and_eager_modes
class RaggedTensorBoundingShapeOp(ragged_test_util.RaggedTensorTestCase, class RaggedTensorShapeTest(ragged_test_util.RaggedTensorTestCase,
parameterized.TestCase): parameterized.TestCase):
def assertShapeEq(self, x, y): def assertShapeEq(self, x, y):
assert isinstance(x, RaggedTensorDynamicShape) assert isinstance(x, RaggedTensorDynamicShape)

View File

@ -181,7 +181,7 @@ class RaggedTensorTest(ragged_test_util.RaggedTensorTestCase,
rt_value = ragged_tensor_value.RaggedTensorValue(values, splits) rt_value = ragged_tensor_value.RaggedTensorValue(values, splits)
self.assertEqual(rt_value.row_splits.dtype, np.int64) self.assertEqual(rt_value.row_splits.dtype, np.int64)
self.assertEqual(rt_value.shape, (5, None)) self.assertEqual(rt_value.shape, (5, None))
self.assertEqual(len(rt_value.nested_row_splits), 1) self.assertLen(rt_value.nested_row_splits, 1)
self.assertAllEqual(splits, rt_value.row_splits) self.assertAllEqual(splits, rt_value.row_splits)
self.assertAllEqual(values, rt_value.values) self.assertAllEqual(values, rt_value.values)
self.assertAllEqual(splits, rt_value.nested_row_splits[0]) self.assertAllEqual(splits, rt_value.nested_row_splits[0])
@ -193,7 +193,7 @@ class RaggedTensorTest(ragged_test_util.RaggedTensorTestCase,
row_splits=splits2) row_splits=splits2)
self.assertEqual(rt_value.row_splits.dtype, np.int64) self.assertEqual(rt_value.row_splits.dtype, np.int64)
self.assertEqual(rt_value.shape, (2, None, None)) self.assertEqual(rt_value.shape, (2, None, None))
self.assertEqual(len(rt_value.nested_row_splits), 2) self.assertLen(rt_value.nested_row_splits, 2)
self.assertAllEqual(splits2, rt_value.row_splits) self.assertAllEqual(splits2, rt_value.row_splits)
self.assertAllEqual(splits, rt_value.values.row_splits) self.assertAllEqual(splits, rt_value.values.row_splits)
self.assertAllEqual(splits2, rt_value.nested_row_splits[0]) self.assertAllEqual(splits2, rt_value.nested_row_splits[0])
@ -1078,15 +1078,17 @@ class RaggedTensorTest(ragged_test_util.RaggedTensorTestCase,
values = [b'a', b'b', b'c', b'd', b'e', b'f', b'g'] values = [b'a', b'b', b'c', b'd', b'e', b'f', b'g']
row_splits = [0, 2, 5, 6, 6, 7] row_splits = [0, 2, 5, 6, 6, 7]
rt = RaggedTensor.from_row_splits(values, row_splits) rt = RaggedTensor.from_row_splits(values, row_splits)
splits_type = 'int64'
if context.executing_eagerly(): if context.executing_eagerly():
expected_str = '<tf.RaggedTensor {}>'.format([[b'a', b'b'], expected_str = '<tf.RaggedTensor {}>'.format([[b'a', b'b'],
[b'c', b'd', b'e'], [b'f'], [b'c', b'd', b'e'], [b'f'],
[], [b'g']]) [], [b'g']])
expected_repr = ( expected_repr = (
'tf.RaggedTensor(values=tf.Tensor([{}], shape=(7,), dtype=string), ' 'tf.RaggedTensor(values=tf.Tensor([{}], shape=(7,), dtype=string), '
'row_splits=tf.Tensor([{}], shape=(6,), dtype=int64))'.format( 'row_splits=tf.Tensor([{}], shape=(6,), dtype={}))'.format(
' '.join(repr(x) for x in values), ' '.join(repr(x) for x in values),
' '.join(repr(x) for x in row_splits))) ' '.join(repr(x) for x in row_splits),
splits_type))
self.assertEqual(str(rt), expected_str) self.assertEqual(str(rt), expected_str)
self.assertEqual(repr(rt), expected_repr) self.assertEqual(repr(rt), expected_repr)
else: else:
@ -1094,7 +1096,7 @@ class RaggedTensorTest(ragged_test_util.RaggedTensorTestCase,
'tf.RaggedTensor(values=Tensor("RaggedFromRowSplits/values:0", ' 'tf.RaggedTensor(values=Tensor("RaggedFromRowSplits/values:0", '
'shape=(7,), dtype=string), row_splits=' 'shape=(7,), dtype=string), row_splits='
'Tensor("RaggedFromRowSplits/row_splits:0", ' 'Tensor("RaggedFromRowSplits/row_splits:0", '
'shape=(6,), dtype=int64))') 'shape=(6,), dtype={}))').format(splits_type)
self.assertEqual(repr(rt), expected_repr) self.assertEqual(repr(rt), expected_repr)
self.assertEqual(str(rt), expected_repr) self.assertEqual(str(rt), expected_repr)
@ -1145,7 +1147,7 @@ class RaggedTensorTest(ragged_test_util.RaggedTensorTestCase,
rt2 = ragged_factory_ops.constant([[[], [1, 2]], [[3]]]) rt2 = ragged_factory_ops.constant([[[], [1, 2]], [[3]]])
with self.test_session() as session: with self.test_session() as session:
result = session.run({'rt1': rt1, 'rt2': rt2}) result = session.run({'rt1': rt1, 'rt2': rt2})
self.assertCountEqual(sorted(result.keys()), ['rt1', 'rt2']) self.assertCountEqual(result.keys(), ['rt1', 'rt2'])
self.assertEqual(result['rt1'].to_list(), [[1, 2, 3], [4]]) self.assertEqual(result['rt1'].to_list(), [[1, 2, 3], [4]])
self.assertEqual(result['rt2'].to_list(), [[[], [1, 2]], [[3]]]) self.assertEqual(result['rt2'].to_list(), [[[], [1, 2]], [[3]]])
@ -1166,15 +1168,10 @@ class RaggedTensorTest(ragged_test_util.RaggedTensorTestCase,
rt2_feed_val = ragged_factory_ops.constant_value([[[], [1, 2]], [[3]]]) rt2_feed_val = ragged_factory_ops.constant_value([[[], [1, 2]], [[3]]])
with self.test_session() as session: with self.test_session() as session:
result = session.run({ fetches = {'rt1': rt1, 'rt2': rt2}
'rt1': rt1, feeds = {rt1: rt1_feed_val, rt2: rt2_feed_val}
'rt2': rt2 result = session.run(fetches, feed_dict=feeds)
}, self.assertCountEqual(result.keys(), ['rt1', 'rt2'])
feed_dict={
rt1: rt1_feed_val,
rt2: rt2_feed_val
})
self.assertCountEqual(sorted(result.keys()), ['rt1', 'rt2'])
self.assertEqual(result['rt1'].to_list(), [[1, 2, 3], [4]]) self.assertEqual(result['rt1'].to_list(), [[1, 2, 3], [4]])
self.assertEqual(result['rt2'].to_list(), [[[], [1, 2]], [[3]]]) self.assertEqual(result['rt2'].to_list(), [[[], [1, 2]], [[3]]])

View File

@ -38,13 +38,17 @@ class RaggedTensorValue(object):
Args: Args:
values: A numpy array of any type and shape; or a RaggedTensorValue. values: A numpy array of any type and shape; or a RaggedTensorValue.
row_splits: A 1-D int64 numpy array. row_splits: A 1-D int32 or int64 numpy array.
""" """
if not (isinstance(row_splits, (np.ndarray, np.generic)) and if not (isinstance(row_splits, (np.ndarray, np.generic)) and
row_splits.dtype == np.int64 and row_splits.ndim == 1): row_splits.dtype in (np.int64, np.int32) and row_splits.ndim == 1):
raise TypeError("row_splits must be a 1D int64 numpy array") raise TypeError("row_splits must be a 1D int32 or int64 numpy array")
if not isinstance(values, (np.ndarray, np.generic, RaggedTensorValue)): if not isinstance(values, (np.ndarray, np.generic, RaggedTensorValue)):
raise TypeError("values must be a numpy array or a RaggedTensorValue") raise TypeError("values must be a numpy array or a RaggedTensorValue")
if (isinstance(values, RaggedTensorValue) and
row_splits.dtype != values.row_splits.dtype):
raise ValueError("row_splits and values.row_splits must have "
"the same dtype")
self._values = values self._values = values
self._row_splits = row_splits self._row_splits = row_splits

View File

@ -191,7 +191,6 @@ class RaggedTensorToSparseOpTest(ragged_test_util.RaggedTensorTestCase):
g1, g2 = gradients_impl.gradients(st.values, g1, g2 = gradients_impl.gradients(st.values,
[rt1.flat_values, rt2.flat_values]) [rt1.flat_values, rt2.flat_values])
print(g1, g2)
self.assertRaggedEqual(g1, [[1.0, 1.0], [1.0, 1.0], [1.0, 1.0]]) self.assertRaggedEqual(g1, [[1.0, 1.0], [1.0, 1.0], [1.0, 1.0]])
self.assertRaggedEqual(g2, [[2.0, 2.0], [2.0, 2.0], [2.0, 2.0]]) self.assertRaggedEqual(g2, [[2.0, 2.0], [2.0, 2.0], [2.0, 2.0]])

View File

@ -268,7 +268,7 @@ def repeat_ranges(params, splits, repeats):
else: else:
# Optimization: we can just call repeat once, and then slice the result. # Optimization: we can just call repeat once, and then slice the result.
repeated_splits = repeat(splits, repeats, axis=0) repeated_splits = repeat(splits, repeats, axis=0)
n_splits = array_ops.shape(repeated_splits, out_type=dtypes.int64)[0] n_splits = array_ops.shape(repeated_splits, out_type=repeats.dtype)[0]
repeated_starts = repeated_splits[:n_splits - repeats] repeated_starts = repeated_splits[:n_splits - repeats]
repeated_limits = repeated_splits[repeats:] repeated_limits = repeated_splits[repeats:]

View File

@ -18,7 +18,6 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
@ -108,6 +107,7 @@ def where(condition, x=None, y=None, name=None):
else: else:
x = ragged_tensor.convert_to_tensor_or_ragged_tensor(x, name='x') x = ragged_tensor.convert_to_tensor_or_ragged_tensor(x, name='x')
y = ragged_tensor.convert_to_tensor_or_ragged_tensor(y, name='y') y = ragged_tensor.convert_to_tensor_or_ragged_tensor(y, name='y')
condition, x, y = ragged_tensor.match_row_splits_dtypes(condition, x, y)
return _elementwise_where(condition, x, y) return _elementwise_where(condition, x, y)
@ -145,6 +145,7 @@ def _coordinate_where(condition):
selected_coords = _coordinate_where(condition.values) selected_coords = _coordinate_where(condition.values)
# Convert the first index in each coordinate to a row index and column index. # Convert the first index in each coordinate to a row index and column index.
condition = condition.with_row_splits_dtype(selected_coords.dtype)
first_index = selected_coords[:, 0] first_index = selected_coords[:, 0]
selected_rows = array_ops.gather(condition.value_rowids(), first_index) selected_rows = array_ops.gather(condition.value_rowids(), first_index)
selected_row_starts = array_ops.gather(condition.row_splits, selected_rows) selected_row_starts = array_ops.gather(condition.row_splits, selected_rows)
@ -158,9 +159,8 @@ def _coordinate_where(condition):
axis=1) axis=1)
def _nrows(rt_input, out_type=dtypes.int64, name=None): def _nrows(rt_input):
if isinstance(rt_input, ragged_tensor.RaggedTensor): if isinstance(rt_input, ragged_tensor.RaggedTensor):
return rt_input.nrows(out_type=out_type, name=name) return rt_input.nrows()
else: else:
with ops.name_scope(name, 'RaggedNRows', [rt_input]): return array_ops.shape(rt_input)[0]
return array_ops.shape(rt_input, out_type=out_type)[0]

View File

@ -31,7 +31,7 @@ from tensorflow.python.util.tf_export import tf_export
# For background on "segments" and "segment ids", see: # For background on "segments" and "segment ids", see:
# https://www.tensorflow.org/api_docs/python/tf/math#Segmentation # https://www.tensorflow.org/api_docs/python/tf/math#Segmentation
@tf_export("ragged.row_splits_to_segment_ids") @tf_export("ragged.row_splits_to_segment_ids")
def row_splits_to_segment_ids(splits, name=None): def row_splits_to_segment_ids(splits, name=None, out_type=None):
"""Generates the segmentation corresponding to a RaggedTensor `row_splits`. """Generates the segmentation corresponding to a RaggedTensor `row_splits`.
Returns an integer vector `segment_ids`, where `segment_ids[i] == j` if Returns an integer vector `segment_ids`, where `segment_ids[i] == j` if
@ -43,22 +43,32 @@ def row_splits_to_segment_ids(splits, name=None):
``` ```
Args: Args:
splits: A sorted 1-D int64 Tensor. `splits[0]` must be zero. splits: A sorted 1-D integer Tensor. `splits[0]` must be zero.
name: A name prefix for the returned tensor (optional). name: A name prefix for the returned tensor (optional).
out_type: The dtype for the return value. Defaults to `splits.dtype`,
or `tf.int64` if `splits` does not have a dtype.
Returns: Returns:
A sorted 1-D int64 Tensor, with `shape=[splits[-1]]` A sorted 1-D integer Tensor, with `shape=[splits[-1]]`
Raises: Raises:
ValueError: If `splits` is invalid. ValueError: If `splits` is invalid.
""" """
with ops.name_scope(name, "RaggedSplitsToSegmentIds", [splits]) as name: with ops.name_scope(name, "RaggedSplitsToSegmentIds", [splits]) as name:
splits = ops.convert_to_tensor(splits, dtype=dtypes.int64, name="splits") splits = ops.convert_to_tensor(
splits, name="splits",
preferred_dtype=dtypes.int64)
if splits.dtype not in (dtypes.int32, dtypes.int64):
raise ValueError("splits must have dtype int32 or int64")
splits.shape.assert_has_rank(1) splits.shape.assert_has_rank(1)
if tensor_shape.dimension_value(splits.shape[0]) == 0: if tensor_shape.dimension_value(splits.shape[0]) == 0:
raise ValueError("Invalid row_splits: []") raise ValueError("Invalid row_splits: []")
if out_type is None:
out_type = splits.dtype
else:
out_type = dtypes.as_dtype(out_type)
row_lengths = splits[1:] - splits[:-1] row_lengths = splits[1:] - splits[:-1]
nrows = array_ops.shape(splits, out_type=dtypes.int64)[-1] - 1 nrows = array_ops.shape(splits, out_type=out_type)[-1] - 1
indices = math_ops.range(nrows) indices = math_ops.range(nrows)
return ragged_util.repeat(indices, repeats=row_lengths, axis=0) return ragged_util.repeat(indices, repeats=row_lengths, axis=0)
@ -66,7 +76,8 @@ def row_splits_to_segment_ids(splits, name=None):
# For background on "segments" and "segment ids", see: # For background on "segments" and "segment ids", see:
# https://www.tensorflow.org/api_docs/python/tf/math#Segmentation # https://www.tensorflow.org/api_docs/python/tf/math#Segmentation
@tf_export("ragged.segment_ids_to_row_splits") @tf_export("ragged.segment_ids_to_row_splits")
def segment_ids_to_row_splits(segment_ids, num_segments=None, name=None): def segment_ids_to_row_splits(segment_ids, num_segments=None,
out_type=None, name=None):
"""Generates the RaggedTensor `row_splits` corresponding to a segmentation. """Generates the RaggedTensor `row_splits` corresponding to a segmentation.
Returns an integer vector `splits`, where `splits[0] = 0` and Returns an integer vector `splits`, where `splits[0] = 0` and
@ -81,24 +92,39 @@ def segment_ids_to_row_splits(segment_ids, num_segments=None, name=None):
segment_ids: A 1-D integer Tensor. segment_ids: A 1-D integer Tensor.
num_segments: A scalar integer indicating the number of segments. Defaults num_segments: A scalar integer indicating the number of segments. Defaults
to `max(segment_ids) + 1` (or zero if `segment_ids` is empty). to `max(segment_ids) + 1` (or zero if `segment_ids` is empty).
out_type: The dtype for the return value. Defaults to `segment_ids.dtype`,
or `tf.int64` if `segment_ids` does not have a dtype.
name: A name prefix for the returned tensor (optional). name: A name prefix for the returned tensor (optional).
Returns: Returns:
A sorted 1-D int64 Tensor, with `shape=[num_segments + 1]`. A sorted 1-D integer Tensor, with `shape=[num_segments + 1]`.
""" """
if out_type is None:
if isinstance(segment_ids, ops.Tensor):
out_type = segment_ids.dtype
elif isinstance(num_segments, ops.Tensor):
out_type = num_segments.dtype
else:
out_type = dtypes.int64
else:
out_type = dtypes.as_dtype(out_type)
with ops.name_scope(name, "SegmentIdsToRaggedSplits", [segment_ids]) as name: with ops.name_scope(name, "SegmentIdsToRaggedSplits", [segment_ids]) as name:
segment_ids = ragged_util.convert_to_int_tensor(segment_ids, "segment_ids") # Note: we cast int64 tensors to int32, since bincount currently only
# supports int32 inputs.
segment_ids = ragged_util.convert_to_int_tensor(segment_ids, "segment_ids",
dtype=dtypes.int32)
segment_ids.shape.assert_has_rank(1) segment_ids.shape.assert_has_rank(1)
if num_segments is not None: if num_segments is not None:
num_segments = ragged_util.convert_to_int_tensor(num_segments, num_segments = ragged_util.convert_to_int_tensor(num_segments,
"num_segments") "num_segments",
dtype=dtypes.int32)
num_segments.shape.assert_has_rank(0) num_segments.shape.assert_has_rank(0)
row_lengths = math_ops.bincount( row_lengths = math_ops.bincount(
segment_ids, segment_ids,
minlength=num_segments, minlength=num_segments,
maxlength=num_segments, maxlength=num_segments,
dtype=dtypes.int64) dtype=out_type)
splits = array_ops.concat([[0], math_ops.cumsum(row_lengths)], axis=0) splits = array_ops.concat([[0], math_ops.cumsum(row_lengths)], axis=0)
# Update shape information, if possible. # Update shape information, if possible.

View File

@ -37,7 +37,7 @@ tf_class {
} }
member_method { member_method {
name: "bounding_shape" name: "bounding_shape"
argspec: "args=[\'self\', \'axis\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " argspec: "args=[\'self\', \'axis\', \'name\', \'out_type\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
} }
member_method { member_method {
name: "consumers" name: "consumers"
@ -73,11 +73,11 @@ tf_class {
} }
member_method { member_method {
name: "from_sparse" name: "from_sparse"
argspec: "args=[\'cls\', \'st_input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " argspec: "args=[\'cls\', \'st_input\', \'name\', \'row_splits_dtype\'], varargs=None, keywords=None, defaults=[\'None\', \"<dtype: \'int64\'>\"], "
} }
member_method { member_method {
name: "from_tensor" name: "from_tensor"
argspec: "args=[\'cls\', \'tensor\', \'lengths\', \'padding\', \'ragged_rank\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'1\', \'None\'], " argspec: "args=[\'cls\', \'tensor\', \'lengths\', \'padding\', \'ragged_rank\', \'name\', \'row_splits_dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'1\', \'None\', \"<dtype: \'int64\'>\"], "
} }
member_method { member_method {
name: "from_value_rowids" name: "from_value_rowids"
@ -89,7 +89,7 @@ tf_class {
} }
member_method { member_method {
name: "nrows" name: "nrows"
argspec: "args=[\'self\', \'out_type\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'int64\'>\", \'None\'], " argspec: "args=[\'self\', \'out_type\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
} }
member_method { member_method {
name: "row_lengths" name: "row_lengths"
@ -123,6 +123,10 @@ tf_class {
name: "with_flat_values" name: "with_flat_values"
argspec: "args=[\'self\', \'new_values\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'self\', \'new_values\'], varargs=None, keywords=None, defaults=None"
} }
member_method {
name: "with_row_splits_dtype"
argspec: "args=[\'self\', \'dtype\'], varargs=None, keywords=None, defaults=None"
}
member_method { member_method {
name: "with_values" name: "with_values"
argspec: "args=[\'self\', \'new_values\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'self\', \'new_values\'], varargs=None, keywords=None, defaults=None"

View File

@ -6,11 +6,11 @@ tf_module {
} }
member_method { member_method {
name: "constant" name: "constant"
argspec: "args=[\'pylist\', \'dtype\', \'ragged_rank\', \'inner_shape\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " argspec: "args=[\'pylist\', \'dtype\', \'ragged_rank\', \'inner_shape\', \'name\', \'row_splits_dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \"<dtype: \'int64\'>\"], "
} }
member_method { member_method {
name: "constant_value" name: "constant_value"
argspec: "args=[\'pylist\', \'dtype\', \'ragged_rank\', \'inner_shape\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " argspec: "args=[\'pylist\', \'dtype\', \'ragged_rank\', \'inner_shape\', \'row_splits_dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'int64\'], "
} }
member_method { member_method {
name: "map_flat_values" name: "map_flat_values"
@ -22,14 +22,14 @@ tf_module {
} }
member_method { member_method {
name: "range" name: "range"
argspec: "args=[\'starts\', \'limits\', \'deltas\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'None\', \'None\'], " argspec: "args=[\'starts\', \'limits\', \'deltas\', \'dtype\', \'name\', \'row_splits_dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'None\', \'None\', \"<dtype: \'int64\'>\"], "
} }
member_method { member_method {
name: "row_splits_to_segment_ids" name: "row_splits_to_segment_ids"
argspec: "args=[\'splits\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " argspec: "args=[\'splits\', \'name\', \'out_type\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
} }
member_method { member_method {
name: "segment_ids_to_row_splits" name: "segment_ids_to_row_splits"
argspec: "args=[\'segment_ids\', \'num_segments\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " argspec: "args=[\'segment_ids\', \'num_segments\', \'out_type\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
} }
} }

View File

@ -2662,7 +2662,7 @@ tf_module {
} }
member_method { member_method {
name: "RaggedRange" name: "RaggedRange"
argspec: "args=[\'starts\', \'limits\', \'deltas\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " argspec: "args=[\'starts\', \'limits\', \'deltas\', \'Tsplits\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'int64\'>\", \'None\'], "
} }
member_method { member_method {
name: "RaggedTensorFromVariant" name: "RaggedTensorFromVariant"
@ -4302,11 +4302,11 @@ tf_module {
} }
member_method { member_method {
name: "UnicodeDecode" name: "UnicodeDecode"
argspec: "args=[\'input\', \'input_encoding\', \'errors\', \'replacement_char\', \'replace_control_characters\', \'name\'], varargs=None, keywords=None, defaults=[\'replace\', \'65533\', \'False\', \'None\'], " argspec: "args=[\'input\', \'input_encoding\', \'errors\', \'replacement_char\', \'replace_control_characters\', \'Tsplits\', \'name\'], varargs=None, keywords=None, defaults=[\'replace\', \'65533\', \'False\', \"<dtype: \'int64\'>\", \'None\'], "
} }
member_method { member_method {
name: "UnicodeDecodeWithOffsets" name: "UnicodeDecodeWithOffsets"
argspec: "args=[\'input\', \'input_encoding\', \'errors\', \'replacement_char\', \'replace_control_characters\', \'name\'], varargs=None, keywords=None, defaults=[\'replace\', \'65533\', \'False\', \'None\'], " argspec: "args=[\'input\', \'input_encoding\', \'errors\', \'replacement_char\', \'replace_control_characters\', \'Tsplits\', \'name\'], varargs=None, keywords=None, defaults=[\'replace\', \'65533\', \'False\', \"<dtype: \'int64\'>\", \'None\'], "
} }
member_method { member_method {
name: "UnicodeEncode" name: "UnicodeEncode"

View File

@ -37,7 +37,7 @@ tf_class {
} }
member_method { member_method {
name: "bounding_shape" name: "bounding_shape"
argspec: "args=[\'self\', \'axis\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " argspec: "args=[\'self\', \'axis\', \'name\', \'out_type\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
} }
member_method { member_method {
name: "consumers" name: "consumers"
@ -73,11 +73,11 @@ tf_class {
} }
member_method { member_method {
name: "from_sparse" name: "from_sparse"
argspec: "args=[\'cls\', \'st_input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " argspec: "args=[\'cls\', \'st_input\', \'name\', \'row_splits_dtype\'], varargs=None, keywords=None, defaults=[\'None\', \"<dtype: \'int64\'>\"], "
} }
member_method { member_method {
name: "from_tensor" name: "from_tensor"
argspec: "args=[\'cls\', \'tensor\', \'lengths\', \'padding\', \'ragged_rank\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'1\', \'None\'], " argspec: "args=[\'cls\', \'tensor\', \'lengths\', \'padding\', \'ragged_rank\', \'name\', \'row_splits_dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'1\', \'None\', \"<dtype: \'int64\'>\"], "
} }
member_method { member_method {
name: "from_value_rowids" name: "from_value_rowids"
@ -89,7 +89,7 @@ tf_class {
} }
member_method { member_method {
name: "nrows" name: "nrows"
argspec: "args=[\'self\', \'out_type\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'int64\'>\", \'None\'], " argspec: "args=[\'self\', \'out_type\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
} }
member_method { member_method {
name: "row_lengths" name: "row_lengths"
@ -123,6 +123,10 @@ tf_class {
name: "with_flat_values" name: "with_flat_values"
argspec: "args=[\'self\', \'new_values\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'self\', \'new_values\'], varargs=None, keywords=None, defaults=None"
} }
member_method {
name: "with_row_splits_dtype"
argspec: "args=[\'self\', \'dtype\'], varargs=None, keywords=None, defaults=None"
}
member_method { member_method {
name: "with_values" name: "with_values"
argspec: "args=[\'self\', \'new_values\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'self\', \'new_values\'], varargs=None, keywords=None, defaults=None"

View File

@ -2,7 +2,7 @@ path: "tensorflow.ragged"
tf_module { tf_module {
member_method { member_method {
name: "constant" name: "constant"
argspec: "args=[\'pylist\', \'dtype\', \'ragged_rank\', \'inner_shape\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " argspec: "args=[\'pylist\', \'dtype\', \'ragged_rank\', \'inner_shape\', \'name\', \'row_splits_dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \"<dtype: \'int64\'>\"], "
} }
member_method { member_method {
name: "map_flat_values" name: "map_flat_values"
@ -10,14 +10,14 @@ tf_module {
} }
member_method { member_method {
name: "range" name: "range"
argspec: "args=[\'starts\', \'limits\', \'deltas\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'None\', \'None\'], " argspec: "args=[\'starts\', \'limits\', \'deltas\', \'dtype\', \'name\', \'row_splits_dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'None\', \'None\', \"<dtype: \'int64\'>\"], "
} }
member_method { member_method {
name: "row_splits_to_segment_ids" name: "row_splits_to_segment_ids"
argspec: "args=[\'splits\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " argspec: "args=[\'splits\', \'name\', \'out_type\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
} }
member_method { member_method {
name: "segment_ids_to_row_splits" name: "segment_ids_to_row_splits"
argspec: "args=[\'segment_ids\', \'num_segments\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " argspec: "args=[\'segment_ids\', \'num_segments\', \'out_type\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
} }
} }

View File

@ -2662,7 +2662,7 @@ tf_module {
} }
member_method { member_method {
name: "RaggedRange" name: "RaggedRange"
argspec: "args=[\'starts\', \'limits\', \'deltas\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " argspec: "args=[\'starts\', \'limits\', \'deltas\', \'Tsplits\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'int64\'>\", \'None\'], "
} }
member_method { member_method {
name: "RaggedTensorFromVariant" name: "RaggedTensorFromVariant"
@ -4302,11 +4302,11 @@ tf_module {
} }
member_method { member_method {
name: "UnicodeDecode" name: "UnicodeDecode"
argspec: "args=[\'input\', \'input_encoding\', \'errors\', \'replacement_char\', \'replace_control_characters\', \'name\'], varargs=None, keywords=None, defaults=[\'replace\', \'65533\', \'False\', \'None\'], " argspec: "args=[\'input\', \'input_encoding\', \'errors\', \'replacement_char\', \'replace_control_characters\', \'Tsplits\', \'name\'], varargs=None, keywords=None, defaults=[\'replace\', \'65533\', \'False\', \"<dtype: \'int64\'>\", \'None\'], "
} }
member_method { member_method {
name: "UnicodeDecodeWithOffsets" name: "UnicodeDecodeWithOffsets"
argspec: "args=[\'input\', \'input_encoding\', \'errors\', \'replacement_char\', \'replace_control_characters\', \'name\'], varargs=None, keywords=None, defaults=[\'replace\', \'65533\', \'False\', \'None\'], " argspec: "args=[\'input\', \'input_encoding\', \'errors\', \'replacement_char\', \'replace_control_characters\', \'Tsplits\', \'name\'], varargs=None, keywords=None, defaults=[\'replace\', \'65533\', \'False\', \"<dtype: \'int64\'>\", \'None\'], "
} }
member_method { member_method {
name: "UnicodeEncode" name: "UnicodeEncode"