Fix for Keras logging
PiperOrigin-RevId: 221099213
This commit is contained in:
parent
fc44600e5c
commit
e7d988404a
@ -274,6 +274,7 @@ def model_iteration(model,
|
||||
# TODO(omalleyt): Handle ProgBar as part of Callbacks once hooks are ready.
|
||||
progbar = _get_progbar(model, count_mode)
|
||||
progbar.params = callbacks.params
|
||||
progbar.params['verbose'] = verbose
|
||||
|
||||
# Find beforehand arrays that need sparse-to-dense conversion.
|
||||
if issparse is not None:
|
||||
|
@ -18,9 +18,12 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import io
|
||||
import logging
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
import six
|
||||
|
||||
from tensorflow.python import keras
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
@ -2222,6 +2225,19 @@ class TestTrainingWithMetrics(test.TestCase):
|
||||
scores = model.train_on_batch(x, y, sample_weight=w)
|
||||
self.assertArrayNear(scores, [0.2, 0.8], 0.1)
|
||||
|
||||
@tf_test_util.run_in_graph_and_eager_modes
|
||||
def test_logging(self):
|
||||
mock_stdout = io.BytesIO() if six.PY2 else io.StringIO()
|
||||
model = keras.models.Sequential()
|
||||
model.add(keras.layers.Dense(10, activation='relu'))
|
||||
model.add(keras.layers.Dense(1, activation='sigmoid'))
|
||||
model.compile(
|
||||
RMSPropOptimizer(learning_rate=0.001), loss='binary_crossentropy')
|
||||
with test.mock.patch.object(sys, 'stdout', mock_stdout):
|
||||
model.fit(
|
||||
np.ones((10, 10), 'float32'), np.ones((10, 1), 'float32'), epochs=10)
|
||||
self.assertTrue('Epoch 5/10' in mock_stdout.getvalue())
|
||||
|
||||
def test_losses_in_defun(self):
|
||||
with context.eager_mode():
|
||||
layer = keras.layers.Dense(1, kernel_regularizer='l1')
|
||||
|
Loading…
Reference in New Issue
Block a user