Use template specialization instead of overloaded methods. This is a more appropriate tool here. NFC
PiperOrigin-RevId: 158292035
This commit is contained in:
parent
55f987692a
commit
ba656b2611
@ -41,13 +41,7 @@ class ColumnInterface {
|
||||
virtual int64 FeatureCount(int64 batch) const = 0;
|
||||
|
||||
// Returns the fingerprint of nth feature from the specified batch.
|
||||
InternalType Feature(int64 batch, int64 n) const {
|
||||
InternalType not_used = InternalType();
|
||||
return DoFeature(batch, n, not_used);
|
||||
}
|
||||
|
||||
virtual InternalType DoFeature(int64 batch, int64 n,
|
||||
InternalType not_used) const = 0;
|
||||
virtual InternalType Feature(int64 batch, int64 n) const = 0;
|
||||
|
||||
virtual ~ColumnInterface() {}
|
||||
};
|
||||
@ -68,26 +62,7 @@ class SparseTensorColumn : public ColumnInterface<InternalType> {
|
||||
return feature_counts_[batch];
|
||||
}
|
||||
|
||||
// InternalType is int64 only when using HashCrosser.
|
||||
int64 DoFeature(int64 batch, int64 n, int64 not_used) const {
|
||||
const int64 start = feature_start_indices_[batch];
|
||||
if (DT_STRING == values_.dtype())
|
||||
return Fingerprint64(values_.vec<string>().data()[start + n]);
|
||||
return values_.vec<int64>().data()[start + n];
|
||||
}
|
||||
|
||||
// InternalType is string or StringPiece when using StringCrosser.
|
||||
string DoFeature(int64 batch, int64 n, string not_used) const {
|
||||
const int64 start = feature_start_indices_[batch];
|
||||
if (DT_STRING == values_.dtype())
|
||||
return values_.vec<string>().data()[start + n];
|
||||
return std::to_string(values_.vec<int64>().data()[start + n]);
|
||||
}
|
||||
|
||||
StringPiece DoFeature(int64 batch, int64 n, StringPiece not_used) const {
|
||||
const int64 start = feature_start_indices_[batch];
|
||||
return values_.vec<string>().data()[start + n];
|
||||
}
|
||||
InternalType Feature(int64 batch, int64 n) const override;
|
||||
|
||||
~SparseTensorColumn() override {}
|
||||
|
||||
@ -97,6 +72,31 @@ class SparseTensorColumn : public ColumnInterface<InternalType> {
|
||||
std::vector<int64> feature_start_indices_;
|
||||
};
|
||||
|
||||
// InternalType is int64 only when using HashCrosser.
|
||||
template <>
|
||||
int64 SparseTensorColumn<int64>::Feature(int64 batch, int64 n) const {
|
||||
const int64 start = feature_start_indices_[batch];
|
||||
if (DT_STRING == values_.dtype())
|
||||
return Fingerprint64(values_.vec<string>().data()[start + n]);
|
||||
return values_.vec<int64>().data()[start + n];
|
||||
}
|
||||
|
||||
// InternalType is string or StringPiece when using StringCrosser.
|
||||
template <>
|
||||
string SparseTensorColumn<string>::Feature(int64 batch, int64 n) const {
|
||||
const int64 start = feature_start_indices_[batch];
|
||||
if (DT_STRING == values_.dtype())
|
||||
return values_.vec<string>().data()[start + n];
|
||||
return std::to_string(values_.vec<int64>().data()[start + n]);
|
||||
}
|
||||
|
||||
template <>
|
||||
StringPiece SparseTensorColumn<StringPiece>::Feature(int64 batch,
|
||||
int64 n) const {
|
||||
const int64 start = feature_start_indices_[batch];
|
||||
return values_.vec<string>().data()[start + n];
|
||||
}
|
||||
|
||||
// A column that is backed by a dense tensor.
|
||||
template <typename InternalType>
|
||||
class DenseTensorColumn : public ColumnInterface<InternalType> {
|
||||
@ -105,22 +105,7 @@ class DenseTensorColumn : public ColumnInterface<InternalType> {
|
||||
|
||||
int64 FeatureCount(int64 batch) const override { return tensor_.dim_size(1); }
|
||||
|
||||
// InternalType is int64 only when using HashCrosser.
|
||||
int64 DoFeature(int64 batch, int64 n, int64 not_used) const {
|
||||
if (DT_STRING == tensor_.dtype())
|
||||
return Fingerprint64(tensor_.matrix<string>()(batch, n));
|
||||
return tensor_.matrix<int64>()(batch, n);
|
||||
}
|
||||
|
||||
// Internal type is string or StringPiece when using StringCrosser.
|
||||
string DoFeature(int64 batch, int64 n, string not_used) const {
|
||||
if (DT_STRING == tensor_.dtype()) return tensor_.matrix<string>()(batch, n);
|
||||
return std::to_string(tensor_.matrix<int64>()(batch, n));
|
||||
}
|
||||
|
||||
StringPiece DoFeature(int64 batch, int64 n, StringPiece not_used) const {
|
||||
return tensor_.matrix<string>()(batch, n);
|
||||
}
|
||||
InternalType Feature(int64 batch, int64 n) const override;
|
||||
|
||||
~DenseTensorColumn() override {}
|
||||
|
||||
@ -128,6 +113,27 @@ class DenseTensorColumn : public ColumnInterface<InternalType> {
|
||||
const Tensor& tensor_;
|
||||
};
|
||||
|
||||
// InternalType is int64 only when using HashCrosser.
|
||||
template <>
|
||||
int64 DenseTensorColumn<int64>::Feature(int64 batch, int64 n) const {
|
||||
if (DT_STRING == tensor_.dtype())
|
||||
return Fingerprint64(tensor_.matrix<string>()(batch, n));
|
||||
return tensor_.matrix<int64>()(batch, n);
|
||||
}
|
||||
|
||||
// Internal type is string or StringPiece when using StringCrosser.
|
||||
template <>
|
||||
string DenseTensorColumn<string>::Feature(int64 batch, int64 n) const {
|
||||
if (DT_STRING == tensor_.dtype()) return tensor_.matrix<string>()(batch, n);
|
||||
return std::to_string(tensor_.matrix<int64>()(batch, n));
|
||||
}
|
||||
|
||||
template <>
|
||||
StringPiece DenseTensorColumn<StringPiece>::Feature(int64 batch,
|
||||
int64 n) const {
|
||||
return tensor_.matrix<string>()(batch, n);
|
||||
}
|
||||
|
||||
// Updates Output tensors with sparse crosses.
|
||||
template <typename OutType>
|
||||
class OutputUpdater {
|
||||
|
@ -41,13 +41,7 @@ class ColumnInterface {
|
||||
virtual int64 FeatureCount(int64 batch) const = 0;
|
||||
|
||||
// Returns the fingerprint of nth feature from the specified batch.
|
||||
InternalType Feature(int64 batch, int64 n) const {
|
||||
InternalType not_used = InternalType();
|
||||
return DoFeature(batch, n, not_used);
|
||||
}
|
||||
|
||||
virtual InternalType DoFeature(int64 batch, int64 n,
|
||||
InternalType not_used) const = 0;
|
||||
virtual InternalType Feature(int64 batch, int64 n) const = 0;
|
||||
|
||||
virtual ~ColumnInterface() {}
|
||||
};
|
||||
@ -68,26 +62,7 @@ class SparseTensorColumn : public ColumnInterface<InternalType> {
|
||||
return feature_counts_[batch];
|
||||
}
|
||||
|
||||
// InternalType is int64 only when using HashCrosser.
|
||||
int64 DoFeature(int64 batch, int64 n, int64 not_used) const {
|
||||
const int64 start = feature_start_indices_[batch];
|
||||
if (DT_STRING == values_.dtype())
|
||||
return Fingerprint64(values_.vec<string>().data()[start + n]);
|
||||
return values_.vec<int64>().data()[start + n];
|
||||
}
|
||||
|
||||
// InternalType is string or StringPiece when using StringCrosser.
|
||||
string DoFeature(int64 batch, int64 n, string not_used) const {
|
||||
const int64 start = feature_start_indices_[batch];
|
||||
if (DT_STRING == values_.dtype())
|
||||
return values_.vec<string>().data()[start + n];
|
||||
return std::to_string(values_.vec<int64>().data()[start + n]);
|
||||
}
|
||||
|
||||
StringPiece DoFeature(int64 batch, int64 n, StringPiece not_used) const {
|
||||
const int64 start = feature_start_indices_[batch];
|
||||
return values_.vec<string>().data()[start + n];
|
||||
}
|
||||
InternalType Feature(int64 batch, int64 n) const override;
|
||||
|
||||
~SparseTensorColumn() override {}
|
||||
|
||||
@ -97,6 +72,31 @@ class SparseTensorColumn : public ColumnInterface<InternalType> {
|
||||
std::vector<int64> feature_start_indices_;
|
||||
};
|
||||
|
||||
// InternalType is int64 only when using HashCrosser.
|
||||
template <>
|
||||
int64 SparseTensorColumn<int64>::Feature(int64 batch, int64 n) const {
|
||||
const int64 start = feature_start_indices_[batch];
|
||||
if (DT_STRING == values_.dtype())
|
||||
return Fingerprint64(values_.vec<string>().data()[start + n]);
|
||||
return values_.vec<int64>().data()[start + n];
|
||||
}
|
||||
|
||||
// InternalType is string or StringPiece when using StringCrosser.
|
||||
template <>
|
||||
string SparseTensorColumn<string>::Feature(int64 batch, int64 n) const {
|
||||
const int64 start = feature_start_indices_[batch];
|
||||
if (DT_STRING == values_.dtype())
|
||||
return values_.vec<string>().data()[start + n];
|
||||
return std::to_string(values_.vec<int64>().data()[start + n]);
|
||||
}
|
||||
|
||||
template <>
|
||||
StringPiece SparseTensorColumn<StringPiece>::Feature(int64 batch,
|
||||
int64 n) const {
|
||||
const int64 start = feature_start_indices_[batch];
|
||||
return values_.vec<string>().data()[start + n];
|
||||
}
|
||||
|
||||
// A column that is backed by a dense tensor.
|
||||
template <typename InternalType>
|
||||
class DenseTensorColumn : public ColumnInterface<InternalType> {
|
||||
@ -105,22 +105,7 @@ class DenseTensorColumn : public ColumnInterface<InternalType> {
|
||||
|
||||
int64 FeatureCount(int64 batch) const override { return tensor_.dim_size(1); }
|
||||
|
||||
// InternalType is int64 only when using HashCrosser.
|
||||
int64 DoFeature(int64 batch, int64 n, int64 not_used) const {
|
||||
if (DT_STRING == tensor_.dtype())
|
||||
return Fingerprint64(tensor_.matrix<string>()(batch, n));
|
||||
return tensor_.matrix<int64>()(batch, n);
|
||||
}
|
||||
|
||||
// Internal type is string or StringPiece when using StringCrosser.
|
||||
string DoFeature(int64 batch, int64 n, string not_used) const {
|
||||
if (DT_STRING == tensor_.dtype()) return tensor_.matrix<string>()(batch, n);
|
||||
return std::to_string(tensor_.matrix<int64>()(batch, n));
|
||||
}
|
||||
|
||||
StringPiece DoFeature(int64 batch, int64 n, StringPiece not_used) const {
|
||||
return tensor_.matrix<string>()(batch, n);
|
||||
}
|
||||
InternalType Feature(int64 batch, int64 n) const override;
|
||||
|
||||
~DenseTensorColumn() override {}
|
||||
|
||||
@ -128,6 +113,27 @@ class DenseTensorColumn : public ColumnInterface<InternalType> {
|
||||
const Tensor& tensor_;
|
||||
};
|
||||
|
||||
// InternalType is int64 only when using HashCrosser.
|
||||
template <>
|
||||
int64 DenseTensorColumn<int64>::Feature(int64 batch, int64 n) const {
|
||||
if (DT_STRING == tensor_.dtype())
|
||||
return Fingerprint64(tensor_.matrix<string>()(batch, n));
|
||||
return tensor_.matrix<int64>()(batch, n);
|
||||
}
|
||||
|
||||
// Internal type is string or StringPiece when using StringCrosser.
|
||||
template <>
|
||||
string DenseTensorColumn<string>::Feature(int64 batch, int64 n) const {
|
||||
if (DT_STRING == tensor_.dtype()) return tensor_.matrix<string>()(batch, n);
|
||||
return std::to_string(tensor_.matrix<int64>()(batch, n));
|
||||
}
|
||||
|
||||
template <>
|
||||
StringPiece DenseTensorColumn<StringPiece>::Feature(int64 batch,
|
||||
int64 n) const {
|
||||
return tensor_.matrix<string>()(batch, n);
|
||||
}
|
||||
|
||||
// Updates Output tensors with sparse crosses.
|
||||
template <typename OutType>
|
||||
class OutputUpdater {
|
||||
|
Loading…
Reference in New Issue
Block a user