From 0e983f8e0b421b2bab0eaee7c454f0b27a0d9dc2 Mon Sep 17 00:00:00 2001 From: Gaurav Jain Date: Fri, 10 May 2019 23:33:46 -0700 Subject: [PATCH] Pass in the variable shape if already known. PiperOrigin-RevId: 247729180 --- tensorflow/python/keras/engine/base_layer_utils.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tensorflow/python/keras/engine/base_layer_utils.py b/tensorflow/python/keras/engine/base_layer_utils.py index b419ca7341e..931fb5f65e5 100644 --- a/tensorflow/python/keras/engine/base_layer_utils.py +++ b/tensorflow/python/keras/engine/base_layer_utils.py @@ -25,6 +25,7 @@ from tensorflow.python.distribute import distribution_strategy_context from tensorflow.python.eager import context from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util from tensorflow.python.keras import backend from tensorflow.python.ops import array_ops @@ -139,7 +140,7 @@ def make_variable(name, # TODO(apassos,rohanj) figure out how to remove collections from here so we # can remove the V1. - v = tf_variables.VariableV1( + return tf_variables.VariableV1( initial_value=init_val, name=name, trainable=trainable, @@ -150,8 +151,8 @@ def make_variable(name, use_resource=use_resource, collections=collections, synchronization=synchronization, - aggregation=aggregation) - return v + aggregation=aggregation, + shape=tensor_shape.TensorShape(shape) if shape else None) def get_default_graph_uid_map():