From 4ab6a520c94441622442747aef620939cc1d8130 Mon Sep 17 00:00:00 2001 From: George Sterpu Date: Thu, 21 Nov 2019 13:59:30 +0000 Subject: [PATCH 1/3] Relax the check for state_size The behaviour of `hasattr` is to evaluate the state_size member. In the case of `tfa.seq2seq.AttentionWrapper`, that is a @property member that is built at graph runtime after calling `setup_memory`, thus `hasattr` returns an error when using AttentionWrapper with dynamic memories. More details: https://github.com/tensorflow/addons/issues/680 --- tensorflow/python/keras/layers/recurrent.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/keras/layers/recurrent.py b/tensorflow/python/keras/layers/recurrent.py index 87a99f49164..6c7610b6795 100644 --- a/tensorflow/python/keras/layers/recurrent.py +++ b/tensorflow/python/keras/layers/recurrent.py @@ -82,7 +82,7 @@ class StackedRNNCells(Layer): if not hasattr(cell, 'call'): raise ValueError('All cells must have a `call` method. ' 'received cells:', cells) - if not hasattr(cell, 'state_size'): + if not ('state_size' in dir(cell) or hasattr(cell, 'state_size')): raise ValueError('All cells must have a ' '`state_size` attribute. ' 'received cells:', cells) @@ -391,7 +391,7 @@ class RNN(Layer): if not hasattr(cell, 'call'): raise ValueError('`cell` should have a `call` method. ' 'The RNN was passed:', cell) - if not hasattr(cell, 'state_size'): + if not ('state_size' in dir(cell) or hasattr(cell, 'state_size')): raise ValueError('The RNN cell should have ' 'an attribute `state_size` ' '(tuple of integers, ' From c678bdb3ae128974ddcc06bc02c5ae5f0de65e24 Mon Sep 17 00:00:00 2001 From: George Sterpu Date: Fri, 10 Jan 2020 12:34:40 +0000 Subject: [PATCH 2/3] Update recurrent.py trying to edit directly from the browser --- tensorflow/python/keras/layers/recurrent.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/keras/layers/recurrent.py b/tensorflow/python/keras/layers/recurrent.py index 6c7610b6795..3a07fbc1694 100644 --- a/tensorflow/python/keras/layers/recurrent.py +++ b/tensorflow/python/keras/layers/recurrent.py @@ -79,10 +79,10 @@ class StackedRNNCells(Layer): def __init__(self, cells, **kwargs): for cell in cells: - if not hasattr(cell, 'call'): + if not 'call' in dir(cell): raise ValueError('All cells must have a `call` method. ' 'received cells:', cells) - if not ('state_size' in dir(cell) or hasattr(cell, 'state_size')): + if not 'state_size' in dir(cell): raise ValueError('All cells must have a ' '`state_size` attribute. ' 'received cells:', cells) From 409db98338a62808209ab3837f6ca3b796c81dc5 Mon Sep 17 00:00:00 2001 From: George Sterpu Date: Fri, 10 Jan 2020 12:37:51 +0000 Subject: [PATCH 3/3] Update recurrent.py --- tensorflow/python/keras/layers/recurrent.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/keras/layers/recurrent.py b/tensorflow/python/keras/layers/recurrent.py index 3a07fbc1694..22d7cd4dcf9 100644 --- a/tensorflow/python/keras/layers/recurrent.py +++ b/tensorflow/python/keras/layers/recurrent.py @@ -388,10 +388,10 @@ class RNN(Layer): **kwargs): if isinstance(cell, (list, tuple)): cell = StackedRNNCells(cell) - if not hasattr(cell, 'call'): + if not 'call' in dir(cell): raise ValueError('`cell` should have a `call` method. ' 'The RNN was passed:', cell) - if not ('state_size' in dir(cell) or hasattr(cell, 'state_size')): + if not 'state_size' in dir(cell): raise ValueError('The RNN cell should have ' 'an attribute `state_size` ' '(tuple of integers, '