diff --git a/tensorflow/contrib/framework/python/ops/checkpoint_ops.py b/tensorflow/contrib/framework/python/ops/checkpoint_ops.py index efeadd3c3c9..fdb834f46b6 100644 --- a/tensorflow/contrib/framework/python/ops/checkpoint_ops.py +++ b/tensorflow/contrib/framework/python/ops/checkpoint_ops.py @@ -17,6 +17,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import math + from tensorflow.contrib.framework.python.ops import gen_checkpoint_ops from tensorflow.contrib.util import loader from tensorflow.python.framework import dtypes @@ -422,8 +424,7 @@ def load_embedding_initializer(ckpt_path, # TODO(b/25671353): This should be kept in sync with the stddev used by # feature_column.py's _EmbeddingColumn. initializer = init_ops.truncated_normal_initializer( - stddev=1.0 / - math_ops.sqrt(math_ops.cast(embedding_dim, dtypes.float32))) + stddev=1.0 / math.sqrt(embedding_dim)) return load_and_remap_matrix_initializer( ckpt_path=ckpt_path,