Adds dict support of eval metrics.
PiperOrigin-RevId: 168310444
This commit is contained in:
parent
ab7f22de6a
commit
4b4e10f9c8
@ -22,6 +22,7 @@ from __future__ import print_function
|
|||||||
import collections
|
import collections
|
||||||
import copy
|
import copy
|
||||||
import threading
|
import threading
|
||||||
|
import six
|
||||||
from six.moves import queue as Queue # pylint: disable=redefined-builtin
|
from six.moves import queue as Queue # pylint: disable=redefined-builtin
|
||||||
|
|
||||||
from tensorflow.contrib.tpu.python.ops import tpu_ops
|
from tensorflow.contrib.tpu.python.ops import tpu_ops
|
||||||
@ -144,14 +145,16 @@ class TPUEstimatorSpec(collections.namedtuple('TPUEstimatorSpec', [
|
|||||||
|
|
||||||
TPU evaluation expects a slightly different signature from the
|
TPU evaluation expects a slightly different signature from the
|
||||||
${tf.estimator.Estimator}. While `EstimatorSpec.eval_metric_ops` expects a
|
${tf.estimator.Estimator}. While `EstimatorSpec.eval_metric_ops` expects a
|
||||||
dict, `TPUEstimatorSpec.eval_metrics` is a tuple of `metric_fn` and a tensor
|
dict, `TPUEstimatorSpec.eval_metrics` is a tuple of `metric_fn` and `tensors`.
|
||||||
list. The tensor list specifies the list of tensors, usually model logits,
|
The `tensors` could be a list of `Tensor`s or dict of names to `Tensor`s. The
|
||||||
which are transferred back from TPU system to CPU host. All tensors must have
|
`tensors` usually specify the model logits, which are transferred back from
|
||||||
be batch-major, i.e., the batch size is the first dimension. Once all tensors
|
TPU system to CPU host. All tensors must have be batch-major, i.e., the batch
|
||||||
are available at CPU host, they are joined and passed as positional arguments
|
size is the first dimension. Once all tensors are available at CPU host from
|
||||||
to the `metric_fn`. `metric_fn` takes the tensor list (concatenated on CPU
|
all shards, they are concatenated (on CPU) and passed as positional arguments
|
||||||
from all shards) and returns a dict from metric string name to the result of
|
to the `metric_fn` if `tensors` is list or keyword arguments if `tensors` is
|
||||||
calling a metric function, namely a `(metric_tensor, update_op)` tuple.
|
dict. `metric_fn` takes the `tensors` and returns a dict from metric string
|
||||||
|
name to the result of calling a metric function, namely a `(metric_tensor,
|
||||||
|
update_op)` tuple.
|
||||||
|
|
||||||
See `TPUEstimator` for MNIST example how to specify the `eval_metrics`.
|
See `TPUEstimator` for MNIST example how to specify the `eval_metrics`.
|
||||||
"""
|
"""
|
||||||
@ -832,6 +835,8 @@ class _EvalMetrics(object):
|
|||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._metric_fn = None
|
self._metric_fn = None
|
||||||
|
self._is_dict = False
|
||||||
|
self._tensor_keys = []
|
||||||
self._tensors = []
|
self._tensors = []
|
||||||
self._tensor_dtypes = []
|
self._tensor_dtypes = []
|
||||||
self._tensor_shapes = []
|
self._tensor_shapes = []
|
||||||
@ -847,30 +852,60 @@ class _EvalMetrics(object):
|
|||||||
raise ValueError('eval_metrics should have two elements.')
|
raise ValueError('eval_metrics should have two elements.')
|
||||||
if not callable(eval_metrics[0]):
|
if not callable(eval_metrics[0]):
|
||||||
raise TypeError('eval_metrics[0] should be callable.')
|
raise TypeError('eval_metrics[0] should be callable.')
|
||||||
if not isinstance(eval_metrics[1], (tuple, list)):
|
if not isinstance(eval_metrics[1], (tuple, list, dict)):
|
||||||
raise ValueError('eval_metrics[1] should be tuple or list.')
|
raise ValueError('eval_metrics[1] should be tuple or list, or dict.')
|
||||||
|
|
||||||
fn_args = util.fn_args(eval_metrics[0])
|
if isinstance(eval_metrics[1], (tuple, list)):
|
||||||
if len(eval_metrics[1]) != len(fn_args):
|
fn_args = util.fn_args(eval_metrics[0])
|
||||||
raise RuntimeError(
|
if len(eval_metrics[1]) != len(fn_args):
|
||||||
'In TPUEstimatorSpec.eval_metrics, length of tensors does not '
|
raise RuntimeError(
|
||||||
'match method args of metric_fn.')
|
'In TPUEstimatorSpec.eval_metrics, length of tensors does not '
|
||||||
|
'match method args of metric_fn.')
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def to_metric_metric_ops_for_cpu(eval_metrics):
|
def to_metric_metric_ops_for_cpu(eval_metrics):
|
||||||
"""Converts `TPUEstimatorSpec.eval_metrics` to `eval_metric_ops` for CPU."""
|
"""Converts `TPUEstimatorSpec.eval_metrics` to `eval_metric_ops` for CPU."""
|
||||||
return (eval_metrics[0](*eval_metrics[1]) if eval_metrics is not None
|
if not eval_metrics:
|
||||||
else None)
|
return None
|
||||||
|
|
||||||
|
_EvalMetrics.validate(eval_metrics)
|
||||||
|
|
||||||
|
metric_fn, tensors = eval_metrics
|
||||||
|
|
||||||
|
if isinstance(tensors, (tuple, list)):
|
||||||
|
return metric_fn(*tensors)
|
||||||
|
else:
|
||||||
|
# Must be dict.
|
||||||
|
try:
|
||||||
|
return metric_fn(**tensors)
|
||||||
|
except TypeError as e:
|
||||||
|
logging.warning(
|
||||||
|
'Exception while calling metric_fn for evalution: %s. '
|
||||||
|
'It is likely the tensors (eval_metrics[1]) do not match the '
|
||||||
|
'metric_fn arguments', e)
|
||||||
|
raise e
|
||||||
|
|
||||||
def record(self, spec):
|
def record(self, spec):
|
||||||
|
"""Records the eval_metrics structure in `spec`."""
|
||||||
if self._recorded:
|
if self._recorded:
|
||||||
raise RuntimeError('Eval metrics have been recorded already.')
|
raise RuntimeError('Eval metrics have been recorded already.')
|
||||||
|
|
||||||
self._metric_fn, self._tensors = spec.eval_metrics
|
self._metric_fn, tensor_list_or_dict = spec.eval_metrics
|
||||||
|
|
||||||
for tensor in self._tensors:
|
if isinstance(tensor_list_or_dict, dict):
|
||||||
self._tensor_dtypes.append(tensor.dtype)
|
self._is_dict = True
|
||||||
self._tensor_shapes.append(tensor.shape)
|
for (key, tensor) in six.iteritems(tensor_list_or_dict):
|
||||||
|
self._tensor_keys.append(key)
|
||||||
|
self._tensors.append(tensor)
|
||||||
|
self._tensor_dtypes.append(tensor.dtype)
|
||||||
|
self._tensor_shapes.append(tensor.shape)
|
||||||
|
else:
|
||||||
|
# List or tuple.
|
||||||
|
self._is_dict = False
|
||||||
|
self._tensors = tensor_list_or_dict
|
||||||
|
for tensor in tensor_list_or_dict:
|
||||||
|
self._tensor_dtypes.append(tensor.dtype)
|
||||||
|
self._tensor_shapes.append(tensor.shape)
|
||||||
self._recorded = True
|
self._recorded = True
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -928,7 +963,19 @@ class _EvalMetrics(object):
|
|||||||
'dimension, but got scalar {}'.format(dequeue_ops[i][0]))
|
'dimension, but got scalar {}'.format(dequeue_ops[i][0]))
|
||||||
# TODO(xiejw): Allow users to specify the axis for batch size dimension.
|
# TODO(xiejw): Allow users to specify the axis for batch size dimension.
|
||||||
dequeue_ops[i] = array_ops.concat(dequeue_ops[i], axis=0)
|
dequeue_ops[i] = array_ops.concat(dequeue_ops[i], axis=0)
|
||||||
eval_metric_ops = self._metric_fn(*dequeue_ops)
|
|
||||||
|
if self._is_dict:
|
||||||
|
dequeue_ops = dict(zip(self._tensor_keys, dequeue_ops))
|
||||||
|
try:
|
||||||
|
eval_metric_ops = self._metric_fn(**dequeue_ops)
|
||||||
|
except TypeError as e:
|
||||||
|
logging.warning(
|
||||||
|
'Exception while calling metric_fn for evalution: %s. '
|
||||||
|
'It is likely the tensors (eval_metrics[1]) do not match the '
|
||||||
|
'metric_fn arguments', e)
|
||||||
|
raise e
|
||||||
|
else:
|
||||||
|
eval_metric_ops = self._metric_fn(*dequeue_ops)
|
||||||
|
|
||||||
eval_update_ops = []
|
eval_update_ops = []
|
||||||
for k, v in eval_metric_ops.items():
|
for k, v in eval_metric_ops.items():
|
||||||
@ -963,14 +1010,11 @@ class TPUEstimator(estimator_lib.Estimator):
|
|||||||
`TPUEstimatorSpec` instead of `EstimatorSpec`, which expects the
|
`TPUEstimatorSpec` instead of `EstimatorSpec`, which expects the
|
||||||
`eval_metrics` for TPU evaluation.
|
`eval_metrics` for TPU evaluation.
|
||||||
|
|
||||||
`TPUEstimatorSpec.eval_metrics` is a tuple of `metric_fn` and a tensor list.
|
`TPUEstimatorSpec.eval_metrics` is a tuple of `metric_fn` and `tensors`, where
|
||||||
The tensor list specifies the list of tensors, usually model logits, which are
|
`tensors` could be a list of `Tensor`s or dict of names to `Tensor`s. (See
|
||||||
transferred back from TPU system to CPU host. All tensors must have be
|
`TPUEstimatorSpec` for details). `metric_fn` takes the `tensors` and returns
|
||||||
batch-major, i.e., the batch size is the first dimension. Once all tensors are
|
a dict from metric string name to the result of calling a metric function,
|
||||||
available at CPU host, they are joined and passed as positional arguments to
|
namely a `(metric_tensor, update_op)` tuple.
|
||||||
the `metric_fn`. `metric_fn` takes the tensor list (concatenated on CPU from
|
|
||||||
all shards) and returns a dict from metric string name to the result of
|
|
||||||
calling a metric function, namely a `(metric_tensor, update_op)` tuple.
|
|
||||||
|
|
||||||
Current limitations:
|
Current limitations:
|
||||||
|
|
||||||
@ -988,7 +1032,7 @@ class TPUEstimator(estimator_lib.Estimator):
|
|||||||
labels=labels, predictions=predictions),
|
labels=labels, predictions=predictions),
|
||||||
}
|
}
|
||||||
|
|
||||||
# Your model Fn which runs on TPU.
|
# Your model Fn which runs on TPU (eval_metrics is list in this example)
|
||||||
def model_fn(features, labels, mode, config, params):
|
def model_fn(features, labels, mode, config, params):
|
||||||
...
|
...
|
||||||
logits = ...
|
logits = ...
|
||||||
@ -998,6 +1042,20 @@ class TPUEstimator(estimator_lib.Estimator):
|
|||||||
mode=mode,
|
mode=mode,
|
||||||
loss=loss,
|
loss=loss,
|
||||||
eval_metrics=(metric_fn, [labels, logits]))
|
eval_metrics=(metric_fn, [labels, logits]))
|
||||||
|
|
||||||
|
# or specify the eval_metrics tensors as dict.
|
||||||
|
def model_fn(features, labels, mode, config, params):
|
||||||
|
...
|
||||||
|
final_layer_output = ...
|
||||||
|
|
||||||
|
if mode = tf.estimator.ModeKeys.EVAL:
|
||||||
|
return tpu_estimator.TPUEstimatorSpec(
|
||||||
|
mode=mode,
|
||||||
|
loss=loss,
|
||||||
|
eval_metrics=(metric_fn, {
|
||||||
|
'labels': labels,
|
||||||
|
'logits': final_layer_output,
|
||||||
|
}))
|
||||||
```
|
```
|
||||||
|
|
||||||
Predict support on TPU is not yet implemented. So, `predict` and
|
Predict support on TPU is not yet implemented. So, `predict` and
|
||||||
|
Loading…
Reference in New Issue
Block a user