Fix JSON serialization error in TensorFlowOpLayer in Python 3.
PiperOrigin-RevId: 259397921
This commit is contained in:
parent
3c5fb53765
commit
7cc180f107
@ -2387,6 +2387,8 @@ class TensorFlowOpLayer(Layer):
|
||||
dtype=None):
|
||||
super(TensorFlowOpLayer, self).__init__(
|
||||
name=_TF_OP_LAYER_NAME_PREFIX + name, trainable=trainable, dtype=dtype)
|
||||
if not isinstance(node_def, bytes):
|
||||
node_def = node_def.encode('utf-8')
|
||||
self.node_def = node_def_pb2.NodeDef.FromString(node_def)
|
||||
self.constants = constants or {}
|
||||
# Layer uses original op unless it is called on new inputs.
|
||||
@ -2446,7 +2448,7 @@ class TensorFlowOpLayer(Layer):
|
||||
def get_config(self):
|
||||
config = super(TensorFlowOpLayer, self).get_config()
|
||||
config.update({
|
||||
'node_def': self.node_def.SerializeToString(),
|
||||
'node_def': self.node_def.SerializeToString().decode('utf-8'),
|
||||
'constants': {
|
||||
i: backend.get_value(c) for i, c in self.constants.items()
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user