Add support for freezing the Exit and Enter op when it is used with resource variables.
PiperOrigin-RevId: 260570736
This commit is contained in:
parent
01fea718fe
commit
7cb9426454
@ -45,6 +45,13 @@ _VARIABLE_OPS = {
|
|||||||
"VariableV2",
|
"VariableV2",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
_CONTROL_FLOW_OP_NAMES_OR_IDENTITY = [
|
||||||
|
"Switch",
|
||||||
|
"Enter",
|
||||||
|
"Exit",
|
||||||
|
"Identity",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def _is_variable_op(op):
|
def _is_variable_op(op):
|
||||||
"""Returns true if 'op' refers to a Variable node."""
|
"""Returns true if 'op' refers to a Variable node."""
|
||||||
@ -290,11 +297,12 @@ def convert_variables_to_constants(sess,
|
|||||||
else:
|
else:
|
||||||
variable_names.append(variable_name + ":0")
|
variable_names.append(variable_name + ":0")
|
||||||
elif node.op in ["ReadVariableOp", "ResourceGather"]:
|
elif node.op in ["ReadVariableOp", "ResourceGather"]:
|
||||||
# There can be one or more Identity or Switch ops in between the
|
# There can be one or more Identity or control flow ops in between the
|
||||||
# ReadVariableOp and VarHandleOp. Store the ops with the associated
|
# ReadVariableOp and VarHandleOp. Store the ops with the associated
|
||||||
# dtypes.
|
# dtypes.
|
||||||
source_op_name = get_input_name(node)
|
source_op_name = get_input_name(node)
|
||||||
while map_name_to_node[source_op_name].op in ["Identity", "Switch"]:
|
while (map_name_to_node[source_op_name].op in
|
||||||
|
_CONTROL_FLOW_OP_NAMES_OR_IDENTITY):
|
||||||
resource_op_types[source_op_name] = node.attr["dtype"]
|
resource_op_types[source_op_name] = node.attr["dtype"]
|
||||||
source_op_name = get_input_name(map_name_to_node[source_op_name])
|
source_op_name = get_input_name(map_name_to_node[source_op_name])
|
||||||
if map_name_to_node[source_op_name].op != "VarHandleOp":
|
if map_name_to_node[source_op_name].op != "VarHandleOp":
|
||||||
|
@ -327,6 +327,10 @@ class ConvertVariablesToConstantsTest(test.TestCase):
|
|||||||
sess.graph.get_tensor_by_name(tensor.name) for tensor in tensor_list
|
sess.graph.get_tensor_by_name(tensor.name) for tensor in tensor_list
|
||||||
]
|
]
|
||||||
|
|
||||||
|
def _get_tensor_names(self, tensors):
|
||||||
|
"""Returns a list of string names for the tensors specified."""
|
||||||
|
return [tensor.name.split(":")[0] for tensor in tensors]
|
||||||
|
|
||||||
def _evaluate_graph_def(self, graph_def, inputs, outputs, input_data):
|
def _evaluate_graph_def(self, graph_def, inputs, outputs, input_data):
|
||||||
"""Evaluates the GraphDef using Sessions."""
|
"""Evaluates the GraphDef using Sessions."""
|
||||||
with ops.Graph().as_default() as graph:
|
with ops.Graph().as_default() as graph:
|
||||||
@ -338,6 +342,19 @@ class ConvertVariablesToConstantsTest(test.TestCase):
|
|||||||
return sess.run(
|
return sess.run(
|
||||||
output_tensors, feed_dict=dict(zip(input_tensors, input_data)))
|
output_tensors, feed_dict=dict(zip(input_tensors, input_data)))
|
||||||
|
|
||||||
|
def _ensure_no_variables_in_graph(self, graph_def):
|
||||||
|
"""Ensures there are no variables in the graph."""
|
||||||
|
for node in graph_def.node:
|
||||||
|
self.assertNotIn(
|
||||||
|
node.op, ["Variable", "VariableV2", "VarHandleOp", "ReadVariableOp"])
|
||||||
|
|
||||||
|
def _test_converted_keras_model(self, model, constant_graph_def, input_data):
|
||||||
|
"""Compares the converted Keras model."""
|
||||||
|
expected_value = model.predict(input_data)
|
||||||
|
actual_value = self._evaluate_graph_def(constant_graph_def, model.inputs,
|
||||||
|
model.outputs, [input_data])
|
||||||
|
np.testing.assert_almost_equal(np.array([expected_value]), actual_value, 5)
|
||||||
|
|
||||||
def _test_variable_to_const_conversion(self, use_resource):
|
def _test_variable_to_const_conversion(self, use_resource):
|
||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
with variable_scope.variable_scope("", use_resource=use_resource):
|
with variable_scope.variable_scope("", use_resource=use_resource):
|
||||||
@ -395,10 +412,7 @@ class ConvertVariablesToConstantsTest(test.TestCase):
|
|||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
_ = importer.import_graph_def(constant_graph_def, name="")
|
_ = importer.import_graph_def(constant_graph_def, name="")
|
||||||
self.assertEqual(4, len(constant_graph_def.node))
|
self.assertEqual(4, len(constant_graph_def.node))
|
||||||
for node in constant_graph_def.node:
|
self._ensure_no_variables_in_graph(constant_graph_def)
|
||||||
self.assertNotIn(
|
|
||||||
node.op,
|
|
||||||
["Variable", "VariableV2", "VarHandleOp", "ReadVariableOp"])
|
|
||||||
with session.Session() as sess:
|
with session.Session() as sess:
|
||||||
output_node = sess.graph.get_tensor_by_name("output_node:0")
|
output_node = sess.graph.get_tensor_by_name("output_node:0")
|
||||||
output = self.evaluate(output_node)
|
output = self.evaluate(output_node)
|
||||||
@ -440,10 +454,7 @@ class ConvertVariablesToConstantsTest(test.TestCase):
|
|||||||
constant_graph_def = graph_util.convert_variables_to_constants(
|
constant_graph_def = graph_util.convert_variables_to_constants(
|
||||||
sess, variable_graph_def, ["output_node"])
|
sess, variable_graph_def, ["output_node"])
|
||||||
|
|
||||||
# Ensure there are no variables after freezing.
|
self._ensure_no_variables_in_graph(constant_graph_def)
|
||||||
for node in constant_graph_def.node:
|
|
||||||
self.assertNotIn(
|
|
||||||
node.op, ["Variable", "VariableV2", "VarHandleOp", "ReadVariableOp"])
|
|
||||||
|
|
||||||
def testReferenceVariables(self):
|
def testReferenceVariables(self):
|
||||||
"""Freezes a graph with reference variables."""
|
"""Freezes a graph with reference variables."""
|
||||||
@ -464,9 +475,6 @@ class ConvertVariablesToConstantsTest(test.TestCase):
|
|||||||
@test_util.run_v1_only("Incompatible with TF 2.0")
|
@test_util.run_v1_only("Incompatible with TF 2.0")
|
||||||
def testWithEmbeddings(self):
|
def testWithEmbeddings(self):
|
||||||
"""Freezes a graph with embeddings."""
|
"""Freezes a graph with embeddings."""
|
||||||
input_data = np.array(np.random.random_sample([1, 1]), dtype=np.int32)
|
|
||||||
|
|
||||||
# Make model.
|
|
||||||
state_input = keras.layers.Input(
|
state_input = keras.layers.Input(
|
||||||
shape=(1,), name="state_input", dtype="int32")
|
shape=(1,), name="state_input", dtype="int32")
|
||||||
output = keras.layers.Embedding(
|
output = keras.layers.Embedding(
|
||||||
@ -476,25 +484,19 @@ class ConvertVariablesToConstantsTest(test.TestCase):
|
|||||||
model.compile(
|
model.compile(
|
||||||
loss={"state": "sparse_categorical_crossentropy"}, optimizer="adam")
|
loss={"state": "sparse_categorical_crossentropy"}, optimizer="adam")
|
||||||
|
|
||||||
# Get associated session.
|
# Freeze the graph.
|
||||||
sess = keras.backend.get_session()
|
sess = keras.backend.get_session()
|
||||||
variable_graph_def = sess.graph_def
|
variable_graph_def = sess.graph_def
|
||||||
output_tensor = [tensor.name.split(":")[0] for tensor in model.outputs]
|
output_tensor = self._get_tensor_names(model.outputs)
|
||||||
constant_graph_def = graph_util.convert_variables_to_constants(
|
constant_graph_def = graph_util.convert_variables_to_constants(
|
||||||
sess, variable_graph_def, output_tensor)
|
sess, variable_graph_def, output_tensor)
|
||||||
|
|
||||||
# Ensure graph has no variables.
|
# Validate converted graph.
|
||||||
for node in constant_graph_def.node:
|
input_data = np.array(np.random.random_sample([1, 1]), dtype=np.int32)
|
||||||
self.assertNotIn(
|
self._ensure_no_variables_in_graph(constant_graph_def)
|
||||||
node.op, ["Variable", "VariableV2", "VarHandleOp", "ReadVariableOp"])
|
self._test_converted_keras_model(model, constant_graph_def, input_data)
|
||||||
|
|
||||||
# Compare the value of the graphs.
|
def testGraphWithSwitch(self):
|
||||||
expected_value = model.predict(input_data)
|
|
||||||
actual_value = self._evaluate_graph_def(constant_graph_def, model.inputs,
|
|
||||||
model.outputs, [input_data])
|
|
||||||
np.testing.assert_almost_equal(np.array([expected_value]), actual_value, 5)
|
|
||||||
|
|
||||||
def testWithSwitch(self):
|
|
||||||
"""Freezes a graph which contains a Switch with type RESOURCE_DT."""
|
"""Freezes a graph which contains a Switch with type RESOURCE_DT."""
|
||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
with variable_scope.variable_scope("", use_resource=True):
|
with variable_scope.variable_scope("", use_resource=True):
|
||||||
@ -513,10 +515,25 @@ class ConvertVariablesToConstantsTest(test.TestCase):
|
|||||||
constant_graph_def = graph_util.convert_variables_to_constants(
|
constant_graph_def = graph_util.convert_variables_to_constants(
|
||||||
sess, variable_graph_def, ["output_node"])
|
sess, variable_graph_def, ["output_node"])
|
||||||
|
|
||||||
# Ensure there are no variables after freezing.
|
self._ensure_no_variables_in_graph(constant_graph_def)
|
||||||
for node in constant_graph_def.node:
|
|
||||||
self.assertNotIn(
|
@test_util.run_v1_only("Incompatible with TF 2.0")
|
||||||
node.op, ["Variable", "VariableV2", "VarHandleOp", "ReadVariableOp"])
|
def testLSTM(self):
|
||||||
|
"""Freezes a Keras LSTM."""
|
||||||
|
model = keras.models.Sequential(
|
||||||
|
[keras.layers.LSTM(units=10, input_shape=(10, 10))])
|
||||||
|
|
||||||
|
# Freeze the model.
|
||||||
|
sess = keras.backend.get_session()
|
||||||
|
variable_graph_def = sess.graph_def
|
||||||
|
output_tensor = self._get_tensor_names(model.outputs)
|
||||||
|
constant_graph_def = graph_util.convert_variables_to_constants(
|
||||||
|
sess, variable_graph_def, output_tensor)
|
||||||
|
|
||||||
|
# Validate converted graph.
|
||||||
|
input_data = np.array(np.random.random_sample([10, 10, 10]), dtype=np.int32)
|
||||||
|
self._ensure_no_variables_in_graph(constant_graph_def)
|
||||||
|
self._test_converted_keras_model(model, constant_graph_def, input_data)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
Loading…
x
Reference in New Issue
Block a user