Fix bug that resulted in unnecessary placeholders created during functional model cloning when input tensors were specified.
PiperOrigin-RevId: 222424883
This commit is contained in:
parent
c59719cf1f
commit
42e2bb488a
@ -19,12 +19,10 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.keras import backend
|
||||
from tensorflow.python.keras.engine import base_layer
|
||||
from tensorflow.python.keras.utils import tf_utils
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
||||
@ -94,19 +92,19 @@ class InputLayer(base_layer.Layer):
|
||||
else:
|
||||
batch_input_shape = None
|
||||
graph = backend.get_graph()
|
||||
with context.graph_mode():
|
||||
with graph.as_default():
|
||||
# In graph mode, create a graph placeholder to call the layer on.
|
||||
if sparse:
|
||||
input_tensor = array_ops.sparse_placeholder(
|
||||
shape=batch_input_shape,
|
||||
dtype=dtype,
|
||||
name=self.name)
|
||||
else:
|
||||
input_tensor = array_ops.placeholder(
|
||||
shape=batch_input_shape,
|
||||
dtype=dtype,
|
||||
name=self.name)
|
||||
with graph.as_default():
|
||||
# In graph mode, create a graph placeholder to call the layer on.
|
||||
if sparse:
|
||||
input_tensor = backend.placeholder(
|
||||
shape=batch_input_shape,
|
||||
dtype=dtype,
|
||||
name=self.name,
|
||||
sparse=True)
|
||||
else:
|
||||
input_tensor = backend.placeholder(
|
||||
shape=batch_input_shape,
|
||||
dtype=dtype,
|
||||
name=self.name)
|
||||
|
||||
self.is_placeholder = True
|
||||
self._batch_input_shape = batch_input_shape
|
||||
|
@ -100,17 +100,19 @@ def _clone_functional_model(model, input_tensors=None):
|
||||
input_tensors = list(input_tensors)
|
||||
input_tensors = generic_utils.to_list(input_tensors)
|
||||
input_tensors_ = []
|
||||
for i, x in enumerate(input_tensors):
|
||||
if not K.is_keras_tensor(x):
|
||||
name = model._input_layers[i].name
|
||||
input_tensor = Input(tensor=x, name='input_wrapper_for_' + name)
|
||||
for i in range(len(input_tensors)):
|
||||
input_tensor = input_tensors[i]
|
||||
if not K.is_keras_tensor(input_tensor):
|
||||
original_input_layer = model._input_layers[i]
|
||||
name = original_input_layer.name
|
||||
input_tensor = Input(tensor=input_tensor,
|
||||
name='input_wrapper_for_' + name)
|
||||
input_tensors_.append(input_tensor)
|
||||
# Cache newly created input layer.
|
||||
original_input_layer = x._keras_history[0]
|
||||
newly_created_input_layer = input_tensor._keras_history[0]
|
||||
layer_map[original_input_layer] = newly_created_input_layer
|
||||
else:
|
||||
input_tensors_.append(x)
|
||||
input_tensors_.append(input_tensor)
|
||||
input_tensors = input_tensors_
|
||||
|
||||
for x, y in zip(model.inputs, input_tensors):
|
||||
|
@ -26,10 +26,12 @@ import numpy as np
|
||||
from tensorflow.python import keras
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.keras import backend as K
|
||||
from tensorflow.python.keras import metrics
|
||||
from tensorflow.python.keras import models
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import random_ops
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.platform import test
|
||||
@ -219,6 +221,33 @@ class TestModelCloning(test.TestCase):
|
||||
with self.assertRaises(ValueError):
|
||||
keras.models._clone_sequential_model(seq_model, input_tensors=y)
|
||||
|
||||
def test_functional_cloning_does_not_create_unnecessary_placeholders(self):
|
||||
with ops.Graph().as_default():
|
||||
x = keras.Input((4,))
|
||||
y = keras.layers.Dense(4)(x)
|
||||
model = keras.models.Model(x, y)
|
||||
graph = ops.Graph()
|
||||
with graph.as_default():
|
||||
x = array_ops.ones((10, 4))
|
||||
_ = keras.models.clone_model(model, input_tensors=[x])
|
||||
has_placeholder = _has_placeholder(graph)
|
||||
self.assertFalse(has_placeholder)
|
||||
|
||||
def test_sequential_cloning_does_not_create_unnecessary_placeholders(self):
|
||||
with ops.Graph().as_default():
|
||||
model = keras.models.Sequential([keras.layers.Dense(4)])
|
||||
graph = ops.Graph()
|
||||
with graph.as_default():
|
||||
x = array_ops.ones((10, 4))
|
||||
_ = keras.models.clone_model(model, input_tensors=[x])
|
||||
has_placeholder = _has_placeholder(graph)
|
||||
self.assertFalse(has_placeholder)
|
||||
|
||||
|
||||
def _has_placeholder(graph):
|
||||
ops_types = [op.type for op in graph.get_operations()]
|
||||
return any('Placeholder' in s for s in ops_types)
|
||||
|
||||
|
||||
class CheckpointingTests(test.TestCase):
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user