Add support for freezing the Exit and Enter op when it is used with resource variables.
PiperOrigin-RevId: 260570736
This commit is contained in:
parent
01fea718fe
commit
7cb9426454
@ -45,6 +45,13 @@ _VARIABLE_OPS = {
|
||||
"VariableV2",
|
||||
}
|
||||
|
||||
_CONTROL_FLOW_OP_NAMES_OR_IDENTITY = [
|
||||
"Switch",
|
||||
"Enter",
|
||||
"Exit",
|
||||
"Identity",
|
||||
]
|
||||
|
||||
|
||||
def _is_variable_op(op):
|
||||
"""Returns true if 'op' refers to a Variable node."""
|
||||
@ -290,11 +297,12 @@ def convert_variables_to_constants(sess,
|
||||
else:
|
||||
variable_names.append(variable_name + ":0")
|
||||
elif node.op in ["ReadVariableOp", "ResourceGather"]:
|
||||
# There can be one or more Identity or Switch ops in between the
|
||||
# There can be one or more Identity or control flow ops in between the
|
||||
# ReadVariableOp and VarHandleOp. Store the ops with the associated
|
||||
# dtypes.
|
||||
source_op_name = get_input_name(node)
|
||||
while map_name_to_node[source_op_name].op in ["Identity", "Switch"]:
|
||||
while (map_name_to_node[source_op_name].op in
|
||||
_CONTROL_FLOW_OP_NAMES_OR_IDENTITY):
|
||||
resource_op_types[source_op_name] = node.attr["dtype"]
|
||||
source_op_name = get_input_name(map_name_to_node[source_op_name])
|
||||
if map_name_to_node[source_op_name].op != "VarHandleOp":
|
||||
|
@ -327,6 +327,10 @@ class ConvertVariablesToConstantsTest(test.TestCase):
|
||||
sess.graph.get_tensor_by_name(tensor.name) for tensor in tensor_list
|
||||
]
|
||||
|
||||
def _get_tensor_names(self, tensors):
|
||||
"""Returns a list of string names for the tensors specified."""
|
||||
return [tensor.name.split(":")[0] for tensor in tensors]
|
||||
|
||||
def _evaluate_graph_def(self, graph_def, inputs, outputs, input_data):
|
||||
"""Evaluates the GraphDef using Sessions."""
|
||||
with ops.Graph().as_default() as graph:
|
||||
@ -338,6 +342,19 @@ class ConvertVariablesToConstantsTest(test.TestCase):
|
||||
return sess.run(
|
||||
output_tensors, feed_dict=dict(zip(input_tensors, input_data)))
|
||||
|
||||
def _ensure_no_variables_in_graph(self, graph_def):
|
||||
"""Ensures there are no variables in the graph."""
|
||||
for node in graph_def.node:
|
||||
self.assertNotIn(
|
||||
node.op, ["Variable", "VariableV2", "VarHandleOp", "ReadVariableOp"])
|
||||
|
||||
def _test_converted_keras_model(self, model, constant_graph_def, input_data):
|
||||
"""Compares the converted Keras model."""
|
||||
expected_value = model.predict(input_data)
|
||||
actual_value = self._evaluate_graph_def(constant_graph_def, model.inputs,
|
||||
model.outputs, [input_data])
|
||||
np.testing.assert_almost_equal(np.array([expected_value]), actual_value, 5)
|
||||
|
||||
def _test_variable_to_const_conversion(self, use_resource):
|
||||
with ops.Graph().as_default():
|
||||
with variable_scope.variable_scope("", use_resource=use_resource):
|
||||
@ -395,10 +412,7 @@ class ConvertVariablesToConstantsTest(test.TestCase):
|
||||
with ops.Graph().as_default():
|
||||
_ = importer.import_graph_def(constant_graph_def, name="")
|
||||
self.assertEqual(4, len(constant_graph_def.node))
|
||||
for node in constant_graph_def.node:
|
||||
self.assertNotIn(
|
||||
node.op,
|
||||
["Variable", "VariableV2", "VarHandleOp", "ReadVariableOp"])
|
||||
self._ensure_no_variables_in_graph(constant_graph_def)
|
||||
with session.Session() as sess:
|
||||
output_node = sess.graph.get_tensor_by_name("output_node:0")
|
||||
output = self.evaluate(output_node)
|
||||
@ -440,10 +454,7 @@ class ConvertVariablesToConstantsTest(test.TestCase):
|
||||
constant_graph_def = graph_util.convert_variables_to_constants(
|
||||
sess, variable_graph_def, ["output_node"])
|
||||
|
||||
# Ensure there are no variables after freezing.
|
||||
for node in constant_graph_def.node:
|
||||
self.assertNotIn(
|
||||
node.op, ["Variable", "VariableV2", "VarHandleOp", "ReadVariableOp"])
|
||||
self._ensure_no_variables_in_graph(constant_graph_def)
|
||||
|
||||
def testReferenceVariables(self):
|
||||
"""Freezes a graph with reference variables."""
|
||||
@ -464,9 +475,6 @@ class ConvertVariablesToConstantsTest(test.TestCase):
|
||||
@test_util.run_v1_only("Incompatible with TF 2.0")
|
||||
def testWithEmbeddings(self):
|
||||
"""Freezes a graph with embeddings."""
|
||||
input_data = np.array(np.random.random_sample([1, 1]), dtype=np.int32)
|
||||
|
||||
# Make model.
|
||||
state_input = keras.layers.Input(
|
||||
shape=(1,), name="state_input", dtype="int32")
|
||||
output = keras.layers.Embedding(
|
||||
@ -476,25 +484,19 @@ class ConvertVariablesToConstantsTest(test.TestCase):
|
||||
model.compile(
|
||||
loss={"state": "sparse_categorical_crossentropy"}, optimizer="adam")
|
||||
|
||||
# Get associated session.
|
||||
# Freeze the graph.
|
||||
sess = keras.backend.get_session()
|
||||
variable_graph_def = sess.graph_def
|
||||
output_tensor = [tensor.name.split(":")[0] for tensor in model.outputs]
|
||||
output_tensor = self._get_tensor_names(model.outputs)
|
||||
constant_graph_def = graph_util.convert_variables_to_constants(
|
||||
sess, variable_graph_def, output_tensor)
|
||||
|
||||
# Ensure graph has no variables.
|
||||
for node in constant_graph_def.node:
|
||||
self.assertNotIn(
|
||||
node.op, ["Variable", "VariableV2", "VarHandleOp", "ReadVariableOp"])
|
||||
# Validate converted graph.
|
||||
input_data = np.array(np.random.random_sample([1, 1]), dtype=np.int32)
|
||||
self._ensure_no_variables_in_graph(constant_graph_def)
|
||||
self._test_converted_keras_model(model, constant_graph_def, input_data)
|
||||
|
||||
# Compare the value of the graphs.
|
||||
expected_value = model.predict(input_data)
|
||||
actual_value = self._evaluate_graph_def(constant_graph_def, model.inputs,
|
||||
model.outputs, [input_data])
|
||||
np.testing.assert_almost_equal(np.array([expected_value]), actual_value, 5)
|
||||
|
||||
def testWithSwitch(self):
|
||||
def testGraphWithSwitch(self):
|
||||
"""Freezes a graph which contains a Switch with type RESOURCE_DT."""
|
||||
with ops.Graph().as_default():
|
||||
with variable_scope.variable_scope("", use_resource=True):
|
||||
@ -513,10 +515,25 @@ class ConvertVariablesToConstantsTest(test.TestCase):
|
||||
constant_graph_def = graph_util.convert_variables_to_constants(
|
||||
sess, variable_graph_def, ["output_node"])
|
||||
|
||||
# Ensure there are no variables after freezing.
|
||||
for node in constant_graph_def.node:
|
||||
self.assertNotIn(
|
||||
node.op, ["Variable", "VariableV2", "VarHandleOp", "ReadVariableOp"])
|
||||
self._ensure_no_variables_in_graph(constant_graph_def)
|
||||
|
||||
@test_util.run_v1_only("Incompatible with TF 2.0")
|
||||
def testLSTM(self):
|
||||
"""Freezes a Keras LSTM."""
|
||||
model = keras.models.Sequential(
|
||||
[keras.layers.LSTM(units=10, input_shape=(10, 10))])
|
||||
|
||||
# Freeze the model.
|
||||
sess = keras.backend.get_session()
|
||||
variable_graph_def = sess.graph_def
|
||||
output_tensor = self._get_tensor_names(model.outputs)
|
||||
constant_graph_def = graph_util.convert_variables_to_constants(
|
||||
sess, variable_graph_def, output_tensor)
|
||||
|
||||
# Validate converted graph.
|
||||
input_data = np.array(np.random.random_sample([10, 10, 10]), dtype=np.int32)
|
||||
self._ensure_no_variables_in_graph(constant_graph_def)
|
||||
self._test_converted_keras_model(model, constant_graph_def, input_data)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Loading…
Reference in New Issue
Block a user