Merge branch 'official_master' into no_mem_opt_if_jit_on
This commit is contained in:
commit
67c50b0914
22
WORKSPACE
22
WORKSPACE
@ -49,34 +49,34 @@ remote_config_workspace()
|
||||
# Apple and Swift rules.
|
||||
http_archive(
|
||||
name = "build_bazel_rules_apple",
|
||||
sha256 = "6efdde60c91724a2be7f89b0c0a64f01138a45e63ba5add2dca2645d981d23a1",
|
||||
urls = ["https://github.com/bazelbuild/rules_apple/releases/download/0.17.2/rules_apple.0.17.2.tar.gz"],
|
||||
sha256 = "a045a436b642c70fb0c10ca84ff0fd2dcbd59cc89100d597a61e8374afafb366",
|
||||
urls = ["https://github.com/bazelbuild/rules_apple/releases/download/0.18.0/rules_apple.0.18.0.tar.gz"],
|
||||
) # https://github.com/bazelbuild/rules_apple/releases
|
||||
http_archive(
|
||||
name = "build_bazel_rules_swift",
|
||||
sha256 = "96a86afcbdab215f8363e65a10cf023b752e90b23abf02272c4fc668fcb70311",
|
||||
urls = ["https://github.com/bazelbuild/rules_swift/releases/download/0.11.1/rules_swift.0.11.1.tar.gz"],
|
||||
sha256 = "18cd4df4e410b0439a4935f9ca035bd979993d42372ba79e7f2d4fafe9596ef0",
|
||||
urls = ["https://github.com/bazelbuild/rules_swift/releases/download/0.12.1/rules_swift.0.12.1.tar.gz"],
|
||||
) # https://github.com/bazelbuild/rules_swift/releases
|
||||
http_archive(
|
||||
name = "build_bazel_apple_support",
|
||||
sha256 = "7356dbd44dea71570a929d1d4731e870622151a5f27164d966dda97305f33471",
|
||||
urls = ["https://github.com/bazelbuild/apple_support/releases/download/0.6.0/apple_support.0.6.0.tar.gz"],
|
||||
sha256 = "122ebf7fe7d1c8e938af6aeaee0efe788a3a2449ece5a8d6a428cb18d6f88033",
|
||||
urls = ["https://github.com/bazelbuild/apple_support/releases/download/0.7.1/apple_support.0.7.1.tar.gz"],
|
||||
) # https://github.com/bazelbuild/apple_support/releases
|
||||
http_archive(
|
||||
name = "bazel_skylib",
|
||||
sha256 = "2ef429f5d7ce7111263289644d233707dba35e39696377ebab8b0bc701f7818e",
|
||||
urls = ["https://github.com/bazelbuild/bazel-skylib/releases/download/0.8.0/bazel-skylib.0.8.0.tar.gz"],
|
||||
sha256 = "1dde365491125a3db70731e25658dfdd3bc5dbdfd11b840b3e987ecf043c7ca0",
|
||||
urls = ["https://github.com/bazelbuild/bazel-skylib/releases/download/0.9.0/bazel-skylib.0.9.0.tar.gz"],
|
||||
) # https://github.com/bazelbuild/bazel-skylib/releases
|
||||
http_archive(
|
||||
name = "com_github_apple_swift_swift_protobuf",
|
||||
type = "zip",
|
||||
strip_prefix = "swift-protobuf-1.5.0/",
|
||||
urls = ["https://github.com/apple/swift-protobuf/archive/1.5.0.zip"],
|
||||
strip_prefix = "swift-protobuf-1.6.0/",
|
||||
urls = ["https://github.com/apple/swift-protobuf/archive/1.6.0.zip"],
|
||||
) # https://github.com/apple/swift-protobuf/releases
|
||||
http_file(
|
||||
name = "xctestrunner",
|
||||
executable = 1,
|
||||
urls = ["https://github.com/google/xctestrunner/releases/download/0.2.7/ios_test_runner.par"],
|
||||
urls = ["https://github.com/google/xctestrunner/releases/download/0.2.9/ios_test_runner.par"],
|
||||
) # https://github.com/google/xctestrunner/releases
|
||||
# Use `swift_rules_dependencies` to fetch the toolchains. With the
|
||||
# `git_repository` rules above, the following call will skip redefining them.
|
||||
|
@ -3,56 +3,56 @@ package(default_visibility = ["//visibility:public"])
|
||||
filegroup(
|
||||
name = "gcc",
|
||||
srcs = [
|
||||
"bin/arm-linux-gnueabihf-gcc",
|
||||
"bin/arm-rpi-linux-gnueabihf-gcc",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "ar",
|
||||
srcs = [
|
||||
"bin/arm-linux-gnueabihf-ar",
|
||||
"bin/arm-rpi-linux-gnueabihf-ar",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "ld",
|
||||
srcs = [
|
||||
"bin/arm-linux-gnueabihf-ld",
|
||||
"bin/arm-rpi-linux-gnueabihf-ld",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "nm",
|
||||
srcs = [
|
||||
"bin/arm-linux-gnueabihf-nm",
|
||||
"bin/arm-rpi-linux-gnueabihf-nm",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "objcopy",
|
||||
srcs = [
|
||||
"bin/arm-linux-gnueabihf-objcopy",
|
||||
"bin/arm-rpi-linux-gnueabihf-objcopy",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "objdump",
|
||||
srcs = [
|
||||
"bin/arm-linux-gnueabihf-objdump",
|
||||
"bin/arm-rpi-linux-gnueabihf-objdump",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "strip",
|
||||
srcs = [
|
||||
"bin/arm-linux-gnueabihf-strip",
|
||||
"bin/arm-rpi-linux-gnueabihf-strip",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "as",
|
||||
srcs = [
|
||||
"bin/arm-linux-gnueabihf-as",
|
||||
"bin/arm-rpi-linux-gnueabihf-as",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -1388,11 +1388,18 @@ def main():
|
||||
if (environ_cp.get('TF_NEED_CUDA') == '1' and
|
||||
'TF_CUDA_CONFIG_REPO' not in environ_cp):
|
||||
|
||||
tensor_rt_question = (
|
||||
'Do you wish to build TensorFlow with TensorRT support? NB! There ' +
|
||||
'are known ODR violations between TensorRT and cuDNN that may result ' +
|
||||
'in application crashes and/or data corruption. Please see ' +
|
||||
'https://github.com/tensorflow/tensorflow/issues/32480 for details.')
|
||||
|
||||
set_action_env_var(
|
||||
environ_cp,
|
||||
'TF_NEED_TENSORRT',
|
||||
'TensorRT',
|
||||
False,
|
||||
question=tensor_rt_question,
|
||||
bazel_config_name='tensorrt')
|
||||
|
||||
environ_save = dict(environ_cp)
|
||||
|
@ -159,7 +159,7 @@ TF_DeprecatedSession* TF_NewDeprecatedSession(const TF_SessionOptions* opt,
|
||||
TF_Status* status) {
|
||||
Session* session;
|
||||
status->status = NewSession(opt->options, &session);
|
||||
if (TF_GetCode(status) == TF_OK) {
|
||||
if (status->status.ok()) {
|
||||
return new TF_DeprecatedSession({session});
|
||||
} else {
|
||||
DCHECK_EQ(nullptr, session);
|
||||
@ -332,7 +332,7 @@ bool ExtendSessionGraphHelper(TF_Session* session, TF_Status* status) {
|
||||
// TODO(nolivia): check this on a subset of the graph instead of all of
|
||||
// it.
|
||||
status->status = graph::ValidateGraphHasNoCycle(session->graph->graph);
|
||||
if (TF_GetCode(status) != TF_OK) {
|
||||
if (!status->status.ok()) {
|
||||
session->graph->mu.unlock();
|
||||
return false;
|
||||
}
|
||||
@ -352,7 +352,7 @@ bool ExtendSessionGraphHelper(TF_Session* session, TF_Status* status) {
|
||||
*graph_def.mutable_library() = graph.flib_def().ToProto();
|
||||
session->graph->mu.unlock();
|
||||
status->status = session->session->Extend(std::move(graph_def));
|
||||
if (TF_GetCode(status) != TF_OK) {
|
||||
if (!status->status.ok()) {
|
||||
// Contract is we always delete input_values[i].
|
||||
return false;
|
||||
}
|
||||
@ -382,7 +382,7 @@ static bool TF_Run_Inputs(TF_Tensor* const* c_inputs,
|
||||
const int ninputs = input_pairs->size();
|
||||
for (int i = 0; i < ninputs; ++i) {
|
||||
status->status = TF_TensorToTensor(c_inputs[i], &(*input_pairs)[i].second);
|
||||
if (TF_GetCode(status) != TF_OK) return false;
|
||||
if (!status->status.ok()) return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
@ -439,7 +439,7 @@ static void TF_Run_Helper(
|
||||
// Serialize back to upstream client, who now owns the new buffer
|
||||
if (run_metadata != nullptr) {
|
||||
status->status = MessageToBuffer(run_metadata_proto, run_metadata);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
if (!status->status.ok()) return;
|
||||
}
|
||||
} else {
|
||||
// NOTE(zongheng): PRun does not support RunOptions yet.
|
||||
@ -459,7 +459,7 @@ static void TF_Run_Helper(
|
||||
continue;
|
||||
}
|
||||
c_outputs[i] = TF_TensorFromTensor(src, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
if (!status->status.ok()) return;
|
||||
}
|
||||
}
|
||||
|
||||
@ -516,7 +516,7 @@ void TF_PRunSetup(TF_DeprecatedSession* s,
|
||||
string new_handle;
|
||||
status->status = s->session->PRunSetup(input_names, output_names,
|
||||
target_oper_names, &new_handle);
|
||||
if (TF_GetCode(status) == TF_OK) {
|
||||
if (status->status.ok()) {
|
||||
char* buf = new char[new_handle.size() + 1];
|
||||
memcpy(buf, new_handle.c_str(), new_handle.size() + 1);
|
||||
*handle = buf;
|
||||
@ -555,7 +555,7 @@ TF_Library* TF_LoadLibrary(const char* library_filename, TF_Status* status) {
|
||||
status->status = tensorflow::LoadLibrary(
|
||||
library_filename, &lib_handle->lib_handle, &lib_handle->op_list.data,
|
||||
&lib_handle->op_list.length);
|
||||
if (TF_GetCode(status) != TF_OK) {
|
||||
if (!status->status.ok()) {
|
||||
delete lib_handle;
|
||||
return nullptr;
|
||||
}
|
||||
@ -983,7 +983,7 @@ void TF_SetAttrTensor(TF_OperationDescription* desc, const char* attr_name,
|
||||
TF_Tensor* value, TF_Status* status) {
|
||||
Tensor t;
|
||||
status->status = TF_TensorToTensor(value, &t);
|
||||
if (TF_GetCode(status) == TF_OK) desc->node_builder.Attr(attr_name, t);
|
||||
if (status->status.ok()) desc->node_builder.Attr(attr_name, t);
|
||||
}
|
||||
|
||||
void TF_SetAttrTensorList(TF_OperationDescription* desc, const char* attr_name,
|
||||
@ -993,13 +993,13 @@ void TF_SetAttrTensorList(TF_OperationDescription* desc, const char* attr_name,
|
||||
std::vector<Tensor> t;
|
||||
t.reserve(num_values);
|
||||
|
||||
for (int i = 0; i < num_values && TF_GetCode(status) == TF_OK; ++i) {
|
||||
for (int i = 0; i < num_values && status->status.ok(); ++i) {
|
||||
Tensor v;
|
||||
status->status = TF_TensorToTensor(values[i], &v);
|
||||
t.emplace_back(v);
|
||||
}
|
||||
|
||||
if (TF_GetCode(status) == TF_OK) desc->node_builder.Attr(attr_name, t);
|
||||
if (status->status.ok()) desc->node_builder.Attr(attr_name, t);
|
||||
}
|
||||
|
||||
void TF_SetAttrValueProto(TF_OperationDescription* desc, const char* attr_name,
|
||||
@ -1048,11 +1048,11 @@ static TF_Operation* TF_FinishOperationLocked(TF_OperationDescription* desc,
|
||||
status->status = desc->node_builder.Finalize(&desc->graph->graph, &ret,
|
||||
/*consume=*/true);
|
||||
|
||||
if (TF_GetCode(status) == TF_OK) {
|
||||
if (status->status.ok()) {
|
||||
// Run shape inference function for newly added node.
|
||||
status->status = desc->graph->refiner.AddNode(ret);
|
||||
}
|
||||
if (TF_GetCode(status) == TF_OK) {
|
||||
if (status->status.ok()) {
|
||||
// Add the node to the name-to-node mapping.
|
||||
desc->graph->name_map[ret->name()] = ret;
|
||||
} else if (ret != nullptr) {
|
||||
@ -1101,7 +1101,7 @@ int TF_OperationOutputListLength(TF_Operation* oper, const char* arg_name,
|
||||
NameRangeMap name_ranges;
|
||||
status->status =
|
||||
NameRangesForNode(oper->node, oper->node.op_def(), nullptr, &name_ranges);
|
||||
if (TF_GetCode(status) != TF_OK) return -1;
|
||||
if (!status->status.ok()) return -1;
|
||||
auto iter = name_ranges.find(arg_name);
|
||||
if (iter == name_ranges.end()) {
|
||||
status->status = InvalidArgument("Output arg '", arg_name, "' not found");
|
||||
@ -1123,7 +1123,7 @@ int TF_OperationInputListLength(TF_Operation* oper, const char* arg_name,
|
||||
NameRangeMap name_ranges;
|
||||
status->status =
|
||||
NameRangesForNode(oper->node, oper->node.op_def(), &name_ranges, nullptr);
|
||||
if (TF_GetCode(status) != TF_OK) return -1;
|
||||
if (!status->status.ok()) return -1;
|
||||
auto iter = name_ranges.find(arg_name);
|
||||
if (iter == name_ranges.end()) {
|
||||
status->status = InvalidArgument("Input arg '", arg_name, "' not found");
|
||||
@ -1142,6 +1142,16 @@ TF_Output TF_OperationInput(TF_Input oper_in) {
|
||||
return {ToOperation(edge->src()), edge->src_output()};
|
||||
}
|
||||
|
||||
void TF_OperationAllInputs(TF_Operation* oper, TF_Output* inputs,
|
||||
int max_inputs) {
|
||||
for (auto* edge : oper->node.in_edges()) {
|
||||
if (edge->dst_input() >= 0 && edge->dst_input() < max_inputs) {
|
||||
inputs[edge->dst_input()] = {ToOperation(edge->src()),
|
||||
edge->src_output()};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int TF_OperationOutputNumConsumers(TF_Output oper_out) {
|
||||
int count = 0;
|
||||
for (const auto* edge : oper_out.oper->node.out_edges()) {
|
||||
@ -1221,7 +1231,7 @@ TF_AttrMetadata TF_OperationGetAttrMetadata(TF_Operation* oper,
|
||||
TF_Status* status) {
|
||||
TF_AttrMetadata metadata;
|
||||
const auto* attr = GetAttrValue(oper, attr_name, status);
|
||||
if (TF_GetCode(status) != TF_OK) return metadata;
|
||||
if (!status->status.ok()) return metadata;
|
||||
switch (attr->value_case()) {
|
||||
#define SINGLE_CASE(kK, attr_type, size_expr) \
|
||||
case tensorflow::AttrValue::kK: \
|
||||
@ -1328,7 +1338,7 @@ void TF_OperationGetAttrString(TF_Operation* oper, const char* attr_name,
|
||||
void* value, size_t max_length,
|
||||
TF_Status* status) {
|
||||
const auto* attr = GetAttrValue(oper, attr_name, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
if (!status->status.ok()) return;
|
||||
if (attr->value_case() != tensorflow::AttrValue::kS) {
|
||||
status->status =
|
||||
InvalidArgument("Attribute '", attr_name, "' is not a string");
|
||||
@ -1346,7 +1356,7 @@ void TF_OperationGetAttrStringList(TF_Operation* oper, const char* attr_name,
|
||||
int max_values, void* storage,
|
||||
size_t storage_size, TF_Status* status) {
|
||||
const auto* attr = GetAttrValue(oper, attr_name, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
if (!status->status.ok()) return;
|
||||
if (attr->value_case() != tensorflow::AttrValue::kList) {
|
||||
status->status =
|
||||
InvalidArgument("Value for '", attr_name, "' is not a list");
|
||||
@ -1379,7 +1389,7 @@ void TF_OperationGetAttrStringList(TF_Operation* oper, const char* attr_name,
|
||||
void func##List(TF_Operation* oper, const char* attr_name, c_type* values, \
|
||||
int max_values, TF_Status* status) { \
|
||||
const auto* attr = GetAttrValue(oper, attr_name, status); \
|
||||
if (TF_GetCode(status) != TF_OK) return; \
|
||||
if (!status->status.ok()) return; \
|
||||
if (attr->value_case() != tensorflow::AttrValue::kList) { \
|
||||
status->status = \
|
||||
InvalidArgument("Value for '", attr_name, "' is not a list."); \
|
||||
@ -1401,7 +1411,7 @@ void TF_OperationGetAttrShape(TF_Operation* oper, const char* attr_name,
|
||||
PartialTensorShape shape;
|
||||
status->status =
|
||||
tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &shape);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
if (!status->status.ok()) return;
|
||||
auto len = std::min(shape.dims(), num_dims);
|
||||
for (int i = 0; i < len; ++i) {
|
||||
value[i] = shape.dim_size(i);
|
||||
@ -1415,7 +1425,7 @@ void TF_OperationGetAttrShapeList(TF_Operation* oper, const char* attr_name,
|
||||
std::vector<PartialTensorShape> shapes;
|
||||
status->status =
|
||||
tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &shapes);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
if (!status->status.ok()) return;
|
||||
auto len = std::min(static_cast<int>(shapes.size()), num_shapes);
|
||||
int64_t* p = storage;
|
||||
int storage_left = storage_size;
|
||||
@ -1443,7 +1453,7 @@ void TF_OperationGetAttrTensorShapeProto(TF_Operation* oper,
|
||||
const char* attr_name,
|
||||
TF_Buffer* value, TF_Status* status) {
|
||||
const auto* attr = GetAttrValue(oper, attr_name, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
if (!status->status.ok()) return;
|
||||
if (attr->value_case() != tensorflow::AttrValue::kShape) {
|
||||
status->status =
|
||||
InvalidArgument("Value for '", attr_name, "' is not a shape.");
|
||||
@ -1457,7 +1467,7 @@ void TF_OperationGetAttrTensorShapeProtoList(TF_Operation* oper,
|
||||
TF_Buffer** values, int max_values,
|
||||
TF_Status* status) {
|
||||
const auto* attr = GetAttrValue(oper, attr_name, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
if (!status->status.ok()) return;
|
||||
if (attr->value_case() != tensorflow::AttrValue::kList) {
|
||||
status->status =
|
||||
InvalidArgument("Value for '", attr_name, "' is not a list");
|
||||
@ -1467,7 +1477,7 @@ void TF_OperationGetAttrTensorShapeProtoList(TF_Operation* oper,
|
||||
for (int i = 0; i < len; ++i) {
|
||||
values[i] = TF_NewBuffer();
|
||||
status->status = MessageToBuffer(attr->list().shape(i), values[i]);
|
||||
if (TF_GetCode(status) != TF_OK) {
|
||||
if (!status->status.ok()) {
|
||||
// Delete everything allocated to far, the operation has failed.
|
||||
for (int j = 0; j <= i; ++j) {
|
||||
TF_DeleteBuffer(values[j]);
|
||||
@ -1482,7 +1492,7 @@ void TF_OperationGetAttrTensor(TF_Operation* oper, const char* attr_name,
|
||||
*value = nullptr;
|
||||
Tensor t;
|
||||
status->status = tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &t);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
if (!status->status.ok()) return;
|
||||
*value = TF_TensorFromTensor(t, status);
|
||||
}
|
||||
|
||||
@ -1491,7 +1501,7 @@ void TF_OperationGetAttrTensorList(TF_Operation* oper, const char* attr_name,
|
||||
TF_Status* status) {
|
||||
std::vector<Tensor> ts;
|
||||
status->status = tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &ts);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
if (!status->status.ok()) return;
|
||||
const auto len = std::min(max_values, static_cast<int>(ts.size()));
|
||||
for (int i = 0; i < len; ++i) {
|
||||
values[i] = TF_TensorFromTensor(ts[i], status);
|
||||
@ -1502,7 +1512,7 @@ void TF_OperationGetAttrValueProto(TF_Operation* oper, const char* attr_name,
|
||||
TF_Buffer* output_attr_value,
|
||||
TF_Status* status) {
|
||||
const auto* attr = GetAttrValue(oper, attr_name, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
if (!status->status.ok()) return;
|
||||
status->status = MessageToBuffer(*attr, output_attr_value);
|
||||
}
|
||||
|
||||
@ -1583,7 +1593,7 @@ void TF_GraphGetOpDef(TF_Graph* graph, const char* op_name,
|
||||
{
|
||||
mutex_lock l(graph->mu);
|
||||
status->status = graph->graph.op_registry()->LookUpOpDef(op_name, &op_def);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
if (!status->status.ok()) return;
|
||||
}
|
||||
status->status = MessageToBuffer(*op_def, output_op_def);
|
||||
}
|
||||
@ -1701,7 +1711,7 @@ static void GraphImportGraphDefLocked(TF_Graph* graph, const GraphDef& def,
|
||||
tensorflow::ImportGraphDefResults results;
|
||||
status->status = tensorflow::ImportGraphDef(opts->opts, def, &graph->graph,
|
||||
&graph->refiner, &results);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
if (!status->status.ok()) return;
|
||||
|
||||
// Add new nodes to name_map
|
||||
for (int i = last_node_id; i < graph->graph.num_node_ids(); ++i) {
|
||||
@ -1755,7 +1765,7 @@ TF_ImportGraphDefResults* TF_GraphImportGraphDefWithResults(
|
||||
auto results = new TF_ImportGraphDefResults();
|
||||
mutex_lock l(graph->mu);
|
||||
GraphImportGraphDefLocked(graph, def, options, results, status);
|
||||
if (TF_GetCode(status) != TF_OK) {
|
||||
if (!status->status.ok()) {
|
||||
delete results;
|
||||
return nullptr;
|
||||
}
|
||||
@ -1813,7 +1823,7 @@ bool CreateInput(const TF_Output& parent_input, TF_Graph* g, const char* name,
|
||||
TF_SetAttrType(desc, "dtype", TF_OperationOutputType(parent_input));
|
||||
// TODO(skyewm): set placeholder shape
|
||||
TF_Operation* oper = TF_FinishOperation(desc, status);
|
||||
if (TF_GetCode(status) != TF_OK) return false;
|
||||
if (!status->status.ok()) return false;
|
||||
*input = {oper, 0};
|
||||
return true;
|
||||
}
|
||||
@ -1958,7 +1968,7 @@ TF_WhileParams TF_NewWhile(TF_Graph* g, TF_Output* inputs, int ninputs,
|
||||
TF_WhileParams params = {ninputs, cond_graph, cond_inputs, cond_output,
|
||||
body_graph, body_inputs, body_outputs, name};
|
||||
|
||||
if (TF_GetCode(status) != TF_OK) {
|
||||
if (!status->status.ok()) {
|
||||
FreeWhileResources(¶ms);
|
||||
return EmptyWhileParams();
|
||||
}
|
||||
@ -2160,7 +2170,7 @@ TF_Session* TF_NewSession(TF_Graph* graph, const TF_SessionOptions* opt,
|
||||
TF_Status* status) {
|
||||
Session* session;
|
||||
status->status = NewSession(opt->options, &session);
|
||||
if (TF_GetCode(status) == TF_OK) {
|
||||
if (status->status.ok()) {
|
||||
TF_Session* new_session = new TF_Session(session, graph);
|
||||
if (graph != nullptr) {
|
||||
mutex_lock l(graph->mu);
|
||||
@ -2208,7 +2218,7 @@ TF_Session* TF_LoadSessionFromSavedModel(
|
||||
status->status =
|
||||
tensorflow::LoadSavedModel(session_options->options, run_options_proto,
|
||||
export_dir, tag_set, &bundle);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
if (!status->status.ok()) return nullptr;
|
||||
|
||||
// Create a TF_Graph from the MetaGraphDef. This is safe as long as Session
|
||||
// extends using GraphDefs. The Graph instance is different, but equivalent
|
||||
@ -2221,11 +2231,11 @@ TF_Session* TF_LoadSessionFromSavedModel(
|
||||
GraphImportGraphDefLocked(graph, bundle.meta_graph_def.graph_def(),
|
||||
import_opts, &results, status);
|
||||
TF_DeleteImportGraphDefOptions(import_opts);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
if (!status->status.ok()) return nullptr;
|
||||
|
||||
if (meta_graph_def != nullptr) {
|
||||
status->status = MessageToBuffer(bundle.meta_graph_def, meta_graph_def);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
if (!status->status.ok()) return nullptr;
|
||||
}
|
||||
|
||||
TF_Session* session = new TF_Session(bundle.session.release(), graph);
|
||||
@ -2325,7 +2335,7 @@ void TF_SessionPRunSetup(TF_Session* session, const TF_Output* inputs,
|
||||
string new_handle;
|
||||
status->status = session->session->PRunSetup(input_names, output_names,
|
||||
target_names, &new_handle);
|
||||
if (TF_GetCode(status) == TF_OK) {
|
||||
if (status->status.ok()) {
|
||||
char* buf = new char[new_handle.size() + 1];
|
||||
memcpy(buf, new_handle.c_str(), new_handle.size() + 1);
|
||||
*handle = buf;
|
||||
@ -2387,9 +2397,9 @@ unsigned char TF_TryEvaluateConstant(TF_Graph* graph, TF_Output output,
|
||||
tensor, graph->refiner, *graph->graph.op_registry(),
|
||||
graph->graph.versions().producer(), &evaluated, &result_tensor);
|
||||
if (evaluated) {
|
||||
DCHECK(TF_GetCode(status) == TF_OK);
|
||||
DCHECK(status->status.ok());
|
||||
*result = TF_TensorFromTensor(result_tensor, status);
|
||||
if (TF_GetCode(status) != TF_OK) evaluated = false;
|
||||
if (!status->status.ok()) evaluated = false;
|
||||
}
|
||||
return evaluated;
|
||||
}
|
||||
@ -2444,7 +2454,7 @@ TF_Buffer* TF_ApiDefMapGet(TF_ApiDefMap* api_def_map, const char* name,
|
||||
|
||||
TF_Buffer* ret = TF_NewBuffer();
|
||||
status->status = MessageToBuffer(*api_def, ret);
|
||||
if (TF_GetCode(status) != TF_OK) {
|
||||
if (!status->status.ok()) {
|
||||
TF_DeleteBuffer(ret);
|
||||
return nullptr;
|
||||
}
|
||||
@ -2456,7 +2466,7 @@ TF_Buffer* TF_GetAllRegisteredKernels(TF_Status* status) {
|
||||
tensorflow::KernelList kernel_list = tensorflow::GetAllRegisteredKernels();
|
||||
TF_Buffer* ret = TF_NewBuffer();
|
||||
status->status = MessageToBuffer(kernel_list, ret);
|
||||
if (TF_GetCode(status) != TF_OK) {
|
||||
if (!status->status.ok()) {
|
||||
TF_DeleteBuffer(ret);
|
||||
return nullptr;
|
||||
}
|
||||
@ -2468,7 +2478,7 @@ TF_Buffer* TF_GetRegisteredKernelsForOp(const char* name, TF_Status* status) {
|
||||
tensorflow::GetRegisteredKernelsForOp(name);
|
||||
TF_Buffer* ret = TF_NewBuffer();
|
||||
status->status = MessageToBuffer(kernel_list, ret);
|
||||
if (TF_GetCode(status) != TF_OK) {
|
||||
if (!status->status.ok()) {
|
||||
TF_DeleteBuffer(ret);
|
||||
return nullptr;
|
||||
}
|
||||
@ -2498,7 +2508,7 @@ TF_Server* TF_NewServer(const void* proto, size_t proto_len,
|
||||
|
||||
std::unique_ptr<tensorflow::ServerInterface> out_server;
|
||||
status->status = tensorflow::NewServer(server_def, &out_server);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
if (!status->status.ok()) return nullptr;
|
||||
|
||||
return new TF_Server(std::move(out_server));
|
||||
#endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
|
||||
|
@ -435,6 +435,15 @@ TF_CAPI_EXPORT extern int TF_OperationInputListLength(TF_Operation* oper,
|
||||
// producer.index) to consumer.oper's input (given by consumer.index).
|
||||
TF_CAPI_EXPORT extern TF_Output TF_OperationInput(TF_Input oper_in);
|
||||
|
||||
// Get list of all inputs of a specific operation. `inputs` must point to
|
||||
// an array of length at least `max_inputs` (ideally set to
|
||||
// TF_OperationNumInputs(oper)). Beware that a concurrent
|
||||
// modification of the graph can increase the number of inputs of
|
||||
// an operation.
|
||||
TF_CAPI_EXPORT extern void TF_OperationAllInputs(TF_Operation* oper,
|
||||
TF_Output* inputs,
|
||||
int max_inputs);
|
||||
|
||||
// Get the number of current consumers of a specific output of an
|
||||
// operation. Note that this number can change when new operations
|
||||
// are added to the graph.
|
||||
|
@ -510,10 +510,6 @@ TFE_TensorHandle* TFE_DequeueVariantTensor(TF_Session* session, int tensor_id,
|
||||
return createTFEDequeue(ctx, TF_VARIANT, queue, status);
|
||||
}
|
||||
|
||||
static void CheckOk(TF_Status* status) {
|
||||
CHECK_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
}
|
||||
|
||||
void TFE_TensorHandlePrintDebugString(TFE_TensorHandle* handle) {
|
||||
auto* status = TF_NewStatus();
|
||||
if (!TFE_TensorHandleIsConcrete(handle)) {
|
||||
|
@ -9,7 +9,6 @@ load(
|
||||
)
|
||||
load(
|
||||
"//tensorflow/core/platform:default/build_config.bzl",
|
||||
"tf_additional_device_tracer_test_flags",
|
||||
"tf_kernel_tests_linkstatic",
|
||||
)
|
||||
load(
|
||||
@ -27,6 +26,7 @@ tf_cuda_library(
|
||||
"c_api.cc",
|
||||
"c_api_debug.cc",
|
||||
"c_api_experimental.h",
|
||||
"c_api_internal.cc",
|
||||
"c_api_internal.h",
|
||||
],
|
||||
hdrs = ["c_api.h"],
|
||||
@ -237,8 +237,7 @@ tf_cuda_cc_test(
|
||||
srcs = [
|
||||
"c_api_experimental_test.cc",
|
||||
],
|
||||
args =
|
||||
["--heap_check=local"] + tf_additional_device_tracer_test_flags(),
|
||||
args = ["--heap_check=local"],
|
||||
extra_copts = tfe_xla_copts(),
|
||||
linkstatic = tf_kernel_tests_linkstatic(),
|
||||
tags = tf_cuda_tests_tags() + ["nomac"],
|
||||
|
@ -33,7 +33,7 @@ limitations under the License.
|
||||
#include "tensorflow/c/eager/c_api_internal.h"
|
||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||
#include "tensorflow/core/framework/device_attributes.pb.h"
|
||||
#include "tensorflow/core/platform/host_info.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/platform.h" // NOLINT
|
||||
#ifdef TENSORFLOW_EAGER_USE_XLA
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
@ -61,6 +61,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/rendezvous.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/lib/core/blocking_counter.h"
|
||||
#include "tensorflow/core/lib/core/refcount.h"
|
||||
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||
#include "tensorflow/core/lib/gtl/cleanup.h"
|
||||
@ -100,32 +101,34 @@ string DeviceName(const tensorflow::Device* d) {
|
||||
tensorflow::Status GetAllRemoteDevices(
|
||||
const std::vector<string>& remote_workers,
|
||||
tensorflow::WorkerCacheInterface* worker_cache,
|
||||
std::unique_ptr<tensorflow::DeviceMgr>* device_mgr) {
|
||||
std::unique_ptr<tensorflow::DynamicDeviceMgr>* device_mgr) {
|
||||
std::vector<std::unique_ptr<tensorflow::Device>> remote_devices;
|
||||
tensorflow::Status status;
|
||||
// TODO(nareshmodi) do this in parallel instead of serially.
|
||||
for (const string& remote_worker : remote_workers) {
|
||||
tensorflow::Notification n;
|
||||
tensorflow::mutex remote_devices_mu;
|
||||
int num_remote_workers = remote_workers.size();
|
||||
tensorflow::BlockingCounter counter(num_remote_workers);
|
||||
std::vector<tensorflow::Status> statuses(num_remote_workers);
|
||||
for (int i = 0; i < num_remote_workers; i++) {
|
||||
tensorflow::NewRemoteDevices(
|
||||
tensorflow::Env::Default(), worker_cache, remote_worker,
|
||||
[&status, &n, &remote_devices](
|
||||
tensorflow::Env::Default(), worker_cache, remote_workers[i],
|
||||
[i, &statuses, &counter, &remote_devices, &remote_devices_mu](
|
||||
const tensorflow::Status& s,
|
||||
std::vector<tensorflow::Device*>* devices) {
|
||||
status = s;
|
||||
statuses[i] = s;
|
||||
if (s.ok()) {
|
||||
tensorflow::mutex_lock l(remote_devices_mu);
|
||||
for (tensorflow::Device* d : *devices) {
|
||||
remote_devices.emplace_back(d);
|
||||
}
|
||||
}
|
||||
n.Notify();
|
||||
counter.DecrementCount();
|
||||
});
|
||||
n.WaitForNotification();
|
||||
}
|
||||
std::unique_ptr<tensorflow::DeviceMgr> remote_device_mgr(
|
||||
new tensorflow::StaticDeviceMgr(std::move(remote_devices)));
|
||||
|
||||
TF_RETURN_IF_ERROR(status);
|
||||
|
||||
counter.Wait();
|
||||
for (int i = 0; i < num_remote_workers; i++) {
|
||||
TF_RETURN_IF_ERROR(statuses[i]);
|
||||
}
|
||||
auto remote_device_mgr = absl::make_unique<tensorflow::DynamicDeviceMgr>();
|
||||
TF_RETURN_IF_ERROR(remote_device_mgr->AddDevices(std::move(remote_devices)));
|
||||
*device_mgr = std::move(remote_device_mgr);
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
@ -135,11 +138,15 @@ tensorflow::Status CreateRemoteContexts(
|
||||
int keep_alive_secs, const tensorflow::ServerDef& server_def,
|
||||
tensorflow::eager::EagerClientCache* remote_eager_workers, bool async,
|
||||
const tensorflow::eager::CreateContextRequest& base_request) {
|
||||
for (int i = 0; i < remote_workers.size(); i++) {
|
||||
int num_remote_workers = remote_workers.size();
|
||||
tensorflow::BlockingCounter counter(num_remote_workers);
|
||||
std::vector<tensorflow::Status> statuses(num_remote_workers);
|
||||
for (int i = 0; i < num_remote_workers; i++) {
|
||||
const string& remote_worker = remote_workers[i];
|
||||
|
||||
tensorflow::eager::CreateContextRequest request(base_request);
|
||||
tensorflow::eager::CreateContextResponse response;
|
||||
tensorflow::eager::CreateContextResponse* response =
|
||||
new tensorflow::eager::CreateContextResponse();
|
||||
request.set_context_id(context_id);
|
||||
tensorflow::DeviceNameUtils::ParsedName parsed_name;
|
||||
if (!tensorflow::DeviceNameUtils::ParseFullName(remote_worker,
|
||||
@ -159,16 +166,17 @@ tensorflow::Status CreateRemoteContexts(
|
||||
return tensorflow::errors::Internal(
|
||||
"Cannot find a client for the given target:", remote_worker);
|
||||
}
|
||||
tensorflow::Notification n;
|
||||
tensorflow::Status status;
|
||||
// TODO(nareshmodi) do this in parallel instead of serially.
|
||||
eager_client->CreateContextAsync(
|
||||
&request, &response, [&status, &n](const tensorflow::Status& s) {
|
||||
status = s;
|
||||
n.Notify();
|
||||
&request, response,
|
||||
[i, &statuses, &counter, response](const tensorflow::Status& s) {
|
||||
statuses[i] = s;
|
||||
delete response;
|
||||
counter.DecrementCount();
|
||||
});
|
||||
n.WaitForNotification();
|
||||
TF_RETURN_IF_ERROR(status);
|
||||
}
|
||||
counter.Wait();
|
||||
for (int i = 0; i < num_remote_workers; i++) {
|
||||
TF_RETURN_IF_ERROR(statuses[i]);
|
||||
}
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
@ -215,7 +223,7 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
|
||||
std::remove(remote_workers.begin(), remote_workers.end(), worker_name),
|
||||
remote_workers.end());
|
||||
|
||||
std::unique_ptr<tensorflow::DeviceMgr> remote_device_mgr;
|
||||
std::unique_ptr<tensorflow::DynamicDeviceMgr> remote_device_mgr;
|
||||
LOG_AND_RETURN_IF_ERROR(GetAllRemoteDevices(
|
||||
remote_workers, grpc_server->master_env()->worker_cache,
|
||||
&remote_device_mgr));
|
||||
@ -247,7 +255,7 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
|
||||
LOG_AND_RETURN_IF_ERROR(
|
||||
CreateRemoteContexts(remote_workers, context_id, keep_alive_secs,
|
||||
server_def, remote_eager_workers.get(),
|
||||
ctx->context->Executor()->Async(), base_request));
|
||||
ctx->context->Executor().Async(), base_request));
|
||||
|
||||
tensorflow::RemoteRendezvous* r =
|
||||
grpc_server->worker_env()->rendezvous_mgr->Find(context_id);
|
||||
@ -564,7 +572,7 @@ TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) {
|
||||
const tensorflow::Tensor* t = nullptr;
|
||||
tensorflow::TensorHandle* h_cpu = nullptr;
|
||||
status->status = EagerCopyToDevice(
|
||||
handle, handle->Context(), handle->Context()->Executor(),
|
||||
handle, handle->Context(), &handle->Context()->Executor(),
|
||||
handle->Context()->HostCPU(), false, &h_cpu);
|
||||
if (!status->status.ok()) {
|
||||
return nullptr;
|
||||
@ -596,33 +604,8 @@ TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) {
|
||||
|
||||
TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name,
|
||||
TF_Status* status) {
|
||||
const char* name = op_or_function_name; // Shorthand
|
||||
const tensorflow::AttrTypeMap* types;
|
||||
bool is_function = false;
|
||||
status->status = tensorflow::AttrTypeMapForOp(name, &types, &is_function);
|
||||
if (!status->status.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
if (!is_function) {
|
||||
const tensorflow::OpDef* op_def;
|
||||
status->status = tensorflow::OpDefForOp(op_or_function_name, &op_def);
|
||||
if (!status->status.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
return new TFE_Op(ctx, name, false, types,
|
||||
new TFE_OpInferenceContext(op_def));
|
||||
}
|
||||
if (!ctx->context->FindFunctionByName(name)) {
|
||||
status->status = tensorflow::errors::NotFound(
|
||||
"'", name,
|
||||
"' is neither a type of a primitive operation nor a name "
|
||||
"of a function registered in binary running on ",
|
||||
tensorflow::port::Hostname(),
|
||||
". Make sure the operation or function is "
|
||||
"registered in the binary running in this process.");
|
||||
return nullptr;
|
||||
}
|
||||
return new TFE_Op(ctx, name, true, types, nullptr);
|
||||
return NewOrResetOp(ctx, op_or_function_name, status,
|
||||
/* op_to_reset= */ nullptr);
|
||||
}
|
||||
|
||||
void TFE_DeleteOp(TFE_Op* op) { delete op; }
|
||||
@ -916,7 +899,7 @@ TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
|
||||
return nullptr;
|
||||
}
|
||||
status->status = tensorflow::EagerCopyToDevice(h->handle, ctx->context,
|
||||
ctx->context->Executor(),
|
||||
&ctx->context->Executor(),
|
||||
device, false, &handle);
|
||||
if (status->status.ok()) {
|
||||
return new TFE_TensorHandle(handle);
|
||||
@ -967,7 +950,7 @@ TFE_TensorHandle* TFE_NewTensorHandle(const tensorflow::Tensor& t,
|
||||
|
||||
void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf,
|
||||
TF_Status* status) {
|
||||
status->status = ctx->context->Executor()->WaitForAllPendingNodes();
|
||||
status->status = ctx->context->Executor().WaitForAllPendingNodes();
|
||||
if (!status->status.ok()) return;
|
||||
tensorflow::mutex_lock ml(*ctx->context->MetadataMu());
|
||||
status->status = MessageToBuffer(*ctx->context->RunMetadataProto(), buf);
|
||||
@ -979,9 +962,9 @@ TFE_Op* GetFunc(TFE_Context* ctx, const tensorflow::NameAttrList& func,
|
||||
TF_Status* status) {
|
||||
TFE_Op* func_op = TFE_NewOp(ctx, func.name().data(), status);
|
||||
for (const auto& attr : func.attr()) {
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
if (!status->status.ok()) return nullptr;
|
||||
SetOpAttrValueScalar(ctx, func_op, attr.second, attr.first.data(), status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
if (!status->status.ok()) return nullptr;
|
||||
}
|
||||
return func_op;
|
||||
}
|
||||
@ -1029,7 +1012,7 @@ void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op,
|
||||
} break;
|
||||
case tensorflow::AttrValue::kFunc: {
|
||||
const auto func_op = GetFunc(ctx, default_value.func(), status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
if (!status->status.ok()) return;
|
||||
// TODO(nareshmodi): TFE_OpSetAttrFunction and TFE_OpSetAttrFunctionList
|
||||
// require TFE_Op* and just convert it internally a NameAttrValue, so
|
||||
// consider adding an overload to the C API to make this case easier.
|
||||
|
@ -28,6 +28,16 @@ limitations under the License.
|
||||
|
||||
using tensorflow::string;
|
||||
|
||||
void TFE_OpReset(TFE_Context* ctx, const char* op_or_function_name,
|
||||
TF_Status* status, TFE_Op* op_to_reset) {
|
||||
if (op_to_reset) {
|
||||
NewOrResetOp(ctx, op_or_function_name, status, op_to_reset);
|
||||
} else {
|
||||
TF_SetStatus(status, TF_INVALID_ARGUMENT,
|
||||
"op_to_reset should not be nullptr");
|
||||
}
|
||||
}
|
||||
|
||||
void TFE_OpConsumeInput(TFE_Op* op, TFE_TensorHandle* h, TF_Status* status) {
|
||||
op->operation.ConsumeInput(h->handle);
|
||||
}
|
||||
@ -597,5 +607,5 @@ void TFE_ContextSetExecutorForThread(TFE_Context* ctx, TFE_Executor* executor) {
|
||||
}
|
||||
|
||||
TFE_Executor* TFE_ContextGetExecutorForThread(TFE_Context* ctx) {
|
||||
return new TFE_Executor(ctx->context->Executor());
|
||||
return new TFE_Executor(&ctx->context->Executor());
|
||||
}
|
||||
|
@ -22,6 +22,10 @@ limitations under the License.
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
TF_CAPI_EXPORT extern void TFE_OpReset(TFE_Context* ctx,
|
||||
const char* op_or_function_name,
|
||||
TF_Status* status, TFE_Op* op_to_reset);
|
||||
|
||||
TF_CAPI_EXPORT extern void TFE_OpConsumeInput(TFE_Op* op, TFE_TensorHandle* h,
|
||||
TF_Status* status);
|
||||
|
||||
|
@ -84,11 +84,6 @@ void ExecuteWithProfiling(bool async) {
|
||||
string profile_proto_str = profile_proto.DebugString();
|
||||
if (!gpu_device_name.empty()) {
|
||||
EXPECT_TRUE(HasSubstr(profile_proto_str, "/device:GPU:0"));
|
||||
// device name with "stream:all" is collected by Device Tracer.
|
||||
#ifndef TENSORFLOW_USE_ROCM
|
||||
// ROCm platform does not yet support stream level tracing
|
||||
EXPECT_TRUE(HasSubstr(profile_proto_str, "stream:all"));
|
||||
#endif
|
||||
}
|
||||
// "/host:CPU" is collected by TraceMe
|
||||
EXPECT_TRUE(HasSubstr(profile_proto_str, "/host:CPU"));
|
||||
|
58
tensorflow/c/eager/c_api_internal.cc
Normal file
58
tensorflow/c/eager/c_api_internal.cc
Normal file
@ -0,0 +1,58 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/c/eager/c_api_internal.h"
|
||||
|
||||
#include "tensorflow/core/platform/host_info.h"
|
||||
|
||||
TFE_Op* NewOrResetOp(TFE_Context* ctx, const char* op_or_function_name,
|
||||
TF_Status* status, TFE_Op* op_to_reset) {
|
||||
const char* name = op_or_function_name; // Shorthand
|
||||
const tensorflow::AttrTypeMap* types;
|
||||
bool is_function = false;
|
||||
status->status = tensorflow::AttrTypeMapForOp(name, &types, &is_function);
|
||||
if (!status->status.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
auto create_or_reset = [&op_to_reset, &ctx, &name, &types](
|
||||
bool is_function,
|
||||
TFE_OpInferenceContext* inference_ctx) -> TFE_Op* {
|
||||
if (op_to_reset) {
|
||||
op_to_reset->Reset(ctx, name, is_function, types, inference_ctx);
|
||||
return op_to_reset;
|
||||
} else {
|
||||
return new TFE_Op(ctx, name, is_function, types, inference_ctx);
|
||||
}
|
||||
};
|
||||
|
||||
if (!is_function) {
|
||||
const tensorflow::OpDef* op_def;
|
||||
status->status = tensorflow::OpDefForOp(op_or_function_name, &op_def);
|
||||
if (!status->status.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
return create_or_reset(false, new TFE_OpInferenceContext(op_def));
|
||||
}
|
||||
if (!ctx->context->FindFunctionByName(name)) {
|
||||
status->status = tensorflow::errors::NotFound(
|
||||
"'", name,
|
||||
"' is neither a type of a primitive operation nor a name "
|
||||
"of a function registered in binary running on ",
|
||||
tensorflow::port::Hostname(),
|
||||
". Make sure the operation or function is "
|
||||
"registered in the binary running in this process.");
|
||||
return nullptr;
|
||||
}
|
||||
return create_or_reset(true, nullptr);
|
||||
}
|
@ -133,10 +133,25 @@ struct TFE_Op {
|
||||
: operation(ctx->context, op, is_function, t),
|
||||
inference_ctx(inference_ctx) {}
|
||||
|
||||
void Clear() {
|
||||
operation.Clear();
|
||||
inference_ctx.reset();
|
||||
}
|
||||
|
||||
void Reset(TFE_Context* ctx, const char* op, bool is_function,
|
||||
const tensorflow::AttrTypeMap* t,
|
||||
TFE_OpInferenceContext* infer_ctx) {
|
||||
operation.Reset(ctx->context, op, is_function, t, nullptr);
|
||||
inference_ctx.reset(infer_ctx);
|
||||
}
|
||||
|
||||
tensorflow::EagerOperation operation;
|
||||
std::unique_ptr<TFE_OpInferenceContext> inference_ctx;
|
||||
};
|
||||
|
||||
TFE_Op* NewOrResetOp(TFE_Context* ctx, const char* op_or_function_name,
|
||||
TF_Status* status, TFE_Op* op_to_reset = nullptr);
|
||||
|
||||
struct TFE_Profiler {
|
||||
explicit TFE_Profiler() { profiler = tensorflow::ProfilerSession::Create(); }
|
||||
|
||||
|
@ -1069,10 +1069,13 @@ void Execute_MatMul_CPU_Runtime_Error(bool async) {
|
||||
// still fail.
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
TFE_DeleteTensorHandle(retvals[0]);
|
||||
TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
|
||||
TFE_ExecutorWaitForAllPendingNodes(executor, status);
|
||||
EXPECT_NE(TF_OK, TF_GetCode(status));
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
retvals[0] = nullptr;
|
||||
TFE_Execute(matmul2, &retvals[0], &num_retvals, status);
|
||||
EXPECT_NE(TF_OK, TF_GetCode(status));
|
||||
TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
|
||||
TFE_ExecutorClearError(executor);
|
||||
TFE_ExecutorWaitForAllPendingNodes(executor, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
@ -292,7 +292,9 @@ string ToCamelCase(const string& str) {
|
||||
bool cap = true;
|
||||
while (i < str.size()) {
|
||||
const char c = str[i++];
|
||||
if (c == joiner) {
|
||||
if (c == '>') {
|
||||
cap = true;
|
||||
} else if (c == joiner) {
|
||||
cap = true;
|
||||
} else if (cap) {
|
||||
result += toupper(c);
|
||||
@ -304,6 +306,21 @@ string ToCamelCase(const string& str) {
|
||||
return result;
|
||||
}
|
||||
|
||||
string SeparateNamespaces(const string& str) {
|
||||
string result;
|
||||
const char joiner = '_';
|
||||
size_t i = 0;
|
||||
while (i < str.size()) {
|
||||
const char c = str[i++];
|
||||
if (c == '>') {
|
||||
result += joiner;
|
||||
} else {
|
||||
result += c;
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// Returns a <string, bool> pair. The string is the C++ type name to be used for
|
||||
// attr_type when defining an object of that type. The bool is a flag to
|
||||
// indicate whether to treat the type as const when accepting the C++ type as an
|
||||
@ -549,7 +566,7 @@ struct OpInfo {
|
||||
OpInfo::OpInfo(const OpDef& graph_op_def, const ApiDef& api_def,
|
||||
const std::vector<string>& aliases)
|
||||
: graph_op_def(graph_op_def), api_def(api_def), aliases(aliases) {
|
||||
op_name = api_def.endpoint(0).name();
|
||||
op_name = SeparateNamespaces(api_def.endpoint(0).name());
|
||||
InferOpAttributes(graph_op_def, &inferred_input_attrs);
|
||||
has_optional_attrs = HasOptionalAttrs(api_def, inferred_input_attrs);
|
||||
arg_types.push_back("const ::tensorflow::Scope&");
|
||||
|
@ -127,8 +127,29 @@ cc_library(
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "loader_test",
|
||||
srcs = ["loader_test.cc"],
|
||||
name = "saved_model_bundle_test",
|
||||
srcs = ["saved_model_bundle_test.cc"],
|
||||
data = [
|
||||
":saved_model_half_plus_two",
|
||||
],
|
||||
linkstatic = 1,
|
||||
deps = [
|
||||
":constants",
|
||||
":loader",
|
||||
":signature_constants",
|
||||
":tag_constants",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:tensorflow",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "saved_model_bundle_lite_test",
|
||||
srcs = ["saved_model_bundle_lite_test.cc"],
|
||||
data = [
|
||||
":saved_model_half_plus_two",
|
||||
],
|
||||
|
@ -299,6 +299,8 @@ Status LoadSavedModelInternal(const SessionOptions& session_options,
|
||||
|
||||
} // namespace
|
||||
|
||||
SavedModelBundleInterface::~SavedModelBundleInterface() {}
|
||||
|
||||
Status LoadSavedModel(const SessionOptions& session_options,
|
||||
const RunOptions& run_options, const string& export_dir,
|
||||
const std::unordered_set<string>& tags,
|
||||
@ -323,6 +325,133 @@ Status LoadSavedModel(const SessionOptions& session_options,
|
||||
return status;
|
||||
}
|
||||
|
||||
namespace {
|
||||
// Session wrapper that prevents calls to Session::Create(), Session::Extend(),
|
||||
// and the deprecated partial-run methods.
|
||||
//
|
||||
// Limiting the available methods on a returned Session gives us the option
|
||||
// to replace the Session with a cut-down implementation, without breaking any
|
||||
// users.
|
||||
class LiteSessionWrapper : public Session {
|
||||
public:
|
||||
explicit LiteSessionWrapper(std::unique_ptr<Session> wrapped)
|
||||
: wrapped_(std::move(wrapped)) {}
|
||||
|
||||
Status Create(const GraphDef& graph) override {
|
||||
return errors::Unimplemented("Session::Create()");
|
||||
}
|
||||
Status Create(GraphDef&& graph) override {
|
||||
return errors::Unimplemented("Session::Create()");
|
||||
}
|
||||
|
||||
Status Extend(const GraphDef& graph) override {
|
||||
return errors::Unimplemented("Session::Extend()");
|
||||
}
|
||||
Status Extend(GraphDef&& graph) override {
|
||||
return errors::Unimplemented("Session::Extend()");
|
||||
}
|
||||
|
||||
Status Run(const std::vector<std::pair<string, Tensor>>& inputs,
|
||||
const std::vector<string>& output_tensor_names,
|
||||
const std::vector<string>& target_node_names,
|
||||
std::vector<Tensor>* outputs) override {
|
||||
return wrapped_->Run(inputs, output_tensor_names, target_node_names,
|
||||
outputs);
|
||||
}
|
||||
|
||||
Status Create(const RunOptions& run_options, const GraphDef& graph) override {
|
||||
return errors::Unimplemented("Session::Create()");
|
||||
}
|
||||
Status Extend(const RunOptions& run_options, const GraphDef& graph) override {
|
||||
return errors::Unimplemented("Session::Extend()");
|
||||
}
|
||||
Status Create(const RunOptions& run_options, GraphDef&& graph) override {
|
||||
return errors::Unimplemented("Session::Create()");
|
||||
}
|
||||
Status Extend(const RunOptions& run_options, GraphDef&& graph) override {
|
||||
return errors::Unimplemented("Session::Extend()");
|
||||
}
|
||||
Status Close(const RunOptions& run_options) override {
|
||||
return wrapped_->Close(run_options);
|
||||
}
|
||||
|
||||
Status Run(const RunOptions& run_options,
|
||||
const std::vector<std::pair<string, Tensor>>& inputs,
|
||||
const std::vector<string>& output_tensor_names,
|
||||
const std::vector<string>& target_node_names,
|
||||
std::vector<Tensor>* outputs, RunMetadata* run_metadata) override {
|
||||
return wrapped_->Run(run_options, inputs, output_tensor_names,
|
||||
target_node_names, outputs, run_metadata);
|
||||
}
|
||||
|
||||
Status PRunSetup(const std::vector<string>& input_names,
|
||||
const std::vector<string>& output_names,
|
||||
const std::vector<string>& target_nodes,
|
||||
string* handle) override {
|
||||
return errors::Unimplemented("Session::PRunSetup()");
|
||||
}
|
||||
|
||||
Status PRun(const string& handle,
|
||||
const std::vector<std::pair<string, Tensor>>& inputs,
|
||||
const std::vector<string>& output_names,
|
||||
std::vector<Tensor>* outputs) override {
|
||||
return errors::Unimplemented("Session::PRun()");
|
||||
}
|
||||
|
||||
Status ListDevices(std::vector<DeviceAttributes>* response) override {
|
||||
return wrapped_->ListDevices(response);
|
||||
}
|
||||
|
||||
Status Close() override { return wrapped_->Close(); }
|
||||
|
||||
Status MakeCallable(const CallableOptions& callable_options,
|
||||
CallableHandle* out_handle) override {
|
||||
return wrapped_->MakeCallable(callable_options, out_handle);
|
||||
}
|
||||
|
||||
Status RunCallable(CallableHandle handle,
|
||||
const std::vector<Tensor>& feed_tensors,
|
||||
std::vector<Tensor>* fetch_tensors,
|
||||
RunMetadata* run_metadata) override {
|
||||
return wrapped_->RunCallable(handle, feed_tensors, fetch_tensors,
|
||||
run_metadata);
|
||||
}
|
||||
|
||||
Status RunCallable(
|
||||
CallableHandle handle, const std::vector<Tensor>& feed_tensors,
|
||||
std::vector<Tensor>* fetch_tensors, RunMetadata* run_metadata,
|
||||
const thread::ThreadPoolOptions& threadpool_options) override {
|
||||
return wrapped_->RunCallable(handle, feed_tensors, fetch_tensors,
|
||||
run_metadata, threadpool_options);
|
||||
}
|
||||
|
||||
Status ReleaseCallable(CallableHandle handle) override {
|
||||
return wrapped_->ReleaseCallable(handle);
|
||||
}
|
||||
|
||||
private:
|
||||
const std::unique_ptr<Session> wrapped_;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
Status LoadSavedModel(const SessionOptions& session_options,
|
||||
const RunOptions& run_options, const string& export_dir,
|
||||
const std::unordered_set<string>& tags,
|
||||
SavedModelBundleLite* const bundle) {
|
||||
SavedModelBundle legacy_bundle;
|
||||
SessionOptions rewritten_options(session_options);
|
||||
rewritten_options.config.mutable_experimental()
|
||||
->set_optimize_for_static_graph(true);
|
||||
// TODO(mrry): Consider specializing the session creation to reduce peak
|
||||
// RAM consumption by using `Session::Create(GraphDef&&)`.
|
||||
TF_RETURN_IF_ERROR(LoadSavedModel(session_options, run_options, export_dir,
|
||||
tags, &legacy_bundle));
|
||||
*bundle = SavedModelBundleLite(
|
||||
absl::make_unique<LiteSessionWrapper>(std::move(legacy_bundle.session)),
|
||||
std::move(*legacy_bundle.meta_graph_def.mutable_signature_def()));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
bool MaybeSavedModelDirectory(const string& export_dir) {
|
||||
const string saved_model_pb_path =
|
||||
io::JoinPath(export_dir, kSavedModelFilenamePb);
|
||||
|
@ -27,31 +27,96 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
/// SavedModel representation once the SavedModel is loaded from storage.
|
||||
struct SavedModelBundle {
|
||||
std::unique_ptr<Session> session;
|
||||
MetaGraphDef meta_graph_def;
|
||||
/// Represents a SavedModel that is loaded from storage.
|
||||
class SavedModelBundleInterface {
|
||||
public:
|
||||
virtual ~SavedModelBundleInterface();
|
||||
|
||||
/// Returns the TensorFlow Session that can be used to interact with the
|
||||
/// SavedModel.
|
||||
virtual Session* GetSession() const = 0;
|
||||
|
||||
/// Returns a map from signature name to SignatureDef for all signatures in
|
||||
/// in the SavedModel.
|
||||
virtual const protobuf::Map<string, SignatureDef>& GetSignatures() const = 0;
|
||||
};
|
||||
|
||||
/// SavedModel representation once the SavedModel is loaded from storage.
|
||||
///
|
||||
/// NOTE: Prefer to use SavedModelBundleLite in new code, as it consumes less
|
||||
/// RAM.
|
||||
struct SavedModelBundle : public SavedModelBundleInterface {
|
||||
/// A TensorFlow Session does not Close itself on destruction. To avoid
|
||||
/// resource leaks, we explicitly call Close on Sessions that we create.
|
||||
~SavedModelBundle() {
|
||||
~SavedModelBundle() override {
|
||||
if (session) {
|
||||
session->Close().IgnoreError();
|
||||
}
|
||||
}
|
||||
|
||||
SavedModelBundle() = default;
|
||||
|
||||
Session* GetSession() const override { return session.get(); }
|
||||
const protobuf::Map<string, SignatureDef>& GetSignatures() const override {
|
||||
return meta_graph_def.signature_def();
|
||||
}
|
||||
|
||||
std::unique_ptr<Session> session;
|
||||
MetaGraphDef meta_graph_def;
|
||||
};
|
||||
|
||||
/// Loads a SavedModel from the specified export directory. The meta graph def
|
||||
// A version of SavedModelBundle that avoids storing a potentially large
|
||||
// MetaGraphDef. Prefer to use SavedModelBundleLite in new code.
|
||||
class SavedModelBundleLite : public SavedModelBundleInterface {
|
||||
public:
|
||||
SavedModelBundleLite() = default;
|
||||
SavedModelBundleLite& operator=(SavedModelBundleLite&& other) = default;
|
||||
|
||||
SavedModelBundleLite(std::unique_ptr<Session> session,
|
||||
protobuf::Map<string, SignatureDef> signatures)
|
||||
: session_(std::move(session)), signatures_(std::move(signatures)) {}
|
||||
|
||||
/// A TensorFlow Session does not Close itself on destruction. To avoid
|
||||
/// resource leaks, we explicitly call Close on Sessions that we create.
|
||||
~SavedModelBundleLite() override {
|
||||
if (session_) {
|
||||
session_->Close().IgnoreError();
|
||||
}
|
||||
}
|
||||
|
||||
Session* GetSession() const override { return session_.get(); }
|
||||
const protobuf::Map<string, SignatureDef>& GetSignatures() const override {
|
||||
return signatures_;
|
||||
}
|
||||
|
||||
private:
|
||||
std::unique_ptr<Session> session_;
|
||||
protobuf::Map<string, SignatureDef> signatures_;
|
||||
};
|
||||
|
||||
/// Loads a SavedModel from the specified export directory. The MetaGraphDef
|
||||
/// to be loaded is identified by the supplied tags, corresponding exactly to
|
||||
/// the set of tags used at SavedModel build time. Returns a SavedModel bundle
|
||||
/// with a session and the requested meta graph def, if found.
|
||||
/// the set of tags used at SavedModel build time. Stores a SavedModel bundle in
|
||||
/// *bundle with a session and the requested MetaGraphDef, if found.
|
||||
///
|
||||
/// NOTE: Prefer the overload that takes a SavedModelBundleLite* in new code.
|
||||
Status LoadSavedModel(const SessionOptions& session_options,
|
||||
const RunOptions& run_options, const string& export_dir,
|
||||
const std::unordered_set<string>& tags,
|
||||
SavedModelBundle* const bundle);
|
||||
|
||||
/// Loads a SavedModel from the specified export directory. The MetaGraphDef
|
||||
/// to be loaded is identified by the supplied tags, corresponding exactly to
|
||||
/// the set of tags used at SavedModel build time. Stores a SavedModel bundle
|
||||
/// in *bundle with a session created from the requested MetaGraphDef if found.
|
||||
///
|
||||
/// This overload creates a SavedModelBundleLite, which consumes less RAM than
|
||||
/// an equivalent SavedModelBundle.
|
||||
Status LoadSavedModel(const SessionOptions& session_options,
|
||||
const RunOptions& run_options, const string& export_dir,
|
||||
const std::unordered_set<string>& tags,
|
||||
SavedModelBundleLite* const bundle);
|
||||
|
||||
/// Checks whether the provided directory could contain a SavedModel. Note that
|
||||
/// the method does not load any data by itself. If the method returns `false`,
|
||||
/// the export directory definitely does not contain a SavedModel. If the method
|
||||
|
244
tensorflow/cc/saved_model/saved_model_bundle_lite_test.cc
Normal file
244
tensorflow/cc/saved_model/saved_model_bundle_lite_test.cc
Normal file
@ -0,0 +1,244 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/cc/saved_model/loader.h"
|
||||
|
||||
#include "tensorflow/cc/saved_model/constants.h"
|
||||
#include "tensorflow/cc/saved_model/signature_constants.h"
|
||||
#include "tensorflow/cc/saved_model/tag_constants.h"
|
||||
#include "tensorflow/core/example/example.pb.h"
|
||||
#include "tensorflow/core/example/feature.pb.h"
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/lib/io/path.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
constexpr char kTestDataPbTxt[] =
|
||||
"cc/saved_model/testdata/half_plus_two_pbtxt/00000123";
|
||||
constexpr char kTestDataMainOp[] =
|
||||
"cc/saved_model/testdata/half_plus_two_main_op/00000123";
|
||||
constexpr char kTestDataSharded[] =
|
||||
"cc/saved_model/testdata/half_plus_two/00000123";
|
||||
constexpr char kTestDataInitOpV2[] =
|
||||
"cc/saved_model/testdata/half_plus_two_v2/00000123";
|
||||
|
||||
class LoaderTest : public ::testing::Test {
|
||||
protected:
|
||||
LoaderTest() {}
|
||||
|
||||
string MakeSerializedExample(float x) {
|
||||
tensorflow::Example example;
|
||||
auto* feature_map = example.mutable_features()->mutable_feature();
|
||||
(*feature_map)["x"].mutable_float_list()->add_value(x);
|
||||
return example.SerializeAsString();
|
||||
}
|
||||
|
||||
void ValidateAssets(const string& export_dir,
|
||||
const SavedModelBundleLite& bundle) {
|
||||
const string asset_directory =
|
||||
io::JoinPath(export_dir, kSavedModelAssetsDirectory);
|
||||
const string asset_filename = "foo.txt";
|
||||
const string asset_filepath = io::JoinPath(asset_directory, asset_filename);
|
||||
TF_EXPECT_OK(Env::Default()->FileExists(asset_filepath));
|
||||
|
||||
std::vector<Tensor> path_outputs;
|
||||
TF_ASSERT_OK(
|
||||
bundle.GetSession()->Run({}, {"filename_tensor:0"}, {}, &path_outputs));
|
||||
ASSERT_EQ(1, path_outputs.size());
|
||||
|
||||
test::ExpectTensorEqual<tstring>(
|
||||
test::AsTensor<tstring>({"foo.txt"}, TensorShape({})), path_outputs[0]);
|
||||
}
|
||||
|
||||
void CheckSavedModelBundle(const string& export_dir,
|
||||
const SavedModelBundleLite& bundle) {
|
||||
ValidateAssets(export_dir, bundle);
|
||||
// Retrieve the regression signature from the bundle.
|
||||
const auto& signature_def = bundle.GetSignatures().at("regress_x_to_y");
|
||||
|
||||
const string input_name = signature_def.inputs().at(kRegressInputs).name();
|
||||
const string output_name =
|
||||
signature_def.outputs().at(kRegressOutputs).name();
|
||||
|
||||
std::vector<tstring> serialized_examples;
|
||||
for (float x : {0, 1, 2, 3}) {
|
||||
serialized_examples.push_back(MakeSerializedExample(x));
|
||||
}
|
||||
|
||||
// Validate the half plus two behavior.
|
||||
Tensor input =
|
||||
test::AsTensor<tstring>(serialized_examples, TensorShape({4}));
|
||||
std::vector<Tensor> outputs;
|
||||
TF_ASSERT_OK(bundle.GetSession()->Run({{input_name, input}}, {output_name},
|
||||
{}, &outputs));
|
||||
ASSERT_EQ(outputs.size(), 1);
|
||||
test::ExpectTensorEqual<float>(
|
||||
outputs[0],
|
||||
test::AsTensor<float>({2, 2.5, 3, 3.5}, TensorShape({4, 1})));
|
||||
}
|
||||
};
|
||||
|
||||
// Test for resource leaks related to TensorFlow session closing requirements
|
||||
// when loading and unloading large numbers of SavedModelBundles.
|
||||
// TODO(sukritiramesh): Increase run iterations and move outside of the test
|
||||
// suite.
|
||||
TEST_F(LoaderTest, ResourceLeakTest) {
|
||||
SavedModelBundleLite bundle;
|
||||
SessionOptions session_options;
|
||||
RunOptions run_options;
|
||||
|
||||
const string export_dir =
|
||||
io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataSharded);
|
||||
for (int i = 0; i < 100; ++i) {
|
||||
TF_ASSERT_OK(LoadSavedModel(session_options, run_options, export_dir,
|
||||
{kSavedModelTagServe}, &bundle));
|
||||
CheckSavedModelBundle(export_dir, bundle);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(LoaderTest, TagMatch) {
|
||||
SavedModelBundleLite bundle;
|
||||
SessionOptions session_options;
|
||||
RunOptions run_options;
|
||||
|
||||
const string export_dir =
|
||||
io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataSharded);
|
||||
TF_ASSERT_OK(LoadSavedModel(session_options, run_options, export_dir,
|
||||
{kSavedModelTagServe}, &bundle));
|
||||
CheckSavedModelBundle(export_dir, bundle);
|
||||
}
|
||||
|
||||
TEST_F(LoaderTest, NoTagMatch) {
|
||||
SavedModelBundleLite bundle;
|
||||
RunOptions run_options;
|
||||
SessionOptions session_options;
|
||||
|
||||
const string export_dir =
|
||||
io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataSharded);
|
||||
Status st = LoadSavedModel(session_options, run_options, export_dir,
|
||||
{"missing-tag"}, &bundle);
|
||||
EXPECT_FALSE(st.ok());
|
||||
EXPECT_TRUE(absl::StrContains(
|
||||
st.error_message(),
|
||||
"Could not find meta graph def matching supplied tags: { missing-tag }"))
|
||||
<< st.error_message();
|
||||
}
|
||||
|
||||
TEST_F(LoaderTest, NoTagMatchMultiple) {
|
||||
SavedModelBundleLite bundle;
|
||||
RunOptions run_options;
|
||||
SessionOptions session_options;
|
||||
|
||||
const string export_dir =
|
||||
io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataSharded);
|
||||
Status st = LoadSavedModel(session_options, run_options, export_dir,
|
||||
{kSavedModelTagServe, "missing-tag"}, &bundle);
|
||||
EXPECT_FALSE(st.ok());
|
||||
EXPECT_TRUE(absl::StrContains(
|
||||
st.error_message(),
|
||||
"Could not find meta graph def matching supplied tags: "))
|
||||
<< st.error_message();
|
||||
}
|
||||
|
||||
TEST_F(LoaderTest, SessionCreationFailure) {
|
||||
SavedModelBundleLite bundle;
|
||||
// Use invalid SessionOptions to cause session creation to fail. Default
|
||||
// options work, so provide an invalid value for the target field.
|
||||
SessionOptions session_options;
|
||||
constexpr char kInvalidTarget[] = "invalid target";
|
||||
session_options.target = kInvalidTarget;
|
||||
RunOptions run_options;
|
||||
|
||||
const string export_dir =
|
||||
io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataSharded);
|
||||
Status st = LoadSavedModel(session_options, run_options, export_dir,
|
||||
{kSavedModelTagServe}, &bundle);
|
||||
EXPECT_FALSE(st.ok());
|
||||
EXPECT_TRUE(absl::StrContains(st.error_message(), kInvalidTarget))
|
||||
<< st.error_message();
|
||||
}
|
||||
|
||||
TEST_F(LoaderTest, PbtxtFormat) {
|
||||
SavedModelBundleLite bundle;
|
||||
SessionOptions session_options;
|
||||
RunOptions run_options;
|
||||
|
||||
const string export_dir =
|
||||
io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPbTxt);
|
||||
TF_ASSERT_OK(LoadSavedModel(session_options, run_options, export_dir,
|
||||
{kSavedModelTagServe}, &bundle));
|
||||
CheckSavedModelBundle(export_dir, bundle);
|
||||
}
|
||||
|
||||
TEST_F(LoaderTest, MainOpFormat) {
|
||||
SavedModelBundleLite bundle;
|
||||
SessionOptions session_options;
|
||||
RunOptions run_options;
|
||||
|
||||
const string export_dir =
|
||||
io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataMainOp);
|
||||
TF_ASSERT_OK(LoadSavedModel(session_options, run_options, export_dir,
|
||||
{kSavedModelTagServe}, &bundle));
|
||||
CheckSavedModelBundle(export_dir, bundle);
|
||||
}
|
||||
|
||||
TEST_F(LoaderTest, InvalidExportPath) {
|
||||
SavedModelBundleLite bundle;
|
||||
RunOptions run_options;
|
||||
SessionOptions session_options;
|
||||
|
||||
const string export_dir =
|
||||
io::JoinPath(testing::TensorFlowSrcRoot(), "missing-path");
|
||||
Status st = LoadSavedModel(session_options, run_options, export_dir,
|
||||
{kSavedModelTagServe}, &bundle);
|
||||
EXPECT_FALSE(st.ok());
|
||||
}
|
||||
|
||||
TEST_F(LoaderTest, MaybeSavedModelDirectory) {
|
||||
// Valid SavedModel directory.
|
||||
const string export_dir =
|
||||
io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataSharded);
|
||||
EXPECT_TRUE(MaybeSavedModelDirectory(export_dir));
|
||||
|
||||
// Directory that does not exist.
|
||||
const string missing_export_dir =
|
||||
io::JoinPath(testing::TensorFlowSrcRoot(), "missing-path");
|
||||
EXPECT_FALSE(MaybeSavedModelDirectory(missing_export_dir));
|
||||
|
||||
// Directory that exists but is an invalid SavedModel location.
|
||||
const string invalid_export_dir =
|
||||
io::JoinPath(testing::TensorFlowSrcRoot(), "cc/saved_model");
|
||||
EXPECT_FALSE(MaybeSavedModelDirectory(invalid_export_dir));
|
||||
}
|
||||
|
||||
TEST_F(LoaderTest, SavedModelInitOpV2Format) {
|
||||
SavedModelBundleLite bundle;
|
||||
SessionOptions session_options;
|
||||
RunOptions run_options;
|
||||
|
||||
const string export_dir =
|
||||
io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataInitOpV2);
|
||||
TF_ASSERT_OK(LoadSavedModel(session_options, run_options, export_dir,
|
||||
{kSavedModelTagServe}, &bundle));
|
||||
CheckSavedModelBundle(export_dir, bundle);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
@ -71,8 +71,7 @@ class LoaderTest : public ::testing::Test {
|
||||
const SavedModelBundle& bundle) {
|
||||
ValidateAssets(export_dir, bundle);
|
||||
// Retrieve the regression signature from meta graph def.
|
||||
const auto signature_def_map = bundle.meta_graph_def.signature_def();
|
||||
const auto signature_def = signature_def_map.at("regress_x_to_y");
|
||||
const auto& signature_def = bundle.GetSignatures().at("regress_x_to_y");
|
||||
|
||||
const string input_name = signature_def.inputs().at(kRegressInputs).name();
|
||||
const string output_name =
|
@ -1,5 +1,5 @@
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cc_test", "cc_header_only_library")
|
||||
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
|
||||
load("//tensorflow/stream_executor:build_defs.bzl", "if_cuda_or_rocm")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library", "tf_jit_compilation_passes_extra_deps")
|
||||
load("//tensorflow/core/platform:default/build_config.bzl", "tf_additional_all_protos", "tf_proto_library")
|
||||
|
||||
@ -38,7 +38,7 @@ cc_library(
|
||||
":xla_cpu_device",
|
||||
":xla_cpu_jit",
|
||||
"//tensorflow/compiler/plugin",
|
||||
] + if_cuda([
|
||||
] + if_cuda_or_rocm([
|
||||
":xla_gpu_device",
|
||||
":xla_gpu_jit",
|
||||
]),
|
||||
@ -61,7 +61,7 @@ cc_library(
|
||||
cc_library(
|
||||
name = "xla_gpu_jit",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = if_cuda([
|
||||
deps = if_cuda_or_rocm([
|
||||
":jit_compilation_passes",
|
||||
"//tensorflow/compiler/jit/kernels:xla_ops",
|
||||
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
||||
|
@ -130,17 +130,24 @@ RecursiveCompilabilityChecker::FindUncompilableNodes(
|
||||
return uncompilable_nodes;
|
||||
}
|
||||
|
||||
bool RecursiveCompilabilityChecker::HasXLAKernel(const Node& node) const {
|
||||
bool RecursiveCompilabilityChecker::HasXLAKernel(
|
||||
const Node& node, string* uncompilable_reason) const {
|
||||
// There is a SymbolicGradient kernel on the XLA_JIT device, but the gradient
|
||||
// is really a kind of function call and will be handled by
|
||||
// IsCompilableCall().
|
||||
if (node.type_string() == "SymbolicGradient") return false;
|
||||
if (node.type_string() == "SymbolicGradient") {
|
||||
*uncompilable_reason =
|
||||
"SymbolicGradient should be handled by IsCompilableCall().";
|
||||
return false;
|
||||
}
|
||||
if (node.type_string() == "Const") {
|
||||
// Skip Const op with type DT_STRING, since XLA doesn't support it, but the
|
||||
// registered Const KernelDef says that it does, to support no-op Assert for
|
||||
// tfcompile.
|
||||
const AttrValue* attr = node.attrs().Find("dtype");
|
||||
if (attr != nullptr && attr->type() == DT_STRING) {
|
||||
*uncompilable_reason =
|
||||
"Const op with type DT_STRING is not supported by XLA.";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@ -150,10 +157,16 @@ bool RecursiveCompilabilityChecker::HasXLAKernel(const Node& node) const {
|
||||
// such nodes out of XLA clusters.
|
||||
if (HasForwardedRefInput(node)) {
|
||||
VLOG(2) << "Rejecting " << node.name() << ": Identity with unsafe cast.";
|
||||
*uncompilable_reason = "Identity with unsafe cast.";
|
||||
return false;
|
||||
}
|
||||
|
||||
return FindKernelDef(jit_device_type_, node.def(), nullptr, nullptr).ok();
|
||||
Status s = FindKernelDef(jit_device_type_, node.def(), nullptr, nullptr);
|
||||
if (!s.ok()) {
|
||||
*uncompilable_reason = s.error_message();
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// Tests whether 'if_node' is compilable. Every operator in the then_branch and
|
||||
@ -336,16 +349,17 @@ bool RecursiveCompilabilityChecker::IsCompilableNode(
|
||||
return false;
|
||||
}
|
||||
|
||||
string uncompilable_reason;
|
||||
if (IsFunctionCall(*lib_runtime->GetFunctionLibraryDefinition(), node)) {
|
||||
if (!IsCompilableCall(node.def(), lib_runtime, stack_trace,
|
||||
encapsulating_function, uncompilable_nodes)) {
|
||||
LogNotCompilable(node, "unsupported function");
|
||||
return false;
|
||||
}
|
||||
} else if (!HasXLAKernel(node)) {
|
||||
absl::string_view uncompilable_reason = "unsupported op";
|
||||
MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
|
||||
encapsulating_function, uncompilable_nodes);
|
||||
} else if (!HasXLAKernel(node, &uncompilable_reason)) {
|
||||
MaybeMarkUncompilableNode(
|
||||
absl::StrCat("unsupported op: ", uncompilable_reason), *stack_trace,
|
||||
encapsulating_function, uncompilable_nodes);
|
||||
LogNotCompilable(node, uncompilable_reason);
|
||||
return false;
|
||||
}
|
||||
|
@ -247,7 +247,8 @@ class RecursiveCompilabilityChecker {
|
||||
absl::c_any_of(node.output_types(), is_variant);
|
||||
}
|
||||
|
||||
bool HasXLAKernel(const Node& node) const;
|
||||
bool HasXLAKernel(const Node& node,
|
||||
string* uncompilable_reason = nullptr) const;
|
||||
|
||||
static void MaybeMarkUncompilableNode(
|
||||
const absl::string_view reason,
|
||||
|
@ -125,7 +125,8 @@ TEST_F(CompilabilityCheckUtilTest, CheckNonFunctionalNodes) {
|
||||
const auto& uncompilable_nodes_inside_function = node_info_it->second.second;
|
||||
ASSERT_EQ(1, uncompilable_nodes_inside_function.size());
|
||||
const auto& uncompilable_node_info = uncompilable_nodes_inside_function.at(0);
|
||||
EXPECT_EQ("unsupported op", uncompilable_node_info.uncompilable_reason);
|
||||
EXPECT_TRUE(absl::StrContains(uncompilable_node_info.uncompilable_reason,
|
||||
"unsupported op"));
|
||||
ASSERT_EQ(1, uncompilable_node_info.stack_trace.size());
|
||||
ASSERT_EQ("", uncompilable_node_info.stack_trace.at(0).function_name);
|
||||
}
|
||||
@ -167,7 +168,8 @@ TEST_F(CompilabilityCheckUtilTest, CheckSimpleFunctionNode) {
|
||||
EXPECT_EQ("D", node_stack.at(0).name);
|
||||
EXPECT_EQ(kUncompilableFunctionNodeName, node_stack.at(1).name);
|
||||
EXPECT_EQ(kUncompilableFunctionNodeName, node_info.name);
|
||||
EXPECT_EQ("unsupported op", node_info.uncompilable_reason);
|
||||
EXPECT_TRUE(
|
||||
absl::StrContains(node_info.uncompilable_reason, "unsupported op"));
|
||||
}
|
||||
|
||||
TEST_F(CompilabilityCheckUtilTest, CheckFunctionalWhileNode) {
|
||||
@ -246,7 +248,8 @@ TEST_F(CompilabilityCheckUtilTest, CheckFunctionalWhileNode) {
|
||||
stacktrace_second_node_info.function_name);
|
||||
|
||||
EXPECT_EQ(kUncompilableFunctionNodeName, node_info.name);
|
||||
EXPECT_EQ("unsupported op", node_info.uncompilable_reason);
|
||||
EXPECT_TRUE(
|
||||
absl::StrContains(node_info.uncompilable_reason, "unsupported op"));
|
||||
}
|
||||
|
||||
TEST_F(CompilabilityCheckUtilTest, CheckFunctionalIfNode) {
|
||||
@ -322,7 +325,8 @@ TEST_F(CompilabilityCheckUtilTest, CheckFunctionalIfNode) {
|
||||
stacktrace_second_node_info.function_name);
|
||||
|
||||
EXPECT_EQ(kUncompilableFunctionNodeName, uncompilable_node_one.name);
|
||||
EXPECT_EQ("unsupported op", uncompilable_node_one.uncompilable_reason);
|
||||
EXPECT_TRUE(absl::StrContains(uncompilable_node_one.uncompilable_reason,
|
||||
"unsupported op"));
|
||||
|
||||
NameAttrList function_two;
|
||||
function_two.set_name(kUncompilableFunctionTwoName);
|
||||
@ -345,7 +349,8 @@ TEST_F(CompilabilityCheckUtilTest, CheckFunctionalIfNode) {
|
||||
node_two_stacktrace_second_node.function_name);
|
||||
|
||||
EXPECT_EQ(kUncompilableFunctionNodeTwoName, uncompilable_node_two.name);
|
||||
EXPECT_EQ("unsupported op", uncompilable_node_two.uncompilable_reason);
|
||||
EXPECT_TRUE(absl::StrContains(uncompilable_node_one.uncompilable_reason,
|
||||
"unsupported op"));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -45,7 +45,8 @@ cc_library(
|
||||
"//tensorflow/compiler/mlir/lite:tensorflow_lite_quantize",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_dialect_registration",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_fold_switch",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_passes",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_test_passes",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tf_dialect_passes",
|
||||
"//tensorflow/compiler/mlir/xla:hlo",
|
||||
"//tensorflow/compiler/mlir/xla:lhlo",
|
||||
|
3
tensorflow/compiler/mlir/g3doc/README.md
Normal file
3
tensorflow/compiler/mlir/g3doc/README.md
Normal file
@ -0,0 +1,3 @@
|
||||
# TensorFlow MLIR
|
||||
|
||||
These are the docs for: https://www.tensorflow.org/mlir
|
24
tensorflow/compiler/mlir/g3doc/_book.yaml
Normal file
24
tensorflow/compiler/mlir/g3doc/_book.yaml
Normal file
@ -0,0 +1,24 @@
|
||||
upper_tabs:
|
||||
# Tabs left of dropdown menu
|
||||
- include: /_upper_tabs_left.yaml
|
||||
- include: /api_docs/_upper_tabs_api.yaml
|
||||
# Dropdown menu
|
||||
- name: Resources
|
||||
path: /resources
|
||||
is_default: true
|
||||
menu:
|
||||
- include: /resources/_menu_toc.yaml
|
||||
lower_tabs:
|
||||
# Subsite tabs
|
||||
other:
|
||||
- name: Guide & Tutorials
|
||||
contents:
|
||||
- title: Overview
|
||||
path: /mlir/overview
|
||||
- heading: Dialects
|
||||
- title: TensorFlow
|
||||
path: /mlir/tf_ops
|
||||
- title: TensorFlow Lite
|
||||
path: /mlir/tfl_ops
|
||||
|
||||
- include: /_upper_tabs_right.yaml
|
48
tensorflow/compiler/mlir/g3doc/_index.yaml
Normal file
48
tensorflow/compiler/mlir/g3doc/_index.yaml
Normal file
@ -0,0 +1,48 @@
|
||||
book_path: /mlir/_book.yaml
|
||||
project_path: /mlir/_project.yaml
|
||||
description: <!--no description-->
|
||||
landing_page:
|
||||
custom_css_path: /site-assets/css/style.css
|
||||
rows:
|
||||
- heading: MLIR unifies the infrastructure for high-performance ML models in TensorFlow.
|
||||
items:
|
||||
- description: >
|
||||
The MLIR project defines a common intermediate representation (IR) that
|
||||
unifies the infrastructure required to execute high performance machine
|
||||
learning models in TensorFlow and similar ML frameworks. This project
|
||||
will include the application of HPC techniques, along with integration of
|
||||
search algorithms like reinforcement learning. MLIR aims to reduce the
|
||||
cost to bring up new hardware, and improve usability for existing
|
||||
TensorFlow users.
|
||||
|
||||
- code_block: |
|
||||
<pre class = "prettyprint">
|
||||
// Syntactically similar to LLVM:
|
||||
func @testFunction(%arg0: i32) {
|
||||
%x = call @thingToCall(%arg0) : (i32) -> i32
|
||||
br ^bb1
|
||||
^bb1:
|
||||
%y = addi %x, %x : i32
|
||||
return %y : i32
|
||||
}
|
||||
</pre>
|
||||
|
||||
- classname: devsite-landing-row-cards
|
||||
items:
|
||||
- heading: "Multi-Level Intermediate Representation for Compiler Infrastructure"
|
||||
youtube_id: qzljG6DKgic
|
||||
buttons:
|
||||
- label: Watch the video
|
||||
path: https://www.youtube.com/watch?v=qzljG6DKgic
|
||||
- heading: "A new intermediate representation and compiler framework"
|
||||
image_path: /resources/images/tf-logo-card-16x9.png
|
||||
path: https://medium.com/tensorflow/mlir-a-new-intermediate-representation-and-compiler-framework-beba999ed18d
|
||||
buttons:
|
||||
- label: Read on TensorFlow blog
|
||||
path: https://medium.com/tensorflow/mlir-a-new-intermediate-representation-and-compiler-framework-beba999ed18d
|
||||
- heading: TensorFlow MLIR on GitHub
|
||||
image_path: /resources/images/github-card-16x9.png
|
||||
path: https://github.com/tensorflow/mlir
|
||||
buttons:
|
||||
- label: View on GitHub
|
||||
path: https://github.com/tensorflow/mlir
|
11
tensorflow/compiler/mlir/g3doc/_project.yaml
Normal file
11
tensorflow/compiler/mlir/g3doc/_project.yaml
Normal file
@ -0,0 +1,11 @@
|
||||
name: TensorFlow MLIR
|
||||
breadcrumb_name: MLIR
|
||||
home_url: /mlir/
|
||||
parent_project_metadata_path: /_project.yaml
|
||||
description: >
|
||||
MLIR unifies the infrastructure for high-performance ML models in TensorFlow.
|
||||
use_site_branding: true
|
||||
hide_from_products_list: true
|
||||
content_license: cc-apache
|
||||
buganizer_id: 443907
|
||||
include: /_project_included.yaml
|
1
tensorflow/compiler/mlir/g3doc/images/mlir-infra.svg
Normal file
1
tensorflow/compiler/mlir/g3doc/images/mlir-infra.svg
Normal file
File diff suppressed because one or more lines are too long
After Width: | Height: | Size: 148 KiB |
5
tensorflow/compiler/mlir/g3doc/overview.md
Normal file
5
tensorflow/compiler/mlir/g3doc/overview.md
Normal file
@ -0,0 +1,5 @@
|
||||
# MLIR overview
|
||||
|
||||
## Overview
|
||||
|
||||
<img alt="MLIR overview diagram" src="./images/mlir-infra.svg"/>
|
2761
tensorflow/compiler/mlir/g3doc/tf_ops.md
Normal file
2761
tensorflow/compiler/mlir/g3doc/tf_ops.md
Normal file
File diff suppressed because it is too large
Load Diff
1606
tensorflow/compiler/mlir/g3doc/tfl_ops.md
Normal file
1606
tensorflow/compiler/mlir/g3doc/tfl_ops.md
Normal file
File diff suppressed because it is too large
Load Diff
@ -1,4 +1,4 @@
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_native_cc_binary")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test", "tf_native_cc_binary")
|
||||
load(
|
||||
"@local_config_mlir//:tblgen.bzl",
|
||||
"gentbl",
|
||||
@ -185,6 +185,41 @@ cc_library(
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "lstm_utils",
|
||||
srcs = [
|
||||
"utils/lstm_utils.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"utils/lstm_utils.h",
|
||||
],
|
||||
copts = ["-std=c++14"],
|
||||
deps = [
|
||||
":tensorflow_lite",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"@llvm//:support",
|
||||
"@local_config_mlir//:IR",
|
||||
"@local_config_mlir//:StandardOps",
|
||||
"@local_config_mlir//:Support",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "lstm_utils_test",
|
||||
size = "small",
|
||||
srcs = ["utils/lstm_utils_test.cc"],
|
||||
deps = [
|
||||
":lstm_utils",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"@llvm//:support",
|
||||
"@local_config_mlir//:IR",
|
||||
"@local_config_mlir//:StandardOps",
|
||||
"@local_config_mlir//:Support",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tensorflow_lite_legalize_tf",
|
||||
srcs = [
|
||||
@ -198,9 +233,11 @@ cc_library(
|
||||
"transforms/prepare_composite_functions_tf.cc",
|
||||
"transforms/prepare_tf.cc",
|
||||
"transforms/trim_functions_tf.cc",
|
||||
"transforms/unroll_batch_matmul.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"transforms/passes.h",
|
||||
"transforms/unroll_batch_matmul.h",
|
||||
],
|
||||
deps = [
|
||||
":common",
|
||||
@ -249,6 +286,7 @@ cc_library(
|
||||
name = "tensorflow_lite_quantize",
|
||||
srcs = [
|
||||
"transforms/generated_quantize.inc",
|
||||
"transforms/load_quantization_recipe.cc",
|
||||
"transforms/post_quantize.cc",
|
||||
"transforms/prepare_quantize.cc",
|
||||
"transforms/quantize.cc",
|
||||
@ -521,7 +559,7 @@ cc_library(
|
||||
":tensorflow_lite_quantize",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"//tensorflow/compiler/mlir/tensorflow:error_util",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_fold_switch",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_passes",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tf_dialect_lib",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tf_dialect_passes",
|
||||
"//tensorflow/compiler/mlir/tensorflow:translate_lib",
|
||||
|
@ -99,7 +99,10 @@ using xla::StatusOr;
|
||||
template <typename T>
|
||||
using BufferOffset = flatbuffers::Offset<T>;
|
||||
|
||||
using CustomOptionsOffset = BufferOffset<flatbuffers::Vector<uint8_t>>;
|
||||
template <typename T>
|
||||
using VectorBufferOffset = flatbuffers::Offset<flatbuffers::Vector<T>>;
|
||||
|
||||
using CustomOptionsOffset = VectorBufferOffset<uint8_t>;
|
||||
|
||||
namespace error = tensorflow::error;
|
||||
namespace tfl = mlir::TFL;
|
||||
@ -415,6 +418,15 @@ class Translator {
|
||||
|
||||
Optional<BufferOffset<tflite::SubGraph>> BuildSubGraph(FuncOp fn);
|
||||
|
||||
// Builds Metadata with the given `name` and buffer `content`.
|
||||
BufferOffset<tflite::Metadata> BuildMetadata(StringRef name,
|
||||
StringRef content);
|
||||
|
||||
// Encodes the `tfl.metadata` dictionary attribute of the module to the
|
||||
// metadata section in the final model.
|
||||
Optional<VectorBufferOffset<BufferOffset<tflite::Metadata>>>
|
||||
CreateMetadataVector();
|
||||
|
||||
// Uses the tf.entry_function attribute (if set) to initialize the op to name
|
||||
// mapping.
|
||||
void InitializeNamesFromAttribute(FuncOp fn);
|
||||
@ -977,6 +989,36 @@ Optional<BufferOffset<tflite::SubGraph>> Translator::BuildSubGraph(FuncOp fn) {
|
||||
/*name=*/builder_.CreateString(fn.getName().str()));
|
||||
}
|
||||
|
||||
BufferOffset<tflite::Metadata> Translator::BuildMetadata(StringRef name,
|
||||
StringRef content) {
|
||||
auto buffer_index = buffers_.size();
|
||||
auto buffer_data = builder_.CreateVector(
|
||||
reinterpret_cast<const uint8_t*>(content.data()), content.size());
|
||||
buffers_.push_back(tflite::CreateBuffer(builder_, buffer_data));
|
||||
return tflite::CreateMetadataDirect(builder_, name.data(), buffer_index);
|
||||
}
|
||||
|
||||
Optional<VectorBufferOffset<BufferOffset<tflite::Metadata>>>
|
||||
Translator::CreateMetadataVector() {
|
||||
auto dict_attr = module_.getAttrOfType<mlir::DictionaryAttr>("tfl.metadata");
|
||||
if (!dict_attr) return VectorBufferOffset<BufferOffset<tflite::Metadata>>();
|
||||
|
||||
std::vector<BufferOffset<tflite::Metadata>> metadata;
|
||||
for (const auto& named_attr : dict_attr) {
|
||||
StringRef name = named_attr.first;
|
||||
mlir::Attribute attr = named_attr.second;
|
||||
if (auto content = attr.dyn_cast<StringAttr>()) {
|
||||
metadata.push_back(BuildMetadata(name, content.getValue()));
|
||||
} else {
|
||||
module_.emitError(
|
||||
"all values in tfl.metadata's dictionary key-value pairs should be "
|
||||
"string attributes");
|
||||
return llvm::None;
|
||||
}
|
||||
}
|
||||
return builder_.CreateVector(metadata);
|
||||
}
|
||||
|
||||
Optional<std::string> Translator::Translate(ModuleOp module,
|
||||
bool emit_builtin_tflite_ops,
|
||||
bool emit_select_tf_ops,
|
||||
@ -1024,12 +1066,17 @@ Optional<std::string> Translator::TranslateInternal() {
|
||||
} else {
|
||||
model_description = "MLIR Converted.";
|
||||
}
|
||||
|
||||
// Build the model and finish the model building process.
|
||||
auto description = builder_.CreateString(model_description.data());
|
||||
VectorBufferOffset<int32_t> metadata_buffer = 0; // Deprecated
|
||||
auto metadata = CreateMetadataVector();
|
||||
if (!metadata) return llvm::None;
|
||||
|
||||
auto model = tflite::CreateModel(
|
||||
builder_, TFLITE_SCHEMA_VERSION, builder_.CreateVector(opcodes_),
|
||||
builder_.CreateVector(subgraphs), description,
|
||||
builder_.CreateVector(buffers_));
|
||||
builder_.CreateVector(buffers_), metadata_buffer, *metadata);
|
||||
tflite::FinishModelBuffer(builder_, model);
|
||||
|
||||
// Return serialized string for the built FlatBuffer.
|
||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdint>
|
||||
|
||||
#include "llvm/ADT/APFloat.h"
|
||||
@ -30,6 +31,7 @@ limitations under the License.
|
||||
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/TypeUtilities.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
|
||||
|
||||
namespace mlir {
|
||||
@ -1167,6 +1169,54 @@ OpFoldResult TransposeOp::fold(ArrayRef<Attribute> operands) {
|
||||
return DenseElementsAttr::get(result_type, new_values);
|
||||
}
|
||||
|
||||
static LogicalResult Verify(TransposeOp op) {
|
||||
auto input_type = op.x()->getType().cast<ShapedType>();
|
||||
auto perm_type = op.perm()->getType().cast<ShapedType>();
|
||||
auto output_type = op.y()->getType().cast<ShapedType>();
|
||||
if (input_type.hasStaticShape() && perm_type.hasStaticShape()) {
|
||||
if (perm_type.getNumElements() != input_type.getRank()) {
|
||||
return op.emitOpError(
|
||||
"perm tensor elements size is not equal to input tensor rank");
|
||||
}
|
||||
}
|
||||
|
||||
DenseIntElementsAttr perm;
|
||||
if (!matchPattern(op.perm(), m_Constant(&perm))) {
|
||||
return success();
|
||||
}
|
||||
|
||||
int index = 0;
|
||||
llvm::SmallVector<int64_t, 4> axes;
|
||||
for (auto axis_int : perm.getValues<APInt>()) {
|
||||
const int64_t axis = axis_int.getSExtValue();
|
||||
if (axis < 0 || (input_type.hasRank() && axis >= input_type.getRank())) {
|
||||
return op.emitOpError(
|
||||
llvm::formatv("perm[{0}] must be in [0, rank)", index));
|
||||
}
|
||||
if (std::count(axes.begin(), axes.end(), axis) > 0) {
|
||||
return op.emitOpError(
|
||||
llvm::formatv("perm[{0}] cannot have duplicated axis", index));
|
||||
}
|
||||
axes.push_back(axis);
|
||||
index++;
|
||||
}
|
||||
|
||||
if (input_type.hasStaticShape() && output_type.hasStaticShape()) {
|
||||
llvm::SmallVector<int64_t, 4> transposed_shape;
|
||||
for (int64_t axis : axes) {
|
||||
transposed_shape.push_back(input_type.getDimSize(axis));
|
||||
}
|
||||
auto expected_output_type =
|
||||
RankedTensorType::get(transposed_shape, input_type.getElementType());
|
||||
if (output_type != expected_output_type) {
|
||||
return op.emitOpError(llvm::formatv("expect output type {0}, got {1}",
|
||||
expected_output_type, output_type));
|
||||
}
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TableGen'd op method definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -132,15 +132,35 @@ def TFL_FpOrI32OrI64Tensor : TensorOf<[AnyFloat, TFL_Int32Or64]>;
|
||||
// Rank/Shape helpers.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class TFL_OperandIsUnrankedPred<int n> :
|
||||
CPred<"$_op.getOperand(" # n # ")->getType().isa<UnrankedTensorType>()">;
|
||||
|
||||
// TODO: Some of these could be generalized and/or moved to more general
|
||||
// location.
|
||||
// Returns true if the n-th operand has unknown rank or has rank m.
|
||||
class TFL_OperandHasRank<int n, int m> :
|
||||
PredOpTrait<"operand " # n # " is " # m # "-D",
|
||||
Or<[CPred<"$_op.getOperand(" # n # ")->getType().isa<UnrankedTensorType>()">,
|
||||
Or<[TFL_OperandIsUnrankedPred<n>,
|
||||
CPred<"$_op.getOperand(" # n #
|
||||
")->getType().cast<ShapedType>().getRank() == " # m>]>>;
|
||||
|
||||
// CPred version of TFL_OperandHasRank.
|
||||
class TFL_OperandHasRankPred<int n, int m> :
|
||||
Or<[TFL_OperandIsUnrankedPred<n>,
|
||||
CPred<"$_op.getOperand(" # n #
|
||||
")->getType().cast<ShapedType>().getRank() == " # m>]>;
|
||||
|
||||
// True if operand n is ranked and has a rank > dim.
|
||||
class TFL_OperandIsRankedAndHasDimPred<int n, int dim> : And<[
|
||||
CPred<"$_op.getOperand(" # n # ")->getType().isa<RankedTensorType>()">,
|
||||
CPred<"$_op.getOperand(" # n # ")->getType().cast<ShapedType>().getRank() > "
|
||||
# dim>]>;
|
||||
|
||||
class TFL_OperandDimEquals<int n, int dim, int size> : And<[
|
||||
TFL_OperandIsRankedAndHasDimPred<n, dim>,
|
||||
CPred<"$_op.getOperand(" # n # ")->getType().cast<ShapedType>()"
|
||||
".getShape()[" # dim # " ] == " # size>]>;
|
||||
|
||||
// Returns true if the n-th operand has unknown rank or at least rank m.
|
||||
class TFL_OperandHasAtleastRank<int n, int m> :
|
||||
PredOpTrait<"operand " # n # " is " # m # "-D",
|
||||
@ -155,6 +175,32 @@ class TFL_OperandRankEquals1DimOfOperand<int x, int y> :
|
||||
"$_op.getOperand(" # y #
|
||||
")->getType().cast<ShapedType>().getShape()[0]">>;
|
||||
|
||||
// True if x_shape[dim] == y_shape[dim].
|
||||
class TFL_DimOfOperandEqualsDimOfOperandPred<int x, int y, int dim> : And<[
|
||||
TFL_OperandIsRankedAndHasDimPred<x, dim>,
|
||||
TFL_OperandIsRankedAndHasDimPred<y, dim>,
|
||||
CPred<"$_op.getOperand(" # x #
|
||||
")->getType().cast<ShapedType>().getShape()[" # dim # "] == "
|
||||
"$_op.getOperand(" # y #
|
||||
")->getType().cast<ShapedType>().getShape()[" # dim # "]">]>;
|
||||
|
||||
// Select operands must satisfy one of the following constraints:
|
||||
// All inputs are unranked/scalars
|
||||
// OR
|
||||
// All inputs are ranked AND have equal dim[0] AND X & Y have same rank.
|
||||
def SelectShapeConstraints :
|
||||
PredOpTrait<"Select operands meet shape criteria",
|
||||
Or<[
|
||||
And<[
|
||||
TFL_OperandHasRankPred<0, 0>,
|
||||
TFL_OperandHasRankPred<1, 0>,
|
||||
TFL_OperandHasRankPred<2, 0>]>,
|
||||
And<[
|
||||
TFL_DimOfOperandEqualsDimOfOperandPred<0, 1, 0>,
|
||||
TFL_DimOfOperandEqualsDimOfOperandPred<0, 2, 0>,
|
||||
CPred<"$_op.getOperand(1)->getType().cast<ShapedType>().getRank() == "
|
||||
"$_op.getOperand(2)->getType().cast<ShapedType>().getRank()">]>]>>;
|
||||
|
||||
// This is a quantization-aware version of TCresVTEtIsSameAsOp
|
||||
class TFL_TCresVTEtIsSameAsOp<int i, int j> : And<[
|
||||
TCOpResIsShapedTypePred<i, j>,
|
||||
@ -315,7 +361,7 @@ def TFL_AddOp : TFL_Op<"add", [Broadcastable, NoSideEffect, Commutative]> {
|
||||
|
||||
// TODO(haoliang): Implement legalization pass after pattern rewrite generator
|
||||
// supports variadic inputs.
|
||||
def TFL_AddNOp : TFL_Op<"add_n", [Commutative, NoSideEffect]> {
|
||||
def TFL_AddNOp : TFL_Op<"add_n", [Commutative, NoSideEffect, SameOperandsAndResultsScale]> {
|
||||
let summary = "add_n operator";
|
||||
|
||||
let description = [{
|
||||
@ -323,11 +369,11 @@ def TFL_AddNOp : TFL_Op<"add_n", [Commutative, NoSideEffect]> {
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
Variadic<TensorOf<[F32, I32]>>:$inputs
|
||||
Variadic<TensorOf<[F32, I32, QI16, QUI16]>>:$inputs
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TensorOf<[F32, I32]>:$sum
|
||||
TensorOf<[F32, I32, QI16, QUI16]>:$sum
|
||||
);
|
||||
}
|
||||
|
||||
@ -680,6 +726,117 @@ def TFL_GreaterEqualOp : TFL_Op<"greater_equal", [
|
||||
let hasOptions = 0;
|
||||
}
|
||||
|
||||
// These ops are named NonMaxSuppressionV4 & NonMaxSuppressionV5 to be
|
||||
// consistent with TensorFlow's naming. They are NOT 'versions' of NMS in the
|
||||
// sense that one is an incremental change over the other.
|
||||
// In reality NonMaxSuppressionV5 implements Soft Non Max Suppression and
|
||||
// NonMaxSuppressionV4 performs hard NMS.
|
||||
|
||||
def TFL_NonMaxSuppressionV4Op : TFL_Op<"non_max_suppression_v4", [
|
||||
NoSideEffect,
|
||||
// Operand 0 (boxes) should have rank 2 with the dim[1] == 4 (box corners)
|
||||
TFL_OperandHasRank<0, 2>,
|
||||
PredOpTrait<"boxes should have dim[1] == 4",
|
||||
TFL_OperandDimEquals<0, 1, 4>>,
|
||||
// Operand 1 (scores) should be a 1-dim tensor
|
||||
TFL_OperandHasRank<1, 1>,
|
||||
// Other operands are scalar params.
|
||||
TFL_OperandHasRank<2, 0>, TFL_OperandHasRank<3, 0>,
|
||||
TFL_OperandHasRank<4, 0>]> {
|
||||
let summary = [{
|
||||
Greedily selects a subset of bounding boxes in descending order of score,
|
||||
}];
|
||||
|
||||
let description = [{
|
||||
pruning away boxes that have high intersection-over-union (IOU) overlap
|
||||
with previously selected boxes. Bounding boxes with score less than
|
||||
`score_threshold` are removed. Bounding boxes are supplied as
|
||||
[y1, x1, y2, x2], where (y1, x1) and (y2, x2) are the coordinates of any
|
||||
diagonal pair of box corners and the coordinates can be provided as normalized
|
||||
(i.e., lying in the interval [0, 1]) or absolute. Note that this algorithm
|
||||
is agnostic to where the origin is in the coordinate system and more
|
||||
generally is invariant to orthogonal transformations and translations
|
||||
of the coordinate system; thus translating or reflections of the coordinate
|
||||
system result in the same boxes being selected by the algorithm.
|
||||
The output of this operation is a set of integers indexing into the input
|
||||
collection of bounding boxes representing the selected boxes. The bounding
|
||||
box coordinates corresponding to the selected indices can then be obtained
|
||||
using the `tf.gather operation`. For example:
|
||||
selected_indices = tf.image.non_max_suppression_v2(
|
||||
boxes, scores, max_output_size, iou_threshold, score_threshold)
|
||||
selected_boxes = tf.gather(boxes, selected_indices)
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TFL_FpTensor:$boxes,
|
||||
TFL_FpTensor:$scores,
|
||||
I32Tensor:$max_output_size,
|
||||
TFL_FpTensor:$iou_threshold,
|
||||
TFL_FpTensor:$score_threshold
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
I32Tensor:$selected_indices,
|
||||
I32Tensor:$valid_outputs
|
||||
);
|
||||
}
|
||||
|
||||
def TFL_NonMaxSuppressionV5Op : TFL_Op<"non_max_suppression_v5", [
|
||||
NoSideEffect,
|
||||
// Operand 0 (boxes) should have rank 2 with the dim[1] == 4 (box corners)
|
||||
TFL_OperandHasRank<0, 2>,
|
||||
PredOpTrait<"boxes should have dim[1] == 4",
|
||||
TFL_OperandDimEquals<0, 1, 4>>,
|
||||
// Operand 1 (scores) should be a 1-dim tensor
|
||||
TFL_OperandHasRank<1, 1>,
|
||||
// Other operands are scalar params.
|
||||
TFL_OperandHasRank<2, 0>, TFL_OperandHasRank<3, 0>,
|
||||
TFL_OperandHasRank<4, 0>, TFL_OperandHasRank<5, 0>]> {
|
||||
let summary = [{
|
||||
Greedily selects a subset of bounding boxes in descending order of score,
|
||||
}];
|
||||
|
||||
let description = [{
|
||||
pruning away boxes that have high intersection-over-union (IOU) overlap
|
||||
with previously selected boxes. Bounding boxes with score less than
|
||||
`score_threshold` are removed. Bounding boxes are supplied as
|
||||
[y1, x1, y2, x2], where (y1, x1) and (y2, x2) are the coordinates of any
|
||||
diagonal pair of box corners and the coordinates can be provided as normalized
|
||||
(i.e., lying in the interval [0, 1]) or absolute. Note that this algorithm
|
||||
is agnostic to where the origin is in the coordinate system and more
|
||||
generally is invariant to orthogonal transformations and translations
|
||||
of the coordinate system; thus translating or reflections of the coordinate
|
||||
system result in the same boxes being selected by the algorithm.
|
||||
The output of this operation is a set of integers indexing into the input
|
||||
collection of bounding boxes representing the selected boxes. The bounding
|
||||
box coordinates corresponding to the selected indices can then be obtained
|
||||
using the `tf.gather operation`. For example:
|
||||
selected_indices = tf.image.non_max_suppression_v2(
|
||||
boxes, scores, max_output_size, iou_threshold, score_threshold)
|
||||
selected_boxes = tf.gather(boxes, selected_indices)
|
||||
This op also supports a Soft-NMS (with Gaussian weighting) mode (c.f.
|
||||
Bodla et al, https://arxiv.org/abs/1704.04503) where boxes reduce the score
|
||||
of other overlapping boxes instead of directly causing them to be pruned.
|
||||
To enable this Soft-NMS mode, set the `soft_nms_sigma` parameter to be
|
||||
larger than 0.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TFL_FpTensor:$boxes,
|
||||
TFL_FpTensor:$scores,
|
||||
I32Tensor:$max_output_size,
|
||||
TFL_FpTensor:$iou_threshold,
|
||||
TFL_FpTensor:$score_threshold,
|
||||
TFL_FpTensor:$soft_nms_sigma
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
I32Tensor:$selected_indices,
|
||||
TFL_FpTensor:$selected_scores,
|
||||
I32Tensor:$valid_outputs
|
||||
);
|
||||
}
|
||||
|
||||
def TFL_NotEqualOp : TFL_Op<"not_equal", [
|
||||
Broadcastable, Commutative, NoSideEffect, NoQuantizableResult]> {
|
||||
let summary = "Not_equal operator";
|
||||
@ -987,11 +1144,11 @@ def TFL_L2NormalizationOp : TFL_Op<"l2_normalization", [NoSideEffect]> {
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TensorOf<[F32, QUI8, QI8, I8]>:$input,
|
||||
TensorOf<[F32, QUI8, QI8, QUI16, QI16, I8]>:$input,
|
||||
TFL_AFAttr:$fused_activation_function
|
||||
);
|
||||
|
||||
let results = (outs TensorOf<[F32, QUI8, QI8, I8]>:$output);
|
||||
let results = (outs TensorOf<[F32, QUI8, QI8, QUI16, QI16, I8]>:$output);
|
||||
|
||||
let hasOptions = 1;
|
||||
|
||||
@ -1100,9 +1257,9 @@ def TFL_LogisticOp: TFL_Op<"logistic", [
|
||||
Computes element-wise Sigmoid of input
|
||||
}];
|
||||
|
||||
let arguments = (ins TensorOf<[AnyFloat, QI8, QUI8]>:$x);
|
||||
let arguments = (ins TensorOf<[AnyFloat, QI8, QUI8, QI16, QUI16]>:$x);
|
||||
|
||||
let results = (outs TensorOf<[AnyFloat, QI8, QUI8]>:$y);
|
||||
let results = (outs TensorOf<[AnyFloat, QI8, QUI8, QI16, QUI16]>:$y);
|
||||
}
|
||||
|
||||
def TFL_LogOp: TFL_Op<"log", [NoSideEffect, SameOperandsAndResultType]> {
|
||||
@ -1441,7 +1598,7 @@ def TFL_NegOp: TFL_Op<"neg", [NoSideEffect, SameOperandsAndResultType]> {
|
||||
let hasOptions = 0b1;
|
||||
}
|
||||
|
||||
def TFL_PackOp : TFL_Op<"pack", [NoSideEffect]> {
|
||||
def TFL_PackOp : TFL_Op<"pack", [NoSideEffect, SameOperandsAndResultsScale]> {
|
||||
let summary = "Packs a list of tensors along a dimension into one tensor";
|
||||
|
||||
let description = [{
|
||||
@ -1472,14 +1629,14 @@ def TFL_PackOp : TFL_Op<"pack", [NoSideEffect]> {
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
Variadic<TensorOf<[F32, I8, I16, I32, I64]>>:$values,
|
||||
Variadic<TensorOf<[F32, I8, I16, I32, I64, QI8, QUI8]>>:$values,
|
||||
|
||||
I32Attr:$values_count,
|
||||
I32Attr:$axis
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TensorOf<[F32, I8, I16, I32, I64]>:$output
|
||||
TensorOf<[F32, I8, I16, I32, I64, QI8, QUI8]>:$output
|
||||
);
|
||||
|
||||
let verifier = [{ return Verify(*this); }];
|
||||
@ -1777,8 +1934,7 @@ def TFL_ReverseV2Op: TFL_Op<"reverse_v2",
|
||||
}
|
||||
|
||||
def TFL_SelectOp : TFL_Op<"select", [NoSideEffect,
|
||||
// TODO(jpienaar): This is too retrictive, rank 1 input is also allowed.
|
||||
SameOperandsAndResultShape,
|
||||
SelectShapeConstraints,
|
||||
PredOpTrait<"operands have same element type", TCopVTEtIsSameAs<1, 2>>,
|
||||
PredOpTrait<"operands and result have same element type",
|
||||
TCresVTEtIsSameAsOp<0, 1>>]> {
|
||||
@ -1836,7 +1992,7 @@ def TFL_SoftmaxOp : TFL_Op<"softmax", [
|
||||
let summary = "Softmax operator";
|
||||
|
||||
let description = [{
|
||||
Computes element-wise softmax activiations with the following formula
|
||||
Computes element-wise softmax activations with the following formula
|
||||
|
||||
exp(input) / tf.reduce_sum(exp(input * beta), dim)
|
||||
}];
|
||||
@ -1942,9 +2098,9 @@ def TFL_TanhOp: TFL_Op<"tanh", [
|
||||
Computes element-wise Hyperbolic tangent of input
|
||||
}];
|
||||
|
||||
let arguments = (ins TensorOf<[F32, I16, I8, QI8, QUI8, TFL_Uint8]>:$x);
|
||||
let arguments = (ins TensorOf<[F32, I16, I8, QI8, QUI8, QI16, QUI16, TFL_Uint8]>:$x);
|
||||
|
||||
let results = (outs TensorOf<[F32, I16, I8, QI8, QUI8, TFL_Uint8]>:$y);
|
||||
let results = (outs TensorOf<[F32, I16, I8, QI8, QUI8, QI16, QUI16, TFL_Uint8]>:$y);
|
||||
}
|
||||
|
||||
def TFL_TileOp: TFL_Op<"tile", [NoSideEffect,
|
||||
@ -1999,8 +2155,6 @@ def TFL_TopKV2Op: TFL_Op<"topk_v2", [NoSideEffect, TFL_OperandHasRank<1,0>,
|
||||
let hasOptions = 1;
|
||||
}
|
||||
|
||||
// TODO: Verify result shape a permutation of the first input shape's
|
||||
// dimensions.
|
||||
def TFL_TransposeOp : TFL_Op<"transpose",
|
||||
[NoSideEffect,
|
||||
TFL_OperandHasRank<1,1>,
|
||||
@ -2025,6 +2179,8 @@ def TFL_TransposeOp : TFL_Op<"transpose",
|
||||
AnyTensor:$y
|
||||
);
|
||||
|
||||
let verifier = [{ return Verify(*this); }];
|
||||
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
@ -2342,7 +2498,8 @@ def TFL_StridedSliceOp: TFL_Op<"strided_slice",
|
||||
let hasOptions = 1;
|
||||
}
|
||||
|
||||
def TFL_CastOp : TFL_Op<"cast", [NoSideEffect, SameOperandsAndResultShape]> {
|
||||
def TFL_CastOp : TFL_Op<"cast", [
|
||||
NoSideEffect, SameOperandsAndResultShape, NoQuantizableResult]> {
|
||||
let summary = "Cast operator";
|
||||
|
||||
let description = [{
|
||||
@ -2629,6 +2786,10 @@ Ba et al. “Layer Normalization”
|
||||
|
||||
let results = (outs AnyTensor:$output);
|
||||
|
||||
// TODO(fengliuai): customize printer and parser to not display
|
||||
// empty region.
|
||||
let regions = (region AnyRegion:$internal);
|
||||
|
||||
let hasOptions = 1;
|
||||
|
||||
let verifier = [{ return Verify(*this); }];
|
||||
|
@ -549,32 +549,42 @@ QuantParams QuantizationDriver::GetQuantParamsForSameScaleConstraint(
|
||||
|
||||
void QuantizationDriver::PreprocessConstantOps() {
|
||||
fn_.walk([&](ConstantOp cst) {
|
||||
// Non-float tensors are neither weights or require quantization.
|
||||
if (!cst.getType().cast<ShapedType>().getElementType().isa<FloatType>()) {
|
||||
return;
|
||||
}
|
||||
// Non-float tensors are neither weights nor require quantization.
|
||||
auto type = cst.getType().dyn_cast<ShapedType>();
|
||||
if (!type || !type.getElementType().isa<FloatType>()) return;
|
||||
|
||||
Value *value = cst.getResult();
|
||||
SmallVector<std::pair<Operation *, int>, 4> bias_users;
|
||||
bool used_as_weight = false;
|
||||
for (auto &use : value->getUses()) {
|
||||
auto spec = GetQuantSpec(use.getOwner());
|
||||
auto biases = spec->biases_params;
|
||||
Operation *user = use.getOwner();
|
||||
int operand_num = use.getOperandNumber();
|
||||
|
||||
// The user doesn't use this value as a bias operand nor require same
|
||||
// scale.
|
||||
// The user doesn't use this value as a bias operand or require same
|
||||
// scale, then this constant is considered to be a weight.
|
||||
if (biases.find(operand_num) == biases.end() &&
|
||||
!spec->requires_same_scale) {
|
||||
weights_.insert(cst);
|
||||
used_as_weight = true;
|
||||
} else {
|
||||
bias_users.push_back({user, operand_num});
|
||||
}
|
||||
}
|
||||
builder_.setInsertionPoint(cst);
|
||||
for (int i = 1; i < bias_users.size(); ++i) {
|
||||
|
||||
// If the constant is used as a weight, this constant will be duplicated for
|
||||
// each bias user, so it isn't shared with the weight usage. Otherwise, the
|
||||
// first bias user can use the original constant and the rest use the
|
||||
// duplications, so we pop bias user from the set.
|
||||
if (used_as_weight) {
|
||||
weights_.insert(cst);
|
||||
} else {
|
||||
bias_users.pop_back();
|
||||
builder_.setInsertionPoint(cst);
|
||||
}
|
||||
for (auto bias_user : bias_users) {
|
||||
auto copied = builder_.create<ConstantOp>(cst.getLoc(), cst.getValue());
|
||||
bias_users[i].first->setOperand(bias_users[i].second, copied.getResult());
|
||||
bias_user.first->setOperand(bias_user.second, copied.getResult());
|
||||
}
|
||||
});
|
||||
}
|
||||
|
@ -32,7 +32,7 @@ static Type GetQuantizedType(Builder builder, Type input_type, double min,
|
||||
double max, int storage_type_width,
|
||||
bool narrow_range, bool is_signed) {
|
||||
auto converter =
|
||||
quant::ExpressedToUniformQuantizedConverter::forInputType(input_type);
|
||||
quant::ExpressedToQuantizedConverter::forInputType(input_type);
|
||||
|
||||
quant::UniformQuantizedType quantizedEleType = quant::fakeQuantAttrsToType(
|
||||
builder.getUnknownLoc(), storage_type_width, min, max, narrow_range,
|
||||
|
@ -13,6 +13,11 @@ func @extractSimpleOphint() {
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: func @d4b1eb00b81211e99426dc4a3e957995(tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32>
|
||||
// CHECK: attributes {_tflite_function_input_index = [0 : i32], _tflite_function_name = "cool_activation"}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: extractPackedInputOphint
|
||||
func @extractPackedInputOphint() {
|
||||
// CHECK: %[[PACK:[0-9]*]] = "tfl.pack"(%0, %1) {axis = 0 : i32, values_count = 2 : i32} : (tensor<1x16x1xf32>, tensor<1x16x1xf32>) -> tensor<2x1x16x1xf32>
|
||||
@ -30,6 +35,11 @@ func @extractPackedInputOphint() {
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: func @47393154b9af11e99426dc4a3e957995(tensor<2x1x16x1xf32>) -> tensor<1x16x1xf32>
|
||||
// CHECK: attributes {_tflite_function_input_index = [0 : i32], _tflite_function_name = "cool_activation_stack"}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: extractFirstInputOphint
|
||||
func @extractFirstInputOphint() {
|
||||
// CHECK: %[[OP_HINT_CALL:[0-9]*]] = call @b703f0f4b9ec11e99426dc4a3e957995(%0) : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32>
|
||||
@ -46,6 +56,11 @@ func @extractFirstInputOphint() {
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: func @b703f0f4b9ec11e99426dc4a3e957995(tensor<1x16x1xf32>) -> tensor<1x16x1xf32>
|
||||
// CHECK: attributes {_tflite_function_input_index = [0 : i32], _tflite_function_name = "cool_activation_first"}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: extractLastInputOphint
|
||||
func @extractLastInputOphint() {
|
||||
// CHECK: %[[OP_HINT_CALL:[0-9]*]] = call @e31fcf90b9ed11e99426dc4a3e957995(%1) : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32>
|
||||
@ -62,6 +77,11 @@ func @extractLastInputOphint() {
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: func @e31fcf90b9ed11e99426dc4a3e957995(tensor<1x16x1xf32>) -> tensor<1x16x1xf32>
|
||||
// CHECK: attributes {_tflite_function_input_index = [0 : i32], _tflite_function_name = "cool_activation_last"}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: extractPackOneInputOphint
|
||||
func @extractPackOneInputOphint() {
|
||||
// CHECK: %[[RESHAPE:[0-9]*]] = "tfl.reshape"(%0) : (tensor<1x16x1xf32>) -> tensor<1x1x16x1xf32>
|
||||
@ -75,13 +95,16 @@ func @extractPackOneInputOphint() {
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: func @33fab028b9ef11e99426dc4a3e957995(tensor<1x1x16x1xf32>) -> tensor<1x16x1xf32>
|
||||
// CHECK: attributes {_tflite_function_input_index = [0 : i32], _tflite_function_name = "cool_activation_pack_input_one"}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: extractStackInputOutputOphint
|
||||
func @extractStackInputOutputOphint() {
|
||||
// CHECK: %[[PACK:[0-9]*]] = "tfl.pack"(%0, %1) {axis = 0 : i32, values_count = 2 : i32} : (tensor<1x16x1xf32>, tensor<1x16x1xf32>) -> tensor<2x1x16x1xf32>
|
||||
// CHECK: %[[OP_HINT_CALL:[0-9]*]] = call @b92ed354b9f011e99426dc4a3e957995(%[[PACK]]) : (tensor<2x1x16x1xf32>) -> tensor<2x1x16x1xf32>
|
||||
// CHECK: %[[UNPACK:[0-9]*]]:2 = "tfl.unpack"(%[[OP_HINT_CALL]]) {axis = 0 : i32, num = 2 : i32} : (tensor<2x1x16x1xf32>) -> (tensor<1x16x1xf32>, tensor<1x16x1xf32>)
|
||||
// CHECK: %[[OUTPUT1:[0-9]*]] = "tf.Identity"(%[[UNPACK]]#0) {T = "tfdtype$DT_FLOAT", _tflite_function_aggregate = "stack", _tflite_function_name = "cool_activation_stack_input_output", _tflite_function_output_index = 0 : i64, _tflite_function_sort_index = 0 : i64, _tflite_function_uuid = "b92ed354b9f011e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "OutputHint-cool_activation_stack_input_output-b92ed354b9f011e99426dc4a3e957995-0-0-None"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32>
|
||||
// CHECK: %[[OUTPUT2:[0-9]*]] = "tf.Identity"(%[[UNPACK]]#1) {T = "tfdtype$DT_FLOAT", _tflite_function_aggregate = "stack", _tflite_function_name = "cool_activation_stack_input_output", _tflite_function_output_index = 0 : i64, _tflite_function_sort_index = 1 : i64, _tflite_function_uuid = "b92ed354b9f011e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "OutputHint-cool_activation_stack_input_output-b92ed354b9f011e99426dc4a3e957995-0-1-None"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32>
|
||||
|
||||
%0 = "tf.Placeholder"() {dtype = "tfdtype$DT_FLOAT", name = "Placeholder", shape = "tfshape$dim { size: 1 } dim { size: 16 } dim { size: 1 }"} : () -> tensor<1x16x1xf32>
|
||||
%1 = "tf.Identity"(%0) {T = "tfdtype$DT_FLOAT", _tflite_function_aggregate = "stack", _tflite_function_input_index = 0 : i64, _tflite_function_name = "cool_activation_stack_input_output", _tflite_function_sort_index = 0 : i64, _tflite_function_uuid = "b92ed354b9f011e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "InputHint-cool_activation_stack_input_output-b92ed354b9f011e99426dc4a3e957995-0-0-None"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32>
|
||||
@ -98,11 +121,14 @@ func @extractStackInputOutputOphint() {
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: func @b92ed354b9f011e99426dc4a3e957995(tensor<2x1x16x1xf32>) -> tensor<2x1x16x1xf32>
|
||||
// CHECK: attributes {_tflite_function_input_index = [0 : i32], _tflite_function_name = "cool_activation_stack_input_output"}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: extractMultipleInputsOutputsOphint
|
||||
func @extractMultipleInputsOutputsOphint() {
|
||||
// CHECK: %[[OP_HINT_CALL:[0-9]*]]:2 = call @a6ca45beb9f411e99426dc4a3e957995(%0, %1) : (tensor<1x16x1xf32>, tensor<1x16x1xf32>) -> (tensor<1x16x1xf32>, tensor<1x16x1xf32>)
|
||||
// CHECK: %[[OUTPUT1:[0-9]*]] = "tf.Identity"(%[[OP_HINT_CALL]]#0) {T = "tfdtype$DT_FLOAT", _tflite_function_name = "cool_activation_multiple_input_output", _tflite_function_output_index = 0 : i64, _tflite_function_uuid = "a6ca45beb9f411e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "OutputHint-cool_activation_multiple_input_output-a6ca45beb9f411e99426dc4a3e957995-0-None-None"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32>
|
||||
// CHECK: %[[OUTPUT2:[0-9]*]] = "tf.Identity"(%[[OP_HINT_CALL]]#1) {T = "tfdtype$DT_FLOAT", _tflite_function_name = "cool_activation_multiple_input_output", _tflite_function_output_index = 1 : i64, _tflite_function_uuid = "a6ca45beb9f411e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "OutputHint-cool_activation_multiple_input_output-a6ca45beb9f411e99426dc4a3e957995-1-None-None"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32>
|
||||
// CHECK: %[[MULTI_INPUT_CALL:[0-9]*]]:2 = call @a6ca45beb9f411e99426dc4a3e957995(%0, %1) : (tensor<1x16x1xf32>, tensor<1x16x1xf32>) -> (tensor<1x16x1xf32>, tensor<1x16x1xf32>)
|
||||
|
||||
%0 = "tf.Placeholder"() {dtype = "tfdtype$DT_FLOAT", name = "Placeholder", shape = "tfshape$dim { size: 1 } dim { size: 16 } dim { size: 1 }"} : () -> tensor<1x16x1xf32>
|
||||
%1 = "tf.Identity"(%0) {T = "tfdtype$DT_FLOAT", _tflite_function_input_index = 0 : i64, _tflite_function_name = "cool_activation_multiple_input_output", _tflite_function_uuid = "a6ca45beb9f411e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "InputHint-cool_activation_multiple_input_output-a6ca45beb9f411e99426dc4a3e957995-0-None-None"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32>
|
||||
@ -119,21 +145,33 @@ func @extractMultipleInputsOutputsOphint() {
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: func @d4b1eb00b81211e99426dc4a3e957995(tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32>
|
||||
// CHECK: attributes {_tflite_function_name = "cool_activation"}
|
||||
// CHECK: func @47393154b9af11e99426dc4a3e957995(tensor<2x1x16x1xf32>) -> tensor<1x16x1xf32>
|
||||
// CHECK: attributes {_tflite_function_name = "cool_activation_stack"}
|
||||
// CHECK: func @b703f0f4b9ec11e99426dc4a3e957995(tensor<1x16x1xf32>) -> tensor<1x16x1xf32>
|
||||
// CHECK: attributes {_tflite_function_name = "cool_activation_first"}
|
||||
// CHECK: func @e31fcf90b9ed11e99426dc4a3e957995(tensor<1x16x1xf32>) -> tensor<1x16x1xf32>
|
||||
// CHECK: attributes {_tflite_function_name = "cool_activation_last"}
|
||||
// CHECK: func @33fab028b9ef11e99426dc4a3e957995(tensor<1x1x16x1xf32>) -> tensor<1x16x1xf32>
|
||||
// CHECK: attributes {_tflite_function_name = "cool_activation_pack_input_one"}
|
||||
// CHECK: func @b92ed354b9f011e99426dc4a3e957995(tensor<2x1x16x1xf32>) -> tensor<2x1x16x1xf32>
|
||||
// CHECK: attributes {_tflite_function_name = "cool_activation_stack_input_output"}
|
||||
// CHECK: func @a6ca45beb9f411e99426dc4a3e957995(tensor<1x16x1xf32>, tensor<1x16x1xf32>) -> (tensor<1x16x1xf32>, tensor<1x16x1xf32>)
|
||||
// CHECK: attributes {_tflite_function_name = "cool_activation_multiple_input_output"}
|
||||
// CHECK: attributes {_tflite_function_input_index = [0 : i32, 1 : i32], _tflite_function_name = "cool_activation_multiple_input_output"}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: inputsAfterOutputs
|
||||
func @inputsAfterOutputs() {
|
||||
// CHECK: %[[PLACE_HOLDER:[0-9]*]] = "tf.Placeholder"() {device = "", dtype = "tfdtype$DT_FLOAT", name = "Placeholder_1", shape = "tfshape$dim { size: 2 } dim { size: 2 }"} : () -> tensor<2x2xf32>
|
||||
// CHECK: %[[INPUT_PROCESS:[0-9]*]] = "tf.Sigmoid"(%[[PLACE_HOLDER]]) {T = "tfdtype$DT_FLOAT", device = "", name = "Sigmoid"} : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// CHECK: %[[OP_HINT_CALL:[0-9]*]]:2 = call @d6266124d2dd11e9b52cdc4a3e957995(%0, %1, %[[INPUT_PROCESS]]) : (tensor<2x2xf32>, tensor<f32>, tensor<2x2xf32>) -> (tensor<2x2xf32>, tensor<2x2xf32>)
|
||||
|
||||
%0 = "tf.Const"() {device = "", dtype = "tfdtype$DT_FLOAT", name = "Const", value = dense<0.000000e+00> : tensor<f32>} : () -> tensor<f32>
|
||||
%1 = "tf.Identity"(%0) {T = "tfdtype$DT_FLOAT", _tflite_function_input_index = 1 : i64, _tflite_function_name = "CustomOp", _tflite_function_uuid = "d6266124d2dd11e9b52cdc4a3e957995", _tflite_ophint_level = 1 : i64, device = "", name = "InputHint-CustomOp-d6266124d2dd11e9b52cdc4a3e957995-1-None-None"} : (tensor<f32>) -> tensor<f32>
|
||||
%2 = "tf.Placeholder"() {device = "", dtype = "tfdtype$DT_FLOAT", name = "Placeholder", shape = "tfshape$dim { size: 2 } dim { size: 2 }"} : () -> tensor<2x2xf32>
|
||||
%3 = "tf.Identity"(%2) {T = "tfdtype$DT_FLOAT", _tflite_function_input_index = 0 : i64, _tflite_function_name = "CustomOp", _tflite_function_uuid = "d6266124d2dd11e9b52cdc4a3e957995", _tflite_ophint_level = 1 : i64, device = "", name = "InputHint-CustomOp-d6266124d2dd11e9b52cdc4a3e957995-0-None-None"} : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
%4 = "tf.Add"(%3, %1) {T = "tfdtype$DT_FLOAT", device = "", name = "Add"} : (tensor<2x2xf32>, tensor<f32>) -> tensor<2x2xf32>
|
||||
%5 = "tf.Identity"(%4) {T = "tfdtype$DT_FLOAT", _tflite_function_name = "CustomOp", _tflite_function_output_index = 0 : i64, _tflite_function_uuid = "d6266124d2dd11e9b52cdc4a3e957995", _tflite_ophint_level = 1 : i64, device = "", name = "OutputHint-CustomOp-d6266124d2dd11e9b52cdc4a3e957995-0-None-None"} : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
%6 = "tf.Placeholder"() {device = "", dtype = "tfdtype$DT_FLOAT", name = "Placeholder_1", shape = "tfshape$dim { size: 2 } dim { size: 2 }"} : () -> tensor<2x2xf32>
|
||||
%7 = "tf.Sigmoid"(%6) {T = "tfdtype$DT_FLOAT", device = "", name = "Sigmoid"} : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
%8 = "tf.Identity"(%7) {T = "tfdtype$DT_FLOAT", _tflite_function_input_index = 2 : i64, _tflite_function_name = "CustomOp", _tflite_function_uuid = "d6266124d2dd11e9b52cdc4a3e957995", _tflite_ophint_level = 1 : i64, device = "", name = "InputHint-CustomOp-d6266124d2dd11e9b52cdc4a3e957995-2-None-None"} : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
%9 = "tf.Add"(%5, %8) {T = "tfdtype$DT_FLOAT", device = "", name = "Add_1"} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
%10 = "tf.Identity"(%9) {T = "tfdtype$DT_FLOAT", _tflite_function_name = "CustomOp", _tflite_function_output_index = 1 : i64, _tflite_function_uuid = "d6266124d2dd11e9b52cdc4a3e957995", _tflite_ophint_level = 1 : i64, device = "", name = "OutputHint-CustomOp-d6266124d2dd11e9b52cdc4a3e957995-1-None-None"} : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: func @d6266124d2dd11e9b52cdc4a3e957995(tensor<2x2xf32>, tensor<f32>, tensor<2x2xf32>) -> (tensor<2x2xf32>, tensor<2x2xf32>)
|
||||
// CHECK: attributes {_tflite_function_input_index = [0 : i32, 1 : i32, 2 : i32], _tflite_function_name = "CustomOp"}
|
||||
|
||||
// -----
|
||||
|
||||
|
@ -50,7 +50,7 @@ func @biasAddInt(%arg0: tensor<1x10x10x32xi32>, %arg1: tensor<32xi32>) -> tensor
|
||||
func @squeezeAndReshape(%arg0: tensor<1x1x10xf32>, %arg1: tensor<?x10xf32>) -> i32 {
|
||||
%0 = "tf.Squeeze"(%arg0) {squeeze_dims = [0]} : (tensor<1x1x10xf32>) -> tensor<1x10xf32>
|
||||
%1 = "tf.Squeeze"(%arg1) : (tensor<?x10xf32>) -> tensor<*xf32>
|
||||
%2 = constant dense<[2, 5]> : tensor<2xi32>
|
||||
%2 = "tf.Const"() { value = dense<[2, 5]> : tensor<2xi32> } : () -> tensor<2xi32>
|
||||
%3 = "tf.Reshape" (%0, %2) : (tensor<1x10xf32>, tensor<2xi32>) -> tensor<2x5xf32>
|
||||
%4 = "some_op"(%1, %3) : (tensor<*xf32>, tensor<2x5xf32>) -> i32
|
||||
return %4 : i32
|
||||
@ -119,8 +119,8 @@ func @fakeQuantArgsTrue(%arg0: tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> {
|
||||
}
|
||||
|
||||
func @fakeQuantVarsFalse(%arg0: tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> {
|
||||
%arg1 = constant dense<-0.1> : tensor<f32>
|
||||
%arg2 = constant dense<0.2> : tensor<f32>
|
||||
%arg1 = "tf.Const"() { value = dense<-0.1> : tensor<f32> } : () -> tensor<f32>
|
||||
%arg2 = "tf.Const"() { value = dense<0.2> : tensor<f32> } : () -> tensor<f32>
|
||||
%0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %arg1, %arg2) {num_bits = 3, narrow_range = false} : (tensor<8x8x8x8xf32>, tensor<f32>, tensor<f32>) -> tensor<8x8x8x8xf32>
|
||||
return %0 : tensor<8x8x8x8xf32>
|
||||
|
||||
@ -153,6 +153,14 @@ func @placeholder(%arg0: tensor<f32>) -> tensor<f32> {
|
||||
// CHECK: %0 = "tfl.pseudo_input"(%arg0) : (tensor<f32>) -> tensor<f32>
|
||||
}
|
||||
|
||||
func @placeholder_int(%arg0: tensor<i32>) -> tensor<i32> {
|
||||
%0 = "tf.Placeholder.input"(%arg0) {name = "Input"} : (tensor<i32>) -> tensor<i32>
|
||||
return %0: tensor<i32>
|
||||
|
||||
// CHECK-LABEL: @placeholder_int
|
||||
// CHECK-NEXT: "tfl.pseudo_input"(%arg0) : (tensor<i32>) -> tensor<i32>
|
||||
}
|
||||
|
||||
func @placeholder_min(%arg0: tensor<f32>) -> tensor<f32> {
|
||||
%0 = "tf.Placeholder.input"(%arg0) {name = "Input", min = -0.1 : f32} : (tensor<f32>) -> tensor<f32>
|
||||
return %0: tensor<f32>
|
||||
@ -409,7 +417,7 @@ func @gatherNdHigherRankIndices(%arg0 : tensor<4x3x2xf32>, %arg1 : tensor<2x2xi3
|
||||
}
|
||||
|
||||
func @gatherV2VectorIndices(%arg0 : tensor<1x2x20xf32>, %arg1 : tensor<3x5xi32>) -> tensor<1x3x5x20xf32> {
|
||||
%0 = constant dense<[1]> : tensor<1xi32>
|
||||
%0 = "tf.Const"() { value = dense<[1]> : tensor<1xi32> } : () -> tensor<1xi32>
|
||||
%1 = "tf.GatherV2"(%arg0, %arg1, %0) : (tensor<1x2x20xf32>, tensor<3x5xi32>, tensor<1xi32>) -> tensor<1x3x5x20xf32>
|
||||
return %1 : tensor<1x3x5x20xf32>
|
||||
|
||||
@ -418,7 +426,7 @@ func @gatherV2VectorIndices(%arg0 : tensor<1x2x20xf32>, %arg1 : tensor<3x5xi32>)
|
||||
}
|
||||
|
||||
func @gatherV2VectorIndicesNegAxis(%arg0 : tensor<1x2x20xf32>, %arg1 : tensor<3x5xi32>) -> tensor<1x2x3x5xf32> {
|
||||
%0 = constant dense<[-1]> : tensor<1xi32>
|
||||
%0 = "tf.Const"() { value = dense<[-1]> : tensor<1xi32> } : () -> tensor<1xi32>
|
||||
%1 = "tf.GatherV2"(%arg0, %arg1, %0) : (tensor<1x2x20xf32>, tensor<3x5xi32>, tensor<1xi32>) -> tensor<1x2x3x5xf32>
|
||||
return %1 : tensor<1x2x3x5xf32>
|
||||
|
||||
@ -427,7 +435,7 @@ func @gatherV2VectorIndicesNegAxis(%arg0 : tensor<1x2x20xf32>, %arg1 : tensor<3x
|
||||
}
|
||||
|
||||
func @gatherV2NonZeroBatchDims(%arg0 : tensor<1x2x20xf32>, %arg1 : tensor<3x5xi32>) -> tensor<1x2x3x5xf32> {
|
||||
%0 = constant dense<[1]> : tensor<1xi32>
|
||||
%0 = "tf.Const"() { value = dense<[1]> : tensor<1xi32> } : () -> tensor<1xi32>
|
||||
%1 = "tf.GatherV2"(%arg0, %arg1, %0) {batch_dims = 1 : i64} : (tensor<1x2x20xf32>, tensor<3x5xi32>, tensor<1xi32>) -> tensor<1x2x3x5xf32>
|
||||
return %1 : tensor<1x2x3x5xf32>
|
||||
|
||||
@ -509,6 +517,15 @@ func @select(%arg0: tensor<8xi1>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>) ->
|
||||
// CHECK: return %0 : tensor<8xf32>
|
||||
}
|
||||
|
||||
func @select_multidim(%arg0: tensor<8xi1>, %arg1: tensor<8x3xf32>, %arg2: tensor<8x3xf32>) -> tensor<8x3xf32> {
|
||||
%0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<8xi1>, tensor<8x3xf32>, tensor<8x3xf32>) -> tensor<8x3xf32>
|
||||
return %0: tensor<8x3xf32>
|
||||
|
||||
// CHECK-LABEL: select_multidim
|
||||
// CHECK: %0 = "tfl.select"(%arg0, %arg1, %arg2)
|
||||
// CHECK: return %0 : tensor<8x3xf32>
|
||||
}
|
||||
|
||||
func @select_v2(%arg0: tensor<8xi1>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>) -> tensor<8xf32> {
|
||||
%0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<8xi1>, tensor<8xf32>, tensor<8xf32>) -> tensor<8xf32>
|
||||
return %0: tensor<8xf32>
|
||||
@ -518,6 +535,15 @@ func @select_v2(%arg0: tensor<8xi1>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>)
|
||||
// CHECK: return %0 : tensor<8xf32>
|
||||
}
|
||||
|
||||
func @select_v2_multidim(%arg0: tensor<8xi1>, %arg1: tensor<8x3xf32>, %arg2: tensor<8x3xf32>) -> tensor<8x3xf32> {
|
||||
%0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<8xi1>, tensor<8x3xf32>, tensor<8x3xf32>) -> tensor<8x3xf32>
|
||||
return %0: tensor<8x3xf32>
|
||||
|
||||
// CHECK-LABEL: select_v2_multidim
|
||||
// CHECK: %0 = "tfl.select"(%arg0, %arg1, %arg2)
|
||||
// CHECK: return %0 : tensor<8x3xf32>
|
||||
}
|
||||
|
||||
func @sin(%arg0: tensor<f32>) -> tensor<f32> {
|
||||
%0 = "tf.Sin"(%arg0) : (tensor<f32>) -> tensor<f32>
|
||||
return %0 : tensor<f32>
|
||||
@ -536,7 +562,7 @@ func @topk(%arg0: tensor<8xf32>, %arg1: tensor<i32>) -> (tensor<?xf32>, tensor<?
|
||||
}
|
||||
|
||||
func @topk_2(%arg0: tensor<8xf32>) -> (tensor<2xf32>, tensor<2xi32>) {
|
||||
%0 = constant dense<2> : tensor<i32>
|
||||
%0 = "tf.Const"() { value = dense<2> : tensor<i32> } : () -> tensor<i32>
|
||||
%1:2 = "tf.TopKV2"(%arg0, %0) : (tensor<8xf32>, tensor<i32>) -> (tensor<2xf32>, tensor<2xi32>)
|
||||
return %1#0, %1#1: tensor<2xf32>, tensor<2xi32>
|
||||
|
||||
@ -546,7 +572,7 @@ func @topk_2(%arg0: tensor<8xf32>) -> (tensor<2xf32>, tensor<2xi32>) {
|
||||
}
|
||||
|
||||
func @topk_3(%arg0: tensor<?x8xf32>) -> (tensor<?x2xf32>, tensor<?x2xi32>) {
|
||||
%0 = constant dense<2> : tensor<i32>
|
||||
%0 = "tf.Const"() { value = dense<2> : tensor<i32> } : () -> tensor<i32>
|
||||
%1:2 = "tf.TopKV2"(%arg0, %0) : (tensor<?x8xf32>, tensor<i32>) -> (tensor<?x2xf32>, tensor<?x2xi32>)
|
||||
return %1#0, %1#1: tensor<?x2xf32>, tensor<?x2xi32>
|
||||
|
||||
@ -556,7 +582,7 @@ func @topk_3(%arg0: tensor<?x8xf32>) -> (tensor<?x2xf32>, tensor<?x2xi32>) {
|
||||
}
|
||||
|
||||
func @topk_4(%arg0: tensor<1x2x3x4xf32>) -> (tensor<1x2x3x2xf32>, tensor<1x2x3x2xi32>) {
|
||||
%0 = constant dense<2> : tensor<i32>
|
||||
%0 = "tf.Const"() { value = dense<2> : tensor<i32> } : () -> tensor<i32>
|
||||
%1:2 = "tf.TopKV2"(%arg0, %0) : (tensor<1x2x3x4xf32>, tensor<i32>) -> (tensor<1x2x3x2xf32>, tensor<1x2x3x2xi32>)
|
||||
return %1#0, %1#1: tensor<1x2x3x2xf32>, tensor<1x2x3x2xi32>
|
||||
|
||||
@ -566,7 +592,7 @@ func @topk_4(%arg0: tensor<1x2x3x4xf32>) -> (tensor<1x2x3x2xf32>, tensor<1x2x3x2
|
||||
}
|
||||
|
||||
func @topk_5(%arg0: tensor<*xf32>) -> (tensor<*xf32>, tensor<*xi32>) {
|
||||
%0 = constant dense<2> : tensor<i32>
|
||||
%0 = "tf.Const"() { value = dense<2> : tensor<i32> } : () -> tensor<i32>
|
||||
%1:2 = "tf.TopKV2"(%arg0, %0) : (tensor<*xf32>, tensor<i32>) -> (tensor<*xf32>, tensor<*xi32>)
|
||||
return %1#0, %1#1: tensor<*xf32>, tensor<*xi32>
|
||||
|
||||
@ -671,7 +697,7 @@ func @pow(%arg0: tensor<2x1x3xf32>, %arg1: tensor<2x1x1xf32>) -> tensor<2x1x3xf3
|
||||
|
||||
func @tile(tensor<2x3xf32>, tensor<2xi32>) -> tensor<2x6xf32> {
|
||||
^bb0(%arg0: tensor<2x3xf32>, %arg1: tensor<2xi32>):
|
||||
%cst = constant dense<[1, 2]> : tensor<2xi32>
|
||||
%cst = "tf.Const"() { value = dense<[1, 2]> : tensor<2xi32> } : () -> tensor<2xi32>
|
||||
%0 = "tf.Tile"(%arg0, %cst) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<2x6xf32>
|
||||
return %0 : tensor<2x6xf32>
|
||||
|
||||
@ -682,7 +708,7 @@ func @tile(tensor<2x3xf32>, tensor<2xi32>) -> tensor<2x6xf32> {
|
||||
|
||||
func @padv2(tensor<2x1x3xf32>, tensor<3x2xi32>) -> tensor<? x f32> {
|
||||
^bb0(%arg0: tensor<2x1x3xf32>, %arg1: tensor<3x2xi32>):
|
||||
%cst = constant dense<2.0> : tensor<f32>
|
||||
%cst = "tf.Const"() { value = dense<2.0> : tensor<f32> } : () -> tensor<f32>
|
||||
%0 = "tf.PadV2"(%arg0, %arg1, %cst) : (tensor<2x1x3xf32>, tensor<3x2xi32>, tensor<f32>) -> tensor<? x f32>
|
||||
return %0#0 : tensor<? x f32>
|
||||
|
||||
@ -858,8 +884,8 @@ func @matmul_transposed(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) -> t
|
||||
}
|
||||
|
||||
func @concat2Tensors(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2x2xi32> {
|
||||
%0 = constant dense<[1]> : tensor<1xi32>
|
||||
%1 = "tf.Concat"(%0, %arg0, %arg1) {N = 2 : i64} : (tensor<1xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<2x2xi32>
|
||||
%0 = "tf.Const"() { value = dense<1> : tensor<i32> } : () -> tensor<i32>
|
||||
%1 = "tf.Concat"(%0, %arg0, %arg1) {N = 2 : i64} : (tensor<i32>, tensor<2xi32>, tensor<2xi32>) -> tensor<2x2xi32>
|
||||
return %1 : tensor<2x2xi32>
|
||||
|
||||
// CHECK-LABEL: concat2Tensors
|
||||
@ -867,8 +893,8 @@ func @concat2Tensors(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2x2xi
|
||||
}
|
||||
|
||||
func @concat3Tensors(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>, %arg2: tensor<2xi32>) -> tensor<2x3xi32> {
|
||||
%0 = constant dense<[-1]> : tensor<1xi32>
|
||||
%1 = "tf.Concat"(%0, %arg0, %arg1, %arg2) {N = 3 : i64} : (tensor<1xi32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<2x3xi32>
|
||||
%0 = "tf.Const"() { value = dense<-1> : tensor<i32> } : () -> tensor<i32>
|
||||
%1 = "tf.Concat"(%0, %arg0, %arg1, %arg2) {N = 3 : i64} : (tensor<i32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<2x3xi32>
|
||||
return %1 : tensor<2x3xi32>
|
||||
|
||||
// CHECK-LABEL: concat3Tensors
|
||||
@ -876,8 +902,8 @@ func @concat3Tensors(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>, %arg2: tensor<2
|
||||
}
|
||||
|
||||
func @concatv2With3Tensors(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>, %arg2: tensor<2xi32>) -> tensor<2x3xi32> {
|
||||
%0 = constant dense<[-1]> : tensor<1xi32>
|
||||
%1 = "tf.ConcatV2"(%arg0, %arg1, %arg2, %0) {N = 3 : i64} : (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<1xi32>) -> tensor<2x3xi32>
|
||||
%0 = "tf.Const"() { value = dense<-1> : tensor<i32> } : () -> tensor<i32>
|
||||
%1 = "tf.ConcatV2"(%arg0, %arg1, %arg2, %0) {N = 3 : i64} : (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<i32>) -> tensor<2x3xi32>
|
||||
return %1 : tensor<2x3xi32>
|
||||
|
||||
// CHECK-LABEL: concatv2With3Tensors
|
||||
@ -1093,3 +1119,35 @@ func @depth_to_space(%arg0: tensor<1x1x1x4xf32>) -> tensor<1x2x2x1xf32> {
|
||||
// CHECK: %[[ARG:.*]]: tensor<1x1x1x4xf32>
|
||||
// CHECK: "tfl.depth_to_space"(%[[ARG]]) {block_size = 2 : i32} : (tensor<1x1x1x4xf32>) -> tensor<1x2x2x1xf32>
|
||||
}
|
||||
|
||||
func @non_max_suppression_v4(%arg0: tensor<3x4xf32>, %arg1: tensor<3xf32>, %arg2: tensor<i32>, %arg3: tensor<f32>, %arg4: tensor<f32>) -> tensor<2xi32> {
|
||||
%0:2 = "tf.NonMaxSuppressionV4"(%arg0, %arg1, %arg2, %arg3, %arg4) {pad_to_max_output_size = true}: (tensor<3x4xf32>, tensor<3xf32>, tensor<i32>, tensor<f32>, tensor<f32>) -> (tensor<2xi32>, tensor<i32>)
|
||||
return %0#0 : tensor<2xi32>
|
||||
|
||||
// CHECK-LABEL: non_max_suppression_v4
|
||||
// CHECK: %0:2 = "tfl.non_max_suppression_v4"(%arg0, %arg1, %arg2, %arg3, %arg4) : (tensor<3x4xf32>, tensor<3xf32>, tensor<i32>, tensor<f32>, tensor<f32>) -> (tensor<2xi32>, tensor<i32>)
|
||||
}
|
||||
|
||||
func @non_max_suppression_v4_no_pad(%arg0: tensor<3x4xf32>, %arg1: tensor<3xf32>, %arg2: tensor<i32>, %arg3: tensor<f32>, %arg4: tensor<f32>) -> tensor<2xi32> {
|
||||
%0:2 = "tf.NonMaxSuppressionV4"(%arg0, %arg1, %arg2, %arg3, %arg4) {pad_to_max_output_size = false}: (tensor<3x4xf32>, tensor<3xf32>, tensor<i32>, tensor<f32>, tensor<f32>) -> (tensor<2xi32>, tensor<i32>)
|
||||
return %0#0 : tensor<2xi32>
|
||||
|
||||
// CHECK-LABEL: non_max_suppression_v4_no_pad
|
||||
// CHECK: %0:2 = "tfl.non_max_suppression_v4"(%arg0, %arg1, %arg2, %arg3, %arg4) : (tensor<3x4xf32>, tensor<3xf32>, tensor<i32>, tensor<f32>, tensor<f32>) -> (tensor<2xi32>, tensor<i32>)
|
||||
}
|
||||
|
||||
func @non_max_suppression_v5(%arg0: tensor<3x4xf32>, %arg1: tensor<3xf32>, %arg2: tensor<i32>, %arg3: tensor<f32>, %arg4: tensor<f32>, %arg5: tensor<f32>) -> tensor<2xi32> {
|
||||
%0:3 = "tf.NonMaxSuppressionV5"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {pad_to_max_output_size = true}: (tensor<3x4xf32>, tensor<3xf32>, tensor<i32>, tensor<f32>, tensor<f32>, tensor<f32>) -> (tensor<2xi32>, tensor<2xf32>, tensor<i32>)
|
||||
return %0#0 : tensor<2xi32>
|
||||
|
||||
// CHECK-LABEL: non_max_suppression_v5
|
||||
// CHECK: %0:3 = "tfl.non_max_suppression_v5"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) : (tensor<3x4xf32>, tensor<3xf32>, tensor<i32>, tensor<f32>, tensor<f32>, tensor<f32>) -> (tensor<2xi32>, tensor<2xf32>, tensor<i32>)
|
||||
}
|
||||
|
||||
func @non_max_suppression_v5_no_pad(%arg0: tensor<3x4xf32>, %arg1: tensor<3xf32>, %arg2: tensor<i32>, %arg3: tensor<f32>, %arg4: tensor<f32>, %arg5: tensor<f32>) -> tensor<2xi32> {
|
||||
%0:3 = "tf.NonMaxSuppressionV5"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {pad_to_max_output_size = false}: (tensor<3x4xf32>, tensor<3xf32>, tensor<i32>, tensor<f32>, tensor<f32>, tensor<f32>) -> (tensor<2xi32>, tensor<2xf32>, tensor<i32>)
|
||||
return %0#0 : tensor<2xi32>
|
||||
|
||||
// CHECK-LABEL: non_max_suppression_v5_no_pad
|
||||
// CHECK: %0:3 = "tfl.non_max_suppression_v5"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) : (tensor<3x4xf32>, tensor<3xf32>, tensor<i32>, tensor<f32>, tensor<f32>, tensor<f32>) -> (tensor<2xi32>, tensor<2xf32>, tensor<i32>)
|
||||
}
|
||||
|
@ -0,0 +1,107 @@
|
||||
// RUN: tf-opt -tfl-load-recipe %s | FileCheck %s --dump-input-on-failure
|
||||
|
||||
// CHECK-LABEL: testLstm
|
||||
func @testLstm(%arg0: tensor<? x f32>, %arg1: tensor<?xf32>, %arg2: tensor<?xf32>, %arg3: tensor<?xf32>, %arg4: tensor<?xf32>, %arg5: tensor<?xf32>, %arg6: tensor<?xf32>, %arg7: tensor<?xf32>, %arg8: tensor<? x f32>, %arg9: tensor<? x f32>, %arg10: tensor<? x f32>, %arg11: tensor<? x f32>, %arg12: tensor<? x f32>, %arg13: tensor<? x f32>, %arg14: tensor<? x f32>, %arg15: tensor<? x f32>, %arg16: tensor<? x f32>, %arg17: tensor<? x f32>, %arg18: tensor<? x f32>, %arg19: tensor<? x f32>, %arg20: tensor<? x f32>, %arg21: tensor<? x f32>, %arg22: tensor<? x f32>, %arg23: tensor<? x f32>) -> tensor<? x f32> {
|
||||
%0 = "tfl.lstm"(%arg0, // input
|
||||
%arg1, %arg2, %arg3, %arg4, // weights
|
||||
%arg5, %arg6, %arg7, %arg8, // recurrent weights
|
||||
%arg9, %arg10, %arg11, // cell weights
|
||||
%arg12, %arg13, %arg14, %arg15, // bias
|
||||
%arg16, %arg17, // projection weight and bias
|
||||
%arg18, %arg19, // stateful
|
||||
%arg20, %arg21, %arg22, %arg23 // layer norm coefficients
|
||||
) ({}) {fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor<? xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
|
||||
// CHECK-NEXT: "tfl.lstm"
|
||||
// CHECK-NEXT: %[[cst:.*]] = constant unit
|
||||
|
||||
// input gate
|
||||
// CHECK-NEXT: %[[in1:.*]] = "tfl.fully_connected"(%arg0, %arg1, %[[cst]])
|
||||
// CHECK-SAME: -> tensor<?x!quant.any<i16:f32>>
|
||||
// CHECK-NEXT: %[[in2:.*]] = "tfl.fully_connected"(%arg18, %arg5, %[[cst]])
|
||||
// CHECK-SAME: -> tensor<?x!quant.any<i16:f32>>
|
||||
// CHECK-NEXT: %[[in3:.*]] = "tfl.mul"(%arg19, %arg9)
|
||||
// CHECK-SAME: -> tensor<?x!quant.any<i16:f32>>
|
||||
// CHECK-NEXT: %[[in4:.*]] = "tfl.add_n"(%[[in1]], %[[in2]], %[[in3]])
|
||||
// CHECK-SAME: -> tensor<?x!quant.any<i16:f32>>
|
||||
// CHECK-NEXT: %[[in5:.*]] = "tfl.l2_normalization"(%[[in4]])
|
||||
// CHECK-SAME: -> tensor<?x!quant.any<i16:f32>>
|
||||
// CHECK-NEXT: %[[in6:.*]] = tfl.add %[[in4]], %[[in5]]
|
||||
// CHECK-SAME: tensor<?x!quant.any<i16:f32>>
|
||||
// CHECK-NEXT: %[[in7:.*]] = "tfl.fully_connected"(%[[in6]], %arg20, %arg12)
|
||||
// CHECK-SAME: -> tensor<?x!quant.any<i16:f32>>
|
||||
// CHECK-NEXT: %[[in8:.*]] = "tfl.logistic"(%[[in7]])
|
||||
// CHECK-SAME: -> tensor<?x!quant.any<i16:f32>>
|
||||
|
||||
// forget gate
|
||||
// CHECK-NEXT: %[[fo1:.*]] = "tfl.fully_connected"(%arg0, %arg2, %[[cst]])
|
||||
// CHECK-SAME: tensor<?x!quant.any<i16:f32>>
|
||||
// CHECK-NEXT: %[[fo2:.*]] = "tfl.fully_connected"(%arg18, %arg6, %[[cst]])
|
||||
// CHECK-SAME: -> tensor<?x!quant.any<i16:f32>>
|
||||
// CHECK-NEXT: %[[fo3:.*]] = "tfl.mul"(%arg19, %arg10)
|
||||
// CHECK-SAME: -> tensor<?x!quant.any<i16:f32>>
|
||||
// CHECK-NEXT: %[[fo4:.*]] = "tfl.add_n"(%[[fo1]], %[[fo2]], %[[fo3]])
|
||||
// CHECK-SAME: -> tensor<?x!quant.any<i16:f32>>
|
||||
// CHECK-NEXT: %[[fo5:.*]] = "tfl.l2_normalization"(%[[fo4]])
|
||||
// CHECK-SAME: -> tensor<?x!quant.any<i16:f32>>
|
||||
// CHECK-NEXT: %[[fo6:.*]] = tfl.add %[[fo4]], %[[fo5]]
|
||||
// CHECK-SAME: tensor<?x!quant.any<i16:f32>>
|
||||
// CHECK-NEXT: %[[fo7:.*]] = "tfl.fully_connected"(%[[fo6]], %arg21, %arg13)
|
||||
// CHECK-SAME: -> tensor<?x!quant.any<i16:f32>>
|
||||
// CHECK-NEXT: %[[fo8:.*]] = "tfl.logistic"(%[[fo7]])
|
||||
// CHECK-SAME: -> tensor<?x!quant.any<i16:f32>>
|
||||
|
||||
// cell gate
|
||||
// CHECK-NEXT: %[[ce1:.*]] = "tfl.fully_connected"(%arg0, %arg3, %[[cst]])
|
||||
// CHECK-SAME: -> tensor<?x!quant.any<i16:f32>>
|
||||
// CHECK-NEXT: %[[ce2:.*]] = "tfl.fully_connected"(%arg18, %arg7, %[[cst]])
|
||||
// CHECK-SAME: -> tensor<?x!quant.any<i16:f32>>
|
||||
// CHECK-NEXT: %[[ce3:.*]] = "tfl.add_n"(%[[ce1]], %[[ce2]])
|
||||
// CHECK-SAME: -> tensor<?x!quant.any<i16:f32>>
|
||||
// CHECK-NEXT: %[[ce4:.*]] = "tfl.l2_normalization"(%[[ce3]])
|
||||
// CHECK-SAME: -> tensor<?x!quant.any<i16:f32>>
|
||||
// CHECK-NEXT: %[[ce5:.*]] = tfl.add %[[ce3]], %[[ce4]]
|
||||
// CHECK-SAME: tensor<?x!quant.any<i16:f32>>
|
||||
// CHECK-NEXT: %[[ce6:.*]] = "tfl.fully_connected"(%[[ce5]], %arg22, %arg14)
|
||||
// CHECK-SAME: -> tensor<?x!quant.any<i16:f32>>
|
||||
// CHECK-NEXT: %[[ce7:.*]] = "tfl.tanh"(%[[ce6]])
|
||||
// CHECK-SAME: -> tensor<?x!quant.any<i16:f32>>
|
||||
|
||||
// CHECK-NEXT: %[[ac1:.*]] = "tfl.mul"(%[[fo8]], %arg19)
|
||||
// CHECK-SAME: -> tensor<?x!quant.any<i16:f32>>
|
||||
// CHECK-NEXT: %[[ac2:.*]] = tfl.mul %[[in8]], %[[ce7]]
|
||||
// CHECK-SAME: tensor<?x!quant.any<i16:f32>>
|
||||
// CHECK-NEXT: %[[ac3:.*]] = tfl.add %[[ac1]], %[[ac2]]
|
||||
// CHECK-SAME: tensor<?x!quant.any<i16:f32>>
|
||||
|
||||
// output gate
|
||||
// CHECK-NEXT: %[[ou1:.*]] = "tfl.fully_connected"(%arg0, %arg4, %[[cst]])
|
||||
// CHECK-SAME: -> tensor<?x!quant.any<i16:f32>>
|
||||
// CHECK-NEXT: %[[ou2:.*]] = "tfl.fully_connected"(%arg18, %arg8, %[[cst]])
|
||||
// CHECK-SAME: -> tensor<?x!quant.any<i16:f32>>
|
||||
// CHECK-NEXT: %[[ou3:.*]] = "tfl.mul"(%[[ac3]], %arg11)
|
||||
// CHECK-SAME: -> tensor<?x!quant.any<i16:f32>>
|
||||
// CHECK-NEXT: %[[ou4:.*]] = "tfl.add_n"(%[[ou1]], %[[ou2]], %[[ou3]])
|
||||
// CHECK-SAME: -> tensor<?x!quant.any<i16:f32>>
|
||||
// CHECK-NEXT: %[[ou5:.*]] = "tfl.l2_normalization"(%[[ou4]])
|
||||
// CHECK-SAME: -> tensor<?x!quant.any<i16:f32>>
|
||||
// CHECK-NEXT: %[[ou6:.*]] = tfl.add %[[ou4]], %[[ou5]]
|
||||
// CHECK-SAME: tensor<?x!quant.any<i16:f32>>
|
||||
// CHECK-NEXT: %[[ou7:.*]] = "tfl.fully_connected"(%[[ou6]], %arg23, %arg15)
|
||||
// CHECK-SAME: -> tensor<?x!quant.any<i16:f32>>
|
||||
// CHECK-NEXT: %[[ou8:.*]] = "tfl.logistic"(%[[ou7]])
|
||||
// CHECK-SAME: -> tensor<?x!quant.any<i16:f32>>
|
||||
|
||||
// output activation
|
||||
// CHECK-NEXT: %[[ac4:.*]] = "tfl.tanh"(%[[ac3]])
|
||||
// CHECK-SAME: -> tensor<?x!quant.any<i16:f32>>
|
||||
// CHECK-NEXT: %[[ac5:.*]] = tfl.mul %[[ac4]], %[[ou8]]
|
||||
// CHECK-SAME: tensor<?x!quant.any<i16:f32>>
|
||||
// CHECK-NEXT: %[[ac6:.*]] = "tfl.fully_connected"(%[[ac5]], %arg16, %arg17)
|
||||
// CHECK-SAME: (tensor<?x!quant.any<i16:f32>>, tensor<?xf32>, tensor<?xf32>) -> tensor<?x!quant.any<i8:f32>>
|
||||
// CHECK-NEXT: %[[ac7:.*]] = "tf_quant.pseudo_return"(%[[ac6]]) : (tensor<?x!quant.any<i8:f32>>) -> tensor<?x!quant.any<i8:f32>>
|
||||
// CHECK-NEXT: })
|
||||
// CHECK-NEXT: return
|
||||
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
@ -143,6 +143,19 @@ func @tensorlistPushBack(%arg0: tensor<3x10xf32>, %arg1: tensor<1xi32>, %arg2: t
|
||||
// CHECK: return [[RESULT]] : tensor<?x10xf32>
|
||||
}
|
||||
|
||||
func @tensorlistLength(%arg0: tensor<3x10xf32>, %arg1: tensor<1xi32>) -> (tensor<i32>) {
|
||||
%0 = "tf.TensorListFromTensor"(%arg0, %arg1) : (tensor<3x10xf32>, tensor<1xi32>) -> tensor<!tf.variant<tensor<10xf32>>>
|
||||
%1 = "tf.TensorListLength"(%0) : (tensor<!tf.variant<tensor<10xf32>>>) -> tensor<i32>
|
||||
return %1: tensor<i32>
|
||||
|
||||
// CHECK-LABEL: tensorlistLength
|
||||
// CHECK-SAME: ([[INPUT:%.*]]: tensor<3x10xf32>, [[ELEM_SHAPE:%.*]]: tensor<1xi32>)
|
||||
// CHECK-DAG: [[SHAPE:%.*]] = "tf.Shape"([[INPUT]]) {{.*}} -> tensor<2xi32>
|
||||
// CHECK-DAG: [[ZERO:%cst.*]] = constant dense<0> : tensor<i32>
|
||||
// CHECK: [[RESULT:%.*]] = "tf.Gather"([[SHAPE]], [[ZERO]]) {validate_indices = true} : (tensor<2xi32>, tensor<i32>) -> tensor<i32>
|
||||
// CHECK: return [[RESULT]] : tensor<i32>
|
||||
}
|
||||
|
||||
func @tensorlistWhileLoop(%arg0: tensor<2x3xf32>) -> tensor<*xf32> {
|
||||
%cst = constant dense<3> : tensor<1xi32>
|
||||
%cst_0 = constant dense<0> : tensor<i32>
|
||||
|
@ -278,6 +278,6 @@ func @main(tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, t
|
||||
%21 = "tfl.pseudo_input" (%arg21) : (tensor<4 x f32>) -> tensor<4 x f32>
|
||||
%22 = "tfl.pseudo_input" (%arg22) : (tensor<4 x f32>) -> tensor<4 x f32>
|
||||
%23 = "tfl.pseudo_input" (%arg23) : (tensor<4 x f32>) -> tensor<4 x f32>
|
||||
%24 = "tfl.lstm"(%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23) {fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
%24 = "tfl.lstm"(%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23) ({}) {fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
return %24 : tensor<4xf32>
|
||||
}
|
@ -0,0 +1,31 @@
|
||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck --dump-input-on-failure %s
|
||||
|
||||
module attributes {
|
||||
tfl.metadata = {key1 = "value1", key2 = "value2"}
|
||||
} {
|
||||
func @main(tensor<3x2xi32>) -> tensor<3x2xi32>
|
||||
attributes {tf.entry_function = {inputs = "input", outputs = "SameNameAsOutput"}} {
|
||||
^bb0(%arg0: tensor<3x2xi32>):
|
||||
%0 = "tfl.pseudo_input" (%arg0) : (tensor<3x2xi32>) -> tensor<3x2xi32> loc("Input")
|
||||
%1 = "tfl.pseudo_const" () {value = dense<[[1, 2], [3, 4], [5, 6]]> : tensor<3x2xi32>} : () -> tensor<3x2xi32>
|
||||
%2 = "tfl.sub" (%0, %1) {fused_activation_function = "NONE"} : (tensor<3x2xi32>, tensor<3x2xi32>) -> tensor<3x2xi32>
|
||||
return %2 : tensor<3x2xi32>
|
||||
}
|
||||
}
|
||||
|
||||
// CHECK: buffers: [ {
|
||||
// CHECK: }, {
|
||||
// CHECK: }, {
|
||||
// CHECK: }, {
|
||||
// CHECK: }, {
|
||||
// CHECK-NEXT: data: [ 118, 97, 108, 117, 101, 49 ]
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: data: [ 118, 97, 108, 117, 101, 50 ]
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: metadata: [ {
|
||||
// CHECK-NEXT: name: "key1",
|
||||
// CHECK-NEXT: buffer: 4
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: name: "key2",
|
||||
// CHECK-NEXT: buffer: 5
|
||||
// CHECK-NEXT: } ]
|
@ -1,5 +1,4 @@
|
||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string -
|
||||
// | FileCheck %s
|
||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck %s
|
||||
|
||||
func @main(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) -> tensor<40x40xf32> {
|
||||
%cst = constant unit
|
||||
@ -9,7 +8,7 @@ func @main(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) -> tensor<40x40xf
|
||||
return %2 : tensor<40x40xf32>
|
||||
}
|
||||
|
||||
// CHECK-NEXT: operators: [ {
|
||||
// CHECK: operators: [ {
|
||||
// CHECK-NEXT: inputs: [ 0, 1, -1 ],
|
||||
// CHECK-NEXT: outputs: [ 2, 3 ],
|
||||
// CHECK-NEXT: builtin_options_type: FullyConnectedOptions,
|
||||
|
@ -103,7 +103,7 @@ func @testAddN(tensor<? x f32>, tensor<? x f32>, tensor<? x f32>) -> tensor<? x
|
||||
// test invalid AddN
|
||||
func @testAddNWrongOperandResultType(tensor<? x f16>, tensor<? x f16>, tensor<? x f16>) -> tensor<? x f16> {
|
||||
^bb0(%arg0: tensor<? x f16>, %arg1: tensor<? x f16>, %arg2: tensor<? x f16>):
|
||||
// expected-error @+1 {{'tfl.add_n' op operand #0 must be tensor of 32-bit float or 32-bit integer values}}
|
||||
// expected-error @+1 {{'tfl.add_n' op operand #0 must be tensor of 32-bit float or 32-bit integer or QI16 type or QUI16 type values}}
|
||||
%0 = "tfl.add_n"(%arg0, %arg1, %arg2): (tensor<? x f16>, tensor<? x f16>, tensor<? x f16>) -> tensor<? x f16>
|
||||
return %0 : tensor<? x f16>
|
||||
}
|
||||
@ -537,7 +537,7 @@ func @testLogistic(tensor<1x2x3x4x5xbf16>) -> tensor<1x2x3x4x5xbf16> {
|
||||
// test invalid Logistic input
|
||||
func @testLogisticWithWrongInputType(tensor<?xi32>) -> tensor<?xi32> {
|
||||
^bb0(%arg0: tensor<?xi32>):
|
||||
// expected-error @+1 {{tfl.logistic' op operand #0 must be tensor of floating-point or QI8 type or QUI8 type values}}
|
||||
// expected-error @+1 {{tfl.logistic' op operand #0 must be tensor of floating-point or QI8 type or QUI8 type or QI16 type or QUI16 type values}}
|
||||
%0 = "tfl.logistic"(%arg0): (tensor<?xi32>) -> tensor<?xi32>
|
||||
return %0#0 : tensor<?xi32>
|
||||
}
|
||||
@ -591,8 +591,9 @@ func @testUnidirectionalSequenceLstmWithInvalidNoneType(%arg0: tensor<? x f32>,
|
||||
|
||||
// CHECK-LABEL: testLstm
|
||||
func @testLstm(%arg0: tensor<? x f32>, %arg1: tensor<? x f32>, %arg2: tensor<? x f32>, %arg3: tensor<? x f32>, %arg4: tensor<? x f32>, %arg5: tensor<? x f32>, %arg6: tensor<? x f32>, %arg7: tensor<? x f32>, %arg8: tensor<? x f32>, %arg9: tensor<? x f32>, %arg10: tensor<? x f32>, %arg11: tensor<? x f32>, %arg12: tensor<? x f32>, %arg13: tensor<? x f32>, %arg14: tensor<? x f32>, %arg15: tensor<? x f32>, %arg16: tensor<? x f32>, %arg17: tensor<? x f32>, %arg18: tensor<? x f32>, %arg19: tensor<? x f32>, %arg20: tensor<? x f32>, %arg21: tensor<? x f32>, %arg22: tensor<? x f32>, %arg23: tensor<? x f32>) -> tensor<? x f32> {
|
||||
// CHECK: "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
%0 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
// CHECK: "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23)
|
||||
// CHECK-NEXT: {fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
%0 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) ({}) {fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
@ -600,8 +601,9 @@ func @testLstm(%arg0: tensor<? x f32>, %arg1: tensor<? x f32>, %arg2: tensor<? x
|
||||
|
||||
// CHECK-LABEL: testLstmWithNoneTypeAndOverrideAttr
|
||||
func @testLstmWithNoneTypeAndOverrideAttr(%arg0: tensor<? x f32>, %arg1: none, %arg2: tensor<? x f32>, %arg3: tensor<? x f32>, %arg4: tensor<? x f32>, %arg5: tensor<? x f32>, %arg6: tensor<? x f32>, %arg7: tensor<? x f32>, %arg8: tensor<? x f32>, %arg9: tensor<? x f32>, %arg10: tensor<? x f32>, %arg11: tensor<? x f32>, %arg12: tensor<? x f32>, %arg13: tensor<? x f32>, %arg14: tensor<? x f32>, %arg15: tensor<? x f32>, %arg16: tensor<? x f32>, %arg17: tensor<? x f32>, %arg18: tensor<? x f32>, %arg19: tensor<? x f32>, %arg20: tensor<? x f32>, %arg21: tensor<? x f32>, %arg22: tensor<? x f32>, %arg23: tensor<? x f32>) -> tensor<? x f32> {
|
||||
// CHECK: "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor<?xf32>, none, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
%0 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor<?xf32>, none, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
// CHECK: "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23)
|
||||
// CHECK-NEXT: {cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor<?xf32>, none, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
%0 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) ({}) {cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor<?xf32>, none, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
@ -610,7 +612,7 @@ func @testLstmWithNoneTypeAndOverrideAttr(%arg0: tensor<? x f32>, %arg1: none, %
|
||||
// test invalid none type applied to a tensor type arg
|
||||
func @testLstmWithInvalidNoneType(%arg0: tensor<? x f32>, %arg1: tensor<? x f32>, %arg2: none, %arg3: tensor<? x f32>, %arg4: tensor<? x f32>, %arg5: tensor<? x f32>, %arg6: tensor<? x f32>, %arg7: tensor<? x f32>, %arg8: tensor<? x f32>, %arg9: tensor<? x f32>, %arg10: tensor<? x f32>, %arg11: tensor<? x f32>, %arg12: tensor<? x f32>, %arg13: tensor<? x f32>, %arg14: tensor<? x f32>, %arg15: tensor<? x f32>, %arg16: tensor<? x f32>, %arg17: tensor<? x f32>, %arg18: tensor<? x f32>, %arg19: tensor<? x f32>, %arg20: tensor<? x f32>, %arg21: tensor<? x f32>, %arg22: tensor<? x f32>, %arg23: tensor<? x f32>) -> tensor<? x f32> {
|
||||
// expected-error @+1 {{'tfl.lstm' op operand #2 must be tensor of 32-bit float or 8-bit integer values}}
|
||||
%0 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {fused_activation_function = "NONE"} : (tensor<?xf32>, tensor<? x f32>, none, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
%0 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) ({}) {fused_activation_function = "NONE"} : (tensor<?xf32>, tensor<? x f32>, none, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
@ -619,7 +621,7 @@ func @testLstmWithInvalidNoneType(%arg0: tensor<? x f32>, %arg1: tensor<? x f32>
|
||||
// test violation of projection weight and projection bias pred op trait
|
||||
func @testLstmWithInvalidNoneType(%arg0: tensor<? x f32>, %arg1: tensor<? x f32>, %arg2: tensor<? x f32>, %arg3: tensor<? x f32>, %arg4: tensor<? x f32>, %arg5: tensor<? x f32>, %arg6: tensor<? x f32>, %arg7: tensor<? x f32>, %arg8: tensor<? x f32>, %arg9: tensor<? x f32>, %arg10: tensor<? x f32>, %arg11: tensor<? x f32>, %arg12: tensor<? x f32>, %arg13: tensor<? x f32>, %arg14: tensor<? x f32>, %arg15: tensor<? x f32>, %arg16: none, %arg17: tensor<? x f32>, %arg18: tensor<? x f32>, %arg19: tensor<? x f32>, %arg20: tensor<? x f32>, %arg21: tensor<? x f32>, %arg22: tensor<? x f32>, %arg23: tensor<? x f32>) -> tensor<? x f32> {
|
||||
// expected-error @+1 {{'tfl.lstm' op failed to verify that either projection weight must be specified or both projection weight and projection bias must not be specified}}
|
||||
%0 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {fused_activation_function = "NONE"} : (tensor<?xf32>, tensor<? x f32>, tensor<? x f32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, none, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
%0 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) ({}) {fused_activation_function = "NONE"} : (tensor<?xf32>, tensor<? x f32>, tensor<? x f32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, none, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
@ -628,7 +630,7 @@ func @testLstmWithInvalidNoneType(%arg0: tensor<? x f32>, %arg1: tensor<? x f32>
|
||||
// test invalid kernel type
|
||||
func @testLstmWithInvalidKernelType(%arg0: tensor<? x f32>, %arg1: tensor<? x f32>, %arg2: tensor<? x f32>, %arg3: tensor<? x f32>, %arg4: tensor<? x f32>, %arg5: tensor<? x f32>, %arg6: tensor<? x f32>, %arg7: tensor<? x f32>, %arg8: tensor<? x f32>, %arg9: tensor<? x f32>, %arg10: tensor<? x f32>, %arg11: tensor<? x f32>, %arg12: tensor<? x f32>, %arg13: tensor<? x f32>, %arg14: tensor<? x f32>, %arg15: tensor<? x f32>, %arg16: tensor<? x f32>, %arg17: tensor<? x f32>, %arg18: tensor<? x f32>, %arg19: tensor<? x f32>, %arg20: tensor<? x f32>, %arg21: tensor<? x f32>, %arg22: tensor<? x f32>, %arg23: tensor<? x f32>) -> tensor<? x f32> {
|
||||
// expected-error @+1 {{'tfl.lstm' op attribute 'kernel_type' failed to satisfy constraint: lstm kernel type enum case FULL}}
|
||||
%0 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "BASIC"} : (tensor<?xf32>, tensor<? x f32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
%0 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) ({}) {cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "BASIC"} : (tensor<?xf32>, tensor<? x f32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
@ -652,6 +654,15 @@ func @testSelect(%cond : tensor<?xi1>, %arg0 : tensor<?xi32>, %arg1 : tensor<?xi
|
||||
|
||||
// -----
|
||||
|
||||
// test select with multi-dim inputs
|
||||
// CHECK-LABEL: testSelectMultiDim
|
||||
func @testSelectMultiDim(%cond : tensor<?xi1>, %arg0 : tensor<?x4xi32>, %arg1 : tensor<?x4xi32>) -> tensor<?x4xi32> {
|
||||
%0 = "tfl.select"(%cond, %arg0, %arg1): (tensor<?xi1>,tensor<?x4xi32>,tensor<?x4xi32>) -> tensor<?x4xi32>
|
||||
return %0 : tensor<?x4xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testSelectWithUnsupportedType(%cond : tensor<?xi32>, %arg0 : tensor<?xi32>, %arg1 : tensor<?xi32>) -> tensor<?xi32> {
|
||||
// expected-error @+1 {{op operand #0 must be tensor of 1-bit integer values}}
|
||||
%0 = "tfl.select"(%cond, %arg0, %arg1): (tensor<?xi32>,tensor<?xi32>,tensor<?xi32>) -> tensor<?xi32>
|
||||
@ -660,6 +671,14 @@ func @testSelectWithUnsupportedType(%cond : tensor<?xi32>, %arg0 : tensor<?xi32>
|
||||
|
||||
// -----
|
||||
|
||||
func @testSelectWithUnsupportedShapes(%cond : tensor<2xi1>, %arg0 : tensor<3xi32>, %arg1 : tensor<3xi32>) -> tensor<3xi32> {
|
||||
// expected-error @+1 {{failed to verify that Select operands meet shape criteria}}
|
||||
%0 = "tfl.select"(%cond, %arg0, %arg1): (tensor<2xi1>,tensor<3xi32>,tensor<3xi32>) -> tensor<3xi32>
|
||||
return %0 : tensor<3xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testSelectWithUnsupportedType(%cond : tensor<?xi1>, %arg0 : tensor<?xi32>, %arg1 : tensor<?xf32>) -> tensor<?xi32> {
|
||||
// expected-error @+1 {{failed to verify that operands have same element type}}
|
||||
%0 = "tfl.select"(%cond, %arg0, %arg1): (tensor<?xi1>,tensor<?xi32>,tensor<?xf32>) -> tensor<?xi32>
|
||||
@ -762,6 +781,21 @@ func @testPadWithInvalidPaddingsRank(tensor<2x1x3xf32>, tensor<1x3x2xi32>) -> te
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: testPadQuantizedU8
|
||||
func @testPadQuantizedU8(%arg0: tensor<2x1x3x!quant.uniform<u8:f32, 0.1>>, %arg1: tensor<3x2xi32>) -> tensor<? x !quant.uniform<u8:f32, 0.1>> {
|
||||
// CHECK: "tfl.pad"(%arg0, %arg1)
|
||||
%0 = "tfl.pad"(%arg0, %arg1) : (tensor<2x1x3x!quant.uniform<u8:f32, 0.1>>, tensor<3x2xi32>) -> tensor<? x !quant.uniform<u8:f32, 0.1>>
|
||||
return %0#0 : tensor<? x !quant.uniform<u8:f32, 0.1>>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testPadQuantizedI8
|
||||
func @testPadQuantizedI8(%arg0: tensor<2x1x3x!quant.uniform<i8:f32, 0.1>>, %arg1: tensor<3x2xi32>) -> tensor<? x !quant.uniform<i8:f32, 0.1>> {
|
||||
// CHECK: "tfl.pad"(%arg0, %arg1)
|
||||
%0 = "tfl.pad"(%arg0, %arg1) : (tensor<2x1x3x!quant.uniform<i8:f32, 0.1>>, tensor<3x2xi32>) -> tensor<? x !quant.uniform<i8:f32, 0.1>>
|
||||
return %0#0 : tensor<? x !quant.uniform<i8:f32, 0.1>>
|
||||
}
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: testPadV2
|
||||
func @testPadV2(tensor<2x1x3xf32>, tensor<3x2xi32>) -> tensor<? x f32> {
|
||||
^bb0(%arg0: tensor<2x1x3xf32>, %arg1: tensor<3x2xi32>):
|
||||
@ -817,6 +851,20 @@ func @testPadV2WithInvalidConstantScalar(tensor<2x1x3xf32>, tensor<3x2xi32>) ->
|
||||
|
||||
// -----
|
||||
|
||||
func @packQuantizedU8(%arg0: tensor<2x!quant.uniform<u8:f32, 0.1>>, %arg1: tensor<2x!quant.uniform<u8:f32, 0.1>>) -> tensor<2x2x!quant.uniform<u8:f32, 0.1>> {
|
||||
// CHECK: "tfl.pack"(%arg0, %arg1) {axis = 0 : i32, values_count = 2 : i32}
|
||||
%0 = "tfl.pack"(%arg0, %arg1) {axis = 0 : i32, values_count = 2 : i32} : (tensor<2x!quant.uniform<u8:f32, 0.1>>, tensor<2x!quant.uniform<u8:f32, 0.1>>) -> tensor<2x2x!quant.uniform<u8:f32, 0.1>>
|
||||
return %0 : tensor<2x2x!quant.uniform<u8:f32, 0.1>>
|
||||
}
|
||||
|
||||
func @packQuantizedI8(%arg0: tensor<2x!quant.uniform<i8:f32, 0.1>>, %arg1: tensor<2x!quant.uniform<i8:f32, 0.1>>) -> tensor<2x2x!quant.uniform<i8:f32, 0.1>> {
|
||||
// CHECK: "tfl.pack"(%arg0, %arg1) {axis = 0 : i32, values_count = 2 : i32}
|
||||
%0 = "tfl.pack"(%arg0, %arg1) {axis = 0 : i32, values_count = 2 : i32} : (tensor<2x!quant.uniform<i8:f32, 0.1>>, tensor<2x!quant.uniform<i8:f32, 0.1>>) -> tensor<2x2x!quant.uniform<i8:f32, 0.1>>
|
||||
return %0 : tensor<2x2x!quant.uniform<i8:f32, 0.1>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @pack(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2x2xi32> {
|
||||
// CHECK: "tfl.pack"(%arg0, %arg1) {axis = 0 : i32, values_count = 2 : i32}
|
||||
%0 = "tfl.pack"(%arg0, %arg1) {axis = 0 : i32, values_count = 2 : i32} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2x2xi32>
|
||||
@ -1101,6 +1149,63 @@ func @transpose_perm_not_i32(%arg0 : tensor<2x2xi32>, %arg1 : tensor<2xf32>) ->
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
func @transpose_perm_size(%arg0 : tensor<2x2xi32>, %arg1 : tensor<3xi32>) -> tensor<2x2xi32> {
|
||||
// expected-error @+1 {{perm tensor elements size is not equal to input tensor rank}}
|
||||
%0 = "tfl.transpose"(%arg0, %arg1) : (tensor<2x2xi32>, tensor<3xi32>) -> tensor<2x2xi32>
|
||||
return %0 : tensor<2x2xi32>
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
func @transpose_unranked_shape(%arg0 : tensor<*xi32>) -> tensor<2x2xi32> {
|
||||
%cst = constant dense<[1, 0]> : tensor<2xi32>
|
||||
%0 = "tfl.transpose"(%arg0, %cst) : (tensor<*xi32>, tensor<2xi32>) -> tensor<2x2xi32>
|
||||
return %0 : tensor<2x2xi32>
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
func @transpose_dynamic_shape(%arg0 : tensor<2x?xi32>) -> tensor<?x2xi32> {
|
||||
%cst = constant dense<[1, 0]> : tensor<2xi32>
|
||||
%0 = "tfl.transpose"(%arg0, %cst) : (tensor<2x?xi32>, tensor<2xi32>) -> tensor<?x2xi32>
|
||||
return %0 : tensor<?x2xi32>
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
func @transpose_perm_axis_invalid(%arg0 : tensor<2x2xi32>) -> tensor<2x2xi32> {
|
||||
%cst = constant dense<[1, -1]> : tensor<2xi32>
|
||||
// expected-error @+1 {{perm[1] must be in [0, rank)}}
|
||||
%0 = "tfl.transpose"(%arg0, %cst) : (tensor<2x2xi32>, tensor<2xi32>) -> tensor<2x2xi32>
|
||||
return %0 : tensor<2x2xi32>
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
func @transpose_perm_axis_duplicated(%arg0 : tensor<2x2xi32>) -> tensor<2x2xi32> {
|
||||
%cst = constant dense<[1, 1]> : tensor<2xi32>
|
||||
// expected-error @+1 {{perm[1] cannot have duplicated axis}}
|
||||
%0 = "tfl.transpose"(%arg0, %cst) : (tensor<2x2xi32>, tensor<2xi32>) -> tensor<2x2xi32>
|
||||
return %0 : tensor<2x2xi32>
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
func @transpose_output_type_bad(%arg0 : tensor<3x4x5x6xi32>) -> tensor<3x4x5x6xi32> {
|
||||
%cst = constant dense<[0, 3, 1, 2]> : tensor<4xi32>
|
||||
// expected-error @+1 {{expect output type tensor<3x6x4x5xi32>, got tensor<3x4x5x6xi32>}}
|
||||
%0 = "tfl.transpose"(%arg0, %cst) : (tensor<3x4x5x6xi32>, tensor<4xi32>) -> tensor<3x4x5x6xi32>
|
||||
return %0 : tensor<3x4x5x6xi32>
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
func @transpose_element_type(%arg0 : tensor<2x2xf32>, %arg1 : tensor<2xi32>) -> tensor<2x2xi32> {
|
||||
@ -1643,3 +1748,33 @@ func @testSplitVOpWithValidSizeSplitsNegative(%arg0 : tensor<16x4xf32>) -> (tens
|
||||
|
||||
return %0, %1, %2, %3, %4 : tensor<7x4xf32>, tensor<3x4xf32>, tensor<6x4xf32>, tensor<16x0xf32>, tensor<16x4xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testNonMaxSuppressionV4WithCorrectBoxShape(%arg0: tensor<3x4xf32>, %arg1: tensor<3xf32>, %arg2: tensor<i32>, %arg3: tensor<f32>, %arg4: tensor<f32>) -> (tensor<2xi32>, tensor<i32>) {
|
||||
%0, %1 = "tfl.non_max_suppression_v4"(%arg0, %arg1, %arg2, %arg3, %arg4) : (tensor<3x4xf32>, tensor<3xf32>, tensor<i32>, tensor<f32>, tensor<f32>) -> (tensor<2xi32>, tensor<i32>)
|
||||
return %0, %1 : tensor<2xi32>, tensor<i32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testNonMaxSuppressionV4WithWrongBoxShape(%arg0: tensor<3x2xf32>, %arg1: tensor<3xf32>, %arg2: tensor<i32>, %arg3: tensor<f32>, %arg4: tensor<f32>) -> (tensor<2xi32>, tensor<i32>) {
|
||||
// expected-error @+1 {{'tfl.non_max_suppression_v4' op failed to verify that boxes should have dim[1] == 4}}
|
||||
%0, %1 = "tfl.non_max_suppression_v4"(%arg0, %arg1, %arg2, %arg3, %arg4) : (tensor<3x2xf32>, tensor<3xf32>, tensor<i32>, tensor<f32>, tensor<f32>) -> (tensor<2xi32>, tensor<i32>)
|
||||
return %0, %1 : tensor<2xi32>, tensor<i32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testNonMaxSuppressionV5WithCorrectBoxShape(%arg0: tensor<3x4xf32>, %arg1: tensor<3xf32>, %arg2: tensor<i32>, %arg3: tensor<f32>, %arg4: tensor<f32>, %arg5: tensor<f32>) -> (tensor<2xi32>, tensor<2xf32>, tensor<i32>) {
|
||||
%0, %1, %2 = "tfl.non_max_suppression_v5"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) : (tensor<3x4xf32>, tensor<3xf32>, tensor<i32>, tensor<f32>, tensor<f32>, tensor<f32>) -> (tensor<2xi32>, tensor<2xf32>, tensor<i32>)
|
||||
return %0, %1, %2 : tensor<2xi32>, tensor<2xf32>, tensor<i32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testNonMaxSuppressionV5WithWrongBoxShape(%arg0: tensor<3x2xf32>, %arg1: tensor<3xf32>, %arg2: tensor<i32>, %arg3: tensor<f32>, %arg4: tensor<f32>, %arg5: tensor<f32>) -> (tensor<2xi32>, tensor<2xf32>, tensor<i32>) {
|
||||
// expected-error @+1 {{'tfl.non_max_suppression_v5' op failed to verify that boxes should have dim[1] == 4}}
|
||||
%0, %1, %2 = "tfl.non_max_suppression_v5"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) : (tensor<3x2xf32>, tensor<3xf32>, tensor<i32>, tensor<f32>, tensor<f32>, tensor<f32>) -> (tensor<2xi32>, tensor<2xf32>, tensor<i32>)
|
||||
return %0, %1, %2 : tensor<2xi32>, tensor<2xf32>, tensor<i32>
|
||||
}
|
||||
|
@ -292,16 +292,3 @@ func @InvalidL2NormalizePattern(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) -> t
|
||||
// CHECK: %3 = "tfl.div"([[INPUT:%.*]], %2) {fused_activation_function = "NONE"} : (tensor<2xf32>, tensor<f32>) -> tensor<2xf32>
|
||||
// CHECK: return %3
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @InvalidL2NormalizePatternMorethan1Dimension
|
||||
// Input has higher rank, it should be limited to 1D only.
|
||||
func @InvalidL2NormalizePatternMorethan1Dimension(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
%cst = constant dense<[0]> : tensor<1xi32>
|
||||
%0 = "tfl.square"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
%1 = "tfl.sum"(%0, %cst) {keep_dims = false} : (tensor<2x2xf32>, tensor<1xi32>) -> tensor<f32>
|
||||
%2 = "tfl.sqrt"(%1) : (tensor<f32>) -> tensor<f32>
|
||||
%3 = "tfl.div"(%arg0, %2) {fused_activation_function = "NONE"} : (tensor<2x2xf32>, tensor<f32>) -> tensor<2x2xf32>
|
||||
return %3: tensor<2x2xf32>
|
||||
// CHECK: %3 = "tfl.div"([[INPUT:%.*]], %2) {fused_activation_function = "NONE"} : (tensor<2x2xf32>, tensor<f32>) -> tensor<2x2xf32>
|
||||
// CHECK: return %3
|
||||
}
|
||||
|
@ -373,21 +373,12 @@ func @QuantizeConstant() -> tensor<2x3xf32> {
|
||||
// CHECK: return %1 : tensor<2x3xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: NotQuantizeNonZeroSplat
|
||||
func @NotQuantizeNonZeroSplat() -> tensor<2x3xf32> {
|
||||
%cst = constant dense<2.0> : tensor<2x3xf32>
|
||||
return %cst : tensor<2x3xf32>
|
||||
// CHECK-LABEL: NotQuantizeNoneType
|
||||
func @NotQuantizeNoneType() -> none {
|
||||
%cst = constant unit
|
||||
return %cst : none
|
||||
|
||||
// CHECK-NEXT: %[[cst:.*]] = constant dense<2.000000e+00>
|
||||
// CHECK-NEXT: return %[[cst]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: NotQuantizeNonZeroScalar
|
||||
func @NotQuantizeNonZeroScalar() -> tensor<f32> {
|
||||
%cst = constant dense<2.0> : tensor<f32>
|
||||
return %cst : tensor<f32>
|
||||
|
||||
// CHECK-NEXT: %[[cst:.*]] = constant dense<2.000000e+00>
|
||||
// CHECK-NEXT: %[[cst:.*]] = constant unit
|
||||
// CHECK-NEXT: return %[[cst]]
|
||||
}
|
||||
|
||||
@ -433,6 +424,32 @@ func @QuantizeSharedBiases(
|
||||
// CHECK: %[[cst_0:.*]] = constant dense<1.000000e+00> : tensor<32xf32>
|
||||
// CHECK: %[[q_0:.*]] = "tfl.quantize"(%[[cst_0]])
|
||||
// CHECK: %[[dq_0:.*]] = "tfl.dequantize"(%[[q_0]])
|
||||
// CHECK: %{{.*}} = "tfl.conv_2d"(%{{.*}}, %{{.*}}, %[[dq]])
|
||||
// CHECK: %{{.*}} = "tfl.conv_2d"(%{{.*}}, %{{.*}}, %[[dq_0]])
|
||||
// CHECK: %{{.*}} = "tfl.conv_2d"(%{{.*}}, %{{.*}}, %[[dq]])
|
||||
}
|
||||
|
||||
// CHECK-LABEL: QuantizeSharedBiases2
|
||||
func @QuantizeSharedBiases2(
|
||||
%arg0: tensor<32x!quant.uniform<u8:f32, 1.0>>,
|
||||
%arg1: tensor<1x112x112x32x!quant.uniform<u8:f32, 1.0>>,
|
||||
%arg2: tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32, 2.0>>) -> (tensor<32x!quant.uniform<u8:f32, 1.0>>, tensor<1x56x56x32x!quant.uniform<u8:f32, 1.0>>) {
|
||||
%cst = constant dense<1.0> : tensor<32xf32>
|
||||
%1 = "tfl.dequantize"(%arg0) : (tensor<32x!quant.uniform<u8:f32, 1.0>>) -> tensor<32xf32>
|
||||
%add = "tfl.add"(%1, %cst) {fused_activation_function = "NONE"} : (tensor<32xf32>, tensor<32xf32>) -> tensor<32xf32>
|
||||
%3 = "tfl.quantize"(%add) {qtype = tensor<32xf32>} : (tensor<32xf32>) -> tensor<32x!quant.uniform<u8:f32, 1.0>>
|
||||
|
||||
%5 = "tfl.dequantize"(%arg1) : (tensor<1x112x112x32x!quant.uniform<u8:f32, 1.0>>) -> tensor<1x112x112x32xf32>
|
||||
%6 = "tfl.dequantize"(%arg2) : (tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32, 2.0>>) -> tensor<32x3x3x3xf32>
|
||||
%conv2 = "tfl.conv_2d"(%5, %6, %cst) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<1x112x112x32xf32>, tensor<32x3x3x3xf32>, tensor<32xf32>) -> tensor<1x56x56x32xf32>
|
||||
%7 = "tfl.quantize"(%conv2) {qtype = tensor<1x56x56x32x!quant.uniform<u8:f32, 1.0>>} : (tensor<1x56x56x32xf32>) -> tensor<1x56x56x32x!quant.uniform<u8:f32, 1.0>>
|
||||
return %3, %7 : tensor<32x!quant.uniform<u8:f32, 1.0>>, tensor<1x56x56x32x!quant.uniform<u8:f32, 1.0>>
|
||||
|
||||
// CHECK: %[[cst:.*]] = constant dense<1.000000e+00> : tensor<32xf32>
|
||||
// CHECK: %[[q:.*]] = "tfl.quantize"(%[[cst]])
|
||||
// CHECK: %[[dq:.*]] = "tfl.dequantize"(%[[q]])
|
||||
// CHECK: %[[cst_0:.*]] = constant dense<1.000000e+00> : tensor<32xf32>
|
||||
// CHECK: %[[q_0:.*]] = "tfl.quantize"(%[[cst_0]]) {qtype = tensor<32x!quant.uniform<u8<1:255>:f32, 1.000000e+00:1>>}
|
||||
// CHECK: %[[dq_0:.*]] = "tfl.dequantize"(%[[q_0]])
|
||||
// CHECK: %{{.*}} = tfl.add %{{.*}}, %[[dq_0]]
|
||||
// CHECK: %{{.*}} = "tfl.conv_2d"(%{{.*}}, %{{.*}}, %[[dq]])
|
||||
}
|
||||
|
@ -63,7 +63,7 @@ func @fusedBatchNorm(tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8
|
||||
return %2, %2#1 : tensor<8x8x8x8xf32>, tensor<8xf32>
|
||||
|
||||
// CHECK-LABEL: fusedBatchNorm
|
||||
// CHECK: %[[CONSTANT:.*]] = "tf.Const"{{.*}} dense<1.000000e-03>
|
||||
// CHECK: %[[CONSTANT:.*]] = constant dense<1.000000e-03>
|
||||
// variance + epsilon
|
||||
// CHECK: %[[ADD1:.*]] = "tf.Add"(%[[ARG4:.*]], %[[CONSTANT]])
|
||||
// rsqrt(variance + epsilon)
|
||||
@ -96,7 +96,7 @@ func @fusedBatchNormV3(tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor
|
||||
return %2, %2#1 : tensor<8x8x8x8xf32>, tensor<8xf32>
|
||||
|
||||
// CHECK-LABEL: fusedBatchNormV3
|
||||
// CHECK: %[[CONSTANT:.*]] = "tf.Const"{{.*}} dense<1.000000e-03>
|
||||
// CHECK: %[[CONSTANT:.*]] = constant dense<1.000000e-03>
|
||||
// variance + epsilon
|
||||
// CHECK: %[[ADD1:.*]] = "tf.Add"(%[[ARG4:.*]], %[[CONSTANT]])
|
||||
// rsqrt(variance + epsilon)
|
||||
@ -155,7 +155,7 @@ func @fakeQuantFolded() -> (tensor<8xf32>) {
|
||||
%rst = "tf.FakeQuantWithMinMaxVars"(%in, %mini, %maxi) {num_bits = 3, narrow_range = false} : (tensor<8xf32>, tensor<f32>, tensor<f32>) -> tensor<8xf32>
|
||||
return %rst : tensor<8xf32>
|
||||
|
||||
// CHECK: %[[CONSTANT:.*]] = "tf.Const"{{.*}} dense<0.000000e+00> : tensor<8xf32>
|
||||
// CHECK: %[[CONSTANT:.*]] = constant dense<0.000000e+00> : tensor<8xf32>
|
||||
// CHECK: %[[QUANTIZE:.*]] = "tfl.quantize"(%[[CONSTANT]]) {qtype = tensor<8x!quant.uniform<u8:f32, 1.000000e+00>>}
|
||||
// CHECK: %[[DEQUANTIZE:.*]] = "tfl.dequantize"(%[[QUANTIZE]])
|
||||
// CHECK: return %[[DEQUANTIZE]] : tensor<8xf32>
|
||||
@ -262,7 +262,7 @@ func @fakeQuantWithConv2D(tensor<256x32x32x3xf32>) -> (tensor<256x30x30x16xf32>)
|
||||
return %rst : tensor<256x30x30x16xf32>
|
||||
|
||||
// CHECK: %[[CONSTANT:.*]] = constant dense<0.000000e+00> : tensor<16xf32>
|
||||
// CHECK: %[[CONSTANT0:.*]] = "tf.Const"{{.*}} dense<0.000000e+00> : tensor<16x3x3x3xf32>
|
||||
// CHECK: %[[CONSTANT0:.*]] = constant dense<0.000000e+00> : tensor<16x3x3x3xf32>
|
||||
// CHECK: %[[QUANTIZE:.*]] = "tfl.quantize"(%[[CONSTANT0]]) {qtype = tensor<16x3x3x3x!quant.uniform<u8:f32, 1.000000e+00>>}
|
||||
// CHECK: %[[DEQUANTIZE:.*]] = "tfl.dequantize"(%[[QUANTIZE]])
|
||||
// CHECK: %[[CONV:.*]] = "tfl.conv_2d"(%arg0, %[[DEQUANTIZE]], %[[CONSTANT]])
|
||||
@ -282,7 +282,7 @@ func @fakeQuantWithDepthwiseConv2D(tensor<256x32x32x3xf32>) -> (tensor<256x30x30
|
||||
return %rst : tensor<256x30x30x16xf32>
|
||||
|
||||
// CHECK: %[[CONSTANT:.*]] = constant dense<0.000000e+00> : tensor<48xf32>
|
||||
// CHECK: %[[CONSTANT0:.*]] = "tf.Const"{{.*}} dense<0.000000e+00> : tensor<1x3x3x48xf32>
|
||||
// CHECK: %[[CONSTANT0:.*]] = constant dense<0.000000e+00> : tensor<1x3x3x48xf32>
|
||||
// CHECK: %[[QUANTIZE:.*]] = "tfl.quantize"(%[[CONSTANT0]]) {qtype = tensor<1x3x3x48x!quant.uniform<u8:f32, 1.000000e+00>>}
|
||||
// CHECK: %[[DEQUANTIZE:.*]] = "tfl.dequantize"(%[[QUANTIZE]])
|
||||
// CHECK: %[[CONV:.*]] = "tfl.depthwise_conv_2d"(%arg0, %[[DEQUANTIZE]], %[[CONSTANT]])
|
||||
@ -348,3 +348,11 @@ func @stop_gradient(%arg0: tensor<3xi32>) -> tensor<3xi32> {
|
||||
// CHECK-LABEL: stop_gradient
|
||||
// CHECK: return %arg0 : tensor<3xi32>
|
||||
}
|
||||
|
||||
func @CheckNumerics(%arg0: tensor<3xf32>) -> tensor<3xf32> {
|
||||
%0 = "tf.CheckNumerics"(%arg0) {message = ""}: (tensor<3xf32>) -> tensor<3xf32>
|
||||
return %0 : tensor<3xf32>
|
||||
// Should be converted to Identity and then from Identity to value
|
||||
// CHECK-LABEL: CheckNumerics
|
||||
// CHECK: return %arg0 : tensor<3xf32>
|
||||
}
|
||||
|
223
tensorflow/compiler/mlir/lite/tests/unroll-batch-matmul.mlir
Normal file
223
tensorflow/compiler/mlir/lite/tests/unroll-batch-matmul.mlir
Normal file
@ -0,0 +1,223 @@
|
||||
// RUN: tf-opt -tfl-unroll-batch-matmul %s | FileCheck %s
|
||||
|
||||
func @batchMatMulV2TwoDim(%arg0: tensor<2x3x4x5xf32>, %arg1: tensor<2x3x5x6xf32>) -> tensor<2x3x4x6xf32> {
|
||||
%0 = "tf.BatchMatMulV2"(%arg0, %arg1) : (tensor<2x3x4x5xf32>, tensor<2x3x5x6xf32>) -> tensor<2x3x4x6xf32>
|
||||
return %0 : tensor<2x3x4x6xf32>
|
||||
|
||||
// CHECK-LABEL: batchMatMulV2TwoDim
|
||||
// CHECK: %[[cst:.*]] = constant dense<[6, 4, 5]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_0:.*]] = constant dense<[1, 4, 5]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_1:.*]] = constant dense<[4, 5]> : tensor<2xi64>
|
||||
// CHECK: %[[cst_2:.*]] = constant dense<[6, 5, 6]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_3:.*]] = constant dense<0> : tensor<3xi64>
|
||||
// CHECK: %[[cst_4:.*]] = constant dense<[1, 0, 0]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_5:.*]] = constant dense<[2, 0, 0]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_6:.*]] = constant dense<[3, 0, 0]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_7:.*]] = constant dense<[4, 0, 0]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_8:.*]] = constant dense<[5, 0, 0]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_9:.*]] = constant dense<[1, 5, 6]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_10:.*]] = constant dense<[5, 6]> : tensor<2xi64>
|
||||
// CHECK: %[[cst_11:.*]] = constant dense<[2, 3, 4, 6]> : tensor<4xi64>
|
||||
|
||||
// CHECK: %[[v0:.*]] = "tf.Reshape"(%arg0, %[[cst]]) : (tensor<2x3x4x5xf32>, tensor<3xi64>) -> tensor<6x4x5xf32>
|
||||
// CHECK: %[[v1:.*]] = "tf.Slice"(%[[v0]], %[[cst_3]], %[[cst_0]]) : (tensor<6x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
|
||||
// CHECK: %[[v2:.*]] = "tf.Reshape"(%[[v1]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
|
||||
// CHECK: %[[v3:.*]] = "tf.Slice"(%[[v0]], %[[cst_4]], %[[cst_0]]) : (tensor<6x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
|
||||
// CHECK: %[[v4:.*]] = "tf.Reshape"(%[[v3]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
|
||||
// CHECK: %[[v5:.*]] = "tf.Slice"(%[[v0]], %[[cst_5]], %[[cst_0]]) : (tensor<6x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
|
||||
// CHECK: %[[v6:.*]] = "tf.Reshape"(%[[v5]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
|
||||
// CHECK: %[[v7:.*]] = "tf.Slice"(%[[v0]], %[[cst_6]], %[[cst_0]]) : (tensor<6x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
|
||||
// CHECK: %[[v8:.*]] = "tf.Reshape"(%[[v7]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
|
||||
// CHECK: %[[v9:.*]] = "tf.Slice"(%[[v0]], %[[cst_7]], %[[cst_0]]) : (tensor<6x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
|
||||
// CHECK: %[[v10:.*]] = "tf.Reshape"(%[[v9]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
|
||||
// CHECK: %[[v11:.*]] = "tf.Slice"(%[[v0]], %[[cst_8]], %[[cst_0]]) : (tensor<6x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
|
||||
// CHECK: %[[v12:.*]] = "tf.Reshape"(%[[v11]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
|
||||
|
||||
// CHECK: %[[v13:.*]] = "tf.Reshape"(%arg1, %[[cst_2]]) : (tensor<2x3x5x6xf32>, tensor<3xi64>) -> tensor<6x5x6xf32>
|
||||
// CHECK: %[[v14:.*]] = "tf.Slice"(%[[v13]], %[[cst_3]], %[[cst_9]]) : (tensor<6x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32>
|
||||
// CHECK: %[[v15:.*]] = "tf.Reshape"(%[[v14]], %[[cst_10]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
|
||||
// CHECK: %[[v16:.*]] = "tf.Slice"(%[[v13]], %[[cst_4]], %[[cst_9]]) : (tensor<6x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32>
|
||||
// CHECK: %[[v17:.*]] = "tf.Reshape"(%[[v16]], %[[cst_10]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
|
||||
// CHECK: %[[v18:.*]] = "tf.Slice"(%[[v13]], %[[cst_5]], %[[cst_9]]) : (tensor<6x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32>
|
||||
// CHECK: %[[v19:.*]] = "tf.Reshape"(%[[v18]], %[[cst_10]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
|
||||
// CHECK: %[[v20:.*]] = "tf.Slice"(%[[v13]], %[[cst_6]], %[[cst_9]]) : (tensor<6x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32>
|
||||
// CHECK: %[[v21:.*]] = "tf.Reshape"(%[[v20]], %[[cst_10]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
|
||||
// CHECK: %[[v22:.*]] = "tf.Slice"(%[[v13]], %[[cst_7]], %[[cst_9]]) : (tensor<6x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32>
|
||||
// CHECK: %[[v23:.*]] = "tf.Reshape"(%[[v22]], %[[cst_10]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
|
||||
// CHECK: %[[v24:.*]] = "tf.Slice"(%[[v13]], %[[cst_8]], %[[cst_9]]) : (tensor<6x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32>
|
||||
// CHECK: %[[v25:.*]] = "tf.Reshape"(%[[v24]], %[[cst_10]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
|
||||
|
||||
// CHECK: %[[v26:.*]] = "tf.MatMul"(%[[v2]], %[[v15]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
|
||||
// CHECK: %[[v27:.*]] = "tf.MatMul"(%[[v4]], %[[v17]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
|
||||
// CHECK: %[[v28:.*]] = "tf.MatMul"(%[[v6]], %[[v19]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
|
||||
// CHECK: %[[v29:.*]] = "tf.MatMul"(%[[v8]], %[[v21]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
|
||||
// CHECK: %[[v30:.*]] = "tf.MatMul"(%[[v10]], %[[v23]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
|
||||
// CHECK: %[[v31:.*]] = "tf.MatMul"(%[[v12]], %[[v25]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
|
||||
|
||||
// CHECK: %[[v32:.*]] = "tf.Pack"(%[[v26]], %[[v27]], %[[v28]], %[[v29]], %[[v30]], %[[v31]]) {N = 6 : i64, axis = 0 : i64} : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<6x4x6xf32>
|
||||
// CHECK: %[[v33:.*]] = "tf.Reshape"(%[[v32]], %[[cst_11]]) : (tensor<6x4x6xf32>, tensor<4xi64>) -> tensor<2x3x4x6xf32>
|
||||
|
||||
// CHECK: return %[[v33]] : tensor<2x3x4x6xf32>
|
||||
}
|
||||
|
||||
func @batchMatMulV2FlatInput(%arg0: tensor<3x4x5xf32>, %arg1: tensor<3x5x6xf32>) -> tensor<3x4x6xf32> {
|
||||
%0 = "tf.BatchMatMulV2"(%arg0, %arg1) : (tensor<3x4x5xf32>, tensor<3x5x6xf32>) -> tensor<3x4x6xf32>
|
||||
return %0 : tensor<3x4x6xf32>
|
||||
|
||||
// CHECK-LABEL: batchMatMulV2FlatInput
|
||||
// CHECK: %[[cst:.*]] = constant dense<[3, 4, 5]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_0:.*]] = constant dense<[1, 4, 5]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_1:.*]] = constant dense<[4, 5]> : tensor<2xi64>
|
||||
// CHECK: %[[cst_2:.*]] = constant dense<[3, 5, 6]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_3:.*]] = constant dense<0> : tensor<3xi64>
|
||||
// CHECK: %[[cst_4:.*]] = constant dense<[1, 0, 0]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_5:.*]] = constant dense<[2, 0, 0]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_6:.*]] = constant dense<[1, 5, 6]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_7:.*]] = constant dense<[5, 6]> : tensor<2xi64>
|
||||
// CHECK: %[[cst_8:.*]] = constant dense<[3, 4, 6]> : tensor<3xi64>
|
||||
|
||||
// CHECK: %[[v0:.*]] = "tf.Reshape"(%arg0, %[[cst]]) : (tensor<3x4x5xf32>, tensor<3xi64>) -> tensor<3x4x5xf32>
|
||||
// CHECK: %[[v1:.*]] = "tf.Slice"(%[[v0]], %[[cst_3]], %[[cst_0]]) : (tensor<3x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
|
||||
// CHECK: %[[v2:.*]] = "tf.Reshape"(%[[v1]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
|
||||
// CHECK: %[[v3:.*]] = "tf.Slice"(%[[v0]], %[[cst_4]], %[[cst_0]]) : (tensor<3x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
|
||||
// CHECK: %[[v4:.*]] = "tf.Reshape"(%[[v3]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
|
||||
// CHECK: %[[v5:.*]] = "tf.Slice"(%[[v0]], %[[cst_5]], %[[cst_0]]) : (tensor<3x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
|
||||
// CHECK: %[[v6:.*]] = "tf.Reshape"(%[[v5]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
|
||||
|
||||
// CHECK: %[[v7:.*]] = "tf.Reshape"(%arg1, %[[cst_2]]) : (tensor<3x5x6xf32>, tensor<3xi64>) -> tensor<3x5x6xf32>
|
||||
// CHECK: %[[v8:.*]] = "tf.Slice"(%[[v7]], %[[cst_3]], %[[cst_6]]) : (tensor<3x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32>
|
||||
// CHECK: %[[v9:.*]] = "tf.Reshape"(%[[v8]], %[[cst_7]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
|
||||
// CHECK: %[[v10:.*]] = "tf.Slice"(%[[v7]], %[[cst_4]], %[[cst_6]]) : (tensor<3x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32>
|
||||
// CHECK: %[[v11:.*]] = "tf.Reshape"(%[[v10]], %[[cst_7]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
|
||||
// CHECK: %[[v12:.*]] = "tf.Slice"(%[[v7]], %[[cst_5]], %[[cst_6]]) : (tensor<3x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32>
|
||||
// CHECK: %[[v13:.*]] = "tf.Reshape"(%[[v12]], %[[cst_7]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
|
||||
|
||||
// CHECK: %[[v14:.*]] = "tf.MatMul"(%[[v2]], %[[v9]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
|
||||
// CHECK: %[[v15:.*]] = "tf.MatMul"(%[[v4]], %[[v11]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
|
||||
// CHECK: %[[v16:.*]] = "tf.MatMul"(%[[v6]], %[[v13]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
|
||||
|
||||
// CHECK: %[[v17:.*]] = "tf.Pack"(%[[v14]], %[[v15]], %[[v16]]) {N = 3 : i64, axis = 0 : i64} : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<3x4x6xf32>
|
||||
// CHECK: %[[v18:.*]] = "tf.Reshape"(%[[v17]], %[[cst_8]]) : (tensor<3x4x6xf32>, tensor<3xi64>) -> tensor<3x4x6xf32>
|
||||
|
||||
// CHECK: return %[[v18]] : tensor<3x4x6xf32>
|
||||
}
|
||||
|
||||
func @batchMatMulV2Matrix(%arg0: tensor<4x5xf32>, %arg1: tensor<5x6xf32>) -> tensor<4x6xf32> {
|
||||
%0 = "tf.BatchMatMulV2"(%arg0, %arg1) : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
|
||||
return %0 : tensor<4x6xf32>
|
||||
|
||||
// CHECK-LABEL: batchMatMulV2Matrix
|
||||
// CHECK: %[[v0:.*]] = "tf.MatMul"(%arg0, %arg1) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
|
||||
// CHECK: return %[[v0]] : tensor<4x6xf32>
|
||||
}
|
||||
|
||||
func @batchMatMulTwoDim(%arg0: tensor<2x3x4x5xf32>, %arg1: tensor<2x3x5x6xf32>) -> tensor<2x3x4x6xf32> {
|
||||
%0 = "tf.BatchMatMul"(%arg0, %arg1) : (tensor<2x3x4x5xf32>, tensor<2x3x5x6xf32>) -> tensor<2x3x4x6xf32>
|
||||
return %0 : tensor<2x3x4x6xf32>
|
||||
|
||||
// CHECK-LABEL: batchMatMulTwoDim
|
||||
// CHECK: %[[cst:.*]] = constant dense<[6, 4, 5]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_0:.*]] = constant dense<[1, 4, 5]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_1:.*]] = constant dense<[4, 5]> : tensor<2xi64>
|
||||
// CHECK: %[[cst_2:.*]] = constant dense<[6, 5, 6]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_3:.*]] = constant dense<0> : tensor<3xi64>
|
||||
// CHECK: %[[cst_4:.*]] = constant dense<[1, 0, 0]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_5:.*]] = constant dense<[2, 0, 0]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_6:.*]] = constant dense<[3, 0, 0]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_7:.*]] = constant dense<[4, 0, 0]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_8:.*]] = constant dense<[5, 0, 0]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_9:.*]] = constant dense<[1, 5, 6]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_10:.*]] = constant dense<[5, 6]> : tensor<2xi64>
|
||||
// CHECK: %[[cst_11:.*]] = constant dense<[2, 3, 4, 6]> : tensor<4xi64>
|
||||
|
||||
// CHECK: %[[v0:.*]] = "tf.Reshape"(%arg0, %[[cst]]) : (tensor<2x3x4x5xf32>, tensor<3xi64>) -> tensor<6x4x5xf32>
|
||||
// CHECK: %[[v1:.*]] = "tf.Slice"(%[[v0]], %[[cst_3]], %[[cst_0]]) : (tensor<6x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
|
||||
// CHECK: %[[v2:.*]] = "tf.Reshape"(%[[v1]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
|
||||
// CHECK: %[[v3:.*]] = "tf.Slice"(%[[v0]], %[[cst_4]], %[[cst_0]]) : (tensor<6x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
|
||||
// CHECK: %[[v4:.*]] = "tf.Reshape"(%[[v3]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
|
||||
// CHECK: %[[v5:.*]] = "tf.Slice"(%[[v0]], %[[cst_5]], %[[cst_0]]) : (tensor<6x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
|
||||
// CHECK: %[[v6:.*]] = "tf.Reshape"(%[[v5]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
|
||||
// CHECK: %[[v7:.*]] = "tf.Slice"(%[[v0]], %[[cst_6]], %[[cst_0]]) : (tensor<6x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
|
||||
// CHECK: %[[v8:.*]] = "tf.Reshape"(%[[v7]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
|
||||
// CHECK: %[[v9:.*]] = "tf.Slice"(%[[v0]], %[[cst_7]], %[[cst_0]]) : (tensor<6x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
|
||||
// CHECK: %[[v10:.*]] = "tf.Reshape"(%[[v9]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
|
||||
// CHECK: %[[v11:.*]] = "tf.Slice"(%[[v0]], %[[cst_8]], %[[cst_0]]) : (tensor<6x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
|
||||
// CHECK: %[[v12:.*]] = "tf.Reshape"(%[[v11]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
|
||||
|
||||
// CHECK: %[[v13:.*]] = "tf.Reshape"(%arg1, %[[cst_2]]) : (tensor<2x3x5x6xf32>, tensor<3xi64>) -> tensor<6x5x6xf32>
|
||||
// CHECK: %[[v14:.*]] = "tf.Slice"(%[[v13]], %[[cst_3]], %[[cst_9]]) : (tensor<6x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32>
|
||||
// CHECK: %[[v15:.*]] = "tf.Reshape"(%[[v14]], %[[cst_10]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
|
||||
// CHECK: %[[v16:.*]] = "tf.Slice"(%[[v13]], %[[cst_4]], %[[cst_9]]) : (tensor<6x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32>
|
||||
// CHECK: %[[v17:.*]] = "tf.Reshape"(%[[v16]], %[[cst_10]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
|
||||
// CHECK: %[[v18:.*]] = "tf.Slice"(%[[v13]], %[[cst_5]], %[[cst_9]]) : (tensor<6x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32>
|
||||
// CHECK: %[[v19:.*]] = "tf.Reshape"(%[[v18]], %[[cst_10]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
|
||||
// CHECK: %[[v20:.*]] = "tf.Slice"(%[[v13]], %[[cst_6]], %[[cst_9]]) : (tensor<6x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32>
|
||||
// CHECK: %[[v21:.*]] = "tf.Reshape"(%[[v20]], %[[cst_10]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
|
||||
// CHECK: %[[v22:.*]] = "tf.Slice"(%[[v13]], %[[cst_7]], %[[cst_9]]) : (tensor<6x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32>
|
||||
// CHECK: %[[v23:.*]] = "tf.Reshape"(%[[v22]], %[[cst_10]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
|
||||
// CHECK: %[[v24:.*]] = "tf.Slice"(%[[v13]], %[[cst_8]], %[[cst_9]]) : (tensor<6x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32>
|
||||
// CHECK: %[[v25:.*]] = "tf.Reshape"(%[[v24]], %[[cst_10]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
|
||||
|
||||
// CHECK: %[[v26:.*]] = "tf.MatMul"(%[[v2]], %[[v15]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
|
||||
// CHECK: %[[v27:.*]] = "tf.MatMul"(%[[v4]], %[[v17]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
|
||||
// CHECK: %[[v28:.*]] = "tf.MatMul"(%[[v6]], %[[v19]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
|
||||
// CHECK: %[[v29:.*]] = "tf.MatMul"(%[[v8]], %[[v21]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
|
||||
// CHECK: %[[v30:.*]] = "tf.MatMul"(%[[v10]], %[[v23]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
|
||||
// CHECK: %[[v31:.*]] = "tf.MatMul"(%[[v12]], %[[v25]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
|
||||
|
||||
// CHECK: %[[v32:.*]] = "tf.Pack"(%[[v26]], %[[v27]], %[[v28]], %[[v29]], %[[v30]], %[[v31]]) {N = 6 : i64, axis = 0 : i64} : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<6x4x6xf32>
|
||||
// CHECK: %[[v33:.*]] = "tf.Reshape"(%[[v32]], %[[cst_11]]) : (tensor<6x4x6xf32>, tensor<4xi64>) -> tensor<2x3x4x6xf32>
|
||||
|
||||
// CHECK: return %[[v33]] : tensor<2x3x4x6xf32>
|
||||
}
|
||||
|
||||
func @batchMatMulFlatInput(%arg0: tensor<3x4x5xf32>, %arg1: tensor<3x5x6xf32>) -> tensor<3x4x6xf32> {
|
||||
%0 = "tf.BatchMatMul"(%arg0, %arg1) : (tensor<3x4x5xf32>, tensor<3x5x6xf32>) -> tensor<3x4x6xf32>
|
||||
return %0 : tensor<3x4x6xf32>
|
||||
|
||||
// CHECK-LABEL: batchMatMulFlatInput
|
||||
// CHECK: %[[cst:.*]] = constant dense<[3, 4, 5]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_0:.*]] = constant dense<[1, 4, 5]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_1:.*]] = constant dense<[4, 5]> : tensor<2xi64>
|
||||
// CHECK: %[[cst_2:.*]] = constant dense<[3, 5, 6]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_3:.*]] = constant dense<0> : tensor<3xi64>
|
||||
// CHECK: %[[cst_4:.*]] = constant dense<[1, 0, 0]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_5:.*]] = constant dense<[2, 0, 0]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_6:.*]] = constant dense<[1, 5, 6]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_7:.*]] = constant dense<[5, 6]> : tensor<2xi64>
|
||||
// CHECK: %[[cst_8:.*]] = constant dense<[3, 4, 6]> : tensor<3xi64>
|
||||
|
||||
// CHECK: %[[v0:.*]] = "tf.Reshape"(%arg0, %[[cst]]) : (tensor<3x4x5xf32>, tensor<3xi64>) -> tensor<3x4x5xf32>
|
||||
// CHECK: %[[v1:.*]] = "tf.Slice"(%[[v0]], %[[cst_3]], %[[cst_0]]) : (tensor<3x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
|
||||
// CHECK: %[[v2:.*]] = "tf.Reshape"(%[[v1]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
|
||||
// CHECK: %[[v3:.*]] = "tf.Slice"(%[[v0]], %[[cst_4]], %[[cst_0]]) : (tensor<3x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
|
||||
// CHECK: %[[v4:.*]] = "tf.Reshape"(%[[v3]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
|
||||
// CHECK: %[[v5:.*]] = "tf.Slice"(%[[v0]], %[[cst_5]], %[[cst_0]]) : (tensor<3x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
|
||||
// CHECK: %[[v6:.*]] = "tf.Reshape"(%[[v5]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
|
||||
|
||||
// CHECK: %[[v7:.*]] = "tf.Reshape"(%arg1, %[[cst_2]]) : (tensor<3x5x6xf32>, tensor<3xi64>) -> tensor<3x5x6xf32>
|
||||
// CHECK: %[[v8:.*]] = "tf.Slice"(%[[v7]], %[[cst_3]], %[[cst_6]]) : (tensor<3x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32>
|
||||
// CHECK: %[[v9:.*]] = "tf.Reshape"(%[[v8]], %[[cst_7]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
|
||||
// CHECK: %[[v10:.*]] = "tf.Slice"(%[[v7]], %[[cst_4]], %[[cst_6]]) : (tensor<3x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32>
|
||||
// CHECK: %[[v11:.*]] = "tf.Reshape"(%[[v10]], %[[cst_7]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
|
||||
// CHECK: %[[v12:.*]] = "tf.Slice"(%[[v7]], %[[cst_5]], %[[cst_6]]) : (tensor<3x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32>
|
||||
// CHECK: %[[v13:.*]] = "tf.Reshape"(%[[v12]], %[[cst_7]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
|
||||
|
||||
// CHECK: %[[v14:.*]] = "tf.MatMul"(%[[v2]], %[[v9]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
|
||||
// CHECK: %[[v15:.*]] = "tf.MatMul"(%[[v4]], %[[v11]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
|
||||
// CHECK: %[[v16:.*]] = "tf.MatMul"(%[[v6]], %[[v13]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
|
||||
|
||||
// CHECK: %[[v17:.*]] = "tf.Pack"(%[[v14]], %[[v15]], %[[v16]]) {N = 3 : i64, axis = 0 : i64} : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<3x4x6xf32>
|
||||
// CHECK: %[[v18:.*]] = "tf.Reshape"(%[[v17]], %[[cst_8]]) : (tensor<3x4x6xf32>, tensor<3xi64>) -> tensor<3x4x6xf32>
|
||||
|
||||
// CHECK: return %[[v18]] : tensor<3x4x6xf32>
|
||||
}
|
||||
|
||||
func @batchMatMulMatrix(%arg0: tensor<4x5xf32>, %arg1: tensor<5x6xf32>) -> tensor<4x6xf32> {
|
||||
%0 = "tf.BatchMatMul"(%arg0, %arg1) : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
|
||||
return %0 : tensor<4x6xf32>
|
||||
|
||||
// CHECK-LABEL: batchMatMulMatrix
|
||||
// CHECK: %[[v0:.*]] = "tf.MatMul"(%arg0, %arg1) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
|
||||
// CHECK: return %[[v0]] : tensor<4x6xf32>
|
||||
}
|
@ -149,7 +149,12 @@ int main(int argc, char **argv) {
|
||||
lower_tensor_list_ops, &result, &pm);
|
||||
if (!status.ok()) return kTrFailure;
|
||||
|
||||
auto output = mlir::openOutputFile(output_file_name);
|
||||
std::string error_msg;
|
||||
auto output = mlir::openOutputFile(output_file_name, &error_msg);
|
||||
if (output == nullptr) {
|
||||
llvm::errs() << error_msg << '\n';
|
||||
return kTrFailure;
|
||||
}
|
||||
output->os() << result;
|
||||
output->keep();
|
||||
|
||||
|
@ -14,7 +14,9 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
#include <map>
|
||||
#include <queue>
|
||||
#include <vector>
|
||||
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/ADT/StringSwitch.h"
|
||||
@ -353,6 +355,127 @@ struct OphintCompositeOp {
|
||||
std::map<int, AggregatedOperand> outputs;
|
||||
};
|
||||
|
||||
// Preprocess the graph for topo sort. (each operation is a node, while
|
||||
// inputs/outputs indictate edges) Assume the graph is acyclic. The preprocess
|
||||
// does the following:
|
||||
// Compute each operations's in-degress (how many input nodes they're taken)
|
||||
// Get all consumer operations for every operations. (operation_to_ouputs)
|
||||
// Get the init_queue (those operations will be processed first).
|
||||
void PreprocessTopoSortGraph(
|
||||
Block* block, std::queue<Operation*>* init_queue,
|
||||
llvm::DenseMap<Operation*, llvm::DenseSet<Operation*>>* operation_to_ouputs,
|
||||
llvm::DenseMap<Operation*, int>* operation_to_in_degrees) {
|
||||
for (auto& op : *block) {
|
||||
if (&op == block->getTerminator()) continue;
|
||||
if (op.getNumOperands() == 0) {
|
||||
init_queue->push(&op);
|
||||
} else {
|
||||
// The operand of the ops is not a direct indication of the "edge" as we
|
||||
// can have a pack op after a unpack op (they have multiple edges), we
|
||||
// should only count as one.
|
||||
llvm::DenseSet<Operation*> input_ops;
|
||||
for (int i = 0; i < op.getNumOperands(); ++i) {
|
||||
Operation* input_op = op.getOperand(i)->getDefiningOp();
|
||||
if (input_op) input_ops.insert(input_op);
|
||||
}
|
||||
if (input_ops.empty()) {
|
||||
init_queue->push(&op);
|
||||
continue;
|
||||
}
|
||||
operation_to_in_degrees->try_emplace(&op, input_ops.size());
|
||||
for (auto* input_op : input_ops) {
|
||||
auto preceeding_op_it = operation_to_ouputs->find(input_op);
|
||||
if (preceeding_op_it == operation_to_ouputs->end()) {
|
||||
auto result = operation_to_ouputs->try_emplace(
|
||||
input_op, llvm::DenseSet<Operation*>());
|
||||
preceeding_op_it = result.first;
|
||||
}
|
||||
preceeding_op_it->second.insert(&op);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool IsSideEffectOp(Operation* op) {
|
||||
if (op->hasNoSideEffect()) return false;
|
||||
|
||||
// Identity op has no side effect.
|
||||
// Check the OperationName maybe more elegant here.
|
||||
auto tf_identity_op = dyn_cast_or_null<TF::IdentityOp>(op);
|
||||
if (tf_identity_op) return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
// It's possible other transformations can benefit from this util function, but
|
||||
// since currently there's none, so we only limit this function to the ophint
|
||||
// extraction pass. We may refactor this function to extend the usage in future.
|
||||
//
|
||||
// Assume the graph is disconnected from outside.
|
||||
// Also assume the block has no arguments.
|
||||
LogicalResult TopoSortOperations(OpBuilder* builder) {
|
||||
std::queue<Operation*> init_queue;
|
||||
llvm::DenseMap<Operation*, llvm::DenseSet<Operation*>> operation_to_ouputs;
|
||||
llvm::DenseMap<Operation*, int> operation_to_in_degrees;
|
||||
std::vector<Operation*> sorted_ops;
|
||||
|
||||
PreprocessTopoSortGraph(builder->getBlock(), &init_queue,
|
||||
&operation_to_ouputs, &operation_to_in_degrees);
|
||||
while (!init_queue.empty()) {
|
||||
Operation* current_op = init_queue.front();
|
||||
init_queue.pop();
|
||||
sorted_ops.push_back(current_op);
|
||||
|
||||
auto current_op_to_output_it = operation_to_ouputs.find(current_op);
|
||||
if (current_op_to_output_it == operation_to_ouputs.end()) {
|
||||
continue;
|
||||
}
|
||||
for (Operation* output_op : current_op_to_output_it->second) {
|
||||
auto output_op_it = operation_to_in_degrees.find(output_op);
|
||||
if (output_op_it == operation_to_in_degrees.end()) return failure();
|
||||
|
||||
output_op_it->second -= 1;
|
||||
if (output_op_it->second == 0) {
|
||||
init_queue.push(output_op);
|
||||
operation_to_in_degrees.erase(output_op_it);
|
||||
}
|
||||
}
|
||||
operation_to_ouputs.erase(current_op_to_output_it);
|
||||
}
|
||||
|
||||
// Before we performs the sort. We need to make sure we didn't mess the
|
||||
// ordering of original side-effect operations.
|
||||
// It's possible those side-effect operations have no topogocial relations
|
||||
// at all!
|
||||
std::vector<Operation*> original_side_effect_ops;
|
||||
std::vector<Operation*> after_sort_side_effect_ops;
|
||||
for (auto& op : *builder->getBlock()) {
|
||||
if (IsSideEffectOp(&op) && (&op != builder->getBlock()->getTerminator()))
|
||||
original_side_effect_ops.push_back(&op);
|
||||
}
|
||||
for (auto* op : sorted_ops) {
|
||||
if (IsSideEffectOp(op)) after_sort_side_effect_ops.push_back(op);
|
||||
}
|
||||
if (original_side_effect_ops.size() != after_sort_side_effect_ops.size())
|
||||
return failure();
|
||||
for (int i = 0; i < original_side_effect_ops.size(); ++i) {
|
||||
if (original_side_effect_ops[i] != after_sort_side_effect_ops[i])
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Performs the sort.
|
||||
// Ideally it would be nice to just clear the block then write the sorted ops.
|
||||
// But unfortunately that's hard to do.
|
||||
for (int i = sorted_ops.size() - 1; i > 0; --i) {
|
||||
Operation* current_op = sorted_ops[i];
|
||||
for (int j = i - 1; j >= 0; --j) {
|
||||
Operation* prev_op = sorted_ops[j];
|
||||
prev_op->moveBefore(current_op);
|
||||
}
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
Operation* BuildFusedFuncOp(StringRef func_name, StringRef fused_func_type,
|
||||
Operation* insert_before_op,
|
||||
const std::map<int, Value*>& inputs,
|
||||
@ -360,10 +483,12 @@ Operation* BuildFusedFuncOp(StringRef func_name, StringRef fused_func_type,
|
||||
OpBuilder* builder, ModuleOp* module_op) {
|
||||
SmallVector<Type, 4> input_types;
|
||||
SmallVector<Value*, 4> input_values;
|
||||
SmallVector<int, 4> input_indexes;
|
||||
for (const auto& kv : inputs) {
|
||||
Value* input = kv.second;
|
||||
input_types.push_back(input->getType());
|
||||
input_values.push_back(input);
|
||||
input_indexes.push_back(kv.first);
|
||||
}
|
||||
|
||||
SmallVector<Type, 4> func_output_types;
|
||||
@ -378,6 +503,8 @@ Operation* BuildFusedFuncOp(StringRef func_name, StringRef fused_func_type,
|
||||
SmallVector<NamedAttribute, 4> attrs;
|
||||
attrs.push_back(builder->getNamedAttr(
|
||||
kTfLiteFunctionName, builder->getStringAttr(fused_func_type)));
|
||||
attrs.push_back(builder->getNamedAttr(
|
||||
kTfLiteFunctionInputIndex, builder->getI32ArrayAttr(input_indexes)));
|
||||
FuncOp func_op = FuncOp::create(insert_before_op->getLoc(), func_name,
|
||||
function_type, llvm::makeArrayRef(attrs));
|
||||
module_op->push_back(func_op);
|
||||
@ -507,6 +634,10 @@ LogicalResult ConvertOphintToStub(StringRef stub_name,
|
||||
};
|
||||
|
||||
builder->getBlock()->walk(removeRemovableOps);
|
||||
|
||||
// Step 8: Topo sort to fix any invalid temporary IRs.
|
||||
if (failed(TopoSortOperations(builder))) return failure();
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -20,6 +20,10 @@ include "mlir/Dialect/StandardOps/Ops.td"
|
||||
include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td"
|
||||
include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td"
|
||||
|
||||
def NonOpaqueElementsAttr : ElementsAttrBase<
|
||||
CPred<"!$_self.isa<OpaqueElementsAttr>()">,
|
||||
"non-opaque constant tensor">;
|
||||
|
||||
def F32ElementsAttr : ElementsAttrBase<
|
||||
CPred<"$_self.cast<ElementsAttr>().getType().getElementType().isF32()">, "float constant tensor">;
|
||||
|
||||
@ -56,8 +60,13 @@ def ExtractSingleElementAsInteger : NativeCodeCall<
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Nullary ops patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def : Pat<(TF_ConstOp ElementsAttr:$value), (TFL_ConstOp $value)>;
|
||||
|
||||
// Convert to std constant for statically shaped, non-opaque constants.
|
||||
def : Pat<(TF_ConstOp:$res NonOpaqueElementsAttr:$value), (ConstantOp $value),
|
||||
[(AnyStaticShapeTensor $res)], (addBenefit 10)>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Unary ops patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -157,7 +166,8 @@ def : Pat<(TF_ZerosLikeOp $arg), (TFL_ZerosLikeOp $arg)>;
|
||||
// The following two rules can both match an tf.Placeholder.input node with
|
||||
// min/max/type attributes, so we increase the benefit of the first rule by one
|
||||
// so the tfl.quantize and tfl.dequantize ops will be inserted if it matches.
|
||||
def : Pat<(TF_PlaceholderInputOp $inputs, $min, $max, $type),
|
||||
def : Pat<(TF_PlaceholderInputOp TensorOf<[F16, F32, F64]>:$inputs,
|
||||
$min, $max, $type),
|
||||
(TFL_DequantizeOp
|
||||
(TFL_QuantizeOp
|
||||
(TFL_InputOp $inputs),
|
||||
@ -191,7 +201,8 @@ def : Pat<(TF_GatherV2Op $params, $indices,
|
||||
|
||||
def : Pat<(TF_FloorDivOp $l, $r), (TFL_FloorDivOp $l, $r)>;
|
||||
|
||||
def : Pat<(TF_NotEqualOp $l, $r), (TFL_NotEqualOp $l, $r)>;
|
||||
def : Pat<(TF_NotEqualOp $l, $r, /*incompatible_shape_error=*/ConstBoolAttrTrue),
|
||||
(TFL_NotEqualOp $l, $r)>;
|
||||
|
||||
def : Pat<(TF_LogicalAndOp $l, $r), (TFL_LogicalAndOp $l, $r)>;
|
||||
|
||||
@ -252,7 +263,7 @@ def : Pat<(TF_ReluOp (TF_SquaredDifferenceOp $l, $r)),
|
||||
|
||||
def : Pat<(TF_ReverseV2Op $arg0, $arg1), (TFL_ReverseV2Op $arg0, $arg1)>;
|
||||
|
||||
def : Pat<(TF_EqualOp $arg0, $arg1), (TFL_EqualOp $arg0, $arg1)>;
|
||||
def : Pat<(TF_EqualOp $arg0, $arg1, /*incompatible_shape_error=*/ConstBoolAttrTrue), (TFL_EqualOp $arg0, $arg1)>;
|
||||
|
||||
def : Pat<(TF_PadOp $arg0, $arg1), (TFL_PadOp $arg0, $arg1)>;
|
||||
|
||||
@ -308,3 +319,11 @@ def : Pat<(TF_FloorModOp $arg0, $arg1), (TFL_FloorModOp $arg0, $arg1)>;
|
||||
def : Pat<(TF_ExpOp $arg0), (TFL_ExpOp $arg0)>;
|
||||
|
||||
def : Pat<(TF_LRNOp $arg0, $radius, F32Attr:$bias, F32Attr:$alpha, F32Attr:$beta), (TFL_LocalResponseNormalizationOp $arg0, (convertIntAttrTo32Bit $radius), $bias, $alpha, $beta)>;
|
||||
|
||||
def : Pat<
|
||||
(TF_NonMaxSuppressionV4Op $boxes, $scores, $max_output_size, $iou_threshold, $score_threshold, $pad_to_max_output_size),
|
||||
(TFL_NonMaxSuppressionV4Op $boxes, $scores, $max_output_size, $iou_threshold, $score_threshold)>;
|
||||
|
||||
def : Pat<
|
||||
(TF_NonMaxSuppressionV5Op $boxes, $scores, $max_output_size, $iou_threshold, $score_threshold, $soft_nms_sigma, $pad_to_max_output_size),
|
||||
(TFL_NonMaxSuppressionV5Op $boxes, $scores, $max_output_size, $iou_threshold, $score_threshold, $soft_nms_sigma)>;
|
||||
|
@ -0,0 +1,228 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This transformation pass prepare the tflite fused ops for quantization.
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/None.h"
|
||||
#include "llvm/ADT/Optional.h"
|
||||
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:local_config_mlir
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Builders.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
|
||||
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// The LoadQuantizationRecipe Pass.
|
||||
//
|
||||
namespace mlir {
|
||||
namespace TFL {
|
||||
|
||||
namespace {
|
||||
|
||||
// This pass loads the quantization recipe for the TFLite ops to be quantized.
|
||||
// Specifically, it extends the fused ops with their internal implementation as
|
||||
// op regions. Each ops in the region produces results with element type
|
||||
// AnyQuantizedType, thus bitwidth, narrow_range, etc are included. The op also
|
||||
// defines the op quantization traits, which are used to propgate the
|
||||
// quantization parameters by the following passes.
|
||||
struct LoadQuantizationRecipe : public FunctionPass<LoadQuantizationRecipe> {
|
||||
void runOnFunction() override;
|
||||
|
||||
private:
|
||||
void Initialize(LSTMOp lstm, OpBuilder* builder);
|
||||
|
||||
// Create LSTM gates with different weights for input, recurrent and
|
||||
// cell state, and also the layer normalization parameters.
|
||||
Operation* CreateGate(Location loc, Value* in, Value* in_w, Value* rec,
|
||||
Value* rec_w,
|
||||
llvm::Optional<std::pair<Value*, Value*>> cell,
|
||||
Value* ln_w, Value* ln_bias, OpBuilder* builder);
|
||||
|
||||
Operation* CreateLayerNorm(Location loc, Value* in, Value* ln_w,
|
||||
Value* ln_bias, OpBuilder* builder);
|
||||
|
||||
// Add the internal implementation of the LSTM to its regions.
|
||||
void LoadForLSTMOp(LSTMOp lstm, OpBuilder* builder);
|
||||
|
||||
StringAttr none_af;
|
||||
StringAttr fc_format;
|
||||
BoolAttr keep_dims;
|
||||
Type int8;
|
||||
Type int16;
|
||||
ConstantOp none_cst;
|
||||
};
|
||||
|
||||
void LoadQuantizationRecipe::Initialize(LSTMOp lstm, OpBuilder* builder) {
|
||||
Type expressed_type =
|
||||
lstm.input()->getType().cast<ShapedType>().getElementType();
|
||||
Type int8_storage_type = builder->getIntegerType(8);
|
||||
Type int16_storage_type = builder->getIntegerType(16);
|
||||
auto flag = quant::QuantizationFlags::FlagValue::Signed;
|
||||
int64_t int8_min = quant::QuantizedType::getDefaultMininumForInteger(
|
||||
flag, /*integralWidth=*/8);
|
||||
int64_t int8_max = quant::QuantizedType::getDefaultMaxinumForInteger(
|
||||
flag, /*integralWidth=*/8);
|
||||
int64_t int16_min = quant::QuantizedType::getDefaultMininumForInteger(
|
||||
flag, /*integralWidth=*/16);
|
||||
int64_t int16_max = quant::QuantizedType::getDefaultMaxinumForInteger(
|
||||
flag, /*integralWidth=*/16);
|
||||
auto any_int8 = quant::AnyQuantizedType::get(
|
||||
flag, int8_storage_type, expressed_type, int8_min, int8_max);
|
||||
auto any_int16 = quant::AnyQuantizedType::get(
|
||||
flag, int16_storage_type, expressed_type, int16_min, int16_max);
|
||||
|
||||
int8 = any_int8.castFromExpressedType(lstm.input()->getType());
|
||||
int16 = any_int16.castFromExpressedType(lstm.input()->getType());
|
||||
}
|
||||
|
||||
Operation* LoadQuantizationRecipe::CreateLayerNorm(Location loc, Value* in,
|
||||
Value* ln_w, Value* ln_bias,
|
||||
OpBuilder* builder) {
|
||||
// Note that l2_normalization and add ops here are not the execution kernle
|
||||
// implementation for layer_normalization and we just want to use them to
|
||||
// model the quantization requirement.
|
||||
auto l2_norm = builder->create<L2NormalizationOp>(loc, int16, in, none_af);
|
||||
auto add = builder->create<AddOp>(loc, int16, in, l2_norm, none_af);
|
||||
return builder->create<FullyConnectedOp>(loc, int16, add, ln_w, ln_bias,
|
||||
none_af, fc_format, keep_dims);
|
||||
}
|
||||
|
||||
Operation* LoadQuantizationRecipe::CreateGate(
|
||||
Location loc, Value* in, Value* in_w, Value* rec, Value* rec_w,
|
||||
llvm::Optional<std::pair<Value*, Value*>> cell, Value* ln_w, Value* ln_bias,
|
||||
OpBuilder* builder) {
|
||||
auto s1 = builder->create<FullyConnectedOp>(loc, int16, in, in_w, none_cst,
|
||||
none_af, fc_format, keep_dims);
|
||||
auto s2 = builder->create<FullyConnectedOp>(loc, int16, rec, rec_w, none_cst,
|
||||
none_af, fc_format, keep_dims);
|
||||
|
||||
AddNOp s4;
|
||||
if (cell.hasValue()) {
|
||||
auto s3 = builder->create<MulOp>(loc, int16, cell.getValue().first,
|
||||
cell.getValue().second, none_af);
|
||||
s4 = builder->create<AddNOp>(
|
||||
loc, int16,
|
||||
llvm::ArrayRef<Value*>(
|
||||
{*s1.output().begin(), *s2.output().begin(), s3.output()}));
|
||||
|
||||
} else {
|
||||
s4 = builder->create<AddNOp>(
|
||||
loc, int16,
|
||||
llvm::ArrayRef<Value*>({*s1.output().begin(), *s2.output().begin()}));
|
||||
}
|
||||
|
||||
auto s5 = CreateLayerNorm(loc, s4.sum(), ln_w, ln_bias, builder);
|
||||
|
||||
if (cell.hasValue()) {
|
||||
return builder->create<LogisticOp>(loc, int16, s5->getResult(0));
|
||||
} else {
|
||||
return builder->create<TanhOp>(loc, int16, s5->getResult(0));
|
||||
}
|
||||
}
|
||||
|
||||
void LoadQuantizationRecipe::LoadForLSTMOp(LSTMOp lstm, OpBuilder* builder) {
|
||||
Initialize(lstm, builder);
|
||||
|
||||
Region region;
|
||||
region.push_back(new Block);
|
||||
builder->setInsertionPointToEnd(®ion.front());
|
||||
Location loc = lstm.getLoc();
|
||||
Type int32_type = builder->getIntegerType(32);
|
||||
Type int32_tensor = builder->getTensorType(int32_type);
|
||||
none_cst = builder->create<ConstantOp>(loc, builder->getNoneType(),
|
||||
builder->getUnitAttr());
|
||||
|
||||
auto input_gate = CreateGate(
|
||||
loc, lstm.input(), lstm.input_to_input_weights(),
|
||||
lstm.input_activation_state(), lstm.recurrent_to_input_weights(),
|
||||
llvm::Optional<std::pair<Value*, Value*>>(
|
||||
{lstm.input_cell_state(), lstm.cell_to_input_weights()}),
|
||||
lstm.input_layer_norm_coefficients(), lstm.input_gate_bias(), builder);
|
||||
|
||||
auto forget_gate = CreateGate(
|
||||
loc, lstm.input(), lstm.input_to_forget_weights(),
|
||||
lstm.input_activation_state(), lstm.recurrent_to_forget_weights(),
|
||||
llvm::Optional<std::pair<Value*, Value*>>(
|
||||
{lstm.input_cell_state(), lstm.cell_to_forget_weights()}),
|
||||
lstm.forget_layer_norm_coefficients(), lstm.forget_gate_bias(), builder);
|
||||
|
||||
auto cell_gate = CreateGate(loc, lstm.input(), lstm.input_to_cell_weights(),
|
||||
lstm.input_activation_state(),
|
||||
lstm.recurrent_to_cell_weights(), llvm::None,
|
||||
lstm.cell_layer_norm_coefficients(),
|
||||
lstm.cell_bias(), builder);
|
||||
|
||||
auto forget_cell_state = builder->create<MulOp>(
|
||||
loc, int16, forget_gate->getResult(0), lstm.input_cell_state(), none_af);
|
||||
auto input_cell_state = builder->create<MulOp>(
|
||||
loc, int16, input_gate->getResult(0), cell_gate->getResult(0), none_af);
|
||||
auto new_cell = builder->create<AddOp>(loc, int16, forget_cell_state.output(),
|
||||
input_cell_state.output(), none_af);
|
||||
|
||||
auto output_gate = CreateGate(
|
||||
loc, lstm.input(), lstm.input_to_output_weights(),
|
||||
lstm.input_activation_state(), lstm.recurrent_to_output_weights(),
|
||||
llvm::Optional<std::pair<Value*, Value*>>(
|
||||
{new_cell, lstm.cell_to_output_weights()}),
|
||||
lstm.output_layer_norm_coefficients(), lstm.output_gate_bias(), builder);
|
||||
|
||||
auto new_cell_tanh = builder->create<TanhOp>(loc, int16, new_cell);
|
||||
auto hidden_state = builder->create<MulOp>(
|
||||
loc, int16, new_cell_tanh.y(), output_gate->getResult(0), none_af);
|
||||
auto act = builder->create<FullyConnectedOp>(
|
||||
loc, int8, hidden_state.output(), lstm.projection_weights(),
|
||||
lstm.projection_bias(), none_af, fc_format, keep_dims);
|
||||
|
||||
// TODO(fengliuai): define and register the op in the QuantOps Dialect.
|
||||
OperationState return_state(loc, "tf_quant.pseudo_return", act.getResult(0),
|
||||
{int8}, {});
|
||||
builder->createOperation(return_state);
|
||||
|
||||
lstm.internal().takeBody(region);
|
||||
}
|
||||
|
||||
void LoadQuantizationRecipe::runOnFunction() {
|
||||
FuncOp func = getFunction();
|
||||
OpBuilder builder(func);
|
||||
none_af = builder.getStringAttr("NONE");
|
||||
fc_format = builder.getStringAttr("DEFAULT");
|
||||
keep_dims = builder.getBoolAttr(false);
|
||||
|
||||
func.walk([&](Operation* op) {
|
||||
if (auto lstm = llvm::dyn_cast<LSTMOp>(op)) {
|
||||
LoadForLSTMOp(lstm, &builder);
|
||||
}
|
||||
// Handles other ops.
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// Creates an instance of the TensorFlow Lite dialect LoadQuantizationRecipe
|
||||
// pass.
|
||||
std::unique_ptr<FunctionPassBase> CreateLoadQuantizationRecipePass() {
|
||||
return absl::make_unique<LoadQuantizationRecipe>();
|
||||
}
|
||||
|
||||
static PassRegistration<LoadQuantizationRecipe> pass(
|
||||
"tfl-load-recipe", "Load TFL op quantization recipe");
|
||||
|
||||
} // namespace TFL
|
||||
} // namespace mlir
|
@ -429,12 +429,14 @@ LogicalResult LowerStaticTensorListPass::RewriteFunction(
|
||||
} else if (auto tf_op = llvm::dyn_cast<TF::TensorListReserveOp>(op)) {
|
||||
if (!(tf_op.element_dtype().isF16() || tf_op.element_dtype().isF32() ||
|
||||
tf_op.element_dtype().isF64() ||
|
||||
tf_op.element_dtype().isInteger(1) ||
|
||||
tf_op.element_dtype().isInteger(8) ||
|
||||
tf_op.element_dtype().isInteger(16) ||
|
||||
tf_op.element_dtype().isInteger(32) ||
|
||||
tf_op.element_dtype().isInteger(64))) {
|
||||
return tf_op.emitError(
|
||||
"requires element_dtype to be 8-bit/16-bit/32-bit/64-bit integer "
|
||||
"requires element_dtype to be 1-bit/8-bit/16-bit/32-bit/64-bit "
|
||||
"integer "
|
||||
"or 16-bit/32-bit/64-bit "
|
||||
"float type during TF Lite transformation pass");
|
||||
}
|
||||
@ -461,6 +463,10 @@ LogicalResult LowerStaticTensorListPass::RewriteFunction(
|
||||
auto c = ConvertTFTensorListPushBack(context);
|
||||
rewriter->setInsertionPoint(op);
|
||||
c.matchAndRewrite(op, *rewriter);
|
||||
} else if (auto tf_op = llvm::dyn_cast<TF::TensorListLengthOp>(op)) {
|
||||
auto c = TFL::ConvertTFTensorListLength(context);
|
||||
rewriter->setInsertionPoint(op);
|
||||
c.matchAndRewrite(op, *rewriter);
|
||||
} else if (auto tf_op = llvm::dyn_cast<TF::WhileOp>(op)) {
|
||||
if (op->getAttr("T")) op->removeAttr(Identifier::get("T", context));
|
||||
UpdateWhileFunctionType(tf_op);
|
||||
|
@ -122,6 +122,8 @@ class OperandHasRank<int n> : Constraint<
|
||||
// Mul->Rsqrt->Sum->Square
|
||||
// Currently L2Normalization doesn't support activation function
|
||||
// in TFLite.
|
||||
// TODO(karimnosseir): Add constraints that the kernel code assumes.
|
||||
// constraint on axis and depth.
|
||||
def : Pat<(TFL_MulOp $operand1,
|
||||
(TFL_RsqrtOp
|
||||
(TFL_SumOp
|
||||
@ -130,13 +132,14 @@ def : Pat<(TFL_MulOp $operand1,
|
||||
$keep_dims)),
|
||||
TFL_AF_None),
|
||||
(TFL_L2NormalizationOp $operand1, TFL_AF_None),
|
||||
[(EqualOperands $operand1, $square_operand),
|
||||
(OperandHasRank<1> $operand1)]>;
|
||||
[(EqualOperands $operand1, $square_operand)]>;
|
||||
|
||||
// This pattern constructs L2NormalizationOp from
|
||||
// Div->sqrt->Sum->Square
|
||||
// Currently L2Normalization doesn't support activation function
|
||||
// in TFLite.
|
||||
// TODO(karimnosseir): Add constraints that the kernel code assumes.
|
||||
// constraint on axis and depth.
|
||||
def : Pat<(TFL_DivOp $operand1,
|
||||
(TFL_SqrtOp
|
||||
(TFL_SumOp
|
||||
@ -145,5 +148,4 @@ def : Pat<(TFL_DivOp $operand1,
|
||||
$keep_dims)),
|
||||
TFL_AF_None),
|
||||
(TFL_L2NormalizationOp $operand1, TFL_AF_None),
|
||||
[(EqualOperands $operand1, $square_operand),
|
||||
(OperandHasRank<1> $operand1)]>;
|
||||
[(EqualOperands $operand1, $square_operand)]>;
|
||||
|
@ -18,6 +18,14 @@ include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td"
|
||||
|
||||
def FalseBoolAttr : AttrConstraint<CPred<"!$_self.getValue()">>;
|
||||
|
||||
def NonOpaqueElementsAttr : ElementsAttrBase<
|
||||
CPred<"!$_self.isa<OpaqueElementsAttr>()">,
|
||||
"non-opaque constant tensor">;
|
||||
|
||||
// Convert to std constant for statically shaped, non-opaque constants.
|
||||
def : Pat<(TF_ConstOp:$res NonOpaqueElementsAttr:$value), (ConstantOp $value),
|
||||
[(AnyStaticShapeTensor $res)]>;
|
||||
|
||||
// Converts tf.FusedBatchNorm & tf.FusedBatchNormV3 into a sequence of more primitive arithmetic
|
||||
// operations. Specifically, performs the following calculation:
|
||||
//
|
||||
@ -81,8 +89,8 @@ class TFi32<int v> : ConstantAttr<I32ElementsAttr, !cast<string>(v)>;
|
||||
def : Pat<(TF_MatMulOp $a, $b, ConstBoolAttrFalse:$at, ConstBoolAttrFalse),
|
||||
(TF_MatMulOp $a, (TF_TransposeOp $b, (TF_SubOp (TF_RangeOp
|
||||
/*start=*/(TF_RankOp $b),
|
||||
/*limit=*/(ConstantOp TFi32<0>),
|
||||
/*delta=*/(ConstantOp TFi32<-1>)), (ConstantOp TFi32<1>))),
|
||||
/*limit=*/(TF_ConstOp TFi32<0>),
|
||||
/*delta=*/(TF_ConstOp TFi32<-1>)), (TF_ConstOp TFi32<1>))),
|
||||
$at, ConstBoolAttrTrue)>;
|
||||
|
||||
// Matmul with transpose on a to matmul with explicit transpose op and a not
|
||||
@ -90,10 +98,12 @@ def : Pat<(TF_MatMulOp $a, $b, ConstBoolAttrFalse:$at, ConstBoolAttrFalse),
|
||||
def : Pat<(TF_MatMulOp $a, $b, ConstBoolAttrTrue, $bt),
|
||||
(TF_MatMulOp (TF_TransposeOp $a, (TF_SubOp (TF_RangeOp
|
||||
/*start=*/(TF_RankOp $a),
|
||||
/*limit=*/(ConstantOp TFi32<0>),
|
||||
/*delta=*/(ConstantOp TFi32<-1>)), (ConstantOp TFi32<1>))), $b,
|
||||
/*limit=*/(TF_ConstOp TFi32<0>),
|
||||
/*delta=*/(TF_ConstOp TFi32<-1>)), (TF_ConstOp TFi32<1>))), $b,
|
||||
ConstBoolAttrFalse, $bt)>;
|
||||
|
||||
// Partially supported in TFLite, treated as passthrough IdentityOp
|
||||
def : Pat<(TF_CheckNumericsOp $arg, $msg), (TF_IdentityOp $arg)>;
|
||||
def : Pat<(TF_SnapshotOp $arg), (TF_IdentityOp $arg)>;
|
||||
def : Pat<(TF_StopGradientOp $arg), (TF_IdentityOp $arg)>;
|
||||
|
||||
|
@ -50,6 +50,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/unroll_batch_matmul.h"
|
||||
#include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h"
|
||||
#include "tensorflow/compiler/mlir/lite/utils/validators.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
@ -246,7 +247,8 @@ struct ConvertTFConvOp : public RewritePattern {
|
||||
filter_type.getShape());
|
||||
auto bias_type = rewriter.getTensorType({bias_dim}, elem_type);
|
||||
auto bias_attr = rewriter.getZeroAttr(bias_type);
|
||||
auto bias = rewriter.create<ConstantOp>(op->getLoc(), bias_type, bias_attr);
|
||||
auto bias =
|
||||
rewriter.create<TF::ConstOp>(op->getLoc(), bias_type, bias_attr);
|
||||
|
||||
auto *conv_state = static_cast<ConvertTFConvOpMatchState *>(state.get());
|
||||
auto conv_op = static_cast<const ConcreteType *>(this)->createTFLOp(
|
||||
@ -297,7 +299,7 @@ class ConvertTFConv2D : public ConvertTFConvOp<ConvertTFConv2D, TF::Conv2DOp> {
|
||||
rewriter.getIntegerType(32));
|
||||
auto perm_attr =
|
||||
DenseElementsAttr::get(perm_type, llvm::makeArrayRef<int>(perm));
|
||||
auto perm_op = rewriter.create<ConstantOp>(loc, perm_type, perm_attr);
|
||||
auto perm_op = rewriter.create<TF::ConstOp>(loc, perm_type, perm_attr);
|
||||
|
||||
// Create tensor type for the transpose result.
|
||||
auto filter_type = filter->getType().cast<RankedTensorType>();
|
||||
@ -366,7 +368,7 @@ class ConvertTFDepthwiseConv2dNative
|
||||
auto shape_type = rewriter.getTensorType({4}, rewriter.getIntegerType(64));
|
||||
auto shape_attr =
|
||||
DenseElementsAttr::get(shape_type, llvm::makeArrayRef(result_shape));
|
||||
auto shape = rewriter.create<ConstantOp>(loc, shape_type, shape_attr);
|
||||
auto shape = rewriter.create<TF::ConstOp>(loc, shape_type, shape_attr);
|
||||
|
||||
return rewriter.create<TF::ReshapeOp>(loc, result_type, filter, shape);
|
||||
}
|
||||
@ -377,6 +379,11 @@ class ConvertTFDepthwiseConv2dNative
|
||||
void PrepareTFPass::runOnFunction() {
|
||||
OwningRewritePatternList patterns;
|
||||
auto func = getFunction();
|
||||
|
||||
patterns.insert<ConvertTFBatchMatMulOp<TF::BatchMatMulOp>,
|
||||
ConvertTFBatchMatMulOp<TF::BatchMatMulV2Op>>(&getContext());
|
||||
applyPatternsGreedily(func, patterns);
|
||||
|
||||
// This pattern was intented to uses TFL QDQs to preserve the quantization
|
||||
// parameters from the TF Quant ops, thus this pattern should run with the
|
||||
// first `applyPatternsGreedily` method, which would otherwise removes the
|
||||
|
@ -14,9 +14,13 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "mlir/Dialect/StandardOps/Ops.td"
|
||||
include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td"
|
||||
include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td"
|
||||
|
||||
def CreateTFShapeOp : NativeCodeCall<
|
||||
"$_builder.create<TF::ShapeOp>($0->getLoc(), $1, $2)">;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TensorList transformation patterns.
|
||||
// Note that the pattern below rewrites `TensorList` tensors (which has type DT_VARIANT)
|
||||
@ -34,3 +38,11 @@ def ConvertTFTensorListStack : Pat<
|
||||
def ConvertTFTensorListGetItem : Pat<
|
||||
(TF_TensorListGetItemOp $input, $index, $element_shape),
|
||||
(TF_GatherOp $input, $index, (NativeCodeCall<"$_builder.getBoolAttr(true)">))>;
|
||||
|
||||
// TensorListLength is equivalent to the size of the first dimension of the
|
||||
// input tensorlist, rewrite it to a combination of Gather and Shape op.
|
||||
def ConvertTFTensorListLength: Pat<
|
||||
(TF_TensorListLengthOp:$old_value $input),
|
||||
(TF_GatherOp
|
||||
(CreateTFShapeOp $old_value, $input, /*use 32bit*/ConstBoolAttrTrue),
|
||||
(ConstantOp ConstantAttr<I32ElementsAttr, "0">), ConstBoolAttrTrue)>;
|
||||
|
309
tensorflow/compiler/mlir/lite/transforms/unroll_batch_matmul.cc
Normal file
309
tensorflow/compiler/mlir/lite/transforms/unroll_batch_matmul.cc
Normal file
@ -0,0 +1,309 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/unroll_batch_matmul.h"
|
||||
|
||||
#include <climits>
|
||||
#include <cstdint>
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/StringSwitch.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "mlir/Analysis/LoopAnalysis.h" // TF:local_config_mlir
|
||||
#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:local_config_mlir
|
||||
#include "mlir/Dialect/QuantOps/UniformSupport.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/OpImplementation.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
|
||||
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/Functional.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
|
||||
#include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h"
|
||||
#include "tensorflow/compiler/mlir/lite/utils/validators.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/core/util/matmul_bcast.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace TFL {
|
||||
|
||||
namespace {
|
||||
// Unrolls a BatchMatMul on the batch dimension. We need to slice each batch out
|
||||
// of the inputs, matmul them individually, then stack them all back together at
|
||||
// the end.
|
||||
struct UnrollBatchMatMulPass : public FunctionPass<UnrollBatchMatMulPass> {
|
||||
void runOnFunction() override;
|
||||
};
|
||||
|
||||
void UnrollBatchMatMulPass::runOnFunction() {
|
||||
OwningRewritePatternList patterns;
|
||||
auto func = getFunction();
|
||||
|
||||
patterns.insert<ConvertTFBatchMatMulOp<TF::BatchMatMulOp>,
|
||||
ConvertTFBatchMatMulOp<TF::BatchMatMulV2Op>>(&getContext());
|
||||
applyPatternsGreedily(func, patterns);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
template <typename BatchMatMulOpType>
|
||||
TF::ReshapeOp ConvertTFBatchMatMulOp<BatchMatMulOpType>::createReshapeOp(
|
||||
Value* value, ArrayRef<int64_t> shape, Type elementType, Location loc,
|
||||
PatternRewriter& rewriter) {
|
||||
int64_t shape_rank = shape.size();
|
||||
auto shapeSpecType =
|
||||
rewriter.getTensorType({shape_rank}, rewriter.getIntegerType(64));
|
||||
Type resultType = rewriter.getTensorType(shape, elementType);
|
||||
auto constant_attr = DenseElementsAttr::get(shapeSpecType, shape);
|
||||
auto shapeTensor =
|
||||
rewriter.create<ConstantOp>(loc, shapeSpecType, constant_attr);
|
||||
return rewriter.create<TF::ReshapeOp>(loc, resultType, /*tensor=*/value,
|
||||
/*shape=*/shapeTensor);
|
||||
}
|
||||
|
||||
template <typename BatchMatMulOpType>
|
||||
std::vector<Value*> ConvertTFBatchMatMulOp<BatchMatMulOpType>::sliceInput(
|
||||
Value* value, int batch_size, Location loc, PatternRewriter& rewriter) {
|
||||
RankedTensorType tensorType = value->getType().cast<RankedTensorType>();
|
||||
Type elementType = tensorType.getElementType();
|
||||
|
||||
int rank = tensorType.getShape().size();
|
||||
int num_rows = tensorType.getShape()[rank - 2];
|
||||
int num_cols = tensorType.getShape()[rank - 1];
|
||||
|
||||
// Reshape to rank-3 Tensor with first dimension as the batch size.
|
||||
auto reshapeOp = createReshapeOp(value, {batch_size, num_rows, num_cols},
|
||||
elementType, loc, rewriter);
|
||||
|
||||
SmallVector<int64_t, 3> sliceSize = {1, num_rows, num_cols};
|
||||
|
||||
std::vector<Value*> sliced;
|
||||
Type int64Type = rewriter.getIntegerType(64);
|
||||
Type sliceResultType = rewriter.getTensorType(sliceSize, elementType);
|
||||
|
||||
// Slice along each batch index and remember the slice output for future
|
||||
// use.
|
||||
for (int batch_idx = 0; batch_idx < batch_size; ++batch_idx) {
|
||||
auto vector3Type = rewriter.getTensorType({3}, int64Type);
|
||||
|
||||
auto begin_attr =
|
||||
DenseElementsAttr::get<int64_t>(vector3Type, {batch_idx, 0, 0});
|
||||
auto size_attr = DenseElementsAttr::get<int64_t>(vector3Type, sliceSize);
|
||||
auto begin = rewriter.create<ConstantOp>(loc, vector3Type, begin_attr);
|
||||
auto size = rewriter.create<ConstantOp>(loc, vector3Type, size_attr);
|
||||
auto sliceOp =
|
||||
rewriter.create<TF::SliceOp>(loc, sliceResultType,
|
||||
/*input=*/reshapeOp.output(), begin, size);
|
||||
|
||||
// Squeeze matrix, i.e. reshape [1, num_rows, num_cols] -> [num_rows,
|
||||
// num_cols]
|
||||
auto squeezeOp = createReshapeOp(sliceOp.output(), {num_rows, num_cols},
|
||||
elementType, loc, rewriter);
|
||||
|
||||
sliced.emplace_back(squeezeOp.output());
|
||||
}
|
||||
return sliced;
|
||||
}
|
||||
|
||||
template <typename BatchMatMulOpType>
|
||||
TF::TransposeOp ConvertTFBatchMatMulOp<BatchMatMulOpType>::createTransposeOp(
|
||||
Value* value, Location loc, PatternRewriter& rewriter) {
|
||||
auto valueType = value->getType().cast<RankedTensorType>();
|
||||
auto shape = valueType.getShape();
|
||||
int dims = shape.size();
|
||||
|
||||
std::vector<int32_t> perm(dims);
|
||||
for (int i = 0; i < dims - 2; i++) {
|
||||
perm[i] = i;
|
||||
}
|
||||
perm[dims - 2] = dims - 1;
|
||||
perm[dims - 1] = dims - 2;
|
||||
|
||||
auto perm_type = rewriter.getTensorType({static_cast<int32_t>(perm.size())},
|
||||
rewriter.getIntegerType(32));
|
||||
|
||||
auto perm_attr = DenseElementsAttr::get(perm_type, llvm::makeArrayRef(perm));
|
||||
auto perm_op = rewriter.create<ConstantOp>(loc, perm_type, perm_attr);
|
||||
|
||||
std::vector<int64_t> transposed_shape(shape.begin(), shape.end());
|
||||
int64_t r = transposed_shape[dims - 1];
|
||||
int64_t c = transposed_shape[dims - 2];
|
||||
|
||||
transposed_shape[dims - 1] = c;
|
||||
transposed_shape[dims - 2] = r;
|
||||
|
||||
auto transposed_type =
|
||||
rewriter.getTensorType(transposed_shape, valueType.getElementType());
|
||||
return rewriter.create<TF::TransposeOp>(loc, transposed_type, value, perm_op);
|
||||
}
|
||||
|
||||
template <typename BatchMatMulOpType>
|
||||
TF::PackOp ConvertTFBatchMatMulOp<BatchMatMulOpType>::createMatMulOps(
|
||||
const std::vector<Value*>& sliced_lhs,
|
||||
const std::vector<Value*>& sliced_rhs, const tensorflow::MatMulBCast& bcast,
|
||||
int rows, int cols, Type elementType, Location loc,
|
||||
PatternRewriter& rewriter) {
|
||||
auto matmulType = rewriter.getTensorType({rows, cols}, elementType);
|
||||
|
||||
std::vector<Value*> matmuls;
|
||||
for (int batch_idx = 0; batch_idx < bcast.output_batch_size(); ++batch_idx) {
|
||||
int lhs_batch_idx, rhs_batch_idx;
|
||||
if (bcast.IsBroadcastingRequired()) {
|
||||
lhs_batch_idx = bcast.x_batch_indices()[batch_idx];
|
||||
rhs_batch_idx = bcast.y_batch_indices()[batch_idx];
|
||||
} else {
|
||||
lhs_batch_idx = batch_idx;
|
||||
rhs_batch_idx = batch_idx;
|
||||
}
|
||||
auto false_attr = rewriter.getBoolAttr(false);
|
||||
auto matmul = rewriter.create<TF::MatMulOp>(loc, matmulType,
|
||||
/*a=*/sliced_lhs[lhs_batch_idx],
|
||||
/*b=*/sliced_rhs[rhs_batch_idx],
|
||||
/*transpose_a=*/false_attr,
|
||||
/*transpose_b=*/false_attr);
|
||||
matmuls.emplace_back(matmul.product());
|
||||
}
|
||||
|
||||
// Combine the result of each individual MatMul into a rank-3 Tensor.
|
||||
Type packedType = rewriter.getTensorType(
|
||||
{bcast.output_batch_size(), rows, cols}, elementType);
|
||||
|
||||
auto N = rewriter.getI64IntegerAttr(matmuls.size());
|
||||
auto axis = rewriter.getI64IntegerAttr(0);
|
||||
return rewriter.create<TF::PackOp>(loc, packedType,
|
||||
/*values=*/matmuls, N, axis);
|
||||
}
|
||||
|
||||
template <typename BatchMatMulOpType>
|
||||
PatternMatchResult ConvertTFBatchMatMulOp<BatchMatMulOpType>::matchAndRewrite(
|
||||
BatchMatMulOpType op, PatternRewriter& rewriter) const {
|
||||
Value* input_lhs = op.x();
|
||||
Value* input_rhs = op.y();
|
||||
|
||||
if (!input_lhs->getType().isa<RankedTensorType>()) {
|
||||
// LHS must be a ranked tensor type
|
||||
return this->matchFailure();
|
||||
}
|
||||
if (!input_rhs->getType().isa<RankedTensorType>()) {
|
||||
// RHS must be a ranked tensor type
|
||||
return this->matchFailure();
|
||||
}
|
||||
|
||||
auto lhs_type = input_lhs->getType().cast<RankedTensorType>();
|
||||
auto rhs_type = input_rhs->getType().cast<RankedTensorType>();
|
||||
|
||||
auto elementType = lhs_type.getElementType();
|
||||
|
||||
if (elementType != rhs_type.getElementType()) {
|
||||
// The element type of LHS must be the same with element type of RHS
|
||||
return this->matchFailure();
|
||||
}
|
||||
|
||||
auto lhs_shape = lhs_type.getShape();
|
||||
auto rhs_shape = rhs_type.getShape();
|
||||
|
||||
Location loc = op.getLoc();
|
||||
|
||||
// Transpose LHS input if necessary.
|
||||
if (op.adj_x()) {
|
||||
input_lhs = createTransposeOp(input_lhs, loc, rewriter);
|
||||
|
||||
lhs_type = input_lhs->getType().cast<RankedTensorType>();
|
||||
lhs_shape = lhs_type.getShape();
|
||||
}
|
||||
|
||||
// Transpose RHS input if necessary.
|
||||
if (op.adj_y()) {
|
||||
input_rhs = createTransposeOp(input_rhs, loc, rewriter);
|
||||
|
||||
rhs_type = input_rhs->getType().cast<RankedTensorType>();
|
||||
rhs_shape = rhs_type.getShape();
|
||||
}
|
||||
|
||||
// Ensure that input ranks are at least 2 and batch shapes are
|
||||
// broadcastable.
|
||||
const int dims_a = lhs_shape.size();
|
||||
const int dims_b = rhs_shape.size();
|
||||
if (dims_a < 2 || dims_b < 2) {
|
||||
// Both inputs must have rank >= 2
|
||||
return this->matchFailure();
|
||||
}
|
||||
|
||||
if (lhs_shape[dims_a - 1] != rhs_shape[dims_b - 2]) {
|
||||
// Input dimensions must be compatible for multipication.
|
||||
return this->matchFailure();
|
||||
}
|
||||
|
||||
if (dims_a == 2 && dims_b == 2) {
|
||||
// When both inputs are matrices, just replace the op to a matmul op.
|
||||
Type resultType =
|
||||
rewriter.getTensorType({lhs_shape[0], rhs_shape[1]}, elementType);
|
||||
auto false_attr = rewriter.getBoolAttr(false);
|
||||
rewriter.replaceOpWithNewOp<TF::MatMulOp>(op, resultType,
|
||||
/*a=*/input_lhs,
|
||||
/*b=*/input_rhs,
|
||||
/*transpose_a=*/false_attr,
|
||||
/*transpose_b=*/false_attr);
|
||||
return this->matchSuccess();
|
||||
}
|
||||
|
||||
tensorflow::MatMulBCast bcast(absl::InlinedVector<tensorflow::int64, 4>(
|
||||
lhs_shape.begin(), lhs_shape.end()),
|
||||
absl::InlinedVector<tensorflow::int64, 4>(
|
||||
rhs_shape.begin(), rhs_shape.end()));
|
||||
|
||||
if (!bcast.IsValid()) {
|
||||
// Input batch dimensions must be broadcastable
|
||||
return this->matchFailure();
|
||||
}
|
||||
|
||||
// Compute slices for each batch in the LHS and RHS.
|
||||
std::vector<Value*> sliced_lhs =
|
||||
sliceInput(input_lhs, bcast.x_batch_size(), loc, rewriter);
|
||||
std::vector<Value*> sliced_rhs =
|
||||
sliceInput(input_rhs, bcast.y_batch_size(), loc, rewriter);
|
||||
|
||||
// Compute (single batch) MatMul for each output batch. The MatMul outputs
|
||||
// are then packed together into one output Tensor.
|
||||
auto packOp =
|
||||
createMatMulOps(sliced_lhs, sliced_rhs, bcast, lhs_shape[dims_a - 2],
|
||||
rhs_shape[dims_b - 1], elementType, loc, rewriter);
|
||||
|
||||
// Reshape the rank-3 Tensor into the correct output shape.
|
||||
const auto& resultBatchShape = bcast.output_batch_shape().dim_sizes();
|
||||
std::vector<int64_t> resultShape(resultBatchShape.begin(),
|
||||
resultBatchShape.end());
|
||||
resultShape.push_back(lhs_shape[dims_a - 2]);
|
||||
resultShape.push_back(rhs_shape[dims_b - 1]);
|
||||
|
||||
auto reshapeOp =
|
||||
createReshapeOp(packOp.output(), resultShape, elementType, loc, rewriter);
|
||||
rewriter.replaceOp(op, reshapeOp.output());
|
||||
return this->matchSuccess();
|
||||
}
|
||||
|
||||
static PassRegistration<UnrollBatchMatMulPass> pass(
|
||||
"tfl-unroll-batch-matmul",
|
||||
"Unroll TF BatchMatMul op into Reshape, Slice, MatMul, Pack ops.");
|
||||
|
||||
} // namespace TFL
|
||||
} // namespace mlir
|
@ -0,0 +1,60 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_UNROLL_BATCH_MATMUL_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_UNROLL_BATCH_MATMUL_H_
|
||||
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "mlir/IR/Location.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/TypeUtilities.h" // TF:local_config_mlir
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/core/util/matmul_bcast.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace TFL {
|
||||
|
||||
// Unroll tf.BatchMatMulV2 op into a sequence of TF ops. Since TFLite does not
|
||||
// support BatchMatMul operation, it unrolls a BatchMatMul op into tf.Reshape,
|
||||
// tf.Slice, tf.MatMul, tf.Pack, and tf.Reshape ops.
|
||||
template <typename BatchMatMulOpType>
|
||||
class ConvertTFBatchMatMulOp : public OpRewritePattern<BatchMatMulOpType> {
|
||||
using OpRewritePattern<BatchMatMulOpType>::OpRewritePattern;
|
||||
|
||||
static TF::ReshapeOp createReshapeOp(Value* value, ArrayRef<int64_t> shape,
|
||||
Type elementType, Location loc,
|
||||
PatternRewriter& rewriter);
|
||||
|
||||
static std::vector<Value*> sliceInput(Value* value, int batch_size,
|
||||
Location loc,
|
||||
PatternRewriter& rewriter);
|
||||
|
||||
static TF::TransposeOp createTransposeOp(Value* value, Location loc,
|
||||
PatternRewriter& rewriter);
|
||||
|
||||
static TF::PackOp createMatMulOps(const std::vector<Value*>& sliced_lhs,
|
||||
const std::vector<Value*>& sliced_rhs,
|
||||
const tensorflow::MatMulBCast& bcast,
|
||||
int rows, int cols, Type elementType,
|
||||
Location loc, PatternRewriter& rewriter);
|
||||
|
||||
PatternMatchResult matchAndRewrite(BatchMatMulOpType op,
|
||||
PatternRewriter& rewriter) const override;
|
||||
};
|
||||
|
||||
} // namespace TFL
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_UNROLL_BATCH_MATMUL_H_
|
456
tensorflow/compiler/mlir/lite/utils/lstm_utils.cc
Normal file
456
tensorflow/compiler/mlir/lite/utils/lstm_utils.cc
Normal file
@ -0,0 +1,456 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/mlir/lite/utils/lstm_utils.h"
|
||||
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/None.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Builders.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Function.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Identifier.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Location.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/OpDefinition.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Operation.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Types.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Value.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace TFL {
|
||||
|
||||
namespace {
|
||||
|
||||
Value* CreateI32SplatConst(OpBuilder* builder, ArrayRef<int64_t> shape,
|
||||
int32_t val, mlir::Location location) {
|
||||
auto type = builder->getTensorType(shape, builder->getIntegerType(32));
|
||||
auto attr = DenseElementsAttr::get(type, val);
|
||||
return builder->create<ConstantOp>(location, type, attr);
|
||||
}
|
||||
|
||||
Value* CreateF32SplatConst(OpBuilder* builder, ArrayRef<int64_t> shape,
|
||||
float val, mlir::Location location) {
|
||||
auto type = builder->getTensorType(shape, builder->getF32Type());
|
||||
auto attr = DenseElementsAttr::get(type, val);
|
||||
return builder->create<ConstantOp>(location, type, attr);
|
||||
}
|
||||
|
||||
Value* CreateI64DenseConst(OpBuilder* builder, ArrayRef<int64_t> shape,
|
||||
ArrayRef<int64_t> values, mlir::Location location) {
|
||||
auto type = builder->getTensorType(static_cast<int>(shape.size()),
|
||||
builder->getIntegerType(64));
|
||||
auto attr = DenseElementsAttr::get(type, values);
|
||||
return builder->create<ConstantOp>(location, type, attr);
|
||||
}
|
||||
|
||||
Value* CreateNoneValue(OpBuilder* builder, mlir::Location location) {
|
||||
return builder->create<mlir::ConstantOp>(location, builder->getNoneType(),
|
||||
builder->getUnitAttr());
|
||||
}
|
||||
|
||||
Value* Transpose2D(OpBuilder* builder, Value* value_to_transpose,
|
||||
RankedTensorType type, mlir::Location location) {
|
||||
// Create a constant op for transpose permutation.
|
||||
SmallVector<int64_t, 2> perm = {1, 0};
|
||||
auto perm_op = CreateI64DenseConst(builder, perm, perm, location);
|
||||
|
||||
// Create tensor type for the transpose result.
|
||||
auto transpose_type = type;
|
||||
auto transpose_shape = functional::map(
|
||||
[transpose_type](int64_t dim) { return transpose_type.getDimSize(dim); },
|
||||
perm);
|
||||
auto elem_type = transpose_type.getElementType();
|
||||
auto result_type = builder->getTensorType(transpose_shape, elem_type);
|
||||
|
||||
return builder->create<TF::TransposeOp>(location, result_type,
|
||||
value_to_transpose, perm_op);
|
||||
}
|
||||
|
||||
Value* SliceRankedTensor(OpBuilder* builder, Value* input,
|
||||
ArrayRef<int64_t> begin_shape,
|
||||
ArrayRef<int64_t> begin_values,
|
||||
ArrayRef<int64_t> size_shape,
|
||||
ArrayRef<int64_t> size_values,
|
||||
mlir::Location location) {
|
||||
// Create a dense constant op for slice's begin
|
||||
auto slice_i2c_begin =
|
||||
CreateI64DenseConst(builder, begin_shape, begin_values, location);
|
||||
|
||||
// Create a dense constant op for slice's size
|
||||
auto slice_i2c_size =
|
||||
CreateI64DenseConst(builder, size_shape, size_values, location);
|
||||
|
||||
return builder->create<TF::SliceOp>(
|
||||
location,
|
||||
builder->getTensorType(
|
||||
size_values,
|
||||
input->getType().cast<RankedTensorType>().getElementType()),
|
||||
input, slice_i2c_begin, slice_i2c_size);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void ConvertLSTMCellSimpleToFusedLSTM::SetWeightForInputToCellGate() {
|
||||
SmallVector<int64_t, 2> begin_i2c_values = {0, 0};
|
||||
input2cell_ = SliceRankedTensor(
|
||||
&builder_, weight_transposed_, weight_slice_shape_, begin_i2c_values,
|
||||
weight_slice_shape_, weight_slice_size_input_values_,
|
||||
fused_func_op_.getLoc());
|
||||
}
|
||||
|
||||
void ConvertLSTMCellSimpleToFusedLSTM::SetWeightForInputToInputGate() {
|
||||
SmallVector<int64_t, 2> begin_i2i_values = {n_cell_, 0};
|
||||
input2input_ = couple_input_forget_gates_
|
||||
? none_
|
||||
: SliceRankedTensor(&builder_, weight_transposed_,
|
||||
weight_slice_shape_, begin_i2i_values,
|
||||
weight_slice_shape_,
|
||||
weight_slice_size_input_values_,
|
||||
fused_func_op_.getLoc());
|
||||
}
|
||||
|
||||
void ConvertLSTMCellSimpleToFusedLSTM::SetWeightForInputToForgetGate() {
|
||||
int input_forget_start = couple_input_forget_gates_ ? n_cell_ : 2 * n_cell_;
|
||||
SmallVector<int64_t, 2> begin_i2f_values = {input_forget_start, 0};
|
||||
input2forget_ = SliceRankedTensor(
|
||||
&builder_, weight_transposed_, weight_slice_shape_, begin_i2f_values,
|
||||
weight_slice_shape_, weight_slice_size_input_values_,
|
||||
fused_func_op_.getLoc());
|
||||
}
|
||||
|
||||
void ConvertLSTMCellSimpleToFusedLSTM::SetWeightForInputToOutputGate() {
|
||||
int input_output_start =
|
||||
couple_input_forget_gates_ ? 2 * n_cell_ : 3 * n_cell_;
|
||||
SmallVector<int64_t, 2> begin_i2o_values = {input_output_start, 0};
|
||||
input2output_ = SliceRankedTensor(
|
||||
&builder_, weight_transposed_, weight_slice_shape_, begin_i2o_values,
|
||||
weight_slice_shape_, weight_slice_size_input_values_,
|
||||
fused_func_op_.getLoc());
|
||||
}
|
||||
|
||||
void ConvertLSTMCellSimpleToFusedLSTM::SetWeightForRecurrentToCellGate() {
|
||||
SmallVector<int64_t, 2> begin_rec2c_values = {0, n_input_};
|
||||
rec2cell_ = SliceRankedTensor(
|
||||
&builder_, weight_transposed_, weight_slice_shape_, begin_rec2c_values,
|
||||
weight_slice_shape_, weight_slice_size_recurrent_values_,
|
||||
fused_func_op_.getLoc());
|
||||
}
|
||||
|
||||
void ConvertLSTMCellSimpleToFusedLSTM::SetWeightForRecurrentToInputGate() {
|
||||
SmallVector<int64_t, 2> begin_rec2i_values = {n_cell_, n_input_};
|
||||
rec2input_ = couple_input_forget_gates_
|
||||
? none_
|
||||
: SliceRankedTensor(&builder_, weight_transposed_,
|
||||
weight_slice_shape_, begin_rec2i_values,
|
||||
weight_slice_shape_,
|
||||
weight_slice_size_recurrent_values_,
|
||||
fused_func_op_.getLoc());
|
||||
}
|
||||
|
||||
void ConvertLSTMCellSimpleToFusedLSTM::SetWeightForRecurrentToForgetGate() {
|
||||
int rec_forget_start = couple_input_forget_gates_ ? n_cell_ : 2 * n_cell_;
|
||||
SmallVector<int64_t, 2> begin_rec2f_values = {rec_forget_start, n_input_};
|
||||
rec2forget_ = SliceRankedTensor(
|
||||
&builder_, weight_transposed_, weight_slice_shape_, begin_rec2f_values,
|
||||
weight_slice_shape_, weight_slice_size_recurrent_values_,
|
||||
fused_func_op_.getLoc());
|
||||
}
|
||||
|
||||
void ConvertLSTMCellSimpleToFusedLSTM::SetWeightForRecurrentToOutputGate() {
|
||||
int rec_output_start = couple_input_forget_gates_ ? 2 * n_cell_ : 3 * n_cell_;
|
||||
SmallVector<int64_t, 2> begin_rec2o_values = {rec_output_start, n_input_};
|
||||
rec2output_ = SliceRankedTensor(
|
||||
&builder_, weight_transposed_, weight_slice_shape_, begin_rec2o_values,
|
||||
weight_slice_shape_, weight_slice_size_recurrent_values_,
|
||||
fused_func_op_.getLoc());
|
||||
}
|
||||
|
||||
void ConvertLSTMCellSimpleToFusedLSTM::SetBiasToCellGate() {
|
||||
SmallVector<int64_t, 1> begin_bias2c_values = {0};
|
||||
bias2cell_ = SliceRankedTensor(&builder_, bias_, bias_slice_shape_,
|
||||
begin_bias2c_values, bias_slice_shape_,
|
||||
bias_size_values_, fused_func_op_.getLoc());
|
||||
}
|
||||
|
||||
void ConvertLSTMCellSimpleToFusedLSTM::SetBiasToInputGate() {
|
||||
SmallVector<int64_t, 1> begin_bias2i_values = {n_cell_};
|
||||
bias2input_ =
|
||||
couple_input_forget_gates_
|
||||
? none_
|
||||
: SliceRankedTensor(&builder_, bias_, bias_slice_shape_,
|
||||
begin_bias2i_values, bias_slice_shape_,
|
||||
bias_size_values_, fused_func_op_.getLoc());
|
||||
}
|
||||
|
||||
void ConvertLSTMCellSimpleToFusedLSTM::SetBiasToForgetGate() {
|
||||
int bias_forget_start = couple_input_forget_gates_ ? n_cell_ : 2 * n_cell_;
|
||||
SmallVector<int64_t, 1> begin_bias2f_values = {bias_forget_start};
|
||||
bias2forget_ = SliceRankedTensor(&builder_, bias_, bias_slice_shape_,
|
||||
begin_bias2f_values, bias_slice_shape_,
|
||||
bias_size_values_, fused_func_op_.getLoc());
|
||||
}
|
||||
|
||||
void ConvertLSTMCellSimpleToFusedLSTM::SetBiasToOutputGate() {
|
||||
int bias_output_start =
|
||||
couple_input_forget_gates_ ? 2 * n_cell_ : 3 * n_cell_;
|
||||
SmallVector<int64_t, 1> begin_bias2o_values = {bias_output_start};
|
||||
bias2output_ = SliceRankedTensor(&builder_, bias_, bias_slice_shape_,
|
||||
begin_bias2o_values, bias_slice_shape_,
|
||||
bias_size_values_, fused_func_op_.getLoc());
|
||||
}
|
||||
|
||||
void ConvertLSTMCellSimpleToFusedLSTM::SetProjection() {
|
||||
SmallVector<int64_t, 2> projection_slice_shape = {
|
||||
1, num_cols_projection_transposed_};
|
||||
SmallVector<int64_t, 2> projection_slice_size_values = {n_output_, n_cell_};
|
||||
SmallVector<int64_t, 2> projection_slice_begin_values = {0, 0};
|
||||
proj_weight_ =
|
||||
!projection_
|
||||
? none_
|
||||
: SliceRankedTensor(
|
||||
&builder_, projection_transposed_, projection_slice_shape,
|
||||
projection_slice_begin_values, projection_slice_shape,
|
||||
projection_slice_size_values, fused_func_op_.getLoc());
|
||||
}
|
||||
|
||||
void ConvertLSTMCellSimpleToFusedLSTM::SetProjectionBias() {
|
||||
proj_bias_ = !projection_type_
|
||||
? none_
|
||||
: CreateF32SplatConst(&builder_, {n_output_}, 0,
|
||||
fused_func_op_.getLoc());
|
||||
}
|
||||
|
||||
void ConvertLSTMCellSimpleToFusedLSTM::SetInputActivationState() {
|
||||
input_activation_state_ = CreateF32SplatConst(&builder_, {1, n_output_}, 0,
|
||||
fused_func_op_.getLoc());
|
||||
}
|
||||
|
||||
void ConvertLSTMCellSimpleToFusedLSTM::SetInputCellState() {
|
||||
input_cell_state_ =
|
||||
CreateF32SplatConst(&builder_, {1, n_cell_}, 0, fused_func_op_.getLoc());
|
||||
}
|
||||
|
||||
void ConvertLSTMCellSimpleToFusedLSTM::SetCellLayerNormCoefficients() {
|
||||
cell_layer_norm_coefficients_ = none_;
|
||||
}
|
||||
|
||||
void ConvertLSTMCellSimpleToFusedLSTM::SetInputLayerNormCoefficients() {
|
||||
input_layer_norm_coefficients_ = none_;
|
||||
}
|
||||
|
||||
void ConvertLSTMCellSimpleToFusedLSTM::SetForgetLayerNormCoefficients() {
|
||||
forget_layer_norm_coefficients_ = none_;
|
||||
}
|
||||
void ConvertLSTMCellSimpleToFusedLSTM::SetOutputLayerNormCoefficients() {
|
||||
output_layer_norm_coefficients_ = none_;
|
||||
}
|
||||
|
||||
void ConvertLSTMCellSimpleToFusedLSTM::GenerateFusedOpOperands() {
|
||||
// Transpose both weight and projection.
|
||||
weight_transposed_ =
|
||||
Transpose2D(&builder_, weight_, weight_type_, fused_func_op_.getLoc());
|
||||
projection_transposed_ = Transpose2D(&builder_, projection_, projection_type_,
|
||||
fused_func_op_.getLoc());
|
||||
|
||||
none_ = CreateNoneValue(&builder_, fused_func_op_.getLoc());
|
||||
// Extract input to cifg gates via slicing the weight tensor
|
||||
SetWeightForInputToCellGate();
|
||||
SetWeightForInputToInputGate();
|
||||
SetWeightForInputToForgetGate();
|
||||
SetWeightForInputToOutputGate();
|
||||
|
||||
// Extract recurrent to cifg gates via slicing the weight tensor
|
||||
SetWeightForRecurrentToCellGate();
|
||||
SetWeightForRecurrentToInputGate();
|
||||
SetWeightForRecurrentToForgetGate();
|
||||
SetWeightForRecurrentToOutputGate();
|
||||
|
||||
// Extract bias to cifg gates via slicing the bias tensor
|
||||
SetBiasToCellGate();
|
||||
SetBiasToInputGate();
|
||||
SetBiasToForgetGate();
|
||||
SetBiasToOutputGate();
|
||||
|
||||
// Extract projection and set an empty projection bias
|
||||
SetProjection();
|
||||
SetProjectionBias();
|
||||
|
||||
// Set the variable tensors
|
||||
SetInputActivationState();
|
||||
SetInputCellState();
|
||||
|
||||
// Extract the layer norm coefficients
|
||||
SetCellLayerNormCoefficients();
|
||||
SetInputLayerNormCoefficients();
|
||||
SetForgetLayerNormCoefficients();
|
||||
SetOutputLayerNormCoefficients();
|
||||
}
|
||||
|
||||
void ConvertLSTMCellSimpleToFusedLSTM::UpdateFuncSignature() {
|
||||
// https://github.com/tensorflow/community/pull/113
|
||||
auto attr = fused_func_op_.getAttrOfType<StringAttr>("tf_.implements");
|
||||
if (!attr) {
|
||||
fused_func_op_.setAttr("tf._implements",
|
||||
builder_.getStringAttr(GetCompositeOpName()));
|
||||
}
|
||||
SmallVector<int64_t, 2> output_shape{1, n_output_};
|
||||
auto input_types = fused_func_op_.getType().getInputs();
|
||||
auto output_type = builder_.getTensorType(
|
||||
output_shape,
|
||||
input_->getType().cast<RankedTensorType>().getElementType());
|
||||
fused_func_op_.setType(mlir::FunctionType::get(input_types, output_type,
|
||||
fused_func_op_.getContext()));
|
||||
}
|
||||
|
||||
void ConvertLSTMCellSimpleToFusedLSTM::RewriteFunc() {
|
||||
// Update the func signature, based on output shape.
|
||||
// The func will ultimately return the output of the fused
|
||||
// LSTM op.
|
||||
UpdateFuncSignature();
|
||||
|
||||
// Transoform the weights, projection, bias and layer norm coefficients
|
||||
// to generate operands for the TFL fused LSTM op.
|
||||
GenerateFusedOpOperands();
|
||||
|
||||
// Create the fused LSTM op.
|
||||
SmallVector<int64_t, 2> output_shape = {1, n_output_};
|
||||
auto result_type = builder_.getTensorType(
|
||||
output_shape,
|
||||
input_->getType().cast<RankedTensorType>().getElementType());
|
||||
lstm_ = builder_.create<mlir::TFL::LSTMOp>(
|
||||
fused_func_op_.getLoc(), result_type, input_, input2input_, input2forget_,
|
||||
input2cell_, input2output_, rec2input_, rec2forget_, rec2cell_,
|
||||
rec2output_, /*cell_to_input_weights*/ none_,
|
||||
/*cell_to_forget_weights*/ none_,
|
||||
/*cell_to_output_weights*/ none_, bias2input_, bias2forget_, bias2cell_,
|
||||
bias2output_, proj_weight_, proj_bias_, input_activation_state_,
|
||||
input_cell_state_, input_layer_norm_coefficients_,
|
||||
forget_layer_norm_coefficients_, cell_layer_norm_coefficients_,
|
||||
output_layer_norm_coefficients_, builder_.getStringAttr("TANH"),
|
||||
builder_.getF32FloatAttr(10.0), builder_.getF32FloatAttr(0.0),
|
||||
builder_.getStringAttr("FULL"));
|
||||
|
||||
builder_.create<mlir::ReturnOp>(fused_func_op_.getLoc(), lstm_.getResult());
|
||||
}
|
||||
|
||||
LogicalResult ConvertLSTMCellSimpleToFusedLSTM::Initialize() {
|
||||
num_gates_ = couple_input_forget_gates_ ? 3 : 4;
|
||||
|
||||
input_ = fused_func_op_.getArgument(0);
|
||||
bias_ = fused_func_op_.getArgument(2);
|
||||
|
||||
weight_ = fused_func_op_.getArgument(1);
|
||||
weight_type_ = weight_->getType().cast<RankedTensorType>();
|
||||
|
||||
if (weight_type_.getRank() != 2) {
|
||||
return fused_func_op_.emitError() << "The weight tensor was not of rank 2";
|
||||
}
|
||||
|
||||
if (weight_type_.getDimSize(1) % num_gates_ != 0) {
|
||||
return fused_func_op_.emitError()
|
||||
<< "Invalid dimension 1 of weight tensor, "
|
||||
"should be divisible by the number of gates";
|
||||
}
|
||||
n_cell_ = weight_type_.getDimSize(1) / num_gates_;
|
||||
|
||||
projection_ = fused_func_op_.getArgument(3);
|
||||
projection_type_ = projection_->getType().cast<RankedTensorType>();
|
||||
if (projection_type_.getRank() != 2) {
|
||||
n_output_ = n_cell_;
|
||||
} else {
|
||||
n_output_ = projection_type_.getDimSize(1);
|
||||
}
|
||||
n_input_ = weight_type_.getDimSize(0) - n_output_;
|
||||
num_cols_weight_transposed_ = weight_type_.getDimSize(0);
|
||||
num_cols_projection_transposed_ = projection_type_.getDimSize(0);
|
||||
|
||||
bias_slice_shape_ = {n_cell_};
|
||||
bias_size_values_ = {n_cell_};
|
||||
weight_slice_shape_ = {1, num_cols_weight_transposed_};
|
||||
weight_slice_size_input_values_ = {n_cell_, n_input_};
|
||||
weight_slice_size_recurrent_values_ = {n_cell_, n_output_};
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM::Initialize() {
|
||||
if (failed(ConvertLSTMCellSimpleToFusedLSTM::Initialize())) {
|
||||
return fused_func_op_.emitError()
|
||||
<< "Specified LayerNormalizedLSTMCellSimple was not of the expected "
|
||||
"interface and cannot not be converted to the fused LSTM op";
|
||||
}
|
||||
|
||||
layer_norm_scale_ = fused_func_op_.getArgument(4);
|
||||
layer_norm_scale_type_ =
|
||||
layer_norm_scale_->getType().cast<RankedTensorType>();
|
||||
if (layer_norm_scale_type_.getRank() != 1) {
|
||||
return fused_func_op_.emitError()
|
||||
<< "The layer_norm_scale tensor was not of rank 1";
|
||||
}
|
||||
layer_norm_slice_shape_ = {n_cell_};
|
||||
layer_norm_size_values_ = {n_cell_};
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
void ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM::
|
||||
SetCellLayerNormCoefficients() {
|
||||
SmallVector<int64_t, 1> begin_cell_layer_norm_values = {0};
|
||||
cell_layer_norm_coefficients_ =
|
||||
SliceRankedTensor(&builder_, layer_norm_scale_, layer_norm_slice_shape_,
|
||||
begin_cell_layer_norm_values, layer_norm_slice_shape_,
|
||||
layer_norm_size_values_, fused_func_op_.getLoc());
|
||||
}
|
||||
|
||||
void ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM::
|
||||
SetInputLayerNormCoefficients() {
|
||||
SmallVector<int64_t, 1> begin_input_layer_norm_values = {n_cell_};
|
||||
input_layer_norm_coefficients_ =
|
||||
couple_input_forget_gates_
|
||||
? none_
|
||||
: SliceRankedTensor(
|
||||
&builder_, layer_norm_scale_, layer_norm_slice_shape_,
|
||||
begin_input_layer_norm_values, layer_norm_slice_shape_,
|
||||
layer_norm_size_values_, fused_func_op_.getLoc());
|
||||
}
|
||||
|
||||
void ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM::
|
||||
SetForgetLayerNormCoefficients() {
|
||||
SmallVector<int64_t, 1> begin_forget_layer_norm_values = {2 * n_cell_};
|
||||
forget_layer_norm_coefficients_ =
|
||||
SliceRankedTensor(&builder_, layer_norm_scale_, layer_norm_slice_shape_,
|
||||
begin_forget_layer_norm_values, layer_norm_slice_shape_,
|
||||
layer_norm_size_values_, fused_func_op_.getLoc());
|
||||
}
|
||||
|
||||
void ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM::
|
||||
SetOutputLayerNormCoefficients() {
|
||||
SmallVector<int64_t, 1> begin_output_layer_norm_values = {3 * n_cell_};
|
||||
output_layer_norm_coefficients_ =
|
||||
SliceRankedTensor(&builder_, layer_norm_scale_, layer_norm_slice_shape_,
|
||||
begin_output_layer_norm_values, layer_norm_slice_shape_,
|
||||
layer_norm_size_values_, fused_func_op_.getLoc());
|
||||
}
|
||||
|
||||
} // namespace TFL
|
||||
} // namespace mlir
|
214
tensorflow/compiler/mlir/lite/utils/lstm_utils.h
Normal file
214
tensorflow/compiler/mlir/lite/utils/lstm_utils.h
Normal file
@ -0,0 +1,214 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This header file defines common utils used by TFLite transformation
|
||||
// passes to work with op attributes.
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_UTILS_LSTM_UTILS_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_LSTM_UTILS_H_
|
||||
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "mlir/IR/Builders.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Function.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Location.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Value.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace TFL {
|
||||
|
||||
constexpr char kLstmCellSimple[] = "LSTMCellSimple";
|
||||
constexpr char kLayerNormalizedLstmCellSimple[] =
|
||||
"LayerNormalizedLstmCellSimple";
|
||||
|
||||
// A utility class that enables the conversion of the LSTMCellSimple composite
|
||||
// op into a fused TFL LSTM op. The fused op is contained within a FuncOp
|
||||
// that also contains other supporting ops needed to construct the operands for
|
||||
// the fused op. The caller provides the containing FuncOp as input with
|
||||
// arguments specifying the input, weight, projection and bias.
|
||||
// The weight, pprojection, bias and layer norm scale all need to be
|
||||
// RankedTensorType.
|
||||
// This class sets the layer norm coefficients to NoneType.
|
||||
class ConvertLSTMCellSimpleToFusedLSTM {
|
||||
public:
|
||||
// TODO(b/140053256): The couple_input_forget_gates should be specified on
|
||||
// FuncOp as an attribute.
|
||||
explicit ConvertLSTMCellSimpleToFusedLSTM(mlir::FuncOp fused_func_op,
|
||||
bool couple_input_forget_gates)
|
||||
: fused_func_op_(fused_func_op),
|
||||
couple_input_forget_gates_(couple_input_forget_gates),
|
||||
builder_(fused_func_op.getBody()) {}
|
||||
|
||||
// not copyable.
|
||||
ConvertLSTMCellSimpleToFusedLSTM(const ConvertLSTMCellSimpleToFusedLSTM&) =
|
||||
delete;
|
||||
ConvertLSTMCellSimpleToFusedLSTM& operator=(
|
||||
const ConvertLSTMCellSimpleToFusedLSTM&) = delete;
|
||||
virtual ~ConvertLSTMCellSimpleToFusedLSTM() {}
|
||||
|
||||
// verify input func op arguments and initialize internal state.
|
||||
virtual LogicalResult Initialize();
|
||||
|
||||
virtual llvm::StringRef GetCompositeOpName() { return kLstmCellSimple; }
|
||||
|
||||
// Rewrite the func body with constructed fused lstm.
|
||||
void RewriteFunc();
|
||||
|
||||
protected:
|
||||
void UpdateFuncSignature();
|
||||
void GenerateFusedOpOperands();
|
||||
|
||||
void SetWeightForInputToCellGate();
|
||||
void SetWeightForInputToInputGate();
|
||||
void SetWeightForInputToForgetGate();
|
||||
void SetWeightForInputToOutputGate();
|
||||
|
||||
void SetWeightForRecurrentToCellGate();
|
||||
void SetWeightForRecurrentToInputGate();
|
||||
void SetWeightForRecurrentToForgetGate();
|
||||
void SetWeightForRecurrentToOutputGate();
|
||||
|
||||
void SetBiasToCellGate();
|
||||
void SetBiasToInputGate();
|
||||
void SetBiasToForgetGate();
|
||||
void SetBiasToOutputGate();
|
||||
|
||||
void SetProjection();
|
||||
void SetProjectionBias();
|
||||
|
||||
void SetInputActivationState();
|
||||
void SetInputCellState();
|
||||
|
||||
virtual void SetCellLayerNormCoefficients();
|
||||
virtual void SetInputLayerNormCoefficients();
|
||||
virtual void SetForgetLayerNormCoefficients();
|
||||
virtual void SetOutputLayerNormCoefficients();
|
||||
|
||||
// specified state
|
||||
FuncOp fused_func_op_;
|
||||
Value* input_;
|
||||
Value* weight_;
|
||||
Value* bias_;
|
||||
Value* projection_;
|
||||
bool couple_input_forget_gates_;
|
||||
|
||||
// internal state
|
||||
Value* weight_transposed_;
|
||||
Value* projection_transposed_;
|
||||
RankedTensorType weight_type_;
|
||||
RankedTensorType projection_type_;
|
||||
int num_gates_;
|
||||
int n_cell_;
|
||||
int n_output_;
|
||||
int n_input_;
|
||||
int num_cols_weight_transposed_;
|
||||
int num_cols_projection_transposed_;
|
||||
|
||||
// input -> cifg
|
||||
Value* input2input_;
|
||||
Value* input2forget_;
|
||||
Value* input2cell_;
|
||||
Value* input2output_;
|
||||
|
||||
// reccurrent -> cifg
|
||||
Value* rec2input_;
|
||||
Value* rec2forget_;
|
||||
Value* rec2cell_;
|
||||
Value* rec2output_;
|
||||
|
||||
// bias -> cifg
|
||||
Value* bias2input_;
|
||||
Value* bias2forget_;
|
||||
Value* bias2cell_;
|
||||
Value* bias2output_;
|
||||
|
||||
// projection
|
||||
Value* proj_weight_;
|
||||
Value* proj_bias_;
|
||||
|
||||
// state
|
||||
Value* input_activation_state_;
|
||||
Value* input_cell_state_;
|
||||
|
||||
// layer norm coefficients
|
||||
Value* input_layer_norm_coefficients_;
|
||||
Value* forget_layer_norm_coefficients_;
|
||||
Value* cell_layer_norm_coefficients_;
|
||||
Value* output_layer_norm_coefficients_;
|
||||
|
||||
mlir::TFL::LSTMOp lstm_;
|
||||
|
||||
Value* none_;
|
||||
SmallVector<int64_t, 1> bias_slice_shape_;
|
||||
SmallVector<int64_t, 1> bias_size_values_;
|
||||
SmallVector<int64_t, 2> weight_slice_shape_;
|
||||
SmallVector<int64_t, 2> weight_slice_size_input_values_;
|
||||
SmallVector<int64_t, 2> weight_slice_size_recurrent_values_;
|
||||
OpBuilder builder_;
|
||||
};
|
||||
|
||||
// A utility class that enables the conversion of the
|
||||
// LayerNormalizedLSTMCellSimple composite op into a fused TFL LSTM op. The
|
||||
// fused op is contained within a FuncOp that also contains other supporting ops
|
||||
// needed to construct the operands for the fused op. The caller provides the
|
||||
// containing FuncOp as input with arguments specifying the input, weight,
|
||||
// projection, bias and layer norm scale. The weight, pprojection, bias and
|
||||
// layer norm scale all need to be RankedTensorType.
|
||||
// This class overrides the layer norm coefficient setters from the base class.
|
||||
class ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM
|
||||
: public ConvertLSTMCellSimpleToFusedLSTM {
|
||||
public:
|
||||
// TODO(b/140053256): The couple_input_forget_gates should be specified on
|
||||
// FuncOp as an attribute.
|
||||
explicit ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM(
|
||||
mlir::FuncOp fused_func_op, bool couple_input_forget_gates)
|
||||
: ConvertLSTMCellSimpleToFusedLSTM(fused_func_op,
|
||||
couple_input_forget_gates) {}
|
||||
|
||||
// not copyable.
|
||||
ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM(
|
||||
const ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM&) = delete;
|
||||
ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM& operator=(
|
||||
const ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM&) = delete;
|
||||
~ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM() override {}
|
||||
|
||||
llvm::StringRef GetCompositeOpName() override {
|
||||
return kLayerNormalizedLstmCellSimple;
|
||||
}
|
||||
|
||||
LogicalResult Initialize() override;
|
||||
|
||||
protected:
|
||||
void SetCellLayerNormCoefficients() override;
|
||||
void SetInputLayerNormCoefficients() override;
|
||||
void SetForgetLayerNormCoefficients() override;
|
||||
void SetOutputLayerNormCoefficients() override;
|
||||
|
||||
private:
|
||||
// specified state
|
||||
Value* layer_norm_scale_;
|
||||
|
||||
// internal state
|
||||
RankedTensorType layer_norm_scale_type_;
|
||||
SmallVector<int64_t, 1> layer_norm_slice_shape_;
|
||||
SmallVector<int64_t, 1> layer_norm_size_values_;
|
||||
};
|
||||
|
||||
} // end namespace TFL
|
||||
} // end namespace mlir
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_LITE_UTILS_LSTM_UTILS_H_
|
222
tensorflow/compiler/mlir/lite/utils/lstm_utils_test.cc
Normal file
222
tensorflow/compiler/mlir/lite/utils/lstm_utils_test.cc
Normal file
@ -0,0 +1,222 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/mlir/lite/utils/lstm_utils.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Builders.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Function.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Location.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Types.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Value.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace TFL {
|
||||
|
||||
FuncOp createFusedFunc(mlir::Builder* builder) {
|
||||
SmallVector<int64_t, 2> input_shape{1, 2};
|
||||
SmallVector<int64_t, 2> weight_shape{3, 12};
|
||||
SmallVector<int64_t, 1> bias_shape{2};
|
||||
SmallVector<int64_t, 2> projection_shape{1, 2};
|
||||
SmallVector<int64_t, 1> layer_norm_scale{4};
|
||||
SmallVector<int64_t, 2> output_shape{1, 2};
|
||||
auto input_type = builder->getTensorType(input_shape, builder->getF32Type());
|
||||
auto weight_type =
|
||||
builder->getTensorType(weight_shape, builder->getF32Type());
|
||||
auto bias_type = builder->getTensorType(bias_shape, builder->getF32Type());
|
||||
auto projection_type =
|
||||
builder->getTensorType(projection_shape, builder->getF32Type());
|
||||
auto layer_norm_scale_type =
|
||||
builder->getTensorType(layer_norm_scale, builder->getF32Type());
|
||||
auto output_type =
|
||||
builder->getTensorType(output_shape, builder->getF32Type());
|
||||
SmallVector<mlir::Type, 4> input_types{input_type, weight_type, bias_type,
|
||||
projection_type,
|
||||
layer_norm_scale_type};
|
||||
auto func_type = builder->getFunctionType(input_types, output_type);
|
||||
|
||||
auto func =
|
||||
FuncOp::create(mlir::NameLoc::get(builder->getIdentifier("fused_func"),
|
||||
builder->getContext()),
|
||||
"fused_func", func_type, {});
|
||||
func.addEntryBlock();
|
||||
return func;
|
||||
}
|
||||
|
||||
// TODO(ashwinm): Revisit if this test should be moved to a test pass
|
||||
// with FileCheck test after the pass that consumes the lstm_utils to stack
|
||||
// the layers.
|
||||
class LstmUtilsTest : public ::testing::Test {
|
||||
protected:
|
||||
LstmUtilsTest() {}
|
||||
|
||||
void SetUp() override {
|
||||
builder_ = std::unique_ptr<mlir::Builder>(new Builder(&context_));
|
||||
fused_lstm_func_ = createFusedFunc(builder_.get());
|
||||
}
|
||||
|
||||
void TearDown() override {
|
||||
fused_lstm_func_.erase();
|
||||
builder_.reset();
|
||||
}
|
||||
FuncOp fused_lstm_func_;
|
||||
mlir::MLIRContext context_;
|
||||
std::unique_ptr<mlir::Builder> builder_;
|
||||
};
|
||||
|
||||
TEST_F(LstmUtilsTest, ConvertLSTMCellSimple) {
|
||||
mlir::TFL::ConvertLSTMCellSimpleToFusedLSTM convert(fused_lstm_func_, false);
|
||||
|
||||
auto result = convert.Initialize();
|
||||
EXPECT_FALSE(failed(result));
|
||||
|
||||
convert.RewriteFunc();
|
||||
fused_lstm_func_.dump();
|
||||
|
||||
// verify transpose
|
||||
EXPECT_EQ(
|
||||
fused_lstm_func_.getAttrOfType<StringAttr>("tf._implements").getValue(),
|
||||
convert.GetCompositeOpName());
|
||||
EXPECT_EQ(fused_lstm_func_.getNumArguments(), 5);
|
||||
EXPECT_EQ(fused_lstm_func_.getType().getNumResults(), 1);
|
||||
|
||||
auto transpose_op = fused_lstm_func_.getBody().front().begin();
|
||||
transpose_op++;
|
||||
EXPECT_EQ(transpose_op->getOperand(0)
|
||||
->getType()
|
||||
.cast<RankedTensorType>()
|
||||
.getDimSize(0),
|
||||
3);
|
||||
EXPECT_EQ(transpose_op->getOperand(0)
|
||||
->getType()
|
||||
.cast<RankedTensorType>()
|
||||
.getDimSize(1),
|
||||
12);
|
||||
EXPECT_EQ(
|
||||
transpose_op->getResult(0)->getType().cast<RankedTensorType>().getDimSize(
|
||||
0),
|
||||
12);
|
||||
EXPECT_EQ(
|
||||
transpose_op->getResult(0)->getType().cast<RankedTensorType>().getDimSize(
|
||||
1),
|
||||
3);
|
||||
|
||||
auto return_op = fused_lstm_func_.getBody().back().rbegin();
|
||||
EXPECT_EQ(return_op->getName().getStringRef(),
|
||||
mlir::ReturnOp::getOperationName());
|
||||
return_op++;
|
||||
EXPECT_EQ(return_op->getName().getStringRef(),
|
||||
mlir::TFL::LSTMOp::getOperationName());
|
||||
EXPECT_EQ(return_op->getNumOperands(), 24);
|
||||
EXPECT_EQ(return_op->getNumResults(), 1);
|
||||
// cifg = false, so input2input is not None.
|
||||
EXPECT_FALSE(return_op->getOperand(1)->getType().isa<NoneType>());
|
||||
// input layer norm is None
|
||||
EXPECT_TRUE(return_op->getOperand(20)->getType().isa<NoneType>());
|
||||
// proj_bias is F32
|
||||
EXPECT_TRUE(return_op->getOperand(17)
|
||||
->getType()
|
||||
.cast<RankedTensorType>()
|
||||
.getElementType()
|
||||
.isF32());
|
||||
|
||||
EXPECT_EQ(fused_lstm_func_.getType().getNumResults(), 1);
|
||||
auto output_types = fused_lstm_func_.getType().getResults();
|
||||
SmallVector<int64_t, 2> output_shape{1, 2};
|
||||
EXPECT_EQ(output_types[0].cast<RankedTensorType>().getShape().size(),
|
||||
output_shape.size());
|
||||
for (int i = 0; i < output_shape.size(); i++) {
|
||||
EXPECT_EQ(output_types[0].cast<RankedTensorType>().getDimSize(i),
|
||||
output_shape[i]);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(LstmUtilsTest, ConvertLSTMCellSimpleToFusedLSTMCoupleInputForget) {
|
||||
mlir::TFL::ConvertLSTMCellSimpleToFusedLSTM convert(fused_lstm_func_, true);
|
||||
|
||||
auto result = convert.Initialize();
|
||||
EXPECT_FALSE(failed(result));
|
||||
|
||||
convert.RewriteFunc();
|
||||
fused_lstm_func_.dump();
|
||||
|
||||
auto it = fused_lstm_func_.getBody().back().rbegin();
|
||||
EXPECT_EQ(it->getName().getStringRef(), mlir::ReturnOp::getOperationName());
|
||||
it++;
|
||||
EXPECT_EQ(it->getName().getStringRef(),
|
||||
mlir::TFL::LSTMOp::getOperationName());
|
||||
EXPECT_EQ(it->getNumOperands(), 24);
|
||||
EXPECT_EQ(it->getNumResults(), 1);
|
||||
// cifg = true, so input2input is None.
|
||||
EXPECT_TRUE(it->getOperand(1)->getType().isa<NoneType>());
|
||||
}
|
||||
|
||||
TEST_F(LstmUtilsTest, ConvertLayerNormLSTMCellSimpleToFusedLSTM) {
|
||||
mlir::TFL::ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM convert(
|
||||
fused_lstm_func_, false);
|
||||
|
||||
auto result = convert.Initialize();
|
||||
EXPECT_FALSE(failed(result));
|
||||
|
||||
convert.RewriteFunc();
|
||||
fused_lstm_func_.dump();
|
||||
|
||||
EXPECT_EQ(
|
||||
fused_lstm_func_.getAttrOfType<StringAttr>("tf._implements").getValue(),
|
||||
convert.GetCompositeOpName());
|
||||
EXPECT_EQ(fused_lstm_func_.getNumArguments(), 5);
|
||||
EXPECT_EQ(fused_lstm_func_.getType().getNumResults(), 1);
|
||||
|
||||
auto it = fused_lstm_func_.getBody().back().rbegin();
|
||||
EXPECT_EQ(it->getName().getStringRef(), mlir::ReturnOp::getOperationName());
|
||||
it++;
|
||||
EXPECT_EQ(it->getName().getStringRef(),
|
||||
mlir::TFL::LSTMOp::getOperationName());
|
||||
EXPECT_EQ(it->getNumOperands(), 24);
|
||||
EXPECT_EQ(it->getNumResults(), 1);
|
||||
// cifg = false, so input2input is not None.
|
||||
EXPECT_FALSE(it->getOperand(1)->getType().isa<NoneType>());
|
||||
|
||||
// input layer norm
|
||||
EXPECT_FALSE(it->getOperand(20)->getType().isa<NoneType>());
|
||||
EXPECT_EQ(
|
||||
it->getOperand(20)->getType().cast<RankedTensorType>().getShape().size(),
|
||||
1);
|
||||
EXPECT_EQ(
|
||||
it->getOperand(20)->getType().cast<RankedTensorType>().getDimSize(0), 3);
|
||||
|
||||
EXPECT_EQ(fused_lstm_func_.getType().getNumResults(), 1);
|
||||
auto output_types = fused_lstm_func_.getType().getResults();
|
||||
SmallVector<int64_t, 2> output_shape{1, 2};
|
||||
EXPECT_EQ(output_types[0].cast<RankedTensorType>().getShape().size(),
|
||||
output_shape.size());
|
||||
for (int i = 0; i < output_shape.size(); i++) {
|
||||
EXPECT_EQ(output_types[0].cast<RankedTensorType>().getDimSize(i),
|
||||
output_shape[i]);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace TFL
|
||||
} // namespace mlir
|
11
tensorflow/compiler/mlir/python/BUILD
Normal file
11
tensorflow/compiler/mlir/python/BUILD
Normal file
@ -0,0 +1,11 @@
|
||||
package(
|
||||
default_visibility = ["//visibility:public"],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
exports_files(
|
||||
["mlir.i"],
|
||||
visibility = [
|
||||
"//tensorflow/python:__subpackages__",
|
||||
],
|
||||
)
|
74
tensorflow/compiler/mlir/python/mlir.i
Normal file
74
tensorflow/compiler/mlir/python/mlir.i
Normal file
@ -0,0 +1,74 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
%include "tensorflow/python/platform/base.i"
|
||||
|
||||
%{
|
||||
|
||||
#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/import_utils.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace swig {
|
||||
|
||||
// Simple wrapper to support tf.mlir.experimental.convert_graph_def.
|
||||
// Load a .pbptx, convert to MLIR, and (optionally) optimize the module before
|
||||
// returning it as a string.
|
||||
// This is an early experimental API, ideally we should return a wrapper object
|
||||
// around a Python binding to the MLIR module.
|
||||
string ImportGraphDef(const string &proto, TF_Status* status) {
|
||||
GraphDef graphdef;
|
||||
auto s = tensorflow::LoadProtoFromBuffer(proto, &graphdef);
|
||||
if (!s.ok()) {
|
||||
Set_TF_Status_from_Status(status, s);
|
||||
return "// error";
|
||||
}
|
||||
GraphDebugInfo debug_info;
|
||||
NodeSpecs specs;
|
||||
mlir::MLIRContext context;
|
||||
auto module = ConvertGraphdefToMlir(graphdef, debug_info, specs, &context);
|
||||
if (!module.ok()) {
|
||||
Set_TF_Status_from_Status(status, module.status());
|
||||
return "// error";
|
||||
}
|
||||
|
||||
return MlirModuleToString(*module.ConsumeValueOrDie());
|
||||
}
|
||||
|
||||
} // namespace swig
|
||||
} // namespace tensorflow
|
||||
|
||||
%}
|
||||
|
||||
%ignoreall
|
||||
|
||||
%unignore tensorflow;
|
||||
%unignore tensorflow::swig;
|
||||
%unignore tensorflow::swig::ImportGraphDef;
|
||||
|
||||
// Wrap this function
|
||||
namespace tensorflow {
|
||||
namespace swig {
|
||||
static string ImportGraphDef(const string &graphdef, TF_Status* status);
|
||||
} // namespace swig
|
||||
} // namespace tensorflow
|
||||
|
||||
%insert("python") %{
|
||||
def import_graphdef(graphdef):
|
||||
return str(ImportGraphDef(str(graphdef).encode('utf-8')));
|
||||
%}
|
||||
|
||||
%unignoreall
|
@ -1,5 +1,5 @@
|
||||
load("@local_config_mlir//:tblgen.bzl", "gentbl")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cc_test", "tf_native_cc_binary")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cc_test", "tf_gen_op_wrapper_py", "tf_native_cc_binary")
|
||||
|
||||
package(
|
||||
default_visibility = [":friends"],
|
||||
@ -123,21 +123,6 @@ cc_library(
|
||||
"ir/tf_ops.cc.inc",
|
||||
"ir/tf_ops.h.inc",
|
||||
"ir/tf_types.cc",
|
||||
"transforms/bridge.cc",
|
||||
"transforms/bridge_pass.cc",
|
||||
"transforms/cluster_formation.cc",
|
||||
"transforms/cluster_outlining.cc",
|
||||
"transforms/executor_island_coarsening.cc",
|
||||
"transforms/functional_control_flow_to_cfg.cc",
|
||||
"transforms/generated_canonicalize.inc",
|
||||
"transforms/generated_optimize.inc",
|
||||
"transforms/graph_pruning.cc",
|
||||
"transforms/optimize.cc",
|
||||
"transforms/raise_control_flow.cc",
|
||||
"transforms/tpu_cluster_formation.cc",
|
||||
"transforms/tpu_rewrite_pass.cc",
|
||||
"translate/control_to_executor_dialect.cc",
|
||||
"translate/executor_to_control_dialect.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"ir/control_flow_ops.h",
|
||||
@ -153,11 +138,11 @@ cc_library(
|
||||
includes = ["include"],
|
||||
deps = [
|
||||
":error_util",
|
||||
":mlir_passthrough_op",
|
||||
":tensorflow_canonicalize_inc_gen",
|
||||
":tensorflow_device_ops_inc_gen",
|
||||
":tensorflow_executor_inc_gen",
|
||||
":tensorflow_ops_inc_gen",
|
||||
":tensorflow_optimize_inc_gen",
|
||||
"//tensorflow/compiler/mlir/lite:validators",
|
||||
"//tensorflow/core:lib",
|
||||
"@llvm//:support",
|
||||
@ -175,12 +160,74 @@ cc_library(
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tensorflow_passes",
|
||||
srcs = [
|
||||
"transforms/bridge.cc",
|
||||
"transforms/bridge_pass.cc",
|
||||
"transforms/cluster_formation.cc",
|
||||
"transforms/cluster_outlining.cc",
|
||||
"transforms/executor_island_coarsening.cc",
|
||||
"transforms/fold_switch.cc",
|
||||
"transforms/functional_control_flow_to_cfg.cc",
|
||||
"transforms/generated_canonicalize.inc",
|
||||
"transforms/generated_optimize.inc",
|
||||
"transforms/graph_pruning.cc",
|
||||
"transforms/materialize_mlir_passthrough_op.cc",
|
||||
"transforms/optimize.cc",
|
||||
"transforms/raise_control_flow.cc",
|
||||
"transforms/sink_constant.cc",
|
||||
"transforms/tpu_cluster_formation.cc",
|
||||
"transforms/tpu_rewrite_pass.cc",
|
||||
"translate/control_to_executor_dialect.cc",
|
||||
"translate/executor_to_control_dialect.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"transforms/bridge.h",
|
||||
"transforms/passes.h",
|
||||
],
|
||||
includes = ["include"],
|
||||
deps = [
|
||||
":error_util",
|
||||
":tensorflow",
|
||||
":tensorflow_optimize_inc_gen",
|
||||
"//tensorflow/compiler/mlir/lite:validators",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/platform:logging",
|
||||
"@llvm//:support",
|
||||
"@local_config_mlir//:Analysis",
|
||||
"@local_config_mlir//:IR",
|
||||
"@local_config_mlir//:Parser",
|
||||
"@local_config_mlir//:Pass",
|
||||
"@local_config_mlir//:StandardOps",
|
||||
"@local_config_mlir//:Support",
|
||||
"@local_config_mlir//:TransformUtils",
|
||||
"@local_config_mlir//:Transforms",
|
||||
],
|
||||
# TODO(jpienaar): Merge in the dialect registration.
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tensorflow_test_passes",
|
||||
srcs = [
|
||||
"transforms/lower_tf_pass.cc",
|
||||
],
|
||||
deps = [
|
||||
":lower_tf_lib",
|
||||
"@local_config_mlir//:IR",
|
||||
"@local_config_mlir//:Pass",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
# Library with TensorFlow dialect static initialization.
|
||||
cc_library(
|
||||
name = "tensorflow_dialect_registration",
|
||||
srcs = ["ir/dialect_registration.cc"],
|
||||
deps = [
|
||||
":tensorflow",
|
||||
":tensorflow_passes",
|
||||
"@local_config_mlir//:IR",
|
||||
],
|
||||
alwayslink = 1,
|
||||
@ -204,6 +251,7 @@ cc_library(
|
||||
":mangling_util",
|
||||
":mlir_roundtrip_flags",
|
||||
":tensorflow",
|
||||
":tensorflow_passes",
|
||||
"//tensorflow/cc/saved_model:loader_lite",
|
||||
"//tensorflow/compiler/jit:shape_inference_helpers",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
@ -252,6 +300,8 @@ cc_library(
|
||||
"utils/export_utils.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"ir/tf_types.def",
|
||||
"ir/tf_types.h",
|
||||
"utils/export_utils.h",
|
||||
],
|
||||
deps = [
|
||||
@ -639,26 +689,73 @@ gentbl(
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tensorflow_fold_switch",
|
||||
srcs = [
|
||||
"transforms/fold_switch.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"transforms/passes.h",
|
||||
],
|
||||
copts = ["-std=c++14"],
|
||||
name = "compile_mlir_util",
|
||||
srcs = ["utils/compile_mlir_util.cc"],
|
||||
hdrs = ["utils/compile_mlir_util.h"],
|
||||
deps = [
|
||||
":tensorflow",
|
||||
":convert_type",
|
||||
":error_util",
|
||||
"//tensorflow/compiler/mlir/xla:hlo",
|
||||
"//tensorflow/compiler/mlir/xla:mlir_hlo_to_hlo",
|
||||
"//tensorflow/compiler/mlir/xla:type_to_shape",
|
||||
"//tensorflow/compiler/mlir/xla:xla_legalize_tf",
|
||||
"//tensorflow/compiler/tf2xla:common",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/memory",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"@com_google_absl//absl/types:span",
|
||||
"@llvm//:support",
|
||||
"@local_config_mlir//:Analysis",
|
||||
"@local_config_mlir//:IR",
|
||||
"@local_config_mlir//:Pass",
|
||||
"@local_config_mlir//:QuantOps",
|
||||
"@local_config_mlir//:StandardOps",
|
||||
"@local_config_mlir//:Support",
|
||||
"@local_config_mlir//:Parser",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "compile_mlir_util_test",
|
||||
size = "small",
|
||||
srcs = ["utils/compile_mlir_util_test.cc"],
|
||||
deps = [
|
||||
":compile_mlir_util",
|
||||
"//tensorflow/compiler/tf2xla:common",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "mlir_passthrough_op",
|
||||
srcs = ["ops/mlir_passthrough_op.cc"],
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
tf_gen_op_wrapper_py(
|
||||
name = "gen_mlir_passthrough_op_py",
|
||||
out = "gen_mlir_passthrough_op.py",
|
||||
deps = [":mlir_passthrough_op"],
|
||||
)
|
||||
|
||||
# Library to get rewrite patterns lowering within TensorFlow.
|
||||
#
|
||||
# This is a separate library so that external passes can link only this library
|
||||
# without linking any of the other tensorflow passes.
|
||||
cc_library(
|
||||
name = "lower_tf_lib",
|
||||
srcs = [
|
||||
"transforms/lower_tf.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"transforms/lower_tf.h",
|
||||
],
|
||||
deps = [
|
||||
":tensorflow",
|
||||
"@local_config_mlir//:IR",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
@ -40,6 +40,7 @@ limitations under the License.
|
||||
#include "mlir/IR/Types.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Value.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir
|
||||
#include "mlir/Transforms/FoldUtils.h" // TF:local_config_mlir
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
|
||||
|
||||
namespace mlir {
|
||||
@ -48,31 +49,65 @@ namespace {
|
||||
|
||||
// If the given tensor has elements of type variant, then returns a new type
|
||||
// after dropping subtypes info. Otherwise, returns the original type as is.
|
||||
Type DropVariantSubTypes(Type ty) {
|
||||
ShapedType shaped_ty = ty.cast<ShapedType>();
|
||||
Type element_ty = shaped_ty.getElementType();
|
||||
ShapedType DropVariantSubTypes(ShapedType ty) {
|
||||
Type element_ty = ty.getElementType();
|
||||
if (!element_ty.isa<TF::VariantType>()) return ty;
|
||||
|
||||
Type variant_ty = TF::VariantType::get(ty.getContext());
|
||||
if (shaped_ty.hasRank()) {
|
||||
return RankedTensorType::get(shaped_ty.getShape(), variant_ty);
|
||||
if (ty.hasRank()) {
|
||||
return RankedTensorType::get(ty.getShape(), variant_ty);
|
||||
}
|
||||
|
||||
return UnrankedTensorType::get(variant_ty);
|
||||
}
|
||||
|
||||
// If the given tensor has elements of type ref, then returns a new type
|
||||
// of the shape, but corresponding non-ref type as element type. Otherwise,
|
||||
// returns the original type as is.
|
||||
ShapedType DropRefType(ShapedType type) {
|
||||
Type element_ty = type.getElementType();
|
||||
TF::TensorFlowRefType ref_type = element_ty.dyn_cast<TF::TensorFlowRefType>();
|
||||
if (!ref_type) return type;
|
||||
|
||||
if (type.hasRank()) {
|
||||
return RankedTensorType::get(type.getShape(), ref_type.RemoveRef());
|
||||
}
|
||||
return UnrankedTensorType::get(ref_type.RemoveRef());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TF Executor Dialect
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
|
||||
struct TensorFlowExecutorOpFolderDialectInterface
|
||||
: public OpFolderDialectInterface {
|
||||
using OpFolderDialectInterface::OpFolderDialectInterface;
|
||||
|
||||
// Registered hook to check if the given region, which is attached to an
|
||||
// operation that is *not* isolated from above (i.e. no internal regions
|
||||
// reference values defined in an enclosing region), should be used when
|
||||
// materializing constants.
|
||||
// In the executor dialect we materialize inside an island.
|
||||
bool shouldMaterializeInto(Region *region) const final {
|
||||
return isa<tf_executor::IslandOp>(region->getParentOp());
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
TensorFlowExecutorDialect::TensorFlowExecutorDialect(MLIRContext *context)
|
||||
: Dialect(/*name=*/"tf_executor", context) {
|
||||
addOperations<
|
||||
#define GET_OP_LIST
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc.inc"
|
||||
>();
|
||||
|
||||
addInterfaces<TensorFlowExecutorOpFolderDialectInterface>();
|
||||
|
||||
addTypes<ControlType, TokenType>();
|
||||
}
|
||||
|
||||
@ -296,6 +331,23 @@ void Print(IslandOp op, OpAsmPrinter *p) {
|
||||
p->printOperands(op.getOperands());
|
||||
*p << ')';
|
||||
}
|
||||
|
||||
// Check if we can print the short "wraps" form: that is if the island
|
||||
// contains a single operation and the result of this operation are perfectly
|
||||
// forwarded to the yield.
|
||||
if (op.getAttrs().empty() &&
|
||||
std::next(op.GetBody().begin(), 2) == op.GetBody().end()) {
|
||||
Operation &wrapped_op = op.GetBody().front();
|
||||
Operation &yield_op = op.GetBody().back();
|
||||
if (wrapped_op.getNumResults() == yield_op.getNumOperands() &&
|
||||
std::equal(wrapped_op.getResults().begin(),
|
||||
wrapped_op.getResults().end(),
|
||||
yield_op.getOperands().begin())) {
|
||||
*p << " wraps ";
|
||||
p->printGenericOp(&op.GetBody().front());
|
||||
return;
|
||||
}
|
||||
}
|
||||
p->printRegion(op.getOperation()->getRegion(0));
|
||||
p->printOptionalAttrDict(op.getAttrs());
|
||||
}
|
||||
@ -316,17 +368,22 @@ ParseResult ParseIslandOp(OpAsmParser *parser, OperationState *result) {
|
||||
// Parse the body region.
|
||||
Region &body = *result->addRegion();
|
||||
|
||||
// TODO(b/134773778): the custom parser is missing support to implement to
|
||||
// short syntax right now.
|
||||
// if (!parser->parseOptionalKeyword("wraps")) {
|
||||
// body.push_back(new Block);
|
||||
// Block &block = body.back();
|
||||
// parser->getBuilder().setInsertionPointToEnd(&block);
|
||||
// if (parser->parseOperation())
|
||||
// return failure();
|
||||
// }
|
||||
|
||||
if (parser->parseRegion(body, llvm::None, llvm::None)) return failure();
|
||||
if (succeeded(parser->parseOptionalKeyword("wraps"))) {
|
||||
// If we parse the short version of the island, we have an operation in the
|
||||
// generic form that follows the "wraps" keyword. Parse it inside the region
|
||||
// and forward all of its results as-is to the yield operation.
|
||||
body.push_back(new Block);
|
||||
Block &block = body.back();
|
||||
Operation *wrapped_op =
|
||||
parser->parseGenericOperation(&block, block.begin());
|
||||
if (!wrapped_op) return failure();
|
||||
OpBuilder builder(parser->getBuilder().getContext());
|
||||
builder.setInsertionPointToEnd(&block);
|
||||
builder.create<YieldOp>(result->location,
|
||||
llvm::to_vector<8>(wrapped_op->getResults()));
|
||||
} else if (parser->parseRegion(body, llvm::None, llvm::None)) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
IslandOp::ensureTerminator(body, parser->getBuilder(), result->location);
|
||||
|
||||
@ -536,35 +593,43 @@ LogicalResult Verify(MergeOp merge) {
|
||||
if (data_type.isa<ControlType>())
|
||||
return merge.emitOpError() << "expects a non-control input";
|
||||
|
||||
// Check that all operands can be broadcasted to a common type compatible with
|
||||
// the result type.
|
||||
Type broadcasted_type = merge.output()->getType();
|
||||
// Check that each operand can be individually broadcasted to the output type.
|
||||
Type output_type = merge.output()->getType();
|
||||
TensorType output_tensor_ty = output_type.dyn_cast<TensorType>();
|
||||
if (!output_tensor_ty) {
|
||||
return merge.emitOpError()
|
||||
<< "expects output to have tensor type but got " << output_type;
|
||||
}
|
||||
bool is_output_ref =
|
||||
output_tensor_ty.getElementType().isa<TF::TensorFlowRefType>();
|
||||
for (Type operand_type : merge.getOperandTypes()) {
|
||||
if (operand_type.isa<ControlType>()) break;
|
||||
|
||||
// TODO(hinsu): Update ControlOperandsAfterAllData trait to verify this
|
||||
// constraint.
|
||||
if (!operand_type.isa<TensorType>())
|
||||
return merge.emitOpError("expects data operands to have tensor type");
|
||||
|
||||
// Variant types may have opaque subtypes information that need not match
|
||||
// between the two types so drop them before computing the broadcasted type.
|
||||
Type new_broadcasted_type =
|
||||
OpTrait::util::getBroadcastedType(DropVariantSubTypes(broadcasted_type),
|
||||
DropVariantSubTypes(operand_type));
|
||||
if (!new_broadcasted_type)
|
||||
TensorType operand_tensor_ty = operand_type.dyn_cast<TensorType>();
|
||||
if (!operand_tensor_ty)
|
||||
return merge.emitOpError()
|
||||
<< "expects all operands to be broadcastable"
|
||||
<< " but got " << broadcasted_type << " vs " << operand_type;
|
||||
// Use the broadcasted type unless we're losing the rank information here.
|
||||
// This is because for example starting with a result of tensor<4xf32>, if
|
||||
// the first operand is unranked, the broadcasted type will be unranked.
|
||||
// Then any tensor operand will be broadcastable to this unranked type.
|
||||
if (!broadcasted_type.cast<TensorType>().hasRank() ||
|
||||
new_broadcasted_type.cast<TensorType>().hasRank())
|
||||
broadcasted_type = new_broadcasted_type;
|
||||
}
|
||||
<< "expects data operands to have tensor type but got "
|
||||
<< operand_type;
|
||||
|
||||
// If output type is a ref type then all operand types should also be of the
|
||||
// same ref type. However, if the output type is a non-ref type T, operands
|
||||
// can be tensor of type T or T_REF.
|
||||
if (is_output_ref &&
|
||||
!operand_tensor_ty.getElementType().isa<TF::TensorFlowRefType>()) {
|
||||
return merge.emitOpError()
|
||||
<< "expects same operand and output element type but got "
|
||||
<< operand_tensor_ty << " vs " << output_tensor_ty;
|
||||
}
|
||||
Type broadcasted_type = OpTrait::util::getBroadcastedType(
|
||||
DropRefType(DropVariantSubTypes(output_tensor_ty)),
|
||||
DropRefType(DropVariantSubTypes(operand_tensor_ty)));
|
||||
if (!broadcasted_type)
|
||||
return merge.emitOpError()
|
||||
<< "expects all operands to be broadcastable with output type"
|
||||
<< " but got " << operand_tensor_ty << " vs " << output_tensor_ty;
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
@ -1088,6 +1153,35 @@ void IslandOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
DropEmptyIslandNoOperandOneDataResult>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// tf_executor.ControlTrigger
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
// This pattern matches and removes ControlTriggerOps with no control operands.
|
||||
// Control result users will have their relevant operands removed.
|
||||
struct DropEmptyControlTrigger : public OpRewritePattern<ControlTriggerOp> {
|
||||
using OpRewritePattern<ControlTriggerOp>::OpRewritePattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(ControlTriggerOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
if (op.getNumOperands() != 0) return matchFailure();
|
||||
|
||||
for (auto &use : llvm::make_early_inc_range(op.control()->getUses()))
|
||||
use.getOwner()->eraseOperand(use.getOperandNumber());
|
||||
|
||||
rewriter.replaceOp(op, {nullptr});
|
||||
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
} // anonymous namespace
|
||||
|
||||
void ControlTriggerOp::getCanonicalizationPatterns(
|
||||
OwningRewritePatternList &results, MLIRContext *context) {
|
||||
results.insert<DropEmptyControlTrigger>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Folders
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -594,6 +594,8 @@ def TfExecutor_ControlTriggerOp : TfExecutor_Op<"ControlTrigger",
|
||||
|
||||
let verifier = ?;
|
||||
|
||||
let hasCanonicalizer = 1;
|
||||
|
||||
let builders = [OpBuilder<
|
||||
"Builder *builder, OperationState *result, "
|
||||
"ArrayRef<Value *> operands, ArrayRef<NamedAttribute> attributes = {}",
|
||||
|
@ -88,13 +88,13 @@ Inputs must be of same size and shape.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
Variadic<TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Variant]>>:$inputs,
|
||||
Variadic<TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8, TF_Variant]>>:$inputs,
|
||||
|
||||
Confined<I64Attr, [IntMinValue<1>]>:$N
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Variant]>:$sum
|
||||
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8, TF_Variant]>:$sum
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
@ -110,12 +110,12 @@ def TF_AddV2Op : TF_Op<"AddV2", [Broadcastable, Commutative, NoSideEffect]>,
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TF_NumberTensor:$x,
|
||||
TF_NumberTensor:$y
|
||||
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint8]>:$x,
|
||||
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint8]>:$y
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TF_NumberTensor:$z
|
||||
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint8]>:$z
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
@ -123,6 +123,32 @@ def TF_AddV2Op : TF_Op<"AddV2", [Broadcastable, Commutative, NoSideEffect]>,
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def TF_AllOp : TF_Op<"All", [NoSideEffect]> {
|
||||
let summary = [{
|
||||
Computes the "logical and" of elements across dimensions of a tensor.
|
||||
}];
|
||||
|
||||
let description = [{
|
||||
Reduces `input` along the dimensions given in `axis`. Unless
|
||||
`keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in
|
||||
`axis`. If `keep_dims` is true, the reduced dimensions are
|
||||
retained with length 1.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
I1Tensor:$input,
|
||||
TF_I32OrI64Tensor:$reduction_indices,
|
||||
|
||||
DefaultValuedAttr<BoolAttr, "false">:$keep_dims
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
I1Tensor:$output
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr Tidx = TF_DerivedOperandTypeAttr<1>;
|
||||
}
|
||||
|
||||
def TF_AnyOp : TF_Op<"Any", [NoSideEffect]> {
|
||||
let summary = [{
|
||||
Computes the "logical or" of elements across dimensions of a tensor.
|
||||
@ -169,7 +195,7 @@ Usage:
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TF_NumberTensor:$input,
|
||||
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$input,
|
||||
TF_I32OrI64Tensor:$dimension
|
||||
);
|
||||
|
||||
@ -202,7 +228,7 @@ Usage:
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TF_NumberTensor:$input,
|
||||
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$input,
|
||||
TF_I32OrI64Tensor:$dimension
|
||||
);
|
||||
|
||||
@ -261,6 +287,88 @@ window in `value`.
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_BatchMatMulOp : TF_Op<"BatchMatMul", [NoSideEffect]> {
|
||||
let summary = "Multiplies slices of two tensors in batches.";
|
||||
|
||||
let description = [{
|
||||
Multiplies all slices of `Tensor` `x` and `y` (each slice can be
|
||||
viewed as an element of a batch), and arranges the individual results
|
||||
in a single output tensor of the same batch size. Each of the
|
||||
individual slices can optionally be adjointed (to adjoint a matrix
|
||||
means to transpose and conjugate it) before multiplication by setting
|
||||
the `adj_x` or `adj_y` flag to `True`, which are by default `False`.
|
||||
|
||||
The input tensors `x` and `y` are 2-D or higher with shape `[..., r_x, c_x]`
|
||||
and `[..., r_y, c_y]`.
|
||||
|
||||
The output tensor is 2-D or higher with shape `[..., r_o, c_o]`, where:
|
||||
|
||||
r_o = c_x if adj_x else r_x
|
||||
c_o = r_y if adj_y else c_y
|
||||
|
||||
It is computed as:
|
||||
|
||||
output[..., :, :] = matrix(x[..., :, :]) * matrix(y[..., :, :])
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$x,
|
||||
TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$y,
|
||||
|
||||
DefaultValuedAttr<BoolAttr, "false">:$adj_x,
|
||||
DefaultValuedAttr<BoolAttr, "false">:$adj_y
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$output
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_BatchMatMulV2Op : TF_Op<"BatchMatMulV2", [NoSideEffect]> {
|
||||
let summary = "Multiplies slices of two tensors in batches.";
|
||||
|
||||
let description = [{
|
||||
Multiplies all slices of `Tensor` `x` and `y` (each slice can be
|
||||
viewed as an element of a batch), and arranges the individual results
|
||||
in a single output tensor of the same batch size. Each of the
|
||||
individual slices can optionally be adjointed (to adjoint a matrix
|
||||
means to transpose and conjugate it) before multiplication by setting
|
||||
the `adj_x` or `adj_y` flag to `True`, which are by default `False`.
|
||||
|
||||
The input tensors `x` and `y` are 2-D or higher with shape `[..., r_x, c_x]`
|
||||
and `[..., r_y, c_y]`.
|
||||
|
||||
The output tensor is 2-D or higher with shape `[..., r_o, c_o]`, where:
|
||||
|
||||
r_o = c_x if adj_x else r_x
|
||||
c_o = r_y if adj_y else c_y
|
||||
|
||||
It is computed as:
|
||||
|
||||
output[..., :, :] = matrix(x[..., :, :]) * matrix(y[..., :, :])
|
||||
|
||||
*NOTE*: `BatchMatMulV2` supports broadcasting in the batch dimensions. More
|
||||
about broadcasting
|
||||
[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html).
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$x,
|
||||
TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$y,
|
||||
|
||||
DefaultValuedAttr<BoolAttr, "false">:$adj_x,
|
||||
DefaultValuedAttr<BoolAttr, "false">:$adj_y
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$output
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_BatchToSpaceNDOp : TF_Op<"BatchToSpaceND", [NoSideEffect]> {
|
||||
let summary = "BatchToSpace for N-D tensors of type T.";
|
||||
|
||||
@ -297,14 +405,14 @@ Broadcasting is supported, so `value` may have any number of dimensions.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TF_NumberTensor:$value,
|
||||
TF_NumberTensor:$bias,
|
||||
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$value,
|
||||
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$bias,
|
||||
|
||||
DefaultValuedAttr<TF_ConvnetDataFormatAttr, "NHWC">:$data_format
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TF_NumberTensor:$output
|
||||
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
@ -332,21 +440,23 @@ gives module error.
|
||||
For example,
|
||||
|
||||
Example 1:
|
||||
```python
|
||||
|
||||
>>> a = [1., 2., 3.]
|
||||
>>> equality_bitcast = tf.bitcast(a,tf.complex128)
|
||||
tensorflow.python.framework.errors_impl.InvalidArgumentError: Cannot bitcast from float to complex128: shape [3] [Op:Bitcast]
|
||||
>>> equality_cast = tf.cast(a,tf.complex128)
|
||||
>>> equality_bitcast = tf.bitcast(a, tf.complex128)
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
InvalidArgumentError: Cannot bitcast from 1 to 18 [Op:Bitcast]
|
||||
>>> equality_cast = tf.cast(a, tf.complex128)
|
||||
>>> print(equality_cast)
|
||||
tf.Tensor([1.+0.j 2.+0.j 3.+0.j], shape=(3,), dtype=complex128)
|
||||
```
|
||||
|
||||
Example 2:
|
||||
```python
|
||||
|
||||
>>> tf.bitcast(tf.constant(0xffffffff, dtype=tf.uint32), tf.uint8)
|
||||
<tf.Tensor: ... shape=(4,), dtype=uint8, numpy=array([255, 255, 255, 255], dtype=uint8)>
|
||||
```
|
||||
|
||||
Example 3:
|
||||
```python
|
||||
|
||||
>>> x = [1., 2., 3.]
|
||||
>>> y = [0., 2., 3.]
|
||||
>>> equality= tf.equal(x,y)
|
||||
@ -358,10 +468,9 @@ tf.Tensor([False True True], shape=(3,), dtype=bool)
|
||||
tf.Tensor([0. 1. 1.], shape=(3,), dtype=float32)
|
||||
>>> print(equality_bitcast)
|
||||
tf.Tensor(
|
||||
[[ 0 0 0 0]
|
||||
[ 0 0 128 63]
|
||||
[ 0 0 128 63]], shape=(3, 4), dtype=uint8)
|
||||
```
|
||||
[[ 0 0 0 0]
|
||||
[ 0 0 128 63]
|
||||
[ 0 0 128 63]], shape=(3, 4), dtype=uint8)
|
||||
|
||||
*NOTE*: Bitcast is implemented as a low-level cast, so machines with different
|
||||
endian orderings will give different results.
|
||||
@ -393,14 +502,13 @@ and works its way forward.
|
||||
|
||||
For example,
|
||||
|
||||
```python
|
||||
>>> x = tf.constant([1, 2, 3])
|
||||
>>> y = tf.broadcast_to(x, [3, 3])
|
||||
>>> sess.run(y)
|
||||
array([[1, 2, 3],
|
||||
[1, 2, 3],
|
||||
[1, 2, 3]], dtype=int32)
|
||||
```
|
||||
>>> print(y)
|
||||
tf.Tensor(
|
||||
[[1 2 3]
|
||||
[1 2 3]
|
||||
[1 2 3]], shape=(3, 3), dtype=int32)
|
||||
|
||||
In the above example, the input Tensor with the shape of `[1, 3]`
|
||||
is broadcasted to output Tensor with shape of `[3, 3]`.
|
||||
@ -462,6 +570,27 @@ def TF_CeilOp : TF_Op<"Ceil", [NoSideEffect, SameOperandsAndResultType]> {
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_CheckNumericsOp : TF_Op<"CheckNumerics", [SameOperandsAndResultType]> {
|
||||
let summary = "Checks a tensor for NaN and Inf values.";
|
||||
|
||||
let description = [{
|
||||
When run, reports an `InvalidArgument` error if `tensor` has any values
|
||||
that are not a number (NaN) or infinity (Inf). Otherwise, passes `tensor` as-is.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TF_FpTensor:$tensor,
|
||||
|
||||
StrAttr:$message
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TF_FpTensor:$output
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_ConcatOp : TF_Op<"Concat", [NoSideEffect]> {
|
||||
let summary = "Concatenates tensors along one dimension.";
|
||||
|
||||
@ -480,6 +609,10 @@ def TF_ConcatOp : TF_Op<"Concat", [NoSideEffect]> {
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<1>;
|
||||
|
||||
let verifier = [{
|
||||
return Verify(*this);
|
||||
}];
|
||||
}
|
||||
|
||||
def TF_ConcatV2Op : TF_Op<"ConcatV2", [NoSideEffect]> {
|
||||
@ -501,6 +634,10 @@ def TF_ConcatV2Op : TF_Op<"ConcatV2", [NoSideEffect]> {
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
TF_DerivedOperandTypeAttr Tidx = TF_DerivedOperandTypeAttr<1>;
|
||||
|
||||
let verifier = [{
|
||||
return Verify(*this);
|
||||
}];
|
||||
}
|
||||
|
||||
def TF_ConjOp : TF_Op<"Conj", [NoSideEffect]> {
|
||||
@ -771,12 +908,12 @@ def TF_DivOp : TF_Op<"Div", [Broadcastable, NoSideEffect]>,
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TF_NumberTensor:$x,
|
||||
TF_NumberTensor:$y
|
||||
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint16, TF_Uint8]>:$x,
|
||||
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint16, TF_Uint8]>:$y
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TF_NumberTensor:$z
|
||||
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint16, TF_Uint8]>:$z
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
@ -805,8 +942,7 @@ See [Fast and Accurate Deep Network Learning by Exponential Linear Units (ELUs)
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_EqualOp : TF_Op<"Equal", [Broadcastable, Commutative, NoSideEffect]>,
|
||||
WithBroadcastableCmpOpBuilder {
|
||||
def TF_EqualOp : TF_Op<"Equal", [Commutative, NoSideEffect]> {
|
||||
let summary = "Returns the truth value of (x == y) element-wise.";
|
||||
|
||||
let description = [{
|
||||
@ -825,8 +961,10 @@ tf.math.equal(x, y) ==> array([True, True])
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TensorOf<[BF16, F16, F32, F64, I1, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Str]>:$x,
|
||||
TensorOf<[BF16, F16, F32, F64, I1, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Str]>:$y
|
||||
TensorOf<[BF16, F16, F32, F64, I1, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Str, TF_Uint8]>:$x,
|
||||
TensorOf<[BF16, F16, F32, F64, I1, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Str, TF_Uint8]>:$y,
|
||||
|
||||
DefaultValuedAttr<BoolAttr, "true">:$incompatible_shape_error
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
@ -834,6 +972,15 @@ tf.math.equal(x, y) ==> array([True, True])
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
|
||||
let builders = [
|
||||
OpBuilder<"Builder* builder, OperationState* result, Value* x, "
|
||||
"Value* y, BoolAttr incompatible_shape_error">
|
||||
];
|
||||
|
||||
let verifier = [{
|
||||
return Verify(*this);
|
||||
}];
|
||||
}
|
||||
|
||||
def TF_ExpOp : TF_Op<"Exp", [NoSideEffect, SameOperandsAndResultType]> {
|
||||
@ -1017,6 +1164,52 @@ values.
|
||||
}];
|
||||
}
|
||||
|
||||
def TF_FakeQuantWithMinMaxVarsPerChannelOp : TF_Op<"FakeQuantWithMinMaxVarsPerChannel", [NoSideEffect]> {
|
||||
let summary = [{
|
||||
Fake-quantize the 'inputs' tensor of type float and one of the shapes: `[d]`,
|
||||
}];
|
||||
|
||||
let description = [{
|
||||
`[b, d]` `[b, h, w, d]` via per-channel floats `min` and `max` of shape `[d]`
|
||||
to 'outputs' tensor of same shape as `inputs`.
|
||||
|
||||
`[min; max]` define the clamping range for the `inputs` data.
|
||||
`inputs` values are quantized into the quantization range (`[0; 2^num_bits - 1]`
|
||||
when `narrow_range` is false and `[1; 2^num_bits - 1]` when it is true) and
|
||||
then de-quantized and output as floats in `[min; max]` interval.
|
||||
`num_bits` is the bitwidth of the quantization; between 2 and 16, inclusive.
|
||||
|
||||
Before quantization, `min` and `max` values are adjusted with the following
|
||||
logic.
|
||||
It is suggested to have `min <= 0 <= max`. If `0` is not in the range of values,
|
||||
the behavior can be unexpected:
|
||||
If `0 < min < max`: `min_adj = 0` and `max_adj = max - min`.
|
||||
If `min < max < 0`: `min_adj = min - max` and `max_adj = 0`.
|
||||
If `min <= 0 <= max`: `scale = (max - min) / (2^num_bits - 1) `,
|
||||
`min_adj = scale * round(min / scale)` and `max_adj = max + min_adj - min`.
|
||||
|
||||
This operation has a gradient and thus allows for training `min` and `max`
|
||||
values.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
F32Tensor:$inputs,
|
||||
F32Tensor:$min,
|
||||
F32Tensor:$max,
|
||||
|
||||
DefaultValuedAttr<I64Attr, "8">:$num_bits,
|
||||
DefaultValuedAttr<BoolAttr, "false">:$narrow_range
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
F32Tensor:$outputs
|
||||
);
|
||||
|
||||
let verifier = [{
|
||||
return Verify(*this);
|
||||
}];
|
||||
}
|
||||
|
||||
def TF_FillOp : TF_Op<"Fill", [NoSideEffect]> {
|
||||
let summary = "Creates a tensor filled with a scalar value.";
|
||||
|
||||
@ -1082,12 +1275,12 @@ def TF_FloorDivOp : TF_Op<"FloorDiv", [Broadcastable, NoSideEffect]>,
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TF_NumberTensor:$x,
|
||||
TF_NumberTensor:$y
|
||||
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint16, TF_Uint8]>:$x,
|
||||
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint16, TF_Uint8]>:$y
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TF_NumberTensor:$z
|
||||
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint16, TF_Uint8]>:$z
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
@ -1780,14 +1973,14 @@ retained with length 1.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TF_NumberTensor:$input,
|
||||
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$input,
|
||||
TF_I32OrI64Tensor:$reduction_indices,
|
||||
|
||||
DefaultValuedAttr<BoolAttr, "false">:$keep_dims
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TF_NumberTensor:$output
|
||||
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
@ -1801,7 +1994,7 @@ def TF_MaxPoolOp : TF_Op<"MaxPool", [NoSideEffect]> {
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TF_IntOrFpTensor:$input,
|
||||
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Qint8, TF_Uint16, TF_Uint8]>:$input,
|
||||
|
||||
Confined<I64ArrayAttr, [ArrayMinCount<4>]>:$ksize,
|
||||
Confined<I64ArrayAttr, [ArrayMinCount<4>]>:$strides,
|
||||
@ -1810,7 +2003,7 @@ def TF_MaxPoolOp : TF_Op<"MaxPool", [NoSideEffect]> {
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TF_IntOrFpTensor:$output
|
||||
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Qint8, TF_Uint16, TF_Uint8]>:$output
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
@ -1850,14 +2043,14 @@ retained with length 1.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TF_NumberTensor:$input,
|
||||
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$input,
|
||||
TF_I32OrI64Tensor:$reduction_indices,
|
||||
|
||||
DefaultValuedAttr<BoolAttr, "false">:$keep_dims
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TF_NumberTensor:$output
|
||||
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
@ -1931,6 +2124,57 @@ pad(t, paddings) ==> [[2, 1, 1, 2, 3, 3, 2]
|
||||
TF_DerivedOperandTypeAttr Tpaddings = TF_DerivedOperandTypeAttr<1>;
|
||||
}
|
||||
|
||||
def TF_MlirPassthroughOpOp : TF_Op<"MlirPassthroughOp", [NoSideEffect]> {
|
||||
let summary = [{
|
||||
Wraps an arbitrary MLIR computation expressed as a module with a main() function.
|
||||
}];
|
||||
|
||||
let description = [{
|
||||
This operation does not have an associated kernel and is not intended to be
|
||||
executed in a regular TensorFlow session. Instead it is intended to be used for
|
||||
testing or for special case where a user intends to pass custom MLIR computation
|
||||
through a TensorFlow graph with the intent of having custom tooling processing
|
||||
it downstream (when targeting a different environment, like TensorFlow lite for
|
||||
example).
|
||||
The MLIR module is expected to have a main() function that will be used as an
|
||||
entry point. The inputs to the operations will be passed as argument to the
|
||||
main() function and the returned values of the main function mapped to the
|
||||
outputs.
|
||||
Example usage:
|
||||
|
||||
```
|
||||
import tensorflow as tf
|
||||
from tensorflow.compiler.mlir.tensorflow.gen_mlir_passthrough_op import mlir_passthrough_op
|
||||
|
||||
mlir_module = '''
|
||||
func @main(%arg0 : tensor<10xf32>, %arg1 : tensor<10xf32>) -> tensor<10x10xf32> {
|
||||
%add = "magic.op"(%arg0, %arg1) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10x10xf32>
|
||||
return %ret : tensor<10x10xf32>
|
||||
}
|
||||
'''
|
||||
|
||||
@tf.function
|
||||
def foo(x, y):
|
||||
return = mlir_passthrough_op([x, y], mlir_module, Toutputs=[tf.float32])
|
||||
|
||||
graph_def = foo.get_concrete_function(tf.TensorSpec([10], tf.float32), tf.TensorSpec([10], tf.float32)).graph.as_graph_def()
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
Variadic<TF_Tensor>:$inputs,
|
||||
|
||||
StrAttr:$mlir_module
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
Variadic<TF_Tensor>:$outputs
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeListAttr Tinputs = TF_DerivedOperandTypeListAttr<0>;
|
||||
TF_DerivedResultTypeListAttr Toutputs = TF_DerivedResultTypeListAttr<0>;
|
||||
}
|
||||
|
||||
def TF_MulOp : TF_Op<"Mul", [Broadcastable, Commutative, NoSideEffect]>,
|
||||
WithBroadcastableBinOpBuilder {
|
||||
let summary = "Returns x * y element-wise.";
|
||||
@ -1941,12 +2185,12 @@ def TF_MulOp : TF_Op<"Mul", [Broadcastable, Commutative, NoSideEffect]>,
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TF_NumberTensor:$x,
|
||||
TF_NumberTensor:$y
|
||||
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint16, TF_Uint8]>:$x,
|
||||
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint16, TF_Uint8]>:$y
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TF_NumberTensor:$z
|
||||
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint16, TF_Uint8]>:$z
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
@ -2006,8 +2250,101 @@ def TF_NoOp : TF_Op<"NoOp", [NoSideEffect]> {
|
||||
let results = (outs);
|
||||
}
|
||||
|
||||
def TF_NotEqualOp : TF_Op<"NotEqual", [Broadcastable, Commutative, NoSideEffect]>,
|
||||
WithBroadcastableCmpOpBuilder {
|
||||
def TF_NonMaxSuppressionV4Op : TF_Op<"NonMaxSuppressionV4", [NoSideEffect]> {
|
||||
let summary = [{
|
||||
Greedily selects a subset of bounding boxes in descending order of score,
|
||||
}];
|
||||
|
||||
let description = [{
|
||||
pruning away boxes that have high intersection-over-union (IOU) overlap
|
||||
with previously selected boxes. Bounding boxes with score less than
|
||||
`score_threshold` are removed. Bounding boxes are supplied as
|
||||
[y1, x1, y2, x2], where (y1, x1) and (y2, x2) are the coordinates of any
|
||||
diagonal pair of box corners and the coordinates can be provided as normalized
|
||||
(i.e., lying in the interval [0, 1]) or absolute. Note that this algorithm
|
||||
is agnostic to where the origin is in the coordinate system and more
|
||||
generally is invariant to orthogonal transformations and translations
|
||||
of the coordinate system; thus translating or reflections of the coordinate
|
||||
system result in the same boxes being selected by the algorithm.
|
||||
The output of this operation is a set of integers indexing into the input
|
||||
collection of bounding boxes representing the selected boxes. The bounding
|
||||
box coordinates corresponding to the selected indices can then be obtained
|
||||
using the `tf.gather operation`. For example:
|
||||
selected_indices = tf.image.non_max_suppression_v2(
|
||||
boxes, scores, max_output_size, iou_threshold, score_threshold)
|
||||
selected_boxes = tf.gather(boxes, selected_indices)
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TensorOf<[F16, F32]>:$boxes,
|
||||
TensorOf<[F16, F32]>:$scores,
|
||||
I32Tensor:$max_output_size,
|
||||
TensorOf<[F16, F32]>:$iou_threshold,
|
||||
TensorOf<[F16, F32]>:$score_threshold,
|
||||
|
||||
DefaultValuedAttr<BoolAttr, "false">:$pad_to_max_output_size
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
I32Tensor:$selected_indices,
|
||||
I32Tensor:$valid_outputs
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T_threshold = TF_DerivedOperandTypeAttr<3>;
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_NonMaxSuppressionV5Op : TF_Op<"NonMaxSuppressionV5", [NoSideEffect]> {
|
||||
let summary = [{
|
||||
Greedily selects a subset of bounding boxes in descending order of score,
|
||||
}];
|
||||
|
||||
let description = [{
|
||||
pruning away boxes that have high intersection-over-union (IOU) overlap
|
||||
with previously selected boxes. Bounding boxes with score less than
|
||||
`score_threshold` are removed. Bounding boxes are supplied as
|
||||
[y1, x1, y2, x2], where (y1, x1) and (y2, x2) are the coordinates of any
|
||||
diagonal pair of box corners and the coordinates can be provided as normalized
|
||||
(i.e., lying in the interval [0, 1]) or absolute. Note that this algorithm
|
||||
is agnostic to where the origin is in the coordinate system and more
|
||||
generally is invariant to orthogonal transformations and translations
|
||||
of the coordinate system; thus translating or reflections of the coordinate
|
||||
system result in the same boxes being selected by the algorithm.
|
||||
The output of this operation is a set of integers indexing into the input
|
||||
collection of bounding boxes representing the selected boxes. The bounding
|
||||
box coordinates corresponding to the selected indices can then be obtained
|
||||
using the `tf.gather operation`. For example:
|
||||
selected_indices = tf.image.non_max_suppression_v2(
|
||||
boxes, scores, max_output_size, iou_threshold, score_threshold)
|
||||
selected_boxes = tf.gather(boxes, selected_indices)
|
||||
This op also supports a Soft-NMS (with Gaussian weighting) mode (c.f.
|
||||
Bodla et al, https://arxiv.org/abs/1704.04503) where boxes reduce the score
|
||||
of other overlapping boxes instead of directly causing them to be pruned.
|
||||
To enable this Soft-NMS mode, set the `soft_nms_sigma` parameter to be
|
||||
larger than 0.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TensorOf<[F16, F32]>:$boxes,
|
||||
TensorOf<[F16, F32]>:$scores,
|
||||
I32Tensor:$max_output_size,
|
||||
TensorOf<[F16, F32]>:$iou_threshold,
|
||||
TensorOf<[F16, F32]>:$score_threshold,
|
||||
TensorOf<[F16, F32]>:$soft_nms_sigma,
|
||||
|
||||
DefaultValuedAttr<BoolAttr, "false">:$pad_to_max_output_size
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
I32Tensor:$selected_indices,
|
||||
TensorOf<[F16, F32]>:$selected_scores,
|
||||
I32Tensor:$valid_outputs
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_NotEqualOp : TF_Op<"NotEqual", [Commutative, NoSideEffect]> {
|
||||
let summary = "Returns the truth value of (x != y) element-wise.";
|
||||
|
||||
let description = [{
|
||||
@ -2016,8 +2353,10 @@ def TF_NotEqualOp : TF_Op<"NotEqual", [Broadcastable, Commutative, NoSideEffect]
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TensorOf<[BF16, F16, F32, F64, I1, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Str]>:$x,
|
||||
TensorOf<[BF16, F16, F32, F64, I1, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Str]>:$y
|
||||
TensorOf<[BF16, F16, F32, F64, I1, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Str, TF_Uint8]>:$x,
|
||||
TensorOf<[BF16, F16, F32, F64, I1, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Str, TF_Uint8]>:$y,
|
||||
|
||||
DefaultValuedAttr<BoolAttr, "true">:$incompatible_shape_error
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
@ -2025,6 +2364,15 @@ def TF_NotEqualOp : TF_Op<"NotEqual", [Broadcastable, Commutative, NoSideEffect]
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
|
||||
let builders = [
|
||||
OpBuilder<"Builder* builder, OperationState* result, Value* x, "
|
||||
"Value* y, BoolAttr incompatible_shape_error">
|
||||
];
|
||||
|
||||
let verifier = [{
|
||||
return Verify(*this);
|
||||
}];
|
||||
}
|
||||
|
||||
def TF_OneHotOp : TF_Op<"OneHot", [NoSideEffect]> {
|
||||
@ -2121,7 +2469,7 @@ output =
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TensorOf<[I32, I64, I8]>:$indices,
|
||||
TensorOf<[I32, I64, TF_Uint8]>:$indices,
|
||||
I32Tensor:$depth,
|
||||
TF_Tensor:$on_value,
|
||||
TF_Tensor:$off_value,
|
||||
@ -2176,6 +2524,10 @@ This is the opposite of `unpack`.
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
|
||||
let verifier = [{
|
||||
return Verify(*this);
|
||||
}];
|
||||
}
|
||||
|
||||
def TF_PadOp : TF_Op<"Pad", [NoSideEffect]> {
|
||||
@ -2303,14 +2655,14 @@ retained with length 1.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TF_NumberTensor:$input,
|
||||
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$input,
|
||||
TF_I32OrI64Tensor:$reduction_indices,
|
||||
|
||||
DefaultValuedAttr<BoolAttr, "false">:$keep_dims
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TF_NumberTensor:$output
|
||||
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
@ -2406,7 +2758,8 @@ The above round function rounds the value based on the given round_mode.
|
||||
DefaultValuedAttr<I64Attr, "8">:$num_bits,
|
||||
DefaultValuedAttr<BoolAttr, "false">:$range_given,
|
||||
DefaultValuedAttr<TF_AnyStrAttrOf<["HALF_TO_EVEN", "HALF_UP"]>, "HALF_TO_EVEN">:$round_mode,
|
||||
DefaultValuedAttr<BoolAttr, "false">:$narrow_range
|
||||
DefaultValuedAttr<BoolAttr, "false">:$narrow_range,
|
||||
DefaultValuedAttr<I64Attr, "-1">:$axis
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
@ -2432,7 +2785,8 @@ tensor, so its value can change during training.
|
||||
|
||||
DefaultValuedAttr<BoolAttr, "true">:$signed_input,
|
||||
DefaultValuedAttr<BoolAttr, "true">:$range_given,
|
||||
DefaultValuedAttr<BoolAttr, "false">:$narrow_range
|
||||
DefaultValuedAttr<BoolAttr, "false">:$narrow_range,
|
||||
DefaultValuedAttr<I64Attr, "-1">:$axis
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
@ -2550,12 +2904,12 @@ If `x` and `y` are reals, this will return the floating-point division.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TF_NumberTensor:$x,
|
||||
TF_NumberTensor:$y
|
||||
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint16, TF_Uint8]>:$x,
|
||||
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint16, TF_Uint8]>:$y
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TF_NumberTensor:$z
|
||||
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint16, TF_Uint8]>:$z
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
@ -2590,11 +2944,11 @@ def TF_ReluOp : TF_Op<"Relu", [NoSideEffect, SameOperandsAndResultType]> {
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TF_IntOrFpTensor:$features
|
||||
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Qint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$features
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TF_IntOrFpTensor:$activations
|
||||
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Qint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$activations
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
@ -2709,7 +3063,7 @@ Input images can be of different types but output images are always float.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TF_IntOrFpTensor:$images,
|
||||
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Uint16, TF_Uint8]>:$images,
|
||||
I32Tensor:$size,
|
||||
|
||||
DefaultValuedAttr<BoolAttr, "false">:$align_corners,
|
||||
@ -2732,7 +3086,7 @@ Resize `images` to `size` using nearest neighbor interpolation.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TensorOf<[F16, F32, F64, I16, I32, I64, I8]>:$images,
|
||||
TensorOf<[F16, F32, F64, I16, I32, I64, I8, TF_Uint16, TF_Uint8]>:$images,
|
||||
I32Tensor:$size,
|
||||
|
||||
DefaultValuedAttr<BoolAttr, "false">:$align_corners,
|
||||
@ -2740,7 +3094,7 @@ Resize `images` to `size` using nearest neighbor interpolation.
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TensorOf<[F16, F32, F64, I16, I32, I64, I8]>:$resized_images
|
||||
TensorOf<[F16, F32, F64, I16, I32, I64, I8, TF_Uint16, TF_Uint8]>:$resized_images
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
@ -2875,12 +3229,12 @@ reverse(t, dims) ==> [[[[8, 9, 10, 11],
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TensorOf<[BF16, F16, F32, F64, I1, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Str]>:$tensor,
|
||||
TensorOf<[BF16, F16, F32, F64, I1, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Str, TF_Uint16, TF_Uint8]>:$tensor,
|
||||
TF_I32OrI64Tensor:$axis
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TensorOf<[BF16, F16, F32, F64, I1, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Str]>:$output
|
||||
TensorOf<[BF16, F16, F32, F64, I1, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Str, TF_Uint16, TF_Uint8]>:$output
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
@ -3031,6 +3385,10 @@ shape(t) ==> [2, 2, 3]
|
||||
return Verify(*this);
|
||||
}];
|
||||
|
||||
let builders = [
|
||||
OpBuilder<"Builder* builder, OperationState* result, Value* input, BoolAttr use32Bit">
|
||||
];
|
||||
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
@ -3643,12 +4001,12 @@ def TF_SubOp : TF_Op<"Sub", [Broadcastable, NoSideEffect]>,
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TF_NumberTensor:$x,
|
||||
TF_NumberTensor:$y
|
||||
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint16, TF_Uint8]>:$x,
|
||||
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint16, TF_Uint8]>:$y
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TF_NumberTensor:$z
|
||||
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint16, TF_Uint8]>:$z
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
@ -3667,20 +4025,36 @@ retained with length 1.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TF_NumberTensor:$input,
|
||||
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$input,
|
||||
TF_I32OrI64Tensor:$reduction_indices,
|
||||
|
||||
DefaultValuedAttr<BoolAttr, "false">:$keep_dims
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TF_NumberTensor:$output
|
||||
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
TF_DerivedOperandTypeAttr Tidx = TF_DerivedOperandTypeAttr<1>;
|
||||
}
|
||||
|
||||
def TF_TPUCompilationResultOp : TF_Op<"TPUCompilationResult", [NoSideEffect]> {
|
||||
let summary = "Returns the result of a TPU compilation.";
|
||||
|
||||
let description = [{
|
||||
This operation returns the result of a TPU compilation as a serialized
|
||||
CompilationResultProto, which holds a status and an error message if an error
|
||||
occurred during compilation.
|
||||
}];
|
||||
|
||||
let arguments = (ins);
|
||||
|
||||
let results = (outs
|
||||
TF_StrTensor:$output
|
||||
);
|
||||
}
|
||||
|
||||
def TF_TanhOp : TF_Op<"Tanh", [NoSideEffect, SameOperandsAndResultType]> {
|
||||
let summary = "Computes hyperbolic tangent of `x` element-wise.";
|
||||
|
||||
@ -3750,6 +4124,23 @@ def TF_TensorListGetItemOp : TF_Op<"TensorListGetItem", [NoSideEffect]> {
|
||||
TF_DerivedResultTypeAttr element_dtype = TF_DerivedResultTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_TensorListLengthOp : TF_Op<"TensorListLength", [NoSideEffect]> {
|
||||
let summary = "Returns the number of tensors in the input tensor list.";
|
||||
|
||||
let description = [{
|
||||
input_handle: the input list
|
||||
length: the number of tensors in the list
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TF_VariantTensor:$input_handle
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
I32Tensor:$length
|
||||
);
|
||||
}
|
||||
|
||||
def TF_TensorListPushBackOp : TF_Op<"TensorListPushBack", [NoSideEffect]> {
|
||||
let summary = [{
|
||||
Returns a list which has the passed-in `Tensor` as last element and the other elements of the given list in `input_handle`.
|
||||
@ -3921,12 +4312,12 @@ Python Semantics.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TF_NumberTensor:$x,
|
||||
TF_NumberTensor:$y
|
||||
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint16, TF_Uint8]>:$x,
|
||||
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint16, TF_Uint8]>:$y
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TF_NumberTensor:$z
|
||||
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint16, TF_Uint8]>:$z
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
@ -4068,7 +4459,7 @@ where(input) ==> [[0, 0, 0],
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TensorOf<[BF16, F16, F32, F64, I1, I16, I32, I64, I8, TF_Complex128, TF_Complex64]>:$input
|
||||
TensorOf<[BF16, F16, F32, F64, I1, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$input
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
|
@ -178,7 +178,8 @@ def TF_AnyNumber : AnyTypeOf<[TF_Int, AnyFloat, TF_AnyQuantized, TF_AnyComplex],
|
||||
|
||||
def TF_NumberTensor : TensorOf<[TF_AnyNumber]>;
|
||||
|
||||
def TF_NumberOrStrTensor : TensorOf<[TF_AnyNumber, TF_Str]>;
|
||||
def TF_NumberOrStr : AnyTypeOf<[AnyFloat, TF_SInt, TF_AnyComplex, TF_Uint8, TF_Str]>;
|
||||
def TF_NumberOrStrTensor : TensorOf<[TF_NumberOrStr]>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TensorFlow attribute definitions
|
||||
|
@ -19,13 +19,16 @@ limitations under the License.
|
||||
#include <functional>
|
||||
#include <numeric>
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
|
||||
#include "llvm/ADT/Sequence.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringExtras.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/ADT/StringSwitch.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir
|
||||
#include "mlir/Dialect/Traits.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Builders.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Diagnostics.h" // TF:local_config_mlir
|
||||
@ -34,6 +37,7 @@ limitations under the License.
|
||||
#include "mlir/IR/Matchers.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/OpImplementation.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/TypeUtilities.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Types.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Value.h" // TF:local_config_mlir
|
||||
@ -71,10 +75,19 @@ static inline bool IsOfRankOrUnranked(Value *value, int64_t rank) {
|
||||
// Returns true if the given `value` has at least the specified rank or has
|
||||
// unranked type.
|
||||
static inline bool HasRankAtLeast(Value *value, int64_t rank) {
|
||||
auto type = value->getType();
|
||||
Type type = value->getType();
|
||||
if (auto ranked_type = type.dyn_cast<RankedTensorType>())
|
||||
return ranked_type.getRank() >= rank;
|
||||
return type.isa<UnrankedTensorType>();
|
||||
return true;
|
||||
}
|
||||
|
||||
// Returns true if the given `value` has at most the specified rank or has
|
||||
// unranked type.
|
||||
static inline bool HasRankAtMost(Value *value, int64_t rank) {
|
||||
Type type = value->getType();
|
||||
if (auto ranked_type = type.dyn_cast<RankedTensorType>())
|
||||
return ranked_type.getRank() <= rank;
|
||||
return true;
|
||||
}
|
||||
|
||||
// Returns true if the given pair of TensorFlow types can be cast to one
|
||||
@ -95,6 +108,85 @@ static bool IsUnknownDimOrRank(int64_t dim_or_rank) {
|
||||
return dim_or_rank == -1;
|
||||
}
|
||||
|
||||
// Returns the tf.Equal/tf.NotEqual result type given `x` and `y` and inputs. If
|
||||
// `incompatible_shape_error` is true, reports error if `x` and `y` has
|
||||
// incompatible shapes. Otherwise, returns a tensor type with unknown rank.
|
||||
static Type DeduceEqualCmpOpType(Builder *builder, Location loc, Value *x,
|
||||
Value *y, BoolAttr incompatible_shape_error) {
|
||||
auto result_type =
|
||||
OpTrait::util::getBroadcastedType(x->getType(), y->getType());
|
||||
if (!result_type) {
|
||||
if (incompatible_shape_error.getValue()) {
|
||||
mlir::emitError(loc, "non-broadcastable operands");
|
||||
} else {
|
||||
result_type = builder->getTensorType(builder->getI1Type());
|
||||
}
|
||||
}
|
||||
return result_type;
|
||||
}
|
||||
|
||||
// Verifies that the given types are cast compatible. If not, emits appropriate
|
||||
// error for the given op. If mask_one_dim is set to true, then the types are
|
||||
// allowed to have one mismatching dimension. Masking one of the dimensions is
|
||||
// useful for ops like Concat that requires all ranked inputs to have the same
|
||||
// rank and match dimension sizes for all but one of the dimensions.
|
||||
static LogicalResult VerifyTypesCompatibility(
|
||||
Operation::operand_type_range types, bool mask_one_dim, Operation *op) {
|
||||
constexpr int64_t kUninitialized = -1;
|
||||
int64_t common_rank = kUninitialized;
|
||||
llvm::SmallVector<int64_t, 4> common_dims;
|
||||
int64_t dim_to_mask = kUninitialized;
|
||||
|
||||
// Initialize common_rank with rank of the first ranked type and verify that
|
||||
// following ranked types have the same rank.
|
||||
// Similarly, initialize each of the dimensions with the first type that has
|
||||
// the dimension size available and verify that all following types have the
|
||||
// same size for the dimension. However, if mask_one_dim is true, note down
|
||||
// the dimension index on the first mismatch and ignore dimension at that
|
||||
// index in following types.
|
||||
for (Type ty : types) {
|
||||
RankedTensorType ranked_ty = ty.dyn_cast<RankedTensorType>();
|
||||
if (!ranked_ty) continue;
|
||||
|
||||
int64_t rank = ranked_ty.getRank();
|
||||
if (common_rank == kUninitialized) {
|
||||
common_rank = rank;
|
||||
common_dims.resize(common_rank, kUninitialized);
|
||||
} else if (common_rank != rank) {
|
||||
return op->emitError()
|
||||
<< "operand type " << ranked_ty
|
||||
<< " is not compatible with preceding operands; expected rank: "
|
||||
<< common_rank;
|
||||
}
|
||||
|
||||
for (int64_t i = 0, e = common_rank; i != e; i++) {
|
||||
if (i == dim_to_mask) continue;
|
||||
|
||||
int64_t dim = ranked_ty.getDimSize(i);
|
||||
if (dim == kUninitialized) continue;
|
||||
|
||||
int64_t &common_dim = common_dims[i];
|
||||
if (common_dim == kUninitialized) {
|
||||
common_dim = dim;
|
||||
} else if (common_dim != dim) {
|
||||
// If mask_one_dim is true, do not emit an error if this is the only
|
||||
// dimension with mismatches. Note down the dimension to mask it from
|
||||
// the following types.
|
||||
if (mask_one_dim && dim_to_mask == kUninitialized) {
|
||||
dim_to_mask = i;
|
||||
continue;
|
||||
}
|
||||
|
||||
return op->emitError() << "operand type " << ranked_ty
|
||||
<< " is not compatible with preceding operands; "
|
||||
"expected dimension at index "
|
||||
<< i << ": " << common_dim;
|
||||
}
|
||||
}
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
namespace {
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/generated_canonicalize.inc"
|
||||
} // namespace
|
||||
@ -176,6 +268,36 @@ void CastOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
results.insert<CastSameType>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ConcatOp and ConcatV2Op
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
template <typename OpT, typename = typename std::enable_if_t<
|
||||
llvm::is_one_of<OpT, ConcatOp, ConcatV2Op>::value>>
|
||||
static LogicalResult Verify(OpT op) {
|
||||
// TODO(hinsu): Convert variadic length attributes to derived attributes.
|
||||
Operation::operand_range values = op.values();
|
||||
|
||||
auto num_values = std::distance(values.begin(), values.end());
|
||||
int64_t attr_N = op.N().getLimitedValue();
|
||||
if (num_values != attr_N) {
|
||||
return op.emitOpError()
|
||||
<< "requires attribute 'N' to match the number of inputs; expected: "
|
||||
<< num_values << " Found: " << attr_N;
|
||||
}
|
||||
|
||||
int axis_idx = std::is_same<OpT, ConcatOp>() ? 0 : 1;
|
||||
Value *axis = *op.getODSOperands(axis_idx).begin();
|
||||
if (!HasRankAtMost(axis, 1)) {
|
||||
return op.emitOpError(
|
||||
"requires axis to be of scalar type (or vector type for older "
|
||||
"versions)");
|
||||
}
|
||||
|
||||
return VerifyTypesCompatibility(values,
|
||||
/*mask_one_dim=*/true, op.getOperation());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ConjOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -257,6 +379,26 @@ static LogicalResult Verify(EmptyTensorListOp op) {
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// EqualOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult Verify(EqualOp op) {
|
||||
// If we allow inputs to have incompatible type, then nothing to do.
|
||||
if (!op.incompatible_shape_error()) return success();
|
||||
|
||||
// Otherwise, check inputs are broadcastable.
|
||||
return mlir::OpTrait::impl::verifyCompatibleOperandBroadcast(
|
||||
op.getOperation());
|
||||
}
|
||||
|
||||
void EqualOp::build(Builder *builder, OperationState *result, Value *x,
|
||||
Value *y, BoolAttr incompatible_shape_error) {
|
||||
auto result_type = DeduceEqualCmpOpType(builder, result->location, x, y,
|
||||
incompatible_shape_error);
|
||||
return build(builder, result, result_type, x, y, incompatible_shape_error);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// FakeQuantWithMinMaxArgsOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -276,12 +418,6 @@ static LogicalResult Verify(FakeQuantWithMinMaxArgsOp op) {
|
||||
return op.emitOpError("range is invalid: [" + Twine(std::to_string(rmin)) +
|
||||
"," + Twine(std::to_string(rmax)) + "]");
|
||||
}
|
||||
// Range must straddle zero.
|
||||
if (rmin > 0.0 || rmax < 0.0) {
|
||||
return op.emitOpError("range failed to straddle zero: [" +
|
||||
Twine(std::to_string(rmin)) + "," +
|
||||
Twine(std::to_string(rmax)) + "]");
|
||||
}
|
||||
int64_t num_bits = op.num_bits().getSExtValue();
|
||||
if (num_bits < 2 || num_bits > 16) {
|
||||
return op.emitOpError(
|
||||
@ -308,6 +444,37 @@ static LogicalResult Verify(FakeQuantWithMinMaxVarsOp op) {
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// FakeQuantWithMinMaxVarsPerChannelOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
static LogicalResult Verify(FakeQuantWithMinMaxVarsPerChannelOp op) {
|
||||
if (!isOfRankedFloatTensorType(op.min(), 1))
|
||||
return op.emitOpError("requires min to be a 1d float tensor");
|
||||
|
||||
if (!isOfRankedFloatTensorType(op.max(), 1))
|
||||
return op.emitOpError("requires max to be a 1d float tensor");
|
||||
|
||||
Value *inputs = op.inputs();
|
||||
if (!HasRankAtLeast(inputs, 1) ||
|
||||
inputs->getType().isa<UnrankedTensorType>()) {
|
||||
return op.emitError("requires inputs to be at least 1d float tensor");
|
||||
}
|
||||
|
||||
auto inputsType = inputs->getType().cast<ShapedType>();
|
||||
int depth = inputsType.getDimSize(inputsType.getRank() - 1);
|
||||
if (op.min()->getType().cast<ShapedType>().getDimSize(0) != depth ||
|
||||
op.max()->getType().cast<ShapedType>().getDimSize(0) != depth) {
|
||||
return op.emitOpError(
|
||||
"requires min and max to have same size as last dimension of inputs");
|
||||
}
|
||||
int64_t num_bits = op.num_bits().getSExtValue();
|
||||
if (num_bits < 2 || num_bits > 16) {
|
||||
return op.emitOpError(
|
||||
"requires num_bits to be between 2 and 16, inclusive");
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// FusedBatchNormOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -471,6 +638,74 @@ void NegOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
results.insert<NegNested>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// NotEqualOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult Verify(NotEqualOp op) {
|
||||
// If we allow inputs to have incompatible type, then nothing to do.
|
||||
if (!op.incompatible_shape_error()) return success();
|
||||
|
||||
// Otherwise, check inputs are broadcastable.
|
||||
return mlir::OpTrait::impl::verifyCompatibleOperandBroadcast(
|
||||
op.getOperation());
|
||||
}
|
||||
|
||||
void NotEqualOp::build(Builder *builder, OperationState *result, Value *x,
|
||||
Value *y, BoolAttr incompatible_shape_error) {
|
||||
auto result_type = DeduceEqualCmpOpType(builder, result->location, x, y,
|
||||
incompatible_shape_error);
|
||||
return build(builder, result, result_type, x, y, incompatible_shape_error);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// PackOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult Verify(PackOp op) {
|
||||
// TODO(hinsu): Convert variadic length attributes to derived attributes.
|
||||
Operation::operand_range values = op.values();
|
||||
|
||||
auto num_values = std::distance(values.begin(), values.end());
|
||||
int64_t attr_N = op.N().getLimitedValue();
|
||||
if (num_values != attr_N) {
|
||||
return op.emitOpError()
|
||||
<< "requires attribute 'N' to match the number of inputs; expected: "
|
||||
<< num_values << " Found: " << attr_N;
|
||||
}
|
||||
|
||||
if (failed(VerifyTypesCompatibility(values,
|
||||
/*mask_one_dim=*/false,
|
||||
op.getOperation()))) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
int64_t inputs_rank = -1;
|
||||
for (Value *value : values) {
|
||||
if (auto ty = value->getType().dyn_cast<RankedTensorType>()) {
|
||||
// Exit early as input types are verified to be compatible so all ranked
|
||||
// tensors have the same rank.
|
||||
inputs_rank = ty.getRank();
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (inputs_rank == -1) return success();
|
||||
|
||||
// The values can be packed along any of the dimensions between 0 and
|
||||
// inputs rank, inclusive. Also, as the negative axis values wrap around so
|
||||
// the axis value range is [-(R+1), R+1).
|
||||
int64_t range_begin = -inputs_rank - 1; // Inclusive
|
||||
int64_t range_end = inputs_rank + 1; // Exclusive
|
||||
int64_t axis = op.axis().getLimitedValue();
|
||||
if (axis < range_begin || axis >= range_end) {
|
||||
return op.emitError() << "attribute 'axis' should be within range ["
|
||||
<< range_begin << ", " << range_end
|
||||
<< "); actual value: " << axis;
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ReciprocalOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -731,6 +966,16 @@ OpFoldResult ShapeOp::fold(ArrayRef<Attribute> operands) {
|
||||
return b.getDenseElementsAttr(resultType, dimensions);
|
||||
}
|
||||
|
||||
void ShapeOp::build(Builder *builder, OperationState *result, Value *input,
|
||||
BoolAttr use32Bit) {
|
||||
auto rankedTensorType = input->getType().dyn_cast<RankedTensorType>();
|
||||
int64_t rank = rankedTensorType ? rankedTensorType.getRank() : -1;
|
||||
auto out_type = use32Bit.getValue() ? builder->getIntegerType(32)
|
||||
: builder->getIntegerType(64);
|
||||
return ShapeOp::build(builder, result,
|
||||
builder->getTensorType({rank}, out_type), input);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ShapeNOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -0,0 +1,60 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
REGISTER_OP("MlirPassthroughOp")
|
||||
.Attr("mlir_module: string")
|
||||
.Attr("Tinputs : list(type) >= 0")
|
||||
.Input("inputs: Tinputs")
|
||||
.Attr("Toutputs : list(type) >= 0")
|
||||
.Output("outputs: Toutputs")
|
||||
.Doc(R"doc(
|
||||
Wraps an arbitrary MLIR computation expressed as a module with a main() function.
|
||||
|
||||
This operation does not have an associated kernel and is not intended to be
|
||||
executed in a regular TensorFlow session. Instead it is intended to be used for
|
||||
testing or for special case where a user intends to pass custom MLIR computation
|
||||
through a TensorFlow graph with the intent of having custom tooling processing
|
||||
it downstream (when targeting a different environment, like TensorFlow lite for
|
||||
example).
|
||||
The MLIR module is expected to have a main() function that will be used as an
|
||||
entry point. The inputs to the operations will be passed as argument to the
|
||||
main() function and the returned values of the main function mapped to the
|
||||
outputs.
|
||||
Example usage:
|
||||
|
||||
```
|
||||
import tensorflow as tf
|
||||
from tensorflow.compiler.mlir.tensorflow.gen_mlir_passthrough_op import mlir_passthrough_op
|
||||
|
||||
mlir_module = '''
|
||||
func @main(%arg0 : tensor<10xf32>, %arg1 : tensor<10xf32>) -> tensor<10x10xf32> {
|
||||
%add = "magic.op"(%arg0, %arg1) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10x10xf32>
|
||||
return %ret : tensor<10x10xf32>
|
||||
}
|
||||
'''
|
||||
|
||||
@tf.function
|
||||
def foo(x, y):
|
||||
return = mlir_passthrough_op([x, y], mlir_module, Toutputs=[tf.float32])
|
||||
|
||||
graph_def = foo.get_concrete_function(tf.TensorSpec([10], tf.float32), tf.TensorSpec([10], tf.float32)).graph.as_graph_def()
|
||||
```
|
||||
)doc");
|
||||
|
||||
} // namespace tensorflow
|
@ -236,8 +236,8 @@ func @testLogicalNotOfEqual(%arg0: tensor<8x16xf32>, %arg1: tensor<8x16xf32>) ->
|
||||
%1 = "tf.LogicalNot"(%0) : (tensor<8x16xi1>) -> tensor<8x16xi1>
|
||||
return %1: tensor<8x16xi1>
|
||||
|
||||
// CHECK: %0 = "tf.NotEqual"(%arg0, %arg1) : (tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xi1>
|
||||
// CHECK: return %0
|
||||
// CHECK: %[[NE:.*]] = "tf.NotEqual"(%arg0, %arg1) {incompatible_shape_error = true}
|
||||
// CHECK: return %[[NE]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testLogicalNotOfNotEqual
|
||||
@ -246,8 +246,8 @@ func @testLogicalNotOfNotEqual(%arg0: tensor<8x16xf32>, %arg1: tensor<8x16xf32>)
|
||||
%1 = "tf.LogicalNot"(%0) : (tensor<8x16xi1>) -> tensor<8x16xi1>
|
||||
return %1: tensor<8x16xi1>
|
||||
|
||||
// CHECK: %0 = "tf.Equal"(%arg0, %arg1) : (tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xi1>
|
||||
// CHECK: return %0
|
||||
// CHECK: %[[NE:.*]] = "tf.Equal"(%arg0, %arg1) {incompatible_shape_error = true}
|
||||
// CHECK: return %[[NE]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testLogicalNotOfGreater
|
||||
|
@ -1,4 +1,4 @@
|
||||
// RUN: tf-opt %s -split-input-file -tf-device-cluster-formation | FileCheck %s
|
||||
// RUN: tf-opt %s -split-input-file -tf-device-cluster-formation | FileCheck %s -dump-input-on-failure
|
||||
|
||||
// Simple case, single device cluster.
|
||||
|
||||
@ -72,11 +72,8 @@ module {
|
||||
// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor<?xi32>)
|
||||
func @argliveinotherislands(%arg0: tensor<?xi32>) -> tensor<?xi32> {
|
||||
%0 = tf_executor.graph {
|
||||
// CHECK: %[[OTHER_ISLAND_OUTPUT:[0-9]*]]:2 = tf_executor.island {
|
||||
%1:2 = tf_executor.island {
|
||||
%3 = "tf.D"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
|
||||
tf_executor.yield %3 : tensor<?xi32>
|
||||
}
|
||||
// CHECK: %[[OTHER_ISLAND_OUTPUT:[0-9]*]]:2 = tf_executor.island wraps "tf.D"
|
||||
%1:2 = tf_executor.island wraps "tf.D"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
|
||||
|
||||
%2:2 = tf_executor.island {
|
||||
// CHECK: %[[TPU0_OUTPUT:[0-9]*]] = "tf_device.launch"
|
||||
|
@ -90,16 +90,13 @@ module {
|
||||
// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor<?xi32>)
|
||||
func @multiplelaunches(%arg0: tensor<?xi32>) -> tensor<?xi32> {
|
||||
%0 = tf_executor.graph {
|
||||
%1:2 = tf_executor.island {
|
||||
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf_device.launch_func"() {device = "tpu0", func = @tpu0_func}
|
||||
%2 = "tf_device.launch"() ( {
|
||||
%1:2 = tf_executor.island wraps
|
||||
// CHECK: %[[A_OUTPUT:[0-9]*]]:2 = {{.*}} "tf_device.launch_func"() {device = "tpu0", func = @tpu0_func}
|
||||
"tf_device.launch"() ( {
|
||||
%3 = "tf.A"() : () -> tensor<?xi32>
|
||||
"tf_device.return"(%3) : (tensor<?xi32>) -> ()
|
||||
}) {device = "tpu0"} : () -> tensor<?xi32>
|
||||
|
||||
// CHECK: tf_executor.yield %[[A_OUTPUT]]
|
||||
tf_executor.yield %2 : tensor<?xi32>
|
||||
}
|
||||
// CHECK: tf_executor.fetch %[[A_OUTPUT]]#0
|
||||
tf_executor.fetch %1#0 : tensor<?xi32>
|
||||
}
|
||||
return %0 : tensor<?xi32>
|
||||
|
@ -11,14 +11,8 @@ func @islands_with_control(tensor<*xf32>) -> tensor<*xf32> {
|
||||
}
|
||||
|
||||
// CHECK-NEXT: %[[GRAPH:[0-9]*]] = tf_executor.graph {
|
||||
// CHECK-NEXT: %[[IDENTITY:[0-9]*]]:2 = tf_executor.island {
|
||||
// CHECK-NEXT: %{{[0-9]*}} = "tf.Identity"(%[[ARG0]]) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
// CHECK-NEXT: tf_executor.yield %{{[0-9]*}} : tensor<*xf32>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: %[[ADD:[0-9]*]]:2 = tf_executor.island(%[[IDENTITY]]#1) {
|
||||
// CHECK-NEXT: %{{[0-9]*}} = "tf.Add"(%[[ARG0]], %[[ARG0]]) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||
// CHECK-NEXT: tf_executor.yield %{{[0-9]*}} : tensor<*xf32>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: %[[IDENTITY:[0-9]*]]:2 = tf_executor.island wraps "tf.Identity"(%[[ARG0]]) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
// CHECK-NEXT: %[[ADD:[0-9]*]]:2 = tf_executor.island(%[[IDENTITY]]#1) wraps "tf.Add"(%[[ARG0]], %[[ARG0]]) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||
// CHECK-NEXT: tf_executor.fetch %[[ADD]]#0 : tensor<*xf32>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: return %[[GRAPH]] : tensor<*xf32>
|
||||
@ -45,40 +39,19 @@ func @LoopTest() {
|
||||
}
|
||||
|
||||
// CHECK-NEXT: tf_executor.graph {
|
||||
// CHECK-NEXT: %[[CONST:[0-9]*]]:2 = tf_executor.island {
|
||||
// CHECK-NEXT: %{{[a-z0-9]*}} = "tf.Const"() {device = "", dtype = "tfdtype$DT_INT32", name = "Const", value = dense<1> : tensor<i32>} : () -> tensor<i32>
|
||||
// CHECK-NEXT: tf_executor.yield %{{[a-z0-9]*}} : tensor<i32>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: %[[CONST:[0-9]*]]:2 = tf_executor.island wraps "tf.Const"() {device = "", dtype = "tfdtype$DT_INT32", name = "Const", value = dense<1> : tensor<i32>} : () -> tensor<i32>
|
||||
// CHECK-NEXT: %[[ENTER:[0-9]*]]:2 = tf_executor.Enter %[[CONST]]#0 frame "while/while_context" : (tensor<i32>) -> (tensor<*xi32>, !tf_executor.control) {T = "tfdtype$DT_INT32", device = "", name = "while/Enter"}
|
||||
// CHECK-NEXT: %[[NOOP:[0-9]*]] = tf_executor.island {
|
||||
// CHECK-NEXT: "tf.NoOp"() {device = "", name = "cluster/pivot"} : () -> ()
|
||||
// CHECK-NEXT: tf_executor.yield
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: %[[NOOP:[0-9]*]] = tf_executor.island wraps "tf.NoOp"() {device = "", name = "cluster/pivot"} : () -> ()
|
||||
// CHECK-NEXT: %[[NEXTIT_SRC:[0-9]*]]:3 = tf_executor.NextIteration.Source : tensor<*xi32> {T = "tfdtype$DT_INT32", device = "", id = 0 : i64, name = "while/NextIteration"}
|
||||
// CHECK-NEXT: %[[MERGE:[0-9]*]]:3 = tf_executor.Merge %[[NEXTIT_SRC]]#0, %[[ENTER]]#0 : tensor<*xi32> {N = 2 : i64, T = "tfdtype$DT_INT32", device = "", name = "while/Merge"}
|
||||
// CHECK-NEXT: %[[CONST_LESS:[0-9]*]]:2 = tf_executor.island(%[[MERGE]]#2) {
|
||||
// CHECK-NEXT: %{{[a-z0-9]*}} = "tf.Const"() {device = "", dtype = "tfdtype$DT_INT32", name = "while/Less/y", value = dense<2> : tensor<i32>} : () -> tensor<i32>
|
||||
// CHECK-NEXT: tf_executor.yield %{{[a-z0-9]*}} : tensor<i32>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: %[[LESS:[0-9]*]]:2 = tf_executor.island {
|
||||
// CHECK-NEXT: %{{[a-z0-9]*}} = "tf.Less"(%[[MERGE]]#0, %[[CONST_LESS]]#0) {T = "tfdtype$DT_INT32", device = "", name = "while/Less"} : (tensor<*xi32>, tensor<i32>) -> tensor<*xi1>
|
||||
// CHECK-NEXT: tf_executor.yield %{{[a-z0-9]*}} : tensor<*xi1>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: %[[CONST_LESS:[0-9]*]]:2 = tf_executor.island(%[[MERGE]]#2) wraps "tf.Const"() {device = "", dtype = "tfdtype$DT_INT32", name = "while/Less/y", value = dense<2> : tensor<i32>} : () -> tensor<i32>
|
||||
// CHECK-NEXT: %[[LESS:[0-9]*]]:2 = tf_executor.island wraps "tf.Less"(%[[MERGE]]#0, %[[CONST_LESS]]#0) {T = "tfdtype$DT_INT32", device = "", name = "while/Less"} : (tensor<*xi32>, tensor<i32>) -> tensor<*xi1>
|
||||
// CHECK-NEXT: %[[COND:[0-9]*]]:2 = tf_executor.LoopCond %[[LESS:[0-9]*]]#0 : (tensor<*xi1>) -> (tensor<i1>, !tf_executor.control) {device = "", name = "while/LoopCond"}
|
||||
// CHECK-NEXT: %[[SWITCH:[0-9]*]]:3 = tf_executor.Switch %[[MERGE]]#0, %[[COND]]#0 : tensor<*xi32> {T = "tfdtype$DT_INT32", _class = ["loc = @while/Merge"], device = "", name = "while/Switch"}
|
||||
// CHECK-NEXT: %[[EXIT:[0-9]*]]:2 = tf_executor.Exit %[[SWITCH]]#0 : tensor<*xi32> {T = "tfdtype$DT_INT32", device = "", name = "while/Exit"}
|
||||
// CHECK-NEXT: %[[IDENTITY:[0-9]*]]:2 = tf_executor.island {
|
||||
// CHECK-NEXT: %{{[a-z0-9]*}} = "tf.Identity"(%[[SWITCH]]#1) {T = "tfdtype$DT_INT32", device = "", name = "while/Identity"} : (tensor<*xi32>) -> tensor<*xi32>
|
||||
// CHECK-NEXT: tf_executor.yield %{{[a-z0-9]*}} : tensor<*xi32>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: %[[CONST_ADD:[0-9]*]]:2 = tf_executor.island(%[[IDENTITY]]#1) {
|
||||
// CHECK-NEXT: %{{[a-z0-9]*}} = "tf.Const"() {device = "", dtype = "tfdtype$DT_INT32", name = "while/Add/y", value = dense<3> : tensor<i32>} : () -> tensor<i32>
|
||||
// CHECK-NEXT: tf_executor.yield %{{[a-z0-9]*}} : tensor<i32>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: %[[ADD:[0-9]*]]:2 = tf_executor.island {
|
||||
// CHECK-NEXT: %{{[0-9]*}} = "tf.Add"(%[[IDENTITY]]#0, %[[CONST_ADD]]#0) {T = "tfdtype$DT_INT32", device = "", name = "while/Add"} : (tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
|
||||
// CHECK-NEXT: tf_executor.yield %{{[0-9]*}} : tensor<*xi32>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: %[[IDENTITY:[0-9]*]]:2 = tf_executor.island wraps "tf.Identity"(%[[SWITCH]]#1) {T = "tfdtype$DT_INT32", device = "", name = "while/Identity"} : (tensor<*xi32>) -> tensor<*xi32>
|
||||
// CHECK-NEXT: %[[CONST_ADD:[0-9]*]]:2 = tf_executor.island(%[[IDENTITY]]#1) wraps "tf.Const"() {device = "", dtype = "tfdtype$DT_INT32", name = "while/Add/y", value = dense<3> : tensor<i32>} : () -> tensor<i32>
|
||||
// CHECK-NEXT: %[[ADD:[0-9]*]]:2 = tf_executor.island wraps "tf.Add"(%[[IDENTITY]]#0, %[[CONST_ADD]]#0) {T = "tfdtype$DT_INT32", device = "", name = "while/Add"} : (tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
|
||||
// CHECK-NEXT: %[[CT:[0-9]*]] = tf_executor.ControlTrigger %[[NOOP]], %[[ADD]]#1, %[[EXIT]]#1 {_tpu_replicate = "cluster", device = "", name = "gradients/while/mul_2_Da30D05wlPU_grad/SymbolicGradient/b_sync"}
|
||||
// CHECK-NEXT: tf_executor.NextIteration.Sink [%[[NEXTIT_SRC]]#1] %[[ADD]]#0, %[[CT]] : tensor<*xi32> {T = "tfdtype$DT_INT32", device = "", id = 0 : i64, name = "while/NextIteration"}
|
||||
// CHECK-NEXT: tf_executor.fetch
|
||||
|
@ -285,9 +285,9 @@ func @empty_island_no_operand_no_data_result() {
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: %[[ISLAND_0:[0-9]*]] = tf_executor.island {
|
||||
// CHECK: %[[ISLAND_0:[0-9]*]] = tf_executor.island
|
||||
// CHECK-NEXT: "tf.opA"
|
||||
// CHECK: tf_executor.island(%[[ISLAND_0]]) {
|
||||
// CHECK: tf_executor.island(%[[ISLAND_0]])
|
||||
// CHECK-NEXT: "tf.opB"
|
||||
// CHECK-NOT: tf_executor.island
|
||||
|
||||
@ -313,9 +313,9 @@ func @empty_island_one_operand_no_data_result() {
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: %[[ISLAND_1:[0-9]*]] = tf_executor.island {
|
||||
// CHECK: %[[ISLAND_1:[0-9]*]] = tf_executor.island
|
||||
// CHECK-NEXT: "tf.opA"
|
||||
// CHECK: tf_executor.island(%[[ISLAND_1]]) {
|
||||
// CHECK: tf_executor.island(%[[ISLAND_1]])
|
||||
// CHECK-NEXT: "tf.opB"
|
||||
// CHECK-NOT: tf_executor.island
|
||||
|
||||
@ -342,8 +342,34 @@ func @empty_island_no_operand_one_data_no_control_result(%arg0 : tensor<i1>) {
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: tf_executor.island {
|
||||
// CHECK: tf_executor.island
|
||||
// CHECK-NEXT: "tf.opA"(%[[ARG_0]])
|
||||
// CHECK: tf_executor.island {
|
||||
// CHECK-NEXT: "tf.opB"(%[[ARG_0]])
|
||||
// CHECK-NOT: tf_executor.island
|
||||
|
||||
|
||||
// Test empty control trigger with no operands is removed.
|
||||
// Control result users should also have their respective operands removed.
|
||||
// CHECK-LABEL: func @empty_control_trigger
|
||||
func @empty_control_trigger() {
|
||||
tf_executor.graph {
|
||||
%0 = tf_executor.ControlTrigger {}
|
||||
%1 = tf_executor.island(%0) {
|
||||
%3 = "tf.opA"() : () -> tensor<i1>
|
||||
tf_executor.yield
|
||||
}
|
||||
%2 = tf_executor.island(%0, %1) {
|
||||
%4 = "tf.opB"() : () -> tensor<i1>
|
||||
tf_executor.yield
|
||||
}
|
||||
tf_executor.fetch
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: %[[ISLAND_0:[0-9]*]] = tf_executor.island
|
||||
// CHECK-NEXT: "tf.opA"
|
||||
// CHECK: tf_executor.island(%[[ISLAND_0]])
|
||||
// CHECK-NEXT: "tf.opB"
|
||||
// CHECK-NOT: tf_executor.island
|
||||
|
@ -89,9 +89,7 @@ func @empty_islands(%arg0 : tensor<i1>, %arg1 : tensor<i1>) -> (tensor<i1>, tens
|
||||
return %0#0, %0#1 : tensor<i1>, tensor<i1>
|
||||
}
|
||||
|
||||
// CHECK: %[[ISLAND:[0-9]*]]:3 = tf_executor.island {
|
||||
// CHECK-NEXT: %[[OP_A:[0-9]*]]:2 = "tf.opA"(%[[ARG_1]], %[[ARG_0]])
|
||||
// CHECK-NEXT: tf_executor.yield %[[OP_A]]#0, %[[OP_A]]#1 : tensor<i1>, tensor<i1>
|
||||
// CHECK: %[[ISLAND:[0-9]*]]:3 = tf_executor.island wraps "tf.opA"(%[[ARG_1]], %[[ARG_0]])
|
||||
// CHECK: tf_executor.fetch %[[ISLAND]]#0, %[[ISLAND]]#1 : tensor<i1>, tensor<i1>
|
||||
|
||||
|
||||
@ -228,9 +226,7 @@ func @islands_interleaved(%arg0 : tensor<i32>, %arg1 : tensor<i32>) -> (tensor<i
|
||||
// CHECK-NEXT: %[[OP_C:[0-9]*]] = "tf.opC"(%[[OP_A]])
|
||||
// CHECK-NEXT: %{{[0-9]*}} = "tf.opE"(%[[ARG_0]])
|
||||
// CHECK-NEXT: tf_executor.yield %[[OP_C]] : tensor<i32>
|
||||
// CHECK: tf_executor.island {
|
||||
// CHECK-NEXT: %[[OP_F:[0-9]*]] = "tf.opF"(%[[ARG_1]])
|
||||
// CHECK-NEXT: tf_executor.yield %[[OP_F]] : tensor<i32>
|
||||
// CHECK: tf_executor.island wraps "tf.opF"(%[[ARG_1]])
|
||||
// CHECK: tf_executor.fetch %[[ISLAND_0]]#0, %[[ISLAND_1]]#0 : tensor<i32>, tensor<i32>
|
||||
|
||||
|
||||
@ -279,13 +275,9 @@ func @merge_islands_only() {
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: %[[ISLAND_0:[0-9]*]]:2 = tf_executor.island {
|
||||
// CHECK-NEXT: %[[OP_A:.*]] = "tf.opA"
|
||||
// CHECK-NEXT: tf_executor.yield %[[OP_A]] : tensor<i32>
|
||||
// CHECK: %[[ISLAND_0:[0-9]*]]:2 = tf_executor.island wraps "tf.opA"
|
||||
// CHECK: %[[ENTER:[0-9]*]]:2 = tf_executor.Enter %[[ISLAND_0]]#0
|
||||
// CHECK-NEXT: %[[ISLAND_1:[0-9]*]] = tf_executor.island {
|
||||
// CHECK-NEXT: "tf.opB"()
|
||||
// CHECK-NEXT: tf_executor.yield
|
||||
// CHECK-NEXT: %[[ISLAND_1:[0-9]*]] = tf_executor.island wraps "tf.opB"()
|
||||
// CHECK: %[[NEXTIT_SRC:[0-9]*]]:3 = tf_executor.NextIteration.Source
|
||||
// CHECK-NEXT: %[[MERGE:[0-9]*]]:3 = tf_executor.Merge %[[NEXTIT_SRC]]#0, %[[ENTER]]#0
|
||||
// CHECK-NEXT: %[[ISLAND_2:[0-9]*]]:2 = tf_executor.island(%[[MERGE]]#2) {
|
||||
@ -322,9 +314,7 @@ func @simple_potential_cycle() {
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: %[[ISLAND:[0-9]*]]:2 = tf_executor.island {
|
||||
// CHECK-NEXT: %[[OP_A:[0-9]*]] = "tf.opA"
|
||||
// CHECK-NEXT: tf_executor.yield %[[OP_A]] : tensor<1xf32>
|
||||
// CHECK: %[[ISLAND:[0-9]*]]:2 = tf_executor.island wraps "tf.opA"
|
||||
// CHECK: %[[CT:[0-9]*]] = tf_executor.ControlTrigger %[[ISLAND]]#1
|
||||
// CHECK-NEXT: tf_executor.island(%[[CT]]) {
|
||||
// CHECK-NEXT: %[[OP_B:[0-9]*]] = "tf.opB"
|
||||
@ -384,9 +374,7 @@ func @merge_into_nested_data_result() {
|
||||
// CHECK-NEXT: [[OP_A:[0-9*]]] = "tf.opA"
|
||||
// CHECK-NEXT: [[INNER_GRAPH:[0-9]*]] = tf_executor.graph {
|
||||
// CHECK-NEXT: [[CT:[0-9]*]] = tf_executor.ControlTrigger
|
||||
// CHECK-NEXT: [[ISLAND_1:[0-9]*]]:2 = tf_executor.island(%[[CT]]) {
|
||||
// CHECK-NEXT: [[OP_B:[0-9]*]] = "tf.opB"(%[[OP_A]])
|
||||
// CHECK-NEXT: tf_executor.yield %[[OP_B]] : tensor<1xf32>
|
||||
// CHECK-NEXT: [[ISLAND_1:[0-9]*]]:2 = tf_executor.island(%[[CT]]) wraps "tf.opB"(%[[OP_A]])
|
||||
// CHECK: tf_executor.fetch %[[ISLAND_1]]#0 : tensor<1xf32>
|
||||
// CHECK: tf_executor.yield
|
||||
|
||||
@ -422,18 +410,14 @@ func @merge_islands_inner_graph() {
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: tf_executor.island {
|
||||
// CHECK-NEXT: [[OP_A:[0-9*]]] = "tf.opA"
|
||||
// CHECK-NEXT: tf_executor.yield %[[OP_A]] : tensor<1xf32>
|
||||
// CHECK: tf_executor.island {
|
||||
// CHECK-NEXT: [[INNER_GRAPH:[0-9]*]] = tf_executor.graph {
|
||||
// CHECK: tf_executor.island wraps "tf.opA"
|
||||
// CHECK: tf_executor.island wraps "tf_executor.graph"() ( {
|
||||
// CHECK-NEXT: [[ISLAND_1:[0-9]*]]:2 = tf_executor.island {
|
||||
// CHECK-NEXT: "tf.opB"
|
||||
// CHECK-NEXT: [[OP_C:[0-9]*]] = "tf.opC"
|
||||
// CHECK-NEXT: [[OP_D:[0-9]*]] = "tf.opD"(%[[OP_C]])
|
||||
// CHECK-NEXT: tf_executor.yield %[[OP_D]] : tensor<1xf32>
|
||||
// CHECK: tf_executor.fetch %[[ISLAND_1]]#0 : tensor<1xf32>
|
||||
// CHECK: tf_executor.yield %[[INNER_GRAPH]] : tensor<1xf32>
|
||||
|
||||
|
||||
// Test merging islands with control island operands and island results only if
|
||||
@ -454,7 +438,7 @@ func @merge_islands_closest_control() {
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: %[[ISLAND:[0-9]*]] = tf_executor.island {
|
||||
// CHECK: %[[ISLAND:[0-9]*]] = tf_executor.island
|
||||
// CHECK: tf_executor.ControlTrigger %[[ISLAND]]
|
||||
// CHECK: %[[CT:[0-9]*]] = tf_executor.ControlTrigger
|
||||
// CHECK: tf_executor.island(%[[ISLAND]], %[[CT]]) {
|
||||
// CHECK: tf_executor.island(%[[ISLAND]], %[[CT]])
|
||||
|
@ -0,0 +1,23 @@
|
||||
// RUN: tf-opt %s -canonicalize | FileCheck %s --dump-input=fail
|
||||
|
||||
// Test that a constant stays inside an island after canonicalization
|
||||
|
||||
// CHECK-LABEL: func @constant_in_island
|
||||
func @constant_in_island(%arg0 : tensor<i1>) -> tensor<f32> {
|
||||
%0 = tf_executor.graph {
|
||||
// CHECK: tf_executor.island
|
||||
// CHECK: tf.Const{{.*}}2.0
|
||||
%1:2 = tf_executor.island {
|
||||
%0 = "tf.Const"() {value = dense<2.000000e+00> : tensor<f32>} : () -> tensor<f32>
|
||||
tf_executor.yield %0 : tensor<f32>
|
||||
}
|
||||
// Uses two islands for no other reason than preventing canonicalization from
|
||||
// eliminating the graph entirely.
|
||||
%2:2 = tf_executor.island(%1#1) {
|
||||
%4 = "tf.opB"(%1#0) : (tensor<f32>) -> tensor<f32>
|
||||
tf_executor.yield %4 : tensor<f32>
|
||||
}
|
||||
tf_executor.fetch %2#0 : tensor<f32>
|
||||
}
|
||||
return %0 : tensor<f32>
|
||||
}
|
@ -97,3 +97,24 @@ func @switchN(%arg0: tensor<i32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
|
||||
}
|
||||
return %fetches : tensor<*xf32>
|
||||
}
|
||||
|
||||
// Test if tf_executor dialect ops with Ref types are mapped correctly to the ops in control dialect.
|
||||
// CHECK-LABEL: func @ref_tf_executor_ops
|
||||
func @ref_tf_executor_ops(%arg0: tensor<4x!tf.f32ref>, %arg1: tensor<4x!tf.f32ref>, %arg3: tensor<i32>, %arg4: tensor<i1> ) -> tensor<4x!tf.f32ref> {
|
||||
%result = tf_executor.graph {
|
||||
// CHECK: _tf.Enter
|
||||
%0:2 = tf_executor.Enter %arg0 frame "while/while_context" : (tensor<4x!tf.f32ref>) -> (tensor<4x!tf.f32ref>, !tf_executor.control)
|
||||
// CHECK: _tf.Exit
|
||||
%1:2 = tf_executor.Exit %arg0 : tensor<4x!tf.f32ref>
|
||||
// CHECK: _tf.Switch
|
||||
%2:3 = tf_executor.Switch %arg0, %arg4 : (tensor<4x!tf.f32ref>, tensor<i1>) -> (tensor<4x!tf.f32ref>, tensor<4x!tf.f32ref>, !tf_executor.control)
|
||||
// CHECK: _tf.Merge
|
||||
%3:3 = tf_executor.Merge %arg0, %arg1 : (tensor<4x!tf.f32ref>, tensor<4x!tf.f32ref>) -> (tensor<4x!tf.f32ref>, tensor<i32>, !tf_executor.control)
|
||||
// CHECK: _tf.NextIteration.source
|
||||
%4:3 = tf_executor.NextIteration.Source : tensor<4x!tf.f32ref>
|
||||
// CHECK: _tf.NextIteration.sink
|
||||
tf_executor.NextIteration.Sink [%4#1] %4#0 : tensor<4x!tf.f32ref>
|
||||
tf_executor.fetch %0#0 : tensor<4x!tf.f32ref>
|
||||
}
|
||||
return %result : tensor<4x!tf.f32ref>
|
||||
}
|
||||
|
@ -39,13 +39,10 @@ versions {
|
||||
# CHECK: func @main(%arg0: tensor<10xi32>, %arg1: tensor<10xi32>) -> tensor<10xi32>
|
||||
# CHECK: attributes {tf.entry_function = {inputs = "input0, input1", outputs = "Add"}} {
|
||||
|
||||
# CHECK: %[[INPUT0:[0-9]+]]:2 = tf_executor.island
|
||||
# CHECK-NEXT: "tf.Placeholder.input"(%arg0)
|
||||
# CHECK: %[[INPUT0:[0-9]+]]:2 = tf_executor.island wraps "tf.Placeholder.input"(%arg0)
|
||||
|
||||
# CHECK: %[[INPUT1:[0-9]+]]:2 = tf_executor.island
|
||||
# CHECK-NEXT: "tf.Placeholder.input"(%arg1)
|
||||
# CHECK: %[[INPUT1:[0-9]+]]:2 = tf_executor.island wraps "tf.Placeholder.input"(%arg1)
|
||||
|
||||
# CHECK: %[[add:[0-9]+]]:2 = tf_executor.island
|
||||
# CHECK-NEXT: "tf.Add"(%[[INPUT0]]#0, %[[INPUT1]]#0)
|
||||
# CHECK: %[[add:[0-9]+]]:2 = tf_executor.island wraps "tf.Add"(%[[INPUT0]]#0, %[[INPUT1]]#0)
|
||||
|
||||
# CHECK: fetch %[[add]]#0
|
||||
|
@ -41,8 +41,7 @@ library {
|
||||
}
|
||||
# Drop the control dependency on arg for the node "test"
|
||||
# CHECK-LABEL: func @foo
|
||||
# CHECK: tf_executor.island {
|
||||
# CHECK-NEXT: "tf.Const"()
|
||||
# CHECK: tf_executor.island wraps "tf.Const"()
|
||||
node_def {
|
||||
name: "test"
|
||||
op: "Const"
|
||||
|
@ -6,12 +6,9 @@
|
||||
|
||||
# CHECK: func @main(%arg0: tensor<*x!tf.resource>, %arg1: tensor<*x!tf.resource>) -> (tensor<f32>, tensor<f32>)
|
||||
# CHECK: attributes {tf.entry_function = {inputs = "args_0, args_1", outputs = "rets_0_RetVal, rets_1_RetVal"}} {
|
||||
# CHECK: %[[ISLAND_0:[0-9]]]:2 = tf_executor.island {
|
||||
# CHECK: "tf.Const"
|
||||
# CHECK: %[[ISLAND_1:[0-9]]]:2 = tf_executor.island {
|
||||
# CHECK: "tf.Identity"(%[[ISLAND_0]]#0)
|
||||
# CHECK: %[[ISLAND_2:[0-9]]]:2 = tf_executor.island {
|
||||
# CHECK: "tf.StatefulPartitionedCall"
|
||||
# CHECK: %[[ISLAND_0:[0-9]]]:2 = tf_executor.island wraps "tf.Const"
|
||||
# CHECK: %[[ISLAND_1:[0-9]]]:2 = tf_executor.island wraps "tf.Identity"(%[[ISLAND_0]]#0)
|
||||
# CHECK: %[[ISLAND_2:[0-9]]]:2 = tf_executor.island wraps "tf.StatefulPartitionedCall"
|
||||
# CHECK-SAME: f = @[[FUNC:[a-z0-9]*]]
|
||||
# CHECK: tf_executor.fetch %[[ISLAND_1]]#0, %[[ISLAND_2]]#0 : tensor<f32>, tensor<f32>
|
||||
# CHECK: func @[[FUNC]](%arg0: tensor<*xf32>, %arg1: tensor<*x!tf.resource>) -> tensor<*xf32>
|
||||
|
@ -5,7 +5,7 @@
|
||||
# FetchOp.
|
||||
|
||||
# Match the island containing the "tf.Neg", capture the output
|
||||
# CHECK: %[[ISLAND_0:[0-9]*]]:2 = tf_executor.island {{.*[[:space:]].*}} "tf.Neg"
|
||||
# CHECK: %[[ISLAND_0:[0-9]*]]:2 = tf_executor.island wraps "tf.Neg"
|
||||
|
||||
# Check that the tf.Neg control is passed to the fetch
|
||||
# CHECK: tf_executor.fetch {{.*}} %[[ISLAND_0]]#1 : tensor<*xf32>, !tf_executor.control
|
||||
|
@ -5,7 +5,7 @@
|
||||
# FetchOp.
|
||||
|
||||
# Match the island containing the "tf.Neg", capture the output
|
||||
# CHECK: %[[ISLAND:[0-9]*]]:2 = tf_executor.island {{.*[[:space:]].*}} "tf.Neg"
|
||||
# CHECK: %[[ISLAND:[0-9]*]]:2 = tf_executor.island wraps "tf.Neg"
|
||||
|
||||
# Check that the tf.Neg data output and control are passed to the fetch
|
||||
# CHECK: tf_executor.fetch %[[ISLAND]]#0, %[[ISLAND]]#1 : tensor<*xf32>, !tf_executor.control
|
||||
|
@ -0,0 +1,110 @@
|
||||
# RUN: tf-mlir-translate -graphdef-to-mlir %s -o - | FileCheck %s
|
||||
|
||||
# Verify that the _input_shapes attribute of the FunctionDef is respected.
|
||||
# This also checks that the output type is correctly inferred based on
|
||||
# that.
|
||||
#CHECK: func @identity_function0(%arg0: tensor<i32>) -> tensor<i32>
|
||||
|
||||
node {
|
||||
name: "Placeholder"
|
||||
op: "Placeholder"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_BOOL
|
||||
}
|
||||
}
|
||||
experimental_debug_info {
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "Placeholder_1"
|
||||
op: "Placeholder"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
experimental_debug_info {
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "If"
|
||||
op: "If"
|
||||
input: "Placeholder"
|
||||
input: "Placeholder_1"
|
||||
attr {
|
||||
key: "Tcond"
|
||||
value {
|
||||
type: DT_BOOL
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "Tin"
|
||||
value {
|
||||
list {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "Tout"
|
||||
value {
|
||||
list {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "else_branch"
|
||||
value {
|
||||
func {
|
||||
name: "identity_function"
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "then_branch"
|
||||
value {
|
||||
func {
|
||||
name: "identity_function"
|
||||
}
|
||||
}
|
||||
}
|
||||
experimental_debug_info {
|
||||
}
|
||||
}
|
||||
library {
|
||||
function {
|
||||
signature {
|
||||
name: "identity_function"
|
||||
input_arg {
|
||||
name: "identity_input"
|
||||
type: DT_INT32
|
||||
}
|
||||
output_arg {
|
||||
name: "identity_output"
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
ret {
|
||||
key: "identity_output"
|
||||
value: "identity_input"
|
||||
}
|
||||
attr {
|
||||
key: "_input_shapes"
|
||||
value {
|
||||
list {
|
||||
shape {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
versions {
|
||||
producer: 29
|
||||
min_consumer: 12
|
||||
}
|
||||
|
@ -0,0 +1,177 @@
|
||||
# RUN: tf-mlir-translate -graphdef-to-mlir %s -o - | FileCheck %s
|
||||
|
||||
# Verify that the _output_shapes attribute of ReadVariableOp's are used to get
|
||||
# variable types.
|
||||
# This also checks that the output type is correctly inferred based on
|
||||
# that.
|
||||
# CHECK: func @__inference_some_function_130(%arg0: tensor<*x!tf.resource>) -> tensor<f32>
|
||||
# CHECK: tf.ReadVariableOp"(%arg0) {{.*}} : (tensor<*x!tf.resource>) -> tensor<f32>
|
||||
|
||||
|
||||
node {
|
||||
name : "Variable"
|
||||
op : "VarHandleOp"
|
||||
attr {
|
||||
key : "shape"
|
||||
value {
|
||||
shape {
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key : "dtype"
|
||||
value {
|
||||
type : DT_FLOAT
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key : "shared_name"
|
||||
value {
|
||||
s: "Variable"
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key : "_output_shapes"
|
||||
value {
|
||||
list {
|
||||
shape {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name : "StatefulPartitionedCall"
|
||||
op : "StatefulPartitionedCall"
|
||||
input : [ "Variable" ]
|
||||
attr {
|
||||
key : "f"
|
||||
value {
|
||||
func {
|
||||
name: "__inference_some_function_13"
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key : "config_proto"
|
||||
value {
|
||||
s: "\n\x07\n\x03GPU\x10\x00\n\x07\n\x03\x43PU\x10\x01\x32\x02J\x00\x38\x01"
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key : "Tout"
|
||||
value {
|
||||
list {
|
||||
type : [ DT_FLOAT ]
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key : "_gradient_op_type"
|
||||
value {
|
||||
s: "PartitionedCall-29"
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key : "_output_shapes"
|
||||
value {
|
||||
list {
|
||||
shape {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key : "Tin"
|
||||
value {
|
||||
list {
|
||||
type : [ DT_RESOURCE ]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
library {
|
||||
function {
|
||||
signature {
|
||||
name: "__inference_some_function_13"
|
||||
input_arg {
|
||||
name : "readvariableop_resource"
|
||||
type : DT_RESOURCE
|
||||
}
|
||||
output_arg {
|
||||
name : "identity"
|
||||
type : DT_FLOAT
|
||||
}
|
||||
is_stateful : true
|
||||
control_output: [ "ReadVariableOp" ]
|
||||
}
|
||||
node_def {
|
||||
name : "ReadVariableOp"
|
||||
op : "ReadVariableOp"
|
||||
input : [ "readvariableop_resource" ]
|
||||
device: "/job:localhost/replica:0/task:0/device:CPU:0"
|
||||
attr {
|
||||
key : "dtype"
|
||||
value {
|
||||
type : DT_FLOAT
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key : "_output_shapes"
|
||||
value {
|
||||
list {
|
||||
shape {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node_def {
|
||||
name : "Identity"
|
||||
op : "Identity"
|
||||
input : [ "ReadVariableOp:value:0", "^ReadVariableOp" ]
|
||||
attr {
|
||||
key : "T"
|
||||
value {
|
||||
type : DT_FLOAT
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key : "_output_shapes"
|
||||
value {
|
||||
list {
|
||||
shape {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
ret {
|
||||
key : "identity"
|
||||
value: "Identity:output:0"
|
||||
}
|
||||
attr {
|
||||
key : "_input_shapes"
|
||||
value {
|
||||
list {
|
||||
shape {
|
||||
unknown_rank: true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
control_ret {
|
||||
key : "ReadVariableOp"
|
||||
value: "ReadVariableOp"
|
||||
}
|
||||
arg_attr {
|
||||
key : 0x00000000
|
||||
value {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
versions {
|
||||
producer : 148
|
||||
min_consumer : 12
|
||||
}
|
@ -7,8 +7,7 @@
|
||||
# CHECK: "tf.Placeholder.input"(%arg0)
|
||||
|
||||
# CHECK: tf.Relu
|
||||
# CHECK: %[[IDENTITY:[0-9]+]]:3 = tf_executor.island
|
||||
# CHECK-NEXT: tf.Identity
|
||||
# CHECK: %[[IDENTITY:[0-9]+]]:3 = tf_executor.island wraps "tf.IdentityN"
|
||||
# CHECK: fetch %[[IDENTITY]]#1, %[[IDENTITY]]#0 : tensor<f32>, tensor<f32>
|
||||
|
||||
node {
|
||||
|
@ -0,0 +1,101 @@
|
||||
# RUN: tf-mlir-translate -graphdef-to-mlir %s | FileCheck %s
|
||||
|
||||
# CHECK:"tf.MlirPassthroughOp"
|
||||
# CHECK: mlir_module = "\0Afunc @main(%arg0 : tensor<10xf32>, %arg1 : tensor<10xf32>) -> tensor<10x10xf32> {\0A %add = \22tf.Add\22(%arg0, %arg1) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>\0A %ret = \22magic.op\22(%add, %add) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10x10xf32>\0A return %ret : tensor<10x10xf32>\0A}\0A", name = "MlirPassthroughOp"} : (tensor<10xf32>, tensor<10xf32>) -> tensor<*xf32>
|
||||
|
||||
node {
|
||||
name: "x"
|
||||
op: "Placeholder"
|
||||
attr {
|
||||
key: "_user_specified_name"
|
||||
value {
|
||||
s: "x"
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "shape"
|
||||
value {
|
||||
shape {
|
||||
dim {
|
||||
size: 10
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "y"
|
||||
op: "Placeholder"
|
||||
attr {
|
||||
key: "_user_specified_name"
|
||||
value {
|
||||
s: "y"
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "shape"
|
||||
value {
|
||||
shape {
|
||||
dim {
|
||||
size: 10
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "MlirPassthroughOp"
|
||||
op: "MlirPassthroughOp"
|
||||
input: "x"
|
||||
input: "y"
|
||||
attr {
|
||||
key: "Tinputs"
|
||||
value {
|
||||
list {
|
||||
type: DT_FLOAT
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "Toutputs"
|
||||
value {
|
||||
list {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "mlir_module"
|
||||
value {
|
||||
s: "\nfunc @main(%arg0 : tensor<10xf32>, %arg1 : tensor<10xf32>) -> tensor<10x10xf32> {\n %add = \"tf.Add\"(%arg0, %arg1) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>\n %ret = \"magic.op\"(%add, %add) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10x10xf32>\n return %ret : tensor<10x10xf32>\n}\n"
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "Identity"
|
||||
op: "Identity"
|
||||
input: "MlirPassthroughOp"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
}
|
||||
versions {
|
||||
producer: 148
|
||||
}
|
||||
|
@ -13,15 +13,12 @@ func @foo(%arg0: tensor<!tf.resource>) -> tensor<!tf.resource> {
|
||||
// The IsolatePlacerInspectionRequiredOpsPass adds Identities for each input/output of function-calling ops.
|
||||
|
||||
// Capture the result of input to function call.
|
||||
// CHECK: [[VARIABLE_REG:%[0-9]*]]:2 = tf_executor.island
|
||||
// CHECK-NEXT: "tf.VarHandleOp"()
|
||||
// CHECK: [[VARIABLE_REG:%[0-9]*]]:2 = tf_executor.island wraps "tf.VarHandleOp"()
|
||||
|
||||
// Test for the presence of Identity op between input and function call.
|
||||
// CHECK: [[IDENTITY_REG:%[0-9]*]]:2 = tf_executor.island
|
||||
// CHECK-NEXT: "tf.Identity"([[VARIABLE_REG]]#0)
|
||||
// CHECK: [[IDENTITY_REG:%[0-9]*]]:2 = tf_executor.island wraps "tf.Identity"([[VARIABLE_REG]]#0)
|
||||
|
||||
// CHECK: [[CALL_RESULT_REG:%[0-9]*]]:2 = tf_executor.island
|
||||
// CHECK-NEXT: "tf.StatefulPartitionedCall"([[IDENTITY_REG]]#0)
|
||||
// CHECK: [[CALL_RESULT_REG:%[0-9]*]]:2 = tf_executor.island wraps "tf.StatefulPartitionedCall"([[IDENTITY_REG]]#0)
|
||||
// CHECK-SAME: f = @[[FUNCTION:[a-zA-Z0-9_]*]]
|
||||
|
||||
// Match the inserted Identity op for call output.
|
||||
|
25
tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir
Normal file
25
tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir
Normal file
@ -0,0 +1,25 @@
|
||||
// RUN: tf-opt %s -test-tf-lower-tf | FileCheck %s --dump-input-on-failure
|
||||
|
||||
// CHECK-LABEL: simple_pack
|
||||
// CHECK-SAME: %[[ARG0:.*]]: tensor<3x5xf32>, %[[ARG1:.*]]: tensor<3x5xf32>
|
||||
func @simple_pack(%arg0: tensor<3x5xf32>, %arg1: tensor<3x5xf32>) -> tensor<2x3x5xf32> {
|
||||
// CHECK: %[[AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor<i64>}
|
||||
// CHECK: %[[INP0:.*]] = "tf.ExpandDims"(%[[ARG0]], %[[AXIS]]) : (tensor<3x5xf32>, tensor<i64>) -> tensor<1x3x5xf32>
|
||||
// CHECK: %[[INP1:.*]] = "tf.ExpandDims"(%[[ARG1]], %[[AXIS]]) : (tensor<3x5xf32>, tensor<i64>) -> tensor<1x3x5xf32>
|
||||
// CHECK: "tf.ConcatV2"(%[[INP0]], %[[INP1]], %[[AXIS]]) {N = 2 : i64} : (tensor<1x3x5xf32>, tensor<1x3x5xf32>, tensor<i64>) -> tensor<2x3x5xf32>
|
||||
|
||||
%0 = "tf.Pack"(%arg0, %arg1) {N = 2 : i64} : (tensor<3x5xf32>, tensor<3x5xf32>) -> tensor<2x3x5xf32>
|
||||
return %0 : tensor<2x3x5xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: pack_with_unranked
|
||||
// CHECK-SAME: %[[ARG0:.*]]: tensor<?x5xf32>, %[[ARG1:.*]]: tensor<*xf32>
|
||||
func @pack_with_unranked(%arg0: tensor<?x5xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
|
||||
// CHECK: %[[AXIS:.*]] = "tf.Const"() {value = dense<-2> : tensor<i64>}
|
||||
// CHECK: %[[INP0:.*]] = "tf.ExpandDims"(%[[ARG0]], %[[AXIS]]) : (tensor<?x5xf32>, tensor<i64>) -> tensor<?x1x5xf32>
|
||||
// CHECK: %[[INP1:.*]] = "tf.ExpandDims"(%[[ARG1]], %[[AXIS]]) : (tensor<*xf32>, tensor<i64>) -> tensor<*xf32>
|
||||
// CHECK: "tf.ConcatV2"(%[[INP0]], %[[INP1]], %[[AXIS]]) {N = 2 : i64} : (tensor<?x1x5xf32>, tensor<*xf32>, tensor<i64>) -> tensor<*xf32>
|
||||
|
||||
%0 = "tf.Pack"(%arg0, %arg1) {axis = -2 : i64, N = 2 : i64} : (tensor<?x5xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
@ -0,0 +1,15 @@
|
||||
// RUN: tf-opt -tf-materialize-passthrough-op %s | FileCheck %s --dump-input=fail
|
||||
|
||||
|
||||
// Check that the MlirPassthroughOp is eliminated and replaced by its attached
|
||||
// MLIR module.
|
||||
|
||||
// CHECK-LABEL: func @main
|
||||
func @main(%arg0 : tensor<10xf32>, %arg1 : tensor<10xf32>) -> tensor<10x10xf32> {
|
||||
// CHECK-SAME: (%[[ARG0:.*]]: tensor<10xf32>, %[[ARG1:.*]]: tensor<10xf32>)
|
||||
// CHECK-NEXT: %[[ADD:.*]] = "tf.Add"(%[[ARG0]], %[[ARG1]]) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
|
||||
// CHECK-NEXT: %[[MAGIC:.*]] = "magic.op"(%[[ADD]], %[[ADD]]) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10x10xf32>
|
||||
// CHECK-NEXT: return %[[MAGIC]]
|
||||
%0 = "tf.MlirPassthroughOp"(%arg0, %arg1) {Tinputs = ["tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT"], Toutputs = ["tfdtype$DT_FLOAT"], device = "", mlir_module = "\0Afunc @main(%arg0 : tensor<10xf32>, %arg1 : tensor<10xf32>) -> tensor<10x10xf32> {\0A %add = \22tf.Add\22(%arg0, %arg1) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>\0A %ret = \22magic.op\22(%add, %add) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10x10xf32>\0A return %ret : tensor<10x10xf32>\0A}\0A", name = "MlirPassthroughOp"} : (tensor<10xf32>, tensor<10xf32>) -> tensor<10x10xf32>
|
||||
return %0 : tensor<10x10xf32>
|
||||
}
|
@ -0,0 +1,23 @@
|
||||
// RUN: tf-mlir-translate -mlir-to-graphdef %s -o - | FileCheck %s
|
||||
|
||||
// Verify the ops generated when Ref type is used in a while loop.
|
||||
func @main() {
|
||||
// CHECK: op: "RefEnter"
|
||||
// CHECK: op: "RefMerge"
|
||||
// CHECK: op: "RefSwitch"
|
||||
// CHECK: op: "RefExit"
|
||||
// CHECK: op: "RefNextIteration"
|
||||
%0:2 = "_tf.NextIteration.source"() {device = "", T = "tfdtype$DT_INT32"} : () -> (tensor<*x!tf.int32ref>, !_tf.control) loc("while/NextIteration")
|
||||
%1:2 = "_tf.VariableV2"() {device = "", dtype = "tfdtype$DT_INT32", value = dense<0> : tensor<i32>} : () -> (tensor<!tf.int32ref>, !_tf.control) loc("Ref_Variable")
|
||||
%2:2 = "_tf.Enter"(%1#0) {device = "", T = "tfdtype$DT_INT32", frame_name = "while/while_context", is_constant = false, parallel_iterations = 10} : (tensor<!tf.int32ref>) -> (tensor<*x!tf.int32ref>, !_tf.control) loc("while/Enter")
|
||||
%3:3 = "_tf.Merge"(%2#0, %0#0) {device = "", N = 2, T = "tfdtype$DT_INT32"} : (tensor<*x!tf.int32ref>, tensor<*x!tf.int32ref>) -> (tensor<*x!tf.int32ref>, tensor<i32>, !_tf.control) loc("while/Merge")
|
||||
%4:2 = "_tf.Const"(%3#2) {device = "", dtype = "tfdtype$DT_INT32", value = dense<10> : tensor<i32>} : (!_tf.control) -> (tensor<i32>, !_tf.control) loc("while/Less/y")
|
||||
%5:2 = "_tf.Less"(%3#0, %4#0) {device = "", T = "tfdtype$DT_INT32"} : (tensor<*x!tf.int32ref>, tensor<i32>) -> (tensor<*xi1>, !_tf.control) loc("while/Less")
|
||||
%6:2 = "_tf.LoopCond"(%5#0) {device = ""} : (tensor<*xi1>) -> (tensor<i1>, !_tf.control) loc("while/LoopCond")
|
||||
%7:3 = "_tf.Switch"(%3#0, %6#0) {device = "", T = "tfdtype$DT_INT32", _class = ["loc:@while/Merge"]} : (tensor<*x!tf.int32ref>, tensor<i1>) -> (tensor<*x!tf.int32ref>, tensor<*x!tf.int32ref>, !_tf.control) loc("while/Switch")
|
||||
%8:2 = "_tf.Exit"(%7#1) {device = "", T = "tfdtype$DT_INT32"} : (tensor<*x!tf.int32ref>) -> (tensor<*x!tf.int32ref>, !_tf.control) loc("while/Exit")
|
||||
%10:2 = "_tf.Const"(%7#2) {device = "", dtype = "tfdtype$DT_INT32", value = dense<1> : tensor<i32>} : (!_tf.control) -> (tensor<i32>, !_tf.control) loc("while/Add/y")
|
||||
%11:2 = "_tf.AssignAdd"(%7#0, %10#0) {device = "", T = "tfdtype$DT_INT32"} : (tensor<*x!tf.int32ref>, tensor<i32>) -> (tensor<*x!tf.int32ref>, !_tf.control) loc("while/Add")
|
||||
%12 = "_tf.NextIteration.sink"(%11#0) {device = "", T = "tfdtype$DT_INT32"} : (tensor<*x!tf.int32ref>) -> !_tf.control loc("while/NextIteration")
|
||||
return
|
||||
}
|
41
tensorflow/compiler/mlir/tensorflow/tests/sink_constant.mlir
Normal file
41
tensorflow/compiler/mlir/tensorflow/tests/sink_constant.mlir
Normal file
@ -0,0 +1,41 @@
|
||||
// RUN: tf-opt %s -tf-device-constant-sinking | FileCheck %s --dump-input=fail
|
||||
|
||||
// CHECK-LABEL: func @sink_const
|
||||
func @sink_const(%arg0 : tensor<16xf32>) -> (tensor<16xf32>, tensor<f32>) {
|
||||
// Verify that the constant are sunk in the tf_device.launch region using them
|
||||
// and removed if no other use is left.
|
||||
|
||||
// Only the 2.0 and 3.0 constants are removed, the 4.0 has a use in the return
|
||||
// CHECK-NOT:"tf.Const"2.0
|
||||
// CHECK-NOT:"tf.Const"3.0
|
||||
%0 = "tf.Const"() {value = dense<2.000000e+00> : tensor<f32>} : () -> tensor<f32>
|
||||
%1 = "tf.Const"() {value = dense<3.000000e+00> : tensor<f32>} : () -> tensor<f32>
|
||||
%2 = "tf.Const"() {value = dense<4.000000e+00> : tensor<f32>} : () -> tensor<f32>
|
||||
%3 = tf_executor.graph {
|
||||
%res, %ctl = tf_executor.island {
|
||||
%3 = "tf_device.launch"() ({
|
||||
|
||||
// In the device region, check that the 3 constants are materialized and
|
||||
// remapped to the uses.
|
||||
// CHECK: tf_device.launch
|
||||
// CHECK-DAG: %[[CST2:.*]] = "tf.Const"{{.*}}2.0
|
||||
// CHECK-DAG: %[[CST3:.*]] = "tf.Const"{{.*}}3.0
|
||||
// CHECK-DAG: %[[CST4:.*]] = "tf.Const"{{.*}}4.0
|
||||
// CHECK-NOT:"tf.Const"
|
||||
// CHECK: %[[MUL1:.*]] = "tf.Mul"(%arg0, %[[CST2]])
|
||||
// CHECK-NEXT: %[[MUL2:.*]] = "tf.Mul"(%[[MUL1]], %[[CST2]])
|
||||
// CHECK-NEXT: %[[MUL3:.*]] = "tf.Mul"(%[[MUL2]], %[[CST3]])
|
||||
// CHECK-NEXT: = "tf.Mul"(%[[MUL3]], %[[CST4]])
|
||||
%3 = "tf.Mul"(%arg0, %0) : (tensor<16xf32>, tensor<f32>) -> tensor<16xf32>
|
||||
%4 = "tf.Mul"(%3, %0) : (tensor<16xf32>, tensor<f32>) -> tensor<16xf32>
|
||||
%5 = "tf.Mul"(%4, %1) : (tensor<16xf32>, tensor<f32>) -> tensor<16xf32>
|
||||
%6 = "tf.Mul"(%5, %2) : (tensor<16xf32>, tensor<f32>) -> tensor<16xf32>
|
||||
"tf_device.return"(%6) : (tensor<16xf32>) -> ()
|
||||
}) {device = "tpu0"} : () -> tensor<16xf32>
|
||||
tf_executor.yield %3 : tensor<16xf32>
|
||||
}
|
||||
tf_executor.fetch %res : tensor<16xf32>
|
||||
}
|
||||
return %3, %2 : tensor<16xf32>, tensor<f32>
|
||||
}
|
||||
|
@ -83,6 +83,15 @@ func @testBitcast(%arg0: tensor<3x4x!tf.uint16>) -> tensor<3x4x!tf.quint16> {
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @testReverseV2
|
||||
func @testReverseV2(%arg0: tensor<2x4x3x!tf.uint8>, %arg1: tensor<1xi32>) -> tensor<2x4x3x!tf.uint8> {
|
||||
// CHECK: tf.ReverseV2
|
||||
%0 = "tf.ReverseV2"(%arg0, %arg1) : (tensor<2x4x3x!tf.uint8>, tensor<1xi32>) -> tensor<2x4x3x!tf.uint8>
|
||||
return %0 : tensor<2x4x3x!tf.uint8>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testIdentityWrongType(%arg0: tensor<4x2x!tf.string>) -> tensor<4x2x!tf.stringref> {
|
||||
// expected-error @+1 {{requires all operands to be either same as or ref type of results}}
|
||||
%0 = "tf.Identity"(%arg0) : (tensor<4x2x!tf.string>) -> tensor<4x2x!tf.stringref>
|
||||
@ -459,6 +468,37 @@ func @testInvalidFakeQuantWithMinMaxVarsWrongMaxType(tensor<8x8x8x8xf32>, tensor
|
||||
|
||||
// -----
|
||||
|
||||
// Test valid tf.FakeQuantWithMinMaxVarsPerChannel
|
||||
// CHECK-LABEL: func @FakeQuantWithMinMaxVarsPerChannel
|
||||
func @FakeQuantWithMinMaxVarsPerChannel(tensor<1x2x3x8xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<1x2x3x8xf32> {
|
||||
^bb0(%arg0: tensor<1x2x3x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>):
|
||||
// CHECK: "tf.FakeQuantWithMinMaxVarsPerChannel"(%arg0, %arg1, %arg2) : (tensor<1x2x3x8xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<1x2x3x8xf32>
|
||||
%0 = "tf.FakeQuantWithMinMaxVarsPerChannel"(%arg0, %arg1, %arg2) : (tensor<1x2x3x8xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<1x2x3x8xf32>
|
||||
return %0 : tensor<1x2x3x8xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test invalid tf.FakeQuantWithMinMaxVarsPerChannel
|
||||
func @FakeQuantWithMinMaxVarsPerChannel_ranked_inputs(tensor<f32>, tensor<8xf32>, tensor<8xf32>) -> tensor<f32> {
|
||||
^bb0(%arg0: tensor<f32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>):
|
||||
// expected-error @+1 {{requires inputs to be at least 1d float tensor}}
|
||||
%0 = "tf.FakeQuantWithMinMaxVarsPerChannel"(%arg0, %arg1, %arg2) : (tensor<f32>, tensor<8xf32>, tensor<8xf32>) -> tensor<f32>
|
||||
return %0 : tensor<f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test invalid tf.FakeQuantWithMinMaxVarsPerChannel
|
||||
func @FakeQuantWithMinMaxVarsPerChannel_mismatch_min_max(tensor<1x2x3x8xf32>, tensor<1xf32>, tensor<8xf32>) -> tensor<1x2x3x8xf32> {
|
||||
^bb0(%arg0: tensor<1x2x3x8xf32>, %arg1: tensor<1xf32>, %arg2: tensor<8xf32>):
|
||||
// expected-error @+1 {{requires min and max to have same size as last dimension of inputs}}
|
||||
%0 = "tf.FakeQuantWithMinMaxVarsPerChannel"(%arg0, %arg1, %arg2) : (tensor<1x2x3x8xf32>, tensor<1xf32>, tensor<8xf32>) -> tensor<1x2x3x8xf32>
|
||||
return %0 : tensor<1x2x3x8xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test valid tf.FusedBatchNorm
|
||||
// CHECK-LABEL: func @testFusedBatchNorm
|
||||
func @testFusedBatchNorm(tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32> {
|
||||
@ -944,25 +984,25 @@ func @testLess(tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> {
|
||||
// -----
|
||||
|
||||
// Test valid tf.ConcatV2
|
||||
func @testConcatV2(%arg: tensor<8x16xf32>, %axis: tensor<1xi32>) -> tensor<?xf32> {
|
||||
// CHECK: %0 = "tf.ConcatV2"(%arg0, %arg0, %arg1) {N = 2 : i64} : (tensor<8x16xf32>, tensor<8x16xf32>, tensor<1xi32>) -> tensor<?xf32>
|
||||
%0 = "tf.ConcatV2"(%arg, %arg, %axis) {N = 2: i64} : (tensor<8x16xf32>, tensor<8x16xf32>, tensor<1xi32>) -> tensor<?xf32>
|
||||
func @testConcatV2(%arg: tensor<8x16xf32>, %axis: tensor<i32>) -> tensor<?xf32> {
|
||||
// CHECK: %0 = "tf.ConcatV2"(%arg0, %arg0, %arg1) {N = 2 : i64} : (tensor<8x16xf32>, tensor<8x16xf32>, tensor<i32>) -> tensor<?xf32>
|
||||
%0 = "tf.ConcatV2"(%arg, %arg, %axis) {N = 2: i64} : (tensor<8x16xf32>, tensor<8x16xf32>, tensor<i32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// tf.ConcatV2 with wrong 'axis' element type
|
||||
func @testConcatV2(%arg: tensor<8x16xf32>, %axis: tensor<1xf32>) -> tensor<?xf32> {
|
||||
func @testConcatV2(%arg: tensor<8x16xf32>, %axis: tensor<f32>) -> tensor<?xf32> {
|
||||
// expected-error @+1 {{operand #2 must be tensor of 32/64-bit integer values}}
|
||||
%0 = "tf.ConcatV2"(%arg, %arg, %axis) {N = 2: i64} : (tensor<8x16xf32>, tensor<8x16xf32>, tensor<1xf32>) -> tensor<?xf32>
|
||||
%0 = "tf.ConcatV2"(%arg, %arg, %axis) {N = 2: i64} : (tensor<8x16xf32>, tensor<8x16xf32>, tensor<f32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// tf.ConcatV2 missing required 'axis' operand
|
||||
func @testConcatV2(%arg: tensor<8x16xf32>, %axis: tensor<1xi32>) -> tensor<?xf32> {
|
||||
func @testConcatV2() -> tensor<?xf32> {
|
||||
// expected-error @+1 {{expected 1 or more operands}}
|
||||
%0 = "tf.ConcatV2"() {N = 0: i64} : () -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
@ -971,9 +1011,165 @@ func @testConcatV2(%arg: tensor<8x16xf32>, %axis: tensor<1xi32>) -> tensor<?xf32
|
||||
// -----
|
||||
|
||||
// tf.ConcatV2 with less than required number of values for the variadic operand
|
||||
func @testConcatV2(%arg: tensor<8x16xf32>, %axis: tensor<1xi32>) -> tensor<?xf32> {
|
||||
func @testConcatV2(%arg: tensor<8x16xf32>, %axis: tensor<i32>) -> tensor<?xf32> {
|
||||
// expected-error @+1 {{attribute 'N' failed to satisfy constraint: 64-bit integer attribute whose minimal value is 2}}
|
||||
%0 = "tf.ConcatV2"(%arg, %axis) {N = 1: i64} : (tensor<8x16xf32>, tensor<1xi32>) -> tensor<?xf32>
|
||||
%0 = "tf.ConcatV2"(%arg, %axis) {N = 1: i64} : (tensor<8x16xf32>, tensor<i32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testConcatV2(%arg: tensor<8x16xf32>, %axis: tensor<i32>) -> tensor<?xf32> {
|
||||
// expected-error @+1 {{requires attribute 'N' to match the number of inputs; expected: 2 Found: 3}}
|
||||
%0 = "tf.ConcatV2"(%arg, %arg, %axis) {N = 3: i64} : (tensor<8x16xf32>, tensor<8x16xf32>, tensor<i32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testAll(%arg0: tensor<2x2xi1>, %arg1: tensor<i32>) -> tensor<i1> {
|
||||
%0 = "tf.All"(%arg0, %arg1) {keep_dims = false} : (tensor<2x2xi1>, tensor<i32>) -> tensor<i1>
|
||||
return %0 : tensor<i1>
|
||||
|
||||
// CHECK-LABEL: testAll
|
||||
// CHECK: %0 = "tf.All"(%arg0, %arg1) {keep_dims = false} : (tensor<2x2xi1>, tensor<i32>) -> tensor<i1>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testAll64(%arg0: tensor<2x2xi1>, %arg1: tensor<i64>) -> tensor<i1> {
|
||||
%0 = "tf.All"(%arg0, %arg1) {keep_dims = false} : (tensor<2x2xi1>, tensor<i64>) -> tensor<i1>
|
||||
return %0 : tensor<i1>
|
||||
|
||||
// CHECK-LABEL: testAll64
|
||||
// CHECK: %0 = "tf.All"(%arg0, %arg1) {keep_dims = false} : (tensor<2x2xi1>, tensor<i64>) -> tensor<i1>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testAllFloat(%arg0: tensor<2x2xi1>, %arg1: tensor<f32>) -> tensor<i1> {
|
||||
// expected-error @+1 {{'tf.All' op operand #1 must be tensor of 32/64-bit integer values}}
|
||||
%0 = "tf.All"(%arg0, %arg1) {keep_dims = false} : (tensor<2x2xi1>, tensor<f32>) -> tensor<i1>
|
||||
return %0 : tensor<i1>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testAllI32(%arg0: tensor<2x2xi32>, %arg1: tensor<f32>) -> tensor<i32> {
|
||||
// expected-error @+1 {{'tf.All' op operand #0 must be tensor of 1-bit integer values}}
|
||||
%0 = "tf.All"(%arg0, %arg1) {keep_dims = false} : (tensor<2x2xi32>, tensor<f32>) -> tensor<i32>
|
||||
return %0 : tensor<i32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testEqualOpIncompatibleShapeTrue(%x: tensor<5xf32>, %y: tensor<4xf32>) -> tensor<5xi1> {
|
||||
// expected-error @+1 {{operands don't have broadcast-compatible shapes}}
|
||||
%0 = "tf.Equal"(%x, %y) {incompatible_shape_error = true} : (tensor<5xf32>, tensor<4xf32>) -> tensor<5xi1>
|
||||
return %0 : tensor<5xi1>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: testEqualOpIncompatibleShapeFalse
|
||||
func @testEqualOpIncompatibleShapeFalse(%x: tensor<5xf32>, %y: tensor<4xf32>) -> tensor<*xi1> {
|
||||
// CHECK: tf.Equal
|
||||
%0 = "tf.Equal"(%x, %y) {incompatible_shape_error = false} : (tensor<5xf32>, tensor<4xf32>) -> tensor<*xi1>
|
||||
return %0 : tensor<*xi1>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testNotEqualOpIncompatibleShapeTrue(%x: tensor<5xf32>, %y: tensor<4xf32>) -> tensor<5xi1> {
|
||||
// expected-error @+1 {{operands don't have broadcast-compatible shapes}}
|
||||
%0 = "tf.NotEqual"(%x, %y) {incompatible_shape_error = true} : (tensor<5xf32>, tensor<4xf32>) -> tensor<5xi1>
|
||||
return %0 : tensor<5xi1>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: testNotEqualOpIncompatibleShapeFalse
|
||||
func @testNotEqualOpIncompatibleShapeFalse(%x: tensor<5xf32>, %y: tensor<4xf32>) -> tensor<*xi1> {
|
||||
// CHECK: tf.NotEqual
|
||||
%0 = "tf.NotEqual"(%x, %y) {incompatible_shape_error = false} : (tensor<5xf32>, tensor<4xf32>) -> tensor<*xi1>
|
||||
return %0 : tensor<*xi1>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testConcatV2(%arg: tensor<8x16xf32>, %axis: tensor<1x1xi32>) -> tensor<*xf32> { // expected-error @+1 {{requires axis to be of scalar type (or vector type for older versions)}}
|
||||
%0 = "tf.ConcatV2"(%arg, %arg, %axis) {N = 2: i64} : (tensor<8x16xf32>, tensor<8x16xf32>, tensor<1x1xi32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testConcatV2(%arg: tensor<8x16xf32>, %axis: tensor<1x1xi32>) -> tensor<*xf32> {
|
||||
// expected-error @+1 {{requires axis to be of scalar type (or vector type for older versions)}}
|
||||
%0 = "tf.Concat"(%axis, %arg, %arg) {N = 2: i64} : (tensor<1x1xi32>, tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testConcatV2(%arg0: tensor<8x16xf32>, %arg1: tensor<8xf32>, %axis: tensor<i32>) -> tensor<*xf32> {
|
||||
// expected-error @+1 {{operand type 'tensor<8xf32>' is not compatible with preceding operands; expected rank: 2}}
|
||||
%0 = "tf.ConcatV2"(%arg0, %arg1, %axis) {N = 2: i64} : (tensor<8x16xf32>, tensor<8xf32>, tensor<i32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Valid Concat operation with concat axis 1 or -1.
|
||||
func @testConcatV2(%arg0: tensor<8x16xf32>, %arg1: tensor<8x8xf32>, %axis: tensor<i32>) -> tensor<*xf32> {
|
||||
%0 = "tf.ConcatV2"(%arg0, %arg1, %axis) {N = 2: i64} : (tensor<8x16xf32>, tensor<8x8xf32>, tensor<i32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testConcatV2(%arg0: tensor<8x16xf32>, %arg1: tensor<16x8xf32>, %axis: tensor<i32>) -> tensor<*xf32> {
|
||||
// expected-error @+1 {{operand type 'tensor<16x8xf32>' is not compatible with preceding operands; expected dimension at index 1: 16}}
|
||||
%0 = "tf.ConcatV2"(%arg0, %arg1, %axis) {N = 2: i64} : (tensor<8x16xf32>, tensor<16x8xf32>, tensor<i32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Valid Concat operation with concat axis 1 or -1.
|
||||
func @testConcatV2(%arg0: tensor<8x8xf32>, %arg1: tensor<?x4xf32>, %arg2: tensor<*xf32>, %arg3: tensor<8x?xf32>, %axis: tensor<i32>) -> tensor<*xf32> {
|
||||
%0 = "tf.ConcatV2"(%arg0, %arg1, %arg2, %arg3, %axis) {N = 4: i64} : (tensor<8x8xf32>, tensor<?x4xf32>, tensor<*xf32>, tensor<8x?xf32>, tensor<i32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Valid Pack operation.
|
||||
func @testPack(%arg0: tensor<4x8xf32>, %arg1: tensor<4x8xf32>) -> tensor<*xf32> {
|
||||
%0 = "tf.Pack"(%arg0, %arg1) {axis = 1 : i64, N = 2: i64} : (tensor<4x8xf32>, tensor<4x8xf32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testPack(%arg0: tensor<4x8xf32>, %arg1: tensor<4x8xf32>) -> tensor<*xf32> {
|
||||
// expected-error @+1 {{requires attribute 'N' to match the number of inputs; expected: 2 Found: 1}}
|
||||
%0 = "tf.Pack"(%arg0, %arg1) {axis = 1 : i64, N = 1: i64} : (tensor<4x8xf32>, tensor<4x8xf32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testPack(%arg0: tensor<4x8xf32>, %arg1: tensor<4x2xf32>) -> tensor<*xf32> {
|
||||
// expected-error @+1 {{operand type 'tensor<4x2xf32>' is not compatible with preceding operands; expected dimension at index 1: 8}}
|
||||
%0 = "tf.Pack"(%arg0, %arg1) {axis = 1 : i64, N = 2: i64} : (tensor<4x8xf32>, tensor<4x2xf32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testPack(%arg0: tensor<4x8xf32>, %arg1: tensor<4x8xf32>, %axis: tensor<i32>) -> tensor<*xf32> {
|
||||
// expected-error @+1 {{attribute 'axis' should be within range [-3, 3); actual value: 3}}
|
||||
%0 = "tf.Pack"(%arg0, %arg1) {axis = 3 : i64, N = 2: i64} : (tensor<4x8xf32>, tensor<4x8xf32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
@ -271,6 +271,39 @@ func @merge_with_variant_type(%arg0: tensor<!tf.variant>, %arg1: tensor<!tf.vari
|
||||
return %result : tensor<!tf.variant<tensor<8xf32>>>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @merge_with_ref_type
|
||||
func @merge_with_ref_type(%arg0: tensor<4x!tf.f32ref>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||
%result = tf_executor.graph {
|
||||
|
||||
// CHECK: tf_executor.Merge{{.*}}(tensor<4x!tf.f32ref>, tensor<4xf32>) -> (tensor<4xf32>, tensor<i32>, !tf_executor.control)
|
||||
%value, %idx, %ctlMerge = "tf_executor.Merge"(%arg0, %arg1) : (tensor<4x!tf.f32ref>, tensor<4xf32>) -> (tensor<4xf32>, tensor<i32>, !tf_executor.control)
|
||||
tf_executor.fetch %value : tensor<4xf32>
|
||||
}
|
||||
return %result : tensor<4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @merge_with_dynamic_shape
|
||||
func @merge_with_dynamic_shape(%arg0: tensor<2xf32>, %arg1: tensor<3xf32>) -> tensor<?xf32> {
|
||||
%result = tf_executor.graph {
|
||||
|
||||
// CHECK: tf_executor.Merge{{.*}}(tensor<2xf32>, tensor<3xf32>) -> (tensor<?xf32>, tensor<i32>, !tf_executor.control)
|
||||
%value, %idx, %ctlMerge = "tf_executor.Merge"(%arg0, %arg1) : (tensor<2xf32>, tensor<3xf32>) -> (tensor<?xf32>, tensor<i32>, !tf_executor.control)
|
||||
tf_executor.fetch %value : tensor<?xf32>
|
||||
}
|
||||
return %result : tensor<?xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @merge_with_unranked_shape
|
||||
func @merge_with_unranked_shape(%arg0: tensor<2xf32>, %arg1: tensor<3xf32>) -> tensor<*xf32> {
|
||||
%result = tf_executor.graph {
|
||||
|
||||
// CHECK: tf_executor.Merge{{.*}}(tensor<2xf32>, tensor<3xf32>) -> (tensor<*xf32>, tensor<i32>, !tf_executor.control)
|
||||
%value, %idx, %ctlMerge = "tf_executor.Merge"(%arg0, %arg1) : (tensor<2xf32>, tensor<3xf32>) -> (tensor<*xf32>, tensor<i32>, !tf_executor.control)
|
||||
tf_executor.fetch %value : tensor<*xf32>
|
||||
}
|
||||
return %result : tensor<*xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @enter(%arg0: tensor<*xf32>, %arg1: i1) -> tensor<*xf32> {
|
||||
func @enter(%arg0: tensor<*xf32>, %arg1: i1) -> tensor<*xf32> {
|
||||
%result = tf_executor.graph {
|
||||
|
@ -490,7 +490,7 @@ func @invalid_merge(%arg0: tensor<*xf32>, %arg1: tensor<i1>) -> tensor<*xf32> {
|
||||
%true, %false, %ctlSwitch = tf_executor.Switch %arg0, %arg1 : tensor<*xf32>
|
||||
|
||||
%value, %idx, %ctlMerge = "tf_executor.Merge"(%true, %false, %arg1) : (tensor<*xf32>, tensor<*xf32>, tensor<i1>) -> (tensor<*xf32>, tensor<i32>, !tf_executor.control)
|
||||
// expected-error@-1 {{'tf_executor.Merge' op expects all operands to be broadcastable but got 'tensor<*xf32>' vs 'tensor<i1>'}}
|
||||
// expected-error@-1 {{'tf_executor.Merge' op expects all operands to be broadcastable with output type but got 'tensor<i1>' vs 'tensor<*xf32>'}}
|
||||
tf_executor.fetch %value : tensor<*xf32>
|
||||
}
|
||||
return %result : tensor<*xf32>
|
||||
@ -502,7 +502,7 @@ func @invalid_merge(%arg0: tensor<*xf32>, %arg1: tensor<i1>) -> tensor<*xf32> {
|
||||
func @invalid_merge(%arg0: tensor<*xf32>, %arg1: tensor<4xf32>) -> tensor<8xf32> {
|
||||
%result = tf_executor.graph {
|
||||
%value, %idx, %ctlMerge = "tf_executor.Merge"(%arg0, %arg1) : (tensor<*xf32>, tensor<4xf32>) -> (tensor<8xf32>, tensor<i32>, !tf_executor.control)
|
||||
// expected-error@-1 {{'tf_executor.Merge' op expects all operands to be broadcastable but got 'tensor<8xf32>' vs 'tensor<4xf32>'}}
|
||||
// expected-error@-1 {{'tf_executor.Merge' op expects all operands to be broadcastable with output type but got 'tensor<4xf32>' vs 'tensor<8xf32>'}}
|
||||
tf_executor.fetch %value : tensor<8xf32>
|
||||
}
|
||||
return %result : tensor<8xf32>
|
||||
@ -514,7 +514,7 @@ func @invalid_merge(%arg0: tensor<*xf32>, %arg1: tensor<4xf32>) -> tensor<8xf32>
|
||||
func @invalid_merge(%arg0: tensor<*x!tf.variant>, %arg1: tensor<4x!tf.variant>) -> tensor<8x!tf.variant> {
|
||||
%result = tf_executor.graph {
|
||||
%value, %idx, %ctlMerge = "tf_executor.Merge"(%arg0, %arg1) : (tensor<*x!tf.variant>, tensor<4x!tf.variant>) -> (tensor<8x!tf.variant>, tensor<i32>, !tf_executor.control)
|
||||
// expected-error@-1 {{'tf_executor.Merge' op expects all operands to be broadcastable but got 'tensor<8x!tf.variant>' vs 'tensor<4x!tf.variant>'}}
|
||||
// expected-error@-1 {{'tf_executor.Merge' op expects all operands to be broadcastable with output type but got 'tensor<4x!tf.variant>' vs 'tensor<8x!tf.variant>'}}
|
||||
tf_executor.fetch %value : tensor<8x!tf.variant>
|
||||
}
|
||||
return %result : tensor<8x!tf.variant>
|
||||
@ -522,6 +522,18 @@ func @invalid_merge(%arg0: tensor<*x!tf.variant>, %arg1: tensor<4x!tf.variant>)
|
||||
|
||||
// -----
|
||||
|
||||
// Check that if result is a ref type, all operands need to be ref too.
|
||||
func @inavlid_merge(%arg0: tensor<4x!tf.f32ref>, %arg1: tensor<4xf32>) -> tensor<4x!tf.f32ref> {
|
||||
%result = tf_executor.graph {
|
||||
%value, %idx, %ctlMerge = "tf_executor.Merge"(%arg0, %arg1) : (tensor<4x!tf.f32ref>, tensor<4xf32>) -> (tensor<4x!tf.f32ref>, tensor<i32>, !tf_executor.control)
|
||||
// expected-error@-1 {{'tf_executor.Merge' op expects same operand and output element type but got 'tensor<4xf32>' vs 'tensor<4x!tf.f32ref>'}}
|
||||
tf_executor.fetch %value : tensor<4x!tf.f32ref>
|
||||
}
|
||||
return %result : tensor<4x!tf.f32ref>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Check that merge data inputs can't appear after control input.
|
||||
func @invalid_merge(%arg0: tensor<*xf32>, %arg1: tensor<i1>) -> tensor<*xf32> {
|
||||
%result = tf_executor.graph {
|
||||
|
@ -11,9 +11,9 @@ module {
|
||||
|
||||
%1 = "tf_device.launch_func"(%0) {_tpu_replicate = "cluster0", device = "tpu0", func = @tpu0_func} : (tensor<?xi32>) -> tensor<?xi32>
|
||||
// CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[A_OUTPUT]])
|
||||
// CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf.MLIRCompileToTPU"(%[[A_SHAPE_OUTPUT]])
|
||||
// CHECK-SAME: _tpu_replicate = "cluster0"
|
||||
// CHECK-SAME: module
|
||||
// CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]])
|
||||
// CHECK-SAME: NumDynamicShapes = 1
|
||||
// CHECK-SAME: mlir_module
|
||||
// CHECK-SAME: func @main
|
||||
// CHECK-SAME: tf.B
|
||||
// CHECK-NOT: func = @tpu0_func
|
||||
@ -68,9 +68,8 @@ module {
|
||||
|
||||
%1 = "tf_device.launch_func"(%0) {_tpu_replicate = "cluster0", device = "tpu0", func = @tpu0_func} : (tensor<?xi32>) -> tensor<?xi32>
|
||||
// CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[A_OUTPUT]])
|
||||
// CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf.MLIRCompileToTPU"(%[[A_SHAPE_OUTPUT]])
|
||||
// CHECK-SAME: _tpu_replicate = "cluster0"
|
||||
// CHECK-SAME: module
|
||||
// CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]])
|
||||
// CHECK-SAME: mlir_module
|
||||
// CHECK-SAME: func @main
|
||||
// CHECK-SAME: tf.B
|
||||
// CHECK-SAME: func @nested_func
|
||||
@ -112,9 +111,8 @@ module {
|
||||
|
||||
%1 = "tf_device.launch_func"(%0) {_tpu_replicate = "cluster0", device = "tpu0", func = @tpu0_func} : (tensor<?xi32>) -> tensor<?xi32>
|
||||
// CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[A_OUTPUT]])
|
||||
// CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf.MLIRCompileToTPU"(%[[A_SHAPE_OUTPUT]])
|
||||
// CHECK-SAME: _tpu_replicate = "cluster0"
|
||||
// CHECK-SAME: module
|
||||
// CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]])
|
||||
// CHECK-SAME: mlir_module
|
||||
// CHECK-SAME: func @main
|
||||
// CHECK-SAME: tf.B
|
||||
// CHECK-SAME: func @referenced_func
|
||||
@ -155,9 +153,8 @@ module {
|
||||
|
||||
%1 = "tf_device.launch_func"(%0) {_tpu_replicate = "cluster0", device = "tpu0", func = @tpu0_func} : (tensor<?xi32>) -> tensor<?xi32>
|
||||
// CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[A_OUTPUT]])
|
||||
// CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf.MLIRCompileToTPU"(%[[A_SHAPE_OUTPUT]])
|
||||
// CHECK-SAME: _tpu_replicate = "cluster0"
|
||||
// CHECK-SAME: module
|
||||
// CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]])
|
||||
// CHECK-SAME: mlir_module
|
||||
// CHECK-SAME: func @main
|
||||
// CHECK-SAME: tf.B
|
||||
// CHECK-SAME: @referenced_func1
|
||||
@ -206,9 +203,8 @@ module {
|
||||
|
||||
%1 = "tf_device.launch_func"(%0) {_tpu_replicate = "cluster0", device = "tpu0", func = @tpu0_func} : (tensor<?xi32>) -> tensor<?xi32>
|
||||
// CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[A_OUTPUT]])
|
||||
// CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf.MLIRCompileToTPU"(%[[A_SHAPE_OUTPUT]])
|
||||
// CHECK-SAME: _tpu_replicate = "cluster0"
|
||||
// CHECK-SAME: module
|
||||
// CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]])
|
||||
// CHECK-SAME: mlir_module
|
||||
// CHECK-SAME: func @main
|
||||
// CHECK-SAME: tf.B
|
||||
// CHECK-COUNT-2: call @referenced_func
|
||||
@ -251,9 +247,8 @@ module {
|
||||
|
||||
%1 = "tf_device.launch_func"(%0) {_tpu_replicate = "cluster0", device = "tpu0", func = @tpu0_func0} : (tensor<?xi32>) -> tensor<?xi32>
|
||||
// CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[A_OUTPUT]])
|
||||
// CHECK: %[[COMPILE0_OUTPUT:[0-9]*]]:2 = "tf.MLIRCompileToTPU"(%[[A_SHAPE_OUTPUT]])
|
||||
// CHECK-SAME: _tpu_replicate = "cluster0"
|
||||
// CHECK-SAME: module
|
||||
// CHECK: %[[COMPILE0_OUTPUT:[0-9]*]]:2 = "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]])
|
||||
// CHECK-SAME: mlir_module
|
||||
// CHECK-SAME: func @main
|
||||
// CHECK-SAME: tf.B
|
||||
// CHECK-NOT: func = @tpu0_func0
|
||||
@ -263,9 +258,8 @@ module {
|
||||
|
||||
%2 = "tf_device.launch_func"(%1) {_tpu_replicate = "cluster1", device = "tpu0", func = @tpu0_func1} : (tensor<?xi32>) -> tensor<?xi32>
|
||||
// CHECK: %[[EXECUTE0_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[EXECUTE0_OUTPUT]])
|
||||
// CHECK: %[[COMPILE1_OUTPUT:[0-9]*]]:2 = "tf.MLIRCompileToTPU"(%[[EXECUTE0_SHAPE_OUTPUT]])
|
||||
// CHECK-SAME: _tpu_replicate = "cluster1"
|
||||
// CHECK-SAME: module
|
||||
// CHECK: %[[COMPILE1_OUTPUT:[0-9]*]]:2 = "tf._TPUCompileMlir"(%[[EXECUTE0_SHAPE_OUTPUT]])
|
||||
// CHECK-SAME: mlir_module
|
||||
// CHECK-SAME: func @main
|
||||
// CHECK-SAME: tf.D
|
||||
// CHECK-NOT: func = @tpu0_func1
|
||||
@ -303,9 +297,8 @@ module {
|
||||
|
||||
%1 = "tf_device.launch_func"(%0) {_tpu_replicate = "cluster0", device = "tpu0", func = @tpu0_func} : (tensor<?xi32>) -> tensor<?xi32>
|
||||
// CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[A_OUTPUT]])
|
||||
// CHECK: %[[COMPILE0_OUTPUT:[0-9]*]]:2 = "tf.MLIRCompileToTPU"(%[[A_SHAPE_OUTPUT]])
|
||||
// CHECK-SAME: _tpu_replicate = "cluster0"
|
||||
// CHECK-SAME: module
|
||||
// CHECK: %[[COMPILE0_OUTPUT:[0-9]*]]:2 = "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]])
|
||||
// CHECK-SAME: mlir_module
|
||||
// CHECK-SAME: func @main
|
||||
// CHECK-SAME: tf.B
|
||||
// CHECK-NOT: func = @tpu0_func
|
||||
@ -315,9 +308,8 @@ module {
|
||||
|
||||
%2 = "tf_device.launch_func"(%1) {_tpu_replicate = "cluster1", device = "tpu0", func = @tpu0_func} : (tensor<?xi32>) -> tensor<?xi32>
|
||||
// CHECK: %[[EXECUTE0_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[EXECUTE0_OUTPUT]])
|
||||
// CHECK: %[[COMPILE1_OUTPUT:[0-9]*]]:2 = "tf.MLIRCompileToTPU"(%[[EXECUTE0_SHAPE_OUTPUT]])
|
||||
// CHECK-SAME: _tpu_replicate = "cluster1"
|
||||
// CHECK-SAME: module
|
||||
// CHECK: %[[COMPILE1_OUTPUT:[0-9]*]]:2 = "tf._TPUCompileMlir"(%[[EXECUTE0_SHAPE_OUTPUT]])
|
||||
// CHECK-SAME: mlir_module
|
||||
// CHECK-SAME: func @main
|
||||
// CHECK-SAME: tf.B
|
||||
// CHECK-NOT: func = @tpu0_func
|
||||
@ -351,9 +343,8 @@ module {
|
||||
|
||||
%1 = "tf_device.launch_func"(%0) {_tpu_replicate = "cluster0", device = "tpu0", func = @tpu0_func} : (tensor<?xi32>) -> tensor<?xi32>
|
||||
// CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[A_OUTPUT]])
|
||||
// CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf.MLIRCompileToTPU"(%[[A_SHAPE_OUTPUT]])
|
||||
// CHECK-SAME: _tpu_replicate = "cluster0"
|
||||
// CHECK-SAME: module
|
||||
// CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]])
|
||||
// CHECK-SAME: mlir_module
|
||||
// CHECK-SAME: func @main
|
||||
// CHECK-SAME: tf.B
|
||||
// CHECK-SAME: func @referenced_func2
|
||||
@ -404,3 +395,44 @@ module {
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
|
||||
// Tests that TPUCompilationResult operations are properly rewritten
|
||||
|
||||
// CHECK-LABEL: func @tpu_compilation_result
|
||||
func @tpu_compilation_result(%arg0: tensor<?xi32>) -> (tensor<?xi32>, tensor<!tf.string>, tensor<!tf.string>) {
|
||||
|
||||
// CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf._TPUCompileMlir"
|
||||
// CHECK: %[[EXECUTE_OUTPUT:[0-9]*]] = "tf.TPUExecute"
|
||||
%1 = "tf_device.launch_func"(%arg0) {_tpu_replicate = "cluster0", device = "tpu0", func = @tpu0_func} : (tensor<?xi32>) -> tensor<?xi32>
|
||||
|
||||
%compile_result = "tf.TPUCompilationResult"() {_tpu_replicate = "cluster0"} : () -> tensor<!tf.string>
|
||||
%compile_result2 = "tf.TPUCompilationResult"() {_tpu_replicate = "cluster0"} : () -> tensor<!tf.string>
|
||||
|
||||
// CHECK: return %[[EXECUTE_OUTPUT]], %[[COMPILE_OUTPUT]]#0, %[[COMPILE_OUTPUT]]#0
|
||||
return %1, %compile_result, %compile_result2 : tensor<?xi32>, tensor<!tf.string>, tensor<!tf.string>
|
||||
}
|
||||
|
||||
func @tpu0_func(%arg0: tensor<?xi32>) -> tensor<?xi32> {
|
||||
%0 = "tf.B"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
|
||||
return %0 : tensor<?xi32>
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
// Tests that TPUReplicatedInput and TPUReplicatedOutput operations are properly rewritten
|
||||
|
||||
func @main(%arg0 : tensor<0xf32>, %arg1 : tensor<0xf32>) -> tensor<0xf32> {
|
||||
// CHECK: %[[EXECUTE_OUTPUT:[0-9]*]] = "tf.TPUExecute"(%arg0, %arg1
|
||||
%0 = "tf.TPUReplicatedInput"(%arg0) {N = 1 : i64} : (tensor<0xf32>) -> tensor<0xf32>
|
||||
%1 = "tf.TPUReplicatedInput"(%arg1) {N = 1 : i64} : (tensor<0xf32>) -> tensor<0xf32>
|
||||
%2 = "tf_device.launch_func"(%0, %1) {device = "", _tpu_replicate = "cluster", func = @_func} : (tensor<0xf32>, tensor<0xf32>) -> tensor<0xf32>
|
||||
%3 = "tf.TPUReplicatedOutput"(%2) {num_replicas = 1 : i64} : (tensor<0xf32>) -> tensor<0xf32>
|
||||
return %3 : tensor<0xf32>
|
||||
}
|
||||
func @_func(%arg0: tensor<0xf32>, %arg1: tensor<0xf32>) -> tensor<0xf32> {
|
||||
%0 = "tf.Const"() {value = dense<3.000000e+00> : tensor<0xf32>} : () -> tensor<0xf32>
|
||||
return %0 : tensor<0xf32>
|
||||
}
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user