From 62cf561f1d32abf4f5b7fbdee6d106389994ff05 Mon Sep 17 00:00:00 2001 From: Jianwei Xie Date: Mon, 29 May 2017 10:29:49 -0700 Subject: [PATCH] Add numpy_input_fn integration for LinearRegressor and fix the expand_dim for label and weight. PiperOrigin-RevId: 157405237 --- tensorflow/python/estimator/canned/head.py | 19 +++++- .../python/estimator/canned/head_test.py | 60 +++++++++++++++++++ .../python/estimator/canned/linear_test.py | 40 +++++++++++++ 3 files changed, 117 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/estimator/canned/head.py b/tensorflow/python/estimator/canned/head.py index 2296e75d90d..22382877865 100644 --- a/tensorflow/python/estimator/canned/head.py +++ b/tensorflow/python/estimator/canned/head.py @@ -137,6 +137,20 @@ class _Head(object): raise NotImplementedError('Calling an abstract method.') +def _maybe_expand_dim(tensor): + """Expand the dim of `tensor` with static rank 1.""" + with ops.name_scope(None, 'maybe_expand_dim', (tensor,)): + tensor = sparse_tensor.convert_to_tensor_or_sparse_tensor(tensor) + if isinstance(tensor, sparse_tensor.SparseTensor): + raise ValueError('SparseTensor labels are not supported.') + static_shape = tensor.shape + if static_shape is None: + return tensor + + return (array_ops.expand_dims(tensor, -1) if static_shape.ndims == 1 + else tensor) + + def _check_labels(labels, expected_labels_dimension): """Check labels type and shape.""" with ops.name_scope(None, 'labels', (labels,)) as scope: @@ -669,13 +683,14 @@ class _RegressionHeadWithMeanSquaredErrorLoss(_Head): export_outputs={'': export_output.RegressionOutput(value=logits)}) # Eval. - labels = _check_labels(math_ops.to_float(labels), self._logits_dimension) + labels = _check_labels(_maybe_expand_dim(math_ops.to_float(labels)), + self._logits_dimension) unweighted_loss = losses.mean_squared_error( labels=labels, predictions=logits, reduction=losses.Reduction.NONE) weights = ( 1. if (self._weight_feature_key is None) else features[self._weight_feature_key]) - weights = math_ops.to_float(weights, name='weights') + weights = _maybe_expand_dim(math_ops.to_float(weights, name='weights')) training_loss = losses.compute_weighted_loss( unweighted_loss, weights=weights, reduction=losses.Reduction.SUM) if mode == model_fn.ModeKeys.EVAL: diff --git a/tensorflow/python/estimator/canned/head_test.py b/tensorflow/python/estimator/canned/head_test.py index f076a9c44f9..c1f8b844e92 100644 --- a/tensorflow/python/estimator/canned/head_test.py +++ b/tensorflow/python/estimator/canned/head_test.py @@ -1366,6 +1366,66 @@ class RegressionHeadWithMeanSquaredErrorLossTest(test.TestCase): metric_keys.MetricKeys.LOSS_MEAN: 39.0769231, }, summary_str) + def test_with_one_dim_label_and_weight(self): + """1d label, 3 examples, 1 batch.""" + head = head_lib._regression_head_with_mean_squared_error_loss( + weight_feature_key='label_weights') + self.assertEqual(1, head.logits_dimension) + + # Create estimator spec. + logits = np.array(((45,), (41,), (44,)), dtype=np.float32) + expected_train_result = b'my_train_op' + # loss = 1*(35-45)^2 + .1*(42-41)^2 + 1.5*(45-44)^2 = 100+.1+1.5 = 101.6 + expected_loss = 101.6 + def _train_op_fn(loss): + with ops.control_dependencies((check_ops.assert_equal( + math_ops.to_float(expected_loss), math_ops.to_float(loss), + name='assert_loss'),)): + return constant_op.constant(expected_train_result) + + x_feature_rank_1 = np.array((42., 43., 44.,), dtype=np.float32) + weight_rank_1 = np.array((1., .1, 1.5,), dtype=np.float64) + labels_rank_1 = np.array((35., 42., 45.,)) + self.assertEqual((3,), x_feature_rank_1.shape) + self.assertEqual((3,), weight_rank_1.shape) + self.assertEqual((3,), labels_rank_1.shape) + + spec = head.create_estimator_spec( + features={ + 'x': x_feature_rank_1, + 'label_weights': weight_rank_1, + }, + mode=model_fn.ModeKeys.TRAIN, + logits=logits, + labels=labels_rank_1, + train_op_fn=_train_op_fn) + + # Assert spec contains expected tensors. + prediction_key = prediction_keys.PredictionKeys.PREDICTIONS + self.assertItemsEqual((prediction_key,), spec.predictions.keys()) + self.assertEqual(dtypes.float32, spec.predictions[prediction_key].dtype) + self.assertEqual(dtypes.float32, spec.loss.dtype) + self.assertEqual({}, spec.eval_metric_ops) + self.assertIsNotNone(spec.train_op) + self.assertIsNone(spec.export_outputs) + _assert_no_hooks(self, spec) + + # Assert predictions, loss, train_op, and summaries. + with self.test_session() as sess: + _initialize_variables(self, spec.scaffold) + self.assertIsNotNone(spec.scaffold.summary_op) + predictions, loss, train_result, summary_str = sess.run(( + spec.predictions[prediction_key], spec.loss, spec.train_op, + spec.scaffold.summary_op)) + self.assertAllClose(logits, predictions) + self.assertAllClose(expected_loss, loss) + self.assertEqual(expected_train_result, train_result) + _assert_simple_summaries(self, { + metric_keys.MetricKeys.LOSS: expected_loss, + # loss_mean = loss/(1+.1+1.5) = 101.6/2.6 = 39.0769231 + metric_keys.MetricKeys.LOSS_MEAN: 39.0769231, + }, summary_str) + def test_weighted_multi_value_eval(self): """3d label, 1 example, 1 batch.""" head = head_lib._regression_head_with_mean_squared_error_loss( diff --git a/tensorflow/python/estimator/canned/linear_test.py b/tensorflow/python/estimator/canned/linear_test.py index f356ce9e73b..a5abf367431 100644 --- a/tensorflow/python/estimator/canned/linear_test.py +++ b/tensorflow/python/estimator/canned/linear_test.py @@ -574,6 +574,46 @@ class LinearRegressorTrainingTest(test.TestCase): input_fn=lambda: ({'age': ((age,),)}, ((label,),)), steps=num_steps) self._assertCheckpoint(num_steps) + def testTrainWithOneDimLabel(self): + label_dimension = 1 + batch_size = 20 + feature_columns = [ + feature_column_lib.numeric_column('age', shape=(1,)) + ] + est = linear.LinearRegressor( + feature_columns=feature_columns, label_dimension=label_dimension, + model_dir=self._model_dir) + data_rank_1 = np.linspace(0., 2., batch_size, dtype=np.float32) + self.assertEqual((batch_size,), data_rank_1.shape) + + train_input_fn = numpy_io.numpy_input_fn( + x={'age': data_rank_1}, y=data_rank_1, + batch_size=batch_size, num_epochs=None, + shuffle=True) + est.train(train_input_fn, steps=200) + self._assertCheckpoint(200) + + def testTrainWithOneDimWeight(self): + label_dimension = 1 + batch_size = 20 + feature_columns = [ + feature_column_lib.numeric_column('age', shape=(1,)) + ] + est = linear.LinearRegressor( + feature_columns=feature_columns, label_dimension=label_dimension, + weight_feature_key='w', + model_dir=self._model_dir) + + data_rank_1 = np.linspace(0., 2., batch_size, dtype=np.float32) + self.assertEqual((batch_size,), data_rank_1.shape) + + train_input_fn = numpy_io.numpy_input_fn( + x={'age': data_rank_1, 'w': data_rank_1}, y=data_rank_1, + batch_size=batch_size, num_epochs=None, + shuffle=True) + est.train(train_input_fn, steps=200) + self._assertCheckpoint(200) + def testFromScratch(self): # Create LinearRegressor. label = 5.