Fix to remove TF op usage outside of the initializer fn (due to deferred execution of initializer fn, this prevent issues with graph mismatch).
PiperOrigin-RevId: 157620177
This commit is contained in:
parent
e8d17ea8c1
commit
ee05b8b690
@ -17,6 +17,8 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import math
|
||||
|
||||
from tensorflow.contrib.framework.python.ops import gen_checkpoint_ops
|
||||
from tensorflow.contrib.util import loader
|
||||
from tensorflow.python.framework import dtypes
|
||||
@ -422,8 +424,7 @@ def load_embedding_initializer(ckpt_path,
|
||||
# TODO(b/25671353): This should be kept in sync with the stddev used by
|
||||
# feature_column.py's _EmbeddingColumn.
|
||||
initializer = init_ops.truncated_normal_initializer(
|
||||
stddev=1.0 /
|
||||
math_ops.sqrt(math_ops.cast(embedding_dim, dtypes.float32)))
|
||||
stddev=1.0 / math.sqrt(embedding_dim))
|
||||
|
||||
return load_and_remap_matrix_initializer(
|
||||
ckpt_path=ckpt_path,
|
||||
|
Loading…
Reference in New Issue
Block a user