diff --git a/tensorflow/lite/testing/model_coverage/model_coverage_lib.py b/tensorflow/lite/testing/model_coverage/model_coverage_lib.py index ba1534c6152..eaf8918308e 100644 --- a/tensorflow/lite/testing/model_coverage/model_coverage_lib.py +++ b/tensorflow/lite/testing/model_coverage/model_coverage_lib.py @@ -242,8 +242,9 @@ def compare_models_v2(tflite_model, concrete_func, input_data=None, # Gets the TensorFlow results as a map from the output names to outputs. # Converts the map into a list that is equivalent to the TFLite list. - tf_results_map = concrete_func(input_data_func) - tf_results = [tf_results_map[tf_results_map.keys()[0]]] + tf_results = concrete_func(input_data_func) + if isinstance(tf_results, dict): + tf_results = [tf_results[tf_results.keys()[0]]] tflite_results = _evaluate_tflite_model(tflite_model, input_data) for tf_result, tflite_result in zip(tf_results, tflite_results): np.testing.assert_almost_equal(tf_result, tflite_result, tolerance) diff --git a/tensorflow/python/framework/convert_to_constants.py b/tensorflow/python/framework/convert_to_constants.py index 6352ef16c4e..43b713d8ca9 100644 --- a/tensorflow/python/framework/convert_to_constants.py +++ b/tensorflow/python/framework/convert_to_constants.py @@ -109,14 +109,31 @@ def convert_variables_to_constants_v2(func): input_tensors = func.inputs[-len(func.captured_inputs):] for var in func.graph.variables: index = func.captured_inputs.index(var.handle) - tensor = input_tensors[index] - node_name = get_name(tensor.name) - tensor_data[node_name] = var.numpy() - map_name_to_handle[node_name] = var.handle + tensor_name = get_name(input_tensors[index].name) + tensor_data[tensor_name] = var.numpy() + map_name_to_handle[tensor_name] = var.handle + + # Get mapping from input name to value for non-variable placeholders. + map_name_to_value = {} + for name_tensor, value_tensor in zip(input_tensors, func.captured_inputs): + tensor_name = get_name(name_tensor.name) + if tensor_name not in map_name_to_handle: + map_name_to_value[tensor_name] = value_tensor resource_identities = {} - resource_placeholders = {} + placeholders = {} + converted_input_indices = set() for node in graph_def.node: + if node.name in map_name_to_value: + # Get the dtype and data for the Placeholders whose values are stored as + # Tensors. This is the case for values that were originally Const ops. + tensor = map_name_to_value[node.name] + placeholders[node.name] = { + "dtype": node.attr["dtype"], + "data": tensor.numpy(), + } + converted_input_indices.add( + func.captured_inputs.index(map_name_to_value[node.name])) if node.op == "ReadVariableOp": # Get name of Placeholder op associated with ReadVariableOp. There can be # an Identity in between the ReadVariableOp and Placeholder. Store the @@ -130,22 +147,23 @@ def convert_variables_to_constants_v2(func): "to the ReadVariableOp.") # Build a map of Placeholder ops that are inputs to ReadVariableOps to the # variable's dtype and data. - resource_placeholders[input_name] = { + placeholders[input_name] = { "dtype": node.attr["dtype"], "data": tensor_data[input_name], } + converted_input_indices.add( + func.captured_inputs.index(map_name_to_handle[input_name])) # Reconstruct the graph with constants in place of variables. output_graph_def = graph_pb2.GraphDef() how_many_converted = 0 - converted_input_indices = set([]) for input_node in graph_def.node: output_node = output_graph_def.node.add() - # Convert Placeholder ops that are inputs to ReadVariableOps into Const ops. - if input_node.name in resource_placeholders: - dtype = resource_placeholders[input_node.name]["dtype"] - data = resource_placeholders[input_node.name]["data"] + # Convert Placeholder ops to Const ops. + if input_node.name in placeholders: + dtype = placeholders[input_node.name]["dtype"] + data = placeholders[input_node.name]["data"] output_node.op = "Const" output_node.name = input_node.name @@ -154,8 +172,6 @@ def convert_variables_to_constants_v2(func): tensor_util.make_tensor_proto( data, dtype=dtype.type, shape=data.shape)) how_many_converted += 1 - converted_input_indices.add( - func.captured_inputs.index(map_name_to_handle[input_node.name])) # Change the dtype for Identity ops that are inputs to ReadVariableOps. elif input_node.name in resource_identities: output_node.CopyFrom(input_node)