Merge pull request #1 from MatthewScholefield/tf2
Swap out keras for tf.keras
This commit is contained in:
commit
06ed742f00
@ -33,7 +33,7 @@ def weighted_log_loss(yt, yp) -> Any:
|
||||
yt: Target
|
||||
yp: Prediction
|
||||
"""
|
||||
from keras import backend as K
|
||||
import tensorflow.keras.backend as K
|
||||
|
||||
pos_loss = -(0 + yt) * K.log(0 + yp + K.epsilon())
|
||||
neg_loss = -(1 - yt) * K.log(1 - yp + K.epsilon())
|
||||
@ -42,7 +42,7 @@ def weighted_log_loss(yt, yp) -> Any:
|
||||
|
||||
|
||||
def weighted_mse_loss(yt, yp) -> Any:
|
||||
from keras import backend as K
|
||||
import tensorflow.keras.backend as K
|
||||
|
||||
total = K.sum(K.ones_like(yt))
|
||||
neg_loss = total * K.sum(K.square(yp * (1 - yt))) / K.sum(1 - yt)
|
||||
@ -52,12 +52,12 @@ def weighted_mse_loss(yt, yp) -> Any:
|
||||
|
||||
|
||||
def false_pos(yt, yp) -> Any:
|
||||
from keras import backend as K
|
||||
import tensorflow.keras.backend as K
|
||||
return K.sum(K.cast(yp * (1 - yt) > 0.5, 'float')) / K.maximum(1.0, K.sum(1 - yt))
|
||||
|
||||
|
||||
def false_neg(yt, yp) -> Any:
|
||||
from keras import backend as K
|
||||
import tensorflow.keras.backend as K
|
||||
return K.sum(K.cast((1 - yp) * (0 + yt) > 0.5, 'float')) / K.maximum(1.0, K.sum(0 + yt))
|
||||
|
||||
|
||||
|
@ -19,7 +19,7 @@ from precise.functions import load_keras, false_pos, false_neg, weighted_log_los
|
||||
from precise.params import inject_params, pr
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from keras.models import Sequential
|
||||
from tensorflow.keras.models import Sequential
|
||||
|
||||
|
||||
@attr.s()
|
||||
@ -45,7 +45,8 @@ def load_precise_model(model_name: str) -> Any:
|
||||
print('Warning: Unknown model type, ', model_name)
|
||||
|
||||
inject_params(model_name)
|
||||
return load_keras().models.load_model(model_name)
|
||||
from tensorflow.keras.models import load_model
|
||||
return load_model(model_name, custom_objects=globals())
|
||||
|
||||
|
||||
def create_model(model_name: Optional[str], params: ModelParams) -> 'Sequential':
|
||||
@ -63,9 +64,8 @@ def create_model(model_name: Optional[str], params: ModelParams) -> 'Sequential'
|
||||
print('Loading from ' + model_name + '...')
|
||||
model = load_precise_model(model_name)
|
||||
else:
|
||||
from keras.layers.core import Dense
|
||||
from keras.layers.recurrent import GRU
|
||||
from keras.models import Sequential
|
||||
from tensorflow.keras.layers import Dense, GRU
|
||||
from tensorflow.keras.models import Sequential
|
||||
|
||||
model = Sequential()
|
||||
model.add(GRU(
|
||||
@ -74,7 +74,6 @@ def create_model(model_name: Optional[str], params: ModelParams) -> 'Sequential'
|
||||
))
|
||||
model.add(Dense(1, activation='sigmoid'))
|
||||
|
||||
load_keras()
|
||||
metrics = ['accuracy'] + params.extra_metrics * [false_pos, false_neg]
|
||||
set_loss_bias(params.loss_bias)
|
||||
for i in model.layers[:params.freeze_till]:
|
||||
|
@ -68,21 +68,15 @@ class TensorFlowRunner(Runner):
|
||||
|
||||
class KerasRunner(Runner):
|
||||
def __init__(self, model_name: str):
|
||||
# Load model using Keras (not tf.keras)
|
||||
self.model = load_precise_model(model_name)
|
||||
|
||||
# TF 2.0 doesn't work well with sessions and graphs
|
||||
# Only in tf.v1.compat, but that restricts usage of v2 features
|
||||
self.graph = None
|
||||
|
||||
def predict(self, inputs: np.ndarray):
|
||||
import keras as K
|
||||
K.backend.tensorflow_backend._SYMBOLIC_SCOPE.value = True
|
||||
return self.model.predict(inputs)
|
||||
|
||||
def run(self, inp: np.ndarray) -> float:
|
||||
return self.predict(inp[np.newaxis])[0][0]
|
||||
|
||||
|
||||
class TFLiteRunner(Runner):
|
||||
def __init__(self, model_name: str):
|
||||
import tensorflow as tf
|
||||
|
@ -13,7 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from fitipy import Fitipy
|
||||
from keras.callbacks import LambdaCallback
|
||||
from tensorflow.keras.callbacks import LambdaCallback
|
||||
from os.path import splitext, isfile
|
||||
from prettyparse import Usage
|
||||
from typing import Any, Tuple
|
||||
@ -85,7 +85,7 @@ class TrainScript(BaseScript):
|
||||
self.model = create_model(args.model, params)
|
||||
self.train, self.test = self.load_data(self.args)
|
||||
|
||||
from keras.callbacks import ModelCheckpoint
|
||||
from tensorflow.keras.callbacks import ModelCheckpoint
|
||||
checkpoint = ModelCheckpoint(args.model, monitor=args.metric_monitor,
|
||||
save_best_only=args.save_best)
|
||||
epoch_fiti = Fitipy(splitext(args.model)[0] + '.epoch')
|
||||
|
@ -18,7 +18,7 @@ from math import sqrt
|
||||
import numpy as np
|
||||
from contextlib import suppress
|
||||
from fitipy import Fitipy
|
||||
from keras.callbacks import LambdaCallback
|
||||
from tensorflow.keras.callbacks import LambdaCallback
|
||||
from os.path import splitext, join, basename
|
||||
from prettyparse import Usage
|
||||
from random import random, shuffle
|
||||
@ -90,7 +90,7 @@ class TrainGeneratedScript(BaseScript):
|
||||
self.model = create_model(args.model, params)
|
||||
self.listener = Listener('', args.chunk_size, runner_cls=lambda x: None)
|
||||
|
||||
from keras.callbacks import ModelCheckpoint
|
||||
from tensorflow.keras.callbacks import ModelCheckpoint
|
||||
checkpoint = ModelCheckpoint(args.model, monitor=args.metric_monitor,
|
||||
save_best_only=args.save_best)
|
||||
epoch_fiti = Fitipy(splitext(args.model)[0] + '.epoch')
|
||||
|
Loading…
x
Reference in New Issue
Block a user