Pass shape into variable creation if rank is known

PiperOrigin-RevId: 247957852
This commit is contained in:
Gaurav Jain 2019-05-13 10:04:04 -07:00 committed by TensorFlower Gardener
parent 2b4a08db39
commit 822404edd2

View File

@ -25,6 +25,7 @@ from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.keras import backend
from tensorflow.python.ops import array_ops
@ -139,7 +140,8 @@ def make_variable(name,
# TODO(apassos,rohanj) figure out how to remove collections from here so we
# can remove the V1.
v = tf_variables.VariableV1(
variable_shape = tensor_shape.TensorShape(shape)
return tf_variables.VariableV1(
initial_value=init_val,
name=name,
trainable=trainable,
@ -150,8 +152,8 @@ def make_variable(name,
use_resource=use_resource,
collections=collections,
synchronization=synchronization,
aggregation=aggregation)
return v
aggregation=aggregation,
shape=variable_shape if variable_shape.rank else None)
def get_default_graph_uid_map():