Use template specialization instead of overloaded methods. This is a more appropriate tool here. NFC

PiperOrigin-RevId: 158292035
This commit is contained in:
A. Unique TensorFlower 2017-06-07 11:18:16 -07:00 committed by TensorFlower Gardener
parent 55f987692a
commit ba656b2611
2 changed files with 98 additions and 86 deletions

View File

@ -41,13 +41,7 @@ class ColumnInterface {
virtual int64 FeatureCount(int64 batch) const = 0;
// Returns the fingerprint of nth feature from the specified batch.
InternalType Feature(int64 batch, int64 n) const {
InternalType not_used = InternalType();
return DoFeature(batch, n, not_used);
}
virtual InternalType DoFeature(int64 batch, int64 n,
InternalType not_used) const = 0;
virtual InternalType Feature(int64 batch, int64 n) const = 0;
virtual ~ColumnInterface() {}
};
@ -68,26 +62,7 @@ class SparseTensorColumn : public ColumnInterface<InternalType> {
return feature_counts_[batch];
}
// InternalType is int64 only when using HashCrosser.
int64 DoFeature(int64 batch, int64 n, int64 not_used) const {
const int64 start = feature_start_indices_[batch];
if (DT_STRING == values_.dtype())
return Fingerprint64(values_.vec<string>().data()[start + n]);
return values_.vec<int64>().data()[start + n];
}
// InternalType is string or StringPiece when using StringCrosser.
string DoFeature(int64 batch, int64 n, string not_used) const {
const int64 start = feature_start_indices_[batch];
if (DT_STRING == values_.dtype())
return values_.vec<string>().data()[start + n];
return std::to_string(values_.vec<int64>().data()[start + n]);
}
StringPiece DoFeature(int64 batch, int64 n, StringPiece not_used) const {
const int64 start = feature_start_indices_[batch];
return values_.vec<string>().data()[start + n];
}
InternalType Feature(int64 batch, int64 n) const override;
~SparseTensorColumn() override {}
@ -97,6 +72,31 @@ class SparseTensorColumn : public ColumnInterface<InternalType> {
std::vector<int64> feature_start_indices_;
};
// InternalType is int64 only when using HashCrosser.
template <>
int64 SparseTensorColumn<int64>::Feature(int64 batch, int64 n) const {
const int64 start = feature_start_indices_[batch];
if (DT_STRING == values_.dtype())
return Fingerprint64(values_.vec<string>().data()[start + n]);
return values_.vec<int64>().data()[start + n];
}
// InternalType is string or StringPiece when using StringCrosser.
template <>
string SparseTensorColumn<string>::Feature(int64 batch, int64 n) const {
const int64 start = feature_start_indices_[batch];
if (DT_STRING == values_.dtype())
return values_.vec<string>().data()[start + n];
return std::to_string(values_.vec<int64>().data()[start + n]);
}
template <>
StringPiece SparseTensorColumn<StringPiece>::Feature(int64 batch,
int64 n) const {
const int64 start = feature_start_indices_[batch];
return values_.vec<string>().data()[start + n];
}
// A column that is backed by a dense tensor.
template <typename InternalType>
class DenseTensorColumn : public ColumnInterface<InternalType> {
@ -105,22 +105,7 @@ class DenseTensorColumn : public ColumnInterface<InternalType> {
int64 FeatureCount(int64 batch) const override { return tensor_.dim_size(1); }
// InternalType is int64 only when using HashCrosser.
int64 DoFeature(int64 batch, int64 n, int64 not_used) const {
if (DT_STRING == tensor_.dtype())
return Fingerprint64(tensor_.matrix<string>()(batch, n));
return tensor_.matrix<int64>()(batch, n);
}
// Internal type is string or StringPiece when using StringCrosser.
string DoFeature(int64 batch, int64 n, string not_used) const {
if (DT_STRING == tensor_.dtype()) return tensor_.matrix<string>()(batch, n);
return std::to_string(tensor_.matrix<int64>()(batch, n));
}
StringPiece DoFeature(int64 batch, int64 n, StringPiece not_used) const {
return tensor_.matrix<string>()(batch, n);
}
InternalType Feature(int64 batch, int64 n) const override;
~DenseTensorColumn() override {}
@ -128,6 +113,27 @@ class DenseTensorColumn : public ColumnInterface<InternalType> {
const Tensor& tensor_;
};
// InternalType is int64 only when using HashCrosser.
template <>
int64 DenseTensorColumn<int64>::Feature(int64 batch, int64 n) const {
if (DT_STRING == tensor_.dtype())
return Fingerprint64(tensor_.matrix<string>()(batch, n));
return tensor_.matrix<int64>()(batch, n);
}
// Internal type is string or StringPiece when using StringCrosser.
template <>
string DenseTensorColumn<string>::Feature(int64 batch, int64 n) const {
if (DT_STRING == tensor_.dtype()) return tensor_.matrix<string>()(batch, n);
return std::to_string(tensor_.matrix<int64>()(batch, n));
}
template <>
StringPiece DenseTensorColumn<StringPiece>::Feature(int64 batch,
int64 n) const {
return tensor_.matrix<string>()(batch, n);
}
// Updates Output tensors with sparse crosses.
template <typename OutType>
class OutputUpdater {

View File

@ -41,13 +41,7 @@ class ColumnInterface {
virtual int64 FeatureCount(int64 batch) const = 0;
// Returns the fingerprint of nth feature from the specified batch.
InternalType Feature(int64 batch, int64 n) const {
InternalType not_used = InternalType();
return DoFeature(batch, n, not_used);
}
virtual InternalType DoFeature(int64 batch, int64 n,
InternalType not_used) const = 0;
virtual InternalType Feature(int64 batch, int64 n) const = 0;
virtual ~ColumnInterface() {}
};
@ -68,26 +62,7 @@ class SparseTensorColumn : public ColumnInterface<InternalType> {
return feature_counts_[batch];
}
// InternalType is int64 only when using HashCrosser.
int64 DoFeature(int64 batch, int64 n, int64 not_used) const {
const int64 start = feature_start_indices_[batch];
if (DT_STRING == values_.dtype())
return Fingerprint64(values_.vec<string>().data()[start + n]);
return values_.vec<int64>().data()[start + n];
}
// InternalType is string or StringPiece when using StringCrosser.
string DoFeature(int64 batch, int64 n, string not_used) const {
const int64 start = feature_start_indices_[batch];
if (DT_STRING == values_.dtype())
return values_.vec<string>().data()[start + n];
return std::to_string(values_.vec<int64>().data()[start + n]);
}
StringPiece DoFeature(int64 batch, int64 n, StringPiece not_used) const {
const int64 start = feature_start_indices_[batch];
return values_.vec<string>().data()[start + n];
}
InternalType Feature(int64 batch, int64 n) const override;
~SparseTensorColumn() override {}
@ -97,6 +72,31 @@ class SparseTensorColumn : public ColumnInterface<InternalType> {
std::vector<int64> feature_start_indices_;
};
// InternalType is int64 only when using HashCrosser.
template <>
int64 SparseTensorColumn<int64>::Feature(int64 batch, int64 n) const {
const int64 start = feature_start_indices_[batch];
if (DT_STRING == values_.dtype())
return Fingerprint64(values_.vec<string>().data()[start + n]);
return values_.vec<int64>().data()[start + n];
}
// InternalType is string or StringPiece when using StringCrosser.
template <>
string SparseTensorColumn<string>::Feature(int64 batch, int64 n) const {
const int64 start = feature_start_indices_[batch];
if (DT_STRING == values_.dtype())
return values_.vec<string>().data()[start + n];
return std::to_string(values_.vec<int64>().data()[start + n]);
}
template <>
StringPiece SparseTensorColumn<StringPiece>::Feature(int64 batch,
int64 n) const {
const int64 start = feature_start_indices_[batch];
return values_.vec<string>().data()[start + n];
}
// A column that is backed by a dense tensor.
template <typename InternalType>
class DenseTensorColumn : public ColumnInterface<InternalType> {
@ -105,22 +105,7 @@ class DenseTensorColumn : public ColumnInterface<InternalType> {
int64 FeatureCount(int64 batch) const override { return tensor_.dim_size(1); }
// InternalType is int64 only when using HashCrosser.
int64 DoFeature(int64 batch, int64 n, int64 not_used) const {
if (DT_STRING == tensor_.dtype())
return Fingerprint64(tensor_.matrix<string>()(batch, n));
return tensor_.matrix<int64>()(batch, n);
}
// Internal type is string or StringPiece when using StringCrosser.
string DoFeature(int64 batch, int64 n, string not_used) const {
if (DT_STRING == tensor_.dtype()) return tensor_.matrix<string>()(batch, n);
return std::to_string(tensor_.matrix<int64>()(batch, n));
}
StringPiece DoFeature(int64 batch, int64 n, StringPiece not_used) const {
return tensor_.matrix<string>()(batch, n);
}
InternalType Feature(int64 batch, int64 n) const override;
~DenseTensorColumn() override {}
@ -128,6 +113,27 @@ class DenseTensorColumn : public ColumnInterface<InternalType> {
const Tensor& tensor_;
};
// InternalType is int64 only when using HashCrosser.
template <>
int64 DenseTensorColumn<int64>::Feature(int64 batch, int64 n) const {
if (DT_STRING == tensor_.dtype())
return Fingerprint64(tensor_.matrix<string>()(batch, n));
return tensor_.matrix<int64>()(batch, n);
}
// Internal type is string or StringPiece when using StringCrosser.
template <>
string DenseTensorColumn<string>::Feature(int64 batch, int64 n) const {
if (DT_STRING == tensor_.dtype()) return tensor_.matrix<string>()(batch, n);
return std::to_string(tensor_.matrix<int64>()(batch, n));
}
template <>
StringPiece DenseTensorColumn<StringPiece>::Feature(int64 batch,
int64 n) const {
return tensor_.matrix<string>()(batch, n);
}
// Updates Output tensors with sparse crosses.
template <typename OutType>
class OutputUpdater {