Fix bug that resulted in unnecessary placeholders created during functional model cloning when input tensors were specified.

PiperOrigin-RevId: 222424883
This commit is contained in:
Francois Chollet 2018-11-21 10:07:11 -08:00 committed by TensorFlower Gardener
parent c59719cf1f
commit 42e2bb488a
3 changed files with 50 additions and 21 deletions

View File

@ -19,12 +19,10 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.python.eager import context
from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_shape
from tensorflow.python.keras import backend from tensorflow.python.keras import backend
from tensorflow.python.keras.engine import base_layer from tensorflow.python.keras.engine import base_layer
from tensorflow.python.keras.utils import tf_utils from tensorflow.python.keras.utils import tf_utils
from tensorflow.python.ops import array_ops
from tensorflow.python.util.tf_export import tf_export from tensorflow.python.util.tf_export import tf_export
@ -94,16 +92,16 @@ class InputLayer(base_layer.Layer):
else: else:
batch_input_shape = None batch_input_shape = None
graph = backend.get_graph() graph = backend.get_graph()
with context.graph_mode():
with graph.as_default(): with graph.as_default():
# In graph mode, create a graph placeholder to call the layer on. # In graph mode, create a graph placeholder to call the layer on.
if sparse: if sparse:
input_tensor = array_ops.sparse_placeholder( input_tensor = backend.placeholder(
shape=batch_input_shape, shape=batch_input_shape,
dtype=dtype, dtype=dtype,
name=self.name) name=self.name,
sparse=True)
else: else:
input_tensor = array_ops.placeholder( input_tensor = backend.placeholder(
shape=batch_input_shape, shape=batch_input_shape,
dtype=dtype, dtype=dtype,
name=self.name) name=self.name)

View File

@ -100,17 +100,19 @@ def _clone_functional_model(model, input_tensors=None):
input_tensors = list(input_tensors) input_tensors = list(input_tensors)
input_tensors = generic_utils.to_list(input_tensors) input_tensors = generic_utils.to_list(input_tensors)
input_tensors_ = [] input_tensors_ = []
for i, x in enumerate(input_tensors): for i in range(len(input_tensors)):
if not K.is_keras_tensor(x): input_tensor = input_tensors[i]
name = model._input_layers[i].name if not K.is_keras_tensor(input_tensor):
input_tensor = Input(tensor=x, name='input_wrapper_for_' + name) original_input_layer = model._input_layers[i]
name = original_input_layer.name
input_tensor = Input(tensor=input_tensor,
name='input_wrapper_for_' + name)
input_tensors_.append(input_tensor) input_tensors_.append(input_tensor)
# Cache newly created input layer. # Cache newly created input layer.
original_input_layer = x._keras_history[0]
newly_created_input_layer = input_tensor._keras_history[0] newly_created_input_layer = input_tensor._keras_history[0]
layer_map[original_input_layer] = newly_created_input_layer layer_map[original_input_layer] = newly_created_input_layer
else: else:
input_tensors_.append(x) input_tensors_.append(input_tensor)
input_tensors = input_tensors_ input_tensors = input_tensors_
for x, y in zip(model.inputs, input_tensors): for x, y in zip(model.inputs, input_tensors):

View File

@ -26,10 +26,12 @@ import numpy as np
from tensorflow.python import keras from tensorflow.python import keras
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util from tensorflow.python.framework import test_util
from tensorflow.python.keras import backend as K from tensorflow.python.keras import backend as K
from tensorflow.python.keras import metrics from tensorflow.python.keras import metrics
from tensorflow.python.keras import models from tensorflow.python.keras import models
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import random_ops from tensorflow.python.ops import random_ops
from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.platform import test from tensorflow.python.platform import test
@ -219,6 +221,33 @@ class TestModelCloning(test.TestCase):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
keras.models._clone_sequential_model(seq_model, input_tensors=y) keras.models._clone_sequential_model(seq_model, input_tensors=y)
def test_functional_cloning_does_not_create_unnecessary_placeholders(self):
with ops.Graph().as_default():
x = keras.Input((4,))
y = keras.layers.Dense(4)(x)
model = keras.models.Model(x, y)
graph = ops.Graph()
with graph.as_default():
x = array_ops.ones((10, 4))
_ = keras.models.clone_model(model, input_tensors=[x])
has_placeholder = _has_placeholder(graph)
self.assertFalse(has_placeholder)
def test_sequential_cloning_does_not_create_unnecessary_placeholders(self):
with ops.Graph().as_default():
model = keras.models.Sequential([keras.layers.Dense(4)])
graph = ops.Graph()
with graph.as_default():
x = array_ops.ones((10, 4))
_ = keras.models.clone_model(model, input_tensors=[x])
has_placeholder = _has_placeholder(graph)
self.assertFalse(has_placeholder)
def _has_placeholder(graph):
ops_types = [op.type for op in graph.get_operations()]
return any('Placeholder' in s for s in ops_types)
class CheckpointingTests(test.TestCase): class CheckpointingTests(test.TestCase):