Pass shape into variable creation if rank is known
PiperOrigin-RevId: 247957852
This commit is contained in:
parent
2b4a08db39
commit
822404edd2
@ -25,6 +25,7 @@ from tensorflow.python.distribute import distribution_strategy_context
|
|||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.framework import tensor_shape
|
||||||
from tensorflow.python.framework import tensor_util
|
from tensorflow.python.framework import tensor_util
|
||||||
from tensorflow.python.keras import backend
|
from tensorflow.python.keras import backend
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
@ -139,7 +140,8 @@ def make_variable(name,
|
|||||||
|
|
||||||
# TODO(apassos,rohanj) figure out how to remove collections from here so we
|
# TODO(apassos,rohanj) figure out how to remove collections from here so we
|
||||||
# can remove the V1.
|
# can remove the V1.
|
||||||
v = tf_variables.VariableV1(
|
variable_shape = tensor_shape.TensorShape(shape)
|
||||||
|
return tf_variables.VariableV1(
|
||||||
initial_value=init_val,
|
initial_value=init_val,
|
||||||
name=name,
|
name=name,
|
||||||
trainable=trainable,
|
trainable=trainable,
|
||||||
@ -150,8 +152,8 @@ def make_variable(name,
|
|||||||
use_resource=use_resource,
|
use_resource=use_resource,
|
||||||
collections=collections,
|
collections=collections,
|
||||||
synchronization=synchronization,
|
synchronization=synchronization,
|
||||||
aggregation=aggregation)
|
aggregation=aggregation,
|
||||||
return v
|
shape=variable_shape if variable_shape.rank else None)
|
||||||
|
|
||||||
|
|
||||||
def get_default_graph_uid_map():
|
def get_default_graph_uid_map():
|
||||||
|
Loading…
x
Reference in New Issue
Block a user