From e3ac98f25b9ff412453d2b7311e1bae58becf1da Mon Sep 17 00:00:00 2001 From: David Soergel Date: Wed, 1 Feb 2017 11:16:28 -0800 Subject: [PATCH] Fix another bug in computing shape for SDCA fake bias column Change: 146266321 --- .../learn/python/learn/estimators/linear.py | 21 +++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/tensorflow/contrib/learn/python/learn/estimators/linear.py b/tensorflow/contrib/learn/python/learn/estimators/linear.py index 3505723ebed..1113f436478 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/linear.py +++ b/tensorflow/contrib/learn/python/learn/estimators/linear.py @@ -78,12 +78,21 @@ def _add_bias_column(feature_columns, columns_to_tensors, bias_variable, if not feature_columns: raise ValueError("feature_columns can't be empty.") - # Using an arbitrary input tensor to figure out batch_size. - some_input = next(iter(columns_to_tensors.values())) - if isinstance(some_input, sparse_tensor.SparseTensor): - batch_size = tensor_util.constant_value(some_input.dense_shape)[0] - else: - batch_size = array_ops.shape(some_input)[0] + # Loop through input tensors until we can figure out batch_size. + batch_size = None + for column in columns_to_tensors.values(): + if isinstance(column, tuple): + column = column[0] + if isinstance(column, sparse_tensor.SparseTensor): + shape = tensor_util.constant_value(column.dense_shape) + if shape is not None: + batch_size = shape[0] + break + else: + batch_size = array_ops.shape(column)[0] + break + if batch_size is None: + raise ValueError("Could not infer batch size from input features.") bias_column = layers.real_valued_column(bias_column_name) columns_to_tensors[bias_column] = array_ops.ones([batch_size, 1],