From f8f5dca1cf098388363eaa4bd0c6b39ac441b994 Mon Sep 17 00:00:00 2001 From: Zhenyu Tan Date: Fri, 1 Mar 2019 13:34:22 -0800 Subject: [PATCH] Fallback to RefVariable for Kmeans. PiperOrigin-RevId: 236368697 --- tensorflow/python/ops/clustering_ops.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/tensorflow/python/ops/clustering_ops.py b/tensorflow/python/ops/clustering_ops.py index 1ca375c314d..2d0edad378d 100644 --- a/tensorflow/python/ops/clustering_ops.py +++ b/tensorflow/python/ops/clustering_ops.py @@ -288,29 +288,34 @@ class KMeans(object): """ init_value = array_ops.constant([], dtype=dtypes.float32) cluster_centers = variable_scope.variable( - init_value, name=CLUSTERS_VAR_NAME, validate_shape=False) + init_value, name=CLUSTERS_VAR_NAME, validate_shape=False, + use_resource=False) cluster_centers_initialized = variable_scope.variable( - False, dtype=dtypes.bool, name='initialized') + False, dtype=dtypes.bool, name='initialized', use_resource=False) if self._use_mini_batch and self._mini_batch_steps_per_iteration > 1: # Copy of cluster centers actively updated each step according to # mini-batch update rule. cluster_centers_updated = variable_scope.variable( - init_value, name='clusters_updated', validate_shape=False) + init_value, name='clusters_updated', validate_shape=False, + use_resource=False) # How many steps till we copy the updated clusters to cluster_centers. update_in_steps = variable_scope.variable( self._mini_batch_steps_per_iteration, dtype=dtypes.int64, - name='update_in_steps') + name='update_in_steps', + use_resource=False) # Count of points assigned to cluster_centers_updated. cluster_counts = variable_scope.variable( - array_ops.zeros([num_clusters], dtype=dtypes.int64)) + array_ops.zeros([num_clusters], dtype=dtypes.int64), + use_resource=False) else: cluster_centers_updated = cluster_centers update_in_steps = None cluster_counts = ( - variable_scope.variable( - array_ops.ones([num_clusters], dtype=dtypes.int64)) + variable_scope.variable( # pylint:disable=g-long-ternary + array_ops.ones([num_clusters], dtype=dtypes.int64), + use_resource=False) if self._use_mini_batch else None) return (cluster_centers, cluster_centers_initialized, cluster_counts, cluster_centers_updated, update_in_steps)