Add support for adding custom metric names in model_to_estimator API.

PiperOrigin-RevId: 302504968
Change-Id: I7260f0d495ec2b97b1a759ee0ff5c212c55c4fa2
This commit is contained in:
Pavithra Vijay 2020-03-23 13:29:35 -07:00 committed by TensorFlower Gardener
parent 4286b658c9
commit 36b108d783
2 changed files with 56 additions and 9 deletions

View File

@ -130,13 +130,13 @@ def model_to_estimator(
@keras_export('keras.estimator.model_to_estimator', v1=[])
def model_to_estimator_v2(
keras_model=None,
keras_model_path=None,
custom_objects=None,
model_dir=None,
config=None,
checkpoint_format='checkpoint'):
def model_to_estimator_v2(keras_model=None,
keras_model_path=None,
custom_objects=None,
model_dir=None,
config=None,
checkpoint_format='checkpoint',
metric_names_map=None):
"""Constructs an `Estimator` instance from given keras model.
If you use infrastructure or other tooling that relies on Estimators, you can
@ -169,6 +169,41 @@ def model_to_estimator_v2(
estimator.train(input_fn, steps=1)
```
To customize the estimator `eval_metric_ops` names, you can pass in the
`metric_names_map` dictionary mapping the keras model output metric names
to the custom names as follows:
```python
input_a = tf.keras.layers.Input(shape=(16,), name='input_a')
input_b = tf.keras.layers.Input(shape=(16,), name='input_b')
dense = tf.keras.layers.Dense(8, name='dense_1')
interm_a = dense(input_a)
interm_b = dense(input_b)
merged = tf.keras.layers.concatenate([interm_a, interm_b], name='merge')
output_a = tf.keras.layers.Dense(3, activation='softmax', name='dense_2')(
merged)
output_b = tf.keras.layers.Dense(2, activation='softmax', name='dense_3')(
merged)
keras_model = tf.keras.models.Model(
inputs=[input_a, input_b], outputs=[output_a, output_b])
keras_model.compile(
loss='categorical_crossentropy',
optimizer='rmsprop',
metrics={
'dense_2': 'categorical_accuracy',
'dense_3': 'categorical_accuracy'
})
metric_names_map = {
'dense_2_categorical_accuracy': 'acc_1',
'dense_3_categorical_accuracy': 'acc_2',
}
keras_est = tf.keras.estimator.model_to_estimator(
keras_model=keras_model,
config=config,
metric_names_map=metric_names_map)
```
Args:
keras_model: A compiled Keras model object. This argument is mutually
exclusive with `keras_model_path`. Estimator's `model_fn` uses the
@ -197,6 +232,17 @@ def model_to_estimator_v2(
`tf.train.Checkpoint`. Currently, saving object-based checkpoints from
`model_to_estimator` is only supported by Functional and Sequential
models. Defaults to 'checkpoint'.
metric_names_map: Optional dictionary mapping Keras model output metric
names to custom names. This can be used to override the default Keras
model output metrics names in a multi IO model use case and provide custom
names for the `eval_metric_ops` in Estimator.
The Keras model metric names can be obtained using `model.metrics_names`
excluding any loss metrics such as total loss and output losses.
For example, if your Keras model has two outputs `out_1` and `out_2`,
with `mse` loss and `acc` metric, then `model.metrics_names` will be
`['loss', 'out_1_loss', 'out_2_loss', 'out_1_acc', 'out_2_acc']`.
The model metric names excluding the loss metrics will be
`['out_1_acc', 'out_2_acc']`.
Returns:
An Estimator from given keras model.
@ -223,5 +269,6 @@ def model_to_estimator_v2(
model_dir=model_dir,
config=config,
checkpoint_format=checkpoint_format,
use_v2_estimator=True)
use_v2_estimator=True,
metric_names_map=metric_names_map)
# LINT.ThenChange(//tensorflow_estimator/python/estimator/keras.py)

View File

@ -2,6 +2,6 @@ path: "tensorflow.keras.estimator"
tf_module {
member_method {
name: "model_to_estimator"
argspec: "args=[\'keras_model\', \'keras_model_path\', \'custom_objects\', \'model_dir\', \'config\', \'checkpoint_format\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'checkpoint\'], "
argspec: "args=[\'keras_model\', \'keras_model_path\', \'custom_objects\', \'model_dir\', \'config\', \'checkpoint_format\', \'metric_names_map\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'checkpoint\', \'None\'], "
}
}