Use arg 'axis' instead of deprecated arg 'dim'

This commit is contained in:
Robert Herbig 2019-09-16 22:35:50 -04:00
parent 376e283836
commit bb69697254

View File

@ -216,11 +216,11 @@ class KMeans(object):
output = []
if not inputs_normalized:
with ops.colocate_with(clusters, ignore_existing=True):
clusters = nn_impl.l2_normalize(clusters, dim=1)
clusters = nn_impl.l2_normalize(clusters, axis=1)
for inp in inputs:
with ops.colocate_with(inp, ignore_existing=True):
if not inputs_normalized:
inp = nn_impl.l2_normalize(inp, dim=1)
inp = nn_impl.l2_normalize(inp, axis=1)
output.append(1 - math_ops.matmul(inp, clusters, transpose_b=True))
return output
@ -251,7 +251,7 @@ class KMeans(object):
# TODO(ands): Support COSINE distance in nearest_neighbors and remove
# this.
with ops.colocate_with(clusters, ignore_existing=True):
clusters = nn_impl.l2_normalize(clusters, dim=1)
clusters = nn_impl.l2_normalize(clusters, axis=1)
for inp, score in zip(inputs, scores):
with ops.colocate_with(inp, ignore_existing=True):
(indices, distances) = gen_clustering_ops.nearest_neighbors(