From 1b3fa440fcf44b3e61fdea4319e2ed490f0789af Mon Sep 17 00:00:00 2001 From: Renjie Liu Date: Mon, 30 Mar 2020 20:32:36 -0700 Subject: [PATCH] try to remove keras deps PiperOrigin-RevId: 303880926 Change-Id: Ic5110435f311129ba5030cf2d0b2be9e4f2edd65 --- .../experimental/examples/lstm/rnn_cell.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/tensorflow/lite/experimental/examples/lstm/rnn_cell.py b/tensorflow/lite/experimental/examples/lstm/rnn_cell.py index 3d5ebf4946f..9736719c997 100644 --- a/tensorflow/lite/experimental/examples/lstm/rnn_cell.py +++ b/tensorflow/lite/experimental/examples/lstm/rnn_cell.py @@ -23,8 +23,6 @@ from __future__ import print_function import itertools from tensorflow.lite.python.op_hint import OpHint -from tensorflow.python.keras import activations -from tensorflow.python.keras import initializers from tensorflow.python.layers import base as base_layer from tensorflow.python.ops import array_ops from tensorflow.python.ops import clip_ops @@ -80,7 +78,9 @@ class TfLiteRNNCell(rnn_cell_impl.LayerRNNCell): self._tflite_wrapper = OpHint("UnidirectionalSequenceRnn") self._num_units = num_units if activation: - self._activation = activations.get(activation) + if activation != "tanh": + raise ValueError("activation other than tanh is not supported") + self._activation = math_ops.tanh else: self._activation = math_ops.tanh @@ -150,7 +150,7 @@ class TfLiteRNNCell(rnn_cell_impl.LayerRNNCell): def get_config(self): config = { "num_units": self._num_units, - "activation": activations.serialize(self._activation), + "activation": "tanh", "reuse": self._reuse, } base_config = super(TfLiteRNNCell, self).get_config() @@ -268,7 +268,12 @@ class TFLiteLSTMCell(rnn_cell_impl.LayerRNNCell): self._num_proj_shards = num_proj_shards self._forget_bias = forget_bias self._state_is_tuple = state_is_tuple - self._activation = activation or math_ops.tanh + if activation: + if activation != "tanh": + raise ValueError("activation other than tanh is not supported") + self._activation = math_ops.tanh + else: + self._activation = math_ops.tanh self._output_size = num_proj if num_proj else num_units self._state_size = ( @@ -516,14 +521,13 @@ class TFLiteLSTMCell(rnn_cell_impl.LayerRNNCell): "num_units": self._num_units, "use_peepholes": self._use_peepholes, "cell_clip": self._cell_clip, - "initializer": initializers.serialize(self._initializer), "num_proj": self._num_proj, "proj_clip": self._proj_clip, "num_unit_shards": self._num_unit_shards, "num_proj_shards": self._num_proj_shards, "forget_bias": self._forget_bias, "state_is_tuple": self._state_is_tuple, - "activation": activations.serialize(self._activation), + "activation": "tanh", "reuse": self._reuse, } base_config = super(TFLiteLSTMCell, self).get_config()