Merge pull request #1 from MatthewScholefield/tf2

Swap out keras for tf.keras
This commit is contained in:
andreselizondo-adestech 2020-08-19 14:09:11 -05:00 committed by GitHub
commit 06ed742f00
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 15 additions and 23 deletions

View File

@ -33,7 +33,7 @@ def weighted_log_loss(yt, yp) -> Any:
yt: Target
yp: Prediction
"""
from keras import backend as K
import tensorflow.keras.backend as K
pos_loss = -(0 + yt) * K.log(0 + yp + K.epsilon())
neg_loss = -(1 - yt) * K.log(1 - yp + K.epsilon())
@ -42,7 +42,7 @@ def weighted_log_loss(yt, yp) -> Any:
def weighted_mse_loss(yt, yp) -> Any:
from keras import backend as K
import tensorflow.keras.backend as K
total = K.sum(K.ones_like(yt))
neg_loss = total * K.sum(K.square(yp * (1 - yt))) / K.sum(1 - yt)
@ -52,12 +52,12 @@ def weighted_mse_loss(yt, yp) -> Any:
def false_pos(yt, yp) -> Any:
from keras import backend as K
import tensorflow.keras.backend as K
return K.sum(K.cast(yp * (1 - yt) > 0.5, 'float')) / K.maximum(1.0, K.sum(1 - yt))
def false_neg(yt, yp) -> Any:
from keras import backend as K
import tensorflow.keras.backend as K
return K.sum(K.cast((1 - yp) * (0 + yt) > 0.5, 'float')) / K.maximum(1.0, K.sum(0 + yt))

View File

@ -19,7 +19,7 @@ from precise.functions import load_keras, false_pos, false_neg, weighted_log_los
from precise.params import inject_params, pr
if TYPE_CHECKING:
from keras.models import Sequential
from tensorflow.keras.models import Sequential
@attr.s()
@ -45,7 +45,8 @@ def load_precise_model(model_name: str) -> Any:
print('Warning: Unknown model type, ', model_name)
inject_params(model_name)
return load_keras().models.load_model(model_name)
from tensorflow.keras.models import load_model
return load_model(model_name, custom_objects=globals())
def create_model(model_name: Optional[str], params: ModelParams) -> 'Sequential':
@ -63,9 +64,8 @@ def create_model(model_name: Optional[str], params: ModelParams) -> 'Sequential'
print('Loading from ' + model_name + '...')
model = load_precise_model(model_name)
else:
from keras.layers.core import Dense
from keras.layers.recurrent import GRU
from keras.models import Sequential
from tensorflow.keras.layers import Dense, GRU
from tensorflow.keras.models import Sequential
model = Sequential()
model.add(GRU(
@ -74,7 +74,6 @@ def create_model(model_name: Optional[str], params: ModelParams) -> 'Sequential'
))
model.add(Dense(1, activation='sigmoid'))
load_keras()
metrics = ['accuracy'] + params.extra_metrics * [false_pos, false_neg]
set_loss_bias(params.loss_bias)
for i in model.layers[:params.freeze_till]:

View File

@ -68,21 +68,15 @@ class TensorFlowRunner(Runner):
class KerasRunner(Runner):
def __init__(self, model_name: str):
# Load model using Keras (not tf.keras)
self.model = load_precise_model(model_name)
# TF 2.0 doesn't work well with sessions and graphs
# Only in tf.v1.compat, but that restricts usage of v2 features
self.graph = None
def predict(self, inputs: np.ndarray):
import keras as K
K.backend.tensorflow_backend._SYMBOLIC_SCOPE.value = True
return self.model.predict(inputs)
def run(self, inp: np.ndarray) -> float:
return self.predict(inp[np.newaxis])[0][0]
class TFLiteRunner(Runner):
def __init__(self, model_name: str):
import tensorflow as tf

View File

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from fitipy import Fitipy
from keras.callbacks import LambdaCallback
from tensorflow.keras.callbacks import LambdaCallback
from os.path import splitext, isfile
from prettyparse import Usage
from typing import Any, Tuple
@ -85,7 +85,7 @@ class TrainScript(BaseScript):
self.model = create_model(args.model, params)
self.train, self.test = self.load_data(self.args)
from keras.callbacks import ModelCheckpoint
from tensorflow.keras.callbacks import ModelCheckpoint
checkpoint = ModelCheckpoint(args.model, monitor=args.metric_monitor,
save_best_only=args.save_best)
epoch_fiti = Fitipy(splitext(args.model)[0] + '.epoch')

View File

@ -18,7 +18,7 @@ from math import sqrt
import numpy as np
from contextlib import suppress
from fitipy import Fitipy
from keras.callbacks import LambdaCallback
from tensorflow.keras.callbacks import LambdaCallback
from os.path import splitext, join, basename
from prettyparse import Usage
from random import random, shuffle
@ -90,7 +90,7 @@ class TrainGeneratedScript(BaseScript):
self.model = create_model(args.model, params)
self.listener = Listener('', args.chunk_size, runner_cls=lambda x: None)
from keras.callbacks import ModelCheckpoint
from tensorflow.keras.callbacks import ModelCheckpoint
checkpoint = ModelCheckpoint(args.model, monitor=args.metric_monitor,
save_best_only=args.save_best)
epoch_fiti = Fitipy(splitext(args.model)[0] + '.epoch')

View File

@ -69,10 +69,9 @@ setup(
},
install_requires=[
'numpy',
'tensorflow-gpu==2.2.0rc', # This should be changed for 2.2.0 when it's released.
'tensorflow-gpu==2.2.0',
'sonopy',
'pyaudio',
'keras>2.1.5',
'h5py',
'wavio',
'typing',