Adds dict support of eval metrics.

PiperOrigin-RevId: 168310444
This commit is contained in:
Jianwei Xie 2017-09-11 17:24:01 -07:00 committed by TensorFlower Gardener
parent ab7f22de6a
commit 4b4e10f9c8

View File

@ -22,6 +22,7 @@ from __future__ import print_function
import collections import collections
import copy import copy
import threading import threading
import six
from six.moves import queue as Queue # pylint: disable=redefined-builtin from six.moves import queue as Queue # pylint: disable=redefined-builtin
from tensorflow.contrib.tpu.python.ops import tpu_ops from tensorflow.contrib.tpu.python.ops import tpu_ops
@ -144,14 +145,16 @@ class TPUEstimatorSpec(collections.namedtuple('TPUEstimatorSpec', [
TPU evaluation expects a slightly different signature from the TPU evaluation expects a slightly different signature from the
${tf.estimator.Estimator}. While `EstimatorSpec.eval_metric_ops` expects a ${tf.estimator.Estimator}. While `EstimatorSpec.eval_metric_ops` expects a
dict, `TPUEstimatorSpec.eval_metrics` is a tuple of `metric_fn` and a tensor dict, `TPUEstimatorSpec.eval_metrics` is a tuple of `metric_fn` and `tensors`.
list. The tensor list specifies the list of tensors, usually model logits, The `tensors` could be a list of `Tensor`s or dict of names to `Tensor`s. The
which are transferred back from TPU system to CPU host. All tensors must have `tensors` usually specify the model logits, which are transferred back from
be batch-major, i.e., the batch size is the first dimension. Once all tensors TPU system to CPU host. All tensors must have be batch-major, i.e., the batch
are available at CPU host, they are joined and passed as positional arguments size is the first dimension. Once all tensors are available at CPU host from
to the `metric_fn`. `metric_fn` takes the tensor list (concatenated on CPU all shards, they are concatenated (on CPU) and passed as positional arguments
from all shards) and returns a dict from metric string name to the result of to the `metric_fn` if `tensors` is list or keyword arguments if `tensors` is
calling a metric function, namely a `(metric_tensor, update_op)` tuple. dict. `metric_fn` takes the `tensors` and returns a dict from metric string
name to the result of calling a metric function, namely a `(metric_tensor,
update_op)` tuple.
See `TPUEstimator` for MNIST example how to specify the `eval_metrics`. See `TPUEstimator` for MNIST example how to specify the `eval_metrics`.
""" """
@ -832,6 +835,8 @@ class _EvalMetrics(object):
def __init__(self): def __init__(self):
self._metric_fn = None self._metric_fn = None
self._is_dict = False
self._tensor_keys = []
self._tensors = [] self._tensors = []
self._tensor_dtypes = [] self._tensor_dtypes = []
self._tensor_shapes = [] self._tensor_shapes = []
@ -847,30 +852,60 @@ class _EvalMetrics(object):
raise ValueError('eval_metrics should have two elements.') raise ValueError('eval_metrics should have two elements.')
if not callable(eval_metrics[0]): if not callable(eval_metrics[0]):
raise TypeError('eval_metrics[0] should be callable.') raise TypeError('eval_metrics[0] should be callable.')
if not isinstance(eval_metrics[1], (tuple, list)): if not isinstance(eval_metrics[1], (tuple, list, dict)):
raise ValueError('eval_metrics[1] should be tuple or list.') raise ValueError('eval_metrics[1] should be tuple or list, or dict.')
fn_args = util.fn_args(eval_metrics[0]) if isinstance(eval_metrics[1], (tuple, list)):
if len(eval_metrics[1]) != len(fn_args): fn_args = util.fn_args(eval_metrics[0])
raise RuntimeError( if len(eval_metrics[1]) != len(fn_args):
'In TPUEstimatorSpec.eval_metrics, length of tensors does not ' raise RuntimeError(
'match method args of metric_fn.') 'In TPUEstimatorSpec.eval_metrics, length of tensors does not '
'match method args of metric_fn.')
@staticmethod @staticmethod
def to_metric_metric_ops_for_cpu(eval_metrics): def to_metric_metric_ops_for_cpu(eval_metrics):
"""Converts `TPUEstimatorSpec.eval_metrics` to `eval_metric_ops` for CPU.""" """Converts `TPUEstimatorSpec.eval_metrics` to `eval_metric_ops` for CPU."""
return (eval_metrics[0](*eval_metrics[1]) if eval_metrics is not None if not eval_metrics:
else None) return None
_EvalMetrics.validate(eval_metrics)
metric_fn, tensors = eval_metrics
if isinstance(tensors, (tuple, list)):
return metric_fn(*tensors)
else:
# Must be dict.
try:
return metric_fn(**tensors)
except TypeError as e:
logging.warning(
'Exception while calling metric_fn for evalution: %s. '
'It is likely the tensors (eval_metrics[1]) do not match the '
'metric_fn arguments', e)
raise e
def record(self, spec): def record(self, spec):
"""Records the eval_metrics structure in `spec`."""
if self._recorded: if self._recorded:
raise RuntimeError('Eval metrics have been recorded already.') raise RuntimeError('Eval metrics have been recorded already.')
self._metric_fn, self._tensors = spec.eval_metrics self._metric_fn, tensor_list_or_dict = spec.eval_metrics
for tensor in self._tensors: if isinstance(tensor_list_or_dict, dict):
self._tensor_dtypes.append(tensor.dtype) self._is_dict = True
self._tensor_shapes.append(tensor.shape) for (key, tensor) in six.iteritems(tensor_list_or_dict):
self._tensor_keys.append(key)
self._tensors.append(tensor)
self._tensor_dtypes.append(tensor.dtype)
self._tensor_shapes.append(tensor.shape)
else:
# List or tuple.
self._is_dict = False
self._tensors = tensor_list_or_dict
for tensor in tensor_list_or_dict:
self._tensor_dtypes.append(tensor.dtype)
self._tensor_shapes.append(tensor.shape)
self._recorded = True self._recorded = True
@property @property
@ -928,7 +963,19 @@ class _EvalMetrics(object):
'dimension, but got scalar {}'.format(dequeue_ops[i][0])) 'dimension, but got scalar {}'.format(dequeue_ops[i][0]))
# TODO(xiejw): Allow users to specify the axis for batch size dimension. # TODO(xiejw): Allow users to specify the axis for batch size dimension.
dequeue_ops[i] = array_ops.concat(dequeue_ops[i], axis=0) dequeue_ops[i] = array_ops.concat(dequeue_ops[i], axis=0)
eval_metric_ops = self._metric_fn(*dequeue_ops)
if self._is_dict:
dequeue_ops = dict(zip(self._tensor_keys, dequeue_ops))
try:
eval_metric_ops = self._metric_fn(**dequeue_ops)
except TypeError as e:
logging.warning(
'Exception while calling metric_fn for evalution: %s. '
'It is likely the tensors (eval_metrics[1]) do not match the '
'metric_fn arguments', e)
raise e
else:
eval_metric_ops = self._metric_fn(*dequeue_ops)
eval_update_ops = [] eval_update_ops = []
for k, v in eval_metric_ops.items(): for k, v in eval_metric_ops.items():
@ -963,14 +1010,11 @@ class TPUEstimator(estimator_lib.Estimator):
`TPUEstimatorSpec` instead of `EstimatorSpec`, which expects the `TPUEstimatorSpec` instead of `EstimatorSpec`, which expects the
`eval_metrics` for TPU evaluation. `eval_metrics` for TPU evaluation.
`TPUEstimatorSpec.eval_metrics` is a tuple of `metric_fn` and a tensor list. `TPUEstimatorSpec.eval_metrics` is a tuple of `metric_fn` and `tensors`, where
The tensor list specifies the list of tensors, usually model logits, which are `tensors` could be a list of `Tensor`s or dict of names to `Tensor`s. (See
transferred back from TPU system to CPU host. All tensors must have be `TPUEstimatorSpec` for details). `metric_fn` takes the `tensors` and returns
batch-major, i.e., the batch size is the first dimension. Once all tensors are a dict from metric string name to the result of calling a metric function,
available at CPU host, they are joined and passed as positional arguments to namely a `(metric_tensor, update_op)` tuple.
the `metric_fn`. `metric_fn` takes the tensor list (concatenated on CPU from
all shards) and returns a dict from metric string name to the result of
calling a metric function, namely a `(metric_tensor, update_op)` tuple.
Current limitations: Current limitations:
@ -988,7 +1032,7 @@ class TPUEstimator(estimator_lib.Estimator):
labels=labels, predictions=predictions), labels=labels, predictions=predictions),
} }
# Your model Fn which runs on TPU. # Your model Fn which runs on TPU (eval_metrics is list in this example)
def model_fn(features, labels, mode, config, params): def model_fn(features, labels, mode, config, params):
... ...
logits = ... logits = ...
@ -998,6 +1042,20 @@ class TPUEstimator(estimator_lib.Estimator):
mode=mode, mode=mode,
loss=loss, loss=loss,
eval_metrics=(metric_fn, [labels, logits])) eval_metrics=(metric_fn, [labels, logits]))
# or specify the eval_metrics tensors as dict.
def model_fn(features, labels, mode, config, params):
...
final_layer_output = ...
if mode = tf.estimator.ModeKeys.EVAL:
return tpu_estimator.TPUEstimatorSpec(
mode=mode,
loss=loss,
eval_metrics=(metric_fn, {
'labels': labels,
'logits': final_layer_output,
}))
``` ```
Predict support on TPU is not yet implemented. So, `predict` and Predict support on TPU is not yet implemented. So, `predict` and