Fix issue using python flatbuffers library.
PiperOrigin-RevId: 331663655 Change-Id: I6b4f617177717faf6d0c216f3ca6bb5aba46ddbe
This commit is contained in:
parent
c11debf86c
commit
7d94d03f7a
@ -217,6 +217,7 @@ py_library(
|
|||||||
"//tensorflow/python:tf_optimizer",
|
"//tensorflow/python:tf_optimizer",
|
||||||
"//tensorflow/python/eager:wrap_function",
|
"//tensorflow/python/eager:wrap_function",
|
||||||
"@absl_py//absl/logging",
|
"@absl_py//absl/logging",
|
||||||
|
"@flatbuffers//:runtime_py",
|
||||||
"@six_archive//:six",
|
"@six_archive//:six",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|||||||
@ -27,6 +27,7 @@ from absl import logging
|
|||||||
import six
|
import six
|
||||||
from six.moves import range
|
from six.moves import range
|
||||||
|
|
||||||
|
import flatbuffers
|
||||||
from tensorflow.core.protobuf import config_pb2 as _config_pb2
|
from tensorflow.core.protobuf import config_pb2 as _config_pb2
|
||||||
from tensorflow.core.protobuf import graph_debug_info_pb2
|
from tensorflow.core.protobuf import graph_debug_info_pb2
|
||||||
from tensorflow.core.protobuf import meta_graph_pb2 as _meta_graph_pb2
|
from tensorflow.core.protobuf import meta_graph_pb2 as _meta_graph_pb2
|
||||||
@ -577,7 +578,7 @@ def _convert_model_from_bytearray_to_object(model_bytearray):
|
|||||||
def _convert_model_from_object_to_bytearray(model_object):
|
def _convert_model_from_object_to_bytearray(model_object):
|
||||||
"""Converts a tflite model from a parsable object into a bytearray."""
|
"""Converts a tflite model from a parsable object into a bytearray."""
|
||||||
# Initial size of the buffer, which will grow automatically if needed
|
# Initial size of the buffer, which will grow automatically if needed
|
||||||
builder = schema_fb.flatbuffers.Builder(1024)
|
builder = flatbuffers.Builder(1024)
|
||||||
model_offset = model_object.Pack(builder)
|
model_offset = model_object.Pack(builder)
|
||||||
builder.Finish(model_offset, file_identifier=_TFLITE_FILE_IDENTIFIER)
|
builder.Finish(model_offset, file_identifier=_TFLITE_FILE_IDENTIFIER)
|
||||||
return bytes(builder.Output())
|
return bytes(builder.Output())
|
||||||
|
|||||||
@ -112,6 +112,7 @@ py_library(
|
|||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/lite/python:schema_py",
|
"//tensorflow/lite/python:schema_py",
|
||||||
|
"@flatbuffers//:runtime_py",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -137,6 +138,7 @@ py_library(
|
|||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/lite/python:schema_py",
|
"//tensorflow/lite/python:schema_py",
|
||||||
|
"@flatbuffers//:runtime_py",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -30,6 +30,7 @@ import os
|
|||||||
import random
|
import random
|
||||||
import re
|
import re
|
||||||
|
|
||||||
|
import flatbuffers
|
||||||
from tensorflow.lite.python import schema_py_generated as schema_fb
|
from tensorflow.lite.python import schema_py_generated as schema_fb
|
||||||
|
|
||||||
_TFLITE_FILE_IDENTIFIER = b'TFL3'
|
_TFLITE_FILE_IDENTIFIER = b'TFL3'
|
||||||
@ -83,7 +84,7 @@ def read_model_with_mutable_tensors(input_tflite_file):
|
|||||||
def convert_object_to_bytearray(model_object):
|
def convert_object_to_bytearray(model_object):
|
||||||
"""Converts a tflite model from an object to a immutable bytearray."""
|
"""Converts a tflite model from an object to a immutable bytearray."""
|
||||||
# Initial size of the buffer, which will grow automatically if needed
|
# Initial size of the buffer, which will grow automatically if needed
|
||||||
builder = schema_fb.flatbuffers.Builder(1024)
|
builder = flatbuffers.Builder(1024)
|
||||||
model_offset = model_object.Pack(builder)
|
model_offset = model_object.Pack(builder)
|
||||||
builder.Finish(model_offset, file_identifier=_TFLITE_FILE_IDENTIFIER)
|
builder.Finish(model_offset, file_identifier=_TFLITE_FILE_IDENTIFIER)
|
||||||
model_bytearray = bytes(builder.Output())
|
model_bytearray = bytes(builder.Output())
|
||||||
@ -156,7 +157,7 @@ def randomize_weights(model, random_seed=0):
|
|||||||
|
|
||||||
|
|
||||||
def xxd_output_to_bytes(input_cc_file):
|
def xxd_output_to_bytes(input_cc_file):
|
||||||
"""Converts xxd output C++ source file to bytes (immutable).
|
"""Converts xxd output C++ source file to bytes (immutable)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input_cc_file: Full path name to th C++ source file dumped by xxd
|
input_cc_file: Full path name to th C++ source file dumped by xxd
|
||||||
@ -195,7 +196,7 @@ def xxd_output_to_bytes(input_cc_file):
|
|||||||
|
|
||||||
|
|
||||||
def xxd_output_to_object(input_cc_file):
|
def xxd_output_to_object(input_cc_file):
|
||||||
"""Converts xxd output C++ source file to object.
|
"""Converts xxd output C++ source file to object
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input_cc_file: Full path name to th C++ source file dumped by xxd
|
input_cc_file: Full path name to th C++ source file dumped by xxd
|
||||||
|
|||||||
@ -21,6 +21,7 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import flatbuffers
|
||||||
from tensorflow.lite.python import schema_py_generated as schema_fb
|
from tensorflow.lite.python import schema_py_generated as schema_fb
|
||||||
|
|
||||||
TFLITE_SCHEMA_VERSION = 3
|
TFLITE_SCHEMA_VERSION = 3
|
||||||
@ -28,7 +29,7 @@ TFLITE_SCHEMA_VERSION = 3
|
|||||||
|
|
||||||
def build_mock_flatbuffer_model():
|
def build_mock_flatbuffer_model():
|
||||||
"""Creates a flatbuffer containing an example model."""
|
"""Creates a flatbuffer containing an example model."""
|
||||||
builder = schema_fb.flatbuffers.Builder(1024)
|
builder = flatbuffers.Builder(1024)
|
||||||
|
|
||||||
schema_fb.BufferStart(builder)
|
schema_fb.BufferStart(builder)
|
||||||
buffer0_offset = schema_fb.BufferEnd(builder)
|
buffer0_offset = schema_fb.BufferEnd(builder)
|
||||||
|
|||||||
4
third_party/flatbuffers/build_defs.bzl
vendored
4
third_party/flatbuffers/build_defs.bzl
vendored
@ -370,8 +370,8 @@ def _concat_flatbuffer_py_srcs_impl(ctx):
|
|||||||
outputs = [ctx.outputs.out],
|
outputs = [ctx.outputs.out],
|
||||||
command = (
|
command = (
|
||||||
"find '%s' -name '*.py' -exec cat {} + |" +
|
"find '%s' -name '*.py' -exec cat {} + |" +
|
||||||
"sed 's/from flatbuffers./from flatbuffers.python.flatbuffers./g' |" +
|
"sed 's/from flatbuffers.compat import import_numpy/import numpy as np' |" +
|
||||||
"sed 's/import flatbuffers/from flatbuffers.python import flatbuffers/g' > %s"
|
"sed '/np = import_numpy()/d' > %s"
|
||||||
) % (
|
) % (
|
||||||
ctx.attr.deps[0].files.to_list()[0].path,
|
ctx.attr.deps[0].files.to_list()[0].path,
|
||||||
ctx.outputs.out.path,
|
ctx.outputs.out.path,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user