Update RaggedTensors to support int32 row_splits.
PiperOrigin-RevId: 245157497
This commit is contained in:
parent
02bd711c79
commit
c45be92834
@ -30,10 +30,11 @@ namespace {
|
||||
// For each slice in `(start, limit)` in `value_slices`, append
|
||||
// `params_dense_values_in[start:limit] to `values_out`. `value_size` indicates
|
||||
// the number of scalars contained in each value params_dense_values_in[i].
|
||||
template <typename VALUE_TYPE>
|
||||
void WriteValueSlices(const Tensor& params_dense_values_in,
|
||||
const std::vector<std::pair<int64, int64>>& value_slices,
|
||||
int64 value_size, Tensor* values_out) {
|
||||
template <typename VALUE_TYPE, typename SPLITS_TYPE>
|
||||
void WriteValueSlices(
|
||||
const Tensor& params_dense_values_in,
|
||||
const std::vector<std::pair<SPLITS_TYPE, SPLITS_TYPE>>& value_slices,
|
||||
SPLITS_TYPE value_size, Tensor* values_out) {
|
||||
const auto& params_dense_values =
|
||||
params_dense_values_in.flat_outer_dims<VALUE_TYPE, 2>();
|
||||
auto values = values_out->flat_outer_dims<VALUE_TYPE, 2>();
|
||||
@ -50,7 +51,7 @@ void WriteValueSlices(const Tensor& params_dense_values_in,
|
||||
|
||||
} // namespace
|
||||
|
||||
template <typename INDEX_TYPE>
|
||||
template <typename INDEX_TYPE, typename SPLITS_TYPE>
|
||||
class RaggedGatherOpBase : public OpKernel {
|
||||
public:
|
||||
using OpKernel::OpKernel;
|
||||
@ -66,18 +67,18 @@ class RaggedGatherOpBase : public OpKernel {
|
||||
context->input(params_nested_splits_in.size() + 1);
|
||||
|
||||
DCHECK_GT(params_nested_splits_in.size(), 0); // Enforced by REGISTER_OP.
|
||||
int64 num_params = params_nested_splits_in[0].dim_size(0) - 1;
|
||||
SPLITS_TYPE num_params = params_nested_splits_in[0].dim_size(0) - 1;
|
||||
OP_REQUIRES_OK(context, ValidateIndices(indices_in, num_params));
|
||||
|
||||
OP_REQUIRES(context, params_dense_values_in.dims() > 0,
|
||||
errors::InvalidArgument("params.rank must be nonzero"));
|
||||
int64 num_params_dense_values = params_dense_values_in.dim_size(0);
|
||||
SPLITS_TYPE num_params_dense_values = params_dense_values_in.dim_size(0);
|
||||
|
||||
// Calculate the `splits`, and store the value slices that we need to
|
||||
// copy in `value_slices`.
|
||||
std::vector<std::pair<int64, int64>> value_slices;
|
||||
int64 num_values = 0;
|
||||
std::vector<std::vector<int64>> out_splits;
|
||||
std::vector<std::pair<SPLITS_TYPE, SPLITS_TYPE>> value_slices;
|
||||
SPLITS_TYPE num_values = 0;
|
||||
std::vector<std::vector<SPLITS_TYPE>> out_splits;
|
||||
OP_REQUIRES_OK(context, MakeSplits(indices_in, params_nested_splits_in,
|
||||
num_params_dense_values, &out_splits,
|
||||
&value_slices, &num_values));
|
||||
@ -90,12 +91,14 @@ class RaggedGatherOpBase : public OpKernel {
|
||||
}
|
||||
|
||||
private:
|
||||
using ConstFlatType = typename TTypes<SPLITS_TYPE>::ConstFlat;
|
||||
|
||||
// Check if any indices are out-of-bounds.
|
||||
::tensorflow::Status ValidateIndices(const Tensor& indices_in,
|
||||
int64 num_params) {
|
||||
SPLITS_TYPE num_params) {
|
||||
const auto& indices = indices_in.flat<INDEX_TYPE>();
|
||||
for (int64 i = 0; i < indices.size(); ++i) {
|
||||
int64 index = indices(i);
|
||||
for (SPLITS_TYPE i = 0; i < indices.size(); ++i) {
|
||||
SPLITS_TYPE index = indices(i);
|
||||
if (index < 0 || index >= num_params) {
|
||||
return errors::InvalidArgument(
|
||||
"indices", SliceDebugString(indices_in.shape(), i), " = ", index,
|
||||
@ -111,9 +114,10 @@ class RaggedGatherOpBase : public OpKernel {
|
||||
// we need for allocating the output values tensor) is stored in `num_values`.
|
||||
::tensorflow::Status MakeSplits(
|
||||
const Tensor& indices_in, const OpInputList& params_nested_splits_in,
|
||||
int64 num_params_dense_values,
|
||||
std::vector<std::vector<int64>>* out_splits,
|
||||
std::vector<std::pair<int64, int64>>* value_slices, int64* num_values) {
|
||||
SPLITS_TYPE num_params_dense_values,
|
||||
std::vector<std::vector<SPLITS_TYPE>>* out_splits,
|
||||
std::vector<std::pair<SPLITS_TYPE, SPLITS_TYPE>>* value_slices,
|
||||
SPLITS_TYPE* num_values) {
|
||||
*num_values = 0;
|
||||
value_slices->clear();
|
||||
|
||||
@ -122,10 +126,10 @@ class RaggedGatherOpBase : public OpKernel {
|
||||
|
||||
// Get Eigen tensors.
|
||||
const auto& indices = indices_in.flat<INDEX_TYPE>();
|
||||
std::vector<TTypes<int64>::ConstFlat> params_nested_splits;
|
||||
std::vector<ConstFlatType> params_nested_splits;
|
||||
params_nested_splits.reserve(params_nested_splits_in.size());
|
||||
for (const auto& splits_in : params_nested_splits_in) {
|
||||
params_nested_splits.push_back(splits_in.flat<int64>());
|
||||
params_nested_splits.push_back(splits_in.flat<SPLITS_TYPE>());
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
@ -165,7 +169,7 @@ class RaggedGatherOpBase : public OpKernel {
|
||||
const auto& splits = params_nested_splits[dim];
|
||||
int out_dim = dim + indices_in.dims() - 1;
|
||||
if (out_dim >= 0) {
|
||||
int64 delta = out_splits->at(out_dim).back() - splits(start);
|
||||
SPLITS_TYPE delta = out_splits->at(out_dim).back() - splits(start);
|
||||
for (int j = start; j < limit; ++j) {
|
||||
out_splits->at(out_dim).push_back(splits(j + 1) + delta);
|
||||
}
|
||||
@ -182,14 +186,14 @@ class RaggedGatherOpBase : public OpKernel {
|
||||
}
|
||||
|
||||
::tensorflow::Status ValidateSplits(
|
||||
const std::vector<TTypes<int64>::ConstFlat>& params_nested_splits,
|
||||
int64 num_params_dense_values) {
|
||||
const std::vector<ConstFlatType>& params_nested_splits,
|
||||
SPLITS_TYPE num_params_dense_values) {
|
||||
// Validate
|
||||
for (int dim = 0; dim < params_nested_splits.size(); ++dim) {
|
||||
const auto& splits = params_nested_splits[dim];
|
||||
int64 last_split = (dim == params_nested_splits.size() - 1)
|
||||
? num_params_dense_values
|
||||
: params_nested_splits[dim + 1].size();
|
||||
SPLITS_TYPE last_split = (dim == params_nested_splits.size() - 1)
|
||||
? num_params_dense_values
|
||||
: params_nested_splits[dim + 1].size();
|
||||
if (splits.size() == 0) {
|
||||
return errors::InvalidArgument("Ragged splits may not be empty");
|
||||
}
|
||||
@ -210,17 +214,17 @@ class RaggedGatherOpBase : public OpKernel {
|
||||
}
|
||||
|
||||
::tensorflow::Status WriteSplits(
|
||||
const std::vector<std::vector<int64>>& out_splits,
|
||||
const std::vector<std::vector<SPLITS_TYPE>>& out_splits,
|
||||
OpKernelContext* context) {
|
||||
OpOutputList splits_out;
|
||||
TF_RETURN_IF_ERROR(
|
||||
context->output_list("output_nested_splits", &splits_out));
|
||||
for (int i = 0; i < out_splits.size(); ++i) {
|
||||
Tensor* splits;
|
||||
int64 num_splits = out_splits[i].size();
|
||||
SPLITS_TYPE num_splits = out_splits[i].size();
|
||||
TF_RETURN_IF_ERROR(
|
||||
splits_out.allocate(i, TensorShape({num_splits}), &splits));
|
||||
auto splits_flat = splits->flat<int64>();
|
||||
auto splits_flat = splits->flat<SPLITS_TYPE>();
|
||||
std::copy_n(out_splits[i].data(), out_splits[i].size(),
|
||||
splits_flat.data());
|
||||
}
|
||||
@ -229,15 +233,16 @@ class RaggedGatherOpBase : public OpKernel {
|
||||
|
||||
::tensorflow::Status WriteValues(
|
||||
const Tensor& params_dense_values_in,
|
||||
const std::vector<std::pair<int64, int64>>& value_slices,
|
||||
int values_index, int64 num_values, OpKernelContext* context) const {
|
||||
const std::vector<std::pair<SPLITS_TYPE, SPLITS_TYPE>>& value_slices,
|
||||
int values_index, SPLITS_TYPE num_values,
|
||||
OpKernelContext* context) const {
|
||||
Tensor* values_out = nullptr;
|
||||
TensorShape values_shape = params_dense_values_in.shape();
|
||||
values_shape.set_dim(0, num_values);
|
||||
TF_RETURN_IF_ERROR(
|
||||
context->allocate_output(values_index, values_shape, &values_out));
|
||||
const int64 num_elements = params_dense_values_in.NumElements();
|
||||
const int64 value_size =
|
||||
const SPLITS_TYPE num_elements = params_dense_values_in.NumElements();
|
||||
const SPLITS_TYPE value_size =
|
||||
num_elements == 0 ? 0
|
||||
: (num_elements / params_dense_values_in.dim_size(0));
|
||||
CallWriteValueSlices(params_dense_values_in, value_slices, value_size,
|
||||
@ -253,34 +258,39 @@ class RaggedGatherOpBase : public OpKernel {
|
||||
// which cuts the binary size of this op from ~300k to <90k.
|
||||
virtual void CallWriteValueSlices(
|
||||
const Tensor& params_dense_values_in,
|
||||
const std::vector<std::pair<int64, int64>>& value_slices,
|
||||
int64 value_size, Tensor* values_out) const = 0;
|
||||
const std::vector<std::pair<SPLITS_TYPE, SPLITS_TYPE>>& value_slices,
|
||||
SPLITS_TYPE value_size, Tensor* values_out) const = 0;
|
||||
};
|
||||
|
||||
template <typename INDEX_TYPE, typename VALUE_TYPE>
|
||||
class RaggedGatherOp : public RaggedGatherOpBase<INDEX_TYPE> {
|
||||
template <typename INDEX_TYPE, typename VALUE_TYPE, typename SPLITS_TYPE>
|
||||
class RaggedGatherOp : public RaggedGatherOpBase<INDEX_TYPE, SPLITS_TYPE> {
|
||||
public:
|
||||
using RaggedGatherOpBase<INDEX_TYPE>::RaggedGatherOpBase;
|
||||
using RaggedGatherOpBase<INDEX_TYPE, SPLITS_TYPE>::RaggedGatherOpBase;
|
||||
|
||||
private:
|
||||
void CallWriteValueSlices(
|
||||
const Tensor& params_dense_values_in,
|
||||
const std::vector<std::pair<int64, int64>>& value_slices,
|
||||
int64 value_size, Tensor* values_out) const override {
|
||||
const std::vector<std::pair<SPLITS_TYPE, SPLITS_TYPE>>& value_slices,
|
||||
SPLITS_TYPE value_size, Tensor* values_out) const override {
|
||||
WriteValueSlices<VALUE_TYPE>(params_dense_values_in, value_slices,
|
||||
value_size, values_out);
|
||||
}
|
||||
};
|
||||
|
||||
#define REGISTER_CPU_KERNEL_WITH_INDEX_TYPE(index_type, value_type) \
|
||||
REGISTER_KERNEL_BUILDER(Name("RaggedGather") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<index_type>("Tindices") \
|
||||
.TypeConstraint<value_type>("Tvalues"), \
|
||||
RaggedGatherOp<index_type, value_type>);
|
||||
#define REGISTER_CPU_KERNEL(value_type) \
|
||||
REGISTER_CPU_KERNEL_WITH_INDEX_TYPE(int32, value_type) \
|
||||
REGISTER_CPU_KERNEL_WITH_INDEX_TYPE(int64, value_type)
|
||||
#define REGISTER_CPU_KERNEL_WITH_INDEX_TYPE(index_type, value_type, \
|
||||
splits_type) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("RaggedGather") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<index_type>("Tindices") \
|
||||
.TypeConstraint<value_type>("Tvalues") \
|
||||
.TypeConstraint<splits_type>("Tsplits"), \
|
||||
RaggedGatherOp<index_type, value_type, splits_type>);
|
||||
#define REGISTER_CPU_KERNEL(value_type) \
|
||||
REGISTER_CPU_KERNEL_WITH_INDEX_TYPE(int32, value_type, int32) \
|
||||
REGISTER_CPU_KERNEL_WITH_INDEX_TYPE(int64, value_type, int32) \
|
||||
REGISTER_CPU_KERNEL_WITH_INDEX_TYPE(int32, value_type, int64) \
|
||||
REGISTER_CPU_KERNEL_WITH_INDEX_TYPE(int64, value_type, int64)
|
||||
TF_CALL_POD_TYPES(REGISTER_CPU_KERNEL);
|
||||
TF_CALL_string(REGISTER_CPU_KERNEL);
|
||||
TF_CALL_QUANTIZED_TYPES(REGISTER_CPU_KERNEL);
|
||||
|
@ -26,7 +26,7 @@ namespace tensorflow {
|
||||
|
||||
using errors::InvalidArgument;
|
||||
|
||||
template <typename T>
|
||||
template <typename T, typename SPLITS_TYPE>
|
||||
class RaggedRangeOp : public OpKernel {
|
||||
public:
|
||||
using OpKernel::OpKernel;
|
||||
@ -60,7 +60,7 @@ class RaggedRangeOp : public OpKernel {
|
||||
InvalidArgument("starts, limits, and deltas must have the "
|
||||
"same shape"));
|
||||
}
|
||||
int64 nrows = in_sizes.empty() ? 1 : in_sizes[0];
|
||||
SPLITS_TYPE nrows = in_sizes.empty() ? 1 : in_sizes[0];
|
||||
|
||||
const auto& starts = starts_in.flat<T>();
|
||||
const auto& limits = limits_in.flat<T>();
|
||||
@ -71,7 +71,7 @@ class RaggedRangeOp : public OpKernel {
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_output(0, TensorShape({nrows + 1}),
|
||||
&rt_nested_splits_out));
|
||||
auto rt_nested_splits = rt_nested_splits_out->flat<int64>();
|
||||
auto rt_nested_splits = rt_nested_splits_out->flat<SPLITS_TYPE>();
|
||||
rt_nested_splits(0) = 0;
|
||||
for (int row = 0; row < nrows; ++row) {
|
||||
T start = broadcast_starts ? starts(0) : starts(row);
|
||||
@ -81,7 +81,7 @@ class RaggedRangeOp : public OpKernel {
|
||||
rt_nested_splits(row + 1) =
|
||||
rt_nested_splits(row) + RangeSize(start, limit, delta);
|
||||
}
|
||||
int64 nvals = rt_nested_splits(nrows);
|
||||
SPLITS_TYPE nvals = rt_nested_splits(nrows);
|
||||
|
||||
// Construct the rt_dense_values tensor.
|
||||
Tensor* rt_dense_values_out = nullptr;
|
||||
@ -90,10 +90,10 @@ class RaggedRangeOp : public OpKernel {
|
||||
auto rt_dense_values = rt_dense_values_out->flat<T>();
|
||||
int value_index = 0;
|
||||
for (int row = 0; row < nrows; ++row) {
|
||||
int64 row_size = rt_nested_splits(row + 1) - rt_nested_splits(row);
|
||||
SPLITS_TYPE row_size = rt_nested_splits(row + 1) - rt_nested_splits(row);
|
||||
T value = broadcast_starts ? starts(0) : starts(row);
|
||||
T delta = broadcast_deltas ? deltas(0) : deltas(row);
|
||||
for (int64 i = 0; i < row_size; ++i) {
|
||||
for (SPLITS_TYPE i = 0; i < row_size; ++i) {
|
||||
rt_dense_values(value_index++) = T(value);
|
||||
value += delta;
|
||||
}
|
||||
@ -102,7 +102,7 @@ class RaggedRangeOp : public OpKernel {
|
||||
|
||||
private:
|
||||
// Returns the number of elements in the specified range.
|
||||
int64 RangeSize(T start, T limit, T delta) {
|
||||
SPLITS_TYPE RangeSize(T start, T limit, T delta) {
|
||||
if (((delta > 0) && (limit < start)) || ((delta < 0) && (limit > start))) {
|
||||
return 0;
|
||||
}
|
||||
@ -114,10 +114,17 @@ class RaggedRangeOp : public OpKernel {
|
||||
}
|
||||
};
|
||||
|
||||
#define REGISTER_CPU_KERNEL(TYPE) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("RaggedRange").Device(DEVICE_CPU).TypeConstraint<TYPE>("T"), \
|
||||
RaggedRangeOp<TYPE>);
|
||||
#define REGISTER_CPU_KERNEL(TYPE) \
|
||||
REGISTER_KERNEL_BUILDER(Name("RaggedRange") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<TYPE>("T") \
|
||||
.TypeConstraint<int32>("Tsplits"), \
|
||||
RaggedRangeOp<TYPE, int32>); \
|
||||
REGISTER_KERNEL_BUILDER(Name("RaggedRange") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<TYPE>("T") \
|
||||
.TypeConstraint<int64>("Tsplits"), \
|
||||
RaggedRangeOp<TYPE, int64>);
|
||||
TF_CALL_float(REGISTER_CPU_KERNEL);
|
||||
TF_CALL_double(REGISTER_CPU_KERNEL);
|
||||
TF_CALL_int32(REGISTER_CPU_KERNEL);
|
||||
|
@ -26,21 +26,23 @@ namespace tensorflow {
|
||||
|
||||
using errors::InvalidArgument;
|
||||
|
||||
template <typename SPLITS_TYPE>
|
||||
class RaggedTensorToSparseOp : public OpKernel {
|
||||
public:
|
||||
using OpKernel::OpKernel;
|
||||
using ConstFlatSplits = typename TTypes<SPLITS_TYPE>::ConstFlat;
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
// Read the `rt_nested_splits` input & convert to Eigen tensors.
|
||||
OpInputList rt_nested_splits_in;
|
||||
OP_REQUIRES_OK(
|
||||
context, context->input_list("rt_nested_splits", &rt_nested_splits_in));
|
||||
const int64 rt_nested_splits_len = rt_nested_splits_in.size();
|
||||
const int rt_nested_splits_len = rt_nested_splits_in.size();
|
||||
DCHECK_GT(rt_nested_splits_len, 0); // Enforced by REGISTER_OP.
|
||||
std::vector<TTypes<int64>::ConstFlat> rt_nested_splits;
|
||||
std::vector<ConstFlatSplits> rt_nested_splits;
|
||||
rt_nested_splits.reserve(rt_nested_splits_len);
|
||||
for (int i = 0; i < rt_nested_splits_len; ++i) {
|
||||
rt_nested_splits.push_back(rt_nested_splits_in[i].flat<int64>());
|
||||
rt_nested_splits.push_back(rt_nested_splits_in[i].flat<SPLITS_TYPE>());
|
||||
}
|
||||
|
||||
// Read the `rt_dense_values` input.
|
||||
@ -135,7 +137,7 @@ class RaggedTensorToSparseOp : public OpKernel {
|
||||
sparse_dense_shape(0) = rt_nested_splits_in[0].dim_size(0) - 1;
|
||||
for (int dim = 0; dim < rt_nested_splits_len; ++dim) {
|
||||
const auto& splits = rt_nested_splits[dim];
|
||||
int64 max_width = 0;
|
||||
SPLITS_TYPE max_width = 0;
|
||||
for (int i = 1; i < splits.size(); ++i) {
|
||||
max_width = std::max(max_width, splits(i) - splits(i - 1));
|
||||
}
|
||||
@ -150,7 +152,7 @@ class RaggedTensorToSparseOp : public OpKernel {
|
||||
private:
|
||||
// Validate `rt_nested_splits` to ensure we don't get any segfaults.
|
||||
static ::tensorflow::Status ValidateInputs(
|
||||
std::vector<TTypes<int64>::ConstFlat> rt_nested_splits,
|
||||
std::vector<ConstFlatSplits> rt_nested_splits,
|
||||
const Tensor& rt_dense_values_in) {
|
||||
for (int i = 0; i < rt_nested_splits.size(); ++i) {
|
||||
if (rt_nested_splits[i].size() == 0) {
|
||||
@ -160,7 +162,7 @@ class RaggedTensorToSparseOp : public OpKernel {
|
||||
return InvalidArgument("First value of ragged splits must be 0.");
|
||||
}
|
||||
if (i > 0) {
|
||||
int64 last_split =
|
||||
SPLITS_TYPE last_split =
|
||||
rt_nested_splits[i - 1](rt_nested_splits[i - 1].size() - 1);
|
||||
if (rt_nested_splits[i].size() != last_split + 1) {
|
||||
return InvalidArgument(
|
||||
@ -206,14 +208,21 @@ class RaggedTensorToSparseOp : public OpKernel {
|
||||
// values.
|
||||
static bool IsCompleted(
|
||||
const std::vector<int64>& pos, int dim,
|
||||
const std::vector<TTypes<int64>::ConstFlat>& rt_nested_splits) {
|
||||
const std::vector<ConstFlatSplits>& rt_nested_splits) {
|
||||
int64 current_child = pos[dim + 1];
|
||||
int64 limit_child = rt_nested_splits[dim](pos[dim] + 1);
|
||||
return current_child >= limit_child;
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("RaggedTensorToSparse").Device(DEVICE_CPU),
|
||||
RaggedTensorToSparseOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("RaggedTensorToSparse")
|
||||
.Device(DEVICE_CPU)
|
||||
.TypeConstraint<int32>("Tsplits"),
|
||||
RaggedTensorToSparseOp<int32>);
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("RaggedTensorToSparse")
|
||||
.Device(DEVICE_CPU)
|
||||
.TypeConstraint<int64>("Tsplits"),
|
||||
RaggedTensorToSparseOp<int64>);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -350,6 +350,7 @@ class UnicodeTranscodeOp : public OpKernel {
|
||||
REGISTER_KERNEL_BUILDER(Name("UnicodeTranscode").Device(DEVICE_CPU),
|
||||
UnicodeTranscodeOp);
|
||||
|
||||
template <typename SPLITS_TYPE>
|
||||
class UnicodeDecodeBaseOp : public OpKernel {
|
||||
public:
|
||||
explicit UnicodeDecodeBaseOp(OpKernelConstruction* ctx, bool generate_offsets)
|
||||
@ -369,8 +370,8 @@ class UnicodeDecodeBaseOp : public OpKernel {
|
||||
}
|
||||
|
||||
void Decode(OpKernelContext* ctx, std::vector<UChar32>* char_values,
|
||||
std::vector<int64>* offset_values, int* current_offset,
|
||||
int64* next_row_split, UChar32 char_value, int char_length,
|
||||
std::vector<SPLITS_TYPE>* offset_values, int* current_offset,
|
||||
SPLITS_TYPE* next_row_split, UChar32 char_value, int char_length,
|
||||
bool found_any_format_error) {
|
||||
if (error_options_.error_on_malformatting && found_any_format_error) {
|
||||
ctx->CtxFailure(
|
||||
@ -414,16 +415,16 @@ class UnicodeDecodeBaseOp : public OpKernel {
|
||||
input_encoding_));
|
||||
|
||||
std::vector<UChar32> char_values;
|
||||
std::vector<int64> offset_values;
|
||||
std::vector<SPLITS_TYPE> offset_values;
|
||||
|
||||
Tensor* output_row_splits;
|
||||
OP_REQUIRES_OK(ctx, ctx->allocate_output("row_splits",
|
||||
{input_tensor->NumElements() + 1},
|
||||
&output_row_splits));
|
||||
auto out_row_splits = output_row_splits->vec<int64>();
|
||||
auto out_row_splits = output_row_splits->vec<SPLITS_TYPE>();
|
||||
|
||||
int row_split_index = 0;
|
||||
int64 next_row_split = 0;
|
||||
SPLITS_TYPE next_row_split = 0;
|
||||
for (int i = 0; i < input_vec.size(); ++i) {
|
||||
const string& input = input_vec(i);
|
||||
// Convert input strings into unicode values. Output to a list of
|
||||
@ -443,18 +444,18 @@ class UnicodeDecodeBaseOp : public OpKernel {
|
||||
|
||||
Tensor* output_char_values;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, ctx->allocate_output("char_values",
|
||||
{static_cast<int64>(char_values.size())},
|
||||
&output_char_values));
|
||||
ctx, ctx->allocate_output(
|
||||
"char_values", {static_cast<SPLITS_TYPE>(char_values.size())},
|
||||
&output_char_values));
|
||||
auto out_char_values = output_char_values->vec<int32>();
|
||||
if (generate_offsets_) {
|
||||
DCHECK(offset_values.size() == char_values.size());
|
||||
Tensor* output_offset_values;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, ctx->allocate_output("char_to_byte_starts",
|
||||
{static_cast<int64>(offset_values.size())},
|
||||
&output_offset_values));
|
||||
auto out_offset_values = output_offset_values->vec<int64>();
|
||||
OP_REQUIRES_OK(ctx, ctx->allocate_output(
|
||||
"char_to_byte_starts",
|
||||
{static_cast<SPLITS_TYPE>(offset_values.size())},
|
||||
&output_offset_values));
|
||||
auto out_offset_values = output_offset_values->vec<SPLITS_TYPE>();
|
||||
|
||||
// Load output tensors from intermediate value arrays.
|
||||
for (int i = 0; i < char_values.size(); ++i) {
|
||||
@ -474,23 +475,36 @@ class UnicodeDecodeBaseOp : public OpKernel {
|
||||
bool generate_offsets_ = false;
|
||||
};
|
||||
|
||||
class UnicodeDecodeOp : public UnicodeDecodeBaseOp {
|
||||
template <typename SPLITS_TYPE>
|
||||
class UnicodeDecodeOp : public UnicodeDecodeBaseOp<SPLITS_TYPE> {
|
||||
public:
|
||||
explicit UnicodeDecodeOp(OpKernelConstruction* ctx)
|
||||
: UnicodeDecodeBaseOp(ctx, false) {}
|
||||
: UnicodeDecodeBaseOp<SPLITS_TYPE>(ctx, false) {}
|
||||
};
|
||||
|
||||
class UnicodeDecodeWithOffsetsOp : public UnicodeDecodeBaseOp {
|
||||
template <typename SPLITS_TYPE>
|
||||
class UnicodeDecodeWithOffsetsOp : public UnicodeDecodeBaseOp<SPLITS_TYPE> {
|
||||
public:
|
||||
explicit UnicodeDecodeWithOffsetsOp(OpKernelConstruction* ctx)
|
||||
: UnicodeDecodeBaseOp(ctx, true) {}
|
||||
: UnicodeDecodeBaseOp<SPLITS_TYPE>(ctx, true) {}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("UnicodeDecode").Device(DEVICE_CPU),
|
||||
UnicodeDecodeOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("UnicodeDecodeWithOffsets").Device(DEVICE_CPU),
|
||||
UnicodeDecodeWithOffsetsOp);
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("UnicodeDecode").Device(DEVICE_CPU).TypeConstraint<int64>("Tsplits"),
|
||||
UnicodeDecodeOp<int64>);
|
||||
REGISTER_KERNEL_BUILDER(Name("UnicodeDecodeWithOffsets")
|
||||
.Device(DEVICE_CPU)
|
||||
.TypeConstraint<int64>("Tsplits"),
|
||||
UnicodeDecodeWithOffsetsOp<int64>);
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("UnicodeDecode").Device(DEVICE_CPU).TypeConstraint<int32>("Tsplits"),
|
||||
UnicodeDecodeOp<int32>);
|
||||
REGISTER_KERNEL_BUILDER(Name("UnicodeDecodeWithOffsets")
|
||||
.Device(DEVICE_CPU)
|
||||
.TypeConstraint<int32>("Tsplits"),
|
||||
UnicodeDecodeWithOffsetsOp<int32>);
|
||||
|
||||
template <typename SPLITS_TYPE>
|
||||
class UnicodeEncodeOp : public OpKernel {
|
||||
public:
|
||||
explicit UnicodeEncodeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
||||
@ -515,7 +529,7 @@ class UnicodeEncodeOp : public OpKernel {
|
||||
const Tensor& input_tensor = context->input(0);
|
||||
const auto input_tensor_flat = input_tensor.flat<int32>();
|
||||
const Tensor& input_splits = context->input(1);
|
||||
const auto input_splits_flat = input_splits.flat<int64>();
|
||||
const auto input_splits_flat = input_splits.flat<SPLITS_TYPE>();
|
||||
|
||||
// Since we limit to a 2-D input (flat_values of rank 1 and a single splits
|
||||
// tensor), our output dimension will be 1 with it's size equal to the
|
||||
@ -558,7 +572,11 @@ class UnicodeEncodeOp : public OpKernel {
|
||||
ErrorOptions error_options_;
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("UnicodeEncode").Device(DEVICE_CPU),
|
||||
UnicodeEncodeOp);
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("UnicodeEncode").Device(DEVICE_CPU).TypeConstraint<int64>("Tsplits"),
|
||||
UnicodeEncodeOp<int64>);
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("UnicodeEncode").Device(DEVICE_CPU).TypeConstraint<int32>("Tsplits"),
|
||||
UnicodeEncodeOp<int32>);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -29,13 +29,14 @@ Status RaggedGatherShapeFn(InferenceContext* c);
|
||||
//==============================================================================
|
||||
|
||||
REGISTER_OP("RaggedGather")
|
||||
.Input("params_nested_splits: PARAMS_RAGGED_RANK * int64")
|
||||
.Input("params_nested_splits: PARAMS_RAGGED_RANK * Tsplits")
|
||||
.Input("params_dense_values: Tvalues")
|
||||
.Input("indices: Tindices")
|
||||
.Output("output_nested_splits: OUTPUT_RAGGED_RANK * int64")
|
||||
.Output("output_nested_splits: OUTPUT_RAGGED_RANK * Tsplits")
|
||||
.Output("output_dense_values: Tvalues")
|
||||
.Attr("Tvalues: type")
|
||||
.Attr("Tindices: {int32, int64}")
|
||||
.Attr("Tsplits: {int32, int64} = DT_INT64")
|
||||
.Attr("PARAMS_RAGGED_RANK: int >= 1")
|
||||
.Attr("OUTPUT_RAGGED_RANK: int >= 0")
|
||||
.SetShapeFn(RaggedGatherShapeFn);
|
||||
|
@ -31,13 +31,14 @@ Status RaggedTensorFromVariantShapeFn(InferenceContext* c);
|
||||
//==============================================================================
|
||||
|
||||
REGISTER_OP("RaggedTensorToSparse")
|
||||
.Input("rt_nested_splits: RAGGED_RANK * int64")
|
||||
.Input("rt_nested_splits: RAGGED_RANK * Tsplits")
|
||||
.Input("rt_dense_values: T")
|
||||
.Output("sparse_indices: int64")
|
||||
.Output("sparse_values: T")
|
||||
.Output("sparse_dense_shape: int64")
|
||||
.Attr("RAGGED_RANK: int >= 1")
|
||||
.Attr("T: type")
|
||||
.Attr("Tsplits: {int32, int64} = DT_INT64")
|
||||
.SetShapeFn(RaggedTensorToSparseShapeFn);
|
||||
|
||||
REGISTER_OP("RaggedTensorToVariant")
|
||||
|
@ -32,9 +32,10 @@ REGISTER_OP("RaggedRange")
|
||||
.Input("starts: T")
|
||||
.Input("limits: T")
|
||||
.Input("deltas: T")
|
||||
.Output("rt_nested_splits: int64")
|
||||
.Output("rt_nested_splits: Tsplits")
|
||||
.Output("rt_dense_values: T")
|
||||
.Attr("T: {bfloat16, float, double, int32, int64} = DT_INT32")
|
||||
.Attr("Tsplits: {int32, int64} = DT_INT64")
|
||||
.SetShapeFn(RaggedRangeShapeFn);
|
||||
|
||||
//==============================================================================
|
||||
|
@ -263,10 +263,11 @@ REGISTER_OP("UnicodeScript")
|
||||
|
||||
REGISTER_OP("UnicodeEncode")
|
||||
.Input("input_values: int32")
|
||||
.Input("input_splits: int64")
|
||||
.Input("input_splits: Tsplits")
|
||||
.Attr("errors: {'ignore', 'replace', 'strict'} = 'replace'")
|
||||
.Attr("output_encoding: {'UTF-8', 'UTF-16-BE', 'UTF-32-BE'}")
|
||||
.Attr("replacement_char: int = 65533") // 0xFFFD unicode replacement char
|
||||
.Attr("Tsplits: {int32, int64} = DT_INT64")
|
||||
.Output("output: string")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
// Check rank of inner values
|
||||
@ -298,12 +299,13 @@ REGISTER_OP("UnicodeTranscode")
|
||||
|
||||
REGISTER_OP("UnicodeDecode")
|
||||
.Input("input: string")
|
||||
.Output("row_splits: int64")
|
||||
.Output("row_splits: Tsplits")
|
||||
.Output("char_values: int32")
|
||||
.Attr("input_encoding: string")
|
||||
.Attr("errors: {'strict', 'replace', 'ignore'} = 'replace'")
|
||||
.Attr("replacement_char: int = 65533") // 0xFFFD unicode replacement char
|
||||
.Attr("replace_control_characters: bool = false")
|
||||
.Attr("Tsplits: {int32, int64} = DT_INT64")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
// row_splits.shape == [input.size() + 1]
|
||||
DimensionHandle num_row_splits;
|
||||
@ -319,13 +321,14 @@ REGISTER_OP("UnicodeDecode")
|
||||
|
||||
REGISTER_OP("UnicodeDecodeWithOffsets")
|
||||
.Input("input: string")
|
||||
.Output("row_splits: int64")
|
||||
.Output("row_splits: Tsplits")
|
||||
.Output("char_values: int32")
|
||||
.Output("char_to_byte_starts: int64")
|
||||
.Attr("input_encoding: string")
|
||||
.Attr("errors: {'strict', 'replace', 'ignore'} = 'replace'")
|
||||
.Attr("replacement_char: int = 65533") // 0xFFFD unicode replacement char
|
||||
.Attr("replace_control_characters: bool = false")
|
||||
.Attr("Tsplits: {int32, int64} = DT_INT64")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
// row_splits.shape == [input.size() + 1]
|
||||
DimensionHandle num_row_splits;
|
||||
|
@ -27,6 +27,7 @@ py_library(
|
||||
":ragged_batch_gather_ops",
|
||||
":ragged_batch_gather_with_default_op",
|
||||
":ragged_concat_ops",
|
||||
":ragged_config",
|
||||
":ragged_conversion_ops",
|
||||
":ragged_dispatch",
|
||||
":ragged_factory_ops",
|
||||
@ -282,11 +283,21 @@ py_library(
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "ragged_config",
|
||||
srcs = ["ragged_config.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/python:dtypes",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "ragged_tensor",
|
||||
srcs = ["ragged_tensor.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":ragged_config",
|
||||
":ragged_tensor_value",
|
||||
":ragged_util",
|
||||
":segment_id_ops",
|
||||
@ -363,6 +374,7 @@ py_library(
|
||||
srcs = ["segment_id_ops.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":ragged_config",
|
||||
":ragged_util",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:dtypes",
|
||||
|
@ -23,7 +23,6 @@ from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops.ragged import ragged_conversion_ops
|
||||
from tensorflow.python.ops.ragged import ragged_functional_ops
|
||||
from tensorflow.python.ops.ragged import ragged_math_ops
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
@ -122,6 +121,8 @@ def boolean_mask(data, mask, keepdims=False, name=None):
|
||||
data = ragged_tensor.convert_to_tensor_or_ragged_tensor(data, name='data')
|
||||
mask = ragged_tensor.convert_to_tensor_or_ragged_tensor(
|
||||
mask, dtypes.bool, name='mask')
|
||||
row_splits_dtype, (data, mask) = ragged_tensor.match_row_splits_dtypes(
|
||||
data, mask, return_dtype=True)
|
||||
|
||||
# Get static rank of mask.
|
||||
if mask.shape.ndims is None:
|
||||
@ -132,8 +133,9 @@ def boolean_mask(data, mask, keepdims=False, name=None):
|
||||
# If mask is ragged, then recurse with a non-ragged mask.
|
||||
if ragged_tensor.is_ragged(mask):
|
||||
if not ragged_tensor.is_ragged(data):
|
||||
data = ragged_conversion_ops.from_tensor(
|
||||
data, ragged_rank=mask.ragged_rank)
|
||||
data = ragged_tensor.RaggedTensor.from_tensor(
|
||||
data, ragged_rank=mask.ragged_rank,
|
||||
row_splits_dtype=mask.row_splits.dtype)
|
||||
# Check that mask.nested_row_splits is a prefix of
|
||||
# data.nested_row_splits.
|
||||
splits_list = [
|
||||
@ -152,7 +154,7 @@ def boolean_mask(data, mask, keepdims=False, name=None):
|
||||
# Count the number of True mask values in each row to find the
|
||||
# lengths of the filtered rows; then convert to splits.
|
||||
int_mask = ragged_functional_ops.map_flat_values(
|
||||
math_ops.cast, mask, dtype=dtypes.int64)
|
||||
math_ops.cast, mask, dtype=row_splits_dtype)
|
||||
masked_row_lengths = ragged_math_ops.reduce_sum(int_mask, axis=1)
|
||||
splits.append(ragged_util.lengths_to_splits(masked_row_lengths))
|
||||
mask = mask.values
|
||||
@ -192,8 +194,9 @@ def boolean_mask(data, mask, keepdims=False, name=None):
|
||||
# If mask is non-ragged and has rank>1, then convert it to be ragged,
|
||||
# with a ragged rank matching data.
|
||||
if ragged_tensor.is_ragged(data):
|
||||
mask = ragged_conversion_ops.from_tensor(
|
||||
mask, ragged_rank=min(data.ragged_rank, mask.shape.ndims - 1))
|
||||
mask = ragged_tensor.RaggedTensor.from_tensor(
|
||||
mask, ragged_rank=min(data.ragged_rank, mask.shape.ndims - 1),
|
||||
row_splits_dtype=data.row_splits.dtype)
|
||||
return boolean_mask(data, mask, keepdims)
|
||||
|
||||
# Otherwise, data and mask are both `Tensor`s.
|
||||
@ -206,14 +209,15 @@ def boolean_mask(data, mask, keepdims=False, name=None):
|
||||
# number of values it contains. Then flatten that to get a list of
|
||||
# cell lengths, and convert it to splits. Finally, combine the splits
|
||||
# and values to get the innermost ragged tensor.
|
||||
masked_lengths = math_ops.count_nonzero(mask, axis=-1)
|
||||
masked_lengths = math_ops.count_nonzero(mask, axis=-1,
|
||||
dtype=row_splits_dtype)
|
||||
flattened_masked_lengths = array_ops.reshape(masked_lengths, [-1])
|
||||
masked_values = ragged_tensor.RaggedTensor.from_row_lengths(
|
||||
masked_values, flattened_masked_lengths)
|
||||
|
||||
# Wrap remaining ragged dimensions.
|
||||
if mask.shape.ndims > 2 and keepdims:
|
||||
mask_shape = array_ops.shape(mask, out_type=dtypes.int64)
|
||||
mask_shape = array_ops.shape(mask, out_type=row_splits_dtype)
|
||||
split_size = math_ops.cumprod(mask_shape) + 1
|
||||
for dim in range(mask.shape.ndims - 3, -1, -1):
|
||||
elt_size = mask_shape[dim + 1]
|
||||
@ -254,11 +258,11 @@ def tile(input, multiples, name=None): # pylint: disable=redefined-builtin
|
||||
with ops.name_scope(name, 'RaggedTile', [input, multiples]):
|
||||
input = ragged_tensor.convert_to_tensor_or_ragged_tensor(
|
||||
input, name='input')
|
||||
multiples = ragged_util.convert_to_int_tensor(
|
||||
multiples, name='multiples', dtype=dtypes.int64)
|
||||
multiples.shape.assert_has_rank(1)
|
||||
if not ragged_tensor.is_ragged(input):
|
||||
return array_ops.tile(input, multiples, name)
|
||||
multiples = ragged_util.convert_to_int_tensor(
|
||||
multiples, name='multiples', dtype=input.row_splits.dtype)
|
||||
multiples.shape.assert_has_rank(1)
|
||||
|
||||
# If the constant value of `multiples` is available, then we can use it
|
||||
# to skip tiling dimensions where `multiples=1`.
|
||||
@ -343,7 +347,7 @@ def _tile_ragged_splits(rt_input, multiples, const_multiples=None):
|
||||
dimensions where `multiples=1`.
|
||||
|
||||
Returns:
|
||||
A list of 1-D `int64` `Tensor`s (one for each ragged dimension in
|
||||
A list of 1-D integer `Tensor`s (one for each ragged dimension in
|
||||
`rt_input`).
|
||||
|
||||
#### Example:
|
||||
@ -514,40 +518,6 @@ def size(input, out_type=dtypes.int32, name=None): # pylint: disable=redefined-
|
||||
return array_ops.size(input, out_type=out_type, name=name)
|
||||
|
||||
|
||||
#===============================================================================
|
||||
# Internal Helper Functions
|
||||
#===============================================================================
|
||||
|
||||
|
||||
def _increase_ragged_rank_to(rt_input, ragged_rank):
|
||||
"""Adds ragged dimensions to `rt_input` so it has the desired ragged rank."""
|
||||
if ragged_rank > 0:
|
||||
if not ragged_tensor.is_ragged(rt_input):
|
||||
rt_input = ragged_conversion_ops.from_tensor(rt_input)
|
||||
if rt_input.ragged_rank < ragged_rank:
|
||||
rt_input = rt_input.with_values(
|
||||
_increase_ragged_rank_to(rt_input.values, ragged_rank - 1))
|
||||
return rt_input
|
||||
|
||||
|
||||
def _concat_ragged_splits(splits_list):
|
||||
"""Concatenates a list of RaggedTensor splits to form a single splits."""
|
||||
pieces = [splits_list[0]]
|
||||
splits_offset = splits_list[0][-1]
|
||||
for splits in splits_list[1:]:
|
||||
pieces.append(splits[1:] + splits_offset)
|
||||
splits_offset += splits[-1]
|
||||
return array_ops.concat(pieces, axis=0)
|
||||
|
||||
|
||||
def _nrows(rt_input, out_type=dtypes.int64, name=None):
|
||||
if isinstance(rt_input, ragged_tensor.RaggedTensor):
|
||||
return rt_input.nrows(out_type=out_type, name=name)
|
||||
else:
|
||||
with ops.name_scope(name, 'RaggedNRows', [rt_input]):
|
||||
return array_ops.shape(rt_input, out_type=out_type)[0]
|
||||
|
||||
|
||||
#===============================================================================
|
||||
# ragged.rank
|
||||
#===============================================================================
|
||||
|
@ -18,7 +18,6 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import check_ops
|
||||
@ -72,6 +71,7 @@ def batch_gather(params, indices, name=None):
|
||||
params, name='params')
|
||||
indices = ragged_tensor.convert_to_tensor_or_ragged_tensor(
|
||||
indices, name='indices')
|
||||
params, indices = ragged_tensor.match_row_splits_dtypes(params, indices)
|
||||
indices_ndims = indices.shape.ndims
|
||||
if indices_ndims is None:
|
||||
raise ValueError(
|
||||
@ -97,15 +97,17 @@ def batch_gather(params, indices, name=None):
|
||||
if params.shape.ndims is not None and params.shape.ndims < 2:
|
||||
raise ValueError('batch shape from indices does '
|
||||
'not match params shape')
|
||||
params = ragged_conversion_ops.from_tensor(params, ragged_rank=1)
|
||||
params = ragged_conversion_ops.from_tensor(
|
||||
params, ragged_rank=1,
|
||||
row_splits_dtype=indices.row_splits.dtype)
|
||||
|
||||
# Adjust indices from within-batch to global (in params.values), and
|
||||
# then use ragged.gather to gather them.
|
||||
num_indices = indices.row_lengths()
|
||||
params_starts = params.row_starts()
|
||||
adjustments = ragged_util.repeat(params_starts, num_indices, axis=0)
|
||||
adjusted_index_values = math_ops.cast(
|
||||
indices.values, dtypes.int64) + adjustments
|
||||
adjusted_index_values = (
|
||||
math_ops.cast(indices.values, adjustments.dtype) + adjustments)
|
||||
return ragged_tensor.RaggedTensor.from_row_splits(
|
||||
ragged_gather_ops.gather(params.values, adjusted_index_values),
|
||||
indices.row_splits)
|
||||
@ -116,7 +118,8 @@ def batch_gather(params, indices, name=None):
|
||||
elif indices_ndims == 2:
|
||||
# Adjust indices from batch-local to global (in params.values)
|
||||
adjustments = array_ops.expand_dims(params.row_starts(), 1)
|
||||
adjusted_indices = math_ops.cast(indices, dtypes.int64) + adjustments
|
||||
adjusted_indices = (
|
||||
math_ops.cast(indices, adjustments.dtype) + adjustments)
|
||||
return ragged_gather_ops.gather(params.values, adjusted_indices)
|
||||
else:
|
||||
raise ValueError('batch shape from indices does not match params shape')
|
||||
|
@ -20,7 +20,6 @@ from __future__ import print_function
|
||||
|
||||
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import check_ops
|
||||
@ -81,6 +80,9 @@ def batch_gather_with_default(params,
|
||||
default_value = ragged_tensor.convert_to_tensor_or_ragged_tensor(
|
||||
default_value, name='default_value',
|
||||
)
|
||||
row_splits_dtype, (params, indices, default_value) = (
|
||||
ragged_tensor.match_row_splits_dtypes(params, indices, default_value,
|
||||
return_dtype=True))
|
||||
# TODO(hterry): lift this restriction and support default_values of
|
||||
# of rank > 1
|
||||
if (default_value.shape.ndims is not 0
|
||||
@ -113,7 +115,7 @@ def batch_gather_with_default(params,
|
||||
axis=-1)
|
||||
upper_bounds = math_ops.cast(row_lengths, indices.dtype)
|
||||
|
||||
pad_shape = _get_pad_shape(params, indices)
|
||||
pad_shape = _get_pad_shape(params, indices, row_splits_dtype)
|
||||
|
||||
pad = ragged_tensor_shape.broadcast_to(
|
||||
default_value, pad_shape)
|
||||
@ -144,11 +146,11 @@ def batch_gather_with_default(params,
|
||||
params=padded_params, indices=adjusted_indices, name=name)
|
||||
|
||||
|
||||
def _get_pad_shape(params, indices):
|
||||
def _get_pad_shape(params, indices, row_splits_dtype):
|
||||
"""Gets the RaggedTensorDynamicShape for the pad tensor."""
|
||||
num_batch_dimensions = indices.shape.ndims - 1
|
||||
params_shape = ragged_tensor_shape.RaggedTensorDynamicShape.from_tensor(
|
||||
params)
|
||||
params, dim_size_dtype=row_splits_dtype)
|
||||
|
||||
# We want to create a pad tensor that can be concatenated with the params.
|
||||
if params.shape.ndims == indices.shape.ndims:
|
||||
@ -169,8 +171,8 @@ def _get_pad_shape(params, indices):
|
||||
# has size 1.
|
||||
pad_dims = None
|
||||
if num_batch_dimensions == 0:
|
||||
pad_dims = (constant_op.constant(1, dtype=dtypes.int64),) + (
|
||||
constant_op.constant([1], dtype=dtypes.int64),) * (
|
||||
pad_dims = (constant_op.constant(1, dtype=row_splits_dtype),) + (
|
||||
constant_op.constant([1], dtype=row_splits_dtype),) * (
|
||||
params_shape.num_partitioned_dimensions -
|
||||
num_batch_dimensions - 1)
|
||||
else:
|
||||
|
@ -24,7 +24,6 @@ from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import check_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops.ragged import ragged_array_ops
|
||||
from tensorflow.python.ops.ragged import ragged_conversion_ops
|
||||
from tensorflow.python.ops.ragged import ragged_gather_ops
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
from tensorflow.python.ops.ragged import ragged_util
|
||||
@ -135,6 +134,9 @@ def _ragged_stack_concat_helper(rt_inputs, axis, stack_values):
|
||||
ragged_tensor.convert_to_tensor_or_ragged_tensor(
|
||||
rt_input, name='rt_input') for rt_input in rt_inputs
|
||||
]
|
||||
row_splits_dtype, rt_inputs = ragged_tensor.match_row_splits_dtypes(
|
||||
*rt_inputs, return_dtype=True)
|
||||
rt_inputs = list(rt_inputs)
|
||||
|
||||
# Special case: if there's only one input, then return it as-is.
|
||||
if len(rt_inputs) == 1:
|
||||
@ -168,12 +170,13 @@ def _ragged_stack_concat_helper(rt_inputs, axis, stack_values):
|
||||
# possible to concatenate Tensors and RaggedTensors together.
|
||||
for i in range(len(rt_inputs)):
|
||||
if not ragged_tensor.is_ragged(rt_inputs[i]):
|
||||
rt_inputs[i] = ragged_conversion_ops.from_tensor(
|
||||
rt_inputs[i], ragged_rank=1)
|
||||
rt_inputs[i] = ragged_tensor.RaggedTensor.from_tensor(
|
||||
rt_inputs[i], ragged_rank=1, row_splits_dtype=row_splits_dtype)
|
||||
|
||||
# Convert the input tensors to all have the same ragged_rank.
|
||||
ragged_rank = max(max(rt.ragged_rank for rt in rt_inputs), 1)
|
||||
rt_inputs = [_increase_ragged_rank_to(rt, ragged_rank) for rt in rt_inputs]
|
||||
rt_inputs = [_increase_ragged_rank_to(rt, ragged_rank, row_splits_dtype)
|
||||
for rt in rt_inputs]
|
||||
|
||||
if axis == 0:
|
||||
return _ragged_stack_concat_axis_0(rt_inputs, stack_values)
|
||||
@ -281,14 +284,16 @@ def _copy_row_shape(rt_inputs, splits):
|
||||
splits.set_shape(tensor_shape.TensorShape(rt.shape[0] + 1))
|
||||
|
||||
|
||||
def _increase_ragged_rank_to(rt_input, ragged_rank):
|
||||
def _increase_ragged_rank_to(rt_input, ragged_rank, row_splits_dtype):
|
||||
"""Adds ragged dimensions to `rt_input` so it has the desired ragged rank."""
|
||||
if ragged_rank > 0:
|
||||
if not ragged_tensor.is_ragged(rt_input):
|
||||
rt_input = ragged_conversion_ops.from_tensor(rt_input)
|
||||
rt_input = ragged_tensor.RaggedTensor.from_tensor(
|
||||
rt_input, row_splits_dtype=row_splits_dtype)
|
||||
if rt_input.ragged_rank < ragged_rank:
|
||||
rt_input = rt_input.with_values(
|
||||
_increase_ragged_rank_to(rt_input.values, ragged_rank - 1))
|
||||
_increase_ragged_rank_to(rt_input.values, ragged_rank - 1,
|
||||
row_splits_dtype))
|
||||
return rt_input
|
||||
|
||||
|
||||
|
33
tensorflow/python/ops/ragged/ragged_config.py
Normal file
33
tensorflow/python/ops/ragged/ragged_config.py
Normal file
@ -0,0 +1,33 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Configuration parameters for RaggedTensors."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
|
||||
def auto_cast_partition_dtype():
|
||||
"""Whether incopmatible row-partitioning dtypes should be auto-converted.
|
||||
|
||||
If true, then operations that combine RaggedTensors but have different
|
||||
row-partitioning tensor dtypes will be automatically cast to a
|
||||
compatible dtype (`tf.int64`). If false, then such operations will result
|
||||
in an error.
|
||||
|
||||
Returns:
|
||||
`bool`
|
||||
"""
|
||||
return False
|
@ -18,15 +18,22 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
|
||||
|
||||
def from_tensor(tensor, lengths=None, padding=None, ragged_rank=1, name=None):
|
||||
def from_tensor(tensor, lengths=None, padding=None, ragged_rank=1,
|
||||
row_splits_dtype=dtypes.int64, name=None):
|
||||
if ragged_tensor.is_ragged(tensor):
|
||||
return tensor
|
||||
else:
|
||||
return ragged_tensor.RaggedTensor.from_tensor(tensor, lengths, padding,
|
||||
ragged_rank, name)
|
||||
return ragged_tensor.RaggedTensor.from_tensor(
|
||||
tensor,
|
||||
lengths=lengths,
|
||||
padding=padding,
|
||||
ragged_rank=ragged_rank,
|
||||
row_splits_dtype=row_splits_dtype,
|
||||
name=name)
|
||||
|
||||
|
||||
def to_tensor(rt_input, default_value=None, name=None):
|
||||
|
@ -128,6 +128,7 @@ class UnaryRaggedElementwiseDispatcher(dispatch.OpDispatcher):
|
||||
elif not _is_convertible_to_tensor(elt):
|
||||
return self.NOT_SUPPORTED
|
||||
if found_ragged:
|
||||
x = ragged_tensor.match_row_splits_dtypes(*x)
|
||||
nested_splits_lists = [
|
||||
elt.nested_row_splits for elt in x if ragged_tensor.is_ragged(elt)
|
||||
]
|
||||
@ -199,6 +200,9 @@ class BinaryRaggedElementwiseDispatcher(dispatch.OpDispatcher):
|
||||
except (TypeError, ValueError):
|
||||
return self.NOT_SUPPORTED
|
||||
|
||||
if x_is_ragged and y_is_ragged:
|
||||
x, y = ragged_tensor.match_row_splits_dtypes(x, y)
|
||||
|
||||
if ((x_is_ragged and y_is_ragged) or
|
||||
(x_is_ragged and x.flat_values.shape.ndims <= y.shape.ndims) or
|
||||
(y_is_ragged and y.flat_values.shape.ndims <= x.shape.ndims)):
|
||||
@ -272,16 +276,6 @@ class RaggedDispatcher(dispatch.OpDispatcher):
|
||||
return found_ragged
|
||||
|
||||
|
||||
def ragged_dispatch(original_op, tensor_args):
|
||||
|
||||
def decorator(ragged_op):
|
||||
dispatch.RaggedDispatcher(original_op, ragged_op,
|
||||
tensor_args).register(original_op)
|
||||
return ragged_op
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
_UNARY_ELEMENTWISE_OPS = [
|
||||
array_ops.check_numerics,
|
||||
array_ops.identity,
|
||||
|
@ -34,7 +34,8 @@ from tensorflow.python.util.tf_export import tf_export
|
||||
# Op to construct a constant RaggedTensor from a nested Python list.
|
||||
#===============================================================================
|
||||
@tf_export("ragged.constant")
|
||||
def constant(pylist, dtype=None, ragged_rank=None, inner_shape=None, name=None):
|
||||
def constant(pylist, dtype=None, ragged_rank=None, inner_shape=None,
|
||||
name=None, row_splits_dtype=dtypes.int64):
|
||||
"""Constructs a constant RaggedTensor from a nested Python list.
|
||||
|
||||
Example:
|
||||
@ -65,6 +66,8 @@ def constant(pylist, dtype=None, ragged_rank=None, inner_shape=None, name=None):
|
||||
is not specified. If `ragged_rank` is specified, then a default is chosen
|
||||
based on the contents of `pylist`.
|
||||
name: A name prefix for the returned tensor (optional).
|
||||
row_splits_dtype: data type for the constructed `RaggedTensor`'s row_splits.
|
||||
One of `tf.int32` or `tf.int64`.
|
||||
|
||||
Returns:
|
||||
A potentially ragged tensor with rank `K` and the specified `ragged_rank`,
|
||||
@ -74,14 +77,19 @@ def constant(pylist, dtype=None, ragged_rank=None, inner_shape=None, name=None):
|
||||
ValueError: If the scalar values in `pylist` have inconsistent nesting
|
||||
depth; or if ragged_rank or inner_shape are incompatible with `pylist`.
|
||||
"""
|
||||
def _ragged_factory(values, row_splits):
|
||||
row_splits = constant_op.constant(row_splits, dtype=row_splits_dtype)
|
||||
return ragged_tensor.RaggedTensor.from_row_splits(values, row_splits)
|
||||
|
||||
with ops.name_scope(name, "RaggedConstant"):
|
||||
return _constant_value(ragged_tensor.RaggedTensor.from_row_splits,
|
||||
return _constant_value(_ragged_factory,
|
||||
constant_op.constant, pylist, dtype, ragged_rank,
|
||||
inner_shape)
|
||||
|
||||
|
||||
@tf_export(v1=["ragged.constant_value"])
|
||||
def constant_value(pylist, dtype=None, ragged_rank=None, inner_shape=None):
|
||||
def constant_value(pylist, dtype=None, ragged_rank=None, inner_shape=None,
|
||||
row_splits_dtype="int64"):
|
||||
"""Constructs a RaggedTensorValue from a nested Python list.
|
||||
|
||||
Warning: This function returns a `RaggedTensorValue`, not a `RaggedTensor`.
|
||||
@ -114,18 +122,20 @@ def constant_value(pylist, dtype=None, ragged_rank=None, inner_shape=None):
|
||||
values in the returned `RaggedTensorValue`. Defaults to `()` if
|
||||
`ragged_rank` is not specified. If `ragged_rank` is specified, then a
|
||||
default is chosen based on the contents of `pylist`.
|
||||
row_splits_dtype: data type for the constructed `RaggedTensorValue`'s
|
||||
row_splits. One of `numpy.int32` or `numpy.int64`.
|
||||
|
||||
Returns:
|
||||
A `RaggedTensorValue` or `numpy.array` with rank `K` and the specified
|
||||
A `tf.RaggedTensorValue` or `numpy.array` with rank `K` and the specified
|
||||
`ragged_rank`, containing the values from `pylist`.
|
||||
|
||||
Raises:
|
||||
ValueError: If the scalar values in `pylist` have inconsistent nesting
|
||||
depth; or if ragged_rank or inner_shape are incompatible with `pylist`.
|
||||
"""
|
||||
|
||||
row_splits_dtype = dtypes.as_dtype(row_splits_dtype).as_numpy_dtype
|
||||
def _ragged_factory(values, row_splits):
|
||||
row_splits = np.array(row_splits, dtype=np.int64)
|
||||
row_splits = np.array(row_splits, dtype=row_splits_dtype)
|
||||
return ragged_tensor_value.RaggedTensorValue(values, row_splits)
|
||||
|
||||
def _inner_factory(pylist, dtype, shape, name=None): # pylint: disable=unused-argument
|
||||
|
@ -18,7 +18,10 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops.ragged import ragged_config
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
from tensorflow.python.ops.ragged import ragged_util
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
@ -72,6 +75,17 @@ def map_flat_values(op, *args, **kwargs):
|
||||
if not nested_splits_lists:
|
||||
return op(*args, **kwargs)
|
||||
|
||||
split_dtypes = set(splits[0].dtype for splits in nested_splits_lists)
|
||||
if len(split_dtypes) > 1:
|
||||
if not ragged_config.auto_cast_partition_dtype():
|
||||
raise ValueError("Input RaggedTensors have mismatched row_splits dtypes; "
|
||||
"use RaggedTensor.with_row_splits_dtype() to convert "
|
||||
"them to compatible dtypes.")
|
||||
|
||||
nested_splits_lists = [
|
||||
[math_ops.cast(s, dtypes.int64) for s in nested_splits] # pylint: disable=g-complex-comprehension
|
||||
for nested_splits in nested_splits_lists]
|
||||
|
||||
with ops.control_dependencies(
|
||||
ragged_util.assert_splits_match(nested_splits_lists)):
|
||||
# Delegate to op, and then compose the result from the transformed values
|
||||
|
@ -96,6 +96,7 @@ def gather(params, indices, validate_indices=None, axis=0, batch_dims=0,
|
||||
params, name='params')
|
||||
indices = ragged_tensor.convert_to_tensor_or_ragged_tensor(
|
||||
indices, name='indices')
|
||||
params, indices = ragged_tensor.match_row_splits_dtypes(params, indices)
|
||||
|
||||
if ragged_tensor.is_ragged(indices):
|
||||
return indices.with_values(gather(params, indices.values))
|
||||
@ -177,6 +178,7 @@ def gather_nd(params, indices, batch_dims=0, name=None):
|
||||
params, name='params')
|
||||
indices = ragged_tensor.convert_to_tensor_or_ragged_tensor(
|
||||
indices, name='indices')
|
||||
params, indices = ragged_tensor.match_row_splits_dtypes(params, indices)
|
||||
indices_shape = indices.shape
|
||||
indices_ndims = indices_shape.ndims
|
||||
if indices_ndims is None:
|
||||
@ -200,7 +202,8 @@ def gather_nd(params, indices, batch_dims=0, name=None):
|
||||
indices_is_dense = not ragged_tensor.is_ragged(indices)
|
||||
if indices_is_dense:
|
||||
indices = ragged_conversion_ops.from_tensor(
|
||||
indices, ragged_rank=indices_ndims - 2)
|
||||
indices, ragged_rank=indices_ndims - 2,
|
||||
row_splits_dtype=params.row_splits.dtype)
|
||||
result = indices.with_flat_values(gather_nd(params, indices.flat_values))
|
||||
if (indices_is_dense and ragged_tensor.is_ragged(result) and
|
||||
result.ragged_rank == indices_ndims - 2):
|
||||
@ -235,7 +238,7 @@ def gather_nd(params, indices, batch_dims=0, name=None):
|
||||
# index tuples point to the correct values in the flattened params; and
|
||||
# then use ragged.gather on the flattened index tuples & params.
|
||||
else:
|
||||
indices = math_ops.cast(indices, dtypes.int64)
|
||||
indices = math_ops.cast(indices, params.row_splits.dtype)
|
||||
|
||||
# Flatten the outermost 2 dimensions of the index tuples & params.
|
||||
flattened_index_tuples = array_ops.gather(params.row_splits,
|
||||
|
@ -19,7 +19,6 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
@ -136,7 +135,8 @@ def _ragged_getitem(rt_input, key_list):
|
||||
# that puts all values in a single row.
|
||||
if row_key is array_ops.newaxis:
|
||||
inner_rt = _ragged_getitem(rt_input, inner_keys)
|
||||
nsplits = array_ops.shape(inner_rt.row_splits, out_type=dtypes.int64)[0]
|
||||
nsplits = array_ops.shape(inner_rt.row_splits,
|
||||
out_type=inner_rt.row_splits.dtype)[0]
|
||||
return ragged_tensor.RaggedTensor.from_row_splits(
|
||||
inner_rt, array_ops.stack([0, nsplits - 1]))
|
||||
|
||||
@ -192,7 +192,7 @@ def _slice_ragged_row_dimension(rt_input, row_key):
|
||||
# Use row_key to slice the starts & limits.
|
||||
new_starts = rt_input.row_splits[:-1][row_key]
|
||||
new_limits = rt_input.row_splits[1:][row_key]
|
||||
zero_pad = array_ops.zeros([1], dtypes.int64)
|
||||
zero_pad = array_ops.zeros([1], rt_input.row_splits.dtype)
|
||||
|
||||
# If there's no slice step, then we can just select a single continuous
|
||||
# span of `ragged.values(rt_input)`.
|
||||
@ -245,7 +245,8 @@ def _ragged_getitem_inner_dimensions(rt_input, key_list):
|
||||
# RaggedTensor that puts each value in its own row.
|
||||
if column_key is array_ops.newaxis:
|
||||
inner_rt = _ragged_getitem_inner_dimensions(rt_input, key_list[1:])
|
||||
nsplits = array_ops.shape(inner_rt.row_splits, out_type=dtypes.int64)[0]
|
||||
nsplits = array_ops.shape(inner_rt.row_splits,
|
||||
out_type=inner_rt.row_splits.dtype)[0]
|
||||
return ragged_tensor.RaggedTensor.from_row_splits(inner_rt,
|
||||
math_ops.range(nsplits))
|
||||
|
||||
@ -359,10 +360,11 @@ def _build_ragged_tensor_from_value_ranges(starts, limits, step, values):
|
||||
step = 1
|
||||
step = ops.convert_to_tensor(step, name="step")
|
||||
if step.dtype.is_integer:
|
||||
step = math_ops.cast(step, dtypes.int64)
|
||||
step = math_ops.cast(step, starts.dtype)
|
||||
else:
|
||||
raise TypeError("slice strides must be integers or None")
|
||||
value_indices = ragged_math_ops.range(starts, limits, step)
|
||||
value_indices = ragged_math_ops.range(starts, limits, step,
|
||||
row_splits_dtype=starts.dtype)
|
||||
|
||||
# Use `ragged_gather` or `array_ops.gather` to collect the values.
|
||||
if isinstance(values, ragged_tensor.RaggedTensor):
|
||||
@ -384,11 +386,11 @@ def _add_offset_to_ranges(offset, starts, limits):
|
||||
|
||||
Args:
|
||||
offset: The offset to add. None, or an int, or a scalar Tensor.
|
||||
starts: 1-D int64 tensor containing start indices.
|
||||
limits: 1-D int64 tensor containing limit indices.
|
||||
starts: 1-D integer tensor containing start indices.
|
||||
limits: 1-D integer tensor containing limit indices.
|
||||
|
||||
Returns:
|
||||
A 1-D int64 tensor.
|
||||
A 1-D integer tensor.
|
||||
"""
|
||||
|
||||
def map_positive_offset(offset):
|
||||
@ -398,7 +400,7 @@ def _add_offset_to_ranges(offset, starts, limits):
|
||||
return math_ops.maximum(limits + offset, starts)
|
||||
|
||||
if isinstance(offset, ops.Tensor):
|
||||
offset = math_ops.cast(offset, dtypes.int64)
|
||||
offset = math_ops.cast(offset, starts.dtype)
|
||||
return control_flow_ops.cond(offset >= 0,
|
||||
lambda: map_positive_offset(offset),
|
||||
lambda: map_negative_offset(offset))
|
||||
|
@ -222,7 +222,7 @@ class RaggedMapOpTest(ragged_test_util.RaggedTensorTestCase,
|
||||
def testZip(self):
|
||||
x = ragged_factory_ops.constant(
|
||||
[[10, 20], [30, 40], [50, 60], [70], [80, 90, 100]], dtypes.int64)
|
||||
y = array_ops.expand_dims(mo.range(x.nrows(), dtype=dtypes.int64), axis=1)
|
||||
y = array_ops.expand_dims(mo.range(x.nrows(out_type=dtypes.int64)), axis=1)
|
||||
|
||||
def _zip(foo):
|
||||
y_val, x_val = foo
|
||||
@ -273,7 +273,7 @@ class RaggedMapOpTest(ragged_test_util.RaggedTensorTestCase,
|
||||
elems = ragged_factory_ops.constant([[1, 2, 3], [4, 5], [6, 7]])
|
||||
fn = lambda x: ragged_tensor.RaggedTensor.from_row_starts(x, [0])
|
||||
with self.assertRaisesWithLiteralMatch(
|
||||
ValueError, r'The declared ragged rank (10) mismatches the result (1)'):
|
||||
ValueError, r'The declared ragged rank (10) mismatches the result (2)'):
|
||||
_ = ragged_map_ops.map_fn(
|
||||
fn,
|
||||
elems,
|
||||
|
@ -29,6 +29,7 @@ from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import tensor_array_ops
|
||||
from tensorflow.python.ops import variable_scope as vs
|
||||
from tensorflow.python.ops.ragged import ragged_config
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.util import nest
|
||||
@ -196,6 +197,7 @@ def map_fn(fn,
|
||||
return nest.pack_sequence_as(elems, x) if input_is_sequence else x[0]
|
||||
|
||||
elems_flat = input_flatten(elems)
|
||||
elems_flat = ragged_tensor.match_row_splits_dtypes(*elems_flat)
|
||||
|
||||
with ops.name_scope(name, "map", elems_flat):
|
||||
# TODO(akshayka): Remove the in_graph_mode check once caching devices are
|
||||
@ -408,8 +410,9 @@ def _maybe_decompose_dtype(d):
|
||||
|
||||
result = _RaggedTensorComponents(
|
||||
flat_values=d.dtype,
|
||||
nested_row_lengths=tuple(dtypes.int64 for i in range(d.ragged_rank - 1)),
|
||||
outer_row_length=dtypes.int64,
|
||||
nested_row_lengths=tuple(
|
||||
d.row_splits_dtype for i in range(d.ragged_rank - 1)),
|
||||
outer_row_length=d.row_splits_dtype,
|
||||
)
|
||||
return result
|
||||
|
||||
@ -418,31 +421,42 @@ def _convert_declared(fn_output_flat, output_declared):
|
||||
"""Convert outputs which are `Tensor`s into `_RaggedTensorComponents`."""
|
||||
for current, declared in zip(fn_output_flat, output_declared):
|
||||
if isinstance(declared, ragged_tensor.RaggedTensorType):
|
||||
if isinstance(current, ragged_tensor.RaggedTensor):
|
||||
# Check that the ragged ranks match up.
|
||||
# + 1 to account for the rank of the outermost dimension.
|
||||
if declared.ragged_rank != current.ragged_rank + 1:
|
||||
raise ValueError(
|
||||
"The declared ragged rank (%d) mismatches the result (%d)" %
|
||||
(declared.ragged_rank, current.ragged_rank))
|
||||
yield current
|
||||
else:
|
||||
# We the output is a Tensor, but the caller has declared that we are
|
||||
# expecting an RaggedTensor output.
|
||||
if declared.ragged_rank != 1:
|
||||
raise ValueError(
|
||||
"The declared ragged rank (%d) mismatches the result (1)" %
|
||||
declared.ragged_rank)
|
||||
|
||||
if isinstance(current, ragged_tensor.RaggedTensor):
|
||||
nrows = current.nrows()
|
||||
else:
|
||||
nrows = array_ops.shape(current, out_type=dtypes.int64)[0]
|
||||
row_length = array_ops.expand_dims(nrows, axis=0)
|
||||
rt = _RaggedTensorComponents(
|
||||
flat_values=current,
|
||||
nested_row_lengths=(),
|
||||
outer_row_length=row_length)
|
||||
yield rt
|
||||
yield _convert_declared_ragged(current, declared)
|
||||
else:
|
||||
yield current
|
||||
|
||||
|
||||
def _convert_declared_ragged(current, declared):
|
||||
"""Converts an output with RaggedTensorType into a _RaggedTensorComponents."""
|
||||
# Check that the ragged ranks match up.
|
||||
# + 1 to account for the rank of the outermost dimension.
|
||||
current_ragged_rank = getattr(current, "ragged_rank", 0)
|
||||
if declared.ragged_rank != current_ragged_rank + 1:
|
||||
raise ValueError(
|
||||
"The declared ragged rank (%d) mismatches the result (%d)" %
|
||||
(declared.ragged_rank, current_ragged_rank + 1))
|
||||
|
||||
# Check that dtypes match up.
|
||||
if declared.dtype != current.dtype:
|
||||
raise ValueError(
|
||||
"The declared dtype (%s) mismatches the result (%s)" %
|
||||
(declared.dtype, current.dtype))
|
||||
if (isinstance(current, ragged_tensor.RaggedTensor) and
|
||||
declared.row_splits_dtype != current.row_splits.dtype):
|
||||
if not ragged_config.auto_cast_partition_dtype():
|
||||
raise ValueError(
|
||||
"The declared row_splits dtype (%s) mismatches the result (%s)."
|
||||
" Use RaggedTensor.with_row_splits_dtype to convert it."
|
||||
% (declared.row_splits_dtype, current.row_splits.dtype))
|
||||
current = current.with_row_splits_dtype(declared.row_splits_dtype)
|
||||
|
||||
if isinstance(current, ragged_tensor.RaggedTensor):
|
||||
return current
|
||||
else:
|
||||
nrows = array_ops.shape(current, out_type=declared.row_splits_dtype)[0]
|
||||
row_length = array_ops.expand_dims(nrows, axis=0)
|
||||
return _RaggedTensorComponents(
|
||||
flat_values=current,
|
||||
nested_row_lengths=(),
|
||||
outer_row_length=row_length)
|
||||
|
||||
|
@ -39,7 +39,8 @@ from tensorflow.python.util.tf_export import tf_export
|
||||
#===============================================================================
|
||||
# pylint: disable=redefined-builtin
|
||||
@tf_export('ragged.range')
|
||||
def range(starts, limits=None, deltas=1, dtype=None, name=None):
|
||||
def range(starts, limits=None, deltas=1, dtype=None,
|
||||
name=None, row_splits_dtype=dtypes.int64):
|
||||
"""Returns a `RaggedTensor` containing the specified sequences of numbers.
|
||||
|
||||
Each row of the returned `RaggedTensor` contains a single sequence:
|
||||
@ -81,10 +82,13 @@ def range(starts, limits=None, deltas=1, dtype=None, name=None):
|
||||
dtype: The type of the elements of the resulting tensor. If not specified,
|
||||
then a value is chosen based on the other args.
|
||||
name: A name for the operation.
|
||||
row_splits_dtype: `dtype` for the returned `RaggedTensor`'s `row_splits`
|
||||
tensor. One of `tf.int32` or `tf.int64`.
|
||||
|
||||
Returns:
|
||||
A `RaggedTensor` of type `dtype` with `ragged_rank=1`.
|
||||
"""
|
||||
row_splits_dtype = dtypes.as_dtype(row_splits_dtype)
|
||||
if limits is None:
|
||||
starts, limits = 0, starts
|
||||
|
||||
@ -99,7 +103,8 @@ def range(starts, limits=None, deltas=1, dtype=None, name=None):
|
||||
[starts, limits, deltas],
|
||||
[dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64])
|
||||
|
||||
result = gen_ragged_math_ops.ragged_range(starts, limits, deltas, name=name)
|
||||
result = gen_ragged_math_ops.ragged_range(
|
||||
starts, limits, deltas, Tsplits=row_splits_dtype, name=name)
|
||||
return ragged_tensor.RaggedTensor.from_row_splits(result.rt_dense_values,
|
||||
result.rt_nested_splits)
|
||||
|
||||
@ -190,6 +195,9 @@ def _ragged_segment_aggregate(unsorted_segment_op,
|
||||
data = ragged_tensor.convert_to_tensor_or_ragged_tensor(data, name='data')
|
||||
segment_ids = ragged_tensor.convert_to_tensor_or_ragged_tensor(
|
||||
segment_ids, name='segment_ids')
|
||||
data, segment_ids = ragged_tensor.match_row_splits_dtypes(data, segment_ids)
|
||||
if segment_ids.dtype not in (dtypes.int32, dtypes.int64):
|
||||
raise ValueError('segment_ids must have dtype int32 or int64.')
|
||||
|
||||
if ragged_tensor.is_ragged(segment_ids):
|
||||
if not ragged_tensor.is_ragged(data):
|
||||
@ -203,22 +211,19 @@ def _ragged_segment_aggregate(unsorted_segment_op,
|
||||
return _ragged_segment_aggregate(unsorted_segment_op, data.values,
|
||||
segment_ids.values, num_segments, name)
|
||||
|
||||
segment_ids = math_ops.cast(segment_ids, dtypes.int64)
|
||||
|
||||
# Find the length of each row in data. (dtype=int64, shape=[data_nrows])
|
||||
# Find the length of each row in data. (shape=[data_nrows])
|
||||
data_row_lengths = data.row_splits[1:] - data.row_splits[:-1]
|
||||
|
||||
# Find the length that each output row will have. The length of the row
|
||||
# corresponding to segment `id` is `max(data_row_lengths[i])` where
|
||||
# `segment_ids[i]=id`. (dtype=int64, shape=[output_nrows])
|
||||
# `segment_ids[i]=id`. (shape=[output_nrows])
|
||||
output_row_lengths = math_ops.maximum(
|
||||
math_ops.unsorted_segment_max(data_row_lengths, segment_ids,
|
||||
num_segments), 0)
|
||||
assert output_row_lengths.dtype == dtypes.int64
|
||||
|
||||
# Build the splits tensor for the output RaggedTensor.
|
||||
output_splits = array_ops.concat([
|
||||
array_ops.zeros([1], dtypes.int64),
|
||||
array_ops.zeros([1], output_row_lengths.dtype),
|
||||
math_ops.cumsum(output_row_lengths)
|
||||
],
|
||||
axis=0)
|
||||
|
@ -43,8 +43,8 @@ class RaggedSplitsToSegmentIdsOpTest(ragged_test_util.RaggedTensorTestCase):
|
||||
self.assertRaisesRegexp(ValueError, r'Invalid row_splits: \[\]',
|
||||
segment_id_ops.row_splits_to_segment_ids, [])
|
||||
self.assertRaisesRegexp(
|
||||
ValueError, r'Tensor conversion requested dtype int64 for '
|
||||
'Tensor with dtype float32', segment_id_ops.row_splits_to_segment_ids,
|
||||
ValueError, r'splits must have dtype int32 or int64',
|
||||
segment_id_ops.row_splits_to_segment_ids,
|
||||
constant_op.constant([0.5]))
|
||||
self.assertRaisesRegexp(ValueError, r'Shape \(\) must have rank 1',
|
||||
segment_id_ops.row_splits_to_segment_ids, 0)
|
||||
|
@ -75,7 +75,7 @@ def squeeze(input, axis=None, name=None): # pylint: disable=redefined-builtin
|
||||
|
||||
# Make sure the specified ragged dimensions are squeezable.
|
||||
assertion_list = []
|
||||
scalar_tensor_one = constant_op.constant(1, dtype=dtypes.int64)
|
||||
scalar_tensor_one = constant_op.constant(1, dtype=input.row_splits.dtype)
|
||||
for i, r in enumerate(input.nested_row_lengths()):
|
||||
if i + 1 in ragged_dims:
|
||||
assertion_list.append(
|
||||
|
@ -24,7 +24,6 @@ from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gen_string_ops
|
||||
from tensorflow.python.ops import string_ops
|
||||
from tensorflow.python.ops.ragged import ragged_array_ops
|
||||
from tensorflow.python.ops.ragged import ragged_conversion_ops
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
from tensorflow.python.util import deprecation
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
@ -146,8 +145,8 @@ def unicode_encode(input,
|
||||
if input_tensor.shape.ndims == 2:
|
||||
# The input tensor is of the correct 2-D shape, it's just not ragged.
|
||||
return unicode_encode(
|
||||
ragged_conversion_ops.from_tensor(input_tensor), output_encoding,
|
||||
errors, replacement_char)
|
||||
ragged_tensor.RaggedTensor.from_tensor(input_tensor),
|
||||
output_encoding, errors, replacement_char)
|
||||
elif input_tensor.shape.ndims > 2:
|
||||
# We need to initially flatten the input tensor to 2-D, and then can
|
||||
# reshape the output of our processed flattened tensor.
|
||||
@ -166,7 +165,7 @@ def unicode_encode(input,
|
||||
ragged_input_tensor = ragged_tensor.RaggedTensor.from_row_splits(
|
||||
input_tensor,
|
||||
array_ops.stack(
|
||||
[0, array_ops.shape(input_tensor, out_type=dtypes.int64)[0]]))
|
||||
[0, array_ops.shape(input_tensor, out_type=dtypes.int32)[0]]))
|
||||
output_tensor = unicode_encode(ragged_input_tensor, output_encoding,
|
||||
errors, replacement_char)
|
||||
return array_ops.reshape(output_tensor, [])
|
||||
@ -404,11 +403,11 @@ def _unicode_decode(input, input_encoding, errors, replacement_char,
|
||||
if input_ndims > 1:
|
||||
# Convert to a ragged tensor with ragged_rank = input_ndims - 1.
|
||||
if not ragged_tensor.is_ragged(input):
|
||||
input = ragged_conversion_ops.from_tensor(
|
||||
input = ragged_tensor.RaggedTensor.from_tensor(
|
||||
input, ragged_rank=input_ndims - 1)
|
||||
elif input.ragged_rank < input_ndims - 1:
|
||||
input = input.with_flat_values(
|
||||
ragged_conversion_ops.from_tensor(
|
||||
ragged_tensor.RaggedTensor.from_tensor(
|
||||
input.flat_values,
|
||||
ragged_rank=input_ndims - input.ragged_rank + 1))
|
||||
|
||||
|
@ -30,6 +30,7 @@ from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import gen_ragged_conversion_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops.ragged import ragged_config
|
||||
from tensorflow.python.ops.ragged import ragged_tensor_value
|
||||
from tensorflow.python.ops.ragged import ragged_util
|
||||
from tensorflow.python.ops.ragged import segment_id_ops
|
||||
@ -115,8 +116,8 @@ class RaggedTensor(composite_tensor.CompositeTensor):
|
||||
`[nvals]`, corresponding one-to-one with `values`, which specifies
|
||||
each value's row index. In particular, the row `rt[row]` consists of the
|
||||
values `rt.values[j]` where `value_rowids[j]==row`. `nrows` is an
|
||||
int64 scalar that specifies the number of rows in the `RaggedTensor`.
|
||||
(`nrows` is used to indicate trailing empty rows.)
|
||||
integer scalar that specifies the number of rows in the
|
||||
`RaggedTensor`. (`nrows` is used to indicate trailing empty rows.)
|
||||
|
||||
* `row_starts`: a vector with shape `[nrows]`, which specifies the start
|
||||
offset of each row. Equivalent to `row_splits[:-1]`.
|
||||
@ -220,10 +221,10 @@ class RaggedTensor(composite_tensor.CompositeTensor):
|
||||
|
||||
Args:
|
||||
values: A potentially ragged tensor of any dtype and shape `[nvals, ...]`.
|
||||
row_splits: A 1-D int64 tensor with shape `[nrows+1]`.
|
||||
cached_row_lengths: A 1-D int64 tensor with shape `[nrows]`
|
||||
cached_value_rowids: A 1-D int64 tensor with shape `[nvals]`.
|
||||
cached_nrows: A 1-D int64 scalar tensor.
|
||||
row_splits: A 1-D integer tensor with shape `[nrows+1]`.
|
||||
cached_row_lengths: A 1-D integer tensor with shape `[nrows]`
|
||||
cached_value_rowids: A 1-D integer tensor with shape `[nvals]`.
|
||||
cached_nrows: A 1-D integer scalar tensor.
|
||||
internal: True if the constructor is being called by one of the factory
|
||||
methods. If false, an exception will be raised.
|
||||
|
||||
@ -244,9 +245,13 @@ class RaggedTensor(composite_tensor.CompositeTensor):
|
||||
raise TypeError("values must be a Tensor or RaggedTensor.")
|
||||
if not isinstance(row_splits, ops.Tensor):
|
||||
raise TypeError("Row-partitioning argument must be a Tensor.")
|
||||
if row_splits.dtype not in (dtypes.int32, dtypes.int64):
|
||||
raise ValueError("Row-partitioning argument must be int32 or int64")
|
||||
values.shape.with_rank_at_least(1)
|
||||
row_splits.shape.assert_has_rank(1)
|
||||
row_splits.set_shape([None])
|
||||
if isinstance(values, RaggedTensor):
|
||||
assert row_splits.dtype == values.row_splits.dtype
|
||||
|
||||
self._values = values
|
||||
self._row_splits = row_splits
|
||||
@ -255,8 +260,11 @@ class RaggedTensor(composite_tensor.CompositeTensor):
|
||||
# round-trip conversions when a RaggedTensor is constructed from
|
||||
# lengths or rowids, and we later want those lengths/rowids back.
|
||||
for tensor in [cached_row_lengths, cached_value_rowids, cached_nrows]:
|
||||
if tensor is not None and not isinstance(tensor, ops.Tensor):
|
||||
raise TypeError("Cached value must be a Tensor or None.")
|
||||
if tensor is not None:
|
||||
if not isinstance(tensor, ops.Tensor):
|
||||
raise TypeError("Cached value must be a Tensor or None.")
|
||||
elif tensor.dtype not in (dtypes.int32, dtypes.int64):
|
||||
raise TypeError("Cached value must be int32 or int64.")
|
||||
self._cached_row_lengths = cached_row_lengths
|
||||
self._cached_value_rowids = cached_value_rowids
|
||||
self._cached_nrows = cached_nrows
|
||||
@ -276,15 +284,12 @@ class RaggedTensor(composite_tensor.CompositeTensor):
|
||||
for row in range(nrows)]
|
||||
```
|
||||
|
||||
Warning: currently, this needs to cast value_rowids to int64 before
|
||||
converting, since `tf.math.bincount` only supports `int32`.
|
||||
|
||||
Args:
|
||||
values: A potentially ragged tensor with shape `[nvals, ...]`.
|
||||
value_rowids: A 1-D int64 tensor with shape `[nvals]`, which corresponds
|
||||
value_rowids: A 1-D integer tensor with shape `[nvals]`, which corresponds
|
||||
one-to-one with `values`, and specifies each value's row index. Must be
|
||||
nonnegative, and must be sorted in ascending order.
|
||||
nrows: An int64 scalar specifying the number of rows. This should be
|
||||
nrows: An integer scalar specifying the number of rows. This should be
|
||||
specified if the `RaggedTensor` may containing empty training rows. Must
|
||||
be greater than `value_rowids[-1]` (or zero if `value_rowids` is empty).
|
||||
Defaults to `value_rowids[-1]` (or zero if `value_rowids` is empty).
|
||||
@ -308,9 +313,8 @@ class RaggedTensor(composite_tensor.CompositeTensor):
|
||||
"""
|
||||
with ops.name_scope(name, "RaggedFromValueRowIds",
|
||||
[values, value_rowids, nrows]):
|
||||
values = convert_to_tensor_or_ragged_tensor(values, name="values")
|
||||
value_rowids = ops.convert_to_tensor(
|
||||
value_rowids, dtypes.int64, name="value_rowids")
|
||||
values, value_rowids = cls._convert_values_and_row_partition(
|
||||
values, value_rowids, "value_rowids")
|
||||
if nrows is None:
|
||||
const_rowids = tensor_util.constant_value(value_rowids)
|
||||
if const_rowids is None:
|
||||
@ -318,9 +322,10 @@ class RaggedTensor(composite_tensor.CompositeTensor):
|
||||
const_nrows = None
|
||||
else:
|
||||
const_nrows = const_rowids[-1] + 1 if const_rowids.size > 0 else 0
|
||||
nrows = ops.convert_to_tensor(const_nrows, dtypes.int64, name="nrows")
|
||||
nrows = ops.convert_to_tensor(const_nrows, value_rowids.dtype,
|
||||
name="nrows")
|
||||
else:
|
||||
nrows = ops.convert_to_tensor(nrows, dtypes.int64, "nrows")
|
||||
nrows = ops.convert_to_tensor(nrows, value_rowids.dtype, "nrows")
|
||||
const_nrows = tensor_util.constant_value(nrows)
|
||||
if const_nrows is not None:
|
||||
if const_nrows < 0:
|
||||
@ -340,14 +345,14 @@ class RaggedTensor(composite_tensor.CompositeTensor):
|
||||
# Note: we don't use segment_ids_to_row_splits() here because we want
|
||||
# to save the intermediate value `row_lengths`, so we can cache it.
|
||||
# TODO(b/116708836) Upgrade bincount to accept int64 so we can skip the
|
||||
# cast (Remove the warning in the docstring when we do.)
|
||||
# cast.
|
||||
value_rowids_int32 = math_ops.cast(value_rowids, dtypes.int32)
|
||||
nrows_int32 = math_ops.cast(nrows, dtypes.int32)
|
||||
row_lengths = math_ops.bincount(
|
||||
value_rowids_int32,
|
||||
minlength=nrows_int32,
|
||||
maxlength=nrows_int32,
|
||||
dtype=dtypes.int64)
|
||||
dtype=value_rowids.dtype)
|
||||
row_splits = array_ops.concat([[0], math_ops.cumsum(row_lengths)], axis=0)
|
||||
if const_nrows is not None:
|
||||
row_lengths.set_shape([const_nrows])
|
||||
@ -374,9 +379,9 @@ class RaggedTensor(composite_tensor.CompositeTensor):
|
||||
|
||||
Args:
|
||||
values: A potentially ragged tensor with shape `[nvals, ...]`.
|
||||
row_splits: A 1-D int64 tensor with shape `[nrows+1]`. Must not be empty,
|
||||
and must be sorted in ascending order. `row_splits[0]` must be zero and
|
||||
`row_splits[-1]` must be `nvals`.
|
||||
row_splits: A 1-D integer tensor with shape `[nrows+1]`. Must not be
|
||||
empty, and must be sorted in ascending order. `row_splits[0]` must be
|
||||
zero and `row_splits[-1]` must be `nvals`.
|
||||
name: A name prefix for the RaggedTensor (optional).
|
||||
|
||||
Returns:
|
||||
@ -397,8 +402,8 @@ class RaggedTensor(composite_tensor.CompositeTensor):
|
||||
if isinstance(row_splits, (list, tuple)) and not row_splits:
|
||||
raise ValueError("row_splits tensor may not be empty.")
|
||||
with ops.name_scope(name, "RaggedFromRowSplits", [values, row_splits]):
|
||||
values = convert_to_tensor_or_ragged_tensor(values, name="values")
|
||||
row_splits = ops.convert_to_tensor(row_splits, dtypes.int64, "row_splits")
|
||||
values, row_splits = cls._convert_values_and_row_partition(
|
||||
values, row_splits, "row_splits")
|
||||
row_splits.shape.assert_has_rank(1)
|
||||
return cls(values=values, row_splits=row_splits, internal=True)
|
||||
|
||||
@ -415,7 +420,7 @@ class RaggedTensor(composite_tensor.CompositeTensor):
|
||||
|
||||
Args:
|
||||
values: A potentially ragged tensor with shape `[nvals, ...]`.
|
||||
row_lengths: A 1-D int64 tensor with shape `[nrows]`. Must be
|
||||
row_lengths: A 1-D integer tensor with shape `[nrows]`. Must be
|
||||
nonnegative. `sum(row_lengths)` must be `nvals`.
|
||||
name: A name prefix for the RaggedTensor (optional).
|
||||
|
||||
@ -432,9 +437,8 @@ class RaggedTensor(composite_tensor.CompositeTensor):
|
||||
```
|
||||
"""
|
||||
with ops.name_scope(name, "RaggedFromRowLengths", [values, row_lengths]):
|
||||
values = convert_to_tensor_or_ragged_tensor(values, name="values")
|
||||
row_lengths = ops.convert_to_tensor(row_lengths, dtypes.int64,
|
||||
"row_lengths")
|
||||
values, row_lengths = cls._convert_values_and_row_partition(
|
||||
values, row_lengths, "row_lengths")
|
||||
row_lengths.shape.assert_has_rank(1)
|
||||
row_limits = math_ops.cumsum(row_lengths)
|
||||
row_splits = array_ops.concat([[0], row_limits], axis=0)
|
||||
@ -452,9 +456,9 @@ class RaggedTensor(composite_tensor.CompositeTensor):
|
||||
|
||||
Args:
|
||||
values: A potentially ragged tensor with shape `[nvals, ...]`.
|
||||
row_starts: A 1-D int64 tensor with shape `[nrows]`. Must be nonnegative
|
||||
and sorted in ascending order. If `nrows>0`, then `row_starts[0]` must
|
||||
be zero.
|
||||
row_starts: A 1-D integer tensor with shape `[nrows]`. Must be
|
||||
nonnegative and sorted in ascending order. If `nrows>0`, then
|
||||
`row_starts[0]` must be zero.
|
||||
name: A name prefix for the RaggedTensor (optional).
|
||||
|
||||
Returns:
|
||||
@ -470,10 +474,10 @@ class RaggedTensor(composite_tensor.CompositeTensor):
|
||||
```
|
||||
"""
|
||||
with ops.name_scope(name, "RaggedFromRowStarts", [values, row_starts]):
|
||||
values = convert_to_tensor_or_ragged_tensor(values, name="values")
|
||||
row_starts = ops.convert_to_tensor(row_starts, dtypes.int64, "row_starts")
|
||||
values, row_starts = cls._convert_values_and_row_partition(
|
||||
values, row_starts, "row_starts")
|
||||
row_starts.shape.assert_has_rank(1)
|
||||
nvals = array_ops.shape(values, out_type=dtypes.int64)[:1]
|
||||
nvals = array_ops.shape(values, out_type=row_starts.dtype)[:1]
|
||||
row_splits = array_ops.concat([row_starts, nvals], axis=0)
|
||||
return cls(values=values, row_splits=row_splits, internal=True)
|
||||
|
||||
@ -485,7 +489,7 @@ class RaggedTensor(composite_tensor.CompositeTensor):
|
||||
|
||||
Args:
|
||||
values: A potentially ragged tensor with shape `[nvals, ...]`.
|
||||
row_limits: A 1-D int64 tensor with shape `[nrows]`. Must be sorted in
|
||||
row_limits: A 1-D integer tensor with shape `[nrows]`. Must be sorted in
|
||||
ascending order. If `nrows>0`, then `row_limits[-1]` must be `nvals`.
|
||||
name: A name prefix for the RaggedTensor (optional).
|
||||
|
||||
@ -502,10 +506,10 @@ class RaggedTensor(composite_tensor.CompositeTensor):
|
||||
```
|
||||
"""
|
||||
with ops.name_scope(name, "RaggedFromRowLimits", [values, row_limits]):
|
||||
values = convert_to_tensor_or_ragged_tensor(values, name="values")
|
||||
row_limits = ops.convert_to_tensor(row_limits, dtypes.int64, "row_limits")
|
||||
values, row_limits = cls._convert_values_and_row_partition(
|
||||
values, row_limits, "row_limits")
|
||||
row_limits.shape.assert_has_rank(1)
|
||||
zero = array_ops.zeros([1], dtypes.int64)
|
||||
zero = array_ops.zeros([1], row_limits.dtype)
|
||||
row_splits = array_ops.concat([zero, row_limits], axis=0)
|
||||
return cls(values=values, row_splits=row_splits, internal=True)
|
||||
|
||||
@ -527,9 +531,9 @@ class RaggedTensor(composite_tensor.CompositeTensor):
|
||||
|
||||
Args:
|
||||
flat_values: A potentially ragged tensor.
|
||||
nested_value_rowids: A list of 1-D int64 tensors. The `i`th tensor is
|
||||
nested_value_rowids: A list of 1-D integer tensors. The `i`th tensor is
|
||||
used as the `value_rowids` for the `i`th ragged dimension.
|
||||
nested_nrows: A list of int64 scalars. The `i`th scalar is used as the
|
||||
nested_nrows: A list of integer scalars. The `i`th scalar is used as the
|
||||
`nrows` for the `i`th ragged dimension.
|
||||
name: A name prefix for the RaggedTensor (optional).
|
||||
|
||||
@ -573,8 +577,8 @@ class RaggedTensor(composite_tensor.CompositeTensor):
|
||||
|
||||
Args:
|
||||
flat_values: A potentially ragged tensor.
|
||||
nested_row_splits: A list of 1-D int64 tensors. The `i`th tensor is used
|
||||
as the `row_splits` for the `i`th ragged dimension.
|
||||
nested_row_splits: A list of 1-D integer tensors. The `i`th tensor is
|
||||
used as the `row_splits` for the `i`th ragged dimension.
|
||||
name: A name prefix for the RaggedTensor (optional).
|
||||
|
||||
Returns:
|
||||
@ -603,8 +607,8 @@ class RaggedTensor(composite_tensor.CompositeTensor):
|
||||
|
||||
Args:
|
||||
flat_values: A potentially ragged tensor.
|
||||
nested_row_lengths: A list of 1-D int64 tensors. The `i`th tensor is used
|
||||
as the `row_lengths` for the `i`th ragged dimension.
|
||||
nested_row_lengths: A list of 1-D integer tensors. The `i`th tensor is
|
||||
used as the `row_lengths` for the `i`th ragged dimension.
|
||||
name: A name prefix for the RaggedTensor (optional).
|
||||
|
||||
Returns:
|
||||
@ -619,6 +623,50 @@ class RaggedTensor(composite_tensor.CompositeTensor):
|
||||
result = cls.from_row_lengths(result, lengths)
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def _convert_values_and_row_partition(cls, values, partition, name):
|
||||
"""Converts `values` and `partition` to Tensors.
|
||||
|
||||
If `values` is a `RaggedTensor`, then converts `values` and `partition`
|
||||
to have compatible row-partitioning dtypes. In particular, if any of the
|
||||
row partitioning tensors are `int64`, then all of the other row
|
||||
partitioning tensors wil be cast to `int64` (if auto_cast_partition_dtype()
|
||||
is true) or an error will be raised (if auto_cast_partition_dtype() is
|
||||
false).
|
||||
|
||||
Args:
|
||||
values: The `values` for the `RaggedTensor` being constructed.
|
||||
partition: A row-partitioning tensor for the `RaggedTensor` being
|
||||
constructed. I.e., one of: row_splits, row_lengths, row_starts,
|
||||
row_limits, value_rowids.
|
||||
name: The name of the row-partitioning tensor.
|
||||
|
||||
Returns:
|
||||
A tuple (values, partition).
|
||||
"""
|
||||
if isinstance(values, RaggedTensor):
|
||||
if isinstance(partition, ops.Tensor):
|
||||
if partition.dtype not in (dtypes.int32, dtypes.int64):
|
||||
raise ValueError("%s must have dtype int32 or int64" % name)
|
||||
if values.row_splits.dtype != partition.dtype:
|
||||
if not ragged_config.auto_cast_partition_dtype():
|
||||
raise ValueError("dtype mismatch: %s (%s) vs values.row_splits (%s)"
|
||||
% (name, partition.dtype, values.row_splits.dtype))
|
||||
partition = math_ops.cast(partition, dtypes.int64)
|
||||
values = values.with_row_splits_dtype(dtypes.int64)
|
||||
else:
|
||||
partition = ops.convert_to_tensor(partition, values.row_splits.dtype,
|
||||
name=name)
|
||||
else:
|
||||
values = ops.convert_to_tensor(values, name="values")
|
||||
partition = ops.convert_to_tensor(
|
||||
partition, preferred_dtype=dtypes.int64,
|
||||
name=name)
|
||||
if partition.dtype not in (dtypes.int32, dtypes.int64):
|
||||
raise ValueError("%s must have dtype int32 or int64" % name)
|
||||
|
||||
return (values, partition)
|
||||
|
||||
#=============================================================================
|
||||
# Accessors
|
||||
#=============================================================================
|
||||
@ -696,7 +744,7 @@ class RaggedTensor(composite_tensor.CompositeTensor):
|
||||
the slice `rt.values[rt.row_splits[i]:rt.row_splits[i+1]]`.
|
||||
|
||||
Returns:
|
||||
A 1-D `int64` `Tensor` with shape `[self.nrows+1]`.
|
||||
A 1-D integer `Tensor` with shape `[self.nrows+1]`.
|
||||
The returned tensor is non-empty, and is sorted in ascending order.
|
||||
`self.row_splits[0]` is zero, and `self.row_splits[-1]` is equal to
|
||||
`self.values.shape[0]`.
|
||||
@ -752,7 +800,7 @@ class RaggedTensor(composite_tensor.CompositeTensor):
|
||||
* `value_splits = rt.values.nested_row_splits` otherwise.
|
||||
|
||||
Returns:
|
||||
A `tuple` of 1-D `int64` `Tensor`s.
|
||||
A `tuple` of 1-D integer `Tensor`s.
|
||||
|
||||
#### Example:
|
||||
|
||||
@ -785,7 +833,7 @@ class RaggedTensor(composite_tensor.CompositeTensor):
|
||||
name: A name prefix for the returned tensor (optional).
|
||||
|
||||
Returns:
|
||||
A 1-D `int64` `Tensor` with shape `self.values.shape[:1]`.
|
||||
A 1-D integer `Tensor` with shape `self.values.shape[:1]`.
|
||||
The returned tensor is nonnegative, and is sorted in ascending order.
|
||||
|
||||
#### Example:
|
||||
@ -803,13 +851,14 @@ class RaggedTensor(composite_tensor.CompositeTensor):
|
||||
with ops.name_scope(name, "RaggedValueRowIds", [self]):
|
||||
return segment_id_ops.row_splits_to_segment_ids(self.row_splits)
|
||||
|
||||
def nrows(self, out_type=dtypes.int64, name=None):
|
||||
def nrows(self, out_type=None, name=None):
|
||||
"""Returns the number of rows in this ragged tensor.
|
||||
|
||||
I.e., the size of the outermost dimension of the tensor.
|
||||
|
||||
Args:
|
||||
out_type: `dtype` for the returned tensor.
|
||||
out_type: `dtype` for the returned tensor. Defaults to
|
||||
`self.row_splits.dtype`.
|
||||
name: A name prefix for the returned tensor (optional).
|
||||
|
||||
Returns:
|
||||
@ -824,7 +873,10 @@ class RaggedTensor(composite_tensor.CompositeTensor):
|
||||
"""
|
||||
if self._cached_nrows is not None:
|
||||
return self._cached_nrows
|
||||
|
||||
if out_type is None:
|
||||
out_type = self._row_splits.dtype
|
||||
else:
|
||||
out_type = dtypes.as_dtype(out_type)
|
||||
with ops.name_scope(name, "RaggedNRows", [self]):
|
||||
return array_ops.shape(self.row_splits, out_type=out_type)[0] - 1
|
||||
|
||||
@ -838,7 +890,7 @@ class RaggedTensor(composite_tensor.CompositeTensor):
|
||||
name: A name prefix for the returned tensor (optional).
|
||||
|
||||
Returns:
|
||||
A 1-D Tensor of int64 with shape `[nrows]`.
|
||||
A 1-D integer Tensor with shape `[nrows]`.
|
||||
The returned tensor is nonnegative, and is sorted in ascending order.
|
||||
|
||||
#### Example:
|
||||
@ -863,7 +915,7 @@ class RaggedTensor(composite_tensor.CompositeTensor):
|
||||
name: A name prefix for the returned tensor (optional).
|
||||
|
||||
Returns:
|
||||
A 1-D Tensor of int64 with shape `[nrows]`.
|
||||
A 1-D integer Tensor with shape `[nrows]`.
|
||||
The returned tensor is nonnegative, and is sorted in ascending order.
|
||||
|
||||
#### Example:
|
||||
@ -890,7 +942,7 @@ class RaggedTensor(composite_tensor.CompositeTensor):
|
||||
name: A name prefix for the returned tensor (optional).
|
||||
|
||||
Returns:
|
||||
A potentially ragged Tensor of int64 with shape `self.shape[:axis]`.
|
||||
A potentially ragged integer Tensor with shape `self.shape[:axis]`.
|
||||
|
||||
Raises:
|
||||
ValueError: If `axis` is out of bounds.
|
||||
@ -917,9 +969,10 @@ class RaggedTensor(composite_tensor.CompositeTensor):
|
||||
elif isinstance(self.values, RaggedTensor):
|
||||
return self.with_values(self.values.row_lengths(axis - 1))
|
||||
else:
|
||||
shape = array_ops.shape(self.values, out_type=dtypes.int64)
|
||||
shape = array_ops.shape(self.values, out_type=self._row_splits.dtype)
|
||||
return self.with_values(
|
||||
array_ops.ones(shape[:axis - 1], dtypes.int64) * shape[axis - 1])
|
||||
array_ops.ones(shape[:axis - 1], self._row_splits.dtype) *
|
||||
shape[axis - 1])
|
||||
|
||||
def nested_row_lengths(self, name=None):
|
||||
"""Returns a tuple containing the row_lengths for all ragged dimensions.
|
||||
@ -931,7 +984,7 @@ class RaggedTensor(composite_tensor.CompositeTensor):
|
||||
name: A name prefix for the returned tensors (optional).
|
||||
|
||||
Returns:
|
||||
A `tuple` of 1-D `int64` `Tensors`. The length of the tuple is equal to
|
||||
A `tuple` of 1-D integer `Tensors`. The length of the tuple is equal to
|
||||
`self.ragged_rank`.
|
||||
"""
|
||||
with ops.name_scope(name, "RaggedNestedRowLengths", [self]):
|
||||
@ -942,7 +995,7 @@ class RaggedTensor(composite_tensor.CompositeTensor):
|
||||
rt = rt.values
|
||||
return tuple(rt_nested_row_lengths)
|
||||
|
||||
def bounding_shape(self, axis=None, name=None):
|
||||
def bounding_shape(self, axis=None, name=None, out_type=None):
|
||||
"""Returns the tight bounding box shape for this `RaggedTensor`.
|
||||
|
||||
Args:
|
||||
@ -950,13 +1003,15 @@ class RaggedTensor(composite_tensor.CompositeTensor):
|
||||
bounding box for. If not specified, then the full bounding box is
|
||||
returned.
|
||||
name: A name prefix for the returned tensor (optional).
|
||||
out_type: `dtype` for the returned tensor. Defaults to
|
||||
`self.row_splits.dtype`.
|
||||
|
||||
Returns:
|
||||
An int64 `Tensor`. If `axis` is not specified, then `output`
|
||||
is a vector with `output.shape=[self.shape.ndims]`. If `axis` is a
|
||||
scalar, then the `output` is a scalar. If `axis` is a vector, then
|
||||
`output` is a vector, where `output[i]` is the bounding size for
|
||||
dimension `axis[i]`.
|
||||
An integer `Tensor` (`dtype=self.row_splits.dtype`). If `axis` is not
|
||||
specified, then `output` is a vector with
|
||||
`output.shape=[self.shape.ndims]`. If `axis` is a scalar, then the
|
||||
`output` is a scalar. If `axis` is a vector, then `output` is a vector,
|
||||
where `output[i]` is the bounding size for dimension `axis[i]`.
|
||||
|
||||
#### Example:
|
||||
```python
|
||||
@ -965,6 +1020,10 @@ class RaggedTensor(composite_tensor.CompositeTensor):
|
||||
[5, 4]
|
||||
```
|
||||
"""
|
||||
if out_type is None:
|
||||
out_type = self._row_splits.dtype
|
||||
else:
|
||||
out_type = dtypes.as_dtype(out_type)
|
||||
with ops.name_scope(name, "RaggedBoundingBox", [self, axis]):
|
||||
nested_splits = self.nested_row_splits
|
||||
rt_flat_values = self.flat_values
|
||||
@ -972,12 +1031,12 @@ class RaggedTensor(composite_tensor.CompositeTensor):
|
||||
# Optimized special cases for when axis=0 or axis=1:
|
||||
if isinstance(axis, int):
|
||||
if axis == 0:
|
||||
return array_ops.shape(nested_splits[0], out_type=dtypes.int64)[0] - 1
|
||||
return array_ops.shape(nested_splits[0], out_type=out_type)[0] - 1
|
||||
elif axis == 1:
|
||||
return math_ops.maximum(math_ops.reduce_max(self.row_lengths()), 0)
|
||||
|
||||
splits_shape = array_ops.shape(self.row_splits, out_type=dtypes.int64)
|
||||
flat_values_shape = array_ops.shape(rt_flat_values, out_type=dtypes.int64)
|
||||
splits_shape = array_ops.shape(self.row_splits, out_type=out_type)
|
||||
flat_values_shape = array_ops.shape(rt_flat_values, out_type=out_type)
|
||||
|
||||
ragged_dimensions = array_ops.stack([splits_shape[0] - 1] + [
|
||||
math_ops.maximum(math_ops.reduce_max(splits[1:] - splits[:-1]), 0)
|
||||
@ -1009,6 +1068,14 @@ class RaggedTensor(composite_tensor.CompositeTensor):
|
||||
"""
|
||||
new_values.shape.with_rank_at_least(1)
|
||||
self.values.shape[:1].assert_is_compatible_with(new_values.shape[:1])
|
||||
if (isinstance(new_values, RaggedTensor) and
|
||||
self._row_splits.dtype != new_values.row_splits.dtype):
|
||||
if not ragged_config.auto_cast_partition_dtype():
|
||||
raise ValueError("self and new_values have mismatched row_splits "
|
||||
"dtypes; use RaggedTensor.with_row_splits_dtype() to "
|
||||
"convert them to compatible dtypes.")
|
||||
new_values = new_values.with_row_splits_dtype(dtypes.int64)
|
||||
return self.with_row_splits_dtype(dtypes.int64).with_values(new_values)
|
||||
return RaggedTensor(
|
||||
new_values,
|
||||
self._row_splits,
|
||||
@ -1038,6 +1105,43 @@ class RaggedTensor(composite_tensor.CompositeTensor):
|
||||
else:
|
||||
return self.with_values(self.values.with_flat_values(new_values))
|
||||
|
||||
def with_row_splits_dtype(self, dtype):
|
||||
"""Returns a copy of this RaggedTensor with the given `row_splits` dtype.
|
||||
|
||||
For RaggedTensors with multiple ragged dimensions, the `row_splits` for all
|
||||
nested `RaggedTensor` objects are cast to the given dtype.
|
||||
|
||||
Args:
|
||||
dtype: The dtype for `row_splits`. One of `tf.int32` or `tf.int64`.
|
||||
|
||||
Returns:
|
||||
A copy of this RaggedTensor, with the `row_splits` cast to the given
|
||||
type.
|
||||
"""
|
||||
dtype = dtypes.as_dtype(dtype)
|
||||
if dtype not in (dtypes.int32, dtypes.int64):
|
||||
raise ValueError("dtype must be int32 or int64")
|
||||
if self._row_splits.dtype == dtype:
|
||||
return self
|
||||
|
||||
row_splits = math_ops.cast(self._row_splits, dtype)
|
||||
|
||||
values = self._values
|
||||
if isinstance(values, RaggedTensor):
|
||||
values = values.with_row_splits_dtype(dtype)
|
||||
cached_row_lengths = self._cached_row_lengths
|
||||
if cached_row_lengths is not None:
|
||||
cached_row_lengths = math_ops.cast(cached_row_lengths, dtype)
|
||||
cached_value_rowids = self._cached_value_rowids
|
||||
if cached_value_rowids is not None:
|
||||
cached_value_rowids = math_ops.cast(cached_value_rowids, dtype)
|
||||
cached_nrows = self._cached_nrows
|
||||
if cached_value_rowids is not None:
|
||||
cached_value_rowids = math_ops.cast(cached_value_rowids, dtype)
|
||||
|
||||
return RaggedTensor(values, row_splits, cached_row_lengths,
|
||||
cached_value_rowids, cached_nrows, internal=True)
|
||||
|
||||
#=============================================================================
|
||||
# Tensor Type Conversions
|
||||
#=============================================================================
|
||||
@ -1048,7 +1152,8 @@ class RaggedTensor(composite_tensor.CompositeTensor):
|
||||
lengths=None,
|
||||
padding=None,
|
||||
ragged_rank=1,
|
||||
name=None):
|
||||
name=None,
|
||||
row_splits_dtype=dtypes.int64):
|
||||
"""Converts a `tf.Tensor` into a `RaggedTensor`.
|
||||
|
||||
The set of absent/default values may be specified using a vector of lengths
|
||||
@ -1096,6 +1201,8 @@ class RaggedTensor(composite_tensor.CompositeTensor):
|
||||
ragged_rank: Integer specifying the ragged rank for the returned
|
||||
`RaggedTensor`. Must be greater than zero.
|
||||
name: A name prefix for the returned tensors (optional).
|
||||
row_splits_dtype: `dtype` for the returned `RaggedTensor`'s `row_splits`
|
||||
tensor. One of `tf.int32` or `tf.int64`.
|
||||
|
||||
Returns:
|
||||
A `RaggedTensor` with the specified `ragged_rank`. The shape of the
|
||||
@ -1103,6 +1210,7 @@ class RaggedTensor(composite_tensor.CompositeTensor):
|
||||
Raises:
|
||||
ValueError: If both `lengths` and `padding` are specified.
|
||||
"""
|
||||
row_splits_dtype = dtypes.as_dtype(row_splits_dtype)
|
||||
if lengths is not None and padding is not None:
|
||||
raise ValueError("Specify lengths or padding, but not both")
|
||||
if not isinstance(ragged_rank, int):
|
||||
@ -1114,7 +1222,7 @@ class RaggedTensor(composite_tensor.CompositeTensor):
|
||||
with ops.name_scope(name, "RaggedFromTensor", [tensor, lengths, padding]):
|
||||
tensor = ops.convert_to_tensor(tensor, name="tensor")
|
||||
tensor.shape.with_rank_at_least(ragged_rank + 1)
|
||||
input_shape = array_ops.shape(tensor, out_type=dtypes.int64)
|
||||
input_shape = array_ops.shape(tensor, out_type=row_splits_dtype)
|
||||
ncols = input_shape[1]
|
||||
|
||||
# Handle ragged_rank>1 via recursion:
|
||||
@ -1125,12 +1233,14 @@ class RaggedTensor(composite_tensor.CompositeTensor):
|
||||
if ragged_rank > 1:
|
||||
# Flatten `tensor` to eliminate all but the last ragged dimension.
|
||||
new_shape = array_ops.concat([
|
||||
constant_op.constant([-1], dtypes.int64), input_shape[ragged_rank:]
|
||||
constant_op.constant([-1], row_splits_dtype),
|
||||
input_shape[ragged_rank:]
|
||||
],
|
||||
axis=0)
|
||||
flattened = array_ops.reshape(tensor, new_shape)
|
||||
# Recursively convert the flattened tensor.
|
||||
values = cls.from_tensor(flattened, lengths, padding)
|
||||
values = cls.from_tensor(flattened, lengths, padding,
|
||||
row_splits_dtype=row_splits_dtype)
|
||||
# The total number of elements in each dimension. E.g., if
|
||||
# input_shape=[3, 4, 5, 6], then dim[2] has 3*4*5 elements in total.
|
||||
dim_size = math_ops.cumprod(input_shape)
|
||||
@ -1167,12 +1277,12 @@ class RaggedTensor(composite_tensor.CompositeTensor):
|
||||
has_default.set_shape(tensor_shape.TensorShape([None, None]))
|
||||
has_default.set_shape(tensor.shape[:2])
|
||||
|
||||
# Use has_default it to find the length of each row: for each
|
||||
# Use has_default to find the length of each row: for each
|
||||
# non-default item in a row, calculate the length that the row needs to
|
||||
# have to include that item; and then take the max of those values
|
||||
# (across each row).
|
||||
has_nondefault = math_ops.logical_not(has_default)
|
||||
has_nondefault = math_ops.cast(has_nondefault, dtypes.int64)
|
||||
has_nondefault = math_ops.cast(has_nondefault, row_splits_dtype)
|
||||
length_for_nondefault_value = (
|
||||
has_nondefault * array_ops.expand_dims(
|
||||
math_ops.range(1, ncols + 1), 0))
|
||||
@ -1198,13 +1308,13 @@ class RaggedTensor(composite_tensor.CompositeTensor):
|
||||
# paddings), then use those to construct splits; and then use masking
|
||||
# to get the corresponding values.
|
||||
lengths = ragged_util.convert_to_int_tensor(lengths, "lengths",
|
||||
dtypes.int64)
|
||||
row_splits_dtype)
|
||||
lengths.shape.assert_has_rank(1)
|
||||
lengths = math_ops.minimum(lengths, ncols)
|
||||
lengths = math_ops.maximum(lengths, 0)
|
||||
limits = math_ops.cumsum(lengths)
|
||||
splits = array_ops.concat(
|
||||
[array_ops.zeros([1], dtypes.int64), limits], axis=0)
|
||||
[array_ops.zeros([1], row_splits_dtype), limits], axis=0)
|
||||
mask = array_ops.sequence_mask(lengths, maxlen=ncols)
|
||||
values = array_ops.boolean_mask(tensor, mask)
|
||||
return cls.from_row_splits(values, splits)
|
||||
@ -1267,9 +1377,10 @@ class RaggedTensor(composite_tensor.CompositeTensor):
|
||||
|
||||
# Get the expected dense shape ([nrows, ncols] + value_shape).
|
||||
rt_row_lengths = [self.row_splits[1:] - self.row_splits[:-1]]
|
||||
nrows = array_ops.shape(self.row_splits, out_type=dtypes.int64)[0] - 1
|
||||
nrows = array_ops.shape(self.row_splits,
|
||||
out_type=self._row_splits.dtype)[0] - 1
|
||||
ncols = math_ops.maximum(math_ops.reduce_max(rt_row_lengths), 0)
|
||||
values_shape = array_ops.shape(values, out_type=dtypes.int64)
|
||||
values_shape = array_ops.shape(values, out_type=self._row_splits.dtype)
|
||||
value_shape = values_shape[1:]
|
||||
nvals = values_shape[0]
|
||||
|
||||
@ -1305,7 +1416,7 @@ class RaggedTensor(composite_tensor.CompositeTensor):
|
||||
return array_ops.gather(values_and_default, indices)
|
||||
|
||||
@classmethod
|
||||
def from_sparse(cls, st_input, name=None):
|
||||
def from_sparse(cls, st_input, name=None, row_splits_dtype=dtypes.int64):
|
||||
"""Converts a 2D `tf.SparseTensor` to a `RaggedTensor`.
|
||||
|
||||
Each row of the `output` `RaggedTensor` will contain the explicit values
|
||||
@ -1327,6 +1438,8 @@ class RaggedTensor(composite_tensor.CompositeTensor):
|
||||
Args:
|
||||
st_input: The sparse tensor to convert. Must have rank 2.
|
||||
name: A name prefix for the returned tensors (optional).
|
||||
row_splits_dtype: `dtype` for the returned `RaggedTensor`'s `row_splits`
|
||||
tensor. One of `tf.int32` or `tf.int64`.
|
||||
|
||||
Returns:
|
||||
A `RaggedTensor` with the same values as `st_input`.
|
||||
@ -1336,6 +1449,7 @@ class RaggedTensor(composite_tensor.CompositeTensor):
|
||||
ValueError: If the number of dimensions in `st_input` is not known
|
||||
statically, or is not two.
|
||||
"""
|
||||
row_splits_dtype = dtypes.as_dtype(row_splits_dtype)
|
||||
if not sparse_tensor.is_sparse(st_input):
|
||||
raise TypeError("Expected SparseTensor, got %s" % type(st_input).__name__)
|
||||
with ops.name_scope(name, "RaggedFromSparse", [st_input]):
|
||||
@ -1360,8 +1474,8 @@ class RaggedTensor(composite_tensor.CompositeTensor):
|
||||
# Treat sparse row indices as segment ids to generate a splits tensor
|
||||
# thta we can pair with the sparse tensor values. (Ignore sparse column
|
||||
# indices.)
|
||||
segment_ids = st_input.indices[:, 0]
|
||||
num_segments = st_input.dense_shape[0]
|
||||
segment_ids = math_ops.cast(st_input.indices[:, 0], row_splits_dtype)
|
||||
num_segments = math_ops.cast(st_input.dense_shape[0], row_splits_dtype)
|
||||
return cls.from_value_rowids(st_input.values, segment_ids, num_segments)
|
||||
|
||||
def to_sparse(self, name=None):
|
||||
@ -1518,6 +1632,50 @@ def is_ragged(value):
|
||||
(RaggedTensor, ragged_tensor_value.RaggedTensorValue))
|
||||
|
||||
|
||||
def match_row_splits_dtypes(*tensors, **kwargs):
|
||||
"""Return a copy of `tensors` with row_splits all having the same dtype.
|
||||
|
||||
Args:
|
||||
*tensors: A list of Tensors or RaggedTensors.
|
||||
**kwargs: If 'return_dtype=True', then return a tuple (dtype, tensors),
|
||||
where `dtype` is the data type used by row-splits, and `tensors` is the
|
||||
converted list of `Tensors` and `RaggedTensors`.
|
||||
Returns:
|
||||
The converted list of `Tensors` and `RaggedTensors`.
|
||||
"""
|
||||
return_dtype = kwargs.pop("return_dtype", False)
|
||||
if kwargs:
|
||||
raise ValueError("Unexpected keyword args %r" % kwargs)
|
||||
|
||||
has_int32 = False
|
||||
has_int64 = False
|
||||
for tensor in tensors:
|
||||
if isinstance(tensor, RaggedTensor):
|
||||
if tensor.row_splits.dtype == dtypes.int32:
|
||||
has_int32 = True
|
||||
else:
|
||||
has_int64 = True
|
||||
|
||||
if has_int32 and has_int64:
|
||||
if not ragged_config.auto_cast_partition_dtype():
|
||||
raise ValueError("Input RaggedTensors have mismatched row_splits dtypes; "
|
||||
"use RaggedTensor.with_row_splits_dtype() to convert "
|
||||
"them to compatible dtypes.")
|
||||
dtype = dtypes.int64
|
||||
tensors = tuple(t.with_row_splits_dtype(dtypes.int64)
|
||||
if isinstance(t, RaggedTensor) else t for t in tensors)
|
||||
|
||||
elif has_int32:
|
||||
dtype = dtypes.int32
|
||||
else:
|
||||
dtype = dtypes.int64
|
||||
|
||||
if return_dtype:
|
||||
return (dtype, tensors)
|
||||
else:
|
||||
return tensors
|
||||
|
||||
|
||||
#===============================================================================
|
||||
# Convert value -> tensor
|
||||
#===============================================================================
|
||||
@ -1606,18 +1764,23 @@ class RaggedTensorType(object):
|
||||
`RaggedTensor`.
|
||||
"""
|
||||
|
||||
def __init__(self, dtype, ragged_rank):
|
||||
def __init__(self, dtype, ragged_rank, row_splits_dtype=dtypes.int64):
|
||||
"""Initializes a RaggedTensorType object.
|
||||
|
||||
Args:
|
||||
dtype: data type of the `RaggedTensor`'s inner values.
|
||||
ragged_rank: ragged_rank of the declared `RaggedTensor`.
|
||||
row_splits_dtype: data type for the `RaggedTensor`'s row splits.
|
||||
One of: `tf.int32` or `tf.int64`.
|
||||
"""
|
||||
row_splits_dtype = dtypes.as_dtype(row_splits_dtype)
|
||||
self._dtype = dtype
|
||||
self._ragged_rank = ragged_rank
|
||||
self._row_splits_dtype = row_splits_dtype
|
||||
|
||||
dtype = property(lambda self: self._dtype)
|
||||
ragged_rank = property(lambda self: self._ragged_rank)
|
||||
row_splits_dtype = property(lambda self: self._row_splits_dtype)
|
||||
|
||||
|
||||
#===============================================================================
|
||||
|
@ -27,6 +27,7 @@ from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops.ragged import ragged_array_ops
|
||||
from tensorflow.python.ops.ragged import ragged_config
|
||||
from tensorflow.python.ops.ragged import ragged_conversion_ops
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
from tensorflow.python.ops.ragged import ragged_util
|
||||
@ -82,7 +83,8 @@ class RaggedTensorDynamicShape(object):
|
||||
`[[[1, 2], [3]], [[4, 5]]]` | 2 | `2, (2, 1), (2, 1, 2)` |
|
||||
"""
|
||||
|
||||
def __init__(self, partitioned_dim_sizes, inner_dim_sizes):
|
||||
def __init__(self, partitioned_dim_sizes, inner_dim_sizes,
|
||||
dim_size_dtype=None):
|
||||
"""Creates a RaggedTensorDynamicShape.
|
||||
|
||||
Args:
|
||||
@ -96,16 +98,19 @@ class RaggedTensorDynamicShape(object):
|
||||
number of inner dimensions. `inner_dim_sizes[n]` is the size of all
|
||||
slices across the `n`th inner dimension (which is the
|
||||
`(len(partitioned_dim_sizes)+n)`th dimension in the overall tensor.
|
||||
dim_size_dtype: dtype for dimension sizes. If not specified, then it
|
||||
is chosen based on the dtypes of `partitioned_dim_sizes` and
|
||||
`inner_dim_sizes`.
|
||||
"""
|
||||
assert isinstance(partitioned_dim_sizes, (list, tuple))
|
||||
|
||||
with ops.name_scope(None, 'RaggedTensorDynamicShape',
|
||||
(partitioned_dim_sizes, inner_dim_sizes)):
|
||||
partitioned_dim_sizes = tuple(
|
||||
ragged_util.convert_to_int_tensor(
|
||||
size, dtype=dtypes.int64, name='partitioned_dimension_size')
|
||||
for size in partitioned_dim_sizes)
|
||||
inner_dim_sizes = ragged_util.convert_to_int_tensor(
|
||||
inner_dim_sizes, dtype=dtypes.int64, name='inner_dim_sizes')
|
||||
ops.convert_to_tensor(size, name='partitioned_dimension_size_%d' % i)
|
||||
for (i, size) in enumerate(partitioned_dim_sizes))
|
||||
inner_dim_sizes = ops.convert_to_tensor(
|
||||
inner_dim_sizes, name='inner_dim_sizes')
|
||||
|
||||
# Validate shapes.
|
||||
if partitioned_dim_sizes:
|
||||
@ -120,6 +125,22 @@ class RaggedTensorDynamicShape(object):
|
||||
raise ValueError('innermost partitioned dimension must be ragged')
|
||||
inner_dim_sizes.shape.assert_has_rank(1)
|
||||
|
||||
# Convert dimension size tensors to a single dtype.
|
||||
if dim_size_dtype is None:
|
||||
dim_size_dtypes = set([p.dtype for p in partitioned_dim_sizes
|
||||
if p.shape.ndims == 1])
|
||||
if not dim_size_dtypes:
|
||||
dim_size_dtype = dtypes.int64
|
||||
elif len(dim_size_dtypes) == 1:
|
||||
dim_size_dtype = dim_size_dtypes.pop()
|
||||
else:
|
||||
if not ragged_config.auto_cast_partition_dtype():
|
||||
raise ValueError('partitioned_dim_sizes must have matching dtypes')
|
||||
dim_size_dtype = dtypes.int64
|
||||
partitioned_dim_sizes = tuple(math_ops.cast(p, dim_size_dtype)
|
||||
for p in partitioned_dim_sizes)
|
||||
inner_dim_sizes = math_ops.cast(inner_dim_sizes, dim_size_dtype)
|
||||
|
||||
self._partitioned_dim_sizes = partitioned_dim_sizes
|
||||
self._inner_dim_sizes = inner_dim_sizes
|
||||
|
||||
@ -137,7 +158,7 @@ class RaggedTensorDynamicShape(object):
|
||||
ragged.
|
||||
|
||||
Args:
|
||||
dim_sizes: List of int64 scalars or vectors.
|
||||
dim_sizes: List of int32 or int64 scalars or vectors.
|
||||
|
||||
Returns:
|
||||
A RaggedTensorDynamicShape.
|
||||
@ -145,8 +166,8 @@ class RaggedTensorDynamicShape(object):
|
||||
with ops.name_scope(None, 'RaggedTensorDynamicShapeFromDimensionSizes',
|
||||
[dim_sizes]):
|
||||
dim_sizes = tuple(
|
||||
ragged_util.convert_to_int_tensor(
|
||||
size, dtype=dtypes.int64, name='dim_sizes') for size in dim_sizes)
|
||||
ops.convert_to_tensor(size, preferred_dtype=dtypes.int64,
|
||||
name='dim_sizes') for size in dim_sizes)
|
||||
# Split the dimensions into partitioned & inner dimensions.
|
||||
inner_split = 0
|
||||
for dim, dim_size in enumerate(dim_sizes):
|
||||
@ -158,7 +179,7 @@ class RaggedTensorDynamicShape(object):
|
||||
dim_sizes[inner_split:])
|
||||
|
||||
@classmethod
|
||||
def from_tensor(cls, rt_input):
|
||||
def from_tensor(cls, rt_input, dim_size_dtype=None):
|
||||
"""Constructs a ragged shape for a potentially ragged tensor."""
|
||||
with ops.name_scope(None, 'RaggedTensorDynamicShapeFromTensor', [rt_input]):
|
||||
rt_input = ragged_tensor.convert_to_tensor_or_ragged_tensor(rt_input)
|
||||
@ -169,7 +190,8 @@ class RaggedTensorDynamicShape(object):
|
||||
(rt_input.nrows(),) + rt_input.nested_row_lengths())
|
||||
return RaggedTensorDynamicShape(
|
||||
partitioned_dim_sizes,
|
||||
array_ops.shape(rt_input.flat_values)[1:])
|
||||
array_ops.shape(rt_input.flat_values)[1:],
|
||||
dim_size_dtype=dim_size_dtype)
|
||||
|
||||
def dimension_size(self, axis):
|
||||
"""Returns the size of slices across the specified dimension."""
|
||||
@ -231,6 +253,11 @@ class RaggedTensorDynamicShape(object):
|
||||
"""The number of inner dimensions, or `None` if not statically known."""
|
||||
return tensor_shape.dimension_value(self._inner_dim_sizes.shape[0])
|
||||
|
||||
@property
|
||||
def dim_size_dtype(self):
|
||||
"""DType used by this shape for dimension sizes."""
|
||||
return self._inner_dim_sizes.dtype
|
||||
|
||||
def broadcast_to_rank(self, rank):
|
||||
"""Adds leading size-1 dimensions to broadcast `self` to the given rank.
|
||||
|
||||
@ -260,7 +287,8 @@ class RaggedTensorDynamicShape(object):
|
||||
return RaggedTensorDynamicShape(partitioned_dims, self._inner_dim_sizes)
|
||||
else:
|
||||
inner_dims = array_ops.concat(
|
||||
[array_ops.ones([dims_to_add], dtypes.int64), self.inner_dim_sizes],
|
||||
[array_ops.ones([dims_to_add], self.dim_size_dtype),
|
||||
self.inner_dim_sizes],
|
||||
axis=0)
|
||||
return RaggedTensorDynamicShape([], inner_dims)
|
||||
|
||||
@ -290,7 +318,7 @@ class RaggedTensorDynamicShape(object):
|
||||
A `RaggedTensorDynamicShape`.
|
||||
"""
|
||||
lengths = ragged_util.convert_to_int_tensor(
|
||||
lengths, name='lengths', dtype=dtypes.int64)
|
||||
lengths, name='lengths', dtype=self.dim_size_dtype)
|
||||
# Check whether lengths is a scalar (for uniform dimensions) or
|
||||
# vector (for ragged dimensions).
|
||||
if lengths.shape.ndims is None:
|
||||
@ -347,7 +375,7 @@ class RaggedTensorDynamicShape(object):
|
||||
def num_slices_in_dimension(self, axis):
|
||||
"""Returns the total number of slices across the indicated dimension."""
|
||||
if axis < 0:
|
||||
return constant_op.constant(1, dtype=dtypes.int64)
|
||||
return constant_op.constant(1, dtype=self.dim_size_dtype)
|
||||
elif self.is_ragged(axis):
|
||||
return math_ops.reduce_sum(self._partitioned_dim_sizes[axis])
|
||||
else:
|
||||
@ -365,7 +393,7 @@ class RaggedTensorDynamicShape(object):
|
||||
splits = array_ops.stack([0, self.num_slices_in_dimension(axis)])
|
||||
else:
|
||||
splits = math_ops.range(
|
||||
array_ops.size(lengths, out_type=dtypes.int64) + 1)
|
||||
array_ops.size(lengths, out_type=self.dim_size_dtype) + 1)
|
||||
repeats = lengths
|
||||
|
||||
partitioned_sizes.append(lengths)
|
||||
@ -404,6 +432,15 @@ class RaggedTensorDynamicShape(object):
|
||||
inner_sizes = self._inner_dim_sizes[axis_in_inner_dims + 1:]
|
||||
return RaggedTensorDynamicShape(partitioned_sizes, inner_sizes)
|
||||
|
||||
def with_dim_size_dtype(self, dtype):
|
||||
if dtype not in (dtypes.int32, dtypes.int64):
|
||||
raise ValueError('dtype must be int32 or int64')
|
||||
if self.dim_size_dtype == dtype:
|
||||
return self
|
||||
return RaggedTensorDynamicShape(
|
||||
[math_ops.cast(p, dtype) for p in self._partitioned_dim_sizes],
|
||||
math_ops.cast(self._inner_dim_sizes, dtype))
|
||||
|
||||
|
||||
def broadcast_dynamic_shape(shape_x, shape_y):
|
||||
"""Returns the shape formed by broadcasting two shapes to be compatible.
|
||||
@ -479,6 +516,17 @@ def _broadcast_to_uniform_shape(rt_input, shape, broadcast_inner_dimensions):
|
||||
|
||||
def _broadcast_to_ragged_shape(rt_input, dst_shape, broadcast_inner_dimensions):
|
||||
"""Broadcasts rt_input to the ragged shape `dst_shape`."""
|
||||
# Check that rt_input and dst_shape have the same row_splits dtype.
|
||||
if (isinstance(rt_input, ragged_tensor.RaggedTensor) and
|
||||
rt_input.row_splits.dtype != dst_shape.dim_size_dtype):
|
||||
if not ragged_config.auto_cast_partition_dtype():
|
||||
raise ValueError('rt_input and dst_shape have different row_split '
|
||||
'dtypes; use RaggedTensor.with_row_splits_dtype() or '
|
||||
'RaggedTensorDynamicShape.with_dim_size_dtype() to '
|
||||
'convert to a compatible dtype.')
|
||||
rt_input = rt_input.with_row_splits_dtype(dtypes.int64)
|
||||
dst_shape = dst_shape.with_dim_size_dtype(dtypes.int64)
|
||||
|
||||
# dst_shape's rank and ragged_rank must be greater than or equal to rt_input's
|
||||
if rt_input.shape.ndims is None or dst_shape.rank is None:
|
||||
raise ValueError('Unable to broadcast: unknown rank')
|
||||
@ -500,7 +548,8 @@ def _broadcast_to_ragged_shape(rt_input, dst_shape, broadcast_inner_dimensions):
|
||||
if ragged_tensor.is_ragged(rt_input):
|
||||
nrows = rt_input.nrows()
|
||||
else:
|
||||
nrows = array_ops.shape(rt_input, out_type=dtypes.int64)[0]
|
||||
nrows = array_ops.shape(rt_input,
|
||||
out_type=dst_shape.dim_size_dtype)[0]
|
||||
rt_input = ragged_tensor.RaggedTensor.from_row_lengths(rt_input, [nrows])
|
||||
|
||||
# Add ragged dimensions to match dst_shape.
|
||||
@ -509,11 +558,13 @@ def _broadcast_to_ragged_shape(rt_input, dst_shape, broadcast_inner_dimensions):
|
||||
rt_input.flat_values.shape.ndims - 1 - dst_shape.num_inner_dimensions)
|
||||
if inner_rank_diff > 0:
|
||||
rt_input = rt_input.with_flat_values(
|
||||
ragged_conversion_ops.from_tensor(
|
||||
rt_input.flat_values, ragged_rank=inner_rank_diff))
|
||||
ragged_tensor.RaggedTensor.from_tensor(
|
||||
rt_input.flat_values, ragged_rank=inner_rank_diff,
|
||||
row_splits_dtype=dst_shape.dim_size_dtype))
|
||||
else:
|
||||
rt_input = ragged_conversion_ops.from_tensor(
|
||||
rt_input, ragged_rank=dst_shape.num_partitioned_dimensions - 1)
|
||||
rt_input = ragged_tensor.RaggedTensor.from_tensor(
|
||||
rt_input, ragged_rank=dst_shape.num_partitioned_dimensions - 1,
|
||||
row_splits_dtype=dst_shape.dim_size_dtype)
|
||||
|
||||
# Do broadcasting for any dimensions that will remain uniform. We can do
|
||||
# these all at once, since they're independent of one another.
|
||||
@ -541,21 +592,24 @@ def _broadcast_to_ragged_shape(rt_input, dst_shape, broadcast_inner_dimensions):
|
||||
for axis in range(dst_shape.num_partitioned_dimensions):
|
||||
if not src_shape.is_ragged(axis) and dst_shape.is_ragged(axis):
|
||||
dst_size = dst_shape.dimension_size(axis)
|
||||
rt_input = _ragged_tile_axis(rt_input, axis, dst_size)
|
||||
rt_input = _ragged_tile_axis(rt_input, axis, dst_size,
|
||||
dst_shape.dim_size_dtype)
|
||||
|
||||
return rt_input
|
||||
|
||||
|
||||
def _ragged_tile_axis(rt_input, axis, repeats):
|
||||
def _ragged_tile_axis(rt_input, axis, repeats, row_splits_dtype):
|
||||
"""Tile a dimension of a RaggedTensor to match a ragged shape."""
|
||||
assert axis > 0 # Outermost dimension may not be ragged.
|
||||
|
||||
if not ragged_tensor.is_ragged(rt_input):
|
||||
rt_input = ragged_conversion_ops.from_tensor(rt_input, ragged_rank=1)
|
||||
rt_input = ragged_conversion_ops.from_tensor(
|
||||
rt_input, ragged_rank=1, row_splits_dtype=row_splits_dtype)
|
||||
|
||||
if axis > 1:
|
||||
return rt_input.with_values(
|
||||
_ragged_tile_axis(rt_input.values, axis - 1, repeats))
|
||||
_ragged_tile_axis(rt_input.values, axis - 1, repeats,
|
||||
row_splits_dtype))
|
||||
else:
|
||||
src_row_splits = rt_input.nested_row_splits
|
||||
src_row_lengths = rt_input.nested_row_lengths()
|
||||
|
@ -32,8 +32,8 @@ from tensorflow.python.platform import googletest
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class RaggedTensorBoundingShapeOp(ragged_test_util.RaggedTensorTestCase,
|
||||
parameterized.TestCase):
|
||||
class RaggedTensorShapeTest(ragged_test_util.RaggedTensorTestCase,
|
||||
parameterized.TestCase):
|
||||
|
||||
def assertShapeEq(self, x, y):
|
||||
assert isinstance(x, RaggedTensorDynamicShape)
|
||||
|
@ -181,7 +181,7 @@ class RaggedTensorTest(ragged_test_util.RaggedTensorTestCase,
|
||||
rt_value = ragged_tensor_value.RaggedTensorValue(values, splits)
|
||||
self.assertEqual(rt_value.row_splits.dtype, np.int64)
|
||||
self.assertEqual(rt_value.shape, (5, None))
|
||||
self.assertEqual(len(rt_value.nested_row_splits), 1)
|
||||
self.assertLen(rt_value.nested_row_splits, 1)
|
||||
self.assertAllEqual(splits, rt_value.row_splits)
|
||||
self.assertAllEqual(values, rt_value.values)
|
||||
self.assertAllEqual(splits, rt_value.nested_row_splits[0])
|
||||
@ -193,7 +193,7 @@ class RaggedTensorTest(ragged_test_util.RaggedTensorTestCase,
|
||||
row_splits=splits2)
|
||||
self.assertEqual(rt_value.row_splits.dtype, np.int64)
|
||||
self.assertEqual(rt_value.shape, (2, None, None))
|
||||
self.assertEqual(len(rt_value.nested_row_splits), 2)
|
||||
self.assertLen(rt_value.nested_row_splits, 2)
|
||||
self.assertAllEqual(splits2, rt_value.row_splits)
|
||||
self.assertAllEqual(splits, rt_value.values.row_splits)
|
||||
self.assertAllEqual(splits2, rt_value.nested_row_splits[0])
|
||||
@ -1078,15 +1078,17 @@ class RaggedTensorTest(ragged_test_util.RaggedTensorTestCase,
|
||||
values = [b'a', b'b', b'c', b'd', b'e', b'f', b'g']
|
||||
row_splits = [0, 2, 5, 6, 6, 7]
|
||||
rt = RaggedTensor.from_row_splits(values, row_splits)
|
||||
splits_type = 'int64'
|
||||
if context.executing_eagerly():
|
||||
expected_str = '<tf.RaggedTensor {}>'.format([[b'a', b'b'],
|
||||
[b'c', b'd', b'e'], [b'f'],
|
||||
[], [b'g']])
|
||||
expected_repr = (
|
||||
'tf.RaggedTensor(values=tf.Tensor([{}], shape=(7,), dtype=string), '
|
||||
'row_splits=tf.Tensor([{}], shape=(6,), dtype=int64))'.format(
|
||||
'row_splits=tf.Tensor([{}], shape=(6,), dtype={}))'.format(
|
||||
' '.join(repr(x) for x in values),
|
||||
' '.join(repr(x) for x in row_splits)))
|
||||
' '.join(repr(x) for x in row_splits),
|
||||
splits_type))
|
||||
self.assertEqual(str(rt), expected_str)
|
||||
self.assertEqual(repr(rt), expected_repr)
|
||||
else:
|
||||
@ -1094,7 +1096,7 @@ class RaggedTensorTest(ragged_test_util.RaggedTensorTestCase,
|
||||
'tf.RaggedTensor(values=Tensor("RaggedFromRowSplits/values:0", '
|
||||
'shape=(7,), dtype=string), row_splits='
|
||||
'Tensor("RaggedFromRowSplits/row_splits:0", '
|
||||
'shape=(6,), dtype=int64))')
|
||||
'shape=(6,), dtype={}))').format(splits_type)
|
||||
self.assertEqual(repr(rt), expected_repr)
|
||||
self.assertEqual(str(rt), expected_repr)
|
||||
|
||||
@ -1145,7 +1147,7 @@ class RaggedTensorTest(ragged_test_util.RaggedTensorTestCase,
|
||||
rt2 = ragged_factory_ops.constant([[[], [1, 2]], [[3]]])
|
||||
with self.test_session() as session:
|
||||
result = session.run({'rt1': rt1, 'rt2': rt2})
|
||||
self.assertCountEqual(sorted(result.keys()), ['rt1', 'rt2'])
|
||||
self.assertCountEqual(result.keys(), ['rt1', 'rt2'])
|
||||
self.assertEqual(result['rt1'].to_list(), [[1, 2, 3], [4]])
|
||||
self.assertEqual(result['rt2'].to_list(), [[[], [1, 2]], [[3]]])
|
||||
|
||||
@ -1166,15 +1168,10 @@ class RaggedTensorTest(ragged_test_util.RaggedTensorTestCase,
|
||||
rt2_feed_val = ragged_factory_ops.constant_value([[[], [1, 2]], [[3]]])
|
||||
|
||||
with self.test_session() as session:
|
||||
result = session.run({
|
||||
'rt1': rt1,
|
||||
'rt2': rt2
|
||||
},
|
||||
feed_dict={
|
||||
rt1: rt1_feed_val,
|
||||
rt2: rt2_feed_val
|
||||
})
|
||||
self.assertCountEqual(sorted(result.keys()), ['rt1', 'rt2'])
|
||||
fetches = {'rt1': rt1, 'rt2': rt2}
|
||||
feeds = {rt1: rt1_feed_val, rt2: rt2_feed_val}
|
||||
result = session.run(fetches, feed_dict=feeds)
|
||||
self.assertCountEqual(result.keys(), ['rt1', 'rt2'])
|
||||
self.assertEqual(result['rt1'].to_list(), [[1, 2, 3], [4]])
|
||||
self.assertEqual(result['rt2'].to_list(), [[[], [1, 2]], [[3]]])
|
||||
|
||||
|
@ -38,13 +38,17 @@ class RaggedTensorValue(object):
|
||||
|
||||
Args:
|
||||
values: A numpy array of any type and shape; or a RaggedTensorValue.
|
||||
row_splits: A 1-D int64 numpy array.
|
||||
row_splits: A 1-D int32 or int64 numpy array.
|
||||
"""
|
||||
if not (isinstance(row_splits, (np.ndarray, np.generic)) and
|
||||
row_splits.dtype == np.int64 and row_splits.ndim == 1):
|
||||
raise TypeError("row_splits must be a 1D int64 numpy array")
|
||||
row_splits.dtype in (np.int64, np.int32) and row_splits.ndim == 1):
|
||||
raise TypeError("row_splits must be a 1D int32 or int64 numpy array")
|
||||
if not isinstance(values, (np.ndarray, np.generic, RaggedTensorValue)):
|
||||
raise TypeError("values must be a numpy array or a RaggedTensorValue")
|
||||
if (isinstance(values, RaggedTensorValue) and
|
||||
row_splits.dtype != values.row_splits.dtype):
|
||||
raise ValueError("row_splits and values.row_splits must have "
|
||||
"the same dtype")
|
||||
self._values = values
|
||||
self._row_splits = row_splits
|
||||
|
||||
|
@ -191,7 +191,6 @@ class RaggedTensorToSparseOpTest(ragged_test_util.RaggedTensorTestCase):
|
||||
|
||||
g1, g2 = gradients_impl.gradients(st.values,
|
||||
[rt1.flat_values, rt2.flat_values])
|
||||
print(g1, g2)
|
||||
self.assertRaggedEqual(g1, [[1.0, 1.0], [1.0, 1.0], [1.0, 1.0]])
|
||||
self.assertRaggedEqual(g2, [[2.0, 2.0], [2.0, 2.0], [2.0, 2.0]])
|
||||
|
||||
|
@ -268,7 +268,7 @@ def repeat_ranges(params, splits, repeats):
|
||||
else:
|
||||
# Optimization: we can just call repeat once, and then slice the result.
|
||||
repeated_splits = repeat(splits, repeats, axis=0)
|
||||
n_splits = array_ops.shape(repeated_splits, out_type=dtypes.int64)[0]
|
||||
n_splits = array_ops.shape(repeated_splits, out_type=repeats.dtype)[0]
|
||||
repeated_starts = repeated_splits[:n_splits - repeats]
|
||||
repeated_limits = repeated_splits[repeats:]
|
||||
|
||||
|
@ -18,7 +18,6 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
@ -108,6 +107,7 @@ def where(condition, x=None, y=None, name=None):
|
||||
else:
|
||||
x = ragged_tensor.convert_to_tensor_or_ragged_tensor(x, name='x')
|
||||
y = ragged_tensor.convert_to_tensor_or_ragged_tensor(y, name='y')
|
||||
condition, x, y = ragged_tensor.match_row_splits_dtypes(condition, x, y)
|
||||
return _elementwise_where(condition, x, y)
|
||||
|
||||
|
||||
@ -145,6 +145,7 @@ def _coordinate_where(condition):
|
||||
selected_coords = _coordinate_where(condition.values)
|
||||
|
||||
# Convert the first index in each coordinate to a row index and column index.
|
||||
condition = condition.with_row_splits_dtype(selected_coords.dtype)
|
||||
first_index = selected_coords[:, 0]
|
||||
selected_rows = array_ops.gather(condition.value_rowids(), first_index)
|
||||
selected_row_starts = array_ops.gather(condition.row_splits, selected_rows)
|
||||
@ -158,9 +159,8 @@ def _coordinate_where(condition):
|
||||
axis=1)
|
||||
|
||||
|
||||
def _nrows(rt_input, out_type=dtypes.int64, name=None):
|
||||
def _nrows(rt_input):
|
||||
if isinstance(rt_input, ragged_tensor.RaggedTensor):
|
||||
return rt_input.nrows(out_type=out_type, name=name)
|
||||
return rt_input.nrows()
|
||||
else:
|
||||
with ops.name_scope(name, 'RaggedNRows', [rt_input]):
|
||||
return array_ops.shape(rt_input, out_type=out_type)[0]
|
||||
return array_ops.shape(rt_input)[0]
|
||||
|
@ -31,7 +31,7 @@ from tensorflow.python.util.tf_export import tf_export
|
||||
# For background on "segments" and "segment ids", see:
|
||||
# https://www.tensorflow.org/api_docs/python/tf/math#Segmentation
|
||||
@tf_export("ragged.row_splits_to_segment_ids")
|
||||
def row_splits_to_segment_ids(splits, name=None):
|
||||
def row_splits_to_segment_ids(splits, name=None, out_type=None):
|
||||
"""Generates the segmentation corresponding to a RaggedTensor `row_splits`.
|
||||
|
||||
Returns an integer vector `segment_ids`, where `segment_ids[i] == j` if
|
||||
@ -43,22 +43,32 @@ def row_splits_to_segment_ids(splits, name=None):
|
||||
```
|
||||
|
||||
Args:
|
||||
splits: A sorted 1-D int64 Tensor. `splits[0]` must be zero.
|
||||
splits: A sorted 1-D integer Tensor. `splits[0]` must be zero.
|
||||
name: A name prefix for the returned tensor (optional).
|
||||
out_type: The dtype for the return value. Defaults to `splits.dtype`,
|
||||
or `tf.int64` if `splits` does not have a dtype.
|
||||
|
||||
Returns:
|
||||
A sorted 1-D int64 Tensor, with `shape=[splits[-1]]`
|
||||
A sorted 1-D integer Tensor, with `shape=[splits[-1]]`
|
||||
|
||||
Raises:
|
||||
ValueError: If `splits` is invalid.
|
||||
"""
|
||||
with ops.name_scope(name, "RaggedSplitsToSegmentIds", [splits]) as name:
|
||||
splits = ops.convert_to_tensor(splits, dtype=dtypes.int64, name="splits")
|
||||
splits = ops.convert_to_tensor(
|
||||
splits, name="splits",
|
||||
preferred_dtype=dtypes.int64)
|
||||
if splits.dtype not in (dtypes.int32, dtypes.int64):
|
||||
raise ValueError("splits must have dtype int32 or int64")
|
||||
splits.shape.assert_has_rank(1)
|
||||
if tensor_shape.dimension_value(splits.shape[0]) == 0:
|
||||
raise ValueError("Invalid row_splits: []")
|
||||
if out_type is None:
|
||||
out_type = splits.dtype
|
||||
else:
|
||||
out_type = dtypes.as_dtype(out_type)
|
||||
row_lengths = splits[1:] - splits[:-1]
|
||||
nrows = array_ops.shape(splits, out_type=dtypes.int64)[-1] - 1
|
||||
nrows = array_ops.shape(splits, out_type=out_type)[-1] - 1
|
||||
indices = math_ops.range(nrows)
|
||||
return ragged_util.repeat(indices, repeats=row_lengths, axis=0)
|
||||
|
||||
@ -66,7 +76,8 @@ def row_splits_to_segment_ids(splits, name=None):
|
||||
# For background on "segments" and "segment ids", see:
|
||||
# https://www.tensorflow.org/api_docs/python/tf/math#Segmentation
|
||||
@tf_export("ragged.segment_ids_to_row_splits")
|
||||
def segment_ids_to_row_splits(segment_ids, num_segments=None, name=None):
|
||||
def segment_ids_to_row_splits(segment_ids, num_segments=None,
|
||||
out_type=None, name=None):
|
||||
"""Generates the RaggedTensor `row_splits` corresponding to a segmentation.
|
||||
|
||||
Returns an integer vector `splits`, where `splits[0] = 0` and
|
||||
@ -81,24 +92,39 @@ def segment_ids_to_row_splits(segment_ids, num_segments=None, name=None):
|
||||
segment_ids: A 1-D integer Tensor.
|
||||
num_segments: A scalar integer indicating the number of segments. Defaults
|
||||
to `max(segment_ids) + 1` (or zero if `segment_ids` is empty).
|
||||
out_type: The dtype for the return value. Defaults to `segment_ids.dtype`,
|
||||
or `tf.int64` if `segment_ids` does not have a dtype.
|
||||
name: A name prefix for the returned tensor (optional).
|
||||
|
||||
Returns:
|
||||
A sorted 1-D int64 Tensor, with `shape=[num_segments + 1]`.
|
||||
A sorted 1-D integer Tensor, with `shape=[num_segments + 1]`.
|
||||
"""
|
||||
if out_type is None:
|
||||
if isinstance(segment_ids, ops.Tensor):
|
||||
out_type = segment_ids.dtype
|
||||
elif isinstance(num_segments, ops.Tensor):
|
||||
out_type = num_segments.dtype
|
||||
else:
|
||||
out_type = dtypes.int64
|
||||
else:
|
||||
out_type = dtypes.as_dtype(out_type)
|
||||
with ops.name_scope(name, "SegmentIdsToRaggedSplits", [segment_ids]) as name:
|
||||
segment_ids = ragged_util.convert_to_int_tensor(segment_ids, "segment_ids")
|
||||
# Note: we cast int64 tensors to int32, since bincount currently only
|
||||
# supports int32 inputs.
|
||||
segment_ids = ragged_util.convert_to_int_tensor(segment_ids, "segment_ids",
|
||||
dtype=dtypes.int32)
|
||||
segment_ids.shape.assert_has_rank(1)
|
||||
if num_segments is not None:
|
||||
num_segments = ragged_util.convert_to_int_tensor(num_segments,
|
||||
"num_segments")
|
||||
"num_segments",
|
||||
dtype=dtypes.int32)
|
||||
num_segments.shape.assert_has_rank(0)
|
||||
|
||||
row_lengths = math_ops.bincount(
|
||||
segment_ids,
|
||||
minlength=num_segments,
|
||||
maxlength=num_segments,
|
||||
dtype=dtypes.int64)
|
||||
dtype=out_type)
|
||||
splits = array_ops.concat([[0], math_ops.cumsum(row_lengths)], axis=0)
|
||||
|
||||
# Update shape information, if possible.
|
||||
|
@ -37,7 +37,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "bounding_shape"
|
||||
argspec: "args=[\'self\', \'axis\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'axis\', \'name\', \'out_type\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "consumers"
|
||||
@ -73,11 +73,11 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "from_sparse"
|
||||
argspec: "args=[\'cls\', \'st_input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'cls\', \'st_input\', \'name\', \'row_splits_dtype\'], varargs=None, keywords=None, defaults=[\'None\', \"<dtype: \'int64\'>\"], "
|
||||
}
|
||||
member_method {
|
||||
name: "from_tensor"
|
||||
argspec: "args=[\'cls\', \'tensor\', \'lengths\', \'padding\', \'ragged_rank\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'1\', \'None\'], "
|
||||
argspec: "args=[\'cls\', \'tensor\', \'lengths\', \'padding\', \'ragged_rank\', \'name\', \'row_splits_dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'1\', \'None\', \"<dtype: \'int64\'>\"], "
|
||||
}
|
||||
member_method {
|
||||
name: "from_value_rowids"
|
||||
@ -89,7 +89,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "nrows"
|
||||
argspec: "args=[\'self\', \'out_type\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'int64\'>\", \'None\'], "
|
||||
argspec: "args=[\'self\', \'out_type\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "row_lengths"
|
||||
@ -123,6 +123,10 @@ tf_class {
|
||||
name: "with_flat_values"
|
||||
argspec: "args=[\'self\', \'new_values\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "with_row_splits_dtype"
|
||||
argspec: "args=[\'self\', \'dtype\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "with_values"
|
||||
argspec: "args=[\'self\', \'new_values\'], varargs=None, keywords=None, defaults=None"
|
||||
|
@ -6,11 +6,11 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "constant"
|
||||
argspec: "args=[\'pylist\', \'dtype\', \'ragged_rank\', \'inner_shape\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
|
||||
argspec: "args=[\'pylist\', \'dtype\', \'ragged_rank\', \'inner_shape\', \'name\', \'row_splits_dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \"<dtype: \'int64\'>\"], "
|
||||
}
|
||||
member_method {
|
||||
name: "constant_value"
|
||||
argspec: "args=[\'pylist\', \'dtype\', \'ragged_rank\', \'inner_shape\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
|
||||
argspec: "args=[\'pylist\', \'dtype\', \'ragged_rank\', \'inner_shape\', \'row_splits_dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'int64\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "map_flat_values"
|
||||
@ -22,14 +22,14 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "range"
|
||||
argspec: "args=[\'starts\', \'limits\', \'deltas\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'None\', \'None\'], "
|
||||
argspec: "args=[\'starts\', \'limits\', \'deltas\', \'dtype\', \'name\', \'row_splits_dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'None\', \'None\', \"<dtype: \'int64\'>\"], "
|
||||
}
|
||||
member_method {
|
||||
name: "row_splits_to_segment_ids"
|
||||
argspec: "args=[\'splits\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'splits\', \'name\', \'out_type\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "segment_ids_to_row_splits"
|
||||
argspec: "args=[\'segment_ids\', \'num_segments\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
argspec: "args=[\'segment_ids\', \'num_segments\', \'out_type\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
|
||||
}
|
||||
}
|
||||
|
@ -2662,7 +2662,7 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "RaggedRange"
|
||||
argspec: "args=[\'starts\', \'limits\', \'deltas\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'starts\', \'limits\', \'deltas\', \'Tsplits\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'int64\'>\", \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "RaggedTensorFromVariant"
|
||||
@ -4302,11 +4302,11 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "UnicodeDecode"
|
||||
argspec: "args=[\'input\', \'input_encoding\', \'errors\', \'replacement_char\', \'replace_control_characters\', \'name\'], varargs=None, keywords=None, defaults=[\'replace\', \'65533\', \'False\', \'None\'], "
|
||||
argspec: "args=[\'input\', \'input_encoding\', \'errors\', \'replacement_char\', \'replace_control_characters\', \'Tsplits\', \'name\'], varargs=None, keywords=None, defaults=[\'replace\', \'65533\', \'False\', \"<dtype: \'int64\'>\", \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "UnicodeDecodeWithOffsets"
|
||||
argspec: "args=[\'input\', \'input_encoding\', \'errors\', \'replacement_char\', \'replace_control_characters\', \'name\'], varargs=None, keywords=None, defaults=[\'replace\', \'65533\', \'False\', \'None\'], "
|
||||
argspec: "args=[\'input\', \'input_encoding\', \'errors\', \'replacement_char\', \'replace_control_characters\', \'Tsplits\', \'name\'], varargs=None, keywords=None, defaults=[\'replace\', \'65533\', \'False\', \"<dtype: \'int64\'>\", \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "UnicodeEncode"
|
||||
|
@ -37,7 +37,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "bounding_shape"
|
||||
argspec: "args=[\'self\', \'axis\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'axis\', \'name\', \'out_type\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "consumers"
|
||||
@ -73,11 +73,11 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "from_sparse"
|
||||
argspec: "args=[\'cls\', \'st_input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'cls\', \'st_input\', \'name\', \'row_splits_dtype\'], varargs=None, keywords=None, defaults=[\'None\', \"<dtype: \'int64\'>\"], "
|
||||
}
|
||||
member_method {
|
||||
name: "from_tensor"
|
||||
argspec: "args=[\'cls\', \'tensor\', \'lengths\', \'padding\', \'ragged_rank\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'1\', \'None\'], "
|
||||
argspec: "args=[\'cls\', \'tensor\', \'lengths\', \'padding\', \'ragged_rank\', \'name\', \'row_splits_dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'1\', \'None\', \"<dtype: \'int64\'>\"], "
|
||||
}
|
||||
member_method {
|
||||
name: "from_value_rowids"
|
||||
@ -89,7 +89,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "nrows"
|
||||
argspec: "args=[\'self\', \'out_type\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'int64\'>\", \'None\'], "
|
||||
argspec: "args=[\'self\', \'out_type\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "row_lengths"
|
||||
@ -123,6 +123,10 @@ tf_class {
|
||||
name: "with_flat_values"
|
||||
argspec: "args=[\'self\', \'new_values\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "with_row_splits_dtype"
|
||||
argspec: "args=[\'self\', \'dtype\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "with_values"
|
||||
argspec: "args=[\'self\', \'new_values\'], varargs=None, keywords=None, defaults=None"
|
||||
|
@ -2,7 +2,7 @@ path: "tensorflow.ragged"
|
||||
tf_module {
|
||||
member_method {
|
||||
name: "constant"
|
||||
argspec: "args=[\'pylist\', \'dtype\', \'ragged_rank\', \'inner_shape\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
|
||||
argspec: "args=[\'pylist\', \'dtype\', \'ragged_rank\', \'inner_shape\', \'name\', \'row_splits_dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \"<dtype: \'int64\'>\"], "
|
||||
}
|
||||
member_method {
|
||||
name: "map_flat_values"
|
||||
@ -10,14 +10,14 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "range"
|
||||
argspec: "args=[\'starts\', \'limits\', \'deltas\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'None\', \'None\'], "
|
||||
argspec: "args=[\'starts\', \'limits\', \'deltas\', \'dtype\', \'name\', \'row_splits_dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'None\', \'None\', \"<dtype: \'int64\'>\"], "
|
||||
}
|
||||
member_method {
|
||||
name: "row_splits_to_segment_ids"
|
||||
argspec: "args=[\'splits\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'splits\', \'name\', \'out_type\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "segment_ids_to_row_splits"
|
||||
argspec: "args=[\'segment_ids\', \'num_segments\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
argspec: "args=[\'segment_ids\', \'num_segments\', \'out_type\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
|
||||
}
|
||||
}
|
||||
|
@ -2662,7 +2662,7 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "RaggedRange"
|
||||
argspec: "args=[\'starts\', \'limits\', \'deltas\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'starts\', \'limits\', \'deltas\', \'Tsplits\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'int64\'>\", \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "RaggedTensorFromVariant"
|
||||
@ -4302,11 +4302,11 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "UnicodeDecode"
|
||||
argspec: "args=[\'input\', \'input_encoding\', \'errors\', \'replacement_char\', \'replace_control_characters\', \'name\'], varargs=None, keywords=None, defaults=[\'replace\', \'65533\', \'False\', \'None\'], "
|
||||
argspec: "args=[\'input\', \'input_encoding\', \'errors\', \'replacement_char\', \'replace_control_characters\', \'Tsplits\', \'name\'], varargs=None, keywords=None, defaults=[\'replace\', \'65533\', \'False\', \"<dtype: \'int64\'>\", \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "UnicodeDecodeWithOffsets"
|
||||
argspec: "args=[\'input\', \'input_encoding\', \'errors\', \'replacement_char\', \'replace_control_characters\', \'name\'], varargs=None, keywords=None, defaults=[\'replace\', \'65533\', \'False\', \'None\'], "
|
||||
argspec: "args=[\'input\', \'input_encoding\', \'errors\', \'replacement_char\', \'replace_control_characters\', \'Tsplits\', \'name\'], varargs=None, keywords=None, defaults=[\'replace\', \'65533\', \'False\', \"<dtype: \'int64\'>\", \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "UnicodeEncode"
|
||||
|
Loading…
x
Reference in New Issue
Block a user