From 6702143f14d1fbf21a3d556ac50c29f96c3fd024 Mon Sep 17 00:00:00 2001 From: Fei Hu Date: Mon, 22 Apr 2019 19:42:13 -0700 Subject: [PATCH] Improve the Cardinality function and validate the input count --- tensorflow/core/kernels/data/shuffle_dataset_op.cc | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/tensorflow/core/kernels/data/shuffle_dataset_op.cc b/tensorflow/core/kernels/data/shuffle_dataset_op.cc index f426e3cc465..287a7c946c0 100644 --- a/tensorflow/core/kernels/data/shuffle_dataset_op.cc +++ b/tensorflow/core/kernels/data/shuffle_dataset_op.cc @@ -63,7 +63,15 @@ class ShuffleDatasetOpBase : public UnaryDatasetOpKernel { return input_->output_shapes(); } - int64 Cardinality() const override { return input_->Cardinality(); } + int64 Cardinality() const override { + if (count_ == -1 || input_->Cardinality() == kInfiniteCardinality) { + return kInfiniteCardinality; + } else if (input_->Cardinality() == kUnknownCardinality) { + return kUnknownCardinality; + } else { + return input_->Cardinality() * count_; + } + } protected: template @@ -645,6 +653,10 @@ class ShuffleAndRepeatDatasetOp : public ShuffleDatasetOpBase { int64 count; OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "count", &count)); + OP_REQUIRES(ctx, count > 0 || count == -1, + errors::InvalidArgument( + "count must be greater than zero or equal to -1.")); + // By TensorFlow convention, if both seeds are 0, then shuffling should be // seeded non-deterministically. if (seed == 0 && seed2 == 0) {