diff --git a/tensorflow/python/keras/engine/input_layer.py b/tensorflow/python/keras/engine/input_layer.py index 590b935d408..9874efe2bcc 100644 --- a/tensorflow/python/keras/engine/input_layer.py +++ b/tensorflow/python/keras/engine/input_layer.py @@ -19,12 +19,10 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.eager import context from tensorflow.python.framework import tensor_shape from tensorflow.python.keras import backend from tensorflow.python.keras.engine import base_layer from tensorflow.python.keras.utils import tf_utils -from tensorflow.python.ops import array_ops from tensorflow.python.util.tf_export import tf_export @@ -94,19 +92,19 @@ class InputLayer(base_layer.Layer): else: batch_input_shape = None graph = backend.get_graph() - with context.graph_mode(): - with graph.as_default(): - # In graph mode, create a graph placeholder to call the layer on. - if sparse: - input_tensor = array_ops.sparse_placeholder( - shape=batch_input_shape, - dtype=dtype, - name=self.name) - else: - input_tensor = array_ops.placeholder( - shape=batch_input_shape, - dtype=dtype, - name=self.name) + with graph.as_default(): + # In graph mode, create a graph placeholder to call the layer on. + if sparse: + input_tensor = backend.placeholder( + shape=batch_input_shape, + dtype=dtype, + name=self.name, + sparse=True) + else: + input_tensor = backend.placeholder( + shape=batch_input_shape, + dtype=dtype, + name=self.name) self.is_placeholder = True self._batch_input_shape = batch_input_shape diff --git a/tensorflow/python/keras/models.py b/tensorflow/python/keras/models.py index 3a0c51b4970..4813b8061e3 100644 --- a/tensorflow/python/keras/models.py +++ b/tensorflow/python/keras/models.py @@ -100,17 +100,19 @@ def _clone_functional_model(model, input_tensors=None): input_tensors = list(input_tensors) input_tensors = generic_utils.to_list(input_tensors) input_tensors_ = [] - for i, x in enumerate(input_tensors): - if not K.is_keras_tensor(x): - name = model._input_layers[i].name - input_tensor = Input(tensor=x, name='input_wrapper_for_' + name) + for i in range(len(input_tensors)): + input_tensor = input_tensors[i] + if not K.is_keras_tensor(input_tensor): + original_input_layer = model._input_layers[i] + name = original_input_layer.name + input_tensor = Input(tensor=input_tensor, + name='input_wrapper_for_' + name) input_tensors_.append(input_tensor) # Cache newly created input layer. - original_input_layer = x._keras_history[0] newly_created_input_layer = input_tensor._keras_history[0] layer_map[original_input_layer] = newly_created_input_layer else: - input_tensors_.append(x) + input_tensors_.append(input_tensor) input_tensors = input_tensors_ for x, y in zip(model.inputs, input_tensors): diff --git a/tensorflow/python/keras/models_test.py b/tensorflow/python/keras/models_test.py index 4b6bb74ef96..23321a2d16b 100644 --- a/tensorflow/python/keras/models_test.py +++ b/tensorflow/python/keras/models_test.py @@ -26,10 +26,12 @@ import numpy as np from tensorflow.python import keras from tensorflow.python.eager import context from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.keras import backend as K from tensorflow.python.keras import metrics from tensorflow.python.keras import models +from tensorflow.python.ops import array_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.platform import test @@ -219,6 +221,33 @@ class TestModelCloning(test.TestCase): with self.assertRaises(ValueError): keras.models._clone_sequential_model(seq_model, input_tensors=y) + def test_functional_cloning_does_not_create_unnecessary_placeholders(self): + with ops.Graph().as_default(): + x = keras.Input((4,)) + y = keras.layers.Dense(4)(x) + model = keras.models.Model(x, y) + graph = ops.Graph() + with graph.as_default(): + x = array_ops.ones((10, 4)) + _ = keras.models.clone_model(model, input_tensors=[x]) + has_placeholder = _has_placeholder(graph) + self.assertFalse(has_placeholder) + + def test_sequential_cloning_does_not_create_unnecessary_placeholders(self): + with ops.Graph().as_default(): + model = keras.models.Sequential([keras.layers.Dense(4)]) + graph = ops.Graph() + with graph.as_default(): + x = array_ops.ones((10, 4)) + _ = keras.models.clone_model(model, input_tensors=[x]) + has_placeholder = _has_placeholder(graph) + self.assertFalse(has_placeholder) + + +def _has_placeholder(graph): + ops_types = [op.type for op in graph.get_operations()] + return any('Placeholder' in s for s in ops_types) + class CheckpointingTests(test.TestCase):