Adds dict support of eval metrics.
PiperOrigin-RevId: 168310444
This commit is contained in:
parent
ab7f22de6a
commit
4b4e10f9c8
@ -22,6 +22,7 @@ from __future__ import print_function
|
||||
import collections
|
||||
import copy
|
||||
import threading
|
||||
import six
|
||||
from six.moves import queue as Queue # pylint: disable=redefined-builtin
|
||||
|
||||
from tensorflow.contrib.tpu.python.ops import tpu_ops
|
||||
@ -144,14 +145,16 @@ class TPUEstimatorSpec(collections.namedtuple('TPUEstimatorSpec', [
|
||||
|
||||
TPU evaluation expects a slightly different signature from the
|
||||
${tf.estimator.Estimator}. While `EstimatorSpec.eval_metric_ops` expects a
|
||||
dict, `TPUEstimatorSpec.eval_metrics` is a tuple of `metric_fn` and a tensor
|
||||
list. The tensor list specifies the list of tensors, usually model logits,
|
||||
which are transferred back from TPU system to CPU host. All tensors must have
|
||||
be batch-major, i.e., the batch size is the first dimension. Once all tensors
|
||||
are available at CPU host, they are joined and passed as positional arguments
|
||||
to the `metric_fn`. `metric_fn` takes the tensor list (concatenated on CPU
|
||||
from all shards) and returns a dict from metric string name to the result of
|
||||
calling a metric function, namely a `(metric_tensor, update_op)` tuple.
|
||||
dict, `TPUEstimatorSpec.eval_metrics` is a tuple of `metric_fn` and `tensors`.
|
||||
The `tensors` could be a list of `Tensor`s or dict of names to `Tensor`s. The
|
||||
`tensors` usually specify the model logits, which are transferred back from
|
||||
TPU system to CPU host. All tensors must have be batch-major, i.e., the batch
|
||||
size is the first dimension. Once all tensors are available at CPU host from
|
||||
all shards, they are concatenated (on CPU) and passed as positional arguments
|
||||
to the `metric_fn` if `tensors` is list or keyword arguments if `tensors` is
|
||||
dict. `metric_fn` takes the `tensors` and returns a dict from metric string
|
||||
name to the result of calling a metric function, namely a `(metric_tensor,
|
||||
update_op)` tuple.
|
||||
|
||||
See `TPUEstimator` for MNIST example how to specify the `eval_metrics`.
|
||||
"""
|
||||
@ -832,6 +835,8 @@ class _EvalMetrics(object):
|
||||
|
||||
def __init__(self):
|
||||
self._metric_fn = None
|
||||
self._is_dict = False
|
||||
self._tensor_keys = []
|
||||
self._tensors = []
|
||||
self._tensor_dtypes = []
|
||||
self._tensor_shapes = []
|
||||
@ -847,30 +852,60 @@ class _EvalMetrics(object):
|
||||
raise ValueError('eval_metrics should have two elements.')
|
||||
if not callable(eval_metrics[0]):
|
||||
raise TypeError('eval_metrics[0] should be callable.')
|
||||
if not isinstance(eval_metrics[1], (tuple, list)):
|
||||
raise ValueError('eval_metrics[1] should be tuple or list.')
|
||||
if not isinstance(eval_metrics[1], (tuple, list, dict)):
|
||||
raise ValueError('eval_metrics[1] should be tuple or list, or dict.')
|
||||
|
||||
fn_args = util.fn_args(eval_metrics[0])
|
||||
if len(eval_metrics[1]) != len(fn_args):
|
||||
raise RuntimeError(
|
||||
'In TPUEstimatorSpec.eval_metrics, length of tensors does not '
|
||||
'match method args of metric_fn.')
|
||||
if isinstance(eval_metrics[1], (tuple, list)):
|
||||
fn_args = util.fn_args(eval_metrics[0])
|
||||
if len(eval_metrics[1]) != len(fn_args):
|
||||
raise RuntimeError(
|
||||
'In TPUEstimatorSpec.eval_metrics, length of tensors does not '
|
||||
'match method args of metric_fn.')
|
||||
|
||||
@staticmethod
|
||||
def to_metric_metric_ops_for_cpu(eval_metrics):
|
||||
"""Converts `TPUEstimatorSpec.eval_metrics` to `eval_metric_ops` for CPU."""
|
||||
return (eval_metrics[0](*eval_metrics[1]) if eval_metrics is not None
|
||||
else None)
|
||||
if not eval_metrics:
|
||||
return None
|
||||
|
||||
_EvalMetrics.validate(eval_metrics)
|
||||
|
||||
metric_fn, tensors = eval_metrics
|
||||
|
||||
if isinstance(tensors, (tuple, list)):
|
||||
return metric_fn(*tensors)
|
||||
else:
|
||||
# Must be dict.
|
||||
try:
|
||||
return metric_fn(**tensors)
|
||||
except TypeError as e:
|
||||
logging.warning(
|
||||
'Exception while calling metric_fn for evalution: %s. '
|
||||
'It is likely the tensors (eval_metrics[1]) do not match the '
|
||||
'metric_fn arguments', e)
|
||||
raise e
|
||||
|
||||
def record(self, spec):
|
||||
"""Records the eval_metrics structure in `spec`."""
|
||||
if self._recorded:
|
||||
raise RuntimeError('Eval metrics have been recorded already.')
|
||||
|
||||
self._metric_fn, self._tensors = spec.eval_metrics
|
||||
self._metric_fn, tensor_list_or_dict = spec.eval_metrics
|
||||
|
||||
for tensor in self._tensors:
|
||||
self._tensor_dtypes.append(tensor.dtype)
|
||||
self._tensor_shapes.append(tensor.shape)
|
||||
if isinstance(tensor_list_or_dict, dict):
|
||||
self._is_dict = True
|
||||
for (key, tensor) in six.iteritems(tensor_list_or_dict):
|
||||
self._tensor_keys.append(key)
|
||||
self._tensors.append(tensor)
|
||||
self._tensor_dtypes.append(tensor.dtype)
|
||||
self._tensor_shapes.append(tensor.shape)
|
||||
else:
|
||||
# List or tuple.
|
||||
self._is_dict = False
|
||||
self._tensors = tensor_list_or_dict
|
||||
for tensor in tensor_list_or_dict:
|
||||
self._tensor_dtypes.append(tensor.dtype)
|
||||
self._tensor_shapes.append(tensor.shape)
|
||||
self._recorded = True
|
||||
|
||||
@property
|
||||
@ -928,7 +963,19 @@ class _EvalMetrics(object):
|
||||
'dimension, but got scalar {}'.format(dequeue_ops[i][0]))
|
||||
# TODO(xiejw): Allow users to specify the axis for batch size dimension.
|
||||
dequeue_ops[i] = array_ops.concat(dequeue_ops[i], axis=0)
|
||||
eval_metric_ops = self._metric_fn(*dequeue_ops)
|
||||
|
||||
if self._is_dict:
|
||||
dequeue_ops = dict(zip(self._tensor_keys, dequeue_ops))
|
||||
try:
|
||||
eval_metric_ops = self._metric_fn(**dequeue_ops)
|
||||
except TypeError as e:
|
||||
logging.warning(
|
||||
'Exception while calling metric_fn for evalution: %s. '
|
||||
'It is likely the tensors (eval_metrics[1]) do not match the '
|
||||
'metric_fn arguments', e)
|
||||
raise e
|
||||
else:
|
||||
eval_metric_ops = self._metric_fn(*dequeue_ops)
|
||||
|
||||
eval_update_ops = []
|
||||
for k, v in eval_metric_ops.items():
|
||||
@ -963,14 +1010,11 @@ class TPUEstimator(estimator_lib.Estimator):
|
||||
`TPUEstimatorSpec` instead of `EstimatorSpec`, which expects the
|
||||
`eval_metrics` for TPU evaluation.
|
||||
|
||||
`TPUEstimatorSpec.eval_metrics` is a tuple of `metric_fn` and a tensor list.
|
||||
The tensor list specifies the list of tensors, usually model logits, which are
|
||||
transferred back from TPU system to CPU host. All tensors must have be
|
||||
batch-major, i.e., the batch size is the first dimension. Once all tensors are
|
||||
available at CPU host, they are joined and passed as positional arguments to
|
||||
the `metric_fn`. `metric_fn` takes the tensor list (concatenated on CPU from
|
||||
all shards) and returns a dict from metric string name to the result of
|
||||
calling a metric function, namely a `(metric_tensor, update_op)` tuple.
|
||||
`TPUEstimatorSpec.eval_metrics` is a tuple of `metric_fn` and `tensors`, where
|
||||
`tensors` could be a list of `Tensor`s or dict of names to `Tensor`s. (See
|
||||
`TPUEstimatorSpec` for details). `metric_fn` takes the `tensors` and returns
|
||||
a dict from metric string name to the result of calling a metric function,
|
||||
namely a `(metric_tensor, update_op)` tuple.
|
||||
|
||||
Current limitations:
|
||||
|
||||
@ -988,7 +1032,7 @@ class TPUEstimator(estimator_lib.Estimator):
|
||||
labels=labels, predictions=predictions),
|
||||
}
|
||||
|
||||
# Your model Fn which runs on TPU.
|
||||
# Your model Fn which runs on TPU (eval_metrics is list in this example)
|
||||
def model_fn(features, labels, mode, config, params):
|
||||
...
|
||||
logits = ...
|
||||
@ -998,6 +1042,20 @@ class TPUEstimator(estimator_lib.Estimator):
|
||||
mode=mode,
|
||||
loss=loss,
|
||||
eval_metrics=(metric_fn, [labels, logits]))
|
||||
|
||||
# or specify the eval_metrics tensors as dict.
|
||||
def model_fn(features, labels, mode, config, params):
|
||||
...
|
||||
final_layer_output = ...
|
||||
|
||||
if mode = tf.estimator.ModeKeys.EVAL:
|
||||
return tpu_estimator.TPUEstimatorSpec(
|
||||
mode=mode,
|
||||
loss=loss,
|
||||
eval_metrics=(metric_fn, {
|
||||
'labels': labels,
|
||||
'logits': final_layer_output,
|
||||
}))
|
||||
```
|
||||
|
||||
Predict support on TPU is not yet implemented. So, `predict` and
|
||||
|
Loading…
Reference in New Issue
Block a user