diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py index 0578ae8b0f4..6748a765623 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py @@ -68,11 +68,12 @@ def _create_global_step(graph): return variable_scope.get_variable( ops.GraphKeys.GLOBAL_STEP, shape=[], - dtype=dtypes.int64, + dtype=dtypes.int32, initializer=init_ops.zeros_initializer(), trainable=False, use_resource=True, - collections=[ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.GLOBAL_STEP]) + collections=[ops.GraphKeys.GLOBAL_VARIABLES, + ops.GraphKeys.GLOBAL_STEP]) def _sync_variables_ops():