try to remove keras deps

PiperOrigin-RevId: 303880926
Change-Id: Ic5110435f311129ba5030cf2d0b2be9e4f2edd65
This commit is contained in:
Renjie Liu 2020-03-30 20:32:36 -07:00 committed by TensorFlower Gardener
parent 4cb2faa19a
commit 1b3fa440fc

View File

@ -23,8 +23,6 @@ from __future__ import print_function
import itertools
from tensorflow.lite.python.op_hint import OpHint
from tensorflow.python.keras import activations
from tensorflow.python.keras import initializers
from tensorflow.python.layers import base as base_layer
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import clip_ops
@ -80,7 +78,9 @@ class TfLiteRNNCell(rnn_cell_impl.LayerRNNCell):
self._tflite_wrapper = OpHint("UnidirectionalSequenceRnn")
self._num_units = num_units
if activation:
self._activation = activations.get(activation)
if activation != "tanh":
raise ValueError("activation other than tanh is not supported")
self._activation = math_ops.tanh
else:
self._activation = math_ops.tanh
@ -150,7 +150,7 @@ class TfLiteRNNCell(rnn_cell_impl.LayerRNNCell):
def get_config(self):
config = {
"num_units": self._num_units,
"activation": activations.serialize(self._activation),
"activation": "tanh",
"reuse": self._reuse,
}
base_config = super(TfLiteRNNCell, self).get_config()
@ -268,7 +268,12 @@ class TFLiteLSTMCell(rnn_cell_impl.LayerRNNCell):
self._num_proj_shards = num_proj_shards
self._forget_bias = forget_bias
self._state_is_tuple = state_is_tuple
self._activation = activation or math_ops.tanh
if activation:
if activation != "tanh":
raise ValueError("activation other than tanh is not supported")
self._activation = math_ops.tanh
else:
self._activation = math_ops.tanh
self._output_size = num_proj if num_proj else num_units
self._state_size = (
@ -516,14 +521,13 @@ class TFLiteLSTMCell(rnn_cell_impl.LayerRNNCell):
"num_units": self._num_units,
"use_peepholes": self._use_peepholes,
"cell_clip": self._cell_clip,
"initializer": initializers.serialize(self._initializer),
"num_proj": self._num_proj,
"proj_clip": self._proj_clip,
"num_unit_shards": self._num_unit_shards,
"num_proj_shards": self._num_proj_shards,
"forget_bias": self._forget_bias,
"state_is_tuple": self._state_is_tuple,
"activation": activations.serialize(self._activation),
"activation": "tanh",
"reuse": self._reuse,
}
base_config = super(TFLiteLSTMCell, self).get_config()