Use the safe sparse tensor API that returns errors rather than crashing
in all TensorFlow core kernels. PiperOrigin-RevId: 204782675
This commit is contained in:
parent
c654856e38
commit
814f9ccd9b
tensorflow
contrib/boosted_trees/lib/utils
core/kernels
ctc_loss_op.cc
data
deserialize_sparse_string_op.ccedit_distance_op.ccreshape_util.ccsdca_internal.ccsdca_internal.hserialize_sparse_op.ccset_kernels.ccsparse_concat_op.ccsparse_reduce_op.ccsparse_reorder_op.ccsparse_slice_grad_op.ccsparse_slice_op.ccsparse_softmax_op.ccsparse_split_op.ccsparse_tensors_map_ops.ccsparse_to_dense_op.cc@ -16,6 +16,7 @@
|
||||
#include "tensorflow/contrib/boosted_trees/lib/utils/batch_features.h"
|
||||
#include "tensorflow/contrib/boosted_trees/lib/utils/macros.h"
|
||||
#include "tensorflow/contrib/boosted_trees/lib/utils/tensor_utils.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace boosted_trees {
|
||||
@ -96,9 +97,11 @@ Status BatchFeatures::Initialize(
|
||||
"Sparse float feature shape incompatible with batch size."));
|
||||
auto tensor_shape = TensorShape({shape_flat(0), shape_flat(1)});
|
||||
auto order_dims = sparse::SparseTensor::VarDimArray({0, 1});
|
||||
sparse_float_feature_columns_.emplace_back(sparse_float_feature_indices,
|
||||
sparse_float_feature_values,
|
||||
tensor_shape, order_dims);
|
||||
sparse::SparseTensor sparse_tensor;
|
||||
TF_RETURN_IF_ERROR(sparse::SparseTensor::Create(
|
||||
sparse_float_feature_indices, sparse_float_feature_values, tensor_shape,
|
||||
order_dims, &sparse_tensor));
|
||||
sparse_float_feature_columns_.push_back(std::move(sparse_tensor));
|
||||
}
|
||||
|
||||
// Read sparse int features.
|
||||
@ -136,9 +139,11 @@ Status BatchFeatures::Initialize(
|
||||
"Sparse int feature shape incompatible with batch size."));
|
||||
auto tensor_shape = TensorShape({shape_flat(0), shape_flat(1)});
|
||||
auto order_dims = sparse::SparseTensor::VarDimArray({0, 1});
|
||||
sparse_int_feature_columns_.emplace_back(sparse_int_feature_indices,
|
||||
sparse_int_feature_values,
|
||||
tensor_shape, order_dims);
|
||||
sparse::SparseTensor sparse_tensor;
|
||||
TF_RETURN_IF_ERROR(sparse::SparseTensor::Create(
|
||||
sparse_int_feature_indices, sparse_int_feature_values, tensor_shape,
|
||||
order_dims, &sparse_tensor));
|
||||
sparse_int_feature_columns_.push_back(std::move(sparse_tensor));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -43,27 +43,35 @@ TEST_F(ExamplesIterableTest, Iterate) {
|
||||
test::AsTensor<int64>({0, 0, 2, 0, 3, 0, 4, 0}, {4, 2});
|
||||
auto sparse_float_values1 = test::AsTensor<float>({-3.0f, 0.0f, 5.0f, 0.0f});
|
||||
auto sparse_float_shape1 = TensorShape({8, 1});
|
||||
sparse::SparseTensor sparse_float_tensor1(
|
||||
sparse_float_indices1, sparse_float_values1, sparse_float_shape1);
|
||||
sparse::SparseTensor sparse_float_tensor1;
|
||||
TF_ASSERT_OK(
|
||||
sparse::SparseTensor::Create(sparse_float_indices1, sparse_float_values1,
|
||||
sparse_float_shape1, &sparse_float_tensor1));
|
||||
auto sparse_float_indices2 = test::AsTensor<int64>(
|
||||
{0, 1, 1, 0, 2, 1, 3, 0, 4, 1, 5, 0, 5, 1, 7, 0}, {8, 2});
|
||||
auto sparse_float_values2 =
|
||||
test::AsTensor<float>({1.f, 4.0f, 3.f, 7.0f, 4.3f, 9.0f, 0.8f, -4.0f});
|
||||
auto sparse_float_shape2 = TensorShape({8, 2});
|
||||
sparse::SparseTensor sparse_float_tensor2(
|
||||
sparse_float_indices2, sparse_float_values2, sparse_float_shape2);
|
||||
sparse::SparseTensor sparse_float_tensor2;
|
||||
TF_ASSERT_OK(
|
||||
sparse::SparseTensor::Create(sparse_float_indices2, sparse_float_values2,
|
||||
sparse_float_shape2, &sparse_float_tensor2));
|
||||
auto sparse_int_indices1 =
|
||||
test::AsTensor<int64>({0, 0, 0, 1, 1, 0, 3, 0, 3, 1, 7, 0}, {6, 2});
|
||||
auto sparse_int_values1 = test::AsTensor<int64>({1, 8, 0, 2, 0, 5});
|
||||
auto sparse_int_shape1 = TensorShape({8, 2});
|
||||
sparse::SparseTensor sparse_int_tensor1(
|
||||
sparse_int_indices1, sparse_int_values1, sparse_int_shape1);
|
||||
sparse::SparseTensor sparse_int_tensor1;
|
||||
TF_ASSERT_OK(
|
||||
sparse::SparseTensor::Create(sparse_int_indices1, sparse_int_values1,
|
||||
sparse_int_shape1, &sparse_int_tensor1));
|
||||
auto sparse_int_indices2 =
|
||||
test::AsTensor<int64>({1, 0, 2, 0, 3, 0, 4, 0}, {4, 2});
|
||||
auto sparse_int_values2 = test::AsTensor<int64>({7, 13, 4, 0});
|
||||
auto sparse_int_shape2 = TensorShape({8, 1});
|
||||
sparse::SparseTensor sparse_int_tensor2(
|
||||
sparse_int_indices2, sparse_int_values2, sparse_int_shape2);
|
||||
sparse::SparseTensor sparse_int_tensor2;
|
||||
TF_ASSERT_OK(
|
||||
sparse::SparseTensor::Create(sparse_int_indices2, sparse_int_values2,
|
||||
sparse_int_shape2, &sparse_int_tensor2));
|
||||
|
||||
auto validate_example_features = [](int64 example_idx,
|
||||
const Example& example) {
|
||||
|
@ -100,8 +100,10 @@ class CTCLossOp : public OpKernel {
|
||||
|
||||
TensorShape labels_shape({batch_size, max_label_len});
|
||||
std::vector<int64> order{0, 1};
|
||||
sparse::SparseTensor labels_sp(*labels_indices, *labels_values,
|
||||
labels_shape, order);
|
||||
sparse::SparseTensor labels_sp;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, sparse::SparseTensor::Create(*labels_indices, *labels_values,
|
||||
labels_shape, order, &labels_sp));
|
||||
|
||||
Status labels_sp_valid = labels_sp.IndicesValid();
|
||||
OP_REQUIRES(ctx, labels_sp_valid.ok(),
|
||||
|
@ -252,10 +252,12 @@ class SparseTensorSliceDatasetOp : public DatasetOpKernel {
|
||||
previous_batch_index = next_batch_index;
|
||||
}
|
||||
gtl::InlinedVector<int64, 8> std_order(dense_shape->NumElements(), 0);
|
||||
sparse::SparseTensor sparse_tensor(
|
||||
*indices, *values, TensorShape(dense_shape->vec<int64>()), std_order);
|
||||
|
||||
*output = new Dataset<T>(ctx, sparse_tensor);
|
||||
sparse::SparseTensor tensor;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, sparse::SparseTensor::Create(
|
||||
*indices, *values, TensorShape(dense_shape->vec<int64>()),
|
||||
std_order, &tensor));
|
||||
*output = new Dataset<T>(ctx, std::move(tensor));
|
||||
}
|
||||
|
||||
private:
|
||||
|
@ -165,7 +165,10 @@ class DeserializeSparseOp : public OpKernel {
|
||||
std::vector<SparseTensor> tensors;
|
||||
tensors.reserve(num_sparse_tensors);
|
||||
for (int i = 0; i < num_sparse_tensors; ++i) {
|
||||
tensors.emplace_back(indices[i], values[i], shape, std_order);
|
||||
SparseTensor tensor;
|
||||
OP_REQUIRES_OK(context, SparseTensor::Create(indices[i], values[i], shape,
|
||||
std_order, &tensor));
|
||||
tensors.push_back(std::move(tensor));
|
||||
}
|
||||
|
||||
gtl::optional<SparseTensor> maybe_output;
|
||||
|
@ -133,10 +133,15 @@ class EditDistanceOp : public OpKernel {
|
||||
std::vector<int64> sorted_order(truth_st_shape.dims());
|
||||
std::iota(sorted_order.begin(), sorted_order.end(), 0);
|
||||
|
||||
sparse::SparseTensor hypothesis(*hypothesis_indices, *hypothesis_values,
|
||||
hypothesis_st_shape, sorted_order);
|
||||
sparse::SparseTensor truth(*truth_indices, *truth_values, truth_st_shape,
|
||||
sorted_order);
|
||||
sparse::SparseTensor hypothesis;
|
||||
OP_REQUIRES_OK(ctx, sparse::SparseTensor::Create(
|
||||
*hypothesis_indices, *hypothesis_values,
|
||||
hypothesis_st_shape, sorted_order, &hypothesis));
|
||||
|
||||
sparse::SparseTensor truth;
|
||||
OP_REQUIRES_OK(ctx, sparse::SparseTensor::Create(
|
||||
*truth_indices, *truth_values, truth_st_shape,
|
||||
sorted_order, &truth));
|
||||
|
||||
// Group dims 0, 1, ..., RANK - 1. The very last dim is assumed
|
||||
// to store the variable length sequences.
|
||||
|
@ -28,7 +28,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/tensor_util.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
||||
#include "tensorflow/core/util/sparse/sparse_tensor.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/kernels/sdca_internal.h"
|
||||
|
||||
#include <limits>
|
||||
#include <numeric>
|
||||
#include <random>
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
|
@ -43,8 +43,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/random/distribution_sampler.h"
|
||||
#include "tensorflow/core/lib/strings/stringprintf.h"
|
||||
#include "tensorflow/core/util/guarded_philox_random.h"
|
||||
#include "tensorflow/core/util/sparse/group_iterator.h"
|
||||
#include "tensorflow/core/util/sparse/sparse_tensor.h"
|
||||
#include "tensorflow/core/util/work_sharder.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
@ -190,8 +190,10 @@ class SerializeManySparseOp : public SerializeManySparseOpBase<U> {
|
||||
TensorShape tensor_input_shape(input_shape->vec<int64>());
|
||||
gtl::InlinedVector<int64, 8> std_order(rank);
|
||||
std::iota(std_order.begin(), std_order.end(), 0);
|
||||
SparseTensor input_st(*input_indices, *input_values, tensor_input_shape,
|
||||
std_order);
|
||||
SparseTensor input_st;
|
||||
OP_REQUIRES_OK(context, SparseTensor::Create(*input_indices, *input_values,
|
||||
tensor_input_shape, std_order,
|
||||
&input_st));
|
||||
|
||||
auto input_shape_t = input_shape->vec<int64>();
|
||||
const int64 N = input_shape_t(0);
|
||||
|
@ -63,9 +63,9 @@ Status GroupShape(const VarDimArray& input_shape, ShapeArray* grouped_shape) {
|
||||
|
||||
// Build `SparseTensor` from indices, values, and shape in inputs
|
||||
// [base_index, base_index + 3), and validate its rank and indices.
|
||||
sparse::SparseTensor SparseTensorFromContext(OpKernelContext* ctx,
|
||||
const int32 base_index,
|
||||
bool validate_indices) {
|
||||
Status SparseTensorFromContext(OpKernelContext* ctx, const int32 base_index,
|
||||
bool validate_indices,
|
||||
sparse::SparseTensor* tensor) {
|
||||
// Assume row-major order.
|
||||
const TensorShape shape =
|
||||
TensorShape(ctx->input(base_index + 2).vec<int64>());
|
||||
@ -73,13 +73,8 @@ sparse::SparseTensor SparseTensorFromContext(OpKernelContext* ctx,
|
||||
std::vector<int64> order(shape.dims());
|
||||
std::iota(order.begin(), order.end(), 0);
|
||||
|
||||
const sparse::SparseTensor st(ctx->input(base_index),
|
||||
ctx->input(base_index + 1), shape, order);
|
||||
if (validate_indices) {
|
||||
Status s = st.IndicesValid();
|
||||
if (!s.ok()) ctx->SetStatus(s);
|
||||
}
|
||||
return st;
|
||||
return sparse::SparseTensor::Create(
|
||||
ctx->input(base_index), ctx->input(base_index + 1), shape, order, tensor);
|
||||
}
|
||||
|
||||
// TODO(ptucker): CheckGroup is just a sanity check on the result of
|
||||
@ -253,11 +248,13 @@ class SetSizeOp : public OpKernel {
|
||||
|
||||
template <typename T>
|
||||
void SetSizeOp<T>::Compute(OpKernelContext* ctx) {
|
||||
const sparse::SparseTensor set_st =
|
||||
SparseTensorFromContext(ctx, 0, validate_indices_);
|
||||
sparse::SparseTensor set_st;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
SparseTensorFromContext(ctx, 0, validate_indices_, &set_st));
|
||||
OP_REQUIRES_OK(ctx, set_st.IndicesValid());
|
||||
|
||||
// Output shape is same as input except for last dimension, which reduces to
|
||||
// the set size of values along that dimension.
|
||||
// Output shape is same as input except for last dimension, which reduces
|
||||
// to the set size of values along that dimension.
|
||||
ShapeArray output_shape;
|
||||
OP_REQUIRES_OK(ctx, GroupShape(set_st.shape(), &output_shape));
|
||||
const auto output_strides = Strides(output_shape);
|
||||
@ -484,8 +481,10 @@ void SetOperationOp<T>::ComputeDenseToDense(OpKernelContext* ctx) const {
|
||||
template <typename T>
|
||||
void SetOperationOp<T>::ComputeDenseToSparse(OpKernelContext* ctx) const {
|
||||
const Tensor& set1_t = ctx->input(0);
|
||||
const sparse::SparseTensor set2_st =
|
||||
SparseTensorFromContext(ctx, 1, validate_indices_);
|
||||
sparse::SparseTensor set2_st;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
SparseTensorFromContext(ctx, 1, validate_indices_, &set2_st));
|
||||
OP_REQUIRES_OK(ctx, set2_st.IndicesValid());
|
||||
// The following should stay in sync with `_dense_to_sparse_shape` shape
|
||||
// assertions in python/ops/set_ops.py, and `SetShapeFn` for
|
||||
// `DenseToSparseSetOperation` in ops/set_ops.cc.
|
||||
@ -597,10 +596,15 @@ const std::vector<int64> GROUP_ITER_END;
|
||||
// with the same first n-1 dimensions in set1 and set2.
|
||||
template <typename T>
|
||||
void SetOperationOp<T>::ComputeSparseToSparse(OpKernelContext* ctx) const {
|
||||
const sparse::SparseTensor set1_st =
|
||||
SparseTensorFromContext(ctx, 0, validate_indices_);
|
||||
const sparse::SparseTensor set2_st =
|
||||
SparseTensorFromContext(ctx, 3, validate_indices_);
|
||||
sparse::SparseTensor set1_st;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
SparseTensorFromContext(ctx, 0, validate_indices_, &set1_st));
|
||||
OP_REQUIRES_OK(ctx, set1_st.IndicesValid());
|
||||
|
||||
sparse::SparseTensor set2_st;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
SparseTensorFromContext(ctx, 3, validate_indices_, &set2_st));
|
||||
|
||||
// The following should stay in sync with `_sparse_to_sparse_shape` shape
|
||||
// assertions in python/ops/set_ops.py, and `SetShapeFn` for
|
||||
// `SparseToSparseSetOperation` in ops/set_ops.cc.
|
||||
|
@ -124,9 +124,12 @@ class SparseConcatOp : public OpKernel {
|
||||
std::vector<sparse::SparseTensor> sp_inputs;
|
||||
for (int i = 0; i < N; ++i) {
|
||||
const TensorShape current_shape(shapes[i].vec<int64>());
|
||||
sp_inputs.emplace_back(tensor::DeepCopy(inds[i]),
|
||||
tensor::DeepCopy(vals[i]), current_shape,
|
||||
std_order);
|
||||
sparse::SparseTensor tensor;
|
||||
OP_REQUIRES_OK(context,
|
||||
sparse::SparseTensor::Create(
|
||||
tensor::DeepCopy(inds[i]), tensor::DeepCopy(vals[i]),
|
||||
current_shape, std_order, &tensor));
|
||||
sp_inputs.push_back(std::move(tensor));
|
||||
sp_inputs[i].Reorder<T>(concat_order);
|
||||
}
|
||||
|
||||
|
@ -172,8 +172,10 @@ class SparseReduceOp : public OpKernel {
|
||||
// making deep copies here. Remove this if/when we change Reorder()'s
|
||||
// semantics.
|
||||
const auto shape_vec = shape_t->vec<int64>();
|
||||
SparseTensor sp(tensor::DeepCopy(*indices_t), tensor::DeepCopy(*values_t),
|
||||
TensorShape(shape_vec));
|
||||
SparseTensor sp;
|
||||
OP_REQUIRES_OK(ctx, SparseTensor::Create(
|
||||
tensor::DeepCopy(*indices_t), tensor::DeepCopy(*values_t),
|
||||
TensorShape(shape_vec), &sp));
|
||||
ReduceDetails reduction = SparseTensorReduceHelper(
|
||||
sp, reduction_axes_t->flat<int32>(), keep_dims_);
|
||||
|
||||
@ -260,8 +262,10 @@ class SparseReduceSparseOp : public OpKernel {
|
||||
|
||||
OP_REQUIRES_OK(ctx, ValidateInputs(shape_t, reduction_axes_t));
|
||||
|
||||
SparseTensor sp(tensor::DeepCopy(*indices_t), tensor::DeepCopy(*values_t),
|
||||
TensorShape(shape_t->vec<int64>()));
|
||||
SparseTensor sp;
|
||||
OP_REQUIRES_OK(ctx, SparseTensor::Create(tensor::DeepCopy(*indices_t),
|
||||
tensor::DeepCopy(*values_t),
|
||||
TensorShape(shape_t->vec<int64>()), &sp));
|
||||
ReduceDetails reduction = SparseTensorReduceHelper(
|
||||
sp, reduction_axes_t->flat<int32>(), keep_dims_);
|
||||
|
||||
|
@ -60,16 +60,21 @@ class SparseReorderOp : public OpKernel {
|
||||
std::iota(std_order.begin(), std_order.end(), 0);
|
||||
|
||||
// Check if the sparse tensor is already ordered correctly
|
||||
sparse::SparseTensor input_sp(input_ind, input_val, input_shape, std_order);
|
||||
sparse::SparseTensor input_sp;
|
||||
OP_REQUIRES_OK(
|
||||
context, sparse::SparseTensor::Create(input_ind, input_val, input_shape,
|
||||
std_order, &input_sp));
|
||||
|
||||
if (input_sp.IndicesValid().ok()) {
|
||||
context->set_output(0, input_sp.indices());
|
||||
context->set_output(1, input_sp.values());
|
||||
} else {
|
||||
// Deep-copy the input Tensors, then reorder in-place
|
||||
sparse::SparseTensor reordered_sp(tensor::DeepCopy(input_ind),
|
||||
tensor::DeepCopy(input_val),
|
||||
input_shape);
|
||||
sparse::SparseTensor reordered_sp;
|
||||
OP_REQUIRES_OK(context,
|
||||
sparse::SparseTensor::Create(tensor::DeepCopy(input_ind),
|
||||
tensor::DeepCopy(input_val),
|
||||
input_shape, &reordered_sp));
|
||||
reordered_sp.Reorder<T>(std_order);
|
||||
context->set_output(0, reordered_sp.indices());
|
||||
context->set_output(1, reordered_sp.values());
|
||||
|
@ -18,7 +18,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_util.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/util/sparse/sparse_tensor.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
|
@ -66,8 +66,11 @@ class SparseSliceOp : public OpKernel {
|
||||
"Expected size to be a vector of length ", input_dims,
|
||||
" but got length ", input_size.NumElements()));
|
||||
|
||||
sparse::SparseTensor sparse_tensor(input_indices, input_values,
|
||||
TensorShape(input_shape.vec<int64>()));
|
||||
sparse::SparseTensor sparse_tensor;
|
||||
OP_REQUIRES_OK(context,
|
||||
sparse::SparseTensor::Create(
|
||||
input_indices, input_values,
|
||||
TensorShape(input_shape.vec<int64>()), &sparse_tensor));
|
||||
|
||||
const gtl::ArraySlice<int64> start(input_start.flat<int64>().data(),
|
||||
input_dims);
|
||||
|
@ -69,8 +69,11 @@ class SparseSoftmaxOp : public OpKernel {
|
||||
|
||||
const int nnz = static_cast<int>(indices_t->dim_size(0));
|
||||
const int rank = static_cast<int>(indices_t->dim_size(1));
|
||||
SparseTensor st(tensor::DeepCopy(*indices_t), tensor::DeepCopy(*values_t),
|
||||
TensorShape(shape_t->flat<int64>()));
|
||||
SparseTensor st;
|
||||
OP_REQUIRES_OK(
|
||||
context, SparseTensor::Create(
|
||||
tensor::DeepCopy(*indices_t), tensor::DeepCopy(*values_t),
|
||||
TensorShape(shape_t->flat<int64>()), &st));
|
||||
|
||||
Tensor *output_values = nullptr;
|
||||
OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape({nnz}),
|
||||
|
@ -63,10 +63,16 @@ class SparseSplitOp : public OpKernel {
|
||||
input_shape.vec<int64>()(split_dim), "), got ",
|
||||
num_split_));
|
||||
|
||||
sparse::SparseTensor sparse_tensor(input_indices, input_values,
|
||||
TensorShape(input_shape.vec<int64>()));
|
||||
const std::vector<sparse::SparseTensor> outputs =
|
||||
sparse::SparseTensor::Split<T>(sparse_tensor, split_dim, num_split_);
|
||||
sparse::SparseTensor sparse_tensor;
|
||||
OP_REQUIRES_OK(context,
|
||||
sparse::SparseTensor::Create(
|
||||
input_indices, input_values,
|
||||
TensorShape(input_shape.vec<int64>()), &sparse_tensor));
|
||||
|
||||
std::vector<sparse::SparseTensor> outputs;
|
||||
OP_REQUIRES_OK(context,
|
||||
sparse::SparseTensor::Split<T>(sparse_tensor, split_dim,
|
||||
num_split_, &outputs));
|
||||
|
||||
for (int slice_index = 0; slice_index < num_split_; ++slice_index) {
|
||||
context->set_output(slice_index, outputs[slice_index].indices());
|
||||
|
@ -93,8 +93,9 @@ class SparseTensorsMap : public ResourceBase {
|
||||
const Tensor* ix = sp_iter->second.indices.AccessTensor(ctx);
|
||||
const Tensor* values = sp_iter->second.values.AccessTensor(ctx);
|
||||
const auto& shape = sp_iter->second.shape;
|
||||
sparse_tensors->emplace_back(*ix, *values, shape);
|
||||
|
||||
SparseTensor tensor;
|
||||
TF_RETURN_IF_ERROR(SparseTensor::Create(*ix, *values, shape, &tensor));
|
||||
sparse_tensors->push_back(std::move(tensor));
|
||||
sp_tensors_.erase(sp_iter);
|
||||
}
|
||||
}
|
||||
@ -195,7 +196,9 @@ class AddSparseToTensorsMapOp : public SparseTensorAccessingOp {
|
||||
TensorShapeUtils::MakeShape(input_shape->vec<int64>().data(),
|
||||
input_shape->NumElements(),
|
||||
&input_shape_object));
|
||||
SparseTensor st(*input_indices, *input_values, input_shape_object);
|
||||
SparseTensor st;
|
||||
OP_REQUIRES_OK(context, SparseTensor::Create(*input_indices, *input_values,
|
||||
input_shape_object, &st));
|
||||
int64 handle;
|
||||
OP_REQUIRES_OK(context, map->AddSparseTensor(context, st, &handle));
|
||||
|
||||
@ -253,8 +256,10 @@ class AddManySparseToTensorsMapOp : public SparseTensorAccessingOp {
|
||||
TensorShape tensor_input_shape(input_shape->vec<int64>());
|
||||
gtl::InlinedVector<int64, 8> std_order(rank);
|
||||
std::iota(std_order.begin(), std_order.end(), 0);
|
||||
SparseTensor input_st(*input_indices, *input_values, tensor_input_shape,
|
||||
std_order);
|
||||
SparseTensor input_st;
|
||||
OP_REQUIRES_OK(context, SparseTensor::Create(*input_indices, *input_values,
|
||||
tensor_input_shape, std_order,
|
||||
&input_st));
|
||||
|
||||
auto input_shape_t = input_shape->vec<int64>();
|
||||
const int64 N = input_shape_t(0);
|
||||
@ -300,7 +305,10 @@ class AddManySparseToTensorsMapOp : public SparseTensorAccessingOp {
|
||||
output_values_t(i) = values(i);
|
||||
}
|
||||
|
||||
SparseTensor st_i(output_indices, output_values, output_shape);
|
||||
SparseTensor st_i;
|
||||
OP_REQUIRES_OK(context,
|
||||
SparseTensor::Create(output_indices, output_values,
|
||||
output_shape, &st_i));
|
||||
int64 handle;
|
||||
OP_REQUIRES_OK(context, map->AddSparseTensor(context, st_i, &handle));
|
||||
sparse_handles_t(b) = handle;
|
||||
@ -311,7 +319,9 @@ class AddManySparseToTensorsMapOp : public SparseTensorAccessingOp {
|
||||
if (visited.size() < N) {
|
||||
Tensor empty_indices(DT_INT64, {0, rank - 1});
|
||||
Tensor empty_values(DataTypeToEnum<T>::value, {0});
|
||||
SparseTensor empty_st(empty_indices, empty_values, output_shape);
|
||||
SparseTensor empty_st;
|
||||
OP_REQUIRES_OK(context, SparseTensor::Create(empty_indices, empty_values,
|
||||
output_shape, &empty_st));
|
||||
|
||||
for (int64 b = 0; b < N; ++b) {
|
||||
// We skipped this batch entry.
|
||||
@ -466,13 +476,15 @@ class TakeManySparseFromTensorsMapOp : public SparseTensorAccessingOp {
|
||||
std::vector<SparseTensor> tensors_to_concat;
|
||||
tensors_to_concat.reserve(N);
|
||||
for (int i = 0; i < N; ++i) {
|
||||
tensors_to_concat.emplace_back(std::move(indices_to_concat[i]),
|
||||
std::move(values_to_concat[i]),
|
||||
preconcat_shape, std_order);
|
||||
SparseTensor tensor;
|
||||
OP_REQUIRES_OK(context,
|
||||
SparseTensor::Create(std::move(indices_to_concat[i]),
|
||||
std::move(values_to_concat[i]),
|
||||
preconcat_shape, std_order, &tensor));
|
||||
tensors_to_concat.push_back(std::move(tensor));
|
||||
}
|
||||
|
||||
SparseTensor output(SparseTensor::Concat<T>(tensors_to_concat));
|
||||
|
||||
auto output = SparseTensor::Concat<T>(tensors_to_concat);
|
||||
Tensor final_output_shape(DT_INT64, TensorShape({output.dims()}));
|
||||
|
||||
std::copy_n(output.shape().data(), output.dims(),
|
||||
|
@ -119,8 +119,10 @@ class SparseToDense : public OpKernel {
|
||||
// Assume SparseTensor is lexicographically sorted.
|
||||
gtl::InlinedVector<int64, 8> order(output->shape().dims());
|
||||
std::iota(order.begin(), order.end(), 0);
|
||||
sparse::SparseTensor st(indices_shaped, sparse_values_b, output->shape(),
|
||||
order);
|
||||
sparse::SparseTensor st;
|
||||
OP_REQUIRES_OK(c,
|
||||
sparse::SparseTensor::Create(indices_shaped, sparse_values_b,
|
||||
output->shape(), order, &st));
|
||||
|
||||
if (validate_indices_) {
|
||||
OP_REQUIRES_OK(c, st.IndicesValid());
|
||||
|
Loading…
Reference in New Issue
Block a user