try to remove keras deps
PiperOrigin-RevId: 303880926 Change-Id: Ic5110435f311129ba5030cf2d0b2be9e4f2edd65
This commit is contained in:
parent
4cb2faa19a
commit
1b3fa440fc
@ -23,8 +23,6 @@ from __future__ import print_function
|
|||||||
import itertools
|
import itertools
|
||||||
|
|
||||||
from tensorflow.lite.python.op_hint import OpHint
|
from tensorflow.lite.python.op_hint import OpHint
|
||||||
from tensorflow.python.keras import activations
|
|
||||||
from tensorflow.python.keras import initializers
|
|
||||||
from tensorflow.python.layers import base as base_layer
|
from tensorflow.python.layers import base as base_layer
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import clip_ops
|
from tensorflow.python.ops import clip_ops
|
||||||
@ -80,7 +78,9 @@ class TfLiteRNNCell(rnn_cell_impl.LayerRNNCell):
|
|||||||
self._tflite_wrapper = OpHint("UnidirectionalSequenceRnn")
|
self._tflite_wrapper = OpHint("UnidirectionalSequenceRnn")
|
||||||
self._num_units = num_units
|
self._num_units = num_units
|
||||||
if activation:
|
if activation:
|
||||||
self._activation = activations.get(activation)
|
if activation != "tanh":
|
||||||
|
raise ValueError("activation other than tanh is not supported")
|
||||||
|
self._activation = math_ops.tanh
|
||||||
else:
|
else:
|
||||||
self._activation = math_ops.tanh
|
self._activation = math_ops.tanh
|
||||||
|
|
||||||
@ -150,7 +150,7 @@ class TfLiteRNNCell(rnn_cell_impl.LayerRNNCell):
|
|||||||
def get_config(self):
|
def get_config(self):
|
||||||
config = {
|
config = {
|
||||||
"num_units": self._num_units,
|
"num_units": self._num_units,
|
||||||
"activation": activations.serialize(self._activation),
|
"activation": "tanh",
|
||||||
"reuse": self._reuse,
|
"reuse": self._reuse,
|
||||||
}
|
}
|
||||||
base_config = super(TfLiteRNNCell, self).get_config()
|
base_config = super(TfLiteRNNCell, self).get_config()
|
||||||
@ -268,7 +268,12 @@ class TFLiteLSTMCell(rnn_cell_impl.LayerRNNCell):
|
|||||||
self._num_proj_shards = num_proj_shards
|
self._num_proj_shards = num_proj_shards
|
||||||
self._forget_bias = forget_bias
|
self._forget_bias = forget_bias
|
||||||
self._state_is_tuple = state_is_tuple
|
self._state_is_tuple = state_is_tuple
|
||||||
self._activation = activation or math_ops.tanh
|
if activation:
|
||||||
|
if activation != "tanh":
|
||||||
|
raise ValueError("activation other than tanh is not supported")
|
||||||
|
self._activation = math_ops.tanh
|
||||||
|
else:
|
||||||
|
self._activation = math_ops.tanh
|
||||||
|
|
||||||
self._output_size = num_proj if num_proj else num_units
|
self._output_size = num_proj if num_proj else num_units
|
||||||
self._state_size = (
|
self._state_size = (
|
||||||
@ -516,14 +521,13 @@ class TFLiteLSTMCell(rnn_cell_impl.LayerRNNCell):
|
|||||||
"num_units": self._num_units,
|
"num_units": self._num_units,
|
||||||
"use_peepholes": self._use_peepholes,
|
"use_peepholes": self._use_peepholes,
|
||||||
"cell_clip": self._cell_clip,
|
"cell_clip": self._cell_clip,
|
||||||
"initializer": initializers.serialize(self._initializer),
|
|
||||||
"num_proj": self._num_proj,
|
"num_proj": self._num_proj,
|
||||||
"proj_clip": self._proj_clip,
|
"proj_clip": self._proj_clip,
|
||||||
"num_unit_shards": self._num_unit_shards,
|
"num_unit_shards": self._num_unit_shards,
|
||||||
"num_proj_shards": self._num_proj_shards,
|
"num_proj_shards": self._num_proj_shards,
|
||||||
"forget_bias": self._forget_bias,
|
"forget_bias": self._forget_bias,
|
||||||
"state_is_tuple": self._state_is_tuple,
|
"state_is_tuple": self._state_is_tuple,
|
||||||
"activation": activations.serialize(self._activation),
|
"activation": "tanh",
|
||||||
"reuse": self._reuse,
|
"reuse": self._reuse,
|
||||||
}
|
}
|
||||||
base_config = super(TFLiteLSTMCell, self).get_config()
|
base_config = super(TFLiteLSTMCell, self).get_config()
|
||||||
|
Loading…
Reference in New Issue
Block a user