Adds StringToHashBucketFast Op to TF android selective registration builds, by split out the fast op into its own .h/.cc
PiperOrigin-RevId: 342980949 Change-Id: Idbba0fe1d2b4e433dc5cee100ec233396cfadce1
This commit is contained in:
parent
8117e74787
commit
df7522f0a5
@ -5042,7 +5042,14 @@ STRING_DEPS = [
|
||||
|
||||
tf_kernel_library(
|
||||
name = "string_to_hash_bucket_op",
|
||||
prefix = "string_to_hash_bucket_op",
|
||||
srcs = [
|
||||
"string_to_hash_bucket_fast_op.cc",
|
||||
"string_to_hash_bucket_op.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"string_to_hash_bucket_fast_op.h",
|
||||
"string_to_hash_bucket_op.h",
|
||||
],
|
||||
deps = STRING_DEPS,
|
||||
)
|
||||
|
||||
@ -5978,6 +5985,7 @@ filegroup(
|
||||
"stateless_random_ops_v2.h",
|
||||
"string_util.h",
|
||||
"string_to_hash_bucket_op.h",
|
||||
"string_to_hash_bucket_fast_op.h",
|
||||
"tensor_array.h",
|
||||
"tensor_list.h",
|
||||
"tile_functor.h",
|
||||
@ -6242,6 +6250,7 @@ filegroup(
|
||||
"string_split_op.cc",
|
||||
"string_strip_op.cc",
|
||||
"string_to_hash_bucket_op.cc",
|
||||
"string_to_hash_bucket_fast_op.cc",
|
||||
"substr_op.cc",
|
||||
"tensor_array.cc",
|
||||
"tensor_array_ops.cc",
|
||||
|
||||
25
tensorflow/core/kernels/string_to_hash_bucket_fast_op.cc
Normal file
25
tensorflow/core/kernels/string_to_hash_bucket_fast_op.cc
Normal file
@ -0,0 +1,25 @@
|
||||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/kernels/string_to_hash_bucket_fast_op.h"
|
||||
|
||||
#include "tensorflow/core/platform/fingerprint.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("StringToHashBucketFast").Device(DEVICE_CPU),
|
||||
StringToHashBucketOp<Fingerprint64>);
|
||||
|
||||
} // namespace tensorflow
|
||||
66
tensorflow/core/kernels/string_to_hash_bucket_fast_op.h
Normal file
66
tensorflow/core/kernels/string_to_hash_bucket_fast_op.h
Normal file
@ -0,0 +1,66 @@
|
||||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_KERNELS_STRING_TO_HASH_BUCKET_FAST_OP_H_
|
||||
#define TENSORFLOW_CORE_KERNELS_STRING_TO_HASH_BUCKET_FAST_OP_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
template <uint64 hash(StringPiece)>
|
||||
class StringToHashBucketOp : public OpKernel {
|
||||
public:
|
||||
explicit StringToHashBucketOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("num_buckets", &num_buckets_));
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
const Tensor* input_tensor;
|
||||
OP_REQUIRES_OK(context, context->input("input", &input_tensor));
|
||||
const auto& input_flat = input_tensor->flat<tstring>();
|
||||
|
||||
Tensor* output_tensor = nullptr;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_output("output", input_tensor->shape(),
|
||||
&output_tensor));
|
||||
auto output_flat = output_tensor->flat<int64>();
|
||||
|
||||
typedef decltype(input_flat.size()) Index;
|
||||
for (Index i = 0; i < input_flat.size(); ++i) {
|
||||
const uint64 input_hash = hash(input_flat(i));
|
||||
const uint64 bucket_id = input_hash % num_buckets_;
|
||||
// The number of buckets is always in the positive range of int64 so is
|
||||
// the resulting bucket_id. Casting the bucket_id from uint64 to int64 is
|
||||
// safe.
|
||||
output_flat(i) = static_cast<int64>(bucket_id);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
int64 num_buckets_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(StringToHashBucketOp);
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_KERNELS_STRING_TO_HASH_BUCKET_FAST_OP_H_
|
||||
@ -16,7 +16,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/kernels/string_to_hash_bucket_op.h"
|
||||
|
||||
#include "tensorflow/core/lib/hash/hash.h"
|
||||
#include "tensorflow/core/platform/fingerprint.h"
|
||||
#include "tensorflow/core/platform/strong_hash.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -62,9 +61,6 @@ class LegacyStringToHashBucketOp : public OpKernel {
|
||||
REGISTER_KERNEL_BUILDER(Name("StringToHashBucket").Device(DEVICE_CPU),
|
||||
LegacyStringToHashBucketOp);
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("StringToHashBucketFast").Device(DEVICE_CPU),
|
||||
StringToHashBucketOp<Fingerprint64>);
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("StringToHashBucketStrong").Device(DEVICE_CPU),
|
||||
StringToKeyedHashBucketOp<StrongKeyedHash>);
|
||||
|
||||
|
||||
@ -26,41 +26,6 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
template <uint64 hash(StringPiece)>
|
||||
class StringToHashBucketOp : public OpKernel {
|
||||
public:
|
||||
explicit StringToHashBucketOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("num_buckets", &num_buckets_));
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
const Tensor* input_tensor;
|
||||
OP_REQUIRES_OK(context, context->input("input", &input_tensor));
|
||||
const auto& input_flat = input_tensor->flat<tstring>();
|
||||
|
||||
Tensor* output_tensor = nullptr;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_output("output", input_tensor->shape(),
|
||||
&output_tensor));
|
||||
auto output_flat = output_tensor->flat<int64>();
|
||||
|
||||
typedef decltype(input_flat.size()) Index;
|
||||
for (Index i = 0; i < input_flat.size(); ++i) {
|
||||
const uint64 input_hash = hash(input_flat(i));
|
||||
const uint64 bucket_id = input_hash % num_buckets_;
|
||||
// The number of buckets is always in the positive range of int64 so is
|
||||
// the resulting bucket_id. Casting the bucket_id from uint64 to int64 is
|
||||
// safe.
|
||||
output_flat(i) = static_cast<int64>(bucket_id);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
int64 num_buckets_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(StringToHashBucketOp);
|
||||
};
|
||||
|
||||
template <uint64 hash(const uint64 (&)[2], const string&)>
|
||||
class StringToKeyedHashBucketOp : public OpKernel {
|
||||
public:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user