Merge pull request #46372 from TVLIgnacy:patch-1

PiperOrigin-RevId: 351944491
Change-Id: I7526b4e15fc9ace0b0f3082b0c548dce90adc8be
This commit is contained in:
TensorFlower Gardener 2021-01-14 22:32:52 -08:00
commit 37e4d3abea

View File

@ -261,7 +261,7 @@ class KerasClassifier(BaseWrapper):
(instead of `(n_sample, 1)` as in Keras).
"""
kwargs = self.filter_sk_params(Sequential.predict_proba, kwargs)
probs = self.model.predict_proba(x, **kwargs)
probs = self.model.predict(x, **kwargs)
# check if binary classification
if probs.shape[1] == 1: