Pass shape into variable creation if rank is known
PiperOrigin-RevId: 247957852
This commit is contained in:
parent
2b4a08db39
commit
822404edd2
@ -25,6 +25,7 @@ from tensorflow.python.distribute import distribution_strategy_context
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.keras import backend
|
||||
from tensorflow.python.ops import array_ops
|
||||
@ -139,7 +140,8 @@ def make_variable(name,
|
||||
|
||||
# TODO(apassos,rohanj) figure out how to remove collections from here so we
|
||||
# can remove the V1.
|
||||
v = tf_variables.VariableV1(
|
||||
variable_shape = tensor_shape.TensorShape(shape)
|
||||
return tf_variables.VariableV1(
|
||||
initial_value=init_val,
|
||||
name=name,
|
||||
trainable=trainable,
|
||||
@ -150,8 +152,8 @@ def make_variable(name,
|
||||
use_resource=use_resource,
|
||||
collections=collections,
|
||||
synchronization=synchronization,
|
||||
aggregation=aggregation)
|
||||
return v
|
||||
aggregation=aggregation,
|
||||
shape=variable_shape if variable_shape.rank else None)
|
||||
|
||||
|
||||
def get_default_graph_uid_map():
|
||||
|
Loading…
x
Reference in New Issue
Block a user