diff --git a/tensorflow/python/keras/engine/base_layer_utils.py b/tensorflow/python/keras/engine/base_layer_utils.py index 931fb5f65e5..b419ca7341e 100644 --- a/tensorflow/python/keras/engine/base_layer_utils.py +++ b/tensorflow/python/keras/engine/base_layer_utils.py @@ -25,7 +25,6 @@ from tensorflow.python.distribute import distribution_strategy_context from tensorflow.python.eager import context from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util from tensorflow.python.keras import backend from tensorflow.python.ops import array_ops @@ -140,7 +139,7 @@ def make_variable(name, # TODO(apassos,rohanj) figure out how to remove collections from here so we # can remove the V1. - return tf_variables.VariableV1( + v = tf_variables.VariableV1( initial_value=init_val, name=name, trainable=trainable, @@ -151,8 +150,8 @@ def make_variable(name, use_resource=use_resource, collections=collections, synchronization=synchronization, - aggregation=aggregation, - shape=tensor_shape.TensorShape(shape) if shape else None) + aggregation=aggregation) + return v def get_default_graph_uid_map():