Add support for freezing the Exit and Enter op when it is used with resource variables.

PiperOrigin-RevId: 260570736
This commit is contained in:
Nupur Garg 2019-07-29 13:42:22 -07:00 committed by TensorFlower Gardener
parent 01fea718fe
commit 7cb9426454
2 changed files with 55 additions and 30 deletions

View File

@ -45,6 +45,13 @@ _VARIABLE_OPS = {
"VariableV2",
}
_CONTROL_FLOW_OP_NAMES_OR_IDENTITY = [
"Switch",
"Enter",
"Exit",
"Identity",
]
def _is_variable_op(op):
"""Returns true if 'op' refers to a Variable node."""
@ -290,11 +297,12 @@ def convert_variables_to_constants(sess,
else:
variable_names.append(variable_name + ":0")
elif node.op in ["ReadVariableOp", "ResourceGather"]:
# There can be one or more Identity or Switch ops in between the
# There can be one or more Identity or control flow ops in between the
# ReadVariableOp and VarHandleOp. Store the ops with the associated
# dtypes.
source_op_name = get_input_name(node)
while map_name_to_node[source_op_name].op in ["Identity", "Switch"]:
while (map_name_to_node[source_op_name].op in
_CONTROL_FLOW_OP_NAMES_OR_IDENTITY):
resource_op_types[source_op_name] = node.attr["dtype"]
source_op_name = get_input_name(map_name_to_node[source_op_name])
if map_name_to_node[source_op_name].op != "VarHandleOp":

View File

@ -327,6 +327,10 @@ class ConvertVariablesToConstantsTest(test.TestCase):
sess.graph.get_tensor_by_name(tensor.name) for tensor in tensor_list
]
def _get_tensor_names(self, tensors):
"""Returns a list of string names for the tensors specified."""
return [tensor.name.split(":")[0] for tensor in tensors]
def _evaluate_graph_def(self, graph_def, inputs, outputs, input_data):
"""Evaluates the GraphDef using Sessions."""
with ops.Graph().as_default() as graph:
@ -338,6 +342,19 @@ class ConvertVariablesToConstantsTest(test.TestCase):
return sess.run(
output_tensors, feed_dict=dict(zip(input_tensors, input_data)))
def _ensure_no_variables_in_graph(self, graph_def):
"""Ensures there are no variables in the graph."""
for node in graph_def.node:
self.assertNotIn(
node.op, ["Variable", "VariableV2", "VarHandleOp", "ReadVariableOp"])
def _test_converted_keras_model(self, model, constant_graph_def, input_data):
"""Compares the converted Keras model."""
expected_value = model.predict(input_data)
actual_value = self._evaluate_graph_def(constant_graph_def, model.inputs,
model.outputs, [input_data])
np.testing.assert_almost_equal(np.array([expected_value]), actual_value, 5)
def _test_variable_to_const_conversion(self, use_resource):
with ops.Graph().as_default():
with variable_scope.variable_scope("", use_resource=use_resource):
@ -395,10 +412,7 @@ class ConvertVariablesToConstantsTest(test.TestCase):
with ops.Graph().as_default():
_ = importer.import_graph_def(constant_graph_def, name="")
self.assertEqual(4, len(constant_graph_def.node))
for node in constant_graph_def.node:
self.assertNotIn(
node.op,
["Variable", "VariableV2", "VarHandleOp", "ReadVariableOp"])
self._ensure_no_variables_in_graph(constant_graph_def)
with session.Session() as sess:
output_node = sess.graph.get_tensor_by_name("output_node:0")
output = self.evaluate(output_node)
@ -440,10 +454,7 @@ class ConvertVariablesToConstantsTest(test.TestCase):
constant_graph_def = graph_util.convert_variables_to_constants(
sess, variable_graph_def, ["output_node"])
# Ensure there are no variables after freezing.
for node in constant_graph_def.node:
self.assertNotIn(
node.op, ["Variable", "VariableV2", "VarHandleOp", "ReadVariableOp"])
self._ensure_no_variables_in_graph(constant_graph_def)
def testReferenceVariables(self):
"""Freezes a graph with reference variables."""
@ -464,9 +475,6 @@ class ConvertVariablesToConstantsTest(test.TestCase):
@test_util.run_v1_only("Incompatible with TF 2.0")
def testWithEmbeddings(self):
"""Freezes a graph with embeddings."""
input_data = np.array(np.random.random_sample([1, 1]), dtype=np.int32)
# Make model.
state_input = keras.layers.Input(
shape=(1,), name="state_input", dtype="int32")
output = keras.layers.Embedding(
@ -476,25 +484,19 @@ class ConvertVariablesToConstantsTest(test.TestCase):
model.compile(
loss={"state": "sparse_categorical_crossentropy"}, optimizer="adam")
# Get associated session.
# Freeze the graph.
sess = keras.backend.get_session()
variable_graph_def = sess.graph_def
output_tensor = [tensor.name.split(":")[0] for tensor in model.outputs]
output_tensor = self._get_tensor_names(model.outputs)
constant_graph_def = graph_util.convert_variables_to_constants(
sess, variable_graph_def, output_tensor)
# Ensure graph has no variables.
for node in constant_graph_def.node:
self.assertNotIn(
node.op, ["Variable", "VariableV2", "VarHandleOp", "ReadVariableOp"])
# Validate converted graph.
input_data = np.array(np.random.random_sample([1, 1]), dtype=np.int32)
self._ensure_no_variables_in_graph(constant_graph_def)
self._test_converted_keras_model(model, constant_graph_def, input_data)
# Compare the value of the graphs.
expected_value = model.predict(input_data)
actual_value = self._evaluate_graph_def(constant_graph_def, model.inputs,
model.outputs, [input_data])
np.testing.assert_almost_equal(np.array([expected_value]), actual_value, 5)
def testWithSwitch(self):
def testGraphWithSwitch(self):
"""Freezes a graph which contains a Switch with type RESOURCE_DT."""
with ops.Graph().as_default():
with variable_scope.variable_scope("", use_resource=True):
@ -513,10 +515,25 @@ class ConvertVariablesToConstantsTest(test.TestCase):
constant_graph_def = graph_util.convert_variables_to_constants(
sess, variable_graph_def, ["output_node"])
# Ensure there are no variables after freezing.
for node in constant_graph_def.node:
self.assertNotIn(
node.op, ["Variable", "VariableV2", "VarHandleOp", "ReadVariableOp"])
self._ensure_no_variables_in_graph(constant_graph_def)
@test_util.run_v1_only("Incompatible with TF 2.0")
def testLSTM(self):
"""Freezes a Keras LSTM."""
model = keras.models.Sequential(
[keras.layers.LSTM(units=10, input_shape=(10, 10))])
# Freeze the model.
sess = keras.backend.get_session()
variable_graph_def = sess.graph_def
output_tensor = self._get_tensor_names(model.outputs)
constant_graph_def = graph_util.convert_variables_to_constants(
sess, variable_graph_def, output_tensor)
# Validate converted graph.
input_data = np.array(np.random.random_sample([10, 10, 10]), dtype=np.int32)
self._ensure_no_variables_in_graph(constant_graph_def)
self._test_converted_keras_model(model, constant_graph_def, input_data)
if __name__ == "__main__":