Merge branch 'master' into interface_16x8

This commit is contained in:
Elena Zhelezina 2020-03-17 16:17:33 +00:00 committed by GitHub
commit 5b9a467e3d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1892 changed files with 65735 additions and 24625 deletions

View File

@ -46,7 +46,6 @@
# sycl_asan:
# sycl_trisycl:
# mkl: Enable full mkl support.
# mkl_open_source_only: Enable MKL support only using open source MKL libraries.
# tensorrt: Enable Tensorrt support.
# ngraph: Enable ngraph support.
# numa: Enable numa using hwloc.
@ -140,13 +139,6 @@ build:mkl --define=tensorflow_mkldnn_contraction_kernel=0
build:mkl --define=build_with_mkl_dnn_v1_only=true
build:mkl -c opt
# This config option is used to enable MKL-DNN open source library only,
# without depending on MKL binary version.
build:mkl_open_source_only --define=build_with_mkl_dnn_only=true
build:mkl_open_source_only --define=build_with_mkl_dnn_v1_only=true
build:mkl_open_source_only --define=build_with_mkl=true --define=enable_mkl=true
build:mkl_open_source_only --define=tensorflow_mkldnn_contraction_kernel=0
# This config refers to building with CUDA available. It does not necessarily
# mean that we build CUDA op kernels.
build:using_cuda --define=using_cuda=true
@ -248,6 +240,7 @@ build:windows --copt=/w
# Tensorflow uses M_* math constants that only get defined by MSVC headers if
# _USE_MATH_DEFINES is defined.
build:windows --copt=/D_USE_MATH_DEFINES
build:windows --host_copt=/D_USE_MATH_DEFINES
# Default paths for TF_SYSTEM_LIBS
build:linux --define=PREFIX=/usr

4
.gitignore vendored
View File

@ -38,7 +38,9 @@ gradleBuild
*.pbxproj
*.xcworkspace
/*.podspec
/tensorflow/lite/**/[ios|objc|swift]*/BUILD
/tensorflow/lite/**/ios/BUILD
/tensorflow/lite/**/objc/BUILD
/tensorflow/lite/**/swift/BUILD
/tensorflow/lite/examples/ios/simple/data/*.tflite
/tensorflow/lite/examples/ios/simple/data/*.txt
Podfile.lock

View File

@ -154,7 +154,10 @@ tf_cuda_library(
"c_api.h",
],
copts = tf_copts(),
visibility = ["//tensorflow/c:__subpackages__"],
visibility = [
"//tensorflow/c:__subpackages__",
"//third_party/llvm/llvm-project:__subpackages__",
],
deps = [
":c_api_internal",
":tf_attrtype",
@ -698,4 +701,5 @@ tf_cuda_library(
# TODO(b/74620627): remove when _USE_C_SHAPES is removed
"//tensorflow/python:cpp_shape_inference_proto_cc",
],
alwayslink = 1,
)

View File

@ -774,7 +774,7 @@ extern "C" {
static TF_OperationDescription* TF_NewOperationLocked(TF_Graph* graph,
const char* op_type,
const char* oper_name)
EXCLUSIVE_LOCKS_REQUIRED(graph->mu) {
TF_EXCLUSIVE_LOCKS_REQUIRED(graph->mu) {
return new TF_OperationDescription(graph, op_type, oper_name);
}
@ -1032,7 +1032,7 @@ void TF_SetAttrValueProto(TF_OperationDescription* desc, const char* attr_name,
static TF_Operation* TF_FinishOperationLocked(TF_OperationDescription* desc,
TF_Status* status)
EXCLUSIVE_LOCKS_REQUIRED(desc->graph->mu) {
TF_EXCLUSIVE_LOCKS_REQUIRED(desc->graph->mu) {
Node* ret = nullptr;
if (desc->graph->name_map.count(desc->node_builder.node_name())) {
@ -1706,7 +1706,7 @@ static void GraphImportGraphDefLocked(TF_Graph* graph, const GraphDef& def,
const TF_ImportGraphDefOptions* opts,
TF_ImportGraphDefResults* tf_results,
TF_Status* status)
EXCLUSIVE_LOCKS_REQUIRED(graph->mu) {
TF_EXCLUSIVE_LOCKS_REQUIRED(graph->mu) {
const int last_node_id = graph->graph.num_node_ids();
tensorflow::ImportGraphDefResults results;
status->status = tensorflow::ImportGraphDef(opts->opts, def, &graph->graph,

View File

@ -51,7 +51,7 @@ Status ProcessInputs(
const TF_Graph* fn_body, const char* fn_name, int ninputs,
const TF_Output* inputs, std::vector<OutputTensor>* input_tensors,
std::unordered_map<const Node*, std::vector<int>>* input_nodes)
EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
TF_EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
input_tensors->reserve(ninputs);
for (int i = 0; i < ninputs; ++i) {
Node* node = &inputs[i].oper->node;
@ -87,7 +87,7 @@ Status ProcessInputs(
Status ProcessOutputs(const TF_Graph* fn_body, const char* fn_name,
int noutputs, const TF_Output* outputs,
std::vector<OutputTensor>* output_tensors)
EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
TF_EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
output_tensors->reserve(noutputs);
for (int i = 0; i < noutputs; ++i) {
Node* node = &outputs[i].oper->node;
@ -111,7 +111,7 @@ Status ComputeBodyNodes(
const TF_Operation* const* opers,
const std::unordered_map<const Node*, std::vector<int>>& input_nodes,
std::vector<const Node*>* body_nodes)
EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
TF_EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
if (num_opers == -1) {
for (const Node* node : fn_body->graph.op_nodes()) {
const auto& iter = input_nodes.find(node);

View File

@ -71,14 +71,14 @@ struct TF_Graph {
TF_Graph();
tensorflow::mutex mu;
tensorflow::Graph graph GUARDED_BY(mu);
tensorflow::Graph graph TF_GUARDED_BY(mu);
// Runs shape inference.
tensorflow::ShapeRefiner refiner GUARDED_BY(mu);
tensorflow::ShapeRefiner refiner TF_GUARDED_BY(mu);
// Maps from name of an operation to the Node* in 'graph'.
std::unordered_map<tensorflow::string, tensorflow::Node*> name_map
GUARDED_BY(mu);
TF_GUARDED_BY(mu);
// The keys of this map are all the active sessions using this graph. Each
// value records whether the graph has been mutated since the corresponding
@ -94,8 +94,8 @@ struct TF_Graph {
// TODO(b/74949947): mutations currently trigger a warning instead of a bad
// status, this should be reverted when possible.
tensorflow::gtl::FlatMap<TF_Session*, tensorflow::string> sessions
GUARDED_BY(mu);
bool delete_requested GUARDED_BY(mu); // set true by TF_DeleteGraph
TF_GUARDED_BY(mu);
bool delete_requested TF_GUARDED_BY(mu); // set true by TF_DeleteGraph
// Used to link graphs contained in TF_WhileParams to the parent graph that
// will eventually contain the full while loop.
@ -123,7 +123,7 @@ struct TF_Session {
tensorflow::Session* session;
TF_Graph* const graph;
tensorflow::mutex mu ACQUIRED_AFTER(TF_Graph::mu);
tensorflow::mutex mu TF_ACQUIRED_AFTER(TF_Graph::mu);
int last_num_graph_nodes;
// If true, TF_SessionRun and similar methods will call
@ -169,9 +169,9 @@ struct TF_ApiDefMap {
}
#if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
tensorflow::ApiDefMap api_def_map GUARDED_BY(lock);
tensorflow::ApiDefMap api_def_map TF_GUARDED_BY(lock);
#endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
bool update_docs_called GUARDED_BY(lock);
bool update_docs_called TF_GUARDED_BY(lock);
tensorflow::mutex lock;
};
@ -210,10 +210,10 @@ void TF_GraphSetOutputHandleShapesAndTypes(TF_Graph* graph, TF_Output output,
void RecordMutation(TF_Graph* graph, const TF_Operation& op,
const char* mutation_type)
EXCLUSIVE_LOCKS_REQUIRED(graph->mu);
TF_EXCLUSIVE_LOCKS_REQUIRED(graph->mu);
bool ExtendSessionGraphHelper(TF_Session* session, TF_Status* status)
LOCKS_EXCLUDED(session->graph->mu, session->mu);
TF_LOCKS_EXCLUDED(session->graph->mu, session->mu);
std::string getTF_OutputDebugString(TF_Output node);

View File

@ -354,6 +354,7 @@ cc_library(
"//tensorflow/core:lib",
"@dlpack",
],
alwayslink = 1,
)
# TODO(karllessard): only used by //tensorflow/core:mobile_srcs_only_runtime

View File

@ -1710,8 +1710,9 @@ void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op,
namespace {
class CustomDeviceAPI : public tensorflow::CustomDevice {
public:
CustomDeviceAPI(TFE_CustomDevice device, void* info, string name)
: device_(device), info_(info), name_(name) {}
CustomDeviceAPI(TFE_Context* context, TFE_CustomDevice device, void* info,
string name)
: context_(context), device_(device), info_(info), name_(name) {}
~CustomDeviceAPI() override { device_.delete_device(info_); }
@ -1725,7 +1726,7 @@ class CustomDeviceAPI : public tensorflow::CustomDevice {
std::make_unique<tensorflow::TensorHandleInterface>(tensor)};
TF_Status status;
TFE_TensorHandle* result_handle =
device_.copy_tensor_to_device(&tensor_handle, &status, info_);
device_.copy_tensor_to_device(context_, &tensor_handle, &status, info_);
if (!status.status.ok()) return status.status;
*result = tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
result_handle->handle.get())
@ -1744,7 +1745,7 @@ class CustomDeviceAPI : public tensorflow::CustomDevice {
TFE_TensorHandle tensor_handle{
std::make_unique<tensorflow::TensorHandleInterface>(tensor)};
TFE_TensorHandle* result_handle = device_.copy_tensor_from_device(
&tensor_handle, target_device_name.c_str(), &status, info_);
context_, &tensor_handle, target_device_name.c_str(), &status, info_);
if (!status.status.ok()) return status.status;
*result = tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
result_handle->handle.get())
@ -1768,7 +1769,7 @@ class CustomDeviceAPI : public tensorflow::CustomDevice {
std::vector<TFE_TensorHandle*> outputs(*num_retvals);
TF_Status status;
TFE_OpAttrs attributes(&op->Attrs(), op->Name().c_str());
device_.execute(inputs.size(), inputs.data(), op->Name().c_str(),
device_.execute(context_, inputs.size(), inputs.data(), op->Name().c_str(),
&attributes, num_retvals, outputs.data(), &status, info_);
if (status.status.ok()) {
for (int i = 0; i < *num_retvals; ++i) {
@ -1787,6 +1788,7 @@ class CustomDeviceAPI : public tensorflow::CustomDevice {
}
private:
TFE_Context* context_;
TFE_CustomDevice device_;
void* info_;
string name_;
@ -1794,8 +1796,10 @@ class CustomDeviceAPI : public tensorflow::CustomDevice {
} // namespace
void TFE_RegisterCustomDevice(TFE_Context* ctx, TFE_CustomDevice device,
const char* device_name, void* device_info) {
const char* device_name, void* device_info,
TF_Status* status) {
auto custom_device =
std::make_unique<CustomDeviceAPI>(device, device_info, device_name);
ctx->context->RegisterCustomDevice(device_name, std::move(custom_device));
std::make_unique<CustomDeviceAPI>(ctx, device, device_info, device_name);
status->status =
ctx->context->RegisterCustomDevice(device_name, std::move(custom_device));
}

View File

@ -458,27 +458,29 @@ TF_CAPI_EXPORT extern void TFE_OpSetAttrValueProto(const TFE_Op* op,
size_t proto_len,
TF_Status* status);
#define TFE_CUSTOM_DEVICE_VERSION 1
#define TFE_CUSTOM_DEVICE_VERSION 2
// Struct to be filled in
typedef struct TFE_CustomDevice {
int version = TFE_CUSTOM_DEVICE_VERSION;
// Method to copy a tensor to the custom device.
TFE_TensorHandle* (*copy_tensor_to_device)(TFE_TensorHandle* tensor,
TFE_TensorHandle* (*copy_tensor_to_device)(TFE_Context* context,
TFE_TensorHandle* tensor,
TF_Status* status,
void* device_info) = nullptr;
// Method to copy a tensor from the custom device to a target device.
TFE_TensorHandle* (*copy_tensor_from_device)(TFE_TensorHandle* tensor,
TFE_TensorHandle* (*copy_tensor_from_device)(TFE_Context* context,
TFE_TensorHandle* tensor,
const char* target_device_name,
TF_Status* status,
void* device_info);
// Method to execute an operation.
void (*execute)(int num_inputs, TFE_TensorHandle** inputs,
const char* operation_name, const TFE_OpAttrs* attributes,
int* num_outputs, TFE_TensorHandle** outputs, TF_Status* s,
void* device_info);
void (*execute)(TFE_Context* context, int num_inputs,
TFE_TensorHandle** inputs, const char* operation_name,
const TFE_OpAttrs* attributes, int* num_outputs,
TFE_TensorHandle** outputs, TF_Status* s, void* device_info);
// Method to delete a device.
void (*delete_device)(void* device_info);
@ -503,11 +505,21 @@ typedef struct TFE_CustomDevice {
// devices, so executing tf.functions which contain operations placed on custom
// devices will fail.
//
// `device_name` must not name an existing physical or custom device. It must
// follow the format:
//
// /job:<name>/replica:<replica>/task:<task>/device:<type>:<device_num>
//
// If the device is successfully registered, `status` is set to TF_OK. Otherwise
// the device is not usable. In case of a bad status, `device.delete_device` is
// still called on `device_info` (i.e. the caller does not retain ownership).
//
// This API is highly experimental, and in particular is expected to change when
// it starts supporting operations with attributes and when tf.function support
// is added.
void TFE_RegisterCustomDevice(TFE_Context* ctx, TFE_CustomDevice device,
const char* device_name, void* device_info);
const char* device_name, void* device_info,
TF_Status* status);
TF_CAPI_EXPORT extern void TFE_ContextGetFunctionDef(TFE_Context* ctx,
const char* function_name,

View File

@ -27,7 +27,6 @@ limitations under the License.
namespace {
struct LoggingDevice {
TFE_Context* ctx;
tensorflow::string device_name;
tensorflow::string underlying_device;
// Set to true whenever a TensorHandle is copied onto the device
@ -48,7 +47,7 @@ void LoggedTensorDeallocator(void* data, size_t len, void* arg) {
}
TFE_TensorHandle* MakeLoggedTensorHandle(
TFE_Context* ctx, const tensorflow::string& logging_device_name,
TFE_Context* context, const tensorflow::string& logging_device_name,
std::unique_ptr<LoggedTensor> t, TF_Status* status) {
std::vector<int64_t> shape(TFE_TensorHandleNumDims(t->tensor, status));
if (TF_GetCode(status) != TF_OK) return nullptr;
@ -58,23 +57,25 @@ TFE_TensorHandle* MakeLoggedTensorHandle(
}
auto dtype = TFE_TensorHandleDataType(t->tensor);
return TFE_NewTensorHandleFromDeviceMemory(
ctx, logging_device_name.c_str(), dtype, shape.data(), shape.size(),
context, logging_device_name.c_str(), dtype, shape.data(), shape.size(),
t.release(), 1, &LoggedTensorDeallocator, nullptr, status);
}
TFE_TensorHandle* CopyToLoggingDevice(TFE_TensorHandle* tensor,
TFE_TensorHandle* CopyToLoggingDevice(TFE_Context* context,
TFE_TensorHandle* tensor,
TF_Status* status, void* device_info) {
LoggingDevice* dev = reinterpret_cast<LoggingDevice*>(device_info);
TFE_TensorHandle* t = TFE_TensorHandleCopyToDevice(
tensor, dev->ctx, dev->underlying_device.c_str(), status);
tensor, context, dev->underlying_device.c_str(), status);
if (TF_GetCode(status) != TF_OK) return nullptr;
auto dst = std::make_unique<LoggedTensor>(t);
*(dev->arrived_flag) = true;
return MakeLoggedTensorHandle(dev->ctx, dev->device_name, std::move(dst),
return MakeLoggedTensorHandle(context, dev->device_name, std::move(dst),
status);
}
TFE_TensorHandle* CopyTensorFromLoggingDevice(TFE_TensorHandle* tensor,
TFE_TensorHandle* CopyTensorFromLoggingDevice(TFE_Context* context,
TFE_TensorHandle* tensor,
const char* target_device_name,
TF_Status* status,
void* device_info) {
@ -83,13 +84,13 @@ TFE_TensorHandle* CopyTensorFromLoggingDevice(TFE_TensorHandle* tensor,
return nullptr;
}
void LoggingDeviceExecute(int num_inputs, TFE_TensorHandle** inputs,
const char* operation_name,
void LoggingDeviceExecute(TFE_Context* context, int num_inputs,
TFE_TensorHandle** inputs, const char* operation_name,
const TFE_OpAttrs* attributes, int* num_outputs,
TFE_TensorHandle** outputs, TF_Status* s,
void* device_info) {
LoggingDevice* dev = reinterpret_cast<LoggingDevice*>(device_info);
TFE_Op* op(TFE_NewOp(dev->ctx, operation_name, s));
TFE_Op* op(TFE_NewOp(context, operation_name, s));
if (TF_GetCode(s) != TF_OK) return;
TFE_OpAddAttrs(op, attributes);
TFE_OpSetDevice(op, dev->underlying_device.c_str(), s);
@ -117,7 +118,7 @@ void LoggingDeviceExecute(int num_inputs, TFE_TensorHandle** inputs,
}
for (int i = 0; i < *num_outputs; ++i) {
auto logged_tensor = std::make_unique<LoggedTensor>(unwrapped_outputs[i]);
outputs[i] = MakeLoggedTensorHandle(dev->ctx, dev->device_name,
outputs[i] = MakeLoggedTensorHandle(context, dev->device_name,
std::move(logged_tensor), s);
}
*(dev->executed_flag) = true;
@ -128,19 +129,19 @@ void DeleteLoggingDevice(void* device_info) {
}
void RegisterLoggingDevice(TFE_Context* context, const char* name,
bool* arrived_flag, bool* executed_flag) {
bool* arrived_flag, bool* executed_flag,
TF_Status* status) {
TFE_CustomDevice custom_device;
custom_device.copy_tensor_to_device = &CopyToLoggingDevice;
custom_device.copy_tensor_from_device = &CopyTensorFromLoggingDevice;
custom_device.delete_device = &DeleteLoggingDevice;
custom_device.execute = &LoggingDeviceExecute;
LoggingDevice* device = new LoggingDevice;
device->ctx = context;
device->arrived_flag = arrived_flag;
device->executed_flag = executed_flag;
device->device_name = name;
device->underlying_device = "/job:localhost/replica:0/task:0/device:CPU:0";
TFE_RegisterCustomDevice(context, custom_device, name, device);
TFE_RegisterCustomDevice(context, custom_device, name, device, status);
}
TEST(CUSTOM_DEVICE, RegisterSimpleDevice) {
@ -153,7 +154,8 @@ TEST(CUSTOM_DEVICE, RegisterSimpleDevice) {
bool arrived = false;
bool executed = false;
const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
RegisterLoggingDevice(context, name, &arrived, &executed);
RegisterLoggingDevice(context, name, &arrived, &executed, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_TensorHandle* hcpu = TestMatrixTensorHandle();
ASSERT_FALSE(arrived);
TFE_TensorHandle* hdevice =
@ -189,7 +191,9 @@ TEST(CUSTOM_DEVICE, ResetOperation) {
bool executed = false;
const char* custom_device_name =
"/job:localhost/replica:0/task:0/device:CUSTOM:0";
RegisterLoggingDevice(context.get(), custom_device_name, &arrived, &executed);
RegisterLoggingDevice(context.get(), custom_device_name, &arrived, &executed,
status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> reused_op(
TFE_NewOp(context.get(), "Identity", status.get()), TFE_DeleteOp);
@ -217,7 +221,8 @@ TEST(CUSTOM_DEVICE, MakeVariable) {
bool arrived = false;
bool executed = false;
const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
RegisterLoggingDevice(context.get(), name, &arrived, &executed);
RegisterLoggingDevice(context.get(), name, &arrived, &executed, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
// Create a variable handle placed on the custom device.
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
@ -291,4 +296,103 @@ TEST(CUSTOM_DEVICE, MakeVariable) {
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
}
TEST(CUSTOM_DEVICE, AccessVariableOnWrongDevice) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
TFE_NewContextOptions(), TFE_DeleteContextOptions);
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
bool arrived = false;
bool executed = false;
const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
RegisterLoggingDevice(context.get(), name, &arrived, &executed, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
// Create a variable handle placed on the custom device.
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
TFE_NewOp(context.get(), "VarHandleOp", status.get()), TFE_DeleteOp);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT);
TFE_OpSetAttrShape(op.get(), "shape", {}, 0, status.get());
TFE_OpSetAttrString(op.get(), "container", "", 0);
TFE_OpSetAttrString(op.get(), "shared_name", "", 0);
TFE_OpSetDevice(op.get(), name, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_TensorHandle* var_handle = nullptr;
int num_retvals = 1;
executed = false;
TFE_Execute(op.get(), &var_handle, &num_retvals, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ASSERT_TRUE(executed);
auto handle_cleaner = tensorflow::gtl::MakeCleanup(
[var_handle]() { TFE_DeleteTensorHandle(var_handle); });
// Assign to the variable, copying to the custom device.
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)> one(
TestScalarTensorHandle(111.f), TFE_DeleteTensorHandle);
op.reset(TFE_NewOp(context.get(), "AssignVariableOp", status.get()));
TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT);
TFE_OpAddInput(op.get(), var_handle, status.get());
TFE_OpAddInput(op.get(), one.get(), status.get());
TFE_OpSetDevice(op.get(), name, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
executed = false;
num_retvals = 0;
TFE_Execute(op.get(), nullptr, &num_retvals, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ASSERT_TRUE(executed);
// Read the variable's value.
op.reset(TFE_NewOp(context.get(), "ReadVariableOp", status.get()));
TFE_OpAddInput(op.get(), var_handle, status.get());
TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
executed = false;
num_retvals = 1;
TFE_TensorHandle* var_value = nullptr;
TFE_Execute(op.get(), &var_value, &num_retvals, status.get());
EXPECT_FALSE(TF_GetCode(status.get()) == TF_OK)
<< "Execution should fail because the variable is being used on the "
"wrong device.";
// Free the backing buffer for the variable.
op.reset(TFE_NewOp(context.get(), "DestroyResourceOp", status.get()));
TFE_OpAddInput(op.get(), var_handle, status.get());
TFE_OpSetDevice(op.get(), name, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
num_retvals = 0;
TFE_Execute(op.get(), nullptr, &num_retvals, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
}
TEST(CUSTOM_DEVICE, InvalidRegistrationError) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
TFE_NewContextOptions(), TFE_DeleteContextOptions);
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
bool arrived = false;
bool executed = false;
RegisterLoggingDevice(context.get(), "/device:CUSTOM:0", &arrived, &executed,
status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_INVALID_ARGUMENT)
<< TF_Message(status.get());
const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
RegisterLoggingDevice(context.get(), name, &arrived, &executed, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
RegisterLoggingDevice(context.get(), name, &arrived, &executed, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_ALREADY_EXISTS)
<< TF_Message(status.get());
RegisterLoggingDevice(context.get(),
"/job:localhost/replica:0/task:0/device:CPU:0",
&arrived, &executed, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_ALREADY_EXISTS)
<< TF_Message(status.get());
}
} // namespace

View File

@ -27,14 +27,10 @@ namespace {
class DummyDevice : public DeviceBase {
public:
DummyDevice(Env* env, bool save) : DeviceBase(env), save_(save) {}
bool RequiresRecordingAccessedTensors() const override { return save_; }
explicit DummyDevice(Env* env) : DeviceBase(env) {}
Allocator* GetAllocator(AllocatorAttributes /*attr*/) override {
return cpu_allocator();
}
private:
bool save_;
};
void TestBitcastOp(Tensor* input_tensor, DataType out_type,
@ -61,7 +57,7 @@ void TestBitcastOp(Tensor* input_tensor, DataType out_type,
ASSERT_TRUE(status.ok()) << status.ToString();
OpKernelContext::Params params;
DummyDevice dummy_device(nullptr, false);
DummyDevice dummy_device(nullptr);
params.device = &dummy_device;
params.op_kernel = kernel.get();
gtl::InlinedVector<TensorValue, 4> inputs;

View File

@ -155,14 +155,10 @@ TEST(TestKernel, TestRegisterKernelBuilder) {
class DummyDevice : public DeviceBase {
public:
DummyDevice(Env* env, bool save) : DeviceBase(env), save_(save) {}
bool RequiresRecordingAccessedTensors() const override { return save_; }
explicit DummyDevice(Env* env) : DeviceBase(env) {}
Allocator* GetAllocator(AllocatorAttributes /*attr*/) override {
return cpu_allocator();
}
private:
bool save_;
};
TEST(TestKernel, TestInputAndOutputCount) {
@ -223,7 +219,7 @@ TEST(TestKernel, TestInputAndOutputCount) {
{
OpKernelContext::Params p;
DummyDevice dummy_device(nullptr, false);
DummyDevice dummy_device(nullptr);
p.device = &dummy_device;
p.step_id = 43;

View File

@ -58,9 +58,9 @@ extern "C" {
// start_offset: array[uint64]
// data: byte[...]
//
// The string length (as a varint), followed by the contents of the string
// is encoded at data[start_offset[i]]]. TF_StringEncode and TF_StringDecode
// facilitate this encoding.
// The string length (as a varint, start_offset[i + 1] - start_offset[i]),
// followed by the contents of the string is encoded at data[start_offset[i]].
// TF_StringEncode and TF_StringDecode facilitate this encoding.
typedef struct TF_Tensor TF_Tensor;

View File

@ -41,7 +41,7 @@ class ClientSession::Impl {
std::shared_ptr<Graph> graph_;
mutable mutex mu_;
mutable int last_num_graph_nodes_ GUARDED_BY(mu_) = 0;
mutable int last_num_graph_nodes_ TF_GUARDED_BY(mu_) = 0;
};
ClientSession::ClientSession(const Scope& scope, const string& target)

View File

@ -114,14 +114,14 @@ class Coordinator {
condition_variable wait_for_stop_;
mutex mu_;
bool should_stop_ GUARDED_BY(mu_);
bool should_stop_ TF_GUARDED_BY(mu_);
mutex status_lock_;
Status status_ GUARDED_BY(status_lock_);
Status status_ TF_GUARDED_BY(status_lock_);
mutable mutex runners_lock_;
std::vector<std::unique_ptr<RunnerInterface>> runners_
GUARDED_BY(runners_lock_);
TF_GUARDED_BY(runners_lock_);
TF_DISALLOW_COPY_AND_ASSIGN(Coordinator);
};

View File

@ -119,8 +119,8 @@ class QueueRunner : public RunnerInterface {
std::unique_ptr<thread::ThreadPool> thread_pool_;
mutex mu_;
int runs_ = 0;
Status status_ GUARDED_BY(mu_);
Status enqueue_status_ GUARDED_BY(mu_);
Status status_ TF_GUARDED_BY(mu_);
Status enqueue_status_ TF_GUARDED_BY(mu_);
std::unique_ptr<BlockingCounter> counter_;
Coordinator* coord_;
@ -131,7 +131,7 @@ class QueueRunner : public RunnerInterface {
std::vector<std::function<void(Status)>> callbacks_;
mutable std::unique_ptr<mutex> cg_mu_;
std::unique_ptr<CostGraphDef> cost_graph_ GUARDED_BY(cg_mu_);
std::unique_ptr<CostGraphDef> cost_graph_ TF_GUARDED_BY(cg_mu_);
RunOptions run_options_;
};

View File

@ -37,6 +37,7 @@ cc_library(
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
"//tensorflow/compiler/mlir/lite/quantization/xla:quantize",
"//tensorflow/compiler/tf2xla",
"//tensorflow/compiler/tf2xla:mlir_tf2xla",
"//tensorflow/compiler/tf2xla:tf2xla_proto_cc",
@ -64,6 +65,7 @@ cc_library(
"@llvm-project//llvm:powerpc_code_gen", # fixdeps: keep
"@llvm-project//llvm:target",
"@llvm-project//llvm:x86_code_gen", # fixdeps: keep
"//tensorflow/core:regexp_internal",
] + if_llvm_aarch64_available([
"//third_party/llvm/llvm-project/llvm:aarch64_target", # fixdeps: keep
]),

View File

@ -26,6 +26,7 @@ limitations under the License.
#include "absl/strings/str_split.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/aot/embedded_protocol_buffers.h"
#include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include "tensorflow/compiler/xla/cpu_function_runtime.h"
#include "tensorflow/compiler/xla/service/compiler.h"
@ -288,8 +289,8 @@ Status GenVariableMethods(const tf2xla::Config& config,
}
// Generates code implementing {Arg,Result}Names(), where T is one of
// tf2xla::{Feed,Fetch}. Each feed or fetch name results in a C-style string
// literal in the array, with nullptr terminating the array.
// tf2xla::{Feed,Fetch,Variable}. Each feed or fetch name results in a C-style
// string literal in the array, with nullptr terminating the array.
template <typename T>
string GenNameToIndexCode(const T& entries, bool generate) {
// No need for a static array if we're not supposed to generate the data.
@ -419,6 +420,16 @@ Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config,
// Generate metadata.
const string arg_names_code =
GenNameToIndexCode(config.feed(), opts.gen_name_to_index);
auto variable_copy = config.variable();
for (auto& var : variable_copy) {
if (var.name().empty()) {
var.set_name(var.node_name());
}
}
const string variable_names_code =
GenNameToIndexCode(variable_copy, opts.gen_name_to_index);
const string result_names_code =
GenNameToIndexCode(config.fetch(), opts.gen_name_to_index);
const string include_xla_data_proto =
@ -507,6 +518,9 @@ class {{CLASS}} final : public tensorflow::XlaCompiledCpuFunction {
// Number of input arguments for the compiled computation.
static constexpr size_t kNumArgs = {{ARG_NUM}};
// Number of variables for the compiled computation.
static constexpr size_t kNumVariables = {{VARIABLE_NUM}};
// Byte size of each argument buffer. There are kNumArgs entries.
static const ::tensorflow::int64 ArgSize(::tensorflow::int32 index) {
return BufferInfos()[ArgIndexToBufferIndex()[index]].size();
@ -522,8 +536,10 @@ class {{CLASS}} final : public tensorflow::XlaCompiledCpuFunction {
set_static_data_num_buffers(data, kNumBuffers);
set_static_data_arg_index_table(data, ArgIndexToBufferIndex());
set_static_data_num_args(data, kNumArgs);
set_static_data_num_variables(data, kNumVariables);
set_static_data_result_index(data, kResultIndex);
set_static_data_arg_names(data, StaticArgNames());
set_static_data_variable_names(data, StaticVariableNames());
set_static_data_result_names(data, StaticResultNames());
set_static_data_program_shape(data, StaticProgramShape());
set_static_data_hlo_profile_printer_data(
@ -626,6 +642,9 @@ class {{CLASS}} final : public tensorflow::XlaCompiledCpuFunction {
// Array of names of each positional argument, terminated by nullptr.
static const char** StaticArgNames() {{ARG_NAMES_CODE}}
// Array of names of each positional variable, terminated by nullptr.
static const char** StaticVariableNames() {{VARIABLE_NAMES_CODE}}
// Array of names of each positional result, terminated by nullptr.
static const char** StaticResultNames() {{RESULT_NAMES_CODE}}
@ -654,6 +673,7 @@ class {{CLASS}} final : public tensorflow::XlaCompiledCpuFunction {
{"{{ARG_BYTES_TOTAL}}", absl::StrCat(arg_bytes_total)},
{"{{ARG_NAMES_CODE}}", arg_names_code},
{"{{ARG_NUM}}", absl::StrCat(arg_index_table.size())},
{"{{VARIABLE_NUM}}", absl::StrCat(config.variable_size())},
{"{{ARG_INDEX_TABLE}}", absl::StrJoin(arg_index_table, ", ")},
{"{{ASSIGN_PROFILE_COUNTERS_SIZE}}", assign_profile_counters_size},
{"{{CLASS}}", opts.class_name},
@ -673,6 +693,7 @@ class {{CLASS}} final : public tensorflow::XlaCompiledCpuFunction {
{"{{PROGRAM_SHAPE}}", xla::ShapeUtil::HumanString(xla::ProgramShape(ps))},
{"{{PROGRAM_SHAPE_SHIM_EXPRESSION}}",
metadata_result.program_shape_access_shim},
{"{{VARIABLE_NAMES_CODE}}", variable_names_code},
{"{{RESULT_INDEX}}", absl::StrCat(result_index)},
{"{{RESULT_NAMES_CODE}}", result_names_code},
{"{{TEMP_BYTES_ALIGNED}}", absl::StrCat(temp_bytes_aligned)},

View File

@ -156,17 +156,14 @@ static void CompareWithGoldenFile(
// bazel test --test_strategy=local \
// third_party/tensorflow/compiler/aot:codegen_test
const bool update_golden = false;
string golden_file_name;
string golden_file_name =
GetDataDependencyFilepath(tensorflow_relative_golden_file_name);
if (update_golden) {
golden_file_name = io::JoinPath(testing::TensorFlowSrcRoot(),
tensorflow_relative_golden_file_name);
TF_EXPECT_OK(
WriteStringToFile(Env::Default(), golden_file_name, expected_contents));
}
golden_file_name =
GetDataDependencyFilepath(tensorflow_relative_golden_file_name);
string golden_file_contents;
TF_ASSERT_OK(ReadFileToString(Env::Default(), golden_file_name,
&golden_file_contents));
@ -220,10 +217,16 @@ TEST(CodegenTest, Golden) {
{},
{BufferInfo::MakeTempBuffer(1),
BufferInfo::MakeEntryParameter(/*size=*/8, /*param_number=*/0),
BufferInfo::MakeTempBuffer(2),
BufferInfo::MakeTempBuffer(1),
BufferInfo::MakeEntryParameter(/*size=*/96, /*param_number=*/1),
BufferInfo::MakeTempBuffer(3), BufferInfo::MakeTempBuffer(120)},
5, {}));
BufferInfo::MakeTempBuffer(1),
BufferInfo::MakeEntryParameter(/*size=*/96, /*param_number=*/2),
BufferInfo::MakeTempBuffer(1),
BufferInfo::MakeEntryParameter(/*size=*/96, /*param_number=*/3),
BufferInfo::MakeTempBuffer(1),
BufferInfo::MakeEntryParameter(/*size=*/96, /*param_number=*/4),
BufferInfo::MakeTempBuffer(1), BufferInfo::MakeTempBuffer(120)},
11, {}));
compile_result.program_shape =
xla::ShapeUtil::MakeProgramShape(
{

View File

@ -55,14 +55,17 @@ namespace bar {
// ((unknown): f32[1,2], (unknown): s64[3,4], (unknown): f32[1], (unknown): f32[1], (unknown): s32[5]) -> (u32[5,6], f32[1], s32[5])
//
// Memory stats:
// arg bytes total: 104
// arg bytes aligned: 192
// arg bytes total: 392
// arg bytes aligned: 576
// temp bytes total: 126
// temp bytes aligned: 320
// temp bytes aligned: 512
class MyClass final : public tensorflow::XlaCompiledCpuFunction {
public:
// Number of input arguments for the compiled computation.
static constexpr size_t kNumArgs = 2;
static constexpr size_t kNumArgs = 5;
// Number of variables for the compiled computation.
static constexpr size_t kNumVariables = 3;
// Byte size of each argument buffer. There are kNumArgs entries.
static const ::tensorflow::int64 ArgSize(::tensorflow::int32 index) {
@ -79,8 +82,10 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction {
set_static_data_num_buffers(data, kNumBuffers);
set_static_data_arg_index_table(data, ArgIndexToBufferIndex());
set_static_data_num_args(data, kNumArgs);
set_static_data_num_variables(data, kNumVariables);
set_static_data_result_index(data, kResultIndex);
set_static_data_arg_names(data, StaticArgNames());
set_static_data_variable_names(data, StaticVariableNames());
set_static_data_result_names(data, StaticResultNames());
set_static_data_program_shape(data, StaticProgramShape());
set_static_data_hlo_profile_printer_data(
@ -295,16 +300,22 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction {
private:
// Number of buffers for the compiled computation.
static constexpr size_t kNumBuffers = 6;
static constexpr size_t kNumBuffers = 12;
static const ::xla::cpu_function_runtime::BufferInfo* BufferInfos() {
static const ::xla::cpu_function_runtime::BufferInfo
kBufferInfos[kNumBuffers] = {
::xla::cpu_function_runtime::BufferInfo({5ULL, ~0ULL}),
::xla::cpu_function_runtime::BufferInfo({34ULL, 0ULL}),
::xla::cpu_function_runtime::BufferInfo({9ULL, ~0ULL}),
::xla::cpu_function_runtime::BufferInfo({5ULL, ~0ULL}),
::xla::cpu_function_runtime::BufferInfo({386ULL, 1ULL}),
::xla::cpu_function_runtime::BufferInfo({13ULL, ~0ULL}),
::xla::cpu_function_runtime::BufferInfo({5ULL, ~0ULL}),
::xla::cpu_function_runtime::BufferInfo({386ULL, 2ULL}),
::xla::cpu_function_runtime::BufferInfo({5ULL, ~0ULL}),
::xla::cpu_function_runtime::BufferInfo({386ULL, 3ULL}),
::xla::cpu_function_runtime::BufferInfo({5ULL, ~0ULL}),
::xla::cpu_function_runtime::BufferInfo({386ULL, 4ULL}),
::xla::cpu_function_runtime::BufferInfo({5ULL, ~0ULL}),
::xla::cpu_function_runtime::BufferInfo({481ULL, ~0ULL})
};
return kBufferInfos;
@ -312,13 +323,13 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction {
static const ::tensorflow::int32* ArgIndexToBufferIndex() {
static constexpr ::tensorflow::int32 kArgIndexToBufferIndex[kNumArgs] = {
1, 3
1, 3, 5, 7, 9
};
return kArgIndexToBufferIndex;
}
// The 0-based index of the result tuple in the temporary buffers.
static constexpr size_t kResultIndex = 5;
static constexpr size_t kResultIndex = 11;
// Array of names of each positional argument, terminated by nullptr.
static const char** StaticArgNames() {
@ -326,6 +337,12 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction {
return kNames;
}
// Array of names of each positional variable, terminated by nullptr.
static const char** StaticVariableNames() {
static const char* kNames[] = {"myvar_readonly", "myvar", "myvar2", nullptr};
return kNames;
}
// Array of names of each positional result, terminated by nullptr.
static const char** StaticResultNames() {
static const char* kNames[] = {"myfetch", nullptr};

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "llvm-c/Target.h"
#include "tensorflow/compiler/aot/codegen.h"
#include "tensorflow/compiler/aot/flags.h"
#include "tensorflow/compiler/mlir/lite/quantization/xla/quantize.h"
#include "tensorflow/compiler/tf2xla/tf2xla.h"
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include "tensorflow/compiler/xla/client/client_library.h"
@ -39,6 +40,7 @@ limitations under the License.
#include "tensorflow/core/lib/strings/proto_serialization.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/regexp.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
@ -105,14 +107,18 @@ Status CompileGraph(GraphDef graph_def, const tf2xla::Config& config,
.ValueOrDie();
xla::XlaComputation computation;
if (flags.mlir_components == "Bridge") {
TF_RETURN_IF_ERROR(
ConvertGraphDefToXlaViaMlir(graph_def, config, &computation));
TF_RETURN_IF_ERROR(ConvertGraphDefToXlaViaMlir(
graph_def, config, &computation, flags.debug_info,
flags.debug_info_path_begin_marker));
} else if (flags.mlir_components.empty() || flags.mlir_components == "None") {
TF_RETURN_IF_ERROR(ConvertGraphDefToXla(std::move(graph_def), config,
client, &computation));
} else {
return errors::Unknown("Unknown mlir_components ", flags.mlir_components);
}
if (flags.quantize) {
TF_RETURN_IF_ERROR(mlir::xla_hlo::XlaQuantize(config, &computation));
}
if (!flags.out_session_module.empty()) {
TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::HloSnapshot> module,
computation.Snapshot());
@ -166,6 +172,23 @@ static void InitializeTargets() {
LLVMInitializeX86AsmPrinter();
}
// Replaces {{tag.type tag.name}} in the error message with tag_name.
// TODO(bixia): We currently only handlge tag.type == "node".
//
// In the error message, a graph node is represented as {{tag.type, tag.name}},
// to allow a Python debugger to insert source information about the graph node.
// For example, a Python add expression may be represented as
// {{node, x_y_sum}} = Add(x, y) in the error message. See routine interpolate
// in tensorflow/python/framework/error_interpolation.py for more detail.
static std::string InterpolateErrorMessage(std::string message) {
// See _NAME_REGEX in tensorflow/python/framework/error_interpolation.py
// Change "prefix {{node tag.name}} suffix" to "prefix tag.name suffix".
static LazyRE2 pattern{"(.*){{node (.*)}}(.*)"};
RE2::GlobalReplace(&message, *pattern, "\\1\\2\\3");
return message;
}
Status Main(const MainFlags& flags) {
absl::call_once(targets_init, &InitializeTargets);
@ -192,8 +215,13 @@ Status Main(const MainFlags& flags) {
GraphDef graph_def;
TF_RETURN_IF_ERROR(ReadProtoFile(flags.graph, &graph_def));
CompileResult compile_result;
TF_RETURN_IF_ERROR(
CompileGraph(std::move(graph_def), config, flags, &compile_result));
Status status =
CompileGraph(std::move(graph_def), config, flags, &compile_result);
if (!status.ok()) {
return Status(status.code(),
InterpolateErrorMessage(status.error_message()));
}
// Write output files.
Env* env = Env::Default();

View File

@ -24,6 +24,13 @@ void AppendMainFlags(std::vector<Flag>* flag_list, MainFlags* flags) {
"Input GraphDef file. If the file ends in '.pbtxt' it is expected to "
"be in the human-readable proto text format, otherwise it is expected "
"to be in the proto binary format."},
{"debug_info", &flags->debug_info,
"Graph debug info file. If the file ends in '.pbtxt' it is expected to "
"be in the human-readable proto text format, otherwise it is expected "
"to be in the proto binary format."},
{"debug_info_path_begin_marker", &flags->debug_info_path_begin_marker,
"If not none, only keep the file path in the debug information after the"
" marker. The default value is empty"},
{"config", &flags->config,
"Input file containing Config proto. If the file ends in '.pbtxt' it "
"is expected to be in the human-readable proto text format, otherwise "
@ -70,6 +77,8 @@ void AppendMainFlags(std::vector<Flag>* flag_list, MainFlags* flags) {
"Output session module proto."},
{"mlir_components", &flags->mlir_components,
"The MLIR components to enable. Currently only Bridge is supported."},
{"quantize", &flags->quantize,
"If set, quantization will be applied before HLO code generation."},
{"gen_name_to_index", &flags->gen_name_to_index,
"Generate name-to-index data for Lookup{Arg,Result}Index methods."},
{"gen_program_shape", &flags->gen_program_shape,

View File

@ -28,6 +28,8 @@ namespace tfcompile {
struct MainFlags {
string graph;
string debug_info;
string debug_info_path_begin_marker;
string config;
bool dump_fetch_nodes = false;
string target_triple;
@ -40,6 +42,7 @@ struct MainFlags {
string out_header;
string out_session_module;
string mlir_components;
bool quantize = false;
// C++ codegen options
bool gen_name_to_index = false;

View File

@ -1,11 +1,37 @@
load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library")
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests")
package(
default_visibility = ["//visibility:private"],
licenses = ["notice"], # Apache 2.0
)
glob_lit_tests(
data = [":filecheck_test_utilities"],
driver = "@llvm-project//mlir:run_lit.sh",
tags_override = {
"test_error_message.lit.pbtxt": ["no_oss"], # TODO(b/150957738): to be fixed on oss.
},
test_file_exts = ["lit.pbtxt"],
)
# Bundle together all of the test utilities that are used by tests.
filegroup(
name = "filecheck_test_utilities",
testonly = True,
srcs = [
"test_error_message.lit.pbtxt.config.pbtxt",
"test_error_message.lit.pbtxt.debug.pbtxt",
"test_error_message.lit.pbtxt.fake_py.debug",
],
data = [
"//tensorflow/compiler/aot:tfcompile",
"@llvm-project//llvm:FileCheck",
"@llvm-project//llvm:not",
],
)
# We disable some tfcompile tests in the open source build with the
# "manual" tag to avoid making our OSS users build LLVM twice
# (once for host and once for target).
@ -60,6 +86,7 @@ genrule(
testonly = 1,
outs = [
"test_graph_tfadd.pb",
"test_debuginfo_tfadd.pb",
"test_graph_tfadd_with_ckpt.ckpt",
"test_graph_tfadd_with_ckpt.pb",
"test_graph_tfadd_with_ckpt_saver.ckpt",
@ -317,6 +344,7 @@ tf_library(
testonly = 1,
config = "test_graph_tfadd.config.pbtxt",
cpp_class = "AddComp",
debug_info = "test_debuginfo_tfadd.pb",
graph = "test_graph_tfadd.pb",
include_standard_runtime_deps = False,
mlir_components = "Bridge",

View File

@ -30,6 +30,7 @@ from tensorflow.core.protobuf import saver_pb2
from tensorflow.python.client import session
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import error_interpolation
from tensorflow.python.framework import function
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
@ -184,7 +185,22 @@ def tfvariable_sequential_updates(_):
array_ops.identity(updates, name='result')
def write_graph(build_graph, out_dir):
def export_debug_info(exported_graph):
"""Exports debug information from a graph.
Args:
exported_graph: A Graph that has been created by tracing a saveable view.
Returns:
Corresponding GraphDebugInfo with traces for all ops in exported_graph.
"""
exported_operations = []
for op in exported_graph.get_operations():
exported_operations.append(('', op))
return error_interpolation.create_graph_debug_info_def(exported_operations)
def write_graph(build_graph, out_dir, debug_info=False):
"""Build a graph using build_graph and write it out."""
g = ops.Graph()
with g.as_default():
@ -193,10 +209,19 @@ def write_graph(build_graph, out_dir):
with open(filename, 'wb') as f:
f.write(six.ensure_binary(g.as_graph_def().SerializeToString()))
if debug_info:
filename_debuginfo = os.path.join(
out_dir, 'test_debuginfo_%s.pb' % build_graph.__name__)
test_debuginfo = export_debug_info(g)
with open(filename_debuginfo, 'wb') as f:
f.write(
six.ensure_binary(
test_debuginfo.SerializeToString(deterministic=True)))
def main(_):
control_flow_util.enable_control_flow_v2()
write_graph(tfadd, FLAGS.out_dir)
write_graph(tfadd, FLAGS.out_dir, debug_info=True)
write_graph(tfadd_with_ckpt, FLAGS.out_dir)
write_graph(tfadd_with_ckpt_saver, FLAGS.out_dir)
write_graph(tfassert_eq, FLAGS.out_dir)

View File

@ -0,0 +1,69 @@
# RUN: not tfcompile --graph=%s --config=%s.config.pbtxt --mlir_components=Bridge --debug_info=%s.debug.pbtxt 2>&1 | FileCheck %s -dump-input-on-failure
# RUN: not tfcompile --graph=%s --config=%s.config.pbtxt --mlir_components=None 2>&1 | FileCheck -check-prefix=OLD %s -dump-input-on-failure
# Checks the error message produced by tfcompile with mlir_component
# Checks that source debug information is used in the output error message and
# the node x_y_sum = Add
# CHECK: INVALID ARGUMENTS: Dimensions must be equal, but are 2 and 3 for 'x_y_sum = Add[T=DT_INT32](aot_feed_0/x, aot_feed_0/y)'
# CHECK: math_ops.add(x, y, name='x_y_sum')
# CHECK: build_graph(out_dir)
# Checks the error message produced by tfcompile without mlir_component
# OLD: INVALID ARGUMENTS: Incompatible shapes: [2] vs. [3]
# OLD: x_y_sum
node: {
name: "x"
op: "Placeholder"
attr: {
key: "shape"
value: {
shape: {
dim: {
size: -1
}
}
}
}
attr: {
key: "dtype"
value: {
type: DT_INT32
}
}
}
node: {
name: "y"
op: "Placeholder"
attr: {
key: "shape"
value: {
shape: {
dim: {
size: -1
}
}
}
}
attr: {
key: "dtype"
value: {
type: DT_INT32
}
}
}
node: {
name: "x_y_sum"
op: "Add"
input: "x"
input: "y"
attr: {
key: "T"
value: {
type: DT_INT32
}
}
}
versions: {
producer: 321
}

View File

@ -0,0 +1,16 @@
# Text form of tensorflow.tf2xla.Config proto.
feed {
id { node_name: "x" }
shape {
dim { size: 2 }
}
}
feed {
id { node_name: "y" }
shape {
dim { size: 3 }
}
}
fetch {
id { node_name: "x_y_sum" }
}

View File

@ -0,0 +1,28 @@
files: "org_tensorflow/tensorflow/compiler/aot/tests/test_error_message.lit.pbtxt.fake_py.debug"
traces: {
key: "x@"
value: {
file_line_cols: {
line: 1
}
}
}
traces: {
key: "x_y_sum@"
value: {
file_line_cols: {
line: 3
}
file_line_cols: {
line: 4
}
}
}
traces: {
key: "y@"
value: {
file_line_cols: {
line: 2
}
}
}

View File

@ -0,0 +1,4 @@
x = value
y = value
math_ops.add(x, y, name='x_y_sum')
build_graph(out_dir)

View File

@ -26,6 +26,7 @@ def tf_library(
name,
graph,
config,
debug_info = None,
freeze_checkpoint = None,
freeze_saver = None,
cpp_class = None,
@ -191,12 +192,15 @@ def tf_library(
mlir_flag = "--mlir_components=" + mlir_components
srcs = [tfcompile_graph, config]
debug_info_flag = ""
if debug_info:
srcs.append(debug_info)
debug_info_flag = " --debug_info=$(location " + debug_info + ")"
native.genrule(
name = ("gen_" + name),
srcs = [
tfcompile_graph,
config,
],
srcs = srcs,
outs = [
header_file,
metadata_object_file,
@ -206,6 +210,7 @@ def tf_library(
"CUDA_VISIBLE_DEVICES='' " +
"$(location " + tfcompile_tool + ")" +
" --graph=$(location " + tfcompile_graph + ")" +
debug_info_flag +
" --config=$(location " + config + ")" +
" --entry_point=" + ep +
" --cpp_class=" + cpp_class +
@ -237,10 +242,7 @@ def tf_library(
session_module_pb = name + "_session_module.pb"
native.genrule(
name = (name + "_session_module"),
srcs = [
tfcompile_graph,
config,
],
srcs = srcs,
outs = [
session_module_pb,
],
@ -248,6 +250,7 @@ def tf_library(
"CUDA_VISIBLE_DEVICES='' " +
"$(location " + tfcompile_tool + ")" +
" --graph=$(location " + tfcompile_graph + ")" +
debug_info_flag +
" --config=$(location " + config + ")" +
" --entry_point=" + ep +
" --cpp_class=" + cpp_class +

View File

@ -65,6 +65,7 @@ int main(int argc, char** argv) {
flags.out_metadata_object = "out_helper.o";
flags.out_header = "out.h";
flags.entry_point = "entry";
flags.debug_info_path_begin_marker = "";
std::vector<tensorflow::Flag> flag_list;
AppendMainFlags(&flag_list, &flags);
@ -81,12 +82,10 @@ int main(int argc, char** argv) {
tensorflow::port::InitMain(usage.c_str(), &argc, &argv);
QCHECK(argc == 1) << "\nERROR: This command does not take any arguments "
"other than flags\n\n"
<< usage;
"other than flags. See --help.\n\n";
tensorflow::Status status = tensorflow::tfcompile::Main(flags);
if (status.code() == tensorflow::error::INVALID_ARGUMENT) {
std::cerr << "INVALID ARGUMENTS: " << status.error_message() << "\n\n"
<< usage;
std::cerr << "INVALID ARGUMENTS: " << status.error_message() << "\n\n";
return 1;
} else {
TF_QCHECK_OK(status);

View File

@ -184,6 +184,7 @@ XLA_DEVICE_DEPS = [
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:dataset_ops_op_lib",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
"//tensorflow/core:functional_ops_op_lib",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",

View File

@ -18,6 +18,7 @@ cc_library(
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
"@com_google_absl//absl/types:span",
],
)

View File

@ -368,14 +368,20 @@ bool GraphCycles::CanContractEdge(int32 a, int32 b) {
return !reachable;
}
bool GraphCycles::ContractEdge(int32 a, int32 b) {
absl::optional<int32> GraphCycles::ContractEdge(int32 a, int32 b) {
CHECK(HasEdge(a, b));
RemoveEdge(a, b);
if (IsReachableNonConst(a, b)) {
// Restore the graph to its original state.
InsertEdge(a, b);
return false;
return absl::nullopt;
}
if (rep_->nodes_[b]->in.Size() + rep_->nodes_[b]->out.Size() >
rep_->nodes_[a]->in.Size() + rep_->nodes_[a]->out.Size()) {
// Swap "a" and "b" to minimize copying.
std::swap(a, b);
}
Node* nb = rep_->nodes_[b];
@ -399,7 +405,8 @@ bool GraphCycles::ContractEdge(int32 a, int32 b) {
InsertEdge(y, a);
}
return true;
// Note, if the swap happened it might be what originally was called "b".
return a;
}
absl::Span<const int32> GraphCycles::Successors(int32 node) const {

View File

@ -40,6 +40,7 @@ limitations under the License.
// FindPath() is linear in the size of the graph.
// The current implementation uses O(|V|+|E|) space.
#include "absl/types/optional.h"
#include "absl/types/span.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
@ -80,11 +81,11 @@ class GraphCycles {
// Return whether there is an edge directly from source_node to dest_node.
bool HasEdge(int32 source_node, int32 dest_node) const;
// Contracts the edge from 'a' to node 'b', merging nodes 'a' and 'b'. 'b' is
// removed from the graph, and edges to/from 'b' are replaced with edges
// to/from 'a'. If contracting the edge would create a cycle, does nothing
// and returns false.
bool ContractEdge(int32 a, int32 b);
// Contracts the edge from 'a' to node 'b', merging nodes 'a' and 'b'. One of
// the nodes is removed from the graph, and edges to/from it are added to
// the remaining one, which is returned. If contracting the edge would create
// a cycle, does nothing and return no value.
absl::optional<int32> ContractEdge(int32 a, int32 b);
// Return true if can contract edge, otherwise return false.
bool CanContractEdge(int32 a, int32 b);

View File

@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
#include <optional>
#include <random>
#include <unordered_set>
#include <vector>
@ -479,19 +480,21 @@ TEST_F(GraphCyclesTest, ContractEdge) {
ASSERT_TRUE(AddEdge(2, 4));
ASSERT_TRUE(AddEdge(3, 4));
EXPECT_FALSE(g_.ContractEdge(1, 3));
EXPECT_FALSE(g_.ContractEdge(1, 3).has_value());
CHECK(g_.CheckInvariants());
EXPECT_TRUE(g_.HasEdge(1, 3));
EXPECT_TRUE(g_.ContractEdge(1, 2));
// Node (2) has more edges.
EXPECT_EQ(g_.ContractEdge(1, 2).value(), 2);
CHECK(g_.CheckInvariants());
EXPECT_TRUE(g_.HasEdge(1, 3));
EXPECT_TRUE(g_.HasEdge(1, 4));
EXPECT_TRUE(g_.HasEdge(2, 3));
EXPECT_TRUE(g_.HasEdge(2, 4));
EXPECT_TRUE(g_.HasEdge(3, 4));
EXPECT_TRUE(g_.ContractEdge(1, 3));
// Node (2) has more edges.
EXPECT_EQ(g_.ContractEdge(2, 3).value(), 2);
CHECK(g_.CheckInvariants());
EXPECT_TRUE(g_.HasEdge(1, 4));
EXPECT_TRUE(g_.HasEdge(2, 4));
}
TEST_F(GraphCyclesTest, CanContractEdge) {
@ -527,3 +530,26 @@ static void BM_StressTest(int iters, int num_nodes) {
}
}
BENCHMARK(BM_StressTest)->Range(2048, 1048576);
static void BM_ContractEdge(int iters, int num_nodes) {
while (iters-- > 0) {
tensorflow::testing::StopTiming();
tensorflow::GraphCycles g;
std::vector<int32> nodes;
nodes.reserve(num_nodes);
for (int i = 0; i < num_nodes; i++) {
nodes.push_back(g.NewNode());
}
// All edges point toward the last one.
for (int i = 0; i < num_nodes - 1; ++i) {
g.InsertEdge(nodes[i], nodes[num_nodes - 1]);
}
tensorflow::testing::StartTiming();
int node = num_nodes - 1;
for (int i = 0; i < num_nodes - 1; ++i) {
node = g.ContractEdge(nodes[i], node).value();
}
}
}
BENCHMARK(BM_ContractEdge)->Arg(1000)->Arg(10000);

View File

@ -172,8 +172,9 @@ class XlaExecutableClosureStore {
private:
mutex mutex_;
int64 key_counter_ GUARDED_BY(mutex_);
absl::flat_hash_map<KeyT, XlaExecutableClosure> closures_ GUARDED_BY(mutex_);
int64 key_counter_ TF_GUARDED_BY(mutex_);
absl::flat_hash_map<KeyT, XlaExecutableClosure> closures_
TF_GUARDED_BY(mutex_);
TF_DISALLOW_COPY_AND_ASSIGN(XlaExecutableClosureStore);
};

View File

@ -165,7 +165,8 @@ class XlaCompileOp : public OpKernel {
// error when compiling the cluster this _XlaCompile is supposed to compile.
// If `cannot_compile_cluster_` is true then we avoid compiling this cluster
// on any future calls to _XlaCompile.
bool cannot_compile_cluster_ GUARDED_BY(cannot_compile_cluster_mu_) = false;
bool cannot_compile_cluster_ TF_GUARDED_BY(cannot_compile_cluster_mu_) =
false;
mutex cannot_compile_cluster_mu_;
};

View File

@ -161,6 +161,11 @@ class MarkForCompilationPassImpl {
// The ID of the cluster as represented in `cycles_graph_`.
int cycles_graph_node_id() const { return cycles_graph_node_id_; }
// Sets the ID of the cluster as represented in `cycles_graph_`.
void set_cycles_graph_node_id(int cycles_graph_node_id) {
cycles_graph_node_id_ = cycles_graph_node_id;
}
// The size of the cluster excluding constant and identity nodes.
int effective_cluster_size() const { return effective_cluster_size_; }
@ -381,14 +386,16 @@ class MarkForCompilationPassImpl {
// R, B} cluster.
string DescribePotentialCycle(int from, int to);
// Merge the clusters `cluster_from` and `cluster_to`. After this step the
// larger combined cluster is represented by `cluster_from`'s ID in
// `cycles_graph_`.
// Merge the clusters `cluster_from` and `cluster_to`. After this step the
// larger combined cluster is represented by `cluster_from`, but can have
// `cycles_graph_`'s ID of either `cluster_from` or `cluster_to` depending on
// which way will require less operations.
bool MergeClusters(Cluster* cluster_from, Cluster* cluster_to) {
int from = cluster_from->cycles_graph_node_id();
int to = cluster_to->cycles_graph_node_id();
if (!cycles_graph_.ContractEdge(from, to)) {
auto optional_merged_node = cycles_graph_.ContractEdge(from, to);
if (!optional_merged_node.has_value()) {
VLOG(3) << "Could not contract " << cluster_from->DebugString(*graph_)
<< " -> " << cluster_to->DebugString(*graph_)
<< " because contracting the edge would create a cycle via "
@ -398,6 +405,8 @@ class MarkForCompilationPassImpl {
// Merge the clusters.
cluster_from->Merge(cluster_to);
// Update `cycle_graph_`'s ID.
cluster_from->set_cycles_graph_node_id(optional_merged_node.value());
// Merge the UnionFind<Cluster*>.
cluster_for_node_[from].Merge(&cluster_for_node_[to]);
@ -1911,6 +1920,7 @@ absl::flat_hash_set<string> GetKnownXLAWhitelistOp() {
"LinSpace",
"ListDiff",
"LogMatrixDeterminant",
"LowerBound",
"MatMul",
"MatrixBandPart",
"MatrixDiag",
@ -2037,6 +2047,7 @@ absl::flat_hash_set<string> GetKnownXLAWhitelistOp() {
"TensorScatterUpdate",
"TridiagonalSolve",
"TruncatedNormal",
"UpperBound",
"UnsortedSegmentMax",
"UnsortedSegmentMin",
"UnsortedSegmentProd",

View File

@ -18,13 +18,15 @@ limitations under the License.
#include "absl/synchronization/mutex.h"
#include "tensorflow/compiler/jit/xla_activity.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/thread_annotations.h"
namespace tensorflow {
namespace {
// The list of all registered `XlaActivityListener`s.
struct XlaActivityListenerList {
absl::Mutex mutex;
std::vector<std::unique_ptr<XlaActivityListener>> listeners GUARDED_BY(mutex);
std::vector<std::unique_ptr<XlaActivityListener>> listeners
TF_GUARDED_BY(mutex);
};
void FlushAllListeners();

View File

@ -50,7 +50,7 @@ TEST(CreateCycleDetectionGraph, ConnectivityThroughEnterExitRegion) {
GraphCycles cycles;
TF_ASSERT_OK(CreateCycleDetectionGraph(root.graph(), &cycles).status());
EXPECT_FALSE(cycles.ContractEdge(a.node()->id(), b.node()->id()));
EXPECT_FALSE(cycles.CanContractEdge(a.node()->id(), b.node()->id()));
}
TEST(CreateCycleDetectionGraph, ConnectivityThroughMultipleEnterExitRegions) {
@ -69,7 +69,7 @@ TEST(CreateCycleDetectionGraph, ConnectivityThroughMultipleEnterExitRegions) {
GraphCycles cycles;
TF_ASSERT_OK(CreateCycleDetectionGraph(root.graph(), &cycles).status());
EXPECT_FALSE(cycles.ContractEdge(a.node()->id(), b.node()->id()));
EXPECT_FALSE(cycles.CanContractEdge(a.node()->id(), b.node()->id()));
}
TEST(CreateCycleDetectionGraph, ReachingEnterExit) {

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/compiler/jit/xla_activity_listener.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "tensorflow/compiler/tf2xla/xla_context.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/util.h"
@ -33,6 +34,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/metrics.h"
#include "tensorflow/core/framework/attr_value_util.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/lib/hash/hash.h"
@ -202,6 +204,52 @@ static bool ShouldBeMegamorphic(int64 compile_count, int64 execution_count) {
execution_count < kMinExecutionsPerCompile * compile_count;
}
// Creates a simple graph using the specified op as the only op apart from the
// arg and retval nodes.
static xla::StatusOr<std::unique_ptr<Graph>> CreateGraph(
const NodeDef& node_def, absl::Span<const XlaCompiler::Argument> args,
absl::Span<const DataType> result_types) {
// TODO(b/74182462): We implement this by creating a new dummy Graph including
// _Arg nodes, and let CompileGraph walk it. This could be optimized.
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
Status status;
// First create the actual node we care about computing.
Node* main_node = graph->AddNode(node_def, &status);
TF_RETURN_IF_ERROR(status);
// Create dummy _Arg nodes. Link these to `node` and also via a control
// dependency edge to the _SOURCE node.
for (int64 i = 0; i < args.size(); ++i) {
Node* node;
string arg_name = absl::StrCat("_arg", i);
Status status =
NodeBuilder(arg_name, FunctionLibraryDefinition::kArgOp)
.ControlInput(graph->source_node())
.Attr("T", args[i].kind == XlaCompiler::Argument::kResource
? DT_RESOURCE
: args[i].type)
.Attr("index", i)
.Finalize(graph.get(), &node);
TF_RETURN_IF_ERROR(status);
graph->AddEdge(node, 0, main_node, i);
}
// Similarly with return values, create dummy _Retval nodes fed by `node`.
for (int64 i = 0; i < result_types.size(); ++i) {
Node* node;
string retval_name = absl::StrCat("_retval", i);
Status status = NodeBuilder(retval_name, FunctionLibraryDefinition::kRetOp)
.Input(main_node, i)
.Attr("T", result_types[i])
.Attr("index", i)
.Finalize(graph.get(), &node);
TF_RETURN_IF_ERROR(status);
}
FixupSourceAndSinkEdges(graph.get());
return graph;
}
Status XlaCompilationCache::CompileSingleOp(
const XlaCompiler::Options& options,
absl::Span<const XlaCompiler::Argument> args, OpKernelContext* ctx,
@ -222,8 +270,11 @@ Status XlaCompilationCache::CompileSingleOp(
for (int i = 0; i < result_dtypes.size(); ++i) {
result_dtypes[i] = ctx->expected_output_dtype(i);
}
return compiler->CompileSingleOp(compile_options, ctx->op_kernel().def(),
args, result_dtypes, result);
const NodeDef& node_def = ctx->op_kernel().def();
TF_ASSIGN_OR_RETURN(auto graph, CreateGraph(node_def, args, result_dtypes));
return compiler->CompileGraph(compile_options, node_def.name(),
std::move(graph), args, result);
};
return CompileImpl(options, name, args, compile_op,
/*compile_threshold=*/absl::nullopt,

View File

@ -151,19 +151,19 @@ class XlaCompilationCache : public ResourceBase {
int64 request_count = 0;
// Did compilation succeed?
Status compilation_status GUARDED_BY(mu);
Status compilation_status TF_GUARDED_BY(mu);
// Output of the XlaCompiler.
XlaCompiler::CompilationResult compilation_result GUARDED_BY(mu);
XlaCompiler::CompilationResult compilation_result TF_GUARDED_BY(mu);
// The XLA executable compiled from <computation>. May be null if no
// executable has been built.
std::unique_ptr<xla::LocalExecutable> executable GUARDED_BY(mu);
std::unique_ptr<xla::LocalExecutable> executable TF_GUARDED_BY(mu);
};
mutex compile_cache_mu_;
absl::flat_hash_map<Signature, std::unique_ptr<Entry>, Signature::Hash> cache_
GUARDED_BY(compile_cache_mu_);
TF_GUARDED_BY(compile_cache_mu_);
struct ClusterCompileStats {
// Number of times the cluster has been (re-)compiled.
@ -185,7 +185,7 @@ class XlaCompilationCache : public ResourceBase {
// Maps cluster names to compilation statistics for said cluster.
absl::flat_hash_map<string, ClusterCompileStats> cluster_compile_stats_
GUARDED_BY(cluster_compile_stats_mu_);
TF_GUARDED_BY(cluster_compile_stats_mu_);
// The number of times a lazy compilation must be requested for a specific
// signature before we attempt to compile it.

View File

@ -83,7 +83,7 @@ class XlaDeviceAllocatorState {
std::unordered_map<std::pair<const xla::Backend*, int>,
std::unique_ptr<XlaDeviceAllocator>,
hash<std::pair<const xla::Backend*, int>>>
allocators_ GUARDED_BY(allocator_mutex_);
allocators_ TF_GUARDED_BY(allocator_mutex_);
TF_DISALLOW_COPY_AND_ASSIGN(XlaDeviceAllocatorState);
};

View File

@ -137,7 +137,7 @@ class XlaDevice : public LocalDevice {
~XlaDevice() override;
Allocator* GetAllocator(AllocatorAttributes attr) override
LOCKS_EXCLUDED(mu_);
TF_LOCKS_EXCLUDED(mu_);
void Compute(OpKernel* op_kernel, OpKernelContext* context) override;
void ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context,
AsyncOpKernel::DoneCallback done) override;
@ -145,18 +145,18 @@ class XlaDevice : public LocalDevice {
void Sync(const DoneCallback& done) override;
Status TryGetDeviceContext(DeviceContext** out_context) override
LOCKS_EXCLUDED(mu_);
TF_LOCKS_EXCLUDED(mu_);
Status MakeTensorFromProto(const TensorProto& tensor_proto,
const AllocatorAttributes alloc_attrs,
Tensor* tensor) override LOCKS_EXCLUDED(mu_);
Tensor* tensor) override TF_LOCKS_EXCLUDED(mu_);
// Allocate tensor on fast memory space. This is only applied to the new TPU
// hardware which has faster read/write memory. If the hardware doesn't
// have such memory space, we fallback to the ordinary memory space.
Status MakeFastMemTensorFromProto(const TensorProto& tensor_proto,
const AllocatorAttributes alloc_attrs,
Tensor* tensor) LOCKS_EXCLUDED(mu_);
Tensor* tensor) TF_LOCKS_EXCLUDED(mu_);
const Metadata& metadata() { return xla_metadata_; }
@ -166,34 +166,35 @@ class XlaDevice : public LocalDevice {
//
// TODO(b/111859745): The Eager context needs to call this method to recover
// from failures.
Status EnsureDeviceContextOk() LOCKS_EXCLUDED(mu_);
Status EnsureDeviceContextOk() TF_LOCKS_EXCLUDED(mu_);
// Instructs this XlaDevice to set a GpuDeviceInfo, which holds extra
// information for GPU and TPU devices.
Status UseGpuDeviceInfo() LOCKS_EXCLUDED(mu_);
Status UseGpuDeviceInfo() TF_LOCKS_EXCLUDED(mu_);
// Instructs this XlaDevice to return 'sync_on_completion' for
// AllowsSyncOnCompletion().
void SetAllowsSyncOnCompletion(bool sync_on_completion) LOCKS_EXCLUDED(mu_);
bool AllowsSyncOnCompletion() const override LOCKS_EXCLUDED(mu_);
void SetAllowsSyncOnCompletion(bool sync_on_completion)
TF_LOCKS_EXCLUDED(mu_);
bool AllowsSyncOnCompletion() const override TF_LOCKS_EXCLUDED(mu_);
// Installs an error handling callback when RefreshStatus sees !status.ok().
void SetHandleDeviceErrorCallback(std::function<Status()> callback);
Status RefreshStatus() override LOCKS_EXCLUDED(mu_);
Status RefreshStatus() override TF_LOCKS_EXCLUDED(mu_);
private:
xla::StatusOr<xla::LocalClient*> GetOrCreateClient() const;
Allocator* GetAllocatorLocked(AllocatorAttributes attr)
EXCLUSIVE_LOCKS_REQUIRED(mu_);
TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
Status EnsureStreamOkLocked(xla::Backend* backend, const string& name,
std::shared_ptr<se::Stream>* stream,
bool* stream_was_changed)
EXCLUSIVE_LOCKS_REQUIRED(mu_);
TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
// Return a pair of device context, the second one is fast_mem device context.
xla::StatusOr<std::pair<XlaDeviceContext*, XlaDeviceContext*>>
GetDeviceContextLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_);
GetDeviceContextLocked() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
static Status GetMetadataFromDevice(DeviceBase* device,
const XlaDevice::Metadata** metadata);
@ -218,13 +219,13 @@ class XlaDevice : public LocalDevice {
// Intra-op threads to spawn (from SessionOptions).
const int intra_op_parallelism_threads_;
// Memory allocator associated with this device.
Allocator* xla_allocator_ GUARDED_BY(mu_) = nullptr; // Not owned.
Allocator* xla_allocator_ TF_GUARDED_BY(mu_) = nullptr; // Not owned.
// Stream associated with this device. Operations enqueued on this
// stream are executed on the device. Operations include data
// copying back and forth between CPU and the device, and
// computations enqueued by XLA.
std::shared_ptr<se::Stream> stream_ GUARDED_BY(mu_);
std::shared_ptr<se::Stream> stream_ TF_GUARDED_BY(mu_);
// If false, only stream_ is valid and all computation and transfers use
// stream_. If true, computation is performed by stream_ and transfers are
// performed by host_to_device/device_to_device stream or borrowing a stream
@ -232,36 +233,36 @@ class XlaDevice : public LocalDevice {
const bool use_multiple_streams_;
// If use_multiple_streams_, host to device transfers are performed using this
// stream.
std::shared_ptr<se::Stream> host_to_device_stream_ GUARDED_BY(mu_);
std::shared_ptr<se::Stream> host_to_device_stream_ TF_GUARDED_BY(mu_);
// If use_multiple_streams_, transfers between different devices are performed
// using these streams.
std::vector<std::shared_ptr<se::Stream>> device_to_device_streams_
GUARDED_BY(mu_);
TF_GUARDED_BY(mu_);
const XlaCompiler::ShapeRepresentationFn shape_representation_fn_;
// The device context accessed by all users of the XlaDevice, set by calls to
// EnsureDeviceContextOk. If gpu_device_info_ is non-null, this pointer is
// also filled in to that struct. XlaDeviceContext is a ref-counted object.
XlaDeviceContext* device_context_ GUARDED_BY(mu_) = nullptr;
XlaDeviceContext* device_context_ TF_GUARDED_BY(mu_) = nullptr;
// The device context will allocate memory on fast memory space on TPU.
// XlaDeviceContext is a ref-counted object.
XlaDeviceContext* fast_mem_device_context_ GUARDED_BY(mu_) = nullptr;
XlaDeviceContext* fast_mem_device_context_ TF_GUARDED_BY(mu_) = nullptr;
// Holds extra information for GPU and TPU devices, e.g. the device context.
bool use_gpu_device_info_ GUARDED_BY(mu_) = false;
std::unique_ptr<GpuDeviceInfo> gpu_device_info_ GUARDED_BY(mu_);
bool use_gpu_device_info_ TF_GUARDED_BY(mu_) = false;
std::unique_ptr<GpuDeviceInfo> gpu_device_info_ TF_GUARDED_BY(mu_);
// Thread pool used for running closures
std::unique_ptr<thread::ThreadPool> thread_pool_;
// True if the device allows XlaDevice::Sync to be called on completion
// regardless of status.
bool sync_on_completion_ GUARDED_BY(mu_) = true;
bool sync_on_completion_ TF_GUARDED_BY(mu_) = true;
// A callback that will be invoked when RefreshStatus sees a status error.
std::function<Status()> device_error_callback_ GUARDED_BY(mu_);
std::function<Status()> device_error_callback_ TF_GUARDED_BY(mu_);
// Set of devices to use. This controls which of the devices on the given
// platform will have resources allocated. For GPUs this will be

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/dma_helper.h"
#include "tensorflow/core/framework/tensor_reference.h"
#include "tensorflow/core/platform/mem.h"
#include "tensorflow/stream_executor/platform/port.h"

View File

@ -117,7 +117,7 @@ class XlaDeviceContext : public DeviceContext {
bool use_fast_mem_;
absl::Mutex mu_;
int next_stream_ GUARDED_BY(mu_) = 0;
int next_stream_ TF_GUARDED_BY(mu_) = 0;
};
} // namespace tensorflow

View File

@ -18,7 +18,6 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_JIT_XLA_LAUNCH_UTIL_H_
#define TENSORFLOW_COMPILER_JIT_XLA_LAUNCH_UTIL_H_
#include "absl/base/thread_annotations.h"
#include "tensorflow/compiler/jit/xla_compilation_cache.h"
#include "tensorflow/compiler/jit/xla_tensor.h"
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
@ -30,6 +29,7 @@ limitations under the License.
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/stream_executor/device_memory_allocator.h"
namespace tensorflow {
@ -102,7 +102,7 @@ class VariableInfo {
// `variables` is allowed to contain instances that don't track a resource
// variable (i.e. variables[i].var() can be null for some i).
Status LockVariables(absl::Span<VariableInfo> variables)
EXCLUSIVE_LOCK_FUNCTION();
TF_EXCLUSIVE_LOCK_FUNCTION();
// Helper class to perform the marshalling of TensorFlow inputs and outputs to
// ShapedBuffers suitable for passing to an XLA computation.

View File

@ -122,7 +122,7 @@ class XlaTensor {
std::shared_ptr<se::Event> definition_event_;
// A list of all streams for which the tensor's content is defined for any
// newly enqueued command.
absl::InlinedVector<se::Stream*, 2> streams_defined_on_ GUARDED_BY(mu_);
absl::InlinedVector<se::Stream*, 2> streams_defined_on_ TF_GUARDED_BY(mu_);
mutex mu_;
};

View File

@ -74,12 +74,15 @@ cc_library(
"//tensorflow/compiler/mlir/xla:hlo",
"//tensorflow/compiler/mlir/xla:hlo_legalize_to_lhlo",
"//tensorflow/compiler/mlir/xla:lhlo",
"//tensorflow/compiler/mlir/xla:lhlo_copy_removal",
"//tensorflow/compiler/mlir/xla:lhlo_fuse_linalg",
"//tensorflow/compiler/mlir/xla:lhlo_legalize_to_affine",
"//tensorflow/compiler/mlir/xla:lhlo_legalize_to_gpu",
"//tensorflow/compiler/mlir/xla:lhlo_legalize_to_parallel_loops",
"//tensorflow/compiler/mlir/xla:xla_dialect_registration",
"//tensorflow/compiler/mlir/xla:xla_legalize_control_flow",
"//tensorflow/compiler/mlir/xla:xla_legalize_tf",
"//tensorflow/compiler/mlir/xla:xla_legalize_tf_with_tf2xla",
"//tensorflow/compiler/mlir/xla:xla_legalize_to_linalg",
"//tensorflow/compiler/mlir/xla:xla_legalize_to_standard",
"//tensorflow/compiler/mlir/xla:xla_lower",
@ -100,6 +103,38 @@ cc_library(
],
)
cc_library(
name = "mlir_graph_optimization_pass",
srcs = ["mlir_graph_optimization_pass.cc"],
hdrs = ["mlir_graph_optimization_pass.h"],
deps = [
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tensorflow:convert_graphdef",
"//tensorflow/compiler/mlir/tensorflow:device_util",
"//tensorflow/compiler/mlir/tensorflow:dump_mlir_util",
"//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags",
"//tensorflow/core:core_cpu",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:support",
"@llvm-project//mlir:IR",
],
alwayslink = 1,
)
cc_library(
name = "mlir_graph_optimization_pass_registration",
srcs = [
"mlir_graph_optimization_pass_registration.cc",
],
deps = [
":mlir_graph_optimization_pass",
"//tensorflow/core:core_cpu",
],
alwayslink = 1,
)
tf_cc_binary(
name = "tf-opt",
deps = [

View File

@ -30,7 +30,8 @@ filegroup(
"ir/tfl_ops.td",
"//tensorflow/compiler/mlir/lite/quantization:quantization_td_files",
"@llvm-project//mlir:OpBaseTdFiles",
"@llvm-project//mlir:include/mlir/Transforms/LoopLikeInterface.td",
"@llvm-project//mlir:include/mlir/Interfaces/LoopLikeInterface.td",
"@llvm-project//mlir:include/mlir/Interfaces/SideEffects.td",
],
)
@ -221,19 +222,21 @@ cc_library(
deps = [
":tensorflow_lite_ops_inc_gen",
":validators",
"@llvm-project//llvm:support",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:Dialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
# TODO(jpienaar): Move this out after splitting out LoopLikeOpInterface.
"@llvm-project//mlir:Transforms",
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
"//tensorflow/lite/schema:schema_fbs",
"@llvm-project//llvm:support",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:DerivedAttributeOpInterface",
"@llvm-project//mlir:Dialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:LoopLikeInterface",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:SideEffects",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TransformUtils",
],
alwayslink = 1,
)
@ -325,8 +328,8 @@ cc_library(
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tensorflow:convert_tensor",
"//tensorflow/compiler/mlir/tensorflow:mangling_util",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_passes",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
"//tensorflow/compiler/mlir/tensorflow:unroll_batch_matmul_pass",
"//tensorflow/compiler/xla:status",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/core:framework",
@ -436,7 +439,7 @@ genrule(
srcs = [
"ir/tfl_ops.td",
"ir/tfl_op_interfaces.td",
"@llvm-project//mlir:include/mlir/Transforms/LoopLikeInterface.td",
"@llvm-project//mlir:include/mlir/Interfaces/LoopLikeInterface.td",
"//tensorflow/compiler/mlir/lite/quantization:quantization_td_files",
],
outs = [
@ -516,6 +519,7 @@ cc_library(
"@com_google_absl//absl/strings",
"@flatbuffers",
"@llvm-project//llvm:support",
"@llvm-project//mlir:AllPassesAndDialects",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:TransformUtils",
],
@ -580,7 +584,7 @@ cc_library(
"//tensorflow/lite/delegates/flex:whitelisted_flex_ops_lib",
"//tensorflow/lite/kernels/internal:kernel_utils",
"//tensorflow/lite/schema:schema_fbs",
"//tensorflow/lite/tools/versioning:op_version",
"//tensorflow/lite/tools/versioning",
"@com_google_absl//absl/base",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:flat_hash_map",
@ -697,11 +701,12 @@ cc_library(
"//tensorflow/compiler/mlir/lite/quantization:quantization_config",
"//tensorflow/compiler/mlir/lite/quantization:quantization_passes",
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tensorflow:decode_constant_pass",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_dialect_registration",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_passes",
"//tensorflow/compiler/mlir/tensorflow:tf_dialect_lib",
"//tensorflow/compiler/mlir/tensorflow:tf_dialect_passes",
"//tensorflow/compiler/mlir/tensorflow:tf_graph_optimization_pass",
"//tensorflow/compiler/mlir/tensorflow:translate_lib",
"@llvm-project//mlir:AllPassesAndDialects",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
@ -725,6 +730,7 @@ cc_library(
":tensorflow_lite_quantize",
"//tensorflow/compiler/mlir/lite/quantization:quantization_config",
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tensorflow:decode_constant_pass",
"//tensorflow/compiler/mlir/tensorflow:error_util",
"//tensorflow/compiler/mlir/tensorflow:tf_dialect_lib",
"//tensorflow/compiler/mlir/tensorflow:tf_dialect_passes",
@ -734,11 +740,10 @@ cc_library(
"//tensorflow/lite/tools/optimize:quantize_weights",
"//tensorflow/stream_executor/lib",
"@llvm-project//llvm:support",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:AllPassesAndDialects",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Parser",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:Transforms",
],

View File

@ -1,9 +1,9 @@
# Experimental code for the new TF-Lite convertor, and MLIR dialects and utilities for TensorFlow Lite.
# The new [MLIR](https://github.com/llvm/llvm-project/tree/master/mlir) based
TensorFlow to TensorFlow Lite converter
This directory contains:
1. Experimental code for the new TF-Lite convertor.
2. Code for the TF-lite dialect [MLIR](https://github.com/tensorflow/mlir).
1. MLIR dialects, transformation passes and utilities for TensorFlow Lite.
## API:
@ -11,7 +11,8 @@ The API for converting TensorFlow models to TensorFlow Lite will be through
`tf.lite.TFLiteConverter`. All the conversion code is open sourced, and
the API will be integrated soon.
### The conversion process from TensorFlow to TensorFlow Lite includes the following major passes:
### The conversion process from TensorFlow to TensorFlow Lite includes the
following major passes:
- Import from GraphDef, in .pb or .pbtxt format, into MLIR.
- Raise to Control-flow-graph. Converts TF Control Flow dialect to TF dialect.
@ -28,3 +29,6 @@ TensorFlow Lite models).
- The Export pass writes out TensorFlow Lite FlatBuffer format. This pass
operates on MLIR TensorFlow Lite dialect and is simple/direct translation.
See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc
for the full list of MLIR passes for conversion from TensorFlow to
TensorFlow Lite.

View File

@ -34,9 +34,9 @@ struct PassConfig {
quant_specs(std::move(specs)),
skip_control_dialect(false),
form_clusters(false),
inline_functions(true),
unfold_batch_matmul(true),
legalize_tf_while(true) {}
legalize_tf_while(true),
shape_inference(false) {}
// If `emit_builtin_tflite_ops` is true, TF Lite legalization passes will be
// added, which produces TF Lite ops.
@ -56,9 +56,6 @@ struct PassConfig {
// are formed by grouping consecutive ops of the same device, under a
// `tf_device.launch` op.
bool form_clusters;
// Inline function calls within the main function in the MLIR module, prior
// to legalization to TFLite.
bool inline_functions;
// if `unfold_batch_matmul` is true, the tf.BatchMatMul is unfolded to a set
// of tfl.fully_connected ops.
bool unfold_batch_matmul;
@ -66,6 +63,8 @@ struct PassConfig {
// Note: This is staging step and will be removed.
// TODO(b/137395003): Remove post switching legalization.
bool legalize_tf_while;
// Whether to do shape inference.
bool shape_inference;
};
} // namespace TFL

View File

@ -119,6 +119,12 @@ static void EmitOptionBuilders(const RecordKeeper &record_keeper,
// conversion generation and so the simplicity was chosen over the
// flexibility.
StringRef arg_name = arg_values->getArgNameStr(i);
// Skip any "intermiadiateXXX" attribute as they are specially handled
// in the exporter. They are special because though they are attributes
// in the MLIR they are expressed as tensors in the flatbuffer instead
// of option.
if (op_name == "LSTMOp" && arg_name.take_back(12) == "intermediate")
continue;
os << formatv(
" auto {0} = Convert{1}ForOptionWriter(op.{0}(), fbb);\n",
arg_name, mlir::tblgen::Attribute(arg_def).getAttrDefName());
@ -164,17 +170,24 @@ static void EmitOperatorBuilders(const std::vector<Record *> &defs,
for (const auto *def : defs) {
StringRef op_name = def->getName().drop_front(4);
const bool has_intermediates = op_name == "LSTMOp";
// Signature
os << "static flatbuffers::Offset<tflite::Operator> "
<< GetOperatorBuilderName(def->getName()) << "(mlir::TFL::" << op_name
<< " tflOp, uint32_t opcode_index, "
<< "const std::vector<int32_t>& operands,"
<< "const std::vector<int32_t>& results,"
<< (has_intermediates ? "const std::vector<int32_t>& intermediate_index,"
: "")
<< "flatbuffers::FlatBufferBuilder *fbb) {\n";
// Inputs & outputs
os << " auto inputs = fbb->CreateVector(operands);\n"
" auto outputs = fbb->CreateVector(results);\n\n";
// Intermediates for LSTM.
if (has_intermediates) {
os << " auto intermediates = fbb->CreateVector(intermediate_index);\n";
}
// Build the FlatBuffer operator
os << " return tflite::CreateOperator(\n"
@ -191,9 +204,9 @@ static void EmitOperatorBuilders(const std::vector<Record *> &defs,
// Only builtin ops' builders are auto-generated. custom_options are only
// used by custom or flex ops and those ops are handled manually.
os << " /*custom_options=*/0, "
"tflite::CustomOptionsFormat_FLEXBUFFERS,\n"
" /*mutating_variable_inputs=*/0);\n"
"}\n\n";
<< "tflite::CustomOptionsFormat_FLEXBUFFERS,\n"
<< " /*mutating_variable_inputs=*/0"
<< (has_intermediates ? ", intermediates" : "") << ");\n}\n\n";
}
}
@ -244,6 +257,7 @@ static void EmitGetBuiltinOpCode(const std::vector<Record *> &defs,
// uint32_t opcode_index,
// const std::vector<int32_t>& operands,
// const std::vector<int32_t>& results,
// const std::vector<int32_t>& intermediates,
// flatbuffers::FlatBufferBuilder *fbb);
static void EmitBuildOperator(const std::vector<Record *> &defs,
raw_ostream *ostream) {
@ -255,6 +269,7 @@ static void EmitBuildOperator(const std::vector<Record *> &defs,
"uint32_t opcode_index, "
"const std::vector<int32_t>& operands,"
"const std::vector<int32_t>& results,"
"const std::vector<int32_t>& intermediates,"
"flatbuffers::FlatBufferBuilder *fbb) {\n";
for (const auto *def : defs) {
@ -264,7 +279,8 @@ static void EmitBuildOperator(const std::vector<Record *> &defs,
os << " if (auto tflOp = llvm::dyn_cast<mlir::TFL::" << op_name
<< ">(op))\n"
<< " return " << GetOperatorBuilderName(def->getName())
<< "(tflOp, opcode_index, operands, results, fbb);\n";
<< "(tflOp, opcode_index, operands, results, "
<< (op_name == "LSTMOp" ? "intermediates, " : "") << "fbb);\n";
}
os << " return llvm::None;\n"
@ -307,6 +323,10 @@ static void EmitBuiltinOptionsToAttributes(const RecordKeeper &record_keeper,
if (!arg_def) continue;
if (arg_def->getDef()->isSubClassOf(attr_type)) {
StringRef arg_name = arg_values->getArgNameStr(i);
// Already handle this case in flatbuffer_import.cc.
if (option_name == "LSTMOptions" &&
arg_name.take_back(12) == "intermediate")
continue;
StringRef attr_type = mlir::tblgen::Attribute(arg_def).getAttrDefName();
os << formatv(
" attributes.emplace_back(builder.getNamedAttr(\"{0}\","

View File

@ -547,6 +547,7 @@ bool IsCustomOp(const std::string& op_name) {
// TODO(krzysd) Handle function calls
StatusOr<Operation*> ConvertOp(
const tflite::OperatorT& op, const std::vector<Value>& vals_map,
const std::vector<mlir::TensorType>& intermediate_types,
Value optional_arg_marker, const std::vector<std::string>& op_names,
const std::vector<std::string>& func_names,
const std::vector<std::unique_ptr<tflite::TensorT>>& tensors, Location loc,
@ -608,6 +609,28 @@ StatusOr<Operation*> ConvertOp(
if (op_name == "tfl.lstm") {
// TODO(b/147587779): add the right region if region is empty.
op_state.addRegion();
if (!op.intermediates.empty()) {
if (op.intermediates.size() != 5) {
auto err = errors::InvalidArgument(
"operator has intermediate tensors but the number of them is not "
"five.");
return emitError(loc, err.ToString()), err;
}
// Create intermediate value
const llvm::SmallVector<llvm::StringRef, 5> kIntermediateNames = {
"input_to_input_intermediate", "input_to_forget_intermediate",
"input_to_cell_intermediate", "input_to_output_intermediate",
"effective_hidden_scale_intermediate"};
for (auto type_and_name :
llvm::zip(intermediate_types, kIntermediateNames)) {
mlir::TypeAttr type_attr =
mlir::TypeAttr::get(std::get<0>(type_and_name));
auto named_attr =
builder.getNamedAttr(std::get<1>(type_and_name), type_attr);
op_state.addAttribute(named_attr.first, named_attr.second);
}
}
}
llvm::SmallVector<mlir::NamedAttribute, 2> attrs;
@ -893,6 +916,18 @@ StatusOr<FuncOp> ConvertSubgraph(
}
}
// Intermediate tensors for tfl.lstm are used to carry quantization range
// in their types, so we only need and extract their types.
std::vector<mlir::TensorType> intermediate_types;
intermediate_types.reserve(5);
for (auto intermediate : op->intermediates) {
TF_ASSIGN_OR_RETURN(
auto type, GetTensorType(*subgraph.tensors[intermediate], builder,
/*shapeless_are_scalars=*/true,
/*is_constant=*/true));
intermediate_types.emplace_back(type);
}
// The NameLoc corresponding to the name of the first output tensor
auto op_loc =
op->outputs.empty()
@ -902,8 +937,8 @@ StatusOr<FuncOp> ConvertSubgraph(
// to a valid Value
TF_ASSIGN_OR_RETURN(
auto* mlir_op,
ConvertOp(*op, vals_map, maybe_optional_arg_marker, op_names,
func_names, subgraph.tensors, op_loc, op_builder));
ConvertOp(*op, vals_map, intermediate_types, maybe_optional_arg_marker,
op_names, func_names, subgraph.tensors, op_loc, op_builder));
// Add the results to the value maps. There are two cases: 1. the result
// tensor does not have min/max values, the original op result is used

View File

@ -44,6 +44,7 @@ llvm::Optional<tflite::BuiltinOperator> GetBuiltinOpCode(Operation *mlir_op);
llvm::Optional<flatbuffers::Offset<tflite::Operator>> CreateFlatBufferOperator(
Operation *mlir_op, uint32_t opcode_index,
const std::vector<int32_t> &operands, const std::vector<int32_t> &results,
const std::vector<int32_t> &intermediates,
flatbuffers::FlatBufferBuilder *fbb);
// Populates the array of mlir::NamedAttributes corresponding to the given

View File

@ -43,6 +43,7 @@ limitations under the License.
#include "llvm/Support/ToolOutputFile.h"
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/Builders.h" // TF:llvm-project
#include "mlir/IR/Function.h" // TF:llvm-project
#include "mlir/IR/Location.h" // TF:llvm-project
@ -75,6 +76,7 @@ limitations under the License.
#include "tensorflow/lite/schema/schema_generated.h"
#include "tensorflow/lite/string_util.h"
#include "tensorflow/lite/tools/versioning/op_version.h"
#include "tensorflow/lite/tools/versioning/runtime_version.h"
#include "tensorflow/lite/version.h"
using llvm::dyn_cast;
@ -179,8 +181,6 @@ static StatusOr<tflite::TensorType> GetTFLiteType(Type type,
return tflite::TensorType_FLOAT16;
case mlir::TF::TensorFlowTypes::STRING:
return tflite::TensorType_STRING;
case mlir::TF::TensorFlowTypes::UINT8:
return tflite::TensorType_UINT8;
case mlir::TF::TensorFlowTypes::QUINT8:
return tflite::TensorType_UINT8;
case mlir::StandardTypes::Complex: {
@ -196,7 +196,8 @@ static StatusOr<tflite::TensorType> GetTFLiteType(Type type,
case 1:
return tflite::TensorType_BOOL;
case 8:
return tflite::TensorType_INT8;
return itype.isUnsigned() ? tflite::TensorType_UINT8
: tflite::TensorType_INT8;
case 16:
return tflite::TensorType_INT16;
case 32:
@ -404,6 +405,11 @@ class Translator {
// and returns llvm::None on failure.
Optional<BufferOffset<tflite::Buffer>> BuildBuffer(Operation* inst);
// Build TFLite tensor from the given type. This function is for tfl.lstm
// intermediates, which should have UniformQuantizedType.
Optional<BufferOffset<tflite::Tensor>> BuildTensorFromType(
mlir::Type type, const std::string& name);
// Builds TFLite tensor from the given value. `buffer_idx` is index of the
// corresponding buffer. Emits error and returns llvm::None on failure.
Optional<BufferOffset<tflite::Tensor>> BuildTensor(Value value,
@ -469,7 +475,8 @@ class Translator {
// tensor indices. Emits an error and returns llvm::None on failure.
Optional<BufferOffset<tflite::Operator>> BuildOperator(
Operation* inst, const std::vector<int32_t>& operands,
const std::vector<int32_t>& results);
const std::vector<int32_t>& results,
const std::vector<int32_t>& intermediates);
// Build a subgraph with a given name out of the region either corresponding
// to a function's body or while op.
@ -581,6 +588,34 @@ Optional<BufferOffset<tflite::Buffer>> Translator::BuildBuffer(
return tflite::CreateBuffer(builder_, buffer_data);
}
Optional<BufferOffset<tflite::Tensor>> Translator::BuildTensorFromType(
mlir::Type type, const std::string& name) {
auto tensor_type = type.cast<TensorType>();
if (!tensor_type.hasStaticShape()) {
return llvm::None;
}
llvm::ArrayRef<int64_t> shape_ref = tensor_type.getShape();
std::vector<int32_t> shape(shape_ref.begin(), shape_ref.end());
auto element_type = tensor_type.getElementType();
tflite::TensorType tflite_element_type =
GetTFLiteType(tensor_type.getElementType()).ValueOrDie();
BufferOffset<tflite::QuantizationParameters> q_params;
auto qtype = element_type.dyn_cast<mlir::quant::UniformQuantizedType>();
if (!qtype) {
return llvm::None;
}
q_params = tflite::CreateQuantizationParameters(
builder_, /*min=*/0, /*max=*/0,
builder_.CreateVector<float>({static_cast<float>(qtype.getScale())}),
builder_.CreateVector<int64_t>({qtype.getZeroPoint()}));
return tflite::CreateTensor(
builder_, builder_.CreateVector(shape), tflite_element_type,
/*buffer=*/0, builder_.CreateString(name), q_params,
/*is_variable=*/false);
}
Optional<BufferOffset<tflite::Tensor>> Translator::BuildTensor(
Value value, const std::string& name, unsigned buffer_idx) {
auto type = value.getType().cast<TensorType>();
@ -933,7 +968,8 @@ uint32_t Translator::GetOpcodeIndex(const std::string& op_name,
Optional<BufferOffset<tflite::Operator>> Translator::BuildOperator(
Operation* inst, const std::vector<int32_t>& operands,
const std::vector<int32_t>& results) {
const std::vector<int32_t>& results,
const std::vector<int32_t>& intermediates) {
const auto* dialect = inst->getDialect();
if (!dialect) {
inst->emitOpError("dialect is not registered");
@ -986,7 +1022,7 @@ Optional<BufferOffset<tflite::Operator>> Translator::BuildOperator(
std::string op_name = inst->getName().getStringRef().str();
uint32_t opcode_index = GetOpcodeIndex(op_name, *builtin_code);
auto offset = CreateFlatBufferOperator(inst, opcode_index, operands,
results, &builder_);
results, intermediates, &builder_);
if (!offset) {
inst->emitOpError("is not a supported TFLite op");
}
@ -1171,6 +1207,29 @@ Optional<BufferOffset<tflite::SubGraph>> Translator::BuildSubGraph(
bool failed_once = false;
for (auto& inst : bb) {
if (inst.isKnownTerminator()) break;
std::vector<int32_t> intermediates;
// Build intermediate tensors for tfl.lstm and insert these tensors into
// flatbuffer.
if (llvm::isa<mlir::TFL::LSTMOp>(inst)) {
std::vector<std::string> intermediate_names = {
"input_to_input_intermediate", "input_to_forget_intermediate",
"input_to_cell_intermediate", "input_to_output_intermediate",
"effective_hidden_scale_intermediate"};
for (const std::string& intermediate : intermediate_names) {
auto intermediate_attr = inst.getAttr(intermediate);
if (auto attr = intermediate_attr.dyn_cast_or_null<mlir::TypeAttr>()) {
Type qtype = attr.getValue();
auto tensor_or = BuildTensorFromType(
qtype, name_mapper_.GetUniqueName(intermediate).str());
if (!tensor_or.hasValue()) {
continue;
} else {
intermediates.push_back(tensors.size());
tensors.push_back(tensor_or.getValue());
}
}
}
}
for (auto val : inst.getResults()) {
std::string name = UniqueName(val);
@ -1195,7 +1254,8 @@ Optional<BufferOffset<tflite::SubGraph>> Translator::BuildSubGraph(
results.push_back(tensor_index_map.lookup(result));
}
if (auto tfl_operator = BuildOperator(&inst, operands, results))
if (auto tfl_operator =
BuildOperator(&inst, operands, results, intermediates))
operators.push_back(*tfl_operator);
else
failed_once = true;
@ -1230,27 +1290,58 @@ BufferOffset<tflite::Metadata> Translator::BuildMetadata(StringRef name,
Optional<VectorBufferOffset<BufferOffset<tflite::Metadata>>>
Translator::CreateMetadataVector() {
auto dict_attr = module_.getAttrOfType<mlir::DictionaryAttr>("tfl.metadata");
if (!dict_attr) return VectorBufferOffset<BufferOffset<tflite::Metadata>>();
std::vector<BufferOffset<tflite::Metadata>> metadata;
for (const auto& named_attr : dict_attr) {
StringRef name = named_attr.first;
mlir::Attribute attr = named_attr.second;
if (auto content = attr.dyn_cast<StringAttr>()) {
metadata.push_back(BuildMetadata(name, content.getValue()));
} else {
module_.emitError(
"all values in tfl.metadata's dictionary key-value pairs should be "
"string attributes");
return llvm::None;
if (dict_attr) {
for (const auto& named_attr : dict_attr) {
StringRef name = named_attr.first;
mlir::Attribute attr = named_attr.second;
if (auto content = attr.dyn_cast<StringAttr>()) {
metadata.push_back(BuildMetadata(name, content.getValue()));
} else {
module_.emitError(
"all values in tfl.metadata's dictionary key-value pairs should be "
"string attributes");
return llvm::None;
}
}
}
// Runtime version string is generated after we update the op
// versions. Here we put a 16-byte dummy string as a placeholder. We choose
// 16-byte because it's the alignment of buffers in flatbuffer, so it won't
// cause any waste of space if the actual string is shorter than 16 bytes.
metadata.push_back(
BuildMetadata("min_runtime_version", std::string(16, '\0')));
return builder_.CreateVector(metadata);
}
bool UpdateEntryFunction(ModuleOp module) {
if (module.lookupSymbol<FuncOp>("main") != nullptr) {
// We already have an entry function.
return true;
}
int entry_func_count = 0;
FuncOp entry_func = nullptr;
for (auto fn : module.getOps<FuncOp>()) {
auto attrs = fn.getAttrOfType<mlir::DictionaryAttr>("tf.entry_function");
if (attrs && !attrs.empty()) {
entry_func_count++;
entry_func = fn;
}
}
// We should have one & only have one entry function.
if (entry_func_count != 1) return false;
// Update the entry func to main.
entry_func.setName("main");
return true;
}
Optional<std::string> Translator::Translate(
ModuleOp module, bool emit_builtin_tflite_ops, bool emit_select_tf_ops,
bool emit_custom_ops, OpOrArgNameMapper* op_or_arg_name_mapper) {
if (!UpdateEntryFunction(module)) return llvm::None;
if (!IsValidTFLiteMlirModule(module)) return llvm::None;
Translator translator(module, emit_builtin_tflite_ops, emit_select_tf_ops,
emit_custom_ops, op_or_arg_name_mapper);
@ -1334,6 +1425,7 @@ Optional<std::string> Translator::TranslateInternal() {
builder_.CreateVector(buffers_), metadata_buffer, *metadata);
tflite::FinishModelBuffer(builder_, model);
tflite::UpdateOpVersion(builder_.GetBufferPointer());
tflite::UpdateMinimumRuntimeVersionForModel(builder_.GetBufferPointer());
// Return serialized string for the built FlatBuffer.
return std::string(reinterpret_cast<const char*>(builder_.GetBufferPointer()),

View File

@ -36,6 +36,7 @@ limitations under the License.
#include "mlir/IR/TypeUtilities.h" // TF:llvm-project
#include "mlir/Support/LLVM.h" // TF:llvm-project
#include "mlir/Support/LogicalResult.h" // TF:llvm-project
#include "mlir/Transforms/FoldUtils.h" // TF:llvm-project
#include "mlir/Transforms/InliningUtils.h" // TF:llvm-project
#include "mlir/Transforms/RegionUtils.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
@ -66,13 +67,29 @@ struct TensorFlowLiteInlinerInterface : public DialectInlinerInterface {
}
};
struct TensorFlowLiteOpFolderDialectInterface
: public OpFolderDialectInterface {
using OpFolderDialectInterface::OpFolderDialectInterface;
// Registered hook to check if the given region, which is attached to an
// operation that is *not* isolated from above (i.e. no internal regions
// reference values defined in an enclosing region), should be used when
// materializing constants.
// In the TFLite dialect we materialize inside a while regions as slightly
// more efficient computationally.
bool shouldMaterializeInto(Region *region) const final {
return isa<WhileOp>(region->getParentOp());
}
};
TensorFlowLiteDialect::TensorFlowLiteDialect(mlir::MLIRContext *context)
: Dialect(/*name=*/"tfl", context) {
addOperations<
#define GET_OP_LIST
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.cc.inc"
>();
addInterfaces<TensorFlowLiteInlinerInterface>();
addInterfaces<TensorFlowLiteInlinerInterface,
TensorFlowLiteOpFolderDialectInterface>();
}
//===----------------------------------------------------------------------===//
@ -1269,6 +1286,20 @@ static LogicalResult Verify(UnidirectionalSequenceLSTMOp op) {
"UnidirectionalSequenceLSTMOp expected to have two stateful operands");
}
//===----------------------------------------------------------------------===//
// BidirectionalSequenceLSTMOp
//===----------------------------------------------------------------------===//
static LogicalResult Verify(BidirectionalSequenceLSTMOp op) {
auto operands = op.GetStatefulOperands();
if (operands.size() == 4 && operands[0] == 35 && operands[1] == 36 &&
operands[2] == 37 && operands[3] == 38) {
return success();
}
return op.emitError(
"BidirectionalSequenceLSTMOp expected to have four stateful operands");
}
//===----------------------------------------------------------------------===//
// UnidirectionalSequenceRNNOp
//===----------------------------------------------------------------------===//

View File

@ -25,9 +25,11 @@ limitations under the License.
#include "mlir/IR/Dialect.h" // TF:llvm-project
#include "mlir/IR/OpImplementation.h" // TF:llvm-project
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
#include "mlir/Interfaces/DerivedAttributeOpInterface.h" // TF:llvm-project
#include "mlir/Interfaces/LoopLikeInterface.h" // TF:llvm-project
#include "mlir/Interfaces/SideEffects.h" // TF:llvm-project
#include "mlir/Support/Functional.h" // TF:llvm-project
#include "mlir/Support/LLVM.h" // TF:llvm-project
#include "mlir/Transforms/LoopLikeInterface.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
#include "tensorflow/lite/schema/schema_generated.h"
@ -39,6 +41,8 @@ class TensorFlowLiteDialect : public Dialect {
public:
explicit TensorFlowLiteDialect(MLIRContext *context);
static StringRef getDialectNamespace() { return "tfl"; }
// Registered hook to materialize a constant operation from a given attribute
// value with the desired resultant type.
Operation *materializeConstant(OpBuilder &builder, Attribute value, Type type,

View File

@ -19,7 +19,8 @@ limitations under the License.
#define TFL_OPS
include "mlir/IR/OpBase.td"
include "mlir/Transforms/LoopLikeInterface.td"
include "mlir/Interfaces/LoopLikeInterface.td"
include "mlir/Interfaces/SideEffects.td"
include "tensorflow/compiler/mlir/lite/ir/tfl_op_interfaces.td"
include "tensorflow/compiler/mlir/lite/quantization/quantization.td"
@ -47,13 +48,6 @@ def TFL_Str : Type<CPred<"$_self.isa<mlir::TF::StringType>()">,
"TFLite string type">,
BuildableType<"getType<mlir::TF::StringType>()">;
//===----------------------------------------------------------------------===//
// TFLite dialect uint8 type - uses the TF uint8 type as implementation
//===----------------------------------------------------------------------===//
def TFL_Uint8 : Type<CPred<"$_self.isa<mlir::TF::Uint8Type>()">,
"TFLite uint8 type">,
BuildableType<"getType<mlir::TF::Uint8Type>()">;
//===----------------------------------------------------------------------===//
// TFLite dialect quint8 type - uses the TF quint8 type as implementation
//===----------------------------------------------------------------------===//
@ -141,7 +135,8 @@ class TFL_VariadicTensorOf<list<Type> allowedRuntimeTypes,
Variadic<TensorOf<allowedOpTypes>>,
TFL_RuntimeType<Variadic<TensorOf<allowedRuntimeTypes>>>;
def TFL_Int32Or64 : IntOfWidths<[32, 64]>;
def TFL_Uint8 : UI<8>;
def TFL_Int32Or64 : SignlessIntOfWidths<[32, 64]>;
def TFL_BoolTensor : TFL_TensorOf<[I1]>;
def TFL_FpOrI32OrI64Tensor : TFL_TensorOf<[AnyFloat, TFL_Int32Or64]>;
@ -223,9 +218,9 @@ class TFL_Operand0DOr1ElementTensor<int x> :
class TFL_TFTypesWithSameBits<int i, int j, int num> :
And<[
Or<[CPred<"getElementTypeOrSelf($_op.getResult(" # i # ")).isa<mlir::TF::Quint" # num # "Type>()">,
CPred<"getElementTypeOrSelf($_op.getResult(" # i # ")).isa<mlir::TF::Uint" # num # "Type>()">]>,
CPred<"getElementTypeOrSelf($_op.getResult(" # i # ")).isUnsignedInteger(" # num # ")">]>,
Or<[CPred<"getElementTypeOrSelf($_op.getOperand(" # j # ")).isa<mlir::TF::Quint" # num # "Type>()">,
CPred<"getElementTypeOrSelf($_op.getOperand(" # j # ")).isa<mlir::TF::Uint" # num # "Type>()">]>]>;
CPred<"getElementTypeOrSelf($_op.getOperand(" # j # ")).isUnsignedInteger(" # num # ")">]>]>;
class TFL_OperandHasRankLessThan<int n, int m> :
PredOpTrait<"operand " # n # " is maximum " # m # "-D",
@ -602,7 +597,7 @@ def TFL_ConcatenationOp : TFL_Op<"concatenation",
let verifier = [{ return Verify(*this); }];
}
def TFL_ConstOp : Op<TFL_Dialect, "pseudo_const", [NoSideEffect,
def TFL_ConstOp : Op<TFL_Dialect, "pseudo_const", [ConstantLike, NoSideEffect,
FirstAttrDerivedResultType]> {
let summary = "Constant pseudo op.";
@ -1863,11 +1858,11 @@ def TFL_MulOp : TFL_Op<"mul", [ResultsBroadcastableShape, NoSideEffect, Commutat
}];
let arguments = (
ins AnyTensor:$lhs,
AnyTensor:$rhs,
ins TFL_TensorOf<[F32, I32, QI8, QUI8, QI16]>:$lhs,
TFL_TensorOf<[F32, I32, QI8, QUI8, QI16]>:$rhs,
TFL_AFAttr:$fused_activation_function);
let results = (outs AnyTensor:$output);
let results = (outs TFL_TensorOf<[F32, I32, QI8, QUI8, QI16]>:$output);
let hasFolder = 1;
@ -1887,9 +1882,9 @@ def TFL_NegOp: TFL_Op<"neg", [NoSideEffect, SameOperandsAndResultType]> {
Computes element-wise negation of input
}];
let arguments = (ins AnyTensor:$x);
let arguments = (ins TFL_TensorOf<[F32, I32, I64]>:$x);
let results = (outs AnyTensor:$y);
let results = (outs TFL_TensorOf<[F32, I32, I64]>:$y);
let hasOptions = 0b1;
@ -2039,10 +2034,10 @@ def TFL_PowOp : TFL_Op<"pow", [ResultsBroadcastableShape, NoSideEffect, NoQuanti
}];
let arguments = (
ins AnyTensor:$lhs,
AnyTensor:$rhs);
ins TFL_TensorOf<[F32, I32]>:$lhs,
TFL_TensorOf<[F32, I32]>:$rhs);
let results = (outs AnyTensor:$output);
let results = (outs TFL_TensorOf<[F32, I32]>:$output);
let parser = [{ return mlir::impl::parseOneResultSameOperandTypeOp(parser, result); }];
@ -2716,7 +2711,7 @@ def TFL_SplitOp : TFL_Op<"split", [
let arguments = (ins
TFL_TensorOf<[I32]>:$split_dim,
TFL_TensorOf<[F32, I16, I32, I64, QI8, QUI8, QI16]>:$value,
PositiveI32Attr:$num_splits
Confined<I32Attr, [IntPositive]>:$num_splits
);
let results = (outs
@ -2741,7 +2736,7 @@ def TFL_SplitVOp : TFL_Op<"split_v", [NoSideEffect, SameOperandsAndResultsScale]
TFL_TensorOf<[F32, I16, I32, I64, QI8, QUI8, QI16]>:$value,
TFL_1DTensorOf<[I32], [I32]>:$size_splits,
TFL_0DTensorOf<[I32], [I32]>:$split_dim,
PositiveI32Attr:$num_splits
Confined<I32Attr, [IntPositive]>:$num_splits
);
let results = (outs
@ -3246,7 +3241,15 @@ Ba et al. 'Layer Normalization'
// Since this op is the FULL kernel only, constrain it.
Confined<
DefaultValuedAttr<TFL_LSTMKernelTypeAttr, "FULL">,
[TFL_LSTM_KT_FULL]>:$kernel_type
[TFL_LSTM_KT_FULL]>:$kernel_type,
// Types of the optional intermediate tensors, which exist for fully
// quantized LSTM op and hold the ranges of the intermediate tensors.
OptionalAttr<TypeAttr>:$input_to_input_intermediate,
OptionalAttr<TypeAttr>:$input_to_forget_intermediate,
OptionalAttr<TypeAttr>:$input_to_cell_intermediate,
OptionalAttr<TypeAttr>:$input_to_output_intermediate,
OptionalAttr<TypeAttr>:$effective_hidden_scale_intermediate
);
let results = (outs AnyTensor:$output);
@ -3350,6 +3353,156 @@ def TFL_UnidirectionalSequenceLSTMOp :
}];
}
def BidiLstmMandatoryInputsConstraint : PredOpTrait<
"mandatory operands element types should match",
// TODO(ashwinm): Replace the indices with input tensor names when that
// support is available.
Or<[
TCopVTEtAreSameAt<[0, 2, 3, 4, 6, 7, 8, 13, 14, 15, 19, 20, 21, 23, 24, 25,
30, 31, 32, 35, 36, 37, 38]>,
Neg<TypeIsPred<"input", F32>>]>>;
def BidiLstmOptionalPeepholeWeightConstraint : PredOpTrait<
"the optional peephole weights should all be specified or none",
TCopVTEtAreSameAt<[9, 10, 11, 26, 27, 28]>>;
def BidiLstmProjectionWeightBiasConstraint : PredOpTrait<
"either projection weight must be specified or both projection weight and "
"projection bias must not be specified",
Or<[
And<[TypeIsPred<"fw_projection_weights", NoneType>,
TypeIsPred<"fw_projection_bias", NoneType>,
TypeIsPred<"bw_projection_weights", NoneType>,
TypeIsPred<"bw_projection_bias", NoneType>]>,
And<[
Neg<TypeIsPred<"fw_projection_weights", NoneType>>,
Neg<TypeIsPred<"bw_projection_weights", NoneType>>,
]>
]>>;
// BidirectionalSequenceLstm op.
// TODO(ashwinm): Add constraint to validate the combination of operands
// that are valid for hybrid vs fully quantized vs float only semantics
def TFL_BidirectionalSequenceLSTMOp :
TFL_Op<"bidirectional_sequence_lstm",
[BidiLstmMandatoryInputsConstraint,
BidiLstmOptionalPeepholeWeightConstraint,
BidiLstmProjectionWeightBiasConstraint,
LstmResultConstraint,
TFL_StatefulOp]> {
let summary = "Bidirectional sequence lstm operator";
let description = [{
Bidirectional lstm is essentiallay two lstms, one running forward & the
other running backward. And the output is the concatenation of the two
lstms.
}];
let arguments = (
ins TFL_TensorOf<[F32, I8]>:$input,
// Forward LSTM Weights
TFL_TensorOfOrNone<[F32, I8]>:$fw_input_to_input_weights,
TFL_TensorOf<[F32, I8]>:$fw_input_to_forget_weights,
TFL_TensorOf<[F32, I8]>:$fw_input_to_cell_weights,
TFL_TensorOf<[F32, I8]>:$fw_input_to_output_weights,
// Forward Recurrent weights
TFL_TensorOfOrNone<[F32, I8]>:$fw_recurrent_to_input_weights,
TFL_TensorOf<[F32, I8]>:$fw_recurrent_to_forget_weights,
TFL_TensorOf<[F32, I8]>:$fw_recurrent_to_cell_weights,
TFL_TensorOf<[F32, I8]>:$fw_recurrent_to_output_weights,
// Forward Cell weights
TFL_TensorOfOrNone<[F32, I8]>:$fw_cell_to_input_weights,
// Optional Forward cell weights
TFL_TensorOfOrNone<[F32, I8]>:$fw_cell_to_forget_weights,
// Optional Forward cell weights
TFL_TensorOfOrNone<[F32, I8]>:$fw_cell_to_output_weights,
// Forward Bias
TFL_TensorOfOrNone<[F32]>:$fw_input_gate_bias,
TFL_TensorOf<[F32]>:$fw_forget_gate_bias,
TFL_TensorOf<[F32]>:$fw_cell_bias,
TFL_TensorOf<[F32]>:$fw_output_gate_bias,
// Forward Projection weight and bias
TFL_TensorOfOrNone<[F32, I8]>:$fw_projection_weights,
// Forward Optional input
TFL_TensorOfOrNone<[F32]>:$fw_projection_bias,
// Backward LSTM Weights
TFL_TensorOfOrNone<[F32, I8]>:$bw_input_to_input_weights,
TFL_TensorOf<[F32, I8]>:$bw_input_to_forget_weights,
TFL_TensorOf<[F32, I8]>:$bw_input_to_cell_weights,
TFL_TensorOf<[F32, I8]>:$bw_input_to_output_weights,
// Backward Recurrent weights
TFL_TensorOfOrNone<[F32, I8]>:$bw_recurrent_to_input_weights,
TFL_TensorOf<[F32, I8]>:$bw_recurrent_to_forget_weights,
TFL_TensorOf<[F32, I8]>:$bw_recurrent_to_cell_weights,
TFL_TensorOf<[F32, I8]>:$bw_recurrent_to_output_weights,
// Backward Cell weights
TFL_TensorOfOrNone<[F32, I8]>:$bw_cell_to_input_weights,
// Optional Forward cell weights
TFL_TensorOfOrNone<[F32, I8]>:$bw_cell_to_forget_weights,
// Optional Forward cell weights
TFL_TensorOfOrNone<[F32, I8]>:$bw_cell_to_output_weights,
// Backward Bias
TFL_TensorOfOrNone<[F32]>:$bw_input_gate_bias,
TFL_TensorOf<[F32]>:$bw_forget_gate_bias,
TFL_TensorOf<[F32]>:$bw_cell_bias,
TFL_TensorOf<[F32]>:$bw_output_gate_bias,
// Backward Projection weight and bias
TFL_TensorOfOrNone<[F32, I8]>:$bw_projection_weights,
// Backward Optional input
TFL_TensorOfOrNone<[F32]>:$bw_projection_bias,
// Stateful activation and cell states.
TFL_StatefulTensor:$fw_input_activation_state,
TFL_StatefulTensor:$fw_input_cell_state,
TFL_StatefulTensor:$bw_input_activation_state,
TFL_StatefulTensor:$bw_input_cell_state,
// Auxiliary input & weights.
TFL_TensorOfOrNone<[F32, I8]>:$aux_input,
// Auxiliary fw weights.
TFL_TensorOfOrNone<[F32, I8]>:$fw_aux_input_to_input_weights,
TFL_TensorOfOrNone<[F32, I8]>:$fw_aux_input_to_forget_weights,
TFL_TensorOfOrNone<[F32, I8]>:$fw_aux_input_to_cell_weights,
TFL_TensorOfOrNone<[F32, I8]>:$fw_aux_input_to_output_weights,
// Auxiliary bw weights.
TFL_TensorOfOrNone<[F32, I8]>:$bw_aux_input_to_input_weights,
TFL_TensorOfOrNone<[F32, I8]>:$bw_aux_input_to_forget_weights,
TFL_TensorOfOrNone<[F32, I8]>:$bw_aux_input_to_cell_weights,
TFL_TensorOfOrNone<[F32, I8]>:$bw_aux_input_to_output_weights,
// Attributes
TFL_AFAttr:$fused_activation_function,
DefaultValuedAttr<F32Attr, "0.0f">:$cell_clip,
DefaultValuedAttr<F32Attr, "0.0f">:$proj_clip,
BoolAttr:$merge_outputs,
BoolAttr:$time_major
);
let results = (outs
AnyTensor:$fw_output,
AnyTensor:$bw_output
);
let hasOptions = 1;
let verifier = [{ return Verify(*this); }];
let extraClassDeclaration = [{
// StatefulOpInterface:
std::vector<int> GetStatefulOperands() { return {35, 36, 37, 38}; }
}];
}
def RnnResultConstraint : PredOpTrait<
"the input and result tensor elemental types must be same",
TCresVTEtIsSameAsOp<0, 0>>;

View File

@ -10,11 +10,9 @@ package_group(
)
cc_library(
name = "graphdef_to_tfl_flatbuffer",
srcs = ["graphdef_to_tfl_flatbuffer.cc"],
hdrs = [
"graphdef_to_tfl_flatbuffer.h",
],
name = "tf_tfl_flatbuffer_helpers",
srcs = ["tf_tfl_flatbuffer_helpers.cc"],
hdrs = ["tf_tfl_flatbuffer_helpers.h"],
deps = [
"//tensorflow/compiler/mlir/lite:common",
"//tensorflow/compiler/mlir/lite:tensorflow_lite",
@ -36,3 +34,61 @@ cc_library(
"@llvm-project//mlir:Transforms",
],
)
cc_library(
name = "graphdef_to_tfl_flatbuffer",
srcs = ["graphdef_to_tfl_flatbuffer.cc"],
hdrs = [
"graphdef_to_tfl_flatbuffer.h",
],
deps = [
":tf_tfl_flatbuffer_helpers",
"//tensorflow/compiler/mlir/lite:common",
"//tensorflow/compiler/mlir/lite:tensorflow_lite",
"//tensorflow/compiler/mlir/lite:tf_tfl_passes",
"//tensorflow/compiler/mlir/lite:tf_to_tfl_flatbuffer",
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tensorflow:convert_graphdef",
"//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/lite/toco:model_flags_proto_cc",
"//tensorflow/lite/toco:toco_flags_proto_cc",
"//tensorflow/lite/toco:types_proto_cc",
"//tensorflow/stream_executor/lib",
"@llvm-project//llvm:support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:Transforms",
],
)
cc_library(
name = "saved_model_to_tfl_flatbuffer",
srcs = ["saved_model_to_tfl_flatbuffer.cc"],
hdrs = [
"saved_model_to_tfl_flatbuffer.h",
],
deps = [
":tf_tfl_flatbuffer_helpers",
"//tensorflow/compiler/mlir/lite:common",
"//tensorflow/compiler/mlir/lite:tensorflow_lite",
"//tensorflow/compiler/mlir/lite:tf_tfl_passes",
"//tensorflow/compiler/mlir/lite:tf_to_tfl_flatbuffer",
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tensorflow:convert_graphdef",
"//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/lite/toco:model_flags_proto_cc",
"//tensorflow/lite/toco:toco_flags_proto_cc",
"//tensorflow/lite/toco:types_proto_cc",
"//tensorflow/stream_executor/lib",
"@llvm-project//llvm:support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:Transforms",
],
)

View File

@ -25,6 +25,7 @@ limitations under the License.
#include "mlir/Support/FileUtilities.h" // TF:llvm-project
#include "mlir/Transforms/ViewOpGraph.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h"
#include "tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h"
#include "tensorflow/compiler/mlir/lite/tf_tfl_passes.h"
#include "tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h"
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
@ -40,288 +41,7 @@ limitations under the License.
#include "tensorflow/lite/toco/types.pb.h"
#include "tensorflow/stream_executor/lib/statusor.h"
using stream_executor::port::StatusOr;
namespace tensorflow {
namespace {
// Op def string for TFLite_Detection_PostProcess Op.
const char kDetectionPostProcessOp[] =
"name: 'TFLite_Detection_PostProcess' input_arg: { name: "
"'raw_outputs/box_encodings' type: DT_FLOAT } input_arg: { name: "
"'raw_outputs/class_predictions' type: DT_FLOAT } input_arg: { name: "
"'anchors' type: DT_FLOAT } output_arg: { name: "
"'TFLite_Detection_PostProcess' type: DT_FLOAT } output_arg: { name: "
"'TFLite_Detection_PostProcess:1' type: DT_FLOAT } output_arg: { name: "
"'TFLite_Detection_PostProcess:2' type: DT_FLOAT } output_arg: { name: "
"'TFLite_Detection_PostProcess:3' type: DT_FLOAT } attr : { name: "
"'h_scale' type: 'float'} attr : { name: 'max_classes_per_detection' "
"type: 'int'} attr : { name: 'max_detections' type: 'int'} attr : { "
"name: 'nms_iou_threshold' type: 'float'} attr : { name: "
"'nms_score_threshold' type: 'float'} attr : { name: 'num_classes' type: "
"'int'} attr : { name: 'w_scale' type: 'float'} attr : { name: 'x_scale' "
"type: 'float'} attr : { name: 'y_scale' type: 'float'} attr { name: "
"'detections_per_class' type: 'int' default_value { i : 100 }} attr { "
"name: 'use_regular_nms' type: 'bool' default_value { b : false }}";
const char kUnidirectionalSequenceLstmOp[] =
"name: 'UnidirectionalSequenceLstm' input_arg: {name: 'Input' type: "
"DT_FLOAT} input_arg: { name: 'InputToInputWeights' type: DT_FLOAT } "
"input_arg: { name: 'InputToForgetWeights' type: DT_FLOAT } input_arg: { "
"name: 'InputToCellWeights' type: DT_FLOAT} input_arg: { name: "
"'InputToOutputWeights' type: DT_FLOAT } input_arg: { name: "
"'RecurrentToInputWeights' type: DT_FLOAT} input_arg: { name: "
"'RecurrentToForgetWeights' type: DT_FLOAT} input_arg: { name: "
"'RecurrentToCellWeights' type: DT_FLOAT } input_arg: { name: "
"'RecurrentToOutputWeights' type: DT_FLOAT } input_arg: { name: "
"'CellToInputWeights' type: DT_FLOAT} input_arg: { name: "
"'CellToForgetWeights' type: DT_FLOAT } input_arg: { name: "
"'CellToOutputWeights' type: DT_FLOAT } input_arg: { name: 'InputGateBias' "
"type: DT_FLOAT } input_arg: { name: 'ForgetGateBias' type: DT_FLOAT } "
"input_arg: { name: 'kCellGateBias' type: DT_FLOAT } input_arg: { name: "
"'OutputGateBias' type: DT_FLOAT } input_arg: { name: 'ProjectionWeights' "
"type: DT_FLOAT } input_arg: { name: 'ProjectionBias' type: DT_FLOAT } "
"input_arg: { name: 'InputActivationState' type: DT_FLOAT} input_arg: { "
"name: 'InputCellStateTensor' type: DT_FLOAT } "
"output_arg: { name: 'Concat' type: DT_FLOAT} "
"output_arg: { name: "
"'LastState' type: DT_FLOAT } output_arg: { name: 'Output' type: DT_FLOAT} "
"attr : { name: '_tflite_input_indices' type: 'list(int)'}";
const char kUnidirectionalSequenceRnnOp[] =
"name: 'UnidirectionalSequenceRnn' input_arg: {name: 'Input' type: "
"DT_FLOAT} input_arg: { name: 'Weights' type: DT_FLOAT } "
"input_arg: { name: 'RecurrentWeights' type: DT_FLOAT } input_arg: { "
"name: 'Bias' type: DT_FLOAT} "
"input_arg: { name: 'HiddenState' type: DT_FLOAT} "
"output_arg: { name: "
"'LastState' type: DT_FLOAT } output_arg: { name: 'Output' type: "
"DT_FLOAT} "
"attr : { name: '_tflite_input_indices' type: 'list(int)'}";
// Converts the toco::IODataType to tensorflow::DataType. Only contains the
// conversion mapping for constants defined in TFLite Python API.
DataType ConvertIODataTypeToDataType(toco::IODataType dtype) {
switch (dtype) {
case toco::IODataType::FLOAT:
return DT_FLOAT;
case toco::IODataType::QUANTIZED_UINT8:
return DT_QUINT8;
case toco::IODataType::INT8:
return DT_QINT8;
case toco::IODataType::INT32:
return DT_INT32;
case toco::IODataType::INT64:
return DT_INT64;
case toco::IODataType::STRING:
return DT_STRING;
case toco::IODataType::BOOL:
return DT_BOOL;
default:
return DT_INVALID;
}
}
StatusOr<std::pair<double, double>> InputStatsToMinMax(double mean, double std,
DataType type) {
// Only qint8 and quint8 are considered here.
double qmin, qmax;
if (type == DT_QUINT8) {
qmin = 0.0;
qmax = 255.0;
} else if (type == DT_QINT8) {
qmin = -128.0;
qmax = 127.0;
} else {
return errors::InvalidArgument("Only int8 and uint8 are considered.");
}
return std::make_pair((qmin - mean) / std, (qmax - mean) / std);
}
// Give a warning for any unused flags that have been specified.
void WarningUnusedFlags(const toco::ModelFlags& model_flags,
const toco::TocoFlags& toco_flags) {
if (toco_flags.output_format()) {
LOG(WARNING) << "Ignored output_format.";
}
if (toco_flags.drop_control_dependency()) {
LOG(WARNING) << "Ignored drop_control_dependency.";
}
if (toco_flags.reorder_across_fake_quant()) {
LOG(WARNING) << "Ignored reorder_across_fake_quant.";
}
if (model_flags.change_concat_input_ranges()) {
LOG(WARNING) << "Ignored change_concat_input_ranges.";
}
if (toco_flags.dump_graphviz_include_video()) {
LOG(WARNING) << "Ignored dump_graphviz_video.";
}
if (model_flags.allow_nonexistent_arrays()) {
LOG(WARNING) << "Allow allow_nonexistent_arrays.";
}
}
// Dumps the op graph of the `module` to `filename` in DOT format.
Status DumpOpGraphToFile(mlir::ModuleOp module, const std::string& filename) {
std::string error_message;
auto output = mlir::openOutputFile(filename, &error_message);
if (!error_message.empty()) {
return errors::InvalidArgument("Failed to open file in %s.", filename);
}
mlir::PassManager pm(module.getContext());
pm.addPass(mlir::createPrintOpGraphPass(output->os()));
if (failed(pm.run(module))) {
return errors::Unknown("Failed to dump Op Graph from MLIR module.");
}
output->keep();
return Status::OK();
}
Status RegisterCustomBuiltinOps(const std::vector<string> extra_tf_opdefs) {
for (const auto& tf_opdefs_string : extra_tf_opdefs) {
tensorflow::OpDef opdef;
if (!tensorflow::protobuf::TextFormat::ParseFromString(tf_opdefs_string,
&opdef)) {
return errors::InvalidArgument("fail to parse extra OpDef");
}
// Make sure the op is not already registered. If registered continue.
const OpRegistrationData* op_reg =
tensorflow::OpRegistry::Global()->LookUp(opdef.name());
if (op_reg) continue;
tensorflow::OpRegistry::Global()->Register(
[opdef](tensorflow::OpRegistrationData* op_reg_data) -> Status {
*op_reg_data = tensorflow::OpRegistrationData(opdef);
return Status::OK();
});
}
return Status::OK();
}
Status RegisterAllCustomOps(const toco::TocoFlags& toco_flags) {
// Register any custom OpDefs.
std::vector<string> extra_tf_opdefs(toco_flags.custom_opdefs().begin(),
toco_flags.custom_opdefs().end());
extra_tf_opdefs.push_back(kDetectionPostProcessOp);
extra_tf_opdefs.push_back(kUnidirectionalSequenceLstmOp);
extra_tf_opdefs.push_back(kUnidirectionalSequenceRnnOp);
return RegisterCustomBuiltinOps(extra_tf_opdefs);
}
Status PopulateQuantizationSpecs(const toco::ModelFlags& model_flags,
const toco::TocoFlags& toco_flags,
mlir::TFL::QuantizationSpecs* quant_specs,
std::vector<string>* node_names,
std::vector<string>* node_dtypes,
std::vector<std::vector<int>>* node_shapes,
std::vector<double>* node_mins,
std::vector<double>* node_maxs) {
quant_specs->inference_input_type =
ConvertIODataTypeToDataType(toco_flags.inference_input_type());
tensorflow::DataType inference_type =
ConvertIODataTypeToDataType(toco_flags.inference_type());
// Use non-float flag `inference_input_type` to override the `inference_type`
// because we have to apply quantization to satisfy that.
if (quant_specs->inference_input_type != tensorflow::DT_FLOAT) {
inference_type = quant_specs->inference_input_type;
}
for (auto& flag : model_flags.input_arrays()) {
node_names->push_back(flag.name());
// TOCO doesn't required `data_type` to be filled for every input.
// If it's not filled, make it an empty string so the importer will use
// the data type in the NodeDef.
auto toco_data_type = flag.data_type();
if (toco_data_type == ::toco::IODataType::IO_DATA_TYPE_UNKNOWN) {
node_dtypes->push_back("");
} else {
node_dtypes->push_back(
DataType_Name(ConvertIODataTypeToDataType(toco_data_type)));
}
node_shapes->push_back(std::vector<int>(flag.shape().dims().begin(),
flag.shape().dims().end()));
// Currently, only UINT8 and INT8 require inputs stats
if (inference_type == DT_QINT8 || inference_type == DT_QUINT8) {
TF_ASSIGN_OR_RETURN(
auto min_max, InputStatsToMinMax(flag.mean_value(), flag.std_value(),
inference_type));
node_mins->push_back(min_max.first);
node_maxs->push_back(min_max.second);
}
}
if (mlir::TFL::GetInputNodeQuantSpecs(*node_names, *node_mins, *node_maxs,
inference_type, quant_specs)) {
return errors::InvalidArgument("Failed to get input quant spec.");
}
// Some extra flag related to post training quantization. If post-training
// quantization is enabled, `inference_type` and `inference_input_type` are
// not used by MLIR passes.
if (toco_flags.post_training_quantize()) {
quant_specs->weight_quantization = true;
if (toco_flags.quantize_to_float16()) {
quant_specs->inference_type = tensorflow::DT_HALF;
quant_specs->inference_input_type = tensorflow::DT_HALF;
} else {
quant_specs->inference_type = tensorflow::DT_QINT8;
quant_specs->inference_input_type = tensorflow::DT_QINT8;
}
}
// Other flags.
if (toco_flags.has_default_ranges_min()) {
quant_specs->default_ranges.first = toco_flags.default_ranges_min();
}
if (toco_flags.has_default_ranges_max()) {
quant_specs->default_ranges.second = toco_flags.default_ranges_max();
}
return ::tensorflow::Status::OK();
}
Status ConvertMLIRToTFLiteFlatBuffer(const toco::TocoFlags& toco_flags,
mlir::OwningModuleRef module,
mlir::TFL::QuantizationSpecs quant_specs,
string* result) {
bool emit_builtin_tflite_ops = !toco_flags.force_select_tf_ops();
bool emit_select_tf_ops = toco_flags.enable_select_tf_ops();
bool emit_custom_ops = toco_flags.allow_custom_ops();
if (toco_flags.has_dump_graphviz_dir()) {
TF_RETURN_IF_ERROR(DumpOpGraphToFile(
module.get(),
// rename once we enable the new converter feature flag.
absl::StrCat(toco_flags.dump_graphviz_dir(), "/toco_AT_IMPORT.dot")));
}
mlir::PassManager pm(module->getContext());
mlir::TFL::PassConfig pass_config(quant_specs);
pass_config.emit_builtin_tflite_ops = emit_builtin_tflite_ops;
pass_config.lower_tensor_list_ops = true;
tensorflow::AddTFToTFLConversionPasses(pass_config, &pm);
// Convert back to outlined while format for export back to flatbuffer.
if (pass_config.legalize_tf_while) {
pm.addPass(mlir::TFL::CreateWhileOutlinePass());
}
pm.addPass(mlir::TFL::CreateRuntimeTypeVerifyPass());
auto status = ConvertTFExecutorToTFLOrFlatbuffer(
module.get(), /*export_to_mlir=*/false, emit_builtin_tflite_ops,
emit_select_tf_ops, emit_custom_ops, quant_specs, result, &pm);
if (toco_flags.has_dump_graphviz_dir()) {
TF_RETURN_IF_ERROR(DumpOpGraphToFile(
// rename once we enable the new converter feature flag.
module.get(), absl::StrCat(toco_flags.dump_graphviz_dir(),
"/toco_AFTER_TRANSFORMATIONS.dot")));
}
return status;
}
} // namespace
Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
const toco::TocoFlags& toco_flags,
const GraphDebugInfo& debug_info,
@ -339,7 +59,7 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
std::vector<double> node_maxs;
// Populate quantization specs.
TF_RETURN_IF_ERROR(PopulateQuantizationSpecs(
TF_RETURN_IF_ERROR(internal::PopulateQuantizationSpecs(
model_flags, toco_flags, &quant_specs, &node_names, &node_dtypes,
&node_shapes, &node_mins, &node_maxs));
@ -356,16 +76,16 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
specs.convert_legacy_fed_inputs = true;
specs.graph_as_function = false;
specs.upgrade_legacy = true;
WarningUnusedFlags(model_flags, toco_flags);
internal::WarningUnusedFlags(model_flags, toco_flags);
// Register all custom ops, including user-specified custom ops.
TF_RETURN_IF_ERROR(RegisterAllCustomOps(toco_flags));
TF_RETURN_IF_ERROR(internal::RegisterAllCustomOps(toco_flags));
TF_ASSIGN_OR_RETURN(
auto module, ConvertGraphdefToMlir(input, debug_info, specs, &context));
return ConvertMLIRToTFLiteFlatBuffer(toco_flags, std::move(module),
quant_specs, result);
return internal::ConvertMLIRToTFLiteFlatBuffer(toco_flags, std::move(module),
quant_specs, result);
}
} // namespace tensorflow

View File

@ -0,0 +1,78 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.h"
#include <utility>
#include "llvm/Support/ToolOutputFile.h"
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
#include "mlir/IR/Module.h" // TF:llvm-project
#include "mlir/Pass/Pass.h" // TF:llvm-project
#include "mlir/Support/FileUtilities.h" // TF:llvm-project
#include "mlir/Transforms/ViewOpGraph.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h"
#include "tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h"
#include "tensorflow/compiler/mlir/lite/tf_tfl_passes.h"
#include "tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h"
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/protobuf/graph_debug_info.pb.h"
#include "tensorflow/lite/toco/model_flags.pb.h"
#include "tensorflow/lite/toco/toco_flags.pb.h"
#include "tensorflow/lite/toco/types.pb.h"
#include "tensorflow/stream_executor/lib/statusor.h"
namespace tensorflow {
Status ConvertSavedModelToTFLiteFlatBuffer(
const toco::ModelFlags& model_flags, const toco::TocoFlags& toco_flags,
const string& saved_model_dir, bool saved_model_v1,
const string& saved_model_tags, const string& saved_model_exported_names,
string* result) {
mlir::MLIRContext context;
mlir::TFL::QuantizationSpecs quant_specs;
// Parse input arrays.
std::vector<string> node_names;
std::vector<string> node_dtypes;
std::vector<std::vector<int>> node_shapes;
std::vector<double> node_mins;
std::vector<double> node_maxs;
// Populate quantization specs.
TF_RETURN_IF_ERROR(internal::PopulateQuantizationSpecs(
model_flags, toco_flags, &quant_specs, &node_names, &node_dtypes,
&node_shapes, &node_mins, &node_maxs));
internal::WarningUnusedFlags(model_flags, toco_flags);
// Register all custom ops, including user-specified custom ops.
TF_RETURN_IF_ERROR(internal::RegisterAllCustomOps(toco_flags));
const bool import_saved_model = !saved_model_v1;
TF_ASSIGN_OR_RETURN(
auto module,
ImportSavedModel(import_saved_model, saved_model_v1, saved_model_dir,
saved_model_tags, saved_model_exported_names, &context));
return internal::ConvertMLIRToTFLiteFlatBuffer(toco_flags, std::move(module),
quant_specs, result);
}
} // namespace tensorflow

View File

@ -0,0 +1,37 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_PYTHON_SAVED_MODEL_TO_TFL_FLATBUFFER_H_
#define TENSORFLOW_COMPILER_MLIR_LITE_PYTHON_SAVED_MODEL_TO_TFL_FLATBUFFER_H_
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/protobuf/graph_debug_info.pb.h"
#include "tensorflow/lite/toco/model_flags.pb.h"
#include "tensorflow/lite/toco/toco_flags.pb.h"
namespace tensorflow {
// Converts the given saved_model(either v1 or v2) to a TF Lite FlatBuffer
// string according to the given model flags, toco flags and tags. Returns error
// status if it fails to convert the input.
Status ConvertSavedModelToTFLiteFlatBuffer(
const toco::ModelFlags& model_flags, const toco::TocoFlags& toco_flags,
const string& saved_model_dir, bool saved_model_v1,
const string& saved_model_tags, const string& saved_model_exported_names,
string* result);
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_MLIR_LITE_PYTHON_SAVED_MODEL_TO_TFL_FLATBUFFER_H_

View File

@ -0,0 +1,325 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h"
#include <ostream>
#include <utility>
#include "llvm/Support/ToolOutputFile.h"
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
#include "mlir/IR/Module.h" // TF:llvm-project
#include "mlir/Pass/Pass.h" // TF:llvm-project
#include "mlir/Support/FileUtilities.h" // TF:llvm-project
#include "mlir/Transforms/ViewOpGraph.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h"
#include "tensorflow/compiler/mlir/lite/tf_tfl_passes.h"
#include "tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h"
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/protobuf/graph_debug_info.pb.h"
#include "tensorflow/lite/toco/model_flags.pb.h"
#include "tensorflow/lite/toco/toco_flags.pb.h"
#include "tensorflow/lite/toco/types.pb.h"
#include "tensorflow/stream_executor/lib/statusor.h"
using stream_executor::port::StatusOr;
namespace tensorflow {
namespace internal {
namespace {
// Op def string for TFLite_Detection_PostProcess Op.
const char kDetectionPostProcessOp[] =
"name: 'TFLite_Detection_PostProcess' input_arg: { name: "
"'raw_outputs/box_encodings' type: DT_FLOAT } input_arg: { name: "
"'raw_outputs/class_predictions' type: DT_FLOAT } input_arg: { name: "
"'anchors' type: DT_FLOAT } output_arg: { name: "
"'TFLite_Detection_PostProcess' type: DT_FLOAT } output_arg: { name: "
"'TFLite_Detection_PostProcess:1' type: DT_FLOAT } output_arg: { name: "
"'TFLite_Detection_PostProcess:2' type: DT_FLOAT } output_arg: { name: "
"'TFLite_Detection_PostProcess:3' type: DT_FLOAT } attr : { name: "
"'h_scale' type: 'float'} attr : { name: 'max_classes_per_detection' "
"type: 'int'} attr : { name: 'max_detections' type: 'int'} attr : { "
"name: 'nms_iou_threshold' type: 'float'} attr : { name: "
"'nms_score_threshold' type: 'float'} attr : { name: 'num_classes' type: "
"'int'} attr : { name: 'w_scale' type: 'float'} attr : { name: 'x_scale' "
"type: 'float'} attr : { name: 'y_scale' type: 'float'} attr { name: "
"'detections_per_class' type: 'int' default_value { i : 100 }} attr { "
"name: 'use_regular_nms' type: 'bool' default_value { b : false }}";
const char kUnidirectionalSequenceLstmOp[] =
"name: 'UnidirectionalSequenceLstm' input_arg: {name: 'Input' type: "
"DT_FLOAT} input_arg: { name: 'InputToInputWeights' type: DT_FLOAT } "
"input_arg: { name: 'InputToForgetWeights' type: DT_FLOAT } input_arg: { "
"name: 'InputToCellWeights' type: DT_FLOAT} input_arg: { name: "
"'InputToOutputWeights' type: DT_FLOAT } input_arg: { name: "
"'RecurrentToInputWeights' type: DT_FLOAT} input_arg: { name: "
"'RecurrentToForgetWeights' type: DT_FLOAT} input_arg: { name: "
"'RecurrentToCellWeights' type: DT_FLOAT } input_arg: { name: "
"'RecurrentToOutputWeights' type: DT_FLOAT } input_arg: { name: "
"'CellToInputWeights' type: DT_FLOAT} input_arg: { name: "
"'CellToForgetWeights' type: DT_FLOAT } input_arg: { name: "
"'CellToOutputWeights' type: DT_FLOAT } input_arg: { name: 'InputGateBias' "
"type: DT_FLOAT } input_arg: { name: 'ForgetGateBias' type: DT_FLOAT } "
"input_arg: { name: 'kCellGateBias' type: DT_FLOAT } input_arg: { name: "
"'OutputGateBias' type: DT_FLOAT } input_arg: { name: 'ProjectionWeights' "
"type: DT_FLOAT } input_arg: { name: 'ProjectionBias' type: DT_FLOAT } "
"input_arg: { name: 'InputActivationState' type: DT_FLOAT} input_arg: { "
"name: 'InputCellStateTensor' type: DT_FLOAT } "
"output_arg: { name: 'Concat' type: DT_FLOAT} "
"output_arg: { name: "
"'LastState' type: DT_FLOAT } output_arg: { name: 'Output' type: DT_FLOAT} "
"attr : { name: '_tflite_input_indices' type: 'list(int)'}";
const char kUnidirectionalSequenceRnnOp[] =
"name: 'UnidirectionalSequenceRnn' input_arg: {name: 'Input' type: "
"DT_FLOAT} input_arg: { name: 'Weights' type: DT_FLOAT } "
"input_arg: { name: 'RecurrentWeights' type: DT_FLOAT } input_arg: { "
"name: 'Bias' type: DT_FLOAT} "
"input_arg: { name: 'HiddenState' type: DT_FLOAT} "
"output_arg: { name: "
"'LastState' type: DT_FLOAT } output_arg: { name: 'Output' type: "
"DT_FLOAT} "
"attr : { name: '_tflite_input_indices' type: 'list(int)'}";
// Converts the toco::IODataType to tensorflow::DataType. Only contains the
// conversion mapping for constants defined in TFLite Python API.
DataType ConvertIODataTypeToDataType(toco::IODataType dtype) {
switch (dtype) {
case toco::IODataType::FLOAT:
return DT_FLOAT;
case toco::IODataType::QUANTIZED_UINT8:
return DT_QUINT8;
case toco::IODataType::INT8:
return DT_QINT8;
case toco::IODataType::INT32:
return DT_INT32;
case toco::IODataType::INT64:
return DT_INT64;
case toco::IODataType::STRING:
return DT_STRING;
case toco::IODataType::BOOL:
return DT_BOOL;
default:
return DT_INVALID;
}
}
StatusOr<std::pair<double, double>> InputStatsToMinMax(double mean, double std,
DataType type) {
// Only qint8 and quint8 are considered here.
double qmin, qmax;
if (type == DT_QUINT8) {
qmin = 0.0;
qmax = 255.0;
} else if (type == DT_QINT8) {
qmin = -128.0;
qmax = 127.0;
} else {
return errors::InvalidArgument("Only int8 and uint8 are considered.");
}
return std::make_pair((qmin - mean) / std, (qmax - mean) / std);
}
Status RegisterCustomBuiltinOps(const std::vector<string> extra_tf_opdefs) {
for (const auto& tf_opdefs_string : extra_tf_opdefs) {
tensorflow::OpDef opdef;
if (!tensorflow::protobuf::TextFormat::ParseFromString(tf_opdefs_string,
&opdef)) {
return errors::InvalidArgument("fail to parse extra OpDef");
}
// Make sure the op is not already registered. If registered continue.
const OpRegistrationData* op_reg =
tensorflow::OpRegistry::Global()->LookUp(opdef.name());
if (op_reg) continue;
tensorflow::OpRegistry::Global()->Register(
[opdef](tensorflow::OpRegistrationData* op_reg_data) -> Status {
*op_reg_data = tensorflow::OpRegistrationData(opdef);
return Status::OK();
});
}
return Status::OK();
}
} // namespace
Status RegisterAllCustomOps(const toco::TocoFlags& toco_flags) {
// Register any custom OpDefs.
std::vector<string> extra_tf_opdefs(toco_flags.custom_opdefs().begin(),
toco_flags.custom_opdefs().end());
extra_tf_opdefs.push_back(kDetectionPostProcessOp);
extra_tf_opdefs.push_back(kUnidirectionalSequenceLstmOp);
extra_tf_opdefs.push_back(kUnidirectionalSequenceRnnOp);
return RegisterCustomBuiltinOps(extra_tf_opdefs);
}
Status PopulateQuantizationSpecs(const toco::ModelFlags& model_flags,
const toco::TocoFlags& toco_flags,
mlir::TFL::QuantizationSpecs* quant_specs,
std::vector<string>* node_names,
std::vector<string>* node_dtypes,
std::vector<std::vector<int>>* node_shapes,
std::vector<double>* node_mins,
std::vector<double>* node_maxs) {
quant_specs->inference_input_type =
ConvertIODataTypeToDataType(toco_flags.inference_input_type());
tensorflow::DataType inference_type =
ConvertIODataTypeToDataType(toco_flags.inference_type());
// Use non-float flag `inference_input_type` to override the `inference_type`
// because we have to apply quantization to satisfy that.
if (quant_specs->inference_input_type != tensorflow::DT_FLOAT) {
inference_type = quant_specs->inference_input_type;
}
for (auto& flag : model_flags.input_arrays()) {
node_names->push_back(flag.name());
// TOCO doesn't required `data_type` to be filled for every input.
// If it's not filled, make it an empty string so the importer will use
// the data type in the NodeDef.
auto toco_data_type = flag.data_type();
if (toco_data_type == ::toco::IODataType::IO_DATA_TYPE_UNKNOWN) {
node_dtypes->push_back("");
} else {
node_dtypes->push_back(
DataType_Name(ConvertIODataTypeToDataType(toco_data_type)));
}
node_shapes->push_back(std::vector<int>(flag.shape().dims().begin(),
flag.shape().dims().end()));
// Currently, only UINT8 and INT8 require inputs stats
if (inference_type == DT_QINT8 || inference_type == DT_QUINT8) {
TF_ASSIGN_OR_RETURN(
auto min_max, InputStatsToMinMax(flag.mean_value(), flag.std_value(),
inference_type));
node_mins->push_back(min_max.first);
node_maxs->push_back(min_max.second);
}
}
if (mlir::TFL::GetInputNodeQuantSpecs(*node_names, *node_mins, *node_maxs,
inference_type, quant_specs)) {
return errors::InvalidArgument("Failed to get input quant spec.");
}
// Some extra flag related to post training quantization. If post-training
// quantization is enabled, `inference_type` and `inference_input_type` are
// not used by MLIR passes.
if (toco_flags.post_training_quantize()) {
quant_specs->weight_quantization = true;
if (toco_flags.quantize_to_float16()) {
quant_specs->inference_type = tensorflow::DT_HALF;
quant_specs->inference_input_type = tensorflow::DT_HALF;
} else {
quant_specs->inference_type = tensorflow::DT_QINT8;
quant_specs->inference_input_type = tensorflow::DT_QINT8;
}
}
// Other flags.
if (toco_flags.has_default_ranges_min()) {
quant_specs->default_ranges.first = toco_flags.default_ranges_min();
}
if (toco_flags.has_default_ranges_max()) {
quant_specs->default_ranges.second = toco_flags.default_ranges_max();
}
return ::tensorflow::Status::OK();
}
// Dumps the op graph of the `module` to `filename` in DOT format.
Status DumpOpGraphToFile(mlir::ModuleOp module, const std::string& filename) {
std::string error_message;
auto output = mlir::openOutputFile(filename, &error_message);
if (!error_message.empty()) {
return errors::InvalidArgument("Failed to open file in %s.", filename);
}
mlir::PassManager pm(module.getContext());
pm.addPass(mlir::createPrintOpGraphPass(output->os()));
if (failed(pm.run(module))) {
return errors::Unknown("Failed to dump Op Graph from MLIR module.");
}
output->keep();
return Status::OK();
}
Status ConvertMLIRToTFLiteFlatBuffer(const toco::TocoFlags& toco_flags,
mlir::OwningModuleRef module,
mlir::TFL::QuantizationSpecs quant_specs,
string* result) {
bool emit_builtin_tflite_ops = !toco_flags.force_select_tf_ops();
bool emit_select_tf_ops = toco_flags.enable_select_tf_ops();
bool emit_custom_ops = toco_flags.allow_custom_ops();
if (toco_flags.has_dump_graphviz_dir()) {
TF_RETURN_IF_ERROR(DumpOpGraphToFile(
module.get(),
// rename once we enable the new converter feature flag.
absl::StrCat(toco_flags.dump_graphviz_dir(), "/toco_AT_IMPORT.dot")));
}
mlir::PassManager pm(module->getContext());
mlir::TFL::PassConfig pass_config(quant_specs);
pass_config.emit_builtin_tflite_ops = emit_builtin_tflite_ops;
pass_config.lower_tensor_list_ops = true;
tensorflow::AddTFToTFLConversionPasses(pass_config, &pm);
// Convert back to outlined while format for export back to flatbuffer.
if (pass_config.legalize_tf_while) {
pm.addPass(mlir::TFL::CreateWhileOutlinePass());
}
pm.addPass(mlir::TFL::CreateRuntimeTypeVerifyPass());
auto status = ConvertTFExecutorToTFLOrFlatbuffer(
module.get(), /*export_to_mlir=*/false, emit_builtin_tflite_ops,
emit_select_tf_ops, emit_custom_ops, quant_specs, result, &pm);
if (toco_flags.has_dump_graphviz_dir()) {
TF_RETURN_IF_ERROR(DumpOpGraphToFile(
// rename once we enable the new converter feature flag.
module.get(), absl::StrCat(toco_flags.dump_graphviz_dir(),
"/toco_AFTER_TRANSFORMATIONS.dot")));
}
return status;
}
void WarningUnusedFlags(const toco::ModelFlags& model_flags,
const toco::TocoFlags& toco_flags) {
if (toco_flags.output_format()) {
LOG(WARNING) << "Ignored output_format.";
}
if (toco_flags.drop_control_dependency()) {
LOG(WARNING) << "Ignored drop_control_dependency.";
}
if (toco_flags.reorder_across_fake_quant()) {
LOG(WARNING) << "Ignored reorder_across_fake_quant.";
}
if (model_flags.change_concat_input_ranges()) {
LOG(WARNING) << "Ignored change_concat_input_ranges.";
}
if (toco_flags.dump_graphviz_include_video()) {
LOG(WARNING) << "Ignored dump_graphviz_video.";
}
if (model_flags.allow_nonexistent_arrays()) {
LOG(WARNING) << "Allow allow_nonexistent_arrays.";
}
}
} // namespace internal
} // namespace tensorflow

View File

@ -0,0 +1,59 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_PYTHON_TF_TFL_FLATBUFFER_HELPERS_H_
#define TENSORFLOW_COMPILER_MLIR_LITE_PYTHON_TF_TFL_FLATBUFFER_HELPERS_H_
#include <ostream>
#include <utility>
#include "mlir/IR/Module.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h"
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
#include "tensorflow/lite/toco/model_flags.pb.h"
#include "tensorflow/lite/toco/toco_flags.pb.h"
#include "tensorflow/lite/toco/types.pb.h"
#include "tensorflow/stream_executor/lib/statusor.h"
namespace tensorflow {
namespace internal {
// Register all custom ops including user specified custom ops.
Status RegisterAllCustomOps(const toco::TocoFlags& toco_flags);
// Populate quantization specs (or not) given user specified ranges for each
// input arrays.
Status PopulateQuantizationSpecs(const toco::ModelFlags& model_flags,
const toco::TocoFlags& toco_flags,
mlir::TFL::QuantizationSpecs* quant_specs,
std::vector<string>* node_names,
std::vector<string>* node_dtypes,
std::vector<std::vector<int>>* node_shapes,
std::vector<double>* node_mins,
std::vector<double>* node_maxs);
// Convert imported MLIR file to TfLite flatbuffer.
// This will also run relevant passes as well.
Status ConvertMLIRToTFLiteFlatBuffer(const toco::TocoFlags& toco_flags,
mlir::OwningModuleRef module,
mlir::TFL::QuantizationSpecs quant_specs,
string* result);
// Give a warning for any unused flags that have been specified.
void WarningUnusedFlags(const toco::ModelFlags& model_flags,
const toco::TocoFlags& toco_flags);
} // namespace internal
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_MLIR_LITE_PYTHON_TF_TFL_FLATBUFFER_HELPERS_H_

View File

@ -20,7 +20,7 @@ limitations under the License.
#define TF_Quantization
include "mlir/IR/OpBase.td"
include "mlir/Dialect/QuantOps/QuantPredicates.td"
include "mlir/Dialect/QuantOps/QuantOpsBase.td"
//===----------------------------------------------------------------------===//
// QuantizedType definitions.

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
#include <algorithm>
#include <cstdint>
#include <limits>
#include <numeric>
@ -147,29 +148,45 @@ static bool BroadcastVector(int target_size, SmallVectorImpl<T>& data) {
// Changes the axis of the input per-channel quantized type to match the
// dimension of the target type. Returns nullptr if it fails.
static quant::UniformQuantizedPerAxisType ResetAxisAndBroadcast(
quant::UniformQuantizedPerAxisType qtype, Type target, int quant_dim) {
ArrayRef<int64_t> shape, quant::UniformQuantizedPerAxisType qtype,
Type target, int quant_dim) {
auto shaped = target.dyn_cast<RankedTensorType>();
if (!shaped) return {};
ArrayRef<int64_t> new_shape = shaped.getShape();
SmallVector<double, 4> scales(qtype.getScales().begin(),
qtype.getScales().end());
SmallVector<int64_t, 4> zero_points(qtype.getZeroPoints().begin(),
qtype.getZeroPoints().end());
// Broadcast the scales and zero points to match the target size, which is
// usually the axis-th dimension of the target type. Currently, it covers two
// cases:
// - for Transpose, the data layout is changed so the `dim[axis]` still equals
// to the `scales_size`. The broadcast skips;
// - for Reshape, the data layout isn't changed but the innermost dimension is
// expand to cover the last two original dimensions. Thus we just need to be
// repeated the `scales` dim[2] times to covers the new dim length.
//
// TODO(b/141709944): after the fix, the `scales` can be for dim[2], thus we
// have to repeat each elements in the `scales` locally dim[3] times.
if (BroadcastVector<double>(shaped.getDimSize(quant_dim), scales) ||
BroadcastVector<int64_t>(shaped.getDimSize(quant_dim), zero_points)) {
if (new_shape.size() == shape.size()) { // same rank
// Broadcast the scales and zero points to match the target size, which is
// usually the axis-th dimension of the target type. Currently, it covers
// two cases:
// - for Transpose, the data layout is changed so the `dim[axis]` still
// equals to the `scales_size`. The broadcast skips;
// - for Reshape, the data layout isn't changed but the innermost dimension
// is expand to cover the last two original dimensions. Thus we just need to
// be repeated the `scales` dim[2] times to covers the new dim length.
//
// TODO(b/141709944): after the fix, the `scales` can be for dim[2], thus we
// have to repeat each elements in the `scales` locally dim[3] times.
if (BroadcastVector<double>(shaped.getDimSize(quant_dim), scales) ||
BroadcastVector<int64_t>(shaped.getDimSize(quant_dim), zero_points)) {
return {};
}
} else if ((new_shape.size() == shape.size() + 1) && new_shape.back() == 1) {
// This is a trivial shift left, then we shift the quant_dim as well.
if (std::equal(shape.begin(), shape.end(), new_shape.begin()) &&
quant_dim == -1) {
quant_dim = shape.size() + quant_dim;
} else {
return {};
}
} else {
return {};
}
return quant::UniformQuantizedPerAxisType::get(
qtype.getFlags(), qtype.getStorageType(), qtype.getExpressedType(),
scales, zero_points, quant_dim, qtype.getStorageTypeMin(),
@ -179,20 +196,21 @@ static quant::UniformQuantizedPerAxisType ResetAxisAndBroadcast(
TypeAttr CastQuantizedTypeAttrFromExpressedType(Builder builder,
TypeAttr source, Type target,
int axis) {
if (auto source_type = source.getValue().dyn_cast_or_null<ShapedType>()) {
auto src_ele_type = source_type.getElementType();
if (auto quantized_type = src_ele_type.dyn_cast<quant::QuantizedType>()) {
if (auto qtype =
quantized_type.dyn_cast<quant::UniformQuantizedPerAxisType>()) {
quantized_type = ResetAxisAndBroadcast(qtype, target, axis);
if (!src_ele_type) return {};
}
Type final_type = quantized_type.castFromExpressedType(target);
if (!final_type) return {};
return TypeAttr::get(final_type);
}
auto source_type = source.getValue().dyn_cast_or_null<ShapedType>();
if (!source_type) return {};
auto src_ele_type = source_type.getElementType();
auto qtype = src_ele_type.dyn_cast<quant::QuantizedType>();
// Reset the quantization dimensions if it is per-axis.
if (auto per_axis =
qtype.dyn_cast_or_null<quant::UniformQuantizedPerAxisType>()) {
qtype =
ResetAxisAndBroadcast(source_type.getShape(), per_axis, target, axis);
}
return {};
if (!qtype) return {};
Type final_type = qtype.castFromExpressedType(target);
if (!final_type) return {};
return TypeAttr::get(final_type);
}
Type GetUniformQuantizedTypeForWeight(ElementsAttr attr, bool symmetric,

View File

@ -9,6 +9,7 @@ package_group(
name = "friends",
includes = ["//third_party/mlir:subpackages"],
packages = [
"//tensorflow/compiler/aot/...",
"//tensorflow/compiler/mlir/...",
"//tensorflow/compiler/mlir/lite/...",
],
@ -38,3 +39,29 @@ cc_library(
],
alwayslink = 1,
)
cc_library(
name = "quantize",
srcs = [
"quantize.cc",
],
hdrs = [
"quantize.h",
],
deps = [
"//tensorflow/compiler/mlir/xla:hlo",
"//tensorflow/compiler/mlir/xla:hlo_to_mlir_hlo",
"//tensorflow/compiler/tf2xla",
"//tensorflow/compiler/tf2xla:mlir_tf2xla",
"//tensorflow/compiler/tf2xla:tf2xla_proto_cc",
"//tensorflow/compiler/tf2xla:tf2xla_util",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/core/platform:status",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Transforms",
],
)

View File

@ -0,0 +1,62 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/lite/quantization/xla/quantize.h"
#include "mlir/IR/Builders.h" // TF:llvm-project
#include "mlir/IR/Function.h" // TF:llvm-project
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
#include "mlir/IR/Module.h" // TF:llvm-project
#include "mlir/Pass/Pass.h" // TF:llvm-project
#include "mlir/Pass/PassManager.h" // TF:llvm-project
#include "mlir/Transforms/Passes.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/xla/hlo_to_mlir_hlo.h"
#include "tensorflow/compiler/tf2xla/tf2xla.h"
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
namespace mlir {
namespace xla_hlo {
// Quantizes the model in the computation.
tensorflow::Status XlaQuantize(const tensorflow::tf2xla::Config& config,
xla::XlaComputation* computation) {
TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::HloSnapshot> snapshot,
computation->Snapshot());
MLIRContext context;
OwningModuleRef module = ModuleOp::create(UnknownLoc::get(&context));
auto status = xla::ConvertHloToMlirHlo(
module.get(), snapshot->mutable_hlo()->mutable_hlo_module());
if (!status.ok()) {
LOG(ERROR) << "Hlo module import failed: " << status;
return status;
}
PassManager pm(&context);
pm.addPass(createCanonicalizerPass());
pm.addPass(createInlinerPass());
pm.addPass(createSymbolDCEPass());
pm.addNestedPass<FuncOp>(createCSEPass());
mlir::StatusScopedDiagnosticHandler diag_handler(&context);
LogicalResult result = pm.run(module.get());
(void)result;
module->dump();
return tensorflow::Status::OK();
}
} // namespace xla_hlo
} // namespace mlir

View File

@ -0,0 +1,33 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_XLA_QUANTIZE_H_
#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_XLA_QUANTIZE_H_
#include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/core/platform/status.h"
namespace mlir {
namespace xla_hlo {
// Quantizes the model in the computation.
tensorflow::Status XlaQuantize(const tensorflow::tf2xla::Config& config,
xla::XlaComputation* computation);
} // namespace xla_hlo
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_XLA_QUANTIZE_H_

View File

@ -3,8 +3,14 @@ load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests")
package(licenses = ["notice"])
glob_lit_tests(
data = [":test_utilities"],
data = [
":graph_config_files",
":test_utilities",
],
driver = "@llvm-project//mlir:run_lit.sh",
tags_override = {
"fadd_quant.mlir": ["no_oss"], # TODO(b/150957738): to be fixed on oss.
},
test_file_exts = ["mlir"],
)
@ -13,7 +19,17 @@ filegroup(
name = "test_utilities",
testonly = True,
data = [
"//tensorflow/compiler/aot:tfcompile",
"//tensorflow/compiler/mlir:tf-opt",
"@llvm-project//llvm:FileCheck",
"@llvm-project//llvm:not",
],
)
# Bundle together all the graph files that are used by the tests.
filegroup(
name = "graph_config_files",
srcs = glob(
["**/*.pbtxt"],
),
)

View File

@ -0,0 +1,15 @@
# RUN: not tfcompile --graph=%s.pbtxt --config=%s.config.pbtxt --quantize --cpp_class="::test::fadd_quant" 2>&1 | FileCheck %s -dump-input-on-failure
# TODO(fengliuai): update this file with the progress of the implementation
// CHECK: func @main
// CHECK: %cst = constant dense<0.000000e+00> : tensor<f32>
// CHECK: %cst_0 = constant dense<1.270000e+02> : tensor<f32>
// CHECK: %cst_1 = constant dense<8> : tensor<i32>
// CHECK: %cst_2 = constant dense<false> : tensor<i1>
// CHECK: %0 = "xla_hlo.custom_call"(%arg0, %cst, %cst_0, %cst_1, %cst_2) {backend_config = "", call_target_name = "fake_quant_with_min_max_vars", has_side_effect = false, name = "custom-call.9"} : (tensor<2x4xf32>, tensor<f32>, tensor<f32>, tensor<i32>, tensor<i1>) -> tensor<2x4xf32>
// CHECK: %1 = "xla_hlo.custom_call"(%arg1, %cst, %cst_0, %cst_1, %cst_2) {backend_config = "", call_target_name = "fake_quant_with_min_max_vars", has_side_effect = false, name = "custom-call.14"} : (tensor<2x4xf32>, tensor<f32>, tensor<f32>, tensor<i32>, tensor<i1>) -> tensor<2x4xf32>
// CHECK: %2 = xla_hlo.add %0, %1 {name = "add.15"} : tensor<2x4xf32>
// CHECK: %3 = "xla_hlo.custom_call"(%2, %cst, %cst_0, %cst_1, %cst_2) {backend_config = "", call_target_name = "fake_quant_with_min_max_vars", has_side_effect = false, name = "custom-call.20"} : (tensor<2x4xf32>, tensor<f32>, tensor<f32>, tensor<i32>, tensor<i1>) -> tensor<2x4xf32>
// CHECK: %4 = "xla_hlo.tuple"(%3) {name = "tuple.22"} : (tensor<2x4xf32>) -> tuple<tensor<2x4xf32>>
// CHECK: return %4 : tuple<tensor<2x4xf32>>
// CHECK: }

View File

@ -0,0 +1,26 @@
feed {
id { node_name: "input0" }
shape {
dim { size: 2 }
dim { size: 4 }
}
}
feed {
id { node_name: "input1" }
shape {
dim { size: 2 }
dim { size: 4 }
}
}
fetch {
id { node_name: "Add/FakeQuantWithMinMaxVars" }
shape {
dim { size: 2 }
dim { size: 4 }
}
}
conversion_options {
custom_fake_quant_op_calls: true
}

View File

@ -0,0 +1,218 @@
node: {
name: "Add/FakeQuantWithMinMaxVars"
op: "FakeQuantWithMinMaxVars"
input: "Add"
input: "Add/FakeQuantWithMinMaxVars/min"
input: "Add/FakeQuantWithMinMaxVars/max"
attr: {
key: "num_bits"
value: {
i: 8
}
}
attr: {
key: "narrow_range"
value: {
b: false
}
}
}
node: {
name: "Add/FakeQuantWithMinMaxVars/min"
op: "Const"
attr: {
key: "value"
value: {
tensor: {
dtype: DT_FLOAT
tensor_shape: {
}
float_val: 0.0
}
}
}
attr: {
key: "dtype"
value: {
type: DT_FLOAT
}
}
}
node: {
name: "Add/FakeQuantWithMinMaxVars/max"
op: "Const"
attr: {
key: "value"
value: {
tensor: {
dtype: DT_FLOAT
tensor_shape: {
}
float_val: 127.0
}
}
}
attr: {
key: "dtype"
value: {
type: DT_FLOAT
}
}
}
node {
name: "Add"
op: "Add"
input: "input0/FakeQuantWithMinMaxVars"
input: "input1/FakeQuantWithMinMaxVars"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
}
node: {
name: "input0/FakeQuantWithMinMaxVars"
op: "FakeQuantWithMinMaxVars"
input: "input0"
input: "input0/FakeQuantWithMinMaxVars/min"
input: "input0/FakeQuantWithMinMaxVars/max"
attr: {
key: "num_bits"
value: {
i: 8
}
}
attr: {
key: "narrow_range"
value: {
b: false
}
}
}
node: {
name: "input0/FakeQuantWithMinMaxVars/min"
op: "Const"
attr: {
key: "value"
value: {
tensor: {
dtype: DT_FLOAT
tensor_shape: {
}
float_val: 0.0
}
}
}
attr: {
key: "dtype"
value: {
type: DT_FLOAT
}
}
}
node: {
name: "input0/FakeQuantWithMinMaxVars/max"
op: "Const"
attr: {
key: "value"
value: {
tensor: {
dtype: DT_FLOAT
tensor_shape: {
}
float_val: 127.0
}
}
}
attr: {
key: "dtype"
value: {
type: DT_FLOAT
}
}
}
node {
name: "input0"
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
}
node: {
name: "input1/FakeQuantWithMinMaxVars"
op: "FakeQuantWithMinMaxVars"
input: "input1"
input: "input1/FakeQuantWithMinMaxVars/min"
input: "input1/FakeQuantWithMinMaxVars/max"
attr: {
key: "num_bits"
value: {
i: 8
}
}
attr: {
key: "narrow_range"
value: {
b: false
}
}
}
node: {
name: "input1/FakeQuantWithMinMaxVars/min"
op: "Const"
attr: {
key: "value"
value: {
tensor: {
dtype: DT_FLOAT
tensor_shape: {
}
float_val: 0.0
}
}
}
attr: {
key: "dtype"
value: {
type: DT_FLOAT
}
}
}
node: {
name: "input1/FakeQuantWithMinMaxVars/max"
op: "Const"
attr: {
key: "value"
value: {
tensor: {
dtype: DT_FLOAT
tensor_shape: {
}
float_val: 127.0
}
}
}
attr: {
key: "dtype"
value: {
type: DT_FLOAT
}
}
}
node {
name: "input1"
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
}
versions {
producer: 27
}

View File

@ -1,4 +1,4 @@
# RUN: tf_tfl_translate -tf-input-arrays=input0,input1 -tf-input-shapes=4:4 -tf-input-data-types=DT_INT32,DT_INT32 -tf-output-arrays=Add %s -o - | flatbuffer_to_string - | FileCheck %s
# RUN: tf_tfl_translate -tf-input-arrays=input0,input1 -tf-input-shapes=4:4 -tf-input-data-types=DT_INT32,DT_INT32 -tf-output-arrays=Add %s -o - | flatbuffer_to_string - | FileCheck --dump-input-on-failure %s
# Add two tensor<4xi32> inputs and return the result
@ -90,5 +90,11 @@ versions {
# CHECK-EMPTY:
# CHECK-NEXT: }, {
# CHECK-EMPTY:
# CHECK-NEXT: }, {
# CHECK-NEXT: data: [ 49, 46, 53, 46, 48, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
# CHECK-NEXT: } ],
# CHECK-NEXT: metadata: [ {
# CHECK-NEXT: name: "min_runtime_version",
# CHECK-NEXT: buffer: 4
# CHECK-NEXT: } ]
# CHECK-NEXT: }

View File

@ -61,11 +61,11 @@ func @i64() -> tensor<4xi64> {
// the same sort of opaque round-trip we get for complex64, but it might be good
// to check
func @uint8() -> tensor<4x!tf.uint8> {
func @uint8() -> tensor<4xui8> {
// CHECK-LABEL: @uint8
// CHECK: value = opaque<"tf", "0x746674656E736F722464747970653A2044545F55494E54382074656E736F725F7368617065207B2064696D207B2073697A653A2034207D207D2074656E736F725F636F6E74656E743A20225C3333365C3235355C3237365C33353722"> : tensor<4x!tf.uint8>
%0 = "tfl.pseudo_const"() { value = opaque<"tf", "0x746674656E736F722464747970653A2044545F55494E54382074656E736F725F7368617065207B2064696D207B2073697A653A2034207D207D2074656E736F725F636F6E74656E743A20225C3333365C3235355C3237365C33353722"> : tensor<4x!tf.uint8> } : () -> tensor<4x!tf.uint8>
return %0 : tensor<4x!tf.uint8>
// CHECK: value = dense<[222, 173, 190, 239]> : tensor<4xui8>
%0 = "tfl.pseudo_const"() {value = dense<[222, 173, 190, 239]> : tensor<4xui8>} : () -> tensor<4xui8>
return %0 : tensor<4xui8>
}
func @qi32_per_axis() -> tensor<3x3x!quant.uniform<i32:f32:1, {1.0, 0.5:1, 0.25:1}>> {

View File

@ -13,3 +13,16 @@ func @main(%arg0: tensor<4 x f32>, %arg1: tensor<4 x f32>, %arg2: tensor<4 x f32
// CHECK: return %[[RES0]]
}
// -----
func @testFullyQuantizedLSTM(%arg0: tensor<1x528x!quant.uniform<i8:f32, 0.037248000502586365:-19>>, %arg1: tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, %arg2: tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.031925998628139496>>, %arg3: tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.056272000074386597>>, %arg4: tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.063763998448848724>>, %arg5: tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.013358999975025654>>, %arg6: tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.022830000147223473>>, %arg7: tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.032276000827550888>>, %arg8: tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.035427000373601913>>, %arg9: tensor<2048x!quant.uniform<i32:f32, 4.2675782196965883E-7>>, %arg10: tensor<2048x!quant.uniform<i32:f32, 1.0742187583900886E-7>>, %arg11: tensor<2048x!quant.uniform<i32:f32, 1.6406249869760359E-7>>, %arg12: tensor<2048x!quant.uniform<i32:f32, 1.523437447303877E-7>>, %arg13: tensor<640x2048x!quant.uniform<i8<-127:127>:f32, 0.021174000576138496>>, %arg14: tensor<640x!quant.uniform<i32:f32, 1.601389680352559E-4>>, %arg15: tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>, %arg16: tensor<2048x!quant.uniform<i16:f32, 1.1000000085914508E-4>>, %arg17: tensor<2048x!quant.uniform<i16:f32, 1.6799999866634607E-4>>, %arg18: tensor<2048x!quant.uniform<i16:f32, 1.55999994603917E-4>>, %arg19: tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>>, %arg20: tensor<1x2048x!quant.uniform<i16:f32, 4.8799999058246613E-4>>) -> tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>> {
%cst = constant unit
%0 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %cst, %cst, %cst, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg19, %arg20, %arg15, %arg16, %arg17, %arg18) ({}) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", input_to_input_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0049890000373125076>>, input_to_forget_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0078849997371435165>>, input_to_cell_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0087630003690719604>>, input_to_output_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0057529998011887074>>, effective_hidden_scale_intermediate = tensor<0x!quant.uniform<i8<-127:127>:f32, 0.0075630000792443752:2>>, kernel_type = "FULL", proj_clip = 0.01 : f32} : (tensor<1x528x!quant.uniform<i8:f32, 0.037248000502586365:-19>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.031925998628139496>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.056272000074386597>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.063763998448848724>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.013358999975025654>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.022830000147223473>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.032276000827550888>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.035427000373601913>>, none, none, none, tensor<2048x!quant.uniform<i32:f32, 4.2675782196965883E-7>>, tensor<2048x!quant.uniform<i32:f32, 1.0742187583900886E-7>>, tensor<2048x!quant.uniform<i32:f32, 1.6406249869760359E-7>>, tensor<2048x!quant.uniform<i32:f32, 1.523437447303877E-7>>, tensor<640x2048x!quant.uniform<i8<-127:127>:f32, 0.021174000576138496>>, tensor<640x!quant.uniform<i32:f32, 1.601389680352559E-4>>, tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>>, tensor<1x2048x!quant.uniform<i16:f32, 4.8799999058246613E-4>>, tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>, tensor<2048x!quant.uniform<i16:f32, 1.1000000085914508E-4>>, tensor<2048x!quant.uniform<i16:f32, 1.6799999866634607E-4>>, tensor<2048x!quant.uniform<i16:f32, 1.55999994603917E-4>>) -> tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>>
return %0 : tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>>
// CHECK-LABEL: testFullyQuantizedLSTM
// CHECK: %cst = constant unit
// CHECK: %[[RES0:.*]] = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %cst, %cst, %cst, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg19, %arg20, %arg15, %arg16, %arg17, %arg18)
// CHECK: }) {cell_clip = 1.000000e+01 : f32, effective_hidden_scale_intermediate = tensor<0x!quant.uniform<i8<-127:127>:f32, 0.0075630000792443752:2>>, fused_activation_function = "TANH", input_to_cell_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0087630003690719604>>, input_to_forget_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0078849997371435165>>, input_to_input_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0049890000373125076>>, input_to_output_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0057529998011887074>>, kernel_type = "FULL", proj_clip = 0.00999999977 : f32} : (tensor<1x528x!quant.uniform<i8:f32, 0.037248000502586365:-19>>, tensor<2048x528x!quant.uniform<i8:f32, 0.059801999479532242>>, tensor<2048x528x!quant.uniform<i8:f32, 0.031925998628139496>>, tensor<2048x528x!quant.uniform<i8:f32, 0.056272000074386597>>, tensor<2048x528x!quant.uniform<i8:f32, 0.063763998448848724>>, tensor<2048x640x!quant.uniform<i8:f32, 0.013358999975025654>>, tensor<2048x640x!quant.uniform<i8:f32, 0.022830000147223473>>, tensor<2048x640x!quant.uniform<i8:f32, 0.032276000827550888>>, tensor<2048x640x!quant.uniform<i8:f32, 0.035427000373601913>>, none, none, none, tensor<2048x!quant.uniform<i32:f32, 4.2675782196965883E-7>>, tensor<2048x!quant.uniform<i32:f32, 1.0742187583900886E-7>>, tensor<2048x!quant.uniform<i32:f32, 1.6406249869760359E-7>>, tensor<2048x!quant.uniform<i32:f32, 1.523437447303877E-7>>, tensor<640x2048x!quant.uniform<i8:f32, 0.021174000576138496>>, tensor<640x!quant.uniform<i32:f32, 1.6013896674849093E-4>>, tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>>, tensor<1x2048x!quant.uniform<i16:f32, 4.8799999058246613E-4>>, tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>, tensor<2048x!quant.uniform<i16:f32, 1.1000000085914508E-4>>, tensor<2048x!quant.uniform<i16:f32, 1.6799999866634607E-4>>, tensor<2048x!quant.uniform<i16:f32, 1.55999994603917E-4>>) -> tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>>
}

View File

@ -58,15 +58,15 @@ func @while_cond_10_frozen0(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>, %arg2: t
// CANON-SAME: (tensor<i32>, tensor<256x256xf32>, tensor<?x256x256xf32>)
// CANON: [[VAL_1:%.*]] = constant dense<1.000000e+00> : tensor<256x256xf32>
// CANON: [[VAL_2:%.*]] = constant dense<0> : tensor<i32>
// CANON: [[VAL_3:%.*]] = constant dense<10> : tensor<i32>
// CANON: [[VAL_4:%.*]] = constant dense<1> : tensor<i32>
// CANON: [[VAL_5:%.*]] = "tf.Const"() {value = dense<2.560000e+02> : tensor<256x256xf32>} : () -> tensor<?x?xf32>
// CANON: [[VAL_6:%.*]]:3 = "tfl.while"([[VAL_2]], [[VAL_2]], [[VAL_0]]) ( {
// CANON: ^bb0([[VAL_7:%.*]]: tensor<*xi32>, [[VAL_8:%.*]]: tensor<*xi32>, [[VAL_9:%.*]]: tensor<*xf32>):
// CANON: [[VAL_3:%.*]] = constant dense<10> : tensor<i32>
// CANON: [[VAL_10:%.*]] = "tf.Less"([[VAL_8]], [[VAL_3]])
// CANON: "tfl.yield"([[VAL_10]]) : (tensor<*xi1>) -> ()
// CANON: }, {
// CANON: ^bb0([[VAL_11:%.*]]: tensor<*xi32>, [[VAL_12:%.*]]: tensor<*xi32>, [[VAL_13:%.*]]: tensor<*xf32>):
// CANON: [[VAL_4:%.*]] = constant dense<1> : tensor<i32>
// CANON: [[VAL_5:%.*]] = "tf.Const"() {value = dense<2.560000e+02> : tensor<256x256xf32>} : () -> tensor<?x?xf32>
// CANON: [[VAL_14:%.*]] = "tf.AddV2"([[VAL_12]], [[VAL_4]])
// CANON: [[VAL_15:%.*]] = "tf.AddV2"([[VAL_13]], [[VAL_5]])
// CANON: [[VAL_16:%.*]] = "tf.AddV2"([[VAL_11]], [[VAL_4]])

View File

@ -1,22 +1,11 @@
// RUN: tf-opt %s -tfl-legalize-tf | FileCheck %s --dump-input-on-failure
func @addRelu(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> {
func @add(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> {
%0 = "tf.Add"(%arg0, %arg1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
%1 = "tf.Add"(%arg0, %0) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
%2 = "tf.Relu"(%1) : (tensor<1xf32>) -> tensor<1xf32>
%3 = "tf.Relu"(%arg0) : (tensor<1xf32>) -> tensor<1xf32>
%4 = "tf.Add"(%3, %2) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
%5 = "tf.Relu6"(%4) : (tensor<1xf32>) -> tensor<1xf32>
%6 = "tfl.add"(%5, %3) {fused_activation_function = "NONE"} : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
%7 = "tf.Relu6"(%6) : (tensor<1xf32>) -> tensor<1xf32>
return %7: tensor<1xf32>
return %0: tensor<1xf32>
// CHECK-LABEL: addRelu
// CHECK-LABEL: add
// CHECK: tfl.add %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<1xf32>
// CHECK: %1 = tfl.add %arg0, %0 {fused_activation_function = "RELU"} : tensor<1xf32>
// CHECK: %2 = "tfl.relu"(%arg0) : (tensor<1xf32>) -> tensor<1xf32>
// CHECK: %3 = tfl.add %2, %1 {fused_activation_function = "RELU6"} : tensor<1xf32>
// CHECK: %4 = tfl.add %3, %2 {fused_activation_function = "RELU6"} : tensor<1xf32>
// CHECK: return
}
@ -30,13 +19,10 @@ func @LeakyRelu(%arg0: tensor<1xf32>) -> tensor<1xf32> {
func @biasAdd(%arg0: tensor<1x10x10x32xf32>, %arg1: tensor<32xf32>) -> tensor<1x10x10x32xf32> {
%0 = "tf.BiasAdd"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", data_format = "NHWC"} : (tensor<1x10x10x32xf32>, tensor<32xf32>) -> tensor<1x10x10x32xf32>
%1 = "tf.BiasAdd"(%0, %arg1) {T = "tfdtype$DT_FLOAT", data_format = "NHWC"} : (tensor<1x10x10x32xf32>, tensor<32xf32>) -> tensor<1x10x10x32xf32>
%2 = "tf.Relu6"(%1) : (tensor<1x10x10x32xf32>) -> tensor<1x10x10x32xf32>
return %2 : tensor<1x10x10x32xf32>
return %0 : tensor<1x10x10x32xf32>
// CHECK-LABEL: biasAdd
// CHECK: "tfl.add"(%arg0, %arg1) {fused_activation_function = "NONE"} : (tensor<1x10x10x32xf32>, tensor<32xf32>) -> tensor<1x10x10x32xf32>
// CHECK: %1 = "tfl.add"(%0, %arg1) {fused_activation_function = "RELU6"} : (tensor<1x10x10x32xf32>, tensor<32xf32>) -> tensor<1x10x10x32xf32>
}
func @biasAddInt(%arg0: tensor<1x10x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x10x10x32xi32> {
@ -55,9 +41,9 @@ func @squeezeAndReshape(%arg0: tensor<1x1x10xf32>, %arg1: tensor<?x10xf32>) -> i
%4 = "some_op"(%1, %3) : (tensor<*xf32>, tensor<2x5xf32>) -> i32
return %4 : i32
// CHECK-LABEL: squeezeAndReshape
// CHECK: %cst = constant dense<[2, 5]> : tensor<2xi32>
// CHECK: "tfl.squeeze"(%arg0) {squeeze_dims = [0]} : (tensor<1x1x10xf32>) -> tensor<1x10xf32>
// CHECK: %1 = "tfl.squeeze"(%arg1) {squeeze_dims = []} : (tensor<?x10xf32>) -> tensor<*xf32>
// CHECK: %cst = constant dense<[2, 5]> : tensor<2xi32>
// CHECK: %2 = "tfl.reshape"(%0, %cst) : (tensor<1x10xf32>, tensor<2xi32>) -> tensor<2x5xf32>
// CHECK: %3 = "some_op"(%1, %2) : (tensor<*xf32>, tensor<2x5xf32>) -> i32
// CHECK: return
@ -88,7 +74,7 @@ func @dynamicReshapeI64Fold(%arg0: tensor<*xf32>) -> tensor<1x2xf32> {
return %0 : tensor<1x2xf32>
// CHECK-LABEL: dynamicReshapeI64Fold
// CHECK-NEXT: %[[cst:.*]] = constant dense<[1, 2]> : tensor<2xi32>
// CHECK: %[[cst:.*]] = constant dense<[1, 2]> : tensor<2xi32>
// CHECK-NEXT: %[[reshape:.*]] = "tfl.reshape"(%arg0, %[[cst]]) : (tensor<*xf32>, tensor<2xi32>) -> tensor<1x2xf32>
// CHECK-NEXT: return %[[reshape]] : tensor<1x2xf32>
}
@ -128,10 +114,10 @@ func @softplus(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
return %0 : tensor<8x16xf32>
// CHECK-LABEL: softplus
// CHECK-NEXT: %[[cst:.*]] = constant dense<1.000000e+00> : tensor<f32>
// CHECK-NEXT: %[[exp:.*]] = "tfl.exp"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16xf32>
// CHECK-NEXT: %[[add:.*]] = "tfl.add"(%[[exp]], %[[cst]]) {fused_activation_function = "NONE"} : (tensor<8x16xf32>, tensor<f32>) -> tensor<8x16xf32>
// CHECK-NEXT: %[[log:.*]] = "tfl.log"(%[[add]]) : (tensor<8x16xf32>) -> tensor<8x16xf32>
// CHECK: %[[exp:.*]] = "tfl.exp"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16xf32>
// CHECK: %[[cst:.*]] = constant dense<1.000000e+00> : tensor<f32>
// CHECK: %[[add:.*]] = "tfl.add"(%[[exp]], %[[cst]]) {fused_activation_function = "NONE"} : (tensor<8x16xf32>, tensor<f32>) -> tensor<8x16xf32>
// CHECK: %[[log:.*]] = "tfl.log"(%[[add]]) : (tensor<8x16xf32>) -> tensor<8x16xf32>
}
func @fakeQuantArgsFalse(%arg0: tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> {
@ -255,20 +241,12 @@ func @zeros_like(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
// CHECK: "tfl.zeros_like"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16xf32>
}
func @divRelu(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> {
func @div(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> {
%0 = "tf.Div"(%arg0, %arg1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
%1 = "tf.Div"(%arg0, %0) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
%2 = "tf.Relu"(%1) : (tensor<1xf32>) -> tensor<1xf32>
%3 = "tf.Relu"(%arg0) : (tensor<1xf32>) -> tensor<1xf32>
%4 = "tf.Div"(%3, %2) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
%5 = "tf.Relu6"(%4) : (tensor<1xf32>) -> tensor<1xf32>
return %5: tensor<1xf32>
return %0: tensor<1xf32>
// CHECK-LABEL: divRelu
// CHECK-LABEL: div
// CHECK: tfl.div %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<1xf32>
// CHECK: %1 = tfl.div %arg0, %0 {fused_activation_function = "RELU"} : tensor<1xf32>
// CHECK: %2 = "tfl.relu"(%arg0) : (tensor<1xf32>) -> tensor<1xf32>
// CHECK: %3 = tfl.div %2, %1 {fused_activation_function = "RELU6"} : tensor<1xf32>
// CHECK: return
}
@ -698,8 +676,9 @@ func @matrix_diag_v2_no_match(%arg0: tensor<8x16xf32>) -> tensor<8x16x16xf32> {
// CHECK-SAME: [[VAL_0:%.*]]: tensor<8x16xf32>) -> tensor<8x16x16xf32> {
// CHECK: [[VAL_1:%.*]] = constant dense<1> : tensor<1xi32>
// CHECK: [[VAL_2:%.*]] = constant dense<-1> : tensor<1xi32>
// CHECK: [[VAL_5:%.*]] = constant dense<-1> : tensor<1xi32>
// CHECK: [[VAL_3:%.*]] = constant dense<0> : tensor<2xi32>
// CHECK: [[VAL_4:%.*]] = "tf.MatrixDiagV2"([[VAL_0]], [[VAL_1]], [[VAL_2]], [[VAL_2]], [[VAL_3]]) : (tensor<8x16xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<2xi32>) -> tensor<8x16x16xf32>
// CHECK: [[VAL_4:%.*]] = "tf.MatrixDiagV2"([[VAL_0]], [[VAL_1]], [[VAL_2]], [[VAL_5]], [[VAL_3]]) : (tensor<8x16xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<2xi32>) -> tensor<8x16x16xf32>
// CHECK: return [[VAL_4]] : tensor<8x16x16xf32>
}
@ -731,8 +710,9 @@ func @matrix_diag_v3_no_match(%arg0: tensor<8x16xf32>) -> tensor<8x16x16xf32> {
// CHECK-SAME: [[VAL_0:%.*]]: tensor<8x16xf32>) -> tensor<8x16x16xf32> {
// CHECK: [[VAL_1:%.*]] = constant dense<1> : tensor<1xi32>
// CHECK: [[VAL_2:%.*]] = constant dense<-1> : tensor<1xi32>
// CHECK: [[VAL_5:%.*]] = constant dense<-1> : tensor<1xi32>
// CHECK: [[VAL_3:%.*]] = constant dense<0> : tensor<2xi32>
// CHECK: [[VAL_4:%.*]] = "tf.MatrixDiagV3"([[VAL_0]], [[VAL_1]], [[VAL_2]], [[VAL_2]], [[VAL_3]]) : (tensor<8x16xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<2xi32>) -> tensor<8x16x16xf32>
// CHECK: [[VAL_4:%.*]] = "tf.MatrixDiagV3"([[VAL_0]], [[VAL_1]], [[VAL_2]], [[VAL_5]], [[VAL_3]]) : (tensor<8x16xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<2xi32>) -> tensor<8x16x16xf32>
// CHECK: return [[VAL_4]] : tensor<8x16x16xf32>
}
@ -1295,7 +1275,8 @@ func @conv2d_backprop_input(%arg0: tensor<4xi32>, %arg1: tensor<3x3x1x32xf32>, %
// CHECK: %[[CST:.*]] = constant dense<[2, 0, 1, 3]> : tensor<4xi32>
// CHECK: %[[ARG0:.*]] = "tfl.transpose"(%arg1, %[[CST]]) : (tensor<3x3x1x32xf32>, tensor<4xi32>) -> tensor<1x3x3x32xf32>
// CHECK: %[[ARG1:.*]] = "tfl.transpose_conv"(%arg0, %[[ARG0]], %arg2) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<1x3x3x32xf32>, tensor<15x14x14x32xf32>) -> tensor<15x28x28x1xf32>
// CHECK: %[[ARG2:.*]] = "tfl.transpose"(%arg1, %[[CST]]) : (tensor<3x3x1x32xf32>, tensor<4xi32>) -> tensor<1x3x3x32xf32>
// CHECK: %[[CST_1:.*]] = constant dense<[2, 0, 1, 3]> : tensor<4xi32>
// CHECK: %[[ARG2:.*]] = "tfl.transpose"(%arg1, %[[CST_1]]) : (tensor<3x3x1x32xf32>, tensor<4xi32>) -> tensor<1x3x3x32xf32>
// CHECK: %[[ARG3:.*]] = "tfl.transpose_conv"(%arg0, %[[ARG2]], %arg2) {padding = "VALID", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<1x3x3x32xf32>, tensor<15x14x14x32xf32>) -> tensor<15x28x28x1xf32>
// CHECK: %[[RESULT:.*]] = tfl.add %[[ARG1]], %[[ARG3]] {fused_activation_function = "NONE"} : tensor<15x28x28x1xf32>
// CHECK: return %[[RESULT]] : tensor<15x28x28x1xf32>
@ -1340,8 +1321,8 @@ func @reciprocal_f16(%arg0: tensor<8xf16>) -> tensor<8xf16> {
return %0: tensor<8xf16>
// CHECK-LABEL: reciprocal_f16
// CHECK: %cst = constant dense<1.000000e+00> : tensor<1xf16>
// CHECK: "tfl.div"(%cst, %arg0) {fused_activation_function = "NONE"} : (tensor<1xf16>, tensor<8xf16>) -> tensor<8xf16>
// CHECK: %cst = constant dense<1.000000e+00> : tensor<f16>
// CHECK: "tfl.div"(%cst, %arg0) {fused_activation_function = "NONE"} : (tensor<f16>, tensor<8xf16>) -> tensor<8xf16>
// CHECK: return
}
@ -1350,8 +1331,8 @@ func @reciprocal_f32(%arg0: tensor<8xf32>) -> tensor<8xf32> {
return %0: tensor<8xf32>
// CHECK-LABEL: reciprocal_f32
// CHECK: %cst = constant dense<1.000000e+00> : tensor<1xf32>
// CHECK: "tfl.div"(%cst, %arg0) {fused_activation_function = "NONE"} : (tensor<1xf32>, tensor<8xf32>) -> tensor<8xf32>
// CHECK: %cst = constant dense<1.000000e+00> : tensor<f32>
// CHECK: "tfl.div"(%cst, %arg0) {fused_activation_function = "NONE"} : (tensor<f32>, tensor<8xf32>) -> tensor<8xf32>
// CHECK: return
}
@ -1360,8 +1341,8 @@ func @reciprocal_complex_f32(%arg0: tensor<8xcomplex<f32>>) -> tensor<8xcomplex<
return %0: tensor<8xcomplex<f32>>
// CHECK-LABEL: reciprocal_complex_f32
// CHECK: %cst = constant opaque<"tf", "0x746674656E736F722464747970653A2044545F434F4D504C455836342074656E736F725F7368617065207B2064696D207B2073697A653A2031207D207D2074656E736F725F636F6E74656E743A20225C3030305C3030305C3230303F5C3030305C3030305C3030305C30303022"> : tensor<1xcomplex<f32>>
// CHECK: "tfl.div"(%cst, %arg0) {fused_activation_function = "NONE"} : (tensor<1xcomplex<f32>>, tensor<8xcomplex<f32>>) -> tensor<8xcomplex<f32>>
// CHECK: %cst = constant opaque<"tf", "0x746674656E736F722464747970653A2044545F434F4D504C455836342074656E736F725F7368617065207B2064696D207B2073697A653A2031207D207D2074656E736F725F636F6E74656E743A20225C3030305C3030305C3230303F5C3030305C3030305C3030305C30303022"> : tensor<complex<f32>>
// CHECK: "tfl.div"(%cst, %arg0) {fused_activation_function = "NONE"} : (tensor<complex<f32>>, tensor<8xcomplex<f32>>) -> tensor<8xcomplex<f32>>
// CHECK: return
}
@ -1370,8 +1351,8 @@ func @reciprocal_i32(%arg0: tensor<8xi32>) -> tensor<8xi32> {
return %0: tensor<8xi32>
// CHECK-LABEL: reciprocal_i32
// CHECK: %cst = constant dense<1> : tensor<1xi32>
// CHECK: "tfl.div"(%cst, %arg0) {fused_activation_function = "NONE"} : (tensor<1xi32>, tensor<8xi32>) -> tensor<8xi32>
// CHECK: %cst = constant dense<1> : tensor<i32>
// CHECK: "tfl.div"(%cst, %arg0) {fused_activation_function = "NONE"} : (tensor<i32>, tensor<8xi32>) -> tensor<8xi32>
// CHECK: return
}
@ -1380,8 +1361,8 @@ func @reciprocal_i64(%arg0: tensor<8xi64>) -> tensor<8xi64> {
return %0: tensor<8xi64>
// CHECK-LABEL: reciprocal_i64
// CHECK: %cst = constant dense<1> : tensor<1xi64>
// CHECK: "tfl.div"(%cst, %arg0) {fused_activation_function = "NONE"} : (tensor<1xi64>, tensor<8xi64>) -> tensor<8xi64>
// CHECK: %cst = constant dense<1> : tensor<i64>
// CHECK: "tfl.div"(%cst, %arg0) {fused_activation_function = "NONE"} : (tensor<i64>, tensor<8xi64>) -> tensor<8xi64>
// CHECK: return
}
@ -1436,7 +1417,7 @@ func @LstmWithoutProjection(%arg: tensor<28x1x28xf32>) -> (tensor<28x1x16xf32>)
// CHECK: [[VAL_3:%.*]] = constant dense<0.000000e+00> : tensor<16xf32>
// CHECK: [[VAL_4:%.*]] = constant dense<0.000000e+00> : tensor<1x16xf32>
// CHECK: [[VAL_5:%.*]] = constant unit
// CHECK: [[VAL_6:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_0]], [[VAL_1]], [[VAL_1]], [[VAL_1]], [[VAL_1]], [[VAL_2]], [[VAL_2]], [[VAL_2]], [[VAL_2]], [[VAL_3]], [[VAL_3]], [[VAL_3]], [[VAL_3]], [[VAL_3]], [[VAL_3]], [[VAL_3]], [[VAL_5]], [[VAL_5]], [[VAL_4]], [[VAL_4]], [[VAL_5]], [[VAL_5]], [[VAL_5]], [[VAL_5]]) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = true} : (tensor<28x1x28xf32>, tensor<16x28xf32>, tensor<16x28xf32>, tensor<16x28xf32>, tensor<16x28xf32>, tensor<16x16xf32>, tensor<16x16xf32>, tensor<16x16xf32>, tensor<16x16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, none, none, tensor<1x16xf32>, tensor<1x16xf32>, none, none, none, none) -> tensor<28x1x16xf32>
// CHECK: [[VAL_6:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_0]], [[VAL_1]], [[VAL_1]], [[VAL_1]], [[VAL_1]], [[VAL_2]], [[VAL_2]], [[VAL_2]], [[VAL_2]], [[VAL_3]], [[VAL_3]], [[VAL_3]], [[VAL_3]], [[VAL_3]], [[VAL_3]], [[VAL_3]], [[VAL_5]], [[VAL_5]], [[VAL_4]], [[VAL_4]], [[VAL_5]], [[VAL_5]], [[VAL_5]], [[VAL_5]]) {cell_clip = 0.000000e+00 : f32, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = true} : (tensor<28x1x28xf32>, tensor<16x28xf32>, tensor<16x28xf32>, tensor<16x28xf32>, tensor<16x28xf32>, tensor<16x16xf32>, tensor<16x16xf32>, tensor<16x16xf32>, tensor<16x16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, none, none, tensor<1x16xf32>, tensor<1x16xf32>, none, none, none, none) -> tensor<28x1x16xf32>
// CHECK: return [[VAL_6]] : tensor<28x1x16xf32>
// CHECK: }
@ -1461,7 +1442,7 @@ func @LstmWithProjection(%arg: tensor<28x1x16xf32>) -> (tensor<28x1x8xf32>) {
// CHECK: [[VAL_12:%.*]] = constant dense<0.000000e+00> : tensor<8x16xf32>
// CHECK: [[VAL_13:%.*]] = constant dense<0.000000e+00> : tensor<1x8xf32>
// CHECK: [[VAL_14:%.*]] = constant unit
// CHECK: [[VAL_15:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_7]], [[VAL_8]], [[VAL_8]], [[VAL_8]], [[VAL_8]], [[VAL_9]], [[VAL_9]], [[VAL_9]], [[VAL_9]], [[VAL_14]], [[VAL_14]], [[VAL_14]], [[VAL_10]], [[VAL_10]], [[VAL_10]], [[VAL_10]], [[VAL_12]], [[VAL_14]], [[VAL_13]], [[VAL_11]], [[VAL_14]], [[VAL_14]], [[VAL_14]], [[VAL_14]]) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = true} : (tensor<28x1x16xf32>, tensor<16x16xf32>, tensor<16x16xf32>, tensor<16x16xf32>, tensor<16x16xf32>, tensor<16x8xf32>, tensor<16x8xf32>, tensor<16x8xf32>, tensor<16x8xf32>, none, none, none, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<8x16xf32>, none, tensor<1x8xf32>, tensor<1x16xf32>, none, none, none, none) -> tensor<28x1x8xf32>
// CHECK: [[VAL_15:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_7]], [[VAL_8]], [[VAL_8]], [[VAL_8]], [[VAL_8]], [[VAL_9]], [[VAL_9]], [[VAL_9]], [[VAL_9]], [[VAL_14]], [[VAL_14]], [[VAL_14]], [[VAL_10]], [[VAL_10]], [[VAL_10]], [[VAL_10]], [[VAL_12]], [[VAL_14]], [[VAL_13]], [[VAL_11]], [[VAL_14]], [[VAL_14]], [[VAL_14]], [[VAL_14]]) {cell_clip = 0.000000e+00 : f32, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = true} : (tensor<28x1x16xf32>, tensor<16x16xf32>, tensor<16x16xf32>, tensor<16x16xf32>, tensor<16x16xf32>, tensor<16x8xf32>, tensor<16x8xf32>, tensor<16x8xf32>, tensor<16x8xf32>, none, none, none, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<8x16xf32>, none, tensor<1x8xf32>, tensor<1x16xf32>, none, none, none, none) -> tensor<28x1x8xf32>
// CHECK: return [[VAL_15]] : tensor<28x1x8xf32>
// CHECK: }
@ -1480,3 +1461,25 @@ func @UnidirectionalRnn(%arg: tensor<28x1x28xf32>) -> (tensor<28x1x28xf32>) {
// CHECK: [[VAL_4:%.*]] = "tfl.unidirectional_sequence_rnn"([[VAL_0]], [[VAL_1]], [[VAL_1]], [[VAL_2]], [[VAL_3]]) {fused_activation_function = "TANH", time_major = true} : (tensor<28x1x28xf32>, tensor<28x28xf32>, tensor<28x28xf32>, tensor<28xf32>, tensor<1x28xf32>) -> tensor<28x1x28xf32>
// CHECK: return [[VAL_4]] : tensor<28x1x28xf32>
// CHECK: }
func @broadcast_to_f32(%arg0: tensor<3xf32>, %arg1: tensor<2xi32>) -> tensor<3x3xf32> {
%0 = "tf.BroadcastTo"(%arg0, %arg1) : (tensor<3xf32>, tensor<2xi32>) -> tensor<3x3xf32>
return %0: tensor<3x3xf32>
// CHECK-LABEL: broadcast_to_f32
// CHECK: [[CST:%.*]] = constant dense<1.000000e+00> : tensor<f32>
// CHECK: [[FILL:%.*]] = "tfl.fill"(%arg1, [[CST]]) : (tensor<2xi32>, tensor<f32>) -> tensor<3x3xf32>
// CHECK: [[MUL:%.*]] = "tfl.mul"(%arg0, [[FILL]]) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32>
// CHECK return [[MUL]] : tensor<3x3xf32>
}
func @broadcast_to_i32(%input: tensor<3xi32>, %shape: tensor<2xi32>) -> tensor<3x3xi32> {
%0 = "tf.BroadcastTo"(%input, %shape) : (tensor<3xi32>, tensor<2xi32>) -> tensor<3x3xi32>
return %0: tensor<3x3xi32>
// CHECK-LABEL: broadcast_to_i32
// CHECK: [[CST:%.*]] = constant dense<1> : tensor<i32>
// CHECK: [[FILL:%.*]] = "tfl.fill"(%arg1, [[CST]]) : (tensor<2xi32>, tensor<i32>) -> tensor<3x3xi32>
// CHECK: [[MUL:%.*]] = "tfl.mul"(%arg0, [[FILL]]) {fused_activation_function = "NONE"} : (tensor<3xi32>, tensor<3x3xi32>) -> tensor<3x3xi32>
// CHECK return [[MUL]] : tensor<3x3xi32>
}

View File

@ -12,11 +12,9 @@
// CHECK: Tensor 0 pconst kTfLiteInt32 kTfLiteMmapRo 4 bytes
// CHECK-NEXT: Tensor 1 N kTfLiteInt32 kTfLiteMmapRo 4 bytes
// CHECK-NEXT: Tensor 2 val kTfLiteFloat32 kTfLiteMmapRo 4 bytes
// CHECK-NEXT: Tensor 3 std.constant kTfLiteInt32 kTfLiteMmapRo 4 bytes
// CHECK-NEXT: Tensor 4 tfl.while kTfLiteInt32 kTfLiteArenaRw 4 bytes
// CHECK-NEXT: Tensor 5 result kTfLiteFloat32 kTfLiteArenaRw 4 bytes
// CHECK-NEXT: Tensor 6 tfl.while:2 kTfLiteInt32 kTfLiteArenaRw 4 bytes
// CHECK-NEXT: Tensor 7 tfl.while:3 kTfLiteInt32 kTfLiteArenaRw 4 bytes
// CHECK-NEXT: Tensor 3 tfl.while kTfLiteInt32 kTfLiteArenaRw 4 bytes
// CHECK-NEXT: Tensor 4 result kTfLiteFloat32 kTfLiteArenaRw 4 bytes
// CHECK-NEXT: Tensor 5 tfl.while:2 kTfLiteInt32 kTfLiteArenaRw 4 bytes
// Verify while was not folded away:
// ------------------------------------

View File

@ -108,6 +108,12 @@ func @main(tensor<1x384xf32>, tensor<1x96xf32>, tensor<384x480xf32>, tensor<384x
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-NEXT: data: [ 49, 46, 49, 48, 46, 48, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
// CHECK-NEXT: } ],
// CHECK-NEXT: metadata: [ {
// CHECK-NEXT: name: "min_runtime_version",
// CHECK-NEXT: buffer: 10
// CHECK-NEXT: } ]
// CHECK-NEXT:}

View File

@ -1,4 +1,4 @@
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -emit-custom-ops -o - | flatbuffer_to_string - | FileCheck %s
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -emit-custom-ops -o - | flatbuffer_to_string - | FileCheck --dump-input-on-failure %s
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir -o - | FileCheck --check-prefix=MLIR %s
@ -61,6 +61,12 @@ func @main(%arg0: tensor<32x4x4x128xf32>, %arg1: tensor<1x32x42x128xf32>, %arg2:
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-NEXT: data: [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
// CHECK-NEXT: } ],
// CHECK-NEXT: metadata: [ {
// CHECK-NEXT: name: "min_runtime_version",
// CHECK-NEXT: buffer: 5
// CHECK-NEXT: } ]
// CHECK-NEXT:}

View File

@ -1,4 +1,4 @@
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -emit-custom-ops -o - | flatbuffer_to_string - | FileCheck %s
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -emit-custom-ops -o - | flatbuffer_to_string - | FileCheck --dump-input-on-failure %s
func @main(tensor<4xf32>) -> tensor<4xf32> {
^bb0(%arg0: tensor<4xf32>):
@ -90,6 +90,12 @@ func @main(tensor<4xf32>) -> tensor<4xf32> {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-NEXT: data: [ 49, 46, 55, 46, 48, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
// CHECK-NEXT: } ],
// CHECK-NEXT: metadata: [ {
// CHECK-NEXT: name: "min_runtime_version",
// CHECK-NEXT: buffer: 6
// CHECK-NEXT: } ]
// CHECK-NEXT:}

View File

@ -82,6 +82,12 @@ func @main(tensor<1x224x224x3xf32>) -> tensor<1x112x112x32xf32> {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-NEXT: data: [ 49, 46, 49, 51, 46, 49, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
// CHECK-NEXT: } ],
// CHECK-NEXT: metadata: [ {
// CHECK-NEXT: name: "min_runtime_version",
// CHECK-NEXT: buffer: 6
// CHECK-NEXT: } ]
// CHECK-NEXT:}

View File

@ -84,6 +84,12 @@ func @main(tensor<1x224x224x3xf32>) -> tensor<1x112x112x32xf32> {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-NEXT: data: [ 49, 46, 49, 51, 46, 49, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
// CHECK-NEXT: } ],
// CHECK-NEXT: metadata: [ {
// CHECK-NEXT: name: "min_runtime_version",
// CHECK-NEXT: buffer: 6
// CHECK-NEXT: } ]
// CHECK-NEXT:}

View File

@ -1,4 +1,4 @@
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck %s
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck --dump-input-on-failure %s
func @main(tensor<4xf32>) -> tensor<4xf32> {
^bb0(%arg0: tensor<4xf32>):
@ -88,6 +88,12 @@ func @main(tensor<4xf32>) -> tensor<4xf32> {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-NEXT: data: [ 49, 46, 55, 46, 48, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
// CHECK-NEXT: } ],
// CHECK-NEXT: metadata: [ {
// CHECK-NEXT: name: "min_runtime_version",
// CHECK-NEXT: buffer: 6
// CHECK-NEXT: } ]
// CHECK-NEXT:}

View File

@ -1,4 +1,4 @@
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck %s
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck --dump-input-on-failure %s
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate -tflite-flatbuffer-to-mlir - -o - | FileCheck --check-prefix=IMPORT %s
func @main(tensor<4xf32>) -> tensor<4xf32> {
@ -46,6 +46,12 @@ func @main(tensor<4xf32>) -> tensor<4xf32> {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-NEXT: data: [ 49, 46, 53, 46, 48, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
// CHECK-NEXT: } ],
// CHECK-NEXT: metadata: [ {
// CHECK-NEXT: name: "min_runtime_version",
// CHECK-NEXT: buffer: 3
// CHECK-NEXT: } ]
// CHECK-NEXT: }

View File

@ -1,4 +1,4 @@
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -emit-select-tf-ops=true -emit-builtin-tflite-ops=false -o - | flatbuffer_to_string - | FileCheck %s
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -emit-select-tf-ops=true -emit-builtin-tflite-ops=false -o - | flatbuffer_to_string - | FileCheck --dump-input-on-failure %s
func @main(%arg0: tensor<3x2xf32>) -> tensor<3x2xf32> {
// CHECK: {
@ -39,6 +39,12 @@ func @main(%arg0: tensor<3x2xf32>) -> tensor<3x2xf32> {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-NEXT: data: [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
// CHECK-NEXT: } ],
// CHECK-NEXT: metadata: [ {
// CHECK-NEXT: name: "min_runtime_version",
// CHECK-NEXT: buffer: 3
// CHECK-NEXT: } ]
// CHECK-NEXT: }

View File

@ -1,4 +1,4 @@
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -emit-select-tf-ops -o - | flatbuffer_to_string - | FileCheck %s
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -emit-select-tf-ops -o - | flatbuffer_to_string - | FileCheck --dump-input-on-failure %s
func @main(tensor<4xf32>) -> tensor<4xf32> {
^bb0(%arg0: tensor<4xf32>):
@ -89,6 +89,12 @@ func @main(tensor<4xf32>) -> tensor<4xf32> {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-NEXT: data: [ 49, 46, 55, 46, 48, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
// CHECK-NEXT: } ],
// CHECK-NEXT: metadata: [ {
// CHECK-NEXT: name: "min_runtime_version",
// CHECK-NEXT: buffer: 6
// CHECK-NEXT: } ]
// CHECK-NEXT:}

View File

@ -61,6 +61,12 @@ func @main(tensor<40x37xf32>, tensor<40x37xf32>) -> tensor<40x40xf32> {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-NEXT: data: [ 49, 46, 53, 46, 48, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
// CHECK-NEXT: } ],
// CHECK-NEXT: metadata: [ {
// CHECK-NEXT: name: "min_runtime_version",
// CHECK-NEXT: buffer: 5
// CHECK-NEXT: } ]
// CHECK-NEXT:}

View File

@ -61,6 +61,12 @@ func @main(tensor<40x37xf32>, tensor<40x37xf32>) -> tensor<40x40xf32> {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-NEXT: data: [ 49, 46, 49, 48, 46, 48, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
// CHECK-NEXT: } ],
// CHECK-NEXT: metadata: [ {
// CHECK-NEXT: name: "min_runtime_version",
// CHECK-NEXT: buffer: 5
// CHECK-NEXT: } ]
// CHECK-NEXT:}

View File

@ -156,6 +156,12 @@
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-NEXT: data: [ 49, 46, 49, 52, 46, 48, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
// CHECK-NEXT: } ],
// CHECK-NEXT: metadata: [ {
// CHECK-NEXT: name: "min_runtime_version",
// CHECK-NEXT: buffer: 11
// CHECK-NEXT: } ]
// CHECK-NEXT: }

View File

@ -1,4 +1,4 @@
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck %s
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck --dump-input-on-failure %s
func @main(tensor<4xi1>) -> tensor<4xi1> {
^bb0(%arg0: tensor<4xi1>):
@ -78,6 +78,12 @@ func @main(tensor<4xi1>) -> tensor<4xi1> {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-NEXT: data: [ 49, 46, 49, 49, 46, 48, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
// CHECK-NEXT: } ],
// CHECK-NEXT: metadata: [ {
// CHECK-NEXT: name: "min_runtime_version",
// CHECK-NEXT: buffer: 6
// CHECK-NEXT: } ]
// CHECK-NEXT: }
// CHECK-EMPTY:

View File

@ -1,4 +1,4 @@
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck %s
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck --dump-input-on-failure %s
func @main(tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>) -> tensor<4 x f32> {
// CHECK: {
@ -192,7 +192,8 @@ func @main(tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, t
// CHECK-NEXT: builtin_options_type: LSTMOptions,
// CHECK-NEXT: builtin_options: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: },
// CHECK-NEXT: intermediates: [ ]
// CHECK-NEXT: } ],
// CHECK-NEXT: name: "main"
// CHECK-NEXT: } ],
@ -249,6 +250,12 @@ func @main(tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, t
// CHECK-NEXT: data: [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-NEXT: data: [ 49, 46, 55, 46, 48, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
// CHECK-NEXT: } ],
// CHECK-NEXT: metadata: [ {
// CHECK-NEXT: name: "min_runtime_version",
// CHECK-NEXT: buffer: 26
// CHECK-NEXT: } ]
// CHECK-NEXT: }
// CHECK-EMPTY:

View File

@ -0,0 +1,323 @@
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck --dump-input-on-failure %s
func @main(%arg0: tensor<1x528x!quant.uniform<i8:f32, 0.037248000502586365:-19>>, %arg1: tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, %arg2: tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.031925998628139496>>, %arg3: tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.056272000074386597>>, %arg4: tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.063763998448848724>>, %arg5: tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.013358999975025654>>, %arg6: tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.022830000147223473>>, %arg7: tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.032276000827550888>>, %arg8: tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.035427000373601913>>, %arg9: tensor<2048x!quant.uniform<i32:f32, 4.2675782196965883E-7>>, %arg10: tensor<2048x!quant.uniform<i32:f32, 1.0742187583900886E-7>>, %arg11: tensor<2048x!quant.uniform<i32:f32, 1.6406249869760359E-7>>, %arg12: tensor<2048x!quant.uniform<i32:f32, 1.523437447303877E-7>>, %arg13: tensor<640x2048x!quant.uniform<i8<-127:127>:f32, 0.021174000576138496>>, %arg14: tensor<640x!quant.uniform<i32:f32, 1.601389680352559E-4>>, %arg15: tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>, %arg16: tensor<2048x!quant.uniform<i16:f32, 1.1000000085914508E-4>>, %arg17: tensor<2048x!quant.uniform<i16:f32, 1.6799999866634607E-4>>, %arg18: tensor<2048x!quant.uniform<i16:f32, 1.55999994603917E-4>>, %arg19: tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>>, %arg20: tensor<1x2048x!quant.uniform<i16:f32, 4.8799999058246613E-4>>) -> tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>> {
%cst = constant unit
%0 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %cst, %cst, %cst, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg19, %arg20, %arg15, %arg16, %arg17, %arg18) ({}) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", input_to_input_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0049890000373125076>>, input_to_forget_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0078849997371435165>>, input_to_cell_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0087630003690719604>>, input_to_output_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0057529998011887074>>, effective_hidden_scale_intermediate = tensor<0x!quant.uniform<i8<-127:127>:f32, 0.0075630000792443752:2>>, kernel_type = "FULL", proj_clip = 0.01 : f32} : (tensor<1x528x!quant.uniform<i8:f32, 0.037248000502586365:-19>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.031925998628139496>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.056272000074386597>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.063763998448848724>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.013358999975025654>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.022830000147223473>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.032276000827550888>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.035427000373601913>>, none, none, none, tensor<2048x!quant.uniform<i32:f32, 4.2675782196965883E-7>>, tensor<2048x!quant.uniform<i32:f32, 1.0742187583900886E-7>>, tensor<2048x!quant.uniform<i32:f32, 1.6406249869760359E-7>>, tensor<2048x!quant.uniform<i32:f32, 1.523437447303877E-7>>, tensor<640x2048x!quant.uniform<i8<-127:127>:f32, 0.021174000576138496>>, tensor<640x!quant.uniform<i32:f32, 1.601389680352559E-4>>, tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>>, tensor<1x2048x!quant.uniform<i16:f32, 4.8799999058246613E-4>>, tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>, tensor<2048x!quant.uniform<i16:f32, 1.1000000085914508E-4>>, tensor<2048x!quant.uniform<i16:f32, 1.6799999866634607E-4>>, tensor<2048x!quant.uniform<i16:f32, 1.55999994603917E-4>>) -> tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>>
return %0 : tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>>
// CHECK: {
// CHECK-NEXT: version: 3,
// CHECK-NEXT: operator_codes: [ {
// CHECK-NEXT: builtin_code: LSTM,
// CHECK-NEXT: version: 1
// CHECK-NEXT: } ],
// CHECK-NEXT: subgraphs: [ {
// CHECK-NEXT: tensors: [ {
// CHECK-NEXT: shape: [ 1, 528 ],
// CHECK-NEXT: type: INT8,
// CHECK-NEXT: buffer: 1,
// CHECK-NEXT: name: "arg0",
// CHECK-NEXT: quantization: {
// CHECK-NEXT: scale: [ 0.037248 ],
// CHECK-NEXT: zero_point: [ -19 ]
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 2048, 528 ],
// CHECK-NEXT: type: INT8,
// CHECK-NEXT: buffer: 2,
// CHECK-NEXT: name: "arg1",
// CHECK-NEXT: quantization: {
// CHECK-NEXT: scale: [ 0.059802 ],
// CHECK-NEXT: zero_point: [ 0 ]
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 2048, 528 ],
// CHECK-NEXT: type: INT8,
// CHECK-NEXT: buffer: 3,
// CHECK-NEXT: name: "arg2",
// CHECK-NEXT: quantization: {
// CHECK-NEXT: scale: [ 0.031926 ],
// CHECK-NEXT: zero_point: [ 0 ]
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 2048, 528 ],
// CHECK-NEXT: type: INT8,
// CHECK-NEXT: buffer: 4,
// CHECK-NEXT: name: "arg3",
// CHECK-NEXT: quantization: {
// CHECK-NEXT: scale: [ 0.056272 ],
// CHECK-NEXT: zero_point: [ 0 ]
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 2048, 528 ],
// CHECK-NEXT: type: INT8,
// CHECK-NEXT: buffer: 5,
// CHECK-NEXT: name: "arg4",
// CHECK-NEXT: quantization: {
// CHECK-NEXT: scale: [ 0.063764 ],
// CHECK-NEXT: zero_point: [ 0 ]
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 2048, 640 ],
// CHECK-NEXT: type: INT8,
// CHECK-NEXT: buffer: 6,
// CHECK-NEXT: name: "arg5",
// CHECK-NEXT: quantization: {
// CHECK-NEXT: scale: [ 0.013359 ],
// CHECK-NEXT: zero_point: [ 0 ]
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 2048, 640 ],
// CHECK-NEXT: type: INT8,
// CHECK-NEXT: buffer: 7,
// CHECK-NEXT: name: "arg6",
// CHECK-NEXT: quantization: {
// CHECK-NEXT: scale: [ 0.02283 ],
// CHECK-NEXT: zero_point: [ 0 ]
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 2048, 640 ],
// CHECK-NEXT: type: INT8,
// CHECK-NEXT: buffer: 8,
// CHECK-NEXT: name: "arg7",
// CHECK-NEXT: quantization: {
// CHECK-NEXT: scale: [ 0.032276 ],
// CHECK-NEXT: zero_point: [ 0 ]
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 2048, 640 ],
// CHECK-NEXT: type: INT8,
// CHECK-NEXT: buffer: 9,
// CHECK-NEXT: name: "arg8",
// CHECK-NEXT: quantization: {
// CHECK-NEXT: scale: [ 0.035427 ],
// CHECK-NEXT: zero_point: [ 0 ]
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 2048 ],
// CHECK-NEXT: type: INT32,
// CHECK-NEXT: buffer: 10,
// CHECK-NEXT: name: "arg9",
// CHECK-NEXT: quantization: {
// CHECK-NEXT: scale: [ 0.0 ],
// CHECK-NEXT: zero_point: [ 0 ]
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 2048 ],
// CHECK-NEXT: type: INT32,
// CHECK-NEXT: buffer: 11,
// CHECK-NEXT: name: "arg10",
// CHECK-NEXT: quantization: {
// CHECK-NEXT: scale: [ 0.0 ],
// CHECK-NEXT: zero_point: [ 0 ]
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 2048 ],
// CHECK-NEXT: type: INT32,
// CHECK-NEXT: buffer: 12,
// CHECK-NEXT: name: "arg11",
// CHECK-NEXT: quantization: {
// CHECK-NEXT: scale: [ 0.0 ],
// CHECK-NEXT: zero_point: [ 0 ]
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 2048 ],
// CHECK-NEXT: type: INT32,
// CHECK-NEXT: buffer: 13,
// CHECK-NEXT: name: "arg12",
// CHECK-NEXT: quantization: {
// CHECK-NEXT: scale: [ 0.0 ],
// CHECK-NEXT: zero_point: [ 0 ]
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 640, 2048 ],
// CHECK-NEXT: type: INT8,
// CHECK-NEXT: buffer: 14,
// CHECK-NEXT: name: "arg13",
// CHECK-NEXT: quantization: {
// CHECK-NEXT: scale: [ 0.021174 ],
// CHECK-NEXT: zero_point: [ 0 ]
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 640 ],
// CHECK-NEXT: type: INT32,
// CHECK-NEXT: buffer: 15,
// CHECK-NEXT: name: "arg14",
// CHECK-NEXT: quantization: {
// CHECK-NEXT: scale: [ 0.00016 ],
// CHECK-NEXT: zero_point: [ 0 ]
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 2048 ],
// CHECK-NEXT: type: INT16,
// CHECK-NEXT: buffer: 16,
// CHECK-NEXT: name: "arg15",
// CHECK-NEXT: quantization: {
// CHECK-NEXT: scale: [ 0.000437 ],
// CHECK-NEXT: zero_point: [ 0 ]
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 2048 ],
// CHECK-NEXT: type: INT16,
// CHECK-NEXT: buffer: 17,
// CHECK-NEXT: name: "arg16",
// CHECK-NEXT: quantization: {
// CHECK-NEXT: scale: [ 0.00011 ],
// CHECK-NEXT: zero_point: [ 0 ]
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 2048 ],
// CHECK-NEXT: type: INT16,
// CHECK-NEXT: buffer: 18,
// CHECK-NEXT: name: "arg17",
// CHECK-NEXT: quantization: {
// CHECK-NEXT: scale: [ 0.000168 ],
// CHECK-NEXT: zero_point: [ 0 ]
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 2048 ],
// CHECK-NEXT: type: INT16,
// CHECK-NEXT: buffer: 19,
// CHECK-NEXT: name: "arg18",
// CHECK-NEXT: quantization: {
// CHECK-NEXT: scale: [ 0.000156 ],
// CHECK-NEXT: zero_point: [ 0 ]
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 1, 640 ],
// CHECK-NEXT: type: INT8,
// CHECK-NEXT: name: "arg19",
// CHECK-NEXT: quantization: {
// CHECK-NEXT: scale: [ 0.096711 ],
// CHECK-NEXT: zero_point: [ 10 ]
// CHECK-NEXT: },
// CHECK-NEXT: is_variable: true
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 1, 2048 ],
// CHECK-NEXT: type: INT16,
// CHECK-NEXT: name: "arg20",
// CHECK-NEXT: quantization: {
// CHECK-NEXT: scale: [ 0.000488 ],
// CHECK-NEXT: zero_point: [ 0 ]
// CHECK-NEXT: },
// CHECK-NEXT: is_variable: true
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 0 ],
// CHECK-NEXT: type: INT16,
// CHECK-NEXT: name: "input_to_input_intermediate",
// CHECK-NEXT: quantization: {
// CHECK-NEXT: scale: [ 0.004989 ],
// CHECK-NEXT: zero_point: [ 0 ]
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 0 ],
// CHECK-NEXT: type: INT16,
// CHECK-NEXT: name: "input_to_forget_intermediate",
// CHECK-NEXT: quantization: {
// CHECK-NEXT: scale: [ 0.007885 ],
// CHECK-NEXT: zero_point: [ 0 ]
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 0 ],
// CHECK-NEXT: type: INT16,
// CHECK-NEXT: name: "input_to_cell_intermediate",
// CHECK-NEXT: quantization: {
// CHECK-NEXT: scale: [ 0.008763 ],
// CHECK-NEXT: zero_point: [ 0 ]
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 0 ],
// CHECK-NEXT: type: INT16,
// CHECK-NEXT: name: "input_to_output_intermediate",
// CHECK-NEXT: quantization: {
// CHECK-NEXT: scale: [ 0.005753 ],
// CHECK-NEXT: zero_point: [ 0 ]
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 0 ],
// CHECK-NEXT: type: INT8,
// CHECK-NEXT: name: "effective_hidden_scale_intermediate",
// CHECK-NEXT: quantization: {
// CHECK-NEXT: scale: [ 0.007563 ],
// CHECK-NEXT: zero_point: [ 2 ]
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 1, 640 ],
// CHECK-NEXT: type: INT8,
// CHECK-NEXT: buffer: 22,
// CHECK-NEXT: name: "tfl.lstm",
// CHECK-NEXT: quantization: {
// CHECK-NEXT: scale: [ 0.096711 ],
// CHECK-NEXT: zero_point: [ 10 ]
// CHECK-NEXT: }
// CHECK-NEXT: } ],
// CHECK-NEXT: inputs: [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20 ],
// CHECK-NEXT: outputs: [ 26 ],
// CHECK-NEXT: operators: [ {
// CHECK-NEXT: inputs: [ 0, 1, 2, 3, 4, 5, 6, 7, 8, -1, -1, -1, 9, 10, 11, 12, 13, 14, 19, 20, 15, 16, 17, 18 ],
// CHECK-NEXT: outputs: [ 26 ],
// CHECK-NEXT: builtin_options_type: LSTMOptions,
// CHECK-NEXT: builtin_options: {
// CHECK-NEXT: fused_activation_function: TANH,
// CHECK-NEXT: cell_clip: 10.0,
// CHECK-NEXT: proj_clip: 0.01
// CHECK-NEXT: },
// CHECK-NEXT: intermediates: [ 21, 22, 23, 24, 25 ]
// CHECK-NEXT: } ],
// CHECK-NEXT: name: "main"
// CHECK-NEXT: } ],
// CHECK-NEXT: description: "MLIR Converted.",
// CHECK-NEXT: buffers: [ {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-NEXT: data: [ 49, 46, 55, 46, 48, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
// CHECK-NEXT: } ],
// CHECK-NEXT: metadata: [ {
// CHECK-NEXT: name: "min_runtime_version",
// CHECK-NEXT: buffer: 23
// CHECK-NEXT: } ]
// CHECK-NEXT: }
}

View File

@ -128,6 +128,12 @@ func @main(tensor<4xf32>) -> tensor<4xf32> {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-NEXT: data: [ 49, 46, 49, 51, 46, 49, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
// CHECK-NEXT: } ],
// CHECK-NEXT: metadata: [ {
// CHECK-NEXT: name: "min_runtime_version",
// CHECK-NEXT: buffer: 8
// CHECK-NEXT: } ]
// CHECK-NEXT: }

View File

@ -1,4 +1,4 @@
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -emit-custom-ops -o - | flatbuffer_to_string - | FileCheck %s
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -emit-custom-ops -o - | flatbuffer_to_string - | FileCheck --dump-input-on-failure %s
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir -o - | FileCheck --check-prefix=MLIR %s
func @main(%arg0: tensor<1x64x64x32xf32>) -> (tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>) {
@ -50,6 +50,12 @@ func @main(%arg0: tensor<1x64x64x32xf32>) -> (tensor<1x32x32x32xf32>, tensor<1x3
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-NEXT: data: [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
// CHECK-NEXT: } ],
// CHECK-NEXT: metadata: [ {
// CHECK-NEXT: name: "min_runtime_version",
// CHECK-NEXT: buffer: 4
// CHECK-NEXT: } ]
// CHECK-NEXT:}

View File

@ -1,4 +1,4 @@
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -emit-custom-ops -o - | flatbuffer_to_string - | FileCheck %s
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -emit-custom-ops -o - | flatbuffer_to_string - | FileCheck --dump-input-on-failure %s
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir -o - | FileCheck --check-prefix=MLIR %s
func @main(%arg0: tensor<1x8x8x128xf32>, %arg1: tensor<1x8x8x128xf32>) -> tensor<1x8x8x128xf32> {
@ -50,6 +50,12 @@ func @main(%arg0: tensor<1x8x8x128xf32>, %arg1: tensor<1x8x8x128xf32>) -> tensor
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-NEXT: data: [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
// CHECK-NEXT: } ],
// CHECK-NEXT: metadata: [ {
// CHECK-NEXT: name: "min_runtime_version",
// CHECK-NEXT: buffer: 4
// CHECK-NEXT: } ]
// CHECK-NEXT:}

View File

@ -20,6 +20,8 @@ module attributes {
// CHECK-NEXT: data: [ 118, 97, 108, 117, 101, 49 ]
// CHECK-NEXT: }, {
// CHECK-NEXT: data: [ 118, 97, 108, 117, 101, 50 ]
// CHECK-NEXT: }, {
// CHECK-NEXT: data: [ 49, 46, 54, 46, 48, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
// CHECK-NEXT: } ],
// CHECK-NEXT: metadata: [ {
// CHECK-NEXT: name: "key1",
@ -27,4 +29,8 @@ module attributes {
// CHECK-NEXT: }, {
// CHECK-NEXT: name: "key2",
// CHECK-NEXT: buffer: 5
// CHECK-NEXT: }, {
// CHECK-NEXT: name: "min_runtime_version",
// CHECK-NEXT: buffer: 6
// CHECK-NEXT: } ]
// CHECK-NEXT: }

Some files were not shown because too many files have changed in this diff Show More