Add support for adding custom metric names in model_to_estimator API.
PiperOrigin-RevId: 302504968 Change-Id: I7260f0d495ec2b97b1a759ee0ff5c212c55c4fa2
This commit is contained in:
parent
4286b658c9
commit
36b108d783
@ -130,13 +130,13 @@ def model_to_estimator(
|
||||
|
||||
|
||||
@keras_export('keras.estimator.model_to_estimator', v1=[])
|
||||
def model_to_estimator_v2(
|
||||
keras_model=None,
|
||||
keras_model_path=None,
|
||||
custom_objects=None,
|
||||
model_dir=None,
|
||||
config=None,
|
||||
checkpoint_format='checkpoint'):
|
||||
def model_to_estimator_v2(keras_model=None,
|
||||
keras_model_path=None,
|
||||
custom_objects=None,
|
||||
model_dir=None,
|
||||
config=None,
|
||||
checkpoint_format='checkpoint',
|
||||
metric_names_map=None):
|
||||
"""Constructs an `Estimator` instance from given keras model.
|
||||
|
||||
If you use infrastructure or other tooling that relies on Estimators, you can
|
||||
@ -169,6 +169,41 @@ def model_to_estimator_v2(
|
||||
estimator.train(input_fn, steps=1)
|
||||
```
|
||||
|
||||
To customize the estimator `eval_metric_ops` names, you can pass in the
|
||||
`metric_names_map` dictionary mapping the keras model output metric names
|
||||
to the custom names as follows:
|
||||
|
||||
```python
|
||||
input_a = tf.keras.layers.Input(shape=(16,), name='input_a')
|
||||
input_b = tf.keras.layers.Input(shape=(16,), name='input_b')
|
||||
dense = tf.keras.layers.Dense(8, name='dense_1')
|
||||
interm_a = dense(input_a)
|
||||
interm_b = dense(input_b)
|
||||
merged = tf.keras.layers.concatenate([interm_a, interm_b], name='merge')
|
||||
output_a = tf.keras.layers.Dense(3, activation='softmax', name='dense_2')(
|
||||
merged)
|
||||
output_b = tf.keras.layers.Dense(2, activation='softmax', name='dense_3')(
|
||||
merged)
|
||||
keras_model = tf.keras.models.Model(
|
||||
inputs=[input_a, input_b], outputs=[output_a, output_b])
|
||||
keras_model.compile(
|
||||
loss='categorical_crossentropy',
|
||||
optimizer='rmsprop',
|
||||
metrics={
|
||||
'dense_2': 'categorical_accuracy',
|
||||
'dense_3': 'categorical_accuracy'
|
||||
})
|
||||
|
||||
metric_names_map = {
|
||||
'dense_2_categorical_accuracy': 'acc_1',
|
||||
'dense_3_categorical_accuracy': 'acc_2',
|
||||
}
|
||||
keras_est = tf.keras.estimator.model_to_estimator(
|
||||
keras_model=keras_model,
|
||||
config=config,
|
||||
metric_names_map=metric_names_map)
|
||||
```
|
||||
|
||||
Args:
|
||||
keras_model: A compiled Keras model object. This argument is mutually
|
||||
exclusive with `keras_model_path`. Estimator's `model_fn` uses the
|
||||
@ -197,6 +232,17 @@ def model_to_estimator_v2(
|
||||
`tf.train.Checkpoint`. Currently, saving object-based checkpoints from
|
||||
`model_to_estimator` is only supported by Functional and Sequential
|
||||
models. Defaults to 'checkpoint'.
|
||||
metric_names_map: Optional dictionary mapping Keras model output metric
|
||||
names to custom names. This can be used to override the default Keras
|
||||
model output metrics names in a multi IO model use case and provide custom
|
||||
names for the `eval_metric_ops` in Estimator.
|
||||
The Keras model metric names can be obtained using `model.metrics_names`
|
||||
excluding any loss metrics such as total loss and output losses.
|
||||
For example, if your Keras model has two outputs `out_1` and `out_2`,
|
||||
with `mse` loss and `acc` metric, then `model.metrics_names` will be
|
||||
`['loss', 'out_1_loss', 'out_2_loss', 'out_1_acc', 'out_2_acc']`.
|
||||
The model metric names excluding the loss metrics will be
|
||||
`['out_1_acc', 'out_2_acc']`.
|
||||
|
||||
Returns:
|
||||
An Estimator from given keras model.
|
||||
@ -223,5 +269,6 @@ def model_to_estimator_v2(
|
||||
model_dir=model_dir,
|
||||
config=config,
|
||||
checkpoint_format=checkpoint_format,
|
||||
use_v2_estimator=True)
|
||||
use_v2_estimator=True,
|
||||
metric_names_map=metric_names_map)
|
||||
# LINT.ThenChange(//tensorflow_estimator/python/estimator/keras.py)
|
||||
|
@ -2,6 +2,6 @@ path: "tensorflow.keras.estimator"
|
||||
tf_module {
|
||||
member_method {
|
||||
name: "model_to_estimator"
|
||||
argspec: "args=[\'keras_model\', \'keras_model_path\', \'custom_objects\', \'model_dir\', \'config\', \'checkpoint_format\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'checkpoint\'], "
|
||||
argspec: "args=[\'keras_model\', \'keras_model_path\', \'custom_objects\', \'model_dir\', \'config\', \'checkpoint_format\', \'metric_names_map\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'checkpoint\', \'None\'], "
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user