Adds dict support of eval metrics.

PiperOrigin-RevId: 168310444
This commit is contained in:
Jianwei Xie 2017-09-11 17:24:01 -07:00 committed by TensorFlower Gardener
parent ab7f22de6a
commit 4b4e10f9c8

View File

@ -22,6 +22,7 @@ from __future__ import print_function
import collections
import copy
import threading
import six
from six.moves import queue as Queue # pylint: disable=redefined-builtin
from tensorflow.contrib.tpu.python.ops import tpu_ops
@ -144,14 +145,16 @@ class TPUEstimatorSpec(collections.namedtuple('TPUEstimatorSpec', [
TPU evaluation expects a slightly different signature from the
${tf.estimator.Estimator}. While `EstimatorSpec.eval_metric_ops` expects a
dict, `TPUEstimatorSpec.eval_metrics` is a tuple of `metric_fn` and a tensor
list. The tensor list specifies the list of tensors, usually model logits,
which are transferred back from TPU system to CPU host. All tensors must have
be batch-major, i.e., the batch size is the first dimension. Once all tensors
are available at CPU host, they are joined and passed as positional arguments
to the `metric_fn`. `metric_fn` takes the tensor list (concatenated on CPU
from all shards) and returns a dict from metric string name to the result of
calling a metric function, namely a `(metric_tensor, update_op)` tuple.
dict, `TPUEstimatorSpec.eval_metrics` is a tuple of `metric_fn` and `tensors`.
The `tensors` could be a list of `Tensor`s or dict of names to `Tensor`s. The
`tensors` usually specify the model logits, which are transferred back from
TPU system to CPU host. All tensors must have be batch-major, i.e., the batch
size is the first dimension. Once all tensors are available at CPU host from
all shards, they are concatenated (on CPU) and passed as positional arguments
to the `metric_fn` if `tensors` is list or keyword arguments if `tensors` is
dict. `metric_fn` takes the `tensors` and returns a dict from metric string
name to the result of calling a metric function, namely a `(metric_tensor,
update_op)` tuple.
See `TPUEstimator` for MNIST example how to specify the `eval_metrics`.
"""
@ -832,6 +835,8 @@ class _EvalMetrics(object):
def __init__(self):
self._metric_fn = None
self._is_dict = False
self._tensor_keys = []
self._tensors = []
self._tensor_dtypes = []
self._tensor_shapes = []
@ -847,30 +852,60 @@ class _EvalMetrics(object):
raise ValueError('eval_metrics should have two elements.')
if not callable(eval_metrics[0]):
raise TypeError('eval_metrics[0] should be callable.')
if not isinstance(eval_metrics[1], (tuple, list)):
raise ValueError('eval_metrics[1] should be tuple or list.')
if not isinstance(eval_metrics[1], (tuple, list, dict)):
raise ValueError('eval_metrics[1] should be tuple or list, or dict.')
fn_args = util.fn_args(eval_metrics[0])
if len(eval_metrics[1]) != len(fn_args):
raise RuntimeError(
'In TPUEstimatorSpec.eval_metrics, length of tensors does not '
'match method args of metric_fn.')
if isinstance(eval_metrics[1], (tuple, list)):
fn_args = util.fn_args(eval_metrics[0])
if len(eval_metrics[1]) != len(fn_args):
raise RuntimeError(
'In TPUEstimatorSpec.eval_metrics, length of tensors does not '
'match method args of metric_fn.')
@staticmethod
def to_metric_metric_ops_for_cpu(eval_metrics):
"""Converts `TPUEstimatorSpec.eval_metrics` to `eval_metric_ops` for CPU."""
return (eval_metrics[0](*eval_metrics[1]) if eval_metrics is not None
else None)
if not eval_metrics:
return None
_EvalMetrics.validate(eval_metrics)
metric_fn, tensors = eval_metrics
if isinstance(tensors, (tuple, list)):
return metric_fn(*tensors)
else:
# Must be dict.
try:
return metric_fn(**tensors)
except TypeError as e:
logging.warning(
'Exception while calling metric_fn for evalution: %s. '
'It is likely the tensors (eval_metrics[1]) do not match the '
'metric_fn arguments', e)
raise e
def record(self, spec):
"""Records the eval_metrics structure in `spec`."""
if self._recorded:
raise RuntimeError('Eval metrics have been recorded already.')
self._metric_fn, self._tensors = spec.eval_metrics
self._metric_fn, tensor_list_or_dict = spec.eval_metrics
for tensor in self._tensors:
self._tensor_dtypes.append(tensor.dtype)
self._tensor_shapes.append(tensor.shape)
if isinstance(tensor_list_or_dict, dict):
self._is_dict = True
for (key, tensor) in six.iteritems(tensor_list_or_dict):
self._tensor_keys.append(key)
self._tensors.append(tensor)
self._tensor_dtypes.append(tensor.dtype)
self._tensor_shapes.append(tensor.shape)
else:
# List or tuple.
self._is_dict = False
self._tensors = tensor_list_or_dict
for tensor in tensor_list_or_dict:
self._tensor_dtypes.append(tensor.dtype)
self._tensor_shapes.append(tensor.shape)
self._recorded = True
@property
@ -928,7 +963,19 @@ class _EvalMetrics(object):
'dimension, but got scalar {}'.format(dequeue_ops[i][0]))
# TODO(xiejw): Allow users to specify the axis for batch size dimension.
dequeue_ops[i] = array_ops.concat(dequeue_ops[i], axis=0)
eval_metric_ops = self._metric_fn(*dequeue_ops)
if self._is_dict:
dequeue_ops = dict(zip(self._tensor_keys, dequeue_ops))
try:
eval_metric_ops = self._metric_fn(**dequeue_ops)
except TypeError as e:
logging.warning(
'Exception while calling metric_fn for evalution: %s. '
'It is likely the tensors (eval_metrics[1]) do not match the '
'metric_fn arguments', e)
raise e
else:
eval_metric_ops = self._metric_fn(*dequeue_ops)
eval_update_ops = []
for k, v in eval_metric_ops.items():
@ -963,14 +1010,11 @@ class TPUEstimator(estimator_lib.Estimator):
`TPUEstimatorSpec` instead of `EstimatorSpec`, which expects the
`eval_metrics` for TPU evaluation.
`TPUEstimatorSpec.eval_metrics` is a tuple of `metric_fn` and a tensor list.
The tensor list specifies the list of tensors, usually model logits, which are
transferred back from TPU system to CPU host. All tensors must have be
batch-major, i.e., the batch size is the first dimension. Once all tensors are
available at CPU host, they are joined and passed as positional arguments to
the `metric_fn`. `metric_fn` takes the tensor list (concatenated on CPU from
all shards) and returns a dict from metric string name to the result of
calling a metric function, namely a `(metric_tensor, update_op)` tuple.
`TPUEstimatorSpec.eval_metrics` is a tuple of `metric_fn` and `tensors`, where
`tensors` could be a list of `Tensor`s or dict of names to `Tensor`s. (See
`TPUEstimatorSpec` for details). `metric_fn` takes the `tensors` and returns
a dict from metric string name to the result of calling a metric function,
namely a `(metric_tensor, update_op)` tuple.
Current limitations:
@ -988,7 +1032,7 @@ class TPUEstimator(estimator_lib.Estimator):
labels=labels, predictions=predictions),
}
# Your model Fn which runs on TPU.
# Your model Fn which runs on TPU (eval_metrics is list in this example)
def model_fn(features, labels, mode, config, params):
...
logits = ...
@ -998,6 +1042,20 @@ class TPUEstimator(estimator_lib.Estimator):
mode=mode,
loss=loss,
eval_metrics=(metric_fn, [labels, logits]))
# or specify the eval_metrics tensors as dict.
def model_fn(features, labels, mode, config, params):
...
final_layer_output = ...
if mode = tf.estimator.ModeKeys.EVAL:
return tpu_estimator.TPUEstimatorSpec(
mode=mode,
loss=loss,
eval_metrics=(metric_fn, {
'labels': labels,
'logits': final_layer_output,
}))
```
Predict support on TPU is not yet implemented. So, `predict` and