Fix 1.X Keras HDF5 model conversion with 2.0 converter.

PiperOrigin-RevId: 244374477
This commit is contained in:
Nupur Garg 2019-04-19 09:58:04 -07:00 committed by TensorFlower Gardener
parent d9cdb3f60d
commit c767cbf14c
2 changed files with 32 additions and 15 deletions

View File

@ -242,8 +242,9 @@ def compare_models_v2(tflite_model, concrete_func, input_data=None,
# Gets the TensorFlow results as a map from the output names to outputs.
# Converts the map into a list that is equivalent to the TFLite list.
tf_results_map = concrete_func(input_data_func)
tf_results = [tf_results_map[tf_results_map.keys()[0]]]
tf_results = concrete_func(input_data_func)
if isinstance(tf_results, dict):
tf_results = [tf_results[tf_results.keys()[0]]]
tflite_results = _evaluate_tflite_model(tflite_model, input_data)
for tf_result, tflite_result in zip(tf_results, tflite_results):
np.testing.assert_almost_equal(tf_result, tflite_result, tolerance)

View File

@ -109,14 +109,31 @@ def convert_variables_to_constants_v2(func):
input_tensors = func.inputs[-len(func.captured_inputs):]
for var in func.graph.variables:
index = func.captured_inputs.index(var.handle)
tensor = input_tensors[index]
node_name = get_name(tensor.name)
tensor_data[node_name] = var.numpy()
map_name_to_handle[node_name] = var.handle
tensor_name = get_name(input_tensors[index].name)
tensor_data[tensor_name] = var.numpy()
map_name_to_handle[tensor_name] = var.handle
# Get mapping from input name to value for non-variable placeholders.
map_name_to_value = {}
for name_tensor, value_tensor in zip(input_tensors, func.captured_inputs):
tensor_name = get_name(name_tensor.name)
if tensor_name not in map_name_to_handle:
map_name_to_value[tensor_name] = value_tensor
resource_identities = {}
resource_placeholders = {}
placeholders = {}
converted_input_indices = set()
for node in graph_def.node:
if node.name in map_name_to_value:
# Get the dtype and data for the Placeholders whose values are stored as
# Tensors. This is the case for values that were originally Const ops.
tensor = map_name_to_value[node.name]
placeholders[node.name] = {
"dtype": node.attr["dtype"],
"data": tensor.numpy(),
}
converted_input_indices.add(
func.captured_inputs.index(map_name_to_value[node.name]))
if node.op == "ReadVariableOp":
# Get name of Placeholder op associated with ReadVariableOp. There can be
# an Identity in between the ReadVariableOp and Placeholder. Store the
@ -130,22 +147,23 @@ def convert_variables_to_constants_v2(func):
"to the ReadVariableOp.")
# Build a map of Placeholder ops that are inputs to ReadVariableOps to the
# variable's dtype and data.
resource_placeholders[input_name] = {
placeholders[input_name] = {
"dtype": node.attr["dtype"],
"data": tensor_data[input_name],
}
converted_input_indices.add(
func.captured_inputs.index(map_name_to_handle[input_name]))
# Reconstruct the graph with constants in place of variables.
output_graph_def = graph_pb2.GraphDef()
how_many_converted = 0
converted_input_indices = set([])
for input_node in graph_def.node:
output_node = output_graph_def.node.add()
# Convert Placeholder ops that are inputs to ReadVariableOps into Const ops.
if input_node.name in resource_placeholders:
dtype = resource_placeholders[input_node.name]["dtype"]
data = resource_placeholders[input_node.name]["data"]
# Convert Placeholder ops to Const ops.
if input_node.name in placeholders:
dtype = placeholders[input_node.name]["dtype"]
data = placeholders[input_node.name]["data"]
output_node.op = "Const"
output_node.name = input_node.name
@ -154,8 +172,6 @@ def convert_variables_to_constants_v2(func):
tensor_util.make_tensor_proto(
data, dtype=dtype.type, shape=data.shape))
how_many_converted += 1
converted_input_indices.add(
func.captured_inputs.index(map_name_to_handle[input_node.name]))
# Change the dtype for Identity ops that are inputs to ReadVariableOps.
elif input_node.name in resource_identities:
output_node.CopyFrom(input_node)