Fix JSON serialization error in TensorFlowOpLayer in Python 3.
PiperOrigin-RevId: 259397921
This commit is contained in:
parent
3c5fb53765
commit
7cc180f107
@ -2387,6 +2387,8 @@ class TensorFlowOpLayer(Layer):
|
|||||||
dtype=None):
|
dtype=None):
|
||||||
super(TensorFlowOpLayer, self).__init__(
|
super(TensorFlowOpLayer, self).__init__(
|
||||||
name=_TF_OP_LAYER_NAME_PREFIX + name, trainable=trainable, dtype=dtype)
|
name=_TF_OP_LAYER_NAME_PREFIX + name, trainable=trainable, dtype=dtype)
|
||||||
|
if not isinstance(node_def, bytes):
|
||||||
|
node_def = node_def.encode('utf-8')
|
||||||
self.node_def = node_def_pb2.NodeDef.FromString(node_def)
|
self.node_def = node_def_pb2.NodeDef.FromString(node_def)
|
||||||
self.constants = constants or {}
|
self.constants = constants or {}
|
||||||
# Layer uses original op unless it is called on new inputs.
|
# Layer uses original op unless it is called on new inputs.
|
||||||
@ -2446,7 +2448,7 @@ class TensorFlowOpLayer(Layer):
|
|||||||
def get_config(self):
|
def get_config(self):
|
||||||
config = super(TensorFlowOpLayer, self).get_config()
|
config = super(TensorFlowOpLayer, self).get_config()
|
||||||
config.update({
|
config.update({
|
||||||
'node_def': self.node_def.SerializeToString(),
|
'node_def': self.node_def.SerializeToString().decode('utf-8'),
|
||||||
'constants': {
|
'constants': {
|
||||||
i: backend.get_value(c) for i, c in self.constants.items()
|
i: backend.get_value(c) for i, c in self.constants.items()
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user