Making the tf.name_scope blocks related to the factor and weight vars configurable. By default they will not be scoped.
PiperOrigin-RevId: 198759754
This commit is contained in:
parent
fdf4d0813d
commit
519189837b
@ -197,7 +197,8 @@ class WALSModel(object):
|
|||||||
row_weights=1,
|
row_weights=1,
|
||||||
col_weights=1,
|
col_weights=1,
|
||||||
use_factors_weights_cache=True,
|
use_factors_weights_cache=True,
|
||||||
use_gramian_cache=True):
|
use_gramian_cache=True,
|
||||||
|
use_scoped_vars=False):
|
||||||
"""Creates model for WALS matrix factorization.
|
"""Creates model for WALS matrix factorization.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -239,6 +240,8 @@ class WALSModel(object):
|
|||||||
weights cache to take effect.
|
weights cache to take effect.
|
||||||
use_gramian_cache: When True, the Gramians will be cached on the workers
|
use_gramian_cache: When True, the Gramians will be cached on the workers
|
||||||
before the updates start. Defaults to True.
|
before the updates start. Defaults to True.
|
||||||
|
use_scoped_vars: When True, the factor and weight vars will also be nested
|
||||||
|
in a tf.name_scope.
|
||||||
"""
|
"""
|
||||||
self._input_rows = input_rows
|
self._input_rows = input_rows
|
||||||
self._input_cols = input_cols
|
self._input_cols = input_cols
|
||||||
@ -251,18 +254,36 @@ class WALSModel(object):
|
|||||||
regularization * linalg_ops.eye(self._n_components)
|
regularization * linalg_ops.eye(self._n_components)
|
||||||
if regularization is not None else None)
|
if regularization is not None else None)
|
||||||
assert (row_weights is None) == (col_weights is None)
|
assert (row_weights is None) == (col_weights is None)
|
||||||
self._row_weights = WALSModel._create_weights(
|
|
||||||
row_weights, self._input_rows, self._num_row_shards, "row_weights")
|
|
||||||
self._col_weights = WALSModel._create_weights(
|
|
||||||
col_weights, self._input_cols, self._num_col_shards, "col_weights")
|
|
||||||
self._use_factors_weights_cache = use_factors_weights_cache
|
self._use_factors_weights_cache = use_factors_weights_cache
|
||||||
self._use_gramian_cache = use_gramian_cache
|
self._use_gramian_cache = use_gramian_cache
|
||||||
self._row_factors = self._create_factors(
|
|
||||||
self._input_rows, self._n_components, self._num_row_shards, row_init,
|
if use_scoped_vars:
|
||||||
"row_factors")
|
with ops.name_scope("row_weights"):
|
||||||
self._col_factors = self._create_factors(
|
self._row_weights = WALSModel._create_weights(
|
||||||
self._input_cols, self._n_components, self._num_col_shards, col_init,
|
row_weights, self._input_rows, self._num_row_shards, "row_weights")
|
||||||
"col_factors")
|
with ops.name_scope("col_weights"):
|
||||||
|
self._col_weights = WALSModel._create_weights(
|
||||||
|
col_weights, self._input_cols, self._num_col_shards, "col_weights")
|
||||||
|
with ops.name_scope("row_factors"):
|
||||||
|
self._row_factors = self._create_factors(
|
||||||
|
self._input_rows, self._n_components, self._num_row_shards,
|
||||||
|
row_init, "row_factors")
|
||||||
|
with ops.name_scope("col_factors"):
|
||||||
|
self._col_factors = self._create_factors(
|
||||||
|
self._input_cols, self._n_components, self._num_col_shards,
|
||||||
|
col_init, "col_factors")
|
||||||
|
else:
|
||||||
|
self._row_weights = WALSModel._create_weights(
|
||||||
|
row_weights, self._input_rows, self._num_row_shards, "row_weights")
|
||||||
|
self._col_weights = WALSModel._create_weights(
|
||||||
|
col_weights, self._input_cols, self._num_col_shards, "col_weights")
|
||||||
|
self._row_factors = self._create_factors(
|
||||||
|
self._input_rows, self._n_components, self._num_row_shards, row_init,
|
||||||
|
"row_factors")
|
||||||
|
self._col_factors = self._create_factors(
|
||||||
|
self._input_cols, self._n_components, self._num_col_shards, col_init,
|
||||||
|
"col_factors")
|
||||||
|
|
||||||
self._row_gramian = self._create_gramian(self._n_components, "row_gramian")
|
self._row_gramian = self._create_gramian(self._n_components, "row_gramian")
|
||||||
self._col_gramian = self._create_gramian(self._n_components, "col_gramian")
|
self._col_gramian = self._create_gramian(self._n_components, "col_gramian")
|
||||||
with ops.name_scope("row_prepare_gramian"):
|
with ops.name_scope("row_prepare_gramian"):
|
||||||
@ -313,37 +334,36 @@ class WALSModel(object):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def _create_factors(cls, rows, cols, num_shards, init, name):
|
def _create_factors(cls, rows, cols, num_shards, init, name):
|
||||||
"""Helper function to create row and column factors."""
|
"""Helper function to create row and column factors."""
|
||||||
with ops.name_scope(name):
|
if callable(init):
|
||||||
if callable(init):
|
init = init()
|
||||||
init = init()
|
if isinstance(init, list):
|
||||||
if isinstance(init, list):
|
assert len(init) == num_shards
|
||||||
assert len(init) == num_shards
|
elif isinstance(init, str) and init == "random":
|
||||||
elif isinstance(init, str) and init == "random":
|
pass
|
||||||
pass
|
elif num_shards == 1:
|
||||||
elif num_shards == 1:
|
init = [init]
|
||||||
init = [init]
|
sharded_matrix = []
|
||||||
sharded_matrix = []
|
sizes = cls._shard_sizes(rows, num_shards)
|
||||||
sizes = cls._shard_sizes(rows, num_shards)
|
assert len(sizes) == num_shards
|
||||||
assert len(sizes) == num_shards
|
|
||||||
|
|
||||||
def make_initializer(i, size):
|
def make_initializer(i, size):
|
||||||
|
|
||||||
def initializer():
|
def initializer():
|
||||||
if init == "random":
|
if init == "random":
|
||||||
return random_ops.random_normal([size, cols])
|
return random_ops.random_normal([size, cols])
|
||||||
else:
|
else:
|
||||||
return init[i]
|
return init[i]
|
||||||
|
|
||||||
return initializer
|
return initializer
|
||||||
|
|
||||||
for i, size in enumerate(sizes):
|
for i, size in enumerate(sizes):
|
||||||
var_name = "%s_shard_%d" % (name, i)
|
var_name = "%s_shard_%d" % (name, i)
|
||||||
var_init = make_initializer(i, size)
|
var_init = make_initializer(i, size)
|
||||||
sharded_matrix.append(
|
sharded_matrix.append(
|
||||||
variable_scope.variable(
|
variable_scope.variable(
|
||||||
var_init, dtype=dtypes.float32, name=var_name))
|
var_init, dtype=dtypes.float32, name=var_name))
|
||||||
|
|
||||||
return sharded_matrix
|
return sharded_matrix
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _create_weights(cls, wt_init, num_wts, num_shards, name):
|
def _create_weights(cls, wt_init, num_wts, num_shards, name):
|
||||||
@ -384,26 +404,25 @@ class WALSModel(object):
|
|||||||
sizes = cls._shard_sizes(num_wts, num_shards)
|
sizes = cls._shard_sizes(num_wts, num_shards)
|
||||||
assert len(sizes) == num_shards
|
assert len(sizes) == num_shards
|
||||||
|
|
||||||
with ops.name_scope(name):
|
def make_wt_initializer(i, size):
|
||||||
def make_wt_initializer(i, size):
|
|
||||||
|
|
||||||
def initializer():
|
def initializer():
|
||||||
if init_mode == "scalar":
|
if init_mode == "scalar":
|
||||||
return wt_init * array_ops.ones([size])
|
return wt_init * array_ops.ones([size])
|
||||||
else:
|
else:
|
||||||
return wt_init[i]
|
return wt_init[i]
|
||||||
|
|
||||||
return initializer
|
return initializer
|
||||||
|
|
||||||
sharded_weight = []
|
sharded_weight = []
|
||||||
for i, size in enumerate(sizes):
|
for i, size in enumerate(sizes):
|
||||||
var_name = "%s_shard_%d" % (name, i)
|
var_name = "%s_shard_%d" % (name, i)
|
||||||
var_init = make_wt_initializer(i, size)
|
var_init = make_wt_initializer(i, size)
|
||||||
sharded_weight.append(
|
sharded_weight.append(
|
||||||
variable_scope.variable(
|
variable_scope.variable(
|
||||||
var_init, dtype=dtypes.float32, name=var_name))
|
var_init, dtype=dtypes.float32, name=var_name))
|
||||||
|
|
||||||
return sharded_weight
|
return sharded_weight
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _create_gramian(n_components, name):
|
def _create_gramian(n_components, name):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user