Merge branch 'official_master' into no_mem_opt_if_jit_on

This commit is contained in:
Trent Lo 2019-09-13 13:28:45 -07:00
commit 67c50b0914
1643 changed files with 49510 additions and 17337 deletions

View File

@ -49,34 +49,34 @@ remote_config_workspace()
# Apple and Swift rules.
http_archive(
name = "build_bazel_rules_apple",
sha256 = "6efdde60c91724a2be7f89b0c0a64f01138a45e63ba5add2dca2645d981d23a1",
urls = ["https://github.com/bazelbuild/rules_apple/releases/download/0.17.2/rules_apple.0.17.2.tar.gz"],
sha256 = "a045a436b642c70fb0c10ca84ff0fd2dcbd59cc89100d597a61e8374afafb366",
urls = ["https://github.com/bazelbuild/rules_apple/releases/download/0.18.0/rules_apple.0.18.0.tar.gz"],
) # https://github.com/bazelbuild/rules_apple/releases
http_archive(
name = "build_bazel_rules_swift",
sha256 = "96a86afcbdab215f8363e65a10cf023b752e90b23abf02272c4fc668fcb70311",
urls = ["https://github.com/bazelbuild/rules_swift/releases/download/0.11.1/rules_swift.0.11.1.tar.gz"],
sha256 = "18cd4df4e410b0439a4935f9ca035bd979993d42372ba79e7f2d4fafe9596ef0",
urls = ["https://github.com/bazelbuild/rules_swift/releases/download/0.12.1/rules_swift.0.12.1.tar.gz"],
) # https://github.com/bazelbuild/rules_swift/releases
http_archive(
name = "build_bazel_apple_support",
sha256 = "7356dbd44dea71570a929d1d4731e870622151a5f27164d966dda97305f33471",
urls = ["https://github.com/bazelbuild/apple_support/releases/download/0.6.0/apple_support.0.6.0.tar.gz"],
sha256 = "122ebf7fe7d1c8e938af6aeaee0efe788a3a2449ece5a8d6a428cb18d6f88033",
urls = ["https://github.com/bazelbuild/apple_support/releases/download/0.7.1/apple_support.0.7.1.tar.gz"],
) # https://github.com/bazelbuild/apple_support/releases
http_archive(
name = "bazel_skylib",
sha256 = "2ef429f5d7ce7111263289644d233707dba35e39696377ebab8b0bc701f7818e",
urls = ["https://github.com/bazelbuild/bazel-skylib/releases/download/0.8.0/bazel-skylib.0.8.0.tar.gz"],
sha256 = "1dde365491125a3db70731e25658dfdd3bc5dbdfd11b840b3e987ecf043c7ca0",
urls = ["https://github.com/bazelbuild/bazel-skylib/releases/download/0.9.0/bazel-skylib.0.9.0.tar.gz"],
) # https://github.com/bazelbuild/bazel-skylib/releases
http_archive(
name = "com_github_apple_swift_swift_protobuf",
type = "zip",
strip_prefix = "swift-protobuf-1.5.0/",
urls = ["https://github.com/apple/swift-protobuf/archive/1.5.0.zip"],
strip_prefix = "swift-protobuf-1.6.0/",
urls = ["https://github.com/apple/swift-protobuf/archive/1.6.0.zip"],
) # https://github.com/apple/swift-protobuf/releases
http_file(
name = "xctestrunner",
executable = 1,
urls = ["https://github.com/google/xctestrunner/releases/download/0.2.7/ios_test_runner.par"],
urls = ["https://github.com/google/xctestrunner/releases/download/0.2.9/ios_test_runner.par"],
) # https://github.com/google/xctestrunner/releases
# Use `swift_rules_dependencies` to fetch the toolchains. With the
# `git_repository` rules above, the following call will skip redefining them.

View File

@ -3,56 +3,56 @@ package(default_visibility = ["//visibility:public"])
filegroup(
name = "gcc",
srcs = [
"bin/arm-linux-gnueabihf-gcc",
"bin/arm-rpi-linux-gnueabihf-gcc",
],
)
filegroup(
name = "ar",
srcs = [
"bin/arm-linux-gnueabihf-ar",
"bin/arm-rpi-linux-gnueabihf-ar",
],
)
filegroup(
name = "ld",
srcs = [
"bin/arm-linux-gnueabihf-ld",
"bin/arm-rpi-linux-gnueabihf-ld",
],
)
filegroup(
name = "nm",
srcs = [
"bin/arm-linux-gnueabihf-nm",
"bin/arm-rpi-linux-gnueabihf-nm",
],
)
filegroup(
name = "objcopy",
srcs = [
"bin/arm-linux-gnueabihf-objcopy",
"bin/arm-rpi-linux-gnueabihf-objcopy",
],
)
filegroup(
name = "objdump",
srcs = [
"bin/arm-linux-gnueabihf-objdump",
"bin/arm-rpi-linux-gnueabihf-objdump",
],
)
filegroup(
name = "strip",
srcs = [
"bin/arm-linux-gnueabihf-strip",
"bin/arm-rpi-linux-gnueabihf-strip",
],
)
filegroup(
name = "as",
srcs = [
"bin/arm-linux-gnueabihf-as",
"bin/arm-rpi-linux-gnueabihf-as",
],
)

View File

@ -1388,11 +1388,18 @@ def main():
if (environ_cp.get('TF_NEED_CUDA') == '1' and
'TF_CUDA_CONFIG_REPO' not in environ_cp):
tensor_rt_question = (
'Do you wish to build TensorFlow with TensorRT support? NB! There ' +
'are known ODR violations between TensorRT and cuDNN that may result ' +
'in application crashes and/or data corruption. Please see ' +
'https://github.com/tensorflow/tensorflow/issues/32480 for details.')
set_action_env_var(
environ_cp,
'TF_NEED_TENSORRT',
'TensorRT',
False,
question=tensor_rt_question,
bazel_config_name='tensorrt')
environ_save = dict(environ_cp)

View File

@ -159,7 +159,7 @@ TF_DeprecatedSession* TF_NewDeprecatedSession(const TF_SessionOptions* opt,
TF_Status* status) {
Session* session;
status->status = NewSession(opt->options, &session);
if (TF_GetCode(status) == TF_OK) {
if (status->status.ok()) {
return new TF_DeprecatedSession({session});
} else {
DCHECK_EQ(nullptr, session);
@ -332,7 +332,7 @@ bool ExtendSessionGraphHelper(TF_Session* session, TF_Status* status) {
// TODO(nolivia): check this on a subset of the graph instead of all of
// it.
status->status = graph::ValidateGraphHasNoCycle(session->graph->graph);
if (TF_GetCode(status) != TF_OK) {
if (!status->status.ok()) {
session->graph->mu.unlock();
return false;
}
@ -352,7 +352,7 @@ bool ExtendSessionGraphHelper(TF_Session* session, TF_Status* status) {
*graph_def.mutable_library() = graph.flib_def().ToProto();
session->graph->mu.unlock();
status->status = session->session->Extend(std::move(graph_def));
if (TF_GetCode(status) != TF_OK) {
if (!status->status.ok()) {
// Contract is we always delete input_values[i].
return false;
}
@ -382,7 +382,7 @@ static bool TF_Run_Inputs(TF_Tensor* const* c_inputs,
const int ninputs = input_pairs->size();
for (int i = 0; i < ninputs; ++i) {
status->status = TF_TensorToTensor(c_inputs[i], &(*input_pairs)[i].second);
if (TF_GetCode(status) != TF_OK) return false;
if (!status->status.ok()) return false;
}
return true;
}
@ -439,7 +439,7 @@ static void TF_Run_Helper(
// Serialize back to upstream client, who now owns the new buffer
if (run_metadata != nullptr) {
status->status = MessageToBuffer(run_metadata_proto, run_metadata);
if (TF_GetCode(status) != TF_OK) return;
if (!status->status.ok()) return;
}
} else {
// NOTE(zongheng): PRun does not support RunOptions yet.
@ -459,7 +459,7 @@ static void TF_Run_Helper(
continue;
}
c_outputs[i] = TF_TensorFromTensor(src, status);
if (TF_GetCode(status) != TF_OK) return;
if (!status->status.ok()) return;
}
}
@ -516,7 +516,7 @@ void TF_PRunSetup(TF_DeprecatedSession* s,
string new_handle;
status->status = s->session->PRunSetup(input_names, output_names,
target_oper_names, &new_handle);
if (TF_GetCode(status) == TF_OK) {
if (status->status.ok()) {
char* buf = new char[new_handle.size() + 1];
memcpy(buf, new_handle.c_str(), new_handle.size() + 1);
*handle = buf;
@ -555,7 +555,7 @@ TF_Library* TF_LoadLibrary(const char* library_filename, TF_Status* status) {
status->status = tensorflow::LoadLibrary(
library_filename, &lib_handle->lib_handle, &lib_handle->op_list.data,
&lib_handle->op_list.length);
if (TF_GetCode(status) != TF_OK) {
if (!status->status.ok()) {
delete lib_handle;
return nullptr;
}
@ -983,7 +983,7 @@ void TF_SetAttrTensor(TF_OperationDescription* desc, const char* attr_name,
TF_Tensor* value, TF_Status* status) {
Tensor t;
status->status = TF_TensorToTensor(value, &t);
if (TF_GetCode(status) == TF_OK) desc->node_builder.Attr(attr_name, t);
if (status->status.ok()) desc->node_builder.Attr(attr_name, t);
}
void TF_SetAttrTensorList(TF_OperationDescription* desc, const char* attr_name,
@ -993,13 +993,13 @@ void TF_SetAttrTensorList(TF_OperationDescription* desc, const char* attr_name,
std::vector<Tensor> t;
t.reserve(num_values);
for (int i = 0; i < num_values && TF_GetCode(status) == TF_OK; ++i) {
for (int i = 0; i < num_values && status->status.ok(); ++i) {
Tensor v;
status->status = TF_TensorToTensor(values[i], &v);
t.emplace_back(v);
}
if (TF_GetCode(status) == TF_OK) desc->node_builder.Attr(attr_name, t);
if (status->status.ok()) desc->node_builder.Attr(attr_name, t);
}
void TF_SetAttrValueProto(TF_OperationDescription* desc, const char* attr_name,
@ -1048,11 +1048,11 @@ static TF_Operation* TF_FinishOperationLocked(TF_OperationDescription* desc,
status->status = desc->node_builder.Finalize(&desc->graph->graph, &ret,
/*consume=*/true);
if (TF_GetCode(status) == TF_OK) {
if (status->status.ok()) {
// Run shape inference function for newly added node.
status->status = desc->graph->refiner.AddNode(ret);
}
if (TF_GetCode(status) == TF_OK) {
if (status->status.ok()) {
// Add the node to the name-to-node mapping.
desc->graph->name_map[ret->name()] = ret;
} else if (ret != nullptr) {
@ -1101,7 +1101,7 @@ int TF_OperationOutputListLength(TF_Operation* oper, const char* arg_name,
NameRangeMap name_ranges;
status->status =
NameRangesForNode(oper->node, oper->node.op_def(), nullptr, &name_ranges);
if (TF_GetCode(status) != TF_OK) return -1;
if (!status->status.ok()) return -1;
auto iter = name_ranges.find(arg_name);
if (iter == name_ranges.end()) {
status->status = InvalidArgument("Output arg '", arg_name, "' not found");
@ -1123,7 +1123,7 @@ int TF_OperationInputListLength(TF_Operation* oper, const char* arg_name,
NameRangeMap name_ranges;
status->status =
NameRangesForNode(oper->node, oper->node.op_def(), &name_ranges, nullptr);
if (TF_GetCode(status) != TF_OK) return -1;
if (!status->status.ok()) return -1;
auto iter = name_ranges.find(arg_name);
if (iter == name_ranges.end()) {
status->status = InvalidArgument("Input arg '", arg_name, "' not found");
@ -1142,6 +1142,16 @@ TF_Output TF_OperationInput(TF_Input oper_in) {
return {ToOperation(edge->src()), edge->src_output()};
}
void TF_OperationAllInputs(TF_Operation* oper, TF_Output* inputs,
int max_inputs) {
for (auto* edge : oper->node.in_edges()) {
if (edge->dst_input() >= 0 && edge->dst_input() < max_inputs) {
inputs[edge->dst_input()] = {ToOperation(edge->src()),
edge->src_output()};
}
}
}
int TF_OperationOutputNumConsumers(TF_Output oper_out) {
int count = 0;
for (const auto* edge : oper_out.oper->node.out_edges()) {
@ -1221,7 +1231,7 @@ TF_AttrMetadata TF_OperationGetAttrMetadata(TF_Operation* oper,
TF_Status* status) {
TF_AttrMetadata metadata;
const auto* attr = GetAttrValue(oper, attr_name, status);
if (TF_GetCode(status) != TF_OK) return metadata;
if (!status->status.ok()) return metadata;
switch (attr->value_case()) {
#define SINGLE_CASE(kK, attr_type, size_expr) \
case tensorflow::AttrValue::kK: \
@ -1328,7 +1338,7 @@ void TF_OperationGetAttrString(TF_Operation* oper, const char* attr_name,
void* value, size_t max_length,
TF_Status* status) {
const auto* attr = GetAttrValue(oper, attr_name, status);
if (TF_GetCode(status) != TF_OK) return;
if (!status->status.ok()) return;
if (attr->value_case() != tensorflow::AttrValue::kS) {
status->status =
InvalidArgument("Attribute '", attr_name, "' is not a string");
@ -1346,7 +1356,7 @@ void TF_OperationGetAttrStringList(TF_Operation* oper, const char* attr_name,
int max_values, void* storage,
size_t storage_size, TF_Status* status) {
const auto* attr = GetAttrValue(oper, attr_name, status);
if (TF_GetCode(status) != TF_OK) return;
if (!status->status.ok()) return;
if (attr->value_case() != tensorflow::AttrValue::kList) {
status->status =
InvalidArgument("Value for '", attr_name, "' is not a list");
@ -1379,7 +1389,7 @@ void TF_OperationGetAttrStringList(TF_Operation* oper, const char* attr_name,
void func##List(TF_Operation* oper, const char* attr_name, c_type* values, \
int max_values, TF_Status* status) { \
const auto* attr = GetAttrValue(oper, attr_name, status); \
if (TF_GetCode(status) != TF_OK) return; \
if (!status->status.ok()) return; \
if (attr->value_case() != tensorflow::AttrValue::kList) { \
status->status = \
InvalidArgument("Value for '", attr_name, "' is not a list."); \
@ -1401,7 +1411,7 @@ void TF_OperationGetAttrShape(TF_Operation* oper, const char* attr_name,
PartialTensorShape shape;
status->status =
tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &shape);
if (TF_GetCode(status) != TF_OK) return;
if (!status->status.ok()) return;
auto len = std::min(shape.dims(), num_dims);
for (int i = 0; i < len; ++i) {
value[i] = shape.dim_size(i);
@ -1415,7 +1425,7 @@ void TF_OperationGetAttrShapeList(TF_Operation* oper, const char* attr_name,
std::vector<PartialTensorShape> shapes;
status->status =
tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &shapes);
if (TF_GetCode(status) != TF_OK) return;
if (!status->status.ok()) return;
auto len = std::min(static_cast<int>(shapes.size()), num_shapes);
int64_t* p = storage;
int storage_left = storage_size;
@ -1443,7 +1453,7 @@ void TF_OperationGetAttrTensorShapeProto(TF_Operation* oper,
const char* attr_name,
TF_Buffer* value, TF_Status* status) {
const auto* attr = GetAttrValue(oper, attr_name, status);
if (TF_GetCode(status) != TF_OK) return;
if (!status->status.ok()) return;
if (attr->value_case() != tensorflow::AttrValue::kShape) {
status->status =
InvalidArgument("Value for '", attr_name, "' is not a shape.");
@ -1457,7 +1467,7 @@ void TF_OperationGetAttrTensorShapeProtoList(TF_Operation* oper,
TF_Buffer** values, int max_values,
TF_Status* status) {
const auto* attr = GetAttrValue(oper, attr_name, status);
if (TF_GetCode(status) != TF_OK) return;
if (!status->status.ok()) return;
if (attr->value_case() != tensorflow::AttrValue::kList) {
status->status =
InvalidArgument("Value for '", attr_name, "' is not a list");
@ -1467,7 +1477,7 @@ void TF_OperationGetAttrTensorShapeProtoList(TF_Operation* oper,
for (int i = 0; i < len; ++i) {
values[i] = TF_NewBuffer();
status->status = MessageToBuffer(attr->list().shape(i), values[i]);
if (TF_GetCode(status) != TF_OK) {
if (!status->status.ok()) {
// Delete everything allocated to far, the operation has failed.
for (int j = 0; j <= i; ++j) {
TF_DeleteBuffer(values[j]);
@ -1482,7 +1492,7 @@ void TF_OperationGetAttrTensor(TF_Operation* oper, const char* attr_name,
*value = nullptr;
Tensor t;
status->status = tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &t);
if (TF_GetCode(status) != TF_OK) return;
if (!status->status.ok()) return;
*value = TF_TensorFromTensor(t, status);
}
@ -1491,7 +1501,7 @@ void TF_OperationGetAttrTensorList(TF_Operation* oper, const char* attr_name,
TF_Status* status) {
std::vector<Tensor> ts;
status->status = tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &ts);
if (TF_GetCode(status) != TF_OK) return;
if (!status->status.ok()) return;
const auto len = std::min(max_values, static_cast<int>(ts.size()));
for (int i = 0; i < len; ++i) {
values[i] = TF_TensorFromTensor(ts[i], status);
@ -1502,7 +1512,7 @@ void TF_OperationGetAttrValueProto(TF_Operation* oper, const char* attr_name,
TF_Buffer* output_attr_value,
TF_Status* status) {
const auto* attr = GetAttrValue(oper, attr_name, status);
if (TF_GetCode(status) != TF_OK) return;
if (!status->status.ok()) return;
status->status = MessageToBuffer(*attr, output_attr_value);
}
@ -1583,7 +1593,7 @@ void TF_GraphGetOpDef(TF_Graph* graph, const char* op_name,
{
mutex_lock l(graph->mu);
status->status = graph->graph.op_registry()->LookUpOpDef(op_name, &op_def);
if (TF_GetCode(status) != TF_OK) return;
if (!status->status.ok()) return;
}
status->status = MessageToBuffer(*op_def, output_op_def);
}
@ -1701,7 +1711,7 @@ static void GraphImportGraphDefLocked(TF_Graph* graph, const GraphDef& def,
tensorflow::ImportGraphDefResults results;
status->status = tensorflow::ImportGraphDef(opts->opts, def, &graph->graph,
&graph->refiner, &results);
if (TF_GetCode(status) != TF_OK) return;
if (!status->status.ok()) return;
// Add new nodes to name_map
for (int i = last_node_id; i < graph->graph.num_node_ids(); ++i) {
@ -1755,7 +1765,7 @@ TF_ImportGraphDefResults* TF_GraphImportGraphDefWithResults(
auto results = new TF_ImportGraphDefResults();
mutex_lock l(graph->mu);
GraphImportGraphDefLocked(graph, def, options, results, status);
if (TF_GetCode(status) != TF_OK) {
if (!status->status.ok()) {
delete results;
return nullptr;
}
@ -1813,7 +1823,7 @@ bool CreateInput(const TF_Output& parent_input, TF_Graph* g, const char* name,
TF_SetAttrType(desc, "dtype", TF_OperationOutputType(parent_input));
// TODO(skyewm): set placeholder shape
TF_Operation* oper = TF_FinishOperation(desc, status);
if (TF_GetCode(status) != TF_OK) return false;
if (!status->status.ok()) return false;
*input = {oper, 0};
return true;
}
@ -1958,7 +1968,7 @@ TF_WhileParams TF_NewWhile(TF_Graph* g, TF_Output* inputs, int ninputs,
TF_WhileParams params = {ninputs, cond_graph, cond_inputs, cond_output,
body_graph, body_inputs, body_outputs, name};
if (TF_GetCode(status) != TF_OK) {
if (!status->status.ok()) {
FreeWhileResources(&params);
return EmptyWhileParams();
}
@ -2160,7 +2170,7 @@ TF_Session* TF_NewSession(TF_Graph* graph, const TF_SessionOptions* opt,
TF_Status* status) {
Session* session;
status->status = NewSession(opt->options, &session);
if (TF_GetCode(status) == TF_OK) {
if (status->status.ok()) {
TF_Session* new_session = new TF_Session(session, graph);
if (graph != nullptr) {
mutex_lock l(graph->mu);
@ -2208,7 +2218,7 @@ TF_Session* TF_LoadSessionFromSavedModel(
status->status =
tensorflow::LoadSavedModel(session_options->options, run_options_proto,
export_dir, tag_set, &bundle);
if (TF_GetCode(status) != TF_OK) return nullptr;
if (!status->status.ok()) return nullptr;
// Create a TF_Graph from the MetaGraphDef. This is safe as long as Session
// extends using GraphDefs. The Graph instance is different, but equivalent
@ -2221,11 +2231,11 @@ TF_Session* TF_LoadSessionFromSavedModel(
GraphImportGraphDefLocked(graph, bundle.meta_graph_def.graph_def(),
import_opts, &results, status);
TF_DeleteImportGraphDefOptions(import_opts);
if (TF_GetCode(status) != TF_OK) return nullptr;
if (!status->status.ok()) return nullptr;
if (meta_graph_def != nullptr) {
status->status = MessageToBuffer(bundle.meta_graph_def, meta_graph_def);
if (TF_GetCode(status) != TF_OK) return nullptr;
if (!status->status.ok()) return nullptr;
}
TF_Session* session = new TF_Session(bundle.session.release(), graph);
@ -2325,7 +2335,7 @@ void TF_SessionPRunSetup(TF_Session* session, const TF_Output* inputs,
string new_handle;
status->status = session->session->PRunSetup(input_names, output_names,
target_names, &new_handle);
if (TF_GetCode(status) == TF_OK) {
if (status->status.ok()) {
char* buf = new char[new_handle.size() + 1];
memcpy(buf, new_handle.c_str(), new_handle.size() + 1);
*handle = buf;
@ -2387,9 +2397,9 @@ unsigned char TF_TryEvaluateConstant(TF_Graph* graph, TF_Output output,
tensor, graph->refiner, *graph->graph.op_registry(),
graph->graph.versions().producer(), &evaluated, &result_tensor);
if (evaluated) {
DCHECK(TF_GetCode(status) == TF_OK);
DCHECK(status->status.ok());
*result = TF_TensorFromTensor(result_tensor, status);
if (TF_GetCode(status) != TF_OK) evaluated = false;
if (!status->status.ok()) evaluated = false;
}
return evaluated;
}
@ -2444,7 +2454,7 @@ TF_Buffer* TF_ApiDefMapGet(TF_ApiDefMap* api_def_map, const char* name,
TF_Buffer* ret = TF_NewBuffer();
status->status = MessageToBuffer(*api_def, ret);
if (TF_GetCode(status) != TF_OK) {
if (!status->status.ok()) {
TF_DeleteBuffer(ret);
return nullptr;
}
@ -2456,7 +2466,7 @@ TF_Buffer* TF_GetAllRegisteredKernels(TF_Status* status) {
tensorflow::KernelList kernel_list = tensorflow::GetAllRegisteredKernels();
TF_Buffer* ret = TF_NewBuffer();
status->status = MessageToBuffer(kernel_list, ret);
if (TF_GetCode(status) != TF_OK) {
if (!status->status.ok()) {
TF_DeleteBuffer(ret);
return nullptr;
}
@ -2468,7 +2478,7 @@ TF_Buffer* TF_GetRegisteredKernelsForOp(const char* name, TF_Status* status) {
tensorflow::GetRegisteredKernelsForOp(name);
TF_Buffer* ret = TF_NewBuffer();
status->status = MessageToBuffer(kernel_list, ret);
if (TF_GetCode(status) != TF_OK) {
if (!status->status.ok()) {
TF_DeleteBuffer(ret);
return nullptr;
}
@ -2498,7 +2508,7 @@ TF_Server* TF_NewServer(const void* proto, size_t proto_len,
std::unique_ptr<tensorflow::ServerInterface> out_server;
status->status = tensorflow::NewServer(server_def, &out_server);
if (TF_GetCode(status) != TF_OK) return nullptr;
if (!status->status.ok()) return nullptr;
return new TF_Server(std::move(out_server));
#endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)

View File

@ -435,6 +435,15 @@ TF_CAPI_EXPORT extern int TF_OperationInputListLength(TF_Operation* oper,
// producer.index) to consumer.oper's input (given by consumer.index).
TF_CAPI_EXPORT extern TF_Output TF_OperationInput(TF_Input oper_in);
// Get list of all inputs of a specific operation. `inputs` must point to
// an array of length at least `max_inputs` (ideally set to
// TF_OperationNumInputs(oper)). Beware that a concurrent
// modification of the graph can increase the number of inputs of
// an operation.
TF_CAPI_EXPORT extern void TF_OperationAllInputs(TF_Operation* oper,
TF_Output* inputs,
int max_inputs);
// Get the number of current consumers of a specific output of an
// operation. Note that this number can change when new operations
// are added to the graph.

View File

@ -510,10 +510,6 @@ TFE_TensorHandle* TFE_DequeueVariantTensor(TF_Session* session, int tensor_id,
return createTFEDequeue(ctx, TF_VARIANT, queue, status);
}
static void CheckOk(TF_Status* status) {
CHECK_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
}
void TFE_TensorHandlePrintDebugString(TFE_TensorHandle* handle) {
auto* status = TF_NewStatus();
if (!TFE_TensorHandleIsConcrete(handle)) {

View File

@ -9,7 +9,6 @@ load(
)
load(
"//tensorflow/core/platform:default/build_config.bzl",
"tf_additional_device_tracer_test_flags",
"tf_kernel_tests_linkstatic",
)
load(
@ -27,6 +26,7 @@ tf_cuda_library(
"c_api.cc",
"c_api_debug.cc",
"c_api_experimental.h",
"c_api_internal.cc",
"c_api_internal.h",
],
hdrs = ["c_api.h"],
@ -237,8 +237,7 @@ tf_cuda_cc_test(
srcs = [
"c_api_experimental_test.cc",
],
args =
["--heap_check=local"] + tf_additional_device_tracer_test_flags(),
args = ["--heap_check=local"],
extra_copts = tfe_xla_copts(),
linkstatic = tf_kernel_tests_linkstatic(),
tags = tf_cuda_tests_tags() + ["nomac"],

View File

@ -33,7 +33,7 @@ limitations under the License.
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/framework/device_attributes.pb.h"
#include "tensorflow/core/platform/host_info.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/platform.h" // NOLINT
#ifdef TENSORFLOW_EAGER_USE_XLA
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
@ -61,6 +61,7 @@ limitations under the License.
#include "tensorflow/core/framework/rendezvous.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/blocking_counter.h"
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
@ -100,32 +101,34 @@ string DeviceName(const tensorflow::Device* d) {
tensorflow::Status GetAllRemoteDevices(
const std::vector<string>& remote_workers,
tensorflow::WorkerCacheInterface* worker_cache,
std::unique_ptr<tensorflow::DeviceMgr>* device_mgr) {
std::unique_ptr<tensorflow::DynamicDeviceMgr>* device_mgr) {
std::vector<std::unique_ptr<tensorflow::Device>> remote_devices;
tensorflow::Status status;
// TODO(nareshmodi) do this in parallel instead of serially.
for (const string& remote_worker : remote_workers) {
tensorflow::Notification n;
tensorflow::mutex remote_devices_mu;
int num_remote_workers = remote_workers.size();
tensorflow::BlockingCounter counter(num_remote_workers);
std::vector<tensorflow::Status> statuses(num_remote_workers);
for (int i = 0; i < num_remote_workers; i++) {
tensorflow::NewRemoteDevices(
tensorflow::Env::Default(), worker_cache, remote_worker,
[&status, &n, &remote_devices](
tensorflow::Env::Default(), worker_cache, remote_workers[i],
[i, &statuses, &counter, &remote_devices, &remote_devices_mu](
const tensorflow::Status& s,
std::vector<tensorflow::Device*>* devices) {
status = s;
statuses[i] = s;
if (s.ok()) {
tensorflow::mutex_lock l(remote_devices_mu);
for (tensorflow::Device* d : *devices) {
remote_devices.emplace_back(d);
}
}
n.Notify();
counter.DecrementCount();
});
n.WaitForNotification();
}
std::unique_ptr<tensorflow::DeviceMgr> remote_device_mgr(
new tensorflow::StaticDeviceMgr(std::move(remote_devices)));
TF_RETURN_IF_ERROR(status);
counter.Wait();
for (int i = 0; i < num_remote_workers; i++) {
TF_RETURN_IF_ERROR(statuses[i]);
}
auto remote_device_mgr = absl::make_unique<tensorflow::DynamicDeviceMgr>();
TF_RETURN_IF_ERROR(remote_device_mgr->AddDevices(std::move(remote_devices)));
*device_mgr = std::move(remote_device_mgr);
return tensorflow::Status::OK();
}
@ -135,11 +138,15 @@ tensorflow::Status CreateRemoteContexts(
int keep_alive_secs, const tensorflow::ServerDef& server_def,
tensorflow::eager::EagerClientCache* remote_eager_workers, bool async,
const tensorflow::eager::CreateContextRequest& base_request) {
for (int i = 0; i < remote_workers.size(); i++) {
int num_remote_workers = remote_workers.size();
tensorflow::BlockingCounter counter(num_remote_workers);
std::vector<tensorflow::Status> statuses(num_remote_workers);
for (int i = 0; i < num_remote_workers; i++) {
const string& remote_worker = remote_workers[i];
tensorflow::eager::CreateContextRequest request(base_request);
tensorflow::eager::CreateContextResponse response;
tensorflow::eager::CreateContextResponse* response =
new tensorflow::eager::CreateContextResponse();
request.set_context_id(context_id);
tensorflow::DeviceNameUtils::ParsedName parsed_name;
if (!tensorflow::DeviceNameUtils::ParseFullName(remote_worker,
@ -159,16 +166,17 @@ tensorflow::Status CreateRemoteContexts(
return tensorflow::errors::Internal(
"Cannot find a client for the given target:", remote_worker);
}
tensorflow::Notification n;
tensorflow::Status status;
// TODO(nareshmodi) do this in parallel instead of serially.
eager_client->CreateContextAsync(
&request, &response, [&status, &n](const tensorflow::Status& s) {
status = s;
n.Notify();
&request, response,
[i, &statuses, &counter, response](const tensorflow::Status& s) {
statuses[i] = s;
delete response;
counter.DecrementCount();
});
n.WaitForNotification();
TF_RETURN_IF_ERROR(status);
}
counter.Wait();
for (int i = 0; i < num_remote_workers; i++) {
TF_RETURN_IF_ERROR(statuses[i]);
}
return tensorflow::Status::OK();
}
@ -215,7 +223,7 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
std::remove(remote_workers.begin(), remote_workers.end(), worker_name),
remote_workers.end());
std::unique_ptr<tensorflow::DeviceMgr> remote_device_mgr;
std::unique_ptr<tensorflow::DynamicDeviceMgr> remote_device_mgr;
LOG_AND_RETURN_IF_ERROR(GetAllRemoteDevices(
remote_workers, grpc_server->master_env()->worker_cache,
&remote_device_mgr));
@ -247,7 +255,7 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
LOG_AND_RETURN_IF_ERROR(
CreateRemoteContexts(remote_workers, context_id, keep_alive_secs,
server_def, remote_eager_workers.get(),
ctx->context->Executor()->Async(), base_request));
ctx->context->Executor().Async(), base_request));
tensorflow::RemoteRendezvous* r =
grpc_server->worker_env()->rendezvous_mgr->Find(context_id);
@ -564,7 +572,7 @@ TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) {
const tensorflow::Tensor* t = nullptr;
tensorflow::TensorHandle* h_cpu = nullptr;
status->status = EagerCopyToDevice(
handle, handle->Context(), handle->Context()->Executor(),
handle, handle->Context(), &handle->Context()->Executor(),
handle->Context()->HostCPU(), false, &h_cpu);
if (!status->status.ok()) {
return nullptr;
@ -596,33 +604,8 @@ TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) {
TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name,
TF_Status* status) {
const char* name = op_or_function_name; // Shorthand
const tensorflow::AttrTypeMap* types;
bool is_function = false;
status->status = tensorflow::AttrTypeMapForOp(name, &types, &is_function);
if (!status->status.ok()) {
return nullptr;
}
if (!is_function) {
const tensorflow::OpDef* op_def;
status->status = tensorflow::OpDefForOp(op_or_function_name, &op_def);
if (!status->status.ok()) {
return nullptr;
}
return new TFE_Op(ctx, name, false, types,
new TFE_OpInferenceContext(op_def));
}
if (!ctx->context->FindFunctionByName(name)) {
status->status = tensorflow::errors::NotFound(
"'", name,
"' is neither a type of a primitive operation nor a name "
"of a function registered in binary running on ",
tensorflow::port::Hostname(),
". Make sure the operation or function is "
"registered in the binary running in this process.");
return nullptr;
}
return new TFE_Op(ctx, name, true, types, nullptr);
return NewOrResetOp(ctx, op_or_function_name, status,
/* op_to_reset= */ nullptr);
}
void TFE_DeleteOp(TFE_Op* op) { delete op; }
@ -916,7 +899,7 @@ TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
return nullptr;
}
status->status = tensorflow::EagerCopyToDevice(h->handle, ctx->context,
ctx->context->Executor(),
&ctx->context->Executor(),
device, false, &handle);
if (status->status.ok()) {
return new TFE_TensorHandle(handle);
@ -967,7 +950,7 @@ TFE_TensorHandle* TFE_NewTensorHandle(const tensorflow::Tensor& t,
void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf,
TF_Status* status) {
status->status = ctx->context->Executor()->WaitForAllPendingNodes();
status->status = ctx->context->Executor().WaitForAllPendingNodes();
if (!status->status.ok()) return;
tensorflow::mutex_lock ml(*ctx->context->MetadataMu());
status->status = MessageToBuffer(*ctx->context->RunMetadataProto(), buf);
@ -979,9 +962,9 @@ TFE_Op* GetFunc(TFE_Context* ctx, const tensorflow::NameAttrList& func,
TF_Status* status) {
TFE_Op* func_op = TFE_NewOp(ctx, func.name().data(), status);
for (const auto& attr : func.attr()) {
if (TF_GetCode(status) != TF_OK) return nullptr;
if (!status->status.ok()) return nullptr;
SetOpAttrValueScalar(ctx, func_op, attr.second, attr.first.data(), status);
if (TF_GetCode(status) != TF_OK) return nullptr;
if (!status->status.ok()) return nullptr;
}
return func_op;
}
@ -1029,7 +1012,7 @@ void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op,
} break;
case tensorflow::AttrValue::kFunc: {
const auto func_op = GetFunc(ctx, default_value.func(), status);
if (TF_GetCode(status) != TF_OK) return;
if (!status->status.ok()) return;
// TODO(nareshmodi): TFE_OpSetAttrFunction and TFE_OpSetAttrFunctionList
// require TFE_Op* and just convert it internally a NameAttrValue, so
// consider adding an overload to the C API to make this case easier.

View File

@ -28,6 +28,16 @@ limitations under the License.
using tensorflow::string;
void TFE_OpReset(TFE_Context* ctx, const char* op_or_function_name,
TF_Status* status, TFE_Op* op_to_reset) {
if (op_to_reset) {
NewOrResetOp(ctx, op_or_function_name, status, op_to_reset);
} else {
TF_SetStatus(status, TF_INVALID_ARGUMENT,
"op_to_reset should not be nullptr");
}
}
void TFE_OpConsumeInput(TFE_Op* op, TFE_TensorHandle* h, TF_Status* status) {
op->operation.ConsumeInput(h->handle);
}
@ -597,5 +607,5 @@ void TFE_ContextSetExecutorForThread(TFE_Context* ctx, TFE_Executor* executor) {
}
TFE_Executor* TFE_ContextGetExecutorForThread(TFE_Context* ctx) {
return new TFE_Executor(ctx->context->Executor());
return new TFE_Executor(&ctx->context->Executor());
}

View File

@ -22,6 +22,10 @@ limitations under the License.
extern "C" {
#endif
TF_CAPI_EXPORT extern void TFE_OpReset(TFE_Context* ctx,
const char* op_or_function_name,
TF_Status* status, TFE_Op* op_to_reset);
TF_CAPI_EXPORT extern void TFE_OpConsumeInput(TFE_Op* op, TFE_TensorHandle* h,
TF_Status* status);

View File

@ -84,11 +84,6 @@ void ExecuteWithProfiling(bool async) {
string profile_proto_str = profile_proto.DebugString();
if (!gpu_device_name.empty()) {
EXPECT_TRUE(HasSubstr(profile_proto_str, "/device:GPU:0"));
// device name with "stream:all" is collected by Device Tracer.
#ifndef TENSORFLOW_USE_ROCM
// ROCm platform does not yet support stream level tracing
EXPECT_TRUE(HasSubstr(profile_proto_str, "stream:all"));
#endif
}
// "/host:CPU" is collected by TraceMe
EXPECT_TRUE(HasSubstr(profile_proto_str, "/host:CPU"));

View File

@ -0,0 +1,58 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/core/platform/host_info.h"
TFE_Op* NewOrResetOp(TFE_Context* ctx, const char* op_or_function_name,
TF_Status* status, TFE_Op* op_to_reset) {
const char* name = op_or_function_name; // Shorthand
const tensorflow::AttrTypeMap* types;
bool is_function = false;
status->status = tensorflow::AttrTypeMapForOp(name, &types, &is_function);
if (!status->status.ok()) {
return nullptr;
}
auto create_or_reset = [&op_to_reset, &ctx, &name, &types](
bool is_function,
TFE_OpInferenceContext* inference_ctx) -> TFE_Op* {
if (op_to_reset) {
op_to_reset->Reset(ctx, name, is_function, types, inference_ctx);
return op_to_reset;
} else {
return new TFE_Op(ctx, name, is_function, types, inference_ctx);
}
};
if (!is_function) {
const tensorflow::OpDef* op_def;
status->status = tensorflow::OpDefForOp(op_or_function_name, &op_def);
if (!status->status.ok()) {
return nullptr;
}
return create_or_reset(false, new TFE_OpInferenceContext(op_def));
}
if (!ctx->context->FindFunctionByName(name)) {
status->status = tensorflow::errors::NotFound(
"'", name,
"' is neither a type of a primitive operation nor a name "
"of a function registered in binary running on ",
tensorflow::port::Hostname(),
". Make sure the operation or function is "
"registered in the binary running in this process.");
return nullptr;
}
return create_or_reset(true, nullptr);
}

View File

@ -133,10 +133,25 @@ struct TFE_Op {
: operation(ctx->context, op, is_function, t),
inference_ctx(inference_ctx) {}
void Clear() {
operation.Clear();
inference_ctx.reset();
}
void Reset(TFE_Context* ctx, const char* op, bool is_function,
const tensorflow::AttrTypeMap* t,
TFE_OpInferenceContext* infer_ctx) {
operation.Reset(ctx->context, op, is_function, t, nullptr);
inference_ctx.reset(infer_ctx);
}
tensorflow::EagerOperation operation;
std::unique_ptr<TFE_OpInferenceContext> inference_ctx;
};
TFE_Op* NewOrResetOp(TFE_Context* ctx, const char* op_or_function_name,
TF_Status* status, TFE_Op* op_to_reset = nullptr);
struct TFE_Profiler {
explicit TFE_Profiler() { profiler = tensorflow::ProfilerSession::Create(); }

View File

@ -1069,10 +1069,13 @@ void Execute_MatMul_CPU_Runtime_Error(bool async) {
// still fail.
TF_SetStatus(status, TF_OK, "");
TFE_DeleteTensorHandle(retvals[0]);
TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
TFE_ExecutorWaitForAllPendingNodes(executor, status);
EXPECT_NE(TF_OK, TF_GetCode(status));
TF_SetStatus(status, TF_OK, "");
retvals[0] = nullptr;
TFE_Execute(matmul2, &retvals[0], &num_retvals, status);
EXPECT_NE(TF_OK, TF_GetCode(status));
TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
TFE_ExecutorClearError(executor);
TFE_ExecutorWaitForAllPendingNodes(executor, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);

View File

@ -292,7 +292,9 @@ string ToCamelCase(const string& str) {
bool cap = true;
while (i < str.size()) {
const char c = str[i++];
if (c == joiner) {
if (c == '>') {
cap = true;
} else if (c == joiner) {
cap = true;
} else if (cap) {
result += toupper(c);
@ -304,6 +306,21 @@ string ToCamelCase(const string& str) {
return result;
}
string SeparateNamespaces(const string& str) {
string result;
const char joiner = '_';
size_t i = 0;
while (i < str.size()) {
const char c = str[i++];
if (c == '>') {
result += joiner;
} else {
result += c;
}
}
return result;
}
// Returns a <string, bool> pair. The string is the C++ type name to be used for
// attr_type when defining an object of that type. The bool is a flag to
// indicate whether to treat the type as const when accepting the C++ type as an
@ -549,7 +566,7 @@ struct OpInfo {
OpInfo::OpInfo(const OpDef& graph_op_def, const ApiDef& api_def,
const std::vector<string>& aliases)
: graph_op_def(graph_op_def), api_def(api_def), aliases(aliases) {
op_name = api_def.endpoint(0).name();
op_name = SeparateNamespaces(api_def.endpoint(0).name());
InferOpAttributes(graph_op_def, &inferred_input_attrs);
has_optional_attrs = HasOptionalAttrs(api_def, inferred_input_attrs);
arg_types.push_back("const ::tensorflow::Scope&");

View File

@ -127,8 +127,29 @@ cc_library(
)
tf_cc_test(
name = "loader_test",
srcs = ["loader_test.cc"],
name = "saved_model_bundle_test",
srcs = ["saved_model_bundle_test.cc"],
data = [
":saved_model_half_plus_two",
],
linkstatic = 1,
deps = [
":constants",
":loader",
":signature_constants",
":tag_constants",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:tensorflow",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
],
)
tf_cc_test(
name = "saved_model_bundle_lite_test",
srcs = ["saved_model_bundle_lite_test.cc"],
data = [
":saved_model_half_plus_two",
],

View File

@ -299,6 +299,8 @@ Status LoadSavedModelInternal(const SessionOptions& session_options,
} // namespace
SavedModelBundleInterface::~SavedModelBundleInterface() {}
Status LoadSavedModel(const SessionOptions& session_options,
const RunOptions& run_options, const string& export_dir,
const std::unordered_set<string>& tags,
@ -323,6 +325,133 @@ Status LoadSavedModel(const SessionOptions& session_options,
return status;
}
namespace {
// Session wrapper that prevents calls to Session::Create(), Session::Extend(),
// and the deprecated partial-run methods.
//
// Limiting the available methods on a returned Session gives us the option
// to replace the Session with a cut-down implementation, without breaking any
// users.
class LiteSessionWrapper : public Session {
public:
explicit LiteSessionWrapper(std::unique_ptr<Session> wrapped)
: wrapped_(std::move(wrapped)) {}
Status Create(const GraphDef& graph) override {
return errors::Unimplemented("Session::Create()");
}
Status Create(GraphDef&& graph) override {
return errors::Unimplemented("Session::Create()");
}
Status Extend(const GraphDef& graph) override {
return errors::Unimplemented("Session::Extend()");
}
Status Extend(GraphDef&& graph) override {
return errors::Unimplemented("Session::Extend()");
}
Status Run(const std::vector<std::pair<string, Tensor>>& inputs,
const std::vector<string>& output_tensor_names,
const std::vector<string>& target_node_names,
std::vector<Tensor>* outputs) override {
return wrapped_->Run(inputs, output_tensor_names, target_node_names,
outputs);
}
Status Create(const RunOptions& run_options, const GraphDef& graph) override {
return errors::Unimplemented("Session::Create()");
}
Status Extend(const RunOptions& run_options, const GraphDef& graph) override {
return errors::Unimplemented("Session::Extend()");
}
Status Create(const RunOptions& run_options, GraphDef&& graph) override {
return errors::Unimplemented("Session::Create()");
}
Status Extend(const RunOptions& run_options, GraphDef&& graph) override {
return errors::Unimplemented("Session::Extend()");
}
Status Close(const RunOptions& run_options) override {
return wrapped_->Close(run_options);
}
Status Run(const RunOptions& run_options,
const std::vector<std::pair<string, Tensor>>& inputs,
const std::vector<string>& output_tensor_names,
const std::vector<string>& target_node_names,
std::vector<Tensor>* outputs, RunMetadata* run_metadata) override {
return wrapped_->Run(run_options, inputs, output_tensor_names,
target_node_names, outputs, run_metadata);
}
Status PRunSetup(const std::vector<string>& input_names,
const std::vector<string>& output_names,
const std::vector<string>& target_nodes,
string* handle) override {
return errors::Unimplemented("Session::PRunSetup()");
}
Status PRun(const string& handle,
const std::vector<std::pair<string, Tensor>>& inputs,
const std::vector<string>& output_names,
std::vector<Tensor>* outputs) override {
return errors::Unimplemented("Session::PRun()");
}
Status ListDevices(std::vector<DeviceAttributes>* response) override {
return wrapped_->ListDevices(response);
}
Status Close() override { return wrapped_->Close(); }
Status MakeCallable(const CallableOptions& callable_options,
CallableHandle* out_handle) override {
return wrapped_->MakeCallable(callable_options, out_handle);
}
Status RunCallable(CallableHandle handle,
const std::vector<Tensor>& feed_tensors,
std::vector<Tensor>* fetch_tensors,
RunMetadata* run_metadata) override {
return wrapped_->RunCallable(handle, feed_tensors, fetch_tensors,
run_metadata);
}
Status RunCallable(
CallableHandle handle, const std::vector<Tensor>& feed_tensors,
std::vector<Tensor>* fetch_tensors, RunMetadata* run_metadata,
const thread::ThreadPoolOptions& threadpool_options) override {
return wrapped_->RunCallable(handle, feed_tensors, fetch_tensors,
run_metadata, threadpool_options);
}
Status ReleaseCallable(CallableHandle handle) override {
return wrapped_->ReleaseCallable(handle);
}
private:
const std::unique_ptr<Session> wrapped_;
};
} // namespace
Status LoadSavedModel(const SessionOptions& session_options,
const RunOptions& run_options, const string& export_dir,
const std::unordered_set<string>& tags,
SavedModelBundleLite* const bundle) {
SavedModelBundle legacy_bundle;
SessionOptions rewritten_options(session_options);
rewritten_options.config.mutable_experimental()
->set_optimize_for_static_graph(true);
// TODO(mrry): Consider specializing the session creation to reduce peak
// RAM consumption by using `Session::Create(GraphDef&&)`.
TF_RETURN_IF_ERROR(LoadSavedModel(session_options, run_options, export_dir,
tags, &legacy_bundle));
*bundle = SavedModelBundleLite(
absl::make_unique<LiteSessionWrapper>(std::move(legacy_bundle.session)),
std::move(*legacy_bundle.meta_graph_def.mutable_signature_def()));
return Status::OK();
}
bool MaybeSavedModelDirectory(const string& export_dir) {
const string saved_model_pb_path =
io::JoinPath(export_dir, kSavedModelFilenamePb);

View File

@ -27,31 +27,96 @@ limitations under the License.
namespace tensorflow {
/// SavedModel representation once the SavedModel is loaded from storage.
struct SavedModelBundle {
std::unique_ptr<Session> session;
MetaGraphDef meta_graph_def;
/// Represents a SavedModel that is loaded from storage.
class SavedModelBundleInterface {
public:
virtual ~SavedModelBundleInterface();
/// Returns the TensorFlow Session that can be used to interact with the
/// SavedModel.
virtual Session* GetSession() const = 0;
/// Returns a map from signature name to SignatureDef for all signatures in
/// in the SavedModel.
virtual const protobuf::Map<string, SignatureDef>& GetSignatures() const = 0;
};
/// SavedModel representation once the SavedModel is loaded from storage.
///
/// NOTE: Prefer to use SavedModelBundleLite in new code, as it consumes less
/// RAM.
struct SavedModelBundle : public SavedModelBundleInterface {
/// A TensorFlow Session does not Close itself on destruction. To avoid
/// resource leaks, we explicitly call Close on Sessions that we create.
~SavedModelBundle() {
~SavedModelBundle() override {
if (session) {
session->Close().IgnoreError();
}
}
SavedModelBundle() = default;
Session* GetSession() const override { return session.get(); }
const protobuf::Map<string, SignatureDef>& GetSignatures() const override {
return meta_graph_def.signature_def();
}
std::unique_ptr<Session> session;
MetaGraphDef meta_graph_def;
};
/// Loads a SavedModel from the specified export directory. The meta graph def
// A version of SavedModelBundle that avoids storing a potentially large
// MetaGraphDef. Prefer to use SavedModelBundleLite in new code.
class SavedModelBundleLite : public SavedModelBundleInterface {
public:
SavedModelBundleLite() = default;
SavedModelBundleLite& operator=(SavedModelBundleLite&& other) = default;
SavedModelBundleLite(std::unique_ptr<Session> session,
protobuf::Map<string, SignatureDef> signatures)
: session_(std::move(session)), signatures_(std::move(signatures)) {}
/// A TensorFlow Session does not Close itself on destruction. To avoid
/// resource leaks, we explicitly call Close on Sessions that we create.
~SavedModelBundleLite() override {
if (session_) {
session_->Close().IgnoreError();
}
}
Session* GetSession() const override { return session_.get(); }
const protobuf::Map<string, SignatureDef>& GetSignatures() const override {
return signatures_;
}
private:
std::unique_ptr<Session> session_;
protobuf::Map<string, SignatureDef> signatures_;
};
/// Loads a SavedModel from the specified export directory. The MetaGraphDef
/// to be loaded is identified by the supplied tags, corresponding exactly to
/// the set of tags used at SavedModel build time. Returns a SavedModel bundle
/// with a session and the requested meta graph def, if found.
/// the set of tags used at SavedModel build time. Stores a SavedModel bundle in
/// *bundle with a session and the requested MetaGraphDef, if found.
///
/// NOTE: Prefer the overload that takes a SavedModelBundleLite* in new code.
Status LoadSavedModel(const SessionOptions& session_options,
const RunOptions& run_options, const string& export_dir,
const std::unordered_set<string>& tags,
SavedModelBundle* const bundle);
/// Loads a SavedModel from the specified export directory. The MetaGraphDef
/// to be loaded is identified by the supplied tags, corresponding exactly to
/// the set of tags used at SavedModel build time. Stores a SavedModel bundle
/// in *bundle with a session created from the requested MetaGraphDef if found.
///
/// This overload creates a SavedModelBundleLite, which consumes less RAM than
/// an equivalent SavedModelBundle.
Status LoadSavedModel(const SessionOptions& session_options,
const RunOptions& run_options, const string& export_dir,
const std::unordered_set<string>& tags,
SavedModelBundleLite* const bundle);
/// Checks whether the provided directory could contain a SavedModel. Note that
/// the method does not load any data by itself. If the method returns `false`,
/// the export directory definitely does not contain a SavedModel. If the method

View File

@ -0,0 +1,244 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/cc/saved_model/loader.h"
#include "tensorflow/cc/saved_model/constants.h"
#include "tensorflow/cc/saved_model/signature_constants.h"
#include "tensorflow/cc/saved_model/tag_constants.h"
#include "tensorflow/core/example/example.pb.h"
#include "tensorflow/core/example/feature.pb.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace {
constexpr char kTestDataPbTxt[] =
"cc/saved_model/testdata/half_plus_two_pbtxt/00000123";
constexpr char kTestDataMainOp[] =
"cc/saved_model/testdata/half_plus_two_main_op/00000123";
constexpr char kTestDataSharded[] =
"cc/saved_model/testdata/half_plus_two/00000123";
constexpr char kTestDataInitOpV2[] =
"cc/saved_model/testdata/half_plus_two_v2/00000123";
class LoaderTest : public ::testing::Test {
protected:
LoaderTest() {}
string MakeSerializedExample(float x) {
tensorflow::Example example;
auto* feature_map = example.mutable_features()->mutable_feature();
(*feature_map)["x"].mutable_float_list()->add_value(x);
return example.SerializeAsString();
}
void ValidateAssets(const string& export_dir,
const SavedModelBundleLite& bundle) {
const string asset_directory =
io::JoinPath(export_dir, kSavedModelAssetsDirectory);
const string asset_filename = "foo.txt";
const string asset_filepath = io::JoinPath(asset_directory, asset_filename);
TF_EXPECT_OK(Env::Default()->FileExists(asset_filepath));
std::vector<Tensor> path_outputs;
TF_ASSERT_OK(
bundle.GetSession()->Run({}, {"filename_tensor:0"}, {}, &path_outputs));
ASSERT_EQ(1, path_outputs.size());
test::ExpectTensorEqual<tstring>(
test::AsTensor<tstring>({"foo.txt"}, TensorShape({})), path_outputs[0]);
}
void CheckSavedModelBundle(const string& export_dir,
const SavedModelBundleLite& bundle) {
ValidateAssets(export_dir, bundle);
// Retrieve the regression signature from the bundle.
const auto& signature_def = bundle.GetSignatures().at("regress_x_to_y");
const string input_name = signature_def.inputs().at(kRegressInputs).name();
const string output_name =
signature_def.outputs().at(kRegressOutputs).name();
std::vector<tstring> serialized_examples;
for (float x : {0, 1, 2, 3}) {
serialized_examples.push_back(MakeSerializedExample(x));
}
// Validate the half plus two behavior.
Tensor input =
test::AsTensor<tstring>(serialized_examples, TensorShape({4}));
std::vector<Tensor> outputs;
TF_ASSERT_OK(bundle.GetSession()->Run({{input_name, input}}, {output_name},
{}, &outputs));
ASSERT_EQ(outputs.size(), 1);
test::ExpectTensorEqual<float>(
outputs[0],
test::AsTensor<float>({2, 2.5, 3, 3.5}, TensorShape({4, 1})));
}
};
// Test for resource leaks related to TensorFlow session closing requirements
// when loading and unloading large numbers of SavedModelBundles.
// TODO(sukritiramesh): Increase run iterations and move outside of the test
// suite.
TEST_F(LoaderTest, ResourceLeakTest) {
SavedModelBundleLite bundle;
SessionOptions session_options;
RunOptions run_options;
const string export_dir =
io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataSharded);
for (int i = 0; i < 100; ++i) {
TF_ASSERT_OK(LoadSavedModel(session_options, run_options, export_dir,
{kSavedModelTagServe}, &bundle));
CheckSavedModelBundle(export_dir, bundle);
}
}
TEST_F(LoaderTest, TagMatch) {
SavedModelBundleLite bundle;
SessionOptions session_options;
RunOptions run_options;
const string export_dir =
io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataSharded);
TF_ASSERT_OK(LoadSavedModel(session_options, run_options, export_dir,
{kSavedModelTagServe}, &bundle));
CheckSavedModelBundle(export_dir, bundle);
}
TEST_F(LoaderTest, NoTagMatch) {
SavedModelBundleLite bundle;
RunOptions run_options;
SessionOptions session_options;
const string export_dir =
io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataSharded);
Status st = LoadSavedModel(session_options, run_options, export_dir,
{"missing-tag"}, &bundle);
EXPECT_FALSE(st.ok());
EXPECT_TRUE(absl::StrContains(
st.error_message(),
"Could not find meta graph def matching supplied tags: { missing-tag }"))
<< st.error_message();
}
TEST_F(LoaderTest, NoTagMatchMultiple) {
SavedModelBundleLite bundle;
RunOptions run_options;
SessionOptions session_options;
const string export_dir =
io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataSharded);
Status st = LoadSavedModel(session_options, run_options, export_dir,
{kSavedModelTagServe, "missing-tag"}, &bundle);
EXPECT_FALSE(st.ok());
EXPECT_TRUE(absl::StrContains(
st.error_message(),
"Could not find meta graph def matching supplied tags: "))
<< st.error_message();
}
TEST_F(LoaderTest, SessionCreationFailure) {
SavedModelBundleLite bundle;
// Use invalid SessionOptions to cause session creation to fail. Default
// options work, so provide an invalid value for the target field.
SessionOptions session_options;
constexpr char kInvalidTarget[] = "invalid target";
session_options.target = kInvalidTarget;
RunOptions run_options;
const string export_dir =
io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataSharded);
Status st = LoadSavedModel(session_options, run_options, export_dir,
{kSavedModelTagServe}, &bundle);
EXPECT_FALSE(st.ok());
EXPECT_TRUE(absl::StrContains(st.error_message(), kInvalidTarget))
<< st.error_message();
}
TEST_F(LoaderTest, PbtxtFormat) {
SavedModelBundleLite bundle;
SessionOptions session_options;
RunOptions run_options;
const string export_dir =
io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPbTxt);
TF_ASSERT_OK(LoadSavedModel(session_options, run_options, export_dir,
{kSavedModelTagServe}, &bundle));
CheckSavedModelBundle(export_dir, bundle);
}
TEST_F(LoaderTest, MainOpFormat) {
SavedModelBundleLite bundle;
SessionOptions session_options;
RunOptions run_options;
const string export_dir =
io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataMainOp);
TF_ASSERT_OK(LoadSavedModel(session_options, run_options, export_dir,
{kSavedModelTagServe}, &bundle));
CheckSavedModelBundle(export_dir, bundle);
}
TEST_F(LoaderTest, InvalidExportPath) {
SavedModelBundleLite bundle;
RunOptions run_options;
SessionOptions session_options;
const string export_dir =
io::JoinPath(testing::TensorFlowSrcRoot(), "missing-path");
Status st = LoadSavedModel(session_options, run_options, export_dir,
{kSavedModelTagServe}, &bundle);
EXPECT_FALSE(st.ok());
}
TEST_F(LoaderTest, MaybeSavedModelDirectory) {
// Valid SavedModel directory.
const string export_dir =
io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataSharded);
EXPECT_TRUE(MaybeSavedModelDirectory(export_dir));
// Directory that does not exist.
const string missing_export_dir =
io::JoinPath(testing::TensorFlowSrcRoot(), "missing-path");
EXPECT_FALSE(MaybeSavedModelDirectory(missing_export_dir));
// Directory that exists but is an invalid SavedModel location.
const string invalid_export_dir =
io::JoinPath(testing::TensorFlowSrcRoot(), "cc/saved_model");
EXPECT_FALSE(MaybeSavedModelDirectory(invalid_export_dir));
}
TEST_F(LoaderTest, SavedModelInitOpV2Format) {
SavedModelBundleLite bundle;
SessionOptions session_options;
RunOptions run_options;
const string export_dir =
io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataInitOpV2);
TF_ASSERT_OK(LoadSavedModel(session_options, run_options, export_dir,
{kSavedModelTagServe}, &bundle));
CheckSavedModelBundle(export_dir, bundle);
}
} // namespace
} // namespace tensorflow

View File

@ -71,8 +71,7 @@ class LoaderTest : public ::testing::Test {
const SavedModelBundle& bundle) {
ValidateAssets(export_dir, bundle);
// Retrieve the regression signature from meta graph def.
const auto signature_def_map = bundle.meta_graph_def.signature_def();
const auto signature_def = signature_def_map.at("regress_x_to_y");
const auto& signature_def = bundle.GetSignatures().at("regress_x_to_y");
const string input_name = signature_def.inputs().at(kRegressInputs).name();
const string output_name =

View File

@ -1,5 +1,5 @@
load("//tensorflow:tensorflow.bzl", "tf_cc_test", "cc_header_only_library")
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
load("//tensorflow/stream_executor:build_defs.bzl", "if_cuda_or_rocm")
load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library", "tf_jit_compilation_passes_extra_deps")
load("//tensorflow/core/platform:default/build_config.bzl", "tf_additional_all_protos", "tf_proto_library")
@ -38,7 +38,7 @@ cc_library(
":xla_cpu_device",
":xla_cpu_jit",
"//tensorflow/compiler/plugin",
] + if_cuda([
] + if_cuda_or_rocm([
":xla_gpu_device",
":xla_gpu_jit",
]),
@ -61,7 +61,7 @@ cc_library(
cc_library(
name = "xla_gpu_jit",
visibility = ["//visibility:public"],
deps = if_cuda([
deps = if_cuda_or_rocm([
":jit_compilation_passes",
"//tensorflow/compiler/jit/kernels:xla_ops",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",

View File

@ -130,17 +130,24 @@ RecursiveCompilabilityChecker::FindUncompilableNodes(
return uncompilable_nodes;
}
bool RecursiveCompilabilityChecker::HasXLAKernel(const Node& node) const {
bool RecursiveCompilabilityChecker::HasXLAKernel(
const Node& node, string* uncompilable_reason) const {
// There is a SymbolicGradient kernel on the XLA_JIT device, but the gradient
// is really a kind of function call and will be handled by
// IsCompilableCall().
if (node.type_string() == "SymbolicGradient") return false;
if (node.type_string() == "SymbolicGradient") {
*uncompilable_reason =
"SymbolicGradient should be handled by IsCompilableCall().";
return false;
}
if (node.type_string() == "Const") {
// Skip Const op with type DT_STRING, since XLA doesn't support it, but the
// registered Const KernelDef says that it does, to support no-op Assert for
// tfcompile.
const AttrValue* attr = node.attrs().Find("dtype");
if (attr != nullptr && attr->type() == DT_STRING) {
*uncompilable_reason =
"Const op with type DT_STRING is not supported by XLA.";
return false;
}
}
@ -150,10 +157,16 @@ bool RecursiveCompilabilityChecker::HasXLAKernel(const Node& node) const {
// such nodes out of XLA clusters.
if (HasForwardedRefInput(node)) {
VLOG(2) << "Rejecting " << node.name() << ": Identity with unsafe cast.";
*uncompilable_reason = "Identity with unsafe cast.";
return false;
}
return FindKernelDef(jit_device_type_, node.def(), nullptr, nullptr).ok();
Status s = FindKernelDef(jit_device_type_, node.def(), nullptr, nullptr);
if (!s.ok()) {
*uncompilable_reason = s.error_message();
return false;
}
return true;
}
// Tests whether 'if_node' is compilable. Every operator in the then_branch and
@ -336,16 +349,17 @@ bool RecursiveCompilabilityChecker::IsCompilableNode(
return false;
}
string uncompilable_reason;
if (IsFunctionCall(*lib_runtime->GetFunctionLibraryDefinition(), node)) {
if (!IsCompilableCall(node.def(), lib_runtime, stack_trace,
encapsulating_function, uncompilable_nodes)) {
LogNotCompilable(node, "unsupported function");
return false;
}
} else if (!HasXLAKernel(node)) {
absl::string_view uncompilable_reason = "unsupported op";
MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
encapsulating_function, uncompilable_nodes);
} else if (!HasXLAKernel(node, &uncompilable_reason)) {
MaybeMarkUncompilableNode(
absl::StrCat("unsupported op: ", uncompilable_reason), *stack_trace,
encapsulating_function, uncompilable_nodes);
LogNotCompilable(node, uncompilable_reason);
return false;
}

View File

@ -247,7 +247,8 @@ class RecursiveCompilabilityChecker {
absl::c_any_of(node.output_types(), is_variant);
}
bool HasXLAKernel(const Node& node) const;
bool HasXLAKernel(const Node& node,
string* uncompilable_reason = nullptr) const;
static void MaybeMarkUncompilableNode(
const absl::string_view reason,

View File

@ -125,7 +125,8 @@ TEST_F(CompilabilityCheckUtilTest, CheckNonFunctionalNodes) {
const auto& uncompilable_nodes_inside_function = node_info_it->second.second;
ASSERT_EQ(1, uncompilable_nodes_inside_function.size());
const auto& uncompilable_node_info = uncompilable_nodes_inside_function.at(0);
EXPECT_EQ("unsupported op", uncompilable_node_info.uncompilable_reason);
EXPECT_TRUE(absl::StrContains(uncompilable_node_info.uncompilable_reason,
"unsupported op"));
ASSERT_EQ(1, uncompilable_node_info.stack_trace.size());
ASSERT_EQ("", uncompilable_node_info.stack_trace.at(0).function_name);
}
@ -167,7 +168,8 @@ TEST_F(CompilabilityCheckUtilTest, CheckSimpleFunctionNode) {
EXPECT_EQ("D", node_stack.at(0).name);
EXPECT_EQ(kUncompilableFunctionNodeName, node_stack.at(1).name);
EXPECT_EQ(kUncompilableFunctionNodeName, node_info.name);
EXPECT_EQ("unsupported op", node_info.uncompilable_reason);
EXPECT_TRUE(
absl::StrContains(node_info.uncompilable_reason, "unsupported op"));
}
TEST_F(CompilabilityCheckUtilTest, CheckFunctionalWhileNode) {
@ -246,7 +248,8 @@ TEST_F(CompilabilityCheckUtilTest, CheckFunctionalWhileNode) {
stacktrace_second_node_info.function_name);
EXPECT_EQ(kUncompilableFunctionNodeName, node_info.name);
EXPECT_EQ("unsupported op", node_info.uncompilable_reason);
EXPECT_TRUE(
absl::StrContains(node_info.uncompilable_reason, "unsupported op"));
}
TEST_F(CompilabilityCheckUtilTest, CheckFunctionalIfNode) {
@ -322,7 +325,8 @@ TEST_F(CompilabilityCheckUtilTest, CheckFunctionalIfNode) {
stacktrace_second_node_info.function_name);
EXPECT_EQ(kUncompilableFunctionNodeName, uncompilable_node_one.name);
EXPECT_EQ("unsupported op", uncompilable_node_one.uncompilable_reason);
EXPECT_TRUE(absl::StrContains(uncompilable_node_one.uncompilable_reason,
"unsupported op"));
NameAttrList function_two;
function_two.set_name(kUncompilableFunctionTwoName);
@ -345,7 +349,8 @@ TEST_F(CompilabilityCheckUtilTest, CheckFunctionalIfNode) {
node_two_stacktrace_second_node.function_name);
EXPECT_EQ(kUncompilableFunctionNodeTwoName, uncompilable_node_two.name);
EXPECT_EQ("unsupported op", uncompilable_node_two.uncompilable_reason);
EXPECT_TRUE(absl::StrContains(uncompilable_node_one.uncompilable_reason,
"unsupported op"));
}
} // namespace

View File

@ -45,7 +45,8 @@ cc_library(
"//tensorflow/compiler/mlir/lite:tensorflow_lite_quantize",
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_dialect_registration",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_fold_switch",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_passes",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_test_passes",
"//tensorflow/compiler/mlir/tensorflow:tf_dialect_passes",
"//tensorflow/compiler/mlir/xla:hlo",
"//tensorflow/compiler/mlir/xla:lhlo",

View File

@ -0,0 +1,3 @@
# TensorFlow MLIR
These are the docs for: https://www.tensorflow.org/mlir

View File

@ -0,0 +1,24 @@
upper_tabs:
# Tabs left of dropdown menu
- include: /_upper_tabs_left.yaml
- include: /api_docs/_upper_tabs_api.yaml
# Dropdown menu
- name: Resources
path: /resources
is_default: true
menu:
- include: /resources/_menu_toc.yaml
lower_tabs:
# Subsite tabs
other:
- name: Guide & Tutorials
contents:
- title: Overview
path: /mlir/overview
- heading: Dialects
- title: TensorFlow
path: /mlir/tf_ops
- title: TensorFlow Lite
path: /mlir/tfl_ops
- include: /_upper_tabs_right.yaml

View File

@ -0,0 +1,48 @@
book_path: /mlir/_book.yaml
project_path: /mlir/_project.yaml
description: <!--no description-->
landing_page:
custom_css_path: /site-assets/css/style.css
rows:
- heading: MLIR unifies the infrastructure for high-performance ML models in TensorFlow.
items:
- description: >
The MLIR project defines a common intermediate representation (IR) that
unifies the infrastructure required to execute high performance machine
learning models in TensorFlow and similar ML frameworks. This project
will include the application of HPC techniques, along with integration of
search algorithms like reinforcement learning. MLIR aims to reduce the
cost to bring up new hardware, and improve usability for existing
TensorFlow users.
- code_block: |
<pre class = "prettyprint">
// Syntactically similar to LLVM:
func @testFunction(%arg0: i32) {
%x = call @thingToCall(%arg0) : (i32) -> i32
br ^bb1
^bb1:
%y = addi %x, %x : i32
return %y : i32
}
</pre>
- classname: devsite-landing-row-cards
items:
- heading: "Multi-Level Intermediate Representation for Compiler Infrastructure"
youtube_id: qzljG6DKgic
buttons:
- label: Watch the video
path: https://www.youtube.com/watch?v=qzljG6DKgic
- heading: "A new intermediate representation and compiler framework"
image_path: /resources/images/tf-logo-card-16x9.png
path: https://medium.com/tensorflow/mlir-a-new-intermediate-representation-and-compiler-framework-beba999ed18d
buttons:
- label: Read on TensorFlow blog
path: https://medium.com/tensorflow/mlir-a-new-intermediate-representation-and-compiler-framework-beba999ed18d
- heading: TensorFlow MLIR on GitHub
image_path: /resources/images/github-card-16x9.png
path: https://github.com/tensorflow/mlir
buttons:
- label: View on GitHub
path: https://github.com/tensorflow/mlir

View File

@ -0,0 +1,11 @@
name: TensorFlow MLIR
breadcrumb_name: MLIR
home_url: /mlir/
parent_project_metadata_path: /_project.yaml
description: >
MLIR unifies the infrastructure for high-performance ML models in TensorFlow.
use_site_branding: true
hide_from_products_list: true
content_license: cc-apache
buganizer_id: 443907
include: /_project_included.yaml

File diff suppressed because one or more lines are too long

After

Width:  |  Height:  |  Size: 148 KiB

View File

@ -0,0 +1,5 @@
# MLIR overview
## Overview
<img alt="MLIR overview diagram" src="./images/mlir-infra.svg"/>

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -1,4 +1,4 @@
load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_native_cc_binary")
load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test", "tf_native_cc_binary")
load(
"@local_config_mlir//:tblgen.bzl",
"gentbl",
@ -185,6 +185,41 @@ cc_library(
alwayslink = 1,
)
cc_library(
name = "lstm_utils",
srcs = [
"utils/lstm_utils.cc",
],
hdrs = [
"utils/lstm_utils.h",
],
copts = ["-std=c++14"],
deps = [
":tensorflow_lite",
"//tensorflow/compiler/mlir/tensorflow",
"@llvm//:support",
"@local_config_mlir//:IR",
"@local_config_mlir//:StandardOps",
"@local_config_mlir//:Support",
],
)
tf_cc_test(
name = "lstm_utils_test",
size = "small",
srcs = ["utils/lstm_utils_test.cc"],
deps = [
":lstm_utils",
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"@llvm//:support",
"@local_config_mlir//:IR",
"@local_config_mlir//:StandardOps",
"@local_config_mlir//:Support",
],
)
cc_library(
name = "tensorflow_lite_legalize_tf",
srcs = [
@ -198,9 +233,11 @@ cc_library(
"transforms/prepare_composite_functions_tf.cc",
"transforms/prepare_tf.cc",
"transforms/trim_functions_tf.cc",
"transforms/unroll_batch_matmul.cc",
],
hdrs = [
"transforms/passes.h",
"transforms/unroll_batch_matmul.h",
],
deps = [
":common",
@ -249,6 +286,7 @@ cc_library(
name = "tensorflow_lite_quantize",
srcs = [
"transforms/generated_quantize.inc",
"transforms/load_quantization_recipe.cc",
"transforms/post_quantize.cc",
"transforms/prepare_quantize.cc",
"transforms/quantize.cc",
@ -521,7 +559,7 @@ cc_library(
":tensorflow_lite_quantize",
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tensorflow:error_util",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_fold_switch",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_passes",
"//tensorflow/compiler/mlir/tensorflow:tf_dialect_lib",
"//tensorflow/compiler/mlir/tensorflow:tf_dialect_passes",
"//tensorflow/compiler/mlir/tensorflow:translate_lib",

View File

@ -99,7 +99,10 @@ using xla::StatusOr;
template <typename T>
using BufferOffset = flatbuffers::Offset<T>;
using CustomOptionsOffset = BufferOffset<flatbuffers::Vector<uint8_t>>;
template <typename T>
using VectorBufferOffset = flatbuffers::Offset<flatbuffers::Vector<T>>;
using CustomOptionsOffset = VectorBufferOffset<uint8_t>;
namespace error = tensorflow::error;
namespace tfl = mlir::TFL;
@ -415,6 +418,15 @@ class Translator {
Optional<BufferOffset<tflite::SubGraph>> BuildSubGraph(FuncOp fn);
// Builds Metadata with the given `name` and buffer `content`.
BufferOffset<tflite::Metadata> BuildMetadata(StringRef name,
StringRef content);
// Encodes the `tfl.metadata` dictionary attribute of the module to the
// metadata section in the final model.
Optional<VectorBufferOffset<BufferOffset<tflite::Metadata>>>
CreateMetadataVector();
// Uses the tf.entry_function attribute (if set) to initialize the op to name
// mapping.
void InitializeNamesFromAttribute(FuncOp fn);
@ -977,6 +989,36 @@ Optional<BufferOffset<tflite::SubGraph>> Translator::BuildSubGraph(FuncOp fn) {
/*name=*/builder_.CreateString(fn.getName().str()));
}
BufferOffset<tflite::Metadata> Translator::BuildMetadata(StringRef name,
StringRef content) {
auto buffer_index = buffers_.size();
auto buffer_data = builder_.CreateVector(
reinterpret_cast<const uint8_t*>(content.data()), content.size());
buffers_.push_back(tflite::CreateBuffer(builder_, buffer_data));
return tflite::CreateMetadataDirect(builder_, name.data(), buffer_index);
}
Optional<VectorBufferOffset<BufferOffset<tflite::Metadata>>>
Translator::CreateMetadataVector() {
auto dict_attr = module_.getAttrOfType<mlir::DictionaryAttr>("tfl.metadata");
if (!dict_attr) return VectorBufferOffset<BufferOffset<tflite::Metadata>>();
std::vector<BufferOffset<tflite::Metadata>> metadata;
for (const auto& named_attr : dict_attr) {
StringRef name = named_attr.first;
mlir::Attribute attr = named_attr.second;
if (auto content = attr.dyn_cast<StringAttr>()) {
metadata.push_back(BuildMetadata(name, content.getValue()));
} else {
module_.emitError(
"all values in tfl.metadata's dictionary key-value pairs should be "
"string attributes");
return llvm::None;
}
}
return builder_.CreateVector(metadata);
}
Optional<std::string> Translator::Translate(ModuleOp module,
bool emit_builtin_tflite_ops,
bool emit_select_tf_ops,
@ -1024,12 +1066,17 @@ Optional<std::string> Translator::TranslateInternal() {
} else {
model_description = "MLIR Converted.";
}
// Build the model and finish the model building process.
auto description = builder_.CreateString(model_description.data());
VectorBufferOffset<int32_t> metadata_buffer = 0; // Deprecated
auto metadata = CreateMetadataVector();
if (!metadata) return llvm::None;
auto model = tflite::CreateModel(
builder_, TFLITE_SCHEMA_VERSION, builder_.CreateVector(opcodes_),
builder_.CreateVector(subgraphs), description,
builder_.CreateVector(buffers_));
builder_.CreateVector(buffers_), metadata_buffer, *metadata);
tflite::FinishModelBuffer(builder_, model);
// Return serialized string for the built FlatBuffer.

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include <algorithm>
#include <cstdint>
#include "llvm/ADT/APFloat.h"
@ -30,6 +31,7 @@ limitations under the License.
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
#include "mlir/IR/TypeUtilities.h" // TF:local_config_mlir
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
namespace mlir {
@ -1167,6 +1169,54 @@ OpFoldResult TransposeOp::fold(ArrayRef<Attribute> operands) {
return DenseElementsAttr::get(result_type, new_values);
}
static LogicalResult Verify(TransposeOp op) {
auto input_type = op.x()->getType().cast<ShapedType>();
auto perm_type = op.perm()->getType().cast<ShapedType>();
auto output_type = op.y()->getType().cast<ShapedType>();
if (input_type.hasStaticShape() && perm_type.hasStaticShape()) {
if (perm_type.getNumElements() != input_type.getRank()) {
return op.emitOpError(
"perm tensor elements size is not equal to input tensor rank");
}
}
DenseIntElementsAttr perm;
if (!matchPattern(op.perm(), m_Constant(&perm))) {
return success();
}
int index = 0;
llvm::SmallVector<int64_t, 4> axes;
for (auto axis_int : perm.getValues<APInt>()) {
const int64_t axis = axis_int.getSExtValue();
if (axis < 0 || (input_type.hasRank() && axis >= input_type.getRank())) {
return op.emitOpError(
llvm::formatv("perm[{0}] must be in [0, rank)", index));
}
if (std::count(axes.begin(), axes.end(), axis) > 0) {
return op.emitOpError(
llvm::formatv("perm[{0}] cannot have duplicated axis", index));
}
axes.push_back(axis);
index++;
}
if (input_type.hasStaticShape() && output_type.hasStaticShape()) {
llvm::SmallVector<int64_t, 4> transposed_shape;
for (int64_t axis : axes) {
transposed_shape.push_back(input_type.getDimSize(axis));
}
auto expected_output_type =
RankedTensorType::get(transposed_shape, input_type.getElementType());
if (output_type != expected_output_type) {
return op.emitOpError(llvm::formatv("expect output type {0}, got {1}",
expected_output_type, output_type));
}
}
return success();
}
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//

View File

@ -132,15 +132,35 @@ def TFL_FpOrI32OrI64Tensor : TensorOf<[AnyFloat, TFL_Int32Or64]>;
// Rank/Shape helpers.
//===----------------------------------------------------------------------===//
class TFL_OperandIsUnrankedPred<int n> :
CPred<"$_op.getOperand(" # n # ")->getType().isa<UnrankedTensorType>()">;
// TODO: Some of these could be generalized and/or moved to more general
// location.
// Returns true if the n-th operand has unknown rank or has rank m.
class TFL_OperandHasRank<int n, int m> :
PredOpTrait<"operand " # n # " is " # m # "-D",
Or<[CPred<"$_op.getOperand(" # n # ")->getType().isa<UnrankedTensorType>()">,
Or<[TFL_OperandIsUnrankedPred<n>,
CPred<"$_op.getOperand(" # n #
")->getType().cast<ShapedType>().getRank() == " # m>]>>;
// CPred version of TFL_OperandHasRank.
class TFL_OperandHasRankPred<int n, int m> :
Or<[TFL_OperandIsUnrankedPred<n>,
CPred<"$_op.getOperand(" # n #
")->getType().cast<ShapedType>().getRank() == " # m>]>;
// True if operand n is ranked and has a rank > dim.
class TFL_OperandIsRankedAndHasDimPred<int n, int dim> : And<[
CPred<"$_op.getOperand(" # n # ")->getType().isa<RankedTensorType>()">,
CPred<"$_op.getOperand(" # n # ")->getType().cast<ShapedType>().getRank() > "
# dim>]>;
class TFL_OperandDimEquals<int n, int dim, int size> : And<[
TFL_OperandIsRankedAndHasDimPred<n, dim>,
CPred<"$_op.getOperand(" # n # ")->getType().cast<ShapedType>()"
".getShape()[" # dim # " ] == " # size>]>;
// Returns true if the n-th operand has unknown rank or at least rank m.
class TFL_OperandHasAtleastRank<int n, int m> :
PredOpTrait<"operand " # n # " is " # m # "-D",
@ -155,6 +175,32 @@ class TFL_OperandRankEquals1DimOfOperand<int x, int y> :
"$_op.getOperand(" # y #
")->getType().cast<ShapedType>().getShape()[0]">>;
// True if x_shape[dim] == y_shape[dim].
class TFL_DimOfOperandEqualsDimOfOperandPred<int x, int y, int dim> : And<[
TFL_OperandIsRankedAndHasDimPred<x, dim>,
TFL_OperandIsRankedAndHasDimPred<y, dim>,
CPred<"$_op.getOperand(" # x #
")->getType().cast<ShapedType>().getShape()[" # dim # "] == "
"$_op.getOperand(" # y #
")->getType().cast<ShapedType>().getShape()[" # dim # "]">]>;
// Select operands must satisfy one of the following constraints:
// All inputs are unranked/scalars
// OR
// All inputs are ranked AND have equal dim[0] AND X & Y have same rank.
def SelectShapeConstraints :
PredOpTrait<"Select operands meet shape criteria",
Or<[
And<[
TFL_OperandHasRankPred<0, 0>,
TFL_OperandHasRankPred<1, 0>,
TFL_OperandHasRankPred<2, 0>]>,
And<[
TFL_DimOfOperandEqualsDimOfOperandPred<0, 1, 0>,
TFL_DimOfOperandEqualsDimOfOperandPred<0, 2, 0>,
CPred<"$_op.getOperand(1)->getType().cast<ShapedType>().getRank() == "
"$_op.getOperand(2)->getType().cast<ShapedType>().getRank()">]>]>>;
// This is a quantization-aware version of TCresVTEtIsSameAsOp
class TFL_TCresVTEtIsSameAsOp<int i, int j> : And<[
TCOpResIsShapedTypePred<i, j>,
@ -315,7 +361,7 @@ def TFL_AddOp : TFL_Op<"add", [Broadcastable, NoSideEffect, Commutative]> {
// TODO(haoliang): Implement legalization pass after pattern rewrite generator
// supports variadic inputs.
def TFL_AddNOp : TFL_Op<"add_n", [Commutative, NoSideEffect]> {
def TFL_AddNOp : TFL_Op<"add_n", [Commutative, NoSideEffect, SameOperandsAndResultsScale]> {
let summary = "add_n operator";
let description = [{
@ -323,11 +369,11 @@ def TFL_AddNOp : TFL_Op<"add_n", [Commutative, NoSideEffect]> {
}];
let arguments = (ins
Variadic<TensorOf<[F32, I32]>>:$inputs
Variadic<TensorOf<[F32, I32, QI16, QUI16]>>:$inputs
);
let results = (outs
TensorOf<[F32, I32]>:$sum
TensorOf<[F32, I32, QI16, QUI16]>:$sum
);
}
@ -680,6 +726,117 @@ def TFL_GreaterEqualOp : TFL_Op<"greater_equal", [
let hasOptions = 0;
}
// These ops are named NonMaxSuppressionV4 & NonMaxSuppressionV5 to be
// consistent with TensorFlow's naming. They are NOT 'versions' of NMS in the
// sense that one is an incremental change over the other.
// In reality NonMaxSuppressionV5 implements Soft Non Max Suppression and
// NonMaxSuppressionV4 performs hard NMS.
def TFL_NonMaxSuppressionV4Op : TFL_Op<"non_max_suppression_v4", [
NoSideEffect,
// Operand 0 (boxes) should have rank 2 with the dim[1] == 4 (box corners)
TFL_OperandHasRank<0, 2>,
PredOpTrait<"boxes should have dim[1] == 4",
TFL_OperandDimEquals<0, 1, 4>>,
// Operand 1 (scores) should be a 1-dim tensor
TFL_OperandHasRank<1, 1>,
// Other operands are scalar params.
TFL_OperandHasRank<2, 0>, TFL_OperandHasRank<3, 0>,
TFL_OperandHasRank<4, 0>]> {
let summary = [{
Greedily selects a subset of bounding boxes in descending order of score,
}];
let description = [{
pruning away boxes that have high intersection-over-union (IOU) overlap
with previously selected boxes. Bounding boxes with score less than
`score_threshold` are removed. Bounding boxes are supplied as
[y1, x1, y2, x2], where (y1, x1) and (y2, x2) are the coordinates of any
diagonal pair of box corners and the coordinates can be provided as normalized
(i.e., lying in the interval [0, 1]) or absolute. Note that this algorithm
is agnostic to where the origin is in the coordinate system and more
generally is invariant to orthogonal transformations and translations
of the coordinate system; thus translating or reflections of the coordinate
system result in the same boxes being selected by the algorithm.
The output of this operation is a set of integers indexing into the input
collection of bounding boxes representing the selected boxes. The bounding
box coordinates corresponding to the selected indices can then be obtained
using the `tf.gather operation`. For example:
selected_indices = tf.image.non_max_suppression_v2(
boxes, scores, max_output_size, iou_threshold, score_threshold)
selected_boxes = tf.gather(boxes, selected_indices)
}];
let arguments = (ins
TFL_FpTensor:$boxes,
TFL_FpTensor:$scores,
I32Tensor:$max_output_size,
TFL_FpTensor:$iou_threshold,
TFL_FpTensor:$score_threshold
);
let results = (outs
I32Tensor:$selected_indices,
I32Tensor:$valid_outputs
);
}
def TFL_NonMaxSuppressionV5Op : TFL_Op<"non_max_suppression_v5", [
NoSideEffect,
// Operand 0 (boxes) should have rank 2 with the dim[1] == 4 (box corners)
TFL_OperandHasRank<0, 2>,
PredOpTrait<"boxes should have dim[1] == 4",
TFL_OperandDimEquals<0, 1, 4>>,
// Operand 1 (scores) should be a 1-dim tensor
TFL_OperandHasRank<1, 1>,
// Other operands are scalar params.
TFL_OperandHasRank<2, 0>, TFL_OperandHasRank<3, 0>,
TFL_OperandHasRank<4, 0>, TFL_OperandHasRank<5, 0>]> {
let summary = [{
Greedily selects a subset of bounding boxes in descending order of score,
}];
let description = [{
pruning away boxes that have high intersection-over-union (IOU) overlap
with previously selected boxes. Bounding boxes with score less than
`score_threshold` are removed. Bounding boxes are supplied as
[y1, x1, y2, x2], where (y1, x1) and (y2, x2) are the coordinates of any
diagonal pair of box corners and the coordinates can be provided as normalized
(i.e., lying in the interval [0, 1]) or absolute. Note that this algorithm
is agnostic to where the origin is in the coordinate system and more
generally is invariant to orthogonal transformations and translations
of the coordinate system; thus translating or reflections of the coordinate
system result in the same boxes being selected by the algorithm.
The output of this operation is a set of integers indexing into the input
collection of bounding boxes representing the selected boxes. The bounding
box coordinates corresponding to the selected indices can then be obtained
using the `tf.gather operation`. For example:
selected_indices = tf.image.non_max_suppression_v2(
boxes, scores, max_output_size, iou_threshold, score_threshold)
selected_boxes = tf.gather(boxes, selected_indices)
This op also supports a Soft-NMS (with Gaussian weighting) mode (c.f.
Bodla et al, https://arxiv.org/abs/1704.04503) where boxes reduce the score
of other overlapping boxes instead of directly causing them to be pruned.
To enable this Soft-NMS mode, set the `soft_nms_sigma` parameter to be
larger than 0.
}];
let arguments = (ins
TFL_FpTensor:$boxes,
TFL_FpTensor:$scores,
I32Tensor:$max_output_size,
TFL_FpTensor:$iou_threshold,
TFL_FpTensor:$score_threshold,
TFL_FpTensor:$soft_nms_sigma
);
let results = (outs
I32Tensor:$selected_indices,
TFL_FpTensor:$selected_scores,
I32Tensor:$valid_outputs
);
}
def TFL_NotEqualOp : TFL_Op<"not_equal", [
Broadcastable, Commutative, NoSideEffect, NoQuantizableResult]> {
let summary = "Not_equal operator";
@ -987,11 +1144,11 @@ def TFL_L2NormalizationOp : TFL_Op<"l2_normalization", [NoSideEffect]> {
}];
let arguments = (ins
TensorOf<[F32, QUI8, QI8, I8]>:$input,
TensorOf<[F32, QUI8, QI8, QUI16, QI16, I8]>:$input,
TFL_AFAttr:$fused_activation_function
);
let results = (outs TensorOf<[F32, QUI8, QI8, I8]>:$output);
let results = (outs TensorOf<[F32, QUI8, QI8, QUI16, QI16, I8]>:$output);
let hasOptions = 1;
@ -1100,9 +1257,9 @@ def TFL_LogisticOp: TFL_Op<"logistic", [
Computes element-wise Sigmoid of input
}];
let arguments = (ins TensorOf<[AnyFloat, QI8, QUI8]>:$x);
let arguments = (ins TensorOf<[AnyFloat, QI8, QUI8, QI16, QUI16]>:$x);
let results = (outs TensorOf<[AnyFloat, QI8, QUI8]>:$y);
let results = (outs TensorOf<[AnyFloat, QI8, QUI8, QI16, QUI16]>:$y);
}
def TFL_LogOp: TFL_Op<"log", [NoSideEffect, SameOperandsAndResultType]> {
@ -1441,7 +1598,7 @@ def TFL_NegOp: TFL_Op<"neg", [NoSideEffect, SameOperandsAndResultType]> {
let hasOptions = 0b1;
}
def TFL_PackOp : TFL_Op<"pack", [NoSideEffect]> {
def TFL_PackOp : TFL_Op<"pack", [NoSideEffect, SameOperandsAndResultsScale]> {
let summary = "Packs a list of tensors along a dimension into one tensor";
let description = [{
@ -1472,14 +1629,14 @@ def TFL_PackOp : TFL_Op<"pack", [NoSideEffect]> {
}];
let arguments = (ins
Variadic<TensorOf<[F32, I8, I16, I32, I64]>>:$values,
Variadic<TensorOf<[F32, I8, I16, I32, I64, QI8, QUI8]>>:$values,
I32Attr:$values_count,
I32Attr:$axis
);
let results = (outs
TensorOf<[F32, I8, I16, I32, I64]>:$output
TensorOf<[F32, I8, I16, I32, I64, QI8, QUI8]>:$output
);
let verifier = [{ return Verify(*this); }];
@ -1777,8 +1934,7 @@ def TFL_ReverseV2Op: TFL_Op<"reverse_v2",
}
def TFL_SelectOp : TFL_Op<"select", [NoSideEffect,
// TODO(jpienaar): This is too retrictive, rank 1 input is also allowed.
SameOperandsAndResultShape,
SelectShapeConstraints,
PredOpTrait<"operands have same element type", TCopVTEtIsSameAs<1, 2>>,
PredOpTrait<"operands and result have same element type",
TCresVTEtIsSameAsOp<0, 1>>]> {
@ -1836,7 +1992,7 @@ def TFL_SoftmaxOp : TFL_Op<"softmax", [
let summary = "Softmax operator";
let description = [{
Computes element-wise softmax activiations with the following formula
Computes element-wise softmax activations with the following formula
exp(input) / tf.reduce_sum(exp(input * beta), dim)
}];
@ -1942,9 +2098,9 @@ def TFL_TanhOp: TFL_Op<"tanh", [
Computes element-wise Hyperbolic tangent of input
}];
let arguments = (ins TensorOf<[F32, I16, I8, QI8, QUI8, TFL_Uint8]>:$x);
let arguments = (ins TensorOf<[F32, I16, I8, QI8, QUI8, QI16, QUI16, TFL_Uint8]>:$x);
let results = (outs TensorOf<[F32, I16, I8, QI8, QUI8, TFL_Uint8]>:$y);
let results = (outs TensorOf<[F32, I16, I8, QI8, QUI8, QI16, QUI16, TFL_Uint8]>:$y);
}
def TFL_TileOp: TFL_Op<"tile", [NoSideEffect,
@ -1999,8 +2155,6 @@ def TFL_TopKV2Op: TFL_Op<"topk_v2", [NoSideEffect, TFL_OperandHasRank<1,0>,
let hasOptions = 1;
}
// TODO: Verify result shape a permutation of the first input shape's
// dimensions.
def TFL_TransposeOp : TFL_Op<"transpose",
[NoSideEffect,
TFL_OperandHasRank<1,1>,
@ -2025,6 +2179,8 @@ def TFL_TransposeOp : TFL_Op<"transpose",
AnyTensor:$y
);
let verifier = [{ return Verify(*this); }];
let hasFolder = 1;
}
@ -2342,7 +2498,8 @@ def TFL_StridedSliceOp: TFL_Op<"strided_slice",
let hasOptions = 1;
}
def TFL_CastOp : TFL_Op<"cast", [NoSideEffect, SameOperandsAndResultShape]> {
def TFL_CastOp : TFL_Op<"cast", [
NoSideEffect, SameOperandsAndResultShape, NoQuantizableResult]> {
let summary = "Cast operator";
let description = [{
@ -2629,6 +2786,10 @@ Ba et al. “Layer Normalization”
let results = (outs AnyTensor:$output);
// TODO(fengliuai): customize printer and parser to not display
// empty region.
let regions = (region AnyRegion:$internal);
let hasOptions = 1;
let verifier = [{ return Verify(*this); }];

View File

@ -549,32 +549,42 @@ QuantParams QuantizationDriver::GetQuantParamsForSameScaleConstraint(
void QuantizationDriver::PreprocessConstantOps() {
fn_.walk([&](ConstantOp cst) {
// Non-float tensors are neither weights or require quantization.
if (!cst.getType().cast<ShapedType>().getElementType().isa<FloatType>()) {
return;
}
// Non-float tensors are neither weights nor require quantization.
auto type = cst.getType().dyn_cast<ShapedType>();
if (!type || !type.getElementType().isa<FloatType>()) return;
Value *value = cst.getResult();
SmallVector<std::pair<Operation *, int>, 4> bias_users;
bool used_as_weight = false;
for (auto &use : value->getUses()) {
auto spec = GetQuantSpec(use.getOwner());
auto biases = spec->biases_params;
Operation *user = use.getOwner();
int operand_num = use.getOperandNumber();
// The user doesn't use this value as a bias operand nor require same
// scale.
// The user doesn't use this value as a bias operand or require same
// scale, then this constant is considered to be a weight.
if (biases.find(operand_num) == biases.end() &&
!spec->requires_same_scale) {
weights_.insert(cst);
used_as_weight = true;
} else {
bias_users.push_back({user, operand_num});
}
}
builder_.setInsertionPoint(cst);
for (int i = 1; i < bias_users.size(); ++i) {
// If the constant is used as a weight, this constant will be duplicated for
// each bias user, so it isn't shared with the weight usage. Otherwise, the
// first bias user can use the original constant and the rest use the
// duplications, so we pop bias user from the set.
if (used_as_weight) {
weights_.insert(cst);
} else {
bias_users.pop_back();
builder_.setInsertionPoint(cst);
}
for (auto bias_user : bias_users) {
auto copied = builder_.create<ConstantOp>(cst.getLoc(), cst.getValue());
bias_users[i].first->setOperand(bias_users[i].second, copied.getResult());
bias_user.first->setOperand(bias_user.second, copied.getResult());
}
});
}

View File

@ -32,7 +32,7 @@ static Type GetQuantizedType(Builder builder, Type input_type, double min,
double max, int storage_type_width,
bool narrow_range, bool is_signed) {
auto converter =
quant::ExpressedToUniformQuantizedConverter::forInputType(input_type);
quant::ExpressedToQuantizedConverter::forInputType(input_type);
quant::UniformQuantizedType quantizedEleType = quant::fakeQuantAttrsToType(
builder.getUnknownLoc(), storage_type_width, min, max, narrow_range,

View File

@ -13,6 +13,11 @@ func @extractSimpleOphint() {
return
}
// CHECK: func @d4b1eb00b81211e99426dc4a3e957995(tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32>
// CHECK: attributes {_tflite_function_input_index = [0 : i32], _tflite_function_name = "cool_activation"}
// -----
// CHECK-LABEL: extractPackedInputOphint
func @extractPackedInputOphint() {
// CHECK: %[[PACK:[0-9]*]] = "tfl.pack"(%0, %1) {axis = 0 : i32, values_count = 2 : i32} : (tensor<1x16x1xf32>, tensor<1x16x1xf32>) -> tensor<2x1x16x1xf32>
@ -30,6 +35,11 @@ func @extractPackedInputOphint() {
return
}
// CHECK: func @47393154b9af11e99426dc4a3e957995(tensor<2x1x16x1xf32>) -> tensor<1x16x1xf32>
// CHECK: attributes {_tflite_function_input_index = [0 : i32], _tflite_function_name = "cool_activation_stack"}
// -----
// CHECK-LABEL: extractFirstInputOphint
func @extractFirstInputOphint() {
// CHECK: %[[OP_HINT_CALL:[0-9]*]] = call @b703f0f4b9ec11e99426dc4a3e957995(%0) : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32>
@ -46,6 +56,11 @@ func @extractFirstInputOphint() {
return
}
// CHECK: func @b703f0f4b9ec11e99426dc4a3e957995(tensor<1x16x1xf32>) -> tensor<1x16x1xf32>
// CHECK: attributes {_tflite_function_input_index = [0 : i32], _tflite_function_name = "cool_activation_first"}
// -----
// CHECK-LABEL: extractLastInputOphint
func @extractLastInputOphint() {
// CHECK: %[[OP_HINT_CALL:[0-9]*]] = call @e31fcf90b9ed11e99426dc4a3e957995(%1) : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32>
@ -62,6 +77,11 @@ func @extractLastInputOphint() {
return
}
// CHECK: func @e31fcf90b9ed11e99426dc4a3e957995(tensor<1x16x1xf32>) -> tensor<1x16x1xf32>
// CHECK: attributes {_tflite_function_input_index = [0 : i32], _tflite_function_name = "cool_activation_last"}
// -----
// CHECK-LABEL: extractPackOneInputOphint
func @extractPackOneInputOphint() {
// CHECK: %[[RESHAPE:[0-9]*]] = "tfl.reshape"(%0) : (tensor<1x16x1xf32>) -> tensor<1x1x16x1xf32>
@ -75,13 +95,16 @@ func @extractPackOneInputOphint() {
return
}
// CHECK: func @33fab028b9ef11e99426dc4a3e957995(tensor<1x1x16x1xf32>) -> tensor<1x16x1xf32>
// CHECK: attributes {_tflite_function_input_index = [0 : i32], _tflite_function_name = "cool_activation_pack_input_one"}
// -----
// CHECK-LABEL: extractStackInputOutputOphint
func @extractStackInputOutputOphint() {
// CHECK: %[[PACK:[0-9]*]] = "tfl.pack"(%0, %1) {axis = 0 : i32, values_count = 2 : i32} : (tensor<1x16x1xf32>, tensor<1x16x1xf32>) -> tensor<2x1x16x1xf32>
// CHECK: %[[OP_HINT_CALL:[0-9]*]] = call @b92ed354b9f011e99426dc4a3e957995(%[[PACK]]) : (tensor<2x1x16x1xf32>) -> tensor<2x1x16x1xf32>
// CHECK: %[[UNPACK:[0-9]*]]:2 = "tfl.unpack"(%[[OP_HINT_CALL]]) {axis = 0 : i32, num = 2 : i32} : (tensor<2x1x16x1xf32>) -> (tensor<1x16x1xf32>, tensor<1x16x1xf32>)
// CHECK: %[[OUTPUT1:[0-9]*]] = "tf.Identity"(%[[UNPACK]]#0) {T = "tfdtype$DT_FLOAT", _tflite_function_aggregate = "stack", _tflite_function_name = "cool_activation_stack_input_output", _tflite_function_output_index = 0 : i64, _tflite_function_sort_index = 0 : i64, _tflite_function_uuid = "b92ed354b9f011e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "OutputHint-cool_activation_stack_input_output-b92ed354b9f011e99426dc4a3e957995-0-0-None"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32>
// CHECK: %[[OUTPUT2:[0-9]*]] = "tf.Identity"(%[[UNPACK]]#1) {T = "tfdtype$DT_FLOAT", _tflite_function_aggregate = "stack", _tflite_function_name = "cool_activation_stack_input_output", _tflite_function_output_index = 0 : i64, _tflite_function_sort_index = 1 : i64, _tflite_function_uuid = "b92ed354b9f011e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "OutputHint-cool_activation_stack_input_output-b92ed354b9f011e99426dc4a3e957995-0-1-None"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32>
%0 = "tf.Placeholder"() {dtype = "tfdtype$DT_FLOAT", name = "Placeholder", shape = "tfshape$dim { size: 1 } dim { size: 16 } dim { size: 1 }"} : () -> tensor<1x16x1xf32>
%1 = "tf.Identity"(%0) {T = "tfdtype$DT_FLOAT", _tflite_function_aggregate = "stack", _tflite_function_input_index = 0 : i64, _tflite_function_name = "cool_activation_stack_input_output", _tflite_function_sort_index = 0 : i64, _tflite_function_uuid = "b92ed354b9f011e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "InputHint-cool_activation_stack_input_output-b92ed354b9f011e99426dc4a3e957995-0-0-None"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32>
@ -98,11 +121,14 @@ func @extractStackInputOutputOphint() {
return
}
// CHECK: func @b92ed354b9f011e99426dc4a3e957995(tensor<2x1x16x1xf32>) -> tensor<2x1x16x1xf32>
// CHECK: attributes {_tflite_function_input_index = [0 : i32], _tflite_function_name = "cool_activation_stack_input_output"}
// -----
// CHECK-LABEL: extractMultipleInputsOutputsOphint
func @extractMultipleInputsOutputsOphint() {
// CHECK: %[[OP_HINT_CALL:[0-9]*]]:2 = call @a6ca45beb9f411e99426dc4a3e957995(%0, %1) : (tensor<1x16x1xf32>, tensor<1x16x1xf32>) -> (tensor<1x16x1xf32>, tensor<1x16x1xf32>)
// CHECK: %[[OUTPUT1:[0-9]*]] = "tf.Identity"(%[[OP_HINT_CALL]]#0) {T = "tfdtype$DT_FLOAT", _tflite_function_name = "cool_activation_multiple_input_output", _tflite_function_output_index = 0 : i64, _tflite_function_uuid = "a6ca45beb9f411e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "OutputHint-cool_activation_multiple_input_output-a6ca45beb9f411e99426dc4a3e957995-0-None-None"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32>
// CHECK: %[[OUTPUT2:[0-9]*]] = "tf.Identity"(%[[OP_HINT_CALL]]#1) {T = "tfdtype$DT_FLOAT", _tflite_function_name = "cool_activation_multiple_input_output", _tflite_function_output_index = 1 : i64, _tflite_function_uuid = "a6ca45beb9f411e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "OutputHint-cool_activation_multiple_input_output-a6ca45beb9f411e99426dc4a3e957995-1-None-None"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32>
// CHECK: %[[MULTI_INPUT_CALL:[0-9]*]]:2 = call @a6ca45beb9f411e99426dc4a3e957995(%0, %1) : (tensor<1x16x1xf32>, tensor<1x16x1xf32>) -> (tensor<1x16x1xf32>, tensor<1x16x1xf32>)
%0 = "tf.Placeholder"() {dtype = "tfdtype$DT_FLOAT", name = "Placeholder", shape = "tfshape$dim { size: 1 } dim { size: 16 } dim { size: 1 }"} : () -> tensor<1x16x1xf32>
%1 = "tf.Identity"(%0) {T = "tfdtype$DT_FLOAT", _tflite_function_input_index = 0 : i64, _tflite_function_name = "cool_activation_multiple_input_output", _tflite_function_uuid = "a6ca45beb9f411e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "InputHint-cool_activation_multiple_input_output-a6ca45beb9f411e99426dc4a3e957995-0-None-None"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32>
@ -119,21 +145,33 @@ func @extractMultipleInputsOutputsOphint() {
return
}
// CHECK: func @d4b1eb00b81211e99426dc4a3e957995(tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32>
// CHECK: attributes {_tflite_function_name = "cool_activation"}
// CHECK: func @47393154b9af11e99426dc4a3e957995(tensor<2x1x16x1xf32>) -> tensor<1x16x1xf32>
// CHECK: attributes {_tflite_function_name = "cool_activation_stack"}
// CHECK: func @b703f0f4b9ec11e99426dc4a3e957995(tensor<1x16x1xf32>) -> tensor<1x16x1xf32>
// CHECK: attributes {_tflite_function_name = "cool_activation_first"}
// CHECK: func @e31fcf90b9ed11e99426dc4a3e957995(tensor<1x16x1xf32>) -> tensor<1x16x1xf32>
// CHECK: attributes {_tflite_function_name = "cool_activation_last"}
// CHECK: func @33fab028b9ef11e99426dc4a3e957995(tensor<1x1x16x1xf32>) -> tensor<1x16x1xf32>
// CHECK: attributes {_tflite_function_name = "cool_activation_pack_input_one"}
// CHECK: func @b92ed354b9f011e99426dc4a3e957995(tensor<2x1x16x1xf32>) -> tensor<2x1x16x1xf32>
// CHECK: attributes {_tflite_function_name = "cool_activation_stack_input_output"}
// CHECK: func @a6ca45beb9f411e99426dc4a3e957995(tensor<1x16x1xf32>, tensor<1x16x1xf32>) -> (tensor<1x16x1xf32>, tensor<1x16x1xf32>)
// CHECK: attributes {_tflite_function_name = "cool_activation_multiple_input_output"}
// CHECK: attributes {_tflite_function_input_index = [0 : i32, 1 : i32], _tflite_function_name = "cool_activation_multiple_input_output"}
// -----
// CHECK-LABEL: inputsAfterOutputs
func @inputsAfterOutputs() {
// CHECK: %[[PLACE_HOLDER:[0-9]*]] = "tf.Placeholder"() {device = "", dtype = "tfdtype$DT_FLOAT", name = "Placeholder_1", shape = "tfshape$dim { size: 2 } dim { size: 2 }"} : () -> tensor<2x2xf32>
// CHECK: %[[INPUT_PROCESS:[0-9]*]] = "tf.Sigmoid"(%[[PLACE_HOLDER]]) {T = "tfdtype$DT_FLOAT", device = "", name = "Sigmoid"} : (tensor<2x2xf32>) -> tensor<2x2xf32>
// CHECK: %[[OP_HINT_CALL:[0-9]*]]:2 = call @d6266124d2dd11e9b52cdc4a3e957995(%0, %1, %[[INPUT_PROCESS]]) : (tensor<2x2xf32>, tensor<f32>, tensor<2x2xf32>) -> (tensor<2x2xf32>, tensor<2x2xf32>)
%0 = "tf.Const"() {device = "", dtype = "tfdtype$DT_FLOAT", name = "Const", value = dense<0.000000e+00> : tensor<f32>} : () -> tensor<f32>
%1 = "tf.Identity"(%0) {T = "tfdtype$DT_FLOAT", _tflite_function_input_index = 1 : i64, _tflite_function_name = "CustomOp", _tflite_function_uuid = "d6266124d2dd11e9b52cdc4a3e957995", _tflite_ophint_level = 1 : i64, device = "", name = "InputHint-CustomOp-d6266124d2dd11e9b52cdc4a3e957995-1-None-None"} : (tensor<f32>) -> tensor<f32>
%2 = "tf.Placeholder"() {device = "", dtype = "tfdtype$DT_FLOAT", name = "Placeholder", shape = "tfshape$dim { size: 2 } dim { size: 2 }"} : () -> tensor<2x2xf32>
%3 = "tf.Identity"(%2) {T = "tfdtype$DT_FLOAT", _tflite_function_input_index = 0 : i64, _tflite_function_name = "CustomOp", _tflite_function_uuid = "d6266124d2dd11e9b52cdc4a3e957995", _tflite_ophint_level = 1 : i64, device = "", name = "InputHint-CustomOp-d6266124d2dd11e9b52cdc4a3e957995-0-None-None"} : (tensor<2x2xf32>) -> tensor<2x2xf32>
%4 = "tf.Add"(%3, %1) {T = "tfdtype$DT_FLOAT", device = "", name = "Add"} : (tensor<2x2xf32>, tensor<f32>) -> tensor<2x2xf32>
%5 = "tf.Identity"(%4) {T = "tfdtype$DT_FLOAT", _tflite_function_name = "CustomOp", _tflite_function_output_index = 0 : i64, _tflite_function_uuid = "d6266124d2dd11e9b52cdc4a3e957995", _tflite_ophint_level = 1 : i64, device = "", name = "OutputHint-CustomOp-d6266124d2dd11e9b52cdc4a3e957995-0-None-None"} : (tensor<2x2xf32>) -> tensor<2x2xf32>
%6 = "tf.Placeholder"() {device = "", dtype = "tfdtype$DT_FLOAT", name = "Placeholder_1", shape = "tfshape$dim { size: 2 } dim { size: 2 }"} : () -> tensor<2x2xf32>
%7 = "tf.Sigmoid"(%6) {T = "tfdtype$DT_FLOAT", device = "", name = "Sigmoid"} : (tensor<2x2xf32>) -> tensor<2x2xf32>
%8 = "tf.Identity"(%7) {T = "tfdtype$DT_FLOAT", _tflite_function_input_index = 2 : i64, _tflite_function_name = "CustomOp", _tflite_function_uuid = "d6266124d2dd11e9b52cdc4a3e957995", _tflite_ophint_level = 1 : i64, device = "", name = "InputHint-CustomOp-d6266124d2dd11e9b52cdc4a3e957995-2-None-None"} : (tensor<2x2xf32>) -> tensor<2x2xf32>
%9 = "tf.Add"(%5, %8) {T = "tfdtype$DT_FLOAT", device = "", name = "Add_1"} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
%10 = "tf.Identity"(%9) {T = "tfdtype$DT_FLOAT", _tflite_function_name = "CustomOp", _tflite_function_output_index = 1 : i64, _tflite_function_uuid = "d6266124d2dd11e9b52cdc4a3e957995", _tflite_ophint_level = 1 : i64, device = "", name = "OutputHint-CustomOp-d6266124d2dd11e9b52cdc4a3e957995-1-None-None"} : (tensor<2x2xf32>) -> tensor<2x2xf32>
return
}
// CHECK: func @d6266124d2dd11e9b52cdc4a3e957995(tensor<2x2xf32>, tensor<f32>, tensor<2x2xf32>) -> (tensor<2x2xf32>, tensor<2x2xf32>)
// CHECK: attributes {_tflite_function_input_index = [0 : i32, 1 : i32, 2 : i32], _tflite_function_name = "CustomOp"}
// -----

View File

@ -50,7 +50,7 @@ func @biasAddInt(%arg0: tensor<1x10x10x32xi32>, %arg1: tensor<32xi32>) -> tensor
func @squeezeAndReshape(%arg0: tensor<1x1x10xf32>, %arg1: tensor<?x10xf32>) -> i32 {
%0 = "tf.Squeeze"(%arg0) {squeeze_dims = [0]} : (tensor<1x1x10xf32>) -> tensor<1x10xf32>
%1 = "tf.Squeeze"(%arg1) : (tensor<?x10xf32>) -> tensor<*xf32>
%2 = constant dense<[2, 5]> : tensor<2xi32>
%2 = "tf.Const"() { value = dense<[2, 5]> : tensor<2xi32> } : () -> tensor<2xi32>
%3 = "tf.Reshape" (%0, %2) : (tensor<1x10xf32>, tensor<2xi32>) -> tensor<2x5xf32>
%4 = "some_op"(%1, %3) : (tensor<*xf32>, tensor<2x5xf32>) -> i32
return %4 : i32
@ -119,8 +119,8 @@ func @fakeQuantArgsTrue(%arg0: tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> {
}
func @fakeQuantVarsFalse(%arg0: tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> {
%arg1 = constant dense<-0.1> : tensor<f32>
%arg2 = constant dense<0.2> : tensor<f32>
%arg1 = "tf.Const"() { value = dense<-0.1> : tensor<f32> } : () -> tensor<f32>
%arg2 = "tf.Const"() { value = dense<0.2> : tensor<f32> } : () -> tensor<f32>
%0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %arg1, %arg2) {num_bits = 3, narrow_range = false} : (tensor<8x8x8x8xf32>, tensor<f32>, tensor<f32>) -> tensor<8x8x8x8xf32>
return %0 : tensor<8x8x8x8xf32>
@ -153,6 +153,14 @@ func @placeholder(%arg0: tensor<f32>) -> tensor<f32> {
// CHECK: %0 = "tfl.pseudo_input"(%arg0) : (tensor<f32>) -> tensor<f32>
}
func @placeholder_int(%arg0: tensor<i32>) -> tensor<i32> {
%0 = "tf.Placeholder.input"(%arg0) {name = "Input"} : (tensor<i32>) -> tensor<i32>
return %0: tensor<i32>
// CHECK-LABEL: @placeholder_int
// CHECK-NEXT: "tfl.pseudo_input"(%arg0) : (tensor<i32>) -> tensor<i32>
}
func @placeholder_min(%arg0: tensor<f32>) -> tensor<f32> {
%0 = "tf.Placeholder.input"(%arg0) {name = "Input", min = -0.1 : f32} : (tensor<f32>) -> tensor<f32>
return %0: tensor<f32>
@ -409,7 +417,7 @@ func @gatherNdHigherRankIndices(%arg0 : tensor<4x3x2xf32>, %arg1 : tensor<2x2xi3
}
func @gatherV2VectorIndices(%arg0 : tensor<1x2x20xf32>, %arg1 : tensor<3x5xi32>) -> tensor<1x3x5x20xf32> {
%0 = constant dense<[1]> : tensor<1xi32>
%0 = "tf.Const"() { value = dense<[1]> : tensor<1xi32> } : () -> tensor<1xi32>
%1 = "tf.GatherV2"(%arg0, %arg1, %0) : (tensor<1x2x20xf32>, tensor<3x5xi32>, tensor<1xi32>) -> tensor<1x3x5x20xf32>
return %1 : tensor<1x3x5x20xf32>
@ -418,7 +426,7 @@ func @gatherV2VectorIndices(%arg0 : tensor<1x2x20xf32>, %arg1 : tensor<3x5xi32>)
}
func @gatherV2VectorIndicesNegAxis(%arg0 : tensor<1x2x20xf32>, %arg1 : tensor<3x5xi32>) -> tensor<1x2x3x5xf32> {
%0 = constant dense<[-1]> : tensor<1xi32>
%0 = "tf.Const"() { value = dense<[-1]> : tensor<1xi32> } : () -> tensor<1xi32>
%1 = "tf.GatherV2"(%arg0, %arg1, %0) : (tensor<1x2x20xf32>, tensor<3x5xi32>, tensor<1xi32>) -> tensor<1x2x3x5xf32>
return %1 : tensor<1x2x3x5xf32>
@ -427,7 +435,7 @@ func @gatherV2VectorIndicesNegAxis(%arg0 : tensor<1x2x20xf32>, %arg1 : tensor<3x
}
func @gatherV2NonZeroBatchDims(%arg0 : tensor<1x2x20xf32>, %arg1 : tensor<3x5xi32>) -> tensor<1x2x3x5xf32> {
%0 = constant dense<[1]> : tensor<1xi32>
%0 = "tf.Const"() { value = dense<[1]> : tensor<1xi32> } : () -> tensor<1xi32>
%1 = "tf.GatherV2"(%arg0, %arg1, %0) {batch_dims = 1 : i64} : (tensor<1x2x20xf32>, tensor<3x5xi32>, tensor<1xi32>) -> tensor<1x2x3x5xf32>
return %1 : tensor<1x2x3x5xf32>
@ -509,6 +517,15 @@ func @select(%arg0: tensor<8xi1>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>) ->
// CHECK: return %0 : tensor<8xf32>
}
func @select_multidim(%arg0: tensor<8xi1>, %arg1: tensor<8x3xf32>, %arg2: tensor<8x3xf32>) -> tensor<8x3xf32> {
%0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<8xi1>, tensor<8x3xf32>, tensor<8x3xf32>) -> tensor<8x3xf32>
return %0: tensor<8x3xf32>
// CHECK-LABEL: select_multidim
// CHECK: %0 = "tfl.select"(%arg0, %arg1, %arg2)
// CHECK: return %0 : tensor<8x3xf32>
}
func @select_v2(%arg0: tensor<8xi1>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>) -> tensor<8xf32> {
%0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<8xi1>, tensor<8xf32>, tensor<8xf32>) -> tensor<8xf32>
return %0: tensor<8xf32>
@ -518,6 +535,15 @@ func @select_v2(%arg0: tensor<8xi1>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>)
// CHECK: return %0 : tensor<8xf32>
}
func @select_v2_multidim(%arg0: tensor<8xi1>, %arg1: tensor<8x3xf32>, %arg2: tensor<8x3xf32>) -> tensor<8x3xf32> {
%0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<8xi1>, tensor<8x3xf32>, tensor<8x3xf32>) -> tensor<8x3xf32>
return %0: tensor<8x3xf32>
// CHECK-LABEL: select_v2_multidim
// CHECK: %0 = "tfl.select"(%arg0, %arg1, %arg2)
// CHECK: return %0 : tensor<8x3xf32>
}
func @sin(%arg0: tensor<f32>) -> tensor<f32> {
%0 = "tf.Sin"(%arg0) : (tensor<f32>) -> tensor<f32>
return %0 : tensor<f32>
@ -536,7 +562,7 @@ func @topk(%arg0: tensor<8xf32>, %arg1: tensor<i32>) -> (tensor<?xf32>, tensor<?
}
func @topk_2(%arg0: tensor<8xf32>) -> (tensor<2xf32>, tensor<2xi32>) {
%0 = constant dense<2> : tensor<i32>
%0 = "tf.Const"() { value = dense<2> : tensor<i32> } : () -> tensor<i32>
%1:2 = "tf.TopKV2"(%arg0, %0) : (tensor<8xf32>, tensor<i32>) -> (tensor<2xf32>, tensor<2xi32>)
return %1#0, %1#1: tensor<2xf32>, tensor<2xi32>
@ -546,7 +572,7 @@ func @topk_2(%arg0: tensor<8xf32>) -> (tensor<2xf32>, tensor<2xi32>) {
}
func @topk_3(%arg0: tensor<?x8xf32>) -> (tensor<?x2xf32>, tensor<?x2xi32>) {
%0 = constant dense<2> : tensor<i32>
%0 = "tf.Const"() { value = dense<2> : tensor<i32> } : () -> tensor<i32>
%1:2 = "tf.TopKV2"(%arg0, %0) : (tensor<?x8xf32>, tensor<i32>) -> (tensor<?x2xf32>, tensor<?x2xi32>)
return %1#0, %1#1: tensor<?x2xf32>, tensor<?x2xi32>
@ -556,7 +582,7 @@ func @topk_3(%arg0: tensor<?x8xf32>) -> (tensor<?x2xf32>, tensor<?x2xi32>) {
}
func @topk_4(%arg0: tensor<1x2x3x4xf32>) -> (tensor<1x2x3x2xf32>, tensor<1x2x3x2xi32>) {
%0 = constant dense<2> : tensor<i32>
%0 = "tf.Const"() { value = dense<2> : tensor<i32> } : () -> tensor<i32>
%1:2 = "tf.TopKV2"(%arg0, %0) : (tensor<1x2x3x4xf32>, tensor<i32>) -> (tensor<1x2x3x2xf32>, tensor<1x2x3x2xi32>)
return %1#0, %1#1: tensor<1x2x3x2xf32>, tensor<1x2x3x2xi32>
@ -566,7 +592,7 @@ func @topk_4(%arg0: tensor<1x2x3x4xf32>) -> (tensor<1x2x3x2xf32>, tensor<1x2x3x2
}
func @topk_5(%arg0: tensor<*xf32>) -> (tensor<*xf32>, tensor<*xi32>) {
%0 = constant dense<2> : tensor<i32>
%0 = "tf.Const"() { value = dense<2> : tensor<i32> } : () -> tensor<i32>
%1:2 = "tf.TopKV2"(%arg0, %0) : (tensor<*xf32>, tensor<i32>) -> (tensor<*xf32>, tensor<*xi32>)
return %1#0, %1#1: tensor<*xf32>, tensor<*xi32>
@ -671,7 +697,7 @@ func @pow(%arg0: tensor<2x1x3xf32>, %arg1: tensor<2x1x1xf32>) -> tensor<2x1x3xf3
func @tile(tensor<2x3xf32>, tensor<2xi32>) -> tensor<2x6xf32> {
^bb0(%arg0: tensor<2x3xf32>, %arg1: tensor<2xi32>):
%cst = constant dense<[1, 2]> : tensor<2xi32>
%cst = "tf.Const"() { value = dense<[1, 2]> : tensor<2xi32> } : () -> tensor<2xi32>
%0 = "tf.Tile"(%arg0, %cst) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<2x6xf32>
return %0 : tensor<2x6xf32>
@ -682,7 +708,7 @@ func @tile(tensor<2x3xf32>, tensor<2xi32>) -> tensor<2x6xf32> {
func @padv2(tensor<2x1x3xf32>, tensor<3x2xi32>) -> tensor<? x f32> {
^bb0(%arg0: tensor<2x1x3xf32>, %arg1: tensor<3x2xi32>):
%cst = constant dense<2.0> : tensor<f32>
%cst = "tf.Const"() { value = dense<2.0> : tensor<f32> } : () -> tensor<f32>
%0 = "tf.PadV2"(%arg0, %arg1, %cst) : (tensor<2x1x3xf32>, tensor<3x2xi32>, tensor<f32>) -> tensor<? x f32>
return %0#0 : tensor<? x f32>
@ -858,8 +884,8 @@ func @matmul_transposed(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) -> t
}
func @concat2Tensors(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2x2xi32> {
%0 = constant dense<[1]> : tensor<1xi32>
%1 = "tf.Concat"(%0, %arg0, %arg1) {N = 2 : i64} : (tensor<1xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<2x2xi32>
%0 = "tf.Const"() { value = dense<1> : tensor<i32> } : () -> tensor<i32>
%1 = "tf.Concat"(%0, %arg0, %arg1) {N = 2 : i64} : (tensor<i32>, tensor<2xi32>, tensor<2xi32>) -> tensor<2x2xi32>
return %1 : tensor<2x2xi32>
// CHECK-LABEL: concat2Tensors
@ -867,8 +893,8 @@ func @concat2Tensors(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2x2xi
}
func @concat3Tensors(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>, %arg2: tensor<2xi32>) -> tensor<2x3xi32> {
%0 = constant dense<[-1]> : tensor<1xi32>
%1 = "tf.Concat"(%0, %arg0, %arg1, %arg2) {N = 3 : i64} : (tensor<1xi32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<2x3xi32>
%0 = "tf.Const"() { value = dense<-1> : tensor<i32> } : () -> tensor<i32>
%1 = "tf.Concat"(%0, %arg0, %arg1, %arg2) {N = 3 : i64} : (tensor<i32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<2x3xi32>
return %1 : tensor<2x3xi32>
// CHECK-LABEL: concat3Tensors
@ -876,8 +902,8 @@ func @concat3Tensors(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>, %arg2: tensor<2
}
func @concatv2With3Tensors(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>, %arg2: tensor<2xi32>) -> tensor<2x3xi32> {
%0 = constant dense<[-1]> : tensor<1xi32>
%1 = "tf.ConcatV2"(%arg0, %arg1, %arg2, %0) {N = 3 : i64} : (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<1xi32>) -> tensor<2x3xi32>
%0 = "tf.Const"() { value = dense<-1> : tensor<i32> } : () -> tensor<i32>
%1 = "tf.ConcatV2"(%arg0, %arg1, %arg2, %0) {N = 3 : i64} : (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<i32>) -> tensor<2x3xi32>
return %1 : tensor<2x3xi32>
// CHECK-LABEL: concatv2With3Tensors
@ -1093,3 +1119,35 @@ func @depth_to_space(%arg0: tensor<1x1x1x4xf32>) -> tensor<1x2x2x1xf32> {
// CHECK: %[[ARG:.*]]: tensor<1x1x1x4xf32>
// CHECK: "tfl.depth_to_space"(%[[ARG]]) {block_size = 2 : i32} : (tensor<1x1x1x4xf32>) -> tensor<1x2x2x1xf32>
}
func @non_max_suppression_v4(%arg0: tensor<3x4xf32>, %arg1: tensor<3xf32>, %arg2: tensor<i32>, %arg3: tensor<f32>, %arg4: tensor<f32>) -> tensor<2xi32> {
%0:2 = "tf.NonMaxSuppressionV4"(%arg0, %arg1, %arg2, %arg3, %arg4) {pad_to_max_output_size = true}: (tensor<3x4xf32>, tensor<3xf32>, tensor<i32>, tensor<f32>, tensor<f32>) -> (tensor<2xi32>, tensor<i32>)
return %0#0 : tensor<2xi32>
// CHECK-LABEL: non_max_suppression_v4
// CHECK: %0:2 = "tfl.non_max_suppression_v4"(%arg0, %arg1, %arg2, %arg3, %arg4) : (tensor<3x4xf32>, tensor<3xf32>, tensor<i32>, tensor<f32>, tensor<f32>) -> (tensor<2xi32>, tensor<i32>)
}
func @non_max_suppression_v4_no_pad(%arg0: tensor<3x4xf32>, %arg1: tensor<3xf32>, %arg2: tensor<i32>, %arg3: tensor<f32>, %arg4: tensor<f32>) -> tensor<2xi32> {
%0:2 = "tf.NonMaxSuppressionV4"(%arg0, %arg1, %arg2, %arg3, %arg4) {pad_to_max_output_size = false}: (tensor<3x4xf32>, tensor<3xf32>, tensor<i32>, tensor<f32>, tensor<f32>) -> (tensor<2xi32>, tensor<i32>)
return %0#0 : tensor<2xi32>
// CHECK-LABEL: non_max_suppression_v4_no_pad
// CHECK: %0:2 = "tfl.non_max_suppression_v4"(%arg0, %arg1, %arg2, %arg3, %arg4) : (tensor<3x4xf32>, tensor<3xf32>, tensor<i32>, tensor<f32>, tensor<f32>) -> (tensor<2xi32>, tensor<i32>)
}
func @non_max_suppression_v5(%arg0: tensor<3x4xf32>, %arg1: tensor<3xf32>, %arg2: tensor<i32>, %arg3: tensor<f32>, %arg4: tensor<f32>, %arg5: tensor<f32>) -> tensor<2xi32> {
%0:3 = "tf.NonMaxSuppressionV5"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {pad_to_max_output_size = true}: (tensor<3x4xf32>, tensor<3xf32>, tensor<i32>, tensor<f32>, tensor<f32>, tensor<f32>) -> (tensor<2xi32>, tensor<2xf32>, tensor<i32>)
return %0#0 : tensor<2xi32>
// CHECK-LABEL: non_max_suppression_v5
// CHECK: %0:3 = "tfl.non_max_suppression_v5"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) : (tensor<3x4xf32>, tensor<3xf32>, tensor<i32>, tensor<f32>, tensor<f32>, tensor<f32>) -> (tensor<2xi32>, tensor<2xf32>, tensor<i32>)
}
func @non_max_suppression_v5_no_pad(%arg0: tensor<3x4xf32>, %arg1: tensor<3xf32>, %arg2: tensor<i32>, %arg3: tensor<f32>, %arg4: tensor<f32>, %arg5: tensor<f32>) -> tensor<2xi32> {
%0:3 = "tf.NonMaxSuppressionV5"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {pad_to_max_output_size = false}: (tensor<3x4xf32>, tensor<3xf32>, tensor<i32>, tensor<f32>, tensor<f32>, tensor<f32>) -> (tensor<2xi32>, tensor<2xf32>, tensor<i32>)
return %0#0 : tensor<2xi32>
// CHECK-LABEL: non_max_suppression_v5_no_pad
// CHECK: %0:3 = "tfl.non_max_suppression_v5"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) : (tensor<3x4xf32>, tensor<3xf32>, tensor<i32>, tensor<f32>, tensor<f32>, tensor<f32>) -> (tensor<2xi32>, tensor<2xf32>, tensor<i32>)
}

View File

@ -0,0 +1,107 @@
// RUN: tf-opt -tfl-load-recipe %s | FileCheck %s --dump-input-on-failure
// CHECK-LABEL: testLstm
func @testLstm(%arg0: tensor<? x f32>, %arg1: tensor<?xf32>, %arg2: tensor<?xf32>, %arg3: tensor<?xf32>, %arg4: tensor<?xf32>, %arg5: tensor<?xf32>, %arg6: tensor<?xf32>, %arg7: tensor<?xf32>, %arg8: tensor<? x f32>, %arg9: tensor<? x f32>, %arg10: tensor<? x f32>, %arg11: tensor<? x f32>, %arg12: tensor<? x f32>, %arg13: tensor<? x f32>, %arg14: tensor<? x f32>, %arg15: tensor<? x f32>, %arg16: tensor<? x f32>, %arg17: tensor<? x f32>, %arg18: tensor<? x f32>, %arg19: tensor<? x f32>, %arg20: tensor<? x f32>, %arg21: tensor<? x f32>, %arg22: tensor<? x f32>, %arg23: tensor<? x f32>) -> tensor<? x f32> {
%0 = "tfl.lstm"(%arg0, // input
%arg1, %arg2, %arg3, %arg4, // weights
%arg5, %arg6, %arg7, %arg8, // recurrent weights
%arg9, %arg10, %arg11, // cell weights
%arg12, %arg13, %arg14, %arg15, // bias
%arg16, %arg17, // projection weight and bias
%arg18, %arg19, // stateful
%arg20, %arg21, %arg22, %arg23 // layer norm coefficients
) ({}) {fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor<? xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
// CHECK-NEXT: "tfl.lstm"
// CHECK-NEXT: %[[cst:.*]] = constant unit
// input gate
// CHECK-NEXT: %[[in1:.*]] = "tfl.fully_connected"(%arg0, %arg1, %[[cst]])
// CHECK-SAME: -> tensor<?x!quant.any<i16:f32>>
// CHECK-NEXT: %[[in2:.*]] = "tfl.fully_connected"(%arg18, %arg5, %[[cst]])
// CHECK-SAME: -> tensor<?x!quant.any<i16:f32>>
// CHECK-NEXT: %[[in3:.*]] = "tfl.mul"(%arg19, %arg9)
// CHECK-SAME: -> tensor<?x!quant.any<i16:f32>>
// CHECK-NEXT: %[[in4:.*]] = "tfl.add_n"(%[[in1]], %[[in2]], %[[in3]])
// CHECK-SAME: -> tensor<?x!quant.any<i16:f32>>
// CHECK-NEXT: %[[in5:.*]] = "tfl.l2_normalization"(%[[in4]])
// CHECK-SAME: -> tensor<?x!quant.any<i16:f32>>
// CHECK-NEXT: %[[in6:.*]] = tfl.add %[[in4]], %[[in5]]
// CHECK-SAME: tensor<?x!quant.any<i16:f32>>
// CHECK-NEXT: %[[in7:.*]] = "tfl.fully_connected"(%[[in6]], %arg20, %arg12)
// CHECK-SAME: -> tensor<?x!quant.any<i16:f32>>
// CHECK-NEXT: %[[in8:.*]] = "tfl.logistic"(%[[in7]])
// CHECK-SAME: -> tensor<?x!quant.any<i16:f32>>
// forget gate
// CHECK-NEXT: %[[fo1:.*]] = "tfl.fully_connected"(%arg0, %arg2, %[[cst]])
// CHECK-SAME: tensor<?x!quant.any<i16:f32>>
// CHECK-NEXT: %[[fo2:.*]] = "tfl.fully_connected"(%arg18, %arg6, %[[cst]])
// CHECK-SAME: -> tensor<?x!quant.any<i16:f32>>
// CHECK-NEXT: %[[fo3:.*]] = "tfl.mul"(%arg19, %arg10)
// CHECK-SAME: -> tensor<?x!quant.any<i16:f32>>
// CHECK-NEXT: %[[fo4:.*]] = "tfl.add_n"(%[[fo1]], %[[fo2]], %[[fo3]])
// CHECK-SAME: -> tensor<?x!quant.any<i16:f32>>
// CHECK-NEXT: %[[fo5:.*]] = "tfl.l2_normalization"(%[[fo4]])
// CHECK-SAME: -> tensor<?x!quant.any<i16:f32>>
// CHECK-NEXT: %[[fo6:.*]] = tfl.add %[[fo4]], %[[fo5]]
// CHECK-SAME: tensor<?x!quant.any<i16:f32>>
// CHECK-NEXT: %[[fo7:.*]] = "tfl.fully_connected"(%[[fo6]], %arg21, %arg13)
// CHECK-SAME: -> tensor<?x!quant.any<i16:f32>>
// CHECK-NEXT: %[[fo8:.*]] = "tfl.logistic"(%[[fo7]])
// CHECK-SAME: -> tensor<?x!quant.any<i16:f32>>
// cell gate
// CHECK-NEXT: %[[ce1:.*]] = "tfl.fully_connected"(%arg0, %arg3, %[[cst]])
// CHECK-SAME: -> tensor<?x!quant.any<i16:f32>>
// CHECK-NEXT: %[[ce2:.*]] = "tfl.fully_connected"(%arg18, %arg7, %[[cst]])
// CHECK-SAME: -> tensor<?x!quant.any<i16:f32>>
// CHECK-NEXT: %[[ce3:.*]] = "tfl.add_n"(%[[ce1]], %[[ce2]])
// CHECK-SAME: -> tensor<?x!quant.any<i16:f32>>
// CHECK-NEXT: %[[ce4:.*]] = "tfl.l2_normalization"(%[[ce3]])
// CHECK-SAME: -> tensor<?x!quant.any<i16:f32>>
// CHECK-NEXT: %[[ce5:.*]] = tfl.add %[[ce3]], %[[ce4]]
// CHECK-SAME: tensor<?x!quant.any<i16:f32>>
// CHECK-NEXT: %[[ce6:.*]] = "tfl.fully_connected"(%[[ce5]], %arg22, %arg14)
// CHECK-SAME: -> tensor<?x!quant.any<i16:f32>>
// CHECK-NEXT: %[[ce7:.*]] = "tfl.tanh"(%[[ce6]])
// CHECK-SAME: -> tensor<?x!quant.any<i16:f32>>
// CHECK-NEXT: %[[ac1:.*]] = "tfl.mul"(%[[fo8]], %arg19)
// CHECK-SAME: -> tensor<?x!quant.any<i16:f32>>
// CHECK-NEXT: %[[ac2:.*]] = tfl.mul %[[in8]], %[[ce7]]
// CHECK-SAME: tensor<?x!quant.any<i16:f32>>
// CHECK-NEXT: %[[ac3:.*]] = tfl.add %[[ac1]], %[[ac2]]
// CHECK-SAME: tensor<?x!quant.any<i16:f32>>
// output gate
// CHECK-NEXT: %[[ou1:.*]] = "tfl.fully_connected"(%arg0, %arg4, %[[cst]])
// CHECK-SAME: -> tensor<?x!quant.any<i16:f32>>
// CHECK-NEXT: %[[ou2:.*]] = "tfl.fully_connected"(%arg18, %arg8, %[[cst]])
// CHECK-SAME: -> tensor<?x!quant.any<i16:f32>>
// CHECK-NEXT: %[[ou3:.*]] = "tfl.mul"(%[[ac3]], %arg11)
// CHECK-SAME: -> tensor<?x!quant.any<i16:f32>>
// CHECK-NEXT: %[[ou4:.*]] = "tfl.add_n"(%[[ou1]], %[[ou2]], %[[ou3]])
// CHECK-SAME: -> tensor<?x!quant.any<i16:f32>>
// CHECK-NEXT: %[[ou5:.*]] = "tfl.l2_normalization"(%[[ou4]])
// CHECK-SAME: -> tensor<?x!quant.any<i16:f32>>
// CHECK-NEXT: %[[ou6:.*]] = tfl.add %[[ou4]], %[[ou5]]
// CHECK-SAME: tensor<?x!quant.any<i16:f32>>
// CHECK-NEXT: %[[ou7:.*]] = "tfl.fully_connected"(%[[ou6]], %arg23, %arg15)
// CHECK-SAME: -> tensor<?x!quant.any<i16:f32>>
// CHECK-NEXT: %[[ou8:.*]] = "tfl.logistic"(%[[ou7]])
// CHECK-SAME: -> tensor<?x!quant.any<i16:f32>>
// output activation
// CHECK-NEXT: %[[ac4:.*]] = "tfl.tanh"(%[[ac3]])
// CHECK-SAME: -> tensor<?x!quant.any<i16:f32>>
// CHECK-NEXT: %[[ac5:.*]] = tfl.mul %[[ac4]], %[[ou8]]
// CHECK-SAME: tensor<?x!quant.any<i16:f32>>
// CHECK-NEXT: %[[ac6:.*]] = "tfl.fully_connected"(%[[ac5]], %arg16, %arg17)
// CHECK-SAME: (tensor<?x!quant.any<i16:f32>>, tensor<?xf32>, tensor<?xf32>) -> tensor<?x!quant.any<i8:f32>>
// CHECK-NEXT: %[[ac7:.*]] = "tf_quant.pseudo_return"(%[[ac6]]) : (tensor<?x!quant.any<i8:f32>>) -> tensor<?x!quant.any<i8:f32>>
// CHECK-NEXT: })
// CHECK-NEXT: return
return %0 : tensor<?xf32>
}

View File

@ -143,6 +143,19 @@ func @tensorlistPushBack(%arg0: tensor<3x10xf32>, %arg1: tensor<1xi32>, %arg2: t
// CHECK: return [[RESULT]] : tensor<?x10xf32>
}
func @tensorlistLength(%arg0: tensor<3x10xf32>, %arg1: tensor<1xi32>) -> (tensor<i32>) {
%0 = "tf.TensorListFromTensor"(%arg0, %arg1) : (tensor<3x10xf32>, tensor<1xi32>) -> tensor<!tf.variant<tensor<10xf32>>>
%1 = "tf.TensorListLength"(%0) : (tensor<!tf.variant<tensor<10xf32>>>) -> tensor<i32>
return %1: tensor<i32>
// CHECK-LABEL: tensorlistLength
// CHECK-SAME: ([[INPUT:%.*]]: tensor<3x10xf32>, [[ELEM_SHAPE:%.*]]: tensor<1xi32>)
// CHECK-DAG: [[SHAPE:%.*]] = "tf.Shape"([[INPUT]]) {{.*}} -> tensor<2xi32>
// CHECK-DAG: [[ZERO:%cst.*]] = constant dense<0> : tensor<i32>
// CHECK: [[RESULT:%.*]] = "tf.Gather"([[SHAPE]], [[ZERO]]) {validate_indices = true} : (tensor<2xi32>, tensor<i32>) -> tensor<i32>
// CHECK: return [[RESULT]] : tensor<i32>
}
func @tensorlistWhileLoop(%arg0: tensor<2x3xf32>) -> tensor<*xf32> {
%cst = constant dense<3> : tensor<1xi32>
%cst_0 = constant dense<0> : tensor<i32>

View File

@ -278,6 +278,6 @@ func @main(tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, t
%21 = "tfl.pseudo_input" (%arg21) : (tensor<4 x f32>) -> tensor<4 x f32>
%22 = "tfl.pseudo_input" (%arg22) : (tensor<4 x f32>) -> tensor<4 x f32>
%23 = "tfl.pseudo_input" (%arg23) : (tensor<4 x f32>) -> tensor<4 x f32>
%24 = "tfl.lstm"(%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23) {fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
%24 = "tfl.lstm"(%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23) ({}) {fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
return %24 : tensor<4xf32>
}

View File

@ -0,0 +1,31 @@
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck --dump-input-on-failure %s
module attributes {
tfl.metadata = {key1 = "value1", key2 = "value2"}
} {
func @main(tensor<3x2xi32>) -> tensor<3x2xi32>
attributes {tf.entry_function = {inputs = "input", outputs = "SameNameAsOutput"}} {
^bb0(%arg0: tensor<3x2xi32>):
%0 = "tfl.pseudo_input" (%arg0) : (tensor<3x2xi32>) -> tensor<3x2xi32> loc("Input")
%1 = "tfl.pseudo_const" () {value = dense<[[1, 2], [3, 4], [5, 6]]> : tensor<3x2xi32>} : () -> tensor<3x2xi32>
%2 = "tfl.sub" (%0, %1) {fused_activation_function = "NONE"} : (tensor<3x2xi32>, tensor<3x2xi32>) -> tensor<3x2xi32>
return %2 : tensor<3x2xi32>
}
}
// CHECK: buffers: [ {
// CHECK: }, {
// CHECK: }, {
// CHECK: }, {
// CHECK: }, {
// CHECK-NEXT: data: [ 118, 97, 108, 117, 101, 49 ]
// CHECK-NEXT: }, {
// CHECK-NEXT: data: [ 118, 97, 108, 117, 101, 50 ]
// CHECK-NEXT: } ],
// CHECK-NEXT: metadata: [ {
// CHECK-NEXT: name: "key1",
// CHECK-NEXT: buffer: 4
// CHECK-NEXT: }, {
// CHECK-NEXT: name: "key2",
// CHECK-NEXT: buffer: 5
// CHECK-NEXT: } ]

View File

@ -1,5 +1,4 @@
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string -
// | FileCheck %s
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck %s
func @main(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) -> tensor<40x40xf32> {
%cst = constant unit
@ -9,7 +8,7 @@ func @main(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) -> tensor<40x40xf
return %2 : tensor<40x40xf32>
}
// CHECK-NEXT: operators: [ {
// CHECK: operators: [ {
// CHECK-NEXT: inputs: [ 0, 1, -1 ],
// CHECK-NEXT: outputs: [ 2, 3 ],
// CHECK-NEXT: builtin_options_type: FullyConnectedOptions,

View File

@ -103,7 +103,7 @@ func @testAddN(tensor<? x f32>, tensor<? x f32>, tensor<? x f32>) -> tensor<? x
// test invalid AddN
func @testAddNWrongOperandResultType(tensor<? x f16>, tensor<? x f16>, tensor<? x f16>) -> tensor<? x f16> {
^bb0(%arg0: tensor<? x f16>, %arg1: tensor<? x f16>, %arg2: tensor<? x f16>):
// expected-error @+1 {{'tfl.add_n' op operand #0 must be tensor of 32-bit float or 32-bit integer values}}
// expected-error @+1 {{'tfl.add_n' op operand #0 must be tensor of 32-bit float or 32-bit integer or QI16 type or QUI16 type values}}
%0 = "tfl.add_n"(%arg0, %arg1, %arg2): (tensor<? x f16>, tensor<? x f16>, tensor<? x f16>) -> tensor<? x f16>
return %0 : tensor<? x f16>
}
@ -537,7 +537,7 @@ func @testLogistic(tensor<1x2x3x4x5xbf16>) -> tensor<1x2x3x4x5xbf16> {
// test invalid Logistic input
func @testLogisticWithWrongInputType(tensor<?xi32>) -> tensor<?xi32> {
^bb0(%arg0: tensor<?xi32>):
// expected-error @+1 {{tfl.logistic' op operand #0 must be tensor of floating-point or QI8 type or QUI8 type values}}
// expected-error @+1 {{tfl.logistic' op operand #0 must be tensor of floating-point or QI8 type or QUI8 type or QI16 type or QUI16 type values}}
%0 = "tfl.logistic"(%arg0): (tensor<?xi32>) -> tensor<?xi32>
return %0#0 : tensor<?xi32>
}
@ -591,8 +591,9 @@ func @testUnidirectionalSequenceLstmWithInvalidNoneType(%arg0: tensor<? x f32>,
// CHECK-LABEL: testLstm
func @testLstm(%arg0: tensor<? x f32>, %arg1: tensor<? x f32>, %arg2: tensor<? x f32>, %arg3: tensor<? x f32>, %arg4: tensor<? x f32>, %arg5: tensor<? x f32>, %arg6: tensor<? x f32>, %arg7: tensor<? x f32>, %arg8: tensor<? x f32>, %arg9: tensor<? x f32>, %arg10: tensor<? x f32>, %arg11: tensor<? x f32>, %arg12: tensor<? x f32>, %arg13: tensor<? x f32>, %arg14: tensor<? x f32>, %arg15: tensor<? x f32>, %arg16: tensor<? x f32>, %arg17: tensor<? x f32>, %arg18: tensor<? x f32>, %arg19: tensor<? x f32>, %arg20: tensor<? x f32>, %arg21: tensor<? x f32>, %arg22: tensor<? x f32>, %arg23: tensor<? x f32>) -> tensor<? x f32> {
// CHECK: "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%0 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
// CHECK: "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23)
// CHECK-NEXT: {fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%0 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) ({}) {fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
@ -600,8 +601,9 @@ func @testLstm(%arg0: tensor<? x f32>, %arg1: tensor<? x f32>, %arg2: tensor<? x
// CHECK-LABEL: testLstmWithNoneTypeAndOverrideAttr
func @testLstmWithNoneTypeAndOverrideAttr(%arg0: tensor<? x f32>, %arg1: none, %arg2: tensor<? x f32>, %arg3: tensor<? x f32>, %arg4: tensor<? x f32>, %arg5: tensor<? x f32>, %arg6: tensor<? x f32>, %arg7: tensor<? x f32>, %arg8: tensor<? x f32>, %arg9: tensor<? x f32>, %arg10: tensor<? x f32>, %arg11: tensor<? x f32>, %arg12: tensor<? x f32>, %arg13: tensor<? x f32>, %arg14: tensor<? x f32>, %arg15: tensor<? x f32>, %arg16: tensor<? x f32>, %arg17: tensor<? x f32>, %arg18: tensor<? x f32>, %arg19: tensor<? x f32>, %arg20: tensor<? x f32>, %arg21: tensor<? x f32>, %arg22: tensor<? x f32>, %arg23: tensor<? x f32>) -> tensor<? x f32> {
// CHECK: "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor<?xf32>, none, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%0 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor<?xf32>, none, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
// CHECK: "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23)
// CHECK-NEXT: {cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor<?xf32>, none, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%0 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) ({}) {cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor<?xf32>, none, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
@ -610,7 +612,7 @@ func @testLstmWithNoneTypeAndOverrideAttr(%arg0: tensor<? x f32>, %arg1: none, %
// test invalid none type applied to a tensor type arg
func @testLstmWithInvalidNoneType(%arg0: tensor<? x f32>, %arg1: tensor<? x f32>, %arg2: none, %arg3: tensor<? x f32>, %arg4: tensor<? x f32>, %arg5: tensor<? x f32>, %arg6: tensor<? x f32>, %arg7: tensor<? x f32>, %arg8: tensor<? x f32>, %arg9: tensor<? x f32>, %arg10: tensor<? x f32>, %arg11: tensor<? x f32>, %arg12: tensor<? x f32>, %arg13: tensor<? x f32>, %arg14: tensor<? x f32>, %arg15: tensor<? x f32>, %arg16: tensor<? x f32>, %arg17: tensor<? x f32>, %arg18: tensor<? x f32>, %arg19: tensor<? x f32>, %arg20: tensor<? x f32>, %arg21: tensor<? x f32>, %arg22: tensor<? x f32>, %arg23: tensor<? x f32>) -> tensor<? x f32> {
// expected-error @+1 {{'tfl.lstm' op operand #2 must be tensor of 32-bit float or 8-bit integer values}}
%0 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {fused_activation_function = "NONE"} : (tensor<?xf32>, tensor<? x f32>, none, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%0 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) ({}) {fused_activation_function = "NONE"} : (tensor<?xf32>, tensor<? x f32>, none, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
@ -619,7 +621,7 @@ func @testLstmWithInvalidNoneType(%arg0: tensor<? x f32>, %arg1: tensor<? x f32>
// test violation of projection weight and projection bias pred op trait
func @testLstmWithInvalidNoneType(%arg0: tensor<? x f32>, %arg1: tensor<? x f32>, %arg2: tensor<? x f32>, %arg3: tensor<? x f32>, %arg4: tensor<? x f32>, %arg5: tensor<? x f32>, %arg6: tensor<? x f32>, %arg7: tensor<? x f32>, %arg8: tensor<? x f32>, %arg9: tensor<? x f32>, %arg10: tensor<? x f32>, %arg11: tensor<? x f32>, %arg12: tensor<? x f32>, %arg13: tensor<? x f32>, %arg14: tensor<? x f32>, %arg15: tensor<? x f32>, %arg16: none, %arg17: tensor<? x f32>, %arg18: tensor<? x f32>, %arg19: tensor<? x f32>, %arg20: tensor<? x f32>, %arg21: tensor<? x f32>, %arg22: tensor<? x f32>, %arg23: tensor<? x f32>) -> tensor<? x f32> {
// expected-error @+1 {{'tfl.lstm' op failed to verify that either projection weight must be specified or both projection weight and projection bias must not be specified}}
%0 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {fused_activation_function = "NONE"} : (tensor<?xf32>, tensor<? x f32>, tensor<? x f32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, none, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%0 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) ({}) {fused_activation_function = "NONE"} : (tensor<?xf32>, tensor<? x f32>, tensor<? x f32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, none, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
@ -628,7 +630,7 @@ func @testLstmWithInvalidNoneType(%arg0: tensor<? x f32>, %arg1: tensor<? x f32>
// test invalid kernel type
func @testLstmWithInvalidKernelType(%arg0: tensor<? x f32>, %arg1: tensor<? x f32>, %arg2: tensor<? x f32>, %arg3: tensor<? x f32>, %arg4: tensor<? x f32>, %arg5: tensor<? x f32>, %arg6: tensor<? x f32>, %arg7: tensor<? x f32>, %arg8: tensor<? x f32>, %arg9: tensor<? x f32>, %arg10: tensor<? x f32>, %arg11: tensor<? x f32>, %arg12: tensor<? x f32>, %arg13: tensor<? x f32>, %arg14: tensor<? x f32>, %arg15: tensor<? x f32>, %arg16: tensor<? x f32>, %arg17: tensor<? x f32>, %arg18: tensor<? x f32>, %arg19: tensor<? x f32>, %arg20: tensor<? x f32>, %arg21: tensor<? x f32>, %arg22: tensor<? x f32>, %arg23: tensor<? x f32>) -> tensor<? x f32> {
// expected-error @+1 {{'tfl.lstm' op attribute 'kernel_type' failed to satisfy constraint: lstm kernel type enum case FULL}}
%0 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "BASIC"} : (tensor<?xf32>, tensor<? x f32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%0 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) ({}) {cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "BASIC"} : (tensor<?xf32>, tensor<? x f32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
@ -652,6 +654,15 @@ func @testSelect(%cond : tensor<?xi1>, %arg0 : tensor<?xi32>, %arg1 : tensor<?xi
// -----
// test select with multi-dim inputs
// CHECK-LABEL: testSelectMultiDim
func @testSelectMultiDim(%cond : tensor<?xi1>, %arg0 : tensor<?x4xi32>, %arg1 : tensor<?x4xi32>) -> tensor<?x4xi32> {
%0 = "tfl.select"(%cond, %arg0, %arg1): (tensor<?xi1>,tensor<?x4xi32>,tensor<?x4xi32>) -> tensor<?x4xi32>
return %0 : tensor<?x4xi32>
}
// -----
func @testSelectWithUnsupportedType(%cond : tensor<?xi32>, %arg0 : tensor<?xi32>, %arg1 : tensor<?xi32>) -> tensor<?xi32> {
// expected-error @+1 {{op operand #0 must be tensor of 1-bit integer values}}
%0 = "tfl.select"(%cond, %arg0, %arg1): (tensor<?xi32>,tensor<?xi32>,tensor<?xi32>) -> tensor<?xi32>
@ -660,6 +671,14 @@ func @testSelectWithUnsupportedType(%cond : tensor<?xi32>, %arg0 : tensor<?xi32>
// -----
func @testSelectWithUnsupportedShapes(%cond : tensor<2xi1>, %arg0 : tensor<3xi32>, %arg1 : tensor<3xi32>) -> tensor<3xi32> {
// expected-error @+1 {{failed to verify that Select operands meet shape criteria}}
%0 = "tfl.select"(%cond, %arg0, %arg1): (tensor<2xi1>,tensor<3xi32>,tensor<3xi32>) -> tensor<3xi32>
return %0 : tensor<3xi32>
}
// -----
func @testSelectWithUnsupportedType(%cond : tensor<?xi1>, %arg0 : tensor<?xi32>, %arg1 : tensor<?xf32>) -> tensor<?xi32> {
// expected-error @+1 {{failed to verify that operands have same element type}}
%0 = "tfl.select"(%cond, %arg0, %arg1): (tensor<?xi1>,tensor<?xi32>,tensor<?xf32>) -> tensor<?xi32>
@ -762,6 +781,21 @@ func @testPadWithInvalidPaddingsRank(tensor<2x1x3xf32>, tensor<1x3x2xi32>) -> te
// -----
// CHECK-LABEL: testPadQuantizedU8
func @testPadQuantizedU8(%arg0: tensor<2x1x3x!quant.uniform<u8:f32, 0.1>>, %arg1: tensor<3x2xi32>) -> tensor<? x !quant.uniform<u8:f32, 0.1>> {
// CHECK: "tfl.pad"(%arg0, %arg1)
%0 = "tfl.pad"(%arg0, %arg1) : (tensor<2x1x3x!quant.uniform<u8:f32, 0.1>>, tensor<3x2xi32>) -> tensor<? x !quant.uniform<u8:f32, 0.1>>
return %0#0 : tensor<? x !quant.uniform<u8:f32, 0.1>>
}
// CHECK-LABEL: testPadQuantizedI8
func @testPadQuantizedI8(%arg0: tensor<2x1x3x!quant.uniform<i8:f32, 0.1>>, %arg1: tensor<3x2xi32>) -> tensor<? x !quant.uniform<i8:f32, 0.1>> {
// CHECK: "tfl.pad"(%arg0, %arg1)
%0 = "tfl.pad"(%arg0, %arg1) : (tensor<2x1x3x!quant.uniform<i8:f32, 0.1>>, tensor<3x2xi32>) -> tensor<? x !quant.uniform<i8:f32, 0.1>>
return %0#0 : tensor<? x !quant.uniform<i8:f32, 0.1>>
}
// -----
// CHECK-LABEL: testPadV2
func @testPadV2(tensor<2x1x3xf32>, tensor<3x2xi32>) -> tensor<? x f32> {
^bb0(%arg0: tensor<2x1x3xf32>, %arg1: tensor<3x2xi32>):
@ -817,6 +851,20 @@ func @testPadV2WithInvalidConstantScalar(tensor<2x1x3xf32>, tensor<3x2xi32>) ->
// -----
func @packQuantizedU8(%arg0: tensor<2x!quant.uniform<u8:f32, 0.1>>, %arg1: tensor<2x!quant.uniform<u8:f32, 0.1>>) -> tensor<2x2x!quant.uniform<u8:f32, 0.1>> {
// CHECK: "tfl.pack"(%arg0, %arg1) {axis = 0 : i32, values_count = 2 : i32}
%0 = "tfl.pack"(%arg0, %arg1) {axis = 0 : i32, values_count = 2 : i32} : (tensor<2x!quant.uniform<u8:f32, 0.1>>, tensor<2x!quant.uniform<u8:f32, 0.1>>) -> tensor<2x2x!quant.uniform<u8:f32, 0.1>>
return %0 : tensor<2x2x!quant.uniform<u8:f32, 0.1>>
}
func @packQuantizedI8(%arg0: tensor<2x!quant.uniform<i8:f32, 0.1>>, %arg1: tensor<2x!quant.uniform<i8:f32, 0.1>>) -> tensor<2x2x!quant.uniform<i8:f32, 0.1>> {
// CHECK: "tfl.pack"(%arg0, %arg1) {axis = 0 : i32, values_count = 2 : i32}
%0 = "tfl.pack"(%arg0, %arg1) {axis = 0 : i32, values_count = 2 : i32} : (tensor<2x!quant.uniform<i8:f32, 0.1>>, tensor<2x!quant.uniform<i8:f32, 0.1>>) -> tensor<2x2x!quant.uniform<i8:f32, 0.1>>
return %0 : tensor<2x2x!quant.uniform<i8:f32, 0.1>>
}
// -----
func @pack(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2x2xi32> {
// CHECK: "tfl.pack"(%arg0, %arg1) {axis = 0 : i32, values_count = 2 : i32}
%0 = "tfl.pack"(%arg0, %arg1) {axis = 0 : i32, values_count = 2 : i32} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2x2xi32>
@ -1101,6 +1149,63 @@ func @transpose_perm_not_i32(%arg0 : tensor<2x2xi32>, %arg1 : tensor<2xf32>) ->
}
// -----
func @transpose_perm_size(%arg0 : tensor<2x2xi32>, %arg1 : tensor<3xi32>) -> tensor<2x2xi32> {
// expected-error @+1 {{perm tensor elements size is not equal to input tensor rank}}
%0 = "tfl.transpose"(%arg0, %arg1) : (tensor<2x2xi32>, tensor<3xi32>) -> tensor<2x2xi32>
return %0 : tensor<2x2xi32>
}
// -----
func @transpose_unranked_shape(%arg0 : tensor<*xi32>) -> tensor<2x2xi32> {
%cst = constant dense<[1, 0]> : tensor<2xi32>
%0 = "tfl.transpose"(%arg0, %cst) : (tensor<*xi32>, tensor<2xi32>) -> tensor<2x2xi32>
return %0 : tensor<2x2xi32>
}
// -----
func @transpose_dynamic_shape(%arg0 : tensor<2x?xi32>) -> tensor<?x2xi32> {
%cst = constant dense<[1, 0]> : tensor<2xi32>
%0 = "tfl.transpose"(%arg0, %cst) : (tensor<2x?xi32>, tensor<2xi32>) -> tensor<?x2xi32>
return %0 : tensor<?x2xi32>
}
// -----
func @transpose_perm_axis_invalid(%arg0 : tensor<2x2xi32>) -> tensor<2x2xi32> {
%cst = constant dense<[1, -1]> : tensor<2xi32>
// expected-error @+1 {{perm[1] must be in [0, rank)}}
%0 = "tfl.transpose"(%arg0, %cst) : (tensor<2x2xi32>, tensor<2xi32>) -> tensor<2x2xi32>
return %0 : tensor<2x2xi32>
}
// -----
func @transpose_perm_axis_duplicated(%arg0 : tensor<2x2xi32>) -> tensor<2x2xi32> {
%cst = constant dense<[1, 1]> : tensor<2xi32>
// expected-error @+1 {{perm[1] cannot have duplicated axis}}
%0 = "tfl.transpose"(%arg0, %cst) : (tensor<2x2xi32>, tensor<2xi32>) -> tensor<2x2xi32>
return %0 : tensor<2x2xi32>
}
// -----
func @transpose_output_type_bad(%arg0 : tensor<3x4x5x6xi32>) -> tensor<3x4x5x6xi32> {
%cst = constant dense<[0, 3, 1, 2]> : tensor<4xi32>
// expected-error @+1 {{expect output type tensor<3x6x4x5xi32>, got tensor<3x4x5x6xi32>}}
%0 = "tfl.transpose"(%arg0, %cst) : (tensor<3x4x5x6xi32>, tensor<4xi32>) -> tensor<3x4x5x6xi32>
return %0 : tensor<3x4x5x6xi32>
}
// -----
func @transpose_element_type(%arg0 : tensor<2x2xf32>, %arg1 : tensor<2xi32>) -> tensor<2x2xi32> {
@ -1643,3 +1748,33 @@ func @testSplitVOpWithValidSizeSplitsNegative(%arg0 : tensor<16x4xf32>) -> (tens
return %0, %1, %2, %3, %4 : tensor<7x4xf32>, tensor<3x4xf32>, tensor<6x4xf32>, tensor<16x0xf32>, tensor<16x4xf32>
}
// -----
func @testNonMaxSuppressionV4WithCorrectBoxShape(%arg0: tensor<3x4xf32>, %arg1: tensor<3xf32>, %arg2: tensor<i32>, %arg3: tensor<f32>, %arg4: tensor<f32>) -> (tensor<2xi32>, tensor<i32>) {
%0, %1 = "tfl.non_max_suppression_v4"(%arg0, %arg1, %arg2, %arg3, %arg4) : (tensor<3x4xf32>, tensor<3xf32>, tensor<i32>, tensor<f32>, tensor<f32>) -> (tensor<2xi32>, tensor<i32>)
return %0, %1 : tensor<2xi32>, tensor<i32>
}
// -----
func @testNonMaxSuppressionV4WithWrongBoxShape(%arg0: tensor<3x2xf32>, %arg1: tensor<3xf32>, %arg2: tensor<i32>, %arg3: tensor<f32>, %arg4: tensor<f32>) -> (tensor<2xi32>, tensor<i32>) {
// expected-error @+1 {{'tfl.non_max_suppression_v4' op failed to verify that boxes should have dim[1] == 4}}
%0, %1 = "tfl.non_max_suppression_v4"(%arg0, %arg1, %arg2, %arg3, %arg4) : (tensor<3x2xf32>, tensor<3xf32>, tensor<i32>, tensor<f32>, tensor<f32>) -> (tensor<2xi32>, tensor<i32>)
return %0, %1 : tensor<2xi32>, tensor<i32>
}
// -----
func @testNonMaxSuppressionV5WithCorrectBoxShape(%arg0: tensor<3x4xf32>, %arg1: tensor<3xf32>, %arg2: tensor<i32>, %arg3: tensor<f32>, %arg4: tensor<f32>, %arg5: tensor<f32>) -> (tensor<2xi32>, tensor<2xf32>, tensor<i32>) {
%0, %1, %2 = "tfl.non_max_suppression_v5"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) : (tensor<3x4xf32>, tensor<3xf32>, tensor<i32>, tensor<f32>, tensor<f32>, tensor<f32>) -> (tensor<2xi32>, tensor<2xf32>, tensor<i32>)
return %0, %1, %2 : tensor<2xi32>, tensor<2xf32>, tensor<i32>
}
// -----
func @testNonMaxSuppressionV5WithWrongBoxShape(%arg0: tensor<3x2xf32>, %arg1: tensor<3xf32>, %arg2: tensor<i32>, %arg3: tensor<f32>, %arg4: tensor<f32>, %arg5: tensor<f32>) -> (tensor<2xi32>, tensor<2xf32>, tensor<i32>) {
// expected-error @+1 {{'tfl.non_max_suppression_v5' op failed to verify that boxes should have dim[1] == 4}}
%0, %1, %2 = "tfl.non_max_suppression_v5"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) : (tensor<3x2xf32>, tensor<3xf32>, tensor<i32>, tensor<f32>, tensor<f32>, tensor<f32>) -> (tensor<2xi32>, tensor<2xf32>, tensor<i32>)
return %0, %1, %2 : tensor<2xi32>, tensor<2xf32>, tensor<i32>
}

View File

@ -292,16 +292,3 @@ func @InvalidL2NormalizePattern(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) -> t
// CHECK: %3 = "tfl.div"([[INPUT:%.*]], %2) {fused_activation_function = "NONE"} : (tensor<2xf32>, tensor<f32>) -> tensor<2xf32>
// CHECK: return %3
}
// CHECK-LABEL: @InvalidL2NormalizePatternMorethan1Dimension
// Input has higher rank, it should be limited to 1D only.
func @InvalidL2NormalizePatternMorethan1Dimension(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
%cst = constant dense<[0]> : tensor<1xi32>
%0 = "tfl.square"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
%1 = "tfl.sum"(%0, %cst) {keep_dims = false} : (tensor<2x2xf32>, tensor<1xi32>) -> tensor<f32>
%2 = "tfl.sqrt"(%1) : (tensor<f32>) -> tensor<f32>
%3 = "tfl.div"(%arg0, %2) {fused_activation_function = "NONE"} : (tensor<2x2xf32>, tensor<f32>) -> tensor<2x2xf32>
return %3: tensor<2x2xf32>
// CHECK: %3 = "tfl.div"([[INPUT:%.*]], %2) {fused_activation_function = "NONE"} : (tensor<2x2xf32>, tensor<f32>) -> tensor<2x2xf32>
// CHECK: return %3
}

View File

@ -373,21 +373,12 @@ func @QuantizeConstant() -> tensor<2x3xf32> {
// CHECK: return %1 : tensor<2x3xf32>
}
// CHECK-LABEL: NotQuantizeNonZeroSplat
func @NotQuantizeNonZeroSplat() -> tensor<2x3xf32> {
%cst = constant dense<2.0> : tensor<2x3xf32>
return %cst : tensor<2x3xf32>
// CHECK-LABEL: NotQuantizeNoneType
func @NotQuantizeNoneType() -> none {
%cst = constant unit
return %cst : none
// CHECK-NEXT: %[[cst:.*]] = constant dense<2.000000e+00>
// CHECK-NEXT: return %[[cst]]
}
// CHECK-LABEL: NotQuantizeNonZeroScalar
func @NotQuantizeNonZeroScalar() -> tensor<f32> {
%cst = constant dense<2.0> : tensor<f32>
return %cst : tensor<f32>
// CHECK-NEXT: %[[cst:.*]] = constant dense<2.000000e+00>
// CHECK-NEXT: %[[cst:.*]] = constant unit
// CHECK-NEXT: return %[[cst]]
}
@ -433,6 +424,32 @@ func @QuantizeSharedBiases(
// CHECK: %[[cst_0:.*]] = constant dense<1.000000e+00> : tensor<32xf32>
// CHECK: %[[q_0:.*]] = "tfl.quantize"(%[[cst_0]])
// CHECK: %[[dq_0:.*]] = "tfl.dequantize"(%[[q_0]])
// CHECK: %{{.*}} = "tfl.conv_2d"(%{{.*}}, %{{.*}}, %[[dq]])
// CHECK: %{{.*}} = "tfl.conv_2d"(%{{.*}}, %{{.*}}, %[[dq_0]])
// CHECK: %{{.*}} = "tfl.conv_2d"(%{{.*}}, %{{.*}}, %[[dq]])
}
// CHECK-LABEL: QuantizeSharedBiases2
func @QuantizeSharedBiases2(
%arg0: tensor<32x!quant.uniform<u8:f32, 1.0>>,
%arg1: tensor<1x112x112x32x!quant.uniform<u8:f32, 1.0>>,
%arg2: tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32, 2.0>>) -> (tensor<32x!quant.uniform<u8:f32, 1.0>>, tensor<1x56x56x32x!quant.uniform<u8:f32, 1.0>>) {
%cst = constant dense<1.0> : tensor<32xf32>
%1 = "tfl.dequantize"(%arg0) : (tensor<32x!quant.uniform<u8:f32, 1.0>>) -> tensor<32xf32>
%add = "tfl.add"(%1, %cst) {fused_activation_function = "NONE"} : (tensor<32xf32>, tensor<32xf32>) -> tensor<32xf32>
%3 = "tfl.quantize"(%add) {qtype = tensor<32xf32>} : (tensor<32xf32>) -> tensor<32x!quant.uniform<u8:f32, 1.0>>
%5 = "tfl.dequantize"(%arg1) : (tensor<1x112x112x32x!quant.uniform<u8:f32, 1.0>>) -> tensor<1x112x112x32xf32>
%6 = "tfl.dequantize"(%arg2) : (tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32, 2.0>>) -> tensor<32x3x3x3xf32>
%conv2 = "tfl.conv_2d"(%5, %6, %cst) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<1x112x112x32xf32>, tensor<32x3x3x3xf32>, tensor<32xf32>) -> tensor<1x56x56x32xf32>
%7 = "tfl.quantize"(%conv2) {qtype = tensor<1x56x56x32x!quant.uniform<u8:f32, 1.0>>} : (tensor<1x56x56x32xf32>) -> tensor<1x56x56x32x!quant.uniform<u8:f32, 1.0>>
return %3, %7 : tensor<32x!quant.uniform<u8:f32, 1.0>>, tensor<1x56x56x32x!quant.uniform<u8:f32, 1.0>>
// CHECK: %[[cst:.*]] = constant dense<1.000000e+00> : tensor<32xf32>
// CHECK: %[[q:.*]] = "tfl.quantize"(%[[cst]])
// CHECK: %[[dq:.*]] = "tfl.dequantize"(%[[q]])
// CHECK: %[[cst_0:.*]] = constant dense<1.000000e+00> : tensor<32xf32>
// CHECK: %[[q_0:.*]] = "tfl.quantize"(%[[cst_0]]) {qtype = tensor<32x!quant.uniform<u8<1:255>:f32, 1.000000e+00:1>>}
// CHECK: %[[dq_0:.*]] = "tfl.dequantize"(%[[q_0]])
// CHECK: %{{.*}} = tfl.add %{{.*}}, %[[dq_0]]
// CHECK: %{{.*}} = "tfl.conv_2d"(%{{.*}}, %{{.*}}, %[[dq]])
}

View File

@ -63,7 +63,7 @@ func @fusedBatchNorm(tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8
return %2, %2#1 : tensor<8x8x8x8xf32>, tensor<8xf32>
// CHECK-LABEL: fusedBatchNorm
// CHECK: %[[CONSTANT:.*]] = "tf.Const"{{.*}} dense<1.000000e-03>
// CHECK: %[[CONSTANT:.*]] = constant dense<1.000000e-03>
// variance + epsilon
// CHECK: %[[ADD1:.*]] = "tf.Add"(%[[ARG4:.*]], %[[CONSTANT]])
// rsqrt(variance + epsilon)
@ -96,7 +96,7 @@ func @fusedBatchNormV3(tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor
return %2, %2#1 : tensor<8x8x8x8xf32>, tensor<8xf32>
// CHECK-LABEL: fusedBatchNormV3
// CHECK: %[[CONSTANT:.*]] = "tf.Const"{{.*}} dense<1.000000e-03>
// CHECK: %[[CONSTANT:.*]] = constant dense<1.000000e-03>
// variance + epsilon
// CHECK: %[[ADD1:.*]] = "tf.Add"(%[[ARG4:.*]], %[[CONSTANT]])
// rsqrt(variance + epsilon)
@ -155,7 +155,7 @@ func @fakeQuantFolded() -> (tensor<8xf32>) {
%rst = "tf.FakeQuantWithMinMaxVars"(%in, %mini, %maxi) {num_bits = 3, narrow_range = false} : (tensor<8xf32>, tensor<f32>, tensor<f32>) -> tensor<8xf32>
return %rst : tensor<8xf32>
// CHECK: %[[CONSTANT:.*]] = "tf.Const"{{.*}} dense<0.000000e+00> : tensor<8xf32>
// CHECK: %[[CONSTANT:.*]] = constant dense<0.000000e+00> : tensor<8xf32>
// CHECK: %[[QUANTIZE:.*]] = "tfl.quantize"(%[[CONSTANT]]) {qtype = tensor<8x!quant.uniform<u8:f32, 1.000000e+00>>}
// CHECK: %[[DEQUANTIZE:.*]] = "tfl.dequantize"(%[[QUANTIZE]])
// CHECK: return %[[DEQUANTIZE]] : tensor<8xf32>
@ -262,7 +262,7 @@ func @fakeQuantWithConv2D(tensor<256x32x32x3xf32>) -> (tensor<256x30x30x16xf32>)
return %rst : tensor<256x30x30x16xf32>
// CHECK: %[[CONSTANT:.*]] = constant dense<0.000000e+00> : tensor<16xf32>
// CHECK: %[[CONSTANT0:.*]] = "tf.Const"{{.*}} dense<0.000000e+00> : tensor<16x3x3x3xf32>
// CHECK: %[[CONSTANT0:.*]] = constant dense<0.000000e+00> : tensor<16x3x3x3xf32>
// CHECK: %[[QUANTIZE:.*]] = "tfl.quantize"(%[[CONSTANT0]]) {qtype = tensor<16x3x3x3x!quant.uniform<u8:f32, 1.000000e+00>>}
// CHECK: %[[DEQUANTIZE:.*]] = "tfl.dequantize"(%[[QUANTIZE]])
// CHECK: %[[CONV:.*]] = "tfl.conv_2d"(%arg0, %[[DEQUANTIZE]], %[[CONSTANT]])
@ -282,7 +282,7 @@ func @fakeQuantWithDepthwiseConv2D(tensor<256x32x32x3xf32>) -> (tensor<256x30x30
return %rst : tensor<256x30x30x16xf32>
// CHECK: %[[CONSTANT:.*]] = constant dense<0.000000e+00> : tensor<48xf32>
// CHECK: %[[CONSTANT0:.*]] = "tf.Const"{{.*}} dense<0.000000e+00> : tensor<1x3x3x48xf32>
// CHECK: %[[CONSTANT0:.*]] = constant dense<0.000000e+00> : tensor<1x3x3x48xf32>
// CHECK: %[[QUANTIZE:.*]] = "tfl.quantize"(%[[CONSTANT0]]) {qtype = tensor<1x3x3x48x!quant.uniform<u8:f32, 1.000000e+00>>}
// CHECK: %[[DEQUANTIZE:.*]] = "tfl.dequantize"(%[[QUANTIZE]])
// CHECK: %[[CONV:.*]] = "tfl.depthwise_conv_2d"(%arg0, %[[DEQUANTIZE]], %[[CONSTANT]])
@ -348,3 +348,11 @@ func @stop_gradient(%arg0: tensor<3xi32>) -> tensor<3xi32> {
// CHECK-LABEL: stop_gradient
// CHECK: return %arg0 : tensor<3xi32>
}
func @CheckNumerics(%arg0: tensor<3xf32>) -> tensor<3xf32> {
%0 = "tf.CheckNumerics"(%arg0) {message = ""}: (tensor<3xf32>) -> tensor<3xf32>
return %0 : tensor<3xf32>
// Should be converted to Identity and then from Identity to value
// CHECK-LABEL: CheckNumerics
// CHECK: return %arg0 : tensor<3xf32>
}

View File

@ -0,0 +1,223 @@
// RUN: tf-opt -tfl-unroll-batch-matmul %s | FileCheck %s
func @batchMatMulV2TwoDim(%arg0: tensor<2x3x4x5xf32>, %arg1: tensor<2x3x5x6xf32>) -> tensor<2x3x4x6xf32> {
%0 = "tf.BatchMatMulV2"(%arg0, %arg1) : (tensor<2x3x4x5xf32>, tensor<2x3x5x6xf32>) -> tensor<2x3x4x6xf32>
return %0 : tensor<2x3x4x6xf32>
// CHECK-LABEL: batchMatMulV2TwoDim
// CHECK: %[[cst:.*]] = constant dense<[6, 4, 5]> : tensor<3xi64>
// CHECK: %[[cst_0:.*]] = constant dense<[1, 4, 5]> : tensor<3xi64>
// CHECK: %[[cst_1:.*]] = constant dense<[4, 5]> : tensor<2xi64>
// CHECK: %[[cst_2:.*]] = constant dense<[6, 5, 6]> : tensor<3xi64>
// CHECK: %[[cst_3:.*]] = constant dense<0> : tensor<3xi64>
// CHECK: %[[cst_4:.*]] = constant dense<[1, 0, 0]> : tensor<3xi64>
// CHECK: %[[cst_5:.*]] = constant dense<[2, 0, 0]> : tensor<3xi64>
// CHECK: %[[cst_6:.*]] = constant dense<[3, 0, 0]> : tensor<3xi64>
// CHECK: %[[cst_7:.*]] = constant dense<[4, 0, 0]> : tensor<3xi64>
// CHECK: %[[cst_8:.*]] = constant dense<[5, 0, 0]> : tensor<3xi64>
// CHECK: %[[cst_9:.*]] = constant dense<[1, 5, 6]> : tensor<3xi64>
// CHECK: %[[cst_10:.*]] = constant dense<[5, 6]> : tensor<2xi64>
// CHECK: %[[cst_11:.*]] = constant dense<[2, 3, 4, 6]> : tensor<4xi64>
// CHECK: %[[v0:.*]] = "tf.Reshape"(%arg0, %[[cst]]) : (tensor<2x3x4x5xf32>, tensor<3xi64>) -> tensor<6x4x5xf32>
// CHECK: %[[v1:.*]] = "tf.Slice"(%[[v0]], %[[cst_3]], %[[cst_0]]) : (tensor<6x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
// CHECK: %[[v2:.*]] = "tf.Reshape"(%[[v1]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
// CHECK: %[[v3:.*]] = "tf.Slice"(%[[v0]], %[[cst_4]], %[[cst_0]]) : (tensor<6x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
// CHECK: %[[v4:.*]] = "tf.Reshape"(%[[v3]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
// CHECK: %[[v5:.*]] = "tf.Slice"(%[[v0]], %[[cst_5]], %[[cst_0]]) : (tensor<6x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
// CHECK: %[[v6:.*]] = "tf.Reshape"(%[[v5]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
// CHECK: %[[v7:.*]] = "tf.Slice"(%[[v0]], %[[cst_6]], %[[cst_0]]) : (tensor<6x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
// CHECK: %[[v8:.*]] = "tf.Reshape"(%[[v7]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
// CHECK: %[[v9:.*]] = "tf.Slice"(%[[v0]], %[[cst_7]], %[[cst_0]]) : (tensor<6x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
// CHECK: %[[v10:.*]] = "tf.Reshape"(%[[v9]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
// CHECK: %[[v11:.*]] = "tf.Slice"(%[[v0]], %[[cst_8]], %[[cst_0]]) : (tensor<6x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
// CHECK: %[[v12:.*]] = "tf.Reshape"(%[[v11]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
// CHECK: %[[v13:.*]] = "tf.Reshape"(%arg1, %[[cst_2]]) : (tensor<2x3x5x6xf32>, tensor<3xi64>) -> tensor<6x5x6xf32>
// CHECK: %[[v14:.*]] = "tf.Slice"(%[[v13]], %[[cst_3]], %[[cst_9]]) : (tensor<6x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32>
// CHECK: %[[v15:.*]] = "tf.Reshape"(%[[v14]], %[[cst_10]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
// CHECK: %[[v16:.*]] = "tf.Slice"(%[[v13]], %[[cst_4]], %[[cst_9]]) : (tensor<6x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32>
// CHECK: %[[v17:.*]] = "tf.Reshape"(%[[v16]], %[[cst_10]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
// CHECK: %[[v18:.*]] = "tf.Slice"(%[[v13]], %[[cst_5]], %[[cst_9]]) : (tensor<6x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32>
// CHECK: %[[v19:.*]] = "tf.Reshape"(%[[v18]], %[[cst_10]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
// CHECK: %[[v20:.*]] = "tf.Slice"(%[[v13]], %[[cst_6]], %[[cst_9]]) : (tensor<6x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32>
// CHECK: %[[v21:.*]] = "tf.Reshape"(%[[v20]], %[[cst_10]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
// CHECK: %[[v22:.*]] = "tf.Slice"(%[[v13]], %[[cst_7]], %[[cst_9]]) : (tensor<6x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32>
// CHECK: %[[v23:.*]] = "tf.Reshape"(%[[v22]], %[[cst_10]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
// CHECK: %[[v24:.*]] = "tf.Slice"(%[[v13]], %[[cst_8]], %[[cst_9]]) : (tensor<6x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32>
// CHECK: %[[v25:.*]] = "tf.Reshape"(%[[v24]], %[[cst_10]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
// CHECK: %[[v26:.*]] = "tf.MatMul"(%[[v2]], %[[v15]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: %[[v27:.*]] = "tf.MatMul"(%[[v4]], %[[v17]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: %[[v28:.*]] = "tf.MatMul"(%[[v6]], %[[v19]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: %[[v29:.*]] = "tf.MatMul"(%[[v8]], %[[v21]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: %[[v30:.*]] = "tf.MatMul"(%[[v10]], %[[v23]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: %[[v31:.*]] = "tf.MatMul"(%[[v12]], %[[v25]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: %[[v32:.*]] = "tf.Pack"(%[[v26]], %[[v27]], %[[v28]], %[[v29]], %[[v30]], %[[v31]]) {N = 6 : i64, axis = 0 : i64} : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<6x4x6xf32>
// CHECK: %[[v33:.*]] = "tf.Reshape"(%[[v32]], %[[cst_11]]) : (tensor<6x4x6xf32>, tensor<4xi64>) -> tensor<2x3x4x6xf32>
// CHECK: return %[[v33]] : tensor<2x3x4x6xf32>
}
func @batchMatMulV2FlatInput(%arg0: tensor<3x4x5xf32>, %arg1: tensor<3x5x6xf32>) -> tensor<3x4x6xf32> {
%0 = "tf.BatchMatMulV2"(%arg0, %arg1) : (tensor<3x4x5xf32>, tensor<3x5x6xf32>) -> tensor<3x4x6xf32>
return %0 : tensor<3x4x6xf32>
// CHECK-LABEL: batchMatMulV2FlatInput
// CHECK: %[[cst:.*]] = constant dense<[3, 4, 5]> : tensor<3xi64>
// CHECK: %[[cst_0:.*]] = constant dense<[1, 4, 5]> : tensor<3xi64>
// CHECK: %[[cst_1:.*]] = constant dense<[4, 5]> : tensor<2xi64>
// CHECK: %[[cst_2:.*]] = constant dense<[3, 5, 6]> : tensor<3xi64>
// CHECK: %[[cst_3:.*]] = constant dense<0> : tensor<3xi64>
// CHECK: %[[cst_4:.*]] = constant dense<[1, 0, 0]> : tensor<3xi64>
// CHECK: %[[cst_5:.*]] = constant dense<[2, 0, 0]> : tensor<3xi64>
// CHECK: %[[cst_6:.*]] = constant dense<[1, 5, 6]> : tensor<3xi64>
// CHECK: %[[cst_7:.*]] = constant dense<[5, 6]> : tensor<2xi64>
// CHECK: %[[cst_8:.*]] = constant dense<[3, 4, 6]> : tensor<3xi64>
// CHECK: %[[v0:.*]] = "tf.Reshape"(%arg0, %[[cst]]) : (tensor<3x4x5xf32>, tensor<3xi64>) -> tensor<3x4x5xf32>
// CHECK: %[[v1:.*]] = "tf.Slice"(%[[v0]], %[[cst_3]], %[[cst_0]]) : (tensor<3x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
// CHECK: %[[v2:.*]] = "tf.Reshape"(%[[v1]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
// CHECK: %[[v3:.*]] = "tf.Slice"(%[[v0]], %[[cst_4]], %[[cst_0]]) : (tensor<3x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
// CHECK: %[[v4:.*]] = "tf.Reshape"(%[[v3]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
// CHECK: %[[v5:.*]] = "tf.Slice"(%[[v0]], %[[cst_5]], %[[cst_0]]) : (tensor<3x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
// CHECK: %[[v6:.*]] = "tf.Reshape"(%[[v5]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
// CHECK: %[[v7:.*]] = "tf.Reshape"(%arg1, %[[cst_2]]) : (tensor<3x5x6xf32>, tensor<3xi64>) -> tensor<3x5x6xf32>
// CHECK: %[[v8:.*]] = "tf.Slice"(%[[v7]], %[[cst_3]], %[[cst_6]]) : (tensor<3x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32>
// CHECK: %[[v9:.*]] = "tf.Reshape"(%[[v8]], %[[cst_7]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
// CHECK: %[[v10:.*]] = "tf.Slice"(%[[v7]], %[[cst_4]], %[[cst_6]]) : (tensor<3x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32>
// CHECK: %[[v11:.*]] = "tf.Reshape"(%[[v10]], %[[cst_7]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
// CHECK: %[[v12:.*]] = "tf.Slice"(%[[v7]], %[[cst_5]], %[[cst_6]]) : (tensor<3x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32>
// CHECK: %[[v13:.*]] = "tf.Reshape"(%[[v12]], %[[cst_7]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
// CHECK: %[[v14:.*]] = "tf.MatMul"(%[[v2]], %[[v9]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: %[[v15:.*]] = "tf.MatMul"(%[[v4]], %[[v11]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: %[[v16:.*]] = "tf.MatMul"(%[[v6]], %[[v13]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: %[[v17:.*]] = "tf.Pack"(%[[v14]], %[[v15]], %[[v16]]) {N = 3 : i64, axis = 0 : i64} : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<3x4x6xf32>
// CHECK: %[[v18:.*]] = "tf.Reshape"(%[[v17]], %[[cst_8]]) : (tensor<3x4x6xf32>, tensor<3xi64>) -> tensor<3x4x6xf32>
// CHECK: return %[[v18]] : tensor<3x4x6xf32>
}
func @batchMatMulV2Matrix(%arg0: tensor<4x5xf32>, %arg1: tensor<5x6xf32>) -> tensor<4x6xf32> {
%0 = "tf.BatchMatMulV2"(%arg0, %arg1) : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
return %0 : tensor<4x6xf32>
// CHECK-LABEL: batchMatMulV2Matrix
// CHECK: %[[v0:.*]] = "tf.MatMul"(%arg0, %arg1) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: return %[[v0]] : tensor<4x6xf32>
}
func @batchMatMulTwoDim(%arg0: tensor<2x3x4x5xf32>, %arg1: tensor<2x3x5x6xf32>) -> tensor<2x3x4x6xf32> {
%0 = "tf.BatchMatMul"(%arg0, %arg1) : (tensor<2x3x4x5xf32>, tensor<2x3x5x6xf32>) -> tensor<2x3x4x6xf32>
return %0 : tensor<2x3x4x6xf32>
// CHECK-LABEL: batchMatMulTwoDim
// CHECK: %[[cst:.*]] = constant dense<[6, 4, 5]> : tensor<3xi64>
// CHECK: %[[cst_0:.*]] = constant dense<[1, 4, 5]> : tensor<3xi64>
// CHECK: %[[cst_1:.*]] = constant dense<[4, 5]> : tensor<2xi64>
// CHECK: %[[cst_2:.*]] = constant dense<[6, 5, 6]> : tensor<3xi64>
// CHECK: %[[cst_3:.*]] = constant dense<0> : tensor<3xi64>
// CHECK: %[[cst_4:.*]] = constant dense<[1, 0, 0]> : tensor<3xi64>
// CHECK: %[[cst_5:.*]] = constant dense<[2, 0, 0]> : tensor<3xi64>
// CHECK: %[[cst_6:.*]] = constant dense<[3, 0, 0]> : tensor<3xi64>
// CHECK: %[[cst_7:.*]] = constant dense<[4, 0, 0]> : tensor<3xi64>
// CHECK: %[[cst_8:.*]] = constant dense<[5, 0, 0]> : tensor<3xi64>
// CHECK: %[[cst_9:.*]] = constant dense<[1, 5, 6]> : tensor<3xi64>
// CHECK: %[[cst_10:.*]] = constant dense<[5, 6]> : tensor<2xi64>
// CHECK: %[[cst_11:.*]] = constant dense<[2, 3, 4, 6]> : tensor<4xi64>
// CHECK: %[[v0:.*]] = "tf.Reshape"(%arg0, %[[cst]]) : (tensor<2x3x4x5xf32>, tensor<3xi64>) -> tensor<6x4x5xf32>
// CHECK: %[[v1:.*]] = "tf.Slice"(%[[v0]], %[[cst_3]], %[[cst_0]]) : (tensor<6x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
// CHECK: %[[v2:.*]] = "tf.Reshape"(%[[v1]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
// CHECK: %[[v3:.*]] = "tf.Slice"(%[[v0]], %[[cst_4]], %[[cst_0]]) : (tensor<6x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
// CHECK: %[[v4:.*]] = "tf.Reshape"(%[[v3]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
// CHECK: %[[v5:.*]] = "tf.Slice"(%[[v0]], %[[cst_5]], %[[cst_0]]) : (tensor<6x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
// CHECK: %[[v6:.*]] = "tf.Reshape"(%[[v5]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
// CHECK: %[[v7:.*]] = "tf.Slice"(%[[v0]], %[[cst_6]], %[[cst_0]]) : (tensor<6x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
// CHECK: %[[v8:.*]] = "tf.Reshape"(%[[v7]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
// CHECK: %[[v9:.*]] = "tf.Slice"(%[[v0]], %[[cst_7]], %[[cst_0]]) : (tensor<6x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
// CHECK: %[[v10:.*]] = "tf.Reshape"(%[[v9]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
// CHECK: %[[v11:.*]] = "tf.Slice"(%[[v0]], %[[cst_8]], %[[cst_0]]) : (tensor<6x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
// CHECK: %[[v12:.*]] = "tf.Reshape"(%[[v11]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
// CHECK: %[[v13:.*]] = "tf.Reshape"(%arg1, %[[cst_2]]) : (tensor<2x3x5x6xf32>, tensor<3xi64>) -> tensor<6x5x6xf32>
// CHECK: %[[v14:.*]] = "tf.Slice"(%[[v13]], %[[cst_3]], %[[cst_9]]) : (tensor<6x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32>
// CHECK: %[[v15:.*]] = "tf.Reshape"(%[[v14]], %[[cst_10]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
// CHECK: %[[v16:.*]] = "tf.Slice"(%[[v13]], %[[cst_4]], %[[cst_9]]) : (tensor<6x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32>
// CHECK: %[[v17:.*]] = "tf.Reshape"(%[[v16]], %[[cst_10]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
// CHECK: %[[v18:.*]] = "tf.Slice"(%[[v13]], %[[cst_5]], %[[cst_9]]) : (tensor<6x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32>
// CHECK: %[[v19:.*]] = "tf.Reshape"(%[[v18]], %[[cst_10]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
// CHECK: %[[v20:.*]] = "tf.Slice"(%[[v13]], %[[cst_6]], %[[cst_9]]) : (tensor<6x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32>
// CHECK: %[[v21:.*]] = "tf.Reshape"(%[[v20]], %[[cst_10]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
// CHECK: %[[v22:.*]] = "tf.Slice"(%[[v13]], %[[cst_7]], %[[cst_9]]) : (tensor<6x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32>
// CHECK: %[[v23:.*]] = "tf.Reshape"(%[[v22]], %[[cst_10]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
// CHECK: %[[v24:.*]] = "tf.Slice"(%[[v13]], %[[cst_8]], %[[cst_9]]) : (tensor<6x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32>
// CHECK: %[[v25:.*]] = "tf.Reshape"(%[[v24]], %[[cst_10]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
// CHECK: %[[v26:.*]] = "tf.MatMul"(%[[v2]], %[[v15]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: %[[v27:.*]] = "tf.MatMul"(%[[v4]], %[[v17]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: %[[v28:.*]] = "tf.MatMul"(%[[v6]], %[[v19]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: %[[v29:.*]] = "tf.MatMul"(%[[v8]], %[[v21]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: %[[v30:.*]] = "tf.MatMul"(%[[v10]], %[[v23]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: %[[v31:.*]] = "tf.MatMul"(%[[v12]], %[[v25]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: %[[v32:.*]] = "tf.Pack"(%[[v26]], %[[v27]], %[[v28]], %[[v29]], %[[v30]], %[[v31]]) {N = 6 : i64, axis = 0 : i64} : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<6x4x6xf32>
// CHECK: %[[v33:.*]] = "tf.Reshape"(%[[v32]], %[[cst_11]]) : (tensor<6x4x6xf32>, tensor<4xi64>) -> tensor<2x3x4x6xf32>
// CHECK: return %[[v33]] : tensor<2x3x4x6xf32>
}
func @batchMatMulFlatInput(%arg0: tensor<3x4x5xf32>, %arg1: tensor<3x5x6xf32>) -> tensor<3x4x6xf32> {
%0 = "tf.BatchMatMul"(%arg0, %arg1) : (tensor<3x4x5xf32>, tensor<3x5x6xf32>) -> tensor<3x4x6xf32>
return %0 : tensor<3x4x6xf32>
// CHECK-LABEL: batchMatMulFlatInput
// CHECK: %[[cst:.*]] = constant dense<[3, 4, 5]> : tensor<3xi64>
// CHECK: %[[cst_0:.*]] = constant dense<[1, 4, 5]> : tensor<3xi64>
// CHECK: %[[cst_1:.*]] = constant dense<[4, 5]> : tensor<2xi64>
// CHECK: %[[cst_2:.*]] = constant dense<[3, 5, 6]> : tensor<3xi64>
// CHECK: %[[cst_3:.*]] = constant dense<0> : tensor<3xi64>
// CHECK: %[[cst_4:.*]] = constant dense<[1, 0, 0]> : tensor<3xi64>
// CHECK: %[[cst_5:.*]] = constant dense<[2, 0, 0]> : tensor<3xi64>
// CHECK: %[[cst_6:.*]] = constant dense<[1, 5, 6]> : tensor<3xi64>
// CHECK: %[[cst_7:.*]] = constant dense<[5, 6]> : tensor<2xi64>
// CHECK: %[[cst_8:.*]] = constant dense<[3, 4, 6]> : tensor<3xi64>
// CHECK: %[[v0:.*]] = "tf.Reshape"(%arg0, %[[cst]]) : (tensor<3x4x5xf32>, tensor<3xi64>) -> tensor<3x4x5xf32>
// CHECK: %[[v1:.*]] = "tf.Slice"(%[[v0]], %[[cst_3]], %[[cst_0]]) : (tensor<3x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
// CHECK: %[[v2:.*]] = "tf.Reshape"(%[[v1]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
// CHECK: %[[v3:.*]] = "tf.Slice"(%[[v0]], %[[cst_4]], %[[cst_0]]) : (tensor<3x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
// CHECK: %[[v4:.*]] = "tf.Reshape"(%[[v3]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
// CHECK: %[[v5:.*]] = "tf.Slice"(%[[v0]], %[[cst_5]], %[[cst_0]]) : (tensor<3x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
// CHECK: %[[v6:.*]] = "tf.Reshape"(%[[v5]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
// CHECK: %[[v7:.*]] = "tf.Reshape"(%arg1, %[[cst_2]]) : (tensor<3x5x6xf32>, tensor<3xi64>) -> tensor<3x5x6xf32>
// CHECK: %[[v8:.*]] = "tf.Slice"(%[[v7]], %[[cst_3]], %[[cst_6]]) : (tensor<3x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32>
// CHECK: %[[v9:.*]] = "tf.Reshape"(%[[v8]], %[[cst_7]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
// CHECK: %[[v10:.*]] = "tf.Slice"(%[[v7]], %[[cst_4]], %[[cst_6]]) : (tensor<3x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32>
// CHECK: %[[v11:.*]] = "tf.Reshape"(%[[v10]], %[[cst_7]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
// CHECK: %[[v12:.*]] = "tf.Slice"(%[[v7]], %[[cst_5]], %[[cst_6]]) : (tensor<3x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32>
// CHECK: %[[v13:.*]] = "tf.Reshape"(%[[v12]], %[[cst_7]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
// CHECK: %[[v14:.*]] = "tf.MatMul"(%[[v2]], %[[v9]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: %[[v15:.*]] = "tf.MatMul"(%[[v4]], %[[v11]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: %[[v16:.*]] = "tf.MatMul"(%[[v6]], %[[v13]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: %[[v17:.*]] = "tf.Pack"(%[[v14]], %[[v15]], %[[v16]]) {N = 3 : i64, axis = 0 : i64} : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<3x4x6xf32>
// CHECK: %[[v18:.*]] = "tf.Reshape"(%[[v17]], %[[cst_8]]) : (tensor<3x4x6xf32>, tensor<3xi64>) -> tensor<3x4x6xf32>
// CHECK: return %[[v18]] : tensor<3x4x6xf32>
}
func @batchMatMulMatrix(%arg0: tensor<4x5xf32>, %arg1: tensor<5x6xf32>) -> tensor<4x6xf32> {
%0 = "tf.BatchMatMul"(%arg0, %arg1) : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
return %0 : tensor<4x6xf32>
// CHECK-LABEL: batchMatMulMatrix
// CHECK: %[[v0:.*]] = "tf.MatMul"(%arg0, %arg1) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: return %[[v0]] : tensor<4x6xf32>
}

View File

@ -149,7 +149,12 @@ int main(int argc, char **argv) {
lower_tensor_list_ops, &result, &pm);
if (!status.ok()) return kTrFailure;
auto output = mlir::openOutputFile(output_file_name);
std::string error_msg;
auto output = mlir::openOutputFile(output_file_name, &error_msg);
if (output == nullptr) {
llvm::errs() << error_msg << '\n';
return kTrFailure;
}
output->os() << result;
output->keep();

View File

@ -14,7 +14,9 @@ limitations under the License.
==============================================================================*/
#include <map>
#include <queue>
#include <vector>
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/StringSwitch.h"
@ -353,6 +355,127 @@ struct OphintCompositeOp {
std::map<int, AggregatedOperand> outputs;
};
// Preprocess the graph for topo sort. (each operation is a node, while
// inputs/outputs indictate edges) Assume the graph is acyclic. The preprocess
// does the following:
// Compute each operations's in-degress (how many input nodes they're taken)
// Get all consumer operations for every operations. (operation_to_ouputs)
// Get the init_queue (those operations will be processed first).
void PreprocessTopoSortGraph(
Block* block, std::queue<Operation*>* init_queue,
llvm::DenseMap<Operation*, llvm::DenseSet<Operation*>>* operation_to_ouputs,
llvm::DenseMap<Operation*, int>* operation_to_in_degrees) {
for (auto& op : *block) {
if (&op == block->getTerminator()) continue;
if (op.getNumOperands() == 0) {
init_queue->push(&op);
} else {
// The operand of the ops is not a direct indication of the "edge" as we
// can have a pack op after a unpack op (they have multiple edges), we
// should only count as one.
llvm::DenseSet<Operation*> input_ops;
for (int i = 0; i < op.getNumOperands(); ++i) {
Operation* input_op = op.getOperand(i)->getDefiningOp();
if (input_op) input_ops.insert(input_op);
}
if (input_ops.empty()) {
init_queue->push(&op);
continue;
}
operation_to_in_degrees->try_emplace(&op, input_ops.size());
for (auto* input_op : input_ops) {
auto preceeding_op_it = operation_to_ouputs->find(input_op);
if (preceeding_op_it == operation_to_ouputs->end()) {
auto result = operation_to_ouputs->try_emplace(
input_op, llvm::DenseSet<Operation*>());
preceeding_op_it = result.first;
}
preceeding_op_it->second.insert(&op);
}
}
}
}
bool IsSideEffectOp(Operation* op) {
if (op->hasNoSideEffect()) return false;
// Identity op has no side effect.
// Check the OperationName maybe more elegant here.
auto tf_identity_op = dyn_cast_or_null<TF::IdentityOp>(op);
if (tf_identity_op) return false;
return true;
}
// It's possible other transformations can benefit from this util function, but
// since currently there's none, so we only limit this function to the ophint
// extraction pass. We may refactor this function to extend the usage in future.
//
// Assume the graph is disconnected from outside.
// Also assume the block has no arguments.
LogicalResult TopoSortOperations(OpBuilder* builder) {
std::queue<Operation*> init_queue;
llvm::DenseMap<Operation*, llvm::DenseSet<Operation*>> operation_to_ouputs;
llvm::DenseMap<Operation*, int> operation_to_in_degrees;
std::vector<Operation*> sorted_ops;
PreprocessTopoSortGraph(builder->getBlock(), &init_queue,
&operation_to_ouputs, &operation_to_in_degrees);
while (!init_queue.empty()) {
Operation* current_op = init_queue.front();
init_queue.pop();
sorted_ops.push_back(current_op);
auto current_op_to_output_it = operation_to_ouputs.find(current_op);
if (current_op_to_output_it == operation_to_ouputs.end()) {
continue;
}
for (Operation* output_op : current_op_to_output_it->second) {
auto output_op_it = operation_to_in_degrees.find(output_op);
if (output_op_it == operation_to_in_degrees.end()) return failure();
output_op_it->second -= 1;
if (output_op_it->second == 0) {
init_queue.push(output_op);
operation_to_in_degrees.erase(output_op_it);
}
}
operation_to_ouputs.erase(current_op_to_output_it);
}
// Before we performs the sort. We need to make sure we didn't mess the
// ordering of original side-effect operations.
// It's possible those side-effect operations have no topogocial relations
// at all!
std::vector<Operation*> original_side_effect_ops;
std::vector<Operation*> after_sort_side_effect_ops;
for (auto& op : *builder->getBlock()) {
if (IsSideEffectOp(&op) && (&op != builder->getBlock()->getTerminator()))
original_side_effect_ops.push_back(&op);
}
for (auto* op : sorted_ops) {
if (IsSideEffectOp(op)) after_sort_side_effect_ops.push_back(op);
}
if (original_side_effect_ops.size() != after_sort_side_effect_ops.size())
return failure();
for (int i = 0; i < original_side_effect_ops.size(); ++i) {
if (original_side_effect_ops[i] != after_sort_side_effect_ops[i])
return failure();
}
// Performs the sort.
// Ideally it would be nice to just clear the block then write the sorted ops.
// But unfortunately that's hard to do.
for (int i = sorted_ops.size() - 1; i > 0; --i) {
Operation* current_op = sorted_ops[i];
for (int j = i - 1; j >= 0; --j) {
Operation* prev_op = sorted_ops[j];
prev_op->moveBefore(current_op);
}
}
return success();
}
Operation* BuildFusedFuncOp(StringRef func_name, StringRef fused_func_type,
Operation* insert_before_op,
const std::map<int, Value*>& inputs,
@ -360,10 +483,12 @@ Operation* BuildFusedFuncOp(StringRef func_name, StringRef fused_func_type,
OpBuilder* builder, ModuleOp* module_op) {
SmallVector<Type, 4> input_types;
SmallVector<Value*, 4> input_values;
SmallVector<int, 4> input_indexes;
for (const auto& kv : inputs) {
Value* input = kv.second;
input_types.push_back(input->getType());
input_values.push_back(input);
input_indexes.push_back(kv.first);
}
SmallVector<Type, 4> func_output_types;
@ -378,6 +503,8 @@ Operation* BuildFusedFuncOp(StringRef func_name, StringRef fused_func_type,
SmallVector<NamedAttribute, 4> attrs;
attrs.push_back(builder->getNamedAttr(
kTfLiteFunctionName, builder->getStringAttr(fused_func_type)));
attrs.push_back(builder->getNamedAttr(
kTfLiteFunctionInputIndex, builder->getI32ArrayAttr(input_indexes)));
FuncOp func_op = FuncOp::create(insert_before_op->getLoc(), func_name,
function_type, llvm::makeArrayRef(attrs));
module_op->push_back(func_op);
@ -507,6 +634,10 @@ LogicalResult ConvertOphintToStub(StringRef stub_name,
};
builder->getBlock()->walk(removeRemovableOps);
// Step 8: Topo sort to fix any invalid temporary IRs.
if (failed(TopoSortOperations(builder))) return failure();
return success();
}

View File

@ -20,6 +20,10 @@ include "mlir/Dialect/StandardOps/Ops.td"
include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td"
include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td"
def NonOpaqueElementsAttr : ElementsAttrBase<
CPred<"!$_self.isa<OpaqueElementsAttr>()">,
"non-opaque constant tensor">;
def F32ElementsAttr : ElementsAttrBase<
CPred<"$_self.cast<ElementsAttr>().getType().getElementType().isF32()">, "float constant tensor">;
@ -56,8 +60,13 @@ def ExtractSingleElementAsInteger : NativeCodeCall<
//===----------------------------------------------------------------------===//
// Nullary ops patterns.
//===----------------------------------------------------------------------===//
def : Pat<(TF_ConstOp ElementsAttr:$value), (TFL_ConstOp $value)>;
// Convert to std constant for statically shaped, non-opaque constants.
def : Pat<(TF_ConstOp:$res NonOpaqueElementsAttr:$value), (ConstantOp $value),
[(AnyStaticShapeTensor $res)], (addBenefit 10)>;
//===----------------------------------------------------------------------===//
// Unary ops patterns.
//===----------------------------------------------------------------------===//
@ -157,7 +166,8 @@ def : Pat<(TF_ZerosLikeOp $arg), (TFL_ZerosLikeOp $arg)>;
// The following two rules can both match an tf.Placeholder.input node with
// min/max/type attributes, so we increase the benefit of the first rule by one
// so the tfl.quantize and tfl.dequantize ops will be inserted if it matches.
def : Pat<(TF_PlaceholderInputOp $inputs, $min, $max, $type),
def : Pat<(TF_PlaceholderInputOp TensorOf<[F16, F32, F64]>:$inputs,
$min, $max, $type),
(TFL_DequantizeOp
(TFL_QuantizeOp
(TFL_InputOp $inputs),
@ -191,7 +201,8 @@ def : Pat<(TF_GatherV2Op $params, $indices,
def : Pat<(TF_FloorDivOp $l, $r), (TFL_FloorDivOp $l, $r)>;
def : Pat<(TF_NotEqualOp $l, $r), (TFL_NotEqualOp $l, $r)>;
def : Pat<(TF_NotEqualOp $l, $r, /*incompatible_shape_error=*/ConstBoolAttrTrue),
(TFL_NotEqualOp $l, $r)>;
def : Pat<(TF_LogicalAndOp $l, $r), (TFL_LogicalAndOp $l, $r)>;
@ -252,7 +263,7 @@ def : Pat<(TF_ReluOp (TF_SquaredDifferenceOp $l, $r)),
def : Pat<(TF_ReverseV2Op $arg0, $arg1), (TFL_ReverseV2Op $arg0, $arg1)>;
def : Pat<(TF_EqualOp $arg0, $arg1), (TFL_EqualOp $arg0, $arg1)>;
def : Pat<(TF_EqualOp $arg0, $arg1, /*incompatible_shape_error=*/ConstBoolAttrTrue), (TFL_EqualOp $arg0, $arg1)>;
def : Pat<(TF_PadOp $arg0, $arg1), (TFL_PadOp $arg0, $arg1)>;
@ -308,3 +319,11 @@ def : Pat<(TF_FloorModOp $arg0, $arg1), (TFL_FloorModOp $arg0, $arg1)>;
def : Pat<(TF_ExpOp $arg0), (TFL_ExpOp $arg0)>;
def : Pat<(TF_LRNOp $arg0, $radius, F32Attr:$bias, F32Attr:$alpha, F32Attr:$beta), (TFL_LocalResponseNormalizationOp $arg0, (convertIntAttrTo32Bit $radius), $bias, $alpha, $beta)>;
def : Pat<
(TF_NonMaxSuppressionV4Op $boxes, $scores, $max_output_size, $iou_threshold, $score_threshold, $pad_to_max_output_size),
(TFL_NonMaxSuppressionV4Op $boxes, $scores, $max_output_size, $iou_threshold, $score_threshold)>;
def : Pat<
(TF_NonMaxSuppressionV5Op $boxes, $scores, $max_output_size, $iou_threshold, $score_threshold, $soft_nms_sigma, $pad_to_max_output_size),
(TFL_NonMaxSuppressionV5Op $boxes, $scores, $max_output_size, $iou_threshold, $score_threshold, $soft_nms_sigma)>;

View File

@ -0,0 +1,228 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This transformation pass prepare the tflite fused ops for quantization.
#include "absl/memory/memory.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/None.h"
#include "llvm/ADT/Optional.h"
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:local_config_mlir
#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir
#include "mlir/IR/Builders.h" // TF:local_config_mlir
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
//===----------------------------------------------------------------------===//
// The LoadQuantizationRecipe Pass.
//
namespace mlir {
namespace TFL {
namespace {
// This pass loads the quantization recipe for the TFLite ops to be quantized.
// Specifically, it extends the fused ops with their internal implementation as
// op regions. Each ops in the region produces results with element type
// AnyQuantizedType, thus bitwidth, narrow_range, etc are included. The op also
// defines the op quantization traits, which are used to propgate the
// quantization parameters by the following passes.
struct LoadQuantizationRecipe : public FunctionPass<LoadQuantizationRecipe> {
void runOnFunction() override;
private:
void Initialize(LSTMOp lstm, OpBuilder* builder);
// Create LSTM gates with different weights for input, recurrent and
// cell state, and also the layer normalization parameters.
Operation* CreateGate(Location loc, Value* in, Value* in_w, Value* rec,
Value* rec_w,
llvm::Optional<std::pair<Value*, Value*>> cell,
Value* ln_w, Value* ln_bias, OpBuilder* builder);
Operation* CreateLayerNorm(Location loc, Value* in, Value* ln_w,
Value* ln_bias, OpBuilder* builder);
// Add the internal implementation of the LSTM to its regions.
void LoadForLSTMOp(LSTMOp lstm, OpBuilder* builder);
StringAttr none_af;
StringAttr fc_format;
BoolAttr keep_dims;
Type int8;
Type int16;
ConstantOp none_cst;
};
void LoadQuantizationRecipe::Initialize(LSTMOp lstm, OpBuilder* builder) {
Type expressed_type =
lstm.input()->getType().cast<ShapedType>().getElementType();
Type int8_storage_type = builder->getIntegerType(8);
Type int16_storage_type = builder->getIntegerType(16);
auto flag = quant::QuantizationFlags::FlagValue::Signed;
int64_t int8_min = quant::QuantizedType::getDefaultMininumForInteger(
flag, /*integralWidth=*/8);
int64_t int8_max = quant::QuantizedType::getDefaultMaxinumForInteger(
flag, /*integralWidth=*/8);
int64_t int16_min = quant::QuantizedType::getDefaultMininumForInteger(
flag, /*integralWidth=*/16);
int64_t int16_max = quant::QuantizedType::getDefaultMaxinumForInteger(
flag, /*integralWidth=*/16);
auto any_int8 = quant::AnyQuantizedType::get(
flag, int8_storage_type, expressed_type, int8_min, int8_max);
auto any_int16 = quant::AnyQuantizedType::get(
flag, int16_storage_type, expressed_type, int16_min, int16_max);
int8 = any_int8.castFromExpressedType(lstm.input()->getType());
int16 = any_int16.castFromExpressedType(lstm.input()->getType());
}
Operation* LoadQuantizationRecipe::CreateLayerNorm(Location loc, Value* in,
Value* ln_w, Value* ln_bias,
OpBuilder* builder) {
// Note that l2_normalization and add ops here are not the execution kernle
// implementation for layer_normalization and we just want to use them to
// model the quantization requirement.
auto l2_norm = builder->create<L2NormalizationOp>(loc, int16, in, none_af);
auto add = builder->create<AddOp>(loc, int16, in, l2_norm, none_af);
return builder->create<FullyConnectedOp>(loc, int16, add, ln_w, ln_bias,
none_af, fc_format, keep_dims);
}
Operation* LoadQuantizationRecipe::CreateGate(
Location loc, Value* in, Value* in_w, Value* rec, Value* rec_w,
llvm::Optional<std::pair<Value*, Value*>> cell, Value* ln_w, Value* ln_bias,
OpBuilder* builder) {
auto s1 = builder->create<FullyConnectedOp>(loc, int16, in, in_w, none_cst,
none_af, fc_format, keep_dims);
auto s2 = builder->create<FullyConnectedOp>(loc, int16, rec, rec_w, none_cst,
none_af, fc_format, keep_dims);
AddNOp s4;
if (cell.hasValue()) {
auto s3 = builder->create<MulOp>(loc, int16, cell.getValue().first,
cell.getValue().second, none_af);
s4 = builder->create<AddNOp>(
loc, int16,
llvm::ArrayRef<Value*>(
{*s1.output().begin(), *s2.output().begin(), s3.output()}));
} else {
s4 = builder->create<AddNOp>(
loc, int16,
llvm::ArrayRef<Value*>({*s1.output().begin(), *s2.output().begin()}));
}
auto s5 = CreateLayerNorm(loc, s4.sum(), ln_w, ln_bias, builder);
if (cell.hasValue()) {
return builder->create<LogisticOp>(loc, int16, s5->getResult(0));
} else {
return builder->create<TanhOp>(loc, int16, s5->getResult(0));
}
}
void LoadQuantizationRecipe::LoadForLSTMOp(LSTMOp lstm, OpBuilder* builder) {
Initialize(lstm, builder);
Region region;
region.push_back(new Block);
builder->setInsertionPointToEnd(&region.front());
Location loc = lstm.getLoc();
Type int32_type = builder->getIntegerType(32);
Type int32_tensor = builder->getTensorType(int32_type);
none_cst = builder->create<ConstantOp>(loc, builder->getNoneType(),
builder->getUnitAttr());
auto input_gate = CreateGate(
loc, lstm.input(), lstm.input_to_input_weights(),
lstm.input_activation_state(), lstm.recurrent_to_input_weights(),
llvm::Optional<std::pair<Value*, Value*>>(
{lstm.input_cell_state(), lstm.cell_to_input_weights()}),
lstm.input_layer_norm_coefficients(), lstm.input_gate_bias(), builder);
auto forget_gate = CreateGate(
loc, lstm.input(), lstm.input_to_forget_weights(),
lstm.input_activation_state(), lstm.recurrent_to_forget_weights(),
llvm::Optional<std::pair<Value*, Value*>>(
{lstm.input_cell_state(), lstm.cell_to_forget_weights()}),
lstm.forget_layer_norm_coefficients(), lstm.forget_gate_bias(), builder);
auto cell_gate = CreateGate(loc, lstm.input(), lstm.input_to_cell_weights(),
lstm.input_activation_state(),
lstm.recurrent_to_cell_weights(), llvm::None,
lstm.cell_layer_norm_coefficients(),
lstm.cell_bias(), builder);
auto forget_cell_state = builder->create<MulOp>(
loc, int16, forget_gate->getResult(0), lstm.input_cell_state(), none_af);
auto input_cell_state = builder->create<MulOp>(
loc, int16, input_gate->getResult(0), cell_gate->getResult(0), none_af);
auto new_cell = builder->create<AddOp>(loc, int16, forget_cell_state.output(),
input_cell_state.output(), none_af);
auto output_gate = CreateGate(
loc, lstm.input(), lstm.input_to_output_weights(),
lstm.input_activation_state(), lstm.recurrent_to_output_weights(),
llvm::Optional<std::pair<Value*, Value*>>(
{new_cell, lstm.cell_to_output_weights()}),
lstm.output_layer_norm_coefficients(), lstm.output_gate_bias(), builder);
auto new_cell_tanh = builder->create<TanhOp>(loc, int16, new_cell);
auto hidden_state = builder->create<MulOp>(
loc, int16, new_cell_tanh.y(), output_gate->getResult(0), none_af);
auto act = builder->create<FullyConnectedOp>(
loc, int8, hidden_state.output(), lstm.projection_weights(),
lstm.projection_bias(), none_af, fc_format, keep_dims);
// TODO(fengliuai): define and register the op in the QuantOps Dialect.
OperationState return_state(loc, "tf_quant.pseudo_return", act.getResult(0),
{int8}, {});
builder->createOperation(return_state);
lstm.internal().takeBody(region);
}
void LoadQuantizationRecipe::runOnFunction() {
FuncOp func = getFunction();
OpBuilder builder(func);
none_af = builder.getStringAttr("NONE");
fc_format = builder.getStringAttr("DEFAULT");
keep_dims = builder.getBoolAttr(false);
func.walk([&](Operation* op) {
if (auto lstm = llvm::dyn_cast<LSTMOp>(op)) {
LoadForLSTMOp(lstm, &builder);
}
// Handles other ops.
});
}
} // namespace
// Creates an instance of the TensorFlow Lite dialect LoadQuantizationRecipe
// pass.
std::unique_ptr<FunctionPassBase> CreateLoadQuantizationRecipePass() {
return absl::make_unique<LoadQuantizationRecipe>();
}
static PassRegistration<LoadQuantizationRecipe> pass(
"tfl-load-recipe", "Load TFL op quantization recipe");
} // namespace TFL
} // namespace mlir

View File

@ -429,12 +429,14 @@ LogicalResult LowerStaticTensorListPass::RewriteFunction(
} else if (auto tf_op = llvm::dyn_cast<TF::TensorListReserveOp>(op)) {
if (!(tf_op.element_dtype().isF16() || tf_op.element_dtype().isF32() ||
tf_op.element_dtype().isF64() ||
tf_op.element_dtype().isInteger(1) ||
tf_op.element_dtype().isInteger(8) ||
tf_op.element_dtype().isInteger(16) ||
tf_op.element_dtype().isInteger(32) ||
tf_op.element_dtype().isInteger(64))) {
return tf_op.emitError(
"requires element_dtype to be 8-bit/16-bit/32-bit/64-bit integer "
"requires element_dtype to be 1-bit/8-bit/16-bit/32-bit/64-bit "
"integer "
"or 16-bit/32-bit/64-bit "
"float type during TF Lite transformation pass");
}
@ -461,6 +463,10 @@ LogicalResult LowerStaticTensorListPass::RewriteFunction(
auto c = ConvertTFTensorListPushBack(context);
rewriter->setInsertionPoint(op);
c.matchAndRewrite(op, *rewriter);
} else if (auto tf_op = llvm::dyn_cast<TF::TensorListLengthOp>(op)) {
auto c = TFL::ConvertTFTensorListLength(context);
rewriter->setInsertionPoint(op);
c.matchAndRewrite(op, *rewriter);
} else if (auto tf_op = llvm::dyn_cast<TF::WhileOp>(op)) {
if (op->getAttr("T")) op->removeAttr(Identifier::get("T", context));
UpdateWhileFunctionType(tf_op);

View File

@ -122,6 +122,8 @@ class OperandHasRank<int n> : Constraint<
// Mul->Rsqrt->Sum->Square
// Currently L2Normalization doesn't support activation function
// in TFLite.
// TODO(karimnosseir): Add constraints that the kernel code assumes.
// constraint on axis and depth.
def : Pat<(TFL_MulOp $operand1,
(TFL_RsqrtOp
(TFL_SumOp
@ -130,13 +132,14 @@ def : Pat<(TFL_MulOp $operand1,
$keep_dims)),
TFL_AF_None),
(TFL_L2NormalizationOp $operand1, TFL_AF_None),
[(EqualOperands $operand1, $square_operand),
(OperandHasRank<1> $operand1)]>;
[(EqualOperands $operand1, $square_operand)]>;
// This pattern constructs L2NormalizationOp from
// Div->sqrt->Sum->Square
// Currently L2Normalization doesn't support activation function
// in TFLite.
// TODO(karimnosseir): Add constraints that the kernel code assumes.
// constraint on axis and depth.
def : Pat<(TFL_DivOp $operand1,
(TFL_SqrtOp
(TFL_SumOp
@ -145,5 +148,4 @@ def : Pat<(TFL_DivOp $operand1,
$keep_dims)),
TFL_AF_None),
(TFL_L2NormalizationOp $operand1, TFL_AF_None),
[(EqualOperands $operand1, $square_operand),
(OperandHasRank<1> $operand1)]>;
[(EqualOperands $operand1, $square_operand)]>;

View File

@ -18,6 +18,14 @@ include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td"
def FalseBoolAttr : AttrConstraint<CPred<"!$_self.getValue()">>;
def NonOpaqueElementsAttr : ElementsAttrBase<
CPred<"!$_self.isa<OpaqueElementsAttr>()">,
"non-opaque constant tensor">;
// Convert to std constant for statically shaped, non-opaque constants.
def : Pat<(TF_ConstOp:$res NonOpaqueElementsAttr:$value), (ConstantOp $value),
[(AnyStaticShapeTensor $res)]>;
// Converts tf.FusedBatchNorm & tf.FusedBatchNormV3 into a sequence of more primitive arithmetic
// operations. Specifically, performs the following calculation:
//
@ -81,8 +89,8 @@ class TFi32<int v> : ConstantAttr<I32ElementsAttr, !cast<string>(v)>;
def : Pat<(TF_MatMulOp $a, $b, ConstBoolAttrFalse:$at, ConstBoolAttrFalse),
(TF_MatMulOp $a, (TF_TransposeOp $b, (TF_SubOp (TF_RangeOp
/*start=*/(TF_RankOp $b),
/*limit=*/(ConstantOp TFi32<0>),
/*delta=*/(ConstantOp TFi32<-1>)), (ConstantOp TFi32<1>))),
/*limit=*/(TF_ConstOp TFi32<0>),
/*delta=*/(TF_ConstOp TFi32<-1>)), (TF_ConstOp TFi32<1>))),
$at, ConstBoolAttrTrue)>;
// Matmul with transpose on a to matmul with explicit transpose op and a not
@ -90,10 +98,12 @@ def : Pat<(TF_MatMulOp $a, $b, ConstBoolAttrFalse:$at, ConstBoolAttrFalse),
def : Pat<(TF_MatMulOp $a, $b, ConstBoolAttrTrue, $bt),
(TF_MatMulOp (TF_TransposeOp $a, (TF_SubOp (TF_RangeOp
/*start=*/(TF_RankOp $a),
/*limit=*/(ConstantOp TFi32<0>),
/*delta=*/(ConstantOp TFi32<-1>)), (ConstantOp TFi32<1>))), $b,
/*limit=*/(TF_ConstOp TFi32<0>),
/*delta=*/(TF_ConstOp TFi32<-1>)), (TF_ConstOp TFi32<1>))), $b,
ConstBoolAttrFalse, $bt)>;
// Partially supported in TFLite, treated as passthrough IdentityOp
def : Pat<(TF_CheckNumericsOp $arg, $msg), (TF_IdentityOp $arg)>;
def : Pat<(TF_SnapshotOp $arg), (TF_IdentityOp $arg)>;
def : Pat<(TF_StopGradientOp $arg), (TF_IdentityOp $arg)>;

View File

@ -50,6 +50,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
#include "tensorflow/compiler/mlir/lite/transforms/unroll_batch_matmul.h"
#include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h"
#include "tensorflow/compiler/mlir/lite/utils/validators.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
@ -246,7 +247,8 @@ struct ConvertTFConvOp : public RewritePattern {
filter_type.getShape());
auto bias_type = rewriter.getTensorType({bias_dim}, elem_type);
auto bias_attr = rewriter.getZeroAttr(bias_type);
auto bias = rewriter.create<ConstantOp>(op->getLoc(), bias_type, bias_attr);
auto bias =
rewriter.create<TF::ConstOp>(op->getLoc(), bias_type, bias_attr);
auto *conv_state = static_cast<ConvertTFConvOpMatchState *>(state.get());
auto conv_op = static_cast<const ConcreteType *>(this)->createTFLOp(
@ -297,7 +299,7 @@ class ConvertTFConv2D : public ConvertTFConvOp<ConvertTFConv2D, TF::Conv2DOp> {
rewriter.getIntegerType(32));
auto perm_attr =
DenseElementsAttr::get(perm_type, llvm::makeArrayRef<int>(perm));
auto perm_op = rewriter.create<ConstantOp>(loc, perm_type, perm_attr);
auto perm_op = rewriter.create<TF::ConstOp>(loc, perm_type, perm_attr);
// Create tensor type for the transpose result.
auto filter_type = filter->getType().cast<RankedTensorType>();
@ -366,7 +368,7 @@ class ConvertTFDepthwiseConv2dNative
auto shape_type = rewriter.getTensorType({4}, rewriter.getIntegerType(64));
auto shape_attr =
DenseElementsAttr::get(shape_type, llvm::makeArrayRef(result_shape));
auto shape = rewriter.create<ConstantOp>(loc, shape_type, shape_attr);
auto shape = rewriter.create<TF::ConstOp>(loc, shape_type, shape_attr);
return rewriter.create<TF::ReshapeOp>(loc, result_type, filter, shape);
}
@ -377,6 +379,11 @@ class ConvertTFDepthwiseConv2dNative
void PrepareTFPass::runOnFunction() {
OwningRewritePatternList patterns;
auto func = getFunction();
patterns.insert<ConvertTFBatchMatMulOp<TF::BatchMatMulOp>,
ConvertTFBatchMatMulOp<TF::BatchMatMulV2Op>>(&getContext());
applyPatternsGreedily(func, patterns);
// This pattern was intented to uses TFL QDQs to preserve the quantization
// parameters from the TF Quant ops, thus this pattern should run with the
// first `applyPatternsGreedily` method, which would otherwise removes the

View File

@ -14,9 +14,13 @@ limitations under the License.
==============================================================================*/
include "mlir/IR/OpBase.td"
include "mlir/Dialect/StandardOps/Ops.td"
include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td"
include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td"
def CreateTFShapeOp : NativeCodeCall<
"$_builder.create<TF::ShapeOp>($0->getLoc(), $1, $2)">;
//===----------------------------------------------------------------------===//
// TensorList transformation patterns.
// Note that the pattern below rewrites `TensorList` tensors (which has type DT_VARIANT)
@ -34,3 +38,11 @@ def ConvertTFTensorListStack : Pat<
def ConvertTFTensorListGetItem : Pat<
(TF_TensorListGetItemOp $input, $index, $element_shape),
(TF_GatherOp $input, $index, (NativeCodeCall<"$_builder.getBoolAttr(true)">))>;
// TensorListLength is equivalent to the size of the first dimension of the
// input tensorlist, rewrite it to a combination of Gather and Shape op.
def ConvertTFTensorListLength: Pat<
(TF_TensorListLengthOp:$old_value $input),
(TF_GatherOp
(CreateTFShapeOp $old_value, $input, /*use 32bit*/ConstBoolAttrTrue),
(ConstantOp ConstantAttr<I32ElementsAttr, "0">), ConstBoolAttrTrue)>;

View File

@ -0,0 +1,309 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/lite/transforms/unroll_batch_matmul.h"
#include <climits>
#include <cstdint>
#include "absl/memory/memory.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringSwitch.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
#include "mlir/Analysis/LoopAnalysis.h" // TF:local_config_mlir
#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:local_config_mlir
#include "mlir/Dialect/QuantOps/UniformSupport.h" // TF:local_config_mlir
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
#include "mlir/IR/OpImplementation.h" // TF:local_config_mlir
#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
#include "mlir/Support/Functional.h" // TF:local_config_mlir
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
#include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h"
#include "tensorflow/compiler/mlir/lite/utils/validators.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/core/util/matmul_bcast.h"
namespace mlir {
namespace TFL {
namespace {
// Unrolls a BatchMatMul on the batch dimension. We need to slice each batch out
// of the inputs, matmul them individually, then stack them all back together at
// the end.
struct UnrollBatchMatMulPass : public FunctionPass<UnrollBatchMatMulPass> {
void runOnFunction() override;
};
void UnrollBatchMatMulPass::runOnFunction() {
OwningRewritePatternList patterns;
auto func = getFunction();
patterns.insert<ConvertTFBatchMatMulOp<TF::BatchMatMulOp>,
ConvertTFBatchMatMulOp<TF::BatchMatMulV2Op>>(&getContext());
applyPatternsGreedily(func, patterns);
}
} // namespace
template <typename BatchMatMulOpType>
TF::ReshapeOp ConvertTFBatchMatMulOp<BatchMatMulOpType>::createReshapeOp(
Value* value, ArrayRef<int64_t> shape, Type elementType, Location loc,
PatternRewriter& rewriter) {
int64_t shape_rank = shape.size();
auto shapeSpecType =
rewriter.getTensorType({shape_rank}, rewriter.getIntegerType(64));
Type resultType = rewriter.getTensorType(shape, elementType);
auto constant_attr = DenseElementsAttr::get(shapeSpecType, shape);
auto shapeTensor =
rewriter.create<ConstantOp>(loc, shapeSpecType, constant_attr);
return rewriter.create<TF::ReshapeOp>(loc, resultType, /*tensor=*/value,
/*shape=*/shapeTensor);
}
template <typename BatchMatMulOpType>
std::vector<Value*> ConvertTFBatchMatMulOp<BatchMatMulOpType>::sliceInput(
Value* value, int batch_size, Location loc, PatternRewriter& rewriter) {
RankedTensorType tensorType = value->getType().cast<RankedTensorType>();
Type elementType = tensorType.getElementType();
int rank = tensorType.getShape().size();
int num_rows = tensorType.getShape()[rank - 2];
int num_cols = tensorType.getShape()[rank - 1];
// Reshape to rank-3 Tensor with first dimension as the batch size.
auto reshapeOp = createReshapeOp(value, {batch_size, num_rows, num_cols},
elementType, loc, rewriter);
SmallVector<int64_t, 3> sliceSize = {1, num_rows, num_cols};
std::vector<Value*> sliced;
Type int64Type = rewriter.getIntegerType(64);
Type sliceResultType = rewriter.getTensorType(sliceSize, elementType);
// Slice along each batch index and remember the slice output for future
// use.
for (int batch_idx = 0; batch_idx < batch_size; ++batch_idx) {
auto vector3Type = rewriter.getTensorType({3}, int64Type);
auto begin_attr =
DenseElementsAttr::get<int64_t>(vector3Type, {batch_idx, 0, 0});
auto size_attr = DenseElementsAttr::get<int64_t>(vector3Type, sliceSize);
auto begin = rewriter.create<ConstantOp>(loc, vector3Type, begin_attr);
auto size = rewriter.create<ConstantOp>(loc, vector3Type, size_attr);
auto sliceOp =
rewriter.create<TF::SliceOp>(loc, sliceResultType,
/*input=*/reshapeOp.output(), begin, size);
// Squeeze matrix, i.e. reshape [1, num_rows, num_cols] -> [num_rows,
// num_cols]
auto squeezeOp = createReshapeOp(sliceOp.output(), {num_rows, num_cols},
elementType, loc, rewriter);
sliced.emplace_back(squeezeOp.output());
}
return sliced;
}
template <typename BatchMatMulOpType>
TF::TransposeOp ConvertTFBatchMatMulOp<BatchMatMulOpType>::createTransposeOp(
Value* value, Location loc, PatternRewriter& rewriter) {
auto valueType = value->getType().cast<RankedTensorType>();
auto shape = valueType.getShape();
int dims = shape.size();
std::vector<int32_t> perm(dims);
for (int i = 0; i < dims - 2; i++) {
perm[i] = i;
}
perm[dims - 2] = dims - 1;
perm[dims - 1] = dims - 2;
auto perm_type = rewriter.getTensorType({static_cast<int32_t>(perm.size())},
rewriter.getIntegerType(32));
auto perm_attr = DenseElementsAttr::get(perm_type, llvm::makeArrayRef(perm));
auto perm_op = rewriter.create<ConstantOp>(loc, perm_type, perm_attr);
std::vector<int64_t> transposed_shape(shape.begin(), shape.end());
int64_t r = transposed_shape[dims - 1];
int64_t c = transposed_shape[dims - 2];
transposed_shape[dims - 1] = c;
transposed_shape[dims - 2] = r;
auto transposed_type =
rewriter.getTensorType(transposed_shape, valueType.getElementType());
return rewriter.create<TF::TransposeOp>(loc, transposed_type, value, perm_op);
}
template <typename BatchMatMulOpType>
TF::PackOp ConvertTFBatchMatMulOp<BatchMatMulOpType>::createMatMulOps(
const std::vector<Value*>& sliced_lhs,
const std::vector<Value*>& sliced_rhs, const tensorflow::MatMulBCast& bcast,
int rows, int cols, Type elementType, Location loc,
PatternRewriter& rewriter) {
auto matmulType = rewriter.getTensorType({rows, cols}, elementType);
std::vector<Value*> matmuls;
for (int batch_idx = 0; batch_idx < bcast.output_batch_size(); ++batch_idx) {
int lhs_batch_idx, rhs_batch_idx;
if (bcast.IsBroadcastingRequired()) {
lhs_batch_idx = bcast.x_batch_indices()[batch_idx];
rhs_batch_idx = bcast.y_batch_indices()[batch_idx];
} else {
lhs_batch_idx = batch_idx;
rhs_batch_idx = batch_idx;
}
auto false_attr = rewriter.getBoolAttr(false);
auto matmul = rewriter.create<TF::MatMulOp>(loc, matmulType,
/*a=*/sliced_lhs[lhs_batch_idx],
/*b=*/sliced_rhs[rhs_batch_idx],
/*transpose_a=*/false_attr,
/*transpose_b=*/false_attr);
matmuls.emplace_back(matmul.product());
}
// Combine the result of each individual MatMul into a rank-3 Tensor.
Type packedType = rewriter.getTensorType(
{bcast.output_batch_size(), rows, cols}, elementType);
auto N = rewriter.getI64IntegerAttr(matmuls.size());
auto axis = rewriter.getI64IntegerAttr(0);
return rewriter.create<TF::PackOp>(loc, packedType,
/*values=*/matmuls, N, axis);
}
template <typename BatchMatMulOpType>
PatternMatchResult ConvertTFBatchMatMulOp<BatchMatMulOpType>::matchAndRewrite(
BatchMatMulOpType op, PatternRewriter& rewriter) const {
Value* input_lhs = op.x();
Value* input_rhs = op.y();
if (!input_lhs->getType().isa<RankedTensorType>()) {
// LHS must be a ranked tensor type
return this->matchFailure();
}
if (!input_rhs->getType().isa<RankedTensorType>()) {
// RHS must be a ranked tensor type
return this->matchFailure();
}
auto lhs_type = input_lhs->getType().cast<RankedTensorType>();
auto rhs_type = input_rhs->getType().cast<RankedTensorType>();
auto elementType = lhs_type.getElementType();
if (elementType != rhs_type.getElementType()) {
// The element type of LHS must be the same with element type of RHS
return this->matchFailure();
}
auto lhs_shape = lhs_type.getShape();
auto rhs_shape = rhs_type.getShape();
Location loc = op.getLoc();
// Transpose LHS input if necessary.
if (op.adj_x()) {
input_lhs = createTransposeOp(input_lhs, loc, rewriter);
lhs_type = input_lhs->getType().cast<RankedTensorType>();
lhs_shape = lhs_type.getShape();
}
// Transpose RHS input if necessary.
if (op.adj_y()) {
input_rhs = createTransposeOp(input_rhs, loc, rewriter);
rhs_type = input_rhs->getType().cast<RankedTensorType>();
rhs_shape = rhs_type.getShape();
}
// Ensure that input ranks are at least 2 and batch shapes are
// broadcastable.
const int dims_a = lhs_shape.size();
const int dims_b = rhs_shape.size();
if (dims_a < 2 || dims_b < 2) {
// Both inputs must have rank >= 2
return this->matchFailure();
}
if (lhs_shape[dims_a - 1] != rhs_shape[dims_b - 2]) {
// Input dimensions must be compatible for multipication.
return this->matchFailure();
}
if (dims_a == 2 && dims_b == 2) {
// When both inputs are matrices, just replace the op to a matmul op.
Type resultType =
rewriter.getTensorType({lhs_shape[0], rhs_shape[1]}, elementType);
auto false_attr = rewriter.getBoolAttr(false);
rewriter.replaceOpWithNewOp<TF::MatMulOp>(op, resultType,
/*a=*/input_lhs,
/*b=*/input_rhs,
/*transpose_a=*/false_attr,
/*transpose_b=*/false_attr);
return this->matchSuccess();
}
tensorflow::MatMulBCast bcast(absl::InlinedVector<tensorflow::int64, 4>(
lhs_shape.begin(), lhs_shape.end()),
absl::InlinedVector<tensorflow::int64, 4>(
rhs_shape.begin(), rhs_shape.end()));
if (!bcast.IsValid()) {
// Input batch dimensions must be broadcastable
return this->matchFailure();
}
// Compute slices for each batch in the LHS and RHS.
std::vector<Value*> sliced_lhs =
sliceInput(input_lhs, bcast.x_batch_size(), loc, rewriter);
std::vector<Value*> sliced_rhs =
sliceInput(input_rhs, bcast.y_batch_size(), loc, rewriter);
// Compute (single batch) MatMul for each output batch. The MatMul outputs
// are then packed together into one output Tensor.
auto packOp =
createMatMulOps(sliced_lhs, sliced_rhs, bcast, lhs_shape[dims_a - 2],
rhs_shape[dims_b - 1], elementType, loc, rewriter);
// Reshape the rank-3 Tensor into the correct output shape.
const auto& resultBatchShape = bcast.output_batch_shape().dim_sizes();
std::vector<int64_t> resultShape(resultBatchShape.begin(),
resultBatchShape.end());
resultShape.push_back(lhs_shape[dims_a - 2]);
resultShape.push_back(rhs_shape[dims_b - 1]);
auto reshapeOp =
createReshapeOp(packOp.output(), resultShape, elementType, loc, rewriter);
rewriter.replaceOp(op, reshapeOp.output());
return this->matchSuccess();
}
static PassRegistration<UnrollBatchMatMulPass> pass(
"tfl-unroll-batch-matmul",
"Unroll TF BatchMatMul op into Reshape, Slice, MatMul, Pack ops.");
} // namespace TFL
} // namespace mlir

View File

@ -0,0 +1,60 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_UNROLL_BATCH_MATMUL_H_
#define TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_UNROLL_BATCH_MATMUL_H_
#include "llvm/ADT/ArrayRef.h"
#include "mlir/IR/Location.h" // TF:local_config_mlir
#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
#include "mlir/IR/TypeUtilities.h" // TF:local_config_mlir
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/core/util/matmul_bcast.h"
namespace mlir {
namespace TFL {
// Unroll tf.BatchMatMulV2 op into a sequence of TF ops. Since TFLite does not
// support BatchMatMul operation, it unrolls a BatchMatMul op into tf.Reshape,
// tf.Slice, tf.MatMul, tf.Pack, and tf.Reshape ops.
template <typename BatchMatMulOpType>
class ConvertTFBatchMatMulOp : public OpRewritePattern<BatchMatMulOpType> {
using OpRewritePattern<BatchMatMulOpType>::OpRewritePattern;
static TF::ReshapeOp createReshapeOp(Value* value, ArrayRef<int64_t> shape,
Type elementType, Location loc,
PatternRewriter& rewriter);
static std::vector<Value*> sliceInput(Value* value, int batch_size,
Location loc,
PatternRewriter& rewriter);
static TF::TransposeOp createTransposeOp(Value* value, Location loc,
PatternRewriter& rewriter);
static TF::PackOp createMatMulOps(const std::vector<Value*>& sliced_lhs,
const std::vector<Value*>& sliced_rhs,
const tensorflow::MatMulBCast& bcast,
int rows, int cols, Type elementType,
Location loc, PatternRewriter& rewriter);
PatternMatchResult matchAndRewrite(BatchMatMulOpType op,
PatternRewriter& rewriter) const override;
};
} // namespace TFL
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_UNROLL_BATCH_MATMUL_H_

View File

@ -0,0 +1,456 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/lite/utils/lstm_utils.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/None.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Casting.h"
#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
#include "mlir/IR/Builders.h" // TF:local_config_mlir
#include "mlir/IR/Function.h" // TF:local_config_mlir
#include "mlir/IR/Identifier.h" // TF:local_config_mlir
#include "mlir/IR/Location.h" // TF:local_config_mlir
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
#include "mlir/IR/OpDefinition.h" // TF:local_config_mlir
#include "mlir/IR/Operation.h" // TF:local_config_mlir
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
#include "mlir/IR/Types.h" // TF:local_config_mlir
#include "mlir/IR/Value.h" // TF:local_config_mlir
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
namespace mlir {
namespace TFL {
namespace {
Value* CreateI32SplatConst(OpBuilder* builder, ArrayRef<int64_t> shape,
int32_t val, mlir::Location location) {
auto type = builder->getTensorType(shape, builder->getIntegerType(32));
auto attr = DenseElementsAttr::get(type, val);
return builder->create<ConstantOp>(location, type, attr);
}
Value* CreateF32SplatConst(OpBuilder* builder, ArrayRef<int64_t> shape,
float val, mlir::Location location) {
auto type = builder->getTensorType(shape, builder->getF32Type());
auto attr = DenseElementsAttr::get(type, val);
return builder->create<ConstantOp>(location, type, attr);
}
Value* CreateI64DenseConst(OpBuilder* builder, ArrayRef<int64_t> shape,
ArrayRef<int64_t> values, mlir::Location location) {
auto type = builder->getTensorType(static_cast<int>(shape.size()),
builder->getIntegerType(64));
auto attr = DenseElementsAttr::get(type, values);
return builder->create<ConstantOp>(location, type, attr);
}
Value* CreateNoneValue(OpBuilder* builder, mlir::Location location) {
return builder->create<mlir::ConstantOp>(location, builder->getNoneType(),
builder->getUnitAttr());
}
Value* Transpose2D(OpBuilder* builder, Value* value_to_transpose,
RankedTensorType type, mlir::Location location) {
// Create a constant op for transpose permutation.
SmallVector<int64_t, 2> perm = {1, 0};
auto perm_op = CreateI64DenseConst(builder, perm, perm, location);
// Create tensor type for the transpose result.
auto transpose_type = type;
auto transpose_shape = functional::map(
[transpose_type](int64_t dim) { return transpose_type.getDimSize(dim); },
perm);
auto elem_type = transpose_type.getElementType();
auto result_type = builder->getTensorType(transpose_shape, elem_type);
return builder->create<TF::TransposeOp>(location, result_type,
value_to_transpose, perm_op);
}
Value* SliceRankedTensor(OpBuilder* builder, Value* input,
ArrayRef<int64_t> begin_shape,
ArrayRef<int64_t> begin_values,
ArrayRef<int64_t> size_shape,
ArrayRef<int64_t> size_values,
mlir::Location location) {
// Create a dense constant op for slice's begin
auto slice_i2c_begin =
CreateI64DenseConst(builder, begin_shape, begin_values, location);
// Create a dense constant op for slice's size
auto slice_i2c_size =
CreateI64DenseConst(builder, size_shape, size_values, location);
return builder->create<TF::SliceOp>(
location,
builder->getTensorType(
size_values,
input->getType().cast<RankedTensorType>().getElementType()),
input, slice_i2c_begin, slice_i2c_size);
}
} // namespace
void ConvertLSTMCellSimpleToFusedLSTM::SetWeightForInputToCellGate() {
SmallVector<int64_t, 2> begin_i2c_values = {0, 0};
input2cell_ = SliceRankedTensor(
&builder_, weight_transposed_, weight_slice_shape_, begin_i2c_values,
weight_slice_shape_, weight_slice_size_input_values_,
fused_func_op_.getLoc());
}
void ConvertLSTMCellSimpleToFusedLSTM::SetWeightForInputToInputGate() {
SmallVector<int64_t, 2> begin_i2i_values = {n_cell_, 0};
input2input_ = couple_input_forget_gates_
? none_
: SliceRankedTensor(&builder_, weight_transposed_,
weight_slice_shape_, begin_i2i_values,
weight_slice_shape_,
weight_slice_size_input_values_,
fused_func_op_.getLoc());
}
void ConvertLSTMCellSimpleToFusedLSTM::SetWeightForInputToForgetGate() {
int input_forget_start = couple_input_forget_gates_ ? n_cell_ : 2 * n_cell_;
SmallVector<int64_t, 2> begin_i2f_values = {input_forget_start, 0};
input2forget_ = SliceRankedTensor(
&builder_, weight_transposed_, weight_slice_shape_, begin_i2f_values,
weight_slice_shape_, weight_slice_size_input_values_,
fused_func_op_.getLoc());
}
void ConvertLSTMCellSimpleToFusedLSTM::SetWeightForInputToOutputGate() {
int input_output_start =
couple_input_forget_gates_ ? 2 * n_cell_ : 3 * n_cell_;
SmallVector<int64_t, 2> begin_i2o_values = {input_output_start, 0};
input2output_ = SliceRankedTensor(
&builder_, weight_transposed_, weight_slice_shape_, begin_i2o_values,
weight_slice_shape_, weight_slice_size_input_values_,
fused_func_op_.getLoc());
}
void ConvertLSTMCellSimpleToFusedLSTM::SetWeightForRecurrentToCellGate() {
SmallVector<int64_t, 2> begin_rec2c_values = {0, n_input_};
rec2cell_ = SliceRankedTensor(
&builder_, weight_transposed_, weight_slice_shape_, begin_rec2c_values,
weight_slice_shape_, weight_slice_size_recurrent_values_,
fused_func_op_.getLoc());
}
void ConvertLSTMCellSimpleToFusedLSTM::SetWeightForRecurrentToInputGate() {
SmallVector<int64_t, 2> begin_rec2i_values = {n_cell_, n_input_};
rec2input_ = couple_input_forget_gates_
? none_
: SliceRankedTensor(&builder_, weight_transposed_,
weight_slice_shape_, begin_rec2i_values,
weight_slice_shape_,
weight_slice_size_recurrent_values_,
fused_func_op_.getLoc());
}
void ConvertLSTMCellSimpleToFusedLSTM::SetWeightForRecurrentToForgetGate() {
int rec_forget_start = couple_input_forget_gates_ ? n_cell_ : 2 * n_cell_;
SmallVector<int64_t, 2> begin_rec2f_values = {rec_forget_start, n_input_};
rec2forget_ = SliceRankedTensor(
&builder_, weight_transposed_, weight_slice_shape_, begin_rec2f_values,
weight_slice_shape_, weight_slice_size_recurrent_values_,
fused_func_op_.getLoc());
}
void ConvertLSTMCellSimpleToFusedLSTM::SetWeightForRecurrentToOutputGate() {
int rec_output_start = couple_input_forget_gates_ ? 2 * n_cell_ : 3 * n_cell_;
SmallVector<int64_t, 2> begin_rec2o_values = {rec_output_start, n_input_};
rec2output_ = SliceRankedTensor(
&builder_, weight_transposed_, weight_slice_shape_, begin_rec2o_values,
weight_slice_shape_, weight_slice_size_recurrent_values_,
fused_func_op_.getLoc());
}
void ConvertLSTMCellSimpleToFusedLSTM::SetBiasToCellGate() {
SmallVector<int64_t, 1> begin_bias2c_values = {0};
bias2cell_ = SliceRankedTensor(&builder_, bias_, bias_slice_shape_,
begin_bias2c_values, bias_slice_shape_,
bias_size_values_, fused_func_op_.getLoc());
}
void ConvertLSTMCellSimpleToFusedLSTM::SetBiasToInputGate() {
SmallVector<int64_t, 1> begin_bias2i_values = {n_cell_};
bias2input_ =
couple_input_forget_gates_
? none_
: SliceRankedTensor(&builder_, bias_, bias_slice_shape_,
begin_bias2i_values, bias_slice_shape_,
bias_size_values_, fused_func_op_.getLoc());
}
void ConvertLSTMCellSimpleToFusedLSTM::SetBiasToForgetGate() {
int bias_forget_start = couple_input_forget_gates_ ? n_cell_ : 2 * n_cell_;
SmallVector<int64_t, 1> begin_bias2f_values = {bias_forget_start};
bias2forget_ = SliceRankedTensor(&builder_, bias_, bias_slice_shape_,
begin_bias2f_values, bias_slice_shape_,
bias_size_values_, fused_func_op_.getLoc());
}
void ConvertLSTMCellSimpleToFusedLSTM::SetBiasToOutputGate() {
int bias_output_start =
couple_input_forget_gates_ ? 2 * n_cell_ : 3 * n_cell_;
SmallVector<int64_t, 1> begin_bias2o_values = {bias_output_start};
bias2output_ = SliceRankedTensor(&builder_, bias_, bias_slice_shape_,
begin_bias2o_values, bias_slice_shape_,
bias_size_values_, fused_func_op_.getLoc());
}
void ConvertLSTMCellSimpleToFusedLSTM::SetProjection() {
SmallVector<int64_t, 2> projection_slice_shape = {
1, num_cols_projection_transposed_};
SmallVector<int64_t, 2> projection_slice_size_values = {n_output_, n_cell_};
SmallVector<int64_t, 2> projection_slice_begin_values = {0, 0};
proj_weight_ =
!projection_
? none_
: SliceRankedTensor(
&builder_, projection_transposed_, projection_slice_shape,
projection_slice_begin_values, projection_slice_shape,
projection_slice_size_values, fused_func_op_.getLoc());
}
void ConvertLSTMCellSimpleToFusedLSTM::SetProjectionBias() {
proj_bias_ = !projection_type_
? none_
: CreateF32SplatConst(&builder_, {n_output_}, 0,
fused_func_op_.getLoc());
}
void ConvertLSTMCellSimpleToFusedLSTM::SetInputActivationState() {
input_activation_state_ = CreateF32SplatConst(&builder_, {1, n_output_}, 0,
fused_func_op_.getLoc());
}
void ConvertLSTMCellSimpleToFusedLSTM::SetInputCellState() {
input_cell_state_ =
CreateF32SplatConst(&builder_, {1, n_cell_}, 0, fused_func_op_.getLoc());
}
void ConvertLSTMCellSimpleToFusedLSTM::SetCellLayerNormCoefficients() {
cell_layer_norm_coefficients_ = none_;
}
void ConvertLSTMCellSimpleToFusedLSTM::SetInputLayerNormCoefficients() {
input_layer_norm_coefficients_ = none_;
}
void ConvertLSTMCellSimpleToFusedLSTM::SetForgetLayerNormCoefficients() {
forget_layer_norm_coefficients_ = none_;
}
void ConvertLSTMCellSimpleToFusedLSTM::SetOutputLayerNormCoefficients() {
output_layer_norm_coefficients_ = none_;
}
void ConvertLSTMCellSimpleToFusedLSTM::GenerateFusedOpOperands() {
// Transpose both weight and projection.
weight_transposed_ =
Transpose2D(&builder_, weight_, weight_type_, fused_func_op_.getLoc());
projection_transposed_ = Transpose2D(&builder_, projection_, projection_type_,
fused_func_op_.getLoc());
none_ = CreateNoneValue(&builder_, fused_func_op_.getLoc());
// Extract input to cifg gates via slicing the weight tensor
SetWeightForInputToCellGate();
SetWeightForInputToInputGate();
SetWeightForInputToForgetGate();
SetWeightForInputToOutputGate();
// Extract recurrent to cifg gates via slicing the weight tensor
SetWeightForRecurrentToCellGate();
SetWeightForRecurrentToInputGate();
SetWeightForRecurrentToForgetGate();
SetWeightForRecurrentToOutputGate();
// Extract bias to cifg gates via slicing the bias tensor
SetBiasToCellGate();
SetBiasToInputGate();
SetBiasToForgetGate();
SetBiasToOutputGate();
// Extract projection and set an empty projection bias
SetProjection();
SetProjectionBias();
// Set the variable tensors
SetInputActivationState();
SetInputCellState();
// Extract the layer norm coefficients
SetCellLayerNormCoefficients();
SetInputLayerNormCoefficients();
SetForgetLayerNormCoefficients();
SetOutputLayerNormCoefficients();
}
void ConvertLSTMCellSimpleToFusedLSTM::UpdateFuncSignature() {
// https://github.com/tensorflow/community/pull/113
auto attr = fused_func_op_.getAttrOfType<StringAttr>("tf_.implements");
if (!attr) {
fused_func_op_.setAttr("tf._implements",
builder_.getStringAttr(GetCompositeOpName()));
}
SmallVector<int64_t, 2> output_shape{1, n_output_};
auto input_types = fused_func_op_.getType().getInputs();
auto output_type = builder_.getTensorType(
output_shape,
input_->getType().cast<RankedTensorType>().getElementType());
fused_func_op_.setType(mlir::FunctionType::get(input_types, output_type,
fused_func_op_.getContext()));
}
void ConvertLSTMCellSimpleToFusedLSTM::RewriteFunc() {
// Update the func signature, based on output shape.
// The func will ultimately return the output of the fused
// LSTM op.
UpdateFuncSignature();
// Transoform the weights, projection, bias and layer norm coefficients
// to generate operands for the TFL fused LSTM op.
GenerateFusedOpOperands();
// Create the fused LSTM op.
SmallVector<int64_t, 2> output_shape = {1, n_output_};
auto result_type = builder_.getTensorType(
output_shape,
input_->getType().cast<RankedTensorType>().getElementType());
lstm_ = builder_.create<mlir::TFL::LSTMOp>(
fused_func_op_.getLoc(), result_type, input_, input2input_, input2forget_,
input2cell_, input2output_, rec2input_, rec2forget_, rec2cell_,
rec2output_, /*cell_to_input_weights*/ none_,
/*cell_to_forget_weights*/ none_,
/*cell_to_output_weights*/ none_, bias2input_, bias2forget_, bias2cell_,
bias2output_, proj_weight_, proj_bias_, input_activation_state_,
input_cell_state_, input_layer_norm_coefficients_,
forget_layer_norm_coefficients_, cell_layer_norm_coefficients_,
output_layer_norm_coefficients_, builder_.getStringAttr("TANH"),
builder_.getF32FloatAttr(10.0), builder_.getF32FloatAttr(0.0),
builder_.getStringAttr("FULL"));
builder_.create<mlir::ReturnOp>(fused_func_op_.getLoc(), lstm_.getResult());
}
LogicalResult ConvertLSTMCellSimpleToFusedLSTM::Initialize() {
num_gates_ = couple_input_forget_gates_ ? 3 : 4;
input_ = fused_func_op_.getArgument(0);
bias_ = fused_func_op_.getArgument(2);
weight_ = fused_func_op_.getArgument(1);
weight_type_ = weight_->getType().cast<RankedTensorType>();
if (weight_type_.getRank() != 2) {
return fused_func_op_.emitError() << "The weight tensor was not of rank 2";
}
if (weight_type_.getDimSize(1) % num_gates_ != 0) {
return fused_func_op_.emitError()
<< "Invalid dimension 1 of weight tensor, "
"should be divisible by the number of gates";
}
n_cell_ = weight_type_.getDimSize(1) / num_gates_;
projection_ = fused_func_op_.getArgument(3);
projection_type_ = projection_->getType().cast<RankedTensorType>();
if (projection_type_.getRank() != 2) {
n_output_ = n_cell_;
} else {
n_output_ = projection_type_.getDimSize(1);
}
n_input_ = weight_type_.getDimSize(0) - n_output_;
num_cols_weight_transposed_ = weight_type_.getDimSize(0);
num_cols_projection_transposed_ = projection_type_.getDimSize(0);
bias_slice_shape_ = {n_cell_};
bias_size_values_ = {n_cell_};
weight_slice_shape_ = {1, num_cols_weight_transposed_};
weight_slice_size_input_values_ = {n_cell_, n_input_};
weight_slice_size_recurrent_values_ = {n_cell_, n_output_};
return success();
}
LogicalResult ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM::Initialize() {
if (failed(ConvertLSTMCellSimpleToFusedLSTM::Initialize())) {
return fused_func_op_.emitError()
<< "Specified LayerNormalizedLSTMCellSimple was not of the expected "
"interface and cannot not be converted to the fused LSTM op";
}
layer_norm_scale_ = fused_func_op_.getArgument(4);
layer_norm_scale_type_ =
layer_norm_scale_->getType().cast<RankedTensorType>();
if (layer_norm_scale_type_.getRank() != 1) {
return fused_func_op_.emitError()
<< "The layer_norm_scale tensor was not of rank 1";
}
layer_norm_slice_shape_ = {n_cell_};
layer_norm_size_values_ = {n_cell_};
return success();
}
void ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM::
SetCellLayerNormCoefficients() {
SmallVector<int64_t, 1> begin_cell_layer_norm_values = {0};
cell_layer_norm_coefficients_ =
SliceRankedTensor(&builder_, layer_norm_scale_, layer_norm_slice_shape_,
begin_cell_layer_norm_values, layer_norm_slice_shape_,
layer_norm_size_values_, fused_func_op_.getLoc());
}
void ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM::
SetInputLayerNormCoefficients() {
SmallVector<int64_t, 1> begin_input_layer_norm_values = {n_cell_};
input_layer_norm_coefficients_ =
couple_input_forget_gates_
? none_
: SliceRankedTensor(
&builder_, layer_norm_scale_, layer_norm_slice_shape_,
begin_input_layer_norm_values, layer_norm_slice_shape_,
layer_norm_size_values_, fused_func_op_.getLoc());
}
void ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM::
SetForgetLayerNormCoefficients() {
SmallVector<int64_t, 1> begin_forget_layer_norm_values = {2 * n_cell_};
forget_layer_norm_coefficients_ =
SliceRankedTensor(&builder_, layer_norm_scale_, layer_norm_slice_shape_,
begin_forget_layer_norm_values, layer_norm_slice_shape_,
layer_norm_size_values_, fused_func_op_.getLoc());
}
void ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM::
SetOutputLayerNormCoefficients() {
SmallVector<int64_t, 1> begin_output_layer_norm_values = {3 * n_cell_};
output_layer_norm_coefficients_ =
SliceRankedTensor(&builder_, layer_norm_scale_, layer_norm_slice_shape_,
begin_output_layer_norm_values, layer_norm_slice_shape_,
layer_norm_size_values_, fused_func_op_.getLoc());
}
} // namespace TFL
} // namespace mlir

View File

@ -0,0 +1,214 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This header file defines common utils used by TFLite transformation
// passes to work with op attributes.
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_UTILS_LSTM_UTILS_H_
#define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_LSTM_UTILS_H_
#include "llvm/ADT/StringRef.h"
#include "mlir/IR/Builders.h" // TF:local_config_mlir
#include "mlir/IR/Function.h" // TF:local_config_mlir
#include "mlir/IR/Location.h" // TF:local_config_mlir
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
#include "mlir/IR/Value.h" // TF:local_config_mlir
#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
namespace mlir {
namespace TFL {
constexpr char kLstmCellSimple[] = "LSTMCellSimple";
constexpr char kLayerNormalizedLstmCellSimple[] =
"LayerNormalizedLstmCellSimple";
// A utility class that enables the conversion of the LSTMCellSimple composite
// op into a fused TFL LSTM op. The fused op is contained within a FuncOp
// that also contains other supporting ops needed to construct the operands for
// the fused op. The caller provides the containing FuncOp as input with
// arguments specifying the input, weight, projection and bias.
// The weight, pprojection, bias and layer norm scale all need to be
// RankedTensorType.
// This class sets the layer norm coefficients to NoneType.
class ConvertLSTMCellSimpleToFusedLSTM {
public:
// TODO(b/140053256): The couple_input_forget_gates should be specified on
// FuncOp as an attribute.
explicit ConvertLSTMCellSimpleToFusedLSTM(mlir::FuncOp fused_func_op,
bool couple_input_forget_gates)
: fused_func_op_(fused_func_op),
couple_input_forget_gates_(couple_input_forget_gates),
builder_(fused_func_op.getBody()) {}
// not copyable.
ConvertLSTMCellSimpleToFusedLSTM(const ConvertLSTMCellSimpleToFusedLSTM&) =
delete;
ConvertLSTMCellSimpleToFusedLSTM& operator=(
const ConvertLSTMCellSimpleToFusedLSTM&) = delete;
virtual ~ConvertLSTMCellSimpleToFusedLSTM() {}
// verify input func op arguments and initialize internal state.
virtual LogicalResult Initialize();
virtual llvm::StringRef GetCompositeOpName() { return kLstmCellSimple; }
// Rewrite the func body with constructed fused lstm.
void RewriteFunc();
protected:
void UpdateFuncSignature();
void GenerateFusedOpOperands();
void SetWeightForInputToCellGate();
void SetWeightForInputToInputGate();
void SetWeightForInputToForgetGate();
void SetWeightForInputToOutputGate();
void SetWeightForRecurrentToCellGate();
void SetWeightForRecurrentToInputGate();
void SetWeightForRecurrentToForgetGate();
void SetWeightForRecurrentToOutputGate();
void SetBiasToCellGate();
void SetBiasToInputGate();
void SetBiasToForgetGate();
void SetBiasToOutputGate();
void SetProjection();
void SetProjectionBias();
void SetInputActivationState();
void SetInputCellState();
virtual void SetCellLayerNormCoefficients();
virtual void SetInputLayerNormCoefficients();
virtual void SetForgetLayerNormCoefficients();
virtual void SetOutputLayerNormCoefficients();
// specified state
FuncOp fused_func_op_;
Value* input_;
Value* weight_;
Value* bias_;
Value* projection_;
bool couple_input_forget_gates_;
// internal state
Value* weight_transposed_;
Value* projection_transposed_;
RankedTensorType weight_type_;
RankedTensorType projection_type_;
int num_gates_;
int n_cell_;
int n_output_;
int n_input_;
int num_cols_weight_transposed_;
int num_cols_projection_transposed_;
// input -> cifg
Value* input2input_;
Value* input2forget_;
Value* input2cell_;
Value* input2output_;
// reccurrent -> cifg
Value* rec2input_;
Value* rec2forget_;
Value* rec2cell_;
Value* rec2output_;
// bias -> cifg
Value* bias2input_;
Value* bias2forget_;
Value* bias2cell_;
Value* bias2output_;
// projection
Value* proj_weight_;
Value* proj_bias_;
// state
Value* input_activation_state_;
Value* input_cell_state_;
// layer norm coefficients
Value* input_layer_norm_coefficients_;
Value* forget_layer_norm_coefficients_;
Value* cell_layer_norm_coefficients_;
Value* output_layer_norm_coefficients_;
mlir::TFL::LSTMOp lstm_;
Value* none_;
SmallVector<int64_t, 1> bias_slice_shape_;
SmallVector<int64_t, 1> bias_size_values_;
SmallVector<int64_t, 2> weight_slice_shape_;
SmallVector<int64_t, 2> weight_slice_size_input_values_;
SmallVector<int64_t, 2> weight_slice_size_recurrent_values_;
OpBuilder builder_;
};
// A utility class that enables the conversion of the
// LayerNormalizedLSTMCellSimple composite op into a fused TFL LSTM op. The
// fused op is contained within a FuncOp that also contains other supporting ops
// needed to construct the operands for the fused op. The caller provides the
// containing FuncOp as input with arguments specifying the input, weight,
// projection, bias and layer norm scale. The weight, pprojection, bias and
// layer norm scale all need to be RankedTensorType.
// This class overrides the layer norm coefficient setters from the base class.
class ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM
: public ConvertLSTMCellSimpleToFusedLSTM {
public:
// TODO(b/140053256): The couple_input_forget_gates should be specified on
// FuncOp as an attribute.
explicit ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM(
mlir::FuncOp fused_func_op, bool couple_input_forget_gates)
: ConvertLSTMCellSimpleToFusedLSTM(fused_func_op,
couple_input_forget_gates) {}
// not copyable.
ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM(
const ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM&) = delete;
ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM& operator=(
const ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM&) = delete;
~ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM() override {}
llvm::StringRef GetCompositeOpName() override {
return kLayerNormalizedLstmCellSimple;
}
LogicalResult Initialize() override;
protected:
void SetCellLayerNormCoefficients() override;
void SetInputLayerNormCoefficients() override;
void SetForgetLayerNormCoefficients() override;
void SetOutputLayerNormCoefficients() override;
private:
// specified state
Value* layer_norm_scale_;
// internal state
RankedTensorType layer_norm_scale_type_;
SmallVector<int64_t, 1> layer_norm_slice_shape_;
SmallVector<int64_t, 1> layer_norm_size_values_;
};
} // end namespace TFL
} // end namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_LITE_UTILS_LSTM_UTILS_H_

View File

@ -0,0 +1,222 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/lite/utils/lstm_utils.h"
#include <memory>
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Casting.h"
#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir
#include "mlir/IR/Builders.h" // TF:local_config_mlir
#include "mlir/IR/Function.h" // TF:local_config_mlir
#include "mlir/IR/Location.h" // TF:local_config_mlir
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
#include "mlir/IR/Types.h" // TF:local_config_mlir
#include "mlir/IR/Value.h" // TF:local_config_mlir
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir
#include "tensorflow/core/platform/test.h"
namespace mlir {
namespace TFL {
FuncOp createFusedFunc(mlir::Builder* builder) {
SmallVector<int64_t, 2> input_shape{1, 2};
SmallVector<int64_t, 2> weight_shape{3, 12};
SmallVector<int64_t, 1> bias_shape{2};
SmallVector<int64_t, 2> projection_shape{1, 2};
SmallVector<int64_t, 1> layer_norm_scale{4};
SmallVector<int64_t, 2> output_shape{1, 2};
auto input_type = builder->getTensorType(input_shape, builder->getF32Type());
auto weight_type =
builder->getTensorType(weight_shape, builder->getF32Type());
auto bias_type = builder->getTensorType(bias_shape, builder->getF32Type());
auto projection_type =
builder->getTensorType(projection_shape, builder->getF32Type());
auto layer_norm_scale_type =
builder->getTensorType(layer_norm_scale, builder->getF32Type());
auto output_type =
builder->getTensorType(output_shape, builder->getF32Type());
SmallVector<mlir::Type, 4> input_types{input_type, weight_type, bias_type,
projection_type,
layer_norm_scale_type};
auto func_type = builder->getFunctionType(input_types, output_type);
auto func =
FuncOp::create(mlir::NameLoc::get(builder->getIdentifier("fused_func"),
builder->getContext()),
"fused_func", func_type, {});
func.addEntryBlock();
return func;
}
// TODO(ashwinm): Revisit if this test should be moved to a test pass
// with FileCheck test after the pass that consumes the lstm_utils to stack
// the layers.
class LstmUtilsTest : public ::testing::Test {
protected:
LstmUtilsTest() {}
void SetUp() override {
builder_ = std::unique_ptr<mlir::Builder>(new Builder(&context_));
fused_lstm_func_ = createFusedFunc(builder_.get());
}
void TearDown() override {
fused_lstm_func_.erase();
builder_.reset();
}
FuncOp fused_lstm_func_;
mlir::MLIRContext context_;
std::unique_ptr<mlir::Builder> builder_;
};
TEST_F(LstmUtilsTest, ConvertLSTMCellSimple) {
mlir::TFL::ConvertLSTMCellSimpleToFusedLSTM convert(fused_lstm_func_, false);
auto result = convert.Initialize();
EXPECT_FALSE(failed(result));
convert.RewriteFunc();
fused_lstm_func_.dump();
// verify transpose
EXPECT_EQ(
fused_lstm_func_.getAttrOfType<StringAttr>("tf._implements").getValue(),
convert.GetCompositeOpName());
EXPECT_EQ(fused_lstm_func_.getNumArguments(), 5);
EXPECT_EQ(fused_lstm_func_.getType().getNumResults(), 1);
auto transpose_op = fused_lstm_func_.getBody().front().begin();
transpose_op++;
EXPECT_EQ(transpose_op->getOperand(0)
->getType()
.cast<RankedTensorType>()
.getDimSize(0),
3);
EXPECT_EQ(transpose_op->getOperand(0)
->getType()
.cast<RankedTensorType>()
.getDimSize(1),
12);
EXPECT_EQ(
transpose_op->getResult(0)->getType().cast<RankedTensorType>().getDimSize(
0),
12);
EXPECT_EQ(
transpose_op->getResult(0)->getType().cast<RankedTensorType>().getDimSize(
1),
3);
auto return_op = fused_lstm_func_.getBody().back().rbegin();
EXPECT_EQ(return_op->getName().getStringRef(),
mlir::ReturnOp::getOperationName());
return_op++;
EXPECT_EQ(return_op->getName().getStringRef(),
mlir::TFL::LSTMOp::getOperationName());
EXPECT_EQ(return_op->getNumOperands(), 24);
EXPECT_EQ(return_op->getNumResults(), 1);
// cifg = false, so input2input is not None.
EXPECT_FALSE(return_op->getOperand(1)->getType().isa<NoneType>());
// input layer norm is None
EXPECT_TRUE(return_op->getOperand(20)->getType().isa<NoneType>());
// proj_bias is F32
EXPECT_TRUE(return_op->getOperand(17)
->getType()
.cast<RankedTensorType>()
.getElementType()
.isF32());
EXPECT_EQ(fused_lstm_func_.getType().getNumResults(), 1);
auto output_types = fused_lstm_func_.getType().getResults();
SmallVector<int64_t, 2> output_shape{1, 2};
EXPECT_EQ(output_types[0].cast<RankedTensorType>().getShape().size(),
output_shape.size());
for (int i = 0; i < output_shape.size(); i++) {
EXPECT_EQ(output_types[0].cast<RankedTensorType>().getDimSize(i),
output_shape[i]);
}
}
TEST_F(LstmUtilsTest, ConvertLSTMCellSimpleToFusedLSTMCoupleInputForget) {
mlir::TFL::ConvertLSTMCellSimpleToFusedLSTM convert(fused_lstm_func_, true);
auto result = convert.Initialize();
EXPECT_FALSE(failed(result));
convert.RewriteFunc();
fused_lstm_func_.dump();
auto it = fused_lstm_func_.getBody().back().rbegin();
EXPECT_EQ(it->getName().getStringRef(), mlir::ReturnOp::getOperationName());
it++;
EXPECT_EQ(it->getName().getStringRef(),
mlir::TFL::LSTMOp::getOperationName());
EXPECT_EQ(it->getNumOperands(), 24);
EXPECT_EQ(it->getNumResults(), 1);
// cifg = true, so input2input is None.
EXPECT_TRUE(it->getOperand(1)->getType().isa<NoneType>());
}
TEST_F(LstmUtilsTest, ConvertLayerNormLSTMCellSimpleToFusedLSTM) {
mlir::TFL::ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM convert(
fused_lstm_func_, false);
auto result = convert.Initialize();
EXPECT_FALSE(failed(result));
convert.RewriteFunc();
fused_lstm_func_.dump();
EXPECT_EQ(
fused_lstm_func_.getAttrOfType<StringAttr>("tf._implements").getValue(),
convert.GetCompositeOpName());
EXPECT_EQ(fused_lstm_func_.getNumArguments(), 5);
EXPECT_EQ(fused_lstm_func_.getType().getNumResults(), 1);
auto it = fused_lstm_func_.getBody().back().rbegin();
EXPECT_EQ(it->getName().getStringRef(), mlir::ReturnOp::getOperationName());
it++;
EXPECT_EQ(it->getName().getStringRef(),
mlir::TFL::LSTMOp::getOperationName());
EXPECT_EQ(it->getNumOperands(), 24);
EXPECT_EQ(it->getNumResults(), 1);
// cifg = false, so input2input is not None.
EXPECT_FALSE(it->getOperand(1)->getType().isa<NoneType>());
// input layer norm
EXPECT_FALSE(it->getOperand(20)->getType().isa<NoneType>());
EXPECT_EQ(
it->getOperand(20)->getType().cast<RankedTensorType>().getShape().size(),
1);
EXPECT_EQ(
it->getOperand(20)->getType().cast<RankedTensorType>().getDimSize(0), 3);
EXPECT_EQ(fused_lstm_func_.getType().getNumResults(), 1);
auto output_types = fused_lstm_func_.getType().getResults();
SmallVector<int64_t, 2> output_shape{1, 2};
EXPECT_EQ(output_types[0].cast<RankedTensorType>().getShape().size(),
output_shape.size());
for (int i = 0; i < output_shape.size(); i++) {
EXPECT_EQ(output_types[0].cast<RankedTensorType>().getDimSize(i),
output_shape[i]);
}
}
} // namespace TFL
} // namespace mlir

View File

@ -0,0 +1,11 @@
package(
default_visibility = ["//visibility:public"],
licenses = ["notice"], # Apache 2.0
)
exports_files(
["mlir.i"],
visibility = [
"//tensorflow/python:__subpackages__",
],
)

View File

@ -0,0 +1,74 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
%include "tensorflow/python/platform/base.i"
%{
#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/import_utils.h"
namespace tensorflow {
namespace swig {
// Simple wrapper to support tf.mlir.experimental.convert_graph_def.
// Load a .pbptx, convert to MLIR, and (optionally) optimize the module before
// returning it as a string.
// This is an early experimental API, ideally we should return a wrapper object
// around a Python binding to the MLIR module.
string ImportGraphDef(const string &proto, TF_Status* status) {
GraphDef graphdef;
auto s = tensorflow::LoadProtoFromBuffer(proto, &graphdef);
if (!s.ok()) {
Set_TF_Status_from_Status(status, s);
return "// error";
}
GraphDebugInfo debug_info;
NodeSpecs specs;
mlir::MLIRContext context;
auto module = ConvertGraphdefToMlir(graphdef, debug_info, specs, &context);
if (!module.ok()) {
Set_TF_Status_from_Status(status, module.status());
return "// error";
}
return MlirModuleToString(*module.ConsumeValueOrDie());
}
} // namespace swig
} // namespace tensorflow
%}
%ignoreall
%unignore tensorflow;
%unignore tensorflow::swig;
%unignore tensorflow::swig::ImportGraphDef;
// Wrap this function
namespace tensorflow {
namespace swig {
static string ImportGraphDef(const string &graphdef, TF_Status* status);
} // namespace swig
} // namespace tensorflow
%insert("python") %{
def import_graphdef(graphdef):
return str(ImportGraphDef(str(graphdef).encode('utf-8')));
%}
%unignoreall

View File

@ -1,5 +1,5 @@
load("@local_config_mlir//:tblgen.bzl", "gentbl")
load("//tensorflow:tensorflow.bzl", "tf_cc_test", "tf_native_cc_binary")
load("//tensorflow:tensorflow.bzl", "tf_cc_test", "tf_gen_op_wrapper_py", "tf_native_cc_binary")
package(
default_visibility = [":friends"],
@ -123,21 +123,6 @@ cc_library(
"ir/tf_ops.cc.inc",
"ir/tf_ops.h.inc",
"ir/tf_types.cc",
"transforms/bridge.cc",
"transforms/bridge_pass.cc",
"transforms/cluster_formation.cc",
"transforms/cluster_outlining.cc",
"transforms/executor_island_coarsening.cc",
"transforms/functional_control_flow_to_cfg.cc",
"transforms/generated_canonicalize.inc",
"transforms/generated_optimize.inc",
"transforms/graph_pruning.cc",
"transforms/optimize.cc",
"transforms/raise_control_flow.cc",
"transforms/tpu_cluster_formation.cc",
"transforms/tpu_rewrite_pass.cc",
"translate/control_to_executor_dialect.cc",
"translate/executor_to_control_dialect.cc",
],
hdrs = [
"ir/control_flow_ops.h",
@ -153,11 +138,11 @@ cc_library(
includes = ["include"],
deps = [
":error_util",
":mlir_passthrough_op",
":tensorflow_canonicalize_inc_gen",
":tensorflow_device_ops_inc_gen",
":tensorflow_executor_inc_gen",
":tensorflow_ops_inc_gen",
":tensorflow_optimize_inc_gen",
"//tensorflow/compiler/mlir/lite:validators",
"//tensorflow/core:lib",
"@llvm//:support",
@ -175,12 +160,74 @@ cc_library(
alwayslink = 1,
)
cc_library(
name = "tensorflow_passes",
srcs = [
"transforms/bridge.cc",
"transforms/bridge_pass.cc",
"transforms/cluster_formation.cc",
"transforms/cluster_outlining.cc",
"transforms/executor_island_coarsening.cc",
"transforms/fold_switch.cc",
"transforms/functional_control_flow_to_cfg.cc",
"transforms/generated_canonicalize.inc",
"transforms/generated_optimize.inc",
"transforms/graph_pruning.cc",
"transforms/materialize_mlir_passthrough_op.cc",
"transforms/optimize.cc",
"transforms/raise_control_flow.cc",
"transforms/sink_constant.cc",
"transforms/tpu_cluster_formation.cc",
"transforms/tpu_rewrite_pass.cc",
"translate/control_to_executor_dialect.cc",
"translate/executor_to_control_dialect.cc",
],
hdrs = [
"transforms/bridge.h",
"transforms/passes.h",
],
includes = ["include"],
deps = [
":error_util",
":tensorflow",
":tensorflow_optimize_inc_gen",
"//tensorflow/compiler/mlir/lite:validators",
"//tensorflow/core:lib",
"//tensorflow/core/platform:logging",
"@llvm//:support",
"@local_config_mlir//:Analysis",
"@local_config_mlir//:IR",
"@local_config_mlir//:Parser",
"@local_config_mlir//:Pass",
"@local_config_mlir//:StandardOps",
"@local_config_mlir//:Support",
"@local_config_mlir//:TransformUtils",
"@local_config_mlir//:Transforms",
],
# TODO(jpienaar): Merge in the dialect registration.
alwayslink = 1,
)
cc_library(
name = "tensorflow_test_passes",
srcs = [
"transforms/lower_tf_pass.cc",
],
deps = [
":lower_tf_lib",
"@local_config_mlir//:IR",
"@local_config_mlir//:Pass",
],
alwayslink = 1,
)
# Library with TensorFlow dialect static initialization.
cc_library(
name = "tensorflow_dialect_registration",
srcs = ["ir/dialect_registration.cc"],
deps = [
":tensorflow",
":tensorflow_passes",
"@local_config_mlir//:IR",
],
alwayslink = 1,
@ -204,6 +251,7 @@ cc_library(
":mangling_util",
":mlir_roundtrip_flags",
":tensorflow",
":tensorflow_passes",
"//tensorflow/cc/saved_model:loader_lite",
"//tensorflow/compiler/jit:shape_inference_helpers",
"//tensorflow/compiler/xla:status_macros",
@ -252,6 +300,8 @@ cc_library(
"utils/export_utils.cc",
],
hdrs = [
"ir/tf_types.def",
"ir/tf_types.h",
"utils/export_utils.h",
],
deps = [
@ -639,26 +689,73 @@ gentbl(
)
cc_library(
name = "tensorflow_fold_switch",
srcs = [
"transforms/fold_switch.cc",
],
hdrs = [
"transforms/passes.h",
],
copts = ["-std=c++14"],
name = "compile_mlir_util",
srcs = ["utils/compile_mlir_util.cc"],
hdrs = ["utils/compile_mlir_util.h"],
deps = [
":tensorflow",
":convert_type",
":error_util",
"//tensorflow/compiler/mlir/xla:hlo",
"//tensorflow/compiler/mlir/xla:mlir_hlo_to_hlo",
"//tensorflow/compiler/mlir/xla:type_to_shape",
"//tensorflow/compiler/mlir/xla:xla_legalize_tf",
"//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"@com_google_absl//absl/memory",
"//tensorflow/stream_executor/lib",
"@com_google_absl//absl/types:span",
"@llvm//:support",
"@local_config_mlir//:Analysis",
"@local_config_mlir//:IR",
"@local_config_mlir//:Pass",
"@local_config_mlir//:QuantOps",
"@local_config_mlir//:StandardOps",
"@local_config_mlir//:Support",
"@local_config_mlir//:Parser",
],
)
tf_cc_test(
name = "compile_mlir_util_test",
size = "small",
srcs = ["utils/compile_mlir_util_test.cc"],
deps = [
":compile_mlir_util",
"//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/core:test_main",
"//tensorflow/stream_executor/lib",
],
)
cc_library(
name = "mlir_passthrough_op",
srcs = ["ops/mlir_passthrough_op.cc"],
deps = [
"//tensorflow/core:framework",
],
alwayslink = 1,
)
tf_gen_op_wrapper_py(
name = "gen_mlir_passthrough_op_py",
out = "gen_mlir_passthrough_op.py",
deps = [":mlir_passthrough_op"],
)
# Library to get rewrite patterns lowering within TensorFlow.
#
# This is a separate library so that external passes can link only this library
# without linking any of the other tensorflow passes.
cc_library(
name = "lower_tf_lib",
srcs = [
"transforms/lower_tf.cc",
],
hdrs = [
"transforms/lower_tf.h",
],
deps = [
":tensorflow",
"@local_config_mlir//:IR",
],
alwayslink = 1,
)

View File

@ -40,6 +40,7 @@ limitations under the License.
#include "mlir/IR/Types.h" // TF:local_config_mlir
#include "mlir/IR/Value.h" // TF:local_config_mlir
#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir
#include "mlir/Transforms/FoldUtils.h" // TF:local_config_mlir
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
namespace mlir {
@ -48,31 +49,65 @@ namespace {
// If the given tensor has elements of type variant, then returns a new type
// after dropping subtypes info. Otherwise, returns the original type as is.
Type DropVariantSubTypes(Type ty) {
ShapedType shaped_ty = ty.cast<ShapedType>();
Type element_ty = shaped_ty.getElementType();
ShapedType DropVariantSubTypes(ShapedType ty) {
Type element_ty = ty.getElementType();
if (!element_ty.isa<TF::VariantType>()) return ty;
Type variant_ty = TF::VariantType::get(ty.getContext());
if (shaped_ty.hasRank()) {
return RankedTensorType::get(shaped_ty.getShape(), variant_ty);
if (ty.hasRank()) {
return RankedTensorType::get(ty.getShape(), variant_ty);
}
return UnrankedTensorType::get(variant_ty);
}
// If the given tensor has elements of type ref, then returns a new type
// of the shape, but corresponding non-ref type as element type. Otherwise,
// returns the original type as is.
ShapedType DropRefType(ShapedType type) {
Type element_ty = type.getElementType();
TF::TensorFlowRefType ref_type = element_ty.dyn_cast<TF::TensorFlowRefType>();
if (!ref_type) return type;
if (type.hasRank()) {
return RankedTensorType::get(type.getShape(), ref_type.RemoveRef());
}
return UnrankedTensorType::get(ref_type.RemoveRef());
}
} // namespace
//===----------------------------------------------------------------------===//
// TF Executor Dialect
//===----------------------------------------------------------------------===//
namespace {
struct TensorFlowExecutorOpFolderDialectInterface
: public OpFolderDialectInterface {
using OpFolderDialectInterface::OpFolderDialectInterface;
// Registered hook to check if the given region, which is attached to an
// operation that is *not* isolated from above (i.e. no internal regions
// reference values defined in an enclosing region), should be used when
// materializing constants.
// In the executor dialect we materialize inside an island.
bool shouldMaterializeInto(Region *region) const final {
return isa<tf_executor::IslandOp>(region->getParentOp());
}
};
} // namespace
TensorFlowExecutorDialect::TensorFlowExecutorDialect(MLIRContext *context)
: Dialect(/*name=*/"tf_executor", context) {
addOperations<
#define GET_OP_LIST
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc.inc"
>();
addInterfaces<TensorFlowExecutorOpFolderDialectInterface>();
addTypes<ControlType, TokenType>();
}
@ -296,6 +331,23 @@ void Print(IslandOp op, OpAsmPrinter *p) {
p->printOperands(op.getOperands());
*p << ')';
}
// Check if we can print the short "wraps" form: that is if the island
// contains a single operation and the result of this operation are perfectly
// forwarded to the yield.
if (op.getAttrs().empty() &&
std::next(op.GetBody().begin(), 2) == op.GetBody().end()) {
Operation &wrapped_op = op.GetBody().front();
Operation &yield_op = op.GetBody().back();
if (wrapped_op.getNumResults() == yield_op.getNumOperands() &&
std::equal(wrapped_op.getResults().begin(),
wrapped_op.getResults().end(),
yield_op.getOperands().begin())) {
*p << " wraps ";
p->printGenericOp(&op.GetBody().front());
return;
}
}
p->printRegion(op.getOperation()->getRegion(0));
p->printOptionalAttrDict(op.getAttrs());
}
@ -316,17 +368,22 @@ ParseResult ParseIslandOp(OpAsmParser *parser, OperationState *result) {
// Parse the body region.
Region &body = *result->addRegion();
// TODO(b/134773778): the custom parser is missing support to implement to
// short syntax right now.
// if (!parser->parseOptionalKeyword("wraps")) {
// body.push_back(new Block);
// Block &block = body.back();
// parser->getBuilder().setInsertionPointToEnd(&block);
// if (parser->parseOperation())
// return failure();
// }
if (parser->parseRegion(body, llvm::None, llvm::None)) return failure();
if (succeeded(parser->parseOptionalKeyword("wraps"))) {
// If we parse the short version of the island, we have an operation in the
// generic form that follows the "wraps" keyword. Parse it inside the region
// and forward all of its results as-is to the yield operation.
body.push_back(new Block);
Block &block = body.back();
Operation *wrapped_op =
parser->parseGenericOperation(&block, block.begin());
if (!wrapped_op) return failure();
OpBuilder builder(parser->getBuilder().getContext());
builder.setInsertionPointToEnd(&block);
builder.create<YieldOp>(result->location,
llvm::to_vector<8>(wrapped_op->getResults()));
} else if (parser->parseRegion(body, llvm::None, llvm::None)) {
return failure();
}
IslandOp::ensureTerminator(body, parser->getBuilder(), result->location);
@ -536,35 +593,43 @@ LogicalResult Verify(MergeOp merge) {
if (data_type.isa<ControlType>())
return merge.emitOpError() << "expects a non-control input";
// Check that all operands can be broadcasted to a common type compatible with
// the result type.
Type broadcasted_type = merge.output()->getType();
// Check that each operand can be individually broadcasted to the output type.
Type output_type = merge.output()->getType();
TensorType output_tensor_ty = output_type.dyn_cast<TensorType>();
if (!output_tensor_ty) {
return merge.emitOpError()
<< "expects output to have tensor type but got " << output_type;
}
bool is_output_ref =
output_tensor_ty.getElementType().isa<TF::TensorFlowRefType>();
for (Type operand_type : merge.getOperandTypes()) {
if (operand_type.isa<ControlType>()) break;
// TODO(hinsu): Update ControlOperandsAfterAllData trait to verify this
// constraint.
if (!operand_type.isa<TensorType>())
return merge.emitOpError("expects data operands to have tensor type");
// Variant types may have opaque subtypes information that need not match
// between the two types so drop them before computing the broadcasted type.
Type new_broadcasted_type =
OpTrait::util::getBroadcastedType(DropVariantSubTypes(broadcasted_type),
DropVariantSubTypes(operand_type));
if (!new_broadcasted_type)
TensorType operand_tensor_ty = operand_type.dyn_cast<TensorType>();
if (!operand_tensor_ty)
return merge.emitOpError()
<< "expects all operands to be broadcastable"
<< " but got " << broadcasted_type << " vs " << operand_type;
// Use the broadcasted type unless we're losing the rank information here.
// This is because for example starting with a result of tensor<4xf32>, if
// the first operand is unranked, the broadcasted type will be unranked.
// Then any tensor operand will be broadcastable to this unranked type.
if (!broadcasted_type.cast<TensorType>().hasRank() ||
new_broadcasted_type.cast<TensorType>().hasRank())
broadcasted_type = new_broadcasted_type;
}
<< "expects data operands to have tensor type but got "
<< operand_type;
// If output type is a ref type then all operand types should also be of the
// same ref type. However, if the output type is a non-ref type T, operands
// can be tensor of type T or T_REF.
if (is_output_ref &&
!operand_tensor_ty.getElementType().isa<TF::TensorFlowRefType>()) {
return merge.emitOpError()
<< "expects same operand and output element type but got "
<< operand_tensor_ty << " vs " << output_tensor_ty;
}
Type broadcasted_type = OpTrait::util::getBroadcastedType(
DropRefType(DropVariantSubTypes(output_tensor_ty)),
DropRefType(DropVariantSubTypes(operand_tensor_ty)));
if (!broadcasted_type)
return merge.emitOpError()
<< "expects all operands to be broadcastable with output type"
<< " but got " << operand_tensor_ty << " vs " << output_tensor_ty;
}
return success();
}
@ -1088,6 +1153,35 @@ void IslandOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
DropEmptyIslandNoOperandOneDataResult>(context);
}
//===----------------------------------------------------------------------===//
// tf_executor.ControlTrigger
//===----------------------------------------------------------------------===//
namespace {
// This pattern matches and removes ControlTriggerOps with no control operands.
// Control result users will have their relevant operands removed.
struct DropEmptyControlTrigger : public OpRewritePattern<ControlTriggerOp> {
using OpRewritePattern<ControlTriggerOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(ControlTriggerOp op,
PatternRewriter &rewriter) const override {
if (op.getNumOperands() != 0) return matchFailure();
for (auto &use : llvm::make_early_inc_range(op.control()->getUses()))
use.getOwner()->eraseOperand(use.getOperandNumber());
rewriter.replaceOp(op, {nullptr});
return matchSuccess();
}
};
} // anonymous namespace
void ControlTriggerOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
results.insert<DropEmptyControlTrigger>(context);
}
//===----------------------------------------------------------------------===//
// Folders
//===----------------------------------------------------------------------===//

View File

@ -594,6 +594,8 @@ def TfExecutor_ControlTriggerOp : TfExecutor_Op<"ControlTrigger",
let verifier = ?;
let hasCanonicalizer = 1;
let builders = [OpBuilder<
"Builder *builder, OperationState *result, "
"ArrayRef<Value *> operands, ArrayRef<NamedAttribute> attributes = {}",

View File

@ -88,13 +88,13 @@ Inputs must be of same size and shape.
}];
let arguments = (ins
Variadic<TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Variant]>>:$inputs,
Variadic<TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8, TF_Variant]>>:$inputs,
Confined<I64Attr, [IntMinValue<1>]>:$N
);
let results = (outs
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Variant]>:$sum
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8, TF_Variant]>:$sum
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
@ -110,12 +110,12 @@ def TF_AddV2Op : TF_Op<"AddV2", [Broadcastable, Commutative, NoSideEffect]>,
}];
let arguments = (ins
TF_NumberTensor:$x,
TF_NumberTensor:$y
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint8]>:$x,
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint8]>:$y
);
let results = (outs
TF_NumberTensor:$z
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint8]>:$z
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
@ -123,6 +123,32 @@ def TF_AddV2Op : TF_Op<"AddV2", [Broadcastable, Commutative, NoSideEffect]>,
let hasCanonicalizer = 1;
}
def TF_AllOp : TF_Op<"All", [NoSideEffect]> {
let summary = [{
Computes the "logical and" of elements across dimensions of a tensor.
}];
let description = [{
Reduces `input` along the dimensions given in `axis`. Unless
`keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in
`axis`. If `keep_dims` is true, the reduced dimensions are
retained with length 1.
}];
let arguments = (ins
I1Tensor:$input,
TF_I32OrI64Tensor:$reduction_indices,
DefaultValuedAttr<BoolAttr, "false">:$keep_dims
);
let results = (outs
I1Tensor:$output
);
TF_DerivedOperandTypeAttr Tidx = TF_DerivedOperandTypeAttr<1>;
}
def TF_AnyOp : TF_Op<"Any", [NoSideEffect]> {
let summary = [{
Computes the "logical or" of elements across dimensions of a tensor.
@ -169,7 +195,7 @@ Usage:
}];
let arguments = (ins
TF_NumberTensor:$input,
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$input,
TF_I32OrI64Tensor:$dimension
);
@ -202,7 +228,7 @@ Usage:
}];
let arguments = (ins
TF_NumberTensor:$input,
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$input,
TF_I32OrI64Tensor:$dimension
);
@ -261,6 +287,88 @@ window in `value`.
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_BatchMatMulOp : TF_Op<"BatchMatMul", [NoSideEffect]> {
let summary = "Multiplies slices of two tensors in batches.";
let description = [{
Multiplies all slices of `Tensor` `x` and `y` (each slice can be
viewed as an element of a batch), and arranges the individual results
in a single output tensor of the same batch size. Each of the
individual slices can optionally be adjointed (to adjoint a matrix
means to transpose and conjugate it) before multiplication by setting
the `adj_x` or `adj_y` flag to `True`, which are by default `False`.
The input tensors `x` and `y` are 2-D or higher with shape `[..., r_x, c_x]`
and `[..., r_y, c_y]`.
The output tensor is 2-D or higher with shape `[..., r_o, c_o]`, where:
r_o = c_x if adj_x else r_x
c_o = r_y if adj_y else c_y
It is computed as:
output[..., :, :] = matrix(x[..., :, :]) * matrix(y[..., :, :])
}];
let arguments = (ins
TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$x,
TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$y,
DefaultValuedAttr<BoolAttr, "false">:$adj_x,
DefaultValuedAttr<BoolAttr, "false">:$adj_y
);
let results = (outs
TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$output
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_BatchMatMulV2Op : TF_Op<"BatchMatMulV2", [NoSideEffect]> {
let summary = "Multiplies slices of two tensors in batches.";
let description = [{
Multiplies all slices of `Tensor` `x` and `y` (each slice can be
viewed as an element of a batch), and arranges the individual results
in a single output tensor of the same batch size. Each of the
individual slices can optionally be adjointed (to adjoint a matrix
means to transpose and conjugate it) before multiplication by setting
the `adj_x` or `adj_y` flag to `True`, which are by default `False`.
The input tensors `x` and `y` are 2-D or higher with shape `[..., r_x, c_x]`
and `[..., r_y, c_y]`.
The output tensor is 2-D or higher with shape `[..., r_o, c_o]`, where:
r_o = c_x if adj_x else r_x
c_o = r_y if adj_y else c_y
It is computed as:
output[..., :, :] = matrix(x[..., :, :]) * matrix(y[..., :, :])
*NOTE*: `BatchMatMulV2` supports broadcasting in the batch dimensions. More
about broadcasting
[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html).
}];
let arguments = (ins
TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$x,
TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$y,
DefaultValuedAttr<BoolAttr, "false">:$adj_x,
DefaultValuedAttr<BoolAttr, "false">:$adj_y
);
let results = (outs
TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$output
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_BatchToSpaceNDOp : TF_Op<"BatchToSpaceND", [NoSideEffect]> {
let summary = "BatchToSpace for N-D tensors of type T.";
@ -297,14 +405,14 @@ Broadcasting is supported, so `value` may have any number of dimensions.
}];
let arguments = (ins
TF_NumberTensor:$value,
TF_NumberTensor:$bias,
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$value,
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$bias,
DefaultValuedAttr<TF_ConvnetDataFormatAttr, "NHWC">:$data_format
);
let results = (outs
TF_NumberTensor:$output
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
@ -332,21 +440,23 @@ gives module error.
For example,
Example 1:
```python
>>> a = [1., 2., 3.]
>>> equality_bitcast = tf.bitcast(a,tf.complex128)
tensorflow.python.framework.errors_impl.InvalidArgumentError: Cannot bitcast from float to complex128: shape [3] [Op:Bitcast]
>>> equality_cast = tf.cast(a,tf.complex128)
>>> equality_bitcast = tf.bitcast(a, tf.complex128)
Traceback (most recent call last):
...
InvalidArgumentError: Cannot bitcast from 1 to 18 [Op:Bitcast]
>>> equality_cast = tf.cast(a, tf.complex128)
>>> print(equality_cast)
tf.Tensor([1.+0.j 2.+0.j 3.+0.j], shape=(3,), dtype=complex128)
```
Example 2:
```python
>>> tf.bitcast(tf.constant(0xffffffff, dtype=tf.uint32), tf.uint8)
<tf.Tensor: ... shape=(4,), dtype=uint8, numpy=array([255, 255, 255, 255], dtype=uint8)>
```
Example 3:
```python
>>> x = [1., 2., 3.]
>>> y = [0., 2., 3.]
>>> equality= tf.equal(x,y)
@ -358,10 +468,9 @@ tf.Tensor([False True True], shape=(3,), dtype=bool)
tf.Tensor([0. 1. 1.], shape=(3,), dtype=float32)
>>> print(equality_bitcast)
tf.Tensor(
[[ 0 0 0 0]
[ 0 0 128 63]
[ 0 0 128 63]], shape=(3, 4), dtype=uint8)
```
[[ 0 0 0 0]
[ 0 0 128 63]
[ 0 0 128 63]], shape=(3, 4), dtype=uint8)
*NOTE*: Bitcast is implemented as a low-level cast, so machines with different
endian orderings will give different results.
@ -393,14 +502,13 @@ and works its way forward.
For example,
```python
>>> x = tf.constant([1, 2, 3])
>>> y = tf.broadcast_to(x, [3, 3])
>>> sess.run(y)
array([[1, 2, 3],
[1, 2, 3],
[1, 2, 3]], dtype=int32)
```
>>> print(y)
tf.Tensor(
[[1 2 3]
[1 2 3]
[1 2 3]], shape=(3, 3), dtype=int32)
In the above example, the input Tensor with the shape of `[1, 3]`
is broadcasted to output Tensor with shape of `[3, 3]`.
@ -462,6 +570,27 @@ def TF_CeilOp : TF_Op<"Ceil", [NoSideEffect, SameOperandsAndResultType]> {
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_CheckNumericsOp : TF_Op<"CheckNumerics", [SameOperandsAndResultType]> {
let summary = "Checks a tensor for NaN and Inf values.";
let description = [{
When run, reports an `InvalidArgument` error if `tensor` has any values
that are not a number (NaN) or infinity (Inf). Otherwise, passes `tensor` as-is.
}];
let arguments = (ins
TF_FpTensor:$tensor,
StrAttr:$message
);
let results = (outs
TF_FpTensor:$output
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_ConcatOp : TF_Op<"Concat", [NoSideEffect]> {
let summary = "Concatenates tensors along one dimension.";
@ -480,6 +609,10 @@ def TF_ConcatOp : TF_Op<"Concat", [NoSideEffect]> {
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<1>;
let verifier = [{
return Verify(*this);
}];
}
def TF_ConcatV2Op : TF_Op<"ConcatV2", [NoSideEffect]> {
@ -501,6 +634,10 @@ def TF_ConcatV2Op : TF_Op<"ConcatV2", [NoSideEffect]> {
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
TF_DerivedOperandTypeAttr Tidx = TF_DerivedOperandTypeAttr<1>;
let verifier = [{
return Verify(*this);
}];
}
def TF_ConjOp : TF_Op<"Conj", [NoSideEffect]> {
@ -771,12 +908,12 @@ def TF_DivOp : TF_Op<"Div", [Broadcastable, NoSideEffect]>,
}];
let arguments = (ins
TF_NumberTensor:$x,
TF_NumberTensor:$y
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint16, TF_Uint8]>:$x,
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint16, TF_Uint8]>:$y
);
let results = (outs
TF_NumberTensor:$z
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint16, TF_Uint8]>:$z
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
@ -805,8 +942,7 @@ See [Fast and Accurate Deep Network Learning by Exponential Linear Units (ELUs)
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_EqualOp : TF_Op<"Equal", [Broadcastable, Commutative, NoSideEffect]>,
WithBroadcastableCmpOpBuilder {
def TF_EqualOp : TF_Op<"Equal", [Commutative, NoSideEffect]> {
let summary = "Returns the truth value of (x == y) element-wise.";
let description = [{
@ -825,8 +961,10 @@ tf.math.equal(x, y) ==> array([True, True])
}];
let arguments = (ins
TensorOf<[BF16, F16, F32, F64, I1, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Str]>:$x,
TensorOf<[BF16, F16, F32, F64, I1, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Str]>:$y
TensorOf<[BF16, F16, F32, F64, I1, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Str, TF_Uint8]>:$x,
TensorOf<[BF16, F16, F32, F64, I1, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Str, TF_Uint8]>:$y,
DefaultValuedAttr<BoolAttr, "true">:$incompatible_shape_error
);
let results = (outs
@ -834,6 +972,15 @@ tf.math.equal(x, y) ==> array([True, True])
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
let builders = [
OpBuilder<"Builder* builder, OperationState* result, Value* x, "
"Value* y, BoolAttr incompatible_shape_error">
];
let verifier = [{
return Verify(*this);
}];
}
def TF_ExpOp : TF_Op<"Exp", [NoSideEffect, SameOperandsAndResultType]> {
@ -1017,6 +1164,52 @@ values.
}];
}
def TF_FakeQuantWithMinMaxVarsPerChannelOp : TF_Op<"FakeQuantWithMinMaxVarsPerChannel", [NoSideEffect]> {
let summary = [{
Fake-quantize the 'inputs' tensor of type float and one of the shapes: `[d]`,
}];
let description = [{
`[b, d]` `[b, h, w, d]` via per-channel floats `min` and `max` of shape `[d]`
to 'outputs' tensor of same shape as `inputs`.
`[min; max]` define the clamping range for the `inputs` data.
`inputs` values are quantized into the quantization range (`[0; 2^num_bits - 1]`
when `narrow_range` is false and `[1; 2^num_bits - 1]` when it is true) and
then de-quantized and output as floats in `[min; max]` interval.
`num_bits` is the bitwidth of the quantization; between 2 and 16, inclusive.
Before quantization, `min` and `max` values are adjusted with the following
logic.
It is suggested to have `min <= 0 <= max`. If `0` is not in the range of values,
the behavior can be unexpected:
If `0 < min < max`: `min_adj = 0` and `max_adj = max - min`.
If `min < max < 0`: `min_adj = min - max` and `max_adj = 0`.
If `min <= 0 <= max`: `scale = (max - min) / (2^num_bits - 1) `,
`min_adj = scale * round(min / scale)` and `max_adj = max + min_adj - min`.
This operation has a gradient and thus allows for training `min` and `max`
values.
}];
let arguments = (ins
F32Tensor:$inputs,
F32Tensor:$min,
F32Tensor:$max,
DefaultValuedAttr<I64Attr, "8">:$num_bits,
DefaultValuedAttr<BoolAttr, "false">:$narrow_range
);
let results = (outs
F32Tensor:$outputs
);
let verifier = [{
return Verify(*this);
}];
}
def TF_FillOp : TF_Op<"Fill", [NoSideEffect]> {
let summary = "Creates a tensor filled with a scalar value.";
@ -1082,12 +1275,12 @@ def TF_FloorDivOp : TF_Op<"FloorDiv", [Broadcastable, NoSideEffect]>,
}];
let arguments = (ins
TF_NumberTensor:$x,
TF_NumberTensor:$y
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint16, TF_Uint8]>:$x,
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint16, TF_Uint8]>:$y
);
let results = (outs
TF_NumberTensor:$z
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint16, TF_Uint8]>:$z
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
@ -1780,14 +1973,14 @@ retained with length 1.
}];
let arguments = (ins
TF_NumberTensor:$input,
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$input,
TF_I32OrI64Tensor:$reduction_indices,
DefaultValuedAttr<BoolAttr, "false">:$keep_dims
);
let results = (outs
TF_NumberTensor:$output
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
@ -1801,7 +1994,7 @@ def TF_MaxPoolOp : TF_Op<"MaxPool", [NoSideEffect]> {
}];
let arguments = (ins
TF_IntOrFpTensor:$input,
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Qint8, TF_Uint16, TF_Uint8]>:$input,
Confined<I64ArrayAttr, [ArrayMinCount<4>]>:$ksize,
Confined<I64ArrayAttr, [ArrayMinCount<4>]>:$strides,
@ -1810,7 +2003,7 @@ def TF_MaxPoolOp : TF_Op<"MaxPool", [NoSideEffect]> {
);
let results = (outs
TF_IntOrFpTensor:$output
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Qint8, TF_Uint16, TF_Uint8]>:$output
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
@ -1850,14 +2043,14 @@ retained with length 1.
}];
let arguments = (ins
TF_NumberTensor:$input,
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$input,
TF_I32OrI64Tensor:$reduction_indices,
DefaultValuedAttr<BoolAttr, "false">:$keep_dims
);
let results = (outs
TF_NumberTensor:$output
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
@ -1931,6 +2124,57 @@ pad(t, paddings) ==> [[2, 1, 1, 2, 3, 3, 2]
TF_DerivedOperandTypeAttr Tpaddings = TF_DerivedOperandTypeAttr<1>;
}
def TF_MlirPassthroughOpOp : TF_Op<"MlirPassthroughOp", [NoSideEffect]> {
let summary = [{
Wraps an arbitrary MLIR computation expressed as a module with a main() function.
}];
let description = [{
This operation does not have an associated kernel and is not intended to be
executed in a regular TensorFlow session. Instead it is intended to be used for
testing or for special case where a user intends to pass custom MLIR computation
through a TensorFlow graph with the intent of having custom tooling processing
it downstream (when targeting a different environment, like TensorFlow lite for
example).
The MLIR module is expected to have a main() function that will be used as an
entry point. The inputs to the operations will be passed as argument to the
main() function and the returned values of the main function mapped to the
outputs.
Example usage:
```
import tensorflow as tf
from tensorflow.compiler.mlir.tensorflow.gen_mlir_passthrough_op import mlir_passthrough_op
mlir_module = '''
func @main(%arg0 : tensor<10xf32>, %arg1 : tensor<10xf32>) -> tensor<10x10xf32> {
%add = "magic.op"(%arg0, %arg1) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10x10xf32>
return %ret : tensor<10x10xf32>
}
'''
@tf.function
def foo(x, y):
return = mlir_passthrough_op([x, y], mlir_module, Toutputs=[tf.float32])
graph_def = foo.get_concrete_function(tf.TensorSpec([10], tf.float32), tf.TensorSpec([10], tf.float32)).graph.as_graph_def()
```
}];
let arguments = (ins
Variadic<TF_Tensor>:$inputs,
StrAttr:$mlir_module
);
let results = (outs
Variadic<TF_Tensor>:$outputs
);
TF_DerivedOperandTypeListAttr Tinputs = TF_DerivedOperandTypeListAttr<0>;
TF_DerivedResultTypeListAttr Toutputs = TF_DerivedResultTypeListAttr<0>;
}
def TF_MulOp : TF_Op<"Mul", [Broadcastable, Commutative, NoSideEffect]>,
WithBroadcastableBinOpBuilder {
let summary = "Returns x * y element-wise.";
@ -1941,12 +2185,12 @@ def TF_MulOp : TF_Op<"Mul", [Broadcastable, Commutative, NoSideEffect]>,
}];
let arguments = (ins
TF_NumberTensor:$x,
TF_NumberTensor:$y
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint16, TF_Uint8]>:$x,
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint16, TF_Uint8]>:$y
);
let results = (outs
TF_NumberTensor:$z
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint16, TF_Uint8]>:$z
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
@ -2006,8 +2250,101 @@ def TF_NoOp : TF_Op<"NoOp", [NoSideEffect]> {
let results = (outs);
}
def TF_NotEqualOp : TF_Op<"NotEqual", [Broadcastable, Commutative, NoSideEffect]>,
WithBroadcastableCmpOpBuilder {
def TF_NonMaxSuppressionV4Op : TF_Op<"NonMaxSuppressionV4", [NoSideEffect]> {
let summary = [{
Greedily selects a subset of bounding boxes in descending order of score,
}];
let description = [{
pruning away boxes that have high intersection-over-union (IOU) overlap
with previously selected boxes. Bounding boxes with score less than
`score_threshold` are removed. Bounding boxes are supplied as
[y1, x1, y2, x2], where (y1, x1) and (y2, x2) are the coordinates of any
diagonal pair of box corners and the coordinates can be provided as normalized
(i.e., lying in the interval [0, 1]) or absolute. Note that this algorithm
is agnostic to where the origin is in the coordinate system and more
generally is invariant to orthogonal transformations and translations
of the coordinate system; thus translating or reflections of the coordinate
system result in the same boxes being selected by the algorithm.
The output of this operation is a set of integers indexing into the input
collection of bounding boxes representing the selected boxes. The bounding
box coordinates corresponding to the selected indices can then be obtained
using the `tf.gather operation`. For example:
selected_indices = tf.image.non_max_suppression_v2(
boxes, scores, max_output_size, iou_threshold, score_threshold)
selected_boxes = tf.gather(boxes, selected_indices)
}];
let arguments = (ins
TensorOf<[F16, F32]>:$boxes,
TensorOf<[F16, F32]>:$scores,
I32Tensor:$max_output_size,
TensorOf<[F16, F32]>:$iou_threshold,
TensorOf<[F16, F32]>:$score_threshold,
DefaultValuedAttr<BoolAttr, "false">:$pad_to_max_output_size
);
let results = (outs
I32Tensor:$selected_indices,
I32Tensor:$valid_outputs
);
TF_DerivedOperandTypeAttr T_threshold = TF_DerivedOperandTypeAttr<3>;
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_NonMaxSuppressionV5Op : TF_Op<"NonMaxSuppressionV5", [NoSideEffect]> {
let summary = [{
Greedily selects a subset of bounding boxes in descending order of score,
}];
let description = [{
pruning away boxes that have high intersection-over-union (IOU) overlap
with previously selected boxes. Bounding boxes with score less than
`score_threshold` are removed. Bounding boxes are supplied as
[y1, x1, y2, x2], where (y1, x1) and (y2, x2) are the coordinates of any
diagonal pair of box corners and the coordinates can be provided as normalized
(i.e., lying in the interval [0, 1]) or absolute. Note that this algorithm
is agnostic to where the origin is in the coordinate system and more
generally is invariant to orthogonal transformations and translations
of the coordinate system; thus translating or reflections of the coordinate
system result in the same boxes being selected by the algorithm.
The output of this operation is a set of integers indexing into the input
collection of bounding boxes representing the selected boxes. The bounding
box coordinates corresponding to the selected indices can then be obtained
using the `tf.gather operation`. For example:
selected_indices = tf.image.non_max_suppression_v2(
boxes, scores, max_output_size, iou_threshold, score_threshold)
selected_boxes = tf.gather(boxes, selected_indices)
This op also supports a Soft-NMS (with Gaussian weighting) mode (c.f.
Bodla et al, https://arxiv.org/abs/1704.04503) where boxes reduce the score
of other overlapping boxes instead of directly causing them to be pruned.
To enable this Soft-NMS mode, set the `soft_nms_sigma` parameter to be
larger than 0.
}];
let arguments = (ins
TensorOf<[F16, F32]>:$boxes,
TensorOf<[F16, F32]>:$scores,
I32Tensor:$max_output_size,
TensorOf<[F16, F32]>:$iou_threshold,
TensorOf<[F16, F32]>:$score_threshold,
TensorOf<[F16, F32]>:$soft_nms_sigma,
DefaultValuedAttr<BoolAttr, "false">:$pad_to_max_output_size
);
let results = (outs
I32Tensor:$selected_indices,
TensorOf<[F16, F32]>:$selected_scores,
I32Tensor:$valid_outputs
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_NotEqualOp : TF_Op<"NotEqual", [Commutative, NoSideEffect]> {
let summary = "Returns the truth value of (x != y) element-wise.";
let description = [{
@ -2016,8 +2353,10 @@ def TF_NotEqualOp : TF_Op<"NotEqual", [Broadcastable, Commutative, NoSideEffect]
}];
let arguments = (ins
TensorOf<[BF16, F16, F32, F64, I1, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Str]>:$x,
TensorOf<[BF16, F16, F32, F64, I1, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Str]>:$y
TensorOf<[BF16, F16, F32, F64, I1, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Str, TF_Uint8]>:$x,
TensorOf<[BF16, F16, F32, F64, I1, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Str, TF_Uint8]>:$y,
DefaultValuedAttr<BoolAttr, "true">:$incompatible_shape_error
);
let results = (outs
@ -2025,6 +2364,15 @@ def TF_NotEqualOp : TF_Op<"NotEqual", [Broadcastable, Commutative, NoSideEffect]
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
let builders = [
OpBuilder<"Builder* builder, OperationState* result, Value* x, "
"Value* y, BoolAttr incompatible_shape_error">
];
let verifier = [{
return Verify(*this);
}];
}
def TF_OneHotOp : TF_Op<"OneHot", [NoSideEffect]> {
@ -2121,7 +2469,7 @@ output =
}];
let arguments = (ins
TensorOf<[I32, I64, I8]>:$indices,
TensorOf<[I32, I64, TF_Uint8]>:$indices,
I32Tensor:$depth,
TF_Tensor:$on_value,
TF_Tensor:$off_value,
@ -2176,6 +2524,10 @@ This is the opposite of `unpack`.
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
let verifier = [{
return Verify(*this);
}];
}
def TF_PadOp : TF_Op<"Pad", [NoSideEffect]> {
@ -2303,14 +2655,14 @@ retained with length 1.
}];
let arguments = (ins
TF_NumberTensor:$input,
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$input,
TF_I32OrI64Tensor:$reduction_indices,
DefaultValuedAttr<BoolAttr, "false">:$keep_dims
);
let results = (outs
TF_NumberTensor:$output
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
@ -2406,7 +2758,8 @@ The above round function rounds the value based on the given round_mode.
DefaultValuedAttr<I64Attr, "8">:$num_bits,
DefaultValuedAttr<BoolAttr, "false">:$range_given,
DefaultValuedAttr<TF_AnyStrAttrOf<["HALF_TO_EVEN", "HALF_UP"]>, "HALF_TO_EVEN">:$round_mode,
DefaultValuedAttr<BoolAttr, "false">:$narrow_range
DefaultValuedAttr<BoolAttr, "false">:$narrow_range,
DefaultValuedAttr<I64Attr, "-1">:$axis
);
let results = (outs
@ -2432,7 +2785,8 @@ tensor, so its value can change during training.
DefaultValuedAttr<BoolAttr, "true">:$signed_input,
DefaultValuedAttr<BoolAttr, "true">:$range_given,
DefaultValuedAttr<BoolAttr, "false">:$narrow_range
DefaultValuedAttr<BoolAttr, "false">:$narrow_range,
DefaultValuedAttr<I64Attr, "-1">:$axis
);
let results = (outs
@ -2550,12 +2904,12 @@ If `x` and `y` are reals, this will return the floating-point division.
}];
let arguments = (ins
TF_NumberTensor:$x,
TF_NumberTensor:$y
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint16, TF_Uint8]>:$x,
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint16, TF_Uint8]>:$y
);
let results = (outs
TF_NumberTensor:$z
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint16, TF_Uint8]>:$z
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
@ -2590,11 +2944,11 @@ def TF_ReluOp : TF_Op<"Relu", [NoSideEffect, SameOperandsAndResultType]> {
}];
let arguments = (ins
TF_IntOrFpTensor:$features
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Qint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$features
);
let results = (outs
TF_IntOrFpTensor:$activations
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Qint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$activations
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
@ -2709,7 +3063,7 @@ Input images can be of different types but output images are always float.
}];
let arguments = (ins
TF_IntOrFpTensor:$images,
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Uint16, TF_Uint8]>:$images,
I32Tensor:$size,
DefaultValuedAttr<BoolAttr, "false">:$align_corners,
@ -2732,7 +3086,7 @@ Resize `images` to `size` using nearest neighbor interpolation.
}];
let arguments = (ins
TensorOf<[F16, F32, F64, I16, I32, I64, I8]>:$images,
TensorOf<[F16, F32, F64, I16, I32, I64, I8, TF_Uint16, TF_Uint8]>:$images,
I32Tensor:$size,
DefaultValuedAttr<BoolAttr, "false">:$align_corners,
@ -2740,7 +3094,7 @@ Resize `images` to `size` using nearest neighbor interpolation.
);
let results = (outs
TensorOf<[F16, F32, F64, I16, I32, I64, I8]>:$resized_images
TensorOf<[F16, F32, F64, I16, I32, I64, I8, TF_Uint16, TF_Uint8]>:$resized_images
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
@ -2875,12 +3229,12 @@ reverse(t, dims) ==> [[[[8, 9, 10, 11],
}];
let arguments = (ins
TensorOf<[BF16, F16, F32, F64, I1, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Str]>:$tensor,
TensorOf<[BF16, F16, F32, F64, I1, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Str, TF_Uint16, TF_Uint8]>:$tensor,
TF_I32OrI64Tensor:$axis
);
let results = (outs
TensorOf<[BF16, F16, F32, F64, I1, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Str]>:$output
TensorOf<[BF16, F16, F32, F64, I1, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Str, TF_Uint16, TF_Uint8]>:$output
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
@ -3031,6 +3385,10 @@ shape(t) ==> [2, 2, 3]
return Verify(*this);
}];
let builders = [
OpBuilder<"Builder* builder, OperationState* result, Value* input, BoolAttr use32Bit">
];
let hasFolder = 1;
}
@ -3643,12 +4001,12 @@ def TF_SubOp : TF_Op<"Sub", [Broadcastable, NoSideEffect]>,
}];
let arguments = (ins
TF_NumberTensor:$x,
TF_NumberTensor:$y
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint16, TF_Uint8]>:$x,
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint16, TF_Uint8]>:$y
);
let results = (outs
TF_NumberTensor:$z
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint16, TF_Uint8]>:$z
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
@ -3667,20 +4025,36 @@ retained with length 1.
}];
let arguments = (ins
TF_NumberTensor:$input,
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$input,
TF_I32OrI64Tensor:$reduction_indices,
DefaultValuedAttr<BoolAttr, "false">:$keep_dims
);
let results = (outs
TF_NumberTensor:$output
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
TF_DerivedOperandTypeAttr Tidx = TF_DerivedOperandTypeAttr<1>;
}
def TF_TPUCompilationResultOp : TF_Op<"TPUCompilationResult", [NoSideEffect]> {
let summary = "Returns the result of a TPU compilation.";
let description = [{
This operation returns the result of a TPU compilation as a serialized
CompilationResultProto, which holds a status and an error message if an error
occurred during compilation.
}];
let arguments = (ins);
let results = (outs
TF_StrTensor:$output
);
}
def TF_TanhOp : TF_Op<"Tanh", [NoSideEffect, SameOperandsAndResultType]> {
let summary = "Computes hyperbolic tangent of `x` element-wise.";
@ -3750,6 +4124,23 @@ def TF_TensorListGetItemOp : TF_Op<"TensorListGetItem", [NoSideEffect]> {
TF_DerivedResultTypeAttr element_dtype = TF_DerivedResultTypeAttr<0>;
}
def TF_TensorListLengthOp : TF_Op<"TensorListLength", [NoSideEffect]> {
let summary = "Returns the number of tensors in the input tensor list.";
let description = [{
input_handle: the input list
length: the number of tensors in the list
}];
let arguments = (ins
TF_VariantTensor:$input_handle
);
let results = (outs
I32Tensor:$length
);
}
def TF_TensorListPushBackOp : TF_Op<"TensorListPushBack", [NoSideEffect]> {
let summary = [{
Returns a list which has the passed-in `Tensor` as last element and the other elements of the given list in `input_handle`.
@ -3921,12 +4312,12 @@ Python Semantics.
}];
let arguments = (ins
TF_NumberTensor:$x,
TF_NumberTensor:$y
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint16, TF_Uint8]>:$x,
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint16, TF_Uint8]>:$y
);
let results = (outs
TF_NumberTensor:$z
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint16, TF_Uint8]>:$z
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
@ -4068,7 +4459,7 @@ where(input) ==> [[0, 0, 0],
}];
let arguments = (ins
TensorOf<[BF16, F16, F32, F64, I1, I16, I32, I64, I8, TF_Complex128, TF_Complex64]>:$input
TensorOf<[BF16, F16, F32, F64, I1, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$input
);
let results = (outs

View File

@ -178,7 +178,8 @@ def TF_AnyNumber : AnyTypeOf<[TF_Int, AnyFloat, TF_AnyQuantized, TF_AnyComplex],
def TF_NumberTensor : TensorOf<[TF_AnyNumber]>;
def TF_NumberOrStrTensor : TensorOf<[TF_AnyNumber, TF_Str]>;
def TF_NumberOrStr : AnyTypeOf<[AnyFloat, TF_SInt, TF_AnyComplex, TF_Uint8, TF_Str]>;
def TF_NumberOrStrTensor : TensorOf<[TF_NumberOrStr]>;
//===----------------------------------------------------------------------===//
// TensorFlow attribute definitions

View File

@ -19,13 +19,16 @@ limitations under the License.
#include <functional>
#include <numeric>
#include <string>
#include <type_traits>
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/StringSwitch.h"
#include "llvm/Support/FormatVariadic.h"
#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir
#include "mlir/Dialect/Traits.h" // TF:local_config_mlir
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
#include "mlir/IR/Builders.h" // TF:local_config_mlir
#include "mlir/IR/Diagnostics.h" // TF:local_config_mlir
@ -34,6 +37,7 @@ limitations under the License.
#include "mlir/IR/Matchers.h" // TF:local_config_mlir
#include "mlir/IR/OpImplementation.h" // TF:local_config_mlir
#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
#include "mlir/IR/TypeUtilities.h" // TF:local_config_mlir
#include "mlir/IR/Types.h" // TF:local_config_mlir
#include "mlir/IR/Value.h" // TF:local_config_mlir
@ -71,10 +75,19 @@ static inline bool IsOfRankOrUnranked(Value *value, int64_t rank) {
// Returns true if the given `value` has at least the specified rank or has
// unranked type.
static inline bool HasRankAtLeast(Value *value, int64_t rank) {
auto type = value->getType();
Type type = value->getType();
if (auto ranked_type = type.dyn_cast<RankedTensorType>())
return ranked_type.getRank() >= rank;
return type.isa<UnrankedTensorType>();
return true;
}
// Returns true if the given `value` has at most the specified rank or has
// unranked type.
static inline bool HasRankAtMost(Value *value, int64_t rank) {
Type type = value->getType();
if (auto ranked_type = type.dyn_cast<RankedTensorType>())
return ranked_type.getRank() <= rank;
return true;
}
// Returns true if the given pair of TensorFlow types can be cast to one
@ -95,6 +108,85 @@ static bool IsUnknownDimOrRank(int64_t dim_or_rank) {
return dim_or_rank == -1;
}
// Returns the tf.Equal/tf.NotEqual result type given `x` and `y` and inputs. If
// `incompatible_shape_error` is true, reports error if `x` and `y` has
// incompatible shapes. Otherwise, returns a tensor type with unknown rank.
static Type DeduceEqualCmpOpType(Builder *builder, Location loc, Value *x,
Value *y, BoolAttr incompatible_shape_error) {
auto result_type =
OpTrait::util::getBroadcastedType(x->getType(), y->getType());
if (!result_type) {
if (incompatible_shape_error.getValue()) {
mlir::emitError(loc, "non-broadcastable operands");
} else {
result_type = builder->getTensorType(builder->getI1Type());
}
}
return result_type;
}
// Verifies that the given types are cast compatible. If not, emits appropriate
// error for the given op. If mask_one_dim is set to true, then the types are
// allowed to have one mismatching dimension. Masking one of the dimensions is
// useful for ops like Concat that requires all ranked inputs to have the same
// rank and match dimension sizes for all but one of the dimensions.
static LogicalResult VerifyTypesCompatibility(
Operation::operand_type_range types, bool mask_one_dim, Operation *op) {
constexpr int64_t kUninitialized = -1;
int64_t common_rank = kUninitialized;
llvm::SmallVector<int64_t, 4> common_dims;
int64_t dim_to_mask = kUninitialized;
// Initialize common_rank with rank of the first ranked type and verify that
// following ranked types have the same rank.
// Similarly, initialize each of the dimensions with the first type that has
// the dimension size available and verify that all following types have the
// same size for the dimension. However, if mask_one_dim is true, note down
// the dimension index on the first mismatch and ignore dimension at that
// index in following types.
for (Type ty : types) {
RankedTensorType ranked_ty = ty.dyn_cast<RankedTensorType>();
if (!ranked_ty) continue;
int64_t rank = ranked_ty.getRank();
if (common_rank == kUninitialized) {
common_rank = rank;
common_dims.resize(common_rank, kUninitialized);
} else if (common_rank != rank) {
return op->emitError()
<< "operand type " << ranked_ty
<< " is not compatible with preceding operands; expected rank: "
<< common_rank;
}
for (int64_t i = 0, e = common_rank; i != e; i++) {
if (i == dim_to_mask) continue;
int64_t dim = ranked_ty.getDimSize(i);
if (dim == kUninitialized) continue;
int64_t &common_dim = common_dims[i];
if (common_dim == kUninitialized) {
common_dim = dim;
} else if (common_dim != dim) {
// If mask_one_dim is true, do not emit an error if this is the only
// dimension with mismatches. Note down the dimension to mask it from
// the following types.
if (mask_one_dim && dim_to_mask == kUninitialized) {
dim_to_mask = i;
continue;
}
return op->emitError() << "operand type " << ranked_ty
<< " is not compatible with preceding operands; "
"expected dimension at index "
<< i << ": " << common_dim;
}
}
}
return success();
}
namespace {
#include "tensorflow/compiler/mlir/tensorflow/transforms/generated_canonicalize.inc"
} // namespace
@ -176,6 +268,36 @@ void CastOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
results.insert<CastSameType>(context);
}
//===----------------------------------------------------------------------===//
// ConcatOp and ConcatV2Op
//===----------------------------------------------------------------------===//
template <typename OpT, typename = typename std::enable_if_t<
llvm::is_one_of<OpT, ConcatOp, ConcatV2Op>::value>>
static LogicalResult Verify(OpT op) {
// TODO(hinsu): Convert variadic length attributes to derived attributes.
Operation::operand_range values = op.values();
auto num_values = std::distance(values.begin(), values.end());
int64_t attr_N = op.N().getLimitedValue();
if (num_values != attr_N) {
return op.emitOpError()
<< "requires attribute 'N' to match the number of inputs; expected: "
<< num_values << " Found: " << attr_N;
}
int axis_idx = std::is_same<OpT, ConcatOp>() ? 0 : 1;
Value *axis = *op.getODSOperands(axis_idx).begin();
if (!HasRankAtMost(axis, 1)) {
return op.emitOpError(
"requires axis to be of scalar type (or vector type for older "
"versions)");
}
return VerifyTypesCompatibility(values,
/*mask_one_dim=*/true, op.getOperation());
}
//===----------------------------------------------------------------------===//
// ConjOp
//===----------------------------------------------------------------------===//
@ -257,6 +379,26 @@ static LogicalResult Verify(EmptyTensorListOp op) {
return success();
}
//===----------------------------------------------------------------------===//
// EqualOp
//===----------------------------------------------------------------------===//
static LogicalResult Verify(EqualOp op) {
// If we allow inputs to have incompatible type, then nothing to do.
if (!op.incompatible_shape_error()) return success();
// Otherwise, check inputs are broadcastable.
return mlir::OpTrait::impl::verifyCompatibleOperandBroadcast(
op.getOperation());
}
void EqualOp::build(Builder *builder, OperationState *result, Value *x,
Value *y, BoolAttr incompatible_shape_error) {
auto result_type = DeduceEqualCmpOpType(builder, result->location, x, y,
incompatible_shape_error);
return build(builder, result, result_type, x, y, incompatible_shape_error);
}
//===----------------------------------------------------------------------===//
// FakeQuantWithMinMaxArgsOp
//===----------------------------------------------------------------------===//
@ -276,12 +418,6 @@ static LogicalResult Verify(FakeQuantWithMinMaxArgsOp op) {
return op.emitOpError("range is invalid: [" + Twine(std::to_string(rmin)) +
"," + Twine(std::to_string(rmax)) + "]");
}
// Range must straddle zero.
if (rmin > 0.0 || rmax < 0.0) {
return op.emitOpError("range failed to straddle zero: [" +
Twine(std::to_string(rmin)) + "," +
Twine(std::to_string(rmax)) + "]");
}
int64_t num_bits = op.num_bits().getSExtValue();
if (num_bits < 2 || num_bits > 16) {
return op.emitOpError(
@ -308,6 +444,37 @@ static LogicalResult Verify(FakeQuantWithMinMaxVarsOp op) {
return success();
}
//===----------------------------------------------------------------------===//
// FakeQuantWithMinMaxVarsPerChannelOp
//===----------------------------------------------------------------------===//
static LogicalResult Verify(FakeQuantWithMinMaxVarsPerChannelOp op) {
if (!isOfRankedFloatTensorType(op.min(), 1))
return op.emitOpError("requires min to be a 1d float tensor");
if (!isOfRankedFloatTensorType(op.max(), 1))
return op.emitOpError("requires max to be a 1d float tensor");
Value *inputs = op.inputs();
if (!HasRankAtLeast(inputs, 1) ||
inputs->getType().isa<UnrankedTensorType>()) {
return op.emitError("requires inputs to be at least 1d float tensor");
}
auto inputsType = inputs->getType().cast<ShapedType>();
int depth = inputsType.getDimSize(inputsType.getRank() - 1);
if (op.min()->getType().cast<ShapedType>().getDimSize(0) != depth ||
op.max()->getType().cast<ShapedType>().getDimSize(0) != depth) {
return op.emitOpError(
"requires min and max to have same size as last dimension of inputs");
}
int64_t num_bits = op.num_bits().getSExtValue();
if (num_bits < 2 || num_bits > 16) {
return op.emitOpError(
"requires num_bits to be between 2 and 16, inclusive");
}
return success();
}
//===----------------------------------------------------------------------===//
// FusedBatchNormOp
//===----------------------------------------------------------------------===//
@ -471,6 +638,74 @@ void NegOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
results.insert<NegNested>(context);
}
//===----------------------------------------------------------------------===//
// NotEqualOp
//===----------------------------------------------------------------------===//
static LogicalResult Verify(NotEqualOp op) {
// If we allow inputs to have incompatible type, then nothing to do.
if (!op.incompatible_shape_error()) return success();
// Otherwise, check inputs are broadcastable.
return mlir::OpTrait::impl::verifyCompatibleOperandBroadcast(
op.getOperation());
}
void NotEqualOp::build(Builder *builder, OperationState *result, Value *x,
Value *y, BoolAttr incompatible_shape_error) {
auto result_type = DeduceEqualCmpOpType(builder, result->location, x, y,
incompatible_shape_error);
return build(builder, result, result_type, x, y, incompatible_shape_error);
}
//===----------------------------------------------------------------------===//
// PackOp
//===----------------------------------------------------------------------===//
static LogicalResult Verify(PackOp op) {
// TODO(hinsu): Convert variadic length attributes to derived attributes.
Operation::operand_range values = op.values();
auto num_values = std::distance(values.begin(), values.end());
int64_t attr_N = op.N().getLimitedValue();
if (num_values != attr_N) {
return op.emitOpError()
<< "requires attribute 'N' to match the number of inputs; expected: "
<< num_values << " Found: " << attr_N;
}
if (failed(VerifyTypesCompatibility(values,
/*mask_one_dim=*/false,
op.getOperation()))) {
return failure();
}
int64_t inputs_rank = -1;
for (Value *value : values) {
if (auto ty = value->getType().dyn_cast<RankedTensorType>()) {
// Exit early as input types are verified to be compatible so all ranked
// tensors have the same rank.
inputs_rank = ty.getRank();
break;
}
}
if (inputs_rank == -1) return success();
// The values can be packed along any of the dimensions between 0 and
// inputs rank, inclusive. Also, as the negative axis values wrap around so
// the axis value range is [-(R+1), R+1).
int64_t range_begin = -inputs_rank - 1; // Inclusive
int64_t range_end = inputs_rank + 1; // Exclusive
int64_t axis = op.axis().getLimitedValue();
if (axis < range_begin || axis >= range_end) {
return op.emitError() << "attribute 'axis' should be within range ["
<< range_begin << ", " << range_end
<< "); actual value: " << axis;
}
return success();
}
//===----------------------------------------------------------------------===//
// ReciprocalOp
//===----------------------------------------------------------------------===//
@ -731,6 +966,16 @@ OpFoldResult ShapeOp::fold(ArrayRef<Attribute> operands) {
return b.getDenseElementsAttr(resultType, dimensions);
}
void ShapeOp::build(Builder *builder, OperationState *result, Value *input,
BoolAttr use32Bit) {
auto rankedTensorType = input->getType().dyn_cast<RankedTensorType>();
int64_t rank = rankedTensorType ? rankedTensorType.getRank() : -1;
auto out_type = use32Bit.getValue() ? builder->getIntegerType(32)
: builder->getIntegerType(64);
return ShapeOp::build(builder, result,
builder->getTensorType({rank}, out_type), input);
}
//===----------------------------------------------------------------------===//
// ShapeNOp
//===----------------------------------------------------------------------===//

View File

@ -0,0 +1,60 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/op.h"
namespace tensorflow {
REGISTER_OP("MlirPassthroughOp")
.Attr("mlir_module: string")
.Attr("Tinputs : list(type) >= 0")
.Input("inputs: Tinputs")
.Attr("Toutputs : list(type) >= 0")
.Output("outputs: Toutputs")
.Doc(R"doc(
Wraps an arbitrary MLIR computation expressed as a module with a main() function.
This operation does not have an associated kernel and is not intended to be
executed in a regular TensorFlow session. Instead it is intended to be used for
testing or for special case where a user intends to pass custom MLIR computation
through a TensorFlow graph with the intent of having custom tooling processing
it downstream (when targeting a different environment, like TensorFlow lite for
example).
The MLIR module is expected to have a main() function that will be used as an
entry point. The inputs to the operations will be passed as argument to the
main() function and the returned values of the main function mapped to the
outputs.
Example usage:
```
import tensorflow as tf
from tensorflow.compiler.mlir.tensorflow.gen_mlir_passthrough_op import mlir_passthrough_op
mlir_module = '''
func @main(%arg0 : tensor<10xf32>, %arg1 : tensor<10xf32>) -> tensor<10x10xf32> {
%add = "magic.op"(%arg0, %arg1) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10x10xf32>
return %ret : tensor<10x10xf32>
}
'''
@tf.function
def foo(x, y):
return = mlir_passthrough_op([x, y], mlir_module, Toutputs=[tf.float32])
graph_def = foo.get_concrete_function(tf.TensorSpec([10], tf.float32), tf.TensorSpec([10], tf.float32)).graph.as_graph_def()
```
)doc");
} // namespace tensorflow

View File

@ -236,8 +236,8 @@ func @testLogicalNotOfEqual(%arg0: tensor<8x16xf32>, %arg1: tensor<8x16xf32>) ->
%1 = "tf.LogicalNot"(%0) : (tensor<8x16xi1>) -> tensor<8x16xi1>
return %1: tensor<8x16xi1>
// CHECK: %0 = "tf.NotEqual"(%arg0, %arg1) : (tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xi1>
// CHECK: return %0
// CHECK: %[[NE:.*]] = "tf.NotEqual"(%arg0, %arg1) {incompatible_shape_error = true}
// CHECK: return %[[NE]]
}
// CHECK-LABEL: testLogicalNotOfNotEqual
@ -246,8 +246,8 @@ func @testLogicalNotOfNotEqual(%arg0: tensor<8x16xf32>, %arg1: tensor<8x16xf32>)
%1 = "tf.LogicalNot"(%0) : (tensor<8x16xi1>) -> tensor<8x16xi1>
return %1: tensor<8x16xi1>
// CHECK: %0 = "tf.Equal"(%arg0, %arg1) : (tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xi1>
// CHECK: return %0
// CHECK: %[[NE:.*]] = "tf.Equal"(%arg0, %arg1) {incompatible_shape_error = true}
// CHECK: return %[[NE]]
}
// CHECK-LABEL: testLogicalNotOfGreater

View File

@ -1,4 +1,4 @@
// RUN: tf-opt %s -split-input-file -tf-device-cluster-formation | FileCheck %s
// RUN: tf-opt %s -split-input-file -tf-device-cluster-formation | FileCheck %s -dump-input-on-failure
// Simple case, single device cluster.
@ -72,11 +72,8 @@ module {
// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor<?xi32>)
func @argliveinotherislands(%arg0: tensor<?xi32>) -> tensor<?xi32> {
%0 = tf_executor.graph {
// CHECK: %[[OTHER_ISLAND_OUTPUT:[0-9]*]]:2 = tf_executor.island {
%1:2 = tf_executor.island {
%3 = "tf.D"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
tf_executor.yield %3 : tensor<?xi32>
}
// CHECK: %[[OTHER_ISLAND_OUTPUT:[0-9]*]]:2 = tf_executor.island wraps "tf.D"
%1:2 = tf_executor.island wraps "tf.D"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
%2:2 = tf_executor.island {
// CHECK: %[[TPU0_OUTPUT:[0-9]*]] = "tf_device.launch"

View File

@ -90,16 +90,13 @@ module {
// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor<?xi32>)
func @multiplelaunches(%arg0: tensor<?xi32>) -> tensor<?xi32> {
%0 = tf_executor.graph {
%1:2 = tf_executor.island {
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf_device.launch_func"() {device = "tpu0", func = @tpu0_func}
%2 = "tf_device.launch"() ( {
%1:2 = tf_executor.island wraps
// CHECK: %[[A_OUTPUT:[0-9]*]]:2 = {{.*}} "tf_device.launch_func"() {device = "tpu0", func = @tpu0_func}
"tf_device.launch"() ( {
%3 = "tf.A"() : () -> tensor<?xi32>
"tf_device.return"(%3) : (tensor<?xi32>) -> ()
}) {device = "tpu0"} : () -> tensor<?xi32>
// CHECK: tf_executor.yield %[[A_OUTPUT]]
tf_executor.yield %2 : tensor<?xi32>
}
// CHECK: tf_executor.fetch %[[A_OUTPUT]]#0
tf_executor.fetch %1#0 : tensor<?xi32>
}
return %0 : tensor<?xi32>

View File

@ -11,14 +11,8 @@ func @islands_with_control(tensor<*xf32>) -> tensor<*xf32> {
}
// CHECK-NEXT: %[[GRAPH:[0-9]*]] = tf_executor.graph {
// CHECK-NEXT: %[[IDENTITY:[0-9]*]]:2 = tf_executor.island {
// CHECK-NEXT: %{{[0-9]*}} = "tf.Identity"(%[[ARG0]]) : (tensor<*xf32>) -> tensor<*xf32>
// CHECK-NEXT: tf_executor.yield %{{[0-9]*}} : tensor<*xf32>
// CHECK-NEXT: }
// CHECK-NEXT: %[[ADD:[0-9]*]]:2 = tf_executor.island(%[[IDENTITY]]#1) {
// CHECK-NEXT: %{{[0-9]*}} = "tf.Add"(%[[ARG0]], %[[ARG0]]) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
// CHECK-NEXT: tf_executor.yield %{{[0-9]*}} : tensor<*xf32>
// CHECK-NEXT: }
// CHECK-NEXT: %[[IDENTITY:[0-9]*]]:2 = tf_executor.island wraps "tf.Identity"(%[[ARG0]]) : (tensor<*xf32>) -> tensor<*xf32>
// CHECK-NEXT: %[[ADD:[0-9]*]]:2 = tf_executor.island(%[[IDENTITY]]#1) wraps "tf.Add"(%[[ARG0]], %[[ARG0]]) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
// CHECK-NEXT: tf_executor.fetch %[[ADD]]#0 : tensor<*xf32>
// CHECK-NEXT: }
// CHECK-NEXT: return %[[GRAPH]] : tensor<*xf32>
@ -45,40 +39,19 @@ func @LoopTest() {
}
// CHECK-NEXT: tf_executor.graph {
// CHECK-NEXT: %[[CONST:[0-9]*]]:2 = tf_executor.island {
// CHECK-NEXT: %{{[a-z0-9]*}} = "tf.Const"() {device = "", dtype = "tfdtype$DT_INT32", name = "Const", value = dense<1> : tensor<i32>} : () -> tensor<i32>
// CHECK-NEXT: tf_executor.yield %{{[a-z0-9]*}} : tensor<i32>
// CHECK-NEXT: }
// CHECK-NEXT: %[[CONST:[0-9]*]]:2 = tf_executor.island wraps "tf.Const"() {device = "", dtype = "tfdtype$DT_INT32", name = "Const", value = dense<1> : tensor<i32>} : () -> tensor<i32>
// CHECK-NEXT: %[[ENTER:[0-9]*]]:2 = tf_executor.Enter %[[CONST]]#0 frame "while/while_context" : (tensor<i32>) -> (tensor<*xi32>, !tf_executor.control) {T = "tfdtype$DT_INT32", device = "", name = "while/Enter"}
// CHECK-NEXT: %[[NOOP:[0-9]*]] = tf_executor.island {
// CHECK-NEXT: "tf.NoOp"() {device = "", name = "cluster/pivot"} : () -> ()
// CHECK-NEXT: tf_executor.yield
// CHECK-NEXT: }
// CHECK-NEXT: %[[NOOP:[0-9]*]] = tf_executor.island wraps "tf.NoOp"() {device = "", name = "cluster/pivot"} : () -> ()
// CHECK-NEXT: %[[NEXTIT_SRC:[0-9]*]]:3 = tf_executor.NextIteration.Source : tensor<*xi32> {T = "tfdtype$DT_INT32", device = "", id = 0 : i64, name = "while/NextIteration"}
// CHECK-NEXT: %[[MERGE:[0-9]*]]:3 = tf_executor.Merge %[[NEXTIT_SRC]]#0, %[[ENTER]]#0 : tensor<*xi32> {N = 2 : i64, T = "tfdtype$DT_INT32", device = "", name = "while/Merge"}
// CHECK-NEXT: %[[CONST_LESS:[0-9]*]]:2 = tf_executor.island(%[[MERGE]]#2) {
// CHECK-NEXT: %{{[a-z0-9]*}} = "tf.Const"() {device = "", dtype = "tfdtype$DT_INT32", name = "while/Less/y", value = dense<2> : tensor<i32>} : () -> tensor<i32>
// CHECK-NEXT: tf_executor.yield %{{[a-z0-9]*}} : tensor<i32>
// CHECK-NEXT: }
// CHECK-NEXT: %[[LESS:[0-9]*]]:2 = tf_executor.island {
// CHECK-NEXT: %{{[a-z0-9]*}} = "tf.Less"(%[[MERGE]]#0, %[[CONST_LESS]]#0) {T = "tfdtype$DT_INT32", device = "", name = "while/Less"} : (tensor<*xi32>, tensor<i32>) -> tensor<*xi1>
// CHECK-NEXT: tf_executor.yield %{{[a-z0-9]*}} : tensor<*xi1>
// CHECK-NEXT: }
// CHECK-NEXT: %[[CONST_LESS:[0-9]*]]:2 = tf_executor.island(%[[MERGE]]#2) wraps "tf.Const"() {device = "", dtype = "tfdtype$DT_INT32", name = "while/Less/y", value = dense<2> : tensor<i32>} : () -> tensor<i32>
// CHECK-NEXT: %[[LESS:[0-9]*]]:2 = tf_executor.island wraps "tf.Less"(%[[MERGE]]#0, %[[CONST_LESS]]#0) {T = "tfdtype$DT_INT32", device = "", name = "while/Less"} : (tensor<*xi32>, tensor<i32>) -> tensor<*xi1>
// CHECK-NEXT: %[[COND:[0-9]*]]:2 = tf_executor.LoopCond %[[LESS:[0-9]*]]#0 : (tensor<*xi1>) -> (tensor<i1>, !tf_executor.control) {device = "", name = "while/LoopCond"}
// CHECK-NEXT: %[[SWITCH:[0-9]*]]:3 = tf_executor.Switch %[[MERGE]]#0, %[[COND]]#0 : tensor<*xi32> {T = "tfdtype$DT_INT32", _class = ["loc = @while/Merge"], device = "", name = "while/Switch"}
// CHECK-NEXT: %[[EXIT:[0-9]*]]:2 = tf_executor.Exit %[[SWITCH]]#0 : tensor<*xi32> {T = "tfdtype$DT_INT32", device = "", name = "while/Exit"}
// CHECK-NEXT: %[[IDENTITY:[0-9]*]]:2 = tf_executor.island {
// CHECK-NEXT: %{{[a-z0-9]*}} = "tf.Identity"(%[[SWITCH]]#1) {T = "tfdtype$DT_INT32", device = "", name = "while/Identity"} : (tensor<*xi32>) -> tensor<*xi32>
// CHECK-NEXT: tf_executor.yield %{{[a-z0-9]*}} : tensor<*xi32>
// CHECK-NEXT: }
// CHECK-NEXT: %[[CONST_ADD:[0-9]*]]:2 = tf_executor.island(%[[IDENTITY]]#1) {
// CHECK-NEXT: %{{[a-z0-9]*}} = "tf.Const"() {device = "", dtype = "tfdtype$DT_INT32", name = "while/Add/y", value = dense<3> : tensor<i32>} : () -> tensor<i32>
// CHECK-NEXT: tf_executor.yield %{{[a-z0-9]*}} : tensor<i32>
// CHECK-NEXT: }
// CHECK-NEXT: %[[ADD:[0-9]*]]:2 = tf_executor.island {
// CHECK-NEXT: %{{[0-9]*}} = "tf.Add"(%[[IDENTITY]]#0, %[[CONST_ADD]]#0) {T = "tfdtype$DT_INT32", device = "", name = "while/Add"} : (tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
// CHECK-NEXT: tf_executor.yield %{{[0-9]*}} : tensor<*xi32>
// CHECK-NEXT: }
// CHECK-NEXT: %[[IDENTITY:[0-9]*]]:2 = tf_executor.island wraps "tf.Identity"(%[[SWITCH]]#1) {T = "tfdtype$DT_INT32", device = "", name = "while/Identity"} : (tensor<*xi32>) -> tensor<*xi32>
// CHECK-NEXT: %[[CONST_ADD:[0-9]*]]:2 = tf_executor.island(%[[IDENTITY]]#1) wraps "tf.Const"() {device = "", dtype = "tfdtype$DT_INT32", name = "while/Add/y", value = dense<3> : tensor<i32>} : () -> tensor<i32>
// CHECK-NEXT: %[[ADD:[0-9]*]]:2 = tf_executor.island wraps "tf.Add"(%[[IDENTITY]]#0, %[[CONST_ADD]]#0) {T = "tfdtype$DT_INT32", device = "", name = "while/Add"} : (tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
// CHECK-NEXT: %[[CT:[0-9]*]] = tf_executor.ControlTrigger %[[NOOP]], %[[ADD]]#1, %[[EXIT]]#1 {_tpu_replicate = "cluster", device = "", name = "gradients/while/mul_2_Da30D05wlPU_grad/SymbolicGradient/b_sync"}
// CHECK-NEXT: tf_executor.NextIteration.Sink [%[[NEXTIT_SRC]]#1] %[[ADD]]#0, %[[CT]] : tensor<*xi32> {T = "tfdtype$DT_INT32", device = "", id = 0 : i64, name = "while/NextIteration"}
// CHECK-NEXT: tf_executor.fetch

View File

@ -285,9 +285,9 @@ func @empty_island_no_operand_no_data_result() {
return
}
// CHECK: %[[ISLAND_0:[0-9]*]] = tf_executor.island {
// CHECK: %[[ISLAND_0:[0-9]*]] = tf_executor.island
// CHECK-NEXT: "tf.opA"
// CHECK: tf_executor.island(%[[ISLAND_0]]) {
// CHECK: tf_executor.island(%[[ISLAND_0]])
// CHECK-NEXT: "tf.opB"
// CHECK-NOT: tf_executor.island
@ -313,9 +313,9 @@ func @empty_island_one_operand_no_data_result() {
return
}
// CHECK: %[[ISLAND_1:[0-9]*]] = tf_executor.island {
// CHECK: %[[ISLAND_1:[0-9]*]] = tf_executor.island
// CHECK-NEXT: "tf.opA"
// CHECK: tf_executor.island(%[[ISLAND_1]]) {
// CHECK: tf_executor.island(%[[ISLAND_1]])
// CHECK-NEXT: "tf.opB"
// CHECK-NOT: tf_executor.island
@ -342,8 +342,34 @@ func @empty_island_no_operand_one_data_no_control_result(%arg0 : tensor<i1>) {
return
}
// CHECK: tf_executor.island {
// CHECK: tf_executor.island
// CHECK-NEXT: "tf.opA"(%[[ARG_0]])
// CHECK: tf_executor.island {
// CHECK-NEXT: "tf.opB"(%[[ARG_0]])
// CHECK-NOT: tf_executor.island
// Test empty control trigger with no operands is removed.
// Control result users should also have their respective operands removed.
// CHECK-LABEL: func @empty_control_trigger
func @empty_control_trigger() {
tf_executor.graph {
%0 = tf_executor.ControlTrigger {}
%1 = tf_executor.island(%0) {
%3 = "tf.opA"() : () -> tensor<i1>
tf_executor.yield
}
%2 = tf_executor.island(%0, %1) {
%4 = "tf.opB"() : () -> tensor<i1>
tf_executor.yield
}
tf_executor.fetch
}
return
}
// CHECK: %[[ISLAND_0:[0-9]*]] = tf_executor.island
// CHECK-NEXT: "tf.opA"
// CHECK: tf_executor.island(%[[ISLAND_0]])
// CHECK-NEXT: "tf.opB"
// CHECK-NOT: tf_executor.island

View File

@ -89,9 +89,7 @@ func @empty_islands(%arg0 : tensor<i1>, %arg1 : tensor<i1>) -> (tensor<i1>, tens
return %0#0, %0#1 : tensor<i1>, tensor<i1>
}
// CHECK: %[[ISLAND:[0-9]*]]:3 = tf_executor.island {
// CHECK-NEXT: %[[OP_A:[0-9]*]]:2 = "tf.opA"(%[[ARG_1]], %[[ARG_0]])
// CHECK-NEXT: tf_executor.yield %[[OP_A]]#0, %[[OP_A]]#1 : tensor<i1>, tensor<i1>
// CHECK: %[[ISLAND:[0-9]*]]:3 = tf_executor.island wraps "tf.opA"(%[[ARG_1]], %[[ARG_0]])
// CHECK: tf_executor.fetch %[[ISLAND]]#0, %[[ISLAND]]#1 : tensor<i1>, tensor<i1>
@ -228,9 +226,7 @@ func @islands_interleaved(%arg0 : tensor<i32>, %arg1 : tensor<i32>) -> (tensor<i
// CHECK-NEXT: %[[OP_C:[0-9]*]] = "tf.opC"(%[[OP_A]])
// CHECK-NEXT: %{{[0-9]*}} = "tf.opE"(%[[ARG_0]])
// CHECK-NEXT: tf_executor.yield %[[OP_C]] : tensor<i32>
// CHECK: tf_executor.island {
// CHECK-NEXT: %[[OP_F:[0-9]*]] = "tf.opF"(%[[ARG_1]])
// CHECK-NEXT: tf_executor.yield %[[OP_F]] : tensor<i32>
// CHECK: tf_executor.island wraps "tf.opF"(%[[ARG_1]])
// CHECK: tf_executor.fetch %[[ISLAND_0]]#0, %[[ISLAND_1]]#0 : tensor<i32>, tensor<i32>
@ -279,13 +275,9 @@ func @merge_islands_only() {
return
}
// CHECK: %[[ISLAND_0:[0-9]*]]:2 = tf_executor.island {
// CHECK-NEXT: %[[OP_A:.*]] = "tf.opA"
// CHECK-NEXT: tf_executor.yield %[[OP_A]] : tensor<i32>
// CHECK: %[[ISLAND_0:[0-9]*]]:2 = tf_executor.island wraps "tf.opA"
// CHECK: %[[ENTER:[0-9]*]]:2 = tf_executor.Enter %[[ISLAND_0]]#0
// CHECK-NEXT: %[[ISLAND_1:[0-9]*]] = tf_executor.island {
// CHECK-NEXT: "tf.opB"()
// CHECK-NEXT: tf_executor.yield
// CHECK-NEXT: %[[ISLAND_1:[0-9]*]] = tf_executor.island wraps "tf.opB"()
// CHECK: %[[NEXTIT_SRC:[0-9]*]]:3 = tf_executor.NextIteration.Source
// CHECK-NEXT: %[[MERGE:[0-9]*]]:3 = tf_executor.Merge %[[NEXTIT_SRC]]#0, %[[ENTER]]#0
// CHECK-NEXT: %[[ISLAND_2:[0-9]*]]:2 = tf_executor.island(%[[MERGE]]#2) {
@ -322,9 +314,7 @@ func @simple_potential_cycle() {
return
}
// CHECK: %[[ISLAND:[0-9]*]]:2 = tf_executor.island {
// CHECK-NEXT: %[[OP_A:[0-9]*]] = "tf.opA"
// CHECK-NEXT: tf_executor.yield %[[OP_A]] : tensor<1xf32>
// CHECK: %[[ISLAND:[0-9]*]]:2 = tf_executor.island wraps "tf.opA"
// CHECK: %[[CT:[0-9]*]] = tf_executor.ControlTrigger %[[ISLAND]]#1
// CHECK-NEXT: tf_executor.island(%[[CT]]) {
// CHECK-NEXT: %[[OP_B:[0-9]*]] = "tf.opB"
@ -384,9 +374,7 @@ func @merge_into_nested_data_result() {
// CHECK-NEXT: [[OP_A:[0-9*]]] = "tf.opA"
// CHECK-NEXT: [[INNER_GRAPH:[0-9]*]] = tf_executor.graph {
// CHECK-NEXT: [[CT:[0-9]*]] = tf_executor.ControlTrigger
// CHECK-NEXT: [[ISLAND_1:[0-9]*]]:2 = tf_executor.island(%[[CT]]) {
// CHECK-NEXT: [[OP_B:[0-9]*]] = "tf.opB"(%[[OP_A]])
// CHECK-NEXT: tf_executor.yield %[[OP_B]] : tensor<1xf32>
// CHECK-NEXT: [[ISLAND_1:[0-9]*]]:2 = tf_executor.island(%[[CT]]) wraps "tf.opB"(%[[OP_A]])
// CHECK: tf_executor.fetch %[[ISLAND_1]]#0 : tensor<1xf32>
// CHECK: tf_executor.yield
@ -422,18 +410,14 @@ func @merge_islands_inner_graph() {
return
}
// CHECK: tf_executor.island {
// CHECK-NEXT: [[OP_A:[0-9*]]] = "tf.opA"
// CHECK-NEXT: tf_executor.yield %[[OP_A]] : tensor<1xf32>
// CHECK: tf_executor.island {
// CHECK-NEXT: [[INNER_GRAPH:[0-9]*]] = tf_executor.graph {
// CHECK: tf_executor.island wraps "tf.opA"
// CHECK: tf_executor.island wraps "tf_executor.graph"() ( {
// CHECK-NEXT: [[ISLAND_1:[0-9]*]]:2 = tf_executor.island {
// CHECK-NEXT: "tf.opB"
// CHECK-NEXT: [[OP_C:[0-9]*]] = "tf.opC"
// CHECK-NEXT: [[OP_D:[0-9]*]] = "tf.opD"(%[[OP_C]])
// CHECK-NEXT: tf_executor.yield %[[OP_D]] : tensor<1xf32>
// CHECK: tf_executor.fetch %[[ISLAND_1]]#0 : tensor<1xf32>
// CHECK: tf_executor.yield %[[INNER_GRAPH]] : tensor<1xf32>
// Test merging islands with control island operands and island results only if
@ -454,7 +438,7 @@ func @merge_islands_closest_control() {
return
}
// CHECK: %[[ISLAND:[0-9]*]] = tf_executor.island {
// CHECK: %[[ISLAND:[0-9]*]] = tf_executor.island
// CHECK: tf_executor.ControlTrigger %[[ISLAND]]
// CHECK: %[[CT:[0-9]*]] = tf_executor.ControlTrigger
// CHECK: tf_executor.island(%[[ISLAND]], %[[CT]]) {
// CHECK: tf_executor.island(%[[ISLAND]], %[[CT]])

View File

@ -0,0 +1,23 @@
// RUN: tf-opt %s -canonicalize | FileCheck %s --dump-input=fail
// Test that a constant stays inside an island after canonicalization
// CHECK-LABEL: func @constant_in_island
func @constant_in_island(%arg0 : tensor<i1>) -> tensor<f32> {
%0 = tf_executor.graph {
// CHECK: tf_executor.island
// CHECK: tf.Const{{.*}}2.0
%1:2 = tf_executor.island {
%0 = "tf.Const"() {value = dense<2.000000e+00> : tensor<f32>} : () -> tensor<f32>
tf_executor.yield %0 : tensor<f32>
}
// Uses two islands for no other reason than preventing canonicalization from
// eliminating the graph entirely.
%2:2 = tf_executor.island(%1#1) {
%4 = "tf.opB"(%1#0) : (tensor<f32>) -> tensor<f32>
tf_executor.yield %4 : tensor<f32>
}
tf_executor.fetch %2#0 : tensor<f32>
}
return %0 : tensor<f32>
}

View File

@ -97,3 +97,24 @@ func @switchN(%arg0: tensor<i32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
}
return %fetches : tensor<*xf32>
}
// Test if tf_executor dialect ops with Ref types are mapped correctly to the ops in control dialect.
// CHECK-LABEL: func @ref_tf_executor_ops
func @ref_tf_executor_ops(%arg0: tensor<4x!tf.f32ref>, %arg1: tensor<4x!tf.f32ref>, %arg3: tensor<i32>, %arg4: tensor<i1> ) -> tensor<4x!tf.f32ref> {
%result = tf_executor.graph {
// CHECK: _tf.Enter
%0:2 = tf_executor.Enter %arg0 frame "while/while_context" : (tensor<4x!tf.f32ref>) -> (tensor<4x!tf.f32ref>, !tf_executor.control)
// CHECK: _tf.Exit
%1:2 = tf_executor.Exit %arg0 : tensor<4x!tf.f32ref>
// CHECK: _tf.Switch
%2:3 = tf_executor.Switch %arg0, %arg4 : (tensor<4x!tf.f32ref>, tensor<i1>) -> (tensor<4x!tf.f32ref>, tensor<4x!tf.f32ref>, !tf_executor.control)
// CHECK: _tf.Merge
%3:3 = tf_executor.Merge %arg0, %arg1 : (tensor<4x!tf.f32ref>, tensor<4x!tf.f32ref>) -> (tensor<4x!tf.f32ref>, tensor<i32>, !tf_executor.control)
// CHECK: _tf.NextIteration.source
%4:3 = tf_executor.NextIteration.Source : tensor<4x!tf.f32ref>
// CHECK: _tf.NextIteration.sink
tf_executor.NextIteration.Sink [%4#1] %4#0 : tensor<4x!tf.f32ref>
tf_executor.fetch %0#0 : tensor<4x!tf.f32ref>
}
return %result : tensor<4x!tf.f32ref>
}

View File

@ -39,13 +39,10 @@ versions {
# CHECK: func @main(%arg0: tensor<10xi32>, %arg1: tensor<10xi32>) -> tensor<10xi32>
# CHECK: attributes {tf.entry_function = {inputs = "input0, input1", outputs = "Add"}} {
# CHECK: %[[INPUT0:[0-9]+]]:2 = tf_executor.island
# CHECK-NEXT: "tf.Placeholder.input"(%arg0)
# CHECK: %[[INPUT0:[0-9]+]]:2 = tf_executor.island wraps "tf.Placeholder.input"(%arg0)
# CHECK: %[[INPUT1:[0-9]+]]:2 = tf_executor.island
# CHECK-NEXT: "tf.Placeholder.input"(%arg1)
# CHECK: %[[INPUT1:[0-9]+]]:2 = tf_executor.island wraps "tf.Placeholder.input"(%arg1)
# CHECK: %[[add:[0-9]+]]:2 = tf_executor.island
# CHECK-NEXT: "tf.Add"(%[[INPUT0]]#0, %[[INPUT1]]#0)
# CHECK: %[[add:[0-9]+]]:2 = tf_executor.island wraps "tf.Add"(%[[INPUT0]]#0, %[[INPUT1]]#0)
# CHECK: fetch %[[add]]#0

View File

@ -41,8 +41,7 @@ library {
}
# Drop the control dependency on arg for the node "test"
# CHECK-LABEL: func @foo
# CHECK: tf_executor.island {
# CHECK-NEXT: "tf.Const"()
# CHECK: tf_executor.island wraps "tf.Const"()
node_def {
name: "test"
op: "Const"

View File

@ -6,12 +6,9 @@
# CHECK: func @main(%arg0: tensor<*x!tf.resource>, %arg1: tensor<*x!tf.resource>) -> (tensor<f32>, tensor<f32>)
# CHECK: attributes {tf.entry_function = {inputs = "args_0, args_1", outputs = "rets_0_RetVal, rets_1_RetVal"}} {
# CHECK: %[[ISLAND_0:[0-9]]]:2 = tf_executor.island {
# CHECK: "tf.Const"
# CHECK: %[[ISLAND_1:[0-9]]]:2 = tf_executor.island {
# CHECK: "tf.Identity"(%[[ISLAND_0]]#0)
# CHECK: %[[ISLAND_2:[0-9]]]:2 = tf_executor.island {
# CHECK: "tf.StatefulPartitionedCall"
# CHECK: %[[ISLAND_0:[0-9]]]:2 = tf_executor.island wraps "tf.Const"
# CHECK: %[[ISLAND_1:[0-9]]]:2 = tf_executor.island wraps "tf.Identity"(%[[ISLAND_0]]#0)
# CHECK: %[[ISLAND_2:[0-9]]]:2 = tf_executor.island wraps "tf.StatefulPartitionedCall"
# CHECK-SAME: f = @[[FUNC:[a-z0-9]*]]
# CHECK: tf_executor.fetch %[[ISLAND_1]]#0, %[[ISLAND_2]]#0 : tensor<f32>, tensor<f32>
# CHECK: func @[[FUNC]](%arg0: tensor<*xf32>, %arg1: tensor<*x!tf.resource>) -> tensor<*xf32>

View File

@ -5,7 +5,7 @@
# FetchOp.
# Match the island containing the "tf.Neg", capture the output
# CHECK: %[[ISLAND_0:[0-9]*]]:2 = tf_executor.island {{.*[[:space:]].*}} "tf.Neg"
# CHECK: %[[ISLAND_0:[0-9]*]]:2 = tf_executor.island wraps "tf.Neg"
# Check that the tf.Neg control is passed to the fetch
# CHECK: tf_executor.fetch {{.*}} %[[ISLAND_0]]#1 : tensor<*xf32>, !tf_executor.control

View File

@ -5,7 +5,7 @@
# FetchOp.
# Match the island containing the "tf.Neg", capture the output
# CHECK: %[[ISLAND:[0-9]*]]:2 = tf_executor.island {{.*[[:space:]].*}} "tf.Neg"
# CHECK: %[[ISLAND:[0-9]*]]:2 = tf_executor.island wraps "tf.Neg"
# Check that the tf.Neg data output and control are passed to the fetch
# CHECK: tf_executor.fetch %[[ISLAND]]#0, %[[ISLAND]]#1 : tensor<*xf32>, !tf_executor.control

View File

@ -0,0 +1,110 @@
# RUN: tf-mlir-translate -graphdef-to-mlir %s -o - | FileCheck %s
# Verify that the _input_shapes attribute of the FunctionDef is respected.
# This also checks that the output type is correctly inferred based on
# that.
#CHECK: func @identity_function0(%arg0: tensor<i32>) -> tensor<i32>
node {
name: "Placeholder"
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_BOOL
}
}
experimental_debug_info {
}
}
node {
name: "Placeholder_1"
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_INT32
}
}
experimental_debug_info {
}
}
node {
name: "If"
op: "If"
input: "Placeholder"
input: "Placeholder_1"
attr {
key: "Tcond"
value {
type: DT_BOOL
}
}
attr {
key: "Tin"
value {
list {
type: DT_INT32
}
}
}
attr {
key: "Tout"
value {
list {
type: DT_INT32
}
}
}
attr {
key: "else_branch"
value {
func {
name: "identity_function"
}
}
}
attr {
key: "then_branch"
value {
func {
name: "identity_function"
}
}
}
experimental_debug_info {
}
}
library {
function {
signature {
name: "identity_function"
input_arg {
name: "identity_input"
type: DT_INT32
}
output_arg {
name: "identity_output"
type: DT_INT32
}
}
ret {
key: "identity_output"
value: "identity_input"
}
attr {
key: "_input_shapes"
value {
list {
shape {
}
}
}
}
}
}
versions {
producer: 29
min_consumer: 12
}

View File

@ -0,0 +1,177 @@
# RUN: tf-mlir-translate -graphdef-to-mlir %s -o - | FileCheck %s
# Verify that the _output_shapes attribute of ReadVariableOp's are used to get
# variable types.
# This also checks that the output type is correctly inferred based on
# that.
# CHECK: func @__inference_some_function_130(%arg0: tensor<*x!tf.resource>) -> tensor<f32>
# CHECK: tf.ReadVariableOp"(%arg0) {{.*}} : (tensor<*x!tf.resource>) -> tensor<f32>
node {
name : "Variable"
op : "VarHandleOp"
attr {
key : "shape"
value {
shape {
}
}
}
attr {
key : "dtype"
value {
type : DT_FLOAT
}
}
attr {
key : "shared_name"
value {
s: "Variable"
}
}
attr {
key : "_output_shapes"
value {
list {
shape {
}
}
}
}
}
node {
name : "StatefulPartitionedCall"
op : "StatefulPartitionedCall"
input : [ "Variable" ]
attr {
key : "f"
value {
func {
name: "__inference_some_function_13"
}
}
}
attr {
key : "config_proto"
value {
s: "\n\x07\n\x03GPU\x10\x00\n\x07\n\x03\x43PU\x10\x01\x32\x02J\x00\x38\x01"
}
}
attr {
key : "Tout"
value {
list {
type : [ DT_FLOAT ]
}
}
}
attr {
key : "_gradient_op_type"
value {
s: "PartitionedCall-29"
}
}
attr {
key : "_output_shapes"
value {
list {
shape {
}
}
}
}
attr {
key : "Tin"
value {
list {
type : [ DT_RESOURCE ]
}
}
}
}
library {
function {
signature {
name: "__inference_some_function_13"
input_arg {
name : "readvariableop_resource"
type : DT_RESOURCE
}
output_arg {
name : "identity"
type : DT_FLOAT
}
is_stateful : true
control_output: [ "ReadVariableOp" ]
}
node_def {
name : "ReadVariableOp"
op : "ReadVariableOp"
input : [ "readvariableop_resource" ]
device: "/job:localhost/replica:0/task:0/device:CPU:0"
attr {
key : "dtype"
value {
type : DT_FLOAT
}
}
attr {
key : "_output_shapes"
value {
list {
shape {
}
}
}
}
}
node_def {
name : "Identity"
op : "Identity"
input : [ "ReadVariableOp:value:0", "^ReadVariableOp" ]
attr {
key : "T"
value {
type : DT_FLOAT
}
}
attr {
key : "_output_shapes"
value {
list {
shape {
}
}
}
}
}
ret {
key : "identity"
value: "Identity:output:0"
}
attr {
key : "_input_shapes"
value {
list {
shape {
unknown_rank: true
}
}
}
}
control_ret {
key : "ReadVariableOp"
value: "ReadVariableOp"
}
arg_attr {
key : 0x00000000
value {
}
}
}
}
versions {
producer : 148
min_consumer : 12
}

View File

@ -7,8 +7,7 @@
# CHECK: "tf.Placeholder.input"(%arg0)
# CHECK: tf.Relu
# CHECK: %[[IDENTITY:[0-9]+]]:3 = tf_executor.island
# CHECK-NEXT: tf.Identity
# CHECK: %[[IDENTITY:[0-9]+]]:3 = tf_executor.island wraps "tf.IdentityN"
# CHECK: fetch %[[IDENTITY]]#1, %[[IDENTITY]]#0 : tensor<f32>, tensor<f32>
node {

View File

@ -0,0 +1,101 @@
# RUN: tf-mlir-translate -graphdef-to-mlir %s | FileCheck %s
# CHECK:"tf.MlirPassthroughOp"
# CHECK: mlir_module = "\0Afunc @main(%arg0 : tensor<10xf32>, %arg1 : tensor<10xf32>) -> tensor<10x10xf32> {\0A %add = \22tf.Add\22(%arg0, %arg1) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>\0A %ret = \22magic.op\22(%add, %add) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10x10xf32>\0A return %ret : tensor<10x10xf32>\0A}\0A", name = "MlirPassthroughOp"} : (tensor<10xf32>, tensor<10xf32>) -> tensor<*xf32>
node {
name: "x"
op: "Placeholder"
attr {
key: "_user_specified_name"
value {
s: "x"
}
}
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "shape"
value {
shape {
dim {
size: 10
}
}
}
}
}
node {
name: "y"
op: "Placeholder"
attr {
key: "_user_specified_name"
value {
s: "y"
}
}
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "shape"
value {
shape {
dim {
size: 10
}
}
}
}
}
node {
name: "MlirPassthroughOp"
op: "MlirPassthroughOp"
input: "x"
input: "y"
attr {
key: "Tinputs"
value {
list {
type: DT_FLOAT
type: DT_FLOAT
}
}
}
attr {
key: "Toutputs"
value {
list {
type: DT_FLOAT
}
}
}
attr {
key: "mlir_module"
value {
s: "\nfunc @main(%arg0 : tensor<10xf32>, %arg1 : tensor<10xf32>) -> tensor<10x10xf32> {\n %add = \"tf.Add\"(%arg0, %arg1) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>\n %ret = \"magic.op\"(%add, %add) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10x10xf32>\n return %ret : tensor<10x10xf32>\n}\n"
}
}
}
node {
name: "Identity"
op: "Identity"
input: "MlirPassthroughOp"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
}
versions {
producer: 148
}

View File

@ -13,15 +13,12 @@ func @foo(%arg0: tensor<!tf.resource>) -> tensor<!tf.resource> {
// The IsolatePlacerInspectionRequiredOpsPass adds Identities for each input/output of function-calling ops.
// Capture the result of input to function call.
// CHECK: [[VARIABLE_REG:%[0-9]*]]:2 = tf_executor.island
// CHECK-NEXT: "tf.VarHandleOp"()
// CHECK: [[VARIABLE_REG:%[0-9]*]]:2 = tf_executor.island wraps "tf.VarHandleOp"()
// Test for the presence of Identity op between input and function call.
// CHECK: [[IDENTITY_REG:%[0-9]*]]:2 = tf_executor.island
// CHECK-NEXT: "tf.Identity"([[VARIABLE_REG]]#0)
// CHECK: [[IDENTITY_REG:%[0-9]*]]:2 = tf_executor.island wraps "tf.Identity"([[VARIABLE_REG]]#0)
// CHECK: [[CALL_RESULT_REG:%[0-9]*]]:2 = tf_executor.island
// CHECK-NEXT: "tf.StatefulPartitionedCall"([[IDENTITY_REG]]#0)
// CHECK: [[CALL_RESULT_REG:%[0-9]*]]:2 = tf_executor.island wraps "tf.StatefulPartitionedCall"([[IDENTITY_REG]]#0)
// CHECK-SAME: f = @[[FUNCTION:[a-zA-Z0-9_]*]]
// Match the inserted Identity op for call output.

View File

@ -0,0 +1,25 @@
// RUN: tf-opt %s -test-tf-lower-tf | FileCheck %s --dump-input-on-failure
// CHECK-LABEL: simple_pack
// CHECK-SAME: %[[ARG0:.*]]: tensor<3x5xf32>, %[[ARG1:.*]]: tensor<3x5xf32>
func @simple_pack(%arg0: tensor<3x5xf32>, %arg1: tensor<3x5xf32>) -> tensor<2x3x5xf32> {
// CHECK: %[[AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor<i64>}
// CHECK: %[[INP0:.*]] = "tf.ExpandDims"(%[[ARG0]], %[[AXIS]]) : (tensor<3x5xf32>, tensor<i64>) -> tensor<1x3x5xf32>
// CHECK: %[[INP1:.*]] = "tf.ExpandDims"(%[[ARG1]], %[[AXIS]]) : (tensor<3x5xf32>, tensor<i64>) -> tensor<1x3x5xf32>
// CHECK: "tf.ConcatV2"(%[[INP0]], %[[INP1]], %[[AXIS]]) {N = 2 : i64} : (tensor<1x3x5xf32>, tensor<1x3x5xf32>, tensor<i64>) -> tensor<2x3x5xf32>
%0 = "tf.Pack"(%arg0, %arg1) {N = 2 : i64} : (tensor<3x5xf32>, tensor<3x5xf32>) -> tensor<2x3x5xf32>
return %0 : tensor<2x3x5xf32>
}
// CHECK-LABEL: pack_with_unranked
// CHECK-SAME: %[[ARG0:.*]]: tensor<?x5xf32>, %[[ARG1:.*]]: tensor<*xf32>
func @pack_with_unranked(%arg0: tensor<?x5xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
// CHECK: %[[AXIS:.*]] = "tf.Const"() {value = dense<-2> : tensor<i64>}
// CHECK: %[[INP0:.*]] = "tf.ExpandDims"(%[[ARG0]], %[[AXIS]]) : (tensor<?x5xf32>, tensor<i64>) -> tensor<?x1x5xf32>
// CHECK: %[[INP1:.*]] = "tf.ExpandDims"(%[[ARG1]], %[[AXIS]]) : (tensor<*xf32>, tensor<i64>) -> tensor<*xf32>
// CHECK: "tf.ConcatV2"(%[[INP0]], %[[INP1]], %[[AXIS]]) {N = 2 : i64} : (tensor<?x1x5xf32>, tensor<*xf32>, tensor<i64>) -> tensor<*xf32>
%0 = "tf.Pack"(%arg0, %arg1) {axis = -2 : i64, N = 2 : i64} : (tensor<?x5xf32>, tensor<*xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}

View File

@ -0,0 +1,15 @@
// RUN: tf-opt -tf-materialize-passthrough-op %s | FileCheck %s --dump-input=fail
// Check that the MlirPassthroughOp is eliminated and replaced by its attached
// MLIR module.
// CHECK-LABEL: func @main
func @main(%arg0 : tensor<10xf32>, %arg1 : tensor<10xf32>) -> tensor<10x10xf32> {
// CHECK-SAME: (%[[ARG0:.*]]: tensor<10xf32>, %[[ARG1:.*]]: tensor<10xf32>)
// CHECK-NEXT: %[[ADD:.*]] = "tf.Add"(%[[ARG0]], %[[ARG1]]) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
// CHECK-NEXT: %[[MAGIC:.*]] = "magic.op"(%[[ADD]], %[[ADD]]) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10x10xf32>
// CHECK-NEXT: return %[[MAGIC]]
%0 = "tf.MlirPassthroughOp"(%arg0, %arg1) {Tinputs = ["tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT"], Toutputs = ["tfdtype$DT_FLOAT"], device = "", mlir_module = "\0Afunc @main(%arg0 : tensor<10xf32>, %arg1 : tensor<10xf32>) -> tensor<10x10xf32> {\0A %add = \22tf.Add\22(%arg0, %arg1) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>\0A %ret = \22magic.op\22(%add, %add) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10x10xf32>\0A return %ret : tensor<10x10xf32>\0A}\0A", name = "MlirPassthroughOp"} : (tensor<10xf32>, tensor<10xf32>) -> tensor<10x10xf32>
return %0 : tensor<10x10xf32>
}

View File

@ -0,0 +1,23 @@
// RUN: tf-mlir-translate -mlir-to-graphdef %s -o - | FileCheck %s
// Verify the ops generated when Ref type is used in a while loop.
func @main() {
// CHECK: op: "RefEnter"
// CHECK: op: "RefMerge"
// CHECK: op: "RefSwitch"
// CHECK: op: "RefExit"
// CHECK: op: "RefNextIteration"
%0:2 = "_tf.NextIteration.source"() {device = "", T = "tfdtype$DT_INT32"} : () -> (tensor<*x!tf.int32ref>, !_tf.control) loc("while/NextIteration")
%1:2 = "_tf.VariableV2"() {device = "", dtype = "tfdtype$DT_INT32", value = dense<0> : tensor<i32>} : () -> (tensor<!tf.int32ref>, !_tf.control) loc("Ref_Variable")
%2:2 = "_tf.Enter"(%1#0) {device = "", T = "tfdtype$DT_INT32", frame_name = "while/while_context", is_constant = false, parallel_iterations = 10} : (tensor<!tf.int32ref>) -> (tensor<*x!tf.int32ref>, !_tf.control) loc("while/Enter")
%3:3 = "_tf.Merge"(%2#0, %0#0) {device = "", N = 2, T = "tfdtype$DT_INT32"} : (tensor<*x!tf.int32ref>, tensor<*x!tf.int32ref>) -> (tensor<*x!tf.int32ref>, tensor<i32>, !_tf.control) loc("while/Merge")
%4:2 = "_tf.Const"(%3#2) {device = "", dtype = "tfdtype$DT_INT32", value = dense<10> : tensor<i32>} : (!_tf.control) -> (tensor<i32>, !_tf.control) loc("while/Less/y")
%5:2 = "_tf.Less"(%3#0, %4#0) {device = "", T = "tfdtype$DT_INT32"} : (tensor<*x!tf.int32ref>, tensor<i32>) -> (tensor<*xi1>, !_tf.control) loc("while/Less")
%6:2 = "_tf.LoopCond"(%5#0) {device = ""} : (tensor<*xi1>) -> (tensor<i1>, !_tf.control) loc("while/LoopCond")
%7:3 = "_tf.Switch"(%3#0, %6#0) {device = "", T = "tfdtype$DT_INT32", _class = ["loc:@while/Merge"]} : (tensor<*x!tf.int32ref>, tensor<i1>) -> (tensor<*x!tf.int32ref>, tensor<*x!tf.int32ref>, !_tf.control) loc("while/Switch")
%8:2 = "_tf.Exit"(%7#1) {device = "", T = "tfdtype$DT_INT32"} : (tensor<*x!tf.int32ref>) -> (tensor<*x!tf.int32ref>, !_tf.control) loc("while/Exit")
%10:2 = "_tf.Const"(%7#2) {device = "", dtype = "tfdtype$DT_INT32", value = dense<1> : tensor<i32>} : (!_tf.control) -> (tensor<i32>, !_tf.control) loc("while/Add/y")
%11:2 = "_tf.AssignAdd"(%7#0, %10#0) {device = "", T = "tfdtype$DT_INT32"} : (tensor<*x!tf.int32ref>, tensor<i32>) -> (tensor<*x!tf.int32ref>, !_tf.control) loc("while/Add")
%12 = "_tf.NextIteration.sink"(%11#0) {device = "", T = "tfdtype$DT_INT32"} : (tensor<*x!tf.int32ref>) -> !_tf.control loc("while/NextIteration")
return
}

View File

@ -0,0 +1,41 @@
// RUN: tf-opt %s -tf-device-constant-sinking | FileCheck %s --dump-input=fail
// CHECK-LABEL: func @sink_const
func @sink_const(%arg0 : tensor<16xf32>) -> (tensor<16xf32>, tensor<f32>) {
// Verify that the constant are sunk in the tf_device.launch region using them
// and removed if no other use is left.
// Only the 2.0 and 3.0 constants are removed, the 4.0 has a use in the return
// CHECK-NOT:"tf.Const"2.0
// CHECK-NOT:"tf.Const"3.0
%0 = "tf.Const"() {value = dense<2.000000e+00> : tensor<f32>} : () -> tensor<f32>
%1 = "tf.Const"() {value = dense<3.000000e+00> : tensor<f32>} : () -> tensor<f32>
%2 = "tf.Const"() {value = dense<4.000000e+00> : tensor<f32>} : () -> tensor<f32>
%3 = tf_executor.graph {
%res, %ctl = tf_executor.island {
%3 = "tf_device.launch"() ({
// In the device region, check that the 3 constants are materialized and
// remapped to the uses.
// CHECK: tf_device.launch
// CHECK-DAG: %[[CST2:.*]] = "tf.Const"{{.*}}2.0
// CHECK-DAG: %[[CST3:.*]] = "tf.Const"{{.*}}3.0
// CHECK-DAG: %[[CST4:.*]] = "tf.Const"{{.*}}4.0
// CHECK-NOT:"tf.Const"
// CHECK: %[[MUL1:.*]] = "tf.Mul"(%arg0, %[[CST2]])
// CHECK-NEXT: %[[MUL2:.*]] = "tf.Mul"(%[[MUL1]], %[[CST2]])
// CHECK-NEXT: %[[MUL3:.*]] = "tf.Mul"(%[[MUL2]], %[[CST3]])
// CHECK-NEXT: = "tf.Mul"(%[[MUL3]], %[[CST4]])
%3 = "tf.Mul"(%arg0, %0) : (tensor<16xf32>, tensor<f32>) -> tensor<16xf32>
%4 = "tf.Mul"(%3, %0) : (tensor<16xf32>, tensor<f32>) -> tensor<16xf32>
%5 = "tf.Mul"(%4, %1) : (tensor<16xf32>, tensor<f32>) -> tensor<16xf32>
%6 = "tf.Mul"(%5, %2) : (tensor<16xf32>, tensor<f32>) -> tensor<16xf32>
"tf_device.return"(%6) : (tensor<16xf32>) -> ()
}) {device = "tpu0"} : () -> tensor<16xf32>
tf_executor.yield %3 : tensor<16xf32>
}
tf_executor.fetch %res : tensor<16xf32>
}
return %3, %2 : tensor<16xf32>, tensor<f32>
}

View File

@ -83,6 +83,15 @@ func @testBitcast(%arg0: tensor<3x4x!tf.uint16>) -> tensor<3x4x!tf.quint16> {
// -----
// CHECK-LABEL: func @testReverseV2
func @testReverseV2(%arg0: tensor<2x4x3x!tf.uint8>, %arg1: tensor<1xi32>) -> tensor<2x4x3x!tf.uint8> {
// CHECK: tf.ReverseV2
%0 = "tf.ReverseV2"(%arg0, %arg1) : (tensor<2x4x3x!tf.uint8>, tensor<1xi32>) -> tensor<2x4x3x!tf.uint8>
return %0 : tensor<2x4x3x!tf.uint8>
}
// -----
func @testIdentityWrongType(%arg0: tensor<4x2x!tf.string>) -> tensor<4x2x!tf.stringref> {
// expected-error @+1 {{requires all operands to be either same as or ref type of results}}
%0 = "tf.Identity"(%arg0) : (tensor<4x2x!tf.string>) -> tensor<4x2x!tf.stringref>
@ -459,6 +468,37 @@ func @testInvalidFakeQuantWithMinMaxVarsWrongMaxType(tensor<8x8x8x8xf32>, tensor
// -----
// Test valid tf.FakeQuantWithMinMaxVarsPerChannel
// CHECK-LABEL: func @FakeQuantWithMinMaxVarsPerChannel
func @FakeQuantWithMinMaxVarsPerChannel(tensor<1x2x3x8xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<1x2x3x8xf32> {
^bb0(%arg0: tensor<1x2x3x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>):
// CHECK: "tf.FakeQuantWithMinMaxVarsPerChannel"(%arg0, %arg1, %arg2) : (tensor<1x2x3x8xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<1x2x3x8xf32>
%0 = "tf.FakeQuantWithMinMaxVarsPerChannel"(%arg0, %arg1, %arg2) : (tensor<1x2x3x8xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<1x2x3x8xf32>
return %0 : tensor<1x2x3x8xf32>
}
// -----
// Test invalid tf.FakeQuantWithMinMaxVarsPerChannel
func @FakeQuantWithMinMaxVarsPerChannel_ranked_inputs(tensor<f32>, tensor<8xf32>, tensor<8xf32>) -> tensor<f32> {
^bb0(%arg0: tensor<f32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>):
// expected-error @+1 {{requires inputs to be at least 1d float tensor}}
%0 = "tf.FakeQuantWithMinMaxVarsPerChannel"(%arg0, %arg1, %arg2) : (tensor<f32>, tensor<8xf32>, tensor<8xf32>) -> tensor<f32>
return %0 : tensor<f32>
}
// -----
// Test invalid tf.FakeQuantWithMinMaxVarsPerChannel
func @FakeQuantWithMinMaxVarsPerChannel_mismatch_min_max(tensor<1x2x3x8xf32>, tensor<1xf32>, tensor<8xf32>) -> tensor<1x2x3x8xf32> {
^bb0(%arg0: tensor<1x2x3x8xf32>, %arg1: tensor<1xf32>, %arg2: tensor<8xf32>):
// expected-error @+1 {{requires min and max to have same size as last dimension of inputs}}
%0 = "tf.FakeQuantWithMinMaxVarsPerChannel"(%arg0, %arg1, %arg2) : (tensor<1x2x3x8xf32>, tensor<1xf32>, tensor<8xf32>) -> tensor<1x2x3x8xf32>
return %0 : tensor<1x2x3x8xf32>
}
// -----
// Test valid tf.FusedBatchNorm
// CHECK-LABEL: func @testFusedBatchNorm
func @testFusedBatchNorm(tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32> {
@ -944,25 +984,25 @@ func @testLess(tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> {
// -----
// Test valid tf.ConcatV2
func @testConcatV2(%arg: tensor<8x16xf32>, %axis: tensor<1xi32>) -> tensor<?xf32> {
// CHECK: %0 = "tf.ConcatV2"(%arg0, %arg0, %arg1) {N = 2 : i64} : (tensor<8x16xf32>, tensor<8x16xf32>, tensor<1xi32>) -> tensor<?xf32>
%0 = "tf.ConcatV2"(%arg, %arg, %axis) {N = 2: i64} : (tensor<8x16xf32>, tensor<8x16xf32>, tensor<1xi32>) -> tensor<?xf32>
func @testConcatV2(%arg: tensor<8x16xf32>, %axis: tensor<i32>) -> tensor<?xf32> {
// CHECK: %0 = "tf.ConcatV2"(%arg0, %arg0, %arg1) {N = 2 : i64} : (tensor<8x16xf32>, tensor<8x16xf32>, tensor<i32>) -> tensor<?xf32>
%0 = "tf.ConcatV2"(%arg, %arg, %axis) {N = 2: i64} : (tensor<8x16xf32>, tensor<8x16xf32>, tensor<i32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
// -----
// tf.ConcatV2 with wrong 'axis' element type
func @testConcatV2(%arg: tensor<8x16xf32>, %axis: tensor<1xf32>) -> tensor<?xf32> {
func @testConcatV2(%arg: tensor<8x16xf32>, %axis: tensor<f32>) -> tensor<?xf32> {
// expected-error @+1 {{operand #2 must be tensor of 32/64-bit integer values}}
%0 = "tf.ConcatV2"(%arg, %arg, %axis) {N = 2: i64} : (tensor<8x16xf32>, tensor<8x16xf32>, tensor<1xf32>) -> tensor<?xf32>
%0 = "tf.ConcatV2"(%arg, %arg, %axis) {N = 2: i64} : (tensor<8x16xf32>, tensor<8x16xf32>, tensor<f32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
// -----
// tf.ConcatV2 missing required 'axis' operand
func @testConcatV2(%arg: tensor<8x16xf32>, %axis: tensor<1xi32>) -> tensor<?xf32> {
func @testConcatV2() -> tensor<?xf32> {
// expected-error @+1 {{expected 1 or more operands}}
%0 = "tf.ConcatV2"() {N = 0: i64} : () -> tensor<?xf32>
return %0 : tensor<?xf32>
@ -971,9 +1011,165 @@ func @testConcatV2(%arg: tensor<8x16xf32>, %axis: tensor<1xi32>) -> tensor<?xf32
// -----
// tf.ConcatV2 with less than required number of values for the variadic operand
func @testConcatV2(%arg: tensor<8x16xf32>, %axis: tensor<1xi32>) -> tensor<?xf32> {
func @testConcatV2(%arg: tensor<8x16xf32>, %axis: tensor<i32>) -> tensor<?xf32> {
// expected-error @+1 {{attribute 'N' failed to satisfy constraint: 64-bit integer attribute whose minimal value is 2}}
%0 = "tf.ConcatV2"(%arg, %axis) {N = 1: i64} : (tensor<8x16xf32>, tensor<1xi32>) -> tensor<?xf32>
%0 = "tf.ConcatV2"(%arg, %axis) {N = 1: i64} : (tensor<8x16xf32>, tensor<i32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
// -----
func @testConcatV2(%arg: tensor<8x16xf32>, %axis: tensor<i32>) -> tensor<?xf32> {
// expected-error @+1 {{requires attribute 'N' to match the number of inputs; expected: 2 Found: 3}}
%0 = "tf.ConcatV2"(%arg, %arg, %axis) {N = 3: i64} : (tensor<8x16xf32>, tensor<8x16xf32>, tensor<i32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
// -----
func @testAll(%arg0: tensor<2x2xi1>, %arg1: tensor<i32>) -> tensor<i1> {
%0 = "tf.All"(%arg0, %arg1) {keep_dims = false} : (tensor<2x2xi1>, tensor<i32>) -> tensor<i1>
return %0 : tensor<i1>
// CHECK-LABEL: testAll
// CHECK: %0 = "tf.All"(%arg0, %arg1) {keep_dims = false} : (tensor<2x2xi1>, tensor<i32>) -> tensor<i1>
}
// -----
func @testAll64(%arg0: tensor<2x2xi1>, %arg1: tensor<i64>) -> tensor<i1> {
%0 = "tf.All"(%arg0, %arg1) {keep_dims = false} : (tensor<2x2xi1>, tensor<i64>) -> tensor<i1>
return %0 : tensor<i1>
// CHECK-LABEL: testAll64
// CHECK: %0 = "tf.All"(%arg0, %arg1) {keep_dims = false} : (tensor<2x2xi1>, tensor<i64>) -> tensor<i1>
}
// -----
func @testAllFloat(%arg0: tensor<2x2xi1>, %arg1: tensor<f32>) -> tensor<i1> {
// expected-error @+1 {{'tf.All' op operand #1 must be tensor of 32/64-bit integer values}}
%0 = "tf.All"(%arg0, %arg1) {keep_dims = false} : (tensor<2x2xi1>, tensor<f32>) -> tensor<i1>
return %0 : tensor<i1>
}
// -----
func @testAllI32(%arg0: tensor<2x2xi32>, %arg1: tensor<f32>) -> tensor<i32> {
// expected-error @+1 {{'tf.All' op operand #0 must be tensor of 1-bit integer values}}
%0 = "tf.All"(%arg0, %arg1) {keep_dims = false} : (tensor<2x2xi32>, tensor<f32>) -> tensor<i32>
return %0 : tensor<i32>
}
// -----
func @testEqualOpIncompatibleShapeTrue(%x: tensor<5xf32>, %y: tensor<4xf32>) -> tensor<5xi1> {
// expected-error @+1 {{operands don't have broadcast-compatible shapes}}
%0 = "tf.Equal"(%x, %y) {incompatible_shape_error = true} : (tensor<5xf32>, tensor<4xf32>) -> tensor<5xi1>
return %0 : tensor<5xi1>
}
// -----
// CHECK-LABEL: testEqualOpIncompatibleShapeFalse
func @testEqualOpIncompatibleShapeFalse(%x: tensor<5xf32>, %y: tensor<4xf32>) -> tensor<*xi1> {
// CHECK: tf.Equal
%0 = "tf.Equal"(%x, %y) {incompatible_shape_error = false} : (tensor<5xf32>, tensor<4xf32>) -> tensor<*xi1>
return %0 : tensor<*xi1>
}
// -----
func @testNotEqualOpIncompatibleShapeTrue(%x: tensor<5xf32>, %y: tensor<4xf32>) -> tensor<5xi1> {
// expected-error @+1 {{operands don't have broadcast-compatible shapes}}
%0 = "tf.NotEqual"(%x, %y) {incompatible_shape_error = true} : (tensor<5xf32>, tensor<4xf32>) -> tensor<5xi1>
return %0 : tensor<5xi1>
}
// -----
// CHECK-LABEL: testNotEqualOpIncompatibleShapeFalse
func @testNotEqualOpIncompatibleShapeFalse(%x: tensor<5xf32>, %y: tensor<4xf32>) -> tensor<*xi1> {
// CHECK: tf.NotEqual
%0 = "tf.NotEqual"(%x, %y) {incompatible_shape_error = false} : (tensor<5xf32>, tensor<4xf32>) -> tensor<*xi1>
return %0 : tensor<*xi1>
}
// -----
func @testConcatV2(%arg: tensor<8x16xf32>, %axis: tensor<1x1xi32>) -> tensor<*xf32> { // expected-error @+1 {{requires axis to be of scalar type (or vector type for older versions)}}
%0 = "tf.ConcatV2"(%arg, %arg, %axis) {N = 2: i64} : (tensor<8x16xf32>, tensor<8x16xf32>, tensor<1x1xi32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}
// -----
func @testConcatV2(%arg: tensor<8x16xf32>, %axis: tensor<1x1xi32>) -> tensor<*xf32> {
// expected-error @+1 {{requires axis to be of scalar type (or vector type for older versions)}}
%0 = "tf.Concat"(%axis, %arg, %arg) {N = 2: i64} : (tensor<1x1xi32>, tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}
// -----
func @testConcatV2(%arg0: tensor<8x16xf32>, %arg1: tensor<8xf32>, %axis: tensor<i32>) -> tensor<*xf32> {
// expected-error @+1 {{operand type 'tensor<8xf32>' is not compatible with preceding operands; expected rank: 2}}
%0 = "tf.ConcatV2"(%arg0, %arg1, %axis) {N = 2: i64} : (tensor<8x16xf32>, tensor<8xf32>, tensor<i32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}
// -----
// Valid Concat operation with concat axis 1 or -1.
func @testConcatV2(%arg0: tensor<8x16xf32>, %arg1: tensor<8x8xf32>, %axis: tensor<i32>) -> tensor<*xf32> {
%0 = "tf.ConcatV2"(%arg0, %arg1, %axis) {N = 2: i64} : (tensor<8x16xf32>, tensor<8x8xf32>, tensor<i32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}
// -----
func @testConcatV2(%arg0: tensor<8x16xf32>, %arg1: tensor<16x8xf32>, %axis: tensor<i32>) -> tensor<*xf32> {
// expected-error @+1 {{operand type 'tensor<16x8xf32>' is not compatible with preceding operands; expected dimension at index 1: 16}}
%0 = "tf.ConcatV2"(%arg0, %arg1, %axis) {N = 2: i64} : (tensor<8x16xf32>, tensor<16x8xf32>, tensor<i32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}
// -----
// Valid Concat operation with concat axis 1 or -1.
func @testConcatV2(%arg0: tensor<8x8xf32>, %arg1: tensor<?x4xf32>, %arg2: tensor<*xf32>, %arg3: tensor<8x?xf32>, %axis: tensor<i32>) -> tensor<*xf32> {
%0 = "tf.ConcatV2"(%arg0, %arg1, %arg2, %arg3, %axis) {N = 4: i64} : (tensor<8x8xf32>, tensor<?x4xf32>, tensor<*xf32>, tensor<8x?xf32>, tensor<i32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}
// -----
// Valid Pack operation.
func @testPack(%arg0: tensor<4x8xf32>, %arg1: tensor<4x8xf32>) -> tensor<*xf32> {
%0 = "tf.Pack"(%arg0, %arg1) {axis = 1 : i64, N = 2: i64} : (tensor<4x8xf32>, tensor<4x8xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}
// -----
func @testPack(%arg0: tensor<4x8xf32>, %arg1: tensor<4x8xf32>) -> tensor<*xf32> {
// expected-error @+1 {{requires attribute 'N' to match the number of inputs; expected: 2 Found: 1}}
%0 = "tf.Pack"(%arg0, %arg1) {axis = 1 : i64, N = 1: i64} : (tensor<4x8xf32>, tensor<4x8xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}
// -----
func @testPack(%arg0: tensor<4x8xf32>, %arg1: tensor<4x2xf32>) -> tensor<*xf32> {
// expected-error @+1 {{operand type 'tensor<4x2xf32>' is not compatible with preceding operands; expected dimension at index 1: 8}}
%0 = "tf.Pack"(%arg0, %arg1) {axis = 1 : i64, N = 2: i64} : (tensor<4x8xf32>, tensor<4x2xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}
// -----
func @testPack(%arg0: tensor<4x8xf32>, %arg1: tensor<4x8xf32>, %axis: tensor<i32>) -> tensor<*xf32> {
// expected-error @+1 {{attribute 'axis' should be within range [-3, 3); actual value: 3}}
%0 = "tf.Pack"(%arg0, %arg1) {axis = 3 : i64, N = 2: i64} : (tensor<4x8xf32>, tensor<4x8xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}

View File

@ -271,6 +271,39 @@ func @merge_with_variant_type(%arg0: tensor<!tf.variant>, %arg1: tensor<!tf.vari
return %result : tensor<!tf.variant<tensor<8xf32>>>
}
// CHECK-LABEL: func @merge_with_ref_type
func @merge_with_ref_type(%arg0: tensor<4x!tf.f32ref>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
%result = tf_executor.graph {
// CHECK: tf_executor.Merge{{.*}}(tensor<4x!tf.f32ref>, tensor<4xf32>) -> (tensor<4xf32>, tensor<i32>, !tf_executor.control)
%value, %idx, %ctlMerge = "tf_executor.Merge"(%arg0, %arg1) : (tensor<4x!tf.f32ref>, tensor<4xf32>) -> (tensor<4xf32>, tensor<i32>, !tf_executor.control)
tf_executor.fetch %value : tensor<4xf32>
}
return %result : tensor<4xf32>
}
// CHECK-LABEL: func @merge_with_dynamic_shape
func @merge_with_dynamic_shape(%arg0: tensor<2xf32>, %arg1: tensor<3xf32>) -> tensor<?xf32> {
%result = tf_executor.graph {
// CHECK: tf_executor.Merge{{.*}}(tensor<2xf32>, tensor<3xf32>) -> (tensor<?xf32>, tensor<i32>, !tf_executor.control)
%value, %idx, %ctlMerge = "tf_executor.Merge"(%arg0, %arg1) : (tensor<2xf32>, tensor<3xf32>) -> (tensor<?xf32>, tensor<i32>, !tf_executor.control)
tf_executor.fetch %value : tensor<?xf32>
}
return %result : tensor<?xf32>
}
// CHECK-LABEL: func @merge_with_unranked_shape
func @merge_with_unranked_shape(%arg0: tensor<2xf32>, %arg1: tensor<3xf32>) -> tensor<*xf32> {
%result = tf_executor.graph {
// CHECK: tf_executor.Merge{{.*}}(tensor<2xf32>, tensor<3xf32>) -> (tensor<*xf32>, tensor<i32>, !tf_executor.control)
%value, %idx, %ctlMerge = "tf_executor.Merge"(%arg0, %arg1) : (tensor<2xf32>, tensor<3xf32>) -> (tensor<*xf32>, tensor<i32>, !tf_executor.control)
tf_executor.fetch %value : tensor<*xf32>
}
return %result : tensor<*xf32>
}
// CHECK-LABEL: func @enter(%arg0: tensor<*xf32>, %arg1: i1) -> tensor<*xf32> {
func @enter(%arg0: tensor<*xf32>, %arg1: i1) -> tensor<*xf32> {
%result = tf_executor.graph {

View File

@ -490,7 +490,7 @@ func @invalid_merge(%arg0: tensor<*xf32>, %arg1: tensor<i1>) -> tensor<*xf32> {
%true, %false, %ctlSwitch = tf_executor.Switch %arg0, %arg1 : tensor<*xf32>
%value, %idx, %ctlMerge = "tf_executor.Merge"(%true, %false, %arg1) : (tensor<*xf32>, tensor<*xf32>, tensor<i1>) -> (tensor<*xf32>, tensor<i32>, !tf_executor.control)
// expected-error@-1 {{'tf_executor.Merge' op expects all operands to be broadcastable but got 'tensor<*xf32>' vs 'tensor<i1>'}}
// expected-error@-1 {{'tf_executor.Merge' op expects all operands to be broadcastable with output type but got 'tensor<i1>' vs 'tensor<*xf32>'}}
tf_executor.fetch %value : tensor<*xf32>
}
return %result : tensor<*xf32>
@ -502,7 +502,7 @@ func @invalid_merge(%arg0: tensor<*xf32>, %arg1: tensor<i1>) -> tensor<*xf32> {
func @invalid_merge(%arg0: tensor<*xf32>, %arg1: tensor<4xf32>) -> tensor<8xf32> {
%result = tf_executor.graph {
%value, %idx, %ctlMerge = "tf_executor.Merge"(%arg0, %arg1) : (tensor<*xf32>, tensor<4xf32>) -> (tensor<8xf32>, tensor<i32>, !tf_executor.control)
// expected-error@-1 {{'tf_executor.Merge' op expects all operands to be broadcastable but got 'tensor<8xf32>' vs 'tensor<4xf32>'}}
// expected-error@-1 {{'tf_executor.Merge' op expects all operands to be broadcastable with output type but got 'tensor<4xf32>' vs 'tensor<8xf32>'}}
tf_executor.fetch %value : tensor<8xf32>
}
return %result : tensor<8xf32>
@ -514,7 +514,7 @@ func @invalid_merge(%arg0: tensor<*xf32>, %arg1: tensor<4xf32>) -> tensor<8xf32>
func @invalid_merge(%arg0: tensor<*x!tf.variant>, %arg1: tensor<4x!tf.variant>) -> tensor<8x!tf.variant> {
%result = tf_executor.graph {
%value, %idx, %ctlMerge = "tf_executor.Merge"(%arg0, %arg1) : (tensor<*x!tf.variant>, tensor<4x!tf.variant>) -> (tensor<8x!tf.variant>, tensor<i32>, !tf_executor.control)
// expected-error@-1 {{'tf_executor.Merge' op expects all operands to be broadcastable but got 'tensor<8x!tf.variant>' vs 'tensor<4x!tf.variant>'}}
// expected-error@-1 {{'tf_executor.Merge' op expects all operands to be broadcastable with output type but got 'tensor<4x!tf.variant>' vs 'tensor<8x!tf.variant>'}}
tf_executor.fetch %value : tensor<8x!tf.variant>
}
return %result : tensor<8x!tf.variant>
@ -522,6 +522,18 @@ func @invalid_merge(%arg0: tensor<*x!tf.variant>, %arg1: tensor<4x!tf.variant>)
// -----
// Check that if result is a ref type, all operands need to be ref too.
func @inavlid_merge(%arg0: tensor<4x!tf.f32ref>, %arg1: tensor<4xf32>) -> tensor<4x!tf.f32ref> {
%result = tf_executor.graph {
%value, %idx, %ctlMerge = "tf_executor.Merge"(%arg0, %arg1) : (tensor<4x!tf.f32ref>, tensor<4xf32>) -> (tensor<4x!tf.f32ref>, tensor<i32>, !tf_executor.control)
// expected-error@-1 {{'tf_executor.Merge' op expects same operand and output element type but got 'tensor<4xf32>' vs 'tensor<4x!tf.f32ref>'}}
tf_executor.fetch %value : tensor<4x!tf.f32ref>
}
return %result : tensor<4x!tf.f32ref>
}
// -----
// Check that merge data inputs can't appear after control input.
func @invalid_merge(%arg0: tensor<*xf32>, %arg1: tensor<i1>) -> tensor<*xf32> {
%result = tf_executor.graph {

View File

@ -11,9 +11,9 @@ module {
%1 = "tf_device.launch_func"(%0) {_tpu_replicate = "cluster0", device = "tpu0", func = @tpu0_func} : (tensor<?xi32>) -> tensor<?xi32>
// CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[A_OUTPUT]])
// CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf.MLIRCompileToTPU"(%[[A_SHAPE_OUTPUT]])
// CHECK-SAME: _tpu_replicate = "cluster0"
// CHECK-SAME: module
// CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]])
// CHECK-SAME: NumDynamicShapes = 1
// CHECK-SAME: mlir_module
// CHECK-SAME: func @main
// CHECK-SAME: tf.B
// CHECK-NOT: func = @tpu0_func
@ -68,9 +68,8 @@ module {
%1 = "tf_device.launch_func"(%0) {_tpu_replicate = "cluster0", device = "tpu0", func = @tpu0_func} : (tensor<?xi32>) -> tensor<?xi32>
// CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[A_OUTPUT]])
// CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf.MLIRCompileToTPU"(%[[A_SHAPE_OUTPUT]])
// CHECK-SAME: _tpu_replicate = "cluster0"
// CHECK-SAME: module
// CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]])
// CHECK-SAME: mlir_module
// CHECK-SAME: func @main
// CHECK-SAME: tf.B
// CHECK-SAME: func @nested_func
@ -112,9 +111,8 @@ module {
%1 = "tf_device.launch_func"(%0) {_tpu_replicate = "cluster0", device = "tpu0", func = @tpu0_func} : (tensor<?xi32>) -> tensor<?xi32>
// CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[A_OUTPUT]])
// CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf.MLIRCompileToTPU"(%[[A_SHAPE_OUTPUT]])
// CHECK-SAME: _tpu_replicate = "cluster0"
// CHECK-SAME: module
// CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]])
// CHECK-SAME: mlir_module
// CHECK-SAME: func @main
// CHECK-SAME: tf.B
// CHECK-SAME: func @referenced_func
@ -155,9 +153,8 @@ module {
%1 = "tf_device.launch_func"(%0) {_tpu_replicate = "cluster0", device = "tpu0", func = @tpu0_func} : (tensor<?xi32>) -> tensor<?xi32>
// CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[A_OUTPUT]])
// CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf.MLIRCompileToTPU"(%[[A_SHAPE_OUTPUT]])
// CHECK-SAME: _tpu_replicate = "cluster0"
// CHECK-SAME: module
// CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]])
// CHECK-SAME: mlir_module
// CHECK-SAME: func @main
// CHECK-SAME: tf.B
// CHECK-SAME: @referenced_func1
@ -206,9 +203,8 @@ module {
%1 = "tf_device.launch_func"(%0) {_tpu_replicate = "cluster0", device = "tpu0", func = @tpu0_func} : (tensor<?xi32>) -> tensor<?xi32>
// CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[A_OUTPUT]])
// CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf.MLIRCompileToTPU"(%[[A_SHAPE_OUTPUT]])
// CHECK-SAME: _tpu_replicate = "cluster0"
// CHECK-SAME: module
// CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]])
// CHECK-SAME: mlir_module
// CHECK-SAME: func @main
// CHECK-SAME: tf.B
// CHECK-COUNT-2: call @referenced_func
@ -251,9 +247,8 @@ module {
%1 = "tf_device.launch_func"(%0) {_tpu_replicate = "cluster0", device = "tpu0", func = @tpu0_func0} : (tensor<?xi32>) -> tensor<?xi32>
// CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[A_OUTPUT]])
// CHECK: %[[COMPILE0_OUTPUT:[0-9]*]]:2 = "tf.MLIRCompileToTPU"(%[[A_SHAPE_OUTPUT]])
// CHECK-SAME: _tpu_replicate = "cluster0"
// CHECK-SAME: module
// CHECK: %[[COMPILE0_OUTPUT:[0-9]*]]:2 = "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]])
// CHECK-SAME: mlir_module
// CHECK-SAME: func @main
// CHECK-SAME: tf.B
// CHECK-NOT: func = @tpu0_func0
@ -263,9 +258,8 @@ module {
%2 = "tf_device.launch_func"(%1) {_tpu_replicate = "cluster1", device = "tpu0", func = @tpu0_func1} : (tensor<?xi32>) -> tensor<?xi32>
// CHECK: %[[EXECUTE0_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[EXECUTE0_OUTPUT]])
// CHECK: %[[COMPILE1_OUTPUT:[0-9]*]]:2 = "tf.MLIRCompileToTPU"(%[[EXECUTE0_SHAPE_OUTPUT]])
// CHECK-SAME: _tpu_replicate = "cluster1"
// CHECK-SAME: module
// CHECK: %[[COMPILE1_OUTPUT:[0-9]*]]:2 = "tf._TPUCompileMlir"(%[[EXECUTE0_SHAPE_OUTPUT]])
// CHECK-SAME: mlir_module
// CHECK-SAME: func @main
// CHECK-SAME: tf.D
// CHECK-NOT: func = @tpu0_func1
@ -303,9 +297,8 @@ module {
%1 = "tf_device.launch_func"(%0) {_tpu_replicate = "cluster0", device = "tpu0", func = @tpu0_func} : (tensor<?xi32>) -> tensor<?xi32>
// CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[A_OUTPUT]])
// CHECK: %[[COMPILE0_OUTPUT:[0-9]*]]:2 = "tf.MLIRCompileToTPU"(%[[A_SHAPE_OUTPUT]])
// CHECK-SAME: _tpu_replicate = "cluster0"
// CHECK-SAME: module
// CHECK: %[[COMPILE0_OUTPUT:[0-9]*]]:2 = "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]])
// CHECK-SAME: mlir_module
// CHECK-SAME: func @main
// CHECK-SAME: tf.B
// CHECK-NOT: func = @tpu0_func
@ -315,9 +308,8 @@ module {
%2 = "tf_device.launch_func"(%1) {_tpu_replicate = "cluster1", device = "tpu0", func = @tpu0_func} : (tensor<?xi32>) -> tensor<?xi32>
// CHECK: %[[EXECUTE0_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[EXECUTE0_OUTPUT]])
// CHECK: %[[COMPILE1_OUTPUT:[0-9]*]]:2 = "tf.MLIRCompileToTPU"(%[[EXECUTE0_SHAPE_OUTPUT]])
// CHECK-SAME: _tpu_replicate = "cluster1"
// CHECK-SAME: module
// CHECK: %[[COMPILE1_OUTPUT:[0-9]*]]:2 = "tf._TPUCompileMlir"(%[[EXECUTE0_SHAPE_OUTPUT]])
// CHECK-SAME: mlir_module
// CHECK-SAME: func @main
// CHECK-SAME: tf.B
// CHECK-NOT: func = @tpu0_func
@ -351,9 +343,8 @@ module {
%1 = "tf_device.launch_func"(%0) {_tpu_replicate = "cluster0", device = "tpu0", func = @tpu0_func} : (tensor<?xi32>) -> tensor<?xi32>
// CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[A_OUTPUT]])
// CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf.MLIRCompileToTPU"(%[[A_SHAPE_OUTPUT]])
// CHECK-SAME: _tpu_replicate = "cluster0"
// CHECK-SAME: module
// CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]])
// CHECK-SAME: mlir_module
// CHECK-SAME: func @main
// CHECK-SAME: tf.B
// CHECK-SAME: func @referenced_func2
@ -404,3 +395,44 @@ module {
}
// -----
// Tests that TPUCompilationResult operations are properly rewritten
// CHECK-LABEL: func @tpu_compilation_result
func @tpu_compilation_result(%arg0: tensor<?xi32>) -> (tensor<?xi32>, tensor<!tf.string>, tensor<!tf.string>) {
// CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf._TPUCompileMlir"
// CHECK: %[[EXECUTE_OUTPUT:[0-9]*]] = "tf.TPUExecute"
%1 = "tf_device.launch_func"(%arg0) {_tpu_replicate = "cluster0", device = "tpu0", func = @tpu0_func} : (tensor<?xi32>) -> tensor<?xi32>
%compile_result = "tf.TPUCompilationResult"() {_tpu_replicate = "cluster0"} : () -> tensor<!tf.string>
%compile_result2 = "tf.TPUCompilationResult"() {_tpu_replicate = "cluster0"} : () -> tensor<!tf.string>
// CHECK: return %[[EXECUTE_OUTPUT]], %[[COMPILE_OUTPUT]]#0, %[[COMPILE_OUTPUT]]#0
return %1, %compile_result, %compile_result2 : tensor<?xi32>, tensor<!tf.string>, tensor<!tf.string>
}
func @tpu0_func(%arg0: tensor<?xi32>) -> tensor<?xi32> {
%0 = "tf.B"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
return %0 : tensor<?xi32>
}
// -----
// Tests that TPUReplicatedInput and TPUReplicatedOutput operations are properly rewritten
func @main(%arg0 : tensor<0xf32>, %arg1 : tensor<0xf32>) -> tensor<0xf32> {
// CHECK: %[[EXECUTE_OUTPUT:[0-9]*]] = "tf.TPUExecute"(%arg0, %arg1
%0 = "tf.TPUReplicatedInput"(%arg0) {N = 1 : i64} : (tensor<0xf32>) -> tensor<0xf32>
%1 = "tf.TPUReplicatedInput"(%arg1) {N = 1 : i64} : (tensor<0xf32>) -> tensor<0xf32>
%2 = "tf_device.launch_func"(%0, %1) {device = "", _tpu_replicate = "cluster", func = @_func} : (tensor<0xf32>, tensor<0xf32>) -> tensor<0xf32>
%3 = "tf.TPUReplicatedOutput"(%2) {num_replicas = 1 : i64} : (tensor<0xf32>) -> tensor<0xf32>
return %3 : tensor<0xf32>
}
func @_func(%arg0: tensor<0xf32>, %arg1: tensor<0xf32>) -> tensor<0xf32> {
%0 = "tf.Const"() {value = dense<3.000000e+00> : tensor<0xf32>} : () -> tensor<0xf32>
return %0 : tensor<0xf32>
}

Some files were not shown because too many files have changed in this diff Show More