try to remove keras deps
PiperOrigin-RevId: 303880926 Change-Id: Ic5110435f311129ba5030cf2d0b2be9e4f2edd65
This commit is contained in:
parent
4cb2faa19a
commit
1b3fa440fc
@ -23,8 +23,6 @@ from __future__ import print_function
|
||||
import itertools
|
||||
|
||||
from tensorflow.lite.python.op_hint import OpHint
|
||||
from tensorflow.python.keras import activations
|
||||
from tensorflow.python.keras import initializers
|
||||
from tensorflow.python.layers import base as base_layer
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import clip_ops
|
||||
@ -80,7 +78,9 @@ class TfLiteRNNCell(rnn_cell_impl.LayerRNNCell):
|
||||
self._tflite_wrapper = OpHint("UnidirectionalSequenceRnn")
|
||||
self._num_units = num_units
|
||||
if activation:
|
||||
self._activation = activations.get(activation)
|
||||
if activation != "tanh":
|
||||
raise ValueError("activation other than tanh is not supported")
|
||||
self._activation = math_ops.tanh
|
||||
else:
|
||||
self._activation = math_ops.tanh
|
||||
|
||||
@ -150,7 +150,7 @@ class TfLiteRNNCell(rnn_cell_impl.LayerRNNCell):
|
||||
def get_config(self):
|
||||
config = {
|
||||
"num_units": self._num_units,
|
||||
"activation": activations.serialize(self._activation),
|
||||
"activation": "tanh",
|
||||
"reuse": self._reuse,
|
||||
}
|
||||
base_config = super(TfLiteRNNCell, self).get_config()
|
||||
@ -268,7 +268,12 @@ class TFLiteLSTMCell(rnn_cell_impl.LayerRNNCell):
|
||||
self._num_proj_shards = num_proj_shards
|
||||
self._forget_bias = forget_bias
|
||||
self._state_is_tuple = state_is_tuple
|
||||
self._activation = activation or math_ops.tanh
|
||||
if activation:
|
||||
if activation != "tanh":
|
||||
raise ValueError("activation other than tanh is not supported")
|
||||
self._activation = math_ops.tanh
|
||||
else:
|
||||
self._activation = math_ops.tanh
|
||||
|
||||
self._output_size = num_proj if num_proj else num_units
|
||||
self._state_size = (
|
||||
@ -516,14 +521,13 @@ class TFLiteLSTMCell(rnn_cell_impl.LayerRNNCell):
|
||||
"num_units": self._num_units,
|
||||
"use_peepholes": self._use_peepholes,
|
||||
"cell_clip": self._cell_clip,
|
||||
"initializer": initializers.serialize(self._initializer),
|
||||
"num_proj": self._num_proj,
|
||||
"proj_clip": self._proj_clip,
|
||||
"num_unit_shards": self._num_unit_shards,
|
||||
"num_proj_shards": self._num_proj_shards,
|
||||
"forget_bias": self._forget_bias,
|
||||
"state_is_tuple": self._state_is_tuple,
|
||||
"activation": activations.serialize(self._activation),
|
||||
"activation": "tanh",
|
||||
"reuse": self._reuse,
|
||||
}
|
||||
base_config = super(TFLiteLSTMCell, self).get_config()
|
||||
|
Loading…
Reference in New Issue
Block a user