Fix 1.X Keras HDF5 model conversion with 2.0 converter.
PiperOrigin-RevId: 244374477
This commit is contained in:
parent
d9cdb3f60d
commit
c767cbf14c
@ -242,8 +242,9 @@ def compare_models_v2(tflite_model, concrete_func, input_data=None,
|
|||||||
|
|
||||||
# Gets the TensorFlow results as a map from the output names to outputs.
|
# Gets the TensorFlow results as a map from the output names to outputs.
|
||||||
# Converts the map into a list that is equivalent to the TFLite list.
|
# Converts the map into a list that is equivalent to the TFLite list.
|
||||||
tf_results_map = concrete_func(input_data_func)
|
tf_results = concrete_func(input_data_func)
|
||||||
tf_results = [tf_results_map[tf_results_map.keys()[0]]]
|
if isinstance(tf_results, dict):
|
||||||
|
tf_results = [tf_results[tf_results.keys()[0]]]
|
||||||
tflite_results = _evaluate_tflite_model(tflite_model, input_data)
|
tflite_results = _evaluate_tflite_model(tflite_model, input_data)
|
||||||
for tf_result, tflite_result in zip(tf_results, tflite_results):
|
for tf_result, tflite_result in zip(tf_results, tflite_results):
|
||||||
np.testing.assert_almost_equal(tf_result, tflite_result, tolerance)
|
np.testing.assert_almost_equal(tf_result, tflite_result, tolerance)
|
||||||
|
@ -109,14 +109,31 @@ def convert_variables_to_constants_v2(func):
|
|||||||
input_tensors = func.inputs[-len(func.captured_inputs):]
|
input_tensors = func.inputs[-len(func.captured_inputs):]
|
||||||
for var in func.graph.variables:
|
for var in func.graph.variables:
|
||||||
index = func.captured_inputs.index(var.handle)
|
index = func.captured_inputs.index(var.handle)
|
||||||
tensor = input_tensors[index]
|
tensor_name = get_name(input_tensors[index].name)
|
||||||
node_name = get_name(tensor.name)
|
tensor_data[tensor_name] = var.numpy()
|
||||||
tensor_data[node_name] = var.numpy()
|
map_name_to_handle[tensor_name] = var.handle
|
||||||
map_name_to_handle[node_name] = var.handle
|
|
||||||
|
# Get mapping from input name to value for non-variable placeholders.
|
||||||
|
map_name_to_value = {}
|
||||||
|
for name_tensor, value_tensor in zip(input_tensors, func.captured_inputs):
|
||||||
|
tensor_name = get_name(name_tensor.name)
|
||||||
|
if tensor_name not in map_name_to_handle:
|
||||||
|
map_name_to_value[tensor_name] = value_tensor
|
||||||
|
|
||||||
resource_identities = {}
|
resource_identities = {}
|
||||||
resource_placeholders = {}
|
placeholders = {}
|
||||||
|
converted_input_indices = set()
|
||||||
for node in graph_def.node:
|
for node in graph_def.node:
|
||||||
|
if node.name in map_name_to_value:
|
||||||
|
# Get the dtype and data for the Placeholders whose values are stored as
|
||||||
|
# Tensors. This is the case for values that were originally Const ops.
|
||||||
|
tensor = map_name_to_value[node.name]
|
||||||
|
placeholders[node.name] = {
|
||||||
|
"dtype": node.attr["dtype"],
|
||||||
|
"data": tensor.numpy(),
|
||||||
|
}
|
||||||
|
converted_input_indices.add(
|
||||||
|
func.captured_inputs.index(map_name_to_value[node.name]))
|
||||||
if node.op == "ReadVariableOp":
|
if node.op == "ReadVariableOp":
|
||||||
# Get name of Placeholder op associated with ReadVariableOp. There can be
|
# Get name of Placeholder op associated with ReadVariableOp. There can be
|
||||||
# an Identity in between the ReadVariableOp and Placeholder. Store the
|
# an Identity in between the ReadVariableOp and Placeholder. Store the
|
||||||
@ -130,22 +147,23 @@ def convert_variables_to_constants_v2(func):
|
|||||||
"to the ReadVariableOp.")
|
"to the ReadVariableOp.")
|
||||||
# Build a map of Placeholder ops that are inputs to ReadVariableOps to the
|
# Build a map of Placeholder ops that are inputs to ReadVariableOps to the
|
||||||
# variable's dtype and data.
|
# variable's dtype and data.
|
||||||
resource_placeholders[input_name] = {
|
placeholders[input_name] = {
|
||||||
"dtype": node.attr["dtype"],
|
"dtype": node.attr["dtype"],
|
||||||
"data": tensor_data[input_name],
|
"data": tensor_data[input_name],
|
||||||
}
|
}
|
||||||
|
converted_input_indices.add(
|
||||||
|
func.captured_inputs.index(map_name_to_handle[input_name]))
|
||||||
|
|
||||||
# Reconstruct the graph with constants in place of variables.
|
# Reconstruct the graph with constants in place of variables.
|
||||||
output_graph_def = graph_pb2.GraphDef()
|
output_graph_def = graph_pb2.GraphDef()
|
||||||
how_many_converted = 0
|
how_many_converted = 0
|
||||||
|
|
||||||
converted_input_indices = set([])
|
|
||||||
for input_node in graph_def.node:
|
for input_node in graph_def.node:
|
||||||
output_node = output_graph_def.node.add()
|
output_node = output_graph_def.node.add()
|
||||||
# Convert Placeholder ops that are inputs to ReadVariableOps into Const ops.
|
# Convert Placeholder ops to Const ops.
|
||||||
if input_node.name in resource_placeholders:
|
if input_node.name in placeholders:
|
||||||
dtype = resource_placeholders[input_node.name]["dtype"]
|
dtype = placeholders[input_node.name]["dtype"]
|
||||||
data = resource_placeholders[input_node.name]["data"]
|
data = placeholders[input_node.name]["data"]
|
||||||
|
|
||||||
output_node.op = "Const"
|
output_node.op = "Const"
|
||||||
output_node.name = input_node.name
|
output_node.name = input_node.name
|
||||||
@ -154,8 +172,6 @@ def convert_variables_to_constants_v2(func):
|
|||||||
tensor_util.make_tensor_proto(
|
tensor_util.make_tensor_proto(
|
||||||
data, dtype=dtype.type, shape=data.shape))
|
data, dtype=dtype.type, shape=data.shape))
|
||||||
how_many_converted += 1
|
how_many_converted += 1
|
||||||
converted_input_indices.add(
|
|
||||||
func.captured_inputs.index(map_name_to_handle[input_node.name]))
|
|
||||||
# Change the dtype for Identity ops that are inputs to ReadVariableOps.
|
# Change the dtype for Identity ops that are inputs to ReadVariableOps.
|
||||||
elif input_node.name in resource_identities:
|
elif input_node.name in resource_identities:
|
||||||
output_node.CopyFrom(input_node)
|
output_node.CopyFrom(input_node)
|
||||||
|
Loading…
Reference in New Issue
Block a user