Pass in the variable shape if already known.

PiperOrigin-RevId: 247729180
This commit is contained in:
Gaurav Jain 2019-05-10 23:33:46 -07:00 committed by TensorFlower Gardener
parent 1673f77327
commit 0e983f8e0b

View File

@ -25,6 +25,7 @@ from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.keras import backend
from tensorflow.python.ops import array_ops
@ -139,7 +140,7 @@ def make_variable(name,
# TODO(apassos,rohanj) figure out how to remove collections from here so we
# can remove the V1.
v = tf_variables.VariableV1(
return tf_variables.VariableV1(
initial_value=init_val,
name=name,
trainable=trainable,
@ -150,8 +151,8 @@ def make_variable(name,
use_resource=use_resource,
collections=collections,
synchronization=synchronization,
aggregation=aggregation)
return v
aggregation=aggregation,
shape=tensor_shape.TensorShape(shape) if shape else None)
def get_default_graph_uid_map():