Change from model.predict_proba to using model.predict

model.predict_proba is deprecated in the Keras code base, and should be replaced by model.predict
This replaces the usage of the deprecated call in the scikit_learn wrapper
This commit is contained in:
TVLIgnacy 2021-01-12 18:13:58 +00:00 committed by GitHub
parent 242e0a5815
commit 27998a5612
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -261,7 +261,7 @@ class KerasClassifier(BaseWrapper):
(instead of `(n_sample, 1)` as in Keras).
"""
kwargs = self.filter_sk_params(Sequential.predict_proba, kwargs)
probs = self.model.predict_proba(x, **kwargs)
probs = self.model.predict(x, **kwargs)
# check if binary classification
if probs.shape[1] == 1: