Fix issue using python flatbuffers library.
PiperOrigin-RevId: 331663655 Change-Id: I6b4f617177717faf6d0c216f3ca6bb5aba46ddbe
This commit is contained in:
parent
c11debf86c
commit
7d94d03f7a
@ -217,6 +217,7 @@ py_library(
|
||||
"//tensorflow/python:tf_optimizer",
|
||||
"//tensorflow/python/eager:wrap_function",
|
||||
"@absl_py//absl/logging",
|
||||
"@flatbuffers//:runtime_py",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
@ -27,6 +27,7 @@ from absl import logging
|
||||
import six
|
||||
from six.moves import range
|
||||
|
||||
import flatbuffers
|
||||
from tensorflow.core.protobuf import config_pb2 as _config_pb2
|
||||
from tensorflow.core.protobuf import graph_debug_info_pb2
|
||||
from tensorflow.core.protobuf import meta_graph_pb2 as _meta_graph_pb2
|
||||
@ -577,7 +578,7 @@ def _convert_model_from_bytearray_to_object(model_bytearray):
|
||||
def _convert_model_from_object_to_bytearray(model_object):
|
||||
"""Converts a tflite model from a parsable object into a bytearray."""
|
||||
# Initial size of the buffer, which will grow automatically if needed
|
||||
builder = schema_fb.flatbuffers.Builder(1024)
|
||||
builder = flatbuffers.Builder(1024)
|
||||
model_offset = model_object.Pack(builder)
|
||||
builder.Finish(model_offset, file_identifier=_TFLITE_FILE_IDENTIFIER)
|
||||
return bytes(builder.Output())
|
||||
|
@ -112,6 +112,7 @@ py_library(
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/lite/python:schema_py",
|
||||
"@flatbuffers//:runtime_py",
|
||||
],
|
||||
)
|
||||
|
||||
@ -137,6 +138,7 @@ py_library(
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/lite/python:schema_py",
|
||||
"@flatbuffers//:runtime_py",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -30,6 +30,7 @@ import os
|
||||
import random
|
||||
import re
|
||||
|
||||
import flatbuffers
|
||||
from tensorflow.lite.python import schema_py_generated as schema_fb
|
||||
|
||||
_TFLITE_FILE_IDENTIFIER = b'TFL3'
|
||||
@ -83,7 +84,7 @@ def read_model_with_mutable_tensors(input_tflite_file):
|
||||
def convert_object_to_bytearray(model_object):
|
||||
"""Converts a tflite model from an object to a immutable bytearray."""
|
||||
# Initial size of the buffer, which will grow automatically if needed
|
||||
builder = schema_fb.flatbuffers.Builder(1024)
|
||||
builder = flatbuffers.Builder(1024)
|
||||
model_offset = model_object.Pack(builder)
|
||||
builder.Finish(model_offset, file_identifier=_TFLITE_FILE_IDENTIFIER)
|
||||
model_bytearray = bytes(builder.Output())
|
||||
@ -156,7 +157,7 @@ def randomize_weights(model, random_seed=0):
|
||||
|
||||
|
||||
def xxd_output_to_bytes(input_cc_file):
|
||||
"""Converts xxd output C++ source file to bytes (immutable).
|
||||
"""Converts xxd output C++ source file to bytes (immutable)
|
||||
|
||||
Args:
|
||||
input_cc_file: Full path name to th C++ source file dumped by xxd
|
||||
@ -195,7 +196,7 @@ def xxd_output_to_bytes(input_cc_file):
|
||||
|
||||
|
||||
def xxd_output_to_object(input_cc_file):
|
||||
"""Converts xxd output C++ source file to object.
|
||||
"""Converts xxd output C++ source file to object
|
||||
|
||||
Args:
|
||||
input_cc_file: Full path name to th C++ source file dumped by xxd
|
||||
|
@ -21,6 +21,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import flatbuffers
|
||||
from tensorflow.lite.python import schema_py_generated as schema_fb
|
||||
|
||||
TFLITE_SCHEMA_VERSION = 3
|
||||
@ -28,7 +29,7 @@ TFLITE_SCHEMA_VERSION = 3
|
||||
|
||||
def build_mock_flatbuffer_model():
|
||||
"""Creates a flatbuffer containing an example model."""
|
||||
builder = schema_fb.flatbuffers.Builder(1024)
|
||||
builder = flatbuffers.Builder(1024)
|
||||
|
||||
schema_fb.BufferStart(builder)
|
||||
buffer0_offset = schema_fb.BufferEnd(builder)
|
||||
|
4
third_party/flatbuffers/build_defs.bzl
vendored
4
third_party/flatbuffers/build_defs.bzl
vendored
@ -370,8 +370,8 @@ def _concat_flatbuffer_py_srcs_impl(ctx):
|
||||
outputs = [ctx.outputs.out],
|
||||
command = (
|
||||
"find '%s' -name '*.py' -exec cat {} + |" +
|
||||
"sed 's/from flatbuffers./from flatbuffers.python.flatbuffers./g' |" +
|
||||
"sed 's/import flatbuffers/from flatbuffers.python import flatbuffers/g' > %s"
|
||||
"sed 's/from flatbuffers.compat import import_numpy/import numpy as np' |" +
|
||||
"sed '/np = import_numpy()/d' > %s"
|
||||
) % (
|
||||
ctx.attr.deps[0].files.to_list()[0].path,
|
||||
ctx.outputs.out.path,
|
||||
|
Loading…
Reference in New Issue
Block a user