Making the tf.name_scope blocks related to the factor and weight vars configurable. By default they will not be scoped.
PiperOrigin-RevId: 198759754
This commit is contained in:
parent
fdf4d0813d
commit
519189837b
@ -197,7 +197,8 @@ class WALSModel(object):
|
||||
row_weights=1,
|
||||
col_weights=1,
|
||||
use_factors_weights_cache=True,
|
||||
use_gramian_cache=True):
|
||||
use_gramian_cache=True,
|
||||
use_scoped_vars=False):
|
||||
"""Creates model for WALS matrix factorization.
|
||||
|
||||
Args:
|
||||
@ -239,6 +240,8 @@ class WALSModel(object):
|
||||
weights cache to take effect.
|
||||
use_gramian_cache: When True, the Gramians will be cached on the workers
|
||||
before the updates start. Defaults to True.
|
||||
use_scoped_vars: When True, the factor and weight vars will also be nested
|
||||
in a tf.name_scope.
|
||||
"""
|
||||
self._input_rows = input_rows
|
||||
self._input_cols = input_cols
|
||||
@ -251,18 +254,36 @@ class WALSModel(object):
|
||||
regularization * linalg_ops.eye(self._n_components)
|
||||
if regularization is not None else None)
|
||||
assert (row_weights is None) == (col_weights is None)
|
||||
self._row_weights = WALSModel._create_weights(
|
||||
row_weights, self._input_rows, self._num_row_shards, "row_weights")
|
||||
self._col_weights = WALSModel._create_weights(
|
||||
col_weights, self._input_cols, self._num_col_shards, "col_weights")
|
||||
self._use_factors_weights_cache = use_factors_weights_cache
|
||||
self._use_gramian_cache = use_gramian_cache
|
||||
self._row_factors = self._create_factors(
|
||||
self._input_rows, self._n_components, self._num_row_shards, row_init,
|
||||
"row_factors")
|
||||
self._col_factors = self._create_factors(
|
||||
self._input_cols, self._n_components, self._num_col_shards, col_init,
|
||||
"col_factors")
|
||||
|
||||
if use_scoped_vars:
|
||||
with ops.name_scope("row_weights"):
|
||||
self._row_weights = WALSModel._create_weights(
|
||||
row_weights, self._input_rows, self._num_row_shards, "row_weights")
|
||||
with ops.name_scope("col_weights"):
|
||||
self._col_weights = WALSModel._create_weights(
|
||||
col_weights, self._input_cols, self._num_col_shards, "col_weights")
|
||||
with ops.name_scope("row_factors"):
|
||||
self._row_factors = self._create_factors(
|
||||
self._input_rows, self._n_components, self._num_row_shards,
|
||||
row_init, "row_factors")
|
||||
with ops.name_scope("col_factors"):
|
||||
self._col_factors = self._create_factors(
|
||||
self._input_cols, self._n_components, self._num_col_shards,
|
||||
col_init, "col_factors")
|
||||
else:
|
||||
self._row_weights = WALSModel._create_weights(
|
||||
row_weights, self._input_rows, self._num_row_shards, "row_weights")
|
||||
self._col_weights = WALSModel._create_weights(
|
||||
col_weights, self._input_cols, self._num_col_shards, "col_weights")
|
||||
self._row_factors = self._create_factors(
|
||||
self._input_rows, self._n_components, self._num_row_shards, row_init,
|
||||
"row_factors")
|
||||
self._col_factors = self._create_factors(
|
||||
self._input_cols, self._n_components, self._num_col_shards, col_init,
|
||||
"col_factors")
|
||||
|
||||
self._row_gramian = self._create_gramian(self._n_components, "row_gramian")
|
||||
self._col_gramian = self._create_gramian(self._n_components, "col_gramian")
|
||||
with ops.name_scope("row_prepare_gramian"):
|
||||
@ -313,37 +334,36 @@ class WALSModel(object):
|
||||
@classmethod
|
||||
def _create_factors(cls, rows, cols, num_shards, init, name):
|
||||
"""Helper function to create row and column factors."""
|
||||
with ops.name_scope(name):
|
||||
if callable(init):
|
||||
init = init()
|
||||
if isinstance(init, list):
|
||||
assert len(init) == num_shards
|
||||
elif isinstance(init, str) and init == "random":
|
||||
pass
|
||||
elif num_shards == 1:
|
||||
init = [init]
|
||||
sharded_matrix = []
|
||||
sizes = cls._shard_sizes(rows, num_shards)
|
||||
assert len(sizes) == num_shards
|
||||
if callable(init):
|
||||
init = init()
|
||||
if isinstance(init, list):
|
||||
assert len(init) == num_shards
|
||||
elif isinstance(init, str) and init == "random":
|
||||
pass
|
||||
elif num_shards == 1:
|
||||
init = [init]
|
||||
sharded_matrix = []
|
||||
sizes = cls._shard_sizes(rows, num_shards)
|
||||
assert len(sizes) == num_shards
|
||||
|
||||
def make_initializer(i, size):
|
||||
def make_initializer(i, size):
|
||||
|
||||
def initializer():
|
||||
if init == "random":
|
||||
return random_ops.random_normal([size, cols])
|
||||
else:
|
||||
return init[i]
|
||||
def initializer():
|
||||
if init == "random":
|
||||
return random_ops.random_normal([size, cols])
|
||||
else:
|
||||
return init[i]
|
||||
|
||||
return initializer
|
||||
return initializer
|
||||
|
||||
for i, size in enumerate(sizes):
|
||||
var_name = "%s_shard_%d" % (name, i)
|
||||
var_init = make_initializer(i, size)
|
||||
sharded_matrix.append(
|
||||
variable_scope.variable(
|
||||
var_init, dtype=dtypes.float32, name=var_name))
|
||||
for i, size in enumerate(sizes):
|
||||
var_name = "%s_shard_%d" % (name, i)
|
||||
var_init = make_initializer(i, size)
|
||||
sharded_matrix.append(
|
||||
variable_scope.variable(
|
||||
var_init, dtype=dtypes.float32, name=var_name))
|
||||
|
||||
return sharded_matrix
|
||||
return sharded_matrix
|
||||
|
||||
@classmethod
|
||||
def _create_weights(cls, wt_init, num_wts, num_shards, name):
|
||||
@ -384,26 +404,25 @@ class WALSModel(object):
|
||||
sizes = cls._shard_sizes(num_wts, num_shards)
|
||||
assert len(sizes) == num_shards
|
||||
|
||||
with ops.name_scope(name):
|
||||
def make_wt_initializer(i, size):
|
||||
def make_wt_initializer(i, size):
|
||||
|
||||
def initializer():
|
||||
if init_mode == "scalar":
|
||||
return wt_init * array_ops.ones([size])
|
||||
else:
|
||||
return wt_init[i]
|
||||
def initializer():
|
||||
if init_mode == "scalar":
|
||||
return wt_init * array_ops.ones([size])
|
||||
else:
|
||||
return wt_init[i]
|
||||
|
||||
return initializer
|
||||
return initializer
|
||||
|
||||
sharded_weight = []
|
||||
for i, size in enumerate(sizes):
|
||||
var_name = "%s_shard_%d" % (name, i)
|
||||
var_init = make_wt_initializer(i, size)
|
||||
sharded_weight.append(
|
||||
variable_scope.variable(
|
||||
var_init, dtype=dtypes.float32, name=var_name))
|
||||
sharded_weight = []
|
||||
for i, size in enumerate(sizes):
|
||||
var_name = "%s_shard_%d" % (name, i)
|
||||
var_init = make_wt_initializer(i, size)
|
||||
sharded_weight.append(
|
||||
variable_scope.variable(
|
||||
var_init, dtype=dtypes.float32, name=var_name))
|
||||
|
||||
return sharded_weight
|
||||
return sharded_weight
|
||||
|
||||
@staticmethod
|
||||
def _create_gramian(n_components, name):
|
||||
|
Loading…
Reference in New Issue
Block a user