Merge pull request #46372 from TVLIgnacy:patch-1
PiperOrigin-RevId: 351944491 Change-Id: I7526b4e15fc9ace0b0f3082b0c548dce90adc8be
This commit is contained in:
commit
37e4d3abea
@ -261,7 +261,7 @@ class KerasClassifier(BaseWrapper):
|
|||||||
(instead of `(n_sample, 1)` as in Keras).
|
(instead of `(n_sample, 1)` as in Keras).
|
||||||
"""
|
"""
|
||||||
kwargs = self.filter_sk_params(Sequential.predict_proba, kwargs)
|
kwargs = self.filter_sk_params(Sequential.predict_proba, kwargs)
|
||||||
probs = self.model.predict_proba(x, **kwargs)
|
probs = self.model.predict(x, **kwargs)
|
||||||
|
|
||||||
# check if binary classification
|
# check if binary classification
|
||||||
if probs.shape[1] == 1:
|
if probs.shape[1] == 1:
|
||||||
|
Loading…
Reference in New Issue
Block a user