Fix bug that resulted in unnecessary placeholders created during functional model cloning when input tensors were specified.

PiperOrigin-RevId: 222424883
This commit is contained in:
Francois Chollet 2018-11-21 10:07:11 -08:00 committed by TensorFlower Gardener
parent c59719cf1f
commit 42e2bb488a
3 changed files with 50 additions and 21 deletions

View File

@ -19,12 +19,10 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.eager import context
from tensorflow.python.framework import tensor_shape
from tensorflow.python.keras import backend
from tensorflow.python.keras.engine import base_layer
from tensorflow.python.keras.utils import tf_utils
from tensorflow.python.ops import array_ops
from tensorflow.python.util.tf_export import tf_export
@ -94,19 +92,19 @@ class InputLayer(base_layer.Layer):
else:
batch_input_shape = None
graph = backend.get_graph()
with context.graph_mode():
with graph.as_default():
# In graph mode, create a graph placeholder to call the layer on.
if sparse:
input_tensor = array_ops.sparse_placeholder(
shape=batch_input_shape,
dtype=dtype,
name=self.name)
else:
input_tensor = array_ops.placeholder(
shape=batch_input_shape,
dtype=dtype,
name=self.name)
with graph.as_default():
# In graph mode, create a graph placeholder to call the layer on.
if sparse:
input_tensor = backend.placeholder(
shape=batch_input_shape,
dtype=dtype,
name=self.name,
sparse=True)
else:
input_tensor = backend.placeholder(
shape=batch_input_shape,
dtype=dtype,
name=self.name)
self.is_placeholder = True
self._batch_input_shape = batch_input_shape

View File

@ -100,17 +100,19 @@ def _clone_functional_model(model, input_tensors=None):
input_tensors = list(input_tensors)
input_tensors = generic_utils.to_list(input_tensors)
input_tensors_ = []
for i, x in enumerate(input_tensors):
if not K.is_keras_tensor(x):
name = model._input_layers[i].name
input_tensor = Input(tensor=x, name='input_wrapper_for_' + name)
for i in range(len(input_tensors)):
input_tensor = input_tensors[i]
if not K.is_keras_tensor(input_tensor):
original_input_layer = model._input_layers[i]
name = original_input_layer.name
input_tensor = Input(tensor=input_tensor,
name='input_wrapper_for_' + name)
input_tensors_.append(input_tensor)
# Cache newly created input layer.
original_input_layer = x._keras_history[0]
newly_created_input_layer = input_tensor._keras_history[0]
layer_map[original_input_layer] = newly_created_input_layer
else:
input_tensors_.append(x)
input_tensors_.append(input_tensor)
input_tensors = input_tensors_
for x, y in zip(model.inputs, input_tensors):

View File

@ -26,10 +26,12 @@ import numpy as np
from tensorflow.python import keras
from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.keras import backend as K
from tensorflow.python.keras import metrics
from tensorflow.python.keras import models
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.platform import test
@ -219,6 +221,33 @@ class TestModelCloning(test.TestCase):
with self.assertRaises(ValueError):
keras.models._clone_sequential_model(seq_model, input_tensors=y)
def test_functional_cloning_does_not_create_unnecessary_placeholders(self):
with ops.Graph().as_default():
x = keras.Input((4,))
y = keras.layers.Dense(4)(x)
model = keras.models.Model(x, y)
graph = ops.Graph()
with graph.as_default():
x = array_ops.ones((10, 4))
_ = keras.models.clone_model(model, input_tensors=[x])
has_placeholder = _has_placeholder(graph)
self.assertFalse(has_placeholder)
def test_sequential_cloning_does_not_create_unnecessary_placeholders(self):
with ops.Graph().as_default():
model = keras.models.Sequential([keras.layers.Dense(4)])
graph = ops.Graph()
with graph.as_default():
x = array_ops.ones((10, 4))
_ = keras.models.clone_model(model, input_tensors=[x])
has_placeholder = _has_placeholder(graph)
self.assertFalse(has_placeholder)
def _has_placeholder(graph):
ops_types = [op.type for op in graph.get_operations()]
return any('Placeholder' in s for s in ops_types)
class CheckpointingTests(test.TestCase):