Remove potentially large python constant from TensorForest graph.

Change: 136825477
This commit is contained in:
A. Unique TensorFlower 2016-10-21 05:26:29 -08:00 committed by TensorFlower Gardener
parent 00526b6f86
commit 3a5b605d8e

View File

@ -34,7 +34,6 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables as tf_variables
from tensorflow.python.platform import tf_logging as logging
@ -164,8 +163,10 @@ class TreeTrainingVariables(object):
name=self.get_tree_name('end_of_tree', tree_num),
dtype=dtypes.int32,
initializer=constant_op.constant([1]))
self.start_epoch = tf_variables.Variable(
[0] * (params.max_nodes), name='start_epoch')
self.start_epoch = variable_scope.get_variable(
name=self.get_tree_name('start_epoch', tree_num),
dtype=dtypes.int32, shape=[params.max_nodes],
initializer=init_ops.constant_initializer(0))
if training:
self.node_to_accumulator_map = variable_scope.get_variable(