diff --git a/tensorflow/contrib/layers/kernels/sparse_feature_cross_kernel.cc b/tensorflow/contrib/layers/kernels/sparse_feature_cross_kernel.cc index 219473153bd..72df272af89 100644 --- a/tensorflow/contrib/layers/kernels/sparse_feature_cross_kernel.cc +++ b/tensorflow/contrib/layers/kernels/sparse_feature_cross_kernel.cc @@ -41,13 +41,7 @@ class ColumnInterface { virtual int64 FeatureCount(int64 batch) const = 0; // Returns the fingerprint of nth feature from the specified batch. - InternalType Feature(int64 batch, int64 n) const { - InternalType not_used = InternalType(); - return DoFeature(batch, n, not_used); - } - - virtual InternalType DoFeature(int64 batch, int64 n, - InternalType not_used) const = 0; + virtual InternalType Feature(int64 batch, int64 n) const = 0; virtual ~ColumnInterface() {} }; @@ -68,26 +62,7 @@ class SparseTensorColumn : public ColumnInterface { return feature_counts_[batch]; } - // InternalType is int64 only when using HashCrosser. - int64 DoFeature(int64 batch, int64 n, int64 not_used) const { - const int64 start = feature_start_indices_[batch]; - if (DT_STRING == values_.dtype()) - return Fingerprint64(values_.vec().data()[start + n]); - return values_.vec().data()[start + n]; - } - - // InternalType is string or StringPiece when using StringCrosser. - string DoFeature(int64 batch, int64 n, string not_used) const { - const int64 start = feature_start_indices_[batch]; - if (DT_STRING == values_.dtype()) - return values_.vec().data()[start + n]; - return std::to_string(values_.vec().data()[start + n]); - } - - StringPiece DoFeature(int64 batch, int64 n, StringPiece not_used) const { - const int64 start = feature_start_indices_[batch]; - return values_.vec().data()[start + n]; - } + InternalType Feature(int64 batch, int64 n) const override; ~SparseTensorColumn() override {} @@ -97,6 +72,31 @@ class SparseTensorColumn : public ColumnInterface { std::vector feature_start_indices_; }; +// InternalType is int64 only when using HashCrosser. +template <> +int64 SparseTensorColumn::Feature(int64 batch, int64 n) const { + const int64 start = feature_start_indices_[batch]; + if (DT_STRING == values_.dtype()) + return Fingerprint64(values_.vec().data()[start + n]); + return values_.vec().data()[start + n]; +} + +// InternalType is string or StringPiece when using StringCrosser. +template <> +string SparseTensorColumn::Feature(int64 batch, int64 n) const { + const int64 start = feature_start_indices_[batch]; + if (DT_STRING == values_.dtype()) + return values_.vec().data()[start + n]; + return std::to_string(values_.vec().data()[start + n]); +} + +template <> +StringPiece SparseTensorColumn::Feature(int64 batch, + int64 n) const { + const int64 start = feature_start_indices_[batch]; + return values_.vec().data()[start + n]; +} + // A column that is backed by a dense tensor. template class DenseTensorColumn : public ColumnInterface { @@ -105,22 +105,7 @@ class DenseTensorColumn : public ColumnInterface { int64 FeatureCount(int64 batch) const override { return tensor_.dim_size(1); } - // InternalType is int64 only when using HashCrosser. - int64 DoFeature(int64 batch, int64 n, int64 not_used) const { - if (DT_STRING == tensor_.dtype()) - return Fingerprint64(tensor_.matrix()(batch, n)); - return tensor_.matrix()(batch, n); - } - - // Internal type is string or StringPiece when using StringCrosser. - string DoFeature(int64 batch, int64 n, string not_used) const { - if (DT_STRING == tensor_.dtype()) return tensor_.matrix()(batch, n); - return std::to_string(tensor_.matrix()(batch, n)); - } - - StringPiece DoFeature(int64 batch, int64 n, StringPiece not_used) const { - return tensor_.matrix()(batch, n); - } + InternalType Feature(int64 batch, int64 n) const override; ~DenseTensorColumn() override {} @@ -128,6 +113,27 @@ class DenseTensorColumn : public ColumnInterface { const Tensor& tensor_; }; +// InternalType is int64 only when using HashCrosser. +template <> +int64 DenseTensorColumn::Feature(int64 batch, int64 n) const { + if (DT_STRING == tensor_.dtype()) + return Fingerprint64(tensor_.matrix()(batch, n)); + return tensor_.matrix()(batch, n); +} + +// Internal type is string or StringPiece when using StringCrosser. +template <> +string DenseTensorColumn::Feature(int64 batch, int64 n) const { + if (DT_STRING == tensor_.dtype()) return tensor_.matrix()(batch, n); + return std::to_string(tensor_.matrix()(batch, n)); +} + +template <> +StringPiece DenseTensorColumn::Feature(int64 batch, + int64 n) const { + return tensor_.matrix()(batch, n); +} + // Updates Output tensors with sparse crosses. template class OutputUpdater { diff --git a/tensorflow/core/kernels/sparse_cross_op.cc b/tensorflow/core/kernels/sparse_cross_op.cc index ed93caad331..c7bf250fad7 100644 --- a/tensorflow/core/kernels/sparse_cross_op.cc +++ b/tensorflow/core/kernels/sparse_cross_op.cc @@ -41,13 +41,7 @@ class ColumnInterface { virtual int64 FeatureCount(int64 batch) const = 0; // Returns the fingerprint of nth feature from the specified batch. - InternalType Feature(int64 batch, int64 n) const { - InternalType not_used = InternalType(); - return DoFeature(batch, n, not_used); - } - - virtual InternalType DoFeature(int64 batch, int64 n, - InternalType not_used) const = 0; + virtual InternalType Feature(int64 batch, int64 n) const = 0; virtual ~ColumnInterface() {} }; @@ -68,26 +62,7 @@ class SparseTensorColumn : public ColumnInterface { return feature_counts_[batch]; } - // InternalType is int64 only when using HashCrosser. - int64 DoFeature(int64 batch, int64 n, int64 not_used) const { - const int64 start = feature_start_indices_[batch]; - if (DT_STRING == values_.dtype()) - return Fingerprint64(values_.vec().data()[start + n]); - return values_.vec().data()[start + n]; - } - - // InternalType is string or StringPiece when using StringCrosser. - string DoFeature(int64 batch, int64 n, string not_used) const { - const int64 start = feature_start_indices_[batch]; - if (DT_STRING == values_.dtype()) - return values_.vec().data()[start + n]; - return std::to_string(values_.vec().data()[start + n]); - } - - StringPiece DoFeature(int64 batch, int64 n, StringPiece not_used) const { - const int64 start = feature_start_indices_[batch]; - return values_.vec().data()[start + n]; - } + InternalType Feature(int64 batch, int64 n) const override; ~SparseTensorColumn() override {} @@ -97,6 +72,31 @@ class SparseTensorColumn : public ColumnInterface { std::vector feature_start_indices_; }; +// InternalType is int64 only when using HashCrosser. +template <> +int64 SparseTensorColumn::Feature(int64 batch, int64 n) const { + const int64 start = feature_start_indices_[batch]; + if (DT_STRING == values_.dtype()) + return Fingerprint64(values_.vec().data()[start + n]); + return values_.vec().data()[start + n]; +} + +// InternalType is string or StringPiece when using StringCrosser. +template <> +string SparseTensorColumn::Feature(int64 batch, int64 n) const { + const int64 start = feature_start_indices_[batch]; + if (DT_STRING == values_.dtype()) + return values_.vec().data()[start + n]; + return std::to_string(values_.vec().data()[start + n]); +} + +template <> +StringPiece SparseTensorColumn::Feature(int64 batch, + int64 n) const { + const int64 start = feature_start_indices_[batch]; + return values_.vec().data()[start + n]; +} + // A column that is backed by a dense tensor. template class DenseTensorColumn : public ColumnInterface { @@ -105,22 +105,7 @@ class DenseTensorColumn : public ColumnInterface { int64 FeatureCount(int64 batch) const override { return tensor_.dim_size(1); } - // InternalType is int64 only when using HashCrosser. - int64 DoFeature(int64 batch, int64 n, int64 not_used) const { - if (DT_STRING == tensor_.dtype()) - return Fingerprint64(tensor_.matrix()(batch, n)); - return tensor_.matrix()(batch, n); - } - - // Internal type is string or StringPiece when using StringCrosser. - string DoFeature(int64 batch, int64 n, string not_used) const { - if (DT_STRING == tensor_.dtype()) return tensor_.matrix()(batch, n); - return std::to_string(tensor_.matrix()(batch, n)); - } - - StringPiece DoFeature(int64 batch, int64 n, StringPiece not_used) const { - return tensor_.matrix()(batch, n); - } + InternalType Feature(int64 batch, int64 n) const override; ~DenseTensorColumn() override {} @@ -128,6 +113,27 @@ class DenseTensorColumn : public ColumnInterface { const Tensor& tensor_; }; +// InternalType is int64 only when using HashCrosser. +template <> +int64 DenseTensorColumn::Feature(int64 batch, int64 n) const { + if (DT_STRING == tensor_.dtype()) + return Fingerprint64(tensor_.matrix()(batch, n)); + return tensor_.matrix()(batch, n); +} + +// Internal type is string or StringPiece when using StringCrosser. +template <> +string DenseTensorColumn::Feature(int64 batch, int64 n) const { + if (DT_STRING == tensor_.dtype()) return tensor_.matrix()(batch, n); + return std::to_string(tensor_.matrix()(batch, n)); +} + +template <> +StringPiece DenseTensorColumn::Feature(int64 batch, + int64 n) const { + return tensor_.matrix()(batch, n); +} + // Updates Output tensors with sparse crosses. template class OutputUpdater {