Improve the Cardinality function and validate the input count
This commit is contained in:
parent
1fa2e34d11
commit
6702143f14
@ -63,7 +63,15 @@ class ShuffleDatasetOpBase : public UnaryDatasetOpKernel {
|
|||||||
return input_->output_shapes();
|
return input_->output_shapes();
|
||||||
}
|
}
|
||||||
|
|
||||||
int64 Cardinality() const override { return input_->Cardinality(); }
|
int64 Cardinality() const override {
|
||||||
|
if (count_ == -1 || input_->Cardinality() == kInfiniteCardinality) {
|
||||||
|
return kInfiniteCardinality;
|
||||||
|
} else if (input_->Cardinality() == kUnknownCardinality) {
|
||||||
|
return kUnknownCardinality;
|
||||||
|
} else {
|
||||||
|
return input_->Cardinality() * count_;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
template <class T>
|
template <class T>
|
||||||
@ -645,6 +653,10 @@ class ShuffleAndRepeatDatasetOp : public ShuffleDatasetOpBase {
|
|||||||
int64 count;
|
int64 count;
|
||||||
OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, "count", &count));
|
OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, "count", &count));
|
||||||
|
|
||||||
|
OP_REQUIRES(ctx, count > 0 || count == -1,
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"count must be greater than zero or equal to -1."));
|
||||||
|
|
||||||
// By TensorFlow convention, if both seeds are 0, then shuffling should be
|
// By TensorFlow convention, if both seeds are 0, then shuffling should be
|
||||||
// seeded non-deterministically.
|
// seeded non-deterministically.
|
||||||
if (seed == 0 && seed2 == 0) {
|
if (seed == 0 && seed2 == 0) {
|
||||||
|
Loading…
Reference in New Issue
Block a user