Making the tf.name_scope blocks related to the factor and weight vars configurable. By default they will not be scoped.

PiperOrigin-RevId: 198759754
This commit is contained in:
A. Unique TensorFlower 2018-05-31 12:16:54 -07:00 committed by TensorFlower Gardener
parent fdf4d0813d
commit 519189837b

View File

@ -197,7 +197,8 @@ class WALSModel(object):
row_weights=1,
col_weights=1,
use_factors_weights_cache=True,
use_gramian_cache=True):
use_gramian_cache=True,
use_scoped_vars=False):
"""Creates model for WALS matrix factorization.
Args:
@ -239,6 +240,8 @@ class WALSModel(object):
weights cache to take effect.
use_gramian_cache: When True, the Gramians will be cached on the workers
before the updates start. Defaults to True.
use_scoped_vars: When True, the factor and weight vars will also be nested
in a tf.name_scope.
"""
self._input_rows = input_rows
self._input_cols = input_cols
@ -251,18 +254,36 @@ class WALSModel(object):
regularization * linalg_ops.eye(self._n_components)
if regularization is not None else None)
assert (row_weights is None) == (col_weights is None)
self._row_weights = WALSModel._create_weights(
row_weights, self._input_rows, self._num_row_shards, "row_weights")
self._col_weights = WALSModel._create_weights(
col_weights, self._input_cols, self._num_col_shards, "col_weights")
self._use_factors_weights_cache = use_factors_weights_cache
self._use_gramian_cache = use_gramian_cache
self._row_factors = self._create_factors(
self._input_rows, self._n_components, self._num_row_shards, row_init,
"row_factors")
self._col_factors = self._create_factors(
self._input_cols, self._n_components, self._num_col_shards, col_init,
"col_factors")
if use_scoped_vars:
with ops.name_scope("row_weights"):
self._row_weights = WALSModel._create_weights(
row_weights, self._input_rows, self._num_row_shards, "row_weights")
with ops.name_scope("col_weights"):
self._col_weights = WALSModel._create_weights(
col_weights, self._input_cols, self._num_col_shards, "col_weights")
with ops.name_scope("row_factors"):
self._row_factors = self._create_factors(
self._input_rows, self._n_components, self._num_row_shards,
row_init, "row_factors")
with ops.name_scope("col_factors"):
self._col_factors = self._create_factors(
self._input_cols, self._n_components, self._num_col_shards,
col_init, "col_factors")
else:
self._row_weights = WALSModel._create_weights(
row_weights, self._input_rows, self._num_row_shards, "row_weights")
self._col_weights = WALSModel._create_weights(
col_weights, self._input_cols, self._num_col_shards, "col_weights")
self._row_factors = self._create_factors(
self._input_rows, self._n_components, self._num_row_shards, row_init,
"row_factors")
self._col_factors = self._create_factors(
self._input_cols, self._n_components, self._num_col_shards, col_init,
"col_factors")
self._row_gramian = self._create_gramian(self._n_components, "row_gramian")
self._col_gramian = self._create_gramian(self._n_components, "col_gramian")
with ops.name_scope("row_prepare_gramian"):
@ -313,37 +334,36 @@ class WALSModel(object):
@classmethod
def _create_factors(cls, rows, cols, num_shards, init, name):
"""Helper function to create row and column factors."""
with ops.name_scope(name):
if callable(init):
init = init()
if isinstance(init, list):
assert len(init) == num_shards
elif isinstance(init, str) and init == "random":
pass
elif num_shards == 1:
init = [init]
sharded_matrix = []
sizes = cls._shard_sizes(rows, num_shards)
assert len(sizes) == num_shards
if callable(init):
init = init()
if isinstance(init, list):
assert len(init) == num_shards
elif isinstance(init, str) and init == "random":
pass
elif num_shards == 1:
init = [init]
sharded_matrix = []
sizes = cls._shard_sizes(rows, num_shards)
assert len(sizes) == num_shards
def make_initializer(i, size):
def make_initializer(i, size):
def initializer():
if init == "random":
return random_ops.random_normal([size, cols])
else:
return init[i]
def initializer():
if init == "random":
return random_ops.random_normal([size, cols])
else:
return init[i]
return initializer
return initializer
for i, size in enumerate(sizes):
var_name = "%s_shard_%d" % (name, i)
var_init = make_initializer(i, size)
sharded_matrix.append(
variable_scope.variable(
var_init, dtype=dtypes.float32, name=var_name))
for i, size in enumerate(sizes):
var_name = "%s_shard_%d" % (name, i)
var_init = make_initializer(i, size)
sharded_matrix.append(
variable_scope.variable(
var_init, dtype=dtypes.float32, name=var_name))
return sharded_matrix
return sharded_matrix
@classmethod
def _create_weights(cls, wt_init, num_wts, num_shards, name):
@ -384,26 +404,25 @@ class WALSModel(object):
sizes = cls._shard_sizes(num_wts, num_shards)
assert len(sizes) == num_shards
with ops.name_scope(name):
def make_wt_initializer(i, size):
def make_wt_initializer(i, size):
def initializer():
if init_mode == "scalar":
return wt_init * array_ops.ones([size])
else:
return wt_init[i]
def initializer():
if init_mode == "scalar":
return wt_init * array_ops.ones([size])
else:
return wt_init[i]
return initializer
return initializer
sharded_weight = []
for i, size in enumerate(sizes):
var_name = "%s_shard_%d" % (name, i)
var_init = make_wt_initializer(i, size)
sharded_weight.append(
variable_scope.variable(
var_init, dtype=dtypes.float32, name=var_name))
sharded_weight = []
for i, size in enumerate(sizes):
var_name = "%s_shard_%d" % (name, i)
var_init = make_wt_initializer(i, size)
sharded_weight.append(
variable_scope.variable(
var_init, dtype=dtypes.float32, name=var_name))
return sharded_weight
return sharded_weight
@staticmethod
def _create_gramian(n_components, name):