Move clustering ops to core.
PiperOrigin-RevId: 228808275
This commit is contained in:
parent
3fc2b09b60
commit
578bd3a276
@ -28,7 +28,6 @@ tf_custom_op_py_library(
|
|||||||
"python/ops/wals.py",
|
"python/ops/wals.py",
|
||||||
],
|
],
|
||||||
dso = [
|
dso = [
|
||||||
":python/ops/_clustering_ops.so",
|
|
||||||
":python/ops/_factorization_ops.so",
|
":python/ops/_factorization_ops.so",
|
||||||
],
|
],
|
||||||
kernels = [
|
kernels = [
|
||||||
@ -38,12 +37,12 @@ tf_custom_op_py_library(
|
|||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
deps = [
|
deps = [
|
||||||
":factorization_ops_test_utils_py",
|
":factorization_ops_test_utils_py",
|
||||||
":gen_clustering_ops",
|
|
||||||
":gen_factorization_ops",
|
":gen_factorization_ops",
|
||||||
"//tensorflow/contrib/framework:framework_py",
|
"//tensorflow/contrib/framework:framework_py",
|
||||||
"//tensorflow/contrib/util:util_py",
|
"//tensorflow/contrib/util:util_py",
|
||||||
"//tensorflow/python:array_ops",
|
"//tensorflow/python:array_ops",
|
||||||
"//tensorflow/python:check_ops",
|
"//tensorflow/python:check_ops",
|
||||||
|
"//tensorflow/python:clustering_ops_gen",
|
||||||
"//tensorflow/python:control_flow_ops",
|
"//tensorflow/python:control_flow_ops",
|
||||||
"//tensorflow/python:data_flow_ops",
|
"//tensorflow/python:data_flow_ops",
|
||||||
"//tensorflow/python:embedding_ops",
|
"//tensorflow/python:embedding_ops",
|
||||||
@ -77,17 +76,6 @@ py_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Ops
|
|
||||||
tf_custom_op_library(
|
|
||||||
name = "python/ops/_clustering_ops.so",
|
|
||||||
srcs = [
|
|
||||||
"ops/clustering_ops.cc",
|
|
||||||
],
|
|
||||||
deps = [
|
|
||||||
"//tensorflow/contrib/factorization/kernels:clustering_ops",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
tf_custom_op_library(
|
tf_custom_op_library(
|
||||||
name = "python/ops/_factorization_ops.so",
|
name = "python/ops/_factorization_ops.so",
|
||||||
srcs = [
|
srcs = [
|
||||||
@ -100,26 +88,16 @@ tf_custom_op_library(
|
|||||||
)
|
)
|
||||||
|
|
||||||
tf_gen_op_libs([
|
tf_gen_op_libs([
|
||||||
"clustering_ops",
|
|
||||||
"factorization_ops",
|
"factorization_ops",
|
||||||
])
|
])
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "all_ops",
|
name = "all_ops",
|
||||||
deps = [
|
deps = [
|
||||||
":clustering_ops_op_lib",
|
|
||||||
":factorization_ops_op_lib",
|
":factorization_ops_op_lib",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
tf_gen_op_wrapper_py(
|
|
||||||
name = "gen_clustering_ops",
|
|
||||||
out = "python/ops/gen_clustering_ops.py",
|
|
||||||
deps = [
|
|
||||||
":clustering_ops_op_lib",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
tf_gen_op_wrapper_py(
|
tf_gen_op_wrapper_py(
|
||||||
name = "gen_factorization_ops",
|
name = "gen_factorization_ops",
|
||||||
out = "python/ops/gen_factorization_ops.py",
|
out = "python/ops/gen_factorization_ops.py",
|
||||||
|
@ -11,7 +11,6 @@ load("//tensorflow:tensorflow.bzl", "tf_cc_test")
|
|||||||
cc_library(
|
cc_library(
|
||||||
name = "all_kernels",
|
name = "all_kernels",
|
||||||
deps = [
|
deps = [
|
||||||
":clustering_ops",
|
|
||||||
":masked_matmul_ops",
|
":masked_matmul_ops",
|
||||||
":wals_solver_ops",
|
":wals_solver_ops",
|
||||||
"@protobuf_archive//:protobuf_headers",
|
"@protobuf_archive//:protobuf_headers",
|
||||||
@ -29,17 +28,6 @@ cc_library(
|
|||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
|
||||||
name = "clustering_ops",
|
|
||||||
srcs = ["clustering_ops.cc"],
|
|
||||||
deps = [
|
|
||||||
"//tensorflow/core:framework_headers_lib",
|
|
||||||
"//third_party/eigen3",
|
|
||||||
"@protobuf_archive//:protobuf_headers",
|
|
||||||
],
|
|
||||||
alwayslink = 1,
|
|
||||||
)
|
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "masked_matmul_ops",
|
name = "masked_matmul_ops",
|
||||||
srcs = ["masked_matmul_ops.cc"],
|
srcs = ["masked_matmul_ops.cc"],
|
||||||
@ -51,19 +39,3 @@ cc_library(
|
|||||||
],
|
],
|
||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
tf_cc_test(
|
|
||||||
name = "clustering_ops_test",
|
|
||||||
srcs = ["clustering_ops_test.cc"],
|
|
||||||
deps = [
|
|
||||||
":clustering_ops",
|
|
||||||
"//tensorflow/contrib/factorization:clustering_ops_op_lib",
|
|
||||||
"//tensorflow/core:core_cpu",
|
|
||||||
"//tensorflow/core:framework",
|
|
||||||
"//tensorflow/core:lib",
|
|
||||||
"//tensorflow/core:protos_all_cc",
|
|
||||||
"//tensorflow/core:test",
|
|
||||||
"//tensorflow/core:test_main",
|
|
||||||
"//tensorflow/core:testlib",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
@ -1,91 +0,0 @@
|
|||||||
// Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not
|
|
||||||
// use this file except in compliance with the License. You may obtain a copy
|
|
||||||
// of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
// License for the specific language governing permissions and limitations under
|
|
||||||
// the License.
|
|
||||||
// ==============================================================================
|
|
||||||
|
|
||||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
|
||||||
#include "tensorflow/core/framework/op.h"
|
|
||||||
|
|
||||||
namespace tensorflow {
|
|
||||||
|
|
||||||
REGISTER_OP("KmeansPlusPlusInitialization")
|
|
||||||
.Input("points: float32")
|
|
||||||
.Input("num_to_sample: int64")
|
|
||||||
.Input("seed: int64")
|
|
||||||
.Input("num_retries_per_sample: int64")
|
|
||||||
.Output("samples: float32")
|
|
||||||
.SetShapeFn(shape_inference::UnknownShape)
|
|
||||||
.Doc(R"(
|
|
||||||
Selects num_to_sample rows of input using the KMeans++ criterion.
|
|
||||||
|
|
||||||
Rows of points are assumed to be input points. One row is selected at random.
|
|
||||||
Subsequent rows are sampled with probability proportional to the squared L2
|
|
||||||
distance from the nearest row selected thus far till num_to_sample rows have
|
|
||||||
been sampled.
|
|
||||||
|
|
||||||
points: Matrix of shape (n, d). Rows are assumed to be input points.
|
|
||||||
num_to_sample: Scalar. The number of rows to sample. This value must not be
|
|
||||||
larger than n.
|
|
||||||
seed: Scalar. Seed for initializing the random number generator.
|
|
||||||
num_retries_per_sample: Scalar. For each row that is sampled, this parameter
|
|
||||||
specifies the number of additional points to draw from the current
|
|
||||||
distribution before selecting the best. If a negative value is specified, a
|
|
||||||
heuristic is used to sample O(log(num_to_sample)) additional points.
|
|
||||||
samples: Matrix of shape (num_to_sample, d). The sampled rows.
|
|
||||||
)");
|
|
||||||
|
|
||||||
REGISTER_OP("KMC2ChainInitialization")
|
|
||||||
.Input("distances: float32")
|
|
||||||
.Input("seed: int64")
|
|
||||||
.Output("index: int64")
|
|
||||||
.SetShapeFn(shape_inference::ScalarShape)
|
|
||||||
.Doc(R"(
|
|
||||||
Returns the index of a data point that should be added to the seed set.
|
|
||||||
|
|
||||||
Entries in distances are assumed to be squared distances of candidate points to
|
|
||||||
the already sampled centers in the seed set. The op constructs one Markov chain
|
|
||||||
of the k-MC^2 algorithm and returns the index of one candidate point to be added
|
|
||||||
as an additional cluster center.
|
|
||||||
|
|
||||||
distances: Vector with squared distances to the closest previously sampled
|
|
||||||
cluster center for each candidate point.
|
|
||||||
seed: Scalar. Seed for initializing the random number generator.
|
|
||||||
index: Scalar with the index of the sampled point.
|
|
||||||
)");
|
|
||||||
|
|
||||||
REGISTER_OP("NearestNeighbors")
|
|
||||||
.Input("points: float32")
|
|
||||||
.Input("centers: float32")
|
|
||||||
.Input("k: int64")
|
|
||||||
.Output("nearest_center_indices: int64")
|
|
||||||
.Output("nearest_center_distances: float32")
|
|
||||||
.SetShapeFn(shape_inference::UnknownShape)
|
|
||||||
.Doc(R"(
|
|
||||||
Selects the k nearest centers for each point.
|
|
||||||
|
|
||||||
Rows of points are assumed to be input points. Rows of centers are assumed to be
|
|
||||||
the list of candidate centers. For each point, the k centers that have least L2
|
|
||||||
distance to it are computed.
|
|
||||||
|
|
||||||
points: Matrix of shape (n, d). Rows are assumed to be input points.
|
|
||||||
centers: Matrix of shape (m, d). Rows are assumed to be centers.
|
|
||||||
k: Scalar. Number of nearest centers to return for each point. If k is larger
|
|
||||||
than m, then only m centers are returned.
|
|
||||||
nearest_center_indices: Matrix of shape (n, min(m, k)). Each row contains the
|
|
||||||
indices of the centers closest to the corresponding point, ordered by
|
|
||||||
increasing distance.
|
|
||||||
nearest_center_distances: Matrix of shape (n, min(m, k)). Each row contains the
|
|
||||||
squared L2 distance to the corresponding center in nearest_center_indices.
|
|
||||||
)");
|
|
||||||
|
|
||||||
} // namespace tensorflow
|
|
@ -18,28 +18,23 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
from tensorflow.contrib.factorization.python.ops import gen_clustering_ops
|
|
||||||
# go/tf-wildcard-import
|
|
||||||
# pylint: disable=wildcard-import
|
|
||||||
from tensorflow.contrib.factorization.python.ops.gen_clustering_ops import *
|
|
||||||
# pylint: enable=wildcard-import
|
|
||||||
from tensorflow.contrib.util import loader
|
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import check_ops
|
from tensorflow.python.ops import check_ops
|
||||||
from tensorflow.python.ops import control_flow_ops
|
from tensorflow.python.ops import control_flow_ops
|
||||||
|
from tensorflow.python.ops import gen_clustering_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import nn_impl
|
from tensorflow.python.ops import nn_impl
|
||||||
from tensorflow.python.ops import random_ops
|
from tensorflow.python.ops import random_ops
|
||||||
from tensorflow.python.ops import state_ops
|
from tensorflow.python.ops import state_ops
|
||||||
from tensorflow.python.ops import variable_scope
|
from tensorflow.python.ops import variable_scope
|
||||||
from tensorflow.python.ops.embedding_ops import embedding_lookup
|
from tensorflow.python.ops.embedding_ops import embedding_lookup
|
||||||
from tensorflow.python.platform import resource_loader
|
# go/tf-wildcard-import
|
||||||
|
# pylint: disable=wildcard-import
|
||||||
_clustering_ops = loader.load_op_library(
|
from tensorflow.python.ops.gen_clustering_ops import *
|
||||||
resource_loader.get_path_to_datafile('_clustering_ops.so'))
|
# pylint: enable=wildcard-import
|
||||||
|
|
||||||
# Euclidean distance between vectors U and V is defined as \\(||U - V||_F\\)
|
# Euclidean distance between vectors U and V is defined as \\(||U - V||_F\\)
|
||||||
# which is the square root of the sum of the absolute squares of the elements
|
# which is the square root of the sum of the absolute squares of the elements
|
||||||
|
@ -1074,6 +1074,7 @@ tf_gen_op_libs(
|
|||||||
"tensor_forest_ops",
|
"tensor_forest_ops",
|
||||||
"candidate_sampling_ops",
|
"candidate_sampling_ops",
|
||||||
"checkpoint_ops",
|
"checkpoint_ops",
|
||||||
|
"clustering_ops",
|
||||||
"collective_ops",
|
"collective_ops",
|
||||||
"control_flow_ops",
|
"control_flow_ops",
|
||||||
"ctc_ops",
|
"ctc_ops",
|
||||||
@ -1228,6 +1229,7 @@ cc_library(
|
|||||||
":tensor_forest_ops_op_lib",
|
":tensor_forest_ops_op_lib",
|
||||||
":candidate_sampling_ops_op_lib",
|
":candidate_sampling_ops_op_lib",
|
||||||
":checkpoint_ops_op_lib",
|
":checkpoint_ops_op_lib",
|
||||||
|
":clustering_ops_op_lib",
|
||||||
":collective_ops_op_lib",
|
":collective_ops_op_lib",
|
||||||
":control_flow_ops_op_lib",
|
":control_flow_ops_op_lib",
|
||||||
":ctc_ops_op_lib",
|
":ctc_ops_op_lib",
|
||||||
@ -1382,6 +1384,7 @@ cc_library(
|
|||||||
"//tensorflow/core/kernels:tensor_forest_ops",
|
"//tensorflow/core/kernels:tensor_forest_ops",
|
||||||
"//tensorflow/core/kernels:candidate_sampler_ops",
|
"//tensorflow/core/kernels:candidate_sampler_ops",
|
||||||
"//tensorflow/core/kernels:checkpoint_ops",
|
"//tensorflow/core/kernels:checkpoint_ops",
|
||||||
|
"//tensorflow/core/kernels:clustering_ops",
|
||||||
"//tensorflow/core/kernels:collective_ops",
|
"//tensorflow/core/kernels:collective_ops",
|
||||||
"//tensorflow/core/kernels:control_flow_ops",
|
"//tensorflow/core/kernels:control_flow_ops",
|
||||||
"//tensorflow/core/kernels:ctc_ops",
|
"//tensorflow/core/kernels:ctc_ops",
|
||||||
|
@ -0,0 +1,30 @@
|
|||||||
|
op {
|
||||||
|
graph_op_name: "KMC2ChainInitialization"
|
||||||
|
visibility: HIDDEN
|
||||||
|
in_arg {
|
||||||
|
name: "distances"
|
||||||
|
description: <<END
|
||||||
|
Vector with squared distances to the closest previously sampled cluster center
|
||||||
|
for each candidate point.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
in_arg {
|
||||||
|
name: "seed"
|
||||||
|
description: <<END
|
||||||
|
Scalar. Seed for initializing the random number generator.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
out_arg {
|
||||||
|
name: "index"
|
||||||
|
description: <<END
|
||||||
|
Scalar with the index of the sampled point.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
summary: "Returns the index of a data point that should be added to the seed set."
|
||||||
|
description: <<END
|
||||||
|
Entries in distances are assumed to be squared distances of candidate points to
|
||||||
|
the already sampled centers in the seed set. The op constructs one Markov chain
|
||||||
|
of the k-MC^2 algorithm and returns the index of one candidate point to be added
|
||||||
|
as an additional cluster center.
|
||||||
|
END
|
||||||
|
}
|
@ -0,0 +1,44 @@
|
|||||||
|
op {
|
||||||
|
graph_op_name: "KmeansPlusPlusInitialization"
|
||||||
|
visibility: HIDDEN
|
||||||
|
in_arg {
|
||||||
|
name: "points"
|
||||||
|
description: <<END
|
||||||
|
Matrix of shape (n, d). Rows are assumed to be input points.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
in_arg {
|
||||||
|
name: "num_to_sample"
|
||||||
|
description: <<END
|
||||||
|
Scalar. The number of rows to sample. This value must not be larger than n.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
in_arg {
|
||||||
|
name: "seed"
|
||||||
|
description: <<END
|
||||||
|
Scalar. Seed for initializing the random number generator.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
in_arg {
|
||||||
|
name: "num_retries_per_sample"
|
||||||
|
description: <<END
|
||||||
|
Scalar. For each row that is sampled, this parameter
|
||||||
|
specifies the number of additional points to draw from the current
|
||||||
|
distribution before selecting the best. If a negative value is specified, a
|
||||||
|
heuristic is used to sample O(log(num_to_sample)) additional points.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
out_arg {
|
||||||
|
name: "samples"
|
||||||
|
description: <<END
|
||||||
|
Matrix of shape (num_to_sample, d). The sampled rows.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
summary: "Selects num_to_sample rows of input using the KMeans++ criterion."
|
||||||
|
description: <<END
|
||||||
|
Rows of points are assumed to be input points. One row is selected at random.
|
||||||
|
Subsequent rows are sampled with probability proportional to the squared L2
|
||||||
|
distance from the nearest row selected thus far till num_to_sample rows have
|
||||||
|
been sampled.
|
||||||
|
END
|
||||||
|
}
|
@ -0,0 +1,43 @@
|
|||||||
|
op {
|
||||||
|
graph_op_name: "NearestNeighbors"
|
||||||
|
visibility: HIDDEN
|
||||||
|
in_arg {
|
||||||
|
name: "points"
|
||||||
|
description: <<END
|
||||||
|
Matrix of shape (n, d). Rows are assumed to be input points.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
in_arg {
|
||||||
|
name: "centers"
|
||||||
|
description: <<END
|
||||||
|
Matrix of shape (m, d). Rows are assumed to be centers.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
in_arg {
|
||||||
|
name: "k"
|
||||||
|
description: <<END
|
||||||
|
Number of nearest centers to return for each point. If k is larger than m, then
|
||||||
|
only m centers are returned.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
out_arg {
|
||||||
|
name: "nearest_center_indices"
|
||||||
|
description: <<END
|
||||||
|
Matrix of shape (n, min(m, k)). Each row contains the indices of the centers
|
||||||
|
closest to the corresponding point, ordered by increasing distance.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
out_arg {
|
||||||
|
name: "nearest_center_distances"
|
||||||
|
description: <<END
|
||||||
|
Matrix of shape (n, min(m, k)). Each row contains the squared L2 distance to the
|
||||||
|
corresponding center in nearest_center_indices.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
summary: "Selects the k nearest centers for each point."
|
||||||
|
description: <<END
|
||||||
|
Rows of points are assumed to be input points. Rows of centers are assumed to be
|
||||||
|
the list of candidate centers. For each point, the k centers that have least L2
|
||||||
|
distance to it are computed.
|
||||||
|
END
|
||||||
|
}
|
@ -0,0 +1,4 @@
|
|||||||
|
op {
|
||||||
|
graph_op_name: "KMC2ChainInitialization"
|
||||||
|
visibility: HIDDEN
|
||||||
|
}
|
@ -0,0 +1,4 @@
|
|||||||
|
op {
|
||||||
|
graph_op_name: "KmeansPlusPlusInitialization"
|
||||||
|
visibility: HIDDEN
|
||||||
|
}
|
@ -0,0 +1,4 @@
|
|||||||
|
op {
|
||||||
|
graph_op_name: "NearestNeighbors"
|
||||||
|
visibility: HIDDEN
|
||||||
|
}
|
@ -152,6 +152,33 @@ tf_kernel_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_kernel_library(
|
||||||
|
name = "clustering_ops",
|
||||||
|
prefix = "clustering_ops",
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/core:clustering_ops_op_lib",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:framework_headers_lib",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_cc_test(
|
||||||
|
name = "clustering_ops_test",
|
||||||
|
srcs = ["clustering_ops_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":clustering_ops",
|
||||||
|
"//tensorflow/core:clustering_ops_op_lib",
|
||||||
|
"//tensorflow/core:core_cpu",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core:protos_all_cc",
|
||||||
|
"//tensorflow/core:test",
|
||||||
|
"//tensorflow/core:test_main",
|
||||||
|
"//tensorflow/core:testlib",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
tf_kernel_library(
|
tf_kernel_library(
|
||||||
name = "collective_ops",
|
name = "collective_ops",
|
||||||
prefix = "collective_ops",
|
prefix = "collective_ops",
|
||||||
|
@ -392,7 +392,7 @@ class NearestNeighborsOp : public OpKernel {
|
|||||||
for (; start < limit; ++start) {
|
for (; start < limit; ++start) {
|
||||||
const int64 start_row = num_points * start / num_units;
|
const int64 start_row = num_points * start / num_units;
|
||||||
const int64 limit_row = num_points * (start + 1) / num_units;
|
const int64 limit_row = num_points * (start + 1) / num_units;
|
||||||
CHECK_LE(limit_row, num_points);
|
DCHECK_LE(limit_row, num_points);
|
||||||
const int64 num_rows = limit_row - start_row;
|
const int64 num_rows = limit_row - start_row;
|
||||||
auto points_shard = points.middleRows(start_row, num_rows);
|
auto points_shard = points.middleRows(start_row, num_rows);
|
||||||
const Eigen::VectorXf points_half_squared_norm =
|
const Eigen::VectorXf points_half_squared_norm =
|
||||||
@ -430,7 +430,7 @@ class NearestNeighborsOp : public OpKernel {
|
|||||||
const Eigen::Ref<const Eigen::VectorXf>& centers_half_squared_norm,
|
const Eigen::Ref<const Eigen::VectorXf>& centers_half_squared_norm,
|
||||||
const Eigen::Ref<MatrixXi64RowMajor>& nearest_center_indices,
|
const Eigen::Ref<MatrixXi64RowMajor>& nearest_center_indices,
|
||||||
const Eigen::Ref<MatrixXfRowMajor>& nearest_center_distances) {
|
const Eigen::Ref<MatrixXfRowMajor>& nearest_center_distances) {
|
||||||
CHECK_LE(k, centers.rows());
|
DCHECK_LE(k, centers.rows());
|
||||||
if (centers.rows() <= kNearestNeighborsCentersMaxBlockSize) {
|
if (centers.rows() <= kNearestNeighborsCentersMaxBlockSize) {
|
||||||
FindKNearestCentersOneBlock(k, points, points_half_squared_norm, centers,
|
FindKNearestCentersOneBlock(k, points, points_half_squared_norm, centers,
|
||||||
centers_half_squared_norm,
|
centers_half_squared_norm,
|
||||||
@ -451,7 +451,7 @@ class NearestNeighborsOp : public OpKernel {
|
|||||||
const Eigen::Ref<const Eigen::VectorXf>& centers_half_squared_norm,
|
const Eigen::Ref<const Eigen::VectorXf>& centers_half_squared_norm,
|
||||||
Eigen::Ref<MatrixXi64RowMajor> nearest_center_indices,
|
Eigen::Ref<MatrixXi64RowMajor> nearest_center_indices,
|
||||||
Eigen::Ref<MatrixXfRowMajor> nearest_center_distances) {
|
Eigen::Ref<MatrixXfRowMajor> nearest_center_distances) {
|
||||||
CHECK_LE(k, centers.rows());
|
DCHECK_LE(k, centers.rows());
|
||||||
const int64 num_points = points.rows();
|
const int64 num_points = points.rows();
|
||||||
const MatrixXfRowMajor inner_product = points * centers.transpose();
|
const MatrixXfRowMajor inner_product = points * centers.transpose();
|
||||||
// Find nearest neighbors.
|
// Find nearest neighbors.
|
||||||
@ -500,8 +500,8 @@ class NearestNeighborsOp : public OpKernel {
|
|||||||
Eigen::Ref<MatrixXfRowMajor> nearest_center_distances) {
|
Eigen::Ref<MatrixXfRowMajor> nearest_center_distances) {
|
||||||
const int64 num_points = points.rows();
|
const int64 num_points = points.rows();
|
||||||
const int64 num_centers = centers.rows();
|
const int64 num_centers = centers.rows();
|
||||||
CHECK_LE(k, num_centers);
|
DCHECK_LE(k, num_centers);
|
||||||
CHECK_GT(num_centers, kNearestNeighborsCentersMaxBlockSize);
|
DCHECK_GT(num_centers, kNearestNeighborsCentersMaxBlockSize);
|
||||||
// Store nearest neighbors with first block of centers directly into the
|
// Store nearest neighbors with first block of centers directly into the
|
||||||
// output matrices.
|
// output matrices.
|
||||||
int64 out_k = std::min(k, kNearestNeighborsCentersMaxBlockSize);
|
int64 out_k = std::min(k, kNearestNeighborsCentersMaxBlockSize);
|
43
tensorflow/core/ops/clustering_ops.cc
Normal file
43
tensorflow/core/ops/clustering_ops.cc
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
// Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); you may not
|
||||||
|
// use this file except in compliance with the License. You may obtain a copy
|
||||||
|
// of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
// License for the specific language governing permissions and limitations under
|
||||||
|
// the License.
|
||||||
|
// ==============================================================================
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||||
|
#include "tensorflow/core/framework/op.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
REGISTER_OP("KmeansPlusPlusInitialization")
|
||||||
|
.Input("points: float32")
|
||||||
|
.Input("num_to_sample: int64")
|
||||||
|
.Input("seed: int64")
|
||||||
|
.Input("num_retries_per_sample: int64")
|
||||||
|
.Output("samples: float32")
|
||||||
|
.SetShapeFn(shape_inference::UnknownShape);
|
||||||
|
|
||||||
|
REGISTER_OP("KMC2ChainInitialization")
|
||||||
|
.Input("distances: float32")
|
||||||
|
.Input("seed: int64")
|
||||||
|
.Output("index: int64")
|
||||||
|
.SetShapeFn(shape_inference::ScalarShape);
|
||||||
|
|
||||||
|
REGISTER_OP("NearestNeighbors")
|
||||||
|
.Input("points: float32")
|
||||||
|
.Input("centers: float32")
|
||||||
|
.Input("k: int64")
|
||||||
|
.Output("nearest_center_indices: int64")
|
||||||
|
.Output("nearest_center_distances: float32")
|
||||||
|
.SetShapeFn(shape_inference::UnknownShape);
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
@ -1680,6 +1680,14 @@ tf_gen_op_wrapper_private_py(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_gen_op_wrapper_private_py(
|
||||||
|
name = "clustering_ops_gen",
|
||||||
|
visibility = ["//tensorflow:internal"],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/core:clustering_ops_op_lib",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
tf_gen_op_wrapper_private_py(
|
tf_gen_op_wrapper_private_py(
|
||||||
name = "collective_ops_gen",
|
name = "collective_ops_gen",
|
||||||
visibility = ["//tensorflow:internal"],
|
visibility = ["//tensorflow:internal"],
|
||||||
|
Loading…
x
Reference in New Issue
Block a user