Improve the Cardinality function and validate the input count

This commit is contained in:
Fei Hu 2019-04-22 19:42:13 -07:00
parent 1fa2e34d11
commit 6702143f14

View File

@ -63,7 +63,15 @@ class ShuffleDatasetOpBase : public UnaryDatasetOpKernel {
return input_->output_shapes();
}
int64 Cardinality() const override { return input_->Cardinality(); }
int64 Cardinality() const override {
if (count_ == -1 || input_->Cardinality() == kInfiniteCardinality) {
return kInfiniteCardinality;
} else if (input_->Cardinality() == kUnknownCardinality) {
return kUnknownCardinality;
} else {
return input_->Cardinality() * count_;
}
}
protected:
template <class T>
@ -645,6 +653,10 @@ class ShuffleAndRepeatDatasetOp : public ShuffleDatasetOpBase {
int64 count;
OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, "count", &count));
OP_REQUIRES(ctx, count > 0 || count == -1,
errors::InvalidArgument(
"count must be greater than zero or equal to -1."));
// By TensorFlow convention, if both seeds are 0, then shuffling should be
// seeded non-deterministically.
if (seed == 0 && seed2 == 0) {