Merge branch 'master' into interface_16x8
This commit is contained in:
commit
5b9a467e3d
9
.bazelrc
9
.bazelrc
@ -46,7 +46,6 @@
|
||||
# sycl_asan:
|
||||
# sycl_trisycl:
|
||||
# mkl: Enable full mkl support.
|
||||
# mkl_open_source_only: Enable MKL support only using open source MKL libraries.
|
||||
# tensorrt: Enable Tensorrt support.
|
||||
# ngraph: Enable ngraph support.
|
||||
# numa: Enable numa using hwloc.
|
||||
@ -140,13 +139,6 @@ build:mkl --define=tensorflow_mkldnn_contraction_kernel=0
|
||||
build:mkl --define=build_with_mkl_dnn_v1_only=true
|
||||
build:mkl -c opt
|
||||
|
||||
# This config option is used to enable MKL-DNN open source library only,
|
||||
# without depending on MKL binary version.
|
||||
build:mkl_open_source_only --define=build_with_mkl_dnn_only=true
|
||||
build:mkl_open_source_only --define=build_with_mkl_dnn_v1_only=true
|
||||
build:mkl_open_source_only --define=build_with_mkl=true --define=enable_mkl=true
|
||||
build:mkl_open_source_only --define=tensorflow_mkldnn_contraction_kernel=0
|
||||
|
||||
# This config refers to building with CUDA available. It does not necessarily
|
||||
# mean that we build CUDA op kernels.
|
||||
build:using_cuda --define=using_cuda=true
|
||||
@ -248,6 +240,7 @@ build:windows --copt=/w
|
||||
# Tensorflow uses M_* math constants that only get defined by MSVC headers if
|
||||
# _USE_MATH_DEFINES is defined.
|
||||
build:windows --copt=/D_USE_MATH_DEFINES
|
||||
build:windows --host_copt=/D_USE_MATH_DEFINES
|
||||
|
||||
# Default paths for TF_SYSTEM_LIBS
|
||||
build:linux --define=PREFIX=/usr
|
||||
|
4
.gitignore
vendored
4
.gitignore
vendored
@ -38,7 +38,9 @@ gradleBuild
|
||||
*.pbxproj
|
||||
*.xcworkspace
|
||||
/*.podspec
|
||||
/tensorflow/lite/**/[ios|objc|swift]*/BUILD
|
||||
/tensorflow/lite/**/ios/BUILD
|
||||
/tensorflow/lite/**/objc/BUILD
|
||||
/tensorflow/lite/**/swift/BUILD
|
||||
/tensorflow/lite/examples/ios/simple/data/*.tflite
|
||||
/tensorflow/lite/examples/ios/simple/data/*.txt
|
||||
Podfile.lock
|
||||
|
@ -154,7 +154,10 @@ tf_cuda_library(
|
||||
"c_api.h",
|
||||
],
|
||||
copts = tf_copts(),
|
||||
visibility = ["//tensorflow/c:__subpackages__"],
|
||||
visibility = [
|
||||
"//tensorflow/c:__subpackages__",
|
||||
"//third_party/llvm/llvm-project:__subpackages__",
|
||||
],
|
||||
deps = [
|
||||
":c_api_internal",
|
||||
":tf_attrtype",
|
||||
@ -698,4 +701,5 @@ tf_cuda_library(
|
||||
# TODO(b/74620627): remove when _USE_C_SHAPES is removed
|
||||
"//tensorflow/python:cpp_shape_inference_proto_cc",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
@ -774,7 +774,7 @@ extern "C" {
|
||||
static TF_OperationDescription* TF_NewOperationLocked(TF_Graph* graph,
|
||||
const char* op_type,
|
||||
const char* oper_name)
|
||||
EXCLUSIVE_LOCKS_REQUIRED(graph->mu) {
|
||||
TF_EXCLUSIVE_LOCKS_REQUIRED(graph->mu) {
|
||||
return new TF_OperationDescription(graph, op_type, oper_name);
|
||||
}
|
||||
|
||||
@ -1032,7 +1032,7 @@ void TF_SetAttrValueProto(TF_OperationDescription* desc, const char* attr_name,
|
||||
|
||||
static TF_Operation* TF_FinishOperationLocked(TF_OperationDescription* desc,
|
||||
TF_Status* status)
|
||||
EXCLUSIVE_LOCKS_REQUIRED(desc->graph->mu) {
|
||||
TF_EXCLUSIVE_LOCKS_REQUIRED(desc->graph->mu) {
|
||||
Node* ret = nullptr;
|
||||
|
||||
if (desc->graph->name_map.count(desc->node_builder.node_name())) {
|
||||
@ -1706,7 +1706,7 @@ static void GraphImportGraphDefLocked(TF_Graph* graph, const GraphDef& def,
|
||||
const TF_ImportGraphDefOptions* opts,
|
||||
TF_ImportGraphDefResults* tf_results,
|
||||
TF_Status* status)
|
||||
EXCLUSIVE_LOCKS_REQUIRED(graph->mu) {
|
||||
TF_EXCLUSIVE_LOCKS_REQUIRED(graph->mu) {
|
||||
const int last_node_id = graph->graph.num_node_ids();
|
||||
tensorflow::ImportGraphDefResults results;
|
||||
status->status = tensorflow::ImportGraphDef(opts->opts, def, &graph->graph,
|
||||
|
@ -51,7 +51,7 @@ Status ProcessInputs(
|
||||
const TF_Graph* fn_body, const char* fn_name, int ninputs,
|
||||
const TF_Output* inputs, std::vector<OutputTensor>* input_tensors,
|
||||
std::unordered_map<const Node*, std::vector<int>>* input_nodes)
|
||||
EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
|
||||
TF_EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
|
||||
input_tensors->reserve(ninputs);
|
||||
for (int i = 0; i < ninputs; ++i) {
|
||||
Node* node = &inputs[i].oper->node;
|
||||
@ -87,7 +87,7 @@ Status ProcessInputs(
|
||||
Status ProcessOutputs(const TF_Graph* fn_body, const char* fn_name,
|
||||
int noutputs, const TF_Output* outputs,
|
||||
std::vector<OutputTensor>* output_tensors)
|
||||
EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
|
||||
TF_EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
|
||||
output_tensors->reserve(noutputs);
|
||||
for (int i = 0; i < noutputs; ++i) {
|
||||
Node* node = &outputs[i].oper->node;
|
||||
@ -111,7 +111,7 @@ Status ComputeBodyNodes(
|
||||
const TF_Operation* const* opers,
|
||||
const std::unordered_map<const Node*, std::vector<int>>& input_nodes,
|
||||
std::vector<const Node*>* body_nodes)
|
||||
EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
|
||||
TF_EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
|
||||
if (num_opers == -1) {
|
||||
for (const Node* node : fn_body->graph.op_nodes()) {
|
||||
const auto& iter = input_nodes.find(node);
|
||||
|
@ -71,14 +71,14 @@ struct TF_Graph {
|
||||
TF_Graph();
|
||||
|
||||
tensorflow::mutex mu;
|
||||
tensorflow::Graph graph GUARDED_BY(mu);
|
||||
tensorflow::Graph graph TF_GUARDED_BY(mu);
|
||||
|
||||
// Runs shape inference.
|
||||
tensorflow::ShapeRefiner refiner GUARDED_BY(mu);
|
||||
tensorflow::ShapeRefiner refiner TF_GUARDED_BY(mu);
|
||||
|
||||
// Maps from name of an operation to the Node* in 'graph'.
|
||||
std::unordered_map<tensorflow::string, tensorflow::Node*> name_map
|
||||
GUARDED_BY(mu);
|
||||
TF_GUARDED_BY(mu);
|
||||
|
||||
// The keys of this map are all the active sessions using this graph. Each
|
||||
// value records whether the graph has been mutated since the corresponding
|
||||
@ -94,8 +94,8 @@ struct TF_Graph {
|
||||
// TODO(b/74949947): mutations currently trigger a warning instead of a bad
|
||||
// status, this should be reverted when possible.
|
||||
tensorflow::gtl::FlatMap<TF_Session*, tensorflow::string> sessions
|
||||
GUARDED_BY(mu);
|
||||
bool delete_requested GUARDED_BY(mu); // set true by TF_DeleteGraph
|
||||
TF_GUARDED_BY(mu);
|
||||
bool delete_requested TF_GUARDED_BY(mu); // set true by TF_DeleteGraph
|
||||
|
||||
// Used to link graphs contained in TF_WhileParams to the parent graph that
|
||||
// will eventually contain the full while loop.
|
||||
@ -123,7 +123,7 @@ struct TF_Session {
|
||||
tensorflow::Session* session;
|
||||
TF_Graph* const graph;
|
||||
|
||||
tensorflow::mutex mu ACQUIRED_AFTER(TF_Graph::mu);
|
||||
tensorflow::mutex mu TF_ACQUIRED_AFTER(TF_Graph::mu);
|
||||
int last_num_graph_nodes;
|
||||
|
||||
// If true, TF_SessionRun and similar methods will call
|
||||
@ -169,9 +169,9 @@ struct TF_ApiDefMap {
|
||||
}
|
||||
|
||||
#if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
|
||||
tensorflow::ApiDefMap api_def_map GUARDED_BY(lock);
|
||||
tensorflow::ApiDefMap api_def_map TF_GUARDED_BY(lock);
|
||||
#endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
|
||||
bool update_docs_called GUARDED_BY(lock);
|
||||
bool update_docs_called TF_GUARDED_BY(lock);
|
||||
tensorflow::mutex lock;
|
||||
};
|
||||
|
||||
@ -210,10 +210,10 @@ void TF_GraphSetOutputHandleShapesAndTypes(TF_Graph* graph, TF_Output output,
|
||||
|
||||
void RecordMutation(TF_Graph* graph, const TF_Operation& op,
|
||||
const char* mutation_type)
|
||||
EXCLUSIVE_LOCKS_REQUIRED(graph->mu);
|
||||
TF_EXCLUSIVE_LOCKS_REQUIRED(graph->mu);
|
||||
|
||||
bool ExtendSessionGraphHelper(TF_Session* session, TF_Status* status)
|
||||
LOCKS_EXCLUDED(session->graph->mu, session->mu);
|
||||
TF_LOCKS_EXCLUDED(session->graph->mu, session->mu);
|
||||
|
||||
std::string getTF_OutputDebugString(TF_Output node);
|
||||
|
||||
|
@ -354,6 +354,7 @@ cc_library(
|
||||
"//tensorflow/core:lib",
|
||||
"@dlpack",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
# TODO(karllessard): only used by //tensorflow/core:mobile_srcs_only_runtime
|
||||
|
@ -1710,8 +1710,9 @@ void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op,
|
||||
namespace {
|
||||
class CustomDeviceAPI : public tensorflow::CustomDevice {
|
||||
public:
|
||||
CustomDeviceAPI(TFE_CustomDevice device, void* info, string name)
|
||||
: device_(device), info_(info), name_(name) {}
|
||||
CustomDeviceAPI(TFE_Context* context, TFE_CustomDevice device, void* info,
|
||||
string name)
|
||||
: context_(context), device_(device), info_(info), name_(name) {}
|
||||
|
||||
~CustomDeviceAPI() override { device_.delete_device(info_); }
|
||||
|
||||
@ -1725,7 +1726,7 @@ class CustomDeviceAPI : public tensorflow::CustomDevice {
|
||||
std::make_unique<tensorflow::TensorHandleInterface>(tensor)};
|
||||
TF_Status status;
|
||||
TFE_TensorHandle* result_handle =
|
||||
device_.copy_tensor_to_device(&tensor_handle, &status, info_);
|
||||
device_.copy_tensor_to_device(context_, &tensor_handle, &status, info_);
|
||||
if (!status.status.ok()) return status.status;
|
||||
*result = tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
|
||||
result_handle->handle.get())
|
||||
@ -1744,7 +1745,7 @@ class CustomDeviceAPI : public tensorflow::CustomDevice {
|
||||
TFE_TensorHandle tensor_handle{
|
||||
std::make_unique<tensorflow::TensorHandleInterface>(tensor)};
|
||||
TFE_TensorHandle* result_handle = device_.copy_tensor_from_device(
|
||||
&tensor_handle, target_device_name.c_str(), &status, info_);
|
||||
context_, &tensor_handle, target_device_name.c_str(), &status, info_);
|
||||
if (!status.status.ok()) return status.status;
|
||||
*result = tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
|
||||
result_handle->handle.get())
|
||||
@ -1768,7 +1769,7 @@ class CustomDeviceAPI : public tensorflow::CustomDevice {
|
||||
std::vector<TFE_TensorHandle*> outputs(*num_retvals);
|
||||
TF_Status status;
|
||||
TFE_OpAttrs attributes(&op->Attrs(), op->Name().c_str());
|
||||
device_.execute(inputs.size(), inputs.data(), op->Name().c_str(),
|
||||
device_.execute(context_, inputs.size(), inputs.data(), op->Name().c_str(),
|
||||
&attributes, num_retvals, outputs.data(), &status, info_);
|
||||
if (status.status.ok()) {
|
||||
for (int i = 0; i < *num_retvals; ++i) {
|
||||
@ -1787,6 +1788,7 @@ class CustomDeviceAPI : public tensorflow::CustomDevice {
|
||||
}
|
||||
|
||||
private:
|
||||
TFE_Context* context_;
|
||||
TFE_CustomDevice device_;
|
||||
void* info_;
|
||||
string name_;
|
||||
@ -1794,8 +1796,10 @@ class CustomDeviceAPI : public tensorflow::CustomDevice {
|
||||
} // namespace
|
||||
|
||||
void TFE_RegisterCustomDevice(TFE_Context* ctx, TFE_CustomDevice device,
|
||||
const char* device_name, void* device_info) {
|
||||
const char* device_name, void* device_info,
|
||||
TF_Status* status) {
|
||||
auto custom_device =
|
||||
std::make_unique<CustomDeviceAPI>(device, device_info, device_name);
|
||||
ctx->context->RegisterCustomDevice(device_name, std::move(custom_device));
|
||||
std::make_unique<CustomDeviceAPI>(ctx, device, device_info, device_name);
|
||||
status->status =
|
||||
ctx->context->RegisterCustomDevice(device_name, std::move(custom_device));
|
||||
}
|
||||
|
@ -458,27 +458,29 @@ TF_CAPI_EXPORT extern void TFE_OpSetAttrValueProto(const TFE_Op* op,
|
||||
size_t proto_len,
|
||||
TF_Status* status);
|
||||
|
||||
#define TFE_CUSTOM_DEVICE_VERSION 1
|
||||
#define TFE_CUSTOM_DEVICE_VERSION 2
|
||||
|
||||
// Struct to be filled in
|
||||
typedef struct TFE_CustomDevice {
|
||||
int version = TFE_CUSTOM_DEVICE_VERSION;
|
||||
// Method to copy a tensor to the custom device.
|
||||
TFE_TensorHandle* (*copy_tensor_to_device)(TFE_TensorHandle* tensor,
|
||||
TFE_TensorHandle* (*copy_tensor_to_device)(TFE_Context* context,
|
||||
TFE_TensorHandle* tensor,
|
||||
TF_Status* status,
|
||||
void* device_info) = nullptr;
|
||||
|
||||
// Method to copy a tensor from the custom device to a target device.
|
||||
TFE_TensorHandle* (*copy_tensor_from_device)(TFE_TensorHandle* tensor,
|
||||
TFE_TensorHandle* (*copy_tensor_from_device)(TFE_Context* context,
|
||||
TFE_TensorHandle* tensor,
|
||||
const char* target_device_name,
|
||||
TF_Status* status,
|
||||
void* device_info);
|
||||
|
||||
// Method to execute an operation.
|
||||
void (*execute)(int num_inputs, TFE_TensorHandle** inputs,
|
||||
const char* operation_name, const TFE_OpAttrs* attributes,
|
||||
int* num_outputs, TFE_TensorHandle** outputs, TF_Status* s,
|
||||
void* device_info);
|
||||
void (*execute)(TFE_Context* context, int num_inputs,
|
||||
TFE_TensorHandle** inputs, const char* operation_name,
|
||||
const TFE_OpAttrs* attributes, int* num_outputs,
|
||||
TFE_TensorHandle** outputs, TF_Status* s, void* device_info);
|
||||
|
||||
// Method to delete a device.
|
||||
void (*delete_device)(void* device_info);
|
||||
@ -503,11 +505,21 @@ typedef struct TFE_CustomDevice {
|
||||
// devices, so executing tf.functions which contain operations placed on custom
|
||||
// devices will fail.
|
||||
//
|
||||
// `device_name` must not name an existing physical or custom device. It must
|
||||
// follow the format:
|
||||
//
|
||||
// /job:<name>/replica:<replica>/task:<task>/device:<type>:<device_num>
|
||||
//
|
||||
// If the device is successfully registered, `status` is set to TF_OK. Otherwise
|
||||
// the device is not usable. In case of a bad status, `device.delete_device` is
|
||||
// still called on `device_info` (i.e. the caller does not retain ownership).
|
||||
//
|
||||
// This API is highly experimental, and in particular is expected to change when
|
||||
// it starts supporting operations with attributes and when tf.function support
|
||||
// is added.
|
||||
void TFE_RegisterCustomDevice(TFE_Context* ctx, TFE_CustomDevice device,
|
||||
const char* device_name, void* device_info);
|
||||
const char* device_name, void* device_info,
|
||||
TF_Status* status);
|
||||
|
||||
TF_CAPI_EXPORT extern void TFE_ContextGetFunctionDef(TFE_Context* ctx,
|
||||
const char* function_name,
|
||||
|
@ -27,7 +27,6 @@ limitations under the License.
|
||||
namespace {
|
||||
|
||||
struct LoggingDevice {
|
||||
TFE_Context* ctx;
|
||||
tensorflow::string device_name;
|
||||
tensorflow::string underlying_device;
|
||||
// Set to true whenever a TensorHandle is copied onto the device
|
||||
@ -48,7 +47,7 @@ void LoggedTensorDeallocator(void* data, size_t len, void* arg) {
|
||||
}
|
||||
|
||||
TFE_TensorHandle* MakeLoggedTensorHandle(
|
||||
TFE_Context* ctx, const tensorflow::string& logging_device_name,
|
||||
TFE_Context* context, const tensorflow::string& logging_device_name,
|
||||
std::unique_ptr<LoggedTensor> t, TF_Status* status) {
|
||||
std::vector<int64_t> shape(TFE_TensorHandleNumDims(t->tensor, status));
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
@ -58,23 +57,25 @@ TFE_TensorHandle* MakeLoggedTensorHandle(
|
||||
}
|
||||
auto dtype = TFE_TensorHandleDataType(t->tensor);
|
||||
return TFE_NewTensorHandleFromDeviceMemory(
|
||||
ctx, logging_device_name.c_str(), dtype, shape.data(), shape.size(),
|
||||
context, logging_device_name.c_str(), dtype, shape.data(), shape.size(),
|
||||
t.release(), 1, &LoggedTensorDeallocator, nullptr, status);
|
||||
}
|
||||
|
||||
TFE_TensorHandle* CopyToLoggingDevice(TFE_TensorHandle* tensor,
|
||||
TFE_TensorHandle* CopyToLoggingDevice(TFE_Context* context,
|
||||
TFE_TensorHandle* tensor,
|
||||
TF_Status* status, void* device_info) {
|
||||
LoggingDevice* dev = reinterpret_cast<LoggingDevice*>(device_info);
|
||||
TFE_TensorHandle* t = TFE_TensorHandleCopyToDevice(
|
||||
tensor, dev->ctx, dev->underlying_device.c_str(), status);
|
||||
tensor, context, dev->underlying_device.c_str(), status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
auto dst = std::make_unique<LoggedTensor>(t);
|
||||
*(dev->arrived_flag) = true;
|
||||
return MakeLoggedTensorHandle(dev->ctx, dev->device_name, std::move(dst),
|
||||
return MakeLoggedTensorHandle(context, dev->device_name, std::move(dst),
|
||||
status);
|
||||
}
|
||||
|
||||
TFE_TensorHandle* CopyTensorFromLoggingDevice(TFE_TensorHandle* tensor,
|
||||
TFE_TensorHandle* CopyTensorFromLoggingDevice(TFE_Context* context,
|
||||
TFE_TensorHandle* tensor,
|
||||
const char* target_device_name,
|
||||
TF_Status* status,
|
||||
void* device_info) {
|
||||
@ -83,13 +84,13 @@ TFE_TensorHandle* CopyTensorFromLoggingDevice(TFE_TensorHandle* tensor,
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void LoggingDeviceExecute(int num_inputs, TFE_TensorHandle** inputs,
|
||||
const char* operation_name,
|
||||
void LoggingDeviceExecute(TFE_Context* context, int num_inputs,
|
||||
TFE_TensorHandle** inputs, const char* operation_name,
|
||||
const TFE_OpAttrs* attributes, int* num_outputs,
|
||||
TFE_TensorHandle** outputs, TF_Status* s,
|
||||
void* device_info) {
|
||||
LoggingDevice* dev = reinterpret_cast<LoggingDevice*>(device_info);
|
||||
TFE_Op* op(TFE_NewOp(dev->ctx, operation_name, s));
|
||||
TFE_Op* op(TFE_NewOp(context, operation_name, s));
|
||||
if (TF_GetCode(s) != TF_OK) return;
|
||||
TFE_OpAddAttrs(op, attributes);
|
||||
TFE_OpSetDevice(op, dev->underlying_device.c_str(), s);
|
||||
@ -117,7 +118,7 @@ void LoggingDeviceExecute(int num_inputs, TFE_TensorHandle** inputs,
|
||||
}
|
||||
for (int i = 0; i < *num_outputs; ++i) {
|
||||
auto logged_tensor = std::make_unique<LoggedTensor>(unwrapped_outputs[i]);
|
||||
outputs[i] = MakeLoggedTensorHandle(dev->ctx, dev->device_name,
|
||||
outputs[i] = MakeLoggedTensorHandle(context, dev->device_name,
|
||||
std::move(logged_tensor), s);
|
||||
}
|
||||
*(dev->executed_flag) = true;
|
||||
@ -128,19 +129,19 @@ void DeleteLoggingDevice(void* device_info) {
|
||||
}
|
||||
|
||||
void RegisterLoggingDevice(TFE_Context* context, const char* name,
|
||||
bool* arrived_flag, bool* executed_flag) {
|
||||
bool* arrived_flag, bool* executed_flag,
|
||||
TF_Status* status) {
|
||||
TFE_CustomDevice custom_device;
|
||||
custom_device.copy_tensor_to_device = &CopyToLoggingDevice;
|
||||
custom_device.copy_tensor_from_device = &CopyTensorFromLoggingDevice;
|
||||
custom_device.delete_device = &DeleteLoggingDevice;
|
||||
custom_device.execute = &LoggingDeviceExecute;
|
||||
LoggingDevice* device = new LoggingDevice;
|
||||
device->ctx = context;
|
||||
device->arrived_flag = arrived_flag;
|
||||
device->executed_flag = executed_flag;
|
||||
device->device_name = name;
|
||||
device->underlying_device = "/job:localhost/replica:0/task:0/device:CPU:0";
|
||||
TFE_RegisterCustomDevice(context, custom_device, name, device);
|
||||
TFE_RegisterCustomDevice(context, custom_device, name, device, status);
|
||||
}
|
||||
|
||||
TEST(CUSTOM_DEVICE, RegisterSimpleDevice) {
|
||||
@ -153,7 +154,8 @@ TEST(CUSTOM_DEVICE, RegisterSimpleDevice) {
|
||||
bool arrived = false;
|
||||
bool executed = false;
|
||||
const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
|
||||
RegisterLoggingDevice(context, name, &arrived, &executed);
|
||||
RegisterLoggingDevice(context, name, &arrived, &executed, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
TFE_TensorHandle* hcpu = TestMatrixTensorHandle();
|
||||
ASSERT_FALSE(arrived);
|
||||
TFE_TensorHandle* hdevice =
|
||||
@ -189,7 +191,9 @@ TEST(CUSTOM_DEVICE, ResetOperation) {
|
||||
bool executed = false;
|
||||
const char* custom_device_name =
|
||||
"/job:localhost/replica:0/task:0/device:CUSTOM:0";
|
||||
RegisterLoggingDevice(context.get(), custom_device_name, &arrived, &executed);
|
||||
RegisterLoggingDevice(context.get(), custom_device_name, &arrived, &executed,
|
||||
status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> reused_op(
|
||||
TFE_NewOp(context.get(), "Identity", status.get()), TFE_DeleteOp);
|
||||
@ -217,7 +221,8 @@ TEST(CUSTOM_DEVICE, MakeVariable) {
|
||||
bool arrived = false;
|
||||
bool executed = false;
|
||||
const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
|
||||
RegisterLoggingDevice(context.get(), name, &arrived, &executed);
|
||||
RegisterLoggingDevice(context.get(), name, &arrived, &executed, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
// Create a variable handle placed on the custom device.
|
||||
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
|
||||
@ -291,4 +296,103 @@ TEST(CUSTOM_DEVICE, MakeVariable) {
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
}
|
||||
|
||||
TEST(CUSTOM_DEVICE, AccessVariableOnWrongDevice) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
|
||||
TFE_NewContextOptions(), TFE_DeleteContextOptions);
|
||||
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
|
||||
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
bool arrived = false;
|
||||
bool executed = false;
|
||||
const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
|
||||
RegisterLoggingDevice(context.get(), name, &arrived, &executed, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
// Create a variable handle placed on the custom device.
|
||||
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
|
||||
TFE_NewOp(context.get(), "VarHandleOp", status.get()), TFE_DeleteOp);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT);
|
||||
TFE_OpSetAttrShape(op.get(), "shape", {}, 0, status.get());
|
||||
TFE_OpSetAttrString(op.get(), "container", "", 0);
|
||||
TFE_OpSetAttrString(op.get(), "shared_name", "", 0);
|
||||
TFE_OpSetDevice(op.get(), name, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
TFE_TensorHandle* var_handle = nullptr;
|
||||
int num_retvals = 1;
|
||||
executed = false;
|
||||
TFE_Execute(op.get(), &var_handle, &num_retvals, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
ASSERT_TRUE(executed);
|
||||
auto handle_cleaner = tensorflow::gtl::MakeCleanup(
|
||||
[var_handle]() { TFE_DeleteTensorHandle(var_handle); });
|
||||
|
||||
// Assign to the variable, copying to the custom device.
|
||||
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)> one(
|
||||
TestScalarTensorHandle(111.f), TFE_DeleteTensorHandle);
|
||||
op.reset(TFE_NewOp(context.get(), "AssignVariableOp", status.get()));
|
||||
TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT);
|
||||
TFE_OpAddInput(op.get(), var_handle, status.get());
|
||||
TFE_OpAddInput(op.get(), one.get(), status.get());
|
||||
TFE_OpSetDevice(op.get(), name, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
executed = false;
|
||||
num_retvals = 0;
|
||||
TFE_Execute(op.get(), nullptr, &num_retvals, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
ASSERT_TRUE(executed);
|
||||
|
||||
// Read the variable's value.
|
||||
op.reset(TFE_NewOp(context.get(), "ReadVariableOp", status.get()));
|
||||
TFE_OpAddInput(op.get(), var_handle, status.get());
|
||||
TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
executed = false;
|
||||
num_retvals = 1;
|
||||
TFE_TensorHandle* var_value = nullptr;
|
||||
TFE_Execute(op.get(), &var_value, &num_retvals, status.get());
|
||||
EXPECT_FALSE(TF_GetCode(status.get()) == TF_OK)
|
||||
<< "Execution should fail because the variable is being used on the "
|
||||
"wrong device.";
|
||||
// Free the backing buffer for the variable.
|
||||
op.reset(TFE_NewOp(context.get(), "DestroyResourceOp", status.get()));
|
||||
TFE_OpAddInput(op.get(), var_handle, status.get());
|
||||
TFE_OpSetDevice(op.get(), name, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
num_retvals = 0;
|
||||
TFE_Execute(op.get(), nullptr, &num_retvals, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
}
|
||||
|
||||
TEST(CUSTOM_DEVICE, InvalidRegistrationError) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
|
||||
TFE_NewContextOptions(), TFE_DeleteContextOptions);
|
||||
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
|
||||
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
bool arrived = false;
|
||||
bool executed = false;
|
||||
RegisterLoggingDevice(context.get(), "/device:CUSTOM:0", &arrived, &executed,
|
||||
status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_INVALID_ARGUMENT)
|
||||
<< TF_Message(status.get());
|
||||
|
||||
const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
|
||||
RegisterLoggingDevice(context.get(), name, &arrived, &executed, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
RegisterLoggingDevice(context.get(), name, &arrived, &executed, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_ALREADY_EXISTS)
|
||||
<< TF_Message(status.get());
|
||||
|
||||
RegisterLoggingDevice(context.get(),
|
||||
"/job:localhost/replica:0/task:0/device:CPU:0",
|
||||
&arrived, &executed, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_ALREADY_EXISTS)
|
||||
<< TF_Message(status.get());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -27,14 +27,10 @@ namespace {
|
||||
|
||||
class DummyDevice : public DeviceBase {
|
||||
public:
|
||||
DummyDevice(Env* env, bool save) : DeviceBase(env), save_(save) {}
|
||||
bool RequiresRecordingAccessedTensors() const override { return save_; }
|
||||
explicit DummyDevice(Env* env) : DeviceBase(env) {}
|
||||
Allocator* GetAllocator(AllocatorAttributes /*attr*/) override {
|
||||
return cpu_allocator();
|
||||
}
|
||||
|
||||
private:
|
||||
bool save_;
|
||||
};
|
||||
|
||||
void TestBitcastOp(Tensor* input_tensor, DataType out_type,
|
||||
@ -61,7 +57,7 @@ void TestBitcastOp(Tensor* input_tensor, DataType out_type,
|
||||
ASSERT_TRUE(status.ok()) << status.ToString();
|
||||
|
||||
OpKernelContext::Params params;
|
||||
DummyDevice dummy_device(nullptr, false);
|
||||
DummyDevice dummy_device(nullptr);
|
||||
params.device = &dummy_device;
|
||||
params.op_kernel = kernel.get();
|
||||
gtl::InlinedVector<TensorValue, 4> inputs;
|
||||
|
@ -155,14 +155,10 @@ TEST(TestKernel, TestRegisterKernelBuilder) {
|
||||
|
||||
class DummyDevice : public DeviceBase {
|
||||
public:
|
||||
DummyDevice(Env* env, bool save) : DeviceBase(env), save_(save) {}
|
||||
bool RequiresRecordingAccessedTensors() const override { return save_; }
|
||||
explicit DummyDevice(Env* env) : DeviceBase(env) {}
|
||||
Allocator* GetAllocator(AllocatorAttributes /*attr*/) override {
|
||||
return cpu_allocator();
|
||||
}
|
||||
|
||||
private:
|
||||
bool save_;
|
||||
};
|
||||
|
||||
TEST(TestKernel, TestInputAndOutputCount) {
|
||||
@ -223,7 +219,7 @@ TEST(TestKernel, TestInputAndOutputCount) {
|
||||
|
||||
{
|
||||
OpKernelContext::Params p;
|
||||
DummyDevice dummy_device(nullptr, false);
|
||||
DummyDevice dummy_device(nullptr);
|
||||
p.device = &dummy_device;
|
||||
p.step_id = 43;
|
||||
|
||||
|
@ -58,9 +58,9 @@ extern "C" {
|
||||
// start_offset: array[uint64]
|
||||
// data: byte[...]
|
||||
//
|
||||
// The string length (as a varint), followed by the contents of the string
|
||||
// is encoded at data[start_offset[i]]]. TF_StringEncode and TF_StringDecode
|
||||
// facilitate this encoding.
|
||||
// The string length (as a varint, start_offset[i + 1] - start_offset[i]),
|
||||
// followed by the contents of the string is encoded at data[start_offset[i]].
|
||||
// TF_StringEncode and TF_StringDecode facilitate this encoding.
|
||||
|
||||
typedef struct TF_Tensor TF_Tensor;
|
||||
|
||||
|
@ -41,7 +41,7 @@ class ClientSession::Impl {
|
||||
std::shared_ptr<Graph> graph_;
|
||||
|
||||
mutable mutex mu_;
|
||||
mutable int last_num_graph_nodes_ GUARDED_BY(mu_) = 0;
|
||||
mutable int last_num_graph_nodes_ TF_GUARDED_BY(mu_) = 0;
|
||||
};
|
||||
|
||||
ClientSession::ClientSession(const Scope& scope, const string& target)
|
||||
|
@ -114,14 +114,14 @@ class Coordinator {
|
||||
condition_variable wait_for_stop_;
|
||||
|
||||
mutex mu_;
|
||||
bool should_stop_ GUARDED_BY(mu_);
|
||||
bool should_stop_ TF_GUARDED_BY(mu_);
|
||||
|
||||
mutex status_lock_;
|
||||
Status status_ GUARDED_BY(status_lock_);
|
||||
Status status_ TF_GUARDED_BY(status_lock_);
|
||||
|
||||
mutable mutex runners_lock_;
|
||||
std::vector<std::unique_ptr<RunnerInterface>> runners_
|
||||
GUARDED_BY(runners_lock_);
|
||||
TF_GUARDED_BY(runners_lock_);
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(Coordinator);
|
||||
};
|
||||
|
@ -119,8 +119,8 @@ class QueueRunner : public RunnerInterface {
|
||||
std::unique_ptr<thread::ThreadPool> thread_pool_;
|
||||
mutex mu_;
|
||||
int runs_ = 0;
|
||||
Status status_ GUARDED_BY(mu_);
|
||||
Status enqueue_status_ GUARDED_BY(mu_);
|
||||
Status status_ TF_GUARDED_BY(mu_);
|
||||
Status enqueue_status_ TF_GUARDED_BY(mu_);
|
||||
std::unique_ptr<BlockingCounter> counter_;
|
||||
|
||||
Coordinator* coord_;
|
||||
@ -131,7 +131,7 @@ class QueueRunner : public RunnerInterface {
|
||||
std::vector<std::function<void(Status)>> callbacks_;
|
||||
|
||||
mutable std::unique_ptr<mutex> cg_mu_;
|
||||
std::unique_ptr<CostGraphDef> cost_graph_ GUARDED_BY(cg_mu_);
|
||||
std::unique_ptr<CostGraphDef> cost_graph_ TF_GUARDED_BY(cg_mu_);
|
||||
RunOptions run_options_;
|
||||
};
|
||||
|
||||
|
@ -37,6 +37,7 @@ cc_library(
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:span",
|
||||
"//tensorflow/compiler/mlir/lite/quantization/xla:quantize",
|
||||
"//tensorflow/compiler/tf2xla",
|
||||
"//tensorflow/compiler/tf2xla:mlir_tf2xla",
|
||||
"//tensorflow/compiler/tf2xla:tf2xla_proto_cc",
|
||||
@ -64,6 +65,7 @@ cc_library(
|
||||
"@llvm-project//llvm:powerpc_code_gen", # fixdeps: keep
|
||||
"@llvm-project//llvm:target",
|
||||
"@llvm-project//llvm:x86_code_gen", # fixdeps: keep
|
||||
"//tensorflow/core:regexp_internal",
|
||||
] + if_llvm_aarch64_available([
|
||||
"//third_party/llvm/llvm-project/llvm:aarch64_target", # fixdeps: keep
|
||||
]),
|
||||
|
@ -26,6 +26,7 @@ limitations under the License.
|
||||
#include "absl/strings/str_split.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/compiler/aot/embedded_protocol_buffers.h"
|
||||
#include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
|
||||
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
|
||||
#include "tensorflow/compiler/xla/cpu_function_runtime.h"
|
||||
#include "tensorflow/compiler/xla/service/compiler.h"
|
||||
@ -288,8 +289,8 @@ Status GenVariableMethods(const tf2xla::Config& config,
|
||||
}
|
||||
|
||||
// Generates code implementing {Arg,Result}Names(), where T is one of
|
||||
// tf2xla::{Feed,Fetch}. Each feed or fetch name results in a C-style string
|
||||
// literal in the array, with nullptr terminating the array.
|
||||
// tf2xla::{Feed,Fetch,Variable}. Each feed or fetch name results in a C-style
|
||||
// string literal in the array, with nullptr terminating the array.
|
||||
template <typename T>
|
||||
string GenNameToIndexCode(const T& entries, bool generate) {
|
||||
// No need for a static array if we're not supposed to generate the data.
|
||||
@ -419,6 +420,16 @@ Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config,
|
||||
// Generate metadata.
|
||||
const string arg_names_code =
|
||||
GenNameToIndexCode(config.feed(), opts.gen_name_to_index);
|
||||
|
||||
auto variable_copy = config.variable();
|
||||
for (auto& var : variable_copy) {
|
||||
if (var.name().empty()) {
|
||||
var.set_name(var.node_name());
|
||||
}
|
||||
}
|
||||
const string variable_names_code =
|
||||
GenNameToIndexCode(variable_copy, opts.gen_name_to_index);
|
||||
|
||||
const string result_names_code =
|
||||
GenNameToIndexCode(config.fetch(), opts.gen_name_to_index);
|
||||
const string include_xla_data_proto =
|
||||
@ -507,6 +518,9 @@ class {{CLASS}} final : public tensorflow::XlaCompiledCpuFunction {
|
||||
// Number of input arguments for the compiled computation.
|
||||
static constexpr size_t kNumArgs = {{ARG_NUM}};
|
||||
|
||||
// Number of variables for the compiled computation.
|
||||
static constexpr size_t kNumVariables = {{VARIABLE_NUM}};
|
||||
|
||||
// Byte size of each argument buffer. There are kNumArgs entries.
|
||||
static const ::tensorflow::int64 ArgSize(::tensorflow::int32 index) {
|
||||
return BufferInfos()[ArgIndexToBufferIndex()[index]].size();
|
||||
@ -522,8 +536,10 @@ class {{CLASS}} final : public tensorflow::XlaCompiledCpuFunction {
|
||||
set_static_data_num_buffers(data, kNumBuffers);
|
||||
set_static_data_arg_index_table(data, ArgIndexToBufferIndex());
|
||||
set_static_data_num_args(data, kNumArgs);
|
||||
set_static_data_num_variables(data, kNumVariables);
|
||||
set_static_data_result_index(data, kResultIndex);
|
||||
set_static_data_arg_names(data, StaticArgNames());
|
||||
set_static_data_variable_names(data, StaticVariableNames());
|
||||
set_static_data_result_names(data, StaticResultNames());
|
||||
set_static_data_program_shape(data, StaticProgramShape());
|
||||
set_static_data_hlo_profile_printer_data(
|
||||
@ -626,6 +642,9 @@ class {{CLASS}} final : public tensorflow::XlaCompiledCpuFunction {
|
||||
// Array of names of each positional argument, terminated by nullptr.
|
||||
static const char** StaticArgNames() {{ARG_NAMES_CODE}}
|
||||
|
||||
// Array of names of each positional variable, terminated by nullptr.
|
||||
static const char** StaticVariableNames() {{VARIABLE_NAMES_CODE}}
|
||||
|
||||
// Array of names of each positional result, terminated by nullptr.
|
||||
static const char** StaticResultNames() {{RESULT_NAMES_CODE}}
|
||||
|
||||
@ -654,6 +673,7 @@ class {{CLASS}} final : public tensorflow::XlaCompiledCpuFunction {
|
||||
{"{{ARG_BYTES_TOTAL}}", absl::StrCat(arg_bytes_total)},
|
||||
{"{{ARG_NAMES_CODE}}", arg_names_code},
|
||||
{"{{ARG_NUM}}", absl::StrCat(arg_index_table.size())},
|
||||
{"{{VARIABLE_NUM}}", absl::StrCat(config.variable_size())},
|
||||
{"{{ARG_INDEX_TABLE}}", absl::StrJoin(arg_index_table, ", ")},
|
||||
{"{{ASSIGN_PROFILE_COUNTERS_SIZE}}", assign_profile_counters_size},
|
||||
{"{{CLASS}}", opts.class_name},
|
||||
@ -673,6 +693,7 @@ class {{CLASS}} final : public tensorflow::XlaCompiledCpuFunction {
|
||||
{"{{PROGRAM_SHAPE}}", xla::ShapeUtil::HumanString(xla::ProgramShape(ps))},
|
||||
{"{{PROGRAM_SHAPE_SHIM_EXPRESSION}}",
|
||||
metadata_result.program_shape_access_shim},
|
||||
{"{{VARIABLE_NAMES_CODE}}", variable_names_code},
|
||||
{"{{RESULT_INDEX}}", absl::StrCat(result_index)},
|
||||
{"{{RESULT_NAMES_CODE}}", result_names_code},
|
||||
{"{{TEMP_BYTES_ALIGNED}}", absl::StrCat(temp_bytes_aligned)},
|
||||
|
@ -156,17 +156,14 @@ static void CompareWithGoldenFile(
|
||||
// bazel test --test_strategy=local \
|
||||
// third_party/tensorflow/compiler/aot:codegen_test
|
||||
const bool update_golden = false;
|
||||
string golden_file_name;
|
||||
string golden_file_name =
|
||||
GetDataDependencyFilepath(tensorflow_relative_golden_file_name);
|
||||
|
||||
if (update_golden) {
|
||||
golden_file_name = io::JoinPath(testing::TensorFlowSrcRoot(),
|
||||
tensorflow_relative_golden_file_name);
|
||||
TF_EXPECT_OK(
|
||||
WriteStringToFile(Env::Default(), golden_file_name, expected_contents));
|
||||
}
|
||||
|
||||
golden_file_name =
|
||||
GetDataDependencyFilepath(tensorflow_relative_golden_file_name);
|
||||
string golden_file_contents;
|
||||
TF_ASSERT_OK(ReadFileToString(Env::Default(), golden_file_name,
|
||||
&golden_file_contents));
|
||||
@ -220,10 +217,16 @@ TEST(CodegenTest, Golden) {
|
||||
{},
|
||||
{BufferInfo::MakeTempBuffer(1),
|
||||
BufferInfo::MakeEntryParameter(/*size=*/8, /*param_number=*/0),
|
||||
BufferInfo::MakeTempBuffer(2),
|
||||
BufferInfo::MakeTempBuffer(1),
|
||||
BufferInfo::MakeEntryParameter(/*size=*/96, /*param_number=*/1),
|
||||
BufferInfo::MakeTempBuffer(3), BufferInfo::MakeTempBuffer(120)},
|
||||
5, {}));
|
||||
BufferInfo::MakeTempBuffer(1),
|
||||
BufferInfo::MakeEntryParameter(/*size=*/96, /*param_number=*/2),
|
||||
BufferInfo::MakeTempBuffer(1),
|
||||
BufferInfo::MakeEntryParameter(/*size=*/96, /*param_number=*/3),
|
||||
BufferInfo::MakeTempBuffer(1),
|
||||
BufferInfo::MakeEntryParameter(/*size=*/96, /*param_number=*/4),
|
||||
BufferInfo::MakeTempBuffer(1), BufferInfo::MakeTempBuffer(120)},
|
||||
11, {}));
|
||||
compile_result.program_shape =
|
||||
xla::ShapeUtil::MakeProgramShape(
|
||||
{
|
||||
|
@ -55,14 +55,17 @@ namespace bar {
|
||||
// ((unknown): f32[1,2], (unknown): s64[3,4], (unknown): f32[1], (unknown): f32[1], (unknown): s32[5]) -> (u32[5,6], f32[1], s32[5])
|
||||
//
|
||||
// Memory stats:
|
||||
// arg bytes total: 104
|
||||
// arg bytes aligned: 192
|
||||
// arg bytes total: 392
|
||||
// arg bytes aligned: 576
|
||||
// temp bytes total: 126
|
||||
// temp bytes aligned: 320
|
||||
// temp bytes aligned: 512
|
||||
class MyClass final : public tensorflow::XlaCompiledCpuFunction {
|
||||
public:
|
||||
// Number of input arguments for the compiled computation.
|
||||
static constexpr size_t kNumArgs = 2;
|
||||
static constexpr size_t kNumArgs = 5;
|
||||
|
||||
// Number of variables for the compiled computation.
|
||||
static constexpr size_t kNumVariables = 3;
|
||||
|
||||
// Byte size of each argument buffer. There are kNumArgs entries.
|
||||
static const ::tensorflow::int64 ArgSize(::tensorflow::int32 index) {
|
||||
@ -79,8 +82,10 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction {
|
||||
set_static_data_num_buffers(data, kNumBuffers);
|
||||
set_static_data_arg_index_table(data, ArgIndexToBufferIndex());
|
||||
set_static_data_num_args(data, kNumArgs);
|
||||
set_static_data_num_variables(data, kNumVariables);
|
||||
set_static_data_result_index(data, kResultIndex);
|
||||
set_static_data_arg_names(data, StaticArgNames());
|
||||
set_static_data_variable_names(data, StaticVariableNames());
|
||||
set_static_data_result_names(data, StaticResultNames());
|
||||
set_static_data_program_shape(data, StaticProgramShape());
|
||||
set_static_data_hlo_profile_printer_data(
|
||||
@ -295,16 +300,22 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction {
|
||||
|
||||
private:
|
||||
// Number of buffers for the compiled computation.
|
||||
static constexpr size_t kNumBuffers = 6;
|
||||
static constexpr size_t kNumBuffers = 12;
|
||||
|
||||
static const ::xla::cpu_function_runtime::BufferInfo* BufferInfos() {
|
||||
static const ::xla::cpu_function_runtime::BufferInfo
|
||||
kBufferInfos[kNumBuffers] = {
|
||||
::xla::cpu_function_runtime::BufferInfo({5ULL, ~0ULL}),
|
||||
::xla::cpu_function_runtime::BufferInfo({34ULL, 0ULL}),
|
||||
::xla::cpu_function_runtime::BufferInfo({9ULL, ~0ULL}),
|
||||
::xla::cpu_function_runtime::BufferInfo({5ULL, ~0ULL}),
|
||||
::xla::cpu_function_runtime::BufferInfo({386ULL, 1ULL}),
|
||||
::xla::cpu_function_runtime::BufferInfo({13ULL, ~0ULL}),
|
||||
::xla::cpu_function_runtime::BufferInfo({5ULL, ~0ULL}),
|
||||
::xla::cpu_function_runtime::BufferInfo({386ULL, 2ULL}),
|
||||
::xla::cpu_function_runtime::BufferInfo({5ULL, ~0ULL}),
|
||||
::xla::cpu_function_runtime::BufferInfo({386ULL, 3ULL}),
|
||||
::xla::cpu_function_runtime::BufferInfo({5ULL, ~0ULL}),
|
||||
::xla::cpu_function_runtime::BufferInfo({386ULL, 4ULL}),
|
||||
::xla::cpu_function_runtime::BufferInfo({5ULL, ~0ULL}),
|
||||
::xla::cpu_function_runtime::BufferInfo({481ULL, ~0ULL})
|
||||
};
|
||||
return kBufferInfos;
|
||||
@ -312,13 +323,13 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction {
|
||||
|
||||
static const ::tensorflow::int32* ArgIndexToBufferIndex() {
|
||||
static constexpr ::tensorflow::int32 kArgIndexToBufferIndex[kNumArgs] = {
|
||||
1, 3
|
||||
1, 3, 5, 7, 9
|
||||
};
|
||||
return kArgIndexToBufferIndex;
|
||||
}
|
||||
|
||||
// The 0-based index of the result tuple in the temporary buffers.
|
||||
static constexpr size_t kResultIndex = 5;
|
||||
static constexpr size_t kResultIndex = 11;
|
||||
|
||||
// Array of names of each positional argument, terminated by nullptr.
|
||||
static const char** StaticArgNames() {
|
||||
@ -326,6 +337,12 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction {
|
||||
return kNames;
|
||||
}
|
||||
|
||||
// Array of names of each positional variable, terminated by nullptr.
|
||||
static const char** StaticVariableNames() {
|
||||
static const char* kNames[] = {"myvar_readonly", "myvar", "myvar2", nullptr};
|
||||
return kNames;
|
||||
}
|
||||
|
||||
// Array of names of each positional result, terminated by nullptr.
|
||||
static const char** StaticResultNames() {
|
||||
static const char* kNames[] = {"myfetch", nullptr};
|
||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
||||
#include "llvm-c/Target.h"
|
||||
#include "tensorflow/compiler/aot/codegen.h"
|
||||
#include "tensorflow/compiler/aot/flags.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/xla/quantize.h"
|
||||
#include "tensorflow/compiler/tf2xla/tf2xla.h"
|
||||
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
|
||||
#include "tensorflow/compiler/xla/client/client_library.h"
|
||||
@ -39,6 +40,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/strings/proto_serialization.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/regexp.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -105,14 +107,18 @@ Status CompileGraph(GraphDef graph_def, const tf2xla::Config& config,
|
||||
.ValueOrDie();
|
||||
xla::XlaComputation computation;
|
||||
if (flags.mlir_components == "Bridge") {
|
||||
TF_RETURN_IF_ERROR(
|
||||
ConvertGraphDefToXlaViaMlir(graph_def, config, &computation));
|
||||
TF_RETURN_IF_ERROR(ConvertGraphDefToXlaViaMlir(
|
||||
graph_def, config, &computation, flags.debug_info,
|
||||
flags.debug_info_path_begin_marker));
|
||||
} else if (flags.mlir_components.empty() || flags.mlir_components == "None") {
|
||||
TF_RETURN_IF_ERROR(ConvertGraphDefToXla(std::move(graph_def), config,
|
||||
client, &computation));
|
||||
} else {
|
||||
return errors::Unknown("Unknown mlir_components ", flags.mlir_components);
|
||||
}
|
||||
if (flags.quantize) {
|
||||
TF_RETURN_IF_ERROR(mlir::xla_hlo::XlaQuantize(config, &computation));
|
||||
}
|
||||
if (!flags.out_session_module.empty()) {
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::HloSnapshot> module,
|
||||
computation.Snapshot());
|
||||
@ -166,6 +172,23 @@ static void InitializeTargets() {
|
||||
LLVMInitializeX86AsmPrinter();
|
||||
}
|
||||
|
||||
// Replaces {{tag.type tag.name}} in the error message with tag_name.
|
||||
// TODO(bixia): We currently only handlge tag.type == "node".
|
||||
//
|
||||
// In the error message, a graph node is represented as {{tag.type, tag.name}},
|
||||
// to allow a Python debugger to insert source information about the graph node.
|
||||
// For example, a Python add expression may be represented as
|
||||
// {{node, x_y_sum}} = Add(x, y) in the error message. See routine interpolate
|
||||
// in tensorflow/python/framework/error_interpolation.py for more detail.
|
||||
static std::string InterpolateErrorMessage(std::string message) {
|
||||
// See _NAME_REGEX in tensorflow/python/framework/error_interpolation.py
|
||||
// Change "prefix {{node tag.name}} suffix" to "prefix tag.name suffix".
|
||||
static LazyRE2 pattern{"(.*){{node (.*)}}(.*)"};
|
||||
RE2::GlobalReplace(&message, *pattern, "\\1\\2\\3");
|
||||
|
||||
return message;
|
||||
}
|
||||
|
||||
Status Main(const MainFlags& flags) {
|
||||
absl::call_once(targets_init, &InitializeTargets);
|
||||
|
||||
@ -192,8 +215,13 @@ Status Main(const MainFlags& flags) {
|
||||
GraphDef graph_def;
|
||||
TF_RETURN_IF_ERROR(ReadProtoFile(flags.graph, &graph_def));
|
||||
CompileResult compile_result;
|
||||
TF_RETURN_IF_ERROR(
|
||||
CompileGraph(std::move(graph_def), config, flags, &compile_result));
|
||||
|
||||
Status status =
|
||||
CompileGraph(std::move(graph_def), config, flags, &compile_result);
|
||||
if (!status.ok()) {
|
||||
return Status(status.code(),
|
||||
InterpolateErrorMessage(status.error_message()));
|
||||
}
|
||||
|
||||
// Write output files.
|
||||
Env* env = Env::Default();
|
||||
|
@ -24,6 +24,13 @@ void AppendMainFlags(std::vector<Flag>* flag_list, MainFlags* flags) {
|
||||
"Input GraphDef file. If the file ends in '.pbtxt' it is expected to "
|
||||
"be in the human-readable proto text format, otherwise it is expected "
|
||||
"to be in the proto binary format."},
|
||||
{"debug_info", &flags->debug_info,
|
||||
"Graph debug info file. If the file ends in '.pbtxt' it is expected to "
|
||||
"be in the human-readable proto text format, otherwise it is expected "
|
||||
"to be in the proto binary format."},
|
||||
{"debug_info_path_begin_marker", &flags->debug_info_path_begin_marker,
|
||||
"If not none, only keep the file path in the debug information after the"
|
||||
" marker. The default value is empty"},
|
||||
{"config", &flags->config,
|
||||
"Input file containing Config proto. If the file ends in '.pbtxt' it "
|
||||
"is expected to be in the human-readable proto text format, otherwise "
|
||||
@ -70,6 +77,8 @@ void AppendMainFlags(std::vector<Flag>* flag_list, MainFlags* flags) {
|
||||
"Output session module proto."},
|
||||
{"mlir_components", &flags->mlir_components,
|
||||
"The MLIR components to enable. Currently only Bridge is supported."},
|
||||
{"quantize", &flags->quantize,
|
||||
"If set, quantization will be applied before HLO code generation."},
|
||||
{"gen_name_to_index", &flags->gen_name_to_index,
|
||||
"Generate name-to-index data for Lookup{Arg,Result}Index methods."},
|
||||
{"gen_program_shape", &flags->gen_program_shape,
|
||||
|
@ -28,6 +28,8 @@ namespace tfcompile {
|
||||
|
||||
struct MainFlags {
|
||||
string graph;
|
||||
string debug_info;
|
||||
string debug_info_path_begin_marker;
|
||||
string config;
|
||||
bool dump_fetch_nodes = false;
|
||||
string target_triple;
|
||||
@ -40,6 +42,7 @@ struct MainFlags {
|
||||
string out_header;
|
||||
string out_session_module;
|
||||
string mlir_components;
|
||||
bool quantize = false;
|
||||
|
||||
// C++ codegen options
|
||||
bool gen_name_to_index = false;
|
||||
|
@ -1,11 +1,37 @@
|
||||
load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
|
||||
load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests")
|
||||
|
||||
package(
|
||||
default_visibility = ["//visibility:private"],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
glob_lit_tests(
|
||||
data = [":filecheck_test_utilities"],
|
||||
driver = "@llvm-project//mlir:run_lit.sh",
|
||||
tags_override = {
|
||||
"test_error_message.lit.pbtxt": ["no_oss"], # TODO(b/150957738): to be fixed on oss.
|
||||
},
|
||||
test_file_exts = ["lit.pbtxt"],
|
||||
)
|
||||
|
||||
# Bundle together all of the test utilities that are used by tests.
|
||||
filegroup(
|
||||
name = "filecheck_test_utilities",
|
||||
testonly = True,
|
||||
srcs = [
|
||||
"test_error_message.lit.pbtxt.config.pbtxt",
|
||||
"test_error_message.lit.pbtxt.debug.pbtxt",
|
||||
"test_error_message.lit.pbtxt.fake_py.debug",
|
||||
],
|
||||
data = [
|
||||
"//tensorflow/compiler/aot:tfcompile",
|
||||
"@llvm-project//llvm:FileCheck",
|
||||
"@llvm-project//llvm:not",
|
||||
],
|
||||
)
|
||||
|
||||
# We disable some tfcompile tests in the open source build with the
|
||||
# "manual" tag to avoid making our OSS users build LLVM twice
|
||||
# (once for host and once for target).
|
||||
@ -60,6 +86,7 @@ genrule(
|
||||
testonly = 1,
|
||||
outs = [
|
||||
"test_graph_tfadd.pb",
|
||||
"test_debuginfo_tfadd.pb",
|
||||
"test_graph_tfadd_with_ckpt.ckpt",
|
||||
"test_graph_tfadd_with_ckpt.pb",
|
||||
"test_graph_tfadd_with_ckpt_saver.ckpt",
|
||||
@ -317,6 +344,7 @@ tf_library(
|
||||
testonly = 1,
|
||||
config = "test_graph_tfadd.config.pbtxt",
|
||||
cpp_class = "AddComp",
|
||||
debug_info = "test_debuginfo_tfadd.pb",
|
||||
graph = "test_graph_tfadd.pb",
|
||||
include_standard_runtime_deps = False,
|
||||
mlir_components = "Bridge",
|
||||
|
@ -30,6 +30,7 @@ from tensorflow.core.protobuf import saver_pb2
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import error_interpolation
|
||||
from tensorflow.python.framework import function
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
@ -184,7 +185,22 @@ def tfvariable_sequential_updates(_):
|
||||
array_ops.identity(updates, name='result')
|
||||
|
||||
|
||||
def write_graph(build_graph, out_dir):
|
||||
def export_debug_info(exported_graph):
|
||||
"""Exports debug information from a graph.
|
||||
|
||||
Args:
|
||||
exported_graph: A Graph that has been created by tracing a saveable view.
|
||||
|
||||
Returns:
|
||||
Corresponding GraphDebugInfo with traces for all ops in exported_graph.
|
||||
"""
|
||||
exported_operations = []
|
||||
for op in exported_graph.get_operations():
|
||||
exported_operations.append(('', op))
|
||||
return error_interpolation.create_graph_debug_info_def(exported_operations)
|
||||
|
||||
|
||||
def write_graph(build_graph, out_dir, debug_info=False):
|
||||
"""Build a graph using build_graph and write it out."""
|
||||
g = ops.Graph()
|
||||
with g.as_default():
|
||||
@ -193,10 +209,19 @@ def write_graph(build_graph, out_dir):
|
||||
with open(filename, 'wb') as f:
|
||||
f.write(six.ensure_binary(g.as_graph_def().SerializeToString()))
|
||||
|
||||
if debug_info:
|
||||
filename_debuginfo = os.path.join(
|
||||
out_dir, 'test_debuginfo_%s.pb' % build_graph.__name__)
|
||||
test_debuginfo = export_debug_info(g)
|
||||
with open(filename_debuginfo, 'wb') as f:
|
||||
f.write(
|
||||
six.ensure_binary(
|
||||
test_debuginfo.SerializeToString(deterministic=True)))
|
||||
|
||||
|
||||
def main(_):
|
||||
control_flow_util.enable_control_flow_v2()
|
||||
write_graph(tfadd, FLAGS.out_dir)
|
||||
write_graph(tfadd, FLAGS.out_dir, debug_info=True)
|
||||
write_graph(tfadd_with_ckpt, FLAGS.out_dir)
|
||||
write_graph(tfadd_with_ckpt_saver, FLAGS.out_dir)
|
||||
write_graph(tfassert_eq, FLAGS.out_dir)
|
||||
|
69
tensorflow/compiler/aot/tests/test_error_message.lit.pbtxt
Normal file
69
tensorflow/compiler/aot/tests/test_error_message.lit.pbtxt
Normal file
@ -0,0 +1,69 @@
|
||||
# RUN: not tfcompile --graph=%s --config=%s.config.pbtxt --mlir_components=Bridge --debug_info=%s.debug.pbtxt 2>&1 | FileCheck %s -dump-input-on-failure
|
||||
# RUN: not tfcompile --graph=%s --config=%s.config.pbtxt --mlir_components=None 2>&1 | FileCheck -check-prefix=OLD %s -dump-input-on-failure
|
||||
|
||||
# Checks the error message produced by tfcompile with mlir_component
|
||||
# Checks that source debug information is used in the output error message and
|
||||
# the node x_y_sum = Add
|
||||
# CHECK: INVALID ARGUMENTS: Dimensions must be equal, but are 2 and 3 for 'x_y_sum = Add[T=DT_INT32](aot_feed_0/x, aot_feed_0/y)'
|
||||
# CHECK: math_ops.add(x, y, name='x_y_sum')
|
||||
# CHECK: build_graph(out_dir)
|
||||
|
||||
# Checks the error message produced by tfcompile without mlir_component
|
||||
# OLD: INVALID ARGUMENTS: Incompatible shapes: [2] vs. [3]
|
||||
# OLD: x_y_sum
|
||||
|
||||
node: {
|
||||
name: "x"
|
||||
op: "Placeholder"
|
||||
attr: {
|
||||
key: "shape"
|
||||
value: {
|
||||
shape: {
|
||||
dim: {
|
||||
size: -1
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
attr: {
|
||||
key: "dtype"
|
||||
value: {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
}
|
||||
node: {
|
||||
name: "y"
|
||||
op: "Placeholder"
|
||||
attr: {
|
||||
key: "shape"
|
||||
value: {
|
||||
shape: {
|
||||
dim: {
|
||||
size: -1
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
attr: {
|
||||
key: "dtype"
|
||||
value: {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
}
|
||||
node: {
|
||||
name: "x_y_sum"
|
||||
op: "Add"
|
||||
input: "x"
|
||||
input: "y"
|
||||
attr: {
|
||||
key: "T"
|
||||
value: {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
}
|
||||
versions: {
|
||||
producer: 321
|
||||
}
|
@ -0,0 +1,16 @@
|
||||
# Text form of tensorflow.tf2xla.Config proto.
|
||||
feed {
|
||||
id { node_name: "x" }
|
||||
shape {
|
||||
dim { size: 2 }
|
||||
}
|
||||
}
|
||||
feed {
|
||||
id { node_name: "y" }
|
||||
shape {
|
||||
dim { size: 3 }
|
||||
}
|
||||
}
|
||||
fetch {
|
||||
id { node_name: "x_y_sum" }
|
||||
}
|
@ -0,0 +1,28 @@
|
||||
files: "org_tensorflow/tensorflow/compiler/aot/tests/test_error_message.lit.pbtxt.fake_py.debug"
|
||||
traces: {
|
||||
key: "x@"
|
||||
value: {
|
||||
file_line_cols: {
|
||||
line: 1
|
||||
}
|
||||
}
|
||||
}
|
||||
traces: {
|
||||
key: "x_y_sum@"
|
||||
value: {
|
||||
file_line_cols: {
|
||||
line: 3
|
||||
}
|
||||
file_line_cols: {
|
||||
line: 4
|
||||
}
|
||||
}
|
||||
}
|
||||
traces: {
|
||||
key: "y@"
|
||||
value: {
|
||||
file_line_cols: {
|
||||
line: 2
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,4 @@
|
||||
x = value
|
||||
y = value
|
||||
math_ops.add(x, y, name='x_y_sum')
|
||||
build_graph(out_dir)
|
@ -26,6 +26,7 @@ def tf_library(
|
||||
name,
|
||||
graph,
|
||||
config,
|
||||
debug_info = None,
|
||||
freeze_checkpoint = None,
|
||||
freeze_saver = None,
|
||||
cpp_class = None,
|
||||
@ -191,12 +192,15 @@ def tf_library(
|
||||
|
||||
mlir_flag = "--mlir_components=" + mlir_components
|
||||
|
||||
srcs = [tfcompile_graph, config]
|
||||
debug_info_flag = ""
|
||||
if debug_info:
|
||||
srcs.append(debug_info)
|
||||
debug_info_flag = " --debug_info=$(location " + debug_info + ")"
|
||||
|
||||
native.genrule(
|
||||
name = ("gen_" + name),
|
||||
srcs = [
|
||||
tfcompile_graph,
|
||||
config,
|
||||
],
|
||||
srcs = srcs,
|
||||
outs = [
|
||||
header_file,
|
||||
metadata_object_file,
|
||||
@ -206,6 +210,7 @@ def tf_library(
|
||||
"CUDA_VISIBLE_DEVICES='' " +
|
||||
"$(location " + tfcompile_tool + ")" +
|
||||
" --graph=$(location " + tfcompile_graph + ")" +
|
||||
debug_info_flag +
|
||||
" --config=$(location " + config + ")" +
|
||||
" --entry_point=" + ep +
|
||||
" --cpp_class=" + cpp_class +
|
||||
@ -237,10 +242,7 @@ def tf_library(
|
||||
session_module_pb = name + "_session_module.pb"
|
||||
native.genrule(
|
||||
name = (name + "_session_module"),
|
||||
srcs = [
|
||||
tfcompile_graph,
|
||||
config,
|
||||
],
|
||||
srcs = srcs,
|
||||
outs = [
|
||||
session_module_pb,
|
||||
],
|
||||
@ -248,6 +250,7 @@ def tf_library(
|
||||
"CUDA_VISIBLE_DEVICES='' " +
|
||||
"$(location " + tfcompile_tool + ")" +
|
||||
" --graph=$(location " + tfcompile_graph + ")" +
|
||||
debug_info_flag +
|
||||
" --config=$(location " + config + ")" +
|
||||
" --entry_point=" + ep +
|
||||
" --cpp_class=" + cpp_class +
|
||||
|
@ -65,6 +65,7 @@ int main(int argc, char** argv) {
|
||||
flags.out_metadata_object = "out_helper.o";
|
||||
flags.out_header = "out.h";
|
||||
flags.entry_point = "entry";
|
||||
flags.debug_info_path_begin_marker = "";
|
||||
|
||||
std::vector<tensorflow::Flag> flag_list;
|
||||
AppendMainFlags(&flag_list, &flags);
|
||||
@ -81,12 +82,10 @@ int main(int argc, char** argv) {
|
||||
|
||||
tensorflow::port::InitMain(usage.c_str(), &argc, &argv);
|
||||
QCHECK(argc == 1) << "\nERROR: This command does not take any arguments "
|
||||
"other than flags\n\n"
|
||||
<< usage;
|
||||
"other than flags. See --help.\n\n";
|
||||
tensorflow::Status status = tensorflow::tfcompile::Main(flags);
|
||||
if (status.code() == tensorflow::error::INVALID_ARGUMENT) {
|
||||
std::cerr << "INVALID ARGUMENTS: " << status.error_message() << "\n\n"
|
||||
<< usage;
|
||||
std::cerr << "INVALID ARGUMENTS: " << status.error_message() << "\n\n";
|
||||
return 1;
|
||||
} else {
|
||||
TF_QCHECK_OK(status);
|
||||
|
@ -184,6 +184,7 @@ XLA_DEVICE_DEPS = [
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:dataset_ops_op_lib",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:functional_ops_op_lib",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
|
@ -18,6 +18,7 @@ cc_library(
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/container:inlined_vector",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
@ -368,14 +368,20 @@ bool GraphCycles::CanContractEdge(int32 a, int32 b) {
|
||||
return !reachable;
|
||||
}
|
||||
|
||||
bool GraphCycles::ContractEdge(int32 a, int32 b) {
|
||||
absl::optional<int32> GraphCycles::ContractEdge(int32 a, int32 b) {
|
||||
CHECK(HasEdge(a, b));
|
||||
RemoveEdge(a, b);
|
||||
|
||||
if (IsReachableNonConst(a, b)) {
|
||||
// Restore the graph to its original state.
|
||||
InsertEdge(a, b);
|
||||
return false;
|
||||
return absl::nullopt;
|
||||
}
|
||||
|
||||
if (rep_->nodes_[b]->in.Size() + rep_->nodes_[b]->out.Size() >
|
||||
rep_->nodes_[a]->in.Size() + rep_->nodes_[a]->out.Size()) {
|
||||
// Swap "a" and "b" to minimize copying.
|
||||
std::swap(a, b);
|
||||
}
|
||||
|
||||
Node* nb = rep_->nodes_[b];
|
||||
@ -399,7 +405,8 @@ bool GraphCycles::ContractEdge(int32 a, int32 b) {
|
||||
InsertEdge(y, a);
|
||||
}
|
||||
|
||||
return true;
|
||||
// Note, if the swap happened it might be what originally was called "b".
|
||||
return a;
|
||||
}
|
||||
|
||||
absl::Span<const int32> GraphCycles::Successors(int32 node) const {
|
||||
|
@ -40,6 +40,7 @@ limitations under the License.
|
||||
// FindPath() is linear in the size of the graph.
|
||||
// The current implementation uses O(|V|+|E|) space.
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
@ -80,11 +81,11 @@ class GraphCycles {
|
||||
// Return whether there is an edge directly from source_node to dest_node.
|
||||
bool HasEdge(int32 source_node, int32 dest_node) const;
|
||||
|
||||
// Contracts the edge from 'a' to node 'b', merging nodes 'a' and 'b'. 'b' is
|
||||
// removed from the graph, and edges to/from 'b' are replaced with edges
|
||||
// to/from 'a'. If contracting the edge would create a cycle, does nothing
|
||||
// and returns false.
|
||||
bool ContractEdge(int32 a, int32 b);
|
||||
// Contracts the edge from 'a' to node 'b', merging nodes 'a' and 'b'. One of
|
||||
// the nodes is removed from the graph, and edges to/from it are added to
|
||||
// the remaining one, which is returned. If contracting the edge would create
|
||||
// a cycle, does nothing and return no value.
|
||||
absl::optional<int32> ContractEdge(int32 a, int32 b);
|
||||
|
||||
// Return true if can contract edge, otherwise return false.
|
||||
bool CanContractEdge(int32 a, int32 b);
|
||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
|
||||
|
||||
#include <optional>
|
||||
#include <random>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
@ -479,19 +480,21 @@ TEST_F(GraphCyclesTest, ContractEdge) {
|
||||
ASSERT_TRUE(AddEdge(2, 4));
|
||||
ASSERT_TRUE(AddEdge(3, 4));
|
||||
|
||||
EXPECT_FALSE(g_.ContractEdge(1, 3));
|
||||
EXPECT_FALSE(g_.ContractEdge(1, 3).has_value());
|
||||
CHECK(g_.CheckInvariants());
|
||||
EXPECT_TRUE(g_.HasEdge(1, 3));
|
||||
|
||||
EXPECT_TRUE(g_.ContractEdge(1, 2));
|
||||
// Node (2) has more edges.
|
||||
EXPECT_EQ(g_.ContractEdge(1, 2).value(), 2);
|
||||
CHECK(g_.CheckInvariants());
|
||||
EXPECT_TRUE(g_.HasEdge(1, 3));
|
||||
EXPECT_TRUE(g_.HasEdge(1, 4));
|
||||
EXPECT_TRUE(g_.HasEdge(2, 3));
|
||||
EXPECT_TRUE(g_.HasEdge(2, 4));
|
||||
EXPECT_TRUE(g_.HasEdge(3, 4));
|
||||
|
||||
EXPECT_TRUE(g_.ContractEdge(1, 3));
|
||||
// Node (2) has more edges.
|
||||
EXPECT_EQ(g_.ContractEdge(2, 3).value(), 2);
|
||||
CHECK(g_.CheckInvariants());
|
||||
EXPECT_TRUE(g_.HasEdge(1, 4));
|
||||
EXPECT_TRUE(g_.HasEdge(2, 4));
|
||||
}
|
||||
|
||||
TEST_F(GraphCyclesTest, CanContractEdge) {
|
||||
@ -527,3 +530,26 @@ static void BM_StressTest(int iters, int num_nodes) {
|
||||
}
|
||||
}
|
||||
BENCHMARK(BM_StressTest)->Range(2048, 1048576);
|
||||
|
||||
static void BM_ContractEdge(int iters, int num_nodes) {
|
||||
while (iters-- > 0) {
|
||||
tensorflow::testing::StopTiming();
|
||||
tensorflow::GraphCycles g;
|
||||
std::vector<int32> nodes;
|
||||
nodes.reserve(num_nodes);
|
||||
for (int i = 0; i < num_nodes; i++) {
|
||||
nodes.push_back(g.NewNode());
|
||||
}
|
||||
// All edges point toward the last one.
|
||||
for (int i = 0; i < num_nodes - 1; ++i) {
|
||||
g.InsertEdge(nodes[i], nodes[num_nodes - 1]);
|
||||
}
|
||||
|
||||
tensorflow::testing::StartTiming();
|
||||
int node = num_nodes - 1;
|
||||
for (int i = 0; i < num_nodes - 1; ++i) {
|
||||
node = g.ContractEdge(nodes[i], node).value();
|
||||
}
|
||||
}
|
||||
}
|
||||
BENCHMARK(BM_ContractEdge)->Arg(1000)->Arg(10000);
|
||||
|
@ -172,8 +172,9 @@ class XlaExecutableClosureStore {
|
||||
|
||||
private:
|
||||
mutex mutex_;
|
||||
int64 key_counter_ GUARDED_BY(mutex_);
|
||||
absl::flat_hash_map<KeyT, XlaExecutableClosure> closures_ GUARDED_BY(mutex_);
|
||||
int64 key_counter_ TF_GUARDED_BY(mutex_);
|
||||
absl::flat_hash_map<KeyT, XlaExecutableClosure> closures_
|
||||
TF_GUARDED_BY(mutex_);
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(XlaExecutableClosureStore);
|
||||
};
|
||||
|
@ -165,7 +165,8 @@ class XlaCompileOp : public OpKernel {
|
||||
// error when compiling the cluster this _XlaCompile is supposed to compile.
|
||||
// If `cannot_compile_cluster_` is true then we avoid compiling this cluster
|
||||
// on any future calls to _XlaCompile.
|
||||
bool cannot_compile_cluster_ GUARDED_BY(cannot_compile_cluster_mu_) = false;
|
||||
bool cannot_compile_cluster_ TF_GUARDED_BY(cannot_compile_cluster_mu_) =
|
||||
false;
|
||||
|
||||
mutex cannot_compile_cluster_mu_;
|
||||
};
|
||||
|
@ -161,6 +161,11 @@ class MarkForCompilationPassImpl {
|
||||
// The ID of the cluster as represented in `cycles_graph_`.
|
||||
int cycles_graph_node_id() const { return cycles_graph_node_id_; }
|
||||
|
||||
// Sets the ID of the cluster as represented in `cycles_graph_`.
|
||||
void set_cycles_graph_node_id(int cycles_graph_node_id) {
|
||||
cycles_graph_node_id_ = cycles_graph_node_id;
|
||||
}
|
||||
|
||||
// The size of the cluster excluding constant and identity nodes.
|
||||
int effective_cluster_size() const { return effective_cluster_size_; }
|
||||
|
||||
@ -381,14 +386,16 @@ class MarkForCompilationPassImpl {
|
||||
// R, B} cluster.
|
||||
string DescribePotentialCycle(int from, int to);
|
||||
|
||||
// Merge the clusters `cluster_from` and `cluster_to`. After this step the
|
||||
// larger combined cluster is represented by `cluster_from`'s ID in
|
||||
// `cycles_graph_`.
|
||||
// Merge the clusters `cluster_from` and `cluster_to`. After this step the
|
||||
// larger combined cluster is represented by `cluster_from`, but can have
|
||||
// `cycles_graph_`'s ID of either `cluster_from` or `cluster_to` depending on
|
||||
// which way will require less operations.
|
||||
bool MergeClusters(Cluster* cluster_from, Cluster* cluster_to) {
|
||||
int from = cluster_from->cycles_graph_node_id();
|
||||
int to = cluster_to->cycles_graph_node_id();
|
||||
|
||||
if (!cycles_graph_.ContractEdge(from, to)) {
|
||||
auto optional_merged_node = cycles_graph_.ContractEdge(from, to);
|
||||
if (!optional_merged_node.has_value()) {
|
||||
VLOG(3) << "Could not contract " << cluster_from->DebugString(*graph_)
|
||||
<< " -> " << cluster_to->DebugString(*graph_)
|
||||
<< " because contracting the edge would create a cycle via "
|
||||
@ -398,6 +405,8 @@ class MarkForCompilationPassImpl {
|
||||
|
||||
// Merge the clusters.
|
||||
cluster_from->Merge(cluster_to);
|
||||
// Update `cycle_graph_`'s ID.
|
||||
cluster_from->set_cycles_graph_node_id(optional_merged_node.value());
|
||||
|
||||
// Merge the UnionFind<Cluster*>.
|
||||
cluster_for_node_[from].Merge(&cluster_for_node_[to]);
|
||||
@ -1911,6 +1920,7 @@ absl::flat_hash_set<string> GetKnownXLAWhitelistOp() {
|
||||
"LinSpace",
|
||||
"ListDiff",
|
||||
"LogMatrixDeterminant",
|
||||
"LowerBound",
|
||||
"MatMul",
|
||||
"MatrixBandPart",
|
||||
"MatrixDiag",
|
||||
@ -2037,6 +2047,7 @@ absl::flat_hash_set<string> GetKnownXLAWhitelistOp() {
|
||||
"TensorScatterUpdate",
|
||||
"TridiagonalSolve",
|
||||
"TruncatedNormal",
|
||||
"UpperBound",
|
||||
"UnsortedSegmentMax",
|
||||
"UnsortedSegmentMin",
|
||||
"UnsortedSegmentProd",
|
||||
|
@ -18,13 +18,15 @@ limitations under the License.
|
||||
#include "absl/synchronization/mutex.h"
|
||||
#include "tensorflow/compiler/jit/xla_activity.pb.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/platform/thread_annotations.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
// The list of all registered `XlaActivityListener`s.
|
||||
struct XlaActivityListenerList {
|
||||
absl::Mutex mutex;
|
||||
std::vector<std::unique_ptr<XlaActivityListener>> listeners GUARDED_BY(mutex);
|
||||
std::vector<std::unique_ptr<XlaActivityListener>> listeners
|
||||
TF_GUARDED_BY(mutex);
|
||||
};
|
||||
|
||||
void FlushAllListeners();
|
||||
|
@ -50,7 +50,7 @@ TEST(CreateCycleDetectionGraph, ConnectivityThroughEnterExitRegion) {
|
||||
|
||||
GraphCycles cycles;
|
||||
TF_ASSERT_OK(CreateCycleDetectionGraph(root.graph(), &cycles).status());
|
||||
EXPECT_FALSE(cycles.ContractEdge(a.node()->id(), b.node()->id()));
|
||||
EXPECT_FALSE(cycles.CanContractEdge(a.node()->id(), b.node()->id()));
|
||||
}
|
||||
|
||||
TEST(CreateCycleDetectionGraph, ConnectivityThroughMultipleEnterExitRegions) {
|
||||
@ -69,7 +69,7 @@ TEST(CreateCycleDetectionGraph, ConnectivityThroughMultipleEnterExitRegions) {
|
||||
|
||||
GraphCycles cycles;
|
||||
TF_ASSERT_OK(CreateCycleDetectionGraph(root.graph(), &cycles).status());
|
||||
EXPECT_FALSE(cycles.ContractEdge(a.node()->id(), b.node()->id()));
|
||||
EXPECT_FALSE(cycles.CanContractEdge(a.node()->id(), b.node()->id()));
|
||||
}
|
||||
|
||||
TEST(CreateCycleDetectionGraph, ReachingEnterExit) {
|
||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/jit/xla_activity_listener.h"
|
||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/type_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_context.h"
|
||||
#include "tensorflow/compiler/xla/client/client_library.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
@ -33,6 +34,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/common_runtime/metrics.h"
|
||||
#include "tensorflow/core/framework/attr_value_util.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/graph/algorithm.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/graph/node_builder.h"
|
||||
#include "tensorflow/core/lib/hash/hash.h"
|
||||
@ -202,6 +204,52 @@ static bool ShouldBeMegamorphic(int64 compile_count, int64 execution_count) {
|
||||
execution_count < kMinExecutionsPerCompile * compile_count;
|
||||
}
|
||||
|
||||
// Creates a simple graph using the specified op as the only op apart from the
|
||||
// arg and retval nodes.
|
||||
static xla::StatusOr<std::unique_ptr<Graph>> CreateGraph(
|
||||
const NodeDef& node_def, absl::Span<const XlaCompiler::Argument> args,
|
||||
absl::Span<const DataType> result_types) {
|
||||
// TODO(b/74182462): We implement this by creating a new dummy Graph including
|
||||
// _Arg nodes, and let CompileGraph walk it. This could be optimized.
|
||||
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
||||
|
||||
Status status;
|
||||
// First create the actual node we care about computing.
|
||||
Node* main_node = graph->AddNode(node_def, &status);
|
||||
TF_RETURN_IF_ERROR(status);
|
||||
|
||||
// Create dummy _Arg nodes. Link these to `node` and also via a control
|
||||
// dependency edge to the _SOURCE node.
|
||||
for (int64 i = 0; i < args.size(); ++i) {
|
||||
Node* node;
|
||||
string arg_name = absl::StrCat("_arg", i);
|
||||
Status status =
|
||||
NodeBuilder(arg_name, FunctionLibraryDefinition::kArgOp)
|
||||
.ControlInput(graph->source_node())
|
||||
.Attr("T", args[i].kind == XlaCompiler::Argument::kResource
|
||||
? DT_RESOURCE
|
||||
: args[i].type)
|
||||
.Attr("index", i)
|
||||
.Finalize(graph.get(), &node);
|
||||
TF_RETURN_IF_ERROR(status);
|
||||
graph->AddEdge(node, 0, main_node, i);
|
||||
}
|
||||
|
||||
// Similarly with return values, create dummy _Retval nodes fed by `node`.
|
||||
for (int64 i = 0; i < result_types.size(); ++i) {
|
||||
Node* node;
|
||||
string retval_name = absl::StrCat("_retval", i);
|
||||
Status status = NodeBuilder(retval_name, FunctionLibraryDefinition::kRetOp)
|
||||
.Input(main_node, i)
|
||||
.Attr("T", result_types[i])
|
||||
.Attr("index", i)
|
||||
.Finalize(graph.get(), &node);
|
||||
TF_RETURN_IF_ERROR(status);
|
||||
}
|
||||
FixupSourceAndSinkEdges(graph.get());
|
||||
return graph;
|
||||
}
|
||||
|
||||
Status XlaCompilationCache::CompileSingleOp(
|
||||
const XlaCompiler::Options& options,
|
||||
absl::Span<const XlaCompiler::Argument> args, OpKernelContext* ctx,
|
||||
@ -222,8 +270,11 @@ Status XlaCompilationCache::CompileSingleOp(
|
||||
for (int i = 0; i < result_dtypes.size(); ++i) {
|
||||
result_dtypes[i] = ctx->expected_output_dtype(i);
|
||||
}
|
||||
return compiler->CompileSingleOp(compile_options, ctx->op_kernel().def(),
|
||||
args, result_dtypes, result);
|
||||
|
||||
const NodeDef& node_def = ctx->op_kernel().def();
|
||||
TF_ASSIGN_OR_RETURN(auto graph, CreateGraph(node_def, args, result_dtypes));
|
||||
return compiler->CompileGraph(compile_options, node_def.name(),
|
||||
std::move(graph), args, result);
|
||||
};
|
||||
return CompileImpl(options, name, args, compile_op,
|
||||
/*compile_threshold=*/absl::nullopt,
|
||||
|
@ -151,19 +151,19 @@ class XlaCompilationCache : public ResourceBase {
|
||||
int64 request_count = 0;
|
||||
|
||||
// Did compilation succeed?
|
||||
Status compilation_status GUARDED_BY(mu);
|
||||
Status compilation_status TF_GUARDED_BY(mu);
|
||||
|
||||
// Output of the XlaCompiler.
|
||||
XlaCompiler::CompilationResult compilation_result GUARDED_BY(mu);
|
||||
XlaCompiler::CompilationResult compilation_result TF_GUARDED_BY(mu);
|
||||
|
||||
// The XLA executable compiled from <computation>. May be null if no
|
||||
// executable has been built.
|
||||
std::unique_ptr<xla::LocalExecutable> executable GUARDED_BY(mu);
|
||||
std::unique_ptr<xla::LocalExecutable> executable TF_GUARDED_BY(mu);
|
||||
};
|
||||
|
||||
mutex compile_cache_mu_;
|
||||
absl::flat_hash_map<Signature, std::unique_ptr<Entry>, Signature::Hash> cache_
|
||||
GUARDED_BY(compile_cache_mu_);
|
||||
TF_GUARDED_BY(compile_cache_mu_);
|
||||
|
||||
struct ClusterCompileStats {
|
||||
// Number of times the cluster has been (re-)compiled.
|
||||
@ -185,7 +185,7 @@ class XlaCompilationCache : public ResourceBase {
|
||||
|
||||
// Maps cluster names to compilation statistics for said cluster.
|
||||
absl::flat_hash_map<string, ClusterCompileStats> cluster_compile_stats_
|
||||
GUARDED_BY(cluster_compile_stats_mu_);
|
||||
TF_GUARDED_BY(cluster_compile_stats_mu_);
|
||||
|
||||
// The number of times a lazy compilation must be requested for a specific
|
||||
// signature before we attempt to compile it.
|
||||
|
@ -83,7 +83,7 @@ class XlaDeviceAllocatorState {
|
||||
std::unordered_map<std::pair<const xla::Backend*, int>,
|
||||
std::unique_ptr<XlaDeviceAllocator>,
|
||||
hash<std::pair<const xla::Backend*, int>>>
|
||||
allocators_ GUARDED_BY(allocator_mutex_);
|
||||
allocators_ TF_GUARDED_BY(allocator_mutex_);
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(XlaDeviceAllocatorState);
|
||||
};
|
||||
|
@ -137,7 +137,7 @@ class XlaDevice : public LocalDevice {
|
||||
~XlaDevice() override;
|
||||
|
||||
Allocator* GetAllocator(AllocatorAttributes attr) override
|
||||
LOCKS_EXCLUDED(mu_);
|
||||
TF_LOCKS_EXCLUDED(mu_);
|
||||
void Compute(OpKernel* op_kernel, OpKernelContext* context) override;
|
||||
void ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context,
|
||||
AsyncOpKernel::DoneCallback done) override;
|
||||
@ -145,18 +145,18 @@ class XlaDevice : public LocalDevice {
|
||||
void Sync(const DoneCallback& done) override;
|
||||
|
||||
Status TryGetDeviceContext(DeviceContext** out_context) override
|
||||
LOCKS_EXCLUDED(mu_);
|
||||
TF_LOCKS_EXCLUDED(mu_);
|
||||
|
||||
Status MakeTensorFromProto(const TensorProto& tensor_proto,
|
||||
const AllocatorAttributes alloc_attrs,
|
||||
Tensor* tensor) override LOCKS_EXCLUDED(mu_);
|
||||
Tensor* tensor) override TF_LOCKS_EXCLUDED(mu_);
|
||||
|
||||
// Allocate tensor on fast memory space. This is only applied to the new TPU
|
||||
// hardware which has faster read/write memory. If the hardware doesn't
|
||||
// have such memory space, we fallback to the ordinary memory space.
|
||||
Status MakeFastMemTensorFromProto(const TensorProto& tensor_proto,
|
||||
const AllocatorAttributes alloc_attrs,
|
||||
Tensor* tensor) LOCKS_EXCLUDED(mu_);
|
||||
Tensor* tensor) TF_LOCKS_EXCLUDED(mu_);
|
||||
|
||||
const Metadata& metadata() { return xla_metadata_; }
|
||||
|
||||
@ -166,34 +166,35 @@ class XlaDevice : public LocalDevice {
|
||||
//
|
||||
// TODO(b/111859745): The Eager context needs to call this method to recover
|
||||
// from failures.
|
||||
Status EnsureDeviceContextOk() LOCKS_EXCLUDED(mu_);
|
||||
Status EnsureDeviceContextOk() TF_LOCKS_EXCLUDED(mu_);
|
||||
|
||||
// Instructs this XlaDevice to set a GpuDeviceInfo, which holds extra
|
||||
// information for GPU and TPU devices.
|
||||
Status UseGpuDeviceInfo() LOCKS_EXCLUDED(mu_);
|
||||
Status UseGpuDeviceInfo() TF_LOCKS_EXCLUDED(mu_);
|
||||
|
||||
// Instructs this XlaDevice to return 'sync_on_completion' for
|
||||
// AllowsSyncOnCompletion().
|
||||
void SetAllowsSyncOnCompletion(bool sync_on_completion) LOCKS_EXCLUDED(mu_);
|
||||
bool AllowsSyncOnCompletion() const override LOCKS_EXCLUDED(mu_);
|
||||
void SetAllowsSyncOnCompletion(bool sync_on_completion)
|
||||
TF_LOCKS_EXCLUDED(mu_);
|
||||
bool AllowsSyncOnCompletion() const override TF_LOCKS_EXCLUDED(mu_);
|
||||
|
||||
// Installs an error handling callback when RefreshStatus sees !status.ok().
|
||||
void SetHandleDeviceErrorCallback(std::function<Status()> callback);
|
||||
|
||||
Status RefreshStatus() override LOCKS_EXCLUDED(mu_);
|
||||
Status RefreshStatus() override TF_LOCKS_EXCLUDED(mu_);
|
||||
|
||||
private:
|
||||
xla::StatusOr<xla::LocalClient*> GetOrCreateClient() const;
|
||||
Allocator* GetAllocatorLocked(AllocatorAttributes attr)
|
||||
EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||
TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||
Status EnsureStreamOkLocked(xla::Backend* backend, const string& name,
|
||||
std::shared_ptr<se::Stream>* stream,
|
||||
bool* stream_was_changed)
|
||||
EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||
TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||
|
||||
// Return a pair of device context, the second one is fast_mem device context.
|
||||
xla::StatusOr<std::pair<XlaDeviceContext*, XlaDeviceContext*>>
|
||||
GetDeviceContextLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||
GetDeviceContextLocked() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||
|
||||
static Status GetMetadataFromDevice(DeviceBase* device,
|
||||
const XlaDevice::Metadata** metadata);
|
||||
@ -218,13 +219,13 @@ class XlaDevice : public LocalDevice {
|
||||
// Intra-op threads to spawn (from SessionOptions).
|
||||
const int intra_op_parallelism_threads_;
|
||||
// Memory allocator associated with this device.
|
||||
Allocator* xla_allocator_ GUARDED_BY(mu_) = nullptr; // Not owned.
|
||||
Allocator* xla_allocator_ TF_GUARDED_BY(mu_) = nullptr; // Not owned.
|
||||
|
||||
// Stream associated with this device. Operations enqueued on this
|
||||
// stream are executed on the device. Operations include data
|
||||
// copying back and forth between CPU and the device, and
|
||||
// computations enqueued by XLA.
|
||||
std::shared_ptr<se::Stream> stream_ GUARDED_BY(mu_);
|
||||
std::shared_ptr<se::Stream> stream_ TF_GUARDED_BY(mu_);
|
||||
// If false, only stream_ is valid and all computation and transfers use
|
||||
// stream_. If true, computation is performed by stream_ and transfers are
|
||||
// performed by host_to_device/device_to_device stream or borrowing a stream
|
||||
@ -232,36 +233,36 @@ class XlaDevice : public LocalDevice {
|
||||
const bool use_multiple_streams_;
|
||||
// If use_multiple_streams_, host to device transfers are performed using this
|
||||
// stream.
|
||||
std::shared_ptr<se::Stream> host_to_device_stream_ GUARDED_BY(mu_);
|
||||
std::shared_ptr<se::Stream> host_to_device_stream_ TF_GUARDED_BY(mu_);
|
||||
// If use_multiple_streams_, transfers between different devices are performed
|
||||
// using these streams.
|
||||
std::vector<std::shared_ptr<se::Stream>> device_to_device_streams_
|
||||
GUARDED_BY(mu_);
|
||||
TF_GUARDED_BY(mu_);
|
||||
|
||||
const XlaCompiler::ShapeRepresentationFn shape_representation_fn_;
|
||||
|
||||
// The device context accessed by all users of the XlaDevice, set by calls to
|
||||
// EnsureDeviceContextOk. If gpu_device_info_ is non-null, this pointer is
|
||||
// also filled in to that struct. XlaDeviceContext is a ref-counted object.
|
||||
XlaDeviceContext* device_context_ GUARDED_BY(mu_) = nullptr;
|
||||
XlaDeviceContext* device_context_ TF_GUARDED_BY(mu_) = nullptr;
|
||||
|
||||
// The device context will allocate memory on fast memory space on TPU.
|
||||
// XlaDeviceContext is a ref-counted object.
|
||||
XlaDeviceContext* fast_mem_device_context_ GUARDED_BY(mu_) = nullptr;
|
||||
XlaDeviceContext* fast_mem_device_context_ TF_GUARDED_BY(mu_) = nullptr;
|
||||
|
||||
// Holds extra information for GPU and TPU devices, e.g. the device context.
|
||||
bool use_gpu_device_info_ GUARDED_BY(mu_) = false;
|
||||
std::unique_ptr<GpuDeviceInfo> gpu_device_info_ GUARDED_BY(mu_);
|
||||
bool use_gpu_device_info_ TF_GUARDED_BY(mu_) = false;
|
||||
std::unique_ptr<GpuDeviceInfo> gpu_device_info_ TF_GUARDED_BY(mu_);
|
||||
|
||||
// Thread pool used for running closures
|
||||
std::unique_ptr<thread::ThreadPool> thread_pool_;
|
||||
|
||||
// True if the device allows XlaDevice::Sync to be called on completion
|
||||
// regardless of status.
|
||||
bool sync_on_completion_ GUARDED_BY(mu_) = true;
|
||||
bool sync_on_completion_ TF_GUARDED_BY(mu_) = true;
|
||||
|
||||
// A callback that will be invoked when RefreshStatus sees a status error.
|
||||
std::function<Status()> device_error_callback_ GUARDED_BY(mu_);
|
||||
std::function<Status()> device_error_callback_ TF_GUARDED_BY(mu_);
|
||||
|
||||
// Set of devices to use. This controls which of the devices on the given
|
||||
// platform will have resources allocated. For GPUs this will be
|
||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/common_runtime/dma_helper.h"
|
||||
#include "tensorflow/core/framework/tensor_reference.h"
|
||||
#include "tensorflow/core/platform/mem.h"
|
||||
#include "tensorflow/stream_executor/platform/port.h"
|
||||
|
||||
|
@ -117,7 +117,7 @@ class XlaDeviceContext : public DeviceContext {
|
||||
bool use_fast_mem_;
|
||||
|
||||
absl::Mutex mu_;
|
||||
int next_stream_ GUARDED_BY(mu_) = 0;
|
||||
int next_stream_ TF_GUARDED_BY(mu_) = 0;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -18,7 +18,6 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_COMPILER_JIT_XLA_LAUNCH_UTIL_H_
|
||||
#define TENSORFLOW_COMPILER_JIT_XLA_LAUNCH_UTIL_H_
|
||||
|
||||
#include "absl/base/thread_annotations.h"
|
||||
#include "tensorflow/compiler/jit/xla_compilation_cache.h"
|
||||
#include "tensorflow/compiler/jit/xla_tensor.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
|
||||
@ -30,6 +29,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||
#include "tensorflow/core/platform/thread_annotations.h"
|
||||
#include "tensorflow/stream_executor/device_memory_allocator.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -102,7 +102,7 @@ class VariableInfo {
|
||||
// `variables` is allowed to contain instances that don't track a resource
|
||||
// variable (i.e. variables[i].var() can be null for some i).
|
||||
Status LockVariables(absl::Span<VariableInfo> variables)
|
||||
EXCLUSIVE_LOCK_FUNCTION();
|
||||
TF_EXCLUSIVE_LOCK_FUNCTION();
|
||||
|
||||
// Helper class to perform the marshalling of TensorFlow inputs and outputs to
|
||||
// ShapedBuffers suitable for passing to an XLA computation.
|
||||
|
@ -122,7 +122,7 @@ class XlaTensor {
|
||||
std::shared_ptr<se::Event> definition_event_;
|
||||
// A list of all streams for which the tensor's content is defined for any
|
||||
// newly enqueued command.
|
||||
absl::InlinedVector<se::Stream*, 2> streams_defined_on_ GUARDED_BY(mu_);
|
||||
absl::InlinedVector<se::Stream*, 2> streams_defined_on_ TF_GUARDED_BY(mu_);
|
||||
mutex mu_;
|
||||
};
|
||||
|
||||
|
@ -74,12 +74,15 @@ cc_library(
|
||||
"//tensorflow/compiler/mlir/xla:hlo",
|
||||
"//tensorflow/compiler/mlir/xla:hlo_legalize_to_lhlo",
|
||||
"//tensorflow/compiler/mlir/xla:lhlo",
|
||||
"//tensorflow/compiler/mlir/xla:lhlo_copy_removal",
|
||||
"//tensorflow/compiler/mlir/xla:lhlo_fuse_linalg",
|
||||
"//tensorflow/compiler/mlir/xla:lhlo_legalize_to_affine",
|
||||
"//tensorflow/compiler/mlir/xla:lhlo_legalize_to_gpu",
|
||||
"//tensorflow/compiler/mlir/xla:lhlo_legalize_to_parallel_loops",
|
||||
"//tensorflow/compiler/mlir/xla:xla_dialect_registration",
|
||||
"//tensorflow/compiler/mlir/xla:xla_legalize_control_flow",
|
||||
"//tensorflow/compiler/mlir/xla:xla_legalize_tf",
|
||||
"//tensorflow/compiler/mlir/xla:xla_legalize_tf_with_tf2xla",
|
||||
"//tensorflow/compiler/mlir/xla:xla_legalize_to_linalg",
|
||||
"//tensorflow/compiler/mlir/xla:xla_legalize_to_standard",
|
||||
"//tensorflow/compiler/mlir/xla:xla_lower",
|
||||
@ -100,6 +103,38 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "mlir_graph_optimization_pass",
|
||||
srcs = ["mlir_graph_optimization_pass.cc"],
|
||||
hdrs = ["mlir_graph_optimization_pass.h"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"//tensorflow/compiler/mlir/tensorflow:convert_graphdef",
|
||||
"//tensorflow/compiler/mlir/tensorflow:device_util",
|
||||
"//tensorflow/compiler/mlir/tensorflow:dump_mlir_util",
|
||||
"//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:IR",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "mlir_graph_optimization_pass_registration",
|
||||
srcs = [
|
||||
"mlir_graph_optimization_pass_registration.cc",
|
||||
],
|
||||
deps = [
|
||||
":mlir_graph_optimization_pass",
|
||||
"//tensorflow/core:core_cpu",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
tf_cc_binary(
|
||||
name = "tf-opt",
|
||||
deps = [
|
||||
|
@ -30,7 +30,8 @@ filegroup(
|
||||
"ir/tfl_ops.td",
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_td_files",
|
||||
"@llvm-project//mlir:OpBaseTdFiles",
|
||||
"@llvm-project//mlir:include/mlir/Transforms/LoopLikeInterface.td",
|
||||
"@llvm-project//mlir:include/mlir/Interfaces/LoopLikeInterface.td",
|
||||
"@llvm-project//mlir:include/mlir/Interfaces/SideEffects.td",
|
||||
],
|
||||
)
|
||||
|
||||
@ -221,19 +222,21 @@ cc_library(
|
||||
deps = [
|
||||
":tensorflow_lite_ops_inc_gen",
|
||||
":validators",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:Analysis",
|
||||
"@llvm-project//mlir:Dialect",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:QuantOps",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:Support",
|
||||
# TODO(jpienaar): Move this out after splitting out LoopLikeOpInterface.
|
||||
"@llvm-project//mlir:Transforms",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:Analysis",
|
||||
"@llvm-project//mlir:DerivedAttributeOpInterface",
|
||||
"@llvm-project//mlir:Dialect",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:LoopLikeInterface",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:QuantOps",
|
||||
"@llvm-project//mlir:SideEffects",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:Support",
|
||||
"@llvm-project//mlir:TransformUtils",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
@ -325,8 +328,8 @@ cc_library(
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"//tensorflow/compiler/mlir/tensorflow:convert_tensor",
|
||||
"//tensorflow/compiler/mlir/tensorflow:mangling_util",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_passes",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
|
||||
"//tensorflow/compiler/mlir/tensorflow:unroll_batch_matmul_pass",
|
||||
"//tensorflow/compiler/xla:status",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/core:framework",
|
||||
@ -436,7 +439,7 @@ genrule(
|
||||
srcs = [
|
||||
"ir/tfl_ops.td",
|
||||
"ir/tfl_op_interfaces.td",
|
||||
"@llvm-project//mlir:include/mlir/Transforms/LoopLikeInterface.td",
|
||||
"@llvm-project//mlir:include/mlir/Interfaces/LoopLikeInterface.td",
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_td_files",
|
||||
],
|
||||
outs = [
|
||||
@ -516,6 +519,7 @@ cc_library(
|
||||
"@com_google_absl//absl/strings",
|
||||
"@flatbuffers",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:AllPassesAndDialects",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:TransformUtils",
|
||||
],
|
||||
@ -580,7 +584,7 @@ cc_library(
|
||||
"//tensorflow/lite/delegates/flex:whitelisted_flex_ops_lib",
|
||||
"//tensorflow/lite/kernels/internal:kernel_utils",
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
"//tensorflow/lite/tools/versioning:op_version",
|
||||
"//tensorflow/lite/tools/versioning",
|
||||
"@com_google_absl//absl/base",
|
||||
"@com_google_absl//absl/base:core_headers",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
@ -697,11 +701,12 @@ cc_library(
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_config",
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_passes",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"//tensorflow/compiler/mlir/tensorflow:decode_constant_pass",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_dialect_registration",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_passes",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tf_dialect_lib",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tf_dialect_passes",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tf_graph_optimization_pass",
|
||||
"//tensorflow/compiler/mlir/tensorflow:translate_lib",
|
||||
"@llvm-project//mlir:AllPassesAndDialects",
|
||||
"@llvm-project//mlir:Analysis",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
@ -725,6 +730,7 @@ cc_library(
|
||||
":tensorflow_lite_quantize",
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_config",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"//tensorflow/compiler/mlir/tensorflow:decode_constant_pass",
|
||||
"//tensorflow/compiler/mlir/tensorflow:error_util",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tf_dialect_lib",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tf_dialect_passes",
|
||||
@ -734,11 +740,10 @@ cc_library(
|
||||
"//tensorflow/lite/tools/optimize:quantize_weights",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:Analysis",
|
||||
"@llvm-project//mlir:AllPassesAndDialects",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Parser",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:QuantOps",
|
||||
"@llvm-project//mlir:Support",
|
||||
"@llvm-project//mlir:Transforms",
|
||||
],
|
||||
|
@ -1,9 +1,9 @@
|
||||
# Experimental code for the new TF-Lite convertor, and MLIR dialects and utilities for TensorFlow Lite.
|
||||
# The new [MLIR](https://github.com/llvm/llvm-project/tree/master/mlir) based
|
||||
TensorFlow to TensorFlow Lite converter
|
||||
|
||||
This directory contains:
|
||||
|
||||
1. Experimental code for the new TF-Lite convertor.
|
||||
2. Code for the TF-lite dialect [MLIR](https://github.com/tensorflow/mlir).
|
||||
1. MLIR dialects, transformation passes and utilities for TensorFlow Lite.
|
||||
|
||||
## API:
|
||||
|
||||
@ -11,7 +11,8 @@ The API for converting TensorFlow models to TensorFlow Lite will be through
|
||||
`tf.lite.TFLiteConverter`. All the conversion code is open sourced, and
|
||||
the API will be integrated soon.
|
||||
|
||||
### The conversion process from TensorFlow to TensorFlow Lite includes the following major passes:
|
||||
### The conversion process from TensorFlow to TensorFlow Lite includes the
|
||||
following major passes:
|
||||
|
||||
- Import from GraphDef, in .pb or .pbtxt format, into MLIR.
|
||||
- Raise to Control-flow-graph. Converts TF Control Flow dialect to TF dialect.
|
||||
@ -28,3 +29,6 @@ TensorFlow Lite models).
|
||||
- The Export pass writes out TensorFlow Lite FlatBuffer format. This pass
|
||||
operates on MLIR TensorFlow Lite dialect and is simple/direct translation.
|
||||
|
||||
See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc
|
||||
for the full list of MLIR passes for conversion from TensorFlow to
|
||||
TensorFlow Lite.
|
||||
|
@ -34,9 +34,9 @@ struct PassConfig {
|
||||
quant_specs(std::move(specs)),
|
||||
skip_control_dialect(false),
|
||||
form_clusters(false),
|
||||
inline_functions(true),
|
||||
unfold_batch_matmul(true),
|
||||
legalize_tf_while(true) {}
|
||||
legalize_tf_while(true),
|
||||
shape_inference(false) {}
|
||||
|
||||
// If `emit_builtin_tflite_ops` is true, TF Lite legalization passes will be
|
||||
// added, which produces TF Lite ops.
|
||||
@ -56,9 +56,6 @@ struct PassConfig {
|
||||
// are formed by grouping consecutive ops of the same device, under a
|
||||
// `tf_device.launch` op.
|
||||
bool form_clusters;
|
||||
// Inline function calls within the main function in the MLIR module, prior
|
||||
// to legalization to TFLite.
|
||||
bool inline_functions;
|
||||
// if `unfold_batch_matmul` is true, the tf.BatchMatMul is unfolded to a set
|
||||
// of tfl.fully_connected ops.
|
||||
bool unfold_batch_matmul;
|
||||
@ -66,6 +63,8 @@ struct PassConfig {
|
||||
// Note: This is staging step and will be removed.
|
||||
// TODO(b/137395003): Remove post switching legalization.
|
||||
bool legalize_tf_while;
|
||||
// Whether to do shape inference.
|
||||
bool shape_inference;
|
||||
};
|
||||
|
||||
} // namespace TFL
|
||||
|
@ -119,6 +119,12 @@ static void EmitOptionBuilders(const RecordKeeper &record_keeper,
|
||||
// conversion generation and so the simplicity was chosen over the
|
||||
// flexibility.
|
||||
StringRef arg_name = arg_values->getArgNameStr(i);
|
||||
// Skip any "intermiadiateXXX" attribute as they are specially handled
|
||||
// in the exporter. They are special because though they are attributes
|
||||
// in the MLIR they are expressed as tensors in the flatbuffer instead
|
||||
// of option.
|
||||
if (op_name == "LSTMOp" && arg_name.take_back(12) == "intermediate")
|
||||
continue;
|
||||
os << formatv(
|
||||
" auto {0} = Convert{1}ForOptionWriter(op.{0}(), fbb);\n",
|
||||
arg_name, mlir::tblgen::Attribute(arg_def).getAttrDefName());
|
||||
@ -164,17 +170,24 @@ static void EmitOperatorBuilders(const std::vector<Record *> &defs,
|
||||
for (const auto *def : defs) {
|
||||
StringRef op_name = def->getName().drop_front(4);
|
||||
|
||||
const bool has_intermediates = op_name == "LSTMOp";
|
||||
// Signature
|
||||
os << "static flatbuffers::Offset<tflite::Operator> "
|
||||
<< GetOperatorBuilderName(def->getName()) << "(mlir::TFL::" << op_name
|
||||
<< " tflOp, uint32_t opcode_index, "
|
||||
<< "const std::vector<int32_t>& operands,"
|
||||
<< "const std::vector<int32_t>& results,"
|
||||
<< (has_intermediates ? "const std::vector<int32_t>& intermediate_index,"
|
||||
: "")
|
||||
<< "flatbuffers::FlatBufferBuilder *fbb) {\n";
|
||||
|
||||
// Inputs & outputs
|
||||
os << " auto inputs = fbb->CreateVector(operands);\n"
|
||||
" auto outputs = fbb->CreateVector(results);\n\n";
|
||||
// Intermediates for LSTM.
|
||||
if (has_intermediates) {
|
||||
os << " auto intermediates = fbb->CreateVector(intermediate_index);\n";
|
||||
}
|
||||
|
||||
// Build the FlatBuffer operator
|
||||
os << " return tflite::CreateOperator(\n"
|
||||
@ -191,9 +204,9 @@ static void EmitOperatorBuilders(const std::vector<Record *> &defs,
|
||||
// Only builtin ops' builders are auto-generated. custom_options are only
|
||||
// used by custom or flex ops and those ops are handled manually.
|
||||
os << " /*custom_options=*/0, "
|
||||
"tflite::CustomOptionsFormat_FLEXBUFFERS,\n"
|
||||
" /*mutating_variable_inputs=*/0);\n"
|
||||
"}\n\n";
|
||||
<< "tflite::CustomOptionsFormat_FLEXBUFFERS,\n"
|
||||
<< " /*mutating_variable_inputs=*/0"
|
||||
<< (has_intermediates ? ", intermediates" : "") << ");\n}\n\n";
|
||||
}
|
||||
}
|
||||
|
||||
@ -244,6 +257,7 @@ static void EmitGetBuiltinOpCode(const std::vector<Record *> &defs,
|
||||
// uint32_t opcode_index,
|
||||
// const std::vector<int32_t>& operands,
|
||||
// const std::vector<int32_t>& results,
|
||||
// const std::vector<int32_t>& intermediates,
|
||||
// flatbuffers::FlatBufferBuilder *fbb);
|
||||
static void EmitBuildOperator(const std::vector<Record *> &defs,
|
||||
raw_ostream *ostream) {
|
||||
@ -255,6 +269,7 @@ static void EmitBuildOperator(const std::vector<Record *> &defs,
|
||||
"uint32_t opcode_index, "
|
||||
"const std::vector<int32_t>& operands,"
|
||||
"const std::vector<int32_t>& results,"
|
||||
"const std::vector<int32_t>& intermediates,"
|
||||
"flatbuffers::FlatBufferBuilder *fbb) {\n";
|
||||
|
||||
for (const auto *def : defs) {
|
||||
@ -264,7 +279,8 @@ static void EmitBuildOperator(const std::vector<Record *> &defs,
|
||||
os << " if (auto tflOp = llvm::dyn_cast<mlir::TFL::" << op_name
|
||||
<< ">(op))\n"
|
||||
<< " return " << GetOperatorBuilderName(def->getName())
|
||||
<< "(tflOp, opcode_index, operands, results, fbb);\n";
|
||||
<< "(tflOp, opcode_index, operands, results, "
|
||||
<< (op_name == "LSTMOp" ? "intermediates, " : "") << "fbb);\n";
|
||||
}
|
||||
|
||||
os << " return llvm::None;\n"
|
||||
@ -307,6 +323,10 @@ static void EmitBuiltinOptionsToAttributes(const RecordKeeper &record_keeper,
|
||||
if (!arg_def) continue;
|
||||
if (arg_def->getDef()->isSubClassOf(attr_type)) {
|
||||
StringRef arg_name = arg_values->getArgNameStr(i);
|
||||
// Already handle this case in flatbuffer_import.cc.
|
||||
if (option_name == "LSTMOptions" &&
|
||||
arg_name.take_back(12) == "intermediate")
|
||||
continue;
|
||||
StringRef attr_type = mlir::tblgen::Attribute(arg_def).getAttrDefName();
|
||||
os << formatv(
|
||||
" attributes.emplace_back(builder.getNamedAttr(\"{0}\","
|
||||
|
@ -547,6 +547,7 @@ bool IsCustomOp(const std::string& op_name) {
|
||||
// TODO(krzysd) Handle function calls
|
||||
StatusOr<Operation*> ConvertOp(
|
||||
const tflite::OperatorT& op, const std::vector<Value>& vals_map,
|
||||
const std::vector<mlir::TensorType>& intermediate_types,
|
||||
Value optional_arg_marker, const std::vector<std::string>& op_names,
|
||||
const std::vector<std::string>& func_names,
|
||||
const std::vector<std::unique_ptr<tflite::TensorT>>& tensors, Location loc,
|
||||
@ -608,6 +609,28 @@ StatusOr<Operation*> ConvertOp(
|
||||
if (op_name == "tfl.lstm") {
|
||||
// TODO(b/147587779): add the right region if region is empty.
|
||||
op_state.addRegion();
|
||||
if (!op.intermediates.empty()) {
|
||||
if (op.intermediates.size() != 5) {
|
||||
auto err = errors::InvalidArgument(
|
||||
"operator has intermediate tensors but the number of them is not "
|
||||
"five.");
|
||||
return emitError(loc, err.ToString()), err;
|
||||
}
|
||||
// Create intermediate value
|
||||
|
||||
const llvm::SmallVector<llvm::StringRef, 5> kIntermediateNames = {
|
||||
"input_to_input_intermediate", "input_to_forget_intermediate",
|
||||
"input_to_cell_intermediate", "input_to_output_intermediate",
|
||||
"effective_hidden_scale_intermediate"};
|
||||
for (auto type_and_name :
|
||||
llvm::zip(intermediate_types, kIntermediateNames)) {
|
||||
mlir::TypeAttr type_attr =
|
||||
mlir::TypeAttr::get(std::get<0>(type_and_name));
|
||||
auto named_attr =
|
||||
builder.getNamedAttr(std::get<1>(type_and_name), type_attr);
|
||||
op_state.addAttribute(named_attr.first, named_attr.second);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
llvm::SmallVector<mlir::NamedAttribute, 2> attrs;
|
||||
@ -893,6 +916,18 @@ StatusOr<FuncOp> ConvertSubgraph(
|
||||
}
|
||||
}
|
||||
|
||||
// Intermediate tensors for tfl.lstm are used to carry quantization range
|
||||
// in their types, so we only need and extract their types.
|
||||
std::vector<mlir::TensorType> intermediate_types;
|
||||
intermediate_types.reserve(5);
|
||||
for (auto intermediate : op->intermediates) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto type, GetTensorType(*subgraph.tensors[intermediate], builder,
|
||||
/*shapeless_are_scalars=*/true,
|
||||
/*is_constant=*/true));
|
||||
intermediate_types.emplace_back(type);
|
||||
}
|
||||
|
||||
// The NameLoc corresponding to the name of the first output tensor
|
||||
auto op_loc =
|
||||
op->outputs.empty()
|
||||
@ -902,8 +937,8 @@ StatusOr<FuncOp> ConvertSubgraph(
|
||||
// to a valid Value
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto* mlir_op,
|
||||
ConvertOp(*op, vals_map, maybe_optional_arg_marker, op_names,
|
||||
func_names, subgraph.tensors, op_loc, op_builder));
|
||||
ConvertOp(*op, vals_map, intermediate_types, maybe_optional_arg_marker,
|
||||
op_names, func_names, subgraph.tensors, op_loc, op_builder));
|
||||
|
||||
// Add the results to the value maps. There are two cases: 1. the result
|
||||
// tensor does not have min/max values, the original op result is used
|
||||
|
@ -44,6 +44,7 @@ llvm::Optional<tflite::BuiltinOperator> GetBuiltinOpCode(Operation *mlir_op);
|
||||
llvm::Optional<flatbuffers::Offset<tflite::Operator>> CreateFlatBufferOperator(
|
||||
Operation *mlir_op, uint32_t opcode_index,
|
||||
const std::vector<int32_t> &operands, const std::vector<int32_t> &results,
|
||||
const std::vector<int32_t> &intermediates,
|
||||
flatbuffers::FlatBufferBuilder *fbb);
|
||||
|
||||
// Populates the array of mlir::NamedAttributes corresponding to the given
|
||||
|
@ -43,6 +43,7 @@ limitations under the License.
|
||||
#include "llvm/Support/ToolOutputFile.h"
|
||||
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||
#include "mlir/IR/Function.h" // TF:llvm-project
|
||||
#include "mlir/IR/Location.h" // TF:llvm-project
|
||||
@ -75,6 +76,7 @@ limitations under the License.
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
#include "tensorflow/lite/string_util.h"
|
||||
#include "tensorflow/lite/tools/versioning/op_version.h"
|
||||
#include "tensorflow/lite/tools/versioning/runtime_version.h"
|
||||
#include "tensorflow/lite/version.h"
|
||||
|
||||
using llvm::dyn_cast;
|
||||
@ -179,8 +181,6 @@ static StatusOr<tflite::TensorType> GetTFLiteType(Type type,
|
||||
return tflite::TensorType_FLOAT16;
|
||||
case mlir::TF::TensorFlowTypes::STRING:
|
||||
return tflite::TensorType_STRING;
|
||||
case mlir::TF::TensorFlowTypes::UINT8:
|
||||
return tflite::TensorType_UINT8;
|
||||
case mlir::TF::TensorFlowTypes::QUINT8:
|
||||
return tflite::TensorType_UINT8;
|
||||
case mlir::StandardTypes::Complex: {
|
||||
@ -196,7 +196,8 @@ static StatusOr<tflite::TensorType> GetTFLiteType(Type type,
|
||||
case 1:
|
||||
return tflite::TensorType_BOOL;
|
||||
case 8:
|
||||
return tflite::TensorType_INT8;
|
||||
return itype.isUnsigned() ? tflite::TensorType_UINT8
|
||||
: tflite::TensorType_INT8;
|
||||
case 16:
|
||||
return tflite::TensorType_INT16;
|
||||
case 32:
|
||||
@ -404,6 +405,11 @@ class Translator {
|
||||
// and returns llvm::None on failure.
|
||||
Optional<BufferOffset<tflite::Buffer>> BuildBuffer(Operation* inst);
|
||||
|
||||
// Build TFLite tensor from the given type. This function is for tfl.lstm
|
||||
// intermediates, which should have UniformQuantizedType.
|
||||
Optional<BufferOffset<tflite::Tensor>> BuildTensorFromType(
|
||||
mlir::Type type, const std::string& name);
|
||||
|
||||
// Builds TFLite tensor from the given value. `buffer_idx` is index of the
|
||||
// corresponding buffer. Emits error and returns llvm::None on failure.
|
||||
Optional<BufferOffset<tflite::Tensor>> BuildTensor(Value value,
|
||||
@ -469,7 +475,8 @@ class Translator {
|
||||
// tensor indices. Emits an error and returns llvm::None on failure.
|
||||
Optional<BufferOffset<tflite::Operator>> BuildOperator(
|
||||
Operation* inst, const std::vector<int32_t>& operands,
|
||||
const std::vector<int32_t>& results);
|
||||
const std::vector<int32_t>& results,
|
||||
const std::vector<int32_t>& intermediates);
|
||||
|
||||
// Build a subgraph with a given name out of the region either corresponding
|
||||
// to a function's body or while op.
|
||||
@ -581,6 +588,34 @@ Optional<BufferOffset<tflite::Buffer>> Translator::BuildBuffer(
|
||||
return tflite::CreateBuffer(builder_, buffer_data);
|
||||
}
|
||||
|
||||
Optional<BufferOffset<tflite::Tensor>> Translator::BuildTensorFromType(
|
||||
mlir::Type type, const std::string& name) {
|
||||
auto tensor_type = type.cast<TensorType>();
|
||||
|
||||
if (!tensor_type.hasStaticShape()) {
|
||||
return llvm::None;
|
||||
}
|
||||
llvm::ArrayRef<int64_t> shape_ref = tensor_type.getShape();
|
||||
std::vector<int32_t> shape(shape_ref.begin(), shape_ref.end());
|
||||
|
||||
auto element_type = tensor_type.getElementType();
|
||||
tflite::TensorType tflite_element_type =
|
||||
GetTFLiteType(tensor_type.getElementType()).ValueOrDie();
|
||||
BufferOffset<tflite::QuantizationParameters> q_params;
|
||||
auto qtype = element_type.dyn_cast<mlir::quant::UniformQuantizedType>();
|
||||
if (!qtype) {
|
||||
return llvm::None;
|
||||
}
|
||||
q_params = tflite::CreateQuantizationParameters(
|
||||
builder_, /*min=*/0, /*max=*/0,
|
||||
builder_.CreateVector<float>({static_cast<float>(qtype.getScale())}),
|
||||
builder_.CreateVector<int64_t>({qtype.getZeroPoint()}));
|
||||
return tflite::CreateTensor(
|
||||
builder_, builder_.CreateVector(shape), tflite_element_type,
|
||||
/*buffer=*/0, builder_.CreateString(name), q_params,
|
||||
/*is_variable=*/false);
|
||||
}
|
||||
|
||||
Optional<BufferOffset<tflite::Tensor>> Translator::BuildTensor(
|
||||
Value value, const std::string& name, unsigned buffer_idx) {
|
||||
auto type = value.getType().cast<TensorType>();
|
||||
@ -933,7 +968,8 @@ uint32_t Translator::GetOpcodeIndex(const std::string& op_name,
|
||||
|
||||
Optional<BufferOffset<tflite::Operator>> Translator::BuildOperator(
|
||||
Operation* inst, const std::vector<int32_t>& operands,
|
||||
const std::vector<int32_t>& results) {
|
||||
const std::vector<int32_t>& results,
|
||||
const std::vector<int32_t>& intermediates) {
|
||||
const auto* dialect = inst->getDialect();
|
||||
if (!dialect) {
|
||||
inst->emitOpError("dialect is not registered");
|
||||
@ -986,7 +1022,7 @@ Optional<BufferOffset<tflite::Operator>> Translator::BuildOperator(
|
||||
std::string op_name = inst->getName().getStringRef().str();
|
||||
uint32_t opcode_index = GetOpcodeIndex(op_name, *builtin_code);
|
||||
auto offset = CreateFlatBufferOperator(inst, opcode_index, operands,
|
||||
results, &builder_);
|
||||
results, intermediates, &builder_);
|
||||
if (!offset) {
|
||||
inst->emitOpError("is not a supported TFLite op");
|
||||
}
|
||||
@ -1171,6 +1207,29 @@ Optional<BufferOffset<tflite::SubGraph>> Translator::BuildSubGraph(
|
||||
bool failed_once = false;
|
||||
for (auto& inst : bb) {
|
||||
if (inst.isKnownTerminator()) break;
|
||||
std::vector<int32_t> intermediates;
|
||||
// Build intermediate tensors for tfl.lstm and insert these tensors into
|
||||
// flatbuffer.
|
||||
if (llvm::isa<mlir::TFL::LSTMOp>(inst)) {
|
||||
std::vector<std::string> intermediate_names = {
|
||||
"input_to_input_intermediate", "input_to_forget_intermediate",
|
||||
"input_to_cell_intermediate", "input_to_output_intermediate",
|
||||
"effective_hidden_scale_intermediate"};
|
||||
for (const std::string& intermediate : intermediate_names) {
|
||||
auto intermediate_attr = inst.getAttr(intermediate);
|
||||
if (auto attr = intermediate_attr.dyn_cast_or_null<mlir::TypeAttr>()) {
|
||||
Type qtype = attr.getValue();
|
||||
auto tensor_or = BuildTensorFromType(
|
||||
qtype, name_mapper_.GetUniqueName(intermediate).str());
|
||||
if (!tensor_or.hasValue()) {
|
||||
continue;
|
||||
} else {
|
||||
intermediates.push_back(tensors.size());
|
||||
tensors.push_back(tensor_or.getValue());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (auto val : inst.getResults()) {
|
||||
std::string name = UniqueName(val);
|
||||
@ -1195,7 +1254,8 @@ Optional<BufferOffset<tflite::SubGraph>> Translator::BuildSubGraph(
|
||||
results.push_back(tensor_index_map.lookup(result));
|
||||
}
|
||||
|
||||
if (auto tfl_operator = BuildOperator(&inst, operands, results))
|
||||
if (auto tfl_operator =
|
||||
BuildOperator(&inst, operands, results, intermediates))
|
||||
operators.push_back(*tfl_operator);
|
||||
else
|
||||
failed_once = true;
|
||||
@ -1230,27 +1290,58 @@ BufferOffset<tflite::Metadata> Translator::BuildMetadata(StringRef name,
|
||||
Optional<VectorBufferOffset<BufferOffset<tflite::Metadata>>>
|
||||
Translator::CreateMetadataVector() {
|
||||
auto dict_attr = module_.getAttrOfType<mlir::DictionaryAttr>("tfl.metadata");
|
||||
if (!dict_attr) return VectorBufferOffset<BufferOffset<tflite::Metadata>>();
|
||||
|
||||
std::vector<BufferOffset<tflite::Metadata>> metadata;
|
||||
for (const auto& named_attr : dict_attr) {
|
||||
StringRef name = named_attr.first;
|
||||
mlir::Attribute attr = named_attr.second;
|
||||
if (auto content = attr.dyn_cast<StringAttr>()) {
|
||||
metadata.push_back(BuildMetadata(name, content.getValue()));
|
||||
} else {
|
||||
module_.emitError(
|
||||
"all values in tfl.metadata's dictionary key-value pairs should be "
|
||||
"string attributes");
|
||||
return llvm::None;
|
||||
if (dict_attr) {
|
||||
for (const auto& named_attr : dict_attr) {
|
||||
StringRef name = named_attr.first;
|
||||
mlir::Attribute attr = named_attr.second;
|
||||
if (auto content = attr.dyn_cast<StringAttr>()) {
|
||||
metadata.push_back(BuildMetadata(name, content.getValue()));
|
||||
} else {
|
||||
module_.emitError(
|
||||
"all values in tfl.metadata's dictionary key-value pairs should be "
|
||||
"string attributes");
|
||||
return llvm::None;
|
||||
}
|
||||
}
|
||||
}
|
||||
// Runtime version string is generated after we update the op
|
||||
// versions. Here we put a 16-byte dummy string as a placeholder. We choose
|
||||
// 16-byte because it's the alignment of buffers in flatbuffer, so it won't
|
||||
// cause any waste of space if the actual string is shorter than 16 bytes.
|
||||
metadata.push_back(
|
||||
BuildMetadata("min_runtime_version", std::string(16, '\0')));
|
||||
return builder_.CreateVector(metadata);
|
||||
}
|
||||
|
||||
bool UpdateEntryFunction(ModuleOp module) {
|
||||
if (module.lookupSymbol<FuncOp>("main") != nullptr) {
|
||||
// We already have an entry function.
|
||||
return true;
|
||||
}
|
||||
|
||||
int entry_func_count = 0;
|
||||
FuncOp entry_func = nullptr;
|
||||
for (auto fn : module.getOps<FuncOp>()) {
|
||||
auto attrs = fn.getAttrOfType<mlir::DictionaryAttr>("tf.entry_function");
|
||||
if (attrs && !attrs.empty()) {
|
||||
entry_func_count++;
|
||||
entry_func = fn;
|
||||
}
|
||||
}
|
||||
|
||||
// We should have one & only have one entry function.
|
||||
if (entry_func_count != 1) return false;
|
||||
|
||||
// Update the entry func to main.
|
||||
entry_func.setName("main");
|
||||
return true;
|
||||
}
|
||||
|
||||
Optional<std::string> Translator::Translate(
|
||||
ModuleOp module, bool emit_builtin_tflite_ops, bool emit_select_tf_ops,
|
||||
bool emit_custom_ops, OpOrArgNameMapper* op_or_arg_name_mapper) {
|
||||
if (!UpdateEntryFunction(module)) return llvm::None;
|
||||
if (!IsValidTFLiteMlirModule(module)) return llvm::None;
|
||||
Translator translator(module, emit_builtin_tflite_ops, emit_select_tf_ops,
|
||||
emit_custom_ops, op_or_arg_name_mapper);
|
||||
@ -1334,6 +1425,7 @@ Optional<std::string> Translator::TranslateInternal() {
|
||||
builder_.CreateVector(buffers_), metadata_buffer, *metadata);
|
||||
tflite::FinishModelBuffer(builder_, model);
|
||||
tflite::UpdateOpVersion(builder_.GetBufferPointer());
|
||||
tflite::UpdateMinimumRuntimeVersionForModel(builder_.GetBufferPointer());
|
||||
|
||||
// Return serialized string for the built FlatBuffer.
|
||||
return std::string(reinterpret_cast<const char*>(builder_.GetBufferPointer()),
|
||||
|
@ -36,6 +36,7 @@ limitations under the License.
|
||||
#include "mlir/IR/TypeUtilities.h" // TF:llvm-project
|
||||
#include "mlir/Support/LLVM.h" // TF:llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // TF:llvm-project
|
||||
#include "mlir/Transforms/FoldUtils.h" // TF:llvm-project
|
||||
#include "mlir/Transforms/InliningUtils.h" // TF:llvm-project
|
||||
#include "mlir/Transforms/RegionUtils.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
|
||||
@ -66,13 +67,29 @@ struct TensorFlowLiteInlinerInterface : public DialectInlinerInterface {
|
||||
}
|
||||
};
|
||||
|
||||
struct TensorFlowLiteOpFolderDialectInterface
|
||||
: public OpFolderDialectInterface {
|
||||
using OpFolderDialectInterface::OpFolderDialectInterface;
|
||||
|
||||
// Registered hook to check if the given region, which is attached to an
|
||||
// operation that is *not* isolated from above (i.e. no internal regions
|
||||
// reference values defined in an enclosing region), should be used when
|
||||
// materializing constants.
|
||||
// In the TFLite dialect we materialize inside a while regions as slightly
|
||||
// more efficient computationally.
|
||||
bool shouldMaterializeInto(Region *region) const final {
|
||||
return isa<WhileOp>(region->getParentOp());
|
||||
}
|
||||
};
|
||||
|
||||
TensorFlowLiteDialect::TensorFlowLiteDialect(mlir::MLIRContext *context)
|
||||
: Dialect(/*name=*/"tfl", context) {
|
||||
addOperations<
|
||||
#define GET_OP_LIST
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.cc.inc"
|
||||
>();
|
||||
addInterfaces<TensorFlowLiteInlinerInterface>();
|
||||
addInterfaces<TensorFlowLiteInlinerInterface,
|
||||
TensorFlowLiteOpFolderDialectInterface>();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -1269,6 +1286,20 @@ static LogicalResult Verify(UnidirectionalSequenceLSTMOp op) {
|
||||
"UnidirectionalSequenceLSTMOp expected to have two stateful operands");
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// BidirectionalSequenceLSTMOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult Verify(BidirectionalSequenceLSTMOp op) {
|
||||
auto operands = op.GetStatefulOperands();
|
||||
if (operands.size() == 4 && operands[0] == 35 && operands[1] == 36 &&
|
||||
operands[2] == 37 && operands[3] == 38) {
|
||||
return success();
|
||||
}
|
||||
return op.emitError(
|
||||
"BidirectionalSequenceLSTMOp expected to have four stateful operands");
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// UnidirectionalSequenceRNNOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -25,9 +25,11 @@ limitations under the License.
|
||||
#include "mlir/IR/Dialect.h" // TF:llvm-project
|
||||
#include "mlir/IR/OpImplementation.h" // TF:llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
|
||||
#include "mlir/Interfaces/DerivedAttributeOpInterface.h" // TF:llvm-project
|
||||
#include "mlir/Interfaces/LoopLikeInterface.h" // TF:llvm-project
|
||||
#include "mlir/Interfaces/SideEffects.h" // TF:llvm-project
|
||||
#include "mlir/Support/Functional.h" // TF:llvm-project
|
||||
#include "mlir/Support/LLVM.h" // TF:llvm-project
|
||||
#include "mlir/Transforms/LoopLikeInterface.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
|
||||
@ -39,6 +41,8 @@ class TensorFlowLiteDialect : public Dialect {
|
||||
public:
|
||||
explicit TensorFlowLiteDialect(MLIRContext *context);
|
||||
|
||||
static StringRef getDialectNamespace() { return "tfl"; }
|
||||
|
||||
// Registered hook to materialize a constant operation from a given attribute
|
||||
// value with the desired resultant type.
|
||||
Operation *materializeConstant(OpBuilder &builder, Attribute value, Type type,
|
||||
|
@ -19,7 +19,8 @@ limitations under the License.
|
||||
#define TFL_OPS
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "mlir/Transforms/LoopLikeInterface.td"
|
||||
include "mlir/Interfaces/LoopLikeInterface.td"
|
||||
include "mlir/Interfaces/SideEffects.td"
|
||||
include "tensorflow/compiler/mlir/lite/ir/tfl_op_interfaces.td"
|
||||
include "tensorflow/compiler/mlir/lite/quantization/quantization.td"
|
||||
|
||||
@ -47,13 +48,6 @@ def TFL_Str : Type<CPred<"$_self.isa<mlir::TF::StringType>()">,
|
||||
"TFLite string type">,
|
||||
BuildableType<"getType<mlir::TF::StringType>()">;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TFLite dialect uint8 type - uses the TF uint8 type as implementation
|
||||
//===----------------------------------------------------------------------===//
|
||||
def TFL_Uint8 : Type<CPred<"$_self.isa<mlir::TF::Uint8Type>()">,
|
||||
"TFLite uint8 type">,
|
||||
BuildableType<"getType<mlir::TF::Uint8Type>()">;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TFLite dialect quint8 type - uses the TF quint8 type as implementation
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -141,7 +135,8 @@ class TFL_VariadicTensorOf<list<Type> allowedRuntimeTypes,
|
||||
Variadic<TensorOf<allowedOpTypes>>,
|
||||
TFL_RuntimeType<Variadic<TensorOf<allowedRuntimeTypes>>>;
|
||||
|
||||
def TFL_Int32Or64 : IntOfWidths<[32, 64]>;
|
||||
def TFL_Uint8 : UI<8>;
|
||||
def TFL_Int32Or64 : SignlessIntOfWidths<[32, 64]>;
|
||||
|
||||
def TFL_BoolTensor : TFL_TensorOf<[I1]>;
|
||||
def TFL_FpOrI32OrI64Tensor : TFL_TensorOf<[AnyFloat, TFL_Int32Or64]>;
|
||||
@ -223,9 +218,9 @@ class TFL_Operand0DOr1ElementTensor<int x> :
|
||||
class TFL_TFTypesWithSameBits<int i, int j, int num> :
|
||||
And<[
|
||||
Or<[CPred<"getElementTypeOrSelf($_op.getResult(" # i # ")).isa<mlir::TF::Quint" # num # "Type>()">,
|
||||
CPred<"getElementTypeOrSelf($_op.getResult(" # i # ")).isa<mlir::TF::Uint" # num # "Type>()">]>,
|
||||
CPred<"getElementTypeOrSelf($_op.getResult(" # i # ")).isUnsignedInteger(" # num # ")">]>,
|
||||
Or<[CPred<"getElementTypeOrSelf($_op.getOperand(" # j # ")).isa<mlir::TF::Quint" # num # "Type>()">,
|
||||
CPred<"getElementTypeOrSelf($_op.getOperand(" # j # ")).isa<mlir::TF::Uint" # num # "Type>()">]>]>;
|
||||
CPred<"getElementTypeOrSelf($_op.getOperand(" # j # ")).isUnsignedInteger(" # num # ")">]>]>;
|
||||
|
||||
class TFL_OperandHasRankLessThan<int n, int m> :
|
||||
PredOpTrait<"operand " # n # " is maximum " # m # "-D",
|
||||
@ -602,7 +597,7 @@ def TFL_ConcatenationOp : TFL_Op<"concatenation",
|
||||
let verifier = [{ return Verify(*this); }];
|
||||
}
|
||||
|
||||
def TFL_ConstOp : Op<TFL_Dialect, "pseudo_const", [NoSideEffect,
|
||||
def TFL_ConstOp : Op<TFL_Dialect, "pseudo_const", [ConstantLike, NoSideEffect,
|
||||
FirstAttrDerivedResultType]> {
|
||||
let summary = "Constant pseudo op.";
|
||||
|
||||
@ -1863,11 +1858,11 @@ def TFL_MulOp : TFL_Op<"mul", [ResultsBroadcastableShape, NoSideEffect, Commutat
|
||||
}];
|
||||
|
||||
let arguments = (
|
||||
ins AnyTensor:$lhs,
|
||||
AnyTensor:$rhs,
|
||||
ins TFL_TensorOf<[F32, I32, QI8, QUI8, QI16]>:$lhs,
|
||||
TFL_TensorOf<[F32, I32, QI8, QUI8, QI16]>:$rhs,
|
||||
TFL_AFAttr:$fused_activation_function);
|
||||
|
||||
let results = (outs AnyTensor:$output);
|
||||
let results = (outs TFL_TensorOf<[F32, I32, QI8, QUI8, QI16]>:$output);
|
||||
|
||||
let hasFolder = 1;
|
||||
|
||||
@ -1887,9 +1882,9 @@ def TFL_NegOp: TFL_Op<"neg", [NoSideEffect, SameOperandsAndResultType]> {
|
||||
Computes element-wise negation of input
|
||||
}];
|
||||
|
||||
let arguments = (ins AnyTensor:$x);
|
||||
let arguments = (ins TFL_TensorOf<[F32, I32, I64]>:$x);
|
||||
|
||||
let results = (outs AnyTensor:$y);
|
||||
let results = (outs TFL_TensorOf<[F32, I32, I64]>:$y);
|
||||
|
||||
let hasOptions = 0b1;
|
||||
|
||||
@ -2039,10 +2034,10 @@ def TFL_PowOp : TFL_Op<"pow", [ResultsBroadcastableShape, NoSideEffect, NoQuanti
|
||||
}];
|
||||
|
||||
let arguments = (
|
||||
ins AnyTensor:$lhs,
|
||||
AnyTensor:$rhs);
|
||||
ins TFL_TensorOf<[F32, I32]>:$lhs,
|
||||
TFL_TensorOf<[F32, I32]>:$rhs);
|
||||
|
||||
let results = (outs AnyTensor:$output);
|
||||
let results = (outs TFL_TensorOf<[F32, I32]>:$output);
|
||||
|
||||
let parser = [{ return mlir::impl::parseOneResultSameOperandTypeOp(parser, result); }];
|
||||
|
||||
@ -2716,7 +2711,7 @@ def TFL_SplitOp : TFL_Op<"split", [
|
||||
let arguments = (ins
|
||||
TFL_TensorOf<[I32]>:$split_dim,
|
||||
TFL_TensorOf<[F32, I16, I32, I64, QI8, QUI8, QI16]>:$value,
|
||||
PositiveI32Attr:$num_splits
|
||||
Confined<I32Attr, [IntPositive]>:$num_splits
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
@ -2741,7 +2736,7 @@ def TFL_SplitVOp : TFL_Op<"split_v", [NoSideEffect, SameOperandsAndResultsScale]
|
||||
TFL_TensorOf<[F32, I16, I32, I64, QI8, QUI8, QI16]>:$value,
|
||||
TFL_1DTensorOf<[I32], [I32]>:$size_splits,
|
||||
TFL_0DTensorOf<[I32], [I32]>:$split_dim,
|
||||
PositiveI32Attr:$num_splits
|
||||
Confined<I32Attr, [IntPositive]>:$num_splits
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
@ -3246,7 +3241,15 @@ Ba et al. 'Layer Normalization'
|
||||
// Since this op is the FULL kernel only, constrain it.
|
||||
Confined<
|
||||
DefaultValuedAttr<TFL_LSTMKernelTypeAttr, "FULL">,
|
||||
[TFL_LSTM_KT_FULL]>:$kernel_type
|
||||
[TFL_LSTM_KT_FULL]>:$kernel_type,
|
||||
|
||||
// Types of the optional intermediate tensors, which exist for fully
|
||||
// quantized LSTM op and hold the ranges of the intermediate tensors.
|
||||
OptionalAttr<TypeAttr>:$input_to_input_intermediate,
|
||||
OptionalAttr<TypeAttr>:$input_to_forget_intermediate,
|
||||
OptionalAttr<TypeAttr>:$input_to_cell_intermediate,
|
||||
OptionalAttr<TypeAttr>:$input_to_output_intermediate,
|
||||
OptionalAttr<TypeAttr>:$effective_hidden_scale_intermediate
|
||||
);
|
||||
|
||||
let results = (outs AnyTensor:$output);
|
||||
@ -3350,6 +3353,156 @@ def TFL_UnidirectionalSequenceLSTMOp :
|
||||
}];
|
||||
}
|
||||
|
||||
def BidiLstmMandatoryInputsConstraint : PredOpTrait<
|
||||
"mandatory operands element types should match",
|
||||
// TODO(ashwinm): Replace the indices with input tensor names when that
|
||||
// support is available.
|
||||
Or<[
|
||||
TCopVTEtAreSameAt<[0, 2, 3, 4, 6, 7, 8, 13, 14, 15, 19, 20, 21, 23, 24, 25,
|
||||
30, 31, 32, 35, 36, 37, 38]>,
|
||||
Neg<TypeIsPred<"input", F32>>]>>;
|
||||
|
||||
def BidiLstmOptionalPeepholeWeightConstraint : PredOpTrait<
|
||||
"the optional peephole weights should all be specified or none",
|
||||
TCopVTEtAreSameAt<[9, 10, 11, 26, 27, 28]>>;
|
||||
|
||||
def BidiLstmProjectionWeightBiasConstraint : PredOpTrait<
|
||||
"either projection weight must be specified or both projection weight and "
|
||||
"projection bias must not be specified",
|
||||
Or<[
|
||||
And<[TypeIsPred<"fw_projection_weights", NoneType>,
|
||||
TypeIsPred<"fw_projection_bias", NoneType>,
|
||||
TypeIsPred<"bw_projection_weights", NoneType>,
|
||||
TypeIsPred<"bw_projection_bias", NoneType>]>,
|
||||
And<[
|
||||
Neg<TypeIsPred<"fw_projection_weights", NoneType>>,
|
||||
Neg<TypeIsPred<"bw_projection_weights", NoneType>>,
|
||||
]>
|
||||
]>>;
|
||||
|
||||
// BidirectionalSequenceLstm op.
|
||||
// TODO(ashwinm): Add constraint to validate the combination of operands
|
||||
// that are valid for hybrid vs fully quantized vs float only semantics
|
||||
def TFL_BidirectionalSequenceLSTMOp :
|
||||
TFL_Op<"bidirectional_sequence_lstm",
|
||||
[BidiLstmMandatoryInputsConstraint,
|
||||
BidiLstmOptionalPeepholeWeightConstraint,
|
||||
BidiLstmProjectionWeightBiasConstraint,
|
||||
LstmResultConstraint,
|
||||
TFL_StatefulOp]> {
|
||||
let summary = "Bidirectional sequence lstm operator";
|
||||
|
||||
let description = [{
|
||||
Bidirectional lstm is essentiallay two lstms, one running forward & the
|
||||
other running backward. And the output is the concatenation of the two
|
||||
lstms.
|
||||
}];
|
||||
|
||||
let arguments = (
|
||||
ins TFL_TensorOf<[F32, I8]>:$input,
|
||||
|
||||
// Forward LSTM Weights
|
||||
TFL_TensorOfOrNone<[F32, I8]>:$fw_input_to_input_weights,
|
||||
TFL_TensorOf<[F32, I8]>:$fw_input_to_forget_weights,
|
||||
TFL_TensorOf<[F32, I8]>:$fw_input_to_cell_weights,
|
||||
TFL_TensorOf<[F32, I8]>:$fw_input_to_output_weights,
|
||||
|
||||
// Forward Recurrent weights
|
||||
TFL_TensorOfOrNone<[F32, I8]>:$fw_recurrent_to_input_weights,
|
||||
TFL_TensorOf<[F32, I8]>:$fw_recurrent_to_forget_weights,
|
||||
TFL_TensorOf<[F32, I8]>:$fw_recurrent_to_cell_weights,
|
||||
TFL_TensorOf<[F32, I8]>:$fw_recurrent_to_output_weights,
|
||||
|
||||
// Forward Cell weights
|
||||
TFL_TensorOfOrNone<[F32, I8]>:$fw_cell_to_input_weights,
|
||||
// Optional Forward cell weights
|
||||
TFL_TensorOfOrNone<[F32, I8]>:$fw_cell_to_forget_weights,
|
||||
// Optional Forward cell weights
|
||||
TFL_TensorOfOrNone<[F32, I8]>:$fw_cell_to_output_weights,
|
||||
|
||||
// Forward Bias
|
||||
TFL_TensorOfOrNone<[F32]>:$fw_input_gate_bias,
|
||||
TFL_TensorOf<[F32]>:$fw_forget_gate_bias,
|
||||
TFL_TensorOf<[F32]>:$fw_cell_bias,
|
||||
TFL_TensorOf<[F32]>:$fw_output_gate_bias,
|
||||
|
||||
// Forward Projection weight and bias
|
||||
TFL_TensorOfOrNone<[F32, I8]>:$fw_projection_weights,
|
||||
// Forward Optional input
|
||||
TFL_TensorOfOrNone<[F32]>:$fw_projection_bias,
|
||||
|
||||
// Backward LSTM Weights
|
||||
TFL_TensorOfOrNone<[F32, I8]>:$bw_input_to_input_weights,
|
||||
TFL_TensorOf<[F32, I8]>:$bw_input_to_forget_weights,
|
||||
TFL_TensorOf<[F32, I8]>:$bw_input_to_cell_weights,
|
||||
TFL_TensorOf<[F32, I8]>:$bw_input_to_output_weights,
|
||||
|
||||
// Backward Recurrent weights
|
||||
TFL_TensorOfOrNone<[F32, I8]>:$bw_recurrent_to_input_weights,
|
||||
TFL_TensorOf<[F32, I8]>:$bw_recurrent_to_forget_weights,
|
||||
TFL_TensorOf<[F32, I8]>:$bw_recurrent_to_cell_weights,
|
||||
TFL_TensorOf<[F32, I8]>:$bw_recurrent_to_output_weights,
|
||||
|
||||
// Backward Cell weights
|
||||
TFL_TensorOfOrNone<[F32, I8]>:$bw_cell_to_input_weights,
|
||||
// Optional Forward cell weights
|
||||
TFL_TensorOfOrNone<[F32, I8]>:$bw_cell_to_forget_weights,
|
||||
// Optional Forward cell weights
|
||||
TFL_TensorOfOrNone<[F32, I8]>:$bw_cell_to_output_weights,
|
||||
|
||||
// Backward Bias
|
||||
TFL_TensorOfOrNone<[F32]>:$bw_input_gate_bias,
|
||||
TFL_TensorOf<[F32]>:$bw_forget_gate_bias,
|
||||
TFL_TensorOf<[F32]>:$bw_cell_bias,
|
||||
TFL_TensorOf<[F32]>:$bw_output_gate_bias,
|
||||
|
||||
// Backward Projection weight and bias
|
||||
TFL_TensorOfOrNone<[F32, I8]>:$bw_projection_weights,
|
||||
// Backward Optional input
|
||||
TFL_TensorOfOrNone<[F32]>:$bw_projection_bias,
|
||||
|
||||
// Stateful activation and cell states.
|
||||
TFL_StatefulTensor:$fw_input_activation_state,
|
||||
TFL_StatefulTensor:$fw_input_cell_state,
|
||||
TFL_StatefulTensor:$bw_input_activation_state,
|
||||
TFL_StatefulTensor:$bw_input_cell_state,
|
||||
|
||||
// Auxiliary input & weights.
|
||||
TFL_TensorOfOrNone<[F32, I8]>:$aux_input,
|
||||
// Auxiliary fw weights.
|
||||
TFL_TensorOfOrNone<[F32, I8]>:$fw_aux_input_to_input_weights,
|
||||
TFL_TensorOfOrNone<[F32, I8]>:$fw_aux_input_to_forget_weights,
|
||||
TFL_TensorOfOrNone<[F32, I8]>:$fw_aux_input_to_cell_weights,
|
||||
TFL_TensorOfOrNone<[F32, I8]>:$fw_aux_input_to_output_weights,
|
||||
// Auxiliary bw weights.
|
||||
TFL_TensorOfOrNone<[F32, I8]>:$bw_aux_input_to_input_weights,
|
||||
TFL_TensorOfOrNone<[F32, I8]>:$bw_aux_input_to_forget_weights,
|
||||
TFL_TensorOfOrNone<[F32, I8]>:$bw_aux_input_to_cell_weights,
|
||||
TFL_TensorOfOrNone<[F32, I8]>:$bw_aux_input_to_output_weights,
|
||||
|
||||
// Attributes
|
||||
TFL_AFAttr:$fused_activation_function,
|
||||
DefaultValuedAttr<F32Attr, "0.0f">:$cell_clip,
|
||||
DefaultValuedAttr<F32Attr, "0.0f">:$proj_clip,
|
||||
BoolAttr:$merge_outputs,
|
||||
BoolAttr:$time_major
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
AnyTensor:$fw_output,
|
||||
AnyTensor:$bw_output
|
||||
);
|
||||
|
||||
let hasOptions = 1;
|
||||
|
||||
let verifier = [{ return Verify(*this); }];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
// StatefulOpInterface:
|
||||
std::vector<int> GetStatefulOperands() { return {35, 36, 37, 38}; }
|
||||
}];
|
||||
}
|
||||
|
||||
def RnnResultConstraint : PredOpTrait<
|
||||
"the input and result tensor elemental types must be same",
|
||||
TCresVTEtIsSameAsOp<0, 0>>;
|
||||
|
@ -10,11 +10,9 @@ package_group(
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "graphdef_to_tfl_flatbuffer",
|
||||
srcs = ["graphdef_to_tfl_flatbuffer.cc"],
|
||||
hdrs = [
|
||||
"graphdef_to_tfl_flatbuffer.h",
|
||||
],
|
||||
name = "tf_tfl_flatbuffer_helpers",
|
||||
srcs = ["tf_tfl_flatbuffer_helpers.cc"],
|
||||
hdrs = ["tf_tfl_flatbuffer_helpers.h"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/mlir/lite:common",
|
||||
"//tensorflow/compiler/mlir/lite:tensorflow_lite",
|
||||
@ -36,3 +34,61 @@ cc_library(
|
||||
"@llvm-project//mlir:Transforms",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "graphdef_to_tfl_flatbuffer",
|
||||
srcs = ["graphdef_to_tfl_flatbuffer.cc"],
|
||||
hdrs = [
|
||||
"graphdef_to_tfl_flatbuffer.h",
|
||||
],
|
||||
deps = [
|
||||
":tf_tfl_flatbuffer_helpers",
|
||||
"//tensorflow/compiler/mlir/lite:common",
|
||||
"//tensorflow/compiler/mlir/lite:tensorflow_lite",
|
||||
"//tensorflow/compiler/mlir/lite:tf_tfl_passes",
|
||||
"//tensorflow/compiler/mlir/lite:tf_to_tfl_flatbuffer",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"//tensorflow/compiler/mlir/tensorflow:convert_graphdef",
|
||||
"//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/lite/toco:model_flags_proto_cc",
|
||||
"//tensorflow/lite/toco:toco_flags_proto_cc",
|
||||
"//tensorflow/lite/toco:types_proto_cc",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:Support",
|
||||
"@llvm-project//mlir:Transforms",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "saved_model_to_tfl_flatbuffer",
|
||||
srcs = ["saved_model_to_tfl_flatbuffer.cc"],
|
||||
hdrs = [
|
||||
"saved_model_to_tfl_flatbuffer.h",
|
||||
],
|
||||
deps = [
|
||||
":tf_tfl_flatbuffer_helpers",
|
||||
"//tensorflow/compiler/mlir/lite:common",
|
||||
"//tensorflow/compiler/mlir/lite:tensorflow_lite",
|
||||
"//tensorflow/compiler/mlir/lite:tf_tfl_passes",
|
||||
"//tensorflow/compiler/mlir/lite:tf_to_tfl_flatbuffer",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"//tensorflow/compiler/mlir/tensorflow:convert_graphdef",
|
||||
"//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/lite/toco:model_flags_proto_cc",
|
||||
"//tensorflow/lite/toco:toco_flags_proto_cc",
|
||||
"//tensorflow/lite/toco:types_proto_cc",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:Support",
|
||||
"@llvm-project//mlir:Transforms",
|
||||
],
|
||||
)
|
||||
|
@ -25,6 +25,7 @@ limitations under the License.
|
||||
#include "mlir/Support/FileUtilities.h" // TF:llvm-project
|
||||
#include "mlir/Transforms/ViewOpGraph.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h"
|
||||
#include "tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h"
|
||||
#include "tensorflow/compiler/mlir/lite/tf_tfl_passes.h"
|
||||
#include "tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h"
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
|
||||
@ -40,288 +41,7 @@ limitations under the License.
|
||||
#include "tensorflow/lite/toco/types.pb.h"
|
||||
#include "tensorflow/stream_executor/lib/statusor.h"
|
||||
|
||||
using stream_executor::port::StatusOr;
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
// Op def string for TFLite_Detection_PostProcess Op.
|
||||
const char kDetectionPostProcessOp[] =
|
||||
"name: 'TFLite_Detection_PostProcess' input_arg: { name: "
|
||||
"'raw_outputs/box_encodings' type: DT_FLOAT } input_arg: { name: "
|
||||
"'raw_outputs/class_predictions' type: DT_FLOAT } input_arg: { name: "
|
||||
"'anchors' type: DT_FLOAT } output_arg: { name: "
|
||||
"'TFLite_Detection_PostProcess' type: DT_FLOAT } output_arg: { name: "
|
||||
"'TFLite_Detection_PostProcess:1' type: DT_FLOAT } output_arg: { name: "
|
||||
"'TFLite_Detection_PostProcess:2' type: DT_FLOAT } output_arg: { name: "
|
||||
"'TFLite_Detection_PostProcess:3' type: DT_FLOAT } attr : { name: "
|
||||
"'h_scale' type: 'float'} attr : { name: 'max_classes_per_detection' "
|
||||
"type: 'int'} attr : { name: 'max_detections' type: 'int'} attr : { "
|
||||
"name: 'nms_iou_threshold' type: 'float'} attr : { name: "
|
||||
"'nms_score_threshold' type: 'float'} attr : { name: 'num_classes' type: "
|
||||
"'int'} attr : { name: 'w_scale' type: 'float'} attr : { name: 'x_scale' "
|
||||
"type: 'float'} attr : { name: 'y_scale' type: 'float'} attr { name: "
|
||||
"'detections_per_class' type: 'int' default_value { i : 100 }} attr { "
|
||||
"name: 'use_regular_nms' type: 'bool' default_value { b : false }}";
|
||||
|
||||
const char kUnidirectionalSequenceLstmOp[] =
|
||||
"name: 'UnidirectionalSequenceLstm' input_arg: {name: 'Input' type: "
|
||||
"DT_FLOAT} input_arg: { name: 'InputToInputWeights' type: DT_FLOAT } "
|
||||
"input_arg: { name: 'InputToForgetWeights' type: DT_FLOAT } input_arg: { "
|
||||
"name: 'InputToCellWeights' type: DT_FLOAT} input_arg: { name: "
|
||||
"'InputToOutputWeights' type: DT_FLOAT } input_arg: { name: "
|
||||
"'RecurrentToInputWeights' type: DT_FLOAT} input_arg: { name: "
|
||||
"'RecurrentToForgetWeights' type: DT_FLOAT} input_arg: { name: "
|
||||
"'RecurrentToCellWeights' type: DT_FLOAT } input_arg: { name: "
|
||||
"'RecurrentToOutputWeights' type: DT_FLOAT } input_arg: { name: "
|
||||
"'CellToInputWeights' type: DT_FLOAT} input_arg: { name: "
|
||||
"'CellToForgetWeights' type: DT_FLOAT } input_arg: { name: "
|
||||
"'CellToOutputWeights' type: DT_FLOAT } input_arg: { name: 'InputGateBias' "
|
||||
"type: DT_FLOAT } input_arg: { name: 'ForgetGateBias' type: DT_FLOAT } "
|
||||
"input_arg: { name: 'kCellGateBias' type: DT_FLOAT } input_arg: { name: "
|
||||
"'OutputGateBias' type: DT_FLOAT } input_arg: { name: 'ProjectionWeights' "
|
||||
"type: DT_FLOAT } input_arg: { name: 'ProjectionBias' type: DT_FLOAT } "
|
||||
"input_arg: { name: 'InputActivationState' type: DT_FLOAT} input_arg: { "
|
||||
"name: 'InputCellStateTensor' type: DT_FLOAT } "
|
||||
"output_arg: { name: 'Concat' type: DT_FLOAT} "
|
||||
"output_arg: { name: "
|
||||
"'LastState' type: DT_FLOAT } output_arg: { name: 'Output' type: DT_FLOAT} "
|
||||
"attr : { name: '_tflite_input_indices' type: 'list(int)'}";
|
||||
|
||||
const char kUnidirectionalSequenceRnnOp[] =
|
||||
"name: 'UnidirectionalSequenceRnn' input_arg: {name: 'Input' type: "
|
||||
"DT_FLOAT} input_arg: { name: 'Weights' type: DT_FLOAT } "
|
||||
"input_arg: { name: 'RecurrentWeights' type: DT_FLOAT } input_arg: { "
|
||||
"name: 'Bias' type: DT_FLOAT} "
|
||||
"input_arg: { name: 'HiddenState' type: DT_FLOAT} "
|
||||
"output_arg: { name: "
|
||||
"'LastState' type: DT_FLOAT } output_arg: { name: 'Output' type: "
|
||||
"DT_FLOAT} "
|
||||
"attr : { name: '_tflite_input_indices' type: 'list(int)'}";
|
||||
|
||||
// Converts the toco::IODataType to tensorflow::DataType. Only contains the
|
||||
// conversion mapping for constants defined in TFLite Python API.
|
||||
DataType ConvertIODataTypeToDataType(toco::IODataType dtype) {
|
||||
switch (dtype) {
|
||||
case toco::IODataType::FLOAT:
|
||||
return DT_FLOAT;
|
||||
case toco::IODataType::QUANTIZED_UINT8:
|
||||
return DT_QUINT8;
|
||||
case toco::IODataType::INT8:
|
||||
return DT_QINT8;
|
||||
case toco::IODataType::INT32:
|
||||
return DT_INT32;
|
||||
case toco::IODataType::INT64:
|
||||
return DT_INT64;
|
||||
case toco::IODataType::STRING:
|
||||
return DT_STRING;
|
||||
case toco::IODataType::BOOL:
|
||||
return DT_BOOL;
|
||||
default:
|
||||
return DT_INVALID;
|
||||
}
|
||||
}
|
||||
|
||||
StatusOr<std::pair<double, double>> InputStatsToMinMax(double mean, double std,
|
||||
DataType type) {
|
||||
// Only qint8 and quint8 are considered here.
|
||||
double qmin, qmax;
|
||||
if (type == DT_QUINT8) {
|
||||
qmin = 0.0;
|
||||
qmax = 255.0;
|
||||
} else if (type == DT_QINT8) {
|
||||
qmin = -128.0;
|
||||
qmax = 127.0;
|
||||
} else {
|
||||
return errors::InvalidArgument("Only int8 and uint8 are considered.");
|
||||
}
|
||||
return std::make_pair((qmin - mean) / std, (qmax - mean) / std);
|
||||
}
|
||||
|
||||
// Give a warning for any unused flags that have been specified.
|
||||
void WarningUnusedFlags(const toco::ModelFlags& model_flags,
|
||||
const toco::TocoFlags& toco_flags) {
|
||||
if (toco_flags.output_format()) {
|
||||
LOG(WARNING) << "Ignored output_format.";
|
||||
}
|
||||
if (toco_flags.drop_control_dependency()) {
|
||||
LOG(WARNING) << "Ignored drop_control_dependency.";
|
||||
}
|
||||
if (toco_flags.reorder_across_fake_quant()) {
|
||||
LOG(WARNING) << "Ignored reorder_across_fake_quant.";
|
||||
}
|
||||
if (model_flags.change_concat_input_ranges()) {
|
||||
LOG(WARNING) << "Ignored change_concat_input_ranges.";
|
||||
}
|
||||
if (toco_flags.dump_graphviz_include_video()) {
|
||||
LOG(WARNING) << "Ignored dump_graphviz_video.";
|
||||
}
|
||||
if (model_flags.allow_nonexistent_arrays()) {
|
||||
LOG(WARNING) << "Allow allow_nonexistent_arrays.";
|
||||
}
|
||||
}
|
||||
|
||||
// Dumps the op graph of the `module` to `filename` in DOT format.
|
||||
Status DumpOpGraphToFile(mlir::ModuleOp module, const std::string& filename) {
|
||||
std::string error_message;
|
||||
auto output = mlir::openOutputFile(filename, &error_message);
|
||||
if (!error_message.empty()) {
|
||||
return errors::InvalidArgument("Failed to open file in %s.", filename);
|
||||
}
|
||||
mlir::PassManager pm(module.getContext());
|
||||
pm.addPass(mlir::createPrintOpGraphPass(output->os()));
|
||||
if (failed(pm.run(module))) {
|
||||
return errors::Unknown("Failed to dump Op Graph from MLIR module.");
|
||||
}
|
||||
output->keep();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status RegisterCustomBuiltinOps(const std::vector<string> extra_tf_opdefs) {
|
||||
for (const auto& tf_opdefs_string : extra_tf_opdefs) {
|
||||
tensorflow::OpDef opdef;
|
||||
if (!tensorflow::protobuf::TextFormat::ParseFromString(tf_opdefs_string,
|
||||
&opdef)) {
|
||||
return errors::InvalidArgument("fail to parse extra OpDef");
|
||||
}
|
||||
// Make sure the op is not already registered. If registered continue.
|
||||
const OpRegistrationData* op_reg =
|
||||
tensorflow::OpRegistry::Global()->LookUp(opdef.name());
|
||||
if (op_reg) continue;
|
||||
|
||||
tensorflow::OpRegistry::Global()->Register(
|
||||
[opdef](tensorflow::OpRegistrationData* op_reg_data) -> Status {
|
||||
*op_reg_data = tensorflow::OpRegistrationData(opdef);
|
||||
return Status::OK();
|
||||
});
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status RegisterAllCustomOps(const toco::TocoFlags& toco_flags) {
|
||||
// Register any custom OpDefs.
|
||||
std::vector<string> extra_tf_opdefs(toco_flags.custom_opdefs().begin(),
|
||||
toco_flags.custom_opdefs().end());
|
||||
extra_tf_opdefs.push_back(kDetectionPostProcessOp);
|
||||
extra_tf_opdefs.push_back(kUnidirectionalSequenceLstmOp);
|
||||
extra_tf_opdefs.push_back(kUnidirectionalSequenceRnnOp);
|
||||
return RegisterCustomBuiltinOps(extra_tf_opdefs);
|
||||
}
|
||||
|
||||
Status PopulateQuantizationSpecs(const toco::ModelFlags& model_flags,
|
||||
const toco::TocoFlags& toco_flags,
|
||||
mlir::TFL::QuantizationSpecs* quant_specs,
|
||||
std::vector<string>* node_names,
|
||||
std::vector<string>* node_dtypes,
|
||||
std::vector<std::vector<int>>* node_shapes,
|
||||
std::vector<double>* node_mins,
|
||||
std::vector<double>* node_maxs) {
|
||||
quant_specs->inference_input_type =
|
||||
ConvertIODataTypeToDataType(toco_flags.inference_input_type());
|
||||
tensorflow::DataType inference_type =
|
||||
ConvertIODataTypeToDataType(toco_flags.inference_type());
|
||||
// Use non-float flag `inference_input_type` to override the `inference_type`
|
||||
// because we have to apply quantization to satisfy that.
|
||||
if (quant_specs->inference_input_type != tensorflow::DT_FLOAT) {
|
||||
inference_type = quant_specs->inference_input_type;
|
||||
}
|
||||
|
||||
for (auto& flag : model_flags.input_arrays()) {
|
||||
node_names->push_back(flag.name());
|
||||
// TOCO doesn't required `data_type` to be filled for every input.
|
||||
// If it's not filled, make it an empty string so the importer will use
|
||||
// the data type in the NodeDef.
|
||||
auto toco_data_type = flag.data_type();
|
||||
if (toco_data_type == ::toco::IODataType::IO_DATA_TYPE_UNKNOWN) {
|
||||
node_dtypes->push_back("");
|
||||
} else {
|
||||
node_dtypes->push_back(
|
||||
DataType_Name(ConvertIODataTypeToDataType(toco_data_type)));
|
||||
}
|
||||
node_shapes->push_back(std::vector<int>(flag.shape().dims().begin(),
|
||||
flag.shape().dims().end()));
|
||||
// Currently, only UINT8 and INT8 require inputs stats
|
||||
if (inference_type == DT_QINT8 || inference_type == DT_QUINT8) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto min_max, InputStatsToMinMax(flag.mean_value(), flag.std_value(),
|
||||
inference_type));
|
||||
node_mins->push_back(min_max.first);
|
||||
node_maxs->push_back(min_max.second);
|
||||
}
|
||||
}
|
||||
|
||||
if (mlir::TFL::GetInputNodeQuantSpecs(*node_names, *node_mins, *node_maxs,
|
||||
inference_type, quant_specs)) {
|
||||
return errors::InvalidArgument("Failed to get input quant spec.");
|
||||
}
|
||||
|
||||
// Some extra flag related to post training quantization. If post-training
|
||||
// quantization is enabled, `inference_type` and `inference_input_type` are
|
||||
// not used by MLIR passes.
|
||||
if (toco_flags.post_training_quantize()) {
|
||||
quant_specs->weight_quantization = true;
|
||||
if (toco_flags.quantize_to_float16()) {
|
||||
quant_specs->inference_type = tensorflow::DT_HALF;
|
||||
quant_specs->inference_input_type = tensorflow::DT_HALF;
|
||||
} else {
|
||||
quant_specs->inference_type = tensorflow::DT_QINT8;
|
||||
quant_specs->inference_input_type = tensorflow::DT_QINT8;
|
||||
}
|
||||
}
|
||||
|
||||
// Other flags.
|
||||
if (toco_flags.has_default_ranges_min()) {
|
||||
quant_specs->default_ranges.first = toco_flags.default_ranges_min();
|
||||
}
|
||||
if (toco_flags.has_default_ranges_max()) {
|
||||
quant_specs->default_ranges.second = toco_flags.default_ranges_max();
|
||||
}
|
||||
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
Status ConvertMLIRToTFLiteFlatBuffer(const toco::TocoFlags& toco_flags,
|
||||
mlir::OwningModuleRef module,
|
||||
mlir::TFL::QuantizationSpecs quant_specs,
|
||||
string* result) {
|
||||
bool emit_builtin_tflite_ops = !toco_flags.force_select_tf_ops();
|
||||
bool emit_select_tf_ops = toco_flags.enable_select_tf_ops();
|
||||
bool emit_custom_ops = toco_flags.allow_custom_ops();
|
||||
|
||||
if (toco_flags.has_dump_graphviz_dir()) {
|
||||
TF_RETURN_IF_ERROR(DumpOpGraphToFile(
|
||||
module.get(),
|
||||
// rename once we enable the new converter feature flag.
|
||||
absl::StrCat(toco_flags.dump_graphviz_dir(), "/toco_AT_IMPORT.dot")));
|
||||
}
|
||||
|
||||
mlir::PassManager pm(module->getContext());
|
||||
mlir::TFL::PassConfig pass_config(quant_specs);
|
||||
pass_config.emit_builtin_tflite_ops = emit_builtin_tflite_ops;
|
||||
pass_config.lower_tensor_list_ops = true;
|
||||
|
||||
tensorflow::AddTFToTFLConversionPasses(pass_config, &pm);
|
||||
// Convert back to outlined while format for export back to flatbuffer.
|
||||
if (pass_config.legalize_tf_while) {
|
||||
pm.addPass(mlir::TFL::CreateWhileOutlinePass());
|
||||
}
|
||||
pm.addPass(mlir::TFL::CreateRuntimeTypeVerifyPass());
|
||||
|
||||
auto status = ConvertTFExecutorToTFLOrFlatbuffer(
|
||||
module.get(), /*export_to_mlir=*/false, emit_builtin_tflite_ops,
|
||||
emit_select_tf_ops, emit_custom_ops, quant_specs, result, &pm);
|
||||
if (toco_flags.has_dump_graphviz_dir()) {
|
||||
TF_RETURN_IF_ERROR(DumpOpGraphToFile(
|
||||
// rename once we enable the new converter feature flag.
|
||||
module.get(), absl::StrCat(toco_flags.dump_graphviz_dir(),
|
||||
"/toco_AFTER_TRANSFORMATIONS.dot")));
|
||||
}
|
||||
|
||||
return status;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
|
||||
const toco::TocoFlags& toco_flags,
|
||||
const GraphDebugInfo& debug_info,
|
||||
@ -339,7 +59,7 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
|
||||
std::vector<double> node_maxs;
|
||||
|
||||
// Populate quantization specs.
|
||||
TF_RETURN_IF_ERROR(PopulateQuantizationSpecs(
|
||||
TF_RETURN_IF_ERROR(internal::PopulateQuantizationSpecs(
|
||||
model_flags, toco_flags, &quant_specs, &node_names, &node_dtypes,
|
||||
&node_shapes, &node_mins, &node_maxs));
|
||||
|
||||
@ -356,16 +76,16 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
|
||||
specs.convert_legacy_fed_inputs = true;
|
||||
specs.graph_as_function = false;
|
||||
specs.upgrade_legacy = true;
|
||||
WarningUnusedFlags(model_flags, toco_flags);
|
||||
internal::WarningUnusedFlags(model_flags, toco_flags);
|
||||
|
||||
// Register all custom ops, including user-specified custom ops.
|
||||
TF_RETURN_IF_ERROR(RegisterAllCustomOps(toco_flags));
|
||||
TF_RETURN_IF_ERROR(internal::RegisterAllCustomOps(toco_flags));
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto module, ConvertGraphdefToMlir(input, debug_info, specs, &context));
|
||||
|
||||
return ConvertMLIRToTFLiteFlatBuffer(toco_flags, std::move(module),
|
||||
quant_specs, result);
|
||||
return internal::ConvertMLIRToTFLiteFlatBuffer(toco_flags, std::move(module),
|
||||
quant_specs, result);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -0,0 +1,78 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.h"
|
||||
|
||||
#include <utility>
|
||||
|
||||
#include "llvm/Support/ToolOutputFile.h"
|
||||
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
||||
#include "mlir/IR/Module.h" // TF:llvm-project
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
#include "mlir/Support/FileUtilities.h" // TF:llvm-project
|
||||
#include "mlir/Transforms/ViewOpGraph.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h"
|
||||
#include "tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h"
|
||||
#include "tensorflow/compiler/mlir/lite/tf_tfl_passes.h"
|
||||
#include "tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h"
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/protobuf/graph_debug_info.pb.h"
|
||||
#include "tensorflow/lite/toco/model_flags.pb.h"
|
||||
#include "tensorflow/lite/toco/toco_flags.pb.h"
|
||||
#include "tensorflow/lite/toco/types.pb.h"
|
||||
#include "tensorflow/stream_executor/lib/statusor.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
Status ConvertSavedModelToTFLiteFlatBuffer(
|
||||
const toco::ModelFlags& model_flags, const toco::TocoFlags& toco_flags,
|
||||
const string& saved_model_dir, bool saved_model_v1,
|
||||
const string& saved_model_tags, const string& saved_model_exported_names,
|
||||
string* result) {
|
||||
mlir::MLIRContext context;
|
||||
mlir::TFL::QuantizationSpecs quant_specs;
|
||||
|
||||
// Parse input arrays.
|
||||
std::vector<string> node_names;
|
||||
std::vector<string> node_dtypes;
|
||||
std::vector<std::vector<int>> node_shapes;
|
||||
std::vector<double> node_mins;
|
||||
std::vector<double> node_maxs;
|
||||
|
||||
// Populate quantization specs.
|
||||
TF_RETURN_IF_ERROR(internal::PopulateQuantizationSpecs(
|
||||
model_flags, toco_flags, &quant_specs, &node_names, &node_dtypes,
|
||||
&node_shapes, &node_mins, &node_maxs));
|
||||
|
||||
internal::WarningUnusedFlags(model_flags, toco_flags);
|
||||
|
||||
// Register all custom ops, including user-specified custom ops.
|
||||
TF_RETURN_IF_ERROR(internal::RegisterAllCustomOps(toco_flags));
|
||||
|
||||
const bool import_saved_model = !saved_model_v1;
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto module,
|
||||
ImportSavedModel(import_saved_model, saved_model_v1, saved_model_dir,
|
||||
saved_model_tags, saved_model_exported_names, &context));
|
||||
return internal::ConvertMLIRToTFLiteFlatBuffer(toco_flags, std::move(module),
|
||||
quant_specs, result);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
@ -0,0 +1,37 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_PYTHON_SAVED_MODEL_TO_TFL_FLATBUFFER_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_LITE_PYTHON_SAVED_MODEL_TO_TFL_FLATBUFFER_H_
|
||||
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/protobuf/graph_debug_info.pb.h"
|
||||
#include "tensorflow/lite/toco/model_flags.pb.h"
|
||||
#include "tensorflow/lite/toco/toco_flags.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Converts the given saved_model(either v1 or v2) to a TF Lite FlatBuffer
|
||||
// string according to the given model flags, toco flags and tags. Returns error
|
||||
// status if it fails to convert the input.
|
||||
Status ConvertSavedModelToTFLiteFlatBuffer(
|
||||
const toco::ModelFlags& model_flags, const toco::TocoFlags& toco_flags,
|
||||
const string& saved_model_dir, bool saved_model_v1,
|
||||
const string& saved_model_tags, const string& saved_model_exported_names,
|
||||
string* result);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_LITE_PYTHON_SAVED_MODEL_TO_TFL_FLATBUFFER_H_
|
@ -0,0 +1,325 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h"
|
||||
|
||||
#include <ostream>
|
||||
#include <utility>
|
||||
|
||||
#include "llvm/Support/ToolOutputFile.h"
|
||||
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
||||
#include "mlir/IR/Module.h" // TF:llvm-project
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
#include "mlir/Support/FileUtilities.h" // TF:llvm-project
|
||||
#include "mlir/Transforms/ViewOpGraph.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h"
|
||||
#include "tensorflow/compiler/mlir/lite/tf_tfl_passes.h"
|
||||
#include "tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h"
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/protobuf/graph_debug_info.pb.h"
|
||||
#include "tensorflow/lite/toco/model_flags.pb.h"
|
||||
#include "tensorflow/lite/toco/toco_flags.pb.h"
|
||||
#include "tensorflow/lite/toco/types.pb.h"
|
||||
#include "tensorflow/stream_executor/lib/statusor.h"
|
||||
|
||||
using stream_executor::port::StatusOr;
|
||||
|
||||
namespace tensorflow {
|
||||
namespace internal {
|
||||
namespace {
|
||||
|
||||
// Op def string for TFLite_Detection_PostProcess Op.
|
||||
const char kDetectionPostProcessOp[] =
|
||||
"name: 'TFLite_Detection_PostProcess' input_arg: { name: "
|
||||
"'raw_outputs/box_encodings' type: DT_FLOAT } input_arg: { name: "
|
||||
"'raw_outputs/class_predictions' type: DT_FLOAT } input_arg: { name: "
|
||||
"'anchors' type: DT_FLOAT } output_arg: { name: "
|
||||
"'TFLite_Detection_PostProcess' type: DT_FLOAT } output_arg: { name: "
|
||||
"'TFLite_Detection_PostProcess:1' type: DT_FLOAT } output_arg: { name: "
|
||||
"'TFLite_Detection_PostProcess:2' type: DT_FLOAT } output_arg: { name: "
|
||||
"'TFLite_Detection_PostProcess:3' type: DT_FLOAT } attr : { name: "
|
||||
"'h_scale' type: 'float'} attr : { name: 'max_classes_per_detection' "
|
||||
"type: 'int'} attr : { name: 'max_detections' type: 'int'} attr : { "
|
||||
"name: 'nms_iou_threshold' type: 'float'} attr : { name: "
|
||||
"'nms_score_threshold' type: 'float'} attr : { name: 'num_classes' type: "
|
||||
"'int'} attr : { name: 'w_scale' type: 'float'} attr : { name: 'x_scale' "
|
||||
"type: 'float'} attr : { name: 'y_scale' type: 'float'} attr { name: "
|
||||
"'detections_per_class' type: 'int' default_value { i : 100 }} attr { "
|
||||
"name: 'use_regular_nms' type: 'bool' default_value { b : false }}";
|
||||
|
||||
const char kUnidirectionalSequenceLstmOp[] =
|
||||
"name: 'UnidirectionalSequenceLstm' input_arg: {name: 'Input' type: "
|
||||
"DT_FLOAT} input_arg: { name: 'InputToInputWeights' type: DT_FLOAT } "
|
||||
"input_arg: { name: 'InputToForgetWeights' type: DT_FLOAT } input_arg: { "
|
||||
"name: 'InputToCellWeights' type: DT_FLOAT} input_arg: { name: "
|
||||
"'InputToOutputWeights' type: DT_FLOAT } input_arg: { name: "
|
||||
"'RecurrentToInputWeights' type: DT_FLOAT} input_arg: { name: "
|
||||
"'RecurrentToForgetWeights' type: DT_FLOAT} input_arg: { name: "
|
||||
"'RecurrentToCellWeights' type: DT_FLOAT } input_arg: { name: "
|
||||
"'RecurrentToOutputWeights' type: DT_FLOAT } input_arg: { name: "
|
||||
"'CellToInputWeights' type: DT_FLOAT} input_arg: { name: "
|
||||
"'CellToForgetWeights' type: DT_FLOAT } input_arg: { name: "
|
||||
"'CellToOutputWeights' type: DT_FLOAT } input_arg: { name: 'InputGateBias' "
|
||||
"type: DT_FLOAT } input_arg: { name: 'ForgetGateBias' type: DT_FLOAT } "
|
||||
"input_arg: { name: 'kCellGateBias' type: DT_FLOAT } input_arg: { name: "
|
||||
"'OutputGateBias' type: DT_FLOAT } input_arg: { name: 'ProjectionWeights' "
|
||||
"type: DT_FLOAT } input_arg: { name: 'ProjectionBias' type: DT_FLOAT } "
|
||||
"input_arg: { name: 'InputActivationState' type: DT_FLOAT} input_arg: { "
|
||||
"name: 'InputCellStateTensor' type: DT_FLOAT } "
|
||||
"output_arg: { name: 'Concat' type: DT_FLOAT} "
|
||||
"output_arg: { name: "
|
||||
"'LastState' type: DT_FLOAT } output_arg: { name: 'Output' type: DT_FLOAT} "
|
||||
"attr : { name: '_tflite_input_indices' type: 'list(int)'}";
|
||||
|
||||
const char kUnidirectionalSequenceRnnOp[] =
|
||||
"name: 'UnidirectionalSequenceRnn' input_arg: {name: 'Input' type: "
|
||||
"DT_FLOAT} input_arg: { name: 'Weights' type: DT_FLOAT } "
|
||||
"input_arg: { name: 'RecurrentWeights' type: DT_FLOAT } input_arg: { "
|
||||
"name: 'Bias' type: DT_FLOAT} "
|
||||
"input_arg: { name: 'HiddenState' type: DT_FLOAT} "
|
||||
"output_arg: { name: "
|
||||
"'LastState' type: DT_FLOAT } output_arg: { name: 'Output' type: "
|
||||
"DT_FLOAT} "
|
||||
"attr : { name: '_tflite_input_indices' type: 'list(int)'}";
|
||||
|
||||
// Converts the toco::IODataType to tensorflow::DataType. Only contains the
|
||||
// conversion mapping for constants defined in TFLite Python API.
|
||||
DataType ConvertIODataTypeToDataType(toco::IODataType dtype) {
|
||||
switch (dtype) {
|
||||
case toco::IODataType::FLOAT:
|
||||
return DT_FLOAT;
|
||||
case toco::IODataType::QUANTIZED_UINT8:
|
||||
return DT_QUINT8;
|
||||
case toco::IODataType::INT8:
|
||||
return DT_QINT8;
|
||||
case toco::IODataType::INT32:
|
||||
return DT_INT32;
|
||||
case toco::IODataType::INT64:
|
||||
return DT_INT64;
|
||||
case toco::IODataType::STRING:
|
||||
return DT_STRING;
|
||||
case toco::IODataType::BOOL:
|
||||
return DT_BOOL;
|
||||
default:
|
||||
return DT_INVALID;
|
||||
}
|
||||
}
|
||||
|
||||
StatusOr<std::pair<double, double>> InputStatsToMinMax(double mean, double std,
|
||||
DataType type) {
|
||||
// Only qint8 and quint8 are considered here.
|
||||
double qmin, qmax;
|
||||
if (type == DT_QUINT8) {
|
||||
qmin = 0.0;
|
||||
qmax = 255.0;
|
||||
} else if (type == DT_QINT8) {
|
||||
qmin = -128.0;
|
||||
qmax = 127.0;
|
||||
} else {
|
||||
return errors::InvalidArgument("Only int8 and uint8 are considered.");
|
||||
}
|
||||
return std::make_pair((qmin - mean) / std, (qmax - mean) / std);
|
||||
}
|
||||
|
||||
Status RegisterCustomBuiltinOps(const std::vector<string> extra_tf_opdefs) {
|
||||
for (const auto& tf_opdefs_string : extra_tf_opdefs) {
|
||||
tensorflow::OpDef opdef;
|
||||
if (!tensorflow::protobuf::TextFormat::ParseFromString(tf_opdefs_string,
|
||||
&opdef)) {
|
||||
return errors::InvalidArgument("fail to parse extra OpDef");
|
||||
}
|
||||
// Make sure the op is not already registered. If registered continue.
|
||||
const OpRegistrationData* op_reg =
|
||||
tensorflow::OpRegistry::Global()->LookUp(opdef.name());
|
||||
if (op_reg) continue;
|
||||
|
||||
tensorflow::OpRegistry::Global()->Register(
|
||||
[opdef](tensorflow::OpRegistrationData* op_reg_data) -> Status {
|
||||
*op_reg_data = tensorflow::OpRegistrationData(opdef);
|
||||
return Status::OK();
|
||||
});
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Status RegisterAllCustomOps(const toco::TocoFlags& toco_flags) {
|
||||
// Register any custom OpDefs.
|
||||
std::vector<string> extra_tf_opdefs(toco_flags.custom_opdefs().begin(),
|
||||
toco_flags.custom_opdefs().end());
|
||||
extra_tf_opdefs.push_back(kDetectionPostProcessOp);
|
||||
extra_tf_opdefs.push_back(kUnidirectionalSequenceLstmOp);
|
||||
extra_tf_opdefs.push_back(kUnidirectionalSequenceRnnOp);
|
||||
return RegisterCustomBuiltinOps(extra_tf_opdefs);
|
||||
}
|
||||
|
||||
Status PopulateQuantizationSpecs(const toco::ModelFlags& model_flags,
|
||||
const toco::TocoFlags& toco_flags,
|
||||
mlir::TFL::QuantizationSpecs* quant_specs,
|
||||
std::vector<string>* node_names,
|
||||
std::vector<string>* node_dtypes,
|
||||
std::vector<std::vector<int>>* node_shapes,
|
||||
std::vector<double>* node_mins,
|
||||
std::vector<double>* node_maxs) {
|
||||
quant_specs->inference_input_type =
|
||||
ConvertIODataTypeToDataType(toco_flags.inference_input_type());
|
||||
tensorflow::DataType inference_type =
|
||||
ConvertIODataTypeToDataType(toco_flags.inference_type());
|
||||
// Use non-float flag `inference_input_type` to override the `inference_type`
|
||||
// because we have to apply quantization to satisfy that.
|
||||
if (quant_specs->inference_input_type != tensorflow::DT_FLOAT) {
|
||||
inference_type = quant_specs->inference_input_type;
|
||||
}
|
||||
|
||||
for (auto& flag : model_flags.input_arrays()) {
|
||||
node_names->push_back(flag.name());
|
||||
// TOCO doesn't required `data_type` to be filled for every input.
|
||||
// If it's not filled, make it an empty string so the importer will use
|
||||
// the data type in the NodeDef.
|
||||
auto toco_data_type = flag.data_type();
|
||||
if (toco_data_type == ::toco::IODataType::IO_DATA_TYPE_UNKNOWN) {
|
||||
node_dtypes->push_back("");
|
||||
} else {
|
||||
node_dtypes->push_back(
|
||||
DataType_Name(ConvertIODataTypeToDataType(toco_data_type)));
|
||||
}
|
||||
node_shapes->push_back(std::vector<int>(flag.shape().dims().begin(),
|
||||
flag.shape().dims().end()));
|
||||
// Currently, only UINT8 and INT8 require inputs stats
|
||||
if (inference_type == DT_QINT8 || inference_type == DT_QUINT8) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto min_max, InputStatsToMinMax(flag.mean_value(), flag.std_value(),
|
||||
inference_type));
|
||||
node_mins->push_back(min_max.first);
|
||||
node_maxs->push_back(min_max.second);
|
||||
}
|
||||
}
|
||||
|
||||
if (mlir::TFL::GetInputNodeQuantSpecs(*node_names, *node_mins, *node_maxs,
|
||||
inference_type, quant_specs)) {
|
||||
return errors::InvalidArgument("Failed to get input quant spec.");
|
||||
}
|
||||
|
||||
// Some extra flag related to post training quantization. If post-training
|
||||
// quantization is enabled, `inference_type` and `inference_input_type` are
|
||||
// not used by MLIR passes.
|
||||
if (toco_flags.post_training_quantize()) {
|
||||
quant_specs->weight_quantization = true;
|
||||
if (toco_flags.quantize_to_float16()) {
|
||||
quant_specs->inference_type = tensorflow::DT_HALF;
|
||||
quant_specs->inference_input_type = tensorflow::DT_HALF;
|
||||
} else {
|
||||
quant_specs->inference_type = tensorflow::DT_QINT8;
|
||||
quant_specs->inference_input_type = tensorflow::DT_QINT8;
|
||||
}
|
||||
}
|
||||
|
||||
// Other flags.
|
||||
if (toco_flags.has_default_ranges_min()) {
|
||||
quant_specs->default_ranges.first = toco_flags.default_ranges_min();
|
||||
}
|
||||
if (toco_flags.has_default_ranges_max()) {
|
||||
quant_specs->default_ranges.second = toco_flags.default_ranges_max();
|
||||
}
|
||||
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
// Dumps the op graph of the `module` to `filename` in DOT format.
|
||||
Status DumpOpGraphToFile(mlir::ModuleOp module, const std::string& filename) {
|
||||
std::string error_message;
|
||||
auto output = mlir::openOutputFile(filename, &error_message);
|
||||
if (!error_message.empty()) {
|
||||
return errors::InvalidArgument("Failed to open file in %s.", filename);
|
||||
}
|
||||
mlir::PassManager pm(module.getContext());
|
||||
pm.addPass(mlir::createPrintOpGraphPass(output->os()));
|
||||
if (failed(pm.run(module))) {
|
||||
return errors::Unknown("Failed to dump Op Graph from MLIR module.");
|
||||
}
|
||||
output->keep();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ConvertMLIRToTFLiteFlatBuffer(const toco::TocoFlags& toco_flags,
|
||||
mlir::OwningModuleRef module,
|
||||
mlir::TFL::QuantizationSpecs quant_specs,
|
||||
string* result) {
|
||||
bool emit_builtin_tflite_ops = !toco_flags.force_select_tf_ops();
|
||||
bool emit_select_tf_ops = toco_flags.enable_select_tf_ops();
|
||||
bool emit_custom_ops = toco_flags.allow_custom_ops();
|
||||
|
||||
if (toco_flags.has_dump_graphviz_dir()) {
|
||||
TF_RETURN_IF_ERROR(DumpOpGraphToFile(
|
||||
module.get(),
|
||||
// rename once we enable the new converter feature flag.
|
||||
absl::StrCat(toco_flags.dump_graphviz_dir(), "/toco_AT_IMPORT.dot")));
|
||||
}
|
||||
|
||||
mlir::PassManager pm(module->getContext());
|
||||
mlir::TFL::PassConfig pass_config(quant_specs);
|
||||
pass_config.emit_builtin_tflite_ops = emit_builtin_tflite_ops;
|
||||
pass_config.lower_tensor_list_ops = true;
|
||||
|
||||
tensorflow::AddTFToTFLConversionPasses(pass_config, &pm);
|
||||
// Convert back to outlined while format for export back to flatbuffer.
|
||||
if (pass_config.legalize_tf_while) {
|
||||
pm.addPass(mlir::TFL::CreateWhileOutlinePass());
|
||||
}
|
||||
pm.addPass(mlir::TFL::CreateRuntimeTypeVerifyPass());
|
||||
|
||||
auto status = ConvertTFExecutorToTFLOrFlatbuffer(
|
||||
module.get(), /*export_to_mlir=*/false, emit_builtin_tflite_ops,
|
||||
emit_select_tf_ops, emit_custom_ops, quant_specs, result, &pm);
|
||||
if (toco_flags.has_dump_graphviz_dir()) {
|
||||
TF_RETURN_IF_ERROR(DumpOpGraphToFile(
|
||||
// rename once we enable the new converter feature flag.
|
||||
module.get(), absl::StrCat(toco_flags.dump_graphviz_dir(),
|
||||
"/toco_AFTER_TRANSFORMATIONS.dot")));
|
||||
}
|
||||
|
||||
return status;
|
||||
}
|
||||
|
||||
void WarningUnusedFlags(const toco::ModelFlags& model_flags,
|
||||
const toco::TocoFlags& toco_flags) {
|
||||
if (toco_flags.output_format()) {
|
||||
LOG(WARNING) << "Ignored output_format.";
|
||||
}
|
||||
if (toco_flags.drop_control_dependency()) {
|
||||
LOG(WARNING) << "Ignored drop_control_dependency.";
|
||||
}
|
||||
if (toco_flags.reorder_across_fake_quant()) {
|
||||
LOG(WARNING) << "Ignored reorder_across_fake_quant.";
|
||||
}
|
||||
if (model_flags.change_concat_input_ranges()) {
|
||||
LOG(WARNING) << "Ignored change_concat_input_ranges.";
|
||||
}
|
||||
if (toco_flags.dump_graphviz_include_video()) {
|
||||
LOG(WARNING) << "Ignored dump_graphviz_video.";
|
||||
}
|
||||
if (model_flags.allow_nonexistent_arrays()) {
|
||||
LOG(WARNING) << "Allow allow_nonexistent_arrays.";
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace internal
|
||||
} // namespace tensorflow
|
@ -0,0 +1,59 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_PYTHON_TF_TFL_FLATBUFFER_HELPERS_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_LITE_PYTHON_TF_TFL_FLATBUFFER_HELPERS_H_
|
||||
|
||||
#include <ostream>
|
||||
#include <utility>
|
||||
|
||||
#include "mlir/IR/Module.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h"
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
|
||||
#include "tensorflow/lite/toco/model_flags.pb.h"
|
||||
#include "tensorflow/lite/toco/toco_flags.pb.h"
|
||||
#include "tensorflow/lite/toco/types.pb.h"
|
||||
#include "tensorflow/stream_executor/lib/statusor.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace internal {
|
||||
|
||||
// Register all custom ops including user specified custom ops.
|
||||
Status RegisterAllCustomOps(const toco::TocoFlags& toco_flags);
|
||||
|
||||
// Populate quantization specs (or not) given user specified ranges for each
|
||||
// input arrays.
|
||||
Status PopulateQuantizationSpecs(const toco::ModelFlags& model_flags,
|
||||
const toco::TocoFlags& toco_flags,
|
||||
mlir::TFL::QuantizationSpecs* quant_specs,
|
||||
std::vector<string>* node_names,
|
||||
std::vector<string>* node_dtypes,
|
||||
std::vector<std::vector<int>>* node_shapes,
|
||||
std::vector<double>* node_mins,
|
||||
std::vector<double>* node_maxs);
|
||||
|
||||
// Convert imported MLIR file to TfLite flatbuffer.
|
||||
// This will also run relevant passes as well.
|
||||
Status ConvertMLIRToTFLiteFlatBuffer(const toco::TocoFlags& toco_flags,
|
||||
mlir::OwningModuleRef module,
|
||||
mlir::TFL::QuantizationSpecs quant_specs,
|
||||
string* result);
|
||||
|
||||
// Give a warning for any unused flags that have been specified.
|
||||
void WarningUnusedFlags(const toco::ModelFlags& model_flags,
|
||||
const toco::TocoFlags& toco_flags);
|
||||
} // namespace internal
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_LITE_PYTHON_TF_TFL_FLATBUFFER_HELPERS_H_
|
@ -20,7 +20,7 @@ limitations under the License.
|
||||
#define TF_Quantization
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "mlir/Dialect/QuantOps/QuantPredicates.td"
|
||||
include "mlir/Dialect/QuantOps/QuantOpsBase.td"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// QuantizedType definitions.
|
||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdint>
|
||||
#include <limits>
|
||||
#include <numeric>
|
||||
@ -147,29 +148,45 @@ static bool BroadcastVector(int target_size, SmallVectorImpl<T>& data) {
|
||||
// Changes the axis of the input per-channel quantized type to match the
|
||||
// dimension of the target type. Returns nullptr if it fails.
|
||||
static quant::UniformQuantizedPerAxisType ResetAxisAndBroadcast(
|
||||
quant::UniformQuantizedPerAxisType qtype, Type target, int quant_dim) {
|
||||
ArrayRef<int64_t> shape, quant::UniformQuantizedPerAxisType qtype,
|
||||
Type target, int quant_dim) {
|
||||
auto shaped = target.dyn_cast<RankedTensorType>();
|
||||
if (!shaped) return {};
|
||||
ArrayRef<int64_t> new_shape = shaped.getShape();
|
||||
|
||||
SmallVector<double, 4> scales(qtype.getScales().begin(),
|
||||
qtype.getScales().end());
|
||||
SmallVector<int64_t, 4> zero_points(qtype.getZeroPoints().begin(),
|
||||
qtype.getZeroPoints().end());
|
||||
// Broadcast the scales and zero points to match the target size, which is
|
||||
// usually the axis-th dimension of the target type. Currently, it covers two
|
||||
// cases:
|
||||
// - for Transpose, the data layout is changed so the `dim[axis]` still equals
|
||||
// to the `scales_size`. The broadcast skips;
|
||||
// - for Reshape, the data layout isn't changed but the innermost dimension is
|
||||
// expand to cover the last two original dimensions. Thus we just need to be
|
||||
// repeated the `scales` dim[2] times to covers the new dim length.
|
||||
//
|
||||
// TODO(b/141709944): after the fix, the `scales` can be for dim[2], thus we
|
||||
// have to repeat each elements in the `scales` locally dim[3] times.
|
||||
if (BroadcastVector<double>(shaped.getDimSize(quant_dim), scales) ||
|
||||
BroadcastVector<int64_t>(shaped.getDimSize(quant_dim), zero_points)) {
|
||||
|
||||
if (new_shape.size() == shape.size()) { // same rank
|
||||
// Broadcast the scales and zero points to match the target size, which is
|
||||
// usually the axis-th dimension of the target type. Currently, it covers
|
||||
// two cases:
|
||||
// - for Transpose, the data layout is changed so the `dim[axis]` still
|
||||
// equals to the `scales_size`. The broadcast skips;
|
||||
// - for Reshape, the data layout isn't changed but the innermost dimension
|
||||
// is expand to cover the last two original dimensions. Thus we just need to
|
||||
// be repeated the `scales` dim[2] times to covers the new dim length.
|
||||
//
|
||||
// TODO(b/141709944): after the fix, the `scales` can be for dim[2], thus we
|
||||
// have to repeat each elements in the `scales` locally dim[3] times.
|
||||
if (BroadcastVector<double>(shaped.getDimSize(quant_dim), scales) ||
|
||||
BroadcastVector<int64_t>(shaped.getDimSize(quant_dim), zero_points)) {
|
||||
return {};
|
||||
}
|
||||
} else if ((new_shape.size() == shape.size() + 1) && new_shape.back() == 1) {
|
||||
// This is a trivial shift left, then we shift the quant_dim as well.
|
||||
if (std::equal(shape.begin(), shape.end(), new_shape.begin()) &&
|
||||
quant_dim == -1) {
|
||||
quant_dim = shape.size() + quant_dim;
|
||||
} else {
|
||||
return {};
|
||||
}
|
||||
} else {
|
||||
return {};
|
||||
}
|
||||
|
||||
return quant::UniformQuantizedPerAxisType::get(
|
||||
qtype.getFlags(), qtype.getStorageType(), qtype.getExpressedType(),
|
||||
scales, zero_points, quant_dim, qtype.getStorageTypeMin(),
|
||||
@ -179,20 +196,21 @@ static quant::UniformQuantizedPerAxisType ResetAxisAndBroadcast(
|
||||
TypeAttr CastQuantizedTypeAttrFromExpressedType(Builder builder,
|
||||
TypeAttr source, Type target,
|
||||
int axis) {
|
||||
if (auto source_type = source.getValue().dyn_cast_or_null<ShapedType>()) {
|
||||
auto src_ele_type = source_type.getElementType();
|
||||
if (auto quantized_type = src_ele_type.dyn_cast<quant::QuantizedType>()) {
|
||||
if (auto qtype =
|
||||
quantized_type.dyn_cast<quant::UniformQuantizedPerAxisType>()) {
|
||||
quantized_type = ResetAxisAndBroadcast(qtype, target, axis);
|
||||
if (!src_ele_type) return {};
|
||||
}
|
||||
Type final_type = quantized_type.castFromExpressedType(target);
|
||||
if (!final_type) return {};
|
||||
return TypeAttr::get(final_type);
|
||||
}
|
||||
auto source_type = source.getValue().dyn_cast_or_null<ShapedType>();
|
||||
if (!source_type) return {};
|
||||
auto src_ele_type = source_type.getElementType();
|
||||
auto qtype = src_ele_type.dyn_cast<quant::QuantizedType>();
|
||||
|
||||
// Reset the quantization dimensions if it is per-axis.
|
||||
if (auto per_axis =
|
||||
qtype.dyn_cast_or_null<quant::UniformQuantizedPerAxisType>()) {
|
||||
qtype =
|
||||
ResetAxisAndBroadcast(source_type.getShape(), per_axis, target, axis);
|
||||
}
|
||||
return {};
|
||||
if (!qtype) return {};
|
||||
Type final_type = qtype.castFromExpressedType(target);
|
||||
if (!final_type) return {};
|
||||
return TypeAttr::get(final_type);
|
||||
}
|
||||
|
||||
Type GetUniformQuantizedTypeForWeight(ElementsAttr attr, bool symmetric,
|
||||
|
@ -9,6 +9,7 @@ package_group(
|
||||
name = "friends",
|
||||
includes = ["//third_party/mlir:subpackages"],
|
||||
packages = [
|
||||
"//tensorflow/compiler/aot/...",
|
||||
"//tensorflow/compiler/mlir/...",
|
||||
"//tensorflow/compiler/mlir/lite/...",
|
||||
],
|
||||
@ -38,3 +39,29 @@ cc_library(
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "quantize",
|
||||
srcs = [
|
||||
"quantize.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"quantize.h",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/compiler/mlir/xla:hlo",
|
||||
"//tensorflow/compiler/mlir/xla:hlo_to_mlir_hlo",
|
||||
"//tensorflow/compiler/tf2xla",
|
||||
"//tensorflow/compiler/tf2xla:mlir_tf2xla",
|
||||
"//tensorflow/compiler/tf2xla:tf2xla_proto_cc",
|
||||
"//tensorflow/compiler/tf2xla:tf2xla_util",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops",
|
||||
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
||||
"//tensorflow/compiler/xla/client:xla_computation",
|
||||
"//tensorflow/core/platform:status",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:Transforms",
|
||||
],
|
||||
)
|
||||
|
62
tensorflow/compiler/mlir/lite/quantization/xla/quantize.cc
Normal file
62
tensorflow/compiler/mlir/lite/quantization/xla/quantize.cc
Normal file
@ -0,0 +1,62 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/xla/quantize.h"
|
||||
|
||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||
#include "mlir/IR/Function.h" // TF:llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
||||
#include "mlir/IR/Module.h" // TF:llvm-project
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
#include "mlir/Pass/PassManager.h" // TF:llvm-project
|
||||
#include "mlir/Transforms/Passes.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/xla/hlo_to_mlir_hlo.h"
|
||||
#include "tensorflow/compiler/tf2xla/tf2xla.h"
|
||||
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace xla_hlo {
|
||||
|
||||
// Quantizes the model in the computation.
|
||||
tensorflow::Status XlaQuantize(const tensorflow::tf2xla::Config& config,
|
||||
xla::XlaComputation* computation) {
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::HloSnapshot> snapshot,
|
||||
computation->Snapshot());
|
||||
|
||||
MLIRContext context;
|
||||
OwningModuleRef module = ModuleOp::create(UnknownLoc::get(&context));
|
||||
auto status = xla::ConvertHloToMlirHlo(
|
||||
module.get(), snapshot->mutable_hlo()->mutable_hlo_module());
|
||||
if (!status.ok()) {
|
||||
LOG(ERROR) << "Hlo module import failed: " << status;
|
||||
return status;
|
||||
}
|
||||
|
||||
PassManager pm(&context);
|
||||
pm.addPass(createCanonicalizerPass());
|
||||
pm.addPass(createInlinerPass());
|
||||
pm.addPass(createSymbolDCEPass());
|
||||
pm.addNestedPass<FuncOp>(createCSEPass());
|
||||
|
||||
mlir::StatusScopedDiagnosticHandler diag_handler(&context);
|
||||
LogicalResult result = pm.run(module.get());
|
||||
(void)result;
|
||||
|
||||
module->dump();
|
||||
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
} // namespace xla_hlo
|
||||
} // namespace mlir
|
33
tensorflow/compiler/mlir/lite/quantization/xla/quantize.h
Normal file
33
tensorflow/compiler/mlir/lite/quantization/xla/quantize.h
Normal file
@ -0,0 +1,33 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_XLA_QUANTIZE_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_XLA_QUANTIZE_H_
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_computation.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace xla_hlo {
|
||||
|
||||
// Quantizes the model in the computation.
|
||||
tensorflow::Status XlaQuantize(const tensorflow::tf2xla::Config& config,
|
||||
xla::XlaComputation* computation);
|
||||
|
||||
} // namespace xla_hlo
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_XLA_QUANTIZE_H_
|
@ -3,8 +3,14 @@ load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests")
|
||||
package(licenses = ["notice"])
|
||||
|
||||
glob_lit_tests(
|
||||
data = [":test_utilities"],
|
||||
data = [
|
||||
":graph_config_files",
|
||||
":test_utilities",
|
||||
],
|
||||
driver = "@llvm-project//mlir:run_lit.sh",
|
||||
tags_override = {
|
||||
"fadd_quant.mlir": ["no_oss"], # TODO(b/150957738): to be fixed on oss.
|
||||
},
|
||||
test_file_exts = ["mlir"],
|
||||
)
|
||||
|
||||
@ -13,7 +19,17 @@ filegroup(
|
||||
name = "test_utilities",
|
||||
testonly = True,
|
||||
data = [
|
||||
"//tensorflow/compiler/aot:tfcompile",
|
||||
"//tensorflow/compiler/mlir:tf-opt",
|
||||
"@llvm-project//llvm:FileCheck",
|
||||
"@llvm-project//llvm:not",
|
||||
],
|
||||
)
|
||||
|
||||
# Bundle together all the graph files that are used by the tests.
|
||||
filegroup(
|
||||
name = "graph_config_files",
|
||||
srcs = glob(
|
||||
["**/*.pbtxt"],
|
||||
),
|
||||
)
|
||||
|
@ -0,0 +1,15 @@
|
||||
# RUN: not tfcompile --graph=%s.pbtxt --config=%s.config.pbtxt --quantize --cpp_class="::test::fadd_quant" 2>&1 | FileCheck %s -dump-input-on-failure
|
||||
|
||||
# TODO(fengliuai): update this file with the progress of the implementation
|
||||
// CHECK: func @main
|
||||
// CHECK: %cst = constant dense<0.000000e+00> : tensor<f32>
|
||||
// CHECK: %cst_0 = constant dense<1.270000e+02> : tensor<f32>
|
||||
// CHECK: %cst_1 = constant dense<8> : tensor<i32>
|
||||
// CHECK: %cst_2 = constant dense<false> : tensor<i1>
|
||||
// CHECK: %0 = "xla_hlo.custom_call"(%arg0, %cst, %cst_0, %cst_1, %cst_2) {backend_config = "", call_target_name = "fake_quant_with_min_max_vars", has_side_effect = false, name = "custom-call.9"} : (tensor<2x4xf32>, tensor<f32>, tensor<f32>, tensor<i32>, tensor<i1>) -> tensor<2x4xf32>
|
||||
// CHECK: %1 = "xla_hlo.custom_call"(%arg1, %cst, %cst_0, %cst_1, %cst_2) {backend_config = "", call_target_name = "fake_quant_with_min_max_vars", has_side_effect = false, name = "custom-call.14"} : (tensor<2x4xf32>, tensor<f32>, tensor<f32>, tensor<i32>, tensor<i1>) -> tensor<2x4xf32>
|
||||
// CHECK: %2 = xla_hlo.add %0, %1 {name = "add.15"} : tensor<2x4xf32>
|
||||
// CHECK: %3 = "xla_hlo.custom_call"(%2, %cst, %cst_0, %cst_1, %cst_2) {backend_config = "", call_target_name = "fake_quant_with_min_max_vars", has_side_effect = false, name = "custom-call.20"} : (tensor<2x4xf32>, tensor<f32>, tensor<f32>, tensor<i32>, tensor<i1>) -> tensor<2x4xf32>
|
||||
// CHECK: %4 = "xla_hlo.tuple"(%3) {name = "tuple.22"} : (tensor<2x4xf32>) -> tuple<tensor<2x4xf32>>
|
||||
// CHECK: return %4 : tuple<tensor<2x4xf32>>
|
||||
// CHECK: }
|
@ -0,0 +1,26 @@
|
||||
feed {
|
||||
id { node_name: "input0" }
|
||||
shape {
|
||||
dim { size: 2 }
|
||||
dim { size: 4 }
|
||||
}
|
||||
}
|
||||
feed {
|
||||
id { node_name: "input1" }
|
||||
shape {
|
||||
dim { size: 2 }
|
||||
dim { size: 4 }
|
||||
}
|
||||
}
|
||||
|
||||
fetch {
|
||||
id { node_name: "Add/FakeQuantWithMinMaxVars" }
|
||||
shape {
|
||||
dim { size: 2 }
|
||||
dim { size: 4 }
|
||||
}
|
||||
}
|
||||
|
||||
conversion_options {
|
||||
custom_fake_quant_op_calls: true
|
||||
}
|
@ -0,0 +1,218 @@
|
||||
node: {
|
||||
name: "Add/FakeQuantWithMinMaxVars"
|
||||
op: "FakeQuantWithMinMaxVars"
|
||||
input: "Add"
|
||||
input: "Add/FakeQuantWithMinMaxVars/min"
|
||||
input: "Add/FakeQuantWithMinMaxVars/max"
|
||||
attr: {
|
||||
key: "num_bits"
|
||||
value: {
|
||||
i: 8
|
||||
}
|
||||
}
|
||||
attr: {
|
||||
key: "narrow_range"
|
||||
value: {
|
||||
b: false
|
||||
}
|
||||
}
|
||||
}
|
||||
node: {
|
||||
name: "Add/FakeQuantWithMinMaxVars/min"
|
||||
op: "Const"
|
||||
attr: {
|
||||
key: "value"
|
||||
value: {
|
||||
tensor: {
|
||||
dtype: DT_FLOAT
|
||||
tensor_shape: {
|
||||
}
|
||||
float_val: 0.0
|
||||
}
|
||||
}
|
||||
}
|
||||
attr: {
|
||||
key: "dtype"
|
||||
value: {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
}
|
||||
node: {
|
||||
name: "Add/FakeQuantWithMinMaxVars/max"
|
||||
op: "Const"
|
||||
attr: {
|
||||
key: "value"
|
||||
value: {
|
||||
tensor: {
|
||||
dtype: DT_FLOAT
|
||||
tensor_shape: {
|
||||
}
|
||||
float_val: 127.0
|
||||
}
|
||||
}
|
||||
}
|
||||
attr: {
|
||||
key: "dtype"
|
||||
value: {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "Add"
|
||||
op: "Add"
|
||||
input: "input0/FakeQuantWithMinMaxVars"
|
||||
input: "input1/FakeQuantWithMinMaxVars"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
}
|
||||
node: {
|
||||
name: "input0/FakeQuantWithMinMaxVars"
|
||||
op: "FakeQuantWithMinMaxVars"
|
||||
input: "input0"
|
||||
input: "input0/FakeQuantWithMinMaxVars/min"
|
||||
input: "input0/FakeQuantWithMinMaxVars/max"
|
||||
attr: {
|
||||
key: "num_bits"
|
||||
value: {
|
||||
i: 8
|
||||
}
|
||||
}
|
||||
attr: {
|
||||
key: "narrow_range"
|
||||
value: {
|
||||
b: false
|
||||
}
|
||||
}
|
||||
}
|
||||
node: {
|
||||
name: "input0/FakeQuantWithMinMaxVars/min"
|
||||
op: "Const"
|
||||
attr: {
|
||||
key: "value"
|
||||
value: {
|
||||
tensor: {
|
||||
dtype: DT_FLOAT
|
||||
tensor_shape: {
|
||||
}
|
||||
float_val: 0.0
|
||||
}
|
||||
}
|
||||
}
|
||||
attr: {
|
||||
key: "dtype"
|
||||
value: {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
}
|
||||
node: {
|
||||
name: "input0/FakeQuantWithMinMaxVars/max"
|
||||
op: "Const"
|
||||
attr: {
|
||||
key: "value"
|
||||
value: {
|
||||
tensor: {
|
||||
dtype: DT_FLOAT
|
||||
tensor_shape: {
|
||||
}
|
||||
float_val: 127.0
|
||||
}
|
||||
}
|
||||
}
|
||||
attr: {
|
||||
key: "dtype"
|
||||
value: {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "input0"
|
||||
op: "Placeholder"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
}
|
||||
node: {
|
||||
name: "input1/FakeQuantWithMinMaxVars"
|
||||
op: "FakeQuantWithMinMaxVars"
|
||||
input: "input1"
|
||||
input: "input1/FakeQuantWithMinMaxVars/min"
|
||||
input: "input1/FakeQuantWithMinMaxVars/max"
|
||||
attr: {
|
||||
key: "num_bits"
|
||||
value: {
|
||||
i: 8
|
||||
}
|
||||
}
|
||||
attr: {
|
||||
key: "narrow_range"
|
||||
value: {
|
||||
b: false
|
||||
}
|
||||
}
|
||||
}
|
||||
node: {
|
||||
name: "input1/FakeQuantWithMinMaxVars/min"
|
||||
op: "Const"
|
||||
attr: {
|
||||
key: "value"
|
||||
value: {
|
||||
tensor: {
|
||||
dtype: DT_FLOAT
|
||||
tensor_shape: {
|
||||
}
|
||||
float_val: 0.0
|
||||
}
|
||||
}
|
||||
}
|
||||
attr: {
|
||||
key: "dtype"
|
||||
value: {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
}
|
||||
node: {
|
||||
name: "input1/FakeQuantWithMinMaxVars/max"
|
||||
op: "Const"
|
||||
attr: {
|
||||
key: "value"
|
||||
value: {
|
||||
tensor: {
|
||||
dtype: DT_FLOAT
|
||||
tensor_shape: {
|
||||
}
|
||||
float_val: 127.0
|
||||
}
|
||||
}
|
||||
}
|
||||
attr: {
|
||||
key: "dtype"
|
||||
value: {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "input1"
|
||||
op: "Placeholder"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
}
|
||||
versions {
|
||||
producer: 27
|
||||
}
|
@ -1,4 +1,4 @@
|
||||
# RUN: tf_tfl_translate -tf-input-arrays=input0,input1 -tf-input-shapes=4:4 -tf-input-data-types=DT_INT32,DT_INT32 -tf-output-arrays=Add %s -o - | flatbuffer_to_string - | FileCheck %s
|
||||
# RUN: tf_tfl_translate -tf-input-arrays=input0,input1 -tf-input-shapes=4:4 -tf-input-data-types=DT_INT32,DT_INT32 -tf-output-arrays=Add %s -o - | flatbuffer_to_string - | FileCheck --dump-input-on-failure %s
|
||||
|
||||
# Add two tensor<4xi32> inputs and return the result
|
||||
|
||||
@ -90,5 +90,11 @@ versions {
|
||||
# CHECK-EMPTY:
|
||||
# CHECK-NEXT: }, {
|
||||
# CHECK-EMPTY:
|
||||
# CHECK-NEXT: }, {
|
||||
# CHECK-NEXT: data: [ 49, 46, 53, 46, 48, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
|
||||
# CHECK-NEXT: } ],
|
||||
# CHECK-NEXT: metadata: [ {
|
||||
# CHECK-NEXT: name: "min_runtime_version",
|
||||
# CHECK-NEXT: buffer: 4
|
||||
# CHECK-NEXT: } ]
|
||||
# CHECK-NEXT: }
|
||||
|
@ -61,11 +61,11 @@ func @i64() -> tensor<4xi64> {
|
||||
// the same sort of opaque round-trip we get for complex64, but it might be good
|
||||
// to check
|
||||
|
||||
func @uint8() -> tensor<4x!tf.uint8> {
|
||||
func @uint8() -> tensor<4xui8> {
|
||||
// CHECK-LABEL: @uint8
|
||||
// CHECK: value = opaque<"tf", "0x746674656E736F722464747970653A2044545F55494E54382074656E736F725F7368617065207B2064696D207B2073697A653A2034207D207D2074656E736F725F636F6E74656E743A20225C3333365C3235355C3237365C33353722"> : tensor<4x!tf.uint8>
|
||||
%0 = "tfl.pseudo_const"() { value = opaque<"tf", "0x746674656E736F722464747970653A2044545F55494E54382074656E736F725F7368617065207B2064696D207B2073697A653A2034207D207D2074656E736F725F636F6E74656E743A20225C3333365C3235355C3237365C33353722"> : tensor<4x!tf.uint8> } : () -> tensor<4x!tf.uint8>
|
||||
return %0 : tensor<4x!tf.uint8>
|
||||
// CHECK: value = dense<[222, 173, 190, 239]> : tensor<4xui8>
|
||||
%0 = "tfl.pseudo_const"() {value = dense<[222, 173, 190, 239]> : tensor<4xui8>} : () -> tensor<4xui8>
|
||||
return %0 : tensor<4xui8>
|
||||
}
|
||||
|
||||
func @qi32_per_axis() -> tensor<3x3x!quant.uniform<i32:f32:1, {1.0, 0.5:1, 0.25:1}>> {
|
||||
|
@ -13,3 +13,16 @@ func @main(%arg0: tensor<4 x f32>, %arg1: tensor<4 x f32>, %arg2: tensor<4 x f32
|
||||
// CHECK: return %[[RES0]]
|
||||
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testFullyQuantizedLSTM(%arg0: tensor<1x528x!quant.uniform<i8:f32, 0.037248000502586365:-19>>, %arg1: tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, %arg2: tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.031925998628139496>>, %arg3: tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.056272000074386597>>, %arg4: tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.063763998448848724>>, %arg5: tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.013358999975025654>>, %arg6: tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.022830000147223473>>, %arg7: tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.032276000827550888>>, %arg8: tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.035427000373601913>>, %arg9: tensor<2048x!quant.uniform<i32:f32, 4.2675782196965883E-7>>, %arg10: tensor<2048x!quant.uniform<i32:f32, 1.0742187583900886E-7>>, %arg11: tensor<2048x!quant.uniform<i32:f32, 1.6406249869760359E-7>>, %arg12: tensor<2048x!quant.uniform<i32:f32, 1.523437447303877E-7>>, %arg13: tensor<640x2048x!quant.uniform<i8<-127:127>:f32, 0.021174000576138496>>, %arg14: tensor<640x!quant.uniform<i32:f32, 1.601389680352559E-4>>, %arg15: tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>, %arg16: tensor<2048x!quant.uniform<i16:f32, 1.1000000085914508E-4>>, %arg17: tensor<2048x!quant.uniform<i16:f32, 1.6799999866634607E-4>>, %arg18: tensor<2048x!quant.uniform<i16:f32, 1.55999994603917E-4>>, %arg19: tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>>, %arg20: tensor<1x2048x!quant.uniform<i16:f32, 4.8799999058246613E-4>>) -> tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>> {
|
||||
%cst = constant unit
|
||||
%0 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %cst, %cst, %cst, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg19, %arg20, %arg15, %arg16, %arg17, %arg18) ({}) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", input_to_input_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0049890000373125076>>, input_to_forget_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0078849997371435165>>, input_to_cell_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0087630003690719604>>, input_to_output_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0057529998011887074>>, effective_hidden_scale_intermediate = tensor<0x!quant.uniform<i8<-127:127>:f32, 0.0075630000792443752:2>>, kernel_type = "FULL", proj_clip = 0.01 : f32} : (tensor<1x528x!quant.uniform<i8:f32, 0.037248000502586365:-19>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.031925998628139496>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.056272000074386597>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.063763998448848724>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.013358999975025654>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.022830000147223473>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.032276000827550888>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.035427000373601913>>, none, none, none, tensor<2048x!quant.uniform<i32:f32, 4.2675782196965883E-7>>, tensor<2048x!quant.uniform<i32:f32, 1.0742187583900886E-7>>, tensor<2048x!quant.uniform<i32:f32, 1.6406249869760359E-7>>, tensor<2048x!quant.uniform<i32:f32, 1.523437447303877E-7>>, tensor<640x2048x!quant.uniform<i8<-127:127>:f32, 0.021174000576138496>>, tensor<640x!quant.uniform<i32:f32, 1.601389680352559E-4>>, tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>>, tensor<1x2048x!quant.uniform<i16:f32, 4.8799999058246613E-4>>, tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>, tensor<2048x!quant.uniform<i16:f32, 1.1000000085914508E-4>>, tensor<2048x!quant.uniform<i16:f32, 1.6799999866634607E-4>>, tensor<2048x!quant.uniform<i16:f32, 1.55999994603917E-4>>) -> tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>>
|
||||
return %0 : tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>>
|
||||
// CHECK-LABEL: testFullyQuantizedLSTM
|
||||
// CHECK: %cst = constant unit
|
||||
// CHECK: %[[RES0:.*]] = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %cst, %cst, %cst, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg19, %arg20, %arg15, %arg16, %arg17, %arg18)
|
||||
// CHECK: }) {cell_clip = 1.000000e+01 : f32, effective_hidden_scale_intermediate = tensor<0x!quant.uniform<i8<-127:127>:f32, 0.0075630000792443752:2>>, fused_activation_function = "TANH", input_to_cell_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0087630003690719604>>, input_to_forget_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0078849997371435165>>, input_to_input_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0049890000373125076>>, input_to_output_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0057529998011887074>>, kernel_type = "FULL", proj_clip = 0.00999999977 : f32} : (tensor<1x528x!quant.uniform<i8:f32, 0.037248000502586365:-19>>, tensor<2048x528x!quant.uniform<i8:f32, 0.059801999479532242>>, tensor<2048x528x!quant.uniform<i8:f32, 0.031925998628139496>>, tensor<2048x528x!quant.uniform<i8:f32, 0.056272000074386597>>, tensor<2048x528x!quant.uniform<i8:f32, 0.063763998448848724>>, tensor<2048x640x!quant.uniform<i8:f32, 0.013358999975025654>>, tensor<2048x640x!quant.uniform<i8:f32, 0.022830000147223473>>, tensor<2048x640x!quant.uniform<i8:f32, 0.032276000827550888>>, tensor<2048x640x!quant.uniform<i8:f32, 0.035427000373601913>>, none, none, none, tensor<2048x!quant.uniform<i32:f32, 4.2675782196965883E-7>>, tensor<2048x!quant.uniform<i32:f32, 1.0742187583900886E-7>>, tensor<2048x!quant.uniform<i32:f32, 1.6406249869760359E-7>>, tensor<2048x!quant.uniform<i32:f32, 1.523437447303877E-7>>, tensor<640x2048x!quant.uniform<i8:f32, 0.021174000576138496>>, tensor<640x!quant.uniform<i32:f32, 1.6013896674849093E-4>>, tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>>, tensor<1x2048x!quant.uniform<i16:f32, 4.8799999058246613E-4>>, tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>, tensor<2048x!quant.uniform<i16:f32, 1.1000000085914508E-4>>, tensor<2048x!quant.uniform<i16:f32, 1.6799999866634607E-4>>, tensor<2048x!quant.uniform<i16:f32, 1.55999994603917E-4>>) -> tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>>
|
||||
}
|
||||
|
||||
|
@ -58,15 +58,15 @@ func @while_cond_10_frozen0(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>, %arg2: t
|
||||
// CANON-SAME: (tensor<i32>, tensor<256x256xf32>, tensor<?x256x256xf32>)
|
||||
// CANON: [[VAL_1:%.*]] = constant dense<1.000000e+00> : tensor<256x256xf32>
|
||||
// CANON: [[VAL_2:%.*]] = constant dense<0> : tensor<i32>
|
||||
// CANON: [[VAL_3:%.*]] = constant dense<10> : tensor<i32>
|
||||
// CANON: [[VAL_4:%.*]] = constant dense<1> : tensor<i32>
|
||||
// CANON: [[VAL_5:%.*]] = "tf.Const"() {value = dense<2.560000e+02> : tensor<256x256xf32>} : () -> tensor<?x?xf32>
|
||||
// CANON: [[VAL_6:%.*]]:3 = "tfl.while"([[VAL_2]], [[VAL_2]], [[VAL_0]]) ( {
|
||||
// CANON: ^bb0([[VAL_7:%.*]]: tensor<*xi32>, [[VAL_8:%.*]]: tensor<*xi32>, [[VAL_9:%.*]]: tensor<*xf32>):
|
||||
// CANON: [[VAL_3:%.*]] = constant dense<10> : tensor<i32>
|
||||
// CANON: [[VAL_10:%.*]] = "tf.Less"([[VAL_8]], [[VAL_3]])
|
||||
// CANON: "tfl.yield"([[VAL_10]]) : (tensor<*xi1>) -> ()
|
||||
// CANON: }, {
|
||||
// CANON: ^bb0([[VAL_11:%.*]]: tensor<*xi32>, [[VAL_12:%.*]]: tensor<*xi32>, [[VAL_13:%.*]]: tensor<*xf32>):
|
||||
// CANON: [[VAL_4:%.*]] = constant dense<1> : tensor<i32>
|
||||
// CANON: [[VAL_5:%.*]] = "tf.Const"() {value = dense<2.560000e+02> : tensor<256x256xf32>} : () -> tensor<?x?xf32>
|
||||
// CANON: [[VAL_14:%.*]] = "tf.AddV2"([[VAL_12]], [[VAL_4]])
|
||||
// CANON: [[VAL_15:%.*]] = "tf.AddV2"([[VAL_13]], [[VAL_5]])
|
||||
// CANON: [[VAL_16:%.*]] = "tf.AddV2"([[VAL_11]], [[VAL_4]])
|
||||
|
@ -1,22 +1,11 @@
|
||||
// RUN: tf-opt %s -tfl-legalize-tf | FileCheck %s --dump-input-on-failure
|
||||
|
||||
func @addRelu(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> {
|
||||
func @add(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> {
|
||||
%0 = "tf.Add"(%arg0, %arg1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
|
||||
%1 = "tf.Add"(%arg0, %0) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
|
||||
%2 = "tf.Relu"(%1) : (tensor<1xf32>) -> tensor<1xf32>
|
||||
%3 = "tf.Relu"(%arg0) : (tensor<1xf32>) -> tensor<1xf32>
|
||||
%4 = "tf.Add"(%3, %2) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
|
||||
%5 = "tf.Relu6"(%4) : (tensor<1xf32>) -> tensor<1xf32>
|
||||
%6 = "tfl.add"(%5, %3) {fused_activation_function = "NONE"} : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
|
||||
%7 = "tf.Relu6"(%6) : (tensor<1xf32>) -> tensor<1xf32>
|
||||
return %7: tensor<1xf32>
|
||||
return %0: tensor<1xf32>
|
||||
|
||||
// CHECK-LABEL: addRelu
|
||||
// CHECK-LABEL: add
|
||||
// CHECK: tfl.add %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<1xf32>
|
||||
// CHECK: %1 = tfl.add %arg0, %0 {fused_activation_function = "RELU"} : tensor<1xf32>
|
||||
// CHECK: %2 = "tfl.relu"(%arg0) : (tensor<1xf32>) -> tensor<1xf32>
|
||||
// CHECK: %3 = tfl.add %2, %1 {fused_activation_function = "RELU6"} : tensor<1xf32>
|
||||
// CHECK: %4 = tfl.add %3, %2 {fused_activation_function = "RELU6"} : tensor<1xf32>
|
||||
// CHECK: return
|
||||
}
|
||||
|
||||
@ -30,13 +19,10 @@ func @LeakyRelu(%arg0: tensor<1xf32>) -> tensor<1xf32> {
|
||||
|
||||
func @biasAdd(%arg0: tensor<1x10x10x32xf32>, %arg1: tensor<32xf32>) -> tensor<1x10x10x32xf32> {
|
||||
%0 = "tf.BiasAdd"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", data_format = "NHWC"} : (tensor<1x10x10x32xf32>, tensor<32xf32>) -> tensor<1x10x10x32xf32>
|
||||
%1 = "tf.BiasAdd"(%0, %arg1) {T = "tfdtype$DT_FLOAT", data_format = "NHWC"} : (tensor<1x10x10x32xf32>, tensor<32xf32>) -> tensor<1x10x10x32xf32>
|
||||
%2 = "tf.Relu6"(%1) : (tensor<1x10x10x32xf32>) -> tensor<1x10x10x32xf32>
|
||||
return %2 : tensor<1x10x10x32xf32>
|
||||
return %0 : tensor<1x10x10x32xf32>
|
||||
|
||||
// CHECK-LABEL: biasAdd
|
||||
// CHECK: "tfl.add"(%arg0, %arg1) {fused_activation_function = "NONE"} : (tensor<1x10x10x32xf32>, tensor<32xf32>) -> tensor<1x10x10x32xf32>
|
||||
// CHECK: %1 = "tfl.add"(%0, %arg1) {fused_activation_function = "RELU6"} : (tensor<1x10x10x32xf32>, tensor<32xf32>) -> tensor<1x10x10x32xf32>
|
||||
}
|
||||
|
||||
func @biasAddInt(%arg0: tensor<1x10x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x10x10x32xi32> {
|
||||
@ -55,9 +41,9 @@ func @squeezeAndReshape(%arg0: tensor<1x1x10xf32>, %arg1: tensor<?x10xf32>) -> i
|
||||
%4 = "some_op"(%1, %3) : (tensor<*xf32>, tensor<2x5xf32>) -> i32
|
||||
return %4 : i32
|
||||
// CHECK-LABEL: squeezeAndReshape
|
||||
// CHECK: %cst = constant dense<[2, 5]> : tensor<2xi32>
|
||||
// CHECK: "tfl.squeeze"(%arg0) {squeeze_dims = [0]} : (tensor<1x1x10xf32>) -> tensor<1x10xf32>
|
||||
// CHECK: %1 = "tfl.squeeze"(%arg1) {squeeze_dims = []} : (tensor<?x10xf32>) -> tensor<*xf32>
|
||||
// CHECK: %cst = constant dense<[2, 5]> : tensor<2xi32>
|
||||
// CHECK: %2 = "tfl.reshape"(%0, %cst) : (tensor<1x10xf32>, tensor<2xi32>) -> tensor<2x5xf32>
|
||||
// CHECK: %3 = "some_op"(%1, %2) : (tensor<*xf32>, tensor<2x5xf32>) -> i32
|
||||
// CHECK: return
|
||||
@ -88,7 +74,7 @@ func @dynamicReshapeI64Fold(%arg0: tensor<*xf32>) -> tensor<1x2xf32> {
|
||||
return %0 : tensor<1x2xf32>
|
||||
|
||||
// CHECK-LABEL: dynamicReshapeI64Fold
|
||||
// CHECK-NEXT: %[[cst:.*]] = constant dense<[1, 2]> : tensor<2xi32>
|
||||
// CHECK: %[[cst:.*]] = constant dense<[1, 2]> : tensor<2xi32>
|
||||
// CHECK-NEXT: %[[reshape:.*]] = "tfl.reshape"(%arg0, %[[cst]]) : (tensor<*xf32>, tensor<2xi32>) -> tensor<1x2xf32>
|
||||
// CHECK-NEXT: return %[[reshape]] : tensor<1x2xf32>
|
||||
}
|
||||
@ -128,10 +114,10 @@ func @softplus(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
|
||||
return %0 : tensor<8x16xf32>
|
||||
|
||||
// CHECK-LABEL: softplus
|
||||
// CHECK-NEXT: %[[cst:.*]] = constant dense<1.000000e+00> : tensor<f32>
|
||||
// CHECK-NEXT: %[[exp:.*]] = "tfl.exp"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16xf32>
|
||||
// CHECK-NEXT: %[[add:.*]] = "tfl.add"(%[[exp]], %[[cst]]) {fused_activation_function = "NONE"} : (tensor<8x16xf32>, tensor<f32>) -> tensor<8x16xf32>
|
||||
// CHECK-NEXT: %[[log:.*]] = "tfl.log"(%[[add]]) : (tensor<8x16xf32>) -> tensor<8x16xf32>
|
||||
// CHECK: %[[exp:.*]] = "tfl.exp"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16xf32>
|
||||
// CHECK: %[[cst:.*]] = constant dense<1.000000e+00> : tensor<f32>
|
||||
// CHECK: %[[add:.*]] = "tfl.add"(%[[exp]], %[[cst]]) {fused_activation_function = "NONE"} : (tensor<8x16xf32>, tensor<f32>) -> tensor<8x16xf32>
|
||||
// CHECK: %[[log:.*]] = "tfl.log"(%[[add]]) : (tensor<8x16xf32>) -> tensor<8x16xf32>
|
||||
}
|
||||
|
||||
func @fakeQuantArgsFalse(%arg0: tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> {
|
||||
@ -255,20 +241,12 @@ func @zeros_like(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
|
||||
// CHECK: "tfl.zeros_like"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16xf32>
|
||||
}
|
||||
|
||||
func @divRelu(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> {
|
||||
func @div(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> {
|
||||
%0 = "tf.Div"(%arg0, %arg1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
|
||||
%1 = "tf.Div"(%arg0, %0) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
|
||||
%2 = "tf.Relu"(%1) : (tensor<1xf32>) -> tensor<1xf32>
|
||||
%3 = "tf.Relu"(%arg0) : (tensor<1xf32>) -> tensor<1xf32>
|
||||
%4 = "tf.Div"(%3, %2) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
|
||||
%5 = "tf.Relu6"(%4) : (tensor<1xf32>) -> tensor<1xf32>
|
||||
return %5: tensor<1xf32>
|
||||
return %0: tensor<1xf32>
|
||||
|
||||
// CHECK-LABEL: divRelu
|
||||
// CHECK-LABEL: div
|
||||
// CHECK: tfl.div %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<1xf32>
|
||||
// CHECK: %1 = tfl.div %arg0, %0 {fused_activation_function = "RELU"} : tensor<1xf32>
|
||||
// CHECK: %2 = "tfl.relu"(%arg0) : (tensor<1xf32>) -> tensor<1xf32>
|
||||
// CHECK: %3 = tfl.div %2, %1 {fused_activation_function = "RELU6"} : tensor<1xf32>
|
||||
// CHECK: return
|
||||
}
|
||||
|
||||
@ -698,8 +676,9 @@ func @matrix_diag_v2_no_match(%arg0: tensor<8x16xf32>) -> tensor<8x16x16xf32> {
|
||||
// CHECK-SAME: [[VAL_0:%.*]]: tensor<8x16xf32>) -> tensor<8x16x16xf32> {
|
||||
// CHECK: [[VAL_1:%.*]] = constant dense<1> : tensor<1xi32>
|
||||
// CHECK: [[VAL_2:%.*]] = constant dense<-1> : tensor<1xi32>
|
||||
// CHECK: [[VAL_5:%.*]] = constant dense<-1> : tensor<1xi32>
|
||||
// CHECK: [[VAL_3:%.*]] = constant dense<0> : tensor<2xi32>
|
||||
// CHECK: [[VAL_4:%.*]] = "tf.MatrixDiagV2"([[VAL_0]], [[VAL_1]], [[VAL_2]], [[VAL_2]], [[VAL_3]]) : (tensor<8x16xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<2xi32>) -> tensor<8x16x16xf32>
|
||||
// CHECK: [[VAL_4:%.*]] = "tf.MatrixDiagV2"([[VAL_0]], [[VAL_1]], [[VAL_2]], [[VAL_5]], [[VAL_3]]) : (tensor<8x16xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<2xi32>) -> tensor<8x16x16xf32>
|
||||
// CHECK: return [[VAL_4]] : tensor<8x16x16xf32>
|
||||
}
|
||||
|
||||
@ -731,8 +710,9 @@ func @matrix_diag_v3_no_match(%arg0: tensor<8x16xf32>) -> tensor<8x16x16xf32> {
|
||||
// CHECK-SAME: [[VAL_0:%.*]]: tensor<8x16xf32>) -> tensor<8x16x16xf32> {
|
||||
// CHECK: [[VAL_1:%.*]] = constant dense<1> : tensor<1xi32>
|
||||
// CHECK: [[VAL_2:%.*]] = constant dense<-1> : tensor<1xi32>
|
||||
// CHECK: [[VAL_5:%.*]] = constant dense<-1> : tensor<1xi32>
|
||||
// CHECK: [[VAL_3:%.*]] = constant dense<0> : tensor<2xi32>
|
||||
// CHECK: [[VAL_4:%.*]] = "tf.MatrixDiagV3"([[VAL_0]], [[VAL_1]], [[VAL_2]], [[VAL_2]], [[VAL_3]]) : (tensor<8x16xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<2xi32>) -> tensor<8x16x16xf32>
|
||||
// CHECK: [[VAL_4:%.*]] = "tf.MatrixDiagV3"([[VAL_0]], [[VAL_1]], [[VAL_2]], [[VAL_5]], [[VAL_3]]) : (tensor<8x16xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<2xi32>) -> tensor<8x16x16xf32>
|
||||
// CHECK: return [[VAL_4]] : tensor<8x16x16xf32>
|
||||
}
|
||||
|
||||
@ -1295,7 +1275,8 @@ func @conv2d_backprop_input(%arg0: tensor<4xi32>, %arg1: tensor<3x3x1x32xf32>, %
|
||||
// CHECK: %[[CST:.*]] = constant dense<[2, 0, 1, 3]> : tensor<4xi32>
|
||||
// CHECK: %[[ARG0:.*]] = "tfl.transpose"(%arg1, %[[CST]]) : (tensor<3x3x1x32xf32>, tensor<4xi32>) -> tensor<1x3x3x32xf32>
|
||||
// CHECK: %[[ARG1:.*]] = "tfl.transpose_conv"(%arg0, %[[ARG0]], %arg2) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<1x3x3x32xf32>, tensor<15x14x14x32xf32>) -> tensor<15x28x28x1xf32>
|
||||
// CHECK: %[[ARG2:.*]] = "tfl.transpose"(%arg1, %[[CST]]) : (tensor<3x3x1x32xf32>, tensor<4xi32>) -> tensor<1x3x3x32xf32>
|
||||
// CHECK: %[[CST_1:.*]] = constant dense<[2, 0, 1, 3]> : tensor<4xi32>
|
||||
// CHECK: %[[ARG2:.*]] = "tfl.transpose"(%arg1, %[[CST_1]]) : (tensor<3x3x1x32xf32>, tensor<4xi32>) -> tensor<1x3x3x32xf32>
|
||||
// CHECK: %[[ARG3:.*]] = "tfl.transpose_conv"(%arg0, %[[ARG2]], %arg2) {padding = "VALID", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<1x3x3x32xf32>, tensor<15x14x14x32xf32>) -> tensor<15x28x28x1xf32>
|
||||
// CHECK: %[[RESULT:.*]] = tfl.add %[[ARG1]], %[[ARG3]] {fused_activation_function = "NONE"} : tensor<15x28x28x1xf32>
|
||||
// CHECK: return %[[RESULT]] : tensor<15x28x28x1xf32>
|
||||
@ -1340,8 +1321,8 @@ func @reciprocal_f16(%arg0: tensor<8xf16>) -> tensor<8xf16> {
|
||||
return %0: tensor<8xf16>
|
||||
|
||||
// CHECK-LABEL: reciprocal_f16
|
||||
// CHECK: %cst = constant dense<1.000000e+00> : tensor<1xf16>
|
||||
// CHECK: "tfl.div"(%cst, %arg0) {fused_activation_function = "NONE"} : (tensor<1xf16>, tensor<8xf16>) -> tensor<8xf16>
|
||||
// CHECK: %cst = constant dense<1.000000e+00> : tensor<f16>
|
||||
// CHECK: "tfl.div"(%cst, %arg0) {fused_activation_function = "NONE"} : (tensor<f16>, tensor<8xf16>) -> tensor<8xf16>
|
||||
// CHECK: return
|
||||
}
|
||||
|
||||
@ -1350,8 +1331,8 @@ func @reciprocal_f32(%arg0: tensor<8xf32>) -> tensor<8xf32> {
|
||||
return %0: tensor<8xf32>
|
||||
|
||||
// CHECK-LABEL: reciprocal_f32
|
||||
// CHECK: %cst = constant dense<1.000000e+00> : tensor<1xf32>
|
||||
// CHECK: "tfl.div"(%cst, %arg0) {fused_activation_function = "NONE"} : (tensor<1xf32>, tensor<8xf32>) -> tensor<8xf32>
|
||||
// CHECK: %cst = constant dense<1.000000e+00> : tensor<f32>
|
||||
// CHECK: "tfl.div"(%cst, %arg0) {fused_activation_function = "NONE"} : (tensor<f32>, tensor<8xf32>) -> tensor<8xf32>
|
||||
// CHECK: return
|
||||
}
|
||||
|
||||
@ -1360,8 +1341,8 @@ func @reciprocal_complex_f32(%arg0: tensor<8xcomplex<f32>>) -> tensor<8xcomplex<
|
||||
return %0: tensor<8xcomplex<f32>>
|
||||
|
||||
// CHECK-LABEL: reciprocal_complex_f32
|
||||
// CHECK: %cst = constant opaque<"tf", "0x746674656E736F722464747970653A2044545F434F4D504C455836342074656E736F725F7368617065207B2064696D207B2073697A653A2031207D207D2074656E736F725F636F6E74656E743A20225C3030305C3030305C3230303F5C3030305C3030305C3030305C30303022"> : tensor<1xcomplex<f32>>
|
||||
// CHECK: "tfl.div"(%cst, %arg0) {fused_activation_function = "NONE"} : (tensor<1xcomplex<f32>>, tensor<8xcomplex<f32>>) -> tensor<8xcomplex<f32>>
|
||||
// CHECK: %cst = constant opaque<"tf", "0x746674656E736F722464747970653A2044545F434F4D504C455836342074656E736F725F7368617065207B2064696D207B2073697A653A2031207D207D2074656E736F725F636F6E74656E743A20225C3030305C3030305C3230303F5C3030305C3030305C3030305C30303022"> : tensor<complex<f32>>
|
||||
// CHECK: "tfl.div"(%cst, %arg0) {fused_activation_function = "NONE"} : (tensor<complex<f32>>, tensor<8xcomplex<f32>>) -> tensor<8xcomplex<f32>>
|
||||
// CHECK: return
|
||||
}
|
||||
|
||||
@ -1370,8 +1351,8 @@ func @reciprocal_i32(%arg0: tensor<8xi32>) -> tensor<8xi32> {
|
||||
return %0: tensor<8xi32>
|
||||
|
||||
// CHECK-LABEL: reciprocal_i32
|
||||
// CHECK: %cst = constant dense<1> : tensor<1xi32>
|
||||
// CHECK: "tfl.div"(%cst, %arg0) {fused_activation_function = "NONE"} : (tensor<1xi32>, tensor<8xi32>) -> tensor<8xi32>
|
||||
// CHECK: %cst = constant dense<1> : tensor<i32>
|
||||
// CHECK: "tfl.div"(%cst, %arg0) {fused_activation_function = "NONE"} : (tensor<i32>, tensor<8xi32>) -> tensor<8xi32>
|
||||
// CHECK: return
|
||||
}
|
||||
|
||||
@ -1380,8 +1361,8 @@ func @reciprocal_i64(%arg0: tensor<8xi64>) -> tensor<8xi64> {
|
||||
return %0: tensor<8xi64>
|
||||
|
||||
// CHECK-LABEL: reciprocal_i64
|
||||
// CHECK: %cst = constant dense<1> : tensor<1xi64>
|
||||
// CHECK: "tfl.div"(%cst, %arg0) {fused_activation_function = "NONE"} : (tensor<1xi64>, tensor<8xi64>) -> tensor<8xi64>
|
||||
// CHECK: %cst = constant dense<1> : tensor<i64>
|
||||
// CHECK: "tfl.div"(%cst, %arg0) {fused_activation_function = "NONE"} : (tensor<i64>, tensor<8xi64>) -> tensor<8xi64>
|
||||
// CHECK: return
|
||||
}
|
||||
|
||||
@ -1436,7 +1417,7 @@ func @LstmWithoutProjection(%arg: tensor<28x1x28xf32>) -> (tensor<28x1x16xf32>)
|
||||
// CHECK: [[VAL_3:%.*]] = constant dense<0.000000e+00> : tensor<16xf32>
|
||||
// CHECK: [[VAL_4:%.*]] = constant dense<0.000000e+00> : tensor<1x16xf32>
|
||||
// CHECK: [[VAL_5:%.*]] = constant unit
|
||||
// CHECK: [[VAL_6:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_0]], [[VAL_1]], [[VAL_1]], [[VAL_1]], [[VAL_1]], [[VAL_2]], [[VAL_2]], [[VAL_2]], [[VAL_2]], [[VAL_3]], [[VAL_3]], [[VAL_3]], [[VAL_3]], [[VAL_3]], [[VAL_3]], [[VAL_3]], [[VAL_5]], [[VAL_5]], [[VAL_4]], [[VAL_4]], [[VAL_5]], [[VAL_5]], [[VAL_5]], [[VAL_5]]) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = true} : (tensor<28x1x28xf32>, tensor<16x28xf32>, tensor<16x28xf32>, tensor<16x28xf32>, tensor<16x28xf32>, tensor<16x16xf32>, tensor<16x16xf32>, tensor<16x16xf32>, tensor<16x16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, none, none, tensor<1x16xf32>, tensor<1x16xf32>, none, none, none, none) -> tensor<28x1x16xf32>
|
||||
// CHECK: [[VAL_6:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_0]], [[VAL_1]], [[VAL_1]], [[VAL_1]], [[VAL_1]], [[VAL_2]], [[VAL_2]], [[VAL_2]], [[VAL_2]], [[VAL_3]], [[VAL_3]], [[VAL_3]], [[VAL_3]], [[VAL_3]], [[VAL_3]], [[VAL_3]], [[VAL_5]], [[VAL_5]], [[VAL_4]], [[VAL_4]], [[VAL_5]], [[VAL_5]], [[VAL_5]], [[VAL_5]]) {cell_clip = 0.000000e+00 : f32, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = true} : (tensor<28x1x28xf32>, tensor<16x28xf32>, tensor<16x28xf32>, tensor<16x28xf32>, tensor<16x28xf32>, tensor<16x16xf32>, tensor<16x16xf32>, tensor<16x16xf32>, tensor<16x16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, none, none, tensor<1x16xf32>, tensor<1x16xf32>, none, none, none, none) -> tensor<28x1x16xf32>
|
||||
// CHECK: return [[VAL_6]] : tensor<28x1x16xf32>
|
||||
// CHECK: }
|
||||
|
||||
@ -1461,7 +1442,7 @@ func @LstmWithProjection(%arg: tensor<28x1x16xf32>) -> (tensor<28x1x8xf32>) {
|
||||
// CHECK: [[VAL_12:%.*]] = constant dense<0.000000e+00> : tensor<8x16xf32>
|
||||
// CHECK: [[VAL_13:%.*]] = constant dense<0.000000e+00> : tensor<1x8xf32>
|
||||
// CHECK: [[VAL_14:%.*]] = constant unit
|
||||
// CHECK: [[VAL_15:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_7]], [[VAL_8]], [[VAL_8]], [[VAL_8]], [[VAL_8]], [[VAL_9]], [[VAL_9]], [[VAL_9]], [[VAL_9]], [[VAL_14]], [[VAL_14]], [[VAL_14]], [[VAL_10]], [[VAL_10]], [[VAL_10]], [[VAL_10]], [[VAL_12]], [[VAL_14]], [[VAL_13]], [[VAL_11]], [[VAL_14]], [[VAL_14]], [[VAL_14]], [[VAL_14]]) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = true} : (tensor<28x1x16xf32>, tensor<16x16xf32>, tensor<16x16xf32>, tensor<16x16xf32>, tensor<16x16xf32>, tensor<16x8xf32>, tensor<16x8xf32>, tensor<16x8xf32>, tensor<16x8xf32>, none, none, none, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<8x16xf32>, none, tensor<1x8xf32>, tensor<1x16xf32>, none, none, none, none) -> tensor<28x1x8xf32>
|
||||
// CHECK: [[VAL_15:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_7]], [[VAL_8]], [[VAL_8]], [[VAL_8]], [[VAL_8]], [[VAL_9]], [[VAL_9]], [[VAL_9]], [[VAL_9]], [[VAL_14]], [[VAL_14]], [[VAL_14]], [[VAL_10]], [[VAL_10]], [[VAL_10]], [[VAL_10]], [[VAL_12]], [[VAL_14]], [[VAL_13]], [[VAL_11]], [[VAL_14]], [[VAL_14]], [[VAL_14]], [[VAL_14]]) {cell_clip = 0.000000e+00 : f32, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = true} : (tensor<28x1x16xf32>, tensor<16x16xf32>, tensor<16x16xf32>, tensor<16x16xf32>, tensor<16x16xf32>, tensor<16x8xf32>, tensor<16x8xf32>, tensor<16x8xf32>, tensor<16x8xf32>, none, none, none, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<8x16xf32>, none, tensor<1x8xf32>, tensor<1x16xf32>, none, none, none, none) -> tensor<28x1x8xf32>
|
||||
// CHECK: return [[VAL_15]] : tensor<28x1x8xf32>
|
||||
// CHECK: }
|
||||
|
||||
@ -1480,3 +1461,25 @@ func @UnidirectionalRnn(%arg: tensor<28x1x28xf32>) -> (tensor<28x1x28xf32>) {
|
||||
// CHECK: [[VAL_4:%.*]] = "tfl.unidirectional_sequence_rnn"([[VAL_0]], [[VAL_1]], [[VAL_1]], [[VAL_2]], [[VAL_3]]) {fused_activation_function = "TANH", time_major = true} : (tensor<28x1x28xf32>, tensor<28x28xf32>, tensor<28x28xf32>, tensor<28xf32>, tensor<1x28xf32>) -> tensor<28x1x28xf32>
|
||||
// CHECK: return [[VAL_4]] : tensor<28x1x28xf32>
|
||||
// CHECK: }
|
||||
|
||||
func @broadcast_to_f32(%arg0: tensor<3xf32>, %arg1: tensor<2xi32>) -> tensor<3x3xf32> {
|
||||
%0 = "tf.BroadcastTo"(%arg0, %arg1) : (tensor<3xf32>, tensor<2xi32>) -> tensor<3x3xf32>
|
||||
return %0: tensor<3x3xf32>
|
||||
|
||||
// CHECK-LABEL: broadcast_to_f32
|
||||
// CHECK: [[CST:%.*]] = constant dense<1.000000e+00> : tensor<f32>
|
||||
// CHECK: [[FILL:%.*]] = "tfl.fill"(%arg1, [[CST]]) : (tensor<2xi32>, tensor<f32>) -> tensor<3x3xf32>
|
||||
// CHECK: [[MUL:%.*]] = "tfl.mul"(%arg0, [[FILL]]) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32>
|
||||
// CHECK return [[MUL]] : tensor<3x3xf32>
|
||||
}
|
||||
|
||||
func @broadcast_to_i32(%input: tensor<3xi32>, %shape: tensor<2xi32>) -> tensor<3x3xi32> {
|
||||
%0 = "tf.BroadcastTo"(%input, %shape) : (tensor<3xi32>, tensor<2xi32>) -> tensor<3x3xi32>
|
||||
return %0: tensor<3x3xi32>
|
||||
|
||||
// CHECK-LABEL: broadcast_to_i32
|
||||
// CHECK: [[CST:%.*]] = constant dense<1> : tensor<i32>
|
||||
// CHECK: [[FILL:%.*]] = "tfl.fill"(%arg1, [[CST]]) : (tensor<2xi32>, tensor<i32>) -> tensor<3x3xi32>
|
||||
// CHECK: [[MUL:%.*]] = "tfl.mul"(%arg0, [[FILL]]) {fused_activation_function = "NONE"} : (tensor<3xi32>, tensor<3x3xi32>) -> tensor<3x3xi32>
|
||||
// CHECK return [[MUL]] : tensor<3x3xi32>
|
||||
}
|
||||
|
@ -12,11 +12,9 @@
|
||||
// CHECK: Tensor 0 pconst kTfLiteInt32 kTfLiteMmapRo 4 bytes
|
||||
// CHECK-NEXT: Tensor 1 N kTfLiteInt32 kTfLiteMmapRo 4 bytes
|
||||
// CHECK-NEXT: Tensor 2 val kTfLiteFloat32 kTfLiteMmapRo 4 bytes
|
||||
// CHECK-NEXT: Tensor 3 std.constant kTfLiteInt32 kTfLiteMmapRo 4 bytes
|
||||
// CHECK-NEXT: Tensor 4 tfl.while kTfLiteInt32 kTfLiteArenaRw 4 bytes
|
||||
// CHECK-NEXT: Tensor 5 result kTfLiteFloat32 kTfLiteArenaRw 4 bytes
|
||||
// CHECK-NEXT: Tensor 6 tfl.while:2 kTfLiteInt32 kTfLiteArenaRw 4 bytes
|
||||
// CHECK-NEXT: Tensor 7 tfl.while:3 kTfLiteInt32 kTfLiteArenaRw 4 bytes
|
||||
// CHECK-NEXT: Tensor 3 tfl.while kTfLiteInt32 kTfLiteArenaRw 4 bytes
|
||||
// CHECK-NEXT: Tensor 4 result kTfLiteFloat32 kTfLiteArenaRw 4 bytes
|
||||
// CHECK-NEXT: Tensor 5 tfl.while:2 kTfLiteInt32 kTfLiteArenaRw 4 bytes
|
||||
|
||||
// Verify while was not folded away:
|
||||
// ------------------------------------
|
||||
|
@ -108,6 +108,12 @@ func @main(tensor<1x384xf32>, tensor<1x96xf32>, tensor<384x480xf32>, tensor<384x
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: data: [ 49, 46, 49, 48, 46, 48, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: metadata: [ {
|
||||
// CHECK-NEXT: name: "min_runtime_version",
|
||||
// CHECK-NEXT: buffer: 10
|
||||
// CHECK-NEXT: } ]
|
||||
// CHECK-NEXT:}
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -emit-custom-ops -o - | flatbuffer_to_string - | FileCheck %s
|
||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -emit-custom-ops -o - | flatbuffer_to_string - | FileCheck --dump-input-on-failure %s
|
||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir -o - | FileCheck --check-prefix=MLIR %s
|
||||
|
||||
|
||||
@ -61,6 +61,12 @@ func @main(%arg0: tensor<32x4x4x128xf32>, %arg1: tensor<1x32x42x128xf32>, %arg2:
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: data: [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: metadata: [ {
|
||||
// CHECK-NEXT: name: "min_runtime_version",
|
||||
// CHECK-NEXT: buffer: 5
|
||||
// CHECK-NEXT: } ]
|
||||
// CHECK-NEXT:}
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -emit-custom-ops -o - | flatbuffer_to_string - | FileCheck %s
|
||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -emit-custom-ops -o - | flatbuffer_to_string - | FileCheck --dump-input-on-failure %s
|
||||
|
||||
func @main(tensor<4xf32>) -> tensor<4xf32> {
|
||||
^bb0(%arg0: tensor<4xf32>):
|
||||
@ -90,6 +90,12 @@ func @main(tensor<4xf32>) -> tensor<4xf32> {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: data: [ 49, 46, 55, 46, 48, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: metadata: [ {
|
||||
// CHECK-NEXT: name: "min_runtime_version",
|
||||
// CHECK-NEXT: buffer: 6
|
||||
// CHECK-NEXT: } ]
|
||||
// CHECK-NEXT:}
|
||||
|
||||
|
@ -82,6 +82,12 @@ func @main(tensor<1x224x224x3xf32>) -> tensor<1x112x112x32xf32> {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: data: [ 49, 46, 49, 51, 46, 49, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: metadata: [ {
|
||||
// CHECK-NEXT: name: "min_runtime_version",
|
||||
// CHECK-NEXT: buffer: 6
|
||||
// CHECK-NEXT: } ]
|
||||
// CHECK-NEXT:}
|
||||
|
||||
|
@ -84,6 +84,12 @@ func @main(tensor<1x224x224x3xf32>) -> tensor<1x112x112x32xf32> {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: data: [ 49, 46, 49, 51, 46, 49, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: metadata: [ {
|
||||
// CHECK-NEXT: name: "min_runtime_version",
|
||||
// CHECK-NEXT: buffer: 6
|
||||
// CHECK-NEXT: } ]
|
||||
// CHECK-NEXT:}
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck %s
|
||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck --dump-input-on-failure %s
|
||||
|
||||
func @main(tensor<4xf32>) -> tensor<4xf32> {
|
||||
^bb0(%arg0: tensor<4xf32>):
|
||||
@ -88,6 +88,12 @@ func @main(tensor<4xf32>) -> tensor<4xf32> {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: data: [ 49, 46, 55, 46, 48, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: metadata: [ {
|
||||
// CHECK-NEXT: name: "min_runtime_version",
|
||||
// CHECK-NEXT: buffer: 6
|
||||
// CHECK-NEXT: } ]
|
||||
// CHECK-NEXT:}
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck %s
|
||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck --dump-input-on-failure %s
|
||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate -tflite-flatbuffer-to-mlir - -o - | FileCheck --check-prefix=IMPORT %s
|
||||
|
||||
func @main(tensor<4xf32>) -> tensor<4xf32> {
|
||||
@ -46,6 +46,12 @@ func @main(tensor<4xf32>) -> tensor<4xf32> {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: data: [ 49, 46, 53, 46, 48, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: metadata: [ {
|
||||
// CHECK-NEXT: name: "min_runtime_version",
|
||||
// CHECK-NEXT: buffer: 3
|
||||
// CHECK-NEXT: } ]
|
||||
// CHECK-NEXT: }
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -emit-select-tf-ops=true -emit-builtin-tflite-ops=false -o - | flatbuffer_to_string - | FileCheck %s
|
||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -emit-select-tf-ops=true -emit-builtin-tflite-ops=false -o - | flatbuffer_to_string - | FileCheck --dump-input-on-failure %s
|
||||
|
||||
func @main(%arg0: tensor<3x2xf32>) -> tensor<3x2xf32> {
|
||||
// CHECK: {
|
||||
@ -39,6 +39,12 @@ func @main(%arg0: tensor<3x2xf32>) -> tensor<3x2xf32> {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: data: [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: metadata: [ {
|
||||
// CHECK-NEXT: name: "min_runtime_version",
|
||||
// CHECK-NEXT: buffer: 3
|
||||
// CHECK-NEXT: } ]
|
||||
// CHECK-NEXT: }
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -emit-select-tf-ops -o - | flatbuffer_to_string - | FileCheck %s
|
||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -emit-select-tf-ops -o - | flatbuffer_to_string - | FileCheck --dump-input-on-failure %s
|
||||
|
||||
func @main(tensor<4xf32>) -> tensor<4xf32> {
|
||||
^bb0(%arg0: tensor<4xf32>):
|
||||
@ -89,6 +89,12 @@ func @main(tensor<4xf32>) -> tensor<4xf32> {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: data: [ 49, 46, 55, 46, 48, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: metadata: [ {
|
||||
// CHECK-NEXT: name: "min_runtime_version",
|
||||
// CHECK-NEXT: buffer: 6
|
||||
// CHECK-NEXT: } ]
|
||||
// CHECK-NEXT:}
|
||||
|
||||
|
@ -61,6 +61,12 @@ func @main(tensor<40x37xf32>, tensor<40x37xf32>) -> tensor<40x40xf32> {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: data: [ 49, 46, 53, 46, 48, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: metadata: [ {
|
||||
// CHECK-NEXT: name: "min_runtime_version",
|
||||
// CHECK-NEXT: buffer: 5
|
||||
// CHECK-NEXT: } ]
|
||||
// CHECK-NEXT:}
|
||||
|
||||
|
@ -61,6 +61,12 @@ func @main(tensor<40x37xf32>, tensor<40x37xf32>) -> tensor<40x40xf32> {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: data: [ 49, 46, 49, 48, 46, 48, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: metadata: [ {
|
||||
// CHECK-NEXT: name: "min_runtime_version",
|
||||
// CHECK-NEXT: buffer: 5
|
||||
// CHECK-NEXT: } ]
|
||||
// CHECK-NEXT:}
|
||||
|
||||
|
@ -156,6 +156,12 @@
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: data: [ 49, 46, 49, 52, 46, 48, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: metadata: [ {
|
||||
// CHECK-NEXT: name: "min_runtime_version",
|
||||
// CHECK-NEXT: buffer: 11
|
||||
// CHECK-NEXT: } ]
|
||||
// CHECK-NEXT: }
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck %s
|
||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck --dump-input-on-failure %s
|
||||
|
||||
func @main(tensor<4xi1>) -> tensor<4xi1> {
|
||||
^bb0(%arg0: tensor<4xi1>):
|
||||
@ -78,6 +78,12 @@ func @main(tensor<4xi1>) -> tensor<4xi1> {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: data: [ 49, 46, 49, 49, 46, 48, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: metadata: [ {
|
||||
// CHECK-NEXT: name: "min_runtime_version",
|
||||
// CHECK-NEXT: buffer: 6
|
||||
// CHECK-NEXT: } ]
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-EMPTY:
|
||||
|
@ -1,4 +1,4 @@
|
||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck %s
|
||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck --dump-input-on-failure %s
|
||||
|
||||
func @main(tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>) -> tensor<4 x f32> {
|
||||
// CHECK: {
|
||||
@ -192,7 +192,8 @@ func @main(tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, t
|
||||
// CHECK-NEXT: builtin_options_type: LSTMOptions,
|
||||
// CHECK-NEXT: builtin_options: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: },
|
||||
// CHECK-NEXT: intermediates: [ ]
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: name: "main"
|
||||
// CHECK-NEXT: } ],
|
||||
@ -249,6 +250,12 @@ func @main(tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, t
|
||||
// CHECK-NEXT: data: [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: data: [ 49, 46, 55, 46, 48, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: metadata: [ {
|
||||
// CHECK-NEXT: name: "min_runtime_version",
|
||||
// CHECK-NEXT: buffer: 26
|
||||
// CHECK-NEXT: } ]
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-EMPTY:
|
||||
|
@ -0,0 +1,323 @@
|
||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck --dump-input-on-failure %s
|
||||
|
||||
func @main(%arg0: tensor<1x528x!quant.uniform<i8:f32, 0.037248000502586365:-19>>, %arg1: tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, %arg2: tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.031925998628139496>>, %arg3: tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.056272000074386597>>, %arg4: tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.063763998448848724>>, %arg5: tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.013358999975025654>>, %arg6: tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.022830000147223473>>, %arg7: tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.032276000827550888>>, %arg8: tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.035427000373601913>>, %arg9: tensor<2048x!quant.uniform<i32:f32, 4.2675782196965883E-7>>, %arg10: tensor<2048x!quant.uniform<i32:f32, 1.0742187583900886E-7>>, %arg11: tensor<2048x!quant.uniform<i32:f32, 1.6406249869760359E-7>>, %arg12: tensor<2048x!quant.uniform<i32:f32, 1.523437447303877E-7>>, %arg13: tensor<640x2048x!quant.uniform<i8<-127:127>:f32, 0.021174000576138496>>, %arg14: tensor<640x!quant.uniform<i32:f32, 1.601389680352559E-4>>, %arg15: tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>, %arg16: tensor<2048x!quant.uniform<i16:f32, 1.1000000085914508E-4>>, %arg17: tensor<2048x!quant.uniform<i16:f32, 1.6799999866634607E-4>>, %arg18: tensor<2048x!quant.uniform<i16:f32, 1.55999994603917E-4>>, %arg19: tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>>, %arg20: tensor<1x2048x!quant.uniform<i16:f32, 4.8799999058246613E-4>>) -> tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>> {
|
||||
%cst = constant unit
|
||||
%0 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %cst, %cst, %cst, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg19, %arg20, %arg15, %arg16, %arg17, %arg18) ({}) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", input_to_input_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0049890000373125076>>, input_to_forget_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0078849997371435165>>, input_to_cell_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0087630003690719604>>, input_to_output_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0057529998011887074>>, effective_hidden_scale_intermediate = tensor<0x!quant.uniform<i8<-127:127>:f32, 0.0075630000792443752:2>>, kernel_type = "FULL", proj_clip = 0.01 : f32} : (tensor<1x528x!quant.uniform<i8:f32, 0.037248000502586365:-19>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.031925998628139496>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.056272000074386597>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.063763998448848724>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.013358999975025654>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.022830000147223473>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.032276000827550888>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.035427000373601913>>, none, none, none, tensor<2048x!quant.uniform<i32:f32, 4.2675782196965883E-7>>, tensor<2048x!quant.uniform<i32:f32, 1.0742187583900886E-7>>, tensor<2048x!quant.uniform<i32:f32, 1.6406249869760359E-7>>, tensor<2048x!quant.uniform<i32:f32, 1.523437447303877E-7>>, tensor<640x2048x!quant.uniform<i8<-127:127>:f32, 0.021174000576138496>>, tensor<640x!quant.uniform<i32:f32, 1.601389680352559E-4>>, tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>>, tensor<1x2048x!quant.uniform<i16:f32, 4.8799999058246613E-4>>, tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>, tensor<2048x!quant.uniform<i16:f32, 1.1000000085914508E-4>>, tensor<2048x!quant.uniform<i16:f32, 1.6799999866634607E-4>>, tensor<2048x!quant.uniform<i16:f32, 1.55999994603917E-4>>) -> tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>>
|
||||
return %0 : tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>>
|
||||
// CHECK: {
|
||||
// CHECK-NEXT: version: 3,
|
||||
// CHECK-NEXT: operator_codes: [ {
|
||||
// CHECK-NEXT: builtin_code: LSTM,
|
||||
// CHECK-NEXT: version: 1
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: subgraphs: [ {
|
||||
// CHECK-NEXT: tensors: [ {
|
||||
// CHECK-NEXT: shape: [ 1, 528 ],
|
||||
// CHECK-NEXT: type: INT8,
|
||||
// CHECK-NEXT: buffer: 1,
|
||||
// CHECK-NEXT: name: "arg0",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-NEXT: scale: [ 0.037248 ],
|
||||
// CHECK-NEXT: zero_point: [ -19 ]
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 2048, 528 ],
|
||||
// CHECK-NEXT: type: INT8,
|
||||
// CHECK-NEXT: buffer: 2,
|
||||
// CHECK-NEXT: name: "arg1",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-NEXT: scale: [ 0.059802 ],
|
||||
// CHECK-NEXT: zero_point: [ 0 ]
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 2048, 528 ],
|
||||
// CHECK-NEXT: type: INT8,
|
||||
// CHECK-NEXT: buffer: 3,
|
||||
// CHECK-NEXT: name: "arg2",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-NEXT: scale: [ 0.031926 ],
|
||||
// CHECK-NEXT: zero_point: [ 0 ]
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 2048, 528 ],
|
||||
// CHECK-NEXT: type: INT8,
|
||||
// CHECK-NEXT: buffer: 4,
|
||||
// CHECK-NEXT: name: "arg3",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-NEXT: scale: [ 0.056272 ],
|
||||
// CHECK-NEXT: zero_point: [ 0 ]
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 2048, 528 ],
|
||||
// CHECK-NEXT: type: INT8,
|
||||
// CHECK-NEXT: buffer: 5,
|
||||
// CHECK-NEXT: name: "arg4",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-NEXT: scale: [ 0.063764 ],
|
||||
// CHECK-NEXT: zero_point: [ 0 ]
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 2048, 640 ],
|
||||
// CHECK-NEXT: type: INT8,
|
||||
// CHECK-NEXT: buffer: 6,
|
||||
// CHECK-NEXT: name: "arg5",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-NEXT: scale: [ 0.013359 ],
|
||||
// CHECK-NEXT: zero_point: [ 0 ]
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 2048, 640 ],
|
||||
// CHECK-NEXT: type: INT8,
|
||||
// CHECK-NEXT: buffer: 7,
|
||||
// CHECK-NEXT: name: "arg6",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-NEXT: scale: [ 0.02283 ],
|
||||
// CHECK-NEXT: zero_point: [ 0 ]
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 2048, 640 ],
|
||||
// CHECK-NEXT: type: INT8,
|
||||
// CHECK-NEXT: buffer: 8,
|
||||
// CHECK-NEXT: name: "arg7",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-NEXT: scale: [ 0.032276 ],
|
||||
// CHECK-NEXT: zero_point: [ 0 ]
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 2048, 640 ],
|
||||
// CHECK-NEXT: type: INT8,
|
||||
// CHECK-NEXT: buffer: 9,
|
||||
// CHECK-NEXT: name: "arg8",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-NEXT: scale: [ 0.035427 ],
|
||||
// CHECK-NEXT: zero_point: [ 0 ]
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 2048 ],
|
||||
// CHECK-NEXT: type: INT32,
|
||||
// CHECK-NEXT: buffer: 10,
|
||||
// CHECK-NEXT: name: "arg9",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-NEXT: scale: [ 0.0 ],
|
||||
// CHECK-NEXT: zero_point: [ 0 ]
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 2048 ],
|
||||
// CHECK-NEXT: type: INT32,
|
||||
// CHECK-NEXT: buffer: 11,
|
||||
// CHECK-NEXT: name: "arg10",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-NEXT: scale: [ 0.0 ],
|
||||
// CHECK-NEXT: zero_point: [ 0 ]
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 2048 ],
|
||||
// CHECK-NEXT: type: INT32,
|
||||
// CHECK-NEXT: buffer: 12,
|
||||
// CHECK-NEXT: name: "arg11",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-NEXT: scale: [ 0.0 ],
|
||||
// CHECK-NEXT: zero_point: [ 0 ]
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 2048 ],
|
||||
// CHECK-NEXT: type: INT32,
|
||||
// CHECK-NEXT: buffer: 13,
|
||||
// CHECK-NEXT: name: "arg12",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-NEXT: scale: [ 0.0 ],
|
||||
// CHECK-NEXT: zero_point: [ 0 ]
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 640, 2048 ],
|
||||
// CHECK-NEXT: type: INT8,
|
||||
// CHECK-NEXT: buffer: 14,
|
||||
// CHECK-NEXT: name: "arg13",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-NEXT: scale: [ 0.021174 ],
|
||||
// CHECK-NEXT: zero_point: [ 0 ]
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 640 ],
|
||||
// CHECK-NEXT: type: INT32,
|
||||
// CHECK-NEXT: buffer: 15,
|
||||
// CHECK-NEXT: name: "arg14",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-NEXT: scale: [ 0.00016 ],
|
||||
// CHECK-NEXT: zero_point: [ 0 ]
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 2048 ],
|
||||
// CHECK-NEXT: type: INT16,
|
||||
// CHECK-NEXT: buffer: 16,
|
||||
// CHECK-NEXT: name: "arg15",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-NEXT: scale: [ 0.000437 ],
|
||||
// CHECK-NEXT: zero_point: [ 0 ]
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 2048 ],
|
||||
// CHECK-NEXT: type: INT16,
|
||||
// CHECK-NEXT: buffer: 17,
|
||||
// CHECK-NEXT: name: "arg16",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-NEXT: scale: [ 0.00011 ],
|
||||
// CHECK-NEXT: zero_point: [ 0 ]
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 2048 ],
|
||||
// CHECK-NEXT: type: INT16,
|
||||
// CHECK-NEXT: buffer: 18,
|
||||
// CHECK-NEXT: name: "arg17",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-NEXT: scale: [ 0.000168 ],
|
||||
// CHECK-NEXT: zero_point: [ 0 ]
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 2048 ],
|
||||
// CHECK-NEXT: type: INT16,
|
||||
// CHECK-NEXT: buffer: 19,
|
||||
// CHECK-NEXT: name: "arg18",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-NEXT: scale: [ 0.000156 ],
|
||||
// CHECK-NEXT: zero_point: [ 0 ]
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 1, 640 ],
|
||||
// CHECK-NEXT: type: INT8,
|
||||
// CHECK-NEXT: name: "arg19",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-NEXT: scale: [ 0.096711 ],
|
||||
// CHECK-NEXT: zero_point: [ 10 ]
|
||||
// CHECK-NEXT: },
|
||||
// CHECK-NEXT: is_variable: true
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 1, 2048 ],
|
||||
// CHECK-NEXT: type: INT16,
|
||||
// CHECK-NEXT: name: "arg20",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-NEXT: scale: [ 0.000488 ],
|
||||
// CHECK-NEXT: zero_point: [ 0 ]
|
||||
// CHECK-NEXT: },
|
||||
// CHECK-NEXT: is_variable: true
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 0 ],
|
||||
// CHECK-NEXT: type: INT16,
|
||||
// CHECK-NEXT: name: "input_to_input_intermediate",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-NEXT: scale: [ 0.004989 ],
|
||||
// CHECK-NEXT: zero_point: [ 0 ]
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 0 ],
|
||||
// CHECK-NEXT: type: INT16,
|
||||
// CHECK-NEXT: name: "input_to_forget_intermediate",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-NEXT: scale: [ 0.007885 ],
|
||||
// CHECK-NEXT: zero_point: [ 0 ]
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 0 ],
|
||||
// CHECK-NEXT: type: INT16,
|
||||
// CHECK-NEXT: name: "input_to_cell_intermediate",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-NEXT: scale: [ 0.008763 ],
|
||||
// CHECK-NEXT: zero_point: [ 0 ]
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 0 ],
|
||||
// CHECK-NEXT: type: INT16,
|
||||
// CHECK-NEXT: name: "input_to_output_intermediate",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-NEXT: scale: [ 0.005753 ],
|
||||
// CHECK-NEXT: zero_point: [ 0 ]
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 0 ],
|
||||
// CHECK-NEXT: type: INT8,
|
||||
// CHECK-NEXT: name: "effective_hidden_scale_intermediate",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-NEXT: scale: [ 0.007563 ],
|
||||
// CHECK-NEXT: zero_point: [ 2 ]
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 1, 640 ],
|
||||
// CHECK-NEXT: type: INT8,
|
||||
// CHECK-NEXT: buffer: 22,
|
||||
// CHECK-NEXT: name: "tfl.lstm",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-NEXT: scale: [ 0.096711 ],
|
||||
// CHECK-NEXT: zero_point: [ 10 ]
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: inputs: [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20 ],
|
||||
// CHECK-NEXT: outputs: [ 26 ],
|
||||
// CHECK-NEXT: operators: [ {
|
||||
// CHECK-NEXT: inputs: [ 0, 1, 2, 3, 4, 5, 6, 7, 8, -1, -1, -1, 9, 10, 11, 12, 13, 14, 19, 20, 15, 16, 17, 18 ],
|
||||
// CHECK-NEXT: outputs: [ 26 ],
|
||||
// CHECK-NEXT: builtin_options_type: LSTMOptions,
|
||||
// CHECK-NEXT: builtin_options: {
|
||||
// CHECK-NEXT: fused_activation_function: TANH,
|
||||
// CHECK-NEXT: cell_clip: 10.0,
|
||||
// CHECK-NEXT: proj_clip: 0.01
|
||||
// CHECK-NEXT: },
|
||||
// CHECK-NEXT: intermediates: [ 21, 22, 23, 24, 25 ]
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: name: "main"
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: description: "MLIR Converted.",
|
||||
// CHECK-NEXT: buffers: [ {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: data: [ 49, 46, 55, 46, 48, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: metadata: [ {
|
||||
// CHECK-NEXT: name: "min_runtime_version",
|
||||
// CHECK-NEXT: buffer: 23
|
||||
// CHECK-NEXT: } ]
|
||||
// CHECK-NEXT: }
|
||||
}
|
@ -128,6 +128,12 @@ func @main(tensor<4xf32>) -> tensor<4xf32> {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: data: [ 49, 46, 49, 51, 46, 49, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: metadata: [ {
|
||||
// CHECK-NEXT: name: "min_runtime_version",
|
||||
// CHECK-NEXT: buffer: 8
|
||||
// CHECK-NEXT: } ]
|
||||
// CHECK-NEXT: }
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -emit-custom-ops -o - | flatbuffer_to_string - | FileCheck %s
|
||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -emit-custom-ops -o - | flatbuffer_to_string - | FileCheck --dump-input-on-failure %s
|
||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir -o - | FileCheck --check-prefix=MLIR %s
|
||||
|
||||
func @main(%arg0: tensor<1x64x64x32xf32>) -> (tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>) {
|
||||
@ -50,6 +50,12 @@ func @main(%arg0: tensor<1x64x64x32xf32>) -> (tensor<1x32x32x32xf32>, tensor<1x3
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: data: [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: metadata: [ {
|
||||
// CHECK-NEXT: name: "min_runtime_version",
|
||||
// CHECK-NEXT: buffer: 4
|
||||
// CHECK-NEXT: } ]
|
||||
// CHECK-NEXT:}
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -emit-custom-ops -o - | flatbuffer_to_string - | FileCheck %s
|
||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -emit-custom-ops -o - | flatbuffer_to_string - | FileCheck --dump-input-on-failure %s
|
||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir -o - | FileCheck --check-prefix=MLIR %s
|
||||
|
||||
func @main(%arg0: tensor<1x8x8x128xf32>, %arg1: tensor<1x8x8x128xf32>) -> tensor<1x8x8x128xf32> {
|
||||
@ -50,6 +50,12 @@ func @main(%arg0: tensor<1x8x8x128xf32>, %arg1: tensor<1x8x8x128xf32>) -> tensor
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: data: [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: metadata: [ {
|
||||
// CHECK-NEXT: name: "min_runtime_version",
|
||||
// CHECK-NEXT: buffer: 4
|
||||
// CHECK-NEXT: } ]
|
||||
// CHECK-NEXT:}
|
||||
|
||||
|
@ -20,6 +20,8 @@ module attributes {
|
||||
// CHECK-NEXT: data: [ 118, 97, 108, 117, 101, 49 ]
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: data: [ 118, 97, 108, 117, 101, 50 ]
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: data: [ 49, 46, 54, 46, 48, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: metadata: [ {
|
||||
// CHECK-NEXT: name: "key1",
|
||||
@ -27,4 +29,8 @@ module attributes {
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: name: "key2",
|
||||
// CHECK-NEXT: buffer: 5
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: name: "min_runtime_version",
|
||||
// CHECK-NEXT: buffer: 6
|
||||
// CHECK-NEXT: } ]
|
||||
// CHECK-NEXT: }
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user