Remove potentially large python constant from TensorForest graph.
Change: 136825477
This commit is contained in:
parent
00526b6f86
commit
3a5b605d8e
@ -34,7 +34,6 @@ from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import random_ops
|
||||
from tensorflow.python.ops import state_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops import variables as tf_variables
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
|
||||
|
||||
@ -164,8 +163,10 @@ class TreeTrainingVariables(object):
|
||||
name=self.get_tree_name('end_of_tree', tree_num),
|
||||
dtype=dtypes.int32,
|
||||
initializer=constant_op.constant([1]))
|
||||
self.start_epoch = tf_variables.Variable(
|
||||
[0] * (params.max_nodes), name='start_epoch')
|
||||
self.start_epoch = variable_scope.get_variable(
|
||||
name=self.get_tree_name('start_epoch', tree_num),
|
||||
dtype=dtypes.int32, shape=[params.max_nodes],
|
||||
initializer=init_ops.constant_initializer(0))
|
||||
|
||||
if training:
|
||||
self.node_to_accumulator_map = variable_scope.get_variable(
|
||||
|
Loading…
Reference in New Issue
Block a user